๐ก [Roformer] RoFormer: Enhanced Transformer with Rotary Position Embedding
๐ญย Overview
Roformer
๋ 2021๋
์ ๋ฐํ๋ ํธ๋์คํฌ๋จธ ๋ชจ๋ธ์ ๋ณํ์ผ๋ก, RoPE(Rotary Position Embedding)
์ด๋ผ๋ ์๋ก์ด ์์น ์ ๋ณด ํฌ์ฐฉ ๋ฐฉ์์ ์ ์ํ๋ค. ๊ทผ๋ ์ ๋ช
ํ ์คํ์์ค LLM ๋ชจ๋ธ๋ค(GPT-Neo, LLaMA)์ ์์น ์ ๋ณด ํฌ์ฐฉ ๋ฐฉ์์ผ๋ก ์ฑํ ๋์ด ์ฃผ๋ชฉ์ ๋ฐ๊ณ ์๋ค. RoPE
๊ธฐ๋ฒ์ ๋ํด ์ดํด๋ณด๊ธฐ ์ ์ ์ผ๋จ, ๊ด๋ จ ๋ถ์ผ์ ์ฐ๊ตฌ ๋ํฅ ๋ฐ ์์น ์ ๋ณด์ ๊ฐ๋
์ ๋ํด ๊ฐ๋จํ๊ฒ ์ดํด๋ณด๊ณ ๋์ด๊ฐ๋ ค ํ๋ค.
๐ค Absolute Position vs Relative Position
ํธ๋์คํฌ๋จธ๊ฐ ์ฑ๊ณต์ ๊ฑฐ๋ ์ ์์๋ ์ด์ ๋ ์ ์ฒด ์ํ์ค๋ฅผ ๋ณ๋ ฌ์ ์ผ๋ก ํ ๋ฒ์ ์ฒ๋ฆฌํ๋, ์ํ์ค ๋ฐ์ ์์ ์ ๋ณด๋ฅผ ํ๋ ฌํฉ ๋ฐฉ์์ผ๋ก ์ธ์ฝ๋ฉํด์คฌ๊ธฐ ๋๋ฌธ์ด๋ค. ์ด ๋ถ์ผ์ ๋ํ ์ฐ๊ตฌ ๋ํฅ์ ํฌ๊ฒ Absolute Position
, Relative Position
๋ฐฉ์์ผ๋ก ๋ถํ๋๋ค.
Absolute Position
์ ์ฃผ์ด์ง ์ํ์ค์ ๊ธธ์ด๋ฅผ ์ธก์ ํ ๋ค, ๋์ด๋ ์์ ๊ทธ๋๋กย forward
ํ๊ฒย 0
๋ถํฐย ๊ธธ์ด-1
์ ๋ฒํธ๋ฅผ ๊ฐ๋ณ ํ ํฐ์ ํ ๋นํ๋ค. ๋ค์ ๋งํด, ๋จ์ด๊ฐ ์ํ์ค์์ ๋ฐ์ํ ์์๋ฅผ ์ํ์ ์ผ๋ก ํํํด ๋ชจ๋ธ์ ์ฃผ์
ํ๋ค๋ ์๋ฏธ๊ฐ ๋๋ค.
ํํธ, Relative Position
์ ์ํ์ค ๋ด๋ถ ํ ํฐ ์ฌ์ด์ ์์น ๊ด๊ณ ํํ์ ํตํด ํ ํฐ ์ฌ์ด์ย relation
์ย pairwise
ํ๊ฒ ํ์ตํ๋ ์์น ์๋ฒ ๋ฉ ๊ธฐ๋ฒ์ ๋งํ๋ค. ์ผ๋ฐ์ ์ผ๋ก ์๋ ์์น ๊ด๊ณ๋ ์๋ก ๋ค๋ฅธ ๋ ํ ํฐ์ ์ํ์ค ์ธ๋ฑ์ค ๊ฐ์ ์ฐจ๋ฅผ ์ด์ฉํด ๋ํ๋ธ๋ค. ํฌ์ฐฉํ๋ ๋ฌธ๋งฅ ์ ๋ณด๋ ์์์ ํจ๊นจ ์ค๋ช
ํ๊ฒ ๋ค. ์์๋ ์์ DeBERTa ๋
ผ๋ฌธ์์ ๋์๋ ๊ฒ์ ํ์ฉํ๋ค. ๋ฅ๋ฌ๋์ด๋ผ๋ ๋จ์ด๋ ์์ด๋กย Deep Learning
ย ์ด๋ค. ๋ ๋จ์ด๋ฅผ ํฉ์ณ๋๊ณ ๋ณด๋ฉดย ์ ๊ฒฝ๋ง์ ์ฌ์ฉํ๋ ๋จธ์ ๋ฌ๋ ๊ธฐ๋ฒ์ ํ ์ข
๋ฅ
๋ผ๋ ์๋ฏธ๋ฅผ ๊ฐ๊ฒ ์ง๋ง, ๋ฐ๋ก ๋ฐ๋ก ๋ณด๋ฉดย ๊น์
,ย ๋ฐฐ์
์ด๋ผ๋ ๊ฐ๋ณ์ ์ธ ์๋ฏธ๋ก ๋๋๋ค.
1) The Deep Learning is the Best Technique in Computer Science
2) Iโm learning how to swim in the deep ocean
Deep
๊ณผย Learning
์ ์๋์ ์ธ ๊ฑฐ๋ฆฌ์ ์ฃผ๋ชฉํ๋ฉด์ ๋ ๋ฌธ์ฅ์ ํด์ํด๋ณด์. ์ฒซ ๋ฒ์งธ ๋ฌธ์ฅ์์ ๋ ๋จ์ด๋ ์ด์ํ๊ฒ ์์นํดย ์ ๊ฒฝ๋ง์ ์ฌ์ฉํ๋ ๋จธ์ ๋ฌ๋ ๊ธฐ๋ฒ์ ํ ์ข
๋ฅ
ย ๋ผ๋ ์๋ฏธ๋ฅผ ๋ง๋ค์ด๋ด๊ณ ์๋ค. ํํธ ๋ ๋ฒ์งธ ๋ฌธ์ฅ์์ ๋ ๋จ์ด๋ ๋์ด์ฐ๊ธฐ ๊ธฐ์ค 5๊ฐ์ ํ ํฐ๋งํผ ๋จ์ด์ ธ ์์นํด ๊ฐ๊ฐย ๋ฐฐ์
,ย ๊น์
ย ์ด๋ผ๋ ์๋ฏธ๋ฅผ ๋ง๋ค์ด ๋ด๊ณ ์๋ค. ์ด์ฒ๋ผ ๊ฐ๋ณ ํ ํฐ ์ฌ์ด์ ์์น ๊ด๊ณ์ ๋ฐ๋ผ์ ํ์๋๋ ๋ฌธ๋งฅ์ ์ ๋ณด๋ฅผ ํฌ์ฐฉํ๋ ค๋ ์๋๋ก ์ค๊ณ๋ ๊ธฐ๋ฒ์ด ๋ฐ๋กย Relative Position Embedding
ย ์ด๋ค.
๐ค Word Context vs Relative Position vs Absolute Position
์ง๊ธ๊น์ง Relative Position Embedding
์ด ๋ฌด์์ด๊ณ , ๋๋์ฒด ์ด๋ค ๋ฌธ๋งฅ ์ ๋ณด๋ฅผ ํฌ์ฐฉํ๋ค๋ ๊ฒ์ธ์ง ์์๋ดค๋ค. ํ์์ ์ค๋ช
์ด ๋งค๋๋ฝ์ง ๋ชปํ๊ธฐ๋ ํ๊ณ ์์๋ฅผ ํ
์คํธ๋ก ๋ค๊ณ ์์ด์ ์ง๊ด์ ์ผ๋ก word context
๋ ๋ฌด์์ธ์ง, Position
์ ๋ณด์๋ ๋ญ๊ฐ ๋ค๋ฅธ์ง, ๋ ๊ฐ์ง Position
์ ๋ณด๋ ๋ญ๊ฐ ์ด๋ป๊ฒ ๋ค๋ฅธ์ง ์๋ฟ์ง ์๋ ๋ถ๋ค์ด ๋ง์ผ์ค ๊ฒ ๊ฐ๋ค. ๊ทธ๋์ ์ต๋ํ ์ง๊ด์ ์ธ ์์๋ฅผ ํตํด ์ธ๊ฐ์ง ์ ๋ณด์ ์ฐจ์ด์ ์ ์ค๋ช
ํด๋ณด๋ ค ํ๋ค.
์ฌ๋ 5๋ช
์ด ๊ณตํญ ์ฒดํฌ์ธ์ ์ํด ์ ์๋ค. ๋ชจ๋ ์ผ์ชฝ์ ๋ณด๊ณ ์๋ ๊ฒ์ ๋ณด์ ์ผ์ชฝ์ ํค๊ฐ ์ ์ผ ์์ ์ฌ์๊ฐ ๊ฐ์ฅ ์์ค์ด๋ผ๊ณ ๋ณผ ์ ์๊ฒ ๋ค. ์ฐ๋ฆฌ๋ ์ค ์์๋ ์์๋๋ก 5๋ช
์ ์ฌ๋์๊ฒ ๋ฒํธ๋ฅผ ๋ถ์ฌํ ๊ฒ์ด๋ค. ํธ์์ 0๋ฒ๋ถํฐ ์์ํด 4๋ฒ๊น์ง ๋ฒํธ๋ฅผ ์ฃผ๊ฒ ๋ค. 1๋ฒ์ ํด๋นํ๋ ์ฌ๋์ ๋๊ตฌ์ธ๊ฐ?? ๋ฐ๋ก ์ค์ 2๋ฒ์งธ์ ์์๋ ์ฌ์๋ค. ๊ทธ๋ผ 2๋ฒ์ ํด๋นํ๋ ์ฌ๋์ ๋๊ตฌ์ธ๊ฐ?? ์ฌ์ง ์ ์ค์ ๊ฐ์ฅ ์ค๊ฐ์ ์๋ ๋จ์๊ฐ 2๋ฒ์ด๋ค. ์ด๋ ๊ฒ ๊ทธ๋ฃน ๋จ์(์ ์ฒด ์ค)์์ ๊ฐ๊ฐ์ธ์ ์ผ๋ จ์ ๋ฒํธ๋ฅผ ๋ถ์ฌํด ์์น๋ฅผ ํํํ๋ ๋ฐฉ๋ฒ์ด ๋ฐ๋ก Absolute Position Embedding
์ด๋ค.
ํํธ, ๋ค์ 2๋ฒ ์ฌ๋์๊ฒ ์ฃผ๋ชฉํด๋ณด์. ์ฐ๋ฆฌ๋ 2๋ฒ ๋จ์๋ฅผ ์ ์ฒด ์ค์์ ๊ฐ์ด๋ฐ ์์นํ ์ฌ๋์ด ์๋๋ผ, ๊ฒ์ ์ ์๋ณต๊ณผ ๊ตฌ๋๋ฅผ ์ ๊ณ ์์ ์ฅ ๋ฌด์ธ๊ฐ๋ฅผ ์์ํ๊ณ ์๋ ์ฌ๋์ด๋ผ๊ณ ํํํ ์๋ ์๋ค. ์ด๊ฒ์ด ๋ฐ๋ก ํ ํฐ์ ์๋ฏธ ์ ๋ณด๋ฅผ ๋ด์ word context
์ ํด๋นํ๋ค.
๋ง์ง๋ง์ผ๋ก Relative Position Embedding
๋ฐฉ์์ผ๋ก 2๋ฒ ๋จ์๋ฅผ ํํํด๋ณด์. ์ค๋ฅธ์์ผ๋ก๋ ์ปคํผ๋ฅผ ๋ค๊ณ ๋ค๋ฅธ ์์ผ๋ก๋ ์บ๋ฆฌ์ด๋ฅผ ์ก๊ณ ์์ผ๋ฉฐ ๊ฒ์ ์ ํ์ดํ๊ณผ ๋ฒ ์ด์ง์ ๋ฐ์ง๋ฅผ ์
์ 1๋ฒ ์ฌ์์ ๋ค์ ์๋ ์ฌ๋, ํ์ ์๋ณต๊ณผ ๊ฒ์ ๋ฟํ
์๊ฒฝ์ ์ฐ๊ณ ํ ์์๋ ์บ๋ฆฌ์ด๋ฅผ ์ก๊ณ ์๋ 4๋ฒ ์ฌ์์ ์์ ์๋ ์ฌ๋, ๊ฒ์ ์ ์์ผ๊ณผ ์ฒญ๋ฐ์ง๋ฅผ ์
๊ณ ํ ์์๋ ํ์ ์ฝํธ๋ฅผ ๋ค๊ณ ์๋ ์ค์ ๋งจ ์ ์ฌ์๋ก๋ถํฐ 2๋ฒ์งธ ๋ค์ ์์๋ ์ฌ๋, ํฑ์์ผ์ด ๊ธธ๊ณ ๋จธ๋ฆฌ๊ฐ ๊ธด ํธ์ด๋ฉฐ ํ๋์ ๊ฐ๋๊ฑด์ ์
๊ณ ์ด๋ก์๊ณผ ๊ฒ์ ์์ด ํผํฉ๋ ๊ฐ๋ฐฉ์ ์ผ์ชฝ์ผ๋ก ๋ฉ๊ณ ์๋ ๋จ์๋ก๋ถํฐ 2๋ฒ์งธ ์์ ์๋ ์ฌ๋.
์ด์ฒ๋ผ ํํํ๋๊ฒ ๋ฐ๋ก Relative Position Embedding
์ ๋์๋๋ค๊ณ ๋ณผ ์ ์๋ค. ์ด์ ์์น ์๋ฒ ๋ฉ์ ๋ํด์ ์ดํด๋ดค์ผ๋, ๋
ผ๋ฌธ์์ ์ ์ํ๋ ๋ด์ฉ์ ๋ํด์ ์์๋ณด์.
๐๏ธ Previous Work: Relative Position Embedding
๋ฏธ๋ฆฌ ๋งํ์๋ฉด, RoPE
๋ ์์น ์ ๋ณด ์ค์์ ์๋ ์์น๋ฅผ ํฌ์ฐฉํ๋ค. ๊ทธ๋์ ์ ์๋ ๊ทธ๋ค์ ๋ฐฉ๋ฒ๋ก ์ ์๊ฐํ๊ธฐ ์ ์ ๋จผ์ , ์ด์ ์ฐ๊ตฌ๋ค์ ์๋ ์์น ํฌ์ฐฉ ๋ฐฉ์์ ๋ํด์ ์๊ฐํ๊ณ ์๋ค. ๊ฐ๋จํ ์ดํด๋ณด์.
(1)๋ฒ ์์์ Transformer-XL
๋
ผ๋ฌธ์์ ์ ์๋ Cross Attention
์์์ด๋ค. ์์น ์ ๋ณด๋ฅผ ๋ด์๋ด๋ ํญ์ ๋ฐ๋ก ๋ง๋ค๊ณ ์ฟผ๋ฆฌ, ํค์ ๋์๋๋ ํญ๊ณผ ๊ณฑํ๊ณ ์๋ค. (2)๋ฒ ์์์ DeBERTa
๋ชจ๋ธ์์ ์ ์๋ Disentangled Attention
์ด๋ค. (1)๊ณผ ๊ตฌ์ฑ์ ์ฐจ์ด๋ ์์ง๋ง ์ญ์, ์์น ์ ๋ณด๋ฅผ ๋ด์๋ด๋ ํญ์ ์ต์ง๋ก ๋ง๋ค๊ณ ๊ทธ๊ฒ๋ค์ ์ฟผ๋ฆฌ ํน์ ํค์ ๊ณฑํ์ฌ ์์น ์ ๋ณด๋ฅผ ๋ด์๋ธ ๋ค, ๋ชจ๋ ํฉํ์ฌ ์ดํ
์
ํ๋ ฌ์ ๋ง๋ค์ด ๋ด๊ณ ์๋ค.
์ ๋ฆฌํ๋ฉด, ๊ธฐ์กด ์ฐ๊ตฌ๋ค์ ์๋ ์์น๋ฅผ ํฌ์ฐฉํ๊ธฐ ์ํด ๋ณ๋์ ํฌ์ง์ ํ๋ ฌ์ ๋ง๋ค๊ณ , ์ด๋ฆฌ์ ๋ฆฌ ๊ณฑํ๊ณ , ๋ค์ ๊ทธ๊ฒ๋ค์ ๋ชจ๋ ํฉํ์ฌ ์ดํ ์ ํ๋ ฌ์ ๋ง๋ค๊ณ ์๋ ๊ฒ์ด๋ค. ๊ธฐ์กด ์ฐ๊ตฌ๋ค์ด ์ ์ํ๋ ๋ฐฉ๋ฒ๋ก ๋ค์ ๊ณตํต๋ ๋ฌธ์ ๋ ํ์ตํด์ผ ํ ํ๋ผ๋ฏธํฐ ์๊ฐ ๋์ด๋ ๋ชจ๋ธ ์ฌ์ด์ฆ๋ ์ปค์ง๊ณ , ํ์ต์๊ฐ๋ ๋์ด๋๋ค๋ ๊ฒ์ด๋ค.
๐กย RoPE
\[f_{q,k}(x_m, m)= \left( \begin{array}{cc}\cos(m\theta) & \sin(m\theta) \\-\sin(m\theta) & \cos(m\theta)\end{array} \right)
\left( \begin{array}{cc}W^{(11)}_{q,k} & W^{(12)}_{q,k} \\W^{(21)}_{q,k} & W^{(22)}_{q,k} \end{array} \right)
\left( \begin{array}{cc}x_m^{(1)} \\x_m^{(2)} \end{array} \right)\]
๋ฑ์์ ์ข๋ณ์ word embedding
์ ์ ํ ํฌ์ ์์ผ ์ป์ query
, key
๋ฒกํฐ์ Rotary Position Embedding
๊ฐ์ ์ถ๊ฐํ ๊ฒฐ๊ณผ ๊ฐ์ ๋ปํ๋ค. ์ฐ๋ณ์ ์์์ด ์๋นํ ๋ณต์กํด ๋ณด์ด๋, ์ค์์ ๋งค์ฐ ๊ฐ๋จํ๋ค. ์ ํ ํฌ์์ผ๋ก ์ป์ query
, key
๋ฒกํฐ์ ์ข์ธก์ ๊ดด๋ํ๊ฒ ์๊ธด ํ๋ ฌ์ ๊ณฑํด์ฃผ๊ฒ ๋ค๋ ๊ฒ์ด๋ค. ์ข์ธก์ ํ๋ ฌ์ ๋ํ๊ต ์ ํ๋์ ์๊ฐ์ ์ค์น๋ฏ ์ง๋๊ฐ๋ Transformation Matrix(ํ์ ํ๋ ฌ)
์ด๋ค. $m$์ $m$-th ํ ํฐ์ ์๋ฏธํ๋๋ฐ, ์ธํ๊ฐ ๋ญ์ง๋ ๋ชจ๋ฅด๊ฒ ์ง๋ง ์ผ๋จ ํ ํฐ์ ์ธ๋ฑ์ค ๊ฐ์ ๋ฐ๋ผ์, ์ฃผ์ด์ง ์๋ ์๋ฒ ๋ฉ ๋ฒกํฐ๋ฅผ ํ์ ์ํค๊ฒ ๋ค๋ ๊ฒ์ด๋ค. ์ง๊ธ ์ดํด๋ณธ ์์๋ ์๋์ธต ํฌ๊ธฐ๊ฐ 2์ฐจ์์ธ ๋จ์ํ ๋ฒกํฐ์๋ค. ์ค์ ๋ชจ๋ธ์ ์ฌ์ฉํ๋ ์ฐจ์(384, 512, 768, โฆ)์ผ๋ก ํ์ฅํ๊ธฐ ์ ์ ์ธํ์ ์ ์ฒด์ ๋ํด ์์๋ณด์.
$\theta$์ ์ ์ฒด๋ ๋ฐ๋ก ์ฃผ๊ธฐํจ์ ์๋ค. ํจ์ดํ ํธ๋์คํฌ๋จธ์์ Absolute Position Encoding
์ ์ํด Sinusoidal
ํจ์๋ฅผ ์ฌ์ฉํ ๊ฒ๊ณผ ๊ฐ์ ์ด์น๋ผ๊ณ ์๊ฐํ๋ฉด ๋๋ค. ์ฆ $\theta$๋ word embedding
๋ฒกํฐ๊ฐ ๊ฐ์ง ์๋์ธต ์ฐจ์ ๋ฐฉํฅ ์ธ๋ฑ์ค์ ๋ฐ๋ผ์ ๋ฌ๋ผ์ง๋ค. ์ฌ๊ธฐ์ ์ํ์ค ๊ธธ์ด ์ฐจ์ ๋ฐฉํฅ์ ์ธ๋ฑ์ค ๊ฐ์ ๋ฐ๋ก ๊ณฑํด์ฃผ๊ธฐ ๋๋ฌธ์ ๊ทธ ์ ์ผ์ฑ์ ๋ณด์ฅํ ์ ์๋ค.
์ด์ ์ ์ฒด RoPE๋ฅผ ์ดํดํ๋๋ฐ ํ์ํ ์ฌ๋ฃ ์ค๋น๋ ๋ชจ๋ ๋๋ฌ๋ค. ์ด์ ์ค์ ์ฐจ์์ผ๋ก ํ์ฅํด๋ณด์.
\[fq,k(x_m,m)=R^d_{ฮ,m}W_{q,k}x_m \\\]ํ๋ ฌ $R^d_{ฮ,m}$์ ์๋์ ๊ฐ์ ํ๋ ฌ์ ๋งํ๋๋ฐ,
\[R^d_{ฮ,m} = \begin{bmatrix} \cos(m\theta_1) & -\sin(m\theta_1) & 0 & 0 & \cdots & 0 & 0 \\ \sin(m\theta_1) & \cos(m\theta_1) & 0 & 0 & \cdots & 0 & 0 \\ 0 & 0 & \cos(m\theta_2) & -\sin(m\theta_2) & \cdots & 0 & 0 \\ 0 & 0 & \sin(m\theta_2) & \cos(m\theta_2) & \cdots & 0 & 0 \\ \vdots & \vdots & \vdots & \vdots & \ddots & \vdots & \vdots \\ 0 & 0 & 0 & 0 & \cdots & \cos(m\theta_{d/2}) & -\sin(m\theta_{d/2}) \\ 0 & 0 & 0 & 0 & \cdots & \sin(m\theta_{d/2}) & \cos(m\theta_{d/2}) \end{bmatrix}\]ํ ํฐ์ ์ธ๋ฑ์ค์ ๋ชจ๋ธ์ ์๋์ฐจ์ ์ธ๋ฑ์ค์ ๋ฐ๋ผ์ ํ๋ ฌ์ ์์๊ฐ์ด ๊ฒฐ์ ๋จ์ ์ ์ ์๋ค. ์ด์ ๋ค์ (3)๋ฒ ์์์ ์๋ฏธ๋ฅผ ์๊ฐํด๋ณด์. ๋จ์ด ์๋ฒ ๋ฉ์ ์ฟผ๋ฆฌ, ํค ํ๋ ฌ๋ก ์ ํ ํฌ์ํ ๋ค (4)๋ฒ ์์์ ๊ณฑํ๋ค. ์์ํ ํ์ ํ๋ ฌ์ ์ฟผ๋ฆฌ, ํค ๋ฒกํฐ์ ๊ณฑํ๊ธฐ ๋๋ฌธ์ ๋ฒกํฐ์ ํฌ๊ธฐ๋ฅผ ์ ์งํ์ฑ, ๋ฐฉํฅ๋ง ๋ฐ๊ฟ์ค ์ ์๋ค๋ ์ฅ์ ์ด ์๋ค.
์ด์ ์ ์ฐ๊ตฌ๋ค์ ํฌ์ง์ ์ ๋ณด๋ฅผ ๊ฐ์ง๊ณ ์๋ ํ๋ ฌ์ ๋จ์ด ๋ฒกํฐ์ ๋ํ๊ธฐ ๋๋ฌธ์ ๋ฒกํฐ์ ๋ฐฉํฅ์ ๋ฌผ๋ก ํฌ๊ธฐ ์ญ์ ์๊ณก๋๋ค. ๋ฌผ๋ก ๋จ์ด ๋ฒกํฐ์ ํฌ์ง์ ๋ฒกํฐ๊ฐ ์๋ก ์ฑ๊ฒฉ์ด ๋ค๋ฅธ ์ ๋ณด๋ผ๋ ์ ์ ๊ณ ๋ คํ๋ฉด ๋ชจ๋ธ์ ์๋์ธต์ฒ๋ผ ๊ณ ์ฐจ์ ๊ณต๊ฐ์์ ์๋ก ์ง๊ตํ ํ๋ฅ ์ด ๋งค์ฐ ๋๊ธฐ ๋๋ฌธ์, ์๋ก ํ์ต์ ์ํฅ์ ๋ฏธ์น ๊ฐ๋ฅ์ฑ์ ๋ฎ๋ค. ํ์ง๋ง ํ๋ฅ ์ ์ธ ์ ๊ทผ์ผ ๋ฟ๋๋ฌ, ๋จ์ด ๋ฒกํฐ์ ํฌ๊ธฐ๊ฐ ์๊ณก๋๋ค๋ ์ ์ด ์ธต์ ๊ฑฐ๋ญํ ์๋ก ์ํฅ์ ๋ฏธ์น ์ง ์ ์ ์๋ค.
RoPE ๋ฐฉ์์ ๋๋ค๋ฅธ ์ฅ์ ์ ๊ณฑํ๋ ๊ฒ๋ง์ผ๋ก๋, ์๋ ์์น ์ ๋ณด๋ฅผ ์ธ์ฝ๋ฉ ํด์ค ์ ์๋ค๋ ์ ์ด๋ค. ์ด์ ์ฐ๊ตฌ๋ค์ ๋๋ถ๋ถ ์ ๋ ์์น ํน์ ์๋ ์์น ํ๋๋ง์ ์ ํํด ๋จ์ด ์๋ฒ ๋ฉ์ ์ ๋ณด๋ฅผ ์ถ๊ฐํด์ฃผ๋ ๊ฒฝ์ฐ๊ฐ ๋๋ค์ ์๋ค. DeBERTa์ ๊ฒฝ์ฐ์๋ง, Task ๋ ์ด์ด ๊ทผ์ฒ(๋ ์ด์ด ํ๋ฐ๋ถ)์ ๊ฐ์ ์ ๋ ์์น๋ฅผ ๋ํด ์๋ ์์น๊ฐ ๊ฐ๋ ๋จ์ ์ ๋ณด์ํ๋ ค๋ ์๋๋ฅผ ํ๋ค. DeBERTa๊ฐ ์ฌ๋ฌ ๋ฐฉ๋ฉด์์ ์๋นํ ์ข์ ์ฑ๋ฅ์ ๊ฑฐ๋ฌ์ ๊ทธ๋ ์ง, ๋ง์ง๋ง ๋ ์ด์ด ๊ทผ์ฒ์ ๊ฐ์ ์ ๋ ์์น๋ฅผ ๋ํด์ฃผ๋๊ฒ ์ฌ์ค ์์ฐ์ค๋ฝ๋ค๊ณ ์๊ฐ๋์ง๋ ์๋๋ค. ๊ทธ๋ฐ๋ฐ RoPE๋ ํ์ ํ๋ ฌ์ ๊ณฑํ๋ ๊ฒ๋ง์ผ๋ก๋ ์ ๋ ์์น์ ์๋ ์์น ๋ชจ๋ ์ธ์ฝ๋ฉ์ด ๊ฐ๋ฅํ๋ค. ์ด๋ป๊ฒ ๊ทธ๋ด๊น??
์ผ๋จ RoPE ์ ํ ํฌ์๋ ์ฟผ๋ฆฌ, ํค ํ๋ ฌ์ ๊ฐ๊ฐ ํ์ ํ๋ ฌ์ ๊ณฑํ๋ค. ๊ณฑํ๋ ๊ณผ์ ์์ ์ด๋ฏธ ํ ํฐ์ ์ธ๋ฑ์ค ๊ฐ์ ๋ฐ๋ผ์ ์๋ก ๋ค๋ฅธ ํฌ์ง์ ๊ฐ์ด ๋จ์ด ์๋ฒ ๋ฉ์ ๊ณฑํด์ง๊ฒ ๋๋ค. ์ด๊ฒ์ผ๋ก ์ผ๋จ ์ ๋ ์์น ์ ๋ณด๋ฅผ ์ถ๊ฐํด์ค ์ ์๋ค. ๊ทธ๋ฆฌ๊ณ ์ ์๋ค์ํผ, ์ฟผ๋ฆฌ์ ํค์ ๋ด์ ์ ์ํํ๋ค. ์ฟผ๋ฆฌ์ ํค์ ๋ด์ ์ ๊ฐ๊ฐ ๋จ์ด ์๋ฒ ๋ฉ, ์ ํ ํฌ์, ํ์ ํ๋ ฌ ํญ์ผ๋ก ๋๋ ์ ์์ ํ์ด ์ฐ๋ฉด ์๋์ ๊ฐ๋ค.
\[q^T_mk_n=(R^d_{ฮ,m}W_{q}x_m)^T(R^d_{ฮ,n}W_{k}x_n) \ \ \ (5)\]์์์ ์ ๊ฐํ๋ฉด ์์ฐ์ค๋ ,
\[x^TW_qR^d_{ฮ,n-m}W_kx_n \ \ \ (6)\](6)๋ฒ ์์์ฒ๋ผ ๋๋ค. ํ๋ ฌ $R^d_{ฮ,n-m}$์ ์์๋ ์๋์ฒ๋ผ,
\[\cos(m\theta_1)*\cos(n\theta_1) - \sin(m\theta_1)*\sin(n\theta_1) \\\]ํ ํฐ ์ธ๋ฑ์ค๋ฅผ ์๋ฏธํ๋ $m,n$์ ๋ํ ์์์ผ๋ก ํํ๋๋ค. ๋ฐ๋ผ์ ์์ฐ์ค๋ฝ๊ฒ ์๋ ์์น๋ฅผ ํฌ์ฐฉํ ์ ์๊ฒ ๋๋ค. ์๋นํ ์์ฐ์ค๋ฝ๊ฒ ์๋ก ๋ค๋ฅธ ๋ ์์น ์ ๋ณด๋ฅผ ์ธ์ฝ๋ฉํ๋๊ฒ ๊ฐ๋ฅํ๋ฉฐ, ์ถ๊ฐ์ ์ผ๋ก ๋ค๋ฅธ ํญ์ ๋ง๋ค์ด ์ดํ ์ ํ๋ ฌ์ ๊ณ์ฐํ์ง ์๊ธฐ ๋๋ฌธ์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ข ๋ ํจ์จ์ ์ผ๋ก ์ฌ์ฉ ๊ฐ๋ฅํ๋ค.
ํํธ, ํ ํฐ์ ์๋ ์์น๋ฅผ ํฌ์ฐฉํ๋ ๋ฐฉ์์ ์์ ๊ณผ ์๋์ ๊ฑฐ๋ฆฌ๊ฐ ๋ฉ์ด์ง์๋ก ์๋ฏธ์ ์ฐ๊ด์ฑ์ด๋ ๊ด๊ณ์ฑ์ด ๋จ์ด์ง๋ค๋ ์ ์ ์ ์ ๋ก ํ๋ค. ์ฆ, ์๋ก ๊ฑฐ๋ฆฌ๊ฐ ๋จผ ํ ํฐ์ผ์๋ก ์ฟผ๋ฆฌ์ ํค๋ฒกํฐ์ ๋ด์ ๊ฐ์ด 0์ ๊ฐ๊น์์ ธ์ผ ํ๋ค๋ ๊ฒ์ด๋ค. ์ ์ ์ญ์ ์ด์ ์ ์ธ๊ธํ๋ฉฐ RoPE
๋ฐฉ์์ด Long-Term Decay
์์ฑ์ ๊ฐ๊ณ ์๋ค๊ณ ์ฃผ์ฅํ๋ค.
Appendix
์์ ์ํ์ ์ผ๋ก ์ฆ๋ช
๊น์ง ์ ์ํ๊ณ ์์ผ๋, ํ์์ ์ํ ์ค๋ ฅ์ด ์์์ ์ ์๋ ๊ณผ์ ์ด ์ดํด๊ฐ ๊ฐ์ง ์๋๋ค. ์ถํ์ ๊ด๋ จ ๋ด์ฉ์ ์ถ๊ฐํ๋๋ก ํ๊ฒ ๋ค. ์ผ๋จ Relative Upper Bound๊ฐ ์ ํํ ๋ฌด์์ ๋งํ๋์ง ๋ชจ๋ฅด๊ฒ ์ง๋ง(๋
ผ๋ฌธ์ ์ ๋๋ก ์ธ๊ธ x, ์ถ์ธกํ๊ฑด๋ฐ, ์๋ฏธ์ ์ฐ๊ด์ฑ์ ๋ํ๋ด๋ ์งํ ๊ฐ์, ์๋ง ๋ด์ ๊ฐ์ผ๋ก ์ถ์ ), ์ ์๋ ๊ทธ๋ํ๋ฅผ ๋ณด๋ฉด ์๋ก ์๋์ ๊ฑฐ๋ฆฌ๊ฐ ๋ฉ์ด์ง์๋ก ํด๋น ์งํ๊ฐ ํ์ฐํ ๊ฐ์ํ๋ ์ถ์ธ๋ฅผ ๋ณด์ธ๋ค.
๋ง์ง๋ง์ผ๋ก ๋
ผ๋ฌธ์์ ๋ฐํ๊ธธ (4), (5)๋ฒ ์์์ ํํ๋ก RoPE๋ฅผ ๋ง๋๋ ๊ฒ์ ์ฐ์ฐ ํจ์จ์ด ๋จ์ด์ง๋ค๊ณ ํ๋ค. ๊ทธ๋์ Appendix
์์ ํจ์จ์ ์ผ๋ก ์ฐ์ฐํ๋ ์์์ ๋ค์ ์ ์ํ๊ณ ์๋ค.
์์ (4), (5)๋ฒ ํํ ๊ทธ๋๋ก ๊ตฌํํ๋ ค๋ฉด, ํฌ๊ธฐ๊ฐ [seq_len, dim_head, dim_head]
์ธ ํ
์๋ฅผ ๊ณ์ ๊ฐ์ง๊ณ ์์ด์ผ ํ๋ค. ์ด๋ ์๋นํ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๋ญ๋นํ๊ฒ ๋๋ค. ์๋ ๊ทธ๋ฆผ์ ํ์๊ฐ (4), (5)๋ฒ ํํ ๊ทธ๋๋ก ๊ตฌํํ ๋ค, MLM ํ์ต์ ๋๋ฆฌ๋ ๋ชจ์ต์ด๋ค.
[body ver result]
11์๊ฐ 40๋ถ์ผ๋ก ํ๋ จ ์๊ฐ์ด ์์ธก๋๋๊ฑธ ๋ณผ ์ ์๋ค. ๋ฌผ๋ก , ์ด๋ฌํ ๊ฒฐ๊ณผ๊ฐ ๋์จ ์ด์ ๋ \(R^d_{ฮ,m}x\)์ด ์ฐจ์งํ๋ ๋ฉ๋ชจ๋ฆฌ ํฌ๊ธฐ๊ฐ ์ปค์ง๋ฉด์, GPU ์์ ํ ๋ฒ์ ์ฌ๋ฆด ์๊ฐ ์์ด์ ธ ๋ฐฐ์น๋ง๋ค ๋ฃจํ๋ฅผ ๋๋ ค์ RoPE๋ฅผ ๊ฐ๋ณ ์ฟผ๋ฆฌ, ํค์ ๊ณฑํด์ฃผ๋ ๋ฐฉ์์ ์ ํํ๊ธฐ ๋๋ฌธ์ด๋ค. ์ด์ Appendix์์ ์ ์ํ ๋ฐฉ๋ฒ๋๋ก RoPE๋ฅผ ๊ตฌํํ๋ฉด,
[appendix ver result]
์ด๋ ๊ฒ 4์๊ฐ์ผ๋ก ์๊ฐ์ด ๋๋ผ๋งํฑํ๊ฒ ์ค์ด๋ค์๋ค. ์ด ๋ฐฉ๋ฒ์ ๋ํ $R^d_{ฮ,m}$๋ฅผ [seq_len, dim_head]
ํฌ๊ธฐ๋ฅผ ๊ฐ๋ ํ
์๋ฅผ ์ฌ์ฉํ๋ฉด ๋๊ธฐ ๋๋ฌธ์, ์ด์ ๋ฐฉ์๋ณด๋ค ํจ์ฌ ๋ฉ๋ชจ๋ฆฌ๋ ๋ ์ฐจ์งํ๋ค. ์ด ๋ฐฉ์์ ๋ฐฐ์น ์ฐจ์์ผ๋ก ๋ฃจํ๋ฅผ ๋๋ฆด ํ์๊ฐ ์์ด์ ธ ํ๋ จ์๊ฐ๋ ๋ํญ ๋จ์ถ๋๋ ๊ฒ์ด๋ค.
๐ RoPE with linear attention
์ ์๋ ํจ์ดํ full attention
๋์ <Transformers are RNNs: Fast Autoregressive Transformers with linear attention> ๋
ผ๋ฌธ์์ ์ ์๋ linear attention
์ ์ฌ์ฉํ๋ค๊ณ ๋ฐํ๊ณ ์๋ค.
ํ์ง๋ง, linear attention
์ ๊ฒฝ์ฐ ๋์ฝ๋์ CLM
์ํ์ ์ด์ธ๋ฆฌ๋ ๋ฐฉ์์ผ๋ก, NLU
๋ฅผ ์ํ ์ธ์ฝ๋์๋ ์ ํฉํ์ง ์๋ค. ํด๋น ๋
ผ๋ฌธ์์๋ ๋ชจ๋ธ์ ๋ฒค์น๋งํฌ ๊ฒฐ๊ณผ๋ฅผ ๋ชจ๋ NLG
์ ๋ํด์๋ง ์ ์ํ๋ค. ๊ทธ๋ฆฌ๊ณ ํ์๊ฐ ์ง์ ๊ตฌํํด MLM
์ ์ํํด๋ณธ ๊ฒฐ๊ณผ(์คํ ๊ฒฐ๊ณผ ๋งํฌ) ์ ํ๋๊ฐ ์๋นํ ๋ฎ๊ฒ ๋์ค๋ ๊ฒ์ ์ ์ ์๋ค. ๋ฌผ๋ก ์ ์ด์ ํด๋น ๋ฐฉ์์ ํธ๋์คํฌ๋จธ๋ฅผ RNN
์ฒ๋ผ ์๊ฐ ์ฐจ์์ ๋ํด์ ํ์ตํ๋ ๊ฒฝ์ฐ๋ฅผ ์์ ํ๊ณ ๋ง๋ค์๊ธฐ ๋๋ฌธ์ linear attention
์ BERT ๊ฐ์ ์ธ์ฝ๋ ๋ชจ๋ธ์ ๊ทธ๋๋ก ์ฌ์ฉํ๋๊ฒ ์ ์ด์ ์ ๋ง์ ์ ์๋ค. ํ์ง๋ง ํ๊น
ํ์ด์ค์ roformer
์ฝ๋๋ฅผ ๋ณด๋ฉด ์ญ์, linear attention
๋์ full attention
์ RoPE
๋ฅผ ํตํฉํ๋ ๋ฐฉ์์ผ๋ก ๊ตฌํํ๋ค. ๋ฐ๋ผ์ ํ์ ์ญ์ full attention
์ ๊ธฐ์ค์ผ๋ก ๋ชจ๋ธ์ ๊ตฌํํ์์ ๋ฐํ๋ค.
๐ฉโ๐ปย Implementation by Pytorch
๋
ผ๋ฌธ์ ๋ด์ฉ๊ณผ ์คํผ์
๋ก ๊ณต๊ฐ๋ ์ฝ๋๋ฅผ ์ข
ํฉํ์ฌ ํ์ดํ ์น๋ก Roformer
๋ฅผ ๊ตฌํํด๋ดค๋ค. ๋ค๋ง, linear attention
๋์ full attention
์ ์ฌ์ฉํ๊ณ ์ค์ง ์ธ์ฝ๋ ๋ถ๋ถ๋ง ๊ตฌํํ์์ ๋ฐํ๋ค.
ํํธ, ํ์๊ฐ ์ง์ ๊ตฌํํ RoPE๋ฅผ ์ฝ๋๋ ์์ผ๋, GPU ์ฐ์ฐ ์ต์ ํ๊น์ง๋ ์คํจํด ๋์ ํ๊น ํ์ด์ค์ ๊ตฌํ์ฒด๋ฅผ ์ฐธ๊ณ ํ์์ ๋ฐํ๋ค. ์๊ฐ์ด ๋ ๋, ์ง์ ๊ตฌํํ๋ RoPE ์ฝ๋๋ ํจ๊ผ ์ฒจ๋ถํ๊ฒ ๋ค. ๊ทธ๋ฆฌ๊ณ ์ด๋ฒ ํฌ์คํ ์์๋ RoPE๋ฅผ ๊ตฌํํ๋ ๋ฐฉ๋ฒ์ ๋ํด์๋ง ๋ค๋ฃจ๊ณ , ๋๋จธ์ง ๊ตฌํ์ ๋ํ ์ค๋ช ์ ์๋ตํ๋ ค ํ๋ค. ์ ์ฒด ๋ชจ๋ธ ๊ตฌ์กฐ ๋ํ ์ฝ๋๋ ์ฌ๊ธฐ ๋งํฌ๋ฅผ ํตํด ์ฐธ๊ณ ๋ฐ๋๋ค.
๐ก Rotary Position Embedding
_init_weight()
์ position_enc
๋ฅผ ์ฃผ๋ชฉํด๋ณด์. position_enc
๋ position
๊ณผ dim
์ ์ธ์๋ก ๋ฐ์ position
๊ณผ dim
์ ๋ฐ๋ผ์ position_enc
๋ฅผ ๋ง๋ค์ด๋ด๋๋ฐ, ์ด๊ฒ์ด ๋ฐ๋ก RoPE
์ ํต์ฌ์ด๋ค. ํด๋น ์ฝ๋ ๋ผ์ธ์ด ์ ํํ๊ฒ $m\theta_d$์ ๊ณ์ฐํ๊ฒ ๋๋ค.
class RoFormerSinusoidalPositionalEmbedding(nn.Embedding):
""" This module produces sinusoidal positional embeddings of any length
Original Source code from Huggingface's RoFormer model, which is the most optimized way to create positional embedding
Args:
max_seq: max sequence length of model
dim_head: dimension of each attention head's hidden states
Returns:
Tensor -> torch.Size([seq_len, dim_head])
References:
https://arxiv.org/abs/2104.09864 # RoFormer: Enhanced Transformer with Rotary Position Embedding
https://github.com/huggingface/transformers/blob/main/src/transformers/models/roformer/modeling_roformer.py#L323
"""
def __init__(self, max_seq: int, dim_head: int) -> None:
super().__init__(max_seq, dim_head)
self.weight = self._init_weight(self.weight)
@staticmethod
def _init_weight(out: nn.Parameter) -> nn.Parameter:
"""
Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
the 2nd half of the vector. [dim // 2:]
"""
n_pos, dim = out.shape
position_enc = np.array(
[[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
) # m * theta
out.requires_grad = False # set early to avoid an error in pytorch-1.8+
sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1
out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
out.detach_()
return out
@torch.no_grad()
def forward(self, seq_len: int, past_key_values_length: int = 0) -> Tensor:
positions = torch.arange(
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
)
return super().forward(positions)
class Embedding(nn.Module):
""" Class module for Roformer Embedding, word embedding & rotary positional encoding
This module has option => whether or not to use ALBERT Style Factorized Embedding
Args:
cfg: configuration.py
References:
https://arxiv.org/abs/1706.03762
https://arxiv.org/pdf/1810.04805.pdf
https://arxiv.org/abs/2006.16236
https://arxiv.org/abs/2104.09864 # RoFormer: Enhanced Transformer with Rotary Position Embedding
https://github.com/huggingface/transformers/blob/main/src/transformers/models/roformer/modeling_roformer.py
https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py
"""
def __init__(self, cfg: CFG) -> None:
super(Embedding, self).__init__()
self.cfg = cfg
self.batch_size = cfg.batch_size
self.max_seq = cfg.max_seq
self.dim_model = cfg.dim_model
self.word_embedding = nn.Embedding(len(cfg.tokenizer), cfg.dim_model)
self.layer_norm1 = nn.LayerNorm(cfg.dim_model, eps=cfg.layer_norm_eps) # for word embedding
self.hidden_dropout = nn.Dropout(p=cfg.hidden_dropout_prob)
self.rotary_pos_encoding = RoFormerSinusoidalPositionalEmbedding(
cfg.max_seq,
cfg.dim_model // cfg.num_attention_heads
)
# ALBERT Style Factorized Embedding
if self.cfg.is_mf_embedding:
self.word_embedding = nn.Embedding(len(cfg.tokenizer), int(cfg.dim_model/6))
self.projector = nn.Linear(int(cfg.dim_model/6), cfg.dim_model) # project to original hidden dim
def forward(self, inputs: Tensor) -> Tuple[nn.Embedding, Tensor]:
if self.cfg.is_mf_embedding:
word_embeddings = self.hidden_dropout(
self.layer_norm1(self.projector(self.word_embedding(inputs)))
)
else:
word_embeddings = self.hidden_dropout(
self.layer_norm1(self.word_embedding(inputs))
)
rotary_pos_enc = self.rotary_pos_encoding(inputs.shape[1])
return word_embeddings, rotary_pos_enc
๐จ Integrated RoPE into Full Attention(scaled dot-product attention)
RoPE
๋ฅผ ์ ์ฉํ๋ Full Attention
์ ๊ตฌํ ์์๋ ๋ค์๊ณผ ๊ฐ๋ค. ๋จผ์ , ๋จ์ด ์๋ฒ ๋ฉ์ ์ฟผ๋ฆฌ, ํค, ๋ฒจ๋ฅ ํ๋ ฌ๋ก ์ ํ ํฌ์ํ๋ค. ์ด ๋ RoPE๋ฅผ ๊ณฑํด์ฃผ๊ธฐ ์ํค apply_rotary_position_embeddings()
์ ์ธ์๋ก ์ฟผ๋ฆฌ, ํค ํ๋ ฌ์ ์ ๋ฌํ๋ค. ์ด ๋ ๋ฐ๋์ ๋ฒจ๋ฅ ํ๋ ฌ์ ๋จ์ด ์๋ฒ ๋ฉ์ผ๋ก๋ถํฐ ์ ํ ํฌ์๋ ์ํ๋ฅผ ์ ์งํด์ผํจ์ ๊ธฐ์ตํ์. apply_rotary_position_embeddings()
๋ RoPE
๊ฐ ๊ณฑํด์ง ์ฟผ๋ฆฌ, ํค ํ๋ ฌ์ ๋ฐํํ๋ค. ์ดํ ๊ณผ์ ์ ํจ์ดํ full attention
๊ณผ ๋์ผํ๋ค.
์ธ์๋ก ๋ค์ด๊ฐ๋ ํ ์๋ค์ ๋ชจ์์ ์ฃผ์์ ์ฐธ๊ณ ๋ฐ๋๋ค.
def apply_rotary_position_embeddings(sinusoidal_pos: Tensor, query_layer: Tensor, key_layer: Tensor, value_layer: Tensor = None):
""" Apply rotary position encoding to query, key layer
Original Source code from Huggingface's RoFormer model, which is the most optimized way to create positional embedding
You can find mathematical proof in official paper's Appendix
Args:
sinusoidal_pos: sinusoidal positional encoding, shape [batch(None), num_dim(None), seq_len, dim_head]
query_layer: query matrix, shape (batch_size, num_head, seq_len, dim_head)
key_layer: key matrix, shape (batch_size, num_head, seq_len, dim_head)
value_layer: value matrix, shape (batch_size, num_head, seq_len, dim_head)
References:
https://arxiv.org/abs/2104.09864 # RoFormer: Enhanced Transformer with Rotary Position Embedding
https://github.com/huggingface/transformers/blob/main/src/transformers/models/roformer/modeling_roformer.py#L323
"""
sin, cos = sinusoidal_pos.chunk(2, dim=-1) # select two element of index values
sin_pos = torch.stack([sin, sin], dim=-1).reshape_as(sinusoidal_pos)
cos_pos = torch.stack([cos, cos], dim=-1).reshape_as(sinusoidal_pos)
rotate_half_query_layer = torch.stack([-query_layer[..., 1::2], query_layer[..., ::2]], dim=-1).reshape_as(
query_layer
)
# mathematical expression from Appendix in official repo
query_layer = query_layer * cos_pos + rotate_half_query_layer * sin_pos
rotate_half_key_layer = torch.stack([-key_layer[..., 1::2], key_layer[..., ::2]], dim=-1).reshape_as(key_layer)
key_layer = key_layer * cos_pos + rotate_half_key_layer * sin_pos
if value_layer is not None: # In official, they don't use value_layer
rotate_half_value_layer = torch.stack([-value_layer[..., 1::2], value_layer[..., ::2]], dim=-1).reshape_as(
value_layer
)
value_layer = value_layer * cos_pos + rotate_half_value_layer * sin_pos
return query_layer, key_layer, value_layer
return query_layer, key_layer
class MultiHeadAttention(nn.Module):
def __init__(
self,
dim_model: int = 1024,
num_attention_heads: int = 16,
dim_head: int = 64,
kernel: str = 'softmax',
attention_dropout_prob: float = 0.1
) -> None:
super(MultiHeadAttention, self).__init__()
self.dim_model = dim_model
self.num_attention_heads = num_attention_heads
self.dim_head = dim_head
self.fc_q = nn.Linear(self.dim_model, self.dim_model)
self.fc_k = nn.Linear(self.dim_model, self.dim_model)
self.fc_v = nn.Linear(self.dim_model, self.dim_model)
self.fc_concat = nn.Linear(self.dim_model, self.dim_model)
self.apply_rope = apply_rotary_position_embeddings
self.attention = scaled_dot_product_attention if kernel == 'softmax' else linear_attention
self.attention_dropout = nn.Dropout(p=attention_dropout_prob)
self.dot_scale = torch.sqrt(torch.tensor(self.dim_head, dtype=torch.float32))
self.kernel = kernel
self.eps = 1e-6
def forward(self, x: Tensor, rotary_pos_enc: Tensor, padding_mask: Tensor, attention_mask: Tensor = None) -> Tensor:
""" x is already passed nn.Layernorm, already multiplied with rotary position encoding """
assert x.ndim == 3, f'Expected (batch, seq, hidden) got {x.shape}'
# size: bs, seq, nums head, dim head, linear projection
q = self.fc_q(x).reshape(-1, x.shape[1], self.num_attention_heads, self.dim_head).permute(0, 2, 1, 3).contiguous()
k = self.fc_k(x).reshape(-1, x.shape[1], self.num_attention_heads, self.dim_head).permute(0, 2, 1, 3).contiguous()
v = self.fc_v(x).reshape(-1, x.shape[1], self.num_attention_heads, self.dim_head).permute(0, 2, 1, 3).contiguous()
# multiple word embedding, rotary position encoding
rotary_q, rotary_k = self.apply_rope(rotary_pos_enc, q, k)
attention_matrix = None
if self.kernel == 'elu':
attention_matrix = self.attention(
rotary_q,
rotary_k,
v,
self.kernel,
self.eps,
self.attention_dropout,
padding_mask,
attention_mask
)
elif self.kernel == 'softmax': # pure self-attention
attention_matrix = self.attention(
rotary_q,
rotary_k,
v,
self.dot_scale,
self.attention_dropout,
padding_mask,
attention_mask
)
attention_output = self.fc_concat(attention_matrix)
return attention_output
Leave a comment