Updated:

πŸ”­Β Overview

2021λ…„ Microsoftμ—μ„œ κ³΅κ°œν•œ DeBERTa-V3은 κΈ°μ‘΄ DeBERTa의 λͺ¨λΈ κ΅¬μ‘°λŠ” κ·ΈλŒ€λ‘œ μœ μ§€ν•˜λ˜, ELECTRA의 Generator-Discriminator ꡬ쑰λ₯Ό μ°¨μš©ν•˜μ—¬ μ „μž‘ λŒ€λΉ„ μ„±λŠ₯을 ν–₯상 μ‹œν‚¨ λͺ¨λΈμ΄λ‹€. ELECTRAμ—μ„œ BackBone λͺ¨λΈλ‘œ BERT λŒ€μ‹  DeBERTa을 μ‚¬μš©ν–ˆλ‹€κ³  μƒκ°ν•˜λ©΄ λœλ‹€. 거기에 더해 ELECTRA의 Tug-of-War ν˜„μƒμ„ λ°©μ§€ν•˜κΈ° μœ„ν•΄ μƒˆλ‘œμš΄ μž„λ² λ”© 곡유 기법인 GDES(Gradient Disentagnled Embedding Sharing)방법을 μ œμ‹œν–ˆλ‹€.

이번 ν¬μŠ€νŒ…μ—μ„œλŠ” κ΅¬ν˜„ μ½”λ“œμ™€ ν•¨κ»˜ GDES에 λŒ€ν•΄μ„œλ§Œ μ‚΄νŽ΄λ³΄λ € ν•œλ‹€. ELECTRA, DeBERTa에 λŒ€ν•΄ κΆκΈˆν•˜λ‹€λ©΄ 이전 ν¬μŠ€νŒ…μ„, 전체 ꡬ쑰에 λŒ€ν•œ μ½”λ“œλŠ” μ—¬κΈ° 링크λ₯Ό 톡해 확인 κ°€λŠ₯ν•˜λ‹€.

πŸͺ’GDES: Gradient Disentangled Embedding Sharing

GDES GDES

그림의 (a)κ°€ κΈ°μ‘΄ ELECTRA의 κ°€μ€‘μΉ˜ 곡유 방식, (c)κ°€ GDES에 ν•΄λ‹Ήλœλ‹€. κ·Έλ¦Ό 속 λͺ¨μ‹λ„와 μ„€λͺ…이 μ’€ λ³΅μž‘ν•΄ λ³΄μ΄μ§€λ§Œ μ•„μ΄λ””μ–΄λŠ” 맀우 κ°„λ‹¨ν•˜λ‹€.

μƒμ„±μžμ™€ νŒλ³„μžκ°€ μ„œλ‘œ ν¬μ›Œλ“œ 패슀 μ‹œμ μ—λŠ” 단어, μœ„μΉ˜ μž„λ² λ”©μ„ κ³΅μœ ν•˜λ˜, λ°±μ›Œλ“œ 패슀 μ‹œμ μ—μ„œλŠ” κ³΅μœ λ˜μ§€ λͺ»ν•˜λ„둝 ν•˜μ—¬, νŒλ³„μžμ˜ ν•™μŠ΅ 결과에 μ˜ν•΄ μƒμ„±μžμ˜ 단어 μž„λ² λ”©, μœ„μΉ˜ μž„λ² λ”©μ΄ μ—…λ°μ΄νŠΈ λ˜μ§€ λͺ»ν•˜λ„둝 ν•˜μ§€λŠ” 것이닀. 였직 μƒμ„±μžμ˜ MLM ν•™μŠ΅μ— μ˜ν•΄μ„œλ§Œ 단어 및 μœ„μΉ˜ μž„λ² λ”©μ΄ μ—…λ°μ΄νŠΈ λ˜μ–΄μ•Ό ν•œλ‹€.

\[E_{D} = \text{sg}(E_{G}) + E_{\Delta}\]

