🛑

CenterNet: Objects as Points

Tags
Object Detection
Created
2021/01/29 07:01
Publication
Rate
3
Summary
굉장히 빠르면서도 준수한 성능을 보이는 one-shot detector를 제안한다. (상세 페이지 참고)

Reference

Introduction

2-stage object detection의 경우에는 물체의 위치를 먼저 추론한 다음, 추론된 위치들에 대한 개별적인 classification을 수행한다. 이 경우 학습은 물론이거니와 추론에 꽤 오랜 시간이 걸리는 심각한 문제점이 발생한다. 한편 YOLO와 같은 이전의 1-sage object detection의 경우에는 단 한 번의 forward propagation으로 추론이 가능하지만, 다수의 anchor box를 사용함으로써 1) anchor box를 직접 선택해줘야 하고, 2) 여러 anchor box에 의해 동일한 물체가 인식됨으로써 NMS(Non-Max Suppression)라는 후처리를 따로 가해줘야 한다는 문제가 있다.
이러한 문제들을 해결하기 위해 CenterNet은 물체를 그 중심점으로써 인식하고, box size나 다른 feature들(예를 들어 pose estimation을 위한 joint들)을 이미지로부터 직접적으로 regress하는 방식을 택한다. CenterNet은 중심점을 찾아내기 위해 중심점에 대한 heatmap을 생성하고, 그렇게 생성된 heatmap의 peak point(예컨대 주변 9 그리드 중 가장 값이 높은 그리드)를 중심점으로 선택(그리고 그 중심점에 대해 다른 feature들을 regress)하는데, 이로써 후처리가 필요하지 않은 1-stage detection이 가능하게 된다.
기본적으로는 YOLO 등의 1-stage detection과 접근법이 유사한 편인데, 논문에서 제시하는 핵심적인 차별점은 다음과 같다.
First, our CenterNet assigns the "anchor" based solely on location, not box overlap. We have no manual thresholds for foreground and background classification.
Second, we only have one positive "anchor" per object, and hence do not need Non-Maximum Suppression. We simply extract local peaks in the keypoint heatmap.
Third, CenterNet uses a larger output resolution (output stride of 4) compared to traditional object detectors. This eliminates the need for multiple anchors.
한편 CornerNet(박스의 좌상, 우하점으로 물체를 인식)이나 ExtremeNet(상하좌우 극단점과 중심점으로 물체를 인식) 등 점으로써 물체를 인식하려는 알고리즘은 이전에도 있었지만, CenterNet은 이들보다 간단하면서(?)도 효과적인 detection을 자랑한다.

Preliminary

