PyTorch 打印模型的FLOPs(torchstat)

Posted 梁小憨憨

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch 打印模型的FLOPs(torchstat)相关的知识,希望对你有一定的参考价值。

最近看论文发现有些作者会把模型的浮点运算数给展现出来,自己也学习一下,记录下来方便以后查阅。

PyTorch 打印模型的FLOPs(torchstat)

区分FLOPs和FLOPS

FLOPS:注意全大写,是floating point operations per second的缩写,意指每秒浮点运算次数,理解为计算速度。是一个衡量硬件性能的指标。

FLOPs:注意s小写,是floating point operations的缩写(s表复数),意指浮点运算数,理解为计算量。可以用来衡量算法/模型的复杂度。

注意在深度学习中,我们用的是FLOPs,也就是说计算量,即用来衡量算法/模型的复杂度。

安装 torchstat

pip install torchstat
import torch
import torch.nn as nn
from torchstat import stat



class Corr_CNN(nn.Module):
    def __init__(self, Filters, channels, dropoutRate_1, dropoutRate_2, n_classes):
        super(Corr_CNN, self).__init__()

        self.conv_1 = nn.Conv2d(
            in_channels=1, 
            out_channels=Filters,
            kernel_size=(1, channels), 
            bias=False
        )

        self.activate_1 = nn.ReLU()

        self.bn_1 = nn.BatchNorm2d(num_features=Filters)

        self.dropout_1 = nn.Dropout(p=dropoutRate_1)

        self.conv_2 = nn.Conv2d(
            in_channels=Filters,
            out_channels=Filters,
            kernel_size=(channels, 1),
            bias=False
        )

        self.activate_2 = nn.ReLU()

        self.bn_2 = nn.BatchNorm2d(num_features=Filters)

        self.dropout_2 = nn.Dropout(p=dropoutRate_2)

        self.fc = nn.Linear(
            in_features=Filters,
            out_features=n_classes,
        )

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        # input shape (batch_size, C, C)
        if len(x.shape) is not 4:
            x = torch.unsqueeze(x, 1)
        # input shape (batch_size, 1, C, C)
        x = self.conv_1(x)
        x = self.activate_1(x)
        x = self.bn_1(x)
        x = self.dropout_1(x)
        x = self.conv_2(x)
        x = self.activate_2(x)
        x = self.bn_2(x)
        x = self.dropout_2(x)
        x = x.view(x.size()[0], -1)  # Flatten # (batch_size*Filters, -1)
        x = self.fc(x)
        out = self.softmax(x)

        return out


###============================ Initialization parameters ============================###
Filters = 30
channels = 62
dropoutRate_1 = 0.3
dropoutRate_2 = 0.3
n_classes = 3

def main():
    input = torch.randn(32, channels, channels)
    model = Corr_CNN(Filters, channels, dropoutRate_1, dropoutRate_2, n_classes)
    out = model(input)
    print('===============================================================')
    print('out', out.shape)
    print('model', model)
    stat(model, (1, channels, channels))

if __name__ == "__main__":
    main()
[MAdd]: Dropout is not supported!
[Flops]: Dropout is not supported!
[Memory]: Dropout is not supported!
[MAdd]: Dropout is not supported!
[Flops]: Dropout is not supported!
[Memory]: Dropout is not supported!
[Flops]: Softmax is not supported!
[Memory]: Softmax is not supported!
      module name  input shape output shape   params memory(MB)       MAdd      Flops  MemRead(B)  MemWrite(B) duration[%]  MemR+W(B)
0          conv_1    1  62  62   30  62   1   1860.0       0.01  228,780.0  115,320.0     22816.0       7440.0      26.41%    30256.0
1      activate_1   30  62   1   30  62   1      0.0       0.01    1,860.0    1,860.0      7440.0       7440.0       4.73%    14880.0
2            bn_1   30  62   1   30  62   1     60.0       0.01    7,440.0    3,720.0      7680.0       7440.0       9.57%    15120.0
3       dropout_1   30  62   1   30  62   1      0.0       0.01        0.0        0.0         0.0          0.0       2.73%        0.0
4          conv_2   30  62   1   30   1   1  55800.0       0.00  111,570.0   55,800.0    230640.0        120.0      34.97%   230760.0
5      activate_2   30   1   1   30   1   1      0.0       0.00       30.0       30.0       120.0        120.0       3.87%      240.0
6            bn_2   30   1   1   30   1   1     60.0       0.00      120.0       60.0       360.0        120.0       6.30%      480.0
7       dropout_2   30   1   1   30   1   1      0.0       0.00        0.0        0.0         0.0          0.0       2.03%        0.0
8              fc           30            3     93.0       0.00      177.0       90.0       492.0         12.0       6.03%      504.0
9         softmax            3            3      0.0       0.00        8.0        0.0         0.0          0.0       3.33%        0.0
total                                        57873.0       0.03  349,985.0  176,880.0         0.0          0.0      99.99%   292240.0
=====================================================================================================================================
Total params: 57,873
-------------------------------------------------------------------------------------------------------------------------------------
Total memory: 0.03MB
Total MAdd: 349.98KMAdd
Total Flops: 176.88KFlops
Total MemR+W: 285.39KB

assert len(inp.size()) == 2 and len(out.size()) == 2

有时运行会报上面的错误,这是因为有时候你得网络的中的全连接层的输入不一定是二维的,有可能大于二维,此时就会在上面的语句出现异常。

assert len(inp.size()) >= 2 and len(out.size()) >= 2

assert len(inp.size()) >= 2 and len(out.size()) >= 2

assert len(inp.size()) >= 2 and len(out.size()) >= 2

AttributeError: ‘tuple’ object has no attribute ‘size’

有时我们的网络会有LSTM模块,此时在前向传播过程中就会出现这么一句话x, (h_1, c_1) = self.lstm_1(x),它的输出的第二项是一个元组,这就是导致上述错误的原因,此时我们只需要用到lstm的输出的第一项,也就是x,那么我就可以在torchstat的源码中做如下更改:

module.output_shape = torch.from_numpy(
    np.array(output[0].size()[1:], dtype=np.int32))

inference_memory = 1
for s in output[0].size()[1:]:
    inference_memory *= s

以上是关于PyTorch 打印模型的FLOPs(torchstat)的主要内容,如果未能解决你的问题,请参考以下文章

PyTorch 打印模型结构输出维度和参数信息(torchsummary)

PyTorch中通过torch.save保存模型和torch.load加载模型介绍

深度学习模型参数量以及FLOPs计算工具

深度学习模型参数量以及FLOPs计算工具

torch模型删不掉

Torch 老司机必看 | 在 PyTorch 中加载 Torch 模型