밑바닥부터 시작하는 딥러닝 1 - 계산 그래프와 오차역전파법
수치 미분은 이해하기 쉽지만 느리다. 매개변수가 많아질수록 하나씩 흔들어보는 방식은 현실적이지 않다.
오차역전파법은 최종 손실에서 시작해 각 노드가 저장해 둔 값을 꺼내며 거꾸로 이동하고, 그 과정에서 `dW`, `db`처럼 실제로 갱신에 필요한 기울기를 한 번의 backward 흐름으로 모은다.
이번 글에서 잡을 것
- 순전파는 입력에서 결과로 가는 계산이다.
- 역전파는 결과에서 입력 방향으로 미분값을 전달하는 계산이다.
- 덧셈 노드는 미분값을 그대로 흘린다.
- 곱셈 노드는 순전파 때의 입력값을 서로 바꿔 곱한다.
- `z=x*y` 같은 곱셈 노드는 x 방향으로는 y를, y 방향으로는 x를 곱해 미분값을 보낸다.
사과 쇼핑 예시
사과 가격이 100원, 사과 개수가 2개, 소비세가 1.1이라면 최종 금액은 220원이다. 순전파는 이 값을 계산하는 과정이다.
역전파는 무엇을 묻는가
역전파는 '최종 금액이 각 입력에 얼마나 민감한가'를 묻는다. 사과 가격이 1원 오르면 최종 금액은 얼마나 오르는가, 사과 개수가 1개 늘면 얼마나 오르는가를 계산한다.
| 입력 | 역전파 결과 | 해석 |
|---|---|---|
| 사과 가격 | 2.2 | 가격이 1원 오르면 최종 금액은 2.2원 증가 |
| 사과 개수 | 110 | 사과 1개가 최종 금액에 110원만큼 영향 |
| 소비세 | 200 | 세금 계수가 1 증가하면 200만큼 영향 |
왜 곱셈 노드는 두 갈래로 나뉘나
곱셈 노드에는 순전파 때 입력이 두 개 들어왔다. 역전파는 두 입력 각각이 최종 결과에 미친 영향을 구해야 하므로 두 갈래로 나뉜다. 입력 변수가 끝 지점이면 그곳에서 미분값 계산은 끝난다.
덧셈
z=x+y이면 x로 미분해도 1, y로 미분해도 1이다.
곱셈
z=x*y이면 x로 미분하면 y, y로 미분하면 x가 남는다.
연쇄법칙
상류에서 온 미분값에 현재 노드의 국소적 미분을 곱한다.
왜 순전파 값을 저장해야 할까
역전파 때 곱셈 노드는 순전파 때 들어온 x와 y를 다시 사용한다. ReLU는 입력이 0 이하였던 위치를 기억해야 하고, Sigmoid는 순전파 출력값을 사용한다. 그래서 계층은 순전파 때 필요한 값을 저장해둔다.
저장 없이는 역전파가 끊긴다
곱셈 노드가 순전파 때의 x와 y를 잊어버리면 `상류 미분 * 상대 입력값` 계산을 할 수 없다. 그래서 forward는 결과만 내는 함수가 아니라 backward를 위한 기록 단계이기도 하다.
역전파를 한 줄씩 추적하기
사과 쇼핑 예시에서 역전파는 최종 금액 220에서 시작한다. 출발 미분값은 1이다. 이후 곱셈 노드를 지날 때마다 순전파 때의 상대 입력값을 곱한다.
| 역방향 위치 | 상류 미분 | 곱하는 값 | 결과 |
|---|---|---|---|
| 최종금액 -> 사과총액 | 1 | 소비세 1.1 | 1.1 |
| 최종금액 -> 소비세 | 1 | 사과총액 200 | 200 |
| 사과총액 -> 사과가격 | 1.1 | 개수 2 | 2.2 |
| 사과총액 -> 사과개수 | 1.1 | 가격 100 | 110 |
소비세 방향은 왜 거기서 멈추나
소비세는 계산 그래프의 시작 입력 중 하나다. 역전파가 그 입력까지 도착했다면 목표였던 '소비세가 최종 결과에 미치는 영향'을 이미 구한 것이다. 그래서 더 거슬러 올라갈 노드가 없다.
중간 노드
이전 계산으로 더 거슬러 올라간다.
입력 노드
그 변수의 최종 기울기가 계산되면 멈춘다.
목표
모든 입력과 매개변수가 손실에 미친 영향을 구하는 것.
수치 미분과 비교하면 왜 빠른가
수치 미분은 매개변수 하나를 조금 바꿔보고 손실이 얼마나 변하는지 다시 계산한다. 매개변수가 10,000개면 이런 재계산을 거의 10,000번 해야 한다. 역전파는 순전파로 값을 저장한 뒤, 역방향 한 번으로 모든 매개변수의 기울기를 같이 얻는다.
| 방법 | 기울기 구하는 방식 | 매개변수 많을 때 |
|---|---|---|
| 수치 미분 | 매개변수 하나씩 흔들어 손실 재계산 | 매우 느려짐 |
| 오차역전파 | 저장된 순전파 값으로 미분값을 뒤로 전달 | 한 번의 backward 흐름으로 묶음 |
사과 쇼핑을 코드 변수로 쓰면
그림으로 본 역전파를 코드 변수로 쓰면 더 분명해진다. 핵심은 순전파 때의 상대 입력값이 역전파 때 곱해진다는 점이다.
apple = 100
apple_num = 2
tax = 1.1
apple_total = apple * apple_num
price = apple_total * tax
dprice = 1
dapple_total = dprice * tax # 1.1
dtax = dprice * apple_total # 200
dapple = dapple_total * apple_num # 2.2
dapple_num = dapple_total * apple # 110
print(price)
print(dapple, dapple_num, dtax)
예상 출력
220.00000000000003
2.2 110.00000000000001 200
스스로 점검
- 순전파와 역전파의 방향 차이를 설명할 수 있는가?
- 덧셈 노드와 곱셈 노드의 역전파 규칙을 말할 수 있는가?
- 역전파가 왜 순전파 때의 값을 저장해야 하는지 설명할 수 있는가?
이번 글에서 기억할 것
- 순전파는 입력에서 결과로 가는 계산이다.
- 역전파는 결과에서 입력 방향으로 미분값을 전달하는 계산이다.
- 덧셈 노드는 미분값을 그대로 흘린다.
- 곱셈 노드는 순전파 때의 입력값을 서로 바꿔 곱한다.
- `z=x*y` 같은 곱셈 노드는 x 방향으로는 y를, y 방향으로는 x를 곱해 미분값을 보낸다.
다음 글로 이어지는 질문
다음 글에서는 ReLU, Sigmoid, Affine 같은 실제 신경망 계층에서 역전파가 어떻게 구현되는지 본다.
한 줄 정리: 오차역전파법은 복잡한 미분을 한 번에 푸는 주문이 아니라, 작은 노드의 국소 미분을 연쇄법칙으로 이어붙이는 계산이다.