본문 바로가기
Machine-Learning/NLP (Natural Language Processing)

[NLP] 멀티 헤드 어텐션(multi-head attention) 원리

by AteN 2022. 11. 6.

멀티 헤드 어텐션 원리 

 

어텐션을 사용할 때 헤드 한 개만 사용한 형태가 아닌 헤드 여러개를 사용한 어텐션 구조도 사용할 수 있다. 

어텐션 행렬이 아닌 다중 어텐션 행렬을 계산해보자. 

 

All is well이라는 문장으로 예를 들어보자. 단어 well의 셀프 어텐션을 계산한다고 하자. 계산 결과 다음과 같은 결과를 얻었다고 가정하자

 

위의 그림에서 볼 수 있듯이 단어 'well' 의 셀프 벡터 값은 가중치를 적용한 각 단어의 벡터 값의 합임을 알 수 있다. 

단어 'well'의 벡터값은 단어 'All' 이 가장 우세하게 작용하는 것을 알 수 있다. 즉, 단어 'well'의 벡터값은 단어 'All'의 벡터값에 0.6을 곱한 값과, 단어 'well'의 벡터값에 0.4를 곱한 결과를 합한 것이다. 이는 단어 'All'이 벡턱밧이 60% 반영되고, 단어 'well'의 벡터값이 40% 반영된 것으로 해석할 수 있다. 따라서 well의 벡턱밧은 단어 All의 영향을 가장 크다고 볼 수 있다 

 

하지만 문장 내에서 단어의 의미가 모호한 경우 역시 발생할 수 있다. 다음 문장을 예를 들어보자 

$$ A \; dog \; ate \; food \; because \; it \; was \; hungry. $$

 

단어 it의 벡터값은 단어 dog에 대한 벡터값으로만 구성된다. 즉, it의 벡터값은 단어 dog가 가장 우세하게 적용한다. 이때 단어 it의 의미는 dog 또는 food가 될 수 있는데 위의 결과는 단어의 의미가 잘 연결된 경우이다. 

 

따라서 문장 내에서 모호한 의미를 가진 단어가 있을 경우에 앞의 예와 같이 적잘한 의미를 가진 단어의 벡터값이 잘 할당되었을 경우에는 문장의 의미를 이해하는데 좋은 영향을 줄 수 있다. 하지만 반대의 경우, 즉 의미가 맞지 않는 단어의 벡터값이 높을 경우에는 문장의 의미를 잘 못 해석될 수 있다. 

 

그래서 어테션 결과의 정확도를 높이기 위해서 단일 헤드 어텐션 행렬 (single head attention) 이 아닌 멀티 헤드 어테션(multi-head attention)을 사용한 후 그 결과값을 더하는 형태로 진행한다. 이와 같은 방법을 사용하는 데는 단일 헤드 어텐션을 사용하는 것보다 멀티헤드 어텐션을 사용하면 좀더 정확하게 의미를 이해할 수 있다는 가정이 있다. 

 

멀티헤드 어텐션

 멀티헤드 어텐션은 2개의 행렬 \(Z_1\), \(Z_2\) 를 계산한다고 하자.

 

먼저  \(Z_1\) 값을 구한다 

 

쿼리  \((Q_1)\), 키 \((K_1)\), 밸류 \((V_1)\) 행렬을 생성한다. 그 다음으로 3개의 가중치 행렬 \(W^Q_1\), \(W^K_1\), \(W^V_1\) 을 생성하고, 마지막으로 입력 행렬 (X)에  가중치 행렬 \(W^Q_1\), \(W^K_1\), \(W^V_1\)을 각각 곱해 쿼리  \((Q_1)\), 키 \((K_1)\), 밸류 \((V_1)\) 행렬을 생성한다 

 

이때 어텐션  행렬 \(Z_1\) 은 다음과 같이 계산된다. 

$$ Z_1 = softmax\left ( \frac{Q_1K_1^T}{\sqrt{d_k}} \right )V_1 $$

 

두번째 어텐션 행렬  \(Z_2\)는 

쿼리  \((Q_2)\), 키 \((K_2)\), 밸류 \((V_2)\) 행렬을 추가로 생성한다. 그 다음으로 3개의 가중치 행렬 \(W^Q_2\), \(W^K_2\), \(W^V_2\) 을 생성하고, 마지막으로 입력 행렬 (X)에  가중치 행렬 \(W^Q_1\), \(W^K_1\), \(W^V_1\)을 각각 곱해 쿼리  \((Q_2)\), 키 \((K_2)\), 밸류 \((V_2)\) 행렬을 생성한다 

$$ Z_2 = softmax\left ( \frac{Q_2K_2^T}{\sqrt{d_k}} \right )V_2 $$

 

마찬가지로 h개의 어텐션 행렬을 구할 수 있다. 8개의 어텐션 행렬 \(Z_1\) ~ \(Z_8\)을 구한다고 하면 해당 행렬을 계산한 후에 그 결과 (어텐션 헤드)를 연결 (concatenate)한 후 새로운 가중치 행렬  \(W^Q\)를 곱하면 최종적으로 우리가 원하는 행렬곱을 구할 수 있다.

$$ multi-head \; attention = concatenate(Z_1, Z_2, ..., Z_8)W_0 $$

댓글