PyTorch 打印模型的FLOPs(torchstat)
Posted 梁小憨憨
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch 打印模型的FLOPs(torchstat)相关的知识,希望对你有一定的参考价值。
最近看论文发现有些作者会把模型的浮点运算数给展现出来,自己也学习一下,记录下来方便以后查阅。
PyTorch 打印模型的FLOPs(torchstat)
- 区分FLOPs和FLOPS
- 安装 torchstat
- assert len(inp.size()) == 2 and len(out.size()) == 2
- AttributeError: 'tuple' object has no attribute 'size'
区分FLOPs和FLOPS
FLOPS
:注意全大写,是floating point operations per second的缩写,意指每秒浮点运算次数,理解为计算速度。是一个衡量硬件性能的指标。
FLOPs
:注意s小写,是floating point operations的缩写(s表复数),意指浮点运算数,理解为计算量。可以用来衡量算法/模型的复杂度。
注意在深度学习中,我们用的是FLOPs,也就是说计算量,即用来衡量算法/模型的复杂度。
安装 torchstat
pip install torchstat
import torch
import torch.nn as nn
from torchstat import stat
class Corr_CNN(nn.Module):
def __init__(self, Filters, channels, dropoutRate_1, dropoutRate_2, n_classes):
super(Corr_CNN, self).__init__()
self.conv_1 = nn.Conv2d(
in_channels=1,
out_channels=Filters,
kernel_size=(1, channels),
bias=False
)
self.activate_1 = nn.ReLU()
self.bn_1 = nn.BatchNorm2d(num_features=Filters)
self.dropout_1 = nn.Dropout(p=dropoutRate_1)
self.conv_2 = nn.Conv2d(
in_channels=Filters,
out_channels=Filters,
kernel_size=(channels, 1),
bias=False
)
self.activate_2 = nn.ReLU()
self.bn_2 = nn.BatchNorm2d(num_features=Filters)
self.dropout_2 = nn.Dropout(p=dropoutRate_2)
self.fc = nn.Linear(
in_features=Filters,
out_features=n_classes,
)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
# input shape (batch_size, C, C)
if len(x.shape) is not 4:
x = torch.unsqueeze(x, 1)
# input shape (batch_size, 1, C, C)
x = self.conv_1(x)
x = self.activate_1(x)
x = self.bn_1(x)
x = self.dropout_1(x)
x = self.conv_2(x)
x = self.activate_2(x)
x = self.bn_2(x)
x = self.dropout_2(x)
x = x.view(x.size()[0], -1) # Flatten # (batch_size*Filters, -1)
x = self.fc(x)
out = self.softmax(x)
return out
###============================ Initialization parameters ============================###
Filters = 30
channels = 62
dropoutRate_1 = 0.3
dropoutRate_2 = 0.3
n_classes = 3
def main():
input = torch.randn(32, channels, channels)
model = Corr_CNN(Filters, channels, dropoutRate_1, dropoutRate_2, n_classes)
out = model(input)
print('===============================================================')
print('out', out.shape)
print('model', model)
stat(model, (1, channels, channels))
if __name__ == "__main__":
main()
[MAdd]: Dropout is not supported!
[Flops]: Dropout is not supported!
[Memory]: Dropout is not supported!
[MAdd]: Dropout is not supported!
[Flops]: Dropout is not supported!
[Memory]: Dropout is not supported!
[Flops]: Softmax is not supported!
[Memory]: Softmax is not supported!
module name input shape output shape params memory(MB) MAdd Flops MemRead(B) MemWrite(B) duration[%] MemR+W(B)
0 conv_1 1 62 62 30 62 1 1860.0 0.01 228,780.0 115,320.0 22816.0 7440.0 26.41% 30256.0
1 activate_1 30 62 1 30 62 1 0.0 0.01 1,860.0 1,860.0 7440.0 7440.0 4.73% 14880.0
2 bn_1 30 62 1 30 62 1 60.0 0.01 7,440.0 3,720.0 7680.0 7440.0 9.57% 15120.0
3 dropout_1 30 62 1 30 62 1 0.0 0.01 0.0 0.0 0.0 0.0 2.73% 0.0
4 conv_2 30 62 1 30 1 1 55800.0 0.00 111,570.0 55,800.0 230640.0 120.0 34.97% 230760.0
5 activate_2 30 1 1 30 1 1 0.0 0.00 30.0 30.0 120.0 120.0 3.87% 240.0
6 bn_2 30 1 1 30 1 1 60.0 0.00 120.0 60.0 360.0 120.0 6.30% 480.0
7 dropout_2 30 1 1 30 1 1 0.0 0.00 0.0 0.0 0.0 0.0 2.03% 0.0
8 fc 30 3 93.0 0.00 177.0 90.0 492.0 12.0 6.03% 504.0
9 softmax 3 3 0.0 0.00 8.0 0.0 0.0 0.0 3.33% 0.0
total 57873.0 0.03 349,985.0 176,880.0 0.0 0.0 99.99% 292240.0
=====================================================================================================================================
Total params: 57,873
-------------------------------------------------------------------------------------------------------------------------------------
Total memory: 0.03MB
Total MAdd: 349.98KMAdd
Total Flops: 176.88KFlops
Total MemR+W: 285.39KB
assert len(inp.size()) == 2 and len(out.size()) == 2
有时运行会报上面的错误,这是因为有时候你得网络的中的全连接层的输入不一定是二维的,有可能大于二维,此时就会在上面的语句出现异常。
assert len(inp.size()) >= 2 and len(out.size()) >= 2
assert len(inp.size()) >= 2 and len(out.size()) >= 2
assert len(inp.size()) >= 2 and len(out.size()) >= 2
AttributeError: ‘tuple’ object has no attribute ‘size’
有时我们的网络会有LSTM模块,此时在前向传播过程中就会出现这么一句话x, (h_1, c_1) = self.lstm_1(x)
,它的输出的第二项是一个元组,这就是导致上述错误的原因,此时我们只需要用到lstm的输出的第一项,也就是x
,那么我就可以在torchstat的源码中做如下更改:
module.output_shape = torch.from_numpy(
np.array(output[0].size()[1:], dtype=np.int32))
inference_memory = 1
for s in output[0].size()[1:]:
inference_memory *= s
以上是关于PyTorch 打印模型的FLOPs(torchstat)的主要内容,如果未能解决你的问题,请参考以下文章
PyTorch 打印模型结构输出维度和参数信息(torchsummary)