๐๏ธ[SpanBERT] SpanBERT: Improving Pre-training by Representing and Predicting Spans
๐ญย Overview
SpanBERT
๋ 2020๋
ํ์ด์ค๋ถ์์ ๋ฐํํ BERT ๊ณ์ด ๋ชจ๋ธ๋ก, ์๋ก์ด ๋ฐฉ๋ฒ๋ก ์ธ SBO(Span Boundary Objective)
๋ฅผ ๊ณ ์ํด ์ฌ์ ํ์ต์ ํ์ฌ ๊ธฐ์กด ๋๋น ๋์ ์ฑ๋ฅ์ ๊ธฐ๋กํ๋ค. ๊ธฐ์กด MLM
, CLM
์ ๋จ์ผ ํ ํฐ์ ์์ธกํ๋ ๋ฐฉ์์ ์ฌ์ฉํ๊ธฐ ๋๋ฌธ์ Word-Level Task์ ์์ฃผ ์ ํฉํ์ง๋ง ์๋์ ์ผ๋ก QA, Sentence-Similarity ๊ฐ์ ๋ฌธ์ฅ ๋จ์ ํ
์คํฌ์ ๊ทธ๋๋ก ํ์ฉํ๊ธฐ์๋ ๋ถ์กฑํ ์ ์ด ์์๋ค. ์ด๋ฌํ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด ๊ณ ์๋ ๋ฐฉ๋ฒ๋ก ์ด ๋ฐ๋ก SBO
๋ค. SBO
๋, MLM๊ณผ ๋น์ทํ์ง๋ง, Span(์ โข๊ตฌ๋ฌธ) ๋จ์๋ก ๋ง์คํนํ๊ณ ๋ค์ Denoising์ ํ๊ธฐ ๋๋ฌธ์, Sentence-Level Task์ ์ํ๋ Down-Stream Task๋ฅผ ์ํ ๋ชจ๋ธ์ ์ฌ์ ํ๋ จ์ผ๋ก ์ ํฉํ๋ค.
์ ๋ฆฌํ์๋ฉด, SpanBERT
๋ชจ๋ธ์ ๊ธฐ์กด BERT์ ๊ตฌ์กฐ์ ์ธก๋ฉด ๊ฐ์ ์ด ์๋, ์ฌ์ ํ์ต ๋ฐฉ๋ฒ์ ๋ํ ๊ฐ์ ์๋๋ผ๊ณ ๋ณผ ์ ์๋ค. ๋ฐ๋ผ์ ์ด๋ค ๋ชจ๋ธ์ด๋๋ผ๋, ์ธ์ฝ๋ ์ธ์ด ๋ชจ๋ธ์ด๋ผ๋ฉด ๋ชจ๋ SpanBERT
๊ตฌ์กฐ๋ฅผ ์ฌ์ฉํ ์ ์์ผ๋ฉฐ, ๊ธฐ์กด ๋
ผ๋ฌธ์์๋ ์๋ณธ BERT ๊ตฌ์กฐ๋ฅผ ์ฌ์ฉํ๋ค. ๊ทธ๋์ ๋ณธ ํฌ์คํ
์์๋ BERT์ ๋ํ ์ค๋ช
์์ด SBO์ ๋ํด์๋ง ๋ค๋ฃจ๋ ค๊ณ ํ๋ค.
๐ย SBO: Span Boundary Objective
[SBO Algorithm Summary]
- 1) ์ฐ์๋ ๋ฒ์์ Span ์์ฑ
- ๋ฌด์์๋ก Span์ ์์ชฝ ๋ ํ ํฐ ์ง์ ($x_{4}, x_{9}$)
- $x_{5}$ to $x_{8}$ ์ ์คํฌ ๋ด๋ถ ํ ํฐ
- ๋ง์คํน ์์ฐ ๊ณ์ฐ
- ๋ฌธ์ฅ ๋น ๋ง์คํน ์์ฐ(ํฉ์ฐ Span ๊ธธ์ด)์ ๋ฌธ์ฅ ๊ธธ์ด์ 15%
- ์์ ์ํ์ค ๊ธธ์ด: 512
- ๋ง์คํน ์์ฐ: ๋๋ต 75 = 512*0.15
- ๊ธฐํ ๋ถํฌ ์ฌ์ฉํด์ ๊ฐ๋ณ ์คํฌ ๊ธธ์ด ์ง์
- ๊ฐ๋ณ ์คํฌ๋น ์ต๋ ๊ธธ์ด ์ง์ , ์ต๋ 10์ด ๋์ง ์๋๋ก ์ค์
- ์ต๋ ์คํฌ ํฉ์ฐ ๊ธธ์ด ๋๋ฌ๊น์ง ๋ง์คํน ๋ฐ๋ณต
- ๋จ์ ๋ง์คํน ์์ฐ < ํ์ฌ ์คํฌ ๊ธธ์ด
- ๋จ์ ๋ง์คํน ์์ฐ์ ํ์ฌ ์คํฌ ๊ธธ์ด๋ก ์ค์
- ๋จ์ ๋ง์คํน ์์ฐ < ํ์ฌ ์คํฌ ๊ธธ์ด
- ๋ฐ๋ผ์ Subword Tokenizing์ด ์๋๋ผ Whole Word Masking ๋จ์ ์์ ์ด ํ์
- ๋ฌด์์๋ก Span์ ์์ชฝ ๋ ํ ํฐ ์ง์ ($x_{4}, x_{9}$)
- 2) ์์ ํ ํฐ ๊ธฐ์ค, ์๋ ์์น ๊ณ์ฐ
- ์คํฌ ๋ด๋ถ ํ ํฐ์ ์๋ ์์น ์๋ฒ ๋ฉ ์์ฑ ๋ฐ ๊ณ์ฐ
- ์์ํ ํฐ, ๋ง์ง๋งํ ํฐ, ์คํฌ ๋ด๋ถ ํ ํฐ์ ์๋ ์์น ์๋ฒ ๋ฉ์ concat, ์๋ ๋ฒกํฐ ์์ฑ
- SpanHead์ ์๋ ๋ฒกํฐ ํต๊ณผ์ํค๊ธฐ
- 3) SpanHead ์ถ๋ ฅ๊ฐ์ ๋ง์คํน์ ๋ํ ์์ธก ํํ์ผ๋ก ์ฌ์ฉ
SBO์ ์์ด๋์ด ์์ฒด๋ ์๋นํ ๊ฐ๋จํ๋ค. ๊ธฐ์กด MLM์ฒ๋ผ ๋ฌด์์๋ก ์ํ์ค์์ ์๋ฌด ํ ํฐ์ด๋ ์ ํํ๋๊ฒ ์๋๋ผ, ์ฃผ์ด์ง ๋ฌธ์ฅ์์ ์ผ์ ๊ธธ์ด์ ์ฐ์๋ ํ ํฐ๋ค์ ํ๋ฒ์ ์ ํํด ๋ง์คํน ์ฒ๋ฆฌํ์ฌ ํ์ตํ๊ฒ ๋ค๋ ๊ฒ์ด๋ค. ๋ ผ๋ฌธ์์ ์ ์ํ SBO ์๊ณ ๋ฆฌ์ฆ์ ์ ๋ฆฌํ๋ฉด ์๋์ ๊ฐ๋ค.
\[\begin{align*} h_0 &= [x_{s-1}; x_{e+1}; p_{i-s+1}] \\ h_1 &= \text{LayerNorm}(\text{GeLU}(W_1 h_0)) \\ y_i &= \text{LayerNorm}(\text{GeLU}(W_2 h_1)) \end{align*}\]์ ๊ทธ๋ฆผ์ ์์๋ก ์๊ณ ๋ฆฌ์ฆ์ ์ดํด๋ณด์. ๋จผ์ ์ฃผ์ด์ง ์คํฌ ๊ธธ์ด์ ๋ง๊ฒ, ์คํฌ์ ์์๊ณผ ๋ ์ง์ ์ด ๋๋ ํ ํฐ์ ๋ฌด์์๋ก ์ ํํ๋ค. ๊ทธ๋ค์ ์์ ์์น๋ฅผ ๊ธฐ์ค์ผ๋ก, ์คํฌ ๋ด๋ถ์ ์ํ๋ ํ ํฐ๋ค์ ์๋ ์์น ์ธ๋ฑ์ค๋ฅผ ๊ณ์ฐํด์ค๋ค. ๊ทธ๋ฆผ ์ $x_{7}$ ํ ํฐ์ ์๋ ์์น ๋ฒํธ๋ 3์ด ๋๋ค. ๋ฏธ๋ฆฌ ์ ์ํ ์๋ ์์น ์๋ฒ ๋ฉ์์ ํ ์ธ๋ฑ์ค๊ฐ 3์ธ ํ๋ฒกํฐ๋ฅผ ๊ฐ์ ธ์จ๋ค. ๊ทธ ๋ค์ ์์ชฝ ๋ ๋ฒกํฐ์ concat์ ์ํํ์ฌ $h_{0}$ ์ ๋ง๋ ๋ค. ๊ทธ๋ฆฌ๊ณ ๋ฏธ๋ฆฌ ์ ์๋ SBOHead
์ ํต๊ณผ์ํจ๋ค. SBOHead
์๊ฒ ๋ฐํ ๋ฐ์ ์๋ ๋ฒกํฐ๊ฐ์ ํด๋น ์์น์ ๋ง์คํน์ ๋ํ ์์ธก๊ฐ($y_{i}$)์ผ๋ก ์ฌ์ฉํ๊ณ ์ด๋ฅผ ์ด์ฉํด SBO ์์ค์ ๊ตฌํ๋ค. ์ง๊ธ๊น์ง ๋ด์ฉ์ ์ ๋ฆฌํด ์์์ผ๋ก ํํํ๋ฉด ์์ ๊ฐ๋ค.
SpanBERT
์ ๋ชฉ์ ํจ์๋ SBO ์์ค ๋ฟ๋ง ์๋๋ผ ๊ธฐ์กด MLM ์์ค๋ ํจ๊ผ ํฌํจ๋์ด ์๋ค. ๋ค๋ง MLM ์์ค์ ๊ตฌํ๊ธฐ ์ํด ์ฃผ์ด์ง ์ํ์ค์ ๋ํด ๋ฐ๋ก ๋ง์คํน์ ํ๋ ๊ฒ์ ์๋๊ณ , SBO๋ฅผ ์ํด ์ ์ฉํ๋ Span Masking์ ๊ทธ๋๋ก ํ์ฉํ๋ค. ๋์ ์์ SBO ์์์ $h_{0}$ ์ด ์๋๋ผ, $x_{i-s+1}$ ($i-s+1$ ๋ฒ์งธ ํ ํฐ์ ์ธ์ฝ๋ ์ถ๋ ฅ๊ฐ)์ ๊ทธ๋๋ก MLM ์์ค์ ๊ตฌํ๋๋ฐ ์ฌ์ฉํ๋ค. ์ ๋ฆฌํ๋ฉด, SpanBERT
์ ์ต์ข
์์ค์ ์ ์์๊ณผ ๊ฐ๋ค. ํํธ, ELECTRA
๋์๋ ๋ค๋ฅด๊ฒ ๋ ์์ค์ ์ค์ผ์ผ ์ฐจ์ด๊ฐ ๊ฑฐ์ ์์ด ๋ฐ๋ก ์ค์ผ์ผ ์์๋ฅผ ๊ณฑํด์ฃผ์ง๋ ์๋ ๊ฒ ๊ฐ๋ค.
๐ฉโ๐ปย Implementation by Pytorch
๋
ผ๋ฌธ์ ๋ด์ฉ ์ข
ํฉํ์ฌ ํ์ดํ ์น๋ก SpanBERT
๋ฅผ ๊ตฌํํด๋ดค๋ค. ๋
ผ๋ฌธ์ ํฌํจ๋ ์์ด๋์ด๋ฅผ ์ดํดํ๋๋ฐ๋ ์ด๋ ต์ง ์์์ง๋ง, ์ ํ๋ ์กฐ๊ฑด์ ๋ง๋ ์คํฌ์ ์ฐพ๊ณ , ๋ง์คํนํ๋ ๊ณผ์ ์ ์ค์ ๊ตฌํํ๋ ๊ฒ์ ๋งค์ฐ ๊น๋ค๋ก์ด ํธ์ด์๋ค.
๋ณธ ํฌ์คํ
์์๋ SpanBERT
์ SBO ํ์ต์ ์ํ ์
๋ ฅ ๋ง๋ค๊ธฐ, SBOHead์ ๋ํด์๋ง ์ค๋ช
ํ๋ ค๊ณ ํ๋ค. BERT
, Whole World Masking
์ ๋ํด ๊ถ๊ธํ๋ค๋ฉด ์ด์ ํฌ์คํ
์, ์ ์ฒด ๋ชจ๋ธ ๊ตฌ์กฐ ๋ํ ์ฝ๋๋ ์ฌ๊ธฐ ๋งํฌ๋ฅผ ํตํด ์ฐธ๊ณ ๋ฐ๋๋ค.
๊ณต๊ฐํ ์ฝ๋๋ ์์ง ์๋ฒฝํ๊ฒ ๋ฒกํฐํ๋ฅผ ์ ์ฉํ์ง ๋ชปํด, GPU ๋ณ๋ ฌ ์ฐ์ฐ์ ์ต์ ํ ๋์ง ๋ชปํ ์ ์ํด ๋ถํํ๋ค. ๋น ๋ฅธ ์์ผ ์ด๋ด์ ๋ฒกํฐํ๋ฅผ ์ ์ฉํด์ ๋ค์ ์์ ๋ ์ฝ๋๋ฅผ ์ฌ๋ฆฌ๊ฒ ๋ค.
๐ฉโ๐ปย Span Masking Algoritm
import random
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch import Tensor
from typing import Dict, List, Tuple, Any
from ..tuner.mlm import WholeWordMaskingCollator
from configuration import CFG
BPE = [
'RobertaTokenizerFast',
'GPT2TokenizerFast',
]
SPM = [
'DebertaV2TokenizerFast',
'DebertaTokenizerFast',
'XLMRobertaTokenizerFast',
]
WORDPIECE = [
'BertTokenizerFast',
'ElectraTokenizerFast',
]
def random_non_negative_integer(max_value: int) -> int:
return random.randint(0, max_value)
class SpanCollator(WholeWordMaskingCollator):
""" Custom Collator for Span Boundary Objective Task, which is used for span masking algorithm
Span Masking is similar to Whole Word Masking, but it has some differences:
1) Span Masking does not use 10% of selected token left & 10% of selected token replaced other vocab token
- just replace all selected token to [MASK] token
Algorithm:
1) Select 2 random tokens from input tokens for spanning
2) Calculate relative position embedding for each token with 2 random tokens froms step 1.
3) Calculate span boundary objective with 2 random tokens from step 1 & pos embedding from step 2.
Args:
cfg: configuration.CFG
masking_budget: masking budget for Span Masking
(default: 0.15 => Recommended by original paper)
span_probability: probability of span length for Geometric Distribution
(default: 0.2 => Recommended by original paper)
max_span_length: maximum span length of each span in one batch sequence
(default: 10 => Recommended by original paper)
References:
https://arxiv.org/pdf/1907.10529.pdf
"""
def __init__(
self,
cfg: CFG,
masking_budget: float = 0.15,
span_probability: float = 0.2,
max_span_length: int = 10
) -> None:
super(SpanCollator, self).__init__(cfg)
self.cfg = cfg
self.tokenizer = self.cfg.tokenizer
self.masking_budget = masking_budget
self.span_probability = span_probability
self.max_span_length = max_span_length
def _whole_word_mask(
self,
input_tokens: List[str],
max_predictions: int = CFG.max_seq
) -> List[int]:
"""
0) apply Whole Word Masking Algorithm for make gathering original token index in natural language
1) calculate number of convert into masking tokens with masking budget*len(input_tokens)
2) define span length of this iteration
- span length follow geometric distribution
- span length is limited by max_span_length
"""
cand_indexes = []
for i, token in enumerate(input_tokens):
if token == "[CLS]" or token == "[SEP]":
continue
if len(cand_indexes) >= 1 and self.select_post_string(token): # method from WholeWordMaskingCollator
cand_indexes[-1].append(i)
elif self.select_src_string(token): # method from WholeWordMaskingCollator
cand_indexes.append([i])
l = len(input_tokens)
src_l = len(cand_indexes)
num_convert_tokens = int(self.masking_budget * l)
budget = num_convert_tokens # int is immutable object, so do not copy manually
masked_lms = []
covered_indexes = set()
while budget:
span_length = max(1, min(10, int(torch.distributions.Geometric(probs=self.span_probability).sample())))
src_index = random_non_negative_integer(src_l - 1)
if span_length > budget:
if budget < 5: # Set the span length to budget to avoid a large number of iter if the remaining budget is too small
span_length = budget
else:
continue
if cand_indexes[src_index][0] + span_length > l - 1: # If the index of the last token in the span is outside the full sequence range
continue
if len(cand_indexes[src_index]) > span_length: # handling bad case: violating WWM algorithm at start
continue
span_token_index = cand_indexes[src_index][0] # init span token index: src token
while span_length:
if span_length == 0:
break
if span_token_index in covered_indexes: # If it encounters an index that is already masked, it ends, and starts the next iteration
break
else:
covered_indexes.add(span_token_index)
masked_lms.append(span_token_index)
span_length -= 1
budget -= 1
span_token_index += 1
continue
if len(covered_indexes) != len(masked_lms):
raise ValueError("Length of covered_indexes is not equal to length of masked_lms.")
mask_labels = [1 if i in covered_indexes else 0 for i in range(len(input_tokens))]
return mask_labels
def get_mask_tokens(
self,
inputs: Tensor,
mask_labels: Tensor
) -> Tuple[Tensor, Tensor]:
""" Prepare masked tokens inputs/labels for Span Boundary Objective with MLM (15%),
All of masked tokens (15%) are replaced by [MASK] token,
Unlike BERT MLM which is replaced by random token or stay original token left
"""
labels = inputs.clone()
probability_matrix = mask_labels
special_tokens_mask = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
]
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
if self.tokenizer.pad_token is not None:
padding_mask = labels.eq(self.tokenizer.pad_token_id)
probability_matrix.masked_fill_(padding_mask, value=0.0)
masked_indices = probability_matrix.bool()
labels[~masked_indices] = -100 # We only compute loss on masked tokens
inputs[masked_indices] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
return inputs, labels
def forward(self, batched: List[Dict[str, Tensor]]) -> Dict:
""" Abstract Method for Collator, you must implement this method in child class """
input_ids = [torch.as_tensor(x["input_ids"]) for x in batched]
padding_mask = [self.get_padding_mask(x) for x in input_ids]
padding_mask = pad_sequence(padding_mask, batch_first=True, padding_value=True)
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
mask_labels = []
for x in batched:
ref_tokens = []
for input_id in x["input_ids"]:
token = self.tokenizer._convert_id_to_token(input_id)
ref_tokens.append(token)
mask_labels.append(self._whole_word_mask(ref_tokens))
mask_labels = [torch.as_tensor(x) for x in mask_labels]
mask_labels = pad_sequence(mask_labels, batch_first=True, padding_value=0)
input_ids, labels = self.get_mask_tokens(
input_ids,
mask_labels
)
return {
"input_ids": input_ids,
"labels": labels,
"padding_mask": padding_mask,
"mask_labels": mask_labels
}
๐ฉโ๐ปย SBO Head
class SBOHead(nn.Module):
""" Custom Head for Span Boundary Objective Task, this module return logit value for each token
we use z for class logit, each Fully Connected Layer doesn't have bias term in original paper
so we don't use bias term in this module => nn.Linear(bias=False)
You must select option for matrix sum or concatenate with x_s-1, x_e+1, p_i-s+1
If you select concatenate option, you must pass is_concatenate=True to cfg.is_concatenate, default is True
Math:
h_0 = [x_s-1;x_e+1;p_i-s+1]
h_t = LayerNorm(GELU(W_0โขh_0))
z = LayerNorm(GELU(W_1โขh_t))
Args:
cfg: configuration.CFG
is_concatenate: option for matrix sum or concatenate with x_s-1, x_e+1, p_i-s+1, default True
max_span_length: maximum span length of each span in one batch sequence
(default: 10 => Recommended by original paper)
References:
https://arxiv.org/pdf/1907.10529.pdf
"""
def __init__(
self,
cfg: CFG,
is_concatenate: bool = True,
max_span_length: int = 10
) -> None:
super(SBOHead, self).__init__()
self.cfg = cfg
self.is_concatenate = is_concatenate # for matrix sum or concatenate with x_s-1, x_e+1, p_i-s+1
self.projector = nn.Linear(self.cfg.dim_model, self.cfg.dim_model*3, bias=False) # for concatenate x_s-1, x_e+1, p_i-s+1
self.span_pos_emb = nn.Embedding(max_span_length*3, cfg.dim_model) # size of dim_model is research topic
self.head = nn.Sequential(
nn.Linear(self.cfg.dim_model*3, self.cfg.dim_ffn),
nn.GELU(),
nn.LayerNorm(self.cfg.dim_ffn),
nn.Linear(self.cfg.dim_ffn, self.cfg.dim_model, bias=False),
nn.GELU(),
nn.LayerNorm(self.cfg.dim_model),
)
self.classifier = nn.Linear(self.cfg.dim_model, self.cfg.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(cfg.vocab_size)) # for matching vocab size
self.classifier.bias = self.bias
@staticmethod
def find_consecutive_groups(mask_labels: Tensor, target_value: int = 1) -> List[List[Dict]]:
""" Get the start and end positions of consecutive groups in tensor for the target value
This method is used for SBO Objective Function, this version is not best performance to make span groups
Args:
mask_labels: masking tensor for span
target_value: target value for finding consecutive groups
"""
all_consecutive_groups = []
for mask_label in mask_labels:
consecutive_groups = []
current_group = None
for i, value in enumerate(mask_label):
if value == target_value:
if current_group is None:
current_group = {"start": i, "end": i}
else:
current_group["end"] = i
else:
if current_group is not None:
consecutive_groups.append(current_group)
current_group = None
if current_group is not None:
consecutive_groups.append(current_group)
all_consecutive_groups.append(consecutive_groups)
return all_consecutive_groups
def cal_span_emb(self, h: Tensor, hidden_states: Tensor, consecutive_groups: List[List[Dict]]) -> Tensor:
""" Calculate span embedding for each span in one batch sequence
Args:
h: hidden states, already passed through projection layer (dim*3)
hidden_states: hidden states from encoder
consecutive_groups: consecutive groups for each batch sequence
"""
for i, batch in enumerate(consecutive_groups): # batch level
for j, span in enumerate(batch): # span level
start, end = span["start"], span["end"]
length = end - start + 1
idx = torch.arange(length, device=self.cfg.device) # .to(self.cfg.device)
context_s, context_e = hidden_states[i, start - 1, :], hidden_states[i, end + 1, :]
span_pos_emb = self.span_pos_emb(idx).squeeze(0)
if length > 1:
for k, p_h in enumerate(span_pos_emb): # length of span_pos_emb == length of span of this iterations
h[i, start+k, :] = torch.cat([context_s, p_h, context_e], dim=0)
else:
h[i, start, :] = torch.cat([context_s, span_pos_emb, context_e], dim=0)
return h
def forward(self, hidden_states: Tensor, mask_labels: Tensor) -> Tensor:
consecutive_groups = self.find_consecutive_groups(mask_labels) # [batch, num_consecutive_groups]
h = self.projector(hidden_states) # [batch, seq, dim_model*3]
h_t = self.cal_span_emb(h, hidden_states, consecutive_groups)
z = self.head(h_t)
logit = self.classifier(z)
return logit
Leave a comment