mnist数据集进行自编码
Posted czz0508
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了mnist数据集进行自编码相关的知识,希望对你有一定的参考价值。
""" 自动编码的核心就是各种全连接的组合,它是一种无监督的形式,因为他的标签是自己。 """ import torch import torch.nn as nn from torch.autograd import Variable import torch.utils.data as Data import torchvision import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D from matplotlib import cm import numpy as np # 超参数 EPOCH = 10 BATCH_SIZE = 64 LR = 0.005 DOWNLOAD_MNIST = False N_TEST_IMG = 5 # Mnist数据集 train_data = torchvision.datasets.MNIST( root=‘./mnist/‘, train=True, transform=torchvision.transforms.ToTensor(), download=DOWNLOAD_MNIST, ) print(train_data.train_data.size()) # (60000, 28, 28) print(train_data.train_labels.size()) # (60000) # 显示出一个例子 plt.imshow(train_data.train_data[2].numpy(), cmap=‘gray‘) plt.title(‘%i‘ % train_data.train_labels[2]) plt.show() # 将数据集分为多批数据 train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) # 搭建自编码网络框架 class AutoEncoder(nn.Module): def __init__(self): super(AutoEncoder, self).__init__() self.encoder = nn.Sequential( nn.Linear(28*28, 128), nn.Tanh(), nn.Linear(128, 64), nn.Tanh(), nn.Linear(64, 12), nn.Tanh(), nn.Linear(12, 3), ) self.decoder = nn.Sequential( nn.Linear(3, 12), nn.Tanh(), nn.Linear(12, 64), nn.Tanh(), nn.Linear(64, 128), nn.Tanh(), nn.Linear(128, 28*28), nn.Sigmoid(), # 将输出结果压缩到0到1之间,因为train_data的数据在0到1之间 ) def forward(self, x): encoded = self.encoder(x) decoded = self.decoder(encoded) return encoded, decoded autoencoder = AutoEncoder() optimizer = torch.optim.Adam(autoencoder.parameters(), lr=LR) loss_func = nn.MSELoss() # initialize figure f, a = plt.subplots(2, N_TEST_IMG, figsize=(5, 2)) plt.ion() # 设置为实时打印 # 第一行是原始图片 view_data = Variable(train_data.train_data[:N_TEST_IMG].view(-1, 28*28).type(torch.FloatTensor)/255.) for i in range(N_TEST_IMG): a[0][i].imshow(np.reshape(view_data.data.numpy()[i], (28, 28)), cmap=‘gray‘); a[0][i].set_xticks(()); a[0][i].set_yticks(()) for epoch in range(EPOCH): for step, (x, y) in enumerate(train_loader): b_x = Variable(x.view(-1, 28*28)) b_y = Variable(x.view(-1, 28*28)) encoded, decoded = autoencoder(b_x) loss = loss_func(decoded, b_y) optimizer.zero_grad() # 将上一部的梯度清零 loss.backward() # 反向传播,计算梯度 optimizer.step() # 优化网络中的各个参数 if step % 100 == 0: print(‘Epoch: ‘, epoch, ‘| train loss: %.4f‘ % loss.data[0]) # 第二行画出解码后的图片 _, decoded_data = autoencoder(view_data) for i in range(N_TEST_IMG): a[1][i].clear() a[1][i].imshow(np.reshape(decoded_data.data.numpy()[i], (28, 28)), cmap=‘gray‘) a[1][i].set_xticks(()); a[1][i].set_yticks(()) plt.draw(); plt.pause(0.05) plt.ioff() plt.show() # 可视化三维图 view_data = Variable(train_data.train_data[:200].view(-1, 28*28).type(torch.FloatTensor)/255.) encoded_data, _ = autoencoder(view_data) fig = plt.figure(2); ax = Axes3D(fig) X, Y, Z = encoded_data.data[:, 0].numpy(), encoded_data.data[:, 1].numpy(), encoded_data.data[:, 2].numpy() values = train_data.train_labels[:200].numpy() for x, y, z, s in zip(X, Y, Z, values): c = cm.rainbow(int(255*s/9)); ax.text(x, y, z, s, backgroundcolor=c) ax.set_xlim(X.min(), X.max()); ax.set_ylim(Y.min(), Y.max()); ax.set_zlim(Z.min(), Z.max()) plt.show()
以上是关于mnist数据集进行自编码的主要内容,如果未能解决你的问题,请参考以下文章