Updated:

๐Ÿค” Optimizer๊ฐ€ ์†์‹ค๊ฐ’์„ ์ œ๋Œ€๋กœ Backward ํ•  ์ˆ˜ ์—†๋Š” ๋ฌธ์ œ

ํ…์„œ์˜ ๊ณ„์‚ฐ ๊ทธ๋ž˜ํ”„๊ฐ€ ์ค‘๊ฐ„์— ๋Š์–ด์ ธ ์˜ตํ‹ฐ๋งˆ์ด์ €๊ฐ€ ๊ทธ๋ผ๋””์–ธํŠธ๋ฅผ ์ œ๋Œ€๋กœ Backward ํ•˜์ง€ ๋ชปํ•ด ๋ฐœ์ƒํ•˜๋Š” ์—๋Ÿฌ๋‹ค. ๊ณต๋ถ€๋ฅผ ์‹œ์ž‘ํ•˜๊ณ  ์ •๋ง ์ฒ˜์Œ ๋งˆ์ฃผํ•˜๋Š” ์—๋Ÿฌ๋ผ์„œ ์ •๋ง ๋งŽ์ด ๋‹นํ™ฉํ–ˆ๋‹ค. ๋ž˜ํผ๋Ÿฐ์Šค ์ž๋ฃŒ ์—ญ์‹œ ๊ฑฐ์˜ ์—†์–ด์„œ ํ•ด๊ฒฐํ•˜๋Š”๋ฐ ์• ๋ฅผ ๋จน์—ˆ๋˜ ์“ฐ๋ผ๋ฆฐ ์‚ฌ์—ฐ์ด ์žˆ๋Š” ์—๋Ÿฌ๋‹ค. ์ด ๊ธ€์„ ์ฝ๋Š” ๋…์ž๋ผ๋ฉด ๋Œ€๋ถ€๋ถ„ ํ…์„œ์˜ ๊ณ„์‚ฐ ๊ทธ๋ž˜ํ”„๊ฐ€ ์ค‘๊ฐ„์— ๋Š์–ด์ง„๋‹ค๋Š” ๊ฒƒ์ด ๋ฌด์Šจ ์˜๋ฏธ์ผ์ง€ ์ดํ•ดํ•˜์‹œ์ง€ ๋ชปํ• ๊ฑฐ๋ผ ์ƒ๊ฐํ•œ๋‹ค. ๊ทธ๊ฒŒ ์ •์ƒ์ด๋‹ค. ํ•„์ž ์—ญ์‹œ ์•Œ๊ณ  ์‹ถ์ง€ ์•Š์•˜์œผ๋‚˜ ์š•์‹ฌ๋งŒ ๋งŽ๊ณ  ๋ฉ์ฒญํ•œ ํƒ“์—โ€ฆ ์•Œ๊ฒŒ ๋˜์—ˆ๋‹ค. ์•„๋ž˜ ์˜ˆ์‹œ ์ฝ”๋“œ๋ฅผ ๋จผ์ € ์‚ดํŽด๋ณด์ž.

# Before Append
def forward(self, inputs: dict, position_list: Tensor) -> Tensor:
        outputs = self.feature(inputs)
        feature = outputs.last_hidden_state
        pred = []
        for i in range(self.cfg.batch_size):
            """ Apply Pooling & Fully Connected Layer for each unique cell in batch (one notebook_id) """
            for idx in range(len(position_list[i])):
                src, end = position_list[i][idx]
                embedding = self.pooling(feature[i, src:end + 1, :].unsqueeze(dim=0))  # maybe don't need mask
                logit = self.fc(embedding)
                pred.append(logit)  
            pred = torch.as_tensor(pred, device=self.cfg.device)
        return pred

# After Append
def forward(self, inputs: dict, position_list: Tensor) -> Tensor:
        outputs = self.feature(inputs)
        feature = outputs.last_hidden_state
        pred = torch.tensor([], device=self.cfg.device)
        for i in range(self.cfg.batch_size):
            """ Apply Pooling & Fully Connected Layer for each unique cell in batch (one notebook_id) """
            for idx in range(len(position_list[i])):
                src, end = position_list[i][idx]
                embedding = self.pooling(feature[i, src:end + 1, :].unsqueeze(dim=0))  # maybe don't need mask
                logit = self.fc(embedding)
                pred = torch.cat([pred, logit], dim=0)
        return pred

