Updated:

ํŒŒ์ดํ† ์น˜์—์„œ ํ•„์ž๊ฐ€ ์ž์ฃผ ์‚ฌ์šฉํ•˜๋Š” ํ…์„œ ์ธ๋ฑ์‹ฑ ๊ด€๋ จ ๋ฉ”์„œ๋“œ์˜ ์‚ฌ์šฉ๋ฒ• ๋ฐ ์‚ฌ์šฉ ์˜ˆ์‹œ๋ฅผ ํ•œ๋ฐฉ์— ์ •๋ฆฌํ•œ ํฌ์ŠคํŠธ๋‹ค. ๋ฉ”์„œ๋“œ ํ•˜๋‚˜๋‹น ํ•˜๋‚˜์˜ ํฌ์ŠคํŠธ๋กœ ๋งŒ๋“ค๊ธฐ์—๋Š” ๋„ˆ๋ฌด ๊ธธ์ด๊ฐ€ ์งง๋‹ค ์ƒ๊ฐํ•ด ํ•œ ํŽ˜์ด์ง€์— ๋ชจ๋‘ ๋„ฃ๊ฒŒ ๋˜์—ˆ๋‹ค. ์ง€์†์ ์œผ๋กœ ์—…๋ฐ์ดํŠธ ๋  ์˜ˆ์ •์ด๋‹ค. ๋˜ํ•œ ํ…์„œ ์ธ๋ฑ์‹ฑ ๋ง๊ณ ๋„ ๋‹ค๋ฅธ ์ฃผ์ œ๋กœ๋„ ๊ด€๋ จ ๋ฉ”์„œ๋“œ๋ฅผ ์ •๋ฆฌํ•ด ์˜ฌ๋ฆด ์˜ˆ์ •์ด๋‹ˆ ๋งŽ์€ ๊ด€์‹ฌ ๋ถ€ํƒ๋“œ๋ฆฐ๋‹ค.

๐Ÿ”Žย torch.argmax

์ž…๋ ฅ ํ…์„œ์—์„œ ๊ฐ€์žฅ ํฐ ๊ฐ’์„ ๊ฐ–๊ณ  ์žˆ๋Š” ์›์†Œ์˜ ์ธ๋ฑ์Šค๋ฅผ ๋ฐ˜ํ™˜ํ•œ๋‹ค. ์ตœ๋Œ€๊ฐ’์„ ์ฐพ์„ ์ฐจ์›์„ ์ง€์ •ํ•ด์ค„ ์ˆ˜ ์žˆ๋‹ค. ์•„๋ž˜ ์˜ˆ์‹œ ์ฝ”๋“œ๋ฅผ ํ™•์ธํ•ด๋ณด์ž.

# torch.argmax params
torch.argmax(tensor, dim=None, keepdim=False)

# torch.argmax example 1
test = torch.tensor([1,29,2,45,22,3])
torch.argmax(test)
torch.argmax(test2, keepdim=True)

<Result>
tensor(3)

# torch.argmax example 2
test = torch.tensor([[4, 2, 3],
                     [4, 5, 6]])

torch.argmax(test, dim=0, keepdim=True)
<Result>
tensor([[0, 1, 1]])

# torch.argmax example 3
test = torch.tensor([[4, 2, 3],
                     [4, 5, 6]])

torch.argmax(test, dim=1, keepdim=True)
tensor([[0],
        [2]])

dim ๋งค๊ฐœ๋ณ€์ˆ˜์— ์›ํ•˜๋Š” ์ฐจ์›์„ ์ž…๋ ฅํ•˜๋ฉด ํ•ด๋‹น ์ฐจ์› ๋ทฐ์—์„œ ๊ฐ€์žฅ ํฐ ์›์†Œ๋ฅผ ์ฐพ์•„ ์ธ๋ฑ์Šค ๊ฐ’์„ ๋ฐ˜ํ™˜ํ•ด์ค„ ๊ฒƒ์ด๋‹ค. ์ด ๋•Œ keepdim=True ๋กœ ์„ค์ •ํ•œ๋‹ค๋ฉด ์ž…๋ ฅ ์ฐจ์›์—์„œ ๊ฐ€์žฅ ํฐ ์›์†Œ์˜ ์ธ๋ฑ์Šค๋ฅผ ๋ฐ˜ํ™˜ํ•˜๋˜ ์›๋ณธ ํ…์„œ์˜ ์ฐจ์›๊ณผ ๋™์ผํ•œ ํ˜•ํƒœ๋กœ ์ถœ๋ ฅํ•ด์ค€๋‹ค. example 2 ์˜ ๊ฒฝ์šฐ dim=0 ๋ผ์„œ ํ–‰์ด ๋ˆ„์ ๋œ ๋ฐฉํ–ฅ์œผ๋กœ ํ…์„œ๋ฅผ ๋ฐ”๋ผ๋ด์•ผ ํ•œ๋‹ค. ํ–‰์ด ๋ˆ„์ ๋œ ๋ฐฉํ–ฅ์œผ๋กœ ํ…์„œ๋ฅผ ๋ณด๊ฒŒ ๋˜๋ฉด tensor([[0, 1, 1]])์ด ๋œ๋‹ค.

๐Ÿ“šย torch.stack

"""
torch.stack
Args:
	tensors(sequence of Tensors): ํ…์„œ๊ฐ€ ๋‹ด๊ธด ํŒŒ์ด์ฌ ์‹œํ€€์Šค ๊ฐ์ฒด
	dim(int): ์ถ”๊ฐ€ํ•  ์ฐจ์› ๋ฐฉํ–ฅ์„ ์„ธํŒ…, ๊ธฐ๋ณธ๊ฐ’์€ 0
"""
torch.stack(tensors, dim=0)

