JAVA实现BP神经网络算法
Posted MrZhaoyx
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了JAVA实现BP神经网络算法相关的知识,希望对你有一定的参考价值。
工作中需要预测一个过程的时间,就想到了使用BP神经网络来进行预测。
简介
BP神经网络(Back Propagation Neural Network)是一种基于BP算法的人工神经网络,其使用BP算法进行权值与阈值的调整。在20世纪80年代,几位不同的学者分别开发出了用于训练多层感知机的反向传播算法,David Rumelhart和James McClelland提出的反向传播算法是最具影响力的。其包含BP的两大主要过程,即工作信号的正向传播与误差信号的反向传播,分别负责了神经网络中输出的计算与权值和阈值更新。工作信号的正向传播是通过计算得到BP神经网络的实际输出,误差信号的反向传播是由后往前逐层修正权值与阈值,为了使实际输出更接近期望输出。
(1)工作信号正向传播。输入信号从输入层进入,通过突触进入隐含层神经元,经传递函数运算后,传递到输出层,并且在输出层计算出输出信号传出。当工作信号正向传播时,权值与阈值固定不变,神经网络中每层的状态只与前一层的净输出、权值和阈值有关。若正向传播在输出层获得到期望的输出,则学习结束,并保留当前的权值与阈值;若正向传播在输出层得不到期望的输出,则在误差信号的反向传播中修正权值与阈值。
(2)误差信号反向传播。在工作信号正向传播后若得不到期望的输出,则通过计算误差信号进行反向传播,通过计算BP神经网络的实际输出与期望输出之间的差值作为误差信号,并且由神经网络的输出层,逐层向输入层传播。在此过程中,每向前传播一层,就对该层的权值与阈值进行修改,由此一直向前传播直至输入层,该过程是为了使神经网络的结果与期望的结果更相近。
当进行一次正向传播和反向传播后,若误差仍不能达到要求,则该过程继续下去,直至误差满足精度,或者满足迭代次数等其他设置的结束条件。
推导请见 https://zh.wikipedia.org/wiki/%E5%8F%8D%E5%90%91%E4%BC%A0%E6%92%AD%E7%AE%97%E6%B3%95
BPNN结构
该BPNN为单输入层单隐含层单输出层结构
项目结构
介绍一些用到的类
- ActivationFunction:激活函数的接口
- BPModel:BP模型实体类
- BPNeuralNetworkFactory:BP神经网络工厂,包括训练BP神经网络,计算,序列化等功能
- BPParameter:BP神经网络参数实体类
- Matrix:矩阵实体类
- Sigmoid:Sigmoid传输函数,实现了ActivationFunction接口
- MatrixUtil:矩阵工具类
实现代码
Matrix实体类
模拟了矩阵的基本运算方法。
package com.top.matrix; import com.top.constants.OrderEnum; import java.io.Serializable; public class Matrix implements Serializable { private double[][] matrix; //矩阵列数 private int matrixColCount; //矩阵行数 private int matrixRowCount; /** * 构造一个空矩阵 */ public Matrix() { this.matrix = null; this.matrixColCount = 0; this.matrixRowCount = 0; } /** * 构造一个matrix矩阵 * @param matrix */ public Matrix(double[][] matrix) { this.matrix = matrix; this.matrixRowCount = matrix.length; this.matrixColCount = matrix[0].length; } /** * 构造一个rowCount行colCount列值为0的矩阵 * @param rowCount * @param colCount */ public Matrix(int rowCount,int colCount) { double[][] matrix = new double[rowCount][colCount]; for (int i = 0; i < rowCount; i++) { for (int j = 0; j < colCount; j++) { matrix[i][j] = 0; } } this.matrix = matrix; this.matrixRowCount = rowCount; this.matrixColCount = colCount; } /** * 构造一个rowCount行colCount列值为val的矩阵 * @param val * @param rowCount * @param colCount */ public Matrix(double val,int rowCount,int colCount) { double[][] matrix = new double[rowCount][colCount]; for (int i = 0; i < rowCount; i++) { for (int j = 0; j < colCount; j++) { matrix[i][j] = val; } } this.matrix = matrix; this.matrixRowCount = rowCount; this.matrixColCount = colCount; } public double[][] getMatrix() { return matrix; } public void setMatrix(double[][] matrix) { this.matrix = matrix; this.matrixRowCount = matrix.length; this.matrixColCount = matrix[0].length; } public int getMatrixColCount() { return matrixColCount; } public int getMatrixRowCount() { return matrixRowCount; } /** * 获取矩阵指定位置的值 * * @param x * @param y * @return */ public double getValOfIdx(int x, int y) throws IllegalArgumentException { if (matrix == null) { throw new IllegalArgumentException("矩阵为空"); } if (x > matrixRowCount - 1) { throw new IllegalArgumentException("索引x越界"); } if (y > matrixColCount - 1) { throw new IllegalArgumentException("索引y越界"); } return matrix[x][y]; } /** * 获取矩阵指定行 * * @param x * @return */ public Matrix getRowOfIdx(int x) throws IllegalArgumentException { if (matrix == null) { throw new IllegalArgumentException("矩阵为空"); } if (x > matrixRowCount - 1) { throw new IllegalArgumentException("索引x越界"); } double[][] result = new double[1][matrixColCount]; result[0] = matrix[x]; return new Matrix(result); } /** * 获取矩阵指定列 * * @param y * @return */ public Matrix getColOfIdx(int y) throws IllegalArgumentException { if (matrix == null) { throw new IllegalArgumentException("矩阵为空"); } if (y > matrixColCount - 1) { throw new IllegalArgumentException("索引y越界"); } double[][] result = new double[matrixRowCount][1]; for (int i = 0; i < matrixRowCount; i++) { result[i][0] = matrix[i][y]; } return new Matrix(result); } /** * 设置矩阵中x,y位置元素的值 * @param x * @param y * @param val */ public void setValue(int x, int y, double val) { if (x > this.matrixRowCount - 1) { throw new IllegalArgumentException("行索引越界"); } if (y > this.matrixColCount - 1) { throw new IllegalArgumentException("列索引越界"); } this.matrix[x][y] = val; } /** * 矩阵乘矩阵 * * @param a * @return * @throws IllegalArgumentException */ public Matrix multiple(Matrix a) throws IllegalArgumentException { if (matrix == null) { throw new IllegalArgumentException("矩阵为空"); } if (a.getMatrix() == null) { throw new IllegalArgumentException("参数矩阵为空"); } if (matrixColCount != a.getMatrixRowCount()) { throw new IllegalArgumentException("矩阵纬度不同,不可计算"); } double[][] result = new double[matrixRowCount][a.getMatrixColCount()]; for (int i = 0; i < matrixRowCount; i++) { for (int j = 0; j < a.getMatrixColCount(); j++) { for (int k = 0; k < matrixColCount; k++) { result[i][j] = result[i][j] + matrix[i][k] * a.getMatrix()[k][j]; } } } return new Matrix(result); } /** * 矩阵乘一个数字 * * @param a * @return */ public Matrix multiple(double a) throws IllegalArgumentException { if (matrix == null) { throw new IllegalArgumentException("矩阵为空"); } double[][] result = new double[matrixRowCount][matrixColCount]; for (int i = 0; i < matrixRowCount; i++) { for (int j = 0; j < matrixColCount; j++) { result[i][j] = matrix[i][j] * a; } } return new Matrix(result); } /** * 矩阵点乘 * * @param a * @return */ public Matrix pointMultiple(Matrix a) throws IllegalArgumentException { if (matrix == null) { throw new IllegalArgumentException("矩阵为空"); } if (a.getMatrix() == null) { throw new IllegalArgumentException("参数矩阵为空"); } if (matrixRowCount != a.getMatrixRowCount() && matrixColCount != a.getMatrixColCount()) { throw new IllegalArgumentException("矩阵纬度不同,不可计算"); } double[][] result = new double[matrixRowCount][matrixColCount]; for (int i = 0; i < matrixRowCount; i++) { for (int j = 0; j < matrixColCount; j++) { result[i][j] = matrix[i][j] * a.getMatrix()[i][j]; } } return new Matrix(result); } /** * 矩阵除一个数字 * @param a * @return * @throws IllegalArgumentException */ public Matrix divide(double a) throws IllegalArgumentException { if (matrix == null) { throw new IllegalArgumentException("矩阵为空"); } double[][] result = new double[matrixRowCount][matrixColCount]; for (int i = 0; i < matrixRowCount; i++) { for (int j = 0; j < matrixColCount; j++) { result[i][j] = matrix[i][j] / a; } } return new Matrix(result); } /** * 矩阵加法 * * @param a * @return */ public Matrix plus(Matrix a) throws IllegalArgumentException { if (matrix == null) { throw new IllegalArgumentException("矩阵为空"); } if (a.getMatrix() == null) { throw new IllegalArgumentException("参数矩阵为空"); } if (matrixRowCount != a.getMatrixRowCount() && matrixColCount != a.getMatrixColCount()) { throw new IllegalArgumentException("矩阵纬度不同,不可计算"); } double[][] result = new double[matrixRowCount][matrixColCount]; for (int i = 0; i < matrixRowCount; i++) { for (int j = 0; j < matrixColCount; j++) { result[i][j] = matrix[i][j] + a.getMatrix()[i][j]; } } return new Matrix(result); } /** * 矩阵加一个数字 * @param a * @return * @throws IllegalArgumentException */ public Matrix plus(double a) throws IllegalArgumentException { if (matrix == null) { throw new IllegalArgumentException("矩阵为空"); } double[][] result = new double[matrixRowCount][matrixColCount]; for (int i = 0; i < matrixRowCount; i++) { for (int j = 0; j < matrixColCount; j++) { result[i][j] = matrix[i][j] + a; } } return new Matrix(result); } /** * 矩阵减法 * * @param a * @return */ public Matrix subtract(Matrix a) throws IllegalArgumentException { if (matrix == null) { throw new IllegalArgumentException("矩阵为空"); } if (a.getMatrix() == null) { throw new IllegalArgumentException("参数矩阵为空"); } if (matrixRowCount != a.getMatrixRowCount() && matrixColCount != a.getMatrixColCount()) { throw new IllegalArgumentException("矩阵纬度不同,不可计算"); } double[][] result = new double[matrixRowCount][matrixColCount]; for (int i = 0; i < matrixRowCount; i++) { for (int j = 0; j < matrixColCount; j++) { result[i][j] = matrix[i][j] - a.getMatrix()[i][j]; } } return new Matrix(result); } /** * 矩阵减一个数字 * @param a * @return * @throws IllegalArgumentException */ public Matrix subtract(double a) throws IllegalArgumentException { if (matrix == null) { throw new IllegalArgumentException("矩阵为空"); } double[][] result = new double[matrixRowCount][matrixColCount]; for (int i = 0; i < matrixRowCount; i++) { for (int j = 0; j < matrixColCount; j++) { result[i][j] = matrix[i][j] - a; } } return new Matrix(result); } /** * 矩阵行求和 * * @return */ public Matrix sumRow() throws IllegalArgumentException { if (matrix == null) { throw new IllegalArgumentException("矩阵为空"); } double[][] result = new double[matrixRowCount][1]; for (int i = 0; i < matrixRowCount; i++) { for (int j = 0; j < matrixColCount; j++) { result[i][0] += matrix[i][j]; } } return new Matrix(result); } /** * 矩阵列求和 * * @return */ public Matrix sumCol() throws IllegalArgumentException { if (matrix == null) { throw new IllegalArgumentException("矩阵为空"); } double[][] result = new double[1][matrixColCount]; for (int i = 0; i < matrixRowCount; i++) { for (int j = 0; j < matrixColCount; j++) { result[0][j] += matrix[i][j]; } } return new Matrix(result); } /** * 矩阵所有元素求和 * * @return */ public double sumAll() throws IllegalArgumentException { if (matrix == null) { throw new IllegalArgumentException("矩阵为空"); } double result = 0; for (double[] doubles : matrix) { for (int j = 0; j < matrixColCount; j++) { result += doubles[j]; } } return result; } /** * 矩阵所有元素求平方 * * @return */ public Matrix square() throws IllegalArgumentException { if (matrix == null) { throw new IllegalArgumentException("矩阵为空"); } double[][] result = new double[matrixRowCount][matrixColCount]; for (int i = 0; i < matrixRowCount; i++) { for (int j = 0; j < matrixColCount; j++) { result[i][j] = matrix[i][j] * matrix[i][j]; } } return new Matrix(result); } /** * 矩阵所有元素求N次方 * * @return */ public Matrix pow(double n) throws IllegalArgumentException { if (matrix == null) { throw new IllegalArgumentException("矩阵为空"); } double[][] result = new double[matrixRowCount][matrixColCount]; for (int i = 0; i < matrixRowCount; i++) { for (int j = 0; j < matrixColCount; j++) { result[i][j] = Math.pow(matrix[i][j],n); } } return new Matrix(result); } /** * 矩阵转置 * * @return */ public Matrix transpose() throws IllegalArgumentException { if (matrix == null) { throw new IllegalArgumentException("矩阵为空"); } double[][] result = new double[matrixColCount][matrixRowCount]; for (int i = 0; i < matrixRowCount; i++) { for (int j = 0; j < matrixColCount; j++) { result[j][i] = matrix[i][j]; } } return new Matrix(result); } /** * 截取矩阵 * @param startRowIndex 开始行索引 * @param rowCount 截取行数 * @param startColIndex 开始列索引 * @param colCount 截取列数 * @return * @throws IllegalArgumentException */ public Matrix subMatrix(int startRowIndex,int rowCount,int startColIndex,int colCount) throws IllegalArgumentException { if (startRowIndex + rowCount > matrixRowCount) { throw new IllegalArgumentException("行索引越界"); } if (startColIndex + colCount> matrixColCount) { throw new IllegalArgumentException("列索引越界"); } double[][] result = new double[rowCount][colCount]; for (int i = startRowIndex; i < startRowIndex + rowCount; i++) { if (startColIndex + colCount - startColIndex >= 0) System.arraycopy(matrix[i], startColIndex, result[i - startRowIndex], 0, colCount); } return new Matrix(result); } /** * 矩阵合并 * @param direction 合并方向,1为横向,2为竖向 * @param a * @return * @throws IllegalArgumentException */ public Matrix splice(int direction, Matrix a) throws IllegalArgumentException { if (matrix == null) { throw new IllegalArgumentException("矩阵为空"); } if (a.getMatrix() == null) { throw new IllegalArgumentException("参数矩阵为空"); } if(direction == 1){ //横向拼接 if (matrixRowCount != a.getMatrixRowCount()) { throw new IllegalArgumentException("矩阵行数不一致,无法拼接"); } double[][] result = new double[matrixRowCount][matrixColCount + a.getMatrixColCount()]; for (int i = 0; i < matrixRowCount; i++) { System.arraycopy(matrix[i],0,result[i],0,matrixColCount); System.arraycopy(a.getMatrix()[i],0,result[i],matrixColCount,a.getMatrixColCount()); } return new Matrix(result); }else if(direction == 2){ //纵向拼接 if (matrixColCount != a.getMatrixColCount()) { throw new IllegalArgumentException("矩阵列数不一致,无法拼接"); } double[][] result = new double[matrixRowCount + a.getMatrixRowCount()][matrixColCount]; for (int i = 0; i < matrixRowCount; i++) { result[i] = matrix[i]; } for (int i = 0; i < a.getMatrixRowCount(); i++) { result[matrixRowCount + i] = a.getMatrix()[i]; } return new Matrix(result); }else{ throw new IllegalArgumentException("方向参数有误"); } } /** * 扩展矩阵 * @param direction 扩展方向,1为横向,2为竖向 * @param a * @return * @throws IllegalArgumentException */ public Matrix extend(int direction , int a) throws IllegalArgumentException { if (matrix == null) { throw new IllegalArgumentException("矩阵为空"); } if(direction == 1){ //横向复制 double[][] result = new double[matrixRowCount][matrixColCount*a]; for (int i = 0; i < matrixRowCount; i++) { for (int j = 0; j < a; j++) { System.arraycopy(matrix[i],0,result[i],j*matrixColCount,matrixColCount); } } return new Matrix(result); }else if(direction == 2){ //纵向复制 double[][] result = new double[matrixRowCount*a][matrixColCount]; for (int i = 0; i < matrixRowCount*a; i++) { result[i] = matrix[i%matrixRowCount]; } return new Matrix(result); }else{ throw new IllegalArgumentException("方向参数有误"); } } /** * 获取每列的平均值 * @return * @throws IllegalArgumentException */ public Matrix getColAvg() throws IllegalArgumentException { Matrix tmp = this.sumCol(); return tmp.divide(matrixRowCount); } /** * 矩阵行排序 * @param index 根据第几列的数进行行排序 * @param order 排序顺序,升序或降序 * @return * @throws IllegalArgumentException */ public void sort(int index,OrderEnum order) throws IllegalArgumentException{ switch (order){ case ASC: for (int i = 0; i < this.matrixRowCount; i++) { for (int j = 0; j < this.matrixRowCount - 1 - i; j++) { if (this.matrix[j][index] > 基于蝙蝠算法优化BP神经网络的数据分类算法及其MATLAB实现-附代码基于Matlab的遗传算法优化BP神经网络的算法实现(附算法介绍与代码详解)