Updated:

๐Ÿ”ญย Overview

Transformer๋Š” 2017๋…„ Google์ด NIPS์—์„œ ๋ฐœํ‘œํ•œ ์ž์—ฐ์–ด ์ฒ˜๋ฆฌ์šฉ ์‹ ๊ฒฝ๋ง์œผ๋กœ ๊ธฐ์กด RNN ๊ณ„์—ด(LSTM, GRU) ์‹ ๊ฒฝ๋ง์ด ๊ฐ€์ง„ ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๊ณ  ์ตœ๋Œ€ํ•œ ์ธ๊ฐ„์˜ ์ž์—ฐ์–ด ์ดํ•ด ๋ฐฉ์‹์„ ์ˆ˜ํ•™์ ์œผ๋กœ ๋ชจ๋ธ๋ง ํ•˜๋ ค๋Š” ์˜๋„๋กœ ์„ค๊ณ„ ๋˜์—ˆ๋‹ค. ์ด ๋ชจ๋ธ์€ ์ดˆ๊ธฐ Encoder-Decoder ๋ฅผ ๋ชจ๋‘ ๊ฐ–์ถ˜ seq2seq ํ˜•ํƒœ๋กœ ๊ณ ์•ˆ ๋˜์—ˆ์œผ๋ฉฐ, ๋‹ค์–‘ํ•œ ๋ฒˆ์—ญ ํ…Œ์Šคํฌ์—์„œ SOTA๋ฅผ ๋‹ฌ์„ฑํ•ด ์ฃผ๋ชฉ์„ ๋ฐ›์•˜๋‹ค. ์ดํ›„์—๋Š” ์—ฌ๋Ÿฌ๋ถ„๋„ ์ž˜ ์•„์‹œ๋Š” ๊ฒƒ์ฒ˜๋Ÿผ BERT, GPT, ViT์˜ ๋ฒ ์ด์Šค ๋ผ์ธ์œผ๋กœ ์ฑ„ํƒ ๋˜๋ฉฐ, ํ˜„๋Œ€ ๋”ฅ๋Ÿฌ๋‹ ์—ญ์‚ฌ์— ํ•œ ํš์„ ๊ทธ์€ ๋ชจ๋ธ๋กœ ํ‰๊ฐ€ ๋ฐ›๊ณ  ์žˆ๋‹ค.

ํ˜„๋Œ€ ๋”ฅ๋Ÿฌ๋‹์˜ ์ „์„ฑ๊ธฐ๋ฅผ ์—ด์–ด์ค€ Transformer๋Š” ์–ด๋–ค ์•„์ด๋””์–ด๋กœ ๊ธฐ์กด Recurrent ๊ณ„์—ด์ด ๊ฐ€์กŒ๋˜ ๋ฌธ์ œ๋“ค์„ ํ•ด๊ฒฐํ–ˆ์„๊นŒ?? ์ด๊ฒƒ์„ ์ œ๋Œ€๋กœ ์ดํ•ดํ•˜๋ ค๋ฉด ๋จผ์ € ๊ธฐ์กด ์ˆœํ™˜ ์‹ ๊ฒฝ๋ง ๋ชจ๋ธ๋“ค์ด ๊ฐ€์กŒ๋˜ ๋ฌธ์ œ๋ถ€ํ„ฐ ์งš๊ณ  ๋„˜์–ด๊ฐˆ ํ•„์š”๊ฐ€ ์žˆ๋‹ค.

๐Ÿค”ย Limitation of Recurrent Structure

  • 1) ์ธ๊ฐ„๊ณผ ๋‹ค๋ฅธ ๋ฉ”์ปค๋‹ˆ์ฆ˜์˜ Vanishing Gradient ๋ฐœ์ƒ (Activation Function with Backward)
  • 2) ์ ์  ํ๋ ค์ง€๋Š” Inputs์— Attention (Activation Function with Forward)
  • 3) ๋””์ฝ”๋”๊ฐ€ ๊ฐ€์žฅ ๋งˆ์ง€๋ง‰ ๋‹จ์–ด๋งŒ ์—ด์‹ฌํžˆ ๋ณด๊ณ  denoising ์ˆ˜ํ–‰ (Seq2Seq with Bi-Directional RNN)

๐Ÿ“ˆย 1) ์ธ๊ฐ„๊ณผ ๋‹ค๋ฅธ ๋ฉ”์ปค๋‹ˆ์ฆ˜์˜ Vanishing Gradient ๋ฐœ์ƒ (Activation Function with Backward)

\[h(t) = tanh(x_tW_x + h_{t-1}W_h + b)\]

RNN์˜ ํ™œ์„ฑ ํ•จ์ˆ˜์ธ Hyperbolic Tangent ๋Š” $y$๊ฐ’์ด [-1, 1] ์‚ฌ์ด์—์„œ ์ •์˜๋˜๋ฉฐ ๊ธฐ์šธ๊ธฐ์˜ ์ตœ๋Œ€๊ฐ’์€ 1์ด๋‹ค. ๋”ฐ๋ผ์„œ ์ด์ „ ์‹œ์  ์ •๋ณด๋Š” ์‹œ์ ์ด ์ง€๋‚˜๋ฉด ์ง€๋‚ ์ˆ˜๋ก (๋” ๋งŽ์€ ์…€์„ ํ†ต๊ณผํ• ์ˆ˜๋ก) ๊ทธ๋ผ๋””์–ธํŠธ ๊ฐ’์ด ์ž‘์•„์ ธ ๋ฏธ๋ž˜ ์‹œ์ ์˜ ํ•™์Šต์— ๋งค์šฐ ์ž‘์€ ์˜ํ–ฅ๋ ฅ์„ ๊ฐ–๊ฒŒ ๋œ๋‹ค. ์ด๊ฒƒ์ด ๋ฐ”๋กœ ๊ทธ ์œ ๋ช…ํ•œ RNN์˜ Vanishing Gradient ํ˜„์ƒ์ด๋‹ค. ์‚ฌ์‹ค ํ˜„์ƒ์˜ ๋ฐœ์ƒ ์ž์ฒด๋Š” ๊ทธ๋ ‡๊ฒŒ ํฐ ๋ฌธ์ œ๊ฐ€ ๋˜์ง€ ์•Š๋Š”๋‹ค. RNN์—์„œ ๋ฐœ์ƒํ•˜๋Š” Vanishing Gradient ๊ฐ€ ๋ฌธ์ œ๊ฐ€ ๋˜๋Š” ์ด์œ ๋Š” ๋ฐ”๋กœ ์ธ๊ฐ„์ด ์ž์—ฐ์–ด๋ฅผ ์ดํ•ดํ•˜๋Š” ๋ฉ”์ปค๋‹ˆ์ฆ˜๊ณผ ๋‹ค๋ฅธ ๋ฐฉ์‹์œผ๋กœ ํ˜„์ƒ์ด ๋ฐœ์ƒํ•˜๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค. ์šฐ๋ฆฌ๊ฐ€ ๊ธ€์„ ์ฝ๋Š” ๊ณผ์ •์„ ์ž˜ ๋– ์˜ฌ๋ ค ๋ณด์ž. ์–ด๋–ค ๋‹จ์–ด์˜ ์˜๋ฏธ๋ฅผ ์•Œ๊ธฐ ์œ„ํ•ด ๊ฐ€๊นŒ์šด ์ฃผ๋ณ€ ๋‹จ์–ด์˜ ๋ฌธ๋งฅ์„ ํ™œ์šฉํ•  ๋•Œ๋„ ์žˆ์ง€๋งŒ, ์ € ๋ฉ€๋ฆฌ ๋–จ์–ด์ง„ ๋ฌธ๋‹จ์˜ ๋ฌธ๋งฅ์„ ํ™œ์šฉํ•  ๋•Œ๋„ ์žˆ๋‹ค. ์ด์ฒ˜๋Ÿผ ๋‹จ์–ด ํ˜น์€ ์‹œํ€€์Šค๋ฅผ ๊ตฌ์„ฑํ•˜๋Š” ์›์†Œ ์‚ฌ์ด์˜ ๊ด€๊ณ„์„ฑ์ด๋‚˜ ์–ด๋–ค ๋‹ค๋ฅธ ์˜๋ฏธ๋ก ์ ์ธ ์ด์œ ๋กœ ๋ถˆ๊ท ํ˜•ํ•˜๊ฒŒ ํ˜„์žฌ ์‹œ์ ์˜ ํ•™์Šต์— ์˜ํ–ฅ๋ ฅ์„ ๊ฐ–๊ฒŒ ๋˜๋Š”๊ฒŒ ์•„๋‹ˆ๋ผ, ๋‹จ์ˆœ ์ž…๋ ฅ ์‹œ์  ๋•Œ๋ฌธ์— ๋ถˆ๊ท ํ˜•์ด ๋ฐœ์ƒํ•˜๊ธฐ ๋•Œ๋ฌธ์— RNN์˜ Vanishing Gradient๊ฐ€ ๋‚ฎ์€ ์„ฑ๋Šฅ์˜ ์›์ธ์œผ๋กœ ์ง€๋ชฉ๋˜๋Š” ๊ฒƒ์ด๋‹ค.

๋‹ค์‹œ ๋งํ•ด, ์‹ค์ œ ์ž์—ฐ์–ด์˜ ๋ฌธ๋งฅ์„ ํŒŒ์•…ํ•ด ๊ทธ๋ผ๋””์–ธํŠธ์— ๋ฐ˜์˜ํ•˜๋Š”๊ฒŒ ์•„๋‹ˆ๋ผ ๋‹จ์ˆœํžˆ ์‹œ์ ์— ๋”ฐ๋ผ์„œ ๊ทธ ์˜ํ–ฅ๋ ฅ์„ ๋ฐ˜์˜ํ•˜๊ฒŒ ๋œ๋‹ค๋Š” ๊ฒƒ์ด๋‹ค. ๋ฉ€๋ฆฌ ๋–จ์–ด์ง„ ์‹œํ€€์Šค์˜ ๋ฌธ๋งฅ์ด ํ•„์š”ํ•œ ๊ฒฝ์šฐ๋ฅผ Recurrent ๊ตฌ์กฐ๋Š” ์ •ํ™•ํžˆ ํ•™์Šตํ•  ์ˆ˜ ์—†๋‹ค.

๊ทธ๋ ‡๋‹ค๋ฉด ํ™œ์„ฑ ํ•จ์ˆ˜๋ฅผ relu ํ˜น์€ gelu ๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด ์œ„ ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•  ์ˆ˜ ์žˆ์„๊นŒ? Vanishing Graident ๋ฌธ์ œ๋Š” ํ•ด๊ฒฐํ•  ์ˆ˜๋„ ์žˆ์œผ๋‚˜ hidden_state ๊ฐ’์ด ๋ฐœ์‚ฐํ•  ๊ฒƒ์ด๋‹ค. ๊ทธ ์ด์œ ๋Š” ๋‘ ํ™œ์„ฑ ํ•จ์ˆ˜ ๋ชจ๋‘ ์–‘์ˆ˜ ๊ตฌ๊ฐ„์—์„œ ์„ ํ˜•์ธ๋ฐ, ์ด์ „ ์ •๋ณด๋ฅผ ๋ˆ„์ ํ•ด์„œ ๊ฐ€์ค‘์น˜์™€ ๊ณฑํ•˜๊ณ  ํ˜„์žฌ ์ž…๋ ฅ๊ฐ’์— ๋”ํ•˜๋Š” RNN์˜ ๊ตฌ์กฐ๋ฅผ ์ƒ๊ฐํ•ด๋ณด๋ฉด ๋„˜์–ด์˜ค๋Š” ์ด์ „ ์ •๋ณด๋Š” ๋ˆ„์ ๋˜๋ฉด์„œ ์ ์  ์ปค์งˆ ๊ฒƒ์ด๊ณ  ๊ทธ๋Ÿฌ๋‹ค ๊ฒฐ๊ตญ ๋ฐœ์‚ฐํ•˜๊ฒŒ ๋œ๋‹ค.

๊ฒฐ๋ก ์ ์œผ๋กœ Vanishing Gradient ํ˜„์ƒ ์ž์ฒด๊ฐ€ ๋ฌธ์ œ๋Š” ์•„๋‹ˆ์ง€๋งŒ ๋ชจ๋ธ์ด ์ž์—ฐ์–ด์˜ ๋ฌธ๋งฅ์„ ํŒŒ์•…ํ•ด ๊ทธ๋ผ๋””์–ธํŠธ์— ๋ฐ˜์˜ํ•˜๋Š”๊ฒŒ ์•„๋‹ˆ๋ผ ๋‹จ์ˆœํžˆ ์‹œ์ ์— ๋”ฐ๋ผ์„œ ๋ถˆ๊ท ํ˜•ํ•˜๊ฒŒ ๋ฐœ์ƒํ•˜๊ธฐ ๋•Œ๋ฌธ์— ๋‚ฎ์€ ์„ฑ๋Šฅ์˜ ์›์ธ์œผ๋กœ ์ง€๋ชฉ ๋ฐ›๋Š” ๊ฒƒ์ด๋‹ค. ์ด๊ฒƒ์„ long-term dependency๋ผ๊ณ  ๋ถ€๋ฅด๊ธฐ๋„ ํ•œ๋‹ค.

โœ๏ธย 2) ์ ์  ํ๋ ค์ง€๋Š” Inputs์— Attention (Activation Function with Forward)

tanh function tanh function

Hyperbolic Tangent ์€ $y$๊ฐ’์ด [-1, 1] ์‚ฌ์ด์—์„œ ์ •์˜๋œ๋‹ค๊ณ  ํ–ˆ๋‹ค. ๋‹ค์‹œ ๋งํ•ด ์…€์˜ ์ถœ๋ ฅ๊ฐ’์ด ํ•ญ์ƒ ์ผ์ • ๋ฒ”์œ„๊ฐ’( [-1,1] )์œผ๋กœ ์ œํ•œ(๊ฐ€์ค‘์น˜, ํŽธํ–ฅ ๋”ํ•˜๋Š” ๊ฒƒ์€ ์ผ๋‹จ ์ œ์™ธ) ๋œ๋‹ค๋Š” ๊ฒƒ์ด๋‹ค. ๋”ฐ๋ผ์„œ ํ•œ์ •๋œ ์ข์€ ๋ฒ”์œ„์— ์ถœ๋ ฅ๊ฐ’๋“ค์ด ๋งตํ•‘๋˜๋Š”๋ฐ, ์ด๋Š” ๊ฒฐ๊ตญ ์ž…๋ ฅ๊ฐ’์˜ ์ •๋ณด๋Š” ๋Œ€๋ถ€๋ถ„ ์†Œ์‹ค๋œ ์ฑ„ ์ผ๋ถ€ ํŠน์ง•๋งŒ ์ •์ œ ๋˜์–ด ์ถœ๋ ฅ๋˜๊ณ  ๋‹ค์Œ ๋ ˆ์ด์–ด๋กœ forward ๋จ์„ ์˜๋ฏธํ•œ๋‹ค. ๊ทธ๋ž˜ํ”„๋ฅผ ํ•œ ๋ฒˆ ์‚ดํŽด๋ณด์ž. ํŠนํžˆ Inputs ๊ฐ’์ด 2.5 ์ด์ƒ์ธ ๊ฒฝ์šฐ๋ถ€ํ„ฐ๋Š” ์ถœ๋ ฅ๊ฐ’์ด ๊ฑฐ์˜ 1์— ์ˆ˜๋ ดํ•ด ๊ทธ ์ฐจ์ด๋ฅผ ์ง๊ด€์ ์œผ๋กœ ํŒŒ์•…ํ•˜๊ธฐ ํž˜๋“ค๋‹ค. ์ด๋Ÿฌํ•œ ํ™œ์„ฑํ•จ์ˆ˜๊ฐ€ ์ˆ˜์‹ญ๊ฐœ, ์ˆ˜๋ฐฑ๊ฐœ ์Œ“์ธ๋‹ค๋ฉด ๊ฒฐ๊ตญ ์›๋ณธ ์ •๋ณด๋Š” ๋งค์šฐ ํ๋ ค์ง€๊ณ  ๋ญ‰๊ฐœ์ ธ์„œ ๋‹ค๋ฅธ ์ธ์Šคํ„ด์Šค์™€ ๊ตฌ๋ณ„์ด ํž˜๋“ค์–ด ์งˆ ๊ฒƒ์ด๋‹ค.

๐Ÿ”ฌย 3) ๋””์ฝ”๋”๊ฐ€ ๊ฐ€์žฅ ๋งˆ์ง€๋ง‰ ๋‹จ์–ด๋งŒ ์—ด์‹ฌํžˆ ๋ณด๊ณ  denoising ์ˆ˜ํ–‰ (Seq2Seq with Bi-Directional RNN)
โ€œ์“ฐ๋‹คโ€ ($t_7$)๋ผ๋Š” ๋‹จ์–ด์˜ ๋œป์„ ์ดํ•ดํ•˜๋ ค๋ฉด โ€œ๋ˆ์„โ€, โ€œ๋ชจ์ž๋ฅผโ€, โ€œ๋ง›์ดโ€, โ€œ๊ธ€์„โ€($t_1$)๊ณผ ๊ฐ™์ด ๋ฉ€๋ฆฌ ์žˆ๋Š” ์•ž ๋‹จ์–ด๋ฅผ ๋ด์•ผ ์•Œ ์ˆ˜ ์žˆ๋Š”๋ฐ, $h_7$ ์—๋Š” $t_1$์ด ํ๋ ค์ง„ ์ฑ„๋กœ ๋“ค์–ด๊ฐ€ ์žˆ์–ด์„œ $t_7$์˜ ์ œ๋Œ€๋กœ ๋œ ์˜๋ฏธ๋ฅผ ํฌ์ฐฉํ•˜์ง€ ๋ชปํ•œ๋‹ค. ์‹ฌ์ง€์–ด ์–ธ์–ด๊ฐ€ ์˜์–ด๋ผ๋ฉด ๋’ค๋ฅผ ๋ด์•ผ ์ •ํ™•ํ•œ ๋ฌธ๋งฅ์„ ์•Œ ์ˆ˜ ์žˆ๋Š”๋ฐ Vanilla RNN์€ ๋‹จ๋ฐฉํ–ฅ์œผ๋กœ๋งŒ ํ•™์Šต์„ ํ•˜๊ฒŒ ๋˜์–ด ๋ฌธ์žฅ์˜ ๋’ท๋ถ€๋ถ„ ๋ฌธ๋งฅ์€ ๋ฐ˜์˜์กฐ์ฐจ(๋’ค์— ์œ„์น˜ํ•œ ๋ชฉ์ ์–ด์— ๋”ฐ๋ผ์„œ ์“ฐ๋‹ค๋ผ๋Š” ๋‹จ์–ด์˜ ๋‰˜์•™์Šค๋Š” ๋‹ฌ๋ผ์ง) ํ•  ์ˆ˜ ์—†๋‹ค. ๊ทธ๋ž˜์„œ Bi-directional RNN ์จ์•ผํ•˜๋Š”๋ฐ, ์ด๊ฒƒ๋„ ์—ญ์‹œ๋„ ์—ฌ์ „ํžˆ โ€œ๊ฑฐ๋ฆฌโ€์— ์˜ํ–ฅ ๋ฐ›๋Š”๋‹ค๋Š” ๊ฑด ๋ณ€ํ•˜์ง€ ์•Š๊ธฐ ๋•Œ๋ฌธ์— ๊ทผ๋ณธ์ ์ธ ํ•ด๊ฒฐ์ฑ…์ด๋ผ ๋ณผ ์ˆ˜ ์—†๋‹ค.

ํ•œํŽธ, ๋””์ฝ”๋”์˜ Next Token Prediction ์„ฑ๋Šฅ์€ ๋ฌด์กฐ๊ฑด ์ธ์ฝ”๋”๋กœ๋ถ€ํ„ฐ ๋ฐ›๋Š” Context Vector์˜ ํ’ˆ์งˆ์— ๋”ฐ๋ผ ์ขŒ์ง€์šฐ์ง€ ๋œ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ Recurrent ๊ตฌ์กฐ์˜ ์ธ์ฝ”๋”๋กœ๋ถ€ํ„ฐ ๋‚˜์˜จ Context Vector๋Š” ์•ž์„œ ์„œ์ˆ ํ•œ ๊ฒƒ์ฒ˜๋Ÿผ ์ข‹์€ ํ’ˆ์งˆ(๋’ค์ชฝ ๋‹จ์–ด๊ฐ€ ์ƒ๋Œ€์ ์œผ๋กœ ์„ ๋ช…ํ•จ)์ด ์•„๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ ๋””์ฝ”๋”์˜ ๋ฒˆ์—ญ(๋‹ค์Œ ๋‹จ์–ด ์˜ˆ์ธก) ์„ฑ๋Šฅ ์—ญ์‹œ ์ข‹์„๋ฆฌ๊ฐ€ ์—†๋‹ค.

๊ฒฐ๊ตญ Recurrent ๊ตฌ์กฐ ์ž์ฒด์— ๋ช…ํ™•ํ•œ ํ•œ๊ณ„๊ฐ€ ์กด์žฌํ•˜์—ฌ ์ธ๊ฐ„์ด ์ž์—ฐ์–ด๋ฅผ ์‚ฌ์šฉํ•˜๊ณ  ์ดํ•ดํ•˜๋Š” ๋งฅ๋ฝ๊ณผ ๋‹ค๋ฅธ ๋ฐฉ์‹์œผ๋กœ ๋™์ž‘ํ–๊ฒŒ ๋˜์—ˆ๋‹ค. LSTM, GRU์˜ ์ œ์•ˆ์œผ๋กœ ์–ด๋Š ์ •๋„ ๋ฌธ์ œ๋ฅผ ์™„ํ™” ์‹œ์ผฐ์œผ๋‚˜, ์•ž์—์„œ ์„œ์ˆ ํ–ˆ๋“ฏ์ด ํƒœ์ƒ์ด Recurrent Structure์„ ๊ฐ€์ง€๊ธฐ ๋•Œ๋ฌธ์— ๊ทผ๋ณธ์ ์ธ ํ•ด๊ฒฐ์ฑ…์ด ๋˜์ง€๋Š” ๋ชปํ–ˆ๋‹ค. ๊ทธ๋ ‡๋‹ค๋ฉด ์ด์ œ Transformer๊ฐ€ ์–ด๋–ป๊ฒŒ ์œ„์— ์„œ์ˆ ํ•œ 3๊ฐ€์ง€ ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๊ณ  ํ˜„์žฌ์˜ ์œ„์ƒ์„ ๊ฐ–๊ฒŒ ๋˜์—ˆ๋Š”์ง€ ์•Œ์•„๋ณด์ž.

๐ŸŒŸย Modeling

Attention Is All You Need Attention Is All You Need

์•ž์„œ Recurrent ๊ตฌ์กฐ์˜ Vanishing Gradient ์„ ์„ค๋ช…ํ•˜๋ฉด์„œ ์‹œ์ ์— ๋”ฐ๋ผ ์ •๋ณด๋ฅผ ์†Œ์‹คํ•˜๊ฒŒ ๋˜๋Š” ํ˜„์ƒ์€ ์ธ๊ฐ„์˜ ์ž์—ฐ์–ด ์ดํ•ด ๋ฐฉ์‹์ด ์•„๋‹ˆ๋ผ๋Š” ์ ์„ ์–ธ๊ธ‰ํ•œ ์  ์žˆ๋‹ค. ๋”ฐ๋ผ์„œ Transformer๋Š” ์ตœ๋Œ€ํ•œ ์ธ๊ฐ„์˜ ์ž์—ฐ์–ด ์ดํ•ด ๋ฐฉ์‹์„ ์ˆ˜ํ•™์ ์œผ๋กœ ๋ชจ๋ธ๋ง ํ•˜๋Š” ๊ฒƒ์— ์ดˆ์ ์„ ๋งž์ท„๋‹ค. ์šฐ๋ฆฌ๊ฐ€ ์“ฐ์—ฌ์ง„ ๊ธ€์„ ์ดํ•ดํ•˜๊ธฐ ์œ„ํ•ด ํ•˜๋Š” ํ–‰๋™๋“ค์„ ๋– ์˜ฌ๋ ค ๋ณด์ž. โ€œAppleโ€์ด๋ž€ ๋‹จ์–ด๊ฐ€ ์‚ฌ๊ณผ๋ฅผ ๋งํ•˜๋Š” ๊ฒƒ์ธ์ง€, ๋ธŒ๋žœ๋“œ ์• ํ”Œ์„ ์ง€์นญํ•˜๋Š” ๊ฒƒ์ธ์ง€ ํŒŒ์•…ํ•˜๊ธฐ ์œ„ํ•ด ๊ฐ™์€ ๋ฌธ์žฅ์— ์†ํ•œ ์ฃผ๋ณ€ ๋‹จ์–ด๋ฅผ ์‚ดํ”ผ๊ธฐ๋„ ํ•˜๊ณ  ๊ทธ๋ž˜๋„ ํŒŒ์•…ํ•˜๊ธฐ ํž˜๋“ค๋‹ค๋ฉด ์•ž๋’ค ๋ฌธ์žฅ, ๋‚˜์•„๊ฐ€ ๋ฌธ์„œ ์ „์ฒด ๋ ˆ๋ฒจ์—์„œ ๋งฅ๋ฝ์„ ํŒŒ์•…ํ•˜๊ธฐ ์œ„ํ•ด ๋…ธ๋ ฅํ•œ๋‹ค. Transformer ์—ฐ๊ตฌ์ง„์€ ๋ฐ”๋กœ ์ด ๊ณผ์ •์— ์ฃผ๋ชฉํ–ˆ์œผ๋ฉฐ ์ด๊ฒƒ์„ ๋ชจ๋ธ๋งํ•˜์—ฌ ๊ทธ ์œ ๋ช…ํ•œ Self-Attention์„ ๊ณ ์•ˆํ•ด๋‚ธ๋‹ค.

Word Embedding Space Word Embedding Space

๋‹ค์‹œ ๋งํ•ด Self-Attention์€ ํ† ํฐ์˜ ์˜๋ฏธ๋ฅผ ์ดํ•ดํ•˜๊ธฐ ์œ„ํ•ด ์ „์ฒด ์ž…๋ ฅ ์‹œํ€€์Šค ์ค‘์—์„œ ์–ด๋–ค ๋‹จ์–ด์— ์ฃผ๋ชฉํ•ด์•ผํ• ์ง€๋ฅผ ์ˆ˜ํ•™์ ์œผ๋กœ ํ‘œํ˜„ํ•œ ๊ฒƒ์ด๋ผ ๋ณผ ์ˆ˜ ์žˆ๋‹ค. ์ข€ ๋” ๊ตฌ์ฒด์ ์œผ๋กœ๋Š” ์‹œํ€€์Šค์— ์†ํ•œ ์—ฌ๋Ÿฌ ํ† ํฐ ๋ฒกํ„ฐ(ํ–‰๋ฐฑํ„ฐ)๋ฅผ ์ž„๋ฒ ๋”ฉ ๊ณต๊ฐ„ ์–ด๋””์— ๋ฐฐ์น˜ํ•  ๊ฒƒ์ธ๊ฐ€์— ๋Œ€ํ•ด ํ›ˆ๋ จํ•˜๋Š” ํ–‰์œ„๋‹ค.

Scaled Dot-Product Attention Scaled Dot-Product Attention

