tf.app.run() 是如何工作的?

Posted

技术标签:

【中文标题】tf.app.run() 是如何工作的?【英文标题】:How does tf.app.run() work? 【发布时间】:2016-02-15 16:14:50 【问题描述】:

tf.app.run() 在 TensorFlow 翻译演示中如何工作?

tensorflow/models/rnn/translate/translate.py 中,有一个对tf.app.run() 的调用。它是如何处理的?

if __name__ == "__main__":
    tf.app.run() 

【问题讨论】:

【参考方案1】:

它只是一个非常快速的包装器,可以处理标志解析,然后分派到您自己的主程序。请参阅code。

【讨论】:

“处理标志解析”是什么意思?也许您可以添加一个链接来告知初学者这是什么意思? 它使用 flags 包解析提供给程序的命令行参数。 (它在幕后使用标准的“argparse”库,并带有一些包装器)。它链接自我在答案中链接到的代码。 在app.py中,main = main or sys.modules['__main__'].mainsys.exit(main(sys.argv[:1] + flags_passthrough))是什么意思? 这对我来说似乎很奇怪,如果你可以直接调用它main(),为什么还要将 main 函数包含在内? hAcKnRoCk:如果文件中没有 main,则使用 sys.modules['main'].main 中的任何内容。 sys.exit 意味着使用 args 和通过的任何标志运行由此找到的 main 命令,并以 main 的返回值退出。 @CharlieParker - 与 Google 现有的 python 应用程序库(如 gflags 和 google-apputils)兼容。例如,请参阅github.com/google/google-apputils【参考方案2】:
if __name__ == "__main__":

表示当前文件在 shell 下执行,而不是作为模块导入。

tf.app.run()

通过文件app.py可以看到

def run(main=None, argv=None):
  """Runs the program with an optional 'main' function and 'argv' list."""
  f = flags.FLAGS

  # Extract the args from the optional `argv` list.
  args = argv[1:] if argv else None

  # Parse the known flags from that list, or from the command
  # line otherwise.
  # pylint: disable=protected-access
  flags_passthrough = f._parse_flags(args=args)
  # pylint: enable=protected-access

  main = main or sys.modules['__main__'].main

  # Call the main function, passing through any arguments
  # to the final program.
  sys.exit(main(sys.argv[:1] + flags_passthrough))

让我们逐行分解:

flags_passthrough = f._parse_flags(args=args)

这确保您通过命令行传递的参数是有效的,例如 python my_model.py --data_dir='...' --max_iteration=10000其实这个特性是基于python标准argparse模块实现的。

main = main or sys.modules['__main__'].main

=右侧的第一个main是当前函数run(main=None, argv=None)的第一个参数 .而sys.modules['__main__'] 表示当前正在运行的文件(例如my_model.py)。

所以有两种情况:

    你在my_model.py 中没有main 函数那么你必须 拨打tf.app.run(my_main_running_function)

    你在my_model.py 中有一个main 函数。 (大部分情况都是这样。)

最后一行:

sys.exit(main(sys.argv[:1] + flags_passthrough))

确保您的 main(argv)my_main_running_function(argv) 函数被正确地使用已解析的参数调用。

【讨论】:

Tensorflow 初学者缺少的一块拼图:Tensorflow 有一些内置的命令行标志处理机制。您可以定义您的标志,如tf.flags.DEFINE_integer('batch_size', 128, 'Number of images to process in a batch.'),然后如果您使用tf.app.run(),它将进行设置,以便您可以全局访问您定义的标志的传递值,例如tf.flags.FLAGS.batch_size,无论您在代码中需要它的任何地方。 在我看来,这是(当前)三个中更好的答案。它解释了“tf.app.run() 是如何工作的”,而其他两个答案只是说明它的作用。 看起来标志是由 abseil 处理的,TF 肯定已经吸收了 abseil.io/docs/python/guides/flags【参考方案3】:

tf.app 没有什么特别之处。这只是一个generic entry point script,

使用可选的“main”函数和“argv”列表运行程序。

它与神经网络无关,它只是调用主函数,将任何参数传递给它。

【讨论】:

【参考方案4】:

简单来说,tf.app.run() 的工作是首先设置全局标志以供以后使用,例如:

from tensorflow.python.platform import flags
f = flags.FLAGS

然后使用一组参数运行您的 自定义 main 函数。

例如在TensorFlow NMT 代码库中,训练/推理程序执行的第一个入口点从此时开始(参见下面的代码)

if __name__ == "__main__":
  nmt_parser = argparse.ArgumentParser()
  add_arguments(nmt_parser)
  FLAGS, unparsed = nmt_parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

使用argparse 解析参数后,使用tf.app.run() 运行函数“main”,其定义如下:

def main(unused_argv):
  default_hparams = create_hparams(FLAGS)
  train_fn = train.train
  inference_fn = inference.inference
  run_main(FLAGS, default_hparams, train_fn, inference_fn)

因此,在设置全局使用的标志之后,tf.app.run() 只需运行您传递给它的 main 函数,并将 argv 作为其参数。

PS:正如Salvador Dali's answer 所说,我猜这只是一个很好的软件工程实践,虽然我不确定 TensorFlow 是否对 main 函数执行任何优化运行而不是使用普通 CPython 运行。

【讨论】:

【参考方案5】:

Google 代码很大程度上依赖于在库/二进制文件/python 脚本中访问的全局标志,因此 tf.app.run() 解析出这些标志以在 FLAGs(或类似的东西)变量中创建一个全局状态,然后调用python main() 应该的。

如果他们没有对 tf.app.run() 进行此调用,那么用户可能会忘记进行 FLAG 解析,从而导致这些库/二进制文件/脚本无法访问他们需要的 FLAG。

【讨论】:

【参考方案6】:

2.0兼容答案:如果你想在Tensorflow 2.0中使用tf.app.run(),我们应该使用命令,

tf.compat.v1.app.run() 或者您可以使用tf_upgrade_v21.x 代码转换为2.0

【讨论】:

以上是关于tf.app.run() 是如何工作的?的主要内容,如果未能解决你的问题,请参考以下文章

tensorflow2:tf.app.run()

if __name__ == "__main__": tf.app.run()

Tensorflow tf.app.flags 的使用

tensorflow的函数

CIFAR-10 DEMO代码阅读与理解

mnist实例