Java程序员学深度学习 DJL上手9 在CIFAR-10数据集使用风格迁移学习

Posted 编程圈子

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Java程序员学深度学习 DJL上手9 在CIFAR-10数据集使用风格迁移学习相关的知识,希望对你有一定的参考价值。

Java程序员学深度学习 DJL上手9 在CIFAR-10数据集使用风格迁移学习

一、说明

本文将使用风格迁移学习模型训练一个图像分类模型。如前文所述,风格迁移学习是训练一个针对某个问题的模型,然后把模型应用在第二个问题上。与直接训练针对特定问题的模型相比,风格迁移学习可以减少学习的特征数量,用更少的时间产生更灵活的模型。

本文使用CIFAR-10数据集训练我们自己的模型,该数据集包含 6万个 32*32的彩色分类图形。

本文的预训练模型使用ResNet50v1 ,它是使用ImageNet训练的有50层的深度学习模型,使用超过120万张图片、拥有1000个分类。本文修改ImageNet,并从CIFAR-10数据集中分类10个类。

本文尚未实验成功,加载预定义模型失败


CIFAR-10 数据集

二、操作过程

1. 加载预先训练的ResNet50V1模型

ResNet50V1可以在 ModelZoo中找到。此模型是在ImageNet数据集上进行了训练,拥有1000个输出分类。 由于我们要在CIFAR10上重新调整为10个分类,因此我们要删除最后一层,并添加 具有 10个输出通道的新线性层。完成对块的修改后,把块重新放回模型中使用。

// load model and change last layer
Criteria<Image, Classifications> criteria = Criteria.builder()
    .setTypes(Image.class, Classifications.class)
    .optProgress(new ProgressBar())
    .optArtifactId("resnet")
    .optFilter("layers", "50")
    .optFilter("flavor", "v1").build();
Model model = criteria.loadModel();
SequentialBlock newBlock = new SequentialBlock();
SymbolBlock block = (SymbolBlock) model.getBlock();
block.removeLastBlock();
newBlock.add(block);
newBlock.add(Blocks.batchFlattenBlock());
newBlock.add(Linear.builder().setUnits(10).build());
model.setBlock(newBlock);

2. 准备数据集

在构建数据集时,可以设置训练、测试的数据集大小、批次大小,设置预处理管道。
管道用于对数据进行预处理,例如可以将形状 (32、32、3)和值从0到256的彩色图像NDArray与形状(3、32、32)和值从0转换成1.
另外还可以根据输入数据的均值和标准偏差值使输入数据正常化。

int batchSize = 32;
int limit = Integer.MAX_VALUE; // change this to a small value for a dry run
// int limit = 160; // limit 160 records in the dataset for a dry run
Pipeline pipeline = new Pipeline(
    new ToTensor(),
    new Normalize(new float[] 0.4914f, 0.4822f, 0.4465f, new float[] 0.2023f, 0.1994f, 0.2010f));
Cifar10 trainDataset = 
    Cifar10.builder()
    .setSampling(batchSize, true)
    .optUsage(Dataset.Usage.TRAIN)
    .optLimit(limit)
    .optPipeline(pipeline)
    .build();
trainDataset.prepare(new ProgressBar());
```
## 3. 设置训练参数
我们利用预先训练的模型,只进行10次迭代。
```java
DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
    //softmaxCrossEntropyLoss is a standard loss for classification problems
    .addEvaluator(new Accuracy()) // Use accuracy so we humans can understand how accurate the model is
    .optDevices(Device.getDevices(1)) // Limit your GPU, using more GPU actually will slow down coverging
    .addTrainingListeners(TrainingListener.Defaults.logging());

// Now that we have our training configuration, we should create a new trainer for our model
Trainer trainer = model.newTrainer(config);
```
## 4. 训练模型
```java
int epoch = 10;
Shape inputShape = new Shape(1, 3, 32, 32);
trainer.initialize(inputShape);
```
```java
for (int i = 0; i < epoch; ++i) 
    int index = 0;
    for (Batch batch : trainer.iterateDataset(trainDataset)) 
        EasyTrain.trainBatch(trainer, batch);
        trainer.step();
        batch.close();
    

    // reset training and validation evaluators at end of epoch
    trainer.notifyListeners(listener -> listener.onEpoch(trainer));

```

## 5. 保存模型
```java
Path modelDir = Paths.get("build/resnet");
Files.createDirectories(modelDir);

model.setProperty("Epoch", String.valueOf(epoch));
model.save(modelDir, "resnet");
```

# 源代码
```java
package com.xundh;

import ai.djl.Application;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.basicdataset.tabular.CsvDataset;
import ai.djl.engine.Engine;
import ai.djl.inference.Predictor;
import ai.djl.metric.Metrics;
import ai.djl.modality.Classifications;
import ai.djl.modality.nlp.SimpleVocabulary;
import ai.djl.modality.nlp.bert.BertFullTokenizer;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Block;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.core.Linear;
import ai.djl.nn.norm.Dropout;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingResult;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.listener.SaveModelTrainingListener;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.PaddingStackBatchifier;
import ai.djl.translate.TranslateException;
import org.apache.commons.csv.CSVFormat;

import java.io.IOException;
import java.nio.file.Paths;

