MNIST Lab 기본 구현 8편
BatchNorm은 미니배치의 평균과 분산을 이용해 값을 정규화한다. 구현에서는 학습 모드와 평가 모드가 다르고, backward에서는 gamma, beta, x에 대한 gradient를 모두 계산해야 한다.
1. BatchNorm의 역할
현재 미니배치 feature별 평균을 구한다.
feature별 값이 평균 주변에 얼마나 퍼졌는지 구한다.
평균 0, 분산 1에 가까운 값으로 바꾼다.
gamma와 beta로 다시 적절한 크기와 위치로 조정한다.
2. forward에서 값이 어떻게 변하는가
BatchNorm은 입력 shape를 바꾸지 않는다. 대신 각 feature 열의 분포를 정리한다. 예를 들어 x가 (128, 512)라면 128개 데이터마다 512개 feature가 있고, 평균과 분산은 feature별로 512개씩 만들어진다.
| 단계 | 계산 | shape 감각 | 의미 |
|---|---|---|---|
| 평균 | mu = mean(x, axis=0) |
(512,) | batch 128개를 훑어 feature별 중심을 잡는다. |
| 분산 | var = var(x, axis=0) |
(512,) | feature별 값이 얼마나 퍼졌는지 본다. |
| 정규화 | x_norm = (x - mu) / sqrt(var + eps) |
(128, 512) | 입력과 같은 shape를 유지한 채 분포만 맞춘다. |
| 복원 조정 | out = gamma * x_norm + beta |
(128, 512) | 학습 가능한 scale과 shift로 필요한 표현력을 되돌린다. |
정규화만 하면 모든 계층의 출력 분포를 강제로 비슷하게 만든다. gamma와 beta는 모델이 필요하면 다시 크게 만들거나 위치를 옮길 수 있게 해주는 학습 가능한 손잡이다.
3. forward 최종 코드
def forward(self, x, train=True):
if train:
mu = np.mean(x, axis=0)
var = np.var(x, axis=0)
self.x_centered = x - mu
self.std_inv = 1.0 / np.sqrt(var + self.eps)
self.x_norm = self.x_centered * self.std_inv
out = self.gamma * self.x_norm + self.beta
self.x = x
self.mu = mu
self.var = var
self.batch_size = x.shape[0]
self.running_mean = self.momentum * self.running_mean + (1 - self.momentum) * mu
self.running_var = self.momentum * self.running_var + (1 - self.momentum) * var
return out
x_norm = (x - self.running_mean) / np.sqrt(self.running_var + self.eps)
out = self.gamma * x_norm + self.beta
return out
학습할 때는 현재 배치 통계를 쓰고, 평가할 때는 학습 중 누적한 running_mean/running_var를 쓴다. test 데이터 한 묶음의 통계로 매번 결과가 흔들리면 안 되기 때문이다.
4. backward 최종 코드
def backward(self, dout):
self.dbeta = np.sum(dout, axis=0)
self.dgamma = np.sum(dout * self.x_norm, axis=0)
N = self.batch_size
dx_norm = dout * self.gamma
dstd_inv = np.sum(dx_norm * self.x_centered, axis=0)
dx_centered1 = dx_norm * self.std_inv
dvar = dstd_inv * (-0.5) * (self.var + self.eps) ** (-1.5)
dx_centered2 = (2.0 / N) * self.x_centered * dvar
dx_centered = dx_centered1 + dx_centered2
dmu = -np.sum(dx_centered, axis=0)
dx_mu = dmu / N
dx = dx_centered + dx_mu
return dx
| gradient | 계산 의미 | 어디로 가는가 |
|---|---|---|
| dbeta | 출력에 그대로 더해진 beta가 손실에 얼마나 영향을 줬는지 | grads['beta1'], grads['beta2']로 모인다. |
| dgamma | 정규화된 값 x_norm을 얼마나 키우거나 줄여야 하는지 | grads['gamma1'], grads['gamma2']로 모인다. |
| dx | BatchNorm 앞 계층으로 되돌려 보낼 gradient | 이전 Affine/ReLU 방향으로 계속 역전파된다. |
BatchNorm forward는 mean, var, sqrt, 나눗셈, gamma, beta가 이어진 합성 함수다. backward는 이 길을 거꾸로 따라가야 하므로 중간값을 저장하고 단계별로 gradient를 풀어야 한다.
5. 테스트가 묻는 것
| 테스트 | 확인 조건 |
|---|---|
| test_batchnorm_forward_shape | forward 출력 shape가 입력과 같은가 |
| test_batchnorm_backward_shape | backward 출력 dx shape가 입력과 같은가 |
현재 테스트는 shape 중심이다. 하지만 실제 학습에서는 running_mean/running_var가 평가 모드에서 쓰이는지, dgamma/dbeta가 grads로 모이는지도 함께 확인해야 한다.
이번 글에서 기억할 것
BatchNorm은 학습 때 batch 통계로 정규화하고, 평가 때 running 통계를 사용하는 계층이다.
스스로 점검
- 평가 모드에서 현재 batch 평균을 쓰면 왜 문제가 되는가?
- gamma와 beta는 각각 무엇을 조정하는가?
- BatchNorm backward가 복잡해지는 이유는 무엇인가?
다음 글 예고
다음 글에서는 Dropout을 구현한다. 학습 때 일부 뉴런을 끄고, 평가 때는 평균적인 출력 크기를 맞춘다.