Updated:

πŸ”₯ Pytorch Backward κ³Όμ •μ—μ„œ NaN λ°œμƒν•˜λŠ” 문제

μ»€μŠ€ν…€μœΌλ‘œ λͺ¨λΈ, μ—¬λŸ¬ 풀링, 맀트릭, 손싀 ν•¨μˆ˜λ“€μ„ μ •μ˜ν•˜λ©΄μ„œλΆ€ν„° 제일 많이 λ§ˆμ£Όν•˜κ²Œ λ˜λŠ” μ—λŸ¬λ‹€. μ§„μ‹¬μœΌλ‘œ μš”μ¦˜ CUDA OOM 보닀 훨씬 자주 λ³΄λŠ” 것 κ°™λ‹€. ν•΄λ‹Ή μ—λŸ¬λŠ” LogSoftmax λ ˆμ΄μ–΄μ— μ „λ‹¬λœ μž…λ ₯κ°’ μ€‘μ—μ„œ nan, inf κ°€ ν¬ν•¨λ˜μ–΄ 연산을 진행할 수 μ—†λ‹€λŠ” 것을 μ˜λ―Έν•œλ‹€. λ”₯λŸ¬λ‹ μ‹€ν—˜μ„ μ§„ν–‰ν•˜λ©΄μ„œ κ°€μž₯ ν•΄κ²°ν•˜κΈ° κΉŒλ‹€λ‘œμš΄ λ…€μ„μœΌλ‘œ 원인을 νŠΉμ •ν•˜κΈ° νž˜λ“€κΈ° λ•Œλ¬Έμ΄λ‹€. 원인을 작기 μ–΄λ €μš΄ μ΄μœ λŠ” λ°”λ‘œ μš°λ¦¬κ°€ μ§€κΈˆ ν•˜κ³  μžˆλŠ”κ²Œ β€˜λ”₯λŸ¬λ‹β€™ μ΄λΌμ„œ κ·Έλ ‡λ‹€. μœ„ μ—λŸ¬λŠ” λŒ€λΆ€λΆ„ μ—°μ‚°μžκ°€ μš°λ¦¬κ°€ μ˜λ„ν•˜μ§€ μ•Šμ€ λ™μž‘μ„ ν•˜λŠ” μΌ€μ΄μŠ€ λ•Œλ¬ΈμΈλ°, ν•˜λ‚˜ ν•˜λ‚˜ λ””λ²„κΉ…ν•˜κΈ°μ—λŠ” λ„ˆλ¬΄λ‚˜λ„ μ—°μ‚°μžκ°€ λ§Žλ‹€. λ˜ν•œ λ”₯λŸ¬λ‹μ€ μž…μΆœλ ₯으둜 μ—„μ²­λ‚˜κ²Œ 큰 μ‚¬μ΄μ¦ˆμ˜ 행렬을 μ‚¬μš©ν•œλ‹€. μš°λ¦¬κ°€ nan, inf κ°’ μ‘΄μž¬μ— λŒ€ν•΄μ„œ μΈμ§€ν•˜κΈ° 쉽지 μ•Šλ‹€.

μœ„ μ—λŸ¬λŠ” ν•„μžμ˜ κ²½ν—˜μƒ λŒ€λΆ€λΆ„ μ»€μŠ€ν…€μœΌλ‘œ μ •μ˜ν•œ λ ˆμ΄μ–΄μ—μ„œ λ°œμƒν•˜λŠ” κ²½μš°κ°€ λ§Žμ•˜μœΌλ©° 특히 λΆ„μˆ˜, 각도, 제곱근, μ§€μˆ˜ κ°œλ…μ„ μ‚¬μš©ν•˜λŠ” μ—°μ‚°μžκ°€ λŒ€λΆ€λΆ„ μ›μΈμ΄μ—ˆλ‹€. 예λ₯Ό λ“€μ–΄ 코사인 μœ μ‚¬λ„λ₯Ό κ΅¬ν•˜λŠ” κ³Όμ •μ—μ„œ μ—°μ‚° λŒ€μƒ 벑터값에 zero-value κ°€ ν¬ν•¨λœ 경우 λΆ„λͺ¨κ°€ 0이 되기 λ•Œλ¬Έμ— μ—°μ‚° μ •μ˜κ°€ λ˜μ§€ μ•Šμ•„ nan 을 λ°˜ν™˜ν•΄ μœ„μ™€ 같은 μ—λŸ¬κ°€ λ°œμƒν•˜λŠ” κ²½μš°κ°€ μžˆλ‹€.

