DJL快速入门(纯Java跑深度学习模型)

Posted iioSnail

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了DJL快速入门(纯Java跑深度学习模型)相关的知识,希望对你有一定的参考价值。

文章目录

1. 本文介绍

服务端大多都是用Java做的,而深度学习模型大多又是用Python写的,所以很多人都是用Java调Python的接口,这样效率低,而且也不优雅,最重要的是如果想使用Android做推理,那就必须要用Java写了

本文使用了一个重要的工具:Deep Java Library,这是一个用Java进行深度学习的库,你可以用它来进行模型推理,甚至是训练模型。很多文章也都介绍过该模型,但是他们都漏了一个重要的内容:深度学习代码不只是推理部分,还有很多预处理和后续处理的部分需要很多Tensor操作,但是他们都没说怎么做。

为了符合大家的实际需求,本文不使用DJL进行模型训练,只做推理。本文的具体内容包括:

  1. DJL核心内容讲解
  2. DJL加载Pytorch模型
  3. DJL的Tensor操作
  4. DJL简单案例(DJL使用Pytorch模型完成图片分类)



2. DJL核心内容讲解


2.1 DJL简介

DJL是一个开源的深度学习 Java 框架(支持Android),其可以用于深度学习模型构建和训练、Tensor操作使用预训练好的常见模型(MXNet、Pytorch、TensorFlow等)Java1.8 以上就可以用,且支持GPU


2.2 DJL核心API

在实际案例之前,先讲解下DJL的核心API,这样在后续的案例也知道代码是做什么的。


2.2.1 Criteria

Criteria 类对象定义了模型的情况,如模型路径、输入和输出等。

例如,这是一段初始化DJL模型的代码:

Criteria<Input, Output> criteria = Criteria.builder()
        .setTypes(Input.class, Output.class) // defines input and output data type
        .optTranslator(new InputOutputTranslator())
        .optModelPath(Paths.get("/var/models/my_resnet50")) // search models in specified path
        .optModelName("model/resnet50") // specify model file prefix
        .build();

ZooModel<Image, Classifications> model = criteria.loadModel();

在上述代码中,Criteria描述了模型的情况,主要包含以下几点:

  • Criteria<I, O> 定义了模型输入和输出。这里的 IO 可以是自定义的类,也可以使用DJL提供的类。
  • setTypes(I.class, O.class):这个代码是必须的。直接从泛型的I,O是获取不到输入和输出的class对象的,所以需要手动设置一下。
  • optModelTranslator:模型的输入和输出是一个Tensor类型。这里就是设置你的I类和O类应该如何与Tensor类型进行转化。后续会具体讲Translator
  • optModelName:设置一下模型名称

定义好模型的情况,就可以使用loadModel方法实例化出 Model Zoo 对象了。

Model Zoo 是DJL的模型,你需要通过该类对象对模型进行进行管理,例如创建模型、创建Predictor,保存模型等。


2.2.2 Translator

在上一节中,模型的输入类和输出类是可以自定义的,但Pytorch模型不可能接收你自己定义的类对象啊,它只会接受Tensor类型,所以我们就需要使用Translator接口来定义如何将我们的自定义输入输出类转换为Tensor类型。

private Translator<Input, Output> translator = new Translator<Input, Output>() 

    @Override
    public NDList processInput(TranslatorContext ctx, Input input) throws Exception 
        return null;
    

    @Override
    public Output processOutput(TranslatorContext ctx, NDList ndList) throws Exception 
        return null;
    
;

Translator接口包含两个接口:

  • processInput:将输入类对象转化为Tensor。这里的Input就是输入类对象,而NDList就是Tensor的集合(因为模型的forward可能会接收多个Tensor参数)。在DJL中,Tensor对应的类为NDArray(类似numpy中的ndarray),后续会详细讲解。
  • processOutput:将模型输出的Tensor转换为自定义类。由于模型可能会输出多个Tensor,所以这里也是NDList

上述这两个方法还包含一个重要的参数TranslatorContext,这个保存了Translator的上下文,可以用它来拿到一些对象(Model, Predictor等),也可以通过setAttachmentgetAttachment 方法来存取一些东西。

