행렬 곱 역전파에서 전치행렬이 등장하는 이유

학습원리 글 목록

행렬 곱 노드의 순전파가 다음과 같다고 하자.

z=Wx\mathbf{z} = \mathbf{W}\mathbf{x}

이때 역전파에서는 다음 식이 등장한다.

Lx=WTLz\frac{\partial L}{\partial \mathbf{x}} = \mathbf{W}^{T} \frac{\partial L}{\partial \mathbf{z}}

처음 보면 왜 W\mathbf{W}가 아니라 WT\mathbf{W}^{T}가 곱해지는지 헷갈릴 수 있다.

핵심은 역전파가 출력 쪽에서 받은 gradient를 입력 쪽으로 전달하는 과정이고, 이때 다변수 연쇄 법칙을 적용하면 야코비안의 전치가 곱해진다는 점이다.

차원 설정

먼저 차원을 다음과 같이 둔다.

WRm×n\mathbf{W} \in \mathbb{R}^{m \times n} xRn\mathbf{x} \in \mathbb{R}^{n} zRm\mathbf{z} \in \mathbb{R}^{m}

그러면 순전파는 다음과 같이 성립한다.

z=Wx\mathbf{z} = \mathbf{W}\mathbf{x}

손실 함수 LL은 최종적으로 z\mathbf{z}에 의존한다고 하자.

L=L(z)L = L(\mathbf{z})

이 글에서는 gradient를 열벡터로 두는 표기를 사용한다.

따라서 출력 쪽에서 전달되는 상류 gradient는 다음 차원을 가진다.

LzRm\frac{\partial L}{\partial \mathbf{z}} \in \mathbb{R}^{m}

입력 x\mathbf{x}에 대한 gradient는 다음 차원이어야 한다.

LxRn\frac{\partial L}{\partial \mathbf{x}} \in \mathbb{R}^{n}

그런데 WT\mathbf{W}^{T}의 차원은 다음과 같다.

WTRn×m\mathbf{W}^{T} \in \mathbb{R}^{n \times m}

따라서 다음 곱은 차원상 자연스럽다.

WTLzRn\mathbf{W}^{T} \frac{\partial L}{\partial \mathbf{z}} \in \mathbb{R}^{n}

즉, 결과가 정확히 x\mathbf{x}와 같은 차원의 gradient가 된다.

Differential 관점

스칼라 함수 LL에 대해 differential은 다음과 같이 쓸 수 있다.

dL=(Lx)TdxdL = \left( \frac{\partial L}{\partial \mathbf{x}} \right)^T d\mathbf{x}

또는 z\mathbf{z}를 기준으로는 다음과 같이 쓸 수 있다.

dL=(Lz)TdzdL = \left( \frac{\partial L}{\partial \mathbf{z}} \right)^T d\mathbf{z}

이제 순전파 식이 다음과 같으므로,

z=Wx\mathbf{z} = \mathbf{W}\mathbf{x}

x\mathbf{x}가 아주 조금 변할 때 z\mathbf{z}의 변화량은 다음과 같다.

dz=Wdxd\mathbf{z} = \mathbf{W}d\mathbf{x}

이를 dLdL 식에 대입한다.

dL=(Lz)TdzdL = \left( \frac{\partial L}{\partial \mathbf{z}} \right)^T d\mathbf{z} dL=(Lz)TWdxdL = \left( \frac{\partial L}{\partial \mathbf{z}} \right)^T \mathbf{W}d\mathbf{x}

이 식은 아직 z\mathbf{z} 쪽 gradient를 기준으로 쓰여 있다.

하지만 우리가 구하고 싶은 것은 x\mathbf{x}에 대한 gradient이다.

그래서 이 식을 dxd\mathbf{x}에 대한 선형형식으로 다시 정리해야 한다.

전치가 등장하는 지점

현재 식은 다음과 같다.

dL=(Lz)TWdxdL = \left( \frac{\partial L}{\partial \mathbf{z}} \right)^T \mathbf{W}d\mathbf{x}

여기서 전체 값은 스칼라이다.

스칼라는 전치해도 값이 변하지 않는다.

따라서 다음처럼 볼 수 있다.

(Lz)TWdx=(WTLz)Tdx\left( \frac{\partial L}{\partial \mathbf{z}} \right)^T \mathbf{W}d\mathbf{x} = \left( \mathbf{W}^{T} \frac{\partial L}{\partial \mathbf{z}} \right)^T d\mathbf{x}

이 변형은 전치 공식에서 나온다.

(ABC)T=CTBTAT(ABC)^T = C^T B^T A^T

즉,

[(Lz)TWdx]T=(dx)TWTLz\left[ \left( \frac{\partial L}{\partial \mathbf{z}} \right)^T \mathbf{W}d\mathbf{x} \right]^T = (d\mathbf{x})^T \mathbf{W}^{T} \frac{\partial L}{\partial \mathbf{z}}

이고, 벡터 내적은 스칼라이므로 다음처럼 다시 쓸 수 있다.

