PyTorch c++ 扩展中的可选张量

Posted

技术标签:

【中文标题】PyTorch c++ 扩展中的可选张量【英文标题】:Optional tensors in PyTorch c++ extension 【发布时间】:2019-07-08 04:29:32 【问题描述】:

我正在为 pytorch 编写 C++ 扩展,并使用 c++ api 来执行此操作。对于我的 forward 函数,我需要传递一个可选的张量。在函数内部,我想根据是否传递了这个可选参数来做不同的事情。通常,我们在 C++ 中使用 NULL 作为可选指针参数,并在函数内部检查指针是否为 NULL。我不知道如何为 at::Tensor 类型的 Torch 的 c++ api 执行此操作。

void xyz_forward(
    const at::Tensor xyz1, 
    const at::Tensor xyz2, 
    const at::Tensor optional_constraints = something)

     if(optional_constraints)
        //do something
     else
        //do something else
     

请注意,我不能做const at::Tensor optional_constraints = at::ones 之类的,因为该参数可以采用任何实际值并且可以具有不同的大小/形状。我不能为它分配一个数值作为可选参数。是否有对应的 NULL

【问题讨论】:

也许我不明白,但你不能检查一下optional_constraints == nullptr吗? @Coolness 不幸的是optional_constrains 不是指针。 啊,我明白了。谢谢。 【参考方案1】:

一种可能性是将std::optional 用作std::optional<at::Tensor> optional_constraints = std::nullopt。它可以根据上下文转换为bool,因此您可以使用if (optional_constraints) 进行检查。传一个则使用.value()方法获取张量,否则默认为std::nullopt

【讨论】:

【参考方案2】:

因为我找不到类似的东西,例如。 API 中的 OpenCV noArray()(主要用于传递可选矩阵,如掩码),我建议您为此目的使用重载函数

void xyz_forward(
    const at::Tensor xyz1, 
    const at::Tensor xyz2)

     // optional tensor wasnt passed


void xyz_forward(
    const at::Tensor xyz1, 
    const at::Tensor xyz2, 
    const at::Tensor optional_constraints)

     // optional tensor passed

【讨论】:

以上是关于PyTorch c++ 扩展中的可选张量的主要内容,如果未能解决你的问题,请参考以下文章

1. PyTorch是什么?

1. PyTorch是什么?

libtorch (pytorch c++) 教程

libtorch (pytorch c++) 教程

PyTorch:tensor-张量维度操作(拼接维度扩展压缩转置重复……)

Pytorch 几何:张量大小有问题