Updated:

πŸ“ κ°€λ³€ 길이의 ν…μ„œλ₯Ό λ°μ΄ν„°λ‘œλ”μ— μ „λ‹¬ν•˜λŠ” 경우

μ»€μŠ€ν…€ 데이터 ν΄λž˜μŠ€μ™€ λ°μ΄ν„°λ‘œλ”λ₯Ό 톡해 λ°˜ν™˜λ˜λŠ” 데이터 μΈμŠ€ν„΄μŠ€μ˜ ν…μ„œ 크기가 μΌμ •ν•˜μ§€ μ•Šμ•„ λ°œμƒν•˜λŠ” μ—λŸ¬λ‹€. 특히 μžμ—°μ–΄ μ²˜λ¦¬μ—μ„œ 자주 μ°Ύμ•„ λ³Ό 수 μžˆλŠ”λ° λ°μ΄ν„°λ‘œλ” 객체 μ„ μ–Έ μ‹œ, λ§€κ°œλ³€μˆ˜ μ˜΅μ…˜ 쀑에 collate_fn=collate λ₯Ό μΆ”κ°€ν•΄μ£Όλ©΄ ν•΄κ²° κ°€λŠ₯ν•œ μ—λŸ¬λ‹€. 이 λ•Œ λ§€κ°œλ³€μˆ˜ collate_fn 에 μ „λ‹¬ν•˜λŠ” κ°’(λ©”μ„œλ“œ)은 μ‚¬μš©μžκ°€ 직접 μ •μ˜ν•΄μ€˜μ•Ό ν•œλ‹€. ν—ˆκΉ…νŽ˜μ΄μŠ€ λΌμ΄λΈŒλ¦¬λŸ¬μ— 상황에 맞게 미리 μ œμž‘λœ collate λ©”μ„œλ“œλ₯Ό 지원해주고 있기 λ•Œλ¬Έμ— 잘 μ΄μš©ν•˜λ©΄ λœλ‹€. ν•„μžμ˜ κ²½μš°μ—λŠ” μ»€μŠ€ν…€μœΌλ‘œ 직접 μ •μ˜ν•œ λ©”μ„œλ“œ, 객체λ₯Ό μ‚¬μš©ν•˜κ³  μžˆλ‹€.

# 데이터 λ‘œλ” μ˜ˆμ‹œ
loader_train = DataLoader(
            train_dataset,
            batch_size=self.cfg.batch_size,
            shuffle=True,
            worker_init_fn=seed_worker,
            collate_fn=MiniBatchCollate,  # 여기에 μ‚¬μš©ν•˜λ €λŠ” collate function ν˜Ήμ€ 객체λ₯Ό μ „λ‹¬ν•˜μž!!
            generator=self.generator,
            num_workers=self.cfg.num_workers,
            pin_memory=True,
            drop_last=False,
        )

# collate λ©”μ„œλ“œ μ˜ˆμ‹œ: 
class MiniBatchCollate(object):
    """
    Collate class for torch.utils.data.DataLoader  
    This class object to use variable data such as NLP text sequence
    If you use static padding with AutoTokenizer, you don't need this class 
    But if you use dynamic padding with AutoTokenizer, you must use this class object & call
    Args:
        batch: data instance from torch.utils.data.DataSet
    """
    def __init__(self, batch: torch.utils.data.DataLoader) -> None:
        self.batch = batch

    def __call__(self) -> tuple[dict[Tensor, Tensor, Tensor], Tensor, Tensor]:
        inputs, labels, position_list = self.batch
        labels = torch.nn.utils.rnn.pad_sequence(
            labels,
            batch_first=True,
            padding_value=-1
        )
        position_list = torch.nn.utils.rnn.pad_sequence(
            position_list,
            batch_first=True,
            padding_value=-1
        )
        return inputs, labels, position_list

def collate(inputs):
    """
    slice input sequence by maximum length sequence in mini-batch, used for speed up training
    if you want slice other variable such as label feature, you can add param on them
    This Function should be used after DataLoader return mini-batch instance
    Args:
        inputs: list of dict, dict has keys of "input_ids", "attention_mask", "token_type_ids"    
    """
    mask_len = int(inputs["attention_mask"].sum(axis=1).max())
    for k, v in inputs.items():
        inputs[k] = inputs[k][:, :mask_len]
    return inputs

일반적으둜 collate λŠ” λ©”μ„œλ“œλ‘œ κ΅¬ν˜„ν•΄μ„œ μ‚¬μš©ν•˜μ§€λ§Œ, μœ„ μ½”λ“œμ²˜λŸΌ 객체둜 κ΅¬ν˜„ν•˜κ³  내뢀에 __call__ λ₯Ό μ •μ˜ν•΄ μ‚¬μš©ν•˜λŠ” 방법도 μžˆλ‹€. ν•„μž μ—­μ‹œ 단일 λ©”μ„œλ“œ ν˜•νƒœλ₯Ό κ³„μ†ν•΄μ„œ μ‚¬μš©ν•˜λ‹€κ°€ 졜근 λ“€μ–΄ 에폭 ν•œ λ²ˆμ— μ„œλ‘œ λ‹€λ₯Έ 데이터 μ„ΈνŠΈ 및 λͺ¨λΈμ„ ν›ˆλ ¨ μ‹œμΌœμ•Ό ν•˜λŠ” 상황을 λ§ˆμ£Όν•œ 이후 객체 ν˜•νƒœλ‘œ λ‹€μ‹œ κ΅¬ν˜„ν•΄ μ‚¬μš©ν•˜κ³  μžˆλ‹€.

