提交 d6825bf3 编写于 作者: 朔-望's avatar 朔-望 提交者: GitHub

Merge pull request #218 from allonli/develop

update code style to Google
---
Language: Cpp
BasedOnStyle: LLVM
BasedOnStyle: Google
Standard: Cpp11
...
......@@ -51,12 +51,13 @@ struct Print;
struct Print {
friend struct ToLog;
template <typename T> Print &operator<<(T const &value) {
template <typename T>
Print &operator<<(T const &value) {
buffer_ << value;
return *this;
}
private:
private:
void print(LogLevel level) {
buffer_ << std::endl;
if (level == kLOG_ERROR) {
......@@ -76,14 +77,15 @@ struct ToLog {
printer_ << logs[level] << " " << info << ":" << std::string(blanks, ' ');
}
template <typename T> ToLog &operator<<(T const &value) {
template <typename T>
ToLog &operator<<(T const &value) {
printer_ << value;
return *this;
}
~ToLog() { printer_.print(level_); }
private:
private:
LogLevel level_;
Print printer_;
};
......@@ -109,16 +111,16 @@ private:
<< (strrchr(__FILE__, '/') ? (strrchr(__FILE__, '/') + 1) : __FILE__) \
<< "] [line: " << __LINE__ << "] ") \
.str())
} // namespace paddle_mobile
} // namespace paddle_mobile
#define LOGF(level, format, ...) \
if (level > paddle_mobile::log_level) { \
} else \
#define LOGF(level, format, ...) \
if (level > paddle_mobile::log_level) { \
} else \
printf(format, ##__VA_ARGS__)
#define DLOGF(format, ...) \
if (paddle_mobile::kLOG_DEBUG > paddle_mobile::log_level) { \
} else \
#define DLOGF(format, ...) \
if (paddle_mobile::kLOG_DEBUG > paddle_mobile::log_level) { \
} else \
printf(format, ##__VA_ARGS__)
#else
......@@ -140,30 +142,34 @@ enum LogLevel {
struct ToLog;
struct Print {
friend struct ToLog;
template <typename T> Print &operator<<(T const &value) {}
template <typename T>
Print &operator<<(T const &value) {}
private:
private:
};
struct ToLog {
ToLog(LogLevel level) {}
template <typename T> ToLog &operator<<(T const &value) { return *this; }
template <typename T>
ToLog &operator<<(T const &value) {
return *this;
}
};
#define LOG(level) \
if (true) { \
} else \
#define LOG(level) \
if (true) { \
} else \
paddle_mobile::ToLog(level)
#define DLOG \
if (true) { \
} else \
#define DLOG \
if (true) { \
} else \
paddle_mobile::ToLog(paddle_mobile::kLOG_DEBUG)
#define LOGF(level, format, ...)
#define DLOGF(format, ...)
} // namespace paddle_mobile
} // namespace paddle_mobile
#endif
......@@ -17,18 +17,19 @@ SOFTWARE.
==============================================================================*/
#pragma once;
#include "framework/attribute.h"
#include <map>
#include <string>
#include "framework/attribute.h"
namespace paddle_mobile {
namespace framework {
template <typename Dtype> class OperatorBase;
template <typename Dtype>
class OperatorBase;
class OpDesc;
class BlockDesc;
class InferShapeContext;
} // namespace framework
} // namespace framework
using VariableNameMap = std::map<std::string, std::vector<std::string>>;
......@@ -49,4 +50,4 @@ using InferVarTypeFN = std::function<void(const framework::OpDesc & /*op_desc*/,
framework::BlockDesc * /*block*/)>;
using InferShapeFN = std::function<void(framework::InferShapeContext *)>;
}; // namespace paddle_mobile
}; // namespace paddle_mobile
......@@ -24,7 +24,8 @@ enum class Precision : int { FP32 = 0 };
//! device type
enum DeviceTypeEnum { kINVALID = -1, kCPU = 0, kFPGA = 1, kGPU_MALI = 2 };
template <DeviceTypeEnum T> struct DeviceType {};
template <DeviceTypeEnum T>
struct DeviceType {};
typedef DeviceType<kCPU> CPU;
typedef DeviceType<kFPGA> FPGA;
......@@ -60,4 +61,4 @@ enum PMStatus {
PMUnImplError = 0x07, /*!< Unimplement error. */
PMWrongDevice = 0x08 /*!< un-correct device. */
};
} // namespace paddle_mobile
} // namespace paddle_mobile
......@@ -21,9 +21,13 @@ SOFTWARE.
#pragma once
namespace paddle_mobile {
template <int ID, typename Type> struct IDToType { typedef Type type_t; };
template <int ID, typename Type>
struct IDToType {
typedef Type type_t;
};
template <typename F, typename... Ts> struct VariantHelper {
template <typename F, typename... Ts>
struct VariantHelper {
static const size_t size = sizeof(F) > VariantHelper<Ts...>::size
? sizeof(F)
: VariantHelper<Ts...>::size;
......@@ -37,7 +41,8 @@ template <typename F, typename... Ts> struct VariantHelper {
}
};
template <typename F> struct VariantHelper<F> {
template <typename F>
struct VariantHelper<F> {
static const size_t size = sizeof(F);
inline static void Destroy(size_t id, void *data) {
if (id == typeid(F).hash_code()) {
......@@ -48,8 +53,9 @@ template <typename F> struct VariantHelper<F> {
}
};
template <size_t size> class RawData {
public:
template <size_t size>
class RawData {
public:
char data[size];
RawData() {}
RawData(const RawData &raw_data) { strcpy(data, raw_data.data); }
......@@ -58,7 +64,8 @@ public:
// }
};
template <typename... Ts> struct Variant {
template <typename... Ts>
struct Variant {
Variant(const Variant &variant) {
// std::cout << " 赋值构造函数 " << std::endl;
type_id = variant.type_id;
......@@ -70,13 +77,15 @@ template <typename... Ts> struct Variant {
// helper::Destroy(type_id, &data);
}
template <typename T, typename... Args> void Set(Args &&... args) {
template <typename T, typename... Args>
void Set(Args &&... args) {
helper::Destroy(type_id, &data);
new (&data) T(std::forward<Args>(args)...);
type_id = typeid(T).hash_code();
}
template <typename T> T &Get() const {
template <typename T>
T &Get() const {
if (type_id == typeid(T).hash_code()) {
return *const_cast<T *>(reinterpret_cast<const T *>(&data));
} else {
......@@ -87,13 +96,16 @@ template <typename... Ts> struct Variant {
size_t TypeId() const { return type_id; }
private:
private:
static inline size_t invalid_type() { return typeid(void).hash_code(); }
typedef VariantHelper<Ts...> helper;
size_t type_id;
RawData<helper::size> data;
};
template <typename T> struct Vistor { typedef T type_t; };
template <typename T>
struct Vistor {
typedef T type_t;
};
} // namespace paddle_mobile
} // namespace paddle_mobile
......@@ -18,4 +18,4 @@ SOFTWARE.
namespace paddle_mobile {
namespace framework {}
} // namespace paddle_mobile
} // namespace paddle_mobile
......@@ -27,80 +27,84 @@ namespace framework {
class BlockDesc;
class Attribute {
public:
public:
static Attribute GetAttrValue(const proto::OpDesc::Attr &attr_desc) {
// std::cout << "begin get attr value" << std::endl;
Attribute attr;
switch (attr_desc.type()) {
case proto::AttrType::BOOLEAN: {
attr.Set<bool>(attr_desc.b());
break;
}
case proto::AttrType::INT: {
attr.Set<int>(attr_desc.i());
break;
}
case proto::AttrType::FLOAT: {
attr.Set<float>(attr_desc.f());
break;
}
case proto::AttrType::STRING: {
attr.Set<std::string>(attr_desc.s());
break;
}
case proto::AttrType::BOOLEANS: {
std::vector<bool> val(attr_desc.bools_size());
for (int i = 0; i < attr_desc.bools_size(); ++i) {
val[i] = attr_desc.bools(i);
case proto::AttrType::BOOLEAN: {
attr.Set<bool>(attr_desc.b());
break;
}
attr.Set<std::vector<bool>>(val);
break;
}
case proto::AttrType::INTS: {
std::vector<int> val(attr_desc.ints_size());
for (int i = 0; i < attr_desc.ints_size(); ++i) {
val[i] = attr_desc.ints(i);
case proto::AttrType::INT: {
attr.Set<int>(attr_desc.i());
break;
}
attr.Set<std::vector<int>>(val);
break;
}
case proto::AttrType::FLOATS: {
std::vector<float> val(attr_desc.floats_size());
for (int i = 0; i < attr_desc.floats_size(); ++i) {
val[i] = attr_desc.floats(i);
case proto::AttrType::FLOAT: {
attr.Set<float>(attr_desc.f());
break;
}
attr.Set<std::vector<float>>(val);
break;
}
case proto::AttrType::STRINGS: {
std::vector<std::string> val(attr_desc.strings_size());
for (int i = 0; i < attr_desc.strings_size(); ++i) {
val[i] = attr_desc.strings(i);
case proto::AttrType::STRING: {
attr.Set<std::string>(attr_desc.s());
break;
}
attr.Set<std::vector<std::string>>(val);
break;
}
case proto::AttrType::LONG: {
attr.Set<int64_t>(attr_desc.l());
break;
}
default:
// std::cout << " not support " << std::endl;
break;
case proto::AttrType::BOOLEANS: {
std::vector<bool> val(attr_desc.bools_size());
for (int i = 0; i < attr_desc.bools_size(); ++i) {
val[i] = attr_desc.bools(i);
}
attr.Set<std::vector<bool>>(val);
break;
}
case proto::AttrType::INTS: {
std::vector<int> val(attr_desc.ints_size());
for (int i = 0; i < attr_desc.ints_size(); ++i) {
val[i] = attr_desc.ints(i);
}
attr.Set<std::vector<int>>(val);
break;
}
case proto::AttrType::FLOATS: {
std::vector<float> val(attr_desc.floats_size());
for (int i = 0; i < attr_desc.floats_size(); ++i) {
val[i] = attr_desc.floats(i);
}
attr.Set<std::vector<float>>(val);
break;
}
case proto::AttrType::STRINGS: {
std::vector<std::string> val(attr_desc.strings_size());
for (int i = 0; i < attr_desc.strings_size(); ++i) {
val[i] = attr_desc.strings(i);
}
attr.Set<std::vector<std::string>>(val);
break;
}
case proto::AttrType::LONG: {
attr.Set<int64_t>(attr_desc.l());
break;
}
default:
// std::cout << " not support " << std::endl;
break;
}
// std::cout << "end get attr value" << std::endl;
return attr;
}
Attribute() {}
template <typename T, typename... Args> Attribute &Set(Args &&... args) {
template <typename T, typename... Args>
Attribute &Set(Args &&... args) {
variant_.Set<T>(args...);
return *this;
}
template <typename T> T &Get() const { return variant_.Get<T>(); }
template <typename T>
T &Get() const {
return variant_.Get<T>();
}
private:
private:
Variant<int, float, std::string, std::vector<int>, std::vector<float>,
std::vector<std::string>, bool, std::vector<bool>, BlockDesc *,
int64_t>
......@@ -110,10 +114,11 @@ private:
using AttributeMap = std::unordered_map<std::string, Attribute>;
class AttrReader {
public:
public:
explicit AttrReader(const AttributeMap &attrs) : attrs_(attrs) {}
template <typename T> inline T Get(const std::string &name) const {
template <typename T>
inline T Get(const std::string &name) const {
// PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should
// be in
// AttributeMap",
......@@ -121,9 +126,9 @@ public:
return ((Attribute)attrs_.at(name)).Get<T>();
}
private:
private:
const AttributeMap &attrs_;
};
} // namespace framework
} // namespace paddle_mobile
} // namespace framework
} // namespace paddle_mobile
......@@ -46,5 +46,5 @@ BlockDesc::BlockDesc(const proto::BlockDesc &desc) : desc_(desc) {
}
}
} // namespace framework
} // namespace paddle_mobile
} // namespace framework
} // namespace paddle_mobile
......@@ -27,7 +27,7 @@ namespace paddle_mobile {
namespace framework {
class BlockDesc : PaddleMobileObject {
public:
public:
BlockDesc(const proto::BlockDesc &desc);
const int &ID() const { return desc_.idx(); }
......@@ -45,18 +45,19 @@ public:
std::vector<std::shared_ptr<VarDesc>> Vars() const;
std::vector<std::shared_ptr<OpDesc>> Ops() const;
private:
private:
proto::BlockDesc desc_;
std::vector<std::shared_ptr<OpDesc>> ops_;
std::unordered_map<std::string, std::shared_ptr<VarDesc>> vars_;
};
} // namespace framework
} // namespace paddle_mobile
} // namespace framework
} // namespace paddle_mobile
namespace std {
template <> struct hash<paddle_mobile::framework::BlockDesc> {
template <>
struct hash<paddle_mobile::framework::BlockDesc> {
typedef paddle_mobile::framework::BlockDesc argument_type;
typedef std::size_t result_type;
result_type operator()(argument_type const &s) const noexcept {
......@@ -66,4 +67,4 @@ template <> struct hash<paddle_mobile::framework::BlockDesc> {
}
};
} // namespace std
} // namespace std
......@@ -46,15 +46,15 @@ inline DataLayout StringToDataLayout(const std::string &str) {
inline std::string DataLayoutToString(const DataLayout &data_layout) {
switch (data_layout) {
case DataLayout::kNHWC:
return "NHWC";
case DataLayout::kNCHW:
return "NCHW";
case DataLayout::kAnyLayout:
return "ANY_LAYOUT";
default:
break;
// std::cout << "unknown DataLayou %d", data_layout;
case DataLayout::kNHWC:
return "NHWC";
case DataLayout::kNCHW:
return "NCHW";
case DataLayout::kAnyLayout:
return "ANY_LAYOUT";
default:
break;
// std::cout << "unknown DataLayou %d", data_layout;
}
}
......@@ -63,5 +63,5 @@ inline std::ostream &operator<<(std::ostream &out, const DataLayout &l) {
return out;
}
} // namespace framework
} // namespace paddle_mobile
} // namespace framework
} // namespace paddle_mobile
......@@ -88,5 +88,5 @@ void CopyVariableWithTensor(const Variable &in_var, const Tensor &tensor,
// }
}
} // namespace framework
} // namespace paddle_mobile
} // namespace framework
} // namespace paddle_mobile
......@@ -37,5 +37,5 @@ void DataTransform(const OpKernelType &expected_kernel_type,
void CopyVariableWithTensor(const Variable &in_var, const Tensor &tensor,
Variable &out_var);
} // namespace framework
} // namespace paddle_mobile
} // namespace framework
} // namespace paddle_mobile
......@@ -40,4 +40,4 @@ namespace framework {
// }
// }
}
} // namespace paddle_mobile
} // namespace paddle_mobile
......@@ -19,49 +19,53 @@ namespace framework {
/// @cond HIDDEN
template <int i> Dim<i> make_dim(const int64_t *d) {
template <int i>
Dim<i> make_dim(const int64_t *d) {
return Dim<i>(*d, make_dim<i - 1>(d + 1));
}
template <> Dim<0> make_dim<0>(const int64_t *d) { return Dim<0>(*d); }
template <>
Dim<0> make_dim<0>(const int64_t *d) {
return Dim<0>(*d);
}
void make_ddim(DDim &ddim, const int64_t *dims, int n) {
switch (n) {
case 0:
ddim = make_dim<0>(dims);
break;
case 1:
ddim = make_dim<1>(dims);
break;
case 2:
ddim = make_dim<2>(dims);
break;
case 3:
ddim = make_dim<3>(dims);
break;
case 4:
ddim = make_dim<4>(dims);
break;
case 5:
ddim = make_dim<5>(dims);
break;
case 6:
ddim = make_dim<6>(dims);
break;
case 7:
ddim = make_dim<7>(dims);
break;
case 8:
ddim = make_dim<8>(dims);
break;
case 9:
ddim = make_dim<9>(dims);
break;
default:
// std::cout << "Dynamic dimensions must have between [1,
// 9]
// dimensions.";
break;
case 0:
ddim = make_dim<0>(dims);
break;
case 1:
ddim = make_dim<1>(dims);
break;
case 2:
ddim = make_dim<2>(dims);
break;
case 3:
ddim = make_dim<3>(dims);
break;
case 4:
ddim = make_dim<4>(dims);
break;
case 5:
ddim = make_dim<5>(dims);
break;
case 6:
ddim = make_dim<6>(dims);
break;
case 7:
ddim = make_dim<7>(dims);
break;
case 8:
ddim = make_dim<8>(dims);
break;
case 9:
ddim = make_dim<9>(dims);
break;
default:
// std::cout << "Dynamic dimensions must have between [1,
// 9]
// dimensions.";
break;
}
}
......@@ -90,24 +94,28 @@ DDim make_ddim(const std::vector<int> &dims) {
// XXX For some reason, putting this in an anonymous namespace causes
// errors
struct DynamicMutableIndexer : Vistor<int64_t &> {
public:
public:
explicit DynamicMutableIndexer(int idx) : idx_(idx) {}
template <int D> int64_t &operator()(Dim<D> &dim) const { return dim[idx_]; }
template <int D>
int64_t &operator()(Dim<D> &dim) const {
return dim[idx_];
}
private:
private:
int idx_;
};
struct DynamicConstIndexer : public Vistor<int64_t> {
public:
public:
explicit DynamicConstIndexer(int idx) : idx_(idx) {}
template <int D> int64_t operator()(const Dim<D> &dim) const {
template <int D>
int64_t operator()(const Dim<D> &dim) const {
return dim[idx_];
}
private:
private:
int idx_;
};
......@@ -182,7 +190,8 @@ struct VectorizeVisitor : Vistor<void> {
explicit VectorizeVisitor(std::vector<int64_t> &v) : vector(v) {}
template <typename T> void operator()(const T &t) {
template <typename T>
void operator()(const T &t) {
vector.push_back(t.head);
this->operator()(t.tail);
}
......@@ -207,7 +216,8 @@ std::vector<int> vectorize2int(const DDim &ddim) {
}
struct ProductVisitor : Vistor<int64_t> {
template <int D> int64_t operator()(const Dim<D> &dim) {
template <int D>
int64_t operator()(const Dim<D> &dim) {
return product(dim);
}
};
......@@ -233,7 +243,8 @@ struct SliceVectorizeVisitor : Vistor<void> {
// ddim slice.");
}
template <int S> void operator()(const Dim<S> &dim) {
template <int S>
void operator()(const Dim<S> &dim) {
if (begin == 0) {
vector.push_back(dim.head);
} else {
......@@ -264,7 +275,10 @@ DDim slice_ddim(const DDim &ddim, int begin, int end) {
/// \cond HIDDEN
struct ArityVisitor : Vistor<int> {
template <int D> int operator()(Dim<D>) const { return D; }
template <int D>
int operator()(Dim<D>) const {
return D;
}
};
/// \endcond
......@@ -282,11 +296,12 @@ int arity(const DDim &d) {
struct OSVistor : Vistor<std::ostream &> {
OSVistor(std::ostream &os) : os_(os) {}
template <int D> std::ostream &operator()(Dim<D> dim) const {
template <int D>
std::ostream &operator()(Dim<D> dim) const {
return os_ << dim;
}
private:
private:
std::ostream &os_;
};
......@@ -326,5 +341,5 @@ DDim stride_numel(const framework::DDim &ddim) {
return framework::make_ddim(strides);
}
} // namespace framework
} // namespace paddle_mobile
} // namespace framework
} // namespace paddle_mobile
......@@ -14,12 +14,12 @@ limitations under the License. */
#pragma once
#include "common/variant.h"
#include "dim.h"
#include <assert.h>
#include <initializer_list>
#include <stdexcept>
#include <vector>
#include "common/variant.h"
#include "dim.h"
namespace paddle_mobile {
namespace framework {
......@@ -66,11 +66,15 @@ struct DDim {
DDim() { var.Set<Dim<1>>(Dim<1>()); }
template <int D> explicit DDim(const Dim<D> &in) { var.Set<Dim<D>>(in); }
template <int D>
explicit DDim(const Dim<D> &in) {
var.Set<Dim<D>>(in);
}
/*implicit*/ DDim(std::initializer_list<int64_t> init_list);
template <int D> DDim &operator=(const Dim<D> &in) {
template <int D>
DDim &operator=(const Dim<D> &in) {
var.Set<Dim<D>>(in);
return *this;
}
......@@ -159,5 +163,5 @@ DDim flatten_to_1d(const DDim &src);
DDim stride(const DDim &ddim);
DDim stride_numel(const DDim &ddim);
} // namespace framework
} // namespace paddle_mobile
} // namespace framework
} // namespace paddle_mobile
......@@ -24,7 +24,8 @@ namespace paddle_mobile {
namespace framework {
// Statically sized, statically indexed dimension
template <int i> struct Dim {
template <int i>
struct Dim {
static constexpr int dimensions = i;
template <typename... Args>
......@@ -70,7 +71,8 @@ template <int i> struct Dim {
};
// Base case specialization
template <> struct Dim<0> {
template <>
struct Dim<0> {
static constexpr int dimensions = 0;
HOSTDEVICE
......@@ -105,28 +107,37 @@ template <> struct Dim<0> {
namespace {
// Helper for accessing Dim classes
template <int i> struct DimGetter {
template <int i>
struct DimGetter {
// Return a copy if Dim is const
template <typename D> HOSTDEVICE static int64_t impl(const D &d) {
template <typename D>
HOSTDEVICE static int64_t impl(const D &d) {
return DimGetter<i - 1>::impl(d.tail);
}
// Return a reference if Dim is mutable
template <typename D> HOSTDEVICE static int64_t &impl(D &d) {
template <typename D>
HOSTDEVICE static int64_t &impl(D &d) {
return DimGetter<i - 1>::impl(d.tail);
}
};
// Eureka! We found the element!
template <> struct DimGetter<0> {
template <>
struct DimGetter<0> {
// Return a copy if Dim is const
template <typename D> HOSTDEVICE static int64_t impl(const D &d) {
template <typename D>
HOSTDEVICE static int64_t impl(const D &d) {
return d.head;
}
// Return a reference if Dim is mutable
template <typename D> HOSTDEVICE static int64_t &impl(D &d) { return d.head; }
template <typename D>
HOSTDEVICE static int64_t &impl(D &d) {
return d.head;
}
};
template <int D> HOSTDEVICE int64_t &indexer(Dim<D> &dim, int idx) {
template <int D>
HOSTDEVICE int64_t &indexer(Dim<D> &dim, int idx) {
#ifndef __CUDA_ARCH__
if (idx < 0) {
throw std::invalid_argument("Tried to access a negative dimension");
......@@ -140,7 +151,8 @@ template <int D> HOSTDEVICE int64_t &indexer(Dim<D> &dim, int idx) {
return indexer(dim.tail, idx - 1);
}
template <> HOSTDEVICE int64_t &indexer<0>(Dim<0> &dim, int idx) {
template <>
HOSTDEVICE int64_t &indexer<0>(Dim<0> &dim, int idx) {
#ifndef __CUDA_ARCH__
throw std::invalid_argument("Invalid index");
#else
......@@ -156,7 +168,8 @@ template <> HOSTDEVICE int64_t &indexer<0>(Dim<0> &dim, int idx) {
#endif
}
template <int D> HOSTDEVICE int64_t indexer(const Dim<D> &dim, int idx) {
template <int D>
HOSTDEVICE int64_t indexer(const Dim<D> &dim, int idx) {
#ifndef __CUDA_ARCH__
if (idx < 0) {
throw std::invalid_argument("Tried to access a negative dimension");
......@@ -170,7 +183,8 @@ template <int D> HOSTDEVICE int64_t indexer(const Dim<D> &dim, int idx) {
return indexer(dim.tail, idx - 1);
}
template <> HOSTDEVICE int64_t indexer<0>(const Dim<0> &dim, int idx) {
template <>
HOSTDEVICE int64_t indexer<0>(const Dim<0> &dim, int idx) {
#ifndef __CUDA_ARCH__
throw std::invalid_argument("Invalid index");
#else
......@@ -186,25 +200,29 @@ template <> HOSTDEVICE int64_t indexer<0>(const Dim<0> &dim, int idx) {
#endif
}
} // namespace
} // namespace
// Static access to constant Dim
template <int i, int l> HOSTDEVICE int64_t get(const Dim<l> &d) {
template <int i, int l>
HOSTDEVICE int64_t get(const Dim<l> &d) {
return DimGetter<i>::impl(d);
}
// Static access to mutable Dim
template <int i, int l> HOSTDEVICE int64_t &get(Dim<l> &d) {
template <int i, int l>
HOSTDEVICE int64_t &get(Dim<l> &d) {
return DimGetter<i>::impl(d);
}
// Dynamic access to constant Dim
template <int l> HOSTDEVICE int64_t Dim<l>::operator[](int i) const {
template <int l>
HOSTDEVICE int64_t Dim<l>::operator[](int i) const {
// std::cout << "l: " << l << std::endl;
return indexer(*this, i);
}
// Dynamic access to mutable Dim
template <int l> HOSTDEVICE int64_t &Dim<l>::operator[](int i) {
template <int l>
HOSTDEVICE int64_t &Dim<l>::operator[](int i) {
return indexer(*this, i);
}
......@@ -247,13 +265,15 @@ HOSTDEVICE inline int64_t linearize(const Dim<0> &a, const Dim<0> &b) {
}
// Product of a Dim
template <int i> HOSTDEVICE int64_t product(const Dim<i> &a, int prod = 1) {
template <int i>
HOSTDEVICE int64_t product(const Dim<i> &a, int prod = 1) {
return prod * a.head * product(a.tail);
}
// Base case product of a Dim
// Notice it is inline because it is no longer a template
template <> HOSTDEVICE inline int64_t product(const Dim<0> &a, int prod) {
template <>
HOSTDEVICE inline int64_t product(const Dim<0> &a, int prod) {
return prod;
}
......@@ -282,7 +302,8 @@ HOSTDEVICE Dim<i> ex_prefix_mul(const Dim<i> &src, int mul = 1) {
///\cond HIDDEN
// Base case of ex_prefix_mul
// Notice it is inline because it is no longer a template
template <> HOSTDEVICE inline Dim<0> ex_prefix_mul(const Dim<0> &src, int mul) {
template <>
HOSTDEVICE inline Dim<0> ex_prefix_mul(const Dim<0> &src, int mul) {
return Dim<0>();
}
///\endcond
......@@ -290,7 +311,8 @@ template <> HOSTDEVICE inline Dim<0> ex_prefix_mul(const Dim<0> &src, int mul) {
/**
* Add two dimensions together
*/
template <int i> HOSTDEVICE Dim<i> dim_plus(const Dim<i> &a, const Dim<i> &b) {
template <int i>
HOSTDEVICE Dim<i> dim_plus(const Dim<i> &a, const Dim<i> &b) {
return Dim<i>(a.head + b.head, dim_plus(a.tail, b.tail));
}
......@@ -308,7 +330,8 @@ HOSTDEVICE Dim<i> operator+(const Dim<i> &lhs, const Dim<i> &rhs) {
/**
* Multiply two dimensions together
*/
template <int i> HOSTDEVICE Dim<i> dim_mult(const Dim<i> &a, const Dim<i> &b) {
template <int i>
HOSTDEVICE Dim<i> dim_mult(const Dim<i> &a, const Dim<i> &b) {
return Dim<i>(a.head * b.head, dim_mult(a.tail, b.tail));
}
......@@ -365,8 +388,8 @@ HOSTDEVICE Dim<sizeof...(Args)> make_dim(Args... idxes) {
// Allows us to output a Dim
// XXX For some reason, overloading fails to resolve this correctly
template <int i>
typename std::enable_if<(i > 1), std::ostream &>::type
operator<<(std::ostream &os, const Dim<i> &d) {
typename std::enable_if<(i > 1), std::ostream &>::type operator<<(
std::ostream &os, const Dim<i> &d) {
os << d.head << ", " << d.tail;
return os;
}
......@@ -374,8 +397,8 @@ operator<<(std::ostream &os, const Dim<i> &d) {
// Base case that allows us to output a Dim
// XXX I wish this could be an overload instead of a template
template <int i>
typename std::enable_if<(i == 1), std::ostream &>::type
operator<<(std::ostream &os, const Dim<i> &d) {
typename std::enable_if<(i == 1), std::ostream &>::type operator<<(
std::ostream &os, const Dim<i> &d) {
os << d.head;
return os;
}
......@@ -384,7 +407,8 @@ inline std::ostream &operator<<(std::ostream &os, const Dim<0> &d) {
return os;
}
template <int i> HOST std::string Dim<i>::to_string() const {
template <int i>
HOST std::string Dim<i>::to_string() const {
std::stringstream stream;
stream << *this;
......@@ -406,5 +430,5 @@ HOSTDEVICE Dim<D> linear_to_dimension(int linear_index, Dim<D> extents) {
return result;
}
} // namespace framework
} // namespace paddle_mobile
} // namespace framework
} // namespace paddle_mobile
......@@ -90,5 +90,5 @@ void Executor<Dtype>::predict(const Tensor &t, int block_id) {
template class Executor<CPU>;
} // namespace framework
} // namespace paddle_mobile
} // namespace framework
} // namespace paddle_mobile
......@@ -34,15 +34,16 @@ SOFTWARE.
namespace paddle_mobile {
namespace framework {
template <typename Dtype> class Executor {
public:
template <typename Dtype>
class Executor {
public:
Executor();
Executor(const Program<Dtype> p);
std::shared_ptr<Tensor> predict(Tensor &t);
public:
public:
const framework::Program<Dtype> program_;
std::shared_ptr<ProgramDesc> to_predict_program_;
......@@ -54,5 +55,5 @@ public:
bool use_optimize_ = false;
};
} // namespace framework
} // namespace paddle_mobile
} // namespace framework
} // namespace paddle_mobile
此差异已折叠。
此差异已折叠。
此差异已折叠。
......@@ -14,12 +14,12 @@ limitations under the License. */
#pragma once
#include "tensor.h"
#include "tensor_util.h"
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "tensor.h"
#include "tensor_util.h"
namespace paddle_mobile {
......@@ -102,7 +102,7 @@ bool CheckAbsLoD(const LoD &in, int tensor_height = -1);
* see https://en.wikipedia.org/wiki/Level_of_details for reference.
*/
class LoDTensor : public Tensor {
public:
public:
LoDTensor() : Tensor() {}
explicit LoDTensor(const LoD &lod) : lod_(lod) {}
......@@ -139,7 +139,7 @@ public:
return (lod_)[level].size() - 1;
}
private:
private:
LoD lod_;
};
......@@ -189,9 +189,8 @@ LoDTensor LodExpand(const LoDTensor &source, const LoD &lod, size_t level) {
// Returns:
// LoD = [[1, 4], [2, 4, 2, 3, 2]]
// pair<size_t, size_t> = {11, 24}
std::pair<LoD, std::pair<size_t, size_t>>
GetSubLoDAndAbsoluteOffset(const LoD &lod, size_t start_idx, size_t end_idx,
size_t start_level);
std::pair<LoD, std::pair<size_t, size_t>> GetSubLoDAndAbsoluteOffset(
const LoD &lod, size_t start_idx, size_t end_idx, size_t start_level);
void AppendLoD(LoD *lod, const LoD &lod_length);
......@@ -204,5 +203,5 @@ void SerializeToStream(std::ostream &os, const LoDTensor &tensor);
void DeserializeFromStream(std::istream &is, LoDTensor *tensor);
} // namespace framework
} // namespace paddle_mobile
} // namespace framework
} // namespace paddle_mobile
......@@ -55,5 +55,5 @@ const std::unordered_map<std::string, Attribute> &OpDesc::GetAttrMap() const {
return attrs_;
}
} // namespace framework
} // namespace paddle_mobile
} // namespace framework
} // namespace paddle_mobile
......@@ -26,7 +26,7 @@ namespace paddle_mobile {
namespace framework {
class OpDesc : PaddleMobileObject {
public:
public:
OpDesc(const proto::OpDesc &desc);
const std::vector<std::string> &Input(const std::string &name) const;
const std::vector<std::string> &Output(const std::string &name) const;
......@@ -40,12 +40,12 @@ public:
const std::string &Type() { return desc_.type(); };
private:
private:
proto::OpDesc desc_;
VariableNameMap inputs_;
VariableNameMap outputs_;
AttributeMap attrs_;
};
} // namespace framework
} // namespace paddle_mobile
} // namespace framework
} // namespace paddle_mobile
此差异已折叠。
......@@ -60,5 +60,5 @@ inline bool TransFromNeeded(const OpKernelType &l, const OpKernelType &r) {
NeedTransformLayout(l.data_layout_, r.data_layout_);
}
} // namespace framework
} // namespace paddle_mobile
} // namespace framework
} // namespace paddle_mobile
......@@ -22,5 +22,5 @@ namespace paddle_mobile {
namespace framework {
// this class not only make proto but also init attribute checkers.
class OpProtoAndCheckerMaker {};
} // namespace framework
} // namespace paddle_mobile
} // namespace framework
} // namespace paddle_mobile
此差异已折叠。
此差异已折叠。
......@@ -92,5 +92,5 @@ Print &operator<<(Print &printer, const Node &node) {
return printer;
}
} // namespace framework
} // namespace paddle_mobile
} // namespace framework
} // namespace paddle_mobile
......@@ -18,4 +18,4 @@ SOFTWARE.
namespace paddle_mobile {
namespace framework {}
} // namespace paddle_mobile
} // namespace paddle_mobile
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册