diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index ab71e0e63ce18e4f221a046eeb2c39499c1c3816..ed1e70c6460b513c1d2e1add18ac037f71d36944 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -5,11 +5,11 @@ proto_library(framework_proto SRCS framework.proto) cc_library(ddim SRCS ddim.cc DEPS eigen3 boost) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) nv_test(dim_test SRCS dim_test.cu DEPS ddim) - +cc_library(data_type SRCS data_type.cc DEPS framework_proto ddim device_context) if(WITH_GPU) - nv_library(tensor SRCS tensor.cc tensor_util.cu DEPS ddim place memory device_context framework_proto) + nv_library(tensor SRCS tensor.cc tensor_util.cu DEPS place memory data_type) else() - cc_library(tensor SRCS tensor.cc tensor_util.cc DEPS ddim place memory device_context framework_proto) + cc_library(tensor SRCS tensor.cc tensor_util.cc DEPS place memory data_type) endif() cc_test(tensor_test SRCS tensor_test.cc DEPS tensor) diff --git a/paddle/fluid/framework/data_type.cc b/paddle/fluid/framework/data_type.cc new file mode 100644 index 0000000000000000000000000000000000000000..b9c90cb0c32f337ba82ce1eaa5b43199540491ef --- /dev/null +++ b/paddle/fluid/framework/data_type.cc @@ -0,0 +1,101 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/data_type.h" +#include +#include +#include + +namespace paddle { +namespace framework { + +struct DataTypeMap { + std::unordered_map cpp_to_proto_; + std::unordered_map proto_to_cpp_; + std::unordered_map proto_to_str_; + std::unordered_map cpp_to_size_; +}; + +static DataTypeMap* InitDataTypeMap(); +static DataTypeMap& gDataTypeMap() { + static DataTypeMap* g_data_type_map_ = InitDataTypeMap(); + return *g_data_type_map_; +} + +template +static inline void RegisterType(DataTypeMap* map, + proto::VarType::Type proto_type, + const std::string& name) { + map->proto_to_cpp_.emplace(static_cast(proto_type), typeid(T)); + map->cpp_to_proto_.emplace(typeid(T), proto_type); + map->proto_to_str_.emplace(static_cast(proto_type), name); + map->cpp_to_size_.emplace(typeid(T), sizeof(T)); +} + +static DataTypeMap* InitDataTypeMap() { + auto retv = new DataTypeMap(); + +#define RegType(cc_type, proto_type) \ + RegisterType(retv, proto_type, #cc_type) + + // NOTE: Add your customize type here. + RegType(platform::float16, proto::VarType::FP16); + RegType(float, proto::VarType::FP32); + RegType(double, proto::VarType::FP64); + RegType(int, proto::VarType::INT32); + RegType(int64_t, proto::VarType::INT64); + RegType(bool, proto::VarType::BOOL); + RegType(size_t, proto::VarType::SIZE_T); + RegType(int16_t, proto::VarType::INT16); + +#undef RegType + return retv; +} + +proto::VarType::Type ToDataType(std::type_index type) { + auto it = gDataTypeMap().cpp_to_proto_.find(type); + if (it != gDataTypeMap().cpp_to_proto_.end()) { + return it->second; + } + PADDLE_THROW("Not support %s as tensor type", type.name()); +} + +std::type_index ToTypeIndex(proto::VarType::Type type) { + auto it = gDataTypeMap().proto_to_cpp_.find(static_cast(type)); + if (it != gDataTypeMap().proto_to_cpp_.end()) { + return it->second; + } + PADDLE_THROW("Not support proto::VarType::Type(%d) as tensor type", + static_cast(type)); +} + +std::string DataTypeToString(const proto::VarType::Type type) { + auto it = gDataTypeMap().proto_to_str_.find(static_cast(type)); + if (it != gDataTypeMap().proto_to_str_.end()) { + return it->second; + } + PADDLE_THROW("Not support proto::VarType::Type(%d) as tensor type", + static_cast(type)); +} + +size_t SizeOfType(std::type_index type) { + auto it = gDataTypeMap().cpp_to_size_.find(type); + if (it != gDataTypeMap().cpp_to_size_.end()) { + return it->second; + } + PADDLE_THROW("Not support %s as tensor type", type.name()); +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/data_type.h b/paddle/fluid/framework/data_type.h index 2a528eb3aa562568c92059250f2c9bc5a75ec103..4b9f572ec5f1cda71c8b8dd8fae54b42e9f16f7a 100644 --- a/paddle/fluid/framework/data_type.h +++ b/paddle/fluid/framework/data_type.h @@ -17,51 +17,14 @@ limitations under the License. */ #include #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/platform/enforce.h" + #include "paddle/fluid/platform/float16.h" namespace paddle { namespace framework { -inline proto::VarType::Type ToDataType(std::type_index type) { - if (typeid(platform::float16).hash_code() == type.hash_code()) { - return proto::VarType::FP16; - } else if (typeid(const float).hash_code() == type.hash_code()) { - // CPPLint complains Using C-style cast. Use static_cast() instead - // One fix to this is to replace float with const float because - // typeid(T) == typeid(const T) - // http://en.cppreference.com/w/cpp/language/typeid - return proto::VarType::FP32; - } else if (typeid(const double).hash_code() == type.hash_code()) { - return proto::VarType::FP64; - } else if (typeid(const int).hash_code() == type.hash_code()) { - return proto::VarType::INT32; - } else if (typeid(const int64_t).hash_code() == type.hash_code()) { - return proto::VarType::INT64; - } else if (typeid(const bool).hash_code() == type.hash_code()) { - return proto::VarType::BOOL; - } else { - PADDLE_THROW("Not supported"); - } -} - -inline std::type_index ToTypeIndex(proto::VarType::Type type) { - switch (type) { - case proto::VarType::FP16: - return typeid(platform::float16); - case proto::VarType::FP32: - return typeid(float); - case proto::VarType::FP64: - return typeid(double); - case proto::VarType::INT32: - return typeid(int); - case proto::VarType::INT64: - return typeid(int64_t); - case proto::VarType::BOOL: - return typeid(bool); - default: - PADDLE_THROW("Not support type %d", type); - } -} +extern proto::VarType::Type ToDataType(std::type_index type); +extern std::type_index ToTypeIndex(proto::VarType::Type type); template inline void VisitDataType(proto::VarType::Type type, Visitor visitor) { @@ -89,32 +52,12 @@ inline void VisitDataType(proto::VarType::Type type, Visitor visitor) { } } -inline std::string DataTypeToString(const proto::VarType::Type type) { - switch (type) { - case proto::VarType::FP16: - return "float16"; - case proto::VarType::FP32: - return "float32"; - case proto::VarType::FP64: - return "float64"; - case proto::VarType::INT16: - return "int16"; - case proto::VarType::INT32: - return "int32"; - case proto::VarType::INT64: - return "int64"; - case proto::VarType::BOOL: - return "bool"; - default: - PADDLE_THROW("Not support type %d", type); - } -} - +extern std::string DataTypeToString(const proto::VarType::Type type); +extern size_t SizeOfType(std::type_index type); inline std::ostream& operator<<(std::ostream& out, const proto::VarType::Type& type) { out << DataTypeToString(type); return out; } - } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/framework.proto b/paddle/fluid/framework/framework.proto index 96f53dc1bc8747e1b8ea84166614f98ff363ae5e..d2558f111f49139b33f921f7260b41830279edc8 100644 --- a/paddle/fluid/framework/framework.proto +++ b/paddle/fluid/framework/framework.proto @@ -101,6 +101,8 @@ message VarType { FP16 = 4; FP32 = 5; FP64 = 6; + // Tensor is used in C++. + SIZE_T = 19; // Other types that may need additional descriptions LOD_TENSOR = 7; diff --git a/paddle/fluid/framework/op_kernel_type_test.cc b/paddle/fluid/framework/op_kernel_type_test.cc index d37ce149ce3df63692b41289bb03448d54e392f5..db95861c510b52a5b52229541434e6437d3fb9f4 100644 --- a/paddle/fluid/framework/op_kernel_type_test.cc +++ b/paddle/fluid/framework/op_kernel_type_test.cc @@ -27,7 +27,7 @@ TEST(OpKernelType, ToString) { LibraryType::kCUDNN); ASSERT_EQ(paddle::framework::KernelTypeToString(op_kernel_type), - "data_type[float32]:data_layout[NCHW]:place[CPUPlace]:library_type[" + "data_type[float]:data_layout[NCHW]:place[CPUPlace]:library_type[" "CUDNN]"); } diff --git a/paddle/fluid/framework/tensor_impl.h b/paddle/fluid/framework/tensor_impl.h index f49d1a47a325b2aac6185073203df124be18b54d..0a1db7758bd9ec0dac133efcbf495de1d690021d 100644 --- a/paddle/fluid/framework/tensor_impl.h +++ b/paddle/fluid/framework/tensor_impl.h @@ -13,54 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/float16.h" namespace paddle { namespace framework { - -template -struct SizeOfTypeFunctor; - -template -struct SizeOfTypeFunctor { - size_t operator()(std::type_index type) const { - if (typeid(T).hash_code() == type.hash_code()) { - return sizeof(T); - } else { - return 0UL; - } - } -}; - -template <> -struct SizeOfTypeFunctor<> { - size_t operator()(std::type_index type) const { return 0UL; } -}; - -template -struct SizeOfTypeFunctor { - size_t operator()(std::type_index type) const { - SizeOfTypeFunctor head; - size_t head_size = head(type); - if (head_size != 0) { - return head_size; - } - SizeOfTypeFunctor tail; - return tail(type); - } -}; - -static inline size_t SizeOfType(std::type_index type) { - SizeOfTypeFunctor - functor; - size_t size = functor(type); - PADDLE_ENFORCE(size != 0UL, "Cannot get size of type %s", type.name()); - return size; -} - +extern size_t SizeOfType(std::type_index type); inline void Tensor::check_memory_size() const { PADDLE_ENFORCE_NOT_NULL( holder_, "Tensor holds no memory. Call Tensor::mutable_data first.");