본문 바로가기
Machine-Learning/NLP (Natural Language Processing)

[NLP] 셀프 어텐션의 작동 원리

by AteN 2022. 9. 26.

셀프 어텐션의 작동 원리 

 

A dog ate the food because it was hungry.

 

이 문장에서 'it'은 'dog'나 'food'를 의미할 수 있다. 하짐나 문장을 자세히 살펴보면 'it'은 'food'가 아닌 'dog'를 의미한다는 것을 쉽게 알 수 있다. 위와 같은 문장이 주어질 경우 모델은 'it'이 'food'가 이닌 'dog'라는 것을 알 수 있을까? 이때 셀프 어텐션이 필요하다 

 

이 문장이 입력되었을 때, 모델은 가장 먼저 단어 'A'의 표현 (representation)을, 그 다음으로 단어 'dog'의 표현을 계산한 다음 'ate'라는 단어의 표현을 계산한다. 각각의 단어를 계산하는 동안 각 단어의 표현들은 문장 안에 있는 다른 모든 단어의 표현과 연결해 단어가 문장 내에서 갖는 의미를 이해한다. 

 

예를 들어 'it'이라는 단어의 표현을 계산하는 동안 모델에서는 'it'이라는 단어의 의미를 이해하기 위해 문장 안에 있는 모든 단어와 'it'이라는 단어를 연결하는 작업을 수행한다 

it 이라는 단어의 표현을 계산하기 위해 it을 문장의 모든 단어와 연결하는 작업을 보여준다. 이와 같은 연결 작업으로 모델은 it이 food가 아닌 dog와 관련이 있다는 것을 학습한다. dog를 잇는 선이 다른 단어보다 두껍게 표시되었다. 이는 주어진 문장 내에서 it이라는 단어가 food가 아닌 dog와 관련이 있다는 것을 보여준다. 

 

셀프 어텐션은 내부적으로 어떤 원리로 작동할까?

입력 문장이 I am good이라고 가정해보자. 이 문장을 기준으로 각 단어의 임베딩 (embedding)을 추출한다. 여기서 임베딩이란 각각의 단어를 표현하는 벡터 값을 의미하며, 임베딩 값은 ㄷ모델 학습 과정에서 같이 학습된다. 

\(x_1\)를 I, \(x_2\)는 am , \(x_3\)는 good에 대한 임베딩 값이라고 하자. 각각의 값을 표현하면 다음과 같다. 

$$ x_1 = [1.76, 2.22, ..., 6.66] $$

$$ x_2 = [7.77, 0.631, ..., 5.35] $$

$$ x_3 = [11.44, 10.10, ..., 3.33] $$

 

이제 입력 문장 'I am good'을 다음과 같이 입력 행렬 X(임베딩 행렬 또는 입력 임베딩)로 표현 할 수 있다.

$$\begin{align*}
I \\
am \\
good \\
\end{align*}\begin{bmatrix}
1.76 & 2.22 & ... & 6.66 \\ 
7.77 & 0.631 & ... & 5.35 \\
11.44 & 10.10 & ... & 3.33\\
\end{bmatrix}$$

 

행렬 X에서 첫 번째 행은 I의 임베딩, 두번째 행은 am의 임베딩, 세 번째 행은 good의 임베딩을 의미한다. 이때 행렬 X의 차원은 [문장 길이 x 임베딩 차원]의 형태가 된다. 위 문장에서 단어의 수 (문장 길이)는 3이고, 임베딩 차원은 512라고 가정하면 입력 행렬 (입력 임베딩)의 차원은 [3 x 512]이 된다. 

 

이제 입력 행렬 X로 부터 \(쿼리_query (Q)\) 행렬, \(키_key (K)\) 행렬, \(벨류_value (V)\) 행렬을 생성한다.

이 행렬은 무엇이고, 왜 필요할까? 이 세가지 행렬은 셀프 어텐션에서 사용된다. 이 행렬들을 어떻게 사용하는지 살펴보자 

 

우선 쿼리, 키, 밸류 행렬을 어떻게 만드는지 알아보자, 행렬을 생성하기 위해서는 \(W^Q\), \(W^K\), \(W^V\)라는 3개의 가중치 행렬 (weight matrix)을 생성한 다음 이 가중치 행렬을 입력 행렬 (X) 에 곱해 Q, K, V 행렬을 생성한다 

 

이때 가중치 행렬 \(W^Q\), \(W^K\), \(W^V\)은 처음에 임의의 값을 가지며, 학습 과정에서 최적값을 얻는다. 학습을 통해 최적의 가중치 행렬값이 생성되면 더욱 정확한 쿼리 값, 키 값, 밸ㄹ 값을 얻게 된다.

아래의 값들에서 볼 수 있듯이 입력 행렬 값에서 가중치 행렬 \(W^Q\), \(W^K\), \(W^V\)을 곱하면 쿼리값, 키 값, 밸류 값을 얻을 수 있다.

 

$$ 쿼리, 키, 밸류의 첫 번째 행인 q_1, k_1, v_1은 단어 'I'에 대한 쿼리, 키, 밸류 벡터를 의미한다 $$

$$ 쿼리, 키, 밸류의 첫 번째 행인 q_2, k_2, v_2은 단어 'am'에 대한 쿼리, 키, 밸류 벡터를 의미한다 $$

$$ 쿼리, 키, 밸류의 첫 번째 행인 q_3, k_3, v_3은 단어 'good'에 대한 쿼리, 키, 밸류 벡터를 의미한다 $$

 

이때 쿼리, 키, 밸류 벡터의 차원이 64라고 가정하면 쿼리, 키, 밸류 행렬의 차원은 [문장 길이 x 64]가 된다. 예제 문장의 세 가지 단어에 대한 쿼리, 키, 밸류 행렬은 차원  [3 x 64]가 된다.

'Machine-Learning > NLP (Natural Language Processing)' 카테고리의 다른 글

[NLP] Transformer-Overview  (0) 2022.10.11
[NLP] Seq2Seq  (0) 2022.10.04
[NLP] LSTM /GRU  (0) 2022.09.29
[NLP] RNN (Recurrent Netural Network)  (0) 2022.09.29
[NLP] NLP의 이해  (0) 2022.09.26

댓글