Updated:

๐Ÿ”ญย Overview

DeBERTa๋Š” 2020๋…„ Microsoft๊ฐ€ ICLR์—์„œ ๋ฐœํ‘œํ•œ ์ž์—ฐ์–ด ์ฒ˜๋ฆฌ์šฉ ์‹ ๊ฒฝ๋ง ๋ชจ๋ธ์ด๋‹ค. Disentangled Self-Attention, Enhanced Mask Decoder๋ผ๋Š” ๋‘๊ฐ€์ง€ ์ƒˆ๋กœ์šด ํ…Œํฌ๋‹‰์„ BERT, RoBERTa์— ์ ์šฉํ•ด ๋‹น์‹œ SOTA๋ฅผ ๋‹ฌ์„ฑํ–ˆ์œผ๋ฉฐ, ํŠนํžˆ ์˜์–ด์ฒ˜๋Ÿผ ๋ฌธ์žฅ์—์„œ ์ž๋ฆฌํ•˜๋Š” ์œ„์น˜์— ๋”ฐ๋ผ ๋‹จ์–ด์˜ ์˜๋ฏธ, ํ˜•ํƒœ๊ฐ€ ๊ฒฐ์ •๋˜๋Š” ๊ตด์ ˆ์–ด ๊ณ„์—ด์— ๋Œ€ํ•œ ์„ฑ๋Šฅ์ด ์ข‹์•„ ๊พธ์ค€ํžˆ ์‚ฌ๋ž‘๋ฐ›๊ณ  ์žˆ๋Š” ๋ชจ๋ธ์ด๋‹ค. ๋˜ํ•œ ์ธ์ฝ”๋”ฉ ๊ฐ€๋Šฅํ•œ ์ตœ๋Œ€ ์‹œํ€€์Šค ๊ธธ์ด๊ฐ€ 4096์œผ๋กœ ๋งค์šฐ ๊ธด ํŽธ (DeBERTa-V3-Large) ์— ์†ํ•ด, Kaggle Competition์—์„œ ์ž์ฃผ ํ™œ์šฉ๋œ๋‹ค. ์ถœ์‹œ๋œ์ง€ 2๋…„์ด ๋„˜๋„๋ก SuperGLUE ๋Œ€์‹œ๋ณด๋“œ์—์„œ ๊พธ์ค€ํžˆ ์ƒ์œ„๊ถŒ์„ ์œ ์ง€ํ•˜๊ณ  ์žˆ๋‹ค๋Š” ์ ๋„ DeBERTa๊ฐ€ ์–ผ๋งˆ๋‚˜ ์ž˜ ์„ค๊ณ„๋œ ๋ชจ๋ธ์ธ์ง€ ์•Œ ์ˆ˜ ์žˆ๋Š” ๋Œ€๋ชฉ์ด๋‹ค.

ํ•œํŽธ, DeBERTa์˜ ์„ค๊ณ„ ์ฒ ํ•™์€ Inductive Bias ๋‹ค. ๊ฐ„๋‹จํ•˜๊ฒŒ Inductive Bias๋ž€, ์ฃผ์–ด์ง„ ๋ฐ์ดํ„ฐ๋กœ๋ถ€ํ„ฐ ์ผ๋ฐ˜ํ™” ์„ฑ๋Šฅ์„ ๋†’์ด๊ธฐ ์œ„ํ•ดย "์ž…๋ ฅ๋˜๋Š” ๋ฐ์ดํ„ฐ๋Š” ~ ํ•  ๊ฒƒ์ด๋‹ค",ย "์ด๋Ÿฐ ํŠน์ง•์„ ๊ฐ–๊ณ  ์žˆ์„ ๊ฒƒ์ด๋‹ค"์™€ ๊ฐ™์€ ๊ฐ€์ •, ๊ฐ€์ค‘์น˜, ๊ฐ€์„ค ๋“ฑ์„ ๊ธฐ๊ณ„ํ•™์Šต ์•Œ๊ณ ๋ฆฌ์ฆ˜์— ์ ์šฉํ•˜๋Š” ๊ฒƒ์„ ๋งํ•œ๋‹ค. ViT ๋…ผ๋ฌธ ๋ฆฌ๋ทฐ์—์„œ๋„ ๋ฐํ˜”๋“ฏ, ํ“จ์–ดํ•œ Self-Attention ์˜ Inductive Bias ๋Š” ์‚ฌ์‹ค์ƒ ์—†์œผ๋ฉฐ, ์ „์ฒด Transformer ๊ตฌ์กฐ ๋ ˆ๋ฒจ์—์„œ ๋ด๋„ Absolute Position Embedding์„ ์‚ฌ์šฉํ•ด ํ† ํฐ์˜ ์œ„์น˜ ์ •๋ณด๋ฅผ ๋ชจ๋ธ์— ์ฃผ์ž…ํ•ด์ฃผ๋Š” ๊ฒƒ์ด ๊ทธ๋‚˜๋งˆ ์•ฝํ•œ Iniductive Bias๋ผ๊ณ  ๋ณผ ์ˆ˜ ์žˆ๋‹ค. ๋‹ค๋ฅธ ํฌ์ŠคํŒ…์—์„œ๋Š” ๋ถ„๋ช… Inductive Bias ๊ฐ€ ์ ๊ธฐ ๋•Œ๋ฌธ์— ์ž์—ฐ์–ด ์ฒ˜๋ฆฌ์—์„œ Transformer ๊ฐ€ ์„ฑ๊ณต์„ ๊ฑฐ๋‘˜ ์ˆ˜ ์žˆ๋‹ค๊ณ  ํ•ด๋†“๊ณ  ์ด๊ฒŒ ์ง€๊ธˆ ์™€์„œ ๋ง์„ ๋’ค์ง‘๋Š”๋‹ค๊ณ  ์ƒ๊ฐํ•  ์ˆ˜ ์žˆ๋‹ค. ํ•˜์ง€๋งŒ Self-Attention๊ณผ Absolute Position Embedding์˜ ์˜๋ฏธ๋ฅผ ๋‹ค์‹œ ํ•œ ๋ฒˆ ์ƒ๊ธฐํ•ด๋ณด๋ฉด, Inductive Bias ์ถ”๊ฐ€๋ฅผ ์ฃผ์žฅํ•˜๋Š” ์ €์ž๋“ค์˜ ์ƒ๊ฐ์ด ๊ฝค๋‚˜ ํ•ฉ๋ฆฌ์ ์ด์—ˆ์Œ์„ ์•Œ ์ˆ˜ ์žˆ๊ฒŒ ๋œ๋‹ค. ๊ตฌ์ฒด์ ์ธ ๋ชจ๋ธ ๊ตฌ์กฐ๋ฅผ ํŒŒ์•…ํ•˜๊ธฐ ์ „์— ๋จผ์ € Inductive Bias ์ถ”๊ฐ€๊ฐ€ ์™œ ํ•„์š”ํ•˜๋ฉฐ, ์–ด๋– ํ•œ ๊ฐ€์ •์ด ํ•„์š”ํ•œ์ง€ ์•Œ์•„๋ณด์ž.

๐Ÿชขย Inducitve Bias in DeBERTa

  • Absolute Position + Relative Position์„ ๋ชจ๋‘ ํ™œ์šฉํ•ด ํ’๋ถ€ํ•˜๊ณ  ๊นŠ์€ ์ž„๋ฒ ๋”ฉ ์ถ”์ถœ
  • ๋‹จ์–ด์˜ ๋ฐœ์ƒ ์ˆœ์„œ ์ž„๋ฒ ๋”ฉ๊ณผ ๋‹จ์–ด ๋ถ„ํฌ ๊ฐ€์„ค ์ž„๋ฒ ๋”ฉ์„ ๋ชจ๋‘ ์ถ”์ถœํ•˜๋Š” ๊ฒƒ์„ ๋ชฉ์ ์œผ๋กœ ์„ค๊ณ„

๋ณธ ๋…ผ๋ฌธ ์ดˆ๋ก์—๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์€ ๋ฌธ์žฅ์ด ์„œ์ˆ ๋˜์–ด ์žˆ๋‹ค.

motivated by the observation that the attention weight of a word pair depends on not only their contents but their relative positions. For example, the dependency between the words โ€œdeepโ€ and โ€œlearningโ€ is much stronger when they occur next to each other than when they occur in different sentences.

์œ„์˜ ๋‘ ๋ฌธ์žฅ์ด DeBERTa์˜ Inducitve Bias ๋ฅผ ๊ฐ€์žฅ ์ž˜ ์„ค๋ช…ํ•˜๊ณ  ์žˆ๋‹ค๊ณ  ์ƒ๊ฐํ•œ๋‹ค. ์ €์ž๊ฐ€ ์ถ”๊ฐ€๋ฅผ ์ฃผ์žฅํ•˜๋Š” Inductive Bias๋ž€, relative position ์ •๋ณด๋ผ๋Š” ๊ฒƒ๊ณผ ๊ธฐ์กด ๋ชจ๋ธ๋ง์œผ๋กœ๋Š” relative position์ด ์ฃผ๋Š” ๋ฌธ๋งฅ ์ •๋ณด ํฌ์ฐฉ์ด ๋ถˆ๊ฐ€๋Šฅํ•˜๋‹ค๋Š” ์‚ฌ์‹ค์„ ์•Œ ์ˆ˜ ์žˆ๋‹ค.

๊ทธ๋ ‡๋‹ค๋ฉด relative position ๊ฐ€ ์ œ๊ณตํ•˜๋Š” ๋ฌธ๋งฅ ์ •๋ณด๊ฐ€ ๋„๋Œ€์ฒด ๋ญ๊ธธ๋ž˜ ๊ธฐ์กด์˜ ๋ฐฉ์‹์œผ๋กœ๋Š” ํฌ์ฐฉ์ด ๋ถˆ๊ฐ€๋Šฅํ•˜๋‹ค๋Š” ๊ฒƒ์ผ๊นŒ?? ์ž์—ฐ์–ด์—์„œ ํฌ์ฐฉ ๊ฐ€๋Šฅํ•œ ๋ฌธ๋งฅ๋“ค์˜ ์ข…๋ฅ˜์™€ ๊ธฐ์กด์˜ ๋ชจ๋ธ๋ง ๋ฐฉ์‹์— ๋Œ€ํ•œ ์ •๋ฆฌ๋ถ€ํ„ฐ ํ•ด๋ณด์ž. ์—ฌ๊ธฐ์„œ ๋งํ•˜๋Š” ๊ธฐ์กด ๋ฐฉ์‹์ด๋ž€, ํ“จ์–ดํ•œ Self-Attention๊ณผ Absolute Position Embedding ์„ ์‚ฌ์šฉํ•˜๋Š” Transformer-Encoder-Base ๋ชจ๋ธ(BERT, RoBERTa)์„ ๋œปํ•œ๋‹ค. ์ด๋ฒˆ ํฌ์ŠคํŒ…์—์„œ๋Š” BERT๋ฅผ ๊ธฐ์ค€์œผ๋กœ ์„ค๋ช…ํ•˜๊ฒ ๋‹ค.

๐Ÿ“š Types of Embedding

๋จผ์ € ํ˜„์กดํ•˜๋Š” ๋ชจ๋“  ์ž„๋ฒ ๋”ฉ(๋ฒกํ„ฐ์— ๋ฌธ๋งฅ์„ ์ฃผ์ž…ํ•˜๋Š”)๊ธฐ๋ฒ•๋“ค์„ ์ •๋ฆฌํ•ด๋ณด์ž. ๋‹ค์Œ๊ณผ ๊ฐ™์ด 3๊ฐ€์ง€ ์นดํ…Œ๊ณ ๋ฆฌ๋กœ ๋ถ„๋ฅ˜๊ฐ€ ๊ฐ€๋Šฅํ•˜๋‹ค.

  • 1) ๋‹จ์–ด์˜ ๋นˆ๋„์ˆ˜: ์‹œํ€€์Šค์—์„œ ์‚ฌ์šฉ๋œ ํ† ํฐ๋“ค์˜ ๋นˆ๋„์ˆ˜๋ฅผ ์ธก์ •(Bag of words)
  • 2) ๋‹จ์–ด์˜ ๋ฐœ์ƒ ์ˆœ์„œ: corpus ๋‚ด๋ถ€์˜ ํŠน์ • sequence ๋“ฑ์žฅ ๋นˆ๋„๋ฅผ ์นด์šดํŠธ(N-Gram), ์ฃผ์–ด์ง„ ์‹œํ€€์Šค๋ฅผ ๊ฐ€์ง€๊ณ  ๋‹ค์Œ ์‹œ์ ์— ๋“ฑ์žฅํ•  ํ† ํฐ์„ ๋งž์ถ”๋Š” ๋ฐฉ์‹(LM)
  • 3) ๋‹จ์–ด ๋ถ„ํฌ ๊ฐ€์„ค : ๋‹จ์–ด์˜ ์˜๋ฏธ๋Š” ์ฃผ๋ณ€ ๋ฌธ๋งฅ์— ์˜ํ•ด ๊ฒฐ์ •๋œ๋‹ค๋Š” ๊ฐ€์ •, ์–ด๋–ค ๋‹จ์–ด ์Œ์ด ์ž์ฃผ ๊ฐ™์ด ๋“ฑ์žฅํ•˜๋Š”์ง€ ์นด์šดํŠธํ•ด PMI๋ฅผ ์ธก์ •ํ•˜๋Š” ๋ฐฉ์‹(Word2Vec)

๊ธฐ์กด์˜ ๋ชจ๋ธ๋ง ๋ฐฉ์‹์€ ์–ด๋””์— ํฌํ•จ๋ ๊นŒ?? BERT ๋Š” ๋Œ€๋ถ„๋ฅ˜ ์ƒ ์‹ ๊ฒฝ๋ง์— ํฌํ•จ๋˜๊ณ , Language Modeling์„ ํ†ตํ•ด ์‹œํ€€์Šค๋ฅผ ํ•™์Šตํ•œ๋‹ค๋Š” ์  ๊ทธ๋ฆฌ๊ณ  Self-Attention๊ณผ Absolute Position Embedding ์„ ์‚ฌ์šฉํ•œ๋‹ค๋Š” ์ ์—์„œ 2๋ฒˆ, ๋‹จ์–ด์˜ ๋ฐœ์ƒ ์ˆœ์„œ ์— ํฌํ•จ๋œ๋‹ค๊ณ  ๋ณผ ์ˆ˜ ์žˆ๋‹ค. Absolute Position Embedding ๊ณผ Self-Attention์˜ ์‚ฌ์šฉ์ด ํ“จ์–ดํ•œ BERT๊ฐ€ ๋ถ„๋ฅ˜์ƒ 2๋ฒˆ์ด๋ผ๋Š” ์‚ฌ์‹ค์„ ๋’ท๋ฐ›์นจํ•˜๋Š” ์ฆ๊ฑฐ๋ผ๋Š” ์ ์—์„œ ์˜์•„ํ•  ์ˆ˜ ์žˆ๋‹ค. ํ•˜์ง€๋งŒ ์ž˜ ์ƒ๊ฐํ•ด๋ณด์ž.

Absolute Position Embedding์€ ์ฃผ์–ด์ง„ ์‹œํ€€์Šค์˜ ๊ธธ์ด๋ฅผ ์ธก์ •ํ•œ ๋’ค, ๋‚˜์—ด๋œ ์ˆœ์„œ ๊ทธ๋Œ€๋กœ forwardํ•˜๊ฒŒ 0๋ถ€ํ„ฐ ๊ธธ์ด-1์˜ ๋ฒˆํ˜ธ๋ฅผ ๊ฐœ๋ณ„ ํ† ํฐ์— ํ• ๋‹นํ•œ๋‹ค. ๋‹ค์‹œ ๋งํ•ด, ๋‹จ์–ด๊ฐ€ ์‹œํ€€์Šค์—์„œ ๋ฐœ์ƒํ•œ ์ˆœ์„œ๋ฅผ ์ˆ˜ํ•™์ ์œผ๋กœ ํ‘œํ˜„ํ•ด ๋ชจ๋ธ์— ์ฃผ์ž…ํ•œ๋‹ค๋Š” ์˜๋ฏธ๊ฐ€ ๋œ๋‹ค. Self-Attention์€ Absolute Position Embedding ์ •๋ณด๊ฐ€ ์ฃผ์ž…๋œ ์‹œํ€€์Šค ์ „์ฒด๋ฅผ ํ•œ ๋ฒˆ์— ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌํ•œ๋‹ค. ๋”ฐ๋ผ์„œ ์ถฉ๋ถ„ํžˆ BERT ๊ฐ™์€ Self-Attention, Absolute Position Embedding ๊ธฐ๋ฐ˜ ๋ชจ๋ธ์„ 2๋ฒˆ์— ๋ถ„๋ฅ˜ํ•  ์ˆ˜ ์žˆ๊ฒ ๋‹ค.

ํ•œํŽธ, ํ˜น์ž๋Š” "BERT๋Š” MLM ์„ ์‚ฌ์šฉํ•˜๋Š”๋ฐ Language Modeling์„ ํ•œ๋‹ค๊ณ  ํ•˜๋Š”๊ฒŒ ๋งž๋‚˜์š”"๋ผ๊ณ  ๋งํ•  ์ˆ˜ ์žˆ๋‹ค. ํ•˜์ง€๋งŒ MLM ์—ญ์‹œ ๋Œ€๋ถ„๋ฅ˜ ์ƒ Language Modeling ๊ธฐ๋ฒ•์— ์†ํ•œ๋‹ค. ๋‹ค๋งŒ, Bi-Directionalํ•˜๊ฒŒ ๋ฌธ๋งฅ์„ ํŒŒ์•…ํ•˜๊ณ  LM์„ ํ•˜๋‹ˆ๊นŒ ์ •๋ง ์—„๋ฐ€ํžˆ ๋”ฐ์ง€๋ฉด 3๋ฒˆ์˜ ์†์„ฑ๋„ ์กฐ๊ธˆ์€ ์žˆ๋‹ค๊ณ  ๋ณด๋Š”๊ฒŒ ๋ฌด๋ฆฌ๋Š” ์•„๋‹ˆ๋ผ ์ƒ๊ฐํ•œ๋‹ค. MLM ์‚ฌ์šฉ์œผ๋กœ ๋” ๋งŽ์€ ์ •๋ณด๋ฅผ ํฌ์ฐฉํ•ด ์ž„๋ฒ ๋”ฉ์„ ๋งŒ๋“ค๊ธฐ ๋•Œ๋ฌธ์— ์ดˆ๊ธฐ BERT๊ฐ€ GPT๋ณด๋‹ค NLU์—์„œ ์ƒ๋Œ€์ ์œผ๋กœ ๊ฐ•์ ์„ ๊ฐ€์กŒ๋˜ ๊ฒƒ ์•„๋‹๊นŒ ์‹ถ๋‹ค.

๐Ÿ”ข Relative Position Embedding

์ด์ œ Relative Position Embedding์ด ๋ฌด์—‡์ด๊ณ , ๋„๋Œ€์ฒด ์–ด๋–ค ๋ฌธ๋งฅ ์ •๋ณด๋ฅผ ํฌ์ฐฉํ•œ๋‹ค๋Š” ๊ฒƒ์ธ์ง€ ์•Œ์•„๋ณด์ž. Relative Position Embedding ์ด๋ž€, ์‹œํ€€์Šค ๋‚ด๋ถ€ ํ† ํฐ ์‚ฌ์ด์˜ ์œ„์น˜ ๊ด€๊ณ„ ํ‘œํ˜„์„ ํ†ตํ•ด ํ† ํฐ ์‚ฌ์ด์˜ relation์„ pairwiseํ•˜๊ฒŒ ํ•™์Šตํ•˜๋Š” ์œ„์น˜ ์ž„๋ฒ ๋”ฉ ๊ธฐ๋ฒ•์„ ๋งํ•œ๋‹ค. ์ผ๋ฐ˜์ ์œผ๋กœ ์ƒ๋Œ€ ์œ„์น˜ ๊ด€๊ณ„๋Š” ์„œ๋กœ ๋‹ค๋ฅธ ๋‘ ํ† ํฐ์˜ ์‹œํ€€์Šค ์ธ๋ฑ์Šค ๊ฐ’์˜ ์ฐจ๋ฅผ ์ด์šฉํ•ด ๋‚˜ํƒ€๋‚ธ๋‹ค. ํฌ์ฐฉํ•˜๋Š” ๋ฌธ๋งฅ ์ •๋ณด๋Š” ์˜ˆ์‹œ์™€ ํ•จ๊นจ ์„ค๋ช…ํ•˜๊ฒ ๋‹ค. ๋”ฅ๋Ÿฌ๋‹์ด๋ผ๋Š” ๋‹จ์–ด๋Š” ์˜์–ด๋กœ Deep Learning ์ด๋‹ค. ๋‘ ๋‹จ์–ด๋ฅผ ํ•ฉ์ณ๋†“๊ณ  ๋ณด๋ฉด ์‹ ๊ฒฝ๋ง์„ ์‚ฌ์šฉํ•˜๋Š” ๋จธ์‹ ๋Ÿฌ๋‹ ๊ธฐ๋ฒ•์˜ ํ•œ ์ข…๋ฅ˜๋ผ๋Š” ์˜๋ฏธ๋ฅผ ๊ฐ–๊ฒ ์ง€๋งŒ, ๋”ฐ๋กœ ๋”ฐ๋กœ ๋ณด๋ฉด ๊นŠ์€, ๋ฐฐ์›€์ด๋ผ๋Š” ๊ฐœ๋ณ„์ ์ธ ์˜๋ฏธ๋กœ ๋‚˜๋‰œ๋‹ค.

  • 1) The Deep Learning is the Best Technique in Computer Science
  • 2) Iโ€™m learning how to swim in the deep ocean

