Updated:

๐Ÿ”ญย Overview

LoRA LoRA

LoRA๋Š” 2021๋…„ MS ์—ฐ๊ตฌ์ง„์ด ๋ฐœํ‘œํ•œ ๋…ผ๋ฌธ์œผ๋กœ ์›๋ณธ(Full ํŒŒ์ธํŠœ๋‹)๊ณผ ๊ฑฐ์˜ ์œ ์‚ฌํ•œ ์„ฑ๋Šฅ(์‹ฌ์ง€์–ด ์ผ๋ถ€ ๋ฒค์น˜๋งˆํฌ๋Š” ๋” ๋†’์Œ)์œผ๋กœ LLM ํŒŒ์ธํŠœ๋‹์— ํ•„์š”ํ•œ GPU ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ํš๊ธฐ์ ์œผ๋กœ ์ค„์ด๋Š”๋ฐ ์„ฑ๊ณตํ•ด ์ฃผ๋ชฉ์„ ๋ฐ›์•˜๋‹ค. ์ปค๋ฎค๋‹ˆํ‹ฐ์—์„œ LoRA is All You Need ๋ผ๋Š” ๋ณ„๋ช…๊นŒ์ง€ ์–ป์œผ๋ฉฐ ๊ทธ ์ธ๊ธฐ๋ฅผ ๊ตฌ๊ฐ€ํ•˜๊ณ  ์žˆ๋‹ค.

DistilBERT ๋ฆฌ๋ทฐ์—์„œ๋„ ์‚ดํŽด๋ณด์•˜๋“ฏ, BERT์™€ GPT์˜ ๋“ฑ์žฅ ์ดํ›„, ๋ชจ๋“  NLP ๋„๋ฉ”์ธ์—์„œ ๋น„์•ฝ์ ์ธ ์„ฑ๋Šฅ ํ–ฅ์ƒ์ด ์ด๋ค„์คฌ์Œ์—๋„ ๋ถˆ๊ตฌํ•˜๊ณ , NLP์šฉ ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ์„ ์‹ค์ƒํ™œ์— ํ™œ์šฉํ•˜๊ธฐ์—๋Š” ๋„ˆ๋ฌด ํฐ ๋ฆฌ์†Œ์Šค ์š”๊ตฌ๋Ÿ‰๊ณผ ๋ ˆ์ดํ„ด์‹œ๊ฐ€ ๋ฐœ๋ชฉ์„ ์žก์•˜๋‹ค. ํ•˜์ง€๋งŒ LoRA ๋ฐœํ‘œ ์ดํ›„, ํŒŒ์ธํŠœ๋‹ ์‹œ์ ์— ํ›ˆ๋ จํ•ด์•ผ ํ•˜๋Š” ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜๊ฐ€ ํ˜„์ €ํžˆ ์ค„์–ด๋“ค๋ฉด์„œ ๋ชจ๋ธ์˜ ์ฒดํฌํฌ์ธํŠธ ์šฉ๋Ÿ‰์ด ๊ธฐํ•˜๊ธ‰์ˆ˜์ ์œผ๋กœ ๊ฐ์†Œํ–ˆ๋‹ค. ๋•๋ถ„์— ์š”๊ตฌ GPU VRAM์ด ํ˜„์ €ํžˆ ๋‚ฎ์•„์ ธ, ๋ฆฌ์†Œ์Šค ์ œํ•œ ๋•Œ๋ฌธ์— ์„œ๋น™ํ•˜์ง€ ๋ชปํ•˜๋Š” ๊ฒฝ์šฐ๊ฐ€ ๋งŽ์ด ์‚ฌ๋ผ์กŒ๋‹ค. ๊ทธ๋ž˜์„œ ์˜ค๋Š˜๋‚  Mixed Precision, Quantization๊ณผ ํ•จ๊ป˜ ๋ชจ๋ธ ๊ฒฝ๋Ÿ‰โ€ข์ตœ์ ํ™” ๋ถ„์•ผ์—์„œ ๊ฐ€์žฅ ์ค‘์š”ํ•œ ์ฃผ์ œ๋กœ ๋– ์˜ค๋ฅด๊ณ  ์žˆ๋‹ค.

๋‚ด์šฉ์„ ์‚ดํŽด๋ณด๊ธฐ์ „, LoRA ๋Š” ์ด๋ฏธ ์‚ฌ์ „ํ•™์Šต์„ ์™„๋ฃŒํ•œ ๋ชจ๋ธ์„ ํŒŒ์ธํŠœ๋‹ํ•  ๋•Œ ์‚ฌ์šฉํ•ด์•ผํ•จ์„ ๋‹ค์‹œ ํ•œ ๋ฒˆ ๋ช…์‹ฌํ•˜์ž. ์ด๋ฒˆ ํฌ์ŠคํŒ…์—์„œ๋Š” ๋‘๊ฐ€์ง€๋ฅผ ์ง‘์ค‘์ ์œผ๋กœ ๋‹ค๋ฃฐ ๊ฒƒ์ด๋‹ค.

1) ๋ชจ๋ธ ํฌ๊ธฐ ์ค„์ธ ๋ฐฉ๋ฒ•, 2) ํฌ๊ธฐ๋ฅผ ์ค„์ด๋ฉด์„œ๋„ ๋น„์Šทํ•œ ์„ฑ๋Šฅ์„ ๋‚ผ ์ˆ˜ ์žˆ์—ˆ๋˜ ์ด์œ 

๐Ÿค”ย Concept: Low-Rank Adaptation

\[h = W_0x + \Delta Wx = W_0x + BAx\]

์•„์ด๋””์–ด๋Š” ์ƒ๋‹นํžˆ ๊ฐ„๋‹จํ•˜๋‹ค. ์‚ฌ์ „ํ•™์Šต์„ ๋งˆ์น˜๊ณ  ์ˆ˜๋ ด๋œ ์ƒํƒœ์˜ ๊ฐ€์ค‘์น˜ ํ–‰๋ ฌ์„ ์˜๋ฏธํ•˜๋Š” $W_0$๊ณผ ์ƒˆ๋กœ์šด ๊ฐ€์ค‘์น˜ ํ–‰๋ ฌ $\Delta W$์— ๋ชจ๋‘ ์ž…๋ ฅ์„ ํ†ต๊ณผ์‹œํ‚จ๋‹ค. ๊ทธ๋ฆฌ๊ณ  ๋‚˜์˜จ ๊ฒฐ๊ณผ๋ฅผ ๋”ํ•ด ๋‹ค์Œ์ธต์˜ ์ž…๋ ฅ์œผ๋กœ ์‚ฌ์šฉํ•œ๋‹ค. ์˜คํžˆ๋ ค ์ƒˆ๋กœ์šด ๊ฐ€์ค‘์น˜ ํ–‰๋ ฌ์„ ์ถ”๊ฐ€ํ•ด ํŒŒ์ธํŠœ๋‹์„ ํ•˜๋Š”๋ฐ ์–ด๋–ป๊ฒŒ ํ›ˆ๋ จํ•ด์•ผ ํ•˜๋Š” ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜๋ฅผ ์ค„์ผ ์ˆ˜ ์žˆ์—ˆ์„๊นŒ??

๊ทธ ๋น„๋ฐ€์€ Freeze(Stop Gradient, require_grad=False)์™€ Matrix Factorization์— ์ˆจ์–ด ์žˆ๋‹ค. ๋จผ์ € ์‚ฌ์ „ ํ›ˆ๋ จ๋œ ๊ฐ€์ค‘์น˜ ํ–‰๋ ฌ์— Freeze(Stop Gradient, require_grad=False) ๋ฅผ ์ ์šฉํ•ด ๊ทธ๋ผ๋””์–ธํŠธ๊ฐ€ ํ๋ฅด์ง€ ์•Š๋„๋ก ํ•œ๋‹ค. ์ด๋ ‡๊ฒŒ ํ•˜๋ฉด ํŒŒ์ธํŠœ๋‹ ๊ณผ์ •์—์„œ ๊ฐ€์ค‘์น˜๊ฐ€ ์—…๋ฐ์ดํŠธ ๋˜์ง€ ์•Š์•„ ์‚ฌ์ „ ํ•™์Šต์—์„œ ์Šต๋“ํ•œ ์ง€์‹์„ ์œ ์ง€ํ•  ์ˆ˜ ์žˆ์„ ๋ฟ๋งŒ ์•„๋‹ˆ๋ผ, ํ•™์Šต์„ ์œ„ํ•ด ๊ทธ๋ผ๋””์–ธํŠธ๋ฅผ ์ €์žฅํ•  ํ•„์š”๊ฐ€ ์—†์–ด์ ธ ํŒŒ์ธํŠœ๋‹ ๋•Œ ํ•„์š”ํ•œ GPU VRAM์„ ํš๊ธฐ์ ์œผ๋กœ ์ค„์ผ ์ˆ˜ ์žˆ๋‹ค.

