AttributeError:“ModifiedTensorBoard”对象没有属性“_train_dir”

Posted

技术标签:

【中文标题】AttributeError:“ModifiedTensorBoard”对象没有属性“_train_dir”【英文标题】:AttributeError: 'ModifiedTensorBoard' object has no attribute '_train_dir' 【发布时间】:2020-12-10 10:19:45 【问题描述】:

我正在关注 youtube 上的 DeepQlearning 教程。但是,我很难让它运行。它说我没有属性“_train_dir”。当我什至没有调用该代码时。代码如下:

class ModifiedTensorBoard(TensorBoard):

    # Overriding init to set initial step and writer (we want one log file for all .fit() calls)
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.step = 1
        self.writer = tf.summary.create_file_writer(self.log_dir)
        self._log_write_dir= self.log_dir

    def _write_logs(self, logs, index):
        with self.writer.as_default():
            for name, value in logs.items():
                tf.summary.scalar(name, value, step=index)
                self.step += 1
                self.writer.flush()
                
    # Overriding this method to stop creating default log writer
    def set_model(self, model):
        pass

    # Overrided, saves logs with our step number
    # (otherwise every .fit() will start writing from 0th step)
    def on_epoch_end(self, epoch, logs=None):
        self.update_stats(**logs)

    # Overrided
    # We train for one batch only, no need to save anything at epoch end
    def on_batch_end(self, batch, logs=None):
        pass

    # Overrided, so won't close writer
    def on_train_end(self, _):
        pass

    # Custom method for saving own metrics
    # Creates writer, writes custom metrics and closes writer
    def update_stats(self, **stats):
        self._write_logs(stats, self.step)

一直编译到这里:

Traceback (most recent call last):
  File "dqn-1.py", line 387, in <module>
    agent.train(done, step)
  File "dqn-1.py", line 334, in train
    verbose=0, shuffle=False, callbacks=[self.tensorboard] if terminal_state else None)
  File "C:\Users\Anthony\AppData\Local\Programs\Python\Python37\lib\site-packages\tensorflow\python\keras\engine\training.py", line 108, in _method_wrapper
    return method(self, *args, **kwargs)
  File "C:\Users\Anthony\AppData\Local\Programs\Python\Python37\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1079, in fit
    callbacks.on_train_begin()
  File "C:\Users\Anthony\AppData\Local\Programs\Python\Python37\lib\site-packages\tensorflow\python\keras\callbacks.py", line 497, in on_train_begin
    callback.on_train_begin(logs)
  File "C:\Users\Anthony\AppData\Local\Programs\Python\Python37\lib\site-packages\tensorflow\python\keras\callbacks.py", line 2141, in on_train_begin
    self._push_writer(self._train_writer, self._train_step)
  File "C:\Users\Anthony\AppData\Local\Programs\Python\Python37\lib\site-packages\tensorflow\python\keras\callbacks.py", line 1988, in _train_writer
    self._train_dir)

我做错了什么?

【问题讨论】:

【参考方案1】:

这是 TensorFlow 2.4.1 的更新工作代码,只需照原样复制粘贴即可:

class ModifiedTensorBoard(TensorBoard):

def __init__(self, **kwargs):
    super().__init__(**kwargs)
    self.step = 1
    self.writer = tf.summary.create_file_writer(self.log_dir)
    self._log_write_dir = self.log_dir

def set_model(self, model):
    self.model = model

    self._train_dir = os.path.join(self._log_write_dir, 'train')
    self._train_step = self.model._train_counter

    self._val_dir = os.path.join(self._log_write_dir, 'validation')
    self._val_step = self.model._test_counter

    self._should_write_train_graph = False

def on_epoch_end(self, epoch, logs=None):
    self.update_stats(**logs)

def on_batch_end(self, batch, logs=None):
    pass

def on_train_end(self, _):
    pass

def update_stats(self, **stats):
    with self.writer.as_default():
        for key, value in stats.items():
            tf.summary.scalar(key, value, step = self.step)
            self.writer.flush()

【讨论】:

【参考方案2】:

我有同样的问题,是因为 tensorflow 版本。我有 2.3,我的更改是这样的:

import tensorflow as tf
#tf.compat.v1.disable_eager_execution() # uncomment if needed
if tf.executing_eagerly():
    print('Executing eagerly')

print(f'tensorflow version tf.__version__')
print(f'tensorflow.keras version tf.keras.__version__')

# Own Tensorboard class
class ModifiedTensorBoard(TensorBoard):

    # Overriding init to set initial step and writer (we want one log file for all .fit() calls)
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.step = 1
        self.model = None
        self.TB_graph = tf.compat.v1.Graph()
        with self.TB_graph.as_default():
            self.writer = tf.summary.create_file_writer(self.log_dir, flush_millis=5000)
            self.writer.set_as_default()
            self.all_summary_ops = tf.compat.v1.summary.all_v2_summary_ops()
        self.TB_sess = tf.compat.v1.InteractiveSession(graph=self.TB_graph)
        self.TB_sess.run(self.writer.init())

    # Overriding this method to stop creating default log writer
    def set_model(self, model):
        self.model = model
        self._train_dir = self.log_dir + '\\train'

    # Overrided, saves logs with our step number
    # (otherwise every .fit() will start writing from 0th step)
    def on_epoch_end(self, epoch, logs=None):
        self.update_stats(**logs)

    # Overrided
    # We train for one batch only, no need to save anything at epoch end
    def on_batch_end(self, batch, logs=None):
        pass

    def on_train_begin(self, logs=None):
        pass
    
    # Overrided, so won't close writer
    def on_train_end(self, _):
        pass

    # added for performance?
    def on_train_batch_end(self, _, __):
        pass

    # Custom method for saving own metrics
    # Creates writer, writes custom metrics and closes writer
    def update_stats(self, **stats):
        self._write_logs(stats, self.step)

    def _write_logs(self, logs, index):
        for name, value in logs.items():
            self.TB_sess.run(self.all_summary_ops)
            if self.model is not None:
                name = f'name_self.model.name'
            self.TB_sess.run(tf.summary.scalar(name, value, step=index))
        self.model = None

【讨论】:

以上是关于AttributeError:“ModifiedTensorBoard”对象没有属性“_train_dir”的主要内容,如果未能解决你的问题,请参考以下文章

初学者 Python:AttributeError:'list' 对象没有属性

AttributeError:“字节”对象没有属性“告诉”

AttributeError: 'RDD' 对象没有属性 'show'

AttributeError:“NumpyArrayIterator”对象没有属性“类”

AttributeError:模块 'dbus' 没有属性 'lowlevel'

AttributeError:模块'keras'没有属性'initializers'