๐ [Linear Attention] Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention
๐ญย Overview
DistilBERT
๋ ํ๊น
ํ์ด์ค ์ฐ๊ตฌ์ง์ด 2019๋
๋ฐํํ BERT์ ๋ณํ์ผ๋ก์, On-Device Ai ๊ฐ๋ฐ์ ๋ชฉํ๋ก ๊ฒฝ๋ํ์ ์ด์ ์ ๋ง์ถ ๋ชจ๋ธ์ด๋ค. GPT, BERT์ ๋ฑ์ฅ ์ดํ, NLP ๋ถ์ผ์์ ๋น์ฝ์ ์ธ ์ฑ๋ฅ ํฅ์์ด ์ด๋ค์ก์์๋ ๋ถ๊ตฌํ๊ณ , ํฐ๋ฌด๋ ์๋ ๋ชจ๋ธ ์ฌ์ด์ฆ์ ์ปดํจํ
๋ฆฌ์์ค ์๊ตฌ๋ก ์ธํด ์ค์ํ ์ ์ฉ ๊ฐ์ ํ์ฉ์ฑ์ ์ฌ์ ํ ํด๊ฒฐํด์ผํ ๋ฌธ์ ๋ก ๋จ์ ์์๋ค. Google์์ ๋ฐํํ ์ด๊ธฐ BERT-base-uncased
๋ง ํด๋ ํ๋ผ๋ฏธํฐ๊ฐ 1์ต 1์ฒ๋ง๊ฐ ์์ค์ ๋ฌํ๋ค.
์ด๋ฅผ ๋ค์ํ ๋น์ฆ๋์ค ์๊ตฌ ์ํฉ์ ์ ์ฉํ ์ ์์ผ๋ ค๋ฉด ์ต์ํ 8GB ์ด์์ ๊ฐ์๊ธฐ ์ ์ฉ RAM ๊ณต๊ฐ์ ์๊ตฌ๋ก ํ๋ค. ์ค๋๋ ๊ฐ์ธ์ฉ PC ํน์ ์๋ฒ ์ปดํจํฐ์ ๊ฒฝ์ฐ, 8GB ์ด์์ VRAM์ด ๋ฌ๋ฆฐ GPU๊ฐ ์ผ๋ฐ์ ์ผ๋ก ํ์ฌ๋๊ธฐ ๋๋ฌธ์ ํฌ๊ฒ ๋ฌธ์ ๋ ๊ฒ ์๋ ์๊ตฌ์ฌํญ์ด์ง๋ง, On-Device ํ๊ฒฝ์์๋ ์ด์ผ๊ธฐ๊ฐ ๋ฌ๋ผ์ง๋ค. ์ต์ ํ์ด์๋ ์ค๋งํธํฐ์ธ Galaxy S24 Ultra, iPhone 15 Pro์ ๊ฒฝ์ฐ 12GB, 8GB์ ๋จ ์ฉ๋์ ๋ณด์ ํ๊ณ ์๋ค. ๊ทธ๋ง์ ๋ ๋๋ถ๋ถ์ ์จ๋๋ฐ์ด์ค ํ๊ฒฝ์ SoC ๊ตฌ์กฐ๋ฅผ ์ฑํํ๊ณ ์๊ธฐ ๋๋ฌธ์ ์ ์ฉ ๊ฐ์๊ธฐ๊ฐ ์จ์ ํ ์ ๋ชจ๋ ๋จ ๊ณต๊ฐ์ ํ์ฉํ ์ ์๋ค.
๋ฐ๋ผ์ ์จ๋๋ฐ์ด์ค์ Ai๋ฅผ ์ ์ฉํ๊ธฐ ์ํด์๋ ํ๊ธฐ์ ์ธ ๋ชจ๋ธ ๊ฒฝ๋ํ๊ฐ ํ์ํ ์ํฉ์ด๊ณ ๊ทธ ์ถ๋ฐ์ ์ด ๋ ์ฐ๊ตฌ๊ฐ ๋ฐ๋ก DistilBERT
๋ค. ๋ก์ปฌ ๋๋ฐ์ด์ค ํ๊ฒฝ์์๋ ์ธ์ด ๋ชจ๋ธ์ ํ์ฉํ๊ธฐ ์ํด ํ๊น
ํ์ด์ค ์ฐ๊ตฌ์ง์ ์ง์ ์ฆ๋ฅ ๊ธฐ๋ฒ์ ํ์ฉํด ์ธ์ฝ๋ ๊ธฐ๋ฐ ์ธ์ด ๋ชจ๋ธ์ ํ๋ผ๋ฏธํฐ๋ฅผ ํ๊ธฐ์ ์ผ๋ก ์ค์ด๋๋ฐ ์ฑ๊ณตํ๋ค.
์ ๋ฆฌํ์๋ฉด, DistilBERT
๋ชจ๋ธ์ ๊ธฐ์กด BERT์ ๊ตฌ์กฐ์ ์ธก๋ฉด ๊ฐ์ ์ด ์๋, ์ฌ์ ํ์ต ๋ฐฉ๋ฒ ํนํ ๊ฒฝ๋ํ์ ์ด์ ์ ๋ง์ถ ์๋๋ผ๊ณ ๋ณผ ์ ์๋ค. ๋ฐ๋ผ์ ์ด๋ค ๋ชจ๋ธ์ด๋๋ผ๋, ์ธ์ฝ๋ ์ธ์ด ๋ชจ๋ธ์ด๋ผ๋ฉด ๋ชจ๋ DistilBERT
๊ตฌ์กฐ๋ฅผ ์ฌ์ฉํ ์ ์์ผ๋ฉฐ, ๊ธฐ์กด ๋
ผ๋ฌธ์์๋ ์๋ณธ BERT ๊ตฌ์กฐ๋ฅผ ์ฌ์ฉํ๋ค. ์ด๋ฒ ํฌ์คํ
์์๋ BERT ๊ตฌ์กฐ์ ๋ํ ์ค๋ช
๋์ , DistilBERT
์ ์ฌ์ ํ์ต ๋ฐฉ๋ฒ๋ก ์ธ Knowledge Distillation
์ ๋ํด์๋ง ๋ค๋ฃจ๋ ค๊ณ ํ๋ค.
๐ย Knowledge Distillations
\[\min_{\theta}\sum_{x \in X} \alpha \mathcal{L}_{\text{KL}}(x, \theta) + \beta \mathcal{L}_{\text{MLM}}(x, \theta) + \gamma \mathcal{L}_{\text{Cos}}(x, \theta)\]
DistilBERT
๋ Teacher-Student Architecture๋ฅผ ์ฐจ์ฉํด ์๋์ ์ผ๋ก ์์ ํ๋ผ๋ฏธํฐ ์ฌ์ด์ฆ๋ฅผ ๊ฐ๋ Student
๋ชจ๋ธ์๊ฒ Teacher
์ ์ง์์ ์ ์ํ๋ ๊ฒ์ ๋ชฉํ๋ก ํ๋ค. ๋ฐ๋ผ์ Teacher
๋ชจ๋ธ์ ์ด๋ฏธ ์ฌ์ ํ์ต์ ๋ง์น๊ณ ์๋ ด๋ ์ํ์ ๊ฐ์ค์น๋ฅผ ๊ฐ๊ณ ์๋ ๋ชจ๋ธ์ ์ฌ์ฉํด์ผ ํ๋ค. ๋๋ถ์ด Teacher ๋ชจ๋ธ์ ๊ตฌ์กฐ๋ง ๊ธฐ์กด BERT๋ฅผ ๋ฐ๋ฅด๋, ์ฌ์ ํ์ต ๋ฐฉ์์ RoBERTa์ ๋ฐฉ์๊ณผ ๋์ผ(NSP ์ ๊ฑฐ, Dynamic Masking ์ ์ฉ)ํ๊ฒ ํ๋ จ๋์ด์ผ ํ๋ค.
ํํธ, Student
๋ชจ๋ธ์ Teacher
์ 60%์ ๋ ํ๋ผ๋ฏธํฐ ์ฌ์ด์ฆ๋ฅผ ๊ฐ๋๋ก ์ถ์ํ์ฌ ์ฌ์ฉํ๋ค. ์ด ๋ ์ถ์๋ ๋ชจ๋ธ์ depth
(๋ ์ด์ด ๊ฐ์)์๋ง ์ ์ฉํ๋๋ฐ, ์ฐ๊ตฌ์ง์ ๋ฐ๋ฅด๋ฉด width
(์๋์ธต ํฌ๊ธฐ)๋ ์ถ์๋ฅผ ์ ์ฉํด๋ ์ฐ์ฐ ํจ์จ์ด ์ฆ๊ฐํ์ง ์๋๋ค๊ณ ํ๋ค. ์ ๋ฆฌํ๋ฉด Teacher
๋ชจ๋ธ์ ๋ ์ด์ด ๊ฐ์*0.6
์ ๊ฐ์๋งํผ ์ธ์ฝ๋๋ฅผ ์์ผ๋ฉด ๋๋ค๋ ๊ฒ์ด๋ค.
๊ทธ๋ฆฌ๊ณ ์ต๋ํ Teacher
์ ์ง์์ ์ ์ํด์ผ ํ๊ธฐ ๋๋ฌธ์, ๋ฐ์ดํฐ๋ Teacher
๋ฅผ ์๋ ด์ํจ ๊ฒ๊ณผ ๋์ผํ ์ธํธ๋ฅผ ์ด์ฉํด์ผ ํ๋ค. ์ด ๋, Teacher ๋ชจ๋ธ์ ์ด๋ฏธ MLE ๋ฐฉ์์ผ๋ก ํ๋ จ์ด ๋ ์ํ๋ผ์ ๋ก์ง์ด ๋จ์ผ ํ ํฐ ํ๋ ์ชฝ์ผ๋ก ์ ๋ ค ์์ ๊ฐ๋ฅ์ฑ์ด ๋งค์ฐ ๋๋ค. ์ด๋ Student
๋ชจ๋ธ์ ์ผ๋ฐํ ๋ฅ๋ ฅ์ ์
์ํฅ์ ๋ฏธ์น ๊ฐ๋ฅ์ฑ์ด ๋๋ค. ๋ฐ๋ผ์ Temperature ๋ณ์ $T$ ๋์
ํด ์ํํธ ๋งฅ์ค(๋ก์ง)์ ๋ถํฌ๋ฅผ ํํํ ํ๋ค. ์ด๋ ๊ฒ ํ๋ฉด, argmax()
๊ฐ ์๋ ๋ค๋ฅธ ํ ํฐ ํํ์ ๋ํด์๋ Student
๋ชจ๋ธ์ด ์ง์์ ์ต๋ํ ์ ์์ด์ ํ๋ถํ ๋ฌธ๋งฅ์ ํ์ตํ๊ณ ์ผ๋ฐํ ๋ฅ๋ ฅ์ ๋์ด๋๋ฐ ๋์์ด ๋๋ค. ์ด๋ฅผ ์ํ ์ง์(Dark Knowledge)
์ ํ์ฉํ๋ค๊ณ ํํํ๋ค. Temperature ๋ณ์ $T$ ๋์
ํ ์ํํธ๋งฅ์ค ํจ์ ์์์ ์๋์ ๊ฐ๋ค.
์์์ ๋ณ์ $T$์ ๊ฐ์ 1์ด์์ผ๋ก ์ธํ
ํด์ผ ํํํ๋ฅผ ํ ์ ์๋ค. ๋ฐ๋ผ์ ์ฐ๊ตฌ์ง์ $T =2$ ๋ก ๋๊ณ ์ฌ์ ํ์ต์ ์งํํ๋ค(๋
ผ๋ฌธ์ ๊ณต๊ฐ์๋จ, GitHub์ ์์). ์ด๋ฒ ํํธ ๋งจ ์ฒ์์ ๋ฑ์ฅํ ์์์ ๋ค์ ๋ณด์. ๊ฒฐ๊ตญ DisilBERT
์ ๋ชฉ์ ํจ์๋ 3๊ฐ์ง ์์ค์ ๊ฐ์คํฉ์ผ๋ก ๊ตฌ์ฑ๋๋ค. ์ด์ ๋ถํฐ๋ ๊ฐ๋ณ ์์ค์ ๋ํด์ ์์ธํ ์ดํด๋ณด์.
๐ย Distillation Loss: KL-Divergence Loss
\[\text{KL-Divergence}(P || Q) = \sum_{i} P(i) \log \frac{P(i)}{Q(i)}\]
์ฆ๋ฅ ์์ค๋ก ์ฌ์ฉ๋๋ KL-Divergence Loss
๋ ๋ ํ๋ฅ ๋ถํฌ ๊ฐ์ ์ฐจ์ด๋ฅผ ์ธก์ ํ๋ ์งํ ์ค ํ๋๋ค. ์ฃผ๋ก ํ๋ฅ ๋ถํฌ P์ Q ์ฌ์ด์ ์ฐจ์ด๋ฅผ ๋ํ๋ด๋๋ฐ, ๊ฐ๋ณ ์์์ ํ๋ฅ ๊ฐ ์ฐจ์ด๊ฐ ํด์๋ก ํฉ์ฐ๊ฐ์ ์ปค์ ธ ์์ค์ด ์ปค์ง๊ฒ ๋๋ค. ๋ฐ๋๋ก ๋ ๋ถํฌ์ ๊ฐ๋ณ ์์ ํ๋ฅ ๊ฐ ์ฐจ์ด๊ฐ ์๋ค๋ฉด ๋น์ฐํ, ๋ ๋ถํฌ๊ฐ ์ ์ฌํ๋ค๋ ์๋ฏธ์ด๋ฏ๋ก ์์ค ์ญ์ ์์์ง๊ฒ ๋๋ค. ์ผ๋ฐ์ ์ผ๋ก KL-Divergence Loss
์์ ํ๋ฅ ๋ถํฌ $P$ ๊ฐ ์ด์์ ์ธ ํ๋ฅ ๋ถํฌ๋ฅผ, $Q$ ๊ฐ ๋ชจ๋ธ์ด ์์ธกํ ํ๋ฅ ๋ถํฌ๋ฅผ ์๋ฏธํ๋ค. ๋ฐ๋ผ์ DistilBERT
์ ๊ฒฝ์ฐ ํ๋ฅ ๋ถํฌ $P$ ์๋ฆฌ์๋ Teacher
๋ชจ๋ธ์ ์ํํธ๋งฅ์ค ๋ถํฌ๊ฐ, $Q$ ์๋ Student
๋ชจ๋ธ์ ์ํํธ๋งฅ์ค ๋ถํฌ๊ฐ ๋์
๋๋ฉด ๋๋ค. ์ด ๋ ๋ ํ๋ฅ ๋ถํฌ ๋ชจ๋, ์ํ ์ง์ ํ๋์ ์ํด ์ํํธ๋งฅ์ค ํํํ๋ฅผ ์ ์ฉํ ๊ฒฐ๊ณผ๋ฅผ ์ฌ์ฉํ๋ค. ๋
ผ๋ฌธ์์, ์ ์ ๋ชจ๋ธ ์์ธก์ ํํํ๋ฅผ ์ ์ฉํ ๊ฒ์ ์ํํธ ๋ผ๋ฒจ
, ํ์ ๋ชจ๋ธ์ ๊ฒ์ ์ ์ฉํ ๊ฒฐ๊ณผ๋ ์ํํธ ์์ธก
์ด๋ผ๊ณ ๋ถ๋ฅธ๋ค.
๐ย Student Loss: MLM Loss
\[\mathcal{L}_{\text{MLM}} = - \sum_{i=1}^{N} \sum_{j=1}^{L} \mathbb{1}_{m_{ij}} \log \text{softmax}(x_{ij})\]
ํ์ ์์ค์ ๋ง๊ทธ๋๋ก ๊ธฐ๋ณธ์ ์ธ MLM ์์ค์ ๋งํ๋ค. ์ ํํ ์์ค๊ฐ ๊ณ์ฐ์ ์ํด์ ํ์์ ์ํํธ๋งฅ์ค ๋ถํฌ์ ํํํ๋ฅผ ์ ์ฉํ์ง ์๋๋ค. ์ด๋ฅผ ๋
ผ๋ฌธ์์๋ ํ๋ ์์ธก
์ด๋ผ๊ณ ๋ถ๋ฅธ๋ค. ๋ผ๋ฒจ ์ญ์ Teacher
๋ก๋ถํฐ ๋์จ ๊ฒ์ด ์๋ ์๋ MLM ์ํ์ ์ฌ์ฉ๋๋ ๋ง์คํน ๋ผ๋ฒจ์ ์ฌ์ฉํ๋ค.
๐ย Cosine Embedding Loss: Contrastive Loss by cosine similarity
\[\mathcal{L}_{\text{COS}}(x,y) = \begin{cases} 1 - \cos(x_1, x_2), & \text{if } y = 1 \\ \max(0, \cos(x_1, x_2) - \text{margin}), & \text{if } y = -1 \end{cases}\]
Teacher
๋ชจ๋ธ๊ณผ Student
๋ชจ๋ธ์ ๋ง์ง๋ง ์ธ์ฝ๋ ๋ชจ๋ธ์ด ์ถ๋ ฅํ๋ ์๋๊ฐ์ ๋ํ Contrastive Loss
๋ฅผ ์๋ฏธํ๋ค. ์ด ๋ Distance Metric
์ ์ฝ์ฌ์ธ ์ ์ฌ๋๋ฅผ ์ฌ์ฉํ๋ค. ๊ทธ๋์ ์ฝ์ฌ์ธ ์๋ฒ ๋ฉ ์์ค์ด๋ผ๊ณ ๋
ผ๋ฌธ์์ ์ ์ํ๋ ๊ฒ์ผ๋ก ์ถ์ ๋๋ค. ์ ์์์ ์ต์ ํํ๋ ๊ฒ์ ๋ชฉ์ ์ผ๋ก ํ๋ค. ์ด ๋ ๋ผ๋ฒจ์ [BS, Seq_len]
์ ํฌ๊ธฐ๋ฅผ ๊ฐ๋, ๋ชจ๋ ์์๋ 1์ด ๋๋๋ก ๋ง๋ ๋ค. ์ด์ ๋ ๊ฐ๋จํ๋ค. Student
๋ชจ๋ธ์ ์๋๊ฐ์ด Teacher
๋ชจ๋ธ์ ๊ฒ๊ณผ ์ต๋ํ ๋น์ทํด์ง๋๋ก ๋ง๋๋๊ฒ ์ฐ๋ฆฌ ๋ชฉ์ ์ด๊ธฐ ๋๋ฌธ์ด๋ค.
๐ฉโ๐ปย Implementation by Pytorch
๋
ผ๋ฌธ์ ๋ด์ฉ๊ณผ ์คํผ์
๋ก ๊ณต๊ฐ๋ ์ฝ๋๋ฅผ ์ข
ํฉํ์ฌ ํ์ดํ ์น๋ก DistilBERT
๋ฅผ ๊ตฌํํด๋ดค๋ค. ๋
ผ๋ฌธ์ ํฌํจ๋ ์์ด๋์ด๋ฅผ ์ดํดํ๋๋ฐ๋ ์ญ์ ์ด๋ ต์ง ์์์ง๋ง, ํ์ดํผ์ hyper-param ํ
์ด๋ธ์ด ๋ฐ๋ก ์ ์๋์ด ์์ง ์์ ๊ณต๊ฐ๋ ์ฝ๋๋ฅผ ์ ๋ณผ์๊ฐ ์์๋ค.
์ ์ฒด ๋ชจ๋ธ ๊ตฌ์กฐ ๋ํ ์ฝ๋๋ ์ฌ๊ธฐ ๋งํฌ๋ฅผ ํตํด ์ฐธ๊ณ ๋ฐ๋๋ค.
๐ฉโ๐ปย Knowledge Distillation Pipeline
def train_val_fn(self, loader_train, model: nn.Module, criterion: Dict[str, nn.Module], optimizer,scheduler) -> Tuple[Any, Union[float, Any]]:
""" Function for train loop with validation for each batch*N Steps
DistillBERT has three loss:
1) distillation loss, calculated by soft targets & soft predictions
(nn.KLDIVLoss(reduction='batchmean'))
2) student loss, calculated by hard targets & hard predictions
(nn.CrossEntropyLoss(reduction='mean')), same as pure MLM Loss
3) cosine similarity loss, calculated by student & teacher logit similarity
(nn.CosineEmbeddingLoss(reduction='mean')), similar as contrastive loss
Those 3 losses are summed jointly and then backward to student model
"""
scaler = torch.cuda.amp.GradScaler(enabled=self.cfg.amp_scaler)
model.train()
for step, batch in enumerate(tqdm(loader_train)):
optimizer.zero_grad(set_to_none=True)
inputs = batch['input_ids'].to(self.cfg.device, non_blocking=True)
labels = batch['labels'].to(self.cfg.device, non_blocking=True)
padding_mask = batch['padding_mask'].to(self.cfg.device, non_blocking=True)
mask = padding_mask.unsqueeze(-1).expand(-1, -1, self.cfg.dim_model) # for hidden states dim
with torch.no_grad():
t_hidden_state, soft_target = model.teacher_fw(
inputs=inputs,
padding_mask=padding_mask,
mask=mask
) # teacher model's pred => hard logit
with torch.cuda.amp.autocast(enabled=self.cfg.amp_scaler):
s_hidden_state, s_logit, soft_pred, c_labels = model.student_fw(
inputs=inputs,
padding_mask=padding_mask,
mask=mask
)
d_loss = criterion["KLDivLoss"](soft_pred.log(), soft_target) # nn.KLDIVLoss
s_loss = criterion["CrossEntropyLoss"](s_logit.view(-1, self.cfg.vocab_size), labels.view(-1)) # nn.CrossEntropyLoss
c_loss = criterion["CosineEmbeddingLoss"](s_hidden_state, t_hidden_state, c_labels) # nn.CosineEmbeddingLoss
loss = d_loss*self.cfg.alpha_distillation + s_loss*self.cfg.alpha_student + c_loss*self.cfg.alpha_cosine # linear combination loss
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
scheduler.step()
๐ฉโ๐ปย Knowledge Distillation Model
class DistillationKnowledge(nn.Module, AbstractTask):
""" Custom Task Module for Knowledge Distillation by DistilBERT Style Architecture
DistilBERT Style Architecture is Teacher-Student Framework for Knowledge Distillation,
And then they have 3 objective functions:
1) distillation loss, calculated by soft targets & soft predictions
(nn.KLDIVLoss(reduction='batchmean'))
2) student loss, calculated by hard targets & hard predictions
(nn.CrossEntropyLoss(reduction='mean')), same as pure MLM Loss
3) cosine similarity loss, calculated by student & teacher logit similarity
(nn.CosineEmbeddingLoss(reduction='mean')), similar as contrastive loss
References:
https://arxiv.org/pdf/1910.01108.pdf
https://github.com/huggingface/transformers/blob/main/examples/research_projects/distillation/distiller.py
"""
def __init__(self, cfg: CFG) -> None:
super(DistillationKnowledge, self).__init__()
self.cfg = CFG
self.model = DistilBERT(
self.cfg,
self.select_model
)
self._init_weights(self.model)
if self.cfg.teacher_load_pretrained: # for teacher model
self.model.teacher.load_state_dict(
torch.load(cfg.checkpoint_dir + cfg.teacher_state_dict),
strict=False
)
if self.cfg.student_load_pretrained: # for student model
self.model.student.load_state_dict(
torch.load(cfg.checkpoint_dir + cfg.student_state_dict),
strict=True
)
if self.cfg.freeze:
freeze(self.model.teacher)
freeze(self.model.mlm_head)
if self.cfg.gradient_checkpoint:
self.model.gradient_checkpointing_enable()
def teacher_fw(
self,
inputs: Tensor,
padding_mask: Tensor,
mask: Tensor,
attention_mask: Tensor = None,
is_valid: bool = False
) -> Tuple[Tensor, Tensor]:
""" teacher forward pass to make soft target, last_hidden_state for distillation loss """
# 1) make soft target
temperature = 1.0 if is_valid else self.cfg.temperature
last_hidden_state, t_logit = self.model.teacher_fw(
inputs,
padding_mask,
attention_mask
)
last_hidden_state = torch.masked_select(last_hidden_state, ~mask) # for inverse select
last_hidden_state = last_hidden_state.view(-1, self.cfg.dim_model) # flatten last_hidden_state
soft_target = F.softmax(
t_logit.view(-1, self.cfg.vocab_size) / temperature**2, # flatten softmax distribution
dim=-1
) # [bs* seq, vocab_size]
return last_hidden_state, soft_target
def student_fw(
self,
inputs: Tensor,
padding_mask: Tensor,
mask: Tensor,
attention_mask: Tensor = None,
is_valid: bool = False
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
""" student forward pass to make soft prediction, hard prediction for student loss """
temperature = 1.0 if is_valid else self.cfg.temperature
last_hidden_state, s_logit = self.model.teacher_fw(
inputs,
padding_mask,
attention_mask
)
last_hidden_state = torch.masked_select(last_hidden_state, ~mask) # for inverse select
last_hidden_state = last_hidden_state.view(-1, self.cfg.dim_model) # flatten last_hidden_state
c_labels = last_hidden_state.new(last_hidden_state.size(0)).fill_(1)
soft_pred = F.softmax(
s_logit.view(-1, self.cfg.vocab_size) / temperature**2, # flatten softmax distribution
dim=-1
)
return last_hidden_state, s_logit, soft_pred, c_labels
๐ฉโ๐ปย DistilBERT Model
class DistilBERT(nn.Module, AbstractModel):
""" Main class for DistilBERT Style Model, Teacher-Student Framework
for Knowledge Distillation aim to lighter Large Scale LLM model. This model have 3 objective functions:
1) distillation loss, calculated by soft targets & soft predictions
(nn.KLDIVLoss(reduction='batchmean'))
2) student loss, calculated by hard targets & hard predictions
(nn.CrossEntropyLoss(reduction='mean')), same as pure MLM Loss
3) cosine similarity loss, calculated by student & teacher logit similarity
(nn.CosineEmbeddingLoss(reduction='mean')), similar as contrastive loss
soft targets & soft predictions are meaning that logit are passed through softmax function applied with temperature T
temperature T aim to flatten softmax layer distribution for making "Dark Knowledge" from teacher model
hard targets & hard predictions are meaning that logit are passed through softmax function without temperature T
hard targets are same as just simple labels from MLM Collator returns for calculating cross entropy loss
cosine similarity loss is calculated by cosine similarity between student & teacher
in official repo, they mask padding tokens for calculating cosine similarity, target for this task is 1
cosine similarity is calculated by nn.CosineSimilarity() function, values are range to [-1, 1]
you can select any other backbone model architecture for Teacher & Student Model for knowledge distillation
but, in original paper, BERT is used for Teacher Model & Student
and you must select pretrained model for Teacher Model, because Teacher Model is used for knowledge distillation,
which is containing pretrained mlm head
Do not pass gradient backward to teacher model!!
(teacher model must be frozen or register_buffer to model or use no_grad() context manager)
Args:
cfg: configuration.CFG
model_func: make model instance in runtime from config.json
References:
https://arxiv.org/pdf/1910.01108.pdf
https://github.com/huggingface/transformers/blob/main/examples/research_projects/distillation/distiller.py
"""
def __init__(self, cfg: CFG, model_func: Callable) -> None:
super(DistilBERT, self).__init__()
self.cfg = cfg
self.teacher = model_func(self.cfg.teacher_num_layers) # must be loading pretrained model containing mlm head
self.mlm_head = MLMHead(self.cfg) # must be loading pretrained model's mlm head
self.student = model_func(self.cfg.student_num_layers)
self.s_mlm_head = MLMHead(self.cfg)
def teacher_fw(
self,
inputs: Tensor,
padding_mask: Tensor,
attention_mask: Tensor = None,
) -> Tuple[Tensor, Tensor]:
""" forward pass for teacher model
"""
last_hidden_state, _ = self.teacher(
inputs,
padding_mask,
attention_mask
)
t_logit = self.mlm_head(last_hidden_state) # hard logit => to make soft logit
return last_hidden_state, t_logit
def student_fw(
self,
inputs: Tensor,
padding_mask: Tensor,
attention_mask: Tensor = None
) -> Tuple[Tensor, Tensor]:
""" forward pass for student model
"""
last_hidden_state, _ = self.student(
inputs,
padding_mask,
attention_mask
)
s_logit = self.s_mlm_head(last_hidden_state) # hard logit => to make soft logit
return last_hidden_state, s_logit
Leave a comment