์ฒ˜์Œ์— ์‚ฌ์ „ํ•™์Šต ๊ฐ€์ค‘์น˜๋ฅผ ํ†ต๊ณผํ•œ ๊ฐ’๊ณผ ์ƒˆ๋กœ์šด ๊ฐ€์ค‘์น˜ ํ–‰๋ ฌ $\Delta W$๋ฅผ ํ†ต๊ณผํ•œ ๊ฐ’์„ ์„œ๋กœ ๋”ํ•œ๋‹ค๊ณ  ์–ธ๊ธ‰ํ–ˆ๋‹ค. ๊ทธ๋ ‡๋‹ค๋ฉด, ๋‘ ๊ฒฐ๊ณผ ํ–‰๋ ฌ์˜ ํ–‰๋ ฌ ํฌ๊ธฐ๊ฐ€ ๋™์ผํ•ด์•ผ ํ•œ๋‹ค๋Š” ๊ฒƒ์ด๋‹ค. ์–ด๋–ป๊ฒŒ ๊ธฐ์กด๋ณด๋‹ค ์‚ฌ์ด์ฆˆ๋Š” ์ค„์ด๋ฉด์„œ ๊ฒฐ๊ณผ ํ–‰๋ ฌ์˜ ํฌ๊ธฐ๋Š” ๋™์ผํ•˜๊ฒŒ ๋งŒ๋“ค์–ด์ค„ ์ˆ˜ ์žˆ์„๊นŒ?? ๋ฐ”๋กœ Low Rank value $r$์„ ๋„์ž…ํ•ด Matrix Factorization ์„ ํ•œ๋‹ค.

\[W_{d \times d} = \begin{bmatrix} w_{1,1} & w_{1,2} & \cdots & w_{1,d} \\ w_{2,1} & w_{2,2} & \cdots & w_{2,d} \\ \vdots & \vdots & \ddots & \vdots \\ w_{d,1} & w_{d,2} & \cdots & w_{d,d} \end{bmatrix}\]

ํ–‰๋ ฌ ๊ณฑ์…ˆ(matrix multiplication)์„ ๋‹ค์‹œ ํ•œ ๋ฒˆ ์ƒ๊ธฐํ•ด๋ณด์ž. MxN ์˜ ํฌ๊ธฐ๋ฅผ ๊ฐ–๋Š” ํ–‰๋ ฌ์— NxK์˜ ํฌ๊ธฐ๋ฅผ ๊ฐ–๋Š” ํ–‰๋ ฌ์„ ๊ณฑํ•ด์ฃผ๋ฉด MxK ์˜ ํฌ๊ธฐ๋ฅผ ๊ฐ–๋Š” ํ–‰๋ ฌ์„ ๋งŒ๋“ค์–ด์ค„ ์ˆ˜ ์žˆ๋‹ค. ๋งˆ์ฐฌ๊ฐ€์ง€๋‹ค. dxd ํฌ๊ธฐ์ธ ์‚ฌ์ „ํ•™์Šต์˜ ๊ฐ€์ค‘์น˜ ํ–‰๋ ฌ $W_{d \times d}$๊ณผ ํฌ๊ธฐ๋ฅผ ๋งž์ถ”๊ธฐ ์œ„ํ•ด, dxd ์งœ๋ฆฌ ํ–‰๋ ฌ์„ ๊ฐ๊ฐ dxr, rxd ์˜ ํฌ๊ธฐ๋ฅผ ๊ฐ–๋Š” ๋‘ ํ–‰๋ ฌ $B, A$๋กœ ๋ถ„ํ•ดํ•œ๋‹ค. ์ด ๋•Œ, ํ–‰๋ ฌ $B$์˜ ์—ด์ฐจ์›๊ณผ ํ–‰๋ ฌ $A$์ฐจ์›์˜ ํ–‰์ฐจ์› ํฌ๊ธฐ๋ฅผ ํ‘œํ˜„ํ•˜๋Š” $r$์— ๋ฐ”๋กœ Low Rank value $r$์„ ๋Œ€์ž…ํ•˜๋ฉด ๋œ๋‹ค.

\[\Delta W_{d \times d} = B_{d \times r}\ A_{r \times d} = \begin{bmatrix} w_{1,1} & w_{1,2} & w_{1,r} \\ w_{2,1} & w_{2,2} & w_{2,r} \\ \vdots & \vdots & \vdots \\ w_{d,1} & w_{d,2} & w_{d,r} \end{bmatrix}\begin{bmatrix} w_{1,1} & w_{2,1} & w_{d,1} \\ w_{1,2} & w_{2,2} & w_{d,2} \\ \vdots & \vdots & \vdots \\ w_{1,r} & w_{2,r} & w_{d,r} \end{bmatrix}\]

$r=3$์ด๋ผ๊ณ  ๊ฐ€์ •ํ•˜๊ณ  768x768 ์งœ๋ฆฌ ๊ธฐ์กด ๊ฐ€์ค‘์น˜ ํ–‰๋ ฌ $W$๊ณผ 768x3, 3x768์˜ ํฌ๊ธฐ๋ฅผ ๊ฐ–๋Š” $\Delta W = BA$์˜ ํŒŒ๋ผ๋ฏธํ„ฐ ๊ฐœ์ˆ˜๋ฅผ ๋น„๊ตํ•ด๋ณด์ž. ๊ณ„์‚ฐํ•ด๋ณด๋ฉด ์ „์ž๋Š” 589,824๊ฐœ, ํ›„์ž๋Š” 4608๊ฐœ๊ฐ€ ๋œ๋‹ค. ์ •ํ™•ํ•˜๊ฒŒ 128๋ฐฐ ์ฐจ์ด๊ฐ€ ๋‚œ๋‹ค. ํŠธ๋žœ์Šคํฌ๋จธ ๋ชจ๋ธ ์†์—๋Š” ํ–‰๋ ฌ $W$๊ณผ ๊ฐ™์€ ํฌ๊ธฐ๋ฅผ ๊ฐ–๋Š” ๊ฐ€์ค‘์น˜ ํ–‰๋ ฌ์ด ๋‹จ์ผ ์ธ์ฝ”๋” ๋‚ด๋ถ€, ํ•˜๋‚˜์˜ ์–ดํ…์…˜ ๋ ˆ์ด์–ด๋งŒ ํ•ด๋„ 4๊ฐœ($W_q, W_k, W_v, W_o$)๊ฐ€ ์žˆ๋‹ค. BERT-base ๋ชจ๋ธ์„ ๊ธฐ์ค€์œผ๋กœ ๋ณด๋ฉด, ํ•ด๋‹น ๋ชจ๋ธ์ด 12๊ฐœ์˜ ์ธ์ฝ”๋”๋กœ ๊ตฌ์„ฑ๋˜์–ด ์žˆ์œผ๋‹ˆ๊นŒ ์ด 48๊ฐœ์˜ ๊ฐ€์ค‘์น˜ ํ–‰๋ ฌ์ด ์žˆ๊ณ , ์–ด๋ฆผ์žก์•„๋„ 48*128๋ฐฐ์˜ ํ•™์Šต ํŒŒ๋ผ๋ฏธํ„ฐ ๊ฐ์†Œ ํšจ๊ณผ๋ฅผ ๋‚ผ ์ˆ˜ ์žˆ๋‹ค. ๋ชจ๋ธ์˜ ๋ ˆ์ด์–ด๊ฐ€ ๋งŽ์œผ๋ฉด ๋งŽ์„์ˆ˜๋ก ๋” ์ข‹์€ ํšจ์œจ์„ ๋ณด์ธ๋‹ค.

Resnet50 Memeory Type in GPU Resnet50 Memeory Type in GPU

