Updated:

๐Ÿ”ญย Overview

ELECTRA๋Š” 2020๋…„ Google์—์„œ ์ฒ˜์Œ ๋ฐœํ‘œํ•œ ๋ชจ๋ธ๋กœ, GAN(Generative Adversarial Networks) Style ์•„ํ‚คํ…์ฒ˜๋ฅผ NLP์— ์ ์šฉํ•œ ๊ฒƒ์ด ํŠน์ง•์ด๋‹ค. ์ƒˆ๋กœ์šด ๊ตฌ์กฐ ์ฐจ์šฉ์— ๋งž์ถฐ์„œ RTD(Replace Token Dection) Task๋ฅผ ๊ณ ์•ˆ์— ์‚ฌ์ „ ํ•™์Šต์œผ๋กœ ์‚ฌ์šฉํ–ˆ๋‹ค. ๋ชจ๋“  ์•„์ด๋””์–ด๋Š” ๊ธฐ์กด MLM(Masked Language Model)์„ ์‚ฌ์ „ํ•™์Šต ๋ฐฉ๋ฒ•๋ก ์œผ๋กœ ์‚ฌ์šฉํ•˜๋Š” ์ธ์ฝ”๋” ์–ธ์–ด ๋ชจ๋ธ(BERT ๊ณ„์—ด)์˜ ๋‹จ์ ์œผ๋กœ๋ถ€ํ„ฐ ์ถœ๋ฐœํ•œ๋‹ค.

[MLM ๋‹จ์ ]

  • 1) ์‚ฌ์ „ํ•™์Šต๊ณผ ํŒŒ์ธํŠœ๋‹ ์‚ฌ์ด ๋ถˆ์ผ์น˜
    • ํŒŒ์ธํŠœ๋‹ ๋•Œ Masking Task๊ฐ€ ์—†์Œ
  • 2) ์—ฐ์‚ฐ๋Ÿ‰ ๋Œ€๋น„ ํ•™์Šต๋Ÿ‰์€ ์ ์€ํŽธ
    • ์ „์ฒด ์‹œํ€€์Šค์˜ 15%๋งŒ ๋งˆ์Šคํ‚น ํ™œ์šฉ(15%๋งŒ ํ•™์Šต)
    • ์ „์—ญ ์–ดํ…์…˜์˜ ์‹œ๊ณต๊ฐ„ ๋ณต์žก๋„ ๊ณ ๋ คํ•˜๋ฉด ์ƒ๋‹นํžˆ ๋น„ํšจ์œจ์ ์ธ ์ˆ˜์น˜
      • ์‹œํ€€์Šค๊ธธ์ด ** 2์˜ ๋ณต์žก๋„
      • Vocab Size๋งŒํผ์˜ ์ฐจ์›์„ ๊ฐ–๋Š” ์†Œํ”„ํŠธ๋งฅ์Šค ๊ณ„์‚ฐ ๋ฐ˜๋ณต

๊ทธ๋ž˜์„œ MLM์€ ํ™œ์šฉํ•˜๋˜, ํŒŒ์ธํŠœ๋‹๊ณผ ๊ดด๋ฆฌ๋Š” ํฌ์ง€ ์•Š์€ ๋ชฉ์ ํ•จ์ˆ˜๋ฅผ ์„ค๊ณ„ํ•จ์œผ๋กœ์„œ ์ž…๋ ฅ๋œ ์ „์ฒด ์‹œํ€€์Šค์— ๋Œ€ํ•ด์„œ ๋ชจ๋ธ์ด ํ•™์Šตํ•˜์—ฌ ์—ฐ์‚ฐ๋Ÿ‰ ๋Œ€๋น„ ํ•™์Šต๋Ÿ‰์„ ๋Š˜๋ฆฌ๊ณ ์ž ํ–ˆ๋˜๊ฒŒ ๋ฐ”๋กœ ELECTRA ๋ชจ๋ธ์ด๋‹ค.

์ •๋ฆฌํ•˜์ž๋ฉด, ELECTRA ๋ชจ๋ธ์€ ๊ธฐ์กด BERT์˜ ๊ตฌ์กฐ์  ์ธก๋ฉด ๊ฐœ์„ ์ด ์•„๋‹Œ, ์‚ฌ์ „ํ•™์Šต ๋ฐฉ๋ฒ•์— ๋Œ€ํ•œ ๊ฐœ์„  ์‹œ๋„๋ผ๊ณ  ๋ณผ ์ˆ˜ ์žˆ๋‹ค. ๋”ฐ๋ผ์„œ ์–ด๋–ค ๋ชจ๋ธ์ด๋”๋ผ๋„, ์ธ์ฝ”๋” ์–ธ์–ด ๋ชจ๋ธ์ด๋ผ๋ฉด ๋ชจ๋‘ ELECTRA ๊ตฌ์กฐ๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์œผ๋ฉฐ, ๊ธฐ์กด ๋…ผ๋ฌธ์—์„œ๋Š” ์›๋ณธ BERT ๊ตฌ์กฐ๋ฅผ ์‚ฌ์šฉํ–ˆ๋‹ค. ๊ทธ๋ž˜์„œ ๋ณธ ํฌ์ŠคํŒ…์—์„œ๋„ BERT์— ๋Œ€ํ•œ ์„ค๋ช… ์—†์ด RTD์— ๋Œ€ํ•ด์„œ๋งŒ ๋‹ค๋ฃจ๋ ค๊ณ  ํ•œ๋‹ค.

๐Ÿ‘ฎย RTD: New Pre-train Task

RTD Task RTD Task

