곱셈 노드의 역전파

학습원리 글 목록

곱셈 노드(Multiplication Node)는 계산 그래프에서 두 값을 곱하는 노드이다.

역전파에서는 출력 쪽에서 전달된 기울기에 대해, 각 입력의 반대쪽 값을 곱해서 기울기를 전달한다.

곱셈 노드

입력이 xx, yy이고 출력이 zz일 때, 곱셈 노드는 다음 연산을 수행한다.

z=xyz = x \cdot y

순전파에서는 입력 xx, yy를 받아 두 값을 곱한 zz를 출력한다.

(x,y)z(x, y) \mapsto z

순전파

순전파는 다음과 같다.

z=xyz = x \cdot y

예를 들어 x=2x=2, y=3y=3이면 출력은 다음과 같다.

z=23=6z = 2 \cdot 3 = 6

즉, 곱셈 노드는 입력 두 개를 받아 하나의 출력 값을 만든다.

역전파

손실 함수를 LL이라고 하자.

출력 zz 쪽에서 곱셈 노드로 전달되는 기울기는 다음과 같다.

Lz\frac{\partial L}{\partial z}

곱셈 노드의 역전파에서는 이 기울기를 각 입력 방향으로 다시 전달해야 한다.

먼저 xx에 대한 기울기는 다음과 같다.

Lx=Lzy\frac{\partial L}{\partial x} = \frac{\partial L}{\partial z} \cdot y

yy에 대한 기울기는 다음과 같다.

Ly=Lzx\frac{\partial L}{\partial y} = \frac{\partial L}{\partial z} \cdot x

즉, xx 방향으로는 yy가 곱해지고, yy 방향으로는 xx가 곱해진다.

상류 Gradient와 로컬 Gradient

역전파는 기본적으로 다음 형태로 이해할 수 있다.

하류 gradient=상류 gradient×로컬 gradient\text{하류 gradient} = \text{상류 gradient} \times \text{로컬 gradient}

여기서 상류 gradient는 뒤쪽 노드에서 현재 노드로 전달된 기울기이다.

곱셈 노드에서는 출력 zz 쪽에서 전달된 다음 값이 상류 gradient이다.

Lz\frac{\partial L}{\partial z}

로컬 gradient는 현재 노드의 출력이 각 입력에 대해 얼마나 변하는지를 나타내는 값이다.

곱셈 노드의 출력은 다음과 같다.

z=xyz = xy

따라서 각 입력에 대한 로컬 gradient는 다음과 같다.

zx=y\frac{\partial z}{\partial x} = y zy=x\frac{\partial z}{\partial y} = x

결국 xx 방향으로 전달되는 gradient는 다음과 같이 계산된다.

Lx=Lz상류 gradient×zx로컬 gradient=Lzy\frac{\partial L}{\partial x} = \underbrace{\frac{\partial L}{\partial z}}_{\text{상류 gradient}} \times \underbrace{\frac{\partial z}{\partial x}}_{\text{로컬 gradient}} = \frac{\partial L}{\partial z} \cdot y

yy 방향도 마찬가지이다.

Ly=Lz상류 gradient×zy로컬 gradient=Lzx\frac{\partial L}{\partial y} = \underbrace{\frac{\partial L}{\partial z}}_{\text{상류 gradient}} \times \underbrace{\frac{\partial z}{\partial y}}_{\text{로컬 gradient}} = \frac{\partial L}{\partial z} \cdot x

즉, 역전파는 상류에서 흘러온 gradient에 현재 노드의 로컬 gradient를 곱해 앞쪽 노드로 전달하는 과정이다.

왜 반대쪽 입력이 곱해지는가

곱셈 노드의 출력은 다음과 같다.

z=xyz = xy

xx에 대해 미분하면 다음과 같다.

zx=y\frac{\partial z}{\partial x} = y

yy에 대해 미분하면 다음과 같다.

zy=x\frac{\partial z}{\partial y} = x

역전파에서는 연쇄 법칙을 사용한다.

따라서 xx에 대한 손실의 기울기는 다음과 같다.

Lx=Lzzx\frac{\partial L}{\partial x} = \frac{\partial L}{\partial z} \frac{\partial z}{\partial x}

여기서 zx=y\frac{\partial z}{\partial x}=y이므로 다음과 같이 된다.

Lx=Lzy\frac{\partial L}{\partial x} = \frac{\partial L}{\partial z} \cdot y

마찬가지로 yy에 대해서는 다음과 같다.

