๐ชข assert len(optimizer_state[โfound_inf_per_deviceโ]) > 0, โNo inf checks were recorded for this optimizer.โ AssertionError: No inf checks were recorded for this optimizer.
๐ค Optimizer๊ฐ ์์ค๊ฐ์ ์ ๋๋ก Backward ํ ์ ์๋ ๋ฌธ์
ํ
์์ ๊ณ์ฐ ๊ทธ๋ํ๊ฐ ์ค๊ฐ์ ๋์ด์ ธ ์ตํฐ๋ง์ด์ ๊ฐ ๊ทธ๋ผ๋์ธํธ๋ฅผ ์ ๋๋ก Backward
ํ์ง ๋ชปํด ๋ฐ์ํ๋ ์๋ฌ๋ค. ๊ณต๋ถ๋ฅผ ์์ํ๊ณ ์ ๋ง ์ฒ์ ๋ง์ฃผํ๋ ์๋ฌ๋ผ์ ์ ๋ง ๋ง์ด ๋นํฉํ๋ค. ๋ํผ๋ฐ์ค ์๋ฃ ์ญ์ ๊ฑฐ์ ์์ด์ ํด๊ฒฐํ๋๋ฐ ์ ๋ฅผ ๋จน์๋ ์ฐ๋ผ๋ฆฐ ์ฌ์ฐ์ด ์๋ ์๋ฌ๋ค. ์ด ๊ธ์ ์ฝ๋ ๋
์๋ผ๋ฉด ๋๋ถ๋ถ ํ
์์ ๊ณ์ฐ ๊ทธ๋ํ๊ฐ ์ค๊ฐ์ ๋์ด์ง๋ค๋ ๊ฒ์ด ๋ฌด์จ ์๋ฏธ์ผ์ง ์ดํดํ์์ง ๋ชปํ ๊ฑฐ๋ผ ์๊ฐํ๋ค. ๊ทธ๊ฒ ์ ์์ด๋ค. ํ์ ์ญ์ ์๊ณ ์ถ์ง ์์์ผ๋ ์์ฌ๋ง ๋ง๊ณ ๋ฉ์ฒญํ ํ์โฆ ์๊ฒ ๋์๋ค. ์๋ ์์ ์ฝ๋๋ฅผ ๋จผ์ ์ดํด๋ณด์.
# Before Append
def forward(self, inputs: dict, position_list: Tensor) -> Tensor:
outputs = self.feature(inputs)
feature = outputs.last_hidden_state
pred = []
for i in range(self.cfg.batch_size):
""" Apply Pooling & Fully Connected Layer for each unique cell in batch (one notebook_id) """
for idx in range(len(position_list[i])):
src, end = position_list[i][idx]
embedding = self.pooling(feature[i, src:end + 1, :].unsqueeze(dim=0)) # maybe don't need mask
logit = self.fc(embedding)
pred.append(logit)
pred = torch.as_tensor(pred, device=self.cfg.device)
return pred
# After Append
def forward(self, inputs: dict, position_list: Tensor) -> Tensor:
outputs = self.feature(inputs)
feature = outputs.last_hidden_state
pred = torch.tensor([], device=self.cfg.device)
for i in range(self.cfg.batch_size):
""" Apply Pooling & Fully Connected Layer for each unique cell in batch (one notebook_id) """
for idx in range(len(position_list[i])):
src, end = position_list[i][idx]
embedding = self.pooling(feature[i, src:end + 1, :].unsqueeze(dim=0)) # maybe don't need mask
logit = self.fc(embedding)
pred = torch.cat([pred, logit], dim=0)
return pred
๋ค์ ์ฝ๋๋ ํ์๊ฐ ๊ณต๋ถ๋ฅผ ์ํด ๋ง๋ ๋ชจ๋ธ ํด๋์ค์ forward
๋ฉ์๋์ด๋ค. ์ ์๋ ์ด๋ฒ ํฌ์คํ
์ ์ฃผ์ ์ธ ์๋ฌ๋ฅผ ์ผ์ผํจ ์ฃผ์ธ๊ณต์ด๊ณ , ํ์๋ ์๋ฌ๋ฅผ ์์ ํ ์ดํ ์ ์์ ์ผ๋ก ์๋ํ๋ ์ฝ๋๋ค. ๋
์ ์ฌ๋ฌ๋ถ๋ค๋ ๋ ์ฝ๋์ ์ด๋ค ์ฐจ์ด๊ฐ ์๋์ง ์ค์ค๋ก ์ง๋ฌธ์ ๋์ง๋ฉด์ ์ฝ์ด์ฃผ์๊ธธ ๋ฐ๋๋ค.
Modeling Overview
์์ ์ฝ๋๋ค์ DeBERTa-V3-Large
์ ๋ง์ง๋ง ์ธ์ฝ๋ ๋ ์ด์ด๊ฐ ๋ฐํํ๋ last_hidden_state
๋ฅผ ๋ฏธ๋ฆฌ ์ค์ ํ ์๋ธ ๊ตฌ๊ฐ๋ณ๋ก ๋๋๊ณ ๊ฐ๋ณ์ ์ผ๋ก pooling & fully connected layer
์ ํต๊ณผ์์ผ ๋ก์ง๊ฐ์ผ๋ก ๋ณํํ๊ธฐ ์ํด ๋ง๋ค์๋ค. ์ฝ๊ฒ ๋งํด ์
๋ ฅ์ผ๋ก ํ ํฐ(๋จ์ด) 384๊ฐ ์ง๋ฆฌ ๋ฌธ์ฅ์ ํ๋ ๋ฃ์๊ณ , ๋ชจ๋ธ์ 384๊ฐ์ ๊ฐ๋ณ ํ ํฐ์ ๋ํ ์๋ฒ ๋ฉ ๊ฐ์ ๋ฐํํ๋๋ฐ ๊ทธ๊ฒ์ ์ ๋ถ ์ด์ฉํ๋ ๊ฒ์ด ์๋๋ผ ์๋ฅผ ๋ค์ด 2๋ฒ~4๋ฒ
ํ ํฐ์ 1๋ฒ ๊ตฌ๊ฐ, 6๋ฒ~20๋ฒ
ํ ํฐ์ 2๋ฒ ๊ตฌ๊ฐ, 30๋ฒ~50๋ฒ
ํ ํฐ์ 3๋ฒ ๊ตฌ๊ฐ โฆ 370๋ฒ~380๋ฒ
ํ ํฐ์ 30๋ฒ ๊ตฌ๊ฐ์ผ๋ก ์ค์ ํ๊ณ ๊ตฌ๊ฐ ๋ณ๋ก ๋ฐ๋ก pooling & fully connected layer
์ ํต๊ณผ์์ผ ๋ก์ง์ ๊ตฌํ๋ ๊ฒ์ด๋ค. ์ผ๋ฐ์ ์ด๋ผ๋ฉด 1๊ฐ์ ๋ฌธ์ฅ์์ 1๊ฐ์ ์ต์ข
๋ก์ง๊ฐ์ด ๋์ถ๋๋ ๊ฒ์ด๋ผ๋ฉด, ์ ์ฝ๋๋ 30๊ฐ์ ๋ก์ง๊ฐ์ด ๋์ถ๋๋ค.
๐๏ธ ๋ด๊ฐ ํด๊ฒฐํ ๋ฐฉ๋ฒ
์ฝ๋ ์ดํด๋ฅผ ์ํ ์ค๋ช
์ ๋ง์ณค์ผ๋ ๋ณธ๊ฒฉ์ ์ผ๋ก ๋ณธ ์๋ฌ์ ์ด๋ค ์ฐ๊ด์ด ์๋์ง ์ดํด๋ณด์. Before
์ฝ๋๋ pred
๋ผ๋ ๋ฆฌ์คํธ์ ๊ฐ๋ณ ๊ตฌ๊ฐ์ ๋ํ ๋ก์ง๊ฐ์ append
ํ๊ณ ๋ง์ง๋ง์ torch.as_tensor
๋ฅผ ํ์ฉํด ํ
์๋ก ๋ณํํ๊ณ ์๋ค. ํํธ ํ์๋ pred
๋ฅผ ๊นกํต ํ
์๋ก ์ ์ธํ ๋ค, torch.cat
์ผ๋ก ๋ชจ๋ ๊ตฌ๊ฐ์ ๋ํ ๋ก์ง๊ฐ์ ํ๋์ ํ
์ ๊ตฌ์กฐ์ฒด์ ๋ด๊ณ ์๋ค.
์ผํ๋ณด๋ฉด ํฌ๊ฒ ๋ค๋ฅธ์ ์ด ์์ด ๋ณด์ธ๋ค. ํ์ง๋ง ์ ์๋ ํ
์ ๊ตฌ์กฐ์ฒด๋ฅผ ์๋ก ์ ์ ํ๋ฉด์ torch.Tensor[[logit1], [logit2], โฆ.]
ํํ๋ฅผ ๊ฐ๊ณ ํ์๋ torch.Tensor[logit1, logit2, โฆ]
ํํ๋ฅผ ๊ฐ๋๋ค. ์๋ก ๋ค๋ฅธ ํ
์ ๊ตฌ์กฐ์ฒด๋ฅผ ๊ทธ๋๋ก ๋ชจ๋ธ ๊ฐ์ฒด์ forward
๋ฉ์๋ ๋ฐ loss function
์ ํต๊ณผ์ํค๊ณ ์ค์ฐจ ์ญ์ ์ ํ๋ฉด ์ด๋ค ์ผ์ด ์๊ธฐ๋์ง ์ง๊ธ๋ถํฐ ์์๋ณด์.
์ ์์ ๊ฒฝ์ฐ๋ ๋์ถ๋ ์์คํจ์์ ๋ฏธ๋ถ๊ฐ์ด ์ ์๋ ๊ณ์ฐ ๊ทธ๋ํ๋ฅผ ํ๊ณ ์ญ์ ๋ ์ ์๋ค. ์ด์ ๋ ์ ์์ pred
๊ฐ forward ๋ฉ์๋ ๋ด๋ถ์์ ์๋ก์ด ์ ์ ๋์๊ธฐ ๋๋ฌธ์ด๋ค. ํ์ ์ญ์ ๋ง์ฐฌ๊ฐ์ง ์๋๊ฐ ์ถ์ ๊ฒ์ด๋ค. ํ์์ pred
์ญ์ forward
๋ฉ์๋ ๋ด๋ถ์์ ์ ์๋ ๊ฒ์ ๋ง์ง๋ง torch.cat
์ ์ฌ์ฉํ๋ฉด์ ๊ตฌ๊ฐ์ ๋ก์ง๊ฐ๋ค ์์ ์๋ก์ด ์ฐจ์์ ๋ฎ์ด์ฐ๋๊ฒ์ด ์๋๊ฒ ๋๋ค. ์ด๊ฒ์ด ๋งค์ฐ ์ค์ํ ์ฐจ์ด๊ฐ ๋๋๋ฐ, ํ์์ ๊ฐ์ ํํ๊ฐ ๋๋ ๊ฒฝ์ฐ, ์์ค๊ฐ์ผ๋ก ๋ถํฐ Backward
๋๋ ๋ฏธ๋ถ๊ฐ๋ค์ด ๊ณง๋ฐ๋ก forward
๊ณผ์ ์์ ๊ธฐ๋ก๋ ์์ ์ ๊ณ์ฐ ๊ทธ๋ํ๋ก ์ฐพ์ ๊ฐ ์ ์๋ค. ํํธ ์ ์์ ๊ฒฝ์ฐ ์๋กญ๊ฒ ๋ฎ์ด ์ฐ์ฌ์ง ์ฐจ์ ๋๋ฌธ์ ๋ฏธ๋ถ๊ฐ๋ค์ด ์์ ์ ๊ณ์ฐ ๊ทธ๋ํ๋ก ์ฐพ์๊ฐ ์ ์๊ฒ ๋๋ค. ๋ฐ๋ผ์ ์ตํฐ๋ง์ด์ ๊ฐ ๋ ์ด์ Backward
๋ฅผ ์ํํ ์ ์์ด ์ ๋ชฉ๊ณผ ๊ฐ์ ์๋ฌ๋ฅผ ๋ฐํํ๊ฒ ๋๋ ๊ฒ์ด๋ค.
์ฒ์ ์ด ์๋ฌ๋ฅผ ๋ง์ฃผํ์ ๋๋ found_inf_per_device
, No inf checks
๋ผ๋ ํค์๋์ ๊ฝํ (ํนํ inf
) <RuntimeError: Function 'LogSoftmaxBackward0' returned nan values in its 0th output>
์ด๊ฒ๊ณผ ์ ์ฌํ ์ข
๋ฅ์ ์๋ฌ๋ผ ์๊ฐํ๊ณ ์ด์ฌํ ์ฐ์ฐ ๊ณผ์ ์ ๋ฌธ์ ๊ฐ ์๋์ง, ์ด๋์ NaN์ด ๋ฐ์ํ๋์ง, ํ์ต๋ฅ ์ ๋๋ฌด ํฌ๊ฒ ์ค์ ํ๋์ง ๋ฑ์ ๊ฒํ ํ๋ฉฐ ํ๋ฃจ๋ฅผ ๋ ๋ ธ์๋ ๊ธฐ์ต์ด ์๋ค.
Leave a comment