๋งค๊ฐœ๋ณ€์ˆ˜๋กœ ์ฃผ์–ด์ง„ ํŒŒ์ด์ฌ ์‹œํ€€์Šค ๊ฐ์ฒด(๋ฆฌ์ŠคํŠธ, ํŠœํ”Œ)๋ฅผ ์‚ฌ์šฉ์ž๊ฐ€ ์ง€์ •ํ•œ ์ƒˆ๋กœ์šด ์ฐจ์›์— ์Œ“๋Š” ๊ธฐ๋Šฅ์„ ํ•œ๋‹ค. ๋งค๊ฐœ๋ณ€์ˆ˜ tensors ๋Š” ํ…์„œ๊ฐ€ ๋‹ด๊ธด ํŒŒ์ด์ฌ์˜ ์‹œํ€€์Šค ๊ฐ์ฒด๋ฅผ ์ž…๋ ฅ์œผ๋กœ ๋ฐ›๋Š”๋‹ค. dim ์€ ์‚ฌ์šฉ์ž๊ฐ€ ํ…์„œ ์ ์žฌ๋ฅผ ํ•˜๊ณ  ์‹ถ์€ ์ƒˆ๋กœ์šด ์ฐจ์›์„ ์ง€์ •ํ•ด์ฃผ๋ฉด ๋œ๋‹ค. ๊ธฐ๋ณธ๊ฐ’์€ 0์ฐจ์›์œผ๋กœ ์ง€์ • ๋˜์–ด์žˆ์œผ๋ฉฐ, ํ…์„œ์˜ ๋งจ ์•ž์ฐจ์›์ด ์ƒˆ๋กญ๊ฒŒ ์ƒ๊ธฐ๊ฒŒ ๋œ๋‹ค. torch.stack ์€ ๊ธฐ๊ณ„ํ•™์Šต, ํŠนํžˆ ๋”ฅ๋Ÿฌ๋‹์—์„œ ์ •๋ง ์ž์ฃผ ์‚ฌ์šฉ๋˜๊ธฐ ๋•Œ๋ฌธ์— ์‚ฌ์šฉ๋ฒ• ๋ฐ ์‚ฌ์šฉ์ƒํ™ฉ์„ ์ตํ˜€๋‘๋ฉด ๋„์›€์ด ๋œ๋‹ค. ์˜ˆ์‹œ๋ฅผ ํ†ตํ•ด ํ•ด๋‹น ๋ฉ”์„œ๋“œ๋ฅผ ์–ด๋–ค ์ƒํ™ฉ์—์„œ ์–ด๋–ป๊ฒŒ ์‚ฌ์šฉํ•˜๋Š”์ง€ ์•Œ์•„๋ณด์ž.

""" torch.stack example """

class Projector(nn.Module):
    """
    Making projection matrix(Q, K, V) for each attention head
    When you call this class, it returns projection matrix of each attention head
    For example, if you call this class with 8 heads, it returns 8 set of projection matrices (Q, K, V)
    Args:
        num_heads: number of heads in MHA, default 8
        dim_head: dimension of each attention head, default 64
    """
    def __init__(self, num_heads: int = 8, dim_head: int = 64) -> None:
        super(Projector, self).__init__()
        self.dim_model = num_heads * dim_head
        self.num_heads = num_heads
        self.dim_head = dim_head

    def __call__(self):
        fc_q = nn.Linear(self.dim_model, self.dim_head)
        fc_k = nn.Linear(self.dim_model, self.dim_head)
        fc_v = nn.Linear(self.dim_model, self.dim_head)
        return fc_q, fc_k, fc_v

num_heads = 8
dim_head = 64
projector = Projector(num_heads, dim_head)  # init instance
projector_list = [list(projector()) for _ in range(num_heads)]  # call instance
x = torch.rand(10, 512, 512) # x.shape: [Batch_Size, Sequence_Length, Dim_model]
Q, K, V = [], [], []

for i in range(num_heads):
    Q.append(projector_list[i][0](x)) # [10, 512, 64]
    K.append(projector_list[i][1](x)) # [10, 512, 64]
	  V.append(projector_list[i][2](x)) # [10, 512, 64]
 
Q = torch.stack(Q, dim=1) # Q.shape: [10, 8, 512, 64]
K = torch.stack(K, dim=1) # K.shape: [10, 8, 512, 64]
V = torch.stack(V, dim=1) # V.shape: [10, 8, 512, 64]

์œ„ ์ฝ”๋“œ๋Š” Transformer ์˜ Multi-Head Attention ๊ตฌํ˜„์ฒด ์ผ๋ถ€๋ฅผ ๋ฐœ์ทŒํ•ด์˜จ ๊ฒƒ์ด๋‹ค. Multi-Head Attention ์€ ๊ฐœ๋ณ„ ์–ดํ…์…˜ ํ•ด๋“œ๋ณ„๋กœ ํ–‰๋ ฌ $Q, K, V$๋ฅผ ๊ฐ€์ ธ์•ผ ํ•œ๋‹ค. ๋”ฐ๋ผ์„œ ์ž…๋ ฅ ์ž„๋ฒ ๋”ฉ์„ ๊ฐœ๋ณ„ ์–ดํ…์…˜ ํ—ค๋“œ์— Linear Combination ํ•ด์ค˜์•ผ ํ•˜๋Š”๋ฐ ํ—ค๋“œ ๊ฐœ์ˆ˜๊ฐ€ 8๊ฐœ๋‚˜ ๋˜๊ธฐ ๋•Œ๋ฌธ์— ๊ฐœ๋ณ„์ ์œผ๋กœ Projection Matrix ๋ฅผ ์„ ์–ธํ•ด์ฃผ๋Š” ๊ฒƒ์€ ๋งค์šฐ ๋น„ํšจ์œจ์ ์ด๋‹ค. ๋”ฐ๋ผ์„œ ๊ฐ์ฒด Projector ์— ํ–‰๋ ฌ $Q, K, V$์— ๋Œ€ํ•œ Projection Matrix ๋ฅผ ์ •์˜ํ•ด์คฌ๋‹ค. ์ดํ›„ ํ—ค๋“œ ๊ฐœ์ˆ˜๋งŒํผ ๊ฐ์ฒด Projector ๋ฅผ ํ˜ธ์ถœํ•ด ๋ฆฌ์ŠคํŠธ์— ํ•ด๋“œ๋ณ„ Projection Matrix ๋ฅผ ๋‹ด์•„์ค€๋‹ค. ๊ทธ ๋‹ค์Œ torch.stack์„ ์‚ฌ์šฉํ•ด Attention Head ๋ฐฉํ–ฅ์˜ ์ฐจ์›์œผ๋กœ ๋ฆฌ์ŠคํŠธ ๋‚ด๋ถ€ ํ…์„œ๋“ค์„ ์Œ“์•„์ฃผ๋ฉด ๋œ๋‹ค.

๐Ÿ”ขย torch.arange

์‚ฌ์šฉ์ž๊ฐ€ ์ง€์ •ํ•œ ์‹œ์ž‘์ ๋ถ€ํ„ฐ ๋์ ๊นŒ์ง€ ์ผ์ •ํ•œ ๊ฐ„๊ฒฉ์œผ๋กœ ํ…์„œ๋ฅผ ๋‚˜์—ดํ•œ๋‹ค. Python์˜ ๋‚ด์žฅ ๋ฉ”์„œ๋“œ range์™€ ๋™์ผํ•œ ์—ญํ• ์„ ํ•˜๋Š”๋ฐ, ๋Œ€์‹  ํ…์„œ ๊ทธ ๊ฒฐ๊ณผ๋ฅผ ํ…์„œ ๊ตฌ์กฐ์ฒด๋กœ ๋ฐ˜ํ™˜ํ•œ๋‹ค๊ณ  ์ƒ๊ฐํ•˜๋ฉด ๋˜๊ฒ ๋‹ค.

# torch.arange usage
torch.arange(start=0, end, step=1)