Deep๊ณผ Learning์˜ ์ƒ๋Œ€์ ์ธ ๊ฑฐ๋ฆฌ์— ์ฃผ๋ชฉํ•˜๋ฉด์„œ ๋‘ ๋ฌธ์žฅ์„ ํ•ด์„ํ•ด๋ณด์ž. ์ฒซ ๋ฒˆ์งธ ๋ฌธ์žฅ์—์„œ ๋‘ ๋‹จ์–ด๋Š” ์ด์›ƒํ•˜๊ฒŒ ์œ„์น˜ํ•ด ์‹ ๊ฒฝ๋ง์„ ์‚ฌ์šฉํ•˜๋Š” ๋จธ์‹ ๋Ÿฌ๋‹ ๊ธฐ๋ฒ•์˜ ํ•œ ์ข…๋ฅ˜ ๋ผ๋Š” ์˜๋ฏธ๋ฅผ ๋งŒ๋“ค์–ด๋‚ด๊ณ  ์žˆ๋‹ค. ํ•œํŽธ ๋‘ ๋ฒˆ์งธ ๋ฌธ์žฅ์—์„œ ๋‘ ๋‹จ์–ด๋Š” ๋„์–ด์“ฐ๊ธฐ ๊ธฐ์ค€ 5๊ฐœ์˜ ํ† ํฐ๋งŒํผ ๋–จ์–ด์ ธ ์œ„์น˜ํ•ด ๊ฐ๊ฐ ๋ฐฐ์›€, ๊นŠ์€ ์ด๋ผ๋Š” ์˜๋ฏธ๋ฅผ ๋งŒ๋“ค์–ด ๋‚ด๊ณ  ์žˆ๋‹ค. ์ด์ฒ˜๋Ÿผ ๊ฐœ๋ณ„ ํ† ํฐ ์‚ฌ์ด์˜ ์œ„์น˜ ๊ด€๊ณ„์— ๋”ฐ๋ผ์„œ ํŒŒ์ƒ๋˜๋Š” ๋ฌธ๋งฅ์  ์ •๋ณด๋ฅผ ํฌ์ฐฉํ•˜๋ ค๋Š” ์˜๋„๋กœ ์„ค๊ณ„๋œ ๊ธฐ๋ฒ•์ด ๋ฐ”๋กœ Relative Position Embedding ์ด๋‹ค.

pairwise ํ•˜๊ฒŒ relation ์„ ํฌ์ฐฉํ•œ๋‹ค๋Š” ์ ์œผ๋กœ ๋ณด์•„ skip-gram์˜ negative sampling๊ณผ ๋งค์šฐ ์œ ์‚ฌํ•œ ๋Š๋‚Œ์˜ ์ •๋ณด๋ฅผ ํฌ์ฐฉํ•  ๊ฒƒ์ด๋ผ๊ณ  ์˜ˆ์ƒ๋˜๋ฉฐ ์นดํ…Œ๊ณ ๋ฆฌ ๋ถ„๋ฅ˜์ƒ 3๋ฒˆ, ๋‹จ์–ด ๋ถ„ํฌ ๊ฐ€์„ค์— ํฌํ•จ์‹œํ‚ฌ ์ˆ˜ ์žˆ์„ ๊ฒƒ ๊ฐ™๋‹ค. (ํ•„์ž์˜ ๊ฐœ์ธ์ ์ธ ์˜๊ฒฌ์ด๋‹ˆ ์ด ๋ถ€๋ถ„์— ๋Œ€ํ•œ ๋‹ค๋ฅธ ์˜๊ฒฌ์ด ์žˆ๋‹ค๋ฉด ๊ผญ ๋Œ“๊ธ€์— ์ ์–ด์ฃผ์‹œ๋ฉด ๊ฐ์‚ฌํ•˜๊ฒ ์Šต๋‹ˆ๋‹น๐Ÿฅฐ).

์œ„ ์˜ˆ์‹œ๋งŒ์œผ๋กœ๋Š” ์ƒ๋Œ€ ์œ„์น˜ ์ž„๋ฒ ๋”ฉ ๊ฐœ๋…์ด ์™€๋‹ฟ์ง€ ์•Š์„ ์ˆ˜ ์žˆ๋‹ค. ๊ทธ๋ ‡๋‹ค๋ฉด ์˜†์— ๋งํฌ๋ฅผ ๋จผ์ € ์ฝ๊ณ  ์˜ค์ž. (๋งํฌ1)

Relative Position Embedding ์„ ์‹ค์ œ ์–ด๋–ป๊ฒŒ ์ฝ”๋“œ๋กœ ๊ตฌํ˜„ํ•˜๋Š”์ง€, ๋ณธ ๋…ผ๋ฌธ์—์„œ๋Š” ์œ„์น˜ ๊ด€๊ณ„๋ฅผ ์–ด๋–ป๊ฒŒ ์ •์˜ํ–ˆ๋Š”์ง€ Absolute Position Embedding์™€ ๋น„๊ต๋ฅผ ํ†ตํ•ด ์•Œ์•„๋ณด์ž. ๋‹ค์Œ๊ณผ ๊ฐ™์€ ๋‘ ๊ฐœ์˜ ๋ฌธ์žฅ์ด ์žˆ์„ ๋•Œ, ๊ฐœ๋ณ„ ์œ„์น˜ ์ž„๋ฒ ๋”ฉ ๋ฐฉ์‹์ด ๋ฌธ์žฅ์˜ ์œ„์น˜ ์ •๋ณด๋ฅผ ์ธ์ฝ”๋”ฉํ•˜๋Š” ๊ณผ์ •์„ ํŒŒ์ด์ฌ ์ฝ”๋“œ๋กœ ์ž‘์„ฑํ•ด๋ดค๋‹ค. ํ•จ๊ป˜ ์‚ดํŽด๋ณด์ž.

  • A) I love studying deep learning so much
  • B) I love deep cheeze burguer so much
# Absolute Position Embedding

>>> max_length = 7
>>> position_embedding = nn.Embedding(7, 512) # [max_seq, dim_model]
>>> pos_x = position_embedding(torch.arange(max_length))
>>> pos_x, pos_x.shape
(tensor([[ 0.4027,  0.9331,  1.0556,  ..., -1.7370,  0.7799,  1.9851],  # A,B์˜ 0๋ฒˆ ํ† ํฐ: I
         [-0.2206,  2.1024, -0.6055,  ..., -1.1342,  1.3956,  0.9017],  # A,B์˜ 1๋ฒˆ ํ† ํฐ: love
         [-0.9560, -0.0426, -1.8587,  ..., -0.9406, -0.1467,  0.1762],  # A,B์˜ 2๋ฒˆ ํ† ํฐ: studying, deep
         ...,                                                           # A,B์˜ 3๋ฒˆ ํ† ํฐ: deep, cheeze
         [ 0.5999,  0.5235, -0.3445,  ...,  1.9020, -1.5003,  0.7535],  # A,B์˜ 4๋ฒˆ ํ† ํฐ: learning, burger
         [ 0.0688,  0.5867, -0.0340,  ...,  0.8547, -0.9196,  1.1193],  # A,B์˜ 5๋ฒˆ ํ† ํฐ: so
         [-0.0751, -0.4133,  0.0256,  ...,  0.0788,  1.4665,  0.8196]], # A,B์˜ 6๋ฒˆ ํ† ํฐ: much
        grad_fn=<EmbeddingBackward0>),
 torch.Size([7, 512]))

Absolute Position Embedding์€ ์ฃผ์œ„ ๋ฌธ๋งฅ์— ์ƒ๊ด€์—†์ด ๊ฐ™์€ ์œ„์น˜์˜ ํ† ํฐ์ด๋ผ๋ฉด ๊ฐ™์€ ํฌ์ง€์…˜ ๊ฐ’์œผ๋กœ ์ธ์ฝ”๋”ฉํ•˜๊ธฐ ๋•Œ๋ฌธ์— 512๊ฐœ์˜ ์›์†Œ๋กœ ๊ตฌ์„ฑ๋œ ํ–‰๋ฒกํ„ฐ๋“ค์˜ ์ธ๋ฑ์Šค๋ฅผ ์‹ค์ œ ๋ฌธ์žฅ์—์„œ ํ† ํฐ์˜ ๋“ฑ์žฅ ์ˆœ์„œ์— ๋งตํ•‘ํ•ด์ฃผ๋Š” ๋ฐฉ์‹์œผ๋กœ ์œ„์น˜ ์ •๋ณด๋ฅผ ํ‘œํ˜„ํ•œ๋‹ค. ์˜ˆ๋ฅผ ๋“ค๋ฉด, ๋ฌธ์žฅ์—์„œ ๊ฐ€์žฅ ๋จผ์ € ๋“ฑ์žฅํ•˜๋Š” 0๋ฒˆ ํ† ํฐ์— 0๋ฒˆ์งธ ํ–‰๋ฒกํ„ฐ๋ฅผ ๋ฐฐ์ •ํ•˜๊ณ  ๊ฐ€์žฅ ๋งˆ์ง€๋ง‰์— ๋“ฑ์žฅํ•˜๋Š” N-1 ๋ฒˆ์งธ ํ† ํฐ์€ N-1๋ฒˆ์งธ ํ–‰๋ฒกํ„ฐ๋ฅผ ์œ„์น˜ ์ •๋ณด๊ฐ’์œผ๋กœ ๊ฐ–๋Š” ๋ฐฉ์‹์ด๋‹ค. ์ „์ฒด ์‹œํ€€์Šค ๊ด€์ ์—์„œ ๊ฐœ๋ณ„ ํ† ํฐ์— ๋ฒˆํ˜ธ๋ฅผ ๋ถ€์—ฌํ•˜๊ธฐ ๋•Œ๋ฌธ์— syntacticalํ•œ ์ •๋ณด๋ฅผ ๋ชจ๋ธ๋ง ํ•ด์ฃผ๊ธฐ ์ ํ•ฉํ•˜๋‹ค๋Š” ์žฅ์ ์ด ์žˆ๋‹ค.

Absolute Position Embedding ์€ ์ผ๋ฐ˜์ ์œผ๋กœ Input Embedding๊ณผ ํ–‰๋ ฌํ•ฉ ์—ฐ์‚ฐ์„ ํ†ตํ•ด Word Embedding ์œผ๋กœ ๋งŒ๋“ค์–ด ์ธ์ฝ”๋”์˜ ์ž…๋ ฅ์œผ๋กœ ์‚ฌ์šฉํ•œ๋‹ค.

์•„๋ž˜ ์ฝ”๋“œ๋Š” ์ €์ž๊ฐ€ ๋…ผ๋ฌธ์—์„œ ์ œ์‹œํ•œ DeBERTa์˜ Relative Position Embedding ๊ตฌํ˜„์„ ํŒŒ์ดํ† ์น˜๋กœ ์˜ฎ๊ธด ๊ฒƒ์ด๋‹ค. Relative Position Embedding ์€ ์ ˆ๋Œ€ ์œ„์น˜์— ๋น„ํ•ด ๊ฝค๋‚˜ ๋ณต์žกํ•œ ๊ณผ์ •์„ ๊ฑฐ์ณ์•ผ ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ์ฝ”๋“œ ์—ญ์‹œ ๊ธด ํŽธ์ด๋‹ค. ํ•˜๋‚˜ ํ•˜๋‚˜ ์ฒœ์ฒœํžˆ ์‚ดํŽด๋ณด์ž.

# Relative Position Embedding
>>> position_embedding = nn.Embedding(2*max_length, dim_model)
>>> x, p_x = torch.randn(max_length, dim_model), position_embedding(torch.arange(2*max_length))
>>> fc_q, fc_kr = nn.Linear(dim_model, dim_head), nn.Linear(dim_model, dim_head)
>>> q, kr = fc_q(x), fc_kr(p_x) # [batch, max_length, dim_head], [batch, 2*max_length, dim_head]

>>> tmp_c2p = torch.matmul(q, kr.transpose(-1, -2))
>>> tmp_c2p, tmp_c2p.shape
(tensor([[ 2.8118,  0.8449, -0.6240, -0.6516,  3.4009,  1.8296,  0.8304,  1.0164,
           3.5664, -1.4208, -2.0821,  1.5752, -0.9469, -7.1767],
         [-2.1907, -3.2801, -2.0628,  0.4443,  2.2272, -5.6653, -4.6036,  1.4134,
          -1.1742, -0.3361, -0.4586, -1.1827,  1.0878, -2.5657],
         [-4.8952, -1.5330,  0.0251,  3.5001,  4.1619,  1.7408, -0.5100, -3.4616,
          -1.6101, -1.8741,  1.1404,  4.9860, -2.5350,  1.0999],
         [-3.3437,  4.2276,  0.4509, -1.8911, -1.1069,  0.9540,  1.2045,  2.2194,
          -2.6509, -1.4076,  5.1599,  1.6591,  3.8764,  2.5126],
         [ 0.8164, -1.9171,  0.8217,  1.3953,  1.6260,  3.8104, -1.0303, -2.1631,
           3.9008,  0.5856, -1.6212,  1.7220,  2.7997, -1.8802],
         [ 3.4473,  0.9721,  3.9137, -3.2055,  0.6963,  1.2761, -0.2266, -3.7274,
          -1.4928, -1.9257, -5.4422, -1.8544,  1.8749, -3.4923],
         [ 2.6639, -1.4392, -3.8818, -1.4120,  1.7542, -0.8774, -3.0795, -1.2156,
          -1.0852,  3.7825, -3.5581, -3.6989, -2.6705, -1.2262]],
        grad_fn=<MmBackward0>),
 torch.Size([7, 14]))

>>> max_seq, max_pos = 7, max_seq * 2
>>> q_index, k_index = torch.arange(max_seq), torch.arange(max_seq)
>>> q_index, k_index
(tensor([0, 1, 2, 3, 4, 5, 6]), tensor([0, 1, 2, 3, 4, 5, 6]))

>>> tmp_pos = q_index.view(-1, 1) - k_index.view(1, -1)
>>> rel_pos_matrix = tmp_pos + max_relative_position
>>> rel_pos_matrix
tensor([[ 7,  6,  5,  4,  3,  2,  1],
        [ 8,  7,  6,  5,  4,  3,  2],
        [ 9,  8,  7,  6,  5,  4,  3],
        [10,  9,  8,  7,  6,  5,  4],
        [11, 10,  9,  8,  7,  6,  5],
        [12, 11, 10,  9,  8,  7,  6],
        [13, 12, 11, 10,  9,  8,  7]])

