MNIST Lab 기본 구현 6편
Adam은 SGD보다 코드가 길지만, 생각은 단순하다. gradient의 방향 기록 m과 크기 기록 v를 따로 저장하고, 초반 보정을 거쳐 파라미터별 이동량을 조절한다.
1. Adam이 기억하는 두 가지
| 기록 | 의미 | 직관 |
|---|---|---|
| m | gradient의 이동평균 | 최근 어느 방향으로 계속 가려 했는지 기억한다. |
| v | gradient 제곱의 이동평균 | 어느 파라미터가 크게 흔들렸는지 기억한다. |
| t | update 횟수 | 초반에 m, v가 0에서 시작해 작게 잡히는 문제를 보정한다. |
SGD가 매 순간 gradient만 보고 걷는다면, Adam은 최근 방향과 흔들림 크기를 함께 보고 보폭을 조절한다.
2. 최종 구현 코드
class Adam:
def __init__(self, lr=0.001):
self.lr = lr
self.m, self.v = {}, {}
self.t = 0
def update(self, params, grads):
beta1, beta2 = 0.9, 0.999
eps = 1e-8
self.t += 1
for key in params.keys():
if key not in self.m:
self.m[key] = np.zeros_like(params[key])
self.v[key] = np.zeros_like(params[key])
self.m[key] = beta1 * self.m[key] + (1 - beta1) * grads[key]
self.v[key] = beta2 * self.v[key] + (1 - beta2) * (grads[key] ** 2)
m_hat = self.m[key] / (1 - beta1 ** self.t)
v_hat = self.v[key] / (1 - beta2 ** self.t)
params[key] -= self.lr * m_hat / (np.sqrt(v_hat) + eps)
3. 코드 흐름을 단계별로 읽기
처음 보는 파라미터 key라면 m, v를 같은 shape의 0 배열로 만든다.
이전 방향 기록에 현재 gradient를 조금 섞는다.
이전 크기 기록에 현재 gradient 제곱을 조금 섞는다.
0에서 시작한 m, v가 초반에 너무 작게 잡히는 것을 보정한다.
방향 기록을 크기 기록으로 나누어 파라미터별 이동량을 조절한다.
4. Adam을 SGD와 비교해서 읽기
Adam이 어려워 보이는 이유는 식이 길기 때문이다. 하지만 역할로 나누면 SGD에 두 개의 기억 장치가 붙은 구조로 볼 수 있다.
| 상황 | SGD의 반응 | Adam의 반응 |
|---|---|---|
| gradient 방향이 계속 비슷함 | 매번 그 순간의 gradient만큼 이동한다. | m이 같은 방향을 누적해 안정적으로 이동한다. |
| 특정 파라미터 gradient가 크게 흔들림 | 그 흔들림에도 같은 lr을 적용한다. | v가 흔들림 크기를 기억해 이동량을 조절한다. |
| 학습 초반 | 별도 보정 없이 바로 이동한다. | m, v가 0에서 시작해 작게 잡히므로 m_hat, v_hat으로 보정한다. |
| 0으로 나눌 위험 | 보통 단순식이라 덜 드러난다. | sqrt(v_hat)이 0에 가까울 수 있어 eps를 더한다. |
Adam 코드를 읽을 때는 수식을 한 번에 외우려 하지 말고 key별 저장소 생성, m 갱신, v 갱신, bias correction, params 갱신 순서로 끊어 읽으면 된다.
5. 테스트가 묻는 것
| 테스트 | 확인 조건 |
|---|---|
| test_adam_update_changes_params | Adam update 후 params 값이 초기값 그대로 남아 있지 않은가 |
Adam의 수식을 전부 검산하는 테스트가 아니라, 최소한 update가 동작해 파라미터가 변하는지 보는 테스트다. 그래서 글에서는 m, v, bias correction의 의미를 별도로 이해해야 한다.
이번 글에서 기억할 것
Adam은 gradient의 방향 기록과 크기 기록을 함께 사용해 파라미터별 이동량을 조절하는 optimizer다.
스스로 점검
- m과 v는 각각 무엇을 기억하는가?
- 왜 m_hat, v_hat으로 보정하는가?
- eps는 어떤 상황을 막기 위해 들어가는가?
다음 글 예고
다음 글에서는 NeuralNetwork를 구현한다. 지금까지 만든 계층들을 순서대로 조립해 하나의 모델로 만든다.