멀티 헤드 어텐션 원리
어텐션을 사용할 때 헤드 한 개만 사용한 형태가 아닌 헤드 여러개를 사용한 어텐션 구조도 사용할 수 있다.
어텐션 행렬이 아닌 다중 어텐션 행렬을 계산해보자.
All is well이라는 문장으로 예를 들어보자. 단어 well의 셀프 어텐션을 계산한다고 하자. 계산 결과 다음과 같은 결과를 얻었다고 가정하자
위의 그림에서 볼 수 있듯이 단어 'well' 의 셀프 벡터 값은 가중치를 적용한 각 단어의 벡터 값의 합임을 알 수 있다.
단어 'well'의 벡터값은 단어 'All' 이 가장 우세하게 작용하는 것을 알 수 있다. 즉, 단어 'well'의 벡터값은 단어 'All'의 벡터값에 0.6을 곱한 값과, 단어 'well'의 벡터값에 0.4를 곱한 결과를 합한 것이다. 이는 단어 'All'이 벡턱밧이 60% 반영되고, 단어 'well'의 벡터값이 40% 반영된 것으로 해석할 수 있다. 따라서 well의 벡턱밧은 단어 All의 영향을 가장 크다고 볼 수 있다
하지만 문장 내에서 단어의 의미가 모호한 경우 역시 발생할 수 있다. 다음 문장을 예를 들어보자
$$ A \; dog \; ate \; food \; because \; it \; was \; hungry. $$
단어 it의 벡터값은 단어 dog에 대한 벡터값으로만 구성된다. 즉, it의 벡터값은 단어 dog가 가장 우세하게 적용한다. 이때 단어 it의 의미는 dog 또는 food가 될 수 있는데 위의 결과는 단어의 의미가 잘 연결된 경우이다.
따라서 문장 내에서 모호한 의미를 가진 단어가 있을 경우에 앞의 예와 같이 적잘한 의미를 가진 단어의 벡터값이 잘 할당되었을 경우에는 문장의 의미를 이해하는데 좋은 영향을 줄 수 있다. 하지만 반대의 경우, 즉 의미가 맞지 않는 단어의 벡터값이 높을 경우에는 문장의 의미를 잘 못 해석될 수 있다.
그래서 어테션 결과의 정확도를 높이기 위해서 단일 헤드 어텐션 행렬 (single head attention) 이 아닌 멀티 헤드 어테션(multi-head attention)을 사용한 후 그 결과값을 더하는 형태로 진행한다. 이와 같은 방법을 사용하는 데는 단일 헤드 어텐션을 사용하는 것보다 멀티헤드 어텐션을 사용하면 좀더 정확하게 의미를 이해할 수 있다는 가정이 있다.
멀티헤드 어텐션
멀티헤드 어텐션은 2개의 행렬 \(Z_1\), \(Z_2\) 를 계산한다고 하자.
먼저 \(Z_1\) 값을 구한다
쿼리 \((Q_1)\), 키 \((K_1)\), 밸류 \((V_1)\) 행렬을 생성한다. 그 다음으로 3개의 가중치 행렬 \(W^Q_1\), \(W^K_1\), \(W^V_1\) 을 생성하고, 마지막으로 입력 행렬 (X)에 가중치 행렬 \(W^Q_1\), \(W^K_1\), \(W^V_1\)을 각각 곱해 쿼리 \((Q_1)\), 키 \((K_1)\), 밸류 \((V_1)\) 행렬을 생성한다
이때 어텐션 행렬 \(Z_1\) 은 다음과 같이 계산된다.
$$ Z_1 = softmax\left ( \frac{Q_1K_1^T}{\sqrt{d_k}} \right )V_1 $$
두번째 어텐션 행렬 \(Z_2\)는
쿼리 \((Q_2)\), 키 \((K_2)\), 밸류 \((V_2)\) 행렬을 추가로 생성한다. 그 다음으로 3개의 가중치 행렬 \(W^Q_2\), \(W^K_2\), \(W^V_2\) 을 생성하고, 마지막으로 입력 행렬 (X)에 가중치 행렬 \(W^Q_1\), \(W^K_1\), \(W^V_1\)을 각각 곱해 쿼리 \((Q_2)\), 키 \((K_2)\), 밸류 \((V_2)\) 행렬을 생성한다
$$ Z_2 = softmax\left ( \frac{Q_2K_2^T}{\sqrt{d_k}} \right )V_2 $$
마찬가지로 h개의 어텐션 행렬을 구할 수 있다. 8개의 어텐션 행렬 \(Z_1\) ~ \(Z_8\)을 구한다고 하면 해당 행렬을 계산한 후에 그 결과 (어텐션 헤드)를 연결 (concatenate)한 후 새로운 가중치 행렬 \(W^Q\)를 곱하면 최종적으로 우리가 원하는 행렬곱을 구할 수 있다.
$$ multi-head \; attention = concatenate(Z_1, Z_2, ..., Z_8)W_0 $$
'Machine-Learning > NLP (Natural Language Processing)' 카테고리의 다른 글
[NLP] BERT 의 이해 (0) | 2022.11.19 |
---|---|
[NLP] Transforemr decoder 이해 (0) | 2022.11.15 |
[NLP] Transforemr - 피드포워드, add와 Norm (0) | 2022.11.06 |
[NLP] Transformer - positional encoding, self-attention (0) | 2022.10.11 |
[NLP] Transformer-Overview (0) | 2022.10.11 |
댓글