Updated:

๐Ÿ”ญย Overview

DistilBERT ๋Š” ํ—ˆ๊น… ํŽ˜์ด์Šค ์—ฐ๊ตฌ์ง„์ด 2019๋…„ ๋ฐœํ‘œํ•œ BERT์˜ ๋ณ€ํ˜•์œผ๋กœ์„œ, On-Device Ai ๊ฐœ๋ฐœ์„ ๋ชฉํ‘œ๋กœ ๊ฒฝ๋Ÿ‰ํ™”์— ์ดˆ์ ์„ ๋งž์ถ˜ ๋ชจ๋ธ์ด๋‹ค. GPT, BERT์˜ ๋“ฑ์žฅ ์ดํ›„, NLP ๋ถ„์•ผ์—์„œ ๋น„์•ฝ์ ์ธ ์„ฑ๋Šฅ ํ–ฅ์ƒ์ด ์ด๋ค„์กŒ์Œ์—๋„ ๋ถˆ๊ตฌํ•˜๊ณ , ํ„ฐ๋ฌด๋‹ˆ ์—†๋Š” ๋ชจ๋ธ ์‚ฌ์ด์ฆˆ์™€ ์ปดํ“จํŒ… ๋ฆฌ์†Œ์Šค ์š”๊ตฌ๋กœ ์ธํ•ด ์‹ค์ƒํ™œ ์ ์šฉ ๊ฐ™์€ ํ™œ์šฉ์„ฑ์€ ์—ฌ์ „ํžˆ ํ•ด๊ฒฐํ•ด์•ผํ•  ๋ฌธ์ œ๋กœ ๋‚จ์•„ ์žˆ์—ˆ๋‹ค. Google์—์„œ ๋ฐœํ‘œํ•œ ์ดˆ๊ธฐ BERT-base-uncased ๋งŒ ํ•ด๋„ ํŒŒ๋ผ๋ฏธํ„ฐ๊ฐ€ 1์–ต 1์ฒœ๋งŒ๊ฐœ ์ˆ˜์ค€์— ๋‹ฌํ•œ๋‹ค.

์ด๋ฅผ ๋‹ค์–‘ํ•œ ๋น„์ฆˆ๋‹ˆ์Šค ์š”๊ตฌ ์ƒํ™ฉ์— ์ ์šฉํ•  ์ˆ˜ ์žˆ์œผ๋ ค๋ฉด ์ตœ์†Œํ•œ 8GB ์ด์ƒ์˜ ๊ฐ€์†๊ธฐ ์ „์šฉ RAM ๊ณต๊ฐ„์„ ์š”๊ตฌ๋กœ ํ•œ๋‹ค. ์˜ค๋Š˜๋‚  ๊ฐœ์ธ์šฉ PC ํ˜น์€ ์„œ๋ฒ„ ์ปดํ“จํ„ฐ์˜ ๊ฒฝ์šฐ, 8GB ์ด์ƒ์˜ VRAM์ด ๋‹ฌ๋ฆฐ GPU๊ฐ€ ์ผ๋ฐ˜์ ์œผ๋กœ ํƒ‘์žฌ๋˜๊ธฐ ๋•Œ๋ฌธ์— ํฌ๊ฒŒ ๋ฌธ์ œ ๋  ๊ฒƒ ์—†๋Š” ์š”๊ตฌ์‚ฌํ•ญ์ด์ง€๋งŒ, On-Device ํ™˜๊ฒฝ์—์„œ๋Š” ์ด์•ผ๊ธฐ๊ฐ€ ๋‹ฌ๋ผ์ง„๋‹ค. ์ตœ์‹  ํ•˜์ด์—”๋“œ ์Šค๋งˆํŠธํฐ์ธ Galaxy S24 Ultra, iPhone 15 Pro์˜ ๊ฒฝ์šฐ 12GB, 8GB์˜ ๋žจ ์šฉ๋Ÿ‰์„ ๋ณด์œ ํ•˜๊ณ  ์žˆ๋‹ค. ๊ทธ๋งˆ์ €๋„ ๋Œ€๋ถ€๋ถ„์˜ ์˜จ๋””๋ฐ”์ด์Šค ํ™˜๊ฒฝ์€ SoC ๊ตฌ์กฐ๋ฅผ ์ฑ„ํƒํ•˜๊ณ  ์žˆ๊ธฐ ๋•Œ๋ฌธ์— ์ „์šฉ ๊ฐ€์†๊ธฐ๊ฐ€ ์˜จ์ „ํžˆ ์ € ๋ชจ๋“  ๋žจ ๊ณต๊ฐ„์„ ํ™œ์šฉํ•  ์ˆ˜ ์—†๋‹ค.

