Tensorflow - tf.split使用
Posted jesee
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Tensorflow - tf.split使用相关的知识,希望对你有一定的参考价值。
XDeepFM的CIN中第一层实现需要使两个二维矩阵相乘得到一个三维张量,于是来复习下split函数(需要用到):
首先看下函数原理:
tf.split( value, num_or_size_splits, axis=0, num=None, name=‘split‘ )
这个函数是用来切割张量的:输入切割的张量和参数,返回切割的结果。
value传入的就是需要切割的张量,axis的数值代表切割哪个维度。
这个函数有两种切割的方式:
以三个维度的张量为例,比如说一个20 * 30 * 40的张量my_tensor,就如同一个长20厘米宽30厘米高40厘米的蛋糕,每立方厘米都是一个分量。
有两种切割方式:
1. 如果num_or_size_splits传入的是一个整数,这个整数代表这个张量最后会被切成几个小张量。此时,传入axis的数值就代表切割哪个维度(从0开始计数)。调用tf.split(my_tensor, 2,0)返回两个10 * 30 * 40的小张量。
2. 如果num_or_size_splits传入的是一个向量,那么向量有几个分量就分成几份,切割的维度还是由axis决定。比如调用tf.split(my_tensor, [10, 5, 25], 2),则返回三个张量分别大小为 20 * 30 * 10、20 * 30 * 5、20 * 30 * 25。很显然,传入的这个向量各个分量加和必须等于axis所指示原张量维度的大小 (10 + 5 + 25 = 40)。
一个实例:
import tensorflow as tf import numpy as np arr1 = tf.convert_to_tensor(np.arange(1,25).reshape(2,4,3),dtype=tf.int32) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) split_arr1 = tf.split(arr1,[1,1,1],2) # 切割成2个2*4*1的张量
print(sess.run(split_arr1)
可以看到原来的2*4*3的张量被切割为了3个2*4*1的张量
Reference:
https://blog.csdn.net/SangrealLilith/article/details/80272346
以上是关于Tensorflow - tf.split使用的主要内容,如果未能解决你的问题,请参考以下文章