Ly=Lzzy=Lzx\frac{\partial L}{\partial y} = \frac{\partial L}{\partial z} \frac{\partial z}{\partial y} = \frac{\partial L}{\partial z} \cdot x

결국 곱셈 노드는 역전파 시 출력 쪽 기울기에 반대쪽 입력 값을 곱해 전달한다.

행렬 기반 표현

딥러닝에서는 스칼라 하나끼리 곱하는 경우보다 벡터나 행렬 단위로 계산하는 경우가 많다.

먼저 벡터 입력의 곱셈을 생각해보자.

벡터끼리 곱한다고 할 때는 보통 같은 위치의 원소끼리 곱하는 원소별 곱(element-wise product)을 의미한다.

z=xy\mathbf{z} = \mathbf{x} \odot \mathbf{y}

여기서 \odot는 원소별 곱을 의미한다.

벡터 x\mathbf{x}, y\mathbf{y}가 다음과 같다고 하자.

x=[x1x2xn],y=[y1y2yn]\mathbf{x} = \begin{bmatrix} x_1 \\ x_2 \\ \vdots \\ x_n \end{bmatrix}, \qquad \mathbf{y} = \begin{bmatrix} y_1 \\ y_2 \\ \vdots \\ y_n \end{bmatrix}

그러면 출력 z\mathbf{z}는 다음과 같다.

z=[z1z2zn]=[x1y1x2y2xnyn]\mathbf{z} = \begin{bmatrix} z_1 \\ z_2 \\ \vdots \\ z_n \end{bmatrix} = \begin{bmatrix} x_1y_1 \\ x_2y_2 \\ \vdots \\ x_ny_n \end{bmatrix}

즉, 각 원소는 다음과 같이 계산된다.

zi=xiyiz_i = x_i y_i

스칼라 곱셈 z=xyz=xy가 각 위치마다 독립적으로 반복된다고 보면 된다.

원소별 곱의 야코비안

벡터 출력 z\mathbf{z}를 벡터 입력 x\mathbf{x}에 대해 미분하면 야코비안 행렬이 나온다.

Jz,x=zxJ_{\mathbf{z}, \mathbf{x}} = \frac{\partial \mathbf{z}}{\partial \mathbf{x}}

이 행렬의 (i,j)(i, j)번째 원소는 다음 값이다.

zixj\frac{\partial z_i}{\partial x_j}

그런데 zi=xiyiz_i = x_i y_i이므로 ziz_ixix_i에만 의존한다.

xjx_jxix_i와 다른 원소라면 ziz_i에는 영향을 주지 않는다.

따라서 다음이 성립한다.

