基于chatGPT设计卷积神经网络

Posted littletomatodonkey

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了基于chatGPT设计卷积神经网络相关的知识,希望对你有一定的参考价值。

1. 简介

本文主要介绍基于chatGPT,设计一个针对骁龙855芯片设计的友好型神经网络。

提问->跑通总共花了5min左右,最终得到的网络在Cifar100数据集上与ResNet18的精度对比如下。

模型flopsparamstrain acc1/5test acc1/5
ResNet18(timm)1.8211.18~98%/100%63.11%/86.11%
CNN-GPT1.516.66~98%/100%68.44%/89.14%

此外,GPT生成的模型训练速度更快,ResNet18才训练170个epoch,CNN-GPT已经完整地训练完了,具体的模型结构在文末,欢迎尝试使用~

个人认为比较有意思的一点是,GPT生成的模型并没有在stem block中就进行降采样,而是卷了一会之后才开始,这个与之前在文本任务中的经验结果类似,特征图越大,会避免信息损失。
基于chatGPT生成模型结构的方式或许在某些情况下可以解决我们的问题,但是基础的神经网络知识还是需要的,否则它给出的答案有些小bug,我们也不知道怎么修改。
模型训练结果与超参设置强相关,在这里仅设置默认的训练策略作为对比,根据训练结果看出,两个模型均过拟合严重。而且,不管是ResNet18或者CNN-GPT,均与Cifar100数据集的SOTA精度相差很远,此处需要注意。

2. 具体步骤

Step1 骁龙855芯片上可以正常运行的算子

基于上述问题,chatGPT给出的回答如下所示。

以下是在骁龙855芯片上比较友好的神经网络算子:

1. Conv2d
2. Depthwise Conv2d
3. Maxpool2d
4. Avgpool2d
5. BatchNorm2d
6. ReLU
7. Flatten
8. Linear

整体看起来还算比较靠谱。

Step2 构建一个模型大小在10M左右的模型

针对上面的问题,chatGPT首先给出了思路。

如果想设计一个在骁龙855芯片上运行速度很快,模型大小在10M以内的模型,可以考虑使用以下方法来进行优化:

1. 使用深度可分离卷积代替标准卷积,以获得更快的速度和更小的模型尺寸。
2. 使用深度方向的滤波器来替换全连接层,以减少参数数量并减少内存使用。
3. 对于每个卷积层,使用Batch Norm层来加速收敛。
4. 使用ReLU代替传统的Sigmoid或Tanh激活函数,因为ReLU的计算效率更高。

然后给出了一个模型示例。

以下是一个可以作为参考的模型:

Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
BatchNorm2d(32)
ReLU()

DepthwiseConv2d(32, kernel_size=3, stride=1, padding=1)
BatchNorm2d(32)
ReLU()

Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
BatchNorm2d(64)
ReLU()

DepthwiseConv2d(64, kernel_size=3, stride=1, padding=1)
BatchNorm2d(64)
ReLU()

Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
BatchNorm2d(128)
ReLU()

DepthwiseConv2d(128, kernel_size=3, stride=1, padding=1)
BatchNorm2d(128)
ReLU()

Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
BatchNorm2d(256)
ReLU()

DepthwiseConv2d(256, kernel_size=3, stride=1, padding=1)
BatchNorm2d(256)
ReLU()

Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
BatchNorm2d(512)
ReLU()

DepthwiseConv2d(512, kernel_size=3, stride=1, padding=1)
BatchNorm2d(512)
ReLU()

Conv2d(512, 1024, kernel_size=3, stride=2, padding=1)
BatchNorm2d(1024)
ReLU()

Avgpool2d(kernel_size=4)
Flatten()
Linear(1024, 10)

上面生成的模型有几个小问题。

  1. DepthwiseConv2d没有被定义
  2. DepthwiseConv2d的参数列表中没有out channel的参数,需要手动补充。

Step3 生成一个DepthwiseConv2d结构

chatGPT回答如下

以下是一个非常简单的torch的DepthwiseConv2d实现:

import torch.nn as nn

class DepthwiseConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=None):
        super(DepthwiseConv2d, self).__init__()
        if groups is None:
            groups = in_channels
        self.depthwise_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups=groups)
        self.pointwise_conv = nn.Conv2d(out_channels, out_channels, 1, 1, 0, 1, groups=1)

    def forward(self, x):
        x = self.depthwise_conv(x)
        x = self.pointwise_conv(x)
        return x

加入到模型结构定义文件中,同时修改类别数,就可以进行模型训练了。

设置训练超参数

为保证对比的公平性,下面是训练过程中的一些超参数,

	parser.add_argument("--epoch", type=int, default=240)
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--lr", type=float, default=0.05)
    parser.add_argument("--lr_decay_rate", type=float, default=0.1)
    parser.add_argument("--lr_decay_stages", type=str, default="150,180,210")
    parser.add_argument("--wd", type=float, default=0.0005)

3. 完整的模型结构

完整的模型结构如下,也欢迎大家在自己的数据集上尝试。

import torch
import torch.nn as nn

class DepthwiseConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=None):
        super(DepthwiseConv2d, self).__init__()
        if groups is None:
            groups = in_channels
        self.depthwise_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups=groups)
        self.pointwise_conv = nn.Conv2d(out_channels, out_channels, 1, 1, 0, 1, groups=1)

    def forward(self, x):
        x = self.depthwise_conv(x)
        x = self.pointwise_conv(x)
        return x


class CNNGPT(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),

            DepthwiseConv2d(32, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),

            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),

            DepthwiseConv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),

            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),

            DepthwiseConv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),

            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),

            DepthwiseConv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),

            nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),

            DepthwiseConv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),

            nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(1024),

            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(1024, 100),
        )
    
    def forward(self, x):
        y = self.model(x)
        return y


def get_flops_params(model):
    from thop import profile

    model.eval()
    flops, params = profile(
        model,
        inputs=[
            torch.randn([1, 3, 224, 224]),
        ],
    )
    print(f"flops: flops/1000**3 G, params: params/1000**2 M")
    return flops, params


if __name__ == "__main__":
    model = CNNGPT()
    get_flops_params(model)

以上是关于基于chatGPT设计卷积神经网络的主要内容,如果未能解决你的问题,请参考以下文章

基于卷积神经网络CNN的分类研究,基于卷积神经网络的手写体识别

基于卷积神经网络CNN的水果分类预测,卷积神经网络水果等级识别

CNN回归预测基于matlab卷积神经网络CNN数据回归预测含Matlab源码 2003期

CNN回归预测基于matlab卷积神经网络CNN数据回归预测含Matlab源码 2003期

基于pytorch使用实现CNN 如何使用pytorch构建CNN卷积神经网络

基于pytorch使用实现CNN 如何使用pytorch构建CNN卷积神经网络