JAVA实现BP神经网络算法

Posted MrZhaoyx

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了JAVA实现BP神经网络算法相关的知识,希望对你有一定的参考价值。

工作中需要预测一个过程的时间,就想到了使用BP神经网络来进行预测。

简介

BP神经网络(Back Propagation Neural Network)是一种基于BP算法的人工神经网络,其使用BP算法进行权值与阈值的调整。在20世纪80年代,几位不同的学者分别开发出了用于训练多层感知机的反向传播算法,David Rumelhart和James McClelland提出的反向传播算法是最具影响力的。其包含BP的两大主要过程,即工作信号的正向传播与误差信号的反向传播,分别负责了神经网络中输出的计算与权值和阈值更新。工作信号的正向传播是通过计算得到BP神经网络的实际输出,误差信号的反向传播是由后往前逐层修正权值与阈值,为了使实际输出更接近期望输出。

​ (1)工作信号正向传播。输入信号从输入层进入,通过突触进入隐含层神经元,经传递函数运算后,传递到输出层,并且在输出层计算出输出信号传出。当工作信号正向传播时,权值与阈值固定不变,神经网络中每层的状态只与前一层的净输出、权值和阈值有关。若正向传播在输出层获得到期望的输出,则学习结束,并保留当前的权值与阈值;若正向传播在输出层得不到期望的输出,则在误差信号的反向传播中修正权值与阈值。

​ (2)误差信号反向传播。在工作信号正向传播后若得不到期望的输出,则通过计算误差信号进行反向传播,通过计算BP神经网络的实际输出与期望输出之间的差值作为误差信号,并且由神经网络的输出层,逐层向输入层传播。在此过程中,每向前传播一层,就对该层的权值与阈值进行修改,由此一直向前传播直至输入层,该过程是为了使神经网络的结果与期望的结果更相近。

​ 当进行一次正向传播和反向传播后,若误差仍不能达到要求,则该过程继续下去,直至误差满足精度,或者满足迭代次数等其他设置的结束条件。

推导请见 https://zh.wikipedia.org/wiki/%E5%8F%8D%E5%90%91%E4%BC%A0%E6%92%AD%E7%AE%97%E6%B3%95

BPNN结构

该BPNN为单输入层单隐含层单输出层结构

项目结构

介绍一些用到的类

  • ActivationFunction:激活函数的接口
  • BPModel:BP模型实体类
  • BPNeuralNetworkFactory:BP神经网络工厂,包括训练BP神经网络,计算,序列化等功能
  • BPParameter:BP神经网络参数实体类
  • Matrix:矩阵实体类
  • Sigmoid:Sigmoid传输函数,实现了ActivationFunction接口
  • MatrixUtil:矩阵工具类

实现代码

Matrix实体类

模拟了矩阵的基本运算方法。

package com.top.matrix;

import com.top.constants.OrderEnum;

import java.io.Serializable;

public class Matrix implements Serializable {
    private double[][] matrix;
    //矩阵列数
    private int matrixColCount;
    //矩阵行数
    private int matrixRowCount;

    /**
     * 构造一个空矩阵
     */
    public Matrix() {
        this.matrix = null;
        this.matrixColCount = 0;
        this.matrixRowCount = 0;
    }

    /**
     * 构造一个matrix矩阵
     * @param matrix
     */
    public Matrix(double[][] matrix) {
        this.matrix = matrix;
        this.matrixRowCount = matrix.length;
        this.matrixColCount = matrix[0].length;
    }

    /**
     * 构造一个rowCount行colCount列值为0的矩阵
     * @param rowCount
     * @param colCount
     */
    public Matrix(int rowCount,int colCount) {
        double[][] matrix = new double[rowCount][colCount];
        for (int i = 0; i < rowCount; i++) {
            for (int j = 0; j < colCount; j++) {
                matrix[i][j] = 0;
            }
        }
        this.matrix = matrix;
        this.matrixRowCount = rowCount;
        this.matrixColCount = colCount;
    }

