Java程序员学深度学习 DJL上手3 创建神经网络

Posted 编程圈子

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Java程序员学深度学习 DJL上手3 创建神经网络相关的知识,希望对你有一定的参考价值。

Java程序员学深度学习 DJL上手3 创建神经网络

本文目的是创建一个图像分类模型 。由于本系列文章主要关注的是Java编写深度学习应用,而且我本人对深度学习原理也不熟悉,所以本系列文章不过多讲解深度学习的原理。

一、 准备环境

  • windows
  • 安装idea
  • 安装maven

本文使用 MNIST 手写字体数据集,数据集包含0-9的黑白手写数字,每个图像大小28*28。使用 MNIST 的手写字体数据集对神经网络进行训练和预测,可以认为是深度神经网络的“Hello World”入门程序。

本文并没有实现完整的训练、识别,只是建了一个神经网络层。

二、 多层感知器的简单概念

本文将要创建的神经网络称为MLP,它是最简单的神经网络,而其它一些神经网络,如:误差反向传播(BP)、概率神经网络、卷积神经网络(CNN,适用于图像识别)、时间神经网络(LSTM,适用于语音识别)等,都可以认为是MLP的变种。

MLP是分层来处理数据,第一层是包含要输入数据的输入层,最后一层是产生结果的输出层。当中的层称为隐藏层。

下面是一个MLP示例,包含大小为3的输入层、大小3的单个隐藏层和大小为2的输出层。

一个MLP的隐藏层数据、大小通常要经过试验测试来确定。每对层之间都有一个线性操作(称为全连接操作,是个矩阵乘法运算);每个线性操作后面还要有非线性的激活功能。

三、确定输入和输出大小

MLP模型 使用一维矢量作为输入和输出。比如本文将要使用的MNIST数据集里,图像大小是28*28,每个像素点用0和1即可表示出来,下面定义的变量:

long inputSize = 28*28;
long outputSize = 10;

输出变量10表示 有每个图像有10个可能的分类。

四、创建顺序块(SequentialBlock)

1. NDArray和NDList

NDArray是深度学习的核心数据类型。NDArray与Numpy里的多维数组有点像,代表多维、固定大小的均匀阵列。
NDList是一个NDArrays的列表,可以有不同的大小和数据类型。

2. 块(block)

DJL中由块来组成单个操作或神经网络,它可以表示单个操作、或g作为神经网络的一部分,甚至用来拼成整个神经网络。
这里调用API 创建一个顺序块。

SequentialBlock block = new SequentialBlock();

五、将块添加到顺序块

MLP 分为多个层,每个层由线性块和非线性的激活函数组成。这里使用常见的 ReLU 激活函数。

第一层和最后一层有固定的大小,中间层则取决于试验和经验。

block.add(Blocks.batchFlattenBlock(inputSize));
block.add(Linear.builder().setUnits(128).build());
block.add(Activation::relu);
block.add(Linear.builder().setUnits(64).build());
block.add(Activation::relu);
block.add(Linear.builder().setUnits(outputSize).build());
log.info(block.ToString());

通过设置的断点,可以查看block的结构:

六、源代码

pom.xml

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>com.xundh</groupId>
    <artifactId>djl</artifactId>
    <version>1.0-SNAPSHOT</version>

    <properties>
        <maven.compiler.source>8</maven.compiler.source>
        <maven.compiler.target>8</maven.compiler.target>
        <djl.version>0.9.0</djl.version>
    </properties>

    <dependencies>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>api</artifactId>
            <version>$djl.version</version>
        </dependency>
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-model-zoo</artifactId>
            <version>$djl.version</version>
        </dependency>
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-engine</artifactId>
            <version>$djl.version</version>
            <scope>runtime</scope>
        </dependency>
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-native-auto</artifactId>
            <version>1.8.1</version>
        </dependency>
        <!-- https://mvnrepository.com/artifact/org.projectlombok/lombok -->
        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
            <version>1.18.20</version>
            <scope>provided</scope>
        </dependency>

    </dependencies>
</project>

learn1.java

package com.jcg.djl;

import ai.djl.nn.Activation;
import ai.djl.nn.Blocks;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.core.Linear;
import lombok.extern.slf4j.Slf4j;

@Slf4j
public class Learn1 
    public static void main(String[] args)
        long inputSize = 28*28;
        long outputSize = 10;

        SequentialBlock block = new SequentialBlock();

        block.add(Blocks.batchFlattenBlock(inputSize));
        block.add(Linear.builder().setUnits(128).build());
        block.add(Activation::relu);
        block.add(Linear.builder().setUnits(64).build());
        block.add(Activation::relu);
        block.add(Linear.builder().setUnits(outputSize).build());

        log.info(block.toString());
    


以上是关于Java程序员学深度学习 DJL上手3 创建神经网络的主要内容,如果未能解决你的问题,请参考以下文章

Java程序员学深度学习 DJL上手4 NDArray基本操作

Java程序员学深度学习 DJL上手4 NDArray基本操作

Java程序员学深度学习 DJL上手5 训练自己的模型

Java程序员学深度学习 DJL上手5 训练自己的模型

Java程序员学深度学习 DJL上手7 使用Pytorch引擎

Java程序员学深度学习 DJL上手1