๐ชข assert len(optimizer_state[โfound_inf_per_deviceโ]) > 0, โNo inf checks were recorded for this optimizer.โ AssertionError: No inf checks were recorded for this optimizer.
๐ค Optimizer๊ฐ ์์ค๊ฐ์ ์ ๋๋ก Backward ํ ์ ์๋ ๋ฌธ์
ํ
์์ ๊ณ์ฐ ๊ทธ๋ํ๊ฐ ์ค๊ฐ์ ๋์ด์ ธ ์ตํฐ๋ง์ด์ ๊ฐ ๊ทธ๋ผ๋์ธํธ๋ฅผ ์ ๋๋ก Backward ํ์ง ๋ชปํด ๋ฐ์ํ๋ ์๋ฌ๋ค. ๊ณต๋ถ๋ฅผ ์์ํ๊ณ ์ ๋ง ์ฒ์ ๋ง์ฃผํ๋ ์๋ฌ๋ผ์ ์ ๋ง ๋ง์ด ๋นํฉํ๋ค. ๋ํผ๋ฐ์ค ์๋ฃ ์ญ์ ๊ฑฐ์ ์์ด์ ํด๊ฒฐํ๋๋ฐ ์ ๋ฅผ ๋จน์๋ ์ฐ๋ผ๋ฆฐ ์ฌ์ฐ์ด ์๋ ์๋ฌ๋ค. ์ด ๊ธ์ ์ฝ๋ ๋
์๋ผ๋ฉด ๋๋ถ๋ถ ํ
์์ ๊ณ์ฐ ๊ทธ๋ํ๊ฐ ์ค๊ฐ์ ๋์ด์ง๋ค๋ ๊ฒ์ด ๋ฌด์จ ์๋ฏธ์ผ์ง ์ดํดํ์์ง ๋ชปํ ๊ฑฐ๋ผ ์๊ฐํ๋ค. ๊ทธ๊ฒ ์ ์์ด๋ค. ํ์ ์ญ์ ์๊ณ ์ถ์ง ์์์ผ๋ ์์ฌ๋ง ๋ง๊ณ ๋ฉ์ฒญํ ํ์โฆ ์๊ฒ ๋์๋ค. ์๋ ์์ ์ฝ๋๋ฅผ ๋จผ์ ์ดํด๋ณด์.
# Before Append
def forward(self, inputs: dict, position_list: Tensor) -> Tensor:
outputs = self.feature(inputs)
feature = outputs.last_hidden_state
pred = []
for i in range(self.cfg.batch_size):
""" Apply Pooling & Fully Connected Layer for each unique cell in batch (one notebook_id) """
for idx in range(len(position_list[i])):
src, end = position_list[i][idx]
embedding = self.pooling(feature[i, src:end + 1, :].unsqueeze(dim=0)) # maybe don't need mask
logit = self.fc(embedding)
pred.append(logit)
pred = torch.as_tensor(pred, device=self.cfg.device)
return pred
# After Append
def forward(self, inputs: dict, position_list: Tensor) -> Tensor:
outputs = self.feature(inputs)
feature = outputs.last_hidden_state
pred = torch.tensor([], device=self.cfg.device)
for i in range(self.cfg.batch_size):
""" Apply Pooling & Fully Connected Layer for each unique cell in batch (one notebook_id) """
for idx in range(len(position_list[i])):
src, end = position_list[i][idx]
embedding = self.pooling(feature[i, src:end + 1, :].unsqueeze(dim=0)) # maybe don't need mask
logit = self.fc(embedding)
pred = torch.cat([pred, logit], dim=0)
return pred
๋ค์ ์ฝ๋๋ ํ์๊ฐ ๊ณต๋ถ๋ฅผ ์ํด ๋ง๋ ๋ชจ๋ธ ํด๋์ค์ forward ๋ฉ์๋์ด๋ค. ์ ์๋ ์ด๋ฒ ํฌ์คํ
์ ์ฃผ์ ์ธ ์๋ฌ๋ฅผ ์ผ์ผํจ ์ฃผ์ธ๊ณต์ด๊ณ , ํ์๋ ์๋ฌ๋ฅผ ์์ ํ ์ดํ ์ ์์ ์ผ๋ก ์๋ํ๋ ์ฝ๋๋ค. ๋
์ ์ฌ๋ฌ๋ถ๋ค๋ ๋ ์ฝ๋์ ์ด๋ค ์ฐจ์ด๊ฐ ์๋์ง ์ค์ค๋ก ์ง๋ฌธ์ ๋์ง๋ฉด์ ์ฝ์ด์ฃผ์๊ธธ ๋ฐ๋๋ค.
Modeling Overview
์์ ์ฝ๋๋ค์ DeBERTa-V3-Large ์ ๋ง์ง๋ง ์ธ์ฝ๋ ๋ ์ด์ด๊ฐ ๋ฐํํ๋ last_hidden_state ๋ฅผ ๋ฏธ๋ฆฌ ์ค์ ํ ์๋ธ ๊ตฌ๊ฐ๋ณ๋ก ๋๋๊ณ ๊ฐ๋ณ์ ์ผ๋ก pooling & fully connected layer ์ ํต๊ณผ์์ผ ๋ก์ง๊ฐ์ผ๋ก ๋ณํํ๊ธฐ ์ํด ๋ง๋ค์๋ค. ์ฝ๊ฒ ๋งํด ์
๋ ฅ์ผ๋ก ํ ํฐ(๋จ์ด) 384๊ฐ ์ง๋ฆฌ ๋ฌธ์ฅ์ ํ๋ ๋ฃ์๊ณ , ๋ชจ๋ธ์ 384๊ฐ์ ๊ฐ๋ณ ํ ํฐ์ ๋ํ ์๋ฒ ๋ฉ ๊ฐ์ ๋ฐํํ๋๋ฐ ๊ทธ๊ฒ์ ์ ๋ถ ์ด์ฉํ๋ ๊ฒ์ด ์๋๋ผ ์๋ฅผ ๋ค์ด 2๋ฒ~4๋ฒ ํ ํฐ์ 1๋ฒ ๊ตฌ๊ฐ, 6๋ฒ~20๋ฒ ํ ํฐ์ 2๋ฒ ๊ตฌ๊ฐ, 30๋ฒ~50๋ฒ ํ ํฐ์ 3๋ฒ ๊ตฌ๊ฐ โฆ 370๋ฒ~380๋ฒ ํ ํฐ์ 30๋ฒ ๊ตฌ๊ฐ์ผ๋ก ์ค์ ํ๊ณ ๊ตฌ๊ฐ ๋ณ๋ก ๋ฐ๋ก pooling & fully connected layer ์ ํต๊ณผ์์ผ ๋ก์ง์ ๊ตฌํ๋ ๊ฒ์ด๋ค. ์ผ๋ฐ์ ์ด๋ผ๋ฉด 1๊ฐ์ ๋ฌธ์ฅ์์ 1๊ฐ์ ์ต์ข
๋ก์ง๊ฐ์ด ๋์ถ๋๋ ๊ฒ์ด๋ผ๋ฉด, ์ ์ฝ๋๋ 30๊ฐ์ ๋ก์ง๊ฐ์ด ๋์ถ๋๋ค.
๐๏ธ ๋ด๊ฐ ํด๊ฒฐํ ๋ฐฉ๋ฒ
์ฝ๋ ์ดํด๋ฅผ ์ํ ์ค๋ช
์ ๋ง์ณค์ผ๋ ๋ณธ๊ฒฉ์ ์ผ๋ก ๋ณธ ์๋ฌ์ ์ด๋ค ์ฐ๊ด์ด ์๋์ง ์ดํด๋ณด์. Before ์ฝ๋๋ pred ๋ผ๋ ๋ฆฌ์คํธ์ ๊ฐ๋ณ ๊ตฌ๊ฐ์ ๋ํ ๋ก์ง๊ฐ์ append ํ๊ณ ๋ง์ง๋ง์ torch.as_tensor๋ฅผ ํ์ฉํด ํ
์๋ก ๋ณํํ๊ณ ์๋ค. ํํธ ํ์๋ pred ๋ฅผ ๊นกํต ํ
์๋ก ์ ์ธํ ๋ค, torch.cat์ผ๋ก ๋ชจ๋ ๊ตฌ๊ฐ์ ๋ํ ๋ก์ง๊ฐ์ ํ๋์ ํ
์ ๊ตฌ์กฐ์ฒด์ ๋ด๊ณ ์๋ค.
์ผํ๋ณด๋ฉด ํฌ๊ฒ ๋ค๋ฅธ์ ์ด ์์ด ๋ณด์ธ๋ค. ํ์ง๋ง ์ ์๋ ํ
์ ๊ตฌ์กฐ์ฒด๋ฅผ ์๋ก ์ ์ ํ๋ฉด์ torch.Tensor[[logit1], [logit2], โฆ.] ํํ๋ฅผ ๊ฐ๊ณ ํ์๋ torch.Tensor[logit1, logit2, โฆ] ํํ๋ฅผ ๊ฐ๋๋ค. ์๋ก ๋ค๋ฅธ ํ
์ ๊ตฌ์กฐ์ฒด๋ฅผ ๊ทธ๋๋ก ๋ชจ๋ธ ๊ฐ์ฒด์ forward ๋ฉ์๋ ๋ฐ loss function์ ํต๊ณผ์ํค๊ณ ์ค์ฐจ ์ญ์ ์ ํ๋ฉด ์ด๋ค ์ผ์ด ์๊ธฐ๋์ง ์ง๊ธ๋ถํฐ ์์๋ณด์.
์ ์์ ๊ฒฝ์ฐ๋ ๋์ถ๋ ์์คํจ์์ ๋ฏธ๋ถ๊ฐ์ด ์ ์๋ ๊ณ์ฐ ๊ทธ๋ํ๋ฅผ ํ๊ณ ์ญ์ ๋ ์ ์๋ค. ์ด์ ๋ ์ ์์ pred ๊ฐ forward ๋ฉ์๋ ๋ด๋ถ์์ ์๋ก์ด ์ ์ ๋์๊ธฐ ๋๋ฌธ์ด๋ค. ํ์ ์ญ์ ๋ง์ฐฌ๊ฐ์ง ์๋๊ฐ ์ถ์ ๊ฒ์ด๋ค. ํ์์ pred ์ญ์ forward ๋ฉ์๋ ๋ด๋ถ์์ ์ ์๋ ๊ฒ์ ๋ง์ง๋ง torch.cat์ ์ฌ์ฉํ๋ฉด์ ๊ตฌ๊ฐ์ ๋ก์ง๊ฐ๋ค ์์ ์๋ก์ด ์ฐจ์์ ๋ฎ์ด์ฐ๋๊ฒ์ด ์๋๊ฒ ๋๋ค. ์ด๊ฒ์ด ๋งค์ฐ ์ค์ํ ์ฐจ์ด๊ฐ ๋๋๋ฐ, ํ์์ ๊ฐ์ ํํ๊ฐ ๋๋ ๊ฒฝ์ฐ, ์์ค๊ฐ์ผ๋ก ๋ถํฐ Backward ๋๋ ๋ฏธ๋ถ๊ฐ๋ค์ด ๊ณง๋ฐ๋ก forward ๊ณผ์ ์์ ๊ธฐ๋ก๋ ์์ ์ ๊ณ์ฐ ๊ทธ๋ํ๋ก ์ฐพ์ ๊ฐ ์ ์๋ค. ํํธ ์ ์์ ๊ฒฝ์ฐ ์๋กญ๊ฒ ๋ฎ์ด ์ฐ์ฌ์ง ์ฐจ์ ๋๋ฌธ์ ๋ฏธ๋ถ๊ฐ๋ค์ด ์์ ์ ๊ณ์ฐ ๊ทธ๋ํ๋ก ์ฐพ์๊ฐ ์ ์๊ฒ ๋๋ค. ๋ฐ๋ผ์ ์ตํฐ๋ง์ด์ ๊ฐ ๋ ์ด์ Backward ๋ฅผ ์ํํ ์ ์์ด ์ ๋ชฉ๊ณผ ๊ฐ์ ์๋ฌ๋ฅผ ๋ฐํํ๊ฒ ๋๋ ๊ฒ์ด๋ค.
์ฒ์ ์ด ์๋ฌ๋ฅผ ๋ง์ฃผํ์ ๋๋ found_inf_per_device, No inf checks ๋ผ๋ ํค์๋์ ๊ฝํ (ํนํ inf) <RuntimeError: Function 'LogSoftmaxBackward0' returned nan values in its 0th output> ์ด๊ฒ๊ณผ ์ ์ฌํ ์ข
๋ฅ์ ์๋ฌ๋ผ ์๊ฐํ๊ณ ์ด์ฌํ ์ฐ์ฐ ๊ณผ์ ์ ๋ฌธ์ ๊ฐ ์๋์ง, ์ด๋์ NaN์ด ๋ฐ์ํ๋์ง, ํ์ต๋ฅ ์ ๋๋ฌด ํฌ๊ฒ ์ค์ ํ๋์ง ๋ฑ์ ๊ฒํ ํ๋ฉฐ ํ๋ฃจ๋ฅผ ๋ ๋ ธ์๋ ๊ธฐ์ต์ด ์๋ค.
Leave a comment