>>> torch.arange(5)
tensor([ 0,  1,  2,  3,  4])

>>> torch.arange(1, 4)
tensor([ 1,  2,  3])

>>> torch.arange(1, 2.5, 0.5)
tensor([ 1.0000,  1.5000,  2.0000])

step ๋งค๊ฐœ๋ณ€์ˆ˜๋กœ ์›์†Œ๊ฐ„ ๊ฐ„๊ฒฉ ์กฐ์ •์„ ํ•  ์ˆ˜ ์žˆ๋Š”๋ฐ, ๊ธฐ๋ณธ์€ 1๋กœ ์ง€์ • ๋˜์–ด ์žˆ์œผ๋‹ˆ ์ฐธ๊ณ ํ•˜์ž. ํ•„์ž์˜ ๊ฒฝ์šฐ์—๋Š” nn.Embedding์˜ ์ž…๋ ฅ ํ…์„œ๋ฅผ ๋งŒ๋“ค ๋•Œ ๊ฐ€์žฅ ๋งŽ์ด ์‚ฌ์šฉํ–ˆ๋‹ค. nn.Embedding ์˜ ๊ฒฝ์šฐ Input์œผ๋กœ IntTensor, LongTensor๋ฅผ ๋ฐ›๊ฒŒ ๋˜์–ด ์žˆ์œผ๋‹ˆ ์•Œ์•„๋‘์ž.

๐Ÿ”ย torch.repeat

์ž…๋ ฅ๊ฐ’์œผ๋กœ ์ฃผ์–ด์ง„ ํ…์„œ๋ฅผ ์‚ฌ์šฉ์ž๊ฐ€ ์ง€์ •ํ•œ ๋ฐ˜๋ณต ํšŸ์ˆ˜๋งŒํผ ํŠน์ • ์ฐจ์› ๋ฐฉํ–ฅ์œผ๋กœ ๋Š˜๋ฆฐ๋‹ค. ์˜ˆ๋ฅผ ๋“ค๋ฉด [1,2,3] * 3์˜ ๊ฒฐ๊ณผ๋Š” [1, 2, 3, 1, 2, 3, 1, 2, 3] ์ธ๋ฐ, ์ด๊ฒƒ์„ ์‚ฌ์šฉ์ž๊ฐ€ ์ง€์ •ํ•œ ๋ฐ˜๋ณต ํšŸ์ˆ˜๋งŒํผ ํŠน์ • ์ฐจ์›์œผ๋กœ ์ˆ˜ํ–‰ํ•˜๊ฒ ๋‹ค๋Š” ๊ฒƒ์ด๋‹ค. ์•„๋ž˜ ์‚ฌ์šฉ ์˜ˆ์ œ๋ฅผ ํ™•์ธํ•ด๋ณด์ž.

# torch.repeat example

>>> x = torch.tensor([1, 2, 3])
>>> x.repeat(4, 2)
tensor([[ 1,  2,  3,  1,  2,  3],
        [ 1,  2,  3,  1,  2,  3],
        [ 1,  2,  3,  1,  2,  3],
        [ 1,  2,  3,  1,  2,  3]])

>>> x.repeat(4, 2, 1)
tensor([[[1, 2, 3],
         [1, 2, 3]],

        [[1, 2, 3],
         [1, 2, 3]],

        [[1, 2, 3],
         [1, 2, 3]],

        [[1, 2, 3],
         [1, 2, 3]]])

>>> x.repeat(4, 2, 1).size
torch.Size([4, 2, 3])

>>> x.repeat(4, 2, 2)
tensor([[[1, 2, 3, 1, 2, 3],
         [1, 2, 3, 1, 2, 3]],

        [[1, 2, 3, 1, 2, 3],
         [1, 2, 3, 1, 2, 3]],

        [[1, 2, 3, 1, 2, 3],
         [1, 2, 3, 1, 2, 3]],

        [[1, 2, 3, 1, 2, 3],
         [1, 2, 3, 1, 2, 3]]])

$t$๋ฅผ ์–ด๋–ค ํ…์„œ ๊ตฌ์กฐ์ฒด $x$์˜ ์ตœ๋Œ€ ์ฐจ์›์ด๋ผ๊ณ  ํ–ˆ์„ , $x_t$๋ฅผ ๊ฐ€์žฅ ์™ผ์ชฝ์— ๋„ฃ๊ณ  ๊ฐ€์žฅ ๋‚ฎ์€ ์ฐจ์›์ธ 0์ฐจ์›์— ๋Œ€ํ•œ ๋ฐ˜๋ณต ํšŸ์ˆ˜๋ฅผ ์˜ค๋ฅธ์ชฝ ๋์— ๋Œ€์ž…ํ•ด์„œ ์‚ฌ์šฉํ•˜๋ฉด ๋œ๋‹ค. (torch.repeat($x_t, x_{t-1}, โ€ฆ x_2, x_1, x_0$)).

# torch.arange & torch.repeate usage example

>>> pos_x = torch.arange(self.num_patches + 1).repeat(inputs.shape[0]).to(inputs)
>>> pos_x.shape
torch.tensor([16, 1025])

ํ•„์ž์˜ ๊ฒฝ์šฐ, position embedding์˜ ์ž…๋ ฅ์„ ๋งŒ๋“ค๊ณ  ์‹ถ์„ ๋•Œ torch.arange ์™€ ์—ฐ๊ณ„ํ•ด ์ž์ฃผ ์‚ฌ์šฉ ํ–ˆ๋˜ ๊ฒƒ ๊ฐ™๋‹ค. ์œ„ ์ฝ”๋“œ๋ฅผ ์ฐธ๊ณ ํ•˜์ž.

๐Ÿ”ฌย torch.clamp

์ž…๋ ฅ ํ…์„œ์˜ ์›์†Œ๊ฐ’์„ ์‚ฌ์šฉ์ž๊ฐ€ ์ง€์ •ํ•œ ์ตœ๋Œ€โ€ข์ตœ์†Œ๊ฐ’ ๋ฒ”์œ„ ์ด๋‚ด๋กœ ์ œํ•œํ•˜๋Š” ๋ฉ”์„œ๋“œ๋‹ค.

# torch.clamp params

>>> torch.clamp(input, min=None, max=None, *, out=None) โ†’ Tensor

# torch.clamp usage example

>>> a = torch.randn(4)
>>> a
tensor([-1.7120,  0.1734, -0.0478, -0.0922])

>>> torch.clamp(a, min=-0.5, max=0.5)
tensor([-0.5000,  0.1734, -0.0478, -0.0922])

