提交 b8117214 编写于 作者: R Ruilong Liu 提交者: GitHub

Merge pull request #209 from codeWorm2015/develop

fix #208 program optimize generate graph
......@@ -51,8 +51,8 @@ paddle-mobile.cbp
.idea
compile_commands.json
cmake-build-debug/
test/models/
\ No newline at end of file
......@@ -34,15 +34,19 @@ SOFTWARE.
namespace paddle_mobile {
namespace framework {
static std::unordered_map<std::string, std::vector<std::string>>
op_input_output_key = {
{"conv2d", {"Input", "Output"}}, {"relu", {"X", "Out"}},
{"softmax", {"X", "Out"}}, {"mul", {"X", "Out"}},
{"elementwise_add", {"X", "Out"}}, {"pool2d", {"X", "Out"}},
{"batch_norm", {"X", "Y"}}, {"lrn", {"X", "Out"}},
{"concat", {"X", "Out"}},
};
static std::unordered_map<
std::string, std::pair<std::vector<std::string>, std::vector<std::string>>>
op_input_output_key = {{"conv2d", {{"Input"}, {"Output"}}},
{"relu", {{"X"}, {"Out"}}},
{"softmax", {{"X"}, {"Out"}}},
{"mul", {{"X"}, {"Out"}}},
{"elementwise_add", {{"X", "Y"}, {"Out"}}},
{"pool2d", {{"X"}, {"Out"}}},
{"batch_norm", {{"X"}, {"Y"}}},
{"lrn", {{"X"}, {"Out"}}},
{"concat", {{"X"}, {"Out"}}},
{"feed", {{"X"}, {"Out"}}},
{"fetch", {{"X"}, {"Out"}}}};
template <typename Dtype> class OperatorBase : PaddleMobileObject {
public:
......
......@@ -21,10 +21,13 @@ SOFTWARE.
#include "node.h"
namespace paddle_mobile {
namespace framework {
Node &Node::operator>(const Node &out) {
std::shared_ptr<Node> node = std::make_shared<Node>(Node(out));
Node &Node::operator>(std::shared_ptr<Node> node) {
outputs_.push_back(node);
std::shared_ptr<Node> this_node;
node->inputs_.push_back(this);
return *node;
}
......@@ -45,20 +48,49 @@ bool Node::operator==(const Node &in) {
return true;
}
std::string Node::ToString(std::string blank) const {
std::string Node::ToString(std::string blank, const Node *node) const {
std::stringstream ss;
ss << type_ << ": \n";
ss << type_ << "-> \n";
if (inputs_.size() > 1 && node != inputs_.back()) {
return ss.str();
} else if (inputs_.size() > 1 && node == inputs_.back()) {
ss << "\n" << blank << type_ << "\n";
}
for (int i = 0; i < outputs_.size(); ++i) {
ss << blank << outputs_[i]->ToString(blank + " ") << "";
ss << blank << outputs_[i]->ToString(blank + " ", this) << "";
}
return ss.str();
}
std::string Node::ToString() const { return this->ToString(" "); }
std::string Node::ToString() const { return this->ToString(" ", this); }
Node &Node::To(int index) {
if (index == 0) {
this->outputs_.clear();
}
for (int j = 0; j < this->outputs_.size(); ++j) {
outputs_[j]->To(index - 1);
}
return *this;
}
uint Node::depth(uint begin) {
uint depth = 0;
begin++;
for (int i = 0; i < outputs_.size(); ++i) {
uint output_depth = outputs_[i]->depth(begin);
depth = output_depth > depth ? output_depth : depth;
}
return begin > depth ? begin : depth;
}
Print &operator<<(Print &printer, const Node &node) {
printer << node.ToString();
return printer;
}
} // namespace framework
} // namespace paddle_mobile
......@@ -22,6 +22,7 @@ SOFTWARE.
#include <vector>
#include "common/log.h"
#include "framework/op_desc.h"
#include "framework/paddle_mobile_object.h"
namespace paddle_mobile {
......@@ -30,14 +31,19 @@ namespace framework {
class Node : PaddleMobileObject {
public:
Node(const std::string &type) : type_(type) {}
Node &operator>(const Node &out);
Node(std::shared_ptr<OpDesc> op_desc)
: op_desc_(op_desc), type_(op_desc->Type()){};
Node &operator>(std::shared_ptr<Node> node);
bool operator==(const Node &in);
std::string ToString() const;
Node &To(int index);
uint depth(uint begin = 0);
private:
std::string ToString(std::string blank) const;
std::shared_ptr<OpDesc> op_desc_;
std::string ToString(std::string blank, const Node *node) const;
std::vector<std::shared_ptr<Node>> outputs_;
std::vector<Node *> inputs_;
std::string type_;
};
......
......@@ -19,7 +19,56 @@ SOFTWARE.
#include "program_optimize.h"
namespace paddle_mobile {
namespace framework {
std::shared_ptr<ProgramDesc> ProgramOptimize::Optimize() {}
std::shared_ptr<ProgramDesc>
ProgramOptimize::FushionOptimize(std::shared_ptr<ProgramDesc> ori_des) {
for (int i = 0; i < ori_des->Blocks().size(); ++i) {
std::unordered_map<std::string, std::shared_ptr<Node>> output_nodes;
std::shared_ptr<Node> begin_node;
auto block = ori_des->Block(i);
// DLOG << " ops size: " << block->Ops().size();
for (int j = 0; j < block->Ops().size(); ++j) {
auto op = block->Ops()[j];
auto op_type = op->Type();
// DLOG << "op type: " << op_type << " index: " << j;
if (op_input_output_key.find(op->Type()) ==
op_input_output_key.end()) {
return NULL;
}
std::shared_ptr<Node> node = std::make_shared<Node>(op);
if (j == 0) {
begin_node = node;
}
auto input_keys = op_input_output_key.at(op->Type()).first;
for (auto input_key : input_keys) {
auto op_inputs = op->Input(input_key);
for (int l = 0; l < op_inputs.size(); ++l) {
std::string input_key = op_inputs[l];
if (output_nodes.find(input_key) != output_nodes.end()) {
auto input_node = output_nodes[input_key];
*input_node > node;
}
}
}
auto output_keys = op_input_output_key.at(op_type).second;
for (auto output_key : output_keys) {
auto op_outputs = op->Output(output_key);
for (int k = 0; k < op_outputs.size(); ++k) {
output_nodes[op_outputs[k]] = node;
}
}
}
DLOG << "node: \n" << *begin_node;
}
return ori_des;
}
} // namespace framework
} // namespace paddle_mobile
......@@ -18,18 +18,24 @@ SOFTWARE.
#pragma once
#include "framework/operator.h"
#include "framework/program_desc.h"
#include "node.h"
namespace paddle_mobile {
namespace framework {
class ProgramOptimize {
public:
ProgramOptimize(std::shared_ptr<ProgramDesc> ori_desc)
: ori_desc_(ori_desc) {}
ProgramOptimize() {}
std::shared_ptr<ProgramDesc> Optimize();
std::shared_ptr<ProgramDesc>
FushionOptimize(std::shared_ptr<ProgramDesc> ori_des);
private:
std::shared_ptr<ProgramDesc> ori_desc_;
// std::shared_ptr<ProgramDesc> ori_desc_;
std::vector<std::unordered_map<std::string, std::shared_ptr<Node>>>
outputs_nodes_;
};
} // namespace framework
} // namespace paddle_mobile
......@@ -195,76 +195,55 @@ Loader<Dtype, P>::Load(const std::string &dirname) {
LOG(kLOG_DEBUG) << "block: " << block.idx();
for (int j = 0; j < block.ops().size(); ++j) {
framework::proto::OpDesc op = block.ops()[j];
LOG(kLOG_DEBUG1) << " op: " << op.type();
LOG(kLOG_DEBUG1) << "op: " << op.type();
for (int m = 0; m < op.inputs_size(); ++m) {
const framework::proto::OpDesc::Var &var = op.inputs(m);
LOG(kLOG_DEBUG2) << " input parameter: " << var.parameter();
LOG(kLOG_DEBUG2) << "input parameter: " << var.parameter();
for (int n = 0; n < var.arguments().size(); ++n) {
LOG(kLOG_DEBUG3) << " argument - " << var.arguments()[n];
LOG(kLOG_DEBUG3) << "argument - " << var.arguments()[n];
}
}
for (int y = 0; y < op.outputs_size(); ++y) {
const framework::proto::OpDesc::Var &var = op.outputs(y);
LOG(kLOG_DEBUG2) << " out parameter: " << var.parameter();
LOG(kLOG_DEBUG2) << "out parameter: " << var.parameter();
for (int z = 0; z < var.arguments().size(); ++z) {
LOG(kLOG_DEBUG3) << " argument - " << var.arguments()[z];
LOG(kLOG_DEBUG3) << "argument - " << var.arguments()[z];
}
}
for (int x = 0; x < op.attrs().size(); ++x) {
const framework::proto::OpDesc_Attr attr = op.attrs()[x];
// std::cout << " attr name: " << attr.name() <<
// std::endl;
// std::cout << " attr type: " << attr.type() <<
// std::endl;
LOG(kLOG_DEBUG2) << "attr name: " << attr.name();
switch (attr.type()) {
case framework::proto::AttrType::BOOLEAN:
// std::cout << " boolen: " << attr.b() <<
// std::endl;
LOG(kLOG_DEBUG3) << "boolen: " << attr.b();
break;
case framework::proto::AttrType::INT:
// std::cout << " int: " << attr.i() <<
// std::endl;
LOG(kLOG_DEBUG3) << "int: " << attr.i();
break;
case framework::proto::AttrType::FLOAT:
// std::cout << " float: " << attr.f() <<
// std::endl;
LOG(kLOG_DEBUG3) << "float: " << attr.f();
case framework::proto::AttrType::STRING:
// std::cout << " string: " << attr.s() <<
// std::endl;
LOG(kLOG_DEBUG3) << "string: " << attr.s();
case framework::proto::AttrType::BOOLEANS:
// std::vector<bool>
// bools(attr.bools_size());
for (int y = 0; y < attr.bools_size(); ++y) {
// std::cout << " bool - " <<
// attr.bools(y) <<
// std::endl;
LOG(kLOG_DEBUG3) << "bools: " << attr.bools(y);
}
case framework::proto::AttrType::LONG:
// std::cout << " long: " << attr.l() <<
// std::endl;
LOG(kLOG_DEBUG3) << "long: " << attr.l();
case framework::proto::AttrType::FLOATS:
for (int y = 0; y < attr.floats_size(); ++y) {
// std::cout << " float - " << y <<
// ": " <<
// attr.floats(y)
// << std::endl;
LOG(kLOG_DEBUG3) << "floats: " << attr.floats(y);
}
case framework::proto::AttrType::INTS:
for (int y = 0; y < attr.ints_size(); ++y) {
// std::cout << " int - " << y << ":
// " <<
// attr.ints(y)
// << std::endl;
LOG(kLOG_DEBUG3) << "ints: " << attr.ints(y);
}
case framework::proto::AttrType::STRINGS:
for (int y = 0; y < attr.strings_size(); ++y) {
// std::cout << " string - " << y <<
// ": " <<
// attr.strings(y)
// << std::endl;
LOG(kLOG_DEBUG3) << "strings: " << attr.strings(y);
}
}
}
......@@ -273,19 +252,15 @@ Loader<Dtype, P>::Load(const std::string &dirname) {
for (int k = 0; k < block.vars().size(); ++k) {
framework::proto::VarDesc var = block.vars()[k];
if (var.type().type() == framework::proto::VarType::LOD_TENSOR) {
// std::cout << " var name: " << var.name() <<
// std::endl;
LOG(kLOG_DEBUG1) << "var name: " << var.name();
const framework::proto::VarType::TensorDesc &tensor_desc =
var.type().lod_tensor().tensor();
// std::cout << " in var tensor desc dims size "
// << tensor_desc.dims().size() <<
// std::endl;
LOG(kLOG_DEBUG2) << "in var tensor desc dims size: "
<< tensor_desc.dims().size();
int memory_size = 1;
for (int l = 0; l < tensor_desc.dims().size(); ++l) {
// std::cout << " var tensor desc dim " << l
// << " value: " <<
// tensor_desc.dims()[l] <<
// std::endl;
LOG(kLOG_DEBUG3) << "var tensor desc dim " << l
<< " value: " << tensor_desc.dims()[l];
}
}
......
......@@ -31,7 +31,6 @@ target_link_libraries(test-log paddle-mobile)
ADD_EXECUTABLE(test-load framework/test_load.cpp)
target_link_libraries(test-load paddle-mobile)
# gen test log
ADD_EXECUTABLE(test-optimize framework/test_optimize.cpp)
target_link_libraries(test-optimize paddle-mobile)
\ No newline at end of file
target_link_libraries(test-optimize paddle-mobile)
......@@ -23,6 +23,6 @@ int main() {
//../../../test/models/googlenet
//../../../test/models/mobilenet
auto program = loader.Load(std::string("../../../test/models/googlenet"));
auto program = loader.Load(std::string("../models/googlenet"));
return 0;
}
\ No newline at end of file
......@@ -17,21 +17,24 @@ SOFTWARE.
==============================================================================*/
#include "framework/program-optimize/node.h"
#include <iostream>
#include "framework/program-optimize/program_optimize.h"
#include "io.h"
using namespace paddle_mobile;
using namespace paddle_mobile::framework;
int main() {
Node node("conv");
node > Node("add") > Node("relu");
Node node1("conv");
node1 > Node("add") > Node("relu");
Loader<paddle_mobile::CPU> loader;
// "../../../test/models/googlenet"
auto program = loader.Load("../models/googlenet");
if (node == node1) {
DLOG << "equal";
}
ProgramOptimize optimize;
auto optimize_program = optimize.FushionOptimize(program.originProgram);
if (optimize_program) {
DLOG << "\n" << node1;
// DLOG << node;
} else {
DLOG << "optimize_program is null";
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册