Updated:

๐Ÿ”ญย Overview

Roformer๋Š” 2021๋…„์— ๋ฐœํ‘œ๋œ ํŠธ๋žœ์Šคํฌ๋จธ ๋ชจ๋ธ์˜ ๋ณ€ํ˜•์œผ๋กœ, RoPE(Rotary Position Embedding)์ด๋ผ๋Š” ์ƒˆ๋กœ์šด ์œ„์น˜ ์ •๋ณด ํฌ์ฐฉ ๋ฐฉ์‹์„ ์ œ์•ˆํ–ˆ๋‹ค. ๊ทผ๋ž˜ ์œ ๋ช…ํ•œ ์˜คํ”ˆ์†Œ์Šค LLM ๋ชจ๋ธ๋“ค(GPT-Neo, LLaMA)์˜ ์œ„์น˜ ์ •๋ณด ํฌ์ฐฉ ๋ฐฉ์‹์œผ๋กœ ์ฑ„ํƒ ๋˜์–ด ์ฃผ๋ชฉ์„ ๋ฐ›๊ณ  ์žˆ๋‹ค. RoPE ๊ธฐ๋ฒ•์— ๋Œ€ํ•ด ์‚ดํŽด๋ณด๊ธฐ ์ „์— ์ผ๋‹จ, ๊ด€๋ จ ๋ถ„์•ผ์˜ ์—ฐ๊ตฌ ๋™ํ–ฅ ๋ฐ ์œ„์น˜ ์ •๋ณด์˜ ๊ฐœ๋…์— ๋Œ€ํ•ด ๊ฐ„๋‹จํ•˜๊ฒŒ ์‚ดํŽด๋ณด๊ณ  ๋„˜์–ด๊ฐ€๋ ค ํ•œ๋‹ค.

๐Ÿค” Absolute Position vs Relative Position

ํŠธ๋žœ์Šคํฌ๋จธ๊ฐ€ ์„ฑ๊ณต์„ ๊ฑฐ๋‘˜ ์ˆ˜ ์žˆ์—ˆ๋˜ ์ด์œ ๋Š” ์ „์ฒด ์‹œํ€€์Šค๋ฅผ ๋ณ‘๋ ฌ์ ์œผ๋กœ ํ•œ ๋ฒˆ์— ์ฒ˜๋ฆฌํ•˜๋˜, ์‹œํ€€์Šค ๋ฐœ์ƒ ์ˆœ์„œ ์ •๋ณด๋ฅผ ํ–‰๋ ฌํ•ฉ ๋ฐฉ์‹์œผ๋กœ ์ธ์ฝ”๋”ฉํ•ด์คฌ๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค. ์ด ๋ถ„์•ผ์— ๋Œ€ํ•œ ์—ฐ๊ตฌ ๋™ํ–ฅ์€ ํฌ๊ฒŒ Absolute Position, Relative Position ๋ฐฉ์‹์œผ๋กœ ๋ถ„ํ™”๋œ๋‹ค.

Absolute Position์€ ์ฃผ์–ด์ง„ ์‹œํ€€์Šค์˜ ๊ธธ์ด๋ฅผ ์ธก์ •ํ•œ ๋’ค, ๋‚˜์—ด๋œ ์ˆœ์„œ ๊ทธ๋Œ€๋กœย forwardํ•˜๊ฒŒย 0๋ถ€ํ„ฐย ๊ธธ์ด-1์˜ ๋ฒˆํ˜ธ๋ฅผ ๊ฐœ๋ณ„ ํ† ํฐ์— ํ• ๋‹นํ•œ๋‹ค. ๋‹ค์‹œ ๋งํ•ด, ๋‹จ์–ด๊ฐ€ ์‹œํ€€์Šค์—์„œ ๋ฐœ์ƒํ•œ ์ˆœ์„œ๋ฅผ ์ˆ˜ํ•™์ ์œผ๋กœ ํ‘œํ˜„ํ•ด ๋ชจ๋ธ์— ์ฃผ์ž…ํ•œ๋‹ค๋Š” ์˜๋ฏธ๊ฐ€ ๋œ๋‹ค.

ํ•œํŽธ, Relative Position์€ ์‹œํ€€์Šค ๋‚ด๋ถ€ ํ† ํฐ ์‚ฌ์ด์˜ ์œ„์น˜ ๊ด€๊ณ„ ํ‘œํ˜„์„ ํ†ตํ•ด ํ† ํฐ ์‚ฌ์ด์˜ย relation์„ย pairwiseํ•˜๊ฒŒ ํ•™์Šตํ•˜๋Š” ์œ„์น˜ ์ž„๋ฒ ๋”ฉ ๊ธฐ๋ฒ•์„ ๋งํ•œ๋‹ค. ์ผ๋ฐ˜์ ์œผ๋กœ ์ƒ๋Œ€ ์œ„์น˜ ๊ด€๊ณ„๋Š” ์„œ๋กœ ๋‹ค๋ฅธ ๋‘ ํ† ํฐ์˜ ์‹œํ€€์Šค ์ธ๋ฑ์Šค ๊ฐ’์˜ ์ฐจ๋ฅผ ์ด์šฉํ•ด ๋‚˜ํƒ€๋‚ธ๋‹ค. ํฌ์ฐฉํ•˜๋Š” ๋ฌธ๋งฅ ์ •๋ณด๋Š” ์˜ˆ์‹œ์™€ ํ•จ๊นจ ์„ค๋ช…ํ•˜๊ฒ ๋‹ค. ์˜ˆ์‹œ๋Š” ์˜ˆ์ „ DeBERTa ๋…ผ๋ฌธ์—์„œ ๋‚˜์™”๋˜ ๊ฒƒ์„ ํ™œ์šฉํ–ˆ๋‹ค. ๋”ฅ๋Ÿฌ๋‹์ด๋ผ๋Š” ๋‹จ์–ด๋Š” ์˜์–ด๋กœย Deep Learningย ์ด๋‹ค. ๋‘ ๋‹จ์–ด๋ฅผ ํ•ฉ์ณ๋†“๊ณ  ๋ณด๋ฉดย ์‹ ๊ฒฝ๋ง์„ ์‚ฌ์šฉํ•˜๋Š” ๋จธ์‹ ๋Ÿฌ๋‹ ๊ธฐ๋ฒ•์˜ ํ•œ ์ข…๋ฅ˜๋ผ๋Š” ์˜๋ฏธ๋ฅผ ๊ฐ–๊ฒ ์ง€๋งŒ, ๋”ฐ๋กœ ๋”ฐ๋กœ ๋ณด๋ฉดย ๊นŠ์€,ย ๋ฐฐ์›€์ด๋ผ๋Š” ๊ฐœ๋ณ„์ ์ธ ์˜๋ฏธ๋กœ ๋‚˜๋‰œ๋‹ค.

  • 1) The Deep Learning is the Best Technique in Computer Science
  • 2) Iโ€™m learning how to swim in the deep ocean

Deep๊ณผย Learning์˜ ์ƒ๋Œ€์ ์ธ ๊ฑฐ๋ฆฌ์— ์ฃผ๋ชฉํ•˜๋ฉด์„œ ๋‘ ๋ฌธ์žฅ์„ ํ•ด์„ํ•ด๋ณด์ž. ์ฒซ ๋ฒˆ์งธ ๋ฌธ์žฅ์—์„œ ๋‘ ๋‹จ์–ด๋Š” ์ด์›ƒํ•˜๊ฒŒ ์œ„์น˜ํ•ดย ์‹ ๊ฒฝ๋ง์„ ์‚ฌ์šฉํ•˜๋Š” ๋จธ์‹ ๋Ÿฌ๋‹ ๊ธฐ๋ฒ•์˜ ํ•œ ์ข…๋ฅ˜ย ๋ผ๋Š” ์˜๋ฏธ๋ฅผ ๋งŒ๋“ค์–ด๋‚ด๊ณ  ์žˆ๋‹ค. ํ•œํŽธ ๋‘ ๋ฒˆ์งธ ๋ฌธ์žฅ์—์„œ ๋‘ ๋‹จ์–ด๋Š” ๋„์–ด์“ฐ๊ธฐ ๊ธฐ์ค€ 5๊ฐœ์˜ ํ† ํฐ๋งŒํผ ๋–จ์–ด์ ธ ์œ„์น˜ํ•ด ๊ฐ๊ฐย ๋ฐฐ์›€,ย ๊นŠ์€ย ์ด๋ผ๋Š” ์˜๋ฏธ๋ฅผ ๋งŒ๋“ค์–ด ๋‚ด๊ณ  ์žˆ๋‹ค. ์ด์ฒ˜๋Ÿผ ๊ฐœ๋ณ„ ํ† ํฐ ์‚ฌ์ด์˜ ์œ„์น˜ ๊ด€๊ณ„์— ๋”ฐ๋ผ์„œ ํŒŒ์ƒ๋˜๋Š” ๋ฌธ๋งฅ์  ์ •๋ณด๋ฅผ ํฌ์ฐฉํ•˜๋ ค๋Š” ์˜๋„๋กœ ์„ค๊ณ„๋œ ๊ธฐ๋ฒ•์ด ๋ฐ”๋กœย Relative Position Embeddingย ์ด๋‹ค.

