Pytorch学习
Posted Rgylin
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pytorch学习相关的知识,希望对你有一定的参考价值。
import torch
import torchvision
from torch.utils.data import DataLoader
#下载 并加载数据源
das= torchvision.datasets.CIFAR10("./data",train=True,transform=torchvision.transforms.ToTensor(),download=True)
#是间接通过 Dataset 来获得数据的,然后进行组装成一个 batch 返回
dataloader= DataLoader(das,batch_size=35)
class my_nn(torch.nn.Module):
def __init__(self):
super(my_nn, self).__init__()
#定义卷积 三个渠道RGB 6个 卷积大小为3*3 里面值看图片大小而定, stride是 移动步数 一列一行走
self.conv1= torch.nn.Conv2d(in_channels=3,out_channels=6,kernel_size=3, stride=1,padding=0)
#输出函数
def forward(self,input):
x=self.conv1(input)
return x
#显示数据
from torch.utils.tensorboard import SummaryWriter
wirte= SummaryWriter("nn_rgylin")
count=0
rgylin= my_nn()
for i in dataloader:
img,target= i
#卷积后操作
print(img.shape)
output= rgylin(img)
print(output.shape)
#由于维度不同所以要进行 torch.reshape将图片改为相同shape值 由于第一个batch大小未知,所以为-1表示
output1= torch.reshape(output,(-1,3,30,30))
#print(output1)
wirte.add_images("rglyin_nn",output1,count,dataformats="NCHW")
count+=1
以上是关于Pytorch学习的主要内容,如果未能解决你的问题,请参考以下文章
对比学习:《深度学习之Pytorch》《PyTorch深度学习实战》+代码