>>> rel_pos_matrix = torch.clamp(rel_pos_matrix, 0, max_pos - 1).repeat(10, 1, 1)
>>> tmp_c2p = tmp_c2p.repeat(10, 1, 1)
>>> rel_pos_matrix, rel_pos_matrix.shape, tmp_c2p.shape 
(tensor([[[ 7,  6,  5,  4,  3,  2,  1],
          [ 8,  7,  6,  5,  4,  3,  2],
          [ 9,  8,  7,  6,  5,  4,  3],
          [10,  9,  8,  7,  6,  5,  4],
          [11, 10,  9,  8,  7,  6,  5],
          [12, 11, 10,  9,  8,  7,  6],
          [13, 12, 11, 10,  9,  8,  7]],
torch.Size([10, 7, 14]),
torch.Size([10, 7, 14]))

>>> outputs = torch.gather(tmp_c2p, dim=-1, index=rel_pos_matrix)
>>> outputs, outputs.shape
(tensor([[[ 1.0164,  0.8304,  1.8296,  3.4009, -0.6516, -0.6240,  0.8449],
          [-1.1742,  1.4134, -4.6036, -5.6653,  2.2272,  0.4443, -2.0628],
          [-1.8741, -1.6101, -3.4616, -0.5100,  1.7408,  4.1619,  3.5001],
          [ 5.1599, -1.4076, -2.6509,  2.2194,  1.2045,  0.9540, -1.1069],
          [ 1.7220, -1.6212,  0.5856,  3.9008, -2.1631, -1.0303,  3.8104],
          [ 1.8749, -1.8544, -5.4422, -1.9257, -1.4928, -3.7274, -0.2266],
          [-1.2262, -2.6705, -3.6989, -3.5581,  3.7825, -1.0852, -1.2156]],
          .....
          [[ 1.0164,  0.8304,  1.8296,  3.4009, -0.6516, -0.6240,  0.8449],
          [-1.1742,  1.4134, -4.6036, -5.6653,  2.2272,  0.4443, -2.0628],
          [-1.8741, -1.6101, -3.4616, -0.5100,  1.7408,  4.1619,  3.5001],
          [ 5.1599, -1.4076, -2.6509,  2.2194,  1.2045,  0.9540, -1.1069],
          [ 1.7220, -1.6212,  0.5856,  3.9008, -2.1631, -1.0303,  3.8104],
          [ 1.8749, -1.8544, -5.4422, -1.9257, -1.4928, -3.7274, -0.2266],
          [-1.2262, -2.6705, -3.6989, -3.5581,  3.7825, -1.0852, -1.2156]]],
        grad_fn=<GatherBackward0>),
 torch.Size([10, 7, 7]))

์ผ๋‹จ ์ ˆ๋Œ€ ์œ„์น˜์™€ ๋™์ผํ•˜๊ฒŒ nn.Embedding์„ ์‚ฌ์šฉํ•ด ์ž„๋ฒ ๋”ฉ ๋ฃฉ์—… ํ…Œ์ด๋ธ”(๋ ˆ์ด์–ด)๋ฅผ ์ •์˜ํ•˜์ง€๋งŒ, ์ž…๋ ฅ ์ฐจ์›์ด ๋‹ค๋ฅด๋‹ค. ์ ˆ๋Œ€ ์œ„์น˜ ์ž„๋ฒ ๋”ฉ์€ forwardํ•˜๊ฒŒ ์œ„์น˜๊ฐ’์„ ๋งตํ•‘ํ•ด์•ผ ํ•˜๋Š” ๋ฐ˜๋ฉด์— ์ƒ๋Œ€ ์œ„์น˜ ์ž„๋ฒ ๋”ฉ ๋ฐฉ์‹์€ Bi-Directionalํ•œ ๋งตํ•‘์„ ํ•ด์•ผ ํ•ด์„œ, ๊ธฐ์กด max_length ๊ฐ’์˜ ๋‘ ๋ฐฐ๋ฅผ ์ž…๋ ฅ ์ฐจ์›(max_pos)์œผ๋กœ ์‚ฌ์šฉํ–ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด 0๋ฒˆ ํ† ํฐ๊ณผ ๋‚˜๋จธ์ง€ ํ† ํฐ ์‚ฌ์ด์˜ ์œ„์น˜ ๊ด€๊ณ„๋ฅผ ํ‘œํ˜„ํ•ด์•ผ ํ•˜๋Š” ์ƒํ™ฉ์ด๋‹ค. ๊ทธ๋ ‡๋‹ค๋ฉด ์šฐ๋ฆฌ๋Š” 0๋ฒˆ ํ† ํฐ๊ณผ ๋‚˜๋จธ์ง€ ํ† ํฐ๊ณผ์˜ ์œ„์น˜ ๊ด€๊ณ„๋ฅผ [0, -1, -2, -3, -4, -5, -6] ์œผ๋กœ ์ธ์ฝ”๋”ฉํ•  ์ˆ˜ ์žˆ๋‹ค.

๋ฐ˜๋Œ€๋กœ ๋งˆ์ง€๋ง‰ 6๋ฒˆ ํ† ํฐ๊ณผ ๋‚˜๋จธ์ง€ ํ† ํฐ ์‚ฌ์ด์˜ ์œ„์น˜ ๊ด€๊ณ„๋ฅผ ํ‘œํ˜„ํ•˜๋Š” ๊ฒฝ์šฐ๋ผ๋ฉด ์–ด๋–ป๊ฒŒ ๋ ๊นŒ?? [6, 5, 4, 3, 2, 1, 0] ์œผ๋กœ ์ธ์ฝ”๋”ฉ ๋  ๊ฒƒ์ด๋‹ค. ๋‹ค์‹œ ๋งํ•ด, ์œ„์น˜ ์ž„๋ฒ ๋”ฉ ์›์†Œ ๊ฐ’์€ [-max_seq:max_seq] ์‚ฌ์ด์—์„œ ์ •์˜๋œ๋‹ค๋Š” ๊ฒƒ์ด๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ์›์†Œ๊ฐ’์˜ ๋ฒ”์œ„๋ฅผ ๊ทธ๋Œ€๋กœ ์‚ฌ์šฉํ•  ์ˆ˜๋Š” ์—†๋‹ค. ์ด์œ ๋Š” ํŒŒ์ด์ฌ์˜ ๋ฆฌ์ŠคํŠธ, ํ…์„œ ๊ฐ™์€ ๋ฐฐ์—ดํ˜• ์ž๋ฃŒ๊ตฌ์กฐ๋Š” ์Œ์ด ์•„๋‹Œ ์ •์ˆ˜๋ฅผ ์ธ๋ฑ์Šค๋กœ ํ™œ์šฉํ•ด์•ผ forward ํ•˜๊ฒŒ ์›์†Œ์— ์ ‘๊ทผํ•  ์ˆ˜ ์žˆ๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค. ์ผ๋ฐ˜์ ์œผ๋กœ ๋ฐฐ์—ด ํ˜•ํƒœ์˜ ์ž๋ฃŒํ˜•์€ ๋ชจ๋‘ ์ธ๋ฑ์Šค 0๋ถ€ํ„ฐ N-1๊นŒ์ง€ ์ˆœ์ฐจ์ ์œผ๋กœ ๋งตํ•‘๋œ๋‹ค. ๊ทธ๋ž˜์„œ ์˜๋„ํ•œ๋Œ€๋กœ ํ† ํฐ์— ์ ‘๊ทผํ•˜๋ ค๋ฉด ์—ญ์‹œ ํ† ํฐ์˜ ์ธ๋ฑ์Šค๋ฅผ forward ํ•œ ํ˜•ํƒœ๋กœ ๋งŒ๋“ค์–ด์ค˜์•ผ ํ•œ๋‹ค.

๋”ฐ๋ผ์„œ ๊ธฐ์กด [-max_seq:max_seq] ์— max_seq๋ฅผ ๋”ํ•ด์ค€ [0:2*max_seq] (2 * max_seq)์„ ์›์†Œ ๊ฐ’์˜ ๋ฒ”์œ„๋กœ ์‚ฌ์šฉํ•˜๊ฒŒ ๋œ๋‹ค. ์—ฌ๊ธฐ๊นŒ์ง€๊ฐ€ ํ†ต์ƒ์ ์œผ๋กœ ๋งํ•˜๋Š” Relative Position Embedding ์— ํ•ด๋‹นํ•œ๋‹ค. ์œ„ ์ฝ”๋“œ์ƒ์œผ๋กœ๋Š” rel_pos_matrix ๋ฅผ ๋งŒ๋“  ๋ถ€๋ถ„์— ํ•ด๋‹นํ•œ๋‹ค.

\[โˆ‚(i,j)= \begin{cases} \ 0 & {(i - j โ‰ค k)} \\ \ 2k-1 & {(i - j โ‰ฅ k)} \\ \ i - j + k & {(others)} \\ \end{cases}\]

์ด์ œ๋ถ€ํ„ฐ ์ €์ž๊ฐ€ ์ฃผ์žฅํ•˜๋Š” ์œ„์น˜ ๊ด€๊ณ„ ํ‘œํ˜„ ๋ฐฉ์‹์— ๋Œ€ํ•ด ์•Œ์•„๋ณด์ž. ์ผ๋ฐ˜์ ์ธ Relative Position Embedding๊ณผ ๊ฑฐ์˜ ์œ ์‚ฌํ•˜์ง€๋งŒ, rel_pos_matrix ๋‚ด๋ถ€ ์›์†Œ ๊ฐ’์ด ์Œ์ˆ˜๊ฐ€ ๋˜๊ฑฐ๋‚˜ max_pos ์„ ์ดˆ๊ณผํ•˜๋Š” ๊ฒฝ์šฐ๋ฅผ ์ฒ˜๋ฆฌ ํ•ด์ฃผ๊ธฐ ์œ„ํ•ด ํ›„์ฒ˜๋ฆฌ ๊ณผ์ •์„ ๋„์ž…ํ•ด ์‚ฌ์šฉํ–ˆ๋‹ค. ์˜ˆ์™ธ ์ƒํ™ฉ์€ max_seq > 1/2 * max_pos(==k) ์ผ ๋•Œ ๋ฐœ์ƒํ•œ๋‹ค. official repo ์˜ ์ฝ”๋“œ๋ฅผ ๋ณด๋ฉด max_seq์™€ k๋ฅผ ์ผ์น˜์‹œ์ผœ ๋ชจ๋ธ๋ง ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ํŒŒ์ธํŠœ๋‹ ํ•˜๋Š” ์ƒํ™ฉ์ด๋ผ๋ฉด ์ด๊ฒƒ์„ ๋ชฐ๋ผ๋„ ์ƒ๊ด€์—†๊ฒ ์ง€๋งŒ, ํ•˜๋‚˜ ํ•˜๋‚˜ ๋ชจ๋ธ์„ ์ง์ ‘ ๋งŒ๋“œ๋Š” ์ž…์žฅ์ด๋ผ๋ฉด ์˜ˆ์™ธ ์ƒํ™ฉ์„ ๋ฐ˜๋“œ์‹œ ๊ธฐ์–ตํ•˜์ž.

ํ•œํŽธ, ์ด๋Ÿฌํ•œ ์ธ์ฝ”๋”ฉ ๋ฐฉ์‹์€ word2vec ์˜ window size ๋„์ž…๊ณผ ๋น„์Šทํ•œ ์›๋ฆฌ(์˜๋ฏธ๋Š” ์ฃผ๋ณ€ ๋ฌธ๋งฅ์— ์˜ํ•ด ๊ฒฐ์ •)๋ผ๊ณ  ์ƒ๊ฐํ•˜๋ฉด ๋˜๋Š”๋ฐ, ์œˆ๋„์šฐ ์‚ฌ์ด์ฆˆ ๋ฒ”์œ„์—์„œ ๋ฒ—์–ด๋‚œ ํ† ํฐ๋“ค์€ ์ฃผ๋ณ€ ๋ฌธ๋งฅ์œผ๋กœ ์ธ์ • ํ•˜์ง€ ์•Š๊ฒ ๋‹ค๋Š”(negative sample) ์˜๋„๋ฅผ ๊ฐ–๊ณ  ์žˆ๋‹ค. ์‹ค์ œ ๊ตฌํ˜„์€ ํ…์„œ ๋‚ด๋ถ€ ์›์†Œ๊ฐ’์˜ ๋ฒ”์œ„๋ฅผ ์‚ฌ์šฉ์ž ์ง€์ • ๋ฒ”์œ„๋กœ ์ œํ•œํ•  ์ˆ˜ ์žˆ๋Š” torch.clamp ๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด 1์ค„๋กœ ๊น”๋”ํ•˜๊ฒŒ ๋งŒ๋“ค ์ˆ˜ ์žˆ์œผ๋‹ˆ ์ฐธ๊ณ ํ•˜์ž.

torch.clamp ๊นŒ์ง€ ์ ์šฉํ•˜๊ณ  ๋‚œ ์ตœ์ข… ๊ฒฐ๊ณผ๋ฅผ ์‚ดํŽด๋ณด์ž. ํ–‰๋ฐฑํ„ฐ, ์—ด๋ฒกํ„ฐ ๋ชจ๋‘ [0:2*max_seq] ์‚ฌ์ด์—์„œ ์ •์˜๋˜๊ณ  ์žˆ์œผ๋ฉฐ, ๊ฐœ๋ณ„ ๋ฐฉํ–ฅ ๋ฒกํ„ฐ ์›์†Œ์˜ ์ตœ๋Œ€๊ฐ’๊ณผ ์ตœ์†Œ๊ฐ’์˜ ์ฐจ์ด๊ฐ€ ํ•ญ์ƒ k ๋กœ ์œ ์ง€ ๋œ๋‹ค. ์˜๋„๋Œ€๋กœ ์ •ํ™•ํžˆ ์œˆ๋„์šฐ ์‚ฌ์ด์ฆˆ๋งŒํผ์˜ ์ฃผ๋ณ€ ๋งฅ๋ฝ์„ ๋ฐ˜์˜ํ•ด ์ž„๋ฒ ๋”ฉ์„ ํ˜•์„ฑํ•˜๊ณ  ์žˆ์Œ์„ ์•Œ ์ˆ˜ ์žˆ๋‹ค.

์ •๋ฆฌํ•˜๋ฉด, Relative Position Embedding ๋ž€ ์ ˆ๋Œ€ ์œ„์น˜ ๋ฐฉ์‹์ฒ˜๋Ÿผ ์ž„๋ฒ ๋”ฉ ๋ฃฉ์—… ํ…Œ์ด๋ธ”์„ ๋งŒ๋“ค๋˜, ์‚ฌ์šฉ์ž๊ฐ€ ์ง€์ •ํ•œ ์œˆ๋„์šฐ ์‚ฌ์ด์ฆˆ์— ํ•ด๋‹นํ•˜๋Š” ํ† ํฐ์˜ ์ž„๋ฒ ๋”ฉ ๊ฐ’๋งŒ ์ถ”์ถœํ•ด ์ƒˆ๋กœ์šด ํ–‰๋ฒกํ„ฐ๋ฅผ ์—ฌ๋Ÿฌ ๊ฐœ ๋งŒ๋“ค์–ด ๋‚ด๋Š” ๊ธฐ๋ฒ•์ด๋ผ๊ณ  ํ•  ์ˆ˜ ์žˆ๋‹ค. ์ด ๋•Œ ํ–‰๋ฒกํ„ฐ๋Š” ๋Œ€์ƒ ํ† ํฐ๊ณผ ๊ทธ ๋‚˜๋จธ์ง€ ํ† ํฐ ์‚ฌ์ด์˜ ์œ„์น˜ ๋ณ€ํ™”์— ๋”ฐ๋ผ ๋ฐœ์ƒํ•˜๋Š” ํŒŒ์ƒ์ ์ธ ๋งฅ๋ฝ ์ •๋ณด๋ฅผ ๋‹ด๊ณ  ์žˆ๋‹ค.

๐Ÿค” Word Context vs Relative Position vs Absolute Position

์ค„ ์„œ์žˆ๋Š” ์‚ฌ๋žŒ๋“ค ์ค„ ์„œ์žˆ๋Š” ์‚ฌ๋žŒ๋“ค

์ง€๊ธˆ๊นŒ์ง€ Relative Position Embedding์ด ๋ฌด์—‡์ด๊ณ , ๋„๋Œ€์ฒด ์–ด๋–ค ๋ฌธ๋งฅ ์ •๋ณด๋ฅผ ํฌ์ฐฉํ•œ๋‹ค๋Š” ๊ฒƒ์ธ์ง€ ์•Œ์•„๋ดค๋‹ค. ํ•„์ž์˜ ์„ค๋ช…์ด ๋งค๋„๋Ÿฝ์ง€ ๋ชปํ•˜๊ธฐ๋„ ํ•˜๊ณ  ์˜ˆ์‹œ๋ฅผ ํ…์ŠคํŠธ๋กœ ๋“ค๊ณ  ์žˆ์–ด์„œ ์ง๊ด€์ ์œผ๋กœ word context๋Š” ๋ฌด์—‡์ธ์ง€, Position ์ •๋ณด์™€๋Š” ๋ญ๊ฐ€ ๋‹ค๋ฅธ์ง€, ๋‘ ๊ฐ€์ง€ Position ์ •๋ณด๋Š” ๋ญ๊ฐ€ ์–ด๋–ป๊ฒŒ ๋‹ค๋ฅธ์ง€ ์™€๋‹ฟ์ง€ ์•Š๋Š” ๋ถ„๋“ค์ด ๋งŽ์œผ์‹ค ๊ฒƒ ๊ฐ™๋‹ค. ๊ทธ๋ž˜์„œ ์ตœ๋Œ€ํ•œ ์ง๊ด€์ ์ธ ์˜ˆ์‹œ๋ฅผ ํ†ตํ•ด ์„ธ๊ฐ€์ง€ ์ •๋ณด์˜ ์ฐจ์ด์ ์„ ์„ค๋ช…ํ•ด๋ณด๋ ค ํ•œ๋‹ค. (ํ•„์ž ๋ณธ์ธ์ด ํ–‡๊ฐˆ๋ ค์„œ ์“ฐ๋Š” ๊ฑด ๋น„๋ฐ€์ด๋‹ค)

์‚ฌ๋žŒ 5๋ช…์ด ๊ณตํ•ญ ์ฒดํฌ์ธ์„ ์œ„ํ•ด ์„œ ์žˆ๋‹ค. ๋ชจ๋‘ ์™ผ์ชฝ์„ ๋ณด๊ณ  ์žˆ๋Š” ๊ฒƒ์„ ๋ณด์•„ ์™ผ์ชฝ์— ํ‚ค๊ฐ€ ์ œ์ผ ์ž‘์€ ์—ฌ์ž๊ฐ€ ๊ฐ€์žฅ ์•ž์ค„์ด๋ผ๊ณ  ๋ณผ ์ˆ˜ ์žˆ๊ฒ ๋‹ค. ์šฐ๋ฆฌ๋Š” ์ค„ ์„œ์žˆ๋Š” ์ˆœ์„œ๋Œ€๋กœ 5๋ช…์˜ ์‚ฌ๋žŒ์—๊ฒŒ ๋ฒˆํ˜ธ๋ฅผ ๋ถ€์—ฌํ•  ๊ฒƒ์ด๋‹ค. ํŽธ์˜์ƒ 0๋ฒˆ๋ถ€ํ„ฐ ์‹œ์ž‘ํ•ด 4๋ฒˆ๊นŒ์ง€ ๋ฒˆํ˜ธ๋ฅผ ์ฃผ๊ฒ ๋‹ค. 1๋ฒˆ์— ํ•ด๋‹นํ•˜๋Š” ์‚ฌ๋žŒ์€ ๋ˆ„๊ตฌ์ธ๊ฐ€?? ๋ฐ”๋กœ ์ค„์˜ 2๋ฒˆ์งธ์— ์„œ์žˆ๋Š” ์—ฌ์ž๋‹ค. ๊ทธ๋Ÿผ 2๋ฒˆ์— ํ•ด๋‹นํ•˜๋Š” ์‚ฌ๋žŒ์€ ๋ˆ„๊ตฌ์ธ๊ฐ€?? ์‚ฌ์ง„ ์† ์ค„์˜ ๊ฐ€์žฅ ์ค‘๊ฐ„์— ์žˆ๋Š” ๋‚จ์ž๊ฐ€ 2๋ฒˆ์ด๋‹ค. ์ด๋ ‡๊ฒŒ ๊ทธ๋ฃน ๋‹จ์œ„(์ „์ฒด ์ค„)์—์„œ ๊ฐœ๊ฐœ์ธ์— ์ผ๋ จ์˜ ๋ฒˆํ˜ธ๋ฅผ ๋ถ€์—ฌํ•ด ์œ„์น˜๋ฅผ ํ‘œํ˜„ํ•˜๋Š” ๋ฐฉ๋ฒ•์ด ๋ฐ”๋กœ Absolute Position Embedding์ด๋‹ค.

ํ•œํŽธ, ๋‹ค์‹œ 2๋ฒˆ ์‚ฌ๋žŒ์—๊ฒŒ ์ฃผ๋ชฉํ•ด๋ณด์ž. ์šฐ๋ฆฌ๋Š” 2๋ฒˆ ๋‚จ์ž๋ฅผ ์ „์ฒด ์ค„์—์„œ ๊ฐ€์šด๋ฐ ์œ„์น˜ํ•œ ์‚ฌ๋žŒ์ด ์•„๋‹ˆ๋ผ, ๊ฒ€์ •์ƒ‰ ์–‘๋ณต๊ณผ ๊ตฌ๋‘๋ฅผ ์‹ ๊ณ  ์†์— ์ฅ” ๋ฌด์–ธ๊ฐ€๋ฅผ ์‘์‹œํ•˜๊ณ  ์žˆ๋Š” ์‚ฌ๋žŒ์ด๋ผ๊ณ  ํ‘œํ˜„ํ•  ์ˆ˜๋„ ์žˆ๋‹ค. ์ด๊ฒƒ์ด ๋ฐ”๋กœ ํ† ํฐ์˜ ์˜๋ฏธ ์ •๋ณด๋ฅผ ๋‹ด์€ word context์— ํ•ด๋‹นํ•œ๋‹ค.

๋งˆ์ง€๋ง‰์œผ๋กœ Relative Position Embedding ๋ฐฉ์‹์œผ๋กœ 2๋ฒˆ ๋‚จ์ž๋ฅผ ํ‘œํ˜„ํ•ด๋ณด์ž. ์˜ค๋ฅธ์†์œผ๋กœ๋Š” ์ปคํ”ผ๋ฅผ ๋“ค๊ณ  ๋‹ค๋ฅธ ์†์œผ๋กœ๋Š” ์บ๋ฆฌ์–ด๋ฅผ ์žก๊ณ  ์žˆ์œผ๋ฉฐ ๊ฒ€์ •์ƒ‰ ํ•˜์ดํž๊ณผ ๋ฒ ์ด์ง€์ƒ‰ ๋ฐ”์ง€๋ฅผ ์ž…์€ 1๋ฒˆ ์—ฌ์ž์˜ ๋’ค์— ์žˆ๋Š” ์‚ฌ๋žŒ, ํšŒ์ƒ‰ ์–‘๋ณต๊ณผ ๊ฒ€์€ ๋ฟ”ํ…Œ ์•ˆ๊ฒฝ์„ ์“ฐ๊ณ  ํ•œ ์†์—๋Š” ์บ๋ฆฌ์–ด๋ฅผ ์žก๊ณ  ์žˆ๋Š” 4๋ฒˆ ์—ฌ์ž์˜ ์•ž์— ์žˆ๋Š” ์‚ฌ๋žŒ, ๊ฒ€์ •์ƒ‰ ์ž์ผ“๊ณผ ์ฒญ๋ฐ”์ง€๋ฅผ ์ž…๊ณ  ํ•œ ์†์—๋Š” ํšŒ์ƒ‰ ์ฝ”ํŠธ๋ฅผ ๋“ค๊ณ  ์žˆ๋Š” ์ค„์˜ ๋งจ ์•ž ์—ฌ์ž๋กœ๋ถ€ํ„ฐ 2๋ฒˆ์งธ ๋’ค์— ์„œ์žˆ๋Š” ์‚ฌ๋žŒ, ํ„ฑ์ˆ˜์—ผ์ด ๊ธธ๊ณ  ๋จธ๋ฆฌ๊ฐ€ ๊ธด ํŽธ์ด๋ฉฐ ํŒŒ๋ž€์ƒ‰ ๊ฐ€๋””๊ฑด์„ ์ž…๊ณ  ์ดˆ๋ก์ƒ‰๊ณผ ๊ฒ€์ •์ƒ‰์ด ํ˜ผํ•ฉ๋œ ๊ฐ€๋ฐฉ์„ ์™ผ์ชฝ์œผ๋กœ ๋ฉ”๊ณ  ์žˆ๋Š” ๋‚จ์ž๋กœ๋ถ€ํ„ฐ 2๋ฒˆ์งธ ์•ž์— ์žˆ๋Š” ์‚ฌ๋žŒ.

์ด์ฒ˜๋Ÿผ ํ‘œํ˜„ํ•˜๋Š”๊ฒŒ ๋ฐ”๋กœ Relative Position Embedding์— ๋Œ€์‘๋œ๋‹ค๊ณ  ๋ณผ ์ˆ˜ ์žˆ๋‹ค. ์ด์ œ ์œ„ ์˜ˆ์‹œ๋ฅผ ์ž์—ฐ์–ด ์ฒ˜๋ฆฌ์— ๊ทธ๋Œ€๋กœ ๋Œ€์ž…์‹œ์ผœ๋ณด๋ฉด ์ดํ•ด๊ฐ€ ํ•œ๊ฒฐ ์ˆ˜์›”ํ•  ๊ฒƒ์ด๋‹ค.

๐Ÿค” DeBERTa Inductive Bias

๊ฒฐ๊ตญ DeBERTa๋Š” ๋‘๊ฐ€์ง€ ์œ„์น˜ ์ •๋ณด ํฌ์ฐฉ ๋ฐฉ์‹์„ ์ ์ ˆํžˆ ์„ž์–ด์„œ ๋ชจ๋ธ์ด ๋”์šฑ ํ’๋ถ€ํ•œ ์ž„๋ฒ ๋”ฉ์„ ๊ฐ–๋„๋ก ํ•˜๋ ค๋Š” ์˜๋„๋กœ ์„ค๊ณ„ ๋˜์—ˆ๋‹ค. ๋˜ํ•œ ์šฐ๋ฆฌ๋Š” ์ด๋ฏธ ๋ชจ๋ธ์ด ๋‹ค์–‘ํ•œ ๋งฅ๋ฝ ์ •๋ณด๋ฅผ ํฌ์ฐฉํ• ์ˆ˜๋ก NLU Task ์—์„œ ๋” ๋‚˜์€ ์„ฑ๋Šฅ์„ ๊ธฐ๋กํ•œ๋‹ค๋Š” ์‚ฌ์‹ค์„ BERT์™€ GPT ์‚ฌ๋ก€์—์„œ ์•Œ ์ˆ˜ ์žˆ์—ˆ๋‹ค. ๋”ฐ๋ผ์„œ Relative Position Embedding ์„ ์ถ”๊ฐ€ํ•˜์—ฌ ๋‹จ์–ด์˜ ๋ฐœ์ƒ ์ˆœ์„œ ๋ฅผ ํฌ์ฐฉํ•˜๋Š” ๋ชจ๋ธ์— ๋‹จ์–ด ๋ถ„ํฌ ๊ฐ€์„ค ์ ์ธ ํŠน์ง•์„ ๋”ํ•ด์ฃผ๋ ค๋Š” ์ €์ž์˜ ์•„์ด๋””์–ด๋Š” ๋งค์šฐ ํƒ€๋‹นํ•˜๋‹ค๊ณ  ๋ณผ ์ˆ˜ ์žˆ๊ฒ ๋‹ค.

์ด์ œ ๊ด€๊ฑด์€ โ€œ๋‘๊ฐ€์ง€ ์œ„์น˜ ์ •๋ณด๋ฅผ ์–ด๋–ค ๋ฐฉ์‹์œผ๋กœ ์ถ”์ถœํ•˜๊ณ  ์„ž์–ด์ค„ ๊ฒƒ์ธ๊ฐ€โ€ ํ•˜๋Š” ๋ฌผ์Œ์— ๋‹ตํ•˜๋Š” ๊ฒƒ์ด๋‹ค. ์ €์ž๋Š” ๋ฌผ์Œ์— ๋‹ตํ•˜๊ธฐ ์œ„ํ•ด Disentangled Self-Attention ๊ณผ Enhanced Mask Decoder ๋ผ๋Š” ์ƒˆ๋กœ์šด ๊ธฐ๋ฒ• ๋‘๊ฐ€์ง€๋ฅผ ์ œ์‹œํ•œ๋‹ค. ์ „์ž๋Š” ๋‹จ์–ด ๋ถ„ํฌ ๊ฐ€์„ค ์— ํ•ด๋‹น๋˜๋Š” ๋งฅ๋ฝ ์ •๋ณด๋ฅผ ์ถ”์ถœํ•˜๊ธฐ ์œ„ํ•œ ๊ธฐ๋ฒ•์ด๊ณ , ํ›„์ž๋Š” ๋‹จ์–ด ๋ฐœ์ƒ ์ˆœ์„œ ์— ํฌํ•จ๋˜๋Š” ์ž„๋ฒ ๋”ฉ์„ ๋ชจ๋ธ์— ์ฃผ์ž…ํ•˜๊ธฐ ์œ„ํ•ด ์„ค๊ณ„๋˜์—ˆ๋‹ค. ๋ชจ๋ธ๋ง ํŒŒํŠธ์—์„œ๋Š” ๋‘๊ฐ€์ง€ ์ƒˆ๋กœ์šด ๊ธฐ๋ฒ•์— ๋Œ€ํ•ด์„œ ์ž์„ธํžˆ ์‚ดํŽด๋ณธ ๋’ค, ๋ชจ๋ธ์„ ์ฝ”๋“œ๋กœ ๋นŒ๋“œํ•˜๋Š” ๊ณผ์ •์„ ์„ค๋ช…ํ•˜๋ ค ํ•œ๋‹ค.
์ฝ”๋“œ๋Š” ๋…ผ๋ฌธ์˜ ๋‚ด์šฉ๊ณผ microsoft์˜ ๊ณต์‹ git repo๋ฅผ ์ฐธ๊ณ ํ•ด ๋งŒ๋“ค์—ˆ์Œ์„ ๋ฐํžŒ๋‹ค. ๋‹ค๋งŒ, ๋…ผ๋ฌธ์—์„œ ๋ชจ๋ธ ๊ตฌํ˜„๊ณผ ๊ด€๋ จํ•ด ์„ธ๋ถ€์ ์ธ ๋‚ด์šฉ์€ ์ƒ๋‹น์ˆ˜ ์ƒ๋žตํ•˜๊ณ  ์žˆ์œผ๋ฉฐ, repo์— ๊ณต๊ฐœ๋œ ์ฝ”๋“œ๋Š” hard coding๋˜์–ด ๊ทธ ์˜๋„๋ฅผ ์ •ํ™•ํ•˜๊ฒŒ ํŒŒ์•…ํ•˜๋Š”๋ฐ ๋งŽ์€ ์–ด๋ ค์›€์ด ์žˆ์—ˆ๋‹ค. ๊ทธ๋ž˜์„œ ์–ด๋Š ์ •๋„๋Š” ํ•„์ž์˜ ์ฃผ๊ด€์ ์ธ ์ƒ๊ฐ์ด ๋ฐ˜์˜๋œ ์ฝ”๋“œ๋ผ๋Š” ์ ์„ ๋ฏธ๋ฆฌ ๋ฐํžŒ๋‹ค.

๐ŸŒŸย Modeling

DeBERTa Model Structure DeBERTa Model Structure

  • 1) Disentangled Self-Attention Encoder Block for Relative Position Embedding
  • 2) Enhanced Mask Decoder for Absolute Position Embedding