๊ทธ๋ ‡๋‹ค๋ฉด ์ด์ œ๋ถ€ํ„ฐ Transformer ๊ฐ€ ์–ด๋–ค ์•„์ด๋ฐ์ด์…˜์„ ํ†ตํ•ด ๊ธฐ์กด ์ˆœํ™˜ ์‹ ๊ฒฝ๋ง ๋ชจ๋ธ์˜ ๋‹จ์ ์„ ํ•ด๊ฒฐํ•˜๊ณ  ๋”ฅ๋Ÿฌ๋‹๊ณ„์˜ G.O.A.T ์ž๋ฆฌ๋ฅผ ์ฐจ์ง€ํ–ˆ๋Š”์ง€ ์•Œ์•„๋ณด์ž. ๋ชจ๋ธ์€ ํฌ๊ฒŒ ์ธ์ฝ”๋”์™€ ๋””์ฝ”๋” ๋ถ€๋ถ„์œผ๋กœ ๋‚˜๋‰˜๋Š”๋ฐ, ํ•˜๋Š” ์—ญํ• ๊ณผ ๋ฏธ์„ธํ•œ ๊ตฌ์กฐ์ƒ์˜ ์ฐจ์ด๋งŒ ์žˆ์„๋ฟ ๋‘ ๋ชจ๋“ˆ ๋ชจ๋‘ Self-Attention์ด ์ œ์ผ ์ค‘์š”ํ•˜๋‹ค๋Š” ๋ณธ์งˆ์€ ๋ณ€ํ•˜์ง€ ์•Š๋Š”๋‹ค. ๋”ฐ๋ผ์„œ Input Embedding๋ถ€ํ„ฐ ์ฐจ๋ก€๋Œ€๋กœ ์‚ดํŽด๋ณด๋˜, Self-Attention ์€ ํŠน๋ณ„ํžˆ ์‚ฌ์šฉ๋œ ํ•˜์œ„ ๋ธ”๋Ÿญ ๋‹จ์œ„๋ฅผ ๋น ์ง ์—†์ด, ์„ธ์„ธํ•˜๊ฒŒ ์‚ดํŽด๋ณผ ๊ฒƒ์ด๋‹ค.

Class Diagram Class Diagram

์ด๋ ‡๊ฒŒ ํ•˜์œ„ ๋ชจ๋“ˆ์— ๋Œ€ํ•œ ์„ค๋ช…๋ถ€ํ„ฐ ์Œ“์•„ ๋‚˜๊ฐ€ ๋งˆ์ง€๋ง‰์—๋Š” ์‹ค์ œ ๊ตฌํ˜„ ์ฝ”๋“œ์™€ ํ•จ๊ป˜ ์ „์ฒด์ ์ธ ๊ตฌ์กฐ ์ธก๋ฉด์—์„œ๋„ ๋ชจ๋ธ์„ ํ•ด์„ํ•ด๋ณผ ๊ฒƒ์ด๋‹ค. ๋๊นŒ์ง€ ํฌ์ŠคํŒ…์„ ์ฝ์–ด์ฃผ์‹œ๊ธธ ๋ฐ”๋ž€๋‹ค.

๐Ÿ”ฌ Input Embedding

\[X_E \in R^{B * S_E * V_E} \\ X_D \in R^{B * S_D * V_D}\]

Transformer๋Š” ์ธ์ฝ”๋”์™€ ๋””์ฝ”๋”๋กœ ์ด๋ค„์ง„ seq2seq ๊ตฌ์กฐ๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ๋‹ค. ์ฆ‰, ๋Œ€์ƒ ์–ธ์–ด๋ฅผ ํƒ€๊ฒŸ ์–ธ์–ด๋กœ ๋ฒˆ์—ญํ•˜๋Š”๋ฐ ๋ชฉ์ ์„ ๋‘๊ณ  ์žˆ๊ธฐ ๋•Œ๋ฌธ์— ์ž…๋ ฅ์œผ๋กœ ๋Œ€์ƒ ์–ธ์–ด ์‹œํ€€์Šค์™€ ํƒ€๊ฒŸ ์–ธ์–ด ์‹œํ€€์Šค ๋ชจ๋‘ ํ•„์š”ํ•˜๋‹ค. $X_E$๋Š” ์ธ์ฝ”๋”์˜ ์ž…๋ ฅ ํ–‰๋ ฌ์„ ๋‚˜ํƒ€๋‚ด๊ณ , $X_D$๋Š” ๋””์ฝ”๋”์˜ ์ž…๋ ฅ ํ–‰๋ ฌ์„ ์˜๋ฏธํ•œ๋‹ค. ์ด ๋•Œ, $B$๋Š” batch size, $S$๋Š” max_seq, $V$๋Š” ๊ฐœ๋ณ„ ๋ชจ๋“ˆ์ด ๊ฐ€์ง„ Vocab์˜ ์‚ฌ์ด์ฆˆ๋ฅผ ๊ฐ€๋ฆฌํ‚จ๋‹ค. ์œ„ ์ˆ˜์‹์€ ์‚ฌ์‹ค ๋…ผ๋ฌธ์— ์ž…๋ ฅ์— ๋Œ€ํ•œ ์ˆ˜์‹์ด ๋”ฐ๋กœ ์„œ์ˆ  ๋˜์–ด ์žˆ์ง€ ์•Š์•„, ํ•„์ž๊ฐ€ ์ง์ ‘ ๋งŒ๋“  ๊ฒƒ์ด๋‹ค. ์•ž์œผ๋กœ๋„ ํ•ด๋‹น ๊ธฐํ˜ธ๋ฅผ ์ด์šฉํ•ด ์ˆ˜์‹์„ ํ‘œํ˜„ํ•  ์˜ˆ์ •์ด๋‹ˆ ์ฐธ๊ณ  ๋ฐ”๋ž€๋‹ค.

\[W_E \in R^{V_E * d} \\ W_D \in R^{V_D * d} \\\]

์ด๋ ‡๊ฒŒ ์ •์˜๋œ ์ž…๋ ฅ๊ฐ’์„ ๊ฐœ๋ณ„ ๋ชจ๋“ˆ์˜ ์ž„๋ฒ ๋”ฉ ๋ ˆ์ด์–ด์— ํ†ต๊ณผ ์‹œํ‚จ ๊ฒฐ๊ณผ๋ฌผ์ด ๋ฐ”๋กœ Input Embedding์ด ๋œ๋‹ค. $d$๋Š” Transformer ๋ชจ๋ธ์˜ ์€๋‹‰์ธต์˜ ํฌ๊ธฐ๋ฅผ ์˜๋ฏธํ•œ๋‹ค. ๋”ฐ๋ผ์„œ Position Embedding ๊ณผ ๋”ํ•ด์ง€๊ธฐ ์ „, ์ž„๋ฒ ๋”ฉ ๋ ˆ์ด์–ด๋ฅผ ํ†ต๊ณผํ•œ Input Embedding์˜ ๋ชจ์–‘์€ ์•„๋ž˜ ์ˆ˜์‹๊ณผ ๊ฐ™๋‹ค.

\[X_E \in R^{B*S_E*d} \\ X_D \in R^{B*S_D*d} \\\]

๊ทธ๋ ‡๋‹ค๋ฉด ์‹ค์ œ ๊ตฌํ˜„์€ ์–ด๋–ป๊ฒŒ ํ• ๊นŒ?? Transformer ์˜ Input Embedding์€ nn.Embedding์œผ๋กœ ๋ ˆ์ด์–ด๋ฅผ ์ •์˜ํ•ด ์‚ฌ์šฉํ•œ๋‹ค. nn.Linear๋„ ์žˆ๋Š”๋ฐ ์™œ ๊ตณ์ด nn.Embedding์„ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ผ๊นŒ??

์ž์—ฐ์–ด ์ฒ˜๋ฆฌ์—์„œ ์ž…๋ ฅ ์ž„๋ฒ ๋”ฉ์„ ๋งŒ๋“ค๋•Œ๋Š” ๋ชจ๋ธ์˜ ํ† ํฌ๋‚˜์ด์ €์— ์˜ํ•ด ์‚ฌ์ „ ์ •์˜๋œ vocab์˜ ์‚ฌ์ด์ฆˆ๊ฐ€ ์ž…๋ ฅ ์‹œํ€€์Šค์— ์†ํ•œ ํ† ํฐ ๊ฐœ์ˆ˜๋ณด๋‹ค ํ›จ์”ฌ ํฌ๊ธฐ ๋•Œ๋ฌธ์— ๋ฐ์ดํ„ฐ ๋ฃฉ์—… ํ…Œ์ด๋ธ” ๋ฐฉ์‹์˜ย nn.Embeddingย ์„ ์‚ฌ์šฉํ•˜๊ฒŒ ๋œ๋‹ค. ์ด๊ฒŒ ๋ฌด์Šจ ๋ง์ด๋ƒ๋ฉด, ํ† ํฌ๋‚˜์ด์ €์— ์˜ํ•ด ์‚ฌ์ „์— ์ •์˜๋œย vocabย ์ „์ฒด๊ฐ€ย nn.Embedding(vocab_size, dim_model)๋กœ ํˆฌ์˜ ๋˜์–ด ๊ฐ€๋กœ๋Š” vocab ์‚ฌ์ด์ฆˆ, ์„ธ๋กœ๋Š” ๋ชจ๋ธ์˜ ์ฐจ์› ํฌ๊ธฐ์— ํ•ด๋‹นํ•˜๋Š” ๋ฃฉ์—… ํ…Œ์ด๋ธ”์ด ์ƒ์„ฑ๋˜๊ณ , ๋‚ด๊ฐ€ ์ž…๋ ฅํ•œ ํ† ํฐ๋“ค์€ ์ „์ฒดย vocab์˜ ์ผ๋ถ€๋ถ„์ผํ…Œ๋‹ˆ ์ „์ฒด ์ž„๋ฒ ๋”ฉ ๋ฃฉ์—… ํ…Œ์ด๋ธ”์—์„œ ๋‚ด๊ฐ€ ์ž„๋ฒ ๋”ฉํ•˜๊ณ  ์‹ถ์€ ํ† ํฐ๋“ค์˜ ์ธ๋ฑ์Šค๋งŒ ์•Œ์•„๋‚ธ๋‹ค๋Š” ๊ฒƒ์ด๋‹ค. ๊ทธ๋ž˜์„œย nn.Embeddingย ์€ ๋ ˆ์ด์–ด์— ์ •์˜๋œ ์ฐจ์›๊ณผ ์‹ค์ œ ์ž…๋ ฅ ๋ฐ์ดํ„ฐ์˜ ์ฐจ์›์ด ๋งž์ง€ ์•Š์•„๋„ ํ•จ์ˆ˜๊ฐ€ ๋™์ž‘ํ•˜๊ฒŒ ๋œ๋‹ค. nn.Linear ์™€ ์ž…๋ ฅ ์ฐจ์›์— ๋Œ€ํ•œ ์กฐ๊ฑด ๋นผ๊ณ ๋Š” ๋™์ผํ•œ ๋™์ž‘์„ ์ˆ˜ํ–‰ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ์‚ฌ์ „ ์ •์˜๋œ vocab ์‚ฌ์ด์ฆˆ์™€ ์ž…๋ ฅ ์‹œํ€€์Šค์˜ ํ† ํฐ ๊ฐœ์ˆ˜๊ฐ€ ๊ฐ™๋‹ค๋ฉด nn.Linear๋ฅผ ์‚ฌ์šฉํ•ด๋„ ๋ฌด๋ฐฉํ•˜๋‹ค.

# Input Embedding Example

class Transformer(nn.Module):
	def __init__(
        self,
        enc_vocab_size: int,
        dec_vocab_size: int,
        max_seq: int = 512,
        enc_N: int = 6,
        dec_N: int = 6,
        dim_model: int = 512, # latent vector space
        num_heads: int = 8,
        dim_ffn: int = 2048,
        dropout: float = 0.1
    ) -> None:
        super(Transformer, self).__init__()
        self.enc_input_embedding = nn.Embedding(enc_vocab_size, dim_model) # Encoder Input Embedding Layer
        self.dec_input_embedding = nn.Embedding(dec_vocab_size, dim_model) # Decoder Input Embedding Layer
	
	def forward(self, enc_inputs: Tensor, dec_inputs: Tensor, enc_pad_index: int, dec_pad_index: int) -> tuple[Tensor, Tensor, Tensor, Tensor]:
            enc_x, dec_x = self.enc_input_embedding(enc_inputs), self.dec_input_embedding(dec_inputs)

์œ„์˜ ์˜ˆ์‹œ ์ฝ”๋“œ๋ฅผ ํ•จ๊ป˜ ์‚ดํŽด๋ณด์ž. __init__ ์˜ self.enc_input_embedding, self._dec_input_embedding์ด ๋ฐ”๋กœ $W_E, W_D$์— ๋Œ€์‘๋œ๋‹ค. ํ•œํŽธ forward ๋ฉ”์„œ๋“œ์— ์ •์˜๋œ enc_x, dec_x ๋Š” ์ž„๋ฒ ๋”ฉ ๋ ˆ์ด์–ด๋ฅผ ๊ฑฐ์น˜๊ณ  ๋‚˜์˜จ $X_E, X_D$์— ํ•ด๋‹น๋œ๋‹ค.

ํ•œํŽธ, $X_E, X_D$์€ ๊ฐ๊ฐ ์ธ์ฝ”๋”, ๋””์ฝ”๋” ๋ชจ๋“ˆ๋กœ ํ˜๋Ÿฌ ๋“ค์–ด๊ฐ€ Absolute Position Embedding๊ณผ ๋”ํ•ด์ง„(ํ–‰๋ ฌ ํ•ฉ) ๋’ค, ๊ฐœ๋ณ„ ๋ชจ๋“ˆ์˜ ์ž…๋ ฅ๊ฐ’์œผ๋กœ ํ™œ์šฉ๋œ๋‹ค.

๐Ÿ”ขย Absolute Position Embedding(Encoding)
์ž…๋ ฅ ์‹œํ€€์Šค์— ์œ„์น˜ ์ •๋ณด๋ฅผ ๋งตํ•‘ํ•ด์ฃผ๋Š” ์—ญํ• ์„ ํ•œ๋‹ค. ํ•„์ž๋Š” ๊ฐœ์ธ์ ์œผ๋กœ Transformer์—์„œ ๊ฐ€์žฅ ์ค‘์š”ํ•œ ์š”์†Œ๋ฅผ ๋ฝ‘์œผ๋ผ๊ณ  ํ•˜๋ฉด ์„ธ ์†๊ฐ€๋ฝ ์•ˆ์— ๋“ค์–ด๊ฐ€๋Š” ํŒŒํŠธ๋ผ๊ณ  ์ƒ๊ฐํ•œ๋‹ค. ๋‹ค์Œ ํŒŒํŠธ์—์„œ ์ž์„ธํžˆ ๊ธฐ์ˆ ํ•˜๊ฒ ์ง€๋งŒ, Self-Attention(๋‚ด์ )์€ ์ž…๋ ฅ ์‹œํ€€์Šค๋ฅผ ๋ณ‘๋ ฌ๋กœ ํ•œ๊บผ๋ฒˆ์— ์ฒ˜๋ฆฌํ•  ์ˆ˜ ์žˆ๋‹ค๋Š” ์žฅ์ ์„ ๊ฐ–๊ณ  ์žˆ์ง€๋งŒ, ๊ทธ ์ž์ฒด๋กœ๋Š” ํ† ํฐ์˜ ์œ„์น˜ ์ •๋ณด๋ฅผ ์ธ์ฝ”๋”ฉํ•  ์ˆ˜ ์—†๋‹ค. ์šฐ๋ฆฌ๊ฐ€ ๋”ฐ๋กœ ์œ„์น˜ ์ •๋ณด๋ฅผ ์•Œ๋ ค์ฃผ์ง€ ์•Š๋Š” ์ด์ƒ ์ฟผ๋ฆฌ ํ–‰๋ ฌ์˜ 2๋ฒˆ์งธ ํ–‰๋ฒกํ„ฐ๊ฐ€ ์ž…๋ ฅ ์‹œํ€€์Šค์—์„œ ๋ช‡ ๋ฒˆ์งธ ์œ„์น˜ํ•œ ํ† ํฐ์ธ์ง€ ๋ชจ๋ธ์€ ์•Œ ๊ธธ์ด ์—†๋‹ค.

๊ทธ๋Ÿฐ๋ฐ, ํ…์ŠคํŠธ๋Š” Permutation Equivariantํ•œ Bias ๊ฐ€ ์žˆ๊ธฐ ๋•Œ๋ฌธ์— ํ† ํฐ์˜ ์œ„์น˜ ์ •๋ณด๋Š” NLP์—์„œ ๋งค์šฐ ์ค‘์š”ํ•œ ์š”์†Œ๋กœ ๊ผฝํžŒ๋‹ค. ์ง๊ด€์ ์œผ๋กœ๋„ ํ† ํฐ์˜ ์ˆœ์„œ๋Š” ์‹œํ€€์Šค๊ฐ€ ๋‚ดํฌํ•˜๋Š” ์˜๋ฏธ์— ์ง€๋Œ€ํ•œ ์˜ํ–ฅ์„ ๋ผ์นœ๋‹ค๋Š” ๊ฒƒ์„ ์•Œ ์ˆ˜ ์žˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด โ€œ์ฒ ์ˆ˜๋Š” ์˜ํฌ๋ฅผ ์ข‹์•„ํ•œ๋‹คโ€๋ผ๋Š” ๋ฌธ์žฅ๊ณผ โ€œ์˜ํฌ๋Š” ์ฒ ์ˆ˜๋ฅผ ์ข‹์•„ํ•œ๋‹คโ€๋ผ๋Š” ๋ฌธ์žฅ์˜ ์˜๋ฏธ๊ฐ€ ๊ฐ™์€๊ฐ€ ์ƒ๊ฐํ•ด๋ณด์ž. ์ฃผ์–ด์™€ ๋ชฉ์ ์–ด ์œ„์น˜๊ฐ€ ๋ฐ”๋€Œ๋ฉด์„œ ์ •๋ฐ˜๋Œ€์˜ ๋œป์ด ๋˜์–ด๋ฒ„๋ฆฐ๋‹ค.

Positional Encoding Example Positional Encoding Example

๋”ฐ๋ผ์„œ ์ €์ž๋Š” ์ž…๋ ฅ ์ž…๋ฒ ๋”ฉ์— ์œ„์น˜ ์ •๋ณด๋ฅผ ์ถ”๊ฐ€ํ•˜๊ณ ์ž Position Encoding ์„ ์ œ์•ˆํ•œ๋‹ค. ์‚ฌ์‹ค Position Encoding ์€ ์—ฌ๋Ÿฌ ๋‹จ์  ๋•Œ๋ฌธ์— ํ›„๋Œ€ Transformer ํŒŒ์ƒ ๋ชจ๋ธ์—์„œ๋Š” ์ž˜ ์‚ฌ์šฉ๋˜์ง€ ์•Š๋Š” ์ถ”์„ธ๋‹ค. ๋Œ€์‹  ๋ชจ๋ธ์ด ํ•™์Šต์„ ํ†ตํ•ด ์ตœ์ ๊ฐ’์„ ์ฐพ์•„์ฃผ๋Š” Position Embedding ๋ฐฉ์‹์„ ๋Œ€๋ถ€๋ถ„ ์ฐจ์šฉํ•˜๊ณ  ์žˆ๋‹ค. ํ•„์ž ์—ญ์‹œ Position Embedding ์„ ์‚ฌ์šฉํ•ด ์œ„์น˜ ์ž„๋ฒ ๋”ฉ์„ ๊ตฌํ˜„ํ–ˆ๊ธฐ ๋•Œ๋ฌธ์— ์›๋ฆฌ์™€ ๋‹จ์ ์— ๋Œ€ํ•ด์„œ๋งŒ ๊ฐ„๋‹จํžˆ ์†Œ๊ฐœํ•˜๊ณ  ๋„˜์–ด๊ฐ€๋ ค ํ•œ๋‹ค. ๋˜ํ•œ ์ €์ž ์—ญ์‹œ ๋…ผ๋ฌธ์—์„œ ๋‘ ๋ฐฉ์‹ ์ค‘ ์–ด๋Š ๊ฒƒ์„ ์จ๋„ ๋น„์Šทํ•œ ์„ฑ๋Šฅ์„ ๋ณด์—ฌ์ค€๋‹ค๊ณ  ์–ธ๊ธ‰ํ•˜๊ณ  ์žˆ๋‹ค.

\[P_E \in R^{B*S_E*D} \\ P_D \in R^{B*S_D*D} \\ P(pos, 2i) = sin(pos/\overset{} {10000_{}^{2i/dmodel}}) \\ P(pos, 2i+1) = cos(pos/\overset{} {10000_{}^{2i/dmodel}})\]

์›๋ฆฌ๋Š” ๋งค์šฐ ๊ฐ„๋‹จํ•˜๋‹ค. ์‚ฌ์ธํ•จ์ˆ˜์™€ ์ฝ”์‚ฌ์ธ ํ•จ์ˆ˜์˜ ์ฃผ๊ธฐ์„ฑ์„ ์ด์šฉํ•ด ๊ฐœ๋ณ„ ์ธ๋ฑ์Šค์˜ ํ–‰๋ฒกํ„ฐ ๊ฐ’์„ ํ‘œํ˜„ํ•˜๋Š” ๊ฒƒ์ด๋‹ค. ํ–‰๋ฒกํ„ฐ์˜ ์›์†Œ ์ค‘์—์„œ ์ง์ˆ˜๋ฒˆ์งธ ์ธ๋ฑ์Šค์— ์œ„์น˜ํ•œ ์›์†Œ๋Š” (์ง์ˆ˜๋ฒˆ์งธ ์—ด๋ฒกํ„ฐ) \(sin(pos/\overset{}{10000_{}^{2i/dmodel}})\) ์˜ ํ•จ์ˆซ๊ฐ’์„ ์ด์šฉํ•ด ์ฑ„์›Œ๋„ฃ๊ณ , ํ™€์ˆ˜๋ฒˆ์งธ ์›์†Œ๋Š” \(cos(pos/\overset{}{10000_{}^{2i/dmodel}})\)๋ฅผ ์ด์šฉํ•ด ์ฑ„์›Œ๋„ฃ๋Š”๋‹ค.

periodic function graph periodic function graph

์ดˆ๋ก์ƒ‰ ๊ทธ๋ž˜ํ”„๋Š” \(sin(pos/\overset{}{10000_{}^{2i/dmodel}})\), ์ฃผํ™ฉ์ƒ‰ ๊ทธ๋ž˜ํ”„๋Š” \(cos(pos/\overset{}{10000_{}^{2i/dmodel}})\)๋ฅผ ์‹œ๊ฐํ™”ํ–ˆ๋‹ค. ์ง€๋ฉด์˜ ์ œํ•œ์œผ๋กœ max_seq=512 ๋งŒํผ์˜ ๋ณ€ํ™”๋Ÿ‰์„ ๋‹ด์ง€๋Š” ๋ชปํ–ˆ์ง€๋งŒ, x์ถ•์ด ์ปค์งˆ์ˆ˜๋ก ๋‘ ํ•จ์ˆ˜ ๋ชจ๋‘ ์ง„๋™ ์ฃผ๊ธฐ๊ฐ€ ์กฐ๊ธˆ์”ฉ ์ปค์ง€๋Š” ์–‘์ƒ์„ ๋ณด์—ฌ์ค€๋‹ค. ๋”ฐ๋ผ์„œ ๊ฐœ๋ณ„ ์ธ๋ฑ์Šค(ํ–‰๋ฒกํ„ฐ)๋ฅผ ์ค‘๋ณต๋˜๋Š” ๊ฐ’ ์—†์ด ํ‘œํ˜„ํ•˜๋Š” ๊ฒƒ์ด ๊ฐ€๋Šฅํ•˜๋‹ค๊ณ  ์ €์ž๋Š” ์ฃผ์žฅํ•œ๋‹ค.

Positional Encoding Result Positional Encoding Result

์œ„ ๊ทธ๋ฆผ์€ ํ† ํฐ 50๊ฐœ, ์€๋‹‰์ธต์ด 256์ฐจ์›์œผ๋กœ ๊ตฌ์„ฑ๋œ ์‹œํ€€์Šค์— ๋Œ€ํ•ด Positional Encodingํ•œ ๊ฒฐ๊ณผ๋ฅผ ์‹œ๊ฐํ™”ํ•œ ์ž๋ฃŒ๋‹ค. ๊ทธ๋ž˜ํ”„์˜ $x$์ถ•์€ ํ–‰๋ฒกํ„ฐ์˜ ์›์†Œ์ด์ž Transformer์˜ ์€๋‹‰ ๋ฒกํ„ฐ ์ฐจ์›์„ ๊ฐ€๋ฆฌํ‚ค๊ณ , $y$์ถ•์€ ์‹œํ€€์Šค์˜ ์ธ๋ฑ์Šค(ํ–‰๋ฒกํ„ฐ)๋ฅผ ์˜๋ฏธํ•œ๋‹ค. ์œก์•ˆ์œผ๋กœ ์ •ํ™•ํ•˜๊ฒŒ ์ฐจ์ด๋ฅผ ์ธ์‹ํ•˜๊ธฐ ์‰ฝ์ง€๋Š” ์•Š์ง€๋งŒ, ํ–‰๋ฒกํ„ฐ๊ฐ€ ๋ชจ๋‘ ์œ ๋‹ˆํฌํ•˜๊ฒŒ ํ‘œํ˜„๋œ๋‹ค๋Š” ์‚ฌ์‹ค(์ง์ ‘ ์‹ค์ˆ˜๊ฐ’์„ ํ™•์ธํ•ด๋ณด๋ฉด ์ •๋ง ๋ฏธ์„ธํ•œ ์ฐจ์ด์ง€๋งŒ ๊ฐœ๋ณ„ ํ† ํฐ์˜ ํฌ์†Œ์„ฑ์ด ๋ณด์žฅ)์„ ์•Œ ์ˆ˜ ์žˆ๋‹ค. ์ž‘์€ ์ฐจ์ด๋ฅผ ์‹œ๊ฐํ™” ์ž๋ฃŒ๋กœ ํŒŒ์•…ํ•˜๊ธฐ๋Š” ์‰ฝ์ง€ ์•Š๊ธฐ ๋•Œ๋ฌธ์— ์ง„์งœ ๊ทธ๋Ÿฐ๊ฐ€ ๊ถ๊ธˆํ•˜์‹  ๋ถ„๋“ค์€ ์ง์ ‘ ์‹ค์ˆ˜๊ฐ’์„ ๊ตฌํ•ด๋ณด๋Š” ๊ฒƒ์„ ์ถ”์ฒœ๋“œ๋ฆฐ๋‹ค.

์—ฌ๊ธฐ์„œ ํ–‰๋ฒกํ„ฐ์˜ ํฌ์†Œ์„ฑ์ด๋ž€ ๊ฐœ๋ณ„ ํ–‰๋ฒกํ„ฐ ์›์†Œ์˜ ํฌ์†Œ์„ฑ์„ ๋งํ•˜๋Š”๊ฒŒ ์•„๋‹ˆ๋‹ค. 0๋ฒˆ ํ† ํฐ, 4๋ฒˆ ํ† ํฐ, 9๋ฒˆ ํ† ํฐ์˜ ํ–‰๋ฒกํ„ฐ 1๋ฒˆ์งธ ์›์†Œ์˜ ๊ฐ’์€ ๊ฐ™์„ ์ˆ˜ ์žˆ๋‹ค. ํ•˜์ง€๋งŒ ์ง„๋™ ์ฃผ๊ธฐ๊ฐ€ ๊ฐˆ์ˆ˜๋ก ์ปค์ง€๋Š” ์ฃผ๊ธฐํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜๊ธฐ ๋•Œ๋ฌธ์— ๋‹ค๋ฅธ ์›์†Œ(์ฐจ์›)๊ฐ’์€ ๋‹ค๋ฅผ ๊ฒƒ์ด๋ผ ๊ธฐ๋Œ€ํ•  ์ˆ˜ ์žˆ๋Š”๋ฐ, ๋ฐ”๋กœ ์ด๊ฒƒ์„ ํ–‰๋ฒกํ„ฐ์˜ ํฌ์†Œ์„ฑ์ด๋ผ๊ณ  ์ •์˜ํ•˜๋Š” ๊ฒƒ์ด๋‹ค. ๋งŒ์•ฝ 1๋ฒˆ ํ† ํฐ๊ณผ 2๋ฒˆ ํ† ํฐ์˜ ๋ชจ๋“  ํ–‰๋ฒกํ„ฐ ์›์†Œ๊ฐ’์ด ๊ฐ™๋‹ค๋ฉด ๊ทธ๊ฒƒ์€ ํฌ์†Œ์„ฑ ์›์น™์— ์œ„๋ฐฐ๋˜๋Š” ์ƒํ™ฉ์ด๋‹ค.

