카테고리 없음

MNIST Lab 8편 - BatchNorm forward/backward 직접 구현하기

cedis 2026. 5. 29. 00:23

MNIST Lab 기본 구현 8편

BatchNorm은 미니배치의 평균과 분산을 이용해 값을 정규화한다. 구현에서는 학습 모드와 평가 모드가 다르고, backward에서는 gamma, beta, x에 대한 gradient를 모두 계산해야 한다.

1. BatchNorm의 역할

1
평균 계산

현재 미니배치 feature별 평균을 구한다.

2
분산 계산

feature별 값이 평균 주변에 얼마나 퍼졌는지 구한다.

3
정규화

평균 0, 분산 1에 가까운 값으로 바꾼다.

4
scale/shift

gamma와 beta로 다시 적절한 크기와 위치로 조정한다.

2. forward에서 값이 어떻게 변하는가

BatchNorm은 입력 shape를 바꾸지 않는다. 대신 각 feature 열의 분포를 정리한다. 예를 들어 x가 (128, 512)라면 128개 데이터마다 512개 feature가 있고, 평균과 분산은 feature별로 512개씩 만들어진다.

단계 계산 shape 감각 의미
평균 mu = mean(x, axis=0) (512,) batch 128개를 훑어 feature별 중심을 잡는다.
분산 var = var(x, axis=0) (512,) feature별 값이 얼마나 퍼졌는지 본다.
정규화 x_norm = (x - mu) / sqrt(var + eps) (128, 512) 입력과 같은 shape를 유지한 채 분포만 맞춘다.
복원 조정 out = gamma * x_norm + beta (128, 512) 학습 가능한 scale과 shift로 필요한 표현력을 되돌린다.
평균 0, 분산 1로 끝내지 않는 이유

정규화만 하면 모든 계층의 출력 분포를 강제로 비슷하게 만든다. gamma와 beta는 모델이 필요하면 다시 크게 만들거나 위치를 옮길 수 있게 해주는 학습 가능한 손잡이다.

3. forward 최종 코드

def forward(self, x, train=True):
    if train:
        mu = np.mean(x, axis=0)
        var = np.var(x, axis=0)

        self.x_centered = x - mu
        self.std_inv = 1.0 / np.sqrt(var + self.eps)
        self.x_norm = self.x_centered * self.std_inv
        out = self.gamma * self.x_norm + self.beta

        self.x = x
        self.mu = mu
        self.var = var
        self.batch_size = x.shape[0]

        self.running_mean = self.momentum * self.running_mean + (1 - self.momentum) * mu
        self.running_var = self.momentum * self.running_var + (1 - self.momentum) * var
        return out

    x_norm = (x - self.running_mean) / np.sqrt(self.running_var + self.eps)
    out = self.gamma * x_norm + self.beta
    return out
train=True와 train=False 차이

학습할 때는 현재 배치 통계를 쓰고, 평가할 때는 학습 중 누적한 running_mean/running_var를 쓴다. test 데이터 한 묶음의 통계로 매번 결과가 흔들리면 안 되기 때문이다.

4. backward 최종 코드

def backward(self, dout):
    self.dbeta = np.sum(dout, axis=0)
    self.dgamma = np.sum(dout * self.x_norm, axis=0)

    N = self.batch_size
    dx_norm = dout * self.gamma
    dstd_inv = np.sum(dx_norm * self.x_centered, axis=0)
    dx_centered1 = dx_norm * self.std_inv

    dvar = dstd_inv * (-0.5) * (self.var + self.eps) ** (-1.5)
    dx_centered2 = (2.0 / N) * self.x_centered * dvar

    dx_centered = dx_centered1 + dx_centered2
    dmu = -np.sum(dx_centered, axis=0)
    dx_mu = dmu / N
    dx = dx_centered + dx_mu
    return dx
gradient 계산 의미 어디로 가는가
dbeta 출력에 그대로 더해진 beta가 손실에 얼마나 영향을 줬는지 grads['beta1'], grads['beta2']로 모인다.
dgamma 정규화된 값 x_norm을 얼마나 키우거나 줄여야 하는지 grads['gamma1'], grads['gamma2']로 모인다.
dx BatchNorm 앞 계층으로 되돌려 보낼 gradient 이전 Affine/ReLU 방향으로 계속 역전파된다.
backward가 길어지는 이유

BatchNorm forward는 mean, var, sqrt, 나눗셈, gamma, beta가 이어진 합성 함수다. backward는 이 길을 거꾸로 따라가야 하므로 중간값을 저장하고 단계별로 gradient를 풀어야 한다.

5. 테스트가 묻는 것

테스트 확인 조건
test_batchnorm_forward_shape forward 출력 shape가 입력과 같은가
test_batchnorm_backward_shape backward 출력 dx shape가 입력과 같은가
테스트보다 더 봐야 할 것

현재 테스트는 shape 중심이다. 하지만 실제 학습에서는 running_mean/running_var가 평가 모드에서 쓰이는지, dgamma/dbeta가 grads로 모이는지도 함께 확인해야 한다.

이번 글에서 기억할 것

BatchNorm은 학습 때 batch 통계로 정규화하고, 평가 때 running 통계를 사용하는 계층이다.

스스로 점검

  1. 평가 모드에서 현재 batch 평균을 쓰면 왜 문제가 되는가?
  2. gamma와 beta는 각각 무엇을 조정하는가?
  3. BatchNorm backward가 복잡해지는 이유는 무엇인가?

다음 글 예고

다음 글에서는 Dropout을 구현한다. 학습 때 일부 뉴런을 끄고, 평가 때는 평균적인 출력 크기를 맞춘다.

한 줄 정리

BatchNorm은 학습 때 batch 통계로 정규화하고, 평가 때 running 통계를 사용하는 계층이다.