๐Ÿค” Word Context vs Relative Position vs Absolute Position

์ค„ ์„œ์žˆ๋Š” ์‚ฌ๋žŒ๋“ค ์ค„ ์„œ์žˆ๋Š” ์‚ฌ๋žŒ๋“ค

์ง€๊ธˆ๊นŒ์ง€ Relative Position Embedding์ด ๋ฌด์—‡์ด๊ณ , ๋„๋Œ€์ฒด ์–ด๋–ค ๋ฌธ๋งฅ ์ •๋ณด๋ฅผ ํฌ์ฐฉํ•œ๋‹ค๋Š” ๊ฒƒ์ธ์ง€ ์•Œ์•„๋ดค๋‹ค. ํ•„์ž์˜ ์„ค๋ช…์ด ๋งค๋„๋Ÿฝ์ง€ ๋ชปํ•˜๊ธฐ๋„ ํ•˜๊ณ  ์˜ˆ์‹œ๋ฅผ ํ…์ŠคํŠธ๋กœ ๋“ค๊ณ  ์žˆ์–ด์„œ ์ง๊ด€์ ์œผ๋กœ word context๋Š” ๋ฌด์—‡์ธ์ง€, Position ์ •๋ณด์™€๋Š” ๋ญ๊ฐ€ ๋‹ค๋ฅธ์ง€, ๋‘ ๊ฐ€์ง€ Position ์ •๋ณด๋Š” ๋ญ๊ฐ€ ์–ด๋–ป๊ฒŒ ๋‹ค๋ฅธ์ง€ ์™€๋‹ฟ์ง€ ์•Š๋Š” ๋ถ„๋“ค์ด ๋งŽ์œผ์‹ค ๊ฒƒ ๊ฐ™๋‹ค. ๊ทธ๋ž˜์„œ ์ตœ๋Œ€ํ•œ ์ง๊ด€์ ์ธ ์˜ˆ์‹œ๋ฅผ ํ†ตํ•ด ์„ธ๊ฐ€์ง€ ์ •๋ณด์˜ ์ฐจ์ด์ ์„ ์„ค๋ช…ํ•ด๋ณด๋ ค ํ•œ๋‹ค.

์‚ฌ๋žŒ 5๋ช…์ด ๊ณตํ•ญ ์ฒดํฌ์ธ์„ ์œ„ํ•ด ์„œ ์žˆ๋‹ค. ๋ชจ๋‘ ์™ผ์ชฝ์„ ๋ณด๊ณ  ์žˆ๋Š” ๊ฒƒ์„ ๋ณด์•„ ์™ผ์ชฝ์— ํ‚ค๊ฐ€ ์ œ์ผ ์ž‘์€ ์—ฌ์ž๊ฐ€ ๊ฐ€์žฅ ์•ž์ค„์ด๋ผ๊ณ  ๋ณผ ์ˆ˜ ์žˆ๊ฒ ๋‹ค. ์šฐ๋ฆฌ๋Š” ์ค„ ์„œ์žˆ๋Š” ์ˆœ์„œ๋Œ€๋กœ 5๋ช…์˜ ์‚ฌ๋žŒ์—๊ฒŒ ๋ฒˆํ˜ธ๋ฅผ ๋ถ€์—ฌํ•  ๊ฒƒ์ด๋‹ค. ํŽธ์˜์ƒ 0๋ฒˆ๋ถ€ํ„ฐ ์‹œ์ž‘ํ•ด 4๋ฒˆ๊นŒ์ง€ ๋ฒˆํ˜ธ๋ฅผ ์ฃผ๊ฒ ๋‹ค. 1๋ฒˆ์— ํ•ด๋‹นํ•˜๋Š” ์‚ฌ๋žŒ์€ ๋ˆ„๊ตฌ์ธ๊ฐ€?? ๋ฐ”๋กœ ์ค„์˜ 2๋ฒˆ์งธ์— ์„œ์žˆ๋Š” ์—ฌ์ž๋‹ค. ๊ทธ๋Ÿผ 2๋ฒˆ์— ํ•ด๋‹นํ•˜๋Š” ์‚ฌ๋žŒ์€ ๋ˆ„๊ตฌ์ธ๊ฐ€?? ์‚ฌ์ง„ ์† ์ค„์˜ ๊ฐ€์žฅ ์ค‘๊ฐ„์— ์žˆ๋Š” ๋‚จ์ž๊ฐ€ 2๋ฒˆ์ด๋‹ค. ์ด๋ ‡๊ฒŒ ๊ทธ๋ฃน ๋‹จ์œ„(์ „์ฒด ์ค„)์—์„œ ๊ฐœ๊ฐœ์ธ์— ์ผ๋ จ์˜ ๋ฒˆํ˜ธ๋ฅผ ๋ถ€์—ฌํ•ด ์œ„์น˜๋ฅผ ํ‘œํ˜„ํ•˜๋Š” ๋ฐฉ๋ฒ•์ด ๋ฐ”๋กœ Absolute Position Embedding์ด๋‹ค.

ํ•œํŽธ, ๋‹ค์‹œ 2๋ฒˆ ์‚ฌ๋žŒ์—๊ฒŒ ์ฃผ๋ชฉํ•ด๋ณด์ž. ์šฐ๋ฆฌ๋Š” 2๋ฒˆ ๋‚จ์ž๋ฅผ ์ „์ฒด ์ค„์—์„œ ๊ฐ€์šด๋ฐ ์œ„์น˜ํ•œ ์‚ฌ๋žŒ์ด ์•„๋‹ˆ๋ผ, ๊ฒ€์ •์ƒ‰ ์–‘๋ณต๊ณผ ๊ตฌ๋‘๋ฅผ ์‹ ๊ณ  ์†์— ์ฅ” ๋ฌด์–ธ๊ฐ€๋ฅผ ์‘์‹œํ•˜๊ณ  ์žˆ๋Š” ์‚ฌ๋žŒ์ด๋ผ๊ณ  ํ‘œํ˜„ํ•  ์ˆ˜๋„ ์žˆ๋‹ค. ์ด๊ฒƒ์ด ๋ฐ”๋กœ ํ† ํฐ์˜ ์˜๋ฏธ ์ •๋ณด๋ฅผ ๋‹ด์€ word context์— ํ•ด๋‹นํ•œ๋‹ค.

๋งˆ์ง€๋ง‰์œผ๋กœ Relative Position Embedding ๋ฐฉ์‹์œผ๋กœ 2๋ฒˆ ๋‚จ์ž๋ฅผ ํ‘œํ˜„ํ•ด๋ณด์ž. ์˜ค๋ฅธ์†์œผ๋กœ๋Š” ์ปคํ”ผ๋ฅผ ๋“ค๊ณ  ๋‹ค๋ฅธ ์†์œผ๋กœ๋Š” ์บ๋ฆฌ์–ด๋ฅผ ์žก๊ณ  ์žˆ์œผ๋ฉฐ ๊ฒ€์ •์ƒ‰ ํ•˜์ดํž๊ณผ ๋ฒ ์ด์ง€์ƒ‰ ๋ฐ”์ง€๋ฅผ ์ž…์€ 1๋ฒˆ ์—ฌ์ž์˜ ๋’ค์— ์žˆ๋Š” ์‚ฌ๋žŒ, ํšŒ์ƒ‰ ์–‘๋ณต๊ณผ ๊ฒ€์€ ๋ฟ”ํ…Œ ์•ˆ๊ฒฝ์„ ์“ฐ๊ณ  ํ•œ ์†์—๋Š” ์บ๋ฆฌ์–ด๋ฅผ ์žก๊ณ  ์žˆ๋Š” 4๋ฒˆ ์—ฌ์ž์˜ ์•ž์— ์žˆ๋Š” ์‚ฌ๋žŒ, ๊ฒ€์ •์ƒ‰ ์ž์ผ“๊ณผ ์ฒญ๋ฐ”์ง€๋ฅผ ์ž…๊ณ  ํ•œ ์†์—๋Š” ํšŒ์ƒ‰ ์ฝ”ํŠธ๋ฅผ ๋“ค๊ณ  ์žˆ๋Š” ์ค„์˜ ๋งจ ์•ž ์—ฌ์ž๋กœ๋ถ€ํ„ฐ 2๋ฒˆ์งธ ๋’ค์— ์„œ์žˆ๋Š” ์‚ฌ๋žŒ, ํ„ฑ์ˆ˜์—ผ์ด ๊ธธ๊ณ  ๋จธ๋ฆฌ๊ฐ€ ๊ธด ํŽธ์ด๋ฉฐ ํŒŒ๋ž€์ƒ‰ ๊ฐ€๋””๊ฑด์„ ์ž…๊ณ  ์ดˆ๋ก์ƒ‰๊ณผ ๊ฒ€์ •์ƒ‰์ด ํ˜ผํ•ฉ๋œ ๊ฐ€๋ฐฉ์„ ์™ผ์ชฝ์œผ๋กœ ๋ฉ”๊ณ  ์žˆ๋Š” ๋‚จ์ž๋กœ๋ถ€ํ„ฐ 2๋ฒˆ์งธ ์•ž์— ์žˆ๋Š” ์‚ฌ๋žŒ.