DeBERTa ์˜ ์ „๋ฐ˜์ ์ธ ๊ตฌ์กฐ๋Š” ์ผ๋ฐ˜์ ์ธ BERT, RoBERTa์™€ ํฌ๊ฒŒ ๋‹ค๋ฅธ ์ ์ด ์—†๋‹ค. ๋‹ค๋งŒ, ๋ชจ๋ธ์˜ ์ดˆ๋ฐ˜๋ถ€ Input Embedding ์—์„œ Absolute Position ์ •๋ณด๋ฅผ ์ถ”๊ฐ€ํ•˜๋Š” ๋ถ€๋ถ„์ด ํ›„๋ฐ˜๋ถ€ Enhanced Mask Decoder๋ผ ๋ถ€๋ฅด๋Š” ์ธ์ฝ”๋” ๋ธ”๋ก์œผ๋กœ ์˜ฎ๊ฒจ๊ฐ„ ๊ฒƒ๊ณผ Disentangled Self-Attention ์„ ์œ„ํ•ด ๊ฐœ๋ณ„ ์ธ์ฝ”๋” ๋ธ”๋ก๋งˆ๋‹ค ์ƒ๋Œ€ ์œ„์น˜ ์ •๋ณด๋ฅผ ์ถœ์ฒ˜๋กœ ํ•˜๋Š” linear projection ๋ ˆ์ด์–ด๊ฐ€ ์ถ”๊ฐ€๋˜์—ˆ์Œ์„ ๋ช…์‹ฌํ•˜์ž. ๋˜ํ•œ, DeBERTa์˜ pre-train ์€ RoBERTa์ฒ˜๋Ÿผ NSP๋ฅผ ์‚ญ์ œํ•˜๊ณ  MLM๋งŒ ์‚ฌ์šฉํ•œ ์ ๋„ ๊ธฐ์–ตํ•˜์ž.

DeBERTa Class Diagram DeBERTa Class Diagram

์œ„ ์ž๋ฃŒ๋Š” ํ•„์ž๊ฐ€ ๊ตฌํ˜„ํ•œ DeBERTa์˜ ๊ตฌ์กฐ๋ฅผ ํ‘œํ˜„ํ•œ ๊ทธ๋ฆผ์ด๋‹ค. ์ฝ”๋“œ ๋ฆฌ๋ทฐ์— ์ฐธ๊ณ ํ•˜์‹œ๋ฉด ์ข‹์„ ๊ฒƒ ๊ฐ™๋‹ค ์ฒจ๋ถ€ํ–ˆ๋‹ค. ๊ฐ€์žฅ ์ค‘์š”ํ•œ Disentangled-Attention๊ณผ EMD๋ถ€ํ„ฐ ์‚ดํŽด๋ณธ ๋’ค, ๋‚˜๋จธ์ง€ ๊ฐ์ฒด์— ๋Œ€ํ•ด์„œ ์‚ดํŽด๋ณด์ž.

๐Ÿชขย Disentangled Self-Attention

\[\tilde{A_{ij}} = Q_i^cโ€ขK_j^{cT} + Q_i^cโ€ขK_{โˆ‚(i,j)}^{rT} + K_j^cโ€ขQ_{โˆ‚(i,j)}^{rT} \\ Attention(Q_c,K_c,V_c,Q_r,K_r) = softmax(\frac{\tilde{A}}{\sqrt{3d_h}})*V_c\]

Disentangled Self-Attention์€ ์ €์ž๊ฐ€ ํ“จ์–ดํ•œ Input Embedding ์ •๋ณด์™€ Relative Position ์ •๋ณด๋ฅผ ํ†ตํ•ฉ์‹œํ‚ค๊ธฐ ์œ„ํ•ด ๊ณ ์•ˆํ•œ ๋ณ€ํ˜• Self-Attention ๊ธฐ๋ฒ•์ด๋‹ค. ๊ธฐ์กด์˜ Self-Attention๊ณผ ๋‹ค๋ฅด๊ฒŒ Position Embedding์„ Input Embedding์™€ ๋”ํ•˜์ง€ ์•Š๊ณ  ๋”ฐ๋กœ ์‚ฌ์šฉํ•œ๋‹ค. ์ฆ‰, ๊ฐ™์€ $d_h$ ๊ณต๊ฐ„์— Input Embedding๊ณผ Relative Position์ด๋ผ๋Š” ์„œ๋กœ ๋‹ค๋ฅธ ๋‘ ๋ฒกํ„ฐ๋ฅผ ๋งตํ•‘ํ•˜๊ณ  ๊ทธ ๊ด€๊ณ„์„ฑ์„ ํŒŒ์•…ํ•ด๋ณด๊ฒ ๋‹ค๋Š” ๋œป์ด๋‹ค.

Input๊ณผ Position ์ •๋ณด๋ฅผ ์„œ๋กœ ์ฃผ์ฒด์ ์ธ ์ž…์žฅ์—์„œ ํ•œ ๋ฒˆ์”ฉ ๋‚ด์ ํ•œ๋‹ค๊ณ  ํ•ด์„œ Disentangled๋ผ๋Š” ์ด๋ฆ„์ด ๋ถ™๊ฒŒ ๋˜์—ˆ๋‹ค. Transformer-XL, XLNet์— ์ œ์‹œ๋œ Cross-Attention๊ณผ ๋งค์šฐ ์œ ์‚ฌํ•œ ๊ฐœ๋…์ด๋‹ค. ์ฒซ๋ฒˆ์งธ ์ˆ˜์‹์—์„œ ๊ฐ€์žฅ ๋งˆ์ง€๋ง‰ ํ•ญ์„ ์ œ์™ธํ•˜๋ฉด Cross-Attention๊ณผ ํฌ์ฐฉํ•˜๋Š” ์ •๋ณด๊ฐ€ ๋™์ผํ•˜๋‹ค๊ณ  ์ €์ž ์—ญ์‹œ ๋ฐํžˆ๊ณ  ์žˆ์œผ๋‹ˆ ์ฐธ๊ณ ํ•˜์ž.

Disentangled Self-Attention ์€ ์ด 5๊ฐ€์ง€ linear projection matrix๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค. Input Embedding ์„ ์ถœ์ฒ˜๋กœ ํ•˜๋Š” $Q^c, k^c, V^c$, ๊ทธ๋ฆฌ๊ณ  Position Embedding์„ ์ถœ์ฒ˜๋กœ ํ•˜๋Š” $Q^r, K^r$์ด๋‹ค. ์ฒจ์ž $c,r$์€ ๊ฐ๊ฐ content, relative ์˜ ์•ฝ์ž๋กœ ํ–‰๋ ฌ์˜ ์ถœ์ฒ˜๋ฅผ ๋œปํ•œ๋‹ค. ํ•œํŽธ ํ–‰๋ ฌ ์•„๋ž˜ ์ฒจ์ž์— ์จ์žˆ๋Š” $i,j$๋Š” ๊ฐ๊ฐ ํ˜„์žฌ ์–ดํ…์…˜ ๋Œ€์ƒ ํ† ํฐ์˜ ์ธ๋ฑ์Šค์™€ ๊ทธ ๋‚˜๋จธ์ง€ ํ† ํฐ์˜ ์ธ๋ฑ์Šค๋ฅผ ๊ฐ€๋ฆฌํ‚จ๋‹ค. ๊ทธ๋ž˜์„œ $\tilde{A_{ij}}$๋Š” [NxN] ํฌ๊ธฐ ํ–‰๋ ฌ(๊ธฐ์กด ์–ดํ…์…˜์—์„œ ์ฟผ๋ฆฌ์™€ ํ‚ค์˜ ๋‚ด์  ๊ฒฐ๊ณผ์— ํ•ด๋‹น)์˜ $i$๋ฒˆ์งธ ํ–‰๋ฐฑํ„ฐ์˜ $j$๋ฒˆ์งธ ์›์†Œ์˜ ๊ฐ’์„ ์˜๋ฏธํ•œ๋‹ค. Input Embedding ์ •๋ณด์™€ Relative Position ์ •๋ณด๋ฅผ ๋”ฐ๋กœ ๋”ฐ๋กœ ๊ด€๋ฆฌํ•˜๊ธฐ ๋•Œ๋ฌธ์— ์šฐ๋ฆฌ๊ฐ€ ๊ธฐ์กด์— ์•Œ๊ณ  ์žˆ๋˜ Self-Attention๊ณผ๋Š” ์‚ฌ๋ญ‡ ๋‹ค๋ฅธ ์ˆ˜์‹์ด๋‹ค. ์ด์ œ๋ถ€ํ„ฐ ์ˆ˜์‹์˜ ํ•ญ ํ•˜๋‚˜ํ•˜๋‚˜์˜ ์˜๋ฏธ๋ฅผ ๊ตฌ์ฒด์ ์ธ ์˜ˆ์‹œ์™€ ํ•จ๊ผ ํŒŒํ—ค์ณ๋ณด์ž.

โ˜บ๏ธย c2c matrix
content2content์˜ ์•ฝ์ž๋กœ ์ฒซ๋ฒˆ์งธ ์ˆ˜์‹ ์šฐ๋ณ€์˜ ์ฒซ๋ฒˆ์งธ ํ•ญ์„ ๊ฐ€๋ฆฌํ‚ค๋Š” ๋ง์ด๋‹ค. ์ด๋ฆ„์˜ ์˜๋ฏธ๋Š” ๋‚ด์ ์— ์‚ฌ์šฉํ•˜๋Š” ๋‘ ํ–‰๋ ฌ์˜ ์ถœ์ฒ˜๊ฐ€ ๋ชจ๋‘ Input Embedding ์ด๋ผ๋Š” ์‚ฌ์‹ค์„ ๋‚ดํฌํ•˜๊ณ  ์žˆ๋‹ค. ๊ธฐ์กด์— ์•Œ๊ณ  ์žˆ๋˜ Self-Attention ์˜ ๋‘๋ฒˆ์งธ ๋‹จ๊ณ„์ธ $Qโ€ขK^T$์™€ ๊ฑฐ์˜ ๋™์ผํ•œ ์˜๋ฏธ๋ฅผ ๋‹ด๊ณ  ์žˆ๋Š” ํ•ญ์ด๋ผ๊ณ  ์ƒ๊ฐํ•˜๋ฉด ๋  ๊ฒƒ ๊ฐ™๋‹ค. ์™„์ „ํžˆ ๊ฐ™๋‹ค๊ณ  ํ•  ์ˆ˜ ์—†๋Š” ์ด์œ ๋Š” Absolute Position ์ •๋ณด๊ฐ€ ๋น ์ง„์ฑ„๋กœ ๋‚ด์ ํ–ˆ๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค.

๋”ฐ๋ผ์„œ ์—ฐ์‚ฐ์˜ ์˜๋ฏธ ์—ญ์‹œ ์šฐ๋ฆฌ๊ฐ€ ๊ธฐ์กด์— ์•Œ๊ณ  ์žˆ๋˜ ๋ฐ”์™€ ๋™์ผํ•˜๋‹ค. ํ˜น์‹œ ํ–‰๋ ฌ $Q,K,V$์™€ Self-Attention๊ฐ€ ๋‚ดํฌํ•˜๋Š” ์˜๋ฏธ์— ๋Œ€ํ•ด ์ž์„ธํžˆ ๊ถ๊ธˆํ•˜์‹  ๋ถ„์ด๋ผ๋ฉด ํ•„์ž๊ฐ€ ์ž‘์„ฑํ•œ Transformer๋…ผ๋ฌธ ๋ฆฌ๋ทฐ๋ฅผ ๋ณด๊ณ  ์˜ค์‹œ๊ธธ ๋ฐ”๋ž€๋‹ค. ๊ทธ๋ž˜๋„ ์–ด์ฐจํ”ผ ๋’ค์— ๋‚จ์€ ๋‘๊ฐœ์˜ ํ•ญ์„ ์„ค๋ช…ํ•˜๋ ค๋ฉด ์–ด์ฐจํ”ผ ์˜ˆ์‹œ๋ฅผ ๋“ค์–ด์•ผ ํ•˜๊ธฐ ๋•Œ๋ฌธ์— c2cํ•ญ์—์„œ๋ถ€ํ„ฐ ์‹œ์ž‘ํ•ด๋ณด๋ ค ํ•œ๋‹ค.

๋‹น์‹ ์€ ์˜ค๋Š˜ ์ €๋… ๋ฐฅ์œผ๋กœ ์ฐจ๋Œ๋ฐ•์ด ๋œ์žฅ ์ฐŒ๊ฐœ, ์‚ผ๊ฒน์‚ด ๊ทธ๋ฆฌ๊ณ  ํ›„์‹์œผ๋กœ ๊ตฌ์šด ๋‹ฌ๊ฑ€์„ ๋จน๊ณ  ์‹ถ๋‹ค. ์ง‘์— ์žฌ๋ฃŒ๊ฐ€ ํ•˜๋‚˜๋„ ์—†์ง€๋งŒ ๋งˆํŠธ์— ๊ฐ€๊ธฐ ๊ท€์ฐฎ์œผ๋‹ˆ ํ•„์š”ํ•œ ์‹์ž์žฌ๋ฅผ ๋‚จํŽธ์—๊ฒŒ ์‚ฌ์˜ค๋ผ๊ณ  ์‹œํ‚ฌ ์ƒ๊ฐ์ด๋‹ค. ๋‹น์‹ ์€ ๊ทธ๋ž˜์„œ ํ•„์š”ํ•œ ์žฌ๋ฃŒ ๋ฆฌ์ŠคํŠธ๋ฅผ ์ ๊ณ  ์žˆ๋‹ค. ๊ทธ๋ ‡๋‹ค๋ฉด ํ•„์š”ํ•œ ์žฌ๋ฃŒ๋ฅผ ์–ด๋–ค ์‹์œผ๋กœ ํ‘œํ˜„ํ•ด์„œ ์ ์–ด์ค˜์•ผ ๋‚จํŽธ์ด ๊ฐ€์žฅ ๋น ๋ฅด๊ณ  ์ •ํ™•ํ•˜๊ฒŒ ํ•„์š”ํ•œ ๋ชจ๋“  ์‹์ž์žฌ๋ฅผ ์‚ฌ์˜ฌ ์ˆ˜ ์žˆ์„๊นŒ??

์ด๊ฒƒ์„ ๊ณ ๋ฏผํ•˜๋Š”๊ฒŒ ๋ฐ”๋กœ ํ–‰๋ ฌ $Q^c$์™€ linear projector ์ธ $W_{Q^c}$์˜ ์—ญํ• ์ด๋‹ค.์˜ˆ๋ฅผ ๋“ค์–ด ๊ฐ™์€ ์•ž๋‹ค๋ฆฌ์‚ด์ด๋ผ๋„ ๊ตฌ์ด์šฉ์ด ์žˆ๊ณ  ์ฐŒ๊ฐœ์šฉ์ด ์žˆ๋‹ค. ๋‹ฌ๊ฑ€๋„ ๊ตฌ์šด ๋‹ฌ๊ฑ€์ด ์žˆ๊ณ  ๋‚ ๋‹ฌ๊ฑ€์ด ์žˆ๋‹ค. ์ •ํ™•ํžˆ ์šฉ๋„๋ฅผ ์ ์–ด์ฃผ๋Š”๊ฒŒ ๋‚จํŽธ ์ž…์žฅ์—์„œ๋Š” ์•„๋‚ด์˜ ์˜๋„๋Œ€๋กœ ์ •ํ™•ํ•˜๊ฒŒ ์žฅ์„ ๋ณด๊ธฐ ํ›จ์”ฌ ํŽธํ•  ๊ฒƒ์ด๋‹ค.

ํ•œํŽธ, ๋‚ด์ ์€ ๋ณธ๋ž˜ ํŒŒ๋ผ๋ฏธํ„ฐ๊ฐ€ ํ•„์š”ํ•œ ์—ฐ์‚ฐ์€ ์•„๋‹ˆ๋ผ์„œ ์‹ค์ œ ์†์‹คํ•จ์ˆ˜ ์˜ค์ฐจ ์—ญ์ „์„ ํ†ตํ•ด ์ตœ์ ํ™”(ํ•™์Šต)๋˜๋Š” ๋Œ€์ƒ์€ ๋ฐ”๋กœ $W_{Q^c}$๊ฐ€ ๋œ๋‹ค. ๋‚จํŽธ์ด ์žฅ์„ ๋น ๋ฅด๊ณ  ์ •ํ™•ํ•˜๊ฒŒ ๋ณด๋Š”๋ฐ ๊ณผ์—ฐ ๋‹น์‹ ์ด ์ ์–ด์ค€ ๋ฆฌ์ŠคํŠธ๋งŒ ์˜ํ–ฅ์„ ๋ฏธ์น ๊นŒ??

์•„๋‹ˆ๋‹ค. ๋‹น์‹ ์ด ์–ด๋–ค ์Œ์‹์„ ์œ„ํ•ด ์–ด๋–ค ์žฌ๋ฃŒ๊ฐ€ ํ•„์š”ํ•œ์ง€ ๊ทธ ์˜๋„๋ฅผ ์ž˜ ์ ์–ด์ฃผ๋Š” ๊ฒƒ๋„ ์ค‘์š”ํ•˜์ง€๋งŒ ์‹ค์ œ ๋งˆํŠธ์— ์ ํ˜€ ์žˆ๋Š” ์ƒํ’ˆ๋ช…๊ณผ ์ƒํ’ˆ์„ค๋ช… ์—ญ์‹œ ์ค‘์š”ํ•˜๋‹ค. ์ข€ ์–ต์ง€์Šค๋Ÿฌ์šด ์˜ˆ์‹œ์ฒ˜๋Ÿผ ๋ณด์ด๊ธด ํ•˜์ง€๋งŒ ๋‹ฌ๊ฑ€์˜ ๊ฒฝ์šฐ ์œก์•ˆ์œผ๋กœ๋งŒ ๋ณด๋ฉด ์ด๊ฒƒ์ด ๊ตฌ์šด ๋‹ฌ๊ฑ€์ธ์ง€ ๋‚ ๋‹ฌ๊ฑ€์ธ์ง€ ๊ตฌ๋ถ„ํ•  ์ˆ˜ ์—†๋‹ค. ๊ทธ๋Ÿฐ๋ฐ ๋งˆํŠธ์— ๋ณ„๋‹ค๋ฅธ ์„ค๋ช…์—†์ด ์ƒํ’ˆ๋ช…์œผ๋กœ โ€œ๋‹ฌ๊ฑ€โ€ ์ด๋ผ๊ณ ๋งŒ ์ ํ˜€์žˆ๋‹ค ์ƒ๊ฐํ•ด๋ณด์ž.