์œ„ ๊ทธ๋ฆผ์€ ํŒŒ์ดํ† ์น˜ ๊ณต์‹ ๋ธ”๋กœ๊ทธ์—์„œ ํผ์˜จ ์ž๋ฃŒ๋กœ, ํ•™์Šต ๋•Œ ResNet50์˜ GPU VRAM ์ ์œ ์œจ ์ถ”์ด๋Š” ๋ฌผ๋ก  ๋ชจ๋ธ์˜ ๊ฐœ๋ณ„ ๊ตฌ์„ฑ์š”์†Œ์˜ ๋ฉ”๋ชจ๋ฆฌ ๋น„์œจ๊นŒ์ง€ ์ž์„ธํžˆ ๋ณด์—ฌ์ค€๋‹ค. ๋จผ์ € Parameter์™€ Optimizer State๋ฅผ ๋ณด์ž. Parameter ๋Š” ๋ชจ๋ธ์—์„œ ํ›ˆ๋ จ์„ ํ†ตํ•ด ์—…๋ฐ์ดํŠธ๊ฐ€ ํ•„์š”ํ•œ ๋ชจ๋“  ๊ตฌ์„ฑ ์š”์†Œ๋ฅผ ๋งํ•œ๋‹ค. Freeze, require_grad=False @torch.no_grad() , torch.register_buffer() ์˜ ์˜ํ–ฅ์„ ๋ฐ›์ง€ ์•Š์€ ๋ชจ๋ธ ๋‚ด๋ถ€์˜ ๋ชจ๋“  ํ…์„œ๋ผ๊ณ  ๋ณด๋ฉด ๋œ๋‹ค.

ํ•œํŽธ, Optimizer State ๋Š” ์˜ตํ‹ฐ๋งˆ์ด์ €์˜ ์ตœ์ ํ™” ์ˆ˜ํ–‰์— ํ•„์š”ํ•œ ๋ชจ๋“  ์ •๋ณด๋“ค์„ ์˜๋ฏธํ•˜๋Š”๋ฐ, ์˜ˆ๋ฅผ ๋“ค์–ด ์—…๋ฐ์ดํŠธ ๋  ํ…์„œ์˜ ๋ฉ”ํƒ€ ์ •๋ณด, ์—ฌ๋Ÿฌ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ ๊ฐ’ ๊ฐ™์€ ๊ฒƒ๋“ค์ด ๋‹ด๊ฒจ ์žˆ๋‹ค.

์ด ๋‘ ์š”์†Œ๊ฐ€ ๋ชจ๋ธ์˜ GPU VRAM์„ ์ฐจ์ง€ํ•˜๋Š” ๋น„์œจ์ด ์ƒ๋‹นํžˆ ํฌ๋‹ค. ํ•˜์ง€๋งŒ ๋‘ ์š”์†Œ ๋ชจ๋‘ ํŒŒ๋ผ๋ฏธํ„ฐ ๊ฐœ์ˆ˜์— ๋น„๋ก€ํ•˜๋ฏ€๋กœ LoRA ์ ์šฉ์œผ๋กœ ํŒŒ๋ผ๋ฏธํ„ฐ ๊ฐœ์ˆ˜๋ฅผ ์ค„์ด๋ฉด, GPU VRAM์„ ํš๊ธฐ์ ์œผ๋กœ ์ค„์ผ ์ˆ˜ ์žˆ๋‹ค.

๋˜ํ•œ ํŒŒ์ดํ† ์น˜๋Š” ์—ญ์ „ํŒŒ ์ˆ˜ํ–‰์„ ์œ„ํ•ด ๊ทธ๋ผ๋””์–ธํŠธ๋ฅผ ํŒŒ๋ผ๋ฏธํ„ฐ์™€ ๋™์ผํ•œ ๋ชจ์–‘(shape)์„ ๊ฐ–๋Š” ํ…์„œ๋กœ ์ €์žฅ๋œ๋‹ค๋Š” ์ ์„ ๊ฐ์•ˆํ•˜๋ฉด, ๊ธฐ์กด์˜ Full-Rank ํ…์„œ ๋Œ€์‹  Low-Rank ํ…์„œ๋ฅผ ํ•™์Šต์— ์ด์šฉํ•จ์œผ๋กœ์„œ ๊ทธ๋ผ๋””์–ธํŠธ ํ…์„œ์˜ ํฌ๊ธฐ ์—ญ์‹œ ํš๊ธฐ์ ์œผ๋กœ ์ค„์ผ ์ˆ˜ ์žˆ๊ฒ ๋‹ค.

ํŠธ๋žœ์Šคํฌ๋จธ ๊ณ„์—ด์˜ ๋ชจ๋ธ๋“ค์ด ResNet ๋Œ€๋น„ ์••๋„์ ์œผ๋กœ ํŒŒ๋ผ๋ฏธํ„ฐ ๊ฐœ์ˆ˜๊ฐ€ ๋งŽ๊ธฐ ๋•Œ๋ฌธ์— LoRA๋ฅผ ์ ์šฉํ•œ๋‹ค๋ฉด ํ›จ์”ฌ ํฐ ํšจ๊ณผ๋ฅผ ๋ณผ ์ˆ˜ ์žˆ์„ ๊ฒƒ์ด๋‹ค.

์™ผ์ชฝ ๊ทธ๋ฆผ์€ ๋…ผ๋ฌธ์—์„œ ์ œ์‹œํ•œ, BERT ๊ณ„์—ด์˜ LM์— LoRA๋ฅผ ์ ์šฉํ•œ ํŒŒ์ธํŠœ๋‹ ๊ฒฐ๊ณผ๋‹ค. ํ‘œ์˜ FT ๊ฐ€ ์ผ๋ฐ˜์ ์ธ ํŒŒ์ธํŠœ๋‹ ๋ฐฉ๋ฒ•์— ์˜ํ•ด ๋‚˜์˜จ ๊ฒฐ๊ณผ๋‹ค. ์—Ž์น˜๋ฝ๋’ค์น˜๋ฝํ•˜๋ฉด์„œ ๊ฑฐ์˜ ๋น„์Šทํ•œ ์–‘์ƒ์„ ๋ณด์ธ๋‹ค. ๋ฒค์น˜๋งˆํฌ ํ‰๊ท  ์„ฑ๋Šฅ์€ LoRA๊ฐ€ ๋” ๋†’๋‹ค. ์•„๋งˆ, ์ ๋‹นํžˆ ์„ฑ๋Šฅ ์ฐจ์ด๋ฅผ ๋ณด์—ฌ์ฃผ๊ธฐ ์œ„ํ•ด ์ทจ์‚ฌ์„ ํƒ๋œ ๋ฒค์น˜๋งˆํฌ์ผ ๊ฐ€๋Šฅ์„ฑ์ด ๋†’์ง€๋งŒ, ๊ทธ๋ž˜๋„ ์ƒ๋‹นํžˆ ์œ ์˜๋ฏธํ•œ ๊ฒฐ๊ณผ๋ผ๊ณ  ์ƒ๊ฐํ•œ๋‹ค. ์šฐ์ธก์€ GPT2์— LoRA๋ฅผ ์ ์šฉํ•œ ๊ฒฐ๊ณผ๋‹ค. ๋งˆ์ฐฌ๊ฐ€์ง€๋กœ, ์—‡๋น„์Šทํ•œ ์„ฑ๋Šฅ ์ถ”์ด๋ฅผ ๋ณด์—ฌ์ค€๋‹ค.

์ง€๊ธˆ๊นŒ์ง€ LoRA ๊ฐ€ ์ œ์‹œํ•˜๋Š” ๋ฐฉ๋ฒ•๋ก ์ด ์–ด๋–ป๊ฒŒ ํš๊ธฐ์ ์œผ๋กœ ํ•™์Šต ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์ค„์ด๊ณ  ๋‚˜์•„๊ฐ€ ๋ชจ๋ธ์ด ์ฐจ์ง€ํ•˜๋Š” GPU VRAM ํฌ๊ธฐ๋ฅผ ๊ฐ์†Œ์‹œ์ผฐ๋Š”์ง€ ์•Œ์•„ ๋ณด์•˜๋‹ค. ์ด์ œ LoRA๋ฅผ ์ ์šฉํ•ด๋„ ์ผ๋ฐ˜์ ์ธ ํŒŒ์ธํŠœ๋‹ ๋ฐฉ๋ฒ•๊ณผ ๋น„์Šทํ•œ ์„ฑ๋Šฅ์„ ์œ ์ง€ํ•  ์ˆ˜ ์žˆ์—ˆ๋Š”์ง€ ๊ทธ ๊ฒฐ๊ณผ์— ๋Œ€ํ•ด ํ•ด์„ํ•ด๋ณด์ž. ๋…ผ๋ฌธ์˜ Chapter 7. UNDERSTANDIGN THE LOW-RANK UPDATES ๋‚ด์šฉ์— ํ•ด๋‹น๋œ๋‹ค. ํ•ด๋‹น ํŒŒํŠธ๋Š” 3๊ฐ€์ง€ ์ธ์‚ฌ์ดํŠธ๋ฅผ ์ œ์‹œํ•œ๋‹ค.

