πΒ [ViT] An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale
πΒ Overview
μμνκΈ° μμ, λ³Έ λ
Όλ¬Έ 리뷰λ₯Ό μμνκ² μ½μΌλ €λ©΄ Transformer
μ λν μ μ΄ν΄κ° νμμ μ΄λ€. μμ§ Transformer
μ λν΄μ μ λͺ¨λ₯Έλ€λ©΄ νμκ° μμ±ν ν¬μ€νΈλ₯Ό μ½κ³ μ€κΈΈ κΆμ₯νλ€. λν λ³Έλ¬Έ λ΄μ©μ μμ±νλ©΄μ μ°Έκ³ ν λ
Όλ¬Έκ³Ό μ¬λ¬ ν¬μ€νΈμ λ§ν¬λ₯Ό 맨 λ° νλ¨μ 첨λΆνμΌλ μ°Έκ³ λ°λλ€. μκ°μ΄ μμΌμ λΆλ€μ μ€κ°μ μ½λ ꡬνλΆλ₯Ό μλ΅νκ³ Insight
λΆν° μ½κΈ°λ₯Ό κΆμ₯νλ€.
Vision Transformer
(μ΄ν ViT
)λ 2020λ
10μ Googleμμ λ°νν μ»΄ν¨ν° λΉμ μ© λͺ¨λΈμ΄λ€. μμ°μ΄ μ²λ¦¬μμ λμ±κ³΅μ κ±°λ νΈλ μ€ν¬λ¨Έ ꡬ쑰μ κΈ°λ²μ κ±°μ κ·Έλλ‘ λΉμ λΆμΌμ μ΄μνλ€λ μ μμ ν° μμκ° μμΌλ©°, μ΄ν μ»΄ν¨ν° λΉμ λΆμΌμ νΈλ μ€ν¬λ¨Έ μ μ±μλκ° μ΄λ¦¬κ² λ κ³κΈ°λ‘ μμ©νλ€.
ννΈ, ViT
μ μ€κ³ μ² νμ λ°λ‘ scalability(λ²μ©μ±)
μ΄λ€. μ κ²½λ§ μ€κ³μμ λ²μ©μ±μ΄λ, λͺ¨λΈμ νμ₯ κ°λ₯μ±μ λ§νλ€. μλ₯Ό λ€λ©΄ νμ΅ λ°μ΄ν°λ³΄λ€ λ ν¬κ³ 볡μ‘ν λ°μ΄ν° μΈνΈλ₯Ό μ¬μ©νκ±°λ λͺ¨λΈμ νλΌλ―Έν°λ₯Ό λλ € μ¬μ΄μ¦λ₯Ό ν€μλ μ¬μ ν μ ν¨ν μΆλ‘ κ²°κ³Όλ₯Ό λμΆνκ±°λ λ λμ μ±λ₯μ 보μ¬μ£Όκ³ λμκ° κ°μ μ μ¬μ§κ° μ¬μ ν λ¨μμμ λ βνμ₯μ±μ΄ λλ€β
λΌκ³ νννλ€. μ μλ€μ λ
Όλ¬Έ μ΄λ°μ μ½ μ°μ΄μ μ»΄ν¨ν° λΉμ λΆμΌμ scalability
λμ΄λ κ²μ΄ μ΄λ² λͺ¨λΈ μ€κ³μ λͺ©νμλ€κ³ λ°νκ³ μλ€. λ²μ©μ±
μ μ κ²½λ§ λͺ¨λΈ μ€κ³μμ κ°μ₯ ν° νλκ° λλλ° λλ©μΈλ§λ€ μ μνλ μλ―Έμ μ°¨μ΄κ° λ―ΈμΈνκ² μ‘΄μ¬νλ€. λ°λΌμ ViT
μ μ μλ€μ΄ λ§νλ λ²μ©μ±
μ΄ λ¬΄μμ μλ―Ένλμ§ μμ보λ κ²μ ꡬ체μ μΈ λͺ¨λΈ ꡬ쑰λ₯Ό μ΄ν΄νλλ° ν° λμμ΄ λ κ²μ΄λ€.
π§ Β Scalability in ViT
λ Όλ¬Έ μ΄λ°λΆμμ λ€μκ³Ό κ°μ λ¬Έμ₯μ΄ μμ λμ΄μλ€.
βOur Vision Transformer (ViT) attains excellent results when pre-trained at sufficient scale and transferred to tasks with fewer datapoints"
μ΄ κ΅¬λ¬Έμ΄ ViT
μ Scalability
λ₯Ό κ°μ₯ μ μ€λͺ
νκ³ μλ€κ³ μκ°νλ€. μ μλ€μ΄ λ§νλ λ²μ©μ±μ κ²°κ΅ backbone
ꡬ쑰μ νμ©μ μλ―Ένλ€. μμ°μ΄ μ²λ¦¬μ μ΅μν λ
μλΌλ©΄ μ½κ² μ΄ν΄κ° κ°λ₯ν κ²μ΄λ€. Transformer
, GPT
, BERT
μ λ±μ₯ μ΄ν, μμ°μ΄ μ²λ¦¬λ λ²μ©μ±μ κ°λ λ°μ΄ν° μΈνΈλ‘ μ¬μ νλ ¨ν λͺ¨λΈμ νμ©ν΄ Task-Agnostic
νκ² νλμ backbone
μΌλ‘ κ±°μ λͺ¨λ Taskλ₯Ό μνν μ μμΌλ©°, μμ μ¬μ΄μ¦μ λ°μ΄ν°λΌλ μλΉν λμ μμ€μ μΆλ‘ μ±λ₯μ λΌ μ μμλ€. κ·Έλ¬λ λΉμ μ»΄ν¨ν° λΉμ μ λ©μΈμ΄μλ Conv
κΈ°λ° λͺ¨λΈλ€μ νμΈνλν΄λ λ°μ΄ν° ν¬κΈ°κ° μμΌλ©΄ μΌλ°ν μ±λ₯μ΄ λ§€μ° λ¨μ΄μ§κ³ , Taskμ λ°λΌμ λ€λ₯Έ μν€ν
μ²λ₯Ό κ°λ λͺ¨λΈμ μλ‘κ² μ μνκ±°λ λΆλ¬μ μ¬μ©ν΄μΌ νλ λ²κ±°λ‘μμ΄ μμλ€. μλ₯Ό λ€λ©΄ Image Classfication
μλ ResNet
, Segmentation
μλ U-Net
, Object Detection
μ YOLO
λ₯Ό μ¬μ©νλ κ²μ²λΌ λ§μ΄λ€. λ°λ©΄ μμ°μ΄ μ²λ¦¬λ μ¬μ νμ΅λ λͺ¨λΈ νλλ‘ λͺ¨λ NLU, μ¬μ§μ΄λ NLG Taskλ μνν μ μλ€. μ μλ€μ μ΄λ¬ν λ²μ©μ±μ μ»΄ν¨ν° λΉμ μλ μ΄μ μν€κ³ μΆμλ κ² κ°λ€. κ·Έλ λ€λ©΄ λ¨Όμ μμ°μ΄ μ²λ¦¬μμ νΈλμ€ν¬λ¨Έ κ³μ΄μ΄ λ²μ©μ±μ κ°μ§ μ μμλ μ΄μ λ 무μμΈμ§ κ°λ¨ν μ΄ν΄λ³΄μ.
μ μλ€μ self-attention
(λ΄μ )μ ν¨μ¨μ±, λͺ¨λΈμ ꡬ쑰μ νμμ± κ·Έλ¦¬κ³ self-supervised task
μ μ‘΄μ¬λ₯Ό κΌ½λλ€. κ·ΈλΌ μ΄κ²λ€μ΄ μ λ²μ©μ±μ λμ΄λλ° λμμ΄ λ κΉ??
self-attention(λ΄μ )
μ νλ ¬ κ° κ³±μ
μΌλ‘ μ μ λμ΄ μ€κ³κ° λ§€μ° κ°νΈνκ³ λ³λ ¬λ‘ νλ²μ μ²λ¦¬νλ κ²μ΄ κ°λ₯νκΈ° λλ¬Έμ ν¨μ¨μ μΌλ‘ μ 체 λ°μ΄ν°λ₯Ό λͺ¨λ κ³ λ €ν μ°μ° κ²°κ³Όλ₯Ό μ»μ μ μλ€.
Multi-Head Attention
ꡬ쑰λ μ¬λ¬ μ°¨μμ μλ―Έ κ΄κ³λ₯Ό λμμ ν¬μ°©νκ³ κ·Έκ²μ μμλΈν κ²κ³Ό κ°μ(μ€μ λ‘λ MLP) κ²°κ³Όλ₯Ό μ»μ μ μλ€λ μ μμ ꡬ쑰μ μΌλ‘ νμνλ€.
λ§μ§λ§μΌλ‘ MLM
, Auto-Regression(LM) Task
λ λ°μ΄ν° μΈνΈμ λ³λμ μΈκ°μ κ°μ
(λΌλ²¨λ§)μ΄ νμνμ§ μκΈ° λλ¬Έμ κ°μ±λΉ μκ² λ°μ΄ν°μ λͺ¨λΈμ μ¬μ΄μ¦λ₯Ό λ릴 μ μκ² λλ€.
μ΄μ λ
Όλ¬Έμμ νΈλμ€ν¬λ¨Έ κ³μ΄μ΄ κ°μ§ λ²μ©μ±μ μ΄λ»κ² λΉμ λΆμΌμ μ μ©νλμ§ μ£Όλͺ©νλ©΄μ λͺ¨λΈ ꡬ쑰λ₯Ό νλ νλ μ΄ν΄λ³΄μ.
πΒ Modeling
- 1) Transfer
Scalability
from pureTransformer
to Computer Vision- Overcome
reliance
on Convolution(Inductive Bias
) in Computer Vision - Apply Self-Attention & Architecture from vanilla NLP Transformers as
closely
as possible - Treat Image as sequence of text token
- Make $P$ sub-patches from whole image, playing same role as token in NLP Transformer
- Overcome
μ μλ€μ λ¨Όμ Conv
μ λν μμ‘΄μ λ²λ¦΄ κ²μ μ£Όμ₯νλ€. Conv
κ° κ°μ§ Inductive Bias
λλ¬Έμ νμΈνλ λ 벨μμ λ°μ΄ν° ν¬κΈ°κ° μμΌλ©΄ μΌλ°ν μ±λ₯μ΄ λ¨μ΄μ§λ κ²μ΄λΌκ³ μ€λͺ
νκ³ μλ€. μ΄ λ§μ μ΄ν΄νλ €λ©΄ Inductive Bias
μ λν΄μ λ¨Όμ μμμΌ νλ€. Inductive Bias
λ, μ£Όμ΄μ§ λ°μ΄ν°λ‘λΆν° μΌλ°ν μ±λ₯μ λμ΄κΈ° μν΄ βμ
λ ₯λλ λ°μ΄ν°λ ~ ν κ²μ΄λ€β
, βμ΄λ° νΉμ§μ κ°κ³ μμ κ²μ΄λ€β
μ κ°μ κ°μ , κ°μ€μΉ, κ°μ€ λ±μ κΈ°κ³νμ΅ μκ³ λ¦¬μ¦μ μ μ©νλ κ²μ λ§νλ€.
Conv
μ°μ° μ체 (κ°μ€μΉ 곡μ , νλ§ μλ Conv Block
μ΄ Invariance
)μ κΈ°λ³Έ κ°μ μ translation equivariance
, locality
μ΄λ€. μ¬μ€ μ μμ μ£Όμ₯μ μ΄ν΄νλλ° equivariance
μ locality
μ λ»μ΄ 무μμΈμ§ νμ
νλ κ²μ ν¬κ² μλ―Έκ° μλ€ (equivariance
μ invariance
μ λν΄μλ λ€λ₯Έ ν¬μ€ν
μμ μμΈν μ΄ν΄λ³΄λλ‘ νκ² λ€). μ€μν κ²μ μ
λ ₯ λ°μ΄ν°μ κ°μ μ λνλ€λ μ μ΄λ€. λ§μ½ μ£Όμ΄μ§ μ
λ ₯μ΄ λ―Έλ¦¬ κ°μ ν Inductive Bias
μ λ²μ΄λλ€λ©΄ μ΄λ»κ² λ κΉ??
μλ§ μ€λ²νΌν
λκ±°λ λͺ¨λΈ νμ΅μ΄ μλ ΄μ±μ κ°μ§ λͺ»νκ² λ κ²μ΄λ€. μ΄λ―Έμ§ λ°μ΄ν°λ Taskμ λ°λΌ νμν Inductive Bias
κ° λ¬λΌμ§λ€. μλ₯Ό λ€μ΄ Segmentation
, Detection
μ κ²½μ°λ μ΄λ―Έμ§ μ κ°μ²΄μ μμΉ, ν½μ
μ¬μ΄μ spatial variance
μ λ³΄κ° λ§€μ° μ€μνλ€. ννΈ, Classification
μ spatial invariance
κ° μ€μνλ€. λͺ©ν κ°μ²΄μ μμΉμ μ£Όλ³ νΉμ§λ³΄λ€ νκ² μ체λ₯Ό μ κ²½λ§μ΄ μΈμνλ κ²μ΄ μ€μνκΈ° λλ¬Έμ΄λ€. λ°λΌμ ViT
μ μλ€μ μ΄λ€ Biasλ μκ΄μμ΄ νΈν₯μ κ°κ³ λ°μ΄ν°λ₯Ό λ³Έλ€λ κ² μ체μ μλ¬Έμ ννλ©°, μ΄λ―Έμ§ μμ Inductive Bias
μμ λ²μ΄λ, μ£Όμ΄μ§ λ°μ΄ν° μ 체 νΉμ§(ν¨μΉ) μ¬μ΄μ κ΄κ³λ₯Ό νμ
νλ κ³Όμ μμ scalability
λ₯Ό νλν μ μλ€κ³ μ£Όμ₯νλ€.
κ·Έλμ Conv
μ λμμΌλ‘ μλμ μΌλ‘ Inductive Bias
κ° λΆμ‘±ν Self-Attention
, Transformer Architecture
λ₯Ό μ¬μ©νλ€. λκ°μ§μ ν¨μ©μ±μ λν΄μλ μ΄λ―Έ μμμ μΈκΈνκΈ° λλ¬Έμ μλ΅νκ³ , μ¬κΈ°μ μ§κ³ λμ΄κ°μΌν μ μ Self-Attention
μ΄ Conv
λλΉ Inductive Bias
κ° μ λ€λ μ μ΄λ€. Self-Attention κ³Όμ μλ μ¬λ¬ μ°μ°, μ€μΌμΌ μ‘°μ κ°λ€μ΄ ν¬ν¨λμ§λ§ λ³Έμ§μ μΌλ‘ βλ΄μ β
μ΄ μ€μ¬μ΄λ€. λ΄μ μ κ·Έ μ΄λ€ νΈν₯ (Conv
μ λμ‘°νλ €κ³ μ΄λ κ² μμ νμ§λ§ μ¬μ€ Position Embedding
λνλ κ²λ μΌμ’
μ μ½ν Inductive Bias
)μ΄ μ‘΄μ¬νμ§ μλλ€. μΌλ¨ μ£Όμ΄μ§ λͺ¨λ λ°μ΄ν°μ λν΄μ λ΄μ κ°μ μ°μΆνκ³ κ·Έ λ€μμ κ΄κ³κ° μλ€κ³ μκ°λλ μ 보λ₯Ό μΆλ¦¬κΈ° λλ¬Έμ΄λ€. Conv
λμ λ¬λ¦¬ βμ
λ ₯λλ λ°μ΄ν°λ ~ ν κ²μ΄λ€β
, βμ΄λ° νΉμ§μ κ°κ³ μμ κ²μ΄λ€β
λΌλ κ°μ μ΄ μλ€. μ΄λ² ν¬μ€ν
μ λ§μ§λ§ μ―€μμ λ€μ λ€λ£¨κ² μ§λ§ κ·Έλμ ViT
λ μΈμ€ν΄μ€ μ¬μ΄μ λͺ¨λ κ΄κ³λ₯Ό λ½μ보λ Self-Attention(λ΄μ )
μ κΈ°λ°μΌλ‘ λ§λ€μ΄μ‘κΈ° λλ¬Έμ μ΄λ―Έμ§μ Global Information
μ ν¬μ°©νλλ° νμν μ±λ₯μ 보μ΄κ³ , Conv
λ βμ€μν μ 보λ κ·Όμ² ν½μ
μ λͺ°λ €μλ€λΌλβ Inductive Bias
λλΆμ Local Information
μ ν¬μ°©νλλ° νμν μ±λ₯μ λΈλ€.
κ·Έλ λ€λ©΄ ν½μ
νλ νλλΌλ¦¬ λ΄μ ν΄μ€λ€λ κ²μΌκΉ?? μλλ€ μ¬κΈ°μ λ
Όλ¬Έμ μ λͺ©μ΄ An Image Is Worth 16x16 Words
μΈ μ΄μ κ° λλ¬λλ€. μΌλ¨ ν½μ
νλ νλλΌλ¦¬ μ μ¬λλ₯Ό μΈ‘μ νλ κ²μ΄ μ μλ―Έν κΉ μκ°ν΄λ³΄μ. μμ°μ΄μ ν ν°κ³Ό λ¬λ¦¬ μ΄λ―Έμ§μ λ¨μΌ ν½μ
ν κ°λ ν° μΈμ¬μ΄νΈλ₯Ό μ»κΈ° νλ€λ€. ν½μ
μ λ§ κ·Έλλ‘ μ νλμΌ λΏμ΄λ€. ν½μ
μ μ¬λ¬ κ° λ¬Άμ΄ ν¨μΉ λ¨μλ‘ λ¬Άλλ€λ©΄ μ΄μΌκΈ°λ λ¬λΌμ§λ€. μΌμ ν¬κΈ° μ΄μμ ν¨μΉλΌλ©΄ μμ°μ΄μ ν ν°μ²λΌ κ·Έ μμ²΄λ‘ μ΄λ€ μλ―Έλ₯Ό λ΄μ μ μλ€. λ°λΌμ μ μλ μ 체 μ΄λ―Έμ§λ₯Ό μ¬λ¬ κ°μ 16x16 νΉμ 14x14 μ¬μ΄μ¦ ν¨μΉλ‘ λλμ΄ νλ νλλ₯Ό ν ν°μΌλ‘ κ°μ£Όν΄ μ΄λ―Έμ§ μνμ€λ₯Ό λ§λ€κ³ κ·Έκ²μ λͺ¨λΈμ InputμΌλ‘ μ¬μ©νλ€.
Class Diagram
λͺ¨λΈ ꡬ쑰μ λΌλκ° λλ λ΄μ©λ€μ λͺ¨λ μ΄ν΄λ³΄μκ³ , μμμ μμ ν λ΄μ©μ ꡬννκΈ° μν΄ μ΄λ€ λΈλ‘λ€μ μ¬μ©νλμ§ νμκ° μ§μ λ
Όλ¬Έμ λ³΄κ³ λ°λΌ ꡬνν μ½λμ ν¨κ» μμ보λλ‘ νμ. μμ 첨λΆν λͺ¨λΈ λͺ¨μλμ λμ μλ λΈλ‘λ€ νλ νλ μ΄ν΄λ³Ό μμ μ΄λ€. μ¬λ΄μΌλ‘ Google Researchμ Official Repo μμ ν¨κ» μ°Έκ³ νλλ°, μ½λκ° λͺ¨λ ꡬκΈμ΄ μμ μλ‘κ² λ―Έλ Jax
, Flax
λ‘ κ΅¬ν λμ΄ μμλ€. νμ΄ν μΉλ μ’ μ¨λ³Έ νμ μ
μ₯μμλ μ λ§ β¦ μ§μ₯λΆμ κ²½ννλ€. μ€λλ λ€μ ν λ² νμ΄μ€λΆ νμ΄ν μΉ κ°λ°νμ ν°μ λλ¦¬κ³ μΆλ€.
π¬Β Linear Projection of Flattened Patches
\[x_p \in R^{N * (P^2β’C)}\]
\[z_{0} = [x_{class}; x_p^1E;x_p^2E;x_p^3E....x_p^NE]\]
\[N = \frac{H*W}{P*P}\]
ViT
μ μ
λ ₯ μλ² λ©μ μμ±νλ μν μ νλ€. ViT
λ $x \in R^{H * W * C}$(H: height, W: width, C: channel)μ νμμ κ°λ μ΄λ―Έμ§λ₯Ό μ
λ ₯μΌλ‘ λ°μ κ°λ‘ μΈλ‘ κΈΈμ΄κ° $P$, μ±λ κ°μ $C$μΈ $N$κ°μ ν¨μΉλ‘ reshape
νλ€. νμκ° μ½λ ꡬν μ€ κ°μ₯ νΌλν λΆλΆμ΄ λ°λ‘ ν¨μΉ κ°μ $N$μ΄μλ€. μ§κ΄μ μΌλ‘ ν¨μΉ κ°μλΌκ³ νλ©΄, μ 체 μ΄λ―Έμ§ μ¬μ΄μ¦μμ ν¨μΉ ν¬κΈ°λ₯Ό λλ κ°μ΄λΌκ³ μκ°νκΈ° μ½κΈ° λλ¬Έμ΄λ€. μλ₯Ό λ€λ©΄ 512x512
μ§λ¦¬ μ΄λ―Έμ§λ₯Ό 16x16
μ¬μ΄μ¦μ ν¨μΉλ‘ λλλ€κ³ ν΄λ³΄μ. νμλ λ¨μν 512/16=32
λΌλ κ²°κ³Όλ₯Ό μ΄μ©ν΄ $N=32$λ‘ μ€μ νκ³ μ€νμ μ§ννλ€κ° ν
μ μ°¨μμ΄ λ§μ§ μμ λ°μνλ μλ¬ λ‘κ·Έλ₯Ό λ§μ£Όνμλ€. κ·Έλ¬λ λ
Όλ¬Έ μ μμμ νμΈν΄λ³΄λ©΄, $H * W / P^2$μ΄ λ°λ‘ ν¨μΉ κ°μ$N$μΌλ‘ μ μλλ€. κ·Έλμ λ§μ½ 512x512
μ¬μ΄μ¦μ RGB
μ΄λ―Έμ§ 10μ₯
μ ViT μ
λ ₯ μλ² λ©μ λ§κ² μ°¨μ λ³ννλ€λ©΄ κ²°κ³Όλ [10, 3, 1024, 768]
μ΄ λ κ²μ΄λ€. (μ΄ μμλ₯Ό μμΌλ‘ κ³μ μ΄μ©νκ² λ€)
μ΄λ κ² μ°¨μμ λ°κΏμ€ μ΄λ―Έμ§λ₯Ό nn.Linear((channels * patch_size**2), dim_model)
λ₯Ό ν΅ν΄ ViT
μ μλ² λ© λ μ΄μ΄μ μ ν ν¬μν΄μ€λ€. μ¬κΈ°μ μμ°μ΄ μ²λ¦¬μ νμ΄ν μΉλ₯Ό μμ£Ό μ¬μ©νμλ λ
μλΌλ©΄ μ nn.Embedding
μ μ¬μ©νμ§ μμλκ° μλ¬Έμ κ°μ§ μ μλ€.
μμ°μ΄ μ²λ¦¬μμ μ
λ ₯ μλ² λ©μ λ§λ€λλ λͺ¨λΈμ ν ν¬λμ΄μ μ μν΄ μ¬μ μ μλ vocabμ μ¬μ΄μ¦κ° μ
λ ₯ λ¬Έμ₯μ μν ν ν° κ°μλ³΄λ€ ν¨μ¬ ν¬κΈ° λλ¬Έμ λ°μ΄ν° 룩μ
ν
μ΄λΈ λ°©μμ nn.Embedding
μ μ¬μ©νκ² λλ€. μ΄κ² λ¬΄μ¨ λ§μ΄λλ©΄, ν ν¬λμ΄μ μ μν΄ μ¬μ μ μ μλ vocab
μ μ²΄κ° nn.Embedding(vocab_size, dim_model)
λ‘ ν¬μ λμ΄ κ°λ‘λ vocab μ¬μ΄μ¦, μΈλ‘λ λͺ¨λΈμ μ°¨μ ν¬κΈ°μ ν΄λΉνλ 룩μ
ν
μ΄λΈμ΄ μμ±λκ³ , λ΄κ° μ
λ ₯ν ν ν°λ€μ μ 체 vocab
μ μΌλΆλΆμΌν
λ μ 체 μλ² λ© λ£©μ
ν
μ΄λΈμμ λ΄κ° μλ² λ©νκ³ μΆμ ν ν°λ€μ μΈλ±μ€λ§ μμλΈλ€λ κ²μ΄λ€.
κ·Έλμ nn.Embedding
μ μ μλ μ°¨μκ³Ό μ€μ μ
λ ₯ λ°μ΄ν°μ μ°¨μμ΄ λ§μ§ μμλ ν¨μκ° λμνκ² λλ κ²μ΄λ€. κ·Έλ¬λ λΉμ μ κ²½μ°, μ¬μ μ μ μλ vocab
μ΄λΌλ κ°λ
μ΄ μ ν μκ³ μ
λ ₯ μ΄λ―Έμ§ μμ νμ κ³ μ λ ν¬κΈ°μ μ°¨μμΌλ‘ λ€μ΄μ€κΈ° λλ¬Έμ nn.Embedding
μ΄ μλ nn.Linear
μ μ¬μ©ν΄ 곧λ°λ‘ μ ν ν¬μμ ꡬνν κ²μ΄λ€. λ λ©μλμ λν μμΈν λΉκ΅λ νμ΄ν μΉ κ΄λ ¨ ν¬μ€νΈμμ λ€μ ν λ² μμΈν λ€λ£¨λλ‘ νκ² λ€.
ννΈ, Position Embedding
μ λνκΈ° μ , Input Embedding
μ μ°¨μμ [10, 1024, 1024]
μ΄ λλ€. μ§κΈκΉμ§ μ€λͺ
ν λΆλΆ(Linear Projection of Flattened Patches
)μ νμ΄ν μΉ μ½λλ‘ κ΅¬ννλ©΄ λ€μκ³Ό κ°λ€.
class VisionTransformer(nn.Module):
...
μ€λ΅
...
self.num_patches = int(image_size / patch_size)**2
self.input_embedding = nn.Linear((channels * patch_size**2), dim_model) # Projection Layer for Input Embedding
...
μ€λ΅
...
def forward(self, inputs: Tensor) -> any:
""" For cls pooling """
assert inputs.ndim != 4, f"Input shape should be [BS, CHANNEL, IMAGE_SIZE, IMAGE_SIZE], but got {inputs.shape}"
x = inputs
x = self.input_embedding(
x.reshape(x.shape[0], self.num_patches, (self.patch_size**2 * x.shape[1])) # Projection Layer for Input Embedding
)
cls_token = torch.zeros(x.shape[0], 1, x.shape[2]) # can change init method
x = torch.cat([cls_token, x], dim=1)
...
μλ² λ© λ μ΄μ΄λ₯Ό κ°μ²΄λ‘ λ°λ‘ ꡬνν΄λ λμ§λ§, νμλ κ΅³μ΄ μΆμνκ° νμνμ§ μλ€κ³ μκ°ν΄ ViTμ μ΅μμ ν΄λμ€μΈ VisionTransformer
μ forward
λ©μλ 맨 μ΄λ°λΆμ ꡬννκ² λμλ€. μ
λ ₯ λ°μ μ΄λ―Έμ§ ν
μλ₯Ό torch.reshape
μ ν΅ν΄ [ν¨μΉ κ°μ, ν½μ
κ°μ*μ±λκ°μ]
λ‘ λ°κΎΌ λ€, 미리 μ μν΄λ self.input_embedding
μ 맀κ°λ³μλ‘ μ λ¬ν΄ βμμΉ μλ² λ©β
κ°μ΄ λν΄μ§κΈ° μ Input Embedding
μ λ§λ λ€.
ννΈ, CLS Pooling
μ μν΄ λ§μ§λ§μ [batch, 1, image_size]
μ μ°¨μμ κ°λ cls_token
μ μ μν΄ ν¨μΉ μνμ€μ concat
(맨 μμ)ν΄μ€λ€. μ΄ λ λ
Όλ¬Έμ μ μλ μμ μ, CLS Token
μ μ ν ν¬μνμ§ μμΌλ©°, ν¨μΉ μνμ€μ μ ν ν¬μμ΄ μ΄λ€μ§κ³ λ λ€μ 맨 μμ Concat
νκ² λλ€.
CLS Token
κΉμ§ λν μ΅μ’
Input Embedding
μ ν
μ μ°¨μμ [10, 1025, 1024]
κ° λλ€.
π’Β Positional Embedding
\[E_{pos} \in R^{(N+1)*D}\]
μ΄λ―Έμ§λ₯Ό ν¨μΉ λ¨μμ μλ² λ©μΌλ‘ λ§λ€μλ€λ©΄ μ΄μ μμΉ μλ² λ©μ μ μν΄μ λν΄μ£Όλ©΄ λͺ¨μλ μ Embedded Patches
, μ¦ μΈμ½λμ λ€μ΄κ° μ΅μ’
Patch Embedding
μ΄ μμ± λλ€. μμΉ μλ² λ©μ λ§λλ λ°©μμ κΈ°μ‘΄ Transformer
, BERT
μ λμΌνλ€. μλ VisionEncoder
ν΄λμ€λ₯Ό ꡬνν μ½λλ₯Ό μ΄ν΄λ³΄μ.
class VisionEncoder(nn.Module):
...
μ€λ΅
...
self.positional_embedding = nn.Embedding((self.num_patches + 1), dim_model) # add 1 for cls token
...
μ€λ΅
...
def forward(self, inputs: Tensor) -> tuple[Tensor, Tensor]:
layer_output = []
pos_x = torch.arange(self.num_patches + 1).repeat(inputs.shape[0]).to(inputs) # inputs.shape[0] = Batch Size of Input
x = self.dropout(
inputs + self.positional_embedding(pos_x)
)
...
Input Embedding
κ³Ό λ€λ₯΄κ² μμΉ μλ² λ©μ nn.Embedding
μΌλ‘ ꡬννλλ°, μ¬κΈ°μλ μ¬μ€ nn.Linear
λ₯Ό μ¬μ©ν΄λ 무방νλ€. κ·Έκ²λ³΄λ€ nn.Embedding
μ μ
λ ₯ μ°¨μμΈ self.num_patches + 1
μ μ£Όλͺ©ν΄λ³΄μ. μ 1μ λν΄μ€ κ°μ μ¬μ©νμκΉ??
ViT
λ BERTμ CLS Token Pooling
μ μ°¨μ©νκΈ° μν΄ ν¨μΉ μνμ€ λ§¨ μμ CLS ν ν°μ μΆκ°νκΈ° λλ¬Έμ΄λ€. μ΄λ κ² μΆκ°λ CLS Token
μ μΈμ½λλ₯Ό κ±°μ³ μ΅μ’
MLP Head
μ νλ¬λ€μ΄κ° λ‘μ§μΌλ‘ λ³νλλ€. λ§μ½ λ
μκ»μ CLS Token Pooling
λμ λ€λ₯Έ νλ§ λ°©μμ μ¬μ©ν κ±°λΌλ©΄ 1μ μΆκ°ν΄μ€ νμλ μλ€.
μ μ΄μ κ°μ²΄ μΈμ€ν΄μ€ μ΄κΈ°ν λΉμμ CLS Token
μ μΆκ°λ₯Ό λ°μν κ°μ μ λ¬νλ©΄ λμ§ μλκ°νλ μλ¬Έμ΄ λ€ μλ μλ€. νμ§λ§ VisionEncoder
κ°μ²΄ μΈμ€ν΄μ€ μ΄κΈ°ν λΉμμλ num_patches
κ°μΌλ‘ CLS Token
μ΄ μΆκ°λκΈ° μ΄μ κ°(+1 λ°μμ΄ μλμ΄ μμ)μ μ λ¬νλλ‘ μ€κ³ λμ΄ μμ΄μ CLS Pooling
μ μ¬μ©ν κ±°λΌλ©΄ 1 μΆκ°λ₯Ό κΌ ν΄μ€μΌ νλ€.
Performance Table by making Position Embedding method
ννΈ μ μλ 2D Postion Embedding
, Relative Position Embedding
λ°©μλ μ μ©ν΄λ΄€μ§λ§, ꡬν 볡μ‘λ & μ°μ°λ λλΉ μ±λ₯ ν₯μ νμ΄ λ§€μ° λ―Έλ―Έν΄ μΌλ°μ μΈ 1D Position Embedding
μ μ¬μ©ν κ²μ μΆμ²νκ³ μλ€.
π©βπ©βπ§βπ¦ Multi-Head Attention
\[z_t^{'} = MSA(LN(z_{t-1}) + z_{t-1})\]
\[MSA(z) = [SA_1();SA_2();SA_3()...SA_k()]*U_{msa}, \ \ U_{msa} \in R^{(k*D_h)*D} \\\]
νΈλμ€ν¬λ¨Έ κ³μ΄ λͺ¨λΈμ ν΅μ¬ Multi-Head Self-Attention
λͺ¨λμ λν΄μ μμ보μ. μ¬μ€ κΈ°μ‘΄ μμ°μ΄ μ²λ¦¬ Transformer
, BERT
λ±μ λμ λ°©μκ³Ό μμ ν λμΌνλ©°, μ½λλ‘ κ΅¬νν λ μμ λμΌνκ² λ§λ€μ΄μ£Όλ©΄ λλ€. μμΈν μ리μ λμ λ°©μμ Attention Is All You Need 리뷰 ν¬μ€νΈμμ μ€λͺ
νκΈ° λλ¬Έμ μλ΅νκ³ λμ΄κ°κ² λ€. ννΈ νμ΄ν μΉλ‘ ꡬνν Multi-Head Self-Attention
λΈλμ λν μ½λλ λ€μκ³Ό κ°λ€.
def scaled_dot_product_attention(q: Tensor, k: Tensor, v: Tensor, dot_scale: Tensor) -> Tensor:
"""
Scaled Dot-Product Attention
Args:
q: query matrix, shape (batch_size, seq_len, dim_head)
k: key matrix, shape (batch_size, seq_len, dim_head)
v: value matrix, shape (batch_size, seq_len, dim_head)
dot_scale: scale factor for Qβ’K^T result, same as pure transformer
Math:
A = softmax(qβ’k^t/sqrt(D_h)), SA(z) = Av
"""
attention_dist = F.softmax(
torch.matmul(q, k.transpose(-1, -2)) / dot_scale,
dim=-1
)
attention_matrix = torch.matmul(attention_dist, v)
return attention_matrix
class AttentionHead(nn.Module):
"""
In this class, we implement workflow of single attention head
Args:
dim_model: dimension of model's latent vector space, default 1024 from official paper
dim_head: dimension of each attention head, default 64 from official paper (1024 / 16)
dropout: dropout rate, default 0.1
Math:
[q,k,v]=zβ’U_qkv, A = softmax(qβ’k^t/sqrt(D_h)), SA(z) = Av
"""
def __init__(self, dim_model: int = 1024, dim_head: int = 64, dropout: float = 0.1) -> None:
super(AttentionHead, self).__init__()
self.dim_model = dim_model
self.dim_head = dim_head
self.dropout = dropout
self.dot_scale = torch.sqrt(torch.tensor(self.dim_head))
self.fc_q = nn.Linear(self.dim_model, self.dim_head)
self.fc_k = nn.Linear(self.dim_model, self.dim_head)
self.fc_v = nn.Linear(self.dim_model, self.dim_head)
def forward(self, x: Tensor) -> Tensor:
attention_matrix = scaled_dot_product_attention(
self.fc_q(x),
self.fc_k(x),
self.fc_v(x),
self.dot_scale
)
return attention_matrix
class MultiHeadAttention(nn.Module):
"""
In this class, we implement workflow of Multi-Head Self-Attention
Args:
dim_model: dimension of model's latent vector space, default 1024 from official paper
num_heads: number of heads in MHSA, default 16 from official paper for ViT-Large
dim_head: dimension of each attention head, default 64 from official paper (1024 / 16)
dropout: dropout rate, default 0.1
Math:
MSA(z) = [SA1(z); SA2(z); Β· Β· Β· ; SAk(z)]β’Umsa
Reference:
https://arxiv.org/abs/2010.11929
https://arxiv.org/abs/1706.03762
"""
def __init__(self, dim_model: int = 1024, num_heads: int = 8, dim_head: int = 64, dropout: float = 0.1) -> None:
super(MultiHeadAttention, self).__init__()
self.dim_model = dim_model
self.num_heads = num_heads
self.dim_head = dim_head
self.dropout = dropout
self.attention_heads = nn.ModuleList(
[AttentionHead(self.dim_model, self.dim_head, self.dropout) for _ in range(self.num_heads)]
)
self.fc_concat = nn.Linear(self.dim_model, self.dim_model)
def forward(self, x: Tensor) -> Tensor:
""" x is already passed nn.Layernorm """
assert x.ndim == 3, f'Expected (batch, seq, hidden) got {x.shape}'
attention_output = self.fc_concat(
torch.cat([head(x) for head in self.attention_heads], dim=-1) # concat all dim_head = num_heads * dim_head
)
return attention_output
MultiHeadAttention
μ κ°μ₯ μ΅μμ κ°μ²΄λ‘ λκ³ , νμμ AttentionHead
κ°μ²΄λ₯Ό λ°λ‘ ꡬννλ€. μ΄λ κ² κ΅¬ννλ©΄, μ΄ν
μ
ν΄λλ³λ‘ 쿼리, ν€, 벨λ₯ μ μ ν¬μ νλ ¬(nn.Linear
)μ λ°λ‘ ꡬνν΄μ€ νμκ° μμ΄μ§λ©°, nn.ModuleList
λ₯Ό ν΅ν΄ κ°λ³ ν΄λλ₯Ό ν λ²μ κ·Έλ£Ήννκ³ loop
λ₯Ό ν΅ν΄ μΆλ ₯ κ²°κ³Όλ₯Ό concat
ν΄μ€ μ μμ΄ λ³΅μ‘νκ³ λ§μ μλ¬λ₯Ό μ λ°νλ ν
μ μ°¨μ μ‘°μμ νΌν μ μμΌλ©°, μ½λμ κ°λ
μ±μ΄ μ¬λΌκ°λ ν¨κ³Όκ° μλ€.
π³οΈ MLP
\[z_{t} = MLP(LN(z_{t}^{'}) + z_{t}^{'})\]
μ΄λ¦λ§ MLP
λ‘ λ°λμμ λΏ, κΈ°μ‘΄ νΈλμ€ν¬λ¨Έμ νΌλ ν¬μλ λΈλκ³Ό λμΌν μν μ νλ€. μμ μμΈν λμ λ°©μμ μ¬κΈ° ν¬μ€νΈμμ νμΈνμ. νμ΄ν μΉλ‘ ꡬνν μ½λλ λ€μκ³Ό κ°λ€.
class MLP(nn.Module):
"""
Class for MLP module in ViT-Large
Args:
dim_model: dimension of model's latent vector space, default 512
dim_mlp: dimension of FFN's hidden layer, default 2048 from official paper
dropout: dropout rate, default 0.1
Math:
MLP(x) = MLP(LN(x))+x
"""
def __init__(self, dim_model: int = 1024, dim_mlp: int = 4096, dropout: float = 0.1) -> None:
super(MLP, self).__init__()
self.mlp = nn.Sequential(
nn.Linear(dim_model, dim_mlp),
nn.GELU(),
nn.Dropout(p=dropout),
nn.Linear(dim_mlp, dim_model),
nn.Dropout(p=dropout),
)
def forward(self, x: Tensor) -> Tensor:
return self.mlp(x)
νΉμ΄ν μ μ Activation Function
μΌλ‘ GELU
λ₯Ό μ¬μ©(κΈ°μ‘΄ νΈλμ€ν¬λ¨Έλ RELU
)νλ€λ μ μ΄λ€.
π Vision Encoder Layer
ViT
μΈμ½λ λΈλ 1κ°μ ν΄λΉνλ νμ λͺ¨λκ³Ό λμμ ꡬνν κ°μ²΄μ΄λ€. ꡬνν μ½λλ μλμ κ°λ€.
class VisionEncoderLayer(nn.Module):
"""
Class for encoder_model module in ViT-Large
In this class, we stack each encoder_model module (Multi-Head Attention, Residual-Connection, Layer Normalization, MLP)
"""
def __init__(self, dim_model: int = 1024, num_heads: int = 16, dim_mlp: int = 4096, dropout: float = 0.1) -> None:
super(VisionEncoderLayer, self).__init__()
self.self_attention = MultiHeadAttention(
dim_model,
num_heads,
int(dim_model / num_heads),
dropout,
)
self.layer_norm1 = nn.LayerNorm(dim_model)
self.layer_norm2 = nn.LayerNorm(dim_model)
self.dropout = nn.Dropout(p=dropout)
self.mlp = MLP(
dim_model,
dim_mlp,
dropout,
)
def forward(self, x: Tensor) -> Tensor:
ln_x = self.layer_norm1(x)
residual_x = self.dropout(self.self_attention(ln_x)) + x
ln_x = self.layer_norm2(residual_x)
fx = self.mlp(ln_x) + residual_x # from official paper & code by Google Research
return fx
νΉμ΄μ μ λ§μ§λ§
λ§μ§λ§ μΈμ½λμ μΆλ ₯κ°μλ§ νλ² λ MLP Layer
μ Residual
κ²°κ³Όλ₯Ό λν λ€, λ€μ μΈμ½λ λΈλ‘μ μ λ¬νκΈ° μ μ μΈ΅ μ κ·νλ₯Ό ν λ² λ μ μ©νλ€λ κ²μ΄λ€. λͺ¨λΈ λͺ¨μλμλ λμ μμ§ μμ§λ§, λ³Έλ¬Έμ ν΄λΉ λ΄μ©μ΄ μ€λ € μλ€.layernorm
μ μ μ©νλ€.
π VisionEncoder
μ
λ ₯ μ΄λ―Έμ§λ₯Ό Patch Embedding
μΌλ‘ μΈμ½λ© νκ³ Nκ°μ VisionEncoderLayer
λ₯Ό μκΈ° μν΄ κ΅¬νλ κ°μ²΄μ΄λ€. Patch Embedding
μ λ§λλ λΆλΆμ μ΄λ―Έ μμμ μ€λͺ
νκΈ° λλ¬Έμ λμ΄κ°κ³ , μΈμ½λ λΈλμ Nκ° μλ λ°©λ²μ μμλ nn.ModuleList
λ₯Ό μ¬μ©νλ©΄ κ°νΈνκ² κ΅¬νν μ μλ€. μλ μ½λλ₯Ό μ΄ν΄λ³΄μ.
class VisionEncoder(nn.Module):
"""
In this class, encode input sequence(Image) and then we stack N VisionEncoderLayer
This model is implemented by cls pooling method for classification
First, we define "positional embedding" and then add to input embedding for making patch embedding
Second, forward patch embedding to N EncoderLayer and then get output embedding
Args:
num_patches: number of patches in input image => (image_size / patch_size)**2
N: number of EncoderLayer, default 24 for large model
"""
def __init__(self, num_patches: int, N: int = 24, dim_model: int = 1024, num_heads: int = 16, dim_mlp: int = 4096, dropout: float = 0.1) -> None:
super(VisionEncoder, self).__init__()
self.num_patches = num_patches
self.positional_embedding = nn.Embedding((self.num_patches + 1), dim_model) # add 1 for cls token
self.num_layers = N
self.dim_model = dim_model
self.num_heads = num_heads
self.dim_mlp = dim_mlp
self.dropout = nn.Dropout(p=dropout)
self.encoder_layers = nn.ModuleList(
[VisionEncoderLayer(dim_model, num_heads, dim_mlp, dropout) for _ in range(self.num_layers)]
)
self.layer_norm = nn.LayerNorm(dim_model)
def forward(self, inputs: Tensor) -> tuple[Tensor, Tensor]:
layer_output = []
pos_x = torch.arange(self.num_patches + 1).repeat(inputs.shape[0]).to(inputs)
x = self.dropout(
inputs + self.positional_embedding(pos_x)
)
for layer in self.encoder_layers:
x = layer(x)
layer_output.append(x)
encoded_x = self.layer_norm(x) # from official paper & code by Google Research
layer_output = torch.stack(layer_output, dim=0).to(x.device) # For Weighted Layer Pool: [N, BS, SEQ_LEN, DIM]
return encoded_x, layer_output
λ§μ§λ§ μΈ΅μ μΈμ½λ μΆλ ₯κ°μλ layernorm
μ μ μ©ν΄μ€μΌ ν¨μ μμ§ λ§μ. ννΈ, layer_output
λ λ μ΄μ΄ λ³ μ΄ν
μ
κ²°κ³Όλ₯Ό μκ°ν νκ±°λ λμ€μ WeightedLayerPool
μ μ¬μ©νλ €κ³ λ§λ€μλ€.
π€ VisionTransformer
ViT
λͺ¨λΈμ κ°μ₯ μ΅μμ κ°μ²΄λ‘, μμμ μ€λͺ
ν λͺ¨λ λͺ¨λλ€μ λμμ΄ μ΄λ€μ§λ κ³³μ΄λ€. μ¬μ©μλ‘λΆν° νμ΄νΌνλΌλ―Έν°λ₯Ό μ
λ ₯ λ°μ λͺ¨λΈμ ν¬κΈ°, κΉμ΄, ν¨μΉ ν¬κΈ°, μ΄λ―Έμ§ μλ² λ© μΆμΆ λ°©μμ μ§μ νλ€. κ·Έλ¦¬κ³ μ
λ ₯ μ΄λ―Έμ§λ₯Ό μ λ¬λ°μ μλ² λ©μ λ§λ€κ³ μΈμ½λμ μ λ¬ν λ€, MLP Head
λ₯Ό ν΅ν΄ μ΅μ’
μμΈ‘ κ²°κ³Όλ₯Ό λ°ννλ μν μ νλ€.
μ΄λ―Έμ§ μλ² λ© μΆμΆ λ°©μμ Linear Projection
κ³Ό Convolution
μ΄ μλ€. μ μκ° λ
Όλ¬Έμμ λ§νλ μΌλ°μ μΈ ViT
λ₯Ό λ§νλ©° νμλ μ μκ° Hybrid ViT
λΌκ³ λ°λ‘ λͺ
λͺ
νλ λͺ¨λΈμ΄λ€. μλ² λ© μΆμΆ λ°©μ μ΄μΈμ λ€λ₯Έ μ°¨μ΄λ μ ν μλ€. extractor
맀κ°λ³μλ₯Ό ν΅ν΄ μλ² λ© μΆμΆ λ°©μμ μ§μ ν μ μμΌλ μλ μ½λλ₯Ό νμΈν΄λ³΄μ.
class VisionTransformer(nn.Module):
"""
Main class for ViT of cls pooling, Pytorch implementation
We implement pure ViT, Not hybrid version which is using CNN for extracting patch embedding
input must be [BS, CHANNEL, IMAGE_SIZE, IMAGE_SIZE]
In NLP, input_sequence is always smaller than vocab size
But in vision, input_sequence is always same as image size, not concept of vocab in vision
So, ViT use nn.Linear instead of nn.Embedding for input_embedding
Args:
num_classes: number of classes for classification task
image_size: size of input image, default 512
patch_size: size of patch, default 16 from official paper for ViT-Large
extractor: option for feature extractor, default 'base' which is crop & just flatten
if you want to use Convolution for feature extractor, set extractor='cnn' named hybrid ver in paper
classifier: option for pooling method, default token meaning that do cls pooling
if you want to use mean pooling, set classifier='mean'
mode: option for train type, default fine-tune, if you want pretrain, set mode='pretrain'
In official paper & code by Google Research, they use different classifier head for pretrain, fine-tune
Math:
image2sequence: [batch, channel, image_size, image_size] -> [batch, patch, patch_size^2*channel]
input_embedding: R^(P^2 Β·C)ΓD
Reference:
https://arxiv.org/abs/2010.11929
https://arxiv.org/abs/1706.03762
https://github.com/google-research/vision_transformer/blob/main/vit_jax/models_vit.py#L184
"""
def __init__(
self,
num_classes: int,
channels: int = 3,
image_size: int = 512,
patch_size: int = 16,
num_layers: int = 24,
dim_model: int = 1024,
num_heads: int = 16,
dim_mlp: int = 4096,
dropout: float = 0.1,
extractor: str = 'base',
classifier: str = 'token',
mode: str = 'fine_tune',
) -> None:
super(VisionTransformer, self).__init__()
self.num_patches = int(image_size / patch_size)**2
self.num_layers = num_layers
self.patch_size = patch_size
self.dim_model = dim_model
self.num_heads = num_heads
self.dim_mlp = dim_mlp
self.dropout = nn.Dropout(p=dropout)
# Input Embedding
self.extractor = extractor
self.input_embedding = nn.Linear((channels * patch_size**2), dim_model)
self.conv = nn.Conv2d(
in_channels=channels,
out_channels=self.dim_model,
kernel_size=self.patch_size,
stride=self.patch_size
)
# Encoder Multi-Head Self-Attention
self.encoder = VisionEncoder(
self.num_patches,
self.num_layers,
self.dim_model,
self.num_heads,
self.dim_mlp,
dropout,
)
self.classifier = classifier
self.pretrain_classifier = nn.Sequential(
nn.Linear(self.dim_model, self.dim_model),
nn.Tanh(),
)
self.fine_tune_classifier = nn.Linear(self.dim_model, num_classes)
self.mode = mode
def forward(self, inputs: Tensor) -> any:
""" For cls pooling """
assert inputs.ndim != 4, f"Input shape should be [BS, CHANNEL, IMAGE_SIZE, IMAGE_SIZE], but got {inputs.shape}"
x = inputs
if self.extractor == 'cnn':
# self.conv(x).shape == [batch, dim, image_size/patch_size, image_size/patch_size]
x = self.conv(x).reshape(x.shape[0], self.dim_model, self.num_patches**2).transpose(-1, -2)
else:
# self.extractor == 'base':
x = self.input_embedding(
x.reshape(x.shape[0], self.num_patches, (self.patch_size**2 * x.shape[1]))
)
cls_token = torch.zeros(x.shape[0], 1, x.shape[2]) # can change init method
x = torch.cat([cls_token, x], dim=1)
x, layer_output = self.encoder(x) # output
# classification
x = x[:, 0, :] # select cls token, which is position 0 in sequence
if self.mode == 'fine_tune':
x = self.fine_tune_classifier(x)
if self.mode == 'pretrain':
x = self.fine_tune_classifier(self.pretrain_classifier(x))
return x
ννΈ, μ½λμμ λμ¬κ²¨λ΄μΌ ν μ μ MLP Head
λ‘, μ μλ pre-train
μμ κ³Ό fine-tune
μμ μ μλ‘ λ€λ₯Έ Classifier Head
λ₯Ό μ¬μ©νλ€. μ μμλ Activation Function
1κ°μ λ κ°μ MLP Layer
λ₯Ό μ¬μ©νκ³ , νμμλ 1κ°μ MLP Layer
λ₯Ό μ¬μ©νλ€.
λ€λ§, pretrain_classifier
μ μ
μΆλ ₯ μ°¨μμ λν μ νν μμΉλ₯Ό λ
Όλ¬Έμ΄λ official repo codeλ₯Ό νμΈν΄λ μ°Ύμ μ μμλ€, κ·Έλμ μμλ‘ λͺ¨λΈμ μ°¨μκ³Ό λκ°μ΄ μΈν
νκ² λμλ€.
λν μ μλ CLS Pooling
κ³Ό λλΆμ΄ GAP
λ°©μλ μ μνλλ°, GAP
λ°©μμ μΆνμ λ°λ‘ μΆκ°κ° νμνλ€. κ·Έλ¦¬κ³ μ¬μ νλ ¨κ³Ό νμΈ νλ λͺ¨λ λΆλ₯ ν
μ€ν¬λ₯Ό μννλλ° (μ¬μ§μ΄ κ°μ λ°μ΄ν° μΈνΈλ₯Ό μ¬μ©ν¨) μ κ΅³μ΄ μλ‘ λ€λ₯Έ Classifier Head
λ₯Ό μ μνλμ§ μλλ₯Ό μ μ μμ΄ λ
Όλ¬Έμ λ€μ μ½μ΄λ΄€μ§λ§, μ΄μ μ λν΄μ μμΈν μΈκΈνλ λΆλΆμ΄ μμλ€.
ViT
λ μ
λ ₯ μλ² λ©μ μ μνλ λΆλΆμ μ μΈνλ©΄ μ μμ μλλλ‘ κΈ°μ‘΄ νΈλμ€ν¬λ¨Έμ λμΌν λͺ¨λΈ ꡬ쑰λ₯Ό κ°μ‘λ€. μμ ν λ€λ₯Έ λ°μ΄ν°μΈ μ΄λ―Έμ§μ ν
μ€νΈμ κ°μ ꡬ쑰μ λͺ¨λΈμ μ μ©νλ€λ κ²μ΄ μ λ§ μ½μ§ μμ 보μλλ°, ν¨μΉ κ°λ
μ λ§λ€μ΄ μμ°μ΄μ ν ν°μ²λΌ κ°μ£Όνκ³ μ¬μ©ν κ²μ΄ μλλλ‘ κ΅¬ννλλ° μ§κ΄μ μ΄λ©΄μλ μ λ§ ν¨κ³Όμ μ΄μλ€κ³ μκ°νλ€. μ΄μ μ΄λ κ² λ§λ€μ΄μ§ λͺ¨λΈμ ν΅ν΄ μ§νν μ¬λ¬ μ€ν κ²°κ³Όμ μ΄λ€ μΈμ¬μ΄νΈκ° λ΄κ²¨ μλμ§ μμ보μ.
π¬Β Insight from Experiment
π‘Β Insight 1. ViTμ Scalability μ¦λͺ
Pre-Train
μ μ¬μ©λλ μ΄λ―Έμ§ λ°μ΄ν° μΈνΈμ ν¬κΈ°κ° 컀μ§μλ‘Fine-Tune Stage
μμViT
κ°CNN
λ³΄λ€ λμ μ±λ₯- κ°μ μ±λ₯μ΄λΌλ©΄
ViT
κ° μλμ μΌλ‘ μ μ μ°μ°λμ κΈ°λ‘
μ λνλ Pre-Train Stage
μ μ¬μ©λ μ΄λ―Έμ§ λ°μ΄ν° μΈνΈμ λ°λ₯Έ λͺ¨λΈμ Fine-Tune
μ±λ₯ μΆμ΄λ₯Ό λνλΈ μλ£λ€. μ¬μ νλ ¨ λ°μ΄ν° μ€μΌμΌμ΄ ν¬μ§ μμ λλ Conv
κΈ°λ°μ ResNet
μ리μ¦κ° ViT
μ리μ¦λ₯Ό μλνλ λͺ¨μ΅μ 보μ¬μ€λ€. νμ§λ§ λ°μ΄ν° μΈνΈμ ν¬κΈ°κ° 컀μ§μλ‘ μ μ ViT
μ리μ¦μ μ±λ₯μ΄ ResNet
μ λ₯κ°νλ κ²°κ³Όλ₯Ό λ³Ό μ μλ€.
ννΈ, ViT & ResNet μ±λ₯ κ²°κ³Ό λͺ¨λ ImageNetκ³Ό JFT-Imageλ‘ μ¬μ νλ ¨ λ° νμΈ νλμ κ±°μ³ λμλ€κ³ νλ μ°Έκ³ νμ. μΆκ°λ‘ νμΈ νλ κ³Όμ μμ μ¬μ νλ ¨ λλ³΄λ€ μ΄λ―Έμ§ μ¬μ΄μ¦λ₯Ό ν€μμ νλ ¨μ μμΌ°λ€κ³ λ Όλ¬Έμμ λ°νκ³ μλλ°, μ΄λ μ μμ μ€ν κ²°κ³Όμ κΈ°μΈν κ²μ΄λ€. λ Όλ¬Έμ λ°λ₯΄λ©΄ νμΈ νλ λ μ¬μ νλ ¨ λΉμλ³΄λ€ λ λμ ν΄μλμ μ΄λ―Έμ§λ₯Ό μ¬μ©νλ©΄ μ±λ₯μ΄ ν₯μ λλ€κ³ νλ κΈ°μ΅νλ€κ° μ¨λ¨Ήμ΄λ³΄μ.
μ λνλ μ°μ°λ λ³νμ λ°λ₯Έ λͺ¨λΈμ μ±λ₯ μΆμ΄λ₯Ό λνλΈ κ·Έλ¦Όμ΄λ€. λ μ§ν λͺ¨λ κ°μ μ μλΌλ©΄ ViT
μ리μ¦μ μ°μ°λμ΄ νμ ν μ μμ μ μ μλ€. λν μ νλ 95% μ΄ν ꡬκ°μμ κ°μ μ±λ₯μ΄λΌλ©΄ ViT
μ Hybrid
λ²μ λͺ¨λΈμ μ°μ°λμ΄ μΌλ° ViT
λ²μ λ³΄λ€ νμ ν μ μμ νμΈν μ μλ€. μ΄λ¬ν μ¬μ€μ μΆνμ Swin-Transformer
μ€κ³μ μκ°μ μ€λ€.
λ κ°μ μ€ν κ²°κ³Όλ₯Ό μ’
ν©νμ λ, ViT
κ° ResNet
λ³΄λ€ μΌλ°ν μ±λ₯μ΄ λ λμΌλ©°(λν 1) λͺ¨λΈμ Saturation
νμμ΄ λλλ¬μ§μ§ μμ μ±λ₯μ νκ³μΉ(λν 2) μμ λ λλ€κ³ λ³Ό μ μλ€. λ°λΌμ κΈ°μ‘΄ νΈλμ€ν¬λ¨Έμ μ°μ°β’ꡬ쑰μ μΈ‘λ©΄μμ Scalability
λ₯Ό μ±κ³΅μ μΌλ‘ μ΄μνλ€κ³ νκ°ν μ μκ² λ€.
π‘Β Insight 2. Pure Self-Attentionμ μ’μ μ΄λ―Έμ§ νΌμ²λ₯Ό μΆμΆνκΈ°μ μΆ©λΆνλ€
- Patch Embedding Layerμ PCA κ²°κ³Ό, ν¨μΉμ κΈ°μ κ° λλ μ°¨μκ³Ό μ μ¬ν λͺ¨μμ μΆμΆ
Convolution
μμ΄Self-Attention
λ§μΌλ‘λ μΆ©λΆν μ΄λ―Έμ§μ μ’μ νΌμ²λ₯Ό μΆμΆνλ κ²μ΄ κ°λ₯Vision
μμConvolution
μ λνreliance
ννΌ κ°λ₯
Patch Embedding Layerβs Filter
μ μλ£λ μΆ©λΆν νμ΅μ κ±°μΉκ³ λ ViT
μ Patch Embedding Layer
μ νν°λ₯Ό PCA
ν κ²°κ³Ό μ€μμ νΉμκ°μ΄ λμ μμ 28κ°μ νΌμ²λ₯Ό λμ΄ν κ·Έλ¦Όμ΄λ€. μ΄λ―Έμ§μ κΈ°λ³Έ λΌλκ° λκΈ°μ μ ν©ν΄ 보μ΄λ νΌμ²λ€μ΄ μΆμΆλ λͺ¨μ΅μ λ³Ό μ μλ€.
λ°λΌμ Inductive Bias
μμ΄, λ¨μΌ Self-Attention
λ§μΌλ‘ μ΄λ―Έμ§μ νΌμ²λ₯Ό μΆμΆνλ κ²μ΄ μΆ©λΆν κ°λ₯νλ€. λΉμ λΆμΌμ λ§μ°ν Convolution
μμ‘΄μμ λ²μ΄λ μλ‘μ΄ μν€ν
μ²μ λμ
μ΄ κ°λ₯ν¨μ μμ¬ν λΆλΆμ΄λΌκ³ ν μ μκ² λ€.
π‘Β Insight 3. Bottom2General Information, Top2Specific Information
μ λ ₯
κ³Ό κ°κΉμ΄ μΈμ½λμΌμλ‘Global & General
ν Informationμ ν¬μ°©μΆλ ₯
κ³Ό κ°κΉμ΄ μΈμ½λμΌμλ‘Local & Specific
ν Informationμ ν¬μ°©
Multi-Head Attention Distance per Network Depth
λ€μ μλ£λ μΈμ½λμ κ°μ λ³νμ λ°λ₯Έ κ°λ³ μ΄ν
μ
ν΄λμ μ΄ν
μ
거리 λ³ν μΆμ΄λ₯Ό λνλΈ κ·Έλ¦Όμ΄λ€. μ¬κΈ°μ μ΄ν
μ
거리λ, ν΄λκ° μΌλ§λ λ©λ¦¬ λ¨μ΄μ§ ν¨μΉλ₯Ό μ΄ν
μ
νλμ§ ν½μ
λ¨μλ‘ ννν μ§νλ€. ν΄λΉ κ°μ΄ λμμλ‘ κ±°λ¦¬μ λ©λ¦¬ λ¨μ΄μ§ ν¨μΉμ μ΄ν
μ
μ, μμμλ‘ κ°κΉμ΄ ν¨μΉμ μ΄ν
μ
νλ€λ κ²μ μλ―Ένλ€. λ€μ λνλ₯Ό μ΄ν΄λ³΄μ. μ
λ ₯κ³Ό κ°κΉμ΄ μΈμ½λμΌμλ‘(Depth 0) ν΄λλ³ μ΄ν
μ
거리μ λΆμ°μ΄ 컀μ§κ³ , μΆλ ₯κ³Ό κ°κΉμ΄ μΈμ½λμΌμλ‘(Depth 23) λΆμ°μ΄ μ μ μ€μ΄λ€λ€κ° κ±°μ ν μ μ μλ ΄νλλ―ν μμμ 보μ¬μ€λ€. λ€μ λ§ν΄, μ
λ ₯κ³Ό κ°κΉμ΄ Bottom Encoder
λ λ©λ¦¬ λ¨μ΄μ§ ν¨μΉλΆν° κ°κΉμ΄ ν¨μΉκΉμ§ λͺ¨λ μ μμ (Global
)μΌλ‘ μ΄ν
μ
μ μνν΄ General
ν μ 보λ₯Ό ν¬μ°©νκ² λκ³ μΆλ ₯κ³Ό κ°κΉμ΄ Top Encoder
λ κ°λ³ ν΄λλ€μ΄ λͺ¨λ λΉμ·ν 거리μ μμΉν ν¨μΉ(Local
)μ μ΄ν
μ
μ μνν΄ Specific
ν μ 보λ₯Ό ν¬μ°©νκ² λλ€.
μ΄ λ Global
κ³Ό Local
μ΄λΌλ μ©μ΄ λλ¬Έμ Bottom Encoder
λ λ©λ¦¬ λ¨μ΄μ§ ν¨μΉμ μ΄ν
μ
νκ³ , Top Encoder
λ κ°κΉμ΄ ν¨μΉμ μ΄ν
μ
νλ€κ³ μ°©κ°νκΈ° μ½λ€. κ·Έλ¬λ κ°λ³ ν΄λλ€μ μ΄ν
μ
κ±°λ¦¬κ° μΌλ§λ λΆμ°λμ΄ μλκ°κ° λ°λ‘ Global
, Local
μ ꡬλΆνλ κΈ°μ€μ΄ λλ€. μ
λ ₯λΆμ κ°κΉμ΄ λ μ΄μ΄λ€μ ν€λλ€μ μ΄ν
μ
거리 λΆμ°μ΄ λ§€μ° ν° νΈμΈλ°, μ΄κ²μ μ΄ν¨μΉ μ ν¨μΉ λͺ¨λ μ΄ν
μ
ν΄λ³΄κ³ λΉκ΅ν΄λ³Έλ€κ³ ν΄μν΄μ Global
μ΄λΌκ³ λΆλ₯΄κ³ , μΆλ ₯λΆμ κ°κΉμ΄ λ μ΄μ΄λ ν€λλ€μ μ΄ν
μ
거리 λΆμ°μ΄ λ§€μ° μμ νΈμΈλ°, μ΄κ² λ°λ‘ κ°κ°μ ν€λλ€μ΄ μ΄λ€ μ 보μ μ£Όλͺ©ν΄μΌν μ§(λΆλ₯ μμ€μ΄ κ°μ₯ μμμ§λ ν¨μΉ) λ²μλ₯Ό μΆ©λΆν μ’ν μνμμ νΉμ λΆλΆμλ§ μ§μ€νλ€λ μλ―Έλ‘ ν΄μν΄ Local
μ΄λΌκ³ λΆλ₯΄κ² λμλ€.
<Revisiting Few-sample BERT Fine-tuning>λ μμ λΉμ·ν λ§₯λ½μ μ¬μ€μ λν΄ μΈκΈνκ³ μμΌλ μ°Έκ³ ν΄λ³΄μ. μ΄λ¬ν μ¬μ€μ νΈλμ€ν¬λ¨Έ μΈμ½λ κ³μ΄ λͺ¨λΈμ νλν λ Depth
λ³λ‘ λ€λ₯Έ Learning Rate
μ μ μ©νλ Layerwise Learning Rate Decay
μ μ΄μμ΄ λκΈ°λ νλ€. Layerwise Learning Rate Decay
μ λν΄μλ μ¬κΈ° ν¬μ€νΈλ₯Ό μ°Έκ³ νλλ‘ νμ.
ννΈ λ
Όλ¬Έμλ μΈκΈλμ§ μμ, νμμ λνΌμ
μ κ°κΉμ§λ§, μΆλ ₯μ κ°κΉμ΄ μΈμ½λλ€μ ν΄λκ° κ°μ§ Attention Distance
μ΄ λͺ¨λ λΉμ·νλ€λ μ¬μ€λ‘ μ΄λ―Έμ§ λΆλ₯μ κ²°μ μ μΈ μν μ νλ νΌμ²κ° μ΄λ―Έμ§μ νΉμ ꡬμμ λͺ¨μ¬ μμΌλ©°, κ·Έ μ€νμ μ΄λ―Έμ§μ μ€μ λΆκ·ΌμΌ κ°λ₯μ±μ΄ λλ€κ³ μΆμΈ‘ ν΄λ³Ό μ μλ€. λͺ¨λ ν΄λμ ν½μ
κ±°λ¦¬κ° μλ‘ λΉμ·νλ €λ©΄ μΌλ¨ λΉμ·ν μμΉμ ν¨μΉμ μ΄ν
μ
μ ν΄μΌνκΈ° λλ¬Έμ λΆλ₯ μμ€κ°μ μ΅μλ‘ μ€μ¬μ£Όλ νΌμ²λ λ³΄ν΅ ν ꡬμ(ν¨μΉ)μ λͺ°λ € μμ κ²μ΄λΌκ³ μ μΆκ° κ°λ₯νλ€. λν νΉμ μ€νμ΄ μ€μμ μμΉν μλ‘ μ΄ν
μ
거리μ λΆμ°μ΄ μ€μ΄λ€κ²μ΄λΌκ³ μκ° ν΄λ³Ό μλ μμλ€. μ μλ Attention Rollout
μ΄λΌλ κ°λ
μ ν΅ν΄ Attention Distance
μ μ°μΆνλ€κ³ μΈκΈνλλ°, μμΈν λ΄μ©μ μμ λ λ§ν¬λ₯Ό μ°Έκ³ ν΄λ³΄μ(νκ΅μ΄ μ€λͺ
λΈλ‘κ·Έ, μλ
Όλ¬Έ). μ΄λ¬ν νμμ κ°μ€μ΄ λ§λ€λ©΄, Convolution
μ Inductive Bias
μ€ Locality
μ ν¨κ³Όμ±μ Self-Attention
μ ν΅ν΄ μ
μ¦μ΄ κ°λ₯νλ©°, λ°λλ‘ Convolution
μ λν μμ‘΄μμ λ²μ΄λ λ¨μΌ Self-Attention
μΌλ‘λ κ°μ ν¨κ³Όλ₯Ό λΌ μ μλ€λ μ¦κ±° μ€ νλκ° λ κ²μ΄λ€.
π‘Β Insight 4. ViTλ CLS Pooling μ¬μ©νλκ² ν¨μ¨μ
CLS Pooling
μGAP
λ³΄λ€ 2λ°° μ΄μ ν° νμ΅λ₯ μ μ¬μ©ν΄λ λΉμ·ν μ±λ₯μ κΈ°λ‘- νμ΅ μλλ λ λΉ λ₯΄λ μ±λ₯μ΄ λΉμ·νκΈ° λλ¬Έμ
CLS Pooling
μ΄ λ ν¨μ¨μ
- νμ΅ μλλ λ λΉ λ₯΄λ μ±λ₯μ΄ λΉμ·νκΈ° λλ¬Έμ
Performance Trend by Pooling Method with LR
λ€μ λνλ νλ§ λ°©μκ³Ό νμ΅λ₯ μ λ³λμ λ°λ₯Έ μ νλ λ³ν μΆμ΄λ₯Ό λνλΈ κ·Έλ¦Όμ΄λ€. λΉμ·ν μ±λ₯μ΄λΌλ©΄ CLS Pooling
μ΄ GAP
λ³΄λ€ 2λ°° μ΄μ ν° νμ΅λ₯ μ μ¬μ©νλ€. νμ΅λ₯ μ΄ ν¬λ©΄ λͺ¨λΈμ μλ ΄ μλκ° λΉ¨λΌμ Έ νμ΅ μλκ° λΉ¨λΌμ§λ μ₯μ μ΄ μλ€. κ·Έλ°λ° μ±λ₯κΉμ§ λΉμ·νλ€λ©΄ ViT
λ CLS Pooling
μ μ¬μ©νλ κ²μ΄ λ ν¨μ¨μ μ΄λΌκ³ ν μ μκ² λ€.
λμ€μ μκ°μ΄ λλ€λ©΄ λ€λ₯Έ νλ§ λ°©μ, μλ₯Ό λ€λ©΄ Weighted Layer Pooling
, GeM Pooling
, Attention Pooling
κ°μ κ²μ μ μ©ν΄ μ€νν΄λ³΄κ² λ€.
π‘Β Insight 5. ViTλ Absolute 1D-Position Embedding μ¬μ©νλκ² κ°μ₯ ν¨μ¨μ
- μ΄λ€ ννλ‘λ μμΉ μλ² λ© κ°μ μ μν΄μ€λ€λ©΄, ννμ μ’ λ₯μ μκ΄μμ΄ κ±°μ λΉμ·ν μ±λ₯μ 보μ
- μ±λ₯μ΄ λΉμ·νλ©΄, μ§κ΄μ μ΄κ³ ꡬνμ΄ κ°νΈν
Absolute 1D-Position Embedding
λ°©λ²μ μ¬μ©νλ κ²μ΄ κ°μ₯ ν¨μ¨μ ViT
λPatch-Level
μ¬μ©ν΄,Pixel-Level
λ³΄λ€ μλμ μΌλ‘ μνμ€ κΈΈμ΄κ° 짧μ μμΉβ’κ³΅κ° μ 보λ₯Ό μΈμ½λ©νλ λ°©μμ μν₯μ λ λ°μ
Performance Table by making Position Embedding method
μ μ€ν κ²°κ³Όλ Position Embedding
μΈμ½λ© λ°©μμ λ°λ₯Έ ViT
λͺ¨λΈμ μ±λ₯ λ³ν μΆμ΄λ₯Ό λνλΈ μλ£λ€. μΈμ½λ© ννμ μκ΄μμ΄ μμΉ μλ² λ©μ μ λ¬΄κ° μ±λ₯μ ν° μν₯μ λ―ΈμΉλ€λ μ¬μ€μ μλ €μ£Όκ³ μλ€. ννΈ, μΈμ½λ© νν λ³νμ λ°λ₯Έ μ μλ―Έν μ±λ₯ λ³νλ μμλ€. νμ§λ§ Absolute 1D-Position Embedding
μ 컨μ
μ΄ κ°μ₯ μ§κ΄μ μ΄λ©° ꡬννκΈ° νΈνκ³ μ°μ°λμ΄ λ€λ₯Έ μΈμ½λ©λ³΄λ€ μ λ€λ κ²μ κ°μνλ©΄ ViTμ κ°μ₯ ν¨μ¨μ μΈ μμΉ μλ² λ© λ°©μμ΄λΌκ³ νλ¨ν μ μλ€.
λ
Όλ¬Έμ κ²°κ³Όμ λν΄ ViT
κ° μ¬μ©νλ Patch-Level Embedding
μ΄ Pixel-Level
λ³΄λ€ μλμ μΌλ‘ 짧μ μνμ€ κΈΈμ΄λ₯Ό κ°κΈ° λλ¬Έμ΄λΌκ³ μ€λͺ
νλ€. μλ₯Ό λ€μ΄ 224x224
μ¬μ΄μ¦μ μ΄λ―Έμ§λ₯Ό 16x16
μ¬μ΄μ¦μ ν¨μΉ μ¬λ¬μ₯μΌλ‘ λ§λ λ€κ³ μκ°ν΄λ³΄μ. μλ² λ© μ°¨μμ λ€μ΄κ°λ $N$ μ $(224/16)^2$ , μ¦ 196
μ΄ λλ€. ννΈ μ΄κ²μ Pixel-Level
λ‘ μλ² λ© νκ² λλ©΄ $224^2$, μ¦ 50176
κ°μ μνμ€κ° μκΈ΄λ€. λ°λΌμ Pixel-Level
μ λΉνλ©΄ ν¨μ¬ 짧μ μνμ€ κΈΈμ΄λ₯Ό κ°κΈ° λλ¬Έμ Absolute 1D-Position Embedding
λ§μΌλ‘λ μΆ©λΆν Spatial Relation
μ νμ΅ν μ μλ κ²μ΄λ€.
Absolute 1D-Position Embedding
νμ§λ§, νμλ μμ°μ΄ μ²λ¦¬μ Transformer-XL
, XLNet
, DeBERTa
κ°μ λͺ¨λΈλ€μ΄ Relative Position Embedding
λ°©μμ μ μ©ν΄ ν° μ±κ³΅μ κ±°λ λ°κ° μλ€λ μ μ μκ°νλ©΄ μ΄λ° κ²°κ³Όκ° λ©λμ΄ κ°λ©΄μλ μμνλ€.
μ μλ μ€νμ μ¬μ©ν λͺ¨λ λ°μ΄ν° μΈνΈλ₯Ό 224x224
λ‘ resize
νλ€κ³ λ°νκ³ μλλ°, λ§μ½ μ΄λ―Έμ§ μ¬μ΄μ¦κ° 512x512
μ λλ§ λλλΌλ $N$ κ°μ΄ 1024
μ΄λΌμ μ κ²°κ³Όμ μλΉν λ€λ₯Έ μμμ΄ λνλμ§ μμκΉ νλ μκ°μ΄ λ λ€. μΆνμ μκ°μ΄ λλ€λ©΄ μ΄ λΆλΆλ κΌ μ€νν΄λ΄μΌκ² λ€. μμΈ‘μ»¨λ° μ΄λ―Έμ μ¬μ΄μ¦κ° 컀μ§μλ‘ 2D Position Embedding
νΉμ Relative Position Embedding
μ΄ λ ν¨μ¨μ μΌ κ²μ΄λΌ μμνλ€.
π§ββοΈΒ Conclusion
μ΄λ κ² ViT
λͺ¨λΈμ μ μν <An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale>μ μ€λ¦° λ΄μ©μ λͺ¨λ μ΄ν΄λ³΄μλ€. Conv
μ λν μμ‘΄μ ννΌ νλ€λ μ μμ λ§€μ° μλ―Έκ° μλ μλμμΌλ©°, Self-Attention & Transformer ꡬ쑰 μ±νλ§μΌλ‘λ μ»΄ν¨ν° λΉμ μμμ μ΄λ μ λ scalability
λ₯Ό μ΄μνλλ° μ±κ³΅νλ€λ μ μμ νλ μ°κ΅¬μ μ€μν μμ¬μ μ λ¨κ²Όλ€. μλμ μΌλ‘ μ 체(??)λμ΄ μλ λΉμ μμμ΄ μ±λ₯μ νκ³λ₯Ό νλ¨κ³ λ°μ΄λμ μ μλ μ΄μμ λ§λ ¨ν΄μ€ μ
μ΄λ€.
νμ§λ§, ViT
μ Pretrain Stage
μ μ ν©ν Self-Supervised Learning
λ°©λ²μ μ°Ύμ§ λͺ»ν΄ μ¬μ ν Supervised Learning
λ°©μμ μ±νν μ μ λ§€μ° μμ¬μ λ€. μ΄λ κ²°κ΅ λ°μ΄ν° Scale
νμ₯μ νκ³λ₯Ό μλ―ΈνκΈ° λλ¬Έμ΄λ€. μ€λλ BERTμ GPTμ μ±κ³΅ μ νλ λΉλ¨ Self-Attention
μ Transformer
μ ꡬ쑰μ νμμ±μ μν΄μλ§ νμνκ² μλλ€. μ΄μ λͺ»μ§ μκ²(κ°μΈμ μΌλ‘ μ μΌ μ€μνλ€ μκ°) μ£Όμνλ κ²μ΄ λ°λ‘ λ°μ΄ν° Scale
νμ₯μ΄λ€. MLM
, AR
λ±μ Self-Supervised Learning
λλΆμ λ°μ΄ν° Scale
μ ν¨μ¨μ μΌλ‘ μ€μΌμΌ μ
μν¬ μ μμκ³ , μ¬μ νλ ¨ λ°μ΄ν°μ μ¦κ°λ λͺ¨λΈ κΉμ΄, λλΉ, μ°¨μκΉμ§ λμ± ν¬μΌ ν€μ°λλ° κΈ°μ¬νλ€.
λν ViT
λ μ μ²μ μΌλ‘ Patch-Level Embedding
μ μ¬μ©νκΈ° λλ¬Έμ λ€μν μ΄λ―Έμ§ ν
μ€ν¬μ μ μ©νλ κ²μ΄ νλ€λ€. Segmentation
, Object Detection
κ°μ Taskλ ν½μ
λ¨μλ‘ μμΈ‘μ μνν΄ κ°μ²΄λ₯Ό νμ§νκ±°λ λΆν ν΄μΌ νλ€. νμ§λ§ Patch
λ¨μλ‘ νλ ¨μ μννλ ViT
λ Pixel
λ¨μμ μμΈ‘μ μννλλ° μ΄λ €μμ κ²ͺλλ€.
λ§μ§λ§μΌλ‘ Self-Attention
μ체μ Computational Overhead
κ° λ무 μ¬ν΄ κ³ ν΄μλμ μ΄λ―Έμ§λ₯Ό μ μ ν λ€λ£¨κΈ° νλ€λ€. μμμλ μΈκΈνμ§λ§ μ΄λ―Έμ§μ μ¬μ΄μ¦κ° 512x512
λ§ λμ΄λ μ΄λ―Έ ν¨μΉμ κ°μκ° 1024
κ° λλ€. μ¬μ΄μ¦κ° 컀μ§μλ‘ μνμ€ κΈΈμ΄ μμ κΈ°νκΈμμ μΌλ‘ 컀μ§λλ°λ€κ° Self-Attention
λ 쿼리μ ν€ νλ ¬μ λ΄μ (μκΈ° μμ κ³Ό κ³±μ΄λΌ λ³Ό μ μμ) νκΈ° λλ¬Έμ Computational Overhead
κ° $N^2$μ΄ λλ€.
νμλ ViT
λ₯Ό μ λ°μ μ±κ³΅μ΄λΌκ³ ννκ³ μΆλ€. λ³Έλ ViT
μ μ€κ³ λͺ©μ μ λΉμ λΆμΌμ Conv
μ λν μμ‘΄μ ννΌνλ©΄μ, ν¨μ΄ν Self-Attention
μ λμ
ν΄ Scalabilty
λ₯Ό μ΄μνλ κ²μ΄μλ€. Self-Attention
μ λμ
νλλ°λ μ±κ³΅νμ§λ§, μ¬μ ν λ€λ£° μ μλ μ΄λ―Έμ§ μ¬μ΄μ¦λ Taskμλ νκ³κ° λΆλͺ
νλ©° κ²°μ μ μΌλ‘ Self-Supervised Learning
λ°©μμ λμ
νμ§ λͺ»νλ€. Scalabilty
λΌλ λ¨μ΄μ μλ―Έλ₯Ό μκ°νλ©΄, λ°©κΈ λ§ν λΆλΆμμκΉμ§ νμ₯μ±μ΄ μμ΄μΌ μ€κ³ μλμ λΆν©νλ κ²°κ³ΌλΌκ³ μκ°νλ€.
Leave a comment