提交 66e82b98 编写于 作者: Y yuyang18

Change implementation to fit sphinx model

上级 715c933d
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include <stdint.h> #include <stdint.h>
#include <mutex> // NOLINT
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
...@@ -28,20 +27,27 @@ struct DataTypeMap { ...@@ -28,20 +27,27 @@ struct DataTypeMap {
std::unordered_map<std::type_index, size_t> cpp_to_size_; std::unordered_map<std::type_index, size_t> cpp_to_size_;
}; };
static DataTypeMap g_data_type_map_; static DataTypeMap* InitDataTypeMap();
static DataTypeMap& gDataTypeMap() {
static DataTypeMap* g_data_type_map_ = InitDataTypeMap();
return *g_data_type_map_;
}
template <typename T> template <typename T>
static inline void RegisterType(proto::VarType::Type proto_type, static inline void RegisterType(DataTypeMap* map,
const std::string &name) { proto::VarType::Type proto_type,
g_data_type_map_.proto_to_cpp_.emplace(static_cast<int>(proto_type), const std::string& name) {
typeid(T)); map->proto_to_cpp_.emplace(static_cast<int>(proto_type), typeid(T));
g_data_type_map_.cpp_to_proto_.emplace(typeid(T), proto_type); map->cpp_to_proto_.emplace(typeid(T), proto_type);
g_data_type_map_.proto_to_str_.emplace(static_cast<int>(proto_type), name); map->proto_to_str_.emplace(static_cast<int>(proto_type), name);
g_data_type_map_.cpp_to_size_.emplace(typeid(T), sizeof(T)); map->cpp_to_size_.emplace(typeid(T), sizeof(T));
} }
static int RegisterAllTypes() { static DataTypeMap* InitDataTypeMap() {
#define RegType(cc_type, proto_type) RegisterType<cc_type>(proto_type, #cc_type) auto retv = new DataTypeMap();
#define RegType(cc_type, proto_type) \
RegisterType<cc_type>(retv, proto_type, #cc_type)
// NOTE: Add your customize type here. // NOTE: Add your customize type here.
RegType(platform::float16, proto::VarType::FP16); RegType(platform::float16, proto::VarType::FP16);
...@@ -52,24 +58,20 @@ static int RegisterAllTypes() { ...@@ -52,24 +58,20 @@ static int RegisterAllTypes() {
RegType(bool, proto::VarType::BOOL); RegType(bool, proto::VarType::BOOL);
#undef RegType #undef RegType
return 0; return retv;
} }
static std::once_flag register_once_flag_;
proto::VarType::Type ToDataType(std::type_index type) { proto::VarType::Type ToDataType(std::type_index type) {
std::call_once(register_once_flag_, RegisterAllTypes); auto it = gDataTypeMap().cpp_to_proto_.find(type);
auto it = g_data_type_map_.cpp_to_proto_.find(type); if (it != gDataTypeMap().cpp_to_proto_.end()) {
if (it != g_data_type_map_.cpp_to_proto_.end()) {
return it->second; return it->second;
} }
PADDLE_THROW("Not support %s as tensor type", type.name()); PADDLE_THROW("Not support %s as tensor type", type.name());
} }
std::type_index ToTypeIndex(proto::VarType::Type type) { std::type_index ToTypeIndex(proto::VarType::Type type) {
std::call_once(register_once_flag_, RegisterAllTypes); auto it = gDataTypeMap().proto_to_cpp_.find(static_cast<int>(type));
auto it = g_data_type_map_.proto_to_cpp_.find(static_cast<int>(type)); if (it != gDataTypeMap().proto_to_cpp_.end()) {
if (it != g_data_type_map_.proto_to_cpp_.end()) {
return it->second; return it->second;
} }
PADDLE_THROW("Not support proto::VarType::Type(%d) as tensor type", PADDLE_THROW("Not support proto::VarType::Type(%d) as tensor type",
...@@ -77,9 +79,8 @@ std::type_index ToTypeIndex(proto::VarType::Type type) { ...@@ -77,9 +79,8 @@ std::type_index ToTypeIndex(proto::VarType::Type type) {
} }
std::string DataTypeToString(const proto::VarType::Type type) { std::string DataTypeToString(const proto::VarType::Type type) {
std::call_once(register_once_flag_, RegisterAllTypes); auto it = gDataTypeMap().proto_to_str_.find(static_cast<int>(type));
auto it = g_data_type_map_.proto_to_str_.find(static_cast<int>(type)); if (it != gDataTypeMap().proto_to_str_.end()) {
if (it != g_data_type_map_.proto_to_str_.end()) {
return it->second; return it->second;
} }
PADDLE_THROW("Not support proto::VarType::Type(%d) as tensor type", PADDLE_THROW("Not support proto::VarType::Type(%d) as tensor type",
...@@ -87,9 +88,8 @@ std::string DataTypeToString(const proto::VarType::Type type) { ...@@ -87,9 +88,8 @@ std::string DataTypeToString(const proto::VarType::Type type) {
} }
size_t SizeOfType(std::type_index type) { size_t SizeOfType(std::type_index type) {
std::call_once(register_once_flag_, RegisterAllTypes); auto it = gDataTypeMap().cpp_to_size_.find(type);
auto it = g_data_type_map_.cpp_to_size_.find(type); if (it != gDataTypeMap().cpp_to_size_.end()) {
if (it != g_data_type_map_.cpp_to_size_.end()) {
return it->second; return it->second;
} }
PADDLE_THROW("Not support %s as tensor type", type.name()); PADDLE_THROW("Not support %s as tensor type", type.name());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册