Tensorflow - tf.split使用

Posted jesee

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Tensorflow - tf.split使用相关的知识,希望对你有一定的参考价值。

XDeepFM的CIN中第一层实现需要使两个二维矩阵相乘得到一个三维张量,于是来复习下split函数(需要用到):
首先看下函数原理:

tf.split(
    value,
    num_or_size_splits,
    axis=0,
    num=None,
    name=split
)

这个函数是用来切割张量的:输入切割的张量和参数,返回切割的结果。
value传入的就是需要切割的张量,axis的数值代表切割哪个维度。
这个函数有两种切割的方式:

以三个维度的张量为例,比如说一个20 * 30 * 40的张量my_tensor,就如同一个长20厘米宽30厘米高40厘米的蛋糕,每立方厘米都是一个分量。

有两种切割方式:
1. 如果num_or_size_splits传入的是一个整数,这个整数代表这个张量最后会被切成几个小张量。此时,传入axis的数值就代表切割哪个维度(从0开始计数)。调用tf.split(my_tensor, 2,0)返回两个10 * 30 * 40的小张量。
2. 如果num_or_size_splits传入的是一个向量,那么向量有几个分量就分成几份,切割的维度还是由axis决定。比如调用tf.split(my_tensor, [10, 5, 25], 2),则返回三个张量分别大小为 20 * 30 * 10、20 * 30 * 5、20 * 30 * 25。很显然,传入的这个向量各个分量加和必须等于axis所指示原张量维度的大小 (10 + 5 + 25 = 40)。

一个实例:

import tensorflow as tf
import numpy as np

arr1 = tf.convert_to_tensor(np.arange(1,25).reshape(2,4,3),dtype=tf.int32)


with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    split_arr1 = tf.split(arr1,[1,1,1],2) # 切割成2个2*4*1的张量
   print(sess.run(split_arr1)

可以看到原来的2*4*3的张量被切割为了3个2*4*1的张量

技术图片

 

Reference:

https://blog.csdn.net/SangrealLilith/article/details/80272346

以上是关于Tensorflow - tf.split使用的主要内容,如果未能解决你的问题,请参考以下文章

Keras Tensorflow 中的切片张量

text tf.split()分割张量用法

Tensor的合并与分割

tensorflow tensorboard 摘要示例

如何使用 TensorFlow 连接两个具有不同形状的张量?

解决Tensorflow源码安装的之后TensorBoard 无法使用的问题