pytorch解决鸢尾花分类
Posted mc-curry
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch解决鸢尾花分类相关的知识,希望对你有一定的参考价值。
半年前用numpy写了个鸢尾花分类200行。。每一步计算都是手写的 python构建bp神经网络_鸢尾花分类
现在用pytorch简单写一遍,pytorch语法解释请看上一篇pytorch搭建简单网络
1 import pandas as pd 2 import torch.nn as nn 3 import torch 4 5 6 class MyNet(nn.Module): 7 def __init__(self): 8 super(MyNet, self).__init__() 9 self.fc = nn.Sequential( 10 nn.Linear(4, 3), 11 nn.Sigmoid(), 12 nn.Linear(3, 3), 13 nn.Sigmoid(), 14 nn.Linear(3, 1), 15 ) 16 self.mls = nn.MSELoss() 17 self.opt = torch.optim.Adam(params=self.parameters(), lr=0.001) 18 19 def get_data(self): 20 inputs = [] 21 labels = [] 22 with open(‘flower.csv‘) as file: 23 df = pd.read_csv(file, header=None) 24 x = df.iloc[:, 0:4].values 25 y = df.iloc[:, 4].values 26 for i in range(len(x)): 27 inputs.append(x[i]) 28 for j in range(len(y)): 29 a = [] 30 a.append(y[j]) 31 labels.append(a) 32 33 return inputs, labels 34 35 def forward(self, inputs): 36 out = self.fc(inputs) 37 return out 38 39 def train(self, x, label): 40 out = self.forward(x) 41 loss = self.mls(out, label) 42 self.opt.zero_grad() 43 loss.backward() 44 self.opt.step() 45 46 def test(self, x): 47 return self.fc(x) 48 49 50 if __name__ == ‘__main__‘: 51 net = MyNet() 52 inputs, labels = net.get_data() 53 for i in range(1000): 54 for index, input in enumerate(inputs): 55 # 这里不加.float()会报错,可能是数据格式的问题吧 56 input = torch.from_numpy(input).float() 57 label = torch.Tensor(labels[index]) 58 net.train(input, label) 59 # 简单测试一下 60 c = torch.Tensor([[5.6, 2.7, 4.2, 1.3]]) 61 print(net.test(c))
运行结果趋近于0.5 正确,单纯练一下pytorch,就没有分训练集,测试集
1 tensor([[0.5392]], grad_fn=<AddmmBackward>)
不用手写反向传播和梯度下降 是多么幸福一件事~
以上是关于pytorch解决鸢尾花分类的主要内容,如果未能解决你的问题,请参考以下文章