Updated:

๐Ÿ˜ต nn.Embedding ์ฐจ์› โ‰  ์‹ค์ œ ๋ฐ์ดํ„ฐ ์ž…๋ ฅ ์ฐจ์›

torch.nn.Embedding์—์„œ ์ •์˜ํ•œ ์ž…์ถœ๋ ฅ ์ฐจ์›๊ณผ ์‹ค์ œ ๋ฐ์ดํ„ฐ์˜ ์ฐจ์›์ด ๋‹ค๋ฅธ ๊ฒฝ์šฐ์— ๋ฐœ์ƒํ•˜๋Š” ์—๋Ÿฌ๋‹ค. ๋‹ค์–‘ํ•œ ์ƒํ™ฉ์—์„œ ๋งˆ์ฃผํ•  ์ˆ˜ ์žˆ๋Š” ์—๋Ÿฌ์ง€๋งŒ, ํ•„์ž์˜ ๊ฒฝ์šฐ Huggingface์—์„œ ๋ถˆ๋Ÿฌ์˜จpretrained tokenizer์— special token ์„ ์ถ”๊ฐ€ํ•ด ์‚ฌ์šฉํ•  ๋•Œ, ํ† ํฐ์„ ์ถ”๊ฐ€ํ–ˆ๋‹ค๋Š” ์‚ฌ์‹ค์„ ์žŠ๊ณ  nn.Embedding ์— ์ •์˜ํ•œ ์ž…์ถœ๋ ฅ ์ฐจ์›์„ ๋ณ€๊ฒฝํ•˜์ง€ ์•Š์•„์„œ ๋ฐœ์ƒํ•˜๋Š” ๊ฒฝ์šฐ๊ฐ€ ๋งŽ์•˜๋‹ค.

from transformers import AutoTokenizer, AutoConfig, AutoModel

class CFG:
    model_name = 'microsoft/deberta-v3-large'
    config = AutoConfig.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name, config=config)


def add_markdown_token(cfg: sCFG) -> None:
    """
    Add MarkDown token to pretrained tokenizer ('[MD]')
    Args:
        cfg: CFG, needed to load tokenizer from Huggingface AutoTokenizer
    """
    markdown_token = '[MD]'
    special_tokens_dict = {'additional_special_tokens': [f'{markdown_token}']}
    cfg.tokenizer.add_special_tokens(special_tokens_dict)
    markdown_token_id = cfg.tokenizer(f'{markdown_token}', add_special_tokens=False)['input_ids'][0]

    setattr(cfg.tokenizer, 'markdown_token', f'{markdown_token}')
    setattr(cfg.tokenizer, 'markdown_token_id', markdown_token_id)
    cfg.tokenizer.save_pretrained(f'{cfg.checkpoint_dir}/tokenizer/')


add_markdown_token(CFG)
CFG.model.resize_token_embeddings(len(tokenizer))

๊ตฌ๊ธ€๋งํ•ด๋ณด๋‹ˆ ํ•ด๊ฒฐํ•˜๋Š” ๋ฐฉ๋ฒ•์€ ๋‹ค์–‘ํ•œ ๊ฒƒ ๊ฐ™์€๋ฐ, torch.nn.Embedding์— ์ •์˜๋œ ์ž…์ถœ๋ ฅ ์ฐจ์›์„ ์‹ค์ œ ๋ฐ์ดํ„ฐ ์ฐจ์›๊ณผ ๋งž์ถฐ์ฃผ๋ฉด ๊ฐ„๋‹จํ•˜๊ฒŒ ํ•ด๊ฒฐ๋œ๋‹ค. ํ•„์ž์ฒ˜๋Ÿผ special token ์„ ์ถ”๊ฐ€ํ•ด ์‚ฌ์šฉํ•˜๋‹ค ํ•ด๋‹น ์—๋Ÿฌ๊ฐ€ ๋ฐœ์ƒํ•˜๋Š” ์ƒํ™ฉ์ด๋ผ๋ฉด ์ƒˆ๋กœ์šด ํ† ํฐ์ด ์ถ”๊ฐ€๋œ ํ† ํฌ๋‚˜์ด์ €์˜ ๊ธธ์ด๋ฅผ ๋‹ค์‹œ ์ธก์ •ํ•œ ๋’ค ๊ฐ’์„ resize_token_embeddings ๋ฉ”์„œ๋“œ์— ์ „๋‹ฌํ•ด nn.Embedding์„ ์—…๋ฐ์ดํŠธ ํ•ด์ฃผ๋ฉด ๋œ๋‹ค.

Leave a comment