    /**
     * 构造一个rowCount行colCount列值为val的矩阵
     * @param val
     * @param rowCount
     * @param colCount
     */
    public Matrix(double val,int rowCount,int colCount) {
        double[][] matrix = new double[rowCount][colCount];
        for (int i = 0; i < rowCount; i++) {
            for (int j = 0; j < colCount; j++) {
                matrix[i][j] = val;
            }
        }
        this.matrix = matrix;
        this.matrixRowCount = rowCount;
        this.matrixColCount = colCount;
    }

    public double[][] getMatrix() {
        return matrix;
    }

    public void setMatrix(double[][] matrix) {
        this.matrix = matrix;
        this.matrixRowCount = matrix.length;
        this.matrixColCount = matrix[0].length;
    }

    public int getMatrixColCount() {
        return matrixColCount;
    }

    public int getMatrixRowCount() {
        return matrixRowCount;
    }

    /**
     * 获取矩阵指定位置的值
     *
     * @param x
     * @param y
     * @return
     */
    public double getValOfIdx(int x, int y) throws IllegalArgumentException {
        if (matrix == null) {
            throw new IllegalArgumentException("矩阵为空");
        }
        if (x > matrixRowCount - 1) {
            throw new IllegalArgumentException("索引x越界");
        }
        if (y > matrixColCount - 1) {
            throw new IllegalArgumentException("索引y越界");
        }
        return matrix[x][y];
    }

    /**
     * 获取矩阵指定行
     *
     * @param x
     * @return
     */
    public Matrix getRowOfIdx(int x) throws IllegalArgumentException {
        if (matrix == null) {
            throw new IllegalArgumentException("矩阵为空");
        }
        if (x > matrixRowCount - 1) {
            throw new IllegalArgumentException("索引x越界");
        }
        double[][] result = new double[1][matrixColCount];
        result[0] = matrix[x];
        return new Matrix(result);
    }

    /**
     * 获取矩阵指定列
     *
     * @param y
     * @return
     */
    public Matrix getColOfIdx(int y) throws IllegalArgumentException {
        if (matrix == null) {
            throw new IllegalArgumentException("矩阵为空");
        }
        if (y > matrixColCount - 1) {
            throw new IllegalArgumentException("索引y越界");
        }
        double[][] result = new double[matrixRowCount][1];
        for (int i = 0; i < matrixRowCount; i++) {
            result[i][0] = matrix[i][y];
        }
        return new Matrix(result);
    }

    /**
     * 设置矩阵中x,y位置元素的值
     * @param x
     * @param y
     * @param val
     */
    public void setValue(int x, int y, double val) {
        if (x > this.matrixRowCount - 1) {
            throw new IllegalArgumentException("行索引越界");
        }
        if (y > this.matrixColCount - 1) {
            throw new IllegalArgumentException("列索引越界");
        }
        this.matrix[x][y] = val;
    }

    /**
     * 矩阵乘矩阵
     *
     * @param a
     * @return
     * @throws IllegalArgumentException
     */
    public Matrix multiple(Matrix a) throws IllegalArgumentException {
        if (matrix == null) {
            throw new IllegalArgumentException("矩阵为空");
        }
        if (a.getMatrix() == null) {
            throw new IllegalArgumentException("参数矩阵为空");
        }
        if (matrixColCount != a.getMatrixRowCount()) {
            throw new IllegalArgumentException("矩阵纬度不同,不可计算");
        }
        double[][] result = new double[matrixRowCount][a.getMatrixColCount()];
        for (int i = 0; i < matrixRowCount; i++) {
            for (int j = 0; j < a.getMatrixColCount(); j++) {
                for (int k = 0; k < matrixColCount; k++) {
                    result[i][j] = result[i][j] + matrix[i][k] * a.getMatrix()[k][j];
                }
            }
        }
        return new Matrix(result);
    }

    /**
     * 矩阵乘一个数字
     *
     * @param a
     * @return
     */
    public Matrix multiple(double a) throws IllegalArgumentException {
        if (matrix == null) {
            throw new IllegalArgumentException("矩阵为空");
        }
        double[][] result = new double[matrixRowCount][matrixColCount];
        for (int i = 0; i < matrixRowCount; i++) {
            for (int j = 0; j < matrixColCount; j++) {
                result[i][j] = matrix[i][j] * a;
            }
        }
        return new Matrix(result);
    }

