카테고리 없음

MNIST Lab 3편 - Affine 계층 xW+b와 backward shape 구현하기

cedis 2026. 5. 28. 02:06

MNIST Lab 기본 구현 3편

Affine은 완전연결층이다. forward에서는 xW+b를 계산하고, backward에서는 이전 입력으로 보낼 dx와 파라미터를 고칠 dW, db를 만든다.

1. Affine은 차원을 바꾸는 계층이다

shape 예시 역할
x (5, 4) batch 5개, 입력 feature 4개
W (4, 3) 4차원 입력을 3차원 출력으로 바꾸는 가중치
b (3,) 출력 feature마다 더하는 편향
out (5, 3) 각 데이터가 3차원 점수로 변환된 결과
Affine을 볼 때 먼저 봐야 하는 것

이 계층은 숫자를 섞는 계층이면서 동시에 차원을 바꾸는 계층이다. 그래서 forward 결과만 맞아도 부족하고, backward에서 dx, dW, db가 각각 원래 자리로 돌아갈 수 있는 shape인지까지 확인해야 한다.

2. forward shape를 눈으로 따라가기

작은 예시로 보면 Affine의 역할이 훨씬 분명하다. batch가 2개이고 입력 feature가 3개라면, 2개 데이터를 각각 2차원 출력으로 바꾸는 계산은 아래처럼 읽으면 된다.

1
입력 묶음

x shape = (2, 3). 데이터 2개가 있고, 각 데이터는 feature 3개를 가진다.

2
가중치

W shape = (3, 2). 3차원 입력을 2차원 출력으로 바꾸는 변환 표다.

3
행렬 곱

x @ W shape = (2, 2). batch 개수 2는 유지되고, feature 차원만 3에서 2로 바뀐다.

4
편향 더하기

b shape = (2,)가 각 데이터의 출력 feature마다 더해져 최종 out shape도 (2, 2)가 된다.

검산 지점 확인할 내용
batch 축 forward와 backward를 지나도 데이터 개수 축은 사라지면 안 된다.
feature 축 W가 입력 feature를 출력 feature로 바꾸는 방향이어야 한다.
broadcast b는 batch마다 반복해서 더해지므로 shape가 출력 feature와 맞아야 한다.

3. 최종 구현 코드

class Affine:
    def __init__(self, W, b):
        self.W = W
        self.b = b
        self.x = None
        self.dW = None
        self.db = None

    def forward(self, x):
        self.x = x
        out = np.dot(x, self.W) + self.b
        return out

    def backward(self, dout):
        dx = np.dot(dout, self.W.T)
        self.dW = np.dot(self.x.T, dout)
        self.db = np.sum(dout, axis=0)
        return dx

4. backward 공식은 shape로 검산한다

계산 결과 shape 왜 이 순서인가
dx = dout @ W.T x와 같은 shape 이전 계층으로 gradient를 돌려보내야 한다.
dW = x.T @ dout W와 같은 shape optimizer가 W를 수정하려면 W와 shape가 같아야 한다.
db = sum(dout, axis=0) b와 같은 shape b는 batch의 모든 데이터에 더해졌으므로 batch 방향으로 합친다.
shape가 틀리면 구현이 거의 틀린 것이다

Affine backward는 수식보다 shape 검산이 먼저다. dW가 W와 같은 모양이 아니면 optimizer가 파라미터를 갱신할 수 없다.

5. 테스트가 묻는 것

테스트 확인 조건
test_affine_forward_shape forward 결과가 x @ W + b와 같은가
test_affine_backward_grad_shape dx, dW, db의 shape가 각각 x, W, b와 맞는가

이번 글에서 기억할 것

Affine 계층은 forward의 xW+b보다 backward의 dx, dW, db shape를 정확히 맞추는 것이 핵심이다.

스스로 점검

  1. 왜 forward에서 self.x를 저장해야 하는가?
  2. dW 계산에서 x.T가 앞에 오는 이유는 무엇인가?
  3. db는 왜 axis=0으로 합치는가?

다음 글 예고

다음 글에서는 Cross Entropy Loss를 구현한다. 예측 확률 중 정답 칸만 뽑아 벌점을 만드는 함수다.

한 줄 정리

Affine 계층은 forward의 xW+b보다 backward의 dx, dW, db shape를 정확히 맞추는 것이 핵심이다.