๋”ฐ๋ผ์„œ ์˜จ๋””๋ฐ”์ด์Šค์— Ai๋ฅผ ์ ์šฉํ•˜๊ธฐ ์œ„ํ•ด์„œ๋Š” ํš๊ธฐ์ ์ธ ๋ชจ๋ธ ๊ฒฝ๋Ÿ‰ํ™”๊ฐ€ ํ•„์š”ํ•œ ์ƒํ™ฉ์ด๊ณ  ๊ทธ ์ถœ๋ฐœ์ ์ด ๋œ ์—ฐ๊ตฌ๊ฐ€ ๋ฐ”๋กœ DistilBERT๋‹ค. ๋กœ์ปฌ ๋””๋ฐ”์ด์Šค ํ™˜๊ฒฝ์—์„œ๋„ ์–ธ์–ด ๋ชจ๋ธ์„ ํ™œ์šฉํ•˜๊ธฐ ์œ„ํ•ด ํ—ˆ๊น… ํŽ˜์ด์Šค ์—ฐ๊ตฌ์ง„์€ ์ง€์‹ ์ฆ๋ฅ˜ ๊ธฐ๋ฒ•์„ ํ™œ์šฉํ•ด ์ธ์ฝ”๋” ๊ธฐ๋ฐ˜ ์–ธ์–ด ๋ชจ๋ธ์˜ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ํš๊ธฐ์ ์œผ๋กœ ์ค„์ด๋Š”๋ฐ ์„ฑ๊ณตํ•œ๋‹ค.

์ •๋ฆฌํ•˜์ž๋ฉด, DistilBERT ๋ชจ๋ธ์€ ๊ธฐ์กด BERT์˜ ๊ตฌ์กฐ์  ์ธก๋ฉด ๊ฐœ์„ ์ด ์•„๋‹Œ, ์‚ฌ์ „ํ•™์Šต ๋ฐฉ๋ฒ• ํŠนํžˆ ๊ฒฝ๋Ÿ‰ํ™”์— ์ดˆ์ ์„ ๋งž์ถ˜ ์‹œ๋„๋ผ๊ณ  ๋ณผ ์ˆ˜ ์žˆ๋‹ค. ๋”ฐ๋ผ์„œ ์–ด๋–ค ๋ชจ๋ธ์ด๋”๋ผ๋„, ์ธ์ฝ”๋” ์–ธ์–ด ๋ชจ๋ธ์ด๋ผ๋ฉด ๋ชจ๋‘ DistilBERT ๊ตฌ์กฐ๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์œผ๋ฉฐ, ๊ธฐ์กด ๋…ผ๋ฌธ์—์„œ๋Š” ์›๋ณธ BERT ๊ตฌ์กฐ๋ฅผ ์‚ฌ์šฉํ–ˆ๋‹ค. ์ด๋ฒˆ ํฌ์ŠคํŒ…์—์„œ๋„ BERT ๊ตฌ์กฐ์— ๋Œ€ํ•œ ์„ค๋ช… ๋Œ€์‹ , DistilBERT์˜ ์‚ฌ์ „ ํ•™์Šต ๋ฐฉ๋ฒ•๋ก ์ธ Knowledge Distillation์— ๋Œ€ํ•ด์„œ๋งŒ ๋‹ค๋ฃจ๋ ค๊ณ  ํ•œ๋‹ค.

๐Ÿ“ฒย Knowledge Distillations

\[\min_{\theta}\sum_{x \in X} \alpha \mathcal{L}_{\text{KL}}(x, \theta) + \beta \mathcal{L}_{\text{MLM}}(x, \theta) + \gamma \mathcal{L}_{\text{Cos}}(x, \theta)\]

