πΒ [ViT] An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale
πΒ Overview
μμνκΈ° μμ, λ³Έ λ
Όλ¬Έ 리뷰λ₯Ό μμνκ² μ½μΌλ €λ©΄ Transformer μ λν μ μ΄ν΄κ° νμμ μ΄λ€. μμ§ Transformer μ λν΄μ μ λͺ¨λ₯Έλ€λ©΄ νμκ° μμ±ν ν¬μ€νΈλ₯Ό μ½κ³ μ€κΈΈ κΆμ₯νλ€. λν λ³Έλ¬Έ λ΄μ©μ μμ±νλ©΄μ μ°Έκ³ ν λ
Όλ¬Έκ³Ό μ¬λ¬ ν¬μ€νΈμ λ§ν¬λ₯Ό 맨 λ° νλ¨μ 첨λΆνμΌλ μ°Έκ³ λ°λλ€. μκ°μ΄ μμΌμ λΆλ€μ μ€κ°μ μ½λ ꡬνλΆλ₯Ό μλ΅νκ³ Insight λΆν° μ½κΈ°λ₯Ό κΆμ₯νλ€.
Vision Transformer(μ΄ν ViT)λ 2020λ
10μ Googleμμ λ°νν μ»΄ν¨ν° λΉμ μ© λͺ¨λΈμ΄λ€. μμ°μ΄ μ²λ¦¬μμ λμ±κ³΅μ κ±°λ νΈλ μ€ν¬λ¨Έ ꡬ쑰μ κΈ°λ²μ κ±°μ κ·Έλλ‘ λΉμ λΆμΌμ μ΄μνλ€λ μ μμ ν° μμκ° μμΌλ©°, μ΄ν μ»΄ν¨ν° λΉμ λΆμΌμ νΈλ μ€ν¬λ¨Έ μ μ±μλκ° μ΄λ¦¬κ² λ κ³κΈ°λ‘ μμ©νλ€.
ννΈ, ViT μ μ€κ³ μ² νμ λ°λ‘ scalability(λ²μ©μ±)μ΄λ€. μ κ²½λ§ μ€κ³μμ λ²μ©μ±μ΄λ, λͺ¨λΈμ νμ₯ κ°λ₯μ±μ λ§νλ€. μλ₯Ό λ€λ©΄ νμ΅ λ°μ΄ν°λ³΄λ€ λ ν¬κ³ 볡μ‘ν λ°μ΄ν° μΈνΈλ₯Ό μ¬μ©νκ±°λ λͺ¨λΈμ νλΌλ―Έν°λ₯Ό λλ € μ¬μ΄μ¦λ₯Ό ν€μλ μ¬μ ν μ ν¨ν μΆλ‘ κ²°κ³Όλ₯Ό λμΆνκ±°λ λ λμ μ±λ₯μ 보μ¬μ£Όκ³ λμκ° κ°μ μ μ¬μ§κ° μ¬μ ν λ¨μμμ λ βνμ₯μ±μ΄ λλ€β λΌκ³ νννλ€. μ μλ€μ λ
Όλ¬Έ μ΄λ°μ μ½ μ°μ΄μ μ»΄ν¨ν° λΉμ λΆμΌμ scalability λμ΄λ κ²μ΄ μ΄λ² λͺ¨λΈ μ€κ³μ λͺ©νμλ€κ³ λ°νκ³ μλ€. λ²μ©μ±μ μ κ²½λ§ λͺ¨λΈ μ€κ³μμ κ°μ₯ ν° νλκ° λλλ° λλ©μΈλ§λ€ μ μνλ μλ―Έμ μ°¨μ΄κ° λ―ΈμΈνκ² μ‘΄μ¬νλ€. λ°λΌμ ViTμ μ μλ€μ΄ λ§νλ λ²μ©μ±μ΄ 무μμ μλ―Ένλμ§ μμ보λ κ²μ ꡬ체μ μΈ λͺ¨λΈ ꡬ쑰λ₯Ό μ΄ν΄νλλ° ν° λμμ΄ λ κ²μ΄λ€.
π§ Β Scalability in ViT
λ Όλ¬Έ μ΄λ°λΆμμ λ€μκ³Ό κ°μ λ¬Έμ₯μ΄ μμ λμ΄μλ€.
βOur Vision Transformer (ViT) attains excellent results when pre-trained at sufficient scale and transferred to tasks with fewer datapoints"
μ΄ κ΅¬λ¬Έμ΄ ViT μ Scalabilityλ₯Ό κ°μ₯ μ μ€λͺ
νκ³ μλ€κ³ μκ°νλ€. μ μλ€μ΄ λ§νλ λ²μ©μ±μ κ²°κ΅ backbone ꡬ쑰μ νμ©μ μλ―Ένλ€. μμ°μ΄ μ²λ¦¬μ μ΅μν λ
μλΌλ©΄ μ½κ² μ΄ν΄κ° κ°λ₯ν κ²μ΄λ€. Transformer, GPT, BERTμ λ±μ₯ μ΄ν, μμ°μ΄ μ²λ¦¬λ λ²μ©μ±μ κ°λ λ°μ΄ν° μΈνΈλ‘ μ¬μ νλ ¨ν λͺ¨λΈμ νμ©ν΄ Task-Agnosticνκ² νλμ backboneμΌλ‘ κ±°μ λͺ¨λ Taskλ₯Ό μνν μ μμΌλ©°, μμ μ¬μ΄μ¦μ λ°μ΄ν°λΌλ μλΉν λμ μμ€μ μΆλ‘ μ±λ₯μ λΌ μ μμλ€. κ·Έλ¬λ λΉμ μ»΄ν¨ν° λΉμ μ λ©μΈμ΄μλ Conv κΈ°λ° λͺ¨λΈλ€μ νμΈνλν΄λ λ°μ΄ν° ν¬κΈ°κ° μμΌλ©΄ μΌλ°ν μ±λ₯μ΄ λ§€μ° λ¨μ΄μ§κ³ , Taskμ λ°λΌμ λ€λ₯Έ μν€ν
μ²λ₯Ό κ°λ λͺ¨λΈμ μλ‘κ² μ μνκ±°λ λΆλ¬μ μ¬μ©ν΄μΌ νλ λ²κ±°λ‘μμ΄ μμλ€. μλ₯Ό λ€λ©΄ Image Classfication μλ ResNet, Segmentation μλ U-Net, Object Detection μ YOLO λ₯Ό μ¬μ©νλ κ²μ²λΌ λ§μ΄λ€. λ°λ©΄ μμ°μ΄ μ²λ¦¬λ μ¬μ νμ΅λ λͺ¨λΈ νλλ‘ λͺ¨λ NLU, μ¬μ§μ΄λ NLG Taskλ μνν μ μλ€. μ μλ€μ μ΄λ¬ν λ²μ©μ±μ μ»΄ν¨ν° λΉμ μλ μ΄μ μν€κ³ μΆμλ κ² κ°λ€. κ·Έλ λ€λ©΄ λ¨Όμ μμ°μ΄ μ²λ¦¬μμ νΈλμ€ν¬λ¨Έ κ³μ΄μ΄ λ²μ©μ±μ κ°μ§ μ μμλ μ΄μ λ 무μμΈμ§ κ°λ¨ν μ΄ν΄λ³΄μ.
μ μλ€μ self-attention(λ΄μ )μ ν¨μ¨μ±, λͺ¨λΈμ ꡬ쑰μ νμμ± κ·Έλ¦¬κ³ self-supervised taskμ μ‘΄μ¬λ₯Ό κΌ½λλ€. κ·ΈλΌ μ΄κ²λ€μ΄ μ λ²μ©μ±μ λμ΄λλ° λμμ΄ λ κΉ??
self-attention(λ΄μ )μ νλ ¬ κ° κ³±μ
μΌλ‘ μ μ λμ΄ μ€κ³κ° λ§€μ° κ°νΈνκ³ λ³λ ¬λ‘ νλ²μ μ²λ¦¬νλ κ²μ΄ κ°λ₯νκΈ° λλ¬Έμ ν¨μ¨μ μΌλ‘ μ 체 λ°μ΄ν°λ₯Ό λͺ¨λ κ³ λ €ν μ°μ° κ²°κ³Όλ₯Ό μ»μ μ μλ€.
Multi-Head Attention ꡬ쑰λ μ¬λ¬ μ°¨μμ μλ―Έ κ΄κ³λ₯Ό λμμ ν¬μ°©νκ³ κ·Έκ²μ μμλΈν κ²κ³Ό κ°μ(μ€μ λ‘λ MLP) κ²°κ³Όλ₯Ό μ»μ μ μλ€λ μ μμ ꡬ쑰μ μΌλ‘ νμνλ€.
λ§μ§λ§μΌλ‘ MLM, Auto-Regression(LM) Taskλ λ°μ΄ν° μΈνΈμ λ³λμ μΈκ°μ κ°μ
(λΌλ²¨λ§)μ΄ νμνμ§ μκΈ° λλ¬Έμ κ°μ±λΉ μκ² λ°μ΄ν°μ λͺ¨λΈμ μ¬μ΄μ¦λ₯Ό λ릴 μ μκ² λλ€.
μ΄μ λ
Όλ¬Έμμ νΈλμ€ν¬λ¨Έ κ³μ΄μ΄ κ°μ§ λ²μ©μ±μ μ΄λ»κ² λΉμ λΆμΌμ μ μ©νλμ§ μ£Όλͺ©νλ©΄μ λͺ¨λΈ ꡬ쑰λ₯Ό νλ νλ μ΄ν΄λ³΄μ.
πΒ Modeling
- 1) Transfer
Scalabilityfrom pureTransformerto Computer Vision- Overcome
relianceon Convolution(Inductive Bias) in Computer Vision - Apply Self-Attention & Architecture from vanilla NLP Transformers as
closelyas possible - Treat Image as sequence of text token
- Make $P$ sub-patches from whole image, playing same role as token in NLP Transformer
- Overcome
μ μλ€μ λ¨Όμ Conv μ λν μμ‘΄μ λ²λ¦΄ κ²μ μ£Όμ₯νλ€. Convκ° κ°μ§ Inductive Bias λλ¬Έμ νμΈνλ λ 벨μμ λ°μ΄ν° ν¬κΈ°κ° μμΌλ©΄ μΌλ°ν μ±λ₯μ΄ λ¨μ΄μ§λ κ²μ΄λΌκ³ μ€λͺ
νκ³ μλ€. μ΄ λ§μ μ΄ν΄νλ €λ©΄ Inductive Biasμ λν΄μ λ¨Όμ μμμΌ νλ€. Inductive Biasλ, μ£Όμ΄μ§ λ°μ΄ν°λ‘λΆν° μΌλ°ν μ±λ₯μ λμ΄κΈ° μν΄ βμ
λ ₯λλ λ°μ΄ν°λ ~ ν κ²μ΄λ€β, βμ΄λ° νΉμ§μ κ°κ³ μμ κ²μ΄λ€βμ κ°μ κ°μ , κ°μ€μΉ, κ°μ€ λ±μ κΈ°κ³νμ΅ μκ³ λ¦¬μ¦μ μ μ©νλ κ²μ λ§νλ€.
Conv μ°μ° μ체 (κ°μ€μΉ 곡μ , νλ§ μλ Conv Blockμ΄ Invariance)μ κΈ°λ³Έ κ°μ μ translation equivariance, localityμ΄λ€. μ¬μ€ μ μμ μ£Όμ₯μ μ΄ν΄νλλ° equivarianceμ localityμ λ»μ΄ 무μμΈμ§ νμ
νλ κ²μ ν¬κ² μλ―Έκ° μλ€ (equivarianceμ invarianceμ λν΄μλ λ€λ₯Έ ν¬μ€ν
μμ μμΈν μ΄ν΄λ³΄λλ‘ νκ² λ€). μ€μν κ²μ μ
λ ₯ λ°μ΄ν°μ κ°μ μ λνλ€λ μ μ΄λ€. λ§μ½ μ£Όμ΄μ§ μ
λ ₯μ΄ λ―Έλ¦¬ κ°μ ν Inductive Bias μ λ²μ΄λλ€λ©΄ μ΄λ»κ² λ κΉ??
μλ§ μ€λ²νΌν
λκ±°λ λͺ¨λΈ νμ΅μ΄ μλ ΄μ±μ κ°μ§ λͺ»νκ² λ κ²μ΄λ€. μ΄λ―Έμ§ λ°μ΄ν°λ Taskμ λ°λΌ νμν Inductive Biasκ° λ¬λΌμ§λ€. μλ₯Ό λ€μ΄ Segmentation, Detection μ κ²½μ°λ μ΄λ―Έμ§ μ κ°μ²΄μ μμΉ, ν½μ
μ¬μ΄μ spatial variance μ λ³΄κ° λ§€μ° μ€μνλ€. ννΈ, Classificationμ spatial invarianceκ° μ€μνλ€. λͺ©ν κ°μ²΄μ μμΉμ μ£Όλ³ νΉμ§λ³΄λ€ νκ² μ체λ₯Ό μ κ²½λ§μ΄ μΈμνλ κ²μ΄ μ€μνκΈ° λλ¬Έμ΄λ€. λ°λΌμ ViT μ μλ€μ μ΄λ€ Biasλ μκ΄μμ΄ νΈν₯μ κ°κ³ λ°μ΄ν°λ₯Ό λ³Έλ€λ κ² μ체μ μλ¬Έμ ννλ©°, μ΄λ―Έμ§ μμ Inductive Biasμμ λ²μ΄λ, μ£Όμ΄μ§ λ°μ΄ν° μ 체 νΉμ§(ν¨μΉ) μ¬μ΄μ κ΄κ³λ₯Ό νμ
νλ κ³Όμ μμ scalabilityλ₯Ό νλν μ μλ€κ³ μ£Όμ₯νλ€.
κ·Έλμ Convμ λμμΌλ‘ μλμ μΌλ‘ Inductive Bias κ° λΆμ‘±ν Self-Attention, Transformer Architectureλ₯Ό μ¬μ©νλ€. λκ°μ§μ ν¨μ©μ±μ λν΄μλ μ΄λ―Έ μμμ μΈκΈνκΈ° λλ¬Έμ μλ΅νκ³ , μ¬κΈ°μ μ§κ³ λμ΄κ°μΌν μ μ Self-Attentionμ΄ Conv λλΉ Inductive Biasκ° μ λ€λ μ μ΄λ€. Self-Attention κ³Όμ μλ μ¬λ¬ μ°μ°, μ€μΌμΌ μ‘°μ κ°λ€μ΄ ν¬ν¨λμ§λ§ λ³Έμ§μ μΌλ‘ βλ΄μ β μ΄ μ€μ¬μ΄λ€. λ΄μ μ κ·Έ μ΄λ€ νΈν₯ (Convμ λμ‘°νλ €κ³ μ΄λ κ² μμ νμ§λ§ μ¬μ€ Position Embedding λνλ κ²λ μΌμ’
μ μ½ν Inductive Bias)μ΄ μ‘΄μ¬νμ§ μλλ€. μΌλ¨ μ£Όμ΄μ§ λͺ¨λ λ°μ΄ν°μ λν΄μ λ΄μ κ°μ μ°μΆνκ³ κ·Έ λ€μμ κ΄κ³κ° μλ€κ³ μκ°λλ μ 보λ₯Ό μΆλ¦¬κΈ° λλ¬Έμ΄λ€. Conv λμ λ¬λ¦¬ βμ
λ ₯λλ λ°μ΄ν°λ ~ ν κ²μ΄λ€β, βμ΄λ° νΉμ§μ κ°κ³ μμ κ²μ΄λ€β λΌλ κ°μ μ΄ μλ€. μ΄λ² ν¬μ€ν
μ λ§μ§λ§ μ―€μμ λ€μ λ€λ£¨κ² μ§λ§ κ·Έλμ ViTλ μΈμ€ν΄μ€ μ¬μ΄μ λͺ¨λ κ΄κ³λ₯Ό λ½μ보λ Self-Attention(λ΄μ ) μ κΈ°λ°μΌλ‘ λ§λ€μ΄μ‘κΈ° λλ¬Έμ μ΄λ―Έμ§μ Global Informationμ ν¬μ°©νλλ° νμν μ±λ₯μ 보μ΄κ³ , Conv λ βμ€μν μ 보λ κ·Όμ² ν½μ
μ λͺ°λ €μλ€λΌλβ Inductive Bias λλΆμ Local Informationμ ν¬μ°©νλλ° νμν μ±λ₯μ λΈλ€.
κ·Έλ λ€λ©΄ ν½μ
νλ νλλΌλ¦¬ λ΄μ ν΄μ€λ€λ κ²μΌκΉ?? μλλ€ μ¬κΈ°μ λ
Όλ¬Έμ μ λͺ©μ΄ An Image Is Worth 16x16 Words μΈ μ΄μ κ° λλ¬λλ€. μΌλ¨ ν½μ
νλ νλλΌλ¦¬ μ μ¬λλ₯Ό μΈ‘μ νλ κ²μ΄ μ μλ―Έν κΉ μκ°ν΄λ³΄μ. μμ°μ΄μ ν ν°κ³Ό λ¬λ¦¬ μ΄λ―Έμ§μ λ¨μΌ ν½μ
ν κ°λ ν° μΈμ¬μ΄νΈλ₯Ό μ»κΈ° νλ€λ€. ν½μ
μ λ§ κ·Έλλ‘ μ νλμΌ λΏμ΄λ€. ν½μ
μ μ¬λ¬ κ° λ¬Άμ΄ ν¨μΉ λ¨μλ‘ λ¬Άλλ€λ©΄ μ΄μΌκΈ°λ λ¬λΌμ§λ€. μΌμ ν¬κΈ° μ΄μμ ν¨μΉλΌλ©΄ μμ°μ΄μ ν ν°μ²λΌ κ·Έ μμ²΄λ‘ μ΄λ€ μλ―Έλ₯Ό λ΄μ μ μλ€. λ°λΌμ μ μλ μ 체 μ΄λ―Έμ§λ₯Ό μ¬λ¬ κ°μ 16x16 νΉμ 14x14 μ¬μ΄μ¦ ν¨μΉλ‘ λλμ΄ νλ νλλ₯Ό ν ν°μΌλ‘ κ°μ£Όν΄ μ΄λ―Έμ§ μνμ€λ₯Ό λ§λ€κ³ κ·Έκ²μ λͺ¨λΈμ InputμΌλ‘ μ¬μ©νλ€.
Class Diagram
λͺ¨λΈ ꡬ쑰μ λΌλκ° λλ λ΄μ©λ€μ λͺ¨λ μ΄ν΄λ³΄μκ³ , μμμ μμ ν λ΄μ©μ ꡬννκΈ° μν΄ μ΄λ€ λΈλ‘λ€μ μ¬μ©νλμ§ νμκ° μ§μ λ
Όλ¬Έμ λ³΄κ³ λ°λΌ ꡬνν μ½λμ ν¨κ» μμ보λλ‘ νμ. μμ 첨λΆν λͺ¨λΈ λͺ¨μλμ λμ μλ λΈλ‘λ€ νλ νλ μ΄ν΄λ³Ό μμ μ΄λ€. μ¬λ΄μΌλ‘ Google Researchμ Official Repo μμ ν¨κ» μ°Έκ³ νλλ°, μ½λκ° λͺ¨λ ꡬκΈμ΄ μμ μλ‘κ² λ―Έλ Jax, Flax λ‘ κ΅¬ν λμ΄ μμλ€. νμ΄ν μΉλ μ’ μ¨λ³Έ νμ μ
μ₯μμλ μ λ§ β¦ μ§μ₯λΆμ κ²½ννλ€. μ€λλ λ€μ ν λ² νμ΄μ€λΆ νμ΄ν μΉ κ°λ°νμ ν°μ λλ¦¬κ³ μΆλ€.
π¬Β Linear Projection of Flattened Patches
\[x_p \in R^{N * (P^2β’C)}\]
\[z_{0} = [x_{class}; x_p^1E;x_p^2E;x_p^3E....x_p^NE]\]
\[N = \frac{H*W}{P*P}\]
ViTμ μ
λ ₯ μλ² λ©μ μμ±νλ μν μ νλ€. ViTλ $x \in R^{H * W * C}$(H: height, W: width, C: channel)μ νμμ κ°λ μ΄λ―Έμ§λ₯Ό μ
λ ₯μΌλ‘ λ°μ κ°λ‘ μΈλ‘ κΈΈμ΄κ° $P$, μ±λ κ°μ $C$μΈ $N$κ°μ ν¨μΉλ‘ reshape νλ€. νμκ° μ½λ ꡬν μ€ κ°μ₯ νΌλν λΆλΆμ΄ λ°λ‘ ν¨μΉ κ°μ $N$μ΄μλ€. μ§κ΄μ μΌλ‘ ν¨μΉ κ°μλΌκ³ νλ©΄, μ 체 μ΄λ―Έμ§ μ¬μ΄μ¦μμ ν¨μΉ ν¬κΈ°λ₯Ό λλ κ°μ΄λΌκ³ μκ°νκΈ° μ½κΈ° λλ¬Έμ΄λ€. μλ₯Ό λ€λ©΄ 512x512μ§λ¦¬ μ΄λ―Έμ§λ₯Ό 16x16 μ¬μ΄μ¦μ ν¨μΉλ‘ λλλ€κ³ ν΄λ³΄μ. νμλ λ¨μν 512/16=32 λΌλ κ²°κ³Όλ₯Ό μ΄μ©ν΄ $N=32$λ‘ μ€μ νκ³ μ€νμ μ§ννλ€κ° ν
μ μ°¨μμ΄ λ§μ§ μμ λ°μνλ μλ¬ λ‘κ·Έλ₯Ό λ§μ£Όνμλ€. κ·Έλ¬λ λ
Όλ¬Έ μ μμμ νμΈν΄λ³΄λ©΄, $H * W / P^2$μ΄ λ°λ‘ ν¨μΉ κ°μ$N$μΌλ‘ μ μλλ€. κ·Έλμ λ§μ½ 512x512 μ¬μ΄μ¦μ RGB μ΄λ―Έμ§ 10μ₯μ ViT μ
λ ₯ μλ² λ©μ λ§κ² μ°¨μ λ³ννλ€λ©΄ κ²°κ³Όλ [10, 3, 1024, 768] μ΄ λ κ²μ΄λ€. (μ΄ μμλ₯Ό μμΌλ‘ κ³μ μ΄μ©νκ² λ€)
μ΄λ κ² μ°¨μμ λ°κΏμ€ μ΄λ―Έμ§λ₯Ό nn.Linear((channels * patch_size**2), dim_model) λ₯Ό ν΅ν΄ ViTμ μλ² λ© λ μ΄μ΄μ μ ν ν¬μν΄μ€λ€. μ¬κΈ°μ μμ°μ΄ μ²λ¦¬μ νμ΄ν μΉλ₯Ό μμ£Ό μ¬μ©νμλ λ
μλΌλ©΄ μ nn.Embeddingμ μ¬μ©νμ§ μμλκ° μλ¬Έμ κ°μ§ μ μλ€.
μμ°μ΄ μ²λ¦¬μμ μ
λ ₯ μλ² λ©μ λ§λ€λλ λͺ¨λΈμ ν ν¬λμ΄μ μ μν΄ μ¬μ μ μλ vocabμ μ¬μ΄μ¦κ° μ
λ ₯ λ¬Έμ₯μ μν ν ν° κ°μλ³΄λ€ ν¨μ¬ ν¬κΈ° λλ¬Έμ λ°μ΄ν° 룩μ
ν
μ΄λΈ λ°©μμ nn.Embedding μ μ¬μ©νκ² λλ€. μ΄κ² λ¬΄μ¨ λ§μ΄λλ©΄, ν ν¬λμ΄μ μ μν΄ μ¬μ μ μ μλ vocab μ μ²΄κ° nn.Embedding(vocab_size, dim_model)λ‘ ν¬μ λμ΄ κ°λ‘λ vocab μ¬μ΄μ¦, μΈλ‘λ λͺ¨λΈμ μ°¨μ ν¬κΈ°μ ν΄λΉνλ 룩μ
ν
μ΄λΈμ΄ μμ±λκ³ , λ΄κ° μ
λ ₯ν ν ν°λ€μ μ 체 vocabμ μΌλΆλΆμΌν
λ μ 체 μλ² λ© λ£©μ
ν
μ΄λΈμμ λ΄κ° μλ² λ©νκ³ μΆμ ν ν°λ€μ μΈλ±μ€λ§ μμλΈλ€λ κ²μ΄λ€.
κ·Έλμ nn.Embedding μ μ μλ μ°¨μκ³Ό μ€μ μ
λ ₯ λ°μ΄ν°μ μ°¨μμ΄ λ§μ§ μμλ ν¨μκ° λμνκ² λλ κ²μ΄λ€. κ·Έλ¬λ λΉμ μ κ²½μ°, μ¬μ μ μ μλ vocabμ΄λΌλ κ°λ
μ΄ μ ν μκ³ μ
λ ₯ μ΄λ―Έμ§ μμ νμ κ³ μ λ ν¬κΈ°μ μ°¨μμΌλ‘ λ€μ΄μ€κΈ° λλ¬Έμ nn.Embeddingμ΄ μλ nn.Linear μ μ¬μ©ν΄ κ³§λ°λ‘ μ ν ν¬μμ ꡬνν κ²μ΄λ€. λ λ©μλμ λν μμΈν λΉκ΅λ νμ΄ν μΉ κ΄λ ¨ ν¬μ€νΈμμ λ€μ ν λ² μμΈν λ€λ£¨λλ‘ νκ² λ€.
ννΈ, Position Embeddingμ λνκΈ° μ , Input Embeddingμ μ°¨μμ [10, 1024, 1024] μ΄ λλ€. μ§κΈκΉμ§ μ€λͺ
ν λΆλΆ(Linear Projection of Flattened Patches )μ νμ΄ν μΉ μ½λλ‘ κ΅¬ννλ©΄ λ€μκ³Ό κ°λ€.
class VisionTransformer(nn.Module):
...
μ€λ΅
...
self.num_patches = int(image_size / patch_size)**2
self.input_embedding = nn.Linear((channels * patch_size**2), dim_model) # Projection Layer for Input Embedding
...
μ€λ΅
...
def forward(self, inputs: Tensor) -> any:
""" For cls pooling """
assert inputs.ndim != 4, f"Input shape should be [BS, CHANNEL, IMAGE_SIZE, IMAGE_SIZE], but got {inputs.shape}"
x = inputs
x = self.input_embedding(
x.reshape(x.shape[0], self.num_patches, (self.patch_size**2 * x.shape[1])) # Projection Layer for Input Embedding
)
cls_token = torch.zeros(x.shape[0], 1, x.shape[2]) # can change init method
x = torch.cat([cls_token, x], dim=1)
...
μλ² λ© λ μ΄μ΄λ₯Ό κ°μ²΄λ‘ λ°λ‘ ꡬνν΄λ λμ§λ§, νμλ κ΅³μ΄ μΆμνκ° νμνμ§ μλ€κ³ μκ°ν΄ ViTμ μ΅μμ ν΄λμ€μΈ VisionTransformerμ forward λ©μλ 맨 μ΄λ°λΆμ ꡬννκ² λμλ€. μ
λ ₯ λ°μ μ΄λ―Έμ§ ν
μλ₯Ό torch.reshape μ ν΅ν΄ [ν¨μΉ κ°μ, ν½μ
κ°μ*μ±λκ°μ] λ‘ λ°κΎΌ λ€, 미리 μ μν΄λ self.input_embedding μ λ§€κ°λ³μλ‘ μ λ¬ν΄ βμμΉ μλ² λ©β κ°μ΄ λν΄μ§κΈ° μ Input Embeddingμ λ§λ λ€.
ννΈ, CLS Poolingμ μν΄ λ§μ§λ§μ [batch, 1, image_size] μ μ°¨μμ κ°λ cls_token μ μ μν΄ ν¨μΉ μνμ€μ concat (맨 μμ)ν΄μ€λ€. μ΄ λ λ
Όλ¬Έμ μ μλ μμ μ, CLS Tokenμ μ ν ν¬μνμ§ μμΌλ©°, ν¨μΉ μνμ€μ μ ν ν¬μμ΄ μ΄λ€μ§κ³ λ λ€μ 맨 μμ Concat νκ² λλ€.
CLS TokenκΉμ§ λν μ΅μ’
Input Embedding μ ν
μ μ°¨μμ [10, 1025, 1024] κ° λλ€.
π’Β Positional Embedding
\[E_{pos} \in R^{(N+1)*D}\]
μ΄λ―Έμ§λ₯Ό ν¨μΉ λ¨μμ μλ² λ©μΌλ‘ λ§λ€μλ€λ©΄ μ΄μ μμΉ μλ² λ©μ μ μν΄μ λν΄μ£Όλ©΄ λͺ¨μλ μ Embedded Patches , μ¦ μΈμ½λμ λ€μ΄κ° μ΅μ’
Patch Embedding μ΄ μμ± λλ€. μμΉ μλ² λ©μ λ§λλ λ°©μμ κΈ°μ‘΄ Transformer, BERT μ λμΌνλ€. μλ VisionEncoder ν΄λμ€λ₯Ό ꡬνν μ½λλ₯Ό μ΄ν΄λ³΄μ.
class VisionEncoder(nn.Module):
...
μ€λ΅
...
self.positional_embedding = nn.Embedding((self.num_patches + 1), dim_model) # add 1 for cls token
...
μ€λ΅
...
def forward(self, inputs: Tensor) -> tuple[Tensor, Tensor]:
layer_output = []
pos_x = torch.arange(self.num_patches + 1).repeat(inputs.shape[0]).to(inputs) # inputs.shape[0] = Batch Size of Input
x = self.dropout(
inputs + self.positional_embedding(pos_x)
)
...
Input Embeddingκ³Ό λ€λ₯΄κ² μμΉ μλ² λ©μ nn.EmbeddingμΌλ‘ ꡬννλλ°, μ¬κΈ°μλ μ¬μ€ nn.Linearλ₯Ό μ¬μ©ν΄λ 무방νλ€. κ·Έκ²λ³΄λ€ nn.Embeddingμ μ
λ ₯ μ°¨μμΈ self.num_patches + 1 μ μ£Όλͺ©ν΄λ³΄μ. μ 1μ λν΄μ€ κ°μ μ¬μ©νμκΉ??
ViTλ BERTμ CLS Token Pooling μ μ°¨μ©νκΈ° μν΄ ν¨μΉ μνμ€ λ§¨ μμ CLS ν ν°μ μΆκ°νκΈ° λλ¬Έμ΄λ€. μ΄λ κ² μΆκ°λ CLS Tokenμ μΈμ½λλ₯Ό κ±°μ³ μ΅μ’
MLP Headμ νλ¬λ€μ΄κ° λ‘μ§μΌλ‘ λ³νλλ€. λ§μ½ λ
μκ»μ CLS Token Pooling λμ λ€λ₯Έ νλ§ λ°©μμ μ¬μ©ν κ±°λΌλ©΄ 1μ μΆκ°ν΄μ€ νμλ μλ€.
μ μ΄μ κ°μ²΄ μΈμ€ν΄μ€ μ΄κΈ°ν λΉμμ CLS Token μ μΆκ°λ₯Ό λ°μν κ°μ μ λ¬νλ©΄ λμ§ μλκ°νλ μλ¬Έμ΄ λ€ μλ μλ€. νμ§λ§ VisionEncoder κ°μ²΄ μΈμ€ν΄μ€ μ΄κΈ°ν λΉμμλ num_patches κ°μΌλ‘ CLS Tokenμ΄ μΆκ°λκΈ° μ΄μ κ°(+1 λ°μμ΄ μλμ΄ μμ)μ μ λ¬νλλ‘ μ€κ³ λμ΄ μμ΄μ CLS Poolingμ μ¬μ©ν κ±°λΌλ©΄ 1 μΆκ°λ₯Ό κΌ ν΄μ€μΌ νλ€.
Performance Table by making Position Embedding method
ννΈ μ μλ 2D Postion Embedding, Relative Position Embedding λ°©μλ μ μ©ν΄λ΄€μ§λ§, ꡬν 볡μ‘λ & μ°μ°λ λλΉ μ±λ₯ ν₯μ νμ΄ λ§€μ° λ―Έλ―Έν΄ μΌλ°μ μΈ 1D Position Embeddingμ μ¬μ©ν κ²μ μΆμ²νκ³ μλ€.
π©βπ©βπ§βπ¦ Multi-Head Attention
\[z_t^{'} = MSA(LN(z_{t-1}) + z_{t-1})\]
\[MSA(z) = [SA_1();SA_2();SA_3()...SA_k()]*U_{msa}, \ \ U_{msa} \in R^{(k*D_h)*D} \\\]
νΈλμ€ν¬λ¨Έ κ³μ΄ λͺ¨λΈμ ν΅μ¬ Multi-Head Self-Attention λͺ¨λμ λν΄μ μμ보μ. μ¬μ€ κΈ°μ‘΄ μμ°μ΄ μ²λ¦¬ Transformer, BERT λ±μ λμ λ°©μκ³Ό μμ ν λμΌνλ©°, μ½λλ‘ κ΅¬νν λ μμ λμΌνκ² λ§λ€μ΄μ£Όλ©΄ λλ€. μμΈν μ리μ λμ λ°©μμ Attention Is All You Need 리뷰 ν¬μ€νΈμμ μ€λͺ
νκΈ° λλ¬Έμ μλ΅νκ³ λμ΄κ°κ² λ€. ννΈ νμ΄ν μΉλ‘ ꡬνν Multi-Head Self-Attention λΈλμ λν μ½λλ λ€μκ³Ό κ°λ€.
def scaled_dot_product_attention(q: Tensor, k: Tensor, v: Tensor, dot_scale: Tensor) -> Tensor:
"""
Scaled Dot-Product Attention
Args:
q: query matrix, shape (batch_size, seq_len, dim_head)
k: key matrix, shape (batch_size, seq_len, dim_head)
v: value matrix, shape (batch_size, seq_len, dim_head)
dot_scale: scale factor for Qβ’K^T result, same as pure transformer
Math:
A = softmax(qβ’k^t/sqrt(D_h)), SA(z) = Av
"""
attention_dist = F.softmax(
torch.matmul(q, k.transpose(-1, -2)) / dot_scale,
dim=-1
)
attention_matrix = torch.matmul(attention_dist, v)
return attention_matrix
class AttentionHead(nn.Module):
"""
In this class, we implement workflow of single attention head
Args:
dim_model: dimension of model's latent vector space, default 1024 from official paper
dim_head: dimension of each attention head, default 64 from official paper (1024 / 16)
dropout: dropout rate, default 0.1
Math:
[q,k,v]=zβ’U_qkv, A = softmax(qβ’k^t/sqrt(D_h)), SA(z) = Av
"""
def __init__(self, dim_model: int = 1024, dim_head: int = 64, dropout: float = 0.1) -> None:
super(AttentionHead, self).__init__()
self.dim_model = dim_model
self.dim_head = dim_head
self.dropout = dropout
self.dot_scale = torch.sqrt(torch.tensor(self.dim_head))
self.fc_q = nn.Linear(self.dim_model, self.dim_head)
self.fc_k = nn.Linear(self.dim_model, self.dim_head)
self.fc_v = nn.Linear(self.dim_model, self.dim_head)
def forward(self, x: Tensor) -> Tensor:
attention_matrix = scaled_dot_product_attention(
self.fc_q(x),
self.fc_k(x),
self.fc_v(x),
self.dot_scale
)
return attention_matrix
class MultiHeadAttention(nn.Module):
"""
In this class, we implement workflow of Multi-Head Self-Attention
Args:
dim_model: dimension of model's latent vector space, default 1024 from official paper
num_heads: number of heads in MHSA, default 16 from official paper for ViT-Large
dim_head: dimension of each attention head, default 64 from official paper (1024 / 16)
dropout: dropout rate, default 0.1
Math:
MSA(z) = [SA1(z); SA2(z); Β· Β· Β· ; SAk(z)]β’Umsa
Reference:
https://arxiv.org/abs/2010.11929
https://arxiv.org/abs/1706.03762
"""
def __init__(self, dim_model: int = 1024, num_heads: int = 8, dim_head: int = 64, dropout: float = 0.1) -> None:
super(MultiHeadAttention, self).__init__()
self.dim_model = dim_model
self.num_heads = num_heads
self.dim_head = dim_head
self.dropout = dropout
self.attention_heads = nn.ModuleList(
[AttentionHead(self.dim_model, self.dim_head, self.dropout) for _ in range(self.num_heads)]
)
self.fc_concat = nn.Linear(self.dim_model, self.dim_model)
def forward(self, x: Tensor) -> Tensor:
""" x is already passed nn.Layernorm """
assert x.ndim == 3, f'Expected (batch, seq, hidden) got {x.shape}'
attention_output = self.fc_concat(
torch.cat([head(x) for head in self.attention_heads], dim=-1) # concat all dim_head = num_heads * dim_head
)
return attention_output
MultiHeadAttentionμ κ°μ₯ μ΅μμ κ°μ²΄λ‘ λκ³ , νμμ AttentionHeadκ°μ²΄λ₯Ό λ°λ‘ ꡬννλ€. μ΄λ κ² κ΅¬ννλ©΄, μ΄ν
μ
ν΄λλ³λ‘ 쿼리, ν€, 벨λ₯ μ μ ν¬μ νλ ¬(nn.Linear)μ λ°λ‘ ꡬνν΄μ€ νμκ° μμ΄μ§λ©°, nn.ModuleList λ₯Ό ν΅ν΄ κ°λ³ ν΄λλ₯Ό ν λ²μ κ·Έλ£Ήννκ³ loop λ₯Ό ν΅ν΄ μΆλ ₯ κ²°κ³Όλ₯Ό concat ν΄μ€ μ μμ΄ λ³΅μ‘νκ³ λ§μ μλ¬λ₯Ό μ λ°νλ ν
μ μ°¨μ μ‘°μμ νΌν μ μμΌλ©°, μ½λμ κ°λ
μ±μ΄ μ¬λΌκ°λ ν¨κ³Όκ° μλ€.
π³οΈ MLP
\[z_{t} = MLP(LN(z_{t}^{'}) + z_{t}^{'})\]
μ΄λ¦λ§ MLPλ‘ λ°λμμ λΏ, κΈ°μ‘΄ νΈλμ€ν¬λ¨Έμ νΌλ ν¬μλ λΈλκ³Ό λμΌν μν μ νλ€. μμ μμΈν λμ λ°©μμ μ¬κΈ° ν¬μ€νΈμμ νμΈνμ. νμ΄ν μΉλ‘ ꡬνν μ½λλ λ€μκ³Ό κ°λ€.
class MLP(nn.Module):
"""
Class for MLP module in ViT-Large
Args:
dim_model: dimension of model's latent vector space, default 512
dim_mlp: dimension of FFN's hidden layer, default 2048 from official paper
dropout: dropout rate, default 0.1
Math:
MLP(x) = MLP(LN(x))+x
"""
def __init__(self, dim_model: int = 1024, dim_mlp: int = 4096, dropout: float = 0.1) -> None:
super(MLP, self).__init__()
self.mlp = nn.Sequential(
nn.Linear(dim_model, dim_mlp),
nn.GELU(),
nn.Dropout(p=dropout),
nn.Linear(dim_mlp, dim_model),
nn.Dropout(p=dropout),
)
def forward(self, x: Tensor) -> Tensor:
return self.mlp(x)
νΉμ΄ν μ μ Activation FunctionμΌλ‘ GELUλ₯Ό μ¬μ©(κΈ°μ‘΄ νΈλμ€ν¬λ¨Έλ RELU)νλ€λ μ μ΄λ€.
π Vision Encoder Layer
ViT μΈμ½λ λΈλ 1κ°μ ν΄λΉνλ νμ λͺ¨λκ³Ό λμμ ꡬνν κ°μ²΄μ΄λ€. ꡬνν μ½λλ μλμ κ°λ€.
class VisionEncoderLayer(nn.Module):
"""
Class for encoder_model module in ViT-Large
In this class, we stack each encoder_model module (Multi-Head Attention, Residual-Connection, Layer Normalization, MLP)
"""
def __init__(self, dim_model: int = 1024, num_heads: int = 16, dim_mlp: int = 4096, dropout: float = 0.1) -> None:
super(VisionEncoderLayer, self).__init__()
self.self_attention = MultiHeadAttention(
dim_model,
num_heads,
int(dim_model / num_heads),
dropout,
)
self.layer_norm1 = nn.LayerNorm(dim_model)
self.layer_norm2 = nn.LayerNorm(dim_model)
self.dropout = nn.Dropout(p=dropout)
self.mlp = MLP(
dim_model,
dim_mlp,
dropout,
)
def forward(self, x: Tensor) -> Tensor:
ln_x = self.layer_norm1(x)
residual_x = self.dropout(self.self_attention(ln_x)) + x
ln_x = self.layer_norm2(residual_x)
fx = self.mlp(ln_x) + residual_x # from official paper & code by Google Research
return fx
νΉμ΄μ μ λ§μ§λ§
λ§μ§λ§ μΈμ½λμ μΆλ ₯κ°μλ§ νλ² λ MLP Layerμ Residual κ²°κ³Όλ₯Ό λν λ€, λ€μ μΈμ½λ λΈλ‘μ μ λ¬νκΈ° μ μ μΈ΅ μ κ·νλ₯Ό ν λ² λ μ μ©νλ€λ κ²μ΄λ€. λͺ¨λΈ λͺ¨μλμλ λμ μμ§ μμ§λ§, λ³Έλ¬Έμ ν΄λΉ λ΄μ©μ΄ μ€λ € μλ€.layernormμ μ μ©νλ€.
π VisionEncoder
μ
λ ₯ μ΄λ―Έμ§λ₯Ό Patch EmbeddingμΌλ‘ μΈμ½λ© νκ³ Nκ°μ VisionEncoderLayerλ₯Ό μκΈ° μν΄ κ΅¬νλ κ°μ²΄μ΄λ€. Patch Embeddingμ λ§λλ λΆλΆμ μ΄λ―Έ μμμ μ€λͺ
νκΈ° λλ¬Έμ λμ΄κ°κ³ , μΈμ½λ λΈλμ Nκ° μλ λ°©λ²μ μμλ nn.ModuleList λ₯Ό μ¬μ©νλ©΄ κ°νΈνκ² κ΅¬νν μ μλ€. μλ μ½λλ₯Ό μ΄ν΄λ³΄μ.
class VisionEncoder(nn.Module):
"""
In this class, encode input sequence(Image) and then we stack N VisionEncoderLayer
This model is implemented by cls pooling method for classification
First, we define "positional embedding" and then add to input embedding for making patch embedding
Second, forward patch embedding to N EncoderLayer and then get output embedding
Args:
num_patches: number of patches in input image => (image_size / patch_size)**2
N: number of EncoderLayer, default 24 for large model
"""
def __init__(self, num_patches: int, N: int = 24, dim_model: int = 1024, num_heads: int = 16, dim_mlp: int = 4096, dropout: float = 0.1) -> None:
super(VisionEncoder, self).__init__()
self.num_patches = num_patches
self.positional_embedding = nn.Embedding((self.num_patches + 1), dim_model) # add 1 for cls token
self.num_layers = N
self.dim_model = dim_model
self.num_heads = num_heads
self.dim_mlp = dim_mlp
self.dropout = nn.Dropout(p=dropout)
self.encoder_layers = nn.ModuleList(
[VisionEncoderLayer(dim_model, num_heads, dim_mlp, dropout) for _ in range(self.num_layers)]
)
self.layer_norm = nn.LayerNorm(dim_model)
def forward(self, inputs: Tensor) -> tuple[Tensor, Tensor]:
layer_output = []
pos_x = torch.arange(self.num_patches + 1).repeat(inputs.shape[0]).to(inputs)
x = self.dropout(
inputs + self.positional_embedding(pos_x)
)
for layer in self.encoder_layers:
x = layer(x)
layer_output.append(x)
encoded_x = self.layer_norm(x) # from official paper & code by Google Research
layer_output = torch.stack(layer_output, dim=0).to(x.device) # For Weighted Layer Pool: [N, BS, SEQ_LEN, DIM]
return encoded_x, layer_output
λ§μ§λ§ μΈ΅μ μΈμ½λ μΆλ ₯κ°μλ layernormμ μ μ©ν΄μ€μΌ ν¨μ μμ§ λ§μ. ννΈ, layer_outputλ λ μ΄μ΄ λ³ μ΄ν
μ
κ²°κ³Όλ₯Ό μκ°ν νκ±°λ λμ€μ WeightedLayerPoolμ μ¬μ©νλ €κ³ λ§λ€μλ€.
π€ VisionTransformer
ViT λͺ¨λΈμ κ°μ₯ μ΅μμ κ°μ²΄λ‘, μμμ μ€λͺ
ν λͺ¨λ λͺ¨λλ€μ λμμ΄ μ΄λ€μ§λ κ³³μ΄λ€. μ¬μ©μλ‘λΆν° νμ΄νΌνλΌλ―Έν°λ₯Ό μ
λ ₯ λ°μ λͺ¨λΈμ ν¬κΈ°, κΉμ΄, ν¨μΉ ν¬κΈ°, μ΄λ―Έμ§ μλ² λ© μΆμΆ λ°©μμ μ§μ νλ€. κ·Έλ¦¬κ³ μ
λ ₯ μ΄λ―Έμ§λ₯Ό μ λ¬λ°μ μλ² λ©μ λ§λ€κ³ μΈμ½λμ μ λ¬ν λ€, MLP Head λ₯Ό ν΅ν΄ μ΅μ’
μμΈ‘ κ²°κ³Όλ₯Ό λ°ννλ μν μ νλ€.
μ΄λ―Έμ§ μλ² λ© μΆμΆ λ°©μμ Linear Projectionκ³Ό Convolutionμ΄ μλ€. μ μκ° λ
Όλ¬Έμμ λ§νλ μΌλ°μ μΈ ViTλ₯Ό λ§νλ©° νμλ μ μκ° Hybrid ViTλΌκ³ λ°λ‘ λͺ
λͺ
νλ λͺ¨λΈμ΄λ€. μλ² λ© μΆμΆ λ°©μ μ΄μΈμ λ€λ₯Έ μ°¨μ΄λ μ ν μλ€. extractor λ§€κ°λ³μλ₯Ό ν΅ν΄ μλ² λ© μΆμΆ λ°©μμ μ§μ ν μ μμΌλ μλ μ½λλ₯Ό νμΈν΄λ³΄μ.
class VisionTransformer(nn.Module):
"""
Main class for ViT of cls pooling, Pytorch implementation
We implement pure ViT, Not hybrid version which is using CNN for extracting patch embedding
input must be [BS, CHANNEL, IMAGE_SIZE, IMAGE_SIZE]
In NLP, input_sequence is always smaller than vocab size
But in vision, input_sequence is always same as image size, not concept of vocab in vision
So, ViT use nn.Linear instead of nn.Embedding for input_embedding
Args:
num_classes: number of classes for classification task
image_size: size of input image, default 512
patch_size: size of patch, default 16 from official paper for ViT-Large
extractor: option for feature extractor, default 'base' which is crop & just flatten
if you want to use Convolution for feature extractor, set extractor='cnn' named hybrid ver in paper
classifier: option for pooling method, default token meaning that do cls pooling
if you want to use mean pooling, set classifier='mean'
mode: option for train type, default fine-tune, if you want pretrain, set mode='pretrain'
In official paper & code by Google Research, they use different classifier head for pretrain, fine-tune
Math:
image2sequence: [batch, channel, image_size, image_size] -> [batch, patch, patch_size^2*channel]
input_embedding: R^(P^2 Β·C)ΓD
Reference:
https://arxiv.org/abs/2010.11929
https://arxiv.org/abs/1706.03762
https://github.com/google-research/vision_transformer/blob/main/vit_jax/models_vit.py#L184
"""
def __init__(
self,
num_classes: int,
channels: int = 3,
image_size: int = 512,
patch_size: int = 16,
num_layers: int = 24,
dim_model: int = 1024,
num_heads: int = 16,
dim_mlp: int = 4096,
dropout: float = 0.1,
extractor: str = 'base',
classifier: str = 'token',
mode: str = 'fine_tune',
) -> None:
super(VisionTransformer, self).__init__()
self.num_patches = int(image_size / patch_size)**2
self.num_layers = num_layers
self.patch_size = patch_size
self.dim_model = dim_model
self.num_heads = num_heads
self.dim_mlp = dim_mlp
self.dropout = nn.Dropout(p=dropout)
# Input Embedding
self.extractor = extractor
self.input_embedding = nn.Linear((channels * patch_size**2), dim_model)
self.conv = nn.Conv2d(
in_channels=channels,
out_channels=self.dim_model,
kernel_size=self.patch_size,
stride=self.patch_size
)
# Encoder Multi-Head Self-Attention
self.encoder = VisionEncoder(
self.num_patches,
self.num_layers,
self.dim_model,
self.num_heads,
self.dim_mlp,
dropout,
)
self.classifier = classifier
self.pretrain_classifier = nn.Sequential(
nn.Linear(self.dim_model, self.dim_model),
nn.Tanh(),
)
self.fine_tune_classifier = nn.Linear(self.dim_model, num_classes)
self.mode = mode
def forward(self, inputs: Tensor) -> any:
""" For cls pooling """
assert inputs.ndim != 4, f"Input shape should be [BS, CHANNEL, IMAGE_SIZE, IMAGE_SIZE], but got {inputs.shape}"
x = inputs
if self.extractor == 'cnn':
# self.conv(x).shape == [batch, dim, image_size/patch_size, image_size/patch_size]
x = self.conv(x).reshape(x.shape[0], self.dim_model, self.num_patches**2).transpose(-1, -2)
else:
# self.extractor == 'base':
x = self.input_embedding(
x.reshape(x.shape[0], self.num_patches, (self.patch_size**2 * x.shape[1]))
)
cls_token = torch.zeros(x.shape[0], 1, x.shape[2]) # can change init method
x = torch.cat([cls_token, x], dim=1)
x, layer_output = self.encoder(x) # output
# classification
x = x[:, 0, :] # select cls token, which is position 0 in sequence
if self.mode == 'fine_tune':
x = self.fine_tune_classifier(x)
if self.mode == 'pretrain':
x = self.fine_tune_classifier(self.pretrain_classifier(x))
return x
ννΈ, μ½λμμ λμ¬κ²¨λ΄μΌ ν μ μ MLP Headλ‘, μ μλ pre-train μμ κ³Ό fine-tune μμ μ μλ‘ λ€λ₯Έ Classifier Headλ₯Ό μ¬μ©νλ€. μ μμλ Activation Function 1κ°μ λ κ°μ MLP Layerλ₯Ό μ¬μ©νκ³ , νμμλ 1κ°μ MLP Layerλ₯Ό μ¬μ©νλ€.
λ€λ§, pretrain_classifierμ μ
μΆλ ₯ μ°¨μμ λν μ νν μμΉλ₯Ό λ
Όλ¬Έμ΄λ official repo codeλ₯Ό νμΈν΄λ μ°Ύμ μ μμλ€, κ·Έλμ μμλ‘ λͺ¨λΈμ μ°¨μκ³Ό λκ°μ΄ μΈν
νκ² λμλ€.
λν μ μλ CLS Poolingκ³Ό λλΆμ΄ GAP λ°©μλ μ μνλλ°, GAP λ°©μμ μΆνμ λ°λ‘ μΆκ°κ° νμνλ€. κ·Έλ¦¬κ³ μ¬μ νλ ¨κ³Ό νμΈ νλ λͺ¨λ λΆλ₯ ν
μ€ν¬λ₯Ό μννλλ° (μ¬μ§μ΄ κ°μ λ°μ΄ν° μΈνΈλ₯Ό μ¬μ©ν¨) μ κ΅³μ΄ μλ‘ λ€λ₯Έ Classifier Headλ₯Ό μ μνλμ§ μλλ₯Ό μ μ μμ΄ λ
Όλ¬Έμ λ€μ μ½μ΄λ΄€μ§λ§, μ΄μ μ λν΄μ μμΈν μΈκΈνλ λΆλΆμ΄ μμλ€.
ViTλ μ
λ ₯ μλ² λ©μ μ μνλ λΆλΆμ μ μΈνλ©΄ μ μμ μλλλ‘ κΈ°μ‘΄ νΈλμ€ν¬λ¨Έμ λμΌν λͺ¨λΈ ꡬ쑰λ₯Ό κ°μ‘λ€. μμ ν λ€λ₯Έ λ°μ΄ν°μΈ μ΄λ―Έμ§μ ν
μ€νΈμ κ°μ ꡬ쑰μ λͺ¨λΈμ μ μ©νλ€λ κ²μ΄ μ λ§ μ½μ§ μμ 보μλλ°, ν¨μΉ κ°λ
μ λ§λ€μ΄ μμ°μ΄μ ν ν°μ²λΌ κ°μ£Όνκ³ μ¬μ©ν κ²μ΄ μλλλ‘ κ΅¬ννλλ° μ§κ΄μ μ΄λ©΄μλ μ λ§ ν¨κ³Όμ μ΄μλ€κ³ μκ°νλ€. μ΄μ μ΄λ κ² λ§λ€μ΄μ§ λͺ¨λΈμ ν΅ν΄ μ§νν μ¬λ¬ μ€ν κ²°κ³Όμ μ΄λ€ μΈμ¬μ΄νΈκ° λ΄κ²¨ μλμ§ μμ보μ.
π¬Β Insight from Experiment
π‘Β Insight 1. ViTμ Scalability μ¦λͺ
Pre-Trainμ μ¬μ©λλ μ΄λ―Έμ§ λ°μ΄ν° μΈνΈμ ν¬κΈ°κ° 컀μ§μλ‘Fine-Tune StageμμViTκ°CNNλ³΄λ€ λμ μ±λ₯- κ°μ μ±λ₯μ΄λΌλ©΄
ViTκ° μλμ μΌλ‘ μ μ μ°μ°λμ κΈ°λ‘
μ λνλ Pre-Train Stageμ μ¬μ©λ μ΄λ―Έμ§ λ°μ΄ν° μΈνΈμ λ°λ₯Έ λͺ¨λΈμ Fine-Tune μ±λ₯ μΆμ΄λ₯Ό λνλΈ μλ£λ€. μ¬μ νλ ¨ λ°μ΄ν° μ€μΌμΌμ΄ ν¬μ§ μμ λλ Conv κΈ°λ°μ ResNet μ리μ¦κ° ViT μ리μ¦λ₯Ό μλνλ λͺ¨μ΅μ 보μ¬μ€λ€. νμ§λ§ λ°μ΄ν° μΈνΈμ ν¬κΈ°κ° 컀μ§μλ‘ μ μ ViT μ리μ¦μ μ±λ₯μ΄ ResNetμ λ₯κ°νλ κ²°κ³Όλ₯Ό λ³Ό μ μλ€.
ννΈ, ViT & ResNet μ±λ₯ κ²°κ³Ό λͺ¨λ ImageNetκ³Ό JFT-Imageλ‘ μ¬μ νλ ¨ λ° νμΈ νλμ κ±°μ³ λμλ€κ³ νλ μ°Έκ³ νμ. μΆκ°λ‘ νμΈ νλ κ³Όμ μμ μ¬μ νλ ¨ λλ³΄λ€ μ΄λ―Έμ§ μ¬μ΄μ¦λ₯Ό ν€μμ νλ ¨μ μμΌ°λ€κ³ λ Όλ¬Έμμ λ°νκ³ μλλ°, μ΄λ μ μμ μ€ν κ²°κ³Όμ κΈ°μΈν κ²μ΄λ€. λ Όλ¬Έμ λ°λ₯΄λ©΄ νμΈ νλ λ μ¬μ νλ ¨ λΉμλ³΄λ€ λ λμ ν΄μλμ μ΄λ―Έμ§λ₯Ό μ¬μ©νλ©΄ μ±λ₯μ΄ ν₯μ λλ€κ³ νλ κΈ°μ΅νλ€κ° μ¨λ¨Ήμ΄λ³΄μ.
μ λνλ μ°μ°λ λ³νμ λ°λ₯Έ λͺ¨λΈμ μ±λ₯ μΆμ΄λ₯Ό λνλΈ κ·Έλ¦Όμ΄λ€. λ μ§ν λͺ¨λ κ°μ μ μλΌλ©΄ ViT μ리μ¦μ μ°μ°λμ΄ νμ ν μ μμ μ μ μλ€. λν μ νλ 95% μ΄ν ꡬκ°μμ κ°μ μ±λ₯μ΄λΌλ©΄ ViTμ Hybrid λ²μ λͺ¨λΈμ μ°μ°λμ΄ μΌλ° ViT λ²μ λ³΄λ€ νμ ν μ μμ νμΈν μ μλ€. μ΄λ¬ν μ¬μ€μ μΆνμ Swin-Transformer μ€κ³μ μκ°μ μ€λ€.
λ κ°μ μ€ν κ²°κ³Όλ₯Ό μ’
ν©νμ λ, ViTκ° ResNetλ³΄λ€ μΌλ°ν μ±λ₯μ΄ λ λμΌλ©°(λν 1) λͺ¨λΈμ Saturation νμμ΄ λλλ¬μ§μ§ μμ μ±λ₯μ νκ³μΉ(λν 2) μμ λ λλ€κ³ λ³Ό μ μλ€. λ°λΌμ κΈ°μ‘΄ νΈλμ€ν¬λ¨Έμ μ°μ°β’ꡬ쑰μ μΈ‘λ©΄μμ Scalabilityλ₯Ό μ±κ³΅μ μΌλ‘ μ΄μνλ€κ³ νκ°ν μ μκ² λ€.
π‘Β Insight 2. Pure Self-Attentionμ μ’μ μ΄λ―Έμ§ νΌμ²λ₯Ό μΆμΆνκΈ°μ μΆ©λΆνλ€
- Patch Embedding Layerμ PCA κ²°κ³Ό, ν¨μΉμ κΈ°μ κ° λλ μ°¨μκ³Ό μ μ¬ν λͺ¨μμ μΆμΆ
Convolutionμμ΄Self-Attentionλ§μΌλ‘λ μΆ©λΆν μ΄λ―Έμ§μ μ’μ νΌμ²λ₯Ό μΆμΆνλ κ²μ΄ κ°λ₯VisionμμConvolutionμ λνrelianceννΌ κ°λ₯
Patch Embedding Layerβs Filter
μ μλ£λ μΆ©λΆν νμ΅μ κ±°μΉκ³ λ ViTμ Patch Embedding Layerμ νν°λ₯Ό PCAν κ²°κ³Ό μ€μμ νΉμκ°μ΄ λμ μμ 28κ°μ νΌμ²λ₯Ό λμ΄ν κ·Έλ¦Όμ΄λ€. μ΄λ―Έμ§μ κΈ°λ³Έ λΌλκ° λκΈ°μ μ ν©ν΄ 보μ΄λ νΌμ²λ€μ΄ μΆμΆλ λͺ¨μ΅μ λ³Ό μ μλ€.
λ°λΌμ Inductive Bias μμ΄, λ¨μΌ Self-Attentionλ§μΌλ‘ μ΄λ―Έμ§μ νΌμ²λ₯Ό μΆμΆνλ κ²μ΄ μΆ©λΆν κ°λ₯νλ€. λΉμ λΆμΌμ λ§μ°ν Convolution μμ‘΄μμ λ²μ΄λ μλ‘μ΄ μν€ν
μ²μ λμ
μ΄ κ°λ₯ν¨μ μμ¬ν λΆλΆμ΄λΌκ³ ν μ μκ² λ€.
π‘Β Insight 3. Bottom2General Information, Top2Specific Information
μ λ ₯κ³Ό κ°κΉμ΄ μΈμ½λμΌμλ‘Global & Generalν Informationμ ν¬μ°©μΆλ ₯κ³Ό κ°κΉμ΄ μΈμ½λμΌμλ‘Local & Specificν Informationμ ν¬μ°©
Multi-Head Attention Distance per Network Depth
λ€μ μλ£λ μΈμ½λμ κ°μ λ³νμ λ°λ₯Έ κ°λ³ μ΄ν
μ
ν΄λμ μ΄ν
μ
거리 λ³ν μΆμ΄λ₯Ό λνλΈ κ·Έλ¦Όμ΄λ€. μ¬κΈ°μ μ΄ν
μ
거리λ, ν΄λκ° μΌλ§λ λ©λ¦¬ λ¨μ΄μ§ ν¨μΉλ₯Ό μ΄ν
μ
νλμ§ ν½μ
λ¨μλ‘ ννν μ§νλ€. ν΄λΉ κ°μ΄ λμμλ‘ κ±°λ¦¬μ λ©λ¦¬ λ¨μ΄μ§ ν¨μΉμ μ΄ν
μ
μ, μμμλ‘ κ°κΉμ΄ ν¨μΉμ μ΄ν
μ
νλ€λ κ²μ μλ―Ένλ€. λ€μ λνλ₯Ό μ΄ν΄λ³΄μ. μ
λ ₯κ³Ό κ°κΉμ΄ μΈμ½λμΌμλ‘(Depth 0) ν΄λλ³ μ΄ν
μ
거리μ λΆμ°μ΄ 컀μ§κ³ , μΆλ ₯κ³Ό κ°κΉμ΄ μΈμ½λμΌμλ‘(Depth 23) λΆμ°μ΄ μ μ μ€μ΄λ€λ€κ° κ±°μ ν μ μ μλ ΄νλλ―ν μμμ 보μ¬μ€λ€. λ€μ λ§ν΄, μ
λ ₯κ³Ό κ°κΉμ΄ Bottom Encoderλ λ©λ¦¬ λ¨μ΄μ§ ν¨μΉλΆν° κ°κΉμ΄ ν¨μΉκΉμ§ λͺ¨λ μ μμ (Global)μΌλ‘ μ΄ν
μ
μ μνν΄ General ν μ 보λ₯Ό ν¬μ°©νκ² λκ³ μΆλ ₯κ³Ό κ°κΉμ΄ Top Encoderλ κ°λ³ ν΄λλ€μ΄ λͺ¨λ λΉμ·ν 거리μ μμΉν ν¨μΉ(Local)μ μ΄ν
μ
μ μνν΄ Specific ν μ 보λ₯Ό ν¬μ°©νκ² λλ€.
μ΄ λ Globalκ³Ό Localμ΄λΌλ μ©μ΄ λλ¬Έμ Bottom Encoder λ λ©λ¦¬ λ¨μ΄μ§ ν¨μΉμ μ΄ν
μ
νκ³ , Top Encoderλ κ°κΉμ΄ ν¨μΉμ μ΄ν
μ
νλ€κ³ μ°©κ°νκΈ° μ½λ€. κ·Έλ¬λ κ°λ³ ν΄λλ€μ μ΄ν
μ
κ±°λ¦¬κ° μΌλ§λ λΆμ°λμ΄ μλκ°κ° λ°λ‘ Global, Localμ ꡬλΆνλ κΈ°μ€μ΄ λλ€. μ
λ ₯λΆμ κ°κΉμ΄ λ μ΄μ΄λ€μ ν€λλ€μ μ΄ν
μ
거리 λΆμ°μ΄ λ§€μ° ν° νΈμΈλ°, μ΄κ²μ μ΄ν¨μΉ μ ν¨μΉ λͺ¨λ μ΄ν
μ
ν΄λ³΄κ³ λΉκ΅ν΄λ³Έλ€κ³ ν΄μν΄μ Globalμ΄λΌκ³ λΆλ₯΄κ³ , μΆλ ₯λΆμ κ°κΉμ΄ λ μ΄μ΄λ ν€λλ€μ μ΄ν
μ
거리 λΆμ°μ΄ λ§€μ° μμ νΈμΈλ°, μ΄κ² λ°λ‘ κ°κ°μ ν€λλ€μ΄ μ΄λ€ μ 보μ μ£Όλͺ©ν΄μΌν μ§(λΆλ₯ μμ€μ΄ κ°μ₯ μμμ§λ ν¨μΉ) λ²μλ₯Ό μΆ©λΆν μ’ν μνμμ νΉμ λΆλΆμλ§ μ§μ€νλ€λ μλ―Έλ‘ ν΄μν΄ Local μ΄λΌκ³ λΆλ₯΄κ² λμλ€.
<Revisiting Few-sample BERT Fine-tuning>λ μμ λΉμ·ν λ§₯λ½μ μ¬μ€μ λν΄ μΈκΈνκ³ μμΌλ μ°Έκ³ ν΄λ³΄μ. μ΄λ¬ν μ¬μ€μ νΈλμ€ν¬λ¨Έ μΈμ½λ κ³μ΄ λͺ¨λΈμ νλν λ Depth λ³λ‘ λ€λ₯Έ Learning Rateμ μ μ©νλ Layerwise Learning Rate Decay μ μ΄μμ΄ λκΈ°λ νλ€. Layerwise Learning Rate Decay μ λν΄μλ μ¬κΈ° ν¬μ€νΈλ₯Ό μ°Έκ³ νλλ‘ νμ.
ννΈ λ
Όλ¬Έμλ μΈκΈλμ§ μμ, νμμ λνΌμ
μ κ°κΉμ§λ§, μΆλ ₯μ κ°κΉμ΄ μΈμ½λλ€μ ν΄λκ° κ°μ§ Attention Distanceμ΄ λͺ¨λ λΉμ·νλ€λ μ¬μ€λ‘ μ΄λ―Έμ§ λΆλ₯μ κ²°μ μ μΈ μν μ νλ νΌμ²κ° μ΄λ―Έμ§μ νΉμ ꡬμμ λͺ¨μ¬ μμΌλ©°, κ·Έ μ€νμ μ΄λ―Έμ§μ μ€μ λΆκ·ΌμΌ κ°λ₯μ±μ΄ λλ€κ³ μΆμΈ‘ ν΄λ³Ό μ μλ€. λͺ¨λ ν΄λμ ν½μ
κ±°λ¦¬κ° μλ‘ λΉμ·νλ €λ©΄ μΌλ¨ λΉμ·ν μμΉμ ν¨μΉμ μ΄ν
μ
μ ν΄μΌνκΈ° λλ¬Έμ λΆλ₯ μμ€κ°μ μ΅μλ‘ μ€μ¬μ£Όλ νΌμ²λ λ³΄ν΅ ν ꡬμ(ν¨μΉ)μ λͺ°λ € μμ κ²μ΄λΌκ³ μ μΆκ° κ°λ₯νλ€. λν νΉμ μ€νμ΄ μ€μμ μμΉν μλ‘ μ΄ν
μ
거리μ λΆμ°μ΄ μ€μ΄λ€κ²μ΄λΌκ³ μκ° ν΄λ³Ό μλ μμλ€. μ μλ Attention Rolloutμ΄λΌλ κ°λ
μ ν΅ν΄ Attention Distanceμ μ°μΆνλ€κ³ μΈκΈνλλ°, μμΈν λ΄μ©μ μμ λ λ§ν¬λ₯Ό μ°Έκ³ ν΄λ³΄μ(νκ΅μ΄ μ€λͺ
λΈλ‘κ·Έ, μλ
Όλ¬Έ). μ΄λ¬ν νμμ κ°μ€μ΄ λ§λ€λ©΄, Convolution μ Inductive Bias μ€ Locality μ ν¨κ³Όμ±μ Self-Attentionμ ν΅ν΄ μ
μ¦μ΄ κ°λ₯νλ©°, λ°λλ‘ Convolutionμ λν μμ‘΄μμ λ²μ΄λ λ¨μΌ Self-Attention μΌλ‘λ κ°μ ν¨κ³Όλ₯Ό λΌ μ μλ€λ μ¦κ±° μ€ νλκ° λ κ²μ΄λ€.
π‘Β Insight 4. ViTλ CLS Pooling μ¬μ©νλκ² ν¨μ¨μ
CLS PoolingμGAPλ³΄λ€ 2λ°° μ΄μ ν° νμ΅λ₯ μ μ¬μ©ν΄λ λΉμ·ν μ±λ₯μ κΈ°λ‘- νμ΅ μλλ λ λΉ λ₯΄λ μ±λ₯μ΄ λΉμ·νκΈ° λλ¬Έμ
CLS Poolingμ΄ λ ν¨μ¨μ
- νμ΅ μλλ λ λΉ λ₯΄λ μ±λ₯μ΄ λΉμ·νκΈ° λλ¬Έμ
Performance Trend by Pooling Method with LR
λ€μ λνλ νλ§ λ°©μκ³Ό νμ΅λ₯ μ λ³λμ λ°λ₯Έ μ νλ λ³ν μΆμ΄λ₯Ό λνλΈ κ·Έλ¦Όμ΄λ€. λΉμ·ν μ±λ₯μ΄λΌλ©΄ CLS Poolingμ΄ GAPλ³΄λ€ 2λ°° μ΄μ ν° νμ΅λ₯ μ μ¬μ©νλ€. νμ΅λ₯ μ΄ ν¬λ©΄ λͺ¨λΈμ μλ ΄ μλκ° λΉ¨λΌμ Έ νμ΅ μλκ° λΉ¨λΌμ§λ μ₯μ μ΄ μλ€. κ·Έλ°λ° μ±λ₯κΉμ§ λΉμ·νλ€λ©΄ ViTλ CLS Poolingμ μ¬μ©νλ κ²μ΄ λ ν¨μ¨μ μ΄λΌκ³ ν μ μκ² λ€.
λμ€μ μκ°μ΄ λλ€λ©΄ λ€λ₯Έ νλ§ λ°©μ, μλ₯Ό λ€λ©΄ Weighted Layer Pooling, GeM Pooling, Attention Pooling κ°μ κ²μ μ μ©ν΄ μ€νν΄λ³΄κ² λ€.
π‘Β Insight 5. ViTλ Absolute 1D-Position Embedding μ¬μ©νλκ² κ°μ₯ ν¨μ¨μ
- μ΄λ€ ννλ‘λ μμΉ μλ² λ© κ°μ μ μν΄μ€λ€λ©΄, ννμ μ’ λ₯μ μκ΄μμ΄ κ±°μ λΉμ·ν μ±λ₯μ 보μ
- μ±λ₯μ΄ λΉμ·νλ©΄, μ§κ΄μ μ΄κ³ ꡬνμ΄ κ°νΈν
Absolute 1D-Position Embeddingλ°©λ²μ μ¬μ©νλ κ²μ΄ κ°μ₯ ν¨μ¨μ ViTλPatch-Levelμ¬μ©ν΄,Pixel-Levelλ³΄λ€ μλμ μΌλ‘ μνμ€ κΈΈμ΄κ° μ§§μ μμΉβ’κ³΅κ° μ 보λ₯Ό μΈμ½λ©νλ λ°©μμ μν₯μ λ λ°μ
Performance Table by making Position Embedding method
μ μ€ν κ²°κ³Όλ Position Embedding μΈμ½λ© λ°©μμ λ°λ₯Έ ViT λͺ¨λΈμ μ±λ₯ λ³ν μΆμ΄λ₯Ό λνλΈ μλ£λ€. μΈμ½λ© ννμ μκ΄μμ΄ μμΉ μλ² λ©μ μ λ¬΄κ° μ±λ₯μ ν° μν₯μ λ―ΈμΉλ€λ μ¬μ€μ μλ €μ£Όκ³ μλ€. ννΈ, μΈμ½λ© νν λ³νμ λ°λ₯Έ μ μλ―Έν μ±λ₯ λ³νλ μμλ€. νμ§λ§ Absolute 1D-Position Embeddingμ 컨μ
μ΄ κ°μ₯ μ§κ΄μ μ΄λ©° ꡬννκΈ° νΈνκ³ μ°μ°λμ΄ λ€λ₯Έ μΈμ½λ©λ³΄λ€ μ λ€λ κ²μ κ°μνλ©΄ ViTμ κ°μ₯ ν¨μ¨μ μΈ μμΉ μλ² λ© λ°©μμ΄λΌκ³ νλ¨ν μ μλ€.
λ
Όλ¬Έμ κ²°κ³Όμ λν΄ ViTκ° μ¬μ©νλ Patch-Level Embeddingμ΄ Pixel-Levelλ³΄λ€ μλμ μΌλ‘ μ§§μ μνμ€ κΈΈμ΄λ₯Ό κ°κΈ° λλ¬Έμ΄λΌκ³ μ€λͺ
νλ€. μλ₯Ό λ€μ΄ 224x224 μ¬μ΄μ¦μ μ΄λ―Έμ§λ₯Ό 16x16 μ¬μ΄μ¦μ ν¨μΉ μ¬λ¬μ₯μΌλ‘ λ§λ λ€κ³ μκ°ν΄λ³΄μ. μλ² λ© μ°¨μμ λ€μ΄κ°λ $N$ μ $(224/16)^2$ , μ¦ 196μ΄ λλ€. ννΈ μ΄κ²μ Pixel-Levelλ‘ μλ² λ© νκ² λλ©΄ $224^2$, μ¦ 50176 κ°μ μνμ€κ° μκΈ΄λ€. λ°λΌμ Pixel-Level μ λΉνλ©΄ ν¨μ¬ μ§§μ μνμ€ κΈΈμ΄λ₯Ό κ°κΈ° λλ¬Έμ Absolute 1D-Position Embedding λ§μΌλ‘λ μΆ©λΆν Spatial Relationμ νμ΅ν μ μλ κ²μ΄λ€.
Absolute 1D-Position Embedding
νμ§λ§, νμλ μμ°μ΄ μ²λ¦¬μ Transformer-XL, XLNet, DeBERTa κ°μ λͺ¨λΈλ€μ΄ Relative Position Embedding λ°©μμ μ μ©ν΄ ν° μ±κ³΅μ κ±°λ λ°κ° μλ€λ μ μ μκ°νλ©΄ μ΄λ° κ²°κ³Όκ° λ©λμ΄ κ°λ©΄μλ μμνλ€.
μ μλ μ€νμ μ¬μ©ν λͺ¨λ λ°μ΄ν° μΈνΈλ₯Ό 224x224λ‘ resize νλ€κ³ λ°νκ³ μλλ°, λ§μ½ μ΄λ―Έμ§ μ¬μ΄μ¦κ° 512x512μ λλ§ λλλΌλ $N$ κ°μ΄ 1024 μ΄λΌμ μ κ²°κ³Όμ μλΉν λ€λ₯Έ μμμ΄ λνλμ§ μμκΉ νλ μκ°μ΄ λ λ€. μΆνμ μκ°μ΄ λλ€λ©΄ μ΄ λΆλΆλ κΌ μ€νν΄λ΄μΌκ² λ€. μμΈ‘μ»¨λ° μ΄λ―Έμ μ¬μ΄μ¦κ° 컀μ§μλ‘ 2D Position Embedding νΉμ Relative Position Embeddingμ΄ λ ν¨μ¨μ μΌ κ²μ΄λΌ μμνλ€.
π§ββοΈΒ Conclusion
μ΄λ κ² ViT λͺ¨λΈμ μ μν <An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale>μ μ€λ¦° λ΄μ©μ λͺ¨λ μ΄ν΄λ³΄μλ€. Conv μ λν μμ‘΄μ ννΌ νλ€λ μ μμ λ§€μ° μλ―Έκ° μλ μλμμΌλ©°, Self-Attention & Transformer ꡬ쑰 μ±νλ§μΌλ‘λ μ»΄ν¨ν° λΉμ μμμ μ΄λ μ λ scalability λ₯Ό μ΄μνλλ° μ±κ³΅νλ€λ μ μμ νλ μ°κ΅¬μ μ€μν μμ¬μ μ λ¨κ²Όλ€. μλμ μΌλ‘ μ 체(??)λμ΄ μλ λΉμ μμμ΄ μ±λ₯μ νκ³λ₯Ό νλ¨κ³ λ°μ΄λμ μ μλ μ΄μμ λ§λ ¨ν΄μ€ μ
μ΄λ€.
νμ§λ§, ViTμ Pretrain Stageμ μ ν©ν Self-Supervised Learning λ°©λ²μ μ°Ύμ§ λͺ»ν΄ μ¬μ ν Supervised Learning λ°©μμ μ±νν μ μ λ§€μ° μμ¬μ λ€. μ΄λ κ²°κ΅ λ°μ΄ν° Scale νμ₯μ νκ³λ₯Ό μλ―ΈνκΈ° λλ¬Έμ΄λ€. μ€λλ BERTμ GPTμ μ±κ³΅ μ νλ λΉλ¨ Self-Attentionμ Transformerμ ꡬ쑰μ νμμ±μ μν΄μλ§ νμνκ² μλλ€. μ΄μ λͺ»μ§ μκ²(κ°μΈμ μΌλ‘ μ μΌ μ€μνλ€ μκ°) μ£Όμνλ κ²μ΄ λ°λ‘ λ°μ΄ν° Scale νμ₯μ΄λ€. MLM, AR λ±μ Self-Supervised Learning λλΆμ λ°μ΄ν° Scaleμ ν¨μ¨μ μΌλ‘ μ€μΌμΌ μ
μν¬ μ μμκ³ , μ¬μ νλ ¨ λ°μ΄ν°μ μ¦κ°λ λͺ¨λΈ κΉμ΄, λλΉ, μ°¨μκΉμ§ λμ± ν¬μΌ ν€μ°λλ° κΈ°μ¬νλ€.
λν ViTλ μ μ²μ μΌλ‘ Patch-Level Embeddingμ μ¬μ©νκΈ° λλ¬Έμ λ€μν μ΄λ―Έμ§ ν
μ€ν¬μ μ μ©νλ κ²μ΄ νλ€λ€. Segmentation, Object Detection κ°μ Taskλ ν½μ
λ¨μλ‘ μμΈ‘μ μνν΄ κ°μ²΄λ₯Ό νμ§νκ±°λ λΆν ν΄μΌ νλ€. νμ§λ§ Patch λ¨μλ‘ νλ ¨μ μννλ ViTλ Pixel λ¨μμ μμΈ‘μ μννλλ° μ΄λ €μμ κ²ͺλλ€.
λ§μ§λ§μΌλ‘ Self-Attention μ체μ Computational Overheadκ° λ무 μ¬ν΄ κ³ ν΄μλμ μ΄λ―Έμ§λ₯Ό μ μ ν λ€λ£¨κΈ° νλ€λ€. μμμλ μΈκΈνμ§λ§ μ΄λ―Έμ§μ μ¬μ΄μ¦κ° 512x512λ§ λμ΄λ μ΄λ―Έ ν¨μΉμ κ°μκ° 1024κ° λλ€. μ¬μ΄μ¦κ° 컀μ§μλ‘ μνμ€ κΈΈμ΄ μμ κΈ°νκΈμμ μΌλ‘ 컀μ§λλ°λ€κ° Self-Attention λ 쿼리μ ν€ νλ ¬μ λ΄μ (μκΈ° μμ κ³Ό κ³±μ΄λΌ λ³Ό μ μμ) νκΈ° λλ¬Έμ Computational Overheadκ° $N^2$μ΄ λλ€.
νμλ ViTλ₯Ό μ λ°μ μ±κ³΅μ΄λΌκ³ ννκ³ μΆλ€. λ³Έλ ViTμ μ€κ³ λͺ©μ μ λΉμ λΆμΌμ Convμ λν μμ‘΄μ ννΌνλ©΄μ, ν¨μ΄ν Self-Attentionμ λμ
ν΄ Scalabilty λ₯Ό μ΄μνλ κ²μ΄μλ€. Self-Attentionμ λμ
νλλ°λ μ±κ³΅νμ§λ§, μ¬μ ν λ€λ£° μ μλ μ΄λ―Έμ§ μ¬μ΄μ¦λ Taskμλ νκ³κ° λΆλͺ
νλ©° κ²°μ μ μΌλ‘ Self-Supervised Learning λ°©μμ λμ
νμ§ λͺ»νλ€. Scalabilty λΌλ λ¨μ΄μ μλ―Έλ₯Ό μκ°νλ©΄, λ°©κΈ λ§ν λΆλΆμμκΉμ§ νμ₯μ±μ΄ μμ΄μΌ μ€κ³ μλμ λΆν©νλ κ²°κ³ΌλΌκ³ μκ°νλ€.
Leave a comment