用外行术语来说,pytorch 中的聚集函数有啥作用?

Posted

技术标签:

【中文标题】用外行术语来说,pytorch 中的聚集函数有啥作用?【英文标题】:What does the gather function do in pytorch in layman terms?用外行术语来说,pytorch 中的聚集函数有什么作用? 【发布时间】:2018-12-02 15:24:51 【问题描述】:

我经历过official doc 和this,但很难理解发生了什么。

我正在尝试理解 DQN 源代码,它使用了第 197 行的收集函数。

有人可以简单地解释一下收集函数的作用吗?该函数的目的是什么?

【问题讨论】:

我从未使用过 DQN。你能尝试指定obs_batchact_batch 是什么吗? @McLawrence obs_batch 是观察批次,act_batch 是动作批次。据我了解,这基本上意味着当我将一批观察值传递给 q 函数时,它会返回一组对应于每个观察值的 q 值。 【参考方案1】:

torch.gather 通过沿输入维度 dim 获取每一行的值,从输入张量创建一个新张量。 torch.LongTensor 中的值作为index 传递,指定从每个“行”中获取的值。输出张量的维度与索引张量的维度相同。以下来自官方文档的插图更清楚地解释了它:

(注意:在插图中,索引从 1 开始,而不是 0)。

在第一个示例中,给定的维度是沿行(从上到下),因此对于result 的 (1,1) 位置,它从index 中获取行值src 即@987654330 @。源值中的 (1,1) 处为 1,因此,在 result 中的 (1,1) 处输出 1。 同样,对于 (2,2),src 的索引中的行值是 3。在 (3,2) 处,src 中的值是 8,因此输出 8 等等。

与第二个示例类似,索引是沿列进行的,因此在 result 的 (2,2) 位置,src 的索引中的列值是 3,因此在 (2,3)从src ,6 被取出并输出到result at (2,2)

【讨论】:

谢谢。这真是“一张图抵千言”的例子 这是最好的答案。谢谢。 index = torch.as_tensor([[0,1,2],[1,2,0]])src = torch.arange(9).reshape(3,3) 然后 torch.gather(src,0,index)torch.gather(src,1,index.T) 这应该是最佳答案! 我希望我能多次投票!【参考方案2】:

torch.gather 函数(或torch.Tensor.gather)是一种多索引选择方法。请看官方文档中的以下示例:

t = torch.tensor([[1,2],[3,4]])
r = torch.gather(t, 1, torch.tensor([[0,0],[1,0]]))
# r now holds:
# tensor([[ 1,  1],
#        [ 4,  3]])

让我们从不同参数的语义开始:第一个参数input 是我们要从中选择元素的源张量。第二个,dim,是我们想要收集的维度(或 tensorflow/numpy 中的轴)。最后,index 是索引input 的索引。 至于操作的语义,官方文档是这样解释的:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

让我们来看看这个例子。

输入张量为[[1, 2], [3, 4]],dim 参数为1,即我们要从第二维收集。第二维的索引以[0, 0][1, 0] 给出。

当我们“跳过”第一个维度(我们要收集的维度是1)时,结果的第一个维度隐含地作为index 的第一个维度给出。这意味着索引包含第二维或列索引,但不包含行索引。这些由 index 张量本身的索引给出。 例如,这意味着输出将在其第一行中选择input 张量的第一行的元素,正如index 张量的第一行的第一行所给出的。由于列索引由[0, 0] 给出,因此我们选择输入第一行的第一个元素两次,得到[1, 1]。同样,结果第二行的元素是input张量第二行的元素通过index张量第二行的元素索引的结果,得到[4, 3]

为了进一步说明这一点,让我们交换示例中的维度:

t = torch.tensor([[1,2],[3,4]])
r = torch.gather(t, 0, torch.tensor([[0,0],[1,0]]))
# r now holds:
# tensor([[ 1,  2],
#        [ 3,  2]])

如您所见,索引现在沿第一个维度收集。

对于你提到的例子,

current_Q_values = Q(obs_batch).gather(1, act_batch.unsqueeze(1))

gather 将通过动作的批处理列表索引 q 值的行(即一批 q 值中的每个样本 q 值)。结果将与您执行以下操作相同(尽管它比循环快得多):

q_vals = []
for qv, ac in zip(Q(obs_batch), act_batch):
    q_vals.append(qv[ac])
q_vals = torch.cat(q_vals, dim=0)

【讨论】:

【参考方案3】:

@Ritesh 和@cleros 给出了很好的答案(有很多 的赞成票),但是在阅读它们之后我仍然有点困惑,我知道为什么。这篇文章也许会对像我这样的人有所帮助。

对于这类带有行和列的练习,我认为 真的 有助于使用非方形对象,所以让我们从使用 @987654323 的更大的 4x3 source (torch.Size([4, 3])) 开始@。这会给我们

\\ This is the source tensor
tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [10, 11, 12]])

现在让我们开始沿列 (dim=1) 建立索引并创建 index = torch.tensor([[0,0],[1,1],[2,2],[0,1]]),这是一个列表列表。这是:因为我们的维度是列,并且源有4 行,所以index 必须包含4 列表!我们需要为每一行创建一个列表。运行source.gather(dim=1, index=index) 会给我们