public class PyTorchLearn 
    public static void main(String[] args) throws IOException, TranslateException, MalformedModelException, ModelNotFoundException 
        // 根据深度学习引擎,选择要下载的模型
        // MXNet base model
        String modelUrls = "https://resources.djl.ai/test-models/distilbert.zip";
        if ("PyTorch".equals(Engine.getInstance().getEngineName())) 
            modelUrls = "https://resources.djl.ai/test-models/traced_distilbert_wikipedia_uncased.zip";
        

        Criteria<NDList, NDList> criteria = Criteria.builder()
                .optApplication(Application.NLP.WORD_EMBEDDING)
                .setTypes(NDList.class, NDList.class)
                .optModelUrls(modelUrls)
                .optProgress(new ProgressBar())
                .build();
        ZooModel<NDList, NDList> embedding = criteria.loadModel();
        Predictor<NDList, NDList> embedder = embedding.newPredictor();
        Block classifier = new SequentialBlock()
                // text embedding layer
                .add(ndList -> 
                    NDArray data = ndList.singletonOrThrow();
                    NDList inputs = new NDList();
                    long batchSize = data.getShape().get(0);
                    float maxLength = data.getShape().get(1);

                    if ("PyTorch".equals(Engine.getInstance().getEngineName())) 
                        inputs.add(data.toType(DataType.INT64, false));
                        inputs.add(data.getManager().full(data.getShape(), 1, DataType.INT64));
                        inputs.add(data.getManager().arange(maxLength)
                                .toType(DataType.INT64, false)
                                .broadcast(data.getShape()));
                     else 
                        inputs.add(data);
                        inputs.add(data.getManager().full(new Shape(batchSize), maxLength));
                    
                    // run embedding
                    try 
                        return embedder.predict(inputs);
                     catch (TranslateException e) 
                        throw new IllegalArgumentException("embedding error", e);
                    
                )
                // classification layer
                .add(Linear.builder().setUnits(768).build()) // pre classifier
                .add(Activation::relu)
                .add(Dropout.builder().optRate(0.2f).build())
                .add(Linear.builder().setUnits(5).build()) // 5 star rating
                .addSingleton(nd -> nd.get(":,0")); // Take [CLS] as the head
        Model model = Model.newInstance("AmazonReviewRatingClassification");
        model.setBlock(classifier);

        // Prepare the vocabulary
        SimpleVocabulary vocabulary = SimpleVocabulary.builder()
                .optMinFrequency(1)
                .addFromTextFile(embedding.getArtifact("vocab.txt"))
                .optUnknownToken("[UNK]")
                .build();
        // Prepare dataset
        int maxTokenLength = 64; // cutoff tokens length
        int batchSize = 8;
        // int limit = Integer.MAX_VALUE;
        int limit = 512; // uncomment for quick testing

        BertFullTokenizer tokenizer = new BertFullTokenizer(vocabulary, true);
        CsvDataset amazonReviewDataset = getDataset(batchSize, tokenizer, maxTokenLength, limit);
        // split data with 7:3 train:valid ratio
        RandomAccessDataset[] datasets = amazonReviewDataset.randomSplit(7, 3);
        RandomAccessDataset trainingSet = datasets[0];
        RandomAccessDataset validationSet = datasets[1];
        SaveModelTrainingListener listener = new SaveModelTrainingListener("build/model");
        listener.setSaveModelCallback(trainer -> 
            TrainingResult result = trainer.getTrainingResult();
            Model model1 = trainer.getModel();
            // track for accuracy and loss
            float accuracy = result.getValidateEvaluation("Accuracy");
            model1.setProperty("Accuracy", String.format("%.5f", accuracy));
            model1.setProperty("Loss", String.format("%.5f", result.getValidateLoss()));
        );
        DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) // loss type
                .addEvaluator(new Accuracy())
                .optDevices(new Device[]Device.cpu()) // train using single GPU
                .addTrainingListeners(TrainingListener.Defaults.logging("build/model"))
                .addTrainingListeners(listener);

        int epoch = 2;

        Trainer trainer = model.newTrainer(config);
        trainer.setMetrics(new Metrics());
        Shape encoderInputShape = new Shape(batchSize, maxTokenLength);
        // initialize trainer with proper input shape
        trainer.initialize(encoderInputShape);
        EasyTrain.fit(trainer, epoch, trainingSet, validationSet);
        System.out.println(trainer.getTrainingResult());

        model.save(Paths.get("build/model"), "amazon-review.param");

        String review = "It works great, but it takes too long to update itself and slows the system";
        Predictor<String, Classifications> predictor = model.newPredictor(new MyTranslator(tokenizer));
        System.out.println(predictor.predict(review));
    

    /**
     * 下载创建数据集对象
     */
    static CsvDataset getDataset(int batchSize, BertFullTokenizer tokenizer, int maxLength, 以上是关于Java程序员学深度学习 DJL上手9 在CIFAR-10数据集使用风格迁移学习的主要内容,如果未能解决你的问题,请参考以下文章

Java程序员学深度学习 DJL上手1

Java程序员学深度学习 DJL上手1

Java程序员学深度学习 DJL上手7 使用Pytorch引擎

Java程序员学深度学习 DJL上手2 Springboot集成

Java程序员学深度学习 DJL上手5 训练自己的模型

Java程序员学深度学习 DJL上手5 训练自己的模型