ν•„μžκ°€ μΆ”μ •ν•˜κΈ°λ‘œλŠ” Skip-Connectionμ—μ„œ μ˜κ°μ„ 받지 μ•Šμ•˜λ‚˜ 싢은 이 μˆ˜μ‹μ€, μƒμ„±μžμ˜ μž„λ² λ”©μ— μž”μ°¨κ°’λ“€μ„ 더해 νŒλ³„μžμ˜ μž„λ² λ”© 행렬이 RTD에 μ΅œμ ν™” λ˜λ„λ‘ 섀계 λ˜μ—ˆλ‹€. μ—¬κΈ°μ„œ sg() λŠ” stop gradientλ₯Ό μ˜λ―Έν•œλ‹€. λ‹€μ‹œ 말해, μƒμ„±μžμ˜ μž„λ² λ”© κ°€μ€‘μΉ˜λ₯Ό νŒλ³„μž ν•™μŠ΅μ— μ‚¬μš©ν•˜λ˜, ν•΄λ‹Ή μ‹œμ μ—μ„œλŠ” 계산 κ·Έλž˜ν”„ μž‘μ„±μ„ μ€‘λ‹¨μ‹œμΌœ νŒλ³„μžμ˜ ν•™μŠ΅ κ²°κ³Ό(이진 λΆ„λ₯˜ 손싀)κ°€ μƒμ„±μžμ˜ μž„λ² λ”© κ°€μ€‘μΉ˜μ— 영ν–₯을 λ―ΈμΉ˜μ§€ λͺ»ν•˜λ„둝 ν•œ 것이닀.

μ΄λŸ¬ν•œ μ•„μ΄λ””μ–΄λŠ” μ‹€μ œλ‘œ μ–΄λ–»κ²Œ μ½”λ“œλ‘œ κ΅¬ν˜„ν•΄μ•Όν• κΉŒ, μ•„λž˜ μ½”λ“œμ™€ ν•¨κ»˜ μ‚΄νŽ΄λ³΄μž.

πŸ‘©β€πŸ’»Β Implementation by Pytorch

ELECTRA λͺ¨λ“ˆ __init__의 share_embed_method에 따라 λΈŒλžœμΉ˜κ°€ λ°œμƒν•˜λŠ” ꡬ간과, μ•„λž˜ share_embedding() λ©”μ„œλ“œμ— μ£Όλͺ©ν•΄λ³΄μž.

import torch
import torch.nn as nn
from experiment.models.abstract_model import AbstractModel
from torch import Tensor
from typing import Tuple, Callable
from einops.layers.torch import Rearrange
from experiment.tuner.mlm import MLMHead
from experiment.tuner.sbo import SBOHead
from experiment.tuner.rtd import get_discriminator_input, RTDHead
from configuration import CFG