DistilBERT๋Š” Teacher-Student Architecture๋ฅผ ์ฐจ์šฉํ•ด ์ƒ๋Œ€์ ์œผ๋กœ ์ž‘์€ ํŒŒ๋ผ๋ฏธํ„ฐ ์‚ฌ์ด์ฆˆ๋ฅผ ๊ฐ–๋Š” Student ๋ชจ๋ธ์—๊ฒŒ Teacher์˜ ์ง€์‹์„ ์ „์ˆ˜ํ•˜๋Š” ๊ฒƒ์„ ๋ชฉํ‘œ๋กœ ํ•œ๋‹ค. ๋”ฐ๋ผ์„œ Teacher ๋ชจ๋ธ์€ ์ด๋ฏธ ์‚ฌ์ „ ํ•™์Šต์„ ๋งˆ์น˜๊ณ  ์ˆ˜๋ ด๋œ ์ƒํƒœ์˜ ๊ฐ€์ค‘์น˜๋ฅผ ๊ฐ–๊ณ  ์žˆ๋Š” ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•ด์•ผ ํ•œ๋‹ค. ๋”๋ถˆ์–ด Teacher ๋ชจ๋ธ์€ ๊ตฌ์กฐ๋งŒ ๊ธฐ์กด BERT๋ฅผ ๋”ฐ๋ฅด๋˜, ์‚ฌ์ „ ํ•™์Šต ๋ฐฉ์‹์€ RoBERTa์˜ ๋ฐฉ์‹๊ณผ ๋™์ผ(NSP ์ œ๊ฑฐ, Dynamic Masking ์ ์šฉ)ํ•˜๊ฒŒ ํ›ˆ๋ จ๋˜์–ด์•ผ ํ•œ๋‹ค.

ํ•œํŽธ, Student ๋ชจ๋ธ์€ Teacher์˜ 60%์ •๋„ ํŒŒ๋ผ๋ฏธํ„ฐ ์‚ฌ์ด์ฆˆ๋ฅผ ๊ฐ–๋„๋ก ์ถ•์†Œํ•˜์—ฌ ์‚ฌ์šฉํ•œ๋‹ค. ์ด ๋•Œ ์ถ•์†Œ๋Š” ๋ชจ๋ธ์˜ depth(๋ ˆ์ด์–ด ๊ฐœ์ˆ˜)์—๋งŒ ์ ์šฉํ•˜๋Š”๋ฐ, ์—ฐ๊ตฌ์ง„์— ๋”ฐ๋ฅด๋ฉด width(์€๋‹‰์ธต ํฌ๊ธฐ)๋Š” ์ถ•์†Œ๋ฅผ ์ ์šฉํ•ด๋„ ์—ฐ์‚ฐ ํšจ์œจ์ด ์ฆ๊ฐ€ํ•˜์ง€ ์•Š๋Š”๋‹ค๊ณ  ํ•œ๋‹ค. ์ •๋ฆฌํ•˜๋ฉด Teacher ๋ชจ๋ธ์˜ ๋ ˆ์ด์–ด ๊ฐœ์ˆ˜*0.6์˜ ๊ฐœ์ˆ˜๋งŒํผ ์ธ์ฝ”๋”๋ฅผ ์Œ“์œผ๋ฉด ๋œ๋‹ค๋Š” ๊ฒƒ์ด๋‹ค.

๊ทธ๋ฆฌ๊ณ  ์ตœ๋Œ€ํ•œ Teacher์˜ ์ง€์‹์„ ์ „์ˆ˜ํ•ด์•ผ ํ•˜๊ธฐ ๋•Œ๋ฌธ์—, ๋ฐ์ดํ„ฐ๋Š” Teacher ๋ฅผ ์ˆ˜๋ ด์‹œํ‚จ ๊ฒƒ๊ณผ ๋™์ผํ•œ ์„ธํŠธ๋ฅผ ์ด์šฉํ•ด์•ผ ํ•œ๋‹ค. ์ด ๋•Œ, Teacher ๋ชจ๋ธ์€ ์ด๋ฏธ MLE ๋ฐฉ์‹์œผ๋กœ ํ›ˆ๋ จ์ด ๋œ ์ƒํƒœ๋ผ์„œ ๋กœ์ง“์ด ๋‹จ์ผ ํ† ํฐ ํ•˜๋‚˜ ์ชฝ์œผ๋กœ ์ ๋ ค ์žˆ์„ ๊ฐ€๋Šฅ์„ฑ์ด ๋งค์šฐ ๋†’๋‹ค. ์ด๋Š” Student ๋ชจ๋ธ์˜ ์ผ๋ฐ˜ํ™” ๋Šฅ๋ ฅ์— ์•…์˜ํ–ฅ์„ ๋ฏธ์น  ๊ฐ€๋Šฅ์„ฑ์ด ๋†’๋‹ค. ๋”ฐ๋ผ์„œ Temperature ๋ณ€์ˆ˜ $T$ ๋„์ž…ํ•ด ์†Œํ”„ํŠธ ๋งฅ์Šค(๋กœ์ง“)์˜ ๋ถ„ํฌ๋ฅผ ํ‰ํƒ„ํ™” ํ•œ๋‹ค. ์ด๋ ‡๊ฒŒ ํ•˜๋ฉด, argmax() ๊ฐ€ ์•„๋‹Œ ๋‹ค๋ฅธ ํ† ํฐ ํ‘œํ˜„์— ๋Œ€ํ•ด์„œ๋„ Student ๋ชจ๋ธ์ด ์ง€์‹์„ ์Šต๋“ํ•  ์ˆ˜ ์žˆ์–ด์„œ ํ’๋ถ€ํ•œ ๋ฌธ๋งฅ์„ ํ•™์Šตํ•˜๊ณ  ์ผ๋ฐ˜ํ™” ๋Šฅ๋ ฅ์„ ๋†’์ด๋Š”๋ฐ ๋„์›€์ด ๋œ๋‹ค. ์ด๋ฅผ ์•”ํ‘ ์ง€์‹(Dark Knowledge) ์„ ํ™œ์šฉํ•œ๋‹ค๊ณ  ํ‘œํ˜„ํ•œ๋‹ค. Temperature ๋ณ€์ˆ˜ $T$ ๋„์ž…ํ•œ ์†Œํ”„ํŠธ๋งฅ์Šค ํ•จ์ˆ˜ ์ˆ˜์‹์€ ์•„๋ž˜์™€ ๊ฐ™๋‹ค.