def check_nan(x: torch.Tensor) -> bool:
    """ Check if there is NaN in tensor """
    checker = False
    if True in torch.isnan(x):
        checker = True
    return checker

def zero_filtering(x: torch.Tensor) -> torch.Tensor:
    """
    Add eps value for zero embedding, because competition metric is cosine similarity
    Cosine Similarity will be returned NaN, when input value has zero, like as torch.clamp()
    """
    eps = 1e-4
    x[x <= eps] = eps
    return x

def nan_filtering(x: torch.Tensor, eps: float = 1e-4) -> torch.Tensor:
    """
    Change eps value for NaN Embedding, because competition metric is cosine similarity
    Cosine Similarity will be returned NaN
    """
    return torch.nan_to_num(x, nan=eps)

class CLIPGEMPooling(nn.Module):
    """
    Generalized Mean Pooling for Natural Language Processing
    This class version of GEMPooling for CLIP, Transfer from NLP Task Code
    ViT don't use attention mask, because input image shape will be same

    Mean Pooling <= GEMPooling <= Max Pooling
    Because of doing exponent to each token embeddings, GEMPooling is like as weight to more activation token

    In original paper, they use p=3, but in this class, we use p=4 because torch doesn't support pow calculation
    for negative value tensor, only for non-negative value in odd number exponent
    """
    def __init__(self, auto_cfg: AutoConfig.from_pretrained) -> None:
        super(CLIPGEMPooling, self).__init__()

    @staticmethod
    def forward(last_hidden_state, p: int = 2) -> Tensor:
        """
        last_hidden_state.size: [batch_size, patches_sequence, hidden_size]
        1) Pow last_hidden_state with p and then take a averaging
        2) pow sum_embeddings with 1/p
        """
        p_embeddings = zero_filtering(torch.pow(last_hidden_state, p))
        # Check NaN value in Embedding after applying torch.pow
        if check_nan(p_embeddings):
            p_embeddings = nan_filtering(p_embeddings)
        sum_embeddings = torch.mean(p_embeddings, 1)

        gem_embeddings = zero_filtering(torch.pow(sum_embeddings, 1. / p))
        # Check NaN value in Embedding after applying torch.pow
        if check_nan(gem_embeddings):
            gem_embeddings = nan_filtering(gem_embeddings)
        return gem_embeddings

class CLIPMultipleNegativeRankingLoss(nn.Module):
    """
    Multiple Negative Ranking Loss for CLIP Model
    main concept is same as original one, but append suitable for other type of model (Not Sentence-Transformers)
    if you set more batch size, you can get more negative pairs for each anchor & positive pair
    Args:
        scale: output of similarity function is multiplied by this value => I don't know why this is needed
        similarity_fct: standard of distance metrics, default cosine similarity
    """
    def __init__(self, reduction: str, scale: float = 20.0, similarity_fct=cos_sim) -> None:
        super().__init__()
        self.reduction = reduction
        self.scale = scale
        self.similarity_fct = similarity_fct
        self.reduction = reduction
        self.cross_entropy_loss = CrossEntropyLoss(self.reduction)

    def forward(self, embeddings_a: Tensor, embeddings_b: Tensor) -> Tensor:
        similarity_scores = zero_filtering(self.similarity_fct(embeddings_a, embeddings_b)) * self.scale
        if check_nan(similarity_scores):
            """ Check NaN Value in similarity_scores """
            similarity_scores = nan_filtering(similarity_scores)

        labels = torch.tensor(
            range(len(similarity_scores)),
            dtype=torch.long,
            device=similarity_scores.device,
        )
        return self.cross_entropy_loss(similarity_scores, labels)

