Pytorch搭建U-Net网络

Posted xbw12138

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pytorch搭建U-Net网络相关的知识,希望对你有一定的参考价值。

U-Net: Convolutional Networks for Biomedical Image Segmentation

import torch.nn as nn
import torch
from torch import autograd
from torchsummary import summary

class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=0),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=0),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, input):
        return self.conv(input)

class Unet(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Unet, self).__init__()
        self.conv1 = DoubleConv(in_ch, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.conv3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.conv4 = DoubleConv(256, 512)
        self.pool4 = nn.MaxPool2d(2)
        self.conv5 = DoubleConv(512, 1024)
        # 逆卷积,也可以使用上采样
        self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.conv6 = DoubleConv(1024, 512)
        self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.conv7 = DoubleConv(512, 256)
        self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv8 = DoubleConv(256, 128)
        self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv9 = DoubleConv(128, 64)
        self.conv10 = nn.Conv2d(64, out_ch, 1)

    def forward(self, x):
        c1 = self.conv1(x)
        crop1 = c1[:,:,88:480,88:480]
        p1 = self.pool1(c1)
        c2 = self.conv2(p1)
        crop2 = c2[:,:,40:240,40:240]
        p2 = self.pool2(c2)
        c3 = self.conv3(p2)
        crop3 = c3[:,:,16:120,16:120]
        p3 = self.pool3(c3)
        c4 = self.conv4(p3)
        crop4 = c4[:,:,4:60,4:60]
        p4 = self.pool4(c4)
        c5 = self.conv5(p4)
        up_6 = self.up6(c5)
        merge6 = torch.cat([up_6, crop4], dim=1)
        c6 = self.conv6(merge6)
        up_7 = self.up7(c6)
        merge7 = torch.cat([up_7, crop3], dim=1)
        c7 = self.conv7(merge7)
        up_8 = self.up8(c7)
        merge8 = torch.cat([up_8, crop2], dim=1)
        c8 = self.conv8(merge8)
        up_9 = self.up9(c8)
        merge9 = torch.cat([up_9, crop1], dim=1)
        c9 = self.conv9(merge9)
        c10 = self.conv10(c9)
        out = nn.Sigmoid()(c10)
        return out


if __name__=="__main__":
    test_input=torch.rand(1, 1, 572, 572)
    model=Unet(in_ch=1, out_ch=2)
    summary(model, (1,572,572))
    ouput=model(test_input)
    print(ouput.size())
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 570, 570]             640
       BatchNorm2d-2         [-1, 64, 570, 570]             128
              ReLU-3         [-1, 64, 570, 570]               0
            Conv2d-4         [-1, 64, 568, 568]          36,928
       BatchNorm2d-5         [-1, 64, 568, 568]             128
              ReLU-6         [-1, 64, 568, 568]               0
        DoubleConv-7         [-1, 64, 568, 568]               0
         MaxPool2d-8         [-1, 64, 284, 284]               0
            Conv2d-9        [-1, 128, 282, 282]          73,856
      BatchNorm2d-10        [-1, 128, 282, 282]             256
             ReLU-11        [-1, 128, 282, 282]               0
           Conv2d-12        [-1, 128, 280, 280]         147,584
      BatchNorm2d-13        [-1, 128, 280, 280]             256
             ReLU-14        [-1, 128, 280, 280]               0
       DoubleConv-15        [-1, 128, 280, 280]               0
        MaxPool2d-16        [-1, 128, 140, 140]               0
           Conv2d-17        [-1, 256, 138, 138]         295,168
      BatchNorm2d-18        [-1, 256, 138, 138]             512
             ReLU-19        [-1, 256, 138, 138]               0
           Conv2d-20        [-1, 256, 136, 136]         590,080
      BatchNorm2d-21        [-1, 256, 136, 136]             512
             ReLU-22        [-1, 256, 136, 136]               0
       DoubleConv-23        [-1, 256, 136, 136]               0
        MaxPool2d-24          [-1, 256, 68, 68]               0
           Conv2d-25          [-1, 512, 66, 66]       1,180,160
      BatchNorm2d-26          [-1, 512, 66, 66]           1,024
             ReLU-27          [-1, 512, 66, 66]               0
           Conv2d-28          [-1, 512, 64, 64]       2,359,808
      BatchNorm2d-29          [-1, 512, 64, 64]           1,024
             ReLU-30          [-1, 512, 64, 64]               0
       DoubleConv-31          [-1, 512, 64, 64]               0
        MaxPool2d-32          [-1, 512, 32, 32]               0
           Conv2d-33         [-1, 1024, 30, 30]       4,719,616
      BatchNorm2d-34         [-1, 1024, 30, 30]           2,048
             ReLU-35         [-1, 1024, 30, 30]               0
           Conv2d-36         [-1, 1024, 28, 28]       9,438,208
      BatchNorm2d-37         [-1, 1024, 28, 28]           2,048
             ReLU-38         [-1, 1024, 28, 28]               0
       DoubleConv-39         [-1, 1024, 28, 28]               0
  ConvTranspose2d-40          [-1, 512, 56, 56]       2,097,664
           Conv2d-41          [-1, 512, 54, 54]       4,719,104
      BatchNorm2d-42          [-1, 512, 54, 54]           1,024
             ReLU-43          [-1, 512, 54, 54]               0
           Conv2d-44          [-1, 512, 52, 52]       2,359,808
      BatchNorm2d-45          [-1, 512, 52, 52]           1,024
             ReLU-46          [-1, 512, 52, 52]               0
       DoubleConv-47          [-1, 512, 52, 52]               0
  ConvTranspose2d-48        [-1, 256, 104, 104]         524,544
           Conv2d-49        [-1, 256, 102, 102]       1,179,904
      BatchNorm2d-50        [-1, 256, 102, 102]             512
             ReLU-51        [-1, 256, 102, 102]               0
           Conv2d-52        [-1, 256, 100, 100]         590,080
      BatchNorm2d-53        [-1, 256, 100, 100]             512
             ReLU-54        [-1, 256, 100, 100]               0
       DoubleConv-55        [-1, 256, 100, 100]               0
  ConvTranspose2d-56        [-1, 128, 200, 200]         131,200
           Conv2d-57        [-1, 128, 198, 198]         295,040
      BatchNorm2d-58        [-1, 128, 198, 198]             256
             ReLU-59        [-1, 128, 198, 198]               0
           Conv2d-60        [-1, 128, 196, 196]         147,584
      BatchNorm2d-61        [-1, 128, 196, 196]             256
             ReLU-62        [-1, 128, 196, 196]               0
       DoubleConv-63        [-1, 128, 196, 196]               0
  ConvTranspose2d-64         [-1, 64, 392, 392]          32,832
           Conv2d-65         [-1, 64, 390, 390]          73,792
      BatchNorm2d-66         [-1, 64, 390, 390]             128
             ReLU-67         [-1, 64, 390, 390]               0
           Conv2d-68         [-1, 64, 388, 388]          36,928
      BatchNorm2d-69         [-1, 64, 388, 388]             128
             ReLU-70         [-1, 64, 388, 388]               0
       DoubleConv-71         [-1, 64<

以上是关于Pytorch搭建U-Net网络的主要内容,如果未能解决你的问题,请参考以下文章

具有预训练主干的 U-net:在哪里进行跳过连接?

Pytorch U-net 分割模型的“ValueError:轴与数组错误不匹配”可能是啥原因?

libtorch(pytorch c++)教程

libtorch(pytorch c++)教程

pytorch搭建简单神经网络

如何入门Pytorch之二:如何搭建实用神经网络