PyTorch c++ 扩展中的可选张量
Posted
技术标签:
【中文标题】PyTorch c++ 扩展中的可选张量【英文标题】:Optional tensors in PyTorch c++ extension 【发布时间】:2019-07-08 04:29:32 【问题描述】:我正在为 pytorch 编写 C++ 扩展,并使用 c++ api 来执行此操作。对于我的 forward
函数,我需要传递一个可选的张量。在函数内部,我想根据是否传递了这个可选参数来做不同的事情。通常,我们在 C++ 中使用 NULL 作为可选指针参数,并在函数内部检查指针是否为 NULL。我不知道如何为 at::Tensor
类型的 Torch 的 c++ api 执行此操作。
void xyz_forward(
const at::Tensor xyz1,
const at::Tensor xyz2,
const at::Tensor optional_constraints = something)
if(optional_constraints)
//do something
else
//do something else
请注意,我不能做const at::Tensor optional_constraints = at::ones
之类的,因为该参数可以采用任何实际值并且可以具有不同的大小/形状。我不能为它分配一个数值作为可选参数。是否有对应的 NULL
?
【问题讨论】:
也许我不明白,但你不能检查一下optional_constraints == nullptr
吗?
@Coolness 不幸的是optional_constrains
不是指针。
啊,我明白了。谢谢。
【参考方案1】:
一种可能性是将std::optional 用作std::optional<at::Tensor> optional_constraints = std::nullopt
。它可以根据上下文转换为bool
,因此您可以使用if (optional_constraints)
进行检查。传一个则使用.value()
方法获取张量,否则默认为std::nullopt
。
【讨论】:
【参考方案2】:因为我找不到类似的东西,例如。 API 中的 OpenCV noArray()
(主要用于传递可选矩阵,如掩码),我建议您为此目的使用重载函数
void xyz_forward(
const at::Tensor xyz1,
const at::Tensor xyz2)
// optional tensor wasnt passed
void xyz_forward(
const at::Tensor xyz1,
const at::Tensor xyz2,
const at::Tensor optional_constraints)
// optional tensor passed
【讨论】:
以上是关于PyTorch c++ 扩展中的可选张量的主要内容,如果未能解决你的问题,请参考以下文章