🎑

Deconstructing the Regularization of BatchNorm

Tags
General
Created
2021/04/24 06:45
Publication
ICLR'21
Rate
4
Summary
모델의 generalization과 정확도 성능 향상이라는 측면에서 2015년에 발표된 Batch Normalization은 이제 딥러닝 모델 설계의 표준적인 component가 되었다. 최근의 SOTA 모델들은 그 분야에 상관 없이 대부분이 BN layer를 포함한다. 하지만 이러한 유행(?)에도 불구하고, Batch Normalization이 대체 왜 좋은지, 어떻게 이러한 영향을 주게 되는지에 대한 mechanism에 대한 이해는 부족한 상태이다. 특히 Batch Normalization은 1) covariate shift 현상을 해소하고, 2) 더 높은 learning rate을 사용할 수 있도록 하며, 3) model initialization을 개선하고, 4) conditioning(?)을 개선하는 등 여러 가지 효과를 가지고 있는 것으로 알려져 있는데, 이렇게 여러 효과들이 서로 섞여 오히려 정확한 mechanism을 파악하는 것을 어렵게 만들고는 한다. 이에 따라 이 페이퍼에서는 Batch Normalization의 효과를 여러 개의 단순한 component로 분해(deconstruct)하여 각각의 영향력들을 조사하고, 이로써 Batch Normalization의 mechanism을 더욱 자세히 이해하고자 한다.

1. Introduction

모델의 generalization과 정확도 성능 향상이라는 측면에서 2015년에 발표된 Batch Normalization은 이제 딥러닝 모델 설계의 표준적인 component가 되었다. 최근의 SOTA 모델들은 그 분야에 상관 없이 대부분이 BN layer를 포함한다. 하지만 이러한 유행(?)에도 불구하고, Batch Normalization이 대체 왜 좋은지, 어떻게 이러한 영향을 주게 되는지에 대한 mechanism에 대한 이해는 부족한 상태이다.
특히 Batch Normalization은 1) covariate shift 현상을 해소하고, 2) 더 높은 learning rate을 사용할 수 있도록 하며, 3) model initialization을 개선하고, 4) conditioning(?)을 개선하는 등 여러 가지 효과를 가지고 있는 것으로 알려져 있는데, 이렇게 여러 효과들이 서로 섞여 오히려 정확한 mechanism을 파악하는 것을 어렵게 만들고는 한다.
이에 따라 이 페이퍼에서는 Batch Normalization의 효과를 여러 개의 단순한 component로 분해(deconstruct)하여 각각의 영향력들을 조사하고, 이로써 Batch Normalization의 mechanism을 더욱 자세히 이해하고자 한다.
이 페이퍼의 contribution을 다음과 같이 요약한다.
How does normalization help generalization?
Batch Normalization의 normalization 효과를 여러 component(i.e. intermediate layer의 standardization, feature embedding layer의 normalization, output layer의 normalization)로 분해하여 그 효과를 additive penality와 ablation을 통해 연구한다. 이를 통해 Fixup initialization과 additive penality를 통해 initialization과 학습 과정에서 발생하는 explosive growth를 regularization함으로써 Batch Normalization의 generalization boost 효과를 상당 부분 수복할 수 있다는 사실을 발견했다.
Links to Fixup and Dropout
Batch Normalization과 final layer에서의 regularization, Dropout regularization, Fixup initialization과의 관계를 각각 규명한다.
Simplicity in regularization
첫 번째 contribution에서 밝혀낸 additive penality는 매우 단순하면서도, 그 자체로 상당한 개선 효과를 가지는 standalone regularization으로 활용될 수 있다.

2. Decomposing the Regularization Effects of Batch Normalization

이 섹션에서는 Batch Normalization을 여러 개의 단순한 mechanism으로 분해해본다. 목적은 Batch Normalization의 generalization boost 효과가 어떠한 mechanism에 의해 발생하는지 보다 명확히 이해하는 것이다.
특히 Batch Normalization은 1) intermediate layer의 standardization, 2) final layer의 norm regularization이라는 implicit effect를 가지는데, 이를 분리하여 연구하기 위해 ablation과 additive penalty를 사용하고자 한다.

2.1. Regularizing Against Explosive Growth in the Final Layer