RTD์˜ ์•„์ด๋””์–ด๋Š” ๊ฐ„๋‹จํ•˜๋‹ค. ์ƒ์„ฑ์ž(Generator)๊ฐ€ ์ถœ๋ ฅ์œผ๋กœ ๋‚ด๋†“์€ ํ† ํฐ ์‹œํ€€์Šค์— ๋Œ€ํ•ด์„œ ํŒ๋ณ„์ž(Discriminator)๊ฐ€ ๊ฐœ๋ณ„ ํ† ํฐ๋“ค์ด ์›๋ณธ์ธ์ง€ ์•„๋‹Œ์ง€๋ฅผ ํŒ์ •(์ด์ง„ ๋ถ„๋ฅ˜)ํ•˜๋„๋ก ๋งŒ๋“ ๋‹ค. ์ƒ์„ฑ์ž๋Š” ๊ธฐ์กด์˜ MLM์„ ๊ทธ๋Œ€๋กœ ์ˆ˜ํ–‰ํ•˜๊ณ , ํŒ๋ณ„์ž๋Š” ์ƒ์„ฑ์ž์˜ ์˜ˆ์ธก์— ๋Œ€ํ•ด ์ง„์งœ์ธ์ง€ ๊ฐ€์งœ์ธ์ง€ ๋ถ„๋ฅ˜ํ•˜๋Š” ์‹์ด๋‹ค.

์œ„ ๊ทธ๋ฆผ์„ ์˜ˆ์‹œ๋กœ ์‚ดํŽด๋ณด์ž. ๋ชจ๋ธ์— ์ž…๋ ฅ์œผ๋กœ the chef cooked the meal๋ผ๋Š” ์‹œํ€€์Šค ์ค€๋‹ค. ๊ทธ๋Ÿฌ๋ฉด MLM ๊ทœ์น™์— ๋”ฐ๋ผ์„œ 15%์˜ ํ† ํฐ์ด ๋ฌด์ž‘์œ„๋กœ ์„ ํƒ๋œ๋‹ค. ๊ทธ๋ž˜์„œ the, cooked๊ฐ€ ๋งˆ์Šคํ‚น ๋˜์—ˆ๋‹ค. ์ด์ œ ์ƒ์„ฑ์ž๋Š” ๋งˆ์Šคํ‚น ํ† ํฐ์— ๋Œ€ํ•ด the, ate๋ผ๋Š” ๊ฒฐ๊ณผ๋ฅผ ๋‚ด๋†“๋Š”๋‹ค. ๊ทธ๋ž˜์„œ ์ตœ์ข…์ ์œผ๋กœ ์ƒ์„ฑ์ž๊ฐ€ ๋ฐ˜ํ™˜ํ•˜๋Š” ์‹œํ€€์Šค๋Š” the chef ate the meal์ด ๋œ๋‹ค. ์ด์ œ ์ƒ์„ฑ์ž๊ฐ€ ๋ฐ˜ํ™˜ํ•œ ์‹œํ€€์Šค๋ฅผ ํŒ๋ณ„์ž์— ์ž…๋ ฅ์œผ๋กœ ๋Œ€์ž…ํ•œ๋‹ค. ํŒ๋ณ„์ž๋Š” ๊ฐœ๋ณ„ ํ† ํฐ๋“ค์ด ์›๋ณธ์ธ์ง€ ์•„๋‹Œ์ง€๋ฅผ ํŒ์ •ํ•ด ๊ฒฐ๊ณผ๋ฅผ ์ถœ๋ ฅํ•œ๋‹ค.

์ด๋Ÿฌํ•œ ๊ตฌ์กฐ ๋ฐ ์‚ฌ์ „ํ•™์Šต ๋ฐฉ์‹์˜ ์žฅ์ ์€ ํŒ๋ณ„์ž๊ฐ€ MLM ํ•™์Šต์— ๋”ฐ๋ฅธ ์ง€์‹์„ ์ƒ์„ฑ์ž๋กœ๋ถ€ํ„ฐ ์ „์ˆ˜ ๋ฐ›๋Š” ๋™์‹œ์— ์ „์ฒด ์‹œํ€€์Šค์— ๋Œ€ํ•ด์„œ ํ•™์Šตํ•  ๊ธฐํšŒ๊ฐ€ ์ƒ๊ธด๋‹ค๋Š” ๊ฒƒ์ด๋‹ค. ์‹œํ€€์Šค ๋‚ด๋ถ€ ๋ชจ๋“  ํ† ํฐ์— ๋Œ€ํ•ด์„œ ์˜ˆ์ธก์„ ์ˆ˜ํ–‰ํ•˜๊ณ  ์†์‹ค์„ ๊ณ„์‚ฐํ•  ์ˆ˜ ์žˆ๊ธฐ ๋•Œ๋ฌธ์— ๊ฐ™์€ ํฌ๊ธฐ์˜ ์‹œํ€€์Šค๋ฅผ ์‚ฌ์šฉํ•ด๋„ ๊ธฐ์กด MLM ๋Œ€๋น„ ๋” ํ’๋ถ€ํ•œ ๋ฌธ๋งฅ ์ •๋ณด๋ฅผ ๋ชจ๋ธ์ด ํฌ์ฐฉํ•  ์ˆ˜ ์žˆ๊ฒŒ ๋œ๋‹ค. ๋˜ํ•œ ํŒ๋ณ„์ž๋ฅผ ํŒŒ์ธํŠœ๋‹์˜ BackBone์œผ๋กœ ์‚ฌ์šฉํ•˜๋ฉด, ํŒ๋ณ„์ž์˜ ์‚ฌ์ „ํ•™์Šต์€ ๊ฒฐ๊ตญ ๋งˆ์Šคํ‚น ์—†์ด ๋ชจ๋“  ์‹œํ€€์Šค๋ฅผ ํ™œ์šฉํ•œ ์ด์ง„ ๋ถ„๋ฅ˜๋ผ๊ณ  ๋ณผ ์ˆ˜ ์žˆ๊ธฐ ๋•Œ๋ฌธ์—, ์‚ฌ์ „ํ•™์Šต๊ณผ ํŒŒ์ธํŠœ๋‹ ์‚ฌ์ด์˜ ๊ดด๋ฆฌ๋„ ์ƒ๋‹นํžˆ ๋งŽ์ด ์ค„์–ด๋“ค๊ฒŒ ๋œ๋‹ค.