๐Ÿ’กย Inisght 1. Apply to LoRA (Wq, Wv) or (Wq, Wk, Wv, Wo)

Which Matrix is the BEST Which Matrix is the BEST

ํ•„์ž๋Š” ๋…ผ๋ฌธ์„ ์ฝ๋Š” ๋‚ด๋‚ด, โ€˜๊ทธ๋ž˜์„œ ์–ด๋–ค ๊ฐ€์ค‘์น˜ ํ–‰๋ ฌ์— ์ ์šฉํ•ด์•ผ ํ• ๊นŒ?? ๋ชจ๋“  ๊ฐ€์ค‘์น˜ ํ–‰๋ ฌ์— ์ ์šฉํ•ด๋„ ๋˜๋Š”๊ฑธ๊นŒ??โ€™ํ•˜๋Š” ์˜๋ฌธ์„ ๊ฐ–๊ณ  ์žˆ์—ˆ๋‹ค. ๊ทผ๋ฐ ๋งˆ์นจ ์ €์ž๋“ค์ด ์ด๋Ÿฌํ•œ ์˜๋ฌธ๋“ค์„ ์—์ƒํ•œ ๋“ฏ, ์œ„์™€ ๊ฐ™์ด ์ ์šฉ ๊ฐ€์ค‘์น˜ ํ–‰๋ ฌ์— ๋”ฐ๋ฅธ ๋ฒค์น˜๋งˆํฌ ์„ฑ๋Šฅ ๊ฒฐ๊ณผ๋ฅผ ํ‘œ๋กœ ์ •๋ฆฌํ•ด์ฃผ์—ˆ๋‹ค. ๋ชจ๋ธ์€ GPT3์„ ์‚ฌ์šฉํ–ˆ๋‹ค๊ณ  ๋…ผ๋ฌธ์—์„œ ๋ฐํžˆ๊ณ  ์žˆ๋‹ค.

๋ณด์ด๋Š” ๊ฒƒ๊ณผ ๊ฐ™์ด, ($W_q, W_v$) ํ˜น์€ ($W_q, W_k, W_v, W_o$)์— LoRA๋ฅผ ์ ์šฉํ•˜๋Š”๊ฒŒ ๊ฐ€์žฅ ์ข‹์€ ๋ฒค์น˜๋งˆํฌ ์„ฑ๋Šฅ์„ ๋ณด์—ฌ์ค€๋‹ค. ์ฃผ๋ชฉํ•  ์ ์€ ๋žญํฌ๊ฐ€ ๊ฐ€์žฅ ๋‚ฎ์œผ๋ฉด์„œ, ๊ฐ€์žฅ ๋งŽ์€ ๊ฐ€์ค‘์น˜ ํ–‰๋ ฌ์— LoRA๋ฅผ ์ ์šฉํ•˜๋Š”๊ฒŒ ๊ฐ€์žฅ ์„ฑ๋Šฅ์ด ์ข‹๋‹ค๋Š” ๊ฒƒ์ด๋‹ค. ์‹คํ—˜๊ฒฐ๊ณผ ์ œ์‹œ ์ด์™ธ์— ๋‹ค๋ฅธ ์ฆ๋ช…์ด๋‚˜ ์ธ์‚ฌ์ดํŠธ ์ œ์‹œ๊ฐ€ ์—†๋Š”๊ฒŒ ์•„์‰ฝ์ง€๋งŒ, ์ด๋ฅผ ํ†ตํ•ด ๋‹ค์Œ๊ณผ ๊ฐ™์€ ์‚ฌ์‹ค๋“ค์„ ๋– ์˜ฌ๋ ค ๋ณด์•˜๋‹ค.

  • 1) FT ๋ฌธ์ œ ํ•ด๊ฒฐ์— ํ•„์š”ํ•œ ๋ฌธ๋งฅ ์ •๋ณด๋“ค์ด ์ฟผ๋ฆฌ, ํ‚ค, ๋ฒจ๋ฅ˜ ํ–‰๋ ฌ์— ์ ์ ˆํžˆ ๋ถ„์‚ฐ
    • ์„ธ๊ฐ€์ง€ ๊ฐ€์ค‘์น˜ ํ–‰๋ ฌ์ด ๋ชจ๋‘ ์œ ์˜๋ฏธํ•œ ๋ฌธ๋งฅ ํ‘œํ˜„์„ ํ•™์Šต
  • 2) ๋‚ฎ์€ ๋žญํฌ๋กœ๋„ ์ถฉ๋ถ„ํžˆ, FT์— ํ•„์š”ํ•œ ์ž„๋ฒ ๋”ฉ ์ถ”์ถœ ๊ฐ€๋Šฅ
    • ๊ทธ๋งŒํผ, ์‚ฌ์ „ ํ•™์Šต์—์„œ ํฌ์ฐฉํ•  ์ˆ˜ ์žˆ๋Š” ์ž„๋ฒ ๋”ฉ์ด ํ’๋ถ€ํ•˜๋ฉฐ ์ผ๋ฐ˜ํ™” ๋Šฅ๋ ฅ์ด ์ข‹๋‹ค๊ณ  ํŒ๋‹จํ•  ์ˆ˜ ์žˆ์Œ
      • ์‚ฌ์ „ ํ•™์Šต ๋‹จ๊ณ„์—์„œ ์ตœ๋Œ€ํ•œ ๊นŠ๊ฒŒ ๋งŽ์ด ํ•™์Šต์‹œํ‚ฌ์ˆ˜๋ก FT ๋‹จ๊ณ„๊ฐ€ ๊ฐ„์†Œํ™” ๋  ์ˆ˜ ์žˆ์ง€ ์•Š์„๊นŒ??
      • ๋‹ค๋งŒ, ์‚ฌ์ „ํ•™์Šต๊ณผ ํŒŒ์ธํŠœ๋‹ ์‚ฌ์ด์˜ ๊ดด๋ฆฌ๊ฐ€ ํฐ ๊ฒฝ์šฐ๋ผ๋ฉด??
      • ์‚ฌ์ „ ํ•™์Šต์€ ์˜์–ด๋กœ, ํŒŒ์ธํŠœ๋‹์€ ํ•œ๊ตญ์–ด ๋ฐ์ดํ„ฐ ์„ธํŠธ๋กœ ํ•˜๋Š” ๊ฒฝ์šฐ๋ผ๋ฉด??

๋งˆ์นจ ์ฃผ์„์— However, we do not expect a small r to work for every task or dataset. Consider the following thought experiment: if the downstream task were in a different language than the one used for pre-training, retraining the entire model ์ด๋ผ๋Š” ์–ธ๊ธ‰์ด ์žˆ๋Š” ๊ฒƒ์œผ๋กœ ๋ณด์•„, ๋žญํฌ ๊ฐ’์€ ๋˜๋„๋ก ๋‚ฎ์€ ๊ฐ’์„ ์„ ์ •ํ•˜๋˜, ์‚ฌ์ „ํ•™์Šต๊ณผ ํŒŒ์ธ ํŠœ๋‹์˜ ๊ดด๋ฆฌ๊ฐ€ ์‹ฌํ•˜๋‹ค๊ณ  ํŒ๋‹จ๋˜๋Š” ๊ฒฝ์šฐ, ๋†’์€ ๋žญํฌ๊ฐ’๊ณผ ์‹คํ—˜ ๊ฒฐ๊ณผ ๋น„๊ต๋ฅผ ํ†ตํ•ด ์ ์ ˆํ•œ ๊ฐ’์„ ์„ ์ •ํ•ด์•ผ๊ฒ ๋‹ค.

๐Ÿ’กย Inisght 2. ๋‚ฎ์€ ๋žญํฌ๋กœ๋„ ์ถฉ๋ถ„

Insight 2 Insight 2