์ž…๋ ฅ๋œ ํ…์„œ์˜ ์›์†Œ๋ฅผ ์ง€์ • ์ตœ๋Œ€โ€ข์ตœ์†Œ ์„ค์ •๊ฐ’๊ณผ ํ•˜๋‚˜ ํ•˜๋‚˜ ๋Œ€์กฐํ•ด์„œ ํ…์„œ ๋‚ด๋ถ€์˜ ๋ชจ๋“  ์›์†Œ๊ฐ€ ์ง€์ • ๋ฒ”์œ„ ์•ˆ์— ๋“ค๋„๋ก ๋งŒ๋“ค์–ด์ค€๋‹ค. torch.clamp ์—ญ์‹œ ๋‹ค์–‘ํ•œ ์ƒํ™ฉ์—์„œ ์‚ฌ์šฉ๋˜๋Š”๋ฐ, ํ•„์ž์˜ ๊ฒฝ์šฐ ๋ชจ๋ธ ๋ ˆ์ด์–ด ์ค‘๊ฐ„์— ์ œ๊ณฑ๊ทผ์ด๋‚˜ ์ง€์ˆ˜, ๋ถ„์ˆ˜ ํ˜น์€ ๊ฐ๋„ ๊ด€๋ จ ์—ฐ์‚ฐ์ด ๋“ค์–ด๊ฐ€ Backward Pass์—์„œ NaN์ด ๋ฐœ์ƒํ•  ์ˆ˜ ์žˆ๋Š” ๊ฒฝ์šฐ์— ์•ˆ์ „์žฅ์น˜๋กœ ๋งŽ์ด ์‚ฌ์šฉํ•˜๊ณ  ์žˆ๋‹ค. (์ž์„ธํžˆ ์•Œ๊ณ  ์‹ถ๋‹ค๋ฉด ํด๋ฆญ)

๐Ÿ‘ฉโ€๐Ÿ‘ฉโ€๐Ÿ‘งโ€๐Ÿ‘ฆย torch.gather

ํ…์„œ ๊ฐ์ฒด ๋‚ด๋ถ€์—์„œ ์›ํ•˜๋Š” ์ธ๋ฑ์Šค์— ์œ„์น˜ํ•œ ์›์†Œ๋งŒ ์ถ”์ถœํ•˜๊ณ  ์‹ถ์„ ๋•Œ ์‚ฌ์šฉํ•˜๋ฉด ๋งค์šฐ ์œ ์šฉํ•œ ๋ฉ”์„œ๋“œ๋‹ค. ํ…์„œ ์—ญ์‹œ iterable ๊ฐ์ฒด๋ผ์„œ loop ๋ฅผ ์‚ฌ์šฉํ•ด ์ ‘๊ทผํ•˜๋Š” ๊ฒƒ์ด ์ง๊ด€์ ์œผ๋กœ ๋ณด์ผ ์ˆ˜ ์žˆ์œผ๋‚˜, ํ†ต์ƒ์ ์œผ๋กœ ํ…์„œ๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ์ƒํ™ฉ์ด๋ผ๋ฉด ๊ฐ์ฒด์˜ ์ฐจ์›์ด ์–ด๋งˆ๋ฌด์‹œ ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ๋ฃจํ”„๋กœ ์ ‘๊ทผํ•ด ๊ด€๋ฆฌํ•˜๋Š” ๊ฒƒ์€ ๋งค์šฐ ๋น„ํšจ์œจ์ ์ด๋‹ค. ๋ฃจํ”„๋ฅผ ํ†ตํ•ด ์ ‘๊ทผํ•˜๋ฉด ํŒŒ์ด์ฌ์˜ ๋‚ด์žฅ ๋ฆฌ์ŠคํŠธ๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ๊ณผ ๋ณ„๋ฐ˜ ๋‹ค๋ฅผ๊ฒŒ ์—†์–ด์ง€๊ธฐ ๋•Œ๋ฌธ์—, ํ…์„œ๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๋ฉ”๋ฆฌํŠธ๊ฐ€ ์‚ฌ๋ผ์ง„๋‹ค. ๋น„๊ต์  ํฌ์ง€ ์•Š์€ 2~3์ฐจ์›์˜ ํ…์„œ ์ •๋„๋ผ๋ฉด ์‚ฌ์šฉํ•ด๋„ ํฌ๊ฒŒ ๋ฌธ์ œ๋Š” ์—†์„๊ฑฐ๋ผ ์ƒ๊ฐํ•˜์ง€๋งŒ ๊ทธ๋ž˜๋„ ์ฝ”๋“œ์˜ ์ผ๊ด€์„ฑ์„ ์œ„ํ•ด torch.gather ์‚ฌ์šฉ์„ ๊ถŒ์žฅํ•œ๋‹ค. ์ด์ œ torch.gather์˜ ์‚ฌ์šฉ๋ฒ•์— ๋Œ€ํ•ด ์•Œ์•„๋ณด์ž.

# torch.gather params

>>> torch.gather(input, dim, index, *, sparse_grad=False, out=None)

dim๊ณผ index์— ์ฃผ๋ชฉํ•ด๋ณด์ž. ๋จผ์ € dim์€ ์‚ฌ์šฉ์ž๊ฐ€ ์ธ๋ฑ์‹ฑ์„ ์ ์šฉํ•˜๊ณ  ์‹ถ์€ ์ฐจ์›์„ ์ง€์ •ํ•ด์ฃผ๋Š” ์—ญํ• ์„ ํ•œ๋‹ค. index ๋งค๊ฐœ๋ณ€์ˆ˜๋กœ ์ „๋‹ฌํ•˜๋Š” ํ…์„œ ์•ˆ์—๋Š” ์›์†Œ์˜ โ€˜์ธ๋ฑ์Šคโ€™๋ฅผ ์˜๋ฏธํ•˜๋Š” ์ˆซ์ž๋“ค์ด ๋งˆ๊ตฌ์žก์ด๋กœ ๋‹ด๊ฒจ์žˆ๋Š”๋ฐ, ํ•ด๋‹น ์ธ๋ฑ์Šค๊ฐ€ ๋Œ€์ƒ ํ…์„œ์˜ ์–ด๋Š ์ฐจ์›์„ ๊ฐ€๋ฆฌํ‚ฌ ๊ฒƒ์ธ์ง€๋ฅผ ์ปดํ“จํ„ฐ์—๊ฒŒ ์•Œ๋ ค์ค€๋‹ค๊ณ  ์ƒ๊ฐํ•˜๋ฉด ๋œ๋‹ค. index ๋Š” ์•ž์—์„œ ์„ค๋ช…ํ–ˆ๋“ฏ์ด ์›์†Œ์˜ โ€˜์ธ๋ฑ์Šคโ€™๋ฅผ ์˜๋ฏธํ•˜๋Š” ์ˆซ์ž๋“ค์ด ๋‹ด๊ธด ํ…์„œ๋ฅผ ์ž…๋ ฅ์œผ๋กœ ํ•˜๋Š” ๋งค๊ฐœ๋ณ€์ˆ˜๋‹ค. ์ด ๋•Œ ์ฃผ์˜ํ•  ์ ์€ ๋Œ€์ƒ ํ…์„œ(input)์™€ ์ธ๋ฑ์Šค ํ…์„œ์˜ ์ฐจ์› ํ˜•ํƒœ๊ฐ€ ๋ฐ˜๋“œ์‹œ ๋™์ผํ•ด์•ผ ํ•œ๋‹ค๋Š” ๊ฒƒ์ด๋‹ค. ์—ญ์‹œ ๋ง๋กœ๋งŒ ๋“ค์œผ๋ฉด ์ดํ•ดํ•˜๊ธฐ ํž˜๋“œ๋‹ˆ ์‚ฌ์šฉ ์˜ˆ์‹œ๋ฅผ ํ•จ๊ผ ์‚ดํŽด๋ณด์ž.

