diff --git a/paddle/fluid/framework/var_type_traits.cc b/paddle/fluid/framework/var_type_traits.cc index c9f9f8d6c653768a19a33fae7cadc3dc3ea5ece7..690c4895c1df3c34c7a3586afb0d4734b74decd1 100644 --- a/paddle/fluid/framework/var_type_traits.cc +++ b/paddle/fluid/framework/var_type_traits.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/framework/var_type_traits.h" +#include "paddle/fluid/platform/macros.h" namespace paddle { namespace framework { @@ -23,54 +24,83 @@ namespace detail { template struct VarIdToTypeIndexMapInitializerImpl { - static void Init(std::unordered_map *m) { + template + static void Init(MapType1 *id_to_type, MapType2 *type_to_id) { using Type = typename std::tuple_element::type; + static_assert(!std::is_same::value, "Type cannot be void"); constexpr int kId = VarTypeTrait::kId; - if (!std::is_same::value) { - m->emplace(kId, std::type_index(typeid(Type))); - } + auto type = std::type_index(typeid(Type)); + PADDLE_ENFORCE(id_to_type->count(kId) == 0, + "Registered duplicate type id %d for type %s", kId, + type.name()); + PADDLE_ENFORCE(type_to_id->count(type) == 0, + "Registered duplicate type_index %s for id %d", type.name(), + kId); + id_to_type->emplace(kId, type); + type_to_id->emplace(type, kId); VarIdToTypeIndexMapInitializerImpl::Init(m); + kStart + 1 == kEnd>::Init(id_to_type, + type_to_id); } }; template struct VarIdToTypeIndexMapInitializerImpl { - static void Init(std::unordered_map *m) {} + template + static void Init(MapType1 *, MapType2 *) {} }; // VarIdToTypeIndexMapInitializer is designed to initialize var_id -> -// std::type_index map +// std::type_index map and std::type_index -> var_id map using VarIdToTypeIndexMapInitializer = VarIdToTypeIndexMapInitializerImpl<0, VarTypeRegistry::kRegisteredTypeNum, VarTypeRegistry::kRegisteredTypeNum == 0>; struct VarIdToTypeIndexMapHolder { + DISABLE_COPY_AND_ASSIGN(VarIdToTypeIndexMapHolder); + public: static const std::type_index &ToTypeIndex(int var_id) { - static const VarIdToTypeIndexMapHolder instance; - auto it = instance.var_type_map_.find(var_id); - PADDLE_ENFORCE(it != instance.var_type_map_.end(), + auto it = Instance().id_to_type_map_.find(var_id); + PADDLE_ENFORCE(it != Instance().id_to_type_map_.end(), "VarId %d is not registered.", var_id); return it->second; } + static int ToTypeId(const std::type_index &type) { + auto it = Instance().type_to_id_map_.find(type); + PADDLE_ENFORCE(it != Instance().type_to_id_map_.end(), + "VarType %s is not registered.", type.name()); + return it->second; + } + private: VarIdToTypeIndexMapHolder() { - VarIdToTypeIndexMapInitializer::Init(&var_type_map_); + VarIdToTypeIndexMapInitializer::Init(&id_to_type_map_, &type_to_id_map_); + } + + static const VarIdToTypeIndexMapHolder &Instance() { + static const VarIdToTypeIndexMapHolder instance; + return instance; } - std::unordered_map var_type_map_; + + std::unordered_map id_to_type_map_; + std::unordered_map type_to_id_map_; }; } // namespace detail -const char *ToTypeName(int var_id) { return ToTypeIndex(var_id).name(); } - const std::type_index &ToTypeIndex(int var_id) { return detail::VarIdToTypeIndexMapHolder::ToTypeIndex(var_id); } +const char *ToTypeName(int var_id) { return ToTypeIndex(var_id).name(); } + +int ToTypeId(const std::type_index &type) { + return detail::VarIdToTypeIndexMapHolder::ToTypeId(type); +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/var_type_traits.h b/paddle/fluid/framework/var_type_traits.h index c5e0d4707efe22ed83d8abad7e0de99628743d9b..a58414c3d4e56e6f24f7637132e78f3a60b1ac9a 100644 --- a/paddle/fluid/framework/var_type_traits.h +++ b/paddle/fluid/framework/var_type_traits.h @@ -42,6 +42,7 @@ namespace framework { const char *ToTypeName(int var_id); const std::type_index &ToTypeIndex(int var_id); +int ToTypeId(const std::type_index &type); namespace detail { @@ -75,10 +76,10 @@ struct VarTypeRegistryImpl { using ArgTuple = std::tuple; // TypePos() returns the position in which T is inside Args... - // If T is not inside Args... or T is void, return -1 + // If T is not inside Args..., return -1 template static constexpr int TypePos() { - return std::is_same::value ? -1 : TypePosFinder::kPos; + return TypePosFinder::kPos; } // IsRegistered() returns whether T is registered inside RegistryImpl @@ -90,19 +91,22 @@ struct VarTypeRegistryImpl { } // namespace detail -#define REG_PROTO_VAR_TYPE_TRAIT(type, proto_id) \ - template <> \ - struct VarTypeTrait { \ - static_assert(VarTypeRegistry::IsRegistered(), \ - "Must be registered type"); \ - using Type = type; \ - static constexpr int kId = proto_id; \ +#define REG_PROTO_VAR_TYPE_TRAIT(type, proto_id) \ + template <> \ + struct VarTypeTrait { \ + static_assert(VarTypeRegistry::IsRegistered(), \ + "Must be registered type"); \ + using Type = type; \ + static constexpr int kId = static_cast(proto_id); \ } /** * The following codes are designed to register variable types. * Only registered types can be stored in Variable. * This registry mechanism is designed to speed up Variable. + * + * Caution: If you want to add more var types, please consider carefully + * whether you really need to add it. */ // Users should add other variable types below. @@ -110,10 +114,9 @@ struct VarTypeRegistryImpl { class Scope; using VarTypeRegistry = detail::VarTypeRegistryImpl< - LoDTensor, SelectedRows, std::vector, LoDRankTable, LoDTensorArray, - platform::PlaceList, ReaderHolder, Tensor, std::string, Scope *, + Tensor, LoDTensor, SelectedRows, std::vector, LoDRankTable, + LoDTensorArray, platform::PlaceList, ReaderHolder, std::string, Scope *, std::map, operators::reader::LoDTensorBlockingQueueHolder, - int, float, #ifdef PADDLE_WITH_CUDA #ifndef _WIN32 ncclUniqueId, platform::Communicator, @@ -123,13 +126,11 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl< operators::AlgorithmsCache, operators::CudnnRNNCache, #endif - void>; // void indicates end of registration, add other types before void + int, float>; template struct VarTypeTrait { - static_assert(std::is_same::value || - VarTypeRegistry::IsRegistered(), - "Must be registered type"); + static_assert(VarTypeRegistry::IsRegistered(), "Must be registered type"); using Type = T; // Default id generation static constexpr int kId = VarTypeRegistry::TypePos() + diff --git a/paddle/fluid/framework/var_type_traits_test.cc b/paddle/fluid/framework/var_type_traits_test.cc index f46608233ab90cd1e17cf9117e8a6c342dde9ed5..4dad4cb27b7720b4a72f12b50b6f7abd1a5b7e79 100644 --- a/paddle/fluid/framework/var_type_traits_test.cc +++ b/paddle/fluid/framework/var_type_traits_test.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/framework/var_type_traits.h" #include #include +#include #include namespace paddle { @@ -29,15 +30,27 @@ struct TypeIndexChecker { static_assert(std::is_same::Type, Type>::value, "Type must be the same"); constexpr auto kId = VarTypeTrait::kId; - if (!std::is_same::value) { - std::type_index actual_type(typeid(Type)); - EXPECT_EQ(std::string(ToTypeName(kId)), std::string(actual_type.name())); - EXPECT_EQ(ToTypeIndex(kId), actual_type); - EXPECT_TRUE(var_id_set->count(kId) == 0); // NOLINT - EXPECT_TRUE(type_index_set->count(actual_type) == 0); // NOLINT - var_id_set->insert(kId); - type_index_set->insert(std::type_index(typeid(Type))); + std::type_index actual_type(typeid(Type)); + EXPECT_EQ(std::string(ToTypeName(kId)), std::string(actual_type.name())); + // For some reasons, comparing std::type_index using EXPECT_EQ would fail + // in MAC CI + bool is_same_type_index = (ToTypeIndex(kId) == actual_type); + if (!is_same_type_index) { + std::string s1 = ToTypeName(kId); + std::string s2 = actual_type.name(); + PADDLE_THROW("Step %d: type %s is not the same as %s, var_id %d", kPos, + s1.c_str(), s2.c_str(), kId); } + EXPECT_TRUE(is_same_type_index); + EXPECT_TRUE(ToTypeId(actual_type) == kId); // NOLINT + is_same_type_index = (ToTypeIndex(ToTypeId(actual_type)) == actual_type); + EXPECT_TRUE(is_same_type_index); + EXPECT_EQ(ToTypeId(ToTypeIndex(kId)), kId); + + EXPECT_TRUE(var_id_set->count(kId) == 0); // NOLINT + EXPECT_TRUE(type_index_set->count(actual_type) == 0); // NOLINT + var_id_set->insert(kId); + type_index_set->insert(std::type_index(typeid(Type))); TypeIndexChecker::Check(var_id_set, type_index_set); } @@ -75,13 +88,11 @@ TEST(var_type_traits, check_proto_type_id) { } TEST(var_type_traits, test_registry) { - using Registry = - detail::VarTypeRegistryImpl; + using Registry = detail::VarTypeRegistryImpl; ASSERT_TRUE(Registry::TypePos() == 0); ASSERT_TRUE(Registry::TypePos() == 1); ASSERT_TRUE(Registry::TypePos() == 2); ASSERT_TRUE(Registry::TypePos() == 3); - ASSERT_TRUE(Registry::TypePos() == -1); ASSERT_TRUE(Registry::TypePos() == -1); }