提交 6fd788f7 编写于 作者: H hjchen2

Replace C++ typeid with self-defined type_id to avoid implementation difference

上级 288148fe
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_EXECUTOR_MULTITHREAD
#include <string>
#include <unordered_map>
#include <vector>
#include "framework/operator.h"
namespace paddle_mobile {
class depCore {
public:
template <typename Dtype>
void analysisDep(
const std::vector<std::shared_ptr<framework::OperatorBase<Dtype>>>& ops) {
std::unordered_map<std::string, int> vars;
size_t nop = ops.size();
deps.resize(nop);
next.resize(nop);
for (size_t i = 0; i < nop; i++) {
const auto& op = ops[i];
for (const auto& kv : op->Inputs()) {
for (const auto& v : kv.second) {
if (vars.find(v) == vars.end()) {
continue;
}
int di = vars[v];
if (di == i) {
continue;
}
if (std::find(deps[i].begin(), deps[i].end(), di) != deps[i].end()) {
continue;
}
deps[i].push_back(di);
next[di].push_back(i);
}
}
for (const auto& kv : op->Outputs()) {
for (const auto& v : kv.second) {
vars[v] = i;
}
}
}
}
const std::vector<int>& getNext(int i) { return next[i]; }
const std::vector<int>& getDeps(int i) { return deps[i]; }
std::vector<std::vector<int>> deps;
std::vector<std::vector<int>> next;
};
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#define EXPORT __attribute__((visibility("default")))
...@@ -14,33 +14,91 @@ limitations under the License. */ ...@@ -14,33 +14,91 @@ limitations under the License. */
#pragma once #pragma once
#include <functional>
#include <map>
#include <string> #include <string>
#include <vector> #include <vector>
#include "framework/attribute.h"
#include "framework/scope.h"
namespace paddle_mobile { namespace paddle_mobile {
template <typename T>
struct TypeIdWrapper {
std::string type();
};
template <typename T>
struct type_id {
const std::string type_ = TypeIdWrapper<T>().type();
template <typename OtherType>
bool operator==(const type_id<OtherType> &operand) {
return this->name() == operand.name();
}
const std::string name() { return type_; }
};
namespace framework { namespace framework {
template <typename Dtype>
class OperatorBase;
class OpDesc;
class BlockDesc; class BlockDesc;
class InferShapeContext; class Tensor;
class LoDTensor;
class SelectedRows;
class Scope;
template <int>
struct Dim;
} // namespace framework } // namespace framework
using VariableNameMap = std::map<std::string, std::vector<std::string>>; #define REGISTER_TYPE_ID(Type, TypeName) \
template <> \
struct TypeIdWrapper<Type> { \
std::string type() { return std::string(#TypeName); } \
};
REGISTER_TYPE_ID(void, _void)
REGISTER_TYPE_ID(float, _float)
REGISTER_TYPE_ID(int, _int)
REGISTER_TYPE_ID(double, _double)
REGISTER_TYPE_ID(int64_t, _int64_t)
REGISTER_TYPE_ID(size_t, _size_t)
REGISTER_TYPE_ID(int16_t, _int16_t)
REGISTER_TYPE_ID(int8_t, _int8_t)
REGISTER_TYPE_ID(uint8_t, _uint8_t)
REGISTER_TYPE_ID(bool, _bool)
REGISTER_TYPE_ID(std::string, _string)
REGISTER_TYPE_ID(std::vector<float>, _floats)
REGISTER_TYPE_ID(std::vector<int>, _ints)
REGISTER_TYPE_ID(std::vector<int64_t>, _int64_ts)
REGISTER_TYPE_ID(std::vector<size_t>, _size_ts)
REGISTER_TYPE_ID(std::vector<bool>, _bools)
REGISTER_TYPE_ID(std::vector<std::string>, _strings)
REGISTER_TYPE_ID(float const, _const_float)
REGISTER_TYPE_ID(int const, _const_int)
REGISTER_TYPE_ID(framework::BlockDesc, _block)
REGISTER_TYPE_ID(framework::Tensor, _tensor)
REGISTER_TYPE_ID(framework::LoDTensor, _lod_tensor)
REGISTER_TYPE_ID(std::vector<framework::BlockDesc>, _blocks)
REGISTER_TYPE_ID(std::vector<framework::Tensor>, _tensors)
REGISTER_TYPE_ID(std::vector<framework::LoDTensor>, _lod_tensors)
template <typename Dtype> REGISTER_TYPE_ID(framework::BlockDesc *, _p_block)
using OpCreator = std::function<framework::OperatorBase<Dtype> *( REGISTER_TYPE_ID(framework::Tensor *, _p_tensor)
const std::string & /*type*/, const VariableNameMap & /*inputs*/, REGISTER_TYPE_ID(framework::LoDTensor *, _p_lod_tensor)
const VariableNameMap & /*outputs*/, REGISTER_TYPE_ID(std::vector<framework::BlockDesc *>, _p_blocks)
const framework::AttributeMap & /*attrs*/, framework::Scope * /*scope*/)>; REGISTER_TYPE_ID(std::vector<framework::Tensor *>, _p_tensors)
REGISTER_TYPE_ID(std::vector<framework::LoDTensor *>, _p_lod_tensors)
using InferVarTypeFN = std::function<void(const framework::OpDesc & /*op_desc*/, REGISTER_TYPE_ID(std::vector<framework::Scope *>, _scopes);
framework::BlockDesc * /*block*/)>; REGISTER_TYPE_ID(framework::SelectedRows, _selected_rows)
REGISTER_TYPE_ID(framework::Dim<0>, _dim0)
REGISTER_TYPE_ID(framework::Dim<1>, _dim1)
REGISTER_TYPE_ID(framework::Dim<2>, _dim2)
REGISTER_TYPE_ID(framework::Dim<3>, _dim3)
REGISTER_TYPE_ID(framework::Dim<4>, _dim4)
REGISTER_TYPE_ID(framework::Dim<5>, _dim5)
REGISTER_TYPE_ID(framework::Dim<6>, _dim6)
REGISTER_TYPE_ID(framework::Dim<7>, _dim7)
REGISTER_TYPE_ID(framework::Dim<8>, _dim8)
REGISTER_TYPE_ID(framework::Dim<9>, _dim9)
using InferShapeFN = std::function<void(framework::InferShapeContext *)>; } // namespace paddle_mobile
}; // namespace paddle_mobile
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <map>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
...@@ -211,4 +212,6 @@ extern std::unordered_map< ...@@ -211,4 +212,6 @@ extern std::unordered_map<
std::string, std::pair<std::vector<std::string>, std::vector<std::string>>> std::string, std::pair<std::vector<std::string>, std::vector<std::string>>>
op_input_output_key; op_input_output_key;
typedef std::map<std::string, std::vector<std::string>> VariableNameMap;
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
#include <string> #include <string>
#include "common/enforce.h" #include "common/enforce.h"
#include "common/log.h" #include "common/log.h"
#include "common/type_define.h"
namespace paddle_mobile { namespace paddle_mobile {
...@@ -33,11 +34,11 @@ struct VariantHelper { ...@@ -33,11 +34,11 @@ struct VariantHelper {
? sizeof(F) ? sizeof(F)
: VariantHelper<Ts...>::size; : VariantHelper<Ts...>::size;
inline static void Destroy(size_t id, void *data) { inline static void Destroy(std::string type, void *data) {
if (id == typeid(F).hash_code()) { if (type == type_id<F>().name()) {
reinterpret_cast<F *>(data)->~F(); reinterpret_cast<F *>(data)->~F();
} else { } else {
VariantHelper<Ts...>::Destroy(id, data); VariantHelper<Ts...>::Destroy(type, data);
} }
} }
}; };
...@@ -45,11 +46,11 @@ struct VariantHelper { ...@@ -45,11 +46,11 @@ struct VariantHelper {
template <typename F> template <typename F>
struct VariantHelper<F> { struct VariantHelper<F> {
static const size_t size = sizeof(F); static const size_t size = sizeof(F);
inline static void Destroy(size_t id, void *data) { inline static void Destroy(std::string type, void *data) {
if (id == typeid(F).hash_code()) { if (type == type_id<F>().name()) {
// reinterpret_cast<F*>(data)->~F(); // reinterpret_cast<F*>(data)->~F();
} else { } else {
// std::cout << "未匹配到 " << std::endl; // std::cout << "未匹配到 " << std::endl;
} }
} }
}; };
...@@ -57,7 +58,7 @@ struct VariantHelper<F> { ...@@ -57,7 +58,7 @@ struct VariantHelper<F> {
template <size_t size> template <size_t size>
class RawData { class RawData {
public: public:
char data[size]; char data[size]; // NOLINT
RawData() {} RawData() {}
RawData(const RawData &raw_data) { memcpy(data, raw_data.data, size); } RawData(const RawData &raw_data) { memcpy(data, raw_data.data, size); }
...@@ -69,32 +70,33 @@ class RawData { ...@@ -69,32 +70,33 @@ class RawData {
template <typename... Ts> template <typename... Ts>
struct Variant { struct Variant {
Variant() : type_(invalid_type()) {}
Variant(const Variant &variant) { Variant(const Variant &variant) {
type_id = variant.type_id; type_ = variant.type_;
data = variant.data; data_ = variant.data_;
} }
Variant() : type_id(invalid_type()) {} virtual ~Variant() {
~Variant() { // helper::Destroy(type_id, &data);
// helper::Destroy(type_id, &data);
} }
template <typename T, typename... Args> template <typename T, typename... Args>
void Set(Args &&... args) { void Set(Args &&... args) {
helper::Destroy(type_id, data.data); helper::Destroy(type_, data_.data);
new (data.data) T(std::forward<Args>(args)...); new (data_.data) T(std::forward<Args>(args)...);
type_id = typeid(T).hash_code(); type_ = type_id<T>().name();
} }
void SetString(std::string &string) { void SetString(const std::string &string) {
helper::Destroy(type_id, data.data); helper::Destroy(type_, data_.data);
type_id = typeid(std::string).hash_code(); type_ = type_id<std::string>().name();
strcpy(data.data, string.c_str()); strcpy(data_.data, string.c_str()); // NOLINT
} }
std::string GetString() const { std::string GetString() const {
if (type_id == typeid(std::string).hash_code()) { if (type_ == type_id<std::string>().name()) {
return std::string(data.data); return std::string(data_.data);
} else { } else {
PADDLE_MOBILE_THROW_EXCEPTION( PADDLE_MOBILE_THROW_EXCEPTION(
" bad cast in variant data type not a string "); " bad cast in variant data type not a string ");
...@@ -104,28 +106,25 @@ struct Variant { ...@@ -104,28 +106,25 @@ struct Variant {
template <typename T> template <typename T>
T &Get() const { T &Get() const {
if (type_id == typeid(std::string).hash_code()) { if (type_ == type_id<std::string>().name()) {
PADDLE_MOBILE_THROW_EXCEPTION( PADDLE_MOBILE_THROW_EXCEPTION(
"Please use getString to get an string (to avoid of an issue with " "Please use getString to get an string (to avoid of an issue with "
"gcc " "gcc "
"stl lib with string copy)"); "stl lib with string copy)");
exit(0); exit(0);
} else if (type_id == typeid(T).hash_code()) {
return *const_cast<T *>(reinterpret_cast<const T *>(data.data));
} else { } else {
PADDLE_MOBILE_THROW_EXCEPTION(" bad cast in variant"); return *const_cast<T *>(reinterpret_cast<const T *>(data_.data));
exit(0);
} }
} }
size_t TypeId() const { return type_id; } std::string TypeId() const { return type_; }
private: private:
static inline size_t invalid_type() { return typeid(void).hash_code(); } static inline std::string invalid_type() { return type_id<void>().name(); }
typedef VariantHelper<Ts...> helper; typedef VariantHelper<Ts...> helper;
size_t type_id; std::string type_ = type_id<void>().name();
// todo use an anto size to suite this. // todo use an anto size to suite this.
RawData<64> data; RawData<64> data_;
}; };
template <typename T> template <typename T>
......
...@@ -128,31 +128,31 @@ class Attribute { ...@@ -128,31 +128,31 @@ class Attribute {
template <typename Vistor> template <typename Vistor>
static typename Vistor::type_t ApplyVistor(Vistor vistor, Attribute attr) { static typename Vistor::type_t ApplyVistor(Vistor vistor, Attribute attr) {
if (attr.variant_.TypeId() == typeid(int).hash_code()) { // NOLINT if (attr.variant_.TypeId() == type_id<int>().name()) { // NOLINT
return vistor(attr.variant_.Get<int>()); return vistor(attr.variant_.Get<int>());
} else if (attr.variant_.TypeId() == typeid(float).hash_code()) { // NOLINT } else if (attr.variant_.TypeId() == type_id<float>().name()) { // NOLINT
return vistor(attr.variant_.Get<float>()); return vistor(attr.variant_.Get<float>());
} else if (attr.variant_.TypeId() == typeid(string).hash_code()) { } else if (attr.variant_.TypeId() == type_id<string>().name()) {
return vistor(attr.variant_.GetString()); return vistor(attr.variant_.GetString());
} else if (attr.variant_.TypeId() == typeid(vector<int>).hash_code()) { } else if (attr.variant_.TypeId() == type_id<vector<int>>().name()) {
return vistor(attr.variant_.Get<vector<int>>()); return vistor(attr.variant_.Get<vector<int>>());
} else if (attr.variant_.TypeId() == typeid(vector<float>).hash_code()) { } else if (attr.variant_.TypeId() == type_id<vector<float>>().name()) {
return vistor(attr.variant_.Get<vector<float>>()); return vistor(attr.variant_.Get<vector<float>>());
} else if (attr.variant_.TypeId() == typeid(vector<string>).hash_code()) { } else if (attr.variant_.TypeId() == type_id<vector<string>>().name()) {
return vistor(attr.variant_.Get<vector<string>>()); return vistor(attr.variant_.Get<vector<string>>());
} else if (attr.variant_.TypeId() == typeid(bool).hash_code()) { // NOLINT } else if (attr.variant_.TypeId() == type_id<bool>().name()) { // NOLINT
return vistor(attr.variant_.Get<bool>()); return vistor(attr.variant_.Get<bool>());
} else if (attr.variant_.TypeId() == typeid(vector<bool>).hash_code()) { } else if (attr.variant_.TypeId() == type_id<vector<bool>>().name()) {
return vistor(attr.variant_.Get<vector<bool>>()); return vistor(attr.variant_.Get<vector<bool>>());
} else if (attr.variant_.TypeId() == typeid(int64_t).hash_code()) { } else if (attr.variant_.TypeId() == type_id<int64_t>().name()) {
return vistor(attr.variant_.Get<int64_t>()); return vistor(attr.variant_.Get<int64_t>());
} else if (attr.variant_.TypeId() == } else if (attr.variant_.TypeId() ==
typeid(framework::BlockDesc *).hash_code()) { type_id<framework::BlockDesc *>().name()) {
return vistor(attr.variant_.Get<framework::BlockDesc *>()); return vistor(attr.variant_.Get<framework::BlockDesc *>());
} else if (attr.variant_.TypeId() == } else if (attr.variant_.TypeId() ==
typeid(vector<framework::BlockDesc *>).hash_code()) { type_id<vector<framework::BlockDesc *>>().name()) {
return vistor(attr.variant_.Get<vector<framework::BlockDesc *>>()); return vistor(attr.variant_.Get<vector<framework::BlockDesc *>>());
} else if (attr.variant_.TypeId() == typeid(vector<int64_t>).hash_code()) { } else if (attr.variant_.TypeId() == type_id<vector<int64_t>>().name()) {
return vistor(attr.variant_.Get<vector<int64_t>>()); return vistor(attr.variant_.Get<vector<int64_t>>());
} else { } else {
PADDLE_MOBILE_THROW_EXCEPTION("type not support"); PADDLE_MOBILE_THROW_EXCEPTION("type not support");
......
...@@ -53,12 +53,12 @@ class CLTensor : TensorBase { ...@@ -53,12 +53,12 @@ class CLTensor : TensorBase {
int64_t size = numel() * sizeof(T); int64_t size = numel() * sizeof(T);
holder_.reset(new PlaceholderImpl( holder_.reset(new PlaceholderImpl(
size, reinterpret_cast<void *>(const_cast<T *>(data)), typeid(T), size, reinterpret_cast<void *>(const_cast<T *>(data)), type_id<T>(),
context_, command_queue_)); context_, command_queue_));
return reinterpret_cast<cl_mem>(holder_->ptr()); return reinterpret_cast<cl_mem>(holder_->ptr());
} }
inline cl_mem mutable_data(std::type_index type) { inline cl_mem mutable_data(std::string type) {
if (holder_ != nullptr) { if (holder_ != nullptr) {
holder_->set_type(type); holder_->set_type(type);
} }
...@@ -77,7 +77,7 @@ class CLTensor : TensorBase { ...@@ -77,7 +77,7 @@ class CLTensor : TensorBase {
*/ */
template <typename T> template <typename T>
inline cl_mem mutable_data() { inline cl_mem mutable_data() {
return reinterpret_cast<cl_mem>(mutable_data(typeid(T))); return reinterpret_cast<cl_mem>(mutable_data(type_id<T>()));
} }
/** /**
...@@ -132,7 +132,7 @@ class CLTensor : TensorBase { ...@@ -132,7 +132,7 @@ class CLTensor : TensorBase {
void *host_ptr_ = nullptr; void *host_ptr_ = nullptr;
struct PlaceholderImpl : public Placeholder { struct PlaceholderImpl : public Placeholder {
PlaceholderImpl(size_t size, void *input, std::type_index type, PlaceholderImpl(size_t size, void *input, std::string type,
cl_context context, cl_command_queue command_queue) cl_context context, cl_command_queue command_queue)
: ptr_(clCreateBuffer(context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, : ptr_(clCreateBuffer(context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR,
size, reinterpret_cast<void *>(input), NULL)), size, reinterpret_cast<void *>(input), NULL)),
...@@ -142,7 +142,7 @@ class CLTensor : TensorBase { ...@@ -142,7 +142,7 @@ class CLTensor : TensorBase {
context_(context), context_(context),
command_queue_(command_queue) {} command_queue_(command_queue) {}
PlaceholderImpl(size_t size, std::type_index type, cl_context context, PlaceholderImpl(size_t size, std::string type, cl_context context,
cl_command_queue command_queue) cl_command_queue command_queue)
: ptr_(clCreateBuffer(context, CL_MEM_READ_WRITE, size, NULL, NULL)), : ptr_(clCreateBuffer(context, CL_MEM_READ_WRITE, size, NULL, NULL)),
size_(size), size_(size),
...@@ -155,9 +155,9 @@ class CLTensor : TensorBase { ...@@ -155,9 +155,9 @@ class CLTensor : TensorBase {
virtual void *ptr() const { return static_cast<void *>(ptr_.get()); } virtual void *ptr() const { return static_cast<void *>(ptr_.get()); }
virtual std::type_index type() const { return type_; } virtual std::string type() const { return type_; }
virtual void set_type(std::type_index type) { type_ = type; } virtual void set_type(std::string type) { type_ = type; }
virtual void resize(size_t size) { virtual void resize(size_t size) {
if (size > capatity_) { if (size > capatity_) {
...@@ -175,7 +175,7 @@ class CLTensor : TensorBase { ...@@ -175,7 +175,7 @@ class CLTensor : TensorBase {
size_t capatity_; size_t capatity_;
/* the current type of memory */ /* the current type of memory */
std::type_index type_; std::string type_;
cl_context context_; cl_context context_;
cl_command_queue command_queue_; cl_command_queue command_queue_;
......
...@@ -16,17 +16,18 @@ limitations under the License. */ ...@@ -16,17 +16,18 @@ limitations under the License. */
#include <stdint.h> #include <stdint.h>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "common/type_define.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
struct DataTypeMap { struct DataTypeMap {
std::unordered_map<std::type_index, std::unordered_map<std::string,
_PaddleMobile__Framework__Proto__VarType__Type> _PaddleMobile__Framework__Proto__VarType__Type>
cpp_to_proto_; cpp_to_proto_;
std::unordered_map<int, std::type_index> proto_to_cpp_; std::unordered_map<int, std::string> proto_to_cpp_;
std::unordered_map<int, std::string> proto_to_str_; std::unordered_map<int, std::string> proto_to_str_;
std::unordered_map<std::type_index, size_t> cpp_to_size_; std::unordered_map<std::string, size_t> cpp_to_size_;
}; };
static DataTypeMap* InitDataTypeMap(); static DataTypeMap* InitDataTypeMap();
...@@ -42,10 +43,10 @@ template <typename T> ...@@ -42,10 +43,10 @@ template <typename T>
static inline void RegisterType( static inline void RegisterType(
DataTypeMap* map, _PaddleMobile__Framework__Proto__VarType__Type proto_type, DataTypeMap* map, _PaddleMobile__Framework__Proto__VarType__Type proto_type,
const std::string& name) { const std::string& name) {
map->proto_to_cpp_.emplace(static_cast<int>(proto_type), typeid(T)); map->proto_to_cpp_.emplace(static_cast<int>(proto_type), type_id<T>().name());
map->cpp_to_proto_.emplace(typeid(T), proto_type); map->cpp_to_proto_.emplace(type_id<T>().name(), proto_type);
map->proto_to_str_.emplace(static_cast<int>(proto_type), name); map->proto_to_str_.emplace(static_cast<int>(proto_type), name);
map->cpp_to_size_.emplace(typeid(T), sizeof(T)); map->cpp_to_size_.emplace(type_id<T>().name(), sizeof(T));
} }
static DataTypeMap* InitDataTypeMap() { static DataTypeMap* InitDataTypeMap() {
...@@ -70,17 +71,15 @@ static DataTypeMap* InitDataTypeMap() { ...@@ -70,17 +71,15 @@ static DataTypeMap* InitDataTypeMap() {
return retv; return retv;
} }
_PaddleMobile__Framework__Proto__VarType__Type ToDataType( _PaddleMobile__Framework__Proto__VarType__Type ToDataType(std::string type) {
std::type_index type) {
auto it = gDataTypeMap().cpp_to_proto_.find(type); auto it = gDataTypeMap().cpp_to_proto_.find(type);
if (it != gDataTypeMap().cpp_to_proto_.end()) { if (it != gDataTypeMap().cpp_to_proto_.end()) {
return it->second; return it->second;
} }
PADDLE_MOBILE_THROW_EXCEPTION("Not support %s as tensor type", type.name()); PADDLE_MOBILE_THROW_EXCEPTION("Not support %s as tensor type", type.c_str());
} }
std::type_index ToTypeIndex( std::string ToTypeIndex(_PaddleMobile__Framework__Proto__VarType__Type type) {
_PaddleMobile__Framework__Proto__VarType__Type type) {
auto it = gDataTypeMap().proto_to_cpp_.find(static_cast<int>(type)); auto it = gDataTypeMap().proto_to_cpp_.find(static_cast<int>(type));
if (it != gDataTypeMap().proto_to_cpp_.end()) { if (it != gDataTypeMap().proto_to_cpp_.end()) {
return it->second; return it->second;
......
...@@ -15,7 +15,6 @@ limitations under the License. */ ...@@ -15,7 +15,6 @@ limitations under the License. */
#pragma once #pragma once
#include <string> #include <string>
#include <typeindex>
#include "common/enforce.h" #include "common/enforce.h"
#include "framework/framework.pb-c.h" #include "framework/framework.pb-c.h"
...@@ -24,8 +23,8 @@ namespace paddle_mobile { ...@@ -24,8 +23,8 @@ namespace paddle_mobile {
namespace framework { namespace framework {
extern _PaddleMobile__Framework__Proto__VarType__Type ToDataType( extern _PaddleMobile__Framework__Proto__VarType__Type ToDataType(
std::type_index type); std::string type);
extern std::type_index ToTypeIndex( extern std::string ToTypeIndex(
_PaddleMobile__Framework__Proto__VarType__Type type); _PaddleMobile__Framework__Proto__VarType__Type type);
inline _PaddleMobile__Framework__Proto__VarType__Type ToDataType(int type) { inline _PaddleMobile__Framework__Proto__VarType__Type ToDataType(int type) {
......
...@@ -22,7 +22,7 @@ limitations under the License. */ ...@@ -22,7 +22,7 @@ limitations under the License. */
#include "common/enforce.h" #include "common/enforce.h"
#include "common/variant.h" #include "common/variant.h"
#include "dim.h" #include "framework/dim.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
...@@ -40,25 +40,25 @@ struct DDim { ...@@ -40,25 +40,25 @@ struct DDim {
template <typename Vistor> template <typename Vistor>
static typename Vistor::type_t ApplyVistor(Vistor vistor, const DDim &d) { static typename Vistor::type_t ApplyVistor(Vistor vistor, const DDim &d) {
if (d.var.TypeId() == typeid(Dim<0>).hash_code()) { if (d.var.TypeId() == type_id<Dim<0>>().name()) {
return vistor(d.var.Get<Dim<0>>()); return vistor(d.var.Get<Dim<0>>());
} else if (d.var.TypeId() == typeid(Dim<1>).hash_code()) { } else if (d.var.TypeId() == type_id<Dim<1>>().name()) {
return vistor(d.var.Get<Dim<1>>()); return vistor(d.var.Get<Dim<1>>());
} else if (d.var.TypeId() == typeid(Dim<2>).hash_code()) { } else if (d.var.TypeId() == type_id<Dim<2>>().name()) {
return vistor(d.var.Get<Dim<2>>()); return vistor(d.var.Get<Dim<2>>());
} else if (d.var.TypeId() == typeid(Dim<3>).hash_code()) { } else if (d.var.TypeId() == type_id<Dim<3>>().name()) {
return vistor(d.var.Get<Dim<3>>()); return vistor(d.var.Get<Dim<3>>());
} else if (d.var.TypeId() == typeid(Dim<4>).hash_code()) { } else if (d.var.TypeId() == type_id<Dim<4>>().name()) {
return vistor(d.var.Get<Dim<4>>()); return vistor(d.var.Get<Dim<4>>());
} else if (d.var.TypeId() == typeid(Dim<5>).hash_code()) { } else if (d.var.TypeId() == type_id<Dim<5>>().name()) {
return vistor(d.var.Get<Dim<5>>()); return vistor(d.var.Get<Dim<5>>());
} else if (d.var.TypeId() == typeid(Dim<6>).hash_code()) { } else if (d.var.TypeId() == type_id<Dim<6>>().name()) {
return vistor(d.var.Get<Dim<6>>()); return vistor(d.var.Get<Dim<6>>());
} else if (d.var.TypeId() == typeid(Dim<7>).hash_code()) { } else if (d.var.TypeId() == type_id<Dim<7>>().name()) {
return vistor(d.var.Get<Dim<7>>()); return vistor(d.var.Get<Dim<7>>());
} else if (d.var.TypeId() == typeid(Dim<8>).hash_code()) { } else if (d.var.TypeId() == type_id<Dim<8>>().name()) {
return vistor(d.var.Get<Dim<8>>()); return vistor(d.var.Get<Dim<8>>());
} else if (d.var.TypeId() == typeid(Dim<9>).hash_code()) { } else if (d.var.TypeId() == type_id<Dim<9>>().name()) {
return vistor(d.var.Get<Dim<9>>()); return vistor(d.var.Get<Dim<9>>());
} else { } else {
PADDLE_MOBILE_ENFORCE(false, " dim not support"); PADDLE_MOBILE_ENFORCE(false, " dim not support");
......
...@@ -63,7 +63,7 @@ Executor<Device, T>::Executor(const Program<Device> &program, ...@@ -63,7 +63,7 @@ Executor<Device, T>::Executor(const Program<Device> &program,
PADDLE_MOBILE_ENFORCE(program_desc_ != nullptr, PADDLE_MOBILE_ENFORCE(program_desc_ != nullptr,
"program_desc_ should not be nullptr"); "program_desc_ should not be nullptr");
#ifndef PADDLE_MOBILE_FPGA #ifndef PADDLE_MOBILE_FPGA
pass::MemoryOptPass()(program_desc_.get(), program_.scope.get()); // pass::MemoryOptPass()(program_desc_.get(), program_.scope.get());
#endif #endif
// resize feed and fetch list // resize feed and fetch list
// should init feed and fetch variables before infer shape // should init feed and fetch variables before infer shape
...@@ -302,25 +302,9 @@ bool Executor<Device, T>::varInputMemory( ...@@ -302,25 +302,9 @@ bool Executor<Device, T>::varInputMemory(
const std::shared_ptr<VarDesc> &var_desc, Variable *var) const { const std::shared_ptr<VarDesc> &var_desc, Variable *var) const {
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
framework::LoDTensor *tensor = var->template GetMutable<LoDTensor>(); framework::LoDTensor *tensor = var->template GetMutable<LoDTensor>();
tensor->init(typeid(float)); tensor->init(type_id<float>());
return true; return true;
#endif #endif
auto TypeId = [](const VarType_Type &type) -> std::type_index {
switch (type) {
case VARTYPE_TYPE_BOOL:
return typeid(bool);
case VARTYPE_TYPE_FP32:
return typeid(float);
case VARTYPE_TYPE_INT8:
return typeid(int8_t);
case VARTYPE_TYPE_INT32:
return typeid(int);
case VARTYPE_TYPE_INT64:
return typeid(int64_t);
default:
PADDLE_MOBILE_THROW_EXCEPTION("got unhandled var type `%d`", type);
}
};
auto type = var_desc->Type(); auto type = var_desc->Type();
if (type == VARTYPE_TYPE_LOD_TENSOR) { if (type == VARTYPE_TYPE_LOD_TENSOR) {
...@@ -390,13 +374,6 @@ void Executor<Device, T>::SetInput(const Tensor &input, ...@@ -390,13 +374,6 @@ void Executor<Device, T>::SetInput(const Tensor &input,
framework::LoDTensor &target = framework::LoDTensor &target =
feed_var->template GetMutable<framework::LoDTensorArray>()->at(index); feed_var->template GetMutable<framework::LoDTensorArray>()->at(index);
if (config_.load_when_predict) {
if (input_dim_last_ != input.dims()) {
InitNoPersistableMemory(input);
input_dim_last_ = input.dims();
}
}
target.Resize(input.dims()); target.Resize(input.dims());
target.ShareDataWith(input); target.ShareDataWith(input);
} }
...@@ -412,13 +389,6 @@ void Executor<Device, T>::SetInput(const LoDTensor &input, ...@@ -412,13 +389,6 @@ void Executor<Device, T>::SetInput(const LoDTensor &input,
framework::LoDTensor &target = framework::LoDTensor &target =
feed_var->template GetMutable<framework::LoDTensorArray>()->at(index); feed_var->template GetMutable<framework::LoDTensorArray>()->at(index);
if (config_.load_when_predict) {
if (input_dim_last_ != input.dims()) {
InitNoPersistableMemory(input);
input_dim_last_ = input.dims();
}
}
target.Resize(input.dims()); target.Resize(input.dims());
target.ShareDataWith(input); target.ShareDataWith(input);
target.set_lod(input.lod()); target.set_lod(input.lod());
......
...@@ -12,52 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,52 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "lod_tensor.h" #include "framework/lod_tensor.h"
#include <algorithm> #include <algorithm>
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
// std::ostream &operator<<(std::ostream &os, const LoD &lod) {
// os << "{";
// for (auto &v : lod) {
// os << "{";
// bool is_first = true;
// for (auto &i : v) {
// if (is_first) {
// os << i;
// is_first = false;
// } else {
// os << ", " << i;
// }
// }
// os << "}";
// }
// os << "}";
//
// return os;
//}
//
// std::ostream &operator<<(std::ostream &os, const LoDTensor &t) {
// PADDLE_MOBILE_ENFORCE(t.type().hash_code() == typeid(float).hash_code(),
// "t.type() is not float");
// os << "dim: " << t.dims() << "\n";
// os << "lod: " << t.lod() << "\n";
// // only print first ten elements
// int64_t size = t.numel() < 10 ? t.numel() : 10;
// for (int64_t i = 0; i < size; ++i) {
// os << t.data<float>()[i] << " ";
// }
//
// return os;
//}
// std::string LoDToString(const LoD &lod) {
// std::ostringstream stream;
// stream << lod;
// return stream.str();
//}
LoD SliceInLevel(const LoD &in, size_t level, size_t elem_begin, LoD SliceInLevel(const LoD &in, size_t level, size_t elem_begin,
size_t elem_end) { size_t elem_end) {
PADDLE_MOBILE_ENFORCE(level < in.size(), "level should >= in.size()"); PADDLE_MOBILE_ENFORCE(level < in.size(), "level should >= in.size()");
......
...@@ -211,17 +211,17 @@ inline Print &operator<<(Print &printer, const LoDTensor &tensor) { ...@@ -211,17 +211,17 @@ inline Print &operator<<(Print &printer, const LoDTensor &tensor) {
stride = stride > 0 ? stride : 1; stride = stride > 0 ? stride : 1;
#ifndef PADDLE_MOBILE_FPGA #ifndef PADDLE_MOBILE_FPGA
for (int i = 0; i < tensor.numel(); i += stride) { for (int i = 0; i < tensor.numel(); i += stride) {
if (tensor.type() == typeid(float)) { if (tensor.type() == type_id<float>()) {
printer << tensor.data<float>()[i] << " "; printer << tensor.data<float>()[i] << " ";
} else if (tensor.type() == typeid(int32_t)) { } else if (tensor.type() == type_id<int32_t>()) {
printer << tensor.data<int32_t>()[i] << " "; printer << tensor.data<int32_t>()[i] << " ";
} else if (tensor.type() == typeid(int64_t)) { } else if (tensor.type() == type_id<int64_t>()) {
printer << tensor.data<int64_t>()[i] << " "; printer << tensor.data<int64_t>()[i] << " ";
} else if (tensor.type() == typeid(int8_t)) { } else if (tensor.type() == type_id<int8_t>()) {
printer << static_cast<int>(tensor.data<int8_t>()[i]) << " "; printer << static_cast<int>(tensor.data<int8_t>()[i]) << " ";
} else if (tensor.type() == typeid(int32_t)) { } else if (tensor.type() == type_id<int32_t>()) {
printer << tensor.data<int32_t>()[i] << " "; printer << tensor.data<int32_t>()[i] << " ";
} else if (tensor.type() == typeid(bool)) { } else if (tensor.type() == type_id<bool>()) {
printer << tensor.data<bool>()[i] << " "; printer << tensor.data<bool>()[i] << " ";
} }
} }
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
#include <algorithm> #include <algorithm>
#include <initializer_list> #include <initializer_list>
#include <vector> #include <vector>
#include "framework/tensor.h" #include "framework/tensor.h"
#include "framework/tensor_util.h" #include "framework/tensor_util.h"
...@@ -198,7 +197,7 @@ class Vector { ...@@ -198,7 +197,7 @@ class Vector {
} }
size_t capacity() const { size_t capacity() const {
return cpu_vec_.memory_size() / SizeOfType(typeid(T)); return cpu_vec_.memory_size() / SizeOfType(type_id<T>().name());
} }
// reserve data // reserve data
......
...@@ -14,13 +14,24 @@ limitations under the License. */ ...@@ -14,13 +14,24 @@ limitations under the License. */
#pragma once #pragma once
#include <functional>
#include <string> #include <string>
#include "common/log.h" #include "common/log.h"
#include "common/type_define.h" #include "common/type_define.h"
#include "framework/scope.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
template <typename Dtype>
class OperatorBase;
template <typename Dtype>
using OpCreator = std::function<framework::OperatorBase<Dtype> *(
const std::string & /*type*/, const VariableNameMap & /*inputs*/,
const VariableNameMap & /*outputs*/,
const framework::AttributeMap & /*attrs*/, framework::Scope * /*scope*/)>;
template <typename Dtype> template <typename Dtype>
struct OpInfo { struct OpInfo {
OpCreator<Dtype> creator_; OpCreator<Dtype> creator_;
...@@ -79,8 +90,6 @@ class OpInfoMap { ...@@ -79,8 +90,6 @@ class OpInfoMap {
private: private:
OpInfoMap() = default; OpInfoMap() = default;
std::unordered_map<std::string, OpInfo<Dtype>> map_; std::unordered_map<std::string, OpInfo<Dtype>> map_;
// DISABLE_COPY_AND_ASSIGN(OpInfoMap);
}; };
} // namespace framework } // namespace framework
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <functional>
#include <map> #include <map>
#include <string> #include <string>
#include <utility> #include <utility>
......
...@@ -18,7 +18,8 @@ limitations under the License. */ ...@@ -18,7 +18,8 @@ limitations under the License. */
#include <vector> #include <vector>
#include "common/log.h" #include "common/log.h"
#include "common/type_define.h" #include "common/types.h"
#include "framework/attribute.h"
#include "framework/framework.pb-c.h" #include "framework/framework.pb-c.h"
namespace paddle_mobile { namespace paddle_mobile {
......
...@@ -72,7 +72,7 @@ void ProgramDesc::Description(std::string header) const { ...@@ -72,7 +72,7 @@ void ProgramDesc::Description(std::string header) const {
} }
} }
for (auto &attr : op->GetAttrMap()) { for (auto &attr : op->GetAttrMap()) {
if (attr.first == "op_callstack") continue; if (attr.first == "op_callstack" || attr.first == "sub_block") continue;
LOG(kLOG_DEBUG2) << "attr name: " << attr.first; LOG(kLOG_DEBUG2) << "attr name: " << attr.first;
LOG(kLOG_DEBUG3) << "argument - " << attr.second; LOG(kLOG_DEBUG3) << "argument - " << attr.second;
} }
......
...@@ -19,8 +19,6 @@ limitations under the License. */ ...@@ -19,8 +19,6 @@ limitations under the License. */
#include <fstream> #include <fstream>
#include <memory> #include <memory>
#include <string> #include <string>
#include <type_traits>
#include <typeindex>
#include <vector> #include <vector>
#include "common/enforce.h" #include "common/enforce.h"
...@@ -83,7 +81,7 @@ class Tensor : public TensorBase { ...@@ -83,7 +81,7 @@ class Tensor : public TensorBase {
return *this; return *this;
} }
inline void *mutable_data(std::type_index type) { inline void *mutable_data(const std::string type) {
if (holder_ != nullptr) { if (holder_ != nullptr) {
holder_->set_type(type); holder_->set_type(type);
} }
...@@ -108,7 +106,7 @@ class Tensor : public TensorBase { ...@@ -108,7 +106,7 @@ class Tensor : public TensorBase {
template <typename T> template <typename T>
inline T *mutable_data() { inline T *mutable_data() {
static_assert(std::is_pod<T>::value, "T must be POD"); static_assert(std::is_pod<T>::value, "T must be POD");
return reinterpret_cast<T *>(mutable_data(typeid(T))); return reinterpret_cast<T *>(mutable_data(type_id<T>().name()));
} }
/** /**
...@@ -165,9 +163,9 @@ class Tensor : public TensorBase { ...@@ -165,9 +163,9 @@ class Tensor : public TensorBase {
check_memory_size(); check_memory_size();
PADDLE_MOBILE_ENFORCE( PADDLE_MOBILE_ENFORCE(
(std::is_same<T, void>::value || (std::is_same<T, void>::value ||
holder_->type().hash_code() == typeid(T).hash_code()), holder_->type() == type_id<T>().name()),
"Tensor holds the wrong type, it holds %s, requested %s", "Tensor holds the wrong type, it holds %s, requested %s",
this->holder_->type().name(), typeid(T).name()); this->holder_->type().c_str(), type_id<T>().name().c_str());
return reinterpret_cast<T *>(reinterpret_cast<uintptr_t>(holder_->ptr()) + return reinterpret_cast<T *>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_); offset_);
...@@ -179,9 +177,9 @@ class Tensor : public TensorBase { ...@@ -179,9 +177,9 @@ class Tensor : public TensorBase {
check_memory_size(); check_memory_size();
PADDLE_MOBILE_ENFORCE( PADDLE_MOBILE_ENFORCE(
(std::is_same<T, void>::value || (std::is_same<T, void>::value ||
holder_->type().hash_code() == typeid(T).hash_code()), holder_->type() == type_id<T>().name()),
"Tensor holds the wrong type, it holds %s, requested %s", "Tensor holds the wrong type, it holds %s, requested %s",
this->holder_->type().name(), typeid(T).name()); this->holder_->type().c_str(), type_id<T>().name().c_str());
return reinterpret_cast<const T *>( return reinterpret_cast<const T *>(
reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_); reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_);
...@@ -189,7 +187,7 @@ class Tensor : public TensorBase { ...@@ -189,7 +187,7 @@ class Tensor : public TensorBase {
private: private:
struct PlaceholderImpl : public Placeholder { struct PlaceholderImpl : public Placeholder {
PlaceholderImpl(size_t size, std::type_index type) PlaceholderImpl(size_t size, const std::string type)
: ptr_(static_cast<uint8_t *>(memory::Alloc(size)), : ptr_(static_cast<uint8_t *>(memory::Alloc(size)),
memory::PODDeleter<uint8_t>()), memory::PODDeleter<uint8_t>()),
size_(size), size_(size),
...@@ -203,9 +201,9 @@ class Tensor : public TensorBase { ...@@ -203,9 +201,9 @@ class Tensor : public TensorBase {
virtual void *ptr() const { return static_cast<void *>(ptr_.get()); } virtual void *ptr() const { return static_cast<void *>(ptr_.get()); }
virtual std::type_index type() const { return type_; } virtual std::string type() const { return type_; }
virtual void set_type(std::type_index type) { type_ = type; } virtual void set_type(const std::string type) { type_ = type; }
virtual void resize(size_t size) { virtual void resize(size_t size) {
if (size > capatity_) { if (size > capatity_) {
...@@ -223,7 +221,7 @@ class Tensor : public TensorBase { ...@@ -223,7 +221,7 @@ class Tensor : public TensorBase {
size_t capatity_; size_t capatity_;
/* the current type of memory */ /* the current type of memory */
std::type_index type_; std::string type_;
}; };
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
...@@ -231,13 +229,13 @@ class Tensor : public TensorBase { ...@@ -231,13 +229,13 @@ class Tensor : public TensorBase {
inline void reset_data_ptr(void *p) { inline void reset_data_ptr(void *p) {
((PlaceholderImpl *)(holder_.get()))->ptr_.reset((uint8_t *)p); // NOLINT ((PlaceholderImpl *)(holder_.get()))->ptr_.reset((uint8_t *)p); // NOLINT
} }
inline void set_type(std::type_index type) { holder_->set_type(type); } inline void set_type(const std::string type) { holder_->set_type(type); }
inline void *get_data() { inline void *get_data() {
return ( return (
void *)(((PlaceholderImpl *)(holder_.get()))->ptr_.get()); // NOLINT void *)(((PlaceholderImpl *)(holder_.get()))->ptr_.get()); // NOLINT
} }
inline void *init(std::type_index type) { inline void *init(const std::string type) {
if (holder_ != nullptr) { if (holder_ != nullptr) {
holder_->set_type(type); holder_->set_type(type);
} }
...@@ -265,15 +263,15 @@ inline Print &operator<<(Print &printer, const Tensor &tensor) { ...@@ -265,15 +263,15 @@ inline Print &operator<<(Print &printer, const Tensor &tensor) {
stride = stride > 0 ? stride : 1; stride = stride > 0 ? stride : 1;
#ifndef PADDLE_MOBILE_FPGA #ifndef PADDLE_MOBILE_FPGA
for (int i = 0; i < tensor.numel(); i += stride) { for (int i = 0; i < tensor.numel(); i += stride) {
if (tensor.type() == typeid(float)) { if (tensor.type() == type_id<float>().name()) {
printer << tensor.data<float>()[i] << " "; printer << tensor.data<float>()[i] << " ";
} else if (tensor.type() == typeid(int32_t)) { } else if (tensor.type() == type_id<int32_t>().name()) {
printer << tensor.data<int32_t>()[i] << " "; printer << tensor.data<int32_t>()[i] << " ";
} else if (tensor.type() == typeid(int64_t)) { } else if (tensor.type() == type_id<int64_t>().name()) {
printer << tensor.data<int64_t>()[i] << " "; printer << tensor.data<int64_t>()[i] << " ";
} else if (tensor.type() == typeid(int8_t)) { } else if (tensor.type() == type_id<int8_t>().name()) {
printer << static_cast<int>(tensor.data<int8_t>()[i]) << " "; printer << static_cast<int>(tensor.data<int8_t>()[i]) << " ";
} else if (tensor.type() == typeid(int32_t)) { } else if (tensor.type() == type_id<int32_t>().name()) {
printer << tensor.data<int32_t>()[i] << " "; printer << tensor.data<int32_t>()[i] << " ";
} }
} }
......
...@@ -14,9 +14,7 @@ limitations under the License. */ ...@@ -14,9 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <type_traits> #include <string>
#include <typeindex>
#include "common/enforce.h" #include "common/enforce.h"
#include "common/types.h" #include "common/types.h"
#include "framework/ddim.h" #include "framework/ddim.h"
...@@ -29,8 +27,8 @@ struct SizeOfTypeFunctor; ...@@ -29,8 +27,8 @@ struct SizeOfTypeFunctor;
template <typename T> template <typename T>
struct SizeOfTypeFunctor<T> { struct SizeOfTypeFunctor<T> {
size_t operator()(std::type_index type) const { size_t operator()(const std::string type) const {
if (typeid(T).hash_code() == type.hash_code()) { if (type_id<T>().name() == type) {
return sizeof(T); return sizeof(T);
} else { } else {
return 0UL; return 0UL;
...@@ -40,12 +38,12 @@ struct SizeOfTypeFunctor<T> { ...@@ -40,12 +38,12 @@ struct SizeOfTypeFunctor<T> {
template <> template <>
struct SizeOfTypeFunctor<> { struct SizeOfTypeFunctor<> {
size_t operator()(std::type_index type) const { return 0UL; } size_t operator()(const std::string type) const { return 0UL; }
}; };
template <typename HEAD, typename... TAIL> template <typename HEAD, typename... TAIL>
struct SizeOfTypeFunctor<HEAD, TAIL...> { struct SizeOfTypeFunctor<HEAD, TAIL...> {
size_t operator()(std::type_index type) const { size_t operator()(const std::string type) const {
SizeOfTypeFunctor<HEAD> head; SizeOfTypeFunctor<HEAD> head;
size_t head_size = head(type); size_t head_size = head(type);
if (head_size != 0) { if (head_size != 0) {
...@@ -56,13 +54,14 @@ struct SizeOfTypeFunctor<HEAD, TAIL...> { ...@@ -56,13 +54,14 @@ struct SizeOfTypeFunctor<HEAD, TAIL...> {
} }
}; };
static inline size_t SizeOfType(std::type_index type) { static inline size_t SizeOfType(std::string type) {
SizeOfTypeFunctor<int8_t, int, half, float, double, int16_t, int64_t, bool, SizeOfTypeFunctor<int8_t, int, half, float, double, int16_t, int64_t, bool,
size_t> size_t>
functor; functor;
size_t size = functor(type); size_t size = functor(type);
PADDLE_MOBILE_ENFORCE(size != 0UL, "Cannot get size of type %s", type.name()); PADDLE_MOBILE_ENFORCE(size != 0UL, "Cannot get size of type %s",
type.c_str());
return size; return size;
} }
...@@ -78,7 +77,7 @@ class TensorBase { ...@@ -78,7 +77,7 @@ class TensorBase {
/*! Return the numel of the memory block. */ /*! Return the numel of the memory block. */
inline int64_t numel() const { return product(dims_); } inline int64_t numel() const { return product(dims_); }
std::type_index type() const { std::string type() const {
PADDLE_MOBILE_ENFORCE( PADDLE_MOBILE_ENFORCE(
holder_ != nullptr, holder_ != nullptr,
"Tensor not initialized yet when Tensor::type() is called.") "Tensor not initialized yet when Tensor::type() is called.")
...@@ -114,9 +113,9 @@ class TensorBase { ...@@ -114,9 +113,9 @@ class TensorBase {
virtual size_t size() const = 0; virtual size_t size() const = 0;
virtual std::type_index type() const = 0; virtual std::string type() const = 0;
virtual void set_type(std::type_index type) = 0; virtual void set_type(std::string type) = 0;
virtual void resize(size_t size) = 0; virtual void resize(size_t size) = 0;
}; };
......
...@@ -16,13 +16,10 @@ limitations under the License. */ ...@@ -16,13 +16,10 @@ limitations under the License. */
#include <memory> #include <memory>
#include <string> #include <string>
#include <typeindex> #include "common/variant.h"
#include <typeinfo>
#include "../common/variant.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
using std::string;
class Variable { class Variable {
public: public:
...@@ -33,7 +30,7 @@ class Variable { ...@@ -33,7 +30,7 @@ class Variable {
template <typename T> template <typename T>
const T GetValue() const { const T GetValue() const {
if (typeid(T) == typeid(std::string)) { if (type_id<T>().name() == type_id<std::string>().name()) {
PADDLE_MOBILE_THROW_EXCEPTION( PADDLE_MOBILE_THROW_EXCEPTION(
"Please use getString to get an string (to avoid of an issue with " "Please use getString to get an string (to avoid of an issue with "
"gcc " "gcc "
...@@ -60,38 +57,39 @@ class Variable { ...@@ -60,38 +57,39 @@ class Variable {
template <typename T> template <typename T>
bool IsType() const { bool IsType() const {
return holder_ != nullptr && holder_->Type() == typeid(T); return holder_ != nullptr && holder_->Type() == type_id<T>().name();
} }
void Clear() { holder_.reset(); } void Clear() { holder_.reset(); }
std::type_index Type() const { return holder_->Type(); } std::string Type() const { return holder_->Type(); }
private: private:
struct Placeholder { struct Placeholder {
Placeholder() = default; Placeholder() = default;
virtual ~Placeholder() = default; virtual ~Placeholder() = default;
virtual const std::type_info &Type() const = 0; virtual std::string Type() const = 0;
virtual void *Ptr() const = 0; virtual void *Ptr() const = 0;
}; };
template <typename T> template <typename T>
struct PlaceholderImp : public Placeholder { struct PlaceholderImp : public Placeholder {
explicit PlaceholderImp(T *ptr) : ptr_(ptr), type_(typeid(T)) {} explicit PlaceholderImp(T *ptr) : ptr_(ptr), type_(type_id<T>().name()) {}
virtual const std::type_info &Type() const { return type_; } std::string Type() const override { return type_; }
virtual void *Ptr() const override { void *Ptr() const override { return static_cast<void *>(ptr_.get()); }
return static_cast<void *>(ptr_.get());
}
std::unique_ptr<T> ptr_; std::unique_ptr<T> ptr_;
const std::type_info &type_; std::string type_;
}; };
Variant<int, bool, string, float, double> variant;
std::unique_ptr<Placeholder> holder_;
friend class Scope; friend class Scope;
string name_;
Variant<int, bool, std::string, float, double> variant;
std::unique_ptr<Placeholder> holder_;
std::string name_;
}; };
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -128,7 +128,7 @@ void ConvertTensors(const framework::Tensor &src, PaddleTensor *des) { ...@@ -128,7 +128,7 @@ void ConvertTensors(const framework::Tensor &src, PaddleTensor *des) {
des->layout = src.layout == framework::LAYOUT_HWC ? LAYOUT_HWC : LAYOUT_CHW; des->layout = src.layout == framework::LAYOUT_HWC ? LAYOUT_HWC : LAYOUT_CHW;
auto num = src.numel(); auto num = src.numel();
if (src.type() == typeid(float)) { if (src.type() == type_id<float>()) {
des->data.Reset(const_cast<float *>(src.data<float>()), des->data.Reset(const_cast<float *>(src.data<float>()),
num * sizeof(float)); num * sizeof(float));
} else { } else {
...@@ -143,7 +143,7 @@ void PaddleMobilePredictor<Device, T>::FeedPaddleTensors( ...@@ -143,7 +143,7 @@ void PaddleMobilePredictor<Device, T>::FeedPaddleTensors(
auto num = inputs.size(); auto num = inputs.size();
std::vector<framework::Tensor> tensors(num, framework::Tensor()); std::vector<framework::Tensor> tensors(num, framework::Tensor());
for (int i = 0; i < num; i++) { for (int i = 0; i < num; i++) {
tensors[i].init(typeid(float)); tensors[i].init(type_id<float>());
ConvertPaddleTensors(inputs[i], &tensors[i]); ConvertPaddleTensors(inputs[i], &tensors[i]);
} }
paddle_mobile_->FeedTensorData(tensors); paddle_mobile_->FeedTensorData(tensors);
......
...@@ -24,7 +24,6 @@ limitations under the License. */ ...@@ -24,7 +24,6 @@ limitations under the License. */
#include <cassert> #include <cassert>
#include <memory> #include <memory>
#include <string> #include <string>
#include <typeindex>
#include <vector> #include <vector>
namespace paddle_mobile { namespace paddle_mobile {
...@@ -88,7 +87,6 @@ struct PaddleTensor { ...@@ -88,7 +87,6 @@ struct PaddleTensor {
// TODO(Superjomn) for LoD support, add a vector<vector<int>> field if needed. // TODO(Superjomn) for LoD support, add a vector<vector<int>> field if needed.
PaddleBuf data; // blob of data. PaddleBuf data; // blob of data.
PaddleDType dtype; PaddleDType dtype;
std::type_index dtypeid = typeid(float);
LayoutType layout; LayoutType layout;
}; };
......
...@@ -14,8 +14,6 @@ limitations under the License. */ ...@@ -14,8 +14,6 @@ limitations under the License. */
#ifdef BEAM_SEARCH_DECODE_OP #ifdef BEAM_SEARCH_DECODE_OP
#pragma once
#include "operators/beam_search_decode_op.h" #include "operators/beam_search_decode_op.h"
namespace paddle_mobile { namespace paddle_mobile {
......
...@@ -14,8 +14,6 @@ limitations under the License. */ ...@@ -14,8 +14,6 @@ limitations under the License. */
#ifdef BEAM_SEARCH_OP #ifdef BEAM_SEARCH_OP
#pragma once
#include "operators/beam_search_op.h" #include "operators/beam_search_op.h"
namespace paddle_mobile { namespace paddle_mobile {
......
...@@ -192,10 +192,10 @@ bool LessThanKernel<CPU, float>::Init(CompareParam<CPU> *param) { ...@@ -192,10 +192,10 @@ bool LessThanKernel<CPU, float>::Init(CompareParam<CPU> *param) {
template <> template <>
void LessThanKernel<CPU, float>::Compute(const CompareParam<CPU> &param) { void LessThanKernel<CPU, float>::Compute(const CompareParam<CPU> &param) {
if (param.input_x_->type() == typeid(int64_t)) { if (param.input_x_->type() == type_id<int64_t>().name()) {
CompareCompute<int64_t, LESS_THAN>()(param.input_x_, param.input_y_, CompareCompute<int64_t, LESS_THAN>()(param.input_x_, param.input_y_,
param.axis_, param.output_); param.axis_, param.output_);
} else if (param.input_x_->type() == typeid(float)) { } else if (param.input_x_->type() == type_id<float>().name()) {
CompareCompute<float, LESS_THAN>()(param.input_x_, param.input_y_, CompareCompute<float, LESS_THAN>()(param.input_x_, param.input_y_,
param.axis_, param.output_); param.axis_, param.output_);
} else { } else {
......
...@@ -27,7 +27,7 @@ bool ConcatKernel<CPU, float>::Init(ConcatParam<CPU> *param) { ...@@ -27,7 +27,7 @@ bool ConcatKernel<CPU, float>::Init(ConcatParam<CPU> *param) {
template <> template <>
void ConcatKernel<CPU, float>::Compute(const ConcatParam<CPU> &param) { void ConcatKernel<CPU, float>::Compute(const ConcatParam<CPU> &param) {
if (param.Inputs()[0]->type() == typeid(int8_t)) { if (param.Inputs()[0]->type() == type_id<int8_t>().name()) {
ConcatCompute<int8_t>(param); ConcatCompute<int8_t>(param);
} else { } else {
ConcatCompute<float>(param); ConcatCompute<float>(param);
......
...@@ -28,7 +28,7 @@ void InitBaseConvKernel(ConvParam<CPU> *param) { ...@@ -28,7 +28,7 @@ void InitBaseConvKernel(ConvParam<CPU> *param) {
bool depth5x5 = conv5x5 && param->Groups() == param->Input()->dims()[1] && bool depth5x5 = conv5x5 && param->Groups() == param->Input()->dims()[1] &&
param->Input()->dims()[1] == param->Output()->dims()[1]; param->Input()->dims()[1] == param->Output()->dims()[1];
if (param->Filter()->type() == typeid(int8_t)) { if (param->Filter()->type() == type_id<int8_t>().name()) {
#ifndef __aarch64__ #ifndef __aarch64__
if (depth3x3 && param->Strides()[0] < 3 && if (depth3x3 && param->Strides()[0] < 3 &&
param->Strides()[0] == param->Strides()[1]) { param->Strides()[0] == param->Strides()[1]) {
......
...@@ -28,7 +28,9 @@ void WriteToArrayKernel<CPU, float>::Compute( ...@@ -28,7 +28,9 @@ void WriteToArrayKernel<CPU, float>::Compute(
const WriteToArrayParam<CPU> &param) { const WriteToArrayParam<CPU> &param) {
int64_t offset = param.index_->data<int64_t>()[0]; int64_t offset = param.index_->data<int64_t>()[0];
if (offset >= param.output_->size()) { if (offset >= param.output_->size()) {
param.output_->resize(offset + 1); while (param.output_->size() <= offset) {
param.output_->emplace_back();
}
} }
framework::LoDTensor *out_tensor = &(param.output_->at(offset)); framework::LoDTensor *out_tensor = &(param.output_->at(offset));
......
...@@ -126,13 +126,13 @@ void Transpose2Kernel<CPU, float>::Compute(const Transpose2Param<CPU> &param) { ...@@ -126,13 +126,13 @@ void Transpose2Kernel<CPU, float>::Compute(const Transpose2Param<CPU> &param) {
const std::vector<int> &axis = param.Axis(); const std::vector<int> &axis = param.Axis();
bool shuffle_channel = IsShuffleChannel(axis); bool shuffle_channel = IsShuffleChannel(axis);
if (shuffle_channel) { if (shuffle_channel) {
if (param.InputX()->type() == typeid(int8_t)) { if (param.InputX()->type() == type_id<int8_t>().name()) {
ShuffleChannelCompute<int8_t>(param); ShuffleChannelCompute<int8_t>(param);
} else { } else {
ShuffleChannelCompute<float>(param); ShuffleChannelCompute<float>(param);
} }
} else { } else {
if (param.InputX()->type() == typeid(int8_t)) { if (param.InputX()->type() == type_id<int8_t>().name()) {
Transpose2Compute<int8_t>(param); Transpose2Compute<int8_t>(param);
} else { } else {
Transpose2Compute<float>(param); Transpose2Compute<float>(param);
......
...@@ -35,6 +35,7 @@ class StepExecutor { ...@@ -35,6 +35,7 @@ class StepExecutor {
auto op_handler = framework::OpRegistry<CPU>::CreateOp( auto op_handler = framework::OpRegistry<CPU>::CreateOp(
op_desc->Type(), op_desc->GetInputs(), op_desc->GetOutputs(), op_desc->Type(), op_desc->GetInputs(), op_desc->GetOutputs(),
op_desc->GetAttrMap(), scope_); op_desc->GetAttrMap(), scope_);
op_handler->Init();
ops_of_block_[i] = op_handler; ops_of_block_[i] = op_handler;
} }
} }
......
...@@ -37,7 +37,7 @@ void MulCompute(const MulParam<CPU> &param) { ...@@ -37,7 +37,7 @@ void MulCompute(const MulParam<CPU> &param) {
if (out_dim.size() != 2) { if (out_dim.size() != 2) {
out->Resize({x_matrix.dims()[0], y_matrix.dims()[1]}); out->Resize({x_matrix.dims()[0], y_matrix.dims()[1]});
} }
if (param.InputX()->type() == typeid(int8_t)) { if (param.InputX()->type() == type_id<int8_t>().name()) {
out->mutable_data<int32_t>(); out->mutable_data<int32_t>();
math::MatMul<int8_t, int32_t>(x_matrix, false, y_matrix, false, math::MatMul<int8_t, int32_t>(x_matrix, false, y_matrix, false,
static_cast<float>(1), out, static_cast<float>(1), out,
......
...@@ -144,7 +144,8 @@ void SumCompute(const SumParam<CPU> &param) { ...@@ -144,7 +144,8 @@ void SumCompute(const SumParam<CPU> &param) {
} }
} else { } else {
PADDLE_MOBILE_THROW_EXCEPTION( PADDLE_MOBILE_THROW_EXCEPTION(
"Unexpected branch, output variable type is %s", outvar->Type().name()); "Unexpected branch, output variable type is %s",
outvar->Type().c_str());
} }
} }
} // namespace operators } // namespace operators
......
...@@ -14,8 +14,6 @@ limitations under the License. */ ...@@ -14,8 +14,6 @@ limitations under the License. */
#if defined(__ARM_NEON__) || defined(__ARM_NEON) #if defined(__ARM_NEON__) || defined(__ARM_NEON)
#pragma once
#include "operators/math/gemm/cblas.h" #include "operators/math/gemm/cblas.h"
#include "operators/math/gemm/executor.h" #include "operators/math/gemm/executor.h"
#include "operators/math/gemm/strategy.h" #include "operators/math/gemm/strategy.h"
......
...@@ -14,8 +14,6 @@ limitations under the License. */ ...@@ -14,8 +14,6 @@ limitations under the License. */
#ifdef ONE_HOT_OP #ifdef ONE_HOT_OP
#pragma once
#include "operators/one_hot_op.h" #include "operators/one_hot_op.h"
namespace paddle_mobile { namespace paddle_mobile {
......
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
#include "common/log.h" #include "common/log.h"
#include "common/type_define.h" #include "common/type_define.h"
#include "common/types.h" #include "common/types.h"
#include "framework/attribute.h"
#include "framework/lod_tensor.h" #include "framework/lod_tensor.h"
#include "framework/scope.h" #include "framework/scope.h"
#include "framework/tensor.h" #include "framework/tensor.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册