使用gluon实现简单的CNN
Posted 白菜hxj
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了使用gluon实现简单的CNN相关的知识,希望对你有一定的参考价值。
from mxnet import ndarray as nd from mxnet import gluon from mxnet import autograd from mxnet.gluon import nn def transform(data, label): return nd.transpose(data.astype(np.float32), (2,0,1))/255, label.astype(np.float32) mnist_train = gluon.data.vision.FashionMNIST(train=True, transform=transform) mnist_test = gluon.data.vision.FashionMNIST(train=False, transform=transform) batch_size = 256 train_data = gluon.data.DataLoader(mnist_train, batch_size, shuffle=True) test_data = gluon.data.DataLoader(mnist_test, batch_size, shuffle=False)
import mxnet as mx try: ctx = mx.gpu() _ = nd.zeros((1,), ctx = ctx) except: ctx = mx.cpu() ctx
def accuracy(output, label): return nd.mean(output.argmax(axis=1)==label).asscalar() def evaluate_accuracy(data_iterator, net): acc = 0. for data, label in data_iterator: output = net(data) acc += accuracy(output, label) return acc / len(data_iterator)
net = nn.Sequential() with net.name_scope(): net.add( nn.Conv2D(channels=20, kernel_size=5, activation=‘relu‘), nn.MaxPool2D(pool_size=2, strides=2), nn.Conv2D(channels=50, kernel_size=3, activation=‘relu‘), nn.MaxPool2D(pool_size=2, strides=2), nn.Flatten(), nn.Dense(128, activation="relu"), nn.Dense(10))
net.initialize(ctx=ctx) softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss() trainer = gluon.Trainer(net.collect_params(), ‘sgd‘, {‘learning_rate‘: 0.2})
softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss() for epoch in range(5): train_loss = 0. train_acc = 0. for data, label in train_data: label = label.as_in_context(ctx) with autograd.record(): output = net(data) loss = softmax_cross_entropy(output, label) loss.backward() trainer.step(batch_size) train_loss += nd.mean(loss).asscalar() train_acc += accuracy(output, label) test_acc = evaluate_accuracy(test_data, net) print("Epoch %d. Loss: %f, Train acc %f, Test acc %f" % (epoch, train_loss/len(train_data),train_acc/len(train_data), test_acc))
以上是关于使用gluon实现简单的CNN的主要内容,如果未能解决你的问题,请参考以下文章
基于pytorch使用实现CNN 如何使用pytorch构建CNN卷积神经网络
基于pytorch使用实现CNN 如何使用pytorch构建CNN卷积神经网络