Java程序员学深度学习 DJL上手2 Springboot集成

Posted 编程圈子

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Java程序员学深度学习 DJL上手2 Springboot集成相关的知识,希望对你有一定的参考价值。

一、准备环境

  • windows
  • idea
  • jdk11
  • maven

本文使用 model-zoo models 运行目标检测任务。
model-zoo 是来自新加坡的许靖宇建立的包含许多深度学习模型的网站。

二、新建项目

最终目录结构如下:

代码地址在:https://examples.javacodegeeks.com/djl-spring-boot-example/

三、pom.xml

<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>ai.djl</groupId>
    <artifactId>image-object-detection</artifactId>
    <version>1.0.0-SNAPSHOT</version>

    <properties>
        <maven.compiler.source>1.8</maven.compiler.source>
        <maven.compiler.target>1.8</maven.compiler.target>
        <djl.version>0.11.0</djl.version>
    </properties>

    <repositories>
        <repository>
            <id>djl.ai</id>
            <url>https://oss.sonatype.org/content/repositories/snapshots/</url>
        </repository>
    </repositories>

    <dependencyManagement>
        <dependencies>
            <dependency>
                <groupId>ai.djl</groupId>
                <artifactId>bom</artifactId>
                <version>${djl.version}</version>
                <type>pom</type>
                <scope>import</scope>
            </dependency>
        </dependencies>
    </dependencyManagement>
    <dependencies>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
            <version>2.3.4.RELEASE</version>
        </dependency>

        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>model-zoo</artifactId>
            <version>${djl.version}</version>
        </dependency>
        <dependency>
            <groupId>ai.djl.mxnet</groupId>
            <artifactId>mxnet-model-zoo</artifactId>
            <version>${djl.version}</version>
        </dependency>
        <dependency>
            <groupId>ai.djl.mxnet</groupId>
            <artifactId>mxnet-engine</artifactId>
            <version>${djl.version}</version>
        </dependency>
        <dependency>
            <groupId>ai.djl.mxnet</groupId>
            <artifactId>mxnet-native-auto</artifactId>
            <version>1.6.0</version>
            <scope>runtime</scope>
        </dependency>
        <dependency>
            <groupId>ch.qos.logback</groupId>
            <artifactId>logback-classic</artifactId>
            <version>1.2.3</version>
            <scope>provided</scope>
        </dependency>
        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
            <version>RELEASE</version>
            <scope>compile</scope>
        </dependency>
    </dependencies>
    <build>
        <finalName>${project.artifactId}</finalName>
        <plugins>
            <plugin>
                <groupId>org.springframework.boot</groupId>
                <artifactId>spring-boot-maven-plugin</artifactId>
                <version>2.0.1.RELEASE</version>
                <executions>
                    <execution>
                        <goals>
                            <goal>repackage</goal>
                        </goals>
                    </execution>
                </executions>
            </plugin>
        </plugins>
    </build>
</project>

四、源代码

1. SpringBoot 入口

package com.jcg.djl;


import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;

@SpringBootApplication
public class ImageObjectDetectionApplication {

    public static void main(String[] args) {
        SpringApplication.run(ImageObjectDetectionApplication.class, args);
    }


}

2. Controller

package com.jcg.djl;

import ai.djl.Application;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.compress.utils.IOUtils;
import org.springframework.core.io.ClassPathResource;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;
import org.springframework.web.servlet.support.ServletUriComponentsBuilder;

import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Objects;

@Slf4j
@RestController
public class ImageDetectController {

    @PostMapping(value = "/upload", produces = MediaType.IMAGE_PNG_VALUE)
    public ResponseEntity<String> diagnose(@RequestParam("file") MultipartFile file) throws ModelException, TranslateException, IOException {
        byte[] bytes = file.getBytes();
        Path imageFile = Paths.get(Objects.requireNonNull(file.getOriginalFilename()));
        Files.write(imageFile, bytes);
        return predict(imageFile);
    }


    public ResponseEntity<String> predict(Path imageFile) throws IOException, ModelException, TranslateException {
        Image img = ImageFactory.getInstance().fromFile(imageFile);

        Criteria<Image, DetectedObjects> criteria =
                Criteria.builder()
                        .optApplication(Application.CV.OBJECT_DETECTION)
                        .setTypes(Image.class, DetectedObjects.class)
                        .optFilter("backbone", "resnet50")
                        .optProgress(new ProgressBar())
                        .build();

        try (ZooModel<Image, DetectedObjects> model = ModelZoo.loadModel(criteria)) {
            try (Predictor<Image, DetectedObjects> predictor = model.newPredictor()) {
                DetectedObjects detection = predictor.predict(img);
                return saveBoundingBoxImage(img, detection);
            }
        }
    }


    private ResponseEntity<String> saveBoundingBoxImage(Image img, DetectedObjects detection)
            throws IOException {
        Path outputDir = Paths.get("src/main/resources");
        Files.createDirectories(outputDir);

        // Make image copy with alpha channel because original image was jpg
        Image newImage = img.duplicate(Image.Type.TYPE_INT_ARGB);
        newImage.drawBoundingBoxes(detection);

        Path imagePath = outputDir.resolve("detected.png");
        // OpenJDK can't save jpg with alpha channel
        newImage.save(Files.newOutputStream(imagePath), "png");
        log.info("Detected objects image has been saved in:{}" , imagePath);


        String fileDownloadUri = ServletUriComponentsBuilder.fromCurrentContextPath()
                .path("get")
                .toUriString();
        return ResponseEntity.ok(fileDownloadUri);
    }


    @GetMapping(
            value = "/get",
            produces = MediaType.IMAGE_PNG_VALUE
    )
    public @ResponseBody
    byte[] getImageWithMediaType() throws IOException {
        InputStream in = new ClassPathResource( "detected.png").getInputStream();
        return IOUtils.toByteArray(in);
    }


}

3. application.xml

djl:
  # 设定应用种类
  application-type: OBJECT_DETECTION
  # 设定输入数据格式, 有的模型支持多种数据格式
  input-class: java.awt.image.BufferedImage
  # 设定输出数据格式
  output-class: ai.djl.modality.cv.output.DetectedObjects
  # 设定一个筛选器来筛选你的模型
  model-filter:
    size: 512
  #  backbone: mobilenet1.0
  # 覆写已有的输入输出配置
  arguments:
    threshold: 0.5 # 只展示预测结果大于等于 0.5

五、使用方式

1. 运行程序:

mvn spring-boot:run

2. 打开网页

http://localhost:8080

3. 上传要识别的图片

4. 下载识别结果

打开地址: http://localhost:8080/get

以上是关于Java程序员学深度学习 DJL上手2 Springboot集成的主要内容,如果未能解决你的问题,请参考以下文章

Java程序员学深度学习 DJL上手2 Springboot集成

Java程序员学深度学习 DJL上手7 使用Pytorch引擎

Java程序员学深度学习 DJL上手4 NDArray基本操作

Java程序员学深度学习 DJL上手4 NDArray基本操作

Java程序员学深度学习 DJL上手5 训练自己的模型

Java程序员学深度学习 DJL上手5 训练自己的模型