\[\text{softmax}(x_i) = \frac{e^{\frac{x_i}{\tau}}}{\sum_{j} e^{\frac{x_j}{\tau}}}\]

์ˆ˜์‹์ƒ ๋ณ€์ˆ˜ $T$์˜ ๊ฐ’์„ 1์ด์ƒ์œผ๋กœ ์„ธํŒ…ํ•ด์•ผ ํ‰ํƒ„ํ™”๋ฅผ ํ•  ์ˆ˜ ์žˆ๋‹ค. ๋”ฐ๋ผ์„œ ์—ฐ๊ตฌ์ง„์€ $T =2$ ๋กœ ๋‘๊ณ  ์‚ฌ์ „ ํ•™์Šต์„ ์ง„ํ–‰ํ–ˆ๋‹ค(๋…ผ๋ฌธ์— ๊ณต๊ฐœ์•ˆ๋จ, GitHub์— ์žˆ์Œ). ์ด๋ฒˆ ํŒŒํŠธ ๋งจ ์ฒ˜์Œ์— ๋“ฑ์žฅํ•œ ์ˆ˜์‹์„ ๋‹ค์‹œ ๋ณด์ž. ๊ฒฐ๊ตญ DisilBERT์˜ ๋ชฉ์ ํ•จ์ˆ˜๋Š” 3๊ฐ€์ง€ ์†์‹ค์˜ ๊ฐ€์ค‘ํ•ฉ์œผ๋กœ ๊ตฌ์„ฑ๋œ๋‹ค. ์ด์ œ๋ถ€ํ„ฐ๋Š” ๊ฐœ๋ณ„ ์†์‹ค์— ๋Œ€ํ•ด์„œ ์ž์„ธํžˆ ์‚ดํŽด๋ณด์ž.

๐Ÿชขย Distillation Loss: KL-Divergence Loss

\[\text{KL-Divergence}(P || Q) = \sum_{i} P(i) \log \frac{P(i)}{Q(i)}\]

