๐ฎ [ELECTRA] Pre-training Text Encoders as Discriminators Rather Than Generators
๐ญย Overview
ELECTRA
๋ 2020๋
Google์์ ์ฒ์ ๋ฐํํ ๋ชจ๋ธ๋ก, GAN(Generative Adversarial Networks) Style ์ํคํ
์ฒ๋ฅผ NLP์ ์ ์ฉํ ๊ฒ์ด ํน์ง์ด๋ค. ์๋ก์ด ๊ตฌ์กฐ ์ฐจ์ฉ์ ๋ง์ถฐ์ RTD(Replace Token Dection)
Task๋ฅผ ๊ณ ์์ ์ฌ์ ํ์ต์ผ๋ก ์ฌ์ฉํ๋ค. ๋ชจ๋ ์์ด๋์ด๋ ๊ธฐ์กด MLM(Masked Language Model)์ ์ฌ์ ํ์ต ๋ฐฉ๋ฒ๋ก ์ผ๋ก ์ฌ์ฉํ๋ ์ธ์ฝ๋ ์ธ์ด ๋ชจ๋ธ(BERT ๊ณ์ด)์ ๋จ์ ์ผ๋ก๋ถํฐ ์ถ๋ฐํ๋ค.
[MLM ๋จ์ ]
- 1) ์ฌ์ ํ์ต๊ณผ ํ์ธํ๋ ์ฌ์ด ๋ถ์ผ์น
- ํ์ธํ๋ ๋ Masking Task๊ฐ ์์
- 2) ์ฐ์ฐ๋ ๋๋น ํ์ต๋์ ์ ์ํธ
- ์ ์ฒด ์ํ์ค์ 15%๋ง ๋ง์คํน ํ์ฉ(15%๋ง ํ์ต)
- ์ ์ญ ์ดํ
์
์ ์๊ณต๊ฐ ๋ณต์ก๋ ๊ณ ๋ คํ๋ฉด ์๋นํ ๋นํจ์จ์ ์ธ ์์น
- ์ํ์ค๊ธธ์ด ** 2์ ๋ณต์ก๋
- Vocab Size๋งํผ์ ์ฐจ์์ ๊ฐ๋ ์ํํธ๋งฅ์ค ๊ณ์ฐ ๋ฐ๋ณต
๊ทธ๋์ MLM์ ํ์ฉํ๋, ํ์ธํ๋๊ณผ ๊ดด๋ฆฌ๋ ํฌ์ง ์์ ๋ชฉ์ ํจ์๋ฅผ ์ค๊ณํจ์ผ๋ก์ ์ ๋ ฅ๋ ์ ์ฒด ์ํ์ค์ ๋ํด์ ๋ชจ๋ธ์ด ํ์ตํ์ฌ ์ฐ์ฐ๋ ๋๋น ํ์ต๋์ ๋๋ฆฌ๊ณ ์ ํ๋๊ฒ ๋ฐ๋ก ELECTRA ๋ชจ๋ธ์ด๋ค.
์ ๋ฆฌํ์๋ฉด, ELECTRA ๋ชจ๋ธ์ ๊ธฐ์กด BERT์ ๊ตฌ์กฐ์ ์ธก๋ฉด ๊ฐ์ ์ด ์๋, ์ฌ์ ํ์ต ๋ฐฉ๋ฒ์ ๋ํ ๊ฐ์ ์๋๋ผ๊ณ ๋ณผ ์ ์๋ค. ๋ฐ๋ผ์ ์ด๋ค ๋ชจ๋ธ์ด๋๋ผ๋, ์ธ์ฝ๋ ์ธ์ด ๋ชจ๋ธ์ด๋ผ๋ฉด ๋ชจ๋ ELECTRA ๊ตฌ์กฐ๋ฅผ ์ฌ์ฉํ ์ ์์ผ๋ฉฐ, ๊ธฐ์กด ๋ ผ๋ฌธ์์๋ ์๋ณธ BERT ๊ตฌ์กฐ๋ฅผ ์ฌ์ฉํ๋ค. ๊ทธ๋์ ๋ณธ ํฌ์คํ ์์๋ BERT์ ๋ํ ์ค๋ช ์์ด RTD์ ๋ํด์๋ง ๋ค๋ฃจ๋ ค๊ณ ํ๋ค.
๐ฎย RTD: New Pre-train Task
RTD์ ์์ด๋์ด๋ ๊ฐ๋จํ๋ค. ์์ฑ์(Generator)๊ฐ ์ถ๋ ฅ์ผ๋ก ๋ด๋์ ํ ํฐ ์ํ์ค์ ๋ํด์ ํ๋ณ์(Discriminator)๊ฐ ๊ฐ๋ณ ํ ํฐ๋ค์ด ์๋ณธ์ธ์ง ์๋์ง๋ฅผ ํ์ (์ด์ง ๋ถ๋ฅ)ํ๋๋ก ๋ง๋ ๋ค. ์์ฑ์๋ ๊ธฐ์กด์ MLM์ ๊ทธ๋๋ก ์ํํ๊ณ , ํ๋ณ์๋ ์์ฑ์์ ์์ธก์ ๋ํด ์ง์ง์ธ์ง ๊ฐ์ง์ธ์ง ๋ถ๋ฅํ๋ ์์ด๋ค.
์ ๊ทธ๋ฆผ์ ์์๋ก ์ดํด๋ณด์. ๋ชจ๋ธ์ ์
๋ ฅ์ผ๋ก the chef cooked the meal
๋ผ๋ ์ํ์ค ์ค๋ค. ๊ทธ๋ฌ๋ฉด MLM ๊ท์น์ ๋ฐ๋ผ์ 15%์ ํ ํฐ์ด ๋ฌด์์๋ก ์ ํ๋๋ค. ๊ทธ๋์ the
, cooked
๊ฐ ๋ง์คํน ๋์๋ค. ์ด์ ์์ฑ์๋ ๋ง์คํน ํ ํฐ์ ๋ํด the
, ate
๋ผ๋ ๊ฒฐ๊ณผ๋ฅผ ๋ด๋๋๋ค. ๊ทธ๋์ ์ต์ข
์ ์ผ๋ก ์์ฑ์๊ฐ ๋ฐํํ๋ ์ํ์ค๋ the chef ate the meal
์ด ๋๋ค. ์ด์ ์์ฑ์๊ฐ ๋ฐํํ ์ํ์ค๋ฅผ ํ๋ณ์์ ์
๋ ฅ์ผ๋ก ๋์
ํ๋ค. ํ๋ณ์๋ ๊ฐ๋ณ ํ ํฐ๋ค์ด ์๋ณธ์ธ์ง ์๋์ง๋ฅผ ํ์ ํด ๊ฒฐ๊ณผ๋ฅผ ์ถ๋ ฅํ๋ค.
์ด๋ฌํ ๊ตฌ์กฐ ๋ฐ ์ฌ์ ํ์ต ๋ฐฉ์์ ์ฅ์ ์ ํ๋ณ์๊ฐ MLM ํ์ต์ ๋ฐ๋ฅธ ์ง์์ ์์ฑ์๋ก๋ถํฐ ์ ์ ๋ฐ๋ ๋์์ ์ ์ฒด ์ํ์ค์ ๋ํด์ ํ์ตํ ๊ธฐํ๊ฐ ์๊ธด๋ค๋ ๊ฒ์ด๋ค. ์ํ์ค ๋ด๋ถ ๋ชจ๋ ํ ํฐ์ ๋ํด์ ์์ธก์ ์ํํ๊ณ ์์ค์ ๊ณ์ฐํ ์ ์๊ธฐ ๋๋ฌธ์ ๊ฐ์ ํฌ๊ธฐ์ ์ํ์ค๋ฅผ ์ฌ์ฉํด๋ ๊ธฐ์กด MLM ๋๋น ๋ ํ๋ถํ ๋ฌธ๋งฅ ์ ๋ณด๋ฅผ ๋ชจ๋ธ์ด ํฌ์ฐฉํ ์ ์๊ฒ ๋๋ค. ๋ํ ํ๋ณ์๋ฅผ ํ์ธํ๋์ BackBone์ผ๋ก ์ฌ์ฉํ๋ฉด, ํ๋ณ์์ ์ฌ์ ํ์ต์ ๊ฒฐ๊ตญ ๋ง์คํน ์์ด ๋ชจ๋ ์ํ์ค๋ฅผ ํ์ฉํ ์ด์ง ๋ถ๋ฅ๋ผ๊ณ ๋ณผ ์ ์๊ธฐ ๋๋ฌธ์, ์ฌ์ ํ์ต๊ณผ ํ์ธํ๋ ์ฌ์ด์ ๊ดด๋ฆฌ๋ ์๋นํ ๋ง์ด ์ค์ด๋ค๊ฒ ๋๋ค.
๐ย Architecture
์ ์๋ ์์ ๊ฐ์ ์คํ ๊ฒฐ๊ณผ๋ฅผ ๋ฐํ์ผ๋ก, ์์ฑ์์ width (์๋์ธต) ํฌ๊ธฐ๊ฐ ํ๋ณ์๋ณด๋ค ์๋๋ก ๋ชจ๋ธ ํฌ๊ธฐ๋ฅผ ์ธํ ํ๋๊ฒ ๊ฐ์ฅ ํจ์จ์ ์ด๋ผ๊ณ ์ฃผ์ฅํ๋ค. ์ ์๋ ๊ทธ๋ํ๋ ์์ฑ์์ ํ๋ณ์์ ํฌ๊ธฐ ๋ณํ ๋๋น ํ์ธํ๋ ์ฑ๋ฅ์ ์ถ์ด๋ฅผ ๋ํ๋ธ๋ค. ์์ฑ์์ width ํฌ๊ธฐ๊ฐ 256, ํ๋ณ์์ width ํฌ๊ธฐ๊ฐ 768์ผ ๋ ๊ฐ์ฅ ์ ์๊ฐ ๋๋ค. depth(๋ ์ด์ด ๊ฐ์)์ ๋ํ ์ธ๊ธ์ ๋ฐ๋ก ์์ง๋ง, ์ ์์ ์ํด ๊ณต๊ฐ๋ Hyper-Param ํ ์ด๋ธ์ ๋ณด๋ฉด ์๋์ธต์ ํฌ๊ธฐ๋ง ์ค์ด๊ณ , ๋ ์ด์ด ๊ฐ์๋ ์์ฑ์์ ํ๋ณ์๊ฐ ๊ฐ์ ๊ฒ์ผ๋ก ์ถ์ ๋๋ค.
์ถ๊ฐ๋ก, ์์ฑ์์ ํ๋ณ์๊ฐ ์๋ฒ ๋ฉ ์ธต์ ์๋ก ๊ณต์ ํ๋๊ฒ ๊ฐ์ฅ ๋์ ์ฑ๋ฅ์ ๋ธ๋ค๊ณ ์ฃผ์ฅํ๋ค. ์ค๋ฅธ์ชฝ ๊ทธ๋ํ ์ถ์ด๋ฅผ ๋ณด๋ฉด ๊ฐ์ ์ฐ์ฐ๋์ด๋ผ๋ฉด, ์๋ฒ ๋ฉ ๊ณต์ (ํ๋์ ์ค์ ) ๋ฐฉ์์ด ๊ฐ์ฅ ๋์ ํ์ธํ๋ ์ฑ๋ฅ์ ๋ณด์ฌ์ค๋ค๋ ๊ฒ์ ์ ์ ์๋ค. ๋ฐ๋ผ์ ๋จ์ด ์๋ฒ ๋ฉ, ์ ๋ ์์น ์๋ฒ ๋ฉ์ ์๋ก ๊ณต์ ํ๋๋ก ์ค๊ณํ๋ค. ๋์ ์์ฑ์ ์๋์ธต์ ํฌ๊ธฐ๊ฐ ๋ ์์๊ฒ ์ ๋ฆฌํ๋ค๊ณ ์ธ๊ธํ๊ธฐ ๋๋ฌธ์, ์ด๊ฒ์ ์ค์ ๋ก ๊ตฌํํ๋ ค๋ฉด ์๋ฒ ๋ฉ ์ธต์ผ๋ก๋ถํฐ ๋์จ ๊ฒฐ๊ณผ๊ฐ์ ์์ฑ์์ ์๋์ธต ์ฐจ์์ผ๋ก ์ ํ ํฌ์ํด์ค์ผ ํ๋ค. ๊ทธ๋์ ์์ฑ์์ ์๋ฒ ๋ฉ ์ธต๊ณผ ์ธ์ฝ๋ ์ฌ์ด์ linear layer๊ฐ ์ถ๊ฐ๋์ด์ผ ํ๋ค.
\[\min_{\theta_G, \theta_D}\sum_{x \in X} \mathcal{L}_{\text{MLM}}(x, \theta_G) + \lambda \mathcal{L}_{\text{Disc}}(x, \theta_D)\]๋ฐ๋ผ์, ์ง๊ธ๊น์ง ์ดํด๋ณธ ๋ชจ๋ ๋ด์ฉ์ ์ข ํฉํด๋ณด๋ฉด ELECTRA์ ๋ชฉ์ ํจ์๋ ๋ค์ ์์๊ณผ ๊ฐ๋ค. ์์ฑ์์ MLM ์์ค๊ณผ ํ๋ณ์์ ์ด์ง ๋ถ๋ฅ ์์ค์ ๋ํด์ ๋ชจ๋ธ์ ์ค์ฐจ ์ญ์ ํด์ฃผ๋ฉด ๋๋๋ฐ, ํน์ดํ ์ ์ ํ๋ณ์์ ์์ค์ ์์๊ฐ์ธ ๋๋ค๊ฐ ๊ณฑํด์ง๋ค๋ ๊ฒ์ด๋ค. ์ค์ ๋ชจ๋ธ์ ๊ตฌํํ๊ณ ์ฌ์ ํ์ต์ ํด๋ณด๋ฉด, ๋ฐ์ดํฐ์ ์์ด๋ ๋ชจ๋ธ ํฌ๊ธฐ ํน์ ์ข ๋ฅ์ ๋ฐ๋ผ ๋ฌ๋ผ์ง๊ฒ ์ง๋ง ๋ ์์ค ์ฌ์ด์ ์ค์ผ์ผ์ ์ฐจ์ด๊ฐ 10๋ฐฐ์ ๋ ์ฐจ์ด ๋๊ฒ ๋๋ค. ๋ ์์ค์ ์ค์ผ์ผ์ ๋ง์ถฐ์ฃผ๋ ๋์์, ์๋ฒ ๋ฉ์ธต์ ํ์ต์ด ํ๋ณ์์ ์์ค์ ์ค์ด๋๋ฐ ๋ ์ง์คํ๋๋ก ๋ง๋ค๊ธฐ ์ํด ๋์ ํ ๊ฒ์ผ๋ก ์ถ์ ๋๋ค. ๋ ผ๋ฌธ๊ณผ ์ฝ๋๋ฅผ ๋ณด๋ฉด ์ ์๋ $\lambda=50$ ์ผ๋ก ๋๊ณ ํ์ตํ๊ณ ์๋ค.
๐ฉโ๐ปย Implementation by Pytorch
๋ ผ๋ฌธ์ ๋ด์ฉ๊ณผ ์ ์๊ฐ ์ง์ ๊ณต๊ฐํ ์ฝ๋๋ฅผ ์ข ํฉํ์ฌ ํ์ดํ ์น๋ก ELECTRA๋ฅผ ๊ตฌํํด๋ดค๋ค. ๋ ๊ฐ์ ์๋ก ๋ค๋ฅธ ๋ชจ๋ธ์ ๊ฐ์ ์คํญ์์ ํ์ต์์ผ์ผ ํ๊ธฐ ๋๋ฌธ์, ์ ์๋ ๋ด์ฉ์ ๋นํด ์ค์ ๊ตฌํ์ ๋งค์ฐ ๊น๋ค๋ก์ด ํธ์ด์๋ค. ๋ณธ ํฌ์คํ ์์๋ ELECTRA ๋ชจ๋ธ ๊ตฌ์กฐ๋ฅผ ๋น๋กฏํด RTD ํ์ต ํ์ดํ๋ผ์ธ ๊ตฌ์ฑ์ ํ์์ ์ธ ์์ ๋ช ๊ฐ์ง์ ๋ํด์๋ง ์ค๋ช ํ๋ ค ํ๋ค. ์ ์ฒด ๊ตฌ์กฐ์ ๋ํ ์ฝ๋๋ ์ฌ๊ธฐ ๋งํฌ๋ฅผ ํตํด ์ฐธ๊ณ ๋ถํ๋๋ฆฐ๋ค.
ELECTRA์ ์ฌ์ ํ์ต์ธ RTD์ ํ์ต ํ์ดํ๋ผ์ธ์ ๊ตฌํํ ์ฝ๋๋ฅผ ๋ณธ ๋ค, ์ธ๋ถ ๊ตฌ์ฑ ์์๋ค์ ๋ํด์ ์ดํด๋ณด์.
๐ RTD trainer method
def train_val_fn(self, loader_train, model: nn.Module, criterion: nn.Module, optimizer, scheduler) -> Tuple[Any, Union[float, Any]]:
scaler = torch.cuda.amp.GradScaler(enabled=self.cfg.amp_scaler)
model.train()
for step, batch in enumerate(tqdm(loader_train)):
optimizer.zero_grad(set_to_none=True)
inputs = batch['input_ids'].to(self.cfg.device, non_blocking=True)
labels = batch['labels'].to(self.cfg.device, non_blocking=True)
padding_mask = batch['padding_mask'].to(self.cfg.device, non_blocking=True)
mask_labels = None
if self.cfg.rtd_masking == 'SpanBoundaryObjective':
mask_labels = batch['mask_labels'].to(self.cfg.device, non_blocking=True)
with torch.cuda.amp.autocast(enabled=self.cfg.amp_scaler):
g_logit, d_inputs, d_labels = model.generator_fw(
inputs,
labels,
padding_mask,
mask_labels
)
d_logit = model.discriminator_fw(
d_inputs,
padding_mask
)
g_loss = criterion(g_logit.view(-1, self.cfg.vocab_size), labels.view(-1))
d_loss = criterion(d_logit.view(-1, 2), d_labels)
loss = g_loss + d_loss*self.cfg.discriminator_lambda
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
scheduler.step()
๋ฐ์ดํฐ๋ก๋๋ก๋ถํฐ ๋ฐ์ ์ ๋ ฅ๋ค์ ์์ฑ์์ ๋ฃ๊ณ MLM ์์ธก ๊ฒฐ๊ณผ, RTD ์ํ์ ์ํด ํ์ํ ์๋ก์ด ๋ผ๋ฒจ๊ฐ์ ๋ฐํ ๋ฐ๋๋ค. ๊ทธ๋ฆฌ๊ณ ์ด๊ฒ์ ๋ค์ ํ๋ณ์์ ์ ๋ ฅ์ผ๋ก ์ฌ์ฉํ๊ณ , ํ๋ณ์์ ์์ธก ๊ฒฐ๊ณผ๋ฅผ ๋ฐํ๋ฐ์ ์๋ก ๋ค๋ฅธ ๋ ๋ชจ๋ธ์ ๋ํ ๊ฐ์ค ์์คํฉ์ฐ์ ๊ตฌํ ๋ค, ์ตํฐ๋ง์ด์ ์ ๋ณด๋ด๊ณ ์ต์ ํ๋ฅผ ์ํํ๋ค. ์ด ๋, ์ฒ์์ ๋ฐ์ดํฐ๋ก๋๊ฐ ๋ฐํํ๋ ์ ๋ ฅ ์ํ์ค์ ๋ผ๋ฒจ์ MLM์ ๊ทธ๊ฒ๊ณผ ๋์ผํ๋ค,
๊ตฌํํ๋ฉด์ ๊ฐ์ฅ ์ด๋ ค์ ๋๊ฒ, ์ตํฐ๋ง์ด์ ๋ฐ ์ค์ผ์ค๋ฌ์ ๊ตฌ์ฑ์ด์๋ค. ๋ ๊ฐ์ ๋ชจ๋ธ์ ๊ฐ์ ์คํญ์์ ํ์ต์ํค๋ ๊ฒฝํ์ด ์ฒ์์ด๋ผ์ ์ฒ์์ ๋ชจ๋ธ ๊ฐ์๋งํผ ์ตํฐ๋ง์ด์ ์ ์ค์ผ์ค๋ฌ ๊ฐ์ฒด๋ฅผ ๋ง๋ค์ด์ค์ผ ํ๋ค๊ณ ์๊ฐํ๋ค. ํนํ ๋ ๋ชจ๋ธ์ ์ค์ผ์ผ์ด ๋ค๋ฅด๊ธฐ ๋๋ฌธ์ ์๋ก ๋ค๋ฅธ ์ตํฐ๋ง์ด์ , ์ค์ผ์ค๋ฌ ๋์ ์ผ๋ก ๊ฐ๊ธฐ ๋ค๋ฅธ ํ์ต๋ฅ ์ ์ ์ฉํ๋๊ฒ ์ ํํ ๊ฒ์ด๋ผ ์๊ฐํ๋ค.
ํ์ง๋ง, ์ตํฐ๋ง์ด์ ๋ฅผ ๋ ๊ฐ ์ฌ์ฉํ๋ ๊ฒ์ ๋งค์ฐ ๋ง์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ฐจ์งํ ๋ฟ๋๋ฌ ๋ ผ๋ฌธ์์ ๊ณต๊ฐํ ํ์ดํผํ๋ผ๋ฏธํฐ ํ ์ด๋ธ์ ๋ณด๋ฉด ๋ ๋ชจ๋ธ์ ๊ฐ์ ํ์ต๋ฅ ์ ์ ์ฉํ๊ณ ์๋ ๊ฒ์ ์ ์ ์์๋ค. ๋ฐ๋ผ์ ๊ทธ์ ๋ง๊ฒ ๊ฐ์ ์ตํฐ๋ง์ด์ , ์ค์ผ์ค๋ฌ๋ฅผ ์ฌ์ฉํด ๋์์ ๋ ๋ชจ๋ธ์ด ํ์ต๋๋๋ก ํ์ดํ๋ผ์ธ์ ๋ง๋ค๊ฒ ๋์๋ค.
์ถ๊ฐ๋ก, ๊ณต๊ฐ๋ ์คํผ์ ์ฝ๋ ์ญ์ ๋จ์ผ ์ตํฐ๋ง์ด์ ๋ฐ ์ค์ผ์ค๋ฌ๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ ํ์ธํ๋ค.
๐ ELECTRA Module
import torch
import torch.nn as nn
from experiment.models.abstract_model import AbstractModel
from torch import Tensor
from typing import Tuple, Callable
from einops.layers.torch import Rearrange
from experiment.tuner.mlm import MLMHead
from experiment.tuner.sbo import SBOHead
from experiment.tuner.rtd import get_discriminator_input, RTDHead
from configuration import CFG
class ELECTRA(nn.Module, AbstractModel):
def __init__(self, cfg: CFG, model_func: Callable) -> None:
super(ELECTRA, self).__init__()
self.cfg = cfg
self.generator = model_func(cfg.generator_num_layers) # init generator
self.mlm_head = MLMHead(self.cfg)
if self.cfg.rtd_masking == 'SpanBoundaryObjective':
self.mlm_head = SBOHead(
cfg=self.cfg,
is_concatenate=self.cfg.is_concatenate,
max_span_length=self.cfg.max_span_length
)
self.discriminator = model_func(cfg.discriminator_num_layers) # init generator
self.rtd_head = RTDHead(self.cfg)
self.share_embed_method = self.cfg.share_embed_method # instance, es, gdes
self.share_embedding()
def share_embedding(self) -> None:
def discriminator_hook(module: nn.Module, *inputs):
if self.share_embed_method == 'instance': # Instance Sharing
self.discriminator.embeddings = self.generator.embeddings
elif self.share_embed_method == 'ES': # ES (Embedding Sharing)
self.discriminator.embeddings.word_embedding.weight = self.generator.embeddings.word_embedding.weight
self.discriminator.embeddings.abs_pos_emb.weight = self.generator.embeddings.abs_pos_emb.weight
self.discriminator.register_forward_pre_hook(discriminator_hook)
def generator_fw(
self,
inputs: Tensor,
labels: Tensor,
padding_mask: Tensor,
mask_labels: Tensor = None,
attention_mask: Tensor = None
) -> Tuple[Tensor, Tensor, Tensor]:
g_last_hidden_states, _ = self.generator(
inputs,
padding_mask,
attention_mask
)
if self.cfg.rtd_masking == 'MaskedLanguageModel':
g_logit = self.mlm_head(
g_last_hidden_states
)
elif self.cfg.rtd_masking == 'SpanBoundaryObjective':
g_logit = self.mlm_head(
g_last_hidden_states,
mask_labels
)
pred = g_logit.clone().detach()
d_inputs, d_labels = get_discriminator_input(
inputs,
labels,
pred,
)
return g_logit, d_inputs, d_labels
def discriminator_fw(
self,
inputs: Tensor,
padding_mask: Tensor,
attention_mask: Tensor = None
) -> Tensor:
d_last_hidden_states, _ = self.discriminator(
inputs,
padding_mask,
attention_mask
)
d_logit = self.rtd_head(
d_last_hidden_states
)
return d_logit
ELECTRA ๋ชจ๋ธ ๊ฐ์ฒด๋ ํฌ๊ฒ ์๋ฐฐ๋ฉ ๋ ์ด์ด ๊ณต์ , ์์ฑ์ ํฌ์๋, ํ๋ณ์ ํฌ์๋ ํํธ๋ก ๋๋๋ค. ๋จผ์ ์๋ฒ ๋ฉ ๋ ์ด์ด ๊ณต์ ๋ ํฌ๊ฒ ๋ ๊ฐ์ง ๋ฐฉ์์ผ๋ก ๊ตฌํ ๊ฐ๋ฅํ๋ค. ํ๋๋ ์๋ฒ ๋ฉ ๋ ์ด์ด ์ธ์คํด์ค ์์ฒด๋ฅผ ๊ณต์ ํ๋ ๋ฐฉ์์ผ๋ก, ์์ฑ์์ ํ๋ณ์์ ์ค์ผ์ผ์ด ๋์ผํ ๋ ์ฌ์ฉํ ์ ์๋ค. ๋๋จธ์ง๋ ๋จ์ด ์๋ฒ ๋ฉ, ํฌ์ง์ ์๋ฒ ๋ฉ์ ๊ฐ์ค์น ํ๋ ฌ๋ง ๊ณต์ ํ๋ ๋ฐฉ์์ผ๋ก, ์๋ก ์ค์ผ์ผ์ด ๋ฌ๋ผ๋ ์ฌ์ฉํ ์ ์๋ค. ๋ ผ๋ฌธ์์ ์ ์ํ๋ ๊ฐ์ฅ ํจ์จ์ ์ธ ๋ฐฉ๋ฒ์ ํ์์ด๋ฉฐ, ํ๋ณ์์ ์๋ฒ ๋ฉ ํ๋ ฌ์ด ์์ฑ์์ ์๋ฒ ๋ฉ ํ๋ ฌ์ ์ฃผ์๋ฅผ ๊ฐ๋ฆฌํค๋๋ก ํจ์ผ๋ก์ ๊ตฌํ ๊ฐ๋ฅํ๋ค.
๐ RTD Input Making
import torch
import torch.nn as nn
from torch import Tensor
from typing import Tuple
from configuration import CFG
def get_discriminator_input(inputs: Tensor, labels: Tensor, pred: Tensor) -> Tuple[Tensor, Tensor]:
""" Post Processing for Replaced Token Detection Task
1) get index of the highest probability of [MASK] token in pred tensor
2) convert [MASK] token to prediction token
3) make label for Discriminator
Args:
inputs: pure inputs from tokenizing by tokenizer
labels: labels for masked language modeling
pred: prediction tensor from Generator
returns:
d_inputs: torch.Tensor, shape of [Batch, Sequence], for Discriminator inputs
d_labels: torch.Tensor, shape of [Sequence], for Discriminator labels
"""
# 1) flatten pred to 2D Tensor
d_inputs, d_labels = inputs.clone().detach().view(-1), None # detach to prevent back-propagation
flat_pred, flat_label = pred.view(-1, pred.size(-1)), labels.view(-1) # (batch * sequence, vocab_size)
# 2) get index of the highest probability of [MASK] token
pred_token_idx, mlm_mask_idx = flat_pred.argmax(dim=-1), torch.where(flat_label != -100)
pred_tokens = torch.index_select(pred_token_idx, 0, mlm_mask_idx[0])
# 3) convert [MASK] token to prediction token
d_inputs[mlm_mask_idx[0]] = pred_tokens
# 4) make label for Discriminator
original_tokens = inputs.clone().detach().view(-1)
original_tokens[mlm_mask_idx[0]] = flat_label[mlm_mask_idx[0]]
d_labels = torch.eq(original_tokens, d_inputs).long()
d_inputs = d_inputs.view(pred.size(0), -1) # covert to [batch, sequence]
return d_inputs, d_labels
์ด์ ๋ง์ง๋ง์ผ๋ก ํ๋ณ์์ ์ ๋ ฅ์ ๋ง๋๋ ์๊ณ ๋ฆฌ์ฆ์ ๋ํ ์ฝ๋๋ฅผ ๋ณด์. ์๊ณ ๋ฆฌ์ฆ์ ๋ค์๊ณผ ๊ฐ๋ค.
- 1) ๊ฐ๋ณ ๋ง์คํน ํ ํฐ์ ๋ํ ์์ธก ํ ํฐ ๊ตฌํ๊ธฐ
- ๋ก์ง์ ์ค์ ํ ํฐ ์ธ๋ฑ์ค๋ก ๋ณํ
- 2) ๋ชจ๋ ๋ง์คํน ๋ถ๋ถ์ ์์ธก ํ ํฐ๋ค๋ก ๋์ฒด
- 3) ๊ธฐ์กด ์
๋ ฅ๊ณผ 2๋ฒ์ผ๋ก ๋ง๋ค์ด์ง ์ํ์ค ๋น๊ตํด ๋ผ๋ฒจ ์์ฑ
- ์๋ก ๊ฐ์ผ๋ฉด 0
- ์๋ก ๋ค๋ฅด๋ฉด 1 ์ด๋ ๊ฒ ๋ง๋ค์ด์ง ์๋ก์ด ์ ๋ ฅ ์ํ์ค์ ๋ผ๋ฒจ์ ELECTRA ๋ชจ๋ธ ์ธ์คํด์ค์ ํ๋ณ์ ํฌ์๋ ๋ฉ์๋์ ์ธ์๋ก ์ ๋ฌํ๋ฉด ๋๋ค.
๐ Future Work (์ฝ๊ณ ๊ตฌํํ๋ฉด์ ๋๋์ & ๊ฐ์ ๋ฐฉํฅ)
์ด๋ ๊ฒ ELECTRA ๋ชจ๋ธ์ ๋ํ ๊ตฌํ๊น์ง ์ดํด๋ดค๋ค. ๋
ผ๋ฌธ์ ์ฝ๊ณ ๊ตฌํํ๋ฉด์ ๊ฐ์ฅ ์๋ฌธ์ค๋ฌ์ ๋ ๋ถ๋ถ์ ์๋ฒ ๋ฉ ๊ณต์ ๋ฐฉ๋ฒ์ด์๋ค. ์ํ์ ์ผ๋ก ์๋ฐํ๊ฒ ๊ณ์ฐํ๊ณ ๋ฐ์ ธ๋ณด์ง ๋ชปํ์ง๋ง, ์ง๊ด์ ์ผ๋ก๋ ์์ฑ์์ MLM๊ณผ ํ๋ณ์์ RTD๋ ์๋ก ์ฑ๊ฒฉ์ด ์๋นํ ๋ค๋ฅธ ์ฌ์ ํ์ต ๋ฐฉ๋ฒ๋ก ์ด๋ผ๋ ๊ฒ์ ์ ์ ์๋ค. ๊ทธ๋ ๋ค๋ฉด ๋จ์ํ ๋จ์ด, ํฌ์ง์
์๋ฒ ๋ฉ์ ๊ณต์ ํ๋ ๊ฒฝ์ฐ ํ์ต ๋ฐฉํฅ์ฑ์ด ๋ฌ๋ผ์ ๊ฐ์ญ์ด ๋ฐ์ํ๊ณ ๋ชจ๋ธ์ด ์๋ ดํ์ง ๋ชปํ ์ฌ์ง๊ฐ ์๊ธด๋ค. ์ด๋ฌํ ์ค๋ค๋ฆฌ๊ธฐ ํ์(tag-of-war)
์ ์ด๋ป๊ฒ ํด๊ฒฐํ ์ ์์๊น์ ๋ํ ๊ณ ๋ฏผ์ด ๋ ํ์ํ๋ค๊ณ ์๊ฐํ๋ค.
๊ทธ๋์ ๋ค์ ํฌ์คํ ์์๋ ์ด๋ฌํ ์ค๋ค๋ฆฌ๊ธฐ ํ์์ ํด๊ฒฐํ๊ณ ์ํ ๋ ผ๋ฌธ์ธ <DeBERTaV3: Improving DeBERTa using ELECTRA-Style Pre-Training with Gradient-Disentangled Embedding Sharing>์ ๋ฆฌ๋ทฐํด๋ณด๊ณ ์ ํ๋ค.
Leave a comment