From 26ce0b7a0f96f0950a6ff0ac0437028ab3dc5289 Mon Sep 17 00:00:00 2001 From: Yunsong Wang Date: Wed, 7 Aug 2024 14:16:18 -0700 Subject: [PATCH] Add multimap count device APIs --- .../static_multimap/static_multimap_ref.inl | 57 +++++++++++++++++++ include/cuco/static_multimap_ref.cuh | 3 - 2 files changed, 57 insertions(+), 3 deletions(-) diff --git a/include/cuco/detail/static_multimap/static_multimap_ref.inl b/include/cuco/detail/static_multimap/static_multimap_ref.inl index a79edba42..01e52d686 100644 --- a/include/cuco/detail/static_multimap/static_multimap_ref.inl +++ b/include/cuco/detail/static_multimap/static_multimap_ref.inl @@ -438,5 +438,62 @@ class operator_impl< return ref_.impl_.contains(group, key); } }; + +template +class operator_impl< + op::count_tag, + static_multimap_ref> { + using base_type = static_multimap_ref; + using ref_type = + static_multimap_ref; + using key_type = typename base_type::key_type; + using value_type = typename base_type::value_type; + using size_type = typename base_type::size_type; + + static constexpr auto cg_size = base_type::cg_size; + static constexpr auto window_size = base_type::window_size; + + public: + /** + * @brief Counts the occurrence of a given key contained in multimap + * + * @tparam ProbeKey Input type + * + * @param key The key to count for + * + * @return Number of occurrences found by the current thread + */ + template + __device__ size_type count(ProbeKey const& key) const noexcept + { + auto const& ref_ = static_cast(*this); + return ref_.impl_.count(key); + } + + /** + * @brief Counts the occurrence of a given key contained in multimap + * + * @tparam ProbeKey Probe key type + * + * @param group The Cooperative Group used to perform group count + * @param key The key to count for + * + * @return Number of occurrences found by the current thread + */ + template + __device__ size_type count(cooperative_groups::thread_block_tile const& group, + ProbeKey const& key) const noexcept + { + auto const& ref_ = static_cast(*this); + return ref_.impl_.count(group, key); + } +}; + } // namespace detail } // namespace cuco diff --git a/include/cuco/static_multimap_ref.cuh b/include/cuco/static_multimap_ref.cuh index fceb10489..74bc81ddb 100644 --- a/include/cuco/static_multimap_ref.cuh +++ b/include/cuco/static_multimap_ref.cuh @@ -74,9 +74,6 @@ class static_multimap_ref using impl_type = detail:: open_addressing_ref_impl; - static_assert(sizeof(T) == 4 or sizeof(T) == 8, - "sizeof(mapped_type) must be either 4 bytes or 8 bytes."); - static_assert( cuco::is_bitwise_comparable_v, "Key type must have unique object representations or have been explicitly declared as safe for "