์ฆ๋ฅ˜ ์†์‹ค๋กœ ์‚ฌ์šฉ๋˜๋Š” KL-Divergence Loss๋Š” ๋‘ ํ™•๋ฅ  ๋ถ„ํฌ ๊ฐ„์˜ ์ฐจ์ด๋ฅผ ์ธก์ •ํ•˜๋Š” ์ง€ํ‘œ ์ค‘ ํ•˜๋‚˜๋‹ค. ์ฃผ๋กœ ํ™•๋ฅ  ๋ถ„ํฌ P์™€ Q ์‚ฌ์ด์˜ ์ฐจ์ด๋ฅผ ๋‚˜ํƒ€๋‚ด๋Š”๋ฐ, ๊ฐœ๋ณ„ ์š”์†Œ์˜ ํ™•๋ฅ ๊ฐ’ ์ฐจ์ด๊ฐ€ ํด์ˆ˜๋ก ํ•ฉ์‚ฐ๊ฐ’์€ ์ปค์ ธ ์†์‹ค์ด ์ปค์ง€๊ฒŒ ๋œ๋‹ค. ๋ฐ˜๋Œ€๋กœ ๋‘ ๋ถ„ํฌ์˜ ๊ฐœ๋ณ„ ์š”์†Œ ํ™•๋ฅ ๊ฐ’ ์ฐจ์ด๊ฐ€ ์ž‘๋‹ค๋ฉด ๋‹น์—ฐํžˆ, ๋‘ ๋ถ„ํฌ๊ฐ€ ์œ ์‚ฌํ•˜๋‹ค๋Š” ์˜๋ฏธ์ด๋ฏ€๋กœ ์†์‹ค ์—ญ์‹œ ์ž‘์•„์ง€๊ฒŒ ๋œ๋‹ค. ์ผ๋ฐ˜์ ์œผ๋กœ KL-Divergence Loss ์—์„œ ํ™•๋ฅ ๋ถ„ํฌ $P$ ๊ฐ€ ์ด์ƒ์ ์ธ ํ™•๋ฅ  ๋ถ„ํฌ๋ฅผ, $Q$ ๊ฐ€ ๋ชจ๋ธ์ด ์˜ˆ์ธกํ•œ ํ™•๋ฅ ๋ถ„ํฌ๋ฅผ ์˜๋ฏธํ•œ๋‹ค. ๋”ฐ๋ผ์„œ DistilBERT์˜ ๊ฒฝ์šฐ ํ™•๋ฅ ๋ถ„ํฌ $P$ ์ž๋ฆฌ์—๋Š” Teacher ๋ชจ๋ธ์˜ ์†Œํ”„ํŠธ๋งฅ์Šค ๋ถ„ํฌ๊ฐ€, $Q$ ์—๋Š” Student ๋ชจ๋ธ์˜ ์†Œํ”„ํŠธ๋งฅ์Šค ๋ถ„ํฌ๊ฐ€ ๋Œ€์ž…๋˜๋ฉด ๋œ๋‹ค. ์ด ๋•Œ ๋‘ ํ™•๋ฅ ๋ถ„ํฌ ๋ชจ๋‘, ์•”ํ‘ ์ง€์‹ ํš๋“์„ ์œ„ํ•ด ์†Œํ”„ํŠธ๋งฅ์Šค ํ‰ํƒ„ํ™”๋ฅผ ์ ์šฉํ•œ ๊ฒฐ๊ณผ๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค. ๋…ผ๋ฌธ์—์„œ, ์„ ์ƒ ๋ชจ๋ธ ์˜ˆ์ธก์— ํ‰ํƒ„ํ™”๋ฅผ ์ ์šฉํ•œ ๊ฒƒ์„ ์†Œํ”„ํŠธ ๋ผ๋ฒจ, ํ•™์ƒ ๋ชจ๋ธ์˜ ๊ฒƒ์— ์ ์šฉํ•œ ๊ฒฐ๊ณผ๋Š” ์†Œํ”„ํŠธ ์˜ˆ์ธก์ด๋ผ๊ณ  ๋ถ€๋ฅธ๋‹ค.

๐Ÿง‘โ€๐ŸŽ“ย Student Loss: MLM Loss

\[\mathcal{L}_{\text{MLM}} = - \sum_{i=1}^{N} \sum_{j=1}^{L} \mathbb{1}_{m_{ij}} \log \text{softmax}(x_{ij})\]

ํ•™์ƒ ์†์‹ค์€ ๋ง๊ทธ๋Œ€๋กœ ๊ธฐ๋ณธ์ ์ธ MLM ์†์‹ค์„ ๋งํ•œ๋‹ค. ์ •ํ™•ํ•œ ์†์‹ค๊ฐ’ ๊ณ„์‚ฐ์„ ์œ„ํ•ด์„œ ํ•™์ƒ์˜ ์†Œํ”„ํŠธ๋งฅ์Šค ๋ถ„ํฌ์— ํ‰ํƒ„ํ™”๋ฅผ ์ ์šฉํ•˜์ง€ ์•Š๋Š”๋‹ค. ์ด๋ฅผ ๋…ผ๋ฌธ์—์„œ๋Š” ํ•˜๋“œ ์˜ˆ์ธก์ด๋ผ๊ณ  ๋ถ€๋ฅธ๋‹ค. ๋ผ๋ฒจ ์—ญ์‹œ Teacher๋กœ๋ถ€ํ„ฐ ๋‚˜์˜จ ๊ฒƒ์ด ์•„๋‹Œ ์›๋ž˜ MLM ์ˆ˜ํ–‰์— ์‚ฌ์šฉ๋˜๋Š” ๋งˆ์Šคํ‚น ๋ผ๋ฒจ์„ ์‚ฌ์šฉํ•œ๋‹ค.

๐ŸŒ†ย Cosine Embedding Loss: Contrastive Loss by cosine similarity

