提交 66e82b98 编写于 作者: Y yuyang18

Change implementation to fit sphinx model

上级 715c933d
......@@ -14,7 +14,6 @@
#include "paddle/fluid/framework/data_type.h"
#include <stdint.h>
#include <mutex> // NOLINT
#include <string>
#include <unordered_map>
......@@ -28,20 +27,27 @@ struct DataTypeMap {
std::unordered_map<std::type_index, size_t> cpp_to_size_;
};
static DataTypeMap g_data_type_map_;
static DataTypeMap* InitDataTypeMap();
static DataTypeMap& gDataTypeMap() {
static DataTypeMap* g_data_type_map_ = InitDataTypeMap();
return *g_data_type_map_;
}
template <typename T>
static inline void RegisterType(proto::VarType::Type proto_type,
const std::string &name) {
g_data_type_map_.proto_to_cpp_.emplace(static_cast<int>(proto_type),
typeid(T));
g_data_type_map_.cpp_to_proto_.emplace(typeid(T), proto_type);
g_data_type_map_.proto_to_str_.emplace(static_cast<int>(proto_type), name);
g_data_type_map_.cpp_to_size_.emplace(typeid(T), sizeof(T));
static inline void RegisterType(DataTypeMap* map,
proto::VarType::Type proto_type,
const std::string& name) {
map->proto_to_cpp_.emplace(static_cast<int>(proto_type), typeid(T));
map->cpp_to_proto_.emplace(typeid(T), proto_type);
map->proto_to_str_.emplace(static_cast<int>(proto_type), name);
map->cpp_to_size_.emplace(typeid(T), sizeof(T));
}
static int RegisterAllTypes() {
#define RegType(cc_type, proto_type) RegisterType<cc_type>(proto_type, #cc_type)
static DataTypeMap* InitDataTypeMap() {
auto retv = new DataTypeMap();
#define RegType(cc_type, proto_type) \
RegisterType<cc_type>(retv, proto_type, #cc_type)
// NOTE: Add your customize type here.
RegType(platform::float16, proto::VarType::FP16);
......@@ -52,24 +58,20 @@ static int RegisterAllTypes() {
RegType(bool, proto::VarType::BOOL);
#undef RegType
return 0;
return retv;
}
static std::once_flag register_once_flag_;
proto::VarType::Type ToDataType(std::type_index type) {
std::call_once(register_once_flag_, RegisterAllTypes);
auto it = g_data_type_map_.cpp_to_proto_.find(type);
if (it != g_data_type_map_.cpp_to_proto_.end()) {
auto it = gDataTypeMap().cpp_to_proto_.find(type);
if (it != gDataTypeMap().cpp_to_proto_.end()) {
return it->second;
}
PADDLE_THROW("Not support %s as tensor type", type.name());
}
std::type_index ToTypeIndex(proto::VarType::Type type) {
std::call_once(register_once_flag_, RegisterAllTypes);
auto it = g_data_type_map_.proto_to_cpp_.find(static_cast<int>(type));
if (it != g_data_type_map_.proto_to_cpp_.end()) {
auto it = gDataTypeMap().proto_to_cpp_.find(static_cast<int>(type));
if (it != gDataTypeMap().proto_to_cpp_.end()) {
return it->second;
}
PADDLE_THROW("Not support proto::VarType::Type(%d) as tensor type",
......@@ -77,9 +79,8 @@ std::type_index ToTypeIndex(proto::VarType::Type type) {
}
std::string DataTypeToString(const proto::VarType::Type type) {
std::call_once(register_once_flag_, RegisterAllTypes);
auto it = g_data_type_map_.proto_to_str_.find(static_cast<int>(type));
if (it != g_data_type_map_.proto_to_str_.end()) {
auto it = gDataTypeMap().proto_to_str_.find(static_cast<int>(type));
if (it != gDataTypeMap().proto_to_str_.end()) {
return it->second;
}
PADDLE_THROW("Not support proto::VarType::Type(%d) as tensor type",
......@@ -87,9 +88,8 @@ std::string DataTypeToString(const proto::VarType::Type type) {
}
size_t SizeOfType(std::type_index type) {
std::call_once(register_once_flag_, RegisterAllTypes);
auto it = g_data_type_map_.cpp_to_size_.find(type);
if (it != g_data_type_map_.cpp_to_size_.end()) {
auto it = gDataTypeMap().cpp_to_size_.find(type);
if (it != gDataTypeMap().cpp_to_size_.end()) {
return it->second;
}
PADDLE_THROW("Not support %s as tensor type", type.name());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册