Positional Encoding

Compare Performance between Encoding and Embedding Compare Performance between Encoding and Embedding

๋น„๋ก ๊ฐœ๋ณ„ ํ–‰๋ฒกํ„ฐ์˜ ํฌ์†Œ์„ฑ์ด ๋ณด์žฅ๋œ๋‹ค๊ณ  ํ•ด๋„ Position Encoding์€ not trainableํ•ด์„œ staticํ•˜๋‹ค๋Š” ๋‹จ์ ์ด ์žˆ๋‹ค. ๋ชจ๋“  ๋ฐฐ์น˜์˜ ์‹œํ€€์Šค๊ฐ€ ๋™์ผํ•œ ์œ„์น˜ ์ •๋ณด๊ฐ’์„ ๊ฐ–๊ฒŒ ๋œ๋‹ค๋Š” ๊ฒƒ์ด๋‹ค. 512๊ฐœ์˜ ํ† ํฐ์œผ๋กœ ๊ตฌ์„ฑ๋œ ์‹œํ€€์Šค A์™€ B๊ฐ€ ์žˆ๋‹ค๊ณ  ๊ฐ€์ •ํ•ด๋ณด์ž. ์ด ๋•Œ ์‹œํ€€์Šค A๋Š” ๋ฌธ์žฅ 5๊ฐœ๋กœ ๊ตฌ์„ฑ ๋˜์–ด ์žˆ๊ณ , B๋Š” ๋ฌธ์žฅ 12๊ฐœ๋กœ ๋งŒ๋“ค์–ด์กŒ๋‹ค. ๋‘ ์‹œํ€€์Šค์˜ 11๋ฒˆ์งธ ํ† ํฐ์˜ ๋ฌธ์žฅ ์„ฑ๋ถ„์€ ๊ณผ์—ฐ ๊ฐ™์„๊นŒ?? ์•„๋งˆ๋„ ๋Œ€๋ถ€๋ถ„์˜ ๊ฒฝ์šฐ์— ๋‹ค๋ฅผ ๊ฒƒ์ด๋‹ค. ํ…์ŠคํŠธ ๋ฐ์ดํ„ฐ์—์„œ ์ˆœ์„œ ์ •๋ณด๊ฐ€ ์ค‘์š”ํ•œ ์ด์œ  ์ค‘ ํ•˜๋‚˜๋Š” ๋ฐ”๋กœ syntactical ํ•œ ์ •๋ณด๋ฅผ ํฌ์ฐฉํ•˜๊ธฐ ์œ„ํ•จ์ด๋‹ค. Position Encoding์€ static ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ์ด๋Ÿฌํ•œ ํƒ€์ž…์˜ ์ •๋ณด๋ฅผ ์ธ์ฝ”๋”ฉ ํ•˜๊ธฐ ์‰ฝ์ง€ ์•Š๋‹ค. ๊ทธ๋ž˜์„œ ์ข€ ๋” ํ’๋ถ€ํ•œ ํ‘œํ˜„์„ ๋‹ด์„ ์ˆ˜ ์žˆ๋Š” Position Embedding์„ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ด ์ตœ๊ทผ ์ถ”์„ธ๋‹ค.

โœ๏ธ Position Embedding

๊ทธ๋ ‡๋‹ค๋ฉด ์ด์ œ Position Embedding์— ๋Œ€ํ•ด ์•Œ์•„๋ณด์ž. Position Embedding ์€ Input Embedding์„ ์ •์˜ํ•œ ๋ฐฉ์‹๊ณผ ๊ฑฐ์˜ ์œ ์‚ฌํ•˜๋‹ค. ๋จผ์ € ์ž…๋ ฅ๊ฐ’๊ณผ weight ์˜ ๋ชจ์–‘๋ถ€ํ„ฐ ํ™•์ธํ•ด๋ณด์ž.

\[P_E \in R^{B*S_E*d} \\ P_D \in R^{B*S_d*d} \\ W_{P_E} \in R^{S_E * d} \\ W_{P_D} \in R^{S_D * d} \\\]

$P_E, P_D$๋Š” ๊ฐœ๋ณ„ ๋ชจ๋“ˆ์˜ ์œ„์น˜ ์ž„๋ฒ ๋”ฉ ๋ ˆ์ด์–ด ์ž…๋ ฅ์„ ๊ฐ€๋ฆฌํ‚ค๋ฉฐ, $W_{P_E}, W_{P_D}$๊ฐ€ ๊ฐœ๋ณ„ ๋ชจ๋“ˆ์˜ ์œ„์น˜ ์ž„๋ฒ ๋”ฉ ๋ ˆ์ด์–ด๊ฐ€ ๋œ๋‹ค. ์ด์ œ ์ด๊ฒƒ์„ ์ฝ”๋“œ๋กœ ์–ด๋–ป๊ฒŒ ๊ตฌํ˜„ํ•˜๋Š”์ง€ ์‚ดํŽด๋ณด์ž.

# Absolute Position Embedding Example

class Encoder(nn.Module):
    """
    In this class, encode input sequence and then we stack N EncoderLayer
    First, we define "positional embedding" and then add to input embedding for making "word embedding"
    Second, forward "word embedding" to N EncoderLayer and then get output embedding
    In official paper, they use positional encoding, which is base on sinusoidal function(fixed, not learnable)
    But we use "positional embedding" which is learnable from training
    Args:
        max_seq: maximum sequence length, default 512 from official paper
        N: number of EncoderLayer, default 6 for base model
    """
    def __init__(self, max_seq: 512, N: int = 6, dim_model: int = 512, num_heads: int = 8, dim_ffn: int = 2048, dropout: float = 0.1) -> None:
        super(Encoder, self).__init__()
        self.max_seq = max_seq
        self.scale = torch.sqrt(torch.Tensor(dim_model))  # scale factor for input embedding from official paper
        self.positional_embedding = nn.Embedding(max_seq, dim_model)  # add 1 for cls token

		... ์ค‘๋žต ...
    def forward(self, inputs: Tensor, mask: Tensor) -> tuple[Tensor, Tensor]:
        """
        inputs: embedding from input sequence, shape => [BS, SEQ_LEN, DIM_MODEL]
        mask: mask for Encoder padded token for speeding up to calculate attention score
        """
        layer_output = []
        pos_x = torch.arange(self.max_seq).repeat(inputs.shape[0]).to(inputs)
        x = self.dropout(
            self.scale * inputs + self.positional_embedding(pos_x)  # layernorm ์ ์šฉํ•˜๊ณ 
        )
		... ์ค‘๋žต ... 

์œ„ ์ฝ”๋“œ๋Š” Transformer์˜ ์ธ์ฝ”๋” ๋ชจ๋“ˆ์„ ๊ตฌํ˜„ํ•œ ๊ฒƒ์ด๋‹ค. ๊ทธ๋ž˜์„œ forward ๋ฉ”์„œ๋“œ์˜ pos_x ๊ฐ€ ๋ฐ”๋กœ $P_E$๊ฐ€ ๋˜๋ฉฐ, __init__์˜ self.positional_embedding์ด ๋ฐ”๋กœ $W_{P_E}$์— ๋Œ€์‘๋œ๋‹ค. ์ด๋ ‡๊ฒŒ ์ •์˜ํ•œ Position Embedding์€ Input Embedding๊ณผ ๋”ํ•ด์„œ Word Embedding ์„ ๋งŒ๋“ ๋‹ค. Word Embedding ์€ ๋‹ค์‹œ ๊ฐœ๋ณ„ ๋ชจ๋“ˆ์˜ linear projection ๋ ˆ์ด์–ด์— ๋Œ€ํ•œ ์ž…๋ ฅ $X$๋กœ ์‚ฌ์šฉ ๋œ๋‹ค.

ํ•œํŽธ, Input Embedding ๊ณผ Position Embedding์„ ๋”ํ•œ๋‹ค๋Š” ๊ฒƒ์— ์ฃผ๋ชฉํ•ด๋ณด์ž. ํ•„์ž๋Š” ๋ณธ ๋…ผ๋ฌธ์„ ๋ณด๋ฉฐ ๊ฐ€์žฅ ์˜๋ฌธ์ด ๋“ค์—ˆ๋˜ ๋ถ€๋ถ„์ด๋‹ค. ๋„๋Œ€์ฒด ์™œ ์™„์ „ํžˆ ์„œ๋กœ ๋‹ค๋ฅธ ์ถœ์ฒ˜์—์„œ ๋งŒ๋“ค์–ด์ง„ ํ–‰๋ ฌ ๋‘๊ฐœ๋ฅผ concat ํ•˜์ง€ ์•Š๊ณ  ๋”ํ•ด์„œ ์‚ฌ์šฉํ–ˆ์„๊นŒ?? concat์„ ์ด์šฉํ•˜๋ฉด Input๊ณผ Position ์ •๋ณด๋ฅผ ์„œ๋กœ ๋‹ค๋ฅธ ์ฐจ์›์— ๋‘๊ณ  ํ•™์Šตํ•˜๋Š”๊ฒŒ ๊ฐ€๋Šฅํ–ˆ์„ํ…๋ฐ ๋ง์ด๋‹ค.

๐Ÿค”ย Why Sum instead of Concatenate
ํ–‰๋ ฌํ•ฉ์„ ์‚ฌ์šฉํ•˜๋Š” ์ด์œ ์— ๋Œ€ํ•ด ์ €์ž๊ฐ€ ํŠน๋ณ„ํžˆ ์–ธ๊ธ‰ํ•˜์ง€๋Š” ์•Š์•„์„œ ๋•Œ๋ฌธ์— ์ •ํ™•ํ•œ ์˜๋„๋ฅผ ์•Œ ์ˆ˜ ์—†์ง€๋งŒ, ์ถ”์ธกํ•˜๊ฑด๋ฐ blessing of dimensionality ํšจ๊ณผ๋ฅผ ์˜๋„ํ–ˆ์ง€ ์•Š์•˜๋‚˜ ์‹ถ๋‹ค. blessing of dimensionality ๋ž€, ๊ณ ์ฐจ์› ๊ณต๊ฐ„์—์„œ ๋ฌด์ž‘์œ„๋กœ ์„œ๋กœ ๋‹ค๋ฅธ ๋ฒกํ„ฐ ๋‘๊ฐœ๋ฅผ ์„ ํƒํ•˜๋ฉด ๋‘ ๋ฒกํ„ฐ๋Š” ๊ฑฐ์˜ ๋Œ€๋ถ€๋ถ„ approximate orthogonality๋ฅผ ๊ฐ–๋Š” ํ˜„์ƒ์„ ์„ค๋ช…ํ•˜๋Š” ์šฉ์–ด๋‹ค. ๋ฌด์กฐ๊ฑด ์„ฑ๋ฆฝํ•˜๋Š” ์„ฑ์งˆ์€ ์•„๋‹ˆ๊ณ  ํ™•๋ฅ ๋ก ์ ์ธ ์ ‘๊ทผ์ด๋ผ๋Š” ๊ฒƒ์„ ๋ช…์‹ฌํ•˜์ž. ์•„๋ฌดํŠผ ์ง๊ตํ•˜๋Š” ๋‘ ๋ฒกํ„ฐ๋Š” ๋‚ด์ ๊ฐ’์ด 0์— ์ˆ˜๋ ดํ•œ๋‹ค. ์ฆ‰, ๋‘ ๋ฒกํ„ฐ๋Š” ์„œ๋กœ์—๊ฒŒ ์˜ํ–ฅ์„ ๋ฏธ์น˜์ง€ ๋ชปํ•œ๋‹ค๋Š” ๊ฒƒ์ด๋‹ค. ์ด๊ฒƒ์€ ์ „์ฒด ๋ชจ๋ธ์˜ hidden states space ์—์„œ Input Embedding ๊ณผ Position Embedding ์—ญ์‹œ ๊ฐœ๋ณ„ ๋ฒกํ„ฐ๊ฐ€ span ํ•˜๋Š” ๋ถ€๋ถ„ ๊ณต๊ฐ„ ๋ผ๋ฆฌ๋Š” ์„œ๋กœ ์ง๊ตํ•  ๊ฐ€๋Šฅ์„ฑ์ด ๋งค์šฐ ๋†’๋‹ค๋Š” ๊ฒƒ์„ ์˜๋ฏธํ•œ๋‹ค. ๋”ฐ๋ผ์„œ ์„œ๋กœ ๋‹ค๋ฅธ ์ถœ์ฒ˜๋ฅผ ํ†ตํ•ด ๋งŒ๋“ค์–ด์ง„ ๋‘ ํ–‰๋ ฌ์„ ๋”ํ•ด๋„ ์„œ๋กœ์—๊ฒŒ ์˜ํ–ฅ์„ ๋ฏธ์น˜์ง€ ๋ชปํ•  ๊ฒƒ์ด๊ณ  ๊ทธ๋กœ ์ธํ•ด ๋ชจ๋ธ์ด Input๊ณผ Position ์ •๋ณด๋ฅผ ๋”ฐ๋กœ ์ž˜ ํ•™์Šตํ•  ์ˆ˜ ์žˆ์„ ๊ฒƒ์ด๋ผ ๊ธฐ๋Œ€ํ•ด๋ณผ ์ˆ˜ ์žˆ๋‹ค. ๊ฐ€์ •๋Œ€๋กœ๋งŒ ๋œ๋‹ค๋ฉด, concat ์„ ์‚ฌ์šฉํ•ด ๋ชจ๋ธ์˜ hidden states space ๋ฅผ ๋Š˜๋ ค Computational Overhead ๋ฅผ ์œ ๋ฐœํ•˜๋Š” ๊ฒƒ๋ณด๋‹ค ํ›จ์”ฌ ํšจ์œจ์ ์ด๋ผ๊ณ  ๋ณผ ์ˆ˜ ์žˆ๊ฒ ๋‹ค.

ํ•œํŽธ blessing of dimensionality์— ๋Œ€ํ•œ ์„ค๋ช…๊ณผ ์ฆ๋ช…์€ ๊ฝค๋‚˜ ๋งŽ์€ ๋‚ด์šฉ์ด ํ•„์š”ํ•ด ์—ฌ๊ธฐ์„œ๋Š” ์ž์„ธํžˆ ๋‹ค๋ฃจ์ง€ ์•Š๊ณ , ๋‹ค๋ฅธ ํฌ์ŠคํŠธ์—์„œ ๋”ฐ๋กœ ๋‹ค๋ฃจ๊ฒ ๋‹ค. ๊ด€๋ จํ•˜์—ฌ ์ข‹์€ ๋‚ด์šฉ์„ ๋‹ด๊ณ  ์žˆ๋Š” ๊ธ€์˜ ๋งํฌ๋ฅผ ๊ฐ™์ด ์ฒจ๋ถ€ํ–ˆ์œผ๋‹ˆ ์ฝ์–ด๋ณด์‹ค ๊ฒƒ์„ ๊ถŒํ•œ๋‹ค(๋งํฌ1, ๋งํฌ2).

๐Ÿš€ Self-Attention with linear projection

์™œ ์ด๋ฆ„์ด self-attention์ผ๊นŒ ๋จผ์ € ๊ณ ๋ฏผํ•ด๋ณด์ž. ์‚ฌ์‹ค attention ๊ฐœ๋…์€ ๋ณธ ๋…ผ๋ฌธ์ด ๋ฐœํ‘œ๋˜๊ธฐ ์ด์ „๋ถ€ํ„ฐ ์‚ฌ์šฉ๋˜๋˜ ๊ฐœ๋…์ด๋‹ค. attention์€ seq2seq ๊ตฌ์กฐ์—์„œ ์ฒ˜์Œ ๋‚˜์™”๋Š”๋ฐ, seq2seq ์€ ๋ฒˆ์—ญ ์„ฑ๋Šฅ์„ ๋†’์ด๋Š” ๊ฒƒ์„ ๋ชฉ์ ์œผ๋กœ ๊ณ ์•ˆ๋œ ๊ตฌ์กฐ๋ผ์„œ, ๋ชฉํ‘œ์ธ ๋””์ฝ”๋”์˜ hidden_states ๊ฐ’์„ ์ฟผ๋ฆฌ๋กœ, ์ธ์ฝ”๋”์˜ hidden_states๋ฅผ ํ‚ค, ๋ฒจ๋ฅ˜์˜ ์ถœ์ฒ˜๋กœ ์‚ฌ์šฉํ–ˆ๋‹ค. ์ฆ‰, ์„œ๋กœ ๋‹ค๋ฅธ ์ถœ์ฒ˜์—์„œ ๋‚˜์˜จ hidden_states ์„ ์‚ฌ์šฉํ•ด ๋‚ด์  ์—ฐ์‚ฐ์„ ์ˆ˜ํ–‰ํ–ˆ๋˜ ๊ฒƒ์ด๋‹ค. ์ด๋Ÿฐ ๊ฐœ๋…์— ์ด์ œ โ€œself" ๋ผ๋Š” ์ด๋ฆ„์ด ๋ถ™์—ˆ๋‹ค. ๊ฒฐ๊ตญ ๊ฐ™์€ ์ถœ์ฒ˜์—์„œ ๋‚˜์˜จ hidden_states ๋ฅผ ๋‚ด์ ํ•˜๊ฒ ๋‹ค๋Š” ์˜๋ฏธ๋ฅผ ๋‚ดํฌํ•˜๊ณ  ์žˆ๋Š” ๊ฒƒ์ด๋‹ค. ๋‚ด์ ์€ ๋‘ ๋ฒกํ„ฐ์˜ โ€œ๋‹ฎ์€ ์ •๋„โ€ ๋ฅผ ์ˆ˜ํ•™์ ์œผ๋กœ ๊ณ„์‚ฐํ•œ๋‹ค. ๋”ฐ๋ผ์„œ self-attention ์ด๋ž€ ๊ฐ„๋‹จํ•˜๊ฒŒ, ๊ฐ™์€ ์ถœ์ฒ˜์—์„œ ๋งŒ๋“ค์–ด์ง„ $Q$(์ฟผ๋ฆฌ), $K$(ํ‚ค), $V$(๋ฒจ๋ฅ˜)๊ฐ€ ์„œ๋กœ ์–ผ๋งˆ๋‚˜ ๋‹ฎ์•˜๋Š”์ง€ ๊ณ„์‚ฐํ•ด๋ณด๊ฒ ๋‹ค๋Š” ๊ฒƒ์ด๋‹ค.

self-attention with linear projection self-attention with linear projection

๊ทธ๋ ‡๋‹ค๋ฉด ์ด์ œ $Q$(์ฟผ๋ฆฌ), $K$(ํ‚ค), $V$(๋ฒจ๋ฅ˜)์˜ ์ •์ฒด, ๊ฐ™์€ ์ถœ์ฒ˜์—์„œ ๋‚˜์™”๋‹ค๋Š” ๋ง์˜ ์˜๋ฏธ ๊ทธ๋ฆฌ๊ณ  ์ž…๋ ฅ ํ–‰๋ ฌ $X$๋ฅผ linear projection ํ•˜์—ฌ $Q$(์ฟผ๋ฆฌ), $K$(ํ‚ค), $V$(๋ฒจ๋ฅ˜) ํ–‰๋ ฌ์„ ๋งŒ๋“œ๋Š” ์ด์œ ๋ฅผ ๊ตฌ์ฒด์ ์ธ ์˜ˆ์‹œ๋ฅผ ํ†ตํ•ด ์ดํ•ดํ•ด๋ณด์ž. ์ถ”๊ฐ€๋กœ $Q$(์ฟผ๋ฆฌ), $K$(ํ‚ค), $V$(๋ฒจ๋ฅ˜) ๊ฐœ๋…์€ Information Retrieval์—์„œ ๋จผ์ € ํŒŒ์ƒ๋œ ๊ฐœ๋…์ด๋ผ์„œ ์˜ˆ์‹œ ์—ญ์‹œ ์ •๋ณด ๊ฒ€์ƒ‰๊ณผ ๊ด€๋ จ๋œ ๊ฒƒ์œผ๋กœ ์ค€๋น„ํ–ˆ๋‹ค.

๋‹น์‹ ์ด ๋งŒ์•ฝ โ€œ์—์–ด์ปจ ํ•„ํ„ฐ ์ฒญ์†Œํ•˜๋Š” ๋ฐฉ๋ฒ•โ€์ด ๊ถ๊ธˆํ•ด ๊ตฌ๊ธ€์— ๊ฒ€์ƒ‰ํ•˜๋Š” ์ƒํ™ฉ์ด๋ผ๊ณ  ๊ฐ€์ •ํ•ด๋ณด๊ฒ ๋‹ค. ๋ชฉํ‘œ๋Š” ๊ฐ€์žฅ ๋น ๋ฅด๊ณ  ์ •ํ™•ํ•˜๊ฒŒ ๋‚ด๊ฐ€ ์›ํ•˜๋Š” ํ•„ํ„ฐ ์ฒญ์†Œ ๋ฐฉ๋ฒ•์— ๋Œ€ํ•œ ์ง€์‹์„ ํš๋“ํ•˜๋Š” ๊ฒƒ์ด๋‹ค. ๊ทธ๋ ‡๋‹ค๋ฉด ๋‹น์‹ ์€ ๋ญ๋ผ๊ณ  ๊ตฌ๊ธ€ ๊ฒ€์ƒ‰์ฐฝ์— ๊ฒ€์ƒ‰ํ•  ๊ฒƒ์ธ๊ฐ€?? ์ด๊ฒƒ์ด ๋ฐ”๋กœ $Q$(์ฟผ๋ฆฌ)์— ํ•ด๋‹นํ•œ๋‹ค. ๋‹น์‹ ์€ ๊ฒ€์ƒ‰์ฐฝ์— โ€œ์—์–ด์ปจ ํ•„ํ„ฐ ์ฒญ์†Œํ•˜๋Š” ๋ฐฉ๋ฒ•โ€์„ ์ž…๋ ฅํ•ด ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๋ฅผ ๋ฐ˜ํ™˜ ๋ฐ›์•˜๋‹ค. ๋ฐ˜ํ™˜ ๋ฐ›์€ ๊ฒฐ๊ณผ๋ฌผ์˜ ์ง‘ํ•ฉ์ด ๋ฐ”๋กœ $K$(ํ‚ค)๊ฐ€ ๋œ๋‹ค. ๋‹น์‹ ์€ ์ด 100๊ฐœ์˜ ๋ธ”๋กœ๊ทธ ๊ฒŒ์‹œ๋ฌผ์„ ํ‚ค ๊ฐ’์œผ๋กœ ๋ฐ›์•˜๋‹ค. ๊ทธ๋ž˜์„œ ๋‹น์‹ ์ด ์‚ฌ์šฉํ•˜๋Š” ์‚ผ์„ฑ ๋ฌดํ’ ์—์–ด์ปจ์˜ ํ•„ํ„ฐ ์ฒญ์†Œ๋ฒ•์ด ์ •ํ™•ํžˆ ์ ํžŒ ๊ฒŒ์‹œ๋ฌผ์„ ์ฐพ๊ธฐ ์œ„ํ•ด ํ•˜๋‚˜ ํ•˜๋‚˜ ๋งํฌ๋ฅผ ํƒ€๊ณ  ๋“ค์–ด๊ฐ€ ๋ณด์•˜๋‹ค. ํ•˜์ง€๋งŒ ์ •ํ™•ํ•˜๊ฒŒ ์›ํ•˜๋Š” ์ •๋ณด๊ฐ€ ์—†์–ด์„œ ๊ณ„์† ์ฐพ๋‹ค๋ณด๋‹ˆ ๊ฒฐ๊ตญ 4ํŽ˜์ด์ง€ ์ฏค์—์„œ ์›ํ•˜๋˜ ์ •๋ณด๊ฐ€ ๋‹ด๊ธด ๊ฒŒ์‹œ๋ฌผ์„ ์ฐพ์„ ์ˆ˜ ์žˆ์—ˆ๋‹ค. ์ด๋ ‡๊ฒŒ ๋‚ด๊ฐ€ ์›ํ•˜๋Š” ์ •๋ณด์ธ์ง€ ์•„๋‹Œ์ง€ ๋Œ€์กฐํ•˜๋Š” ๊ณผ์ •์ด ๋ฐ”๋กœ $Q$(์ฟผ๋ฆฌ)์™€ $K$(ํ‚ค) ํ–‰๋ ฌ์„ ๋‚ด์ ํ•˜๋Š” ํ–‰์œ„๊ฐ€ ๋œ๋‹ค. ๊ณง๋ฐ”๋กœ ์—์–ด์ปจ ์ฒญ์†Œ๋ฅผ ํ•˜๋ ค๊ณ  ๋ณด๋‹ˆ, ๋ฐฉ๋ฒ•์„ ๊นŒ๋จน์–ด์„œ ๋งค๋…„ ์—ฌ๋ฆ„๋งˆ๋‹ค ๊ฒ€์ƒ‰์„ ํ•ด์•ผํ•  ๊ฒƒ ๊ฐ™์•„ ํ•ด๋‹น ๊ฒŒ์‹œ๋ฌผ์„ ๋ถ๋งˆํฌ์— ์ €์žฅํ•ด๋‘์—ˆ๋‹ค. ์—ฌ๊ธฐ์„œ ๋ถ๋งˆํฌ๊ฐ€ ๋ฐ”๋กœ $V$(๋ฒจ๋ฅ˜) ํ–‰๋ ฌ์ด ๋œ๋‹ค.

์ด ๋ชจ๋“  ๊ณผ์ •์— 10๋ถ„์ด ๊ฑธ๋ ธ๋‹ค. ๊ฒจ์šฐ ํ•„ํ„ฐ ์ฒญ์†Œ ๋ฐฉ๋ฒ•์„ ์ฐพ๋Š”๋ฐ 10๋ถ„์ด๋ผ๋‹ˆ ๋‹น์‹ ์€ ์ž์กด์‹ฌ์ด ์ƒํ–ˆ๋‹ค. ๋” ๋นจ๋ฆฌ ์›ํ•˜๋Š” ์ •๋ณด(์†์‹ค ํ•จ์ˆ˜ ์ตœ์ ํ™”)๋ฅผ ์ฐพ์„ ์ˆ˜ ์žˆ๋Š” ๋ฐฉ๋ฒ•์ด ์—†์„๊นŒ ๊ณ ๋ฏผํ•ด๋ณด๋‹ค๊ฐ€ ๋‹น์‹ ์ด ์‚ฌ์šฉํ•˜๋Š” ์—์–ด์ปจ ๋ธŒ๋žœ๋“œ๋ช…(์‚ผ์„ฑ Bespoke ์—์–ด์ปจ)์„ ๊ฒ€์ƒ‰์–ด์— ์ถ”๊ฐ€ํ•˜๊ธฐ๋กœ ํ–ˆ๋‹ค. ๊ทธ๋žฌ๋”๋‹ˆ 1ํŽ˜์ด์ง€ ์ตœํ•˜๋‹จ์—์„œ ์•„๊นŒ 4ํŽ˜์ด์ง€์—์„œ ์ฐพ์€ ์ •๋ณด๋ฅผ ๊ณง๋ฐ”๋กœ ์ฐพ์„ ์ˆ˜ ์žˆ์—ˆ๋‹ค. ๊ทธ ๋•๋ถ„์— ์‹œ๊ฐ„์„ 10๋ถ„์—์„œ 1๋ถ„ 30์ดˆ๋กœ ๋‹จ์ถ•์‹œํ‚ฌ ์ˆ˜ ์žˆ์—ˆ๋‹ค. ์ด๋ ‡๊ฒŒ ๊ฒ€์ƒ‰ ์‹œ๊ฐ„์„ ๋‹จ์ถ•(์†์‹ค ์ค„์ด๊ธฐ)ํ•˜๊ธฐ ์œ„ํ•ด ๋” ๋‚˜์€ ๊ฒ€์ƒ‰ ํ‘œํ˜„์„ ๊ณ ๋ฏผํ•˜๊ณ  ์ˆ˜์ •ํ•˜๋Š” ํ–‰์œ„๊ฐ€ ๋ฐ”๋กœ ์ž…๋ ฅ $X$์— $W_{Q}$๋ฅผ ๊ณฑํ•ด ํ–‰๋ ฌ $Q$ ์„ ๋งŒ๋“œ๋Š” ์ˆ˜์‹์œผ๋กœ ํ‘œํ˜„๋œ๋‹ค.

