提交 7701ac0d 编写于 作者: H hjchen2

Use enum type rather than string since comparing between enum is more efficent than string

上级 2852680a
...@@ -14,28 +14,85 @@ limitations under the License. */ ...@@ -14,28 +14,85 @@ limitations under the License. */
#pragma once #pragma once
#include <functional>
#include <string> #include <string>
#include <vector> #include <vector>
namespace paddle_mobile { namespace paddle_mobile {
typedef enum {
_void = 0,
_float,
_int,
_double,
_int64_t,
_size_t,
_int16_t,
_int8_t,
_uint8_t,
_bool,
_string,
_floats = 100,
_ints,
_int64_ts,
_size_ts,
_bools,
_strings,
_const_float = 200,
_const_int,
_block = 300,
_tensor,
_lod_tensor,
_blocks,
_tensors,
_lod_tensors,
_p_block = 400,
_p_tensor,
_p_lod_tensor,
_p_blocks,
_p_tensors,
_p_lod_tensors,
_scopes = 500,
_selected_rows,
_dim0 = 600,
_dim1,
_dim2,
_dim3,
_dim4,
_dim5,
_dim6,
_dim7,
_dim8,
_dim9,
} kTypeId_t;
template <typename T> template <typename T>
struct TypeIdWrapper { struct TypeIdWrapper {
std::string type(); inline std::string name();
inline kTypeId_t hash_code();
}; };
template <typename T> template <typename T>
struct type_id { struct type_id {
const std::string type_ = TypeIdWrapper<T>().type(); const kTypeId_t hash_code() const { return TypeIdWrapper<T>().hash_code(); }
const std::string name() const { return TypeIdWrapper<T>().name(); }
template <typename OtherType> template <typename OtherType>
bool operator==(const type_id<OtherType> &operand) { bool operator==(const type_id<OtherType> &operand) const {
return this->name() == operand.name(); return this->hash_code() == operand.hash_code();
} }
const std::string name() { return type_; }
}; };
template <typename T>
inline bool operator==(const kTypeId_t &t0, const type_id<T> &t1) {
return t0 == t1.hash_code();
}
template <typename T>
inline bool operator==(const type_id<T> &t0, const kTypeId_t &t1) {
return t1 == t0.hash_code();
}
namespace framework { namespace framework {
class BlockDesc; class BlockDesc;
class Tensor; class Tensor;
...@@ -50,7 +107,8 @@ struct Dim; ...@@ -50,7 +107,8 @@ struct Dim;
#define REGISTER_TYPE_ID(Type, TypeName) \ #define REGISTER_TYPE_ID(Type, TypeName) \
template <> \ template <> \
struct TypeIdWrapper<Type> { \ struct TypeIdWrapper<Type> { \
std::string type() { return std::string(#TypeName); } \ inline std::string name() { return std::string(#TypeName); } \
inline kTypeId_t hash_code() { return kTypeId_t::TypeName; } \
}; };
REGISTER_TYPE_ID(void, _void) REGISTER_TYPE_ID(void, _void)
...@@ -102,3 +160,14 @@ REGISTER_TYPE_ID(framework::Dim<8>, _dim8) ...@@ -102,3 +160,14 @@ REGISTER_TYPE_ID(framework::Dim<8>, _dim8)
REGISTER_TYPE_ID(framework::Dim<9>, _dim9) REGISTER_TYPE_ID(framework::Dim<9>, _dim9)
} // namespace paddle_mobile } // namespace paddle_mobile
namespace std {
template <>
struct hash<paddle_mobile::kTypeId_t> {
size_t operator()(const paddle_mobile::kTypeId_t &t) const {
return std::hash<int>{}(static_cast<int>(t));
}
};
} // namespace std
...@@ -34,8 +34,8 @@ struct VariantHelper { ...@@ -34,8 +34,8 @@ struct VariantHelper {
? sizeof(F) ? sizeof(F)
: VariantHelper<Ts...>::size; : VariantHelper<Ts...>::size;
inline static void Destroy(std::string type, void *data) { inline static void Destroy(kTypeId_t type, void *data) {
if (type == type_id<F>().name()) { if (type == type_id<F>()) {
reinterpret_cast<F *>(data)->~F(); reinterpret_cast<F *>(data)->~F();
} else { } else {
VariantHelper<Ts...>::Destroy(type, data); VariantHelper<Ts...>::Destroy(type, data);
...@@ -46,8 +46,8 @@ struct VariantHelper { ...@@ -46,8 +46,8 @@ struct VariantHelper {
template <typename F> template <typename F>
struct VariantHelper<F> { struct VariantHelper<F> {
static const size_t size = sizeof(F); static const size_t size = sizeof(F);
inline static void Destroy(std::string type, void *data) { inline static void Destroy(kTypeId_t type, void *data) {
if (type == type_id<F>().name()) { if (type == type_id<F>()) {
// reinterpret_cast<F*>(data)->~F(); // reinterpret_cast<F*>(data)->~F();
} else { } else {
// std::cout << "未匹配到 " << std::endl; // std::cout << "未匹配到 " << std::endl;
...@@ -85,17 +85,17 @@ struct Variant { ...@@ -85,17 +85,17 @@ struct Variant {
void Set(Args &&... args) { void Set(Args &&... args) {
helper::Destroy(type_, data_.data); helper::Destroy(type_, data_.data);
new (data_.data) T(std::forward<Args>(args)...); new (data_.data) T(std::forward<Args>(args)...);
type_ = type_id<T>().name(); type_ = type_id<T>().hash_code();
} }
void SetString(const std::string &string) { void SetString(const std::string &string) {
helper::Destroy(type_, data_.data); helper::Destroy(type_, data_.data);
type_ = type_id<std::string>().name(); type_ = type_id<std::string>().hash_code();
strcpy(data_.data, string.c_str()); // NOLINT strcpy(data_.data, string.c_str()); // NOLINT
} }
std::string GetString() const { std::string GetString() const {
if (type_ == type_id<std::string>().name()) { if (type_ == type_id<std::string>()) {
return std::string(data_.data); return std::string(data_.data);
} else { } else {
PADDLE_MOBILE_THROW_EXCEPTION( PADDLE_MOBILE_THROW_EXCEPTION(
...@@ -106,7 +106,7 @@ struct Variant { ...@@ -106,7 +106,7 @@ struct Variant {
template <typename T> template <typename T>
T &Get() const { T &Get() const {
if (type_ == type_id<std::string>().name()) { if (type_ == type_id<std::string>()) {
PADDLE_MOBILE_THROW_EXCEPTION( PADDLE_MOBILE_THROW_EXCEPTION(
"Please use getString to get an string (to avoid of an issue with " "Please use getString to get an string (to avoid of an issue with "
"gcc " "gcc "
...@@ -117,12 +117,12 @@ struct Variant { ...@@ -117,12 +117,12 @@ struct Variant {
} }
} }
std::string TypeId() const { return type_; } kTypeId_t TypeId() const { return type_; }
private: private:
static inline std::string invalid_type() { return type_id<void>().name(); } static inline kTypeId_t invalid_type() { return type_id<void>().hash_code(); }
typedef VariantHelper<Ts...> helper; typedef VariantHelper<Ts...> helper;
std::string type_ = type_id<void>().name(); kTypeId_t type_ = type_id<void>().hash_code();
// todo use an anto size to suite this. // todo use an anto size to suite this.
RawData<64> data_; RawData<64> data_;
}; };
......
...@@ -128,31 +128,30 @@ class Attribute { ...@@ -128,31 +128,30 @@ class Attribute {
template <typename Vistor> template <typename Vistor>
static typename Vistor::type_t ApplyVistor(Vistor vistor, Attribute attr) { static typename Vistor::type_t ApplyVistor(Vistor vistor, Attribute attr) {
if (attr.variant_.TypeId() == type_id<int>().name()) { // NOLINT if (attr.variant_.TypeId() == type_id<int>()) { // NOLINT
return vistor(attr.variant_.Get<int>()); return vistor(attr.variant_.Get<int>());
} else if (attr.variant_.TypeId() == type_id<float>().name()) { // NOLINT } else if (attr.variant_.TypeId() == type_id<float>()) { // NOLINT
return vistor(attr.variant_.Get<float>()); return vistor(attr.variant_.Get<float>());
} else if (attr.variant_.TypeId() == type_id<string>().name()) { } else if (attr.variant_.TypeId() == type_id<string>()) {
return vistor(attr.variant_.GetString()); return vistor(attr.variant_.GetString());
} else if (attr.variant_.TypeId() == type_id<vector<int>>().name()) { } else if (attr.variant_.TypeId() == type_id<vector<int>>()) {
return vistor(attr.variant_.Get<vector<int>>()); return vistor(attr.variant_.Get<vector<int>>());
} else if (attr.variant_.TypeId() == type_id<vector<float>>().name()) { } else if (attr.variant_.TypeId() == type_id<vector<float>>()) {
return vistor(attr.variant_.Get<vector<float>>()); return vistor(attr.variant_.Get<vector<float>>());
} else if (attr.variant_.TypeId() == type_id<vector<string>>().name()) { } else if (attr.variant_.TypeId() == type_id<vector<string>>()) {
return vistor(attr.variant_.Get<vector<string>>()); return vistor(attr.variant_.Get<vector<string>>());
} else if (attr.variant_.TypeId() == type_id<bool>().name()) { // NOLINT } else if (attr.variant_.TypeId() == type_id<bool>()) { // NOLINT
return vistor(attr.variant_.Get<bool>()); return vistor(attr.variant_.Get<bool>());
} else if (attr.variant_.TypeId() == type_id<vector<bool>>().name()) { } else if (attr.variant_.TypeId() == type_id<vector<bool>>()) {
return vistor(attr.variant_.Get<vector<bool>>()); return vistor(attr.variant_.Get<vector<bool>>());
} else if (attr.variant_.TypeId() == type_id<int64_t>().name()) { } else if (attr.variant_.TypeId() == type_id<int64_t>()) {
return vistor(attr.variant_.Get<int64_t>()); return vistor(attr.variant_.Get<int64_t>());
} else if (attr.variant_.TypeId() == } else if (attr.variant_.TypeId() == type_id<framework::BlockDesc *>()) {
type_id<framework::BlockDesc *>().name()) {
return vistor(attr.variant_.Get<framework::BlockDesc *>()); return vistor(attr.variant_.Get<framework::BlockDesc *>());
} else if (attr.variant_.TypeId() == } else if (attr.variant_.TypeId() ==
type_id<vector<framework::BlockDesc *>>().name()) { type_id<vector<framework::BlockDesc *>>()) {
return vistor(attr.variant_.Get<vector<framework::BlockDesc *>>()); return vistor(attr.variant_.Get<vector<framework::BlockDesc *>>());
} else if (attr.variant_.TypeId() == type_id<vector<int64_t>>().name()) { } else if (attr.variant_.TypeId() == type_id<vector<int64_t>>()) {
return vistor(attr.variant_.Get<vector<int64_t>>()); return vistor(attr.variant_.Get<vector<int64_t>>());
} else { } else {
PADDLE_MOBILE_THROW_EXCEPTION("type not support"); PADDLE_MOBILE_THROW_EXCEPTION("type not support");
......
...@@ -22,12 +22,11 @@ namespace paddle_mobile { ...@@ -22,12 +22,11 @@ namespace paddle_mobile {
namespace framework { namespace framework {
struct DataTypeMap { struct DataTypeMap {
std::unordered_map<std::string, std::unordered_map<kTypeId_t, _PaddleMobile__Framework__Proto__VarType__Type>
_PaddleMobile__Framework__Proto__VarType__Type>
cpp_to_proto_; cpp_to_proto_;
std::unordered_map<int, std::string> proto_to_cpp_; std::unordered_map<int, kTypeId_t> proto_to_cpp_;
std::unordered_map<int, std::string> proto_to_str_; std::unordered_map<int, std::string> proto_to_str_;
std::unordered_map<std::string, size_t> cpp_to_size_; std::unordered_map<kTypeId_t, size_t> cpp_to_size_;
}; };
static DataTypeMap* InitDataTypeMap(); static DataTypeMap* InitDataTypeMap();
...@@ -43,10 +42,11 @@ template <typename T> ...@@ -43,10 +42,11 @@ template <typename T>
static inline void RegisterType( static inline void RegisterType(
DataTypeMap* map, _PaddleMobile__Framework__Proto__VarType__Type proto_type, DataTypeMap* map, _PaddleMobile__Framework__Proto__VarType__Type proto_type,
const std::string& name) { const std::string& name) {
map->proto_to_cpp_.emplace(static_cast<int>(proto_type), type_id<T>().name()); map->proto_to_cpp_.emplace(static_cast<int>(proto_type),
map->cpp_to_proto_.emplace(type_id<T>().name(), proto_type); type_id<T>().hash_code());
map->cpp_to_proto_.emplace(type_id<T>().hash_code(), proto_type);
map->proto_to_str_.emplace(static_cast<int>(proto_type), name); map->proto_to_str_.emplace(static_cast<int>(proto_type), name);
map->cpp_to_size_.emplace(type_id<T>().name(), sizeof(T)); map->cpp_to_size_.emplace(type_id<T>().hash_code(), sizeof(T));
} }
static DataTypeMap* InitDataTypeMap() { static DataTypeMap* InitDataTypeMap() {
...@@ -71,15 +71,15 @@ static DataTypeMap* InitDataTypeMap() { ...@@ -71,15 +71,15 @@ static DataTypeMap* InitDataTypeMap() {
return retv; return retv;
} }
_PaddleMobile__Framework__Proto__VarType__Type ToDataType(std::string type) { _PaddleMobile__Framework__Proto__VarType__Type ToDataType(kTypeId_t type) {
auto it = gDataTypeMap().cpp_to_proto_.find(type); auto it = gDataTypeMap().cpp_to_proto_.find(type);
if (it != gDataTypeMap().cpp_to_proto_.end()) { if (it != gDataTypeMap().cpp_to_proto_.end()) {
return it->second; return it->second;
} }
PADDLE_MOBILE_THROW_EXCEPTION("Not support %s as tensor type", type.c_str()); PADDLE_MOBILE_THROW_EXCEPTION("Not support %d as tensor type", type);
} }
std::string ToTypeIndex(_PaddleMobile__Framework__Proto__VarType__Type type) { kTypeId_t ToTypeIndex(_PaddleMobile__Framework__Proto__VarType__Type type) {
auto it = gDataTypeMap().proto_to_cpp_.find(static_cast<int>(type)); auto it = gDataTypeMap().proto_to_cpp_.find(static_cast<int>(type));
if (it != gDataTypeMap().proto_to_cpp_.end()) { if (it != gDataTypeMap().proto_to_cpp_.end()) {
return it->second; return it->second;
......
...@@ -16,16 +16,16 @@ limitations under the License. */ ...@@ -16,16 +16,16 @@ limitations under the License. */
#include <string> #include <string>
#include "common/enforce.h" #include "common/enforce.h"
#include "common/type_define.h"
#include "framework/framework.pb-c.h" #include "framework/framework.pb-c.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
extern _PaddleMobile__Framework__Proto__VarType__Type ToDataType( _PaddleMobile__Framework__Proto__VarType__Type ToDataType(kTypeId_t type);
std::string type);
extern std::string ToTypeIndex( kTypeId_t ToTypeIndex(_PaddleMobile__Framework__Proto__VarType__Type type);
_PaddleMobile__Framework__Proto__VarType__Type type);
inline _PaddleMobile__Framework__Proto__VarType__Type ToDataType(int type) { inline _PaddleMobile__Framework__Proto__VarType__Type ToDataType(int type) {
return static_cast<_PaddleMobile__Framework__Proto__VarType__Type>(type); return static_cast<_PaddleMobile__Framework__Proto__VarType__Type>(type);
......
...@@ -40,25 +40,25 @@ struct DDim { ...@@ -40,25 +40,25 @@ struct DDim {
template <typename Vistor> template <typename Vistor>
static typename Vistor::type_t ApplyVistor(Vistor vistor, const DDim &d) { static typename Vistor::type_t ApplyVistor(Vistor vistor, const DDim &d) {
if (d.var.TypeId() == type_id<Dim<0>>().name()) { if (d.var.TypeId() == type_id<Dim<0>>()) {
return vistor(d.var.Get<Dim<0>>()); return vistor(d.var.Get<Dim<0>>());
} else if (d.var.TypeId() == type_id<Dim<1>>().name()) { } else if (d.var.TypeId() == type_id<Dim<1>>()) {
return vistor(d.var.Get<Dim<1>>()); return vistor(d.var.Get<Dim<1>>());
} else if (d.var.TypeId() == type_id<Dim<2>>().name()) { } else if (d.var.TypeId() == type_id<Dim<2>>()) {
return vistor(d.var.Get<Dim<2>>()); return vistor(d.var.Get<Dim<2>>());
} else if (d.var.TypeId() == type_id<Dim<3>>().name()) { } else if (d.var.TypeId() == type_id<Dim<3>>()) {
return vistor(d.var.Get<Dim<3>>()); return vistor(d.var.Get<Dim<3>>());
} else if (d.var.TypeId() == type_id<Dim<4>>().name()) { } else if (d.var.TypeId() == type_id<Dim<4>>()) {
return vistor(d.var.Get<Dim<4>>()); return vistor(d.var.Get<Dim<4>>());
} else if (d.var.TypeId() == type_id<Dim<5>>().name()) { } else if (d.var.TypeId() == type_id<Dim<5>>()) {
return vistor(d.var.Get<Dim<5>>()); return vistor(d.var.Get<Dim<5>>());
} else if (d.var.TypeId() == type_id<Dim<6>>().name()) { } else if (d.var.TypeId() == type_id<Dim<6>>()) {
return vistor(d.var.Get<Dim<6>>()); return vistor(d.var.Get<Dim<6>>());
} else if (d.var.TypeId() == type_id<Dim<7>>().name()) { } else if (d.var.TypeId() == type_id<Dim<7>>()) {
return vistor(d.var.Get<Dim<7>>()); return vistor(d.var.Get<Dim<7>>());
} else if (d.var.TypeId() == type_id<Dim<8>>().name()) { } else if (d.var.TypeId() == type_id<Dim<8>>()) {
return vistor(d.var.Get<Dim<8>>()); return vistor(d.var.Get<Dim<8>>());
} else if (d.var.TypeId() == type_id<Dim<9>>().name()) { } else if (d.var.TypeId() == type_id<Dim<9>>()) {
return vistor(d.var.Get<Dim<9>>()); return vistor(d.var.Get<Dim<9>>());
} else { } else {
PADDLE_MOBILE_ENFORCE(false, " dim not support"); PADDLE_MOBILE_ENFORCE(false, " dim not support");
......
...@@ -268,6 +268,9 @@ LOAD_OP1(sequence_expand, CPU); ...@@ -268,6 +268,9 @@ LOAD_OP1(sequence_expand, CPU);
#ifdef SEQUENCE_POOL_OP #ifdef SEQUENCE_POOL_OP
LOAD_OP1(sequence_pool, CPU); LOAD_OP1(sequence_pool, CPU);
#endif #endif
#ifdef SEQUENCE_SOFTMAX_OP
LOAD_OP1(sequence_softmax, CPU);
#endif
#ifdef LOG_OP #ifdef LOG_OP
LOAD_OP1(log, CPU); LOAD_OP1(log, CPU);
#endif #endif
......
...@@ -197,7 +197,7 @@ class Vector { ...@@ -197,7 +197,7 @@ class Vector {
} }
size_t capacity() const { size_t capacity() const {
return cpu_vec_.memory_size() / SizeOfType(type_id<T>().name()); return cpu_vec_.memory_size() / SizeOfType(type_id<T>().hash_code());
} }
// reserve data // reserve data
......
...@@ -81,7 +81,7 @@ class Tensor : public TensorBase { ...@@ -81,7 +81,7 @@ class Tensor : public TensorBase {
return *this; return *this;
} }
inline void *mutable_data(const std::string type) { inline void *mutable_data(const kTypeId_t type) {
if (holder_ != nullptr) { if (holder_ != nullptr) {
holder_->set_type(type); holder_->set_type(type);
} }
...@@ -106,7 +106,7 @@ class Tensor : public TensorBase { ...@@ -106,7 +106,7 @@ class Tensor : public TensorBase {
template <typename T> template <typename T>
inline T *mutable_data() { inline T *mutable_data() {
static_assert(std::is_pod<T>::value, "T must be POD"); static_assert(std::is_pod<T>::value, "T must be POD");
return reinterpret_cast<T *>(mutable_data(type_id<T>().name())); return reinterpret_cast<T *>(mutable_data(type_id<T>().hash_code()));
} }
/** /**
...@@ -163,9 +163,9 @@ class Tensor : public TensorBase { ...@@ -163,9 +163,9 @@ class Tensor : public TensorBase {
check_memory_size(); check_memory_size();
PADDLE_MOBILE_ENFORCE( PADDLE_MOBILE_ENFORCE(
(std::is_same<T, void>::value || (std::is_same<T, void>::value ||
holder_->type() == type_id<T>().name()), holder_->type() == type_id<T>().hash_code()),
"Tensor holds the wrong type, it holds %s, requested %s", "Tensor holds the wrong type, it holds %d, requested %d",
this->holder_->type().c_str(), type_id<T>().name().c_str()); this->holder_->type(), type_id<T>().hash_code());
return reinterpret_cast<T *>(reinterpret_cast<uintptr_t>(holder_->ptr()) + return reinterpret_cast<T *>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_); offset_);
...@@ -177,9 +177,9 @@ class Tensor : public TensorBase { ...@@ -177,9 +177,9 @@ class Tensor : public TensorBase {
check_memory_size(); check_memory_size();
PADDLE_MOBILE_ENFORCE( PADDLE_MOBILE_ENFORCE(
(std::is_same<T, void>::value || (std::is_same<T, void>::value ||
holder_->type() == type_id<T>().name()), holder_->type() == type_id<T>().hash_code()),
"Tensor holds the wrong type, it holds %s, requested %s", "Tensor holds the wrong type, it holds %d, requested %d",
this->holder_->type().c_str(), type_id<T>().name().c_str()); this->holder_->type(), type_id<T>().hash_code());
return reinterpret_cast<const T *>( return reinterpret_cast<const T *>(
reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_); reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_);
...@@ -187,7 +187,7 @@ class Tensor : public TensorBase { ...@@ -187,7 +187,7 @@ class Tensor : public TensorBase {
private: private:
struct PlaceholderImpl : public Placeholder { struct PlaceholderImpl : public Placeholder {
PlaceholderImpl(size_t size, const std::string type) PlaceholderImpl(size_t size, const kTypeId_t type)
: ptr_(static_cast<uint8_t *>(memory::Alloc(size)), : ptr_(static_cast<uint8_t *>(memory::Alloc(size)),
memory::PODDeleter<uint8_t>()), memory::PODDeleter<uint8_t>()),
size_(size), size_(size),
...@@ -201,9 +201,9 @@ class Tensor : public TensorBase { ...@@ -201,9 +201,9 @@ class Tensor : public TensorBase {
virtual void *ptr() const { return static_cast<void *>(ptr_.get()); } virtual void *ptr() const { return static_cast<void *>(ptr_.get()); }
virtual std::string type() const { return type_; } virtual kTypeId_t type() const { return type_; }
virtual void set_type(const std::string type) { type_ = type; } virtual void set_type(const kTypeId_t type) { type_ = type; }
virtual void resize(size_t size) { virtual void resize(size_t size) {
if (size > capatity_) { if (size > capatity_) {
...@@ -221,7 +221,7 @@ class Tensor : public TensorBase { ...@@ -221,7 +221,7 @@ class Tensor : public TensorBase {
size_t capatity_; size_t capatity_;
/* the current type of memory */ /* the current type of memory */
std::string type_; kTypeId_t type_;
}; };
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
...@@ -229,13 +229,13 @@ class Tensor : public TensorBase { ...@@ -229,13 +229,13 @@ class Tensor : public TensorBase {
inline void reset_data_ptr(void *p) { inline void reset_data_ptr(void *p) {
((PlaceholderImpl *)(holder_.get()))->ptr_.reset((uint8_t *)p); // NOLINT ((PlaceholderImpl *)(holder_.get()))->ptr_.reset((uint8_t *)p); // NOLINT
} }
inline void set_type(const std::string type) { holder_->set_type(type); } inline void set_type(const kTypeId_t type) { holder_->set_type(type); }
inline void *get_data() { inline void *get_data() {
return ( return (
void *)(((PlaceholderImpl *)(holder_.get()))->ptr_.get()); // NOLINT void *)(((PlaceholderImpl *)(holder_.get()))->ptr_.get()); // NOLINT
} }
inline void *init(const std::string type) { inline void *init(const kTypeId_t type) {
if (holder_ != nullptr) { if (holder_ != nullptr) {
holder_->set_type(type); holder_->set_type(type);
} }
...@@ -263,15 +263,15 @@ inline Print &operator<<(Print &printer, const Tensor &tensor) { ...@@ -263,15 +263,15 @@ inline Print &operator<<(Print &printer, const Tensor &tensor) {
stride = stride > 0 ? stride : 1; stride = stride > 0 ? stride : 1;
#ifndef PADDLE_MOBILE_FPGA #ifndef PADDLE_MOBILE_FPGA
for (int i = 0; i < tensor.numel(); i += stride) { for (int i = 0; i < tensor.numel(); i += stride) {
if (tensor.type() == type_id<float>().name()) { if (tensor.type() == type_id<float>()) {
printer << tensor.data<float>()[i] << " "; printer << tensor.data<float>()[i] << " ";
} else if (tensor.type() == type_id<int32_t>().name()) { } else if (tensor.type() == type_id<int32_t>()) {
printer << tensor.data<int32_t>()[i] << " "; printer << tensor.data<int32_t>()[i] << " ";
} else if (tensor.type() == type_id<int64_t>().name()) { } else if (tensor.type() == type_id<int64_t>()) {
printer << tensor.data<int64_t>()[i] << " "; printer << tensor.data<int64_t>()[i] << " ";
} else if (tensor.type() == type_id<int8_t>().name()) { } else if (tensor.type() == type_id<int8_t>()) {
printer << static_cast<int>(tensor.data<int8_t>()[i]) << " "; printer << static_cast<int>(tensor.data<int8_t>()[i]) << " ";
} else if (tensor.type() == type_id<int32_t>().name()) { } else if (tensor.type() == type_id<int32_t>()) {
printer << tensor.data<int32_t>()[i] << " "; printer << tensor.data<int32_t>()[i] << " ";
} }
} }
......
...@@ -14,8 +14,8 @@ limitations under the License. */ ...@@ -14,8 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include <string>
#include "common/enforce.h" #include "common/enforce.h"
#include "common/type_define.h"
#include "common/types.h" #include "common/types.h"
#include "framework/ddim.h" #include "framework/ddim.h"
...@@ -27,8 +27,8 @@ struct SizeOfTypeFunctor; ...@@ -27,8 +27,8 @@ struct SizeOfTypeFunctor;
template <typename T> template <typename T>
struct SizeOfTypeFunctor<T> { struct SizeOfTypeFunctor<T> {
size_t operator()(const std::string type) const { size_t operator()(const kTypeId_t type) const {
if (type_id<T>().name() == type) { if (type_id<T>().hash_code() == type) {
return sizeof(T); return sizeof(T);
} else { } else {
return 0UL; return 0UL;
...@@ -38,12 +38,12 @@ struct SizeOfTypeFunctor<T> { ...@@ -38,12 +38,12 @@ struct SizeOfTypeFunctor<T> {
template <> template <>
struct SizeOfTypeFunctor<> { struct SizeOfTypeFunctor<> {
size_t operator()(const std::string type) const { return 0UL; } size_t operator()(const kTypeId_t type) const { return 0UL; }
}; };
template <typename HEAD, typename... TAIL> template <typename HEAD, typename... TAIL>
struct SizeOfTypeFunctor<HEAD, TAIL...> { struct SizeOfTypeFunctor<HEAD, TAIL...> {
size_t operator()(const std::string type) const { size_t operator()(const kTypeId_t type) const {
SizeOfTypeFunctor<HEAD> head; SizeOfTypeFunctor<HEAD> head;
size_t head_size = head(type); size_t head_size = head(type);
if (head_size != 0) { if (head_size != 0) {
...@@ -54,14 +54,13 @@ struct SizeOfTypeFunctor<HEAD, TAIL...> { ...@@ -54,14 +54,13 @@ struct SizeOfTypeFunctor<HEAD, TAIL...> {
} }
}; };
static inline size_t SizeOfType(std::string type) { static inline size_t SizeOfType(const kTypeId_t type) {
SizeOfTypeFunctor<int8_t, int, half, float, double, int16_t, int64_t, bool, SizeOfTypeFunctor<int8_t, int, half, float, double, int16_t, int64_t, bool,
size_t> size_t>
functor; functor;
size_t size = functor(type); size_t size = functor(type);
PADDLE_MOBILE_ENFORCE(size != 0UL, "Cannot get size of type %s", PADDLE_MOBILE_ENFORCE(size != 0UL, "Cannot get size of type %d", type);
type.c_str());
return size; return size;
} }
...@@ -77,7 +76,7 @@ class TensorBase { ...@@ -77,7 +76,7 @@ class TensorBase {
/*! Return the numel of the memory block. */ /*! Return the numel of the memory block. */
inline int64_t numel() const { return product(dims_); } inline int64_t numel() const { return product(dims_); }
std::string type() const { kTypeId_t type() const {
PADDLE_MOBILE_ENFORCE( PADDLE_MOBILE_ENFORCE(
holder_ != nullptr, holder_ != nullptr,
"Tensor not initialized yet when Tensor::type() is called.") "Tensor not initialized yet when Tensor::type() is called.")
...@@ -113,9 +112,9 @@ class TensorBase { ...@@ -113,9 +112,9 @@ class TensorBase {
virtual size_t size() const = 0; virtual size_t size() const = 0;
virtual std::string type() const = 0; virtual kTypeId_t type() const = 0;
virtual void set_type(std::string type) = 0; virtual void set_type(kTypeId_t type) = 0;
virtual void resize(size_t size) = 0; virtual void resize(size_t size) = 0;
}; };
......
...@@ -30,7 +30,7 @@ class Variable { ...@@ -30,7 +30,7 @@ class Variable {
template <typename T> template <typename T>
const T GetValue() const { const T GetValue() const {
if (type_id<T>().name() == type_id<std::string>().name()) { if (type_id<T>().hash_code() == type_id<std::string>().hash_code()) {
PADDLE_MOBILE_THROW_EXCEPTION( PADDLE_MOBILE_THROW_EXCEPTION(
"Please use getString to get an string (to avoid of an issue with " "Please use getString to get an string (to avoid of an issue with "
"gcc " "gcc "
...@@ -57,31 +57,32 @@ class Variable { ...@@ -57,31 +57,32 @@ class Variable {
template <typename T> template <typename T>
bool IsType() const { bool IsType() const {
return holder_ != nullptr && holder_->Type() == type_id<T>().name(); return holder_ != nullptr && holder_->Type() == type_id<T>().hash_code();
} }
void Clear() { holder_.reset(); } void Clear() { holder_.reset(); }
std::string Type() const { return holder_->Type(); } kTypeId_t Type() const { return holder_->Type(); }
private: private:
struct Placeholder { struct Placeholder {
Placeholder() = default; Placeholder() = default;
virtual ~Placeholder() = default; virtual ~Placeholder() = default;
virtual std::string Type() const = 0; virtual kTypeId_t Type() const = 0;
virtual void *Ptr() const = 0; virtual void *Ptr() const = 0;
}; };
template <typename T> template <typename T>
struct PlaceholderImp : public Placeholder { struct PlaceholderImp : public Placeholder {
explicit PlaceholderImp(T *ptr) : ptr_(ptr), type_(type_id<T>().name()) {} explicit PlaceholderImp(T *ptr)
: ptr_(ptr), type_(type_id<T>().hash_code()) {}
std::string Type() const override { return type_; } kTypeId_t Type() const override { return type_; }
void *Ptr() const override { return static_cast<void *>(ptr_.get()); } void *Ptr() const override { return static_cast<void *>(ptr_.get()); }
std::unique_ptr<T> ptr_; std::unique_ptr<T> ptr_;
std::string type_; kTypeId_t type_;
}; };
friend class Scope; friend class Scope;
......
...@@ -192,10 +192,10 @@ bool LessThanKernel<CPU, float>::Init(CompareParam<CPU> *param) { ...@@ -192,10 +192,10 @@ bool LessThanKernel<CPU, float>::Init(CompareParam<CPU> *param) {
template <> template <>
void LessThanKernel<CPU, float>::Compute(const CompareParam<CPU> &param) { void LessThanKernel<CPU, float>::Compute(const CompareParam<CPU> &param) {
if (param.input_x_->type() == type_id<int64_t>().name()) { if (param.input_x_->type() == type_id<int64_t>().hash_code()) {
CompareCompute<int64_t, LESS_THAN>()(param.input_x_, param.input_y_, CompareCompute<int64_t, LESS_THAN>()(param.input_x_, param.input_y_,
param.axis_, param.output_); param.axis_, param.output_);
} else if (param.input_x_->type() == type_id<float>().name()) { } else if (param.input_x_->type() == type_id<float>().hash_code()) {
CompareCompute<float, LESS_THAN>()(param.input_x_, param.input_y_, CompareCompute<float, LESS_THAN>()(param.input_x_, param.input_y_,
param.axis_, param.output_); param.axis_, param.output_);
} else { } else {
......
...@@ -27,7 +27,7 @@ bool ConcatKernel<CPU, float>::Init(ConcatParam<CPU> *param) { ...@@ -27,7 +27,7 @@ bool ConcatKernel<CPU, float>::Init(ConcatParam<CPU> *param) {
template <> template <>
void ConcatKernel<CPU, float>::Compute(const ConcatParam<CPU> &param) { void ConcatKernel<CPU, float>::Compute(const ConcatParam<CPU> &param) {
if (param.Inputs()[0]->type() == type_id<int8_t>().name()) { if (param.Inputs()[0]->type() == type_id<int8_t>().hash_code()) {
ConcatCompute<int8_t>(param); ConcatCompute<int8_t>(param);
} else { } else {
ConcatCompute<float>(param); ConcatCompute<float>(param);
......
...@@ -28,7 +28,7 @@ void InitBaseConvKernel(ConvParam<CPU> *param) { ...@@ -28,7 +28,7 @@ void InitBaseConvKernel(ConvParam<CPU> *param) {
bool depth5x5 = conv5x5 && param->Groups() == param->Input()->dims()[1] && bool depth5x5 = conv5x5 && param->Groups() == param->Input()->dims()[1] &&
param->Input()->dims()[1] == param->Output()->dims()[1]; param->Input()->dims()[1] == param->Output()->dims()[1];
if (param->Filter()->type() == type_id<int8_t>().name()) { if (param->Filter()->type() == type_id<int8_t>().hash_code()) {
#ifndef __aarch64__ #ifndef __aarch64__
if (depth3x3 && param->Strides()[0] < 3 && if (depth3x3 && param->Strides()[0] < 3 &&
param->Strides()[0] == param->Strides()[1]) { param->Strides()[0] == param->Strides()[1]) {
......
...@@ -126,13 +126,13 @@ void Transpose2Kernel<CPU, float>::Compute(const Transpose2Param<CPU> &param) { ...@@ -126,13 +126,13 @@ void Transpose2Kernel<CPU, float>::Compute(const Transpose2Param<CPU> &param) {
const std::vector<int> &axis = param.Axis(); const std::vector<int> &axis = param.Axis();
bool shuffle_channel = IsShuffleChannel(axis); bool shuffle_channel = IsShuffleChannel(axis);
if (shuffle_channel) { if (shuffle_channel) {
if (param.InputX()->type() == type_id<int8_t>().name()) { if (param.InputX()->type() == type_id<int8_t>().hash_code()) {
ShuffleChannelCompute<int8_t>(param); ShuffleChannelCompute<int8_t>(param);
} else { } else {
ShuffleChannelCompute<float>(param); ShuffleChannelCompute<float>(param);
} }
} else { } else {
if (param.InputX()->type() == type_id<int8_t>().name()) { if (param.InputX()->type() == type_id<int8_t>().hash_code()) {
Transpose2Compute<int8_t>(param); Transpose2Compute<int8_t>(param);
} else { } else {
Transpose2Compute<float>(param); Transpose2Compute<float>(param);
......
...@@ -37,7 +37,7 @@ void MulCompute(const MulParam<CPU> &param) { ...@@ -37,7 +37,7 @@ void MulCompute(const MulParam<CPU> &param) {
if (out_dim.size() != 2) { if (out_dim.size() != 2) {
out->Resize({x_matrix.dims()[0], y_matrix.dims()[1]}); out->Resize({x_matrix.dims()[0], y_matrix.dims()[1]});
} }
if (param.InputX()->type() == type_id<int8_t>().name()) { if (param.InputX()->type() == type_id<int8_t>().hash_code()) {
out->mutable_data<int32_t>(); out->mutable_data<int32_t>();
math::MatMul<int8_t, int32_t>(x_matrix, false, y_matrix, false, math::MatMul<int8_t, int32_t>(x_matrix, false, y_matrix, false,
static_cast<float>(1), out, static_cast<float>(1), out,
......
...@@ -144,8 +144,7 @@ void SumCompute(const SumParam<CPU> &param) { ...@@ -144,8 +144,7 @@ void SumCompute(const SumParam<CPU> &param) {
} }
} else { } else {
PADDLE_MOBILE_THROW_EXCEPTION( PADDLE_MOBILE_THROW_EXCEPTION(
"Unexpected branch, output variable type is %s", "Unexpected branch, output variable type is %d", outvar->Type());
outvar->Type().c_str());
} }
} }
} // namespace operators } // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册