๋‚ฎ์€ ๋žญํฌ๋กœ๋„ ์ถฉ๋ถ„ํžˆ, FT์— ํ•„์š”ํ•œ ์ž„๋ฒ ๋”ฉ ์ถ”์ถœ ๊ฐ€๋Šฅํ•˜๋‹ค๋Š” ๊ฒƒ์„ ์ข€ ๋” ๊ตฌ์ฒด์ ์ธ ์‹คํ—˜์œผ๋กœ ์ฆ๋ช…ํ•˜๊ณ  ์žˆ๋‹ค. ๊ทธ๋ž˜ํ”„ $y$์ถ•๊ณผ $x$์ถ•์€ ๊ฐ๊ฐ $A_r = 8$, $A_r = 64$์ธ (ํ…์„œ ๋ชจ์–‘ [r, dim_model]) ๊ฐ€์ค‘์น˜ ํ–‰๋ ฌ์„ SVDํ•˜์—ฌ ์–ป์€ right-singular matrix ์—์„œ top-i(1 โ‰ค i โ‰ค 8), top-j (1 โ‰ค j โ‰ค 64)๊ฐœ์˜ ํŠน์ด๊ฐ’์„ ์ถ”์ถœํ•œ ๋’ค, Grassmann Distance๋ฅผ ๊ฑฐ๋ฆฌ ๋งคํŠธ๋ฆญ์œผ๋กœ ์ด์šฉํ•ด ๋ถ€๋ถ„ ๊ณต๊ฐ„ ์‚ฌ์ด์˜ ์œ ์‚ฌ๋„๋ฅผ ์ธก์ •ํ•œ ๊ฒฐ๊ณผ๋‹ค.

ํ•œํŽธ, ์™œ right-singular matrix ์ผ๊นŒ ๋‹ค์‹œ ํ•œ ๋ฒˆ ์ƒ๊ฐํ•ด๋ดค๋‹ค. ์ „์ฒด ๋ฒกํ„ฐ ๊ณต๊ฐ„์—์„œ top-i(1 โ‰ค i โ‰ค 8), top-j (1 โ‰ค j โ‰ค 64)๊ฐœ์˜ ํŠน์ด๊ฐ’์„ ๋ฝ‘์•„๋‚ด ์„œ๋กœ ๋น„๊ตํ•˜๋ ค๋ฉด, ๋‘ ํ–‰๋ ฌ $A_{r = 8}$, $A_{r = 64}$์ด ๊ฐ™์€ ๋ถ€๋ถ„ ๊ณต๊ฐ„์—์„œ ์ •์˜ ๋˜์–ด์•ผ ํ•œ๋‹ค. SVD ์ •์˜์ƒ, ์™ผ์ชฝ ํŠน์ด๋ฒกํ„ฐ๋Š” ๊ฐ๊ฐ 8x8, 64x64์ฐจ์›์ด ๋˜์–ด ๋น„๊ตํ•˜๊ธฐ ์–ด๋ ต๋‹ค. ํ•œํŽธ, ์˜ค๋ฅธ์ชฝ ๋ฒกํ„ฐ๋Š” ๋‘ ํ–‰๋ ฌ ๋ชจ๋‘ dxd๋กœ ์ •์˜๋œ๋‹ค. ๋งŒ์•ฝ ํ–‰๋ ฌ $A$ ๋Œ€์‹  $B$๋ฅผ ์‚ฌ์šฉํ•˜๊ณ  ์‹ถ๋‹ค๋ฉด ์™ผ์ชฝ ํŠน์ด ๋ฒกํ„ฐ๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด ๋œ๋‹ค.

์˜ค๋žœ์ง€ ์ƒ‰์— ๊ฐ€๊นŒ์šธ์ˆ˜๋ก ์„œ๋กœ ๊ฒน์น˜๋Š” ์ •๋ณด๊ฐ€ ๋งŽ๋‹ค๋Š” ์˜๋ฏธ๋ฅผ ๊ฐ–๋Š”๋ฐ, ์‚ฌ์šฉํ•œ ์ธ์ฝ”๋” ์œ„์น˜์™€ ์ƒ๊ด€์—†์ด $A_{r = 8}$์˜ top ์—ด๋ฒกํ„ฐ์ผ์ˆ˜๋ก, $A_{r = 64}$์˜ ๋‚˜๋จธ์ง€ ์—ด๋ฒกํ„ฐ๋“ค๊ณผ ๋†’์€ ์œ ์‚ฌ๋„(์˜ค๋žœ์ง€์ƒ‰์— ๊ฐ€๊นŒ์›€)์„ ๊ธฐ๋กํ•˜๊ณ  ์žˆ๋‹ค.(ํ—ท๊ฐˆ๋ฆฌ๋‹ˆ๊นŒ ์™ผ์ชฝ ๋‘ ๊ฐœ ๊ทธ๋ž˜ํ”„๋งŒ ๋ณด๋Š”๊ฒŒ ๋‚ซ๋‹ค). ๊ทธ๋ฆฌ๊ณ  $A_{r = 8}$์˜ bottom ์—ด๋ฒกํ„ฐ์ผ์ˆ˜๋ก, ๊ฑฐ๋ฌด์ฃฝ์ฃฝํ•œ ์ƒ‰๊น”์„ ๊ฐ€์ง€๋ฉฐ $A_{r = 64}$์˜ ๋‚˜๋จธ์ง€ ์—ด๋ฒกํ„ฐ๋“ค๊ณผ ๋‚ฎ์€ ์œ ์‚ฌ๋„๋ฅผ ๋ณด์ธ๋‹ค.

๊ฒฐ๊ตญ ์•Œ๊ณ ๋ดค๋”๋‹ˆ ์‚ฌ์ „ํ•™์Šต์„ ์ถฉ๋ถ„ํžˆ ์ˆ˜ํ–‰ํ•œ ๋ชจ๋ธ์˜ ๊ฒฝ์šฐ, ํŒŒ์ธํŠœ๋‹ Task์— ๋Œ€ํ•ด ์ ์‘์‹œํ‚ค๋Š”๋ฐ ํ•„์š”ํ•œ ๊ณต๊ฐ„์€ ์†Œ์ˆ˜, ๊ตณ์ด ์ „์ฒด ๊ณต๊ฐ„์„ ํ•™์Šต ํŒŒ๋ผ๋ฏธํ„ฐ๋กœ ๋‘๊ณ  ํŒŒ์ธํŠœ๋‹ํ•ด๋ด์•ผ ๋Œ€๋ถ€๋ถ„์˜ ์—ด๋ฒกํ„ฐ๋Š” ์“ฐ์ž˜๋ฐ๊ธฐ ์—†๋Š” ํ‘œํ˜„์„ ์ธ์ฝ”๋”ฉํ•˜๋Š”๋ฐ ์“ฐ์ด๊ณ  ์žˆ์—ˆ๋‹ค๊ณ  ๋ณผ ์ˆ˜ ์žˆ๊ฒ ๋‹ค.

๋ฌผ๋ก  ์—ฌ๊ธฐ์„œ๋„ ์ฃผ์˜ํ•  ์ ์€, GPT3์˜ ์‚ฌ์ „ํ•™์Šต๊ณผ ๊ถค๊ฐ€ ๋น„์Šทํ•œ WikiSQL, MNLU์— ๋Œ€ํ•ด ํŒŒ์ธํŠœ๋‹ํ•œ ๊ฒฐ๊ณผ๋ผ๋Š” ์ ์ด๋‹ค. ๋‹ค๊ตญ์–ด๋กœ ๊ตฌ์„ฑ๋œ ๋ฐ์ดํ„ฐ ์„ธํŠธ๋ฅผ ํ™œ์šฉํ•˜๊ฒŒ ๋˜๋ฉด, ์ด ๊ฒฐ๊ณผ๊ฐ€ ์–ด๋–ป๊ฒŒ ๋ฐ”๋€”์ง€ ๋ชจ๋ฅธ๋‹ค.

Grassmann Distance ๋Š” ์„ ํ˜• ๋ถ€๋ถ„๊ณต๊ฐ„(linear subspace) ๊ฐ„์˜ ๊ฑฐ๋ฆฌ๋ฅผ ์ธก์ •ํ•˜๋Š” ๋ฐ ์‚ฌ์šฉ๋˜๋Š” ๊ฐœ๋…์ด๋ผ๊ณ  ํ•˜๋Š”๋ฐ, ์—ฌ๊ธฐ์„œ ์ด๊ฒƒ๊นŒ์ง€ ๋‹ค๋ฃจ๋ฉด ํฌ์ŠคํŒ… ๊ธธ์ด๊ฐ€ ๋„ˆ๋ฌด ๊ธธ์–ด์งˆ ๊ฒƒ ๊ฐ™์•„์„œ, ๋‚˜์ค‘์— ๋‹ค๋ฅธ ํฌ์ŠคํŠธ์—์„œ ๋‹ค๋ฃจ๋„๋ก ํ•˜๊ฒ ๋‹ค.

