提交 c4d6daac 编写于 作者: Y yuyang18

Polish SizeOfType

上级 711d86bb
......@@ -21,6 +21,7 @@ struct DataTypeMap {
std::unordered_map<std::type_index, proto::VarType::Type> cpp_to_proto_;
std::unordered_map<proto::VarType::Type, std::type_index> proto_to_cpp_;
std::unordered_map<proto::VarType::Type, std::string> proto_to_str_;
std::unordered_map<std::type_index, size_t> cpp_to_size_;
};
static DataTypeMap g_data_type_map_;
......@@ -31,11 +32,13 @@ static inline void RegisterType(proto::VarType::Type proto_type,
g_data_type_map_.proto_to_cpp_.emplace(proto_type, typeid(T));
g_data_type_map_.cpp_to_proto_.emplace(typeid(T), proto_type);
g_data_type_map_.proto_to_str_.emplace(proto_type, name);
g_data_type_map_.cpp_to_size_.emplace(typeid(T), sizeof(T));
}
static int RegisterAllTypes() {
#define RegType(cc_type, proto_type) RegisterType<cc_type>(proto_type, #cc_type)
// NOTE: Add your customize type here.
RegType(platform::float16, proto::VarType::FP16);
RegType(float, proto::VarType::FP32);
RegType(double, proto::VarType::FP64);
......@@ -78,5 +81,14 @@ std::string DataTypeToString(const proto::VarType::Type type) {
static_cast<int>(type));
}
size_t SizeOfType(std::type_index type) {
std::call_once(register_once_flag_, RegisterAllTypes);
auto it = g_data_type_map_.cpp_to_size_.find(type);
if (it != g_data_type_map_.cpp_to_size_.end()) {
return it->second;
}
PADDLE_THROW("Not support %s as tensor type", type.name());
}
} // namespace framework
} // namespace paddle
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <typeindex>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
......@@ -24,6 +25,7 @@ namespace framework {
extern proto::VarType::Type ToDataType(std::type_index type);
extern std::type_index ToTypeIndex(proto::VarType::Type type);
template <typename Visitor>
inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
switch (type) {
......@@ -51,6 +53,7 @@ inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
}
extern std::string DataTypeToString(const proto::VarType::Type type);
extern size_t SizeOfType(std::type_index type);
inline std::ostream& operator<<(std::ostream& out,
const proto::VarType::Type& type) {
out << DataTypeToString(type);
......
......@@ -13,54 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace framework {
template <typename... T>
struct SizeOfTypeFunctor;
template <typename T>
struct SizeOfTypeFunctor<T> {
size_t operator()(std::type_index type) const {
if (typeid(T).hash_code() == type.hash_code()) {
return sizeof(T);
} else {
return 0UL;
}
}
};
template <>
struct SizeOfTypeFunctor<> {
size_t operator()(std::type_index type) const { return 0UL; }
};
template <typename HEAD, typename... TAIL>
struct SizeOfTypeFunctor<HEAD, TAIL...> {
size_t operator()(std::type_index type) const {
SizeOfTypeFunctor<HEAD> head;
size_t head_size = head(type);
if (head_size != 0) {
return head_size;
}
SizeOfTypeFunctor<TAIL...> tail;
return tail(type);
}
};
static inline size_t SizeOfType(std::type_index type) {
SizeOfTypeFunctor<int, float, double, int16_t, int64_t, bool, size_t,
platform::float16>
functor;
size_t size = functor(type);
PADDLE_ENFORCE(size != 0UL, "Cannot get size of type %s", type.name());
return size;
}
extern size_t SizeOfType(std::type_index type);
inline void Tensor::check_memory_size() const {
PADDLE_ENFORCE_NOT_NULL(
holder_, "Tensor holds no memory. Call Tensor::mutable_data first.");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册