์•„๋ฌด๋ฆฌ ๋‹น์‹ ์ด ์ข‹์€ ํ–‰๋ ฌ $Q^c$๋ฅผ ํ‘œํ˜„ํ•ด์ค˜๋„ ๋‚จํŽธ์ด ๋‚ ๋‹ฌ๊ฑ€์„ ์‚ฌ์˜ฌ ํ™•๋ฅ ์ด ๊ฝค๋‚˜ ๋†’์„ ๊ฒƒ์ด๋‹ค. ์ด๋ ‡๊ฒŒ ๋งˆํŠธ์— ์ ํ˜€์žˆ๋Š” ์ƒํ’ˆ๋ช…๊ณผ ์ƒํ’ˆ์„ค๋ช…์ด ๋ฐ”๋กœ ํ–‰๋ ฌ $K^c$์— ๋Œ€์‘๋œ๋‹ค. ๊ทธ๋ฆฌ๊ณ  ๋ฌผ๊ฑด์„ ์‚ฌ๊ธฐ ์œ„ํ•ด ๋‹น์‹ ์ด ์ ์–ด์ค€ ์‹์ž์žฌ ๋ฆฌ์ŠคํŠธ์™€ ๋งค์žฅ์— ์ ํžŒ ์ƒํ’ˆ๋ช…๊ณผ ์ƒํ’ˆ์„ค๋ช…์„ ๋Œ€์กฐํ•˜๋ฉฐ ์ด๊ฒƒ์ด ์˜๋„์— ๋งž๋Š” ์ƒํ’ˆ์ธ์ง€ ๋”ฐ์ ธ๋ณด๋Š” ์ž‘์—…์ด ๋ฐ”๋กœ $Q_i^cโ€ขK_j^{cT}$, c2c matrix๊ฐ€ ๋œ๋‹ค.

๋‹ค๋งŒ, ์ „์—ญ ์–ดํ…์…˜์„ ์‚ฌ์šฉํ•˜๊ธฐ ๋•Œ๋ฌธ์— ๋‹ฌ๊ฑ€์„ ์‚ฌ๊ธฐ ์œ„ํ•ด ๋งค์žฅ์— ์žˆ๋Š” ๋ชจ๋“  ์ƒํ’ˆ๊ณผ ๋Œ€์กฐ๋ฅผ ํ•œ๋‹ค๊ณ  ์ƒ๊ฐํ•˜๋ฉด ๋œ๋‹ค. ํŠนํžˆ ๊ธฐ์กด ์ „์—ญ ์–ดํ…์…˜์˜ $Q_i^cโ€ขK_j^{cT}$์˜ ๊ฒฝ์šฐ ๋ชจ๋“  ์ƒํ’ˆ๊ณผ ๋Œ€์กฐํ•˜๋Š” ๊ณผ์ •์—์„œ ๋Œ€์กฐ๊ตฐ์ด ๋งค์žฅ์— ์ „์‹œ๋œ ์œ„์น˜, ์นดํ…Œ๊ณ ๋ฆฌ ๋ถ„๋ฅ˜์ƒ ์–ด๋Š ์ฝ”๋„ˆ์— ์†ํ•˜๋Š”์ง€ ๋“ฑ์˜ ์œ„์น˜ ์ •๋ณด๋ฅผ ํ•œ๊บผ๋ฒˆ์— ๊ณ ๋ คํ•˜์ง€๋งŒ, ์šฐ๋ฆฌ์˜(c2c matrix) ๊ฒฝ์šฐ ์—ฌ๊ธฐ์„œ ์ด๋Ÿฐ ์œ„์น˜ ์ •๋ณด๋ฅผ ์ „ํ˜€ ๊ณ ๋ คํ•˜์ง€ ์•Š๊ณ  ๋’ค์— ๋‘ ๊ฐœ์˜ ํ•ญ์—์„œ ๋”ฐ๋กœ ๊ณ ๋ คํ•œ๋‹ค.

์ •๋ฆฌํ•˜๋ฉด, c2c๋Š” ๋งค์žฅ์— ์ง„์—ด๋œ ์‹์ž์žฌ์˜ ์ƒํ’ˆ๋ช… ๋ฐ ์„ค๋ช…๋งŒ ๊ฐ€์ง€๊ณ  ๋‚ด๊ฐ€ ์‚ฌ์•ผ ํ•˜๋Š” ์‹์žฌ๋ฃŒ์ธ์ง€ ์•„๋‹Œ์ง€ ํŒ๋‹จํ•˜๋Š” ์ž‘์—…์„ ์ˆ˜ํ•™์ ์œผ๋กœ ๋ชจ๋ธ๋ง ํ–ˆ๋‹ค๊ณ  ๋ณผ ์ˆ˜ ์žˆ๊ฒ ๋‹ค. ์ž์—ฐ์–ด ์ฒ˜๋ฆฌ ๋งฅ๋ฝ์—์„œ ๋ฐ”๋ผ๋ณด๋ฉด, ํŠน์ • ํ† ํฐ์˜ ์˜๋ฏธ๋ฅผ ์•Œ๊ธฐ ์œ„ํ•ด์„œ syntacticalํ•œ ์ •๋ณด์—†์ด ์ˆœ์ˆ˜ํ•˜๊ฒŒ ๋‚˜๋จธ์ง€ ๋‹ค๋ฅธ ํ† ํฐ๋“ค์˜ ์˜๋ฏธ๋ฅผ ๊ฐ€์ค‘ํ•ฉ์œผ๋กœ ๋ฐ˜์˜ํ•˜๋Š” ํ–‰์œ„์— ๋Œ€์‘๋œ๋‹ค.

๐Ÿ—‚๏ธย c2p matrix

\[c2p = Q_i^cโ€ขK_{โˆ‚(i,j)}^{rT}\]

content2position์˜ ์•ฝ์ž๋กœ ์ˆ˜์‹ ์šฐ๋ณ€์˜ ๋‘๋ฒˆ์งธ ํ•ญ, $Q_i^cโ€ขK_{โˆ‚(i,j)}^{rT}$๋ฅผ ๊ฐ€๋ฆฌํ‚จ๋‹ค. c2c๋•Œ์™€๋Š” ๋‹ค๋ฅด๊ฒŒ ์„œ๋กœ ์ถœ์ฒ˜๊ฐ€ ๋‹ค๋ฅธ ๋‘ ํ–‰๋ ฌ์„ ์‚ฌ์šฉํ•ด c2p๋ผ๋Š” ์ด๋ฆ„์„ ๋ถ™์˜€๋‹ค. ๋‚ด์  ๋Œ€์ƒ์˜ ์ฟผ๋ฆฌ๋Š” Input Embedding์œผ๋กœ๋ถ€ํ„ฐ ๋งŒ๋“  ํ–‰๋ ฌ $Q_i^c$, ํ‚ค๋Š” Position Embedding์œผ๋กœ๋ถ€ํ„ฐ ๋งŒ๋“  ํ–‰๋ ฌ $K_{โˆ‚(i,j)}^{rT}$ ์„ ์‚ฌ์šฉํ–ˆ๋‹ค. word context์™€ relative position์„ ์„œ๋กœ ๋Œ€์กฐํ•œ๋‹ค๋Š” ๊ฒƒ์ด ๋ฌด์Šจ ์˜๋ฏธ๋ฅผ ๊ฐ–๋Š”์ง€ ์ง๊ด€์ ์œผ๋กœ ์•Œ๊ธฐ ํž˜๋“œ๋‹ˆ ์žฅ๋ณด๊ธฐ ์˜ˆ์‹œ๋ฅผ ํ†ตํ•ด ์ดํ•ดํ•ด๋ณด์ž.

๊ตฌ์šด ๋‹ฌ๊ฑ€๊ณผ ๋‚ ๋‹ฌ๊ฑ€์˜ ์˜ˆ์‹œ๋ฅผ ๋“ค๋ฉด์„œ ์ƒํ’ˆ๋ช…๊ณผ ์„ค๋ช…์ด ์žฅ๋ณด๊ธฐ์— ์ค‘์š”ํ•œ ์˜ํ–ฅ์„ ๋ฏธ์นœ๋‹ค๊ณ  ์–ธ๊ธ‰ํ–ˆ๋‹ค. ํ•˜์ง€๋งŒ ์ƒํ’ˆ๋ช…๊ณผ ์„ค๋ช…์ด ์—ฌ์ „ํžˆ ๋‹จ์ˆœ โ€œ๋‹ฌ๊ฑ€โ€์œผ๋กœ ์ ํ˜€ ์žˆ์–ด๋„ ์šฐ๋ฆฌ๋Š” ์ด๊ฒƒ์„ ๊ตฌ๋ถ„ํ•ด ๋‚ผ ๋ฐฉ๋ฒ•์ด ์žˆ๋‹ค. ๋ฐ”๋กœ ์ฃผ๋ณ€์— ์ง„์—ด๋œ ์ƒํ’ˆ์ด ๋ฌด์—‡์ธ์ง€ ์‚ดํŽด๋ณด๋Š” ๊ฒƒ์ด๋‹ค. โ€œ๋‹ฌ๊ฑ€โ€ ๋ฐ”๋กœ ์˜†์— ์šฐ์œ , ์น˜์ฆˆ, ์ƒ์„ , ์ •์œก๊ณผ ๊ฐ™์€ ์‹ ์„ ์‹ํ’ˆ๋ฅ˜๊ฐ€ ๋ฐฐ์น˜๋˜์–ด ์žˆ๋‹ค๊ณ  ๊ฐ€์ •ํ•ด๋ณด์ž. ์šฐ๋ฆฌ๋Š” ์šฐ๋ฆฌ ๋ˆˆ ์•ž์— ์žˆ๋Š” โ€œ๋‹ฌ๊ฑ€โ€์ด ๋‚ ๋‹ฌ๊ฑ€์ด๋ผ๊ณ  ๊ธฐ๋Œ€ํ•ด ๋ด„์งํ•˜๋‹ค. ๋งŒ์•ฝ โ€œ๋‹ฌ๊ฑ€โ€ ์˜†์— ์ฅํฌ, ๋ง๋ฆฐ ์˜ค์ง•์–ด, ์œกํฌ, ๊ณผ์ž ๊ฐ™์€ ๊ฐ„์‹๋ฅ˜ ์ƒํ’ˆ๋“ค์ด ๋ฐฐ์น˜๋˜์–ด ์žˆ๋‹ค๋ฉด ์–ด๋–จ๊นŒ?? ๊ทธ๋Ÿผ ์ด โ€œ๋‹ฌ๊ฑ€โ€์€ ์ถฉ๋ถ„ํžˆ ๊ตฌ์šด ๋‹ฌ๊ฑ€์ด๋ผ๊ณ  ํ•ด์„ํ•ด๋ณผ ์ˆ˜ ์žˆ๋‹ค. ์ด์ฒ˜๋Ÿผ ์ฃผ์œ„์— ์–ด๋–ค ๋‹ค๋ฅธ ์ƒํ’ˆ๋“ค์ด ๋ฐฐ์น˜ ๋˜์–ด ์žˆ๋Š”๊ฐ€๋ฅผ ํ†ตํ•ด ์šฐ๋ฆฌ๊ฐ€ ์‚ฌ๋ ค๋Š” ๋ฌผ๊ฑด์ด ๋งž๋Š”์ง€ ๋Œ€์กฐํ•ด๋ณด๋Š” ํ–‰์œ„๊ฐ€ ๋ฐ”๋กœ c2p ์— ๋Œ€์‘๋œ๋‹ค. ๊ทธ๋ ‡๋‹ค๋ฉด ์ฃผ์œ„์— ์–ด๋–ค ๋‹ค๋ฅธ ์ƒํ’ˆ๋“ค์ด ๋ฐฐ์น˜ ๋˜์–ด ์žˆ๋Š”๊ฐ€ ์ •๋ณด๋ฅผ ๋ชจ์•„ ๋†“์€ ๊ฒƒ์ด ๋ฐ”๋กœ $K_{โˆ‚(i,j)}^{rT}$๊ฐ€ ๋œ๋‹ค.

๐Ÿ”ฌย p2c matrix

\[p2c = K_j^cโ€ขQ_{โˆ‚(i,j)}^{rT}\]

Disentangled Self-Attention์ด ์—ฌํƒ€ ๋‹ค๋ฅธ ์–ดํ…์…˜ ๊ธฐ๋ฒ•๋“ค๊ณผ ๊ฐ€์žฅ ์ฐจ๋ณ„ํ™”๋˜๋Š” ๋ถ€๋ถ„์ด๋‹ค. ์ €์ž๊ฐ€ ๋…ผ๋ฌธ์—์„œ ๊ฐ€์žฅ ๊ฐ•์กฐํ•˜๋Š” ๋ถ€๋ถ„์ด๊ธฐ๋„ ํ•˜๋‹ค. ์‚ฌ์‹ค ๊ทธ๋Ÿฐ ๊ฒƒ์น˜๊ณ ๋Š” ๋…ผ๋ฌธ ์† ์„ค๋ช…์ด ์ƒ๋‹นํžˆ ๋ถˆ์นœ์ ˆํ•ด ์ดํ•ดํ•˜๊ธฐ ์ฐธ ๋‚œํ•ดํ•œ ๊ฐœ๋…์ด๋‹ค. ์ด๊ฑฐ ์„ค๋ช…ํ•˜๊ณ  ์‹ถ์–ด์„œ ์žฅ๋ณด๊ธฐ ์˜ˆ์‹œ๋ฅผ ์ƒ๊ฐํ•ด๋‚ด๊ฒŒ ๋˜์—ˆ๋‹ค. ๋‹ค์‹œ ๋‚จํŽธ์—๊ฒŒ ์ค„ ์žฅ๋ณด๊ธฐ ๋ฆฌ์ŠคํŠธ๋ฅผ ์ž‘์„ฑํ•˜๋˜ ์‹œ์ ์œผ๋กœ ๋Œ์•„๊ฐ€๋ณด์ž.

์˜ค๋Š˜ ์ €๋… ๋ฉ”๋‰ด๋Š” ์ฐจ๋Œ๋ฐ•์ด ๋œ์žฅ์ฐŒ๊ฐœ์™€ ๊ตฌ์šด ์‚ผ๊ฒน์‚ด์ด๋‹ค. ๋จผ์ € ์ฐจ๋Œ๋ฐ•์ด ๋œ์žฅ์ฐŒ๊ฐœ๋ฅผ ๋งŒ๋“ค๋ ค๋ฉด ์–ด๋–ค ์žฌ๋ฃŒ๊ฐ€ ํ•„์š”ํ• ๊นŒ?? ์ฐจ๋Œ๋ฐ•์ด, ๋œ์žฅ, ์ฒญ์–‘๊ณ ์ถ”, ์–‘ํŒŒ, ๋‹ค์ง„ ๋งˆ๋Š˜, ํ˜ธ๋ฐ•๊ณผ ๊ฐ™์€ ์‹์ž์žฌ๊ฐ€ ํ•„์š”ํ•  ๊ฒƒ์ด๋‹ค. ๊ทธ๋ฆฌ๊ณ  ์‚ผ๊ฒน์‚ด์— ํ•„์š”ํ•œ ์žฌ๋ฃŒ๋ฅผ ์ƒ๊ฐํ•ด๋ณด์ž. ์ƒ์‚ผ๊ฒน์‚ด๊ณผ ์žก๋‚ด๋ฅผ ์—†์• ๋Š”๋ฐ ํ•„์š”ํ•œ ํ›„์ถ”์™€ ์†Œ๊ธˆ ๊ทธ๋ฆฌ๊ณ  ๊ตฌ์›Œ ๋จน์„ ํ†ต๋งˆ๋Š˜์ด ํ•„์š”ํ•˜๋‹ค๊ณ  ๋‹น์‹ ์€ ์ƒ๊ฐํ–ˆ๋‹ค. ๊ทธ๋Ÿผ ์ด์ œ ์ด๊ฒƒ์„ ๋ฐ”ํƒ•์œผ๋กœ ๋ฆฌ์ŠคํŠธ๋ฅผ ์ž‘์„ฑํ•  ๊ฒƒ์ด๋‹ค. ์–ด๋–ค ์‹์œผ๋กœ ๋ฆฌ์ŠคํŠธ๋ฅผ ์ž‘์„ฑํ•˜๋Š”๊ฒŒ ๊ฐ€์žฅ ์ตœ์ ์ผ๊นŒ??

c2c, c2p ์˜ˆ์‹œ์™€ ํ•จ๊ป˜ ์ƒ๊ฐํ•ด๋ณด๋ฉด ์•Œ ์ˆ˜ ์žˆ๋‹ค. c2c์—์„œ๋Š” ๊ฐ™์€ ์žฌ๋ฃŒ๋ผ๋„ ๊ทธ ์šฉ๋„์— ๋”ฐ๋ผ์„œ ์‚ฌ์•ผํ•  ํ’ˆ๋ชฉ์ด ๋‹ฌ๋ผ์ง„๋‹ค๊ณ  ์–ธ๊ธ‰ํ•œ ๋ฐ”์žˆ๋‹ค. c2p ์—์„œ๋Š” ์ •ํ™•ํ•œ ์„ค๋ช…์ด ์—†์–ด๋„ ์ฃผ๋ณ€์— ๋‚˜์—ด๋œ ํ’ˆ๋ชฉ๋“ค์„ ๋ณด๋ฉด์„œ ์–ด๋–ค ์ƒํ’ˆ์ธ์ง€ ์œ ์ถ”๊ฐ€ ๊ฐ€๋Šฅํ•˜๋‹ค๊ณ  ํ–ˆ๋‹ค. ์ด๊ฒƒ์„ ํ•ฉ์ณ๋ณด์ž. ๋งŒ์•ฝ ๋‹น์‹ ์ด ์•„๋ž˜์™€ ๊ฐ™์€ ์ˆœ์„œ๋กœ ๋ฆฌ์ŠคํŠธ๋ฅผ ์ ์—ˆ๋‹ค๊ณ  ๊ฐ€์ •ํ•ด๋ณด๊ฒ ๋‹ค.

# ์žฅ๋ณด๊ธฐ ๋ฆฌ์ŠคํŠธ ์˜ˆ์‹œ1

์ฐจ๋Œ๋ฐ•์ด, ๋œ์žฅ, ๋งˆ๋Š˜, ์ฒญ์–‘๊ณ ์ถ”, ์–‘ํŒŒ, ํ˜ธ๋ฐ•, ์‚ผ๊ฒน์‚ด, ํ›„์ถ”, ์†Œ๊ธˆ

์•„๊นŒ ํ•„์š”ํ•œ ํ’ˆ๋ชฉ์„ ๋‚˜์—ดํ–ˆ์„ ๋•Œ ๋ถ„๋ช…ํžˆ ๋‹ค์ง„ ๋งˆ๋Š˜๊ณผ ํ†ต๋งˆ๋Š˜์„ ๋™์‹œ์— ์ƒ๊ฐํ–ˆ์—ˆ๋‹ค. ๊ทผ๋ฐ ์œ„์ฒ˜๋Ÿผ ๋ฆฌ์ŠคํŠธ๋ฅผ ์ž‘์„ฑํ•ด์„œ ๋‚จํŽธ์—๊ฒŒ ์คฌ๋‹ค๋ฉด ๋‚จํŽธ์€ ์–ด๋–ค ๋งˆ๋Š˜์„ ์‚ฌ์˜ฌ๊นŒ?? ๋‹น์—ฐํžˆ ์ฐจ๋Œ๋ฐ•์ด์™€ ๋œ์žฅ ๊ทธ๋ฆฌ๊ณ  ์–‘ํŒŒ ์‚ฌ์ด์— ๋งˆ๋Š˜์ด ์œ„์น˜ํ•œ ๊ฒƒ์„ ๋ณด๊ณ  ๋‚จํŽธ์€ ๊ตญ๋ฌผ์šฉ ๋งˆ๋Š˜์ด ํ•„์š”ํ•˜๊ตฌ๋‚˜ ์‹ถ์–ด์„œ ๋‹ค์ง„ ๋งˆ๋Š˜์„ ์‚ฌ์˜ฌ ๊ฒƒ์ด๋‹ค.

๊ทธ๋ ‡๋‹ค๋ฉด ๋ฐ˜๋Œ€๋กœ ๋‹น์‹ ์ด ์•„๋ž˜์ฒ˜๋Ÿผ ๋ฆฌ์ŠคํŠธ๋ฅผ ์ž‘์„ฑํ–ˆ๋‹ค๊ณ  ์ƒ๊ฐํ•ด๋ณด์ž.

# ์žฅ๋ณด๊ธฐ ๋ฆฌ์ŠคํŠธ ์˜ˆ์‹œ2

์ฐจ๋Œ๋ฐ•์ด, ๋œ์žฅ, ์ฒญ์–‘๊ณ ์ถ”, ์–‘ํŒŒ, ํ˜ธ๋ฐ•, ์‚ผ๊ฒน์‚ด, ๋งˆ๋Š˜, ํ›„์ถ”, ์†Œ๊ธˆ

์ด๋ฒˆ์—๋Š” ์‚ผ๊ฒน์‚ด ๊ตฌ์šธ ๋•Œ, ๊ฐ™์ด ๊ตฌ์›Œ๋จน์„ ํ†ต๋งˆ๋Š˜์ด ํ•„์š”ํ•˜๊ตฌ๋‚˜๋ฅผ ๋‚จํŽธ์ด ๋Š๋‚„ ์ˆ˜ ์žˆ์„ ๊ฒƒ์ด๋‹ค. ํ•œํŽธ ์•„๋ž˜์™€ ๊ฐ™์€ ์ƒํ™ฉ์ด๋ผ๋ฉด ์–ด๋–จ๊นŒ??

# ์žฅ๋ณด๊ธฐ ๋ฆฌ์ŠคํŠธ ์˜ˆ์‹œ3

์ฐจ๋Œ๋ฐ•์ด, ๋œ์žฅ, ๋งˆ๋Š˜, ์ฒญ์–‘๊ณ ์ถ”, ์–‘ํŒŒ, ํ˜ธ๋ฐ•, ์‚ผ๊ฒน์‚ด, ๋งˆ๋Š˜, ํ›„์ถ”, ์†Œ๊ธˆ

