πͺ’Β [DeBERTa-V3] DeBERTaV3: Improving DeBERTa using ELECTRA-Style Pre-Training with Gradient-Disentangled Embedding Sharing
πΒ Overview
2021λ
Microsoftμμ 곡κ°ν DeBERTa-V3
μ κΈ°μ‘΄ DeBERTaμ λͺ¨λΈ ꡬ쑰λ κ·Έλλ‘ μ μ§νλ, ELECTRAμ Generator-Discriminator ꡬ쑰λ₯Ό μ°¨μ©νμ¬ μ μ λλΉ μ±λ₯μ ν₯μ μν¨ λͺ¨λΈμ΄λ€. ELECTRAμμ BackBone λͺ¨λΈλ‘ BERT λμ DeBERTaμ
μ¬μ©νλ€κ³ μκ°νλ©΄ λλ€. κ±°κΈ°μ λν΄ ELECTRAμ Tug-of-War
νμμ λ°©μ§νκΈ° μν΄ μλ‘μ΄ μλ² λ© κ³΅μ κΈ°λ²μΈ GDES(Gradient Disentagnled Embedding Sharing)
λ°©λ²μ μ μνλ€.
μ΄λ² ν¬μ€ν μμλ ꡬν μ½λμ ν¨κ» GDESμ λν΄μλ§ μ΄ν΄λ³΄λ € νλ€. ELECTRA, DeBERTaμ λν΄ κΆκΈνλ€λ©΄ μ΄μ ν¬μ€ν μ, μ 체 ꡬ쑰μ λν μ½λλ μ¬κΈ° λ§ν¬λ₯Ό ν΅ν΄ νμΈ κ°λ₯νλ€.
πͺ’GDES: Gradient Disentangled Embedding Sharing
κ·Έλ¦Όμ (a)κ° κΈ°μ‘΄ ELECTRAμ κ°μ€μΉ 곡μ λ°©μ, (c)κ° GDESμ ν΄λΉλλ€. κ·Έλ¦Ό μ λͺ¨μλμ μ€λͺ μ΄ μ’ λ³΅μ‘ν΄ λ³΄μ΄μ§λ§ μμ΄λμ΄λ λ§€μ° κ°λ¨νλ€.
μμ±μμ νλ³μκ° μλ‘ ν¬μλ ν¨μ€ μμ μλ λ¨μ΄, μμΉ μλ² λ©μ 곡μ νλ, λ°±μλ ν¨μ€ μμ μμλ 곡μ λμ§ λͺ»νλλ‘ νμ¬, νλ³μμ νμ΅ κ²°κ³Όμ μν΄ μμ±μμ λ¨μ΄ μλ² λ©, μμΉ μλ² λ©μ΄ μ λ°μ΄νΈ λμ§ λͺ»νλλ‘ νμ§λ κ²μ΄λ€. μ€μ§ μμ±μμ MLM νμ΅μ μν΄μλ§ λ¨μ΄ λ° μμΉ μλ² λ©μ΄ μ λ°μ΄νΈ λμ΄μΌ νλ€.
\[E_{D} = \text{sg}(E_{G}) + E_{\Delta}\]νμκ° μΆμ νκΈ°λ‘λ Skip-Connection
μμ μκ°μ λ°μ§ μμλ μΆμ μ΄ μμμ, μμ±μμ μλ² λ©μ μμ°¨κ°λ€μ λν΄ νλ³μμ μλ² λ© νλ ¬μ΄ RTDμ μ΅μ ν λλλ‘ μ€κ³ λμλ€. μ¬κΈ°μ sg()
λ stop gradient
λ₯Ό μλ―Ένλ€. λ€μ λ§ν΄, μμ±μμ μλ² λ© κ°μ€μΉλ₯Ό νλ³μ νμ΅μ μ¬μ©νλ, ν΄λΉ μμ μμλ κ³μ° κ·Έλν μμ±μ μ€λ¨μμΌ νλ³μμ νμ΅ κ²°κ³Ό(μ΄μ§ λΆλ₯ μμ€)κ° μμ±μμ μλ² λ© κ°μ€μΉμ μν₯μ λ―ΈμΉμ§ λͺ»νλλ‘ ν κ²μ΄λ€.
μ΄λ¬ν μμ΄λμ΄λ μ€μ λ‘ μ΄λ»κ² μ½λλ‘ κ΅¬νν΄μΌν κΉ, μλ μ½λμ ν¨κ» μ΄ν΄λ³΄μ.
π©βπ»Β Implementation by Pytorch
ELECTRA λͺ¨λ __init__
μ share_embed_method
μ λ°λΌ λΈλμΉκ° λ°μνλ ꡬκ°κ³Ό, μλ share_embedding()
λ©μλμ μ£Όλͺ©ν΄λ³΄μ.
import torch
import torch.nn as nn
from experiment.models.abstract_model import AbstractModel
from torch import Tensor
from typing import Tuple, Callable
from einops.layers.torch import Rearrange
from experiment.tuner.mlm import MLMHead
from experiment.tuner.sbo import SBOHead
from experiment.tuner.rtd import get_discriminator_input, RTDHead
from configuration import CFG
class ELECTRA(nn.Module, AbstractModel):
""" If you want to use pure ELECTRA, you should set share_embedding = ES
elif you want to use ELECTRA with GDES, you should set share_embedding = GDES
GDES is new approach of embedding sharing method from DeBERTa-V3 paper
Args:
cfg: configuration.CFG
model_func: make model instance in runtime from config.json
Var:
cfg: configuration.CFG
generator: Generator, which is used for generating replaced tokens for RTD
should select backbone model ex) BERT, RoBERTa, DeBERTa, ...
discriminator: Discriminator, which is used for detecting replaced tokens for RTD
should select backbone model ex) BERT, RoBERTa, DeBERTa, ...
share_embedding: whether or not to share embedding layer (word & pos) between Generator & Discriminator
self.word_bias: Delta_E in paper
self.abs_pos_bias: Delta_E in paper
self.rel_pos_bias: Delta_E in paper
References:
https://arxiv.org/pdf/2003.10555.pdf
https://arxiv.org/pdf/2111.09543.pdf
https://github.com/google-research/electra
"""
def __init__(self, cfg: CFG, model_func: Callable) -> None:
super(ELECTRA, self).__init__()
self.cfg = cfg
self.generator = model_func(cfg.generator_num_layers) # init generator
self.mlm_head = MLMHead(self.cfg)
if self.cfg.rtd_masking == 'SpanBoundaryObjective':
self.mlm_head = SBOHead(
cfg=self.cfg,
is_concatenate=self.cfg.is_concatenate,
max_span_length=self.cfg.max_span_length
)
self.discriminator = model_func(cfg.discriminator_num_layers) # init generator
self.rtd_head = RTDHead(self.cfg)
self.share_embed_method = self.cfg.share_embed_method # instance, es, gdes
if self.share_embed_method == 'GDES':
self.word_bias = nn.Parameter(
torch.zeros_like(self.discriminator.embeddings.word_embedding.weight, device=self.cfg.device)
)
self.abs_pos_bias = nn.Parameter(
torch.zeros_like(self.discriminator.embeddings.abs_pos_emb.weight, device=self.cfg.device)
)
delattr(self.discriminator.embeddings.word_embedding, 'weight')
self.discriminator.embeddings.word_embedding.register_parameter('_weight', self.word_bias)
delattr(self.discriminator.embeddings.abs_pos_emb, 'weight')
self.discriminator.embeddings.abs_pos_emb.register_parameter('_weight', self.abs_pos_bias)
if self.cfg.model_name == 'DeBERTa':
self.rel_pos_bias = nn.Parameter(
torch.zeros_like(self.discriminator.embeddings.rel_pos_emb.weight, device=self.cfg.device)
)
delattr(self.discriminator.embeddings.rel_pos_emb, 'weight')
self.discriminator.embeddings.rel_pos_emb.register_parameter('_weight', self.rel_pos_emb)
self.share_embedding()
def share_embedding(self) -> None:
def discriminator_hook(module: nn.Module, *inputs):
if self.share_embed_method == 'instance': # Instance Sharing
self.discriminator.embeddings = self.generator.embeddings
elif self.share_embed_method == 'ES': # ES (Embedding Sharing)
self.discriminator.embeddings.word_embedding.weight = self.generator.embeddings.word_embedding.weight
self.discriminator.embeddings.abs_pos_emb.weight = self.generator.embeddings.abs_pos_emb.weight
if self.cfg.model_name == 'DeBERTa':
self.discriminator.embeddings.rel_pos_emb.weight = self.generator.embeddings.rel_pos_emb.weight
elif self.share_embed_method == 'GDES': # GDES (Generator Discriminator Embedding Sharing)
g_w_emb = self.generator.embeddings.word_embedding
d_w_emb = self.discriminator.embeddings.word_embedding
self._set_param(d_w_emb, 'weight', g_w_emb.weight.detach() + d_w_emb._weight)
g_p_emb = self.generator.embeddings.abs_pos_emb
d_p_emb = self.discriminator.embeddings.abs_pos_emb
self._set_param(d_p_emb, 'weight', g_p_emb.weight.detach() + d_p_emb._weight)
if self.cfg.model_name == 'DeBERTa':
g_rp_emb = self.generator.embeddings.rel_pos_emb
d_rp_emb = self.discriminator.embeddings.rel_pos_emb
self._set_param(d_rp_emb, 'weight', g_rp_emb.weight.detach() + d_rp_emb._weight)
self.discriminator.register_forward_pre_hook(discriminator_hook)
@staticmethod
def _set_param(module, param_name, value):
module.register_buffer(param_name, value)
def generator_fw(self, inputs: Tensor, labels: Tensor, padding_mask: Tensor, mask_labels: Tensor = None, attention_mask: Tensor = None) -> Tuple[Tensor, Tensor, Tensor]:
g_last_hidden_states, _ = self.generator(
inputs,
padding_mask,
attention_mask
)
if self.cfg.rtd_masking == 'MaskedLanguageModel':
g_logit = self.mlm_head(
g_last_hidden_states
)
elif self.cfg.rtd_masking == 'SpanBoundaryObjective':
g_logit = self.mlm_head(
g_last_hidden_states,
mask_labels
)
pred = g_logit.clone().detach()
d_inputs, d_labels = get_discriminator_input(
inputs,
labels,
pred,
)
return g_logit, d_inputs, d_labels
def discriminator_fw(self, inputs: Tensor, padding_mask: Tensor,attention_mask: Tensor = None) -> Tensor:
d_last_hidden_states, _ = self.discriminator(
inputs,
padding_mask,
attention_mask
)
d_logit = self.rtd_head(
d_last_hidden_states
)
return d_logit
λ¨Όμ __init__
μ λΈλμΉ κ΅¬κ°μ μ΄ν΄λ³΄μ. word_bias
, pos_bias
λ₯Ό λ§λ€μ΄ register_parameter
νλ₯Ό νκ³ μλ€. μλ‘κ² μμ±λμ΄ _weight
μ΄λ μ΄λ¦μΌλ‘ μμ±μμ νλΌλ―Έν°κ° λ λ κ°μ€μΉκ° λ°λ‘ $E_{\Delta}$ κ° λλ€.
λ€μ share_embedding()
λ©μλλ₯Ό 보μ. $E_{G}$ μ torch.detach()
λ₯Ό μ¬μ©ν΄ μμμ stop gradient
ν¨κ³Όλ₯Ό μ μ©νλ€. κ·Έλ¦¬κ³ λ κ°μ€μΉλ₯Ό λνκ³ , torch.register_buffer
λ₯Ό νΈμΆν΄ ν¬μλ ν¨μ€μ νμ©μ λμ§λ§ λ°±μλ ν¨μ€μ κ·ΈλΌλμΈνΈκ° ν΄λΉ κ°μ€μΉλ₯Ό μ
λ°μ΄νΈ νμ§ λͺ»νλλ‘ μ€μ νλ€. κ·Έλ¦¬κ³ λ§μ§λ§μ torch.register_forward_pre_hook
μ νΈμΆνλλ°, κ·Έ μ΄μ λ $E_{G}$ μ torch.detach()
λ₯Ό μ¬μ©νκΈ° λλ¬Έμ νμ¬ νλ³μμ λ²νΌμ μλ $E_{G}$ λ μ΄μ μμ μ μμ±μ MLM μμ€μ μν΄ μλ‘κ² μ
λ°μ΄νΈ $E_{G}$ κ° μλλ€. λ°λΌμ λ§€λ² νλ³μμ ν¬μλ ν¨μ€κ° νΈμΆ(μμ)λλ μμ μ μ
λ°μ΄νΈ λ $E_{G}$ λ₯Ό λ°μν΄ RTDλ₯Ό μνν μ μλλ‘ νκΈ° μν΄ register_forward_pre_hook
λ₯Ό μ¬μ©νλ€.
π€ GDES Experiment
GDESκ° μ λλ‘ κ΅¬νλμλμ§, λ Όλ¬Έ μ£Όμ₯λλ‘ νλ³μ νμ΅ κ²°κ³Όκ° κ°μμ λ°μμν€μ§ μλμ§ νμΈνκΈ° μν΄ νκ°μ§ μ€νμ μ§ννλ€. μ€ν λ΄μ©μ μ΄λ λ€. λ§μ½ GDESκ° μλλλ‘ κ΅¬νλκ² λ§λ€λ©΄, μΈμ½λ λͺ¨λΈμ MLM νμ΅ κ²°κ³Ό μΆμ΄μ ELECTRAμ μμ±μ νμ΅ κ²°κ³Ό μΆμ΄ μμμ΄ μ μ¬ν΄μΌ νλ€. λ§μ½ μ΅μ ν μΆμΈκ° λ€λ₯΄λ€λ©΄, νμκ° μλͺ» ꡬννκ±°λ, μ μμ μ£Όμ₯κ³Ό λ€λ₯΄κ² κ°μμ΄ λ°μνλ κ²μ΄λΌ λ³Ό μ μμ κ²μ΄λ€. Backboneμ DeBERTaλ‘ λκ³ κ°κ° νμ΅μ μ§ννλ€. λͺ¨λ νμ΄νΌ νλΌλ―Έν°λ₯Ό κ³ μ ν λ€, νμ΅ μ΄λ° 120μ€νμ λν κ²°κ³Ό μΆμ΄λ₯Ό λΉκ΅ν΄λ΄€λ€.
λ―Έμ² κΉλ¨Ήκ³ torch.backends.cudnn.deterministic = False
λ‘ λκ³ μ€νμ μ§ννμ¬, μμ±μμ μλ ΄μ΄ μ’ λ 빨리 μ§νλλ μμμ 보μ΄κ³ μλ€. μλ§λ μμ±μ νμ΅μ ν λ cudnn
μ΄ μ΄μ¬ν μΌμ ν κ² κ°λ. μλ ΄ μλμλ μ°¨μ΄κ° μ‘°κΈ λμ§λ§, μ΅μ ν λλ μΆμΈ μ체λ λμΌν κ²μ μ μ μλ€.
λ°λΌμ GDESλ₯Ό μ¬μ©νλ©΄ κ°μμ΄ λ°μνμ§ μμ Tug-of-War
νμμ λ°©μ§ν μ μλ€. λ€λ§, μ€νμ΄ λ€μ μλ°νμ§ λͺ»ν μΈ‘λ©΄μ΄ μλ€. μΆνμ μ’ λ μλ°ν μ¦λͺ
μ ν μ μλ μ€ν λ°©λ²μ μκ°ν΄λ΄μΌκ² λ€.
Leave a comment