Java学习(Day 37)
Posted 言山兮尺川
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Java学习(Day 37)相关的知识,希望对你有一定的参考价值。
学习来源:日撸 Java 三百行(81-90天,CNN 卷积神经网络)_闵帆的博客-CSDN博客
文章目录
前言
本文代码来自 CSDN文章: 日撸 Java 三百行(81-90天,CNN 卷积神经网络)
我将借用这部分代码对 CNN 进行一个更深层次的理解.
卷积神经网络 (代码篇)
一、数据集读取与存储
1. 数据集描述
简要描述一下我们需要读取的数据集.
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
乍一看这不就是由 0 和 1组成的集合吗? 这个时候我们对这些数字想象成一个图片, 然后通过一些工具就可以呈现出下面的这样一副图片.
这张图片的大小就为 28 × 28 28 \\times 28 28×28, 那这堆数据最后不是多出了一个数字吗? 这个数字要表达什么意思呢? 这个时候仔细观察图片, 它是不是看起来像数字 ‘0’. 为了检验这个想法是否正确, 我们再找一行数据.
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3
虽然图中的数字写法不标准, 但是隐约中还是能判别为数字 ‘3’, 然后多出的那个数字正好是 ‘3’. 由此得出结论, 数据集的每一行代表一张图片, 由 ‘0’ ‘1’ 表示其黑白像素点, 且该行最后一个数字表示图片中数字的值.
所以对于这个数据集数据的读取就是把图片的像素点以数组方式存储, 数组的大小就是图片的大小. 然后用一个单独的值存储图片中所表示的数字, 把这个就作为图片的标签.
2. 具体代码
package cnn;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/**
* Manage the dataset.
*
* @author Shi-Huai Wen Email: shihuaiwen@outlook.com.
*/
public class Dataset
/**
* All instances organized by a list.
*/
private List<Instance> instances;
/**
* The label index.
*/
private int labelIndex;
/**
* The max label (label start from 0).
*/
private double maxLabel = -1;
/**
* **********************
* The first constructor.
* **********************
*/
public Dataset()
labelIndex = -1;
instances = new ArrayList<>();
// Of the first constructor
/**
* **********************
* The second constructor.
*
* @param paraFilename The filename.
* @param paraSplitSign Often comma.
* @param paraLabelIndex Often the last column.
* **********************
*/
public Dataset(String paraFilename, String paraSplitSign, int paraLabelIndex)
instances = new ArrayList<>();
labelIndex = paraLabelIndex;
File tempFile = new File(paraFilename);
try
BufferedReader tempReader = new BufferedReader(new FileReader(tempFile));
String tempLine;
while ((tempLine = tempReader.readLine()) != null)
String[] tempDatum = tempLine.split(paraSplitSign);
if (tempDatum.length == 0)
continue;
// Of if
double[] tempData = new double[tempDatum.length];
for (int i = 0; i < tempDatum.length; i++)
tempData[i] = Double.parseDouble(tempDatum[i]);
Instance tempInstance = new Instance(tempData);
append(tempInstance);
// Of while
tempReader.close();
catch (IOException e)
e.printStackTrace();
System.out.println("Unable to load " + paraFilename);
System.exit(0);
//Of try
// Of the second constructor
/**
* **********************
* Append an instance.
*
* @param paraInstance The given record.
* **********************
*/
public void append(Instance paraInstance)
instances.add(paraInstance);
// Of append
/**
* **********************
* Append an instance specified by double values.
* **********************
*/
public void append(double[] paraAttributes, Double paraLabel)
instances.add(new Instance(paraAttributes, paraLabel));
// Of append
/**
* **********************
* Getter.
* **********************
*/
public Instance getInstance(int paraIndex)
return instances.get(paraIndex);
// Of getInstance
/**
* **********************
* Getter.
* **********************
*/
public int size()
return instances.size();
// Of size
/**
* **********************
* Getter.
* **********************
*/
public double[] getAttributes(int paraIndex)
return instances.get(paraIndex).getAttributes();
// Of getAttrs
/**
* **********************
* Getter.
* **********************
*/
public Double getLabel(int paraIndex)
return instances.get(paraIndex).getLabel();
// Of getLabel
/**
* **********************
* Unit test.
* **********************
*/
public static void main(String[] args)
Dataset tempData = new Dataset("D:/Work/Data/sampledata/train.format", ",", 784);
Instance tempInstance = tempData.getInstance(0);
System.out.println("The first instance is: " + tempInstance);
System.out.println("The first instance label is: " + tempInstance.label);
tempInstance = tempData.getInstance(1);
System.out.println("The second instance is: " + tempInstance);
System.out.println("The second instance label is: " + tempInstance.label);
// Of main
/**
* **********************
* An instance.
* **********************
*/
public class Instance
/**
* Conditional attributes.
*/
private double[] attributes;
/**
* Label.
*/
private Double label;
/**
* **********************
* The first constructor.
* **********************
*/
private Instance(double[] paraAttrs, Double paraLabel)
attributes = paraAttrs;
label = paraLabel;
//Of the first constructor
/**
* **********************
* The second constructor.
* **********************
*/
public Instance(double[] paraData)
if (labelIndex == -1)
// No label
attributes = paraData;
else
label = paraData[labelIndex];
if (label > maxLabel)
// It is a new label
maxLabel = label;
// Of if
if (labelIndex == 0)
// The first column is the label
attributes = Arrays.copyOfRange(paraData, 1, paraData.length);
else
// The last column is the label
attributes = Arrays.copyOfRange(paraData, 0, paraData.length - 1);
// Of if
// Of if
// Of the second constructor
/**
* **********************
* Getter.
* **********************
*/
public double[] getAttributes()
return attributes;
// Of getAttributes
/**
* **********************
* Getter.
* **********************
*/
public Double getLabel()
if (labelIndex == -1)
return null;
return label;
// Of getLabel
/**
* **********************
* toString.
* **********************
*/
public String toString()
return Arrays.toString(attributes) + ", " + label;
//Of toString
// Of class Instance
//Of class Dataset
3. 运行截图
二、卷积核大小的基本操作
1. 操作
对卷积核大小进行处理, 也就是对卷积核的长和宽进行处理.
一个方法是长和宽同时除以两个整数, 要是不能被整除就抛出错误. 例如:
(4, 12) / (2, 3) -> (2, 4)
(2, 2) / (4, 6) -> Error
另一个方法是长和宽同时减去两个整数, 然后再加上 1. 例如:
(4, 6) - (2, 2) + 1 -> (3,5)
2. 具体代码
package cnn;
/**
* The size of a convolution core.
*
* @author Shi-Huai Wen Email: shihuaiwen@outlook.com.
*/
public class Size
/**
* Cannot be changed after initialization.
*/
public final int width;
/**
* Cannot be changed after initialization.
*/
public final int height;
/**
* **********************
* The first constructor.
*
* @param paraWidth The given width.
* @param paraHeight The given height.
* **********************
*/
public Size(int paraWidth, int paraHeight)
width = paraWidth;
height = paraHeight;
// Of the first constructor
/**
* **********************
* Divide a scale with another one. For example (4, 12) / (2, 3) = (2, 4).
*
* @param paraScaleSize The given scale size.
* @return The new size.
* **********************
*/
public Size divide(Size paraScaleSize)
int resultWidth = width / paraScaleSize.width;
int resultHeight = height / paraScaleSize.height;
if (resultWidth * paraScaleSize.width != width || resultHeight * paraScaleSize.height != height)
throw new RuntimeException("Unable to divide " + this + " with " + paraScaleSize);
return new Size(resultWidth, resultHeight);
// Of divide
/**
* **********************
* Subtract a scale with another one, and add a value. For example (4, 12) -
* (2, 3) + 1 = (3, 10).
*
* @param paraScaleSize The given scale size.
* @param paraAppend The appended size to both dimensions.
* @return The new size.
* **********************
*/
public Size subtract(Size paraScaleSize, int paraAppend)
int resultWidth = width - paraScaleSize.width + paraAppend;
int resultHeight = height - paraScaleSize.height + paraAppend;
return new Size(resultWidth, resultHeight);
// Of subtract
public String toString()
String resultString = "(" + width + ", " + height + ")";
return resultString;
// Of toString
/**
* **********************
* Unit test.
* **********************
*/
public static void main(String[] args)
Size tempSize1 = new Size(4, 6);
Size tempSize2 = new Size(2, 2);
System.out.println("" + tempSize1 + " divide " + tempSize2 + " = " + tempSize1.divide(tempSize2));
try
System.out.println("" + tempSize2 + " divide " + tempSize1 + " = " + tempSize2.divide(tempSize1));
catch (Exception ee)
System.out.println("Error is :" + ee);
// Of try
System.out.println("" + tempSize1 + " - " + tempSize2 + " + 1 = " + tempSize1.subtract(tempSize2, 1));
// Of main
//Of class Size
3. 运行截图
三、数学工具类
1. 工具函数
定义了一个算子, 其主要目的是为了矩阵操作时对每个元素都做一遍. 有对单个矩阵进行运算, 例如用 1 减去矩阵中的值, 或者对矩阵中的值使用 S i g m o i d Sigmoid Sigmoid 函数. 有对两个矩阵进行运算, 例如两个矩阵之间的加法还有减法.
矩阵旋转 180 度, 其实就是旋转两次 90 度. 旋转 90 度的公式为
m
a
t
r
i
x
[
r
o
w
]
[
c
o
l
]
=
r
o
t
a
t
e
m
a
t
r
i
x
n
e
w
[
c
o
l
]
[
n
−
r
o
w
−
1
]
matrix[row][col] \\oversetrotate=matrix_new[col][n - row - 1]
matrix[row][col]=rotatematrixnew[col][n−row−1]
convnValid 是卷积操作. convnFull 为其逆向操作.
scaleMatrix 是均值池化. kronecker 是池化的逆向操作.
2. 具体代码
package cnn;
import java.io.Serializable;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Random;
import java.util.Set;
/**
* Math operations. Adopted from cnn-master.
*
* @author Shi-Huai Wen Email: shihuaiwen@outlook.com.
*/
public class MathUtils
/**
* An interface for different on-demand operators.
*/
public interface Operator extends Serializable
double process(double value);
// Of interface Operator
/**
* The one-minus-the-value operator.
*/
public static final Operator one_value = new Operator()
private static final long serialVersionUID = 3752139491940330714L;
@Override
public double process(double value)
return 1 - value;
// Of process
;
/**
* The sigmoid operator.
*/
public static final Operator sigmoid = new Operator()
private static final long serialVersionUID = -1952718905019847589L;
@Override
public double process(double value)
return 1 / (1 + Math.pow(Math.E, -value));
// Of process
;
/**
* An interface for operations with two operators.
*/
interface OperatorOnTwo extends Serializable
double process(double a, double b);
// Of interface OperatorOnTwo
/**
* Plus.
*/
public static final OperatorOnTwo plus = new OperatorOnTwo()
private static final long serialVersionUID = -6298144029766839945L;
@Override
public以上是关于Java学习(Day 37)的主要内容,如果未能解决你的问题,请参考以下文章