๐ŸŒŸย Architecture

Model Architecture Model Architecture

์ €์ž๋Š” ์œ„์™€ ๊ฐ™์€ ์‹คํ—˜ ๊ฒฐ๊ณผ๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ, ์ƒ์„ฑ์ž์˜ width (์€๋‹‰์ธต) ํฌ๊ธฐ๊ฐ€ ํŒ๋ณ„์ž๋ณด๋‹ค ์ž‘๋„๋ก ๋ชจ๋ธ ํฌ๊ธฐ๋ฅผ ์„ธํŒ…ํ•˜๋Š”๊ฒŒ ๊ฐ€์žฅ ํšจ์œจ์ ์ด๋ผ๊ณ  ์ฃผ์žฅํ•œ๋‹ค. ์ œ์‹œ๋œ ๊ทธ๋ž˜ํ”„๋Š” ์ƒ์„ฑ์ž์™€ ํŒ๋ณ„์ž์˜ ํฌ๊ธฐ ๋ณ€ํ™” ๋Œ€๋น„ ํŒŒ์ธํŠœ๋‹ ์„ฑ๋Šฅ์˜ ์ถ”์ด๋ฅผ ๋‚˜ํƒ€๋‚ธ๋‹ค. ์ƒ์„ฑ์ž์˜ width ํฌ๊ธฐ๊ฐ€ 256, ํŒ๋ณ„์ž์˜ width ํฌ๊ธฐ๊ฐ€ 768์ผ ๋•Œ ๊ฐ€์žฅ ์ ์ˆ˜๊ฐ€ ๋†’๋‹ค. depth(๋ ˆ์ด์–ด ๊ฐœ์ˆ˜)์— ๋Œ€ํ•œ ์–ธ๊ธ‰์€ ๋”ฐ๋กœ ์—†์ง€๋งŒ, ์ €์ž์— ์˜ํ•ด ๊ณต๊ฐœ๋œ Hyper-Param ํ…Œ์ด๋ธ”์„ ๋ณด๋ฉด ์€๋‹‰์ธต์˜ ํฌ๊ธฐ๋งŒ ์ค„์ด๊ณ , ๋ ˆ์ด์–ด ๊ฐœ์ˆ˜๋Š” ์ƒ์„ฑ์ž์™€ ํŒ๋ณ„์ž๊ฐ€ ๊ฐ™์€ ๊ฒƒ์œผ๋กœ ์ถ”์ •๋œ๋‹ค.

์ถ”๊ฐ€๋กœ, ์ƒ์„ฑ์ž์™€ ํŒ๋ณ„์ž๊ฐ€ ์ž„๋ฒ ๋”ฉ ์ธต์„ ์„œ๋กœ ๊ณต์œ ํ•˜๋Š”๊ฒŒ ๊ฐ€์žฅ ๋†’์€ ์„ฑ๋Šฅ์„ ๋‚ธ๋‹ค๊ณ  ์ฃผ์žฅํ•œ๋‹ค. ์˜ค๋ฅธ์ชฝ ๊ทธ๋ž˜ํ”„ ์ถ”์ด๋ฅผ ๋ณด๋ฉด ๊ฐ™์€ ์—ฐ์‚ฐ๋Ÿ‰์ด๋ผ๋ฉด, ์ž„๋ฒ ๋”ฉ ๊ณต์œ (ํŒŒ๋ž€์ƒ‰ ์‹ค์„ ) ๋ฐฉ์‹์ด ๊ฐ€์žฅ ๋†’์€ ํŒŒ์ธํŠœ๋‹ ์„ฑ๋Šฅ์„ ๋ณด์—ฌ์ค€๋‹ค๋Š” ๊ฒƒ์„ ์•Œ ์ˆ˜ ์žˆ๋‹ค. ๋”ฐ๋ผ์„œ ๋‹จ์–ด ์ž„๋ฒ ๋”ฉ, ์ ˆ๋Œ€ ์œ„์น˜ ์ž„๋ฒ ๋”ฉ์„ ์„œ๋กœ ๊ณต์œ ํ•˜๋„๋ก ์„ค๊ณ„ํ•œ๋‹ค. ๋Œ€์‹  ์ƒ์„ฑ์ž ์€๋‹‰์ธต์˜ ํฌ๊ธฐ๊ฐ€ ๋” ์ž‘์€๊ฒŒ ์œ ๋ฆฌํ•˜๋‹ค๊ณ  ์–ธ๊ธ‰ํ–ˆ๊ธฐ ๋•Œ๋ฌธ์—, ์ด๊ฒƒ์„ ์‹ค์ œ๋กœ ๊ตฌํ˜„ํ•˜๋ ค๋ฉด ์ž„๋ฒ ๋”ฉ ์ธต์œผ๋กœ๋ถ€ํ„ฐ ๋‚˜์˜จ ๊ฒฐ๊ณผ๊ฐ’์„ ์ƒ์„ฑ์ž์˜ ์€๋‹‰์ธต ์ฐจ์›์œผ๋กœ ์„ ํ˜• ํˆฌ์˜ํ•ด์ค˜์•ผ ํ•œ๋‹ค. ๊ทธ๋ž˜์„œ ์ƒ์„ฑ์ž์˜ ์ž„๋ฒ ๋”ฉ ์ธต๊ณผ ์ธ์ฝ”๋” ์‚ฌ์ด์— linear layer๊ฐ€ ์ถ”๊ฐ€๋˜์–ด์•ผ ํ•œ๋‹ค.

