๐ช [LoRA] Low-Rank Adaptation of Large Language Models
๐ญย Overview
LoRA๋ 2021๋
MS ์ฐ๊ตฌ์ง์ด ๋ฐํํ ๋
ผ๋ฌธ์ผ๋ก ์๋ณธ(Full ํ์ธํ๋)๊ณผ ๊ฑฐ์ ์ ์ฌํ ์ฑ๋ฅ(์ฌ์ง์ด ์ผ๋ถ ๋ฒค์น๋งํฌ๋ ๋ ๋์)์ผ๋ก LLM ํ์ธํ๋์ ํ์ํ GPU ๋ฉ๋ชจ๋ฆฌ๋ฅผ ํ๊ธฐ์ ์ผ๋ก ์ค์ด๋๋ฐ ์ฑ๊ณตํด ์ฃผ๋ชฉ์ ๋ฐ์๋ค. ์ปค๋ฎค๋ํฐ์์ LoRA is All You Need
๋ผ๋ ๋ณ๋ช
๊น์ง ์ป์ผ๋ฉฐ ๊ทธ ์ธ๊ธฐ๋ฅผ ๊ตฌ๊ฐํ๊ณ ์๋ค.
DistilBERT
๋ฆฌ๋ทฐ์์๋ ์ดํด๋ณด์๋ฏ, BERT์ GPT์ ๋ฑ์ฅ ์ดํ, ๋ชจ๋ NLP ๋๋ฉ์ธ์์ ๋น์ฝ์ ์ธ ์ฑ๋ฅ ํฅ์์ด ์ด๋ค์คฌ์์๋ ๋ถ๊ตฌํ๊ณ , NLP์ฉ ๋ฅ๋ฌ๋ ๋ชจ๋ธ์ ์ค์ํ์ ํ์ฉํ๊ธฐ์๋ ๋๋ฌด ํฐ ๋ฆฌ์์ค ์๊ตฌ๋๊ณผ ๋ ์ดํด์๊ฐ ๋ฐ๋ชฉ์ ์ก์๋ค. ํ์ง๋ง LoRA
๋ฐํ ์ดํ, ํ์ธํ๋ ์์ ์ ํ๋ จํด์ผ ํ๋ ํ๋ผ๋ฏธํฐ ์๊ฐ ํ์ ํ ์ค์ด๋ค๋ฉด์ ๋ชจ๋ธ์ ์ฒดํฌํฌ์ธํธ ์ฉ๋์ด ๊ธฐํ๊ธ์์ ์ผ๋ก ๊ฐ์ํ๋ค. ๋๋ถ์ ์๊ตฌ GPU VRAM์ด ํ์ ํ ๋ฎ์์ ธ, ๋ฆฌ์์ค ์ ํ ๋๋ฌธ์ ์๋นํ์ง ๋ชปํ๋ ๊ฒฝ์ฐ๊ฐ ๋ง์ด ์ฌ๋ผ์ก๋ค. ๊ทธ๋์ ์ค๋๋ Mixed Precision
, Quantization
๊ณผ ํจ๊ป ๋ชจ๋ธ ๊ฒฝ๋โข์ต์ ํ ๋ถ์ผ์์ ๊ฐ์ฅ ์ค์ํ ์ฃผ์ ๋ก ๋ ์ค๋ฅด๊ณ ์๋ค.
๋ด์ฉ์ ์ดํด๋ณด๊ธฐ์ , LoRA
๋ ์ด๋ฏธ ์ฌ์ ํ์ต์ ์๋ฃํ ๋ชจ๋ธ์ ํ์ธํ๋ํ ๋ ์ฌ์ฉํด์ผํจ์ ๋ค์ ํ ๋ฒ ๋ช
์ฌํ์. ์ด๋ฒ ํฌ์คํ
์์๋ ๋๊ฐ์ง๋ฅผ ์ง์ค์ ์ผ๋ก ๋ค๋ฃฐ ๊ฒ์ด๋ค.
1) ๋ชจ๋ธ ํฌ๊ธฐ ์ค์ธ ๋ฐฉ๋ฒ, 2) ํฌ๊ธฐ๋ฅผ ์ค์ด๋ฉด์๋ ๋น์ทํ ์ฑ๋ฅ์ ๋ผ ์ ์์๋ ์ด์
๐คย Concept: Low-Rank Adaptation
\[h = W_0x + \Delta Wx = W_0x + BAx\]
์์ด๋์ด๋ ์๋นํ ๊ฐ๋จํ๋ค. ์ฌ์ ํ์ต์ ๋ง์น๊ณ ์๋ ด๋ ์ํ์ ๊ฐ์ค์น ํ๋ ฌ์ ์๋ฏธํ๋ $W_0$๊ณผ ์๋ก์ด ๊ฐ์ค์น ํ๋ ฌ $\Delta W$์ ๋ชจ๋ ์ ๋ ฅ์ ํต๊ณผ์ํจ๋ค. ๊ทธ๋ฆฌ๊ณ ๋์จ ๊ฒฐ๊ณผ๋ฅผ ๋ํด ๋ค์์ธต์ ์ ๋ ฅ์ผ๋ก ์ฌ์ฉํ๋ค. ์คํ๋ ค ์๋ก์ด ๊ฐ์ค์น ํ๋ ฌ์ ์ถ๊ฐํด ํ์ธํ๋์ ํ๋๋ฐ ์ด๋ป๊ฒ ํ๋ จํด์ผ ํ๋ ํ๋ผ๋ฏธํฐ ์๋ฅผ ์ค์ผ ์ ์์์๊น??
๊ทธ ๋น๋ฐ์ Freeze(Stop Gradient, require_grad=False)
์ Matrix Factorization
์ ์จ์ด ์๋ค. ๋จผ์ ์ฌ์ ํ๋ จ๋ ๊ฐ์ค์น ํ๋ ฌ์ Freeze(Stop Gradient, require_grad=False)
๋ฅผ ์ ์ฉํด ๊ทธ๋ผ๋์ธํธ๊ฐ ํ๋ฅด์ง ์๋๋ก ํ๋ค. ์ด๋ ๊ฒ ํ๋ฉด ํ์ธํ๋ ๊ณผ์ ์์ ๊ฐ์ค์น๊ฐ ์
๋ฐ์ดํธ ๋์ง ์์ ์ฌ์ ํ์ต์์ ์ต๋ํ ์ง์์ ์ ์งํ ์ ์์ ๋ฟ๋ง ์๋๋ผ, ํ์ต์ ์ํด ๊ทธ๋ผ๋์ธํธ๋ฅผ ์ ์ฅํ ํ์๊ฐ ์์ด์ ธ ํ์ธํ๋ ๋ ํ์ํ GPU VRAM์ ํ๊ธฐ์ ์ผ๋ก ์ค์ผ ์ ์๋ค.
์ฒ์์ ์ฌ์ ํ์ต ๊ฐ์ค์น๋ฅผ ํต๊ณผํ ๊ฐ๊ณผ ์๋ก์ด ๊ฐ์ค์น ํ๋ ฌ $\Delta W$๋ฅผ ํต๊ณผํ ๊ฐ์ ์๋ก ๋ํ๋ค๊ณ ์ธ๊ธํ๋ค. ๊ทธ๋ ๋ค๋ฉด, ๋ ๊ฒฐ๊ณผ ํ๋ ฌ์ ํ๋ ฌ ํฌ๊ธฐ๊ฐ ๋์ผํด์ผ ํ๋ค๋ ๊ฒ์ด๋ค. ์ด๋ป๊ฒ ๊ธฐ์กด๋ณด๋ค ์ฌ์ด์ฆ๋ ์ค์ด๋ฉด์ ๊ฒฐ๊ณผ ํ๋ ฌ์ ํฌ๊ธฐ๋ ๋์ผํ๊ฒ ๋ง๋ค์ด์ค ์ ์์๊น?? ๋ฐ๋ก Low Rank value $r$์ ๋์ ํด Matrix Factorization ์ ํ๋ค.
\[W_{d \times d} = \begin{bmatrix} w_{1,1} & w_{1,2} & \cdots & w_{1,d} \\ w_{2,1} & w_{2,2} & \cdots & w_{2,d} \\ \vdots & \vdots & \ddots & \vdots \\ w_{d,1} & w_{d,2} & \cdots & w_{d,d} \end{bmatrix}\]ํ๋ ฌ ๊ณฑ์
(matrix multiplication)์ ๋ค์ ํ ๋ฒ ์๊ธฐํด๋ณด์. MxN
์ ํฌ๊ธฐ๋ฅผ ๊ฐ๋ ํ๋ ฌ์ NxK
์ ํฌ๊ธฐ๋ฅผ ๊ฐ๋ ํ๋ ฌ์ ๊ณฑํด์ฃผ๋ฉด MxK
์ ํฌ๊ธฐ๋ฅผ ๊ฐ๋ ํ๋ ฌ์ ๋ง๋ค์ด์ค ์ ์๋ค. ๋ง์ฐฌ๊ฐ์ง๋ค. dxd
ํฌ๊ธฐ์ธ ์ฌ์ ํ์ต์ ๊ฐ์ค์น ํ๋ ฌ $W_{d \times d}$๊ณผ ํฌ๊ธฐ๋ฅผ ๋ง์ถ๊ธฐ ์ํด, dxd
์ง๋ฆฌ ํ๋ ฌ์ ๊ฐ๊ฐ dxr
, rxd
์ ํฌ๊ธฐ๋ฅผ ๊ฐ๋ ๋ ํ๋ ฌ $B, A$๋ก ๋ถํดํ๋ค. ์ด ๋, ํ๋ ฌ $B$์ ์ด์ฐจ์๊ณผ ํ๋ ฌ $A$์ฐจ์์ ํ์ฐจ์ ํฌ๊ธฐ๋ฅผ ํํํ๋ $r$์ ๋ฐ๋ก Low Rank value $r$์ ๋์
ํ๋ฉด ๋๋ค.
$r=3$์ด๋ผ๊ณ ๊ฐ์ ํ๊ณ 768x768
์ง๋ฆฌ ๊ธฐ์กด ๊ฐ์ค์น ํ๋ ฌ $W$๊ณผ 768x3
, 3x768
์ ํฌ๊ธฐ๋ฅผ ๊ฐ๋ $\Delta W = BA$์ ํ๋ผ๋ฏธํฐ ๊ฐ์๋ฅผ ๋น๊ตํด๋ณด์. ๊ณ์ฐํด๋ณด๋ฉด ์ ์๋ 589,824
๊ฐ, ํ์๋ 4608
๊ฐ๊ฐ ๋๋ค. ์ ํํ๊ฒ 128
๋ฐฐ ์ฐจ์ด๊ฐ ๋๋ค. ํธ๋์คํฌ๋จธ ๋ชจ๋ธ ์์๋ ํ๋ ฌ $W$๊ณผ ๊ฐ์ ํฌ๊ธฐ๋ฅผ ๊ฐ๋ ๊ฐ์ค์น ํ๋ ฌ์ด ๋จ์ผ ์ธ์ฝ๋ ๋ด๋ถ, ํ๋์ ์ดํ
์
๋ ์ด์ด๋ง ํด๋ 4๊ฐ($W_q, W_k, W_v, W_o$)๊ฐ ์๋ค. BERT-base
๋ชจ๋ธ์ ๊ธฐ์ค์ผ๋ก ๋ณด๋ฉด, ํด๋น ๋ชจ๋ธ์ด 12
๊ฐ์ ์ธ์ฝ๋๋ก ๊ตฌ์ฑ๋์ด ์์ผ๋๊น ์ด 48
๊ฐ์ ๊ฐ์ค์น ํ๋ ฌ์ด ์๊ณ , ์ด๋ฆผ์ก์๋ 48*128
๋ฐฐ์ ํ์ต ํ๋ผ๋ฏธํฐ ๊ฐ์ ํจ๊ณผ๋ฅผ ๋ผ ์ ์๋ค. ๋ชจ๋ธ์ ๋ ์ด์ด๊ฐ ๋ง์ผ๋ฉด ๋ง์์๋ก ๋ ์ข์ ํจ์จ์ ๋ณด์ธ๋ค.
์ ๊ทธ๋ฆผ์ ํ์ดํ ์น ๊ณต์ ๋ธ๋ก๊ทธ์์ ํผ์จ ์๋ฃ๋ก, ํ์ต ๋ ResNet50์ GPU VRAM ์ ์ ์จ ์ถ์ด๋ ๋ฌผ๋ก ๋ชจ๋ธ์ ๊ฐ๋ณ ๊ตฌ์ฑ์์์ ๋ฉ๋ชจ๋ฆฌ ๋น์จ๊น์ง ์์ธํ ๋ณด์ฌ์ค๋ค. ๋จผ์ Parameter
์ Optimizer State
๋ฅผ ๋ณด์. Parameter
๋ ๋ชจ๋ธ์์ ํ๋ จ์ ํตํด ์
๋ฐ์ดํธ๊ฐ ํ์ํ ๋ชจ๋ ๊ตฌ์ฑ ์์๋ฅผ ๋งํ๋ค. Freeze
, require_grad=False
@torch.no_grad()
, torch.register_buffer()
์ ์ํฅ์ ๋ฐ์ง ์์ ๋ชจ๋ธ ๋ด๋ถ์ ๋ชจ๋ ํ
์๋ผ๊ณ ๋ณด๋ฉด ๋๋ค.
ํํธ, Optimizer State
๋ ์ตํฐ๋ง์ด์ ์ ์ต์ ํ ์ํ์ ํ์ํ ๋ชจ๋ ์ ๋ณด๋ค์ ์๋ฏธํ๋๋ฐ, ์๋ฅผ ๋ค์ด ์
๋ฐ์ดํธ ๋ ํ
์์ ๋ฉํ ์ ๋ณด, ์ฌ๋ฌ ํ์ดํผํ๋ผ๋ฏธํฐ ๊ฐ ๊ฐ์ ๊ฒ๋ค์ด ๋ด๊ฒจ ์๋ค.
์ด ๋ ์์๊ฐ ๋ชจ๋ธ์ GPU VRAM์ ์ฐจ์งํ๋ ๋น์จ์ด ์๋นํ ํฌ๋ค. ํ์ง๋ง ๋ ์์ ๋ชจ๋ ํ๋ผ๋ฏธํฐ ๊ฐ์์ ๋น๋กํ๋ฏ๋ก LoRA
์ ์ฉ์ผ๋ก ํ๋ผ๋ฏธํฐ ๊ฐ์๋ฅผ ์ค์ด๋ฉด, GPU VRAM์ ํ๊ธฐ์ ์ผ๋ก ์ค์ผ ์ ์๋ค.
๋ํ ํ์ดํ ์น๋ ์ญ์ ํ ์ํ์ ์ํด ๊ทธ๋ผ๋์ธํธ๋ฅผ ํ๋ผ๋ฏธํฐ์ ๋์ผํ ๋ชจ์(shape)์ ๊ฐ๋ ํ ์๋ก ์ ์ฅ๋๋ค๋ ์ ์ ๊ฐ์ํ๋ฉด, ๊ธฐ์กด์ Full-Rank ํ ์ ๋์ Low-Rank ํ ์๋ฅผ ํ์ต์ ์ด์ฉํจ์ผ๋ก์ ๊ทธ๋ผ๋์ธํธ ํ ์์ ํฌ๊ธฐ ์ญ์ ํ๊ธฐ์ ์ผ๋ก ์ค์ผ ์ ์๊ฒ ๋ค.
ํธ๋์คํฌ๋จธ ๊ณ์ด์ ๋ชจ๋ธ๋ค์ด ResNet ๋๋น ์๋์ ์ผ๋ก ํ๋ผ๋ฏธํฐ ๊ฐ์๊ฐ ๋ง๊ธฐ ๋๋ฌธ์ LoRA
๋ฅผ ์ ์ฉํ๋ค๋ฉด ํจ์ฌ ํฐ ํจ๊ณผ๋ฅผ ๋ณผ ์ ์์ ๊ฒ์ด๋ค.
์ผ์ชฝ ๊ทธ๋ฆผ์ ๋
ผ๋ฌธ์์ ์ ์ํ, BERT
๊ณ์ด์ LM
์ LoRA
๋ฅผ ์ ์ฉํ ํ์ธํ๋ ๊ฒฐ๊ณผ๋ค. ํ์ FT
๊ฐ ์ผ๋ฐ์ ์ธ ํ์ธํ๋ ๋ฐฉ๋ฒ์ ์ํด ๋์จ ๊ฒฐ๊ณผ๋ค. ์์น๋ฝ๋ค์น๋ฝํ๋ฉด์ ๊ฑฐ์ ๋น์ทํ ์์์ ๋ณด์ธ๋ค. ๋ฒค์น๋งํฌ ํ๊ท ์ฑ๋ฅ์ LoRA
๊ฐ ๋ ๋๋ค. ์๋ง, ์ ๋นํ ์ฑ๋ฅ ์ฐจ์ด๋ฅผ ๋ณด์ฌ์ฃผ๊ธฐ ์ํด ์ทจ์ฌ์ ํ๋ ๋ฒค์น๋งํฌ์ผ ๊ฐ๋ฅ์ฑ์ด ๋์ง๋ง, ๊ทธ๋๋ ์๋นํ ์ ์๋ฏธํ ๊ฒฐ๊ณผ๋ผ๊ณ ์๊ฐํ๋ค. ์ฐ์ธก์ GPT2
์ LoRA
๋ฅผ ์ ์ฉํ ๊ฒฐ๊ณผ๋ค. ๋ง์ฐฌ๊ฐ์ง๋ก, ์๋น์ทํ ์ฑ๋ฅ ์ถ์ด๋ฅผ ๋ณด์ฌ์ค๋ค.
์ง๊ธ๊น์ง LoRA
๊ฐ ์ ์ํ๋ ๋ฐฉ๋ฒ๋ก ์ด ์ด๋ป๊ฒ ํ๊ธฐ์ ์ผ๋ก ํ์ต ํ๋ผ๋ฏธํฐ๋ฅผ ์ค์ด๊ณ ๋์๊ฐ ๋ชจ๋ธ์ด ์ฐจ์งํ๋ GPU VRAM
ํฌ๊ธฐ๋ฅผ ๊ฐ์์์ผฐ๋์ง ์์ ๋ณด์๋ค. ์ด์ LoRA
๋ฅผ ์ ์ฉํด๋ ์ผ๋ฐ์ ์ธ ํ์ธํ๋ ๋ฐฉ๋ฒ๊ณผ ๋น์ทํ ์ฑ๋ฅ์ ์ ์งํ ์ ์์๋์ง ๊ทธ ๊ฒฐ๊ณผ์ ๋ํด ํด์ํด๋ณด์. ๋
ผ๋ฌธ์ Chapter 7. UNDERSTANDIGN THE LOW-RANK UPDATES
๋ด์ฉ์ ํด๋น๋๋ค. ํด๋น ํํธ๋ 3๊ฐ์ง ์ธ์ฌ์ดํธ๋ฅผ ์ ์ํ๋ค.
๐กย Inisght 1. Apply to LoRA (Wq, Wv) or (Wq, Wk, Wv, Wo)
ํ์๋ ๋
ผ๋ฌธ์ ์ฝ๋ ๋ด๋ด, โ๊ทธ๋์ ์ด๋ค ๊ฐ์ค์น ํ๋ ฌ์ ์ ์ฉํด์ผ ํ ๊น?? ๋ชจ๋ ๊ฐ์ค์น ํ๋ ฌ์ ์ ์ฉํด๋ ๋๋๊ฑธ๊น??โ
ํ๋ ์๋ฌธ์ ๊ฐ๊ณ ์์๋ค. ๊ทผ๋ฐ ๋ง์นจ ์ ์๋ค์ด ์ด๋ฌํ ์๋ฌธ๋ค์ ์์ํ ๋ฏ, ์์ ๊ฐ์ด ์ ์ฉ ๊ฐ์ค์น ํ๋ ฌ์ ๋ฐ๋ฅธ ๋ฒค์น๋งํฌ ์ฑ๋ฅ ๊ฒฐ๊ณผ๋ฅผ ํ๋ก ์ ๋ฆฌํด์ฃผ์๋ค. ๋ชจ๋ธ์ GPT3
์ ์ฌ์ฉํ๋ค๊ณ ๋
ผ๋ฌธ์์ ๋ฐํ๊ณ ์๋ค.
๋ณด์ด๋ ๊ฒ๊ณผ ๊ฐ์ด, ($W_q, W_v$) ํน์ ($W_q, W_k, W_v, W_o$)์ LoRA
๋ฅผ ์ ์ฉํ๋๊ฒ ๊ฐ์ฅ ์ข์ ๋ฒค์น๋งํฌ ์ฑ๋ฅ์ ๋ณด์ฌ์ค๋ค. ์ฃผ๋ชฉํ ์ ์ ๋ญํฌ๊ฐ ๊ฐ์ฅ ๋ฎ์ผ๋ฉด์, ๊ฐ์ฅ ๋ง์ ๊ฐ์ค์น ํ๋ ฌ์ LoRA
๋ฅผ ์ ์ฉํ๋๊ฒ ๊ฐ์ฅ ์ฑ๋ฅ์ด ์ข๋ค๋ ๊ฒ์ด๋ค. ์คํ๊ฒฐ๊ณผ ์ ์ ์ด์ธ์ ๋ค๋ฅธ ์ฆ๋ช
์ด๋ ์ธ์ฌ์ดํธ ์ ์๊ฐ ์๋๊ฒ ์์ฝ์ง๋ง, ์ด๋ฅผ ํตํด ๋ค์๊ณผ ๊ฐ์ ์ฌ์ค๋ค์ ๋ ์ฌ๋ ค ๋ณด์๋ค.
- 1)
FT
๋ฌธ์ ํด๊ฒฐ์ ํ์ํ ๋ฌธ๋งฅ ์ ๋ณด๋ค์ด์ฟผ๋ฆฌ
,ํค
,๋ฒจ๋ฅ
ํ๋ ฌ์ ์ ์ ํ ๋ถ์ฐ- ์ธ๊ฐ์ง ๊ฐ์ค์น ํ๋ ฌ์ด ๋ชจ๋ ์ ์๋ฏธํ ๋ฌธ๋งฅ ํํ์ ํ์ต
- 2) ๋ฎ์ ๋ญํฌ๋ก๋ ์ถฉ๋ถํ,
FT
์ ํ์ํ ์๋ฒ ๋ฉ ์ถ์ถ ๊ฐ๋ฅ- ๊ทธ๋งํผ, ์ฌ์ ํ์ต์์ ํฌ์ฐฉํ ์ ์๋ ์๋ฒ ๋ฉ์ด ํ๋ถํ๋ฉฐ ์ผ๋ฐํ ๋ฅ๋ ฅ์ด ์ข๋ค๊ณ ํ๋จํ ์ ์์
- ์ฌ์ ํ์ต ๋จ๊ณ์์ ์ต๋ํ ๊น๊ฒ ๋ง์ด ํ์ต์ํฌ์๋ก FT ๋จ๊ณ๊ฐ ๊ฐ์ํ ๋ ์ ์์ง ์์๊น??
- ๋ค๋ง, ์ฌ์ ํ์ต๊ณผ ํ์ธํ๋ ์ฌ์ด์ ๊ดด๋ฆฌ๊ฐ ํฐ ๊ฒฝ์ฐ๋ผ๋ฉด??
- ์ฌ์ ํ์ต์ ์์ด๋ก, ํ์ธํ๋์ ํ๊ตญ์ด ๋ฐ์ดํฐ ์ธํธ๋ก ํ๋ ๊ฒฝ์ฐ๋ผ๋ฉด??
- ๊ทธ๋งํผ, ์ฌ์ ํ์ต์์ ํฌ์ฐฉํ ์ ์๋ ์๋ฒ ๋ฉ์ด ํ๋ถํ๋ฉฐ ์ผ๋ฐํ ๋ฅ๋ ฅ์ด ์ข๋ค๊ณ ํ๋จํ ์ ์์
๋ง์นจ ์ฃผ์์ However, we do not expect a small r to work for every task or dataset. Consider the following thought experiment: if the downstream task were in a different language than the one used for pre-training, retraining the entire model
์ด๋ผ๋ ์ธ๊ธ์ด ์๋ ๊ฒ์ผ๋ก ๋ณด์, ๋ญํฌ ๊ฐ์ ๋๋๋ก ๋ฎ์ ๊ฐ์ ์ ์ ํ๋, ์ฌ์ ํ์ต๊ณผ ํ์ธ ํ๋์ ๊ดด๋ฆฌ๊ฐ ์ฌํ๋ค๊ณ ํ๋จ๋๋ ๊ฒฝ์ฐ, ๋์ ๋ญํฌ๊ฐ๊ณผ ์คํ ๊ฒฐ๊ณผ ๋น๊ต๋ฅผ ํตํด ์ ์ ํ ๊ฐ์ ์ ์ ํด์ผ๊ฒ ๋ค.
๐กย Inisght 2. ๋ฎ์ ๋ญํฌ๋ก๋ ์ถฉ๋ถ
๋ฎ์ ๋ญํฌ๋ก๋ ์ถฉ๋ถํ, FT
์ ํ์ํ ์๋ฒ ๋ฉ ์ถ์ถ ๊ฐ๋ฅํ๋ค๋ ๊ฒ์ ์ข ๋ ๊ตฌ์ฒด์ ์ธ ์คํ์ผ๋ก ์ฆ๋ช
ํ๊ณ ์๋ค. ๊ทธ๋ํ $y$์ถ๊ณผ $x$์ถ์ ๊ฐ๊ฐ $A_r = 8$, $A_r = 64$์ธ (ํ
์ ๋ชจ์ [r, dim_model]
) ๊ฐ์ค์น ํ๋ ฌ์ SVD
ํ์ฌ ์ป์ right-singular matrix
์์ top-i(1 โค i โค 8)
, top-j (1 โค j โค 64)
๊ฐ์ ํน์ด๊ฐ์ ์ถ์ถํ ๋ค, Grassmann Distance
๋ฅผ ๊ฑฐ๋ฆฌ ๋งคํธ๋ฆญ์ผ๋ก ์ด์ฉํด ๋ถ๋ถ ๊ณต๊ฐ ์ฌ์ด์ ์ ์ฌ๋๋ฅผ ์ธก์ ํ ๊ฒฐ๊ณผ๋ค.
ํํธ, ์ right-singular matrix
์ผ๊น ๋ค์ ํ ๋ฒ ์๊ฐํด๋ดค๋ค. ์ ์ฒด ๋ฒกํฐ ๊ณต๊ฐ์์ top-i(1 โค i โค 8)
, top-j (1 โค j โค 64)
๊ฐ์ ํน์ด๊ฐ์ ๋ฝ์๋ด ์๋ก ๋น๊ตํ๋ ค๋ฉด, ๋ ํ๋ ฌ $A_{r = 8}$, $A_{r = 64}$์ด ๊ฐ์ ๋ถ๋ถ ๊ณต๊ฐ์์ ์ ์ ๋์ด์ผ ํ๋ค. SVD
์ ์์, ์ผ์ชฝ ํน์ด๋ฒกํฐ๋ ๊ฐ๊ฐ 8x8
, 64x64
์ฐจ์์ด ๋์ด ๋น๊ตํ๊ธฐ ์ด๋ ต๋ค. ํํธ, ์ค๋ฅธ์ชฝ ๋ฒกํฐ๋ ๋ ํ๋ ฌ ๋ชจ๋ dxd
๋ก ์ ์๋๋ค. ๋ง์ฝ ํ๋ ฌ $A$ ๋์ $B$๋ฅผ ์ฌ์ฉํ๊ณ ์ถ๋ค๋ฉด ์ผ์ชฝ ํน์ด ๋ฒกํฐ๋ฅผ ์ฌ์ฉํ๋ฉด ๋๋ค.
์ค๋์ง ์์ ๊ฐ๊น์ธ์๋ก ์๋ก ๊ฒน์น๋ ์ ๋ณด๊ฐ ๋ง๋ค๋ ์๋ฏธ๋ฅผ ๊ฐ๋๋ฐ, ์ฌ์ฉํ ์ธ์ฝ๋ ์์น์ ์๊ด์์ด $A_{r = 8}$์ top
์ด๋ฒกํฐ์ผ์๋ก, $A_{r = 64}$์ ๋๋จธ์ง ์ด๋ฒกํฐ๋ค๊ณผ ๋์ ์ ์ฌ๋(์ค๋์ง์์ ๊ฐ๊น์)์ ๊ธฐ๋กํ๊ณ ์๋ค.(ํท๊ฐ๋ฆฌ๋๊น ์ผ์ชฝ ๋ ๊ฐ ๊ทธ๋ํ๋ง ๋ณด๋๊ฒ ๋ซ๋ค). ๊ทธ๋ฆฌ๊ณ $A_{r = 8}$์ bottom
์ด๋ฒกํฐ์ผ์๋ก, ๊ฑฐ๋ฌด์ฃฝ์ฃฝํ ์๊น์ ๊ฐ์ง๋ฉฐ $A_{r = 64}$์ ๋๋จธ์ง ์ด๋ฒกํฐ๋ค๊ณผ ๋ฎ์ ์ ์ฌ๋๋ฅผ ๋ณด์ธ๋ค.
๊ฒฐ๊ตญ ์๊ณ ๋ดค๋๋ ์ฌ์ ํ์ต์ ์ถฉ๋ถํ ์ํํ ๋ชจ๋ธ์ ๊ฒฝ์ฐ, ํ์ธํ๋ Task
์ ๋ํด ์ ์์ํค๋๋ฐ ํ์ํ ๊ณต๊ฐ์ ์์, ๊ตณ์ด ์ ์ฒด ๊ณต๊ฐ์ ํ์ต ํ๋ผ๋ฏธํฐ๋ก ๋๊ณ ํ์ธํ๋ํด๋ด์ผ ๋๋ถ๋ถ์ ์ด๋ฒกํฐ๋ ์ฐ์๋ฐ๊ธฐ ์๋ ํํ์ ์ธ์ฝ๋ฉํ๋๋ฐ ์ฐ์ด๊ณ ์์๋ค๊ณ ๋ณผ ์ ์๊ฒ ๋ค.
๋ฌผ๋ก ์ฌ๊ธฐ์๋ ์ฃผ์ํ ์ ์, GPT3
์ ์ฌ์ ํ์ต๊ณผ ๊ถค๊ฐ ๋น์ทํ WikiSQL
, MNLU
์ ๋ํด ํ์ธํ๋ํ ๊ฒฐ๊ณผ๋ผ๋ ์ ์ด๋ค. ๋ค๊ตญ์ด๋ก ๊ตฌ์ฑ๋ ๋ฐ์ดํฐ ์ธํธ๋ฅผ ํ์ฉํ๊ฒ ๋๋ฉด, ์ด ๊ฒฐ๊ณผ๊ฐ ์ด๋ป๊ฒ ๋ฐ๋์ง ๋ชจ๋ฅธ๋ค.
Grassmann Distance
๋ ์ ํ ๋ถ๋ถ๊ณต๊ฐ(linear subspace) ๊ฐ์ ๊ฑฐ๋ฆฌ๋ฅผ ์ธก์ ํ๋ ๋ฐ ์ฌ์ฉ๋๋ ๊ฐ๋
์ด๋ผ๊ณ ํ๋๋ฐ, ์ฌ๊ธฐ์ ์ด๊ฒ๊น์ง ๋ค๋ฃจ๋ฉด ํฌ์คํ
๊ธธ์ด๊ฐ ๋๋ฌด ๊ธธ์ด์ง ๊ฒ ๊ฐ์์, ๋์ค์ ๋ค๋ฅธ ํฌ์คํธ์์ ๋ค๋ฃจ๋๋ก ํ๊ฒ ๋ค.
๐กย Inisght 3. w โ delta w
์ฌ์ ํ์ตํ ๊ฐ์ค์น ํ๋ ฌ $W$๊ณผ LoRA
์ $\Delta W$๊ฐ ์๋ก ์ผ๋ง๋ ์ ์ฌํ์ง, ์คํ์ ์ผ๋ก ์ฆ๋ช
ํ๊ณ ์๋ค. ๋
ผ๋ฌธ์์ ์ ๊ณตํ ์คํ ๋ฐฉ์์ ์ ๋ฆฌํ๋ฉด ๋ค์๊ณผ ๊ฐ๋ค.
- 1) ์ฌ์ ํ์ต์ผ๋ก ์๋ ด๋ ์ฟผ๋ฆฌ ํ๋ ฌ, $W_q$๋ฅผ
task-specific
ํ ๊ณต๊ฐ($U^T, V^T$: $\Delta W$์Top-r
๊ฐ์Left, Right-Singular Vector
)์ผ๋ก ํฌ์ - 2) LoRA์ ์ํด ์๋ ด๋ ๋ธํ ์ฟผ๋ฆฌํ๋ ฌ, $\Delta W_q$์ ์ด๋ฏธ ํ๋ ฌ ์ ๋ฐ์
task-specific
ํ ์ ๋ณด๋ฅผ ๋ด๊ณ ์์.- ๊ทธ๋์
top-r
์ถ์ถํ์ง ์๊ณ ์ ์ฒด์ ๋ํด์ํ๋ก๋ฒ ๋์ฐ์ค ๋
๊ตฌํ๊ธฐ
- ๊ทธ๋์
- 3) 1๋ฒ ์คํญ์์ ๊ตฌํ ํฌ์ ํ๋ ฌ, $U^TW_qV^T$์ ๋ํด
ํ๋ก๋ฒ ๋์ฐ์ค ๋
๊ณ์ฐ - 4) 2๋ฒ/3๋ฒ ์ํ:
task-specific
ํ ๊ณต๊ฐ์LoRA
๊ฐ ์ฌ์ ํ์ต ๊ฐ์ค์น์ ๋นํด ์ผ๋ง๋ ๋ง์ด ๊ฐ์กฐํ๋์ง ๋ํ๋ด๋ ์งํ- ๋
ผ๋ฌธ์์๋
Feature Amplication Factor
๋ผ๊ณ ์ ์
- ๋
ผ๋ฌธ์์๋
ํ๋ก๋ฒ ๋์ฐ์ค ๋
์ ๊ธฐํํ์ ์ผ๋ก ํ๋ ฌ์ ํฌ๊ธฐ, ์ฆ ์ ํ๋ณํ
์ ํฌ๊ธฐ๋ฅผ ์๋ฏธํ๋ค. ๊ทธ๋์ ๊ณง Feature Amplication Factor
๊ฐ ํ๋ ฌ์ ํฌ๊ธฐ/ํ๋ ฌ์ ํฌ๊ธฐ๋ฅผ ๋ํ๋ด๋ ์งํ๊ฐ ๋๊ณ , ๋ถ์์ ๋ถ๋ชจ์ ํ๋ ฌ์ ๋ชจ๋ task-specific
ํ ๊ณต๊ฐ์ผ๋ก์ ๋ณํ ํฌ๊ธฐ๋ฅผ ์๋ฏธํ๊ธฐ ๋๋ฌธ์, ๊ฐ์ ํน์ง์ ๋ถ์($\Delta W$)๊ฐ ๋ถ๋ชจ($W$)์ ๋นํด์ ์ผ๋ง๋ ๋ ๊ฐ์กฐํ๋์ง๋ฅผ ๋ปํ๊ฒ ๋๋ค. ๋ ๋ฆฌ์ factor
๊ฐ์ด ํด์๋ก LoRA
๊ฐ ์ฌ์ ํ์ต์์ ๊ฐ์กฐํ์ง ์์๋ ํน์ง์ ๋์ฑ ๊ฐ์กฐํ๋ค๊ณ ํด์ํ ์ ์๊ฒ ๋๋ค. ์ด์ ๋ค์ ํ๋ฅผ ๋ถ์ํด๋ณด์.
Low Rank value
$r=4$์ผ ๋, Feature Amplication Factor
์ ๋ถ๋ชจ๋ 0.32
, ๋ถ์๋ 6.91
์ด ๋๋ค. ๋ฐ๋ผ์ factor
๊ฐ์ ๋๋ต 21.5
๊ฐ ๋๋ค. ๋ค์ ๋งํด GPT3
์ 48๋ฒ์งธ ๋ ์ด์ด์ ๊ฒฝ์ฐ, FT
์ ์์ ํ์ํ task-specific
ํ ๊ณต๊ฐ์ LoRA
๊ฐ ์ฌ์ ํ์ต ์ฟผ๋ฆฌ ํ๋ ฌ
๋ณด๋ค 21.5
๋ฐฐ ๊ฐ์กฐํ๊ณ ์๋ค๋ ๊ฒ์ด๋ค.
Low Rank value
$r=64$์ผ ๋๋ factor
๊ฐ ๋๋ต 1.9
๊ฐ ๋๋ค. $r=4$์ผ ๋๋ณด๋ค factor
๊ฐ์ด ํ์ ํ ๋ฎ์ ์ด์ ๋ Insight 2
์ ๊ฒฐ๊ณผ(๋ฎ์ ๋ญํฌ๋ก๋ ์ถฉ๋ถํ FT
์ task-specific
์ ๋ณด ํํ ๊ฐ๋ฅ)์ ์ผ๋งฅ์ํตํ๋ค๊ณ ๋ณผ ์ ์๋ค.
์ฒ์ ์ฝ์์ ๋ ์ด ๋ถ๋ถ์ ๋ํ ํด์์ด ๋๋ฌด ๋ํดํด, ์ ์๋ค์ด ๊นํ๋ธ์ ๊ณต๊ฐํ, RoBERTa
๋ฅผ LoRA์ ํจ๊ป MRPC
๋ฒค์น๋งํฌ์ ํ์ธํ๋ํ ๊ฐ์ค์น๋ฅผ ๋ถ๋ฌ์ ๋๊ฐ์ ๋ฐฉ์์ผ๋ก ์คํ์ ์งํํด๋ดค๋ค. ๋จผ์ ์ ์ฒด ์คํ ๋ฐฉ์์ ์์ฝํ๋ฉด ๋ค์๊ณผ ๊ฐ๋ค.
- 1)
Huggingface Hub
์์RoBERTa-base
์ ์ฌ์ ํ์ต ๊ฐ์ค์น ๋ถ๋ฌ์ค๊ธฐ - 2)
LoRA official github
์์roberta_base_lora_mrpc.bin
๋ถ๋ฌ์ค๊ธฐ - 3)
1,2
๋ฒ์์ ๋ชจ๋6
๋ฒ์งธ์ธ์ฝ๋ ๋ ์ด์ด
์์ฟผ๋ฆฌ ํ๋ ฌ
์ ๋ํ ๊ฐ์ค์น ์ถ์ถ - 4) ์ดํ ๋๋จธ์ง ๊ณผ์ ์ ์์ ๋ ผ๋ฌธ์ ์คํ ๋ฐฉ์์ ๋ฐ๋ฆ
์ ์ฒด ๊ณผ์ ์ ์ฝ๋๋ก ์ ๋ฆฌํ๋ฉด ๋ค์๊ณผ ๊ฐ๋ค.
""" Insight 3 Experiment Code Exanple """
import torch
from transformers import AutoModel, AutoConfig
""" LoRA ๊ฒฐ๊ณผ ํด์ ์ฌํ """
pt_config = AutoConfig.from_pretrained('FacebookAI/roberta-base')
pt_model = AutoModel.from_pretrained( # pretrained model
'roberta-base',
config=pt_config
)
lora_checkpoint = torch.load('model/roberta_base_lora_mrpc.bin', map_location='cpu')
lora_checkpoint
""" Select Wq in 6-th encoder layer """
pt_wq, lora_a, lora_b = pt_model.encoder.layer[6].attention.self.query.weight, lora_checkpoint['roberta.encoder.layer.6.attention.self.query.lora_A'], lora_checkpoint['roberta.encoder.layer.6.attention.self.query.lora_B']
delta_wq = lora_b @ lora_a
pt_wq.shape, lora_a.shape, lora_b.shape, delta_wq.shape
>>> (torch.Size([768, 768]),
>>> torch.Size([8, 768]),
>>> torch.Size([768, 8]),
>>> torch.Size([768, 768]))
""" Let's SVD, select top-r singular vector, ๋ถ์ """
U, S, V = torch.svd(delta_wq)
print(f"Delta W U: {U.shape}")
print(f"Delta W S: {S.shape}")
print(f"Delta W V: {V.shape}")
>>> Delta W U: torch.Size([768, 768])
>>> Delta W S: torch.Size([768])
>>> Delta W V: torch.Size([768, 768])
r = 4
r_U, r_V = U[:, :r], V[:r, :]
result1 = torch.matmul(r_U.T @ pt_wq, r_V.T)
fwq_norm = torch.norm(result1) # ๋ถ์๊ฐ
result1, fwq_norm
>>> (tensor([[-0.0441, 0.0447, 0.0323, 0.0963],
[-0.0038, -0.0412, -0.0903, -0.0949],
[-0.0314, 0.1003, -0.0599, 0.0023],
[-0.0222, -0.1090, 0.0315, 0.0575]], grad_fn=<MmBackward0>),
>>> tensor(0.2539, grad_fn=<LinalgVectorNormBackward0>))
""" ๋ถ๋ชจ """
fdwq_norm = torch.norm(delta_wq) # ๋ถ๋ชจ๊ฐ
fdwq_norm
>>> tensor(5.0820)
"""๊ฒฐ๊ณผ: Feature Amplication Factor """
fdwq_norm / fwq_norm
>>> tensor(20.0170, grad_fn=<DivBackward0>)
๐ฉโ๐ปย Implementation by Pytorch
import math
import torch
import torch.nn as nn
from torch import Tensor
class LoRA(nn.Module):
""" class module for Low-Rank adaptation of LLM SFT
This module return result of "BAx*(a/r)" in mathematical expression in official paper
Args:
dim: dimension of input tensor
rank: rank of tensor, which is hyperparameter for LoRA
alpha: hyperparameter for LoRA, trainable parameter, which is initialized by rank value
options: default str, 'rlora' which is already proved to work better than pure lora
you can select pure lora as passing argument 'lora'
Math:
h = W0x + โWx = W0x + BAx*(a/r)
Notes:
we use sqrt(rank) value, it is already proven to work better in LoRA,
from Huggingface PEFT library official docs
References:
https://arxiv.org/abs/2106.09685
https://pytorch.org/blog/understanding-gpu-memory-1/
"""
def __init__(self, dim: int, rank: int, alpha: int, options: str = 'rlora'):
super().__init__()
self.a = nn.Parameter(torch.randn(rank, dim)) # init by random Gaussian distribution (normal distribution)
self.b = nn.Parameter(torch.zeros(dim, rank)) # init by zero
self.alpha = alpha / math.sqrt(rank) if options == 'rlora' else alpha / rank
def forward(self, inputs: Tensor) -> Tensor:
return torch.matmul(inputs, self.b @ self.a) * self.alpha
๋ค์์ ํ์๊ฐ ์ง์ ๊ตฌํํ LoRA
๊ฐ์ฒด๋ค. ๊ตฌํ์ ํน์ด์ ์ด๋ผ์ ํฌ์คํ
๋ด์ฉ์๋ ํฌํจ๋์ง ์์ ๋ถ๋ถ๋ค์ด ์๊ธฐ ๋๋ฌธ์ ์ฝ๋๋ฅผ ํจ๊ป ์ดํด๋ณด์. ์ผ๋จ ํ๋ ฌ $A$์ ํด๋น๋๋ self.a
๋ nn.Parameter
๋ฅผ ํธ์ถํด ๋ชจ๋ธ์ด ํ์ต ํ๋ผ๋ฏธํฐ๋ก ์ธ์ํ๋๋ก ๋ง๋ ๋ค. ๊ทธ๋ฆฌ๊ณ ๋
ผ๋ฌธ์ ๋์จ๋๋ก, ๋๋ค ๊ฐ์ฐ์์ ๋ถํฌ๋ฅผ ๋ฐ๋ฅด๋๋ก ํ
์๋ฅผ ์ด๊ธฐํ ํด์ค๋ค, ํ๋ ฌ $B$์ ํด๋น๋๋ self.b
nn.Parameter
๋ฅผ ํธ์ถํด ๋ชจ๋ธ์ด ํ์ต ํ๋ผ๋ฏธํฐ๋ก ์ฌ์ฉํ๋๋ก ๋ง๋ค๊ณ , ๋
ผ๋ฌธ์ ๋์จ๋๋ก ์ํ๋ ฌ๋ก ์ด๊ธฐํ ํด์ค๋ค. ๋ง์ง๋ง์ผ๋ก $\Delta W$ ๊ฐ์ ์ค์ผ์ผ๋ง ํด์ค ์ค์ผ์ผ ํฉํฐ alpha
๋ฅผ ๋์
ํ๋ค. ์ด๋, options
์ธ์๋ฅผ ํตํด LoRA
์ RLORA
์ค ์ด๋ค ๊ฒ์ ์ฌ์ฉํ ์ง ์ ํํ ์ ์๋ค. ์ด์ค์์, ๋ถ๋ชจ์ ๋ญํฌ๊ฐ์ ์ ๊ณฑ๊ทผ์ ์ทจํด์ฃผ๋ RLORA
๊ฐ ๋ ์ข์ ์ฑ๋ฅ์ ๋ณด์ธ๋ค๊ณ ํ์ ์ฐ๊ตฌ์์ ๋ฐํ์ก๋ค๊ณ ํ๋ค.
์ด๋ ๊ฒ LoRA
๊ฐ์ฒด์ ๋ํด์ ํจ๊ป ์ดํด๋ณด์๋ค. ๊ตฌํ ์์ฒด๋ ๋งค์ฐ ๊ฐ๋จํ๋ค. ํ์ง๋ง, ์ค์ํ ์ ์ ์ฌ์ ํ์ต ๋ชจ๋ธ์ ๊ฐ์ค์น์ LoRA
๊ฐ์ฒด๋ฅผ ์ ์ฉํ์ฌ ์๋ก์ด ๋ชจ๋ธ ๊ฐ์ฒด๋ฅผ ๋ง๋ค์ด ๋ด๋ ๊ฒ์ด๋ค. ์๋ ์ฝ๋์ฒ๋ผ,
""" before MHA """
q = self.fc_q(x).reshape(-1, x.shape[1], self.num_attention_heads, self.dim_head).permute(0, 2, 1, 3).contiguous()
k = self.fc_k(x).reshape(-1, x.shape[1], self.num_attention_heads, self.dim_head).permute(0, 2, 1, 3).contiguous()
v = self.fc_v(x).reshape(-1, x.shape[1], self.num_attention_heads, self.dim_head).permute(0, 2, 1, 3).contiguous()
attention_output = self.fc_concat(attention_matrix)
""" after MHA """
self.lora_q = lora()
self.lora_k = lora()
self.lora_v = lora()
self.lora_o = lora()
q = self.fc_q(x) + self.lora_q(x) # freeze + trainable
q.reshape(-1, x.shape[1], self.num_attention_heads, self.dim_head).permute(0, 2, 1, 3).contiguous()
k = self.fc_k(x) + self.lora_k(x) # freeze + trainable
k.reshape(-1, x.shape[1], self.num_attention_heads, self.dim_head).permute(0, 2, 1, 3).contiguous()
v = self.fc_v(x) + self.lora_v(x) # freeze + trainable
v.reshape(-1, x.shape[1], self.num_attention_heads, self.dim_head).permute(0, 2, 1, 3).contiguous()
attention_output = self.fc_concat(attention_matrix) + self.lora_o(attention_matrix) # freeze + trainable
์ ํ ํฌ์๋ ์ฟผ๋ฆฌ, ํค, ๋ฒจ๋ฅ ํ๋ ฌ๊ณผ ๊ฐ๊ฐ์ LoRA
๊ฐ์ฒด๋ฅผ ๋ํด์ค ์๋ง ์๋ค๋ฉด ๋งค์ฐ ๊ฐ๋จํ๊ฒ ํด๊ฒฐ๋ ๋ฌธ์ ์ง๋ง, ์ฌ์ ํ์ต ๋ชจ๋ธ์ Multi-Head Attention
๊ฐ์ฒด๋ฅผ ์ฒ์๋ถํฐ ์ ๋ฐ์์ผ๋ก ์ ์ํด์ผ๋ง ๊ฐ๋ฅํ ์ผ์ด๋ค. ํ์๊ฐ ์์ฑํ ๋ชจ๋ธ ์ฝ๋๋ฅผ ๋น๋กฏํด ๋๋ถ๋ถ์ ์คํ์์ค๋ก ํ๋ ค์๋ ํธ๋์คํฌ๋จธ ๋ชจ๋ธ๋ค์ ์ ๋ฐ์์ผ๋ก ์์ฑ๋์ด ์์ง ์๋ค. ๋ฐ๋ผ์ ๋ค๋ฅธ ๋ฐฉ๋ฒ์ ๋ ์ฌ๋ คํ๋๋ฐ, ๋น์ฅ์ ๋๋ฌด ๋ณต์กํ ์์
์ด ๋ ๊ฒ ๊ฐ์(์คํ ์ดํ๋ฆฌ์ผ์ด์
๊ตฌ์กฐ๋ฅผ ๋ค์์ด์ผ ๊ฐ๋ฅํ ๊ฒ์ผ๋ก ์์ธก) ์ผ๋จ์ ์ฌ๊ธฐ์ ๋ง๋ฌด๋ฆฌํ๋ ค๊ณ ํ๋ค. ๋ง์ฝ LoRA
๋ฅผ ์ฌ์ ํ์ต ๋ชจ๋ธ์ ์ ์ฉํด ํ์ธํ๋์ ํด๋ณด๊ณ ์ถ๋ค๋ฉด, Huggingface
์ PEFT
๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ด์ฉํด๋ณด์. Hugginface
์ Automodel
, Trainer
๊ฐ์ฒด์ ์ ์ฐํ๊ฒ ์ฐ๋์ด ๊ฐ๋ฅํ๋ค. ์๋์ PEFT
๊ณต์ ๋ฌธ์์์ ์ฐธ๊ณ ํ Usage Example
์ฝ๋๋ฅผ ์ฒจ๋ถํ์ผ๋ ์ฐธ๊ณ ๋ถํ๋ฐ๋๋ค.
""" PEFT LoRA Usage Example
Reference:
https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/model.py
"""
>>> from transformers import AutoModelForSeq2SeqLM
>>> from peft import LoraModel, LoraConfig
>>> config = LoraConfig(
... task_type="SEQ_2_SEQ_LM",
... r=8, # rank value in official paper
... lora_alpha=32, # alpha value in official paper
... target_modules=["q", "v"],
... lora_dropout=0.01,
... )
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
>>> lora_model = LoraModel(model, config, "default")
>>> import torch
>>> import transformers
>>> from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
>>> rank = ...
>>> target_modules = ["q_proj", "k_proj", "v_proj", "out_proj", "fc_in", "fc_out", "wte"] # target for projection matrix, MLP
>>> config = LoraConfig(
... r=4, lora_alpha=16, target_modules=target_modules, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM"
... )
>>> quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True)
>>> tokenizer = transformers.AutoTokenizer.from_pretrained(
... "kakaobrain/kogpt",
... revision="KoGPT6B-ryan1.5b-float16", # or float32 version: revision=KoGPT6B-ryan1.5b
... bos_token="[BOS]",
... eos_token="[EOS]",
... unk_token="[UNK]",
... pad_token="[PAD]",
... mask_token="[MASK]",
... )
>>> model = transformers.GPTJForCausalLM.from_pretrained(
... "kakaobrain/kogpt",
... revision="KoGPT6B-ryan1.5b-float16", # or float32 version: revision=KoGPT6B-ryan1.5b
... pad_token_id=tokenizer.eos_token_id,
... use_cache=False,
... device_map={"": rank},
... torch_dtype=torch.float16,
... quantization_config=quantization_config,
... )
>>> model = prepare_model_for_kbit_training(model)
>>> lora_model = get_peft_model(model, config)
Leave a comment