(dx)TWTLz=(WTLz)Tdx(d\mathbf{x})^T \mathbf{W}^{T} \frac{\partial L}{\partial \mathbf{z}} = \left( \mathbf{W}^{T} \frac{\partial L}{\partial \mathbf{z}} \right)^T d\mathbf{x}

결국 dLdL은 다음과 같이 정리된다.

dL=(WTLz)TdxdL = \left( \mathbf{W}^{T} \frac{\partial L}{\partial \mathbf{z}} \right)^T d\mathbf{x}

한편 gradient의 정의에 의해 x\mathbf{x} 기준 differential은 다음과 같다.

dL=(Lx)TdxdL = \left( \frac{\partial L}{\partial \mathbf{x}} \right)^T d\mathbf{x}

두 식은 모든 dxd\mathbf{x}에 대해 같은 값을 가져야 한다.

따라서 dxd\mathbf{x} 앞에 붙은 계수 벡터가 같아야 한다.

Lx=WTLz\frac{\partial L}{\partial \mathbf{x}} = \mathbf{W}^{T} \frac{\partial L}{\partial \mathbf{z}}

이것이 행렬 곱 역전파에서 WT\mathbf{W}^{T}가 등장하는 이유이다.

단순히 dxd\mathbf{x}를 약분해서 얻은 것이 아니라, 모든 dxd\mathbf{x}에 대해 두 선형형식이 같아야 하므로 계수 벡터를 비교한 것이다.

Chain Rule 관점

같은 내용을 다변수 연쇄 법칙으로도 볼 수 있다.

열벡터 gradient 표기에서는 다음 형태를 사용한다.

Lx=(zx)TLz\frac{\partial L}{\partial \mathbf{x}} = \left( \frac{\partial \mathbf{z}}{\partial \mathbf{x}} \right)^T \frac{\partial L}{\partial \mathbf{z}}

여기서 zx\frac{\partial \mathbf{z}}{\partial \mathbf{x}}z\mathbf{z}x\mathbf{x}로 미분한 야코비안이다.

그런데

z=Wx\mathbf{z} = \mathbf{W}\mathbf{x}

이므로,

zx=W\frac{\partial \mathbf{z}}{\partial \mathbf{x}} = \mathbf{W}

이다.

따라서 다변수 연쇄 법칙에 의해 다음이 된다.

Lx=WTLz\frac{\partial L}{\partial \mathbf{x}} = \mathbf{W}^{T} \frac{\partial L}{\partial \mathbf{z}}

즉, WT\mathbf{W}^{T}는 임의로 붙는 것이 아니라 야코비안 W\mathbf{W}의 전치가 곱해진 결과이다.

원소 단위로 보기

조금 더 직접적으로 보기 위해 2차원 예를 생각해보자.

W=[w11w12w21w22],x=[x1x2]\mathbf{W} = \begin{bmatrix} w_{11} & w_{12} \\ w_{21} & w_{22} \end{bmatrix}, \qquad \mathbf{x} = \begin{bmatrix} x_1 \\ x_2 \end{bmatrix}

그러면 출력은 다음과 같다.

z=[z1z2]=[w11x1+w12x2w21x1+w22x2]\mathbf{z} = \begin{bmatrix} z_1 \\ z_2 \end{bmatrix} = \begin{bmatrix} w_{11}x_1 + w_{12}x_2 \\ w_{21}x_1 + w_{22}x_2 \end{bmatrix}

출력 쪽 gradient를 다음처럼 두자.

Lz=[g1g2]\frac{\partial L}{\partial \mathbf{z}} = \begin{bmatrix} g_1 \\ g_2 \end{bmatrix}

그러면 x1x_1에 대한 gradient는 z1z_1, z2z_2를 거쳐 들어오는 영향을 모두 더해야 한다.

Lx1=Lz1z1x1+Lz2z2x1\frac{\partial L}{\partial x_1} = \frac{\partial L}{\partial z_1} \frac{\partial z_1}{\partial x_1} + \frac{\partial L}{\partial z_2} \frac{\partial z_2}{\partial x_1} =g1w11+g2w21= g_1w_{11} + g_2w_{21}

x2x_2에 대해서도 마찬가지이다.

Lx2=g1w12+g2w22\frac{\partial L}{\partial x_2} = g_1w_{12} + g_2w_{22}

따라서 입력 쪽 gradient는 다음과 같다.

Lx=[g1w11+g2w21g1w12+g2w22]\frac{\partial L}{\partial \mathbf{x}} = \begin{bmatrix} g_1w_{11} + g_2w_{21} \\ g_1w_{12} + g_2w_{22} \end{bmatrix}

이 식은 다음 행렬 곱과 같다.

[w11w21w12w22][g1g2]\begin{bmatrix} w_{11} & w_{21} \\ w_{12} & w_{22} \end{bmatrix} \begin{bmatrix} g_1 \\ g_2 \end{bmatrix}