CenterNet에서는 이미지 IRW×H×3I \in R^{W \times H \times 3}의 mapping Y^[0,1]WR×HR×C\hat{Y} \in [0, 1]^{\frac{W}{R} \times \frac{H}{R} \times C}를 학습하는 것을 목적으로 한다. 이 때
W,HW, H는 각각 이미지의 너비와 높이, RR은 output stride이다.
CC는 keypoint type이다. 예컨대 detection에서 CC는 category의 수, pose estimation에서 CC는 관절의 수(논문에서는 17을 사용)가 될 수 있다.
Y^=1\hat{Y} = 1은 해당 그리드가 keypoint임을 의미한다.
학습을 위해서는 이미지에 있는 각각의 keypoint의 low-resolution 좌표를 구하고, Gaussian Kernel을 이용해 각각의 그리드의 값을 구한다. 여기서 Gaussian Kernel은 해당 그리드가 (가장 인접한 - 즉 같은 class의 Kernel값이 겹칠 경우, 값이 큰 것을 택한다) 중심점과 얼마나 가까운지를 측정한다. 즉 중심점에 가까울수록 그 값이 크고, 멀수록 그 값이 작다. 이렇게 구한 값들이 히트맵이 되는 것이다.
Yxyc=exp((xpx~)2+(ypy~)22σp2)Y_{xyc} = exp( -\frac{ (x - \tilde{p_x})^2 + (y - \tilde{p_y})^2 }{ 2\sigma_p^2 } )
그러면 다음과 같은 Focal Loss를 이용해 히트맵을 학습한다.
Lk=1Nxyc{(1Y^xyc)αlog(Y^xyc),if Yxyc=1(1Yxyc)β(Y^xyc)αlog(1Y^xyc),otherwise L_k = -\frac{1}{N} \sum_{xyc} \begin{cases} (1 - \hat{Y}{xyc})^{\alpha} \log{( \hat{Y}{xyc} )}, & \text{if}\ Y_{xyc} = 1 \\ (1 - Y_{xyc})^{\beta} (\hat{Y}{xyc})^{\alpha} \log{( 1 - \hat{Y}{xyc} )}, & \text{otherwise} \end{cases} 
위의 Focal Loss를 해석해보자면,
첫째 항은 해당 그리드가 중심점일 경우, Y^\hat{Y}가 1에서 얼마나 멀리 떨어져 있는지를 측정하며,
둘째 항은 해당 그리드가 중심점이 아닐 경우, Y^\hat{Y}가 0에서 얼마나 멀리 떨어져 있는지를 측정하지만, 그 값을 Kernel 값에 대해 보정하여 계산한다. 즉 Kernel 값이 작을수록(중심점에서 멀리 떨어져 있을수록) 0에서 멀리 떨어져 있음으로써 발생하는 loss가 커진다.
여기서 α,β\alpha, \beta는 각각 하이퍼파라미터, NN은 이미지 II의 keypoint의 수이다. 논문에서는 α=2,β=4\alpha = 2, \beta = 4를 사용한다.
한편 히트맵을 그리기 위해 원본 이미지의 해상도를 RR로 나누어 줄이게 되는데(해당 그리드의 중앙값을 중심 좌표로 사용한다), 이로 인해 발생하는 정보의 손실을 고려할 수 있다. 이는 그 offset에 대한 L1 loss를 cost에 포함함으로써 가능하다. 이러한 offset loss는 keypoint에 해당하는 그리드에만 적용한다. 즉 다른 그리드는 masking하여 loss를 계산한다.
Loff=1NpO^p~(pRp~)L_{off} = \frac{1}{N} \sum_p{| \hat{O}_{\tilde{p}} - (\frac{p}{R} - \tilde{p}) |}

Objects as Points

CenterNet의 강점 중 하나는, object detection뿐 아니라 pose estimation, 3D detection 등 다양한 task에 대한 generalization이 쉽다는 것이다. 이는 기본적으로 관련 feature들을 이미지에서 직접적으로 regression하기에 가능하다.
먼저, object size를 중심점에 해당하는 그리드에서 다음과 같은 L1 loss를 계산함으로써 학습한다.
Lsize=1NkNS^pkskL_{size} = \frac{1}{N} \sum_k^N{| \hat{S}_{p_k} - s_k |}
이제 detection을 위한 loss는 다음과 같이 계산된다. 논문에서는 λsize=0.1,λoff=1\lambda_{size} = 0.1, \lambda_{off} = 1를 사용한다.
Ldet=Lk+λsizeLsize+λoff+LoffL_{det} = L_k + \lambda_{size} L_{size} + \lambda_{off} + L_{off}
이에 따라 각 그리드의 출력값의 크기는 class별 heat(C), size(2), offset(2)를 포함하여 C+4C + 4가 된다. 여기서 값이 가장 높은 100개의 peaks들을 남기고, 이들에 대해 위의 정보들을 종합하여 bounding box를 생성함으로써 detection이 가능해진다.
그런데... 여기서 NMS가 따로 필요 없다고 하는데, 그러면 이미지마다 꼭 100개의 bounding box를 생성하게 되는 건가...?

Human Pose Estimation

3D detection은 어려운데 조만간 쓸 일은 없을 것 같으니 생략하고, CenterNet을 pose estimation task로 generalize하는 것을 정리한다.
기본적으로 human pose estimation는 인간의 k=17k = 17개의 관절을 파악함으로써 가능하다고 전제한다. 따라서, CenterNet에서는 pose estimation을 위해 k×2k \times 2 dimension의 "인간 object의 중심점으로부터의 관절 object에 중심점까지의 offset" feature(JJ)를 L1 loss를 통해 학습한다.
여기서 joint keypoint detection 결과를 개선하기 위해, kk class의 관절에 대한 keypoint heatmap을 위에서 설명한 방식과 동일하게 학습한다. 즉 Gaussian Kernel을 통해 히트맵을 생성하고, 각각 관절의 중심점에 대한 offset을 포함하여 학습한다.
그리고 처음에 학습한 관절 offset J^\hat{J}에 대해, 해당 관절 class의 가장 가까운 heatmap keypoint(confidence > 0.1인 peaks 들을 필터링하여 뽑아냄) 좌표을 할당함으로써 우리가 원하는 관절의 좌표를 찾아낸다! 즉 진짜 좌표는 히트맵으로 찾는 것이고, 인간의 중심점과 그에 대한 각각의 관절 offset은 grouping cue 역할을 하는 것이다. (해석이 틀렸을 수 있다. 혹시나 틀렸다면 댓글이나 연락 감사합니다.)
이러한 CenterNet의 성능은, 꽤 빠르고 적당한 정확도를 보여주는 정도인 것 같다. SOTA 급의 정확도는 당연히 아니고, 그렇다고 막 미친듯이 빠른 것도 아니지만, 그 밸런스를 잘 맞춘 상태에서 상당한 성능을 자랑하는 CenterNet(?)!
E.O.D.