Java程序员学深度学习 DJL上手1
Posted 编程圈子
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Java程序员学深度学习 DJL上手1相关的知识,希望对你有一定的参考价值。
Java程序员学深度学习 DJL上手1
一、简介
官网地址:http://djl.ai/
DPL是一款开源的Java深度学习框架,易启动、Java程序员容易上手操作。
下面是一个推理的伪代码示例:
// Assume user uses a pre-trained model from model zoo, they just need to load it
Criteria<Image, Classifications> criteria =
Criteria.builder()
.optApplication(Application.CV.OBJECT_DETECTION) // find object dection model
.setTypes(Image.class, Classifications.class) // define input and output
.optFilter("backbone", "resnet50") // choose network architecture
.build();
try (ZooModel<Image, Classifications> model = criteria.loadModel())
try (Predictor<Image, Classifications> predictor = model.newPredictor())
Image img = ImageFactory.getInstance().fromUrl("http://..."); // read image
Classifications result = predictor.predict(img);
// get the classification and probability
...
训练的伪代码示例:
// Construct your neural network with built-in blocks
Block block = new Mlp(28, 28);
try (Model model = Model.newInstance("mlp")) // Create an empty model
model.setBlock(block); // set neural network to model
// Get training and validation dataset (MNIST dataset)
Dataset trainingSet = new Mnist.Builder().setUsage(Usage.TRAIN) ... .build();
Dataset validateSet = new Mnist.Builder().setUsage(Usage.TEST) ... .build();
// Setup training configurations, such as Initializer, Optimizer, Loss ...
TrainingConfig config = setupTrainingConfig();
try (Trainer trainer = model.newTrainer(config))
/*
* Configure input shape based on dataset to initialize the trainer.
* 1st axis is batch axis, we can use 1 for initialization.
* MNIST is 28x28 grayscale image and pre processed into 28 * 28 NDArray.
*/
Shape inputShape = new Shape(1, 28 * 28);
trainer.initialize(new Shape[] inputShape);
EasyTrain.fit(trainer, epoch, trainingSet, validateSet);
// Save the model
model.save(modelDir, "mlp");
DPL 仍需要其它的深度学习框架。下面 demo 需要安装pytorch。
二、准备环境
- mac
- 安装anaconda,或直接安装python3环境。
- idea
二、从头开始
1. 使用anaconda新建mxnet环境
conda create -n mxnet
conda activate mxnet
conda install mxnet
后面的demo不一定需要mxnet,但必须要python环境。
2. 新建一个空的 idea maven项目。
pom.xml文件:
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.xundh</groupId>
<artifactId>djl</artifactId>
<version>1.0-SNAPSHOT</version>
<properties>
<maven.compiler.source>8</maven.compiler.source>
<maven.compiler.target>8</maven.compiler.target>
<djl.version>0.12.0</djl.version>
</properties>
<dependencies>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
<version>$djl.version</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-model-zoo</artifactId>
<version>$djl.version</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
<version>$djl.version</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-auto</artifactId>
<version>1.8.1</version>
</dependency>
</dependencies>
</project>
3. 安装Python Commu Edition 插件
4. 在Module新增conda的环境
如果使用conda跑python环境,则进行此步配置。
三、一个简单的模型
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.CenterCrop;
import ai.djl.modality.cv.transform.Normalize;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.DownloadUtils;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.Pipeline;
import ai.djl.translate.Translator;
import java.io.File;
import java.io.FileInputStream;
import java.net.URL;
public class ResnetDemo
public static void main(String[] args) throws Exception
// 下载aws 预先训练好的 resnet 模型
DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/resnet/0.0.1/traced_resnet18.pt.gz", "build/pytorch_models/resnet18/resnet18.pt", new ProgressBar());
// 下载标签
DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/synset.txt", "build/pytorch_models/resnet18/synset.txt", new ProgressBar());
// 图片预处理
Pipeline pipeline = new Pipeline();
pipeline.add(new Resize(256))
.add(new CenterCrop(224, 224))
.add(new ToTensor())
.add(new Normalize(
new float[]0.485f, 0.456f, 0.406f,
new float[]0.229f, 0.224f, 0.225f));
Translator<Image, Classifications> translator = ImageClassificationTranslator.builder()
.setPipeline(pipeline)
.optApplySoftmax(true)
.build();
// 设置模型
System.setProperty("ai.djl.repository.zoo.location", "build/pytorch_models/resnet18");
Criteria<Image, Classifications> criteria = Criteria.builder()
.setTypes(Image.class, Classifications.class)
// only search the model in local directory
// "ai.djl.localmodelzoo:name of the model"
.optArtifactId("ai.djl.localmodelzoo:resnet18")
.optTranslator(translator)
.optProgress(new ProgressBar()).build();
ZooModel<Image,Classifications> model = ModelZoo.loadModel(criteria);
// 加载本地图片
URL path = ResnetDemo.class.getClassLoader().getResource("demo.png");
assert path != null;
File fs =new File(path.getPath());
Image img = ImageFactory.getInstance().fromInputStream(new FileInputStream(fs));
// 执行推理
Predictor<Image, Classifications> predictor = model.newPredictor();
Classifications classifications = predictor.predict(img);
System.out.println(classifications);
项目框架:
随便找了一张狗的图片:
运行结果:
运行结果:
Loading: 100% |========================================|
[
class: "n02091134 whippet", probability: 0.17714
class: "n02099712 Labrador retriever", probability: 0.05827
class: "n02091032 Italian greyhound", probability: 0.05465
class: "n02091831 Saluki, gazelle hound", probability: 0.05111
class: "n02090622 borzoi, Russian wolfhound", probability: 0.05043
]
Process finished with exit code 0
以上是关于Java程序员学深度学习 DJL上手1的主要内容,如果未能解决你的问题,请参考以下文章
Java程序员学深度学习 DJL上手7 使用Pytorch引擎