\[\min_{\theta_G, \theta_D}\sum_{x \in X} \mathcal{L}_{\text{MLM}}(x, \theta_G) + \lambda \mathcal{L}_{\text{Disc}}(x, \theta_D)\]

๋”ฐ๋ผ์„œ, ์ง€๊ธˆ๊นŒ์ง€ ์‚ดํŽด๋ณธ ๋ชจ๋“  ๋‚ด์šฉ์„ ์ข…ํ•ฉํ•ด๋ณด๋ฉด ELECTRA์˜ ๋ชฉ์ ํ•จ์ˆ˜๋Š” ๋‹ค์Œ ์ˆ˜์‹๊ณผ ๊ฐ™๋‹ค. ์ƒ์„ฑ์ž์˜ MLM ์†์‹ค๊ณผ ํŒ๋ณ„์ž์˜ ์ด์ง„ ๋ถ„๋ฅ˜ ์†์‹ค์„ ๋”ํ•ด์„œ ๋ชจ๋ธ์— ์˜ค์ฐจ ์—ญ์ „ํ•ด์ฃผ๋ฉด ๋˜๋Š”๋ฐ, ํŠน์ดํ•œ ์ ์€ ํŒ๋ณ„์ž์˜ ์†์‹ค์— ์ƒ์ˆ˜๊ฐ’์ธ ๋žŒ๋‹ค๊ฐ€ ๊ณฑํ•ด์ง„๋‹ค๋Š” ๊ฒƒ์ด๋‹ค. ์‹ค์ œ ๋ชจ๋ธ์„ ๊ตฌํ˜„ํ•˜๊ณ  ์‚ฌ์ „ํ•™์Šต์„ ํ•ด๋ณด๋ฉด, ๋ฐ์ดํ„ฐ์˜ ์–‘์ด๋‚˜ ๋ชจ๋ธ ํฌ๊ธฐ ํ˜น์€ ์ข…๋ฅ˜์— ๋”ฐ๋ผ ๋‹ฌ๋ผ์ง€๊ฒ ์ง€๋งŒ ๋‘ ์†์‹ค ์‚ฌ์ด์˜ ์Šค์ผ€์ผ์˜ ์ฐจ์ด๊ฐ€ 10๋ฐฐ์ •๋„ ์ฐจ์ด ๋‚˜๊ฒŒ ๋œ๋‹ค. ๋‘ ์†์‹ค์˜ ์Šค์ผ€์ผ์„ ๋งž์ถฐ์ฃผ๋Š” ๋™์‹œ์—, ์ž„๋ฒ ๋”ฉ์ธต์˜ ํ•™์Šต์ด ํŒ๋ณ„์ž์˜ ์†์‹ค์„ ์ค„์ด๋Š”๋ฐ ๋” ์ง‘์ค‘ํ•˜๋„๋ก ๋งŒ๋“ค๊ธฐ ์œ„ํ•ด ๋„์ž…ํ•œ ๊ฒƒ์œผ๋กœ ์ถ”์ •๋œ๋‹ค. ๋…ผ๋ฌธ๊ณผ ์ฝ”๋“œ๋ฅผ ๋ณด๋ฉด ์ €์ž๋Š” $\lambda=50$ ์œผ๋กœ ๋‘๊ณ  ํ•™์Šตํ•˜๊ณ  ์žˆ๋‹ค.

๐Ÿ‘ฉโ€๐Ÿ’ปย Implementation by Pytorch

๋…ผ๋ฌธ์˜ ๋‚ด์šฉ๊ณผ ์ €์ž๊ฐ€ ์ง์ ‘ ๊ณต๊ฐœํ•œ ์ฝ”๋“œ๋ฅผ ์ข…ํ•ฉํ•˜์—ฌ ํŒŒ์ดํ† ์น˜๋กœ ELECTRA๋ฅผ ๊ตฌํ˜„ํ•ด๋ดค๋‹ค. ๋‘ ๊ฐœ์˜ ์„œ๋กœ ๋‹ค๋ฅธ ๋ชจ๋ธ์„ ๊ฐ™์€ ์Šคํƒญ์—์„œ ํ•™์Šต์‹œ์ผœ์•ผ ํ•˜๊ธฐ ๋•Œ๋ฌธ์—, ์ œ์‹œ๋œ ๋‚ด์šฉ์— ๋น„ํ•ด ์‹ค์ œ ๊ตฌํ˜„์€ ๋งค์šฐ ๊นŒ๋‹ค๋กœ์šด ํŽธ์ด์—ˆ๋‹ค. ๋ณธ ํฌ์ŠคํŒ…์—์„œ๋Š” ELECTRA ๋ชจ๋ธ ๊ตฌ์กฐ๋ฅผ ๋น„๋กฏํ•ด RTD ํ•™์Šต ํŒŒ์ดํ”„๋ผ์ธ ๊ตฌ์„ฑ์— ํ•„์ˆ˜์ ์ธ ์š”์†Œ ๋ช‡ ๊ฐ€์ง€์— ๋Œ€ํ•ด์„œ๋งŒ ์„ค๋ช…ํ•˜๋ ค ํ•œ๋‹ค. ์ „์ฒด ๊ตฌ์กฐ์— ๋Œ€ํ•œ ์ฝ”๋“œ๋Š” ์—ฌ๊ธฐ ๋งํฌ๋ฅผ ํ†ตํ•ด ์ฐธ๊ณ  ๋ถ€ํƒ๋“œ๋ฆฐ๋‹ค.