1๋…„ ๋’ค ์—ฌ๋ฆ„, ๋‹น์‹ ์€ ๋ธŒ๋ผ์šฐ์ €๋ฅผ ๋ฐ”๊พผ ํƒ“์— ๋ถ๋งˆํฌ๊ฐ€ ์ดˆ๊ธฐํ™” ๋˜์–ด ๋‹ค์‹œ ํ•œ ๋ฒˆ ๊ฒ€์ƒ‰์„ ํ•ด์•ผ ํ–ˆ๋‹ค. ํ•˜์ง€๋งŒ ์—ฌ์ „ํžˆ ๊ฒ€์ƒ‰์–ด๋Š” ๊ธฐ์–ตํ•˜๊ณ  ์žˆ์–ด์„œ, 1๋…„์ „ ์ตœ์ ์˜ ๊ฒฐ๊ณผ๋ฅผ ์–ป์—ˆ๋˜ ๊ทธ๋Œ€๋กœ ๋‹ค์‹œ ๊ฒ€์ƒ‰์„ ํ–ˆ๋‹ค. ๋ถ„๋ช… ๋˜‘๊ฐ™์ด ๊ฒ€์ƒ‰์„ ํ–ˆ๋Š”๋ฐ ๊ฐ™์€ ๊ฒฐ๊ณผ๊ฐ€ 1ํŽ˜์ด์ง€ ์ตœ์ƒ๋‹จ์—์„œ ๋ฐ˜ํ™˜๋˜๊ณ  ์žˆ์—ˆ๋‹ค. ๋‹น์‹ ์€ ์ด๊ฒŒ ์–ด๋–ป๊ฒŒ ๋œ ์ผ์ธ์ง€ ๊ถ๊ธˆํ•ด ํฌ์ŠคํŠธ๋ฅผ ์ฒœ์ฒœํžˆ ๋ณด๋˜ ์ค‘, ์ œ๋ชฉ์— 1๋…„์ „์—๋Š” ์—†๋˜ ์‚ผ์„ฑ Bespoke ์—์–ด์ปจ ์ด๋ผ๋Š” ํ‚ค์›Œ๋“œ๊ฐ€ ํฌํ•จ ๋˜์–ด ์žˆ์—ˆ๋‹ค. ๊ฒŒ์‹œ๋ฌผ์˜ ์ฃผ์ธ์žฅ์ด SEO ์ตœ์ ํ™”๋ฅผ ์œ„ํ•ด ์ถ”๊ฐ€ํ–ˆ๋˜ ๊ฒƒ์ด์—ˆ๋‹ค. ๋•๋ถ„์— ๋‹น์‹ ์€ ์†Œ์š” ์‹œ๊ฐ„์„ 1๋ถ„ 30์ดˆ์—์„œ 20์ดˆ๋กœ ์ค„์ผ ์ˆ˜ ์žˆ์—ˆ๋‹ค. ์ด๋Ÿฐ ์ƒํ™ฉ์ด ๋ฐ”๋กœ ์ž…๋ ฅ $X$์— $W_{K}$๋ฅผ ๊ณฑํ•ด ํ–‰๋ ฌ $K$ ๋ฅผ ๋งŒ๋“œ๋Š” ์ˆ˜์‹์— ๋Œ€์‘๋œ๋‹ค.

์šฐ๋ฆฌ๋Š” ์œ„ ์˜ˆ์‹œ๋ฅผ ํ†ตํ•ด ์›ํ•˜๋Š” ์ •๋ณด๋ฅผ ๋น ๋ฅด๊ณ  ์ •ํ™•ํ•˜๊ฒŒ ์ฐพ๋Š” ํ–‰์œ„๋ž€, ๋‹ต๋ณ€์ž๊ฐ€ ์ดํ•ดํ•˜๊ธฐ ์ข‹์€ ์งˆ๋ฌธ๊ณผ ์งˆ๋ฌธ์ž์˜ ์งˆ๋ฌธ ์˜๋„์— ๋ถ€ํ•ฉํ•˜๋Š” ์ข‹์€ ๋‹ต๋ณ€์œผ๋กœ ์™„์„ฑ๋œ๋‹ค๋Š” ๊ฒƒ์„ ์•Œ ์ˆ˜ ์žˆ์—ˆ๋‹ค. ๋ฟ๋งŒ ์•„๋‹ˆ๋ผ, ์ข‹์€ ์งˆ๋ฌธ๊ณผ ์ข‹์€ ๋‹ต๋ณ€์ด๋ผ๋Š” ๊ฒƒ์€ ์ฒ˜์Œ๋ถ€ํ„ฐ ์™„์„ฑ๋˜๋Š”๊ฒŒ ์•„๋‹ˆ๋ผ ๊ฒ€์ƒ‰ ์‹œ๊ฐ„์„ ๋‹จ์ถ•ํ•˜๋ ค๋Š” ๋Š์ž„์—†๋Š” ๋…ธ๋ ฅ์„ ํ†ตํ•ด ์„ฑ์ทจ๋œ๋‹ค๋Š” ๊ฒƒ ์—ญ์‹œ ๊นจ์šฐ์ณค๋‹ค. ๋‘๊ฐ€์ง€ ์ธ์‚ฌ์ดํŠธ๊ฐ€ ๋ฐ”๋กœ linear projection์œผ๋กœ ํ–‰๋ ฌ $Q, K,V$์„ ์ •์˜ํ•œ ์ด์œ ๋‹ค. ๋‚ด๊ฐ€ ์›ํ•˜๋Š” ์ •๋ณด์ธ์ง€ ์•„๋‹Œ์ง€ ๋Œ€์กฐํ•˜๋Š” ๋‚ด์  ์—ฐ์‚ฐ์€ ์ˆ˜ํ–‰ํ•˜๋Š”๋ฐ ๊ฐ€์ค‘์น˜ ํ–‰๋ ฌ์ด ํ•„์š” ์—†๊ธฐ ๋•Œ๋ฌธ์— ์†์‹คํ•จ์ˆ˜์˜ ์˜ค์ฐจ ์—ญ์ „์„ ํ™œ์šฉํ•œ ์ˆ˜์น˜ ์ตœ์ ํ™”๋ฅผ ์ˆ˜ํ–‰ํ•  ์ˆ˜ ์—†๋‹ค. ๊ทธ๋ž˜์„œ ์†์‹คํ•จ์ˆ˜ ๋ฏธ๋ถ„์— ์˜ํ•œ ์ตœ์ ํ™”๊ฐ€ ๊ฐ€๋Šฅํ•˜๋„๋ก linear projection matrix๋ฅผ ํ™œ์šฉํ•ด ํ–‰๋ ฌ $Q, K,V$๋ฅผ ์ •์˜ํ•ด์ค€ ๊ฒƒ์ด๋‹ค. ์ด๋ ‡๊ฒŒ ํ•˜๋ฉด ๋ชจ๋ธ์ด ์šฐ๋ฆฌ์˜ ๋ชฉ์ ์— ๊ฐ€์žฅ ์ ํ•ฉํ•œ ์งˆ๋ฌธ๊ณผ ๋‹ต๋ณ€์„ ์•Œ์•„์„œ ํ‘œํ˜„ ํ•ด์ค„ ๊ฒƒ์ด๋ผ ๊ธฐ๋Œ€ํ•  ์ˆ˜ ์žˆ๊ฒŒ ๋œ๋‹ค. ํ•œํŽธ, ๊ฐ™์€ ์ถœ์ฒ˜์—์„œ ๋‚˜์™”๋‹ค๋Š” ๋ง์€ ๋ฐฉ๊ธˆ ์˜ˆ์‹œ์—์„œ ํ–‰๋ ฌ $Q, K,V$๋ฅผ ๋งŒ๋“œ๋Š”๋ฐ ๋™์ผํ•˜๊ฒŒ ์ž…๋ ฅ $X$๋ฅผ ์‚ฌ์šฉ ๊ฒƒ๊ณผ ๊ฐ™์€ ์ƒํ™ฉ์„ ์˜๋ฏธํ•œ๋‹ค.

์ด์ œ ๋‹ค์‹œ ์ž์—ฐ์–ด ์ฒ˜๋ฆฌ ๋งฅ๋ฝ์œผ๋กœ ๋Œ์•„์™€๋ณด์ž. Transformer ๋Š” ์ข‹์€ ๋ฒˆ์—ญ๊ธฐ๋ฅผ ๋งŒ๋“ค๊ธฐ ์œ„ํ•ด ๊ณ ์•ˆ๋œ seq2seq ๊ตฌ์กฐ์˜ ๋ชจ๋ธ์ด๋‹ค. ์ฆ‰, ๋น ๋ฅด๊ณ  ์ •ํ™•ํ•˜๊ฒŒ ๋Œ€์ƒ ์–ธ์–ด์—์„œ ํƒ€๊ฒŸ ์–ธ์–ด๋กœ ๋ฒˆ์—ญํ•˜๋Š” ๊ฒƒ์— ๋ชฉํ‘œ๋ฅผ ๋‘๊ณ  ๋งŒ๋“ค์–ด์กŒ๋‹ค๋Š” ๊ฒƒ์ด๋‹ค. ๋ฒˆ์—ญ์„ ์ž˜ํ•˜๊ธฐ ์œ„ํ•ด์„œ๋Š” ์–ด๋–ป๊ฒŒ ํ•ด์•ผ ํ• ๊นŒ?? 1) ๋Œ€์ƒ ์–ธ์–ด๋กœ ์“ฐ์ธ ์‹œํ€€์Šค์˜ ์˜๋ฏธ๋ฅผ ์ •ํ™•ํ•˜๊ฒŒ ํŒŒ์•…ํ•ด์•ผ ํ•˜๊ณ , 2) ํŒŒ์•…ํ•œ ์˜๋ฏธ์™€ ๊ฐ€์žฅ ์œ ์‚ฌํ•œ ์‹œํ€€์Šค๋ฅผ ํƒ€๊ฒŸ ์–ธ์–ด๋กœ ๋งŒ๋“ค์–ด ๋‚ด์•ผ ํ•œ๋‹ค. ๊ทธ๋ž˜์„œ 1๋ฒˆ์˜ ์—ญํ• ์€ Encoder๊ฐ€ ๊ทธ๋ฆฌ๊ณ  2๋ฒˆ์€ Decoder๊ฐ€ ๋งก๊ฒŒ ๋œ๋‹ค. ์ธ์ฝ”๋”๋Š” ๊ฒฐ๊ตญ (๋ฒˆ์—ญํ•˜๋Š”๋ฐ ์ ํ•ฉํ•œ ํ˜•ํƒœ๋กœ) ๋Œ€์ƒ ์–ธ์–ด ์‹œํ€€์Šค์˜ ์˜๋ฏธ๋ฅผ ์ •ํ™•ํžˆ ์ดํ•ดํ•˜๋Š” ๋ฐฉํ–ฅ(์ˆซ์ž๋กœ ํ‘œํ˜„, ์ž„๋ฒ ๋”ฉ ์ถ”์ถœ)์œผ๋กœ ํ•™์Šต์„ ์ˆ˜ํ–‰ํ•˜๊ฒŒ ๋˜๋ฉฐ, ๋””์ฝ”๋”๋Š” ์ธ์ฝ”๋”์˜ ํ•™์Šต ๊ฒฐ๊ณผ์™€ ๊ฐ€์žฅ ์œ ์‚ฌํ•œ ๋ฌธ์žฅ์„ ํƒ€๊ฒŸ ์–ธ์–ด๋กœ ์ƒ์„ฑํ•ด๋‚ด๋Š” ๊ณผ์ •์„ ๋ฐฐ์šฐ๊ฒŒ ๋œ๋‹ค. ๋”ฐ๋ผ์„œ ์ธ์ฝ”๋”๋Š” ๋Œ€์ƒ ์–ธ์–ด๋ฅผ ์ถœ์ฒ˜๋กœ, ๋””์ฝ”๋”๋Š” ํƒ€๊ฒŸ ์–ธ์–ด๋ฅผ ์ถœ์ฒ˜๋กœ ํ–‰๋ ฌ $Q, K,V$๋ฅผ ๋งŒ๋“ ๋‹ค. ์ •ํ™•ํžˆ self ๋ผ๋Š” ๋‹จ์–ด๋ฅผ ์ด๋ฆ„์— ๊ฐ–๋‹ค ๋ถ™์ธ ์˜๋„์™€ ์ผ๋งฅ์ƒํ†ตํ•˜๋Š” ๋ชจ์Šต์ด๋‹ค.

๊ฒฐ๊ตญ Transformer ์˜ ์„ฑ๋Šฅ์„ ์ขŒ์ง€์šฐ์ง€ ํ•˜๋Š” ๊ฒƒ์€ ๋ˆ„๊ฐ€ ์–ผ๋งˆ๋‚˜ ๋” linear projection weight์„ ์ž˜ ์ตœ์ ํ™” ํ•˜๋Š”๊ฐ€์— ๋‹ฌ๋ ธ๋‹ค๊ณ  ๋ณผ ์ˆ˜ ์žˆ๋‹ค.

ํ•œํŽธ ํ•„์ž๋Š” ์ฒ˜์Œ ์ด ๋…ผ๋ฌธ์„ ์ฝ์—ˆ์„ ๋•Œ linear projection ์ž์ฒด์˜ ํ•„์š”์„ฑ์€ ๊ณต๊ฐํ–ˆ์œผ๋‚˜, ๊ตณ์ด 3๊ฐœ์˜ ํ–‰๋ ฌ๋กœ ๋‚˜๋ˆ ์„œ train ์‹œ์ผœ์•ผ ํ•˜๋Š” param ์ˆซ์ž๋ฅผ ๋Š˜๋ฆฌ๋Š” ๊ฒƒ๋ณด๋‹ค๋Š” weight share ํ•˜๋Š” ํ˜•ํƒœ๋กœ ๋งŒ๋“œ๋Š”๊ฒŒ ๋” ํšจ์œจ์ ์ผ ๊ฒƒ ๊ฐ™๋‹ค๋Š” ์ถ”์ธก์„ ํ–ˆ์—ˆ๋‹ค.

๊ทธ๋Ÿฌ๋‚˜ ์ด๋ฒˆ ๋ฆฌ๋ทฐ๋ฅผ ์œ„ํ•ด ๋‹ค์‹œ ๋…ผ๋ฌธ์„ ์ฝ๋˜ ์ค‘, ์ข‹์€ ์งˆ๋ฌธ์„ ํ•˜๊ธฐ ์œ„ํ•œ ๋…ธ๋ ฅ๊ณผ ์ข‹์€ ๋‹ต๋ณ€์„ ํ•˜๊ธฐ ์œ„ํ•œ ๋…ธ๋ ฅ, ๊ทธ๋ฆฌ๊ณ  ํ•„์š”ํ•œ ์ •๋ณด๋ฅผ ์ •ํ™•ํžˆ ์ถ”์ถœํ•ด๋‚ด๋Š” ํ–‰์œ„๋ฅผ ๊ฐ๊ฐ ์„œ๋กœ ๋‹ค๋ฅธ 3๊ฐœ์˜ ๋ฒกํ„ฐ๋กœ ํ‘œํ˜„ํ–ˆ์„ ๋•Œ ๋ฒกํ„ฐ๋“ค์ด ๊ฐ€์ง€๋Š” ๋ฐฉํ–ฅ์„ฑ์ด ์„œ๋กœ ๋‹ค๋ฅผํ…๋ฐ ๊ทธ๊ฒƒ์„ ํ•˜๋‚˜์˜ ๋ฒกํ„ฐ๋กœ ํ‘œํ˜„ํ•˜๋ ค๋ฉด ๋ชจ๋ธ์ด ํ•™์Šต์„ ํ•˜๊ธฐ ํž˜๋“ค ๊ฒƒ ๊ฐ™๋‹ค๋Š” ์ƒ๊ฐ์ด ๋“ค์—ˆ๋‹ค. ๋ฐฉ๊ธˆ ์œ„์—์„œ ๋“  ์˜ˆ์‹œ๋งŒ ๋ด๋„ ๊ทธ๋ ‡๋‹ค. ์„œ๋กœ ๋‹ค๋ฅธ 3๊ฐœ์˜ ํ–‰์œ„ ์‚ฌ์ด์˜ ์ตœ์  ์ง€์ ์„ ์ฐพ์œผ๋ผ๋Š” ๊ฒƒ๊ณผ ๋งˆ์ฐฌ๊ฐ€์ง„๋ฐ ๊ทธ๋Ÿฐ ์ŠคํŒŸ์ด ์žˆ๋‹ค๊ณ  ํ•ด๋„ ์–ธ์–ด ๋ชจ๋ธ์ด ์ž˜ ์ฐพ์„ ์ˆ˜ ์žˆ์„๊นŒ?? ์ธ๊ฐ„๋„ ์ฐพ๊ธฐ ํž˜๋“  ๊ฒƒ์„ ๋ชจ๋ธ์ด ์ž˜ ์ฐพ์„๋ฆฌ๊ฐ€ ์—†๋‹ค.

๐Ÿ“ย Scaled Dot-Product Attention

\[Attention(Q,K,V) = softmax(\frac{QยทK^T}{\sqrt{d_k}})V\]

์ด๋ฒˆ์—๋Š” Self-Attention ์˜ ๋‘ ๋ฒˆ์งธ ํ•˜์œ„ ๋ธ”๋Ÿญ์ธ Scaled Dot-Product Attention ์ฐจ๋ก€๋‹ค. ์‚ฌ์‹ค ์šฐ๋ฆฌ๋Š” Linear Projection ํŒŒํŠธ์—์„œ ์ด๋ฏธ ์šฐ๋ฆฌ๋„ ๋ชจ๋ฅด๊ฒŒ Scaled Dot-Product Attention ์— ๋Œ€ํ•ด ๊ณต๋ถ€ํ–ˆ๋‹ค. ์˜ˆ์‹œ๋ฅผ ๋‹ค์‹œ ํ•œ ๋ฒˆ ์ƒ๊ธฐ์‹œ์ผœ๋ณด์ž. ์งˆ์˜๋ฅผ ํ†ตํ•ด ์–ป์€ ๊ฒฐ๊ณผ ๋ฆฌ์ŠคํŠธ(ํ‚ค)์—์„œ ๋‚ด๊ฐ€ ์›ํ•˜๋Š” ์ •๋ณด๋ฅผ ์ฐพ๊ธฐ ์œ„ํ•ด ์ฟผ๋ฆฌ์™€ ํ‚ค๋ฅผ ๋Œ€์กฐํ•œ๋‹ค๊ณ  ํ–ˆ๋˜ ๊ฒƒ ๊ธฐ์–ต๋‚˜๋Š”๊ฐ€?? ๋ฐ”๋กœ ๊ทธ ๋Œ€์กฐํ•˜๋Š” ํ–‰์œ„๋ฅผ ์ˆ˜ํ•™์ ์œผ๋กœ ๋ชจ๋ธ๋งํ•œ ๊ฒƒ์ด ๋ฐ”๋กœ Scaled Dot-Product Attention ์— ํ•ด๋‹นํ•œ๋‹ค.

Attention is All You Need Attention is All You Need

Scaled Dot-Product Attention ์€ ์ด 5๋‹จ๊ณ„๋ฅผ ๊ฑฐ์ณ ์™„์„ฑ๋œ๋‹ค. ๋‹จ๊ณ„๋งˆ๋‹ค ์–ด๋–ค ์—ฐ์‚ฐ์„ ์™œ ํ•˜๋Š”์ง€ ๊ทธ๋ฆฌ๊ณ  ๋ฌด์Šจ ์ธ์‚ฌ์ดํŠธ๊ฐ€ ๋‹ด๊ฒจ ์žˆ๋Š”์ง€ ์•Œ์•„๋ณด์ž. ์ด ์ค‘์—์„œ ๋งˆ์Šคํ‚น ๋‹จ๊ณ„๋Š” ์ธ์ฝ”๋”์™€ ๋””์ฝ”๋”์˜ ๋™์ž‘์„ ์ž์„ธํžˆ ์•Œ์•„์•ผํ•˜๊ธฐ ๋•Œ๋ฌธ์— ์ „์ฒด์ ์ธ ๊ตฌ์กฐ ๊ด€์ ์—์„œ ๋ชจ๋ธ์„ ๋ฐ”๋ผ๋ณผ ๋•Œ ํ•จ๊ป˜ ์„ค๋ช…ํ•˜๋„๋ก ํ•˜๊ฒ ๋‹ค.

โœ–๏ธย Stage 1. Qโ€ขK^T Dot-Product

\[Qโ€ขK^T\]

์ธ๊ฐ„์€ ๋ฌธ์žฅ์ด๋‚˜ ์–ด๋–ค ํ‘œํ˜„์˜ ์˜๋ฏธ๋ฅผ ํŒŒ์•…ํ•˜๋Š”๋ฐ ๋ฐ”๋กœ ์ฃผ๋ณ€ ๋งฅ๋ฝ์„ ์ฐธ๊ณ ํ•˜๊ฑฐ๋‚˜, ๋” ๋ฉ€๋ฆฌ ๋–จ์–ด์ง„ ๊ณณ์˜ ๋‹จ์–ดโ€ข์‹œํ€€์Šค๋ฅผ ์ด์šฉํ•˜๊ธฐ๋„ ํ•œ๋‹ค. ์ฆ‰, ์ฃผ์–ด์ง„ ์‹œํ€€์Šค ๋‚ด๋ถ€์˜ ๋ชจ๋“  ๋งฅ๋ฝ์„ ์ด์šฉํ•ด ํŠน์ • ๋ถ€๋ถ„์˜ ์˜๋ฏธ๋ฅผ ์ดํ•ดํ•œ๋‹ค๋Š” ๊ฒƒ์ด๋‹ค. ๊ทธ๋ ‡๋‹ค๊ณ  ๋ชจ๋“  ์ •๋ณด๊ฐ€ ๋™์ผํ•˜๊ฒŒ ํŠน์ • ํ‘œํ˜„์˜ ์˜๋ฏธ์— ์˜ํ–ฅ์„ ๋ฏธ์น˜๋Š” ๊ฒƒ์€ ๋˜ ์•„๋‹Œ๋ฐ, ์ˆ˜๋Šฅ ์˜์–ด์— ํ‚ฌ๋Ÿฌ ๋ฌธํ•ญ์œผ๋กœ ๋“ฑ์žฅํ•˜๋Š” ๋นˆ์นธ ์ฑ„์šฐ๊ธฐ ๋ฌธ์ œ๋ฅผ ์–ด๋–ป๊ฒŒ ํ’€์—ˆ๋‚˜ ๋– ์˜ฌ๋ ค๋ณด์ž. ๋””ํ…Œ์ผํ•œ ํ’€์ด ๋ฐฉ์‹์—๋Š” ์‚ฌ๋žŒ๋งˆ๋‹ค ์ฐจ์ด๊ฐ€ ์žˆ๊ฒ ์ง€๋งŒ, ์ผ๋ฐ˜์ ์œผ๋กœ ์ง€๋ฌธ์€ ๋ชจ๋‘ ํ›‘์–ด ๋ณด๋˜ ๋นˆ์นธ์— ๋“ค์–ด๊ฐˆ ์ •๋‹ต์˜ ๊ทผ๊ฑฐ๊ฐ€ ๋˜๋Š” ํŠน์ • ๋ฌธ์žฅ ํ˜น์€ ํ‘œํ˜„ 1~2๊ฐœ๋ฅผ ์ฐพ์•„๋‚ด์–ด ๋น„์Šทํ•œ ์˜๋ฏธ๋ฅผ ์ง€๋‹Œ ์„ ์ง€๋ฅผ ๊ณจ๋ผ ๋‚ด๋Š” ๋ฐฉ์‹์„ ์‚ฌ์šฉํ•œ๋‹ค. ๋‹ค์‹œ ๋งํ•ด, ์ฃผ์–ด์ง„ ์ „์ฒด ๋‹จ๋ฝ์—์„œ ์˜๋ฏธ๋ฅผ ์ดํ•ดํ•˜๋Š”๋ฐ ์ค‘์š”ํ•œ ์—ญํ• ์„ ํ•˜๋Š” ํ‘œํ˜„์ด๋‚˜ ๋ฌธ์žฅ์„ ๊ณจ๋ผ๋‚ด์–ด ์ค‘์š”๋„ ๋งŒํผ ๊ฐ€์ค‘์น˜ ๋ฅผ ์ฃผ๊ฒ ๋‹ค๋Š” ๊ฒƒ์ด๋‹ค.

Qโ€ขK^T Dot Product Visualization Qโ€ขK^T Dot Product Visualization

๊ทธ๋ ‡๋‹ค๋ฉด ์ด๊ฒƒ์„ ์–ด๋–ป๊ฒŒ ์ˆ˜ํ•™์ ์œผ๋กœ ๋ชจ๋ธ๋งํ–ˆ์„๊นŒ?? ๋ฐ”๋กœ ํ–‰๋ ฌ $Q$์™€ $K^T$์˜ ๋‚ด์ ์„ ํ™œ์šฉํ•œ๋‹ค. ํ–‰๋ ฌ $Q$๋Š” ๋ชจ๋ธ์ด ์˜๋ฏธ๋ฅผ ํŒŒ์•…ํ•ด์•ผ ํ•˜๋Š” ๋Œ€์ƒ์ด ๋‹ด๊ฒจ ์žˆ๊ณ , ํ–‰๋ ฌ $K$์—๋Š” ์˜๋ฏธ ํŒŒ์•…์— ํ•„์š”ํ•œ ๋‹จ์„œ๋“ค์ด ๋‹ด๊ฒจ์žˆ๋‹ค. ๋‚ด์ ์€ ๋‘ ๋ฒกํ„ฐ์˜ ์„œ๋กœ โ€œ๋‹ฎ์€ ์ •๋„โ€ ๋ฅผ ์˜๋ฏธํ•œ๋‹ค๊ณ  ํ–ˆ๋‹ค. โ€œ๋‹ฎ์€ ์ •๋„โ€ ๊ฐ€ ๋ฐ”๋กœ ์ค‘์š”๋„โ€ข๊ฐ€์ค‘์น˜์— ๋Œ€์‘๋œ๋‹ค. ๋”ฐ๋ผ์„œ ์—ฐ์‚ฐ ๊ฒฐ๊ณผ๋Š” ์ „์ฒด ์‹œํ€€์Šค์— ์†ํ•œ ํ† ํฐ๋“ค ์‚ฌ์ด์˜ โ€œ๋‹ฎ์€ ์ •๋„โ€ ๊ฐ€ ์ˆ˜์น˜๋กœ ๋ณ€ํ™˜๋˜์–ด ํ–‰๋ ฌ์— ๋‹ด๊ธด๋‹ค.