    /**
     * 矩阵点乘
     *
     * @param a
     * @return
     */
    public Matrix pointMultiple(Matrix a) throws IllegalArgumentException {
        if (matrix == null) {
            throw new IllegalArgumentException("矩阵为空");
        }
        if (a.getMatrix() == null) {
            throw new IllegalArgumentException("参数矩阵为空");
        }
        if (matrixRowCount != a.getMatrixRowCount() && matrixColCount != a.getMatrixColCount()) {
            throw new IllegalArgumentException("矩阵纬度不同,不可计算");
        }
        double[][] result = new double[matrixRowCount][matrixColCount];
        for (int i = 0; i < matrixRowCount; i++) {
            for (int j = 0; j < matrixColCount; j++) {
                result[i][j] = matrix[i][j] * a.getMatrix()[i][j];
            }
        }
        return new Matrix(result);
    }

    /**
     * 矩阵除一个数字
     * @param a
     * @return
     * @throws IllegalArgumentException
     */
    public Matrix divide(double a) throws IllegalArgumentException {
        if (matrix == null) {
            throw new IllegalArgumentException("矩阵为空");
        }
        double[][] result = new double[matrixRowCount][matrixColCount];
        for (int i = 0; i < matrixRowCount; i++) {
            for (int j = 0; j < matrixColCount; j++) {
                result[i][j] = matrix[i][j] / a;
            }
        }
        return new Matrix(result);
    }

    /**
     * 矩阵加法
     *
     * @param a
     * @return
     */
    public Matrix plus(Matrix a) throws IllegalArgumentException {
        if (matrix == null) {
            throw new IllegalArgumentException("矩阵为空");
        }
        if (a.getMatrix() == null) {
            throw new IllegalArgumentException("参数矩阵为空");
        }
        if (matrixRowCount != a.getMatrixRowCount() && matrixColCount != a.getMatrixColCount()) {
            throw new IllegalArgumentException("矩阵纬度不同,不可计算");
        }
        double[][] result = new double[matrixRowCount][matrixColCount];
        for (int i = 0; i < matrixRowCount; i++) {
            for (int j = 0; j < matrixColCount; j++) {
                result[i][j] = matrix[i][j] + a.getMatrix()[i][j];
            }
        }
        return new Matrix(result);
    }

    /**
     * 矩阵加一个数字
     * @param a
     * @return
     * @throws IllegalArgumentException
     */
    public Matrix plus(double a) throws IllegalArgumentException {
        if (matrix == null) {
            throw new IllegalArgumentException("矩阵为空");
        }
        double[][] result = new double[matrixRowCount][matrixColCount];
        for (int i = 0; i < matrixRowCount; i++) {
            for (int j = 0; j < matrixColCount; j++) {
                result[i][j] = matrix[i][j] + a;
            }
        }
        return new Matrix(result);
    }

    /**
     * 矩阵减法
     *
     * @param a
     * @return
     */
    public Matrix subtract(Matrix a) throws IllegalArgumentException {
        if (matrix == null) {
            throw new IllegalArgumentException("矩阵为空");
        }
        if (a.getMatrix() == null) {
            throw new IllegalArgumentException("参数矩阵为空");
        }
        if (matrixRowCount != a.getMatrixRowCount() && matrixColCount != a.getMatrixColCount()) {
            throw new IllegalArgumentException("矩阵纬度不同,不可计算");
        }
        double[][] result = new double[matrixRowCount][matrixColCount];
        for (int i = 0; i < matrixRowCount; i++) {
            for (int j = 0; j < matrixColCount; j++) {
                result[i][j] = matrix[i][j] - a.getMatrix()[i][j];
            }
        }
        return new Matrix(result);
    }

