MNIST Lab 기본 구현 3편
Affine은 완전연결층이다. forward에서는 xW+b를 계산하고, backward에서는 이전 입력으로 보낼 dx와 파라미터를 고칠 dW, db를 만든다.
1. Affine은 차원을 바꾸는 계층이다
| 값 | shape 예시 | 역할 |
|---|---|---|
| x | (5, 4) | batch 5개, 입력 feature 4개 |
| W | (4, 3) | 4차원 입력을 3차원 출력으로 바꾸는 가중치 |
| b | (3,) | 출력 feature마다 더하는 편향 |
| out | (5, 3) | 각 데이터가 3차원 점수로 변환된 결과 |
이 계층은 숫자를 섞는 계층이면서 동시에 차원을 바꾸는 계층이다. 그래서 forward 결과만 맞아도 부족하고, backward에서 dx, dW, db가 각각 원래 자리로 돌아갈 수 있는 shape인지까지 확인해야 한다.
2. forward shape를 눈으로 따라가기
작은 예시로 보면 Affine의 역할이 훨씬 분명하다. batch가 2개이고 입력 feature가 3개라면, 2개 데이터를 각각 2차원 출력으로 바꾸는 계산은 아래처럼 읽으면 된다.
x shape = (2, 3). 데이터 2개가 있고, 각 데이터는 feature 3개를 가진다.
W shape = (3, 2). 3차원 입력을 2차원 출력으로 바꾸는 변환 표다.
x @ W shape = (2, 2). batch 개수 2는 유지되고, feature 차원만 3에서 2로 바뀐다.
b shape = (2,)가 각 데이터의 출력 feature마다 더해져 최종 out shape도 (2, 2)가 된다.
| 검산 지점 | 확인할 내용 |
|---|---|
| batch 축 | forward와 backward를 지나도 데이터 개수 축은 사라지면 안 된다. |
| feature 축 | W가 입력 feature를 출력 feature로 바꾸는 방향이어야 한다. |
| broadcast | b는 batch마다 반복해서 더해지므로 shape가 출력 feature와 맞아야 한다. |
3. 최종 구현 코드
class Affine:
def __init__(self, W, b):
self.W = W
self.b = b
self.x = None
self.dW = None
self.db = None
def forward(self, x):
self.x = x
out = np.dot(x, self.W) + self.b
return out
def backward(self, dout):
dx = np.dot(dout, self.W.T)
self.dW = np.dot(self.x.T, dout)
self.db = np.sum(dout, axis=0)
return dx
4. backward 공식은 shape로 검산한다
| 계산 | 결과 shape | 왜 이 순서인가 |
|---|---|---|
dx = dout @ W.T |
x와 같은 shape | 이전 계층으로 gradient를 돌려보내야 한다. |
dW = x.T @ dout |
W와 같은 shape | optimizer가 W를 수정하려면 W와 shape가 같아야 한다. |
db = sum(dout, axis=0) |
b와 같은 shape | b는 batch의 모든 데이터에 더해졌으므로 batch 방향으로 합친다. |
Affine backward는 수식보다 shape 검산이 먼저다. dW가 W와 같은 모양이 아니면 optimizer가 파라미터를 갱신할 수 없다.
5. 테스트가 묻는 것
| 테스트 | 확인 조건 |
|---|---|
| test_affine_forward_shape | forward 결과가 x @ W + b와 같은가 |
| test_affine_backward_grad_shape | dx, dW, db의 shape가 각각 x, W, b와 맞는가 |
이번 글에서 기억할 것
Affine 계층은 forward의 xW+b보다 backward의 dx, dW, db shape를 정확히 맞추는 것이 핵심이다.
스스로 점검
- 왜 forward에서 self.x를 저장해야 하는가?
- dW 계산에서 x.T가 앞에 오는 이유는 무엇인가?
- db는 왜 axis=0으로 합치는가?
다음 글 예고
다음 글에서는 Cross Entropy Loss를 구현한다. 예측 확률 중 정답 칸만 뽑아 벌점을 만드는 함수다.