Java程序员学深度学习 DJL上手7 使用Pytorch引擎

Posted 编程圈子

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Java程序员学深度学习 DJL上手7 使用Pytorch引擎相关的知识,希望对你有一定的参考价值。

本文主要讲解如何在DJL调用Pytorch引擎并使用Pytorch的对象。 由于DJL只支持ScriptTorch格式,所以自己的PyTorch模型需要进行格式转换。本文前节讲了转换的方式 ,后面的演示从网络加载已经转换好的ScriptTorch格式模型。

一、DJL 项目在maven引用 Pytorch引擎

1. 引用 pytorch-engin

<dependency>
    <groupId>ai.djl.pytorch</groupId>
    <artifactId>pytorch-engine</artifactId>
    <version>0.13.0-SNAPSHOT</version>
    <scope>runtime</scope>
</dependency>

2. 引用 pytorch-native-auto库

目前pytoch-engin的每个版本只能匹配一个pytorch版本,对应关系如下:

PyTorch engine versionPyTorch native library version
pytorch-engine:0.13.0pytorch-native-auto:1.9.0
pytorch-engine:0.12.0pytorch-native-auto:1.8.1
pytorch-engine:0.11.0pytorch-native-auto:1.8.1
pytorch-engine:0.10.0pytorch-native-auto:1.7.1
pytorch-engine:0.9.0pytorch-native-auto:1.7.0
pytorch-engine:0.8.0pytorch-native-auto:1.6.0
pytorch-engine:0.7.0pytorch-native-auto:1.6.0
pytorch-engine:0.6.0pytorch-native-auto:1.5.0
pytorch-engine:0.5.0pytorch-native-auto:1.4.0
pytorch-engine:0.4.0pytorch-native-auto:1.4.0

使用示例:

<dependency>
    <groupId>ai.djl.pytorch</groupId>
    <artifactId>pytorch-native-auto</artifactId>
    <version>1.9.0</version>
    <scope>runtime</scope>
</dependency>

这一步操作与CPU、操作架构、GPU等也有关系,但pytorch-native-auto将自动匹配相应的版本。
如果自适应有问题,可以到 http://docs.djl.ai/engines/pytorch/pytorch-engine/index.html 查询对应架构需要的库进行手工修改。

二、PyTorch的Model Zoo预训练模型库


<dependency>
    <groupId>ai.djl.pytorch</groupId>
    <artifactId>pytorch-model-zoo</artifactId>
    <version>0.13.0-SNAPSHOT</version>
</dependency>

ModelZoo的预训练的模型主要是机器视觉模型,包括:

  • 图片分类
  • 目标检测
  • 风格迁移
  • 图像生成
    等。

三、PyTorch模型格式转换

需要将PyTorch的模型转为 TorchScript 格式,转换方式主要有两种 :跟踪(Tracing)和脚本(Scripting)。
Tracing的脚本示例:

import torch
import torchvision

# 指向你自己的模型
model = torchvision.models.resnet18(pretrained=True)

# 转为测试模式
model.eval()


# 提供一个示例数据给模型的前向处理(forward)方法
example = torch.rand(1, 3, 224, 224)

# 执行Trace
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)

# 保存 TorchScript模型
traced_script_module.save("traced_resnet_model.pt")

四、加载PyTorch模型

1. 准备模型

下面的示例假设已经准备好了TorchScript格式模型, 这里使用预训练的resnet18模型,
DownloadUtils函数用来下载网络上的模型,目标文件夹是build/pytorch_models。

DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/resnet/0.0.1/traced_resnet18.pt.gz", "build/pytorch_models/resnet18/resnet18.pt", new ProgressBar());
Downloading: 100% |████████████████████████████████████████| resnet18.pt

配合resnet18模型还要有标签文件,同样使用DownloadUtils下载。

DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/synset.txt", "build/pytorch_models/resnet18/synset.txt", new ProgressBar());
Downloading: 100% |████████████████████████████████████████| synset.txt

2. 创建转换器(Translator)

先创建一个管道(每个图像要经过的预处理):

preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

然后创建转换器:

Translator<Image, Classifications> translator = ImageClassificationTranslator.builder()
        .addTransform(new Resize(256))
        .addTransform(new CenterCrop(224, 224))
        .addTransform(new ToTensor())
        .addTransform(new Normalize(
            new float[] {0.485f, 0.456f, 0.406f},
            new float[] {0.229f, 0.224f, 0.225f}))
        .optApplySoftmax(true)
        .build();

3. 加载自己的模型

加载模型时需要一些参数,如 optModelPath 告知模型的位置 。

Criteria<Image, Classifications> criteria = Criteria.builder()
        .setTypes(Image.class, Classifications.class)
        .optModelPath(Paths.get("build/pytorch_models/resnet18"))
        .optTranslator(translator)
        .optProgress(new ProgressBar()).build();

ZooModel model = criteria.loadModel();
Loading:     100% |████████████████████████████████████████|