์ด์ฒ˜๋Ÿผ ํ‘œํ˜„ํ•˜๋Š”๊ฒŒ ๋ฐ”๋กœ Relative Position Embedding์— ๋Œ€์‘๋œ๋‹ค๊ณ  ๋ณผ ์ˆ˜ ์žˆ๋‹ค. ์ด์ œ ์œ„์น˜ ์ž„๋ฒ ๋”ฉ์— ๋Œ€ํ•ด์„œ ์‚ดํŽด๋ดค์œผ๋‹ˆ, ๋…ผ๋ฌธ์—์„œ ์ œ์‹œํ•˜๋Š” ๋‚ด์šฉ์— ๋Œ€ํ•ด์„œ ์•Œ์•„๋ณด์ž.

๐Ÿ—‚๏ธ Previous Work: Relative Position Embedding

๋ฏธ๋ฆฌ ๋งํ•˜์ž๋ฉด, RoPE๋Š” ์œ„์น˜ ์ •๋ณด ์ค‘์—์„œ ์ƒ๋Œ€ ์œ„์น˜๋ฅผ ํฌ์ฐฉํ•œ๋‹ค. ๊ทธ๋ž˜์„œ ์ €์ž๋Š” ๊ทธ๋“ค์˜ ๋ฐฉ๋ฒ•๋ก ์„ ์†Œ๊ฐœํ•˜๊ธฐ ์ „์— ๋จผ์ €, ์ด์ „ ์—ฐ๊ตฌ๋“ค์˜ ์ƒ๋Œ€ ์œ„์น˜ ํฌ์ฐฉ ๋ฐฉ์‹์— ๋Œ€ํ•ด์„œ ์†Œ๊ฐœํ•˜๊ณ  ์žˆ๋‹ค. ๊ฐ„๋‹จํžˆ ์‚ดํŽด๋ณด์ž.

\[q^T_mk_n = x^T_mW^T_qW_kx_n + x^T_mW^T_qW_kp_n + p^T_mW^T_qW_kx_n + p^T_mW^T_qW_kp_n\ \ \ (1) \\\] \[q^T_m k_n = x^T_m W^T_q W_k x_n + x^T_m W^T_q {W_k} \tilde{x}_{m-n} + \widetilde{p}_{m-n} W^T_q W_k x_n \ \ \ (2)\]

(1)๋ฒˆ ์ˆ˜์‹์€ Transformer-XL ๋…ผ๋ฌธ์—์„œ ์ œ์‹œ๋œ Cross Attention ์ˆ˜์‹์ด๋‹ค. ์œ„์น˜ ์ •๋ณด๋ฅผ ๋‹ด์•„๋‚ด๋Š” ํ•ญ์„ ๋”ฐ๋กœ ๋งŒ๋“ค๊ณ  ์ฟผ๋ฆฌ, ํ‚ค์— ๋Œ€์‘๋˜๋Š” ํ•ญ๊ณผ ๊ณฑํ•˜๊ณ  ์žˆ๋‹ค. (2)๋ฒˆ ์ˆ˜์‹์€ DeBERTa ๋ชจ๋ธ์—์„œ ์ œ์‹œ๋œ Disentangled Attention ์ด๋‹ค. (1)๊ณผ ๊ตฌ์„ฑ์˜ ์ฐจ์ด๋Š” ์žˆ์ง€๋งŒ ์—ญ์‹œ, ์œ„์น˜ ์ •๋ณด๋ฅผ ๋‹ด์•„๋‚ด๋Š” ํ•ญ์„ ์–ต์ง€๋กœ ๋งŒ๋“ค๊ณ  ๊ทธ๊ฒƒ๋“ค์„ ์ฟผ๋ฆฌ ํ˜น์€ ํ‚ค์™€ ๊ณฑํ•˜์—ฌ ์œ„์น˜ ์ •๋ณด๋ฅผ ๋‹ด์•„๋‚ธ ๋’ค, ๋ชจ๋‘ ํ•ฉํ•˜์—ฌ ์–ดํ…์…˜ ํ–‰๋ ฌ์„ ๋งŒ๋“ค์–ด ๋‚ด๊ณ  ์žˆ๋‹ค.

์ •๋ฆฌํ•˜๋ฉด, ๊ธฐ์กด ์—ฐ๊ตฌ๋“ค์€ ์ƒ๋Œ€ ์œ„์น˜๋ฅผ ํฌ์ฐฉํ•˜๊ธฐ ์œ„ํ•ด ๋ณ„๋„์˜ ํฌ์ง€์…˜ ํ–‰๋ ฌ์„ ๋งŒ๋“ค๊ณ , ์ด๋ฆฌ์ €๋ฆฌ ๊ณฑํ•˜๊ณ , ๋‹ค์‹œ ๊ทธ๊ฒƒ๋“ค์„ ๋ชจ๋‘ ํ•ฉํ•˜์—ฌ ์–ดํ…์…˜ ํ–‰๋ ฌ์„ ๋งŒ๋“ค๊ณ  ์žˆ๋Š” ๊ฒƒ์ด๋‹ค. ๊ธฐ์กด ์—ฐ๊ตฌ๋“ค์ด ์ œ์‹œํ•˜๋Š” ๋ฐฉ๋ฒ•๋ก ๋“ค์˜ ๊ณตํ†ต๋œ ๋ฌธ์ œ๋Š” ํ•™์Šตํ•ด์•ผ ํ•  ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜๊ฐ€ ๋Š˜์–ด๋‚˜ ๋ชจ๋ธ ์‚ฌ์ด์ฆˆ๋„ ์ปค์ง€๊ณ , ํ•™์Šต์‹œ๊ฐ„๋„ ๋Š˜์–ด๋‚œ๋‹ค๋Š” ๊ฒƒ์ด๋‹ค.

๐ŸŽกย RoPE

\[f_{q,k}(x_m, m)= \left( \begin{array}{cc}\cos(m\theta) & \sin(m\theta) \\-\sin(m\theta) & \cos(m\theta)\end{array} \right) \left( \begin{array}{cc}W^{(11)}_{q,k} & W^{(12)}_{q,k} \\W^{(21)}_{q,k} & W^{(22)}_{q,k} \end{array} \right) \left( \begin{array}{cc}x_m^{(1)} \\x_m^{(2)} \end{array} \right)\]

๋“ฑ์‹์˜ ์ขŒ๋ณ€์€ word embedding์„ ์„ ํ˜• ํˆฌ์˜ ์‹œ์ผœ ์–ป์€ query, key ๋ฒกํ„ฐ์— Rotary Position Embedding ๊ฐ’์„ ์ถ”๊ฐ€ํ•œ ๊ฒฐ๊ณผ ๊ฐ’์„ ๋œปํ•œ๋‹ค. ์šฐ๋ณ€์˜ ์ˆ˜์‹์ด ์ƒ๋‹นํžˆ ๋ณต์žกํ•ด ๋ณด์ด๋‚˜, ์‹ค์ƒ์€ ๋งค์šฐ ๊ฐ„๋‹จํ•˜๋‹ค. ์„ ํ˜• ํˆฌ์˜์œผ๋กœ ์–ป์€ query, key ๋ฒกํ„ฐ์— ์ขŒ์ธก์˜ ๊ดด๋ž„ํ•˜๊ฒŒ ์ƒ๊ธด ํ–‰๋ ฌ์„ ๊ณฑํ•ด์ฃผ๊ฒ ๋‹ค๋Š” ๊ฒƒ์ด๋‹ค. ์ขŒ์ธก์˜ ํ–‰๋ ฌ์€ ๋Œ€ํ•™๊ต ์„ ํ˜•๋Œ€์ˆ˜ ์‹œ๊ฐ„์— ์Šค์น˜๋“ฏ ์ง€๋‚˜๊ฐ”๋˜ Transformation Matrix(ํšŒ์ „ ํ–‰๋ ฌ)์ด๋‹ค. $m$์€ $m$-th ํ† ํฐ์„ ์˜๋ฏธํ•˜๋Š”๋ฐ, ์„ธํƒ€๊ฐ€ ๋ญ”์ง€๋Š” ๋ชจ๋ฅด๊ฒ ์ง€๋งŒ ์ผ๋‹จ ํ† ํฐ์˜ ์ธ๋ฑ์Šค ๊ฐ’์— ๋”ฐ๋ผ์„œ, ์ฃผ์–ด์ง„ ์›Œ๋“œ ์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ๋ฅผ ํšŒ์ „์‹œํ‚ค๊ฒ ๋‹ค๋Š” ๊ฒƒ์ด๋‹ค. ์ง€๊ธˆ ์‚ดํŽด๋ณธ ์˜ˆ์‹œ๋Š” ์€๋‹‰์ธต ํฌ๊ธฐ๊ฐ€ 2์ฐจ์›์ธ ๋‹จ์ˆœํ•œ ๋ฒกํ„ฐ์˜€๋‹ค. ์‹ค์ œ ๋ชจ๋ธ์— ์‚ฌ์šฉํ•˜๋Š” ์ฐจ์›(384, 512, 768, โ€ฆ)์œผ๋กœ ํ™•์žฅํ•˜๊ธฐ ์ „์— ์„ธํƒ€์˜ ์ •์ฒด์— ๋Œ€ํ•ด ์•Œ์•„๋ณด์ž.