먼저 final layer의 normalization 효과를 먼저 규명해본다.
Neural network NN(x)=WEmb(x)NN(x) = W \text{Emb}(x) 와 loss LL 에 대해서, Emb(x)=Swish(γBatchNorm(PreEmb(x))+β)\text{Emb}(x) = \text{Swish}(\gamma \text{BatchNorm}(\text{PreEmb}(x)) + \beta) 라고 하자. Swish()\text{Swish}()ρ\rho 를 parameter로 가진 swish function이고, γ,β\gamma, \beta 는 Batch Normalization parameter이다.
이 때 Batch Normalization은 parameter γ,β\gamma, \beta 에 대한 weight decay regularization이 다음과 같은 additive penalty와 equivalent 하도록 유도한다. (자세한 유도는 appendix 참조. Swish function에 대한 Taylor expansion을 진행한 다음 Batch Normalization의 normalization effect - embedding의 기댓값이 0이 되고, 제곱의 기댓값이 1이 된다 - 를 활용하여 유도할 수 있다.)
L(NN(x))+λγ2+λβ2=L(NN(x))+λ4E[Emb(x)2+O(ρ)L(NN(x)) + \lambda||\gamma||^2 + \lambda||\beta||^2 = L(NN(x)) + \frac{\lambda}{4}E[||\text{Emb}(x)^2|| + O(|\rho|)
즉 Batch Normalization parameter의 norm이 직접적으로 feature embedding의 norm을 control할 수 있게 되며, 이에 따라 Batch Normalization paremeter가 작게 유지되는 한 feature embedding norm 역시 explosive growth 할 수 없다. 또 실제로는 weight decay regularization을 포함하지 않더라도, SGD가 low norm parameter를 선호하는 경향이 있기 때문에 여전히 이러한 효과를 누릴 수 있다고 한다.
단 이러한 equivalency는 Batch Normalization을 사용할 때에만 성립한다. 그렇지 않으면 embedding network의 activation이 feature embedding norm에 큰 영향을 미치게 되기 때문인데, 여러 논문에 의해 Batch Normalization을 포함하지 않은 residual network의 activation이 깊이에 대해 exponential한 growth를 보인다는 사실이 밝혀졌다. 실제로 이후 섹션의 실험에서 Batch Normalization을 포함하지 않은 network의 경우 feature embedding norm이 크게 증가함을 확인할 수 있다.
Feature Embedding L2 (EL2)
이제 위와 같은 regularization의 영향을 다음과 같은 additive penalty를 통해 탐구해보고자 한다.
REL2(NN)=1HE[Emb(x)2]R_{EL2}(NN) = \frac{1}{H}E[||\text{Emb}(x)||^2]
위와 같은 regularization을 loss에 추가함으로써, Batch Normalization과 독립적으로 final layer normalization의 영향을 확인해볼 수 있다. 이러한 regularization을 embedding L2 라 부르며, classification layer 바로 직전에 적용한다. 이후 실험으로 인해 이러한 feature embedding normalization이 Batch Normalization의 generalization boost 효과의 대부분을 수복함을 확인할 수 있다.
Functional L2 (FL2)
Feature embedding의 normalization은 final output norm의 regularization을 간접적으로 유도한다. 이러한 영향을 실험하기 위해 다음과 같은 additive penalty를 활용한다.
RFL2(NN)=1KE[NN(x)2]R_{FL2}(NN) = \frac{1}{K}E[||NN(x)||^2]
이러한 regularization을 functional L2 라 부르며, classification model의 logit에 적용한다. 추후 실험을 통해 이러한 FL2와 EL2 중 어느 것이 generalization boost와 더 관련이 깊은지 살펴본다.

2.2. Standardizing the Intermediate Activations of the Model

또한 각 intermediate layer의 normalization 효과를 측정하기 위한 standardization loss 역시 고려한다. 위 두 regularization의 효과를 측정하기 위한 reference로 삼기 위함이다.
DKL(P(x)N(x0,I)=12i(μi2+σi2logσi21)D_{KL}(P(x)||N(x|0, I) = \frac{1}{2}\sum_i(\mu_i^2 + \sigma_i^2 - \log\sigma_i^2 - 1)
μi\mu_iσi\sigma_i 는 각각 intermediate layer의 mean과 variance로, Batch Normalization statistic과 동일한 그것이다.

3. Drawing Links to Other Methods

이 섹션에서는 Batch Normalization과 다른 방법들의 관계를 규명해본다.

3.1. Dropout Regularization

Dropout layer를 intermediate layer에 적용하면 상대적으로 효과가 적으며, final layer의 input에 적용될 경우 가장 큰 효과를 얻을 수 있음이 드러났다. 이 때 이러한 final layer의 Dropout과 MSE loss의 조합에서 Dropout regularization이 다음과 같은 additive penalty와 동치임을 보일 수 있다고 한다.
LDropout=1NE[(WDropout(Emb(xi))yi)2]=1N(WEmb(xi)yi)2+λtr(Wdiag(E[Emb(x)Emb(x)T])WT)L_{\text{Dropout}} = \frac{1}{N}\sum E[(W\text{Dropout}(\text{Emb}(x_i)) - y_i)^2] \\ = \frac{1}{N}(W\text{Emb}(x_i) - y_i)^2 + \lambda tr(W\text{diag}(E[\text{Emb}(x)\text{Emb}(x)^T])W^T)
이 때 RDropout=tr(Wdiag(E[Emb(x)Emb(x)T])WT)R_{\text{Dropout}} = tr(W\text{diag}(E[\text{Emb}(x)\text{Emb}(x)^T])W^T) 라 정의한다.
이제 위와 같은 Dropout regularization이 섹션 2.1에서의 mechanism과 다음과 같이 연결됨을 알 수 있다. 단 이러한 approximation은 feature들이 서로 상대적으로 decorrelate 되어 있다는 전제 하에 가능하다.
KRFL2(NN)=E[NN(x)2]=E[WEmb(x)2]=E[WEmb(x)Emb(x)TWT]RDropout(NN)=tr(Wdiag(E[Emb(x)Emb(x)T])WT)14W4+H4REL2(NN)KR_{FL2}(NN) = E[||NN(x)||^2] \\ {} \\ = E[||W\text{Emb(x)}||^2] = E[W\text{Emb(x)}\text{Emb(x)}^TW^T] \\ {} \\ \approx R_{\text{Dropout}}(NN) = tr(W\text{diag}(E[\text{Emb}(x)\text{Emb}(x)^T])W^T) \\ {} \\ \leq \frac{1}{4}||W||^4 + \frac{H}{4}R_{EL2}(NN)
이러한 upper bound는 Batch Normalization과 weight decay로 학습된 모델이 Dropout robustness를 가지게 됨을 어느 정도 보장해주는데, 이는 즉 Batch Normalization이 이미 Dropout이 가져다주는 regularization effect를 상당 부분 포함하고 있음으로 해석된다. 이러한 해석은 Batch Normalization을 포함한 모델에 Dropout layer를 추가하는 것이 그다지 많은 도움이 되지 않는 현상을 설명해준다.

3.2. Fixup Initialization

Batch normalization 없는 residual network는 conventional initialization의 상황에서 activation의 scale이 depth에 따라 기하급수적으로 증가한다. 이러한 상황에서는 신경망의 학습이 매우 어려워질 수 있다.
Fixup initialization은 이러한 상황에서 normalization 없이 initialization으로 학습을 안정화하는 것을 목표로 한다. Fixup은 첫 gradient step에서 output이 급격하게 변하지 않도록 network를 초기화함으로써 이 목표를 달성하고자 한다. 이 페이퍼에서는 Fixup initialization이 다음과 같은 initial penalty를 최소화함을 주장한다.
REL2(NNFixupt=0),RFL2(NNFixupt=0)O(1)R_{EL2}(NN^{t=0}_{Fixup}), R_{FL2}(NN^{t=0}_{Fixup}) \in O(1)

4. Experiments

4.1. Datasets

CIFAR-10, SVHN, ImageNet 데이터셋을 활용하여 실험을 진행한다.

4.2. Implementation Details

WideResNet, ResNet-50, EfficientNet을 활용하여 실험을 진행한다. 단 Batch Normalization을 제외한 경우에는 Fixup initialization을 활용한다.

4.3. Results: Ablation of BatchNorm and Interventions

가장 먼저 CIFAR-10과 ImageNet 모두에서, Fixup initialization을 활용하였음에도 BatchNorm 모델과 baseline(BatchNorm을 제외하고 Fixup initialization을 포함한 모델)의 성능이 큰 차이를 보임을 확인할 수 있다.
또 standardization regularization(intermediate layer의 normalization)이 CIFAR-10에서는 1) robustness를 개선시켜주지만 2) accuracy는 개선시켜주지 못하고, ImageNet에서는 1) 약간의 정확도 개선을 보여줌을 확인한다. 하지만 페이퍼에서는 이 결과가 standardization regularization의 실패가 아니라 regularization을 위해 수행한 특정한 방식(standardization loss)의 실패라고 주장한다.
한편 final layer normalization을 수행한 경우에는 CIFAR-10과 ImageNet 모두에서 더 나은 결과를 얻을 수 있었다. ImageNet에서는 BatchNorm의 generalization boost 효과를 따라잡을 수 있었지만 CIFAR-10에서는 BatchNorm에 비해서는 살짝 못 미치는 효과를 보여준다. 이러한 결과는 final layer normalization이 Batch Normalization의 효과의 상당 부분을 설명하고 있다는 주장을 뒷받침한다.
Batch Normalization의 장점 중 하나로는 학습 시 더 높은 learning rate을 활용할 수 있게 해 준다는 것인데, 실험 결과 ImageNet에 대해 Batch Normalization 모델에서의 최적 학습률 0.1과 EL2, FL2에서의 최적 학습률 0.1이 같게 나타났으며, CIFAR-10에 대해서는 Batch Normalization 모델에서의 최적 학습률 0.1이 오히려 0.2로 증가한 것을 확인할 수 있었다.
마지막으로 CIFAR-10에서의 robustness test 결과 오히려 baseline이 Batch Normalization에 비해 정확도는 낮지만 robustness는 높은 것으로 확인되는데, Batch Normalization의 intermediate layer standardization은 robustness 개선 효과를 가지지만, 다른 side effect로 인해 이러한 효과가 상쇄되고 있을 것이라고 추측한다. FL2와 EL2에서도 Batch Normalization에 비해 robustness가 개선되는 결과를 확인할 수 있다.

4.4. Analysis: How Much Does This Explain the Regularization of BatchNorm?

이제 위의 결과를 기반으로, 제시한 mechanism과 Batch Normalization의 관계를 설명한다.
How much does BatchNorm regularize the norm at the last layer?
CIFAR-10 (상), ImageNet (하)
학습 초기에는 Fixup initialization으로 인해 학습이 상당히 안정적이지만, epoch이 지날수록 점차 generalization gap이 벌어지는 현상을 확인할 수 있다. 그리고 이러한 generalization gap이 embedding norm과 output norm의 explosive growth와 관련이 있음을 확인할 수 있다.
Does regularizing the norm to the same level as BatchNorm improve results?
우선 FL2, EL2 coefficient가 상승함에 따라 embedding/output norm이 작아지고 정확도가 상승하는 현상을 확인할 수 있다.
또 output norm이 Batch Normalization과 유사한 수준일 때 FL2 모델의 정확도가 훨씬 낮음을 확인할 수 있다. Output norm이 유사한 수준이 되는 penalty coefficient가 상당히 작은 수준이다.
한편 embedding norm이 유사한 수준일 때는 성능 역시 comparable함을 확인할 수 있다. 따라서 Batch Normalization과 embedding norm이 유사한 수준이 될 정도로 coefficient가 높아야 함을 시사하며, feature embedding normalization이 Batch Normalization의 gradient boost와 보다 밀접하게 관련됨을 암시한다.
Can we disentangle BatchNorm's effect on initialization to its effect on the norm at the last layer? Can we train networks with only a single BatchNorm layer at the feature embedding level?
이번에는 마지막 feature embedding layer에만 Batch Normalization을 적용하고 Fixup ablation을 수행함으로써 Batch Normalization의 initialization이 last layer norm에 주는 영향을 조사해본다.
우선 Fixup 없이는 보통 성능이 아주 엉망이 될 텐데, 73.8%라는 준수한 정확도를 보여준다. Fixup을 적용할 경우에는 74.9%로 정확도가 상승하는데, 적절한 initialization과 학습 동안의 feature embedding norm normalization만으로 Batch Normalization 효과의 상당 부분을 수복할 수 있음을 보여준다.
How much do these methods increase natural robustness to Dropout?
섹션 3.1에서 Batch Normalization과 Dropout, 그리고 EL2의 관계를 규명해보았는데, 실제로 Dropout 없이 학습된 모델이라도 Dropout에 대한 natural robustness를 상당 부분 지니게 됨을 확인할 수 있다.

4.5. Regularization with BatchNorm

Batch Normalization이 이미 feature embedding norm을 어느 정도 줄여주지만, 그 효과는 간접적이다. 한편 embedding L2는 이를 직접적으로 줄여준다.
이에 따라 Batch Normalization과 함께 embedding L2 regularization을 적용하면, (EfficientNet B8 모델에 Dropout을 사용하지 않은 상황에서) Dropout에 맞먹는 모델 성능 향상을 얻을 수 있음을 확인했다. Embedding L2가 Dropout에 비해 훨씬 단순한 regularization을 감안하면 놀라운 결과라고 주장한다.

5. Conclusion

Initialization에서 feature embedding norm의 explosive growth를 방지하고, 학습 중에 이러한 feature embedding norm을 regularization함으로써 Batch Normalization의 generalization 개선 효과를 상당 부분 재현할 수 있음을 확인했다. 또한 feature embedding norm regularization이라는 아주 간단한 additive regularization으로 Dropout에 맞먹는 결과를 얻을 수 있음을 확인했다.
E.O.D.