\[\mathcal{L}_{\text{COS}}(x,y) = \begin{cases} 1 - \cos(x_1, x_2), & \text{if } y = 1 \\ \max(0, \cos(x_1, x_2) - \text{margin}), & \text{if } y = -1 \end{cases}\]

Teacher ๋ชจ๋ธ๊ณผ Student ๋ชจ๋ธ์˜ ๋งˆ์ง€๋ง‰ ์ธ์ฝ”๋” ๋ชจ๋ธ์ด ์ถœ๋ ฅํ•˜๋Š” ์€๋‹‰๊ฐ’์— ๋Œ€ํ•œ Contrastive Loss๋ฅผ ์˜๋ฏธํ•œ๋‹ค. ์ด ๋•Œ Distance Metric์€ ์ฝ”์‚ฌ์ธ ์œ ์‚ฌ๋„๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค. ๊ทธ๋ž˜์„œ ์ฝ”์‚ฌ์ธ ์ž„๋ฒ ๋”ฉ ์†์‹ค์ด๋ผ๊ณ  ๋…ผ๋ฌธ์—์„œ ์ •์˜ํ•˜๋Š” ๊ฒƒ์œผ๋กœ ์ถ”์ •๋œ๋‹ค. ์œ„ ์ˆ˜์‹์„ ์ตœ์ ํ™”ํ•˜๋Š” ๊ฒƒ์„ ๋ชฉ์ ์œผ๋กœ ํ•œ๋‹ค. ์ด ๋•Œ ๋ผ๋ฒจ์€ [BS, Seq_len]์˜ ํฌ๊ธฐ๋ฅผ ๊ฐ–๋˜, ๋ชจ๋“  ์›์†Œ๋Š” 1์ด ๋˜๋„๋ก ๋งŒ๋“ ๋‹ค. ์ด์œ ๋Š” ๊ฐ„๋‹จํ•˜๋‹ค. Student ๋ชจ๋ธ์˜ ์€๋‹‰๊ฐ’์ด Teacher ๋ชจ๋ธ์˜ ๊ฒƒ๊ณผ ์ตœ๋Œ€ํ•œ ๋น„์Šทํ•ด์ง€๋„๋ก ๋งŒ๋“œ๋Š”๊ฒŒ ์šฐ๋ฆฌ ๋ชฉ์ ์ด๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค.

๐Ÿ‘ฉโ€๐Ÿ’ปย Implementation by Pytorch

๋…ผ๋ฌธ์˜ ๋‚ด์šฉ๊ณผ ์˜คํ”ผ์…œ๋กœ ๊ณต๊ฐœ๋œ ์ฝ”๋“œ๋ฅผ ์ข…ํ•ฉํ•˜์—ฌ ํŒŒ์ดํ† ์น˜๋กœ DistilBERT๋ฅผ ๊ตฌํ˜„ํ•ด๋ดค๋‹ค. ๋…ผ๋ฌธ์— ํฌํ•จ๋œ ์•„์ด๋””์–ด๋ฅผ ์ดํ•ดํ•˜๋Š”๋ฐ๋Š” ์—ญ์‹œ ์–ด๋ ต์ง€ ์•Š์•˜์ง€๋งŒ, ํŽ˜์ดํผ์— hyper-param ํ…Œ์ด๋ธ”์ด ๋”ฐ๋กœ ์ œ์‹œ๋˜์–ด ์žˆ์ง€ ์•Š์•„ ๊ณต๊ฐœ๋œ ์ฝ”๋“œ๋ฅผ ์•ˆ ๋ณผ์ˆ˜๊ฐ€ ์—†์—ˆ๋‹ค.

์ „์ฒด ๋ชจ๋ธ ๊ตฌ์กฐ ๋Œ€ํ•œ ์ฝ”๋“œ๋Š” ์—ฌ๊ธฐ ๋งํฌ๋ฅผ ํ†ตํ•ด ์ฐธ๊ณ ๋ฐ”๋ž€๋‹ค.

๐Ÿ‘ฉโ€๐Ÿ’ปย Knowledge Distillation Pipeline

