将 tensorflow 1 contrib 转换为 tensorflow 2 Keras 版本

Posted

技术标签:

【中文标题】将 tensorflow 1 contrib 转换为 tensorflow 2 Keras 版本【英文标题】:Convert tensorflow 1 contrib to tensorflow 2 Keras version 【发布时间】:2021-09-08 08:18:32 【问题描述】:

我正在将我的代码从 tf1 迁移到 tf2,我认为我必须解决大部分问题才能使用 tf2 运行它。但是在将其迁移到与 tfa.seq2seq.LuongAttention 和 tfa.seq2seq.AttentionWrapper 兼容的 Tf2 时遇到了问题。已经将 contrib 替换为 v2,但不确定为什么它不起作用。

def _single_cell(num_units, keep_prob, device_str=None):
    single_cell = tf.compat.v1.nn.rnn_cell.GRUCell(num_units)
    if keep_prob < 1.0:
        single_cell = tf.contrib.rnn.DropoutWrapper(cell=single_cell, input_keep_prob=keep_prob)
    # Device Wrapper
    if device_str:
        single_cell = tf.contrib.rnn.DeviceWrapper(single_cell, device_str)
    return single_cell


def create_rnn_cell(num_units, num_layers, keep_prob):
    """Create multi-layer RNN cell."""
    cell_list = []
    for i in range(num_layers):
        single_cell = _single_cell(num_units=num_units, keep_prob=keep_prob)
        cell_list.append(single_cell)
    if len(cell_list) == 1:  # Single layer.
        return cell_list[0]
    else:  # Multi layers
        return tf.compat.v1.nn.rnn_cell.MultiRNNCell(cell_list)

cell = create_rnn_cell(
            num_units=hparams.num_units,
            num_layers=hparams.num_layers,
            keep_prob=hparams.keep_prob)

encoder_outputs, encoder_state = tf.compat.v1.nn.dynamic_rnn(
                cell,
                encoder_emb_inp,
                dtype=dtype,
                sequence_length=self.batch_input.source_sequence_length,
                time_major=self.time_major)

我参考了https://github.com/tensorflow/addons/tree/master/tensorflow_addons/seq2seq,并且能够迁移除这两个函数之外的大部分代码

【问题讨论】:

【参考方案1】:

Tensorflow 2.x 中很少有库被移动到其他存储库,例如插件和操作。

tf.contrib.rnn.DropoutWrapper 替换为tf.compat.v1.nn.rnn_cell.DropoutWrapper 以获取有关库的更多信息,请查找here。

tf.contrib.rnn.DeviceWrapper 替换为tf.compat.v1.nn.rnn_cell.DeviceWrapper 以获取有关库的更多信息,请查找here。

【讨论】:

以上是关于将 tensorflow 1 contrib 转换为 tensorflow 2 Keras 版本的主要内容,如果未能解决你的问题,请参考以下文章

ImportError:没有名为'tensorflow.contrib.lite.python.tflite_convert'的模块

将时间序列元素的Tensorflow数据集转换为窗口序列的数据集

将 CudnnGRU 参数转换为正常的权重和偏差

无法在 tensorflow r1.14 中导入“tensorflow.contrib.tensorrt”

不降级解决No module named ‘tensorflow.contrib‘

AttributeError:模块“tensorflow.contrib”没有属性“估计器”