Java程序员学深度学习 DJL上手2 Springboot集成
Posted 编程圈子
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Java程序员学深度学习 DJL上手2 Springboot集成相关的知识,希望对你有一定的参考价值。
Java程序员学深度学习 DJL上手2 Springboot集成
一、准备环境
- windows
- idea
- jdk11
- maven
本文使用 model-zoo models 运行目标检测任务。
model-zoo 是来自新加坡的许靖宇建立的包含许多深度学习模型的网站。
二、新建项目
最终目录结构如下:
代码地址在:https://examples.javacodegeeks.com/djl-spring-boot-example/
三、pom.xml
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>ai.djl</groupId>
<artifactId>image-object-detection</artifactId>
<version>1.0.0-SNAPSHOT</version>
<properties>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
<djl.version>0.11.0</djl.version>
</properties>
<repositories>
<repository>
<id>djl.ai</id>
<url>https://oss.sonatype.org/content/repositories/snapshots/</url>
</repository>
</repositories>
<dependencyManagement>
<dependencies>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>bom</artifactId>
<version>${djl.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
<version>2.3.4.RELEASE</version>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>model-zoo</artifactId>
<version>${djl.version}</version>
</dependency>
<dependency>
<groupId>ai.djl.mxnet</groupId>
<artifactId>mxnet-model-zoo</artifactId>
<version>${djl.version}</version>
</dependency>
<dependency>
<groupId>ai.djl.mxnet</groupId>
<artifactId>mxnet-engine</artifactId>
<version>${djl.version}</version>
</dependency>
<dependency>
<groupId>ai.djl.mxnet</groupId>
<artifactId>mxnet-native-auto</artifactId>
<version>1.6.0</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
<version>1.2.3</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<version>RELEASE</version>
<scope>compile</scope>
</dependency>
</dependencies>
<build>
<finalName>${project.artifactId}</finalName>
<plugins>
<plugin>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
<version>2.0.1.RELEASE</version>
<executions>
<execution>
<goals>
<goal>repackage</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>
四、源代码
1. SpringBoot 入口
package com.jcg.djl;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
@SpringBootApplication
public class ImageObjectDetectionApplication {
public static void main(String[] args) {
SpringApplication.run(ImageObjectDetectionApplication.class, args);
}
}
2. Controller
package com.jcg.djl;
import ai.djl.Application;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.compress.utils.IOUtils;
import org.springframework.core.io.ClassPathResource;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;
import org.springframework.web.servlet.support.ServletUriComponentsBuilder;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Objects;
@Slf4j
@RestController
public class ImageDetectController {
@PostMapping(value = "/upload", produces = MediaType.IMAGE_PNG_VALUE)
public ResponseEntity<String> diagnose(@RequestParam("file") MultipartFile file) throws ModelException, TranslateException, IOException {
byte[] bytes = file.getBytes();
Path imageFile = Paths.get(Objects.requireNonNull(file.getOriginalFilename()));
Files.write(imageFile, bytes);
return predict(imageFile);
}
public ResponseEntity<String> predict(Path imageFile) throws IOException, ModelException, TranslateException {
Image img = ImageFactory.getInstance().fromFile(imageFile);
Criteria<Image, DetectedObjects> criteria =
Criteria.builder()
.optApplication(Application.CV.OBJECT_DETECTION)
.setTypes(Image.class, DetectedObjects.class)
.optFilter("backbone", "resnet50")
.optProgress(new ProgressBar())
.build();
try (ZooModel<Image, DetectedObjects> model = ModelZoo.loadModel(criteria)) {
try (Predictor<Image, DetectedObjects> predictor = model.newPredictor()) {
DetectedObjects detection = predictor.predict(img);
return saveBoundingBoxImage(img, detection);
}
}
}
private ResponseEntity<String> saveBoundingBoxImage(Image img, DetectedObjects detection)
throws IOException {
Path outputDir = Paths.get("src/main/resources");
Files.createDirectories(outputDir);
// Make image copy with alpha channel because original image was jpg
Image newImage = img.duplicate(Image.Type.TYPE_INT_ARGB);
newImage.drawBoundingBoxes(detection);
Path imagePath = outputDir.resolve("detected.png");
// OpenJDK can't save jpg with alpha channel
newImage.save(Files.newOutputStream(imagePath), "png");
log.info("Detected objects image has been saved in:{}" , imagePath);
String fileDownloadUri = ServletUriComponentsBuilder.fromCurrentContextPath()
.path("get")
.toUriString();
return ResponseEntity.ok(fileDownloadUri);
}
@GetMapping(
value = "/get",
produces = MediaType.IMAGE_PNG_VALUE
)
public @ResponseBody
byte[] getImageWithMediaType() throws IOException {
InputStream in = new ClassPathResource( "detected.png").getInputStream();
return IOUtils.toByteArray(in);
}
}
3. application.xml
djl:
# 设定应用种类
application-type: OBJECT_DETECTION
# 设定输入数据格式, 有的模型支持多种数据格式
input-class: java.awt.image.BufferedImage
# 设定输出数据格式
output-class: ai.djl.modality.cv.output.DetectedObjects
# 设定一个筛选器来筛选你的模型
model-filter:
size: 512
# backbone: mobilenet1.0
# 覆写已有的输入输出配置
arguments:
threshold: 0.5 # 只展示预测结果大于等于 0.5
五、使用方式
1. 运行程序:
mvn spring-boot:run
2. 打开网页
http://localhost:8080
3. 上传要识别的图片
4. 下载识别结果
打开地址: http://localhost:8080/get
以上是关于Java程序员学深度学习 DJL上手2 Springboot集成的主要内容,如果未能解决你的问题,请参考以下文章
Java程序员学深度学习 DJL上手2 Springboot集成
Java程序员学深度学习 DJL上手7 使用Pytorch引擎
Java程序员学深度学习 DJL上手4 NDArray基本操作