encode与decode
Posted wmy-ncut
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了encode与decode相关的知识,希望对你有一定的参考价值。
import torch from torch import nn import numpy as np import matplotlib.pyplot as plt import torch.utils.data as Data import torchvision from mpl_toolkits.mplot3d import Axes3D #画3D图 from matplotlib import cm # Hyper Parameters EPOCH=10 BATCH_SIZE=64 LR = 0.005 # learning rate DOWNLOAD_MNIST=False N_TEST_IMG=5 train_data=torchvision.datasets.MNIST( root=‘./mnist/‘, train=True, transform=torchvision.transforms.ToTensor(), download=DOWNLOAD_MNIST ) train_loader=Data.DataLoader(dataset=train_data,batch_size=BATCH_SIZE,shuffle=True) class AutoEncoder(nn.Module): def __init__(self): super(AutoEncoder, self).__init__() self.encoder = nn.Sequential( nn.Linear(28 * 28, 128), nn.Tanh(), nn.Linear(128,64), nn.Tanh(), nn.Linear(64, 12), # nn.Tanh(), # nn.Linear(12, 3), ) self.decoder=nn.Sequential( # nn.Linear(3,12), # nn.Tanh(), nn.Linear(12, 64), nn.Tanh(), nn.Linear(64, 128), nn.Tanh(), nn.Linear(128, 28*28), nn.Sigmoid() ) def forward(self, x ): encoder=self.encoder(x) decoder=self.decoder(encoder) return encoder,decoder AutoEncoder = AutoEncoder() # print(AutoEncoder) optimizer = torch.optim.Adam(AutoEncoder.parameters(), lr=LR) # optimize all cnn parameters loss_func = nn.MSELoss() f,a=plt.subplots(2,N_TEST_IMG,figsize=(5,2)) plt.ion() # continuously plot view_data=train_data.train_data[:N_TEST_IMG].view(-1,28*28).type(torch.FloatTensor)/255 for i in range(N_TEST_IMG): a[0][i].imshow(np.reshape(view_data.data.numpy()[i], (28, 28)), cmap=‘gray‘) a[0][i].set_xticks(()) a[0][i].set_yticks(()) for epoch in range(EPOCH): for step,(x,b_label) in enumerate(train_loader): b_x=x.view(-1,28*28) b_y=x.view(-1,28*28) encoded, decoded = AutoEncoder(b_x) loss=loss_func(decoded,b_y) optimizer.zero_grad() loss.backward() optimizer.step() if step%100==0: print(‘Epoch:|‘,epoch,‘train loss:%0.4f‘%loss.data.numpy()) _,decoded_data=AutoEncoder(view_data) for i in range(N_TEST_IMG): a[1][i].clear() a[1][i].imshow(np.reshape(decoded.data.numpy()[i],(28,28)),cmap=‘gray‘) a[1][i].set_xticks(()) a[1][i].set_yticks(()) plt.draw() plt.pause(0.05) plt.ioff() plt.show() view_data=train_data.train_data[:200].view(-1,28*28).type(torch.FloatTensor)/255 encoded_data,_=AutoEncoder(view_data) fig=plt.figure(2) ax=Axes3D(fig) X,Y,Z=encoded_data.data[:, 0].numpy(), encoded_data.data[:, 1].numpy(), encoded_data.data[:, 2].numpy() values=train_data.train_labels[:200].numpy() for x,y,z ,s in zip(X,Y,Z,values): c=cm.rainbow(int(255*s/9)) ax.text(x,y,z,s,backgroundcolor=c) ax.set_xlim(X.min(),X.max()) ax.set_ylim(Y.min(),Y.max()) ax.set_zlim(Z.min(),Z.max()) plt.show()
选出五张图片做测试。
图像分为5*2显示,上面一行是原始图像,下面一行为编码和解码后的图像。
以上是关于encode与decode的主要内容,如果未能解决你的问题,请参考以下文章