提交 f4126618 编写于 作者: L liuruilong

fix confict

上级 45de1e84
...@@ -21,9 +21,10 @@ SOFTWARE. ...@@ -21,9 +21,10 @@ SOFTWARE.
#include "node.h" #include "node.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
Node &Node::operator>(const Node &out) {
std::shared_ptr<Node> node = std::make_shared<Node>(Node(out)); Node &Node::operator>(std::shared_ptr<Node> node) {
outputs_.push_back(node); outputs_.push_back(node);
return *node; return *node;
} }
...@@ -47,7 +48,7 @@ bool Node::operator==(const Node &in) { ...@@ -47,7 +48,7 @@ bool Node::operator==(const Node &in) {
std::string Node::ToString(std::string blank) const { std::string Node::ToString(std::string blank) const {
std::stringstream ss; std::stringstream ss;
ss << type_ << ": \n"; ss << type_ << "-> \n";
for (int i = 0; i < outputs_.size(); ++i) { for (int i = 0; i < outputs_.size(); ++i) {
ss << blank << outputs_[i]->ToString(blank + " ") << ""; ss << blank << outputs_[i]->ToString(blank + " ") << "";
} }
...@@ -56,9 +57,31 @@ std::string Node::ToString(std::string blank) const { ...@@ -56,9 +57,31 @@ std::string Node::ToString(std::string blank) const {
std::string Node::ToString() const { return this->ToString(" "); } std::string Node::ToString() const { return this->ToString(" "); }
Node &Node::To(int index) {
if (index == 0) {
this->outputs_.clear();
}
for (int j = 0; j < this->outputs_.size(); ++j) {
outputs_[j]->To(index - 1);
}
return *this;
}
uint Node::depth(uint begin) {
uint depth = 0;
begin++;
for (int i = 0; i < outputs_.size(); ++i) {
uint output_depth = outputs_[i]->depth(begin);
depth = output_depth > depth ? output_depth : depth;
}
return begin > depth ? begin : depth;
}
Print &operator<<(Print &printer, const Node &node) { Print &operator<<(Print &printer, const Node &node) {
printer << node.ToString(); printer << node.ToString();
return printer; return printer;
} }
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -22,6 +22,7 @@ SOFTWARE. ...@@ -22,6 +22,7 @@ SOFTWARE.
#include <vector> #include <vector>
#include "common/log.h" #include "common/log.h"
#include "framework/op_desc.h"
#include "framework/paddle_mobile_object.h" #include "framework/paddle_mobile_object.h"
namespace paddle_mobile { namespace paddle_mobile {
...@@ -30,12 +31,16 @@ namespace framework { ...@@ -30,12 +31,16 @@ namespace framework {
class Node : PaddleMobileObject { class Node : PaddleMobileObject {
public: public:
Node(const std::string &type) : type_(type) {} Node(const std::string &type) : type_(type) {}
Node(std::shared_ptr<OpDesc> op_desc)
Node &operator>(const Node &out); : op_desc_(op_desc), type_(op_desc->Type()){};
Node &operator>(std::shared_ptr<Node> node);
bool operator==(const Node &in); bool operator==(const Node &in);
std::string ToString() const; std::string ToString() const;
Node &To(int index);
uint depth(uint begin = 0);
private: private:
std::shared_ptr<OpDesc> op_desc_;
std::string ToString(std::string blank) const; std::string ToString(std::string blank) const;
std::vector<std::shared_ptr<Node>> outputs_; std::vector<std::shared_ptr<Node>> outputs_;
std::string type_; std::string type_;
......
...@@ -19,7 +19,31 @@ SOFTWARE. ...@@ -19,7 +19,31 @@ SOFTWARE.
#include "program_optimize.h" #include "program_optimize.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
std::shared_ptr<ProgramDesc> ProgramOptimize::Optimize() {} std::shared_ptr<ProgramDesc> ProgramOptimize::Optimize() {}
} // namespace framework
} // namespace paddle_mobile std::shared_ptr<ProgramDesc>
ProgramOptimize::FushionOptimize(std::shared_ptr<ProgramDesc> ori_des) {
for (int i = 0; i < ori_des->Blocks().size(); ++i) {
std::unordered_map<std::string, std::shared_ptr<Node>> output_nodes;
auto block = ori_des->Block(i);
for (int j = 0; j < block->Ops().size(); ++j) {
auto op = block->Ops()[j];
std::shared_ptr<Node> node = std::make_shared<Node>(op);
auto op_outputs = op->Output(op_input_output_key.at(op->Type())[1]);
for (int k = 0; k < op_outputs.size(); ++k) {
output_nodes[op_outputs[k]] = node;
}
auto op_iutputs = op->Output(op_input_output_key.at(op->Type())[0]);
for (int l = 0; l < op_iutputs.size(); ++l) {
auto input_node = output_nodes[op_iutputs[l]];
*input_node > node;
}
}
DLOG << output_nodes["feed"];
}
}
}
}
...@@ -18,18 +18,24 @@ SOFTWARE. ...@@ -18,18 +18,24 @@ SOFTWARE.
#pragma once #pragma once
#include "framework/operator.h"
#include "framework/program_desc.h" #include "framework/program_desc.h"
#include "node.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
class ProgramOptimize { class ProgramOptimize {
public: public:
ProgramOptimize(std::shared_ptr<ProgramDesc> ori_desc) ProgramOptimize() {}
: ori_desc_(ori_desc) {}
std::shared_ptr<ProgramDesc> Optimize(); std::shared_ptr<ProgramDesc> Optimize();
std::shared_ptr<ProgramDesc>
FushionOptimize(std::shared_ptr<ProgramDesc> ori_des);
private: private:
std::shared_ptr<ProgramDesc> ori_desc_; // std::shared_ptr<ProgramDesc> ori_desc_;
std::vector<std::unordered_map<std::string, std::shared_ptr<Node>>>
outputs_nodes_;
}; };
} // namespace framework }
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -17,21 +17,33 @@ SOFTWARE. ...@@ -17,21 +17,33 @@ SOFTWARE.
==============================================================================*/ ==============================================================================*/
#include "framework/program-optimize/node.h" #include "framework/program-optimize/node.h"
#include <iostream> #include "framework/program-optimize/program_optimize.h"
#include "io.h"
using namespace paddle_mobile;
using namespace paddle_mobile::framework; using namespace paddle_mobile::framework;
int main() { int main() {
Loader<paddle_mobile::CPU> loader;
// "../../../test/models/googlenet"
auto program = loader.Load("../../../test/models/googlenet");
ProgramOptimize optimize;
optimize.FushionOptimize(program.originProgram);
Node node("conv"); Node node("conv");
node > Node("add") > Node("relu");
Node node1("conv"); node > std::make_shared<Node>("add") > std::make_shared<Node>("relu") >
node1 > Node("add") > Node("relu"); std::make_shared<Node>("lrn");
node > std::make_shared<Node>("batch normal");
DLOG << "depath of node " << node.depth();
if (node == node1) { // Node node1("conv");
DLOG << "equal"; // node1 > Node("add") > Node("relu");
}
DLOG << "\n" << node1; Node node2 = node.To(4);
// DLOG << node; DLOG << "\n" << node2;
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册