在官方的例子中,Translator是对图像进行处理,但Translator并非只能处理图像,这里的Input和Output可以是任意Java类。


2.2.3 NDArray

在python中,我们有numpy,而在Java中,我们有DJL的NDArray,使用该类,我们几乎可以实现Numpy中的所有Tensor操作。本节将会介绍常用的tensor操作。

开始前先介绍与NDArray相关的几个类:

  • NDArray:相当于numpy.ndarray,可以通过getShape()方法获取其shape
  • NDManager:NDArray的管理类,全局new一个就行了,需要用该类对象创建NDArray
  • NDIndex:用于对Tensor进行切片
  • Shape: 创建NDArray的时候,需要指定Shape。获取NDArray的Shape时返回的也是该类的对象。

接下来开始具体演示Tensor的常见操作(这里只举几个例子,有不会的操作可以在评论区告知,我会进行补充):


创建NDArray(Tensor)

创建一个Shape为(1,2,3,4)的Tensor

NDManager ndManager = NDManager.newBaseManager();
NDArray ndArray = ndManager.create(new Shape(1, 2, 3, 4));

ndManager全局应只创建一个

指定值创建:

ndManager.create(new int[]1, 2, 3, 4);

变更数据类型

变为float类型

ndManager.create(new int[]1, 2, 3, 4).toType(DataType.FLOAT32, false);

变为float数组:

ndManager.create(new int[]1, 2, 3, 4).toType(DataType.FLOAT32, false)
									.toFloatArray();

注意,在toArray()前需要将NDArray转变为相对应的类型,且字节数要对上。例如在java中float是使用32个bit(4个字节)存储的,所以NDArray的类型必须是Float32,不能是Float64,否则会报错。

运算

加减乘除:

ndArray.add(1);
ndArray.sub(1);
ndArray.mul(1);
ndArray.div(1);

也可以使用NDArrays.add,类似np.add()

NDArrays.add(ndArray, ndArray);

切片

NDArray ndArray = ndManager.arange(24).reshape(3, 8);
ndArray = ndArray.get(new NDIndex("1:, :"));

等价于python中的[1:, :]

DJL的切片好像不能指定index,例如 x = [1,2,3], y = [2,3,4],然后切片 nums[x, y]。 DJL中我还没找到应该如何这样切,所以我只能自己用for循环实现,如果大家知道怎么弄,欢迎在评论区告诉我

赋值

NDArray ndArray = ndManager.arange(24).reshape(3, 8);
ndArray.set(new NDIndex("1:, :"), 1);

等价于Python的ndArray[1:, :] = 1

翻转

在Python中,对数组进行翻转可以使用[..., ::-1],但java中不行,但可以利用flip函数实现

NDArray ndArray = ndManager.arange(24).reshape(3, 8);
ndArray = ndArray.flip(-1);

2.2.3 Predictor

创建好模型后,需要new一个Predictor,然后用这个Predictor进行预测:

predictor = zooModel.newPredictor();
Output output = predictor.predict(input);

到这里DJL常用的API就讲完了,接下来使用一个简单的案例进行实战。


3. 实战:DJL使用Pytorch模型完成图片分类

这里使用Pytorch提供的resnet18模型完成一个图片分类任务。

  1. 首先引入依赖:
<dependency>
    <groupId>ai.djl.pytorch</groupId>
    <artifactId>pytorch-engine</artifactId>
    <version>0.17.0</version>
    <scope>runtime</scope>
</dependency>

<dependency>
    <groupId>ai.djl.pytorch</groupId>
    <artifactId>pytorch-native-cpu</artifactId>
    <classifier>win-x86_64</classifier>
    <scope>runtime</scope>
    <version>1.11.0</version>
</dependency>
<dependency>
    <groupId>ai.djl.pytorch</groupId>
    <artifactId>pytorch-jni</artifactId>
    <version>1.11.0-0.17.0</version>
    <scope>runtime</scope>
</dependency>

<dependency>
    <groupId>ai.djl</groupId>
    <artifactId>api</artifactId>
    <version>0.17.0</version>
</dependency>

<dependency>
    <groupId>ai.djl</groupId>
    <artifactId>basicdataset</artifactId>
    <version>0.17.0</version>
