Java程序员学深度学习 DJL上手6 使用自己的模型

Posted 编程圈子

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Java程序员学深度学习 DJL上手6 使用自己的模型相关的知识,希望对你有一定的参考价值。

本文使用前一节训练的模型执行图片推理任务。

一、加载手写体待推理图片

        Image img = ImageFactory.getInstance().fromUrl("https://resources.djl.ai/images/0.png");
        img.getWrappedImage();

图像比较小,就是个数字0:

二、加载模型

        Path modelDir = Paths.get("build/mlp");
        Model model = Model.newInstance("mlp");
        model.setBlock(new Mlp(28 * 28, 10, new int[] {128, 64}));
        model.load(modelDir);

这里的参数与训练时定义的参数一致。

三、创建 Translator

Translator是DJL封装的推理预处理、后处理功能,输入参数是单个数据项,而不是一批数据。

        Translator<Image, Classifications> translator = new Translator<Image, Classifications>() {

            @Override
            public NDList processInput(TranslatorContext ctx, Image input) {
                // Convert Image to NDArray
                NDArray array = input.toNDArray(ctx.getNDManager(), Image.Flag.GRAYSCALE);
                return new NDList(NDImageUtils.toTensor(array));
            }

            @Override
            public Classifications processOutput(TranslatorContext ctx, NDList list) {
                // Create a Classifications with the output probabilities
                NDArray probabilities = list.singletonOrThrow().softmax(0);
                List<String> classNames = IntStream.range(0, 10).mapToObj(String::valueOf).collect(Collectors.toList());
                return new Classifications(classNames, probabilities);
            }

            @Override
            public Batchifier getBatchifier() {
                // The Batchifier describes how to combine a batch together
                // Stacking, the most common batchifier, takes N [X1, X2, ...] arrays to a single [N, X1, X2, ...] array
                return Batchifier.STACK;
            }
        };

四、创建推理对象

        Predictor<Image, Classifications> predictor = model.newPredictor(translator);

按DJL官网文档描述,每次执行推理任务的时候最好创建新的推理器。

五、执行推理任务

        Classifications classifications = predictor.predict(img);
        System.out.println(classifications);

另外,在ModelZoo里有一些训练好的模型可以拿来测试使用,类似本系列第一篇文章所写的操作。

六、源代码

1. pom.xml

与前一文章相同

2. java

package com.xundh;

import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.basicmodelzoo.basic.Mlp;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.util.NDImageUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.translate.Batchifier;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;

import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

public class NDArrayLearning {
    public static void main(String[] args) throws IOException, TranslateException, MalformedModelException {
        Image img = ImageFactory.getInstance().fromUrl("https://resources.djl.ai/images/0.png");
        img.getWrappedImage();
        Path modelDir = Paths.get("build/mlp");
        Model model = Model.newInstance("mlp");
        model.setBlock(new Mlp(28 * 28, 10, new int[] {128, 64}));
        model.load(modelDir);

        Translator<Image, Classifications> translator = new Translator<Image, Classifications>() {

            @Override
            public NDList processInput(TranslatorContext ctx, Image input) {
                // Convert Image to NDArray
                NDArray array = input.toNDArray(ctx.getNDManager(), Image.Flag.GRAYSCALE);
                return new NDList(NDImageUtils.toTensor(array));
            }

            @Override
            public Classifications processOutput(TranslatorContext ctx, NDList list) {
                // Create a Classifications with the output probabilities
                NDArray probabilities = list.singletonOrThrow().softmax(0);
                List<String> classNames = IntStream.range(0, 10).mapToObj(String::valueOf).collect(Collectors.toList());
                return new Classifications(classNames, probabilities);
            }

            @Override
            public Batchifier getBatchifier() {
                // The Batchifier describes how to combine a batch together
                // Stacking, the most common batchifier, takes N [X1, X2, ...] arrays to a single [N, X1, X2, ...] array
                return Batchifier.STACK;
            }
        };
        Predictor<Image, Classifications> predictor = model.newPredictor(translator);
        Classifications classifications = predictor.predict(img);
        System.out.println(classifications);

    }
}

3. 执行结果

[
	class: "0", probability: 0.99994
	class: "2", probability: 0.00004
	class: "6", probability: 2.9e-06
	class: "9", probability: 5.7e-07
	class: "1", probability: 2.7e-07
]

推测结果是 数字0的可能最大。

以上是关于Java程序员学深度学习 DJL上手6 使用自己的模型的主要内容,如果未能解决你的问题,请参考以下文章

Java程序员学深度学习 DJL上手5 训练自己的模型

Java程序员学深度学习 DJL上手7 使用Pytorch引擎

Java程序员学深度学习 DJL上手4 NDArray基本操作

Java程序员学深度学习 DJL上手4 NDArray基本操作

Java程序员学深度学习 DJL上手1

Java程序员学深度学习 DJL上手1