提交 47e5978f 编写于 作者: 朔-望's avatar 朔-望

add clang-format clang-tidy hook

上级 ab5a35c2
...@@ -20,11 +20,20 @@ repos: ...@@ -20,11 +20,20 @@ repos:
- id: trailing-whitespace - id: trailing-whitespace
files: (src).*\.(md|py|mm|swift|java|c|cc|cxx|cpp|cu|h|hpp|hxx)$ files: (src).*\.(md|py|mm|swift|java|c|cc|cxx|cpp|cu|h|hpp|hxx)$
- repo: local
hooks:
- id: clang-format
name: clang-format
description: Format files with ClangFormat.
entry: bash ./tools/pre-commit.hooks/.clang-format.hook -i
language: system
files: \.(c|cc|cxx|cpp|h|hpp|hxx)$
- repo: local - repo: local
hooks: hooks:
- id: clang-tidy - id: clang-tidy
name: clang-tidy name: clang-tidy
description: Format files with tidy. description: Check C++ code style using clang-tidy.
entry: bash ./tools/pre-commit.hooks/.clang-tidy.hook -i entry: bash ./tools/pre-commit.hooks/.clang-tidy.hook -i
language: system language: system
files: (src).*\.(c|cc|cxx|cpp|h|hpp|hxx)$ files: (src).*\.(c|cc|cxx|cpp|h|hpp|hxx)$
......
...@@ -27,146 +27,145 @@ SOFTWARE. ...@@ -27,146 +27,145 @@ SOFTWARE.
namespace paddle_mobile { namespace paddle_mobile {
enum LogLevel { enum LogLevel {
kNO_LOG, kNO_LOG,
kLOG_ERROR, kLOG_ERROR,
kLOG_WARNING, kLOG_WARNING,
kLOG_INFO, kLOG_INFO,
kLOG_DEBUG, kLOG_DEBUG,
kLOG_DEBUG1, kLOG_DEBUG1,
kLOG_DEBUG2, kLOG_DEBUG2,
kLOG_DEBUG3, kLOG_DEBUG3,
kLOG_DEBUG4 kLOG_DEBUG4
}; };
// log level // log level
static LogLevel log_level = kLOG_DEBUG4; static LogLevel log_level = kLOG_DEBUG4;
static std::vector<std::string> logs{"NO", "ERROR ", "WARNING", static std::vector<std::string> logs{"NO", "ERROR ", "WARNING",
"INFO ", "DEBUG ", "DEBUG1 ", "INFO ", "DEBUG ", "DEBUG1 ",
"DEBUG2 ", "DEBUG3 ", "DEBUG4 "}; "DEBUG2 ", "DEBUG3 ", "DEBUG4 "};
struct ToLog; struct ToLog;
struct Print; struct Print;
struct Print { struct Print {
friend struct ToLog; friend struct ToLog;
template <typename T> Print &operator<<(T const &value) { template <typename T> Print &operator<<(T const &value) {
buffer_ << value; buffer_ << value;
return *this; return *this;
}
private:
void print(LogLevel level) {
buffer_ << std::endl;
if (level == kLOG_ERROR) {
std::cerr << buffer_.str();
} else {
std::cout << buffer_.str();
} }
}
private: std::ostringstream buffer_;
void print(LogLevel level) { };
buffer_ << std::endl;
if (level == kLOG_ERROR) { struct ToLog {
std::cerr << buffer_.str(); ToLog(LogLevel level = kLOG_DEBUG, const std::string &info = "")
} else { : level_(level) {
std::cout << buffer_.str(); unsigned blanks =
} (unsigned)(level > kLOG_DEBUG ? (level - kLOG_DEBUG) * 4 : 1);
} printer_ << logs[level] << " " << info << ":"
std::ostringstream buffer_; << std::string(blanks, ' ');
}; }
struct ToLog { template <typename T> ToLog &operator<<(T const &value) {
ToLog(LogLevel level = kLOG_DEBUG, const std::string &info = "") printer_ << value;
: level_(level) { return *this;
unsigned blanks = }
(unsigned)(level > kLOG_DEBUG ? (level - kLOG_DEBUG) * 4 : 1);
printer_ << logs[level] << " " << info << ":" ~ToLog() { printer_.print(level_); }
<< std::string(blanks, ' ');
} private:
LogLevel level_;
template <typename T> ToLog &operator<<(T const &value) { Print printer_;
printer_ << value; };
return *this;
}
~ToLog() { printer_.print(level_); }
private:
LogLevel level_;
Print printer_;
};
#define LOG(level) \ #define LOG(level) \
if (level > paddle_mobile::log_level) { \ if (level > paddle_mobile::log_level) { \
} else \ } else \
paddle_mobile::ToLog( \ paddle_mobile::ToLog( \
level, \ level, (std::stringstream() \
(std::stringstream() \ << "[file: " \
<< "[file: " \ << (strrchr(__FILE__, '/') ? (strrchr(__FILE__, '/') + 1) \
<< (strrchr(__FILE__, '/') ? (strrchr(__FILE__, '/') + 1) : __FILE__) \ : __FILE__) \
<< "] [line: " << __LINE__ << "] ") \ << "] [line: " << __LINE__ << "] ") \
.str()) .str())
#define DLOG \ #define DLOG \
if (paddle_mobile::kLOG_DEBUG > paddle_mobile::log_level) { \ if (paddle_mobile::kLOG_DEBUG > paddle_mobile::log_level) { \
} else \ } else \
paddle_mobile::ToLog( \ paddle_mobile::ToLog( \
paddle_mobile::kLOG_DEBUG, \ paddle_mobile::kLOG_DEBUG, \
(std::stringstream() \ (std::stringstream() \
<< "[file: " \ << "[file: " \
<< (strrchr(__FILE__, '/') ? (strrchr(__FILE__, '/') + 1) : __FILE__) \ << (strrchr(__FILE__, '/') ? (strrchr(__FILE__, '/') + 1) \
<< "] [line: " << __LINE__ << "] ") \ : __FILE__) \
.str()) << "] [line: " << __LINE__ << "] ") \
} .str())
} // namespace paddle_mobile
#define LOGF(level, format, ...) \ #define LOGF(level, format, ...) \
if (level > paddle_mobile::log_level) { \ if (level > paddle_mobile::log_level) { \
} else \ } else \
printf(format, ##__VA_ARGS__) printf(format, ##__VA_ARGS__)
#define DLOGF(format, ...) \ #define DLOGF(format, ...) \
if (paddle_mobile::kLOG_DEBUG > paddle_mobile::log_level) { \ if (paddle_mobile::kLOG_DEBUG > paddle_mobile::log_level) { \
} else \ } else \
printf(format, ##__VA_ARGS__) printf(format, ##__VA_ARGS__)
#else #else
namespace paddle_mobile { namespace paddle_mobile {
enum LogLevel { enum LogLevel {
kNO_LOG, kNO_LOG,
kLOG_ERROR, kLOG_ERROR,
kLOG_WARNING, kLOG_WARNING,
kLOG_INFO, kLOG_INFO,
kLOG_DEBUG, kLOG_DEBUG,
kLOG_DEBUG1, kLOG_DEBUG1,
kLOG_DEBUG2, kLOG_DEBUG2,
kLOG_DEBUG3, kLOG_DEBUG3,
kLOG_DEBUG4 kLOG_DEBUG4
}; };
struct ToLog; struct ToLog;
struct Print { struct Print {
friend struct ToLog; friend struct ToLog;
template <typename T> Print &operator<<(T const &value) {} template <typename T> Print &operator<<(T const &value) {}
private: private:
}; };
struct ToLog { struct ToLog {
ToLog(LogLevel level) {} ToLog(LogLevel level) {}
template <typename T> ToLog &operator<<(T const &value) { template <typename T> ToLog &operator<<(T const &value) { return *this; }
return *this; };
}
};
#define LOG(level) \ #define LOG(level) \
if (true) { \ if (true) { \
} else \ } else \
paddle_mobile::ToLog(level) paddle_mobile::ToLog(level)
#define DLOG \ #define DLOG \
if (true) { \ if (true) { \
} else \ } else \
paddle_mobile::ToLog(paddle_mobile::kLOG_DEBUG) paddle_mobile::ToLog(paddle_mobile::kLOG_DEBUG)
#define LOGF(level, format, ...) #define LOGF(level, format, ...)
#define DLOGF(format, ...) #define DLOGF(format, ...)
} } // namespace paddle_mobile
#endif #endif
...@@ -24,30 +24,29 @@ SOFTWARE. ...@@ -24,30 +24,29 @@ SOFTWARE.
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
template<typename Dtype> class OperatorBase; template <typename Dtype> class OperatorBase;
class OpDesc; class OpDesc;
class BlockDesc; class BlockDesc;
class InferShapeContext; class InferShapeContext;
} } // namespace framework
using VariableNameMap = std::map<std::string, std::vector<std::string>>; using VariableNameMap = std::map<std::string, std::vector<std::string>>;
template<typename Dtype> template <typename Dtype>
using OpCreator = std::function<framework::OperatorBase<Dtype> *( using OpCreator = std::function<framework::OperatorBase<Dtype> *(
const std::string & /*type*/, const VariableNameMap & /*inputs*/, const std::string & /*type*/, const VariableNameMap & /*inputs*/,
const VariableNameMap & /*outputs*/, const VariableNameMap & /*outputs*/,
const framework::AttributeMap & /*attrs*/)>; const framework::AttributeMap & /*attrs*/)>;
using GradOpMakerFN = using GradOpMakerFN =
std::function<std::vector<std::unique_ptr<framework::OpDesc>>( std::function<std::vector<std::unique_ptr<framework::OpDesc>>(
const framework::OpDesc &, const framework::OpDesc &,
const std::unordered_set<std::string> & /*no_grad_set*/, const std::unordered_set<std::string> & /*no_grad_set*/,
std::unordered_map<std::string, std::string> * /*grad_to_var*/, std::unordered_map<std::string, std::string> * /*grad_to_var*/,
const std::vector<framework::BlockDesc *> &grad_block)>; const std::vector<framework::BlockDesc *> &grad_block)>;
using InferVarTypeFN = using InferVarTypeFN = std::function<void(const framework::OpDesc & /*op_desc*/,
std::function<void(const framework::OpDesc & /*op_desc*/, framework::BlockDesc * /*block*/)>;
framework::BlockDesc * /*block*/)>;
using InferShapeFN = std::function<void(framework::InferShapeContext *)>; using InferShapeFN = std::function<void(framework::InferShapeContext *)>;
}; }; // namespace paddle_mobile
...@@ -24,7 +24,7 @@ enum class Precision : int { FP32 = 0 }; ...@@ -24,7 +24,7 @@ enum class Precision : int { FP32 = 0 };
//! device type //! device type
enum DeviceTypeEnum { kINVALID = -1, kCPU = 0, kFPGA = 1, kGPU_MALI = 2 }; enum DeviceTypeEnum { kINVALID = -1, kCPU = 0, kFPGA = 1, kGPU_MALI = 2 };
template<DeviceTypeEnum T> struct DeviceType {}; template <DeviceTypeEnum T> struct DeviceType {};
typedef DeviceType<kCPU> CPU; typedef DeviceType<kCPU> CPU;
typedef DeviceType<kFPGA> FPGA; typedef DeviceType<kFPGA> FPGA;
...@@ -32,32 +32,32 @@ typedef DeviceType<kGPU_MALI> GPU_MALI; ...@@ -32,32 +32,32 @@ typedef DeviceType<kGPU_MALI> GPU_MALI;
//! data type //! data type
enum DataType { enum DataType {
PM_INVALID = -1, PM_INVALID = -1,
PM_HALF = 0, PM_HALF = 0,
PM_FLOAT = 1, PM_FLOAT = 1,
PM_DOUBLE = 2, PM_DOUBLE = 2,
PM_INT8 = 3, PM_INT8 = 3,
PM_INT16 = 4, PM_INT16 = 4,
PM_INT32 = 5, PM_INT32 = 5,
PM_INT64 = 6, PM_INT64 = 6,
PM_UINT8 = 7, PM_UINT8 = 7,
PM_UINT16 = 8, PM_UINT16 = 8,
PM_UINT32 = 9, PM_UINT32 = 9,
PM_STRING = 10, PM_STRING = 10,
PM_BOOL = 11, PM_BOOL = 11,
PM_SHAPE = 12, PM_SHAPE = 12,
PM_TENSOR = 13 PM_TENSOR = 13
}; };
//! //!
enum PMStatus { enum PMStatus {
PMSuccess = 0xFF, /*!< No errors */ PMSuccess = 0xFF, /*!< No errors */
PMNotInitialized = 0x01, /*!< Data not initialized. */ PMNotInitialized = 0x01, /*!< Data not initialized. */
PMInvalidValue = 0x02, /*!< Incorrect variable value. */ PMInvalidValue = 0x02, /*!< Incorrect variable value. */
PMMemAllocFailed = 0x03, /*!< Memory allocation error. */ PMMemAllocFailed = 0x03, /*!< Memory allocation error. */
PMUnKownError = 0x04, /*!< Unknown error. */ PMUnKownError = 0x04, /*!< Unknown error. */
PMOutOfAuthority = 0x05, /*!< Try to modified data not your own*/ PMOutOfAuthority = 0x05, /*!< Try to modified data not your own*/
PMOutOfMem = 0x06, /*!< OOM error*/ PMOutOfMem = 0x06, /*!< OOM error*/
PMUnImplError = 0x07, /*!< Unimplement error. */ PMUnImplError = 0x07, /*!< Unimplement error. */
PMWrongDevice = 0x08 /*!< un-correct device. */ PMWrongDevice = 0x08 /*!< un-correct device. */
}; };
} } // namespace paddle_mobile
...@@ -21,79 +21,79 @@ SOFTWARE. ...@@ -21,79 +21,79 @@ SOFTWARE.
#pragma once #pragma once
namespace paddle_mobile { namespace paddle_mobile {
template<int ID, typename Type> struct IDToType { typedef Type type_t; }; template <int ID, typename Type> struct IDToType { typedef Type type_t; };
template<typename F, typename... Ts> struct VariantHelper { template <typename F, typename... Ts> struct VariantHelper {
static const size_t size = sizeof(F) > VariantHelper<Ts...>::size static const size_t size = sizeof(F) > VariantHelper<Ts...>::size
? sizeof(F) ? sizeof(F)
: VariantHelper<Ts...>::size; : VariantHelper<Ts...>::size;
inline static void Destroy(size_t id, void *data) { inline static void Destroy(size_t id, void *data) {
if (id == typeid(F).hash_code()) { if (id == typeid(F).hash_code()) {
reinterpret_cast<F *>(data)->~F(); reinterpret_cast<F *>(data)->~F();
} else { } else {
VariantHelper<Ts...>::Destroy(id, data); VariantHelper<Ts...>::Destroy(id, data);
}
} }
}
}; };
template<typename F> struct VariantHelper<F> { template <typename F> struct VariantHelper<F> {
static const size_t size = sizeof(F); static const size_t size = sizeof(F);
inline static void Destroy(size_t id, void *data) { inline static void Destroy(size_t id, void *data) {
if (id == typeid(F).hash_code()) { if (id == typeid(F).hash_code()) {
// reinterpret_cast<F*>(data)->~F(); // reinterpret_cast<F*>(data)->~F();
} else { } else {
// std::cout << "未匹配到 " << std::endl; // std::cout << "未匹配到 " << std::endl;
}
} }
}
}; };
template<size_t size> class RawData { template <size_t size> class RawData {
public: public:
char data[size]; char data[size];
RawData() {} RawData() {}
RawData(const RawData &raw_data) { strcpy(data, raw_data.data); } RawData(const RawData &raw_data) { strcpy(data, raw_data.data); }
// void operator=(const RawData &raw_data){ // void operator=(const RawData &raw_data){
// strcpy(data, raw_data.data); // strcpy(data, raw_data.data);
// } // }
}; };
template<typename... Ts> struct Variant { template <typename... Ts> struct Variant {
Variant(const Variant &variant) { Variant(const Variant &variant) {
// std::cout << " 赋值构造函数 " << std::endl; // std::cout << " 赋值构造函数 " << std::endl;
type_id = variant.type_id; type_id = variant.type_id;
data = variant.data; data = variant.data;
} }
Variant() : type_id(invalid_type()) {} Variant() : type_id(invalid_type()) {}
~Variant() { ~Variant() {
// helper::Destroy(type_id, &data); // helper::Destroy(type_id, &data);
} }
template<typename T, typename... Args> void Set(Args &&... args) { template <typename T, typename... Args> void Set(Args &&... args) {
helper::Destroy(type_id, &data); helper::Destroy(type_id, &data);
new(&data) T(std::forward<Args>(args)...); new (&data) T(std::forward<Args>(args)...);
type_id = typeid(T).hash_code(); type_id = typeid(T).hash_code();
} }
template<typename T> T &Get() const { template <typename T> T &Get() const {
if (type_id == typeid(T).hash_code()) { if (type_id == typeid(T).hash_code()) {
return *const_cast<T *>(reinterpret_cast<const T *>(&data)); return *const_cast<T *>(reinterpret_cast<const T *>(&data));
} else { } else {
// std::cout << " bad cast in variant " << std::endl; // std::cout << " bad cast in variant " << std::endl;
throw std::bad_cast(); throw std::bad_cast();
}
} }
}
size_t TypeId() const { return type_id; } size_t TypeId() const { return type_id; }
private: private:
static inline size_t invalid_type() { return typeid(void).hash_code(); } static inline size_t invalid_type() { return typeid(void).hash_code(); }
typedef VariantHelper<Ts...> helper; typedef VariantHelper<Ts...> helper;
size_t type_id; size_t type_id;
RawData<helper::size> data; RawData<helper::size> data;
}; };
template<typename T> struct Vistor { typedef T type_t; }; template <typename T> struct Vistor { typedef T type_t; };
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -27,104 +27,102 @@ namespace framework { ...@@ -27,104 +27,102 @@ namespace framework {
class BlockDesc; class BlockDesc;
class Attribute { class Attribute {
public: public:
static Attribute static Attribute GetAttrValue(const proto::OpDesc::Attr &attr_desc) {
GetAttrValue(const proto::OpDesc::Attr &attr_desc) { // std::cout << "begin get attr value" << std::endl;
// std::cout << "begin get attr value" << std::endl; Attribute attr;
Attribute attr; switch (attr_desc.type()) {
switch (attr_desc.type()) { case proto::AttrType::BOOLEAN: {
case proto::AttrType::BOOLEAN: { attr.Set<bool>(attr_desc.b());
attr.Set<bool>(attr_desc.b()); break;
break; }
case proto::AttrType::INT: {
attr.Set<int>(attr_desc.i());
break;
}
case proto::AttrType::FLOAT: {
attr.Set<float>(attr_desc.f());
break;
}
case proto::AttrType::STRING: {
attr.Set<std::string>(attr_desc.s());
break;
}
case proto::AttrType::BOOLEANS: {
std::vector<bool> val(attr_desc.bools_size());
for (int i = 0; i < attr_desc.bools_size(); ++i) {
val[i] = attr_desc.bools(i);
}
attr.Set<std::vector<bool>>(val);
break;
}
case proto::AttrType::INTS: {
std::vector<int> val(attr_desc.ints_size());
for (int i = 0; i < attr_desc.ints_size(); ++i) {
val[i] = attr_desc.ints(i);
}
attr.Set<std::vector<int>>(val);
break;
}
case proto::AttrType::FLOATS: {
std::vector<float> val(attr_desc.floats_size());
for (int i = 0; i < attr_desc.floats_size(); ++i) {
val[i] = attr_desc.floats(i);
}
attr.Set<std::vector<float>>(val);
break;
}
case proto::AttrType::STRINGS: {
std::vector<std::string> val(attr_desc.strings_size());
for (int i = 0; i < attr_desc.strings_size(); ++i) {
val[i] = attr_desc.strings(i);
}
attr.Set<std::vector<std::string>>(val);
break;
}
case proto::AttrType::LONG: {
attr.Set<int64_t>(attr_desc.l());
break;
}
default:
// std::cout << " not support " << std::endl;
break;
}
// std::cout << "end get attr value" << std::endl;
return attr;
} }
case proto::AttrType::INT: {
attr.Set<int>(attr_desc.i());
break;
}
case proto::AttrType::FLOAT: {
attr.Set<float>(attr_desc.f());
break;
}
case proto::AttrType::STRING: {
attr.Set<std::string>(attr_desc.s());
break;
}
case proto::AttrType::BOOLEANS: {
std::vector<bool> val(attr_desc.bools_size());
for (int i = 0; i < attr_desc.bools_size(); ++i) {
val[i] = attr_desc.bools(i);
}
attr.Set<std::vector<bool>>(val);
break;
}
case proto::AttrType::INTS: {
std::vector<int> val(attr_desc.ints_size());
for (int i = 0; i < attr_desc.ints_size(); ++i) {
val[i] = attr_desc.ints(i);
}
attr.Set<std::vector<int>>(val);
break;
}
case proto::AttrType::FLOATS: {
std::vector<float> val(attr_desc.floats_size());
for (int i = 0; i < attr_desc.floats_size(); ++i) {
val[i] = attr_desc.floats(i);
}
attr.Set<std::vector<float>>(val);
break;
}
case proto::AttrType::STRINGS: {
std::vector<std::string> val(attr_desc.strings_size());
for (int i = 0; i < attr_desc.strings_size(); ++i) {
val[i] = attr_desc.strings(i);
}
attr.Set<std::vector<std::string>>(val);
break;
}
case proto::AttrType::LONG: {
attr.Set<int64_t>(attr_desc.l());
break;
}
default:
// std::cout << " not support " << std::endl;
break;
}
// std::cout << "end get attr value" << std::endl;
return attr;
}
Attribute() {} Attribute() {}
template<typename T, typename... Args> template <typename T, typename... Args> Attribute &Set(Args &&... args) {
Attribute &Set(Args &&... args) { variant_.Set<T>(args...);
variant_.Set<T>(args...); return *this;
return *this; }
}
template<typename T> T &Get() const { return variant_.Get<T>(); } template <typename T> T &Get() const { return variant_.Get<T>(); }
private: private:
Variant<int, float, std::string, std::vector<int>, Variant<int, float, std::string, std::vector<int>, std::vector<float>,
std::vector<float>, std::vector<std::string>, bool, std::vector<std::string>, bool, std::vector<bool>, BlockDesc *,
std::vector<bool>, BlockDesc *, int64_t> int64_t>
variant_; variant_;
}; };
using AttributeMap = std::unordered_map<std::string, Attribute>; using AttributeMap = std::unordered_map<std::string, Attribute>;
class AttrReader { class AttrReader {
public: public:
explicit AttrReader(const AttributeMap &attrs) : attrs_(attrs) {} explicit AttrReader(const AttributeMap &attrs) : attrs_(attrs) {}
template<typename T> inline T Get(const std::string &name) const { template <typename T> inline T Get(const std::string &name) const {
// PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should // PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should
// be in // be in
// AttributeMap", // AttributeMap",
// name); // name);
return ((Attribute) attrs_.at(name)).Get<T>(); return ((Attribute)attrs_.at(name)).Get<T>();
} }
private: private:
const AttributeMap &attrs_; const AttributeMap &attrs_;
}; };
} // namespace framework } // namespace framework
......
...@@ -22,28 +22,28 @@ namespace paddle_mobile { ...@@ -22,28 +22,28 @@ namespace paddle_mobile {
namespace framework { namespace framework {
std::vector<std::shared_ptr<VarDesc>> BlockDesc::Vars() const { std::vector<std::shared_ptr<VarDesc>> BlockDesc::Vars() const {
std::vector<std::shared_ptr<VarDesc>> res; std::vector<std::shared_ptr<VarDesc>> res;
for (const auto &p : vars_) { for (const auto &p : vars_) {
res.push_back(p.second); res.push_back(p.second);
} }
return res; return res;
} }
std::vector<std::shared_ptr<OpDesc>> BlockDesc::Ops() const { std::vector<std::shared_ptr<OpDesc>> BlockDesc::Ops() const {
std::vector<std::shared_ptr<OpDesc>> res; std::vector<std::shared_ptr<OpDesc>> res;
for (const auto &op : ops_) { for (const auto &op : ops_) {
res.push_back(op); res.push_back(op);
} }
return res; return res;
} }
BlockDesc::BlockDesc(const proto::BlockDesc &desc) : desc_(desc) { BlockDesc::BlockDesc(const proto::BlockDesc &desc) : desc_(desc) {
for (const proto::VarDesc &var_desc : desc_.vars()) { for (const proto::VarDesc &var_desc : desc_.vars()) {
vars_[var_desc.name()].reset(new VarDesc(var_desc)); vars_[var_desc.name()].reset(new VarDesc(var_desc));
} }
for (const proto::OpDesc &op_desc : desc_.ops()) { for (const proto::OpDesc &op_desc : desc_.ops()) {
ops_.emplace_back(new framework::OpDesc(op_desc)); ops_.emplace_back(new framework::OpDesc(op_desc));
} }
} }
} // namespace framework } // namespace framework
......
...@@ -27,32 +27,29 @@ namespace paddle_mobile { ...@@ -27,32 +27,29 @@ namespace paddle_mobile {
namespace framework { namespace framework {
class BlockDesc : PaddleMobileObject { class BlockDesc : PaddleMobileObject {
public: public:
BlockDesc(const proto::BlockDesc &desc); BlockDesc(const proto::BlockDesc &desc);
const int &ID() const { return desc_.idx(); } const int &ID() const { return desc_.idx(); }
const int &Parent() const { return desc_.parent_idx(); } const int &Parent() const { return desc_.parent_idx(); }
bool operator==( bool operator==(const paddle_mobile::framework::BlockDesc &in_block) const {
const paddle_mobile::framework::BlockDesc &in_block) const { return this->ID() == in_block.ID() &&
return this->ID() == in_block.ID() && this->Parent() == in_block.Parent();
this->Parent() == in_block.Parent(); }
}
bool operator<( bool operator<(const paddle_mobile::framework::BlockDesc &in_block) const {
const paddle_mobile::framework::BlockDesc &in_block) const { return this->ID() < in_block.ID() && this->Parent() < in_block.Parent();
return this->ID() < in_block.ID() && }
this->Parent() < in_block.Parent();
}
std::vector<std::shared_ptr<VarDesc>> Vars() const; std::vector<std::shared_ptr<VarDesc>> Vars() const;
std::vector<std::shared_ptr<OpDesc>> Ops() const; std::vector<std::shared_ptr<OpDesc>> Ops() const;
private: private:
proto::BlockDesc desc_; proto::BlockDesc desc_;
std::vector<std::shared_ptr<OpDesc>> ops_; std::vector<std::shared_ptr<OpDesc>> ops_;
std::unordered_map<std::string, std::shared_ptr<VarDesc>> vars_; std::unordered_map<std::string, std::shared_ptr<VarDesc>> vars_;
}; };
} // namespace framework } // namespace framework
...@@ -60,14 +57,14 @@ private: ...@@ -60,14 +57,14 @@ private:
namespace std { namespace std {
template<> struct hash<paddle_mobile::framework::BlockDesc> { template <> struct hash<paddle_mobile::framework::BlockDesc> {
typedef paddle_mobile::framework::BlockDesc argument_type; typedef paddle_mobile::framework::BlockDesc argument_type;
typedef std::size_t result_type; typedef std::size_t result_type;
result_type operator()(argument_type const &s) const noexcept { result_type operator()(argument_type const &s) const noexcept {
result_type const h1(std::hash<int>{}(s.ID())); result_type const h1(std::hash<int>{}(s.ID()));
result_type const h2(std::hash<int>{}(s.ID())); result_type const h2(std::hash<int>{}(s.ID()));
return h1 ^ (h2 << 1); return h1 ^ (h2 << 1);
} }
}; };
} // namespace std } // namespace std
...@@ -22,42 +22,45 @@ namespace paddle_mobile { ...@@ -22,42 +22,45 @@ namespace paddle_mobile {
namespace framework { namespace framework {
enum class DataLayout { enum class DataLayout {
kNHWC = 0, kNHWC = 0,
kNCHW = 1, kNCHW = 1,
kAnyLayout = 2, kAnyLayout = 2,
}; };
inline DataLayout StringToDataLayout(const std::string &str) { inline DataLayout StringToDataLayout(const std::string &str) {
std::string s(str); std::string s(str);
for (size_t i = 0; i < s.size(); ++i) { for (size_t i = 0; i < s.size(); ++i) {
s[i] = toupper(s[i]); s[i] = toupper(s[i]);
} }
if (s == "NHWC") { if (s == "NHWC") {
return DataLayout::kNHWC; return DataLayout::kNHWC;
} else if (s == "NCHW") { } else if (s == "NCHW") {
return DataLayout::kNCHW; return DataLayout::kNCHW;
} else if (s == "ANYLAYOUT") { } else if (s == "ANYLAYOUT") {
return DataLayout::kAnyLayout; return DataLayout::kAnyLayout;
} else { } else {
// std::cout << "Unknown storage order string: %s", s; // std::cout << "Unknown storage order string: %s", s;
} }
} }
inline std::string DataLayoutToString(const DataLayout &data_layout) { inline std::string DataLayoutToString(const DataLayout &data_layout) {
switch (data_layout) { switch (data_layout) {
case DataLayout::kNHWC:return "NHWC"; case DataLayout::kNHWC:
case DataLayout::kNCHW:return "NCHW"; return "NHWC";
case DataLayout::kAnyLayout:return "ANY_LAYOUT"; case DataLayout::kNCHW:
default:break; return "NCHW";
// std::cout << "unknown DataLayou %d", data_layout; case DataLayout::kAnyLayout:
} return "ANY_LAYOUT";
default:
break;
// std::cout << "unknown DataLayou %d", data_layout;
}
} }
inline std::ostream &operator<<(std::ostream &out, inline std::ostream &operator<<(std::ostream &out, const DataLayout &l) {
const DataLayout &l) { out << DataLayoutToString(l);
out << DataLayoutToString(l); return out;
return out;
} }
} // namespace framework } // namespace framework
......
...@@ -24,68 +24,68 @@ namespace paddle_mobile { ...@@ -24,68 +24,68 @@ namespace paddle_mobile {
namespace framework { namespace framework {
static void PassTensorData(Tensor *from, Tensor *to) { static void PassTensorData(Tensor *from, Tensor *to) {
to->ShareDataWith(*from); to->ShareDataWith(*from);
*from = Tensor(); *from = Tensor();
} }
void DataTransform(const OpKernelType &expected_kernel_type, void DataTransform(const OpKernelType &expected_kernel_type,
const OpKernelType &kernel_type_for_var, const OpKernelType &kernel_type_for_var,
const Tensor &input_tensor, Tensor *output_tensor) { const Tensor &input_tensor, Tensor *output_tensor) {
bool transformed = false; bool transformed = false;
Tensor in; Tensor in;
in.ShareDataWith(input_tensor); in.ShareDataWith(input_tensor);
Tensor out; Tensor out;
// // do layout transform // // do layout transform
// if (NeedTransformLayout(expected_kernel_type.data_layout_, // if (NeedTransformLayout(expected_kernel_type.data_layout_,
// kernel_type_for_var.data_layout_)) { // kernel_type_for_var.data_layout_)) {
// TransDataLayout(kernel_type_for_var, expected_kernel_type, in, // TransDataLayout(kernel_type_for_var, expected_kernel_type, in,
// &out); // &out);
// transformed = true; // transformed = true;
// PassTensorData(&out, &in); // PassTensorData(&out, &in);
// } // }
// //
// // do data type transform // // do data type transform
// if (expected_kernel_type.data_type_ != // if (expected_kernel_type.data_type_ !=
// kernel_type_for_var.data_type_) { // kernel_type_for_var.data_type_) {
// TransDataType(kernel_type_for_var, expected_kernel_type, in, // TransDataType(kernel_type_for_var, expected_kernel_type, in,
// &out); // &out);
// transformed = true; // transformed = true;
// PassTensorData(&out, &in); // PassTensorData(&out, &in);
// } // }
// //
// // do device transform // // do device transform
// if (!platform::is_same_place(kernel_type_for_var.place_, // if (!platform::is_same_place(kernel_type_for_var.place_,
// expected_kernel_type.place_)) { // expected_kernel_type.place_)) {
// TransDataDevice(in, expected_kernel_type.place_, &out); // TransDataDevice(in, expected_kernel_type.place_, &out);
// transformed = true; // transformed = true;
// PassTensorData(&out, &in); // PassTensorData(&out, &in);
// } // }
// //
// PADDLE_ENFORCE(transformed, "No transform is applied, please // PADDLE_ENFORCE(transformed, "No transform is applied, please
// check!"); // check!");
// get output data // get output data
output_tensor->ShareDataWith(in); output_tensor->ShareDataWith(in);
} }
void CopyVariableWithTensor(const Variable &in_var, void CopyVariableWithTensor(const Variable &in_var, const Tensor &tensor,
const Tensor &tensor, Variable &out_var) { Variable &out_var) {
// if (in_var.IsType<LoDTensor>()) { // if (in_var.IsType<LoDTensor>()) {
// auto& in_lod_tensor = in_var.Get<LoDTensor>(); // auto& in_lod_tensor = in_var.Get<LoDTensor>();
// auto* tran_lod_tensor = out_var.GetMutable<LoDTensor>(); // auto* tran_lod_tensor = out_var.GetMutable<LoDTensor>();
// tran_lod_tensor->set_lod(in_lod_tensor.lod()); // tran_lod_tensor->set_lod(in_lod_tensor.lod());
// tran_lod_tensor->set_layout(in_lod_tensor.layout()); // tran_lod_tensor->set_layout(in_lod_tensor.layout());
// tran_lod_tensor->ShareDataWith(tensor); // tran_lod_tensor->ShareDataWith(tensor);
// } else if (in_var.IsType<SelectedRows>()) { // } else if (in_var.IsType<SelectedRows>()) {
// auto& in_selected_rows = in_var.Get<SelectedRows>(); // auto& in_selected_rows = in_var.Get<SelectedRows>();
// auto* trans_selected_rows = // auto* trans_selected_rows =
// out_var.GetMutable<SelectedRows>(); // out_var.GetMutable<SelectedRows>();
// trans_selected_rows->set_height(in_selected_rows.height()); // trans_selected_rows->set_height(in_selected_rows.height());
// trans_selected_rows->set_rows(in_selected_rows.rows()); // trans_selected_rows->set_rows(in_selected_rows.rows());
// trans_selected_rows->mutable_value()->ShareDataWith(tensor); // trans_selected_rows->mutable_value()->ShareDataWith(tensor);
// } else { // } else {
// PADDLE_THROW("unknown var type"); // PADDLE_THROW("unknown var type");
// } // }
} }
} // namespace framework } // namespace framework
......
...@@ -28,14 +28,14 @@ SOFTWARE. ...@@ -28,14 +28,14 @@ SOFTWARE.
#include "variable.h" #include "variable.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
void DataTransform(const OpKernelType &expected_kernel_type, void DataTransform(const OpKernelType &expected_kernel_type,
const OpKernelType &kernel_type_for_var, const OpKernelType &kernel_type_for_var,
const Tensor &input_tensor, Tensor *out); const Tensor &input_tensor, Tensor *out);
void CopyVariableWithTensor(const Variable &in_var, void CopyVariableWithTensor(const Variable &in_var, const Tensor &tensor,
const Tensor &tensor, Variable &out_var); Variable &out_var);
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -21,23 +21,23 @@ SOFTWARE. ...@@ -21,23 +21,23 @@ SOFTWARE.
#include "framework.pb.h" #include "framework.pb.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
// inline proto::VarType::Type ToDataType(std::type_index type) { // inline proto::VarType::Type ToDataType(std::type_index type) {
// using namespace paddle_mobile::framework::proto; // using namespace paddle_mobile::framework::proto;
// if (typeid(float).hash_code() == type.hash_code()) { // if (typeid(float).hash_code() == type.hash_code()) {
// return proto::VarType::FP32; // return proto::VarType::FP32;
// } else if (typeid(double).hash_code() == type.hash_code()) { // } else if (typeid(double).hash_code() == type.hash_code()) {
// return proto::VarType::FP64; // return proto::VarType::FP64;
// } else if (typeid(int).hash_code() == type.hash_code()) { // } else if (typeid(int).hash_code() == type.hash_code()) {
// return proto::VarType::INT32; // return proto::VarType::INT32;
// } else if (typeid(int64_t).hash_code() == type.hash_code()) { // } else if (typeid(int64_t).hash_code() == type.hash_code()) {
// return proto::VarType::INT64; // return proto::VarType::INT64;
// } else if (typeid(bool).hash_code() == type.hash_code()) { // } else if (typeid(bool).hash_code() == type.hash_code()) {
// return proto::VarType::BOOL; // return proto::VarType::BOOL;
// } else { // } else {
//// PADDLE_THROW("Not supported"); //// PADDLE_THROW("Not supported");
// } // }
// } // }
} }
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -15,320 +15,318 @@ limitations under the License. */ ...@@ -15,320 +15,318 @@ limitations under the License. */
#include "ddim.h" #include "ddim.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
/// @cond HIDDEN /// @cond HIDDEN
template <int i> Dim<i> make_dim(const int64_t *d) { template <int i> Dim<i> make_dim(const int64_t *d) {
return Dim<i>(*d, make_dim<i - 1>(d + 1)); return Dim<i>(*d, make_dim<i - 1>(d + 1));
} }
template <> Dim<0> make_dim<0>(const int64_t *d) { return Dim<0>(*d); } template <> Dim<0> make_dim<0>(const int64_t *d) { return Dim<0>(*d); }
void make_ddim(DDim &ddim, const int64_t *dims, int n) { void make_ddim(DDim &ddim, const int64_t *dims, int n) {
switch (n) { switch (n) {
case 0: case 0:
ddim = make_dim<0>(dims); ddim = make_dim<0>(dims);
break; break;
case 1: case 1:
ddim = make_dim<1>(dims); ddim = make_dim<1>(dims);
break; break;
case 2: case 2:
ddim = make_dim<2>(dims); ddim = make_dim<2>(dims);
break; break;
case 3: case 3:
ddim = make_dim<3>(dims); ddim = make_dim<3>(dims);
break; break;
case 4: case 4:
ddim = make_dim<4>(dims); ddim = make_dim<4>(dims);
break; break;
case 5: case 5:
ddim = make_dim<5>(dims); ddim = make_dim<5>(dims);
break; break;
case 6: case 6:
ddim = make_dim<6>(dims); ddim = make_dim<6>(dims);
break; break;
case 7: case 7:
ddim = make_dim<7>(dims); ddim = make_dim<7>(dims);
break; break;
case 8: case 8:
ddim = make_dim<8>(dims); ddim = make_dim<8>(dims);
break; break;
case 9: case 9:
ddim = make_dim<9>(dims); ddim = make_dim<9>(dims);
break; break;
default: default:
// std::cout << "Dynamic dimensions must have between [1, // std::cout << "Dynamic dimensions must have between [1,
// 9] // 9]
// dimensions."; // dimensions.";
break; break;
} }
} }
/// @endcond /// @endcond
DDim make_ddim(std::initializer_list<int64_t> dims) { DDim make_ddim(std::initializer_list<int64_t> dims) {
DDim result(make_dim(0)); DDim result(make_dim(0));
make_ddim(result, dims.begin(), dims.size()); make_ddim(result, dims.begin(), dims.size());
return result; return result;
} }
DDim make_ddim(const std::vector<int64_t> &dims) { DDim make_ddim(const std::vector<int64_t> &dims) {
DDim result(make_dim(0)); DDim result(make_dim(0));
make_ddim(result, &dims[0], dims.size()); make_ddim(result, &dims[0], dims.size());
return result; return result;
} }
DDim make_ddim(const std::vector<int> &dims) { DDim make_ddim(const std::vector<int> &dims) {
std::vector<int64_t> res(dims.size()); std::vector<int64_t> res(dims.size());
std::transform(dims.begin(), dims.end(), res.begin(), std::transform(dims.begin(), dims.end(), res.begin(),
[](int d) { return static_cast<int64_t>(d); }); [](int d) { return static_cast<int64_t>(d); });
return make_ddim(res); return make_ddim(res);
} }
/// @cond HIDDEN /// @cond HIDDEN
// XXX For some reason, putting this in an anonymous namespace causes // XXX For some reason, putting this in an anonymous namespace causes
// errors // errors
struct DynamicMutableIndexer : Vistor<int64_t &> { struct DynamicMutableIndexer : Vistor<int64_t &> {
public: public:
explicit DynamicMutableIndexer(int idx) : idx_(idx) {} explicit DynamicMutableIndexer(int idx) : idx_(idx) {}
template <int D> int64_t &operator()(Dim<D> &dim) const { template <int D> int64_t &operator()(Dim<D> &dim) const {
return dim[idx_]; return dim[idx_];
} }
private: private:
int idx_; int idx_;
}; };
struct DynamicConstIndexer : public Vistor<int64_t> { struct DynamicConstIndexer : public Vistor<int64_t> {
public: public:
explicit DynamicConstIndexer(int idx) : idx_(idx) {} explicit DynamicConstIndexer(int idx) : idx_(idx) {}
template <int D> int64_t operator()(const Dim<D> &dim) const { template <int D> int64_t operator()(const Dim<D> &dim) const {
return dim[idx_]; return dim[idx_];
} }
private: private:
int idx_; int idx_;
}; };
/// @endcond /// @endcond
int64_t &DDim::operator[](int idx) { int64_t &DDim::operator[](int idx) {
return DDim::ApplyVistor(DynamicMutableIndexer(idx), *this); return DDim::ApplyVistor(DynamicMutableIndexer(idx), *this);
} }
int64_t DDim::operator[](int idx) const { int64_t DDim::operator[](int idx) const {
return DDim::ApplyVistor(DynamicConstIndexer(idx), *this); return DDim::ApplyVistor(DynamicConstIndexer(idx), *this);
} }
int DDim::size() const { return arity(*this); } int DDim::size() const { return arity(*this); }
bool DDim::operator==(DDim d) const { bool DDim::operator==(DDim d) const {
// if (var.which() != d.getVar().which()) { // if (var.which() != d.getVar().which()) {
// return false; // return false;
// } else { // } else {
std::vector<int64_t> v1 = vectorize(*this); std::vector<int64_t> v1 = vectorize(*this);
std::vector<int64_t> v2 = vectorize(d); std::vector<int64_t> v2 = vectorize(d);
for (unsigned int i = 0; i < v1.size(); i++) { for (unsigned int i = 0; i < v1.size(); i++) {
if (v1[i] != v2[i]) { if (v1[i] != v2[i]) {
return false; return false;
}
}
return true;
// }
}
bool DDim::operator!=(DDim d) const { return !(*this == d); }
DDim DDim::operator+(DDim d) const {
std::vector<int64_t> v1 = vectorize(*this);
std::vector<int64_t> v2 = vectorize(d);
std::vector<int64_t> v3;
assert(v1.size() == v2.size());
for (unsigned int i = 0; i < v1.size(); i++) {
v3.push_back(v1[i] + v2[i]);
}
return make_ddim(v3);
}
DDim DDim::operator*(DDim d) const {
std::vector<int64_t> v1 = vectorize(*this);
std::vector<int64_t> v2 = vectorize(d);
std::vector<int64_t> v3;
assert(v1.size() == v2.size());
for (unsigned int i = 0; i < v1.size(); i++) {
v3.push_back(v1[i] * v2[i]);
}
return make_ddim(v3);
} }
}
int64_t get(const DDim &ddim, int idx) { return ddim[idx]; } return true;
// }
void set(DDim &ddim, int idx, int value) { ddim[idx] = value; } }
/// @cond HIDDEN bool DDim::operator!=(DDim d) const { return !(*this == d); }
struct VectorizeVisitor : Vistor<void> {
std::vector<int64_t> &vector; DDim DDim::operator+(DDim d) const {
std::vector<int64_t> v1 = vectorize(*this);
explicit VectorizeVisitor(std::vector<int64_t> &v) : vector(v) {} std::vector<int64_t> v2 = vectorize(d);
template <typename T> void operator()(const T &t) { std::vector<int64_t> v3;
vector.push_back(t.head);
this->operator()(t.tail); assert(v1.size() == v2.size());
}
for (unsigned int i = 0; i < v1.size(); i++) {
void operator()(const Dim<0> &t) {} v3.push_back(v1[i] + v2[i]);
}; }
/// @endcond
return make_ddim(v3);
std::vector<int64_t> vectorize(const DDim &ddim) { }
std::vector<int64_t> result;
VectorizeVisitor visitor(result); DDim DDim::operator*(DDim d) const {
DDim::ApplyVistor(visitor, ddim); std::vector<int64_t> v1 = vectorize(*this);
return result; std::vector<int64_t> v2 = vectorize(d);
std::vector<int64_t> v3;
assert(v1.size() == v2.size());
for (unsigned int i = 0; i < v1.size(); i++) {
v3.push_back(v1[i] * v2[i]);
}
return make_ddim(v3);
}
int64_t get(const DDim &ddim, int idx) { return ddim[idx]; }
void set(DDim &ddim, int idx, int value) { ddim[idx] = value; }
/// @cond HIDDEN
struct VectorizeVisitor : Vistor<void> {
std::vector<int64_t> &vector;
explicit VectorizeVisitor(std::vector<int64_t> &v) : vector(v) {}
template <typename T> void operator()(const T &t) {
vector.push_back(t.head);
this->operator()(t.tail);
}
void operator()(const Dim<0> &t) {}
};
/// @endcond
std::vector<int64_t> vectorize(const DDim &ddim) {
std::vector<int64_t> result;
VectorizeVisitor visitor(result);
DDim::ApplyVistor(visitor, ddim);
return result;
}
// NOTE: framework::vectorize converts to type int64_t
// which does not fit cudnn inputs.
std::vector<int> vectorize2int(const DDim &ddim) {
std::vector<int64_t> temp = vectorize(ddim);
std::vector<int> result(temp.begin(), temp.end());
return result;
}
struct ProductVisitor : Vistor<int64_t> {
template <int D> int64_t operator()(const Dim<D> &dim) {
return product(dim);
}
};
int64_t product(const DDim &ddim) {
ProductVisitor visitor;
return DDim::ApplyVistor(visitor, ddim);
}
struct SliceVectorizeVisitor : Vistor<void> {
std::vector<int64_t> &vector;
int begin;
int end;
SliceVectorizeVisitor(std::vector<int64_t> &v, int b, int e)
: vector(v), begin(b), end(e) {
// PADDLE_ENFORCE(begin < end,
// "Begin index must be less than end index in
// ddim
// slice.");
// PADDLE_ENFORCE(begin >= 0,
// "Begin index can't be less than zero in
// ddim slice.");
}
template <int S> void operator()(const Dim<S> &dim) {
if (begin == 0) {
vector.push_back(dim.head);
} else {
--begin;
} }
--end;
// NOTE: framework::vectorize converts to type int64_t if (end > 0) {
// which does not fit cudnn inputs. this->operator()(dim.tail);
std::vector<int> vectorize2int(const DDim &ddim) {
std::vector<int64_t> temp = vectorize(ddim);
std::vector<int> result(temp.begin(), temp.end());
return result;
} }
}
struct ProductVisitor : Vistor<int64_t> {
template <int D> int64_t operator()(const Dim<D> &dim) { void operator()(const Dim<0> &dim) {
return product(dim); // PADDLE_ENFORCE(end == 0, "End index in ddim slice is out
} // of bound.");
}; }
};
int64_t product(const DDim &ddim) {
ProductVisitor visitor; DDim slice_ddim(const DDim &ddim, int begin, int end) {
return DDim::ApplyVistor(visitor, ddim); std::vector<int64_t> vec;
} vec.reserve(end - begin);
SliceVectorizeVisitor visitor(vec, begin, end);
struct SliceVectorizeVisitor : Vistor<void> { // boost::apply_visitor(visitor, dim);
std::vector<int64_t> &vector; DDim::ApplyVistor(visitor, ddim);
int begin; // visitor(ddim.var.Get<Dim<4>>());
int end; return make_ddim(vec);
}
SliceVectorizeVisitor(std::vector<int64_t> &v, int b, int e)
: vector(v), begin(b), end(e) { /// \cond HIDDEN
// PADDLE_ENFORCE(begin < end,
// "Begin index must be less than end index in struct ArityVisitor : Vistor<int> {
// ddim template <int D> int operator()(Dim<D>) const { return D; }
// slice."); };
// PADDLE_ENFORCE(begin >= 0,
// "Begin index can't be less than zero in /// \endcond
// ddim slice.");
} int arity(const DDim &d) {
ArityVisitor arityVisitor = ArityVisitor();
template <int S> void operator()(const Dim<S> &dim) { return DDim::ApplyVistor(arityVisitor, d);
if (begin == 0) { // return arityVisitor(d.var.Get<Dim<4>>());
vector.push_back(dim.head); // return boost::apply_visitor(ArityVisitor(), d); }
} else { }
--begin; /// \cond HIDDEN
}
--end; /// \endcond
if (end > 0) {
this->operator()(dim.tail); struct OSVistor : Vistor<std::ostream &> {
} OSVistor(std::ostream &os) : os_(os) {}
}
template <int D> std::ostream &operator()(Dim<D> dim) const {
void operator()(const Dim<0> &dim) { return os_ << dim;
// PADDLE_ENFORCE(end == 0, "End index in ddim slice is out }
// of bound.");
} private:
}; std::ostream &os_;
};
DDim slice_ddim(const DDim &ddim, int begin, int end) {
std::vector<int64_t> vec; std::ostream &operator<<(std::ostream &os, const DDim &ddim) {
vec.reserve(end - begin); auto vistor = OSVistor(os);
SliceVectorizeVisitor visitor(vec, begin, end); DDim::ApplyVistor(vistor, ddim);
// boost::apply_visitor(visitor, dim); return os;
DDim::ApplyVistor(visitor, ddim); }
// visitor(ddim.var.Get<Dim<4>>());
return make_ddim(vec); DDim::DDim(std::initializer_list<int64_t> init_list) {
} *this = make_ddim(init_list);
}
/// \cond HIDDEN
DDim flatten_to_2d(const DDim &src, int num_col_dims) {
struct ArityVisitor : Vistor<int> { int rank = src.size();
template <int D> int operator()(Dim<D>) const { return D; } return make_ddim({product(slice_ddim(src, 0, num_col_dims)),
}; product(slice_ddim(src, num_col_dims, rank))});
}
/// \endcond
DDim flatten_to_1d(const DDim &src) { return make_ddim({product(src)}); }
int arity(const DDim &d) {
ArityVisitor arityVisitor = ArityVisitor(); DDim stride(const DDim &ddim) {
return DDim::ApplyVistor(arityVisitor, d); std::vector<int64_t> strides(ddim.size());
// return arityVisitor(d.var.Get<Dim<4>>()); strides[ddim.size() - 1] = 1;
// return boost::apply_visitor(ArityVisitor(), d); } for (int i = ddim.size() - 2; i >= 0; --i) {
} strides[i] = strides[i + 1] * ddim[i + 1];
/// \cond HIDDEN }
return framework::make_ddim(strides);
/// \endcond }
struct OSVistor : Vistor<std::ostream &> { DDim stride_numel(const framework::DDim &ddim) {
OSVistor(std::ostream &os) : os_(os) {} std::vector<int64_t> strides(ddim.size());
strides[ddim.size() - 1] = ddim[ddim.size() - 1];
template <int D> std::ostream &operator()(Dim<D> dim) const { for (int i = ddim.size() - 2; i >= 0; --i) {
return os_ << dim; strides[i] = strides[i + 1] * ddim[i];
} }
return framework::make_ddim(strides);
private: }
std::ostream &os_;
}; } // namespace framework
std::ostream &operator<<(std::ostream &os, const DDim &ddim) {
auto vistor = OSVistor(os);
DDim::ApplyVistor(vistor, ddim);
return os;
}
DDim::DDim(std::initializer_list<int64_t> init_list) {
*this = make_ddim(init_list);
}
DDim flatten_to_2d(const DDim &src, int num_col_dims) {
int rank = src.size();
return make_ddim({product(slice_ddim(src, 0, num_col_dims)),
product(slice_ddim(src, num_col_dims, rank))});
}
DDim flatten_to_1d(const DDim &src) {
return make_ddim({product(src)});
}
DDim stride(const DDim &ddim) {
std::vector<int64_t> strides(ddim.size());
strides[ddim.size() - 1] = 1;
for (int i = ddim.size() - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * ddim[i + 1];
}
return framework::make_ddim(strides);
}
DDim stride_numel(const framework::DDim &ddim) {
std::vector<int64_t> strides(ddim.size());
strides[ddim.size() - 1] = ddim[ddim.size() - 1];
for (int i = ddim.size() - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * ddim[i];
}
return framework::make_ddim(strides);
}
} // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -22,145 +22,142 @@ limitations under the License. */ ...@@ -22,145 +22,142 @@ limitations under the License. */
#include <vector> #include <vector>
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
/** /**
* \brief A dynamically sized dimension. * \brief A dynamically sized dimension.
* *
* The number of dimensions must be between [1, 9]. * The number of dimensions must be between [1, 9].
*/ */
struct DDim { struct DDim {
typedef Variant<Dim<0>, Dim<1>, Dim<2>, Dim<3>, Dim<4>, Dim<5>, typedef Variant<Dim<0>, Dim<1>, Dim<2>, Dim<3>, Dim<4>, Dim<5>, Dim<6>,
Dim<6>, Dim<7>, Dim<8>, Dim<9>> Dim<7>, Dim<8>, Dim<9>>
DDimVar; DDimVar;
DDimVar var; DDimVar var;
template <typename Vistor> template <typename Vistor>
static typename Vistor::type_t ApplyVistor(Vistor vistor, static typename Vistor::type_t ApplyVistor(Vistor vistor, const DDim &d) {
const DDim &d) { if (d.var.TypeId() == typeid(Dim<0>).hash_code()) {
if (d.var.TypeId() == typeid(Dim<0>).hash_code()) { return vistor(d.var.Get<Dim<0>>());
return vistor(d.var.Get<Dim<0>>()); } else if (d.var.TypeId() == typeid(Dim<1>).hash_code()) {
} else if (d.var.TypeId() == typeid(Dim<1>).hash_code()) { return vistor(d.var.Get<Dim<1>>());
return vistor(d.var.Get<Dim<1>>()); } else if (d.var.TypeId() == typeid(Dim<2>).hash_code()) {
} else if (d.var.TypeId() == typeid(Dim<2>).hash_code()) { return vistor(d.var.Get<Dim<2>>());
return vistor(d.var.Get<Dim<2>>()); } else if (d.var.TypeId() == typeid(Dim<3>).hash_code()) {
} else if (d.var.TypeId() == typeid(Dim<3>).hash_code()) { return vistor(d.var.Get<Dim<3>>());
return vistor(d.var.Get<Dim<3>>()); } else if (d.var.TypeId() == typeid(Dim<4>).hash_code()) {
} else if (d.var.TypeId() == typeid(Dim<4>).hash_code()) { return vistor(d.var.Get<Dim<4>>());
return vistor(d.var.Get<Dim<4>>()); } else if (d.var.TypeId() == typeid(Dim<5>).hash_code()) {
} else if (d.var.TypeId() == typeid(Dim<5>).hash_code()) { return vistor(d.var.Get<Dim<5>>());
return vistor(d.var.Get<Dim<5>>()); } else if (d.var.TypeId() == typeid(Dim<6>).hash_code()) {
} else if (d.var.TypeId() == typeid(Dim<6>).hash_code()) { return vistor(d.var.Get<Dim<6>>());
return vistor(d.var.Get<Dim<6>>()); } else if (d.var.TypeId() == typeid(Dim<7>).hash_code()) {
} else if (d.var.TypeId() == typeid(Dim<7>).hash_code()) { return vistor(d.var.Get<Dim<7>>());
return vistor(d.var.Get<Dim<7>>()); } else if (d.var.TypeId() == typeid(Dim<8>).hash_code()) {
} else if (d.var.TypeId() == typeid(Dim<8>).hash_code()) { return vistor(d.var.Get<Dim<8>>());
return vistor(d.var.Get<Dim<8>>()); } else if (d.var.TypeId() == typeid(Dim<9>).hash_code()) {
} else if (d.var.TypeId() == typeid(Dim<9>).hash_code()) { return vistor(d.var.Get<Dim<9>>());
return vistor(d.var.Get<Dim<9>>()); } else {
} else { printf(" dim not support \n");
printf(" dim not support \n"); throw std::bad_exception();
throw std::bad_exception(); // return typename Vistor::type_t();
// return typename Vistor::type_t(); }
} }
}
DDim() { var.Set<Dim<1>>(Dim<1>()); }
DDim() { var.Set<Dim<1>>(Dim<1>()); }
template <int D> explicit DDim(const Dim<D> &in) { var.Set<Dim<D>>(in); }
template <int D> explicit DDim(const Dim<D> &in) {
var.Set<Dim<D>>(in); /*implicit*/ DDim(std::initializer_list<int64_t> init_list);
}
template <int D> DDim &operator=(const Dim<D> &in) {
/*implicit*/ DDim(std::initializer_list<int64_t> init_list); var.Set<Dim<D>>(in);
return *this;
template <int D> DDim &operator=(const Dim<D> &in) { }
var.Set<Dim<D>>(in);
return *this; int64_t &operator[](int idx);
}
int64_t operator[](int idx) const;
int64_t &operator[](int idx);
// template <typename Visitor>
int64_t operator[](int idx) const; // typename Visitor::result_type apply_visitor(Visitor& visitor) {
// return var.apply_visitor(visitor);
// template <typename Visitor> // }
// typename Visitor::result_type apply_visitor(Visitor& visitor) { //
// return var.apply_visitor(visitor); // template <typename Visitor>
// } // typename Visitor::result_type apply_visitor(Visitor& visitor)
// // const {
// template <typename Visitor> // return var.apply_visitor(visitor);
// typename Visitor::result_type apply_visitor(Visitor& visitor) // }
// const {
// return var.apply_visitor(visitor); DDimVar getVar() { return var; }
// }
bool operator==(DDim d) const;
DDimVar getVar() { return var; }
bool operator!=(DDim d) const;
bool operator==(DDim d) const;
DDim operator+(DDim d) const;
bool operator!=(DDim d) const;
DDim operator*(DDim d) const;
DDim operator+(DDim d) const;
int size() const;
DDim operator*(DDim d) const; };
int size() const; /**
}; * \brief Make a DDim from std::vector<int64_t>
*
/** * \param dims An vector of ints. Must be sized between [1, 9]
* \brief Make a DDim from std::vector<int64_t> */
* DDim make_ddim(const std::vector<int64_t> &dims);
* \param dims An vector of ints. Must be sized between [1, 9]
*/ DDim make_ddim(const std::vector<int> &dims);
DDim make_ddim(const std::vector<int64_t> &dims);
/**
DDim make_ddim(const std::vector<int> &dims); * \brief Make a DDim from an initializer list
*
/** * \param dims An initializer list of ints. Must be sized between [1, 9]
* \brief Make a DDim from an initializer list *
* */
* \param dims An initializer list of ints. Must be sized between [1, 9] DDim make_ddim(std::initializer_list<int64_t> dims);
*
*/ int64_t get(const DDim &dim, int idx);
DDim make_ddim(std::initializer_list<int64_t> dims);
void set(DDim &dim, int idx, int val);
int64_t get(const DDim &dim, int idx);
std::vector<int64_t> vectorize(const DDim &ddim);
void set(DDim &dim, int idx, int val);
std::vector<int> vectorize2int(const DDim &ddim);
std::vector<int64_t> vectorize(const DDim &ddim);
int64_t product(const DDim &ddim);
std::vector<int> vectorize2int(const DDim &ddim);
/**
int64_t product(const DDim &ddim); * \brief Slice a ddim
*
/** * Slice dim with [begin, end).
* \brief Slice a ddim * e.g. DDim d = make_ddim({1,2,3,4,5});
* * slice_ddim(d, 1, 3); ====> {2,3}
* Slice dim with [begin, end). */
* e.g. DDim d = make_ddim({1,2,3,4,5}); DDim slice_ddim(const DDim &dim, int begin, int end);
* slice_ddim(d, 1, 3); ====> {2,3}
*/ /**
DDim slice_ddim(const DDim &dim, int begin, int end); * \brief What is the length of this dimension?
*
/** * \param Dynamic dimension to inspect
* \brief What is the length of this dimension? */
*
* \param Dynamic dimension to inspect
*/
int arity(const DDim &ddim); int arity(const DDim &ddim);
std::ostream &operator<<(std::ostream &, const DDim &); std::ostream &operator<<(std::ostream &, const DDim &);
// Reshape a tensor to a matrix. The matrix's first dimension(column // Reshape a tensor to a matrix. The matrix's first dimension(column
// length) // length)
// will be the product of tensor's first `num_col_dims` dimensions. // will be the product of tensor's first `num_col_dims` dimensions.
DDim flatten_to_2d(const DDim &src, int num_col_dims); DDim flatten_to_2d(const DDim &src, int num_col_dims);
DDim flatten_to_1d(const DDim &src); DDim flatten_to_1d(const DDim &src);
DDim stride(const DDim &ddim); DDim stride(const DDim &ddim);
DDim stride_numel(const DDim &ddim); DDim stride_numel(const DDim &ddim);
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
此差异已折叠。
...@@ -23,75 +23,72 @@ SOFTWARE. ...@@ -23,75 +23,72 @@ SOFTWARE.
#include "variable.h" #include "variable.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
template <typename Dtype> template <typename Dtype>
Executor<Dtype>::Executor(const Program<Dtype> p) : program_(p) { Executor<Dtype>::Executor(const Program<Dtype> p) : program_(p) {
if (use_optimize_) { if (use_optimize_) {
to_predict_program_ = program_.optimizeProgram; to_predict_program_ = program_.optimizeProgram;
} else { } else {
to_predict_program_ = program_.originProgram; to_predict_program_ = program_.originProgram;
} }
const std::vector<std::shared_ptr<BlockDesc>> blocks = const std::vector<std::shared_ptr<BlockDesc>> blocks =
to_predict_program_->Blocks(); to_predict_program_->Blocks();
for (int i = 0; i < blocks.size(); ++i) { for (int i = 0; i < blocks.size(); ++i) {
std::shared_ptr<BlockDesc> block_desc = blocks[i]; std::shared_ptr<BlockDesc> block_desc = blocks[i];
std::vector<std::shared_ptr<OpDesc>> ops = block_desc->Ops(); std::vector<std::shared_ptr<OpDesc>> ops = block_desc->Ops();
for (int j = 0; j < ops.size(); ++j) { for (int j = 0; j < ops.size(); ++j) {
std::shared_ptr<OpDesc> op = ops[j]; std::shared_ptr<OpDesc> op = ops[j];
if (op->Type() == "conv2d" && if (op->Type() == "conv2d" && op->Input("Input")[0] == "pixel") {
op->Input("Input")[0] == "pixel") { Attribute strides_attr = op->GetAttrMap().at("strides");
Attribute strides_attr = op->GetAttrMap().at("strides"); std::vector<int> stride = strides_attr.Get<std::vector<int>>();
std::vector<int> stride = for (int k = 0; k < stride.size(); ++k) {
strides_attr.Get<std::vector<int>>();
for (int k = 0; k < stride.size(); ++k) {
}
std::shared_ptr<operators::ConvOp<Dtype, float>> conv =
std::make_shared<operators::ConvOp<Dtype, float>>(
op->Type(), op->GetInputs(), op->GetOutputs(),
op->GetAttrMap(), program_.scope);
ops_of_block_[*block_desc.get()].push_back(conv);
}
} }
std::shared_ptr<operators::ConvOp<Dtype, float>> conv =
std::make_shared<operators::ConvOp<Dtype, float>>(
op->Type(), op->GetInputs(), op->GetOutputs(),
op->GetAttrMap(), program_.scope);
ops_of_block_[*block_desc.get()].push_back(conv);
} }
} }
}
}
template <typename Dtype> template <typename Dtype>
std::shared_ptr<Tensor> Executor<Dtype>::predict(Tensor &t) { std::shared_ptr<Tensor> Executor<Dtype>::predict(Tensor &t) {
// feed // feed
auto scope = program_.scope; auto scope = program_.scope;
Variable *g_feed_value = scope->Var("pixel"); Variable *g_feed_value = scope->Var("pixel");
auto tensor = g_feed_value->GetMutable<Tensor>(); auto tensor = g_feed_value->GetMutable<Tensor>();
tensor->ShareDataWith(t); tensor->ShareDataWith(t);
Variable *con_output = scope->Var("conv2d_0.tmp_0"); Variable *con_output = scope->Var("conv2d_0.tmp_0");
Tensor *output_tensor = con_output->GetMutable<Tensor>(); Tensor *output_tensor = con_output->GetMutable<Tensor>();
output_tensor->mutable_data<float>({1, 16, 32, 32}); output_tensor->mutable_data<float>({1, 16, 32, 32});
// std::cout << typeid(output_tensor).name() << std::endl; // std::cout << typeid(output_tensor).name() << std::endl;
// std::cout << "output_tensor dims: " << output_tensor->dims() << // std::cout << "output_tensor dims: " << output_tensor->dims() <<
// std::endl; // std::endl;
std::shared_ptr<Tensor> out_tensor = std::make_shared<LoDTensor>(); std::shared_ptr<Tensor> out_tensor = std::make_shared<LoDTensor>();
out_tensor.reset(output_tensor); out_tensor.reset(output_tensor);
predict(t, 0); predict(t, 0);
return out_tensor; return out_tensor;
} }
template <typename Dtype> template <typename Dtype>
void Executor<Dtype>::predict(const Tensor &t, int block_id) { void Executor<Dtype>::predict(const Tensor &t, int block_id) {
std::shared_ptr<BlockDesc> to_predict_block = std::shared_ptr<BlockDesc> to_predict_block =
to_predict_program_->Block(block_id); to_predict_program_->Block(block_id);
for (int j = 0; j < ops_of_block_[*to_predict_block.get()].size(); for (int j = 0; j < ops_of_block_[*to_predict_block.get()].size(); ++j) {
++j) { auto op = ops_of_block_[*to_predict_block.get()][j];
auto op = ops_of_block_[*to_predict_block.get()][j]; // std::cout << "开始run" << std::endl;
// std::cout << "开始run" << std::endl; op->Run();
op->Run(); }
} }
}
template class Executor<CPU>; template class Executor<CPU>;
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -32,22 +32,22 @@ SOFTWARE. ...@@ -32,22 +32,22 @@ SOFTWARE.
#include "variable.h" #include "variable.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
template <typename Dtype> class Executor { template <typename Dtype> class Executor {
public: public:
Executor(const Program<Dtype> p); Executor(const Program<Dtype> p);
std::shared_ptr<Tensor> predict(Tensor &t); std::shared_ptr<Tensor> predict(Tensor &t);
private: private:
const framework::Program<Dtype> program_; const framework::Program<Dtype> program_;
std::shared_ptr<ProgramDesc> to_predict_program_; std::shared_ptr<ProgramDesc> to_predict_program_;
void predict(const Tensor &t, int block_id); void predict(const Tensor &t, int block_id);
std::map<framework::BlockDesc, std::map<framework::BlockDesc,
std::vector<std::shared_ptr<OperatorBase<Dtype>>>> std::vector<std::shared_ptr<OperatorBase<Dtype>>>>
ops_of_block_; ops_of_block_;
bool use_optimize_ = false; bool use_optimize_ = false;
}; };
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
此差异已折叠。
此差异已折叠。
此差异已折叠。
...@@ -23,190 +23,186 @@ limitations under the License. */ ...@@ -23,190 +23,186 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
/* /*
* LoD is short for Level of Details. * LoD is short for Level of Details.
* *
* - in a level, each element indicates relative offset of the lower * - in a level, each element indicates relative offset of the lower
* level * level
* - the first element should be 0 and that indicates that this sequence * - the first element should be 0 and that indicates that this sequence
* start * start
* from 0 * from 0
* - each sequence's begin and end(no-inclusive) is level[id, id+1] * - each sequence's begin and end(no-inclusive) is level[id, id+1]
* *
* For example: * For example:
* 3-level LoD stores * 3-level LoD stores
* *
* 0 2 3 * 0 2 3
* 0 2 4 7 * 0 2 4 7
* 0 2 5 7 10 12 15 20 * 0 2 5 7 10 12 15 20
*/ */
using LoD = std::vector<std::vector<size_t>>; using LoD = std::vector<std::vector<size_t>>;
std::ostream &operator<<(std::ostream &os, const LoD &lod); std::ostream &operator<<(std::ostream &os, const LoD &lod);
std::ostream &operator<<(std::ostream &os, const LoDTensor &t); std::ostream &operator<<(std::ostream &os, const LoDTensor &t);
std::string LoDToString(const LoD &lod); std::string LoDToString(const LoD &lod);
LoD SliceInLevel(const LoD &in, size_t level, size_t elem_begin, LoD SliceInLevel(const LoD &in, size_t level, size_t elem_begin,
size_t elem_end); size_t elem_end);
/* /*
* Transform an LoD from relative offsets to absolute offsets. * Transform an LoD from relative offsets to absolute offsets.
*/ */
LoD ToAbsOffset(const LoD &in); LoD ToAbsOffset(const LoD &in);
bool operator==(const LoD &a, const LoD &b); bool operator==(const LoD &a, const LoD &b);
/* /*
* Check whether this lod's format is valid. * Check whether this lod's format is valid.
* *
* ATTENTION: * ATTENTION:
* - Empty lod is treated as valid. * - Empty lod is treated as valid.
* *
* It will check two things: * It will check two things:
* *
* 1. all the offsets in a level should be ascending(no same items * 1. all the offsets in a level should be ascending(no same items
* allows). * allows).
* 2. there should be more than 2 offsets existing in each level. * 2. there should be more than 2 offsets existing in each level.
* 3. the higher level's last offset should equals the lower level's * 3. the higher level's last offset should equals the lower level's
* size-1. * size-1.
* 4. the first offset(the begin offset) of each level should be 0. * 4. the first offset(the begin offset) of each level should be 0.
* 5. the lowest level's last offset should equals `tensor_height` if * 5. the lowest level's last offset should equals `tensor_height` if
* tensor_height>0. * tensor_height>0.
*/ */
bool CheckLoD(const LoD &in, int tensor_height = -1); bool CheckLoD(const LoD &in, int tensor_height = -1);
/* /*
* Check whether this absolute lod's format is valid. * Check whether this absolute lod's format is valid.
* *
* ATTENTION: * ATTENTION:
* - Empty lod is treated as valid. * - Empty lod is treated as valid.
* *
* It will check two things: * It will check two things:
* 1. all the offsets in a level should be ascending(no same items * 1. all the offsets in a level should be ascending(no same items
* allows) * allows)
* 2. there should be more than 2 offsets existing in each level. * 2. there should be more than 2 offsets existing in each level.
* 3. the first offset of each level should be 0, and the last should * 3. the first offset of each level should be 0, and the last should
* be the * be the
* same(the height of underlying tensor) or `tensor_height` if * same(the height of underlying tensor) or `tensor_height` if
* tensor_height>0. * tensor_height>0.
*/ */
bool CheckAbsLoD(const LoD &in, int tensor_height = -1); bool CheckAbsLoD(const LoD &in, int tensor_height = -1);
/* /*
* LoDTensor (Level of details Tensor) * LoDTensor (Level of details Tensor)
* see https://en.wikipedia.org/wiki/Level_of_details for reference. * see https://en.wikipedia.org/wiki/Level_of_details for reference.
*/ */
class LoDTensor : public Tensor { class LoDTensor : public Tensor {
public: public:
LoDTensor() : Tensor() {} LoDTensor() : Tensor() {}
explicit LoDTensor(const LoD &lod) : lod_(lod) {} explicit LoDTensor(const LoD &lod) : lod_(lod) {}
void set_lod(const LoD &lod) { lod_ = lod; } void set_lod(const LoD &lod) { lod_ = lod; }
const LoD &lod() const { return lod_; } const LoD &lod() const { return lod_; }
LoD *mutable_lod() { return &lod_; } LoD *mutable_lod() { return &lod_; }
/* /*
* Get the start offset and end offset of an element from LoD. * Get the start offset and end offset of an element from LoD.
*/ */
std::pair<size_t, size_t> lod_element(size_t level, std::pair<size_t, size_t> lod_element(size_t level, size_t elem) const {
size_t elem) const { // PADDLE_ENFORCE_LT(level, NumLevels());
// PADDLE_ENFORCE_LT(level, NumLevels()); // PADDLE_ENFORCE_LT(elem, NumElements(level));
// PADDLE_ENFORCE_LT(elem, NumElements(level)); return std::make_pair((lod_)[level][elem], (lod_)[level][elem + 1]);
return std::make_pair((lod_)[level][elem], }
(lod_)[level][elem + 1]);
} /*
* Number of LoDTensor's levels, each level has units of data, for
/* * example,
* Number of LoDTensor's levels, each level has units of data, for * in the sentence's view, article, paragraph, sentence are 3
* example, * levels.
* in the sentence's view, article, paragraph, sentence are 3 */
* levels. size_t NumLevels() const { return lod_.size(); }
*/
size_t NumLevels() const { return lod_.size(); } /*
* Number of elements in a level.
/* */
* Number of elements in a level. size_t NumElements(size_t level = 0) const {
*/ // PADDLE_ENFORCE_LT(level, NumLevels());
size_t NumElements(size_t level = 0) const { // the last offset is the end of last element
// PADDLE_ENFORCE_LT(level, NumLevels()); return (lod_)[level].size() - 1;
// the last offset is the end of last element }
return (lod_)[level].size() - 1;
} private:
LoD lod_;
private: };
LoD lod_;
}; /*
* Expand the `source` to fit the LoD of `lod`. For example, a `source`
/* * LoDTensor is
* Expand the `source` to fit the LoD of `lod`. For example, a `source` * - LoD: [0, 2]
* LoDTensor is * - tensor: [a0, a1]
* - LoD: [0, 2] * a `lod` is
* - tensor: [a0, a1] * - LoD: [0 3 5]
* a `lod` is * returns a new LoDTensor
* - LoD: [0 3 5] * - [a0 a0 a0 a1 a1]
* returns a new LoDTensor */
* - [a0 a0 a0 a1 a1] template <typename T>
*/ LoDTensor LodExpand(const LoDTensor &source, const LoD &lod, size_t level) {
template <typename T> LoD abs_lod = ToAbsOffset(lod);
LoDTensor LodExpand(const LoDTensor &source, const LoD &lod, const auto &lod_level = lod[level];
size_t level) { size_t num_instances = source.dims()[0];
LoD abs_lod = ToAbsOffset(lod);
const auto &lod_level = lod[level]; // new tensor
size_t num_instances = source.dims()[0]; LoDTensor tensor;
tensor.set_lod(lod);
// new tensor auto dims = source.dims();
LoDTensor tensor; dims[0] = lod_level.back();
tensor.set_lod(lod); tensor.Resize(dims);
auto dims = source.dims(); tensor.mutable_data<T>();
dims[0] = lod_level.back();
tensor.Resize(dims); // PADDLE_ENFORCE_EQ(num_instances, lod_level.size() - 1);
tensor.mutable_data<T>(); for (size_t ins = 0; ins < num_instances; ins++) {
for (size_t elem = lod_level[ins]; elem < lod_level[ins + 1]; elem++) {
// PADDLE_ENFORCE_EQ(num_instances, lod_level.size() - 1); auto slice = tensor.Slice(elem, elem + 1);
for (size_t ins = 0; ins < num_instances; ins++) { TensorCopy(source.Slice(ins, ins + 1), &slice);
for (size_t elem = lod_level[ins]; elem < lod_level[ins + 1];
elem++) {
auto slice = tensor.Slice(elem, elem + 1);
TensorCopy(source.Slice(ins, ins + 1), &slice);
}
}
return tensor;
} }
}
// Get the absolute offset of a lod[start_level][start_idx:end_idx] and return tensor;
// relative length of details for every levels(i.e., [start_level: ]). }
//
// For example, // Get the absolute offset of a lod[start_level][start_idx:end_idx] and
// lod = [[0, 3, 4, 8], [0, 9, 10, 11, 13, 17, 19, 22, 24]] // relative length of details for every levels(i.e., [start_level: ]).
// start_level = 0 //
// start_idx = 1 // For example,
// end_idx = 3 // lod = [[0, 3, 4, 8], [0, 9, 10, 11, 13, 17, 19, 22, 24]]
// // start_level = 0
// Returns: // start_idx = 1
// LoD = [[1, 4], [2, 4, 2, 3, 2]] // end_idx = 3
// pair<size_t, size_t> = {11, 24} //
std::pair<LoD, std::pair<size_t, size_t>> // Returns:
GetSubLoDAndAbsoluteOffset(const LoD &lod, size_t start_idx, // LoD = [[1, 4], [2, 4, 2, 3, 2]]
size_t end_idx, size_t start_level); // pair<size_t, size_t> = {11, 24}
std::pair<LoD, std::pair<size_t, size_t>>
void AppendLoD(LoD *lod, const LoD &lod_length); GetSubLoDAndAbsoluteOffset(const LoD &lod, size_t start_idx, size_t end_idx,
size_t start_level);
/*
* Serialize/Desiralize LoDTensor to std::ostream void AppendLoD(LoD *lod, const LoD &lod_length);
* You can pass ofstream or ostringstream to serilize to file
* or to a in memory string. GPU tensor will be copied to CPU. /*
*/ * Serialize/Desiralize LoDTensor to std::ostream
void SerializeToStream(std::ostream &os, const LoDTensor &tensor); * You can pass ofstream or ostringstream to serilize to file
* or to a in memory string. GPU tensor will be copied to CPU.
void DeserializeFromStream(std::istream &is, LoDTensor *tensor); */
void SerializeToStream(std::ostream &os, const LoDTensor &tensor);
} // namespace framework
void DeserializeFromStream(std::istream &is, LoDTensor *tensor);
} // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -5,58 +5,55 @@ ...@@ -5,58 +5,55 @@
#include "op_desc.h" #include "op_desc.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
OpDesc::OpDesc(const proto::OpDesc &desc) : desc_(desc) { OpDesc::OpDesc(const proto::OpDesc &desc) : desc_(desc) {
for (int i = 0; i < desc_.inputs_size(); ++i) { for (int i = 0; i < desc_.inputs_size(); ++i) {
const proto::OpDesc::Var &var = desc_.inputs(i); const proto::OpDesc::Var &var = desc_.inputs(i);
std::vector<std::string> &args = inputs_[var.parameter()]; std::vector<std::string> &args = inputs_[var.parameter()];
int arg_size = var.arguments_size(); int arg_size = var.arguments_size();
for (int j = 0; j < arg_size; ++j) { for (int j = 0; j < arg_size; ++j) {
args.push_back(var.arguments(j)); args.push_back(var.arguments(j));
}
}
for (int i = 0; i < desc_.outputs_size(); ++i) {
const proto::OpDesc::Var &var = desc_.outputs(i);
std::vector<std::string> &args = outputs_[var.parameter()];
int arg_size = var.arguments_size();
for (int j = 0; j < arg_size; ++j) {
args.push_back(var.arguments(j));
}
}
for (const proto::OpDesc::Attr &attr : desc_.attrs()) {
std::string attr_name = attr.name();
if (attr.type() != proto::AttrType::BLOCK) {
attrs_[attr_name] = Attribute::GetAttrValue(attr);
// if (attr.type() == proto::AttrType::INT){
// std::cout << " attrName " << attr_name << " " <<
// attrs_[attr_name].Get<int>() << std::endl;
// }
}
}
} }
}
const std::vector<std::string> &
OpDesc::Input(const std::string &name) const { for (int i = 0; i < desc_.outputs_size(); ++i) {
return inputs_.find(name)->second; const proto::OpDesc::Var &var = desc_.outputs(i);
std::vector<std::string> &args = outputs_[var.parameter()];
int arg_size = var.arguments_size();
for (int j = 0; j < arg_size; ++j) {
args.push_back(var.arguments(j));
} }
}
const std::vector<std::string> &
OpDesc::Output(const std::string &name) const { for (const proto::OpDesc::Attr &attr : desc_.attrs()) {
return outputs_.find(name)->second; std::string attr_name = attr.name();
if (attr.type() != proto::AttrType::BLOCK) {
attrs_[attr_name] = Attribute::GetAttrValue(attr);
// if (attr.type() == proto::AttrType::INT){
// std::cout << " attrName " << attr_name << " " <<
// attrs_[attr_name].Get<int>() << std::endl;
// }
} }
}
}
Attribute OpDesc::GetAttr(const std::string &name) const { const std::vector<std::string> &OpDesc::Input(const std::string &name) const {
auto it = attrs_.find(name); return inputs_.find(name)->second;
return it->second; }
}
const std::unordered_map<std::string, Attribute> & const std::vector<std::string> &OpDesc::Output(const std::string &name) const {
OpDesc::GetAttrMap() const { return outputs_.find(name)->second;
return attrs_; }
}
Attribute OpDesc::GetAttr(const std::string &name) const {
auto it = attrs_.find(name);
return it->second;
}
const std::unordered_map<std::string, Attribute> &OpDesc::GetAttrMap() const {
return attrs_;
}
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -23,31 +23,29 @@ SOFTWARE. ...@@ -23,31 +23,29 @@ SOFTWARE.
#include "paddle_mobile_object.h" #include "paddle_mobile_object.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
class OpDesc : PaddleMobileObject { class OpDesc : PaddleMobileObject {
public: public:
OpDesc(const proto::OpDesc &desc); OpDesc(const proto::OpDesc &desc);
const std::vector<std::string> & const std::vector<std::string> &Input(const std::string &name) const;
Input(const std::string &name) const; const std::vector<std::string> &Output(const std::string &name) const;
const std::vector<std::string> & Attribute GetAttr(const std::string &name) const;
Output(const std::string &name) const;
Attribute GetAttr(const std::string &name) const;
const VariableNameMap &GetInputs() { return inputs_; } const VariableNameMap &GetInputs() { return inputs_; }
const VariableNameMap &GetOutputs() { return outputs_; } const VariableNameMap &GetOutputs() { return outputs_; }
const AttributeMap &GetAttrMap() const; const AttributeMap &GetAttrMap() const;
const std::string &Type() { return desc_.type(); }; const std::string &Type() { return desc_.type(); };
private: private:
proto::OpDesc desc_; proto::OpDesc desc_;
VariableNameMap inputs_; VariableNameMap inputs_;
VariableNameMap outputs_; VariableNameMap outputs_;
AttributeMap attrs_; AttributeMap attrs_;
}; };
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -22,74 +22,73 @@ SOFTWARE. ...@@ -22,74 +22,73 @@ SOFTWARE.
#include "framework.pb.h" #include "framework.pb.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
template <typename Dtype> struct OpInfo { template <typename Dtype> struct OpInfo {
OpCreator<Dtype> creator_; OpCreator<Dtype> creator_;
const OpCreator<Dtype> &Creator() const { const OpCreator<Dtype> &Creator() const {
// PADDLE_ENFORCE_NOT_NULL(creator_, // PADDLE_ENFORCE_NOT_NULL(creator_,
// "Operator Creator has not been // "Operator Creator has not been
// registered"); // registered");
return creator_; return creator_;
} }
}; };
template <typename Dtype> class OpInfoMap; template <typename Dtype> class OpInfoMap;
template <typename Dtype> template <typename Dtype> static OpInfoMap<Dtype> *g_op_info_map = nullptr;
static OpInfoMap<Dtype> *g_op_info_map = nullptr;
template <typename Dtype> class OpInfoMap {
template <typename Dtype> class OpInfoMap { public:
public: static OpInfoMap &Instance() {
static OpInfoMap &Instance() { if (g_op_info_map<Dtype> == nullptr) {
if (g_op_info_map<Dtype> == nullptr) { g_op_info_map<Dtype> = new OpInfoMap();
g_op_info_map<Dtype> = new OpInfoMap(); }
} return *g_op_info_map<Dtype>;
return *g_op_info_map<Dtype>; };
};
bool Has(const std::string &op_type) const {
bool Has(const std::string &op_type) const { return map_.find(op_type) != map_.end();
return map_.find(op_type) != map_.end(); }
}
void Insert(const std::string &type, const OpInfo<Dtype> &info) {
void Insert(const std::string &type, const OpInfo<Dtype> &info) { // PADDLE_ENFORCE(!Has(type), "Operator %s has been
// PADDLE_ENFORCE(!Has(type), "Operator %s has been // registered", type);
// registered", type); map_.insert({type, info});
map_.insert({type, info}); }
}
const OpInfo<Dtype> &Get(const std::string &type) const {
const OpInfo<Dtype> &Get(const std::string &type) const { auto op_info_ptr = GetNullable(type);
auto op_info_ptr = GetNullable(type); // PADDLE_ENFORCE_NOT_NULL(op_info_ptr, "Operator %s has not
// PADDLE_ENFORCE_NOT_NULL(op_info_ptr, "Operator %s has not // been
// been // registered",
// registered", // type);
// type); return *op_info_ptr;
return *op_info_ptr; }
}
const OpInfo<Dtype> *GetNullable(const std::string &type) const {
const OpInfo<Dtype> *GetNullable(const std::string &type) const { auto it = map_.find(type);
auto it = map_.find(type); if (it == map_.end()) {
if (it == map_.end()) { return nullptr;
return nullptr; } else {
} else { return &it->second;
return &it->second; }
} }
}
const std::unordered_map<std::string, OpInfo<Dtype>> &map() const {
const std::unordered_map<std::string, OpInfo<Dtype>> &map() const { return map_;
return map_; }
}
std::unordered_map<std::string, OpInfo<Dtype>> *mutable_map() {
std::unordered_map<std::string, OpInfo<Dtype>> *mutable_map() { return &map_;
return &map_; }
}
private:
private: OpInfoMap() = default;
OpInfoMap() = default; std::unordered_map<std::string, OpInfo<Dtype>> map_;
std::unordered_map<std::string, OpInfo<Dtype>> map_;
// DISABLE_COPY_AND_ASSIGN(OpInfoMap);
// DISABLE_COPY_AND_ASSIGN(OpInfoMap); };
};
} // namespace framework
} // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -22,51 +22,44 @@ SOFTWARE. ...@@ -22,51 +22,44 @@ SOFTWARE.
#include "framework.pb.h" #include "framework.pb.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
struct OpKernelType { struct OpKernelType {
struct Hash { struct Hash {
size_t operator()(const OpKernelType &key) const { size_t operator()(const OpKernelType &key) const {
int data_type = static_cast<int>(key.data_type_) int data_type = static_cast<int>(key.data_type_) << LEFT_SHIFT;
<< LEFT_SHIFT; int data_layout = static_cast<int>(key.data_layout_)
int data_layout = static_cast<int>(key.data_layout_) << (LEFT_SHIFT * 2);
<< (LEFT_SHIFT * 2);
std::hash<int> hasher; std::hash<int> hasher;
return hasher(data_type + data_layout); return hasher(data_type + data_layout);
} }
}; };
// place, data_type, library_type kinds less than 2^8 // place, data_type, library_type kinds less than 2^8
constexpr static int LEFT_SHIFT = 8; constexpr static int LEFT_SHIFT = 8;
proto::VarType::Type data_type_; proto::VarType::Type data_type_;
DataLayout data_layout_; DataLayout data_layout_;
OpKernelType(proto::VarType::Type data_type, OpKernelType(proto::VarType::Type data_type,
DataLayout data_layout = DataLayout::kAnyLayout) DataLayout data_layout = DataLayout::kAnyLayout)
: data_type_(data_type), data_layout_(data_layout) {} : data_type_(data_type), data_layout_(data_layout) {}
bool operator==(const OpKernelType &o) const { bool operator==(const OpKernelType &o) const {
return data_type_ == o.data_type_ && return data_type_ == o.data_type_ && data_layout_ == o.data_layout_;
data_layout_ == o.data_layout_; }
}
bool operator!=(const OpKernelType &o) const { bool operator!=(const OpKernelType &o) const { return !(*this == o); }
return !(*this == o); };
}
};
inline bool NeedTransformLayout(const DataLayout &l, inline bool NeedTransformLayout(const DataLayout &l, const DataLayout &r) {
const DataLayout &r) { return l != DataLayout::kAnyLayout && r != DataLayout::kAnyLayout && l != r;
return l != DataLayout::kAnyLayout && r != DataLayout::kAnyLayout && }
l != r;
}
inline bool TransFromNeeded(const OpKernelType &l, inline bool TransFromNeeded(const OpKernelType &l, const OpKernelType &r) {
const OpKernelType &r) { return (l.data_type_ != r.data_type_) ||
return (l.data_type_ != r.data_type_) || NeedTransformLayout(l.data_layout_, r.data_layout_);
NeedTransformLayout(l.data_layout_, r.data_layout_); }
}
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -19,8 +19,8 @@ SOFTWARE. ...@@ -19,8 +19,8 @@ SOFTWARE.
#pragma once #pragma once
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
// this class not only make proto but also init attribute checkers. // this class not only make proto but also init attribute checkers.
class OpProtoAndCheckerMaker {}; class OpProtoAndCheckerMaker {};
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -20,23 +20,23 @@ SOFTWARE. ...@@ -20,23 +20,23 @@ SOFTWARE.
#include "op_info.h" #include "op_info.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
template <typename Dtype> template <typename Dtype>
OperatorBase<Dtype>::OperatorBase(const std::string &type, OperatorBase<Dtype>::OperatorBase(const std::string &type,
const VariableNameMap &inputs, const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const AttributeMap &attrs,
std::shared_ptr<Scope> scope) std::shared_ptr<Scope> scope)
: type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs), : type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs),
scope_(scope) { scope_(scope) {
CheckAllInputOutputSet(); CheckAllInputOutputSet();
} }
template <typename Dtype> template <typename Dtype>
void OperatorBase<Dtype>::CheckAllInputOutputSet() const {} void OperatorBase<Dtype>::CheckAllInputOutputSet() const {}
template class OperatorBase<CPU>; template class OperatorBase<CPU>;
template class OperatorWithKernel<CPU>; template class OperatorWithKernel<CPU>;
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -33,68 +33,64 @@ SOFTWARE. ...@@ -33,68 +33,64 @@ SOFTWARE.
#include "variable.h" #include "variable.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
static std::unordered_map<std::string, std::vector<std::string>> static std::unordered_map<std::string, std::vector<std::string>>
op_input_output_key = { op_input_output_key = {
{"conv2d", {"Input", "Output"}}, {"relu", {"X", "Out"}}, {"conv2d", {"Input", "Output"}}, {"relu", {"X", "Out"}},
{"softmax", {"X", "Out"}}, {"mul", {"X", "Out"}}, {"softmax", {"X", "Out"}}, {"mul", {"X", "Out"}},
{"elementwise_add", {"X", "Out"}}, {"pool2d", {"X", "Out"}}, {"elementwise_add", {"X", "Out"}}, {"pool2d", {"X", "Out"}},
{"batch_norm", {"X", "Y"}}, {"lrn", {"X", "Out"}}, {"batch_norm", {"X", "Y"}}, {"lrn", {"X", "Out"}},
{"concat", {"X", "Out"}}, {"concat", {"X", "Out"}},
}; };
template <typename Dtype> class OperatorBase : PaddleMobileObject { template <typename Dtype> class OperatorBase : PaddleMobileObject {
public: public:
OperatorBase(const std::string &type, const VariableNameMap &inputs, OperatorBase(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs, const AttributeMap &attrs,
const AttributeMap &attrs, std::shared_ptr<Scope> scope);
std::shared_ptr<Scope> scope); virtual ~OperatorBase() {}
virtual ~OperatorBase() {} virtual void Run() const = 0;
virtual void Run() const = 0;
const VariableNameMap &Inputs() const { return inputs_; } const VariableNameMap &Inputs() const { return inputs_; }
const VariableNameMap &Outputs() const { return outputs_; } const VariableNameMap &Outputs() const { return outputs_; }
const std::string &Type() const { return type_; } const std::string &Type() const { return type_; }
const AttributeMap &Attrs() const { return attrs_; } const AttributeMap &Attrs() const { return attrs_; }
void ClearVariables() const { void ClearVariables() const {
if (this->scope_) { if (this->scope_) {
this->scope_->EraseVars(this->inputs_.at("Filter")); this->scope_->EraseVars(this->inputs_.at("Filter"));
this->scope_->EraseVars(this->inputs_.at("Input")); this->scope_->EraseVars(this->inputs_.at("Input"));
} }
} }
protected: protected:
std::shared_ptr<Scope> scope_; std::shared_ptr<Scope> scope_;
std::string type_; std::string type_;
VariableNameMap inputs_; VariableNameMap inputs_;
VariableNameMap outputs_; VariableNameMap outputs_;
AttributeMap attrs_; AttributeMap attrs_;
private: private:
void CheckAllInputOutputSet() const; void CheckAllInputOutputSet() const;
}; };
template <typename Dtype> template <typename Dtype>
class OperatorWithKernel : public OperatorBase<Dtype> { class OperatorWithKernel : public OperatorBase<Dtype> {
public: public:
OperatorWithKernel(const std::string &type, OperatorWithKernel(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &inputs, const VariableNameMap &outputs,
const VariableNameMap &outputs, const AttributeMap &attrs, std::shared_ptr<Scope> scope)
const AttributeMap &attrs, : OperatorBase<Dtype>(type, inputs, outputs, attrs, scope) {}
std::shared_ptr<Scope> scope) virtual void InferShape() const = 0;
: OperatorBase<Dtype>(type, inputs, outputs, attrs, scope) {} virtual void Run() const = 0;
virtual void InferShape() const = 0; };
virtual void Run() const = 0;
};
template <typename Dtype, typename P> template <typename Dtype, typename P> class OpKernelBase : PaddleMobileObject {
class OpKernelBase : PaddleMobileObject { public:
public: virtual void Compute(const P &para) const = 0;
virtual void Compute(const P &para) const = 0;
virtual ~OpKernelBase() = default; virtual ~OpKernelBase() = default;
}; };
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -23,14 +23,14 @@ SOFTWARE. ...@@ -23,14 +23,14 @@ SOFTWARE.
namespace paddle_mobile { namespace paddle_mobile {
class PaddleMobileObject { class PaddleMobileObject {
public: public:
virtual inline const std::string &ToString() { virtual inline const std::string &ToString() {
char address[128] = {0}; char address[128] = {0};
sprintf(address, "%p", this); sprintf(address, "%p", this);
return std::string(address); return std::string(address);
} }
private: private:
}; };
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -17,5 +17,5 @@ SOFTWARE. ...@@ -17,5 +17,5 @@ SOFTWARE.
==============================================================================*/ ==============================================================================*/
namespace paddle_mobile { namespace paddle_mobile {
namespace framework {} namespace framework {}
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -24,17 +24,17 @@ SOFTWARE. ...@@ -24,17 +24,17 @@ SOFTWARE.
#include "scope.h" #include "scope.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
template <typename Dtype, Precision P = Precision::FP32> template <typename Dtype, Precision P = Precision::FP32>
class Program : PaddleMobileObject { class Program : PaddleMobileObject {
public: public:
std::shared_ptr<ProgramDesc> originProgram; std::shared_ptr<ProgramDesc> originProgram;
std::shared_ptr<ProgramDesc> optimizeProgram; std::shared_ptr<ProgramDesc> optimizeProgram;
std::shared_ptr<Scope> scope; std::shared_ptr<Scope> scope;
private: private:
}; };
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -5,18 +5,18 @@ ...@@ -5,18 +5,18 @@
#include "program_desc.h" #include "program_desc.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
ProgramDesc::ProgramDesc(const proto::ProgramDesc &desc) : desc_(desc) { ProgramDesc::ProgramDesc(const proto::ProgramDesc &desc) : desc_(desc) {
for (auto &block_desc : *desc_.mutable_blocks()) { for (auto &block_desc : *desc_.mutable_blocks()) {
// new framework::BlockDesc(block_desc) // new framework::BlockDesc(block_desc)
blocks_.emplace_back(std::make_shared<BlockDesc>(block_desc)); blocks_.emplace_back(std::make_shared<BlockDesc>(block_desc));
} }
} }
std::shared_ptr<BlockDesc> ProgramDesc::Block(size_t idx) { std::shared_ptr<BlockDesc> ProgramDesc::Block(size_t idx) {
return blocks_[idx]; return blocks_[idx];
} }
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -25,20 +25,18 @@ SOFTWARE. ...@@ -25,20 +25,18 @@ SOFTWARE.
#include "paddle_mobile_object.h" #include "paddle_mobile_object.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
class ProgramDesc : PaddleMobileObject { class ProgramDesc : PaddleMobileObject {
public: public:
ProgramDesc(const proto::ProgramDesc &desc); ProgramDesc(const proto::ProgramDesc &desc);
std::shared_ptr<BlockDesc> Block(size_t idx); std::shared_ptr<BlockDesc> Block(size_t idx);
const std::vector<std::shared_ptr<BlockDesc>> &Blocks() { const std::vector<std::shared_ptr<BlockDesc>> &Blocks() { return blocks_; };
return blocks_;
};
private: private:
std::vector<std::shared_ptr<BlockDesc>> blocks_; std::vector<std::shared_ptr<BlockDesc>> blocks_;
proto::ProgramDesc desc_; proto::ProgramDesc desc_;
}; };
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -4,116 +4,116 @@ ...@@ -4,116 +4,116 @@
#include <vector> #include <vector>
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
Scope &Scope::NewScope() const { Scope &Scope::NewScope() const {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
kids_.push_back(new Scope(this)); kids_.push_back(new Scope(this));
return *kids_.back(); return *kids_.back();
} }
Variable *Scope::Var(const std::string &name) { Variable *Scope::Var(const std::string &name) {
auto *pvar = FindVarLocally(name); auto *pvar = FindVarLocally(name);
if (pvar != nullptr) { if (pvar != nullptr) {
return pvar; return pvar;
}; };
pvar = new Variable; pvar = new Variable;
vars_[name] = pvar; vars_[name] = pvar;
pvar->name_ = &(vars_.find(name)->first); pvar->name_ = &(vars_.find(name)->first);
return pvar; return pvar;
} }
// Variable* Scope::Var(std::string* name) { // Variable* Scope::Var(std::string* name) {
// auto var_name = string::Sprintf("%p.%d", this, // auto var_name = string::Sprintf("%p.%d", this,
// vars_.size()); // vars_.size());
// if (name != nullptr) { // if (name != nullptr) {
// *name = var_name; // *name = var_name;
// } // }
// return Var(var_name); // return Var(var_name);
// } // }
Variable *Scope::FindVar(const std::string &name) const { Variable *Scope::FindVar(const std::string &name) const {
auto *pvar = FindVarLocally(name); auto *pvar = FindVarLocally(name);
if (pvar != nullptr) { if (pvar != nullptr) {
return pvar; return pvar;
} }
return (parent_ == nullptr) ? nullptr : parent_->FindVar(name); return (parent_ == nullptr) ? nullptr : parent_->FindVar(name);
} }
const Scope *Scope::FindScope(const Variable *var) const { const Scope *Scope::FindScope(const Variable *var) const {
for (auto &name_var : vars_) { for (auto &name_var : vars_) {
if (name_var.second == var) { if (name_var.second == var) {
return this; return this;
}
}
return (parent_ == nullptr) ? nullptr : parent_->FindScope(var);
} }
}
return (parent_ == nullptr) ? nullptr : parent_->FindScope(var);
}
void Scope::DropKids() { void Scope::DropKids() {
for (Scope *s : kids_) { for (Scope *s : kids_) {
delete s; delete s;
} }
kids_.clear(); kids_.clear();
} }
std::vector<std::string> Scope::LocalVarNames() const { std::vector<std::string> Scope::LocalVarNames() const {
std::vector<std::string> known_vars; std::vector<std::string> known_vars;
known_vars.reserve(vars_.size()); known_vars.reserve(vars_.size());
for (auto &name_var : vars_) { for (auto &name_var : vars_) {
known_vars.emplace_back(name_var.first); known_vars.emplace_back(name_var.first);
} }
return known_vars; return known_vars;
} }
void Scope::DeleteScope(Scope *scope) const { void Scope::DeleteScope(Scope *scope) const {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
auto it = std::find(kids_.begin(), kids_.end(), scope); auto it = std::find(kids_.begin(), kids_.end(), scope);
kids_.erase(it); kids_.erase(it);
delete scope; delete scope;
// deferent // deferent
} }
void Scope::EraseVars(const std::vector<std::string> &var_names) { void Scope::EraseVars(const std::vector<std::string> &var_names) {
std::set<std::string> var_set(var_names.begin(), var_names.end()); std::set<std::string> var_set(var_names.begin(), var_names.end());
for (auto it = vars_.begin(); it != vars_.end();) { for (auto it = vars_.begin(); it != vars_.end();) {
if (var_set.find(it->first) != var_set.end()) { if (var_set.find(it->first) != var_set.end()) {
delete it->second; delete it->second;
it = vars_.erase(it); it = vars_.erase(it);
} else { } else {
++it; ++it;
}
}
} }
}
}
void Scope::Rename(const std::string &origin_name, void Scope::Rename(const std::string &origin_name,
const std::string &new_name) const { const std::string &new_name) const {
auto origin_it = vars_.find(origin_name); auto origin_it = vars_.find(origin_name);
if (origin_it == vars_.end()) { if (origin_it == vars_.end()) {
return; return;
} }
auto new_it = vars_.find(new_name); auto new_it = vars_.find(new_name);
if (new_it != vars_.end()) { if (new_it != vars_.end()) {
return; return;
} }
vars_[new_name] = origin_it->second; vars_[new_name] = origin_it->second;
vars_.erase(origin_it); vars_.erase(origin_it);
} }
// //
// std::string Scope::Rename(const std::string& origin_name) // std::string Scope::Rename(const std::string& origin_name)
// const { // const {
// auto var_name = string::Sprintf("%p.%d", this, // auto var_name = string::Sprintf("%p.%d", this,
// vars_.size()); // vars_.size());
// Rename(origin_name, var_name); // Rename(origin_name, var_name);
// return var_name; // return var_name;
// } // }
Variable *Scope::FindVarLocally(const std::string &name) const { Variable *Scope::FindVarLocally(const std::string &name) const {
auto it = vars_.find(name); auto it = vars_.find(name);
if (it != vars_.end()) { if (it != vars_.end()) {
return it->second; return it->second;
} }
return nullptr; return nullptr;
} }
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -24,58 +24,58 @@ SOFTWARE. ...@@ -24,58 +24,58 @@ SOFTWARE.
#include <unordered_map> //std::unordered_map #include <unordered_map> //std::unordered_map
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
class Scope { class Scope {
public: public:
Scope() {} Scope() {}
~Scope() {} ~Scope() {}
Scope &NewScope() const; Scope &NewScope() const;
/// Create a variable with given name if it doesn't exist. /// Create a variable with given name if it doesn't exist.
Variable *Var(const std::string &name); Variable *Var(const std::string &name);
/// Create a variable with a scope-unique name. /// Create a variable with a scope-unique name.
Variable *Var(std::string *name = nullptr); Variable *Var(std::string *name = nullptr);
void EraseVars(const std::vector<std::string> &var_names); void EraseVars(const std::vector<std::string> &var_names);
/// Find a variable in the scope or any of its ancestors. Returns /// Find a variable in the scope or any of its ancestors. Returns
/// nullptr if cannot find. /// nullptr if cannot find.
Variable *FindVar(const std::string &name) const; Variable *FindVar(const std::string &name) const;
const Scope *parent() const { return parent_; } const Scope *parent() const { return parent_; }
/// Find the scope or an ancestor scope that contains the given /// Find the scope or an ancestor scope that contains the given
/// variable. /// variable.
const Scope *FindScope(const Variable *var) const; const Scope *FindScope(const Variable *var) const;
void DeleteScope(Scope *scope) const; void DeleteScope(Scope *scope) const;
/// Drop all kids scopes belonged to this scope. /// Drop all kids scopes belonged to this scope.
void DropKids(); void DropKids();
// enumerate all the variables current contains. // enumerate all the variables current contains.
std::vector<std::string> LocalVarNames() const; std::vector<std::string> LocalVarNames() const;
// Rename variable to a new name // Rename variable to a new name
void Rename(const std::string &origin_name, void Rename(const std::string &origin_name,
const std::string &new_name) const; const std::string &new_name) const;
// Rename variable to a new name and return the new name // Rename variable to a new name and return the new name
std::string Rename(const std::string &origin_name) const; std::string Rename(const std::string &origin_name) const;
Variable *FindVarLocally(const std::string &name) const; Variable *FindVarLocally(const std::string &name) const;
private: private:
// Call Scope::NewScope for a sub-scope. // Call Scope::NewScope for a sub-scope.
explicit Scope(Scope const *parent) : parent_(parent) {} explicit Scope(Scope const *parent) : parent_(parent) {}
mutable std::unordered_map<std::string, Variable *> vars_; mutable std::unordered_map<std::string, Variable *> vars_;
mutable std::list<Scope *> kids_; mutable std::list<Scope *> kids_;
Scope const *parent_{nullptr}; Scope const *parent_{nullptr};
mutable std::mutex mutex_; mutable std::mutex mutex_;
}; };
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -24,59 +24,58 @@ SOFTWARE. ...@@ -24,59 +24,58 @@ SOFTWARE.
#include "tensor.h" #include "tensor.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
class SelectedRows { class SelectedRows {
public: public:
SelectedRows(const std::vector<int64_t> &rows, SelectedRows(const std::vector<int64_t> &rows, const int64_t &height)
const int64_t &height) : rows_(rows), height_(height) {
: rows_(rows), height_(height) { value_.reset(new Tensor());
value_.reset(new Tensor()); }
}
SelectedRows() { SelectedRows() {
height_ = 0; height_ = 0;
value_.reset(new Tensor()); value_.reset(new Tensor());
} }
const Tensor &value() const { return *value_; } const Tensor &value() const { return *value_; }
Tensor *mutable_value() { return value_.get(); } Tensor *mutable_value() { return value_.get(); }
int64_t height() const { return height_; } int64_t height() const { return height_; }
void set_height(int64_t height) { height_ = height; } void set_height(int64_t height) { height_ = height; }
const std::vector<int64_t> &rows() const { return rows_; } const std::vector<int64_t> &rows() const { return rows_; }
std::vector<int64_t> *mutable_rows() { return &rows_; } std::vector<int64_t> *mutable_rows() { return &rows_; }
void set_rows(const std::vector<int64_t> &rows) { rows_ = rows; } void set_rows(const std::vector<int64_t> &rows) { rows_ = rows; }
/** /**
* get the index of id in rows * get the index of id in rows
*/ */
int64_t index(int64_t id) const { int64_t index(int64_t id) const {
auto it = std::find(rows_.begin(), rows_.end(), id); auto it = std::find(rows_.begin(), rows_.end(), id);
// PADDLE_ENFORCE(it != rows_.end(), "id should be in rows"); // PADDLE_ENFORCE(it != rows_.end(), "id should be in rows");
return static_cast<int64_t>(std::distance(rows_.begin(), it)); return static_cast<int64_t>(std::distance(rows_.begin(), it));
} }
DDim GetCompleteDims() const { DDim GetCompleteDims() const {
std::vector<int64_t> dims = vectorize(value_->dims()); std::vector<int64_t> dims = vectorize(value_->dims());
dims[0] = height_; dims[0] = height_;
return make_ddim(dims); return make_ddim(dims);
} }
private: private:
// Notice: rows can be duplicate. We can have {0, 4, 7, 0, 5, 7, 9} // Notice: rows can be duplicate. We can have {0, 4, 7, 0, 5, 7, 9}
// here. // here.
// SelectedRows are simply concated when adding together. Until a // SelectedRows are simply concated when adding together. Until a
// SelectedRows add a Tensor, will the duplicate rows be handled. // SelectedRows add a Tensor, will the duplicate rows be handled.
std::vector<int64_t> rows_; std::vector<int64_t> rows_;
std::unique_ptr<Tensor> value_{nullptr}; std::unique_ptr<Tensor> value_{nullptr};
int64_t height_; int64_t height_;
}; };
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
此差异已折叠。
此差异已折叠。
此差异已折叠。
...@@ -20,9 +20,9 @@ SOFTWARE. ...@@ -20,9 +20,9 @@ SOFTWARE.
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
VarDesc::VarDesc(const proto::VarDesc &desc) : desc_(desc) {} VarDesc::VarDesc(const proto::VarDesc &desc) : desc_(desc) {}
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
此差异已折叠。
...@@ -23,17 +23,17 @@ SOFTWARE. ...@@ -23,17 +23,17 @@ SOFTWARE.
#include "variable.h" #include "variable.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
inline proto::VarType::Type ToVarType(std::type_index type) { inline proto::VarType::Type ToVarType(std::type_index type) {
if (type.hash_code() == typeid(LoDTensor).hash_code()) { if (type.hash_code() == typeid(LoDTensor).hash_code()) {
return proto::VarType_Type_LOD_TENSOR; return proto::VarType_Type_LOD_TENSOR;
} else if (type.hash_code() == typeid(SelectedRows).hash_code()) { } else if (type.hash_code() == typeid(SelectedRows).hash_code()) {
return proto::VarType_Type_SELECTED_ROWS; return proto::VarType_Type_SELECTED_ROWS;
} else { } else {
// PADDLE_THROW("ToVarType:Unsupported type %s", // PADDLE_THROW("ToVarType:Unsupported type %s",
// type.name()); // type.name());
} }
} }
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
此差异已折叠。
此差异已折叠。
...@@ -27,14 +27,13 @@ SOFTWARE. ...@@ -27,14 +27,13 @@ SOFTWARE.
namespace paddle_mobile { namespace paddle_mobile {
template <typename Dtype, Precision P = Precision::FP32> template <typename Dtype, Precision P = Precision::FP32>
class Loader : PaddleMobileObject { class Loader : PaddleMobileObject {
public: public:
const framework::Program<Dtype, P> Load(const std::string &dirname); const framework::Program<Dtype, P> Load(const std::string &dirname);
private: private:
void LoadVar(framework::LoDTensor *tensor, void LoadVar(framework::LoDTensor *tensor, const std::string &file_path);
const std::string &file_path); };
};
} // namespace paddle_mobile } // namespace paddle_mobile
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
...@@ -25,4 +25,4 @@ namespace operators { ...@@ -25,4 +25,4 @@ namespace operators {
// //
// template class ConvKernel<FPGA, float>; // template class ConvKernel<FPGA, float>;
} }
} } // namespace paddle_mobile
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册