์กฐ๊ธˆ ์„ผ์Šค๊ฐ€ ์žˆ๋Š” ๋‚จํŽธ์ด๋ผ๋ฉด ๋œ์žฅ์ฐŒ๊ฐœ ๊ตญ๋ฌผ์šฉ ๋‹ค์ง„๋งˆ๋Š˜๊ณผ ์‚ผ๊ฒน์‚ด ๊ตฌ์ด์šฉ ํ†ต๋งˆ๋Š˜์ด ๋™์‹œ์— ํ•„์š”ํ•˜๊ตฌ๋‚˜๋ผ๊ณ  ์œ ์ถ”ํ•˜๊ณ  ๋งค์žฅ์—์„œ ๋‹ค์ง„๋งˆ๋Š˜, ํ†ต๋งˆ๋Š˜์ด๋ผ ์จ์žˆ๋Š” ํ’ˆ๋ชฉ์„ ์ฐพ์•„์„œ ๋‘˜ ๋‹ค ์‚ฌ์˜ฌ ๊ฒƒ์ด๋‹ค. ๋ฌผ๋ก  ์„ผ์Šค์žˆ๋Š” ์•„๋‚ด๋ผ๋ฉด ์• ์ดˆ์— ์ €๋ ‡๊ฒŒ ์• ๋งคํ•˜๊ฒŒ ๋งˆ๋Š˜์ด๋ผ๊ณ  2๋ฒˆ ์•ˆ์ ๊ณ  ๋‹ค์ง„๋งˆ๋Š˜, ํ†ต๋งˆ๋Š˜์ด๋ผ๊ณ  ์šฉ๋„๋ฅผ ํ•จ๊ป˜ ์ ์–ด์คฌ๊ฒ ์ง€๋งŒ ๋ง์ด๋‹ค.

์ด๋Ÿฌํ•œ ์ผ๋ จ์˜ ์ƒํ™ฉ์ด ๋ฐ”๋กœ p2c์— ๋Œ€์‘๋œ๋‹ค. ๊ทธ๋ ‡๋‹ค๋ฉด ์•„๋‚ด๊ฐ€ ์ ์–ด์ค€ ๋ฆฌ์ŠคํŠธ์—์„œ ์ฃผ๋ณ€์— ์œ„์น˜ํ•œ ํ’ˆ๋ชฉ๋“ค์— ๋”ฐ๋ผ์„œ ํฌ์ฐฉ๋˜๋Š” ๋Œ€์ƒ ํ’ˆ๋ชฉ์˜ ์šฉ๋„๋‚˜ ์“ฐ์ž„์ƒˆ, ์˜๋ฏธ ๋“ฑ์ด ๋ฐ”๋กœ ํ–‰๋ ฌ $Q_{โˆ‚(i,j)}^{rT}$๊ฐ€ ๋œ๋‹ค.

โš’๏ธย DeBERTa Scale Factor
์ฒ˜์Œ์— ๋‚˜์—ดํ•œ ์ˆ˜์‹์„ ๋‹ค์‹œ ๋ณด๋ฉด DeBERTa์˜ scale factor๋Š” ๊ธฐ์กด Self-Attention ๊ณผ ๋‹ค๋ฅด๊ฒŒ $\sqrt{3d_h}$๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค. ์ด์œ ๊ฐ€ ๋ญ˜๊นŒ?? ๊ธฐ์กด ๋ฐฉ์‹์€ softmax layer์— ์ „๋‹ฌํ•˜๋Š” ํ–‰๋ ฌ์˜ ์ข…๋ฅ˜๊ฐ€ $Qโ€ขK^T$ ํ•œ ๊ฐœ๋‹ค. DeBERTa์˜ ๊ฒฝ์šฐ๋Š” 3๊ฐœ๋ฅผ ์ „๋‹ฌํ•˜๊ฒŒ ๋œ๋‹ค. ๊ทธ๋ž˜์„œ $d_h$์•ž์— 3์„ ๊ณฑํ•ด์ค€ ๊ฒƒ์ด๋‹ค. official repo์˜ ์ฝ”๋“œ๋ฅผ ํ™•์ธํ•ด๋ณด๋ฉด ํ™•์‹คํžˆ ์•Œ ์ˆ˜ ์žˆ๋Š”๋ฐ, ์–ดํ…์…˜์— ์‚ฌ์šฉํ•˜๋Š” ํ–‰๋ ฌ ์ข…๋ฅ˜์˜ ๊ฐœ์ˆ˜๋ฅผ $d_h$์•ž์— ๊ณฑํ•ด์ค€๋‹ค. ์•„๋ž˜๋Š” repo์— ์˜ฌ๋ผ์™€ ์žˆ๋Š” ์ฝ”๋“œ์˜ ์ผ๋ถ€๋ฅผ ๋ฐœ์ทŒํ•œ ๊ฒƒ์ด๋‹ค.

# official Disentangled Self-Attention by microsoft from official repo

...์ค‘๋žต...
def forward(self, hidden_states, attention_mask, return_att=False, query_states=None, relative_pos=None, rel_embeddings=None):
    if query_states is None:
        query_states = hidden_states
    query_layer = self.transpose_for_scores(self.query_proj(query_states), self.num_attention_heads).float()
    key_layer = self.transpose_for_scores(self.key_proj(hidden_states), self.num_attention_heads).float()
    value_layer = self.transpose_for_scores(self.value_proj(hidden_states), self.num_attention_heads)
    
    rel_att = None
    # Take the dot product between "query" and "key" to get the raw attention scores.
    scale_factor = 1
    if 'c2p' in self.pos_att_type:
        scale_factor += 1
    if 'p2c' in self.pos_att_type:
        scale_factor += 1
    if 'p2p' in self.pos_att_type:
        scale_factor += 1

๐Ÿ‘ฉโ€๐Ÿ’ปย Implementation
์ด๋ ‡๊ฒŒ Disentangled Self-Attention์— ๋Œ€ํ•œ ๋ชจ๋“  ๋‚ด์šฉ์„ ์‚ดํŽด๋ดค๋‹ค. ์‹ค์ œ ๊ตฌํ˜„์€ ์–ด๋–ป๊ฒŒ ํ•ด์•ผ ํ•˜๋Š”์ง€ ํ•„์ž๊ฐ€ ์ž‘์„ฑํ•œ ํŒŒ์ดํ† ์น˜ ์ฝ”๋“œ์™€ ํ•จ๊ป˜ ์•Œ์•„๋ณด์ž.

# Pytorch Implementation of DeBERTa Disentangled Self-Attention

def build_relative_position(x_size: int) -> Tensor:
    """ Build Relative Position Matrix for Disentangled Self-attention in DeBERTa

    Args:
        x_size: sequence length of query matrix

    Reference:
        https://arxiv.org/abs/2006.03654
        https://github.com/microsoft/DeBERTa/blob/master/DeBERTa/deberta/da_utils.py#L29
    """
    x_index, y_index = torch.arange(x_size, device="cuda"), torch.arange(x_size, device="cuda")  # same as rel_pos in official repo
    rel_pos = x_index.view(-1, 1) - y_index.view(1, -1)
    return rel_pos


def disentangled_attention(
        q: Tensor,
        k: Tensor,
        v: Tensor,
        qr: Tensor,
        kr: Tensor,
        attention_dropout: torch.nn.Dropout,
        padding_mask: Tensor = None,
        attention_mask: Tensor = None
) -> Tensor:
    """ Disentangled Self-attention for DeBERTa, same role as Module "DisentangledSelfAttention" in official Repo

    Args:
        q: content query matrix, shape (batch_size, seq_len, dim_head)
        k: content key matrix, shape (batch_size, seq_len, dim_head)
        v: content value matrix, shape (batch_size, seq_len, dim_head)
        qr: position query matrix, shape (batch_size, 2*max_relative_position, dim_head), r means relative position
        kr: position key matrix, shape (batch_size, 2*max_relative_position, dim_head), r means relative position
        attention_dropout: dropout for attention matrix, default rate is 0.1 from official paper
        padding_mask: mask for attention matrix for MLM
        attention_mask: mask for attention matrix for CLM

    Math:
        c2c = torch.matmul(q, k.transpose(-1, -2))  # A_c2c
        c2p = torch.gather(torch.matmul(q, kr.transpose(-1 z, -2)), dim=-1, index=c2p_pos)
        p2c = torch.gather(torch.matmul(qr, k.transpose(-1, -2)), dim=-2, index=c2p_pos)
        attention Matrix = c2c + c2p + p2c
        A = softmax(attention Matrix/sqrt(3*D_h)), SA(z) = Av

    Notes:
        dot_scale(range 1 ~ 3): scale factor for Qโ€ขK^T result, sqrt(3*dim_head) from official paper by microsoft,
        3 means that use full attention matrix(c2c, c2p, p2c), same as number of using what kind of matrix
        default 1, c2c is always used and c2p & p2c is optional

    References:
        https://arxiv.org/pdf/1803.02155
        https://arxiv.org/abs/2006.03654
        https://arxiv.org/abs/2111.09543
        https://arxiv.org/abs/1901.02860
        https://arxiv.org/abs/1906.08237
        https://github.com/microsoft/DeBERTa/blob/master/DeBERTa/deberta/disentangled_attention.py
    """
    BS, NUM_HEADS, SEQ_LEN, DIM_HEADS = q.shape
    _, MAX_REL_POS, _, _ = kr.shape

    scale_factor = 1
    c2c = torch.matmul(q, k)  # A_c2c, q: (BS, NUM_HEADS, SEQ_LEN, DIM_HEADS), k: (BS, NUM_HEADS, DIM_HEADS, SEQ_LEN)
    c2p_att = torch.matmul(q, kr.permute(0, 2, 3, 1).contiguous())

    c2p_pos = build_relative_position(SEQ_LEN) + MAX_REL_POS / 2  # same as rel_pos in official repo
    c2p_pos = torch.clamp(c2p_pos, 0, MAX_REL_POS - 1).repeat(BS, NUM_HEADS, 1, 1).long()
    c2p = torch.gather(c2p_att, dim=-1, index=c2p_pos)
    if c2p is not None:
        scale_factor += 1

    p2c_att = torch.matmul(qr, k)  # qr: (BS, NUM_HEADS, SEQ_LEN, DIM_HEADS), k: (BS, NUM_HEADS, DIM_HEADS, SEQ_LEN)
    p2c = torch.gather(p2c_att, dim=-2, index=c2p_pos)  # same as torch.gather(kโ€ขqr^t, dim=-1, index=c2p_pos)
    if p2c is not None:
        scale_factor += 1

    dot_scale = torch.sqrt(torch.tensor(scale_factor * DIM_HEADS))  # from official paper by microsoft
    attention_matrix = (c2c + c2p + p2c) / dot_scale  # attention Matrix = A_c2c + A_c2r + A_r2c
    if padding_mask is not None:
        padding_mask = padding_mask.unsqueeze(1).unsqueeze(2)  # for broadcasting: shape (BS, 1, 1, SEQ_LEN)
        attention_matrix = attention_matrix.masked_fill(padding_mask == 1, float('-inf'))  # Padding Token Masking

    attention_dist = attention_dropout(
        F.softmax(attention_matrix, dim=-1)
    )
    attention_matrix = torch.matmul(attention_dist, v).permute(0, 2, 1, 3).reshape(-1, SEQ_LEN, NUM_HEADS*DIM_HEADS).contiguous()
    return attention_matrix

p2c ๋ฅผ ๊ตฌํ•˜๋Š” ๊ณผ์ •์˜ ์ฝ”๋“œ๋ผ์ธ์— ์ฃผ๋ชฉํ•ด๋ณด์ž. ๋…ผ๋ฌธ์— ๊ธฐ์žฌ๋œ ์ˆ˜์‹($K_j^cโ€ขQ_{โˆ‚(i,j)}^{rT}$)๊ณผ ๋‹ค๋ฅด๊ฒŒ, ์ฟผ๋ฆฌ์™€ ํ‚ค์˜ ์ˆœ์„œ๋ฅผ ๋’ค์ง‘์—ˆ๋‹ค. ๊ทธ๋ž˜์„œ torch.gather์˜ ์ฐจ์› ๋งค๊ฐœ๋ณ€์ˆ˜ dim๋ฅผ c2p์˜ ์ƒํ™ฉ๊ณผ ๋‹ค๋ฅด๊ฒŒ -2๋กœ ์ดˆ๊ธฐํ™”ํ•˜๊ฒŒ ๋˜์—ˆ๋‹ค. ๋‚ด์ ํ•˜๋Š” ํ•ญ์˜ ์ˆœ์„œ๋ฅผ ๋’ค์ง‘์€ ๊ฒƒ์œผ๋กœ ์ธํ•ด ์šฐ๋ฆฌ๊ฐ€ ์ถ”์ถœํ•˜๊ณ  ์‹ถ์€ ๋Œ€์ƒ ๊ฐ’์ธ ์ƒ๋Œ€ ์œ„์น˜ ์ž„๋ฒ ๋”ฉ์ด -2๋ฒˆ์งธ ์ฐจ์›์— ์œ„์น˜ ํ•˜๊ฒŒ ๋˜๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค.

๐Ÿ˜ทย Enhanced Mask Decoder

DeBERTa์˜ ์„ค๊ณ„ ๋ชฉ์ ์€ 2๊ฐ€์ง€ ์œ„์น˜ ์ •๋ณด๋ฅผ ์ ์ ˆํžˆ ์„ž์–ด์„œ ์ตœ๋Œ€ํ•œ ํ’๋ถ€ํ•œ ์ž„๋ฒ ๋”ฉ์„ ๋งŒ๋“œ๋Š” ๊ฒƒ์ด๋ผ๊ณ  ํ–ˆ๋‹ค. ์ƒ๋Œ€ ์œ„์น˜ ์ž„๋ฒ ๋”ฉ์€ Disentangled Self-Attention์„ ํ†ตํ•ด ํฌ์ฐฉํ•œ๋‹ค๋Š” ๊ฒƒ์„ ์ด์ œ ์•Œ์•˜๋‹ค. ๊ทธ๋Ÿผ ์ ˆ๋Œ€ ์œ„์น˜ ์ž„๋ฒ ๋”ฉ์€ ์–ด๋–ค ์‹์œผ๋กœ ๋ชจ๋ธ๋งํ•ด์ค˜์•ผ ํ• ๊นŒ?? ๊ทธ ๋ฌผ์Œ์— ๋‹ต์€ ๋ฐ”๋กœ EMD๋ผ ๋ถˆ๋ฆฌ๋Š” Enhanced Mask Decoder์— ์žˆ๋‹ค. EMD์˜ ์›๋ฆฌ์— ๋Œ€ํ•ด ๊ณต๋ถ€ํ•˜๊ธฐ ์ „์— ์™œ ์ ˆ๋Œ€ ์œ„์น˜ ์ž„๋ฒ ๋”ฉ์ด NLU์— ํ•„์š”ํ•œ์ง€ ์งš๊ณ  ๋„˜์–ด๊ฐ€์ž.

a new "store" opened beside the new "mallโ€

์œ„ ๋ฌธ์žฅ์€ ์ €์ž๊ฐ€ ๋…ผ๋ฌธ์—์„œ Absolute Position Embedding์˜ ํ•„์š”์„ฑ์„ ์—ญ์„คํ•  ๋•Œ ์‚ฌ์šฉํ•œ ์˜ˆ์‹œ ๋ฌธ์žฅ์ด๋‹ค. ๊ณผ์—ฐ ์ƒ๋Œ€ ์œ„์น˜ ์ž„๋ฒ ๋”ฉ๋งŒ ์‚ฌ์šฉํ•ด์„œ store์™€ mall์˜ ์ฐจ์ด๋ฅผ ์ž˜ ๊ตฌ๋ณ„ํ•  ์ˆ˜ ์žˆ์„๊นŒ ์ƒ๊ฐํ•ด๋ณด์ž. ์•ž์„œ ์šฐ๋ฆฌ๋Š” ์ƒ๋Œ€ ์œ„์น˜ ์ž„๋ฒ ๋”ฉ์„ ๋Œ€์ƒ ํ† ํฐ๊ณผ ๊ทธ ๋‚˜๋จธ์ง€ ํ† ํฐ ์‚ฌ์ด์˜ ์œ„์น˜ ๋ณ€ํ™”์— ๋”ฐ๋ผ ๋ฐœ์ƒํ•˜๋Š” ํŒŒ์ƒ์ ์ธ ๋งฅ๋ฝ ์ •๋ณด๋ฅผ ๋‹ด์€ ํ–‰๋ ฌ์ด๋ผ๊ณ  ์ •์˜ํ•œ ๋ฐ” ์žˆ๋‹ค. ๋‹ค์‹œ ๋งํ•ด, ๋Œ€์ƒ ํ† ํฐ์˜ ์˜๋ฏธ๋ฅผ ์ฃผ๋ณ€์— ์–ด๋–ค context๊ฐ€ ์žˆ๋Š”์ง€ ํŒŒ์•…ํ•ด ํ†ตํ•ด ์ดํ•ดํ•ด๋ณด๊ฒ ๋‹ค๋Š” ๊ฒƒ์ด๋‹ค.

์˜ˆ์‹œ ๋ฌธ์žฅ์„ ๋‹ค์‹œ ๋ณด์ž. ๋‘ ๋Œ€์ƒ ๋‹จ์–ด ๋ชจ๋‘ ์ฃผ์œ„์— ๋น„์Šทํ•œ ์˜๋ฏธ๋ฅผ ๊ฐ–๋Š” ๋‹จ์–ด๋“ค์ด ์œ„์น˜ํ•ด ์žˆ๋‹ค. ์ด๋Ÿฐ ๊ฒฝ์šฐ ์ƒ๋Œ€ ์œ„์น˜ ์ž„๋ฒ ๋”ฉ๋งŒ์œผ๋กœ๋Š” ์‹œํ€€์Šค ๋‚ด๋ถ€์—์„œ store์™€ mall์˜ ์˜๋ฏธ ์ฐจ์ด๋ฅผ ๋ชจ๋ธ์ด ๋ช…ํ™•ํ•˜๊ฒŒ ์ดํ•ดํ•˜๊ธฐ ๋งค์šฐ ์–ด๋ ค์šธ ๊ฒƒ์ด๋‹ค. ํ˜„์žฌ ์ƒํ™ฉ์—์„œ ๋‘ ๋‹จ์–ด์˜ ๋‰˜์•™์Šค ์ฐจ์ด๋Š” ๊ฒฐ๊ตญ ๋ฌธ์žฅ์˜ ์ฃผ์–ด๋ƒ ๋ชฉ์ ์–ด๋ƒ ํ•˜๋Š” syntacticalํ•œ ์ •๋ณด์— ์˜ํ•ด์„œ ๊ฒฐ์ •๋œ๋‹ค. syntacticalํ•œ ์ •๋ณด์˜ ํ•„์š”์„ฑ์€ ๋ฐ”๋กœ ์ ˆ๋Œ€ ์œ„์น˜ ์ž„๋ฒ ๋”ฉ์ด NLU์— ๊ผญ ํ•„์š”ํ•œ ์ด์œ ์— ๋Œ€์‘๋œ๋‹ค.

Enhanced Mask Decoder Overview Enhanced Mask Decoder Overview

๐Ÿค”ย why named decoder
ํ•„์ž๋Š” ์ฒ˜์Œ ๋…ผ๋ฌธ์„ ์ฝ์—ˆ์„ ๋•Œ Decoder๋ผ๋Š” ์ด๋ฆ„์„ ๋ณด๋ฉด์„œ ์ฐธ ์˜์•„ํ–ˆ๋‹ค. ๋ถ„๋ช… Only-Encoder ๋ชจ๋ธ๋กœ ์•Œ๊ณ  ์žˆ๋Š”๋ฐ ์–ด์ฐŒํ•˜์—ฌ ์ด๋ฆ„์— ๋””์ฝ”๋”๊ฐ€ ๋ถ™๋Š” ๋ชจ๋“ˆ์ด ์žˆ๋Š” ๊ฒƒ์ธ๊ฐ€. ๊ทธ๋ ‡๋‹ค๊ณ  ์ด๋ฆ„์„ ์ €๋ ‡๊ฒŒ ๋ถ™์ธ ์˜๋„๋ฅผ ์„ค๋ช…ํ•˜๋Š” ๊ฒƒ๋„ ์•„๋‹ˆ๋‹ค. ๊ทธ๋ž˜์„œ ํ•„์ž๊ฐ€ ์Šค์Šค๋กœ ์ถ”์ธกํ•ด๋ดค๋‹ค.

DeBERTa๋Š” pre-train task ๋กœ MLM์„ ์‚ฌ์šฉํ–ˆ๋‹ค. MLM์ด ๋ฌด์—‡์ธ๊ฐ€?? ๋ฐ”๋กœ ๋งˆ์Šคํ‚น๋œ ์ž๋ฆฌ์— ์ ์ ˆํ•œ ํ† ํฐ์„ ์ฐพ๋Š” ๋นˆ์นธ ์ฑ„์šฐ๊ธฐ ๋ฌธ์ œ๋‹ค. ์˜๋ฏธ๊ถŒ์—์„œ๋Š” ์ด๊ฒƒ์„ denoisingํ•œ๋‹ค๊ณ  ํ‘œํ˜„ํ•˜๊ธฐ๋„ ํ•˜๋Š”๋ฐ, Absolute Position Embedding์ด ๋ฐ”๋กœ ์ด denoising์— ์ง€๋Œ€ํ•œ ์˜ํ–ฅ๋ ฅ์„ ๋ฏธ์นœ๋‹ค๋Š” ์–ธ๊ธ‰์„ ๋…ผ๋ฌธ์—์„œ ์ฐพ์•„๋ณผ ์ˆ˜ ์žˆ๋‹ค. ๋”ฐ๋ผ์„œ denoising ์„ฑ๋Šฅ์— ํฐ ์˜ํ–ฅ์„ ์ฃผ๋Š” Absolute Position Embedding์„ ํ™œ์šฉํ•œ๋‹ค๊ณ  ํ•ด์„œ ์ด๋ฆ„์— decoder๋ฅผ ๋ถ™์˜€์ง€ ์•Š์•˜๋‚˜ ์˜ˆ์ƒํ•ด๋ณธ๋‹ค.

