与此 Python 函数等效的 JaxNumpy 兼容是啥?

Posted

技术标签:

【中文标题】与此 Python 函数等效的 JaxNumpy 兼容是啥?【英文标题】:What is JaxNumpy-compatible equivalent to this Python function?与此 Python 函数等效的 JaxNumpy 兼容是什么? 【发布时间】:2021-07-21 05:52:03 【问题描述】:

如何以与 JAX 兼容的方式实现以下功能(例如,使用 jax.numpy)?

def actions(state: tuple[int, ...]) -> list[tuple[int, ...]]:
    l = []
    iterables = [range(1, i+1) for i in state]
    ns = list(range(len(iterables)))
    for i, iterable in enumerate(iterables):
        for value in iterable:
            action = tuple(value if n == i else 0 for n in ns)
            l.append(action)
    return l

>>> state = (3, 1, 2)
>>> actions(state)
[(1, 0, 0), (2, 0, 0), (3, 0, 0), (0, 1, 0), (0, 0, 1), (0, 0, 2)]

【问题讨论】:

Jax 和 numpy 一样,不能有效地对元组和列表进行操作——输出一个二维数组是否足以满足您的用例? 当然,可以将数组作为输入(1D ... n)和输出(2D ... m x n)。元组只是纯 Python 等价物(因为我需要它们是不可变的)。 【参考方案1】:

Jax 和 numpy 一样,不能有效地对 Python 容器类型(如列表和元组)进行操作,因此实际上没有任何与 JAX 兼容的方式来创建具有您上面指定的确切签名的函数。

但是,如果您对返回值是一个二维数组感到满意,您可以根据jnp.vstack 执行类似的操作:

from typing import Tuple
import jax.numpy as jnp
from jax import jit, partial

@partial(jit, static_argnums=0)
def actions(state: Tuple[int, ...]) -> jnp.ndarray:
  return jnp.vstack([
    jnp.zeros((val, len(state)), int).at[:, i].set(jnp.arange(1, val + 1))
    for i, val in enumerate(state)])
>>> state = (3, 1, 2)
>>> actions(state)
DeviceArray([[1, 0, 0],
             [2, 0, 0],
             [3, 0, 0],
             [0, 1, 0],
             [0, 0, 1],
             [0, 0, 2]], dtype=int32)

请注意,由于输出数组的大小取决于state 的内容,state 必须是静态量,因此元组是输入的不错选择。

【讨论】:

这行得通,谢谢。根据我在上面评论中的回复,采用state:jnp.ndarray 输入是否允许不同的构造? 不,因为输出的大小取决于状态的内容,它的内容必须是静态的,所以最好将其指定为静态元组。 我的犹豫是下一步是从状态中减去一个动作(这是一个 Nim 游戏)。因此,如果它们具有相同的类型,那就太好了,因为这是强化学习设置中的环境函数。

以上是关于与此 Python 函数等效的 JaxNumpy 兼容是啥?的主要内容,如果未能解决你的问题,请参考以下文章

python函数.size()的sql等效函数是啥?

Python中matlab imfilter的等效函数

Javascript 等效于 Python 的 zip 函数

C++ 中是不是有与 python 中的 astype() 函数等效的函数?

matlab中是不是有等效的python函数id()?

Python/Django中PHP“in”函数的等效函数