python lstm_tf_eager.py
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了python lstm_tf_eager.py相关的知识,希望对你有一定的参考价值。
class LSTMModel(tf.keras.Model):
def __init__(self):
super().__init__()
self.rnn_cell = tf.nn.rnn_cell.BasicLSTMCell(64)
self.dropout = tf.layers.Dropout(rate=0.5)
self.dense1 = tf.layers.Dense(len(np.unique(y_train)))
self.optimizer = tf.train.AdamOptimizer()
self.cross_entropy = tf.losses.sparse_softmax_cross_entropy
def predict(self, input_data, train=False):
# initial state
batch_size = tf.shape(input_data)[0]
state = self.rnn_cell.zero_state(batch_size, dtype=tf.float64)
# transpose
# (batch_size, seq_len, n_features) =>
# (seq_len, batch_size, n_features)
inputs = tf.unstack(input_data, axis=1)
outputs = []
for sample in inputs:
out, state = self.rnn_cell(sample, state)
outputs.append(out)
# reverse transpose
# (seq_len, batch_size, n_features) =>
# (batch_size, seq_len, n_features)
out = tf.stack(outputs, axis=1)
# get the output of the last time step, of each sample
seq_len = tf.shape(out[0])[0].numpy()
a = tf.range(batch_size)
b = tf.constant([seq_len-1 for _ in range(batch_size.numpy())])
idxs_last_output = tf.stack([a, b], axis=1)
out = tf.gather_nd(out, idxs_last_output)
out = self.dropout(out, training=train)
logits = self.dense1(out)
return logits
def forward_pass(self, input_data):
preds = self.predict(input_data, train=True)
return preds
def backward_pass(self, loss, tape):
grad = tape.gradient(loss, self.variables)
self.optimizer.apply_gradients(zip(grad, self.variables))
def fit(self, X, Y, val_X=None, val_Y=None, epoch=1000, print_every=100):
for i in range(epoch):
with tfe.GradientTape() as tape:
preds = self.forward_pass(X)
loss = self.cross_entropy(labels=Y, logits=preds)
self.backward_pass(loss, tape)
if (i+1) % print_every == 0:
loss = tf.cast(loss, tf.float64).numpy()
acc = compute_accuracy(preds, Y)
if val_X is not None and val_Y is not None:
p = self.predict(val_X)
val_acc = compute_accuracy(p, val_Y)
print("epoch", i+1, "loss", round(loss, 5), "acc", acc, "val_acc", round(val_acc, 2))
else:
print("epoch", i+1, "loss", round(loss, 5), "acc", round(acc,2))
lstmmodel = LSTMModel()
lstmmodel.fit(X, Y, X_t, Y_t, epoch=1000)
以上是关于python lstm_tf_eager.py的主要内容,如果未能解决你的问题,请参考以下文章