未验证 提交 e73d4652 编写于 作者: Y Yiqun Liu 提交者: GitHub

Implement a class PrecisionTypeTrait to get the PrecisionType of a c++ data type. (#2835)

test=develop
上级 8cfd96f2
...@@ -116,6 +116,34 @@ static size_t PrecisionTypeLength(PrecisionType type) { ...@@ -116,6 +116,34 @@ static size_t PrecisionTypeLength(PrecisionType type) {
} }
} }
template <typename T>
struct PrecisionTypeTrait {
constexpr static PrecisionType Type() { return PrecisionType::kUnk; }
};
#define _ForEachPrecisionTypeHelper(callback, cpp_type, precision_type) \
callback(cpp_type, ::paddle::lite_api::PrecisionType::precision_type);
#define _ForEachPrecisionType(callback) \
_ForEachPrecisionTypeHelper(callback, bool, kBool); \
_ForEachPrecisionTypeHelper(callback, float, kFloat); \
_ForEachPrecisionTypeHelper(callback, int8_t, kInt8); \
_ForEachPrecisionTypeHelper(callback, int16_t, kInt16); \
_ForEachPrecisionTypeHelper(callback, int, kInt32); \
_ForEachPrecisionTypeHelper(callback, int64_t, kInt64);
#define DefinePrecisionTypeTrait(cpp_type, precision_type) \
template <> \
struct PrecisionTypeTrait<cpp_type> { \
constexpr static PrecisionType Type() { return precision_type; } \
}
_ForEachPrecisionType(DefinePrecisionTypeTrait);
#undef _ForEachPrecisionTypeHelper
#undef _ForEachPrecisionType
#undef DefinePrecisionTypeTrait
#define TARGET(item__) paddle::lite_api::TargetType::item__ #define TARGET(item__) paddle::lite_api::TargetType::item__
#define PRECISION(item__) paddle::lite_api::PrecisionType::item__ #define PRECISION(item__) paddle::lite_api::PrecisionType::item__
#define DATALAYOUT(item__) paddle::lite_api::DataLayoutType::item__ #define DATALAYOUT(item__) paddle::lite_api::DataLayoutType::item__
......
...@@ -139,22 +139,7 @@ class TensorLite { ...@@ -139,22 +139,7 @@ class TensorLite {
// For other devices, T and R may be the same type. // For other devices, T and R may be the same type.
template <typename T, typename R = T> template <typename T, typename R = T>
R *mutable_data() { R *mutable_data() {
auto type_id = typeid(T).hash_code(); precision_ = lite_api::PrecisionTypeTrait<T>::Type();
if (type_id == typeid(bool).hash_code()) { // NOLINT
precision_ = PrecisionType::kBool;
} else if (type_id == typeid(float).hash_code()) { // NOLINT
precision_ = PrecisionType::kFloat;
} else if (type_id == typeid(int8_t).hash_code()) {
precision_ = PrecisionType::kInt8;
} else if (type_id == typeid(int16_t).hash_code()) {
precision_ = PrecisionType::kInt16;
} else if (type_id == typeid(int32_t).hash_code()) {
precision_ = PrecisionType::kInt32;
} else if (type_id == typeid(int64_t).hash_code()) {
precision_ = PrecisionType::kInt64;
} else {
precision_ = PrecisionType::kUnk;
}
memory_size_ = dims_.production() * sizeof(T); memory_size_ = dims_.production() * sizeof(T);
buffer_->ResetLazy(target_, memory_size_); buffer_->ResetLazy(target_, memory_size_);
return reinterpret_cast<R *>(static_cast<char *>(buffer_->data()) + return reinterpret_cast<R *>(static_cast<char *>(buffer_->data()) +
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册