如何在Tensorflow中使用自定义/非默认tf.Graph正确的方法?
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了如何在Tensorflow中使用自定义/非默认tf.Graph正确的方法?相关的知识,希望对你有一定的参考价值。
我是Tensorflow
的新手,我正在阅读https://www.amazon.com/TensorFlow-Machine-Learning-Cookbook-McClure/dp/1786462168。我在tf.Session
注意到的一个论点是graph
。我喜欢完全控制流程,我想知道如何正确使用tf.Graph
与tf.Session
以及如何将计算添加到特定的图形?另外,人们在Tensorflow
中向特定图形添加操作的规范语法(如果有的话)是什么?
类似于以下内容:
t = np.linspace(0,2*np.pi)
fig, ax = plt.subplots()
ax.scatter(x=t, y=np.sin(t))
相比:
plt.scatter(x=t, y=np.sin(t))
我如何才能与tf.Graph()
具有相同的灵活性?
G = tf.Graph()
t_query = np.linspace(0,2*np.pi,50)
pH_t = tf.placeholder(tf.float32, shape=t_query.shape)
def simple_sinewave(t, name=None):
return tf.sin(t, name=name)
with tf.Session() as sess:
r = sess.run(simple_sinewave(pH_t), feed_dict={pH_t:t_query})
# array([ 0.00000000e+00, 1.27877161e-01, 2.53654599e-01,
# ...
# -1.27877384e-01, 1.74845553e-07], dtype=float32)
现在尝试指定一个graph
参数:
with tf.Session(graph=G) as sess:
r = sess.run(simple_sinewave(pH_t), feed_dict={pH_t:t_query})
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-51-d73a1f0963e3> in <module>()
26 # -1.27877384e-01, 1.74845553e-07], dtype=float32)
27 with tf.Session(graph=G) as sess:
---> 28 r = sess.run(simple_sinewave(pH_t), feed_dict={pH_t:t_query})
... RuntimeError:会话图是空的。在调用run()之前向图形添加操作。
使用David Parks更新回答此问题:
# Custom Function
def simple_sinewave(t, name=None):
return tf.sin(t, name=name)
# Synth graph
G = tf.Graph()
# Build Graph
with G.as_default():
t_query = np.linspace(0,2*np.pi,50)
pH_t = tf.placeholder(tf.float32, shape=t_query.shape)
# Run session using Graph
with tf.Session(graph=G) as sess:
r = sess.run(simple_sinewave(pH_t), feed_dict={pH_t:t_query})
r
# array([ 0.00000000e+00, 1.27877161e-01, 2.53654599e-01,
# 3.75266999e-01, 4.90717560e-01, 5.98110557e-01,
# ...
# -4.90717530e-01, -3.75267059e-01, -2.53654718e-01,
# -1.27877384e-01, 1.74845553e-07], dtype=float32)
额外:在Tensorflow中是否有一个特定的命名符来命名占位符变量?像pd.DataFrame
一样df_data
。
以上是关于如何在Tensorflow中使用自定义/非默认tf.Graph正确的方法?的主要内容,如果未能解决你的问题,请参考以下文章
如何在 keras 自定义回调中访问 tf.data.Dataset?
如何在 Tensorflow 2.x Keras 自定义层中使用多个输入?