Skip to content

Commit

Permalink
Pass callback as universal reference
Browse files Browse the repository at this point in the history
  • Loading branch information
sleeepyjack committed Jun 15, 2024
1 parent 7053703 commit 18c5f60
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -976,7 +976,7 @@ class open_addressing_ref_impl {
* @param callback Function to call on every element found
*/
template <class ProbeKey, class Callback>
__device__ void for_each(ProbeKey const& key, Callback callback) const noexcept
__device__ void for_each(ProbeKey const& key, Callback&& callback) const noexcept
{
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");
auto probing_iter = this->probing_scheme_(key, this->storage_ref_.window_extent());
Expand Down Expand Up @@ -1027,7 +1027,7 @@ class open_addressing_ref_impl {
template <class ProbeKey, class Callback>
__device__ void for_each(cooperative_groups::thread_block_tile<cg_size> const& group,
ProbeKey const& key,
Callback callback) const noexcept
Callback&& callback) const noexcept
{
auto probing_iter = probing_scheme_(group, key, storage_ref_.window_extent());

Expand Down
10 changes: 6 additions & 4 deletions include/cuco/detail/static_multiset/static_multiset_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

#include <cooperative_groups.h>

#include <utility>

namespace cuco {

template <typename Key,
Expand Down Expand Up @@ -481,11 +483,11 @@ class operator_impl<
* @param callback Function to call on every element found
*/
template <class ProbeKey, class Callback>
__device__ void for_each(ProbeKey const& key, Callback callback) const noexcept
__device__ void for_each(ProbeKey const& key, Callback&& callback) const noexcept
{
// CRTP: cast `this` to the actual ref type
auto const& ref_ = static_cast<ref_type const&>(*this);
ref_.impl_.for_each(key, callback);
ref_.impl_.for_each(key, std::forward<Callback>(callback));
}

/**
Expand All @@ -509,11 +511,11 @@ class operator_impl<
template <class ProbeKey, class Callback>
__device__ void for_each(cooperative_groups::thread_block_tile<cg_size> const& group,
ProbeKey const& key,
Callback callback) const noexcept
Callback&& callback) const noexcept
{
// CRTP: cast `this` to the actual ref type
auto const& ref_ = static_cast<ref_type const&>(*this);
ref_.impl_.for_each(group, key, callback);
ref_.impl_.for_each(group, key, std::forward<Callback>(callback));
}
};

Expand Down
1 change: 1 addition & 0 deletions tests/static_multiset/for_each_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ CUCO_KERNEL void for_each_check_cooperative(Ref ref,
ref.for_each(tile, key, [&] __device__(auto const it) {
if (ref.key_eq()(key, *it)) { thread_matches++; }
});
tile.sync();
auto const tile_matches =
cooperative_groups::reduce(tile, thread_matches, cooperative_groups::plus<std::size_t>());
if (tile_matches != multiplicity and tile.thread_rank() == 0) {
Expand Down

0 comments on commit 18c5f60

Please sign in to comment.