diff --git a/lite/api/paddle_place.h b/lite/api/paddle_place.h index c8f136ace8d536f9bcca30c3ab24618b2c0a78e5..7da52adc7fb6fdd70de3b098508e4622496bed7d 100644 --- a/lite/api/paddle_place.h +++ b/lite/api/paddle_place.h @@ -116,6 +116,34 @@ static size_t PrecisionTypeLength(PrecisionType type) { } } +template +struct PrecisionTypeTrait { + constexpr static PrecisionType Type() { return PrecisionType::kUnk; } +}; + +#define _ForEachPrecisionTypeHelper(callback, cpp_type, precision_type) \ + callback(cpp_type, ::paddle::lite_api::PrecisionType::precision_type); + +#define _ForEachPrecisionType(callback) \ + _ForEachPrecisionTypeHelper(callback, bool, kBool); \ + _ForEachPrecisionTypeHelper(callback, float, kFloat); \ + _ForEachPrecisionTypeHelper(callback, int8_t, kInt8); \ + _ForEachPrecisionTypeHelper(callback, int16_t, kInt16); \ + _ForEachPrecisionTypeHelper(callback, int, kInt32); \ + _ForEachPrecisionTypeHelper(callback, int64_t, kInt64); + +#define DefinePrecisionTypeTrait(cpp_type, precision_type) \ + template <> \ + struct PrecisionTypeTrait { \ + constexpr static PrecisionType Type() { return precision_type; } \ + } + +_ForEachPrecisionType(DefinePrecisionTypeTrait); + +#undef _ForEachPrecisionTypeHelper +#undef _ForEachPrecisionType +#undef DefinePrecisionTypeTrait + #define TARGET(item__) paddle::lite_api::TargetType::item__ #define PRECISION(item__) paddle::lite_api::PrecisionType::item__ #define DATALAYOUT(item__) paddle::lite_api::DataLayoutType::item__ diff --git a/lite/core/tensor.h b/lite/core/tensor.h index 41a2d16f75f946c9ef8250d3e2af1ac6ee370d60..3e334048fabb0a26951dadd173e4cf42ee4d8099 100644 --- a/lite/core/tensor.h +++ b/lite/core/tensor.h @@ -139,22 +139,7 @@ class TensorLite { // For other devices, T and R may be the same type. template R *mutable_data() { - auto type_id = typeid(T).hash_code(); - if (type_id == typeid(bool).hash_code()) { // NOLINT - precision_ = PrecisionType::kBool; - } else if (type_id == typeid(float).hash_code()) { // NOLINT - precision_ = PrecisionType::kFloat; - } else if (type_id == typeid(int8_t).hash_code()) { - precision_ = PrecisionType::kInt8; - } else if (type_id == typeid(int16_t).hash_code()) { - precision_ = PrecisionType::kInt16; - } else if (type_id == typeid(int32_t).hash_code()) { - precision_ = PrecisionType::kInt32; - } else if (type_id == typeid(int64_t).hash_code()) { - precision_ = PrecisionType::kInt64; - } else { - precision_ = PrecisionType::kUnk; - } + precision_ = lite_api::PrecisionTypeTrait::Type(); memory_size_ = dims_.production() * sizeof(T); buffer_->ResetLazy(target_, memory_size_); return reinterpret_cast(static_cast(buffer_->data()) +