๋‹ค์Œ ์ฝ”๋“œ๋Š” ํ•„์ž๊ฐ€ ๊ณต๋ถ€๋ฅผ ์œ„ํ•ด ๋งŒ๋“  ๋ชจ๋ธ ํด๋ž˜์Šค์˜ forward ๋ฉ”์„œ๋“œ์ด๋‹ค. ์ „์ž๋Š” ์ด๋ฒˆ ํฌ์ŠคํŒ…์˜ ์ฃผ์ œ์ธ ์—๋Ÿฌ๋ฅผ ์ผ์œผํ‚จ ์ฃผ์ธ๊ณต์ด๊ณ , ํ›„์ž๋Š” ์—๋Ÿฌ๋ฅผ ์ˆ˜์ •ํ•œ ์ดํ›„ ์ •์ƒ์ ์œผ๋กœ ์ž‘๋™ํ•˜๋Š” ์ฝ”๋“œ๋‹ค. ๋…์ž ์—ฌ๋Ÿฌ๋ถ„๋“ค๋„ ๋‘ ์ฝ”๋“œ์— ์–ด๋–ค ์ฐจ์ด๊ฐ€ ์žˆ๋Š”์ง€ ์Šค์Šค๋กœ ์งˆ๋ฌธ์„ ๋˜์ง€๋ฉด์„œ ์ฝ์–ด์ฃผ์‹œ๊ธธ ๋ฐ”๋ž€๋‹ค.

Model Overview Modeling Overview

์œ„์˜ ์ฝ”๋“œ๋“ค์€ DeBERTa-V3-Large ์˜ ๋งˆ์ง€๋ง‰ ์ธ์ฝ”๋” ๋ ˆ์ด์–ด๊ฐ€ ๋ฐ˜ํ™˜ํ•˜๋Š” last_hidden_state ๋ฅผ ๋ฏธ๋ฆฌ ์„ค์ •ํ•œ ์„œ๋ธŒ ๊ตฌ๊ฐ„๋ณ„๋กœ ๋‚˜๋ˆ„๊ณ  ๊ฐœ๋ณ„์ ์œผ๋กœ pooling & fully connected layer ์— ํ†ต๊ณผ์‹œ์ผœ ๋กœ์ง“๊ฐ’์œผ๋กœ ๋ณ€ํ™˜ํ•˜๊ธฐ ์œ„ํ•ด ๋งŒ๋“ค์—ˆ๋‹ค. ์‰ฝ๊ฒŒ ๋งํ•ด ์ž…๋ ฅ์œผ๋กœ ํ† ํฐ(๋‹จ์–ด) 384๊ฐœ ์งœ๋ฆฌ ๋ฌธ์žฅ์„ ํ•˜๋‚˜ ๋„ฃ์—ˆ๊ณ , ๋ชจ๋ธ์€ 384๊ฐœ์˜ ๊ฐœ๋ณ„ ํ† ํฐ์— ๋Œ€ํ•œ ์ž„๋ฒ ๋”ฉ ๊ฐ’์„ ๋ฐ˜ํ™˜ํ–ˆ๋Š”๋ฐ ๊ทธ๊ฒƒ์„ ์ „๋ถ€ ์ด์šฉํ•˜๋Š” ๊ฒƒ์ด ์•„๋‹ˆ๋ผ ์˜ˆ๋ฅผ ๋“ค์–ด 2๋ฒˆ~4๋ฒˆ ํ† ํฐ์„ 1๋ฒˆ ๊ตฌ๊ฐ„, 6๋ฒˆ~20๋ฒˆ ํ† ํฐ์„ 2๋ฒˆ ๊ตฌ๊ฐ„, 30๋ฒˆ~50๋ฒˆ ํ† ํฐ์„ 3๋ฒˆ ๊ตฌ๊ฐ„ โ€ฆ 370๋ฒˆ~380๋ฒˆ ํ† ํฐ์„ 30๋ฒˆ ๊ตฌ๊ฐ„์œผ๋กœ ์„ค์ •ํ•˜๊ณ  ๊ตฌ๊ฐ„ ๋ณ„๋กœ ๋”ฐ๋กœ pooling & fully connected layer ์— ํ†ต๊ณผ์‹œ์ผœ ๋กœ์ง“์„ ๊ตฌํ•˜๋Š” ๊ฒƒ์ด๋‹ค. ์ผ๋ฐ˜์ ์ด๋ผ๋ฉด 1๊ฐœ์˜ ๋ฌธ์žฅ์—์„œ 1๊ฐœ์˜ ์ตœ์ข… ๋กœ์ง“๊ฐ’์ด ๋„์ถœ๋˜๋Š” ๊ฒƒ์ด๋ผ๋ฉด, ์œ„ ์ฝ”๋“œ๋Š” 30๊ฐœ์˜ ๋กœ์ง“๊ฐ’์ด ๋„์ถœ๋œ๋‹ค.

