Java程序员学深度学习 DJL上手3 创建神经网络
Posted 编程圈子
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Java程序员学深度学习 DJL上手3 创建神经网络相关的知识,希望对你有一定的参考价值。
Java程序员学深度学习 DJL上手3 创建神经网络
本文目的是创建一个图像分类模型 。由于本系列文章主要关注的是Java编写深度学习应用,而且我本人对深度学习原理也不熟悉,所以本系列文章不过多讲解深度学习的原理。
一、 准备环境
- windows
- 安装idea
- 安装maven
本文使用 MNIST 手写字体数据集,数据集包含0-9的黑白手写数字,每个图像大小28*28。使用 MNIST 的手写字体数据集对神经网络进行训练和预测,可以认为是深度神经网络的“Hello World”入门程序。
本文并没有实现完整的训练、识别,只是建了一个神经网络层。
二、 多层感知器的简单概念
本文将要创建的神经网络称为MLP,它是最简单的神经网络,而其它一些神经网络,如:误差反向传播(BP)、概率神经网络、卷积神经网络(CNN,适用于图像识别)、时间神经网络(LSTM,适用于语音识别)等,都可以认为是MLP的变种。
MLP是分层来处理数据,第一层是包含要输入数据的输入层,最后一层是产生结果的输出层。当中的层称为隐藏层。
下面是一个MLP示例,包含大小为3的输入层、大小3的单个隐藏层和大小为2的输出层。
一个MLP的隐藏层数据、大小通常要经过试验测试来确定。每对层之间都有一个线性操作(称为全连接操作,是个矩阵乘法运算);每个线性操作后面还要有非线性的激活功能。
三、确定输入和输出大小
MLP模型 使用一维矢量作为输入和输出。比如本文将要使用的MNIST数据集里,图像大小是28*28,每个像素点用0和1即可表示出来,下面定义的变量:
long inputSize = 28*28;
long outputSize = 10;
输出变量10表示 有每个图像有10个可能的分类。
四、创建顺序块(SequentialBlock)
1. NDArray和NDList
NDArray是深度学习的核心数据类型。NDArray与Numpy里的多维数组有点像,代表多维、固定大小的均匀阵列。
NDList是一个NDArrays的列表,可以有不同的大小和数据类型。
2. 块(block)
DJL中由块来组成单个操作或神经网络,它可以表示单个操作、或g作为神经网络的一部分,甚至用来拼成整个神经网络。
这里调用API 创建一个顺序块。
SequentialBlock block = new SequentialBlock();
五、将块添加到顺序块
MLP 分为多个层,每个层由线性块和非线性的激活函数组成。这里使用常见的 ReLU 激活函数。
第一层和最后一层有固定的大小,中间层则取决于试验和经验。
block.add(Blocks.batchFlattenBlock(inputSize));
block.add(Linear.builder().setUnits(128).build());
block.add(Activation::relu);
block.add(Linear.builder().setUnits(64).build());
block.add(Activation::relu);
block.add(Linear.builder().setUnits(outputSize).build());
log.info(block.ToString());
通过设置的断点,可以查看block的结构:
六、源代码
pom.xml
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.xundh</groupId>
<artifactId>djl</artifactId>
<version>1.0-SNAPSHOT</version>
<properties>
<maven.compiler.source>8</maven.compiler.source>
<maven.compiler.target>8</maven.compiler.target>
<djl.version>0.9.0</djl.version>
</properties>
<dependencies>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
<version>$djl.version</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-model-zoo</artifactId>
<version>$djl.version</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
<version>$djl.version</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-auto</artifactId>
<version>1.8.1</version>
</dependency>
<!-- https://mvnrepository.com/artifact/org.projectlombok/lombok -->
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<version>1.18.20</version>
<scope>provided</scope>
</dependency>
</dependencies>
</project>
learn1.java
package com.jcg.djl;
import ai.djl.nn.Activation;
import ai.djl.nn.Blocks;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.core.Linear;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class Learn1
public static void main(String[] args)
long inputSize = 28*28;
long outputSize = 10;
SequentialBlock block = new SequentialBlock();
block.add(Blocks.batchFlattenBlock(inputSize));
block.add(Linear.builder().setUnits(128).build());
block.add(Activation::relu);
block.add(Linear.builder().setUnits(64).build());
block.add(Activation::relu);
block.add(Linear.builder().setUnits(outputSize).build());
log.info(block.toString());
以上是关于Java程序员学深度学习 DJL上手3 创建神经网络的主要内容,如果未能解决你的问题,请参考以下文章
Java程序员学深度学习 DJL上手4 NDArray基本操作
Java程序员学深度学习 DJL上手4 NDArray基本操作