如何在Tensorflow中使用自定义/非默认tf.Graph正确的方法?

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了如何在Tensorflow中使用自定义/非默认tf.Graph正确的方法?相关的知识,希望对你有一定的参考价值。

我是Tensorflow的新手,我正在阅读https://www.amazon.com/TensorFlow-Machine-Learning-Cookbook-McClure/dp/1786462168。我在tf.Session注意到的一个论点是graph。我喜欢完全控制流程,我想知道如何正确使用tf.Graphtf.Session以及如何将计算添加到特定的图形?另外,人们在Tensorflow中向特定图形添加操作的规范语法(如果有的话)是什么?

类似于以下内容:

t = np.linspace(0,2*np.pi)
fig, ax = plt.subplots()
ax.scatter(x=t, y=np.sin(t))

相比:

plt.scatter(x=t, y=np.sin(t))

我如何才能与tf.Graph()具有相同的灵活性?

G = tf.Graph()

t_query = np.linspace(0,2*np.pi,50)
pH_t = tf.placeholder(tf.float32, shape=t_query.shape)

def simple_sinewave(t, name=None):
    return tf.sin(t, name=name)

with tf.Session() as sess:
    r = sess.run(simple_sinewave(pH_t), feed_dict={pH_t:t_query})
# array([  0.00000000e+00,   1.27877161e-01,   2.53654599e-01,
# ...
#         -1.27877384e-01,   1.74845553e-07], dtype=float32)

现在尝试指定一个graph参数:

with tf.Session(graph=G) as sess:
    r = sess.run(simple_sinewave(pH_t), feed_dict={pH_t:t_query})
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-51-d73a1f0963e3> in <module>()
     26 #         -1.27877384e-01,   1.74845553e-07], dtype=float32)
     27 with tf.Session(graph=G) as sess:
---> 28     r = sess.run(simple_sinewave(pH_t), feed_dict={pH_t:t_query})

... RuntimeError:会话图是空的。在调用run()之前向图形添加操作。

使用David Parks更新回答此问题:

# Custom Function
def simple_sinewave(t, name=None):
    return tf.sin(t, name=name)

# Synth graph
G = tf.Graph()

# Build Graph
with G.as_default():
    t_query = np.linspace(0,2*np.pi,50)
    pH_t = tf.placeholder(tf.float32, shape=t_query.shape)

# Run session using Graph
with tf.Session(graph=G) as sess:
    r = sess.run(simple_sinewave(pH_t), feed_dict={pH_t:t_query})
r
# array([  0.00000000e+00,   1.27877161e-01,   2.53654599e-01,
#          3.75266999e-01,   4.90717560e-01,   5.98110557e-01,
# ...
#         -4.90717530e-01,  -3.75267059e-01,  -2.53654718e-01,
#         -1.27877384e-01,   1.74845553e-07], dtype=float32)

额外:在Tensorflow中是否有一个特定的命名符来命名占位符变量?像pd.DataFrame一样df_data

以上是关于如何在Tensorflow中使用自定义/非默认tf.Graph正确的方法?的主要内容,如果未能解决你的问题,请参考以下文章

在 TensorFlow 2.0 中实现自定义损失函数

如何在 keras 自定义回调中访问 tf.data.Dataset?

如何在 Tensorflow 2.x Keras 自定义层中使用多个输入?

在 tensorflow 中编写自定义成本函数

如何强制 TensorFlow 在 float16 下运行?

如何在 tensorflow 服务器上使用自定义操作