\[\Theta = \left\{ \theta_i = 10000^{ -{2(i-1)}/{d}}, \quad i \in \left[1, 2, \ldots, \frac{d}{2}\right] \right\}\]

$\theta$์˜ ์ •์ฒด๋Š” ๋ฐ”๋กœ ์ฃผ๊ธฐํ•จ์ˆ˜ ์˜€๋‹ค. ํ“จ์–ดํ•œ ํŠธ๋žœ์Šคํฌ๋จธ์—์„œ Absolute Position Encoding์„ ์œ„ํ•ด Sinusoidal ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•œ ๊ฒƒ๊ณผ ๊ฐ™์€ ์ด์น˜๋ผ๊ณ  ์ƒ๊ฐํ•˜๋ฉด ๋œ๋‹ค. ์ฆ‰ $\theta$๋Š” word embedding ๋ฒกํ„ฐ๊ฐ€ ๊ฐ€์ง„ ์€๋‹‰์ธต ์ฐจ์› ๋ฐฉํ–ฅ ์ธ๋ฑ์Šค์— ๋”ฐ๋ผ์„œ ๋‹ฌ๋ผ์ง„๋‹ค. ์—ฌ๊ธฐ์— ์‹œํ€€์Šค ๊ธธ์ด ์ฐจ์› ๋ฐฉํ–ฅ์˜ ์ธ๋ฑ์Šค ๊ฐ’์„ ๋”ฐ๋กœ ๊ณฑํ•ด์ฃผ๊ธฐ ๋•Œ๋ฌธ์— ๊ทธ ์œ ์ผ์„ฑ์„ ๋ณด์žฅํ•  ์ˆ˜ ์žˆ๋‹ค.

์ด์ œ ์ „์ฒด RoPE๋ฅผ ์ดํ•ดํ•˜๋Š”๋ฐ ํ•„์š”ํ•œ ์žฌ๋ฃŒ ์ค€๋น„๋Š” ๋ชจ๋‘ ๋๋‚ฌ๋‹ค. ์ด์ œ ์‹ค์ œ ์ฐจ์›์œผ๋กœ ํ™•์žฅํ•ด๋ณด์ž.

\[fq,k(x_m,m)=R^d_{ฮ˜,m}W_{q,k}x_m \\\]

ํ–‰๋ ฌ $R^d_{ฮ˜,m}$์€ ์•„๋ž˜์™€ ๊ฐ™์€ ํ–‰๋ ฌ์„ ๋งํ•˜๋Š”๋ฐ,

\[R^d_{ฮ˜,m} = \begin{bmatrix} \cos(m\theta_1) & -\sin(m\theta_1) & 0 & 0 & \cdots & 0 & 0 \\ \sin(m\theta_1) & \cos(m\theta_1) & 0 & 0 & \cdots & 0 & 0 \\ 0 & 0 & \cos(m\theta_2) & -\sin(m\theta_2) & \cdots & 0 & 0 \\ 0 & 0 & \sin(m\theta_2) & \cos(m\theta_2) & \cdots & 0 & 0 \\ \vdots & \vdots & \vdots & \vdots & \ddots & \vdots & \vdots \\ 0 & 0 & 0 & 0 & \cdots & \cos(m\theta_{d/2}) & -\sin(m\theta_{d/2}) \\ 0 & 0 & 0 & 0 & \cdots & \sin(m\theta_{d/2}) & \cos(m\theta_{d/2}) \end{bmatrix}\]

ํ† ํฐ์˜ ์ธ๋ฑ์Šค์™€ ๋ชจ๋ธ์˜ ์€๋‹‰์ฐจ์› ์ธ๋ฑ์Šค์— ๋”ฐ๋ผ์„œ ํ–‰๋ ฌ์˜ ์›์†Œ๊ฐ’์ด ๊ฒฐ์ •๋จ์„ ์•Œ ์ˆ˜ ์žˆ๋‹ค. ์ด์ œ ๋‹ค์‹œ (3)๋ฒˆ ์ˆ˜์‹์˜ ์˜๋ฏธ๋ฅผ ์ƒ๊ฐํ•ด๋ณด์ž. ๋‹จ์–ด ์ž„๋ฒ ๋”ฉ์„ ์ฟผ๋ฆฌ, ํ‚ค ํ–‰๋ ฌ๋กœ ์„ ํ˜• ํˆฌ์˜ํ•œ ๋’ค (4)๋ฒˆ ์ˆ˜์‹์„ ๊ณฑํ•œ๋‹ค. ์ˆœ์ˆ˜ํ•œ ํšŒ์ „ํ–‰๋ ฌ์„ ์ฟผ๋ฆฌ, ํ‚ค ๋ฒกํ„ฐ์— ๊ณฑํ•˜๊ธฐ ๋•Œ๋ฌธ์— ๋ฒกํ„ฐ์˜ ํฌ๊ธฐ๋ฅผ ์œ ์ง€ํ•œ์ฑ„, ๋ฐฉํ–ฅ๋งŒ ๋ฐ”๊ฟ”์ค„ ์ˆ˜ ์žˆ๋‹ค๋Š” ์žฅ์ ์ด ์žˆ๋‹ค.

์ด์ „์˜ ์—ฐ๊ตฌ๋“ค์€ ํฌ์ง€์…˜ ์ •๋ณด๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ๋Š” ํ–‰๋ ฌ์„ ๋‹จ์–ด ๋ฒกํ„ฐ์— ๋”ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ๋ฒกํ„ฐ์˜ ๋ฐฉํ–ฅ์€ ๋ฌผ๋ก  ํฌ๊ธฐ ์—ญ์‹œ ์™œ๊ณก๋œ๋‹ค. ๋ฌผ๋ก  ๋‹จ์–ด ๋ฒกํ„ฐ์™€ ํฌ์ง€์…˜ ๋ฒกํ„ฐ๊ฐ€ ์„œ๋กœ ์„ฑ๊ฒฉ์ด ๋‹ค๋ฅธ ์ •๋ณด๋ผ๋Š” ์ ์„ ๊ณ ๋ คํ•˜๋ฉด ๋ชจ๋ธ์˜ ์€๋‹‰์ธต์ฒ˜๋Ÿผ ๊ณ ์ฐจ์› ๊ณต๊ฐ„์—์„œ ์„œ๋กœ ์ง๊ตํ•  ํ™•๋ฅ ์ด ๋งค์šฐ ๋†’๊ธฐ ๋•Œ๋ฌธ์—, ์„œ๋กœ ํ•™์Šต์— ์˜ํ–ฅ์„ ๋ฏธ์น  ๊ฐ€๋Šฅ์„ฑ์€ ๋‚ฎ๋‹ค. ํ•˜์ง€๋งŒ ํ™•๋ฅ ์ ์ธ ์ ‘๊ทผ์ผ ๋ฟ๋”๋Ÿฌ, ๋‹จ์–ด ๋ฒกํ„ฐ์˜ ํฌ๊ธฐ๊ฐ€ ์™œ๊ณก๋œ๋‹ค๋Š” ์ ์ด ์ธต์„ ๊ฑฐ๋“ญํ• ์ˆ˜๋ก ์˜ํ–ฅ์„ ๋ฏธ์น ์ง€ ์•Œ ์ˆ˜ ์—†๋‹ค.

