๐ฅย Pytorch Tensor Indexing ์์ฃผ ์ฌ์ฉํ๋ ๋ฉ์๋ ๋ชจ์์ง
ํ์ดํ ์น์์ ํ์๊ฐ ์์ฃผ ์ฌ์ฉํ๋ ํ ์ ์ธ๋ฑ์ฑ ๊ด๋ จ ๋ฉ์๋์ ์ฌ์ฉ๋ฒ ๋ฐ ์ฌ์ฉ ์์๋ฅผ ํ๋ฐฉ์ ์ ๋ฆฌํ ํฌ์คํธ๋ค. ๋ฉ์๋ ํ๋๋น ํ๋์ ํฌ์คํธ๋ก ๋ง๋ค๊ธฐ์๋ ๋๋ฌด ๊ธธ์ด๊ฐ ์งง๋ค ์๊ฐํด ํ ํ์ด์ง์ ๋ชจ๋ ๋ฃ๊ฒ ๋์๋ค. ์ง์์ ์ผ๋ก ์ ๋ฐ์ดํธ ๋ ์์ ์ด๋ค. ๋ํ ํ ์ ์ธ๋ฑ์ฑ ๋ง๊ณ ๋ ๋ค๋ฅธ ์ฃผ์ ๋ก๋ ๊ด๋ จ ๋ฉ์๋๋ฅผ ์ ๋ฆฌํด ์ฌ๋ฆด ์์ ์ด๋ ๋ง์ ๊ด์ฌ ๋ถํ๋๋ฆฐ๋ค.
๐ย torch.argmax
์ ๋ ฅ ํ ์์์ ๊ฐ์ฅ ํฐ ๊ฐ์ ๊ฐ๊ณ ์๋ ์์์ ์ธ๋ฑ์ค๋ฅผ ๋ฐํํ๋ค. ์ต๋๊ฐ์ ์ฐพ์ ์ฐจ์์ ์ง์ ํด์ค ์ ์๋ค. ์๋ ์์ ์ฝ๋๋ฅผ ํ์ธํด๋ณด์.
# torch.argmax params
torch.argmax(tensor, dim=None, keepdim=False)
# torch.argmax example 1
test = torch.tensor([1,29,2,45,22,3])
torch.argmax(test)
torch.argmax(test2, keepdim=True)
<Result>
tensor(3)
# torch.argmax example 2
test = torch.tensor([[4, 2, 3],
[4, 5, 6]])
torch.argmax(test, dim=0, keepdim=True)
<Result>
tensor([[0, 1, 1]])
# torch.argmax example 3
test = torch.tensor([[4, 2, 3],
[4, 5, 6]])
torch.argmax(test, dim=1, keepdim=True)
tensor([[0],
[2]])
dim
๋งค๊ฐ๋ณ์์ ์ํ๋ ์ฐจ์์ ์
๋ ฅํ๋ฉด ํด๋น ์ฐจ์ ๋ทฐ์์ ๊ฐ์ฅ ํฐ ์์๋ฅผ ์ฐพ์ ์ธ๋ฑ์ค ๊ฐ์ ๋ฐํํด์ค ๊ฒ์ด๋ค. ์ด ๋ keepdim=True
๋ก ์ค์ ํ๋ค๋ฉด ์
๋ ฅ ์ฐจ์์์ ๊ฐ์ฅ ํฐ ์์์ ์ธ๋ฑ์ค๋ฅผ ๋ฐํํ๋ ์๋ณธ ํ
์์ ์ฐจ์๊ณผ ๋์ผํ ํํ๋ก ์ถ๋ ฅํด์ค๋ค. example 2
์ ๊ฒฝ์ฐ dim=0
๋ผ์ ํ์ด ๋์ ๋ ๋ฐฉํฅ์ผ๋ก ํ
์๋ฅผ ๋ฐ๋ผ๋ด์ผ ํ๋ค. ํ์ด ๋์ ๋ ๋ฐฉํฅ์ผ๋ก ํ
์๋ฅผ ๋ณด๊ฒ ๋๋ฉด tensor([[0, 1, 1]])
์ด ๋๋ค.
๐ย torch.stack
"""
torch.stack
Args:
tensors(sequence of Tensors): ํ
์๊ฐ ๋ด๊ธด ํ์ด์ฌ ์ํ์ค ๊ฐ์ฒด
dim(int): ์ถ๊ฐํ ์ฐจ์ ๋ฐฉํฅ์ ์ธํ
, ๊ธฐ๋ณธ๊ฐ์ 0
"""
torch.stack(tensors, dim=0)
๋งค๊ฐ๋ณ์๋ก ์ฃผ์ด์ง ํ์ด์ฌ ์ํ์ค ๊ฐ์ฒด(๋ฆฌ์คํธ, ํํ)๋ฅผ ์ฌ์ฉ์๊ฐ ์ง์ ํ ์๋ก์ด ์ฐจ์์ ์๋ ๊ธฐ๋ฅ์ ํ๋ค. ๋งค๊ฐ๋ณ์ tensors
๋ ํ
์๊ฐ ๋ด๊ธด ํ์ด์ฌ์ ์ํ์ค ๊ฐ์ฒด๋ฅผ ์
๋ ฅ์ผ๋ก ๋ฐ๋๋ค. dim
์ ์ฌ์ฉ์๊ฐ ํ
์ ์ ์ฌ๋ฅผ ํ๊ณ ์ถ์ ์๋ก์ด ์ฐจ์์ ์ง์ ํด์ฃผ๋ฉด ๋๋ค. ๊ธฐ๋ณธ๊ฐ์ 0์ฐจ์์ผ๋ก ์ง์ ๋์ด์์ผ๋ฉฐ, ํ
์์ ๋งจ ์์ฐจ์์ด ์๋กญ๊ฒ ์๊ธฐ๊ฒ ๋๋ค. torch.stack
์ ๊ธฐ๊ณํ์ต, ํนํ ๋ฅ๋ฌ๋์์ ์ ๋ง ์์ฃผ ์ฌ์ฉ๋๊ธฐ ๋๋ฌธ์ ์ฌ์ฉ๋ฒ ๋ฐ ์ฌ์ฉ์ํฉ์ ์ตํ๋๋ฉด ๋์์ด ๋๋ค. ์์๋ฅผ ํตํด ํด๋น ๋ฉ์๋๋ฅผ ์ด๋ค ์ํฉ์์ ์ด๋ป๊ฒ ์ฌ์ฉํ๋์ง ์์๋ณด์.
""" torch.stack example """
class Projector(nn.Module):
"""
Making projection matrix(Q, K, V) for each attention head
When you call this class, it returns projection matrix of each attention head
For example, if you call this class with 8 heads, it returns 8 set of projection matrices (Q, K, V)
Args:
num_heads: number of heads in MHA, default 8
dim_head: dimension of each attention head, default 64
"""
def __init__(self, num_heads: int = 8, dim_head: int = 64) -> None:
super(Projector, self).__init__()
self.dim_model = num_heads * dim_head
self.num_heads = num_heads
self.dim_head = dim_head
def __call__(self):
fc_q = nn.Linear(self.dim_model, self.dim_head)
fc_k = nn.Linear(self.dim_model, self.dim_head)
fc_v = nn.Linear(self.dim_model, self.dim_head)
return fc_q, fc_k, fc_v
num_heads = 8
dim_head = 64
projector = Projector(num_heads, dim_head) # init instance
projector_list = [list(projector()) for _ in range(num_heads)] # call instance
x = torch.rand(10, 512, 512) # x.shape: [Batch_Size, Sequence_Length, Dim_model]
Q, K, V = [], [], []
for i in range(num_heads):
Q.append(projector_list[i][0](x)) # [10, 512, 64]
K.append(projector_list[i][1](x)) # [10, 512, 64]
V.append(projector_list[i][2](x)) # [10, 512, 64]
Q = torch.stack(Q, dim=1) # Q.shape: [10, 8, 512, 64]
K = torch.stack(K, dim=1) # K.shape: [10, 8, 512, 64]
V = torch.stack(V, dim=1) # V.shape: [10, 8, 512, 64]
์ ์ฝ๋๋ Transformer
์ Multi-Head Attention
๊ตฌํ์ฒด ์ผ๋ถ๋ฅผ ๋ฐ์ทํด์จ ๊ฒ์ด๋ค. Multi-Head Attention
์ ๊ฐ๋ณ ์ดํ
์
ํด๋๋ณ๋ก ํ๋ ฌ $Q, K, V$๋ฅผ ๊ฐ์ ธ์ผ ํ๋ค. ๋ฐ๋ผ์ ์
๋ ฅ ์๋ฒ ๋ฉ์ ๊ฐ๋ณ ์ดํ
์
ํค๋์ Linear Combination
ํด์ค์ผ ํ๋๋ฐ ํค๋ ๊ฐ์๊ฐ 8๊ฐ๋ ๋๊ธฐ ๋๋ฌธ์ ๊ฐ๋ณ์ ์ผ๋ก Projection Matrix
๋ฅผ ์ ์ธํด์ฃผ๋ ๊ฒ์ ๋งค์ฐ ๋นํจ์จ์ ์ด๋ค. ๋ฐ๋ผ์ ๊ฐ์ฒด Projector
์ ํ๋ ฌ $Q, K, V$์ ๋ํ Projection Matrix
๋ฅผ ์ ์ํด์คฌ๋ค. ์ดํ ํค๋ ๊ฐ์๋งํผ ๊ฐ์ฒด Projector
๋ฅผ ํธ์ถํด ๋ฆฌ์คํธ์ ํด๋๋ณ Projection Matrix
๋ฅผ ๋ด์์ค๋ค. ๊ทธ ๋ค์ torch.stack
์ ์ฌ์ฉํด Attention Head
๋ฐฉํฅ์ ์ฐจ์์ผ๋ก ๋ฆฌ์คํธ ๋ด๋ถ ํ
์๋ค์ ์์์ฃผ๋ฉด ๋๋ค.
๐ขย torch.arange
์ฌ์ฉ์๊ฐ ์ง์ ํ ์์์ ๋ถํฐ ๋์ ๊น์ง ์ผ์ ํ ๊ฐ๊ฒฉ์ผ๋ก ํ
์๋ฅผ ๋์ดํ๋ค. Python์ ๋ด์ฅ ๋ฉ์๋ range
์ ๋์ผํ ์ญํ ์ ํ๋๋ฐ, ๋์ ํ
์ ๊ทธ ๊ฒฐ๊ณผ๋ฅผ ํ
์ ๊ตฌ์กฐ์ฒด๋ก ๋ฐํํ๋ค๊ณ ์๊ฐํ๋ฉด ๋๊ฒ ๋ค.
# torch.arange usage
torch.arange(start=0, end, step=1)
>>> torch.arange(5)
tensor([ 0, 1, 2, 3, 4])
>>> torch.arange(1, 4)
tensor([ 1, 2, 3])
>>> torch.arange(1, 2.5, 0.5)
tensor([ 1.0000, 1.5000, 2.0000])
step
๋งค๊ฐ๋ณ์๋ก ์์๊ฐ ๊ฐ๊ฒฉ ์กฐ์ ์ ํ ์ ์๋๋ฐ, ๊ธฐ๋ณธ์ 1๋ก ์ง์ ๋์ด ์์ผ๋ ์ฐธ๊ณ ํ์. ํ์์ ๊ฒฝ์ฐ์๋ nn.Embedding
์ ์
๋ ฅ ํ
์๋ฅผ ๋ง๋ค ๋ ๊ฐ์ฅ ๋ง์ด ์ฌ์ฉํ๋ค. nn.Embedding
์ ๊ฒฝ์ฐ Input์ผ๋ก IntTensor
, LongTensor
๋ฅผ ๋ฐ๊ฒ ๋์ด ์์ผ๋ ์์๋์.
๐ย torch.repeat
์
๋ ฅ๊ฐ์ผ๋ก ์ฃผ์ด์ง ํ
์๋ฅผ ์ฌ์ฉ์๊ฐ ์ง์ ํ ๋ฐ๋ณต ํ์๋งํผ ํน์ ์ฐจ์ ๋ฐฉํฅ์ผ๋ก ๋๋ฆฐ๋ค. ์๋ฅผ ๋ค๋ฉด [1,2,3] * 3
์ ๊ฒฐ๊ณผ๋ [1, 2, 3, 1, 2, 3, 1, 2, 3]
์ธ๋ฐ, ์ด๊ฒ์ ์ฌ์ฉ์๊ฐ ์ง์ ํ ๋ฐ๋ณต ํ์๋งํผ ํน์ ์ฐจ์์ผ๋ก ์ํํ๊ฒ ๋ค๋ ๊ฒ์ด๋ค. ์๋ ์ฌ์ฉ ์์ ๋ฅผ ํ์ธํด๋ณด์.
# torch.repeat example
>>> x = torch.tensor([1, 2, 3])
>>> x.repeat(4, 2)
tensor([[ 1, 2, 3, 1, 2, 3],
[ 1, 2, 3, 1, 2, 3],
[ 1, 2, 3, 1, 2, 3],
[ 1, 2, 3, 1, 2, 3]])
>>> x.repeat(4, 2, 1)
tensor([[[1, 2, 3],
[1, 2, 3]],
[[1, 2, 3],
[1, 2, 3]],
[[1, 2, 3],
[1, 2, 3]],
[[1, 2, 3],
[1, 2, 3]]])
>>> x.repeat(4, 2, 1).size
torch.Size([4, 2, 3])
>>> x.repeat(4, 2, 2)
tensor([[[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3]],
[[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3]],
[[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3]],
[[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3]]])
$t$๋ฅผ ์ด๋ค ํ
์ ๊ตฌ์กฐ์ฒด $x$์ ์ต๋ ์ฐจ์์ด๋ผ๊ณ ํ์ , $x_t$๋ฅผ ๊ฐ์ฅ ์ผ์ชฝ์ ๋ฃ๊ณ ๊ฐ์ฅ ๋ฎ์ ์ฐจ์์ธ 0์ฐจ์์ ๋ํ ๋ฐ๋ณต ํ์๋ฅผ ์ค๋ฅธ์ชฝ ๋์ ๋์
ํด์ ์ฌ์ฉํ๋ฉด ๋๋ค. (torch.repeat(
$x_t, x_{t-1}, โฆ x_2, x_1, x_0$))
.
# torch.arange & torch.repeate usage example
>>> pos_x = torch.arange(self.num_patches + 1).repeat(inputs.shape[0]).to(inputs)
>>> pos_x.shape
torch.tensor([16, 1025])
ํ์์ ๊ฒฝ์ฐ, position embedding
์ ์
๋ ฅ์ ๋ง๋ค๊ณ ์ถ์ ๋ torch.arange
์ ์ฐ๊ณํด ์์ฃผ ์ฌ์ฉ ํ๋ ๊ฒ ๊ฐ๋ค. ์ ์ฝ๋๋ฅผ ์ฐธ๊ณ ํ์.
๐ฌย torch.clamp
์ ๋ ฅ ํ ์์ ์์๊ฐ์ ์ฌ์ฉ์๊ฐ ์ง์ ํ ์ต๋โข์ต์๊ฐ ๋ฒ์ ์ด๋ด๋ก ์ ํํ๋ ๋ฉ์๋๋ค.
# torch.clamp params
>>> torch.clamp(input, min=None, max=None, *, out=None) โ Tensor
# torch.clamp usage example
>>> a = torch.randn(4)
>>> a
tensor([-1.7120, 0.1734, -0.0478, -0.0922])
>>> torch.clamp(a, min=-0.5, max=0.5)
tensor([-0.5000, 0.1734, -0.0478, -0.0922])
์
๋ ฅ๋ ํ
์์ ์์๋ฅผ ์ง์ ์ต๋โข์ต์ ์ค์ ๊ฐ๊ณผ ํ๋ ํ๋ ๋์กฐํด์ ํ
์ ๋ด๋ถ์ ๋ชจ๋ ์์๊ฐ ์ง์ ๋ฒ์ ์์ ๋ค๋๋ก ๋ง๋ค์ด์ค๋ค. torch.clamp
์ญ์ ๋ค์ํ ์ํฉ์์ ์ฌ์ฉ๋๋๋ฐ, ํ์์ ๊ฒฝ์ฐ ๋ชจ๋ธ ๋ ์ด์ด ์ค๊ฐ์ ์ ๊ณฑ๊ทผ์ด๋ ์ง์, ๋ถ์ ํน์ ๊ฐ๋ ๊ด๋ จ ์ฐ์ฐ์ด ๋ค์ด๊ฐ Backward Pass
์์ NaN
์ด ๋ฐ์ํ ์ ์๋ ๊ฒฝ์ฐ์ ์์ ์ฅ์น๋ก ๋ง์ด ์ฌ์ฉํ๊ณ ์๋ค. (์์ธํ ์๊ณ ์ถ๋ค๋ฉด ํด๋ฆญ)
๐ฉโ๐ฉโ๐งโ๐ฆย torch.gather
ํ
์ ๊ฐ์ฒด ๋ด๋ถ์์ ์ํ๋ ์ธ๋ฑ์ค์ ์์นํ ์์๋ง ์ถ์ถํ๊ณ ์ถ์ ๋ ์ฌ์ฉํ๋ฉด ๋งค์ฐ ์ ์ฉํ ๋ฉ์๋๋ค. ํ
์ ์ญ์ iterable
๊ฐ์ฒด๋ผ์ loop
๋ฅผ ์ฌ์ฉํด ์ ๊ทผํ๋ ๊ฒ์ด ์ง๊ด์ ์ผ๋ก ๋ณด์ผ ์ ์์ผ๋, ํต์์ ์ผ๋ก ํ
์๋ฅผ ์ฌ์ฉํ๋ ์ํฉ์ด๋ผ๋ฉด ๊ฐ์ฒด์ ์ฐจ์์ด ์ด๋ง๋ฌด์ ํ๊ธฐ ๋๋ฌธ์ ๋ฃจํ๋ก ์ ๊ทผํด ๊ด๋ฆฌํ๋ ๊ฒ์ ๋งค์ฐ ๋นํจ์จ์ ์ด๋ค. ๋ฃจํ๋ฅผ ํตํด ์ ๊ทผํ๋ฉด ํ์ด์ฌ์ ๋ด์ฅ ๋ฆฌ์คํธ๋ฅผ ์ฌ์ฉํ๋ ๊ฒ๊ณผ ๋ณ๋ฐ ๋ค๋ฅผ๊ฒ ์์ด์ง๊ธฐ ๋๋ฌธ์, ํ
์๋ฅผ ์ฌ์ฉํ๋ ๋ฉ๋ฆฌํธ๊ฐ ์ฌ๋ผ์ง๋ค. ๋น๊ต์ ํฌ์ง ์์ 2~3์ฐจ์์ ํ
์ ์ ๋๋ผ๋ฉด ์ฌ์ฉํด๋ ํฌ๊ฒ ๋ฌธ์ ๋ ์์๊ฑฐ๋ผ ์๊ฐํ์ง๋ง ๊ทธ๋๋ ์ฝ๋์ ์ผ๊ด์ฑ์ ์ํด torch.gather
์ฌ์ฉ์ ๊ถ์ฅํ๋ค. ์ด์ torch.gather
์ ์ฌ์ฉ๋ฒ์ ๋ํด ์์๋ณด์.
# torch.gather params
>>> torch.gather(input, dim, index, *, sparse_grad=False, out=None)
dim
๊ณผ index
์ ์ฃผ๋ชฉํด๋ณด์. ๋จผ์ dim
์ ์ฌ์ฉ์๊ฐ ์ธ๋ฑ์ฑ์ ์ ์ฉํ๊ณ ์ถ์ ์ฐจ์์ ์ง์ ํด์ฃผ๋ ์ญํ ์ ํ๋ค. index
๋งค๊ฐ๋ณ์๋ก ์ ๋ฌํ๋ ํ
์ ์์๋ ์์์ โ์ธ๋ฑ์คโ
๋ฅผ ์๋ฏธํ๋ ์ซ์๋ค์ด ๋ง๊ตฌ์ก์ด๋ก ๋ด๊ฒจ์๋๋ฐ, ํด๋น ์ธ๋ฑ์ค๊ฐ ๋์ ํ
์์ ์ด๋ ์ฐจ์์ ๊ฐ๋ฆฌํฌ ๊ฒ์ธ์ง๋ฅผ ์ปดํจํฐ์๊ฒ ์๋ ค์ค๋ค๊ณ ์๊ฐํ๋ฉด ๋๋ค. index
๋ ์์์ ์ค๋ช
ํ๋ฏ์ด ์์์ โ์ธ๋ฑ์คโ
๋ฅผ ์๋ฏธํ๋ ์ซ์๋ค์ด ๋ด๊ธด ํ
์๋ฅผ ์
๋ ฅ์ผ๋ก ํ๋ ๋งค๊ฐ๋ณ์๋ค. ์ด ๋ ์ฃผ์ํ ์ ์ ๋์ ํ
์(input
)์ ์ธ๋ฑ์ค ํ
์์ ์ฐจ์ ํํ๊ฐ ๋ฐ๋์ ๋์ผํด์ผ ํ๋ค๋ ๊ฒ์ด๋ค. ์ญ์ ๋ง๋ก๋ง ๋ค์ผ๋ฉด ์ดํดํ๊ธฐ ํ๋๋ ์ฌ์ฉ ์์๋ฅผ ํจ๊ผ ์ดํด๋ณด์.
# torch.gather usage example
>>> q, kr = torch.randn(10, 1024, 64), torch.randn(10, 1024, 64) # [batch, sequence, dim_head], [batch, 2*sequence, dim_head]
>>> tmp_c2p = torch.matmul(q, kr.transpose(-1, -2))
>>> tmp_c2p, tmp_c2p.shape
(tensor([[-2.6477, -4.7478, -5.3250, ..., 1.6062, -1.9717, 3.8004],
[ 0.0662, 1.5240, 0.1182, ..., 0.1653, 2.8476, 1.6337],
[-0.5010, -4.2267, -1.1179, ..., 1.1447, 1.7845, -0.1493],
...,
[-2.1073, -1.2149, -4.8630, ..., 0.8238, -0.5833, -1.2066],
[ 2.1747, 3.2924, 6.5808, ..., -0.2926, -0.2511, 2.6996],
[-2.8362, 2.8700, -0.9729, ..., -4.9913, -0.3616, -0.1708]],
grad_fn=<MmBackward0>)
torch.Size([1024, 1024]))
>>> max_seq, max_relative_position = 1024, 512
>>> q_index, k_index = torch.arange(max_seq), torch.arange(2*max_relative_position)
>>> q_index, k_index
(tensor([ 0, 1, 2, ..., 1021, 1022, 1023]),
tensor([ 0, 1, 2, ..., 1021, 1022, 1023]))
>>> tmp_pos = q_index.view(-1, 1) - k_index.view(1, -1)
>>> rel_pos_matrix = tmp_pos + max_relative_position
>>> rel_pos_matrix
tensor([[ 512, 511, 510, ..., -509, -510, -511],
[ 513, 512, 511, ..., -508, -509, -510],
[ 514, 513, 512, ..., -507, -508, -509],
...,
[1533, 1532, 1531, ..., 512, 511, 510],
[1534, 1533, 1532, ..., 513, 512, 511],
[1535, 1534, 1533, ..., 514, 513, 512]])
>>> rel_pos_matrix = torch.clamp(rel_pos_matrix, 0, 2*max_relative_position - 1).repeat(10, 1, 1)
>>> tmp_c2p = tmp_c2p.repeat(10, 1, 1)
>>> rel_pos_matrix, rel_pos_matrix.shape, tmp_c2p.shape
(tensor([[[ 512, 511, 510, ..., 0, 0, 0],
[ 513, 512, 511, ..., 0, 0, 0],
[ 514, 513, 512, ..., 0, 0, 0],
...,
[1023, 1023, 1023, ..., 512, 511, 510],
[1023, 1023, 1023, ..., 513, 512, 511],
[1023, 1023, 1023, ..., 514, 513, 512]],
[[ 512, 511, 510, ..., 0, 0, 0],
[ 513, 512, 511, ..., 0, 0, 0],
[ 514, 513, 512, ..., 0, 0, 0],
...,
[1023, 1023, 1023, ..., 512, 511, 510],
[1023, 1023, 1023, ..., 513, 512, 511],
[1023, 1023, 1023, ..., 514, 513, 512]],
torch.Size([10, 1024, 1024]),
torch.Size([10, 1024, 1024]))
>>> torch.gather(tmp_c2p, dim=-1, index=rel_pos_matrix)
tensor([[[-0.8579, -0.2178, 1.6323, ..., -2.6477, -2.6477, -2.6477],
[ 1.1601, 2.1752, 0.7187, ..., 0.0662, 0.0662, 0.0662],
[ 3.4379, -1.2573, 0.1375, ..., -0.5010, -0.5010, -0.5010],
...,
[-1.2066, -1.2066, -1.2066, ..., 0.5943, -0.5169, -3.0820],
[ 2.6996, 2.6996, 2.6996, ..., 0.2014, 1.1458, 3.2626],
[-0.1708, -0.1708, -0.1708, ..., 1.9955, 4.1549, 2.6356]],
์ ์ฝ๋๋ DeBERTa
์ Disentangled Self-Attention
์ ๊ตฌํํ ์ฝ๋์ ์ผ๋ถ๋ถ์ด๋ค. ์์ธํ ์๋ฆฌ๋ DeBERTa
๋
ผ๋ฌธ ๋ฆฌ๋ทฐ ํฌ์คํ
์์ ํ์ธํ๋ฉด ๋๊ณ , ์ฐ๋ฆฌ๊ฐ ์ง๊ธ ์ฃผ๋ชฉํ ๋ถ๋ถ์ ๋ฐ๋ก tmp_c2p
, rel_pos_matrix
๊ทธ๋ฆฌ๊ณ ๋ง์ง๋ง ์ค์ ์์นํ torch.gather
๋ค. [10, 1024, 1024]
๋ชจ์์ ๊ฐ์ง ๋์ ํ
์ tmp_c2p
์์ ๋ด๊ฐ ์ํ๋ ์์๋ง ์ถ์ถํ๋ ค๋ ์ํฉ์ธ๋ฐ, ์ถ์ถํด์ผํ ์์์ ์ธ๋ฑ์ค ๊ฐ์ด ๋ด๊ธด ํ
์๋ฅผ rel_pos_matrix
๋ก ์ ์ํ๋ค. rel_pos_matrix
์ ์ฐจ์์ [10, 1024, 1024]
๋ก tmp_c2p
์ ๋์ผํ๋ค. ์ฐธ๊ณ ๋ก ์ถ์ถํด์ผ ํ๋ ์ฐจ์ ๋ฐฉํฅ์ ๊ฐ๋ก ๋ฐฉํฅ(๋ ๋ฒ์งธ 1024)์ด๋ค.
์ด์ torch.gather
์ ๋์์ ์ดํด๋ณด์. ์ฐ๋ฆฌ๊ฐ ํ์ฌ ์ถ์ถํ๊ณ ์ถ์ ๋์์ 3์ฐจ์ ํ
์์ ๊ฐ๋ก ๋ฐฉํฅ(๋ ๋ฒ์งธ 1024, ํ
์์ ํ ๋ฒกํฐ), ์ฆ 2 * max_sequence_length
๋ฅผ ์๋ฏธํ๋ ์ฐจ์ ๋ฐฉํฅ์ ์์๋ค. ๋ฐ๋ผ์ dim=-1
์ผ๋ก ์ค์ ํด์ค๋ค. ์ด์ ๋ฉ์๋๊ฐ ์๋๋๋ก ์ ์ฉ๋์๋์ง ํ์ธํด๋ณด์. rel_pos_matrix
์ 0๋ฒ ๋ฐฐ์น, 0๋ฒ์งธ ์ํ์ค์ ๊ฐ์ฅ ๋ง์ง๋ง ์ฐจ์์ ๊ฐ์ 0
์ผ๋ก ์ด๊ธฐํ ๋์ด ์๋ค. ๋ค์ ๋งํด, ๋์ ํ
์์ ๋์ ์ฐจ์์์ 0๋ฒ์งธ ์ธ๋ฑ์ค์ ํด๋นํ๋ ๊ฐ์ ๊ฐ์ ธ์ค๋ผ๋ ์๋ฏธ๋ฅผ ๋ด๊ณ ์๋ค. ๊ทธ๋ ๋ค๋ฉด torch.gather
์คํ ๊ฒฐ๊ณผ๊ฐ tmp_c2p
์ 0๋ฒ ๋ฐฐ์น, 0๋ฒ์งธ ์ํ์ค์ 0๋ฒ์งธ ์ฐจ์ ๊ฐ๊ณผ ์ผ์นํ๋์ง ํ์ธํด๋ณด์. ๋ ๋ค -2.6477
, -2.6477
์ผ๋ก ๊ฐ์ ๊ฐ์ ๋ํ๋ด๊ณ ์๋ค. ๋ฐ๋ผ์ ์ฐ๋ฆฌ ์๋๋๋ก ์ ์คํ๋์๋ค๋ ์ฌ์ค์ ์ ์ ์๋ค.
๐ฉโ๐ฉโ๐งโ๐ฆย torch.triu, torch.tril
๊ฐ๊ฐ ์
๋ ฅ ํ
์๋ฅผ ์์ผ๊ฐํ๋ ฌ
, ํ์ผ๊ฐํ๋ ฌ
๋ก ๋ง๋ ๋ค. triu
๋ tril
์ ์ฌ์ค ๋ค์ง์ผ๋ฉด ๊ฐ์ ๊ฒฐ๊ณผ๋ฅผ ๋ฐํํ๊ธฐ ๋๋ฌธ์ tril
์ ๊ธฐ์ค์ผ๋ก ์ค๋ช
์ ํ๊ฒ ๋ค. ๋ฉ์๋์ ๋งค๊ฐ๋ณ์๋ ๋ค์๊ณผ ๊ฐ๋ค.
# torch.triu, tril params
upper_tri_matrix = torch.triu(input_tensor, diagonal=0, *, out=None)
lower_tri_matrix = torch.tril(input_tensors, diagonal=0, *, out=None)
diagonal
์ ์ฃผ๋ชฉํด๋ณด์. ์์๋ฅผ ์ ๋ฌํ๋ฉด ์ฃผ๋๊ฐ์ฑ๋ถ์์ ํด๋นํ๋ ๊ฐ๋งํผ ๋จ์ด์ง ๊ณณ์ ๋๊ฐ์ฑ๋ถ๊น์ง ๊ทธ ๊ฐ์ ์ด๋ ค๋๋ค. ํํธ ์์๋ฅผ ์ ๋ฌํ๋ฉด ์ฃผ๋๊ฐ์ฑ๋ถ์ ํฌํจํด ์ฃผ์ด์ง ๊ฐ๋งํผ ๋จ์ด์ง ๊ณณ๊น์ง์ ๋๊ฐ์ฑ๋ถ์ ๋ชจ๋ 0์ผ๋ก ๋ง๋ค์ด๋ฒ๋ฆฐ๋ค. ๊ธฐ๋ณธ์ 0์ผ๋ก ์ค์ ๋์ด ์์ผ๋ฉฐ, ์ด๋ ์ฃผ๋๊ฐ์ฑ๋ถ๋ถํฐ ์ผ์ชฝ ํ๋จ์ ์์๋ฅผ ๋ชจ๋ ์ด๋ ค๋๊ฒ ๋ค๋ ์๋ฏธ๊ฐ ๋๋ค.
# torch.tril usage example
>>> lm_mask = torch.tril(torch.ones(x.shape[0], x.shape[-1], x.shape[-1]))
>>> lm_mask
1 0 0 0 0
1 1 0 0 0
1 1 1 0 0
1 1 1 1 0
๋ ๋ฉ์๋๋ ์ ํ๋์ํ์ด ํ์ํ ๋ค์ํ ๋ถ์ผ์์ ์ฌ์ฉ๋๋๋ฐ, ํ์์ ๊ฒฝ์ฐ, GPT
์ฒ๋ผ Transformer
์ Decoder
๋ฅผ ์ฌ์ฉํ๋ ๋ชจ๋ธ์ ๋น๋ํ ๋ ๊ฐ์ฅ ๋ง์ด ์ฌ์ฉํ๋ ๊ฒ ๊ฐ๋ค. Decoder
๋ฅผ ์ฌ์ฉํ๋ ๋ชจ๋ธ์ ๋๋ถ๋ถ ๊ตฌ์กฐ์ Language Modeling
์ ์ํด์ Masked Multi-Head Self-Attention Block
์ ์ฌ์ฉํ๋๋ฐ ์ด ๋ ๋ฏธ๋ ์์ ์ ํ ํฐ ์๋ฒ ๋ฉ ๊ฐ์ ๋ง์คํน์ ํด์ฃผ๊ธฐ ์ํด torch.tril
์ ์ฌ์ฉํ๊ฒ ๋๋ ์ฐธ๊ณ ํ์.
๐ฉโ๐ฉโ๐งโ๐ฆย torch.Tensor.masked_fill
์ฌ์ฉ์๊ฐ ์ง์ ํ ๊ฐ์ ํด๋น๋๋ ์์๋ฅผ ๋ชจ๋ ๋ง์คํน ์ฒ๋ฆฌํด์ฃผ๋ ๋ฉ์๋๋ค. ๋จผ์ ๋งค๊ฐ๋ณ์๋ฅผ ํ์ธํด๋ณด์.
# torch.Tensor.masked_fill params
input_tensors = torch.Tensor([[1,2,3], [4,5,6]])
input_tensors.masked_fill(mask: BoolTensor, value: float)
masked_fill
์ ํ
์ ๊ฐ์ฒด์ ๋ด๋ถ attribute
๋ก ์ ์๋๊ธฐ ๋๋ฌธ์ ํด๋น ๋ฉ์๋๋ฅผ ์ฌ์ฉํ๊ณ ์ถ๋ค๋ฉด ๋จผ์ ๋ง์คํน ๋์ ํ
์๋ฅผ ๋ง๋ค์ด์ผ ํ๋ค. ํ
์๋ฅผ ์ ์ํ๋ค๋ฉด ํ
์ ๊ฐ์ฒด์ attributes
์ ๊ทผ์ ํตํด masked_fill()
์ ํธ์ถํ ๋ค, ํ์ํ ๋งค๊ฐ๋ณ์๋ฅผ ์ ๋ฌํด์ฃผ๋ ๋ฐฉ์์ผ๋ก ์ฌ์ฉํ๋ฉด ๋๋ค.
mask
๋งค๊ฐ๋ณ์์๋ ๋ง์คํน ํ
์๋ฅผ ์ ๋ฌํด์ผ ํ๋๋ฐ, ์ด ๋ ๋ด๋ถ ์์๋ ๋ชจ๋ boolean
์ด์ด์ผ ํ๊ณ ํ
์์ ํํ๋ ๋์ ํ
์์ ๋์ผํด์ผ ํ๋ค(์์ ํ ๊ฐ์ ํ์๋ ์๊ณ , ๋ธ๋ก๋ ์บ์คํ
๋ง ๊ฐ๋ฅํ๋ฉด ์๊ด ์์).
value
๋งค๊ฐ๋ณ์์๋ ๋ง์คํน ๋์ ์์๋ค์ ์ผ๊ด์ ์ผ๋ก ์ ์ฉํด์ฃผ๊ณ ์ถ์ ๊ฐ์ ์ ๋ฌํ๋ค. ์ด๊ฒ ๋ง๋ก๋ง ๋ค์ผ๋ฉด ์ดํดํ๊ธฐ ์ฝ์ง ์๋ค. ์๋ ์ฌ์ฉ ์์๋ฅผ ํจ๊ป ์ฒจ๋ถํ์ผ๋ ์ฐธ๊ณ ๋ฐ๋๋ค.
# torch.masked_fill usage
>>> lm_mask = torch.tril(torch.ones(x.shape[0], x.shape[-1], x.shape[-1]))
>>> lm_mask
1 0 0 0 0
1 1 0 0 0
1 1 1 0 0
1 1 1 1 0
>>> attention_matrix = torch.matmul(q, k.transpose(-1, -2)) / dot_scale
>>> attention_matrix
1.22 2.1 3.4 1.2 1.1
1.22 2.1 3.4 9.9 9.9
1.22 2.1 3.4 9.9 9.9
1.22 2.1 3.4 9.9 9.9
>>> attention_matrix = attention_matrix.masked_fill(lm_mask == 0, float('-inf'))
>>> attention_matrix
1.22 -inf -inf -inf -inf
1.22 2.1 -inf -inf -inf
1.22 2.1 3.4 -inf -inf
1.22 2.1 3.4 9.9 -inf
๐๏ธย torch.clone
inputs
์ธ์๋ก ์ ๋ฌํ ํ
์๋ฅผ ๋ณต์ฌํ๋ ํ์ดํ ์น ๋ด์ฅ ๋ฉ์๋๋ค. ์ฌ์ฉ๋ฒ์ ์๋์ ๊ฐ๋ค.
""" torch.clone """
torch.clone(
input,ย
*,
ย memory_format=torch.preserve_format
)ย โย [Tensor]
๋ฅ๋ฌ๋ ํ์ดํ๋ผ์ธ์ ๋ง๋ค๋ค ๋ณด๋ฉด ๋ง์ด ์ฌ์ฉํ๊ฒ ๋๋ ๊ธฐ๋ณธ์ ์ธ ๋ฉ์๋์ธ๋ฐ, ์ด๋ ๊ฒ ๋ฐ๋ก ์ ๋ฆฌํ๊ฒ ๋ ์ด์ ๊ฐ ์๋ค. ์ ๋ ฅ๋ ํ ์๋ฅผ ๊ทธ๋๋ก ๋ณต์ฌํ๋ค๋ ํน์ฑ ๋๋ฌธ์ ์ฌ์ฉ์ ์ฃผ์ํด์ผ ํ ์ ์ด ์๊ธฐ ๋๋ฌธ์ด๋ค. ํด๋น ๋ฉ์๋๋ฅผ ์ฌ์ฉํ๊ธฐ ์ ์ ๋ฐ๋์ ์ ๋ ฅํ ํ ์๊ฐ ํ์ฌ ์ด๋ ๋๋ฐ์ด์ค(CPU, GPU) ์์ ์๋์ง, ๊ทธ๋ฆฌ๊ณ ํด๋น ํ ์๊ฐ ๊ณ์ฐ ๊ทธ๋ํ๋ฅผ ๊ฐ์ง๊ณ ์๋์ง๋ฅผ ๋ฐ๋์ ํ์ ํด์ผ ํ๋ค.
ํ์๋ ELECTRA ๋ชจ๋ธ์ ์ง์ ๊ตฌํํ๋ ๊ณผ์ ์์ clone()
๋ฉ์๋๋ฅผ ์ฌ์ฉํ๋๋ฐ, Generator ๋ชจ๋ธ์ ๊ฒฐ๊ณผ ๋ก์ง์ Discriminator์ ์
๋ ฅ์ผ๋ก ๋ณํํด์ฃผ๊ธฐ ์ํจ์ด์๋ค. ๊ทธ ๊ณผ์ ์์ Generator๊ฐ ๋ฐํํ ๋ก์ง์ ๊ทธ๋๋ก clone
ํ ๋ค, ์
๋ ฅ์ ๋ง๋ค์ด ์ฃผ์๊ณ ๊ทธ ๊ฒฐ๊ณผ ์๋์ ๊ฐ์ ์๋ฌ๋ฅผ ๋ง์ฃผํ๋ค.
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [
torch.cuda.LongTensor [8, 511]] is at version 1; expected version 0
instead. Hint: the backtrace further above shows the operation that failed to compute its gradient.
The variable in question was changed in there or anywhere later. Good luck!
์๋ฌ ๋ก๊ทธ๋ฅผ ์์ธํ ์ฝ์ด๋ณด๋ฉด ํ
์ ๋ฒ์ ์ ๋ณ๊ฒฝ์ผ๋ก ์ธํด ๊ทธ๋ผ๋์ธํธ ๊ณ์ฐ์ด ๋ถ๊ฐํ๋ค๋ ๋ด์ฉ์ด ๋ด๊ฒจ์๋ค. ๊ตฌ๊ธ๋งํด๋ด๋ ์ ์๋์์ ํฌ๊ธฐํ๋ ค๋ ์ฐฐ๋ผ์ ์ฐ์ฐํ torch.clone()
๋ฉ์๋์ ์ ํํ ์ฌ์ฉ๋ฒ์ด ๊ถ๊ธํด ๊ณต์ Docs๋ฅผ ์ฝ๊ฒ ๋์๊ณ , ๊ฑฐ๊ธฐ์ ์์ฒญ๋ ์ฌ์ค์ ๋ฐ๊ฒฌํ๋ค. clone()
๋ฉ์๋๊ฐ ์
๋ ฅ๋ ํ
์์ ํ์ฌ ๋๋ฐ์ด์ค ์์น์ ๋๊ฐ์ด ๋ณต์ฌ๋ ๊ฒ์ด๋ ์์์ ํ์ง๋ง, ์
๋ ฅ ํ
์์ ๊ณ์ฐ๊ทธ๋ํ๊น์ง ๋ณต์ฌ๋ ๊ฒ์ด๋ ์๊ฐ์ ์ ํ ํ์ง ๋ชปํ๊ธฐ ๋๋ฌธ์ด๋ค. ๊ทธ๋์ ์์ ๊ฐ์ ์๋ฌ๋ฅผ ๋ง์ฃผํ์ง ์์ผ๋ ค๋ฉด, clone()
์ ํธ์ถํ ๋ ๋ค์ ๋ฐ๋์ detach()
๋ฅผ ํจ๊ป ํธ์ถํด์ค์ผ ํ๋ค.
clone()
๋ฉ์๋๋ ์
๋ ฅ๋ ํ
์์ ๋ชจ๋ ๊ฒ์ ๋ณต์ฌํ๋ค๋ ์ ์ ๋ฐ๋์ ๊ธฐ์ตํ์.
Leave a comment