    /**
     * 矩阵减一个数字
     * @param a
     * @return
     * @throws IllegalArgumentException
     */
    public Matrix subtract(double a) throws IllegalArgumentException {
        if (matrix == null) {
            throw new IllegalArgumentException("矩阵为空");
        }
        double[][] result = new double[matrixRowCount][matrixColCount];
        for (int i = 0; i < matrixRowCount; i++) {
            for (int j = 0; j < matrixColCount; j++) {
                result[i][j] = matrix[i][j] - a;
            }
        }
        return new Matrix(result);
    }

    /**
     * 矩阵行求和
     *
     * @return
     */
    public Matrix sumRow() throws IllegalArgumentException {
        if (matrix == null) {
            throw new IllegalArgumentException("矩阵为空");
        }
        double[][] result = new double[matrixRowCount][1];
        for (int i = 0; i < matrixRowCount; i++) {
            for (int j = 0; j < matrixColCount; j++) {
                result[i][0] += matrix[i][j];
            }
        }
        return new Matrix(result);
    }

    /**
     * 矩阵列求和
     *
     * @return
     */
    public Matrix sumCol() throws IllegalArgumentException {
        if (matrix == null) {
            throw new IllegalArgumentException("矩阵为空");
        }
        double[][] result = new double[1][matrixColCount];
        for (int i = 0; i < matrixRowCount; i++) {
            for (int j = 0; j < matrixColCount; j++) {
                result[0][j] += matrix[i][j];
            }
        }
        return new Matrix(result);
    }

    /**
     * 矩阵所有元素求和
     *
     * @return
     */
    public double sumAll() throws IllegalArgumentException {
        if (matrix == null) {
            throw new IllegalArgumentException("矩阵为空");
        }
        double result = 0;
        for (double[] doubles : matrix) {
            for (int j = 0; j < matrixColCount; j++) {
                result += doubles[j];
            }
        }
        return result;
    }

    /**
     * 矩阵所有元素求平方
     *
     * @return
     */
    public Matrix square() throws IllegalArgumentException {
        if (matrix == null) {
            throw new IllegalArgumentException("矩阵为空");
        }
        double[][] result = new double[matrixRowCount][matrixColCount];
        for (int i = 0; i < matrixRowCount; i++) {
            for (int j = 0; j < matrixColCount; j++) {
                result[i][j] = matrix[i][j] * matrix[i][j];
            }
        }
        return new Matrix(result);
    }

    /**
     * 矩阵所有元素求N次方
     *
     * @return
     */
    public Matrix pow(double n) throws IllegalArgumentException {
        if (matrix == null) {
            throw new IllegalArgumentException("矩阵为空");
        }
        double[][] result = new double[matrixRowCount][matrixColCount];
        for (int i = 0; i < matrixRowCount; i++) {
            for (int j = 0; j < matrixColCount; j++) {
                result[i][j] = Math.pow(matrix[i][j],n);
            }
        }
        return new Matrix(result);
    }

    /**
     * 矩阵转置
     *
     * @return
     */
    public Matrix transpose() throws IllegalArgumentException {
        if (matrix == null) {
            throw new IllegalArgumentException("矩阵为空");
        }
        double[][] result = new double[matrixColCount][matrixRowCount];
        for (int i = 0; i < matrixRowCount; i++) {
            for (int j = 0; j < matrixColCount; j++) {
                result[j][i] = matrix[i][j];
            }
        }
        return new Matrix(result);
    }

    /**
     * 截取矩阵
     * @param startRowIndex 开始行索引
     * @param rowCount   截取行数
     * @param startColIndex 开始列索引
     * @param colCount   截取列数
     * @return
     * @throws IllegalArgumentException
     */
    public Matrix subMatrix(int startRowIndex,int rowCount,int startColIndex,int colCount) throws IllegalArgumentException {
        if (startRowIndex + rowCount > matrixRowCount) {
            throw new IllegalArgumentException("行索引越界");
        }
        if (startColIndex + colCount> matrixColCount) {
            throw new IllegalArgumentException("列索引越界");
        }
        double[][] result = new double[rowCount][colCount];
        for (int i = startRowIndex; i < startRowIndex + rowCount; i++) {
            if (startColIndex + colCount - startColIndex >= 0)
                System.arraycopy(matrix[i], startColIndex, result[i - startRowIndex], 0, colCount);
        }
        return new Matrix(result);
    }