</dependency>

<dependency>
    <groupId>ai.djl.opencv</groupId>
    <artifactId>opencv</artifactId>
    <version>0.17.0</version>
</dependency>
  1. 导出pytorch的resnet18模型:
import torch
import torchvision

# An instance of your model.
model = torchvision.models.resnet18(pretrained=True)

# Switch the model to eval model
model.eval()

# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)

# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)

# Save the TorchScript model
traced_script_module.save("traced_resnet_model.pt")
  1. 将导出的模型拷贝到项目的model目录下:

  2. 创建Translator,这里我们定义输入为String类型,表示图片的输入路径;输出也为String,表示类别。将图片送入Resnet18网络,需要做一些预处理:

...
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
...

这里利用Java的NDArray的

Translator<String/*filename*/, String/*class*/> translator = new Translator<String, String>() 

    @Override
    public NDList processInput(TranslatorContext ctx, String input) throws Exception 
        // 根据路径读取图片
        Image image = ImageFactory.getInstance().fromFile(Paths.get(input));
        NDArray ndArray = image.toNDArray(ctx.getNDManager());
        // 在图片送入resnet前要做一些预处理,官方的例子中使用transforms,但为了本文的前后呼应,我这里就用上面将的NDArray的操作来完成
        Resize resize = new Resize(256, 256);
        ndArray = resize.transform(ndArray); // 将图片的大小resize到256x256

        // py: transforms.CenterCrop(224)
        // NDArray没有CenterCrop方法,但是我们可以通过切片的方式实现
        ndArray = ndArray.get(new NDIndex("16:240, 16:240, :"));

        // ToTensor会将Shape的(224,224,3)转变为(3,224,224),并且将值从0-255缩放到0-1
        ndArray = new ToTensor().transform(ndArray);

        // py: transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        Normalize normalize = new Normalize(new float[]0.485f, 0.456f, 0.406f, new float[]0.229f, 0.224f, 0.225f);
        ndArray = normalize.transform(ndArray);

        return new NDList(ndArray);  // resnet只接受一个Tensor
    

    @Override
    public String processOutput(TranslatorContext ctx, NDList list) throws Exception 
        // resnet只返回一个tensor,所以要get(0)
        int index = list.get(0).argMax().toType(DataType.INT32, false).getInt();
        // 由于resnet可以识别1000种物体,这里我就只给index了。
        return index + "";
    
;
  1. 定义Criteria,然后实例化模型,并newPredictor
Criteria<String, String> criteria = Criteria.builder()
        .setTypes(String.class, String.class)
        .optModelPath(Paths.get("model/traced_resnet_model.pt"))
        .optOption("mapLocation", "true")
        .optTranslator(translator)
        .build();

ZooModel model = criteria.loadModel();
Predictor predictor = model.newPredictor();
  1. 准备一张图片,我这里放在项目的test目录下:

  1. 进行预测
System.out.println(predictor.predict("test/test.jpg"));

由于resnet可以识别1000个物体,太多了,所以我只输出了index,全部的类别可以到该链接查找。最终输出为:

258

258对应的类别为Samoyed(萨摩耶),可以看得到预测对了。

DJL更多的例子可以参考官方Demo






参考资料

Deep Java Library官方文档:https://docs.djl.ai/

Dive Into Deep Learning: https://d2l.djl.ai/chapter_preliminaries/ndarray.html

djl-demo: https://github.com/deepjavalibrary/djl-demo

Java程序员学深度学习 DJL上手6 使用自己的模型

本文使用前一节训练的模型执行图片推理任务。

一、加载手写体待推理图片

        Image img = ImageFactory.getInstance().fromUrl("https://resources.djl.ai/images/0.png");
        img.getWrappedImage();

图像比较小,就是个数字0:

二、加载模型

        Path modelDir = Paths.get("build/mlp");
        Model model = Model.newInstance("mlp");
        model.setBlock(new Mlp(28 * 28, 10, new int[] {128, 64}));
        model.load(modelDir);

这里的参数与训练时定义的参数一致。

三、创建 Translator