ELECTRA์˜ ์‚ฌ์ „ ํ•™์Šต์ธ RTD์˜ ํ•™์Šต ํŒŒ์ดํ”„๋ผ์ธ์„ ๊ตฌํ˜„ํ•œ ์ฝ”๋“œ๋ฅผ ๋ณธ ๋’ค, ์„ธ๋ถ€ ๊ตฌ์„ฑ ์š”์†Œ๋“ค์— ๋Œ€ํ•ด์„œ ์‚ดํŽด๋ณด์ž.

๐ŸŒ† RTD trainer method

def train_val_fn(self, loader_train, model: nn.Module, criterion: nn.Module, optimizer, scheduler) -> Tuple[Any, Union[float, Any]]:
  scaler = torch.cuda.amp.GradScaler(enabled=self.cfg.amp_scaler)
  model.train()
  for step, batch in enumerate(tqdm(loader_train)):
      optimizer.zero_grad(set_to_none=True)
      inputs = batch['input_ids'].to(self.cfg.device, non_blocking=True)
      labels = batch['labels'].to(self.cfg.device, non_blocking=True)  
      padding_mask = batch['padding_mask'].to(self.cfg.device, non_blocking=True)  

      mask_labels = None
      if self.cfg.rtd_masking == 'SpanBoundaryObjective':
          mask_labels = batch['mask_labels'].to(self.cfg.device, non_blocking=True)

      with torch.cuda.amp.autocast(enabled=self.cfg.amp_scaler):
          g_logit, d_inputs, d_labels = model.generator_fw(
              inputs,
              labels,
              padding_mask,
              mask_labels
          )
          d_logit = model.discriminator_fw(
              d_inputs,
              padding_mask
          )
          g_loss = criterion(g_logit.view(-1, self.cfg.vocab_size), labels.view(-1))
          d_loss = criterion(d_logit.view(-1, 2), d_labels)
          loss = g_loss + d_loss*self.cfg.discriminator_lambda

      scaler.scale(loss).backward()
      scaler.step(optimizer)
      scaler.update()
      scheduler.step()

๋ฐ์ดํ„ฐ๋กœ๋”๋กœ๋ถ€ํ„ฐ ๋ฐ›์€ ์ž…๋ ฅ๋“ค์„ ์ƒ์„ฑ์ž์— ๋„ฃ๊ณ  MLM ์˜ˆ์ธก ๊ฒฐ๊ณผ, RTD ์ˆ˜ํ–‰์„ ์œ„ํ•ด ํ•„์š”ํ•œ ์ƒˆ๋กœ์šด ๋ผ๋ฒจ๊ฐ’์„ ๋ฐ˜ํ™˜ ๋ฐ›๋Š”๋‹ค. ๊ทธ๋ฆฌ๊ณ  ์ด๊ฒƒ์„ ๋‹ค์‹œ ํŒ๋ณ„์ž์˜ ์ž…๋ ฅ์œผ๋กœ ์‚ฌ์šฉํ•˜๊ณ , ํŒ๋ณ„์ž์˜ ์˜ˆ์ธก ๊ฒฐ๊ณผ๋ฅผ ๋ฐ˜ํ™˜๋ฐ›์•„ ์„œ๋กœ ๋‹ค๋ฅธ ๋‘ ๋ชจ๋ธ์— ๋Œ€ํ•œ ๊ฐ€์ค‘ ์†์‹คํ•ฉ์‚ฐ์„ ๊ตฌํ•œ ๋’ค, ์˜ตํ‹ฐ๋งˆ์ด์ €์— ๋ณด๋‚ด๊ณ  ์ตœ์ ํ™”๋ฅผ ์ˆ˜ํ–‰ํ•œ๋‹ค. ์ด ๋•Œ, ์ฒ˜์Œ์— ๋ฐ์ดํ„ฐ๋กœ๋”๊ฐ€ ๋ฐ˜ํ™˜ํ•˜๋Š” ์ž…๋ ฅ ์‹œํ€€์Šค์™€ ๋ผ๋ฒจ์€ MLM์˜ ๊ทธ๊ฒƒ๊ณผ ๋™์ผํ•˜๋‹ค,