ν•œνŽΈ μ˜ˆμ‹œ μ½”λ“œ κ°€μž₯ λ§ˆμ§€λ§‰ collate λ©”μ„œλ“œλŠ” μž…λ ₯ μ‹œν€€μŠ€κ°€ huggingface의 AutoTokenizer.encode_plus λ₯Ό μ΄μš©ν•΄ μ‚¬μš©μž 지정 max_lenκΉŒμ§€ νŒ¨λ”©μ„ 마친 μƒνƒœλΌλŠ” κ°€μ •ν•˜μ— κ΅¬ν˜„ λ˜μ—ˆλ‹€. ν•΄λ‹Ή λ©”μ„œλ“œλŠ” μœ„μ— λ°œμƒν•œ μ—λŸ¬λ₯Ό ν•΄κ²°ν•˜κΈ° μœ„ν•¨λ³΄λ‹€, λ―Έλ‹ˆ λ°°μΉ˜μ— μ†ν•œ 전체 데이터 μ€‘μ—μ„œ μ΅œλŒ€ 길이가 μ‚¬μš©μž 지정 max_lenκΉŒμ§€ λ―ΈμΉ˜μ§€ λͺ»ν•˜λŠ”데 νŒ¨λ”©μ΄ 된 κ²½μš°μ— μ‚¬μš©ν•˜κΈ° μœ„ν•΄ λ§Œλ“€μ—ˆλ‹€. λΆˆν•„μš”ν•œ νŒ¨λ”©μ„ trucation ν•˜μ—¬ λ‰΄λŸ΄ λ„€νŠΈμ›Œν¬μ˜ ν•™μŠ΅ 속도λ₯Ό 높이기 μœ„ν•¨μ΄λ‹€. ν•΄λ‹Ή λ©”μ„œλ“œλŠ” ν¬μŠ€νŒ…μ˜ 제λͺ©μ— 달린 μ—λŸ¬λ₯Ό ν•΄κ²°ν•˜λŠ”λ° μ‚¬μš©ν•  μˆ˜λŠ” μ—†μ§€λ§Œ collate κΈ°λŠ₯을 μ–ΈκΈ‰ν•˜λŠ” 김에 μƒκ°μ΄λ‚˜ 같이 정리해봀닀. 이 λ©”μ„œλ“œλŠ” torch.utils.data.DataLoader 의 μΈμžκ°€ μ•„λ‹ˆλΌ, 메인 ν•™μŠ΅ 루프 내뢀에 μ‚¬μš©ν•œλ‹€. λ‹€μ‹œ 말해, λ°μ΄ν„°λ‘œλ”κ°€ 배치 μΈμŠ€ν„΄μŠ€λ₯Ό λ°˜ν™˜ν•œ λ‹€μŒ μ‚¬μš©ν•˜λ©΄ λœλ‹€λŠ” 것이닀. νŒ¨λ”©λ°©μ‹κ³Ό collate κΈ°λŠ₯에 λŒ€ν•œ μžμ„Έν•œ μ„€λͺ…은 λ‹€λ₯Έ ν¬μŠ€νŒ…μ—μ„œ 닀루도둝 ν•˜κ² λ‹€.

반면 MiniBatchCollate κ°μ²΄λŠ” torch.utils.data.DataLoader 의 collate_fn μΈμžμ— μ „λ‹¬ν•˜λ©΄ λœλ‹€. ν•„μžμ˜ κ²½μš°λŠ” Dynamic Padding 기법을 μ‚¬μš©ν•˜κΈ° λ•Œλ¬Έμ— λ―Έλ‹ˆ 배치 λ‚΄λΆ€μ˜ μΈμŠ€ν„΄μŠ€λ“€μ΄ μ„œλ‘œ λ‹€λ₯Έ μ‹œν€€μŠ€ 길이λ₯Ό κ°–λŠ” κ²½μš°κ°€ λ°œμƒν•œλ‹€. λ°μ΄ν„°λ‘œλ”λŠ” λ―Έλ‹ˆ λ°°μΉ˜μ— μ†ν•˜λŠ” λ°μ΄ν„°μ˜ 길이가 ν†΅μΌλ˜μ§€ μ•ŠμœΌλ©΄ 배치 λ‹¨μœ„λ‘œ 데이터λ₯Ό 묢을 수 μ—†κ²Œ λœλ‹€. λ”°λΌμ„œ λ―Έλ‹ˆ 배치 λ‹¨μœ„μ˜ 길이 톡일을 μœ„ν•΄ torch.nn.utils.rnn.pad_sequence λ©”μ„œλ“œλ₯Ό μ‚¬μš©ν•œλ‹€. 이 λ©”μ„œλ“œλŠ” μž…λ ₯ν•œ λ―Έλ‹ˆ 배치 데이터 μ€‘μ—μ„œ κ°€μž₯ κΈ΄ μ‹œν€€μŠ€λ₯Ό κΈ°μ€€μœΌλ‘œ λͺ¨λ“  데이터 길이λ₯Ό ν†΅μΌν•œλ‹€. batch_first=True λ₯Ό μ£Όλͺ©ν•˜μž. 이 인자λ₯Ό False 둜 μ„€μ •ν•  경우, 배치 차원이 맨 μ•žμ΄ μ•„λ‹ˆλΌ 쀑간에 μ •μ˜λœλ‹€. μΌλ°˜μ μœΌλ‘œλŠ” 배치 차원을 맨 μ•žμ— λ‘λŠ” μ›Œν¬ν”Œλ‘œμš°λ₯Ό μ‚¬μš©ν•˜κΈ° λ•Œλ¬Έμ— κΌ­ ν•΄λ‹Ή 인자λ₯Ό True 둜 μ„€μ •ν•˜κ³  μ‚¬μš©ν•˜μž.

Leave a comment