๐Ÿ’กย Inisght 3. w โ‰  delta w

Insight 3 Insight 3

์‚ฌ์ „ ํ•™์Šตํ•œ ๊ฐ€์ค‘์น˜ ํ–‰๋ ฌ $W$๊ณผ LoRA ์˜ $\Delta W$๊ฐ€ ์„œ๋กœ ์–ผ๋งˆ๋‚˜ ์œ ์‚ฌํ•œ์ง€, ์‹คํ—˜์ ์œผ๋กœ ์ฆ๋ช…ํ•˜๊ณ  ์žˆ๋‹ค. ๋…ผ๋ฌธ์—์„œ ์ œ๊ณตํ•œ ์‹คํ—˜ ๋ฐฉ์‹์„ ์ •๋ฆฌํ•˜๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค.

  • 1) ์‚ฌ์ „ ํ•™์Šต์œผ๋กœ ์ˆ˜๋ ด๋œ ์ฟผ๋ฆฌ ํ–‰๋ ฌ, $W_q$๋ฅผ task-specificํ•œ ๊ณต๊ฐ„($U^T, V^T$: $\Delta W$์˜ Top-r๊ฐœ์˜ Left, Right-Singular Vector)์œผ๋กœ ํˆฌ์˜
  • 2) LoRA์— ์˜ํ•ด ์ˆ˜๋ ด๋œ ๋ธํƒ€ ์ฟผ๋ฆฌํ–‰๋ ฌ, $\Delta W_q$์€ ์ด๋ฏธ ํ–‰๋ ฌ ์ „๋ฐ˜์— task-specificํ•œ ์ •๋ณด๋ฅผ ๋‹ด๊ณ  ์žˆ์Œ.
    • ๊ทธ๋ž˜์„œ top-r ์ถ”์ถœํ•˜์ง€ ์•Š๊ณ  ์ „์ฒด์— ๋Œ€ํ•ด์„œ ํ”„๋กœ๋ฒ ๋‹ˆ์šฐ์Šค ๋†ˆ ๊ตฌํ•˜๊ธฐ
  • 3) 1๋ฒˆ ์Šคํƒญ์—์„œ ๊ตฌํ•œ ํˆฌ์˜ ํ–‰๋ ฌ, $U^TW_qV^T$์— ๋Œ€ํ•ด ํ”„๋กœ๋ฒ ๋‹ˆ์šฐ์Šค ๋†ˆ ๊ณ„์‚ฐ
  • 4) 2๋ฒˆ/3๋ฒˆ ์ˆ˜ํ–‰: task-specificํ•œ ๊ณต๊ฐ„์„ LoRA๊ฐ€ ์‚ฌ์ „ ํ•™์Šต ๊ฐ€์ค‘์น˜์— ๋น„ํ•ด ์–ผ๋งˆ๋‚˜ ๋งŽ์ด ๊ฐ•์กฐํ–ˆ๋Š”์ง€ ๋‚˜ํƒ€๋‚ด๋Š” ์ง€ํ‘œ
    • ๋…ผ๋ฌธ์—์„œ๋Š” Feature Amplication Factor ๋ผ๊ณ  ์ •์˜

ํ”„๋กœ๋ฒ ๋‹ˆ์šฐ์Šค ๋†ˆ์€ ๊ธฐํ•˜ํ•™์ ์œผ๋กœ ํ–‰๋ ฌ์˜ ํฌ๊ธฐ, ์ฆ‰ ์„ ํ˜•๋ณ€ํ™˜์˜ ํฌ๊ธฐ๋ฅผ ์˜๋ฏธํ•œ๋‹ค. ๊ทธ๋ž˜์„œ ๊ณง Feature Amplication Factor๊ฐ€ ํ–‰๋ ฌ์˜ ํฌ๊ธฐ/ํ–‰๋ ฌ์˜ ํฌ๊ธฐ๋ฅผ ๋‚˜ํƒ€๋‚ด๋Š” ์ง€ํ‘œ๊ฐ€ ๋˜๊ณ , ๋ถ„์ž์™€ ๋ถ„๋ชจ์˜ ํ–‰๋ ฌ์€ ๋ชจ๋‘ task-specificํ•œ ๊ณต๊ฐ„์œผ๋กœ์˜ ๋ณ€ํ™˜ ํฌ๊ธฐ๋ฅผ ์˜๋ฏธํ•˜๊ธฐ ๋•Œ๋ฌธ์—, ๊ฐ™์€ ํŠน์ง•์„ ๋ถ„์ž($\Delta W$)๊ฐ€ ๋ถ„๋ชจ($W$)์— ๋น„ํ•ด์„œ ์–ผ๋งˆ๋‚˜ ๋” ๊ฐ•์กฐํ•˜๋Š”์ง€๋ฅผ ๋œปํ•˜๊ฒŒ ๋œ๋‹ค. ๋ ๋ฆฌ์„œ factor ๊ฐ’์ด ํด์ˆ˜๋ก LoRA๊ฐ€ ์‚ฌ์ „ ํ•™์Šต์—์„œ ๊ฐ•์กฐํ•˜์ง€ ์•Š์•˜๋˜ ํŠน์ง•์„ ๋”์šฑ ๊ฐ•์กฐํ•œ๋‹ค๊ณ  ํ•ด์„ํ•  ์ˆ˜ ์žˆ๊ฒŒ ๋œ๋‹ค. ์ด์ œ ๋‹ค์‹œ ํ‘œ๋ฅผ ๋ถ„์„ํ•ด๋ณด์ž.

Low Rank value $r=4$์ผ ๋•Œ, Feature Amplication Factor์˜ ๋ถ„๋ชจ๋Š” 0.32, ๋ถ„์ž๋Š” 6.91์ด ๋œ๋‹ค. ๋”ฐ๋ผ์„œ factor ๊ฐ’์€ ๋Œ€๋žต 21.5๊ฐ€ ๋œ๋‹ค. ๋‹ค์‹œ ๋งํ•ด GPT3์˜ 48๋ฒˆ์งธ ๋ ˆ์ด์–ด์˜ ๊ฒฝ์šฐ, FT ์ ์‘์— ํ•„์š”ํ•œ task-specificํ•œ ๊ณต๊ฐ„์„ LoRA๊ฐ€ ์‚ฌ์ „ํ•™์Šต ์ฟผ๋ฆฌ ํ–‰๋ ฌ๋ณด๋‹ค 21.5๋ฐฐ ๊ฐ•์กฐํ•˜๊ณ  ์žˆ๋‹ค๋Š” ๊ฒƒ์ด๋‹ค.

Low Rank value $r=64$์ผ ๋•Œ๋Š” factor๊ฐ€ ๋Œ€๋žต 1.9๊ฐ€ ๋œ๋‹ค. $r=4$์ผ ๋•Œ๋ณด๋‹ค factor ๊ฐ’์ด ํ˜„์ €ํžˆ ๋‚ฎ์€ ์ด์œ ๋Š” Insight 2์˜ ๊ฒฐ๊ณผ(๋‚ฎ์€ ๋žญํฌ๋กœ๋„ ์ถฉ๋ถ„ํžˆ FT์˜ task-specific ์ •๋ณด ํ‘œํ˜„ ๊ฐ€๋Šฅ)์™€ ์ผ๋งฅ์ƒํ†ตํ•œ๋‹ค๊ณ  ๋ณผ ์ˆ˜ ์žˆ๋‹ค.

