Java程序员学深度学习 DJL上手1

Posted 编程圈子

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Java程序员学深度学习 DJL上手1相关的知识,希望对你有一定的参考价值。

Java程序员学深度学习 DJL上手1

一、简介

官网地址:http://djl.ai/
DPL是一款开源的Java深度学习框架,易启动、Java程序员容易上手操作。

下面是一个推理的伪代码示例:

  // Assume user uses a pre-trained model from model zoo, they just need to load it
    Criteria<Image, Classifications> criteria =
            Criteria.builder()
                    .optApplication(Application.CV.OBJECT_DETECTION) // find object dection model
                    .setTypes(Image.class, Classifications.class) // define input and output
                    .optFilter("backbone", "resnet50") // choose network architecture
                    .build();

    try (ZooModel<Image, Classifications> model = criteria.loadModel()) 
        try (Predictor<Image, Classifications> predictor = model.newPredictor()) 
            Image img = ImageFactory.getInstance().fromUrl("http://..."); // read image
            Classifications result = predictor.predict(img);

            // get the classification and probability
            ...
        
    

训练的伪代码示例:

    // Construct your neural network with built-in blocks
    Block block = new Mlp(28, 28);

    try (Model model = Model.newInstance("mlp"))  // Create an empty model
        model.setBlock(block); // set neural network to model

        // Get training and validation dataset (MNIST dataset)
        Dataset trainingSet = new Mnist.Builder().setUsage(Usage.TRAIN) ... .build();
        Dataset validateSet = new Mnist.Builder().setUsage(Usage.TEST) ... .build();

        // Setup training configurations, such as Initializer, Optimizer, Loss ...
        TrainingConfig config = setupTrainingConfig();
        try (Trainer trainer = model.newTrainer(config)) 
            /*
             * Configure input shape based on dataset to initialize the trainer.
             * 1st axis is batch axis, we can use 1 for initialization.
             * MNIST is 28x28 grayscale image and pre processed into 28 * 28 NDArray.
             */
            Shape inputShape = new Shape(1, 28 * 28);
            trainer.initialize(new Shape[] inputShape);

            EasyTrain.fit(trainer, epoch, trainingSet, validateSet);
        

        // Save the model
        model.save(modelDir, "mlp");
    

DPL 仍需要其它的深度学习框架。下面 demo 需要安装pytorch。

二、准备环境

  • mac
  • 安装anaconda,或直接安装python3环境。
  • idea

二、从头开始

1. 使用anaconda新建mxnet环境

conda create -n mxnet
conda activate mxnet
conda install mxnet

后面的demo不一定需要mxnet,但必须要python环境。

2. 新建一个空的 idea maven项目。

pom.xml文件:

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>com.xundh</groupId>
    <artifactId>djl</artifactId>
    <version>1.0-SNAPSHOT</version>

    <properties>
        <maven.compiler.source>8</maven.compiler.source>
        <maven.compiler.target>8</maven.compiler.target>
        <djl.version>0.12.0</djl.version>
    </properties>

    <dependencies>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>api</artifactId>
            <version>$djl.version</version>
        </dependency>
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-model-zoo</artifactId>
            <version>$djl.version</version>
        </dependency>
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-engine</artifactId>
            <version>$djl.version</version>
            <scope>runtime</scope>
        </dependency>
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-native-auto</artifactId>
            <version>1.8.1</version>
        </dependency>
    </dependencies>
</project>

3. 安装Python Commu Edition 插件

4. 在Module新增conda的环境

如果使用conda跑python环境,则进行此步配置。

三、一个简单的模型

import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.CenterCrop;
import ai.djl.modality.cv.transform.Normalize;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.DownloadUtils;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.Pipeline;
import ai.djl.translate.Translator;

import java.io.File;
import java.io.FileInputStream;
import java.net.URL;

public class ResnetDemo 
    public static void main(String[] args) throws Exception
        // 下载aws 预先训练好的 resnet 模型
        DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/resnet/0.0.1/traced_resnet18.pt.gz", "build/pytorch_models/resnet18/resnet18.pt", new ProgressBar());
        // 下载标签
        DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/synset.txt", "build/pytorch_models/resnet18/synset.txt", new ProgressBar());
        // 图片预处理
        Pipeline pipeline = new Pipeline();
        pipeline.add(new Resize(256))
                .add(new CenterCrop(224, 224))
                .add(new ToTensor())
                .add(new Normalize(
                        new float[]0.485f, 0.456f, 0.406f,
                        new float[]0.229f, 0.224f, 0.225f));

        Translator<Image, Classifications> translator = ImageClassificationTranslator.builder()
                .setPipeline(pipeline)
                .optApplySoftmax(true)
                .build();

        // 设置模型
        System.setProperty("ai.djl.repository.zoo.location", "build/pytorch_models/resnet18");
        Criteria<Image, Classifications> criteria = Criteria.builder()
                .setTypes(Image.class, Classifications.class)
                // only search the model in local directory
                // "ai.djl.localmodelzoo:name of the model"
                .optArtifactId("ai.djl.localmodelzoo:resnet18")
                .optTranslator(translator)
                .optProgress(new ProgressBar()).build();
        ZooModel<Image,Classifications> model = ModelZoo.loadModel(criteria);
        // 加载本地图片
        URL path = ResnetDemo.class.getClassLoader().getResource("demo.png");

        assert path != null;
        File fs =new File(path.getPath());

        Image img = ImageFactory.getInstance().fromInputStream(new FileInputStream(fs));
        // 执行推理
        Predictor<Image, Classifications> predictor = model.newPredictor();
        Classifications classifications = predictor.predict(img);
        System.out.println(classifications);
    


项目框架:

随便找了一张狗的图片:

运行结果:


运行结果:

Loading:     100% |========================================|
[
	class: "n02091134 whippet", probability: 0.17714
	class: "n02099712 Labrador retriever", probability: 0.05827
	class: "n02091032 Italian greyhound", probability: 0.05465
	class: "n02091831 Saluki, gazelle hound", probability: 0.05111
	class: "n02090622 borzoi, Russian wolfhound", probability: 0.05043
]

Process finished with exit code 0

以上是关于Java程序员学深度学习 DJL上手1的主要内容,如果未能解决你的问题,请参考以下文章

Java程序员学深度学习 DJL上手7 使用Pytorch引擎

Java程序员学深度学习 DJL上手2 Springboot集成

Java程序员学深度学习 DJL上手5 训练自己的模型

Java程序员学深度学习 DJL上手5 训练自己的模型

Java程序员学深度学习 DJL上手6 使用自己的模型

Java程序员学深度学习 DJL上手3 创建神经网络