深度学习 之 GRU 算法例子
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了深度学习 之 GRU 算法例子相关的知识,希望对你有一定的参考价值。
首先下载代码:https://github.com/whk6688/rnn例子1:预测下文
private void train(CharText ctext, double lr) {
Map<Integer, String> indexChar = ctext.getIndexChar();
Map<String, DoubleMatrix> charVector = ctext.getCharVector();
List<String> sequence = ctext.getSequence();
for (int i = 0; i < 100; i++) {
double error = 0;
double num = 0;
double start = System.currentTimeMillis();
for (int s = 0; s < sequence.size(); s++) {
String seq = sequence.get(s);
if (seq.length() < 3) {
continue;
}
Map<String, DoubleMatrix> acts = new HashMap<>();
// forward pass
System.out.print(String.valueOf(seq.charAt(0)+"->"));
for (int t = 0; t < seq.length() - 1; t++) {
DoubleMatrix xt = charVector.get(String.valueOf(seq.charAt(t)));
acts.put("x" + t, xt);
gru.active(t, acts);
DoubleMatrix predcitYt = gru.decode(acts.get("h" + t));
acts.put("py" + t, predcitYt);
DoubleMatrix trueYt = charVector.get(String.valueOf(seq.charAt(t + 1)));
acts.put("y" + t, trueYt);
System.out.print(indexChar.get(predcitYt.argmax()));
//error += LossFunction.getMeanCategoricalCrossEntropy(predcitYt, trueYt);
}
System.out.println();
// bptt
gru.bptt(acts, seq.length() - 2, lr);
num += seq.length();
}
System.out.println("Iter = " + i + ", error = " + error / num + ", time = " + (System.currentTimeMillis() - start) / 1000 + "s");
}
}
private void test(CharText ctext) {
Map<Integer, String> indexChar = ctext.getIndexChar();
Map<String, DoubleMatrix> charVector = ctext.getCharVector();
Map<String, DoubleMatrix> acts = new HashMap<>();
String seq="不";
int t=0;
DoubleMatrix xt = charVector.get(String.valueOf(seq.charAt(t)));
acts.put("x" + t, xt);
gru.active(t, acts);
DoubleMatrix predcitYt = gru.decode(acts.get("h" + t));
acts.put("py" + t, predcitYt);
System.out.print(indexChar.get(predcitYt.argmax()));
}
训练的文本为:
行尸走肉
金蝉脱壳
百里挑一
金玉满堂
不花不四
不花千里
不花×××
背水一战
霸王别姬
天上人间
不吐不快
海阔天空
情非得已
满腹经纶
兵临城下
春暖花开
插翅难逃
黄道吉日
天下无双
偷天换日
两小无猜
卧虎藏龙
珠光宝气
簪缨世族
×××
绘声绘影
国色天香
相亲相爱
八仙过海
金玉良缘
掌上明珠
皆大欢喜
逍遥法外
当输入“不”时,下一个词会提示为“花”。 因为此算法是有时间概念的,因此当你在加入两条不字开头的成语,会发现结果不同。
例子二:预测结果
public static void main(String[] args) {
loadData();
int hiddenSize = 4;//隐含层数量
double lr = 0.1;
gru = new GRU(4, hiddenSize, new MatIniter(MatIniter.Type.Uniform, 0.1, 0, 0),3);//4是输入层,3是输出层
for (int i = 0; i < 2000; i++) {//迭代2000次
double error = 0;
double num = 0;
double start = System.currentTimeMillis();
Map<String, DoubleMatrix> acts = new HashMap<>();
for (int s = 0; s < train_x.length; s++) {
double newx[][] = new double[1][4];
newx[0] = train_x[s];
DoubleMatrix xt = new DoubleMatrix(newx);//获取字的矩阵
//System.out.println(xt.getColumns()+" "+xt.getRows());
acts.put("x" + s, xt);
gru.active(s, acts);
DoubleMatrix predcitYt = gru.decode(acts.get("h" + s));
acts.put("py" + s, predcitYt);
double newy[][] = new double[1][3];
newy[0] = train_y[s];
DoubleMatrix trueYt = new DoubleMatrix(newy);
acts.put("y" + s, trueYt);
//System.out.println(predcitYt.argmax()+"-->"+trueYt.argmax());
if(predcitYt.argmax()!=trueYt.argmax())
error++;
// bptt
num ++;
}
gru.bptt(acts, train_x.length-1, lr);
System.out.println("Iter = " + i + ", error = " + error / num + ", time = " + (System.currentTimeMillis() - start) / 1000 + "s");
}//结束迭代
//开始测试
int num = 0,error = 0;
Map<String, DoubleMatrix> acts = new HashMap<>();
for(int s = 0; s<test_x.length;s++){
double newx[][] = new double[1][4];
newx[0] = test_x[s];
DoubleMatrix xt = new DoubleMatrix(newx);
acts.put("x" + s, xt);
gru.active(s, acts);
DoubleMatrix predcitYt = gru.decode(acts.get("h" + s));
acts.put("py" + s, predcitYt);
double newy[][] = new double[1][3];
newy[0] = test_y[s];
DoubleMatrix trueYt = new DoubleMatrix(newy);
acts.put("y" + s, trueYt);
if(predcitYt.argmax()!=trueYt.argmax())
error++;
// bptt
num ++;
}
System.out.println("错误数:"+error+"/"+num);
}
这个例子来预测花的种类,当然也可以使用决策树来实现。换种方式也感觉挺好
引用:https://blog.csdn.net/czs1130/article/details/70717348
以上是关于深度学习 之 GRU 算法例子的主要内容,如果未能解决你的问题,请参考以下文章
深度学习时间序列预测:GRU算法构建单变量时间序列预测模型+代码实战
深度学习多变量时间序列预测:GRU算法构建时间序列多变量模型预测交通流量+代码实战
深度学习100例 | 第32天(GRU模型):利用算法生成小说(斗罗大陆版)
从零开始学习深度学习35. 门控循环神经网络之门控循环单元(gated recurrent unit,GRU)介绍Pytorch实现GRU并进行训练预测