与此 Python 函数等效的 JaxNumpy 兼容是啥?
Posted
技术标签:
【中文标题】与此 Python 函数等效的 JaxNumpy 兼容是啥?【英文标题】:What is JaxNumpy-compatible equivalent to this Python function?与此 Python 函数等效的 JaxNumpy 兼容是什么? 【发布时间】:2021-07-21 05:52:03 【问题描述】:如何以与 JAX 兼容的方式实现以下功能(例如,使用 jax.numpy
)?
def actions(state: tuple[int, ...]) -> list[tuple[int, ...]]:
l = []
iterables = [range(1, i+1) for i in state]
ns = list(range(len(iterables)))
for i, iterable in enumerate(iterables):
for value in iterable:
action = tuple(value if n == i else 0 for n in ns)
l.append(action)
return l
>>> state = (3, 1, 2)
>>> actions(state)
[(1, 0, 0), (2, 0, 0), (3, 0, 0), (0, 1, 0), (0, 0, 1), (0, 0, 2)]
【问题讨论】:
Jax 和 numpy 一样,不能有效地对元组和列表进行操作——输出一个二维数组是否足以满足您的用例? 当然,可以将数组作为输入(1D ... n)和输出(2D ... m x n)。元组只是纯 Python 等价物(因为我需要它们是不可变的)。 【参考方案1】:Jax 和 numpy 一样,不能有效地对 Python 容器类型(如列表和元组)进行操作,因此实际上没有任何与 JAX 兼容的方式来创建具有您上面指定的确切签名的函数。
但是,如果您对返回值是一个二维数组感到满意,您可以根据jnp.vstack
执行类似的操作:
from typing import Tuple
import jax.numpy as jnp
from jax import jit, partial
@partial(jit, static_argnums=0)
def actions(state: Tuple[int, ...]) -> jnp.ndarray:
return jnp.vstack([
jnp.zeros((val, len(state)), int).at[:, i].set(jnp.arange(1, val + 1))
for i, val in enumerate(state)])
>>> state = (3, 1, 2)
>>> actions(state)
DeviceArray([[1, 0, 0],
[2, 0, 0],
[3, 0, 0],
[0, 1, 0],
[0, 0, 1],
[0, 0, 2]], dtype=int32)
请注意,由于输出数组的大小取决于state
的内容,state
必须是静态量,因此元组是输入的不错选择。
【讨论】:
这行得通,谢谢。根据我在上面评论中的回复,采用state:jnp.ndarray
输入是否允许不同的构造?
不,因为输出的大小取决于状态的内容,它的内容必须是静态的,所以最好将其指定为静态元组。
我的犹豫是下一步是从状态中减去一个动作(这是一个 Nim 游戏)。因此,如果它们具有相同的类型,那就太好了,因为这是强化学习设置中的环境函数。以上是关于与此 Python 函数等效的 JaxNumpy 兼容是啥?的主要内容,如果未能解决你的问题,请参考以下文章
Javascript 等效于 Python 的 zip 函数