def train_val_fn(self, loader_train, model: nn.Module, criterion: Dict[str, nn.Module], optimizer,scheduler) -> Tuple[Any, Union[float, Any]]:
    """ Function for train loop with validation for each batch*N Steps
    DistillBERT has three loss:

        1) distillation loss, calculated by soft targets & soft predictions
            (nn.KLDIVLoss(reduction='batchmean'))

        2) student loss, calculated by hard targets & hard predictions
            (nn.CrossEntropyLoss(reduction='mean')), same as pure MLM Loss

        3) cosine similarity loss, calculated by student & teacher logit similarity
            (nn.CosineEmbeddingLoss(reduction='mean')), similar as contrastive loss

    Those 3 losses are summed jointly and then backward to student model
    """
    scaler = torch.cuda.amp.GradScaler(enabled=self.cfg.amp_scaler)
    model.train()
    for step, batch in enumerate(tqdm(loader_train)):
        optimizer.zero_grad(set_to_none=True)
        inputs = batch['input_ids'].to(self.cfg.device, non_blocking=True)
        labels = batch['labels'].to(self.cfg.device, non_blocking=True)
        padding_mask = batch['padding_mask'].to(self.cfg.device, non_blocking=True)

        mask = padding_mask.unsqueeze(-1).expand(-1, -1, self.cfg.dim_model)  # for hidden states dim
        with torch.no_grad():
            t_hidden_state, soft_target = model.teacher_fw(
                inputs=inputs,
                padding_mask=padding_mask,
                mask=mask
            )  # teacher model's pred => hard logit

        with torch.cuda.amp.autocast(enabled=self.cfg.amp_scaler):
            s_hidden_state, s_logit, soft_pred, c_labels = model.student_fw(
                inputs=inputs,
                padding_mask=padding_mask,
                mask=mask
            )
            d_loss = criterion["KLDivLoss"](soft_pred.log(), soft_target)  # nn.KLDIVLoss
            s_loss = criterion["CrossEntropyLoss"](s_logit.view(-1, self.cfg.vocab_size), labels.view(-1))  # nn.CrossEntropyLoss
            c_loss = criterion["CosineEmbeddingLoss"](s_hidden_state, t_hidden_state, c_labels)  # nn.CosineEmbeddingLoss
            loss = d_loss*self.cfg.alpha_distillation + s_loss*self.cfg.alpha_student + c_loss*self.cfg.alpha_cosine  # linear combination loss

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

๐Ÿ‘ฉโ€๐Ÿ’ปย Knowledge Distillation Model

