Skip to content

Commit

Permalink
Add multimap count device APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
PointKernel committed Aug 7, 2024
1 parent c523f46 commit 26ce0b7
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 3 deletions.
57 changes: 57 additions & 0 deletions include/cuco/detail/static_multimap/static_multimap_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -438,5 +438,62 @@ class operator_impl<
return ref_.impl_.contains(group, key);
}
};

template <typename Key,
typename T,
cuda::thread_scope Scope,
typename KeyEqual,
typename ProbingScheme,
typename StorageRef,
typename... Operators>
class operator_impl<
op::count_tag,
static_multimap_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>> {
using base_type = static_multimap_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef>;
using ref_type =
static_multimap_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>;
using key_type = typename base_type::key_type;
using value_type = typename base_type::value_type;
using size_type = typename base_type::size_type;

static constexpr auto cg_size = base_type::cg_size;
static constexpr auto window_size = base_type::window_size;

public:
/**
* @brief Counts the occurrence of a given key contained in multimap
*
* @tparam ProbeKey Input type
*
* @param key The key to count for
*
* @return Number of occurrences found by the current thread
*/
template <typename ProbeKey>
__device__ size_type count(ProbeKey const& key) const noexcept
{
auto const& ref_ = static_cast<ref_type const&>(*this);
return ref_.impl_.count(key);
}

/**
* @brief Counts the occurrence of a given key contained in multimap
*
* @tparam ProbeKey Probe key type
*
* @param group The Cooperative Group used to perform group count
* @param key The key to count for
*
* @return Number of occurrences found by the current thread
*/
template <typename ProbeKey>
__device__ size_type count(cooperative_groups::thread_block_tile<cg_size> const& group,
ProbeKey const& key) const noexcept
{
auto const& ref_ = static_cast<ref_type const&>(*this);
return ref_.impl_.count(group, key);
}
};

} // namespace detail
} // namespace cuco
3 changes: 0 additions & 3 deletions include/cuco/static_multimap_ref.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,6 @@ class static_multimap_ref
using impl_type = detail::
open_addressing_ref_impl<Key, Scope, KeyEqual, ProbingScheme, StorageRef, allows_duplicates>;

static_assert(sizeof(T) == 4 or sizeof(T) == 8,
"sizeof(mapped_type) must be either 4 bytes or 8 bytes.");

static_assert(
cuco::is_bitwise_comparable_v<Key>,
"Key type must have unique object representations or have been explicitly declared as safe for "
Expand Down

0 comments on commit 26ce0b7

Please sign in to comment.