4. 加载分类器

var img = ImageFactory.getInstance().fromUrl("https://raw.githubusercontent.com/pytorch/hub/master/images/dog.jpg");
img.getWrappedImage()

5. 执行推理

Predictor<Image, Classifications> predictor = model.newPredictor();
Classifications classifications = predictor.predict(img);

打印结果:

classifications

[
    class: "n02111889 Samoyed, Samoyede", probability: 0.94256
    class: "n02114548 white wolf, Arctic wolf, Canis lupus tundrarum", probability: 0.02820
    class: "n02111500 Great Pyrenees", probability: 0.01032
    class: "n02120079 Arctic fox, white fox, Alopex lagopus", probability: 0.00412
    class: "n02109961 Eskimo dog, husky", probability: 0.00279
]

五、源代码

1. pom.xml

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>com.xundh</groupId>
    <artifactId>djl-learning</artifactId>
    <version>0.1-SNAPSHOT</version>

    <properties>
        <maven.compiler.source>1.8</maven.compiler.source>
        <maven.compiler.target>1.8</maven.compiler.target>
        <java.version>8</java.version>
        <djl.version>0.13.0-SNAPSHOT</djl.version>
    </properties>

    <dependencyManagement>
        <dependencies>
            <dependency>
                <groupId>ai.djl</groupId>
                <artifactId>bom</artifactId>
                <version>${djl.version}</version>
                <type>pom</type>
                <scope>import</scope>
            </dependency>
        </dependencies>
    </dependencyManagement>
    <dependencies>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>api</artifactId>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>basicdataset</artifactId>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>model-zoo</artifactId>
        </dependency>
        <!-- Pytorch -->
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-engine</artifactId>
        </dependency>
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-native-auto</artifactId>
            <version>1.9.0</version>
        </dependency>
    </dependencies>
</project>

2. java

package com.xundh;

import ai.djl.MalformedModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.CenterCrop;
import ai.djl.modality.cv.transform.Normalize;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.DownloadUtils;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;

import java.io.IOException;
import java.nio.file.Paths;

public class PyTorchLearn {
    public static void main(String[] args) throws IOException, TranslateException, MalformedModelException, ModelNotFoundException {
        DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/resnet/0.0.1/traced_resnet18.pt.gz", "build/pytorch_models/resnet18/resnet18.pt", new ProgressBar());
        DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/synset.txt", "build/pytorch_models/resnet18/synset.txt", new ProgressBar());

        Translator<Image, Classifications> translator = ImageClassificationTranslator.builder()
                .addTransform(new Resize(256))
                .addTransform(new CenterCrop(224, 224))
                .addTransform(new ToTensor())
                .addTransform(new Normalize(
                        new float[] {0.485f, 0.456f, 0.406f},
                        new float[] {0.229f, 0.224f, 0.225f}))
                .optApplySoftmax(true)
                .build();

        Criteria<Image, Classifications> criteria = Criteria.builder()
                .setTypes(Image.class, Classifications.class)
                .optModelPath(Paths.get("build/pytorch_models/resnet18"))
                .optTranslator(translator)
                .optProgress(new ProgressBar()).build();

        ZooModel model = criteria.loadModel();
        Image img = ImageFactory.getInstance().fromUrl("https://img-blog.csdnimg.cn/4c1c40b41c6a49afa69f7ccf96e24ddf.png?x-oss-process=image/watermark,type_ZHJvaWRzYW5zZmFsbGJhY2s,shadow_50,text_Q1NETiBA57yW56iL5ZyI5a2Q,size_20,color_FFFFFF,t_70,g_se,x_16#pic_center");
        img.getWrappedImage();
        Predictor<Image, Classifications> predictor = model.newPredictor();
        Classifications classifications = predictor.predict(img);
        System.out.println(classifications);
    }
}

六、加载本地模型

package com.xundh;

import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.CenterCrop;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.training.util.DownloadUtils;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.Pipeline;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;

import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;

public class PyTorchLearn {
    public static void main(String[] args) throws IOException, TranslateException, MalformedModelException, ModelNotFoundException {
        DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/resnet/0.0.1/traced_resnet18.pt.gz", "build/pytorch_models/resnet18/resnet18.pt", new ProgressBar());
        DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/synset.txt", "build/pytorch_models/resnet18/synset.txt", new ProgressBar());

        Path modelDir = Paths.get("build/pytorch_models/resnet18");
        Model model = Model.newInstance("resnet");
        model.load(mod

以上是关于Java程序员学深度学习 DJL上手7 使用Pytorch引擎的主要内容,如果未能解决你的问题,请参考以下文章

Java程序员学深度学习 DJL上手4 NDArray基本操作

Java程序员学深度学习 DJL上手1

Java程序员学深度学习 DJL上手1

Java程序员学深度学习 DJL上手6 使用自己的模型

Java程序员学深度学习 DJL上手2 Springboot集成

Java程序员学深度学习 DJL上手8 使用风格迁移学习