class DistillationKnowledge(nn.Module, AbstractTask):
    """ Custom Task Module for Knowledge Distillation by DistilBERT Style Architecture
    DistilBERT Style Architecture is Teacher-Student Framework for Knowledge Distillation,

    And then they have 3 objective functions:
        1) distillation loss, calculated by soft targets & soft predictions
            (nn.KLDIVLoss(reduction='batchmean'))
        2) student loss, calculated by hard targets & hard predictions
            (nn.CrossEntropyLoss(reduction='mean')), same as pure MLM Loss
        3) cosine similarity loss, calculated by student & teacher logit similarity
            (nn.CosineEmbeddingLoss(reduction='mean')), similar as contrastive loss

    References:
        https://arxiv.org/pdf/1910.01108.pdf
        https://github.com/huggingface/transformers/blob/main/examples/research_projects/distillation/distiller.py
    """
    def __init__(self, cfg: CFG) -> None:
        super(DistillationKnowledge, self).__init__()
        self.cfg = CFG
        self.model = DistilBERT(
            self.cfg,
            self.select_model
        )
        self._init_weights(self.model)
        if self.cfg.teacher_load_pretrained:  # for teacher model
            self.model.teacher.load_state_dict(
                torch.load(cfg.checkpoint_dir + cfg.teacher_state_dict),
                strict=False
            )
        if self.cfg.student_load_pretrained:  # for student model
            self.model.student.load_state_dict(
                torch.load(cfg.checkpoint_dir + cfg.student_state_dict),
                strict=True
            )
        if self.cfg.freeze:
            freeze(self.model.teacher)
            freeze(self.model.mlm_head)

        if self.cfg.gradient_checkpoint:
            self.model.gradient_checkpointing_enable()

    def teacher_fw(
        self,
        inputs: Tensor,
        padding_mask: Tensor,
        mask: Tensor,
        attention_mask: Tensor = None,
        is_valid: bool = False
    ) -> Tuple[Tensor, Tensor]:
        """ teacher forward pass to make soft target, last_hidden_state for distillation loss """
        # 1) make soft target
        temperature = 1.0 if is_valid else self.cfg.temperature
        last_hidden_state, t_logit = self.model.teacher_fw(
            inputs,
            padding_mask,
            attention_mask
        )
        last_hidden_state = torch.masked_select(last_hidden_state, ~mask)  # for inverse select
        last_hidden_state = last_hidden_state.view(-1, self.cfg.dim_model)  # flatten last_hidden_state
        soft_target = F.softmax(
            t_logit.view(-1, self.cfg.vocab_size) / temperature**2,  # flatten softmax distribution
            dim=-1
        )  # [bs* seq, vocab_size]
        return last_hidden_state, soft_target

    def student_fw(
        self,
        inputs: Tensor,
        padding_mask: Tensor,
        mask: Tensor,
        attention_mask: Tensor = None,
        is_valid: bool = False
    ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
        """ student forward pass to make soft prediction, hard prediction for student loss """
        temperature = 1.0 if is_valid else self.cfg.temperature
        last_hidden_state, s_logit = self.model.teacher_fw(
            inputs,
            padding_mask,
            attention_mask
        )
        last_hidden_state = torch.masked_select(last_hidden_state, ~mask)  # for inverse select
        last_hidden_state = last_hidden_state.view(-1, self.cfg.dim_model)  # flatten last_hidden_state
        c_labels = last_hidden_state.new(last_hidden_state.size(0)).fill_(1)
        soft_pred = F.softmax(
            s_logit.view(-1, self.cfg.vocab_size) / temperature**2,  # flatten softmax distribution
            dim=-1
        )
        return last_hidden_state, s_logit, soft_pred, c_labels

๐Ÿ‘ฉโ€๐Ÿ’ปย DistilBERT Model

class DistilBERT(nn.Module, AbstractModel):
    """ Main class for DistilBERT Style Model, Teacher-Student Framework
    for Knowledge Distillation aim to lighter Large Scale LLM model. This model have 3 objective functions:

        1) distillation loss, calculated by soft targets & soft predictions
            (nn.KLDIVLoss(reduction='batchmean'))

        2) student loss, calculated by hard targets & hard predictions
            (nn.CrossEntropyLoss(reduction='mean')), same as pure MLM Loss

        3) cosine similarity loss, calculated by student & teacher logit similarity
            (nn.CosineEmbeddingLoss(reduction='mean')), similar as contrastive loss

    soft targets & soft predictions are meaning that logit are passed through softmax function applied with temperature T
    temperature T aim to flatten softmax layer distribution for making "Dark Knowledge" from teacher model

    hard targets & hard predictions are meaning that logit are passed through softmax function without temperature T
    hard targets are same as just simple labels from MLM Collator returns for calculating cross entropy loss

    cosine similarity loss is calculated by cosine similarity between student & teacher
    in official repo, they mask padding tokens for calculating cosine similarity, target for this task is 1
    cosine similarity is calculated by nn.CosineSimilarity() function, values are range to [-1, 1]

    you can select any other backbone model architecture for Teacher & Student Model for knowledge distillation
    but, in original paper, BERT is used for Teacher Model & Student
    and you must select pretrained model for Teacher Model, because Teacher Model is used for knowledge distillation,
    which is containing pretrained mlm head

    Do not pass gradient backward to teacher model!!
    (teacher model must be frozen or register_buffer to model or use no_grad() context manager)

    Args:
        cfg: configuration.CFG
        model_func: make model instance in runtime from config.json

    References:
        https://arxiv.org/pdf/1910.01108.pdf
        https://github.com/huggingface/transformers/blob/main/examples/research_projects/distillation/distiller.py
    """
    def __init__(self, cfg: CFG, model_func: Callable) -> None:
        super(DistilBERT, self).__init__()
        self.cfg = cfg
        self.teacher = model_func(self.cfg.teacher_num_layers)  # must be loading pretrained model containing mlm head
        self.mlm_head = MLMHead(self.cfg)  # must be loading pretrained model's mlm head

        self.student = model_func(self.cfg.student_num_layers)
        self.s_mlm_head = MLMHead(self.cfg)

    def teacher_fw(
        self,
        inputs: Tensor,
        padding_mask: Tensor,
        attention_mask: Tensor = None,
    ) -> Tuple[Tensor, Tensor]:
        """ forward pass for teacher model
        """
        last_hidden_state, _ = self.teacher(
            inputs,
            padding_mask,
            attention_mask
        )
        t_logit = self.mlm_head(last_hidden_state)  # hard logit => to make soft logit
        return last_hidden_state, t_logit

    def student_fw(
        self,
        inputs: Tensor,
        padding_mask: Tensor,
        attention_mask: Tensor = None
    ) -> Tuple[Tensor, Tensor]:
        """ forward pass for student model
        """
        last_hidden_state, _ = self.student(
            inputs,
            padding_mask,
            attention_mask
        )
        s_logit = self.s_mlm_head(last_hidden_state)  # hard logit => to make soft logit
        return last_hidden_state, s_logit

Leave a comment