# torch.gather usage example
>>> q, kr = torch.randn(10, 1024, 64), torch.randn(10, 1024, 64) # [batch, sequence, dim_head], [batch, 2*sequence, dim_head]
>>> tmp_c2p = torch.matmul(q, kr.transpose(-1, -2))
>>> tmp_c2p, tmp_c2p.shape
(tensor([[-2.6477, -4.7478, -5.3250,  ...,  1.6062, -1.9717,  3.8004],
         [ 0.0662,  1.5240,  0.1182,  ...,  0.1653,  2.8476,  1.6337],
         [-0.5010, -4.2267, -1.1179,  ...,  1.1447,  1.7845, -0.1493],
         ...,
         [-2.1073, -1.2149, -4.8630,  ...,  0.8238, -0.5833, -1.2066],
         [ 2.1747,  3.2924,  6.5808,  ..., -0.2926, -0.2511,  2.6996],
         [-2.8362,  2.8700, -0.9729,  ..., -4.9913, -0.3616, -0.1708]],
        grad_fn=<MmBackward0>)
torch.Size([1024, 1024]))

>>> max_seq, max_relative_position = 1024, 512
>>> q_index, k_index = torch.arange(max_seq), torch.arange(2*max_relative_position)
>>> q_index, k_index
(tensor([   0,    1,    2,  ..., 1021, 1022, 1023]),
 tensor([   0,    1,    2,  ..., 1021, 1022, 1023]))

>>> tmp_pos = q_index.view(-1, 1) - k_index.view(1, -1)
>>> rel_pos_matrix = tmp_pos + max_relative_position
>>> rel_pos_matrix
tensor([[ 512,  511,  510,  ..., -509, -510, -511],
        [ 513,  512,  511,  ..., -508, -509, -510],
        [ 514,  513,  512,  ..., -507, -508, -509],
        ...,
        [1533, 1532, 1531,  ...,  512,  511,  510],
        [1534, 1533, 1532,  ...,  513,  512,  511],
        [1535, 1534, 1533,  ...,  514,  513,  512]])

