Java程序员学深度学习 DJL上手9 在CIFAR-10数据集使用风格迁移学习
Posted 编程圈子
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Java程序员学深度学习 DJL上手9 在CIFAR-10数据集使用风格迁移学习相关的知识,希望对你有一定的参考价值。
Java程序员学深度学习 DJL上手9 在CIFAR-10数据集使用风格迁移学习
一、说明
本文将使用风格迁移学习模型训练一个图像分类模型。如前文所述,风格迁移学习是训练一个针对某个问题的模型,然后把模型应用在第二个问题上。与直接训练针对特定问题的模型相比,风格迁移学习可以减少学习的特征数量,用更少的时间产生更灵活的模型。
本文使用CIFAR-10数据集训练我们自己的模型,该数据集包含 6万个 32*32的彩色分类图形。
本文的预训练模型使用ResNet50v1 ,它是使用ImageNet训练的有50层的深度学习模型,使用超过120万张图片、拥有1000个分类。本文修改ImageNet,并从CIFAR-10数据集中分类10个类。
本文尚未实验成功,加载预定义模型失败
CIFAR-10 数据集
二、操作过程
1. 加载预先训练的ResNet50V1模型
ResNet50V1可以在 ModelZoo中找到。此模型是在ImageNet数据集上进行了训练,拥有1000个输出分类。 由于我们要在CIFAR10上重新调整为10个分类,因此我们要删除最后一层,并添加 具有 10个输出通道的新线性层。完成对块的修改后,把块重新放回模型中使用。
// load model and change last layer
Criteria<Image, Classifications> criteria = Criteria.builder()
.setTypes(Image.class, Classifications.class)
.optProgress(new ProgressBar())
.optArtifactId("resnet")
.optFilter("layers", "50")
.optFilter("flavor", "v1").build();
Model model = criteria.loadModel();
SequentialBlock newBlock = new SequentialBlock();
SymbolBlock block = (SymbolBlock) model.getBlock();
block.removeLastBlock();
newBlock.add(block);
newBlock.add(Blocks.batchFlattenBlock());
newBlock.add(Linear.builder().setUnits(10).build());
model.setBlock(newBlock);
2. 准备数据集
在构建数据集时,可以设置训练、测试的数据集大小、批次大小,设置预处理管道。
管道用于对数据进行预处理,例如可以将形状 (32、32、3)和值从0到256的彩色图像NDArray与形状(3、32、32)和值从0转换成1.
另外还可以根据输入数据的均值和标准偏差值使输入数据正常化。
int batchSize = 32;
int limit = Integer.MAX_VALUE; // change this to a small value for a dry run
// int limit = 160; // limit 160 records in the dataset for a dry run
Pipeline pipeline = new Pipeline(
new ToTensor(),
new Normalize(new float[] 0.4914f, 0.4822f, 0.4465f, new float[] 0.2023f, 0.1994f, 0.2010f));
Cifar10 trainDataset =
Cifar10.builder()
.setSampling(batchSize, true)
.optUsage(Dataset.Usage.TRAIN)
.optLimit(limit)
.optPipeline(pipeline)
.build();
trainDataset.prepare(new ProgressBar());
```
## 3. 设置训练参数
我们利用预先训练的模型,只进行10次迭代。
```java
DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
//softmaxCrossEntropyLoss is a standard loss for classification problems
.addEvaluator(new Accuracy()) // Use accuracy so we humans can understand how accurate the model is
.optDevices(Device.getDevices(1)) // Limit your GPU, using more GPU actually will slow down coverging
.addTrainingListeners(TrainingListener.Defaults.logging());
// Now that we have our training configuration, we should create a new trainer for our model
Trainer trainer = model.newTrainer(config);
```
## 4. 训练模型
```java
int epoch = 10;
Shape inputShape = new Shape(1, 3, 32, 32);
trainer.initialize(inputShape);
```
```java
for (int i = 0; i < epoch; ++i)
int index = 0;
for (Batch batch : trainer.iterateDataset(trainDataset))
EasyTrain.trainBatch(trainer, batch);
trainer.step();
batch.close();
// reset training and validation evaluators at end of epoch
trainer.notifyListeners(listener -> listener.onEpoch(trainer));
```
## 5. 保存模型
```java
Path modelDir = Paths.get("build/resnet");
Files.createDirectories(modelDir);
model.setProperty("Epoch", String.valueOf(epoch));
model.save(modelDir, "resnet");
```
# 源代码
```java
package com.xundh;
import ai.djl.Application;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.basicdataset.tabular.CsvDataset;
import ai.djl.engine.Engine;
import ai.djl.inference.Predictor;
import ai.djl.metric.Metrics;
import ai.djl.modality.Classifications;
import ai.djl.modality.nlp.SimpleVocabulary;
import ai.djl.modality.nlp.bert.BertFullTokenizer;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Block;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.core.Linear;
import ai.djl.nn.norm.Dropout;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingResult;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.listener.SaveModelTrainingListener;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.PaddingStackBatchifier;
import ai.djl.translate.TranslateException;
import org.apache.commons.csv.CSVFormat;
import java.io.IOException;
import java.nio.file.Paths;
public class PyTorchLearn
public static void main(String[] args) throws IOException, TranslateException, MalformedModelException, ModelNotFoundException
// 根据深度学习引擎,选择要下载的模型
// MXNet base model
String modelUrls = "https://resources.djl.ai/test-models/distilbert.zip";
if ("PyTorch".equals(Engine.getInstance().getEngineName()))
modelUrls = "https://resources.djl.ai/test-models/traced_distilbert_wikipedia_uncased.zip";
Criteria<NDList, NDList> criteria = Criteria.builder()
.optApplication(Application.NLP.WORD_EMBEDDING)
.setTypes(NDList.class, NDList.class)
.optModelUrls(modelUrls)
.optProgress(new ProgressBar())
.build();
ZooModel<NDList, NDList> embedding = criteria.loadModel();
Predictor<NDList, NDList> embedder = embedding.newPredictor();
Block classifier = new SequentialBlock()
// text embedding layer
.add(ndList ->
NDArray data = ndList.singletonOrThrow();
NDList inputs = new NDList();
long batchSize = data.getShape().get(0);
float maxLength = data.getShape().get(1);
if ("PyTorch".equals(Engine.getInstance().getEngineName()))
inputs.add(data.toType(DataType.INT64, false));
inputs.add(data.getManager().full(data.getShape(), 1, DataType.INT64));
inputs.add(data.getManager().arange(maxLength)
.toType(DataType.INT64, false)
.broadcast(data.getShape()));
else
inputs.add(data);
inputs.add(data.getManager().full(new Shape(batchSize), maxLength));
// run embedding
try
return embedder.predict(inputs);
catch (TranslateException e)
throw new IllegalArgumentException("embedding error", e);
)
// classification layer
.add(Linear.builder().setUnits(768).build()) // pre classifier
.add(Activation::relu)
.add(Dropout.builder().optRate(0.2f).build())
.add(Linear.builder().setUnits(5).build()) // 5 star rating
.addSingleton(nd -> nd.get(":,0")); // Take [CLS] as the head
Model model = Model.newInstance("AmazonReviewRatingClassification");
model.setBlock(classifier);
// Prepare the vocabulary
SimpleVocabulary vocabulary = SimpleVocabulary.builder()
.optMinFrequency(1)
.addFromTextFile(embedding.getArtifact("vocab.txt"))
.optUnknownToken("[UNK]")
.build();
// Prepare dataset
int maxTokenLength = 64; // cutoff tokens length
int batchSize = 8;
// int limit = Integer.MAX_VALUE;
int limit = 512; // uncomment for quick testing
BertFullTokenizer tokenizer = new BertFullTokenizer(vocabulary, true);
CsvDataset amazonReviewDataset = getDataset(batchSize, tokenizer, maxTokenLength, limit);
// split data with 7:3 train:valid ratio
RandomAccessDataset[] datasets = amazonReviewDataset.randomSplit(7, 3);
RandomAccessDataset trainingSet = datasets[0];
RandomAccessDataset validationSet = datasets[1];
SaveModelTrainingListener listener = new SaveModelTrainingListener("build/model");
listener.setSaveModelCallback(trainer ->
TrainingResult result = trainer.getTrainingResult();
Model model1 = trainer.getModel();
// track for accuracy and loss
float accuracy = result.getValidateEvaluation("Accuracy");
model1.setProperty("Accuracy", String.format("%.5f", accuracy));
model1.setProperty("Loss", String.format("%.5f", result.getValidateLoss()));
);
DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) // loss type
.addEvaluator(new Accuracy())
.optDevices(new Device[]Device.cpu()) // train using single GPU
.addTrainingListeners(TrainingListener.Defaults.logging("build/model"))
.addTrainingListeners(listener);
int epoch = 2;
Trainer trainer = model.newTrainer(config);
trainer.setMetrics(new Metrics());
Shape encoderInputShape = new Shape(batchSize, maxTokenLength);
// initialize trainer with proper input shape
trainer.initialize(encoderInputShape);
EasyTrain.fit(trainer, epoch, trainingSet, validationSet);
System.out.println(trainer.getTrainingResult());
model.save(Paths.get("build/model"), "amazon-review.param");
String review = "It works great, but it takes too long to update itself and slows the system";
Predictor<String, Classifications> predictor = model.newPredictor(new MyTranslator(tokenizer));
System.out.println(predictor.predict(review));
/**
* 下载创建数据集对象
*/
static CsvDataset getDataset(int batchSize, BertFullTokenizer tokenizer, int maxLength, 以上是关于Java程序员学深度学习 DJL上手9 在CIFAR-10数据集使用风格迁移学习的主要内容,如果未能解决你的问题,请参考以下文章
Java程序员学深度学习 DJL上手7 使用Pytorch引擎