将 tf.data.Dataset 转换为 jax.numpy 迭代器

Posted

技术标签:

【中文标题】将 tf.data.Dataset 转换为 jax.numpy 迭代器【英文标题】:Turn a tf.data.Dataset to a jax.numpy iterator 【发布时间】:2021-12-15 08:26:20 【问题描述】:

我对使用 JAX 训练神经网络很感兴趣。我查看了tf.data.Dataset,但它只提供 tf 张量。我寻找一种将数据集更改为 JAX numpy 数组的方法,我发现许多使用 Dataset.as_numpy_generator() 将 tf 张量转换为 numpy 数组的实现。但是我想知道这是否是一个好习惯,因为 numpy 数组存储在 CPU 内存中,这不是我想要的训练(我使用 GPU)。所以我发现的最后一个想法是通过调用jnp.array 手动重铸数组,但这并不是很优雅(我担心GPU内存中的副本)。有人对此有更好的想法吗?

快速代码说明:

import os
import jax.numpy as jnp
import tensorflow as tf

def generator():
    for _ in range(2):
        yield tf.random.uniform((1, ))

ds = tf.data.Dataset.from_generator(generator, output_types=tf.float32,
                                    output_shapes=tf.TensorShape([1]))

ds1 = ds.take(1).as_numpy_iterator()
ds2 = ds.skip(1)

for i, batch in enumerate(ds1):
    print(type(batch))

for i, batch in enumerate(ds2):
    print(type(jnp.array(batch)))

# returns:

<class 'numpy.ndarray'> # not good
<class 'jaxlib.xla_extension.DeviceArray'> # good but not elegant

【问题讨论】:

欢迎来到 SO;如果下面的答案解决了您的问题,请接受 - 请参阅What should I do when someone answers my question? 【参考方案1】:

tensorflow 和 JAX 都能够在不复制内存的情况下将数组转换为 dlpack 张量,因此您可以从 tensorflow 数组创建 JAX 数组而不复制底层数据缓冲区的一种方法是通过 dlpack 进行:

import numpy as np
import tensorflow as tf
import jax.dlpack

tf_arr = tf.random.uniform((10,))
dl_arr = tf.experimental.dlpack.to_dlpack(tf_arr)
jax_arr = jax.dlpack.from_dlpack(dl_arr)

np.testing.assert_array_equal(tf_arr, jax_arr)

通过往返于 JAX,您可以比较 unsafe_buffer_pointer() 以确保数组指向同一个缓冲区,而不是一路复制缓冲区:

def tf_to_jax(arr):
  return jax.dlpack.from_dlpack(tf.experimental.dlpack.to_dlpack(tf_arr))

def jax_to_tf(arr):
  return tf.experimental.dlpack.from_dlpack(jax.dlpack.to_dlpack(arr))

jax_arr = jnp.arange(20.)
tf_arr = jax_to_tf(jax_arr)
jax_arr2 = tf_to_jax(tf_arr)

print(jnp.all(jax_arr == jax_arr2))
# True
print(jax_arr.unsafe_buffer_pointer() == jax_arr2.unsafe_buffer_pointer())
# True

【讨论】:

非常感谢!您知道是否可以在所有数据集中运行该函数一次?我尝试了 .map() 方法,但它失败了,因为 The argument to `to_dlpack` must be a TF tensor, not Python object 即使我的数据集是由 tf.Tensor... 我不知道你所说的“在所有数据集中运行一次函数”是什么意思 类似dataset.map(tf_to_jax) 以避免在数据集的每次迭代中调用该函数 不,我认为 tensorflow 不支持类似的东西。

以上是关于将 tf.data.Dataset 转换为 jax.numpy 迭代器的主要内容,如果未能解决你的问题,请参考以下文章

tensorflow-读写数据tf.data

两个tf.data.Dataset可以共存并由tf.cond()控制

tensorflow-tf.data

002.tf.data.DataSet

tf.data.Dataset:不能为给定的输入类型指定 `batch_size` 参数

如何在 tf.data.Dataset 对象上使用序列/生成器将部分数据放入内存?