    /**
     * 矩阵合并
     * @param direction 合并方向,1为横向,2为竖向
     * @param a
     * @return
     * @throws IllegalArgumentException
     */
    public Matrix splice(int direction, Matrix a) throws IllegalArgumentException {
        if (matrix == null) {
            throw new IllegalArgumentException("矩阵为空");
        }
        if (a.getMatrix() == null) {
            throw new IllegalArgumentException("参数矩阵为空");
        }
        if(direction == 1){
            //横向拼接
            if (matrixRowCount != a.getMatrixRowCount()) {
                throw new IllegalArgumentException("矩阵行数不一致,无法拼接");
            }
            double[][] result = new double[matrixRowCount][matrixColCount + a.getMatrixColCount()];
            for (int i = 0; i < matrixRowCount; i++) {
                System.arraycopy(matrix[i],0,result[i],0,matrixColCount);
                System.arraycopy(a.getMatrix()[i],0,result[i],matrixColCount,a.getMatrixColCount());
            }
            return new Matrix(result);
        }else if(direction == 2){
            //纵向拼接
            if (matrixColCount != a.getMatrixColCount()) {
                throw new IllegalArgumentException("矩阵列数不一致,无法拼接");
            }
            double[][] result = new double[matrixRowCount + a.getMatrixRowCount()][matrixColCount];
            for (int i = 0; i < matrixRowCount; i++) {
                result[i] = matrix[i];
            }
            for (int i = 0; i < a.getMatrixRowCount(); i++) {
                result[matrixRowCount + i] = a.getMatrix()[i];
            }
            return new Matrix(result);
        }else{
            throw new IllegalArgumentException("方向参数有误");
        }
    }
    /**
     * 扩展矩阵
     * @param direction 扩展方向,1为横向,2为竖向
     * @param a
     * @return
     * @throws IllegalArgumentException
     */
    public Matrix extend(int direction , int a) throws IllegalArgumentException {
        if (matrix == null) {
            throw new IllegalArgumentException("矩阵为空");
        }
        if(direction == 1){
            //横向复制
            double[][] result = new double[matrixRowCount][matrixColCount*a];
            for (int i = 0; i < matrixRowCount; i++) {
                for (int j = 0; j < a; j++) {
                    System.arraycopy(matrix[i],0,result[i],j*matrixColCount,matrixColCount);
                }
            }
            return new Matrix(result);
        }else if(direction == 2){
            //纵向复制
            double[][] result = new double[matrixRowCount*a][matrixColCount];
            for (int i = 0; i < matrixRowCount*a; i++) {
                result[i] = matrix[i%matrixRowCount];
            }
            return new Matrix(result);
        }else{
            throw new IllegalArgumentException("方向参数有误");
        }
    }
    /**
     * 获取每列的平均值
     * @return
     * @throws IllegalArgumentException
     */
    public Matrix getColAvg() throws IllegalArgumentException {
        Matrix tmp = this.sumCol();
        return tmp.divide(matrixRowCount);
    }

    /**
     * 矩阵行排序
     * @param index 根据第几列的数进行行排序
     * @param order 排序顺序,升序或降序
     * @return
     * @throws IllegalArgumentException
     */
    public void sort(int index,OrderEnum order) throws IllegalArgumentException{
        switch (order){
            case ASC:
                for (int i = 0; i < this.matrixRowCount; i++) {
                    for (int j = 0; j < this.matrixRowCount - 1 - i; j++) {
                        if (this.matrix[j][index] > 基于蝙蝠算法优化BP神经网络的数据分类算法及其MATLAB实现-附代码

基于Matlab的遗传算法优化BP神经网络的算法实现(附算法介绍与代码详解)

第5章 实现多层神经网络BP算法

如何理解CNN神经网络里的反向传播backpropagation,bp算法

AI从入门到放弃:BP神经网络算法推导及代码实现笔记

基于遗传算法的BP神经网络在汇率预测中的应用研究(Matlab代码实现)