cuda Thrust 如何获取与键关联的值
Posted
技术标签:
【中文标题】cuda Thrust 如何获取与键关联的值【英文标题】:cuda Thrust how to get the values associated to key 【发布时间】:2022-01-11 20:16:47 【问题描述】:我在一侧有一个值/键列表,在另一侧我有一个键列表。
我想在列表中的这个键列表中获取关联的值。
这是我的伪代码
我不知道如何制作我的谓词,顺便说一句,这是否是实现目标的好方法。
void pseudoCode()
const int N = 7;
thrust::device_vector<vec3> keys1(N);
values[0] = vec3(1.01,1.01,1.0156);
values[1] = vec3(1.01,1.01,1.01561);
values[2] = vec3(1.02,1.52,1.02);
values[3] = vec3(1.02,1.52,1.02);
values[4] = vec3(1.0,1.0,1.0);
values[5] = vec3(5.0,1.0,1.0);
values[6] = vec3(5.0,1.5,1.0);
thrust::device_vector<long> values(N);
keys1[0] = 0;
keys1[1] = 1;
keys1[2] = 2;
keys1[3] = 5;
keys1[4] = 9;
keys1[5] = 19;
keys1[6] = 22;
thrust::device_vector<long> to_find(6);
to_find[0] = 1;
to_find[1] = 5;
to_find[2] = 9;
thrust::device_vector<vec3> output(6);
auto begin = thrust::make_zip_iterator(thrust::make_tuple(keys1.begin(), values.begin()));
auto end = thrust::make_zip_iterator(thrust::make_tuple(keys1.end(), values.end()));
thrust::copy_if(begin, end, output, pred(to_find));
// result
output[0]=values[1];
output[1]=values[3];
output[2]=values[4];
我试着在这里做我的谓词是我的代码
struct FindValue
thrust::device_vector<long> ToFind;
explicit FindValue(thrust::device_vector<long> toFind) : ToFindtoFind
__host__ __device__
bool operator()(thrust::tuple<vec3, long> x)
long v = thrust::get<1>(x);
size_t N = ToFind.size();
for (size_t index=0; index<N; index++)
if (v==ToFind[index]) return true;
return false;
;
并更改 copy_if 的调用
thrust::copy_if(begin, end, output, FindValue(to_find));
但我得到了错误
/usr/local/cuda/bin/../targets/x86_64-linux/include/thrust/detail/copy_if.inl(75): error: class "thrust::iterator_system<thrust::device_vector<vec3, thrust::device_allocator<vec3>>>" has no member "type"
detected during instantiation of "OutputIterator thrust::copy_if(InputIterator, InputIterator, OutputIterator, Predicate) [with InputIterator=thrust::zip_iterator<thrust::tuple<thrust::detail::normal_iterator<thrust::device_ptr<vec3>>, thrust::detail::normal_iterator<thrust::device_ptr<long>>, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type>>, OutputIterator=thrust::device_vector<vec3, thrust::device_allocator<vec3>>, Predicate=FindValue]"
cpp_addin_new2.cu(768): here
/usr/local/cuda/bin/../targets/x86_64-linux/include/thrust/detail/copy_if.inl(80): error: no instance of overloaded function "select_system" matches the argument list
argument types are: (System1, System2)
detected during instantiation of "OutputIterator thrust::copy_if(InputIterator, InputIterator, OutputIterator, Predicate) [with InputIterator=thrust::zip_iterator<thrust::tuple<thrust::detail::normal_iterator<thrust::device_ptr<vec3>>, thrust::detail::normal_iterator<thrust::device_ptr<long>>, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type>>, OutputIterator=thrust::device_vector<vec3, thrust::device_allocator<vec3>>, Predicate=FindValue]"
cpp_addin_new2.cu(768): here
/usr/local/cuda/bin/../targets/x86_64-linux/include/thrust/detail/copy_if.inl(80): error: no instance of overloaded function "thrust::copy_if" matches the argument list
argument types are: (<error-type>, thrust::zip_iterator<thrust::tuple<thrust::detail::normal_iterator<thrust::device_ptr<vec3>>, thrust::detail::normal_iterator<thrust::device_ptr<long>>, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type>>, thrust::zip_iterator<thrust::tuple<thrust::detail::normal_iterator<thrust::device_ptr<vec3>>, thrust::detail::normal_iterator<thrust::device_ptr<long>>, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type>>, thrust::device_vector<vec3, thrust::device_allocator<vec3>>, FindValue)
detected during instantiation of "OutputIterator thrust::copy_if(InputIterator, InputIterator, OutputIterator, Predicate) [with InputIterator=thrust::zip_iterator<thrust::tuple<thrust::detail::normal_iterator<thrust::device_ptr<vec3>>, thrust::detail::normal_iterator<thrust::device_ptr<long>>, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type>>, OutputIterator=thrust::device_vector<vec3, thrust::device_allocator<vec3>>, Predicate=FindValue]"
cpp_addin_new2.cu(768): here
3 errors detected in the compilation of "cpp_addin_new2.cu".
【问题讨论】:
你的谓词可以对to_find
中的键执行二分搜索
我有效地尝试做到这一点,但没有设法做到这一点。有预感
你能显示你当前的谓词,哪个不起作用?
【参考方案1】:
您当前方法的主要问题是thrust::device_vector
不能在设备代码中使用。函子需要存储指针和大小。
数据类型也存在问题,例如,您有 long
值,并且想要复制值,但输出向量包含 vec3
即键。当前函子还搜索 to_find 值而不是键。
无论如何,搜索键的正确仿函数可能如下所示。
template<class KeyType>
struct CopyIfPred
int numToFind;
KeyType* to_find;
__host__ __device__
CopyIfPred(KeyType* to_find_, int numToFind_) : numToFind(numToFind_), to_find(to_find_)
template<class Tup>
__host__ __device__
bool operator()(const Tup& tup) const noexcept
const auto key = thrust::get<0>(tup);
const bool found = thrust::binary_search(
thrust::seq,
to_find,
to_find + numToFind,
key);
return found;
;
可以这样使用:
auto begin = thrust::make_zip_iterator(thrust::make_tuple(
keys1.begin(),
values.begin()
));
auto end = begin + N;
auto outputbegin = thrust::make_zip_iterator(thrust::make_tuple(
thrust::make_discard_iterator(), // discard copied keys
output.begin()
));
CopyIfPred<KeyType> pred(
thrust::raw_pointer_cast(to_find.data()),
to_find.size());
auto endIters = thrust::copy_if(begin, end, outputbegin, pred);
std::size_t outputsize = thrust::distance(outputbegin, endIters);
for(std::size_t i = 0; i < outputsize; i++)
std::cout << output[i] << "\n";
【讨论】:
这是一个漂亮的解决方案。工作正常。谢谢@Abator Abetor【参考方案2】:这正是thrust::lower_bound
的设计初衷。请看这里的示例代码。
https://thrust.github.io/doc/group__vectorized__binary__search_ga04b5a67cd0daf7be4b35c2bc75d22bee.html
【讨论】:
请不要仅发布链接答案。如果该链接失效,您的答案将变得毫无用处以上是关于cuda Thrust 如何获取与键关联的值的主要内容,如果未能解决你的问题,请参考以下文章
Swift Parse JSON 错误:没有与键 CodingKeys 关联的值(stringValue: \"_source\