『Pytorch』静动态图构建对比

Posted 叠加态的猫

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了『Pytorch』静动态图构建对比相关的知识,希望对你有一定的参考价值。

对比TensorFlow和Pytorch的动静态图构建上的差异

静态图框架设计好了不能够修改,且定义静态图时需要使用新的特殊语法,这也意味着图设定时无法使用if、while、for-loop等结构,而是需要特殊的由框架专门设计的语法,在构建图时,我们需要考虑到所有的情况(即各个if分支图结构必须全部在图中,即使不一定会在每一次运行时使用到),使得静态图异常庞大占用过多显存。

以动态图没有这个顾虑,它兼容python的各种逻辑控制语法,最终创建的图取决于每次运行时的条件分支选择,下面我们对比一下TensorFlow和Pytorch的if条件分支构建图的实现:

# Author : Hellcat
# Time   : 2018/2/9

def tf_graph_if():
    import numpy as np
    import tensorflow as tf

    x = tf.placeholder(tf.float32, shape=(3, 4))
    z = tf.placeholder(tf.float32, shape=None)
    w1 = tf.placeholder(tf.float32, shape=(4, 5))
    w2 = tf.placeholder(tf.float32, shape=(4, 5))

    def f1():
        return tf.matmul(x, w1)

    def f2():
        return tf.matmul(x, w2)

    y = tf.cond(tf.less(z, 0), f1, f2)

    with tf.Session() as sess:
        y_out = sess.run(y, feed_dict={
            x: np.random.randn(3, 4),
            z: 10,
            w1: np.random.randn(4, 5),
            w2: np.random.randn(4, 5)})
    return y_out

def t_graph_if():
    import torch as t
    from torch.autograd import Variable

    x = Variable(t.randn(3, 4))
    w1 = Variable(t.randn(4, 5))
    w2 = Variable(t.randn(4, 5))

    z = 10
    if z > 0:
        y = x.mm(w1)
    else:
        y = x.mm(w2)

    return y


if __name__ == "__main__":
    print(tf_graph_if())
    print(t_graph_if())

 计算输出如下:

[[ 4.0871315   0.90317607 -4.65211582  0.71610922 -2.70281982]
 [ 3.67874336 -0.58160967 -3.43737102  1.9781189  -2.18779659]
 [ 2.6638422  -0.81783319 -0.30386463 -0.61386991 -3.80232286]]
Variable containing:
-0.2474  0.1269  0.0830  3.4642  0.2255
 0.7555 -0.8057 -2.8159  3.7416  0.6230
 0.9010 -0.9469 -2.5086 -0.8848  0.2499
[torch.FloatTensor of size 3x5]

 

个人感觉上面的对比不太完美,如果使用TensorFlow的变量来对比,上面函数应该改写如下,

# Author : Hellcat
# Time   : 2018/2/9

def tf_graph_if():
    import tensorflow as tf

    x = tf.Variable(dtype=tf.float32, initial_value=tf.random_uniform(shape=[3, 4]))
    z = tf.constant(dtype=tf.float32, value=10)
    w1 = tf.Variable(dtype=tf.float32, initial_value=tf.random_uniform(shape=[4, 5]))
    w2 = tf.Variable(dtype=tf.float32, initial_value=tf.random_uniform(shape=[4, 5]))

    def f1():
        return tf.matmul(x, w1)

    def f2():
        return tf.matmul(x, w2)

    y = tf.cond(tf.less(z, 0), f1, f2)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        y_out = sess.run(y)
    return y_out

 输出没什么变化,

[[ 1.89582038  1.12734962  0.59730953  0.99833554  0.86517167]
 [ 1.2659111   0.77320379  0.63649696  0.5804953   0.82271856]
 [ 1.92151642  1.64715886  1.19869363  1.31581473  1.5636673 ]]

 

可以看到,TensorFlow的if条件分支使用函数tf.cond(tf.less(z, 0), f1, f2)来实现,这和Pytorch直接使用if的逻辑很不同,而且,动态图不需要feed,直接运行便可。简单对比,可以看到Pytorch的逻辑更为简洁,让人很感兴趣。

 

以上是关于『Pytorch』静动态图构建对比的主要内容,如果未能解决你的问题,请参考以下文章

观点 | 属于动态图的未来:横向对比PyTorch与Keras

pytorch 计算图与动态图机制

PyTorch简易入门

PyTorch-入门与安装

pytorch介绍和环境配置

最新PyTorch0.4.0教程01PyTorch的动态计算图深入浅出