๐Ÿ–๏ธ ๋‚ด๊ฐ€ ํ•ด๊ฒฐํ•œ ๋ฐฉ๋ฒ•

์ฝ”๋“œ ์ดํ•ด๋ฅผ ์œ„ํ•œ ์„ค๋ช…์€ ๋งˆ์ณค์œผ๋‹ˆ ๋ณธ๊ฒฉ์ ์œผ๋กœ ๋ณธ ์—๋Ÿฌ์™€ ์–ด๋–ค ์—ฐ๊ด€์ด ์žˆ๋Š”์ง€ ์‚ดํŽด๋ณด์ž. Before ์ฝ”๋“œ๋Š” pred ๋ผ๋Š” ๋ฆฌ์ŠคํŠธ์— ๊ฐœ๋ณ„ ๊ตฌ๊ฐ„์— ๋Œ€ํ•œ ๋กœ์ง“๊ฐ’์„ append ํ•˜๊ณ  ๋งˆ์ง€๋ง‰์— torch.as_tensor๋ฅผ ํ™œ์šฉํ•ด ํ…์„œ๋กœ ๋ณ€ํ™˜ํ•˜๊ณ  ์žˆ๋‹ค. ํ•œํŽธ ํ›„์ž๋Š” pred ๋ฅผ ๊นกํ†ต ํ…์„œ๋กœ ์„ ์–ธํ•œ ๋’ค, torch.cat์œผ๋กœ ๋ชจ๋“  ๊ตฌ๊ฐ„์— ๋Œ€ํ•œ ๋กœ์ง“๊ฐ’์„ ํ•˜๋‚˜์˜ ํ…์„œ ๊ตฌ์กฐ์ฒด์— ๋‹ด๊ณ  ์žˆ๋‹ค.

์–ผํ•๋ณด๋ฉด ํฌ๊ฒŒ ๋‹ค๋ฅธ์ ์ด ์—†์–ด ๋ณด์ธ๋‹ค. ํ•˜์ง€๋งŒ ์ „์ž๋Š” ํ…์„œ ๊ตฌ์กฐ์ฒด๋ฅผ ์ƒˆ๋กœ ์ •์˜ ํ•˜๋ฉด์„œ torch.Tensor[[logit1], [logit2], โ€ฆ.] ํ˜•ํƒœ๋ฅผ ๊ฐ–๊ณ  ํ›„์ž๋Š” torch.Tensor[logit1, logit2, โ€ฆ] ํ˜•ํƒœ๋ฅผ ๊ฐ–๋Š”๋‹ค. ์„œ๋กœ ๋‹ค๋ฅธ ํ…์„œ ๊ตฌ์กฐ์ฒด๋ฅผ ๊ทธ๋Œ€๋กœ ๋ชจ๋ธ ๊ฐ์ฒด์˜ forward ๋ฉ”์„œ๋“œ ๋ฐ loss function์— ํ†ต๊ณผ์‹œํ‚ค๊ณ  ์˜ค์ฐจ ์—ญ์ „์„ ํ•˜๋ฉด ์–ด๋–ค ์ผ์ด ์ƒ๊ธฐ๋Š”์ง€ ์ง€๊ธˆ๋ถ€ํ„ฐ ์•Œ์•„๋ณด์ž.