๋…ผ๋ฌธ์— ๊ฐ™์ด ์‹ค๋ฆฐ ๊ทธ๋ฆผ์„ ํ†ตํ•ด์„œ๋„ ์ถ”์ธก์ด ๊ฐ€๋Šฅํ•˜๋‹ค. EMD ์˜ ๊ตฌ์กฐ๋ฅผ ์„ค๋ช…ํ•˜๋ฉด์„œ ์˜†์— BERT์˜ ๋ชจ์‹๋„๋„ ํ•จ๊ป˜ ์ œ๊ณตํ•˜๋Š”๋ฐ, BERT์—๋Š” Decoder๊ฐ€ ์ „ํ˜€ ์—†๋‹ค. ๊ทธ๋Ÿฐ๋ฐ๋„ ์ด๋ฆ„์„ BERT decoding layer๋ผ๊ณ  ๋ถ€๋ฅด๋Š” ๊ฒƒ๋ณด๋ฉด ํ•„์ž์˜ ์ถ”์ธก์— ์ข€ ๋” ์ •๋‹น์„ฑ์„ ๋ถ€์—ฌํ•˜๋Š” ๊ฒƒ ๊ฐ™๋‹ค.

(+ ์ถ”๊ฐ€) offical repo code์—์„œ๋„ EMD๊ฐ€ ์šฐ๋ฆฌ๊ฐ€ ์•„๋Š” ๊ทธ Encoder๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค๋Š” ์‚ฌ์‹ค์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค.

๐Ÿคทโ€โ™‚๏ธย How to add Absolute Position

DeBERTa Model Structure DeBERTa Model Structure

์ด์ œ EMD๊ฐ€ ๋ฌด์—‡์ด๋ฉฐ, Absolute Position์„ ์–ด๋–ป๊ฒŒ ๋ชจ๋ธ์— ์ถ”๊ฐ€ํ•˜๋Š”์ง€ ์•Œ์•„๋ณด์ž. EMD๋Š” MLM ์„ฑ๋Šฅ์„ ๋†’์ด๊ธฐ ์œ„ํ•ด ๊ณ ์•ˆ๋œ ๊ตฌ์กฐ๋‹ค. ๊ทธ๋ž˜์„œ ํ† ํฐ ์˜ˆ์ธก์„ ์œ„ํ•œ feedforward & softmax ๋ ˆ์ด์–ด ์ง์ „์— ์Œ“๋Š”๋‹ค. ๋ช‡๊ฐœ์˜ EMD๋ฅผ ์Œ“์„ ๊ฒƒ์ธ์ง€๋Š” ํ•˜์ดํผํŒŒ๋ฆฌ๋ฏธํ„ฐ์ด๋ฉฐ, ์ €์ž์˜ ์‹คํ—˜ ๊ฒฐ๊ณผ 2๊ฐœ ์‚ฌ์šฉํ•˜๋Š”๊ฒŒ ๊ฐ€์žฅ ํšจ์œจ์ ์ด๋ผ๊ณ  ํ•œ๋‹ค. ์ƒˆ๋กญ๊ฒŒ ์ธ์ฝ”๋” ๋ธ”๋Ÿญ์„ ์Œ“์ง€ ์•Š๊ณ  Disentangled-Attention ๋ ˆ์ด์–ด์˜ ๊ฐ€์žฅ ๋งˆ์ง€๋ง‰ ์ธ์ฝ”๋” ๋ธ”๋Ÿญ๊ณผ ๊ฐ€์ค‘์น˜๋ฅผ ๊ณต์œ ํ•˜๋Š” ํ˜•ํƒœ๋กœ ๊ตฌํ˜„ํ•œ๋‹ค.

# EMD Implementation Example

class EnhancedMaskDecoder(nn.Module):
    def __init__(self, encoder: list[nn.ModuleList], dim_model: int = 1024) -> None:
        super(EnhancedMaskDecoder, self).__init__()
        self.emd_layers = encoder

class DeBERTa(nn.Module):
    def __init__(self, vocab_size: int, max_seq: 512, N: int = 24, N_EMD: int = 2, dim_model: int = 1024, num_heads: int = 16, dim_ffn: int = 4096, dropout: float = 0.1) -> None:
			# Init Sub-Blocks & Modules
        self.encoder = DeBERTaEncoder(
            self.max_seq,
            self.N,
            self.dim_model,
            self.num_heads,
            self.dim_ffn,
            self.dropout_prob
        )
        self.emd_layers = [self.encoder.encoder_layers[-1] for _ in range(self.N_EMD)]. # weight share
        self.emd_encoder = EnhancedMaskDecoder(self.emd_layers, self.dim_model)

๋”ฐ๋ผ์„œ N_EMD=2 ๋กœ ์„ค์ •ํ•œ๋‹ค๋Š” ๊ฒƒ์€ ๊ฒฐ๊ตญ, Disentangled-Attention ๋ ˆ์ด์–ด์˜ ๊ฐ€์žฅ ๋งˆ์ง€๋ง‰ ์ธ์ฝ”๋” ๋ธ”๋Ÿญ์„ 2๊ฐœ ๋” ์Œ“๋Š” ๊ฒƒ๊ณผ ๋™์น˜๋‹ค. ๋Œ€์‹  ์ธ์ฝ”๋”์˜ linear projection ๋ ˆ์ด์–ด์˜ ์ž…๋ ฅ๊ฐ’์ด ๋‹ค๋ฅด๋‹ค. Disentangled-Attention ์˜ ํ–‰๋ ฌ $Q^c, K^c, V^c$๋Š” ์ด์ „ ๋ธ”๋Ÿญ์˜ hidden_states ๊ฐ’์ธ ํ–‰๋ ฌ $H$๋ฅผ ์ž…๋ ฅ์œผ๋กœ, ํ–‰๋ ฌ $Q^r, K^r$์€ ๋ ˆ์ด์–ด์˜ ์œ„์น˜์— ์ƒ๊ด€์—†์ด ๋ชจ๋‘ ๊ฐ™์€ ๊ฐ’์˜ Relative Position Embedding ์„ ์ž…๋ ฅ์œผ๋กœ ์‚ฌ์šฉํ•œ๋‹ค.

๋ฐ˜๋ฉด, EMD ๋งจ ์ฒ˜์Œ ์ธ์ฝ”๋” ๋ธ”๋Ÿญ์˜ ํ–‰๋ ฌ $Q^c$๋Š” ๋ฐ”๋กœ ์ง์ „ ๋ธ”๋Ÿญ์˜ hidden_states ์— Absolute Position Embedding์„ ๋”ํ•œ ๊ฐ’์„ ์ž…๋ ฅ์œผ๋กœ ์‚ฌ์šฉํ•œ๋‹ค. ์ดํ›„ ๋‚˜๋จธ์ง€ ๋ธ”๋Ÿญ์—๋Š” Disentangled-Attention ์™€ ๋งˆ์ฐฌ๊ฐ€์ง€๋กœ ์ด์ „ ๋ธ”๋Ÿญ์˜ hidden_states ๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค. ํ–‰๋ ฌ $K^c, V^c$๋Š” ๋ธ”๋Ÿญ ์ˆœ์„œ์— ์ƒ๊ด€์—†์ด ์ด์ „ ๋ธ”๋Ÿญ์˜ hidden_states ๋งŒ ๊ฐ€์ง€๊ณ  linear projection์„ ์ˆ˜ํ–‰ํ•œ๋‹ค. ๊ทธ๋ฆฌ๊ณ  ํ–‰๋ ฌ $Q^r, K^r$ ์—ญ์‹œ ๊ฐ™์€ ๊ฐ’์˜ Relative Position Embedding ์„ ์ž…๋ ฅ์œผ๋กœ ์‚ฌ์šฉํ•œ๋‹ค.

์‚ฌ์‹ค ํ•„์ž๋Š” ๋…ผ๋ฌธ๋งŒ ์ฝ์—ˆ์„ ๋•Œ EMD๋„ Relative Position ์ •๋ณด๋ฅผ ์ฃผ์ž…ํ•ด Disengtanled-Attention์„ ์ˆ˜ํ–‰ํ•œ๋‹ค๊ณ  ์ „ํ˜€ ์ƒ๊ฐํ•˜์ง€ ๋ชปํ–ˆ๋‹ค. ์ด๋Š” ๋…ผ๋ฌธ์˜ ์„ค๋ช…์ด ์ƒ๋‹นํžˆ ๋ถˆ์นœ์ ˆํ•œ ๋•๋ถ„์ธ๋ฐ, ๋…ผ๋ฌธ์— ์ด์™€ ๊ด€๋ จํ•ด์„œ ์ž์„ธํ•œ ์„ค๋ช…๋„ ์—†๊ณ  Absolute Position Embedding์„ ์‚ฌ์šฉํ•˜๋Š” ๋ ˆ์ด์–ด๋ผ์„œ ๋‹น์—ฐํžˆ ์ผ๋ฐ˜์ ์ธ Self-Attention์„ ์‚ฌ์šฉํ•  ๊ฒƒ์ด๋ผ๊ณ  ์ƒ๊ฐํ–ˆ๋˜ ๊ฒƒ์ด๋‹ค.

ํ•„์ž๋Š” ์—ฌ๊ธฐ์„œ Absolute Position์ด ์™œ ํ•„์š”ํ•œ์ง€๋„ ์•Œ๊ฒ ๊ณ  ๊ทธ๋ž˜์„œ ํ–‰๋ ฌํ•ฉ์œผ๋กœ ๋”ํ•ด์„œ ์–ดํ…์…˜์„ ์ˆ˜ํ–‰ํ•˜๋Š” ๊ฒƒ๋„ ์ž˜ ์•Œ๊ฒ ๋Š”๋ฐ ์™œ ๊ตณ์ด ๊ฐ€์žฅ ๋งˆ์ง€๋ง‰ ๋ ˆ์ด์–ด์—์„œ ์ด๊ฑธ ํ• ๊นŒ?? ํ•˜๋Š” ์˜๋ฌธ์ด ๋“ค์—ˆ๋‹ค. ์ผ๋ฐ˜ Self-Attention์ฒ˜๋Ÿผ ๋งจ ์ฒ˜์Œ์— ๋”ํ•˜๊ณ  ์‹œ์ž‘ํ•˜๋ฉด ์•ˆ๋ ๊นŒ??

์ €์ž์˜ ์‹คํ—˜์— ๋”ฐ๋ฅด๋ฉด Absolute Position ์„ ์ฒ˜์Œ์— ์ถ”๊ฐ€ํ•˜๋Š” ๊ฒƒ๋ณด๋‹ค EMD์ฒ˜๋Ÿผ ๊ฐ€์žฅ ๋งˆ์ง€๋ง‰์— ๋”ํ•ด์ฃผ๋Š”๊ฒŒ ์„ฑ๋Šฅ์ด ๋” ์ข‹์•˜๋‹ค๊ณ  ํ•œ๋‹ค. ๊ทธ ์ด์œ ๋กœ Absolute Position๋ฅผ ์ดˆ๋ฐ˜์— ์ถ”๊ฐ€ํ•˜๋ฉด ๋ชจ๋ธ์ด Relative Position์„ ํ•™์Šตํ•˜๋Š”๋ฐ ๋ฐฉํ•ด๊ฐ€ ๋˜๋Š” ๊ฒƒ ๊ฐ™๋‹ค๋Š” ์ถ”์ธก์„ ํ•จ๊ป˜ ์„œ์ˆ ํ•˜๊ณ  ์žˆ๋‹ค. ๊ทธ๋ ‡๋‹ค๋ฉด ์™œ ๋ฐฉํ•ด๊ฐ€ ๋˜๋Š” ๊ฒƒ์ผ๊นŒ??

ํ•„์ž์˜ ๋‡Œํ”ผ์…œ์ด์ง€๋งŒ ์ด๊ฒƒ ์—ญ์‹œ blessing of dimensionality ์—์„œ ํŒŒ์ƒ๋œ ๋ฌธ์ œ๋ผ๊ณ  ์ƒ๊ฐํ•œ๋‹ค. ์ผ๋‹จ ์šฉ์–ด์˜ ๋œป๋ถ€ํ„ฐ ์•Œ์•„๋ณด์ž. blessing of dimensionality ๋ž€, ๊ณ ์ฐจ์› ๊ณต๊ฐ„์—์„œ ๋ฌด์ž‘์œ„๋กœ ์„œ๋กœ ๋‹ค๋ฅธ ๋ฒกํ„ฐ ๋‘๊ฐœ๋ฅผ ์„ ํƒํ•˜๋ฉด ๋‘ ๋ฒกํ„ฐ๋Š” ๊ฑฐ์˜ ๋Œ€๋ถ€๋ถ„ approximate orthogonality๋ฅผ ๊ฐ–๋Š” ํ˜„์ƒ์„ ์„ค๋ช…ํ•˜๋Š” ์šฉ์–ด๋‹ค. ๋ฌด์กฐ๊ฑด ์„ฑ๋ฆฝํ•˜๋Š” ์„ฑ์งˆ์€ ์•„๋‹ˆ๊ณ  ํ™•๋ฅ ๋ก ์ ์ธ ์ ‘๊ทผ์ด๋ผ๋Š” ๊ฒƒ์„ ๋ช…์‹ฌํ•˜์ž. ์•„๋ฌดํŠผ ์ง๊ตํ•˜๋Š” ๋‘ ๋ฒกํ„ฐ๋Š” ๋‚ด์ ๊ฐ’์ด 0์— ์ˆ˜๋ ดํ•œ๋‹ค. ์ฆ‰, ๋‘ ๋ฒกํ„ฐ๋Š” ์„œ๋กœ์—๊ฒŒ ์˜ํ–ฅ์„ ๋ฏธ์น˜์ง€ ๋ชปํ•œ๋‹ค๋Š” ๊ฒƒ์ด๋‹ค.

์ด๊ฒƒ์€ Transformer์—์„œ Input Embedding๊ณผ Absolute Position Embedding์„ ํ–‰๋ ฌํ•ฉ์œผ๋กœ ๋”ํ•ด๋„ ์ข‹์€ ํ•™์Šต ๊ฒฐ๊ณผ๋ฅผ ์–ป์„ ์ˆ˜ ์žˆ๋Š” ์ด์œ ๊ฐ€ ๋œ๋‹ค. ๋‹ค์‹œ ๋งํ•ด์„œ, hidden states space ์—์„œ Input Embedding ๊ณผ Absolute Position Embedding ์—ญ์‹œ ๊ฐœ๋ณ„ ๋ฒกํ„ฐ๊ฐ€ span ํ•˜๋Š” ๋ถ€๋ถ„ ๊ณต๊ฐ„ ๋ผ๋ฆฌ๋Š” ์„œ๋กœ ์ง๊ตํ•  ๊ฐ€๋Šฅ์„ฑ์ด ๋งค์šฐ ๋†’๋‹ค๋Š” ๊ฒƒ์„ ์˜๋ฏธํ•œ๋‹ค. ๋”ฐ๋ผ์„œ ์„œ๋กœ ๋‹ค๋ฅธ ์ถœ์ฒ˜๋ฅผ ํ†ตํ•ด ๋งŒ๋“ค์–ด์ง„ ๋‘ ํ–‰๋ ฌ์„ ๋”ํ•ด๋„ ์„œ๋กœ์—๊ฒŒ ์˜ํ–ฅ์„ ๋ฏธ์น˜์ง€ ๋ชปํ•  ๊ฒƒ์ด๊ณ  ๊ทธ๋กœ ์ธํ•ด ๋ชจ๋ธ์ด Input๊ณผ Position ์ •๋ณด๋ฅผ ๋”ฐ๋กœ ์ž˜ ํ•™์Šตํ•  ์ˆ˜ ์žˆ์„ ๊ฒƒ์ด๋ผ ๊ธฐ๋Œ€ํ•ด๋ณผ ์ˆ˜ ์žˆ๋‹ค.

hidden states vector space example hidden states vector space example

์ด์ œ ๋‹ค์‹œ DeBERTa ๊ฒฝ์šฐ๋กœ ๋Œ์•„์™€๋ณด์ž. ์œ„ ๊ทธ๋ฆผ์˜ ํŒŒ๋ž€์ƒ‰ ์ง์„ ์„ Input Embedding, ๋นจ๊ฐ„์ƒ‰ ์ง์„ ์„ Absolute Position Embedding, ์™ผ์ชฝ์˜ ๋ณด๋ผ์ƒ‰ ์ง์„ ์„ Relative Position Embedding์ด๋ผ๊ณ  ๊ฐ€์ •ํ•˜์ž. blessing of dimensionality์— ์˜ํ•ด word con text ์ •๋ณด(ํŒŒ๋ž€ ์ง์„ )์™€ position ์ •๋ณด(๋นจ๊ฐ•, ๋ณด๋ผ ์ง์„ )๋Š” ๊ทธ๋ฆผ์ฒ˜๋Ÿผ ์„œ๋กœ ๊ทผ์‚ฌ ์ง๊ตํ•  ๊ฐ€๋Šฅ์„ฑ์ด ๋งค์šฐ ๋†’๋‹ค. ์—ฌ๊ธฐ๋ถ€ํ„ฐ ํ•„์ž์˜ ๋‡Œํ”ผ์…œ์ด ๋“ค์–ด๊ฐ€๋Š”๋ฐ, ๋ณด๋ผ์ƒ‰ ์ง์„ ๊ณผ ๋นจ๊ฐ•์ƒ‰ ์ง์„ ์€ ์„ฑ๊ฒฉ์ด ์ข€ ๋‹ค๋ฅด์ง€๋งŒ ๊ฒฐ๊ตญ ๋‘˜ ๋‹ค ์‹œํ€€์Šค์˜ position ์ •๋ณด๋ฅผ ๋‚˜ํƒ€๋‚ธ๋‹ค๋Š” ์ ์—์„œ ๋ฟŒ๋ฆฌ๋Š” ๊ฐ™๋‹ค๊ณ  ๋ณผ ์ˆ˜ ์žˆ๋‹ค. ๋”ฐ๋ผ์„œ ์‹ค์ œ hidden states ๊ณต๊ฐ„์—์„œ ์–ด๋–ค ์‹์œผ๋กœ ๋งตํ•‘๋ ์ง€๋Š” ์ž˜ ๋ชจ๋ฅด๊ฒ ์ง€๋งŒ, ์„œ๋กœ ์ง๊ตํ•˜๋Š” ํ˜•ํƒœ๋Š” ์•„๋‹ ๊ฒƒ์ด๋ผ ์ถ”์ธกํ•  ์ˆ˜ ์žˆ๋‹ค.

๊ทธ๋ ‡๋‹ค๋ฉด Absolute Position์„ ๋ชจ๋ธ ๊ทน์ดˆ๋ฐ˜์— ๋”ํ•ด์ค€๋‹ค๊ณ  ์ƒ๊ฐํ•ด๋ณด์ž. ์ธ์ฝ”๋”์— ๋“ค์–ด๊ฐ€๋Š” ํ–‰๋ ฌ์€ ๊ฒฐ๊ตญ ์œ„ ๊ทธ๋ฆผ์˜ ์ดˆ๋ก์ƒ‰ ์ง์„ ์œผ๋กœ ํ‘œํ˜„๋  ๊ฒƒ์ด๋‹ค. ํŒŒ๋ž€์ƒ‰ ์ง์„ ๊ณผ ๋นจ๊ฐ„์ƒ‰ ์ง์„ ์ด ๊ทผ์‚ฌ ์ง๊ตํ•œ๋‹ค๋Š” ๊ฐ€์ •ํ•˜์— ๋‘ ๋ฐฑํ„ฐ์˜ ํ•ฉ์€ ๋‘ ๋ฒกํ„ฐ์˜ 45๋„ ์ •๋„ ๋˜๋Š” ๊ณณ์— ์œ„์น˜ํ•˜๊ฒŒ(์ดˆ๋ก์ƒ‰ ์ง์„ ) ๋  ๊ฒƒ์ด๋‹ค. ๊ทธ๋ ‡๋‹ค๋ฉด ๋ณด๋ผ์ƒ‰ ์ง์„ ๊ณผ ์ดˆ๋ก์ƒ‰ ์ง์„ ์˜ ๊ด€๊ณ„ ์—ญ์‹œ ๊ทผ์‚ฌ ์ง๊ต์—์„œ ์„œ๋กœ ๊ฐ„์„ญํ•˜๋Š” ํ˜•ํƒœ๋กœ ๋ณ€ํ™”ํ•œ๋‹ค. ๋”ฐ์„œ EMD๋ฅผ ๊ทน์ดˆ๋ฐ˜์— ์‚ฌ์šฉํ•˜๋ฉด ๊ฐ„์„ญ์ด ๋ฐœ์ƒํ•ด ๋ชจ๋ธ์ด Relative Position ์ •๋ณด๋ฅผ ์ œ๋Œ€๋กœ ํ•™์Šตํ•˜์ง€ ๋ชปํ•  ๊ฒƒ์ด๋‹ค.

๐Ÿ‘ฉโ€๐Ÿ’ปย Implementation

# Pytorch Implementation of DeBERTa Enhanced Mask Decoder

