diff --git a/thrift/compiler/generate/t_mstch_cpp2_generator.cc b/thrift/compiler/generate/t_mstch_cpp2_generator.cc index d100b3d2e55..60257edf7fc 100644 --- a/thrift/compiler/generate/t_mstch_cpp2_generator.cc +++ b/thrift/compiler/generate/t_mstch_cpp2_generator.cc @@ -2042,6 +2042,9 @@ void t_mstch_cpp2_generator::generate_visitation(const t_program* program) { cache_->programs_[id], "module_for_each_field.h", name + "_for_each_field.h"); + + render_to_file( + cache_->programs_[id], "module_visit_union.h", name + "_visit_union.h"); } void t_mstch_cpp2_generator::generate_structs(t_program const* program) { diff --git a/thrift/compiler/generate/templates/cpp2/module_visit_union.h.mustache b/thrift/compiler/generate/templates/cpp2/module_visit_union.h.mustache new file mode 100644 index 00000000000..b9a7e478873 --- /dev/null +++ b/thrift/compiler/generate/templates/cpp2/module_visit_union.h.mustache @@ -0,0 +1,48 @@ +<%! + + Copyright (c) Facebook, Inc. and its affiliates. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +%><% > Autogen%> +#pragma once + +#include "<%program:include_prefix%><%program:name%>_metadata.h" +#include + +namespace apache { +namespace thrift { +namespace detail { + +<%#program:structs%> +<%#struct:union?%> +template <> +struct VisitUnion<<% > common/namespace_cpp2%><%struct:name%>> { + template + void operator()(F&& f, T&& t) const { + using Union = std::remove_reference_t; + constexpr auto get_metadata = get_field_metadata<<% > common/namespace_cpp2%><%struct:name%>>; + switch (t.getType()) { + <%#struct:fields%> + case Union::Type::<%field:cpp_name%>: + return f(get_metadata(<%field:index%>), *static_cast(t).<%field:cpp_name%>_ref()); + <%/struct:fields%> + case Union::Type::__EMPTY__: ; + } + } +}; +<%/struct:union?%> +<%/program:structs%> +} // namespace detail +} // namespace thrift +} // namespace apache diff --git a/thrift/compiler/test/fixtures/visitation/gen-cpp2/module_visit_union.h b/thrift/compiler/test/fixtures/visitation/gen-cpp2/module_visit_union.h new file mode 100644 index 00000000000..82e53e4f05a --- /dev/null +++ b/thrift/compiler/test/fixtures/visitation/gen-cpp2/module_visit_union.h @@ -0,0 +1,163 @@ +/** + * Autogenerated by Thrift + * + * DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING + * @generated + */ +#pragma once + +#include "thrift/compiler/test/fixtures/visitation/gen-cpp2/module_metadata.h" +#include + +namespace apache { +namespace thrift { +namespace detail { + +template <> +struct VisitUnion<::test_cpp2::cpp_reflection::union1> { + template + void operator()(F&& f, T&& t) const { + using Union = std::remove_reference_t; + constexpr auto get_metadata = get_field_metadata<::test_cpp2::cpp_reflection::union1>; + switch (t.getType()) { + case Union::Type::ui: + return f(get_metadata(0), *static_cast(t).ui_ref()); + case Union::Type::ud: + return f(get_metadata(1), *static_cast(t).ud_ref()); + case Union::Type::us: + return f(get_metadata(2), *static_cast(t).us_ref()); + case Union::Type::ue: + return f(get_metadata(3), *static_cast(t).ue_ref()); + case Union::Type::__EMPTY__: ; + } + } +}; +template <> +struct VisitUnion<::test_cpp2::cpp_reflection::union2> { + template + void operator()(F&& f, T&& t) const { + using Union = std::remove_reference_t; + constexpr auto get_metadata = get_field_metadata<::test_cpp2::cpp_reflection::union2>; + switch (t.getType()) { + case Union::Type::ui_2: + return f(get_metadata(0), *static_cast(t).ui_2_ref()); + case Union::Type::ud_2: + return f(get_metadata(1), *static_cast(t).ud_2_ref()); + case Union::Type::us_2: + return f(get_metadata(2), *static_cast(t).us_2_ref()); + case Union::Type::ue_2: + return f(get_metadata(3), *static_cast(t).ue_2_ref()); + case Union::Type::__EMPTY__: ; + } + } +}; +template <> +struct VisitUnion<::test_cpp2::cpp_reflection::union3> { + template + void operator()(F&& f, T&& t) const { + using Union = std::remove_reference_t; + constexpr auto get_metadata = get_field_metadata<::test_cpp2::cpp_reflection::union3>; + switch (t.getType()) { + case Union::Type::ui_3: + return f(get_metadata(0), *static_cast(t).ui_3_ref()); + case Union::Type::ud_3: + return f(get_metadata(1), *static_cast(t).ud_3_ref()); + case Union::Type::us_3: + return f(get_metadata(2), *static_cast(t).us_3_ref()); + case Union::Type::ue_3: + return f(get_metadata(3), *static_cast(t).ue_3_ref()); + case Union::Type::__EMPTY__: ; + } + } +}; +template <> +struct VisitUnion<::test_cpp2::cpp_reflection::unionA> { + template + void operator()(F&& f, T&& t) const { + using Union = std::remove_reference_t; + constexpr auto get_metadata = get_field_metadata<::test_cpp2::cpp_reflection::unionA>; + switch (t.getType()) { + case Union::Type::i: + return f(get_metadata(0), *static_cast(t).i_ref()); + case Union::Type::d: + return f(get_metadata(1), *static_cast(t).d_ref()); + case Union::Type::s: + return f(get_metadata(2), *static_cast(t).s_ref()); + case Union::Type::e: + return f(get_metadata(3), *static_cast(t).e_ref()); + case Union::Type::a: + return f(get_metadata(4), *static_cast(t).a_ref()); + case Union::Type::__EMPTY__: ; + } + } +}; +template <> +struct VisitUnion<::test_cpp2::cpp_reflection::union_with_special_names> { + template + void operator()(F&& f, T&& t) const { + using Union = std::remove_reference_t; + constexpr auto get_metadata = get_field_metadata<::test_cpp2::cpp_reflection::union_with_special_names>; + switch (t.getType()) { + case Union::Type::get: + return f(get_metadata(0), *static_cast(t).get_ref()); + case Union::Type::getter: + return f(get_metadata(1), *static_cast(t).getter_ref()); + case Union::Type::lists: + return f(get_metadata(2), *static_cast(t).lists_ref()); + case Union::Type::maps: + return f(get_metadata(3), *static_cast(t).maps_ref()); + case Union::Type::name: + return f(get_metadata(4), *static_cast(t).name_ref()); + case Union::Type::name_to_value: + return f(get_metadata(5), *static_cast(t).name_to_value_ref()); + case Union::Type::names: + return f(get_metadata(6), *static_cast(t).names_ref()); + case Union::Type::prefix_tree: + return f(get_metadata(7), *static_cast(t).prefix_tree_ref()); + case Union::Type::sets: + return f(get_metadata(8), *static_cast(t).sets_ref()); + case Union::Type::setter: + return f(get_metadata(9), *static_cast(t).setter_ref()); + case Union::Type::str: + return f(get_metadata(10), *static_cast(t).str_ref()); + case Union::Type::strings: + return f(get_metadata(11), *static_cast(t).strings_ref()); + case Union::Type::type: + return f(get_metadata(12), *static_cast(t).type_ref()); + case Union::Type::value: + return f(get_metadata(13), *static_cast(t).value_ref()); + case Union::Type::value_to_name: + return f(get_metadata(14), *static_cast(t).value_to_name_ref()); + case Union::Type::values: + return f(get_metadata(15), *static_cast(t).values_ref()); + case Union::Type::id: + return f(get_metadata(16), *static_cast(t).id_ref()); + case Union::Type::ids: + return f(get_metadata(17), *static_cast(t).ids_ref()); + case Union::Type::descriptor: + return f(get_metadata(18), *static_cast(t).descriptor_ref()); + case Union::Type::descriptors: + return f(get_metadata(19), *static_cast(t).descriptors_ref()); + case Union::Type::key: + return f(get_metadata(20), *static_cast(t).key_ref()); + case Union::Type::keys: + return f(get_metadata(21), *static_cast(t).keys_ref()); + case Union::Type::annotation: + return f(get_metadata(22), *static_cast(t).annotation_ref()); + case Union::Type::annotations: + return f(get_metadata(23), *static_cast(t).annotations_ref()); + case Union::Type::member: + return f(get_metadata(24), *static_cast(t).member_ref()); + case Union::Type::members: + return f(get_metadata(25), *static_cast(t).members_ref()); + case Union::Type::field: + return f(get_metadata(26), *static_cast(t).field_ref()); + case Union::Type::fields: + return f(get_metadata(27), *static_cast(t).fields_ref()); + case Union::Type::__EMPTY__: ; + } + } +}; +} // namespace detail +} // namespace thrift +} // namespace apache diff --git a/thrift/lib/cpp2/visitation/visit_union.h b/thrift/lib/cpp2/visitation/visit_union.h new file mode 100644 index 00000000000..613380bedf2 --- /dev/null +++ b/thrift/lib/cpp2/visitation/visit_union.h @@ -0,0 +1,50 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace apache { +namespace thrift { +namespace detail { +template +struct VisitUnion { + static_assert(sizeof(T) < 0, "Must include visitation header"); +}; +} // namespace detail + +/** + * Applies the callable to active member of thrift union. Example: + * + * visit_union(thriftUnion, [](const ThriftField& meta, auto&& value) { + * LOG(INFO) << *meta.name_ref() << " --> " << value; + * }) + * + * ThriftField schema is defined here: https://git.io/JJQpY + * If `no_metadata` thrift option is enabled, ThriftField will be empty. + * If union is empty, callable won't be called. + * + * @param t thrift union + * @param f a callable that accepts all member types from union + */ +template +void visit_union(T&& t, F f) { + return detail::VisitUnion>()(f, static_cast(t)); +} +} // namespace thrift +} // namespace apache diff --git a/thrift/test/UnionFieldRef.thrift b/thrift/test/UnionFieldRef.thrift index 70e3fe97108..180842b7549 100644 --- a/thrift/test/UnionFieldRef.thrift +++ b/thrift/test/UnionFieldRef.thrift @@ -17,9 +17,9 @@ namespace cpp2 apache.thrift.test union Basic { - 1: string str - 2: i64 int64 - 3: list list_i32 + 2: string str + 1: i64 int64 + 4: list list_i32 } union DuplicateType { diff --git a/thrift/test/visitation_visit_union_test.cpp b/thrift/test/visitation_visit_union_test.cpp new file mode 100644 index 00000000000..00efea64af0 --- /dev/null +++ b/thrift/test/visitation_visit_union_test.cpp @@ -0,0 +1,73 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include // @manual=:union_field_ref-cpp2-visitation +using namespace std; + +namespace apache { +namespace thrift { +namespace test { +TEST(UnionFieldTest, basic) { + Basic a; + visit_union(a, [&](auto&&, auto&&) { FAIL(); }); + + static const string str = "foo"; + a.str_ref() = str; + visit_union(a, [](auto&& meta, auto&& v) { + EXPECT_EQ(meta.name, "str"); + EXPECT_EQ(meta.type.getType(), meta.type.t_primitive); + EXPECT_EQ(meta.id, 2); + EXPECT_EQ(meta.is_optional, false); + if constexpr (std::is_same_v) { + EXPECT_EQ(v, str); + } else { + FAIL(); + } + }); + + static const int64_t int64 = 42LL << 42; + a.int64_ref() = int64; + visit_union(a, [](auto&& meta, auto&& v) { + EXPECT_EQ(meta.name, "int64"); + EXPECT_EQ(meta.type.getType(), meta.type.t_primitive); + EXPECT_EQ(meta.id, 1); + EXPECT_EQ(meta.is_optional, false); + EXPECT_EQ(typeid(v), typeid(int64_t)); + if constexpr (std::is_same_v) { + EXPECT_EQ(v, int64); + } else { + FAIL(); + } + }); + + static const vector list_i32 = {3, 1, 2}; + a.list_i32_ref() = list_i32; + visit_union(a, [](auto&& meta, auto&& v) { + EXPECT_EQ(meta.name, "list_i32"); + EXPECT_EQ(meta.type.getType(), meta.type.t_list); + EXPECT_EQ(meta.id, 4); + EXPECT_EQ(meta.is_optional, false); + if constexpr (std::is_same_v&>) { + EXPECT_EQ(v, list_i32); + } else { + FAIL(); + } + }); +} +} // namespace test +} // namespace thrift +} // namespace apache