Java程序员学深度学习 DJL上手7 使用Pytorch引擎
Posted 编程圈子
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Java程序员学深度学习 DJL上手7 使用Pytorch引擎相关的知识,希望对你有一定的参考价值。
Java程序员学深度学习 DJL上手7 使用Pytorch引擎
本文主要讲解如何在DJL调用Pytorch引擎并使用Pytorch的对象。 由于DJL只支持ScriptTorch格式,所以自己的PyTorch模型需要进行格式转换。本文前节讲了转换的方式 ,后面的演示从网络加载已经转换好的ScriptTorch格式模型。
一、DJL 项目在maven引用 Pytorch引擎
1. 引用 pytorch-engin
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
<version>0.13.0-SNAPSHOT</version>
<scope>runtime</scope>
</dependency>
2. 引用 pytorch-native-auto库
目前pytoch-engin的每个版本只能匹配一个pytorch版本,对应关系如下:
PyTorch engine version | PyTorch native library version |
---|---|
pytorch-engine:0.13.0 | pytorch-native-auto:1.9.0 |
pytorch-engine:0.12.0 | pytorch-native-auto:1.8.1 |
pytorch-engine:0.11.0 | pytorch-native-auto:1.8.1 |
pytorch-engine:0.10.0 | pytorch-native-auto:1.7.1 |
pytorch-engine:0.9.0 | pytorch-native-auto:1.7.0 |
pytorch-engine:0.8.0 | pytorch-native-auto:1.6.0 |
pytorch-engine:0.7.0 | pytorch-native-auto:1.6.0 |
pytorch-engine:0.6.0 | pytorch-native-auto:1.5.0 |
pytorch-engine:0.5.0 | pytorch-native-auto:1.4.0 |
pytorch-engine:0.4.0 | pytorch-native-auto:1.4.0 |
使用示例:
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-auto</artifactId>
<version>1.9.0</version>
<scope>runtime</scope>
</dependency>
这一步操作与CPU、操作架构、GPU等也有关系,但pytorch-native-auto将自动匹配相应的版本。
如果自适应有问题,可以到 http://docs.djl.ai/engines/pytorch/pytorch-engine/index.html 查询对应架构需要的库进行手工修改。
二、PyTorch的Model Zoo预训练模型库
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-model-zoo</artifactId>
<version>0.13.0-SNAPSHOT</version>
</dependency>
ModelZoo的预训练的模型主要是机器视觉模型,包括:
- 图片分类
- 目标检测
- 风格迁移
- 图像生成
等。
三、PyTorch模型格式转换
需要将PyTorch的模型转为 TorchScript 格式,转换方式主要有两种 :跟踪(Tracing)和脚本(Scripting)。
Tracing的脚本示例:
import torch
import torchvision
# 指向你自己的模型
model = torchvision.models.resnet18(pretrained=True)
# 转为测试模式
model.eval()
# 提供一个示例数据给模型的前向处理(forward)方法
example = torch.rand(1, 3, 224, 224)
# 执行Trace
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)
# 保存 TorchScript模型
traced_script_module.save("traced_resnet_model.pt")
四、加载PyTorch模型
1. 准备模型
下面的示例假设已经准备好了TorchScript格式模型, 这里使用预训练的resnet18模型,
DownloadUtils函数用来下载网络上的模型,目标文件夹是build/pytorch_models。
DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/resnet/0.0.1/traced_resnet18.pt.gz", "build/pytorch_models/resnet18/resnet18.pt", new ProgressBar());
Downloading: 100% |████████████████████████████████████████| resnet18.pt
配合resnet18模型还要有标签文件,同样使用DownloadUtils下载。
DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/synset.txt", "build/pytorch_models/resnet18/synset.txt", new ProgressBar());
Downloading: 100% |████████████████████████████████████████| synset.txt
2. 创建转换器(Translator)
先创建一个管道(每个图像要经过的预处理):
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
然后创建转换器:
Translator<Image, Classifications> translator = ImageClassificationTranslator.builder()
.addTransform(new Resize(256))
.addTransform(new CenterCrop(224, 224))
.addTransform(new ToTensor())
.addTransform(new Normalize(
new float[] {0.485f, 0.456f, 0.406f},
new float[] {0.229f, 0.224f, 0.225f}))
.optApplySoftmax(true)
.build();
3. 加载自己的模型
加载模型时需要一些参数,如 optModelPath 告知模型的位置 。
Criteria<Image, Classifications> criteria = Criteria.builder()
.setTypes(Image.class, Classifications.class)
.optModelPath(Paths.get("build/pytorch_models/resnet18"))
.optTranslator(translator)
.optProgress(new ProgressBar()).build();
ZooModel model = criteria.loadModel();
Loading: 100% |████████████████████████████████████████|
4. 加载分类器
var img = ImageFactory.getInstance().fromUrl("https://raw.githubusercontent.com/pytorch/hub/master/images/dog.jpg");
img.getWrappedImage()
5. 执行推理
Predictor<Image, Classifications> predictor = model.newPredictor();
Classifications classifications = predictor.predict(img);
打印结果:
classifications
[
class: "n02111889 Samoyed, Samoyede", probability: 0.94256
class: "n02114548 white wolf, Arctic wolf, Canis lupus tundrarum", probability: 0.02820
class: "n02111500 Great Pyrenees", probability: 0.01032
class: "n02120079 Arctic fox, white fox, Alopex lagopus", probability: 0.00412
class: "n02109961 Eskimo dog, husky", probability: 0.00279
]
五、源代码
1. pom.xml
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.xundh</groupId>
<artifactId>djl-learning</artifactId>
<version>0.1-SNAPSHOT</version>
<properties>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
<java.version>8</java.version>
<djl.version>0.13.0-SNAPSHOT</djl.version>
</properties>
<dependencyManagement>
<dependencies>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>bom</artifactId>
<version>${djl.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>
<dependencies>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>basicdataset</artifactId>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>model-zoo</artifactId>
</dependency>
<!-- Pytorch -->
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-auto</artifactId>
<version>1.9.0</version>
</dependency>
</dependencies>
</project>
2. java
package com.xundh;
import ai.djl.MalformedModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.CenterCrop;
import ai.djl.modality.cv.transform.Normalize;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.DownloadUtils;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import java.io.IOException;
import java.nio.file.Paths;
public class PyTorchLearn {
public static void main(String[] args) throws IOException, TranslateException, MalformedModelException, ModelNotFoundException {
DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/resnet/0.0.1/traced_resnet18.pt.gz", "build/pytorch_models/resnet18/resnet18.pt", new ProgressBar());
DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/synset.txt", "build/pytorch_models/resnet18/synset.txt", new ProgressBar());
Translator<Image, Classifications> translator = ImageClassificationTranslator.builder()
.addTransform(new Resize(256))
.addTransform(new CenterCrop(224, 224))
.addTransform(new ToTensor())
.addTransform(new Normalize(
new float[] {0.485f, 0.456f, 0.406f},
new float[] {0.229f, 0.224f, 0.225f}))
.optApplySoftmax(true)
.build();
Criteria<Image, Classifications> criteria = Criteria.builder()
.setTypes(Image.class, Classifications.class)
.optModelPath(Paths.get("build/pytorch_models/resnet18"))
.optTranslator(translator)
.optProgress(new ProgressBar()).build();
ZooModel model = criteria.loadModel();
Image img = ImageFactory.getInstance().fromUrl("https://img-blog.csdnimg.cn/4c1c40b41c6a49afa69f7ccf96e24ddf.png?x-oss-process=image/watermark,type_ZHJvaWRzYW5zZmFsbGJhY2s,shadow_50,text_Q1NETiBA57yW56iL5ZyI5a2Q,size_20,color_FFFFFF,t_70,g_se,x_16#pic_center");
img.getWrappedImage();
Predictor<Image, Classifications> predictor = model.newPredictor();
Classifications classifications = predictor.predict(img);
System.out.println(classifications);
}
}
六、加载本地模型
package com.xundh;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.CenterCrop;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.training.util.DownloadUtils;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.Pipeline;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
public class PyTorchLearn {
public static void main(String[] args) throws IOException, TranslateException, MalformedModelException, ModelNotFoundException {
DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/resnet/0.0.1/traced_resnet18.pt.gz", "build/pytorch_models/resnet18/resnet18.pt", new ProgressBar());
DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/synset.txt", "build/pytorch_models/resnet18/synset.txt", new ProgressBar());
Path modelDir = Paths.get("build/pytorch_models/resnet18");
Model model = Model.newInstance("resnet");
model.load(mod以上是关于Java程序员学深度学习 DJL上手7 使用Pytorch引擎的主要内容,如果未能解决你的问题,请参考以下文章
Java程序员学深度学习 DJL上手4 NDArray基本操作