class ELECTRA(nn.Module, AbstractModel):
    """ If you want to use pure ELECTRA, you should set share_embedding = ES
    elif you want to use ELECTRA with GDES, you should set share_embedding = GDES
    GDES is new approach of embedding sharing method from DeBERTa-V3 paper

    Args:
        cfg: configuration.CFG
        model_func: make model instance in runtime from config.json

    Var:
        cfg: configuration.CFG
        generator: Generator, which is used for generating replaced tokens for RTD
                   should select backbone model ex) BERT, RoBERTa, DeBERTa, ...
        discriminator: Discriminator, which is used for detecting replaced tokens for RTD
                       should select backbone model ex) BERT, RoBERTa, DeBERTa, ...
        share_embedding: whether or not to share embedding layer (word & pos) between Generator & Discriminator
        self.word_bias: Delta_E in paper
        self.abs_pos_bias: Delta_E in paper
        self.rel_pos_bias: Delta_E in paper

    References:
        https://arxiv.org/pdf/2003.10555.pdf
        https://arxiv.org/pdf/2111.09543.pdf
        https://github.com/google-research/electra
    """
    def __init__(self, cfg: CFG, model_func: Callable) -> None:
        super(ELECTRA, self).__init__()
        self.cfg = cfg
        self.generator = model_func(cfg.generator_num_layers)  # init generator
        self.mlm_head = MLMHead(self.cfg)
        if self.cfg.rtd_masking == 'SpanBoundaryObjective':
            self.mlm_head = SBOHead(
                cfg=self.cfg,
                is_concatenate=self.cfg.is_concatenate,
                max_span_length=self.cfg.max_span_length
            )
        self.discriminator = model_func(cfg.discriminator_num_layers)  # init generator
        self.rtd_head = RTDHead(self.cfg)
        self.share_embed_method = self.cfg.share_embed_method  # instance, es, gdes
        if self.share_embed_method == 'GDES':
            self.word_bias = nn.Parameter(
                torch.zeros_like(self.discriminator.embeddings.word_embedding.weight, device=self.cfg.device)
            )
            self.abs_pos_bias = nn.Parameter(
                torch.zeros_like(self.discriminator.embeddings.abs_pos_emb.weight, device=self.cfg.device)
            )
            delattr(self.discriminator.embeddings.word_embedding, 'weight')
            self.discriminator.embeddings.word_embedding.register_parameter('_weight', self.word_bias)

            delattr(self.discriminator.embeddings.abs_pos_emb, 'weight')
            self.discriminator.embeddings.abs_pos_emb.register_parameter('_weight', self.abs_pos_bias)

            if self.cfg.model_name == 'DeBERTa':
                self.rel_pos_bias = nn.Parameter(
                    torch.zeros_like(self.discriminator.embeddings.rel_pos_emb.weight, device=self.cfg.device)
                )
                delattr(self.discriminator.embeddings.rel_pos_emb, 'weight')
                self.discriminator.embeddings.rel_pos_emb.register_parameter('_weight', self.rel_pos_emb)
        self.share_embedding()

    def share_embedding(self) -> None:
        def discriminator_hook(module: nn.Module, *inputs):
            if self.share_embed_method == 'instance':  # Instance Sharing
                self.discriminator.embeddings = self.generator.embeddings

            elif self.share_embed_method == 'ES':  # ES (Embedding Sharing)
                self.discriminator.embeddings.word_embedding.weight = self.generator.embeddings.word_embedding.weight
                self.discriminator.embeddings.abs_pos_emb.weight = self.generator.embeddings.abs_pos_emb.weight
                if self.cfg.model_name == 'DeBERTa':
                    self.discriminator.embeddings.rel_pos_emb.weight = self.generator.embeddings.rel_pos_emb.weight

            elif self.share_embed_method == 'GDES':  # GDES (Generator Discriminator Embedding Sharing)
                g_w_emb = self.generator.embeddings.word_embedding
                d_w_emb = self.discriminator.embeddings.word_embedding
                self._set_param(d_w_emb, 'weight', g_w_emb.weight.detach() + d_w_emb._weight)
                g_p_emb = self.generator.embeddings.abs_pos_emb
                d_p_emb = self.discriminator.embeddings.abs_pos_emb
                self._set_param(d_p_emb, 'weight', g_p_emb.weight.detach() + d_p_emb._weight)

                if self.cfg.model_name == 'DeBERTa':
                    g_rp_emb = self.generator.embeddings.rel_pos_emb
                    d_rp_emb = self.discriminator.embeddings.rel_pos_emb
                    self._set_param(d_rp_emb, 'weight', g_rp_emb.weight.detach() + d_rp_emb._weight)
        self.discriminator.register_forward_pre_hook(discriminator_hook)

    @staticmethod
    def _set_param(module, param_name, value):
        module.register_buffer(param_name, value)

    def generator_fw(self, inputs: Tensor, labels: Tensor, padding_mask: Tensor, mask_labels: Tensor = None, attention_mask: Tensor = None) -> Tuple[Tensor, Tensor, Tensor]:
        g_last_hidden_states, _ = self.generator(
            inputs,
            padding_mask,
            attention_mask
        )
        if self.cfg.rtd_masking == 'MaskedLanguageModel':
            g_logit = self.mlm_head(
                g_last_hidden_states
            )
        elif self.cfg.rtd_masking == 'SpanBoundaryObjective':
            g_logit = self.mlm_head(
                g_last_hidden_states,
                mask_labels
            )
        pred = g_logit.clone().detach()
        d_inputs, d_labels = get_discriminator_input(
            inputs,
            labels,
            pred,
        )
        return g_logit, d_inputs, d_labels

    def discriminator_fw(self, inputs: Tensor, padding_mask: Tensor,attention_mask: Tensor = None) -> Tensor:
        d_last_hidden_states, _ = self.discriminator(
            inputs,
            padding_mask,
            attention_mask
        )
        d_logit = self.rtd_head(
            d_last_hidden_states
        )
        return d_logit