์ฒ˜์Œ ์ฝ์—ˆ์„ ๋•Œ ์ด ๋ถ€๋ถ„์— ๋Œ€ํ•œ ํ•ด์„์ด ๋„ˆ๋ฌด ๋‚œํ•ดํ•ด, ์ €์ž๋“ค์ด ๊นƒํ—ˆ๋ธŒ์— ๊ณต๊ฐœํ•œ, RoBERTa๋ฅผ LoRA์™€ ํ•จ๊ป˜ MRPC ๋ฒค์น˜๋งˆํฌ์— ํŒŒ์ธํŠœ๋‹ํ•œ ๊ฐ€์ค‘์น˜๋ฅผ ๋ถˆ๋Ÿฌ์™€ ๋˜‘๊ฐ™์€ ๋ฐฉ์‹์œผ๋กœ ์‹คํ—˜์„ ์ง„ํ–‰ํ•ด๋ดค๋‹ค. ๋จผ์ € ์ „์ฒด ์‹คํ—˜ ๋ฐฉ์‹์„ ์š”์•ฝํ•˜๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค.

  • 1) Huggingface Hub์—์„œ RoBERTa-base์˜ ์‚ฌ์ „ํ•™์Šต ๊ฐ€์ค‘์น˜ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
  • 2) LoRA official github์—์„œ roberta_base_lora_mrpc.bin ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
  • 3) 1,2๋ฒˆ์—์„œ ๋ชจ๋‘ 6๋ฒˆ์งธ ์ธ์ฝ”๋” ๋ ˆ์ด์–ด์˜ ์ฟผ๋ฆฌ ํ–‰๋ ฌ์— ๋Œ€ํ•œ ๊ฐ€์ค‘์น˜ ์ถ”์ถœ
  • 4) ์ดํ•˜ ๋‚˜๋จธ์ง€ ๊ณผ์ •์€ ์œ„์— ๋…ผ๋ฌธ์˜ ์‹คํ—˜ ๋ฐฉ์‹์„ ๋”ฐ๋ฆ„

์ „์ฒด ๊ณผ์ •์„ ์ฝ”๋“œ๋กœ ์ •๋ฆฌํ•˜๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค.

""" Insight 3 Experiment Code Exanple """
import torch
from transformers import AutoModel, AutoConfig

""" LoRA ๊ฒฐ๊ณผ ํ•ด์„ ์žฌํ˜„ """
pt_config = AutoConfig.from_pretrained('FacebookAI/roberta-base')  
pt_model = AutoModel.from_pretrained( # pretrained model
    'roberta-base',
    config=pt_config
)

lora_checkpoint = torch.load('model/roberta_base_lora_mrpc.bin', map_location='cpu')
lora_checkpoint

""" Select Wq in 6-th encoder layer """
pt_wq, lora_a, lora_b = pt_model.encoder.layer[6].attention.self.query.weight, lora_checkpoint['roberta.encoder.layer.6.attention.self.query.lora_A'], lora_checkpoint['roberta.encoder.layer.6.attention.self.query.lora_B']
delta_wq = lora_b @ lora_a
pt_wq.shape, lora_a.shape, lora_b.shape, delta_wq.shape

>>> (torch.Size([768, 768]),
>>> torch.Size([8, 768]),
>>> torch.Size([768, 8]),
>>> torch.Size([768, 768]))

""" Let's SVD, select top-r singular vector, ๋ถ„์ž  """
U, S, V = torch.svd(delta_wq)
print(f"Delta W U: {U.shape}")
print(f"Delta W S: {S.shape}")
print(f"Delta W V: {V.shape}")

>>> Delta W U: torch.Size([768, 768])
>>> Delta W S: torch.Size([768])
>>> Delta W V: torch.Size([768, 768])

r = 4
r_U, r_V = U[:, :r], V[:r, :]
result1 = torch.matmul(r_U.T @ pt_wq, r_V.T)
fwq_norm = torch.norm(result1)  # ๋ถ„์ž๊ฐ’
result1, fwq_norm

>>> (tensor([[-0.0441,  0.0447,  0.0323,  0.0963],
         [-0.0038, -0.0412, -0.0903, -0.0949],
         [-0.0314,  0.1003, -0.0599,  0.0023],
         [-0.0222, -0.1090,  0.0315,  0.0575]], grad_fn=<MmBackward0>),
>>> tensor(0.2539, grad_fn=<LinalgVectorNormBackward0>))

""" ๋ถ„๋ชจ """
fdwq_norm = torch.norm(delta_wq)  # ๋ถ„๋ชจ๊ฐ’
fdwq_norm

>>> tensor(5.0820)

"""๊ฒฐ๊ณผ: Feature Amplication Factor """
fdwq_norm / fwq_norm

>>> tensor(20.0170, grad_fn=<DivBackward0>)

๐Ÿ‘ฉโ€๐Ÿ’ปย Implementation by Pytorch

import math
import torch
import torch.nn as nn
from torch import Tensor


class LoRA(nn.Module):
    """ class module for Low-Rank adaptation of LLM SFT
    This module return result of "BAx*(a/r)" in mathematical expression in official paper

    Args:
        dim: dimension of input tensor
        rank: rank of tensor, which is hyperparameter for LoRA
        alpha: hyperparameter for LoRA, trainable parameter, which is initialized by rank value
        options: default str, 'rlora' which is already proved to work better than pure lora
                 you can select pure lora as passing argument 'lora'

    Math:
        h = W0x + โˆ†Wx = W0x + BAx*(a/r)

    Notes:
        we use sqrt(rank) value, it is already proven to work better in LoRA,
        from Huggingface PEFT library official docs

    References:
        https://arxiv.org/abs/2106.09685
        https://pytorch.org/blog/understanding-gpu-memory-1/
    """
    def __init__(self, dim: int, rank: int, alpha: int, options: str = 'rlora'):
        super().__init__()
        self.a = nn.Parameter(torch.randn(rank, dim))  # init by random Gaussian distribution (normal distribution)
        self.b = nn.Parameter(torch.zeros(dim, rank))  # init by zero
        self.alpha = alpha / math.sqrt(rank) if options == 'rlora' else alpha / rank

    def forward(self, inputs: Tensor) -> Tensor:
        return torch.matmul(inputs, self.b @ self.a) * self.alpha

๋‹ค์Œ์€ ํ•„์ž๊ฐ€ ์ง์ ‘ ๊ตฌํ˜„ํ•œ LoRA ๊ฐ์ฒด๋‹ค. ๊ตฌํ˜„์ƒ ํŠน์ด์ ์ด๋ผ์„œ ํฌ์ŠคํŒ… ๋‚ด์šฉ์—๋Š” ํฌํ•จ๋˜์ง€ ์•Š์€ ๋ถ€๋ถ„๋“ค์ด ์žˆ๊ธฐ ๋•Œ๋ฌธ์— ์ฝ”๋“œ๋ฅผ ํ•จ๊ป˜ ์‚ดํŽด๋ณด์ž. ์ผ๋‹จ ํ–‰๋ ฌ $A$์— ํ•ด๋‹น๋˜๋Š” self.a๋Š” nn.Parameter๋ฅผ ํ˜ธ์ถœํ•ด ๋ชจ๋ธ์ด ํ•™์Šต ํŒŒ๋ผ๋ฏธํ„ฐ๋กœ ์ธ์‹ํ•˜๋„๋ก ๋งŒ๋“ ๋‹ค. ๊ทธ๋ฆฌ๊ณ  ๋…ผ๋ฌธ์— ๋‚˜์˜จ๋Œ€๋กœ, ๋žœ๋ค ๊ฐ€์šฐ์‹œ์•ˆ ๋ถ„ํฌ๋ฅผ ๋”ฐ๋ฅด๋„๋ก ํ…์„œ๋ฅผ ์ดˆ๊ธฐํ™” ํ•ด์ค€๋‹ค, ํ–‰๋ ฌ $B$์— ํ•ด๋‹น๋˜๋Š” self.b nn.Parameter๋ฅผ ํ˜ธ์ถœํ•ด ๋ชจ๋ธ์ด ํ•™์Šต ํŒŒ๋ผ๋ฏธํ„ฐ๋กœ ์‚ฌ์šฉํ•˜๋„๋ก ๋งŒ๋“ค๊ณ , ๋…ผ๋ฌธ์— ๋‚˜์˜จ๋Œ€๋กœ ์˜ํ–‰๋ ฌ๋กœ ์ดˆ๊ธฐํ™” ํ•ด์ค€๋‹ค. ๋งˆ์ง€๋ง‰์œผ๋กœ $\Delta W$ ๊ฐ’์„ ์Šค์ผ€์ผ๋ง ํ•ด์ค„ ์Šค์ผ€์ผ ํŒฉํ„ฐ alpha๋ฅผ ๋„์ž…ํ•œ๋‹ค. ์ด๋•Œ, options ์ธ์ž๋ฅผ ํ†ตํ•ด LoRA์™€ RLORA ์ค‘ ์–ด๋–ค ๊ฒƒ์„ ์‚ฌ์šฉํ• ์ง€ ์„ ํƒํ•  ์ˆ˜ ์žˆ๋‹ค. ์ด์ค‘์—์„œ, ๋ถ„๋ชจ์˜ ๋žญํฌ๊ฐ’์— ์ œ๊ณฑ๊ทผ์„ ์ทจํ•ด์ฃผ๋Š” RLORA๊ฐ€ ๋” ์ข‹์€ ์„ฑ๋Šฅ์„ ๋ณด์ธ๋‹ค๊ณ  ํ›„์† ์—ฐ๊ตฌ์—์„œ ๋ฐํ˜€์กŒ๋‹ค๊ณ  ํ•œ๋‹ค.