๊ตฌํ˜„ํ•˜๋ฉด์„œ ๊ฐ€์žฅ ์–ด๋ ค์› ๋˜๊ฒŒ, ์˜ตํ‹ฐ๋งˆ์ด์ € ๋ฐ ์Šค์ผ€์ค„๋Ÿฌ์˜ ๊ตฌ์„ฑ์ด์—ˆ๋‹ค. ๋‘ ๊ฐœ์˜ ๋ชจ๋ธ์„ ๊ฐ™์€ ์Šคํƒญ์—์„œ ํ•™์Šต์‹œํ‚ค๋Š” ๊ฒฝํ—˜์ด ์ฒ˜์Œ์ด๋ผ์„œ ์ฒ˜์Œ์— ๋ชจ๋ธ ๊ฐœ์ˆ˜๋งŒํผ ์˜ตํ‹ฐ๋งˆ์ด์ €์™€ ์Šค์ผ€์ค„๋Ÿฌ ๊ฐ์ฒด๋ฅผ ๋งŒ๋“ค์–ด์ค˜์•ผ ํ•œ๋‹ค๊ณ  ์ƒ๊ฐํ–ˆ๋‹ค. ํŠนํžˆ ๋‘ ๋ชจ๋ธ์˜ ์Šค์ผ€์ผ์ด ๋‹ค๋ฅด๊ธฐ ๋•Œ๋ฌธ์— ์„œ๋กœ ๋‹ค๋ฅธ ์˜ตํ‹ฐ๋งˆ์ด์ €, ์Šค์ผ€์ค„๋Ÿฌ ๋„์ž…์œผ๋กœ ๊ฐ๊ธฐ ๋‹ค๋ฅธ ํ•™์Šต๋ฅ ์„ ์ ์šฉํ•˜๋Š”๊ฒŒ ์ •ํ™•ํ•  ๊ฒƒ์ด๋ผ ์ƒ๊ฐํ–ˆ๋‹ค.

ํ•˜์ง€๋งŒ, ์˜ตํ‹ฐ๋งˆ์ด์ €๋ฅผ ๋‘ ๊ฐœ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์€ ๋งค์šฐ ๋งŽ์€ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ์ฐจ์ง€ํ•  ๋ฟ๋”๋Ÿฌ ๋…ผ๋ฌธ์—์„œ ๊ณต๊ฐœํ•œ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ ํ…Œ์ด๋ธ”์„ ๋ณด๋ฉด ๋‘ ๋ชจ๋ธ์— ๊ฐ™์€ ํ•™์Šต๋ฅ ์„ ์ ์šฉํ•˜๊ณ  ์žˆ๋Š” ๊ฒƒ์„ ์•Œ ์ˆ˜ ์žˆ์—ˆ๋‹ค. ๋”ฐ๋ผ์„œ ๊ทธ์— ๋งž๊ฒŒ ๊ฐ™์€ ์˜ตํ‹ฐ๋งˆ์ด์ €, ์Šค์ผ€์ค„๋Ÿฌ๋ฅผ ์‚ฌ์šฉํ•ด ๋™์‹œ์— ๋‘ ๋ชจ๋ธ์ด ํ•™์Šต๋˜๋„๋ก ํŒŒ์ดํ”„๋ผ์ธ์„ ๋งŒ๋“ค๊ฒŒ ๋˜์—ˆ๋‹ค.

์ถ”๊ฐ€๋กœ, ๊ณต๊ฐœ๋œ ์˜คํ”ผ์…œ ์ฝ”๋“œ ์—ญ์‹œ ๋‹จ์ผ ์˜ตํ‹ฐ๋งˆ์ด์ € ๋ฐ ์Šค์ผ€์ค„๋Ÿฌ๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์„ ํ™•์ธํ–ˆ๋‹ค.

๐ŸŒ† ELECTRA Module

import torch
import torch.nn as nn
from experiment.models.abstract_model import AbstractModel
from torch import Tensor
from typing import Tuple, Callable
from einops.layers.torch import Rearrange
from experiment.tuner.mlm import MLMHead
from experiment.tuner.sbo import SBOHead
from experiment.tuner.rtd import get_discriminator_input, RTDHead
from configuration import CFG

class ELECTRA(nn.Module, AbstractModel):
    def __init__(self, cfg: CFG, model_func: Callable) -> None:
        super(ELECTRA, self).__init__()
        self.cfg = cfg
        self.generator = model_func(cfg.generator_num_layers)  # init generator
        self.mlm_head = MLMHead(self.cfg)
        if self.cfg.rtd_masking == 'SpanBoundaryObjective':
            self.mlm_head = SBOHead(
                cfg=self.cfg,
                is_concatenate=self.cfg.is_concatenate,
                max_span_length=self.cfg.max_span_length
            )
        self.discriminator = model_func(cfg.discriminator_num_layers)  # init generator
        self.rtd_head = RTDHead(self.cfg)
        self.share_embed_method = self.cfg.share_embed_method  # instance, es, gdes
        self.share_embedding()

    def share_embedding(self) -> None:
        def discriminator_hook(module: nn.Module, *inputs):
            if self.share_embed_method == 'instance':  # Instance Sharing
                self.discriminator.embeddings = self.generator.embeddings

            elif self.share_embed_method == 'ES':  # ES (Embedding Sharing)
                self.discriminator.embeddings.word_embedding.weight = self.generator.embeddings.word_embedding.weight
                self.discriminator.embeddings.abs_pos_emb.weight = self.generator.embeddings.abs_pos_emb.weight
        self.discriminator.register_forward_pre_hook(discriminator_hook)

    def generator_fw(
            self,
            inputs: Tensor,
            labels: Tensor,
            padding_mask: Tensor,
            mask_labels: Tensor = None,
            attention_mask: Tensor = None
    ) -> Tuple[Tensor, Tensor, Tensor]:
        g_last_hidden_states, _ = self.generator(
            inputs,
            padding_mask,
            attention_mask
        )
        if self.cfg.rtd_masking == 'MaskedLanguageModel':
            g_logit = self.mlm_head(
                g_last_hidden_states
            )
        elif self.cfg.rtd_masking == 'SpanBoundaryObjective':
            g_logit = self.mlm_head(
                g_last_hidden_states,
                mask_labels
            )
        pred = g_logit.clone().detach()
        d_inputs, d_labels = get_discriminator_input(
            inputs,
            labels,
            pred,
        )
        return g_logit, d_inputs, d_labels

    def discriminator_fw(
            self,
            inputs: Tensor,
            padding_mask: Tensor,
            attention_mask: Tensor = None
    ) -> Tensor:
        d_last_hidden_states, _ = self.discriminator(
            inputs,
            padding_mask,
            attention_mask
        )
        d_logit = self.rtd_head(
            d_last_hidden_states
        )
        return d_logit