class EnhancedMaskDecoder(nn.Module):
    """
    Class for Enhanced Mask Decoder module in DeBERTa, which is used for Masked Language Model (Pretrain Task)
    Word 'Decoder' means that denoise masked token by predicting masked token
    In official paper & repo, they might use 2 EMD layers for MLM Task
    And this layer's key & value input is output from last disentangled self-attention encoder layer,
    Also, all of them can share parameters and this layer also do disentangled self-attention
    In official repo, they implement this layer so hard coding that we can't understand directly & easily
    So, we implement this layer with our own style, as closely as possible to paper statement
    Notes:
        Also we temporarily implement only extract token embedding, not calculating logit, loss for MLM Task yet
        MLM Task will be implemented ASAP
    Args:
        encoder: list of nn.ModuleList, which is (N_EMD * last encoder layer) from DeBERTaEncoder
    References:
        https://arxiv.org/abs/2006.03654
        https://arxiv.org/abs/2111.09543
        https://github.com/microsoft/DeBERTa/blob/master/DeBERTa/apps/models/masked_language_model.py
    """
    def __init__(self, encoder: list[nn.ModuleList], dim_model: int = 1024) -> None:
        super(EnhancedMaskDecoder, self).__init__()
        self.emd_layers = encoder
        self.layer_norm = nn.LayerNorm(dim_model)

    def emd_context_layer(self, hidden_states, abs_pos_emb, rel_pos_emb, mask):
        outputs = []
        query_states = hidden_states + abs_pos_emb  # "I" in official paper,
        for emd_layer in self.emd_layers:
            query_states = emd_layer(x=hidden_states, pos_x=rel_pos_emb, mask=mask, emd=query_states)
            outputs.append(query_states)
        last_hidden_state = self.layer_norm(query_states)  # because of applying pre-layer norm
        hidden_states = torch.stack(outputs, dim=0).to(hidden_states.device)
        return last_hidden_state, hidden_states

    def forward(self, hidden_states: Tensor, abs_pos_emb: Tensor, rel_pos_emb, mask: Tensor) -> tuple[Tensor, Tensor]:
        """
        hidden_states: output from last disentangled self-attention encoder layer
        abs_pos_emb: absolute position embedding
        rel_pos_emb: relative position embedding
        """
        last_hidden_state, emd_hidden_states = self.emd_context_layer(hidden_states, abs_pos_emb, rel_pos_emb, mask)
        return last_hidden_state, emd_hidden_states

emd_context_layer ๋ฉ”์„œ๋“œ์—์„œ Absolute Position ์ •๋ณด๋ฅผ ์ถ”๊ฐ€ํ•ด์ฃผ๋Š” ๋ถ€๋ถ„์„ ์ œ์™ธํ•˜๋ฉด ์ผ๋ฐ˜ Encoder ๊ฐ์ฒด์˜ ๋™์ž‘๊ณผ ๋™์ผํ•˜๋‹ค. ๋˜ํ•œ DeBERTa๋Š” ๋ชจ๋“  ๋ ˆ์ด์–ด๊ฐ€ ๊ฐ™์€ ์‹œ์ ์˜ forward pass ๋•Œ, ๋™์ผํ•œ ๊ฐ€์ค‘์น˜์˜ Relative Position Embedding์„ ์‚ฌ์šฉํ•ด์•ผ ํ•˜๋Š”๋ฐ, EMD ์—ญ์‹œ ์˜ˆ์™ธ๋Š” ์•„๋‹ˆ๊ธฐ ๋•Œ๋ฌธ์— ๋ฐ˜๋“œ์‹œ ์ตœ์ƒ์œ„ ๊ฐ์ฒด์—์„œ ์ดˆ๊ธฐํ™”ํ•œ Relative Position Embedding์„ ๋˜‘๊ฐ™์ด ๋งค๊ฐœ๋ณ€์ˆ˜๋กœ ์ „๋‹ฌํ•ด์ค˜์•ผ ํ•œ๋‹ค.

๊ทธ๋ฆฌ๊ณ  ๋งˆ์ง€๋ง‰์œผ๋กœ ๊ฐ์ฒด์—์„œ ์‚ฌ์šฉํ•˜๋Š” emd_layers ๋Š” ๋ชจ๋‘ Disentangled-Attention ๋ ˆ์ด์–ด์˜ ๊ฐ€์žฅ ๋งˆ์ง€๋ง‰ ์ธ์ฝ”๋”๋ผ๋Š” ์‚ฌ์‹ค์„ ์žŠ์ง€ ๋ง์ž.

๐Ÿ‘ฉโ€๐Ÿ‘ฉโ€๐Ÿ‘งโ€๐Ÿ‘ฆย Multi-Head Attention

์ด์ œ ๋‚˜๋จธ์ง€ ๋ธ”๋Ÿญ๋“ค์— ๋Œ€ํ•ด์„œ ์‚ดํŽด๋ณด๊ฒ ๋‹ค. ์›๋ฆฌ๋‚˜ ์˜๋ฏธ๋Š” ์ด๋ฏธ Transformer ๋ฆฌ๋ทฐ์—์„œ ๋ชจ๋‘ ์‚ดํŽด๋ดค๊ธฐ ๋•Œ๋ฌธ์— ์ƒ๋žตํ•˜๊ณ , ๊ตฌํ˜„์ƒ ํŠน์ด์ ์— ๋Œ€ํ•ด์„œ๋งŒ ์–ธ๊ธ‰ํ•˜๋ ค๊ณ  ํ•œ๋‹ค. ๋จผ์ € Single-Head Atttention ์ฝ”๋“œ๋ฅผ ๋ณด์ž.

# Pytorch Implementation of Single Attention Head

class MultiHeadAttention(nn.Module):
    """ In this class, we implement workflow of Multi-Head Self-attention for DeBERTa-Large
    This class has same role as Module "BertAttention" in official Repo (bert.py)
    In official repo, they use post-layer norm, but we use pre-layer norm which is more stable & efficient for training

    Args:
        dim_model: dimension of model's latent vector space, default 1024 from official paper
        num_attention_heads: number of heads in MHSA, default 16 from official paper for Transformer
        dim_head: dimension of each attention head, default 64 from official paper (1024 / 16)
        attention_dropout_prob: dropout rate, default 0.1

    Math:
        attention Matrix = c2c + c2p + p2c
        A = softmax(attention Matrix/sqrt(3*D_h)), SA(z) = Av

    Reference:
        https://arxiv.org/abs/1706.03762
        https://arxiv.org/abs/2006.03654
    """

    def __init__(self, dim_model: int = 1024, num_attention_heads: int = 12, dim_head: int = 64,
                 attention_dropout_prob: float = 0.1) -> None:
        super(MultiHeadAttention, self).__init__()
        self.dim_model = dim_model
        self.num_attention_heads = num_attention_heads
        self.dim_head = dim_head
        self.fc_q = nn.Linear(self.dim_model, self.dim_model)
        self.fc_k = nn.Linear(self.dim_model, self.dim_model)
        self.fc_v = nn.Linear(self.dim_model, self.dim_model)
        self.fc_qr = nn.Linear(self.dim_model, self.dim_model)  # projector for Relative Position Query matrix
        self.fc_kr = nn.Linear(self.dim_model, self.dim_model)  # projector for Relative Position Key matrix
        self.fc_concat = nn.Linear(self.dim_model, self.dim_model)
        self.attention = disentangled_attention
        self.attention_dropout = nn.Dropout(p=attention_dropout_prob)

    def forward(self, x: Tensor, rel_pos_emb: Tensor, padding_mask: Tensor, attention_mask: Tensor = None,
                emd: Tensor = None) -> Tensor:
        """ x is already passed nn.Layernorm """
        assert x.ndim == 3, f'Expected (batch, seq, hidden) got {x.shape}'

        # size: bs, seq, nums head, dim head, linear projection
        q = self.fc_q(x).reshape(-1, x.shape[1], self.num_attention_heads, self.dim_head).permute(0, 2, 1, 3).contiguous()
        k = self.fc_k(x).reshape(-1, x.shape[1], self.num_attention_heads, self.dim_head).permute(0, 2, 3, 1).contiguous()
        v = self.fc_v(x).reshape(-1, x.shape[1], self.num_attention_heads, self.dim_head).permute(0, 2, 1, 3).contiguous()
        qr = self.fc_qr(rel_pos_emb).reshape(-1, x.shape[1], self.num_attention_heads, self.dim_head).permute(0, 2, 1, 3).contiguous()
        kr = self.fc_kr(rel_pos_emb).reshape(-1, x.shape[1], self.num_attention_heads, self.dim_head)
        if emd is not None:
            q = self.fc_q(emd).reshape(-1, x.shape[1], self.num_attention_heads, self.dim_head).permute(0, 2, 1, 3).contiguous()

        attention_matrix = self.attention(
            q,
            k,
            v,
            qr,
            kr,
            self.attention_dropout,
            padding_mask,
            attention_mask
        )
        attention_output = self.fc_concat(attention_matrix)
        return attention_output

๋™์ž‘ ์ž์ฒด๋Š” ๋™์ผํ•˜์ง€๋งŒ, ์ƒ๋Œ€ ์œ„์น˜ ์ •๋ณด์— ๋Œ€ํ•œ linear projection ๋ ˆ์ด์–ด๊ฐ€ ์ถ”๊ฐ€ ๋˜์—ˆ๋‹ค. ๊ทธ๋ฆฌ๊ณ  Enhanced Mask Decoder ๋ฅผ ์œ„ํ•ด forward ๋ฉ”์„œ๋“œ์— ์กฐ๊ฑด๋ฌธ์„ ํ™œ์šฉํ•˜์—ฌ Decodingํ•˜๋Š” ์‹œ์ ์—๋Š” hidden_states + absolute position embedding ์œผ๋กœ ํ–‰๋ ฌ $Q^c$๋ฅผ ํ‘œํ˜„ํ•˜๊ฒŒ ๊ตฌํ˜„ํ–ˆ๋‹ค. ์ด๋ ‡๊ฒŒ ๊ตฌํ˜„ํ•˜๋ฉด EMD ๋ฅผ ์œ„ํ•ด ๋”ฐ๋กœ AttentionHead๋ฅผ ๊ตฌํ˜„ํ•  ํ•„์š”๊ฐ€ ์—†์–ด์„œ ์ฝ”๋“œ ๊ฐ„์†Œํ™”๊ฐ€ ๋œ๋‹ค.

MultiHeadAttention ๊ฐ์ฒด๋Š” ๋‹จ์ผ AttentionHead ๊ฐ์ฒด๋ฅผ ํ˜ธ์ถœํ•  ๋•Œ rel_pos_emb ๋ฅผ ๋งค๊ฐœ๋ณ€์ˆ˜๋กœ ์ „๋‹ฌํ•ด์•ผ ํ•œ๋‹ค๋Š” ์ ๋งŒ ๊ธฐ์–ตํ•˜๋ฉด ๋œ๋‹ค.

๐Ÿ”ฌย Feed Forward Network

# Pytorch Implementation of FeedForward Network

class FeedForward(nn.Module):
    """
    Class for Feed-Forward Network module in transformer
    In official paper, they use ReLU activation function, but GELU is better for now
    We change ReLU to GELU & add dropout layer
    Args:
        dim_model: dimension of model's latent vector space, default 512
        dim_ffn: dimension of FFN's hidden layer, default 2048 from official paper
        dropout: dropout rate, default 0.1
    Math:
        FeedForward(x) = FeedForward(LN(x))+x
    """
    def __init__(self, dim_model: int = 512, dim_ffn: int = 2048, dropout: float = 0.1) -> None:
        super(FeedForward, self).__init__()
        self.ffn = nn.Sequential(
            nn.Linear(dim_model, dim_ffn),
            nn.GELU(),
            nn.Dropout(p=dropout),
            nn.Linear(dim_ffn, dim_model),
            nn.Dropout(p=dropout),
        )

    def forward(self, x: Tensor) -> Tensor:
        return self.ffn(x)

์—ญ์‹œ ๊ธฐ์กด Transformer, BERT์™€ ๋‹ค๋ฅธ๊ฒŒ ์—†๋‹ค.

๐Ÿ“˜ย DeBERTaEncoderLayer

# Pytorch Implementation of DeBERTaEncoderLayer(single Disentangled-Attention Encoder Block)

class DeBERTaEncoderLayer(nn.Module):
    """
    Class for encoder model module in DeBERTa-Large
    In this class, we stack each encoder_model module (Multi-Head Attention, Residual-Connection, LayerNorm, FFN)
    This class has same role as Module "BertEncoder" in official Repo (bert.py)
    In official repo, they use post-layer norm, but we use pre-layer norm which is more stable & efficient for training
    References:
        https://github.com/microsoft/DeBERTa/blob/master/DeBERTa/deberta/bert.py
    """
    def __init__(self, dim_model: int = 1024, num_heads: int = 16, dim_ffn: int = 4096, dropout: float = 0.1) -> None:
        super(DeBERTaEncoderLayer, self).__init__()
        self.self_attention = MultiHeadAttention(
            dim_model,
            num_heads,
            int(dim_model / num_heads),
            dropout,
        )
        self.layer_norm1 = nn.LayerNorm(dim_model)
        self.layer_norm2 = nn.LayerNorm(dim_model)
        self.dropout = nn.Dropout(p=dropout)
        self.ffn = FeedForward(
            dim_model,
            dim_ffn,
            dropout,
        )

    def forward(self, x: Tensor, pos_x: torch.nn.Embedding, mask: Tensor, emd: Tensor = None) -> Tensor:
        """ rel_pos_emb is fixed for all layer in same forward pass time """
        ln_x, ln_pos_x = self.layer_norm1(x), self.layer_norm1(pos_x)  # pre-layer norm, weight share
        residual_x = self.dropout(self.self_attention(ln_x, ln_pos_x, mask, emd)) + x

        ln_x = self.layer_norm2(residual_x)
        fx = self.ffn(ln_x) + residual_x
        return fx

official code์™€ ๋‹ค๋ฅด๊ฒŒ pre-layernorm ์„ ์‚ฌ์šฉํ•ด ๊ตฌํ˜„ํ–ˆ๋‹ค. pre-layernorm์— ๋Œ€ํ•ด ๊ถ๊ธˆํ•˜๋‹ค๋ฉด ์—ฌ๊ธฐ๋ฅผ ํด๋ฆญํ•ด ํ™•์ธํ•ด๋ณด์ž.

๐Ÿ“šย DeBERTaEncoder

# Pytorch Implementation of DeBERTaEncoderr(N stacked DeBERTaEncoderLayer)

class DeBERTaEncoder(nn.Module, AbstractModel):
    """ In this class, 1) encode input sequence, 2) make relative position embedding, 3) stack num_layers DeBERTaEncoderLayer
    This class's forward output is not integrated with EMD Layer's output

    Output have ONLY result of disentangled self-attention

    All ops order is from official paper & repo by microsoft, but ops operating is slightly different,
    Because they use custom ops, e.g. XDropout, XSoftmax, ..., we just apply pure pytorch ops

    Args:
        max_seq: maximum sequence length, named "max_position_embedding" in official repo, default 512, in official paper, this value is called 'k'
        num_layers: number of EncoderLayer, default 24 for large model

    Notes:
        self.rel_pos_emb: P in paper, this matrix is fixed during forward pass in same time, all layer & all module must share this layer from official paper

    References:
        https://github.com/microsoft/DeBERTa/blob/master/DeBERTa/deberta/ops.py
    """
    def __init__(
            self,
            cfg: CFG,
            max_seq: int = 512,
            num_layers: int = 24,
            dim_model: int = 1024,
            num_attention_heads: int = 16,
            dim_ffn: int = 4096,
            layer_norm_eps: float = 0.02,
            attention_dropout_prob: float = 0.1,
            hidden_dropout_prob: float = 0.1,
            gradient_checkpointing: bool = False
    ) -> None:
        super(DeBERTaEncoder, self).__init__()
        self.cfg = cfg
        self.max_seq = max_seq
        self.num_layers = num_layers
        self.dim_model = dim_model
        self.num_attention_heads = num_attention_heads
        self.dim_ffn = dim_ffn
        self.hidden_dropout = nn.Dropout(p=hidden_dropout_prob)  # dropout is not learnable
        self.layer = nn.ModuleList(
            [DeBERTaEncoderLayer(dim_model, num_attention_heads, dim_ffn, layer_norm_eps, attention_dropout_prob, hidden_dropout_prob) for _ in range(self.num_layers)]
        )
        self.layer_norm = nn.LayerNorm(dim_model, eps=layer_norm_eps)  # for final-Encoder output
        self.gradient_checkpointing = gradient_checkpointing

    def forward(self, inputs: Tensor, rel_pos_emb: Tensor, padding_mask: Tensor, attention_mask: Tensor = None) -> Tuple[Tensor, Tensor]:
        """
        Args:
            inputs: embedding from input sequence
            rel_pos_emb: relative position embedding
            padding_mask: mask for Encoder padded token for speeding up to calculate attention score or MLM
            attention_mask: mask for CLM
        """
        layer_output = []
        x, pos_x = inputs, rel_pos_emb  # x is same as word_embeddings or embeddings in official repo
        for layer in self.layer:
            if self.gradient_checkpointing and self.cfg.train:
                x = self._gradient_checkpointing_func(
                    layer.__call__,  # same as __forward__ call, torch reference recommend to use __call__ instead of forward
                    x,
                    pos_x,
                    padding_mask,
                    attention_mask
                )
            else:
                x = layer(
                    x,
                    pos_x,
                    padding_mask,
                    attention_mask
                )
            layer_output.append(x)
        last_hidden_state = self.layer_norm(x)  # because of applying pre-layer norm
        hidden_states = torch.stack(layer_output, dim=0).to(x.device)  # shape: [num_layers, BS, SEQ_LEN, DIM_Model]
        return last_hidden_state, hidden_states

EMD์™€ ๋งˆ์ฐฌ๊ฐ€์ง€๋กœ ๋ ˆ์ด์–ด์˜ ์œ„์น˜์— ์ƒ๊ด€์—†์ด ๊ฐ™์€ ์‹œ์ ์—๋Š” ๋ชจ๋‘ ๋™์ผํ•œ Relative Position Embedding ์„ ์‚ฌ์šฉํ•ด linear projection ํ•˜๋„๋ก ๊ตฌํ˜„ํ•ด์ฃผ๋Š” ๊ฒƒ์ด ์ค‘์š” ํฌ์ธํŠธ๋‹ค. forward ๋ฉ”์„œ๋“œ๋ฅผ ํ™•์ธํ•˜์ž!

๐Ÿค–ย DeBERTa

# Pytorch Implementation of DeBERTa
class DeBERTa(nn.Module, AbstractModel):
    """ Main class for DeBERTa, having all of sub-blocks & modules such as Disentangled Self-attention, DeBERTaEncoder, EMD
    Init Scale of DeBERTa Hyper-Parameters, Embedding Layer, Encoder Blocks, EMD Blocks
    And then make 3-types of Embedding Layer, Word Embedding, Absolute Position Embedding, Relative Position Embedding

    Args:
        cfg: configuration.CFG
        num_layers: number of EncoderLayer, default 12 for base model
        this value must be init by user for objective task
        if you select electra, you should set num_layers twice (generator, discriminator)

    Var:
        vocab_size: size of vocab in DeBERTa's Native Tokenizer
        max_seq: maximum sequence length
        max_rel_pos: max_seq x2 for build relative position embedding
        num_layers: number of Disentangled-Encoder layers
        num_attention_heads: number of attention heads
        num_emd: number of EMD layers
        dim_model: dimension of model
        num_attention_heads: number of heads in multi-head attention
        dim_ffn: dimension of feed-forward network, same as intermediate size in official repo
        hidden_dropout_prob: dropout rate for embedding, hidden layer
        attention_probs_dropout_prob: dropout rate for attention

    References:
        https://arxiv.org/abs/2006.03654
        https://arxiv.org/abs/2111.09543
        https://github.com/microsoft/DeBERTa/blob/master/experiments/language_model/deberta_xxlarge.json
        https://github.com/microsoft/DeBERTa/blob/master/DeBERTa/deberta/config.py
        https://github.com/microsoft/DeBERTa/blob/master/DeBERTa/deberta/deberta.py
        https://github.com/microsoft/DeBERTa/blob/master/DeBERTa/deberta/bert.py
        https://github.com/microsoft/DeBERTa/blob/master/DeBERTa/deberta/disentangled_attention.py
        https://github.com/microsoft/DeBERTa/blob/master/DeBERTa/apps/models/masked_language_model.py
    """
    def __init__(self, cfg: CFG, num_layers: int = 12) -> None:
        super(DeBERTa, self).__init__()
        # Init Scale of DeBERTa Module
        self.cfg = cfg
        self.vocab_size = cfg.vocab_size
        self.max_seq = cfg.max_seq
        self.max_rel_pos = 2 * self.max_seq
        self.num_layers = num_layers
        self.num_attention_heads = cfg.num_attention_heads
        self.num_emd = cfg.num_emd
        self.dim_model = cfg.dim_model
        self.dim_ffn = cfg.dim_ffn
        self.layer_norm_eps = cfg.layer_norm_eps
        self.hidden_dropout_prob = cfg.hidden_dropout_prob
        self.attention_dropout_prob = cfg.attention_probs_dropout_prob
        self.gradient_checkpointing = cfg.gradient_checkpoint

        # Init Embedding Layer
        self.embeddings = Embedding(cfg)

        # Init Encoder Blocks & Modules
        self.encoder = DeBERTaEncoder(
            self.cfg,
            self.max_seq,
            self.num_layers,
            self.dim_model,
            self.num_attention_heads,
            self.dim_ffn,
            self.layer_norm_eps,
            self.attention_dropout_prob,
            self.hidden_dropout_prob,
            self.gradient_checkpointing
        )
        self.emd_layers = [self.encoder.layer[-1] for _ in range(self.num_emd)]
        self.emd_encoder = EnhancedMaskDecoder(
            self.cfg,
            self.emd_layers,
            self.dim_model,
            self.layer_norm_eps,
            self.gradient_checkpointing
        )

    def forward(self, inputs: Tensor, padding_mask: Tensor, attention_mask: Tensor = None) -> Tuple[Tensor, Tensor]:
        """
        Args:
            inputs: input sequence, shape (batch_size, sequence)
            padding_mask: padding mask for MLM or padding token
            attention_mask: attention mask for CLM, default None
        """
        assert inputs.ndim == 2, f'Expected (batch, sequence) got {inputs.shape}'

        word_embeddings, rel_pos_emb, abs_pos_emb = self.embeddings(inputs)
        last_hidden_state, hidden_states = self.encoder(word_embeddings, rel_pos_emb, padding_mask, attention_mask)

        emd_hidden_states = hidden_states[-self.cfg.num_emd]
        emd_last_hidden_state, emd_hidden_states = self.emd_encoder(emd_hidden_states, abs_pos_emb, rel_pos_emb, padding_mask, attention_mask)
        return emd_last_hidden_state, emd_hidden_states

Leave a comment