행렬 곱 역전파에서 전치행렬이 등장하는 이유
학습원리 글 목록 행렬 곱 노드의 순전파가 다음과 같다고 하자.
z=Wx
이때 역전파에서는 다음 식이 등장한다.
∂x∂L=WT∂z∂L
처음 보면 왜 W가 아니라 WT가 곱해지는지 헷갈릴 수 있다.
핵심은 역전파가 출력 쪽에서 받은 gradient를 입력 쪽으로 전달하는 과정이고, 이때 다변수 연쇄 법칙을 적용하면 야코비안의 전치가 곱해진다는 점이다.
차원 설정
먼저 차원을 다음과 같이 둔다.
W∈Rm×n
x∈Rn
z∈Rm
그러면 순전파는 다음과 같이 성립한다.
z=Wx
손실 함수 L은 최종적으로 z에 의존한다고 하자.
L=L(z)
이 글에서는 gradient를 열벡터로 두는 표기를 사용한다.
따라서 출력 쪽에서 전달되는 상류 gradient는 다음 차원을 가진다.
∂z∂L∈Rm
입력 x에 대한 gradient는 다음 차원이어야 한다.
∂x∂L∈Rn
그런데 WT의 차원은 다음과 같다.
WT∈Rn×m
따라서 다음 곱은 차원상 자연스럽다.
WT∂z∂L∈Rn
즉, 결과가 정확히 x와 같은 차원의 gradient가 된다.
Differential 관점
스칼라 함수 L에 대해 differential은 다음과 같이 쓸 수 있다.
dL=(∂x∂L)Tdx
또는 z를 기준으로는 다음과 같이 쓸 수 있다.
dL=(∂z∂L)Tdz
이제 순전파 식이 다음과 같으므로,
z=Wx
x가 아주 조금 변할 때 z의 변화량은 다음과 같다.
dz=Wdx
이를 dL 식에 대입한다.
dL=(∂z∂L)Tdz
dL=(∂z∂L)TWdx
이 식은 아직 z 쪽 gradient를 기준으로 쓰여 있다.
하지만 우리가 구하고 싶은 것은 x에 대한 gradient이다.
그래서 이 식을 dx에 대한 선형형식으로 다시 정리해야 한다.
전치가 등장하는 지점
현재 식은 다음과 같다.
dL=(∂z∂L)TWdx
여기서 전체 값은 스칼라이다.
스칼라는 전치해도 값이 변하지 않는다.
따라서 다음처럼 볼 수 있다.
(∂z∂L)TWdx=(WT∂z∂L)Tdx
이 변형은 전치 공식에서 나온다.
(ABC)T=CTBTAT
즉,
[(∂z∂L)TWdx]T=(dx)TWT∂z∂L
이고, 벡터 내적은 스칼라이므로 다음처럼 다시 쓸 수 있다.
(dx)TWT∂z∂L=(WT∂z∂L)Tdx
결국 dL은 다음과 같이 정리된다.
dL=(WT∂z∂L)Tdx
한편 gradient의 정의에 의해 x 기준 differential은 다음과 같다.
dL=(∂x∂L)Tdx
두 식은 모든 dx에 대해 같은 값을 가져야 한다.
따라서 dx 앞에 붙은 계수 벡터가 같아야 한다.
∂x∂L=WT∂z∂L
이것이 행렬 곱 역전파에서 WT가 등장하는 이유이다.
단순히 dx를 약분해서 얻은 것이 아니라, 모든 dx에 대해 두 선형형식이 같아야 하므로 계수 벡터를 비교한 것이다.
Chain Rule 관점
같은 내용을 다변수 연쇄 법칙으로도 볼 수 있다.
열벡터 gradient 표기에서는 다음 형태를 사용한다.
∂x∂L=(∂x∂z)T∂z∂L
여기서 ∂x∂z는 z를 x로 미분한 야코비안이다.
그런데
z=Wx
이므로,
∂x∂z=W
이다.
따라서 다변수 연쇄 법칙에 의해 다음이 된다.
∂x∂L=WT∂z∂L
즉, WT는 임의로 붙는 것이 아니라 야코비안 W의 전치가 곱해진 결과이다.
원소 단위로 보기
조금 더 직접적으로 보기 위해 2차원 예를 생각해보자.
W=[w11w21w12w22],x=[x1x2]
그러면 출력은 다음과 같다.
z=[z1z2]=[w11x1+w12x2w21x1+w22x2]
출력 쪽 gradient를 다음처럼 두자.
∂z∂L=[g1g2]
그러면 x1에 대한 gradient는 z1, z2를 거쳐 들어오는 영향을 모두 더해야 한다.
∂x1∂L=∂z1∂L∂x1∂z1+∂z2∂L∂x1∂z2
=g1w11+g2w21
x2에 대해서도 마찬가지이다.
∂x2∂L=g1w12+g2w22
따라서 입력 쪽 gradient는 다음과 같다.
∂x∂L=[g1w11+g2w21g1w12+g2w22]
이 식은 다음 행렬 곱과 같다.
[w11w12w21w22][g1g2]
앞의 행렬은 원래 W가 아니라 WT이다.
WT=[w11w12w21w22]
그래서 원소 단위로 계산해도 다음과 같은 결론이 나온다.
∂x∂L=WT∂z∂L
직관
순전파에서는 입력 x가 행렬 W를 통과해 출력 z가 된다.
x⟶z=Wx
역전파에서는 출력 쪽 gradient가 입력 쪽으로 되돌아간다.
이때 각 입력 원소는 여러 출력 원소에 영향을 주었으므로, 그 영향을 다시 모아서 받아야 한다.
이 모으는 과정이 WT를 곱하는 형태로 표현된다.
즉, WT는 단순히 방향이 반대라서 붙는 기호가 아니다.
순전파에서 W가 입력을 출력으로 보낸 방식과 정확히 대응되도록, 출력 쪽 gradient를 입력 쪽 gradient로 다시 모아주는 연산이다.
forward에서 W는 입력 gradient가 아니라 입력값을 출력 방향으로 “퍼뜨리는” 행렬이고, backward에서 WT는 출력 쪽 gradient를 각 입력이 보냈던 경로 기준으로 다시 “수거해서 합산하는” 행렬이다.
가중치에 대한 Gradient
같은 노드에서 가중치 행렬 W에 대한 gradient도 구할 수 있다.
순전파가 다음과 같을 때,
z=Wx
출력 쪽 gradient를 ∂z∂L라고 하면,
∂W∂L=∂z∂LxT
이다.
차원을 보면 더 자연스럽다.
∂z∂L∈Rm
xT∈R1×n
따라서
∂z∂LxT∈Rm×n
이고, 이는 W와 같은 차원이다.
즉, 선형 계층의 기본 역전파는 다음 두 식으로 정리된다.
∂x∂L=WT∂z∂L
∂W∂L=∂z∂LxT
정리
행렬 곱 노드가 다음과 같을 때,
z=Wx
입력 x에 대한 역전파는 다음과 같다.
∂x∂L=WT∂z∂L
핵심 흐름은 다음이다.
dL=(∂z∂L)Tdz
dz=Wdx
따라서,
dL=(∂z∂L)TWdx
이를 dx에 대한 선형형식으로 정리하면,
dL=(WT∂z∂L)Tdx
한편 gradient 정의상,
dL=(∂x∂L)Tdx
이므로,
∂x∂L=WT∂z∂L
이다.
즉, 행렬 곱 역전파에서 전치행렬이 등장하는 이유는 출력 쪽 gradient를 입력 쪽 gradient로 되돌릴 때 야코비안의 전치가 곱해지기 때문이다.
이 과정이 수식으로는 야코비안의 전치 WT를 곱하는 형태로 나타난다.