λ¨Όμ € __init__의 브랜치 ꡬ간을 μ‚΄νŽ΄λ³΄μž. word_bias, pos_biasλ₯Ό λ§Œλ“€μ–΄ register_parameterν™”λ₯Ό ν•˜κ³  μžˆλ‹€. μƒˆλ‘­κ²Œ μƒμ„±λ˜μ–΄ _weightμ΄λž€ μ΄λ¦„μœΌλ‘œ μƒμ„±μžμ˜ νŒŒλΌλ―Έν„°κ°€ 된 두 κ°€μ€‘μΉ˜κ°€ λ°”λ‘œ $E_{\Delta}$ κ°€ λœλ‹€.

λ‹€μŒ share_embedding() λ©”μ„œλ“œλ₯Ό 보자. $E_{G}$ 에 torch.detach()λ₯Ό μ‚¬μš©ν•΄ μˆ˜μ‹μ˜ stop gradient 효과λ₯Ό μ μš©ν•œλ‹€. 그리고 두 κ°€μ€‘μΉ˜λ₯Ό λ”ν•˜κ³ , torch.register_bufferλ₯Ό ν˜ΈμΆœν•΄ ν¬μ›Œλ“œ νŒ¨μŠ€μ— ν™œμš©μ€ λ˜μ§€λ§Œ λ°±μ›Œλ“œ νŒ¨μŠ€μ— κ·ΈλΌλ””μ–ΈνŠΈκ°€ ν•΄λ‹Ή κ°€μ€‘μΉ˜λ₯Ό μ—…λ°μ΄νŠΈ ν•˜μ§€ λͺ»ν•˜λ„둝 μ„€μ •ν•œλ‹€. 그리고 λ§ˆμ§€λ§‰μ— torch.register_forward_pre_hook을 ν˜ΈμΆœν•˜λŠ”λ°, κ·Έ μ΄μœ λŠ” $E_{G}$ 에 torch.detach() λ₯Ό μ‚¬μš©ν–ˆκΈ° λ•Œλ¬Έμ— ν˜„μž¬ νŒλ³„μžμ˜ 버퍼에 μžˆλŠ” $E_{G}$ λŠ” 이전 μ‹œμ μ˜ μƒμ„±μž MLM 손싀에 μ˜ν•΄ μƒˆλ‘­κ²Œ μ—…λ°μ΄νŠΈ $E_{G}$ κ°€ μ•„λ‹ˆλ‹€. λ”°λΌμ„œ 맀번 νŒλ³„μžμ˜ ν¬μ›Œλ“œ νŒ¨μŠ€κ°€ 호좜(μ‹œμž‘)λ˜λŠ” μ‹œμ μ— μ—…λ°μ΄νŠΈ 된 $E_{G}$ λ₯Ό λ°˜μ˜ν•΄ RTDλ₯Ό μˆ˜ν–‰ν•  수 μžˆλ„λ‘ ν•˜κΈ° μœ„ν•΄ register_forward_pre_hook λ₯Ό μ‚¬μš©ν–ˆλ‹€.

