怎么把这个RNN文本分类代码改成文本生成?
Posted
技术标签:
【中文标题】怎么把这个RNN文本分类代码改成文本生成?【英文标题】:How to change this RNN text classification code to text generation? 【发布时间】:2020-01-23 02:09:52 【问题描述】:我有这段代码可以使用 TensorFlow RNN 进行文本分类,但是如何将其更改为进行文本生成呢?
以下文本分类有 3D 输入,但有 2D 输出。是否应该将其更改为 3D 输入和 3D 输出以生成文本?以及如何?
示例数据为:
t0 t1 t2
british gray is => cat (y=0)
0 1 2
white samoyed is => dog (y=1)
3 4 2
对于分类喂养“英国灰色是”导致“猫”。我希望得到的是喂“英国人”应该导致下一个词“灰色”。
import tensorflow as tf;
tf.reset_default_graph();
#data
'''
t0 t1 t2
british gray is => cat (y=0)
0 1 2
white samoyed is => dog (y=1)
3 4 2
'''
Bsize = 2;
Times = 3;
Max_X = 4;
Max_Y = 1;
X = [[[0],[1],[2]], [[3],[4],[2]]];
Y = [[0], [1] ];
#normalise
for I in range(len(X)):
for J in range(len(X[I])):
X[I][J][0] /= Max_X;
for I in range(len(Y)):
Y[I][0] /= Max_Y;
#model
Inputs = tf.placeholder(tf.float32, [Bsize,Times,1]);
Expected = tf.placeholder(tf.float32, [Bsize, 1]);
#single LSTM layer
#'''
Layer1 = tf.keras.layers.LSTM(20);
Hidden1 = Layer1(Inputs);
#'''
#multi LSTM layers
'''
Layers = tf.keras.layers.RNN([
tf.keras.layers.LSTMCell(30), #hidden 1
tf.keras.layers.LSTMCell(20) #hidden 2
]);
Hidden2 = Layers(Inputs);
'''
Weight3 = tf.Variable(tf.random_uniform([20,1], -1,1));
Bias3 = tf.Variable(tf.random_uniform([ 1], -1,1));
Output = tf.sigmoid(tf.matmul(Hidden1,Weight3) + Bias3);
Loss = tf.reduce_sum(tf.square(Expected-Output));
Optim = tf.train.GradientDescentOptimizer(1e-1);
Training = Optim.minimize(Loss);
#train
Sess = tf.Session();
Init = tf.global_variables_initializer();
Sess.run(Init);
Feed = Inputs:X, Expected:Y;
for I in range(1000): #number of feeds, 1 feed = 1 batch
if I%100==0:
Lossvalue = Sess.run(Loss,Feed);
print("Loss:",Lossvalue);
#end if
Sess.run(Training,Feed);
#end for
Lastloss = Sess.run(Loss,Feed);
print("Loss:",Lastloss,"(Last)");
#eval
Results = Sess.run(Output,Feed);
print("\nEval:");
print(Results);
print("\nDone.");
#eof
【问题讨论】:
你的意思是它的当前状态?或者你可以重新训练它吗? @Recessive 我的意思是如何得到下一个词而不是类,例如,喂“英国”,我应该能够得到“灰色”而不是喂“英国灰色是”来得到“猫” 示例数据令人困惑,但看起来不兼容。由于您没有回答,我假设您可以重新训练网络,在这种情况下,最好的做法是相同的输入和输出尺寸,可能是 1d。为此,您可以获取训练数据中的所有单词并将它们用作输入和输出的非常大的 1 热向量。例如,假设您有单词['hello', 'hi','is','that','yes']
,那么您的输入将是长度为 5 的 1d,而要输入 'hello'
,您将在索引 0 处输入 1
【参考方案1】:
我发现了如何切换它(代码)来执行文本生成任务,使用 3D 输入(X)和 3D 标签(Y),如下面的源代码:
源代码:
import tensorflow as tf;
tf.reset_default_graph();
#data
'''
t0 t1 t2
british gray is cat
0 1 2 (3) <=x
1 2 3 <=y
white samoyed is dog
4 5 2 (6) <=x
5 2 6 <=y
'''
Bsize = 2;
Times = 3;
Max_X = 5;
Max_Y = 6;
X = [[[0],[1],[2]], [[4],[5],[2]]];
Y = [[[1],[2],[3]], [[5],[2],[6]]];
#normalise
for I in range(len(X)):
for J in range(len(X[I])):
X[I][J][0] /= Max_X;
for I in range(len(Y)):
for J in range(len(Y[I])):
Y[I][J][0] /= Max_Y;
#model
Input = tf.placeholder(tf.float32, [Bsize,Times,1]);
Expected = tf.placeholder(tf.float32, [Bsize,Times,1]);
#single LSTM layer
'''
Layer1 = tf.keras.layers.LSTM(20);
Hidden1 = Layer1(Input);
'''
#multi LSTM layers
#'''
Layers = tf.keras.layers.RNN([
tf.keras.layers.LSTMCell(30), #hidden 1
tf.keras.layers.LSTMCell(20) #hidden 2
],
return_sequences=True);
Hidden2 = Layers(Input);
#'''
Weight3 = tf.Variable(tf.random_uniform([20,1], -1,1));
Bias3 = tf.Variable(tf.random_uniform([ 1], -1,1));
Output = tf.sigmoid(tf.matmul(Hidden2,Weight3) + Bias3); #sequence of 2d * 2d
Loss = tf.reduce_sum(tf.square(Expected-Output));
Optim = tf.train.GradientDescentOptimizer(1e-1);
Training = Optim.minimize(Loss);
#train
Sess = tf.Session();
Init = tf.global_variables_initializer();
Sess.run(Init);
Feed = Input:X, Expected:Y;
Epochs = 10000;
for I in range(Epochs): #number of feeds, 1 feed = 1 batch
if I%(Epochs/10)==0:
Lossvalue = Sess.run(Loss,Feed);
print("Loss:",Lossvalue);
#end if
Sess.run(Training,Feed);
#end for
Lastloss = Sess.run(Loss,Feed);
print("Loss:",Lastloss,"(Last)");
#eval
Results = Sess.run(Output,Feed).tolist();
print("\nEval:");
for I in range(len(Results)):
for J in range(len(Results[I])):
for K in range(len(Results[I][J])):
Results[I][J][K] = round(Results[I][J][K]*Max_Y);
#end for i
print(Results);
print("\nDone.");
#eof
【讨论】:
以上是关于怎么把这个RNN文本分类代码改成文本生成?的主要内容,如果未能解决你的问题,请参考以下文章