์™œ ๋‚ด์  ๊ฒฐ๊ณผ๊ฐ€ ์ค‘์š”๋„์™€ ๊ฐ™์€ ์˜๋ฏธ๋ฅผ ๊ฐ–๊ฒŒ ๋˜๋Š” ๊ฒƒ์ผ๊นŒ?? ์•„๊นŒ Input Embedding๊ณผ Position Embedding์„ ํ–‰๋ ฌํ•ฉ ํ•˜๋Š” ๊ฒƒ์— ๋Œ€ํ•œ ๋‹น์œ„์„ฑ์„ ์„ค๋ช…ํ•˜๋ฉด์„œ ๊ณ ์ฐจ์›์œผ๋กœ ๊ฐˆ์ˆ˜๋ก ๋Œ€๋ถ€๋ถ„์˜ ๋ฒกํ„ฐ ์Œ์€ ์ง๊ต์„ฑ์„ ๊ฐ–๊ฒŒ ๋œ๋‹ค๊ณ  ์–ธ๊ธ‰ํ•œ ๋ฐ” ์žˆ๋‹ค. ๊ทธ๋ž˜์„œ ๋‘ ๋ฒกํ„ฐ๊ฐ€ ๋น„์Šทํ•œ ๋ฐฉํ–ฅ์„ฑ์„ ๊ฐ–๋Š”๋‹ค๋Š” ๊ฒƒ ์ž์ฒด๊ฐ€ ๋งค์šฐ ๋“œ๋ฌธ์ผ์ด๋‹ค. ํฌ๊ท€ํ•˜๊ณ  ๋“œ๋ฌธ ์‚ฌ๊ฑด์€ ๊ทธ๋งŒํผ ์ค‘์š”ํ•˜๋‹ค๊ณ  ๋งํ•  ์ˆ˜ ์žˆ๊ธฐ ๋•Œ๋ฌธ์— ๋‚ด์  ๊ฒฐ๊ณผ๋ฅผ ์ค‘์š”๋„์— ๋งตํ•‘ํ•˜๋Š” ๊ฒƒ์ด๋‹ค.

ํ•œํŽธ, ํ–‰๋ ฌ $Q,K$ ๋ชจ๋‘ ์ฐจ์›์ด [Batch, Max_Seq, Dim_Head] ์ธ ํ…์„œ๋ผ์„œ ๋‚ด์ ํ•œ ๊ฒฐ๊ณผ์˜ ๋ชจ์–‘์€ [Batch, Max_Seq, Max_Seq] ์ด ๋  ๊ฒƒ์ด๋‹ค.

๐Ÿ”ญย Stage 2. Scale

\[Qโ€ขK^T = \begin{bmatrix} 56.8 & 12.1 & 43.5 \\ 30.4 & 100.8 & 24.2 \\ 11.11 & 7.34 & 20.23 \\ \end{bmatrix}\]

โ€œI am dogโ€ ๋ผ๋Š” ๋ฌธ์žฅ์„ $Qโ€ขK^T$ํ•˜๋ฉด ์œ„์™€ ๊ฐ™์€ 3x3 ์งœ๋ฆฌ ํ–‰๋ ฌ์ด ๋‚˜์˜ฌ ๊ฒƒ์ด๋‹ค. ํ–‰๋ ฌ์„ ํ–‰๋ฒกํ„ฐ๋กœ ๋ฐ”๋ผ๋ณด์ž. ํ–‰ ์‚ฌ์ด์˜ ๊ฐ’์˜ ๋ถ„ํฌ๊ฐ€ ๊ณ ๋ฅด์ง€ ๋ชปํ•˜๋‹ค๋Š” ๊ฒƒ์„ ์•Œ ์ˆ˜ ์žˆ๋‹ค. ์ด๋ ‡๊ฒŒ ๋ถ„์‚ฐ์ด ํฐ ์ƒํƒœ๋กœ softmax ์— ํ†ต๊ณผ์‹œํ‚ค๊ฒŒ ๋˜๋ฉด ์—ญ์ „ํŒŒ ๊ณผ์ •์—์„œ softmax ์˜ ๋ฏธ๋ถ„๊ฐ’์ด ์ค„์–ด ๋“ค์–ด ํ•™์Šต ์†๋„๊ฐ€ ๋Š๋ ค์ง€๊ณ  ๋‚˜์•„๊ฐ€ vanishing gradient ํ˜„์ƒ์ด ๋ฐœ์ƒํ•  ์ˆ˜ ์žˆ๋‹ค. ๋”ฐ๋ผ์„œ ํ–‰๋ฒกํ„ฐ ์‚ฌ์ด์˜ ๋ถ„์‚ฐ์„ ์ค„์—ฌ์ฃผ๊ธฐ ์œ„ํ•ด์„œ Scale Factor ๋ฅผ ์ •์˜ํ•˜๊ฒŒ ๋œ๋‹ค. ๊ทธ๋ ‡๋‹ค๋ฉด ์–ด๋–ค Scale Factor ๋ฅผ ์จ์•ผํ• ๊นŒ??

\[\frac{Qโ€ขK^T}{\sqrt{d_h}}\]

์• ์ดˆ์— Dim Head ์ฐจ์›์— ์†ํ•œ ๊ฐ’๋“ค์˜ ๋ถ„์‚ฐ์ด ํฐ ๊ฒƒ๋„ ๋ฌธ์ œ๊ฐ€ ๋˜์ง€๋งŒ ์ด๊ฒƒ์€ Input Embedding์ด๋‚˜ Position Embedding์— layernorm ์„ ์ ์šฉํ•˜๋ฉด ํ•ด๊ฒฐํ•  ์ˆ˜ ์žˆ๊ธฐ ๋•Œ๋ฌธ์— ๋…ผ์˜ ๋Œ€์ƒ์ด ์•„๋‹ˆ๋‹ค. ๊ทธ๊ฒƒ๋ณด๋‹ค๋Š” ๋‚ด์  ๊ณผ์ •์— ์ฃผ๋ชฉํ•ด๋ณด์ž. ์šฐ๋ฆฌ๋Š” ๋‚ด์ ์„ ํ•˜๋‹ค๋ณด๋ฉด Dim Head์˜ ์ฐจ์›์ด ์ปค์งˆ์ˆ˜๋ก ๋”ํ•ด์ค˜์•ผ ํ•˜๋Š” ์Šค์นผ๋ผ ๊ฐ’์˜ ๊ฐœ์ˆ˜๊ฐ€ ๋Š˜์–ด๋‚˜๊ฒŒ ๋œ๋‹ค๋Š” ์‚ฌ์‹ค์„ ์•Œ ์ˆ˜ ์žˆ๋‹ค. ๋งŒ์•ฝ ์œ„์—์„œ ์˜ˆ์‹œ๋กœ ๋“  ์ˆ˜์‹์˜ Dim Head๊ฐ€ 64๋ผ๊ณ  ๊ฐ€์ •ํ•ด๋ณด์ž. ๊ทธ๋Ÿผ ์šฐ๋ฆฌ๋Š” 1ํ–‰ 1์—ด์˜ ๊ฐ’์„ ์–ป๊ธฐ ์œ„ํ•ด 64๊ฐœ์˜ ์Šค์นผ๋ผ ๊ฐ’์„ ๋”ํ•ด์ค˜์•ผ ํ•œ๋‹ค. ๋งŒ์•ฝ 512์ฐจ์›์ด๋ผ๋ฉด 512๊ฐœ๋กœ ๋ถˆ์–ด๋‚œ๋‹ค. ๋”ํ•ด์ค˜์•ผ ํ•˜๋Š” ์Šค์นผ๋ผ ๊ฐ’์ด ๋งŽ์•„์ง„๋‹ค๋ฉด ํ–‰๋ฒกํ„ฐ ๋ผ๋ฆฌ์˜ ๋ถ„์‚ฐ์ด ์ปค์งˆ ์šฐ๋ ค๊ฐ€ ์žˆ๋‹ค. ๋”ฐ๋ผ์„œ ์ฐจ์› ํฌ๊ธฐ์˜ ์Šค์ผ€์ผ์— ๋”ฐ๋ผ softmax์˜ ๋ฏธ๋ถ„๊ฐ’์ด ์ค„์–ด๋“œ๋Š” ๊ฒƒ์„ ๋ฐฉ์ง€ํ•˜๊ธฐ ์œ„ํ•ด $Qโ€ขK^T$๊ฒฐ๊ณผ์— $\sqrt{d_h}$๋ฅผ ๋‚˜๋ˆ  ์ค€๋‹ค.

์—ฌ๋‹ด์œผ๋กœ ์ด๋Ÿฌํ•œ scale factor ์˜ ์กด์žฌ ๋•Œ๋ฌธ์— Self-Attention์„ Scaled Dot-Product Attention ์ด๋ผ๊ณ  ๋ถ€๋ฅด๊ธฐ๋„ ํ•œ๋‹ค.

๐ŸŽญย Stage 3. masking
๋งˆ์Šคํ‚น์€ ์ธ์ฝ”๋” Input Padding, ๋””์ฝ”๋” Masked Multi-Head Attention, ์ธ์ฝ”๋”-๋””์ฝ”๋” Self-Attention ์„ ์œ„ํ•ด ํ•„์š”ํ•œ ๊ณ„์ธต์ด๋‹ค. ๋’ค์— ๋‘๊ฐœ๋Š” ๋””์ฝ”๋”์˜ ๋™์ž‘์„ ์•Œ์•„์•ผ ์ดํ•ด๊ฐ€ ๊ฐ€๋Šฅํ•˜๊ธฐ ๋•Œ๋ฌธ์— ์—ฌ๊ธฐ์„œ๋Š” ์ธ์ฝ”๋”์˜ ๋งˆ์Šคํ‚น์— ๋Œ€ํ•ด์„œ๋งŒ ์•Œ์•„๋ณด์ž.

Encoder Padding Mask Encoder Padding Mask

์‹ค์ œ ํ…์ŠคํŠธ ๋ฐ์ดํ„ฐ๋Š” ๋ฐฐ์น˜๋œ ์‹œํ€€์Šค๋งˆ๋‹ค ๊ทธ ๊ธธ์ด๊ฐ€ ์ œ๊ฐ๊ฐ์ด๋‹ค. ํšจ์œจ์„ฑ์„ ์œ„ํ•ด ํ–‰๋ ฌ์„ ์‚ฌ์šฉํ•˜๋Š” ์ปดํ“จํ„ฐ ์—ฐ์‚ฐ ํŠน์„ฑ์ƒ ๋ฐฐ์น˜๋œ ์‹œํ€€์Šค์˜ ๊ธธ์ด๊ฐ€ ๋ชจ๋‘ ๋‹ค๋ฅด๋‹ค๋ฉด ์—ฐ์‚ฐ์„ ์ง„ํ–‰ํ•  ์ˆ˜๊ฐ€ ์—†๋‹ค. ๋”ฐ๋ผ์„œ ๋ฐฐ์น˜ ๋‚ด๋ถ€์˜ ๋ชจ๋“  ์‹œํ€€์Šค์˜ ๊ธธ์ด๋ฅผ ํ†ต์ผํ•ด์ฃผ๋Š” ์ž‘์—…์„ ํ•˜๊ฒŒ ๋˜๋Š”๋ฐ, ์ด ๋•Œ ๊ธฐ์ค€ ๊ธธ์ด๋ณด๋‹ค ์งง์€ ์‹œํ€€์Šค์— ๋Œ€ํ•ด์„œ๋Š” 0๊ฐ’์„ ์ฑ„์›Œ๋„ฃ๋Š” padding ์ž‘์—…์„ ํ•œ๋‹ค. ํ–‰๋ ฌ ์—ฐ์‚ฐ์—๋Š” ๊ผญ ํ•„์š”ํ–ˆ๋˜ padding์€ ์˜คํžˆ๋ ค softmax ๋ ˆ์ด์–ด๋ฅผ ๊ณ„์‚ฐํ•  ๋•Œ ๋ฐฉํ•ด๊ฐ€ ๋œ๋‹ค. ๋”ฐ๋ผ์„œ ๋ชจ๋“  padding ๊ฐ’์„ softmax์˜ ํ™•๋ฅ  ๊ณ„์‚ฐ์—์„œ ์™„์ „ํžˆ ์ œ์™ธ์‹œํ‚ค๊ธฐ ์œ„ํ•ด Input Embedding์—์„œ padding token์˜ ์ธ๋ฑ์Šค๋ฅผ ์ €์žฅํ•˜๊ณ  ํ•ด๋‹น๋˜๋Š” ๋ชจ๋“  ์›์†Œ๋ฅผ -โˆž ๋กœ ๋งˆ์Šคํ‚นํ•˜๋Š” ๊ณผ์ •์ด ํ•„์š”ํ•˜๋‹ค.

์ด ๋•Œ ๋งˆ์Šคํ‚น ์ฒ˜๋ฆฌ๋Š” ์—ด๋ฒกํ„ฐ์—๋งŒ ์ ์šฉํ•œ๋‹ค. ๊ทธ ์ด์œ ๋Š” ๋ฐ”๋กœ softmax ๊ณ„์‚ฐ์„ ์–ด์ฐจํ”ผ ํ–‰๋ฒกํ„ฐ ๋ฐฉํ–ฅ์œผ๋กœ๋งŒ ํ•  ๊ฒƒ์ด๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค. ํ–‰๋ฒกํ„ฐ ๋ฐฉํ–ฅ์˜ padding token์—๋„ ๋™์ผํ•˜๊ฒŒ ๋งˆ์Šคํ‚น ์ ์šฉํ•˜๋Š” ๊ฒƒ์€ ์ƒ๊ด€ ์—†์œผ๋‚˜ ์—ด๋ฒกํ„ฐ์™€ ํ–‰๋ฒกํ„ฐ ๋™์‹œ์— ๋งˆ์Šคํ‚น ์ ์šฉํ•˜๋Š” ๋™์ž‘์„ ๊ตฌํ˜„ํ•˜๋Š” ๊ฒƒ์€ ์ƒ๊ฐ๋ณด๋‹ค ๋งŽ์ด ๊นŒ๋‹ค๋กœ์šฐ๋ฉฐ, ๋‚˜์ค‘์— ์†์‹ค๊ฐ’ ๊ณ„์‚ฐํ•˜๋Š” ๋‹จ๊ณ„์—์„œ ignore_index ์˜ต์…˜์„ ์‚ฌ์šฉํ•ด ํ–‰๋ฒกํ„ฐ์˜ padding token์„ ๋ฌด์‹œํ•˜๋Š” ๊ฒƒ์ด ํ›จ์”ฌ ํšจ์œจ์ ์ด๋‹ค. ํ•œํŽธ, ignore_index ์˜ต์…˜์€ nn.CrossEntropyLoss ์— ๋งค๊ฐœ๋ณ€์ˆ˜๋กœ ๊ตฌํ˜„ ๋˜์–ด ์žˆ๋‹ค.

๐Ÿ“ˆย Stage 4. Softmax & Scoreโ€ขV

\[Score = \begin{bmatrix} 0.90 & 0.07 & 0.03 \\ 0.025 & 0.95 & 0.025 \\ 0.21 & 0.03 & 0.76 \end{bmatrix}, \ \ V=\begin{bmatrix} 67.85 & 90 & 91 & ..... \\ 62 & 40 & 50 & ..... \\ 37 & 41 & 20 & ..... \end{bmatrix},\ \ Z = score \ โ€ข \ V\] \[{\overset{}{z_{1}^{}}} = {\overset{}{Score_{11}^{}}}({\overset{}{V_{11}^{}}}\ + \ {\overset{}{V_{12}^{}}}\ + \ ...) \ + \ {\overset{}{Score_{12}^{}}}({\overset{}{V_{21}^{}}}\ + \ {\overset{}{V_{22}^{}}}\ + \ ...)\ + \ ....... \\ {\overset{}{z_{2}^{}}} = {\overset{}{Score_{21}^{}}}({\overset{}{V_{11}^{}}}\ + \ {\overset{}{V_{12}^{}}}\ + \ ...) \ + \ {\overset{}{Score_{22}^{}}}({\overset{}{V_{21}^{}}}\ + \ {\overset{}{V_{22}^{}}}\ + \ ...)\ + \ ....... \\ {\overset{}{z_{3}^{}}} = {\overset{}{Score_{31}^{}}}({\overset{}{V_{11}^{}}}\ + \ {\overset{}{V_{12}^{}}}\ + \ ...) \ + \ {\overset{}{Score_{32}^{}}}({\overset{}{V_{21}^{}}}\ + \ {\overset{}{V_{22}^{}}}\ + \ ...)\ + \ ....... \\\]

๊ณ„์‚ฐ๋œ ์œ ์‚ฌ๋„(๋‚ด์  ๊ฒฐ๊ณผ, ์ค‘์š”๋„, ๊ฐ€์ค‘์น˜), $\frac{Qโ€ขK^T}{\sqrt{d_h}}$๋Š” ์ดํ›„์— ํ–‰๋ ฌ $V$์™€ ๋‹ค์‹œ ๊ณฑํ•ด์ ธ ํ–‰๋ฒกํ„ฐ $Z_n$(n๋ฒˆ์งธ ํ† ํฐ)์—์„œ ํ† ํฐ์— ๋Œ€ํ•œ ์–ดํ…์…˜ ์ •๋„๋ฅผ ๋‚˜ํƒ€๋‚ด๋Š” ๊ฐ€์ค‘์น˜์˜ ์—ญํ• ์„ ํ•˜๊ฒŒ ๋œ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ๊ณ„์‚ฐ๋œ ์œ ์‚ฌ๋„๋Š” ๋น„์ •๊ทœํ™”๋œ ํ˜•ํƒœ๋‹ค. ์ˆ˜์‹์—๋Š” ํŽธ์˜์ƒ ์ด๋ฏธ softmax๋ฅผ ์ ์šฉํ•œ ํ˜•ํƒœ์˜ ํ–‰๋ ฌ์„ ์ ์—ˆ์ง€๋งŒ, ์‹ค์ œ๋กœ๋Š” ์›์†Œ๊ฐ’์˜ ๋ถ„์‚ฐ์ด ๋„ˆ๋ฌด ์ปค์„œ ๊ฐ€์ค‘์น˜๋กœ๋Š” ์“ฐ๊ธฐ ํž˜๋“  ์ˆ˜์ค€์ด๋‹ค. ๋”ฐ๋ผ์„œ ํ–‰๋ฒกํ„ฐ ๋‹จ์œ„๋กœ softmax์— ํ†ต๊ณผ์‹œ์ผœ ๊ฒฐ๊ณผ์˜ ํ•ฉ์ด 1์ธ ํ™•๋ฅ ๊ฐ’์œผ๋กœ ๋ณ€ํ™˜(์ •๊ทœํ™”)ํ•ด ํ–‰๋ ฌ $V$์˜ ๊ฐ€์ค‘์น˜๋กœ ์‚ฌ์šฉํ•œ๋‹ค.

์ด์ œ ๋‘๋ฒˆ์งธ ์ˆ˜์‹์„ ๋ณด์ž. $Score_{11}$์— ํ•ด๋‹นํ•˜๋Š” 0.90๊ฐ€ ํ–‰๋ ฌ $V$์˜ ์ฒซ๋ฒˆ์งธ ํ–‰๋ฒกํ„ฐ์™€ ๊ณฑํ•ด์ง€๊ณ  ์žˆ๋‹ค. ํ–‰๋ ฌ $V$์˜ ์ฒซ๋ฒˆ์งธ ํ–‰๋ฒกํ„ฐ๋Š” ํ† ํฐ โ€œIโ€ ๋ฅผ 512์ฐจ์›์œผ๋กœ ํ‘œํ˜„ํ•œ ๊ฒƒ์ด๋‹ค. ๊ทธ ๋‹ค์Œ $Score_{12}$๋Š” ํ–‰๋ ฌ $V$์˜ ๋‘๋ฒˆ์งธ ํ–‰๋ฒกํ„ฐ์™€, $Score_{13}$์€ ํ–‰๋ ฌ $V$์˜ ์„ธ๋ฒˆ์งธ ํ–‰๋ฒกํ„ฐ์™€ ๊ฐ๊ฐ ๊ณฑํ•ด์ง„๋‹ค.

์ด ํ–‰์œ„์˜ ์˜๋ฏธ๋Š” ๋ฌด์—‡์ผ๊นŒ?? $Score_{11}$, $Score_{12}$, $Score_{13}$์€ ๋ชจ๋‘ ์ฒซ๋ฒˆ์งธ ํ† ํฐ์ธ โ€œIโ€์— ์˜๋ฏธ๋ฅผ ํŒŒ์•…ํ•˜๋Š”๋ฐ โ€œIโ€, โ€œamโ€, โ€œdogโ€๋ฅผ ์–ด๋Š ์ •๋„๋กœ ์–ดํ…์…˜ํ•ด์•ผ ํ•˜๋Š”์ง€, ์ฆ‰ โ€œIโ€์˜ ์˜๋ฏธ๋ฅผ ํ‘œํ˜„ํ•˜๋Š”๋ฐ ์„ธ ํ† ํฐ์˜ ์˜๋ฏธ๋ฅผ ์–ด๋Š ์ •๋„ ๋ฐ˜์˜ํ• ์ง€ ์ˆ˜์น˜๋กœ ํ‘œํ˜„ํ•œ ๊ฒƒ์ด๋‹ค. ๋‹น์—ฐํžˆ ์ž๊ธฐ ์ž์‹ ์ธ โ€œIโ€์™€ ๊ฐ€์ค‘์น˜(์œ ์‚ฌ๋„, ์ค‘์š”๋„)๊ฐ€ ๊ฐ€์žฅ ๋†’๊ธฐ ๋•Œ๋ฌธ์— ํ–‰๋ ฌ $V$์—์„œ โ€œIโ€ ์— ํ•ด๋‹นํ•˜๋Š” ํ–‰๋ฒกํ„ฐ ๊ฐ€์ค‘์น˜์— ๊ฐ€์žฅ ํฐ ๊ฐ’์ด ๋“ค์–ด๊ฐ„๋‹ค๊ณ  ์ƒ๊ฐํ•ด๋ณผ ์ˆ˜ ์žˆ๋‹ค. ์ด๋ ‡๊ฒŒ ๊ฐ ํ† ํฐ๋งˆ๋‹ค ๊ฐ€์ค‘ํ•ฉ์„ ๋ฐ˜๋ณตํ•ด์ฃผ๋ฉด ์ตœ์ข…์ ์œผ๋กœ โ€œIโ€, โ€œamโ€, โ€œdogโ€ ์„ ์ธ์ฝ”๋”ฉํ•œ $Z_1, \ Z_2, \ Z_3$ ๊ฐ’์„ ์–ป์„ ์ˆ˜ ์žˆ๋‹ค.

๐Ÿ‘ฉโ€๐Ÿ’ปย Implementation

์ด๋ ‡๊ฒŒ Scaled Dot-Product Attention ์„ ๋ชจ๋‘ ์‚ดํŽด๋ณด์•˜๋‹ค. ํ•ด๋‹น ๋ ˆ์ด์–ด๋Š” ๋ชจ๋ธ์ด ์†์‹ค๊ฐ’์ด ๊ฐ€์žฅ ์ž‘์•„์ง€๋Š” ๋ฐฉํ–ฅ์œผ๋กœ ์ตœ์ ํ™”ํ•œ ํ–‰๋ ฌ $Q, K, V$ ์„ ์ด์šฉํ•ด, ํ† ํฐ์˜ ์˜๋ฏธ๋ฅผ ์ดํ•ดํ•˜๋Š”๋ฐ ์–ด๋–ค ๋งฅ๋ฝ๊ณผ ํ‘œํ˜„์— ์ข€ ๋” ์ง‘์ค‘ํ•˜๊ณ  ๋œ ์ง‘์ค‘ํ•ด์•ผ ํ•˜๋Š”์ง€๋ฅผ ์œ ์‚ฌ๋„๋ฅผ ๊ธฐ์ค€์œผ๋กœ ํŒ๋‹จํ•œ๋‹ค๋Š” ๊ฒƒ์„ ๊ผญ ๊ธฐ์–ตํ•˜์ž. ๊ทธ๋ ‡๋‹ค๋ฉด ์‹ค์ œ ์ฝ”๋“œ๋Š” ์–ด๋–ป๊ฒŒ ์ž‘์„ฑ ํ•ด์•ผํ•˜๋Š”์ง€ ํ•จ๊ป˜ ์•Œ์•„๋ณด์ž. ์ƒ๋‹จ์˜ class diagram ์„ ๋‹ค์‹œ ํ•œ ๋ฒˆ ๋ณด๊ณ  ๋Œ์•„์˜ค์ž.

# Pytorch Implementation of Scaled Dot-Product Self-Attention

def scaled_dot_product_attention(q: Tensor, k: Tensor, v: Tensor, dot_scale: Tensor, mask: Tensor = None) -> Tensor:
    """
    Scaled Dot-Product Attention with Masking for Decoder
    Args:
        q: query matrix, shape (batch_size, seq_len, dim_head)
        k: key matrix, shape (batch_size, seq_len, dim_head)
        v: value matrix, shape (batch_size, seq_len, dim_head)
        dot_scale: scale factor for Qโ€ขK^T result
        mask: there are three types of mask, mask matrix shape must be same as single attention head
              1) Encoder padded token
              2) Decoder Masked-Self-Attention
              3) Decoder's Encoder-Decoder Attention
    Math:
        A = softmax(qโ€ขk^t/sqrt(D_h)), SA(z) = Av
    """
    attention_matrix = torch.matmul(q, k.transpose(-1, -2)) / dot_scale
    if mask is not None:
        attention_matrix = attention_matrix.masked_fill(mask == 0, float('-inf'))
    attention_dist = F.softmax(attention_matrix, dim=-1)
    attention_matrix = torch.matmul(attention_dist, v)
    return attention_matrix

๋งˆ์Šคํ‚น ์˜ต์…˜์˜ ๊ฒฝ์šฐ ์ฃผ์„์— ์ •๋ฆฌ๋œ 3๊ฐ€์ง€ ์ƒํ™ฉ ์ค‘์—์„œ ํ•œ ๊ฐœ ์ด์ƒ์— ํ•ด๋‹น๋˜๋ฉด ์‹คํ–‰๋˜๋„๋ก ์ฝ”๋“œ๋ฅผ ์ž‘์„ฑํ–ˆ๋‹ค. 3๊ฐ€์ง€ ์ƒํ™ฉ๊ณผ ๊ตฌ์ฒด์ ์ธ ๋งˆ์Šคํ‚น ๋ฐฉ๋ฒ•์— ๋Œ€ํ•ด์„œ๋Š” ์ „์ฒด ๋ชจ๋ธ ๊ตฌ์กฐ๋ฅผ ๋ณด๋Š” ๋•Œ ์†Œ๊ฐœํ•˜๋„๋ก ํ•˜๊ฒ ๋‹ค.

ํ•œํŽธ, ์ธ์ฝ”๋”๋‚˜ ๋””์ฝ”๋”๋‚˜ ๋ชจ๋‘ ์‚ฌ์šฉํ•˜๋Š” ์ž…๋ ฅ๊ณผ ๋งˆ์Šคํ‚น ๋ฐฉ์‹์— ์ฐจ์ด๋Š” ์žˆ์ง€๋งŒ, Scaled Dot-Product Attention ์—ฐ์‚ฐ ์ž์ฒด๋Š” ๋™์ผํ•œ ๊ฒƒ์„ ์‚ฌ์šฉํ•œ๋‹ค. ๋”ฐ๋ผ์„œ ์—ฌ๋Ÿฌ๊ฐœ์˜ ์ธ์ฝ”๋”๋‚˜ ๋””์ฝ”๋” ๊ฐ์ฒด๋“ค ํ˜น์€ ์–ดํ…์…˜ ํ•ด๋“œ ๊ฐ์ฒด๋“ค์ด ๋ชจ๋‘ ์‰ฝ๊ฒŒ ์—ฐ์‚ฐ์— ์ ‘๊ทผํ•  ์ˆ˜ ์žˆ๊ฒŒ ํด๋ž˜์Šค ์™ธ๋ถ€์— ๋ฉ”์„œ๋“œ ํ˜•ํƒœ๋กœ ๊ตฌํ˜„ํ•˜๊ฒŒ ๋˜์—ˆ๋‹ค.

