java写卷积神经网络---卷积神经网络(CupCnn)的数据结构
Posted 阳光玻璃杯
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了java写卷积神经网络---卷积神经网络(CupCnn)的数据结构相关的知识,希望对你有一定的参考价值。
前言
我在写CupCnn的时候,一个困扰我很久的问题,就是如何组织卷积神经网络的数据结构。尤其是卷积层和全连接层之间的衔接问题。卷积层至少需要四维的数据结构(batch+channel+height+width),而全连接层则只需一个二维的数据即可(batch+数据)。
CupCnn是我用java实现的一个卷积神经网络,它的源码可以从github下载:
点击下载CupCnn
卷积神经网络模型
这张图片是一个很典型的卷积神经网络模型,在CupCnn中识别手写数字的例子中使用的模型和它非常类似。从图中可以看出,卷积层和全连接层有很大的不同,全连接层用一维数组就可以表示(加上batch就是二维),而卷积层的数据则至少需要一个三维的数据结构(加上batch就是四维,batch是指训练的时候,一次送入神经网络的图片的数量。)。同一个神经网络,使用统一的数据接口会使编程更加容易,因此,我们必须使用统一的四维模型来装载一切的数据。
在CupCnn中,这个数据接口叫Blob,它的实现如下:
public class Blob implements Serializable
/**
*
*/
private static final long serialVersionUID = 1L;
private double[] data;
private int numbers;
private int channels;
private int width;
private int height;
private int id;
public Blob(int numbers,int channels,int height,int width)
this.numbers = numbers;
this.channels = channels;
this.height = height;
this.width = width;
data = new double[getSize()];
//获取第n个number的第channels个通道的第height行的第width列的数
public double getDataByParams(int numbers,int channels,int height,int width)
return data[numbers*get3DSize()+channels*get2DSize()+height*getWidth()+width];
public int getIndexByParams(int numbers,int channels,int height,int width)
return (numbers*get3DSize()+channels*get2DSize()+height*getWidth()+width);
public int getWidth()
return width;
public int getHeight()
return height;
public int getChannels()
return channels;
public int getNumbers()
return numbers;
public int get2DSize()
return width * height;
public int get3DSize()
return channels*width*height;
public int get4DSize()
return numbers*channels*width*height;
public int getSize()
return get4DSize();
public void setId(int id)
this.id = id;
public int getId()
return id;
public double[] getData()
return data;
public void fillValue(double value)
for(int i=0;i<data.length;i++)
data[i] = value;
public void cloneTo(Blob to)
to.numbers = this.numbers;
to.channels = this.channels;
to.height = this.height;
to.width = this.width;
double[] toData = to.getData();
for(int i=0;i<data.length;i++)
toData[i] = this.data[i];
Blob的实现中,所有的数据都存储在一个一维的数组中,通过四个变量batch,channel,height,width来分别记录它各个维度的大小,此外,还导出了get(x)DSize()这样获取维度大小的接口。
卷积神经网络工作的工程中,数据的变化如下:
对于一个64*64大小的三维图片,经过一个卷积层+一个池化层后,图片的大小变为一半(卷积方式为same),但是通道却极大的增多了,注意,这里要强调的是通道的增加。在CupCnn的实现过程中,假如指定的batch为10,那么每个层,它的batch都是10,至始自终不会改变,卷积层主要会增加channel,池化层不会增加channel,但会使图像减小。
如果解释的还不清楚,再来看下面这张图:
注意图片中的连线,图中,第一个卷积层有4个卷积核,分别对原始图片做卷积,得到了4个28*28的图像,这里显然是使用了valide的方式进行的卷积,如果使用的是same的方式,卷积后大小仍为32*32。池化不会再增加通道,而是将每一个图像都变小了。至于卷积和池化的具体工作流程,这里不再展开。
注意:图中的数据没有添加batch的概念,加上batch后会更加复杂。但是只要高清了这幅图中的工作机制,相信理解加上batch后的卷积神经网络也就不是事了。
卷积层与全连接层的衔接
用一个一维的数组保存所有的数据除了速度上的优势之外,还有个很大的便利就是在卷积层和全连接层进行衔接的时候,由于数据本来就是存储在一维数组上的,我们完全可以忘记它是四维的数据结构,而把它当成一个一维的数据结构。这样就可以轻易的实现卷积到全连接的过度。
Blob的传递
在卷积神经网络中,这一层的输出便是下一层的输入。CupCnn中数据流动的就是Blob这个结构。为了方便下一个层获取上一个层的输出,CupCnn中的每一个层都有一个id,这个id是他在卷积神经网络中的位置,或者序号。比如第一个输入层它的id=0,第二个层它的id=1。此外,每一个层都有一个network的引用,因为所有的数据都由network统一管理,拥有network的引用,可以轻易的通过id索引获取任意一层的数据,包括输出和diff。
一开始就创建所有的需要的数据结构:
public Network()
datas = new ArrayList<Blob>();
diffs = new ArrayList<Blob>();
layers = new ArrayList<Layer>();
根据每一个层的配置参数创建层,每一层的输出Blob和残差Blob:
public void prepare()
for(int i=0;i<layers.size();i++)
BlobParams layerParams = layers.get(i).getLayerParames();
assert (layerParams.getNumbers()>0 && layerParams.getChannels()>0 && layerParams.getHeight()>0 && layerParams.getWidth() >0):"prapare---layer params error";
Blob data = new Blob(batch,layerParams.getChannels(),layerParams.getHeight(),layerParams.getWidth());
datas.add(data);
Blob diff = new Blob(data.getNumbers(),data.getChannels(),data.getHeight(),data.getWidth());
diffs.add(diff);
layers.get(i).setId(i);
layers.get(i).prepare();
通过id获取指定层的数据:
@Override
public void forward()
// TODO Auto-generated method stub
Blob input = mNetwork.getDatas().get(id-1);
Blob output = mNetwork.getDatas().get(id);
double [] outputData = output.getData();
double [] zData = z.getData();
...
写在最后
写卷积神经网络的时候,建议先写全连接层,因为写完全连接层就可以验证神经网络的正确性。这个时候,大家还是要注意数据结构一开始就用四维的,为以后和卷积层衔接做准备。如果您在写代码的过程中遇到什么困惑或者有什么兴奋的改进,都可以家下面的QQ群互相交流:
机器学习 QQ交流群:704153141
以上是关于java写卷积神经网络---卷积神经网络(CupCnn)的数据结构的主要内容,如果未能解决你的问题,请参考以下文章
从软件工程的角度写机器学习6——深度学习之卷积神经网络(CNN)实现