๐คย [Transformer] Attention Is All You Need
๐ญย Overview
Transformer
๋ 2017๋
Google์ด NIPS์์ ๋ฐํํ ์์ฐ์ด ์ฒ๋ฆฌ์ฉ ์ ๊ฒฝ๋ง์ผ๋ก ๊ธฐ์กด RNN
๊ณ์ด(LSTM, GRU) ์ ๊ฒฝ๋ง์ด ๊ฐ์ง ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ณ ์ต๋ํ ์ธ๊ฐ์ ์์ฐ์ด ์ดํด ๋ฐฉ์์ ์ํ์ ์ผ๋ก ๋ชจ๋ธ๋ง ํ๋ ค๋ ์๋๋ก ์ค๊ณ ๋์๋ค. ์ด ๋ชจ๋ธ์ ์ด๊ธฐ Encoder-Decoder
๋ฅผ ๋ชจ๋ ๊ฐ์ถ seq2seq
ํํ๋ก ๊ณ ์ ๋์์ผ๋ฉฐ, ๋ค์ํ ๋ฒ์ญ ํ
์คํฌ์์ SOTA
๋ฅผ ๋ฌ์ฑํด ์ฃผ๋ชฉ์ ๋ฐ์๋ค. ์ดํ์๋ ์ฌ๋ฌ๋ถ๋ ์ ์์๋ ๊ฒ์ฒ๋ผ BERT
, GPT
, ViT
์ ๋ฒ ์ด์ค ๋ผ์ธ์ผ๋ก ์ฑํ ๋๋ฉฐ, ํ๋ ๋ฅ๋ฌ๋ ์ญ์ฌ์ ํ ํ์ ๊ทธ์ ๋ชจ๋ธ๋ก ํ๊ฐ ๋ฐ๊ณ ์๋ค.
ํ๋ ๋ฅ๋ฌ๋์ ์ ์ฑ๊ธฐ๋ฅผ ์ด์ด์ค Transformer
๋ ์ด๋ค ์์ด๋์ด๋ก ๊ธฐ์กด Recurrent
๊ณ์ด์ด ๊ฐ์ก๋ ๋ฌธ์ ๋ค์ ํด๊ฒฐํ์๊น?? ์ด๊ฒ์ ์ ๋๋ก ์ดํดํ๋ ค๋ฉด ๋จผ์ ๊ธฐ์กด ์ํ ์ ๊ฒฝ๋ง ๋ชจ๋ธ๋ค์ด ๊ฐ์ก๋ ๋ฌธ์ ๋ถํฐ ์ง๊ณ ๋์ด๊ฐ ํ์๊ฐ ์๋ค.
๐คย Limitation of Recurrent Structure
- 1) ์ธ๊ฐ๊ณผ ๋ค๋ฅธ ๋ฉ์ปค๋์ฆ์ Vanishing Gradient ๋ฐ์ (Activation Function with Backward)
- 2) ์ ์ ํ๋ ค์ง๋ Inputs์ Attention (Activation Function with Forward)
- 3) ๋์ฝ๋๊ฐ ๊ฐ์ฅ ๋ง์ง๋ง ๋จ์ด๋ง ์ด์ฌํ ๋ณด๊ณ
denoising
์ํ (Seq2Seq with Bi-Directional RNN)
๐ย 1) ์ธ๊ฐ๊ณผ ๋ค๋ฅธ ๋ฉ์ปค๋์ฆ์ Vanishing Gradient ๋ฐ์ (Activation Function with Backward)
RNN
์ ํ์ฑ ํจ์์ธ Hyperbolic Tangent
๋ $y$๊ฐ์ด [-1, 1]
์ฌ์ด์์ ์ ์๋๋ฉฐ ๊ธฐ์ธ๊ธฐ์ ์ต๋๊ฐ์ 1์ด๋ค. ๋ฐ๋ผ์ ์ด์ ์์ ์ ๋ณด๋ ์์ ์ด ์ง๋๋ฉด ์ง๋ ์๋ก (๋ ๋ง์ ์
์ ํต๊ณผํ ์๋ก) ๊ทธ๋ผ๋์ธํธ ๊ฐ์ด ์์์ ธ ๋ฏธ๋ ์์ ์ ํ์ต์ ๋งค์ฐ ์์ ์ํฅ๋ ฅ์ ๊ฐ๊ฒ ๋๋ค. ์ด๊ฒ์ด ๋ฐ๋ก ๊ทธ ์ ๋ช
ํ RNN
์ Vanishing Gradient
ํ์์ด๋ค. ์ฌ์ค ํ์์ ๋ฐ์ ์์ฒด๋ ๊ทธ๋ ๊ฒ ํฐ ๋ฌธ์ ๊ฐ ๋์ง ์๋๋ค. RNN
์์ ๋ฐ์ํ๋ Vanishing Gradient
๊ฐ ๋ฌธ์ ๊ฐ ๋๋ ์ด์ ๋ ๋ฐ๋ก ์ธ๊ฐ์ด ์์ฐ์ด๋ฅผ ์ดํดํ๋ ๋ฉ์ปค๋์ฆ๊ณผ ๋ค๋ฅธ ๋ฐฉ์์ผ๋ก ํ์์ด ๋ฐ์ํ๊ธฐ ๋๋ฌธ์ด๋ค. ์ฐ๋ฆฌ๊ฐ ๊ธ์ ์ฝ๋ ๊ณผ์ ์ ์ ๋ ์ฌ๋ ค ๋ณด์. ์ด๋ค ๋จ์ด์ ์๋ฏธ๋ฅผ ์๊ธฐ ์ํด ๊ฐ๊น์ด ์ฃผ๋ณ ๋จ์ด์ ๋ฌธ๋งฅ์ ํ์ฉํ ๋๋ ์์ง๋ง, ์ ๋ฉ๋ฆฌ ๋จ์ด์ง ๋ฌธ๋จ์ ๋ฌธ๋งฅ์ ํ์ฉํ ๋๋ ์๋ค. ์ด์ฒ๋ผ ๋จ์ด ํน์ ์ํ์ค๋ฅผ ๊ตฌ์ฑํ๋ ์์ ์ฌ์ด์ ๊ด๊ณ์ฑ
์ด๋ ์ด๋ค ๋ค๋ฅธ ์๋ฏธ๋ก ์ ์ธ ์ด์
๋ก ๋ถ๊ท ํ
ํ๊ฒ ํ์ฌ ์์ ์ ํ์ต์ ์ํฅ๋ ฅ์ ๊ฐ๊ฒ ๋๋๊ฒ ์๋๋ผ, ๋จ์ ์
๋ ฅ ์์
๋๋ฌธ์ ๋ถ๊ท ํ์ด ๋ฐ์ํ๊ธฐ ๋๋ฌธ์ RNN
์ Vanishing Gradient
๊ฐ ๋ฎ์ ์ฑ๋ฅ์ ์์ธ์ผ๋ก ์ง๋ชฉ๋๋ ๊ฒ์ด๋ค.
๋ค์ ๋งํด, ์ค์ ์์ฐ์ด์ ๋ฌธ๋งฅ์ ํ์
ํด ๊ทธ๋ผ๋์ธํธ์ ๋ฐ์ํ๋๊ฒ ์๋๋ผ ๋จ์ํ ์์ ์ ๋ฐ๋ผ์ ๊ทธ ์ํฅ๋ ฅ์ ๋ฐ์ํ๊ฒ ๋๋ค๋ ๊ฒ์ด๋ค. ๋ฉ๋ฆฌ ๋จ์ด์ง ์ํ์ค์ ๋ฌธ๋งฅ์ด ํ์ํ ๊ฒฝ์ฐ๋ฅผ Recurrent
๊ตฌ์กฐ๋ ์ ํํ ํ์ตํ ์ ์๋ค.
๊ทธ๋ ๋ค๋ฉด ํ์ฑ ํจ์๋ฅผ relu
ํน์ gelu
๋ฅผ ์ฌ์ฉํ๋ฉด ์ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ ์ ์์๊น? Vanishing Graident
๋ฌธ์ ๋ ํด๊ฒฐํ ์๋ ์์ผ๋ hidden_state
๊ฐ์ด ๋ฐ์ฐํ ๊ฒ์ด๋ค. ๊ทธ ์ด์ ๋ ๋ ํ์ฑ ํจ์ ๋ชจ๋ ์์ ๊ตฌ๊ฐ์์ ์ ํ์ธ๋ฐ, ์ด์ ์ ๋ณด๋ฅผ ๋์ ํด์ ๊ฐ์ค์น์ ๊ณฑํ๊ณ ํ์ฌ ์
๋ ฅ๊ฐ์ ๋ํ๋ RNN
์ ๊ตฌ์กฐ๋ฅผ ์๊ฐํด๋ณด๋ฉด ๋์ด์ค๋ ์ด์ ์ ๋ณด๋ ๋์ ๋๋ฉด์ ์ ์ ์ปค์ง ๊ฒ์ด๊ณ ๊ทธ๋ฌ๋ค ๊ฒฐ๊ตญ ๋ฐ์ฐํ๊ฒ ๋๋ค.
๊ฒฐ๋ก ์ ์ผ๋ก Vanishing Gradient
ํ์ ์์ฒด๊ฐ ๋ฌธ์ ๋ ์๋์ง๋ง ๋ชจ๋ธ์ด ์์ฐ์ด์ ๋ฌธ๋งฅ์ ํ์
ํด ๊ทธ๋ผ๋์ธํธ์ ๋ฐ์ํ๋๊ฒ ์๋๋ผ ๋จ์ํ ์์ ์ ๋ฐ๋ผ์ ๋ถ๊ท ํํ๊ฒ ๋ฐ์ํ๊ธฐ ๋๋ฌธ์ ๋ฎ์ ์ฑ๋ฅ์ ์์ธ์ผ๋ก ์ง๋ชฉ ๋ฐ๋ ๊ฒ์ด๋ค. ์ด๊ฒ์ long-term dependency
๋ผ๊ณ ๋ถ๋ฅด๊ธฐ๋ ํ๋ค.
โ๏ธย 2) ์ ์ ํ๋ ค์ง๋ Inputs์ Attention (Activation Function with Forward)
tanh function
Hyperbolic Tangent
์ $y$๊ฐ์ด [-1, 1]
์ฌ์ด์์ ์ ์๋๋ค๊ณ ํ๋ค. ๋ค์ ๋งํด ์
์ ์ถ๋ ฅ๊ฐ์ด ํญ์ ์ผ์ ๋ฒ์๊ฐ( [-1,1]
)์ผ๋ก ์ ํ(๊ฐ์ค์น, ํธํฅ ๋ํ๋ ๊ฒ์ ์ผ๋จ ์ ์ธ) ๋๋ค๋ ๊ฒ์ด๋ค. ๋ฐ๋ผ์ ํ์ ๋ ์ข์ ๋ฒ์์ ์ถ๋ ฅ๊ฐ๋ค์ด ๋งตํ๋๋๋ฐ, ์ด๋ ๊ฒฐ๊ตญ ์
๋ ฅ๊ฐ์ ์ ๋ณด๋ ๋๋ถ๋ถ ์์ค๋ ์ฑ ์ผ๋ถ ํน์ง๋ง ์ ์ ๋์ด ์ถ๋ ฅ๋๊ณ ๋ค์ ๋ ์ด์ด๋ก forward
๋จ์ ์๋ฏธํ๋ค. ๊ทธ๋ํ๋ฅผ ํ ๋ฒ ์ดํด๋ณด์. ํนํ Inputs
๊ฐ์ด 2.5 ์ด์์ธ ๊ฒฝ์ฐ๋ถํฐ๋ ์ถ๋ ฅ๊ฐ์ด ๊ฑฐ์ 1์ ์๋ ดํด ๊ทธ ์ฐจ์ด๋ฅผ ์ง๊ด์ ์ผ๋ก ํ์
ํ๊ธฐ ํ๋ค๋ค. ์ด๋ฌํ ํ์ฑํจ์๊ฐ ์์ญ๊ฐ, ์๋ฐฑ๊ฐ ์์ธ๋ค๋ฉด ๊ฒฐ๊ตญ ์๋ณธ ์ ๋ณด๋ ๋งค์ฐ ํ๋ ค์ง๊ณ ๋ญ๊ฐ์ ธ์ ๋ค๋ฅธ ์ธ์คํด์ค์ ๊ตฌ๋ณ์ด ํ๋ค์ด ์ง ๊ฒ์ด๋ค.
๐ฌย 3) ๋์ฝ๋๊ฐ ๊ฐ์ฅ ๋ง์ง๋ง ๋จ์ด๋ง ์ด์ฌํ ๋ณด๊ณ denoising ์ํ (Seq2Seq with Bi-Directional RNN)
โ์ฐ๋คโ
($t_7$)๋ผ๋ ๋จ์ด์ ๋ป์ ์ดํดํ๋ ค๋ฉด โ๋์โ
, โ๋ชจ์๋ฅผโ
, โ๋ง์ดโ
, โ๊ธ์โ
($t_1$)๊ณผ ๊ฐ์ด ๋ฉ๋ฆฌ ์๋ ์ ๋จ์ด๋ฅผ ๋ด์ผ ์ ์ ์๋๋ฐ, $h_7$ ์๋ $t_1$์ด ํ๋ ค์ง ์ฑ๋ก ๋ค์ด๊ฐ ์์ด์ $t_7$์ ์ ๋๋ก ๋ ์๋ฏธ๋ฅผ ํฌ์ฐฉํ์ง ๋ชปํ๋ค. ์ฌ์ง์ด ์ธ์ด๊ฐ ์์ด๋ผ๋ฉด ๋ค๋ฅผ ๋ด์ผ ์ ํํ ๋ฌธ๋งฅ์ ์ ์ ์๋๋ฐ Vanilla RNN
์ ๋จ๋ฐฉํฅ์ผ๋ก๋ง ํ์ต์ ํ๊ฒ ๋์ด ๋ฌธ์ฅ์ ๋ท๋ถ๋ถ ๋ฌธ๋งฅ์ ๋ฐ์์กฐ์ฐจ(๋ค์ ์์นํ ๋ชฉ์ ์ด์ ๋ฐ๋ผ์ ์ฐ๋ค๋ผ๋ ๋จ์ด์ ๋์์ค๋ ๋ฌ๋ผ์ง) ํ ์ ์๋ค. ๊ทธ๋์ Bi-directional RNN
์จ์ผํ๋๋ฐ, ์ด๊ฒ๋ ์ญ์๋ ์ฌ์ ํ โ๊ฑฐ๋ฆฌโ
์ ์ํฅ ๋ฐ๋๋ค๋ ๊ฑด ๋ณํ์ง ์๊ธฐ ๋๋ฌธ์ ๊ทผ๋ณธ์ ์ธ ํด๊ฒฐ์ฑ
์ด๋ผ ๋ณผ ์ ์๋ค.
ํํธ, ๋์ฝ๋์ Next Token Prediction
์ฑ๋ฅ์ ๋ฌด์กฐ๊ฑด ์ธ์ฝ๋๋ก๋ถํฐ ๋ฐ๋ Context Vector
์ ํ์ง์ ๋ฐ๋ผ ์ข์ง์ฐ์ง ๋๋ค. ๊ทธ๋ฌ๋ Recurrent ๊ตฌ์กฐ์ ์ธ์ฝ๋๋ก๋ถํฐ ๋์จ Context Vector๋ ์์ ์์ ํ ๊ฒ์ฒ๋ผ ์ข์ ํ์ง(๋ค์ชฝ ๋จ์ด๊ฐ ์๋์ ์ผ๋ก ์ ๋ช
ํจ)์ด ์๋๋ค. ๋ฐ๋ผ์ ๋์ฝ๋์ ๋ฒ์ญ(๋ค์ ๋จ์ด ์์ธก) ์ฑ๋ฅ ์ญ์ ์ข์๋ฆฌ๊ฐ ์๋ค.
๊ฒฐ๊ตญ Recurrent
๊ตฌ์กฐ ์์ฒด์ ๋ช
ํํ ํ๊ณ๊ฐ ์กด์ฌํ์ฌ ์ธ๊ฐ์ด ์์ฐ์ด๋ฅผ ์ฌ์ฉํ๊ณ ์ดํดํ๋ ๋งฅ๋ฝ๊ณผ ๋ค๋ฅธ ๋ฐฉ์์ผ๋ก ๋์ํ๊ฒ ๋์๋ค. LSTM
, GRU
์ ์ ์์ผ๋ก ์ด๋ ์ ๋ ๋ฌธ์ ๋ฅผ ์ํ ์์ผฐ์ผ๋, ์์์ ์์ ํ๋ฏ์ด ํ์์ด Recurrent Structure
์ ๊ฐ์ง๊ธฐ ๋๋ฌธ์ ๊ทผ๋ณธ์ ์ธ ํด๊ฒฐ์ฑ
์ด ๋์ง๋ ๋ชปํ๋ค. ๊ทธ๋ ๋ค๋ฉด ์ด์ Transformer
๊ฐ ์ด๋ป๊ฒ ์์ ์์ ํ 3๊ฐ์ง ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ณ ํ์ฌ์ ์์์ ๊ฐ๊ฒ ๋์๋์ง ์์๋ณด์.
๐ย Modeling
์์ Recurrent
๊ตฌ์กฐ์ Vanishing Gradient
์ ์ค๋ช
ํ๋ฉด์ ์์ ์ ๋ฐ๋ผ ์ ๋ณด๋ฅผ ์์คํ๊ฒ ๋๋ ํ์์ ์ธ๊ฐ์ ์์ฐ์ด ์ดํด ๋ฐฉ์์ด ์๋๋ผ๋ ์ ์ ์ธ๊ธํ ์ ์๋ค. ๋ฐ๋ผ์ Transformer
๋ ์ต๋ํ ์ธ๊ฐ์ ์์ฐ์ด ์ดํด ๋ฐฉ์์ ์ํ์ ์ผ๋ก ๋ชจ๋ธ๋ง ํ๋ ๊ฒ์ ์ด์ ์ ๋ง์ท๋ค. ์ฐ๋ฆฌ๊ฐ ์ฐ์ฌ์ง ๊ธ์ ์ดํดํ๊ธฐ ์ํด ํ๋ ํ๋๋ค์ ๋ ์ฌ๋ ค ๋ณด์. โAppleโ
์ด๋ ๋จ์ด๊ฐ ์ฌ๊ณผ๋ฅผ ๋งํ๋ ๊ฒ์ธ์ง, ๋ธ๋๋ ์ ํ์ ์ง์นญํ๋ ๊ฒ์ธ์ง ํ์
ํ๊ธฐ ์ํด ๊ฐ์ ๋ฌธ์ฅ์ ์ํ ์ฃผ๋ณ ๋จ์ด๋ฅผ ์ดํผ๊ธฐ๋ ํ๊ณ ๊ทธ๋๋ ํ์
ํ๊ธฐ ํ๋ค๋ค๋ฉด ์๋ค ๋ฌธ์ฅ, ๋์๊ฐ ๋ฌธ์ ์ ์ฒด ๋ ๋ฒจ์์ ๋งฅ๋ฝ์ ํ์
ํ๊ธฐ ์ํด ๋
ธ๋ ฅํ๋ค. Transformer
์ฐ๊ตฌ์ง์ ๋ฐ๋ก ์ด ๊ณผ์ ์ ์ฃผ๋ชฉํ์ผ๋ฉฐ ์ด๊ฒ์ ๋ชจ๋ธ๋งํ์ฌ ๊ทธ ์ ๋ช
ํ Self-Attention
์ ๊ณ ์ํด๋ธ๋ค.
๋ค์ ๋งํด Self-Attention
์ ํ ํฐ์ ์๋ฏธ๋ฅผ ์ดํดํ๊ธฐ ์ํด ์ ์ฒด ์
๋ ฅ ์ํ์ค
์ค์์ ์ด๋ค ๋จ์ด์ ์ฃผ๋ชฉํด์ผํ ์ง๋ฅผ ์ํ์ ์ผ๋ก ํํํ ๊ฒ์ด๋ผ ๋ณผ ์ ์๋ค. ์ข ๋ ๊ตฌ์ฒด์ ์ผ๋ก๋ ์ํ์ค์ ์ํ ์ฌ๋ฌ ํ ํฐ ๋ฒกํฐ(ํ๋ฐฑํฐ)๋ฅผ ์๋ฒ ๋ฉ ๊ณต๊ฐ ์ด๋์ ๋ฐฐ์นํ ๊ฒ์ธ๊ฐ์ ๋ํด ํ๋ จํ๋ ํ์๋ค.
๊ทธ๋ ๋ค๋ฉด ์ด์ ๋ถํฐ Transformer
๊ฐ ์ด๋ค ์์ด๋ฐ์ด์
์ ํตํด ๊ธฐ์กด ์ํ ์ ๊ฒฝ๋ง ๋ชจ๋ธ์ ๋จ์ ์ ํด๊ฒฐํ๊ณ ๋ฅ๋ฌ๋๊ณ์ G.O.A.T
์๋ฆฌ๋ฅผ ์ฐจ์งํ๋์ง ์์๋ณด์. ๋ชจ๋ธ์ ํฌ๊ฒ ์ธ์ฝ๋์ ๋์ฝ๋ ๋ถ๋ถ์ผ๋ก ๋๋๋๋ฐ, ํ๋ ์ญํ ๊ณผ ๋ฏธ์ธํ ๊ตฌ์กฐ์์ ์ฐจ์ด๋ง ์์๋ฟ ๋ ๋ชจ๋ ๋ชจ๋ Self-Attention
์ด ์ ์ผ ์ค์ํ๋ค๋ ๋ณธ์ง์ ๋ณํ์ง ์๋๋ค. ๋ฐ๋ผ์ Input Embedding
๋ถํฐ ์ฐจ๋ก๋๋ก ์ดํด๋ณด๋, Self-Attention
์ ํน๋ณํ ์ฌ์ฉ๋ ํ์ ๋ธ๋ญ ๋จ์๋ฅผ ๋น ์ง ์์ด, ์ธ์ธํ๊ฒ ์ดํด๋ณผ ๊ฒ์ด๋ค.
Class Diagram
์ด๋ ๊ฒ ํ์ ๋ชจ๋์ ๋ํ ์ค๋ช ๋ถํฐ ์์ ๋๊ฐ ๋ง์ง๋ง์๋ ์ค์ ๊ตฌํ ์ฝ๋์ ํจ๊ป ์ ์ฒด์ ์ธ ๊ตฌ์กฐ ์ธก๋ฉด์์๋ ๋ชจ๋ธ์ ํด์ํด๋ณผ ๊ฒ์ด๋ค. ๋๊น์ง ํฌ์คํ ์ ์ฝ์ด์ฃผ์๊ธธ ๋ฐ๋๋ค.
๐ฌ Input Embedding
\[X_E \in R^{B * S_E * V_E} \\
X_D \in R^{B * S_D * V_D}\]
Transformer
๋ ์ธ์ฝ๋์ ๋์ฝ๋๋ก ์ด๋ค์ง seq2seq
๊ตฌ์กฐ๋ฅผ ๊ฐ์ง๊ณ ์๋ค. ์ฆ, ๋์ ์ธ์ด๋ฅผ ํ๊ฒ ์ธ์ด๋ก ๋ฒ์ญํ๋๋ฐ ๋ชฉ์ ์ ๋๊ณ ์๊ธฐ ๋๋ฌธ์ ์
๋ ฅ์ผ๋ก ๋์ ์ธ์ด ์ํ์ค์ ํ๊ฒ ์ธ์ด ์ํ์ค ๋ชจ๋ ํ์ํ๋ค. $X_E$๋ ์ธ์ฝ๋
์ ์
๋ ฅ ํ๋ ฌ์ ๋ํ๋ด๊ณ , $X_D$๋ ๋์ฝ๋
์ ์
๋ ฅ ํ๋ ฌ์ ์๋ฏธํ๋ค. ์ด ๋, $B$๋ batch size
, $S$๋ max_seq
, $V$๋ ๊ฐ๋ณ ๋ชจ๋์ด ๊ฐ์ง Vocab
์ ์ฌ์ด์ฆ๋ฅผ ๊ฐ๋ฆฌํจ๋ค. ์ ์์์ ์ฌ์ค ๋
ผ๋ฌธ์ ์
๋ ฅ์ ๋ํ ์์์ด ๋ฐ๋ก ์์ ๋์ด ์์ง ์์, ํ์๊ฐ ์ง์ ๋ง๋ ๊ฒ์ด๋ค. ์์ผ๋ก๋ ํด๋น ๊ธฐํธ๋ฅผ ์ด์ฉํด ์์์ ํํํ ์์ ์ด๋ ์ฐธ๊ณ ๋ฐ๋๋ค.
์ด๋ ๊ฒ ์ ์๋ ์
๋ ฅ๊ฐ์ ๊ฐ๋ณ ๋ชจ๋์ ์๋ฒ ๋ฉ ๋ ์ด์ด์ ํต๊ณผ ์ํจ ๊ฒฐ๊ณผ๋ฌผ์ด ๋ฐ๋ก Input Embedding
์ด ๋๋ค. $d$๋ Transformer
๋ชจ๋ธ์ ์๋์ธต์ ํฌ๊ธฐ๋ฅผ ์๋ฏธํ๋ค. ๋ฐ๋ผ์ Position Embedding
๊ณผ ๋ํด์ง๊ธฐ ์ , ์๋ฒ ๋ฉ ๋ ์ด์ด๋ฅผ ํต๊ณผํ Input Embedding
์ ๋ชจ์์ ์๋ ์์๊ณผ ๊ฐ๋ค.
๊ทธ๋ ๋ค๋ฉด ์ค์ ๊ตฌํ์ ์ด๋ป๊ฒ ํ ๊น?? Transformer
์ Input Embedding
์ nn.Embedding
์ผ๋ก ๋ ์ด์ด๋ฅผ ์ ์ํด ์ฌ์ฉํ๋ค. nn.Linear
๋ ์๋๋ฐ ์ ๊ตณ์ด nn.Embedding
์ ์ฌ์ฉํ๋ ๊ฒ์ผ๊น??
์์ฐ์ด ์ฒ๋ฆฌ์์ ์
๋ ฅ ์๋ฒ ๋ฉ์ ๋ง๋ค๋๋ ๋ชจ๋ธ์ ํ ํฌ๋์ด์ ์ ์ํด ์ฌ์ ์ ์๋ vocab
์ ์ฌ์ด์ฆ๊ฐ ์
๋ ฅ ์ํ์ค์ ์ํ ํ ํฐ ๊ฐ์๋ณด๋ค ํจ์ฌ ํฌ๊ธฐ ๋๋ฌธ์ ๋ฐ์ดํฐ ๋ฃฉ์
ํ
์ด๋ธ ๋ฐฉ์์ย nn.Embedding
ย ์ ์ฌ์ฉํ๊ฒ ๋๋ค. ์ด๊ฒ ๋ฌด์จ ๋ง์ด๋๋ฉด, ํ ํฌ๋์ด์ ์ ์ํด ์ฌ์ ์ ์ ์๋ย vocab
ย ์ ์ฒด๊ฐย nn.Embedding(vocab_size, dim_model)
๋ก ํฌ์ ๋์ด ๊ฐ๋ก๋ vocab
์ฌ์ด์ฆ, ์ธ๋ก๋ ๋ชจ๋ธ์ ์ฐจ์ ํฌ๊ธฐ์ ํด๋นํ๋ ๋ฃฉ์
ํ
์ด๋ธ์ด ์์ฑ๋๊ณ , ๋ด๊ฐ ์
๋ ฅํ ํ ํฐ๋ค์ ์ ์ฒดย vocab
์ ์ผ๋ถ๋ถ์ผํ
๋ ์ ์ฒด ์๋ฒ ๋ฉ ๋ฃฉ์
ํ
์ด๋ธ์์ ๋ด๊ฐ ์๋ฒ ๋ฉํ๊ณ ์ถ์ ํ ํฐ๋ค์ ์ธ๋ฑ์ค๋ง ์์๋ธ๋ค๋ ๊ฒ์ด๋ค. ๊ทธ๋์ย nn.Embedding
ย ์ ๋ ์ด์ด์ ์ ์๋ ์ฐจ์๊ณผ ์ค์ ์
๋ ฅ ๋ฐ์ดํฐ์ ์ฐจ์์ด ๋ง์ง ์์๋ ํจ์๊ฐ ๋์ํ๊ฒ ๋๋ค. nn.Linear
์ ์
๋ ฅ ์ฐจ์์ ๋ํ ์กฐ๊ฑด ๋นผ๊ณ ๋ ๋์ผํ ๋์์ ์ํํ๊ธฐ ๋๋ฌธ์ ์ฌ์ ์ ์๋ vocab
์ฌ์ด์ฆ์ ์
๋ ฅ ์ํ์ค์ ํ ํฐ ๊ฐ์๊ฐ ๊ฐ๋ค๋ฉด nn.Linear
๋ฅผ ์ฌ์ฉํด๋ ๋ฌด๋ฐฉํ๋ค.
# Input Embedding Example
class Transformer(nn.Module):
def __init__(
self,
enc_vocab_size: int,
dec_vocab_size: int,
max_seq: int = 512,
enc_N: int = 6,
dec_N: int = 6,
dim_model: int = 512, # latent vector space
num_heads: int = 8,
dim_ffn: int = 2048,
dropout: float = 0.1
) -> None:
super(Transformer, self).__init__()
self.enc_input_embedding = nn.Embedding(enc_vocab_size, dim_model) # Encoder Input Embedding Layer
self.dec_input_embedding = nn.Embedding(dec_vocab_size, dim_model) # Decoder Input Embedding Layer
def forward(self, enc_inputs: Tensor, dec_inputs: Tensor, enc_pad_index: int, dec_pad_index: int) -> tuple[Tensor, Tensor, Tensor, Tensor]:
enc_x, dec_x = self.enc_input_embedding(enc_inputs), self.dec_input_embedding(dec_inputs)
์์ ์์ ์ฝ๋๋ฅผ ํจ๊ป ์ดํด๋ณด์. __init__
์ self.enc_input_embedding
, self._dec_input_embedding
์ด ๋ฐ๋ก $W_E, W_D$์ ๋์๋๋ค. ํํธ forward
๋ฉ์๋์ ์ ์๋ enc_x
, dec_x
๋ ์๋ฒ ๋ฉ ๋ ์ด์ด๋ฅผ ๊ฑฐ์น๊ณ ๋์จ $X_E, X_D$์ ํด๋น๋๋ค.
ํํธ, $X_E, X_D$์ ๊ฐ๊ฐ ์ธ์ฝ๋, ๋์ฝ๋ ๋ชจ๋๋ก ํ๋ฌ ๋ค์ด๊ฐ Absolute Position Embedding
๊ณผ ๋ํด์ง(ํ๋ ฌ ํฉ) ๋ค, ๊ฐ๋ณ ๋ชจ๋์ ์
๋ ฅ๊ฐ์ผ๋ก ํ์ฉ๋๋ค.
๐ขย Absolute Position Embedding(Encoding)
์
๋ ฅ ์ํ์ค์ ์์น ์ ๋ณด๋ฅผ ๋งตํํด์ฃผ๋ ์ญํ ์ ํ๋ค. ํ์๋ ๊ฐ์ธ์ ์ผ๋ก Transformer
์์ ๊ฐ์ฅ ์ค์ํ ์์๋ฅผ ๋ฝ์ผ๋ผ๊ณ ํ๋ฉด ์ธ ์๊ฐ๋ฝ ์์ ๋ค์ด๊ฐ๋ ํํธ๋ผ๊ณ ์๊ฐํ๋ค. ๋ค์ ํํธ์์ ์์ธํ ๊ธฐ์ ํ๊ฒ ์ง๋ง, Self-Attention(๋ด์ )
์ ์
๋ ฅ ์ํ์ค๋ฅผ ๋ณ๋ ฌ๋ก ํ๊บผ๋ฒ์ ์ฒ๋ฆฌํ ์ ์๋ค๋ ์ฅ์ ์ ๊ฐ๊ณ ์์ง๋ง, ๊ทธ ์์ฒด๋ก๋ ํ ํฐ์ ์์น ์ ๋ณด๋ฅผ ์ธ์ฝ๋ฉํ ์ ์๋ค. ์ฐ๋ฆฌ๊ฐ ๋ฐ๋ก ์์น ์ ๋ณด๋ฅผ ์๋ ค์ฃผ์ง ์๋ ์ด์ ์ฟผ๋ฆฌ ํ๋ ฌ์ 2๋ฒ์งธ ํ๋ฒกํฐ๊ฐ ์
๋ ฅ ์ํ์ค์์ ๋ช ๋ฒ์งธ ์์นํ ํ ํฐ์ธ์ง ๋ชจ๋ธ์ ์ ๊ธธ์ด ์๋ค.
๊ทธ๋ฐ๋ฐ, ํ
์คํธ๋ Permutation Equivariant
ํ Bias
๊ฐ ์๊ธฐ ๋๋ฌธ์ ํ ํฐ์ ์์น ์ ๋ณด๋ NLP
์์ ๋งค์ฐ ์ค์ํ ์์๋ก ๊ผฝํ๋ค. ์ง๊ด์ ์ผ๋ก๋ ํ ํฐ์ ์์๋ ์ํ์ค๊ฐ ๋ดํฌํ๋ ์๋ฏธ์ ์ง๋ํ ์ํฅ์ ๋ผ์น๋ค๋ ๊ฒ์ ์ ์ ์๋ค. ์๋ฅผ ๋ค์ด โ์ฒ ์๋ ์ํฌ๋ฅผ ์ข์ํ๋คโ
๋ผ๋ ๋ฌธ์ฅ๊ณผ โ์ํฌ๋ ์ฒ ์๋ฅผ ์ข์ํ๋คโ
๋ผ๋ ๋ฌธ์ฅ์ ์๋ฏธ๊ฐ ๊ฐ์๊ฐ ์๊ฐํด๋ณด์. ์ฃผ์ด์ ๋ชฉ์ ์ด ์์น๊ฐ ๋ฐ๋๋ฉด์ ์ ๋ฐ๋์ ๋ป์ด ๋์ด๋ฒ๋ฆฐ๋ค.
๋ฐ๋ผ์ ์ ์๋ ์
๋ ฅ ์
๋ฒ ๋ฉ์ ์์น ์ ๋ณด๋ฅผ ์ถ๊ฐํ๊ณ ์ Position Encoding
์ ์ ์ํ๋ค. ์ฌ์ค Position Encoding
์ ์ฌ๋ฌ ๋จ์ ๋๋ฌธ์ ํ๋ Transformer
ํ์ ๋ชจ๋ธ์์๋ ์ ์ฌ์ฉ๋์ง ์๋ ์ถ์ธ๋ค. ๋์ ๋ชจ๋ธ์ด ํ์ต์ ํตํด ์ต์ ๊ฐ์ ์ฐพ์์ฃผ๋ Position Embedding
๋ฐฉ์์ ๋๋ถ๋ถ ์ฐจ์ฉํ๊ณ ์๋ค. ํ์ ์ญ์ Position Embedding
์ ์ฌ์ฉํด ์์น ์๋ฒ ๋ฉ์ ๊ตฌํํ๊ธฐ ๋๋ฌธ์ ์๋ฆฌ์ ๋จ์ ์ ๋ํด์๋ง ๊ฐ๋จํ ์๊ฐํ๊ณ ๋์ด๊ฐ๋ ค ํ๋ค. ๋ํ ์ ์ ์ญ์ ๋
ผ๋ฌธ์์ ๋ ๋ฐฉ์ ์ค ์ด๋ ๊ฒ์ ์จ๋ ๋น์ทํ ์ฑ๋ฅ์ ๋ณด์ฌ์ค๋ค๊ณ ์ธ๊ธํ๊ณ ์๋ค.
์๋ฆฌ๋ ๋งค์ฐ ๊ฐ๋จํ๋ค. ์ฌ์ธํจ์์ ์ฝ์ฌ์ธ ํจ์์ ์ฃผ๊ธฐ์ฑ์ ์ด์ฉํด ๊ฐ๋ณ ์ธ๋ฑ์ค์ ํ๋ฒกํฐ ๊ฐ์ ํํํ๋ ๊ฒ์ด๋ค. ํ๋ฒกํฐ์ ์์ ์ค์์ ์ง์๋ฒ์งธ ์ธ๋ฑ์ค์ ์์นํ ์์๋ (์ง์๋ฒ์งธ ์ด๋ฒกํฐ) \(sin(pos/\overset{}{10000_{}^{2i/dmodel}})\) ์ ํจ์ซ๊ฐ์ ์ด์ฉํด ์ฑ์๋ฃ๊ณ , ํ์๋ฒ์งธ ์์๋ \(cos(pos/\overset{}{10000_{}^{2i/dmodel}})\)๋ฅผ ์ด์ฉํด ์ฑ์๋ฃ๋๋ค.
periodic function graph
์ด๋ก์ ๊ทธ๋ํ๋ \(sin(pos/\overset{}{10000_{}^{2i/dmodel}})\), ์ฃผํฉ์ ๊ทธ๋ํ๋ \(cos(pos/\overset{}{10000_{}^{2i/dmodel}})\)๋ฅผ ์๊ฐํํ๋ค. ์ง๋ฉด์ ์ ํ์ผ๋ก max_seq=512
๋งํผ์ ๋ณํ๋์ ๋ด์ง๋ ๋ชปํ์ง๋ง, x์ถ์ด ์ปค์ง์๋ก ๋ ํจ์ ๋ชจ๋ ์ง๋ ์ฃผ๊ธฐ๊ฐ ์กฐ๊ธ์ฉ ์ปค์ง๋ ์์์ ๋ณด์ฌ์ค๋ค. ๋ฐ๋ผ์ ๊ฐ๋ณ ์ธ๋ฑ์ค(ํ๋ฒกํฐ)๋ฅผ ์ค๋ณต๋๋ ๊ฐ ์์ด ํํํ๋ ๊ฒ์ด ๊ฐ๋ฅํ๋ค๊ณ ์ ์๋ ์ฃผ์ฅํ๋ค.
์ ๊ทธ๋ฆผ์ ํ ํฐ 50
๊ฐ, ์๋์ธต์ด 256์ฐจ์์ผ๋ก ๊ตฌ์ฑ๋ ์ํ์ค์ ๋ํด Positional Encoding
ํ ๊ฒฐ๊ณผ๋ฅผ ์๊ฐํํ ์๋ฃ๋ค. ๊ทธ๋ํ์ $x$์ถ์ ํ๋ฒกํฐ์ ์์
์ด์ Transformer
์ ์๋ ๋ฒกํฐ ์ฐจ์์ ๊ฐ๋ฆฌํค๊ณ , $y$์ถ์ ์ํ์ค์ ์ธ๋ฑ์ค
(ํ๋ฒกํฐ)๋ฅผ ์๋ฏธํ๋ค. ์ก์์ผ๋ก ์ ํํ๊ฒ ์ฐจ์ด๋ฅผ ์ธ์ํ๊ธฐ ์ฝ์ง๋ ์์ง๋ง, ํ๋ฒกํฐ๊ฐ ๋ชจ๋ ์ ๋ํฌํ๊ฒ ํํ๋๋ค๋ ์ฌ์ค(์ง์ ์ค์๊ฐ์ ํ์ธํด๋ณด๋ฉด ์ ๋ง ๋ฏธ์ธํ ์ฐจ์ด์ง๋ง ๊ฐ๋ณ ํ ํฐ์ ํฌ์์ฑ์ด ๋ณด์ฅ)์ ์ ์ ์๋ค. ์์ ์ฐจ์ด๋ฅผ ์๊ฐํ ์๋ฃ๋ก ํ์
ํ๊ธฐ๋ ์ฝ์ง ์๊ธฐ ๋๋ฌธ์ ์ง์ง ๊ทธ๋ฐ๊ฐ ๊ถ๊ธํ์ ๋ถ๋ค์ ์ง์ ์ค์๊ฐ์ ๊ตฌํด๋ณด๋ ๊ฒ์ ์ถ์ฒ๋๋ฆฐ๋ค.
์ฌ๊ธฐ์ ํ๋ฒกํฐ์ ํฌ์์ฑ์ด๋ ๊ฐ๋ณ ํ๋ฒกํฐ ์์์ ํฌ์์ฑ์ ๋งํ๋๊ฒ ์๋๋ค. 0๋ฒ ํ ํฐ, 4๋ฒ ํ ํฐ, 9๋ฒ ํ ํฐ์ ํ๋ฒกํฐ 1๋ฒ์งธ ์์์ ๊ฐ์ ๊ฐ์ ์ ์๋ค. ํ์ง๋ง ์ง๋ ์ฃผ๊ธฐ๊ฐ ๊ฐ์๋ก ์ปค์ง๋ ์ฃผ๊ธฐํจ์๋ฅผ ์ฌ์ฉํ๊ธฐ ๋๋ฌธ์ ๋ค๋ฅธ ์์(์ฐจ์)๊ฐ์ ๋ค๋ฅผ ๊ฒ์ด๋ผ ๊ธฐ๋ํ ์ ์๋๋ฐ, ๋ฐ๋ก ์ด๊ฒ์ ํ๋ฒกํฐ์ ํฌ์์ฑ์ด๋ผ๊ณ ์ ์ํ๋ ๊ฒ์ด๋ค. ๋ง์ฝ 1๋ฒ ํ ํฐ๊ณผ 2๋ฒ ํ ํฐ์ ๋ชจ๋ ํ๋ฒกํฐ ์์๊ฐ์ด ๊ฐ๋ค๋ฉด ๊ทธ๊ฒ์ ํฌ์์ฑ ์์น์ ์๋ฐฐ๋๋ ์ํฉ์ด๋ค.
Compare Performance between Encoding and Embedding
๋น๋ก ๊ฐ๋ณ ํ๋ฒกํฐ์ ํฌ์์ฑ์ด ๋ณด์ฅ๋๋ค๊ณ ํด๋ Position Encoding
์ not trainable
ํด์ static
ํ๋ค๋ ๋จ์ ์ด ์๋ค. ๋ชจ๋ ๋ฐฐ์น์ ์ํ์ค๊ฐ ๋์ผํ ์์น ์ ๋ณด๊ฐ์ ๊ฐ๊ฒ ๋๋ค๋ ๊ฒ์ด๋ค. 512
๊ฐ์ ํ ํฐ์ผ๋ก ๊ตฌ์ฑ๋ ์ํ์ค A์ B๊ฐ ์๋ค๊ณ ๊ฐ์ ํด๋ณด์. ์ด ๋ ์ํ์ค A๋ ๋ฌธ์ฅ 5
๊ฐ๋ก ๊ตฌ์ฑ ๋์ด ์๊ณ , B๋ ๋ฌธ์ฅ 12
๊ฐ๋ก ๋ง๋ค์ด์ก๋ค. ๋ ์ํ์ค์ 11
๋ฒ์งธ ํ ํฐ์ ๋ฌธ์ฅ ์ฑ๋ถ์ ๊ณผ์ฐ ๊ฐ์๊น?? ์๋ง๋ ๋๋ถ๋ถ์ ๊ฒฝ์ฐ์ ๋ค๋ฅผ ๊ฒ์ด๋ค. ํ
์คํธ ๋ฐ์ดํฐ์์ ์์ ์ ๋ณด๊ฐ ์ค์ํ ์ด์ ์ค ํ๋๋ ๋ฐ๋ก syntactical
ํ ์ ๋ณด๋ฅผ ํฌ์ฐฉํ๊ธฐ ์ํจ์ด๋ค. Position Encoding
์ static
ํ๊ธฐ ๋๋ฌธ์ ์ด๋ฌํ ํ์
์ ์ ๋ณด๋ฅผ ์ธ์ฝ๋ฉ ํ๊ธฐ ์ฝ์ง ์๋ค. ๊ทธ๋์ ์ข ๋ ํ๋ถํ ํํ์ ๋ด์ ์ ์๋ Position Embedding
์ ์ฌ์ฉํ๋ ๊ฒ์ด ์ต๊ทผ ์ถ์ธ๋ค.
โ๏ธ Position Embedding
๊ทธ๋ ๋ค๋ฉด ์ด์ Position Embedding
์ ๋ํด ์์๋ณด์. Position Embedding
์ Input Embedding
์ ์ ์ํ ๋ฐฉ์๊ณผ ๊ฑฐ์ ์ ์ฌํ๋ค. ๋จผ์ ์
๋ ฅ๊ฐ๊ณผ weight
์ ๋ชจ์๋ถํฐ ํ์ธํด๋ณด์.
$P_E, P_D$๋ ๊ฐ๋ณ ๋ชจ๋์ ์์น ์๋ฒ ๋ฉ ๋ ์ด์ด ์ ๋ ฅ์ ๊ฐ๋ฆฌํค๋ฉฐ, $W_{P_E}, W_{P_D}$๊ฐ ๊ฐ๋ณ ๋ชจ๋์ ์์น ์๋ฒ ๋ฉ ๋ ์ด์ด๊ฐ ๋๋ค. ์ด์ ์ด๊ฒ์ ์ฝ๋๋ก ์ด๋ป๊ฒ ๊ตฌํํ๋์ง ์ดํด๋ณด์.
# Absolute Position Embedding Example
class Encoder(nn.Module):
"""
In this class, encode input sequence and then we stack N EncoderLayer
First, we define "positional embedding" and then add to input embedding for making "word embedding"
Second, forward "word embedding" to N EncoderLayer and then get output embedding
In official paper, they use positional encoding, which is base on sinusoidal function(fixed, not learnable)
But we use "positional embedding" which is learnable from training
Args:
max_seq: maximum sequence length, default 512 from official paper
N: number of EncoderLayer, default 6 for base model
"""
def __init__(self, max_seq: 512, N: int = 6, dim_model: int = 512, num_heads: int = 8, dim_ffn: int = 2048, dropout: float = 0.1) -> None:
super(Encoder, self).__init__()
self.max_seq = max_seq
self.scale = torch.sqrt(torch.Tensor(dim_model)) # scale factor for input embedding from official paper
self.positional_embedding = nn.Embedding(max_seq, dim_model) # add 1 for cls token
... ์ค๋ต ...
def forward(self, inputs: Tensor, mask: Tensor) -> tuple[Tensor, Tensor]:
"""
inputs: embedding from input sequence, shape => [BS, SEQ_LEN, DIM_MODEL]
mask: mask for Encoder padded token for speeding up to calculate attention score
"""
layer_output = []
pos_x = torch.arange(self.max_seq).repeat(inputs.shape[0]).to(inputs)
x = self.dropout(
self.scale * inputs + self.positional_embedding(pos_x) # layernorm ์ ์ฉํ๊ณ
)
... ์ค๋ต ...
์ ์ฝ๋๋ Transformer
์ ์ธ์ฝ๋ ๋ชจ๋์ ๊ตฌํํ ๊ฒ์ด๋ค. ๊ทธ๋์ forward
๋ฉ์๋์ pos_x
๊ฐ ๋ฐ๋ก $P_E$๊ฐ ๋๋ฉฐ, __init__
์ self.positional_embedding
์ด ๋ฐ๋ก $W_{P_E}$์ ๋์๋๋ค. ์ด๋ ๊ฒ ์ ์ํ Position Embedding
์ Input Embedding
๊ณผ ๋ํด์ Word Embedding
์ ๋ง๋ ๋ค. Word Embedding
์ ๋ค์ ๊ฐ๋ณ ๋ชจ๋์ linear projection
๋ ์ด์ด์ ๋ํ ์
๋ ฅ $X$๋ก ์ฌ์ฉ ๋๋ค.
ํํธ, Input Embedding
๊ณผ Position Embedding
์ ๋ํ๋ค๋ ๊ฒ์ ์ฃผ๋ชฉํด๋ณด์. ํ์๋ ๋ณธ ๋
ผ๋ฌธ์ ๋ณด๋ฉฐ ๊ฐ์ฅ ์๋ฌธ์ด ๋ค์๋ ๋ถ๋ถ์ด๋ค. ๋๋์ฒด ์ ์์ ํ ์๋ก ๋ค๋ฅธ ์ถ์ฒ์์ ๋ง๋ค์ด์ง ํ๋ ฌ ๋๊ฐ๋ฅผ concat
ํ์ง ์๊ณ ๋ํด์ ์ฌ์ฉํ์๊น?? concat
์ ์ด์ฉํ๋ฉด Input
๊ณผ Position
์ ๋ณด๋ฅผ ์๋ก ๋ค๋ฅธ ์ฐจ์์ ๋๊ณ ํ์ตํ๋๊ฒ ๊ฐ๋ฅํ์ํ
๋ฐ ๋ง์ด๋ค.
๐คย Why Sum instead of Concatenate
ํ๋ ฌํฉ์ ์ฌ์ฉํ๋ ์ด์ ์ ๋ํด ์ ์๊ฐ ํน๋ณํ ์ธ๊ธํ์ง๋ ์์์ ๋๋ฌธ์ ์ ํํ ์๋๋ฅผ ์ ์ ์์ง๋ง, ์ถ์ธกํ๊ฑด๋ฐ blessing of dimensionality
ํจ๊ณผ๋ฅผ ์๋ํ์ง ์์๋ ์ถ๋ค. blessing of dimensionality
๋, ๊ณ ์ฐจ์ ๊ณต๊ฐ์์ ๋ฌด์์๋ก ์๋ก ๋ค๋ฅธ ๋ฒกํฐ ๋๊ฐ๋ฅผ ์ ํํ๋ฉด ๋ ๋ฒกํฐ๋ ๊ฑฐ์ ๋๋ถ๋ถ approximate orthogonality
๋ฅผ ๊ฐ๋ ํ์์ ์ค๋ช
ํ๋ ์ฉ์ด๋ค. ๋ฌด์กฐ๊ฑด ์ฑ๋ฆฝํ๋ ์ฑ์ง์ ์๋๊ณ ํ๋ฅ ๋ก ์ ์ธ ์ ๊ทผ์ด๋ผ๋ ๊ฒ์ ๋ช
์ฌํ์. ์๋ฌดํผ ์ง๊ตํ๋ ๋ ๋ฒกํฐ๋ ๋ด์ ๊ฐ์ด 0์ ์๋ ดํ๋ค. ์ฆ, ๋ ๋ฒกํฐ๋ ์๋ก์๊ฒ ์ํฅ์ ๋ฏธ์น์ง ๋ชปํ๋ค๋ ๊ฒ์ด๋ค. ์ด๊ฒ์ ์ ์ฒด ๋ชจ๋ธ์ hidden states space
์์ Input Embedding
๊ณผ Position Embedding
์ญ์ ๊ฐ๋ณ ๋ฒกํฐ๊ฐ span
ํ๋ ๋ถ๋ถ ๊ณต๊ฐ ๋ผ๋ฆฌ๋ ์๋ก ์ง๊ตํ ๊ฐ๋ฅ์ฑ์ด ๋งค์ฐ ๋๋ค๋ ๊ฒ์ ์๋ฏธํ๋ค. ๋ฐ๋ผ์ ์๋ก ๋ค๋ฅธ ์ถ์ฒ๋ฅผ ํตํด ๋ง๋ค์ด์ง ๋ ํ๋ ฌ์ ๋ํด๋ ์๋ก์๊ฒ ์ํฅ์ ๋ฏธ์น์ง ๋ชปํ ๊ฒ์ด๊ณ ๊ทธ๋ก ์ธํด ๋ชจ๋ธ์ด Input
๊ณผ Position
์ ๋ณด๋ฅผ ๋ฐ๋ก ์ ํ์ตํ ์ ์์ ๊ฒ์ด๋ผ ๊ธฐ๋ํด๋ณผ ์ ์๋ค. ๊ฐ์ ๋๋ก๋ง ๋๋ค๋ฉด, concat
์ ์ฌ์ฉํด ๋ชจ๋ธ์ hidden states space
๋ฅผ ๋๋ ค Computational Overhead
๋ฅผ ์ ๋ฐํ๋ ๊ฒ๋ณด๋ค ํจ์ฌ ํจ์จ์ ์ด๋ผ๊ณ ๋ณผ ์ ์๊ฒ ๋ค.
ํํธ blessing of dimensionality
์ ๋ํ ์ค๋ช
๊ณผ ์ฆ๋ช
์ ๊ฝค๋ ๋ง์ ๋ด์ฉ์ด ํ์ํด ์ฌ๊ธฐ์๋ ์์ธํ ๋ค๋ฃจ์ง ์๊ณ , ๋ค๋ฅธ ํฌ์คํธ์์ ๋ฐ๋ก ๋ค๋ฃจ๊ฒ ๋ค. ๊ด๋ จํ์ฌ ์ข์ ๋ด์ฉ์ ๋ด๊ณ ์๋ ๊ธ์ ๋งํฌ๋ฅผ ๊ฐ์ด ์ฒจ๋ถํ์ผ๋ ์ฝ์ด๋ณด์ค ๊ฒ์ ๊ถํ๋ค(๋งํฌ1, ๋งํฌ2).
๐ Self-Attention with linear projection
์ ์ด๋ฆ์ด self-attention
์ผ๊น ๋จผ์ ๊ณ ๋ฏผํด๋ณด์. ์ฌ์ค attention
๊ฐ๋
์ ๋ณธ ๋
ผ๋ฌธ์ด ๋ฐํ๋๊ธฐ ์ด์ ๋ถํฐ ์ฌ์ฉ๋๋ ๊ฐ๋
์ด๋ค. attention
์ seq2seq
๊ตฌ์กฐ์์ ์ฒ์ ๋์๋๋ฐ, seq2seq
์ ๋ฒ์ญ ์ฑ๋ฅ์ ๋์ด๋ ๊ฒ์ ๋ชฉ์ ์ผ๋ก ๊ณ ์๋ ๊ตฌ์กฐ๋ผ์, ๋ชฉํ์ธ ๋์ฝ๋์ hidden_states
๊ฐ์ ์ฟผ๋ฆฌ๋ก, ์ธ์ฝ๋์ hidden_states
๋ฅผ ํค, ๋ฒจ๋ฅ์ ์ถ์ฒ๋ก ์ฌ์ฉํ๋ค. ์ฆ, ์๋ก ๋ค๋ฅธ ์ถ์ฒ์์ ๋์จ hidden_states
์ ์ฌ์ฉํด ๋ด์ ์ฐ์ฐ์ ์ํํ๋ ๊ฒ์ด๋ค. ์ด๋ฐ ๊ฐ๋
์ ์ด์ โself"
๋ผ๋ ์ด๋ฆ์ด ๋ถ์๋ค. ๊ฒฐ๊ตญ ๊ฐ์ ์ถ์ฒ์์ ๋์จ hidden_states
๋ฅผ ๋ด์ ํ๊ฒ ๋ค๋ ์๋ฏธ๋ฅผ ๋ดํฌํ๊ณ ์๋ ๊ฒ์ด๋ค. ๋ด์ ์ ๋ ๋ฒกํฐ์ โ๋ฎ์ ์ ๋โ
๋ฅผ ์ํ์ ์ผ๋ก ๊ณ์ฐํ๋ค. ๋ฐ๋ผ์ self-attention
์ด๋ ๊ฐ๋จํ๊ฒ, ๊ฐ์ ์ถ์ฒ์์ ๋ง๋ค์ด์ง $Q$(์ฟผ๋ฆฌ), $K$(ํค), $V$(๋ฒจ๋ฅ)๊ฐ ์๋ก ์ผ๋ง๋ ๋ฎ์๋์ง
๊ณ์ฐํด๋ณด๊ฒ ๋ค๋ ๊ฒ์ด๋ค.
self-attention with linear projection
๊ทธ๋ ๋ค๋ฉด ์ด์ $Q$(์ฟผ๋ฆฌ), $K$(ํค), $V$(๋ฒจ๋ฅ)์ ์ ์ฒด, ๊ฐ์ ์ถ์ฒ์์ ๋์๋ค๋ ๋ง์ ์๋ฏธ ๊ทธ๋ฆฌ๊ณ ์
๋ ฅ ํ๋ ฌ $X$๋ฅผ linear projection
ํ์ฌ $Q$(์ฟผ๋ฆฌ), $K$(ํค), $V$(๋ฒจ๋ฅ) ํ๋ ฌ์ ๋ง๋๋ ์ด์ ๋ฅผ ๊ตฌ์ฒด์ ์ธ ์์๋ฅผ ํตํด ์ดํดํด๋ณด์. ์ถ๊ฐ๋ก $Q$(์ฟผ๋ฆฌ), $K$(ํค), $V$(๋ฒจ๋ฅ) ๊ฐ๋
์ Information Retrieval
์์ ๋จผ์ ํ์๋ ๊ฐ๋
์ด๋ผ์ ์์ ์ญ์ ์ ๋ณด ๊ฒ์๊ณผ ๊ด๋ จ๋ ๊ฒ์ผ๋ก ์ค๋นํ๋ค.
๋น์ ์ด ๋ง์ฝ โ์์ด์ปจ ํํฐ ์ฒญ์ํ๋ ๋ฐฉ๋ฒโ
์ด ๊ถ๊ธํด ๊ตฌ๊ธ์ ๊ฒ์ํ๋ ์ํฉ์ด๋ผ๊ณ ๊ฐ์ ํด๋ณด๊ฒ ๋ค. ๋ชฉํ๋ ๊ฐ์ฅ ๋น ๋ฅด๊ณ ์ ํํ๊ฒ ๋ด๊ฐ ์ํ๋ ํํฐ ์ฒญ์ ๋ฐฉ๋ฒ์ ๋ํ ์ง์์ ํ๋ํ๋ ๊ฒ์ด๋ค. ๊ทธ๋ ๋ค๋ฉด ๋น์ ์ ๋ญ๋ผ๊ณ ๊ตฌ๊ธ ๊ฒ์์ฐฝ์ ๊ฒ์ํ ๊ฒ์ธ๊ฐ??
์ด๊ฒ์ด ๋ฐ๋ก $Q$(์ฟผ๋ฆฌ)์ ํด๋นํ๋ค. ๋น์ ์ ๊ฒ์์ฐฝ์ โ์์ด์ปจ ํํฐ ์ฒญ์ํ๋ ๋ฐฉ๋ฒโ
์ ์
๋ ฅํด ๊ฒ์ ๊ฒฐ๊ณผ๋ฅผ ๋ฐํ ๋ฐ์๋ค. ๋ฐํ ๋ฐ์ ๊ฒฐ๊ณผ๋ฌผ์ ์งํฉ์ด ๋ฐ๋ก $K$(ํค)๊ฐ ๋๋ค. ๋น์ ์ ์ด 100๊ฐ์ ๋ธ๋ก๊ทธ ๊ฒ์๋ฌผ์ ํค ๊ฐ์ผ๋ก ๋ฐ์๋ค. ๊ทธ๋์ ๋น์ ์ด ์ฌ์ฉํ๋ ์ผ์ฑ ๋ฌดํ ์์ด์ปจ์ ํํฐ ์ฒญ์๋ฒ์ด ์ ํํ ์ ํ ๊ฒ์๋ฌผ์ ์ฐพ๊ธฐ ์ํด ํ๋ ํ๋ ๋งํฌ๋ฅผ ํ๊ณ ๋ค์ด๊ฐ ๋ณด์๋ค. ํ์ง๋ง ์ ํํ๊ฒ ์ํ๋ ์ ๋ณด๊ฐ ์์ด์ ๊ณ์ ์ฐพ๋ค๋ณด๋ ๊ฒฐ๊ตญ 4ํ์ด์ง ์ฏค์์ ์ํ๋ ์ ๋ณด๊ฐ ๋ด๊ธด ๊ฒ์๋ฌผ์ ์ฐพ์ ์ ์์๋ค. ์ด๋ ๊ฒ ๋ด๊ฐ ์ํ๋ ์ ๋ณด์ธ์ง ์๋์ง ๋์กฐํ๋ ๊ณผ์ ์ด ๋ฐ๋ก $Q$(์ฟผ๋ฆฌ)์ $K$(ํค) ํ๋ ฌ์ ๋ด์
ํ๋ ํ์๊ฐ ๋๋ค. ๊ณง๋ฐ๋ก ์์ด์ปจ ์ฒญ์๋ฅผ ํ๋ ค๊ณ ๋ณด๋, ๋ฐฉ๋ฒ์ ๊น๋จน์ด์ ๋งค๋
์ฌ๋ฆ๋ง๋ค ๊ฒ์์ ํด์ผํ ๊ฒ ๊ฐ์ ํด๋น ๊ฒ์๋ฌผ์ ๋ถ๋งํฌ์ ์ ์ฅํด๋์๋ค. ์ฌ๊ธฐ์ ๋ถ๋งํฌ๊ฐ ๋ฐ๋ก $V$(๋ฒจ๋ฅ) ํ๋ ฌ์ด ๋๋ค.
์ด ๋ชจ๋ ๊ณผ์ ์ 10๋ถ์ด ๊ฑธ๋ ธ๋ค. ๊ฒจ์ฐ ํํฐ ์ฒญ์ ๋ฐฉ๋ฒ์ ์ฐพ๋๋ฐ 10๋ถ์ด๋ผ๋ ๋น์ ์ ์์กด์ฌ์ด ์ํ๋ค. ๋ ๋นจ๋ฆฌ ์ํ๋ ์ ๋ณด(์์ค ํจ์ ์ต์ ํ)
๋ฅผ ์ฐพ์ ์ ์๋ ๋ฐฉ๋ฒ์ด ์์๊น ๊ณ ๋ฏผํด๋ณด๋ค๊ฐ ๋น์ ์ด ์ฌ์ฉํ๋ ์์ด์ปจ ๋ธ๋๋๋ช
(์ผ์ฑ Bespoke ์์ด์ปจ)์ ๊ฒ์์ด์ ์ถ๊ฐํ๊ธฐ๋ก ํ๋ค
. ๊ทธ๋ฌ๋๋ 1ํ์ด์ง ์ตํ๋จ์์ ์๊น 4ํ์ด์ง์์ ์ฐพ์ ์ ๋ณด๋ฅผ ๊ณง๋ฐ๋ก ์ฐพ์ ์ ์์๋ค. ๊ทธ ๋๋ถ์ ์๊ฐ์ 10๋ถ
์์ 1๋ถ 30์ด
๋ก ๋จ์ถ์ํฌ ์ ์์๋ค. ์ด๋ ๊ฒ ๊ฒ์ ์๊ฐ์ ๋จ์ถ(์์ค ์ค์ด๊ธฐ)ํ๊ธฐ ์ํด ๋ ๋์ ๊ฒ์ ํํ์ ๊ณ ๋ฏผํ๊ณ ์์ ํ๋ ํ์๊ฐ ๋ฐ๋ก ์
๋ ฅ $X$์ $W_{Q}$๋ฅผ ๊ณฑํด ํ๋ ฌ $Q$ ์ ๋ง๋๋ ์์์ผ๋ก ํํ๋๋ค.
1๋
๋ค ์ฌ๋ฆ, ๋น์ ์ ๋ธ๋ผ์ฐ์ ๋ฅผ ๋ฐ๊พผ ํ์ ๋ถ๋งํฌ๊ฐ ์ด๊ธฐํ ๋์ด ๋ค์ ํ ๋ฒ ๊ฒ์์ ํด์ผ ํ๋ค. ํ์ง๋ง ์ฌ์ ํ ๊ฒ์์ด๋ ๊ธฐ์ตํ๊ณ ์์ด์, 1๋
์ ์ต์ ์ ๊ฒฐ๊ณผ๋ฅผ ์ป์๋ ๊ทธ๋๋ก ๋ค์ ๊ฒ์์ ํ๋ค. ๋ถ๋ช
๋๊ฐ์ด ๊ฒ์์ ํ๋๋ฐ ๊ฐ์ ๊ฒฐ๊ณผ๊ฐ 1ํ์ด์ง ์ต์๋จ์์ ๋ฐํ๋๊ณ ์์๋ค. ๋น์ ์ ์ด๊ฒ ์ด๋ป๊ฒ ๋ ์ผ์ธ์ง ๊ถ๊ธํด ํฌ์คํธ๋ฅผ ์ฒ์ฒํ ๋ณด๋ ์ค, ์ ๋ชฉ์ 1๋
์ ์๋ ์๋ ์ผ์ฑ Bespoke ์์ด์ปจ
์ด๋ผ๋ ํค์๋๊ฐ ํฌํจ ๋์ด ์์๋ค. ๊ฒ์๋ฌผ์ ์ฃผ์ธ์ฅ์ด SEO ์ต์ ํ
๋ฅผ ์ํด ์ถ๊ฐํ๋ ๊ฒ์ด์๋ค. ๋๋ถ์ ๋น์ ์ ์์ ์๊ฐ์ 1๋ถ 30์ด
์์ 20์ด
๋ก ์ค์ผ ์ ์์๋ค. ์ด๋ฐ ์ํฉ์ด ๋ฐ๋ก ์
๋ ฅ $X$์ $W_{K}$๋ฅผ ๊ณฑํด ํ๋ ฌ $K$ ๋ฅผ ๋ง๋๋ ์์์ ๋์๋๋ค.
์ฐ๋ฆฌ๋ ์ ์์๋ฅผ ํตํด ์ํ๋ ์ ๋ณด๋ฅผ ๋น ๋ฅด๊ณ ์ ํํ๊ฒ ์ฐพ๋ ํ์๋, ๋ต๋ณ์๊ฐ ์ดํดํ๊ธฐ ์ข์ ์ง๋ฌธ๊ณผ ์ง๋ฌธ์์ ์ง๋ฌธ ์๋์ ๋ถํฉํ๋ ์ข์ ๋ต๋ณ์ผ๋ก ์์ฑ๋๋ค๋ ๊ฒ์ ์ ์ ์์๋ค. ๋ฟ๋ง ์๋๋ผ, ์ข์ ์ง๋ฌธ๊ณผ ์ข์ ๋ต๋ณ์ด๋ผ๋ ๊ฒ์ ์ฒ์๋ถํฐ ์์ฑ๋๋๊ฒ ์๋๋ผ ๊ฒ์ ์๊ฐ์ ๋จ์ถํ๋ ค๋ ๋์์๋ ๋
ธ๋ ฅ์ ํตํด ์ฑ์ทจ๋๋ค๋ ๊ฒ ์ญ์ ๊นจ์ฐ์ณค๋ค. ๋๊ฐ์ง ์ธ์ฌ์ดํธ๊ฐ ๋ฐ๋ก linear projection
์ผ๋ก ํ๋ ฌ $Q, K,V$์ ์ ์ํ ์ด์ ๋ค. ๋ด๊ฐ ์ํ๋ ์ ๋ณด์ธ์ง ์๋์ง ๋์กฐํ๋ ๋ด์ ์ฐ์ฐ์ ์ํํ๋๋ฐ ๊ฐ์ค์น ํ๋ ฌ์ด ํ์ ์๊ธฐ ๋๋ฌธ์ ์์คํจ์์ ์ค์ฐจ ์ญ์ ์ ํ์ฉํ ์์น ์ต์ ํ๋ฅผ ์ํํ ์ ์๋ค. ๊ทธ๋์ ์์คํจ์ ๋ฏธ๋ถ์ ์ํ ์ต์ ํ๊ฐ ๊ฐ๋ฅํ๋๋ก linear projection matrix
๋ฅผ ํ์ฉํด ํ๋ ฌ $Q, K,V$๋ฅผ ์ ์ํด์ค ๊ฒ์ด๋ค. ์ด๋ ๊ฒ ํ๋ฉด ๋ชจ๋ธ์ด ์ฐ๋ฆฌ์ ๋ชฉ์ ์ ๊ฐ์ฅ ์ ํฉํ ์ง๋ฌธ๊ณผ ๋ต๋ณ์ ์์์ ํํ ํด์ค ๊ฒ์ด๋ผ ๊ธฐ๋ํ ์ ์๊ฒ ๋๋ค. ํํธ, ๊ฐ์ ์ถ์ฒ์์ ๋์๋ค๋ ๋ง์ ๋ฐฉ๊ธ ์์์์ ํ๋ ฌ $Q, K,V$๋ฅผ ๋ง๋๋๋ฐ ๋์ผํ๊ฒ ์
๋ ฅ $X$๋ฅผ ์ฌ์ฉ ๊ฒ๊ณผ ๊ฐ์ ์ํฉ์ ์๋ฏธํ๋ค.
์ด์ ๋ค์ ์์ฐ์ด ์ฒ๋ฆฌ ๋งฅ๋ฝ์ผ๋ก ๋์์๋ณด์. Transformer
๋ ์ข์ ๋ฒ์ญ๊ธฐ๋ฅผ ๋ง๋ค๊ธฐ ์ํด ๊ณ ์๋ seq2seq
๊ตฌ์กฐ์ ๋ชจ๋ธ์ด๋ค. ์ฆ, ๋น ๋ฅด๊ณ ์ ํํ๊ฒ ๋์ ์ธ์ด์์ ํ๊ฒ ์ธ์ด๋ก ๋ฒ์ญํ๋ ๊ฒ์ ๋ชฉํ๋ฅผ ๋๊ณ ๋ง๋ค์ด์ก๋ค๋ ๊ฒ์ด๋ค. ๋ฒ์ญ์ ์ํ๊ธฐ ์ํด์๋ ์ด๋ป๊ฒ ํด์ผ ํ ๊น?? 1) ๋์ ์ธ์ด๋ก ์ฐ์ธ ์ํ์ค์ ์๋ฏธ๋ฅผ ์ ํํ๊ฒ ํ์
ํด์ผ ํ๊ณ , 2) ํ์
ํ ์๋ฏธ์ ๊ฐ์ฅ ์ ์ฌํ ์ํ์ค๋ฅผ ํ๊ฒ ์ธ์ด๋ก ๋ง๋ค์ด ๋ด์ผ ํ๋ค. ๊ทธ๋์ 1๋ฒ์ ์ญํ ์ Encoder๊ฐ ๊ทธ๋ฆฌ๊ณ 2๋ฒ์ Decoder๊ฐ ๋งก๊ฒ ๋๋ค
. ์ธ์ฝ๋๋ ๊ฒฐ๊ตญ (๋ฒ์ญํ๋๋ฐ ์ ํฉํ ํํ๋ก) ๋์ ์ธ์ด ์ํ์ค์ ์๋ฏธ๋ฅผ ์ ํํ ์ดํดํ๋ ๋ฐฉํฅ(์ซ์๋ก ํํ, ์๋ฒ ๋ฉ ์ถ์ถ)์ผ๋ก ํ์ต์ ์ํํ๊ฒ ๋๋ฉฐ, ๋์ฝ๋๋ ์ธ์ฝ๋์ ํ์ต ๊ฒฐ๊ณผ์ ๊ฐ์ฅ ์ ์ฌํ ๋ฌธ์ฅ์ ํ๊ฒ ์ธ์ด๋ก ์์ฑํด๋ด๋ ๊ณผ์ ์ ๋ฐฐ์ฐ๊ฒ ๋๋ค. ๋ฐ๋ผ์ ์ธ์ฝ๋๋ ๋์ ์ธ์ด๋ฅผ ์ถ์ฒ๋ก, ๋์ฝ๋๋ ํ๊ฒ ์ธ์ด๋ฅผ ์ถ์ฒ๋ก ํ๋ ฌ $Q, K,V$๋ฅผ ๋ง๋ ๋ค. ์ ํํ self
๋ผ๋ ๋จ์ด๋ฅผ ์ด๋ฆ์ ๊ฐ๋ค ๋ถ์ธ ์๋์ ์ผ๋งฅ์ํตํ๋ ๋ชจ์ต์ด๋ค.
๊ฒฐ๊ตญ Transformer
์ ์ฑ๋ฅ์ ์ข์ง์ฐ์ง ํ๋ ๊ฒ์ ๋๊ฐ ์ผ๋ง๋ ๋ linear projection weight
์ ์ ์ต์ ํ ํ๋๊ฐ์ ๋ฌ๋ ธ๋ค๊ณ ๋ณผ ์ ์๋ค.
ํํธ ํ์๋ ์ฒ์ ์ด ๋
ผ๋ฌธ์ ์ฝ์์ ๋ linear projection
์์ฒด์ ํ์์ฑ์ ๊ณต๊ฐํ์ผ๋, ๊ตณ์ด 3๊ฐ์ ํ๋ ฌ๋ก ๋๋ ์ train
์์ผ์ผ ํ๋ param
์ซ์๋ฅผ ๋๋ฆฌ๋ ๊ฒ๋ณด๋ค๋ weight share
ํ๋ ํํ๋ก ๋ง๋๋๊ฒ ๋ ํจ์จ์ ์ผ ๊ฒ ๊ฐ๋ค๋ ์ถ์ธก์ ํ์๋ค.
๊ทธ๋ฌ๋ ์ด๋ฒ ๋ฆฌ๋ทฐ๋ฅผ ์ํด ๋ค์ ๋ ผ๋ฌธ์ ์ฝ๋ ์ค, ์ข์ ์ง๋ฌธ์ ํ๊ธฐ ์ํ ๋ ธ๋ ฅ๊ณผ ์ข์ ๋ต๋ณ์ ํ๊ธฐ ์ํ ๋ ธ๋ ฅ, ๊ทธ๋ฆฌ๊ณ ํ์ํ ์ ๋ณด๋ฅผ ์ ํํ ์ถ์ถํด๋ด๋ ํ์๋ฅผ ๊ฐ๊ฐ ์๋ก ๋ค๋ฅธ 3๊ฐ์ ๋ฒกํฐ๋ก ํํํ์ ๋ ๋ฒกํฐ๋ค์ด ๊ฐ์ง๋ ๋ฐฉํฅ์ฑ์ด ์๋ก ๋ค๋ฅผํ ๋ฐ ๊ทธ๊ฒ์ ํ๋์ ๋ฒกํฐ๋ก ํํํ๋ ค๋ฉด ๋ชจ๋ธ์ด ํ์ต์ ํ๊ธฐ ํ๋ค ๊ฒ ๊ฐ๋ค๋ ์๊ฐ์ด ๋ค์๋ค. ๋ฐฉ๊ธ ์์์ ๋ ์์๋ง ๋ด๋ ๊ทธ๋ ๋ค. ์๋ก ๋ค๋ฅธ 3๊ฐ์ ํ์ ์ฌ์ด์ ์ต์ ์ง์ ์ ์ฐพ์ผ๋ผ๋ ๊ฒ๊ณผ ๋ง์ฐฌ๊ฐ์ง๋ฐ ๊ทธ๋ฐ ์คํ์ด ์๋ค๊ณ ํด๋ ์ธ์ด ๋ชจ๋ธ์ด ์ ์ฐพ์ ์ ์์๊น?? ์ธ๊ฐ๋ ์ฐพ๊ธฐ ํ๋ ๊ฒ์ ๋ชจ๋ธ์ด ์ ์ฐพ์๋ฆฌ๊ฐ ์๋ค.
๐ย Scaled Dot-Product Attention
\[Attention(Q,K,V) = softmax(\frac{QยทK^T}{\sqrt{d_k}})V\]
์ด๋ฒ์๋ Self-Attention
์ ๋ ๋ฒ์งธ ํ์ ๋ธ๋ญ์ธ Scaled Dot-Product Attention
์ฐจ๋ก๋ค. ์ฌ์ค ์ฐ๋ฆฌ๋ Linear Projection
ํํธ์์ ์ด๋ฏธ ์ฐ๋ฆฌ๋ ๋ชจ๋ฅด๊ฒ Scaled Dot-Product Attention
์ ๋ํด ๊ณต๋ถํ๋ค. ์์๋ฅผ ๋ค์ ํ ๋ฒ ์๊ธฐ์์ผ๋ณด์. ์ง์๋ฅผ ํตํด ์ป์ ๊ฒฐ๊ณผ ๋ฆฌ์คํธ(ํค)์์ ๋ด๊ฐ ์ํ๋ ์ ๋ณด๋ฅผ ์ฐพ๊ธฐ ์ํด ์ฟผ๋ฆฌ์ ํค๋ฅผ ๋์กฐํ๋ค๊ณ ํ๋ ๊ฒ ๊ธฐ์ต๋๋๊ฐ?? ๋ฐ๋ก ๊ทธ ๋์กฐํ๋ ํ์๋ฅผ ์ํ์ ์ผ๋ก ๋ชจ๋ธ๋งํ ๊ฒ์ด ๋ฐ๋ก Scaled Dot-Product Attention
์ ํด๋นํ๋ค.
Scaled Dot-Product Attention
์ ์ด 5๋จ๊ณ๋ฅผ ๊ฑฐ์ณ ์์ฑ๋๋ค. ๋จ๊ณ๋ง๋ค ์ด๋ค ์ฐ์ฐ์ ์ ํ๋์ง ๊ทธ๋ฆฌ๊ณ ๋ฌด์จ ์ธ์ฌ์ดํธ๊ฐ ๋ด๊ฒจ ์๋์ง ์์๋ณด์. ์ด ์ค์์ ๋ง์คํน ๋จ๊ณ๋ ์ธ์ฝ๋์ ๋์ฝ๋์ ๋์์ ์์ธํ ์์์ผํ๊ธฐ ๋๋ฌธ์ ์ ์ฒด์ ์ธ ๊ตฌ์กฐ ๊ด์ ์์ ๋ชจ๋ธ์ ๋ฐ๋ผ๋ณผ ๋ ํจ๊ป ์ค๋ช
ํ๋๋ก ํ๊ฒ ๋ค.
โ๏ธย Stage 1. QโขK^T Dot-Product
์ธ๊ฐ์ ๋ฌธ์ฅ์ด๋ ์ด๋ค ํํ์ ์๋ฏธ๋ฅผ ํ์
ํ๋๋ฐ ๋ฐ๋ก ์ฃผ๋ณ ๋งฅ๋ฝ์ ์ฐธ๊ณ ํ๊ฑฐ๋, ๋ ๋ฉ๋ฆฌ ๋จ์ด์ง ๊ณณ์ ๋จ์ดโข์ํ์ค๋ฅผ ์ด์ฉํ๊ธฐ๋ ํ๋ค. ์ฆ, ์ฃผ์ด์ง ์ํ์ค ๋ด๋ถ์ ๋ชจ๋ ๋งฅ๋ฝ์ ์ด์ฉํด ํน์ ๋ถ๋ถ์ ์๋ฏธ๋ฅผ ์ดํดํ๋ค๋ ๊ฒ์ด๋ค. ๊ทธ๋ ๋ค๊ณ ๋ชจ๋ ์ ๋ณด๊ฐ ๋์ผํ๊ฒ ํน์ ํํ์ ์๋ฏธ์ ์ํฅ์ ๋ฏธ์น๋ ๊ฒ์ ๋ ์๋๋ฐ, ์๋ฅ ์์ด์ ํฌ๋ฌ ๋ฌธํญ์ผ๋ก ๋ฑ์ฅํ๋ ๋น์นธ ์ฑ์ฐ๊ธฐ ๋ฌธ์ ๋ฅผ ์ด๋ป๊ฒ ํ์๋ ๋ ์ฌ๋ ค๋ณด์. ๋ํ
์ผํ ํ์ด ๋ฐฉ์์๋ ์ฌ๋๋ง๋ค ์ฐจ์ด๊ฐ ์๊ฒ ์ง๋ง, ์ผ๋ฐ์ ์ผ๋ก ์ง๋ฌธ์ ๋ชจ๋ ํ์ด ๋ณด๋ ๋น์นธ์ ๋ค์ด๊ฐ ์ ๋ต์ ๊ทผ๊ฑฐ๊ฐ ๋๋ ํน์ ๋ฌธ์ฅ ํน์ ํํ 1~2๊ฐ๋ฅผ ์ฐพ์๋ด์ด ๋น์ทํ ์๋ฏธ๋ฅผ ์ง๋ ์ ์ง๋ฅผ ๊ณจ๋ผ ๋ด๋ ๋ฐฉ์์ ์ฌ์ฉํ๋ค. ๋ค์ ๋งํด, ์ฃผ์ด์ง ์ ์ฒด ๋จ๋ฝ์์ ์๋ฏธ๋ฅผ ์ดํดํ๋๋ฐ ์ค์ํ ์ญํ ์ ํ๋ ํํ์ด๋ ๋ฌธ์ฅ์ ๊ณจ๋ผ๋ด์ด ์ค์๋
๋งํผ ๊ฐ์ค์น
๋ฅผ ์ฃผ๊ฒ ๋ค๋ ๊ฒ์ด๋ค.
QโขK^T Dot Product Visualization
๊ทธ๋ ๋ค๋ฉด ์ด๊ฒ์ ์ด๋ป๊ฒ ์ํ์ ์ผ๋ก ๋ชจ๋ธ๋งํ์๊น?? ๋ฐ๋ก ํ๋ ฌ $Q$์ $K^T$์ ๋ด์
์ ํ์ฉํ๋ค. ํ๋ ฌ $Q$๋ ๋ชจ๋ธ์ด ์๋ฏธ๋ฅผ ํ์
ํด์ผ ํ๋ ๋์์ด ๋ด๊ฒจ ์๊ณ , ํ๋ ฌ $K$์๋ ์๋ฏธ ํ์
์ ํ์ํ ๋จ์๋ค์ด ๋ด๊ฒจ์๋ค. ๋ด์ ์ ๋ ๋ฒกํฐ์ ์๋ก โ๋ฎ์ ์ ๋โ
๋ฅผ ์๋ฏธํ๋ค๊ณ ํ๋ค. โ๋ฎ์ ์ ๋โ
๊ฐ ๋ฐ๋ก ์ค์๋โข๊ฐ์ค์น
์ ๋์๋๋ค. ๋ฐ๋ผ์ ์ฐ์ฐ ๊ฒฐ๊ณผ๋ ์ ์ฒด ์ํ์ค์ ์ํ ํ ํฐ๋ค ์ฌ์ด์ โ๋ฎ์ ์ ๋โ
๊ฐ ์์น๋ก ๋ณํ๋์ด ํ๋ ฌ์ ๋ด๊ธด๋ค.
์ ๋ด์ ๊ฒฐ๊ณผ
๊ฐ ์ค์๋
์ ๊ฐ์ ์๋ฏธ๋ฅผ ๊ฐ๊ฒ ๋๋ ๊ฒ์ผ๊น?? ์๊น Input Embedding
๊ณผ Position Embedding
์ ํ๋ ฌํฉ ํ๋ ๊ฒ์ ๋ํ ๋น์์ฑ์ ์ค๋ช
ํ๋ฉด์ ๊ณ ์ฐจ์์ผ๋ก ๊ฐ์๋ก ๋๋ถ๋ถ์ ๋ฒกํฐ ์์ ์ง๊ต์ฑ
์ ๊ฐ๊ฒ ๋๋ค๊ณ ์ธ๊ธํ ๋ฐ ์๋ค. ๊ทธ๋์ ๋ ๋ฒกํฐ๊ฐ ๋น์ทํ ๋ฐฉํฅ์ฑ์ ๊ฐ๋๋ค๋ ๊ฒ ์์ฒด๊ฐ ๋งค์ฐ ๋๋ฌธ์ผ์ด๋ค. ํฌ๊ทํ๊ณ ๋๋ฌธ ์ฌ๊ฑด์ ๊ทธ๋งํผ ์ค์ํ๋ค๊ณ ๋งํ ์ ์๊ธฐ ๋๋ฌธ์ ๋ด์ ๊ฒฐ๊ณผ
๋ฅผ ์ค์๋
์ ๋งตํํ๋ ๊ฒ์ด๋ค.
ํํธ, ํ๋ ฌ $Q,K$ ๋ชจ๋ ์ฐจ์์ด [Batch, Max_Seq, Dim_Head]
์ธ ํ
์๋ผ์ ๋ด์ ํ ๊ฒฐ๊ณผ์ ๋ชจ์์ [Batch, Max_Seq, Max_Seq]
์ด ๋ ๊ฒ์ด๋ค.
๐ญย Stage 2. Scale
โI am dogโ
๋ผ๋ ๋ฌธ์ฅ์ $QโขK^T$ํ๋ฉด ์์ ๊ฐ์ 3x3
์ง๋ฆฌ ํ๋ ฌ์ด ๋์ฌ ๊ฒ์ด๋ค. ํ๋ ฌ์ ํ๋ฒกํฐ๋ก ๋ฐ๋ผ๋ณด์. ํ ์ฌ์ด์ ๊ฐ์ ๋ถํฌ๊ฐ ๊ณ ๋ฅด์ง ๋ชปํ๋ค๋ ๊ฒ์ ์ ์ ์๋ค. ์ด๋ ๊ฒ ๋ถ์ฐ์ด ํฐ ์ํ๋ก softmax
์ ํต๊ณผ์ํค๊ฒ ๋๋ฉด ์ญ์ ํ ๊ณผ์ ์์ softmax
์ ๋ฏธ๋ถ๊ฐ์ด ์ค์ด ๋ค์ด ํ์ต ์๋๊ฐ ๋๋ ค์ง๊ณ ๋์๊ฐ vanishing gradient
ํ์์ด ๋ฐ์ํ ์ ์๋ค. ๋ฐ๋ผ์ ํ๋ฒกํฐ ์ฌ์ด์ ๋ถ์ฐ์ ์ค์ฌ์ฃผ๊ธฐ ์ํด์ Scale Factor
๋ฅผ ์ ์ํ๊ฒ ๋๋ค. ๊ทธ๋ ๋ค๋ฉด ์ด๋ค Scale Factor
๋ฅผ ์จ์ผํ ๊น??
์ ์ด์ Dim Head
์ฐจ์์ ์ํ ๊ฐ๋ค์ ๋ถ์ฐ์ด ํฐ ๊ฒ๋ ๋ฌธ์ ๊ฐ ๋์ง๋ง ์ด๊ฒ์ Input Embedding
์ด๋ Position Embedding
์ layernorm
์ ์ ์ฉํ๋ฉด ํด๊ฒฐํ ์ ์๊ธฐ ๋๋ฌธ์ ๋
ผ์ ๋์์ด ์๋๋ค. ๊ทธ๊ฒ๋ณด๋ค๋ ๋ด์ ๊ณผ์ ์ ์ฃผ๋ชฉํด๋ณด์. ์ฐ๋ฆฌ๋ ๋ด์ ์ ํ๋ค๋ณด๋ฉด Dim Head
์ ์ฐจ์์ด ์ปค์ง์๋ก ๋ํด์ค์ผ ํ๋ ์ค์นผ๋ผ ๊ฐ์ ๊ฐ์๊ฐ ๋์ด๋๊ฒ ๋๋ค๋ ์ฌ์ค์ ์ ์ ์๋ค. ๋ง์ฝ ์์์ ์์๋ก ๋ ์์์ Dim Head
๊ฐ 64๋ผ๊ณ ๊ฐ์ ํด๋ณด์. ๊ทธ๋ผ ์ฐ๋ฆฌ๋ 1ํ 1์ด์ ๊ฐ์ ์ป๊ธฐ ์ํด 64๊ฐ์ ์ค์นผ๋ผ ๊ฐ์ ๋ํด์ค์ผ ํ๋ค. ๋ง์ฝ 512
์ฐจ์์ด๋ผ๋ฉด 512
๊ฐ๋ก ๋ถ์ด๋๋ค. ๋ํด์ค์ผ ํ๋ ์ค์นผ๋ผ ๊ฐ์ด ๋ง์์ง๋ค๋ฉด ํ๋ฒกํฐ ๋ผ๋ฆฌ์ ๋ถ์ฐ์ด ์ปค์ง ์ฐ๋ ค๊ฐ ์๋ค. ๋ฐ๋ผ์ ์ฐจ์ ํฌ๊ธฐ์ ์ค์ผ์ผ์ ๋ฐ๋ผ softmax
์ ๋ฏธ๋ถ๊ฐ์ด ์ค์ด๋๋ ๊ฒ์ ๋ฐฉ์งํ๊ธฐ ์ํด $QโขK^T$๊ฒฐ๊ณผ์ $\sqrt{d_h}$๋ฅผ ๋๋ ์ค๋ค.
์ฌ๋ด์ผ๋ก ์ด๋ฌํ scale factor
์ ์กด์ฌ ๋๋ฌธ์ Self-Attention
์ Scaled Dot-Product Attention
์ด๋ผ๊ณ ๋ถ๋ฅด๊ธฐ๋ ํ๋ค.
๐ญย Stage 3. masking
๋ง์คํน์ ์ธ์ฝ๋ Input Padding
, ๋์ฝ๋ Masked Multi-Head Attention
, ์ธ์ฝ๋-๋์ฝ๋ Self-Attention
์ ์ํด ํ์ํ ๊ณ์ธต์ด๋ค. ๋ค์ ๋๊ฐ๋ ๋์ฝ๋์ ๋์์ ์์์ผ ์ดํด๊ฐ ๊ฐ๋ฅํ๊ธฐ ๋๋ฌธ์ ์ฌ๊ธฐ์๋ ์ธ์ฝ๋์ ๋ง์คํน์ ๋ํด์๋ง ์์๋ณด์.
์ค์ ํ
์คํธ ๋ฐ์ดํฐ๋ ๋ฐฐ์น๋ ์ํ์ค๋ง๋ค ๊ทธ ๊ธธ์ด๊ฐ ์ ๊ฐ๊ฐ์ด๋ค. ํจ์จ์ฑ์ ์ํด ํ๋ ฌ์ ์ฌ์ฉํ๋ ์ปดํจํฐ ์ฐ์ฐ ํน์ฑ์ ๋ฐฐ์น๋ ์ํ์ค์ ๊ธธ์ด๊ฐ ๋ชจ๋ ๋ค๋ฅด๋ค๋ฉด ์ฐ์ฐ์ ์งํํ ์๊ฐ ์๋ค. ๋ฐ๋ผ์ ๋ฐฐ์น ๋ด๋ถ์ ๋ชจ๋ ์ํ์ค์ ๊ธธ์ด๋ฅผ ํต์ผํด์ฃผ๋ ์์
์ ํ๊ฒ ๋๋๋ฐ, ์ด ๋ ๊ธฐ์ค ๊ธธ์ด๋ณด๋ค ์งง์ ์ํ์ค์ ๋ํด์๋ 0
๊ฐ์ ์ฑ์๋ฃ๋ padding
์์
์ ํ๋ค. ํ๋ ฌ ์ฐ์ฐ์๋ ๊ผญ ํ์ํ๋ padding
์ ์คํ๋ ค softmax
๋ ์ด์ด๋ฅผ ๊ณ์ฐํ ๋ ๋ฐฉํด๊ฐ ๋๋ค. ๋ฐ๋ผ์ ๋ชจ๋ padding
๊ฐ์ softmax
์ ํ๋ฅ ๊ณ์ฐ์์ ์์ ํ ์ ์ธ์ํค๊ธฐ ์ํด Input Embedding
์์ padding token
์ ์ธ๋ฑ์ค๋ฅผ ์ ์ฅํ๊ณ ํด๋น๋๋ ๋ชจ๋ ์์๋ฅผ -โ
๋ก ๋ง์คํนํ๋ ๊ณผ์ ์ด ํ์ํ๋ค.
์ด ๋ ๋ง์คํน ์ฒ๋ฆฌ๋ ์ด๋ฒกํฐ์๋ง ์ ์ฉํ๋ค. ๊ทธ ์ด์ ๋ ๋ฐ๋ก softmax
๊ณ์ฐ์ ์ด์ฐจํผ ํ๋ฒกํฐ ๋ฐฉํฅ์ผ๋ก๋ง ํ ๊ฒ์ด๊ธฐ ๋๋ฌธ์ด๋ค. ํ๋ฒกํฐ ๋ฐฉํฅ์ padding token
์๋ ๋์ผํ๊ฒ ๋ง์คํน ์ ์ฉํ๋ ๊ฒ์ ์๊ด ์์ผ๋ ์ด๋ฒกํฐ์ ํ๋ฒกํฐ ๋์์ ๋ง์คํน ์ ์ฉํ๋ ๋์์ ๊ตฌํํ๋ ๊ฒ์ ์๊ฐ๋ณด๋ค ๋ง์ด ๊น๋ค๋ก์ฐ๋ฉฐ, ๋์ค์ ์์ค๊ฐ ๊ณ์ฐํ๋ ๋จ๊ณ์์ ignore_index
์ต์
์ ์ฌ์ฉํด ํ๋ฒกํฐ์ padding token
์ ๋ฌด์ํ๋ ๊ฒ์ด ํจ์ฌ ํจ์จ์ ์ด๋ค. ํํธ, ignore_index
์ต์
์ nn.CrossEntropyLoss
์ ๋งค๊ฐ๋ณ์๋ก ๊ตฌํ ๋์ด ์๋ค.
๐ย Stage 4. Softmax & ScoreโขV
๊ณ์ฐ๋ ์ ์ฌ๋(๋ด์ ๊ฒฐ๊ณผ, ์ค์๋, ๊ฐ์ค์น)
, $\frac{QโขK^T}{\sqrt{d_h}}$๋ ์ดํ์ ํ๋ ฌ $V$์ ๋ค์ ๊ณฑํด์ ธ ํ๋ฒกํฐ $Z_n$(n๋ฒ์งธ ํ ํฐ)์์ ํ ํฐ์ ๋ํ ์ดํ
์
์ ๋๋ฅผ ๋ํ๋ด๋ ๊ฐ์ค์น
์ ์ญํ ์ ํ๊ฒ ๋๋ค. ๊ทธ๋ฌ๋ ๊ณ์ฐ๋ ์ ์ฌ๋๋ ๋น์ ๊ทํ๋ ํํ๋ค. ์์์๋ ํธ์์ ์ด๋ฏธ softmax
๋ฅผ ์ ์ฉํ ํํ์ ํ๋ ฌ์ ์ ์์ง๋ง, ์ค์ ๋ก๋ ์์๊ฐ์ ๋ถ์ฐ์ด ๋๋ฌด ์ปค์ ๊ฐ์ค์น๋ก๋ ์ฐ๊ธฐ ํ๋ ์์ค์ด๋ค. ๋ฐ๋ผ์ ํ๋ฒกํฐ ๋จ์๋ก softmax
์ ํต๊ณผ์์ผ ๊ฒฐ๊ณผ์ ํฉ์ด 1์ธ ํ๋ฅ ๊ฐ์ผ๋ก ๋ณํ(์ ๊ทํ)
ํด ํ๋ ฌ $V$์ ๊ฐ์ค์น๋ก ์ฌ์ฉํ๋ค.
์ด์ ๋๋ฒ์งธ ์์์ ๋ณด์. $Score_{11}$์ ํด๋นํ๋ 0.90
๊ฐ ํ๋ ฌ $V$์ ์ฒซ๋ฒ์งธ ํ๋ฒกํฐ์ ๊ณฑํด์ง๊ณ ์๋ค. ํ๋ ฌ $V$์ ์ฒซ๋ฒ์งธ ํ๋ฒกํฐ๋ ํ ํฐ โIโ
๋ฅผ 512
์ฐจ์์ผ๋ก ํํํ ๊ฒ์ด๋ค. ๊ทธ ๋ค์ $Score_{12}$๋ ํ๋ ฌ $V$์ ๋๋ฒ์งธ ํ๋ฒกํฐ์, $Score_{13}$์ ํ๋ ฌ $V$์ ์ธ๋ฒ์งธ ํ๋ฒกํฐ์ ๊ฐ๊ฐ ๊ณฑํด์ง๋ค.
์ด ํ์์ ์๋ฏธ๋ ๋ฌด์์ผ๊น?? $Score_{11}$, $Score_{12}$, $Score_{13}$์ ๋ชจ๋ ์ฒซ๋ฒ์งธ ํ ํฐ์ธ โIโ
์ ์๋ฏธ๋ฅผ ํ์
ํ๋๋ฐ โIโ
, โamโ
, โdogโ
๋ฅผ ์ด๋ ์ ๋๋ก ์ดํ
์
ํด์ผ ํ๋์ง, ์ฆ โIโ
์ ์๋ฏธ๋ฅผ ํํํ๋๋ฐ ์ธ ํ ํฐ์ ์๋ฏธ๋ฅผ ์ด๋ ์ ๋ ๋ฐ์ํ ์ง ์์น๋ก ํํํ ๊ฒ์ด๋ค. ๋น์ฐํ ์๊ธฐ ์์ ์ธ โIโ
์ ๊ฐ์ค์น(์ ์ฌ๋, ์ค์๋)
๊ฐ ๊ฐ์ฅ ๋๊ธฐ ๋๋ฌธ์ ํ๋ ฌ $V$์์ โIโ
์ ํด๋นํ๋ ํ๋ฒกํฐ ๊ฐ์ค์น์ ๊ฐ์ฅ ํฐ ๊ฐ์ด ๋ค์ด๊ฐ๋ค๊ณ ์๊ฐํด๋ณผ ์ ์๋ค. ์ด๋ ๊ฒ ๊ฐ ํ ํฐ๋ง๋ค ๊ฐ์คํฉ์ ๋ฐ๋ณตํด์ฃผ๋ฉด ์ต์ข
์ ์ผ๋ก โIโ
, โamโ
, โdogโ
์ ์ธ์ฝ๋ฉํ $Z_1, \ Z_2, \ Z_3$ ๊ฐ์ ์ป์ ์ ์๋ค.
๐ฉโ๐ปย Implementation
์ด๋ ๊ฒ Scaled Dot-Product Attention
์ ๋ชจ๋ ์ดํด๋ณด์๋ค. ํด๋น ๋ ์ด์ด๋ ๋ชจ๋ธ์ด ์์ค๊ฐ์ด ๊ฐ์ฅ ์์์ง๋ ๋ฐฉํฅ์ผ๋ก ์ต์ ํํ ํ๋ ฌ $Q, K, V$ ์ ์ด์ฉํด, ํ ํฐ์ ์๋ฏธ๋ฅผ ์ดํดํ๋๋ฐ ์ด๋ค ๋งฅ๋ฝ๊ณผ ํํ์ ์ข ๋ ์ง์คํ๊ณ ๋ ์ง์คํด์ผ ํ๋์ง๋ฅผ ์ ์ฌ๋๋ฅผ ๊ธฐ์ค์ผ๋ก ํ๋จํ๋ค๋ ๊ฒ์ ๊ผญ ๊ธฐ์ตํ์. ๊ทธ๋ ๋ค๋ฉด ์ค์ ์ฝ๋๋ ์ด๋ป๊ฒ ์์ฑ ํด์ผํ๋์ง ํจ๊ป ์์๋ณด์. ์๋จ์ class diagram
์ ๋ค์ ํ ๋ฒ ๋ณด๊ณ ๋์์ค์.
# Pytorch Implementation of Scaled Dot-Product Self-Attention
def scaled_dot_product_attention(q: Tensor, k: Tensor, v: Tensor, dot_scale: Tensor, mask: Tensor = None) -> Tensor:
"""
Scaled Dot-Product Attention with Masking for Decoder
Args:
q: query matrix, shape (batch_size, seq_len, dim_head)
k: key matrix, shape (batch_size, seq_len, dim_head)
v: value matrix, shape (batch_size, seq_len, dim_head)
dot_scale: scale factor for QโขK^T result
mask: there are three types of mask, mask matrix shape must be same as single attention head
1) Encoder padded token
2) Decoder Masked-Self-Attention
3) Decoder's Encoder-Decoder Attention
Math:
A = softmax(qโขk^t/sqrt(D_h)), SA(z) = Av
"""
attention_matrix = torch.matmul(q, k.transpose(-1, -2)) / dot_scale
if mask is not None:
attention_matrix = attention_matrix.masked_fill(mask == 0, float('-inf'))
attention_dist = F.softmax(attention_matrix, dim=-1)
attention_matrix = torch.matmul(attention_dist, v)
return attention_matrix
๋ง์คํน ์ต์ ์ ๊ฒฝ์ฐ ์ฃผ์์ ์ ๋ฆฌ๋ 3๊ฐ์ง ์ํฉ ์ค์์ ํ ๊ฐ ์ด์์ ํด๋น๋๋ฉด ์คํ๋๋๋ก ์ฝ๋๋ฅผ ์์ฑํ๋ค. 3๊ฐ์ง ์ํฉ๊ณผ ๊ตฌ์ฒด์ ์ธ ๋ง์คํน ๋ฐฉ๋ฒ์ ๋ํด์๋ ์ ์ฒด ๋ชจ๋ธ ๊ตฌ์กฐ๋ฅผ ๋ณด๋ ๋ ์๊ฐํ๋๋ก ํ๊ฒ ๋ค.
ํํธ, ์ธ์ฝ๋๋ ๋์ฝ๋๋ ๋ชจ๋ ์ฌ์ฉํ๋ ์
๋ ฅ๊ณผ ๋ง์คํน ๋ฐฉ์์ ์ฐจ์ด๋ ์์ง๋ง, Scaled Dot-Product Attention
์ฐ์ฐ ์์ฒด๋ ๋์ผํ ๊ฒ์ ์ฌ์ฉํ๋ค. ๋ฐ๋ผ์ ์ฌ๋ฌ๊ฐ์ ์ธ์ฝ๋๋ ๋์ฝ๋ ๊ฐ์ฒด๋ค ํน์ ์ดํ
์
ํด๋ ๊ฐ์ฒด๋ค์ด ๋ชจ๋ ์ฝ๊ฒ ์ฐ์ฐ์ ์ ๊ทผํ ์ ์๊ฒ ํด๋์ค ์ธ๋ถ์ ๋ฉ์๋ ํํ๋ก ๊ตฌํํ๊ฒ ๋์๋ค.
๐ฉโ๐ฉโ๐งโ๐ฆย Multi-Head Attention Block
์ง๊ธ๊น์ง ์ดํด๋ณธ Self-Attention
์ ๋์์ ๋ชจ๋ ํ ๊ฐ์ Attention-Head
์์ ์ผ์ด๋๋ ์ผ์ ์์ ํ ๊ฒ์ด๋ค. ์ฌ์ค ์ค์ ๋ชจ๋ธ์์๋ ๊ฐ์ ๋์์ด N-1
๊ฐ์ ๋ค๋ฅธ ํด๋์์ ๋์์ ์ผ์ด๋๋๋ฐ, ์ด๊ฒ์ด ๋ฐ๋ก Multi-Head Attention
์ด๋ค.
Official Paper
๊ธฐ์ค์ผ๋ก Transformer-base
์ hidden states
์ฐจ์์ 512
์ด๋ค. ์ด๊ฒ์ ๊ฐ๋น 64
์ฐจ์์ ๊ฐ๋ 8
๊ฐ์ Attention-Head
๋ก ์ชผ๊ฐ ๋ค, 8๊ฐ์ Attention-Head
์์ ๋์์ Self-Attention
์ ์ํํ๋ค. ์ดํ ๊ฒฐ๊ณผ๋ฅผ concat
ํ์ฌ ๋ค์ hidden states
๋ฅผ 512
๋ก ๋ง๋ ๋ค, ์ฌ๋ฌ ํด๋์์ ๋ง๋ ๊ฒฐ๊ณผ๋ฅผ ์ฐ๊ฒฐํ๊ณ ์์ด์ฃผ๊ธฐ ์ํด ์
์ถ๋ ฅ ์ฐจ์์ด hidden states
์ ๋์ผํ linear projection layer
์ ํต๊ณผ์ํจ๋ค. ์ด๊ฒ์ด ์ธ์ฝ๋(ํน์ ๋์ฝ๋) ๋ธ๋ญ ํ ๊ฐ์ ์ต์ข
Self-Attention
๊ฒฐ๊ณผ๊ฐ ๋๋ค.
Multi-Head Attention Result Visualization
๊ทธ๋ผ ์ ์ด๋ ๊ฒ ์ฌ๋ฌ ํด๋๋ฅผ ์ฌ์ฉํ์๊น?? ๋ฐ๋ก ์ง๋จ์ง์ฑ์ ํจ๊ณผ๋ฅผ ๋๋ฆฌ๊ธฐ ์ํจ์ด๋ค. ์๊ฐํด๋ณด์. ์ฑ ํ๋๋ฅผ ์ฝ์ด๋ ์ฌ๋๋ง๋ค ์ ๋ง ๋ค์ํ ํด์์ด ๋์จ๋ค. ๋ชจ๋ธ๋ ๋ง์ฐฌ๊ฐ์ง๋ค. ์ฌ๋ฌ ํด๋๋ฅผ ์ฌ์ฉํด์ ์ข ๋ ๋ค์ํ๊ณ ํ๋ถํ ์๋ฏธ๋ฅผ ์๋ฒ ๋ฉ์ ๋ด๊ณ ์ถ์๋ ๊ฒ์ด๋ค. Kaggle์ ํด๋ณด์ ๋ ์๋ผ๋ฉด, ์ฌ๋ฌ ์ ๋ต์ ์ฌ์ฉํด ์ฌ๋ฌ ๊ฐ์ ๊ฒฐ๊ณผ๋ฅผ ๋์ถํ ๋ค, ๋ง์ง๋ง์ ๋ชจ๋ ์์๋ธํ๋ฉด ์ ๋ต ํ๋ ํ๋์ ๊ฒฐ๊ณผ๋ณด๋ค ๋ ๋์ ์ฑ์ ์ ์ป์ด๋ณธ ๊ฒฝํ์ด ์์ ๊ฒ์ด๋ค. ์ด๊ฒ๋ ๋น์ทํ ํจ๊ณผ๋ฅผ ์๋ํ๋ค๊ณ ์๊ฐํ๋ค. Vision์์ Conv Filter๋ฅผ ์ฌ๋ฌ ์ข ๋ฅ ์ฌ์ฉํด ๋ค์ํ Feature Map์ ์ถ์ถํ๋ ๊ฒ๋ ๋น์ทํ ํ์์ด๋ผ ๋ณผ ์ ์๊ฒ ๋ค.
์ ๊ทธ๋ฆผ์ ์ ์๊ฐ ์ ์ํ Multi-Head Attention
์ ์๊ฐํ ๊ฒฐ๊ณผ๋ค. ์ค๊ฐ์ ์๋ ์ฌ๋ฌ ์๊น์ ๋ ๋ ๊ฐ๋ณ ํด๋๊ฐ ์ดํ
์
ํ๋ ๋ฐฉํฅ์ ๊ฐ๋ฆฌํจ๋ค. ํ ํฐ โmakingโ
์ ๋ํด์ ํด๋๋ค์ด ์๋ก ๋ค๋ฅธ ํ ํฐ์ ์ดํ
์
ํ๊ณ ์๋ค.
ViT Multi-Head Attention Result Visualization
์ ๊ทธ๋ฆผ์ Vision Transformer ๋
ผ๋ฌธ์์ ๋ฐ์ทํ ๊ทธ๋ฆผ(๊ทธ๋ฆผ์ ์์ธํ ์๋ฏธ๋ ์ฌ๊ธฐ์)์ด๋ค. ์ญ์ ๋ง์ฐฌ๊ฐ์ง๋ก ๋ชจ๋ธ์ ์ด๋ฐ๋ถ ์ธ์ฝ๋์ ์ํ Multi-Head๋ค์ด ์๋ก ๋ค์ํ ํ ํฐ์ ์ดํ
์
์ ํ๊ณ ์์์ ์ ์ ์๋ค. ์ถ๊ฐ๋ก ํ๋ฐ์ผ๋ก ๊ฐ์๋ก ์ ์ Attention Distance
๊ฐ ์ผ์ ํ ์์ค์ ์๋ ดํ๋ ๋ชจ์ต์ ๋ณผ ์ ์๋๋ฐ, ์ด๊ฒ์ ๋ ์ด์ด๋ฅผ ํต๊ณผํ ์๋ก ๊ฐ๋ณ ํด๋๊ฐ ์์ ์ด ์ด๋ค ํ ํฐ์ ์ฃผ์๋ฅผ ๊ธฐ์ธ์ฌ์ผํ ์ง ๊ตฌ์ฒด์ ์ผ๋ก ์์๊ฐ๋ ๊ณผ์ ์ด๋ผ๊ณ ํด์ํ ์ ์๋ค. ์ด๋ฐ๋ถ์๋ ์ด์ฐํ ๋ฐ๋ฅผ ๋ชฐ๋ผ์ ์ดํ ํฐ ์ ํ ํฐ์ ์ฃ๋ค ์ดํ
์
ํ๋ ๊ฒ์ด๋ค.
๊ทธ๋์ Transformer
๋ Bottom Layer
์์๋ Global
ํ๊ณ General
ํ ์ ๋ณด๋ฅผ ํฌ์ฐฉํ๊ณ , Output
๊ณผ ๊ฐ๊น์ด Top Layer
์์๋ Local
ํ๊ณ Specific
ํ ์ ๋ณด๋ฅผ ํฌ์ฐฉํ๋ค.
๐ฉโ๐ปย Implementation
์ด์ ๊ตฌํ์ ์ค์ ๋ก ๊ตฌํ์ ํด๋ณด์. ์ญ์ ๊ตฌํ์ ํ์ดํ ์น๋ก ์งํํ๋ค.
# Pytorch Implemenation of Single Attention Head
class AttentionHead(nn.Module):
"""
In this class, we implement workflow of single attention head
Args:
dim_model: dimension of model's latent vector space, default 512 from official paper
dim_head: dimension of each attention head, default 64 from official paper (512 / 8)
dropout: dropout rate, default 0.1
Math:
[q,k,v]=zโขU_qkv, A = softmax(qโขk^t/sqrt(D_h)), SA(z) = Av
"""
def __init__(self, dim_model: int = 512, dim_head: int = 64, dropout: float = 0.1) -> None:
super(AttentionHead, self).__init__()
self.dim_model = dim_model
self.dim_head = dim_head # 512 / 8 = 64
self.dropout = dropout
self.dot_scale = torch.sqrt(torch.tensor(self.dim_head))
self.fc_q = nn.Linear(self.dim_model, self.dim_head) # Linear Projection for Query Matrix
self.fc_k = nn.Linear(self.dim_model, self.dim_head) # Linear Projection for Key Matrix
self.fc_v = nn.Linear(self.dim_model, self.dim_head) # Linear Projection for Value Matrix
def forward(self, x: Tensor, mask: Tensor, enc_output: Tensor = None) -> Tensor:
q, k, v = self.fc_q(x), self.fc_k(x), self.fc_v(x) # x is previous layer's output
if enc_output is not None:
""" For encoder-decoder self-attention """
k = self.fc_k(enc_output)
v = self.fc_v(enc_output)
attention_matrix = scaled_dot_product_attention(q, k, v, self.dot_scale, mask=mask)
return attention_matrix
๋๊ฐ์ Attention-Head
๋ฅผ N
๊ฐ ์ฌ์ฉํ๊ธฐ ๋๋ฌธ์ ๋จผ์ Single Attention Head
์ ๋์์ ๋ฐ๋ก ๊ฐ์ฒด๋ก ๋ง๋ค์๋ค. ์ด๋ ๊ฒ ํ๋ฉด MultiHeadAttention
๊ฐ์ฒด์์ nn.ModuleList
๋ฅผ ์ฌ์ฉํด N
๊ฐ์ ํด๋๋ฅผ ์ด์ด๋ถ์ผ ์ ์์ด์ ๊ตฌํ์ด ํจ์ฌ ๊ฐํธํด์ง๊ธฐ ๋๋ฌธ์ด๋ค. Single Attention Head
๊ฐ์ฒด๊ฐ ํ๋ ์ผ์ ๋ค์๊ณผ ๊ฐ๋ค.
- 1) Linear Projection by Dimension of Single Attention Head
- 2) Maksing
- 3) Scaled Dot-Product Attention
ํํธ, ์ฌ๋ฌ Transformer
๊ตฌํ Git Repo๋ฅผ ์ดํด๋ณด๋ฉด ๊ตฌํ ๋ฐฉ๋ฒ์ ํฌ๊ฒ ํ์์ฒ๋ผ Single Attention Head
๋ฅผ ์ถ์ํํ๊ฑฐ๋ MultiHeadAttention
๊ฐ์ฒด ํ๋์ ๋ชจ๋ ๋์์ ๋๋ ค๋ฃ๋ ๋ฐฉ์์ผ๋ก ๋๋๋ ๊ฒ ๊ฐ๋ค. ์ฌ์ค ๊ตฌํ์ ์ ๋ต์ ์์ง๋ง ๊ฐ์ธ์ ์ผ๋ก ํ์์ ๋ฐฉ์์ ๋นํจ์จ์ ์ด๋ผ ์๊ฐํ๋ค. ์ ๋ ๊ฒ ๊ตฌํํ๋ฉด 3*N
๊ฐ์ linear projector
๋ฅผ ํด๋์ค __init__
์ ๋ง๋ค๊ณ ๊ด๋ฆฌํด์ค์ผ ํ๋๋ฐ ์ฝ์ง ์์ ๊ฒ์ด๋ค. ๋ฌผ๋ก 3
๊ฐ์ linear projector
๋ง ์ด๊ธฐํํด์ ์ฌ์ฉํ๊ณ ๋์ ์ถ๋ ฅ ์ฐจ์์ Dim_Head
๊ฐ ์๋ Dim_Model
๋ก ๊ตฌํํ ๋ค, N
๊ฐ๋ก ์ฐจ์์ ๋ถํ ํ๋ ๋ฐฉ๋ฒ๋ ์๋ค. ํ์ง๋ง ์ฐจ์์ ์ชผ๊ฐ๋ ๋์์ ๊ตฌํํ๋ ๊ฒ๋ ์ฌ์ค ์ฝ์ง ์๋ค. ๊ทธ๋์ ํ์๋ ์ ์์ ๋ฐฉ์์ ์ถ์ฒํ๋ค.
ํํธ, forward
๋ฉ์๋์ if enc_output is not None:
๋ถ๋ถ์ ์ถํ์ ๋์ฝ๋์์ Multi-Head Attention
์ ๊ตฌํํ๊ธฐ ์ํด ์ถ๊ฐํ ์ฝ๋๋ค. ๋์ฝ๋๋ ์ธ์ฝ๋์ ๋ค๋ฅด๊ฒ ํ๋์ ๋์ฝ๋ ๋ธ๋ญ์์ Self-Attention
๋์์ ๋๋ฒํ๋๋ฐ, ๋๋ฒ์งธ ๋์์ ์๋ก ๋ค๋ฅธ ์ถ์ฒ์ ๊ฐ์ ์ด์ฉํด linear projection
์ ์ํํ๋ค. ๋ฐ๋ผ์ ๊ทธ ๊ฒฝ์ฐ๋ฅผ ์ฒ๋ฆฌํด์ฃผ๊ธฐ ์ํด ๊ตฌํํ๊ฒ ๋์๋ค.
์๋๋ MultiHeadAttention
์ ๊ตฌํํ ํ์ดํ ์น ์ฝ๋๋ค.
# Pytorch Implemenation of Single Attention Head
class MultiHeadAttention(nn.Module):
"""
In this class, we implement workflow of Multi-Head Self-Attention
Args:
dim_model: dimension of model's latent vector space, default 512 from official paper
num_heads: number of heads in MHSA, default 8 from official paper for Transformer
dim_head: dimension of each attention head, default 64 from official paper (512 / 8)
dropout: dropout rate, default 0.1
Math:
MSA(z) = [SA1(z); SA2(z); ยท ยท ยท ; SAk(z)]โขUmsa
Reference:
https://arxiv.org/abs/1706.03762
"""
def __init__(self, dim_model: int = 512, num_heads: int = 8, dim_head: int = 64, dropout: float = 0.1) -> None:
super(MultiHeadAttention, self).__init__()
self.dim_model = dim_model
self.num_heads = num_heads
self.dim_head = dim_head
self.dropout = dropout
self.attention_heads = nn.ModuleList(
[AttentionHead(self.dim_model, self.dim_head, self.dropout) for _ in range(self.num_heads)]
)
self.fc_concat = nn.Linear(self.dim_model, self.dim_model)
def forward(self, x: Tensor, mask: Tensor, enc_output: Tensor = None) -> Tensor:
""" x is already passed nn.Layernorm """
assert x.ndim == 3, f'Expected (batch, seq, hidden) got {x.shape}'
attention_output = self.fc_concat(
torch.cat([head(x, mask, enc_output) for head in self.attention_heads], dim=-1)
)
return attention_output
MultiHeadAttention
๊ฐ์ฒด๋ ๊ฐ๋ณ ํด๋๋ค์ด ๋์ถํ ์ดํ
์
๊ฒฐ๊ณผ๋ฅผ concat
ํ๊ณ ๊ทธ๊ฒ์ connect & mix
ํ๋ ค๊ณ linear projection
์ ์ํํ๋ค.
๐ฌ Feed Forward Network
# Pytorch Implementation of FeedForward Network
class FeedForward(nn.Module):
"""
Class for Feed-Forward Network module in transformer
In official paper, they use ReLU activation function, but GELU is better for now
We change ReLU to GELU & add dropout layer
Args:
dim_model: dimension of model's latent vector space, default 512
dim_ffn: dimension of FFN's hidden layer, default 2048 from official paper
dropout: dropout rate, default 0.1
Math:
FeedForward(x) = FeedForward(LN(x))+x
"""
def __init__(self, dim_model: int = 512, dim_ffn: int = 2048, dropout: float = 0.1) -> None:
super(FeedForward, self).__init__()
self.ffn = nn.Sequential(
nn.Linear(dim_model, dim_ffn),
nn.GELU(),
nn.Dropout(p=dropout),
nn.Linear(dim_ffn, dim_model),
nn.Dropout(p=dropout),
)
def forward(self, x: Tensor) -> Tensor:
return self.ffn(x)
ํผ๋ ํฌ์๋๋ ๋ชจ๋ธ์ non-linearity
๋ฅผ ์ถ๊ฐํ๊ธฐ ์ํด ์ฌ์ฉํ๋ ๋ ์ด์ด๋ค. ์๋ณธ ๋ชจ๋ธ์ ReLU
๋ฅผ ์ฌ์ฉํ์ง๋ง ์ต๊ทผ Transformer
๋ฅ ๋ชจ๋ธ์๋ GeLU
๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ด ์ข ๋ ์์ ์ ์ธ ํ์ต์ ํ๋๋ฐ ๋์์ด ๋๋ค๊ณ ๋ฐํ์ ธ, ํ์ ์ญ์ GeLU
๋ฅผ ์ฌ์ฉํด ๊ตฌํํ๋ค. ๋ํ ๋
ผ๋ฌธ์๋ dropout
์ ๋ํ ์ธ๊ธ์ด ์ ํ ์๋๋ฐ, ์๋์ธต์ ์ฐจ์์ ์ ๋ ๊ฒ ํฌ๊ฒ ํค์ ๋ค ์ค์ด๋๋ฐ overfitting
์ด์๊ฐ ์์ ๊ฒ ๊ฐ์์ ViT
๋
ผ๋ฌธ์ ์ฐธ๊ณ ํด ๋ฐ๋ก ์ถ๊ฐํด์คฌ๋ค.
โย Add & Norm
Residual Connection
๊ณผ Layernorm
์ ์๋ฏธํ๋ค. ๋ฐ๋ก ๊ฐ์ฒด๋ฅผ ๋ง๋ค์ด์ ์ฌ์ฉํ์ง๋ ์๊ณ , EncoderLayer
๊ฐ์ฒด์ ๋ผ์ธ์ผ๋ก ์ถ๊ฐํด ๊ตฌํํ๊ธฐ ๋๋ฌธ์ ์ฌ๊ธฐ์๋ ์ญํ ๊ณผ ์๋ฏธ๋ง ์ค๋ช
ํ๊ณ ๋์ด๊ฐ๊ฒ ๋ค.
๋จผ์ Skip-Connection
์ผ๋ก๋ ๋ถ๋ฆฌ๋ Residual Connection
์ ์ด๋ค ๋ ์ด์ด๋ฅผ ํต๊ณผํ๊ธฐ ์ , ์
๋ ฅ $x$ ๋ฅผ ๋ ์ด์ด๋ฅผ ํต๊ณผํ๊ณ ๋์จ ๊ฒฐ๊ณผ๊ฐ $fx$ ์ ๋ํด์ค๋ค. ๋ฐ๋ผ์ ๋ค์ ๋ ์ด์ด์ ํต๊ณผ๋๋ ์
๋ ฅ๊ฐ์ $x+fx$ ๊ฐ ๋๋ค. ์ ์ด๋ ๊ฒ ๋ํด์ค๊น?? ๋ฐ๋ก ๋ชจ๋ธ์ ์์ ์ ์ธ ํ์ต์ ์ํด์๋ค. ์ผ๋จ ๊ทธ์ ์ ๋ช
์ฌํ๊ณ ๊ฐ์ผํ ์ ์ ๊ฐ ํ๋ ์๋ค. ๋ชจ๋ธ์ ๋ ์ด์ด๊ฐ ๊น์ด์ง์๋ก ๋ ์ด์ด๋ง๋ค ๊ฐ์ ์กฐ๊ธ์ฉ ๋ฐ๊ฟ๋๊ฐ๋ ๊ฒ์ด Robust
ํ๊ณ Stable
ํ ๊ฒฐ๊ณผ๋ฅผ ๋์ถํ ์ ์๋ค๋ ๊ฒ์ด๋ค. ์ง๊ด์ ์ผ๋ก ๋ ์ด์ด๋ง๋ค ๊ฒฐ๊ณผ๊ฐ ๋๋ฐ๊ธฐํ๋ ๋ชจ๋ธ๋ณด๋ค ์์ ์ ์ผ๋ก ์ฐจ๊ทผ์ฐจ๊ทผ ํ์ตํด๋๊ฐ๋ ๋ชจ๋ธ์ ์ผ๋ฐํ ์ฑ๋ฅ์ด ๋ ์ข์ ๊ฒ์ด๋ผ๊ณ ์ถ์ธกํด๋ณผ ์ ์๋ค. ๊ทธ๋์ Residual Connection
์ ์
๋ ฅ $x$ ์ ๋ ์ด์ด์ ์ด์์ ์ธ ์ถ๋ ฅ๊ฐ $H(x)$ ์ ์ฐจ์ด๊ฐ ํฌ์ง ์์์ ๊ฐ์ ํ๋ค. ๋ง์ฝ, ์
๋ ฅ $X$ ๋ฅผ 10.0
, $H(x)$ ๋ฅผ 10.4
๋ผ๊ณ ํด๋ณด์. ๊ทธ๋ผ Residual Connection
์ ์ฌ์ฉํ๋ ๋ชจ๋ธ์ 0.4
์ ๋ํด์๋ง ํ์ต์ ํ๋ฉด ๋๋ค. ํํธ ์ด๊ฒ์ ์ฌ์ฉํ์ง ์๋ ๋ชจ๋ธ์ 0์์๋ถํฐ ์์ํด ๋ฌด๋ ค 10.4
๋ฅผ ํ์ตํด์ผ ํ๋ค. ์ด๋ค ๋ชจ๋ธ์ด ํ์ตํ๊ธฐ ์ฌ์ธ๊น?? ๋น์ฐํ ์ ์์ผ ๊ฒ์ด๋ค. ์ด๋ ๊ฒ ๋ชจ๋ธ์ด ์ด์์ ์ธ ๊ฐ๊ณผ ์
๋ ฅ์ ์ฐจ์ด๋ง ํ์ตํ๋ฉด ๋๊ธฐ ๋๋ฌธ์ ์ด๊ฒ์ ์์ฐจ ํ์ต
์ด๋ผ๊ณ ๋ถ๋ฅด๋ ๊ฒ์ด๋ค.
Batchnorm
์ โMini-Batchโ
๋จ์๋ฅผ Channel(Feature)
๋ณ๋ก ํ๊ท ๊ณผ ํ์คํธ์ฐจ๋ฅผ ๊ตฌํ๋ค๋ฉด, Layernorm
์ Channel(Feature)
๋จ์๋ฅผ ๊ฐ๋ณ ์ธ์คํด์ค
๋ณ๋ก ํ๊ท ๊ณผ ํ์คํธ์ฐจ๋ฅผ ๊ตฌํ์ฌ ์ ๊ทํํ๋ ๋ฐฉ์์ด๋ค.
์๋ฅผ ๋ค์ด ๋ฐฐ์น๋ก 4๊ฐ์ ๋ฌธ์ฅ์ ์๋์ธต์ ์ฌ์ด์ฆ๊ฐ 512
์ธ ๋ชจ๋ธ์ ์
๋ ฅํด์คฌ๋ค๊ณ ์๊ฐํด๋ณด์. ๊ทธ๋ผ 4๊ฐ์ ๋ฌธ์ฅ์ ๊ฐ๊ฐ 512
๊ฐ์ ์์๋ฅผ ๊ฐ๊ฒ ๋๋๋ฐ, ์ด๊ฒ์ ๋ํ ํ๊ท ๊ณผ ํ์คํธ์ฐจ๋ฅผ ๊ตฌํ๋ค๋ ๊ฒ์ด๋ค. ํ ๊ฐ์ ๋ฌธ์ฅ๋น ํ๊ท ๊ณผ ํ์คํธ์ฐจ๋ฅผ 1๊ฐ์ฉ ๊ตฌํด์, 4๊ฐ์ ๋ฌธ์ฅ์ด๋๊น ์ด 8๊ฐ๊ฐ ๋์ค๊ฒ ๋ค.
๊ทธ๋ ๋ค๋ฉด ์ Transformer
๋ Layernorm
์ ์ฌ์ฉํ์๊น?? ์์ฐ์ด ์ฒ๋ฆฌ๋ ๋ฐฐ์น๋ง๋ค ์ํ์ค์ ๊ธธ์ด๊ฐ ๊ณ ์ ๋์ด ์์ง ์์ ํจ๋ฉ์ด๋ ์ ์ญ์ ์ํํ๋ค. ์ ์ญ๋ณด๋ค๋ ํจ๋ฉ์ด ๋ฌธ์ ๊ฐ ๋๋ค. ํจ๋ฉ์ ์ผ๋ฐ์ ์ผ๋ก ๋ฌธ์ฅ์ ๋๋ถ๋ถ์ ํด์ค๋ค. ์ฌ๊ธฐ์ Batchnorm
์ ์ฌ์ฉํ๋ฉด ๋์ชฝ์ ์์นํ ๋ค๋ฅธ ์ํ์ค์ ์ํ ์ ์์ ์ธ ํ ํฐ๋ค์ ํจ๋ฉ์ ์ํด ๊ฐ์ด ์๊ณก๋ ๊ฐ๋ฅ์ฑ์ด ์๋ค. ๊ทธ๋์ Batchnorm
๋์ Layernorm
์ ์ฌ์ฉํ๋ค. ๋ํ Batchnorm
์ ๋ฐฐ์น ํฌ๊ธฐ์ ์ข
์์ ์ด๋ผ์ ํ
์คํธ ์ํฉ์์๋ ๊ทธ๋๋ก ์ฌ์ฉํ ์ ์๋ค. ๋ฐ๋ผ์ ๋ฐฐ์น ์ฌ์ด์ฆ์ ๋
๋ฆฝ์ ์ธ Layernorm
์ ์ฌ์ฉํ๊ธฐ๋ ํ๋ค.
ํํธ ์ด๋ฌํ ์ ๊ทํ๋ฅผ ์ ์ฌ์ฉํ๋์ง ๊ถ๊ธํ์๋ค๋ฉด ๋ค๋ฅธ ํฌ์คํธ์ ์ ๋ฆฌ๋ฅผ ํด๋์ผ๋ ์ฐธ๊ณ ํ์๊ธธ ๋ฐ๋๋ค. ๊ฐ๋จํ๊ฒ๋ง ์ธ๊ธํ๋ฉด, ๋ชจ๋ธ์ ๋น์ ํ์ฑ
๊ณผ ๊ทธ๋ผ๋์ธํธ ํฌ๊ธฐ ์ฌ์ด์ ์ต์ ์ Trade-Off
๋ฅผ ์ธ๊ฐ์ด ์๋ ๋ชจ๋ธ๋ณด๊ณ ์ฐพ๊ฒ ๋ง๋๋๊ฒ ๋ชฉ์ ์ด๋ผ ๋ณผ ์ ์๋ค.
๐ย EncoderLayer
์ด์ Single Encoder Block
์ ์ ์ํ๊ธฐ์ ํ์ํ ๋ชจ๋ ์ฌ๋ฃ๋ฅผ ์ดํด๋ดค๋ค. ์ง๊ธ๊น์ง์ ๋ด์ฉ์ ์ข
ํฉํด ํ ๊ฐ์ ์ธ์ฝ๋ ๋ธ๋ญ์ ๋ง๋ค์ด๋ณด์.
# Pytorch Implementation of Single Encoder Block
class EncoderLayer(nn.Module):
"""
Class for encoder model module in Transformer
In this class, we stack each encoder_model module (Multi-Head Attention, Residual-Connection, LayerNorm, FFN)
We apply pre-layernorm, which is different from original paper
In common sense, pre-layernorm are more effective & stable than post-layernorm
"""
def __init__(self, dim_model: int = 512, num_heads: int = 8, dim_ffn: int = 2048, dropout: float = 0.1) -> None:
super(EncoderLayer, self).__init__()
self.self_attention = MultiHeadAttention(
dim_model,
num_heads,
int(dim_model / num_heads),
dropout,
)
self.layer_norm1 = nn.LayerNorm(dim_model)
self.layer_norm2 = nn.LayerNorm(dim_model)
self.dropout = nn.Dropout(p=dropout)
self.ffn = FeedForward(
dim_model,
dim_ffn,
dropout,
)
def forward(self, x: Tensor, mask: Tensor) -> Tensor:
ln_x = self.layer_norm1(x)
residual_x = self.dropout(self.self_attention(ln_x, mask)) + x
ln_x = self.layer_norm2(residual_x)
fx = self.ffn(ln_x) + residual_x
return fx
์ง๊ธ๊น์ง์ ๋ด์ฉ์ ๊ฐ์ฒด ํ๋์ ๋ชจ์๋๊ฑฐ๋ผ ํน๋ณํ ์ค๋ช
์ด ํ์ํ ๋ถ๋ถ์ ์์ง๋ง, ํ์๊ฐ add & norm
์ ์ธ์ ์ฌ์ฉํ๋์ง ์ฃผ๋ชฉํด๋ณด์. ์๋ณธ ๋
ผ๋ฌธ์ Multi-Head Attention
๊ณผ FeedForward
Layer
๋ฅผ ํต๊ณผํ ์ดํ์ add & norm
์ ํ๋ post-layernorm
๋ฐฉ์์ ์ ์ฉํ๋ค. ํ์ง๋ง ํ์๋ ๋ ๋ ์ด์ด ํต๊ณผ ์ด์ ์ ๋ฏธ๋ฆฌ add & norm
์ ํด์ฃผ๋ pre-layernorm
๋ฐฉ์์ ์ฑํํ๋ค.
pre-layernorm vs post-layernorm
์ต๊ทผ Transformer
๋ฅ์ ๋ชจ๋ธ์ pre-layernorm
์ ์ ์ฉํ๋ ๊ฒ์ด ์ข ๋ ์์ ์ ์ด๊ณ ํจ์จ์ ์ธ ํ์ต์ ์ ๋ํ ์ ์๋ค๊ณ ์คํ์ ํตํด ๋ฐํ์ง๊ณ ์๋ค. pre-layernorm
์ ์ฌ์ฉํ๋ฉด ๋ณ๋ค๋ฅธ Gradient Explode
ํ์์ด ํ์ ํ ์ค์ด๋ค์ด ๋ณต์กํ ์ค์ผ์ค๋ฌ(warmup
๊ธฐ๋ฅ์ด ์๋ ์ค์ผ์ค๋ฌ)๋ฅผ ์ฌ์ฉํ ํ์๊ฐ ์์ด์ง๋ค๊ณ ํ๋ ์ฐธ๊ณ ํ์.
์ด๋ ๊ฒ ๊ตฌํํ Single Encoder Block
์ ์ด์ N๊ฐ ์๊ธฐ๋ง ํ๋ฉด ๋๋์ด ์ธ์ฝ๋๋ฅผ ์์ฑํ ์ ์๊ฒ ๋๋ค.
๐ Encoder
๋๋์ด ๋๋ง์ Encoder
๊ฐ์ฒด ๊ตฌํ์ ์ดํด๋ณผ ์๊ฐ์ด๋ค.
# Pytorch Implementation of Encoder(Stacked N EncoderLayer)
class Encoder(nn.Module):
"""
In this class, encode input sequence and then we stack N EncoderLayer
First, we define "positional embedding" and then add to input embedding for making "word embedding"
Second, forward "word embedding" to N EncoderLayer and then get output embedding
In official paper, they use positional encoding, which is base on sinusoidal function(fixed, not learnable)
But we use "positional embedding" which is learnable from training
Args:
max_seq: maximum sequence length, default 512 from official paper
N: number of EncoderLayer, default 6 for base model
"""
def __init__(self, max_seq: 512, N: int = 6, dim_model: int = 512, num_heads: int = 8, dim_ffn: int = 2048, dropout: float = 0.1) -> None:
super(Encoder, self).__init__()
self.max_seq = max_seq
self.scale = torch.sqrt(torch.Tensor(dim_model)) # scale factor for input embedding from official paper
self.positional_embedding = nn.Embedding(max_seq, dim_model) # add 1 for cls token
self.num_layers = N
self.dim_model = dim_model
self.num_heads = num_heads
self.dim_ffn = dim_ffn
self.dropout = nn.Dropout(p=dropout)
self.encoder_layers = nn.ModuleList(
[EncoderLayer(dim_model, num_heads, dim_ffn, dropout) for _ in range(self.num_layers)]
)
self.layer_norm = nn.LayerNorm(dim_model)
def forward(self, inputs: Tensor, mask: Tensor) -> tuple[Tensor, Tensor]:
"""
inputs: embedding from input sequence, shape => [BS, SEQ_LEN, DIM_MODEL]
mask: mask for Encoder padded token for speeding up to calculate attention score
"""
layer_output = []
pos_x = torch.arange(self.max_seq).repeat(inputs.shape[0]).to(inputs)
x = self.dropout(
self.scale * inputs + self.positional_embedding(pos_x)
)
for layer in self.encoder_layers:
x = layer(x, mask)
layer_output.append(x)
encoded_x = self.layer_norm(x) # from official paper & code by Google Research
layer_output = torch.stack(layer_output, dim=0).to(x.device) # For Weighted Layer Pool: [N, BS, SEQ_LEN, DIM]
return encoded_x, layer_output
์ญ์ ์ง๊ธ๊น์ง ๋ด์ฉ์ ์ข
ํฉํ ๊ฒ๋ฟ์ด๋ผ์ ํฌ๊ฒ ํน์ดํ ๋ด์ฉ์ ์๊ณ , ๊ตฌํ์ ๋์น๊ธฐ ์ฌ์ด ๋ถ๋ถ๋ง ์๊ณ ๋์ด๊ฐ๋ฉด ๋๋ค. forward
๋ฉ์๋์ ๋ณ์ x
๋ฅผ ์ด๊ธฐํํ๋ ์ฝ๋ ๋ผ์ธ์ ์ฃผ๋ชฉํด๋ณด์. ์ด๊ฒ์ด ๋ฐ๋ก Input Embedding
๊ณผ Position Embedding
์ ๋ํ๋(ํ๋ ฌ ํฉ) ์ฐ์ฐ์ ๊ตฌํํ ๊ฒ์ด๋ค. ์ด ๋ ๋์น๊ธฐ ์ฌ์ด ๋ถ๋ถ์ด ๋ฐ๋ก Input Embedding
์ scale factor
๋ฅผ ๊ณฑํด์ค๋ค๋ ๊ฒ์ด๋ค. ์ ์์ ์ฃผ์ฅ์ ๋ฐ๋ฅด๋ฉด Input Embedding
์๋ง scale factor
๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ด ์์ ์ ์ธ ํ์ต์ ๋์์ด ๋๋ค๊ณ ํ๋ ์ฐธ๊ณ ํ์.
ํํธ, ๋ง์ง๋ง ์ธ์ฝ๋ ๋ธ๋ญ์์ ๋์จ ์๋ฒ ๋ฉ์ ๋ค์ ํ ๋ฒ layernorm
์ ํต๊ณผํ๋๋ก ๊ตฌํํ๋ค. ์ด ๋ถ๋ถ๋ ์๋ณธ ๋
ผ๋ฌธ์ ์๋ ๋ด์ฉ์ ์๋๊ณ ViT
์ ๋
ผ๋ฌธ ๋ด์ฉ์ ์ฐธ๊ณ ํด ์ถ๊ฐํ๋ค.
๐ย DecoderLayer
์ด๋ฒ์๋ ๋์ฝ๋์ ์ฌ์ฉ๋ ๋ธ๋ญ์ ๋์ ๋ฐฉ์๊ณผ ์๋ฏธ ๊ทธ๋ฆฌ๊ณ ๊ตฌํ๊น์ง ์์๋ณด์. ์ฌ์ค ๋์ฝ๋๋ ์ง๊ธ๊น์ง ๊ณต๋ถํ ๋ด์ฉ๊ณผ ํฌ๊ฒ ๋ค๋ฅธ๊ฒ ์๋ค. ๋ค๋ง ์ธ์ฝ๋์๋ ๋ชฉ์ ์ด ๋ค๋ฅด๊ธฐ ๋๋ฌธ์ ๋ฐ์ํ๋ ๋ฏธ์ธํ ๋์์ ์ฐจ์ด์ ์ฃผ๋ชฉํด๋ณด์. ๋จผ์ Single Decoder Block
์ Single Encoder Block
๊ณผ ๋ค๋ฅด๊ฒ Self-Attention
์ ๋ ๋ฒ ์ํํ๋ค. ์ง๊ฒน๊ฒ ์ง๋ง ๋ค์ ํ ๋ฒ Transformer์ ๋ชฉ์ ์ ์๊ธฐ์์ผ๋ณด์. ๋ฐ๋ก ๋์ ์ธ์ด๋ฅผ ํ๊ฒ ์ธ์ด๋ก ์ ๋ฒ์ญํ๋ ๊ฒ์ด์๋ค. ์ผ๋จ ์ธ์ฝ๋๋ฅผ ํตํด ๋์ ์ธ์ด๋ ์ ์ดํดํ๊ฒ ๋์๋ค. ๊ทธ๋ผ ์ด์ ํ๊ฒ ์ธ์ด๋ ์ ์ดํดํด์ผํ์ง ์์๊ฐ?? ๊ทธ๋์ ํ๊ฒ ์ธ์ด๋ฅผ ์ดํดํ๊ธฐ ์ํด Self-Attention
์ ํ ๋ฒ, ๊ทธ๋ฆฌ๊ณ ๋์ ์ธ์ด๋ฅผ ํ๊ฒ ์ธ์ด๋ก ๋ฒ์ญํ๊ธฐ ์ํด Self-Attention
์ ํ ๋ฒ, ์ด 2๋ฒ ์ํํ๋ ๊ฒ์ด๋ค. ์ฒซ๋ฒ์งธ Self-Attention
์ Masked Multi-Head Attention
, ๋๋ฒ์งธ๋ฅผ Encoder-Decoder Multi-Head Attention
์ด๋ผ๊ณ ๋ถ๋ฅธ๋ค.
๐ญ Masked Multi-Head Attention
์ธ์ฝ๋์ Multi-Head Attention์
ํ๋ ฌ $Q,K,V$ ์ ์ถ์ฒ๊ฐ ๋ค๋ฅด๋ค. ๋์ฝ๋๋ ์ถ์ฒ๊ฐ ํ๊ฒ ์ธ์ด์ธ linear projection matrix
๋ฅผ ์ฌ์ฉํ๋ค. ๋ํ ์ธ์ฝ๋์ ๋ค๋ฅด๊ฒ ๊ฐ๋ณ ์์ ์ ๋ง๋ ๋ง์คํน ํ๋ ฌ์ด ํ์ํ๋ค. ๋์ฝ๋์ ๊ณผ์
์ ๊ฒฐ๊ตญ ๋์ ์ธ์ด๋ฅผ ์ ์ดํดํ๊ณ ๊ทธ๊ฒ์ ๊ฐ์ฅ ์ ๋ค์ด๋ง๋ ํ๊ฒ ์ธ์ด ์ํ์ค๋ฅผ generate
ํ๋ ๊ฒ์ด๋ค. ์ฆ, Next Token Prediction
์ ํตํด ์ํ์ค๋ฅผ ๋ง๋ค์ด๋ด์ผ ํ๋ค. ๊ทธ๋ฐ๋ฐ ํ์ฌ ์์ ์์ ๋ฏธ๋ ์์ ์ ๋์ฝ๋๊ฐ ์์ธกํด์ผํ ํ ํฐ์ ๋ฏธ๋ฆฌ ์๊ณ ์์ผ๋ฉด ๊ทธ๊ฒ์ ์์ธก์ด๋ผ๊ณ ํ ์ ์์๊น?? ๋์ฝ๋๊ฐ ํ์ฌ ์์ ์ ํ ํฐ์ ์์ธกํ๋๋ฐ ๋ฏธ๋ ์์ ์ Context
๋ฅผ ๋ฐ์ํ์ง ๋ชปํ๋๋ก ๋ง๊ธฐ ์ํด ๋ฏธ๋ฆฌ ๋ง์คํน ํ๋ ฌ์ ์ ์ํด Word_Embedding
์ ์ ์ฉํด์ค๋ค. ์ด๋ ๊ฒ ๋ง์คํน์ด ์ ์ฉ๋ ์๋ฒ ๋ฉ ํ๋ ฌ์ ๊ฐ์ง๊ณ linear projection & self-attention
์ ์ํํ๊ธฐ ๋๋ฌธ์ ์ด๋ฆ ์์ masked
๋ฅผ ๋ถ์ด๊ฒ ๋์๋ค.
Decoder Language Modeling Mask
์ ๊ทธ๋ฆผ์ ๋ง์คํน์ ์ ์ฉํ Word_Embedding
์ ๋ชจ์ต์ด๋ค. ์ฒซ ๋ฒ์งธ ์์ ์์ ๋ชจ๋ธ์ ์๊ธฐ ์์ ์ ์ ์ธํ ๋๋จธ์ง Context
๋ฅผ ์์ธก์ ํ์ฉํ ์ ์๋ค. ๊ทธ๋์ ์ดํ ๋๋จธ์ง ํ ํฐ์ ์ ๋ถ ๋ง์คํน ์ฒ๋ฆฌํด์คฌ๋ค. ๋๋ฒ์งธ ์์ ์์๋ ์ง์ ์์ ์ธ ์ฒซ๋ฒ์งธ ํ ํฐ๊ณผ ์๊ธฐ ์์ ๋ง ์ฐธ๊ณ ํ ์ ์๋ค. ํํธ, ์ด๋ ๊ฒ ์ง์ Context
๋ง ๊ฐ์ง๊ณ ํ์ฌ ํ ํฐ์ ์ถ๋ก ํ๋ ๊ฒ์ Language Modeling
์ด๋ผ ๋ถ๋ฅธ๋ค. ๊ทธ๋ฆฌ๊ณ ๋ง์ฐฌ๊ฐ์ง๋ก ๋์ฝ๋ ์ญ์ ์ํ์ค์ ํจ๋ฉ ์ฒ๋ฆฌ๋ฅผ ํด์ฃผ๊ธฐ ๋๋ฌธ์ ์ธ์ฝ๋์ ๋์ผํ ์๋ฆฌ๋ก ๋ง๋ decoder padding mask
์ญ์ ํ์ํ๋ค.
๋ง์คํน ํ๋ ฌ ๊ตฌํ์ ์ต์์ ๊ฐ์ฒด์ธ Transformer
์ ๋ด๋ถ ๋ฉ์๋๋ก ๋ง๋ค์์ผ๋, ๊ทธ ๋ ์์ธํ ์ค๋ช
ํ๊ฒ ๋ค. ์ดํ ๋๋จธ์ง ๋ํ
์ผ์ ์ธ์ฝ๋์ ๊ฒ๊ณผ ๋์ผํ๋ค.
๐ชข Encoder-Decoder Multi-Head Attention
์ธ์ฝ๋๋ฅผ ํตํด ์ดํดํ ๋์ ์ธ์ด ์ํ์ค์ ๋ฐ๋ก ์ง์ Self-Attention
์ ํตํด ์ดํดํ ํ๊ฒ ์ธ์ด ์ํ์ค๋ฅผ ์๋ก ๋์กฐํ๋ ๋ ์ด์ด๋ค. ์ฐ๋ฆฌ์ ์ง๊ธ ๋ชฉ์ ์ ํ๊ฒ ์ธ์ด
์ ๊ฐ์ฅ ์ ์ฌํ ๋์ ์ธ์ด
๋ฅผ ์ฐพ์ ๋ฌธ์ฅ์ ์์ฑํ๋ ๊ฒ์ด๋ค. ๋ฐ๋ผ์ ์ดํ
์
๊ณ์ฐ์ ์ฌ์ฉ๋ ํ๋ ฌ $Q$ ์ ์ถ์ฒ๋ ์ง์ ๋ ์ด์ด์ธ Masked Multi-Head Attention
์ ๋ฐํ๊ฐ์ ์ฌ์ฉํ๊ณ , ํ๋ ฌ $K,V$ ๋ ์ธ์ฝ๋์ ์ต์ข
๋ฐํ๊ฐ์ ์ฌ์ฉํ๋ค.
ํํธ, ์ฌ๊ธฐ ๋ ์ด์ด์๋ ๋ง์คํน ํ๋ ฌ์ด ์ธ ์ข
๋ฅ๋ ํ์ํ๋ค. ๊ทธ ์ด์ ๋ ์๋ก ์ถ์ฒ๊ฐ ๋ค๋ฅธ ๋๊ฐ์ง ํ๋ ฌ์ ๊ณ์ฐ์ ์ฌ์ฉํ๊ธฐ ๋๋ฌธ์ด๋ค. ์ง๊ธ์ ์ฌ์ ํ ๋์ฝ๋์ ์ญํ ์ ์ํํ๋ ๊ฒ์ด๊ธฐ ๋๋ฌธ์ ์ง์ ๋ ์ด์ด์์ ์ฌ์ฉํ 2๊ฐ์ ๋ง์คํน ํ๋ ฌ์ด ๊ทธ๋๋ก ํ์ํ๋ค. ๊ทธ๋ฆฌ๊ณ ์ธ์ฝ๋์์ ๋์ด์จ ๊ฐ์ ์ฌ์ฉํ๋ค๋ ๊ฒ์ ์ธ์ฝ๋์ ํจ๋ฉ ์ญ์ ์ฒ๋ฆฌ๊ฐ ํ์ํ๋ค๋ ์๋ฏธ๋ค. ๋ฐ๋ผ์ lm_mask
, dec_pad_mask
, enc_pad_mask
๊ฐ ํ์ํ๋ค. ์ญ์ ๋ง์คํน ๊ตฌํ์ ์ต์์ ๊ฐ์ฒด ์ค๋ช
๋ ํจ๊ป ์ดํด๋ณด๊ฒ ๋ค.
๐ฉโ๐ปย Implementation
์ด์ Single Decoder Block
์ ๊ตฌํ์ ์ดํด๋ณด์. ์ญ์ ํ์ดํ ์น๋ก ๊ตฌํํ๋ค.
# Pytorch Implementation of Single Decoder Block
class DecoderLayer(nn.Module):
"""
Class for decoder model module in Transformer
In this class, we stack each decoder_model module (Masked Multi-Head Attention, Residual-Connection, LayerNorm, FFN)
We apply pre-layernorm, which is different from original paper
References:
https://arxiv.org/abs/1706.03762
"""
def __init__(self, dim_model: int = 512, num_heads: int = 8, dim_ffn: int = 2048, dropout: float = 0.1) -> None:
super(DecoderLayer, self).__init__()
self.masked_attention = MultiHeadAttention(
dim_model,
num_heads,
int(dim_model / num_heads),
dropout,
)
self.enc_dec_attention = MultiHeadAttention(
dim_model,
num_heads,
int(dim_model / num_heads),
dropout,
)
self.layer_norm1 = nn.LayerNorm(dim_model)
self.layer_norm2 = nn.LayerNorm(dim_model)
self.layer_norm3 = nn.LayerNorm(dim_model)
self.dropout = nn.Dropout(p=dropout) # dropout is not learnable layer
self.ffn = FeedForward(
dim_model,
dim_ffn,
dropout,
)
def forward(self, x: Tensor, dec_mask: Tensor, enc_dec_mask: Tensor, enc_output: Tensor) -> Tensor:
ln_x = self.layer_norm1(x)
residual_x = self.dropout(self.masked_attention(ln_x, dec_mask)) + x
ln_x = self.layer_norm2(residual_x)
residual_x = self.dropout(self.enc_dec_attention(ln_x, enc_dec_mask, enc_output)) + x # for enc_dec self-attention
ln_x = self.layer_norm3(residual_x)
fx = self.ffn(ln_x) + residual_x
return fx
Self-Attention
๋ ์ด์ด๊ฐ ์ธ์ฝ๋๋ณด๋ค ํ๋ ๋ ์ถ๊ฐ๋์ด add & norm
์ ์ด 3๋ฒ ํด์ค์ผ ํ๋ค๋ ๊ฒ์ ์ ์ธํ๊ณ ๋ ํฌ๊ฒ ๊ตฌํ์์ ํน์ด์ ์ ์๋ค. ๊ทธ์ ์ง๊ธ๊น์ง ์ดํด๋ณธ ๋ธ๋ญ์ ์๋ฆฌ์กฐ๋ฆฌ ๋ค์ ์์ผ๋ฉด ๋๋ค.
๐ย Decoder
Single Decoder Block
์ N
๊ฐ ์๊ณ ์ ์ฒด ๋์ฝ๋ ๋์์ ์ํํ๋ Decoder
๊ฐ์ฒด์ ๊ตฌํ์ ์์๋ณด์.
# Pytorch Implementation of Decoder(N Stacked Single Decoder Block)
class Decoder(nn.Module):
"""
In this class, decode encoded embedding from encoder by outputs (target language, Decoder's Input Sequence)
First, we define "positional embedding" for Decoder's Input Sequence,
and then add them to Decoder's Input Sequence for making "decoder word embedding"
Second, forward "decoder word embedding" to N DecoderLayer and then pass to linear & softmax for OutPut Probability
Args:
vocab_size: size of vocabulary for output probability
max_seq: maximum sequence length, default 512 from official paper
N: number of EncoderLayer, default 6 for base model
References:
https://arxiv.org/abs/1706.03762
"""
def __init__(
self,
vocab_size: int,
max_seq: int = 512,
N: int = 6,
dim_model: int = 512,
num_heads: int = 8,
dim_ffn: int = 2048,
dropout: float = 0.1
) -> None:
super(Decoder, self).__init__()
self.max_seq = max_seq
self.scale = torch.sqrt(torch.Tensor(dim_model)) # scale factor for input embedding from official paper
self.positional_embedding = nn.Embedding(max_seq, dim_model) # add 1 for cls token
self.num_layers = N
self.dim_model = dim_model
self.num_heads = num_heads
self.dim_ffn = dim_ffn
self.dropout = nn.Dropout(p=dropout)
self.decoder_layers = nn.ModuleList(
[DecoderLayer(dim_model, num_heads, dim_ffn, dropout) for _ in range(self.num_layers)]
)
self.layer_norm = nn.LayerNorm(dim_model)
self.fc_out = nn.Linear(dim_model, vocab_size) # In Pytorch, nn.CrossEntropyLoss already has softmax function
def forward(self, inputs: Tensor, dec_mask: Tensor, enc_dec_mask: Tensor, enc_output: Tensor) -> tuple[Tensor, Tensor]:
"""
inputs: embedding from input sequence, shape => [BS, SEQ_LEN, DIM_MODEL]
dec_mask: mask for Decoder padded token for Language Modeling
enc_dec_mask: mask for Encoder-Decoder Self-Attention, from encoder padded token
"""
layer_output = []
pos_x = torch.arange(self.max_seq).repeat(inputs.shape[0]).to(inputs)
x = self.dropout(
self.scale * inputs + self.positional_embedding(pos_x)
)
for layer in self.decoder_layers:
x = layer(x, dec_mask, enc_dec_mask, enc_output)
layer_output.append(x)
decoded_x = self.fc_out(self.layer_norm(x)) # Because of pre-layernorm
layer_output = torch.stack(layer_output, dim=0).to(x.device) # For Weighted Layer Pool: [N, BS, SEQ_LEN, DIM]
return decoded_x, layer_output
Encoder
๊ฐ์ฒด์ ๋ชจ๋ ๋ถ๋ถ์ด ๋์ผํ๋ค. ๋ํ
์ผํ ์ค์ ๋ง ๋์ฝ๋์ ๋ง๊ฒ ๋ณ๊ฒฝ๋์์ ๋ฟ์ด๋ค. self.fc_out
์ ์ฃผ๋ชฉํด๋ณด์. ๋์ฝ๋๋ ํ์ฌ ์์ ์ ๊ฐ์ฅ ์ ํฉํ ํ ํฐ์ ์์ธกํด์ผ ํ๊ธฐ ๋๋ฌธ์ ๋์ฝ๋์ ์ถ๋ ฅ๋ถ๋ถ์ ๋ก์ง ๊ณ์ฐ์ ์ํ ๋ ์ด์ด๊ฐ ํ์ํ๋ค. ๊ทธ ์ญํ ์ ํ๋ ๊ฒ์ด ๋ฐ๋ก self.fc_out
์ด๋ค. ํํธ, self.fc_out
์ ์ถ๋ ฅ ์ฐจ์์ด vocab_size
์ผ๋ก ๋์ด์๋๋ฐ, ๋์ฝ๋๋ ๋์ฝ๋๊ฐ ๊ฐ์ง ์ ์ฒด vocab
์ ํ์ฌ ์์ ์ ์ ํฉํ ํ ํฐ ํ๋ณด๊ตฐ์ผ๋ก ์ฌ์ฉํ๊ธฐ ๋๋ฌธ์ด๋ค.
๐ฆพ Transformer
์ด์ ๋๋ง์ ๋ง์ง๋งโฆ ๋ชจ๋ธ์ ๊ฐ์ฅ ์ต์์ ๊ฐ์ฒด์ธ Transformer
์ ๋ํด์ ์ดํด๋ณด์. ๊ฐ์ฒด์ ๋์์ ์ ๋ฆฌํ๋ฉด ๋ค์๊ณผ ๊ฐ๋ค.
- 1) Make
Input Embedding
for Encoder & Decoder respectively, InitEncoder & Decoder
Class - 2) Make 3 types of Masking:
Encoder Padding Mask
,Decoder LM & Padding Mask
,Encoder-Decoder Mask
- 3) Return
Output
from Encoder & Decoder
ํนํ ๊ณ์ ๋ฏธ๋ค์๋ ๋ง์คํน ๊ตฌํ์ ๋ํด์ ์ดํด๋ณด์. ๋๋จธ์ง๋ ์ด๋ฏธ ์์์ ๋ง์ด ์ค๋ช
ํ์ผ๋๊น ๋์ด๊ฐ๋๋ก ํ๊ฒ ๋ค. ์ผ๋จ ๋จผ์ ์ฝ๋๋ฅผ ์ฝ์ด๋ณด์. ์ถ๊ฐ๋ก Input Embedding
๊ตฌํ์ ์ฌ์ฉ์์ vocab
๊ตฌ์ถ ๋ฐฉ์์ ๋ฐ๋ผ ๋ฌ๋ผ์ง๋ค. ํ์์ ๊ฒฝ์ฐ ๋์ ์ธ์ด์ ํ๊ฒ ์ธ์ด์ vocab
์ ๋ถ๋ฆฌํด ์ฌ์ฉํ๋ ๊ฒ์ ๊ฐ์ ํ๊ณ ์ฝ๋๋ฅผ ๋ง๋ค์ด ์๋ฒ ๋ฉ ๋ ์ด์ด๋ฅผ ๋ฐ๋ก ๋ฐ๋ก ๊ตฌํํด์คฌ๋ค. vocab
์ ํตํฉ์ผ๋ก ๊ตฌ์ถํ์๋ ๋ถ์ด๋ผ๋ฉด ํ๋๋ง ์ ์ํด๋ ์๊ด์๋ค. ๋์ ๋์ค์ ๋์ฝ๋์ ๋ก์ง๊ฐ ๊ณ์ฐ์ ์ํด ๊ฐ๋ณ ์ธ์ด์ ํ ํฐ ์ฌ์ด์ฆ๋ ์๊ณ ์์ด์ผ ํ ๊ฒ์ด๋ค.
# Pytorch Implementation of Transformer
class Transformer(nn.Module):
"""
Main class for Pure Transformer, Pytorch implementation
There are two Masking Method for padding token
1) Row & Column masking
2) Column masking only at forward time, Row masking at calculating loss time
second method is more efficient than first method, first method is complex & difficult to implement
Args:
enc_vocab_size: size of vocabulary for Encoder Input Sequence
dec_vocab_size: size of vocabulary for Decoder Input Sequence
max_seq: maximum sequence length, default 512 from official paper
enc_N: number of EncoderLayer, default 6 for base model
dec_N: number of DecoderLayer, default 6 for base model
Reference:
https://arxiv.org/abs/1706.03762
"""
def __init__(
self,
enc_vocab_size: int,
dec_vocab_size: int,
max_seq: int = 512,
enc_N: int = 6,
dec_N: int = 6,
dim_model: int = 512,
num_heads: int = 8,
dim_ffn: int = 2048,
dropout: float = 0.1
) -> None:
super(Transformer, self).__init__()
self.enc_input_embedding = nn.Embedding(enc_vocab_size, dim_model)
self.dec_input_embedding = nn.Embedding(dec_vocab_size, dim_model)
self.encoder = Encoder(max_seq, enc_N, dim_model, num_heads, dim_ffn, dropout)
self.decoder = Decoder(dec_vocab_size, max_seq, dec_N, dim_model, num_heads, dim_ffn, dropout)
@staticmethod
def enc_masking(x: Tensor, enc_pad_index: int) -> Tensor:
""" make masking matrix for Encoder Padding Token """
enc_mask = (x != enc_pad_index).int().repeat(1, x.shape[-1]).view(x.shape[0], x.shape[-1], x.shape[-1])
return enc_mask
@staticmethod
def dec_masking(x: Tensor, dec_pad_index: int) -> Tensor:
""" make masking matrix for Decoder Masked Multi-Head Self-Attention """
pad_mask = (x != dec_pad_index).int().repeat(1, x.shape[-1]).view(x.shape[0], x.shape[-1], x.shape[-1])
lm_mask = torch.tril(torch.ones(x.shape[0], x.shape[-1], x.shape[-1]))
dec_mask = pad_mask * lm_mask
return dec_mask
@staticmethod
def enc_dec_masking(enc_x: Tensor, dec_x: Tensor, enc_pad_index: int) -> Tensor:
""" make masking matrix for Encoder-Decoder Multi-Head Self-Attention in Decoder """
enc_dec_mask = (enc_x != enc_pad_index).int().repeat(1, dec_x.shape[-1]).view(
enc_x.shape[0], dec_x.shape[-1], enc_x.shape[-1]
)
return enc_dec_mask
def forward(self, enc_inputs: Tensor, dec_inputs: Tensor, enc_pad_index: int, dec_pad_index: int) -> tuple[Tensor, Tensor, Tensor, Tensor]:
enc_mask = self.enc_masking(enc_inputs, enc_pad_index) # enc_x.shape[1] == encoder input sequence length
dec_mask = self.dec_masking(dec_inputs, dec_pad_index) # dec_x.shape[1] == decoder input sequence length
enc_dec_mask = self.enc_dec_masking(enc_inputs, dec_inputs, enc_pad_index)
enc_x, dec_x = self.enc_input_embedding(enc_inputs), self.dec_input_embedding(dec_inputs)
enc_output, enc_layer_output = self.encoder(enc_x, enc_mask)
dec_output, dec_layer_output = self.decoder(dec_x, dec_mask, enc_dec_mask, enc_output)
return enc_output, dec_output, enc_layer_output, dec_layer_output
๋ง์คํน์ ํ์์ฑ์ด๋ ๋์ ๋ฐฉ์์ ์ด๋ฏธ ์์์ ๋ชจ๋ ์ค๋ช
ํ๊ธฐ ๋๋ฌธ์ ๊ตฌํ์ ํน์ง๋ง ์ค๋ช
ํ๋ คํ๋ค. ์ธ๊ฐ์ง ๋ง์คํน ๋ชจ๋ ๊ณตํต์ ์ผ๋ก ๊ตฌํ ์ฝ๋ ๋ผ์ธ์ .int()
๊ฐ ๋ค์ด๊ฐ ์๋ค. ๊ทธ ์ด์ ๋ $\frac{QโขK^T}{\sqrt{d_h}}$์ ๋ง์คํน์ ์ ์ฉํ ๋ torch.masked_fill
๋ฉ์๋๋ฅผ ์ฌ์ฉํ๊ธฐ ๋๋ฌธ์ด๋ค. ๋ฌด์จ ์ด์ ๋๋ฌธ์ธ์ง๋ ๋ชจ๋ฅด๊ฒ ์ผ๋ torch.masked_fill
์ ๊ฒฝ์ฐ ๋ง์คํน ์กฐ๊ฑด์ผ๋ก boolean
์ ์ ๋ฌํ๋ฉด ๋ง์คํน์ด ์ ๋๋ก ๊ตฌํ๋์ง ์๋ ํ์์ด ์์๋ค. ํํธ, ์ ์๊ฐ์ผ๋ก ์กฐ๊ฑด์ ๊ตฌํํ๋ฉด ์๋ํ๋๋ก ๊ตฌํ์ด ๋๋ ๊ฒ์ ํ์ธํ๋ค. ๊ทธ๋์ ํจ๋ฉ์ ํด๋น๋๋ ํ ํฐ์ด ์์นํ ๊ณณ์ ์์๊ฐ์ ์ ์ํ Binary
๋ก ๋ง๋ค์ด์ฃผ๊ธฐ ์ํด int()
๋ฅผ ์ฌ์ฉํ ๊ฒ์ด๋ค.
๐ญย Decoder Mask
๋์ฝ๋๋ ์ด 2๊ฐ์ง์ ๋ง์คํน์ด ํ์ํ๋ค๊ณ ์ธ๊ธํ์๋ค. pad_mask
์ ๊ฒฝ์ฐ๋ ์ธ์ฝ๋์ ๊ฒ๊ณผ ๋์ผํ ์๋ฆฌ๋ฅผ ์ฌ์ฉํ๊ธฐ ๋๋ฌธ์ ์ค๋ช
์ ์๋ตํ๊ฒ ๋ค. lm_mask
์ ๊ฒฝ์ฐ๋ torch.tril
์ ์ด์ฉํด ํ์ผ๊ฐํ๋ ฌ ํํ๋ก ๋ง์คํน ํ๋ ฌ ์ ์๊ฐ ์ฝ๊ฒ ๊ฐ๋ฅํ๋ค.
ํํธ, 2๊ฐ์ ๋ง์คํน์ ๋์์ ๋์ฝ๋ ๊ฐ์ฒด์ ๋๊ธฐ๋ ๊ฒ์ ๋งค์ฐ ๋นํจ์จ์ ์ด๋ค. ๋ฐ๋ผ์ pad_mask
์ lm_mask
์ ํฉ์งํฉ์ ํด๋นํ๋ ํ๋ ฌ์ ๋ง๋ค์ด ์ต์ข
๋์ฝ๋์ ๋ง์คํน์ผ๋ก ์ ๋ฌํ๋ค.
๐ย Encoder-Decoder Mask
์ด๋ฒ ๊ฒฝ์ฐ๋ ๋ง์คํน์ ํ๋ฐฉํฅ ์ฐจ์์ ๋์ฝ๋ ์
๋ ฅ๊ฐ์ ์ํ์ค ๊ธธ์ด, ์ด๋ฐฉํฅ ์ฐจ์์ ์ธ์ฝ๋ ์
๋ ฅ๊ฐ์ ์ํ์ค ๊ธธ์ด๋ก ์ค์ ํด์ผ ํ๋ค. ๊ทธ ์ด์ ๋ ๋ค๋ฅธ Self-Attention
๋ ์ด์ด์ ๋ค๋ฅด๊ฒ ์๋ก ๋ค๋ฅธ ์ถ์ฒ๋ฅผ ํตํด ๋ง๋ ํ๋ ฌ์ ์ฌ์ฉํ๊ธฐ ๋๋ฌธ์ $\frac{QโขK^T}{\sqrt{d_h}}$์ ๋ชจ์์ด ์ ์ฌ๊ฐํ๋ ฌ์ด ์๋ ์๋ ์๋ค. ์๋ฅผ ๋ค์ด ํ๊ตญ์ด ๋ฌธ์ฅ์ ์์ด๋ก ๋ฐ๊พธ๋ ๊ฒฝ์ฐ๋ฅผ ์๊ฐํด๋ณด์. ๊ฐ์ ๋ป์ด ๋ด๊ธด ๋ฌธ์ฅ์ด๋ผ๊ณ ํด์ ๋ ๋ฌธ์ฅ์ ๊ธธ์ด๊ฐ ๊ฐ์๊ฐ?? ์๋๋ค. ์๋ก ๋ค๋ฅธ ์ธ์ด๋ผ๋ฉด ๊ฑฐ์ ๋๋ถ๋ถ์ ๊ฒฝ์ฐ ๊ธธ์ด๊ฐ ๋ค๋ฅผ ๊ฒ์ด๋ค. ๋ฐ๋ผ์ $\frac{QโขK^T}{\sqrt{d_h}}$์ ํ๋ฐฉํฅ์ ๋์ฝ๋์ ์ํ์ค ๊ธธ์ด์ ๋ฐ๋ฅด๊ณ ์ด๋ฐฉํฅ์ ์ธ์ฝ๋์ ์ํ์ค ๊ธธ์ด์ ๋ฐ๋ฅด๋๋ก ๋ง์คํน ์ญ์ ๊ตฌํํด์ค์ผ ํ๋ค.
๊ทธ๋ฆฌ๊ณ ์ด๋ฒ ๋ง์คํน์ ๋ง๋๋ ๋ชฉ์ ์ด ์ธ์ฝ๋์ ํจ๋ฉ์ ๋ง์คํน ์ฒ๋ฆฌํด์ฃผ๊ธฐ ์ํจ์ด๊ธฐ ๋๋ฌธ์ enc_pad_index
๋งค๊ฐ๋ณ์์๋ ์ธ์ฝ๋ vocab
์์ ์ ์ํ pad_token_ID
๋ฅผ ์ ๋ฌํ๋ฉด ๋๋ค.
# scaled_dot_product_attention์ ์ผ๋ถ
if mask is not None:
attention_matrix = attention_matrix.masked_fill(mask == 0, float('-inf'))
์ด๋ ๊ฒ ๊ตฌํ๋ ๋ง์คํน์ scaled_dot_product_attention
๋ฉ์๋์ ๊ตฌํ๋ ์กฐ๊ฑด๋ฌธ์ ํตํด ๋ง์คํน ๋์์ -โ์ผ๋ก ๋ณํํ๋ ์ญํ ์ ํ๊ฒ ๋๋ค.
Leave a comment