提交 f4126618 编写于 作者: L liuruilong

fix confict

上级 45de1e84
......@@ -21,9 +21,10 @@ SOFTWARE.
#include "node.h"
namespace paddle_mobile {
namespace framework {
Node &Node::operator>(const Node &out) {
std::shared_ptr<Node> node = std::make_shared<Node>(Node(out));
Node &Node::operator>(std::shared_ptr<Node> node) {
outputs_.push_back(node);
return *node;
}
......@@ -47,7 +48,7 @@ bool Node::operator==(const Node &in) {
std::string Node::ToString(std::string blank) const {
std::stringstream ss;
ss << type_ << ": \n";
ss << type_ << "-> \n";
for (int i = 0; i < outputs_.size(); ++i) {
ss << blank << outputs_[i]->ToString(blank + " ") << "";
}
......@@ -56,9 +57,31 @@ std::string Node::ToString(std::string blank) const {
std::string Node::ToString() const { return this->ToString(" "); }
Node &Node::To(int index) {
if (index == 0) {
this->outputs_.clear();
}
for (int j = 0; j < this->outputs_.size(); ++j) {
outputs_[j]->To(index - 1);
}
return *this;
}
uint Node::depth(uint begin) {
uint depth = 0;
begin++;
for (int i = 0; i < outputs_.size(); ++i) {
uint output_depth = outputs_[i]->depth(begin);
depth = output_depth > depth ? output_depth : depth;
}
return begin > depth ? begin : depth;
}
Print &operator<<(Print &printer, const Node &node) {
printer << node.ToString();
return printer;
}
} // namespace framework
} // namespace paddle_mobile
......@@ -22,6 +22,7 @@ SOFTWARE.
#include <vector>
#include "common/log.h"
#include "framework/op_desc.h"
#include "framework/paddle_mobile_object.h"
namespace paddle_mobile {
......@@ -30,12 +31,16 @@ namespace framework {
class Node : PaddleMobileObject {
public:
Node(const std::string &type) : type_(type) {}
Node &operator>(const Node &out);
Node(std::shared_ptr<OpDesc> op_desc)
: op_desc_(op_desc), type_(op_desc->Type()){};
Node &operator>(std::shared_ptr<Node> node);
bool operator==(const Node &in);
std::string ToString() const;
Node &To(int index);
uint depth(uint begin = 0);
private:
std::shared_ptr<OpDesc> op_desc_;
std::string ToString(std::string blank) const;
std::vector<std::shared_ptr<Node>> outputs_;
std::string type_;
......
......@@ -19,7 +19,31 @@ SOFTWARE.
#include "program_optimize.h"
namespace paddle_mobile {
namespace framework {
std::shared_ptr<ProgramDesc> ProgramOptimize::Optimize() {}
} // namespace framework
} // namespace paddle_mobile
std::shared_ptr<ProgramDesc>
ProgramOptimize::FushionOptimize(std::shared_ptr<ProgramDesc> ori_des) {
for (int i = 0; i < ori_des->Blocks().size(); ++i) {
std::unordered_map<std::string, std::shared_ptr<Node>> output_nodes;
auto block = ori_des->Block(i);
for (int j = 0; j < block->Ops().size(); ++j) {
auto op = block->Ops()[j];
std::shared_ptr<Node> node = std::make_shared<Node>(op);
auto op_outputs = op->Output(op_input_output_key.at(op->Type())[1]);
for (int k = 0; k < op_outputs.size(); ++k) {
output_nodes[op_outputs[k]] = node;
}
auto op_iutputs = op->Output(op_input_output_key.at(op->Type())[0]);
for (int l = 0; l < op_iutputs.size(); ++l) {
auto input_node = output_nodes[op_iutputs[l]];
*input_node > node;
}
}
DLOG << output_nodes["feed"];
}
}
}
}
......@@ -18,18 +18,24 @@ SOFTWARE.
#pragma once
#include "framework/operator.h"
#include "framework/program_desc.h"
#include "node.h"
namespace paddle_mobile {
namespace framework {
class ProgramOptimize {
public:
ProgramOptimize(std::shared_ptr<ProgramDesc> ori_desc)
: ori_desc_(ori_desc) {}
ProgramOptimize() {}
std::shared_ptr<ProgramDesc> Optimize();
std::shared_ptr<ProgramDesc>
FushionOptimize(std::shared_ptr<ProgramDesc> ori_des);
private:
std::shared_ptr<ProgramDesc> ori_desc_;
// std::shared_ptr<ProgramDesc> ori_desc_;
std::vector<std::unordered_map<std::string, std::shared_ptr<Node>>>
outputs_nodes_;
};
} // namespace framework
}
} // namespace paddle_mobile
......@@ -17,21 +17,33 @@ SOFTWARE.
==============================================================================*/
#include "framework/program-optimize/node.h"
#include <iostream>
#include "framework/program-optimize/program_optimize.h"
#include "io.h"
using namespace paddle_mobile;
using namespace paddle_mobile::framework;
int main() {
Loader<paddle_mobile::CPU> loader;
// "../../../test/models/googlenet"
auto program = loader.Load("../../../test/models/googlenet");
ProgramOptimize optimize;
optimize.FushionOptimize(program.originProgram);
Node node("conv");
node > Node("add") > Node("relu");
Node node1("conv");
node1 > Node("add") > Node("relu");
node > std::make_shared<Node>("add") > std::make_shared<Node>("relu") >
std::make_shared<Node>("lrn");
node > std::make_shared<Node>("batch normal");
DLOG << "depath of node " << node.depth();
if (node == node1) {
DLOG << "equal";
}
// Node node1("conv");
// node1 > Node("add") > Node("relu");
DLOG << "\n" << node1;
// DLOG << node;
Node node2 = node.To(4);
DLOG << "\n" << node2;
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册