提交 d2f8befa 编写于 作者: L liuruilong

format files

上级 c9ffe855
...@@ -17,15 +17,14 @@ SOFTWARE. ...@@ -17,15 +17,14 @@ SOFTWARE.
==============================================================================*/ ==============================================================================*/
#pragma once; #pragma once;
#include "framework/attribute.h"
#include <map> #include <map>
#include <string> #include <string>
#include "framework/attribute.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
template <typename Dtype> template <typename Dtype> class OperatorBase;
class OperatorBase;
class OpDesc; class OpDesc;
class BlockDesc; class BlockDesc;
class InferShapeContext; class InferShapeContext;
...@@ -34,20 +33,20 @@ class InferShapeContext; ...@@ -34,20 +33,20 @@ class InferShapeContext;
using VariableNameMap = std::map<std::string, std::vector<std::string>>; using VariableNameMap = std::map<std::string, std::vector<std::string>>;
template <typename Dtype> template <typename Dtype>
using OpCreator = std::function<framework::OperatorBase<Dtype>*( using OpCreator = std::function<framework::OperatorBase<Dtype> *(
const std::string& /*type*/, const VariableNameMap& /*inputs*/, const std::string & /*type*/, const VariableNameMap & /*inputs*/,
const VariableNameMap& /*outputs*/, const VariableNameMap & /*outputs*/,
const framework::AttributeMap& /*attrs*/)>; const framework::AttributeMap & /*attrs*/)>;
using GradOpMakerFN = using GradOpMakerFN =
std::function<std::vector<std::unique_ptr<framework::OpDesc>>( std::function<std::vector<std::unique_ptr<framework::OpDesc>>(
const framework::OpDesc&, const framework::OpDesc &,
const std::unordered_set<std::string>& /*no_grad_set*/, const std::unordered_set<std::string> & /*no_grad_set*/,
std::unordered_map<std::string, std::string>* /*grad_to_var*/, std::unordered_map<std::string, std::string> * /*grad_to_var*/,
const std::vector<framework::BlockDesc*>& grad_block)>; const std::vector<framework::BlockDesc *> &grad_block)>;
using InferVarTypeFN = std::function<void(const framework::OpDesc& /*op_desc*/, using InferVarTypeFN = std::function<void(const framework::OpDesc & /*op_desc*/,
framework::BlockDesc* /*block*/)>; framework::BlockDesc * /*block*/)>;
using InferShapeFN = std::function<void(framework::InferShapeContext*)>; using InferShapeFN = std::function<void(framework::InferShapeContext *)>;
}; };
...@@ -24,8 +24,7 @@ enum class Precision : int { FP32 = 0 }; ...@@ -24,8 +24,7 @@ enum class Precision : int { FP32 = 0 };
//! device type //! device type
enum DeviceTypeEnum { kINVALID = -1, kCPU = 0, kFPGA = 1, kGPU_MALI = 2 }; enum DeviceTypeEnum { kINVALID = -1, kCPU = 0, kFPGA = 1, kGPU_MALI = 2 };
template <DeviceTypeEnum T> template <DeviceTypeEnum T> struct DeviceType {};
struct DeviceType {};
typedef DeviceType<kCPU> CPU; typedef DeviceType<kCPU> CPU;
typedef DeviceType<kFPGA> FPGA; typedef DeviceType<kFPGA> FPGA;
......
...@@ -21,13 +21,9 @@ SOFTWARE. ...@@ -21,13 +21,9 @@ SOFTWARE.
#pragma once #pragma once
namespace paddle_mobile { namespace paddle_mobile {
template <int ID, typename Type> template <int ID, typename Type> struct IDToType { typedef Type type_t; };
struct IDToType {
typedef Type type_t;
};
template <typename F, typename... Ts> template <typename F, typename... Ts> struct VariantHelper {
struct VariantHelper {
static const size_t size = sizeof(F) > VariantHelper<Ts...>::size static const size_t size = sizeof(F) > VariantHelper<Ts...>::size
? sizeof(F) ? sizeof(F)
: VariantHelper<Ts...>::size; : VariantHelper<Ts...>::size;
...@@ -41,8 +37,7 @@ struct VariantHelper { ...@@ -41,8 +37,7 @@ struct VariantHelper {
} }
}; };
template <typename F> template <typename F> struct VariantHelper<F> {
struct VariantHelper<F> {
static const size_t size = sizeof(F); static const size_t size = sizeof(F);
inline static void Destroy(size_t id, void *data) { inline static void Destroy(size_t id, void *data) {
if (id == typeid(F).hash_code()) { if (id == typeid(F).hash_code()) {
...@@ -53,9 +48,8 @@ struct VariantHelper<F> { ...@@ -53,9 +48,8 @@ struct VariantHelper<F> {
} }
}; };
template <size_t size> template <size_t size> class RawData {
class RawData { public:
public:
char data[size]; char data[size];
RawData() {} RawData() {}
RawData(const RawData &raw_data) { strcpy(data, raw_data.data); } RawData(const RawData &raw_data) { strcpy(data, raw_data.data); }
...@@ -64,8 +58,7 @@ class RawData { ...@@ -64,8 +58,7 @@ class RawData {
// } // }
}; };
template <typename... Ts> template <typename... Ts> struct Variant {
struct Variant {
Variant(const Variant &variant) { Variant(const Variant &variant) {
// std::cout << " 赋值构造函数 " << std::endl; // std::cout << " 赋值构造函数 " << std::endl;
type_id = variant.type_id; type_id = variant.type_id;
...@@ -77,15 +70,13 @@ struct Variant { ...@@ -77,15 +70,13 @@ struct Variant {
// helper::Destroy(type_id, &data); // helper::Destroy(type_id, &data);
} }
template <typename T, typename... Args> template <typename T, typename... Args> void Set(Args &&... args) {
void Set(Args &&... args) {
helper::Destroy(type_id, &data); helper::Destroy(type_id, &data);
new (&data) T(std::forward<Args>(args)...); new (&data) T(std::forward<Args>(args)...);
type_id = typeid(T).hash_code(); type_id = typeid(T).hash_code();
} }
template <typename T> template <typename T> T &Get() const {
T &Get() const {
if (type_id == typeid(T).hash_code()) { if (type_id == typeid(T).hash_code()) {
return *const_cast<T *>(reinterpret_cast<const T *>(&data)); return *const_cast<T *>(reinterpret_cast<const T *>(&data));
} else { } else {
...@@ -96,16 +87,13 @@ struct Variant { ...@@ -96,16 +87,13 @@ struct Variant {
size_t TypeId() const { return type_id; } size_t TypeId() const { return type_id; }
private: private:
static inline size_t invalid_type() { return typeid(void).hash_code(); } static inline size_t invalid_type() { return typeid(void).hash_code(); }
typedef VariantHelper<Ts...> helper; typedef VariantHelper<Ts...> helper;
size_t type_id; size_t type_id;
RawData<helper::size> data; RawData<helper::size> data;
}; };
template <typename T> template <typename T> struct Vistor { typedef T type_t; };
struct Vistor {
typedef T type_t;
};
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -20,4 +20,4 @@ SOFTWARE. ...@@ -20,4 +20,4 @@ SOFTWARE.
namespace paddle_mobile { namespace paddle_mobile {
namespace framework {} namespace framework {}
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -27,86 +27,82 @@ namespace framework { ...@@ -27,86 +27,82 @@ namespace framework {
class BlockDesc; class BlockDesc;
class Attribute { class Attribute {
public: public:
static Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc) { static Attribute GetAttrValue(const proto::OpDesc::Attr &attr_desc) {
// std::cout << "begin get attr value" << std::endl; // std::cout << "begin get attr value" << std::endl;
Attribute attr; Attribute attr;
switch (attr_desc.type()) { switch (attr_desc.type()) {
case proto::AttrType::BOOLEAN: { case proto::AttrType::BOOLEAN: {
attr.Set<bool>(attr_desc.b()); attr.Set<bool>(attr_desc.b());
break; break;
} }
case proto::AttrType::INT: { case proto::AttrType::INT: {
attr.Set<int>(attr_desc.i()); attr.Set<int>(attr_desc.i());
break; break;
} }
case proto::AttrType::FLOAT: { case proto::AttrType::FLOAT: {
attr.Set<float>(attr_desc.f()); attr.Set<float>(attr_desc.f());
break; break;
} }
case proto::AttrType::STRING: { case proto::AttrType::STRING: {
attr.Set<std::string>(attr_desc.s()); attr.Set<std::string>(attr_desc.s());
break; break;
} }
case proto::AttrType::BOOLEANS: { case proto::AttrType::BOOLEANS: {
std::vector<bool> val(attr_desc.bools_size()); std::vector<bool> val(attr_desc.bools_size());
for (int i = 0; i < attr_desc.bools_size(); ++i) { for (int i = 0; i < attr_desc.bools_size(); ++i) {
val[i] = attr_desc.bools(i); val[i] = attr_desc.bools(i);
}
attr.Set<std::vector<bool>>(val);
break;
}
case proto::AttrType::INTS: {
std::vector<int> val(attr_desc.ints_size());
for (int i = 0; i < attr_desc.ints_size(); ++i) {
val[i] = attr_desc.ints(i);
}
attr.Set<std::vector<int>>(val);
break;
} }
case proto::AttrType::FLOATS: { attr.Set<std::vector<bool>>(val);
std::vector<float> val(attr_desc.floats_size()); break;
for (int i = 0; i < attr_desc.floats_size(); ++i) { }
val[i] = attr_desc.floats(i); case proto::AttrType::INTS: {
} std::vector<int> val(attr_desc.ints_size());
attr.Set<std::vector<float>>(val); for (int i = 0; i < attr_desc.ints_size(); ++i) {
break; val[i] = attr_desc.ints(i);
} }
case proto::AttrType::STRINGS: { attr.Set<std::vector<int>>(val);
std::vector<std::string> val(attr_desc.strings_size()); break;
for (int i = 0; i < attr_desc.strings_size(); ++i) { }
val[i] = attr_desc.strings(i); case proto::AttrType::FLOATS: {
} std::vector<float> val(attr_desc.floats_size());
attr.Set<std::vector<std::string>>(val); for (int i = 0; i < attr_desc.floats_size(); ++i) {
break; val[i] = attr_desc.floats(i);
} }
case proto::AttrType::LONG: { attr.Set<std::vector<float>>(val);
attr.Set<int64_t>(attr_desc.l()); break;
break; }
case proto::AttrType::STRINGS: {
std::vector<std::string> val(attr_desc.strings_size());
for (int i = 0; i < attr_desc.strings_size(); ++i) {
val[i] = attr_desc.strings(i);
} }
default: attr.Set<std::vector<std::string>>(val);
// std::cout << " not support " << std::endl; break;
break; }
case proto::AttrType::LONG: {
attr.Set<int64_t>(attr_desc.l());
break;
}
default:
// std::cout << " not support " << std::endl;
break;
} }
// std::cout << "end get attr value" << std::endl; // std::cout << "end get attr value" << std::endl;
return attr; return attr;
} }
Attribute() {} Attribute() {}
template <typename T, typename... Args> template <typename T, typename... Args> Attribute &Set(Args &&... args) {
Attribute& Set(Args&&... args) {
variant_.Set<T>(args...); variant_.Set<T>(args...);
return *this; return *this;
} }
template <typename T> template <typename T> T &Get() const { return variant_.Get<T>(); }
T& Get() const {
return variant_.Get<T>();
}
private: private:
Variant<int, float, std::string, std::vector<int>, std::vector<float>, Variant<int, float, std::string, std::vector<int>, std::vector<float>,
std::vector<std::string>, bool, std::vector<bool>, BlockDesc*, std::vector<std::string>, bool, std::vector<bool>, BlockDesc *,
int64_t> int64_t>
variant_; variant_;
}; };
...@@ -114,20 +110,19 @@ class Attribute { ...@@ -114,20 +110,19 @@ class Attribute {
using AttributeMap = std::unordered_map<std::string, Attribute>; using AttributeMap = std::unordered_map<std::string, Attribute>;
class AttrReader { class AttrReader {
public: public:
explicit AttrReader(const AttributeMap& attrs) : attrs_(attrs) {} explicit AttrReader(const AttributeMap &attrs) : attrs_(attrs) {}
template <typename T> template <typename T> inline T Get(const std::string &name) const {
inline T Get(const std::string& name) const {
// PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in // PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in
// AttributeMap", // AttributeMap",
// name); // name);
return ((Attribute)attrs_.at(name)).Get<T>(); return ((Attribute)attrs_.at(name)).Get<T>();
} }
private: private:
const AttributeMap& attrs_; const AttributeMap &attrs_;
}; };
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -46,5 +46,5 @@ BlockDesc::BlockDesc(const proto::BlockDesc &desc) : desc_(desc) { ...@@ -46,5 +46,5 @@ BlockDesc::BlockDesc(const proto::BlockDesc &desc) : desc_(desc) {
} }
} }
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -27,7 +27,7 @@ namespace paddle_mobile { ...@@ -27,7 +27,7 @@ namespace paddle_mobile {
namespace framework { namespace framework {
class BlockDesc : PaddleMobileObject { class BlockDesc : PaddleMobileObject {
public: public:
BlockDesc(const proto::BlockDesc &desc); BlockDesc(const proto::BlockDesc &desc);
const int &ID() const { return desc_.idx(); } const int &ID() const { return desc_.idx(); }
...@@ -45,19 +45,18 @@ class BlockDesc : PaddleMobileObject { ...@@ -45,19 +45,18 @@ class BlockDesc : PaddleMobileObject {
std::vector<std::shared_ptr<VarDesc>> Vars() const; std::vector<std::shared_ptr<VarDesc>> Vars() const;
std::vector<std::shared_ptr<OpDesc>> Ops() const; std::vector<std::shared_ptr<OpDesc>> Ops() const;
private: private:
proto::BlockDesc desc_; proto::BlockDesc desc_;
std::vector<std::shared_ptr<OpDesc>> ops_; std::vector<std::shared_ptr<OpDesc>> ops_;
std::unordered_map<std::string, std::shared_ptr<VarDesc>> vars_; std::unordered_map<std::string, std::shared_ptr<VarDesc>> vars_;
}; };
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
namespace std { namespace std {
template <> template <> struct hash<paddle_mobile::framework::BlockDesc> {
struct hash<paddle_mobile::framework::BlockDesc> {
typedef paddle_mobile::framework::BlockDesc argument_type; typedef paddle_mobile::framework::BlockDesc argument_type;
typedef std::size_t result_type; typedef std::size_t result_type;
result_type operator()(argument_type const &s) const noexcept { result_type operator()(argument_type const &s) const noexcept {
...@@ -67,4 +66,4 @@ struct hash<paddle_mobile::framework::BlockDesc> { ...@@ -67,4 +66,4 @@ struct hash<paddle_mobile::framework::BlockDesc> {
} }
}; };
} // namespace std } // namespace std
...@@ -27,7 +27,7 @@ enum class DataLayout { ...@@ -27,7 +27,7 @@ enum class DataLayout {
kAnyLayout = 2, kAnyLayout = 2,
}; };
inline DataLayout StringToDataLayout(const std::string& str) { inline DataLayout StringToDataLayout(const std::string &str) {
std::string s(str); std::string s(str);
for (size_t i = 0; i < s.size(); ++i) { for (size_t i = 0; i < s.size(); ++i) {
s[i] = toupper(s[i]); s[i] = toupper(s[i]);
...@@ -44,24 +44,24 @@ inline DataLayout StringToDataLayout(const std::string& str) { ...@@ -44,24 +44,24 @@ inline DataLayout StringToDataLayout(const std::string& str) {
} }
} }
inline std::string DataLayoutToString(const DataLayout& data_layout) { inline std::string DataLayoutToString(const DataLayout &data_layout) {
switch (data_layout) { switch (data_layout) {
case DataLayout::kNHWC: case DataLayout::kNHWC:
return "NHWC"; return "NHWC";
case DataLayout::kNCHW: case DataLayout::kNCHW:
return "NCHW"; return "NCHW";
case DataLayout::kAnyLayout: case DataLayout::kAnyLayout:
return "ANY_LAYOUT"; return "ANY_LAYOUT";
default: default:
break; break;
// std::cout << "unknown DataLayou %d", data_layout; // std::cout << "unknown DataLayou %d", data_layout;
} }
} }
inline std::ostream& operator<<(std::ostream& out, const DataLayout& l) { inline std::ostream &operator<<(std::ostream &out, const DataLayout &l) {
out << DataLayoutToString(l); out << DataLayoutToString(l);
return out; return out;
} }
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -23,14 +23,14 @@ SOFTWARE. ...@@ -23,14 +23,14 @@ SOFTWARE.
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
static void PassTensorData(Tensor* from, Tensor* to) { static void PassTensorData(Tensor *from, Tensor *to) {
to->ShareDataWith(*from); to->ShareDataWith(*from);
*from = Tensor(); *from = Tensor();
} }
void DataTransform(const OpKernelType& expected_kernel_type, void DataTransform(const OpKernelType &expected_kernel_type,
const OpKernelType& kernel_type_for_var, const OpKernelType &kernel_type_for_var,
const Tensor& input_tensor, Tensor* output_tensor) { const Tensor &input_tensor, Tensor *output_tensor) {
bool transformed = false; bool transformed = false;
Tensor in; Tensor in;
in.ShareDataWith(input_tensor); in.ShareDataWith(input_tensor);
...@@ -64,8 +64,8 @@ void DataTransform(const OpKernelType& expected_kernel_type, ...@@ -64,8 +64,8 @@ void DataTransform(const OpKernelType& expected_kernel_type,
output_tensor->ShareDataWith(in); output_tensor->ShareDataWith(in);
} }
void CopyVariableWithTensor(const Variable& in_var, const Tensor& tensor, void CopyVariableWithTensor(const Variable &in_var, const Tensor &tensor,
Variable& out_var) { Variable &out_var) {
// if (in_var.IsType<LoDTensor>()) { // if (in_var.IsType<LoDTensor>()) {
// auto& in_lod_tensor = in_var.Get<LoDTensor>(); // auto& in_lod_tensor = in_var.Get<LoDTensor>();
// auto* tran_lod_tensor = out_var.GetMutable<LoDTensor>(); // auto* tran_lod_tensor = out_var.GetMutable<LoDTensor>();
...@@ -83,5 +83,5 @@ void CopyVariableWithTensor(const Variable& in_var, const Tensor& tensor, ...@@ -83,5 +83,5 @@ void CopyVariableWithTensor(const Variable& in_var, const Tensor& tensor,
// } // }
} }
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -30,12 +30,12 @@ SOFTWARE. ...@@ -30,12 +30,12 @@ SOFTWARE.
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
void DataTransform(const OpKernelType& expected_kernel_type, void DataTransform(const OpKernelType &expected_kernel_type,
const OpKernelType& kernel_type_for_var, const OpKernelType &kernel_type_for_var,
const Tensor& input_tensor, Tensor* out); const Tensor &input_tensor, Tensor *out);
void CopyVariableWithTensor(const Variable& in_var, const Tensor& tensor, void CopyVariableWithTensor(const Variable &in_var, const Tensor &tensor,
Variable& out_var); Variable &out_var);
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -40,4 +40,4 @@ namespace framework { ...@@ -40,4 +40,4 @@ namespace framework {
// } // }
// } // }
} }
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -19,52 +19,48 @@ namespace framework { ...@@ -19,52 +19,48 @@ namespace framework {
/// @cond HIDDEN /// @cond HIDDEN
template <int i> template <int i> Dim<i> make_dim(const int64_t *d) {
Dim<i> make_dim(const int64_t* d) {
return Dim<i>(*d, make_dim<i - 1>(d + 1)); return Dim<i>(*d, make_dim<i - 1>(d + 1));
} }
template <> template <> Dim<0> make_dim<0>(const int64_t *d) { return Dim<0>(*d); }
Dim<0> make_dim<0>(const int64_t* d) {
return Dim<0>(*d);
}
void make_ddim(DDim& ddim, const int64_t* dims, int n) { void make_ddim(DDim &ddim, const int64_t *dims, int n) {
switch (n) { switch (n) {
case 0: case 0:
ddim = make_dim<0>(dims); ddim = make_dim<0>(dims);
break; break;
case 1: case 1:
ddim = make_dim<1>(dims); ddim = make_dim<1>(dims);
break; break;
case 2: case 2:
ddim = make_dim<2>(dims); ddim = make_dim<2>(dims);
break; break;
case 3: case 3:
ddim = make_dim<3>(dims); ddim = make_dim<3>(dims);
break; break;
case 4: case 4:
ddim = make_dim<4>(dims); ddim = make_dim<4>(dims);
break; break;
case 5: case 5:
ddim = make_dim<5>(dims); ddim = make_dim<5>(dims);
break; break;
case 6: case 6:
ddim = make_dim<6>(dims); ddim = make_dim<6>(dims);
break; break;
case 7: case 7:
ddim = make_dim<7>(dims); ddim = make_dim<7>(dims);
break; break;
case 8: case 8:
ddim = make_dim<8>(dims); ddim = make_dim<8>(dims);
break; break;
case 9: case 9:
ddim = make_dim<9>(dims); ddim = make_dim<9>(dims);
break; break;
default: default:
// std::cout << "Dynamic dimensions must have between [1, 9] // std::cout << "Dynamic dimensions must have between [1, 9]
// dimensions."; // dimensions.";
break; break;
} }
} }
...@@ -76,13 +72,13 @@ DDim make_ddim(std::initializer_list<int64_t> dims) { ...@@ -76,13 +72,13 @@ DDim make_ddim(std::initializer_list<int64_t> dims) {
return result; return result;
} }
DDim make_ddim(const std::vector<int64_t>& dims) { DDim make_ddim(const std::vector<int64_t> &dims) {
DDim result(make_dim(0)); DDim result(make_dim(0));
make_ddim(result, &dims[0], dims.size()); make_ddim(result, &dims[0], dims.size());
return result; return result;
} }
DDim make_ddim(const std::vector<int>& dims) { DDim make_ddim(const std::vector<int> &dims) {
std::vector<int64_t> res(dims.size()); std::vector<int64_t> res(dims.size());
std::transform(dims.begin(), dims.end(), res.begin(), std::transform(dims.begin(), dims.end(), res.begin(),
[](int d) { return static_cast<int64_t>(d); }); [](int d) { return static_cast<int64_t>(d); });
...@@ -91,35 +87,31 @@ DDim make_ddim(const std::vector<int>& dims) { ...@@ -91,35 +87,31 @@ DDim make_ddim(const std::vector<int>& dims) {
/// @cond HIDDEN /// @cond HIDDEN
// XXX For some reason, putting this in an anonymous namespace causes errors // XXX For some reason, putting this in an anonymous namespace causes errors
struct DynamicMutableIndexer : Vistor<int64_t&> { struct DynamicMutableIndexer : Vistor<int64_t &> {
public: public:
explicit DynamicMutableIndexer(int idx) : idx_(idx) {} explicit DynamicMutableIndexer(int idx) : idx_(idx) {}
template <int D> template <int D> int64_t &operator()(Dim<D> &dim) const { return dim[idx_]; }
int64_t& operator()(Dim<D>& dim) const {
return dim[idx_];
}
private: private:
int idx_; int idx_;
}; };
struct DynamicConstIndexer : public Vistor<int64_t> { struct DynamicConstIndexer : public Vistor<int64_t> {
public: public:
explicit DynamicConstIndexer(int idx) : idx_(idx) {} explicit DynamicConstIndexer(int idx) : idx_(idx) {}
template <int D> template <int D> int64_t operator()(const Dim<D> &dim) const {
int64_t operator()(const Dim<D>& dim) const {
return dim[idx_]; return dim[idx_];
} }
private: private:
int idx_; int idx_;
}; };
/// @endcond /// @endcond
int64_t& DDim::operator[](int idx) { int64_t &DDim::operator[](int idx) {
return DDim::ApplyVistor(DynamicMutableIndexer(idx), *this); return DDim::ApplyVistor(DynamicMutableIndexer(idx), *this);
} }
...@@ -178,27 +170,26 @@ DDim DDim::operator*(DDim d) const { ...@@ -178,27 +170,26 @@ DDim DDim::operator*(DDim d) const {
return make_ddim(v3); return make_ddim(v3);
} }
int64_t get(const DDim& ddim, int idx) { return ddim[idx]; } int64_t get(const DDim &ddim, int idx) { return ddim[idx]; }
void set(DDim& ddim, int idx, int value) { ddim[idx] = value; } void set(DDim &ddim, int idx, int value) { ddim[idx] = value; }
/// @cond HIDDEN /// @cond HIDDEN
struct VectorizeVisitor : Vistor<void> { struct VectorizeVisitor : Vistor<void> {
std::vector<int64_t>& vector; std::vector<int64_t> &vector;
explicit VectorizeVisitor(std::vector<int64_t>& v) : vector(v) {} explicit VectorizeVisitor(std::vector<int64_t> &v) : vector(v) {}
template <typename T> template <typename T> void operator()(const T &t) {
void operator()(const T& t) {
vector.push_back(t.head); vector.push_back(t.head);
this->operator()(t.tail); this->operator()(t.tail);
} }
void operator()(const Dim<0>& t) {} void operator()(const Dim<0> &t) {}
}; };
/// @endcond /// @endcond
std::vector<int64_t> vectorize(const DDim& ddim) { std::vector<int64_t> vectorize(const DDim &ddim) {
std::vector<int64_t> result; std::vector<int64_t> result;
VectorizeVisitor visitor(result); VectorizeVisitor visitor(result);
DDim::ApplyVistor(visitor, ddim); DDim::ApplyVistor(visitor, ddim);
...@@ -207,30 +198,29 @@ std::vector<int64_t> vectorize(const DDim& ddim) { ...@@ -207,30 +198,29 @@ std::vector<int64_t> vectorize(const DDim& ddim) {
// NOTE: framework::vectorize converts to type int64_t // NOTE: framework::vectorize converts to type int64_t
// which does not fit cudnn inputs. // which does not fit cudnn inputs.
std::vector<int> vectorize2int(const DDim& ddim) { std::vector<int> vectorize2int(const DDim &ddim) {
std::vector<int64_t> temp = vectorize(ddim); std::vector<int64_t> temp = vectorize(ddim);
std::vector<int> result(temp.begin(), temp.end()); std::vector<int> result(temp.begin(), temp.end());
return result; return result;
} }
struct ProductVisitor : Vistor<int64_t> { struct ProductVisitor : Vistor<int64_t> {
template <int D> template <int D> int64_t operator()(const Dim<D> &dim) {
int64_t operator()(const Dim<D>& dim) {
return product(dim); return product(dim);
} }
}; };
int64_t product(const DDim& ddim) { int64_t product(const DDim &ddim) {
ProductVisitor visitor; ProductVisitor visitor;
return DDim::ApplyVistor(visitor, ddim); return DDim::ApplyVistor(visitor, ddim);
} }
struct SliceVectorizeVisitor : Vistor<void> { struct SliceVectorizeVisitor : Vistor<void> {
std::vector<int64_t>& vector; std::vector<int64_t> &vector;
int begin; int begin;
int end; int end;
SliceVectorizeVisitor(std::vector<int64_t>& v, int b, int e) SliceVectorizeVisitor(std::vector<int64_t> &v, int b, int e)
: vector(v), begin(b), end(e) { : vector(v), begin(b), end(e) {
// PADDLE_ENFORCE(begin < end, // PADDLE_ENFORCE(begin < end,
// "Begin index must be less than end index in ddim // "Begin index must be less than end index in ddim
...@@ -239,8 +229,7 @@ struct SliceVectorizeVisitor : Vistor<void> { ...@@ -239,8 +229,7 @@ struct SliceVectorizeVisitor : Vistor<void> {
// "Begin index can't be less than zero in ddim slice."); // "Begin index can't be less than zero in ddim slice.");
} }
template <int S> template <int S> void operator()(const Dim<S> &dim) {
void operator()(const Dim<S>& dim) {
if (begin == 0) { if (begin == 0) {
vector.push_back(dim.head); vector.push_back(dim.head);
} else { } else {
...@@ -252,12 +241,12 @@ struct SliceVectorizeVisitor : Vistor<void> { ...@@ -252,12 +241,12 @@ struct SliceVectorizeVisitor : Vistor<void> {
} }
} }
void operator()(const Dim<0>& dim) { void operator()(const Dim<0> &dim) {
// PADDLE_ENFORCE(end == 0, "End index in ddim slice is out of bound."); // PADDLE_ENFORCE(end == 0, "End index in ddim slice is out of bound.");
} }
}; };
DDim slice_ddim(const DDim& ddim, int begin, int end) { DDim slice_ddim(const DDim &ddim, int begin, int end) {
std::vector<int64_t> vec; std::vector<int64_t> vec;
vec.reserve(end - begin); vec.reserve(end - begin);
SliceVectorizeVisitor visitor(vec, begin, end); SliceVectorizeVisitor visitor(vec, begin, end);
...@@ -270,15 +259,12 @@ DDim slice_ddim(const DDim& ddim, int begin, int end) { ...@@ -270,15 +259,12 @@ DDim slice_ddim(const DDim& ddim, int begin, int end) {
/// \cond HIDDEN /// \cond HIDDEN
struct ArityVisitor : Vistor<int> { struct ArityVisitor : Vistor<int> {
template <int D> template <int D> int operator()(Dim<D>) const { return D; }
int operator()(Dim<D>) const {
return D;
}
}; };
/// \endcond /// \endcond
int arity(const DDim& d) { int arity(const DDim &d) {
ArityVisitor arityVisitor = ArityVisitor(); ArityVisitor arityVisitor = ArityVisitor();
return DDim::ApplyVistor(arityVisitor, d); return DDim::ApplyVistor(arityVisitor, d);
// return arityVisitor(d.var.Get<Dim<4>>()); // return arityVisitor(d.var.Get<Dim<4>>());
...@@ -288,19 +274,18 @@ int arity(const DDim& d) { ...@@ -288,19 +274,18 @@ int arity(const DDim& d) {
/// \endcond /// \endcond
struct OSVistor : Vistor<std::ostream&> { struct OSVistor : Vistor<std::ostream &> {
OSVistor(std::ostream& os) : os_(os) {} OSVistor(std::ostream &os) : os_(os) {}
template <int D> template <int D> std::ostream &operator()(Dim<D> dim) const {
std::ostream& operator()(Dim<D> dim) const {
return os_ << dim; return os_ << dim;
} }
private: private:
std::ostream& os_; std::ostream &os_;
}; };
std::ostream& operator<<(std::ostream& os, const DDim& ddim) { std::ostream &operator<<(std::ostream &os, const DDim &ddim) {
auto vistor = OSVistor(os); auto vistor = OSVistor(os);
DDim::ApplyVistor(vistor, ddim); DDim::ApplyVistor(vistor, ddim);
return os; return os;
...@@ -310,15 +295,15 @@ DDim::DDim(std::initializer_list<int64_t> init_list) { ...@@ -310,15 +295,15 @@ DDim::DDim(std::initializer_list<int64_t> init_list) {
*this = make_ddim(init_list); *this = make_ddim(init_list);
} }
DDim flatten_to_2d(const DDim& src, int num_col_dims) { DDim flatten_to_2d(const DDim &src, int num_col_dims) {
int rank = src.size(); int rank = src.size();
return make_ddim({product(slice_ddim(src, 0, num_col_dims)), return make_ddim({product(slice_ddim(src, 0, num_col_dims)),
product(slice_ddim(src, num_col_dims, rank))}); product(slice_ddim(src, num_col_dims, rank))});
} }
DDim flatten_to_1d(const DDim& src) { return make_ddim({product(src)}); } DDim flatten_to_1d(const DDim &src) { return make_ddim({product(src)}); }
DDim stride(const DDim& ddim) { DDim stride(const DDim &ddim) {
std::vector<int64_t> strides(ddim.size()); std::vector<int64_t> strides(ddim.size());
strides[ddim.size() - 1] = 1; strides[ddim.size() - 1] = 1;
for (int i = ddim.size() - 2; i >= 0; --i) { for (int i = ddim.size() - 2; i >= 0; --i) {
...@@ -327,7 +312,7 @@ DDim stride(const DDim& ddim) { ...@@ -327,7 +312,7 @@ DDim stride(const DDim& ddim) {
return framework::make_ddim(strides); return framework::make_ddim(strides);
} }
DDim stride_numel(const framework::DDim& ddim) { DDim stride_numel(const framework::DDim &ddim) {
std::vector<int64_t> strides(ddim.size()); std::vector<int64_t> strides(ddim.size());
strides[ddim.size() - 1] = ddim[ddim.size() - 1]; strides[ddim.size() - 1] = ddim[ddim.size() - 1];
for (int i = ddim.size() - 2; i >= 0; --i) { for (int i = ddim.size() - 2; i >= 0; --i) {
...@@ -336,5 +321,5 @@ DDim stride_numel(const framework::DDim& ddim) { ...@@ -336,5 +321,5 @@ DDim stride_numel(const framework::DDim& ddim) {
return framework::make_ddim(strides); return framework::make_ddim(strides);
} }
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -14,12 +14,12 @@ limitations under the License. */ ...@@ -14,12 +14,12 @@ limitations under the License. */
#pragma once #pragma once
#include "common/variant.h"
#include "dim.h"
#include <assert.h> #include <assert.h>
#include <initializer_list> #include <initializer_list>
#include <stdexcept> #include <stdexcept>
#include <vector> #include <vector>
#include "common/variant.h"
#include "dim.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
...@@ -66,15 +66,11 @@ struct DDim { ...@@ -66,15 +66,11 @@ struct DDim {
DDim() { var.Set<Dim<1>>(Dim<1>()); } DDim() { var.Set<Dim<1>>(Dim<1>()); }
template <int D> template <int D> explicit DDim(const Dim<D> &in) { var.Set<Dim<D>>(in); }
explicit DDim(const Dim<D> &in) {
var.Set<Dim<D>>(in);
}
/*implicit*/ DDim(std::initializer_list<int64_t> init_list); /*implicit*/ DDim(std::initializer_list<int64_t> init_list);
template <int D> template <int D> DDim &operator=(const Dim<D> &in) {
DDim &operator=(const Dim<D> &in) {
var.Set<Dim<D>>(in); var.Set<Dim<D>>(in);
return *this; return *this;
} }
...@@ -161,5 +157,5 @@ DDim flatten_to_1d(const DDim &src); ...@@ -161,5 +157,5 @@ DDim flatten_to_1d(const DDim &src);
DDim stride(const DDim &ddim); DDim stride(const DDim &ddim);
DDim stride_numel(const DDim &ddim); DDim stride_numel(const DDim &ddim);
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -24,8 +24,7 @@ namespace paddle_mobile { ...@@ -24,8 +24,7 @@ namespace paddle_mobile {
namespace framework { namespace framework {
// Statically sized, statically indexed dimension // Statically sized, statically indexed dimension
template <int i> template <int i> struct Dim {
struct Dim {
static constexpr int dimensions = i; static constexpr int dimensions = i;
template <typename... Args> template <typename... Args>
...@@ -35,7 +34,7 @@ struct Dim { ...@@ -35,7 +34,7 @@ struct Dim {
} }
HOSTDEVICE HOSTDEVICE
Dim(int64_t _head, const Dim<i - 1>& _tail) : head(_head), tail(_tail) {} Dim(int64_t _head, const Dim<i - 1> &_tail) : head(_head), tail(_tail) {}
HOSTDEVICE HOSTDEVICE
Dim() : head(0), tail() {} Dim() : head(0), tail() {}
...@@ -43,7 +42,7 @@ struct Dim { ...@@ -43,7 +42,7 @@ struct Dim {
/** Construct a Dim from a linear index and size. Uses Fortran order /** Construct a Dim from a linear index and size. Uses Fortran order
* indexing. */ * indexing. */
HOSTDEVICE HOSTDEVICE
Dim(int64_t idx, const Dim<i>& size) Dim(int64_t idx, const Dim<i> &size)
: head(idx % size.head), tail(idx / size.head, size.tail) {} : head(idx % size.head), tail(idx / size.head, size.tail) {}
/** Construct a Dim with each dimension set to the given index */ /** Construct a Dim with each dimension set to the given index */
...@@ -51,15 +50,15 @@ struct Dim { ...@@ -51,15 +50,15 @@ struct Dim {
Dim(int64_t idx) : head(idx), tail(idx) {} Dim(int64_t idx) : head(idx), tail(idx) {}
HOSTDEVICE HOSTDEVICE
bool operator==(const Dim<i>& o) const { bool operator==(const Dim<i> &o) const {
return (head == o.head) && (tail == o.tail); return (head == o.head) && (tail == o.tail);
} }
HOSTDEVICE HOSTDEVICE
bool operator!=(const Dim<i>& o) const { return !(*this == o); } bool operator!=(const Dim<i> &o) const { return !(*this == o); }
HOSTDEVICE HOSTDEVICE
int64_t& operator[](int idx); int64_t &operator[](int idx);
HOSTDEVICE HOSTDEVICE
int64_t operator[](int idx) const; int64_t operator[](int idx) const;
...@@ -70,8 +69,7 @@ struct Dim { ...@@ -70,8 +69,7 @@ struct Dim {
}; };
// Base case specialization // Base case specialization
template <> template <> struct Dim<0> {
struct Dim<0> {
static constexpr int dimensions = 0; static constexpr int dimensions = 0;
HOSTDEVICE HOSTDEVICE
...@@ -81,7 +79,7 @@ struct Dim<0> { ...@@ -81,7 +79,7 @@ struct Dim<0> {
Dim() {} Dim() {}
HOSTDEVICE HOSTDEVICE
Dim(int idx, const Dim<0>& size) { Dim(int idx, const Dim<0> &size) {
#ifndef __CUDA_ARCH__ #ifndef __CUDA_ARCH__
if (idx > 0) { if (idx > 0) {
throw std::invalid_argument("Index out of range."); throw std::invalid_argument("Index out of range.");
...@@ -92,13 +90,13 @@ struct Dim<0> { ...@@ -92,13 +90,13 @@ struct Dim<0> {
} }
HOSTDEVICE HOSTDEVICE
bool operator==(const Dim<0>& o) const { return true; } bool operator==(const Dim<0> &o) const { return true; }
HOSTDEVICE HOSTDEVICE
bool operator!=(const Dim<0>& o) const { return false; } bool operator!=(const Dim<0> &o) const { return false; }
HOSTDEVICE HOSTDEVICE
int64_t& operator[](int idx); int64_t &operator[](int idx);
HOSTDEVICE HOSTDEVICE
int64_t operator[](int idx) const; int64_t operator[](int idx) const;
}; };
...@@ -106,37 +104,28 @@ struct Dim<0> { ...@@ -106,37 +104,28 @@ struct Dim<0> {
namespace { namespace {
// Helper for accessing Dim classes // Helper for accessing Dim classes
template <int i> template <int i> struct DimGetter {
struct DimGetter {
// Return a copy if Dim is const // Return a copy if Dim is const
template <typename D> template <typename D> HOSTDEVICE static int64_t impl(const D &d) {
HOSTDEVICE static int64_t impl(const D& d) {
return DimGetter<i - 1>::impl(d.tail); return DimGetter<i - 1>::impl(d.tail);
} }
// Return a reference if Dim is mutable // Return a reference if Dim is mutable
template <typename D> template <typename D> HOSTDEVICE static int64_t &impl(D &d) {
HOSTDEVICE static int64_t& impl(D& d) {
return DimGetter<i - 1>::impl(d.tail); return DimGetter<i - 1>::impl(d.tail);
} }
}; };
// Eureka! We found the element! // Eureka! We found the element!
template <> template <> struct DimGetter<0> {
struct DimGetter<0> {
// Return a copy if Dim is const // Return a copy if Dim is const
template <typename D> template <typename D> HOSTDEVICE static int64_t impl(const D &d) {
HOSTDEVICE static int64_t impl(const D& d) {
return d.head; return d.head;
} }
// Return a reference if Dim is mutable // Return a reference if Dim is mutable
template <typename D> template <typename D> HOSTDEVICE static int64_t &impl(D &d) { return d.head; }
HOSTDEVICE static int64_t& impl(D& d) {
return d.head;
}
}; };
template <int D> template <int D> HOSTDEVICE int64_t &indexer(Dim<D> &dim, int idx) {
HOSTDEVICE int64_t& indexer(Dim<D>& dim, int idx) {
#ifndef __CUDA_ARCH__ #ifndef __CUDA_ARCH__
if (idx < 0) { if (idx < 0) {
throw std::invalid_argument("Tried to access a negative dimension"); throw std::invalid_argument("Tried to access a negative dimension");
...@@ -150,8 +139,7 @@ HOSTDEVICE int64_t& indexer(Dim<D>& dim, int idx) { ...@@ -150,8 +139,7 @@ HOSTDEVICE int64_t& indexer(Dim<D>& dim, int idx) {
return indexer(dim.tail, idx - 1); return indexer(dim.tail, idx - 1);
} }
template <> template <> HOSTDEVICE int64_t &indexer<0>(Dim<0> &dim, int idx) {
HOSTDEVICE int64_t& indexer<0>(Dim<0>& dim, int idx) {
#ifndef __CUDA_ARCH__ #ifndef __CUDA_ARCH__
throw std::invalid_argument("Invalid index"); throw std::invalid_argument("Invalid index");
#else #else
...@@ -167,8 +155,7 @@ HOSTDEVICE int64_t& indexer<0>(Dim<0>& dim, int idx) { ...@@ -167,8 +155,7 @@ HOSTDEVICE int64_t& indexer<0>(Dim<0>& dim, int idx) {
#endif #endif
} }
template <int D> template <int D> HOSTDEVICE int64_t indexer(const Dim<D> &dim, int idx) {
HOSTDEVICE int64_t indexer(const Dim<D>& dim, int idx) {
#ifndef __CUDA_ARCH__ #ifndef __CUDA_ARCH__
if (idx < 0) { if (idx < 0) {
throw std::invalid_argument("Tried to access a negative dimension"); throw std::invalid_argument("Tried to access a negative dimension");
...@@ -182,8 +169,7 @@ HOSTDEVICE int64_t indexer(const Dim<D>& dim, int idx) { ...@@ -182,8 +169,7 @@ HOSTDEVICE int64_t indexer(const Dim<D>& dim, int idx) {
return indexer(dim.tail, idx - 1); return indexer(dim.tail, idx - 1);
} }
template <> template <> HOSTDEVICE int64_t indexer<0>(const Dim<0> &dim, int idx) {
HOSTDEVICE int64_t indexer<0>(const Dim<0>& dim, int idx) {
#ifndef __CUDA_ARCH__ #ifndef __CUDA_ARCH__
throw std::invalid_argument("Invalid index"); throw std::invalid_argument("Invalid index");
#else #else
...@@ -199,29 +185,25 @@ HOSTDEVICE int64_t indexer<0>(const Dim<0>& dim, int idx) { ...@@ -199,29 +185,25 @@ HOSTDEVICE int64_t indexer<0>(const Dim<0>& dim, int idx) {
#endif #endif
} }
} // namespace } // namespace
// Static access to constant Dim // Static access to constant Dim
template <int i, int l> template <int i, int l> HOSTDEVICE int64_t get(const Dim<l> &d) {
HOSTDEVICE int64_t get(const Dim<l>& d) {
return DimGetter<i>::impl(d); return DimGetter<i>::impl(d);
} }
// Static access to mutable Dim // Static access to mutable Dim
template <int i, int l> template <int i, int l> HOSTDEVICE int64_t &get(Dim<l> &d) {
HOSTDEVICE int64_t& get(Dim<l>& d) {
return DimGetter<i>::impl(d); return DimGetter<i>::impl(d);
} }
// Dynamic access to constant Dim // Dynamic access to constant Dim
template <int l> template <int l> HOSTDEVICE int64_t Dim<l>::operator[](int i) const {
HOSTDEVICE int64_t Dim<l>::operator[](int i) const {
// std::cout << "l: " << l << std::endl; // std::cout << "l: " << l << std::endl;
return indexer(*this, i); return indexer(*this, i);
} }
// Dynamic access to mutable Dim // Dynamic access to mutable Dim
template <int l> template <int l> HOSTDEVICE int64_t &Dim<l>::operator[](int i) {
HOSTDEVICE int64_t& Dim<l>::operator[](int i) {
return indexer(*this, i); return indexer(*this, i);
} }
...@@ -231,54 +213,52 @@ inline HOSTDEVICE int64_t Dim<0>::operator[](int i) const { ...@@ -231,54 +213,52 @@ inline HOSTDEVICE int64_t Dim<0>::operator[](int i) const {
} }
// Dynamic access to mutable Dim // Dynamic access to mutable Dim
inline HOSTDEVICE int64_t& Dim<0>::operator[](int i) { inline HOSTDEVICE int64_t &Dim<0>::operator[](int i) {
return indexer(*this, i); return indexer(*this, i);
} }
// Dynamic access to constant Dim // Dynamic access to constant Dim
// without std::enable_if will try to instantiate this on get<0>(d) // without std::enable_if will try to instantiate this on get<0>(d)
template <int l> template <int l>
HOSTDEVICE typename std::enable_if<(l > 0), int64_t>::type get(const Dim<l>& d, HOSTDEVICE typename std::enable_if<(l > 0), int64_t>::type get(const Dim<l> &d,
int i) { int i) {
return d[i]; return d[i];
} }
// Dynamic access to mutable Dim // Dynamic access to mutable Dim
template <int l> template <int l>
HOSTDEVICE typename std::enable_if<(l > 0), int64_t&>::type get(Dim<l>& d, HOSTDEVICE typename std::enable_if<(l > 0), int64_t &>::type get(Dim<l> &d,
int i) { int i) {
return d[i]; return d[i];
} }
// Dot product of two dims // Dot product of two dims
template <int i> template <int i>
HOSTDEVICE int64_t linearize(const Dim<i>& a, const Dim<i>& b) { HOSTDEVICE int64_t linearize(const Dim<i> &a, const Dim<i> &b) {
return a.head * b.head + linearize(a.tail, b.tail); return a.head * b.head + linearize(a.tail, b.tail);
} }
// Base case dot product of two Dims // Base case dot product of two Dims
// Notice it is inline because it is no longer a template // Notice it is inline because it is no longer a template
template <> template <>
HOSTDEVICE inline int64_t linearize(const Dim<0>& a, const Dim<0>& b) { HOSTDEVICE inline int64_t linearize(const Dim<0> &a, const Dim<0> &b) {
return 0; return 0;
} }
// Product of a Dim // Product of a Dim
template <int i> template <int i> HOSTDEVICE int64_t product(const Dim<i> &a, int prod = 1) {
HOSTDEVICE int64_t product(const Dim<i>& a, int prod = 1) {
return prod * a.head * product(a.tail); return prod * a.head * product(a.tail);
} }
// Base case product of a Dim // Base case product of a Dim
// Notice it is inline because it is no longer a template // Notice it is inline because it is no longer a template
template <> template <> HOSTDEVICE inline int64_t product(const Dim<0> &a, int prod) {
HOSTDEVICE inline int64_t product(const Dim<0>& a, int prod) {
return prod; return prod;
} }
// Is 0 <= idx_i < size_i for all i? // Is 0 <= idx_i < size_i for all i?
template <int i> template <int i>
HOSTDEVICE bool contained(const Dim<i>& idx, const Dim<i>& size) { HOSTDEVICE bool contained(const Dim<i> &idx, const Dim<i> &size) {
return ((0 <= idx.head) && (idx.head < size.head) && return ((0 <= idx.head) && (idx.head < size.head) &&
contained(idx.tail, size.tail)); contained(idx.tail, size.tail));
} }
...@@ -286,7 +266,7 @@ HOSTDEVICE bool contained(const Dim<i>& idx, const Dim<i>& size) { ...@@ -286,7 +266,7 @@ HOSTDEVICE bool contained(const Dim<i>& idx, const Dim<i>& size) {
// Base case of is 0 <= idx_i < size_i ? // Base case of is 0 <= idx_i < size_i ?
// Notice it is inline because it is no longer a template // Notice it is inline because it is no longer a template
template <> template <>
HOSTDEVICE inline bool contained(const Dim<0>& idx, const Dim<0>& size) { HOSTDEVICE inline bool contained(const Dim<0> &idx, const Dim<0> &size) {
return true; return true;
} }
...@@ -294,15 +274,14 @@ HOSTDEVICE inline bool contained(const Dim<0>& idx, const Dim<0>& size) { ...@@ -294,15 +274,14 @@ HOSTDEVICE inline bool contained(const Dim<0>& idx, const Dim<0>& size) {
* \brief Compute exclusive prefix-multiply of a Dim. * \brief Compute exclusive prefix-multiply of a Dim.
*/ */
template <int i> template <int i>
HOSTDEVICE Dim<i> ex_prefix_mul(const Dim<i>& src, int mul = 1) { HOSTDEVICE Dim<i> ex_prefix_mul(const Dim<i> &src, int mul = 1) {
return Dim<i>(mul, ex_prefix_mul(src.tail, mul * src.head)); return Dim<i>(mul, ex_prefix_mul(src.tail, mul * src.head));
} }
///\cond HIDDEN ///\cond HIDDEN
// Base case of ex_prefix_mul // Base case of ex_prefix_mul
// Notice it is inline because it is no longer a template // Notice it is inline because it is no longer a template
template <> template <> HOSTDEVICE inline Dim<0> ex_prefix_mul(const Dim<0> &src, int mul) {
HOSTDEVICE inline Dim<0> ex_prefix_mul(const Dim<0>& src, int mul) {
return Dim<0>(); return Dim<0>();
} }
///\endcond ///\endcond
...@@ -310,38 +289,36 @@ HOSTDEVICE inline Dim<0> ex_prefix_mul(const Dim<0>& src, int mul) { ...@@ -310,38 +289,36 @@ HOSTDEVICE inline Dim<0> ex_prefix_mul(const Dim<0>& src, int mul) {
/** /**
* Add two dimensions together * Add two dimensions together
*/ */
template <int i> template <int i> HOSTDEVICE Dim<i> dim_plus(const Dim<i> &a, const Dim<i> &b) {
HOSTDEVICE Dim<i> dim_plus(const Dim<i>& a, const Dim<i>& b) {
return Dim<i>(a.head + b.head, dim_plus(a.tail, b.tail)); return Dim<i>(a.head + b.head, dim_plus(a.tail, b.tail));
} }
// Base case // Base case
template <> template <>
HOSTDEVICE inline Dim<0> dim_plus(const Dim<0>& a, const Dim<0>& b) { HOSTDEVICE inline Dim<0> dim_plus(const Dim<0> &a, const Dim<0> &b) {
return Dim<0>(); return Dim<0>();
} }
template <int i> template <int i>
HOSTDEVICE Dim<i> operator+(const Dim<i>& lhs, const Dim<i>& rhs) { HOSTDEVICE Dim<i> operator+(const Dim<i> &lhs, const Dim<i> &rhs) {
return dim_plus(lhs, rhs); return dim_plus(lhs, rhs);
} }
/** /**
* Multiply two dimensions together * Multiply two dimensions together
*/ */
template <int i> template <int i> HOSTDEVICE Dim<i> dim_mult(const Dim<i> &a, const Dim<i> &b) {
HOSTDEVICE Dim<i> dim_mult(const Dim<i>& a, const Dim<i>& b) {
return Dim<i>(a.head * b.head, dim_mult(a.tail, b.tail)); return Dim<i>(a.head * b.head, dim_mult(a.tail, b.tail));
} }
// Base case // Base case
template <> template <>
HOSTDEVICE inline Dim<0> dim_mult(const Dim<0>& a, const Dim<0>& b) { HOSTDEVICE inline Dim<0> dim_mult(const Dim<0> &a, const Dim<0> &b) {
return Dim<0>(); return Dim<0>();
} }
template <int i> template <int i>
HOSTDEVICE Dim<i> operator*(const Dim<i>& lhs, const Dim<i>& rhs) { HOSTDEVICE Dim<i> operator*(const Dim<i> &lhs, const Dim<i> &rhs) {
return dim_mult(lhs, rhs); return dim_mult(lhs, rhs);
} }
...@@ -356,7 +333,7 @@ HOSTDEVICE Dim<i> operator*(const Dim<i>& lhs, const Dim<i>& rhs) { ...@@ -356,7 +333,7 @@ HOSTDEVICE Dim<i> operator*(const Dim<i>& lhs, const Dim<i>& rhs) {
*/ */
template <int i> template <int i>
HOSTDEVICE Dim<i> normalize_strides(const Dim<i>& size, const Dim<i>& stride) { HOSTDEVICE Dim<i> normalize_strides(const Dim<i> &size, const Dim<i> &stride) {
int norm_stride = size.head == 1 ? 0 : stride.head; int norm_stride = size.head == 1 ? 0 : stride.head;
return Dim<i>(norm_stride, normalize_strides(size.tail, stride.tail)); return Dim<i>(norm_stride, normalize_strides(size.tail, stride.tail));
} }
...@@ -364,8 +341,8 @@ HOSTDEVICE Dim<i> normalize_strides(const Dim<i>& size, const Dim<i>& stride) { ...@@ -364,8 +341,8 @@ HOSTDEVICE Dim<i> normalize_strides(const Dim<i>& size, const Dim<i>& stride) {
///\cond HIDDEN ///\cond HIDDEN
template <> template <>
HOSTDEVICE inline Dim<0> normalize_strides(const Dim<0>& size, HOSTDEVICE inline Dim<0> normalize_strides(const Dim<0> &size,
const Dim<0>& stride) { const Dim<0> &stride) {
return Dim<0>(); return Dim<0>();
} }
...@@ -386,8 +363,8 @@ HOSTDEVICE Dim<sizeof...(Args)> make_dim(Args... idxes) { ...@@ -386,8 +363,8 @@ HOSTDEVICE Dim<sizeof...(Args)> make_dim(Args... idxes) {
// Allows us to output a Dim // Allows us to output a Dim
// XXX For some reason, overloading fails to resolve this correctly // XXX For some reason, overloading fails to resolve this correctly
template <int i> template <int i>
typename std::enable_if<(i > 1), std::ostream&>::type operator<<( typename std::enable_if<(i > 1), std::ostream &>::type
std::ostream& os, const Dim<i>& d) { operator<<(std::ostream &os, const Dim<i> &d) {
os << d.head << ", " << d.tail; os << d.head << ", " << d.tail;
return os; return os;
} }
...@@ -395,18 +372,17 @@ typename std::enable_if<(i > 1), std::ostream&>::type operator<<( ...@@ -395,18 +372,17 @@ typename std::enable_if<(i > 1), std::ostream&>::type operator<<(
// Base case that allows us to output a Dim // Base case that allows us to output a Dim
// XXX I wish this could be an overload instead of a template // XXX I wish this could be an overload instead of a template
template <int i> template <int i>
typename std::enable_if<(i == 1), std::ostream&>::type operator<<( typename std::enable_if<(i == 1), std::ostream &>::type
std::ostream& os, const Dim<i>& d) { operator<<(std::ostream &os, const Dim<i> &d) {
os << d.head; os << d.head;
return os; return os;
} }
inline std::ostream& operator<<(std::ostream& os, const Dim<0>& d) { inline std::ostream &operator<<(std::ostream &os, const Dim<0> &d) {
return os; return os;
} }
template <int i> template <int i> HOST std::string Dim<i>::to_string() const {
HOST std::string Dim<i>::to_string() const {
std::stringstream stream; std::stringstream stream;
stream << *this; stream << *this;
...@@ -428,5 +404,5 @@ HOSTDEVICE Dim<D> linear_to_dimension(int linear_index, Dim<D> extents) { ...@@ -428,5 +404,5 @@ HOSTDEVICE Dim<D> linear_to_dimension(int linear_index, Dim<D> extents) {
return result; return result;
} }
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -102,5 +102,5 @@ void Executor<Dtype>::predict(const Tensor &t, int block_id) { ...@@ -102,5 +102,5 @@ void Executor<Dtype>::predict(const Tensor &t, int block_id) {
template class Executor<CPU>; template class Executor<CPU>;
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -34,13 +34,12 @@ SOFTWARE. ...@@ -34,13 +34,12 @@ SOFTWARE.
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
template <typename Dtype> template <typename Dtype> class Executor {
class Executor { public:
public:
Executor(const Program<Dtype> p); Executor(const Program<Dtype> p);
std::shared_ptr<Tensor> predict(Tensor &t); std::shared_ptr<Tensor> predict(Tensor &t);
private: private:
const framework::Program<Dtype> program_; const framework::Program<Dtype> program_;
std::shared_ptr<ProgramDesc> to_predict_program_; std::shared_ptr<ProgramDesc> to_predict_program_;
void predict(const Tensor &t, int block_id); void predict(const Tensor &t, int block_id);
...@@ -50,5 +49,5 @@ class Executor { ...@@ -50,5 +49,5 @@ class Executor {
bool use_optimize_ = false; bool use_optimize_ = false;
}; };
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
此差异已折叠。
此差异已折叠。
...@@ -13,10 +13,10 @@ See the License for the specific language governing permissions and ...@@ -13,10 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "lod_tensor.h" #include "lod_tensor.h"
#include <stdint.h>
#include <string.h>
#include <algorithm> #include <algorithm>
#include <iterator> #include <iterator>
#include <stdint.h>
#include <string.h>
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
...@@ -103,7 +103,8 @@ LoD SliceInLevel(const LoD &in, size_t level, size_t elem_begin, ...@@ -103,7 +103,8 @@ LoD SliceInLevel(const LoD &in, size_t level, size_t elem_begin,
LoD ToAbsOffset(const LoD &in) { LoD ToAbsOffset(const LoD &in) {
// the lowest level stores relative offsets // the lowest level stores relative offsets
if (in.empty() || in.size() == 1) return in; if (in.empty() || in.size() == 1)
return in;
LoD result = in; LoD result = in;
for (auto level = static_cast<int>(in.size() - 2); level >= 0; level--) { for (auto level = static_cast<int>(in.size() - 2); level >= 0; level--) {
for (size_t i = 0; i < in[level].size(); ++i) { for (size_t i = 0; i < in[level].size(); ++i) {
...@@ -135,16 +136,20 @@ bool operator==(const LoD &a, const LoD &b) { ...@@ -135,16 +136,20 @@ bool operator==(const LoD &a, const LoD &b) {
} }
bool CheckLoD(const LoD &in, int tensor_height) { bool CheckLoD(const LoD &in, int tensor_height) {
if (in.empty()) return true; if (in.empty())
return true;
for (const auto &level : in) { for (const auto &level : in) {
// check: there should be more than 2 offsets existing in each level. // check: there should be more than 2 offsets existing in each level.
if (level.size() < 2) return false; if (level.size() < 2)
return false;
// check: the first offset(the begin offset) of each level should be 0. // check: the first offset(the begin offset) of each level should be 0.
if (level.front() != 0) return false; if (level.front() != 0)
return false;
// check: all the offsets in a level should be ascending(no same items // check: all the offsets in a level should be ascending(no same items
// allows). // allows).
if (!std::is_sorted(level.begin(), level.begin(), [](size_t a, size_t b) { if (!std::is_sorted(level.begin(), level.begin(), [](size_t a, size_t b) {
if (a < b) return true; if (a < b)
return true;
return false; return false;
})) { })) {
std::cout << "ascending error"; std::cout << "ascending error";
...@@ -161,29 +166,34 @@ bool CheckLoD(const LoD &in, int tensor_height) { ...@@ -161,29 +166,34 @@ bool CheckLoD(const LoD &in, int tensor_height) {
// NOTE LoD store the levels from top to bottom, so the higher level goes // NOTE LoD store the levels from top to bottom, so the higher level goes
// first. // first.
for (size_t level = 0; level < in.size() - 1; level++) { for (size_t level = 0; level < in.size() - 1; level++) {
if (in[level].back() != in[level + 1].size() - 1) return false; if (in[level].back() != in[level + 1].size() - 1)
return false;
} }
return true; return true;
} }
bool CheckAbsLoD(const LoD &in, int tensor_height) { bool CheckAbsLoD(const LoD &in, int tensor_height) {
if (in.empty()) return true; if (in.empty())
return true;
for (const auto &level : in) { for (const auto &level : in) {
// check: all the offsets in a level should be ascending(no same items // check: all the offsets in a level should be ascending(no same items
// allows). // allows).
if (!std::is_sorted(level.begin(), level.begin(), [](size_t a, size_t b) { if (!std::is_sorted(level.begin(), level.begin(), [](size_t a, size_t b) {
if (a < b) return true; if (a < b)
return true;
return false; return false;
})) { })) {
return false; return false;
} }
// check: there should be more than 2 offsets existing in each level. // check: there should be more than 2 offsets existing in each level.
if (level.size() < 2) return false; if (level.size() < 2)
return false;
// check: the first offset of each level should be 0, and the last should be // check: the first offset of each level should be 0, and the last should be
// the same(the height of underlying tensor). // the same(the height of underlying tensor).
if (level.front() != 0) return false; if (level.front() != 0)
return false;
if (tensor_height < 0) { if (tensor_height < 0) {
tensor_height = level.back(); tensor_height = level.back();
} else if ((size_t)tensor_height != level.back()) { } else if ((size_t)tensor_height != level.back()) {
...@@ -220,7 +230,7 @@ void AppendLoD(LoD *lod, const LoD &lod_length) { ...@@ -220,7 +230,7 @@ void AppendLoD(LoD *lod, const LoD &lod_length) {
// "The lod_length should has the same size with the appended lod."); // "The lod_length should has the same size with the appended lod.");
if (lod->empty()) { if (lod->empty()) {
for (size_t i = 0; i < lod_length.size(); ++i) { for (size_t i = 0; i < lod_length.size(); ++i) {
lod->emplace_back(1, 0); // size = 1, value = 0; lod->emplace_back(1, 0); // size = 1, value = 0;
} }
*lod = LoD(lod_length.size(), std::vector<size_t>({0})); *lod = LoD(lod_length.size(), std::vector<size_t>({0}));
} }
...@@ -233,7 +243,7 @@ void AppendLoD(LoD *lod, const LoD &lod_length) { ...@@ -233,7 +243,7 @@ void AppendLoD(LoD *lod, const LoD &lod_length) {
} }
void SerializeToStream(std::ostream &os, const LoDTensor &tensor) { void SerializeToStream(std::ostream &os, const LoDTensor &tensor) {
{ // the 1st field, uint32_t version for LoDTensor { // the 1st field, uint32_t version for LoDTensor
constexpr uint32_t version = 0; constexpr uint32_t version = 0;
os.write(reinterpret_cast<const char *>(&version), sizeof(version)); os.write(reinterpret_cast<const char *>(&version), sizeof(version));
} }
...@@ -284,5 +294,5 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor) { ...@@ -284,5 +294,5 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor) {
TensorFromStream(is, static_cast<Tensor *>(tensor)); TensorFromStream(is, static_cast<Tensor *>(tensor));
} }
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -14,12 +14,12 @@ limitations under the License. */ ...@@ -14,12 +14,12 @@ limitations under the License. */
#pragma once #pragma once
#include "tensor.h"
#include "tensor_util.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "tensor.h"
#include "tensor_util.h"
namespace paddle_mobile { namespace paddle_mobile {
...@@ -96,7 +96,7 @@ bool CheckAbsLoD(const LoD &in, int tensor_height = -1); ...@@ -96,7 +96,7 @@ bool CheckAbsLoD(const LoD &in, int tensor_height = -1);
* see https://en.wikipedia.org/wiki/Level_of_details for reference. * see https://en.wikipedia.org/wiki/Level_of_details for reference.
*/ */
class LoDTensor : public Tensor { class LoDTensor : public Tensor {
public: public:
LoDTensor() : Tensor() {} LoDTensor() : Tensor() {}
explicit LoDTensor(const LoD &lod) : lod_(lod) {} explicit LoDTensor(const LoD &lod) : lod_(lod) {}
...@@ -131,7 +131,7 @@ class LoDTensor : public Tensor { ...@@ -131,7 +131,7 @@ class LoDTensor : public Tensor {
return (lod_)[level].size() - 1; return (lod_)[level].size() - 1;
} }
private: private:
LoD lod_; LoD lod_;
}; };
...@@ -181,8 +181,9 @@ LoDTensor LodExpand(const LoDTensor &source, const LoD &lod, size_t level) { ...@@ -181,8 +181,9 @@ LoDTensor LodExpand(const LoDTensor &source, const LoD &lod, size_t level) {
// Returns: // Returns:
// LoD = [[1, 4], [2, 4, 2, 3, 2]] // LoD = [[1, 4], [2, 4, 2, 3, 2]]
// pair<size_t, size_t> = {11, 24} // pair<size_t, size_t> = {11, 24}
std::pair<LoD, std::pair<size_t, size_t>> GetSubLoDAndAbsoluteOffset( std::pair<LoD, std::pair<size_t, size_t>>
const LoD &lod, size_t start_idx, size_t end_idx, size_t start_level); GetSubLoDAndAbsoluteOffset(const LoD &lod, size_t start_idx, size_t end_idx,
size_t start_level);
void AppendLoD(LoD *lod, const LoD &lod_length); void AppendLoD(LoD *lod, const LoD &lod_length);
...@@ -195,5 +196,5 @@ void SerializeToStream(std::ostream &os, const LoDTensor &tensor); ...@@ -195,5 +196,5 @@ void SerializeToStream(std::ostream &os, const LoDTensor &tensor);
void DeserializeFromStream(std::istream &is, LoDTensor *tensor); void DeserializeFromStream(std::istream &is, LoDTensor *tensor);
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -55,5 +55,5 @@ const std::unordered_map<std::string, Attribute> &OpDesc::GetAttrMap() const { ...@@ -55,5 +55,5 @@ const std::unordered_map<std::string, Attribute> &OpDesc::GetAttrMap() const {
return attrs_; return attrs_;
} }
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -26,7 +26,7 @@ namespace paddle_mobile { ...@@ -26,7 +26,7 @@ namespace paddle_mobile {
namespace framework { namespace framework {
class OpDesc : PaddleMobileObject { class OpDesc : PaddleMobileObject {
public: public:
OpDesc(const proto::OpDesc &desc); OpDesc(const proto::OpDesc &desc);
const std::vector<std::string> &Input(const std::string &name) const; const std::vector<std::string> &Input(const std::string &name) const;
const std::vector<std::string> &Output(const std::string &name) const; const std::vector<std::string> &Output(const std::string &name) const;
...@@ -40,12 +40,12 @@ class OpDesc : PaddleMobileObject { ...@@ -40,12 +40,12 @@ class OpDesc : PaddleMobileObject {
const std::string &Type() { return desc_.type(); }; const std::string &Type() { return desc_.type(); };
private: private:
proto::OpDesc desc_; proto::OpDesc desc_;
VariableNameMap inputs_; VariableNameMap inputs_;
VariableNameMap outputs_; VariableNameMap outputs_;
AttributeMap attrs_; AttributeMap attrs_;
}; };
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -24,42 +24,38 @@ SOFTWARE. ...@@ -24,42 +24,38 @@ SOFTWARE.
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
template <typename Dtype> template <typename Dtype> struct OpInfo {
struct OpInfo {
OpCreator<Dtype> creator_; OpCreator<Dtype> creator_;
const OpCreator<Dtype>& Creator() const { const OpCreator<Dtype> &Creator() const {
// PADDLE_ENFORCE_NOT_NULL(creator_, // PADDLE_ENFORCE_NOT_NULL(creator_,
// "Operator Creator has not been registered"); // "Operator Creator has not been registered");
return creator_; return creator_;
} }
}; };
template <typename Dtype> template <typename Dtype> class OpInfoMap;
class OpInfoMap;
template <typename Dtype> template <typename Dtype> static OpInfoMap<Dtype> *g_op_info_map = nullptr;
static OpInfoMap<Dtype>* g_op_info_map = nullptr;
template <typename Dtype> template <typename Dtype> class OpInfoMap {
class OpInfoMap { public:
public: static OpInfoMap &Instance() {
static OpInfoMap& Instance() {
if (g_op_info_map<Dtype> == nullptr) { if (g_op_info_map<Dtype> == nullptr) {
g_op_info_map<Dtype> = new OpInfoMap(); g_op_info_map<Dtype> = new OpInfoMap();
} }
return *g_op_info_map<Dtype>; return *g_op_info_map<Dtype>;
}; };
bool Has(const std::string& op_type) const { bool Has(const std::string &op_type) const {
return map_.find(op_type) != map_.end(); return map_.find(op_type) != map_.end();
} }
void Insert(const std::string& type, const OpInfo<Dtype>& info) { void Insert(const std::string &type, const OpInfo<Dtype> &info) {
// PADDLE_ENFORCE(!Has(type), "Operator %s has been registered", type); // PADDLE_ENFORCE(!Has(type), "Operator %s has been registered", type);
map_.insert({type, info}); map_.insert({type, info});
} }
const OpInfo<Dtype>& Get(const std::string& type) const { const OpInfo<Dtype> &Get(const std::string &type) const {
auto op_info_ptr = GetNullable(type); auto op_info_ptr = GetNullable(type);
// PADDLE_ENFORCE_NOT_NULL(op_info_ptr, "Operator %s has not been // PADDLE_ENFORCE_NOT_NULL(op_info_ptr, "Operator %s has not been
// registered", // registered",
...@@ -67,7 +63,7 @@ class OpInfoMap { ...@@ -67,7 +63,7 @@ class OpInfoMap {
return *op_info_ptr; return *op_info_ptr;
} }
const OpInfo<Dtype>* GetNullable(const std::string& type) const { const OpInfo<Dtype> *GetNullable(const std::string &type) const {
auto it = map_.find(type); auto it = map_.find(type);
if (it == map_.end()) { if (it == map_.end()) {
return nullptr; return nullptr;
...@@ -76,20 +72,20 @@ class OpInfoMap { ...@@ -76,20 +72,20 @@ class OpInfoMap {
} }
} }
const std::unordered_map<std::string, OpInfo<Dtype>>& map() const { const std::unordered_map<std::string, OpInfo<Dtype>> &map() const {
return map_; return map_;
} }
std::unordered_map<std::string, OpInfo<Dtype>>* mutable_map() { std::unordered_map<std::string, OpInfo<Dtype>> *mutable_map() {
return &map_; return &map_;
} }
private: private:
OpInfoMap() = default; OpInfoMap() = default;
std::unordered_map<std::string, OpInfo<Dtype>> map_; std::unordered_map<std::string, OpInfo<Dtype>> map_;
// DISABLE_COPY_AND_ASSIGN(OpInfoMap); // DISABLE_COPY_AND_ASSIGN(OpInfoMap);
}; };
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -25,7 +25,7 @@ namespace paddle_mobile { ...@@ -25,7 +25,7 @@ namespace paddle_mobile {
namespace framework { namespace framework {
struct OpKernelType { struct OpKernelType {
struct Hash { struct Hash {
size_t operator()(const OpKernelType& key) const { size_t operator()(const OpKernelType &key) const {
int data_type = static_cast<int>(key.data_type_) << LEFT_SHIFT; int data_type = static_cast<int>(key.data_type_) << LEFT_SHIFT;
int data_layout = static_cast<int>(key.data_layout_) << (LEFT_SHIFT * 2); int data_layout = static_cast<int>(key.data_layout_) << (LEFT_SHIFT * 2);
...@@ -44,21 +44,21 @@ struct OpKernelType { ...@@ -44,21 +44,21 @@ struct OpKernelType {
DataLayout data_layout = DataLayout::kAnyLayout) DataLayout data_layout = DataLayout::kAnyLayout)
: data_type_(data_type), data_layout_(data_layout) {} : data_type_(data_type), data_layout_(data_layout) {}
bool operator==(const OpKernelType& o) const { bool operator==(const OpKernelType &o) const {
return data_type_ == o.data_type_ && data_layout_ == o.data_layout_; return data_type_ == o.data_type_ && data_layout_ == o.data_layout_;
} }
bool operator!=(const OpKernelType& o) const { return !(*this == o); } bool operator!=(const OpKernelType &o) const { return !(*this == o); }
}; };
inline bool NeedTransformLayout(const DataLayout& l, const DataLayout& r) { inline bool NeedTransformLayout(const DataLayout &l, const DataLayout &r) {
return l != DataLayout::kAnyLayout && r != DataLayout::kAnyLayout && l != r; return l != DataLayout::kAnyLayout && r != DataLayout::kAnyLayout && l != r;
} }
inline bool TransFromNeeded(const OpKernelType& l, const OpKernelType& r) { inline bool TransFromNeeded(const OpKernelType &l, const OpKernelType &r) {
return (l.data_type_ != r.data_type_) || return (l.data_type_ != r.data_type_) ||
NeedTransformLayout(l.data_layout_, r.data_layout_); NeedTransformLayout(l.data_layout_, r.data_layout_);
} }
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -22,5 +22,5 @@ namespace paddle_mobile { ...@@ -22,5 +22,5 @@ namespace paddle_mobile {
namespace framework { namespace framework {
// this class not only make proto but also init attribute checkers. // this class not only make proto but also init attribute checkers.
class OpProtoAndCheckerMaker {}; class OpProtoAndCheckerMaker {};
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -23,23 +23,17 @@ namespace paddle_mobile { ...@@ -23,23 +23,17 @@ namespace paddle_mobile {
namespace framework { namespace framework {
template <typename Dtype> template <typename Dtype>
OperatorBase<Dtype>::OperatorBase(const std::string& type, OperatorBase<Dtype>::OperatorBase(const std::string &type,
const VariableNameMap& inputs, const VariableNameMap &inputs,
const VariableNameMap& outputs, const VariableNameMap &outputs,
const AttributeMap& attrs, const AttributeMap &attrs,
std::shared_ptr<Scope> scope) std::shared_ptr<Scope> scope)
: type_(type), : type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs),
inputs_(inputs),
outputs_(outputs),
attrs_(attrs),
scope_(scope) { scope_(scope) {
CheckAllInputOutputSet(); CheckAllInputOutputSet();
} }
template <typename Dtype> template <typename Dtype> void OperatorBase<Dtype>::Run() { RunImpl(); }
void OperatorBase<Dtype>::Run() {
RunImpl();
}
template <typename Dtype> template <typename Dtype>
void OperatorBase<Dtype>::CheckAllInputOutputSet() const {} void OperatorBase<Dtype>::CheckAllInputOutputSet() const {}
...@@ -47,5 +41,5 @@ void OperatorBase<Dtype>::CheckAllInputOutputSet() const {} ...@@ -47,5 +41,5 @@ void OperatorBase<Dtype>::CheckAllInputOutputSet() const {}
template class OperatorBase<CPU>; template class OperatorBase<CPU>;
template class OperatorWithKernel<CPU>; template class OperatorWithKernel<CPU>;
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -35,53 +35,51 @@ SOFTWARE. ...@@ -35,53 +35,51 @@ SOFTWARE.
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
template <typename Dtype> template <typename Dtype> class OperatorBase : PaddleMobileObject {
class OperatorBase : PaddleMobileObject { public:
public: OperatorBase(const std::string &type, const VariableNameMap &inputs,
OperatorBase(const std::string& type, const VariableNameMap& inputs, const VariableNameMap &outputs, const AttributeMap &attrs,
const VariableNameMap& outputs, const AttributeMap& attrs,
std::shared_ptr<Scope> scope); std::shared_ptr<Scope> scope);
virtual ~OperatorBase() {} virtual ~OperatorBase() {}
virtual void Run(); virtual void Run();
const VariableNameMap& Inputs() const { return inputs_; } const VariableNameMap &Inputs() const { return inputs_; }
const VariableNameMap& Outputs() const { return outputs_; } const VariableNameMap &Outputs() const { return outputs_; }
const std::string& Type() const { return type_; } const std::string &Type() const { return type_; }
const AttributeMap& Attrs() const { return attrs_; } const AttributeMap &Attrs() const { return attrs_; }
protected: protected:
std::shared_ptr<Scope> scope_; std::shared_ptr<Scope> scope_;
std::string type_; std::string type_;
VariableNameMap inputs_; VariableNameMap inputs_;
VariableNameMap outputs_; VariableNameMap outputs_;
AttributeMap attrs_; AttributeMap attrs_;
private: private:
void CheckAllInputOutputSet() const; void CheckAllInputOutputSet() const;
virtual void RunImpl() const = 0; virtual void RunImpl() const = 0;
}; };
template <typename Dtype> template <typename Dtype>
class OperatorWithKernel : public OperatorBase<Dtype> { class OperatorWithKernel : public OperatorBase<Dtype> {
public: public:
OperatorWithKernel(const std::string& type, const VariableNameMap& inputs, OperatorWithKernel(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap& outputs, const AttributeMap& attrs, const VariableNameMap &outputs, const AttributeMap &attrs,
std::shared_ptr<Scope> scope) std::shared_ptr<Scope> scope)
: OperatorBase<Dtype>(type, inputs, outputs, attrs, scope) {} : OperatorBase<Dtype>(type, inputs, outputs, attrs, scope) {}
virtual void InferShape() const = 0; virtual void InferShape() const = 0;
protected: protected:
virtual void RunImpl() const = 0; virtual void RunImpl() const = 0;
private: private:
}; };
template <typename Dtype, typename P> template <typename Dtype, typename P> class OpKernelBase : PaddleMobileObject {
class OpKernelBase : PaddleMobileObject { public:
public: virtual void Compute(const P &para) const = 0;
virtual void Compute(const P& para) const = 0;
virtual ~OpKernelBase() = default; virtual ~OpKernelBase() = default;
}; };
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -18,19 +18,19 @@ SOFTWARE. ...@@ -18,19 +18,19 @@ SOFTWARE.
#pragma once #pragma once
#include <string>
#include "stdio.h" #include "stdio.h"
#include <string>
namespace paddle_mobile { namespace paddle_mobile {
class PaddleMobileObject { class PaddleMobileObject {
public: public:
virtual inline const std::string& ToString() { virtual inline const std::string &ToString() {
char address[128] = {0}; char address[128] = {0};
sprintf(address, "%p", this); sprintf(address, "%p", this);
return std::string(address); return std::string(address);
} }
private: private:
}; };
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -18,4 +18,4 @@ SOFTWARE. ...@@ -18,4 +18,4 @@ SOFTWARE.
namespace paddle_mobile { namespace paddle_mobile {
namespace framework {} namespace framework {}
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -28,13 +28,13 @@ namespace framework { ...@@ -28,13 +28,13 @@ namespace framework {
template <typename Dtype, Precision P = Precision::FP32> template <typename Dtype, Precision P = Precision::FP32>
class Program : PaddleMobileObject { class Program : PaddleMobileObject {
public: public:
std::shared_ptr<ProgramDesc> originProgram; std::shared_ptr<ProgramDesc> originProgram;
std::shared_ptr<ProgramDesc> optimizeProgram; std::shared_ptr<ProgramDesc> optimizeProgram;
std::shared_ptr<Scope> scope; std::shared_ptr<Scope> scope;
private: private:
}; };
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -18,5 +18,5 @@ std::shared_ptr<BlockDesc> ProgramDesc::Block(size_t idx) { ...@@ -18,5 +18,5 @@ std::shared_ptr<BlockDesc> ProgramDesc::Block(size_t idx) {
return blocks_[idx]; return blocks_[idx];
} }
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -28,15 +28,15 @@ namespace paddle_mobile { ...@@ -28,15 +28,15 @@ namespace paddle_mobile {
namespace framework { namespace framework {
class ProgramDesc : PaddleMobileObject { class ProgramDesc : PaddleMobileObject {
public: public:
ProgramDesc(const proto::ProgramDesc &desc); ProgramDesc(const proto::ProgramDesc &desc);
std::shared_ptr<BlockDesc> Block(size_t idx); std::shared_ptr<BlockDesc> Block(size_t idx);
const std::vector<std::shared_ptr<BlockDesc>> &Blocks() { return blocks_; }; const std::vector<std::shared_ptr<BlockDesc>> &Blocks() { return blocks_; };
private: private:
std::vector<std::shared_ptr<BlockDesc>> blocks_; std::vector<std::shared_ptr<BlockDesc>> blocks_;
proto::ProgramDesc desc_; proto::ProgramDesc desc_;
}; };
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -112,5 +112,5 @@ Variable *Scope::FindVarLocally(const std::string &name) const { ...@@ -112,5 +112,5 @@ Variable *Scope::FindVarLocally(const std::string &name) const {
return nullptr; return nullptr;
} }
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -18,38 +18,38 @@ SOFTWARE. ...@@ -18,38 +18,38 @@ SOFTWARE.
==============================================================================*/ ==============================================================================*/
#pragma once #pragma once
#include <list> //std::list
#include <mutex> //std::mutex
#include <unordered_map> //std::unordered_map
#include "variable.h" #include "variable.h"
#include <list> //std::list
#include <mutex> //std::mutex
#include <unordered_map> //std::unordered_map
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
class Scope { class Scope {
public: public:
Scope() {} Scope() {}
~Scope() {} ~Scope() {}
Scope& NewScope() const; Scope &NewScope() const;
/// Create a variable with given name if it doesn't exist. /// Create a variable with given name if it doesn't exist.
Variable* Var(const std::string& name); Variable *Var(const std::string &name);
/// Create a variable with a scope-unique name. /// Create a variable with a scope-unique name.
Variable* Var(std::string* name = nullptr); Variable *Var(std::string *name = nullptr);
void EraseVars(const std::vector<std::string>& var_names); void EraseVars(const std::vector<std::string> &var_names);
/// Find a variable in the scope or any of its ancestors. Returns /// Find a variable in the scope or any of its ancestors. Returns
/// nullptr if cannot find. /// nullptr if cannot find.
Variable* FindVar(const std::string& name) const; Variable *FindVar(const std::string &name) const;
const Scope* parent() const { return parent_; } const Scope *parent() const { return parent_; }
/// Find the scope or an ancestor scope that contains the given variable. /// Find the scope or an ancestor scope that contains the given variable.
const Scope* FindScope(const Variable* var) const; const Scope *FindScope(const Variable *var) const;
void DeleteScope(Scope* scope) const; void DeleteScope(Scope *scope) const;
/// Drop all kids scopes belonged to this scope. /// Drop all kids scopes belonged to this scope.
void DropKids(); void DropKids();
...@@ -58,23 +58,23 @@ class Scope { ...@@ -58,23 +58,23 @@ class Scope {
std::vector<std::string> LocalVarNames() const; std::vector<std::string> LocalVarNames() const;
// Rename variable to a new name // Rename variable to a new name
void Rename(const std::string& origin_name, void Rename(const std::string &origin_name,
const std::string& new_name) const; const std::string &new_name) const;
// Rename variable to a new name and return the new name // Rename variable to a new name and return the new name
std::string Rename(const std::string& origin_name) const; std::string Rename(const std::string &origin_name) const;
Variable* FindVarLocally(const std::string& name) const; Variable *FindVarLocally(const std::string &name) const;
private: private:
// Call Scope::NewScope for a sub-scope. // Call Scope::NewScope for a sub-scope.
explicit Scope(Scope const* parent) : parent_(parent) {} explicit Scope(Scope const *parent) : parent_(parent) {}
mutable std::unordered_map<std::string, Variable*> vars_; mutable std::unordered_map<std::string, Variable *> vars_;
mutable std::list<Scope*> kids_; mutable std::list<Scope *> kids_;
Scope const* parent_{nullptr}; Scope const *parent_{nullptr};
mutable std::mutex mutex_; mutable std::mutex mutex_;
}; };
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -27,8 +27,8 @@ namespace paddle_mobile { ...@@ -27,8 +27,8 @@ namespace paddle_mobile {
namespace framework { namespace framework {
class SelectedRows { class SelectedRows {
public: public:
SelectedRows(const std::vector<int64_t>& rows, const int64_t& height) SelectedRows(const std::vector<int64_t> &rows, const int64_t &height)
: rows_(rows), height_(height) { : rows_(rows), height_(height) {
value_.reset(new Tensor()); value_.reset(new Tensor());
} }
...@@ -38,19 +38,19 @@ class SelectedRows { ...@@ -38,19 +38,19 @@ class SelectedRows {
value_.reset(new Tensor()); value_.reset(new Tensor());
} }
const Tensor& value() const { return *value_; } const Tensor &value() const { return *value_; }
Tensor* mutable_value() { return value_.get(); } Tensor *mutable_value() { return value_.get(); }
int64_t height() const { return height_; } int64_t height() const { return height_; }
void set_height(int64_t height) { height_ = height; } void set_height(int64_t height) { height_ = height; }
const std::vector<int64_t>& rows() const { return rows_; } const std::vector<int64_t> &rows() const { return rows_; }
std::vector<int64_t>* mutable_rows() { return &rows_; } std::vector<int64_t> *mutable_rows() { return &rows_; }
void set_rows(const std::vector<int64_t>& rows) { rows_ = rows; } void set_rows(const std::vector<int64_t> &rows) { rows_ = rows; }
/** /**
* get the index of id in rows * get the index of id in rows
...@@ -67,7 +67,7 @@ class SelectedRows { ...@@ -67,7 +67,7 @@ class SelectedRows {
return make_ddim(dims); return make_ddim(dims);
} }
private: private:
// Notice: rows can be duplicate. We can have {0, 4, 7, 0, 5, 7, 9} here. // Notice: rows can be duplicate. We can have {0, 4, 7, 0, 5, 7, 9} here.
// SelectedRows are simply concated when adding together. Until a // SelectedRows are simply concated when adding together. Until a
// SelectedRows add a Tensor, will the duplicate rows be handled. // SelectedRows add a Tensor, will the duplicate rows be handled.
...@@ -76,5 +76,5 @@ class SelectedRows { ...@@ -76,5 +76,5 @@ class SelectedRows {
int64_t height_; int64_t height_;
}; };
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -26,11 +26,9 @@ limitations under the License. */ ...@@ -26,11 +26,9 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
template <typename... T> template <typename... T> struct SizeOfTypeFunctor;
struct SizeOfTypeFunctor;
template <typename T> template <typename T> struct SizeOfTypeFunctor<T> {
struct SizeOfTypeFunctor<T> {
size_t operator()(std::type_index type) const { size_t operator()(std::type_index type) const {
if (typeid(T).hash_code() == type.hash_code()) { if (typeid(T).hash_code() == type.hash_code()) {
return sizeof(T); return sizeof(T);
...@@ -40,8 +38,7 @@ struct SizeOfTypeFunctor<T> { ...@@ -40,8 +38,7 @@ struct SizeOfTypeFunctor<T> {
} }
}; };
template <> template <> struct SizeOfTypeFunctor<> {
struct SizeOfTypeFunctor<> {
size_t operator()(std::type_index type) const { return 0UL; } size_t operator()(std::type_index type) const { return 0UL; }
}; };
...@@ -68,12 +65,11 @@ static inline size_t SizeOfType(std::type_index type) { ...@@ -68,12 +65,11 @@ static inline size_t SizeOfType(std::type_index type) {
class LoDTensor; class LoDTensor;
class Tensor { class Tensor {
public: public:
Tensor() : offset_(0) {} Tensor() : offset_(0) {}
/*! Return a pointer to mutable memory block. */ /*! Return a pointer to mutable memory block. */
template <typename T> template <typename T> inline T *data() {
inline T *data() {
check_memory_size(); check_memory_size();
// PADDLE_ENFORCE(std::is_same<T, void>::value || // PADDLE_ENFORCE(std::is_same<T, void>::value ||
// holder_->type().hash_code() == typeid(T).hash_code(), // holder_->type().hash_code() == typeid(T).hash_code(),
...@@ -84,8 +80,7 @@ class Tensor { ...@@ -84,8 +80,7 @@ class Tensor {
} }
/*! Return a pointer to constant memory block. */ /*! Return a pointer to constant memory block. */
template <typename T> template <typename T> inline const T *data() const {
inline const T *data() const {
check_memory_size(); check_memory_size();
// PADDLE_ENFORCE(std::is_same<T, void>::value || // PADDLE_ENFORCE(std::is_same<T, void>::value ||
// holder_->type().hash_code() == typeid(T).hash_code(), // holder_->type().hash_code() == typeid(T).hash_code(),
...@@ -102,8 +97,7 @@ class Tensor { ...@@ -102,8 +97,7 @@ class Tensor {
* @brief Return a pointer to mutable memory block. * @brief Return a pointer to mutable memory block.
* @note If not exist, then allocation. * @note If not exist, then allocation.
*/ */
template <typename T> template <typename T> inline T *mutable_data() {
inline T *mutable_data() {
static_assert(std::is_pod<T>::value, "T must be POD"); static_assert(std::is_pod<T>::value, "T must be POD");
return reinterpret_cast<T *>(mutable_data(typeid(T))); return reinterpret_cast<T *>(mutable_data(typeid(T)));
} }
...@@ -141,8 +135,7 @@ class Tensor { ...@@ -141,8 +135,7 @@ class Tensor {
* *
* @note If not exist, then allocation. * @note If not exist, then allocation.
*/ */
template <typename T> template <typename T> inline T *mutable_data(DDim dims) {
inline T *mutable_data(DDim dims) {
static_assert(std::is_pod<T>::value, "T must be POD"); static_assert(std::is_pod<T>::value, "T must be POD");
Resize(dims); Resize(dims);
return mutable_data<T>(); return mutable_data<T>();
...@@ -227,7 +220,7 @@ class Tensor { ...@@ -227,7 +220,7 @@ class Tensor {
inline void set_layout(const DataLayout layout) { layout_ = layout; } inline void set_layout(const DataLayout layout) { layout_ = layout; }
private: private:
/** /**
* @note Placeholder hides type T, so it doesn't appear as a template * @note Placeholder hides type T, so it doesn't appear as a template
* parameter of Variable. * parameter of Variable.
...@@ -248,8 +241,7 @@ class Tensor { ...@@ -248,8 +241,7 @@ class Tensor {
PlaceholderImpl(size_t size, std::type_index type) PlaceholderImpl(size_t size, std::type_index type)
: ptr_(static_cast<uint8_t *>(memory::Alloc(size)), : ptr_(static_cast<uint8_t *>(memory::Alloc(size)),
memory::PODDeleter<uint8_t>()), memory::PODDeleter<uint8_t>()),
size_(size), size_(size), type_(type) {
type_(type) {
// PADDLE_ENFORCE_NOT_NULL(ptr_, "Insufficient %s // PADDLE_ENFORCE_NOT_NULL(ptr_, "Insufficient %s
// memory to allocation.", // memory to allocation.",
// (is_cpu_place(place_) ? // (is_cpu_place(place_) ?
...@@ -315,5 +307,5 @@ inline Tensor ReshapeToMatrix(const Tensor &src, int num_col_dims) { ...@@ -315,5 +307,5 @@ inline Tensor ReshapeToMatrix(const Tensor &src, int num_col_dims) {
return res; return res;
} }
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
void TensorCopy(const Tensor& src, Tensor* dst) { void TensorCopy(const Tensor &src, Tensor *dst) {
// VLOG(3) << "TensorCopy " << src.dims() << " from " << src.place() << " to // VLOG(3) << "TensorCopy " << src.dims() << " from " << src.place() << " to
// " // "
// << dst_place; // << dst_place;
...@@ -37,7 +37,7 @@ void TensorCopy(const Tensor& src, Tensor* dst) { ...@@ -37,7 +37,7 @@ void TensorCopy(const Tensor& src, Tensor* dst) {
memory::Copy(dst_ptr, src_ptr, size); memory::Copy(dst_ptr, src_ptr, size);
} }
void TensorCopySync(const Tensor& src, Tensor* dst) { void TensorCopySync(const Tensor &src, Tensor *dst) {
// VLOG(3) << "TensorCopySync " << src.dims() << " from " << src.place() // VLOG(3) << "TensorCopySync " << src.dims() << " from " << src.place()
// << " to " << dst_place; // << " to " << dst_place;
src.check_memory_size(); src.check_memory_size();
...@@ -49,17 +49,15 @@ void TensorCopySync(const Tensor& src, Tensor* dst) { ...@@ -49,17 +49,15 @@ void TensorCopySync(const Tensor& src, Tensor* dst) {
memory::Copy(dst_ptr, src_ptr, size); memory::Copy(dst_ptr, src_ptr, size);
} }
template <typename Predicate> template <typename Predicate> struct AnyDTypeVisitor {
struct AnyDTypeVisitor {
Predicate predicate_; Predicate predicate_;
const Tensor& tensor_; const Tensor &tensor_;
Tensor* out_; Tensor *out_;
AnyDTypeVisitor(Predicate predicate, const Tensor& tensor, Tensor* out) AnyDTypeVisitor(Predicate predicate, const Tensor &tensor, Tensor *out)
: predicate_(predicate), tensor_(tensor), out_(out) {} : predicate_(predicate), tensor_(tensor), out_(out) {}
template <typename T> template <typename T> void operator()() const {
void operator()() const {
// auto t = EigenVector<T>::Flatten(tensor_); // auto t = EigenVector<T>::Flatten(tensor_);
// auto o = EigenScalar<bool>::From(*out_); // auto o = EigenScalar<bool>::From(*out_);
// return any of predicate_(t) is true. // return any of predicate_(t) is true.
...@@ -68,18 +66,17 @@ struct AnyDTypeVisitor { ...@@ -68,18 +66,17 @@ struct AnyDTypeVisitor {
}; };
template <typename Predicate> template <typename Predicate>
inline void AnyImpl(Predicate predicate, const Tensor& tensor, inline void AnyImpl(Predicate predicate, const Tensor &tensor,
framework::Tensor* out) { framework::Tensor *out) {
VisitDataType(ToDataType(tensor.type()), VisitDataType(ToDataType(tensor.type()),
AnyDTypeVisitor<Predicate>(predicate, tensor, out)); AnyDTypeVisitor<Predicate>(predicate, tensor, out));
} }
template <typename Predicate> template <typename Predicate> struct AnyVisitor {
struct AnyVisitor { const framework::Tensor &tensor_;
const framework::Tensor& tensor_;
Predicate predicate_; Predicate predicate_;
AnyVisitor(const framework::Tensor& tensor, Predicate predicate) AnyVisitor(const framework::Tensor &tensor, Predicate predicate)
: tensor_(tensor), predicate_(std::move(predicate)) {} : tensor_(tensor), predicate_(std::move(predicate)) {}
bool operator()(void) const { bool operator()(void) const {
...@@ -90,13 +87,13 @@ struct AnyVisitor { ...@@ -90,13 +87,13 @@ struct AnyVisitor {
return this->GetResult(out); return this->GetResult(out);
} }
bool GetResult(const framework::Tensor& out) const { bool GetResult(const framework::Tensor &out) const {
return *out.data<bool>(); return *out.data<bool>();
} }
}; };
template <typename Predicate> template <typename Predicate>
inline bool Any(const framework::Tensor& tensor, Predicate predicate) { inline bool Any(const framework::Tensor &tensor, Predicate predicate) {
AnyVisitor<Predicate> visitor(tensor, predicate); AnyVisitor<Predicate> visitor(tensor, predicate);
// return platform::VisitPlace(visitor); // return platform::VisitPlace(visitor);
return visitor(); return visitor();
...@@ -104,101 +101,100 @@ inline bool Any(const framework::Tensor& tensor, Predicate predicate) { ...@@ -104,101 +101,100 @@ inline bool Any(const framework::Tensor& tensor, Predicate predicate) {
struct ContainsNANPredicate { struct ContainsNANPredicate {
template <typename T> template <typename T>
auto operator()(const T& eigen_vec) const auto operator()(const T &eigen_vec) const
-> decltype(std::declval<T>().isnan()) { -> decltype(std::declval<T>().isnan()) {
// Cast eigen_vector to vector of bool. true if is inf. // Cast eigen_vector to vector of bool. true if is inf.
return eigen_vec.isnan(); return eigen_vec.isnan();
} }
}; };
bool TensorContainsNAN(const framework::Tensor& tensor) { bool TensorContainsNAN(const framework::Tensor &tensor) {
ContainsNANPredicate predicate; ContainsNANPredicate predicate;
return Any(tensor, predicate); return Any(tensor, predicate);
} }
struct ContainsInfPredicate { struct ContainsInfPredicate {
template <typename T> template <typename T>
auto operator()(const T& eigen_vec) const auto operator()(const T &eigen_vec) const
-> decltype(std::declval<T>().isinf()) { -> decltype(std::declval<T>().isinf()) {
// Cast eigen_vector to vector of bool. true if is inf. // Cast eigen_vector to vector of bool. true if is inf.
return eigen_vec.isinf(); return eigen_vec.isinf();
} }
}; };
bool TensorContainsInf(const framework::Tensor& tensor) { bool TensorContainsInf(const framework::Tensor &tensor) {
ContainsInfPredicate predicate; ContainsInfPredicate predicate;
return Any(tensor, predicate); return Any(tensor, predicate);
} }
void TensorToStream(std::ostream& os, const Tensor& tensor) { void TensorToStream(std::ostream &os, const Tensor &tensor) {
{ // the 1st field, uint32_t version { // the 1st field, uint32_t version
constexpr uint32_t version = 0; constexpr uint32_t version = 0;
os.write(reinterpret_cast<const char*>(&version), sizeof(version)); os.write(reinterpret_cast<const char *>(&version), sizeof(version));
} }
{ // the 2nd field, tensor description { // the 2nd field, tensor description
// int32_t size // int32_t size
// void* protobuf message // void* protobuf message
proto::VarType::TensorDesc desc; proto::VarType::TensorDesc desc;
desc.set_data_type(framework::ToDataType(tensor.type())); desc.set_data_type(framework::ToDataType(tensor.type()));
auto dims = framework::vectorize(tensor.dims()); auto dims = framework::vectorize(tensor.dims());
auto* pb_dims = desc.mutable_dims(); auto *pb_dims = desc.mutable_dims();
pb_dims->Resize(static_cast<int>(dims.size()), 0); pb_dims->Resize(static_cast<int>(dims.size()), 0);
std::copy(dims.begin(), dims.end(), pb_dims->begin()); std::copy(dims.begin(), dims.end(), pb_dims->begin());
int32_t size = desc.ByteSize(); int32_t size = desc.ByteSize();
os.write(reinterpret_cast<const char*>(&size), sizeof(size)); os.write(reinterpret_cast<const char *>(&size), sizeof(size));
auto out = desc.SerializeAsString(); auto out = desc.SerializeAsString();
os.write(out.data(), size); os.write(out.data(), size);
} }
{ // the 3rd field, tensor data { // the 3rd field, tensor data
uint64_t size = tensor.memory_size(); uint64_t size = tensor.memory_size();
auto* data_ptr = tensor.data<void>(); auto *data_ptr = tensor.data<void>();
// PADDLE_ENFORCE(size < std::numeric_limits<std::streamsize>::max(), // PADDLE_ENFORCE(size < std::numeric_limits<std::streamsize>::max(),
// "Index overflow when writing tensor"); // "Index overflow when writing tensor");
os.write(static_cast<const char*>(data_ptr), os.write(static_cast<const char *>(data_ptr),
static_cast<std::streamsize>(size)); static_cast<std::streamsize>(size));
} }
} }
struct DeserializedDataFunctor { struct DeserializedDataFunctor {
DeserializedDataFunctor(void** buf, Tensor* tensor) DeserializedDataFunctor(void **buf, Tensor *tensor)
: buf_(buf), tensor_(tensor) {} : buf_(buf), tensor_(tensor) {}
template <typename T> template <typename T> void operator()() {
void operator()() {
*buf_ = tensor_->mutable_data<T>(); *buf_ = tensor_->mutable_data<T>();
} }
void** buf_; void **buf_;
Tensor* tensor_; Tensor *tensor_;
}; };
void TensorFromStream(std::istream& is, framework::Tensor* tensor) { void TensorFromStream(std::istream &is, framework::Tensor *tensor) {
uint32_t version; uint32_t version;
is.read(reinterpret_cast<char*>(&version), sizeof(version)); is.read(reinterpret_cast<char *>(&version), sizeof(version));
// PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported"); // PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
proto::VarType::TensorDesc desc; proto::VarType::TensorDesc desc;
{ // int32_t size { // int32_t size
// proto buffer // proto buffer
int32_t size; int32_t size;
is.read(reinterpret_cast<char*>(&size), sizeof(size)); is.read(reinterpret_cast<char *>(&size), sizeof(size));
std::unique_ptr<char[]> buf(new char[size]); std::unique_ptr<char[]> buf(new char[size]);
is.read(reinterpret_cast<char*>(buf.get()), size); is.read(reinterpret_cast<char *>(buf.get()), size);
// PADDLE_ENFORCE(desc.ParseFromArray(buf.get(), size), // PADDLE_ENFORCE(desc.ParseFromArray(buf.get(), size),
// "Cannot parse tensor desc"); // "Cannot parse tensor desc");
} }
{ // read tensor { // read tensor
std::vector<int64_t> dims; std::vector<int64_t> dims;
dims.reserve(static_cast<size_t>(desc.dims().size())); dims.reserve(static_cast<size_t>(desc.dims().size()));
std::copy(desc.dims().begin(), desc.dims().end(), std::back_inserter(dims)); std::copy(desc.dims().begin(), desc.dims().end(), std::back_inserter(dims));
tensor->Resize(framework::make_ddim(dims)); tensor->Resize(framework::make_ddim(dims));
void* buf; void *buf;
framework::VisitDataType(desc.data_type(), framework::VisitDataType(desc.data_type(),
DeserializedDataFunctor(&buf, tensor)); DeserializedDataFunctor(&buf, tensor));
is.read(static_cast<char*>(buf), tensor->memory_size()); is.read(static_cast<char *>(buf), tensor->memory_size());
} }
} }
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -13,54 +13,54 @@ See the License for the specific language governing permissions and ...@@ -13,54 +13,54 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <vector>
#include "framework.pb.h" #include "framework.pb.h"
#include "memory/t_malloc.h" #include "memory/t_malloc.h"
#include "platform/data_type.h" #include "platform/data_type.h"
#include "tensor.h" #include "tensor.h"
#include <vector>
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
void TensorCopy(const Tensor& src, Tensor* dst); void TensorCopy(const Tensor &src, Tensor *dst);
void TensorCopySync(const Tensor& src, Tensor* dst); void TensorCopySync(const Tensor &src, Tensor *dst);
template <typename T> template <typename T>
void TensorFromVector(const std::vector<T>& src, Tensor* dst); void TensorFromVector(const std::vector<T> &src, Tensor *dst);
template <typename T> template <typename T>
void TesnorToVector(const Tensor& src, std::vector<T>* dst); void TesnorToVector(const Tensor &src, std::vector<T> *dst);
bool TensorContainsNAN(const framework::Tensor& tensor); bool TensorContainsNAN(const framework::Tensor &tensor);
bool TensorContainsInf(const framework::Tensor& tensor); bool TensorContainsInf(const framework::Tensor &tensor);
void TensorToStream(std::ostream& os, const Tensor& tensor); void TensorToStream(std::ostream &os, const Tensor &tensor);
void TensorFromStream(std::istream& is, Tensor* tensor); void TensorFromStream(std::istream &is, Tensor *tensor);
// //
// The implementation of template functions. // The implementation of template functions.
// //
template <typename T> template <typename T>
void TensorFromVector(const std::vector<T>& src, Tensor* dst) { void TensorFromVector(const std::vector<T> &src, Tensor *dst) {
auto src_ptr = static_cast<const void*>(src.data()); auto src_ptr = static_cast<const void *>(src.data());
dst->Resize({static_cast<int64_t>(src.size())}); dst->Resize({static_cast<int64_t>(src.size())});
auto dst_ptr = static_cast<void*>(dst->mutable_data<T>()); auto dst_ptr = static_cast<void *>(dst->mutable_data<T>());
auto size = src.size() * sizeof(T); auto size = src.size() * sizeof(T);
memory::Copy(dst_ptr, src_ptr, size); memory::Copy(dst_ptr, src_ptr, size);
} }
template <typename T> template <typename T>
void TensorToVector(const Tensor& src, std::vector<T>* dst) { void TensorToVector(const Tensor &src, std::vector<T> *dst) {
auto src_ptr = static_cast<const void*>(src.data<T>()); auto src_ptr = static_cast<const void *>(src.data<T>());
auto size = src.numel() * sizeof(T); auto size = src.numel() * sizeof(T);
dst->resize(src.numel()); dst->resize(src.numel());
auto dst_ptr = static_cast<void*>(dst->data()); auto dst_ptr = static_cast<void *>(dst->data());
memory::Copy(dst_ptr, src_ptr, size); memory::Copy(dst_ptr, src_ptr, size);
} }
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -24,5 +24,5 @@ namespace framework { ...@@ -24,5 +24,5 @@ namespace framework {
VarDesc::VarDesc(const proto::VarDesc &desc) : desc_(desc) {} VarDesc::VarDesc(const proto::VarDesc &desc) : desc_(desc) {}
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -25,7 +25,7 @@ namespace paddle_mobile { ...@@ -25,7 +25,7 @@ namespace paddle_mobile {
namespace framework { namespace framework {
class VarDesc { class VarDesc {
public: public:
VarDesc(const proto::VarDesc &desc); VarDesc(const proto::VarDesc &desc);
std::string Name() const { return desc_.name(); } std::string Name() const { return desc_.name(); }
...@@ -36,33 +36,33 @@ class VarDesc { ...@@ -36,33 +36,33 @@ class VarDesc {
const proto::VarType::ChannelDesc &channel_desc() const { const proto::VarType::ChannelDesc &channel_desc() const {
switch (desc_.type().type()) { switch (desc_.type().type()) {
case proto::VarType::CHANNEL: case proto::VarType::CHANNEL:
return desc_.type().channel(); return desc_.type().channel();
default: default:
break; break;
} }
} }
const proto::VarType::TensorDesc &tensor_desc() const { const proto::VarType::TensorDesc &tensor_desc() const {
switch (desc_.type().type()) { switch (desc_.type().type()) {
case proto::VarType::SELECTED_ROWS: case proto::VarType::SELECTED_ROWS:
return desc_.type().selected_rows(); return desc_.type().selected_rows();
case proto::VarType::LOD_TENSOR: case proto::VarType::LOD_TENSOR:
return desc_.type().lod_tensor().tensor(); return desc_.type().lod_tensor().tensor();
case proto::VarType::LOD_TENSOR_ARRAY: case proto::VarType::LOD_TENSOR_ARRAY:
return desc_.type().tensor_array().tensor(); return desc_.type().tensor_array().tensor();
default: default:
break; break;
} }
} }
proto::VarType::Type GetDataType() const { proto::VarType::Type GetDataType() const {
switch (desc_.type().type()) { switch (desc_.type().type()) {
case proto::VarType::CHANNEL: case proto::VarType::CHANNEL:
return channel_desc().data_type(); return channel_desc().data_type();
break; break;
default: default:
return tensor_desc().data_type(); return tensor_desc().data_type();
} }
} }
...@@ -80,9 +80,9 @@ class VarDesc { ...@@ -80,9 +80,9 @@ class VarDesc {
return this->RepeatedToVector(tensor_desc().dims()); return this->RepeatedToVector(tensor_desc().dims());
} }
private: private:
proto::VarDesc desc_; proto::VarDesc desc_;
}; };
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -34,5 +34,5 @@ inline proto::VarType::Type ToVarType(std::type_index type) { ...@@ -34,5 +34,5 @@ inline proto::VarType::Type ToVarType(std::type_index type) {
} }
} }
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -18,42 +18,39 @@ SOFTWARE. ...@@ -18,42 +18,39 @@ SOFTWARE.
==============================================================================*/ ==============================================================================*/
#pragma once #pragma once
#include "paddle_mobile_object.h"
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <string> #include <string>
#include <typeindex> #include <typeindex>
#include <typeinfo> #include <typeinfo>
#include "paddle_mobile_object.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
class Variable : public PaddleMobileObject { class Variable : public PaddleMobileObject {
public: public:
Variable() {} Variable() {}
~Variable() {} ~Variable() {}
template <typename T> template <typename T> const T *Get() const {
const T* Get() const { return static_cast<const T *>(holder_->Ptr());
return static_cast<const T*>(holder_->Ptr());
} }
bool IsInitialized() const { return holder_ != nullptr; } bool IsInitialized() const { return holder_ != nullptr; }
const std::string* Name() { return name_; } const std::string *Name() { return name_; }
template <typename T> template <typename T> T *GetMutable() {
T* GetMutable() {
if (!IsType<T>()) { if (!IsType<T>()) {
if (*Name() == "pixel") { if (*Name() == "pixel") {
// std::cout << " reset " << *Name() << std::endl; // std::cout << " reset " << *Name() << std::endl;
} }
holder_.reset(new PlaceholderImp<T>(new T())); holder_.reset(new PlaceholderImp<T>(new T()));
} }
return static_cast<T*>(holder_->Ptr()); return static_cast<T *>(holder_->Ptr());
} }
template <typename T> template <typename T> bool IsType() const {
bool IsType() const {
if (holder_) { if (holder_) {
// printf("not null \n"); // printf("not null \n");
printf(" holder type : %s, this type %s \n", holder_->Type().name(), printf(" holder type : %s, this type %s \n", holder_->Type().name(),
...@@ -69,33 +66,32 @@ class Variable : public PaddleMobileObject { ...@@ -69,33 +66,32 @@ class Variable : public PaddleMobileObject {
std::type_index Type() const { return holder_->Type(); } std::type_index Type() const { return holder_->Type(); }
void SetName(const std::string* name) { name_ = name; } void SetName(const std::string *name) { name_ = name; }
private: private:
struct Placeholder { struct Placeholder {
Placeholder() = default; Placeholder() = default;
virtual ~Placeholder() = default; virtual ~Placeholder() = default;
virtual const std::type_info& Type() const = 0; virtual const std::type_info &Type() const = 0;
virtual void* Ptr() const = 0; virtual void *Ptr() const = 0;
}; };
template <typename T> template <typename T> struct PlaceholderImp : public Placeholder {
struct PlaceholderImp : public Placeholder { explicit PlaceholderImp(T *ptr) : ptr_(ptr), type_(typeid(T)) {}
explicit PlaceholderImp(T* ptr) : ptr_(ptr), type_(typeid(T)) {}
virtual const std::type_info& Type() const { return type_; } virtual const std::type_info &Type() const { return type_; }
virtual void* Ptr() const override { virtual void *Ptr() const override {
return static_cast<void*>(ptr_.get()); return static_cast<void *>(ptr_.get());
} }
std::unique_ptr<T> ptr_; std::unique_ptr<T> ptr_;
const std::type_info& type_; const std::type_info &type_;
}; };
std::unique_ptr<Placeholder> holder_; std::unique_ptr<Placeholder> holder_;
friend class Scope; friend class Scope;
const std::string* name_; const std::string *name_;
}; };
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -45,10 +45,10 @@ void Loader<Dtype, P>::LoadVar(framework::LoDTensor *tensor, ...@@ -45,10 +45,10 @@ void Loader<Dtype, P>::LoadVar(framework::LoDTensor *tensor,
std::ifstream is(file_path); std::ifstream is(file_path);
std::streampos pos = is.tellg(); // save current position std::streampos pos = is.tellg(); // save current position
is.seekg(0, std::ios::end); is.seekg(0, std::ios::end);
// std::cout << " file length = " << is.tellg() << std::endl; // std::cout << " file length = " << is.tellg() << std::endl;
is.seekg(pos); // restore saved position is.seekg(pos); // restore saved position
// 1. version // 1. version
uint32_t version; uint32_t version;
...@@ -106,34 +106,34 @@ void Loader<Dtype, P>::LoadVar(framework::LoDTensor *tensor, ...@@ -106,34 +106,34 @@ void Loader<Dtype, P>::LoadVar(framework::LoDTensor *tensor,
int type_size = 0; int type_size = 0;
// std::cout << " desc pre type: "; // std::cout << " desc pre type: ";
switch (desc.data_type()) { switch (desc.data_type()) {
case framework::proto::VarType::FP16: case framework::proto::VarType::FP16:
// std::cout << "FP16" << std::endl; // std::cout << "FP16" << std::endl;
type_size = 2; type_size = 2;
break; break;
case framework::proto::VarType::FP32: case framework::proto::VarType::FP32:
type_size = 4; type_size = 4;
memory = tensor->mutable_data<float>(); memory = tensor->mutable_data<float>();
// std::cout << "FP32" << std::endl; // std::cout << "FP32" << std::endl;
break; break;
case framework::proto::VarType::FP64: case framework::proto::VarType::FP64:
type_size = 8; type_size = 8;
// std::cout << "FP64" << std::endl; // std::cout << "FP64" << std::endl;
break; break;
case framework::proto::VarType::INT32: case framework::proto::VarType::INT32:
type_size = 4; type_size = 4;
// std::cout << "INT32" << std::endl; // std::cout << "INT32" << std::endl;
break; break;
case framework::proto::VarType::INT64: case framework::proto::VarType::INT64:
type_size = 8; type_size = 8;
// std::cout << "INT64" << std::endl; // std::cout << "INT64" << std::endl;
break; break;
case framework::proto::VarType::BOOL: case framework::proto::VarType::BOOL:
type_size = 1; type_size = 1;
// std::cout << "BOOL" << std::endl; // std::cout << "BOOL" << std::endl;
break; break;
default: default:
break; break;
// std::cout << " not support" << std::endl; // std::cout << " not support" << std::endl;
} }
// std::cout << " malloc size: " << memory_size * type_size << std::endl; // std::cout << " malloc size: " << memory_size * type_size << std::endl;
...@@ -143,8 +143,8 @@ void Loader<Dtype, P>::LoadVar(framework::LoDTensor *tensor, ...@@ -143,8 +143,8 @@ void Loader<Dtype, P>::LoadVar(framework::LoDTensor *tensor,
}; };
template <typename Dtype, Precision P> template <typename Dtype, Precision P>
const framework::Program<Dtype, P> Loader<Dtype, P>::Load( const framework::Program<Dtype, P>
const std::string &dirname) { Loader<Dtype, P>::Load(const std::string &dirname) {
std::string model_filename = dirname + "/__model__"; std::string model_filename = dirname + "/__model__";
std::string program_desc_str; std::string program_desc_str;
ReadBinaryFile(model_filename, &program_desc_str); ReadBinaryFile(model_filename, &program_desc_str);
...@@ -217,43 +217,43 @@ const framework::Program<Dtype, P> Loader<Dtype, P>::Load( ...@@ -217,43 +217,43 @@ const framework::Program<Dtype, P> Loader<Dtype, P>::Load(
// std::cout << " attr type: " << attr.type() << std::endl; // std::cout << " attr type: " << attr.type() << std::endl;
switch (attr.type()) { switch (attr.type()) {
case framework::proto::AttrType::BOOLEAN: case framework::proto::AttrType::BOOLEAN:
// std::cout << " boolen: " << attr.b() << std::endl; // std::cout << " boolen: " << attr.b() << std::endl;
break; break;
case framework::proto::AttrType::INT: case framework::proto::AttrType::INT:
// std::cout << " int: " << attr.i() << std::endl; // std::cout << " int: " << attr.i() << std::endl;
break; break;
case framework::proto::AttrType::FLOAT: case framework::proto::AttrType::FLOAT:
// std::cout << " float: " << attr.f() << std::endl; // std::cout << " float: " << attr.f() << std::endl;
case framework::proto::AttrType::STRING: case framework::proto::AttrType::STRING:
// std::cout << " string: " << attr.s() << std::endl; // std::cout << " string: " << attr.s() << std::endl;
case framework::proto::AttrType::BOOLEANS: case framework::proto::AttrType::BOOLEANS:
// std::vector<bool> // std::vector<bool>
// bools(attr.bools_size()); // bools(attr.bools_size());
for (int y = 0; y < attr.bools_size(); ++y) { for (int y = 0; y < attr.bools_size(); ++y) {
// std::cout << " bool - " << attr.bools(y) << // std::cout << " bool - " << attr.bools(y) <<
// std::endl; // std::endl;
} }
case framework::proto::AttrType::LONG: case framework::proto::AttrType::LONG:
// std::cout << " long: " << attr.l() << std::endl; // std::cout << " long: " << attr.l() << std::endl;
case framework::proto::AttrType::FLOATS: case framework::proto::AttrType::FLOATS:
for (int y = 0; y < attr.floats_size(); ++y) { for (int y = 0; y < attr.floats_size(); ++y) {
// std::cout << " float - " << y << ": " << // std::cout << " float - " << y << ": " <<
// attr.floats(y) // attr.floats(y)
// << std::endl; // << std::endl;
} }
case framework::proto::AttrType::INTS: case framework::proto::AttrType::INTS:
for (int y = 0; y < attr.ints_size(); ++y) { for (int y = 0; y < attr.ints_size(); ++y) {
// std::cout << " int - " << y << ": " << // std::cout << " int - " << y << ": " <<
// attr.ints(y) // attr.ints(y)
// << std::endl; // << std::endl;
} }
case framework::proto::AttrType::STRINGS: case framework::proto::AttrType::STRINGS:
for (int y = 0; y < attr.strings_size(); ++y) { for (int y = 0; y < attr.strings_size(); ++y) {
// std::cout << " string - " << y << ": " << // std::cout << " string - " << y << ": " <<
// attr.strings(y) // attr.strings(y)
// << std::endl; // << std::endl;
} }
} }
} }
} }
...@@ -280,10 +280,10 @@ const framework::Program<Dtype, P> Loader<Dtype, P>::Load( ...@@ -280,10 +280,10 @@ const framework::Program<Dtype, P> Loader<Dtype, P>::Load(
// std::cout << " to load " << var.name() << std::endl; // std::cout << " to load " << var.name() << std::endl;
std::string file_path = dirname + "/" + var.name(); std::string file_path = dirname + "/" + var.name();
std::ifstream is(file_path); std::ifstream is(file_path);
std::streampos pos = is.tellg(); // save current position std::streampos pos = is.tellg(); // save current position
is.seekg(0, std::ios::end); is.seekg(0, std::ios::end);
// std::cout << " file length = " << is.tellg() << std::endl; // std::cout << " file length = " << is.tellg() << std::endl;
is.seekg(pos); // restore saved position is.seekg(pos); // restore saved position
// 1. version // 1. version
uint32_t version; uint32_t version;
...@@ -333,33 +333,33 @@ const framework::Program<Dtype, P> Loader<Dtype, P>::Load( ...@@ -333,33 +333,33 @@ const framework::Program<Dtype, P> Loader<Dtype, P>::Load(
int type_size = 0; int type_size = 0;
// std::cout << " desc pre type: "; // std::cout << " desc pre type: ";
switch (desc.data_type()) { switch (desc.data_type()) {
case framework::proto::VarType::FP16: case framework::proto::VarType::FP16:
// std::cout << "FP16" << std::endl; // std::cout << "FP16" << std::endl;
type_size = 2; type_size = 2;
break; break;
case framework::proto::VarType::FP32: case framework::proto::VarType::FP32:
type_size = 4; type_size = 4;
// std::cout << "FP32" << std::endl; // std::cout << "FP32" << std::endl;
break; break;
case framework::proto::VarType::FP64: case framework::proto::VarType::FP64:
type_size = 8; type_size = 8;
// std::cout << "FP64" << std::endl; // std::cout << "FP64" << std::endl;
break; break;
case framework::proto::VarType::INT32: case framework::proto::VarType::INT32:
type_size = 4; type_size = 4;
// std::cout << "INT32" << std::endl; // std::cout << "INT32" << std::endl;
break; break;
case framework::proto::VarType::INT64: case framework::proto::VarType::INT64:
type_size = 8; type_size = 8;
// std::cout << "INT64" << std::endl; // std::cout << "INT64" << std::endl;
break; break;
case framework::proto::VarType::BOOL: case framework::proto::VarType::BOOL:
type_size = 1; type_size = 1;
// std::cout << "BOOL" << std::endl; // std::cout << "BOOL" << std::endl;
break; break;
default: default:
break; break;
// std::cout << " not support" << std::endl; // std::cout << " not support" << std::endl;
} }
// std::cout << " malloc size: " << memory_size * type_size // std::cout << " malloc size: " << memory_size * type_size
...@@ -381,4 +381,4 @@ const framework::Program<Dtype, P> Loader<Dtype, P>::Load( ...@@ -381,4 +381,4 @@ const framework::Program<Dtype, P> Loader<Dtype, P>::Load(
template class Loader<CPU, Precision::FP32>; template class Loader<CPU, Precision::FP32>;
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -29,11 +29,11 @@ namespace paddle_mobile { ...@@ -29,11 +29,11 @@ namespace paddle_mobile {
template <typename Dtype, Precision P = Precision::FP32> template <typename Dtype, Precision P = Precision::FP32>
class Loader : PaddleMobileObject { class Loader : PaddleMobileObject {
public: public:
const framework::Program<Dtype, P> Load(const std::string &dirname); const framework::Program<Dtype, P> Load(const std::string &dirname);
private: private:
void LoadVar(framework::LoDTensor *tensor, const std::string &file_path); void LoadVar(framework::LoDTensor *tensor, const std::string &file_path);
}; };
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -47,5 +47,5 @@ void Free(void *ptr) { ...@@ -47,5 +47,5 @@ void Free(void *ptr) {
} }
} }
} // namespace memory } // namespace memory
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -37,11 +37,10 @@ void Free(void *ptr); ...@@ -37,11 +37,10 @@ void Free(void *ptr);
* std::unique_ptr<T> in tensor.h. * std::unique_ptr<T> in tensor.h.
* static_cast * static_cast
*/ */
template <typename T> template <typename T> class PODDeleter {
class PODDeleter {
static_assert(std::is_pod<T>::value, "T must be POD"); static_assert(std::is_pod<T>::value, "T must be POD");
public: public:
explicit PODDeleter(){}; explicit PODDeleter(){};
void operator()(T *ptr) { Free(static_cast<void *>(ptr)); } void operator()(T *ptr) { Free(static_cast<void *>(ptr)); }
...@@ -55,12 +54,11 @@ class PODDeleter { ...@@ -55,12 +54,11 @@ class PODDeleter {
* std::unique_ptr<T> in tensor.h. * std::unique_ptr<T> in tensor.h.
* reinterpret_cast * reinterpret_cast
*/ */
template <typename T> template <typename T> class PlainDeleter {
class PlainDeleter { public:
public:
explicit PlainDeleter(){}; explicit PlainDeleter(){};
void operator()(T *ptr) { Free(reinterpret_cast<void *>(ptr)); } void operator()(T *ptr) { Free(reinterpret_cast<void *>(ptr)); }
}; };
} // namespace memory } // namespace memory
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -72,5 +72,5 @@ void ConvOp<Dtype, T>::InferShape() const { ...@@ -72,5 +72,5 @@ void ConvOp<Dtype, T>::InferShape() const {
template class ConvOp<CPU, float>; template class ConvOp<CPU, float>;
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -28,9 +28,9 @@ using namespace framework; ...@@ -28,9 +28,9 @@ using namespace framework;
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class ConvOp : public framework::OperatorWithKernel<DeviceType> { class ConvOp : public framework::OperatorWithKernel<DeviceType> {
public: public:
ConvOp(const std::string& type, const VariableNameMap& inputs, ConvOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap& outputs, const framework::AttributeMap& attrs, const VariableNameMap &outputs, const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope) std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType>(type, inputs, outputs, attrs, : framework::OperatorWithKernel<DeviceType>(type, inputs, outputs, attrs,
scope), scope),
...@@ -39,7 +39,7 @@ class ConvOp : public framework::OperatorWithKernel<DeviceType> { ...@@ -39,7 +39,7 @@ class ConvOp : public framework::OperatorWithKernel<DeviceType> {
using framework::OperatorWithKernel<DeviceType>::OperatorWithKernel; using framework::OperatorWithKernel<DeviceType>::OperatorWithKernel;
void InferShape() const override; void InferShape() const override;
protected: protected:
void RunImpl() const { void RunImpl() const {
operators::ConvKernel<DeviceType, T, ConvParam> kernel; operators::ConvKernel<DeviceType, T, ConvParam> kernel;
kernel.Compute(param_); kernel.Compute(param_);
...@@ -48,5 +48,5 @@ class ConvOp : public framework::OperatorWithKernel<DeviceType> { ...@@ -48,5 +48,5 @@ class ConvOp : public framework::OperatorWithKernel<DeviceType> {
ConvParam param_; ConvParam param_;
}; };
} // operators } // operators
} // paddle_mobile } // paddle_mobile
...@@ -21,9 +21,9 @@ SOFTWARE. ...@@ -21,9 +21,9 @@ SOFTWARE.
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
bool IsExpand(const std::vector<int64_t>& filter_dim, bool IsExpand(const std::vector<int64_t> &filter_dim,
const std::vector<int>& strides, const std::vector<int>& paddings, const std::vector<int> &strides, const std::vector<int> &paddings,
const std::vector<int>& dilations) { const std::vector<int> &dilations) {
bool filter_1 = true, strides_1 = true, padding_0 = true, dilation_1 = true; bool filter_1 = true, strides_1 = true, padding_0 = true, dilation_1 = true;
for (size_t j = 0; j < strides.size(); ++j) { for (size_t j = 0; j < strides.size(); ++j) {
filter_1 = filter_1 && (static_cast<int>(filter_dim[j + 2]) == 1); filter_1 = filter_1 && (static_cast<int>(filter_dim[j + 2]) == 1);
...@@ -35,8 +35,8 @@ bool IsExpand(const std::vector<int64_t>& filter_dim, ...@@ -35,8 +35,8 @@ bool IsExpand(const std::vector<int64_t>& filter_dim,
} }
template <> template <>
void ConvKernel<CPU, float, ConvParam>::Compute(const ConvParam& param) const { void ConvKernel<CPU, float, ConvParam>::Compute(const ConvParam &param) const {
const Tensor* input = param.Input(); const Tensor *input = param.Input();
std::cout << " conv param " << param << std::endl; std::cout << " conv param " << param << std::endl;
...@@ -45,7 +45,7 @@ void ConvKernel<CPU, float, ConvParam>::Compute(const ConvParam& param) const { ...@@ -45,7 +45,7 @@ void ConvKernel<CPU, float, ConvParam>::Compute(const ConvParam& param) const {
// that avoids modifying the variable in the Scope. // that avoids modifying the variable in the Scope.
Tensor filter = *param.Filter(); Tensor filter = *param.Filter();
Tensor* output = param.Output(); Tensor *output = param.Output();
// output->mutable_data<T>(context.GetPlace()); // output->mutable_data<T>(context.GetPlace());
int groups = param.Groups(); int groups = param.Groups();
...@@ -149,5 +149,5 @@ void ConvKernel<CPU, float, ConvParam>::Compute(const ConvParam& param) const { ...@@ -149,5 +149,5 @@ void ConvKernel<CPU, float, ConvParam>::Compute(const ConvParam& param) const {
template class ConvKernel<CPU, float, ConvParam>; template class ConvKernel<CPU, float, ConvParam>;
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -31,7 +31,7 @@ using namespace framework; ...@@ -31,7 +31,7 @@ using namespace framework;
template <typename DeviceType, typename T, typename P> template <typename DeviceType, typename T, typename P>
class ConvKernel : public framework::OpKernelBase<DeviceType, ConvParam> { class ConvKernel : public framework::OpKernelBase<DeviceType, ConvParam> {
public: public:
void Compute(const ConvParam &param) const; void Compute(const ConvParam &param) const;
}; };
} }
......
...@@ -24,12 +24,11 @@ namespace math { ...@@ -24,12 +24,11 @@ namespace math {
* col = * col =
* [input_channels, filter_height, filter_width, output_height, output_width] * [input_channels, filter_height, filter_width, output_height, output_width]
*/ */
template <class T> template <class T> class Im2ColFunctor<ColFormat::kCFO, CPU, T> {
class Im2ColFunctor<ColFormat::kCFO, CPU, T> { public:
public: void operator()(const framework::Tensor &im, const std::vector<int> &dilation,
void operator()(const framework::Tensor& im, const std::vector<int>& dilation, const std::vector<int> &stride,
const std::vector<int>& stride, const std::vector<int> &padding, framework::Tensor *col) {
const std::vector<int>& padding, framework::Tensor* col) {
// PADDLE_ENFORCE(im.dims().size() == 3); // PADDLE_ENFORCE(im.dims().size() == 3);
// PADDLE_ENFORCE(col->dims().size() == 5); // PADDLE_ENFORCE(col->dims().size() == 5);
...@@ -58,8 +57,8 @@ class Im2ColFunctor<ColFormat::kCFO, CPU, T> { ...@@ -58,8 +57,8 @@ class Im2ColFunctor<ColFormat::kCFO, CPU, T> {
int channels_col = im_channels * filter_height * filter_width; int channels_col = im_channels * filter_height * filter_width;
const T* im_data = im.data<T>(); const T *im_data = im.data<T>();
T* col_data = col->data<T>(); T *col_data = col->data<T>();
for (int c = 0; c < channels_col; ++c) { for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width; int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height; int h_offset = (c / filter_width) % filter_height;
...@@ -86,13 +85,12 @@ class Im2ColFunctor<ColFormat::kCFO, CPU, T> { ...@@ -86,13 +85,12 @@ class Im2ColFunctor<ColFormat::kCFO, CPU, T> {
* col = * col =
* [input_channels, filter_height, filter_width, output_height, output_width] * [input_channels, filter_height, filter_width, output_height, output_width]
*/ */
template <class T> template <class T> class Col2ImFunctor<ColFormat::kCFO, CPU, T> {
class Col2ImFunctor<ColFormat::kCFO, CPU, T> { public:
public: void operator()(const framework::Tensor &col,
void operator()(const framework::Tensor& col, const std::vector<int> &dilation,
const std::vector<int>& dilation, const std::vector<int> &stride,
const std::vector<int>& stride, const std::vector<int> &padding, framework::Tensor *im) {
const std::vector<int>& padding, framework::Tensor* im) {
// PADDLE_ENFORCE(im->dims().size() == 3); // PADDLE_ENFORCE(im->dims().size() == 3);
// PADDLE_ENFORCE(col.dims().size() == 5); // PADDLE_ENFORCE(col.dims().size() == 5);
int im_channels = im->dims()[0]; int im_channels = im->dims()[0];
...@@ -120,8 +118,8 @@ class Col2ImFunctor<ColFormat::kCFO, CPU, T> { ...@@ -120,8 +118,8 @@ class Col2ImFunctor<ColFormat::kCFO, CPU, T> {
int channels_col = im_channels * filter_height * filter_width; int channels_col = im_channels * filter_height * filter_width;
T* im_data = im->data<T>(); T *im_data = im->data<T>();
const T* col_data = col.data<T>(); const T *col_data = col.data<T>();
for (int c = 0; c < channels_col; ++c) { for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width; int w_offset = c % filter_width;
...@@ -152,12 +150,11 @@ template class Col2ImFunctor<ColFormat::kCFO, CPU, double>; ...@@ -152,12 +150,11 @@ template class Col2ImFunctor<ColFormat::kCFO, CPU, double>;
* col = * col =
* [output_height, output_width, input_channels, filter_height, filter_width] * [output_height, output_width, input_channels, filter_height, filter_width]
*/ */
template <class T> template <class T> class Im2ColFunctor<ColFormat::kOCF, CPU, T> {
class Im2ColFunctor<ColFormat::kOCF, CPU, T> { public:
public: void operator()(const framework::Tensor &im, const std::vector<int> &dilation,
void operator()(const framework::Tensor& im, const std::vector<int>& dilation, const std::vector<int> &stride,
const std::vector<int>& stride, const std::vector<int> &padding, framework::Tensor *col) {
const std::vector<int>& padding, framework::Tensor* col) {
// PADDLE_ENFORCE(im.dims().size() == 3); // PADDLE_ENFORCE(im.dims().size() == 3);
// PADDLE_ENFORCE(col->dims().size() == 5); // PADDLE_ENFORCE(col->dims().size() == 5);
int im_channels = im.dims()[0]; int im_channels = im.dims()[0];
...@@ -177,8 +174,8 @@ class Im2ColFunctor<ColFormat::kOCF, CPU, T> { ...@@ -177,8 +174,8 @@ class Im2ColFunctor<ColFormat::kOCF, CPU, T> {
// 1, col_width, "col_width and padding(padding_left, padding_right) // 1, col_width, "col_width and padding(padding_left, padding_right)
// are " "inconsistent."); // are " "inconsistent.");
const T* im_data = im.data<T>(); const T *im_data = im.data<T>();
T* col_data = col->data<T>(); T *col_data = col->data<T>();
for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) { for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) {
for (int col_col_idx = 0; col_col_idx < col_width; ++col_col_idx) { for (int col_col_idx = 0; col_col_idx < col_width; ++col_col_idx) {
...@@ -220,13 +217,12 @@ class Im2ColFunctor<ColFormat::kOCF, CPU, T> { ...@@ -220,13 +217,12 @@ class Im2ColFunctor<ColFormat::kOCF, CPU, T> {
* col = * col =
* [output_height, output_width, input_channels, filter_height, filter_width] * [output_height, output_width, input_channels, filter_height, filter_width]
*/ */
template <class T> template <class T> class Col2ImFunctor<ColFormat::kOCF, CPU, T> {
class Col2ImFunctor<ColFormat::kOCF, CPU, T> { public:
public: void operator()(const framework::Tensor &col,
void operator()(const framework::Tensor& col, const std::vector<int> &dilation,
const std::vector<int>& dilation, const std::vector<int> &stride,
const std::vector<int>& stride, const std::vector<int> &padding, framework::Tensor *im) {
const std::vector<int>& padding, framework::Tensor* im) {
// PADDLE_ENFORCE(im->dims().size() == 3); // PADDLE_ENFORCE(im->dims().size() == 3);
// PADDLE_ENFORCE(col.dims().size() == 5); // PADDLE_ENFORCE(col.dims().size() == 5);
int im_channels = im->dims()[0]; int im_channels = im->dims()[0];
...@@ -246,8 +242,8 @@ class Col2ImFunctor<ColFormat::kOCF, CPU, T> { ...@@ -246,8 +242,8 @@ class Col2ImFunctor<ColFormat::kOCF, CPU, T> {
// 1, col_width, "col_width and padding(padding_left, padding_right) // 1, col_width, "col_width and padding(padding_left, padding_right)
// are " "inconsistent."); // are " "inconsistent.");
T* im_data = im->data<T>(); T *im_data = im->data<T>();
const T* col_data = col.data<T>(); const T *col_data = col.data<T>();
for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) { for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) {
for (int col_col_idx = 0; col_col_idx < col_width; ++col_col_idx) { for (int col_col_idx = 0; col_col_idx < col_width; ++col_col_idx) {
...@@ -289,6 +285,6 @@ template class Im2ColFunctor<ColFormat::kOCF, CPU, double>; ...@@ -289,6 +285,6 @@ template class Im2ColFunctor<ColFormat::kOCF, CPU, double>;
template class Col2ImFunctor<ColFormat::kOCF, CPU, float>; template class Col2ImFunctor<ColFormat::kOCF, CPU, float>;
template class Col2ImFunctor<ColFormat::kOCF, CPU, double>; template class Col2ImFunctor<ColFormat::kOCF, CPU, double>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -79,21 +79,21 @@ enum class ColFormat { kCFO = 0, kOCF = 1 }; ...@@ -79,21 +79,21 @@ enum class ColFormat { kCFO = 0, kOCF = 1 };
*/ */
template <ColFormat Format, typename DeviceType, typename T> template <ColFormat Format, typename DeviceType, typename T>
class Im2ColFunctor { class Im2ColFunctor {
public: public:
void operator()(const framework::Tensor& im, const std::vector<int>& dilation, void operator()(const framework::Tensor &im, const std::vector<int> &dilation,
const std::vector<int>& stride, const std::vector<int> &stride,
const std::vector<int>& padding, framework::Tensor* col); const std::vector<int> &padding, framework::Tensor *col);
}; };
template <ColFormat Format, typename DeviceType, typename T> template <ColFormat Format, typename DeviceType, typename T>
class Col2ImFunctor { class Col2ImFunctor {
public: public:
void operator()(const framework::Tensor& col, void operator()(const framework::Tensor &col,
const std::vector<int>& dilation, const std::vector<int> &dilation,
const std::vector<int>& stride, const std::vector<int> &stride,
const std::vector<int>& padding, framework::Tensor* im); const std::vector<int> &padding, framework::Tensor *im);
}; };
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -21,7 +21,7 @@ namespace math { ...@@ -21,7 +21,7 @@ namespace math {
template <> template <>
void gemm<float>(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, void gemm<float>(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB,
const int M, const int N, const int K, const float alpha, const int M, const int N, const int K, const float alpha,
const float* A, const float* B, const float beta, float* C) { const float *A, const float *B, const float beta, float *C) {
int lda = (transA == CblasNoTrans) ? K : M; int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K; int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N; int ldc = N;
...@@ -32,8 +32,8 @@ void gemm<float>(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, ...@@ -32,8 +32,8 @@ void gemm<float>(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB,
template <> template <>
void gemm<double>(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, void gemm<double>(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB,
const int M, const int N, const int K, const double alpha, const int M, const int N, const int K, const double alpha,
const double* A, const double* B, const double beta, const double *A, const double *B, const double beta,
double* C) { double *C) {
int lda = (transA == CblasNoTrans) ? K : M; int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K; int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N; int ldc = N;
...@@ -43,8 +43,8 @@ void gemm<double>(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, ...@@ -43,8 +43,8 @@ void gemm<double>(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB,
template <> template <>
void gemm<float>(const bool transA, const bool transB, const int M, const int N, void gemm<float>(const bool transA, const bool transB, const int M, const int N,
const int K, const float alpha, const float* A, const int lda, const int K, const float alpha, const float *A, const int lda,
const float* B, const int ldb, const float beta, float* C, const float *B, const int ldb, const float beta, float *C,
const int ldc) { const int ldc) {
cblas_sgemm(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans, cblas_sgemm(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans,
transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A, transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
...@@ -53,18 +53,18 @@ void gemm<float>(const bool transA, const bool transB, const int M, const int N, ...@@ -53,18 +53,18 @@ void gemm<float>(const bool transA, const bool transB, const int M, const int N,
template <> template <>
void gemm<double>(const bool transA, const bool transB, const int M, void gemm<double>(const bool transA, const bool transB, const int M,
const int N, const int K, const double alpha, const double* A, const int N, const int K, const double alpha, const double *A,
const int lda, const double* B, const int ldb, const int lda, const double *B, const int ldb,
const double beta, double* C, const int ldc) { const double beta, double *C, const int ldc) {
cblas_dgemm(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans, cblas_dgemm(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans,
transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A, transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
lda, B, ldb, beta, C, ldc); lda, B, ldb, beta, C, ldc);
} }
template <> template <>
void matmul<float>(const framework::Tensor& matrix_a, bool trans_a, void matmul<float>(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor& matrix_b, bool trans_b, float alpha, const framework::Tensor &matrix_b, bool trans_b, float alpha,
framework::Tensor* matrix_out, float beta) { framework::Tensor *matrix_out, float beta) {
auto dim_a = matrix_a.dims(); auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims(); auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims(); auto dim_out = matrix_out->dims();
...@@ -89,9 +89,9 @@ void matmul<float>(const framework::Tensor& matrix_a, bool trans_a, ...@@ -89,9 +89,9 @@ void matmul<float>(const framework::Tensor& matrix_a, bool trans_a,
} }
template <> template <>
void matmul<double>(const framework::Tensor& matrix_a, bool trans_a, void matmul<double>(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor& matrix_b, bool trans_b, const framework::Tensor &matrix_b, bool trans_b,
double alpha, framework::Tensor* matrix_out, double beta) { double alpha, framework::Tensor *matrix_out, double beta) {
auto dim_a = matrix_a.dims(); auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims(); auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims(); auto dim_out = matrix_out->dims();
...@@ -115,6 +115,6 @@ void matmul<double>(const framework::Tensor& matrix_a, bool trans_a, ...@@ -115,6 +115,6 @@ void matmul<double>(const framework::Tensor& matrix_a, bool trans_a,
matrix_b.data<double>(), beta, matrix_out->data<double>()); matrix_b.data<double>(), beta, matrix_out->data<double>());
} }
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -14,9 +14,9 @@ limitations under the License. */ ...@@ -14,9 +14,9 @@ limitations under the License. */
#pragma once #pragma once
#include "framework/tensor.h"
#include <cblas.h> #include <cblas.h>
#include <cmath> #include <cmath>
#include "framework/tensor.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
...@@ -24,19 +24,19 @@ namespace math { ...@@ -24,19 +24,19 @@ namespace math {
template <typename T> template <typename T>
void gemm(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, void gemm(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB,
const int M, const int N, const int K, const T alpha, const T* A, const int M, const int N, const int K, const T alpha, const T *A,
const T* B, const T beta, T* C); const T *B, const T beta, T *C);
template <typename T> template <typename T>
void gemm(const bool transA, const bool transB, const int M, const int N, void gemm(const bool transA, const bool transB, const int M, const int N,
const int K, const T alpha, const T* A, const int lda, const T* B, const int K, const T alpha, const T *A, const int lda, const T *B,
const int ldb, const T beta, T* C, const int ldc); const int ldb, const T beta, T *C, const int ldc);
// matrix multiply with continuous memory // matrix multiply with continuous memory
template <typename T> template <typename T>
void matmul(const framework::Tensor& matrix_a, bool trans_a, void matmul(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor& matrix_b, bool trans_b, T alpha, const framework::Tensor &matrix_b, bool trans_b, T alpha,
framework::Tensor* matrix_out, T beta); framework::Tensor *matrix_out, T beta);
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -25,12 +25,11 @@ using Tensor = paddle_mobile::framework::Tensor; ...@@ -25,12 +25,11 @@ using Tensor = paddle_mobile::framework::Tensor;
* [input_channels, filter_depth, filter_height, filter_width, * [input_channels, filter_depth, filter_height, filter_width,
* output_depth, output_height, output_width] * output_depth, output_height, output_width]
*/ */
template <typename T> template <typename T> class Vol2ColFunctor<CPU, T> {
class Vol2ColFunctor<CPU, T> { public:
public: void operator()(const Tensor &vol, const std::vector<int> &dilations,
void operator()(const Tensor& vol, const std::vector<int>& dilations, const std::vector<int> &strides,
const std::vector<int>& strides, const std::vector<int> &paddings, Tensor *col) const {
const std::vector<int>& paddings, Tensor* col) const {
// PADDLE_ENFORCE(vol.dims().size() == 4); // PADDLE_ENFORCE(vol.dims().size() == 4);
// PADDLE_ENFORCE(col->dims().size() == 7); // PADDLE_ENFORCE(col->dims().size() == 7);
...@@ -69,8 +68,8 @@ class Vol2ColFunctor<CPU, T> { ...@@ -69,8 +68,8 @@ class Vol2ColFunctor<CPU, T> {
// "input_width and output_width are " // "input_width and output_width are "
// "mismatching."); // "mismatching.");
const T* vol_data = vol.data<T>(); const T *vol_data = vol.data<T>();
T* col_data = col->data<T>(); T *col_data = col->data<T>();
for (int c = 0; c < channels_col; ++c) { for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width; int w_offset = c % filter_width;
...@@ -108,12 +107,11 @@ class Vol2ColFunctor<CPU, T> { ...@@ -108,12 +107,11 @@ class Vol2ColFunctor<CPU, T> {
* [input_channels, filter_depth, filter_height, filter_width, * [input_channels, filter_depth, filter_height, filter_width,
* output_depth, output_height, output_width] * output_depth, output_height, output_width]
*/ */
template <typename T> template <typename T> class Col2VolFunctor<CPU, T> {
class Col2VolFunctor<CPU, T> { public:
public: void operator()(const Tensor &col, const std::vector<int> &dilations,
void operator()(const Tensor& col, const std::vector<int>& dilations, const std::vector<int> &strides,
const std::vector<int>& strides, const std::vector<int> &paddings, Tensor *vol) const {
const std::vector<int>& paddings, Tensor* vol) const {
// PADDLE_ENFORCE(vol->dims().size() == 4); // PADDLE_ENFORCE(vol->dims().size() == 4);
// PADDLE_ENFORCE(col.dims().size() == 7); // PADDLE_ENFORCE(col.dims().size() == 7);
...@@ -151,8 +149,8 @@ class Col2VolFunctor<CPU, T> { ...@@ -151,8 +149,8 @@ class Col2VolFunctor<CPU, T> {
// output_width, // output_width,
// "input_width and output_width are " // "input_width and output_width are "
// "mismatching."); // "mismatching.");
T* vol_data = vol->data<T>(); T *vol_data = vol->data<T>();
const T* col_data = col.data<T>(); const T *col_data = col.data<T>();
for (int c = 0; c < channels_col; ++c) { for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width; int w_offset = c % filter_width;
...@@ -190,6 +188,6 @@ template class Vol2ColFunctor<CPU, double>; ...@@ -190,6 +188,6 @@ template class Vol2ColFunctor<CPU, double>;
template class Col2VolFunctor<CPU, float>; template class Col2VolFunctor<CPU, float>;
template class Col2VolFunctor<CPU, double>; template class Col2VolFunctor<CPU, double>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -64,22 +64,20 @@ namespace math { ...@@ -64,22 +64,20 @@ namespace math {
*/ */
using Tensor = paddle_mobile::framework::Tensor; using Tensor = paddle_mobile::framework::Tensor;
template <typename DeviceType, typename T> template <typename DeviceType, typename T> class Vol2ColFunctor {
class Vol2ColFunctor { public:
public: void operator()(const Tensor &vol, const std::vector<int> &dilations,
void operator()(const Tensor& vol, const std::vector<int>& dilations, const std::vector<int> &strides,
const std::vector<int>& strides, const std::vector<int> &paddings, Tensor *col) const;
const std::vector<int>& paddings, Tensor* col) const;
}; };
template <typename DeviceType, typename T> template <typename DeviceType, typename T> class Col2VolFunctor {
class Col2VolFunctor { public:
public: void operator()(const Tensor &col, const std::vector<int> &dilations,
void operator()(const Tensor& col, const std::vector<int>& dilations, const std::vector<int> &strides,
const std::vector<int>& strides, const std::vector<int> &paddings, Tensor *vol) const;
const std::vector<int>& paddings, Tensor* vol) const;
}; };
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -21,7 +21,7 @@ SOFTWARE. ...@@ -21,7 +21,7 @@ SOFTWARE.
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
std::ostream& operator<<(std::ostream& os, const ConvParam& conv_param) { std::ostream &operator<<(std::ostream &os, const ConvParam &conv_param) {
os << "parameter of conv: " << std::endl; os << "parameter of conv: " << std::endl;
os << " stride: " os << " stride: "
<< " (" << conv_param.Strides()[0] << conv_param.Strides()[1] << ") " << " (" << conv_param.Strides()[0] << conv_param.Strides()[1] << ") "
...@@ -39,5 +39,5 @@ std::ostream& operator<<(std::ostream& os, const ConvParam& conv_param) { ...@@ -39,5 +39,5 @@ std::ostream& operator<<(std::ostream& os, const ConvParam& conv_param) {
return os; return os;
} }
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -30,8 +30,8 @@ namespace operators { ...@@ -30,8 +30,8 @@ namespace operators {
using namespace framework; using namespace framework;
class OpParam : PaddleMobileObject { class OpParam : PaddleMobileObject {
public: public:
protected: protected:
template <typename T> template <typename T>
static T *InputFrom(const VariableNameMap &inputs, const Scope &scope) { static T *InputFrom(const VariableNameMap &inputs, const Scope &scope) {
return GetVarValue<T>("Input", inputs, scope); return GetVarValue<T>("Input", inputs, scope);
...@@ -67,7 +67,7 @@ class OpParam : PaddleMobileObject { ...@@ -67,7 +67,7 @@ class OpParam : PaddleMobileObject {
}; };
class ConvParam : OpParam { class ConvParam : OpParam {
public: public:
ConvParam(const VariableNameMap &inputs, const VariableNameMap &outputs, ConvParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const framework::AttributeMap &attrs, const framework::AttributeMap &attrs,
const framework::Scope &scope) { const framework::Scope &scope) {
...@@ -94,7 +94,7 @@ class ConvParam : OpParam { ...@@ -94,7 +94,7 @@ class ConvParam : OpParam {
const int &Groups() const { return groups; } const int &Groups() const { return groups; }
private: private:
Tensor *input_; Tensor *input_;
Tensor *output_; Tensor *output_;
LoDTensor *filter_; LoDTensor *filter_;
...@@ -106,5 +106,5 @@ class ConvParam : OpParam { ...@@ -106,5 +106,5 @@ class ConvParam : OpParam {
std::ostream &operator<<(std::ostream &os, const ConvParam &conv_param); std::ostream &operator<<(std::ostream &os, const ConvParam &conv_param);
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -14,9 +14,9 @@ limitations under the License. */ ...@@ -14,9 +14,9 @@ limitations under the License. */
#pragma once #pragma once
#include "framework/framework.pb.h"
#include <string> #include <string>
#include <typeindex> #include <typeindex>
#include "framework/framework.pb.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
...@@ -47,70 +47,70 @@ inline proto::VarType::Type ToDataType(std::type_index type) { ...@@ -47,70 +47,70 @@ inline proto::VarType::Type ToDataType(std::type_index type) {
inline std::type_index ToTypeIndex(proto::VarType::Type type) { inline std::type_index ToTypeIndex(proto::VarType::Type type) {
switch (type) { switch (type) {
// case proto::VarType::FP16: // case proto::VarType::FP16:
// return typeid(platform::float16); // return typeid(platform::float16);
case proto::VarType::FP32: case proto::VarType::FP32:
return typeid(float); return typeid(float);
case proto::VarType::FP64: case proto::VarType::FP64:
return typeid(double); return typeid(double);
case proto::VarType::INT32: case proto::VarType::INT32:
return typeid(int); return typeid(int);
case proto::VarType::INT64: case proto::VarType::INT64:
return typeid(int64_t); return typeid(int64_t);
case proto::VarType::BOOL: case proto::VarType::BOOL:
return typeid(bool); return typeid(bool);
default: default:
// PADDLE_THROW("Not support type %d", type); // PADDLE_THROW("Not support type %d", type);
printf("Not support type %d", type); printf("Not support type %d", type);
} }
} }
template <typename Visitor> template <typename Visitor>
inline void VisitDataType(proto::VarType::Type type, Visitor visitor) { inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
switch (type) { switch (type) {
// case proto::VarType::FP16: // case proto::VarType::FP16:
// visitor.template operator()<platform::float16>(); // visitor.template operator()<platform::float16>();
// break; // break;
case proto::VarType::FP32: case proto::VarType::FP32:
visitor.template operator()<float>(); visitor.template operator()<float>();
break; break;
case proto::VarType::FP64: case proto::VarType::FP64:
visitor.template operator()<double>(); visitor.template operator()<double>();
break; break;
case proto::VarType::INT32: case proto::VarType::INT32:
visitor.template operator()<int>(); visitor.template operator()<int>();
break; break;
case proto::VarType::INT64: case proto::VarType::INT64:
visitor.template operator()<int64_t>(); visitor.template operator()<int64_t>();
break; break;
case proto::VarType::BOOL: case proto::VarType::BOOL:
visitor.template operator()<bool>(); visitor.template operator()<bool>();
break; break;
default: default:
// PADDLE_THROW("Not supported"); // PADDLE_THROW("Not supported");
printf("Not supported"); printf("Not supported");
} }
} }
inline std::string DataTypeToString(const proto::VarType::Type type) { inline std::string DataTypeToString(const proto::VarType::Type type) {
switch (type) { switch (type) {
case proto::VarType::FP16: case proto::VarType::FP16:
return "float16"; return "float16";
case proto::VarType::FP32: case proto::VarType::FP32:
return "float32"; return "float32";
case proto::VarType::FP64: case proto::VarType::FP64:
return "float64"; return "float64";
case proto::VarType::INT16: case proto::VarType::INT16:
return "int16"; return "int16";
case proto::VarType::INT32: case proto::VarType::INT32:
return "int32"; return "int32";
case proto::VarType::INT64: case proto::VarType::INT64:
return "int64"; return "int64";
case proto::VarType::BOOL: case proto::VarType::BOOL:
return "bool"; return "bool";
default: default:
// PADDLE_THROW("Not support type %d", type); // PADDLE_THROW("Not support type %d", type);
printf("Not support type %d", type); printf("Not support type %d", type);
} }
} }
...@@ -120,5 +120,5 @@ inline std::ostream &operator<<(std::ostream &out, ...@@ -120,5 +120,5 @@ inline std::ostream &operator<<(std::ostream &out,
return out; return out;
} }
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -16,10 +16,10 @@ limitations under the License. */ ...@@ -16,10 +16,10 @@ limitations under the License. */
// Disable the copy and assignment operator for a class. // Disable the copy and assignment operator for a class.
#ifndef DISABLE_COPY_AND_ASSIGN #ifndef DISABLE_COPY_AND_ASSIGN
#define DISABLE_COPY_AND_ASSIGN(classname) \ #define DISABLE_COPY_AND_ASSIGN(classname) \
private: \ private: \
classname(const classname&) = delete; \ classname(const classname &) = delete; \
classname(classname&&) = delete; \ classname(classname &&) = delete; \
classname& operator=(const classname&) = delete; \ classname &operator=(const classname &) = delete; \
classname& operator=(classname&&) = delete classname &operator=(classname &&) = delete
#endif #endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册