diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index b6372a2ef5934aa9b8d1a5ba86f99eb028e8ea86..d0beb8361c206a280e3e3ebe797fd49d7a0e691c 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -83,10 +83,7 @@ cc_test(threadpool_test SRCS threadpool_test.cc DEPS threadpool) cc_library(var_type_traits SRCS var_type_traits DEPS lod_tensor selected_rows framework_proto) if (WITH_GPU) - target_link_libraries(var_type_traits cudnn) - if (NOT WIN32) - target_link_libraries(var_type_traits nccl) - endif() + target_link_libraries(var_type_traits dynload_cuda) endif() cc_test(var_type_traits_test SRCS var_type_traits_test.cc DEPS var_type_traits) diff --git a/paddle/fluid/framework/var_type_traits.cc b/paddle/fluid/framework/var_type_traits.cc index 0171df6f7389deb73b9e771de9bf023234a12d0e..c9f9f8d6c653768a19a33fae7cadc3dc3ea5ece7 100644 --- a/paddle/fluid/framework/var_type_traits.cc +++ b/paddle/fluid/framework/var_type_traits.cc @@ -17,9 +17,58 @@ namespace paddle { namespace framework { -const char* ToTypeName(int var_id) { return ToTypeIndex(var_id).name(); } +// Besides registering variable type id, it is helpful to register a +// var_id -> std::type_index map (for example, get type names according to id) +namespace detail { -const std::type_index& ToTypeIndex(int var_id) { +template +struct VarIdToTypeIndexMapInitializerImpl { + static void Init(std::unordered_map *m) { + using Type = + typename std::tuple_element::type; + constexpr int kId = VarTypeTrait::kId; + if (!std::is_same::value) { + m->emplace(kId, std::type_index(typeid(Type))); + } + VarIdToTypeIndexMapInitializerImpl::Init(m); + } +}; + +template +struct VarIdToTypeIndexMapInitializerImpl { + static void Init(std::unordered_map *m) {} +}; + +// VarIdToTypeIndexMapInitializer is designed to initialize var_id -> +// std::type_index map +using VarIdToTypeIndexMapInitializer = + VarIdToTypeIndexMapInitializerImpl<0, VarTypeRegistry::kRegisteredTypeNum, + VarTypeRegistry::kRegisteredTypeNum == + 0>; + +struct VarIdToTypeIndexMapHolder { + public: + static const std::type_index &ToTypeIndex(int var_id) { + static const VarIdToTypeIndexMapHolder instance; + auto it = instance.var_type_map_.find(var_id); + PADDLE_ENFORCE(it != instance.var_type_map_.end(), + "VarId %d is not registered.", var_id); + return it->second; + } + + private: + VarIdToTypeIndexMapHolder() { + VarIdToTypeIndexMapInitializer::Init(&var_type_map_); + } + std::unordered_map var_type_map_; +}; + +} // namespace detail + +const char *ToTypeName(int var_id) { return ToTypeIndex(var_id).name(); } + +const std::type_index &ToTypeIndex(int var_id) { return detail::VarIdToTypeIndexMapHolder::ToTypeIndex(var_id); } diff --git a/paddle/fluid/framework/var_type_traits.h b/paddle/fluid/framework/var_type_traits.h index 88f917e74fc1416eafc95d0e98a1378715d6ed84..c5e0d4707efe22ed83d8abad7e0de99628743d9b 100644 --- a/paddle/fluid/framework/var_type_traits.h +++ b/paddle/fluid/framework/var_type_traits.h @@ -40,6 +40,9 @@ namespace paddle { namespace framework { +const char *ToTypeName(int var_id); +const std::type_index &ToTypeIndex(int var_id); + namespace detail { template std::type_index (for example, get var names according to id) -namespace detail { - -template -struct VarIdToTypeIndexMapInitializerImpl { - static void Init(std::unordered_map *m) { - using Type = - typename std::tuple_element::type; - constexpr int kId = VarTypeTrait::kId; - if (!std::is_same::value) { - m->emplace(kId, std::type_index(typeid(Type))); - } - VarIdToTypeIndexMapInitializerImpl::Init(m); - } -}; - -template -struct VarIdToTypeIndexMapInitializerImpl { - static void Init(std::unordered_map *m) {} -}; - -// VarIdToTypeIndexMapInitializer is designed to initialize var_id -> -// std::type_index map -using VarIdToTypeIndexMapInitializer = - VarIdToTypeIndexMapInitializerImpl<0, VarTypeRegistry::kRegisteredTypeNum, - VarTypeRegistry::kRegisteredTypeNum == - 0>; - -struct VarIdToTypeIndexMapHolder { - public: - static const std::type_index &ToTypeIndex(int var_id) { - static const VarIdToTypeIndexMapHolder instance; - auto it = instance.var_type_map_.find(var_id); - PADDLE_ENFORCE(it != instance.var_type_map_.end(), - "VarId %d is not registered.", var_id); - return it->second; - } - - private: - VarIdToTypeIndexMapHolder() { - VarIdToTypeIndexMapInitializer::Init(&var_type_map_); - } - std::unordered_map var_type_map_; -}; - -} // namespace detail - -const char *ToTypeName(int var_id); -const std::type_index &ToTypeIndex(int var_id); - template inline constexpr bool IsRegisteredVarType() { return VarTypeRegistry::IsRegistered(); diff --git a/paddle/fluid/framework/var_type_traits_test.cc b/paddle/fluid/framework/var_type_traits_test.cc index 09fab719c16747db8c23854c255a33b09dbf0285..f46608233ab90cd1e17cf9117e8a6c342dde9ed5 100644 --- a/paddle/fluid/framework/var_type_traits_test.cc +++ b/paddle/fluid/framework/var_type_traits_test.cc @@ -15,32 +15,46 @@ #include "paddle/fluid/framework/var_type_traits.h" #include #include +#include namespace paddle { namespace framework { template struct TypeIndexChecker { - static void Check() { + template + static void Check(SetType1 *var_id_set, SetType2 *type_index_set) { using Type = typename std::tuple_element::type; + static_assert(std::is_same::Type, Type>::value, + "Type must be the same"); + constexpr auto kId = VarTypeTrait::kId; if (!std::is_same::value) { - EXPECT_TRUE(ToTypeIndex(VarTypeTrait::kId) == typeid(Type)); - EXPECT_TRUE(std::string(ToTypeName(VarTypeTrait::kId)) == - typeid(Type).name()); + std::type_index actual_type(typeid(Type)); + EXPECT_EQ(std::string(ToTypeName(kId)), std::string(actual_type.name())); + EXPECT_EQ(ToTypeIndex(kId), actual_type); + EXPECT_TRUE(var_id_set->count(kId) == 0); // NOLINT + EXPECT_TRUE(type_index_set->count(actual_type) == 0); // NOLINT + var_id_set->insert(kId); + type_index_set->insert(std::type_index(typeid(Type))); } - TypeIndexChecker::Check(); + TypeIndexChecker::Check(var_id_set, + type_index_set); } }; template struct TypeIndexChecker { - static void Check() {} + template + static void Check(SetType1 *, SetType2 *) {} }; -TEST(var_type_traits, check_type_index) { +TEST(var_type_traits, check_no_duplicate_registry) { constexpr size_t kRegisteredNum = VarTypeRegistry::kRegisteredTypeNum; - TypeIndexChecker<0, kRegisteredNum, kRegisteredNum == 0>::Check(); + std::unordered_set var_id_set; + std::unordered_set type_index_set; + TypeIndexChecker<0, kRegisteredNum, kRegisteredNum == 0>::Check( + &var_id_set, &type_index_set); } template