앞의 행렬은 원래 W\mathbf{W}가 아니라 WT\mathbf{W}^{T}이다.

WT=[w11w21w12w22]\mathbf{W}^{T} = \begin{bmatrix} w_{11} & w_{21} \\ w_{12} & w_{22} \end{bmatrix}

그래서 원소 단위로 계산해도 다음과 같은 결론이 나온다.

Lx=WTLz\frac{\partial L}{\partial \mathbf{x}} = \mathbf{W}^{T} \frac{\partial L}{\partial \mathbf{z}}

직관

순전파에서는 입력 x\mathbf{x}가 행렬 W\mathbf{W}를 통과해 출력 z\mathbf{z}가 된다.

xz=Wx\mathbf{x} \longrightarrow \mathbf{z} = \mathbf{W}\mathbf{x}

역전파에서는 출력 쪽 gradient가 입력 쪽으로 되돌아간다.

이때 각 입력 원소는 여러 출력 원소에 영향을 주었으므로, 그 영향을 다시 모아서 받아야 한다.

이 모으는 과정이 WT\mathbf{W}^{T}를 곱하는 형태로 표현된다.

즉, WT\mathbf{W}^{T}는 단순히 방향이 반대라서 붙는 기호가 아니다.

순전파에서 W\mathbf{W}가 입력을 출력으로 보낸 방식과 정확히 대응되도록, 출력 쪽 gradient를 입력 쪽 gradient로 다시 모아주는 연산이다.

forward에서 W\mathbf{W}는 입력 gradient가 아니라 입력값을 출력 방향으로 “퍼뜨리는” 행렬이고, backward에서 WT\mathbf{W}^{T}는 출력 쪽 gradient를 각 입력이 보냈던 경로 기준으로 다시 “수거해서 합산하는” 행렬이다.

가중치에 대한 Gradient

같은 노드에서 가중치 행렬 W\mathbf{W}에 대한 gradient도 구할 수 있다.

순전파가 다음과 같을 때,

z=Wx\mathbf{z} = \mathbf{W}\mathbf{x}

출력 쪽 gradient를 Lz\frac{\partial L}{\partial \mathbf{z}}라고 하면,

LW=LzxT\frac{\partial L}{\partial \mathbf{W}} = \frac{\partial L}{\partial \mathbf{z}} \mathbf{x}^{T}

이다.

차원을 보면 더 자연스럽다.

LzRm\frac{\partial L}{\partial \mathbf{z}} \in \mathbb{R}^{m} xTR1×n\mathbf{x}^{T} \in \mathbb{R}^{1 \times n}

따라서

LzxTRm×n\frac{\partial L}{\partial \mathbf{z}} \mathbf{x}^{T} \in \mathbb{R}^{m \times n}

이고, 이는 W\mathbf{W}와 같은 차원이다.

즉, 선형 계층의 기본 역전파는 다음 두 식으로 정리된다.

Lx=WTLz\frac{\partial L}{\partial \mathbf{x}} = \mathbf{W}^{T} \frac{\partial L}{\partial \mathbf{z}} LW=LzxT\frac{\partial L}{\partial \mathbf{W}} = \frac{\partial L}{\partial \mathbf{z}} \mathbf{x}^{T}

정리

행렬 곱 노드가 다음과 같을 때,

z=Wx\mathbf{z} = \mathbf{W}\mathbf{x}

입력 x\mathbf{x}에 대한 역전파는 다음과 같다.

Lx=WTLz\frac{\partial L}{\partial \mathbf{x}} = \mathbf{W}^{T} \frac{\partial L}{\partial \mathbf{z}}

핵심 흐름은 다음이다.

dL=(Lz)TdzdL = \left( \frac{\partial L}{\partial \mathbf{z}} \right)^T d\mathbf{z} dz=Wdxd\mathbf{z} = \mathbf{W}d\mathbf{x}

따라서,

dL=(Lz)TWdxdL = \left( \frac{\partial L}{\partial \mathbf{z}} \right)^T \mathbf{W}d\mathbf{x}

이를 dxd\mathbf{x}에 대한 선형형식으로 정리하면,

dL=(WTLz)TdxdL = \left( \mathbf{W}^{T} \frac{\partial L}{\partial \mathbf{z}} \right)^T d\mathbf{x}

한편 gradient 정의상,

dL=(Lx)TdxdL = \left( \frac{\partial L}{\partial \mathbf{x}} \right)^T d\mathbf{x}

이므로,

Lx=WTLz\frac{\partial L}{\partial \mathbf{x}} = \mathbf{W}^{T} \frac{\partial L}{\partial \mathbf{z}}

이다.

즉, 행렬 곱 역전파에서 전치행렬이 등장하는 이유는 출력 쪽 gradient를 입력 쪽 gradient로 되돌릴 때 야코비안의 전치가 곱해지기 때문이다.
이 과정이 수식으로는 야코비안의 전치 WT\mathbf{W}^{T}를 곱하는 형태로 나타난다.