Reference
Hard example mining이라 함은, 모델이 잘 맞춰내지 못하는 데이터들을 집중적으로 모델에 학습시킴으로써 보다 효율적인 학습을 꾀하는 방법론이다. 즉 모델이 이미 잘 추론해내는 데이터를 계속해서 학습에 사용함으로써 발생하는 비효율을 최소화하기 위해, 모델이 어려워하는 유의미한 데이터를 잘 골라 모델에 먹여주는 것이다.
이러한 방법론은 그다지 새로운 것은 아닌데, 기존의 bootstrapping이 이러한 hard example mining의 대표적인 예시이다. Boostrapping은 대표적으로 SVM의 학습에 종종 사용되던 방법이었는데, 간단하게 설명하자면 1) 전체 데이터셋의 subset에 모델이 수렴할 때까지 학습시키고, 2) 학습된 모델을 통해 추론한 결과의 loss가 높은 데이터를 위주로 subset을 재편성하고, 3) 새로운 subset에 다시 모델이 수렴할 때까지 학습시키는 것을 반복한다. (참고로 통계학에서 사용하는 boostrapping과는 별개의 개념이다. 통계학의 boostrapping은 데이터가 부족한 상황에서 어떠한 지표의 불확실성(분산)을 평가하기 위해, 주어진 데이터셋에서 작은 크기의 subset들을 반복적으로 sampling하는 방법이다.)
하지만 이러한 boostrapping은 딥러닝 프레임워크에 직접적으로 적용하기에는 적절하지 않은데, 작은 batch size에 대한 step-wise learning을 반복하는 딥러닝 모델에서 매 step마다 데이터셋을 재편성하는 것은 더욱 큰 비효율을 야기할 것이기 때문이다. 이를 해결하기 위해 딥러닝 프레임워크 - 특히 object detection task를 위한 - 에 맞추어 hard example mining을 개조한 알고리즘이 바로 OHEM(Online Hard Example Mining)이다.
착상은 다음과 같다: RCNN 계열의 object detection model은 학습에 한두 개의 이미지밖에 사용하지 않지만, 하나의 이미지에 수천 개의 RoI(Region of Interest)들이 존재한다. 즉 OHEM은 각각의 step에서 학습 이미지로부터 도출된 RoI 중 hard example들을 뽑아내어, 이들을 대상으로 모델을 학습한다. 알고리즘을 정리하면 다음과 같다.
•
이미지 내의 모든 RoI들에 대해 forward propagation을 수행한다.
•
각 이미지별로 loss가 가장 높은 $B/N$개의 RoI들을 뽑는다.
•
단, 겹치는 RoI가 학습에 포함되는 것을 방지하기 위해 NMS를 수행한다.
•
선정된 총 $B$개의 RoI들을 대상으로만 backward propagation을 수행한다.
Boostrapping에서는 data batch에 일정한 hard/easy example의 비율을 유지하려 하는데, OHEM에서는 hard example들만 뽑는다는 것이 인상적이다. 페이퍼에서는 "어차피 easy example 데이터를 잘 못 맞추게 되면, 다시 이러한 example이 학습에 포함될 것이기 때문에" 문제가 없을 것이라고 주장한다.
이러한 알고리즘을 구현하는 방법에는 대표적으로 두 가지가 있을 수 있는데, 첫 번째는 easy RoI들의 loss를 0으로 설정하여 back propagation을 수행하는 것이다. 쉽지만, 모든 RoI를 backward propagation에 불필요하게 포함함으로써 memory inefficiency가 발생한다. 그래서 페이퍼에서 제시하는 두 번째 방법은 1) gradient를 계산하지 않는(read-only) 모델을 포함한 두 개의 모델을 복제한 뒤, 2) 전체 RoI에 대해 read-only 모델로 forward propagation을 수행하여 hard example을 선정하고, 3) hard RoI에 대해서 두 번째 모델로 forward/backward propagation을 수행하는 것이다. 첫 번째 옵션보다 두 배 이상 빠르다고 한다.
또한 페이퍼에서는 전체 RoI를 대상으로 학습한 것보다, 혹은 매우 큰 를 대상으로 학습한 것보다 적절한 를 설정하여 OHEM으로 학습하는 것이 유의미하게 더 높은 학습 성과를 보여줌을 empirical하게 증명하였다. 어려운 데이터만을 골라 모델에 먹여주는 것이 효과가 있었던 것이다. 하지만 RCNN 계열의 object detector를 고려하여 고안된 알고리즘이기 때문에 다른 task에 직접적으로 generalize하기는 어려워보인다.
E.O.D.