RoPE ๋ฐฉ์‹์˜ ๋˜๋‹ค๋ฅธ ์žฅ์ ์€ ๊ณฑํ•˜๋Š” ๊ฒƒ๋งŒ์œผ๋กœ๋„, ์ƒ๋Œ€ ์œ„์น˜ ์ •๋ณด๋ฅผ ์ธ์ฝ”๋”ฉ ํ•ด์ค„ ์ˆ˜ ์žˆ๋‹ค๋Š” ์ ์ด๋‹ค. ์ด์ „ ์—ฐ๊ตฌ๋“ค์€ ๋Œ€๋ถ€๋ถ„ ์ ˆ๋Œ€ ์œ„์น˜ ํ˜น์€ ์ƒ๋Œ€ ์œ„์น˜ ํ•˜๋‚˜๋งŒ์„ ์„ ํƒํ•ด ๋‹จ์–ด ์ž„๋ฒ ๋”ฉ์— ์ •๋ณด๋ฅผ ์ถ”๊ฐ€ํ•ด์ฃผ๋Š” ๊ฒฝ์šฐ๊ฐ€ ๋Œ€๋‹ค์ˆ˜ ์˜€๋‹ค. DeBERTa์˜ ๊ฒฝ์šฐ์—๋งŒ, Task ๋ ˆ์ด์–ด ๊ทผ์ฒ˜(๋ ˆ์ด์–ด ํ›„๋ฐ˜๋ถ€)์— ๊ฐ€์„œ ์ ˆ๋Œ€ ์œ„์น˜๋ฅผ ๋”ํ•ด ์ƒ๋Œ€ ์œ„์น˜๊ฐ€ ๊ฐ–๋Š” ๋‹จ์ ์„ ๋ณด์™„ํ•˜๋ ค๋Š” ์‹œ๋„๋ฅผ ํ–ˆ๋‹ค. DeBERTa๊ฐ€ ์—ฌ๋Ÿฌ ๋ฐฉ๋ฉด์—์„œ ์ƒ๋‹นํžˆ ์ข‹์€ ์„ฑ๋Šฅ์„ ๊ฑฐ๋‘ฌ์„œ ๊ทธ๋ ‡์ง€, ๋งˆ์ง€๋ง‰ ๋ ˆ์ด์–ด ๊ทผ์ฒ˜์— ๊ฐ€์„œ ์ ˆ๋Œ€ ์œ„์น˜๋ฅผ ๋”ํ•ด์ฃผ๋Š”๊ฒŒ ์‚ฌ์‹ค ์ž์—ฐ์Šค๋Ÿฝ๋‹ค๊ณ  ์ƒ๊ฐ๋˜์ง€๋Š” ์•Š๋Š”๋‹ค. ๊ทธ๋Ÿฐ๋ฐ RoPE๋Š” ํšŒ์ „ ํ–‰๋ ฌ์„ ๊ณฑํ•˜๋Š” ๊ฒƒ๋งŒ์œผ๋กœ๋„ ์ ˆ๋Œ€ ์œ„์น˜์™€ ์ƒ๋Œ€ ์œ„์น˜ ๋ชจ๋‘ ์ธ์ฝ”๋”ฉ์ด ๊ฐ€๋Šฅํ•˜๋‹ค. ์–ด๋–ป๊ฒŒ ๊ทธ๋Ÿด๊นŒ??

์ผ๋‹จ RoPE ์„ ํ˜• ํˆฌ์˜๋œ ์ฟผ๋ฆฌ, ํ‚ค ํ–‰๋ ฌ์— ๊ฐ๊ฐ ํšŒ์ „ํ–‰๋ ฌ์„ ๊ณฑํ•œ๋‹ค. ๊ณฑํ•˜๋Š” ๊ณผ์ •์—์„œ ์ด๋ฏธ ํ† ํฐ์˜ ์ธ๋ฑ์Šค ๊ฐ’์— ๋”ฐ๋ผ์„œ ์„œ๋กœ ๋‹ค๋ฅธ ํฌ์ง€์…˜ ๊ฐ’์ด ๋‹จ์–ด ์ž„๋ฒ ๋”ฉ์— ๊ณฑํ•ด์ง€๊ฒŒ ๋œ๋‹ค. ์ด๊ฒƒ์œผ๋กœ ์ผ๋‹จ ์ ˆ๋Œ€ ์œ„์น˜ ์ •๋ณด๋ฅผ ์ถ”๊ฐ€ํ•ด์ค„ ์ˆ˜ ์žˆ๋‹ค. ๊ทธ๋ฆฌ๊ณ  ์ž˜ ์•Œ๋‹ค์‹œํ”ผ, ์ฟผ๋ฆฌ์™€ ํ‚ค์˜ ๋‚ด์ ์„ ์ˆ˜ํ–‰ํ•œ๋‹ค. ์ฟผ๋ฆฌ์™€ ํ‚ค์˜ ๋‚ด์ ์„ ๊ฐ๊ฐ ๋‹จ์–ด ์ž„๋ฒ ๋”ฉ, ์„ ํ˜• ํˆฌ์˜, ํšŒ์ „ํ–‰๋ ฌ ํ•ญ์œผ๋กœ ๋‚˜๋ˆ ์„œ ์‹์„ ํ’€์–ด ์“ฐ๋ฉด ์•„๋ž˜์™€ ๊ฐ™๋‹ค.

\[q^T_mk_n=(R^d_{ฮ˜,m}W_{q}x_m)^T(R^d_{ฮ˜,n}W_{k}x_n) \ \ \ (5)\]

์ˆ˜์‹์„ ์ „๊ฐœํ•˜๋ฉด ์ž์—ฐ์Šค๋ ˆ,

\[x^TW_qR^d_{ฮ˜,n-m}W_kx_n \ \ \ (6)\]

(6)๋ฒˆ ์ˆ˜์‹์ฒ˜๋Ÿผ ๋œ๋‹ค. ํ–‰๋ ฌ $R^d_{ฮ˜,n-m}$์˜ ์›์†Œ๋Š” ์•„๋ž˜์ฒ˜๋Ÿผ,

\[\cos(m\theta_1)*\cos(n\theta_1) - \sin(m\theta_1)*\sin(n\theta_1) \\\]

ํ† ํฐ ์ธ๋ฑ์Šค๋ฅผ ์˜๋ฏธํ•˜๋Š” $m,n$์— ๋Œ€ํ•œ ์ˆ˜์‹์œผ๋กœ ํ‘œํ˜„๋œ๋‹ค. ๋”ฐ๋ผ์„œ ์ž์—ฐ์Šค๋Ÿฝ๊ฒŒ ์ƒ๋Œ€ ์œ„์น˜๋ฅผ ํฌ์ฐฉํ•  ์ˆ˜ ์žˆ๊ฒŒ ๋œ๋‹ค. ์ƒ๋‹นํžˆ ์ž์—ฐ์Šค๋Ÿฝ๊ฒŒ ์„œ๋กœ ๋‹ค๋ฅธ ๋‘ ์œ„์น˜ ์ •๋ณด๋ฅผ ์ธ์ฝ”๋”ฉํ•˜๋Š”๊ฒŒ ๊ฐ€๋Šฅํ•˜๋ฉฐ, ์ถ”๊ฐ€์ ์œผ๋กœ ๋‹ค๋ฅธ ํ•ญ์„ ๋งŒ๋“ค์–ด ์–ดํ…์…˜ ํ–‰๋ ฌ์„ ๊ณ„์‚ฐํ•˜์ง€ ์•Š๊ธฐ ๋•Œ๋ฌธ์— ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ์ข€ ๋” ํšจ์œจ์ ์œผ๋กœ ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•˜๋‹ค.

ํ•œํŽธ, ํ† ํฐ์˜ ์ƒ๋Œ€ ์œ„์น˜๋ฅผ ํฌ์ฐฉํ•˜๋Š” ๋ฐฉ์‹์€ ์ž์‹ ๊ณผ ์ƒ๋Œ€์  ๊ฑฐ๋ฆฌ๊ฐ€ ๋ฉ€์–ด์งˆ์ˆ˜๋ก ์˜๋ฏธ์  ์—ฐ๊ด€์„ฑ์ด๋‚˜ ๊ด€๊ณ„์„ฑ์ด ๋–จ์–ด์ง„๋‹ค๋Š” ์ ์„ ์ „์ œ๋กœ ํ•œ๋‹ค. ์ฆ‰, ์„œ๋กœ ๊ฑฐ๋ฆฌ๊ฐ€ ๋จผ ํ† ํฐ์ผ์ˆ˜๋ก ์ฟผ๋ฆฌ์™€ ํ‚ค๋ฒกํ„ฐ์˜ ๋‚ด์ ๊ฐ’์ด 0์— ๊ฐ€๊นŒ์›Œ์ ธ์•ผ ํ•œ๋‹ค๋Š” ๊ฒƒ์ด๋‹ค. ์ €์ž ์—ญ์‹œ ์ด์ ์„ ์–ธ๊ธ‰ํ•˜๋ฉฐ RoPE ๋ฐฉ์‹์ด Long-Term Decay ์†์„ฑ์„ ๊ฐ–๊ณ  ์žˆ๋‹ค๊ณ  ์ฃผ์žฅํ•œ๋‹ค.

