diff --git a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh index 92ad4fe3c..1c77c76be 100644 --- a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh @@ -976,7 +976,7 @@ class open_addressing_ref_impl { * @param callback Function to call on every element found */ template - __device__ void for_each(ProbeKey const& key, Callback callback) const noexcept + __device__ void for_each(ProbeKey const& key, Callback&& callback) const noexcept { static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); auto probing_iter = this->probing_scheme_(key, this->storage_ref_.window_extent()); @@ -1027,7 +1027,7 @@ class open_addressing_ref_impl { template __device__ void for_each(cooperative_groups::thread_block_tile const& group, ProbeKey const& key, - Callback callback) const noexcept + Callback&& callback) const noexcept { auto probing_iter = probing_scheme_(group, key, storage_ref_.window_extent()); diff --git a/include/cuco/detail/static_multiset/static_multiset_ref.inl b/include/cuco/detail/static_multiset/static_multiset_ref.inl index d34586579..78c54b1b4 100644 --- a/include/cuco/detail/static_multiset/static_multiset_ref.inl +++ b/include/cuco/detail/static_multiset/static_multiset_ref.inl @@ -22,6 +22,8 @@ #include +#include + namespace cuco { template - __device__ void for_each(ProbeKey const& key, Callback callback) const noexcept + __device__ void for_each(ProbeKey const& key, Callback&& callback) const noexcept { // CRTP: cast `this` to the actual ref type auto const& ref_ = static_cast(*this); - ref_.impl_.for_each(key, callback); + ref_.impl_.for_each(key, std::forward(callback)); } /** @@ -509,11 +511,11 @@ class operator_impl< template __device__ void for_each(cooperative_groups::thread_block_tile const& group, ProbeKey const& key, - Callback callback) const noexcept + Callback&& callback) const noexcept { // CRTP: cast `this` to the actual ref type auto const& ref_ = static_cast(*this); - ref_.impl_.for_each(group, key, callback); + ref_.impl_.for_each(group, key, std::forward(callback)); } }; diff --git a/tests/static_multiset/for_each_test.cu b/tests/static_multiset/for_each_test.cu index e7ecece1e..b0cb81091 100644 --- a/tests/static_multiset/for_each_test.cu +++ b/tests/static_multiset/for_each_test.cu @@ -72,6 +72,7 @@ CUCO_KERNEL void for_each_check_cooperative(Ref ref, ref.for_each(tile, key, [&] __device__(auto const it) { if (ref.key_eq()(key, *it)) { thread_matches++; } }); + tile.sync(); auto const tile_matches = cooperative_groups::reduce(tile, thread_matches, cooperative_groups::plus()); if (tile_matches != multiplicity and tile.thread_rank() == 0) {