ELECTRA ๋ชจ๋ธ ๊ฐ์ฒด๋Š” ํฌ๊ฒŒ ์ž„๋ฐฐ๋”ฉ ๋ ˆ์ด์–ด ๊ณต์œ , ์ƒ์„ฑ์ž ํฌ์›Œ๋“œ, ํŒ๋ณ„์ž ํฌ์›Œ๋“œ ํŒŒํŠธ๋กœ ๋‚˜๋‰œ๋‹ค. ๋จผ์ € ์ž„๋ฒ ๋”ฉ ๋ ˆ์ด์–ด ๊ณต์œ ๋Š” ํฌ๊ฒŒ ๋‘ ๊ฐ€์ง€ ๋ฐฉ์‹์œผ๋กœ ๊ตฌํ˜„ ๊ฐ€๋Šฅํ•˜๋‹ค. ํ•˜๋‚˜๋Š” ์ž„๋ฒ ๋”ฉ ๋ ˆ์ด์–ด ์ธ์Šคํ„ด์Šค ์ž์ฒด๋ฅผ ๊ณต์œ ํ•˜๋Š” ๋ฐฉ์‹์œผ๋กœ, ์ƒ์„ฑ์ž์™€ ํŒ๋ณ„์ž์˜ ์Šค์ผ€์ผ์ด ๋™์ผํ•  ๋•Œ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋‹ค. ๋‚˜๋จธ์ง€๋Š” ๋‹จ์–ด ์ž„๋ฒ ๋”ฉ, ํฌ์ง€์…˜ ์ž„๋ฒ ๋”ฉ์˜ ๊ฐ€์ค‘์น˜ ํ–‰๋ ฌ๋งŒ ๊ณต์œ ํ•˜๋Š” ๋ฐฉ์‹์œผ๋กœ, ์„œ๋กœ ์Šค์ผ€์ผ์ด ๋‹ฌ๋ผ๋„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋‹ค. ๋…ผ๋ฌธ์—์„œ ์ œ์‹œํ•˜๋Š” ๊ฐ€์žฅ ํšจ์œจ์ ์ธ ๋ฐฉ๋ฒ•์€ ํ›„์ž์ด๋ฉฐ, ํŒ๋ณ„์ž์˜ ์ž„๋ฒ ๋”ฉ ํ–‰๋ ฌ์ด ์ƒ์„ฑ์ž์˜ ์ž„๋ฒ ๋”ฉ ํ–‰๋ ฌ์˜ ์ฃผ์†Œ๋ฅผ ๊ฐ€๋ฆฌํ‚ค๋„๋ก ํ•จ์œผ๋กœ์„œ ๊ตฌํ˜„ ๊ฐ€๋Šฅํ•˜๋‹ค.

๐ŸŒ† RTD Input Making

import torch
import torch.nn as nn
from torch import Tensor
from typing import Tuple
from configuration import CFG

def get_discriminator_input(inputs: Tensor, labels: Tensor, pred: Tensor) -> Tuple[Tensor, Tensor]:
    """ Post Processing for Replaced Token Detection Task
    1) get index of the highest probability of [MASK] token in pred tensor
    2) convert [MASK] token to prediction token
    3) make label for Discriminator

    Args:
        inputs: pure inputs from tokenizing by tokenizer
        labels: labels for masked language modeling
        pred: prediction tensor from Generator

    returns:
        d_inputs: torch.Tensor, shape of [Batch, Sequence], for Discriminator inputs
        d_labels: torch.Tensor, shape of [Sequence], for Discriminator labels
    """
    # 1) flatten pred to 2D Tensor
    d_inputs, d_labels = inputs.clone().detach().view(-1), None  # detach to prevent back-propagation
    flat_pred, flat_label = pred.view(-1, pred.size(-1)), labels.view(-1)  # (batch * sequence, vocab_size)

    # 2) get index of the highest probability of [MASK] token
    pred_token_idx, mlm_mask_idx = flat_pred.argmax(dim=-1), torch.where(flat_label != -100)
    pred_tokens = torch.index_select(pred_token_idx, 0, mlm_mask_idx[0])

    # 3) convert [MASK] token to prediction token
    d_inputs[mlm_mask_idx[0]] = pred_tokens

    # 4) make label for Discriminator
    original_tokens = inputs.clone().detach().view(-1)
    original_tokens[mlm_mask_idx[0]] = flat_label[mlm_mask_idx[0]]
    d_labels = torch.eq(original_tokens, d_inputs).long()
    d_inputs = d_inputs.view(pred.size(0), -1)  # covert to [batch, sequence]
    return d_inputs, d_labels