πŸ€” GDES Experiment

GDESκ°€ μ œλŒ€λ‘œ κ΅¬ν˜„λ˜μ—ˆλŠ”μ§€, λ…Όλ¬Έ μ£Όμž₯λŒ€λ‘œ νŒλ³„μž ν•™μŠ΅ κ²°κ³Όκ°€ 간섭을 λ°œμƒμ‹œν‚€μ§€ μ•ŠλŠ”μ§€ ν™•μΈν•˜κΈ° μœ„ν•΄ ν•œκ°€μ§€ μ‹€ν—˜μ„ μ§„ν–‰ν–ˆλ‹€. μ‹€ν—˜ λ‚΄μš©μ€ 이렇닀. λ§Œμ•½ GDESκ°€ μ˜λ„λŒ€λ‘œ κ΅¬ν˜„λœκ²Œ λ§žλ‹€λ©΄, 인코더 λͺ¨λΈμ˜ MLM ν•™μŠ΅ κ²°κ³Ό 좔이와 ELECTRA의 μƒμ„±μž ν•™μŠ΅ κ²°κ³Ό 좔이 양상이 μœ μ‚¬ν•΄μ•Ό ν•œλ‹€. λ§Œμ•½ μ΅œμ ν™” μΆ”μ„Έκ°€ λ‹€λ₯΄λ‹€λ©΄, ν•„μžκ°€ 잘λͺ» κ΅¬ν˜„ν–ˆκ±°λ‚˜, μ €μžμ˜ μ£Όμž₯κ³Ό λ‹€λ₯΄κ²Œ 간섭이 λ°œμƒν•˜λŠ” 것이라 λ³Ό 수 μžˆμ„ 것이닀. Backbone을 DeBERTa둜 두고 각각 ν•™μŠ΅μ„ μ§„ν–‰ν–ˆλ‹€. λͺ¨λ“  ν•˜μ΄νΌ νŒŒλΌλ―Έν„°λ₯Ό κ³ μ •ν•œ λ’€, ν•™μŠ΅ 초반 120μŠ€νƒ­μ— λŒ€ν•œ κ²°κ³Ό 좔이λ₯Ό 비ꡐ해봀닀.

DeBERTa MLM Result DeBERTa MLM Result

GDES Result GDES Result

미처 까먹고 torch.backends.cudnn.deterministic = False둜 두고 μ‹€ν—˜μ„ μ§„ν–‰ν•˜μ—¬, μƒμ„±μžμ˜ 수렴이 μ’€ 더 빨리 μ§„ν–‰λ˜λŠ” 양상을 보이고 μžˆλ‹€. μ•„λ§ˆλ„ μƒμ„±μž ν•™μŠ΅μ„ ν•  λ•Œ cudnn 이 μ—΄μ‹¬νžˆ 일을 ν•œ 것 κ°™λŒœ. 수렴 μ†λ„μ—λŠ” 차이가 쑰금 λ‚˜μ§€λ§Œ, μ΅œμ ν™” λ˜λŠ” μΆ”μ„Έ μžμ²΄λŠ” λ™μΌν•œ 것을 μ•Œ 수 μžˆλ‹€.

λ”°λΌμ„œ GDESλ₯Ό μ‚¬μš©ν•˜λ©΄ 간섭이 λ°œμƒν•˜μ§€ μ•Šμ•„ Tug-of-War ν˜„μƒμ„ 방지할 수 μžˆλ‹€. λ‹€λ§Œ, μ‹€ν—˜μ΄ λ‹€μ†Œ μ—„λ°€ν•˜μ§€ λͺ»ν•œ 츑면이 μžˆλ‹€. 좔후에 μ’€ 더 μ—„λ°€ν•œ 증λͺ…을 ν•  수 μžˆλŠ” μ‹€ν—˜ 방법을 생각해봐야겠닀.

Leave a comment