Pytorch搭建U-Net网络
Posted xbw12138
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pytorch搭建U-Net网络相关的知识,希望对你有一定的参考价值。
U-Net: Convolutional Networks for Biomedical Image Segmentation
import torch.nn as nn
import torch
from torch import autograd
from torchsummary import summary
class DoubleConv(nn.Module):
def __init__(self, in_ch, out_ch):
super(DoubleConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=0),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=0),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
def forward(self, input):
return self.conv(input)
class Unet(nn.Module):
def __init__(self, in_ch, out_ch):
super(Unet, self).__init__()
self.conv1 = DoubleConv(in_ch, 64)
self.pool1 = nn.MaxPool2d(2)
self.conv2 = DoubleConv(64, 128)
self.pool2 = nn.MaxPool2d(2)
self.conv3 = DoubleConv(128, 256)
self.pool3 = nn.MaxPool2d(2)
self.conv4 = DoubleConv(256, 512)
self.pool4 = nn.MaxPool2d(2)
self.conv5 = DoubleConv(512, 1024)
# 逆卷积,也可以使用上采样
self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
self.conv6 = DoubleConv(1024, 512)
self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
self.conv7 = DoubleConv(512, 256)
self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
self.conv8 = DoubleConv(256, 128)
self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
self.conv9 = DoubleConv(128, 64)
self.conv10 = nn.Conv2d(64, out_ch, 1)
def forward(self, x):
c1 = self.conv1(x)
crop1 = c1[:,:,88:480,88:480]
p1 = self.pool1(c1)
c2 = self.conv2(p1)
crop2 = c2[:,:,40:240,40:240]
p2 = self.pool2(c2)
c3 = self.conv3(p2)
crop3 = c3[:,:,16:120,16:120]
p3 = self.pool3(c3)
c4 = self.conv4(p3)
crop4 = c4[:,:,4:60,4:60]
p4 = self.pool4(c4)
c5 = self.conv5(p4)
up_6 = self.up6(c5)
merge6 = torch.cat([up_6, crop4], dim=1)
c6 = self.conv6(merge6)
up_7 = self.up7(c6)
merge7 = torch.cat([up_7, crop3], dim=1)
c7 = self.conv7(merge7)
up_8 = self.up8(c7)
merge8 = torch.cat([up_8, crop2], dim=1)
c8 = self.conv8(merge8)
up_9 = self.up9(c8)
merge9 = torch.cat([up_9, crop1], dim=1)
c9 = self.conv9(merge9)
c10 = self.conv10(c9)
out = nn.Sigmoid()(c10)
return out
if __name__=="__main__":
test_input=torch.rand(1, 1, 572, 572)
model=Unet(in_ch=1, out_ch=2)
summary(model, (1,572,572))
ouput=model(test_input)
print(ouput.size())
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 570, 570] 640
BatchNorm2d-2 [-1, 64, 570, 570] 128
ReLU-3 [-1, 64, 570, 570] 0
Conv2d-4 [-1, 64, 568, 568] 36,928
BatchNorm2d-5 [-1, 64, 568, 568] 128
ReLU-6 [-1, 64, 568, 568] 0
DoubleConv-7 [-1, 64, 568, 568] 0
MaxPool2d-8 [-1, 64, 284, 284] 0
Conv2d-9 [-1, 128, 282, 282] 73,856
BatchNorm2d-10 [-1, 128, 282, 282] 256
ReLU-11 [-1, 128, 282, 282] 0
Conv2d-12 [-1, 128, 280, 280] 147,584
BatchNorm2d-13 [-1, 128, 280, 280] 256
ReLU-14 [-1, 128, 280, 280] 0
DoubleConv-15 [-1, 128, 280, 280] 0
MaxPool2d-16 [-1, 128, 140, 140] 0
Conv2d-17 [-1, 256, 138, 138] 295,168
BatchNorm2d-18 [-1, 256, 138, 138] 512
ReLU-19 [-1, 256, 138, 138] 0
Conv2d-20 [-1, 256, 136, 136] 590,080
BatchNorm2d-21 [-1, 256, 136, 136] 512
ReLU-22 [-1, 256, 136, 136] 0
DoubleConv-23 [-1, 256, 136, 136] 0
MaxPool2d-24 [-1, 256, 68, 68] 0
Conv2d-25 [-1, 512, 66, 66] 1,180,160
BatchNorm2d-26 [-1, 512, 66, 66] 1,024
ReLU-27 [-1, 512, 66, 66] 0
Conv2d-28 [-1, 512, 64, 64] 2,359,808
BatchNorm2d-29 [-1, 512, 64, 64] 1,024
ReLU-30 [-1, 512, 64, 64] 0
DoubleConv-31 [-1, 512, 64, 64] 0
MaxPool2d-32 [-1, 512, 32, 32] 0
Conv2d-33 [-1, 1024, 30, 30] 4,719,616
BatchNorm2d-34 [-1, 1024, 30, 30] 2,048
ReLU-35 [-1, 1024, 30, 30] 0
Conv2d-36 [-1, 1024, 28, 28] 9,438,208
BatchNorm2d-37 [-1, 1024, 28, 28] 2,048
ReLU-38 [-1, 1024, 28, 28] 0
DoubleConv-39 [-1, 1024, 28, 28] 0
ConvTranspose2d-40 [-1, 512, 56, 56] 2,097,664
Conv2d-41 [-1, 512, 54, 54] 4,719,104
BatchNorm2d-42 [-1, 512, 54, 54] 1,024
ReLU-43 [-1, 512, 54, 54] 0
Conv2d-44 [-1, 512, 52, 52] 2,359,808
BatchNorm2d-45 [-1, 512, 52, 52] 1,024
ReLU-46 [-1, 512, 52, 52] 0
DoubleConv-47 [-1, 512, 52, 52] 0
ConvTranspose2d-48 [-1, 256, 104, 104] 524,544
Conv2d-49 [-1, 256, 102, 102] 1,179,904
BatchNorm2d-50 [-1, 256, 102, 102] 512
ReLU-51 [-1, 256, 102, 102] 0
Conv2d-52 [-1, 256, 100, 100] 590,080
BatchNorm2d-53 [-1, 256, 100, 100] 512
ReLU-54 [-1, 256, 100, 100] 0
DoubleConv-55 [-1, 256, 100, 100] 0
ConvTranspose2d-56 [-1, 128, 200, 200] 131,200
Conv2d-57 [-1, 128, 198, 198] 295,040
BatchNorm2d-58 [-1, 128, 198, 198] 256
ReLU-59 [-1, 128, 198, 198] 0
Conv2d-60 [-1, 128, 196, 196] 147,584
BatchNorm2d-61 [-1, 128, 196, 196] 256
ReLU-62 [-1, 128, 196, 196] 0
DoubleConv-63 [-1, 128, 196, 196] 0
ConvTranspose2d-64 [-1, 64, 392, 392] 32,832
Conv2d-65 [-1, 64, 390, 390] 73,792
BatchNorm2d-66 [-1, 64, 390, 390] 128
ReLU-67 [-1, 64, 390, 390] 0
Conv2d-68 [-1, 64, 388, 388] 36,928
BatchNorm2d-69 [-1, 64, 388, 388] 128
ReLU-70 [-1, 64, 388, 388] 0
DoubleConv-71 [-1, 64<以上是关于Pytorch搭建U-Net网络的主要内容,如果未能解决你的问题,请参考以下文章