Long-Term Decay Long-Term Decay

Appendix์—์„œ ์ˆ˜ํ•™์ ์œผ๋กœ ์ฆ๋ช…๊นŒ์ง€ ์ œ์‹œํ•˜๊ณ  ์žˆ์œผ๋‚˜, ํ•„์ž์˜ ์ˆ˜ํ•™ ์‹ค๋ ฅ์ด ์–•์•„์„œ ์ œ์‹œ๋œ ๊ณผ์ •์ด ์ดํ•ด๊ฐ€ ๊ฐ€์งˆ ์•Š๋Š”๋‹ค. ์ถ”ํ›„์— ๊ด€๋ จ ๋‚ด์šฉ์€ ์ถ”๊ฐ€ํ•˜๋„๋ก ํ•˜๊ฒ ๋‹ค. ์ผ๋‹จ Relative Upper Bound๊ฐ€ ์ •ํ™•ํžˆ ๋ฌด์—‡์„ ๋งํ•˜๋Š”์ง€ ๋ชจ๋ฅด๊ฒ ์ง€๋งŒ(๋…ผ๋ฌธ์— ์ œ๋Œ€๋กœ ์–ธ๊ธ‰ x, ์ถ”์ธกํ•˜๊ฑด๋ฐ, ์˜๋ฏธ์  ์—ฐ๊ด€์„ฑ์„ ๋‚˜ํƒ€๋‚ด๋Š” ์ง€ํ‘œ ๊ฐ™์Œ, ์•„๋งˆ ๋‚ด์ ๊ฐ’์œผ๋กœ ์ถ”์ •), ์ œ์‹œ๋œ ๊ทธ๋ž˜ํ”„๋ฅผ ๋ณด๋ฉด ์„œ๋กœ ์ƒ๋Œ€์  ๊ฑฐ๋ฆฌ๊ฐ€ ๋ฉ€์–ด์งˆ์ˆ˜๋ก ํ•ด๋‹น ์ง€ํ‘œ๊ฐ€ ํ™•์—ฐํžˆ ๊ฐ์†Œํ•˜๋Š” ์ถ”์„ธ๋ฅผ ๋ณด์ธ๋‹ค.

๋งˆ์ง€๋ง‰์œผ๋กœ ๋…ผ๋ฌธ์—์„œ ๋ฐํžˆ๊ธธ (4), (5)๋ฒˆ ์ˆ˜์‹์˜ ํ˜•ํƒœ๋กœ RoPE๋ฅผ ๋งŒ๋“œ๋Š” ๊ฒƒ์€ ์—ฐ์‚ฐ ํšจ์œจ์ด ๋–จ์–ด์ง„๋‹ค๊ณ  ํ•œ๋‹ค. ๊ทธ๋ž˜์„œ Appendix์—์„œ ํšจ์œจ์ ์œผ๋กœ ์—ฐ์‚ฐํ•˜๋Š” ์ˆ˜์‹์„ ๋‹ค์‹œ ์ œ์‹œํ•˜๊ณ  ์žˆ๋‹ค.

\[R^d_{ฮ˜,m}x = \begin{bmatrix} x_1 \\ x_2 \\ x_3 \\ x_4 \\ ... \\ x_{d-1} \\ x_{d} \\ \end{bmatrix} \otimes \begin{bmatrix} cos (m\theta_1) \\ cos (m\theta_1) \\ cos (m\theta_2) \\ cos (m\theta_2) \\ ... \\ cos (m\theta_{d/2}) \\ cos (m\theta_{d/2}) \\ \end{bmatrix} + \begin{bmatrix} -x_2 \\ x_1 \\ -x_4 \\ x_3 \\ ... \\ -x_{d-1} \\ x_d \\ \end{bmatrix} \otimes \begin{bmatrix} \sin(m\theta_1) \\ \sin(m\theta_1) \\ \sin(m\theta_2) \\ \sin(m\theta_2) \\ \vdots \\ \sin(m\theta_{d/2}) \\ \sin(m\theta_{d/2}) \end{bmatrix}\]

์ˆ˜์‹ (4), (5)๋ฒˆ ํ˜•ํƒœ ๊ทธ๋Œ€๋กœ ๊ตฌํ˜„ํ•˜๋ ค๋ฉด, ํฌ๊ธฐ๊ฐ€ [seq_len, dim_head, dim_head]์ธ ํ…์„œ๋ฅผ ๊ณ„์† ๊ฐ€์ง€๊ณ  ์žˆ์–ด์•ผ ํ•œ๋‹ค. ์ด๋Š” ์ƒ๋‹นํžˆ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ๋‚ญ๋น„ํ•˜๊ฒŒ ๋œ๋‹ค. ์•„๋ž˜ ๊ทธ๋ฆผ์€ ํ•„์ž๊ฐ€ (4), (5)๋ฒˆ ํ˜•ํƒœ ๊ทธ๋Œ€๋กœ ๊ตฌํ˜„ํ•œ ๋’ค, MLM ํ•™์Šต์„ ๋Œ๋ฆฌ๋˜ ๋ชจ์Šต์ด๋‹ค.

body ver result [body ver result]

11์‹œ๊ฐ„ 40๋ถ„์œผ๋กœ ํ›ˆ๋ จ ์‹œ๊ฐ„์ด ์˜ˆ์ธก๋˜๋Š”๊ฑธ ๋ณผ ์ˆ˜ ์žˆ๋‹ค. ๋ฌผ๋ก , ์ด๋Ÿฌํ•œ ๊ฒฐ๊ณผ๊ฐ€ ๋‚˜์˜จ ์ด์œ ๋Š” \(R^d_{ฮ˜,m}x\)์ด ์ฐจ์ง€ํ•˜๋Š” ๋ฉ”๋ชจ๋ฆฌ ํฌ๊ธฐ๊ฐ€ ์ปค์ง€๋ฉด์„œ, GPU ์ƒ์— ํ•œ ๋ฒˆ์— ์˜ฌ๋ฆด ์ˆ˜๊ฐ€ ์—†์–ด์ ธ ๋ฐฐ์น˜๋งˆ๋‹ค ๋ฃจํ”„๋ฅผ ๋Œ๋ ค์„œ RoPE๋ฅผ ๊ฐœ๋ณ„ ์ฟผ๋ฆฌ, ํ‚ค์— ๊ณฑํ•ด์ฃผ๋Š” ๋ฐฉ์‹์„ ์„ ํƒํ–ˆ๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค. ์ด์ œ Appendix์—์„œ ์ œ์‹œํ•œ ๋ฐฉ๋ฒ•๋Œ€๋กœ RoPE๋ฅผ ๊ตฌํ˜„ํ•˜๋ฉด,

appendix ver result [appendix ver result]

์ด๋ ‡๊ฒŒ 4์‹œ๊ฐ„์œผ๋กœ ์‹œ๊ฐ„์ด ๋“œ๋ผ๋งˆํ‹ฑํ•˜๊ฒŒ ์ค„์–ด๋“ค์—ˆ๋‹ค. ์ด ๋ฐฉ๋ฒ•์€ ๋˜ํ•œ $R^d_{ฮ˜,m}$๋ฅผ [seq_len, dim_head] ํฌ๊ธฐ๋ฅผ ๊ฐ–๋Š” ํ…์„œ๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด ๋˜๊ธฐ ๋•Œ๋ฌธ์—, ์ด์ „ ๋ฐฉ์‹๋ณด๋‹ค ํ›จ์”ฌ ๋ฉ”๋ชจ๋ฆฌ๋„ ๋œ ์ฐจ์ง€ํ•œ๋‹ค. ์ด ๋ฐฉ์‹์€ ๋ฐฐ์น˜ ์ฐจ์›์œผ๋กœ ๋ฃจํ”„๋ฅผ ๋Œ๋ฆด ํ•„์š”๊ฐ€ ์—†์–ด์ ธ ํ›ˆ๋ จ์‹œ๊ฐ„๋„ ๋Œ€ํญ ๋‹จ์ถ•๋˜๋Š” ๊ฒƒ์ด๋‹ค.

๐Ÿ“ RoPE with linear attention

์ €์ž๋Š” ํ“จ์–ดํ•œ full attention ๋Œ€์‹  <Transformers are RNNs: Fast Autoregressive Transformers with linear attention> ๋…ผ๋ฌธ์—์„œ ์ œ์‹œ๋œ linear attention ์„ ์‚ฌ์šฉํ–ˆ๋‹ค๊ณ  ๋ฐํžˆ๊ณ  ์žˆ๋‹ค.