>>> rel_pos_matrix = torch.clamp(rel_pos_matrix, 0, 2*max_relative_position - 1).repeat(10, 1, 1)
>>> tmp_c2p = tmp_c2p.repeat(10, 1, 1)
>>> rel_pos_matrix, rel_pos_matrix.shape, tmp_c2p.shape 
(tensor([[[ 512,  511,  510,  ...,    0,    0,    0],
          [ 513,  512,  511,  ...,    0,    0,    0],
          [ 514,  513,  512,  ...,    0,    0,    0],
          ...,
          [1023, 1023, 1023,  ...,  512,  511,  510],
          [1023, 1023, 1023,  ...,  513,  512,  511],
          [1023, 1023, 1023,  ...,  514,  513,  512]],
 
         [[ 512,  511,  510,  ...,    0,    0,    0],
          [ 513,  512,  511,  ...,    0,    0,    0],
          [ 514,  513,  512,  ...,    0,    0,    0],
          ...,
          [1023, 1023, 1023,  ...,  512,  511,  510],
          [1023, 1023, 1023,  ...,  513,  512,  511],
          [1023, 1023, 1023,  ...,  514,  513,  512]],
torch.Size([10, 1024, 1024]),
torch.Size([10, 1024, 1024]))

>>> torch.gather(tmp_c2p, dim=-1, index=rel_pos_matrix)
tensor([[[-0.8579, -0.2178,  1.6323,  ..., -2.6477, -2.6477, -2.6477],
         [ 1.1601,  2.1752,  0.7187,  ...,  0.0662,  0.0662,  0.0662],
         [ 3.4379, -1.2573,  0.1375,  ..., -0.5010, -0.5010, -0.5010],
         ...,
         [-1.2066, -1.2066, -1.2066,  ...,  0.5943, -0.5169, -3.0820],
         [ 2.6996,  2.6996,  2.6996,  ...,  0.2014,  1.1458,  3.2626],
         [-0.1708, -0.1708, -0.1708,  ...,  1.9955,  4.1549,  2.6356]],

์œ„ ์ฝ”๋“œ๋Š” DeBERTa ์˜ Disentangled Self-Attention์„ ๊ตฌํ˜„ํ•œ ์ฝ”๋“œ์˜ ์ผ๋ถ€๋ถ„์ด๋‹ค. ์ž์„ธํ•œ ์›๋ฆฌ๋Š” DeBERTa ๋…ผ๋ฌธ ๋ฆฌ๋ทฐ ํฌ์ŠคํŒ…์—์„œ ํ™•์ธํ•˜๋ฉด ๋˜๊ณ , ์šฐ๋ฆฌ๊ฐ€ ์ง€๊ธˆ ์ฃผ๋ชฉํ•  ๋ถ€๋ถ„์€ ๋ฐ”๋กœ tmp_c2p, rel_pos_matrix ๊ทธ๋ฆฌ๊ณ  ๋งˆ์ง€๋ง‰ ์ค„์— ์œ„์น˜ํ•œ torch.gather ๋‹ค. [10, 1024, 1024] ๋ชจ์–‘์„ ๊ฐ€์ง„ ๋Œ€์ƒ ํ…์„œ tmp_c2p ์—์„œ ๋‚ด๊ฐ€ ์›ํ•˜๋Š” ์›์†Œ๋งŒ ์ถ”์ถœํ•˜๋ ค๋Š” ์ƒํ™ฉ์ธ๋ฐ, ์ถ”์ถœํ•ด์•ผํ•  ์›์†Œ์˜ ์ธ๋ฑ์Šค ๊ฐ’์ด ๋‹ด๊ธด ํ…์„œ๋ฅผ rel_pos_matrix ๋กœ ์ •์˜ํ–ˆ๋‹ค. rel_pos_matrix ์˜ ์ฐจ์›์€ [10, 1024, 1024]๋กœ tmp_c2p์™€ ๋™์ผํ•˜๋‹ค. ์ฐธ๊ณ ๋กœ ์ถ”์ถœํ•ด์•ผ ํ•˜๋Š” ์ฐจ์› ๋ฐฉํ–ฅ์€ ๊ฐ€๋กœ ๋ฐฉํ–ฅ(๋‘ ๋ฒˆ์งธ 1024)์ด๋‹ค.

์ด์ œ torch.gather์˜ ๋™์ž‘์„ ์‚ดํŽด๋ณด์ž. ์šฐ๋ฆฌ๊ฐ€ ํ˜„์žฌ ์ถ”์ถœํ•˜๊ณ  ์‹ถ์€ ๋Œ€์ƒ์€ 3์ฐจ์› ํ…์„œ์˜ ๊ฐ€๋กœ ๋ฐฉํ–ฅ(๋‘ ๋ฒˆ์งธ 1024, ํ…์„œ์˜ ํ–‰ ๋ฒกํ„ฐ), ์ฆ‰ 2 * max_sequence_length ๋ฅผ ์˜๋ฏธํ•˜๋Š” ์ฐจ์› ๋ฐฉํ–ฅ์˜ ์›์†Œ๋‹ค. ๋”ฐ๋ผ์„œ dim=-1์œผ๋กœ ์„ค์ •ํ•ด์ค€๋‹ค. ์ด์ œ ๋ฉ”์„œ๋“œ๊ฐ€ ์˜๋„๋Œ€๋กœ ์ ์šฉ๋˜์—ˆ๋Š”์ง€ ํ™•์ธํ•ด๋ณด์ž. rel_pos_matrix ์˜ 0๋ฒˆ ๋ฐฐ์น˜, 0๋ฒˆ์งธ ์‹œํ€€์Šค์˜ ๊ฐ€์žฅ ๋งˆ์ง€๋ง‰ ์ฐจ์›์˜ ๊ฐ’์€ 0์œผ๋กœ ์ดˆ๊ธฐํ™” ๋˜์–ด ์žˆ๋‹ค. ๋‹ค์‹œ ๋งํ•ด, ๋Œ€์ƒ ํ…์„œ์˜ ๋Œ€์ƒ ์ฐจ์›์—์„œ 0๋ฒˆ์งธ ์ธ๋ฑ์Šค์— ํ•ด๋‹นํ•˜๋Š” ๊ฐ’์„ ๊ฐ€์ ธ์˜ค๋ผ๋Š” ์˜๋ฏธ๋ฅผ ๋‹ด๊ณ  ์žˆ๋‹ค. ๊ทธ๋ ‡๋‹ค๋ฉด torch.gather ์‹คํ–‰ ๊ฒฐ๊ณผ๊ฐ€ tmp_c2p์˜ 0๋ฒˆ ๋ฐฐ์น˜, 0๋ฒˆ์งธ ์‹œํ€€์Šค์˜ 0๋ฒˆ์งธ ์ฐจ์› ๊ฐ’๊ณผ ์ผ์น˜ํ•˜๋Š”์ง€ ํ™•์ธํ•ด๋ณด์ž. ๋‘˜ ๋‹ค -2.6477, -2.6477 ์œผ๋กœ ๊ฐ™์€ ๊ฐ’์„ ๋‚˜ํƒ€๋‚ด๊ณ  ์žˆ๋‹ค. ๋”ฐ๋ผ์„œ ์šฐ๋ฆฌ ์˜๋„๋Œ€๋กœ ์ž˜ ์‹คํ–‰๋˜์—ˆ๋‹ค๋Š” ์‚ฌ์‹ค์„ ์•Œ ์ˆ˜ ์žˆ๋‹ค.

๐Ÿ‘ฉโ€๐Ÿ‘ฉโ€๐Ÿ‘งโ€๐Ÿ‘ฆย torch.triu, torch.tril

๊ฐ๊ฐ ์ž…๋ ฅ ํ…์„œ๋ฅผ ์ƒ์‚ผ๊ฐํ–‰๋ ฌ, ํ•˜์‚ผ๊ฐํ–‰๋ ฌ๋กœ ๋งŒ๋“ ๋‹ค. triu๋‚˜ tril์€ ์‚ฌ์‹ค ๋’ค์ง‘์œผ๋ฉด ๊ฐ™์€ ๊ฒฐ๊ณผ๋ฅผ ๋ฐ˜ํ™˜ํ•˜๊ธฐ ๋•Œ๋ฌธ์— tril์„ ๊ธฐ์ค€์œผ๋กœ ์„ค๋ช…์„ ํ•˜๊ฒ ๋‹ค. ๋ฉ”์„œ๋“œ์˜ ๋งค๊ฐœ๋ณ€์ˆ˜๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค.

# torch.triu, tril params
upper_tri_matrix = torch.triu(input_tensor, diagonal=0, *, out=None)
lower_tri_matrix = torch.tril(input_tensors, diagonal=0, *, out=None)

diagonal ์— ์ฃผ๋ชฉํ•ด๋ณด์ž. ์–‘์ˆ˜๋ฅผ ์ „๋‹ฌํ•˜๋ฉด ์ฃผ๋Œ€๊ฐ์„ฑ๋ถ„์—์„œ ํ•ด๋‹นํ•˜๋Š” ๊ฐ’๋งŒํผ ๋–จ์–ด์ง„ ๊ณณ์˜ ๋Œ€๊ฐ์„ฑ๋ถ„๊นŒ์ง€ ๊ทธ ๊ฐ’์„ ์‚ด๋ ค๋‘”๋‹ค. ํ•œํŽธ ์Œ์ˆ˜๋ฅผ ์ „๋‹ฌํ•˜๋ฉด ์ฃผ๋Œ€๊ฐ์„ฑ๋ถ„์„ ํฌํ•จํ•ด ์ฃผ์–ด์ง„ ๊ฐ’๋งŒํผ ๋–จ์–ด์ง„ ๊ณณ๊นŒ์ง€์˜ ๋Œ€๊ฐ์„ฑ๋ถ„์„ ๋ชจ๋‘ 0์œผ๋กœ ๋งŒ๋“ค์–ด๋ฒ„๋ฆฐ๋‹ค. ๊ธฐ๋ณธ์€ 0์œผ๋กœ ์„ค์ •๋˜์–ด ์žˆ์œผ๋ฉฐ, ์ด๋Š” ์ฃผ๋Œ€๊ฐ์„ฑ๋ถ„๋ถ€ํ„ฐ ์™ผ์ชฝ ํ•˜๋‹จ์˜ ์›์†Œ๋ฅผ ๋ชจ๋‘ ์‚ด๋ ค๋‘๊ฒ ๋‹ค๋Š” ์˜๋ฏธ๊ฐ€ ๋œ๋‹ค.

# torch.tril usage example
>>> lm_mask = torch.tril(torch.ones(x.shape[0], x.shape[-1], x.shape[-1]))
>>> lm_mask
1 0 0 0 0
1 1 0 0 0
1 1 1 0 0
1 1 1 1 0

๋‘ ๋ฉ”์„œ๋“œ๋Š” ์„ ํ˜•๋Œ€์ˆ˜ํ•™์ด ํ•„์š”ํ•œ ๋‹ค์–‘ํ•œ ๋ถ„์•ผ์—์„œ ์‚ฌ์šฉ๋˜๋Š”๋ฐ, ํ•„์ž์˜ ๊ฒฝ์šฐ, GPT์ฒ˜๋Ÿผ Transformer์˜ Decoder ๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๋ชจ๋ธ์„ ๋นŒ๋“œํ•  ๋•Œ ๊ฐ€์žฅ ๋งŽ์ด ์‚ฌ์šฉํ–ˆ๋˜ ๊ฒƒ ๊ฐ™๋‹ค. Decoder๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๋ชจ๋ธ์€ ๋Œ€๋ถ€๋ถ„ ๊ตฌ์กฐ์ƒ Language Modeling์„ ์œ„ํ•ด์„œ Masked Multi-Head Self-Attention Block์„ ์‚ฌ์šฉํ•˜๋Š”๋ฐ ์ด ๋•Œ ๋ฏธ๋ž˜ ์‹œ์ ์˜ ํ† ํฐ ์ž„๋ฒ ๋”ฉ ๊ฐ’์— ๋งˆ์Šคํ‚น์„ ํ•ด์ฃผ๊ธฐ ์œ„ํ•ด torch.tril ์„ ์‚ฌ์šฉํ•˜๊ฒŒ ๋˜๋‹ˆ ์ฐธ๊ณ ํ•˜์ž.

๐Ÿ‘ฉโ€๐Ÿ‘ฉโ€๐Ÿ‘งโ€๐Ÿ‘ฆย torch.Tensor.masked_fill

์‚ฌ์šฉ์ž๊ฐ€ ์ง€์ •ํ•œ ๊ฐ’์— ํ•ด๋‹น๋˜๋Š” ์›์†Œ๋ฅผ ๋ชจ๋‘ ๋งˆ์Šคํ‚น ์ฒ˜๋ฆฌํ•ด์ฃผ๋Š” ๋ฉ”์„œ๋“œ๋‹ค. ๋จผ์ € ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ํ™•์ธํ•ด๋ณด์ž.

# torch.Tensor.masked_fill params
input_tensors = torch.Tensor([[1,2,3], [4,5,6]])
input_tensors.masked_fill(mask: BoolTensor, value: float)

masked_fill ์€ ํ…์„œ ๊ฐ์ฒด์˜ ๋‚ด๋ถ€ attribute ๋กœ ์ •์˜๋˜๊ธฐ ๋•Œ๋ฌธ์— ํ•ด๋‹น ๋ฉ”์„œ๋“œ๋ฅผ ์‚ฌ์šฉํ•˜๊ณ  ์‹ถ๋‹ค๋ฉด ๋จผ์ € ๋งˆ์Šคํ‚น ๋Œ€์ƒ ํ…์„œ๋ฅผ ๋งŒ๋“ค์–ด์•ผ ํ•œ๋‹ค. ํ…์„œ๋ฅผ ์ •์˜ํ–ˆ๋‹ค๋ฉด ํ…์„œ ๊ฐ์ฒด์˜ attributes ์ ‘๊ทผ์„ ํ†ตํ•ด masked_fill() ์„ ํ˜ธ์ถœํ•œ ๋’ค, ํ•„์š”ํ•œ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์ „๋‹ฌํ•ด์ฃผ๋Š” ๋ฐฉ์‹์œผ๋กœ ์‚ฌ์šฉํ•˜๋ฉด ๋œ๋‹ค.

mask ๋งค๊ฐœ๋ณ€์ˆ˜์—๋Š” ๋งˆ์Šคํ‚น ํ…์„œ๋ฅผ ์ „๋‹ฌํ•ด์•ผ ํ•˜๋Š”๋ฐ, ์ด ๋•Œ ๋‚ด๋ถ€ ์›์†Œ๋Š” ๋ชจ๋‘ boolean์ด์–ด์•ผ ํ•˜๊ณ  ํ…์„œ์˜ ํ˜•ํƒœ๋Š” ๋Œ€์ƒ ํ…์„œ์™€ ๋™์ผํ•ด์•ผ ํ•œ๋‹ค(์™„์ „ํžˆ ๊ฐ™์„ ํ•„์š”๋Š” ์—†๊ณ , ๋ธŒ๋กœ๋“œ ์บ์ŠคํŒ…๋งŒ ๊ฐ€๋Šฅํ•˜๋ฉด ์ƒ๊ด€ ์—†์Œ).

value ๋งค๊ฐœ๋ณ€์ˆ˜์—๋Š” ๋งˆ์Šคํ‚น ๋Œ€์ƒ ์›์†Œ๋“ค์— ์ผ๊ด„์ ์œผ๋กœ ์ ์šฉํ•ด์ฃผ๊ณ  ์‹ถ์€ ๊ฐ’์„ ์ „๋‹ฌํ•œ๋‹ค. ์ด๊ฒŒ ๋ง๋กœ๋งŒ ๋“ค์œผ๋ฉด ์ดํ•ดํ•˜๊ธฐ ์‰ฝ์ง€ ์•Š๋‹ค. ์•„๋ž˜ ์‚ฌ์šฉ ์˜ˆ์‹œ๋ฅผ ํ•จ๊ป˜ ์ฒจ๋ถ€ํ–ˆ์œผ๋‹ˆ ์ฐธ๊ณ  ๋ฐ”๋ž€๋‹ค.

# torch.masked_fill usage

>>> lm_mask = torch.tril(torch.ones(x.shape[0], x.shape[-1], x.shape[-1]))
>>> lm_mask
1 0 0 0 0
1 1 0 0 0
1 1 1 0 0
1 1 1 1 0
>>> attention_matrix = torch.matmul(q, k.transpose(-1, -2)) / dot_scale
>>> attention_matrix
1.22 2.1 3.4 1.2 1.1
1.22 2.1 3.4 9.9 9.9
1.22 2.1 3.4 9.9 9.9
1.22 2.1 3.4 9.9 9.9

>>> attention_matrix = attention_matrix.masked_fill(lm_mask == 0, float('-inf'))
>>> attention_matrix
1.22 -inf -inf -inf -inf
1.22 2.1 -inf -inf -inf
1.22 2.1 3.4 -inf -inf
1.22 2.1 3.4 9.9 -inf

๐Ÿ—‚๏ธย torch.clone

inputs ์ธ์ž๋กœ ์ „๋‹ฌํ•œ ํ…์„œ๋ฅผ ๋ณต์‚ฌํ•˜๋Š” ํŒŒ์ดํ† ์น˜ ๋‚ด์žฅ ๋ฉ”์„œ๋“œ๋‹ค. ์‚ฌ์šฉ๋ฒ•์€ ์•„๋ž˜์™€ ๊ฐ™๋‹ค.

""" torch.clone """
torch.clone(
    input,ย 
    *,
   ย memory_format=torch.preserve_format
)ย โ†’ย [Tensor]

๋”ฅ๋Ÿฌ๋‹ ํŒŒ์ดํ”„๋ผ์ธ์„ ๋งŒ๋“ค๋‹ค ๋ณด๋ฉด ๋งŽ์ด ์‚ฌ์šฉํ•˜๊ฒŒ ๋˜๋Š” ๊ธฐ๋ณธ์ ์ธ ๋ฉ”์„œ๋“œ์ธ๋ฐ, ์ด๋ ‡๊ฒŒ ๋”ฐ๋กœ ์ •๋ฆฌํ•˜๊ฒŒ ๋œ ์ด์œ ๊ฐ€ ์žˆ๋‹ค. ์ž…๋ ฅ๋œ ํ…์„œ๋ฅผ ๊ทธ๋Œ€๋กœ ๋ณต์‚ฌํ•œ๋‹ค๋Š” ํŠน์„ฑ ๋•Œ๋ฌธ์— ์‚ฌ์šฉ์‹œ ์ฃผ์˜ํ•ด์•ผ ํ•  ์ ์ด ์žˆ๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค. ํ•ด๋‹น ๋ฉ”์„œ๋“œ๋ฅผ ์‚ฌ์šฉํ•˜๊ธฐ ์ „์— ๋ฐ˜๋“œ์‹œ ์ž…๋ ฅํ•  ํ…์„œ๊ฐ€ ํ˜„์žฌ ์–ด๋Š ๋””๋ฐ”์ด์Šค(CPU, GPU) ์œ„์— ์žˆ๋Š”์ง€, ๊ทธ๋ฆฌ๊ณ  ํ•ด๋‹น ํ…์„œ๊ฐ€ ๊ณ„์‚ฐ ๊ทธ๋ž˜ํ”„๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ๋Š”์ง€๋ฅผ ๋ฐ˜๋“œ์‹œ ํŒŒ์•…ํ•ด์•ผ ํ•œ๋‹ค.

ํ•„์ž๋Š” ELECTRA ๋ชจ๋ธ์„ ์ง์ ‘ ๊ตฌํ˜„ํ•˜๋Š” ๊ณผ์ •์—์„œ clone() ๋ฉ”์„œ๋“œ๋ฅผ ์‚ฌ์šฉํ–ˆ๋Š”๋ฐ, Generator ๋ชจ๋ธ์˜ ๊ฒฐ๊ณผ ๋กœ์ง“์„ Discriminator์˜ ์ž…๋ ฅ์œผ๋กœ ๋ณ€ํ™˜ํ•ด์ฃผ๊ธฐ ์œ„ํ•จ์ด์—ˆ๋‹ค. ๊ทธ ๊ณผ์ •์—์„œ Generator๊ฐ€ ๋ฐ˜ํ™˜ํ•œ ๋กœ์ง“์„ ๊ทธ๋Œ€๋กœ cloneํ•œ ๋’ค, ์ž…๋ ฅ์„ ๋งŒ๋“ค์–ด ์ฃผ์—ˆ๊ณ  ๊ทธ ๊ฒฐ๊ณผ ์•„๋ž˜์™€ ๊ฐ™์€ ์—๋Ÿฌ๋ฅผ ๋งˆ์ฃผํ–ˆ๋‹ค.

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [
torch.cuda.LongTensor [8, 511]] is at version 1; expected version 0 
instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. 
The variable in question was changed in there or anywhere later. Good luck!

์—๋Ÿฌ ๋กœ๊ทธ๋ฅผ ์ž์„ธํžˆ ์ฝ์–ด๋ณด๋ฉด ํ…์„œ ๋ฒ„์ „์˜ ๋ณ€๊ฒฝ์œผ๋กœ ์ธํ•ด ๊ทธ๋ผ๋””์–ธํŠธ ๊ณ„์‚ฐ์ด ๋ถˆ๊ฐ€ํ•˜๋‹ค๋Š” ๋‚ด์šฉ์ด ๋‹ด๊ฒจ์žˆ๋‹ค. ๊ตฌ๊ธ€๋งํ•ด๋ด๋„ ์ž˜ ์•ˆ๋‚˜์™€์„œ ํฌ๊ธฐํ•˜๋ ค๋˜ ์ฐฐ๋ผ์— ์šฐ์—ฐํžˆ torch.clone() ๋ฉ”์„œ๋“œ์˜ ์ •ํ™•ํ•œ ์‚ฌ์šฉ๋ฒ•์ด ๊ถ๊ธˆํ•ด ๊ณต์‹ Docs๋ฅผ ์ฝ๊ฒŒ ๋˜์—ˆ๊ณ , ๊ฑฐ๊ธฐ์„œ ์—„์ฒญ๋‚œ ์‚ฌ์‹ค์„ ๋ฐœ๊ฒฌํ–ˆ๋‹ค. clone() ๋ฉ”์„œ๋“œ๊ฐ€ ์ž…๋ ฅ๋œ ํ…์„œ์˜ ํ˜„์žฌ ๋””๋ฐ”์ด์Šค ์œ„์น˜์— ๋˜‘๊ฐ™์ด ๋ณต์‚ฌ๋  ๊ฒƒ์ด๋ž€ ์˜ˆ์ƒ์€ ํ–ˆ์ง€๋งŒ, ์ž…๋ ฅ ํ…์„œ์˜ ๊ณ„์‚ฐ๊ทธ๋ž˜ํ”„๊นŒ์ง€ ๋ณต์‚ฌ๋  ๊ฒƒ์ด๋ž€ ์ƒ๊ฐ์€ ์ „ํ˜€ ํ•˜์ง€ ๋ชปํ–ˆ๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค. ๊ทธ๋ž˜์„œ ์œ„์™€ ๊ฐ™์€ ์—๋Ÿฌ๋ฅผ ๋งˆ์ฃผํ•˜์ง€ ์•Š์œผ๋ ค๋ฉด, clone()์„ ํ˜ธ์ถœํ•  ๋•Œ ๋’ค์— ๋ฐ˜๋“œ์‹œ detach()๋ฅผ ํ•จ๊ป˜ ํ˜ธ์ถœํ•ด์ค˜์•ผ ํ•œ๋‹ค.

clone() ๋ฉ”์„œ๋“œ๋Š” ์ž…๋ ฅ๋œ ํ…์„œ์˜ ๋ชจ๋“  ๊ฒƒ์„ ๋ณต์‚ฌํ•œ๋‹ค๋Š” ์ ์„ ๋ฐ˜๋“œ์‹œ ๊ธฐ์–ตํ•˜์ž.

Leave a comment