diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 340b891e41671df7e61a4a66ec538d4603bb9842..0dfcf27b1c04871ba44390cb889ecbd3f433c8c9 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -5,11 +5,11 @@ proto_library(framework_proto SRCS framework.proto) cc_library(ddim SRCS ddim.cc DEPS eigen3 boost) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) nv_test(dim_test SRCS dim_test.cu DEPS ddim) - +cc_library(data_type SRCS data_type.cc DEPS framework_proto ddim) if(WITH_GPU) - nv_library(tensor SRCS tensor.cc tensor_util.cu DEPS ddim place memory device_context framework_proto) + nv_library(tensor SRCS tensor.cc tensor_util.cu DEPS place memory device_context data_type) else() - cc_library(tensor SRCS tensor.cc tensor_util.cc DEPS ddim place memory device_context framework_proto) + cc_library(tensor SRCS tensor.cc tensor_util.cc DEPS place memory device_context data_type) endif() cc_test(tensor_test SRCS tensor_test.cc DEPS tensor) diff --git a/paddle/fluid/framework/data_type.cc b/paddle/fluid/framework/data_type.cc new file mode 100644 index 0000000000000000000000000000000000000000..b60c6186588988e0156ae36c1dc549b1cda4cba0 --- /dev/null +++ b/paddle/fluid/framework/data_type.cc @@ -0,0 +1,82 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/data_type.h" + +namespace paddle { +namespace framework { + +struct DataTypeMap { + std::unordered_map cpp_to_proto_; + std::unordered_map proto_to_cpp_; + std::unordered_map proto_to_str_; +}; + +static DataTypeMap g_data_type_map_; + +template +static inline void RegisterType(proto::VarType::Type proto_type, + const std::string &name) { + g_data_type_map_.proto_to_cpp_.emplace(proto_type, typeid(T)); + g_data_type_map_.cpp_to_proto_.emplace(typeid(T), proto_type); + g_data_type_map_.proto_to_str_.emplace(proto_type, name); +} + +static int RegisterAllTypes() { +#define RegType(cc_type, proto_type) RegisterType(proto_type, #cc_type) + + RegType(platform::float16, proto::VarType::FP16); + RegType(float, proto::VarType::FP32); + RegType(double, proto::VarType::FP64); + RegType(int, proto::VarType::INT32); + RegType(int64_t, proto::VarType::INT64); + RegType(bool, proto::VarType::BOOL); + +#undef RegType + return 0; +} + +static std::once_flag register_once_flag_; + +proto::VarType::Type ToDataType(std::type_index type) { + std::call_once(register_once_flag_, RegisterAllTypes); + auto it = g_data_type_map_.cpp_to_proto_.find(type); + if (it != g_data_type_map_.cpp_to_proto_.end()) { + return it->second; + } + PADDLE_THROW("Not support %s as tensor type", type.name()); +} + +std::type_index ToTypeIndex(proto::VarType::Type type) { + std::call_once(register_once_flag_, RegisterAllTypes); + auto it = g_data_type_map_.proto_to_cpp_.find(type); + if (it != g_data_type_map_.proto_to_cpp_.end()) { + return it->second; + } + PADDLE_THROW("Not support proto::VarType::Type(%d) as tensor type", + static_cast(type)); +} + +std::string DataTypeToString(const proto::VarType::Type type) { + std::call_once(register_once_flag_, RegisterAllTypes); + auto it = g_data_type_map_.proto_to_str_.find(type); + if (it != g_data_type_map_.proto_to_str_.end()) { + return it->second; + } + PADDLE_THROW("Not support proto::VarType::Type(%d) as tensor type", + static_cast(type)); +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/data_type.h b/paddle/fluid/framework/data_type.h index 2a528eb3aa562568c92059250f2c9bc5a75ec103..06cc5940b75fd3ff7259fefadb202b1693b39879 100644 --- a/paddle/fluid/framework/data_type.h +++ b/paddle/fluid/framework/data_type.h @@ -22,47 +22,8 @@ limitations under the License. */ namespace paddle { namespace framework { -inline proto::VarType::Type ToDataType(std::type_index type) { - if (typeid(platform::float16).hash_code() == type.hash_code()) { - return proto::VarType::FP16; - } else if (typeid(const float).hash_code() == type.hash_code()) { - // CPPLint complains Using C-style cast. Use static_cast() instead - // One fix to this is to replace float with const float because - // typeid(T) == typeid(const T) - // http://en.cppreference.com/w/cpp/language/typeid - return proto::VarType::FP32; - } else if (typeid(const double).hash_code() == type.hash_code()) { - return proto::VarType::FP64; - } else if (typeid(const int).hash_code() == type.hash_code()) { - return proto::VarType::INT32; - } else if (typeid(const int64_t).hash_code() == type.hash_code()) { - return proto::VarType::INT64; - } else if (typeid(const bool).hash_code() == type.hash_code()) { - return proto::VarType::BOOL; - } else { - PADDLE_THROW("Not supported"); - } -} - -inline std::type_index ToTypeIndex(proto::VarType::Type type) { - switch (type) { - case proto::VarType::FP16: - return typeid(platform::float16); - case proto::VarType::FP32: - return typeid(float); - case proto::VarType::FP64: - return typeid(double); - case proto::VarType::INT32: - return typeid(int); - case proto::VarType::INT64: - return typeid(int64_t); - case proto::VarType::BOOL: - return typeid(bool); - default: - PADDLE_THROW("Not support type %d", type); - } -} - +extern proto::VarType::Type ToDataType(std::type_index type); +extern std::type_index ToTypeIndex(proto::VarType::Type type); template inline void VisitDataType(proto::VarType::Type type, Visitor visitor) { switch (type) { @@ -89,32 +50,11 @@ inline void VisitDataType(proto::VarType::Type type, Visitor visitor) { } } -inline std::string DataTypeToString(const proto::VarType::Type type) { - switch (type) { - case proto::VarType::FP16: - return "float16"; - case proto::VarType::FP32: - return "float32"; - case proto::VarType::FP64: - return "float64"; - case proto::VarType::INT16: - return "int16"; - case proto::VarType::INT32: - return "int32"; - case proto::VarType::INT64: - return "int64"; - case proto::VarType::BOOL: - return "bool"; - default: - PADDLE_THROW("Not support type %d", type); - } -} - +extern std::string DataTypeToString(const proto::VarType::Type type); inline std::ostream& operator<<(std::ostream& out, const proto::VarType::Type& type) { out << DataTypeToString(type); return out; } - } // namespace framework } // namespace paddle