ํ•˜์ง€๋งŒ, linear attention ์˜ ๊ฒฝ์šฐ ๋””์ฝ”๋”์˜ CLM ์ˆ˜ํ–‰์— ์–ด์šธ๋ฆฌ๋Š” ๋ฐฉ์‹์œผ๋กœ, NLU๋ฅผ ์œ„ํ•œ ์ธ์ฝ”๋”์—๋Š” ์ ํ•ฉํ•˜์ง€ ์•Š๋‹ค. ํ•ด๋‹น ๋…ผ๋ฌธ์—์„œ๋„ ๋ชจ๋ธ์˜ ๋ฒค์น˜๋งˆํฌ ๊ฒฐ๊ณผ๋ฅผ ๋ชจ๋‘ NLG์— ๋Œ€ํ•ด์„œ๋งŒ ์ œ์‹œํ•œ๋‹ค. ๊ทธ๋ฆฌ๊ณ  ํ•„์ž๊ฐ€ ์ง์ ‘ ๊ตฌํ˜„ํ•ด MLM์„ ์ˆ˜ํ–‰ํ•ด๋ณธ ๊ฒฐ๊ณผ(์‹คํ—˜ ๊ฒฐ๊ณผ ๋งํฌ) ์ •ํ™•๋„๊ฐ€ ์ƒ๋‹นํžˆ ๋‚ฎ๊ฒŒ ๋‚˜์˜ค๋Š” ๊ฒƒ์„ ์•Œ ์ˆ˜ ์žˆ๋‹ค. ๋ฌผ๋ก  ์• ์ดˆ์— ํ•ด๋‹น ๋ฐฉ์‹์€ ํŠธ๋žœ์Šคํฌ๋จธ๋ฅผ RNN์ฒ˜๋Ÿผ ์‹œ๊ฐ„ ์ฐจ์›์— ๋Œ€ํ•ด์„œ ํ•™์Šตํ•˜๋Š” ๊ฒฝ์šฐ๋ฅผ ์ƒ์ •ํ•˜๊ณ  ๋งŒ๋“ค์—ˆ๊ธฐ ๋–„๋ฌธ์— linear attention ์„ BERT ๊ฐ™์€ ์ธ์ฝ”๋” ๋ชจ๋ธ์— ๊ทธ๋Œ€๋กœ ์‚ฌ์šฉํ•˜๋Š”๊ฒŒ ์• ์ดˆ์— ์•ˆ ๋งž์„ ์ˆ˜ ์žˆ๋‹ค. ํ•˜์ง€๋งŒ ํ—ˆ๊น… ํŽ˜์ด์Šค์˜ roformer ์ฝ”๋“œ๋ฅผ ๋ณด๋ฉด ์—ญ์‹œ, linear attention ๋Œ€์‹  full attention์— RoPE๋ฅผ ํ†ตํ•ฉํ•˜๋Š” ๋ฐฉ์‹์œผ๋กœ ๊ตฌํ˜„ํ–ˆ๋‹ค. ๋”ฐ๋ผ์„œ ํ•„์ž ์—ญ์‹œ full attention์„ ๊ธฐ์ค€์œผ๋กœ ๋ชจ๋ธ์„ ๊ตฌํ˜„ํ–ˆ์Œ์„ ๋ฐํžŒ๋‹ค.

๐Ÿ‘ฉโ€๐Ÿ’ปย Implementation by Pytorch

๋…ผ๋ฌธ์˜ ๋‚ด์šฉ๊ณผ ์˜คํ”ผ์…œ๋กœ ๊ณต๊ฐœ๋œ ์ฝ”๋“œ๋ฅผ ์ข…ํ•ฉํ•˜์—ฌ ํŒŒ์ดํ† ์น˜๋กœ Roformer๋ฅผ ๊ตฌํ˜„ํ•ด๋ดค๋‹ค. ๋‹ค๋งŒ, linear attention ๋Œ€์‹  full attention์„ ์‚ฌ์šฉํ–ˆ๊ณ  ์˜ค์ง ์ธ์ฝ”๋” ๋ถ€๋ถ„๋งŒ ๊ตฌํ˜„ํ–ˆ์Œ์„ ๋ฐํžŒ๋‹ค.

ํ•œํŽธ, ํ•„์ž๊ฐ€ ์ง์ ‘ ๊ตฌํ˜„ํ•œ RoPE๋ฅผ ์ฝ”๋“œ๋„ ์žˆ์œผ๋‚˜, GPU ์—ฐ์‚ฐ ์ตœ์ ํ™”๊นŒ์ง€๋Š” ์‹คํŒจํ•ด ๋Œ€์‹  ํ—ˆ๊น…ํŽ˜์ด์Šค์˜ ๊ตฌํ˜„์ฒด๋ฅผ ์ฐธ๊ณ ํ–ˆ์Œ์„ ๋ฐํžŒ๋‹ค. ์‹œ๊ฐ„์ด ๋  ๋•Œ, ์ง์ ‘ ๊ตฌํ˜„ํ–ˆ๋˜ RoPE ์ฝ”๋“œ๋„ ํ•จ๊ผ ์ฒจ๋ถ€ํ•˜๊ฒ ๋‹ค. ๊ทธ๋ฆฌ๊ณ  ์ด๋ฒˆ ํฌ์ŠคํŒ…์—์„œ๋Š” RoPE๋ฅผ ๊ตฌํ˜„ํ•˜๋Š” ๋ฐฉ๋ฒ•์— ๋Œ€ํ•ด์„œ๋งŒ ๋‹ค๋ฃจ๊ณ , ๋‚˜๋จธ์ง€ ๊ตฌํ˜„์— ๋Œ€ํ•œ ์„ค๋ช…์€ ์ƒ๋žตํ•˜๋ ค ํ•œ๋‹ค. ์ „์ฒด ๋ชจ๋ธ ๊ตฌ์กฐ ๋Œ€ํ•œ ์ฝ”๋“œ๋Š” ์—ฌ๊ธฐ ๋งํฌ๋ฅผ ํ†ตํ•ด ์ฐธ๊ณ ๋ฐ”๋ž€๋‹ค.

๐ŸŽก Rotary Position Embedding

_init_weight()์˜ position_enc๋ฅผ ์ฃผ๋ชฉํ•ด๋ณด์ž. position_enc๋Š” position๊ณผ dim์„ ์ธ์ž๋กœ ๋ฐ›์•„ position๊ณผ dim์— ๋”ฐ๋ผ์„œ position_enc๋ฅผ ๋งŒ๋“ค์–ด๋‚ด๋Š”๋ฐ, ์ด๊ฒƒ์ด ๋ฐ”๋กœ RoPE์˜ ํ•ต์‹ฌ์ด๋‹ค. ํ•ด๋‹น ์ฝ”๋“œ ๋ผ์ธ์ด ์ •ํ™•ํ•˜๊ฒŒ $m\theta_d$์„ ๊ณ„์‚ฐํ•˜๊ฒŒ ๋œ๋‹ค.

