diff --git a/paddle/fluid/framework/data_type.cc b/paddle/fluid/framework/data_type.cc index b60c6186588988e0156ae36c1dc549b1cda4cba0..f322584900269fac9feba1d04cd3a73648d421ee 100644 --- a/paddle/fluid/framework/data_type.cc +++ b/paddle/fluid/framework/data_type.cc @@ -21,6 +21,7 @@ struct DataTypeMap { std::unordered_map cpp_to_proto_; std::unordered_map proto_to_cpp_; std::unordered_map proto_to_str_; + std::unordered_map cpp_to_size_; }; static DataTypeMap g_data_type_map_; @@ -31,11 +32,13 @@ static inline void RegisterType(proto::VarType::Type proto_type, g_data_type_map_.proto_to_cpp_.emplace(proto_type, typeid(T)); g_data_type_map_.cpp_to_proto_.emplace(typeid(T), proto_type); g_data_type_map_.proto_to_str_.emplace(proto_type, name); + g_data_type_map_.cpp_to_size_.emplace(typeid(T), sizeof(T)); } static int RegisterAllTypes() { #define RegType(cc_type, proto_type) RegisterType(proto_type, #cc_type) + // NOTE: Add your customize type here. RegType(platform::float16, proto::VarType::FP16); RegType(float, proto::VarType::FP32); RegType(double, proto::VarType::FP64); @@ -78,5 +81,14 @@ std::string DataTypeToString(const proto::VarType::Type type) { static_cast(type)); } +size_t SizeOfType(std::type_index type) { + std::call_once(register_once_flag_, RegisterAllTypes); + auto it = g_data_type_map_.cpp_to_size_.find(type); + if (it != g_data_type_map_.cpp_to_size_.end()) { + return it->second; + } + PADDLE_THROW("Not support %s as tensor type", type.name()); +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/data_type.h b/paddle/fluid/framework/data_type.h index 06cc5940b75fd3ff7259fefadb202b1693b39879..4b9f572ec5f1cda71c8b8dd8fae54b42e9f16f7a 100644 --- a/paddle/fluid/framework/data_type.h +++ b/paddle/fluid/framework/data_type.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/platform/enforce.h" + #include "paddle/fluid/platform/float16.h" namespace paddle { @@ -24,6 +25,7 @@ namespace framework { extern proto::VarType::Type ToDataType(std::type_index type); extern std::type_index ToTypeIndex(proto::VarType::Type type); + template inline void VisitDataType(proto::VarType::Type type, Visitor visitor) { switch (type) { @@ -51,6 +53,7 @@ inline void VisitDataType(proto::VarType::Type type, Visitor visitor) { } extern std::string DataTypeToString(const proto::VarType::Type type); +extern size_t SizeOfType(std::type_index type); inline std::ostream& operator<<(std::ostream& out, const proto::VarType::Type& type) { out << DataTypeToString(type); diff --git a/paddle/fluid/framework/tensor_impl.h b/paddle/fluid/framework/tensor_impl.h index f49d1a47a325b2aac6185073203df124be18b54d..0a1db7758bd9ec0dac133efcbf495de1d690021d 100644 --- a/paddle/fluid/framework/tensor_impl.h +++ b/paddle/fluid/framework/tensor_impl.h @@ -13,54 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/float16.h" namespace paddle { namespace framework { - -template -struct SizeOfTypeFunctor; - -template -struct SizeOfTypeFunctor { - size_t operator()(std::type_index type) const { - if (typeid(T).hash_code() == type.hash_code()) { - return sizeof(T); - } else { - return 0UL; - } - } -}; - -template <> -struct SizeOfTypeFunctor<> { - size_t operator()(std::type_index type) const { return 0UL; } -}; - -template -struct SizeOfTypeFunctor { - size_t operator()(std::type_index type) const { - SizeOfTypeFunctor head; - size_t head_size = head(type); - if (head_size != 0) { - return head_size; - } - SizeOfTypeFunctor tail; - return tail(type); - } -}; - -static inline size_t SizeOfType(std::type_index type) { - SizeOfTypeFunctor - functor; - size_t size = functor(type); - PADDLE_ENFORCE(size != 0UL, "Cannot get size of type %s", type.name()); - return size; -} - +extern size_t SizeOfType(std::type_index type); inline void Tensor::check_memory_size() const { PADDLE_ENFORCE_NOT_NULL( holder_, "Tensor holds no memory. Call Tensor::mutable_data first.");