提交 c4d6daac 编写于 作者: Y yuyang18

Polish SizeOfType

上级 711d86bb
...@@ -21,6 +21,7 @@ struct DataTypeMap { ...@@ -21,6 +21,7 @@ struct DataTypeMap {
std::unordered_map<std::type_index, proto::VarType::Type> cpp_to_proto_; std::unordered_map<std::type_index, proto::VarType::Type> cpp_to_proto_;
std::unordered_map<proto::VarType::Type, std::type_index> proto_to_cpp_; std::unordered_map<proto::VarType::Type, std::type_index> proto_to_cpp_;
std::unordered_map<proto::VarType::Type, std::string> proto_to_str_; std::unordered_map<proto::VarType::Type, std::string> proto_to_str_;
std::unordered_map<std::type_index, size_t> cpp_to_size_;
}; };
static DataTypeMap g_data_type_map_; static DataTypeMap g_data_type_map_;
...@@ -31,11 +32,13 @@ static inline void RegisterType(proto::VarType::Type proto_type, ...@@ -31,11 +32,13 @@ static inline void RegisterType(proto::VarType::Type proto_type,
g_data_type_map_.proto_to_cpp_.emplace(proto_type, typeid(T)); g_data_type_map_.proto_to_cpp_.emplace(proto_type, typeid(T));
g_data_type_map_.cpp_to_proto_.emplace(typeid(T), proto_type); g_data_type_map_.cpp_to_proto_.emplace(typeid(T), proto_type);
g_data_type_map_.proto_to_str_.emplace(proto_type, name); g_data_type_map_.proto_to_str_.emplace(proto_type, name);
g_data_type_map_.cpp_to_size_.emplace(typeid(T), sizeof(T));
} }
static int RegisterAllTypes() { static int RegisterAllTypes() {
#define RegType(cc_type, proto_type) RegisterType<cc_type>(proto_type, #cc_type) #define RegType(cc_type, proto_type) RegisterType<cc_type>(proto_type, #cc_type)
// NOTE: Add your customize type here.
RegType(platform::float16, proto::VarType::FP16); RegType(platform::float16, proto::VarType::FP16);
RegType(float, proto::VarType::FP32); RegType(float, proto::VarType::FP32);
RegType(double, proto::VarType::FP64); RegType(double, proto::VarType::FP64);
...@@ -78,5 +81,14 @@ std::string DataTypeToString(const proto::VarType::Type type) { ...@@ -78,5 +81,14 @@ std::string DataTypeToString(const proto::VarType::Type type) {
static_cast<int>(type)); static_cast<int>(type));
} }
size_t SizeOfType(std::type_index type) {
std::call_once(register_once_flag_, RegisterAllTypes);
auto it = g_data_type_map_.cpp_to_size_.find(type);
if (it != g_data_type_map_.cpp_to_size_.end()) {
return it->second;
}
PADDLE_THROW("Not support %s as tensor type", type.name());
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <typeindex> #include <typeindex>
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
namespace paddle { namespace paddle {
...@@ -24,6 +25,7 @@ namespace framework { ...@@ -24,6 +25,7 @@ namespace framework {
extern proto::VarType::Type ToDataType(std::type_index type); extern proto::VarType::Type ToDataType(std::type_index type);
extern std::type_index ToTypeIndex(proto::VarType::Type type); extern std::type_index ToTypeIndex(proto::VarType::Type type);
template <typename Visitor> template <typename Visitor>
inline void VisitDataType(proto::VarType::Type type, Visitor visitor) { inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
switch (type) { switch (type) {
...@@ -51,6 +53,7 @@ inline void VisitDataType(proto::VarType::Type type, Visitor visitor) { ...@@ -51,6 +53,7 @@ inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
} }
extern std::string DataTypeToString(const proto::VarType::Type type); extern std::string DataTypeToString(const proto::VarType::Type type);
extern size_t SizeOfType(std::type_index type);
inline std::ostream& operator<<(std::ostream& out, inline std::ostream& operator<<(std::ostream& out,
const proto::VarType::Type& type) { const proto::VarType::Type& type) {
out << DataTypeToString(type); out << DataTypeToString(type);
......
...@@ -13,54 +13,14 @@ See the License for the specific language governing permissions and ...@@ -13,54 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
extern size_t SizeOfType(std::type_index type);
template <typename... T>
struct SizeOfTypeFunctor;
template <typename T>
struct SizeOfTypeFunctor<T> {
size_t operator()(std::type_index type) const {
if (typeid(T).hash_code() == type.hash_code()) {
return sizeof(T);
} else {
return 0UL;
}
}
};
template <>
struct SizeOfTypeFunctor<> {
size_t operator()(std::type_index type) const { return 0UL; }
};
template <typename HEAD, typename... TAIL>
struct SizeOfTypeFunctor<HEAD, TAIL...> {
size_t operator()(std::type_index type) const {
SizeOfTypeFunctor<HEAD> head;
size_t head_size = head(type);
if (head_size != 0) {
return head_size;
}
SizeOfTypeFunctor<TAIL...> tail;
return tail(type);
}
};
static inline size_t SizeOfType(std::type_index type) {
SizeOfTypeFunctor<int, float, double, int16_t, int64_t, bool, size_t,
platform::float16>
functor;
size_t size = functor(type);
PADDLE_ENFORCE(size != 0UL, "Cannot get size of type %s", type.name());
return size;
}
inline void Tensor::check_memory_size() const { inline void Tensor::check_memory_size() const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
holder_, "Tensor holds no memory. Call Tensor::mutable_data first."); holder_, "Tensor holds no memory. Call Tensor::mutable_data first.");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册