From 7701ac0d5cde3e27c7c2c268532ccc686b0f54c3 Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Fri, 29 Mar 2019 13:20:17 +0800 Subject: [PATCH] Use enum type rather than string since comparing between enum is more efficent than string --- src/common/type_define.h | 89 ++++++++++++++++--- src/common/variant.h | 22 ++--- src/framework/attribute.h | 25 +++--- src/framework/data_type.cpp | 20 ++--- src/framework/data_type.h | 8 +- src/framework/ddim.h | 20 ++--- src/framework/load_ops.h | 3 + src/framework/mixed_vector.h | 2 +- src/framework/tensor.h | 38 ++++---- src/framework/tensor_base.h | 21 +++-- src/framework/variable.h | 15 ++-- src/operators/kernel/arm/compare_kernel.cpp | 4 +- src/operators/kernel/arm/concat_kernel.cpp | 2 +- .../kernel/arm/convolution/conv_common.cpp | 2 +- .../kernel/arm/transpose2_kernel.cpp | 4 +- .../kernel/central-arm-func/mul_arm_func.h | 2 +- .../kernel/central-arm-func/sum_arm_func.h | 3 +- 17 files changed, 175 insertions(+), 105 deletions(-) diff --git a/src/common/type_define.h b/src/common/type_define.h index e81863f9fb..d91aa35c9c 100644 --- a/src/common/type_define.h +++ b/src/common/type_define.h @@ -14,28 +14,85 @@ limitations under the License. */ #pragma once +#include #include #include namespace paddle_mobile { +typedef enum { + _void = 0, + _float, + _int, + _double, + _int64_t, + _size_t, + _int16_t, + _int8_t, + _uint8_t, + _bool, + _string, + _floats = 100, + _ints, + _int64_ts, + _size_ts, + _bools, + _strings, + _const_float = 200, + _const_int, + _block = 300, + _tensor, + _lod_tensor, + _blocks, + _tensors, + _lod_tensors, + _p_block = 400, + _p_tensor, + _p_lod_tensor, + _p_blocks, + _p_tensors, + _p_lod_tensors, + _scopes = 500, + _selected_rows, + _dim0 = 600, + _dim1, + _dim2, + _dim3, + _dim4, + _dim5, + _dim6, + _dim7, + _dim8, + _dim9, +} kTypeId_t; + template struct TypeIdWrapper { - std::string type(); + inline std::string name(); + inline kTypeId_t hash_code(); }; template struct type_id { - const std::string type_ = TypeIdWrapper().type(); + const kTypeId_t hash_code() const { return TypeIdWrapper().hash_code(); } + const std::string name() const { return TypeIdWrapper().name(); } template - bool operator==(const type_id &operand) { - return this->name() == operand.name(); + bool operator==(const type_id &operand) const { + return this->hash_code() == operand.hash_code(); } - - const std::string name() { return type_; } }; +template +inline bool operator==(const kTypeId_t &t0, const type_id &t1) { + return t0 == t1.hash_code(); +} + +template +inline bool operator==(const type_id &t0, const kTypeId_t &t1) { + return t1 == t0.hash_code(); +} + namespace framework { class BlockDesc; class Tensor; @@ -47,10 +104,11 @@ template struct Dim; } // namespace framework -#define REGISTER_TYPE_ID(Type, TypeName) \ - template <> \ - struct TypeIdWrapper { \ - std::string type() { return std::string(#TypeName); } \ +#define REGISTER_TYPE_ID(Type, TypeName) \ + template <> \ + struct TypeIdWrapper { \ + inline std::string name() { return std::string(#TypeName); } \ + inline kTypeId_t hash_code() { return kTypeId_t::TypeName; } \ }; REGISTER_TYPE_ID(void, _void) @@ -102,3 +160,14 @@ REGISTER_TYPE_ID(framework::Dim<8>, _dim8) REGISTER_TYPE_ID(framework::Dim<9>, _dim9) } // namespace paddle_mobile + +namespace std { + +template <> +struct hash { + size_t operator()(const paddle_mobile::kTypeId_t &t) const { + return std::hash{}(static_cast(t)); + } +}; + +} // namespace std diff --git a/src/common/variant.h b/src/common/variant.h index 5c8e053406..87fd822243 100644 --- a/src/common/variant.h +++ b/src/common/variant.h @@ -34,8 +34,8 @@ struct VariantHelper { ? sizeof(F) : VariantHelper::size; - inline static void Destroy(std::string type, void *data) { - if (type == type_id().name()) { + inline static void Destroy(kTypeId_t type, void *data) { + if (type == type_id()) { reinterpret_cast(data)->~F(); } else { VariantHelper::Destroy(type, data); @@ -46,8 +46,8 @@ struct VariantHelper { template struct VariantHelper { static const size_t size = sizeof(F); - inline static void Destroy(std::string type, void *data) { - if (type == type_id().name()) { + inline static void Destroy(kTypeId_t type, void *data) { + if (type == type_id()) { // reinterpret_cast(data)->~F(); } else { // std::cout << "未匹配到 " << std::endl; @@ -85,17 +85,17 @@ struct Variant { void Set(Args &&... args) { helper::Destroy(type_, data_.data); new (data_.data) T(std::forward(args)...); - type_ = type_id().name(); + type_ = type_id().hash_code(); } void SetString(const std::string &string) { helper::Destroy(type_, data_.data); - type_ = type_id().name(); + type_ = type_id().hash_code(); strcpy(data_.data, string.c_str()); // NOLINT } std::string GetString() const { - if (type_ == type_id().name()) { + if (type_ == type_id()) { return std::string(data_.data); } else { PADDLE_MOBILE_THROW_EXCEPTION( @@ -106,7 +106,7 @@ struct Variant { template T &Get() const { - if (type_ == type_id().name()) { + if (type_ == type_id()) { PADDLE_MOBILE_THROW_EXCEPTION( "Please use getString to get an string (to avoid of an issue with " "gcc " @@ -117,12 +117,12 @@ struct Variant { } } - std::string TypeId() const { return type_; } + kTypeId_t TypeId() const { return type_; } private: - static inline std::string invalid_type() { return type_id().name(); } + static inline kTypeId_t invalid_type() { return type_id().hash_code(); } typedef VariantHelper helper; - std::string type_ = type_id().name(); + kTypeId_t type_ = type_id().hash_code(); // todo use an anto size to suite this. RawData<64> data_; }; diff --git a/src/framework/attribute.h b/src/framework/attribute.h index d809ec4a72..01c4a8d7d4 100644 --- a/src/framework/attribute.h +++ b/src/framework/attribute.h @@ -128,31 +128,30 @@ class Attribute { template static typename Vistor::type_t ApplyVistor(Vistor vistor, Attribute attr) { - if (attr.variant_.TypeId() == type_id().name()) { // NOLINT + if (attr.variant_.TypeId() == type_id()) { // NOLINT return vistor(attr.variant_.Get()); - } else if (attr.variant_.TypeId() == type_id().name()) { // NOLINT + } else if (attr.variant_.TypeId() == type_id()) { // NOLINT return vistor(attr.variant_.Get()); - } else if (attr.variant_.TypeId() == type_id().name()) { + } else if (attr.variant_.TypeId() == type_id()) { return vistor(attr.variant_.GetString()); - } else if (attr.variant_.TypeId() == type_id>().name()) { + } else if (attr.variant_.TypeId() == type_id>()) { return vistor(attr.variant_.Get>()); - } else if (attr.variant_.TypeId() == type_id>().name()) { + } else if (attr.variant_.TypeId() == type_id>()) { return vistor(attr.variant_.Get>()); - } else if (attr.variant_.TypeId() == type_id>().name()) { + } else if (attr.variant_.TypeId() == type_id>()) { return vistor(attr.variant_.Get>()); - } else if (attr.variant_.TypeId() == type_id().name()) { // NOLINT + } else if (attr.variant_.TypeId() == type_id()) { // NOLINT return vistor(attr.variant_.Get()); - } else if (attr.variant_.TypeId() == type_id>().name()) { + } else if (attr.variant_.TypeId() == type_id>()) { return vistor(attr.variant_.Get>()); - } else if (attr.variant_.TypeId() == type_id().name()) { + } else if (attr.variant_.TypeId() == type_id()) { return vistor(attr.variant_.Get()); - } else if (attr.variant_.TypeId() == - type_id().name()) { + } else if (attr.variant_.TypeId() == type_id()) { return vistor(attr.variant_.Get()); } else if (attr.variant_.TypeId() == - type_id>().name()) { + type_id>()) { return vistor(attr.variant_.Get>()); - } else if (attr.variant_.TypeId() == type_id>().name()) { + } else if (attr.variant_.TypeId() == type_id>()) { return vistor(attr.variant_.Get>()); } else { PADDLE_MOBILE_THROW_EXCEPTION("type not support"); diff --git a/src/framework/data_type.cpp b/src/framework/data_type.cpp index 94272a16b6..5eaf3ecaf5 100644 --- a/src/framework/data_type.cpp +++ b/src/framework/data_type.cpp @@ -22,12 +22,11 @@ namespace paddle_mobile { namespace framework { struct DataTypeMap { - std::unordered_map + std::unordered_map cpp_to_proto_; - std::unordered_map proto_to_cpp_; + std::unordered_map proto_to_cpp_; std::unordered_map proto_to_str_; - std::unordered_map cpp_to_size_; + std::unordered_map cpp_to_size_; }; static DataTypeMap* InitDataTypeMap(); @@ -43,10 +42,11 @@ template static inline void RegisterType( DataTypeMap* map, _PaddleMobile__Framework__Proto__VarType__Type proto_type, const std::string& name) { - map->proto_to_cpp_.emplace(static_cast(proto_type), type_id().name()); - map->cpp_to_proto_.emplace(type_id().name(), proto_type); + map->proto_to_cpp_.emplace(static_cast(proto_type), + type_id().hash_code()); + map->cpp_to_proto_.emplace(type_id().hash_code(), proto_type); map->proto_to_str_.emplace(static_cast(proto_type), name); - map->cpp_to_size_.emplace(type_id().name(), sizeof(T)); + map->cpp_to_size_.emplace(type_id().hash_code(), sizeof(T)); } static DataTypeMap* InitDataTypeMap() { @@ -71,15 +71,15 @@ static DataTypeMap* InitDataTypeMap() { return retv; } -_PaddleMobile__Framework__Proto__VarType__Type ToDataType(std::string type) { +_PaddleMobile__Framework__Proto__VarType__Type ToDataType(kTypeId_t type) { auto it = gDataTypeMap().cpp_to_proto_.find(type); if (it != gDataTypeMap().cpp_to_proto_.end()) { return it->second; } - PADDLE_MOBILE_THROW_EXCEPTION("Not support %s as tensor type", type.c_str()); + PADDLE_MOBILE_THROW_EXCEPTION("Not support %d as tensor type", type); } -std::string ToTypeIndex(_PaddleMobile__Framework__Proto__VarType__Type type) { +kTypeId_t ToTypeIndex(_PaddleMobile__Framework__Proto__VarType__Type type) { auto it = gDataTypeMap().proto_to_cpp_.find(static_cast(type)); if (it != gDataTypeMap().proto_to_cpp_.end()) { return it->second; diff --git a/src/framework/data_type.h b/src/framework/data_type.h index ef7a19ab93..bda823ada4 100644 --- a/src/framework/data_type.h +++ b/src/framework/data_type.h @@ -16,16 +16,16 @@ limitations under the License. */ #include #include "common/enforce.h" +#include "common/type_define.h" #include "framework/framework.pb-c.h" namespace paddle_mobile { namespace framework { -extern _PaddleMobile__Framework__Proto__VarType__Type ToDataType( - std::string type); -extern std::string ToTypeIndex( - _PaddleMobile__Framework__Proto__VarType__Type type); +_PaddleMobile__Framework__Proto__VarType__Type ToDataType(kTypeId_t type); + +kTypeId_t ToTypeIndex(_PaddleMobile__Framework__Proto__VarType__Type type); inline _PaddleMobile__Framework__Proto__VarType__Type ToDataType(int type) { return static_cast<_PaddleMobile__Framework__Proto__VarType__Type>(type); diff --git a/src/framework/ddim.h b/src/framework/ddim.h index f35e162507..8ba0756a6b 100644 --- a/src/framework/ddim.h +++ b/src/framework/ddim.h @@ -40,25 +40,25 @@ struct DDim { template static typename Vistor::type_t ApplyVistor(Vistor vistor, const DDim &d) { - if (d.var.TypeId() == type_id>().name()) { + if (d.var.TypeId() == type_id>()) { return vistor(d.var.Get>()); - } else if (d.var.TypeId() == type_id>().name()) { + } else if (d.var.TypeId() == type_id>()) { return vistor(d.var.Get>()); - } else if (d.var.TypeId() == type_id>().name()) { + } else if (d.var.TypeId() == type_id>()) { return vistor(d.var.Get>()); - } else if (d.var.TypeId() == type_id>().name()) { + } else if (d.var.TypeId() == type_id>()) { return vistor(d.var.Get>()); - } else if (d.var.TypeId() == type_id>().name()) { + } else if (d.var.TypeId() == type_id>()) { return vistor(d.var.Get>()); - } else if (d.var.TypeId() == type_id>().name()) { + } else if (d.var.TypeId() == type_id>()) { return vistor(d.var.Get>()); - } else if (d.var.TypeId() == type_id>().name()) { + } else if (d.var.TypeId() == type_id>()) { return vistor(d.var.Get>()); - } else if (d.var.TypeId() == type_id>().name()) { + } else if (d.var.TypeId() == type_id>()) { return vistor(d.var.Get>()); - } else if (d.var.TypeId() == type_id>().name()) { + } else if (d.var.TypeId() == type_id>()) { return vistor(d.var.Get>()); - } else if (d.var.TypeId() == type_id>().name()) { + } else if (d.var.TypeId() == type_id>()) { return vistor(d.var.Get>()); } else { PADDLE_MOBILE_ENFORCE(false, " dim not support"); diff --git a/src/framework/load_ops.h b/src/framework/load_ops.h index a969a73b9e..983a544cda 100644 --- a/src/framework/load_ops.h +++ b/src/framework/load_ops.h @@ -268,6 +268,9 @@ LOAD_OP1(sequence_expand, CPU); #ifdef SEQUENCE_POOL_OP LOAD_OP1(sequence_pool, CPU); #endif +#ifdef SEQUENCE_SOFTMAX_OP +LOAD_OP1(sequence_softmax, CPU); +#endif #ifdef LOG_OP LOAD_OP1(log, CPU); #endif diff --git a/src/framework/mixed_vector.h b/src/framework/mixed_vector.h index bae96e620c..6e46164fb7 100644 --- a/src/framework/mixed_vector.h +++ b/src/framework/mixed_vector.h @@ -197,7 +197,7 @@ class Vector { } size_t capacity() const { - return cpu_vec_.memory_size() / SizeOfType(type_id().name()); + return cpu_vec_.memory_size() / SizeOfType(type_id().hash_code()); } // reserve data diff --git a/src/framework/tensor.h b/src/framework/tensor.h index c38199b9e2..fe2f266137 100644 --- a/src/framework/tensor.h +++ b/src/framework/tensor.h @@ -81,7 +81,7 @@ class Tensor : public TensorBase { return *this; } - inline void *mutable_data(const std::string type) { + inline void *mutable_data(const kTypeId_t type) { if (holder_ != nullptr) { holder_->set_type(type); } @@ -106,7 +106,7 @@ class Tensor : public TensorBase { template inline T *mutable_data() { static_assert(std::is_pod::value, "T must be POD"); - return reinterpret_cast(mutable_data(type_id().name())); + return reinterpret_cast(mutable_data(type_id().hash_code())); } /** @@ -163,9 +163,9 @@ class Tensor : public TensorBase { check_memory_size(); PADDLE_MOBILE_ENFORCE( (std::is_same::value || - holder_->type() == type_id().name()), - "Tensor holds the wrong type, it holds %s, requested %s", - this->holder_->type().c_str(), type_id().name().c_str()); + holder_->type() == type_id().hash_code()), + "Tensor holds the wrong type, it holds %d, requested %d", + this->holder_->type(), type_id().hash_code()); return reinterpret_cast(reinterpret_cast(holder_->ptr()) + offset_); @@ -177,9 +177,9 @@ class Tensor : public TensorBase { check_memory_size(); PADDLE_MOBILE_ENFORCE( (std::is_same::value || - holder_->type() == type_id().name()), - "Tensor holds the wrong type, it holds %s, requested %s", - this->holder_->type().c_str(), type_id().name().c_str()); + holder_->type() == type_id().hash_code()), + "Tensor holds the wrong type, it holds %d, requested %d", + this->holder_->type(), type_id().hash_code()); return reinterpret_cast( reinterpret_cast(holder_->ptr()) + offset_); @@ -187,7 +187,7 @@ class Tensor : public TensorBase { private: struct PlaceholderImpl : public Placeholder { - PlaceholderImpl(size_t size, const std::string type) + PlaceholderImpl(size_t size, const kTypeId_t type) : ptr_(static_cast(memory::Alloc(size)), memory::PODDeleter()), size_(size), @@ -201,9 +201,9 @@ class Tensor : public TensorBase { virtual void *ptr() const { return static_cast(ptr_.get()); } - virtual std::string type() const { return type_; } + virtual kTypeId_t type() const { return type_; } - virtual void set_type(const std::string type) { type_ = type; } + virtual void set_type(const kTypeId_t type) { type_ = type; } virtual void resize(size_t size) { if (size > capatity_) { @@ -221,7 +221,7 @@ class Tensor : public TensorBase { size_t capatity_; /* the current type of memory */ - std::string type_; + kTypeId_t type_; }; #ifdef PADDLE_MOBILE_FPGA @@ -229,13 +229,13 @@ class Tensor : public TensorBase { inline void reset_data_ptr(void *p) { ((PlaceholderImpl *)(holder_.get()))->ptr_.reset((uint8_t *)p); // NOLINT } - inline void set_type(const std::string type) { holder_->set_type(type); } + inline void set_type(const kTypeId_t type) { holder_->set_type(type); } inline void *get_data() { return ( void *)(((PlaceholderImpl *)(holder_.get()))->ptr_.get()); // NOLINT } - inline void *init(const std::string type) { + inline void *init(const kTypeId_t type) { if (holder_ != nullptr) { holder_->set_type(type); } @@ -263,15 +263,15 @@ inline Print &operator<<(Print &printer, const Tensor &tensor) { stride = stride > 0 ? stride : 1; #ifndef PADDLE_MOBILE_FPGA for (int i = 0; i < tensor.numel(); i += stride) { - if (tensor.type() == type_id().name()) { + if (tensor.type() == type_id()) { printer << tensor.data()[i] << " "; - } else if (tensor.type() == type_id().name()) { + } else if (tensor.type() == type_id()) { printer << tensor.data()[i] << " "; - } else if (tensor.type() == type_id().name()) { + } else if (tensor.type() == type_id()) { printer << tensor.data()[i] << " "; - } else if (tensor.type() == type_id().name()) { + } else if (tensor.type() == type_id()) { printer << static_cast(tensor.data()[i]) << " "; - } else if (tensor.type() == type_id().name()) { + } else if (tensor.type() == type_id()) { printer << tensor.data()[i] << " "; } } diff --git a/src/framework/tensor_base.h b/src/framework/tensor_base.h index 7d76c0eff2..027f1165a0 100644 --- a/src/framework/tensor_base.h +++ b/src/framework/tensor_base.h @@ -14,8 +14,8 @@ limitations under the License. */ #pragma once -#include #include "common/enforce.h" +#include "common/type_define.h" #include "common/types.h" #include "framework/ddim.h" @@ -27,8 +27,8 @@ struct SizeOfTypeFunctor; template struct SizeOfTypeFunctor { - size_t operator()(const std::string type) const { - if (type_id().name() == type) { + size_t operator()(const kTypeId_t type) const { + if (type_id().hash_code() == type) { return sizeof(T); } else { return 0UL; @@ -38,12 +38,12 @@ struct SizeOfTypeFunctor { template <> struct SizeOfTypeFunctor<> { - size_t operator()(const std::string type) const { return 0UL; } + size_t operator()(const kTypeId_t type) const { return 0UL; } }; template struct SizeOfTypeFunctor { - size_t operator()(const std::string type) const { + size_t operator()(const kTypeId_t type) const { SizeOfTypeFunctor head; size_t head_size = head(type); if (head_size != 0) { @@ -54,14 +54,13 @@ struct SizeOfTypeFunctor { } }; -static inline size_t SizeOfType(std::string type) { +static inline size_t SizeOfType(const kTypeId_t type) { SizeOfTypeFunctor functor; size_t size = functor(type); - PADDLE_MOBILE_ENFORCE(size != 0UL, "Cannot get size of type %s", - type.c_str()); + PADDLE_MOBILE_ENFORCE(size != 0UL, "Cannot get size of type %d", type); return size; } @@ -77,7 +76,7 @@ class TensorBase { /*! Return the numel of the memory block. */ inline int64_t numel() const { return product(dims_); } - std::string type() const { + kTypeId_t type() const { PADDLE_MOBILE_ENFORCE( holder_ != nullptr, "Tensor not initialized yet when Tensor::type() is called.") @@ -113,9 +112,9 @@ class TensorBase { virtual size_t size() const = 0; - virtual std::string type() const = 0; + virtual kTypeId_t type() const = 0; - virtual void set_type(std::string type) = 0; + virtual void set_type(kTypeId_t type) = 0; virtual void resize(size_t size) = 0; }; diff --git a/src/framework/variable.h b/src/framework/variable.h index 51997530e5..30486cb347 100644 --- a/src/framework/variable.h +++ b/src/framework/variable.h @@ -30,7 +30,7 @@ class Variable { template const T GetValue() const { - if (type_id().name() == type_id().name()) { + if (type_id().hash_code() == type_id().hash_code()) { PADDLE_MOBILE_THROW_EXCEPTION( "Please use getString to get an string (to avoid of an issue with " "gcc " @@ -57,31 +57,32 @@ class Variable { template bool IsType() const { - return holder_ != nullptr && holder_->Type() == type_id().name(); + return holder_ != nullptr && holder_->Type() == type_id().hash_code(); } void Clear() { holder_.reset(); } - std::string Type() const { return holder_->Type(); } + kTypeId_t Type() const { return holder_->Type(); } private: struct Placeholder { Placeholder() = default; virtual ~Placeholder() = default; - virtual std::string Type() const = 0; + virtual kTypeId_t Type() const = 0; virtual void *Ptr() const = 0; }; template struct PlaceholderImp : public Placeholder { - explicit PlaceholderImp(T *ptr) : ptr_(ptr), type_(type_id().name()) {} + explicit PlaceholderImp(T *ptr) + : ptr_(ptr), type_(type_id().hash_code()) {} - std::string Type() const override { return type_; } + kTypeId_t Type() const override { return type_; } void *Ptr() const override { return static_cast(ptr_.get()); } std::unique_ptr ptr_; - std::string type_; + kTypeId_t type_; }; friend class Scope; diff --git a/src/operators/kernel/arm/compare_kernel.cpp b/src/operators/kernel/arm/compare_kernel.cpp index e1a0f6f167..35bb13363c 100644 --- a/src/operators/kernel/arm/compare_kernel.cpp +++ b/src/operators/kernel/arm/compare_kernel.cpp @@ -192,10 +192,10 @@ bool LessThanKernel::Init(CompareParam *param) { template <> void LessThanKernel::Compute(const CompareParam ¶m) { - if (param.input_x_->type() == type_id().name()) { + if (param.input_x_->type() == type_id().hash_code()) { CompareCompute()(param.input_x_, param.input_y_, param.axis_, param.output_); - } else if (param.input_x_->type() == type_id().name()) { + } else if (param.input_x_->type() == type_id().hash_code()) { CompareCompute()(param.input_x_, param.input_y_, param.axis_, param.output_); } else { diff --git a/src/operators/kernel/arm/concat_kernel.cpp b/src/operators/kernel/arm/concat_kernel.cpp index efee9cff28..3e585ec721 100644 --- a/src/operators/kernel/arm/concat_kernel.cpp +++ b/src/operators/kernel/arm/concat_kernel.cpp @@ -27,7 +27,7 @@ bool ConcatKernel::Init(ConcatParam *param) { template <> void ConcatKernel::Compute(const ConcatParam ¶m) { - if (param.Inputs()[0]->type() == type_id().name()) { + if (param.Inputs()[0]->type() == type_id().hash_code()) { ConcatCompute(param); } else { ConcatCompute(param); diff --git a/src/operators/kernel/arm/convolution/conv_common.cpp b/src/operators/kernel/arm/convolution/conv_common.cpp index d96eef35c8..5a5c04c656 100644 --- a/src/operators/kernel/arm/convolution/conv_common.cpp +++ b/src/operators/kernel/arm/convolution/conv_common.cpp @@ -28,7 +28,7 @@ void InitBaseConvKernel(ConvParam *param) { bool depth5x5 = conv5x5 && param->Groups() == param->Input()->dims()[1] && param->Input()->dims()[1] == param->Output()->dims()[1]; - if (param->Filter()->type() == type_id().name()) { + if (param->Filter()->type() == type_id().hash_code()) { #ifndef __aarch64__ if (depth3x3 && param->Strides()[0] < 3 && param->Strides()[0] == param->Strides()[1]) { diff --git a/src/operators/kernel/arm/transpose2_kernel.cpp b/src/operators/kernel/arm/transpose2_kernel.cpp index 54f759f016..54c88015cb 100644 --- a/src/operators/kernel/arm/transpose2_kernel.cpp +++ b/src/operators/kernel/arm/transpose2_kernel.cpp @@ -126,13 +126,13 @@ void Transpose2Kernel::Compute(const Transpose2Param ¶m) { const std::vector &axis = param.Axis(); bool shuffle_channel = IsShuffleChannel(axis); if (shuffle_channel) { - if (param.InputX()->type() == type_id().name()) { + if (param.InputX()->type() == type_id().hash_code()) { ShuffleChannelCompute(param); } else { ShuffleChannelCompute(param); } } else { - if (param.InputX()->type() == type_id().name()) { + if (param.InputX()->type() == type_id().hash_code()) { Transpose2Compute(param); } else { Transpose2Compute(param); diff --git a/src/operators/kernel/central-arm-func/mul_arm_func.h b/src/operators/kernel/central-arm-func/mul_arm_func.h index 4b697c0d13..01d668021b 100644 --- a/src/operators/kernel/central-arm-func/mul_arm_func.h +++ b/src/operators/kernel/central-arm-func/mul_arm_func.h @@ -37,7 +37,7 @@ void MulCompute(const MulParam ¶m) { if (out_dim.size() != 2) { out->Resize({x_matrix.dims()[0], y_matrix.dims()[1]}); } - if (param.InputX()->type() == type_id().name()) { + if (param.InputX()->type() == type_id().hash_code()) { out->mutable_data(); math::MatMul(x_matrix, false, y_matrix, false, static_cast(1), out, diff --git a/src/operators/kernel/central-arm-func/sum_arm_func.h b/src/operators/kernel/central-arm-func/sum_arm_func.h index eb1e830849..7d41c898db 100644 --- a/src/operators/kernel/central-arm-func/sum_arm_func.h +++ b/src/operators/kernel/central-arm-func/sum_arm_func.h @@ -144,8 +144,7 @@ void SumCompute(const SumParam ¶m) { } } else { PADDLE_MOBILE_THROW_EXCEPTION( - "Unexpected branch, output variable type is %s", - outvar->Type().c_str()); + "Unexpected branch, output variable type is %d", outvar->Type()); } } } // namespace operators -- GitLab