์ „์ž์˜ ๊ฒฝ์šฐ๋Š” ๋„์ถœ๋œ ์†์‹คํ•จ์ˆ˜์˜ ๋ฏธ๋ถ„๊ฐ’์ด ์ •์˜๋œ ๊ณ„์‚ฐ ๊ทธ๋ž˜ํ”„๋ฅผ ํƒ€๊ณ  ์—ญ์ „๋  ์ˆ˜ ์—†๋‹ค. ์ด์œ ๋Š” ์ „์ž์˜ pred ๊ฐ€ forward ๋ฉ”์„œ๋“œ ๋‚ด๋ถ€์—์„œ ์ƒˆ๋กœ์ด ์ •์˜ ๋˜์—ˆ๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค. ํ›„์ž ์—ญ์‹œ ๋งˆ์ฐฌ๊ฐ€์ง€ ์•„๋‹Œ๊ฐ€ ์‹ถ์„ ๊ฒƒ์ด๋‹ค. ํ›„์ž์˜ pred ์—ญ์‹œ forward ๋ฉ”์„œ๋“œ ๋‚ด๋ถ€์—์„œ ์ •์˜๋œ ๊ฒƒ์€ ๋งž์ง€๋งŒ torch.cat์„ ์‚ฌ์šฉํ•˜๋ฉด์„œ ๊ตฌ๊ฐ„์˜ ๋กœ์ง“๊ฐ’๋“ค ์œ„์— ์ƒˆ๋กœ์ด ์ฐจ์›์„ ๋ฎ์–ด์“ฐ๋Š”๊ฒƒ์ด ์•„๋‹ˆ๊ฒŒ ๋œ๋‹ค. ์ด๊ฒƒ์ด ๋งค์šฐ ์ค‘์š”ํ•œ ์ฐจ์ด๊ฐ€ ๋˜๋Š”๋ฐ, ํ›„์ž์™€ ๊ฐ™์€ ํ˜•ํƒœ๊ฐ€ ๋˜๋Š” ๊ฒฝ์šฐ, ์†์‹ค๊ฐ’์œผ๋กœ ๋ถ€ํ„ฐ Backward ๋˜๋Š” ๋ฏธ๋ถ„๊ฐ’๋“ค์ด ๊ณง๋ฐ”๋กœ forward ๊ณผ์ •์—์„œ ๊ธฐ๋ก๋œ ์ž์‹ ์˜ ๊ณ„์‚ฐ ๊ทธ๋ž˜ํ”„๋กœ ์ฐพ์•„ ๊ฐˆ ์ˆ˜ ์žˆ๋‹ค. ํ•œํŽธ ์ „์ž์˜ ๊ฒฝ์šฐ ์ƒˆ๋กญ๊ฒŒ ๋ฎ์–ด ์“ฐ์—ฌ์ง„ ์ฐจ์› ๋•Œ๋ฌธ์— ๋ฏธ๋ถ„๊ฐ’๋“ค์ด ์ž์‹ ์˜ ๊ณ„์‚ฐ ๊ทธ๋ž˜ํ”„๋กœ ์ฐพ์•„๊ฐˆ ์ˆ˜ ์—†๊ฒŒ ๋œ๋‹ค. ๋”ฐ๋ผ์„œ ์˜ตํ‹ฐ๋งˆ์ด์ €๊ฐ€ ๋” ์ด์ƒ Backward ๋ฅผ ์ˆ˜ํ–‰ํ•  ์ˆ˜ ์—†์–ด ์ œ๋ชฉ๊ณผ ๊ฐ™์€ ์—๋Ÿฌ๋ฅผ ๋ฐ˜ํ™˜ํ•˜๊ฒŒ ๋˜๋Š” ๊ฒƒ์ด๋‹ค.

์ฒ˜์Œ ์ด ์—๋Ÿฌ๋ฅผ ๋งˆ์ฃผํ–ˆ์„ ๋•Œ๋Š” found_inf_per_device, No inf checks ๋ผ๋Š” ํ‚ค์›Œ๋“œ์— ๊ฝ‚ํ˜€ (ํŠนํžˆ inf) <RuntimeError: Function 'LogSoftmaxBackward0' returned nan values in its 0th output> ์ด๊ฒƒ๊ณผ ์œ ์‚ฌํ•œ ์ข…๋ฅ˜์˜ ์—๋Ÿฌ๋ผ ์ƒ๊ฐํ•˜๊ณ  ์—ด์‹ฌํžˆ ์—ฐ์‚ฐ ๊ณผ์ •์— ๋ฌธ์ œ๊ฐ€ ์—†๋Š”์ง€, ์–ด๋””์„œ NaN์ด ๋ฐœ์ƒํ•˜๋Š”์ง€, ํ•™์Šต๋ฅ ์„ ๋„ˆ๋ฌด ํฌ๊ฒŒ ์„ค์ •ํ–ˆ๋Š”์ง€ ๋“ฑ์„ ๊ฒ€ํ† ํ•˜๋ฉฐ ํ•˜๋ฃจ๋ฅผ ๋‚ ๋ ธ์—ˆ๋˜ ๊ธฐ์–ต์ด ์žˆ๋‹ค.

Leave a comment