From c4d6daac58f89c234e8a250c4e1d2c080d83972b Mon Sep 17 00:00:00 2001
From: yuyang18 <reyoung@126.com>
Date: Fri, 11 May 2018 12:28:25 +0800
Subject: [PATCH] Polish SizeOfType

---
 paddle/fluid/framework/data_type.cc  | 12 ++++++++
 paddle/fluid/framework/data_type.h   |  3 ++
 paddle/fluid/framework/tensor_impl.h | 44 ++--------------------------
 3 files changed, 17 insertions(+), 42 deletions(-)

diff --git a/paddle/fluid/framework/data_type.cc b/paddle/fluid/framework/data_type.cc
index b60c6186588..f3225849002 100644
--- a/paddle/fluid/framework/data_type.cc
+++ b/paddle/fluid/framework/data_type.cc
@@ -21,6 +21,7 @@ struct DataTypeMap {
   std::unordered_map<std::type_index, proto::VarType::Type> cpp_to_proto_;
   std::unordered_map<proto::VarType::Type, std::type_index> proto_to_cpp_;
   std::unordered_map<proto::VarType::Type, std::string> proto_to_str_;
+  std::unordered_map<std::type_index, size_t> cpp_to_size_;
 };
 
 static DataTypeMap g_data_type_map_;
@@ -31,11 +32,13 @@ static inline void RegisterType(proto::VarType::Type proto_type,
   g_data_type_map_.proto_to_cpp_.emplace(proto_type, typeid(T));
   g_data_type_map_.cpp_to_proto_.emplace(typeid(T), proto_type);
   g_data_type_map_.proto_to_str_.emplace(proto_type, name);
+  g_data_type_map_.cpp_to_size_.emplace(typeid(T), sizeof(T));
 }
 
 static int RegisterAllTypes() {
 #define RegType(cc_type, proto_type) RegisterType<cc_type>(proto_type, #cc_type)
 
+  // NOTE: Add your customize type here.
   RegType(platform::float16, proto::VarType::FP16);
   RegType(float, proto::VarType::FP32);
   RegType(double, proto::VarType::FP64);
@@ -78,5 +81,14 @@ std::string DataTypeToString(const proto::VarType::Type type) {
                static_cast<int>(type));
 }
 
+size_t SizeOfType(std::type_index type) {
+  std::call_once(register_once_flag_, RegisterAllTypes);
+  auto it = g_data_type_map_.cpp_to_size_.find(type);
+  if (it != g_data_type_map_.cpp_to_size_.end()) {
+    return it->second;
+  }
+  PADDLE_THROW("Not support %s as tensor type", type.name());
+}
+
 }  // namespace framework
 }  // namespace paddle
diff --git a/paddle/fluid/framework/data_type.h b/paddle/fluid/framework/data_type.h
index 06cc5940b75..4b9f572ec5f 100644
--- a/paddle/fluid/framework/data_type.h
+++ b/paddle/fluid/framework/data_type.h
@@ -17,6 +17,7 @@ limitations under the License. */
 #include <typeindex>
 #include "paddle/fluid/framework/framework.pb.h"
 #include "paddle/fluid/platform/enforce.h"
+
 #include "paddle/fluid/platform/float16.h"
 
 namespace paddle {
@@ -24,6 +25,7 @@ namespace framework {
 
 extern proto::VarType::Type ToDataType(std::type_index type);
 extern std::type_index ToTypeIndex(proto::VarType::Type type);
+
 template <typename Visitor>
 inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
   switch (type) {
@@ -51,6 +53,7 @@ inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
 }
 
 extern std::string DataTypeToString(const proto::VarType::Type type);
+extern size_t SizeOfType(std::type_index type);
 inline std::ostream& operator<<(std::ostream& out,
                                 const proto::VarType::Type& type) {
   out << DataTypeToString(type);
diff --git a/paddle/fluid/framework/tensor_impl.h b/paddle/fluid/framework/tensor_impl.h
index f49d1a47a32..0a1db7758bd 100644
--- a/paddle/fluid/framework/tensor_impl.h
+++ b/paddle/fluid/framework/tensor_impl.h
@@ -13,54 +13,14 @@ See the License for the specific language governing permissions and
 limitations under the License. */
 
 #pragma once
+#include "paddle/fluid/framework/data_type.h"
 #include "paddle/fluid/memory/memcpy.h"
 #include "paddle/fluid/platform/enforce.h"
 #include "paddle/fluid/platform/float16.h"
 
 namespace paddle {
 namespace framework {
-
-template <typename... T>
-struct SizeOfTypeFunctor;
-
-template <typename T>
-struct SizeOfTypeFunctor<T> {
-  size_t operator()(std::type_index type) const {
-    if (typeid(T).hash_code() == type.hash_code()) {
-      return sizeof(T);
-    } else {
-      return 0UL;
-    }
-  }
-};
-
-template <>
-struct SizeOfTypeFunctor<> {
-  size_t operator()(std::type_index type) const { return 0UL; }
-};
-
-template <typename HEAD, typename... TAIL>
-struct SizeOfTypeFunctor<HEAD, TAIL...> {
-  size_t operator()(std::type_index type) const {
-    SizeOfTypeFunctor<HEAD> head;
-    size_t head_size = head(type);
-    if (head_size != 0) {
-      return head_size;
-    }
-    SizeOfTypeFunctor<TAIL...> tail;
-    return tail(type);
-  }
-};
-
-static inline size_t SizeOfType(std::type_index type) {
-  SizeOfTypeFunctor<int, float, double, int16_t, int64_t, bool, size_t,
-                    platform::float16>
-      functor;
-  size_t size = functor(type);
-  PADDLE_ENFORCE(size != 0UL, "Cannot get size of type %s", type.name());
-  return size;
-}
-
+extern size_t SizeOfType(std::type_index type);
 inline void Tensor::check_memory_size() const {
   PADDLE_ENFORCE_NOT_NULL(
       holder_, "Tensor holds no memory. Call Tensor::mutable_data first.");
-- 
GitLab