class RoFormerSinusoidalPositionalEmbedding(nn.Embedding):
    """ This module produces sinusoidal positional embeddings of any length
    Original Source code from Huggingface's RoFormer model, which is the most optimized way to create positional embedding

    Args:
        max_seq: max sequence length of model
        dim_head: dimension of each attention head's hidden states

    Returns:
        Tensor -> torch.Size([seq_len, dim_head])

    References:
        https://arxiv.org/abs/2104.09864  # RoFormer: Enhanced Transformer with Rotary Position Embedding
        https://github.com/huggingface/transformers/blob/main/src/transformers/models/roformer/modeling_roformer.py#L323
    """

    def __init__(self, max_seq: int, dim_head: int) -> None:
        super().__init__(max_seq, dim_head)
        self.weight = self._init_weight(self.weight)

    @staticmethod
    def _init_weight(out: nn.Parameter) -> nn.Parameter:
        """
        Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
        the 2nd half of the vector. [dim // 2:]
        """
        n_pos, dim = out.shape
        position_enc = np.array(
            [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
        )  # m * theta
        out.requires_grad = False  # set early to avoid an error in pytorch-1.8+
        sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1
        out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
        out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
        out.detach_()
        return out

    @torch.no_grad()
    def forward(self, seq_len: int, past_key_values_length: int = 0) -> Tensor:
        positions = torch.arange(
            past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
        )
        return super().forward(positions)


class Embedding(nn.Module):
    """ Class module for Roformer Embedding, word embedding & rotary positional encoding
    This module has option => whether or not to use ALBERT Style Factorized Embedding

    Args:
        cfg: configuration.py

    References:
        https://arxiv.org/abs/1706.03762
        https://arxiv.org/pdf/1810.04805.pdf
        https://arxiv.org/abs/2006.16236
        https://arxiv.org/abs/2104.09864  # RoFormer: Enhanced Transformer with Rotary Position Embedding
        https://github.com/huggingface/transformers/blob/main/src/transformers/models/roformer/modeling_roformer.py
        https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py
    """
    def __init__(self, cfg: CFG) -> None:
        super(Embedding, self).__init__()
        self.cfg = cfg
        self.batch_size = cfg.batch_size
        self.max_seq = cfg.max_seq
        self.dim_model = cfg.dim_model
        self.word_embedding = nn.Embedding(len(cfg.tokenizer), cfg.dim_model)
        self.layer_norm1 = nn.LayerNorm(cfg.dim_model, eps=cfg.layer_norm_eps)  # for word embedding
        self.hidden_dropout = nn.Dropout(p=cfg.hidden_dropout_prob)
        self.rotary_pos_encoding = RoFormerSinusoidalPositionalEmbedding(
            cfg.max_seq,
            cfg.dim_model // cfg.num_attention_heads
        )

        # ALBERT Style Factorized Embedding
        if self.cfg.is_mf_embedding:
            self.word_embedding = nn.Embedding(len(cfg.tokenizer), int(cfg.dim_model/6))
            self.projector = nn.Linear(int(cfg.dim_model/6), cfg.dim_model)  # project to original hidden dim

    def forward(self, inputs: Tensor) -> Tuple[nn.Embedding, Tensor]:
        if self.cfg.is_mf_embedding:
            word_embeddings = self.hidden_dropout(
                self.layer_norm1(self.projector(self.word_embedding(inputs)))
            )
        else:
            word_embeddings = self.hidden_dropout(
                self.layer_norm1(self.word_embedding(inputs))
            )
        rotary_pos_enc = self.rotary_pos_encoding(inputs.shape[1])
        return word_embeddings, rotary_pos_enc

๐Ÿ”จ Integrated RoPE into Full Attention(scaled dot-product attention)

RoPE๋ฅผ ์ ์šฉํ•˜๋Š” Full Attention์˜ ๊ตฌํ˜„ ์ˆœ์„œ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค. ๋จผ์ €, ๋‹จ์–ด ์ž„๋ฒ ๋”ฉ์„ ์ฟผ๋ฆฌ, ํ‚ค, ๋ฒจ๋ฅ˜ ํ–‰๋ ฌ๋กœ ์„ ํ˜• ํˆฌ์˜ํ•œ๋‹ค. ์ด ๋•Œ RoPE๋ฅผ ๊ณฑํ•ด์ฃผ๊ธฐ ์œ„ํ—ค apply_rotary_position_embeddings()์— ์ธ์ž๋กœ ์ฟผ๋ฆฌ, ํ‚ค ํ–‰๋ ฌ์„ ์ „๋‹ฌํ•œ๋‹ค. ์ด ๋•Œ ๋ฐ˜๋“œ์‹œ ๋ฒจ๋ฅ˜ ํ–‰๋ ฌ์€ ๋‹จ์–ด ์ž„๋ฒ ๋”ฉ์œผ๋กœ๋ถ€ํ„ฐ ์„ ํ˜• ํˆฌ์˜๋œ ์ƒํƒœ๋ฅผ ์œ ์ง€ํ•ด์•ผํ•จ์„ ๊ธฐ์–ตํ•˜์ž. apply_rotary_position_embeddings()๋Š” RoPE๊ฐ€ ๊ณฑํ•ด์ง„ ์ฟผ๋ฆฌ, ํ‚ค ํ–‰๋ ฌ์„ ๋ฐ˜ํ™˜ํ•œ๋‹ค. ์ดํ›„ ๊ณผ์ •์€ ํ“จ์–ดํ•œ full attention๊ณผ ๋™์ผํ•˜๋‹ค.

์ธ์ž๋กœ ๋“ค์–ด๊ฐ€๋Š” ํ…์„œ๋“ค์˜ ๋ชจ์–‘์€ ์ฃผ์„์„ ์ฐธ๊ณ  ๋ฐ”๋ž€๋‹ค.

def apply_rotary_position_embeddings(sinusoidal_pos: Tensor, query_layer: Tensor, key_layer: Tensor, value_layer: Tensor = None):
    """ Apply rotary position encoding to query, key layer
    Original Source code from Huggingface's RoFormer model, which is the most optimized way to create positional embedding

    You can find mathematical proof in official paper's Appendix

    Args:
        sinusoidal_pos: sinusoidal positional encoding, shape [batch(None), num_dim(None), seq_len, dim_head]
        query_layer: query matrix, shape (batch_size, num_head, seq_len, dim_head)
        key_layer: key matrix, shape (batch_size, num_head, seq_len, dim_head)
        value_layer: value matrix, shape (batch_size, num_head, seq_len, dim_head)

    References:
        https://arxiv.org/abs/2104.09864  # RoFormer: Enhanced Transformer with Rotary Position Embedding
        https://github.com/huggingface/transformers/blob/main/src/transformers/models/roformer/modeling_roformer.py#L323
    """
    sin, cos = sinusoidal_pos.chunk(2, dim=-1)  # select two element of index values
    sin_pos = torch.stack([sin, sin], dim=-1).reshape_as(sinusoidal_pos)

    cos_pos = torch.stack([cos, cos], dim=-1).reshape_as(sinusoidal_pos)
    rotate_half_query_layer = torch.stack([-query_layer[..., 1::2], query_layer[..., ::2]], dim=-1).reshape_as(
        query_layer
    )

    # mathematical expression from Appendix in official repo
    query_layer = query_layer * cos_pos + rotate_half_query_layer * sin_pos
    rotate_half_key_layer = torch.stack([-key_layer[..., 1::2], key_layer[..., ::2]], dim=-1).reshape_as(key_layer)
    key_layer = key_layer * cos_pos + rotate_half_key_layer * sin_pos

    if value_layer is not None:  # In official, they don't use value_layer
        rotate_half_value_layer = torch.stack([-value_layer[..., 1::2], value_layer[..., ::2]], dim=-1).reshape_as(
            value_layer
        )
        value_layer = value_layer * cos_pos + rotate_half_value_layer * sin_pos
        return query_layer, key_layer, value_layer
    return query_layer, key_layer

class MultiHeadAttention(nn.Module):
    def __init__(
        self,
        dim_model: int = 1024,
        num_attention_heads: int = 16,
        dim_head: int = 64,
        kernel: str = 'softmax',
        attention_dropout_prob: float = 0.1
    ) -> None:
        super(MultiHeadAttention, self).__init__()
        self.dim_model = dim_model
        self.num_attention_heads = num_attention_heads
        self.dim_head = dim_head
        self.fc_q = nn.Linear(self.dim_model, self.dim_model)
        self.fc_k = nn.Linear(self.dim_model, self.dim_model)
        self.fc_v = nn.Linear(self.dim_model, self.dim_model)
        self.fc_concat = nn.Linear(self.dim_model, self.dim_model)
        self.apply_rope = apply_rotary_position_embeddings
        self.attention = scaled_dot_product_attention if kernel == 'softmax' else linear_attention
        self.attention_dropout = nn.Dropout(p=attention_dropout_prob)
        self.dot_scale = torch.sqrt(torch.tensor(self.dim_head, dtype=torch.float32))
        self.kernel = kernel
        self.eps = 1e-6
    
    def forward(self, x: Tensor, rotary_pos_enc: Tensor, padding_mask: Tensor, attention_mask: Tensor = None) -> Tensor:
        """ x is already passed nn.Layernorm, already multiplied with rotary position encoding """
        assert x.ndim == 3, f'Expected (batch, seq, hidden) got {x.shape}'

        # size: bs, seq, nums head, dim head, linear projection
        q = self.fc_q(x).reshape(-1, x.shape[1], self.num_attention_heads, self.dim_head).permute(0, 2, 1, 3).contiguous()
        k = self.fc_k(x).reshape(-1, x.shape[1], self.num_attention_heads, self.dim_head).permute(0, 2, 1, 3).contiguous()
        v = self.fc_v(x).reshape(-1, x.shape[1], self.num_attention_heads, self.dim_head).permute(0, 2, 1, 3).contiguous()

        # multiple word embedding, rotary position encoding
        rotary_q, rotary_k = self.apply_rope(rotary_pos_enc, q, k)

        attention_matrix = None
        if self.kernel == 'elu':
            attention_matrix = self.attention(
                rotary_q,
                rotary_k,
                v,
                self.kernel,
                self.eps,
                self.attention_dropout,
                padding_mask,
                attention_mask
            )
        elif self.kernel == 'softmax':  # pure self-attention
            attention_matrix = self.attention(
                rotary_q,
                rotary_k,
                v,
                self.dot_scale,
                self.attention_dropout,
                padding_mask,
                attention_mask
            )

        attention_output = self.fc_concat(attention_matrix)
        return attention_output

Leave a comment