提交 13429c3e 编写于 作者: S sneaxiy

clean code, remove void registration

test why MAC CI fail again
test=develop
上级 ce4a26dd
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle {
namespace framework {
......@@ -23,54 +24,83 @@ namespace detail {
template <int kStart, int kEnd, bool kStop>
struct VarIdToTypeIndexMapInitializerImpl {
static void Init(std::unordered_map<int, std::type_index> *m) {
template <typename MapType1, typename MapType2>
static void Init(MapType1 *id_to_type, MapType2 *type_to_id) {
using Type =
typename std::tuple_element<kStart, VarTypeRegistry::ArgTuple>::type;
static_assert(!std::is_same<Type, void>::value, "Type cannot be void");
constexpr int kId = VarTypeTrait<Type>::kId;
if (!std::is_same<Type, void>::value) {
m->emplace(kId, std::type_index(typeid(Type)));
}
auto type = std::type_index(typeid(Type));
PADDLE_ENFORCE(id_to_type->count(kId) == 0,
"Registered duplicate type id %d for type %s", kId,
type.name());
PADDLE_ENFORCE(type_to_id->count(type) == 0,
"Registered duplicate type_index %s for id %d", type.name(),
kId);
id_to_type->emplace(kId, type);
type_to_id->emplace(type, kId);
VarIdToTypeIndexMapInitializerImpl<kStart + 1, kEnd,
kStart + 1 == kEnd>::Init(m);
kStart + 1 == kEnd>::Init(id_to_type,
type_to_id);
}
};
template <int kStart, int kEnd>
struct VarIdToTypeIndexMapInitializerImpl<kStart, kEnd, true> {
static void Init(std::unordered_map<int, std::type_index> *m) {}
template <typename MapType1, typename MapType2>
static void Init(MapType1 *, MapType2 *) {}
};
// VarIdToTypeIndexMapInitializer is designed to initialize var_id ->
// std::type_index map
// std::type_index map and std::type_index -> var_id map
using VarIdToTypeIndexMapInitializer =
VarIdToTypeIndexMapInitializerImpl<0, VarTypeRegistry::kRegisteredTypeNum,
VarTypeRegistry::kRegisteredTypeNum ==
0>;
struct VarIdToTypeIndexMapHolder {
DISABLE_COPY_AND_ASSIGN(VarIdToTypeIndexMapHolder);
public:
static const std::type_index &ToTypeIndex(int var_id) {
static const VarIdToTypeIndexMapHolder instance;
auto it = instance.var_type_map_.find(var_id);
PADDLE_ENFORCE(it != instance.var_type_map_.end(),
auto it = Instance().id_to_type_map_.find(var_id);
PADDLE_ENFORCE(it != Instance().id_to_type_map_.end(),
"VarId %d is not registered.", var_id);
return it->second;
}
static int ToTypeId(const std::type_index &type) {
auto it = Instance().type_to_id_map_.find(type);
PADDLE_ENFORCE(it != Instance().type_to_id_map_.end(),
"VarType %s is not registered.", type.name());
return it->second;
}
private:
VarIdToTypeIndexMapHolder() {
VarIdToTypeIndexMapInitializer::Init(&var_type_map_);
VarIdToTypeIndexMapInitializer::Init(&id_to_type_map_, &type_to_id_map_);
}
static const VarIdToTypeIndexMapHolder &Instance() {
static const VarIdToTypeIndexMapHolder instance;
return instance;
}
std::unordered_map<int, std::type_index> var_type_map_;
std::unordered_map<int, std::type_index> id_to_type_map_;
std::unordered_map<std::type_index, int> type_to_id_map_;
};
} // namespace detail
const char *ToTypeName(int var_id) { return ToTypeIndex(var_id).name(); }
const std::type_index &ToTypeIndex(int var_id) {
return detail::VarIdToTypeIndexMapHolder::ToTypeIndex(var_id);
}
const char *ToTypeName(int var_id) { return ToTypeIndex(var_id).name(); }
int ToTypeId(const std::type_index &type) {
return detail::VarIdToTypeIndexMapHolder::ToTypeId(type);
}
} // namespace framework
} // namespace paddle
......@@ -42,6 +42,7 @@ namespace framework {
const char *ToTypeName(int var_id);
const std::type_index &ToTypeIndex(int var_id);
int ToTypeId(const std::type_index &type);
namespace detail {
......@@ -75,10 +76,10 @@ struct VarTypeRegistryImpl {
using ArgTuple = std::tuple<Args...>;
// TypePos() returns the position in which T is inside Args...
// If T is not inside Args... or T is void, return -1
// If T is not inside Args..., return -1
template <typename T>
static constexpr int TypePos() {
return std::is_same<T, void>::value ? -1 : TypePosFinder<T, Args...>::kPos;
return TypePosFinder<T, Args...>::kPos;
}
// IsRegistered() returns whether T is registered inside RegistryImpl
......@@ -90,19 +91,22 @@ struct VarTypeRegistryImpl {
} // namespace detail
#define REG_PROTO_VAR_TYPE_TRAIT(type, proto_id) \
template <> \
struct VarTypeTrait<type> { \
static_assert(VarTypeRegistry::IsRegistered<type>(), \
"Must be registered type"); \
using Type = type; \
static constexpr int kId = proto_id; \
#define REG_PROTO_VAR_TYPE_TRAIT(type, proto_id) \
template <> \
struct VarTypeTrait<type> { \
static_assert(VarTypeRegistry::IsRegistered<type>(), \
"Must be registered type"); \
using Type = type; \
static constexpr int kId = static_cast<int>(proto_id); \
}
/**
* The following codes are designed to register variable types.
* Only registered types can be stored in Variable.
* This registry mechanism is designed to speed up Variable.
*
* Caution: If you want to add more var types, please consider carefully
* whether you really need to add it.
*/
// Users should add other variable types below.
......@@ -110,10 +114,9 @@ struct VarTypeRegistryImpl {
class Scope;
using VarTypeRegistry = detail::VarTypeRegistryImpl<
LoDTensor, SelectedRows, std::vector<Scope *>, LoDRankTable, LoDTensorArray,
platform::PlaceList, ReaderHolder, Tensor, std::string, Scope *,
Tensor, LoDTensor, SelectedRows, std::vector<Scope *>, LoDRankTable,
LoDTensorArray, platform::PlaceList, ReaderHolder, std::string, Scope *,
std::map<size_t, Tensor>, operators::reader::LoDTensorBlockingQueueHolder,
int, float,
#ifdef PADDLE_WITH_CUDA
#ifndef _WIN32
ncclUniqueId, platform::Communicator,
......@@ -123,13 +126,11 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl<
operators::AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>,
operators::CudnnRNNCache,
#endif
void>; // void indicates end of registration, add other types before void
int, float>;
template <typename T>
struct VarTypeTrait {
static_assert(std::is_same<T, void>::value ||
VarTypeRegistry::IsRegistered<T>(),
"Must be registered type");
static_assert(VarTypeRegistry::IsRegistered<T>(), "Must be registered type");
using Type = T;
// Default id generation
static constexpr int kId = VarTypeRegistry::TypePos<T>() +
......
......@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/var_type_traits.h"
#include <gtest/gtest.h>
#include <cstdint>
#include <iostream>
#include <unordered_set>
namespace paddle {
......@@ -29,15 +30,27 @@ struct TypeIndexChecker {
static_assert(std::is_same<typename VarTypeTrait<Type>::Type, Type>::value,
"Type must be the same");
constexpr auto kId = VarTypeTrait<Type>::kId;
if (!std::is_same<Type, void>::value) {
std::type_index actual_type(typeid(Type));
EXPECT_EQ(std::string(ToTypeName(kId)), std::string(actual_type.name()));
EXPECT_EQ(ToTypeIndex(kId), actual_type);
EXPECT_TRUE(var_id_set->count(kId) == 0); // NOLINT
EXPECT_TRUE(type_index_set->count(actual_type) == 0); // NOLINT
var_id_set->insert(kId);
type_index_set->insert(std::type_index(typeid(Type)));
std::type_index actual_type(typeid(Type));
EXPECT_EQ(std::string(ToTypeName(kId)), std::string(actual_type.name()));
// For some reasons, comparing std::type_index using EXPECT_EQ would fail
// in MAC CI
bool is_same_type_index = (ToTypeIndex(kId) == actual_type);
if (!is_same_type_index) {
std::string s1 = ToTypeName(kId);
std::string s2 = actual_type.name();
PADDLE_THROW("Step %d: type %s is not the same as %s, var_id %d", kPos,
s1.c_str(), s2.c_str(), kId);
}
EXPECT_TRUE(is_same_type_index);
EXPECT_TRUE(ToTypeId(actual_type) == kId); // NOLINT
is_same_type_index = (ToTypeIndex(ToTypeId(actual_type)) == actual_type);
EXPECT_TRUE(is_same_type_index);
EXPECT_EQ(ToTypeId(ToTypeIndex(kId)), kId);
EXPECT_TRUE(var_id_set->count(kId) == 0); // NOLINT
EXPECT_TRUE(type_index_set->count(actual_type) == 0); // NOLINT
var_id_set->insert(kId);
type_index_set->insert(std::type_index(typeid(Type)));
TypeIndexChecker<kPos + 1, kEnd, kPos + 1 == kEnd>::Check(var_id_set,
type_index_set);
}
......@@ -75,13 +88,11 @@ TEST(var_type_traits, check_proto_type_id) {
}
TEST(var_type_traits, test_registry) {
using Registry =
detail::VarTypeRegistryImpl<int8_t, int32_t, size_t, double, void>;
using Registry = detail::VarTypeRegistryImpl<int8_t, int32_t, size_t, double>;
ASSERT_TRUE(Registry::TypePos<int8_t>() == 0);
ASSERT_TRUE(Registry::TypePos<int32_t>() == 1);
ASSERT_TRUE(Registry::TypePos<size_t>() == 2);
ASSERT_TRUE(Registry::TypePos<double>() == 3);
ASSERT_TRUE(Registry::TypePos<void>() == -1);
ASSERT_TRUE(Registry::TypePos<float>() == -1);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册