Java程序员学深度学习 DJL上手5 训练自己的模型

Posted 编程圈子

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Java程序员学深度学习 DJL上手5 训练自己的模型相关的知识,希望对你有一定的参考价值。

Java程序员学深度学习 DJL上手5 训练自己的模型

一、准备环境

  • windows
  • idea
  • maven

二、创建示例项目

三、准备数据集

int batchSize = 32;
Mnist mnist = Mnist.builder().setSampling(batchSize, true).build();
mnist.prepare(new ProgressBar());

这里对数据集进行了分批处理,每批大小32,合适的分批大小将在训练时显著提升性能。

四、创建模型

本节会根据之前文章创建模型。由于 MNIST 数据集中的图像为 28x28 灰度图像,这里我们创建一个具有 28 x 28 输入的 MLP 块。
输出的图输出为 10,因为每个图像可能有 10 个可能的类(0 到 9)。
对于隐藏的层,其大小是猜测的值new int[] 128, 64

Model model = Model.newInstance("mlp");
model.setBlock(new Mlp(28 * 28, 10, new int[] 128, 64));

五、创建训练器

1. 训练器配置

  • 损失函数,用来测量模型与测试数据集的匹配程度,值越低越好;这里定义为softmaxCrossEntropyLoss()
  • 评估函数,也用于测量模型与数据集的匹配程度。与损失不同,它们只供人们查看,不用于优化模型。
  • 监听器,用来监控训练过程。

        DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
                .addEvaluator(new Accuracy())
                .addTrainingListeners(TrainingListener.Defaults.logging());

        Trainer trainer = model.newTrainer(config);

2. 初始化训练器

这里使用输入的形状来初始化训练器。初始化函数里形状的第一个参数是批次大小,这个不影响参数初始化。
第二个参数是输入图像的像素数,即28*28。

        trainer.initialize(new Shape(1, 28 * 28));

3. 训练模型

这里使用了DJL的EasyTrain,

        int epoch = 2;
        EasyTrain.fit(trainer, epoch, mnist, null);

4. 保存模型

保存模型还可以添加一些元数据,如训练迭代次数、训练精度等。

        Path modelDir = Paths.get("build/mlp");
        Files.createDirectories(modelDir);

        model.setProperty("Epoch", String.valueOf(epoch));

        model.save(modelDir, "mlp");

        System.out.println(model);

六、源代码

1. pom

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>com.xundh</groupId>
    <artifactId>djl-learning</artifactId>
    <version>0.1-SNAPSHOT</version>

    <properties>
        <maven.compiler.source>1.8</maven.compiler.source>
        <maven.compiler.target>1.8</maven.compiler.target>
        <java.version>8</java.version>
        <djl.version>0.13.0-SNAPSHOT</djl.version>
    </properties>

    <dependencyManagement>
        <dependencies>
            <dependency>
                <groupId>ai.djl</groupId>
                <artifactId>bom</artifactId>
                <version>$djl.version</version>
                <type>pom</type>
                <scope>import</scope>
            </dependency>
        </dependencies>
    </dependencyManagement>
    <dependencies>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>api</artifactId>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>basicdataset</artifactId>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>model-zoo</artifactId>
        </dependency>
        <!-- Pytorch -->
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-engine</artifactId>
        </dependency>
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-native-auto</artifactId>
            <version>1.7.0</version>
        </dependency>
    </dependencies>
</project>

2. java

package com.xundh;

import ai.djl.Model;
import ai.djl.basicdataset.cv.classification.Mnist;
import ai.djl.basicmodelzoo.basic.Mlp;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;

public class NDArrayLearning 
    public static void main(String[] args) throws IOException, TranslateException 
        int batchSize = 32;
        Mnist mnist = Mnist.builder().setSampling(batchSize, true).build();
        mnist.prepare(new ProgressBar());

        Model model = Model.newInstance("mlp");
        model.setBlock(new Mlp(28 * 28, 10, new int[]128, 64));
        DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
                .addEvaluator(new Accuracy())
                .addTrainingListeners(TrainingListener.Defaults.logging());

        Trainer trainer = model.newTrainer(config);
        trainer.initialize(new Shape(1, 28 * 28));
        int epoch = 2;

        EasyTrain.fit(trainer, epoch, mnist, null);

        Path modelDir = Paths.get("build/mlp");
        Files.createDirectories(modelDir);

        model.setProperty("Epoch", String.valueOf(epoch));

        model.save(modelDir, "mlp");

        System.out.println(model);
    


运行结果示例:

以上是关于Java程序员学深度学习 DJL上手5 训练自己的模型的主要内容,如果未能解决你的问题,请参考以下文章

Java程序员学深度学习 DJL上手6 使用自己的模型

Java程序员学深度学习 DJL上手7 使用Pytorch引擎

Java程序员学深度学习 DJL上手9 在CIFAR-10数据集使用风格迁移学习

Java程序员学深度学习 DJL上手9 在CIFAR-10数据集使用风格迁移学习

Java程序员学深度学习 DJL上手4 NDArray基本操作

Java程序员学深度学习 DJL上手4 NDArray基本操作