From ce4a26ddad08a9d640f1ec3ddae254d0d0abd004 Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Wed, 19 Dec 2018 12:23:11 +0000 Subject: [PATCH] clean code try to fix mac compile bug? test=develop --- paddle/fluid/framework/CMakeLists.txt | 5 +- paddle/fluid/framework/var_type_traits.cc | 53 +++++++++++++++++- paddle/fluid/framework/var_type_traits.h | 55 +------------------ .../fluid/framework/var_type_traits_test.cc | 30 +++++++--- 4 files changed, 77 insertions(+), 66 deletions(-) diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index b6372a2ef..d0beb8361 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -83,10 +83,7 @@ cc_test(threadpool_test SRCS threadpool_test.cc DEPS threadpool) cc_library(var_type_traits SRCS var_type_traits DEPS lod_tensor selected_rows framework_proto) if (WITH_GPU) - target_link_libraries(var_type_traits cudnn) - if (NOT WIN32) - target_link_libraries(var_type_traits nccl) - endif() + target_link_libraries(var_type_traits dynload_cuda) endif() cc_test(var_type_traits_test SRCS var_type_traits_test.cc DEPS var_type_traits) diff --git a/paddle/fluid/framework/var_type_traits.cc b/paddle/fluid/framework/var_type_traits.cc index 0171df6f7..c9f9f8d6c 100644 --- a/paddle/fluid/framework/var_type_traits.cc +++ b/paddle/fluid/framework/var_type_traits.cc @@ -17,9 +17,58 @@ namespace paddle { namespace framework { -const char* ToTypeName(int var_id) { return ToTypeIndex(var_id).name(); } +// Besides registering variable type id, it is helpful to register a +// var_id -> std::type_index map (for example, get type names according to id) +namespace detail { -const std::type_index& ToTypeIndex(int var_id) { +template +struct VarIdToTypeIndexMapInitializerImpl { + static void Init(std::unordered_map *m) { + using Type = + typename std::tuple_element::type; + constexpr int kId = VarTypeTrait::kId; + if (!std::is_same::value) { + m->emplace(kId, std::type_index(typeid(Type))); + } + VarIdToTypeIndexMapInitializerImpl::Init(m); + } +}; + +template +struct VarIdToTypeIndexMapInitializerImpl { + static void Init(std::unordered_map *m) {} +}; + +// VarIdToTypeIndexMapInitializer is designed to initialize var_id -> +// std::type_index map +using VarIdToTypeIndexMapInitializer = + VarIdToTypeIndexMapInitializerImpl<0, VarTypeRegistry::kRegisteredTypeNum, + VarTypeRegistry::kRegisteredTypeNum == + 0>; + +struct VarIdToTypeIndexMapHolder { + public: + static const std::type_index &ToTypeIndex(int var_id) { + static const VarIdToTypeIndexMapHolder instance; + auto it = instance.var_type_map_.find(var_id); + PADDLE_ENFORCE(it != instance.var_type_map_.end(), + "VarId %d is not registered.", var_id); + return it->second; + } + + private: + VarIdToTypeIndexMapHolder() { + VarIdToTypeIndexMapInitializer::Init(&var_type_map_); + } + std::unordered_map var_type_map_; +}; + +} // namespace detail + +const char *ToTypeName(int var_id) { return ToTypeIndex(var_id).name(); } + +const std::type_index &ToTypeIndex(int var_id) { return detail::VarIdToTypeIndexMapHolder::ToTypeIndex(var_id); } diff --git a/paddle/fluid/framework/var_type_traits.h b/paddle/fluid/framework/var_type_traits.h index 88f917e74..c5e0d4707 100644 --- a/paddle/fluid/framework/var_type_traits.h +++ b/paddle/fluid/framework/var_type_traits.h @@ -40,6 +40,9 @@ namespace paddle { namespace framework { +const char *ToTypeName(int var_id); +const std::type_index &ToTypeIndex(int var_id); + namespace detail { template std::type_index (for example, get var names according to id) -namespace detail { - -template -struct VarIdToTypeIndexMapInitializerImpl { - static void Init(std::unordered_map *m) { - using Type = - typename std::tuple_element::type; - constexpr int kId = VarTypeTrait::kId; - if (!std::is_same::value) { - m->emplace(kId, std::type_index(typeid(Type))); - } - VarIdToTypeIndexMapInitializerImpl::Init(m); - } -}; - -template -struct VarIdToTypeIndexMapInitializerImpl { - static void Init(std::unordered_map *m) {} -}; - -// VarIdToTypeIndexMapInitializer is designed to initialize var_id -> -// std::type_index map -using VarIdToTypeIndexMapInitializer = - VarIdToTypeIndexMapInitializerImpl<0, VarTypeRegistry::kRegisteredTypeNum, - VarTypeRegistry::kRegisteredTypeNum == - 0>; - -struct VarIdToTypeIndexMapHolder { - public: - static const std::type_index &ToTypeIndex(int var_id) { - static const VarIdToTypeIndexMapHolder instance; - auto it = instance.var_type_map_.find(var_id); - PADDLE_ENFORCE(it != instance.var_type_map_.end(), - "VarId %d is not registered.", var_id); - return it->second; - } - - private: - VarIdToTypeIndexMapHolder() { - VarIdToTypeIndexMapInitializer::Init(&var_type_map_); - } - std::unordered_map var_type_map_; -}; - -} // namespace detail - -const char *ToTypeName(int var_id); -const std::type_index &ToTypeIndex(int var_id); - template inline constexpr bool IsRegisteredVarType() { return VarTypeRegistry::IsRegistered(); diff --git a/paddle/fluid/framework/var_type_traits_test.cc b/paddle/fluid/framework/var_type_traits_test.cc index 09fab719c..f46608233 100644 --- a/paddle/fluid/framework/var_type_traits_test.cc +++ b/paddle/fluid/framework/var_type_traits_test.cc @@ -15,32 +15,46 @@ #include "paddle/fluid/framework/var_type_traits.h" #include #include +#include namespace paddle { namespace framework { template struct TypeIndexChecker { - static void Check() { + template + static void Check(SetType1 *var_id_set, SetType2 *type_index_set) { using Type = typename std::tuple_element::type; + static_assert(std::is_same::Type, Type>::value, + "Type must be the same"); + constexpr auto kId = VarTypeTrait::kId; if (!std::is_same::value) { - EXPECT_TRUE(ToTypeIndex(VarTypeTrait::kId) == typeid(Type)); - EXPECT_TRUE(std::string(ToTypeName(VarTypeTrait::kId)) == - typeid(Type).name()); + std::type_index actual_type(typeid(Type)); + EXPECT_EQ(std::string(ToTypeName(kId)), std::string(actual_type.name())); + EXPECT_EQ(ToTypeIndex(kId), actual_type); + EXPECT_TRUE(var_id_set->count(kId) == 0); // NOLINT + EXPECT_TRUE(type_index_set->count(actual_type) == 0); // NOLINT + var_id_set->insert(kId); + type_index_set->insert(std::type_index(typeid(Type))); } - TypeIndexChecker::Check(); + TypeIndexChecker::Check(var_id_set, + type_index_set); } }; template struct TypeIndexChecker { - static void Check() {} + template + static void Check(SetType1 *, SetType2 *) {} }; -TEST(var_type_traits, check_type_index) { +TEST(var_type_traits, check_no_duplicate_registry) { constexpr size_t kRegisteredNum = VarTypeRegistry::kRegisteredTypeNum; - TypeIndexChecker<0, kRegisteredNum, kRegisteredNum == 0>::Check(); + std::unordered_set var_id_set; + std::unordered_set type_index_set; + TypeIndexChecker<0, kRegisteredNum, kRegisteredNum == 0>::Check( + &var_id_set, &type_index_set); } template -- GitLab