From ea9e150d2bb469b680e4a43795750488d56ce053 Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Sat, 8 Feb 2020 10:49:29 +0800 Subject: [PATCH] Implement a class PrecisionTypeTrait to get the PrecisionType of a c++ data type. (#2835) test=develop --- lite/api/paddle_place.h | 28 ++++++++++++++++++++++++++++ lite/core/tensor.h | 17 +---------------- 2 files changed, 29 insertions(+), 16 deletions(-) diff --git a/lite/api/paddle_place.h b/lite/api/paddle_place.h index c8f136ace8..7da52adc7f 100644 --- a/lite/api/paddle_place.h +++ b/lite/api/paddle_place.h @@ -116,6 +116,34 @@ static size_t PrecisionTypeLength(PrecisionType type) { } } +template +struct PrecisionTypeTrait { + constexpr static PrecisionType Type() { return PrecisionType::kUnk; } +}; + +#define _ForEachPrecisionTypeHelper(callback, cpp_type, precision_type) \ + callback(cpp_type, ::paddle::lite_api::PrecisionType::precision_type); + +#define _ForEachPrecisionType(callback) \ + _ForEachPrecisionTypeHelper(callback, bool, kBool); \ + _ForEachPrecisionTypeHelper(callback, float, kFloat); \ + _ForEachPrecisionTypeHelper(callback, int8_t, kInt8); \ + _ForEachPrecisionTypeHelper(callback, int16_t, kInt16); \ + _ForEachPrecisionTypeHelper(callback, int, kInt32); \ + _ForEachPrecisionTypeHelper(callback, int64_t, kInt64); + +#define DefinePrecisionTypeTrait(cpp_type, precision_type) \ + template <> \ + struct PrecisionTypeTrait { \ + constexpr static PrecisionType Type() { return precision_type; } \ + } + +_ForEachPrecisionType(DefinePrecisionTypeTrait); + +#undef _ForEachPrecisionTypeHelper +#undef _ForEachPrecisionType +#undef DefinePrecisionTypeTrait + #define TARGET(item__) paddle::lite_api::TargetType::item__ #define PRECISION(item__) paddle::lite_api::PrecisionType::item__ #define DATALAYOUT(item__) paddle::lite_api::DataLayoutType::item__ diff --git a/lite/core/tensor.h b/lite/core/tensor.h index 41a2d16f75..3e334048fa 100644 --- a/lite/core/tensor.h +++ b/lite/core/tensor.h @@ -139,22 +139,7 @@ class TensorLite { // For other devices, T and R may be the same type. template R *mutable_data() { - auto type_id = typeid(T).hash_code(); - if (type_id == typeid(bool).hash_code()) { // NOLINT - precision_ = PrecisionType::kBool; - } else if (type_id == typeid(float).hash_code()) { // NOLINT - precision_ = PrecisionType::kFloat; - } else if (type_id == typeid(int8_t).hash_code()) { - precision_ = PrecisionType::kInt8; - } else if (type_id == typeid(int16_t).hash_code()) { - precision_ = PrecisionType::kInt16; - } else if (type_id == typeid(int32_t).hash_code()) { - precision_ = PrecisionType::kInt32; - } else if (type_id == typeid(int64_t).hash_code()) { - precision_ = PrecisionType::kInt64; - } else { - precision_ = PrecisionType::kUnk; - } + precision_ = lite_api::PrecisionTypeTrait::Type(); memory_size_ = dims_.production() * sizeof(T); buffer_->ResetLazy(target_, memory_size_); return reinterpret_cast(static_cast(buffer_->data()) + -- GitLab