tensor([[ 1,  1],
        [ 5,  5],
        [ 9,  9],
        [10, 11]])

因此,index 中的每个列表都为我们提供了从中提取值的列。 index ([0,0]) 的第一个列表告诉我们查看source 的第一行并取该行的第一列(它是零索引)两次,即[1,1] . index ([1,1]) 的第二个列表告诉我们看一下source 的第二行,并取该行的第二列两次,即[5,5]。跳转到index ([0,1]) 的第 4 个列表,它要求我们查看 source 的第 4 行也是最后一行,要求我们先取第 1 列 (10),然后第二列 (11) 给了我们[10,11]

这是一件很有趣的事情:index 的每个列表都必须具有相同的长度,但它们可以根据您的喜好而定!比如index = torch.tensor([[0,1,2,1,0],[2,1,0,1,2],[1,2,0,2,1],[1,0,2,0,1]])source.gather(dim=1, index=index)会给我们

tensor([[ 1,  2,  3,  2,  1],
        [ 6,  5,  4,  5,  6],
        [ 8,  9,  7,  9,  8],
        [11, 10, 12, 10, 11]])

输出将始终具有与source 相同的行数,但列数将等于index 中每个列表的长度。比如index([2,1,0,1,2])的第2个列表去source的第2行,分别拉取第3、2、1、2、3项,也就是[6,5,4,5,6]。请注意,index 中每个元素的值必须小于source 的列数(在本例中为3),否则会出现out of bounds 错误。

切换到dim=0,我们现在将使用行而不是列。使用相同的source,我们现在需要一个index,其中每个列表的长度等于source 中的列数。为什么?因为当我们逐列移动时,列表中的每个元素都代表来自source 的行。

因此,index = torch.tensor([[0,0,0],[0,1,2],[1,2,3],[3,2,0]]) 将有 source.gather(dim=0, index=index) 给我们

tensor([[ 1,  2,  3],
        [ 1,  5,  9],
        [ 4,  8, 12],
        [10,  8,  3]])

查看index ([0,0,0]) 中的第一个列表,我们可以看到我们正在移动 source 的 3 列,选择每列的第一个元素(它是零索引),它是[1,2,3]index ([0,1,2]) 中的第二个列表告诉我们在列中移动,分别采用第一个、第二个和第三个项目,即 [1,5,9]。以此类推。

对于dim=1,我们的index 的列表数必须与source 中的行数相等,但每个列表的长度可以根据您的喜好而定。对于dim=0index 中的每个列表的长度必须与source 中的列数相同,但我们现在可以拥有任意数量的列表。但是,index 中的每个值都必须小于 source 中的行数(在本例中为 4)。

例如,index = torch.tensor([[0,0,0],[1,1,1],[2,2,2],[3,3,3],[0,1,2],[1,2,3],[3,2,0]]) 将有 source.gather(dim=0, index=index) 给我们

tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [10, 11, 12],
        [ 1,  5,  9],
        [ 4,  8, 12],
        [10,  8,  3]])

对于dim=1,输出的行数始终与source 相同,尽管列数将等于index 中列表的长度。 index 中的列表数必须等于source 中的行数。但是,index 中的每个值都必须小于 source 中的列数。

对于dim=0,输出的列数始终与source 相同,但行数将等于index 中的列表数。 index 中每个列表的长度必须等于source 中的列数。但是index 中的每个值都必须小于source 中的行数。

这就是二维。超出此范围将遵循相同的模式。

【讨论】:

很棒的答案。您对每个维度所需内容的描述帮助我更轻松地可视化操作。 太棒了。他们对我来说的关键确实是你用粗体指出的。使用非方阵也非常有帮助。非常感谢!【参考方案4】:

这是基于@Ritesh answer(感谢@Ritesh!)和一些真实的代码。

torch.gather API 是

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor

示例 1

dim = 0

dim = 0
input = torch.tensor([[10, 11, 12], [13, 14, 15], [16, 17, 18]])
index = torch.tensor([[0, 1, 2], [1, 2, 0]]

output = torch.gather(input, dim, index))
# tensor([[10, 14, 18],
#         [13, 17, 12]])

示例 2

dim = 1

dim = 1
input = torch.tensor([[10, 11, 12], [13, 14, 15], [16, 17, 18]])
index = torch.tensor([[0, 1], [1, 2], [2, 0]]

output = torch.gather(input, dim, index))
# tensor([[10, 11],
#         [14, 15],
#         [18, 16]])

【讨论】:

以上是关于用外行术语来说,pytorch 中的聚集函数有啥作用?的主要内容,如果未能解决你的问题,请参考以下文章

外行术语中的 Spring 传播示例

外行术语中的同源策略

什么是 PHP 中的 Closures/Lambda 或 Javascript 中的外行术语? [复制]

Rapid Miner 中的 k-means 质心图实际上是啥意思(用外行的话来说)?

任何人都可以用外行术语解释,当使用域名打开我的网站时,它显示 https,但使用静态 IP,它显示 http。为啥?

您将如何用外行术语解释这个查询?