nVidia 推力:device_ptr 常量正确性
Posted
技术标签:
【中文标题】nVidia 推力:device_ptr 常量正确性【英文标题】:nVidia Thrust: device_ptr Const-Correctness 【发布时间】:2016-02-26 20:14:53 【问题描述】:在我广泛使用 nVidia CUDA 的项目中,我有时将 Thrust 用于它做得非常非常好的事情。 Reduce 是一种在该库中实现得特别好的算法,reduce 的一个用途是通过将每个元素除以所有元素的总和来归一化非负元素的向量元素。
template <typename T>
void normalise(T const* const d_input, const unsigned int size, T* d_output)
const thrust::device_ptr<T> X = thrust::device_pointer_cast(const_cast<T*>(d_input));
T sum = thrust::reduce(X, X + size);
thrust::constant_iterator<T> denominator(sum);
thrust::device_ptr<T> Y = thrust::device_pointer_cast(d_output);
thrust::transform(X, X + size, denominator, Y, thrust::divides<T>());
(T
通常是 float
或 double
)
一般来说,我不想在整个代码库中都依赖 Thrust,因此我尝试确保像上面示例这样的函数只接受原始 CUDA 设备指针。这意味着一旦它们被 NVCC 编译,我可以将它们静态链接到没有 NVCC 的其他代码中。
然而,这段代码让我担心。我希望函数是 const 正确的,但我似乎找不到 const
版本的 thrust::device_pointer_cast(...)
- 这样的事情存在吗?在这个版本的代码中,我使用了const_cast
,所以我在函数签名中使用了const
,这让我很伤心。
在旁注中,将 reduce 的结果复制到主机只是为了将其发送回设备以进行下一步,这感觉很奇怪。有没有更好的方法来做到这一点?
【问题讨论】:
为什么你认为是non const?你为什么投? 我强制转换是因为没有强制转换就无法编译。 “没有构造函数实例”thrust::pointerthrust::device_ptr<const T>
?
@Angew - 这行得通。 thrust::device_ptr<const T> X = thrust::device_pointer_cast<const T>(d_input)
就好了。
@Angew - 感谢您的提示。
【参考方案1】:
如果你想要 const 正确性,你需要在任何地方都是 const 正确的。 input
是指向 const T
的指针,因此应该是 X
:
const thrust::device_ptr<const T> X = thrust::device_pointer_cast(d_input);
【讨论】:
@Drop 是的(虽然那将是const auto X = ...
以完全匹配 OP 的代码)。
如果@Angew 在他的回答中使用了auto
,我从这个问题中学到的东西很少。正是他给了我完整的类型规范这一事实才很有帮助。
确保 T 也匹配指针中包含的类型。例如,如果 d_input 为 uint64 并且您尝试:
以上是关于nVidia 推力:device_ptr 常量正确性的主要内容,如果未能解决你的问题,请参考以下文章