用外行术语来说,pytorch 中的聚集函数有啥作用?
Posted
技术标签:
【中文标题】用外行术语来说,pytorch 中的聚集函数有啥作用?【英文标题】:What does the gather function do in pytorch in layman terms?用外行术语来说,pytorch 中的聚集函数有什么作用? 【发布时间】:2018-12-02 15:24:51 【问题描述】:我经历过official doc 和this,但很难理解发生了什么。
我正在尝试理解 DQN 源代码,它使用了第 197 行的收集函数。
有人可以简单地解释一下收集函数的作用吗?该函数的目的是什么?
【问题讨论】:
我从未使用过 DQN。你能尝试指定obs_batch
和act_batch
是什么吗?
@McLawrence obs_batch
是观察批次,act_batch
是动作批次。据我了解,这基本上意味着当我将一批观察值传递给 q 函数时,它会返回一组对应于每个观察值的 q 值。
【参考方案1】:
torch.gather
通过沿输入维度 dim
获取每一行的值,从输入张量创建一个新张量。 torch.LongTensor
中的值作为index
传递,指定从每个“行”中获取的值。输出张量的维度与索引张量的维度相同。以下来自官方文档的插图更清楚地解释了它:
(注意:在插图中,索引从 1 开始,而不是 0)。
在第一个示例中,给定的维度是沿行(从上到下),因此对于result
的 (1,1) 位置,它从index
中获取行值src
即@987654330 @。源值中的 (1,1) 处为 1
,因此,在 result
中的 (1,1) 处输出 1
。
同样,对于 (2,2),src
的索引中的行值是 3
。在 (3,2) 处,src
中的值是 8
,因此输出 8
等等。
与第二个示例类似,索引是沿列进行的,因此在 result
的 (2,2) 位置,src
的索引中的列值是 3
,因此在 (2,3)从src
,6
被取出并输出到result
at (2,2)
【讨论】:
谢谢。这真是“一张图抵千言”的例子 这是最好的答案。谢谢。index = torch.as_tensor([[0,1,2],[1,2,0]])
和 src = torch.arange(9).reshape(3,3)
然后 torch.gather(src,0,index)
和 torch.gather(src,1,index.T)
这应该是最佳答案!
我希望我能多次投票!【参考方案2】:
torch.gather
函数(或torch.Tensor.gather
)是一种多索引选择方法。请看官方文档中的以下示例:
t = torch.tensor([[1,2],[3,4]])
r = torch.gather(t, 1, torch.tensor([[0,0],[1,0]]))
# r now holds:
# tensor([[ 1, 1],
# [ 4, 3]])
让我们从不同参数的语义开始:第一个参数input
是我们要从中选择元素的源张量。第二个,dim
,是我们想要收集的维度(或 tensorflow/numpy 中的轴)。最后,index
是索引input
的索引。
至于操作的语义,官方文档是这样解释的:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
让我们来看看这个例子。
输入张量为[[1, 2], [3, 4]]
,dim 参数为1
,即我们要从第二维收集。第二维的索引以[0, 0]
和[1, 0]
给出。
当我们“跳过”第一个维度(我们要收集的维度是1
)时,结果的第一个维度隐含地作为index
的第一个维度给出。这意味着索引包含第二维或列索引,但不包含行索引。这些由 index
张量本身的索引给出。
例如,这意味着输出将在其第一行中选择input
张量的第一行的元素,正如index
张量的第一行的第一行所给出的。由于列索引由[0, 0]
给出,因此我们选择输入第一行的第一个元素两次,得到[1, 1]
。同样,结果第二行的元素是input
张量第二行的元素通过index
张量第二行的元素索引的结果,得到[4, 3]
。
为了进一步说明这一点,让我们交换示例中的维度:
t = torch.tensor([[1,2],[3,4]])
r = torch.gather(t, 0, torch.tensor([[0,0],[1,0]]))
# r now holds:
# tensor([[ 1, 2],
# [ 3, 2]])
如您所见,索引现在沿第一个维度收集。
对于你提到的例子,
current_Q_values = Q(obs_batch).gather(1, act_batch.unsqueeze(1))
gather
将通过动作的批处理列表索引 q 值的行(即一批 q 值中的每个样本 q 值)。结果将与您执行以下操作相同(尽管它比循环快得多):
q_vals = []
for qv, ac in zip(Q(obs_batch), act_batch):
q_vals.append(qv[ac])
q_vals = torch.cat(q_vals, dim=0)
【讨论】:
【参考方案3】:@Ritesh 和@cleros 给出了很好的答案(有很多 的赞成票),但是在阅读它们之后我仍然有点困惑,我知道为什么。这篇文章也许会对像我这样的人有所帮助。
对于这类带有行和列的练习,我认为 真的 有助于使用非方形对象,所以让我们从使用 @987654323 的更大的 4x3 source
(torch.Size([4, 3])
) 开始@。这会给我们
\\ This is the source tensor
tensor([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12]])
现在让我们开始沿列 (dim=1
) 建立索引并创建 index = torch.tensor([[0,0],[1,1],[2,2],[0,1]])
,这是一个列表列表。这是键:因为我们的维度是列,并且源有4
行,所以index
必须包含4
列表!我们需要为每一行创建一个列表。运行source.gather(dim=1, index=index)
会给我们
tensor([[ 1, 1],
[ 5, 5],
[ 9, 9],
[10, 11]])
因此,index
中的每个列表都为我们提供了从中提取值的列。 index
([0,0]
) 的第一个列表告诉我们查看source
的第一行并取该行的第一列(它是零索引)两次,即[1,1]
. index
([1,1]
) 的第二个列表告诉我们看一下source
的第二行,并取该行的第二列两次,即[5,5]
。跳转到index
([0,1]
) 的第 4 个列表,它要求我们查看 source
的第 4 行也是最后一行,要求我们先取第 1 列 (10
),然后第二列 (11
) 给了我们[10,11]
。
这是一件很有趣的事情:index
的每个列表都必须具有相同的长度,但它们可以根据您的喜好而定!比如index = torch.tensor([[0,1,2,1,0],[2,1,0,1,2],[1,2,0,2,1],[1,0,2,0,1]])
,source.gather(dim=1, index=index)
会给我们
tensor([[ 1, 2, 3, 2, 1],
[ 6, 5, 4, 5, 6],
[ 8, 9, 7, 9, 8],
[11, 10, 12, 10, 11]])
输出将始终具有与source
相同的行数,但列数将等于index
中每个列表的长度。比如index
([2,1,0,1,2]
)的第2个列表去source
的第2行,分别拉取第3、2、1、2、3项,也就是[6,5,4,5,6]
。请注意,index
中每个元素的值必须小于source
的列数(在本例中为3
),否则会出现out of bounds
错误。
切换到dim=0
,我们现在将使用行而不是列。使用相同的source
,我们现在需要一个index
,其中每个列表的长度等于source
中的列数。为什么?因为当我们逐列移动时,列表中的每个元素都代表来自source
的行。
因此,index = torch.tensor([[0,0,0],[0,1,2],[1,2,3],[3,2,0]])
将有 source.gather(dim=0, index=index)
给我们
tensor([[ 1, 2, 3],
[ 1, 5, 9],
[ 4, 8, 12],
[10, 8, 3]])
查看index
([0,0,0]
) 中的第一个列表,我们可以看到我们正在移动 source
的 3 列,选择每列的第一个元素(它是零索引),它是[1,2,3]
。 index
([0,1,2]
) 中的第二个列表告诉我们在列中移动,分别采用第一个、第二个和第三个项目,即 [1,5,9]
。以此类推。
对于dim=1
,我们的index
的列表数必须与source
中的行数相等,但每个列表的长度可以根据您的喜好而定。对于dim=0
,index
中的每个列表的长度必须与source
中的列数相同,但我们现在可以拥有任意数量的列表。但是,index
中的每个值都必须小于 source
中的行数(在本例中为 4
)。
例如,index = torch.tensor([[0,0,0],[1,1,1],[2,2,2],[3,3,3],[0,1,2],[1,2,3],[3,2,0]])
将有 source.gather(dim=0, index=index)
给我们
tensor([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12],
[ 1, 5, 9],
[ 4, 8, 12],
[10, 8, 3]])
对于dim=1
,输出的行数始终与source
相同,尽管列数将等于index
中列表的长度。 index
中的列表数必须等于source
中的行数。但是,index
中的每个值都必须小于 source
中的列数。
对于dim=0
,输出的列数始终与source
相同,但行数将等于index
中的列表数。 index
中每个列表的长度必须等于source
中的列数。但是index
中的每个值都必须小于source
中的行数。
这就是二维。超出此范围将遵循相同的模式。
【讨论】:
很棒的答案。您对每个维度所需内容的描述帮助我更轻松地可视化操作。 太棒了。他们对我来说的关键确实是你用粗体指出的。使用非方阵也非常有帮助。非常感谢!【参考方案4】:这是基于@Ritesh answer(感谢@Ritesh!)和一些真实的代码。
torch.gather
API 是
torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
示例 1
当dim = 0
,
dim = 0
input = torch.tensor([[10, 11, 12], [13, 14, 15], [16, 17, 18]])
index = torch.tensor([[0, 1, 2], [1, 2, 0]]
output = torch.gather(input, dim, index))
# tensor([[10, 14, 18],
# [13, 17, 12]])
示例 2
当dim = 1
,
dim = 1
input = torch.tensor([[10, 11, 12], [13, 14, 15], [16, 17, 18]])
index = torch.tensor([[0, 1], [1, 2], [2, 0]]
output = torch.gather(input, dim, index))
# tensor([[10, 11],
# [14, 15],
# [18, 16]])
【讨论】:
以上是关于用外行术语来说,pytorch 中的聚集函数有啥作用?的主要内容,如果未能解决你的问题,请参考以下文章
什么是 PHP 中的 Closures/Lambda 或 Javascript 中的外行术语? [复制]
Rapid Miner 中的 k-means 质心图实际上是啥意思(用外行的话来说)?