ν•„μžμ˜ 경우, 두 개의 μž…λ ₯ 행렬에 각각 sqrt() λ₯Ό μ μš©ν•˜κ³  두 ν–‰λ ¬μ˜ κ°œλ³„ μ›μ†Œ μ‚¬μ΄μ˜ 코사인 μœ μ‚¬λ„λ₯Ό ꡬ해야 ν–ˆλ˜ 적이 μžˆλ‹€. sqrt κ³Όμ •μ—μ„œ λ„ˆλ¬΄ μž‘μ€ 값듀이 μž…λ ₯으둜 λ“€μ–΄κ°€ underflow κ°€ λ°œμƒν•΄ 행렬에 zero-value κ°€ 생겼고, 이λ₯Ό λͺ¨λ₯Έμ±„ 코사인 μœ μ‚¬λ„λ₯Ό κ΅¬ν•˜λ‹€κ°€ ν•œμ°Έμ„ μœ„ μ—λŸ¬μ™€ μ‹Έμ› λ˜ 적이 μžˆλ‹€. 심지어 연산속도 ν–₯상을 μœ„ν•΄μ„œ torch.autocast 클래슀의 grad_scaler(float32 to float16) κΉŒμ§€ μ μš©ν•˜κ³  μžˆμ—ˆλ‹€.

πŸ–οΈ λ‚΄κ°€ ν•΄κ²°ν•œ 방법

이 글을 μ½λŠ” 당신이 λ§Œμ•½ sqrt ν˜Ήμ€ powλ₯Ό ν™œμš©ν•˜λŠ” 경우, underflow 방지λ₯Ό μœ„ν•΄μ„œ μœ„ μ˜ˆμ‹œ μ½”λ“œμ²˜λŸΌ κΌ­ μ λ‹Ήν•œ μž…μ‹€λ‘  값을 μ—°μ‚° 전후에 ν•„μš”μ— 따라 더해쀄 것을 ꢌμž₯ν•œλ‹€. μž…μ‹€λ‘  κ°’μ˜ 섀정은 ν˜„μž¬ μžμ‹ μ΄ μ‚¬μš©ν•˜κ³  μžˆλŠ” 뢀동 μ†Œμˆ˜μ  정확도에 맞게 μ„€μ •ν•΄μ£Όλ©΄ 될 것 κ°™λ‹€. float32 λ₯Ό μ‚¬μš©ν•˜λŠ” κ²½μš°μ—λŠ” λŒ€λΆ€λΆ„ 1e-6 을 많이 μ‚¬μš©ν•˜λŠ” 것 κ°™λ‹€. ν•„μžλ„ μ •ν™•νžˆ μ–΄λ–€ 값이 μ λ‹Ήν•œμ§€ 아직 잘 λͺ¨λ₯΄κ² λ‹€β€¦ 그리고 λ”₯λŸ¬λ‹ μ‹€ν—˜ν•˜λ©΄μ„œ overflow λ•Œλ¬Έμ— inf 이 λ°œμƒν–ˆλ˜ 적은 μ—†μ—ˆλ‹€.

μž…μ‹€λ‘  값을 λ¬Έμ œκ°€ λ˜λŠ” μ—°μ‚° 전에 μΌκ΄„μ μœΌλ‘œ 더할 경우, 아무리 μž‘μ€ 값이라도 μ—°μ‚° μ’…λ₯˜μ— λ”°λΌμ„œ κ²°κ³Όκ°€ 크게 μ™œκ³‘λ˜λŠ” κ²½μš°κ°€ λ°œμƒν•œλ‹€. λ”°λΌμ„œ 연산을 λ¨Όμ € μ μš©ν•œ λ’€ 결과에 NaN, Inf, Zeroκ°€ λ°œμƒν•˜λŠ”μ§€ μ²΄ν¬ν•˜κ³ , λ°œμƒν•œ 뢀뢄에 ν•œν•΄μ„œ μž…μ‹€λ‘  값을 λ”ν•΄μ£ΌλŠ” μ»€μŠ€ν…€ function울 μ •μ˜ν•΄ 문제λ₯Ό ν•΄κ²°ν–ˆλ‹€.
(μœ„μ˜ μ½”λ“œ 예제 check_nan, zero_filtering, nan_filtering)

ν•œνŽΈ torch.autograd.set_detect_anomaly(True) λ₯Ό ν›ˆλ ¨ 루프 μ΄ˆλ°˜μ— μ •μ˜ν•΄μ£Όλ©΄, NaN이 λ°œμƒν•˜λŠ” μ¦‰μ‹œ 싀행이 λ©ˆμΆ”κ³  NaN을 μœ λ°œν•œ 라인을 좜λ ₯ν•΄μ€€λ‹€. κΌ­ ν™œμš©ν•΄λ³΄μž.

Leave a comment