Translator是DJL封装的推理预处理、后处理功能,输入参数是单个数据项,而不是一批数据。

        Translator<Image, Classifications> translator = new Translator<Image, Classifications>() {

            @Override
            public NDList processInput(TranslatorContext ctx, Image input) {
                // Convert Image to NDArray
                NDArray array = input.toNDArray(ctx.getNDManager(), Image.Flag.GRAYSCALE);
                return new NDList(NDImageUtils.toTensor(array));
            }

            @Override
            public Classifications processOutput(TranslatorContext ctx, NDList list) {
                // Create a Classifications with the output probabilities
                NDArray probabilities = list.singletonOrThrow().softmax(0);
                List<String> classNames = IntStream.range(0, 10).mapToObj(String::valueOf).collect(Collectors.toList());
                return new Classifications(classNames, probabilities);
            }

            @Override
            public Batchifier getBatchifier() {
                // The Batchifier describes how to combine a batch together
                // Stacking, the most common batchifier, takes N [X1, X2, ...] arrays to a single [N, X1, X2, ...] array
                return Batchifier.STACK;
            }
        };

四、创建推理对象

        Predictor<Image, Classifications> predictor = model.newPredictor(translator);

按DJL官网文档描述,每次执行推理任务的时候最好创建新的推理器。

五、执行推理任务

        Classifications classifications = predictor.predict(img);
        System.out.println(classifications);

另外,在ModelZoo里有一些训练好的模型可以拿来测试使用,类似本系列第一篇文章所写的操作。

六、源代码

1. pom.xml

与前一文章相同

2. java

package com.xundh;

import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.basicmodelzoo.basic.Mlp;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.util.NDImageUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.translate.Batchifier;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;

import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

public class NDArrayLearning {
    public static void main(String[] args) throws IOException, TranslateException, MalformedModelException {
        Image img = ImageFactory.getInstance().fromUrl("https://resources.djl.ai/images/0.png");
        img.getWrappedImage();
        Path modelDir = Paths.get("build/mlp");
        Model model = Model.newInstance("mlp");
        model.setBlock(new Mlp(28 * 28, 10, new int[] {128, 64}));
        model.load(modelDir);

        Translator<Image, Classifications> translator = new Translator<Image, Classifications>() {

            @Override
            public NDList processInput(TranslatorContext ctx, Image input) {
                // Convert Image to NDArray
                NDArray array = input.toNDArray(ctx.getNDManager(), Image.Flag.GRAYSCALE);
                return new NDList(NDImageUtils.toTensor(array));
            }

            @Override
            public Classifications processOutput(TranslatorContext ctx, NDList list) {
                // Create a Classifications with the output probabilities
                NDArray probabilities = list.singletonOrThrow().softmax(0);
                List<String> classNames = IntStream.range(0, 10).mapToObj(String::valueOf).collect(Collectors.toList());
                return new Classifications(classNames, probabilities);
            }

            @Override
            public Batchifier getBatchifier() {
                // The Batchifier describes how to combine a batch together
                // Stacking, the most common batchifier, takes N [X1, X2, ...] arrays to a single [N, X1, X2, ...] array
                return Batchifier.STACK;
            }
        };
        Predictor<Image, Classifications> predictor = model.newPredictor(translator);
        Classifications classifications = predictor.predict(img);
        System.out.println(classifications);

    }
}

3. 执行结果

[
	class: "0", probability: 0.99994
	class: "2", probability: 0.00004
	class: "6", probability: 2.9e-06
	class: "9", probability: 5.7e-07
	class: "1", probability: 2.7e-07
]

推测结果是 数字0的可能最大。

以上是关于DJL快速入门(纯Java跑深度学习模型)的主要内容,如果未能解决你的问题,请参考以下文章

Java程序员学深度学习 DJL上手5 训练自己的模型

Java程序员学深度学习 DJL上手6 使用自己的模型

Java程序员学深度学习 DJL上手7 使用Pytorch引擎

Java程序员学深度学习 DJL上手9 在CIFAR-10数据集使用风格迁移学习

Java程序员学深度学习 DJL上手9 在CIFAR-10数据集使用风格迁移学习

Java程序员学深度学习 DJL上手1