提交 b8117214 编写于 作者: R Ruilong Liu 提交者: GitHub

Merge pull request #209 from codeWorm2015/develop

fix #208 program optimize generate graph
...@@ -51,8 +51,8 @@ paddle-mobile.cbp ...@@ -51,8 +51,8 @@ paddle-mobile.cbp
.idea .idea
compile_commands.json
cmake-build-debug/ cmake-build-debug/
test/models/ test/models/
\ No newline at end of file
...@@ -34,15 +34,19 @@ SOFTWARE. ...@@ -34,15 +34,19 @@ SOFTWARE.
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
static std::unordered_map<std::string, std::vector<std::string>> static std::unordered_map<
op_input_output_key = { std::string, std::pair<std::vector<std::string>, std::vector<std::string>>>
{"conv2d", {"Input", "Output"}}, {"relu", {"X", "Out"}}, op_input_output_key = {{"conv2d", {{"Input"}, {"Output"}}},
{"softmax", {"X", "Out"}}, {"mul", {"X", "Out"}}, {"relu", {{"X"}, {"Out"}}},
{"elementwise_add", {"X", "Out"}}, {"pool2d", {"X", "Out"}}, {"softmax", {{"X"}, {"Out"}}},
{"batch_norm", {"X", "Y"}}, {"lrn", {"X", "Out"}}, {"mul", {{"X"}, {"Out"}}},
{"concat", {"X", "Out"}}, {"elementwise_add", {{"X", "Y"}, {"Out"}}},
{"pool2d", {{"X"}, {"Out"}}},
}; {"batch_norm", {{"X"}, {"Y"}}},
{"lrn", {{"X"}, {"Out"}}},
{"concat", {{"X"}, {"Out"}}},
{"feed", {{"X"}, {"Out"}}},
{"fetch", {{"X"}, {"Out"}}}};
template <typename Dtype> class OperatorBase : PaddleMobileObject { template <typename Dtype> class OperatorBase : PaddleMobileObject {
public: public:
......
...@@ -21,10 +21,13 @@ SOFTWARE. ...@@ -21,10 +21,13 @@ SOFTWARE.
#include "node.h" #include "node.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
Node &Node::operator>(const Node &out) {
std::shared_ptr<Node> node = std::make_shared<Node>(Node(out)); Node &Node::operator>(std::shared_ptr<Node> node) {
outputs_.push_back(node); outputs_.push_back(node);
std::shared_ptr<Node> this_node;
node->inputs_.push_back(this);
return *node; return *node;
} }
...@@ -45,20 +48,49 @@ bool Node::operator==(const Node &in) { ...@@ -45,20 +48,49 @@ bool Node::operator==(const Node &in) {
return true; return true;
} }
std::string Node::ToString(std::string blank) const { std::string Node::ToString(std::string blank, const Node *node) const {
std::stringstream ss; std::stringstream ss;
ss << type_ << ": \n"; ss << type_ << "-> \n";
if (inputs_.size() > 1 && node != inputs_.back()) {
return ss.str();
} else if (inputs_.size() > 1 && node == inputs_.back()) {
ss << "\n" << blank << type_ << "\n";
}
for (int i = 0; i < outputs_.size(); ++i) { for (int i = 0; i < outputs_.size(); ++i) {
ss << blank << outputs_[i]->ToString(blank + " ") << ""; ss << blank << outputs_[i]->ToString(blank + " ", this) << "";
} }
return ss.str(); return ss.str();
} }
std::string Node::ToString() const { return this->ToString(" "); } std::string Node::ToString() const { return this->ToString(" ", this); }
Node &Node::To(int index) {
if (index == 0) {
this->outputs_.clear();
}
for (int j = 0; j < this->outputs_.size(); ++j) {
outputs_[j]->To(index - 1);
}
return *this;
}
uint Node::depth(uint begin) {
uint depth = 0;
begin++;
for (int i = 0; i < outputs_.size(); ++i) {
uint output_depth = outputs_[i]->depth(begin);
depth = output_depth > depth ? output_depth : depth;
}
return begin > depth ? begin : depth;
}
Print &operator<<(Print &printer, const Node &node) { Print &operator<<(Print &printer, const Node &node) {
printer << node.ToString(); printer << node.ToString();
return printer; return printer;
} }
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -22,6 +22,7 @@ SOFTWARE. ...@@ -22,6 +22,7 @@ SOFTWARE.
#include <vector> #include <vector>
#include "common/log.h" #include "common/log.h"
#include "framework/op_desc.h"
#include "framework/paddle_mobile_object.h" #include "framework/paddle_mobile_object.h"
namespace paddle_mobile { namespace paddle_mobile {
...@@ -30,14 +31,19 @@ namespace framework { ...@@ -30,14 +31,19 @@ namespace framework {
class Node : PaddleMobileObject { class Node : PaddleMobileObject {
public: public:
Node(const std::string &type) : type_(type) {} Node(const std::string &type) : type_(type) {}
Node(std::shared_ptr<OpDesc> op_desc)
Node &operator>(const Node &out); : op_desc_(op_desc), type_(op_desc->Type()){};
Node &operator>(std::shared_ptr<Node> node);
bool operator==(const Node &in); bool operator==(const Node &in);
std::string ToString() const; std::string ToString() const;
Node &To(int index);
uint depth(uint begin = 0);
private: private:
std::string ToString(std::string blank) const; std::shared_ptr<OpDesc> op_desc_;
std::string ToString(std::string blank, const Node *node) const;
std::vector<std::shared_ptr<Node>> outputs_; std::vector<std::shared_ptr<Node>> outputs_;
std::vector<Node *> inputs_;
std::string type_; std::string type_;
}; };
......
...@@ -19,7 +19,56 @@ SOFTWARE. ...@@ -19,7 +19,56 @@ SOFTWARE.
#include "program_optimize.h" #include "program_optimize.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
std::shared_ptr<ProgramDesc> ProgramOptimize::Optimize() {} std::shared_ptr<ProgramDesc> ProgramOptimize::Optimize() {}
std::shared_ptr<ProgramDesc>
ProgramOptimize::FushionOptimize(std::shared_ptr<ProgramDesc> ori_des) {
for (int i = 0; i < ori_des->Blocks().size(); ++i) {
std::unordered_map<std::string, std::shared_ptr<Node>> output_nodes;
std::shared_ptr<Node> begin_node;
auto block = ori_des->Block(i);
// DLOG << " ops size: " << block->Ops().size();
for (int j = 0; j < block->Ops().size(); ++j) {
auto op = block->Ops()[j];
auto op_type = op->Type();
// DLOG << "op type: " << op_type << " index: " << j;
if (op_input_output_key.find(op->Type()) ==
op_input_output_key.end()) {
return NULL;
}
std::shared_ptr<Node> node = std::make_shared<Node>(op);
if (j == 0) {
begin_node = node;
}
auto input_keys = op_input_output_key.at(op->Type()).first;
for (auto input_key : input_keys) {
auto op_inputs = op->Input(input_key);
for (int l = 0; l < op_inputs.size(); ++l) {
std::string input_key = op_inputs[l];
if (output_nodes.find(input_key) != output_nodes.end()) {
auto input_node = output_nodes[input_key];
*input_node > node;
}
}
}
auto output_keys = op_input_output_key.at(op_type).second;
for (auto output_key : output_keys) {
auto op_outputs = op->Output(output_key);
for (int k = 0; k < op_outputs.size(); ++k) {
output_nodes[op_outputs[k]] = node;
}
}
}
DLOG << "node: \n" << *begin_node;
}
return ori_des;
}
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -18,18 +18,24 @@ SOFTWARE. ...@@ -18,18 +18,24 @@ SOFTWARE.
#pragma once #pragma once
#include "framework/operator.h"
#include "framework/program_desc.h" #include "framework/program_desc.h"
#include "node.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
class ProgramOptimize { class ProgramOptimize {
public: public:
ProgramOptimize(std::shared_ptr<ProgramDesc> ori_desc) ProgramOptimize() {}
: ori_desc_(ori_desc) {}
std::shared_ptr<ProgramDesc> Optimize(); std::shared_ptr<ProgramDesc> Optimize();
std::shared_ptr<ProgramDesc>
FushionOptimize(std::shared_ptr<ProgramDesc> ori_des);
private: private:
std::shared_ptr<ProgramDesc> ori_desc_; // std::shared_ptr<ProgramDesc> ori_desc_;
std::vector<std::unordered_map<std::string, std::shared_ptr<Node>>>
outputs_nodes_;
}; };
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -195,76 +195,55 @@ Loader<Dtype, P>::Load(const std::string &dirname) { ...@@ -195,76 +195,55 @@ Loader<Dtype, P>::Load(const std::string &dirname) {
LOG(kLOG_DEBUG) << "block: " << block.idx(); LOG(kLOG_DEBUG) << "block: " << block.idx();
for (int j = 0; j < block.ops().size(); ++j) { for (int j = 0; j < block.ops().size(); ++j) {
framework::proto::OpDesc op = block.ops()[j]; framework::proto::OpDesc op = block.ops()[j];
LOG(kLOG_DEBUG1) << " op: " << op.type(); LOG(kLOG_DEBUG1) << "op: " << op.type();
for (int m = 0; m < op.inputs_size(); ++m) { for (int m = 0; m < op.inputs_size(); ++m) {
const framework::proto::OpDesc::Var &var = op.inputs(m); const framework::proto::OpDesc::Var &var = op.inputs(m);
LOG(kLOG_DEBUG2) << " input parameter: " << var.parameter(); LOG(kLOG_DEBUG2) << "input parameter: " << var.parameter();
for (int n = 0; n < var.arguments().size(); ++n) { for (int n = 0; n < var.arguments().size(); ++n) {
LOG(kLOG_DEBUG3) << " argument - " << var.arguments()[n]; LOG(kLOG_DEBUG3) << "argument - " << var.arguments()[n];
} }
} }
for (int y = 0; y < op.outputs_size(); ++y) { for (int y = 0; y < op.outputs_size(); ++y) {
const framework::proto::OpDesc::Var &var = op.outputs(y); const framework::proto::OpDesc::Var &var = op.outputs(y);
LOG(kLOG_DEBUG2) << " out parameter: " << var.parameter(); LOG(kLOG_DEBUG2) << "out parameter: " << var.parameter();
for (int z = 0; z < var.arguments().size(); ++z) { for (int z = 0; z < var.arguments().size(); ++z) {
LOG(kLOG_DEBUG3) << " argument - " << var.arguments()[z]; LOG(kLOG_DEBUG3) << "argument - " << var.arguments()[z];
} }
} }
for (int x = 0; x < op.attrs().size(); ++x) { for (int x = 0; x < op.attrs().size(); ++x) {
const framework::proto::OpDesc_Attr attr = op.attrs()[x]; const framework::proto::OpDesc_Attr attr = op.attrs()[x];
// std::cout << " attr name: " << attr.name() << LOG(kLOG_DEBUG2) << "attr name: " << attr.name();
// std::endl;
// std::cout << " attr type: " << attr.type() <<
// std::endl;
switch (attr.type()) { switch (attr.type()) {
case framework::proto::AttrType::BOOLEAN: case framework::proto::AttrType::BOOLEAN:
// std::cout << " boolen: " << attr.b() << LOG(kLOG_DEBUG3) << "boolen: " << attr.b();
// std::endl;
break; break;
case framework::proto::AttrType::INT: case framework::proto::AttrType::INT:
// std::cout << " int: " << attr.i() << LOG(kLOG_DEBUG3) << "int: " << attr.i();
// std::endl;
break; break;
case framework::proto::AttrType::FLOAT: case framework::proto::AttrType::FLOAT:
// std::cout << " float: " << attr.f() << LOG(kLOG_DEBUG3) << "float: " << attr.f();
// std::endl;
case framework::proto::AttrType::STRING: case framework::proto::AttrType::STRING:
// std::cout << " string: " << attr.s() << LOG(kLOG_DEBUG3) << "string: " << attr.s();
// std::endl;
case framework::proto::AttrType::BOOLEANS: case framework::proto::AttrType::BOOLEANS:
// std::vector<bool>
// bools(attr.bools_size());
for (int y = 0; y < attr.bools_size(); ++y) { for (int y = 0; y < attr.bools_size(); ++y) {
// std::cout << " bool - " << LOG(kLOG_DEBUG3) << "bools: " << attr.bools(y);
// attr.bools(y) <<
// std::endl;
} }
case framework::proto::AttrType::LONG: case framework::proto::AttrType::LONG:
// std::cout << " long: " << attr.l() << LOG(kLOG_DEBUG3) << "long: " << attr.l();
// std::endl;
case framework::proto::AttrType::FLOATS: case framework::proto::AttrType::FLOATS:
for (int y = 0; y < attr.floats_size(); ++y) { for (int y = 0; y < attr.floats_size(); ++y) {
// std::cout << " float - " << y << LOG(kLOG_DEBUG3) << "floats: " << attr.floats(y);
// ": " <<
// attr.floats(y)
// << std::endl;
} }
case framework::proto::AttrType::INTS: case framework::proto::AttrType::INTS:
for (int y = 0; y < attr.ints_size(); ++y) { for (int y = 0; y < attr.ints_size(); ++y) {
// std::cout << " int - " << y << ": LOG(kLOG_DEBUG3) << "ints: " << attr.ints(y);
// " <<
// attr.ints(y)
// << std::endl;
} }
case framework::proto::AttrType::STRINGS: case framework::proto::AttrType::STRINGS:
for (int y = 0; y < attr.strings_size(); ++y) { for (int y = 0; y < attr.strings_size(); ++y) {
// std::cout << " string - " << y << LOG(kLOG_DEBUG3) << "strings: " << attr.strings(y);
// ": " <<
// attr.strings(y)
// << std::endl;
} }
} }
} }
...@@ -273,19 +252,15 @@ Loader<Dtype, P>::Load(const std::string &dirname) { ...@@ -273,19 +252,15 @@ Loader<Dtype, P>::Load(const std::string &dirname) {
for (int k = 0; k < block.vars().size(); ++k) { for (int k = 0; k < block.vars().size(); ++k) {
framework::proto::VarDesc var = block.vars()[k]; framework::proto::VarDesc var = block.vars()[k];
if (var.type().type() == framework::proto::VarType::LOD_TENSOR) { if (var.type().type() == framework::proto::VarType::LOD_TENSOR) {
// std::cout << " var name: " << var.name() << LOG(kLOG_DEBUG1) << "var name: " << var.name();
// std::endl;
const framework::proto::VarType::TensorDesc &tensor_desc = const framework::proto::VarType::TensorDesc &tensor_desc =
var.type().lod_tensor().tensor(); var.type().lod_tensor().tensor();
// std::cout << " in var tensor desc dims size " LOG(kLOG_DEBUG2) << "in var tensor desc dims size: "
// << tensor_desc.dims().size() << << tensor_desc.dims().size();
// std::endl;
int memory_size = 1; int memory_size = 1;
for (int l = 0; l < tensor_desc.dims().size(); ++l) { for (int l = 0; l < tensor_desc.dims().size(); ++l) {
// std::cout << " var tensor desc dim " << l LOG(kLOG_DEBUG3) << "var tensor desc dim " << l
// << " value: " << << " value: " << tensor_desc.dims()[l];
// tensor_desc.dims()[l] <<
// std::endl;
} }
} }
......
...@@ -31,7 +31,6 @@ target_link_libraries(test-log paddle-mobile) ...@@ -31,7 +31,6 @@ target_link_libraries(test-log paddle-mobile)
ADD_EXECUTABLE(test-load framework/test_load.cpp) ADD_EXECUTABLE(test-load framework/test_load.cpp)
target_link_libraries(test-load paddle-mobile) target_link_libraries(test-load paddle-mobile)
# gen test log # gen test log
ADD_EXECUTABLE(test-optimize framework/test_optimize.cpp) ADD_EXECUTABLE(test-optimize framework/test_optimize.cpp)
target_link_libraries(test-optimize paddle-mobile) target_link_libraries(test-optimize paddle-mobile)
\ No newline at end of file
...@@ -23,6 +23,6 @@ int main() { ...@@ -23,6 +23,6 @@ int main() {
//../../../test/models/googlenet //../../../test/models/googlenet
//../../../test/models/mobilenet //../../../test/models/mobilenet
auto program = loader.Load(std::string("../../../test/models/googlenet")); auto program = loader.Load(std::string("../models/googlenet"));
return 0; return 0;
} }
\ No newline at end of file
...@@ -17,21 +17,24 @@ SOFTWARE. ...@@ -17,21 +17,24 @@ SOFTWARE.
==============================================================================*/ ==============================================================================*/
#include "framework/program-optimize/node.h" #include "framework/program-optimize/node.h"
#include <iostream> #include "framework/program-optimize/program_optimize.h"
#include "io.h"
using namespace paddle_mobile;
using namespace paddle_mobile::framework; using namespace paddle_mobile::framework;
int main() { int main() {
Node node("conv");
node > Node("add") > Node("relu");
Node node1("conv"); Loader<paddle_mobile::CPU> loader;
node1 > Node("add") > Node("relu"); // "../../../test/models/googlenet"
auto program = loader.Load("../models/googlenet");
if (node == node1) { ProgramOptimize optimize;
DLOG << "equal";
} auto optimize_program = optimize.FushionOptimize(program.originProgram);
if (optimize_program) {
DLOG << "\n" << node1; } else {
// DLOG << node; DLOG << "optimize_program is null";
}
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册