์ด์ œ ๋งˆ์ง€๋ง‰์œผ๋กœ ํŒ๋ณ„์ž์˜ ์ž…๋ ฅ์„ ๋งŒ๋“œ๋Š” ์•Œ๊ณ ๋ฆฌ์ฆ˜์— ๋Œ€ํ•œ ์ฝ”๋“œ๋ฅผ ๋ณด์ž. ์•Œ๊ณ ๋ฆฌ์ฆ˜์€ ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค.

  • 1) ๊ฐœ๋ณ„ ๋งˆ์Šคํ‚น ํ† ํฐ์— ๋Œ€ํ•œ ์˜ˆ์ธก ํ† ํฐ ๊ตฌํ•˜๊ธฐ
    • ๋กœ์ง“์„ ์‹ค์ œ ํ† ํฐ ์ธ๋ฑ์Šค๋กœ ๋ณ€ํ™˜
  • 2) ๋ชจ๋“  ๋งˆ์Šคํ‚น ๋ถ€๋ถ„์— ์˜ˆ์ธก ํ† ํฐ๋“ค๋กœ ๋Œ€์ฒด
  • 3) ๊ธฐ์กด ์ž…๋ ฅ๊ณผ 2๋ฒˆ์œผ๋กœ ๋งŒ๋“ค์–ด์ง„ ์‹œํ€€์Šค ๋น„๊ตํ•ด ๋ผ๋ฒจ ์ƒ์„ฑ
    • ์„œ๋กœ ๊ฐ™์œผ๋ฉด 0
    • ์„œ๋กœ ๋‹ค๋ฅด๋ฉด 1 ์ด๋ ‡๊ฒŒ ๋งŒ๋“ค์–ด์ง„ ์ƒˆ๋กœ์šด ์ž…๋ ฅ ์‹œํ€€์Šค์™€ ๋ผ๋ฒจ์„ ELECTRA ๋ชจ๋ธ ์ธ์Šคํ„ด์Šค์˜ ํŒ๋ณ„์ž ํฌ์›Œ๋“œ ๋ฉ”์„œ๋“œ์— ์ธ์ž๋กœ ์ „๋‹ฌํ•˜๋ฉด ๋œ๋‹ค.

๐ŸŒŸ Future Work (์ฝ๊ณ  ๊ตฌํ˜„ํ•˜๋ฉด์„œ ๋Š๋‚€์  & ๊ฐœ์„ ๋ฐฉํ–ฅ)

์ด๋ ‡๊ฒŒ ELECTRA ๋ชจ๋ธ์— ๋Œ€ํ•œ ๊ตฌํ˜„๊นŒ์ง€ ์‚ดํŽด๋ดค๋‹ค. ๋…ผ๋ฌธ์„ ์ฝ๊ณ  ๊ตฌํ˜„ํ•˜๋ฉด์„œ ๊ฐ€์žฅ ์˜๋ฌธ์Šค๋Ÿฌ์› ๋˜ ๋ถ€๋ถ„์€ ์ž„๋ฒ ๋”ฉ ๊ณต์œ  ๋ฐฉ๋ฒ•์ด์—ˆ๋‹ค. ์ˆ˜ํ•™์ ์œผ๋กœ ์—„๋ฐ€ํ•˜๊ฒŒ ๊ณ„์‚ฐํ•˜๊ณ  ๋”ฐ์ ธ๋ณด์ง€ ๋ชปํ–ˆ์ง€๋งŒ, ์ง๊ด€์ ์œผ๋กœ๋„ ์ƒ์„ฑ์ž์˜ MLM๊ณผ ํŒ๋ณ„์ž์˜ RTD๋Š” ์„œ๋กœ ์„ฑ๊ฒฉ์ด ์ƒ๋‹นํžˆ ๋‹ค๋ฅธ ์‚ฌ์ „ ํ•™์Šต ๋ฐฉ๋ฒ•๋ก ์ด๋ผ๋Š” ๊ฒƒ์„ ์•Œ ์ˆ˜ ์žˆ๋‹ค. ๊ทธ๋ ‡๋‹ค๋ฉด ๋‹จ์ˆœํžˆ ๋‹จ์–ด, ํฌ์ง€์…˜ ์ž„๋ฒ ๋”ฉ์„ ๊ณต์œ ํ•˜๋Š” ๊ฒฝ์šฐ ํ•™์Šต ๋ฐฉํ–ฅ์„ฑ์ด ๋‹ฌ๋ผ์„œ ๊ฐ„์„ญ์ด ๋ฐœ์ƒํ•˜๊ณ  ๋ชจ๋ธ์ด ์ˆ˜๋ ดํ•˜์ง€ ๋ชปํ•  ์—ฌ์ง€๊ฐ€ ์ƒ๊ธด๋‹ค. ์ด๋Ÿฌํ•œ ์ค„๋‹ค๋ฆฌ๊ธฐ ํ˜„์ƒ(tag-of-war)์„ ์–ด๋–ป๊ฒŒ ํ•ด๊ฒฐํ•  ์ˆ˜ ์žˆ์„๊นŒ์— ๋Œ€ํ•œ ๊ณ ๋ฏผ์ด ๋” ํ•„์š”ํ•˜๋‹ค๊ณ  ์ƒ๊ฐํ•œ๋‹ค.

๊ทธ๋ž˜์„œ ๋‹ค์Œ ํฌ์ŠคํŒ…์—์„œ๋Š” ์ด๋Ÿฌํ•œ ์ค„๋‹ค๋ฆฌ๊ธฐ ํ˜„์ƒ์„ ํ•ด๊ฒฐํ•˜๊ณ ์žํ•œ ๋…ผ๋ฌธ์ธ <DeBERTaV3: Improving DeBERTa using ELECTRA-Style Pre-Training with Gradient-Disentangled Embedding Sharing>์„ ๋ฆฌ๋ทฐํ•ด๋ณด๊ณ ์ž ํ•œ๋‹ค.

Leave a comment