zixj={yi,i=j0,ij\frac{\partial z_i}{\partial x_j} = \begin{cases} y_i, & i = j \\ 0, & i \neq j \end{cases}

즉, x\mathbf{x}에 대한 야코비안은 대각행렬이다.

zx=[y1000y2000yn]=diag(y)\frac{\partial \mathbf{z}}{\partial \mathbf{x}} = \begin{bmatrix} y_1 & 0 & \cdots & 0 \\ 0 & y_2 & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \cdots & y_n \end{bmatrix} = \operatorname{diag}(\mathbf{y})

마찬가지로 y\mathbf{y}에 대한 야코비안은 다음과 같다.

ziyj={xi,i=j0,ij\frac{\partial z_i}{\partial y_j} = \begin{cases} x_i, & i = j \\ 0, & i \neq j \end{cases}

따라서,

zy=[x1000x2000xn]=diag(x)\frac{\partial \mathbf{z}}{\partial \mathbf{y}} = \begin{bmatrix} x_1 & 0 & \cdots & 0 \\ 0 & x_2 & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \cdots & x_n \end{bmatrix} = \operatorname{diag}(\mathbf{x})

이다.

원소별 곱의 역전파

손실 함수가 z\mathbf{z}를 통해 계산된다고 하자.

L=L(z)L = L(\mathbf{z})

출력 쪽에서 내려온 상류 gradient를 다음과 같이 둔다.

Lz=[Lz1Lz2Lzn]\frac{\partial L}{\partial \mathbf{z}} = \begin{bmatrix} \frac{\partial L}{\partial z_1} \\ \frac{\partial L}{\partial z_2} \\ \vdots \\ \frac{\partial L}{\partial z_n} \end{bmatrix}

다변수 연쇄 법칙에 의해 x\mathbf{x} 방향의 gradient는 다음과 같다.

Lx=(zx)TLz\frac{\partial L}{\partial \mathbf{x}} = \left( \frac{\partial \mathbf{z}}{\partial \mathbf{x}} \right)^T \frac{\partial L}{\partial \mathbf{z}}

여기에 앞에서 구한 야코비안을 대입하면 다음과 같다.

Lx=diag(y)TLz\frac{\partial L}{\partial \mathbf{x}} = \operatorname{diag}(\mathbf{y})^T \frac{\partial L}{\partial \mathbf{z}}

대각행렬은 전치해도 같으므로,

Lx=diag(y)Lz\frac{\partial L}{\partial \mathbf{x}} = \operatorname{diag}(\mathbf{y}) \frac{\partial L}{\partial \mathbf{z}}

이다.

원소별로 보면 다음과 같다.

Lxi=Lziyi\frac{\partial L}{\partial x_i} = \frac{\partial L}{\partial z_i} y_i

결국 벡터 형태로는 다음처럼 쓸 수 있다.

Lx=Lzy\frac{\partial L}{\partial \mathbf{x}} = \frac{\partial L}{\partial \mathbf{z}} \odot \mathbf{y}

y\mathbf{y} 방향도 같은 방식이다.

Ly=(zy)TLz\frac{\partial L}{\partial \mathbf{y}} = \left( \frac{\partial \mathbf{z}}{\partial \mathbf{y}} \right)^T \frac{\partial L}{\partial \mathbf{z}} Ly=diag(x)Lz\frac{\partial L}{\partial \mathbf{y}} = \operatorname{diag}(\mathbf{x}) \frac{\partial L}{\partial \mathbf{z}}

따라서 벡터 형태로는 다음과 같다.

Ly=Lzx\frac{\partial L}{\partial \mathbf{y}} = \frac{\partial L}{\partial \mathbf{z}} \odot \mathbf{x}

즉, 스칼라에서

dx=dzy,dy=dzxdx = dz \cdot y, \qquad dy = dz \cdot x

였던 것이 벡터에서는 원소별로 확장된다.

dx=dzyd\mathbf{x} = d\mathbf{z} \odot \mathbf{y} dy=dzxd\mathbf{y} = d\mathbf{z} \odot \mathbf{x}

곱셈 노드의 핵심은 그대로이다.

상류 gradient에 현재 노드의 로컬 gradient를 곱해 앞쪽으로 전달한다.

다만 벡터에서는 이 계산이 각 원소마다 독립적으로 일어나기 때문에, 결과가 원소별 곱 형태로 나타난다.

행렬 곱셈의 경우

원소별 곱과 달리, 일반적인 행렬 곱셈에서는 한 출력 원소가 여러 입력 원소의 합으로 만들어진다.

그래서 역전파 식도 단순한 원소별 곱이 아니라 전치 행렬을 사용한 형태로 나타난다.

예를 들어 다음과 같은 선형 변환을 생각할 수 있다.

z=Wx\mathbf{z} = \mathbf{W}\mathbf{x}

이때 출력 쪽에서 전달되는 기울기를 다음과 같이 두자.

Lz\frac{\partial L}{\partial \mathbf{z}}

그러면 입력 벡터 x\mathbf{x}에 대한 기울기는 다음과 같다.

Lx=WLz\frac{\partial L}{\partial \mathbf{x}} = \mathbf{W}^{\top} \frac{\partial L}{\partial \mathbf{z}}

가중치 행렬 W\mathbf{W}에 대한 기울기는 다음과 같다.

LW=Lzx\frac{\partial L}{\partial \mathbf{W}} = \frac{\partial L}{\partial \mathbf{z}} \mathbf{x}^{\top}

행렬 곱셈에서도 핵심은 같다.

역전파에서는 출력의 기울기를 각 입력 방향으로 전달하되, 해당 입력이 출력에 미친 영향을 곱해서 전달한다.

요약

곱셈 노드는 순전파에서 두 입력을 곱한다.

z=xyz = x \cdot y

역전파에서는 출력 쪽 기울기에 반대쪽 입력 값을 곱해 전달한다.

Lx=Lzy\frac{\partial L}{\partial x} = \frac{\partial L}{\partial z} \cdot y Ly=Lzx\frac{\partial L}{\partial y} = \frac{\partial L}{\partial z} \cdot x

즉, 곱셈 노드의 역전파는 “서로의 값을 바꿔 곱해서 전달한다”고 이해할 수 있다.

#딥러닝 #딥러닝/기초/역전파