提交 7701ac0d 编写于 作者: H hjchen2

Use enum type rather than string since comparing between enum is more efficent than string

上级 2852680a
......@@ -14,28 +14,85 @@ limitations under the License. */
#pragma once
#include <functional>
#include <string>
#include <vector>
namespace paddle_mobile {
typedef enum {
_void = 0,
_float,
_int,
_double,
_int64_t,
_size_t,
_int16_t,
_int8_t,
_uint8_t,
_bool,
_string,
_floats = 100,
_ints,
_int64_ts,
_size_ts,
_bools,
_strings,
_const_float = 200,
_const_int,
_block = 300,
_tensor,
_lod_tensor,
_blocks,
_tensors,
_lod_tensors,
_p_block = 400,
_p_tensor,
_p_lod_tensor,
_p_blocks,
_p_tensors,
_p_lod_tensors,
_scopes = 500,
_selected_rows,
_dim0 = 600,
_dim1,
_dim2,
_dim3,
_dim4,
_dim5,
_dim6,
_dim7,
_dim8,
_dim9,
} kTypeId_t;
template <typename T>
struct TypeIdWrapper {
std::string type();
inline std::string name();
inline kTypeId_t hash_code();
};
template <typename T>
struct type_id {
const std::string type_ = TypeIdWrapper<T>().type();
const kTypeId_t hash_code() const { return TypeIdWrapper<T>().hash_code(); }
const std::string name() const { return TypeIdWrapper<T>().name(); }
template <typename OtherType>
bool operator==(const type_id<OtherType> &operand) {
return this->name() == operand.name();
bool operator==(const type_id<OtherType> &operand) const {
return this->hash_code() == operand.hash_code();
}
const std::string name() { return type_; }
};
template <typename T>
inline bool operator==(const kTypeId_t &t0, const type_id<T> &t1) {
return t0 == t1.hash_code();
}
template <typename T>
inline bool operator==(const type_id<T> &t0, const kTypeId_t &t1) {
return t1 == t0.hash_code();
}
namespace framework {
class BlockDesc;
class Tensor;
......@@ -47,10 +104,11 @@ template <int>
struct Dim;
} // namespace framework
#define REGISTER_TYPE_ID(Type, TypeName) \
template <> \
struct TypeIdWrapper<Type> { \
std::string type() { return std::string(#TypeName); } \
#define REGISTER_TYPE_ID(Type, TypeName) \
template <> \
struct TypeIdWrapper<Type> { \
inline std::string name() { return std::string(#TypeName); } \
inline kTypeId_t hash_code() { return kTypeId_t::TypeName; } \
};
REGISTER_TYPE_ID(void, _void)
......@@ -102,3 +160,14 @@ REGISTER_TYPE_ID(framework::Dim<8>, _dim8)
REGISTER_TYPE_ID(framework::Dim<9>, _dim9)
} // namespace paddle_mobile
namespace std {
template <>
struct hash<paddle_mobile::kTypeId_t> {
size_t operator()(const paddle_mobile::kTypeId_t &t) const {
return std::hash<int>{}(static_cast<int>(t));
}
};
} // namespace std
......@@ -34,8 +34,8 @@ struct VariantHelper {
? sizeof(F)
: VariantHelper<Ts...>::size;
inline static void Destroy(std::string type, void *data) {
if (type == type_id<F>().name()) {
inline static void Destroy(kTypeId_t type, void *data) {
if (type == type_id<F>()) {
reinterpret_cast<F *>(data)->~F();
} else {
VariantHelper<Ts...>::Destroy(type, data);
......@@ -46,8 +46,8 @@ struct VariantHelper {
template <typename F>
struct VariantHelper<F> {
static const size_t size = sizeof(F);
inline static void Destroy(std::string type, void *data) {
if (type == type_id<F>().name()) {
inline static void Destroy(kTypeId_t type, void *data) {
if (type == type_id<F>()) {
// reinterpret_cast<F*>(data)->~F();
} else {
// std::cout << "未匹配到 " << std::endl;
......@@ -85,17 +85,17 @@ struct Variant {
void Set(Args &&... args) {
helper::Destroy(type_, data_.data);
new (data_.data) T(std::forward<Args>(args)...);
type_ = type_id<T>().name();
type_ = type_id<T>().hash_code();
}
void SetString(const std::string &string) {
helper::Destroy(type_, data_.data);
type_ = type_id<std::string>().name();
type_ = type_id<std::string>().hash_code();
strcpy(data_.data, string.c_str()); // NOLINT
}
std::string GetString() const {
if (type_ == type_id<std::string>().name()) {
if (type_ == type_id<std::string>()) {
return std::string(data_.data);
} else {
PADDLE_MOBILE_THROW_EXCEPTION(
......@@ -106,7 +106,7 @@ struct Variant {
template <typename T>
T &Get() const {
if (type_ == type_id<std::string>().name()) {
if (type_ == type_id<std::string>()) {
PADDLE_MOBILE_THROW_EXCEPTION(
"Please use getString to get an string (to avoid of an issue with "
"gcc "
......@@ -117,12 +117,12 @@ struct Variant {
}
}
std::string TypeId() const { return type_; }
kTypeId_t TypeId() const { return type_; }
private:
static inline std::string invalid_type() { return type_id<void>().name(); }
static inline kTypeId_t invalid_type() { return type_id<void>().hash_code(); }
typedef VariantHelper<Ts...> helper;
std::string type_ = type_id<void>().name();
kTypeId_t type_ = type_id<void>().hash_code();
// todo use an anto size to suite this.
RawData<64> data_;
};
......
......@@ -128,31 +128,30 @@ class Attribute {
template <typename Vistor>
static typename Vistor::type_t ApplyVistor(Vistor vistor, Attribute attr) {
if (attr.variant_.TypeId() == type_id<int>().name()) { // NOLINT
if (attr.variant_.TypeId() == type_id<int>()) { // NOLINT
return vistor(attr.variant_.Get<int>());
} else if (attr.variant_.TypeId() == type_id<float>().name()) { // NOLINT
} else if (attr.variant_.TypeId() == type_id<float>()) { // NOLINT
return vistor(attr.variant_.Get<float>());
} else if (attr.variant_.TypeId() == type_id<string>().name()) {
} else if (attr.variant_.TypeId() == type_id<string>()) {
return vistor(attr.variant_.GetString());
} else if (attr.variant_.TypeId() == type_id<vector<int>>().name()) {
} else if (attr.variant_.TypeId() == type_id<vector<int>>()) {
return vistor(attr.variant_.Get<vector<int>>());
} else if (attr.variant_.TypeId() == type_id<vector<float>>().name()) {
} else if (attr.variant_.TypeId() == type_id<vector<float>>()) {
return vistor(attr.variant_.Get<vector<float>>());
} else if (attr.variant_.TypeId() == type_id<vector<string>>().name()) {
} else if (attr.variant_.TypeId() == type_id<vector<string>>()) {
return vistor(attr.variant_.Get<vector<string>>());
} else if (attr.variant_.TypeId() == type_id<bool>().name()) { // NOLINT
} else if (attr.variant_.TypeId() == type_id<bool>()) { // NOLINT
return vistor(attr.variant_.Get<bool>());
} else if (attr.variant_.TypeId() == type_id<vector<bool>>().name()) {
} else if (attr.variant_.TypeId() == type_id<vector<bool>>()) {
return vistor(attr.variant_.Get<vector<bool>>());
} else if (attr.variant_.TypeId() == type_id<int64_t>().name()) {
} else if (attr.variant_.TypeId() == type_id<int64_t>()) {
return vistor(attr.variant_.Get<int64_t>());
} else if (attr.variant_.TypeId() ==
type_id<framework::BlockDesc *>().name()) {
} else if (attr.variant_.TypeId() == type_id<framework::BlockDesc *>()) {
return vistor(attr.variant_.Get<framework::BlockDesc *>());
} else if (attr.variant_.TypeId() ==
type_id<vector<framework::BlockDesc *>>().name()) {
type_id<vector<framework::BlockDesc *>>()) {
return vistor(attr.variant_.Get<vector<framework::BlockDesc *>>());
} else if (attr.variant_.TypeId() == type_id<vector<int64_t>>().name()) {
} else if (attr.variant_.TypeId() == type_id<vector<int64_t>>()) {
return vistor(attr.variant_.Get<vector<int64_t>>());
} else {
PADDLE_MOBILE_THROW_EXCEPTION("type not support");
......
......@@ -22,12 +22,11 @@ namespace paddle_mobile {
namespace framework {
struct DataTypeMap {
std::unordered_map<std::string,
_PaddleMobile__Framework__Proto__VarType__Type>
std::unordered_map<kTypeId_t, _PaddleMobile__Framework__Proto__VarType__Type>
cpp_to_proto_;
std::unordered_map<int, std::string> proto_to_cpp_;
std::unordered_map<int, kTypeId_t> proto_to_cpp_;
std::unordered_map<int, std::string> proto_to_str_;
std::unordered_map<std::string, size_t> cpp_to_size_;
std::unordered_map<kTypeId_t, size_t> cpp_to_size_;
};
static DataTypeMap* InitDataTypeMap();
......@@ -43,10 +42,11 @@ template <typename T>
static inline void RegisterType(
DataTypeMap* map, _PaddleMobile__Framework__Proto__VarType__Type proto_type,
const std::string& name) {
map->proto_to_cpp_.emplace(static_cast<int>(proto_type), type_id<T>().name());
map->cpp_to_proto_.emplace(type_id<T>().name(), proto_type);
map->proto_to_cpp_.emplace(static_cast<int>(proto_type),
type_id<T>().hash_code());
map->cpp_to_proto_.emplace(type_id<T>().hash_code(), proto_type);
map->proto_to_str_.emplace(static_cast<int>(proto_type), name);
map->cpp_to_size_.emplace(type_id<T>().name(), sizeof(T));
map->cpp_to_size_.emplace(type_id<T>().hash_code(), sizeof(T));
}
static DataTypeMap* InitDataTypeMap() {
......@@ -71,15 +71,15 @@ static DataTypeMap* InitDataTypeMap() {
return retv;
}
_PaddleMobile__Framework__Proto__VarType__Type ToDataType(std::string type) {
_PaddleMobile__Framework__Proto__VarType__Type ToDataType(kTypeId_t type) {
auto it = gDataTypeMap().cpp_to_proto_.find(type);
if (it != gDataTypeMap().cpp_to_proto_.end()) {
return it->second;
}
PADDLE_MOBILE_THROW_EXCEPTION("Not support %s as tensor type", type.c_str());
PADDLE_MOBILE_THROW_EXCEPTION("Not support %d as tensor type", type);
}
std::string ToTypeIndex(_PaddleMobile__Framework__Proto__VarType__Type type) {
kTypeId_t ToTypeIndex(_PaddleMobile__Framework__Proto__VarType__Type type) {
auto it = gDataTypeMap().proto_to_cpp_.find(static_cast<int>(type));
if (it != gDataTypeMap().proto_to_cpp_.end()) {
return it->second;
......
......@@ -16,16 +16,16 @@ limitations under the License. */
#include <string>
#include "common/enforce.h"
#include "common/type_define.h"
#include "framework/framework.pb-c.h"
namespace paddle_mobile {
namespace framework {
extern _PaddleMobile__Framework__Proto__VarType__Type ToDataType(
std::string type);
extern std::string ToTypeIndex(
_PaddleMobile__Framework__Proto__VarType__Type type);
_PaddleMobile__Framework__Proto__VarType__Type ToDataType(kTypeId_t type);
kTypeId_t ToTypeIndex(_PaddleMobile__Framework__Proto__VarType__Type type);
inline _PaddleMobile__Framework__Proto__VarType__Type ToDataType(int type) {
return static_cast<_PaddleMobile__Framework__Proto__VarType__Type>(type);
......
......@@ -40,25 +40,25 @@ struct DDim {
template <typename Vistor>
static typename Vistor::type_t ApplyVistor(Vistor vistor, const DDim &d) {
if (d.var.TypeId() == type_id<Dim<0>>().name()) {
if (d.var.TypeId() == type_id<Dim<0>>()) {
return vistor(d.var.Get<Dim<0>>());
} else if (d.var.TypeId() == type_id<Dim<1>>().name()) {
} else if (d.var.TypeId() == type_id<Dim<1>>()) {
return vistor(d.var.Get<Dim<1>>());
} else if (d.var.TypeId() == type_id<Dim<2>>().name()) {
} else if (d.var.TypeId() == type_id<Dim<2>>()) {
return vistor(d.var.Get<Dim<2>>());
} else if (d.var.TypeId() == type_id<Dim<3>>().name()) {
} else if (d.var.TypeId() == type_id<Dim<3>>()) {
return vistor(d.var.Get<Dim<3>>());
} else if (d.var.TypeId() == type_id<Dim<4>>().name()) {
} else if (d.var.TypeId() == type_id<Dim<4>>()) {
return vistor(d.var.Get<Dim<4>>());
} else if (d.var.TypeId() == type_id<Dim<5>>().name()) {
} else if (d.var.TypeId() == type_id<Dim<5>>()) {
return vistor(d.var.Get<Dim<5>>());
} else if (d.var.TypeId() == type_id<Dim<6>>().name()) {
} else if (d.var.TypeId() == type_id<Dim<6>>()) {
return vistor(d.var.Get<Dim<6>>());
} else if (d.var.TypeId() == type_id<Dim<7>>().name()) {
} else if (d.var.TypeId() == type_id<Dim<7>>()) {
return vistor(d.var.Get<Dim<7>>());
} else if (d.var.TypeId() == type_id<Dim<8>>().name()) {
} else if (d.var.TypeId() == type_id<Dim<8>>()) {
return vistor(d.var.Get<Dim<8>>());
} else if (d.var.TypeId() == type_id<Dim<9>>().name()) {
} else if (d.var.TypeId() == type_id<Dim<9>>()) {
return vistor(d.var.Get<Dim<9>>());
} else {
PADDLE_MOBILE_ENFORCE(false, " dim not support");
......
......@@ -268,6 +268,9 @@ LOAD_OP1(sequence_expand, CPU);
#ifdef SEQUENCE_POOL_OP
LOAD_OP1(sequence_pool, CPU);
#endif
#ifdef SEQUENCE_SOFTMAX_OP
LOAD_OP1(sequence_softmax, CPU);
#endif
#ifdef LOG_OP
LOAD_OP1(log, CPU);
#endif
......
......@@ -197,7 +197,7 @@ class Vector {
}
size_t capacity() const {
return cpu_vec_.memory_size() / SizeOfType(type_id<T>().name());
return cpu_vec_.memory_size() / SizeOfType(type_id<T>().hash_code());
}
// reserve data
......
......@@ -81,7 +81,7 @@ class Tensor : public TensorBase {
return *this;
}
inline void *mutable_data(const std::string type) {
inline void *mutable_data(const kTypeId_t type) {
if (holder_ != nullptr) {
holder_->set_type(type);
}
......@@ -106,7 +106,7 @@ class Tensor : public TensorBase {
template <typename T>
inline T *mutable_data() {
static_assert(std::is_pod<T>::value, "T must be POD");
return reinterpret_cast<T *>(mutable_data(type_id<T>().name()));
return reinterpret_cast<T *>(mutable_data(type_id<T>().hash_code()));
}
/**
......@@ -163,9 +163,9 @@ class Tensor : public TensorBase {
check_memory_size();
PADDLE_MOBILE_ENFORCE(
(std::is_same<T, void>::value ||
holder_->type() == type_id<T>().name()),
"Tensor holds the wrong type, it holds %s, requested %s",
this->holder_->type().c_str(), type_id<T>().name().c_str());
holder_->type() == type_id<T>().hash_code()),
"Tensor holds the wrong type, it holds %d, requested %d",
this->holder_->type(), type_id<T>().hash_code());
return reinterpret_cast<T *>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_);
......@@ -177,9 +177,9 @@ class Tensor : public TensorBase {
check_memory_size();
PADDLE_MOBILE_ENFORCE(
(std::is_same<T, void>::value ||
holder_->type() == type_id<T>().name()),
"Tensor holds the wrong type, it holds %s, requested %s",
this->holder_->type().c_str(), type_id<T>().name().c_str());
holder_->type() == type_id<T>().hash_code()),
"Tensor holds the wrong type, it holds %d, requested %d",
this->holder_->type(), type_id<T>().hash_code());
return reinterpret_cast<const T *>(
reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_);
......@@ -187,7 +187,7 @@ class Tensor : public TensorBase {
private:
struct PlaceholderImpl : public Placeholder {
PlaceholderImpl(size_t size, const std::string type)
PlaceholderImpl(size_t size, const kTypeId_t type)
: ptr_(static_cast<uint8_t *>(memory::Alloc(size)),
memory::PODDeleter<uint8_t>()),
size_(size),
......@@ -201,9 +201,9 @@ class Tensor : public TensorBase {
virtual void *ptr() const { return static_cast<void *>(ptr_.get()); }
virtual std::string type() const { return type_; }
virtual kTypeId_t type() const { return type_; }
virtual void set_type(const std::string type) { type_ = type; }
virtual void set_type(const kTypeId_t type) { type_ = type; }
virtual void resize(size_t size) {
if (size > capatity_) {
......@@ -221,7 +221,7 @@ class Tensor : public TensorBase {
size_t capatity_;
/* the current type of memory */
std::string type_;
kTypeId_t type_;
};
#ifdef PADDLE_MOBILE_FPGA
......@@ -229,13 +229,13 @@ class Tensor : public TensorBase {
inline void reset_data_ptr(void *p) {
((PlaceholderImpl *)(holder_.get()))->ptr_.reset((uint8_t *)p); // NOLINT
}
inline void set_type(const std::string type) { holder_->set_type(type); }
inline void set_type(const kTypeId_t type) { holder_->set_type(type); }
inline void *get_data() {
return (
void *)(((PlaceholderImpl *)(holder_.get()))->ptr_.get()); // NOLINT
}
inline void *init(const std::string type) {
inline void *init(const kTypeId_t type) {
if (holder_ != nullptr) {
holder_->set_type(type);
}
......@@ -263,15 +263,15 @@ inline Print &operator<<(Print &printer, const Tensor &tensor) {
stride = stride > 0 ? stride : 1;
#ifndef PADDLE_MOBILE_FPGA
for (int i = 0; i < tensor.numel(); i += stride) {
if (tensor.type() == type_id<float>().name()) {
if (tensor.type() == type_id<float>()) {
printer << tensor.data<float>()[i] << " ";
} else if (tensor.type() == type_id<int32_t>().name()) {
} else if (tensor.type() == type_id<int32_t>()) {
printer << tensor.data<int32_t>()[i] << " ";
} else if (tensor.type() == type_id<int64_t>().name()) {
} else if (tensor.type() == type_id<int64_t>()) {
printer << tensor.data<int64_t>()[i] << " ";
} else if (tensor.type() == type_id<int8_t>().name()) {
} else if (tensor.type() == type_id<int8_t>()) {
printer << static_cast<int>(tensor.data<int8_t>()[i]) << " ";
} else if (tensor.type() == type_id<int32_t>().name()) {
} else if (tensor.type() == type_id<int32_t>()) {
printer << tensor.data<int32_t>()[i] << " ";
}
}
......
......@@ -14,8 +14,8 @@ limitations under the License. */
#pragma once
#include <string>
#include "common/enforce.h"
#include "common/type_define.h"
#include "common/types.h"
#include "framework/ddim.h"
......@@ -27,8 +27,8 @@ struct SizeOfTypeFunctor;
template <typename T>
struct SizeOfTypeFunctor<T> {
size_t operator()(const std::string type) const {
if (type_id<T>().name() == type) {
size_t operator()(const kTypeId_t type) const {
if (type_id<T>().hash_code() == type) {
return sizeof(T);
} else {
return 0UL;
......@@ -38,12 +38,12 @@ struct SizeOfTypeFunctor<T> {
template <>
struct SizeOfTypeFunctor<> {
size_t operator()(const std::string type) const { return 0UL; }
size_t operator()(const kTypeId_t type) const { return 0UL; }
};
template <typename HEAD, typename... TAIL>
struct SizeOfTypeFunctor<HEAD, TAIL...> {
size_t operator()(const std::string type) const {
size_t operator()(const kTypeId_t type) const {
SizeOfTypeFunctor<HEAD> head;
size_t head_size = head(type);
if (head_size != 0) {
......@@ -54,14 +54,13 @@ struct SizeOfTypeFunctor<HEAD, TAIL...> {
}
};
static inline size_t SizeOfType(std::string type) {
static inline size_t SizeOfType(const kTypeId_t type) {
SizeOfTypeFunctor<int8_t, int, half, float, double, int16_t, int64_t, bool,
size_t>
functor;
size_t size = functor(type);
PADDLE_MOBILE_ENFORCE(size != 0UL, "Cannot get size of type %s",
type.c_str());
PADDLE_MOBILE_ENFORCE(size != 0UL, "Cannot get size of type %d", type);
return size;
}
......@@ -77,7 +76,7 @@ class TensorBase {
/*! Return the numel of the memory block. */
inline int64_t numel() const { return product(dims_); }
std::string type() const {
kTypeId_t type() const {
PADDLE_MOBILE_ENFORCE(
holder_ != nullptr,
"Tensor not initialized yet when Tensor::type() is called.")
......@@ -113,9 +112,9 @@ class TensorBase {
virtual size_t size() const = 0;
virtual std::string type() const = 0;
virtual kTypeId_t type() const = 0;
virtual void set_type(std::string type) = 0;
virtual void set_type(kTypeId_t type) = 0;
virtual void resize(size_t size) = 0;
};
......
......@@ -30,7 +30,7 @@ class Variable {
template <typename T>
const T GetValue() const {
if (type_id<T>().name() == type_id<std::string>().name()) {
if (type_id<T>().hash_code() == type_id<std::string>().hash_code()) {
PADDLE_MOBILE_THROW_EXCEPTION(
"Please use getString to get an string (to avoid of an issue with "
"gcc "
......@@ -57,31 +57,32 @@ class Variable {
template <typename T>
bool IsType() const {
return holder_ != nullptr && holder_->Type() == type_id<T>().name();
return holder_ != nullptr && holder_->Type() == type_id<T>().hash_code();
}
void Clear() { holder_.reset(); }
std::string Type() const { return holder_->Type(); }
kTypeId_t Type() const { return holder_->Type(); }
private:
struct Placeholder {
Placeholder() = default;
virtual ~Placeholder() = default;
virtual std::string Type() const = 0;
virtual kTypeId_t Type() const = 0;
virtual void *Ptr() const = 0;
};
template <typename T>
struct PlaceholderImp : public Placeholder {
explicit PlaceholderImp(T *ptr) : ptr_(ptr), type_(type_id<T>().name()) {}
explicit PlaceholderImp(T *ptr)
: ptr_(ptr), type_(type_id<T>().hash_code()) {}
std::string Type() const override { return type_; }
kTypeId_t Type() const override { return type_; }
void *Ptr() const override { return static_cast<void *>(ptr_.get()); }
std::unique_ptr<T> ptr_;
std::string type_;
kTypeId_t type_;
};
friend class Scope;
......
......@@ -192,10 +192,10 @@ bool LessThanKernel<CPU, float>::Init(CompareParam<CPU> *param) {
template <>
void LessThanKernel<CPU, float>::Compute(const CompareParam<CPU> &param) {
if (param.input_x_->type() == type_id<int64_t>().name()) {
if (param.input_x_->type() == type_id<int64_t>().hash_code()) {
CompareCompute<int64_t, LESS_THAN>()(param.input_x_, param.input_y_,
param.axis_, param.output_);
} else if (param.input_x_->type() == type_id<float>().name()) {
} else if (param.input_x_->type() == type_id<float>().hash_code()) {
CompareCompute<float, LESS_THAN>()(param.input_x_, param.input_y_,
param.axis_, param.output_);
} else {
......
......@@ -27,7 +27,7 @@ bool ConcatKernel<CPU, float>::Init(ConcatParam<CPU> *param) {
template <>
void ConcatKernel<CPU, float>::Compute(const ConcatParam<CPU> &param) {
if (param.Inputs()[0]->type() == type_id<int8_t>().name()) {
if (param.Inputs()[0]->type() == type_id<int8_t>().hash_code()) {
ConcatCompute<int8_t>(param);
} else {
ConcatCompute<float>(param);
......
......@@ -28,7 +28,7 @@ void InitBaseConvKernel(ConvParam<CPU> *param) {
bool depth5x5 = conv5x5 && param->Groups() == param->Input()->dims()[1] &&
param->Input()->dims()[1] == param->Output()->dims()[1];
if (param->Filter()->type() == type_id<int8_t>().name()) {
if (param->Filter()->type() == type_id<int8_t>().hash_code()) {
#ifndef __aarch64__
if (depth3x3 && param->Strides()[0] < 3 &&
param->Strides()[0] == param->Strides()[1]) {
......
......@@ -126,13 +126,13 @@ void Transpose2Kernel<CPU, float>::Compute(const Transpose2Param<CPU> &param) {
const std::vector<int> &axis = param.Axis();
bool shuffle_channel = IsShuffleChannel(axis);
if (shuffle_channel) {
if (param.InputX()->type() == type_id<int8_t>().name()) {
if (param.InputX()->type() == type_id<int8_t>().hash_code()) {
ShuffleChannelCompute<int8_t>(param);
} else {
ShuffleChannelCompute<float>(param);
}
} else {
if (param.InputX()->type() == type_id<int8_t>().name()) {
if (param.InputX()->type() == type_id<int8_t>().hash_code()) {
Transpose2Compute<int8_t>(param);
} else {
Transpose2Compute<float>(param);
......
......@@ -37,7 +37,7 @@ void MulCompute(const MulParam<CPU> &param) {
if (out_dim.size() != 2) {
out->Resize({x_matrix.dims()[0], y_matrix.dims()[1]});
}
if (param.InputX()->type() == type_id<int8_t>().name()) {
if (param.InputX()->type() == type_id<int8_t>().hash_code()) {
out->mutable_data<int32_t>();
math::MatMul<int8_t, int32_t>(x_matrix, false, y_matrix, false,
static_cast<float>(1), out,
......
......@@ -144,8 +144,7 @@ void SumCompute(const SumParam<CPU> &param) {
}
} else {
PADDLE_MOBILE_THROW_EXCEPTION(
"Unexpected branch, output variable type is %s",
outvar->Type().c_str());
"Unexpected branch, output variable type is %d", outvar->Type());
}
}
} // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册