提交 13429c3e 编写于 作者: S sneaxiy

clean code, remove void registration

test why MAC CI fail again
test=develop
上级 ce4a26dd
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/var_type_traits.h" #include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -23,54 +24,83 @@ namespace detail { ...@@ -23,54 +24,83 @@ namespace detail {
template <int kStart, int kEnd, bool kStop> template <int kStart, int kEnd, bool kStop>
struct VarIdToTypeIndexMapInitializerImpl { struct VarIdToTypeIndexMapInitializerImpl {
static void Init(std::unordered_map<int, std::type_index> *m) { template <typename MapType1, typename MapType2>
static void Init(MapType1 *id_to_type, MapType2 *type_to_id) {
using Type = using Type =
typename std::tuple_element<kStart, VarTypeRegistry::ArgTuple>::type; typename std::tuple_element<kStart, VarTypeRegistry::ArgTuple>::type;
static_assert(!std::is_same<Type, void>::value, "Type cannot be void");
constexpr int kId = VarTypeTrait<Type>::kId; constexpr int kId = VarTypeTrait<Type>::kId;
if (!std::is_same<Type, void>::value) { auto type = std::type_index(typeid(Type));
m->emplace(kId, std::type_index(typeid(Type))); PADDLE_ENFORCE(id_to_type->count(kId) == 0,
} "Registered duplicate type id %d for type %s", kId,
type.name());
PADDLE_ENFORCE(type_to_id->count(type) == 0,
"Registered duplicate type_index %s for id %d", type.name(),
kId);
id_to_type->emplace(kId, type);
type_to_id->emplace(type, kId);
VarIdToTypeIndexMapInitializerImpl<kStart + 1, kEnd, VarIdToTypeIndexMapInitializerImpl<kStart + 1, kEnd,
kStart + 1 == kEnd>::Init(m); kStart + 1 == kEnd>::Init(id_to_type,
type_to_id);
} }
}; };
template <int kStart, int kEnd> template <int kStart, int kEnd>
struct VarIdToTypeIndexMapInitializerImpl<kStart, kEnd, true> { struct VarIdToTypeIndexMapInitializerImpl<kStart, kEnd, true> {
static void Init(std::unordered_map<int, std::type_index> *m) {} template <typename MapType1, typename MapType2>
static void Init(MapType1 *, MapType2 *) {}
}; };
// VarIdToTypeIndexMapInitializer is designed to initialize var_id -> // VarIdToTypeIndexMapInitializer is designed to initialize var_id ->
// std::type_index map // std::type_index map and std::type_index -> var_id map
using VarIdToTypeIndexMapInitializer = using VarIdToTypeIndexMapInitializer =
VarIdToTypeIndexMapInitializerImpl<0, VarTypeRegistry::kRegisteredTypeNum, VarIdToTypeIndexMapInitializerImpl<0, VarTypeRegistry::kRegisteredTypeNum,
VarTypeRegistry::kRegisteredTypeNum == VarTypeRegistry::kRegisteredTypeNum ==
0>; 0>;
struct VarIdToTypeIndexMapHolder { struct VarIdToTypeIndexMapHolder {
DISABLE_COPY_AND_ASSIGN(VarIdToTypeIndexMapHolder);
public: public:
static const std::type_index &ToTypeIndex(int var_id) { static const std::type_index &ToTypeIndex(int var_id) {
static const VarIdToTypeIndexMapHolder instance; auto it = Instance().id_to_type_map_.find(var_id);
auto it = instance.var_type_map_.find(var_id); PADDLE_ENFORCE(it != Instance().id_to_type_map_.end(),
PADDLE_ENFORCE(it != instance.var_type_map_.end(),
"VarId %d is not registered.", var_id); "VarId %d is not registered.", var_id);
return it->second; return it->second;
} }
static int ToTypeId(const std::type_index &type) {
auto it = Instance().type_to_id_map_.find(type);
PADDLE_ENFORCE(it != Instance().type_to_id_map_.end(),
"VarType %s is not registered.", type.name());
return it->second;
}
private: private:
VarIdToTypeIndexMapHolder() { VarIdToTypeIndexMapHolder() {
VarIdToTypeIndexMapInitializer::Init(&var_type_map_); VarIdToTypeIndexMapInitializer::Init(&id_to_type_map_, &type_to_id_map_);
}
static const VarIdToTypeIndexMapHolder &Instance() {
static const VarIdToTypeIndexMapHolder instance;
return instance;
} }
std::unordered_map<int, std::type_index> var_type_map_;
std::unordered_map<int, std::type_index> id_to_type_map_;
std::unordered_map<std::type_index, int> type_to_id_map_;
}; };
} // namespace detail } // namespace detail
const char *ToTypeName(int var_id) { return ToTypeIndex(var_id).name(); }
const std::type_index &ToTypeIndex(int var_id) { const std::type_index &ToTypeIndex(int var_id) {
return detail::VarIdToTypeIndexMapHolder::ToTypeIndex(var_id); return detail::VarIdToTypeIndexMapHolder::ToTypeIndex(var_id);
} }
const char *ToTypeName(int var_id) { return ToTypeIndex(var_id).name(); }
int ToTypeId(const std::type_index &type) {
return detail::VarIdToTypeIndexMapHolder::ToTypeId(type);
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -42,6 +42,7 @@ namespace framework { ...@@ -42,6 +42,7 @@ namespace framework {
const char *ToTypeName(int var_id); const char *ToTypeName(int var_id);
const std::type_index &ToTypeIndex(int var_id); const std::type_index &ToTypeIndex(int var_id);
int ToTypeId(const std::type_index &type);
namespace detail { namespace detail {
...@@ -75,10 +76,10 @@ struct VarTypeRegistryImpl { ...@@ -75,10 +76,10 @@ struct VarTypeRegistryImpl {
using ArgTuple = std::tuple<Args...>; using ArgTuple = std::tuple<Args...>;
// TypePos() returns the position in which T is inside Args... // TypePos() returns the position in which T is inside Args...
// If T is not inside Args... or T is void, return -1 // If T is not inside Args..., return -1
template <typename T> template <typename T>
static constexpr int TypePos() { static constexpr int TypePos() {
return std::is_same<T, void>::value ? -1 : TypePosFinder<T, Args...>::kPos; return TypePosFinder<T, Args...>::kPos;
} }
// IsRegistered() returns whether T is registered inside RegistryImpl // IsRegistered() returns whether T is registered inside RegistryImpl
...@@ -96,13 +97,16 @@ struct VarTypeRegistryImpl { ...@@ -96,13 +97,16 @@ struct VarTypeRegistryImpl {
static_assert(VarTypeRegistry::IsRegistered<type>(), \ static_assert(VarTypeRegistry::IsRegistered<type>(), \
"Must be registered type"); \ "Must be registered type"); \
using Type = type; \ using Type = type; \
static constexpr int kId = proto_id; \ static constexpr int kId = static_cast<int>(proto_id); \
} }
/** /**
* The following codes are designed to register variable types. * The following codes are designed to register variable types.
* Only registered types can be stored in Variable. * Only registered types can be stored in Variable.
* This registry mechanism is designed to speed up Variable. * This registry mechanism is designed to speed up Variable.
*
* Caution: If you want to add more var types, please consider carefully
* whether you really need to add it.
*/ */
// Users should add other variable types below. // Users should add other variable types below.
...@@ -110,10 +114,9 @@ struct VarTypeRegistryImpl { ...@@ -110,10 +114,9 @@ struct VarTypeRegistryImpl {
class Scope; class Scope;
using VarTypeRegistry = detail::VarTypeRegistryImpl< using VarTypeRegistry = detail::VarTypeRegistryImpl<
LoDTensor, SelectedRows, std::vector<Scope *>, LoDRankTable, LoDTensorArray, Tensor, LoDTensor, SelectedRows, std::vector<Scope *>, LoDRankTable,
platform::PlaceList, ReaderHolder, Tensor, std::string, Scope *, LoDTensorArray, platform::PlaceList, ReaderHolder, std::string, Scope *,
std::map<size_t, Tensor>, operators::reader::LoDTensorBlockingQueueHolder, std::map<size_t, Tensor>, operators::reader::LoDTensorBlockingQueueHolder,
int, float,
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#ifndef _WIN32 #ifndef _WIN32
ncclUniqueId, platform::Communicator, ncclUniqueId, platform::Communicator,
...@@ -123,13 +126,11 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl< ...@@ -123,13 +126,11 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl<
operators::AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>, operators::AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>,
operators::CudnnRNNCache, operators::CudnnRNNCache,
#endif #endif
void>; // void indicates end of registration, add other types before void int, float>;
template <typename T> template <typename T>
struct VarTypeTrait { struct VarTypeTrait {
static_assert(std::is_same<T, void>::value || static_assert(VarTypeRegistry::IsRegistered<T>(), "Must be registered type");
VarTypeRegistry::IsRegistered<T>(),
"Must be registered type");
using Type = T; using Type = T;
// Default id generation // Default id generation
static constexpr int kId = VarTypeRegistry::TypePos<T>() + static constexpr int kId = VarTypeRegistry::TypePos<T>() +
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/var_type_traits.h" #include "paddle/fluid/framework/var_type_traits.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <cstdint> #include <cstdint>
#include <iostream>
#include <unordered_set> #include <unordered_set>
namespace paddle { namespace paddle {
...@@ -29,15 +30,27 @@ struct TypeIndexChecker { ...@@ -29,15 +30,27 @@ struct TypeIndexChecker {
static_assert(std::is_same<typename VarTypeTrait<Type>::Type, Type>::value, static_assert(std::is_same<typename VarTypeTrait<Type>::Type, Type>::value,
"Type must be the same"); "Type must be the same");
constexpr auto kId = VarTypeTrait<Type>::kId; constexpr auto kId = VarTypeTrait<Type>::kId;
if (!std::is_same<Type, void>::value) {
std::type_index actual_type(typeid(Type)); std::type_index actual_type(typeid(Type));
EXPECT_EQ(std::string(ToTypeName(kId)), std::string(actual_type.name())); EXPECT_EQ(std::string(ToTypeName(kId)), std::string(actual_type.name()));
EXPECT_EQ(ToTypeIndex(kId), actual_type); // For some reasons, comparing std::type_index using EXPECT_EQ would fail
// in MAC CI
bool is_same_type_index = (ToTypeIndex(kId) == actual_type);
if (!is_same_type_index) {
std::string s1 = ToTypeName(kId);
std::string s2 = actual_type.name();
PADDLE_THROW("Step %d: type %s is not the same as %s, var_id %d", kPos,
s1.c_str(), s2.c_str(), kId);
}
EXPECT_TRUE(is_same_type_index);
EXPECT_TRUE(ToTypeId(actual_type) == kId); // NOLINT
is_same_type_index = (ToTypeIndex(ToTypeId(actual_type)) == actual_type);
EXPECT_TRUE(is_same_type_index);
EXPECT_EQ(ToTypeId(ToTypeIndex(kId)), kId);
EXPECT_TRUE(var_id_set->count(kId) == 0); // NOLINT EXPECT_TRUE(var_id_set->count(kId) == 0); // NOLINT
EXPECT_TRUE(type_index_set->count(actual_type) == 0); // NOLINT EXPECT_TRUE(type_index_set->count(actual_type) == 0); // NOLINT
var_id_set->insert(kId); var_id_set->insert(kId);
type_index_set->insert(std::type_index(typeid(Type))); type_index_set->insert(std::type_index(typeid(Type)));
}
TypeIndexChecker<kPos + 1, kEnd, kPos + 1 == kEnd>::Check(var_id_set, TypeIndexChecker<kPos + 1, kEnd, kPos + 1 == kEnd>::Check(var_id_set,
type_index_set); type_index_set);
} }
...@@ -75,13 +88,11 @@ TEST(var_type_traits, check_proto_type_id) { ...@@ -75,13 +88,11 @@ TEST(var_type_traits, check_proto_type_id) {
} }
TEST(var_type_traits, test_registry) { TEST(var_type_traits, test_registry) {
using Registry = using Registry = detail::VarTypeRegistryImpl<int8_t, int32_t, size_t, double>;
detail::VarTypeRegistryImpl<int8_t, int32_t, size_t, double, void>;
ASSERT_TRUE(Registry::TypePos<int8_t>() == 0); ASSERT_TRUE(Registry::TypePos<int8_t>() == 0);
ASSERT_TRUE(Registry::TypePos<int32_t>() == 1); ASSERT_TRUE(Registry::TypePos<int32_t>() == 1);
ASSERT_TRUE(Registry::TypePos<size_t>() == 2); ASSERT_TRUE(Registry::TypePos<size_t>() == 2);
ASSERT_TRUE(Registry::TypePos<double>() == 3); ASSERT_TRUE(Registry::TypePos<double>() == 3);
ASSERT_TRUE(Registry::TypePos<void>() == -1);
ASSERT_TRUE(Registry::TypePos<float>() == -1); ASSERT_TRUE(Registry::TypePos<float>() == -1);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册