๐Ÿ‘ฉโ€๐Ÿ‘ฉโ€๐Ÿ‘งโ€๐Ÿ‘ฆย Multi-Head Attention Block

์ง€๊ธˆ๊นŒ์ง€ ์‚ดํŽด๋ณธ Self-Attention์˜ ๋™์ž‘์€ ๋ชจ๋‘ ํ•œ ๊ฐœ์˜ Attention-Head์—์„œ ์ผ์–ด๋‚˜๋Š” ์ผ์„ ์„œ์ˆ ํ•œ ๊ฒƒ์ด๋‹ค. ์‚ฌ์‹ค ์‹ค์ œ ๋ชจ๋ธ์—์„œ๋Š” ๊ฐ™์€ ๋™์ž‘์ด N-1๊ฐœ์˜ ๋‹ค๋ฅธ ํ•ด๋“œ์—์„œ ๋™์‹œ์— ์ผ์–ด๋‚˜๋Š”๋ฐ, ์ด๊ฒƒ์ด ๋ฐ”๋กœ Multi-Head Attention์ด๋‹ค.

Official Paper ๊ธฐ์ค€์œผ๋กœ Transformer-base์˜ hidden states ์ฐจ์›์€ 512์ด๋‹ค. ์ด๊ฒƒ์„ ๊ฐœ๋‹น 64์ฐจ์›์„ ๊ฐ–๋Š” 8๊ฐœ์˜ Attention-Head ๋กœ ์ชผ๊ฐ  ๋’ค, 8๊ฐœ์˜ Attention-Head ์—์„œ ๋™์‹œ์— Self-Attention ์„ ์ˆ˜ํ–‰ํ•œ๋‹ค. ์ดํ›„ ๊ฒฐ๊ณผ๋ฅผ concatํ•˜์—ฌ ๋‹ค์‹œ hidden states ๋ฅผ 512 ๋กœ ๋งŒ๋“  ๋’ค, ์—ฌ๋Ÿฌ ํ•ด๋“œ์—์„œ ๋งŒ๋“  ๊ฒฐ๊ณผ๋ฅผ ์—ฐ๊ฒฐํ•˜๊ณ  ์„ž์–ด์ฃผ๊ธฐ ์œ„ํ•ด ์ž…์ถœ๋ ฅ ์ฐจ์›์ด hidden states์™€ ๋™์ผํ•œ linear projection layer์— ํ†ต๊ณผ์‹œํ‚จ๋‹ค. ์ด๊ฒƒ์ด ์ธ์ฝ”๋”(ํ˜น์€ ๋””์ฝ”๋”) ๋ธ”๋Ÿญ ํ•œ ๊ฐœ์˜ ์ตœ์ข… Self-Attention ๊ฒฐ๊ณผ๊ฐ€ ๋œ๋‹ค.

Multi-Head Attention Result Visualization Multi-Head Attention Result Visualization

๊ทธ๋Ÿผ ์™œ ์ด๋ ‡๊ฒŒ ์—ฌ๋Ÿฌ ํ•ด๋“œ๋ฅผ ์‚ฌ์šฉํ–ˆ์„๊นŒ?? ๋ฐ”๋กœ ์ง‘๋‹จ์ง€์„ฑ์˜ ํšจ๊ณผ๋ฅผ ๋ˆ„๋ฆฌ๊ธฐ ์œ„ํ•จ์ด๋‹ค. ์ƒ๊ฐํ•ด๋ณด์ž. ์ฑ… ํ•˜๋‚˜๋ฅผ ์ฝ์–ด๋„ ์‚ฌ๋žŒ๋งˆ๋‹ค ์ •๋ง ๋‹ค์–‘ํ•œ ํ•ด์„์ด ๋‚˜์˜จ๋‹ค. ๋ชจ๋ธ๋„ ๋งˆ์ฐฌ๊ฐ€์ง€๋‹ค. ์—ฌ๋Ÿฌ ํ•ด๋“œ๋ฅผ ์‚ฌ์šฉํ•ด์„œ ์ข€ ๋” ๋‹ค์–‘ํ•˜๊ณ  ํ’๋ถ€ํ•œ ์˜๋ฏธ๋ฅผ ์ž„๋ฒ ๋”ฉ์— ๋‹ด๊ณ  ์‹ถ์—ˆ๋˜ ๊ฒƒ์ด๋‹ค. Kaggle์„ ํ•ด๋ณด์‹  ๋…์ž๋ผ๋ฉด, ์—ฌ๋Ÿฌ ์ „๋žต์„ ์‚ฌ์šฉํ•ด ์—ฌ๋Ÿฌ ๊ฐœ์˜ ๊ฒฐ๊ณผ๋ฅผ ๋„์ถœํ•œ ๋’ค, ๋งˆ์ง€๋ง‰์— ๋ชจ๋‘ ์•™์ƒ๋ธ”ํ•˜๋ฉด ์ „๋žต ํ•˜๋‚˜ ํ•˜๋‚˜์˜ ๊ฒฐ๊ณผ๋ณด๋‹ค ๋” ๋†’์€ ์„ฑ์ ์„ ์–ป์–ด๋ณธ ๊ฒฝํ—˜์ด ์žˆ์„ ๊ฒƒ์ด๋‹ค. ์ด๊ฒƒ๋„ ๋น„์Šทํ•œ ํšจ๊ณผ๋ฅผ ์˜๋„ํ–ˆ๋‹ค๊ณ  ์ƒ๊ฐํ•œ๋‹ค. Vision์—์„œ Conv Filter๋ฅผ ์—ฌ๋Ÿฌ ์ข…๋ฅ˜ ์‚ฌ์šฉํ•ด ๋‹ค์–‘ํ•œ Feature Map์„ ์ถ”์ถœํ•˜๋Š” ๊ฒƒ๋„ ๋น„์Šทํ•œ ํ˜„์ƒ์ด๋ผ ๋ณผ ์ˆ˜ ์žˆ๊ฒ ๋‹ค.

์œ„ ๊ทธ๋ฆผ์€ ์ €์ž๊ฐ€ ์ œ์‹œํ•œ Multi-Head Attention์˜ ์‹œ๊ฐํ™” ๊ฒฐ๊ณผ๋‹ค. ์ค‘๊ฐ„์— ์žˆ๋Š” ์—ฌ๋Ÿฌ ์ƒ‰๊น”์˜ ๋ ๋Š” ๊ฐœ๋ณ„ ํ•ด๋“œ๊ฐ€ ์–ดํ…์…˜ํ•˜๋Š” ๋ฐฉํ–ฅ์„ ๊ฐ€๋ฆฌํ‚จ๋‹ค. ํ† ํฐ โ€œmakingโ€ ์— ๋Œ€ํ•ด์„œ ํ•ด๋“œ๋“ค์ด ์„œ๋กœ ๋‹ค๋ฅธ ํ† ํฐ์— ์–ดํ…์…˜ํ•˜๊ณ  ์žˆ๋‹ค.

ViT Multi-Head Attention Result Visualization ViT Multi-Head Attention Result Visualization

์œ„ ๊ทธ๋ฆผ์€ Vision Transformer ๋…ผ๋ฌธ์—์„œ ๋ฐœ์ทŒํ•œ ๊ทธ๋ฆผ(๊ทธ๋ฆผ์˜ ์ž์„ธํ•œ ์˜๋ฏธ๋Š” ์—ฌ๊ธฐ์„œ)์ด๋‹ค. ์—ญ์‹œ ๋งˆ์ฐฌ๊ฐ€์ง€๋กœ ๋ชจ๋ธ์˜ ์ดˆ๋ฐ˜๋ถ€ ์ธ์ฝ”๋”์— ์†ํ•œ Multi-Head๋“ค์ด ์„œ๋กœ ๋‹ค์–‘ํ•œ ํ† ํฐ์— ์–ดํ…์…˜์„ ํ•˜๊ณ  ์žˆ์Œ์„ ์•Œ ์ˆ˜ ์žˆ๋‹ค. ์ถ”๊ฐ€๋กœ ํ›„๋ฐ˜์œผ๋กœ ๊ฐˆ์ˆ˜๋ก ์ ์  Attention Distance ๊ฐ€ ์ผ์ •ํ•œ ์ˆ˜์ค€์— ์ˆ˜๋ ดํ•˜๋Š” ๋ชจ์Šต์„ ๋ณผ ์ˆ˜ ์žˆ๋Š”๋ฐ, ์ด๊ฒƒ์„ ๋ ˆ์ด์–ด๋ฅผ ํ†ต๊ณผํ• ์ˆ˜๋ก ๊ฐœ๋ณ„ ํ•ด๋“œ๊ฐ€ ์ž์‹ ์ด ์–ด๋–ค ํ† ํฐ์— ์ฃผ์˜๋ฅผ ๊ธฐ์šธ์—ฌ์•ผํ• ์ง€ ๊ตฌ์ฒด์ ์œผ๋กœ ์•Œ์•„๊ฐ€๋Š” ๊ณผ์ •์ด๋ผ๊ณ  ํ•ด์„ํ•  ์ˆ˜ ์žˆ๋‹ค. ์ดˆ๋ฐ˜๋ถ€์—๋Š” ์–ด์ฐŒํ•  ๋ฐ”๋ฅผ ๋ชฐ๋ผ์„œ ์ดํ† ํฐ ์ €ํ† ํฐ์— ์ฃ„๋‹ค ์–ดํ…์…˜ํ•˜๋Š” ๊ฒƒ์ด๋‹ค.

๊ทธ๋ž˜์„œ Transformer๋Š” Bottom Layer์—์„œ๋Š” Globalํ•˜๊ณ  Generalํ•œ ์ •๋ณด๋ฅผ ํฌ์ฐฉํ•˜๊ณ , Output๊ณผ ๊ฐ€๊นŒ์šด Top Layer์—์„œ๋Š” Localํ•˜๊ณ  Specificํ•œ ์ •๋ณด๋ฅผ ํฌ์ฐฉํ•œ๋‹ค.

๐Ÿ‘ฉโ€๐Ÿ’ปย Implementation

์ด์ œ ๊ตฌํ˜„์„ ์‹ค์ œ๋กœ ๊ตฌํ˜„์„ ํ•ด๋ณด์ž. ์—ญ์‹œ ๊ตฌํ˜„์€ ํŒŒ์ดํ† ์น˜๋กœ ์ง„ํ–‰ํ–ˆ๋‹ค.

# Pytorch Implemenation of Single Attention Head

class AttentionHead(nn.Module):
    """
    In this class, we implement workflow of single attention head
    Args:
        dim_model: dimension of model's latent vector space, default 512 from official paper
        dim_head: dimension of each attention head, default 64 from official paper (512 / 8)
        dropout: dropout rate, default 0.1
    Math:
        [q,k,v]=zโ€ขU_qkv, A = softmax(qโ€ขk^t/sqrt(D_h)), SA(z) = Av
    """
    def __init__(self, dim_model: int = 512, dim_head: int = 64, dropout: float = 0.1) -> None:
        super(AttentionHead, self).__init__()
        self.dim_model = dim_model
        self.dim_head = dim_head  # 512 / 8 = 64
        self.dropout = dropout
        self.dot_scale = torch.sqrt(torch.tensor(self.dim_head))
        self.fc_q = nn.Linear(self.dim_model, self.dim_head)  # Linear Projection for Query Matrix
        self.fc_k = nn.Linear(self.dim_model, self.dim_head)  # Linear Projection for Key Matrix
        self.fc_v = nn.Linear(self.dim_model, self.dim_head)  # Linear Projection for Value Matrix

    def forward(self, x: Tensor, mask: Tensor, enc_output: Tensor = None) -> Tensor:
        q, k, v = self.fc_q(x), self.fc_k(x), self.fc_v(x)  # x is previous layer's output
        if enc_output is not None:
            """ For encoder-decoder self-attention """
            k = self.fc_k(enc_output)
            v = self.fc_v(enc_output)
        attention_matrix = scaled_dot_product_attention(q, k, v, self.dot_scale, mask=mask)
        return attention_matrix

๋˜‘๊ฐ™์€ Attention-Head๋ฅผ N๊ฐœ ์‚ฌ์šฉํ•˜๊ธฐ ๋•Œ๋ฌธ์— ๋จผ์ € Single Attention Head์˜ ๋™์ž‘์„ ๋”ฐ๋กœ ๊ฐ์ฒด๋กœ ๋งŒ๋“ค์—ˆ๋‹ค. ์ด๋ ‡๊ฒŒ ํ•˜๋ฉด MultiHeadAttention ๊ฐ์ฒด์—์„œ nn.ModuleList ๋ฅผ ์‚ฌ์šฉํ•ด N๊ฐœ์˜ ํ•ด๋“œ๋ฅผ ์ด์–ด๋ถ™์ผ ์ˆ˜ ์žˆ์–ด์„œ ๊ตฌํ˜„์ด ํ›จ์”ฌ ๊ฐ„ํŽธํ•ด์ง€๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค. Single Attention Head ๊ฐ์ฒด๊ฐ€ ํ•˜๋Š” ์ผ์€ ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค.

  • 1) Linear Projection by Dimension of Single Attention Head
  • 2) Maksing
  • 3) Scaled Dot-Product Attention

ํ•œํŽธ, ์—ฌ๋Ÿฌ Transformer ๊ตฌํ˜„ Git Repo๋ฅผ ์‚ดํŽด๋ณด๋ฉด ๊ตฌํ˜„ ๋ฐฉ๋ฒ•์€ ํฌ๊ฒŒ ํ•„์ž์ฒ˜๋Ÿผ Single Attention Head๋ฅผ ์ถ”์ƒํ™”ํ•˜๊ฑฐ๋‚˜ MultiHeadAttention ๊ฐ์ฒด ํ•˜๋‚˜์— ๋ชจ๋“  ๋™์ž‘์„ ๋•Œ๋ ค๋„ฃ๋Š” ๋ฐฉ์‹์œผ๋กœ ๋‚˜๋‰˜๋Š” ๊ฒƒ ๊ฐ™๋‹ค. ์‚ฌ์‹ค ๊ตฌํ˜„์— ์ •๋‹ต์€ ์—†์ง€๋งŒ ๊ฐœ์ธ์ ์œผ๋กœ ํ›„์ž์˜ ๋ฐฉ์‹์€ ๋น„ํšจ์œจ์ ์ด๋ผ ์ƒ๊ฐํ•œ๋‹ค. ์ €๋ ‡๊ฒŒ ๊ตฌํ˜„ํ•˜๋ฉด 3*N๊ฐœ์˜ linear projector๋ฅผ ํด๋ž˜์Šค __init__ ์— ๋งŒ๋“ค๊ณ  ๊ด€๋ฆฌํ•ด์ค˜์•ผ ํ•˜๋Š”๋ฐ ์‰ฝ์ง€ ์•Š์„ ๊ฒƒ์ด๋‹ค. ๋ฌผ๋ก  3๊ฐœ์˜ linear projector ๋งŒ ์ดˆ๊ธฐํ™”ํ•ด์„œ ์‚ฌ์šฉํ•˜๊ณ  ๋Œ€์‹  ์ถœ๋ ฅ ์ฐจ์›์„ Dim_Head๊ฐ€ ์•„๋‹Œ Dim_Model๋กœ ๊ตฌํ˜„ํ•œ ๋’ค, N๊ฐœ๋กœ ์ฐจ์›์„ ๋ถ„ํ• ํ•˜๋Š” ๋ฐฉ๋ฒ•๋„ ์žˆ๋‹ค. ํ•˜์ง€๋งŒ ์ฐจ์›์„ ์ชผ๊ฐœ๋Š” ๋™์ž‘์„ ๊ตฌํ˜„ํ•˜๋Š” ๊ฒƒ๋„ ์‚ฌ์‹ค ์‰ฝ์ง€ ์•Š๋‹ค. ๊ทธ๋ž˜์„œ ํ•„์ž๋Š” ์ „์ž์˜ ๋ฐฉ์‹์„ ์ถ”์ฒœํ•œ๋‹ค.

ํ•œํŽธ, forward ๋ฉ”์„œ๋“œ์— if enc_output is not None: ๋ถ€๋ถ„์€ ์ถ”ํ›„์— ๋””์ฝ”๋”์—์„œ Multi-Head Attention์„ ๊ตฌํ˜„ํ•˜๊ธฐ ์œ„ํ•ด ์ถ”๊ฐ€ํ•œ ์ฝ”๋“œ๋‹ค. ๋””์ฝ”๋”๋Š” ์ธ์ฝ”๋”์™€ ๋‹ค๋ฅด๊ฒŒ ํ•˜๋‚˜์˜ ๋””์ฝ”๋” ๋ธ”๋Ÿญ์—์„œ Self-Attention๋™์ž‘์„ ๋‘๋ฒˆํ•˜๋Š”๋ฐ, ๋‘๋ฒˆ์งธ ๋™์ž‘์€ ์„œ๋กœ ๋‹ค๋ฅธ ์ถœ์ฒ˜์˜ ๊ฐ’์„ ์ด์šฉํ•ด linear projection์„ ์ˆ˜ํ–‰ํ•œ๋‹ค. ๋”ฐ๋ผ์„œ ๊ทธ ๊ฒฝ์šฐ๋ฅผ ์ฒ˜๋ฆฌํ•ด์ฃผ๊ธฐ ์œ„ํ•ด ๊ตฌํ˜„ํ•˜๊ฒŒ ๋˜์—ˆ๋‹ค.

์•„๋ž˜๋Š” MultiHeadAttention ์„ ๊ตฌํ˜„ํ•œ ํŒŒ์ดํ† ์น˜ ์ฝ”๋“œ๋‹ค.

# Pytorch Implemenation of Single Attention Head

class MultiHeadAttention(nn.Module):
    """
    In this class, we implement workflow of Multi-Head Self-Attention
    Args:
        dim_model: dimension of model's latent vector space, default 512 from official paper
        num_heads: number of heads in MHSA, default 8 from official paper for Transformer
        dim_head: dimension of each attention head, default 64 from official paper (512 / 8)
        dropout: dropout rate, default 0.1
    Math:
        MSA(z) = [SA1(z); SA2(z); ยท ยท ยท ; SAk(z)]โ€ขUmsa
    Reference:
        https://arxiv.org/abs/1706.03762
    """
    def __init__(self, dim_model: int = 512, num_heads: int = 8, dim_head: int = 64, dropout: float = 0.1) -> None:
        super(MultiHeadAttention, self).__init__()
        self.dim_model = dim_model
        self.num_heads = num_heads
        self.dim_head = dim_head
        self.dropout = dropout
        self.attention_heads = nn.ModuleList(
            [AttentionHead(self.dim_model, self.dim_head, self.dropout) for _ in range(self.num_heads)]
        )
        self.fc_concat = nn.Linear(self.dim_model, self.dim_model)

    def forward(self, x: Tensor, mask: Tensor, enc_output: Tensor = None) -> Tensor:
        """ x is already passed nn.Layernorm """
        assert x.ndim == 3, f'Expected (batch, seq, hidden) got {x.shape}'
        attention_output = self.fc_concat(
            torch.cat([head(x, mask, enc_output) for head in self.attention_heads], dim=-1)
        )
        return attention_output

MultiHeadAttention ๊ฐ์ฒด๋Š” ๊ฐœ๋ณ„ ํ•ด๋“œ๋“ค์ด ๋„์ถœํ•œ ์–ดํ…์…˜ ๊ฒฐ๊ณผ๋ฅผ concatํ•˜๊ณ  ๊ทธ๊ฒƒ์„ connect & mixํ•˜๋ ค๊ณ  linear projection์„ ์ˆ˜ํ–‰ํ•œ๋‹ค.

๐Ÿ”ฌ Feed Forward Network

# Pytorch Implementation of FeedForward Network

class FeedForward(nn.Module):
    """
    Class for Feed-Forward Network module in transformer
    In official paper, they use ReLU activation function, but GELU is better for now
    We change ReLU to GELU & add dropout layer
    Args:
        dim_model: dimension of model's latent vector space, default 512
        dim_ffn: dimension of FFN's hidden layer, default 2048 from official paper
        dropout: dropout rate, default 0.1
    Math:
        FeedForward(x) = FeedForward(LN(x))+x
    """
    def __init__(self, dim_model: int = 512, dim_ffn: int = 2048, dropout: float = 0.1) -> None:
        super(FeedForward, self).__init__()
        self.ffn = nn.Sequential(
            nn.Linear(dim_model, dim_ffn),
            nn.GELU(),
            nn.Dropout(p=dropout),
            nn.Linear(dim_ffn, dim_model),
            nn.Dropout(p=dropout),
        )

    def forward(self, x: Tensor) -> Tensor:
        return self.ffn(x)

ํ”ผ๋“œ ํฌ์›Œ๋“œ๋Š” ๋ชจ๋ธ์— non-linearity๋ฅผ ์ถ”๊ฐ€ํ•˜๊ธฐ ์œ„ํ•ด ์‚ฌ์šฉํ•˜๋Š” ๋ ˆ์ด์–ด๋‹ค. ์›๋ณธ ๋ชจ๋ธ์€ ReLU ๋ฅผ ์‚ฌ์šฉํ•˜์ง€๋งŒ ์ตœ๊ทผ Transformer๋ฅ˜ ๋ชจ๋ธ์—๋Š” GeLU๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ด ์ข€ ๋” ์•ˆ์ •์ ์ธ ํ•™์Šต์„ ํ•˜๋Š”๋ฐ ๋„์›€์ด ๋œ๋‹ค๊ณ  ๋ฐํ˜€์ ธ, ํ•„์ž ์—ญ์‹œ GeLU๋ฅผ ์‚ฌ์šฉํ•ด ๊ตฌํ˜„ํ–ˆ๋‹ค. ๋˜ํ•œ ๋…ผ๋ฌธ์—๋Š” dropout์— ๋Œ€ํ•œ ์–ธ๊ธ‰์ด ์ „ํ˜€ ์—†๋Š”๋ฐ, ์€๋‹‰์ธต์˜ ์ฐจ์›์„ ์ €๋ ‡๊ฒŒ ํฌ๊ฒŒ ํ‚ค์› ๋‹ค ์ค„์ด๋Š”๋ฐ overfitting ์ด์Šˆ๊ฐ€ ์žˆ์„ ๊ฒƒ ๊ฐ™์•„์„œ ViT ๋…ผ๋ฌธ์„ ์ฐธ๊ณ ํ•ด ๋”ฐ๋กœ ์ถ”๊ฐ€ํ•ด์คฌ๋‹ค.

โž•ย Add & Norm

Residual Connection๊ณผ Layernorm์„ ์˜๋ฏธํ•œ๋‹ค. ๋”ฐ๋กœ ๊ฐ์ฒด๋ฅผ ๋งŒ๋“ค์–ด์„œ ์‚ฌ์šฉํ•˜์ง€๋Š” ์•Š๊ณ , EncoderLayer ๊ฐ์ฒด์— ๋ผ์ธ์œผ๋กœ ์ถ”๊ฐ€ํ•ด ๊ตฌํ˜„ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ์—ฌ๊ธฐ์„œ๋Š” ์—ญํ• ๊ณผ ์˜๋ฏธ๋งŒ ์„ค๋ช…ํ•˜๊ณ  ๋„˜์–ด๊ฐ€๊ฒ ๋‹ค.

๋จผ์ € Skip-Connection์œผ๋กœ๋„ ๋ถˆ๋ฆฌ๋Š” Residual Connection์€ ์–ด๋–ค ๋ ˆ์ด์–ด๋ฅผ ํ†ต๊ณผํ•˜๊ธฐ ์ „, ์ž…๋ ฅ $x$ ๋ฅผ ๋ ˆ์ด์–ด๋ฅผ ํ†ต๊ณผํ•˜๊ณ  ๋‚˜์˜จ ๊ฒฐ๊ณผ๊ฐ’ $fx$ ์— ๋”ํ•ด์ค€๋‹ค. ๋”ฐ๋ผ์„œ ๋‹ค์Œ ๋ ˆ์ด์–ด์— ํ†ต๊ณผ๋˜๋Š” ์ž…๋ ฅ๊ฐ’์€ $x+fx$ ๊ฐ€ ๋œ๋‹ค. ์™œ ์ด๋ ‡๊ฒŒ ๋”ํ•ด์ค„๊นŒ?? ๋ฐ”๋กœ ๋ชจ๋ธ์˜ ์•ˆ์ •์ ์ธ ํ•™์Šต์„ ์œ„ํ•ด์„œ๋‹ค. ์ผ๋‹จ ๊ทธ์ „์— ๋ช…์‹ฌํ•˜๊ณ  ๊ฐ€์•ผํ•  ์ „์ œ๊ฐ€ ํ•˜๋‚˜ ์žˆ๋‹ค. ๋ชจ๋ธ์˜ ๋ ˆ์ด์–ด๊ฐ€ ๊นŠ์–ด์งˆ์ˆ˜๋ก ๋ ˆ์ด์–ด๋งˆ๋‹ค ๊ฐ’์„ ์กฐ๊ธˆ์”ฉ ๋ฐ”๊ฟ”๋‚˜๊ฐ€๋Š” ๊ฒƒ์ด Robustํ•˜๊ณ  Stableํ•œ ๊ฒฐ๊ณผ๋ฅผ ๋„์ถœํ•  ์ˆ˜ ์žˆ๋‹ค๋Š” ๊ฒƒ์ด๋‹ค. ์ง๊ด€์ ์œผ๋กœ ๋ ˆ์ด์–ด๋งˆ๋‹ค ๊ฒฐ๊ณผ๊ฐ€ ๋„๋›ฐ๊ธฐํ•˜๋Š” ๋ชจ๋ธ๋ณด๋‹ค ์•ˆ์ •์ ์œผ๋กœ ์ฐจ๊ทผ์ฐจ๊ทผ ํ•™์Šตํ•ด๋‚˜๊ฐ€๋Š” ๋ชจ๋ธ์˜ ์ผ๋ฐ˜ํ™” ์„ฑ๋Šฅ์ด ๋” ์ข‹์„ ๊ฒƒ์ด๋ผ๊ณ  ์ถ”์ธกํ•ด๋ณผ ์ˆ˜ ์žˆ๋‹ค. ๊ทธ๋ž˜์„œ Residual Connection ์€ ์ž…๋ ฅ $x$ ์™€ ๋ ˆ์ด์–ด์˜ ์ด์ƒ์ ์ธ ์ถœ๋ ฅ๊ฐ’ $H(x)$ ์˜ ์ฐจ์ด๊ฐ€ ํฌ์ง€ ์•Š์Œ์„ ๊ฐ€์ •ํ•œ๋‹ค. ๋งŒ์•ฝ, ์ž…๋ ฅ $X$ ๋ฅผ 10.0 , $H(x)$ ๋ฅผ 10.4 ๋ผ๊ณ  ํ•ด๋ณด์ž. ๊ทธ๋Ÿผ Residual Connection ์„ ์‚ฌ์šฉํ•˜๋Š” ๋ชจ๋ธ์€ 0.4์— ๋Œ€ํ•ด์„œ๋งŒ ํ•™์Šต์„ ํ•˜๋ฉด ๋œ๋‹ค. ํ•œํŽธ ์ด๊ฒƒ์„ ์‚ฌ์šฉํ•˜์ง€ ์•Š๋Š” ๋ชจ๋ธ์€ 0์—์„œ๋ถ€ํ„ฐ ์‹œ์ž‘ํ•ด ๋ฌด๋ ค 10.4๋ฅผ ํ•™์Šตํ•ด์•ผ ํ•œ๋‹ค. ์–ด๋–ค ๋ชจ๋ธ์ด ํ•™์Šตํ•˜๊ธฐ ์‰ฌ์šธ๊นŒ?? ๋‹น์—ฐํžˆ ์ „์ž์ผ ๊ฒƒ์ด๋‹ค. ์ด๋ ‡๊ฒŒ ๋ชจ๋ธ์ด ์ด์ƒ์ ์ธ ๊ฐ’๊ณผ ์ž…๋ ฅ์˜ ์ฐจ์ด๋งŒ ํ•™์Šตํ•˜๋ฉด ๋˜๊ธฐ ๋•Œ๋ฌธ์— ์ด๊ฒƒ์„ ์ž”์ฐจ ํ•™์Šต์ด๋ผ๊ณ  ๋ถ€๋ฅด๋Š” ๊ฒƒ์ด๋‹ค.