์ด๋ ‡๊ฒŒ LoRA ๊ฐ์ฒด์— ๋Œ€ํ•ด์„œ ํ•จ๊ป˜ ์‚ดํŽด๋ณด์•˜๋‹ค. ๊ตฌํ˜„ ์ž์ฒด๋Š” ๋งค์šฐ ๊ฐ„๋‹จํ•˜๋‹ค. ํ•˜์ง€๋งŒ, ์ค‘์š”ํ•œ ์ ์€ ์‚ฌ์ „ ํ•™์Šต ๋ชจ๋ธ์˜ ๊ฐ€์ค‘์น˜์— LoRA ๊ฐ์ฒด๋ฅผ ์ ์šฉํ•˜์—ฌ ์ƒˆ๋กœ์šด ๋ชจ๋ธ ๊ฐ์ฒด๋ฅผ ๋งŒ๋“ค์–ด ๋‚ด๋Š” ๊ฒƒ์ด๋‹ค. ์•„๋ž˜ ์ฝ”๋“œ์ฒ˜๋Ÿผ,

""" before MHA """

q = self.fc_q(x).reshape(-1, x.shape[1], self.num_attention_heads, self.dim_head).permute(0, 2, 1, 3).contiguous()
k = self.fc_k(x).reshape(-1, x.shape[1], self.num_attention_heads, self.dim_head).permute(0, 2, 1, 3).contiguous()
v = self.fc_v(x).reshape(-1, x.shape[1], self.num_attention_heads, self.dim_head).permute(0, 2, 1, 3).contiguous()

attention_output = self.fc_concat(attention_matrix)

""" after MHA """
self.lora_q = lora()
self.lora_k = lora()
self.lora_v = lora()
self.lora_o = lora()

q = self.fc_q(x) + self.lora_q(x)  # freeze + trainable
q.reshape(-1, x.shape[1], self.num_attention_heads, self.dim_head).permute(0, 2, 1, 3).contiguous()

k = self.fc_k(x) + self.lora_k(x)  # freeze + trainable
k.reshape(-1, x.shape[1], self.num_attention_heads, self.dim_head).permute(0, 2, 1, 3).contiguous()

v = self.fc_v(x) + self.lora_v(x)  # freeze + trainable
v.reshape(-1, x.shape[1], self.num_attention_heads, self.dim_head).permute(0, 2, 1, 3).contiguous()

attention_output = self.fc_concat(attention_matrix) + self.lora_o(attention_matrix) # freeze + trainable

์„ ํ˜• ํˆฌ์˜๋œ ์ฟผ๋ฆฌ, ํ‚ค, ๋ฒจ๋ฅ˜ ํ–‰๋ ฌ๊ณผ ๊ฐ๊ฐ์˜ LoRA ๊ฐ์ฒด๋ฅผ ๋”ํ•ด์ค„ ์ˆ˜๋งŒ ์žˆ๋‹ค๋ฉด ๋งค์šฐ ๊ฐ„๋‹จํ•˜๊ฒŒ ํ•ด๊ฒฐ๋  ๋ฌธ์ œ์ง€๋งŒ, ์‚ฌ์ „ํ•™์Šต ๋ชจ๋ธ์˜ Multi-Head Attention๊ฐ์ฒด๋ฅผ ์ฒ˜์Œ๋ถ€ํ„ฐ ์ €๋Ÿฐ์‹์œผ๋กœ ์ •์˜ํ•ด์•ผ๋งŒ ๊ฐ€๋Šฅํ•œ ์ผ์ด๋‹ค. ํ•„์ž๊ฐ€ ์ž‘์„ฑํ•œ ๋ชจ๋ธ ์ฝ”๋“œ๋ฅผ ๋น„๋กฏํ•ด ๋Œ€๋ถ€๋ถ„์˜ ์˜คํ”ˆ์†Œ์Šค๋กœ ํ’€๋ ค์žˆ๋Š” ํŠธ๋žœ์Šคํฌ๋จธ ๋ชจ๋ธ๋“ค์€ ์ €๋Ÿฐ์‹์œผ๋กœ ์ž‘์„ฑ๋˜์–ด ์žˆ์ง€ ์•Š๋‹ค. ๋”ฐ๋ผ์„œ ๋‹ค๋ฅธ ๋ฐฉ๋ฒ•์„ ๋– ์˜ฌ๋ คํ•˜๋Š”๋ฐ, ๋‹น์žฅ์€ ๋„ˆ๋ฌด ๋ณต์žกํ•œ ์ž‘์—…์ด ๋  ๊ฒƒ ๊ฐ™์•„(์‹คํ—˜ ์–ดํ”Œ๋ฆฌ์ผ€์ด์…˜ ๊ตฌ์กฐ๋ฅผ ๋’ค์—Ž์–ด์•ผ ๊ฐ€๋Šฅํ•  ๊ฒƒ์œผ๋กœ ์˜ˆ์ธก) ์ผ๋‹จ์€ ์—ฌ๊ธฐ์„œ ๋งˆ๋ฌด๋ฆฌํ•˜๋ ค๊ณ  ํ•œ๋‹ค. ๋งŒ์•ฝ LoRA๋ฅผ ์‚ฌ์ „ ํ•™์Šต ๋ชจ๋ธ์— ์ ์šฉํ•ด ํŒŒ์ธํŠœ๋‹์„ ํ•ด๋ณด๊ณ  ์‹ถ๋‹ค๋ฉด, Huggingface์˜ PEFT ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์ด์šฉํ•ด๋ณด์ž. Hugginface์˜ Automodel, Trainer ๊ฐ์ฒด์™€ ์œ ์—ฐํ•˜๊ฒŒ ์—ฐ๋™์ด ๊ฐ€๋Šฅํ•˜๋‹ค. ์•„๋ž˜์— PEFT ๊ณต์‹ ๋ฌธ์„œ์—์„œ ์ฐธ๊ณ ํ•œ Usage Example ์ฝ”๋“œ๋ฅผ ์ฒจ๋ถ€ํ–ˆ์œผ๋‹ˆ ์ฐธ๊ณ  ๋ถ€ํƒ๋ฐ”๋ž€๋‹ค.

""" PEFT LoRA Usage Example

Reference:
		https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/model.py
"""
>>> from transformers import AutoModelForSeq2SeqLM
>>> from peft import LoraModel, LoraConfig

>>> config = LoraConfig(
...     task_type="SEQ_2_SEQ_LM",
...     r=8,  # rank value in official paper
...     lora_alpha=32,  # alpha value in official paper
...     target_modules=["q", "v"],
...     lora_dropout=0.01,
... )

>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
>>> lora_model = LoraModel(model, config, "default")

>>> import torch
>>> import transformers
>>> from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training

>>> rank = ...
>>> target_modules = ["q_proj", "k_proj", "v_proj", "out_proj", "fc_in", "fc_out", "wte"]  # target for projection matrix, MLP
>>> config = LoraConfig(
...     r=4, lora_alpha=16, target_modules=target_modules, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM"
... )

>>> quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True)
>>> tokenizer = transformers.AutoTokenizer.from_pretrained(
...     "kakaobrain/kogpt",
...     revision="KoGPT6B-ryan1.5b-float16",  # or float32 version: revision=KoGPT6B-ryan1.5b
...     bos_token="[BOS]",
...     eos_token="[EOS]",
...     unk_token="[UNK]",
...     pad_token="[PAD]",
...     mask_token="[MASK]",
... )
>>> model = transformers.GPTJForCausalLM.from_pretrained(
...     "kakaobrain/kogpt",
...     revision="KoGPT6B-ryan1.5b-float16",  # or float32 version: revision=KoGPT6B-ryan1.5b
...     pad_token_id=tokenizer.eos_token_id,
...     use_cache=False,
...     device_map={"": rank},
...     torch_dtype=torch.float16,
...     quantization_config=quantization_config,
... )
>>> model = prepare_model_for_kbit_training(model)
>>> lora_model = get_peft_model(model, config) 

Leave a comment