提交 cf963d96 编写于 作者: L liuruilong

format files

上级 6018d63c
...@@ -14,6 +14,4 @@ limitations under the License. */ ...@@ -14,6 +14,4 @@ limitations under the License. */
#include "log.h" #include "log.h"
namespace paddle_mobile { namespace paddle_mobile {}
}
...@@ -11,4 +11,3 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,4 +11,3 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
...@@ -17,32 +17,30 @@ limitations under the License. */ ...@@ -17,32 +17,30 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
/* /*
* Variant<int, float, std::string, std::vector<int>, std::vector<float>, * Variant<int, float, std::string, std::vector<int>, std::vector<float>,
std::vector<std::string>, bool, std::vector<bool>, BlockDesc *, std::vector<std::string>, bool, std::vector<bool>, BlockDesc *,
int64_t> int64_t>
* */ * */
struct PrintVistor: Vistor<Print &>{ struct PrintVistor : Vistor<Print &> {
PrintVistor(Print &printer):printer_(printer){ PrintVistor(Print &printer) : printer_(printer) {}
}
template <typename T> template <typename T>
Print &operator()(const T &value){ Print &operator()(const T &value) {
printer_ << value; printer_ << value;
return printer_; return printer_;
} }
private: private:
Print &printer_; Print &printer_;
}; };
Print &operator<<(Print &printer, const Attribute &attr) { Print &operator<<(Print &printer, const Attribute &attr) {
Attribute::ApplyVistor(PrintVistor(printer), attr); Attribute::ApplyVistor(PrintVistor(printer), attr);
// std::vector<std::string> v = {"1", "2"}; // std::vector<std::string> v = {"1", "2"};
// printer << (v); // printer << (v);
return printer; return printer;
} }
} } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -111,13 +111,16 @@ class Attribute { ...@@ -111,13 +111,16 @@ class Attribute {
return vistor(attr.variant_.Get<std::string>()); return vistor(attr.variant_.Get<std::string>());
} else if (attr.variant_.TypeId() == typeid(std::vector<int>).hash_code()) { } else if (attr.variant_.TypeId() == typeid(std::vector<int>).hash_code()) {
return vistor(attr.variant_.Get<std::vector<int>>()); return vistor(attr.variant_.Get<std::vector<int>>());
} else if (attr.variant_.TypeId() == typeid(std::vector<float>).hash_code()) { } else if (attr.variant_.TypeId() ==
typeid(std::vector<float>).hash_code()) {
return vistor(attr.variant_.Get<std::vector<float>>()); return vistor(attr.variant_.Get<std::vector<float>>());
} else if (attr.variant_.TypeId() == typeid(std::vector<std::string>).hash_code()) { } else if (attr.variant_.TypeId() ==
typeid(std::vector<std::string>).hash_code()) {
return vistor(attr.variant_.Get<std::vector<std::string>>()); return vistor(attr.variant_.Get<std::vector<std::string>>());
} else if (attr.variant_.TypeId() == typeid(bool).hash_code()) { } else if (attr.variant_.TypeId() == typeid(bool).hash_code()) {
return vistor(attr.variant_.Get<bool>()); return vistor(attr.variant_.Get<bool>());
} else if (attr.variant_.TypeId() == typeid(std::vector<bool>).hash_code()) { } else if (attr.variant_.TypeId() ==
typeid(std::vector<bool>).hash_code()) {
return vistor(attr.variant_.Get<std::vector<bool>>()); return vistor(attr.variant_.Get<std::vector<bool>>());
} else if (attr.variant_.TypeId() == typeid(int64_t).hash_code()) { } else if (attr.variant_.TypeId() == typeid(int64_t).hash_code()) {
return vistor(attr.variant_.Get<int64_t>()); return vistor(attr.variant_.Get<int64_t>());
...@@ -152,9 +155,6 @@ class AttrReader { ...@@ -152,9 +155,6 @@ class AttrReader {
const AttributeMap &attrs_; const AttributeMap &attrs_;
}; };
Print &operator<<(Print &printer, const Attribute &op_desc); Print &operator<<(Print &printer, const Attribute &op_desc);
} // namespace framework } // namespace framework
......
...@@ -18,11 +18,11 @@ limitations under the License. */ ...@@ -18,11 +18,11 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "framework/program/block_desc.h"
#include "framework.pb.h" #include "framework.pb.h"
#include "operator.h" #include "framework/program/block_desc.h"
#include "framework/program/program.h" #include "framework/program/program.h"
#include "framework/program/program_desc.h" #include "framework/program/program_desc.h"
#include "operator.h"
#include "scope.h" #include "scope.h"
#include "tensor.h" #include "tensor.h"
#include "variable.h" #include "variable.h"
......
...@@ -22,15 +22,15 @@ limitations under the License. */ ...@@ -22,15 +22,15 @@ limitations under the License. */
#include "common/type_define.h" #include "common/type_define.h"
#include "common/types.h" #include "common/types.h"
#include "common/variant.h" #include "common/variant.h"
#include "framework/scope.h"
#include "framework/tensor.h"
#include "framework/op_info.h"
#include "framework/variable.h"
#include "framework/attribute.h" #include "framework/attribute.h"
#include "framework/op_info.h"
#include "framework/op_kernel_type.h" #include "framework/op_kernel_type.h"
#include "framework/program/block_desc.h"
#include "framework/paddle_mobile_object.h" #include "framework/paddle_mobile_object.h"
#include "framework/program/block_desc.h"
#include "framework/program/program-optimize/node.h" #include "framework/program/program-optimize/node.h"
#include "framework/scope.h"
#include "framework/tensor.h"
#include "framework/variable.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
...@@ -104,7 +104,7 @@ class OpKernelBase : PaddleMobileObject { ...@@ -104,7 +104,7 @@ class OpKernelBase : PaddleMobileObject {
std::shared_ptr<::paddle_mobile::framework::Scope> scope) \ std::shared_ptr<::paddle_mobile::framework::Scope> scope) \
: parent_cls<Dtype, T>(type, inputs, outputs, attrs, scope) {} : parent_cls<Dtype, T>(type, inputs, outputs, attrs, scope) {}
class FusionOpMatcher: PaddleMobileObject{ class FusionOpMatcher : PaddleMobileObject {
public: public:
FusionOpMatcher() {} FusionOpMatcher() {}
...@@ -112,16 +112,11 @@ class FusionOpMatcher: PaddleMobileObject{ ...@@ -112,16 +112,11 @@ class FusionOpMatcher: PaddleMobileObject{
virtual void FolderNodes(Node &node) { virtual void FolderNodes(Node &node) {
node.Folder(node_.Depth(), Type(), {}); node.Folder(node_.Depth(), Type(), {});
} }
virtual Node &BeginNode() { virtual Node &BeginNode() { return node_; }
return node_;
}
std::string BeginType() { std::string BeginType() { return node_.BeginType(); }
return node_.BeginType();
}
protected: protected:
Node node_; Node node_;
......
...@@ -33,8 +33,8 @@ std::vector<std::shared_ptr<OpDesc>> BlockDesc::Ops() const { ...@@ -33,8 +33,8 @@ std::vector<std::shared_ptr<OpDesc>> BlockDesc::Ops() const {
return res; return res;
} }
BlockDesc::BlockDesc(const proto::BlockDesc &desc): BlockDesc::BlockDesc(const proto::BlockDesc &desc)
index_(desc.idx()), parent_index_(desc.parent_idx()) { : index_(desc.idx()), parent_index_(desc.parent_idx()) {
for (const proto::VarDesc &var_desc : desc.vars()) { for (const proto::VarDesc &var_desc : desc.vars()) {
vars_[var_desc.name()].reset(new VarDesc(var_desc)); vars_[var_desc.name()].reset(new VarDesc(var_desc));
} }
......
...@@ -15,10 +15,9 @@ limitations under the License. */ ...@@ -15,10 +15,9 @@ limitations under the License. */
#pragma once #pragma once
#include "framework/framework.pb.h" #include "framework/framework.pb.h"
#include "framework/paddle_mobile_object.h"
#include "framework/program/op_desc.h" #include "framework/program/op_desc.h"
#include "framework/program/var_desc.h" #include "framework/program/var_desc.h"
#include "framework/paddle_mobile_object.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
...@@ -29,9 +28,8 @@ class BlockDesc : PaddleMobileObject { ...@@ -29,9 +28,8 @@ class BlockDesc : PaddleMobileObject {
friend class ProgramOptimize; friend class ProgramOptimize;
BlockDesc(const proto::BlockDesc &desc); BlockDesc(const proto::BlockDesc &desc);
BlockDesc(const BlockDesc &block_desc): BlockDesc(const BlockDesc &block_desc)
index_(block_desc.index_), : index_(block_desc.index_), parent_index_(block_desc.parent_index_) {
parent_index_(block_desc.parent_index_) {
for (auto &op_desc : block_desc.ops_) { for (auto &op_desc : block_desc.ops_) {
std::shared_ptr<OpDesc> copy_op_desc = std::make_shared<OpDesc>(*op_desc); std::shared_ptr<OpDesc> copy_op_desc = std::make_shared<OpDesc>(*op_desc);
ops_.push_back(copy_op_desc); ops_.push_back(copy_op_desc);
......
...@@ -20,7 +20,7 @@ limitations under the License. */ ...@@ -20,7 +20,7 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
OpDesc::OpDesc(const proto::OpDesc &desc):type_(desc.type()) { OpDesc::OpDesc(const proto::OpDesc &desc) : type_(desc.type()) {
for (int i = 0; i < desc.inputs_size(); ++i) { for (int i = 0; i < desc.inputs_size(); ++i) {
const proto::OpDesc::Var &var = desc.inputs(i); const proto::OpDesc::Var &var = desc.inputs(i);
std::vector<std::string> &args = inputs_[var.parameter()]; std::vector<std::string> &args = inputs_[var.parameter()];
......
...@@ -14,8 +14,8 @@ limitations under the License. */ ...@@ -14,8 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include <vector>
#include <string> #include <string>
#include <vector>
#include "common/log.h" #include "common/log.h"
#include "common/type_define.h" #include "common/type_define.h"
...@@ -32,14 +32,13 @@ class OpDesc : PaddleMobileObject { ...@@ -32,14 +32,13 @@ class OpDesc : PaddleMobileObject {
friend class Node; friend class Node;
explicit OpDesc(const proto::OpDesc &desc); explicit OpDesc(const proto::OpDesc &desc);
OpDesc(const OpDesc &op_desc): type_(op_desc.type_) { OpDesc(const OpDesc &op_desc) : type_(op_desc.type_) {
this->inputs_ = op_desc.inputs_; this->inputs_ = op_desc.inputs_;
this->outputs_ = op_desc.outputs_; this->outputs_ = op_desc.outputs_;
this->attrs_ = op_desc.attrs_; this->attrs_ = op_desc.attrs_;
} }
OpDesc() { OpDesc() {}
}
const std::vector<std::string> &Input(const std::string &name) const; const std::vector<std::string> &Input(const std::string &name) const;
const std::vector<std::string> &Output(const std::string &name) const; const std::vector<std::string> &Output(const std::string &name) const;
Attribute GetAttr(const std::string &name) const; Attribute GetAttr(const std::string &name) const;
...@@ -52,17 +51,11 @@ class OpDesc : PaddleMobileObject { ...@@ -52,17 +51,11 @@ class OpDesc : PaddleMobileObject {
const std::string &Type() { return type_; } const std::string &Type() { return type_; }
void SetInputs(VariableNameMap inputs){ void SetInputs(VariableNameMap inputs) { inputs_ = inputs; }
inputs_ = inputs;
}
void SetOutputs(VariableNameMap outputs){ void SetOutputs(VariableNameMap outputs) { outputs_ = outputs; }
outputs_ = outputs;
}
void SetAttrMap(AttributeMap attrs){ void SetAttrMap(AttributeMap attrs) { attrs_ = attrs; }
attrs_ = attrs;
}
private: private:
std::string type_; std::string type_;
......
...@@ -25,8 +25,8 @@ namespace framework { ...@@ -25,8 +25,8 @@ namespace framework {
class FusionOpRegister { class FusionOpRegister {
public: public:
static FusionOpRegister *Instance() { static FusionOpRegister* Instance() {
static FusionOpRegister *regist = nullptr; static FusionOpRegister* regist = nullptr;
if (regist == nullptr) { if (regist == nullptr) {
regist = new FusionOpRegister(); regist = new FusionOpRegister();
} }
...@@ -47,9 +47,9 @@ class FusionOpRegister { ...@@ -47,9 +47,9 @@ class FusionOpRegister {
FusionOpRegister() {} FusionOpRegister() {}
}; };
class FusionOpRegistrar{ class FusionOpRegistrar {
public: public:
explicit FusionOpRegistrar(FusionOpMatcher* matcher){ explicit FusionOpRegistrar(FusionOpMatcher* matcher) {
FusionOpRegister::Instance()->regist(matcher); FusionOpRegister::Instance()->regist(matcher);
} }
}; };
......
...@@ -72,7 +72,8 @@ void Node::OpDescs(uint index, ...@@ -72,7 +72,8 @@ void Node::OpDescs(uint index,
} }
} }
void Node::OpDescs(std::vector<std::shared_ptr<framework::OpDesc>> *op_desc, Node *node) { void Node::OpDescs(std::vector<std::shared_ptr<framework::OpDesc>> *op_desc,
Node *node) {
auto iter = std::find(op_desc->begin(), op_desc->end(), this->op_desc_); auto iter = std::find(op_desc->begin(), op_desc->end(), this->op_desc_);
if (inputs_.size() > 1 && node != inputs_.back()) { if (inputs_.size() > 1 && node != inputs_.back()) {
return; return;
......
...@@ -20,8 +20,8 @@ limitations under the License. */ ...@@ -20,8 +20,8 @@ limitations under the License. */
#include <vector> #include <vector>
#include "common/log.h" #include "common/log.h"
#include "framework/program/op_desc.h"
#include "framework/paddle_mobile_object.h" #include "framework/paddle_mobile_object.h"
#include "framework/program/op_desc.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
...@@ -42,7 +42,8 @@ class Node : PaddleMobileObject { ...@@ -42,7 +42,8 @@ class Node : PaddleMobileObject {
std::map<std::string, std::pair<std::string, std::string>> change_map); std::map<std::string, std::pair<std::string, std::string>> change_map);
std::vector<std::shared_ptr<framework::OpDesc>> OpDescs(uint size); std::vector<std::shared_ptr<framework::OpDesc>> OpDescs(uint size);
std::vector<std::shared_ptr<framework::OpDesc>> OpDescs(); std::vector<std::shared_ptr<framework::OpDesc>> OpDescs();
void OpDescs(std::vector<std::shared_ptr<framework::OpDesc>> *op_desc, Node *node); void OpDescs(std::vector<std::shared_ptr<framework::OpDesc>> *op_desc,
Node *node);
std::shared_ptr<framework::OpDesc> OpDesc() { return op_desc_; } std::shared_ptr<framework::OpDesc> OpDesc() { return op_desc_; }
std::string BeginType() { return type_; } std::string BeginType() { return type_; }
void Description(); void Description();
......
...@@ -23,7 +23,6 @@ std::shared_ptr<ProgramDesc> ProgramOptimize::Optimize() {} ...@@ -23,7 +23,6 @@ std::shared_ptr<ProgramDesc> ProgramOptimize::Optimize() {}
std::shared_ptr<ProgramDesc> ProgramOptimize::FushionOptimize( std::shared_ptr<ProgramDesc> ProgramOptimize::FushionOptimize(
std::shared_ptr<ProgramDesc> ori_des) { std::shared_ptr<ProgramDesc> ori_des) {
ProgramDesc *optimize_program = new ProgramDesc(*ori_des); ProgramDesc *optimize_program = new ProgramDesc(*ori_des);
for (int i = 0; i < optimize_program->Blocks().size(); ++i) { for (int i = 0; i < optimize_program->Blocks().size(); ++i) {
...@@ -96,9 +95,8 @@ std::shared_ptr<ProgramDesc> ProgramOptimize::FushionOptimize( ...@@ -96,9 +95,8 @@ std::shared_ptr<ProgramDesc> ProgramOptimize::FushionOptimize(
} }
} }
// DLOG << "node: \n" << *begin_node; // DLOG << "node: \n" << *begin_node;
block->ops_ = begin_node->OpDescs(); block->ops_ = begin_node->OpDescs();
} }
std::shared_ptr<ProgramDesc> shared_optimzie(optimize_program); std::shared_ptr<ProgramDesc> shared_optimzie(optimize_program);
return shared_optimzie; return shared_optimzie;
......
...@@ -30,6 +30,7 @@ class ProgramOptimize { ...@@ -30,6 +30,7 @@ class ProgramOptimize {
std::shared_ptr<ProgramDesc> Optimize(); std::shared_ptr<ProgramDesc> Optimize();
std::shared_ptr<ProgramDesc> FushionOptimize( std::shared_ptr<ProgramDesc> FushionOptimize(
std::shared_ptr<ProgramDesc> ori_des); std::shared_ptr<ProgramDesc> ori_des);
private: private:
// std::shared_ptr<ProgramDesc> ori_desc_; // std::shared_ptr<ProgramDesc> ori_desc_;
std::vector<std::unordered_map<std::string, std::shared_ptr<Node>>> std::vector<std::unordered_map<std::string, std::shared_ptr<Node>>>
......
...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <vector>
#include <string> #include <string>
#include <vector>
#include "program_desc.h" #include "program_desc.h"
...@@ -29,7 +29,7 @@ ProgramDesc::ProgramDesc(const proto::ProgramDesc &desc) { ...@@ -29,7 +29,7 @@ ProgramDesc::ProgramDesc(const proto::ProgramDesc &desc) {
void ProgramDesc::Description(std::string header) { void ProgramDesc::Description(std::string header) {
#ifdef PADDLE_MOBILE_DEBUG #ifdef PADDLE_MOBILE_DEBUG
if (header.size()){ if (header.size()) {
LOG(kLOG_INFO) << header; LOG(kLOG_INFO) << header;
} }
for (const auto &block : this->blocks_) { for (const auto &block : this->blocks_) {
......
...@@ -17,8 +17,8 @@ limitations under the License. */ ...@@ -17,8 +17,8 @@ limitations under the License. */
#include <vector> #include <vector>
#include "common/types.h" #include "common/types.h"
#include "framework/program/block_desc.h"
#include "framework/paddle_mobile_object.h" #include "framework/paddle_mobile_object.h"
#include "framework/program/block_desc.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
...@@ -39,6 +39,7 @@ class ProgramDesc : PaddleMobileObject { ...@@ -39,6 +39,7 @@ class ProgramDesc : PaddleMobileObject {
} }
void Description(std::string header = ""); void Description(std::string header = "");
private: private:
std::vector<std::shared_ptr<BlockDesc>> blocks_; std::vector<std::shared_ptr<BlockDesc>> blocks_;
}; };
......
...@@ -24,7 +24,7 @@ class VarDesc { ...@@ -24,7 +24,7 @@ class VarDesc {
public: public:
VarDesc(const proto::VarDesc &desc); VarDesc(const proto::VarDesc &desc);
VarDesc(const VarDesc &var_desc):desc_(var_desc.desc_) {} VarDesc(const VarDesc &var_desc) : desc_(var_desc.desc_) {}
std::string Name() const { return desc_.name(); } std::string Name() const { return desc_.name(); }
......
...@@ -14,13 +14,13 @@ limitations under the License. */ ...@@ -14,13 +14,13 @@ limitations under the License. */
#include <fstream> #include <fstream>
#include "io.h"
#include "common/log.h" #include "common/log.h"
#include "framework/scope.h"
#include "framework/tensor.h"
#include "framework/lod_tensor.h"
#include "framework/framework.pb.h" #include "framework/framework.pb.h"
#include "framework/lod_tensor.h"
#include "framework/program/program_desc.h" #include "framework/program/program_desc.h"
#include "framework/scope.h"
#include "framework/tensor.h"
#include "io.h"
namespace paddle_mobile { namespace paddle_mobile {
......
...@@ -24,17 +24,17 @@ class FushionConvAddReluOpMatcher : public framework::FusionOpMatcher { ...@@ -24,17 +24,17 @@ class FushionConvAddReluOpMatcher : public framework::FusionOpMatcher {
public: public:
FushionConvAddReluOpMatcher() { FushionConvAddReluOpMatcher() {
node_ = framework::Node("conv2d"); node_ = framework::Node("conv2d");
node_ > std::make_shared<framework::Node>("elementwise_add") > std::make_shared<framework::Node>("relu"); node_ > std::make_shared<framework::Node>("elementwise_add") >
std::make_shared<framework::Node>("relu");
} }
void FolderNodes(framework::Node &node) { void FolderNodes(framework::Node &node) {
std::vector<std::shared_ptr<framework::OpDesc>> origin_descs = node.OpDescs(node_.Depth()); std::vector<std::shared_ptr<framework::OpDesc>> origin_descs =
node.Folder(node_.Depth(), Type(), {{"elementwise_add" , {"Y", "Z"}}}); node.OpDescs(node_.Depth());
node.Folder(node_.Depth(), Type(), {{"elementwise_add", {"Y", "Z"}}});
} }
std::string Type() { std::string Type() { return "FusionConvAddRelu"; }
return "FusionConvAddRelu";
}
}; };
class FusionFcOp { class FusionFcOp {
...@@ -42,7 +42,8 @@ class FusionFcOp { ...@@ -42,7 +42,8 @@ class FusionFcOp {
private: private:
}; };
static framework::FusionOpRegistrar fc_registrar(new FushionConvAddReluOpMatcher()); static framework::FusionOpRegistrar fc_registrar(
new FushionConvAddReluOpMatcher());
} } // namespace operators
} } // namespace paddle_mobile
...@@ -30,13 +30,12 @@ class FusionFcMatcher : public framework::FusionOpMatcher { ...@@ -30,13 +30,12 @@ class FusionFcMatcher : public framework::FusionOpMatcher {
} }
void FolderNodes(framework::Node &node) { void FolderNodes(framework::Node &node) {
std::vector<std::shared_ptr<framework::OpDesc>> origin_descs = node.OpDescs(node_.Depth()); std::vector<std::shared_ptr<framework::OpDesc>> origin_descs =
node.Folder(node_.Depth(), Type(), {{"elementwise_add" , {"Y", "Z"}}}); node.OpDescs(node_.Depth());
node.Folder(node_.Depth(), Type(), {{"elementwise_add", {"Y", "Z"}}});
} }
std::string Type() { std::string Type() { return "fc"; }
return "fc";
}
}; };
class FusionFcOp { class FusionFcOp {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册