Layernorm vs Batchnorm Layernorm vs Batchnorm

Batchnorm์€ โ€œMini-Batchโ€ ๋‹จ์œ„๋ฅผ Channel(Feature)๋ณ„๋กœ ํ‰๊ท ๊ณผ ํ‘œ์ค€ํŽธ์ฐจ๋ฅผ ๊ตฌํ•œ๋‹ค๋ฉด, Layernorm์€ Channel(Feature) ๋‹จ์œ„๋ฅผ ๊ฐœ๋ณ„ ์ธ์Šคํ„ด์Šค๋ณ„๋กœ ํ‰๊ท ๊ณผ ํ‘œ์ค€ํŽธ์ฐจ๋ฅผ ๊ตฌํ•˜์—ฌ ์ •๊ทœํ™”ํ•˜๋Š” ๋ฐฉ์‹์ด๋‹ค.

์˜ˆ๋ฅผ ๋“ค์–ด ๋ฐฐ์น˜๋กœ 4๊ฐœ์˜ ๋ฌธ์žฅ์„ ์€๋‹‰์ธต์˜ ์‚ฌ์ด์ฆˆ๊ฐ€ 512์ธ ๋ชจ๋ธ์— ์ž…๋ ฅํ•ด์คฌ๋‹ค๊ณ  ์ƒ๊ฐํ•ด๋ณด์ž. ๊ทธ๋Ÿผ 4๊ฐœ์˜ ๋ฌธ์žฅ์€ ๊ฐ๊ฐ 512๊ฐœ์˜ ์›์†Œ๋ฅผ ๊ฐ–๊ฒŒ ๋˜๋Š”๋ฐ, ์ด๊ฒƒ์— ๋Œ€ํ•œ ํ‰๊ท ๊ณผ ํ‘œ์ค€ํŽธ์ฐจ๋ฅผ ๊ตฌํ•œ๋‹ค๋Š” ๊ฒƒ์ด๋‹ค. ํ•œ ๊ฐœ์˜ ๋ฌธ์žฅ๋‹น ํ‰๊ท ๊ณผ ํ‘œ์ค€ํŽธ์ฐจ๋ฅผ 1๊ฐœ์”ฉ ๊ตฌํ•ด์„œ, 4๊ฐœ์˜ ๋ฌธ์žฅ์ด๋‹ˆ๊นŒ ์ด 8๊ฐœ๊ฐ€ ๋‚˜์˜ค๊ฒ ๋‹ค.

๊ทธ๋ ‡๋‹ค๋ฉด ์™œ Transformer๋Š” Layernorm์„ ์‚ฌ์šฉํ–ˆ์„๊นŒ?? ์ž์—ฐ์–ด ์ฒ˜๋ฆฌ๋Š” ๋ฐฐ์น˜๋งˆ๋‹ค ์‹œํ€€์Šค์˜ ๊ธธ์ด๊ฐ€ ๊ณ ์ •๋˜์–ด ์žˆ์ง€ ์•Š์•„ ํŒจ๋”ฉ์ด๋‚˜ ์ ˆ์‚ญ์„ ์ˆ˜ํ–‰ํ•œ๋‹ค. ์ ˆ์‚ญ๋ณด๋‹ค๋Š” ํŒจ๋”ฉ์ด ๋ฌธ์ œ๊ฐ€ ๋œ๋‹ค. ํŒจ๋”ฉ์€ ์ผ๋ฐ˜์ ์œผ๋กœ ๋ฌธ์žฅ์˜ ๋๋ถ€๋ถ„์— ํ•ด์ค€๋‹ค. ์—ฌ๊ธฐ์„œ Batchnorm์„ ์‚ฌ์šฉํ•˜๋ฉด ๋์ชฝ์— ์œ„์น˜ํ•œ ๋‹ค๋ฅธ ์‹œํ€€์Šค์— ์†ํ•œ ์ •์ƒ์ ์ธ ํ† ํฐ๋“ค์€ ํŒจ๋”ฉ์— ์˜ํ•ด ๊ฐ’์ด ์™œ๊ณก๋  ๊ฐ€๋Šฅ์„ฑ์ด ์žˆ๋‹ค. ๊ทธ๋ž˜์„œ Batchnorm ๋Œ€์‹  Layernorm์„ ์‚ฌ์šฉํ•œ๋‹ค. ๋˜ํ•œ Batchnorm ์€ ๋ฐฐ์น˜ ํฌ๊ธฐ์— ์ข…์†์ ์ด๋ผ์„œ ํ…Œ์ŠคํŠธ ์ƒํ™ฉ์—์„œ๋Š” ๊ทธ๋Œ€๋กœ ์‚ฌ์šฉํ•  ์ˆ˜ ์—†๋‹ค. ๋”ฐ๋ผ์„œ ๋ฐฐ์น˜ ์‚ฌ์ด์ฆˆ์— ๋…๋ฆฝ์ ์ธ Layernorm์„ ์‚ฌ์šฉํ•˜๊ธฐ๋„ ํ•œ๋‹ค.

ํ•œํŽธ ์ด๋Ÿฌํ•œ ์ •๊ทœํ™”๋ฅผ ์™œ ์‚ฌ์šฉํ•˜๋Š”์ง€ ๊ถ๊ธˆํ•˜์‹œ๋‹ค๋ฉด ๋‹ค๋ฅธ ํฌ์ŠคํŠธ์— ์ •๋ฆฌ๋ฅผ ํ•ด๋’€์œผ๋‹ˆ ์ฐธ๊ณ ํ•˜์‹œ๊ธธ ๋ฐ”๋ž€๋‹ค. ๊ฐ„๋‹จํ•˜๊ฒŒ๋งŒ ์–ธ๊ธ‰ํ•˜๋ฉด, ๋ชจ๋ธ์˜ ๋น„์„ ํ˜•์„ฑ๊ณผ ๊ทธ๋ผ๋””์–ธํŠธ ํฌ๊ธฐ ์‚ฌ์ด์˜ ์ตœ์ ์˜ Trade-Off๋ฅผ ์ธ๊ฐ„์ด ์•„๋‹Œ ๋ชจ๋ธ๋ณด๊ณ  ์ฐพ๊ฒŒ ๋งŒ๋“œ๋Š”๊ฒŒ ๋ชฉ์ ์ด๋ผ ๋ณผ ์ˆ˜ ์žˆ๋‹ค.

๐Ÿ“˜ย EncoderLayer

์ด์ œ Single Encoder Block์„ ์ •์˜ํ•˜๊ธฐ์— ํ•„์š”ํ•œ ๋ชจ๋“  ์žฌ๋ฃŒ๋ฅผ ์‚ดํŽด๋ดค๋‹ค. ์ง€๊ธˆ๊นŒ์ง€์˜ ๋‚ด์šฉ์„ ์ข…ํ•ฉํ•ด ํ•œ ๊ฐœ์˜ ์ธ์ฝ”๋” ๋ธ”๋Ÿญ์„ ๋งŒ๋“ค์–ด๋ณด์ž.

# Pytorch Implementation of Single Encoder Block

class EncoderLayer(nn.Module):
    """
    Class for encoder model module in Transformer
    In this class, we stack each encoder_model module (Multi-Head Attention, Residual-Connection, LayerNorm, FFN)
    We apply pre-layernorm, which is different from original paper
    In common sense, pre-layernorm are more effective & stable than post-layernorm
    """
    def __init__(self, dim_model: int = 512, num_heads: int = 8, dim_ffn: int = 2048, dropout: float = 0.1) -> None:
        super(EncoderLayer, self).__init__()
        self.self_attention = MultiHeadAttention(
            dim_model,
            num_heads,
            int(dim_model / num_heads),
            dropout,
        )
        self.layer_norm1 = nn.LayerNorm(dim_model)
        self.layer_norm2 = nn.LayerNorm(dim_model)
        self.dropout = nn.Dropout(p=dropout)
        self.ffn = FeedForward(
            dim_model,
            dim_ffn,
            dropout,
        )

    def forward(self, x: Tensor, mask: Tensor) -> Tensor:
        ln_x = self.layer_norm1(x)
        residual_x = self.dropout(self.self_attention(ln_x, mask)) + x

        ln_x = self.layer_norm2(residual_x)
        fx = self.ffn(ln_x) + residual_x
        return fx

์ง€๊ธˆ๊นŒ์ง€์˜ ๋‚ด์šฉ์„ ๊ฐ์ฒด ํ•˜๋‚˜์— ๋ชจ์•„๋‘”๊ฑฐ๋ผ ํŠน๋ณ„ํžˆ ์„ค๋ช…์ด ํ•„์š”ํ•œ ๋ถ€๋ถ„์€ ์—†์ง€๋งŒ, ํ•„์ž๊ฐ€ add & norm์„ ์–ธ์ œ ์‚ฌ์šฉํ–ˆ๋Š”์ง€ ์ฃผ๋ชฉํ•ด๋ณด์ž. ์›๋ณธ ๋…ผ๋ฌธ์€ Multi-Head Attention๊ณผ FeedForward Layer๋ฅผ ํ†ต๊ณผํ•œ ์ดํ›„์— add & norm์„ ํ•˜๋Š” post-layernorm ๋ฐฉ์‹์„ ์ ์šฉํ–ˆ๋‹ค. ํ•˜์ง€๋งŒ ํ•„์ž๋Š” ๋‘ ๋ ˆ์ด์–ด ํ†ต๊ณผ ์ด์ „์— ๋ฏธ๋ฆฌ add & norm ์„ ํ•ด์ฃผ๋Š” pre-layernorm ๋ฐฉ์‹์„ ์ฑ„ํƒํ–ˆ๋‹ค.

pre-layernorm vs post-layernorm pre-layernorm vs post-layernorm

์ตœ๊ทผ Transformer๋ฅ˜์˜ ๋ชจ๋ธ์— pre-layernorm์„ ์ ์šฉํ•˜๋Š” ๊ฒƒ์ด ์ข€ ๋” ์•ˆ์ •์ ์ด๊ณ  ํšจ์œจ์ ์ธ ํ•™์Šต์„ ์œ ๋„ํ•  ์ˆ˜ ์žˆ๋‹ค๊ณ  ์‹คํ—˜์„ ํ†ตํ•ด ๋ฐํ˜€์ง€๊ณ  ์žˆ๋‹ค. pre-layernorm ์„ ์‚ฌ์šฉํ•˜๋ฉด ๋ณ„๋‹ค๋ฅธ Gradient Explode ํ˜„์ƒ์ด ํ˜„์ €ํžˆ ์ค„์–ด๋“ค์–ด ๋ณต์žกํ•œ ์Šค์ผ€์ค„๋Ÿฌ(warmup ๊ธฐ๋Šฅ์ด ์žˆ๋Š” ์Šค์ผ€์ค„๋Ÿฌ)๋ฅผ ์‚ฌ์šฉํ•  ํ•„์š”๊ฐ€ ์—†์–ด์ง„๋‹ค๊ณ  ํ•˜๋‹ˆ ์ฐธ๊ณ ํ•˜์ž.

์ด๋ ‡๊ฒŒ ๊ตฌํ˜„ํ•œ Single Encoder Block์„ ์ด์ œ N๊ฐœ ์Œ“๊ธฐ๋งŒ ํ•˜๋ฉด ๋“œ๋””์–ด ์ธ์ฝ”๋”๋ฅผ ์™„์„ฑํ•  ์ˆ˜ ์žˆ๊ฒŒ ๋œ๋‹ค.

๐Ÿ“š Encoder

๋“œ๋””์–ด ๋Œ€๋ง์˜ Encoder ๊ฐ์ฒด ๊ตฌํ˜„์„ ์‚ดํŽด๋ณผ ์‹œ๊ฐ„์ด๋‹ค.

# Pytorch Implementation of Encoder(Stacked N EncoderLayer)

class Encoder(nn.Module):
    """
    In this class, encode input sequence and then we stack N EncoderLayer
    First, we define "positional embedding" and then add to input embedding for making "word embedding"
    Second, forward "word embedding" to N EncoderLayer and then get output embedding
    In official paper, they use positional encoding, which is base on sinusoidal function(fixed, not learnable)
    But we use "positional embedding" which is learnable from training
    Args:
        max_seq: maximum sequence length, default 512 from official paper
        N: number of EncoderLayer, default 6 for base model
    """
    def __init__(self, max_seq: 512, N: int = 6, dim_model: int = 512, num_heads: int = 8, dim_ffn: int = 2048, dropout: float = 0.1) -> None:
        super(Encoder, self).__init__()
        self.max_seq = max_seq
        self.scale = torch.sqrt(torch.Tensor(dim_model))  # scale factor for input embedding from official paper
        self.positional_embedding = nn.Embedding(max_seq, dim_model)  # add 1 for cls token
        self.num_layers = N
        self.dim_model = dim_model
        self.num_heads = num_heads
        self.dim_ffn = dim_ffn
        self.dropout = nn.Dropout(p=dropout)
        self.encoder_layers = nn.ModuleList(
            [EncoderLayer(dim_model, num_heads, dim_ffn, dropout) for _ in range(self.num_layers)]
        )
        self.layer_norm = nn.LayerNorm(dim_model)

    def forward(self, inputs: Tensor, mask: Tensor) -> tuple[Tensor, Tensor]:
        """
        inputs: embedding from input sequence, shape => [BS, SEQ_LEN, DIM_MODEL]
        mask: mask for Encoder padded token for speeding up to calculate attention score
        """
        layer_output = []
        pos_x = torch.arange(self.max_seq).repeat(inputs.shape[0]).to(inputs)
        x = self.dropout(
            self.scale * inputs + self.positional_embedding(pos_x)
        )
        for layer in self.encoder_layers:
            x = layer(x, mask)
            layer_output.append(x)
        encoded_x = self.layer_norm(x)  # from official paper & code by Google Research
        layer_output = torch.stack(layer_output, dim=0).to(x.device)  # For Weighted Layer Pool: [N, BS, SEQ_LEN, DIM]
        return encoded_x, layer_output

์—ญ์‹œ ์ง€๊ธˆ๊นŒ์ง€ ๋‚ด์šฉ์„ ์ข…ํ•ฉํ•œ ๊ฒƒ๋ฟ์ด๋ผ์„œ ํฌ๊ฒŒ ํŠน์ดํ•œ ๋‚ด์šฉ์€ ์—†๊ณ , ๊ตฌํ˜„์ƒ ๋†“์น˜๊ธฐ ์‰ฌ์šด ๋ถ€๋ถ„๋งŒ ์•Œ๊ณ  ๋„˜์–ด๊ฐ€๋ฉด ๋œ๋‹ค. forward ๋ฉ”์„œ๋“œ์˜ ๋ณ€์ˆ˜ x๋ฅผ ์ดˆ๊ธฐํ™”ํ•˜๋Š” ์ฝ”๋“œ ๋ผ์ธ์„ ์ฃผ๋ชฉํ•ด๋ณด์ž. ์ด๊ฒƒ์ด ๋ฐ”๋กœ Input Embedding๊ณผ Position Embedding์„ ๋”ํ•˜๋Š”(ํ–‰๋ ฌ ํ•ฉ) ์—ฐ์‚ฐ์„ ๊ตฌํ˜„ํ•œ ๊ฒƒ์ด๋‹ค. ์ด ๋•Œ ๋†“์น˜๊ธฐ ์‰ฌ์šด ๋ถ€๋ถ„์ด ๋ฐ”๋กœ Input Embedding์— scale factor๋ฅผ ๊ณฑํ•ด์ค€๋‹ค๋Š” ๊ฒƒ์ด๋‹ค. ์ €์ž์˜ ์ฃผ์žฅ์— ๋”ฐ๋ฅด๋ฉด Input Embedding์—๋งŒ scale factor๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ด ์•ˆ์ •์ ์ธ ํ•™์Šต์— ๋„์›€์ด ๋œ๋‹ค๊ณ  ํ•˜๋‹ˆ ์ฐธ๊ณ ํ•˜์ž.

ํ•œํŽธ, ๋งˆ์ง€๋ง‰ ์ธ์ฝ”๋” ๋ธ”๋Ÿญ์—์„œ ๋‚˜์˜จ ์ž„๋ฒ ๋”ฉ์„ ๋‹ค์‹œ ํ•œ ๋ฒˆ layernorm์— ํ†ต๊ณผํ•˜๋„๋ก ๊ตฌํ˜„ํ–ˆ๋‹ค. ์ด ๋ถ€๋ถ„๋„ ์›๋ณธ ๋…ผ๋ฌธ์— ์žˆ๋Š” ๋‚ด์šฉ์€ ์•„๋‹ˆ๊ณ  ViT์˜ ๋…ผ๋ฌธ ๋‚ด์šฉ์„ ์ฐธ๊ณ ํ•ด ์ถ”๊ฐ€ํ–ˆ๋‹ค.

๐Ÿ“˜ย DecoderLayer

์ด๋ฒˆ์—๋Š” ๋””์ฝ”๋”์— ์‚ฌ์šฉ๋œ ๋ธ”๋Ÿญ์˜ ๋™์ž‘ ๋ฐฉ์‹๊ณผ ์˜๋ฏธ ๊ทธ๋ฆฌ๊ณ  ๊ตฌํ˜„๊นŒ์ง€ ์•Œ์•„๋ณด์ž. ์‚ฌ์‹ค ๋””์ฝ”๋”๋„ ์ง€๊ธˆ๊นŒ์ง€ ๊ณต๋ถ€ํ•œ ๋‚ด์šฉ๊ณผ ํฌ๊ฒŒ ๋‹ค๋ฅธ๊ฒŒ ์—†๋‹ค. ๋‹ค๋งŒ ์ธ์ฝ”๋”์™€๋Š” ๋ชฉ์ ์ด ๋‹ค๋ฅด๊ธฐ ๋•Œ๋ฌธ์— ๋ฐœ์ƒํ•˜๋Š” ๋ฏธ์„ธํ•œ ๋™์ž‘์˜ ์ฐจ์ด์— ์ฃผ๋ชฉํ•ด๋ณด์ž. ๋จผ์ € Single Decoder Block์€ Single Encoder Block๊ณผ ๋‹ค๋ฅด๊ฒŒ Self-Attention์„ ๋‘ ๋ฒˆ ์ˆ˜ํ–‰ํ•œ๋‹ค. ์ง€๊ฒน๊ฒ ์ง€๋งŒ ๋‹ค์‹œ ํ•œ ๋ฒˆ Transformer์˜ ๋ชฉ์ ์„ ์ƒ๊ธฐ์‹œ์ผœ๋ณด์ž. ๋ฐ”๋กœ ๋Œ€์ƒ ์–ธ์–ด๋ฅผ ํƒ€๊ฒŸ ์–ธ์–ด๋กœ ์ž˜ ๋ฒˆ์—ญํ•˜๋Š” ๊ฒƒ์ด์—ˆ๋‹ค. ์ผ๋‹จ ์ธ์ฝ”๋”๋ฅผ ํ†ตํ•ด ๋Œ€์ƒ ์–ธ์–ด๋Š” ์ž˜ ์ดํ•ดํ•˜๊ฒŒ ๋˜์—ˆ๋‹ค. ๊ทธ๋Ÿผ ์ด์ œ ํƒ€๊ฒŸ ์–ธ์–ด๋„ ์ž˜ ์ดํ•ดํ•ด์•ผํ•˜์ง€ ์•Š์€๊ฐ€?? ๊ทธ๋ž˜์„œ ํƒ€๊ฒŸ ์–ธ์–ด๋ฅผ ์ดํ•ดํ•˜๊ธฐ ์œ„ํ•ด Self-Attention์„ ํ•œ ๋ฒˆ, ๊ทธ๋ฆฌ๊ณ  ๋Œ€์ƒ ์–ธ์–ด๋ฅผ ํƒ€๊ฒŸ ์–ธ์–ด๋กœ ๋ฒˆ์—ญํ•˜๊ธฐ ์œ„ํ•ด Self-Attention์„ ํ•œ ๋ฒˆ, ์ด 2๋ฒˆ ์ˆ˜ํ–‰ํ•˜๋Š” ๊ฒƒ์ด๋‹ค. ์ฒซ๋ฒˆ์งธ Self-Attention ์„ Masked Multi-Head Attention, ๋‘๋ฒˆ์งธ๋ฅผ Encoder-Decoder Multi-Head Attention์ด๋ผ๊ณ  ๋ถ€๋ฅธ๋‹ค.

๐ŸŽญ Masked Multi-Head Attention
์ธ์ฝ”๋”์˜ Multi-Head Attention์™€ ํ–‰๋ ฌ $Q,K,V$ ์˜ ์ถœ์ฒ˜๊ฐ€ ๋‹ค๋ฅด๋‹ค. ๋””์ฝ”๋”๋Š” ์ถœ์ฒ˜๊ฐ€ ํƒ€๊ฒŸ ์–ธ์–ด์ธ linear projection matrix๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค. ๋˜ํ•œ ์ธ์ฝ”๋”์™€ ๋‹ค๋ฅด๊ฒŒ ๊ฐœ๋ณ„ ์‹œ์ ์— ๋งž๋Š” ๋งˆ์Šคํ‚น ํ–‰๋ ฌ์ด ํ•„์š”ํ•˜๋‹ค. ๋””์ฝ”๋”์˜ ๊ณผ์—…์€ ๊ฒฐ๊ตญ ๋Œ€์ƒ ์–ธ์–ด๋ฅผ ์ž˜ ์ดํ•ดํ•˜๊ณ  ๊ทธ๊ฒƒ์— ๊ฐ€์žฅ ์ž˜ ๋“ค์–ด๋งž๋Š” ํƒ€๊ฒŸ ์–ธ์–ด ์‹œํ€€์Šค๋ฅผ generateํ•˜๋Š” ๊ฒƒ์ด๋‹ค. ์ฆ‰, Next Token Prediction์„ ํ†ตํ•ด ์‹œํ€€์Šค๋ฅผ ๋งŒ๋“ค์–ด๋‚ด์•ผ ํ•œ๋‹ค. ๊ทธ๋Ÿฐ๋ฐ ํ˜„์žฌ ์‹œ์ ์—์„œ ๋ฏธ๋ž˜ ์‹œ์ ์— ๋””์ฝ”๋”๊ฐ€ ์˜ˆ์ธกํ•ด์•ผํ•  ํ† ํฐ์„ ๋ฏธ๋ฆฌ ์•Œ๊ณ  ์žˆ์œผ๋ฉด ๊ทธ๊ฒƒ์„ ์˜ˆ์ธก์ด๋ผ๊ณ  ํ•  ์ˆ˜ ์žˆ์„๊นŒ?? ๋””์ฝ”๋”๊ฐ€ ํ˜„์žฌ ์‹œ์ ์˜ ํ† ํฐ์„ ์˜ˆ์ธกํ•˜๋Š”๋ฐ ๋ฏธ๋ž˜ ์‹œ์ ์˜ Context๋ฅผ ๋ฐ˜์˜ํ•˜์ง€ ๋ชปํ•˜๋„๋ก ๋ง‰๊ธฐ ์œ„ํ•ด ๋ฏธ๋ฆฌ ๋งˆ์Šคํ‚น ํ–‰๋ ฌ์„ ์ •์˜ํ•ด Word_Embedding์— ์ ์šฉํ•ด์ค€๋‹ค. ์ด๋ ‡๊ฒŒ ๋งˆ์Šคํ‚น์ด ์ ์šฉ๋œ ์ž„๋ฒ ๋”ฉ ํ–‰๋ ฌ์„ ๊ฐ€์ง€๊ณ  linear projection & self-attention์„ ์ˆ˜ํ–‰ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ์ด๋ฆ„ ์•ž์— masked๋ฅผ ๋ถ™์ด๊ฒŒ ๋˜์—ˆ๋‹ค.

Decoder Language Modeling Mask Decoder Language Modeling Mask

์œ„ ๊ทธ๋ฆผ์€ ๋งˆ์Šคํ‚น์„ ์ ์šฉํ•œ Word_Embedding์˜ ๋ชจ์Šต์ด๋‹ค. ์ฒซ ๋ฒˆ์งธ ์‹œ์ ์—์„œ ๋ชจ๋ธ์€ ์ž๊ธฐ ์ž์‹ ์„ ์ œ์™ธํ•œ ๋‚˜๋จธ์ง€ Context๋ฅผ ์˜ˆ์ธก์— ํ™œ์šฉํ•  ์ˆ˜ ์—†๋‹ค. ๊ทธ๋ž˜์„œ ์ดํ•˜ ๋‚˜๋จธ์ง€ ํ† ํฐ์„ ์ „๋ถ€ ๋งˆ์Šคํ‚น ์ฒ˜๋ฆฌํ•ด์คฌ๋‹ค. ๋‘๋ฒˆ์งธ ์‹œ์ ์—์„œ๋Š” ์ง์ „ ์‹œ์ ์ธ ์ฒซ๋ฒˆ์งธ ํ† ํฐ๊ณผ ์ž๊ธฐ ์ž์‹ ๋งŒ ์ฐธ๊ณ ํ•  ์ˆ˜ ์žˆ๋‹ค. ํ•œํŽธ, ์ด๋ ‡๊ฒŒ ์ง์ „ Context๋งŒ ๊ฐ€์ง€๊ณ  ํ˜„์žฌ ํ† ํฐ์„ ์ถ”๋ก ํ•˜๋Š” ๊ฒƒ์„ Language Modeling์ด๋ผ ๋ถ€๋ฅธ๋‹ค. ๊ทธ๋ฆฌ๊ณ  ๋งˆ์ฐฌ๊ฐ€์ง€๋กœ ๋””์ฝ”๋” ์—ญ์‹œ ์‹œํ€€์Šค์— ํŒจ๋”ฉ ์ฒ˜๋ฆฌ๋ฅผ ํ•ด์ฃผ๊ธฐ ๋•Œ๋ฌธ์— ์ธ์ฝ”๋”์™€ ๋™์ผํ•œ ์›๋ฆฌ๋กœ ๋งŒ๋“  decoder padding mask ์—ญ์‹œ ํ•„์š”ํ•˜๋‹ค.

๋งˆ์Šคํ‚น ํ–‰๋ ฌ ๊ตฌํ˜„์€ ์ตœ์ƒ์œ„ ๊ฐ์ฒด์ธ Transformer์˜ ๋‚ด๋ถ€ ๋ฉ”์„œ๋“œ๋กœ ๋งŒ๋“ค์—ˆ์œผ๋‹ˆ, ๊ทธ ๋•Œ ์ž์„ธํžˆ ์„ค๋ช…ํ•˜๊ฒ ๋‹ค. ์ดํ•˜ ๋‚˜๋จธ์ง€ ๋””ํ…Œ์ผ์€ ์ธ์ฝ”๋”์˜ ๊ฒƒ๊ณผ ๋™์ผํ•˜๋‹ค.

