๐ชขย [DeBERTa] DeBERTa: Decoding-Enhanced BERT with Disentangled-Attention
๐ญย Overview
DeBERTa
๋ 2020๋
Microsoft
๊ฐ ICLR
์์ ๋ฐํํ ์์ฐ์ด ์ฒ๋ฆฌ์ฉ ์ ๊ฒฝ๋ง ๋ชจ๋ธ์ด๋ค. Disentangled Self-Attention
, Enhanced Mask Decoder
๋ผ๋ ๋๊ฐ์ง ์๋ก์ด ํ
ํฌ๋์ BERT
, RoBERTa
์ ์ ์ฉํด ๋น์ SOTA
๋ฅผ ๋ฌ์ฑํ์ผ๋ฉฐ, ํนํ ์์ด์ฒ๋ผ ๋ฌธ์ฅ์์ ์๋ฆฌํ๋ ์์น์ ๋ฐ๋ผ ๋จ์ด์ ์๋ฏธ, ํํ๊ฐ ๊ฒฐ์ ๋๋ ๊ตด์ ์ด ๊ณ์ด์ ๋ํ ์ฑ๋ฅ์ด ์ข์ ๊พธ์คํ ์ฌ๋๋ฐ๊ณ ์๋ ๋ชจ๋ธ์ด๋ค. ๋ํ ์ธ์ฝ๋ฉ ๊ฐ๋ฅํ ์ต๋ ์ํ์ค ๊ธธ์ด๊ฐ 4096
์ผ๋ก ๋งค์ฐ ๊ธด ํธ (DeBERTa-V3-Large
) ์ ์ํด, Kaggle Competition
์์ ์์ฃผ ํ์ฉ๋๋ค. ์ถ์๋์ง 2๋
์ด ๋๋๋ก SuperGLUE
๋์๋ณด๋์์ ๊พธ์คํ ์์๊ถ์ ์ ์งํ๊ณ ์๋ค๋ ์ ๋ DeBERTa
๊ฐ ์ผ๋ง๋ ์ ์ค๊ณ๋ ๋ชจ๋ธ์ธ์ง ์ ์ ์๋ ๋๋ชฉ์ด๋ค.
ํํธ, DeBERTa
์ ์ค๊ณ ์ฒ ํ์ Inductive Bias
๋ค. ๊ฐ๋จํ๊ฒ Inductive Bias
๋, ์ฃผ์ด์ง ๋ฐ์ดํฐ๋ก๋ถํฐ ์ผ๋ฐํ ์ฑ๋ฅ์ ๋์ด๊ธฐ ์ํดย "์
๋ ฅ๋๋ ๋ฐ์ดํฐ๋ ~ ํ ๊ฒ์ด๋ค"
,ย "์ด๋ฐ ํน์ง์ ๊ฐ๊ณ ์์ ๊ฒ์ด๋ค"
์ ๊ฐ์ ๊ฐ์ , ๊ฐ์ค์น, ๊ฐ์ค ๋ฑ์ ๊ธฐ๊ณํ์ต ์๊ณ ๋ฆฌ์ฆ์ ์ ์ฉํ๋ ๊ฒ์ ๋งํ๋ค. ViT
๋
ผ๋ฌธ ๋ฆฌ๋ทฐ์์๋ ๋ฐํ๋ฏ, ํจ์ดํ Self-Attention
์ Inductive Bias
๋ ์ฌ์ค์ ์์ผ๋ฉฐ, ์ ์ฒด Transformer
๊ตฌ์กฐ ๋ ๋ฒจ์์ ๋ด๋ Absolute Position Embedding
์ ์ฌ์ฉํด ํ ํฐ์ ์์น ์ ๋ณด๋ฅผ ๋ชจ๋ธ์ ์ฃผ์
ํด์ฃผ๋ ๊ฒ์ด ๊ทธ๋๋ง ์ฝํ Iniductive Bias
๋ผ๊ณ ๋ณผ ์ ์๋ค. ๋ค๋ฅธ ํฌ์คํ
์์๋ ๋ถ๋ช
Inductive Bias
๊ฐ ์ ๊ธฐ ๋๋ฌธ์ ์์ฐ์ด ์ฒ๋ฆฌ์์ Transformer
๊ฐ ์ฑ๊ณต์ ๊ฑฐ๋ ์ ์๋ค๊ณ ํด๋๊ณ ์ด๊ฒ ์ง๊ธ ์์ ๋ง์ ๋ค์ง๋๋ค๊ณ ์๊ฐํ ์ ์๋ค. ํ์ง๋ง Self-Attention
๊ณผ Absolute Position Embedding
์ ์๋ฏธ๋ฅผ ๋ค์ ํ ๋ฒ ์๊ธฐํด๋ณด๋ฉด, Inductive Bias
์ถ๊ฐ๋ฅผ ์ฃผ์ฅํ๋ ์ ์๋ค์ ์๊ฐ์ด ๊ฝค๋ ํฉ๋ฆฌ์ ์ด์์์ ์ ์ ์๊ฒ ๋๋ค. ๊ตฌ์ฒด์ ์ธ ๋ชจ๋ธ ๊ตฌ์กฐ๋ฅผ ํ์
ํ๊ธฐ ์ ์ ๋จผ์ Inductive Bias
์ถ๊ฐ๊ฐ ์ ํ์ํ๋ฉฐ, ์ด๋ ํ ๊ฐ์ ์ด ํ์ํ์ง ์์๋ณด์.
๐ชขย Inducitve Bias in DeBERTa
Absolute Position + Relative Position
์ ๋ชจ๋ ํ์ฉํด ํ๋ถํ๊ณ ๊น์ ์๋ฒ ๋ฉ ์ถ์ถ๋จ์ด์ ๋ฐ์ ์์
์๋ฒ ๋ฉ๊ณผ๋จ์ด ๋ถํฌ ๊ฐ์ค
์๋ฒ ๋ฉ์ ๋ชจ๋ ์ถ์ถํ๋ ๊ฒ์ ๋ชฉ์ ์ผ๋ก ์ค๊ณ
๋ณธ ๋ ผ๋ฌธ ์ด๋ก์๋ ๋ค์๊ณผ ๊ฐ์ ๋ฌธ์ฅ์ด ์์ ๋์ด ์๋ค.
motivated by the observation that the attention weight of a word pair depends on not only their contents but their relative positions. For example, the dependency between the words โdeepโ and โlearningโ is much stronger when they occur next to each other than when they occur in different sentences.
์์ ๋ ๋ฌธ์ฅ์ด DeBERTa
์ Inducitve Bias
๋ฅผ ๊ฐ์ฅ ์ ์ค๋ช
ํ๊ณ ์๋ค๊ณ ์๊ฐํ๋ค. ์ ์๊ฐ ์ถ๊ฐ๋ฅผ ์ฃผ์ฅํ๋ Inductive Bias
๋, relative position
์ ๋ณด๋ผ๋ ๊ฒ๊ณผ ๊ธฐ์กด ๋ชจ๋ธ๋ง์ผ๋ก๋ relative position
์ด ์ฃผ๋ ๋ฌธ๋งฅ ์ ๋ณด ํฌ์ฐฉ์ด ๋ถ๊ฐ๋ฅํ๋ค๋ ์ฌ์ค์ ์ ์ ์๋ค.
๊ทธ๋ ๋ค๋ฉด relative position
๊ฐ ์ ๊ณตํ๋ ๋ฌธ๋งฅ ์ ๋ณด๊ฐ ๋๋์ฒด ๋ญ๊ธธ๋ ๊ธฐ์กด์ ๋ฐฉ์์ผ๋ก๋ ํฌ์ฐฉ์ด ๋ถ๊ฐ๋ฅํ๋ค๋ ๊ฒ์ผ๊น?? ์์ฐ์ด์์ ํฌ์ฐฉ ๊ฐ๋ฅํ ๋ฌธ๋งฅ๋ค์ ์ข
๋ฅ์ ๊ธฐ์กด์ ๋ชจ๋ธ๋ง ๋ฐฉ์์ ๋ํ ์ ๋ฆฌ๋ถํฐ ํด๋ณด์. ์ฌ๊ธฐ์ ๋งํ๋ ๊ธฐ์กด ๋ฐฉ์์ด๋, ํจ์ดํ Self-Attention
๊ณผ Absolute Position Embedding
์ ์ฌ์ฉํ๋ Transformer-Encoder-Base
๋ชจ๋ธ(BERT
, RoBERTa
)์ ๋ปํ๋ค. ์ด๋ฒ ํฌ์คํ
์์๋ BERT
๋ฅผ ๊ธฐ์ค์ผ๋ก ์ค๋ช
ํ๊ฒ ๋ค.
๐ Types of Embedding
๋จผ์ ํ์กดํ๋ ๋ชจ๋ ์๋ฒ ๋ฉ(๋ฒกํฐ์ ๋ฌธ๋งฅ์ ์ฃผ์
ํ๋
)๊ธฐ๋ฒ๋ค์ ์ ๋ฆฌํด๋ณด์. ๋ค์๊ณผ ๊ฐ์ด 3๊ฐ์ง ์นดํ
๊ณ ๋ฆฌ๋ก ๋ถ๋ฅ๊ฐ ๊ฐ๋ฅํ๋ค.
- 1) ๋จ์ด์ ๋น๋์: ์ํ์ค์์ ์ฌ์ฉ๋ ํ ํฐ๋ค์ ๋น๋์๋ฅผ ์ธก์ (
Bag of words
) - 2) ๋จ์ด์ ๋ฐ์ ์์:
corpus
๋ด๋ถ์ ํน์ sequence
๋ฑ์ฅ ๋น๋๋ฅผ ์นด์ดํธ(N-Gram
), ์ฃผ์ด์ง ์ํ์ค๋ฅผ ๊ฐ์ง๊ณ ๋ค์ ์์ ์ ๋ฑ์ฅํ ํ ํฐ์ ๋ง์ถ๋ ๋ฐฉ์(LM
) - 3) ๋จ์ด ๋ถํฌ ๊ฐ์ค : ๋จ์ด์ ์๋ฏธ๋ ์ฃผ๋ณ ๋ฌธ๋งฅ์ ์ํด ๊ฒฐ์ ๋๋ค๋ ๊ฐ์ , ์ด๋ค ๋จ์ด ์์ด ์์ฃผ ๊ฐ์ด ๋ฑ์ฅํ๋์ง ์นด์ดํธํด
PMI
๋ฅผ ์ธก์ ํ๋ ๋ฐฉ์(Word2Vec
)
๊ธฐ์กด์ ๋ชจ๋ธ๋ง ๋ฐฉ์์ ์ด๋์ ํฌํจ๋ ๊น?? BERT
๋ ๋๋ถ๋ฅ ์ ์ ๊ฒฝ๋ง์ ํฌํจ๋๊ณ , Language Modeling
์ ํตํด ์ํ์ค๋ฅผ ํ์ตํ๋ค๋ ์ ๊ทธ๋ฆฌ๊ณ Self-Attention
๊ณผ Absolute Position Embedding
์ ์ฌ์ฉํ๋ค๋ ์ ์์ 2๋ฒ, ๋จ์ด์ ๋ฐ์ ์์
์ ํฌํจ๋๋ค๊ณ ๋ณผ ์ ์๋ค. Absolute Position Embedding
๊ณผ Self-Attention
์ ์ฌ์ฉ์ด ํจ์ดํ BERT
๊ฐ ๋ถ๋ฅ์ 2๋ฒ์ด๋ผ๋ ์ฌ์ค์ ๋ท๋ฐ์นจํ๋ ์ฆ๊ฑฐ๋ผ๋ ์ ์์ ์์ํ ์ ์๋ค. ํ์ง๋ง ์ ์๊ฐํด๋ณด์.
Absolute Position Embedding
์ ์ฃผ์ด์ง ์ํ์ค์ ๊ธธ์ด๋ฅผ ์ธก์ ํ ๋ค, ๋์ด๋ ์์ ๊ทธ๋๋ก forward
ํ๊ฒ 0
๋ถํฐ ๊ธธ์ด-1
์ ๋ฒํธ๋ฅผ ๊ฐ๋ณ ํ ํฐ์ ํ ๋นํ๋ค. ๋ค์ ๋งํด, ๋จ์ด๊ฐ ์ํ์ค์์ ๋ฐ์ํ ์์๋ฅผ ์ํ์ ์ผ๋ก ํํํด ๋ชจ๋ธ์ ์ฃผ์
ํ๋ค๋ ์๋ฏธ๊ฐ ๋๋ค. Self-Attention
์ Absolute Position Embedding
์ ๋ณด๊ฐ ์ฃผ์
๋ ์ํ์ค ์ ์ฒด๋ฅผ ํ ๋ฒ์ ๋ณ๋ ฌ ์ฒ๋ฆฌํ๋ค. ๋ฐ๋ผ์ ์ถฉ๋ถํ BERT
๊ฐ์ Self-Attention
, Absolute Position Embedding
๊ธฐ๋ฐ ๋ชจ๋ธ์ 2๋ฒ์ ๋ถ๋ฅํ ์ ์๊ฒ ๋ค.
ํํธ, ํน์๋ "BERT๋ MLM ์ ์ฌ์ฉํ๋๋ฐ Language Modeling์ ํ๋ค๊ณ ํ๋๊ฒ ๋ง๋์"
๋ผ๊ณ ๋งํ ์ ์๋ค. ํ์ง๋ง MLM
์ญ์ ๋๋ถ๋ฅ ์ Language Modeling
๊ธฐ๋ฒ์ ์ํ๋ค. ๋ค๋ง, Bi-Directional
ํ๊ฒ ๋ฌธ๋งฅ์ ํ์
ํ๊ณ LM
์ ํ๋๊น ์ ๋ง ์๋ฐํ ๋ฐ์ง๋ฉด 3๋ฒ์ ์์ฑ๋ ์กฐ๊ธ์ ์๋ค๊ณ ๋ณด๋๊ฒ ๋ฌด๋ฆฌ๋ ์๋๋ผ ์๊ฐํ๋ค. MLM
์ฌ์ฉ์ผ๋ก ๋ ๋ง์ ์ ๋ณด๋ฅผ ํฌ์ฐฉํด ์๋ฒ ๋ฉ์ ๋ง๋ค๊ธฐ ๋๋ฌธ์ ์ด๊ธฐ BERT
๊ฐ GPT
๋ณด๋ค NLU
์์ ์๋์ ์ผ๋ก ๊ฐ์ ์ ๊ฐ์ก๋ ๊ฒ ์๋๊น ์ถ๋ค.
๐ข Relative Position Embedding
์ด์ Relative Position Embedding
์ด ๋ฌด์์ด๊ณ , ๋๋์ฒด ์ด๋ค ๋ฌธ๋งฅ ์ ๋ณด๋ฅผ ํฌ์ฐฉํ๋ค๋ ๊ฒ์ธ์ง ์์๋ณด์. Relative Position Embedding
์ด๋, ์ํ์ค ๋ด๋ถ ํ ํฐ ์ฌ์ด์ ์์น ๊ด๊ณ ํํ์ ํตํด ํ ํฐ ์ฌ์ด์ relation
์ pairwise
ํ๊ฒ ํ์ตํ๋ ์์น ์๋ฒ ๋ฉ ๊ธฐ๋ฒ์ ๋งํ๋ค. ์ผ๋ฐ์ ์ผ๋ก ์๋ ์์น ๊ด๊ณ๋ ์๋ก ๋ค๋ฅธ ๋ ํ ํฐ์ ์ํ์ค ์ธ๋ฑ์ค ๊ฐ์ ์ฐจ๋ฅผ ์ด์ฉํด ๋ํ๋ธ๋ค. ํฌ์ฐฉํ๋ ๋ฌธ๋งฅ ์ ๋ณด๋ ์์์ ํจ๊นจ ์ค๋ช
ํ๊ฒ ๋ค. ๋ฅ๋ฌ๋์ด๋ผ๋ ๋จ์ด๋ ์์ด๋ก Deep Learning
์ด๋ค. ๋ ๋จ์ด๋ฅผ ํฉ์ณ๋๊ณ ๋ณด๋ฉด ์ ๊ฒฝ๋ง์ ์ฌ์ฉํ๋ ๋จธ์ ๋ฌ๋ ๊ธฐ๋ฒ์ ํ ์ข
๋ฅ
๋ผ๋ ์๋ฏธ๋ฅผ ๊ฐ๊ฒ ์ง๋ง, ๋ฐ๋ก ๋ฐ๋ก ๋ณด๋ฉด ๊น์
, ๋ฐฐ์
์ด๋ผ๋ ๊ฐ๋ณ์ ์ธ ์๋ฏธ๋ก ๋๋๋ค.
1) The Deep Learning is the Best Technique in Computer Science
2) Iโm learning how to swim in the deep ocean
Deep
๊ณผ Learning
์ ์๋์ ์ธ ๊ฑฐ๋ฆฌ์ ์ฃผ๋ชฉํ๋ฉด์ ๋ ๋ฌธ์ฅ์ ํด์ํด๋ณด์. ์ฒซ ๋ฒ์งธ ๋ฌธ์ฅ์์ ๋ ๋จ์ด๋ ์ด์ํ๊ฒ ์์นํด ์ ๊ฒฝ๋ง์ ์ฌ์ฉํ๋ ๋จธ์ ๋ฌ๋ ๊ธฐ๋ฒ์ ํ ์ข
๋ฅ
๋ผ๋ ์๋ฏธ๋ฅผ ๋ง๋ค์ด๋ด๊ณ ์๋ค. ํํธ ๋ ๋ฒ์งธ ๋ฌธ์ฅ์์ ๋ ๋จ์ด๋ ๋์ด์ฐ๊ธฐ ๊ธฐ์ค 5๊ฐ์ ํ ํฐ๋งํผ ๋จ์ด์ ธ ์์นํด ๊ฐ๊ฐ ๋ฐฐ์
, ๊น์
์ด๋ผ๋ ์๋ฏธ๋ฅผ ๋ง๋ค์ด ๋ด๊ณ ์๋ค. ์ด์ฒ๋ผ ๊ฐ๋ณ ํ ํฐ ์ฌ์ด์ ์์น ๊ด๊ณ์ ๋ฐ๋ผ์ ํ์๋๋ ๋ฌธ๋งฅ์ ์ ๋ณด๋ฅผ ํฌ์ฐฉํ๋ ค๋ ์๋๋ก ์ค๊ณ๋ ๊ธฐ๋ฒ์ด ๋ฐ๋ก Relative Position Embedding
์ด๋ค.
pairwise
ํ๊ฒ relation
์ ํฌ์ฐฉํ๋ค๋ ์ ์ผ๋ก ๋ณด์ skip-gram
์ negative sampling
๊ณผ ๋งค์ฐ ์ ์ฌํ ๋๋์ ์ ๋ณด๋ฅผ ํฌ์ฐฉํ ๊ฒ์ด๋ผ๊ณ ์์๋๋ฉฐ ์นดํ
๊ณ ๋ฆฌ ๋ถ๋ฅ์ 3๋ฒ, ๋จ์ด ๋ถํฌ ๊ฐ์ค
์ ํฌํจ์ํฌ ์ ์์ ๊ฒ ๊ฐ๋ค. (ํ์์ ๊ฐ์ธ์ ์ธ ์๊ฒฌ์ด๋ ์ด ๋ถ๋ถ์ ๋ํ ๋ค๋ฅธ ์๊ฒฌ์ด ์๋ค๋ฉด ๊ผญ ๋๊ธ์ ์ ์ด์ฃผ์๋ฉด ๊ฐ์ฌํ๊ฒ ์ต๋๋น๐ฅฐ).
์ ์์๋ง์ผ๋ก๋ ์๋ ์์น ์๋ฒ ๋ฉ ๊ฐ๋ ์ด ์๋ฟ์ง ์์ ์ ์๋ค. ๊ทธ๋ ๋ค๋ฉด ์์ ๋งํฌ๋ฅผ ๋จผ์ ์ฝ๊ณ ์ค์. (๋งํฌ1)
Relative Position Embedding
์ ์ค์ ์ด๋ป๊ฒ ์ฝ๋๋ก ๊ตฌํํ๋์ง, ๋ณธ ๋
ผ๋ฌธ์์๋ ์์น ๊ด๊ณ๋ฅผ ์ด๋ป๊ฒ ์ ์ํ๋์ง Absolute Position Embedding
์ ๋น๊ต๋ฅผ ํตํด ์์๋ณด์. ๋ค์๊ณผ ๊ฐ์ ๋ ๊ฐ์ ๋ฌธ์ฅ์ด ์์ ๋, ๊ฐ๋ณ ์์น ์๋ฒ ๋ฉ ๋ฐฉ์์ด ๋ฌธ์ฅ์ ์์น ์ ๋ณด๋ฅผ ์ธ์ฝ๋ฉํ๋ ๊ณผ์ ์ ํ์ด์ฌ ์ฝ๋๋ก ์์ฑํด๋ดค๋ค. ํจ๊ป ์ดํด๋ณด์.
A) I love studying deep learning so much
B) I love deep cheeze burguer so much
# Absolute Position Embedding
>>> max_length = 7
>>> position_embedding = nn.Embedding(7, 512) # [max_seq, dim_model]
>>> pos_x = position_embedding(torch.arange(max_length))
>>> pos_x, pos_x.shape
(tensor([[ 0.4027, 0.9331, 1.0556, ..., -1.7370, 0.7799, 1.9851], # A,B์ 0๋ฒ ํ ํฐ: I
[-0.2206, 2.1024, -0.6055, ..., -1.1342, 1.3956, 0.9017], # A,B์ 1๋ฒ ํ ํฐ: love
[-0.9560, -0.0426, -1.8587, ..., -0.9406, -0.1467, 0.1762], # A,B์ 2๋ฒ ํ ํฐ: studying, deep
..., # A,B์ 3๋ฒ ํ ํฐ: deep, cheeze
[ 0.5999, 0.5235, -0.3445, ..., 1.9020, -1.5003, 0.7535], # A,B์ 4๋ฒ ํ ํฐ: learning, burger
[ 0.0688, 0.5867, -0.0340, ..., 0.8547, -0.9196, 1.1193], # A,B์ 5๋ฒ ํ ํฐ: so
[-0.0751, -0.4133, 0.0256, ..., 0.0788, 1.4665, 0.8196]], # A,B์ 6๋ฒ ํ ํฐ: much
grad_fn=<EmbeddingBackward0>),
torch.Size([7, 512]))
Absolute Position Embedding
์ ์ฃผ์ ๋ฌธ๋งฅ์ ์๊ด์์ด ๊ฐ์ ์์น์ ํ ํฐ์ด๋ผ๋ฉด ๊ฐ์ ํฌ์ง์
๊ฐ์ผ๋ก ์ธ์ฝ๋ฉํ๊ธฐ ๋๋ฌธ์ 512
๊ฐ์ ์์๋ก ๊ตฌ์ฑ๋ ํ๋ฒกํฐ๋ค์ ์ธ๋ฑ์ค๋ฅผ ์ค์ ๋ฌธ์ฅ์์ ํ ํฐ์ ๋ฑ์ฅ ์์์ ๋งตํํด์ฃผ๋ ๋ฐฉ์์ผ๋ก ์์น ์ ๋ณด๋ฅผ ํํํ๋ค. ์๋ฅผ ๋ค๋ฉด, ๋ฌธ์ฅ์์ ๊ฐ์ฅ ๋จผ์ ๋ฑ์ฅํ๋ 0
๋ฒ ํ ํฐ์ 0
๋ฒ์งธ ํ๋ฒกํฐ
๋ฅผ ๋ฐฐ์ ํ๊ณ ๊ฐ์ฅ ๋ง์ง๋ง์ ๋ฑ์ฅํ๋ N-1
๋ฒ์งธ ํ ํฐ์ N-1
๋ฒ์งธ ํ๋ฒกํฐ
๋ฅผ ์์น ์ ๋ณด๊ฐ์ผ๋ก ๊ฐ๋ ๋ฐฉ์์ด๋ค. ์ ์ฒด ์ํ์ค ๊ด์ ์์ ๊ฐ๋ณ ํ ํฐ์ ๋ฒํธ๋ฅผ ๋ถ์ฌํ๊ธฐ ๋๋ฌธ์ syntactical
ํ ์ ๋ณด๋ฅผ ๋ชจ๋ธ๋ง ํด์ฃผ๊ธฐ ์ ํฉํ๋ค๋ ์ฅ์ ์ด ์๋ค.
Absolute Position Embedding
์ ์ผ๋ฐ์ ์ผ๋ก Input Embedding
๊ณผ ํ๋ ฌํฉ ์ฐ์ฐ์ ํตํด Word Embedding
์ผ๋ก ๋ง๋ค์ด ์ธ์ฝ๋์ ์
๋ ฅ์ผ๋ก ์ฌ์ฉํ๋ค.
์๋ ์ฝ๋๋ ์ ์๊ฐ ๋
ผ๋ฌธ์์ ์ ์ํ DeBERTa
์ Relative Position Embedding
๊ตฌํ์ ํ์ดํ ์น๋ก ์ฎ๊ธด ๊ฒ์ด๋ค. Relative Position Embedding
์ ์ ๋ ์์น์ ๋นํด ๊ฝค๋ ๋ณต์กํ ๊ณผ์ ์ ๊ฑฐ์ณ์ผ ํ๊ธฐ ๋๋ฌธ์ ์ฝ๋ ์ญ์ ๊ธด ํธ์ด๋ค. ํ๋ ํ๋ ์ฒ์ฒํ ์ดํด๋ณด์.
# Relative Position Embedding
>>> position_embedding = nn.Embedding(2*max_length, dim_model)
>>> x, p_x = torch.randn(max_length, dim_model), position_embedding(torch.arange(2*max_length))
>>> fc_q, fc_kr = nn.Linear(dim_model, dim_head), nn.Linear(dim_model, dim_head)
>>> q, kr = fc_q(x), fc_kr(p_x) # [batch, max_length, dim_head], [batch, 2*max_length, dim_head]
>>> tmp_c2p = torch.matmul(q, kr.transpose(-1, -2))
>>> tmp_c2p, tmp_c2p.shape
(tensor([[ 2.8118, 0.8449, -0.6240, -0.6516, 3.4009, 1.8296, 0.8304, 1.0164,
3.5664, -1.4208, -2.0821, 1.5752, -0.9469, -7.1767],
[-2.1907, -3.2801, -2.0628, 0.4443, 2.2272, -5.6653, -4.6036, 1.4134,
-1.1742, -0.3361, -0.4586, -1.1827, 1.0878, -2.5657],
[-4.8952, -1.5330, 0.0251, 3.5001, 4.1619, 1.7408, -0.5100, -3.4616,
-1.6101, -1.8741, 1.1404, 4.9860, -2.5350, 1.0999],
[-3.3437, 4.2276, 0.4509, -1.8911, -1.1069, 0.9540, 1.2045, 2.2194,
-2.6509, -1.4076, 5.1599, 1.6591, 3.8764, 2.5126],
[ 0.8164, -1.9171, 0.8217, 1.3953, 1.6260, 3.8104, -1.0303, -2.1631,
3.9008, 0.5856, -1.6212, 1.7220, 2.7997, -1.8802],
[ 3.4473, 0.9721, 3.9137, -3.2055, 0.6963, 1.2761, -0.2266, -3.7274,
-1.4928, -1.9257, -5.4422, -1.8544, 1.8749, -3.4923],
[ 2.6639, -1.4392, -3.8818, -1.4120, 1.7542, -0.8774, -3.0795, -1.2156,
-1.0852, 3.7825, -3.5581, -3.6989, -2.6705, -1.2262]],
grad_fn=<MmBackward0>),
torch.Size([7, 14]))
>>> max_seq, max_pos = 7, max_seq * 2
>>> q_index, k_index = torch.arange(max_seq), torch.arange(max_seq)
>>> q_index, k_index
(tensor([0, 1, 2, 3, 4, 5, 6]), tensor([0, 1, 2, 3, 4, 5, 6]))
>>> tmp_pos = q_index.view(-1, 1) - k_index.view(1, -1)
>>> rel_pos_matrix = tmp_pos + max_relative_position
>>> rel_pos_matrix
tensor([[ 7, 6, 5, 4, 3, 2, 1],
[ 8, 7, 6, 5, 4, 3, 2],
[ 9, 8, 7, 6, 5, 4, 3],
[10, 9, 8, 7, 6, 5, 4],
[11, 10, 9, 8, 7, 6, 5],
[12, 11, 10, 9, 8, 7, 6],
[13, 12, 11, 10, 9, 8, 7]])
>>> rel_pos_matrix = torch.clamp(rel_pos_matrix, 0, max_pos - 1).repeat(10, 1, 1)
>>> tmp_c2p = tmp_c2p.repeat(10, 1, 1)
>>> rel_pos_matrix, rel_pos_matrix.shape, tmp_c2p.shape
(tensor([[[ 7, 6, 5, 4, 3, 2, 1],
[ 8, 7, 6, 5, 4, 3, 2],
[ 9, 8, 7, 6, 5, 4, 3],
[10, 9, 8, 7, 6, 5, 4],
[11, 10, 9, 8, 7, 6, 5],
[12, 11, 10, 9, 8, 7, 6],
[13, 12, 11, 10, 9, 8, 7]],
torch.Size([10, 7, 14]),
torch.Size([10, 7, 14]))
>>> outputs = torch.gather(tmp_c2p, dim=-1, index=rel_pos_matrix)
>>> outputs, outputs.shape
(tensor([[[ 1.0164, 0.8304, 1.8296, 3.4009, -0.6516, -0.6240, 0.8449],
[-1.1742, 1.4134, -4.6036, -5.6653, 2.2272, 0.4443, -2.0628],
[-1.8741, -1.6101, -3.4616, -0.5100, 1.7408, 4.1619, 3.5001],
[ 5.1599, -1.4076, -2.6509, 2.2194, 1.2045, 0.9540, -1.1069],
[ 1.7220, -1.6212, 0.5856, 3.9008, -2.1631, -1.0303, 3.8104],
[ 1.8749, -1.8544, -5.4422, -1.9257, -1.4928, -3.7274, -0.2266],
[-1.2262, -2.6705, -3.6989, -3.5581, 3.7825, -1.0852, -1.2156]],
.....
[[ 1.0164, 0.8304, 1.8296, 3.4009, -0.6516, -0.6240, 0.8449],
[-1.1742, 1.4134, -4.6036, -5.6653, 2.2272, 0.4443, -2.0628],
[-1.8741, -1.6101, -3.4616, -0.5100, 1.7408, 4.1619, 3.5001],
[ 5.1599, -1.4076, -2.6509, 2.2194, 1.2045, 0.9540, -1.1069],
[ 1.7220, -1.6212, 0.5856, 3.9008, -2.1631, -1.0303, 3.8104],
[ 1.8749, -1.8544, -5.4422, -1.9257, -1.4928, -3.7274, -0.2266],
[-1.2262, -2.6705, -3.6989, -3.5581, 3.7825, -1.0852, -1.2156]]],
grad_fn=<GatherBackward0>),
torch.Size([10, 7, 7]))
์ผ๋จ ์ ๋ ์์น์ ๋์ผํ๊ฒ nn.Embedding
์ ์ฌ์ฉํด ์๋ฒ ๋ฉ ๋ฃฉ์
ํ
์ด๋ธ(๋ ์ด์ด)๋ฅผ ์ ์ํ์ง๋ง, ์
๋ ฅ ์ฐจ์์ด ๋ค๋ฅด๋ค. ์ ๋ ์์น ์๋ฒ ๋ฉ์ forward
ํ๊ฒ ์์น๊ฐ์ ๋งตํํด์ผ ํ๋ ๋ฐ๋ฉด์ ์๋ ์์น ์๋ฒ ๋ฉ ๋ฐฉ์์ Bi-Directional
ํ ๋งตํ์ ํด์ผ ํด์, ๊ธฐ์กด max_length
๊ฐ์ ๋ ๋ฐฐ๋ฅผ ์
๋ ฅ ์ฐจ์(max_pos
)์ผ๋ก ์ฌ์ฉํ๋ค. ์๋ฅผ ๋ค์ด 0
๋ฒ ํ ํฐ๊ณผ ๋๋จธ์ง ํ ํฐ ์ฌ์ด์ ์์น ๊ด๊ณ๋ฅผ ํํํด์ผ ํ๋ ์ํฉ์ด๋ค. ๊ทธ๋ ๋ค๋ฉด ์ฐ๋ฆฌ๋ 0
๋ฒ ํ ํฐ๊ณผ ๋๋จธ์ง ํ ํฐ๊ณผ์ ์์น ๊ด๊ณ๋ฅผ [0, -1, -2, -3, -4, -5, -6]
์ผ๋ก ์ธ์ฝ๋ฉํ ์ ์๋ค.
๋ฐ๋๋ก ๋ง์ง๋ง 6
๋ฒ ํ ํฐ๊ณผ ๋๋จธ์ง ํ ํฐ ์ฌ์ด์ ์์น ๊ด๊ณ๋ฅผ ํํํ๋ ๊ฒฝ์ฐ๋ผ๋ฉด ์ด๋ป๊ฒ ๋ ๊น?? [6, 5, 4, 3, 2, 1, 0]
์ผ๋ก ์ธ์ฝ๋ฉ ๋ ๊ฒ์ด๋ค. ๋ค์ ๋งํด, ์์น ์๋ฒ ๋ฉ ์์ ๊ฐ์ [-max_seq:max_seq]
์ฌ์ด์์ ์ ์๋๋ค๋ ๊ฒ์ด๋ค. ๊ทธ๋ฌ๋ ์์๊ฐ์ ๋ฒ์๋ฅผ ๊ทธ๋๋ก ์ฌ์ฉํ ์๋ ์๋ค. ์ด์ ๋ ํ์ด์ฌ์ ๋ฆฌ์คํธ, ํ
์ ๊ฐ์ ๋ฐฐ์ดํ ์๋ฃ๊ตฌ์กฐ๋ ์์ด ์๋ ์ ์๋ฅผ ์ธ๋ฑ์ค๋ก ํ์ฉํด์ผ forward
ํ๊ฒ ์์์ ์ ๊ทผํ ์ ์๊ธฐ ๋๋ฌธ์ด๋ค. ์ผ๋ฐ์ ์ผ๋ก ๋ฐฐ์ด ํํ์ ์๋ฃํ์ ๋ชจ๋ ์ธ๋ฑ์ค 0
๋ถํฐ N-1
๊น์ง ์์ฐจ์ ์ผ๋ก ๋งตํ๋๋ค. ๊ทธ๋์ ์๋ํ๋๋ก ํ ํฐ์ ์ ๊ทผํ๋ ค๋ฉด ์ญ์ ํ ํฐ์ ์ธ๋ฑ์ค๋ฅผ forward
ํ ํํ๋ก ๋ง๋ค์ด์ค์ผ ํ๋ค.
๋ฐ๋ผ์ ๊ธฐ์กด [-max_seq:max_seq]
์ max_seq
๋ฅผ ๋ํด์ค [0:2*max_seq]
(2 * max_seq
)์ ์์ ๊ฐ์ ๋ฒ์๋ก ์ฌ์ฉํ๊ฒ ๋๋ค. ์ฌ๊ธฐ๊น์ง๊ฐ ํต์์ ์ผ๋ก ๋งํ๋ Relative Position Embedding
์ ํด๋นํ๋ค. ์ ์ฝ๋์์ผ๋ก๋ rel_pos_matrix
๋ฅผ ๋ง๋ ๋ถ๋ถ์ ํด๋นํ๋ค.
์ด์ ๋ถํฐ ์ ์๊ฐ ์ฃผ์ฅํ๋ ์์น ๊ด๊ณ ํํ ๋ฐฉ์์ ๋ํด ์์๋ณด์. ์ผ๋ฐ์ ์ธ Relative Position Embedding
๊ณผ ๊ฑฐ์ ์ ์ฌํ์ง๋ง, rel_pos_matrix
๋ด๋ถ ์์ ๊ฐ์ด ์์๊ฐ ๋๊ฑฐ๋ max_pos
์ ์ด๊ณผํ๋ ๊ฒฝ์ฐ๋ฅผ ์ฒ๋ฆฌ ํด์ฃผ๊ธฐ ์ํด ํ์ฒ๋ฆฌ ๊ณผ์ ์ ๋์
ํด ์ฌ์ฉํ๋ค. ์์ธ ์ํฉ์ max_seq > 1/2 * max_pos(==k)
์ผ ๋ ๋ฐ์ํ๋ค. official repo
์ ์ฝ๋๋ฅผ ๋ณด๋ฉด max_seq
์ k
๋ฅผ ์ผ์น์์ผ ๋ชจ๋ธ๋ง ํ๊ธฐ ๋๋ฌธ์ ํ์ธํ๋ ํ๋ ์ํฉ์ด๋ผ๋ฉด ์ด๊ฒ์ ๋ชฐ๋ผ๋ ์๊ด์๊ฒ ์ง๋ง, ํ๋ ํ๋ ๋ชจ๋ธ์ ์ง์ ๋ง๋๋ ์
์ฅ์ด๋ผ๋ฉด ์์ธ ์ํฉ์ ๋ฐ๋์ ๊ธฐ์ตํ์.
ํํธ, ์ด๋ฌํ ์ธ์ฝ๋ฉ ๋ฐฉ์์ word2vec
์ window size
๋์
๊ณผ ๋น์ทํ ์๋ฆฌ(์๋ฏธ๋ ์ฃผ๋ณ ๋ฌธ๋งฅ์ ์ํด ๊ฒฐ์
)๋ผ๊ณ ์๊ฐํ๋ฉด ๋๋๋ฐ, ์๋์ฐ ์ฌ์ด์ฆ ๋ฒ์์์ ๋ฒ์ด๋ ํ ํฐ๋ค์ ์ฃผ๋ณ ๋ฌธ๋งฅ์ผ๋ก ์ธ์ ํ์ง ์๊ฒ ๋ค๋(negative sample
) ์๋๋ฅผ ๊ฐ๊ณ ์๋ค. ์ค์ ๊ตฌํ์ ํ
์ ๋ด๋ถ ์์๊ฐ์ ๋ฒ์๋ฅผ ์ฌ์ฉ์ ์ง์ ๋ฒ์๋ก ์ ํํ ์ ์๋ torch.clamp
๋ฅผ ์ฌ์ฉํ๋ฉด 1
์ค๋ก ๊น๋ํ๊ฒ ๋ง๋ค ์ ์์ผ๋ ์ฐธ๊ณ ํ์.
torch.clamp
๊น์ง ์ ์ฉํ๊ณ ๋ ์ต์ข
๊ฒฐ๊ณผ๋ฅผ ์ดํด๋ณด์. ํ๋ฐฑํฐ, ์ด๋ฒกํฐ ๋ชจ๋ [0:2*max_seq]
์ฌ์ด์์ ์ ์๋๊ณ ์์ผ๋ฉฐ, ๊ฐ๋ณ ๋ฐฉํฅ ๋ฒกํฐ ์์์ ์ต๋๊ฐ๊ณผ ์ต์๊ฐ์ ์ฐจ์ด๊ฐ ํญ์ k
๋ก ์ ์ง ๋๋ค. ์๋๋๋ก ์ ํํ ์๋์ฐ ์ฌ์ด์ฆ๋งํผ์ ์ฃผ๋ณ ๋งฅ๋ฝ์ ๋ฐ์ํด ์๋ฒ ๋ฉ์ ํ์ฑํ๊ณ ์์์ ์ ์ ์๋ค.
์ ๋ฆฌํ๋ฉด, Relative Position Embedding
๋ ์ ๋ ์์น ๋ฐฉ์์ฒ๋ผ ์๋ฒ ๋ฉ ๋ฃฉ์
ํ
์ด๋ธ์ ๋ง๋ค๋, ์ฌ์ฉ์๊ฐ ์ง์ ํ ์๋์ฐ ์ฌ์ด์ฆ์ ํด๋นํ๋ ํ ํฐ์ ์๋ฒ ๋ฉ ๊ฐ๋ง ์ถ์ถํด ์๋ก์ด ํ๋ฒกํฐ๋ฅผ ์ฌ๋ฌ ๊ฐ ๋ง๋ค์ด ๋ด๋ ๊ธฐ๋ฒ์ด๋ผ๊ณ ํ ์ ์๋ค. ์ด ๋ ํ๋ฒกํฐ๋ ๋์ ํ ํฐ๊ณผ ๊ทธ ๋๋จธ์ง ํ ํฐ ์ฌ์ด์ ์์น ๋ณํ์ ๋ฐ๋ผ ๋ฐ์ํ๋ ํ์์ ์ธ ๋งฅ๋ฝ ์ ๋ณด๋ฅผ ๋ด๊ณ ์๋ค.
๐ค Word Context vs Relative Position vs Absolute Position
์ง๊ธ๊น์ง Relative Position Embedding
์ด ๋ฌด์์ด๊ณ , ๋๋์ฒด ์ด๋ค ๋ฌธ๋งฅ ์ ๋ณด๋ฅผ ํฌ์ฐฉํ๋ค๋ ๊ฒ์ธ์ง ์์๋ดค๋ค. ํ์์ ์ค๋ช
์ด ๋งค๋๋ฝ์ง ๋ชปํ๊ธฐ๋ ํ๊ณ ์์๋ฅผ ํ
์คํธ๋ก ๋ค๊ณ ์์ด์ ์ง๊ด์ ์ผ๋ก word context
๋ ๋ฌด์์ธ์ง, Position
์ ๋ณด์๋ ๋ญ๊ฐ ๋ค๋ฅธ์ง, ๋ ๊ฐ์ง Position
์ ๋ณด๋ ๋ญ๊ฐ ์ด๋ป๊ฒ ๋ค๋ฅธ์ง ์๋ฟ์ง ์๋ ๋ถ๋ค์ด ๋ง์ผ์ค ๊ฒ ๊ฐ๋ค. ๊ทธ๋์ ์ต๋ํ ์ง๊ด์ ์ธ ์์๋ฅผ ํตํด ์ธ๊ฐ์ง ์ ๋ณด์ ์ฐจ์ด์ ์ ์ค๋ช
ํด๋ณด๋ ค ํ๋ค. (ํ์ ๋ณธ์ธ์ด ํ๊ฐ๋ ค์ ์ฐ๋ ๊ฑด ๋น๋ฐ์ด๋ค)
์ฌ๋ 5๋ช
์ด ๊ณตํญ ์ฒดํฌ์ธ์ ์ํด ์ ์๋ค. ๋ชจ๋ ์ผ์ชฝ์ ๋ณด๊ณ ์๋ ๊ฒ์ ๋ณด์ ์ผ์ชฝ์ ํค๊ฐ ์ ์ผ ์์ ์ฌ์๊ฐ ๊ฐ์ฅ ์์ค์ด๋ผ๊ณ ๋ณผ ์ ์๊ฒ ๋ค. ์ฐ๋ฆฌ๋ ์ค ์์๋ ์์๋๋ก 5๋ช
์ ์ฌ๋์๊ฒ ๋ฒํธ๋ฅผ ๋ถ์ฌํ ๊ฒ์ด๋ค. ํธ์์ 0๋ฒ๋ถํฐ ์์ํด 4๋ฒ๊น์ง ๋ฒํธ๋ฅผ ์ฃผ๊ฒ ๋ค. 1๋ฒ์ ํด๋นํ๋ ์ฌ๋์ ๋๊ตฌ์ธ๊ฐ?? ๋ฐ๋ก ์ค์ 2๋ฒ์งธ์ ์์๋ ์ฌ์๋ค. ๊ทธ๋ผ 2๋ฒ์ ํด๋นํ๋ ์ฌ๋์ ๋๊ตฌ์ธ๊ฐ?? ์ฌ์ง ์ ์ค์ ๊ฐ์ฅ ์ค๊ฐ์ ์๋ ๋จ์๊ฐ 2๋ฒ์ด๋ค. ์ด๋ ๊ฒ ๊ทธ๋ฃน ๋จ์(์ ์ฒด ์ค)์์ ๊ฐ๊ฐ์ธ์ ์ผ๋ จ์ ๋ฒํธ๋ฅผ ๋ถ์ฌํด ์์น๋ฅผ ํํํ๋ ๋ฐฉ๋ฒ์ด ๋ฐ๋ก Absolute Position Embedding
์ด๋ค.
ํํธ, ๋ค์ 2๋ฒ ์ฌ๋์๊ฒ ์ฃผ๋ชฉํด๋ณด์. ์ฐ๋ฆฌ๋ 2๋ฒ ๋จ์๋ฅผ ์ ์ฒด ์ค์์ ๊ฐ์ด๋ฐ ์์นํ ์ฌ๋์ด ์๋๋ผ, ๊ฒ์ ์ ์๋ณต๊ณผ ๊ตฌ๋๋ฅผ ์ ๊ณ ์์ ์ฅ ๋ฌด์ธ๊ฐ๋ฅผ ์์ํ๊ณ ์๋ ์ฌ๋์ด๋ผ๊ณ ํํํ ์๋ ์๋ค. ์ด๊ฒ์ด ๋ฐ๋ก ํ ํฐ์ ์๋ฏธ ์ ๋ณด๋ฅผ ๋ด์ word context
์ ํด๋นํ๋ค.
๋ง์ง๋ง์ผ๋ก Relative Position Embedding
๋ฐฉ์์ผ๋ก 2๋ฒ ๋จ์๋ฅผ ํํํด๋ณด์. ์ค๋ฅธ์์ผ๋ก๋ ์ปคํผ๋ฅผ ๋ค๊ณ ๋ค๋ฅธ ์์ผ๋ก๋ ์บ๋ฆฌ์ด๋ฅผ ์ก๊ณ ์์ผ๋ฉฐ ๊ฒ์ ์ ํ์ดํ๊ณผ ๋ฒ ์ด์ง์ ๋ฐ์ง๋ฅผ ์
์ 1๋ฒ ์ฌ์์ ๋ค์ ์๋ ์ฌ๋, ํ์ ์๋ณต๊ณผ ๊ฒ์ ๋ฟํ
์๊ฒฝ์ ์ฐ๊ณ ํ ์์๋ ์บ๋ฆฌ์ด๋ฅผ ์ก๊ณ ์๋ 4๋ฒ ์ฌ์์ ์์ ์๋ ์ฌ๋, ๊ฒ์ ์ ์์ผ๊ณผ ์ฒญ๋ฐ์ง๋ฅผ ์
๊ณ ํ ์์๋ ํ์ ์ฝํธ๋ฅผ ๋ค๊ณ ์๋ ์ค์ ๋งจ ์ ์ฌ์๋ก๋ถํฐ 2๋ฒ์งธ ๋ค์ ์์๋ ์ฌ๋, ํฑ์์ผ์ด ๊ธธ๊ณ ๋จธ๋ฆฌ๊ฐ ๊ธด ํธ์ด๋ฉฐ ํ๋์ ๊ฐ๋๊ฑด์ ์
๊ณ ์ด๋ก์๊ณผ ๊ฒ์ ์์ด ํผํฉ๋ ๊ฐ๋ฐฉ์ ์ผ์ชฝ์ผ๋ก ๋ฉ๊ณ ์๋ ๋จ์๋ก๋ถํฐ 2๋ฒ์งธ ์์ ์๋ ์ฌ๋.
์ด์ฒ๋ผ ํํํ๋๊ฒ ๋ฐ๋ก Relative Position Embedding
์ ๋์๋๋ค๊ณ ๋ณผ ์ ์๋ค. ์ด์ ์ ์์๋ฅผ ์์ฐ์ด ์ฒ๋ฆฌ์ ๊ทธ๋๋ก ๋์
์์ผ๋ณด๋ฉด ์ดํด๊ฐ ํ๊ฒฐ ์์ํ ๊ฒ์ด๋ค.
๐ค DeBERTa Inductive Bias
๊ฒฐ๊ตญ DeBERTa
๋ ๋๊ฐ์ง ์์น ์ ๋ณด ํฌ์ฐฉ ๋ฐฉ์์ ์ ์ ํ ์์ด์ ๋ชจ๋ธ์ด ๋์ฑ ํ๋ถํ ์๋ฒ ๋ฉ์ ๊ฐ๋๋ก ํ๋ ค๋ ์๋๋ก ์ค๊ณ ๋์๋ค. ๋ํ ์ฐ๋ฆฌ๋ ์ด๋ฏธ ๋ชจ๋ธ์ด ๋ค์ํ ๋งฅ๋ฝ ์ ๋ณด๋ฅผ ํฌ์ฐฉํ ์๋ก NLU Task
์์ ๋ ๋์ ์ฑ๋ฅ์ ๊ธฐ๋กํ๋ค๋ ์ฌ์ค์ BERT
์ GPT
์ฌ๋ก์์ ์ ์ ์์๋ค. ๋ฐ๋ผ์ Relative Position Embedding
์ ์ถ๊ฐํ์ฌ ๋จ์ด์ ๋ฐ์ ์์
๋ฅผ ํฌ์ฐฉํ๋ ๋ชจ๋ธ์ ๋จ์ด ๋ถํฌ ๊ฐ์ค
์ ์ธ ํน์ง์ ๋ํด์ฃผ๋ ค๋ ์ ์์ ์์ด๋์ด๋ ๋งค์ฐ ํ๋นํ๋ค๊ณ ๋ณผ ์ ์๊ฒ ๋ค.
์ด์ ๊ด๊ฑด์ โ๋๊ฐ์ง ์์น ์ ๋ณด๋ฅผ ์ด๋ค ๋ฐฉ์์ผ๋ก ์ถ์ถํ๊ณ ์์ด์ค ๊ฒ์ธ๊ฐโ
ํ๋ ๋ฌผ์์ ๋ตํ๋ ๊ฒ์ด๋ค. ์ ์๋ ๋ฌผ์์ ๋ตํ๊ธฐ ์ํด Disentangled Self-Attention
๊ณผ Enhanced Mask Decoder
๋ผ๋ ์๋ก์ด ๊ธฐ๋ฒ ๋๊ฐ์ง๋ฅผ ์ ์ํ๋ค. ์ ์๋ ๋จ์ด ๋ถํฌ ๊ฐ์ค
์ ํด๋น๋๋ ๋งฅ๋ฝ ์ ๋ณด๋ฅผ ์ถ์ถํ๊ธฐ ์ํ ๊ธฐ๋ฒ์ด๊ณ , ํ์๋ ๋จ์ด ๋ฐ์ ์์
์ ํฌํจ๋๋ ์๋ฒ ๋ฉ์ ๋ชจ๋ธ์ ์ฃผ์
ํ๊ธฐ ์ํด ์ค๊ณ๋์๋ค. ๋ชจ๋ธ๋ง ํํธ์์๋ ๋๊ฐ์ง ์๋ก์ด ๊ธฐ๋ฒ์ ๋ํด์ ์์ธํ ์ดํด๋ณธ ๋ค, ๋ชจ๋ธ์ ์ฝ๋๋ก ๋น๋ํ๋ ๊ณผ์ ์ ์ค๋ช
ํ๋ ค ํ๋ค.
์ฝ๋๋ ๋
ผ๋ฌธ์ ๋ด์ฉ๊ณผ microsoft์ ๊ณต์ git repo๋ฅผ ์ฐธ๊ณ ํด ๋ง๋ค์์์ ๋ฐํ๋ค. ๋ค๋ง, ๋
ผ๋ฌธ์์ ๋ชจ๋ธ ๊ตฌํ๊ณผ ๊ด๋ จํด ์ธ๋ถ์ ์ธ ๋ด์ฉ์ ์๋น์ ์๋ตํ๊ณ ์์ผ๋ฉฐ, repo์ ๊ณต๊ฐ๋ ์ฝ๋๋ hard coding๋์ด ๊ทธ ์๋๋ฅผ ์ ํํ๊ฒ ํ์
ํ๋๋ฐ ๋ง์ ์ด๋ ค์์ด ์์๋ค. ๊ทธ๋์ ์ด๋ ์ ๋๋ ํ์์ ์ฃผ๊ด์ ์ธ ์๊ฐ์ด ๋ฐ์๋ ์ฝ๋๋ผ๋ ์ ์ ๋ฏธ๋ฆฌ ๋ฐํ๋ค.
๐ย Modeling
- 1) Disentangled Self-Attention Encoder Block for
Relative Position Embedding
- 2) Enhanced Mask Decoder for
Absolute Position Embedding
DeBERTa
์ ์ ๋ฐ์ ์ธ ๊ตฌ์กฐ๋ ์ผ๋ฐ์ ์ธ BERT
, RoBERTa
์ ํฌ๊ฒ ๋ค๋ฅธ ์ ์ด ์๋ค. ๋ค๋ง, ๋ชจ๋ธ์ ์ด๋ฐ๋ถ Input Embedding
์์ Absolute Position
์ ๋ณด๋ฅผ ์ถ๊ฐํ๋ ๋ถ๋ถ์ด ํ๋ฐ๋ถ Enhanced Mask Decoder
๋ผ ๋ถ๋ฅด๋ ์ธ์ฝ๋ ๋ธ๋ก์ผ๋ก ์ฎ๊ฒจ๊ฐ ๊ฒ๊ณผ Disentangled Self-Attention
์ ์ํด ๊ฐ๋ณ ์ธ์ฝ๋ ๋ธ๋ก๋ง๋ค ์๋ ์์น ์ ๋ณด๋ฅผ ์ถ์ฒ๋ก ํ๋ linear projection
๋ ์ด์ด๊ฐ ์ถ๊ฐ๋์์์ ๋ช
์ฌํ์. ๋ํ, DeBERTa
์ pre-train
์ RoBERTa
์ฒ๋ผ NSP
๋ฅผ ์ญ์ ํ๊ณ MLM
๋ง ์ฌ์ฉํ ์ ๋ ๊ธฐ์ตํ์.
DeBERTa Class Diagram
์ ์๋ฃ๋ ํ์๊ฐ ๊ตฌํํ DeBERTa
์ ๊ตฌ์กฐ๋ฅผ ํํํ ๊ทธ๋ฆผ์ด๋ค. ์ฝ๋ ๋ฆฌ๋ทฐ์ ์ฐธ๊ณ ํ์๋ฉด ์ข์ ๊ฒ ๊ฐ๋ค ์ฒจ๋ถํ๋ค. ๊ฐ์ฅ ์ค์ํ Disentangled-Attention
๊ณผ EMD
๋ถํฐ ์ดํด๋ณธ ๋ค, ๋๋จธ์ง ๊ฐ์ฒด์ ๋ํด์ ์ดํด๋ณด์.
๐ชขย Disentangled Self-Attention
\[\tilde{A_{ij}} = Q_i^cโขK_j^{cT} + Q_i^cโขK_{โ(i,j)}^{rT} + K_j^cโขQ_{โ(i,j)}^{rT} \\
Attention(Q_c,K_c,V_c,Q_r,K_r) = softmax(\frac{\tilde{A}}{\sqrt{3d_h}})*V_c\]
Disentangled Self-Attention
์ ์ ์๊ฐ ํจ์ดํ Input Embedding
์ ๋ณด์ Relative Position
์ ๋ณด๋ฅผ ํตํฉ์ํค๊ธฐ ์ํด ๊ณ ์ํ ๋ณํ Self-Attention
๊ธฐ๋ฒ์ด๋ค. ๊ธฐ์กด์ Self-Attention
๊ณผ ๋ค๋ฅด๊ฒ Position Embedding
์ Input Embedding
์ ๋ํ์ง ์๊ณ ๋ฐ๋ก ์ฌ์ฉํ๋ค. ์ฆ, ๊ฐ์ $d_h$ ๊ณต๊ฐ์ Input Embedding
๊ณผ Relative Position
์ด๋ผ๋ ์๋ก ๋ค๋ฅธ ๋ ๋ฒกํฐ๋ฅผ ๋งตํํ๊ณ ๊ทธ ๊ด๊ณ์ฑ์ ํ์
ํด๋ณด๊ฒ ๋ค๋ ๋ป์ด๋ค.
Input
๊ณผ Position
์ ๋ณด๋ฅผ ์๋ก ์ฃผ์ฒด์ ์ธ ์
์ฅ์์ ํ ๋ฒ์ฉ ๋ด์ ํ๋ค๊ณ ํด์ Disentangled
๋ผ๋ ์ด๋ฆ์ด ๋ถ๊ฒ ๋์๋ค. Transformer-XL
, XLNet
์ ์ ์๋ Cross-Attention
๊ณผ ๋งค์ฐ ์ ์ฌํ ๊ฐ๋
์ด๋ค. ์ฒซ๋ฒ์งธ ์์์์ ๊ฐ์ฅ ๋ง์ง๋ง ํญ์ ์ ์ธํ๋ฉด Cross-Attention
๊ณผ ํฌ์ฐฉํ๋ ์ ๋ณด๊ฐ ๋์ผํ๋ค๊ณ ์ ์ ์ญ์ ๋ฐํ๊ณ ์์ผ๋ ์ฐธ๊ณ ํ์.
Disentangled Self-Attention
์ ์ด 5๊ฐ์ง linear projection matrix
๋ฅผ ์ฌ์ฉํ๋ค. Input Embedding
์ ์ถ์ฒ๋ก ํ๋ $Q^c, k^c, V^c$, ๊ทธ๋ฆฌ๊ณ Position Embedding
์ ์ถ์ฒ๋ก ํ๋ $Q^r, K^r$์ด๋ค. ์ฒจ์ $c,r$์ ๊ฐ๊ฐ content
, relative
์ ์ฝ์๋ก ํ๋ ฌ์ ์ถ์ฒ๋ฅผ ๋ปํ๋ค. ํํธ ํ๋ ฌ ์๋ ์ฒจ์์ ์จ์๋ $i,j$๋ ๊ฐ๊ฐ ํ์ฌ ์ดํ
์
๋์ ํ ํฐ์ ์ธ๋ฑ์ค์ ๊ทธ ๋๋จธ์ง ํ ํฐ์ ์ธ๋ฑ์ค๋ฅผ ๊ฐ๋ฆฌํจ๋ค. ๊ทธ๋์ $\tilde{A_{ij}}$๋ [NxN]
ํฌ๊ธฐ ํ๋ ฌ(๊ธฐ์กด ์ดํ
์
์์ ์ฟผ๋ฆฌ์ ํค์ ๋ด์ ๊ฒฐ๊ณผ์ ํด๋น
)์ $i$๋ฒ์งธ ํ๋ฐฑํฐ์ $j$๋ฒ์งธ ์์์ ๊ฐ์ ์๋ฏธํ๋ค. Input Embedding
์ ๋ณด์ Relative Position
์ ๋ณด๋ฅผ ๋ฐ๋ก ๋ฐ๋ก ๊ด๋ฆฌํ๊ธฐ ๋๋ฌธ์ ์ฐ๋ฆฌ๊ฐ ๊ธฐ์กด์ ์๊ณ ์๋ Self-Attention
๊ณผ๋ ์ฌ๋ญ ๋ค๋ฅธ ์์์ด๋ค. ์ด์ ๋ถํฐ ์์์ ํญ ํ๋ํ๋์ ์๋ฏธ๋ฅผ ๊ตฌ์ฒด์ ์ธ ์์์ ํจ๊ผ ํํค์ณ๋ณด์.
โบ๏ธย c2c matrix
content2content
์ ์ฝ์๋ก ์ฒซ๋ฒ์งธ ์์ ์ฐ๋ณ์ ์ฒซ๋ฒ์งธ ํญ์ ๊ฐ๋ฆฌํค๋ ๋ง์ด๋ค. ์ด๋ฆ์ ์๋ฏธ๋ ๋ด์ ์ ์ฌ์ฉํ๋ ๋ ํ๋ ฌ์ ์ถ์ฒ๊ฐ ๋ชจ๋ Input Embedding
์ด๋ผ๋ ์ฌ์ค์ ๋ดํฌํ๊ณ ์๋ค. ๊ธฐ์กด์ ์๊ณ ์๋ Self-Attention
์ ๋๋ฒ์งธ ๋จ๊ณ์ธ $QโขK^T$์ ๊ฑฐ์ ๋์ผํ ์๋ฏธ๋ฅผ ๋ด๊ณ ์๋ ํญ์ด๋ผ๊ณ ์๊ฐํ๋ฉด ๋ ๊ฒ ๊ฐ๋ค. ์์ ํ ๊ฐ๋ค๊ณ ํ ์ ์๋ ์ด์ ๋ Absolute Position
์ ๋ณด๊ฐ ๋น ์ง์ฑ๋ก ๋ด์ ํ๊ธฐ ๋๋ฌธ์ด๋ค.
๋ฐ๋ผ์ ์ฐ์ฐ์ ์๋ฏธ ์ญ์ ์ฐ๋ฆฌ๊ฐ ๊ธฐ์กด์ ์๊ณ ์๋ ๋ฐ์ ๋์ผํ๋ค. ํน์ ํ๋ ฌ $Q,K,V$์ Self-Attention
๊ฐ ๋ดํฌํ๋ ์๋ฏธ์ ๋ํด ์์ธํ ๊ถ๊ธํ์ ๋ถ์ด๋ผ๋ฉด ํ์๊ฐ ์์ฑํ Transformer๋
ผ๋ฌธ ๋ฆฌ๋ทฐ๋ฅผ ๋ณด๊ณ ์ค์๊ธธ ๋ฐ๋๋ค. ๊ทธ๋๋ ์ด์ฐจํผ ๋ค์ ๋จ์ ๋๊ฐ์ ํญ์ ์ค๋ช
ํ๋ ค๋ฉด ์ด์ฐจํผ ์์๋ฅผ ๋ค์ด์ผ ํ๊ธฐ ๋๋ฌธ์ c2c
ํญ์์๋ถํฐ ์์ํด๋ณด๋ ค ํ๋ค.
๋น์ ์ ์ค๋ ์ ๋ ๋ฐฅ์ผ๋ก ์ฐจ๋๋ฐ์ด ๋์ฅ ์ฐ๊ฐ, ์ผ๊ฒน์ด ๊ทธ๋ฆฌ๊ณ ํ์์ผ๋ก ๊ตฌ์ด ๋ฌ๊ฑ์ ๋จน๊ณ ์ถ๋ค. ์ง์ ์ฌ๋ฃ๊ฐ ํ๋๋ ์์ง๋ง ๋งํธ์ ๊ฐ๊ธฐ ๊ท์ฐฎ์ผ๋ ํ์ํ ์์์ฌ๋ฅผ ๋จํธ์๊ฒ ์ฌ์ค๋ผ๊ณ ์ํฌ ์๊ฐ์ด๋ค. ๋น์ ์ ๊ทธ๋์ ํ์ํ ์ฌ๋ฃ ๋ฆฌ์คํธ๋ฅผ ์ ๊ณ ์๋ค. ๊ทธ๋ ๋ค๋ฉด ํ์ํ ์ฌ๋ฃ๋ฅผ ์ด๋ค ์์ผ๋ก ํํํด์ ์ ์ด์ค์ผ ๋จํธ์ด ๊ฐ์ฅ ๋น ๋ฅด๊ณ ์ ํํ๊ฒ ํ์ํ ๋ชจ๋ ์์์ฌ๋ฅผ ์ฌ์ฌ ์ ์์๊น??
์ด๊ฒ์ ๊ณ ๋ฏผํ๋๊ฒ ๋ฐ๋ก ํ๋ ฌ $Q^c$์ linear projector
์ธ $W_{Q^c}$์ ์ญํ ์ด๋ค.์๋ฅผ ๋ค์ด ๊ฐ์ ์๋ค๋ฆฌ์ด์ด๋ผ๋ ๊ตฌ์ด์ฉ์ด ์๊ณ ์ฐ๊ฐ์ฉ์ด ์๋ค. ๋ฌ๊ฑ๋ ๊ตฌ์ด ๋ฌ๊ฑ์ด ์๊ณ ๋ ๋ฌ๊ฑ์ด ์๋ค. ์ ํํ ์ฉ๋๋ฅผ ์ ์ด์ฃผ๋๊ฒ ๋จํธ ์
์ฅ์์๋ ์๋ด์ ์๋๋๋ก ์ ํํ๊ฒ ์ฅ์ ๋ณด๊ธฐ ํจ์ฌ ํธํ ๊ฒ์ด๋ค.
ํํธ, ๋ด์ ์ ๋ณธ๋ ํ๋ผ๋ฏธํฐ๊ฐ ํ์ํ ์ฐ์ฐ์ ์๋๋ผ์ ์ค์ ์์คํจ์ ์ค์ฐจ ์ญ์ ์ ํตํด ์ต์ ํ(ํ์ต)๋๋ ๋์์ ๋ฐ๋ก $W_{Q^c}$๊ฐ ๋๋ค. ๋จํธ์ด ์ฅ์ ๋น ๋ฅด๊ณ ์ ํํ๊ฒ ๋ณด๋๋ฐ ๊ณผ์ฐ ๋น์ ์ด ์ ์ด์ค ๋ฆฌ์คํธ๋ง ์ํฅ์ ๋ฏธ์น ๊น??
์๋๋ค. ๋น์ ์ด ์ด๋ค ์์์ ์ํด ์ด๋ค ์ฌ๋ฃ๊ฐ ํ์ํ์ง ๊ทธ ์๋๋ฅผ ์ ์ ์ด์ฃผ๋ ๊ฒ๋ ์ค์ํ์ง๋ง ์ค์ ๋งํธ์ ์ ํ ์๋ ์ํ๋ช
๊ณผ ์ํ์ค๋ช
์ญ์ ์ค์ํ๋ค. ์ข ์ต์ง์ค๋ฌ์ด ์์์ฒ๋ผ ๋ณด์ด๊ธด ํ์ง๋ง ๋ฌ๊ฑ์ ๊ฒฝ์ฐ ์ก์์ผ๋ก๋ง ๋ณด๋ฉด ์ด๊ฒ์ด ๊ตฌ์ด ๋ฌ๊ฑ์ธ์ง ๋ ๋ฌ๊ฑ์ธ์ง ๊ตฌ๋ถํ ์ ์๋ค. ๊ทธ๋ฐ๋ฐ ๋งํธ์ ๋ณ๋ค๋ฅธ ์ค๋ช
์์ด ์ํ๋ช
์ผ๋ก โ๋ฌ๊ฑโ
์ด๋ผ๊ณ ๋ง ์ ํ์๋ค ์๊ฐํด๋ณด์.
์๋ฌด๋ฆฌ ๋น์ ์ด ์ข์ ํ๋ ฌ $Q^c$๋ฅผ ํํํด์ค๋ ๋จํธ์ด ๋ ๋ฌ๊ฑ์ ์ฌ์ฌ ํ๋ฅ ์ด ๊ฝค๋ ๋์ ๊ฒ์ด๋ค. ์ด๋ ๊ฒ ๋งํธ์ ์ ํ์๋ ์ํ๋ช
๊ณผ ์ํ์ค๋ช
์ด ๋ฐ๋ก ํ๋ ฌ $K^c$์ ๋์๋๋ค. ๊ทธ๋ฆฌ๊ณ ๋ฌผ๊ฑด์ ์ฌ๊ธฐ ์ํด ๋น์ ์ด ์ ์ด์ค ์์์ฌ ๋ฆฌ์คํธ์ ๋งค์ฅ์ ์ ํ ์ํ๋ช
๊ณผ ์ํ์ค๋ช
์ ๋์กฐํ๋ฉฐ ์ด๊ฒ์ด ์๋์ ๋ง๋ ์ํ์ธ์ง ๋ฐ์ ธ๋ณด๋ ์์
์ด ๋ฐ๋ก $Q_i^cโขK_j^{cT}$, c2c matrix
๊ฐ ๋๋ค.
๋ค๋ง, ์ ์ญ ์ดํ
์
์ ์ฌ์ฉํ๊ธฐ ๋๋ฌธ์ ๋ฌ๊ฑ์ ์ฌ๊ธฐ ์ํด ๋งค์ฅ์ ์๋ ๋ชจ๋ ์ํ๊ณผ ๋์กฐ๋ฅผ ํ๋ค๊ณ ์๊ฐํ๋ฉด ๋๋ค. ํนํ ๊ธฐ์กด ์ ์ญ ์ดํ
์
์ $Q_i^cโขK_j^{cT}$์ ๊ฒฝ์ฐ ๋ชจ๋ ์ํ๊ณผ ๋์กฐํ๋ ๊ณผ์ ์์ ๋์กฐ๊ตฐ์ด ๋งค์ฅ์ ์ ์๋ ์์น, ์นดํ
๊ณ ๋ฆฌ ๋ถ๋ฅ์ ์ด๋ ์ฝ๋์ ์ํ๋์ง ๋ฑ์ ์์น ์ ๋ณด๋ฅผ ํ๊บผ๋ฒ์ ๊ณ ๋ คํ์ง๋ง, ์ฐ๋ฆฌ์(c2c matrix
) ๊ฒฝ์ฐ ์ฌ๊ธฐ์ ์ด๋ฐ ์์น ์ ๋ณด๋ฅผ ์ ํ ๊ณ ๋ คํ์ง ์๊ณ ๋ค์ ๋ ๊ฐ์ ํญ์์ ๋ฐ๋ก ๊ณ ๋ คํ๋ค.
์ ๋ฆฌํ๋ฉด, c2c
๋ ๋งค์ฅ์ ์ง์ด๋ ์์์ฌ์ ์ํ๋ช
๋ฐ ์ค๋ช
๋ง ๊ฐ์ง๊ณ ๋ด๊ฐ ์ฌ์ผ ํ๋ ์์ฌ๋ฃ์ธ์ง ์๋์ง ํ๋จํ๋ ์์
์ ์ํ์ ์ผ๋ก ๋ชจ๋ธ๋ง ํ๋ค๊ณ ๋ณผ ์ ์๊ฒ ๋ค. ์์ฐ์ด ์ฒ๋ฆฌ ๋งฅ๋ฝ์์ ๋ฐ๋ผ๋ณด๋ฉด, ํน์ ํ ํฐ์ ์๋ฏธ๋ฅผ ์๊ธฐ ์ํด์ syntactical
ํ ์ ๋ณด์์ด ์์ํ๊ฒ ๋๋จธ์ง ๋ค๋ฅธ ํ ํฐ๋ค์ ์๋ฏธ๋ฅผ ๊ฐ์คํฉ์ผ๋ก ๋ฐ์ํ๋ ํ์์ ๋์๋๋ค.
๐๏ธย c2p matrix
content2position
์ ์ฝ์๋ก ์์ ์ฐ๋ณ์ ๋๋ฒ์งธ ํญ, $Q_i^cโขK_{โ(i,j)}^{rT}$๋ฅผ ๊ฐ๋ฆฌํจ๋ค. c2c
๋์๋ ๋ค๋ฅด๊ฒ ์๋ก ์ถ์ฒ๊ฐ ๋ค๋ฅธ ๋ ํ๋ ฌ์ ์ฌ์ฉํด c2p
๋ผ๋ ์ด๋ฆ์ ๋ถ์๋ค. ๋ด์ ๋์์ ์ฟผ๋ฆฌ๋ Input Embedding
์ผ๋ก๋ถํฐ ๋ง๋ ํ๋ ฌ $Q_i^c$, ํค๋ Position Embedding
์ผ๋ก๋ถํฐ ๋ง๋ ํ๋ ฌ $K_{โ(i,j)}^{rT}$ ์ ์ฌ์ฉํ๋ค. word context
์ relative position
์ ์๋ก ๋์กฐํ๋ค๋ ๊ฒ์ด ๋ฌด์จ ์๋ฏธ๋ฅผ ๊ฐ๋์ง ์ง๊ด์ ์ผ๋ก ์๊ธฐ ํ๋๋ ์ฅ๋ณด๊ธฐ ์์๋ฅผ ํตํด ์ดํดํด๋ณด์.
๊ตฌ์ด ๋ฌ๊ฑ๊ณผ ๋ ๋ฌ๊ฑ์ ์์๋ฅผ ๋ค๋ฉด์ ์ํ๋ช
๊ณผ ์ค๋ช
์ด ์ฅ๋ณด๊ธฐ์ ์ค์ํ ์ํฅ์ ๋ฏธ์น๋ค๊ณ ์ธ๊ธํ๋ค. ํ์ง๋ง ์ํ๋ช
๊ณผ ์ค๋ช
์ด ์ฌ์ ํ ๋จ์ โ๋ฌ๊ฑโ
์ผ๋ก ์ ํ ์์ด๋ ์ฐ๋ฆฌ๋ ์ด๊ฒ์ ๊ตฌ๋ถํด ๋ผ ๋ฐฉ๋ฒ์ด ์๋ค. ๋ฐ๋ก ์ฃผ๋ณ์ ์ง์ด๋ ์ํ์ด ๋ฌด์์ธ์ง ์ดํด๋ณด๋ ๊ฒ์ด๋ค. โ๋ฌ๊ฑโ
๋ฐ๋ก ์์ ์ฐ์ , ์น์ฆ, ์์ , ์ ์ก๊ณผ ๊ฐ์ ์ ์ ์ํ๋ฅ๊ฐ ๋ฐฐ์น๋์ด ์๋ค๊ณ ๊ฐ์ ํด๋ณด์. ์ฐ๋ฆฌ๋ ์ฐ๋ฆฌ ๋ ์์ ์๋ โ๋ฌ๊ฑโ
์ด ๋ ๋ฌ๊ฑ์ด๋ผ๊ณ ๊ธฐ๋ํด ๋ด์งํ๋ค. ๋ง์ฝ โ๋ฌ๊ฑโ
์์ ์ฅํฌ, ๋ง๋ฆฐ ์ค์ง์ด, ์กํฌ, ๊ณผ์ ๊ฐ์ ๊ฐ์๋ฅ ์ํ๋ค์ด ๋ฐฐ์น๋์ด ์๋ค๋ฉด ์ด๋จ๊น?? ๊ทธ๋ผ ์ด โ๋ฌ๊ฑโ
์ ์ถฉ๋ถํ ๊ตฌ์ด ๋ฌ๊ฑ์ด๋ผ๊ณ ํด์ํด๋ณผ ์ ์๋ค. ์ด์ฒ๋ผ ์ฃผ์์ ์ด๋ค ๋ค๋ฅธ ์ํ๋ค์ด ๋ฐฐ์น ๋์ด ์๋๊ฐ๋ฅผ ํตํด ์ฐ๋ฆฌ๊ฐ ์ฌ๋ ค๋ ๋ฌผ๊ฑด์ด ๋ง๋์ง ๋์กฐํด๋ณด๋ ํ์๊ฐ ๋ฐ๋ก c2p
์ ๋์๋๋ค. ๊ทธ๋ ๋ค๋ฉด ์ฃผ์์ ์ด๋ค ๋ค๋ฅธ ์ํ๋ค์ด ๋ฐฐ์น ๋์ด ์๋๊ฐ ์ ๋ณด๋ฅผ ๋ชจ์ ๋์ ๊ฒ์ด ๋ฐ๋ก $K_{โ(i,j)}^{rT}$๊ฐ ๋๋ค.
๐ฌย p2c matrix
Disentangled Self-Attention
์ด ์ฌํ ๋ค๋ฅธ ์ดํ
์
๊ธฐ๋ฒ๋ค๊ณผ ๊ฐ์ฅ ์ฐจ๋ณํ๋๋ ๋ถ๋ถ์ด๋ค. ์ ์๊ฐ ๋
ผ๋ฌธ์์ ๊ฐ์ฅ ๊ฐ์กฐํ๋ ๋ถ๋ถ์ด๊ธฐ๋ ํ๋ค. ์ฌ์ค ๊ทธ๋ฐ ๊ฒ์น๊ณ ๋ ๋
ผ๋ฌธ ์ ์ค๋ช
์ด ์๋นํ ๋ถ์น์ ํด ์ดํดํ๊ธฐ ์ฐธ ๋ํดํ ๊ฐ๋
์ด๋ค. ์ด๊ฑฐ ์ค๋ช
ํ๊ณ ์ถ์ด์ ์ฅ๋ณด๊ธฐ ์์๋ฅผ ์๊ฐํด๋ด๊ฒ ๋์๋ค. ๋ค์ ๋จํธ์๊ฒ ์ค ์ฅ๋ณด๊ธฐ ๋ฆฌ์คํธ๋ฅผ ์์ฑํ๋ ์์ ์ผ๋ก ๋์๊ฐ๋ณด์.
์ค๋ ์ ๋ ๋ฉ๋ด๋ ์ฐจ๋๋ฐ์ด ๋์ฅ์ฐ๊ฐ์ ๊ตฌ์ด ์ผ๊ฒน์ด์ด๋ค. ๋จผ์ ์ฐจ๋๋ฐ์ด ๋์ฅ์ฐ๊ฐ๋ฅผ ๋ง๋ค๋ ค๋ฉด ์ด๋ค ์ฌ๋ฃ๊ฐ ํ์ํ ๊น?? ์ฐจ๋๋ฐ์ด, ๋์ฅ, ์ฒญ์๊ณ ์ถ, ์ํ, ๋ค์ง ๋ง๋, ํธ๋ฐ๊ณผ ๊ฐ์ ์์์ฌ๊ฐ ํ์ํ ๊ฒ์ด๋ค. ๊ทธ๋ฆฌ๊ณ ์ผ๊ฒน์ด์ ํ์ํ ์ฌ๋ฃ๋ฅผ ์๊ฐํด๋ณด์. ์์ผ๊ฒน์ด๊ณผ ์ก๋ด๋ฅผ ์์ ๋๋ฐ ํ์ํ ํ์ถ์ ์๊ธ ๊ทธ๋ฆฌ๊ณ ๊ตฌ์ ๋จน์ ํต๋ง๋์ด ํ์ํ๋ค๊ณ ๋น์ ์ ์๊ฐํ๋ค. ๊ทธ๋ผ ์ด์ ์ด๊ฒ์ ๋ฐํ์ผ๋ก ๋ฆฌ์คํธ๋ฅผ ์์ฑํ ๊ฒ์ด๋ค. ์ด๋ค ์์ผ๋ก ๋ฆฌ์คํธ๋ฅผ ์์ฑํ๋๊ฒ ๊ฐ์ฅ ์ต์ ์ผ๊น??
c2c
, c2p
์์์ ํจ๊ป ์๊ฐํด๋ณด๋ฉด ์ ์ ์๋ค. c2c
์์๋ ๊ฐ์ ์ฌ๋ฃ๋ผ๋ ๊ทธ ์ฉ๋์ ๋ฐ๋ผ์ ์ฌ์ผํ ํ๋ชฉ์ด ๋ฌ๋ผ์ง๋ค๊ณ ์ธ๊ธํ ๋ฐ์๋ค. c2p
์์๋ ์ ํํ ์ค๋ช
์ด ์์ด๋ ์ฃผ๋ณ์ ๋์ด๋ ํ๋ชฉ๋ค์ ๋ณด๋ฉด์ ์ด๋ค ์ํ์ธ์ง ์ ์ถ๊ฐ ๊ฐ๋ฅํ๋ค๊ณ ํ๋ค. ์ด๊ฒ์ ํฉ์ณ๋ณด์. ๋ง์ฝ ๋น์ ์ด ์๋์ ๊ฐ์ ์์๋ก ๋ฆฌ์คํธ๋ฅผ ์ ์๋ค๊ณ ๊ฐ์ ํด๋ณด๊ฒ ๋ค.
# ์ฅ๋ณด๊ธฐ ๋ฆฌ์คํธ ์์1
์ฐจ๋๋ฐ์ด, ๋์ฅ, ๋ง๋, ์ฒญ์๊ณ ์ถ, ์ํ, ํธ๋ฐ, ์ผ๊ฒน์ด, ํ์ถ, ์๊ธ
์๊น ํ์ํ ํ๋ชฉ์ ๋์ดํ์ ๋ ๋ถ๋ช ํ ๋ค์ง ๋ง๋๊ณผ ํต๋ง๋์ ๋์์ ์๊ฐํ์๋ค. ๊ทผ๋ฐ ์์ฒ๋ผ ๋ฆฌ์คํธ๋ฅผ ์์ฑํด์ ๋จํธ์๊ฒ ์คฌ๋ค๋ฉด ๋จํธ์ ์ด๋ค ๋ง๋์ ์ฌ์ฌ๊น?? ๋น์ฐํ ์ฐจ๋๋ฐ์ด์ ๋์ฅ ๊ทธ๋ฆฌ๊ณ ์ํ ์ฌ์ด์ ๋ง๋์ด ์์นํ ๊ฒ์ ๋ณด๊ณ ๋จํธ์ ๊ตญ๋ฌผ์ฉ ๋ง๋์ด ํ์ํ๊ตฌ๋ ์ถ์ด์ ๋ค์ง ๋ง๋์ ์ฌ์ฌ ๊ฒ์ด๋ค.
๊ทธ๋ ๋ค๋ฉด ๋ฐ๋๋ก ๋น์ ์ด ์๋์ฒ๋ผ ๋ฆฌ์คํธ๋ฅผ ์์ฑํ๋ค๊ณ ์๊ฐํด๋ณด์.
# ์ฅ๋ณด๊ธฐ ๋ฆฌ์คํธ ์์2
์ฐจ๋๋ฐ์ด, ๋์ฅ, ์ฒญ์๊ณ ์ถ, ์ํ, ํธ๋ฐ, ์ผ๊ฒน์ด, ๋ง๋, ํ์ถ, ์๊ธ
์ด๋ฒ์๋ ์ผ๊ฒน์ด ๊ตฌ์ธ ๋, ๊ฐ์ด ๊ตฌ์๋จน์ ํต๋ง๋์ด ํ์ํ๊ตฌ๋๋ฅผ ๋จํธ์ด ๋๋ ์ ์์ ๊ฒ์ด๋ค. ํํธ ์๋์ ๊ฐ์ ์ํฉ์ด๋ผ๋ฉด ์ด๋จ๊น??
# ์ฅ๋ณด๊ธฐ ๋ฆฌ์คํธ ์์3
์ฐจ๋๋ฐ์ด, ๋์ฅ, ๋ง๋, ์ฒญ์๊ณ ์ถ, ์ํ, ํธ๋ฐ, ์ผ๊ฒน์ด, ๋ง๋, ํ์ถ, ์๊ธ
์กฐ๊ธ ์ผ์ค๊ฐ ์๋ ๋จํธ์ด๋ผ๋ฉด ๋์ฅ์ฐ๊ฐ ๊ตญ๋ฌผ์ฉ ๋ค์ง๋ง๋๊ณผ ์ผ๊ฒน์ด ๊ตฌ์ด์ฉ ํต๋ง๋์ด ๋์์ ํ์ํ๊ตฌ๋๋ผ๊ณ ์ ์ถํ๊ณ ๋งค์ฅ์์ ๋ค์ง๋ง๋, ํต๋ง๋์ด๋ผ ์จ์๋ ํ๋ชฉ์ ์ฐพ์์ ๋ ๋ค ์ฌ์ฌ ๊ฒ์ด๋ค. ๋ฌผ๋ก ์ผ์ค์๋ ์๋ด๋ผ๋ฉด ์ ์ด์ ์ ๋ ๊ฒ ์ ๋งคํ๊ฒ ๋ง๋
์ด๋ผ๊ณ 2๋ฒ ์์ ๊ณ ๋ค์ง๋ง๋
, ํต๋ง๋
์ด๋ผ๊ณ ์ฉ๋๋ฅผ ํจ๊ป ์ ์ด์คฌ๊ฒ ์ง๋ง ๋ง์ด๋ค.
์ด๋ฌํ ์ผ๋ จ์ ์ํฉ์ด ๋ฐ๋ก p2c
์ ๋์๋๋ค. ๊ทธ๋ ๋ค๋ฉด ์๋ด๊ฐ ์ ์ด์ค ๋ฆฌ์คํธ์์ ์ฃผ๋ณ์ ์์นํ ํ๋ชฉ๋ค์ ๋ฐ๋ผ์ ํฌ์ฐฉ๋๋ ๋์ ํ๋ชฉ์ ์ฉ๋๋ ์ฐ์์, ์๋ฏธ ๋ฑ์ด ๋ฐ๋ก ํ๋ ฌ $Q_{โ(i,j)}^{rT}$๊ฐ ๋๋ค.
โ๏ธย DeBERTa Scale Factor
์ฒ์์ ๋์ดํ ์์์ ๋ค์ ๋ณด๋ฉด DeBERTa
์ scale factor
๋ ๊ธฐ์กด Self-Attention
๊ณผ ๋ค๋ฅด๊ฒ $\sqrt{3d_h}$๋ฅผ ์ฌ์ฉํ๋ค. ์ด์ ๊ฐ ๋ญ๊น?? ๊ธฐ์กด ๋ฐฉ์์ softmax layer
์ ์ ๋ฌํ๋ ํ๋ ฌ์ ์ข
๋ฅ๊ฐ $QโขK^T$ ํ ๊ฐ๋ค. DeBERTa
์ ๊ฒฝ์ฐ๋ 3๊ฐ๋ฅผ ์ ๋ฌํ๊ฒ ๋๋ค. ๊ทธ๋์ $d_h$์์ 3์ ๊ณฑํด์ค ๊ฒ์ด๋ค. official repo์ ์ฝ๋๋ฅผ ํ์ธํด๋ณด๋ฉด ํ์คํ ์ ์ ์๋๋ฐ, ์ดํ
์
์ ์ฌ์ฉํ๋ ํ๋ ฌ ์ข
๋ฅ์ ๊ฐ์๋ฅผ $d_h$์์ ๊ณฑํด์ค๋ค. ์๋๋ repo
์ ์ฌ๋ผ์ ์๋ ์ฝ๋์ ์ผ๋ถ๋ฅผ ๋ฐ์ทํ ๊ฒ์ด๋ค.
# official Disentangled Self-Attention by microsoft from official repo
...์ค๋ต...
def forward(self, hidden_states, attention_mask, return_att=False, query_states=None, relative_pos=None, rel_embeddings=None):
if query_states is None:
query_states = hidden_states
query_layer = self.transpose_for_scores(self.query_proj(query_states), self.num_attention_heads).float()
key_layer = self.transpose_for_scores(self.key_proj(hidden_states), self.num_attention_heads).float()
value_layer = self.transpose_for_scores(self.value_proj(hidden_states), self.num_attention_heads)
rel_att = None
# Take the dot product between "query" and "key" to get the raw attention scores.
scale_factor = 1
if 'c2p' in self.pos_att_type:
scale_factor += 1
if 'p2c' in self.pos_att_type:
scale_factor += 1
if 'p2p' in self.pos_att_type:
scale_factor += 1
๐ฉโ๐ปย Implementation
์ด๋ ๊ฒ Disentangled Self-Attention
์ ๋ํ ๋ชจ๋ ๋ด์ฉ์ ์ดํด๋ดค๋ค. ์ค์ ๊ตฌํ์ ์ด๋ป๊ฒ ํด์ผ ํ๋์ง ํ์๊ฐ ์์ฑํ ํ์ดํ ์น ์ฝ๋์ ํจ๊ป ์์๋ณด์.
# Pytorch Implementation of DeBERTa Disentangled Self-Attention
def build_relative_position(x_size: int) -> Tensor:
""" Build Relative Position Matrix for Disentangled Self-attention in DeBERTa
Args:
x_size: sequence length of query matrix
Reference:
https://arxiv.org/abs/2006.03654
https://github.com/microsoft/DeBERTa/blob/master/DeBERTa/deberta/da_utils.py#L29
"""
x_index, y_index = torch.arange(x_size, device="cuda"), torch.arange(x_size, device="cuda") # same as rel_pos in official repo
rel_pos = x_index.view(-1, 1) - y_index.view(1, -1)
return rel_pos
def disentangled_attention(
q: Tensor,
k: Tensor,
v: Tensor,
qr: Tensor,
kr: Tensor,
attention_dropout: torch.nn.Dropout,
padding_mask: Tensor = None,
attention_mask: Tensor = None
) -> Tensor:
""" Disentangled Self-attention for DeBERTa, same role as Module "DisentangledSelfAttention" in official Repo
Args:
q: content query matrix, shape (batch_size, seq_len, dim_head)
k: content key matrix, shape (batch_size, seq_len, dim_head)
v: content value matrix, shape (batch_size, seq_len, dim_head)
qr: position query matrix, shape (batch_size, 2*max_relative_position, dim_head), r means relative position
kr: position key matrix, shape (batch_size, 2*max_relative_position, dim_head), r means relative position
attention_dropout: dropout for attention matrix, default rate is 0.1 from official paper
padding_mask: mask for attention matrix for MLM
attention_mask: mask for attention matrix for CLM
Math:
c2c = torch.matmul(q, k.transpose(-1, -2)) # A_c2c
c2p = torch.gather(torch.matmul(q, kr.transpose(-1 z, -2)), dim=-1, index=c2p_pos)
p2c = torch.gather(torch.matmul(qr, k.transpose(-1, -2)), dim=-2, index=c2p_pos)
attention Matrix = c2c + c2p + p2c
A = softmax(attention Matrix/sqrt(3*D_h)), SA(z) = Av
Notes:
dot_scale(range 1 ~ 3): scale factor for QโขK^T result, sqrt(3*dim_head) from official paper by microsoft,
3 means that use full attention matrix(c2c, c2p, p2c), same as number of using what kind of matrix
default 1, c2c is always used and c2p & p2c is optional
References:
https://arxiv.org/pdf/1803.02155
https://arxiv.org/abs/2006.03654
https://arxiv.org/abs/2111.09543
https://arxiv.org/abs/1901.02860
https://arxiv.org/abs/1906.08237
https://github.com/microsoft/DeBERTa/blob/master/DeBERTa/deberta/disentangled_attention.py
"""
BS, NUM_HEADS, SEQ_LEN, DIM_HEADS = q.shape
_, MAX_REL_POS, _, _ = kr.shape
scale_factor = 1
c2c = torch.matmul(q, k) # A_c2c, q: (BS, NUM_HEADS, SEQ_LEN, DIM_HEADS), k: (BS, NUM_HEADS, DIM_HEADS, SEQ_LEN)
c2p_att = torch.matmul(q, kr.permute(0, 2, 3, 1).contiguous())
c2p_pos = build_relative_position(SEQ_LEN) + MAX_REL_POS / 2 # same as rel_pos in official repo
c2p_pos = torch.clamp(c2p_pos, 0, MAX_REL_POS - 1).repeat(BS, NUM_HEADS, 1, 1).long()
c2p = torch.gather(c2p_att, dim=-1, index=c2p_pos)
if c2p is not None:
scale_factor += 1
p2c_att = torch.matmul(qr, k) # qr: (BS, NUM_HEADS, SEQ_LEN, DIM_HEADS), k: (BS, NUM_HEADS, DIM_HEADS, SEQ_LEN)
p2c = torch.gather(p2c_att, dim=-2, index=c2p_pos) # same as torch.gather(kโขqr^t, dim=-1, index=c2p_pos)
if p2c is not None:
scale_factor += 1
dot_scale = torch.sqrt(torch.tensor(scale_factor * DIM_HEADS)) # from official paper by microsoft
attention_matrix = (c2c + c2p + p2c) / dot_scale # attention Matrix = A_c2c + A_c2r + A_r2c
if padding_mask is not None:
padding_mask = padding_mask.unsqueeze(1).unsqueeze(2) # for broadcasting: shape (BS, 1, 1, SEQ_LEN)
attention_matrix = attention_matrix.masked_fill(padding_mask == 1, float('-inf')) # Padding Token Masking
attention_dist = attention_dropout(
F.softmax(attention_matrix, dim=-1)
)
attention_matrix = torch.matmul(attention_dist, v).permute(0, 2, 1, 3).reshape(-1, SEQ_LEN, NUM_HEADS*DIM_HEADS).contiguous()
return attention_matrix
p2c
๋ฅผ ๊ตฌํ๋ ๊ณผ์ ์ ์ฝ๋๋ผ์ธ์ ์ฃผ๋ชฉํด๋ณด์. ๋
ผ๋ฌธ์ ๊ธฐ์ฌ๋ ์์($K_j^cโขQ_{โ(i,j)}^{rT}$)๊ณผ ๋ค๋ฅด๊ฒ, ์ฟผ๋ฆฌ์ ํค์ ์์๋ฅผ ๋ค์ง์๋ค. ๊ทธ๋์ torch.gather
์ ์ฐจ์ ๋งค๊ฐ๋ณ์ dim
๋ฅผ c2p
์ ์ํฉ๊ณผ ๋ค๋ฅด๊ฒ -2
๋ก ์ด๊ธฐํํ๊ฒ ๋์๋ค. ๋ด์ ํ๋ ํญ์ ์์๋ฅผ ๋ค์ง์ ๊ฒ์ผ๋ก ์ธํด ์ฐ๋ฆฌ๊ฐ ์ถ์ถํ๊ณ ์ถ์ ๋์ ๊ฐ์ธ ์๋ ์์น ์๋ฒ ๋ฉ์ด -2
๋ฒ์งธ ์ฐจ์์ ์์น ํ๊ฒ ๋๊ธฐ ๋๋ฌธ์ด๋ค.
๐ทย Enhanced Mask Decoder
DeBERTa
์ ์ค๊ณ ๋ชฉ์ ์ 2๊ฐ์ง ์์น ์ ๋ณด๋ฅผ ์ ์ ํ ์์ด์ ์ต๋ํ ํ๋ถํ ์๋ฒ ๋ฉ์ ๋ง๋๋ ๊ฒ์ด๋ผ๊ณ ํ๋ค. ์๋ ์์น ์๋ฒ ๋ฉ์ Disentangled Self-Attention
์ ํตํด ํฌ์ฐฉํ๋ค๋ ๊ฒ์ ์ด์ ์์๋ค. ๊ทธ๋ผ ์ ๋ ์์น ์๋ฒ ๋ฉ์ ์ด๋ค ์์ผ๋ก ๋ชจ๋ธ๋งํด์ค์ผ ํ ๊น?? ๊ทธ ๋ฌผ์์ ๋ต์ ๋ฐ๋ก EMD
๋ผ ๋ถ๋ฆฌ๋ Enhanced Mask Decoder
์ ์๋ค. EMD
์ ์๋ฆฌ์ ๋ํด ๊ณต๋ถํ๊ธฐ ์ ์ ์ ์ ๋ ์์น ์๋ฒ ๋ฉ์ด NLU
์ ํ์ํ์ง ์ง๊ณ ๋์ด๊ฐ์.
์ ๋ฌธ์ฅ์ ์ ์๊ฐ ๋
ผ๋ฌธ์์ Absolute Position Embedding
์ ํ์์ฑ์ ์ญ์คํ ๋ ์ฌ์ฉํ ์์ ๋ฌธ์ฅ์ด๋ค. ๊ณผ์ฐ ์๋ ์์น ์๋ฒ ๋ฉ๋ง ์ฌ์ฉํด์ store
์ mall
์ ์ฐจ์ด๋ฅผ ์ ๊ตฌ๋ณํ ์ ์์๊น ์๊ฐํด๋ณด์. ์์ ์ฐ๋ฆฌ๋ ์๋ ์์น ์๋ฒ ๋ฉ์ ๋์ ํ ํฐ๊ณผ ๊ทธ ๋๋จธ์ง ํ ํฐ ์ฌ์ด์ ์์น ๋ณํ์ ๋ฐ๋ผ ๋ฐ์ํ๋ ํ์์ ์ธ ๋งฅ๋ฝ ์ ๋ณด๋ฅผ ๋ด์ ํ๋ ฌ์ด๋ผ๊ณ ์ ์ํ ๋ฐ ์๋ค. ๋ค์ ๋งํด, ๋์ ํ ํฐ์ ์๋ฏธ๋ฅผ ์ฃผ๋ณ์ ์ด๋ค context
๊ฐ ์๋์ง ํ์
ํด ํตํด ์ดํดํด๋ณด๊ฒ ๋ค๋ ๊ฒ์ด๋ค.
์์ ๋ฌธ์ฅ์ ๋ค์ ๋ณด์. ๋ ๋์ ๋จ์ด ๋ชจ๋ ์ฃผ์์ ๋น์ทํ ์๋ฏธ๋ฅผ ๊ฐ๋ ๋จ์ด๋ค์ด ์์นํด ์๋ค. ์ด๋ฐ ๊ฒฝ์ฐ ์๋ ์์น ์๋ฒ ๋ฉ๋ง์ผ๋ก๋ ์ํ์ค ๋ด๋ถ์์ store
์ mall
์ ์๋ฏธ ์ฐจ์ด๋ฅผ ๋ชจ๋ธ์ด ๋ช
ํํ๊ฒ ์ดํดํ๊ธฐ ๋งค์ฐ ์ด๋ ค์ธ ๊ฒ์ด๋ค. ํ์ฌ ์ํฉ์์ ๋ ๋จ์ด์ ๋์์ค ์ฐจ์ด๋ ๊ฒฐ๊ตญ ๋ฌธ์ฅ์ ์ฃผ์ด๋ ๋ชฉ์ ์ด๋ ํ๋ syntactical
ํ ์ ๋ณด์ ์ํด์ ๊ฒฐ์ ๋๋ค. syntactical
ํ ์ ๋ณด์ ํ์์ฑ์ ๋ฐ๋ก ์ ๋ ์์น ์๋ฒ ๋ฉ์ด NLU
์ ๊ผญ ํ์ํ ์ด์ ์ ๋์๋๋ค.
Enhanced Mask Decoder Overview
๐คย why named decoder
ํ์๋ ์ฒ์ ๋
ผ๋ฌธ์ ์ฝ์์ ๋ Decoder
๋ผ๋ ์ด๋ฆ์ ๋ณด๋ฉด์ ์ฐธ ์์ํ๋ค. ๋ถ๋ช
Only-Encoder
๋ชจ๋ธ๋ก ์๊ณ ์๋๋ฐ ์ด์ฐํ์ฌ ์ด๋ฆ์ ๋์ฝ๋๊ฐ ๋ถ๋ ๋ชจ๋์ด ์๋ ๊ฒ์ธ๊ฐ. ๊ทธ๋ ๋ค๊ณ ์ด๋ฆ์ ์ ๋ ๊ฒ ๋ถ์ธ ์๋๋ฅผ ์ค๋ช
ํ๋ ๊ฒ๋ ์๋๋ค. ๊ทธ๋์ ํ์๊ฐ ์ค์ค๋ก ์ถ์ธกํด๋ดค๋ค.
DeBERTa
๋ pre-train task
๋ก MLM
์ ์ฌ์ฉํ๋ค. MLM
์ด ๋ฌด์์ธ๊ฐ?? ๋ฐ๋ก ๋ง์คํน๋ ์๋ฆฌ์ ์ ์ ํ ํ ํฐ์ ์ฐพ๋ ๋น์นธ ์ฑ์ฐ๊ธฐ ๋ฌธ์ ๋ค. ์๋ฏธ๊ถ์์๋ ์ด๊ฒ์ denoising
ํ๋ค๊ณ ํํํ๊ธฐ๋ ํ๋๋ฐ, Absolute Position Embedding
์ด ๋ฐ๋ก ์ด denoising
์ ์ง๋ํ ์ํฅ๋ ฅ์ ๋ฏธ์น๋ค๋ ์ธ๊ธ์ ๋
ผ๋ฌธ์์ ์ฐพ์๋ณผ ์ ์๋ค. ๋ฐ๋ผ์ denoising
์ฑ๋ฅ์ ํฐ ์ํฅ์ ์ฃผ๋ Absolute Position Embedding
์ ํ์ฉํ๋ค๊ณ ํด์ ์ด๋ฆ์ decoder
๋ฅผ ๋ถ์์ง ์์๋ ์์ํด๋ณธ๋ค.
๋
ผ๋ฌธ์ ๊ฐ์ด ์ค๋ฆฐ ๊ทธ๋ฆผ์ ํตํด์๋ ์ถ์ธก์ด ๊ฐ๋ฅํ๋ค. EMD
์ ๊ตฌ์กฐ๋ฅผ ์ค๋ช
ํ๋ฉด์ ์์ BERT์ ๋ชจ์๋๋ ํจ๊ป ์ ๊ณตํ๋๋ฐ, BERT
์๋ Decoder
๊ฐ ์ ํ ์๋ค. ๊ทธ๋ฐ๋ฐ๋ ์ด๋ฆ์ BERT decoding layer
๋ผ๊ณ ๋ถ๋ฅด๋ ๊ฒ๋ณด๋ฉด ํ์์ ์ถ์ธก์ ์ข ๋ ์ ๋น์ฑ์ ๋ถ์ฌํ๋ ๊ฒ ๊ฐ๋ค.
(+ ์ถ๊ฐ) offical repo code์์๋ EMD
๊ฐ ์ฐ๋ฆฌ๊ฐ ์๋ ๊ทธ Encoder
๋ฅผ ์ฌ์ฉํ๋ค๋ ์ฌ์ค์ ํ์ธํ ์ ์๋ค.
๐คทโโ๏ธย How to add Absolute Position
์ด์ EMD
๊ฐ ๋ฌด์์ด๋ฉฐ, Absolute Position
์ ์ด๋ป๊ฒ ๋ชจ๋ธ์ ์ถ๊ฐํ๋์ง ์์๋ณด์. EMD
๋ MLM
์ฑ๋ฅ์ ๋์ด๊ธฐ ์ํด ๊ณ ์๋ ๊ตฌ์กฐ๋ค. ๊ทธ๋์ ํ ํฐ ์์ธก์ ์ํ feedforward & softmax
๋ ์ด์ด ์ง์ ์ ์๋๋ค. ๋ช๊ฐ์ EMD
๋ฅผ ์์ ๊ฒ์ธ์ง๋ ํ์ดํผํ๋ฆฌ๋ฏธํฐ์ด๋ฉฐ, ์ ์์ ์คํ ๊ฒฐ๊ณผ 2
๊ฐ ์ฌ์ฉํ๋๊ฒ ๊ฐ์ฅ ํจ์จ์ ์ด๋ผ๊ณ ํ๋ค. ์๋กญ๊ฒ ์ธ์ฝ๋ ๋ธ๋ญ์ ์์ง ์๊ณ Disentangled-Attention
๋ ์ด์ด์ ๊ฐ์ฅ ๋ง์ง๋ง ์ธ์ฝ๋ ๋ธ๋ญ๊ณผ ๊ฐ์ค์น๋ฅผ ๊ณต์ ํ๋ ํํ๋ก ๊ตฌํํ๋ค.
# EMD Implementation Example
class EnhancedMaskDecoder(nn.Module):
def __init__(self, encoder: list[nn.ModuleList], dim_model: int = 1024) -> None:
super(EnhancedMaskDecoder, self).__init__()
self.emd_layers = encoder
class DeBERTa(nn.Module):
def __init__(self, vocab_size: int, max_seq: 512, N: int = 24, N_EMD: int = 2, dim_model: int = 1024, num_heads: int = 16, dim_ffn: int = 4096, dropout: float = 0.1) -> None:
# Init Sub-Blocks & Modules
self.encoder = DeBERTaEncoder(
self.max_seq,
self.N,
self.dim_model,
self.num_heads,
self.dim_ffn,
self.dropout_prob
)
self.emd_layers = [self.encoder.encoder_layers[-1] for _ in range(self.N_EMD)]. # weight share
self.emd_encoder = EnhancedMaskDecoder(self.emd_layers, self.dim_model)
๋ฐ๋ผ์ N_EMD=2
๋ก ์ค์ ํ๋ค๋ ๊ฒ์ ๊ฒฐ๊ตญ, Disentangled-Attention
๋ ์ด์ด์ ๊ฐ์ฅ ๋ง์ง๋ง ์ธ์ฝ๋ ๋ธ๋ญ์ 2๊ฐ ๋ ์๋ ๊ฒ๊ณผ ๋์น๋ค. ๋์ ์ธ์ฝ๋์ linear projection
๋ ์ด์ด์ ์
๋ ฅ๊ฐ์ด ๋ค๋ฅด๋ค. Disentangled-Attention
์ ํ๋ ฌ $Q^c, K^c, V^c$๋ ์ด์ ๋ธ๋ญ์ hidden_states
๊ฐ์ธ ํ๋ ฌ $H$๋ฅผ ์
๋ ฅ์ผ๋ก, ํ๋ ฌ $Q^r, K^r$์ ๋ ์ด์ด์ ์์น์ ์๊ด์์ด ๋ชจ๋ ๊ฐ์ ๊ฐ์ Relative Position Embedding
์ ์
๋ ฅ์ผ๋ก ์ฌ์ฉํ๋ค.
๋ฐ๋ฉด, EMD
๋งจ ์ฒ์ ์ธ์ฝ๋ ๋ธ๋ญ์ ํ๋ ฌ $Q^c$๋ ๋ฐ๋ก ์ง์ ๋ธ๋ญ์ hidden_states
์ Absolute Position Embedding
์ ๋ํ ๊ฐ์ ์
๋ ฅ์ผ๋ก ์ฌ์ฉํ๋ค. ์ดํ ๋๋จธ์ง ๋ธ๋ญ์๋ Disentangled-Attention
์ ๋ง์ฐฌ๊ฐ์ง๋ก ์ด์ ๋ธ๋ญ์ hidden_states
๋ฅผ ์ฌ์ฉํ๋ค. ํ๋ ฌ $K^c, V^c$๋ ๋ธ๋ญ ์์์ ์๊ด์์ด ์ด์ ๋ธ๋ญ์ hidden_states
๋ง ๊ฐ์ง๊ณ linear projection
์ ์ํํ๋ค. ๊ทธ๋ฆฌ๊ณ ํ๋ ฌ $Q^r, K^r$ ์ญ์ ๊ฐ์ ๊ฐ์ Relative Position Embedding
์ ์
๋ ฅ์ผ๋ก ์ฌ์ฉํ๋ค.
์ฌ์ค ํ์๋ ๋
ผ๋ฌธ๋ง ์ฝ์์ ๋ EMD
๋ Relative Position
์ ๋ณด๋ฅผ ์ฃผ์
ํด Disengtanled-Attention
์ ์ํํ๋ค๊ณ ์ ํ ์๊ฐํ์ง ๋ชปํ๋ค. ์ด๋ ๋
ผ๋ฌธ์ ์ค๋ช
์ด ์๋นํ ๋ถ์น์ ํ ๋๋ถ์ธ๋ฐ, ๋
ผ๋ฌธ์ ์ด์ ๊ด๋ จํด์ ์์ธํ ์ค๋ช
๋ ์๊ณ Absolute Position Embedding
์ ์ฌ์ฉํ๋ ๋ ์ด์ด๋ผ์ ๋น์ฐํ ์ผ๋ฐ์ ์ธ Self-Attention
์ ์ฌ์ฉํ ๊ฒ์ด๋ผ๊ณ ์๊ฐํ๋ ๊ฒ์ด๋ค.
ํ์๋ ์ฌ๊ธฐ์ Absolute Position
์ด ์ ํ์ํ์ง๋ ์๊ฒ ๊ณ ๊ทธ๋์ ํ๋ ฌํฉ์ผ๋ก ๋ํด์ ์ดํ
์
์ ์ํํ๋ ๊ฒ๋ ์ ์๊ฒ ๋๋ฐ ์ ๊ตณ์ด ๊ฐ์ฅ ๋ง์ง๋ง ๋ ์ด์ด์์ ์ด๊ฑธ ํ ๊น?? ํ๋ ์๋ฌธ์ด ๋ค์๋ค. ์ผ๋ฐ Self-Attention
์ฒ๋ผ ๋งจ ์ฒ์์ ๋ํ๊ณ ์์ํ๋ฉด ์๋ ๊น??
์ ์์ ์คํ์ ๋ฐ๋ฅด๋ฉด Absolute Position
์ ์ฒ์์ ์ถ๊ฐํ๋ ๊ฒ๋ณด๋ค EMD
์ฒ๋ผ ๊ฐ์ฅ ๋ง์ง๋ง์ ๋ํด์ฃผ๋๊ฒ ์ฑ๋ฅ์ด ๋ ์ข์๋ค๊ณ ํ๋ค. ๊ทธ ์ด์ ๋ก Absolute Position
๋ฅผ ์ด๋ฐ์ ์ถ๊ฐํ๋ฉด ๋ชจ๋ธ์ด Relative Position
์ ํ์ตํ๋๋ฐ ๋ฐฉํด๊ฐ ๋๋ ๊ฒ ๊ฐ๋ค๋ ์ถ์ธก์ ํจ๊ป ์์ ํ๊ณ ์๋ค. ๊ทธ๋ ๋ค๋ฉด ์ ๋ฐฉํด๊ฐ ๋๋ ๊ฒ์ผ๊น??
ํ์์ ๋ํผ์
์ด์ง๋ง ์ด๊ฒ ์ญ์ blessing of dimensionality
์์ ํ์๋ ๋ฌธ์ ๋ผ๊ณ ์๊ฐํ๋ค. ์ผ๋จ ์ฉ์ด์ ๋ป๋ถํฐ ์์๋ณด์. blessing of dimensionality
๋, ๊ณ ์ฐจ์ ๊ณต๊ฐ์์ ๋ฌด์์๋ก ์๋ก ๋ค๋ฅธ ๋ฒกํฐ ๋๊ฐ๋ฅผ ์ ํํ๋ฉด ๋ ๋ฒกํฐ๋ ๊ฑฐ์ ๋๋ถ๋ถ approximate orthogonality
๋ฅผ ๊ฐ๋ ํ์์ ์ค๋ช
ํ๋ ์ฉ์ด๋ค. ๋ฌด์กฐ๊ฑด ์ฑ๋ฆฝํ๋ ์ฑ์ง์ ์๋๊ณ ํ๋ฅ ๋ก ์ ์ธ ์ ๊ทผ์ด๋ผ๋ ๊ฒ์ ๋ช
์ฌํ์. ์๋ฌดํผ ์ง๊ตํ๋ ๋ ๋ฒกํฐ๋ ๋ด์ ๊ฐ์ด 0์ ์๋ ดํ๋ค. ์ฆ, ๋ ๋ฒกํฐ๋ ์๋ก์๊ฒ ์ํฅ์ ๋ฏธ์น์ง ๋ชปํ๋ค๋ ๊ฒ์ด๋ค.
์ด๊ฒ์ Transformer
์์ Input Embedding
๊ณผ Absolute Position Embedding
์ ํ๋ ฌํฉ์ผ๋ก ๋ํด๋ ์ข์ ํ์ต ๊ฒฐ๊ณผ๋ฅผ ์ป์ ์ ์๋ ์ด์ ๊ฐ ๋๋ค. ๋ค์ ๋งํด์, hidden states space
์์ Input Embedding
๊ณผ Absolute Position Embedding
์ญ์ ๊ฐ๋ณ ๋ฒกํฐ๊ฐ span
ํ๋ ๋ถ๋ถ ๊ณต๊ฐ ๋ผ๋ฆฌ๋ ์๋ก ์ง๊ตํ ๊ฐ๋ฅ์ฑ์ด ๋งค์ฐ ๋๋ค๋ ๊ฒ์ ์๋ฏธํ๋ค. ๋ฐ๋ผ์ ์๋ก ๋ค๋ฅธ ์ถ์ฒ๋ฅผ ํตํด ๋ง๋ค์ด์ง ๋ ํ๋ ฌ์ ๋ํด๋ ์๋ก์๊ฒ ์ํฅ์ ๋ฏธ์น์ง ๋ชปํ ๊ฒ์ด๊ณ ๊ทธ๋ก ์ธํด ๋ชจ๋ธ์ด Input
๊ณผ Position
์ ๋ณด๋ฅผ ๋ฐ๋ก ์ ํ์ตํ ์ ์์ ๊ฒ์ด๋ผ ๊ธฐ๋ํด๋ณผ ์ ์๋ค.
hidden states vector space example
์ด์ ๋ค์ DeBERTa
๊ฒฝ์ฐ๋ก ๋์์๋ณด์. ์ ๊ทธ๋ฆผ์ ํ๋์ ์ง์ ์ Input Embedding
, ๋นจ๊ฐ์ ์ง์ ์ Absolute Position Embedding
, ์ผ์ชฝ์ ๋ณด๋ผ์ ์ง์ ์ Relative Position Embedding
์ด๋ผ๊ณ ๊ฐ์ ํ์. blessing of dimensionality
์ ์ํด word con text
์ ๋ณด(ํ๋ ์ง์ )์ position
์ ๋ณด(๋นจ๊ฐ, ๋ณด๋ผ ์ง์ )๋ ๊ทธ๋ฆผ์ฒ๋ผ ์๋ก ๊ทผ์ฌ ์ง๊ตํ ๊ฐ๋ฅ์ฑ์ด ๋งค์ฐ ๋๋ค. ์ฌ๊ธฐ๋ถํฐ ํ์์ ๋ํผ์
์ด ๋ค์ด๊ฐ๋๋ฐ, ๋ณด๋ผ์ ์ง์ ๊ณผ ๋นจ๊ฐ์ ์ง์ ์ ์ฑ๊ฒฉ์ด ์ข ๋ค๋ฅด์ง๋ง ๊ฒฐ๊ตญ ๋ ๋ค ์ํ์ค์ position
์ ๋ณด๋ฅผ ๋ํ๋ธ๋ค๋ ์ ์์ ๋ฟ๋ฆฌ๋ ๊ฐ๋ค๊ณ ๋ณผ ์ ์๋ค. ๋ฐ๋ผ์ ์ค์ hidden states
๊ณต๊ฐ์์ ์ด๋ค ์์ผ๋ก ๋งตํ๋ ์ง๋ ์ ๋ชจ๋ฅด๊ฒ ์ง๋ง, ์๋ก ์ง๊ตํ๋ ํํ๋ ์๋ ๊ฒ์ด๋ผ ์ถ์ธกํ ์ ์๋ค.
๊ทธ๋ ๋ค๋ฉด Absolute Position
์ ๋ชจ๋ธ ๊ทน์ด๋ฐ์ ๋ํด์ค๋ค๊ณ ์๊ฐํด๋ณด์. ์ธ์ฝ๋์ ๋ค์ด๊ฐ๋ ํ๋ ฌ์ ๊ฒฐ๊ตญ ์ ๊ทธ๋ฆผ์ ์ด๋ก์ ์ง์ ์ผ๋ก ํํ๋ ๊ฒ์ด๋ค. ํ๋์ ์ง์ ๊ณผ ๋นจ๊ฐ์ ์ง์ ์ด ๊ทผ์ฌ ์ง๊ตํ๋ค๋ ๊ฐ์ ํ์ ๋ ๋ฐฑํฐ์ ํฉ์ ๋ ๋ฒกํฐ์ 45๋ ์ ๋ ๋๋ ๊ณณ์ ์์นํ๊ฒ(์ด๋ก์ ์ง์ ) ๋ ๊ฒ์ด๋ค. ๊ทธ๋ ๋ค๋ฉด ๋ณด๋ผ์ ์ง์ ๊ณผ ์ด๋ก์ ์ง์ ์ ๊ด๊ณ ์ญ์ ๊ทผ์ฌ ์ง๊ต์์ ์๋ก ๊ฐ์ญํ๋ ํํ๋ก ๋ณํํ๋ค. ๋ฐ์ EMD
๋ฅผ ๊ทน์ด๋ฐ์ ์ฌ์ฉํ๋ฉด ๊ฐ์ญ์ด ๋ฐ์ํด ๋ชจ๋ธ์ด Relative Position
์ ๋ณด๋ฅผ ์ ๋๋ก ํ์ตํ์ง ๋ชปํ ๊ฒ์ด๋ค.
๐ฉโ๐ปย Implementation
# Pytorch Implementation of DeBERTa Enhanced Mask Decoder
class EnhancedMaskDecoder(nn.Module):
"""
Class for Enhanced Mask Decoder module in DeBERTa, which is used for Masked Language Model (Pretrain Task)
Word 'Decoder' means that denoise masked token by predicting masked token
In official paper & repo, they might use 2 EMD layers for MLM Task
And this layer's key & value input is output from last disentangled self-attention encoder layer,
Also, all of them can share parameters and this layer also do disentangled self-attention
In official repo, they implement this layer so hard coding that we can't understand directly & easily
So, we implement this layer with our own style, as closely as possible to paper statement
Notes:
Also we temporarily implement only extract token embedding, not calculating logit, loss for MLM Task yet
MLM Task will be implemented ASAP
Args:
encoder: list of nn.ModuleList, which is (N_EMD * last encoder layer) from DeBERTaEncoder
References:
https://arxiv.org/abs/2006.03654
https://arxiv.org/abs/2111.09543
https://github.com/microsoft/DeBERTa/blob/master/DeBERTa/apps/models/masked_language_model.py
"""
def __init__(self, encoder: list[nn.ModuleList], dim_model: int = 1024) -> None:
super(EnhancedMaskDecoder, self).__init__()
self.emd_layers = encoder
self.layer_norm = nn.LayerNorm(dim_model)
def emd_context_layer(self, hidden_states, abs_pos_emb, rel_pos_emb, mask):
outputs = []
query_states = hidden_states + abs_pos_emb # "I" in official paper,
for emd_layer in self.emd_layers:
query_states = emd_layer(x=hidden_states, pos_x=rel_pos_emb, mask=mask, emd=query_states)
outputs.append(query_states)
last_hidden_state = self.layer_norm(query_states) # because of applying pre-layer norm
hidden_states = torch.stack(outputs, dim=0).to(hidden_states.device)
return last_hidden_state, hidden_states
def forward(self, hidden_states: Tensor, abs_pos_emb: Tensor, rel_pos_emb, mask: Tensor) -> tuple[Tensor, Tensor]:
"""
hidden_states: output from last disentangled self-attention encoder layer
abs_pos_emb: absolute position embedding
rel_pos_emb: relative position embedding
"""
last_hidden_state, emd_hidden_states = self.emd_context_layer(hidden_states, abs_pos_emb, rel_pos_emb, mask)
return last_hidden_state, emd_hidden_states
emd_context_layer
๋ฉ์๋์์ Absolute Position
์ ๋ณด๋ฅผ ์ถ๊ฐํด์ฃผ๋ ๋ถ๋ถ์ ์ ์ธํ๋ฉด ์ผ๋ฐ Encoder
๊ฐ์ฒด์ ๋์๊ณผ ๋์ผํ๋ค. ๋ํ DeBERTa๋ ๋ชจ๋ ๋ ์ด์ด๊ฐ ๊ฐ์ ์์ ์ forward pass ๋, ๋์ผํ ๊ฐ์ค์น์ Relative Position Embedding
์ ์ฌ์ฉํด์ผ ํ๋๋ฐ, EMD
์ญ์ ์์ธ๋ ์๋๊ธฐ ๋๋ฌธ์ ๋ฐ๋์ ์ต์์ ๊ฐ์ฒด์์ ์ด๊ธฐํํ Relative Position Embedding
์ ๋๊ฐ์ด ๋งค๊ฐ๋ณ์๋ก ์ ๋ฌํด์ค์ผ ํ๋ค.
๊ทธ๋ฆฌ๊ณ ๋ง์ง๋ง์ผ๋ก ๊ฐ์ฒด์์ ์ฌ์ฉํ๋ emd_layers
๋ ๋ชจ๋ Disentangled-Attention
๋ ์ด์ด์ ๊ฐ์ฅ ๋ง์ง๋ง ์ธ์ฝ๋๋ผ๋ ์ฌ์ค์ ์์ง ๋ง์.
๐ฉโ๐ฉโ๐งโ๐ฆย Multi-Head Attention
์ด์ ๋๋จธ์ง ๋ธ๋ญ๋ค์ ๋ํด์ ์ดํด๋ณด๊ฒ ๋ค. ์๋ฆฌ๋ ์๋ฏธ๋ ์ด๋ฏธ Transformer
๋ฆฌ๋ทฐ์์ ๋ชจ๋ ์ดํด๋ดค๊ธฐ ๋๋ฌธ์ ์๋ตํ๊ณ , ๊ตฌํ์ ํน์ด์ ์ ๋ํด์๋ง ์ธ๊ธํ๋ ค๊ณ ํ๋ค. ๋จผ์ Single-Head Atttention
์ฝ๋๋ฅผ ๋ณด์.
# Pytorch Implementation of Single Attention Head
class MultiHeadAttention(nn.Module):
""" In this class, we implement workflow of Multi-Head Self-attention for DeBERTa-Large
This class has same role as Module "BertAttention" in official Repo (bert.py)
In official repo, they use post-layer norm, but we use pre-layer norm which is more stable & efficient for training
Args:
dim_model: dimension of model's latent vector space, default 1024 from official paper
num_attention_heads: number of heads in MHSA, default 16 from official paper for Transformer
dim_head: dimension of each attention head, default 64 from official paper (1024 / 16)
attention_dropout_prob: dropout rate, default 0.1
Math:
attention Matrix = c2c + c2p + p2c
A = softmax(attention Matrix/sqrt(3*D_h)), SA(z) = Av
Reference:
https://arxiv.org/abs/1706.03762
https://arxiv.org/abs/2006.03654
"""
def __init__(self, dim_model: int = 1024, num_attention_heads: int = 12, dim_head: int = 64,
attention_dropout_prob: float = 0.1) -> None:
super(MultiHeadAttention, self).__init__()
self.dim_model = dim_model
self.num_attention_heads = num_attention_heads
self.dim_head = dim_head
self.fc_q = nn.Linear(self.dim_model, self.dim_model)
self.fc_k = nn.Linear(self.dim_model, self.dim_model)
self.fc_v = nn.Linear(self.dim_model, self.dim_model)
self.fc_qr = nn.Linear(self.dim_model, self.dim_model) # projector for Relative Position Query matrix
self.fc_kr = nn.Linear(self.dim_model, self.dim_model) # projector for Relative Position Key matrix
self.fc_concat = nn.Linear(self.dim_model, self.dim_model)
self.attention = disentangled_attention
self.attention_dropout = nn.Dropout(p=attention_dropout_prob)
def forward(self, x: Tensor, rel_pos_emb: Tensor, padding_mask: Tensor, attention_mask: Tensor = None,
emd: Tensor = None) -> Tensor:
""" x is already passed nn.Layernorm """
assert x.ndim == 3, f'Expected (batch, seq, hidden) got {x.shape}'
# size: bs, seq, nums head, dim head, linear projection
q = self.fc_q(x).reshape(-1, x.shape[1], self.num_attention_heads, self.dim_head).permute(0, 2, 1, 3).contiguous()
k = self.fc_k(x).reshape(-1, x.shape[1], self.num_attention_heads, self.dim_head).permute(0, 2, 3, 1).contiguous()
v = self.fc_v(x).reshape(-1, x.shape[1], self.num_attention_heads, self.dim_head).permute(0, 2, 1, 3).contiguous()
qr = self.fc_qr(rel_pos_emb).reshape(-1, x.shape[1], self.num_attention_heads, self.dim_head).permute(0, 2, 1, 3).contiguous()
kr = self.fc_kr(rel_pos_emb).reshape(-1, x.shape[1], self.num_attention_heads, self.dim_head)
if emd is not None:
q = self.fc_q(emd).reshape(-1, x.shape[1], self.num_attention_heads, self.dim_head).permute(0, 2, 1, 3).contiguous()
attention_matrix = self.attention(
q,
k,
v,
qr,
kr,
self.attention_dropout,
padding_mask,
attention_mask
)
attention_output = self.fc_concat(attention_matrix)
return attention_output
๋์ ์์ฒด๋ ๋์ผํ์ง๋ง, ์๋ ์์น ์ ๋ณด์ ๋ํ linear projection
๋ ์ด์ด๊ฐ ์ถ๊ฐ ๋์๋ค. ๊ทธ๋ฆฌ๊ณ Enhanced Mask Decoder
๋ฅผ ์ํด forward
๋ฉ์๋์ ์กฐ๊ฑด๋ฌธ์ ํ์ฉํ์ฌ Decoding
ํ๋ ์์ ์๋ hidden_states + absolute position embedding
์ผ๋ก ํ๋ ฌ $Q^c$๋ฅผ ํํํ๊ฒ ๊ตฌํํ๋ค. ์ด๋ ๊ฒ ๊ตฌํํ๋ฉด EMD
๋ฅผ ์ํด ๋ฐ๋ก AttentionHead
๋ฅผ ๊ตฌํํ ํ์๊ฐ ์์ด์ ์ฝ๋ ๊ฐ์ํ๊ฐ ๋๋ค.
MultiHeadAttention
๊ฐ์ฒด๋ ๋จ์ผ AttentionHead
๊ฐ์ฒด๋ฅผ ํธ์ถํ ๋ rel_pos_emb
๋ฅผ ๋งค๊ฐ๋ณ์๋ก ์ ๋ฌํด์ผ ํ๋ค๋ ์ ๋ง ๊ธฐ์ตํ๋ฉด ๋๋ค.
๐ฌย Feed Forward Network
# Pytorch Implementation of FeedForward Network
class FeedForward(nn.Module):
"""
Class for Feed-Forward Network module in transformer
In official paper, they use ReLU activation function, but GELU is better for now
We change ReLU to GELU & add dropout layer
Args:
dim_model: dimension of model's latent vector space, default 512
dim_ffn: dimension of FFN's hidden layer, default 2048 from official paper
dropout: dropout rate, default 0.1
Math:
FeedForward(x) = FeedForward(LN(x))+x
"""
def __init__(self, dim_model: int = 512, dim_ffn: int = 2048, dropout: float = 0.1) -> None:
super(FeedForward, self).__init__()
self.ffn = nn.Sequential(
nn.Linear(dim_model, dim_ffn),
nn.GELU(),
nn.Dropout(p=dropout),
nn.Linear(dim_ffn, dim_model),
nn.Dropout(p=dropout),
)
def forward(self, x: Tensor) -> Tensor:
return self.ffn(x)
์ญ์ ๊ธฐ์กด Transformer
, BERT
์ ๋ค๋ฅธ๊ฒ ์๋ค.
๐ย DeBERTaEncoderLayer
# Pytorch Implementation of DeBERTaEncoderLayer(single Disentangled-Attention Encoder Block)
class DeBERTaEncoderLayer(nn.Module):
"""
Class for encoder model module in DeBERTa-Large
In this class, we stack each encoder_model module (Multi-Head Attention, Residual-Connection, LayerNorm, FFN)
This class has same role as Module "BertEncoder" in official Repo (bert.py)
In official repo, they use post-layer norm, but we use pre-layer norm which is more stable & efficient for training
References:
https://github.com/microsoft/DeBERTa/blob/master/DeBERTa/deberta/bert.py
"""
def __init__(self, dim_model: int = 1024, num_heads: int = 16, dim_ffn: int = 4096, dropout: float = 0.1) -> None:
super(DeBERTaEncoderLayer, self).__init__()
self.self_attention = MultiHeadAttention(
dim_model,
num_heads,
int(dim_model / num_heads),
dropout,
)
self.layer_norm1 = nn.LayerNorm(dim_model)
self.layer_norm2 = nn.LayerNorm(dim_model)
self.dropout = nn.Dropout(p=dropout)
self.ffn = FeedForward(
dim_model,
dim_ffn,
dropout,
)
def forward(self, x: Tensor, pos_x: torch.nn.Embedding, mask: Tensor, emd: Tensor = None) -> Tensor:
""" rel_pos_emb is fixed for all layer in same forward pass time """
ln_x, ln_pos_x = self.layer_norm1(x), self.layer_norm1(pos_x) # pre-layer norm, weight share
residual_x = self.dropout(self.self_attention(ln_x, ln_pos_x, mask, emd)) + x
ln_x = self.layer_norm2(residual_x)
fx = self.ffn(ln_x) + residual_x
return fx
official code์ ๋ค๋ฅด๊ฒ pre-layernorm
์ ์ฌ์ฉํด ๊ตฌํํ๋ค. pre-layernorm
์ ๋ํด ๊ถ๊ธํ๋ค๋ฉด ์ฌ๊ธฐ๋ฅผ ํด๋ฆญํด ํ์ธํด๋ณด์.
๐ย DeBERTaEncoder
# Pytorch Implementation of DeBERTaEncoderr(N stacked DeBERTaEncoderLayer)
class DeBERTaEncoder(nn.Module, AbstractModel):
""" In this class, 1) encode input sequence, 2) make relative position embedding, 3) stack num_layers DeBERTaEncoderLayer
This class's forward output is not integrated with EMD Layer's output
Output have ONLY result of disentangled self-attention
All ops order is from official paper & repo by microsoft, but ops operating is slightly different,
Because they use custom ops, e.g. XDropout, XSoftmax, ..., we just apply pure pytorch ops
Args:
max_seq: maximum sequence length, named "max_position_embedding" in official repo, default 512, in official paper, this value is called 'k'
num_layers: number of EncoderLayer, default 24 for large model
Notes:
self.rel_pos_emb: P in paper, this matrix is fixed during forward pass in same time, all layer & all module must share this layer from official paper
References:
https://github.com/microsoft/DeBERTa/blob/master/DeBERTa/deberta/ops.py
"""
def __init__(
self,
cfg: CFG,
max_seq: int = 512,
num_layers: int = 24,
dim_model: int = 1024,
num_attention_heads: int = 16,
dim_ffn: int = 4096,
layer_norm_eps: float = 0.02,
attention_dropout_prob: float = 0.1,
hidden_dropout_prob: float = 0.1,
gradient_checkpointing: bool = False
) -> None:
super(DeBERTaEncoder, self).__init__()
self.cfg = cfg
self.max_seq = max_seq
self.num_layers = num_layers
self.dim_model = dim_model
self.num_attention_heads = num_attention_heads
self.dim_ffn = dim_ffn
self.hidden_dropout = nn.Dropout(p=hidden_dropout_prob) # dropout is not learnable
self.layer = nn.ModuleList(
[DeBERTaEncoderLayer(dim_model, num_attention_heads, dim_ffn, layer_norm_eps, attention_dropout_prob, hidden_dropout_prob) for _ in range(self.num_layers)]
)
self.layer_norm = nn.LayerNorm(dim_model, eps=layer_norm_eps) # for final-Encoder output
self.gradient_checkpointing = gradient_checkpointing
def forward(self, inputs: Tensor, rel_pos_emb: Tensor, padding_mask: Tensor, attention_mask: Tensor = None) -> Tuple[Tensor, Tensor]:
"""
Args:
inputs: embedding from input sequence
rel_pos_emb: relative position embedding
padding_mask: mask for Encoder padded token for speeding up to calculate attention score or MLM
attention_mask: mask for CLM
"""
layer_output = []
x, pos_x = inputs, rel_pos_emb # x is same as word_embeddings or embeddings in official repo
for layer in self.layer:
if self.gradient_checkpointing and self.cfg.train:
x = self._gradient_checkpointing_func(
layer.__call__, # same as __forward__ call, torch reference recommend to use __call__ instead of forward
x,
pos_x,
padding_mask,
attention_mask
)
else:
x = layer(
x,
pos_x,
padding_mask,
attention_mask
)
layer_output.append(x)
last_hidden_state = self.layer_norm(x) # because of applying pre-layer norm
hidden_states = torch.stack(layer_output, dim=0).to(x.device) # shape: [num_layers, BS, SEQ_LEN, DIM_Model]
return last_hidden_state, hidden_states
EMD
์ ๋ง์ฐฌ๊ฐ์ง๋ก ๋ ์ด์ด์ ์์น์ ์๊ด์์ด ๊ฐ์ ์์ ์๋ ๋ชจ๋ ๋์ผํ Relative Position Embedding
์ ์ฌ์ฉํด linear projection
ํ๋๋ก ๊ตฌํํด์ฃผ๋ ๊ฒ์ด ์ค์ ํฌ์ธํธ๋ค. forward
๋ฉ์๋๋ฅผ ํ์ธํ์!
๐คย DeBERTa
# Pytorch Implementation of DeBERTa
class DeBERTa(nn.Module, AbstractModel):
""" Main class for DeBERTa, having all of sub-blocks & modules such as Disentangled Self-attention, DeBERTaEncoder, EMD
Init Scale of DeBERTa Hyper-Parameters, Embedding Layer, Encoder Blocks, EMD Blocks
And then make 3-types of Embedding Layer, Word Embedding, Absolute Position Embedding, Relative Position Embedding
Args:
cfg: configuration.CFG
num_layers: number of EncoderLayer, default 12 for base model
this value must be init by user for objective task
if you select electra, you should set num_layers twice (generator, discriminator)
Var:
vocab_size: size of vocab in DeBERTa's Native Tokenizer
max_seq: maximum sequence length
max_rel_pos: max_seq x2 for build relative position embedding
num_layers: number of Disentangled-Encoder layers
num_attention_heads: number of attention heads
num_emd: number of EMD layers
dim_model: dimension of model
num_attention_heads: number of heads in multi-head attention
dim_ffn: dimension of feed-forward network, same as intermediate size in official repo
hidden_dropout_prob: dropout rate for embedding, hidden layer
attention_probs_dropout_prob: dropout rate for attention
References:
https://arxiv.org/abs/2006.03654
https://arxiv.org/abs/2111.09543
https://github.com/microsoft/DeBERTa/blob/master/experiments/language_model/deberta_xxlarge.json
https://github.com/microsoft/DeBERTa/blob/master/DeBERTa/deberta/config.py
https://github.com/microsoft/DeBERTa/blob/master/DeBERTa/deberta/deberta.py
https://github.com/microsoft/DeBERTa/blob/master/DeBERTa/deberta/bert.py
https://github.com/microsoft/DeBERTa/blob/master/DeBERTa/deberta/disentangled_attention.py
https://github.com/microsoft/DeBERTa/blob/master/DeBERTa/apps/models/masked_language_model.py
"""
def __init__(self, cfg: CFG, num_layers: int = 12) -> None:
super(DeBERTa, self).__init__()
# Init Scale of DeBERTa Module
self.cfg = cfg
self.vocab_size = cfg.vocab_size
self.max_seq = cfg.max_seq
self.max_rel_pos = 2 * self.max_seq
self.num_layers = num_layers
self.num_attention_heads = cfg.num_attention_heads
self.num_emd = cfg.num_emd
self.dim_model = cfg.dim_model
self.dim_ffn = cfg.dim_ffn
self.layer_norm_eps = cfg.layer_norm_eps
self.hidden_dropout_prob = cfg.hidden_dropout_prob
self.attention_dropout_prob = cfg.attention_probs_dropout_prob
self.gradient_checkpointing = cfg.gradient_checkpoint
# Init Embedding Layer
self.embeddings = Embedding(cfg)
# Init Encoder Blocks & Modules
self.encoder = DeBERTaEncoder(
self.cfg,
self.max_seq,
self.num_layers,
self.dim_model,
self.num_attention_heads,
self.dim_ffn,
self.layer_norm_eps,
self.attention_dropout_prob,
self.hidden_dropout_prob,
self.gradient_checkpointing
)
self.emd_layers = [self.encoder.layer[-1] for _ in range(self.num_emd)]
self.emd_encoder = EnhancedMaskDecoder(
self.cfg,
self.emd_layers,
self.dim_model,
self.layer_norm_eps,
self.gradient_checkpointing
)
def forward(self, inputs: Tensor, padding_mask: Tensor, attention_mask: Tensor = None) -> Tuple[Tensor, Tensor]:
"""
Args:
inputs: input sequence, shape (batch_size, sequence)
padding_mask: padding mask for MLM or padding token
attention_mask: attention mask for CLM, default None
"""
assert inputs.ndim == 2, f'Expected (batch, sequence) got {inputs.shape}'
word_embeddings, rel_pos_emb, abs_pos_emb = self.embeddings(inputs)
last_hidden_state, hidden_states = self.encoder(word_embeddings, rel_pos_emb, padding_mask, attention_mask)
emd_hidden_states = hidden_states[-self.cfg.num_emd]
emd_last_hidden_state, emd_hidden_states = self.emd_encoder(emd_hidden_states, abs_pos_emb, rel_pos_emb, padding_mask, attention_mask)
return emd_last_hidden_state, emd_hidden_states
Leave a comment