PyTorch笔记 - Vision Transformer(ViT)

Posted SpikeKing

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch笔记 - Vision Transformer(ViT)相关的知识,希望对你有一定的参考价值。

Transformer包含Encoder和Decoder,核心是Multi-Head Self-Attention(空间融合),FeedForward Nerual Network(通道融合)。

Encoder和Decoder的交互信息:Memory-base Multi-Head Cross-Attention

注入位置信息Position Embedding

数据量的要求与 归纳偏置(Inductive Bias) 的引入成反比,上限很高,数据量要求也很高。

归纳法、演绎法,归纳偏置(Inductive Bias),将人类的经验带入模型的设计当中。

Transformer的使用场景:

  • Encoder Only:BERT、分类任务、非流式任务
  • Decoder Only:GPT系列、语言建模、自回归生成任务、流式任务
  • Encoder-Decoder:机器翻译、语言识别

Vision Transformer(ViT):

  • DNN perspective(视角): Image2Patch、Patch2Embedding
  • CNN perspective(视角): 2D Convolution over image
  • Class Token Embedding,占位符
  • Position Embedding: Interpolation(插入) when inference
  • Transformer Encoder
  • Classification Head

Paper: An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale

Classification Token:起到Query的作用

Linear Projection of Flattened Patches -> Patch + Position Embedding -> Transformer Encoder -> MLP Head

Patch + Position Embedding,先从左到右,再从上到下,拉成序列形状

实现Image2Embedding,TransformerEncoder由PyTorch封装

ViT:

import torch
import torch.nn as nn
import torch.nn.functional as F


# step1 convert image to embedding vector sequence
def image2emb_naive(image, patch_size, weight):
    """
    使用unfold生成patch
    """
    # image shape: bs*channel*h*w
    # 没有交叠,stride=patch_size,直接生成patch
    patch = F.unfold(image, kernel_size=patch_size, stride=patch_size)
    patch = patch.transpose(2, 1)
    # (bs, patch_depth(patch_size*patch_size*ic), num_patch)
    print(f'patch: patch.shape')
    patch_embedding = patch @ weight  # 输出的embeding
    print(f'patch_embedding: patch_embedding.shape')
    return patch_embedding


def image2emb_conv(image, kernel, stride):
    """
    使用conv生成patch
    """
    conv_output = F.conv2d(image, kernel, stride=stride)  # bs*oc*oh*ow
    bs, oc, oh, ow = conv_output.shape
    patch_embedding = conv_output.reshape((bs, oc, oh*ow)).transpose(2, 1)
    print(f'patch_embedding: patch_embedding.shape')
    return patch_embedding


# test code for image2emb
bs, ic, image_h, image_w = 1, 3, 8, 8
patch_size = 4
model_dim = 8  # embedding dim
max_num_token = 16 
num_classes = 10
label = torch.randint(10, (bs,))
patch_depth = patch_size*patch_size*ic

# 分块方法得到embedding
torch.manual_seed(42)
image = torch.randn((bs, ic, image_h, image_w))  # 生成图像
weight = torch.randn((patch_depth, model_dim))  # patch_depth -> model_dim, model_dim是输出通道数目
print(f'weight: weight.shape')
patch_embedding_naive = image2emb_naive(image, patch_size, weight)
print(f'patch_embedding_naive: \\npatch_embedding_naive')

# 二维卷积方法得到embedding
# kernel的形状,oc*ic*k_h*k_w
kernel = weight.transpose(1, 0).reshape((model_dim, ic, patch_size, patch_size))
patch_embedding_conv = image2emb_conv(image, kernel, stride=patch_size)
print(f'patch_embedding_conv: \\npatch_embedding_conv')


# step2 prepend CLS token embedding
cls_token_embedding = torch.randn((bs, 1, model_dim), requires_grad=True)
token_embedding = torch.cat([cls_token_embedding, patch_embedding_conv], dim=1)
print(f'token_embedding: token_embedding.shape')


# step3 add position embedding
position_embedding_table = torch.randn((max_num_token, model_dim), requires_grad=True)
seq_len = token_embedding.shape[1]
# 复制 position_embedding 操作
position_embedding = torch.tile(position_embedding_table[:seq_len], [token_embedding.shape[0], 1, 1])   
token_embedding += position_embedding
print(f'token_embedding: token_embedding.shape')


# step4 pass embedding to Transformer Encoder
encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim, nhead=8)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
encoder_output = transformer_encoder(token_embedding)


# step5 do classification
cls_token_output = encoder_output[:, 0, :]
linear_layer = nn.Linear(model_dim, num_classes)
logits = linear_layer(cls_token_output)
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits, label)
print(f'loss: loss')

以上是关于PyTorch笔记 - Vision Transformer(ViT)的主要内容,如果未能解决你的问题,请参考以下文章

PyTorch笔记 - Vision Transformer(ViT)

PyTorch笔记 - MAE: Masked Autoencoders Are Scalable Vision Learners

PyTorch笔记 - MAE: Masked Autoencoders Are Scalable Vision Learners

PyTorch笔记 - MAE: Masked Autoencoders Are Scalable Vision Learners

pytorch实现 vision_transformer

pytorch实现 vision_transformer