๐Ÿชข Encoder-Decoder Multi-Head Attention
์ธ์ฝ”๋”๋ฅผ ํ†ตํ•ด ์ดํ•ดํ•œ ๋Œ€์ƒ ์–ธ์–ด ์‹œํ€€์Šค์™€ ๋ฐ”๋กœ ์ง์ „ Self-Attention์„ ํ†ตํ•ด ์ดํ•ดํ•œ ํƒ€๊ฒŸ ์–ธ์–ด ์‹œํ€€์Šค๋ฅผ ์„œ๋กœ ๋Œ€์กฐํ•˜๋Š” ๋ ˆ์ด์–ด๋‹ค. ์šฐ๋ฆฌ์˜ ์ง€๊ธˆ ๋ชฉ์ ์€ ํƒ€๊ฒŸ ์–ธ์–ด์™€ ๊ฐ€์žฅ ์œ ์‚ฌํ•œ ๋Œ€์ƒ ์–ธ์–ด๋ฅผ ์ฐพ์•„ ๋ฌธ์žฅ์„ ์™„์„ฑํ•˜๋Š” ๊ฒƒ์ด๋‹ค. ๋”ฐ๋ผ์„œ ์–ดํ…์…˜ ๊ณ„์‚ฐ์— ์‚ฌ์šฉ๋  ํ–‰๋ ฌ $Q$ ์˜ ์ถœ์ฒ˜๋Š” ์ง์ „ ๋ ˆ์ด์–ด์ธ Masked Multi-Head Attention ์˜ ๋ฐ˜ํ™˜๊ฐ’์„ ์‚ฌ์šฉํ•˜๊ณ , ํ–‰๋ ฌ $K,V$ ๋Š” ์ธ์ฝ”๋”์˜ ์ตœ์ข… ๋ฐ˜ํ™˜๊ฐ’์„ ์‚ฌ์šฉํ•œ๋‹ค.

ํ•œํŽธ, ์—ฌ๊ธฐ ๋ ˆ์ด์–ด์—๋Š” ๋งˆ์Šคํ‚น ํ–‰๋ ฌ์ด ์„ธ ์ข…๋ฅ˜๋‚˜ ํ•„์š”ํ•˜๋‹ค. ๊ทธ ์ด์œ ๋Š” ์„œ๋กœ ์ถœ์ฒ˜๊ฐ€ ๋‹ค๋ฅธ ๋‘๊ฐ€์ง€ ํ–‰๋ ฌ์„ ๊ณ„์‚ฐ์— ์‚ฌ์šฉํ•˜๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค. ์ง€๊ธˆ์€ ์—ฌ์ „ํžˆ ๋””์ฝ”๋”์˜ ์—ญํ• ์„ ์ˆ˜ํ–‰ํ•˜๋Š” ๊ฒƒ์ด๊ธฐ ๋•Œ๋ฌธ์— ์ง์ „ ๋ ˆ์ด์–ด์—์„œ ์‚ฌ์šฉํ•œ 2๊ฐœ์˜ ๋งˆ์Šคํ‚น ํ–‰๋ ฌ์ด ๊ทธ๋Œ€๋กœ ํ•„์š”ํ•˜๋‹ค. ๊ทธ๋ฆฌ๊ณ  ์ธ์ฝ”๋”์—์„œ ๋„˜์–ด์˜จ ๊ฐ’์„ ์‚ฌ์šฉํ•œ๋‹ค๋Š” ๊ฒƒ์€ ์ธ์ฝ”๋”์˜ ํŒจ๋”ฉ ์—ญ์‹œ ์ฒ˜๋ฆฌ๊ฐ€ ํ•„์š”ํ•˜๋‹ค๋Š” ์˜๋ฏธ๋‹ค. ๋”ฐ๋ผ์„œ lm_mask, dec_pad_mask, enc_pad_mask๊ฐ€ ํ•„์š”ํ•˜๋‹ค. ์—ญ์‹œ ๋งˆ์Šคํ‚น ๊ตฌํ˜„์€ ์ตœ์ƒ์œ„ ๊ฐ์ฒด ์„ค๋ช… ๋•Œ ํ•จ๊ป˜ ์‚ดํŽด๋ณด๊ฒ ๋‹ค.

๐Ÿ‘ฉโ€๐Ÿ’ปย Implementation
์ด์ œ Single Decoder Block์˜ ๊ตฌํ˜„์„ ์‚ดํŽด๋ณด์ž. ์—ญ์‹œ ํŒŒ์ดํ† ์น˜๋กœ ๊ตฌํ˜„ํ–ˆ๋‹ค.

# Pytorch Implementation of Single Decoder Block

class DecoderLayer(nn.Module):
    """
    Class for decoder model module in Transformer
    In this class, we stack each decoder_model module (Masked Multi-Head Attention, Residual-Connection, LayerNorm, FFN)
    We apply pre-layernorm, which is different from original paper
    References:
        https://arxiv.org/abs/1706.03762
    """
    def __init__(self, dim_model: int = 512, num_heads: int = 8, dim_ffn: int = 2048, dropout: float = 0.1) -> None:
        super(DecoderLayer, self).__init__()
        self.masked_attention = MultiHeadAttention(
            dim_model,
            num_heads,
            int(dim_model / num_heads),
            dropout,
        )
        self.enc_dec_attention = MultiHeadAttention(
            dim_model,
            num_heads,
            int(dim_model / num_heads),
            dropout,
        )
        self.layer_norm1 = nn.LayerNorm(dim_model)
        self.layer_norm2 = nn.LayerNorm(dim_model)
        self.layer_norm3 = nn.LayerNorm(dim_model)
        self.dropout = nn.Dropout(p=dropout)  # dropout is not learnable layer

        self.ffn = FeedForward(
            dim_model,
            dim_ffn,
            dropout,
        )

    def forward(self, x: Tensor, dec_mask: Tensor, enc_dec_mask: Tensor, enc_output: Tensor) -> Tensor:
        ln_x = self.layer_norm1(x)
        residual_x = self.dropout(self.masked_attention(ln_x, dec_mask)) + x

        ln_x = self.layer_norm2(residual_x)
        residual_x = self.dropout(self.enc_dec_attention(ln_x, enc_dec_mask, enc_output)) + x  # for enc_dec self-attention

        ln_x = self.layer_norm3(residual_x)
        fx = self.ffn(ln_x) + residual_x
        return fx

Self-Attention ๋ ˆ์ด์–ด๊ฐ€ ์ธ์ฝ”๋”๋ณด๋‹ค ํ•˜๋‚˜ ๋” ์ถ”๊ฐ€๋˜์–ด add & norm ์„ ์ด 3๋ฒˆ ํ•ด์ค˜์•ผ ํ•œ๋‹ค๋Š” ๊ฒƒ์„ ์ œ์™ธํ•˜๊ณ ๋Š” ํฌ๊ฒŒ ๊ตฌํ˜„์ƒ์˜ ํŠน์ด์ ์€ ์—†๋‹ค. ๊ทธ์ € ์ง€๊ธˆ๊นŒ์ง€ ์‚ดํŽด๋ณธ ๋ธ”๋Ÿญ์„ ์š”๋ฆฌ์กฐ๋ฆฌ ๋‹ค์‹œ ์Œ“์œผ๋ฉด ๋œ๋‹ค.

๐Ÿ“šย Decoder

Single Decoder Block์„ N๊ฐœ ์Œ“๊ณ  ์ „์ฒด ๋””์ฝ”๋” ๋™์ž‘์„ ์ˆ˜ํ–‰ํ•˜๋Š” Decoder ๊ฐ์ฒด์˜ ๊ตฌํ˜„์„ ์•Œ์•„๋ณด์ž.

# Pytorch Implementation of Decoder(N Stacked Single Decoder Block)

class Decoder(nn.Module):
    """
    In this class, decode encoded embedding from encoder by outputs (target language, Decoder's Input Sequence)
    First, we define "positional embedding" for Decoder's Input Sequence,
    and then add them to Decoder's Input Sequence for making "decoder word embedding"
    Second, forward "decoder word embedding" to N DecoderLayer and then pass to linear & softmax for OutPut Probability
    Args:
        vocab_size: size of vocabulary for output probability
        max_seq: maximum sequence length, default 512 from official paper
        N: number of EncoderLayer, default 6 for base model
    References:
        https://arxiv.org/abs/1706.03762
    """
    def __init__(
        self,
        vocab_size: int,
        max_seq: int = 512,
        N: int = 6,
        dim_model: int = 512,
        num_heads: int = 8,
        dim_ffn: int = 2048,
        dropout: float = 0.1
    ) -> None:
        super(Decoder, self).__init__()
        self.max_seq = max_seq
        self.scale = torch.sqrt(torch.Tensor(dim_model))  # scale factor for input embedding from official paper
        self.positional_embedding = nn.Embedding(max_seq, dim_model)  # add 1 for cls token
        self.num_layers = N
        self.dim_model = dim_model
        self.num_heads = num_heads
        self.dim_ffn = dim_ffn
        self.dropout = nn.Dropout(p=dropout)
        self.decoder_layers = nn.ModuleList(
            [DecoderLayer(dim_model, num_heads, dim_ffn, dropout) for _ in range(self.num_layers)]
        )
        self.layer_norm = nn.LayerNorm(dim_model)
        self.fc_out = nn.Linear(dim_model, vocab_size)  # In Pytorch, nn.CrossEntropyLoss already has softmax function

    def forward(self, inputs: Tensor, dec_mask: Tensor, enc_dec_mask: Tensor, enc_output: Tensor) -> tuple[Tensor, Tensor]:
        """
        inputs: embedding from input sequence, shape => [BS, SEQ_LEN, DIM_MODEL]
        dec_mask: mask for Decoder padded token for Language Modeling
        enc_dec_mask: mask for Encoder-Decoder Self-Attention, from encoder padded token
        """
        layer_output = []
        pos_x = torch.arange(self.max_seq).repeat(inputs.shape[0]).to(inputs)
        x = self.dropout(
            self.scale * inputs + self.positional_embedding(pos_x)
        )
        for layer in self.decoder_layers:
            x = layer(x, dec_mask, enc_dec_mask, enc_output)
            layer_output.append(x)
        decoded_x = self.fc_out(self.layer_norm(x))  # Because of pre-layernorm
        layer_output = torch.stack(layer_output, dim=0).to(x.device)  # For Weighted Layer Pool: [N, BS, SEQ_LEN, DIM]
        return decoded_x, layer_output

Encoder ๊ฐ์ฒด์™€ ๋ชจ๋“  ๋ถ€๋ถ„์ด ๋™์ผํ•˜๋‹ค. ๋””ํ…Œ์ผํ•œ ์„ค์ •๋งŒ ๋””์ฝ”๋”์— ๋งž๊ฒŒ ๋ณ€๊ฒฝ๋˜์—ˆ์„ ๋ฟ์ด๋‹ค. self.fc_out ์— ์ฃผ๋ชฉํ•ด๋ณด์ž. ๋””์ฝ”๋”๋Š” ํ˜„์žฌ ์‹œ์ ์— ๊ฐ€์žฅ ์ ํ•ฉํ•œ ํ† ํฐ์„ ์˜ˆ์ธกํ•ด์•ผ ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ๋””์ฝ”๋”์˜ ์ถœ๋ ฅ๋ถ€๋ถ„์— ๋กœ์ง“ ๊ณ„์‚ฐ์„ ์œ„ํ•œ ๋ ˆ์ด์–ด๊ฐ€ ํ•„์š”ํ•˜๋‹ค. ๊ทธ ์—ญํ• ์„ ํ•˜๋Š” ๊ฒƒ์ด ๋ฐ”๋กœ self.fc_out์ด๋‹ค. ํ•œํŽธ, self.fc_out์˜ ์ถœ๋ ฅ ์ฐจ์›์ด vocab_size์œผ๋กœ ๋˜์–ด์žˆ๋Š”๋ฐ, ๋””์ฝ”๋”๋Š” ๋””์ฝ”๋”๊ฐ€ ๊ฐ€์ง„ ์ „์ฒด vocab ์„ ํ˜„์žฌ ์‹œ์ ์— ์ ํ•ฉํ•œ ํ† ํฐ ํ›„๋ณด๊ตฐ์œผ๋กœ ์‚ฌ์šฉํ•˜๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค.

๐Ÿฆพ Transformer

์ด์ œ ๋Œ€๋ง์˜ ๋งˆ์ง€๋ง‰โ€ฆ ๋ชจ๋ธ์˜ ๊ฐ€์žฅ ์ตœ์ƒ์œ„ ๊ฐ์ฒด์ธ Transformer์— ๋Œ€ํ•ด์„œ ์‚ดํŽด๋ณด์ž. ๊ฐ์ฒด์˜ ๋™์ž‘์€ ์ •๋ฆฌํ•˜๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค.

  • 1) Make Input Embedding for Encoder & Decoder respectively, Init Encoder & Decoder Class
  • 2) Make 3 types of Masking: Encoder Padding Mask, Decoder LM & Padding Mask, Encoder-Decoder Mask
  • 3) Return Output from Encoder & Decoder

ํŠนํžˆ ๊ณ„์† ๋ฏธ๋ค„์™”๋˜ ๋งˆ์Šคํ‚น ๊ตฌํ˜„์— ๋Œ€ํ•ด์„œ ์‚ดํŽด๋ณด์ž. ๋‚˜๋จธ์ง€๋Š” ์ด๋ฏธ ์•ž์—์„œ ๋งŽ์ด ์„ค๋ช…ํ–ˆ์œผ๋‹ˆ๊นŒ ๋„˜์–ด๊ฐ€๋„๋ก ํ•˜๊ฒ ๋‹ค. ์ผ๋‹จ ๋จผ์ € ์ฝ”๋“œ๋ฅผ ์ฝ์–ด๋ณด์ž. ์ถ”๊ฐ€๋กœ Input Embedding ๊ตฌํ˜„์€ ์‚ฌ์šฉ์ž์˜ vocab ๊ตฌ์ถ• ๋ฐฉ์‹์— ๋”ฐ๋ผ ๋‹ฌ๋ผ์ง„๋‹ค. ํ•„์ž์˜ ๊ฒฝ์šฐ ๋Œ€์ƒ ์–ธ์–ด์™€ ํƒ€๊ฒŸ ์–ธ์–ด์˜ vocab์„ ๋ถ„๋ฆฌํ•ด ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์„ ๊ฐ€์ •ํ•˜๊ณ  ์ฝ”๋“œ๋ฅผ ๋งŒ๋“ค์–ด ์ž„๋ฒ ๋”ฉ ๋ ˆ์ด์–ด๋ฅผ ๋”ฐ๋กœ ๋”ฐ๋กœ ๊ตฌํ˜„ํ•ด์คฌ๋‹ค. vocab์„ ํ†ตํ•ฉ์œผ๋กœ ๊ตฌ์ถ•ํ•˜์‹œ๋Š” ๋ถ„์ด๋ผ๋ฉด ํ•˜๋‚˜๋งŒ ์ •์˜ํ•ด๋„ ์ƒ๊ด€์—†๋‹ค. ๋Œ€์‹  ๋‚˜์ค‘์— ๋””์ฝ”๋”์˜ ๋กœ์ง“๊ฐ’ ๊ณ„์‚ฐ์„ ์œ„ํ•ด ๊ฐœ๋ณ„ ์–ธ์–ด์˜ ํ† ํฐ ์‚ฌ์ด์ฆˆ๋Š” ์•Œ๊ณ  ์žˆ์–ด์•ผ ํ•  ๊ฒƒ์ด๋‹ค.

# Pytorch Implementation of Transformer

class Transformer(nn.Module):
    """
    Main class for Pure Transformer, Pytorch implementation
    There are two Masking Method for padding token
        1) Row & Column masking
        2) Column masking only at forward time, Row masking at calculating loss time
    second method is more efficient than first method, first method is complex & difficult to implement
    Args:
        enc_vocab_size: size of vocabulary for Encoder Input Sequence
        dec_vocab_size: size of vocabulary for Decoder Input Sequence
        max_seq: maximum sequence length, default 512 from official paper
        enc_N: number of EncoderLayer, default 6 for base model
        dec_N: number of DecoderLayer, default 6 for base model
    Reference:
        https://arxiv.org/abs/1706.03762
    """
    def __init__(
        self,
        enc_vocab_size: int,
        dec_vocab_size: int,
        max_seq: int = 512,
        enc_N: int = 6,
        dec_N: int = 6,
        dim_model: int = 512,
        num_heads: int = 8,
        dim_ffn: int = 2048,
        dropout: float = 0.1
    ) -> None:
        super(Transformer, self).__init__()
        self.enc_input_embedding = nn.Embedding(enc_vocab_size, dim_model)
        self.dec_input_embedding = nn.Embedding(dec_vocab_size, dim_model)
        self.encoder = Encoder(max_seq, enc_N, dim_model, num_heads, dim_ffn, dropout)
        self.decoder = Decoder(dec_vocab_size, max_seq, dec_N, dim_model, num_heads, dim_ffn, dropout)

    @staticmethod
    def enc_masking(x: Tensor, enc_pad_index: int) -> Tensor:
        """ make masking matrix for Encoder Padding Token """
        enc_mask = (x != enc_pad_index).int().repeat(1, x.shape[-1]).view(x.shape[0], x.shape[-1], x.shape[-1])
        return enc_mask

    @staticmethod
    def dec_masking(x: Tensor, dec_pad_index: int) -> Tensor:
        """ make masking matrix for Decoder Masked Multi-Head Self-Attention """
        pad_mask = (x != dec_pad_index).int().repeat(1, x.shape[-1]).view(x.shape[0], x.shape[-1], x.shape[-1])
        lm_mask = torch.tril(torch.ones(x.shape[0], x.shape[-1], x.shape[-1]))
        dec_mask = pad_mask * lm_mask
        return dec_mask

    @staticmethod
    def enc_dec_masking(enc_x: Tensor, dec_x: Tensor, enc_pad_index: int) -> Tensor:
        """ make masking matrix for Encoder-Decoder Multi-Head Self-Attention in Decoder """
        enc_dec_mask = (enc_x != enc_pad_index).int().repeat(1, dec_x.shape[-1]).view(
            enc_x.shape[0], dec_x.shape[-1], enc_x.shape[-1]
        )
        return enc_dec_mask

    def forward(self, enc_inputs: Tensor, dec_inputs: Tensor, enc_pad_index: int, dec_pad_index: int) -> tuple[Tensor, Tensor, Tensor, Tensor]:
        enc_mask = self.enc_masking(enc_inputs, enc_pad_index)  # enc_x.shape[1] == encoder input sequence length
        dec_mask = self.dec_masking(dec_inputs, dec_pad_index)  # dec_x.shape[1] == decoder input sequence length
        enc_dec_mask = self.enc_dec_masking(enc_inputs, dec_inputs, enc_pad_index)

        enc_x, dec_x = self.enc_input_embedding(enc_inputs), self.dec_input_embedding(dec_inputs)

        enc_output, enc_layer_output = self.encoder(enc_x, enc_mask)
        dec_output, dec_layer_output = self.decoder(dec_x, dec_mask, enc_dec_mask, enc_output)
        return enc_output, dec_output, enc_layer_output, dec_layer_output

๋งˆ์Šคํ‚น์˜ ํ•„์š”์„ฑ์ด๋‚˜ ๋™์ž‘ ๋ฐฉ์‹์€ ์ด๋ฏธ ์œ„์—์„œ ๋ชจ๋‘ ์„ค๋ช…ํ–ˆ๊ธฐ ๋•Œ๋ฌธ์— ๊ตฌํ˜„์ƒ ํŠน์ง•๋งŒ ์„ค๋ช…ํ•˜๋ คํ•œ๋‹ค. ์„ธ๊ฐ€์ง€ ๋งˆ์Šคํ‚น ๋ชจ๋‘ ๊ณตํ†ต์ ์œผ๋กœ ๊ตฌํ˜„ ์ฝ”๋“œ ๋ผ์ธ์— .int() ๊ฐ€ ๋“ค์–ด๊ฐ€ ์žˆ๋‹ค. ๊ทธ ์ด์œ ๋Š” $\frac{Qโ€ขK^T}{\sqrt{d_h}}$์— ๋งˆ์Šคํ‚น์„ ์ ์šฉํ•  ๋•Œ torch.masked_fill ๋ฉ”์„œ๋“œ๋ฅผ ์‚ฌ์šฉํ•˜๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค. ๋ฌด์Šจ ์ด์œ  ๋•Œ๋ฌธ์ธ์ง€๋Š” ๋ชจ๋ฅด๊ฒ ์œผ๋‚˜ torch.masked_fill์˜ ๊ฒฝ์šฐ ๋งˆ์Šคํ‚น ์กฐ๊ฑด์œผ๋กœ boolean์„ ์ „๋‹ฌํ•˜๋ฉด ๋งˆ์Šคํ‚น์ด ์ œ๋Œ€๋กœ ๊ตฌํ˜„๋˜์ง€ ์•Š๋Š” ํ˜„์ƒ์ด ์žˆ์—ˆ๋‹ค. ํ•œํŽธ, ์ •์ˆ˜๊ฐ’์œผ๋กœ ์กฐ๊ฑด์„ ๊ตฌํ˜„ํ•˜๋ฉด ์˜๋„ํ•œ๋Œ€๋กœ ๊ตฌํ˜„์ด ๋˜๋Š” ๊ฒƒ์„ ํ™•์ธํ–ˆ๋‹ค. ๊ทธ๋ž˜์„œ ํŒจ๋”ฉ์— ํ•ด๋‹น๋˜๋Š” ํ† ํฐ์ด ์œ„์น˜ํ•œ ๊ณณ์˜ ์›์†Œ๊ฐ’์„ ์ •์ˆ˜ํ˜• Binary ๋กœ ๋งŒ๋“ค์–ด์ฃผ๊ธฐ ์œ„ํ•ด int() ๋ฅผ ์‚ฌ์šฉํ•œ ๊ฒƒ์ด๋‹ค.

๐ŸŽญย Decoder Mask
๋””์ฝ”๋”๋Š” ์ด 2๊ฐ€์ง€์˜ ๋งˆ์Šคํ‚น์ด ํ•„์š”ํ•˜๋‹ค๊ณ  ์–ธ๊ธ‰ํ–ˆ์—ˆ๋‹ค. pad_mask์˜ ๊ฒฝ์šฐ๋Š” ์ธ์ฝ”๋”์˜ ๊ฒƒ๊ณผ ๋™์ผํ•œ ์›๋ฆฌ๋ฅผ ์‚ฌ์šฉํ•˜๊ธฐ ๋•Œ๋ฌธ์— ์„ค๋ช…์„ ์ƒ๋žตํ•˜๊ฒ ๋‹ค. lm_mask ์˜ ๊ฒฝ์šฐ๋Š” torch.tril์„ ์ด์šฉํ•ด ํ•˜์‚ผ๊ฐํ–‰๋ ฌ ํ˜•ํƒœ๋กœ ๋งˆ์Šคํ‚น ํ–‰๋ ฌ ์ •์˜๊ฐ€ ์‰ฝ๊ฒŒ ๊ฐ€๋Šฅํ•˜๋‹ค.
ํ•œํŽธ, 2๊ฐœ์˜ ๋งˆ์Šคํ‚น์„ ๋™์‹œ์— ๋””์ฝ”๋” ๊ฐ์ฒด์— ๋„˜๊ธฐ๋Š” ๊ฒƒ์€ ๋งค์šฐ ๋น„ํšจ์œจ์ ์ด๋‹ค. ๋”ฐ๋ผ์„œ pad_mask ์™€ lm_mask์˜ ํ•ฉ์ง‘ํ•ฉ์— ํ•ด๋‹นํ•˜๋Š” ํ–‰๋ ฌ์„ ๋งŒ๋“ค์–ด ์ตœ์ข… ๋””์ฝ”๋”์˜ ๋งˆ์Šคํ‚น์œผ๋กœ ์ „๋‹ฌํ•œ๋‹ค.

๐Ÿ™Œย Encoder-Decoder Mask
์ด๋ฒˆ ๊ฒฝ์šฐ๋Š” ๋งˆ์Šคํ‚น์˜ ํ–‰๋ฐฉํ–ฅ ์ฐจ์›์€ ๋””์ฝ”๋” ์ž…๋ ฅ๊ฐ’์˜ ์‹œํ€€์Šค ๊ธธ์ด, ์—ด๋ฐฉํ–ฅ ์ฐจ์›์€ ์ธ์ฝ”๋” ์ž…๋ ฅ๊ฐ’์˜ ์‹œํ€€์Šค ๊ธธ์ด๋กœ ์„ค์ •ํ•ด์•ผ ํ•œ๋‹ค. ๊ทธ ์ด์œ ๋Š” ๋‹ค๋ฅธ Self-Attention ๋ ˆ์ด์–ด์™€ ๋‹ค๋ฅด๊ฒŒ ์„œ๋กœ ๋‹ค๋ฅธ ์ถœ์ฒ˜๋ฅผ ํ†ตํ•ด ๋งŒ๋“  ํ–‰๋ ฌ์„ ์‚ฌ์šฉํ•˜๊ธฐ ๋•Œ๋ฌธ์— $\frac{Qโ€ขK^T}{\sqrt{d_h}}$์˜ ๋ชจ์–‘์ด ์ •์‚ฌ๊ฐํ–‰๋ ฌ์ด ์•„๋‹ ์ˆ˜๋„ ์žˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด ํ•œ๊ตญ์–ด ๋ฌธ์žฅ์„ ์˜์–ด๋กœ ๋ฐ”๊พธ๋Š” ๊ฒฝ์šฐ๋ฅผ ์ƒ๊ฐํ•ด๋ณด์ž. ๊ฐ™์€ ๋œป์ด ๋‹ด๊ธด ๋ฌธ์žฅ์ด๋ผ๊ณ  ํ•ด์„œ ๋‘ ๋ฌธ์žฅ์˜ ๊ธธ์ด๊ฐ€ ๊ฐ™์€๊ฐ€?? ์•„๋‹ˆ๋‹ค. ์„œ๋กœ ๋‹ค๋ฅธ ์–ธ์–ด๋ผ๋ฉด ๊ฑฐ์˜ ๋Œ€๋ถ€๋ถ„์˜ ๊ฒฝ์šฐ ๊ธธ์ด๊ฐ€ ๋‹ค๋ฅผ ๊ฒƒ์ด๋‹ค. ๋”ฐ๋ผ์„œ $\frac{Qโ€ขK^T}{\sqrt{d_h}}$์˜ ํ–‰๋ฐฉํ–ฅ์€ ๋””์ฝ”๋”์˜ ์‹œํ€€์Šค ๊ธธ์ด์— ๋”ฐ๋ฅด๊ณ  ์—ด๋ฐฉํ–ฅ์€ ์ธ์ฝ”๋”์˜ ์‹œํ€€์Šค ๊ธธ์ด์— ๋”ฐ๋ฅด๋„๋ก ๋งˆ์Šคํ‚น ์—ญ์‹œ ๊ตฌํ˜„ํ•ด์ค˜์•ผ ํ•œ๋‹ค.
๊ทธ๋ฆฌ๊ณ  ์ด๋ฒˆ ๋งˆ์Šคํ‚น์„ ๋งŒ๋“œ๋Š” ๋ชฉ์ ์ด ์ธ์ฝ”๋”์˜ ํŒจ๋”ฉ์„ ๋งˆ์Šคํ‚น ์ฒ˜๋ฆฌํ•ด์ฃผ๊ธฐ ์œ„ํ•จ์ด๊ธฐ ๋•Œ๋ฌธ์— enc_pad_index ๋งค๊ฐœ๋ณ€์ˆ˜์—๋Š” ์ธ์ฝ”๋” vocab์—์„œ ์ •์˜ํ•œ pad_token_ID๋ฅผ ์ „๋‹ฌํ•˜๋ฉด ๋œ๋‹ค.

# scaled_dot_product_attention์˜ ์ผ๋ถ€

if mask is not None:
		attention_matrix = attention_matrix.masked_fill(mask == 0, float('-inf'))

์ด๋ ‡๊ฒŒ ๊ตฌํ˜„๋œ ๋งˆ์Šคํ‚น์€ scaled_dot_product_attention ๋ฉ”์„œ๋“œ์— ๊ตฌํ˜„๋œ ์กฐ๊ฑด๋ฌธ์„ ํ†ตํ•ด ๋งˆ์Šคํ‚น ๋Œ€์ƒ์„ -โˆž์œผ๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ์—ญํ• ์„ ํ•˜๊ฒŒ ๋œ๋‹ค.

Leave a comment