한빛미디어의 <밑바닥부터 시작하는 딥러닝 2>를 요약 정리한 글이다.
RNN 또는 LSTM 기반 언어 모델에서 문장을 생성할 때는 단어 ID 하나가 입력으로 들어가고, 모델은 다음 단어의 확률분포를 출력한다.
이 글은 문장 생성 과정에서 각 계층을 통과할 때 텐서 형상이 어떻게 변하는지 정리한 것이다.
전체 흐름
문장 생성의 한 시점 흐름은 다음과 같다.
word id
→ Embedding
→ LSTM
→ Affine
→ Softmax
→ word probability distribution
→ sampling
학습 시에는 여러 배치와 여러 시점을 한 번에 처리하지만, 문장 생성 시에는 보통 단어 ID 하나를 넣고 다음 단어 하나를 샘플링한다.
따라서 추론 시에는 배치 크기가 1인 경우가 많다.
1. 입력 임베딩
입력은 단어 ID이다.
Embedding 계층은 단어 ID를 임베딩 벡터로 변환한다.
미니배치 크기를 N, 임베딩 차원을 D라고 하면 한 시점의 입력 임베딩은 다음 형상을 가진다.
xt∈RN×D
여기서 각 기호의 의미는 다음과 같다.
- N: 배치 크기이다.
- D: 임베딩 차원이다.
문장 생성에서 단어 하나만 넣는다면 보통 N=1이다.
2. LSTM 내부 아핀 변환
LSTM 내부에서는 네 개의 게이트를 한 번의 큰 아핀 변환으로 계산한다.
A=xtWx+ht−1Wh+b
각 텐서의 형상은 다음과 같다.
xt∈RN×D
Wx∈RD×4H
ht−1∈RN×H
Wh∈RH×4H
b∈R4H
따라서 결과 A의 형상은 다음과 같다.
A∈RN×4H
4H가 되는 이유는 LSTM이 네 개의 값을 동시에 계산하기 때문이다.
- forget gate ft
- cell candidate gt
- input gate it
- output gate ot
즉, A는 열 방향으로 네 부분으로 나뉜다.
A[:, :H] → f
A[:, H:2H] → g
A[:, 2H:3H] → i
A[:, 3H:] → o
3. LSTM 셀 상태 업데이트
LSTM의 셀 상태는 다음 식으로 갱신된다.
ct=ft⊙ct−1+it⊙gt
은닉 상태는 다음과 같이 계산된다.
ht=ot⊙tanh(ct)
여기서 ⊙는 원소별 곱이다.
각 게이트와 셀 상태, 은닉 상태는 모두 같은 은닉 차원 H를 가진다.
ft,gt,it,ot,ct,ht∈RN×H
따라서 LSTM 한 시점의 출력인 은닉 상태는 다음 형상이다.
ht∈RN×H
4. 출력층
출력층에서는 은닉 상태 ht를 단어 사전 크기 V의 score로 변환한다.
st=htWs+bs
각 텐서의 형상은 다음과 같다.
ht∈RN×H
Ws∈RH×V
bs∈RV
따라서 출력 score의 형상은 다음과 같다.
st∈RN×V
여기서 V는 vocabulary size, 즉 단어 사전 크기이다.
각 행은 해당 배치 샘플에 대한 전체 단어 점수 벡터이다.
5. Softmax와 확률분포
softmax는 score를 확률분포로 변환한다.
pt=softmax(st)
확률분포의 형상은 score와 같다.
pt∈RN×V
각 배치마다 단어 사전 크기 V만큼의 확률분포가 만들어진다.
문장 생성에서는 이 확률분포에서 다음 단어 ID를 샘플링한다.
sampled_id = np.random.choice(len(p), p=p)
추론 시 형상
학습 시에는 보통 여러 문장과 여러 시점을 한 번에 처리한다.
하지만 문장 생성 시에는 현재 단어 하나만 넣고 다음 단어 하나를 예측하는 방식으로 진행된다.
따라서 입력 단어 ID는 다음처럼 만들어진다.
x = np.array(x).reshape(1, 1)
이때 형상은 다음과 같다.
x∈R1×1
여기서 첫 번째 1은 batch size이고, 두 번째 1은 time size이다.
모델의 predict()를 통과하면 다음 단어에 대한 score가 나온다.
보통 이 score는 다음처럼 단어 사전 크기 V에 해당하는 값을 포함한다.
s∈R1×1×V
또는 구현에 따라 다음처럼 볼 수도 있다.
s∈R1×V
이후 flatten()을 적용하면 배치 차원과 시간 차원이 제거된다.
p = softmax(score.flatten())
결과는 다음 형상이다.
p∈RV
즉, p의 각 인덱스는 단어 ID와 1:1로 대응된다.
따라서 np.random.choice로 선택된 인덱스는 곧 다음 입력으로 사용할 단어 ID이다.
정리
문장 생성에서 한 시점의 핵심 형상은 다음과 같다.
xt∈RN×D
A∈RN×4H
ht∈RN×H
st∈RN×V
pt∈RN×V
추론에서는 보통 N=1이고, softmax 결과를 flatten하여 길이 V의 확률분포로 만든 뒤 다음 단어 ID를 샘플링한다.