SwinTransformer模型转化:pytorch模型转keras。
Posted AI浩
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了SwinTransformer模型转化:pytorch模型转keras。相关的知识,希望对你有一定的参考价值。
SwinTransformer官方模型只有pytorch,没有keras,需要转换才可以用。
这篇文章记录一下如何实现模型的转化:
新建model.py,插入如下代码:
这段代码是SwinTransformer模型的keras实现。
import tensorflow as tf
from tensorflow.keras import Model, layers, initializers
import numpy as np
class PatchEmbed(layers.Layer):
"""
2D Image to Patch Embedding
"""
def __init__(self, patch_size=4, embed_dim=96, norm_layer=None):
super(PatchEmbed, self).__init__()
self.embed_dim = embed_dim
self.patch_size = (patch_size, patch_size)
self.norm = norm_layer(epsilon=1e-6, name="norm") if norm_layer else layers.Activation('linear')
self.proj = layers.Conv2D(filters=embed_dim, kernel_size=patch_size,
strides=patch_size, padding='SAME',
kernel_initializer=initializers.LecunNormal(),
bias_initializer=initializers.Zeros(),
name="proj")
def call(self, x, **kwargs):
_, H, W, _ = x.shape
# padding
# 支持多尺度
# 如果输入图片的H,W不是patch_size的整数倍,需要进行padding
pad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0)
if pad_input:
paddings = tf.constant([[0, 0],
[0, self.patch_size[0] - H % self.patch_size[0]],
[0, self.patch_size[1] - W % self.patch_size[1]]])
x = tf.pad(x, paddings)
# 下采样patch_size倍
x = self.proj(x)
B, H, W, C = x.shape
# [B, H, W, C] -> [B, H*W, C]
x = tf.reshape(x, [B, -1, C])
x = self.norm(x)
return x, H, W
def window_partition(x, window_size: int):
"""
将feature map按照window_size划分成一个个没有重叠的window
Args:
x: (B, H, W, C)
window_size (int): window size(M)
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = tf.reshape(x, [B, H // window_size, window_size, W // window_size, window_size, C])
# transpose: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H//Mh, W//Mh, Mw, Mw, C]
# reshape: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B*num_windows, Mh, Mw, C]
x = tf.transpose(x, [0, 1, 3, 2, 4, 5])
windows = tf.reshape(x, [-1, window_size, window_size, C])
return windows
def window_reverse(windows, window_size: int, H: int, W: int):
"""
将一个个window还原成一个feature map
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size(M)
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
# reshape: [B*num_windows, Mh, Mw, C] -> [B, H//Mh, W//Mw, Mh, Mw, C]
x = tf.reshape(windows, [B, H // window_size, W // window_size, window_size, window_size, -1])
# permute: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B, H//Mh, Mh, W//Mw, Mw, C]
# reshape: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H, W, C]
x = tf.transpose(x, [0, 1, 3, 2, 4, 5])
x = tf.reshape(x, [B, H, W, -1])
return x
class PatchMerging(layers.Layer):
def __init__(self, dim: int, norm_layer=layers.LayerNormalization, name=None):
super(PatchMerging, self).__init__(name=name)
self.dim = dim
self.reduction = layers.Dense(2 * dim,
use_bias=False,
kernel_initializer=initializers.TruncatedNormal(stddev=0.02),
name="reduction")
self.norm = norm_layer(epsilon=1e-6, name="norm")
def call(self, x, H, W):
"""
x: [B, H*W, C]
"""
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
x = tf.reshape(x, [B, H, W, C])
# padding
# 如果输入feature map的H,W不是2的整数倍,需要进行padding
pad_input = (H % 2 != 0) or (W % 2 != 0)
if pad_input:
paddings = tf.constant([[0, 0],
[0, 1],
[0, 1],
[0, 0]])
x = tf.pad(x, paddings)
x0 = x[:, 0::2, 0::2, :] # [B, H/2, W/2, C]
x1 = x[:, 1::2, 0::2, :] # [B, H/2, W/2, C]
x2 = x[:, 0::2, 1::2, :] # [B, H/2, W/2, C]
x3 = x[:, 1::2, 1::2, :] # [B, H/2, W/2, C]
x = tf.concat([x0, x1, x2, x3], -1) # [B, H/2, W/2, 4*C]
x = tf.reshape(x, [B, -1, 4 * C]) # [B, H/2*W/2, 4*C]
x = self.norm(x)
x = self.reduction(x) # [B, H/2*W/2, 2*C]
return x
class MLP(layers.Layer):
"""
MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
k_ini = initializers.TruncatedNormal(stddev=0.02)
b_ini = initializers.Zeros()
def __init__(self, in_features, mlp_ratio=4.0, drop=0., name=None):
super(MLP, self).__init__(name=name)
self.fc1 = layers.Dense(int(in_features * mlp_ratio), name="fc1",
kernel_initializer=self.k_ini, bias_initializer=self.b_ini)
self.act = layers.Activation("gelu")
self.fc2 = layers.Dense(in_features, name="fc2",
kernel_initializer=self.k_ini, bias_initializer=self.b_ini)
self.drop = layers.Dropout(drop)
def call(self, x, training=None):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x, training=training)
x = self.fc2(x)
x = self.drop(x, training=training)
return x
class WindowAttention(layers.Layer):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
attn_drop_ratio (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop_ratio (float, optional): Dropout ratio of output. Default: 0.0
"""
k_ini = initializers.GlorotUniform()
b_ini = initializers.Zeros()
def __init__(self,
dim,
window_size,
num_heads=8,
qkv_bias=False,
attn_drop_ratio=0.,
proj_drop_ratio=0.,
name=None):
super(WindowAttention, self).__init__(name=name)
self.dim = dim
self.window_size = window_size # [Mh, Mw]
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = layers.Dense(dim * 3, use_bias=qkv_bias, name="qkv",
kernel_initializer=self.k_ini, bias_initializer=self.b_ini)
self.attn_drop = layers.Dropout(attn_drop_ratio)
self.proj = layers.Dense(dim, name="proj",
kernel_initializer=self.k_ini, bias_initializer=self.b_ini)
self.proj_drop = layers.Dropout(proj_drop_ratio)
def build(self, input_shape):
# define a parameter table of relative position bias
# [2*Mh-1 * 2*Mw-1, nH]
self.relative_position_bias_table = self.add_weight(
shape=[(2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), self.num_heads],
initializer=initializers.TruncatedNormal(stddev=0.02),
trainable=True,
dtype=tf.float32,
name="relative_position_bias_table"
)
coords_h = np.arange(self.window_size[0])
coords_w = np.arange(self.window_size[1])
coords = np.stack(np.meshgrid(coords_h, coords_w, indexing="ij")) # [2, Mh, Mw]
coords_flatten = np.reshape(coords, [2, -1]) # [2, Mh*Mw]
# [2, Mh*Mw, 1] - [2, 1, Mh*Mw]
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # [2, Mh*Mw, Mh*Mw]
relative_coords = np.transpose(relative_coords, [1, 2, 0]) # [Mh*Mw, Mh*Mw, 2]
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # [Mh*Mw, Mh*Mw]
self.relative_position_index = tf.Variable(tf.convert_to_tensor(relative_position_index),
trainable=False,
dtype=tf.int32,
name="relative_position_index")
def call(self, x, mask=None, training=None):
"""
Args:
x: input features with shape of (num_windows*B, Mh*Mw, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
training: whether training mode
"""
# [batch_size*num_windows, Mh*Mw, total_embed_dim]
B_, N, C = x.shape
# qkv(): -> [batch_size*num_windows, Mh*Mw, 3 * total_embed_dim]
qkv = self.qkv(x)
# reshape: -> [batch_size*num_windows, Mh*Mw, 3, num_heads, embed_dim_per_head]
qkv = tf.reshape(qkv, [B_, N, 3, self.num_heads, C // self.num_heads])
# transpose: -> [3, batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
qkv = tf.transpose(qkv, [2, 0, 3, 1, 4])
# [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
q, k, v = qkv[0], qkv[1], qkv[2]
# transpose: -> [batch_size*num_windows, num_heads, embed_dim_per_head, Mh*Mw]
# multiply -> [batch_size*num_windows, num_heads, Mh*Mw, Mh*Mw]
attn = tf.matmul(a=q, b=k, transpose_b=True) * self.scale
# relative_position_bias(reshape): [Mh*Mw*Mh*Mw,nH] -> [Mh*Mw,Mh*Mw,nH]
relative_position_bias = tf.gather(self.relative_position_bias_table,
tf.reshape(self.relative_position_index, [-1]))
relative_position_bias = tf.reshape(relative_position_bias,
[self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
-1])
relative_position_bias = tf.transpose(relative_position_bias, [2, 0, 1]) # [nH, Mh*Mw, Mh*Mw]
attn = attn + tf.expand_dims(relative_position_bias, 0)
if mask is not None:
# mask: [nW, Mh*Mw, Mh*Mw]
nW = mask.shape[0] # num_windows
# attn(reshape): [batch_size, num_windows, num_heads, Mh*Mw, Mh*Mw]
# mask(expand_dim): [1, nW, 1, Mh*Mw, Mh*Mw]
attn = tf.reshape(attn, [B_ // nW, nW, self.num_heads, N, N]) + tf.expand_dims(tf.expand_dims(mask, 1), 0)
attn = tf.reshape(attn, [-1, self.num_heads, N, N])
attn = tf.nn.softmax(attn, axis=-1)
attn = self.attn_drop(attn, training=training)
# multiply -> [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
x = tf.matmul(attn, v)
# transpose: -> [batch_size*num_windows, Mh*Mw, num_heads, embed_dim_per_head]
x = tf.transpose(x, [0, 2, 1, 3])
# reshape: -> [batch_size*num_windows, Mh*Mw, total_embed_dim]
x = tf.reshape(x, [B_, N, C])
x = self.proj(x)
x = self.proj_drop(x, training=training)
return x
class SwinTransformerBlock(layers.Layer):
r""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
"""
def __init__(self, dim, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., name=None):
super().__init__(name=name)
self.dim = dim
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = layers.LayerNormalization(epsilon=1e-6, name="norm1")
self.attn = WindowAttention(dim,
window_size=(window_size, window_size),
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop_ratio=attn_drop,
以上是关于SwinTransformer模型转化:pytorch模型转keras。的主要内容,如果未能解决你的问题,请参考以下文章
PyTorch笔记 - SwinTransformer的原理与实现