提交 7cd66c46 编写于 作者: R Ruilong Liu 提交者: GitHub

Merge pull request #328 from codeWorm2015/develop

fix #327 add split model config
...@@ -151,7 +151,7 @@ class FusionOpMatcher : PaddleMobileObject { ...@@ -151,7 +151,7 @@ class FusionOpMatcher : PaddleMobileObject {
virtual Node &BeginNode() { return node_; } virtual Node &BeginNode() { return node_; }
std::string BeginType() { return node_.BeginType(); } std::string BeginType() { return node_.Type(); }
protected: protected:
Node node_; Node node_;
......
...@@ -25,13 +25,7 @@ std::vector<std::shared_ptr<VarDesc>> BlockDesc::Vars() const { ...@@ -25,13 +25,7 @@ std::vector<std::shared_ptr<VarDesc>> BlockDesc::Vars() const {
return res; return res;
} }
std::vector<std::shared_ptr<OpDesc>> BlockDesc::Ops() const { std::vector<std::shared_ptr<OpDesc>> BlockDesc::Ops() const { return ops_; }
std::vector<std::shared_ptr<OpDesc>> res;
for (const auto &op : ops_) {
res.push_back(op);
}
return res;
}
BlockDesc::BlockDesc(PaddleMobile__Framework__Proto__BlockDesc *desc) BlockDesc::BlockDesc(PaddleMobile__Framework__Proto__BlockDesc *desc)
: index_(desc->idx), parent_index_(desc->idx) { : index_(desc->idx), parent_index_(desc->idx) {
......
...@@ -26,6 +26,7 @@ class BlockDesc : PaddleMobileObject { ...@@ -26,6 +26,7 @@ class BlockDesc : PaddleMobileObject {
public: public:
friend class Node; friend class Node;
friend class ProgramOptimize; friend class ProgramOptimize;
BlockDesc() {}
BlockDesc(PaddleMobile__Framework__Proto__BlockDesc *desc); BlockDesc(PaddleMobile__Framework__Proto__BlockDesc *desc);
BlockDesc(const BlockDesc &block_desc) BlockDesc(const BlockDesc &block_desc)
: index_(block_desc.index_), parent_index_(block_desc.parent_index_) { : index_(block_desc.index_), parent_index_(block_desc.parent_index_) {
...@@ -43,6 +44,8 @@ class BlockDesc : PaddleMobileObject { ...@@ -43,6 +44,8 @@ class BlockDesc : PaddleMobileObject {
const int &ID() const { return index_; } const int &ID() const { return index_; }
const bool &MultiThread() const { return multi_thread_; }
const int &Parent() const { return parent_index_; } const int &Parent() const { return parent_index_; }
bool operator==(const paddle_mobile::framework::BlockDesc &in_block) const { bool operator==(const paddle_mobile::framework::BlockDesc &in_block) const {
...@@ -58,6 +61,7 @@ class BlockDesc : PaddleMobileObject { ...@@ -58,6 +61,7 @@ class BlockDesc : PaddleMobileObject {
private: private:
int index_; int index_;
bool multi_thread_;
int parent_index_; int parent_index_;
std::vector<std::shared_ptr<OpDesc>> ops_; std::vector<std::shared_ptr<OpDesc>> ops_;
std::unordered_map<std::string, std::shared_ptr<VarDesc>> vars_; std::unordered_map<std::string, std::shared_ptr<VarDesc>> vars_;
......
...@@ -45,17 +45,6 @@ bool Node::operator==(const Node &in) { ...@@ -45,17 +45,6 @@ bool Node::operator==(const Node &in) {
return true; return true;
} }
// std::shared_ptr<Node> Node::MatchTheFirstNode(std::string type){
//
// for (const auto &node : outputs_){
// if (node->type_ == type){
// return node;
// }else{
//
// }
// }
//}
std::vector<std::shared_ptr<framework::OpDesc>> Node::OpDescs(uint size) { std::vector<std::shared_ptr<framework::OpDesc>> Node::OpDescs(uint size) {
std::vector<std::shared_ptr<framework::OpDesc>> op_descs; std::vector<std::shared_ptr<framework::OpDesc>> op_descs;
OpDescs(size - 1, &op_descs); OpDescs(size - 1, &op_descs);
...@@ -75,21 +64,40 @@ void Node::OpDescs(uint index, ...@@ -75,21 +64,40 @@ void Node::OpDescs(uint index,
void Node::OpDescs(std::vector<std::shared_ptr<framework::OpDesc>> *op_desc, void Node::OpDescs(std::vector<std::shared_ptr<framework::OpDesc>> *op_desc,
Node *node, bool adding_thread, int thread_num) { Node *node, bool adding_thread, int thread_num) {
bool can_add_split = false;
if (outputs_.size() > 1) { if (outputs_.size() > 1) {
adding_thread = false;
}
bool can_add_split = false;
// 如果当前节点有多个输出 并且 只有当前节点对应的 op_desc_ 输出数为 1 时支持
if (outputs_.size() > 1 &&
op_input_output_key[op_desc_->type_].second.size() == 1) {
can_add_split = true; can_add_split = true;
if (op_input_output_key[op_desc_->type_].second.size() != 1) {
DLOG << "当前 op desc 输出数不为 1 "; // 遍历当前节点的 output 节点
can_add_split = false;
}
for (const auto &output : outputs_) { for (const auto &output : outputs_) {
if (op_input_output_key.find(output->op_desc_->type_) != // 不支持 output 有多个 output 的情况
op_input_output_key.end()) { if (output->outputs_.size() > 0) {
auto inputs_and_outputs = op_input_output_key[output->op_desc_->type_]; can_add_split = false;
auto outputs_of_output = break;
output->op_desc_->Output(inputs_and_outputs.second[0]); }
auto inputs_of_output =
output->op_desc_->Input(inputs_and_outputs.first[0]); //与节点关联的 OpDesc
std::shared_ptr<framework::OpDesc> &op_desc = output->op_desc_;
//获取这个 op 的 inputs key 和 outputs key
auto inputs_and_outputs = op_input_output_key[op_desc->type_];
//判断现在 是否存在这个 op
//判断这个 output 和 input key 的 size 等于 1
if (op_input_output_key.find(op_desc->type_) !=
op_input_output_key.end() &&
inputs_and_outputs.first.size() == 1 &&
inputs_and_outputs.second.size() == 1) {
auto inputs_of_output = op_desc->Input(inputs_and_outputs.first[0]);
auto outputs_of_output = op_desc->Output(inputs_and_outputs.second[0]);
// 判断一下, 如果输入和输出没有同名, 是支持的
for (int i = 0; i < inputs_of_output.size(); ++i) { for (int i = 0; i < inputs_of_output.size(); ++i) {
std::string input_of_output = inputs_of_output[i]; std::string input_of_output = inputs_of_output[i];
for (int j = 0; j < outputs_of_output.size(); ++j) { for (int j = 0; j < outputs_of_output.size(); ++j) {
...@@ -101,7 +109,7 @@ void Node::OpDescs(std::vector<std::shared_ptr<framework::OpDesc>> *op_desc, ...@@ -101,7 +109,7 @@ void Node::OpDescs(std::vector<std::shared_ptr<framework::OpDesc>> *op_desc,
} }
} }
} }
} else { } else { // 如果模型中包含没有的 op, 则不支持添加 split
DLOG << "找不到 这个 op 类型: " << output->op_desc_->type_; DLOG << "找不到 这个 op 类型: " << output->op_desc_->type_;
can_add_split = false; can_add_split = false;
} }
...@@ -124,12 +132,10 @@ void Node::OpDescs(std::vector<std::shared_ptr<framework::OpDesc>> *op_desc, ...@@ -124,12 +132,10 @@ void Node::OpDescs(std::vector<std::shared_ptr<framework::OpDesc>> *op_desc,
if (can_add_split) { if (can_add_split) {
adding_thread = true; adding_thread = true;
std::shared_ptr<class OpDesc> split_op_desc = std::shared_ptr<OpDesc> split_op_desc = std::make_shared<OpDesc>();
std::make_shared<class OpDesc>();
split_op_desc->type_ = G_OP_TYPE_SPLIT; split_op_desc->type_ = G_OP_TYPE_SPLIT;
auto outputs = this->op_desc_->Output( auto outputs = this->op_desc_->Output(
op_input_output_key[this->op_desc_->Type()].second[0]); op_input_output_key[this->op_desc_->Type()].second[0]);
split_op_desc->inputs_ = { split_op_desc->inputs_ = {
{op_input_output_key[G_OP_TYPE_SPLIT].first[0], outputs}}; {op_input_output_key[G_OP_TYPE_SPLIT].first[0], outputs}};
auto &split_outputs = auto &split_outputs =
...@@ -157,41 +163,12 @@ std::vector<std::shared_ptr<framework::OpDesc>> Node::OpDescs() { ...@@ -157,41 +163,12 @@ std::vector<std::shared_ptr<framework::OpDesc>> Node::OpDescs() {
return op_descs; return op_descs;
} }
std::string Node::ToString(std::string blank, const Node *node) const {
std::stringstream ss;
ss << type_ << "-> \n";
if (inputs_.size() > 1 && node != inputs_.back()) {
return ss.str();
} else if (inputs_.size() > 1 && node == inputs_.back()) {
ss << "\n" << blank << type_ << "\n";
}
for (int i = 0; i < outputs_.size(); ++i) {
ss << blank << outputs_[i]->ToString(blank + " ", this) << "";
}
return ss.str();
}
std::string Node::ToString() const { return this->ToString(" ", this); }
std::shared_ptr<Node> Node::To(int size) { std::shared_ptr<Node> Node::To(int size) {
std::shared_ptr<Node> node = std::make_shared<Node>(); std::shared_ptr<Node> node = std::make_shared<Node>();
this->To(size - 1, node); this->To(size - 1, node);
return node; return node;
} }
// Node &Node::To(int size) {
// if (size == 1) {
// this->outputs_.clear();
// }
//
// for (int j = 0; j < this->outputs_.size(); ++j) {
// outputs_[j]->To(size - 1);
// }
// return *this;
//}
void Node::To(int index, std::shared_ptr<Node> node) { void Node::To(int index, std::shared_ptr<Node> node) {
node->type_ = this->type_; node->type_ = this->type_;
if (index != 0) { if (index != 0) {
...@@ -268,6 +245,24 @@ void Node::Folder( ...@@ -268,6 +245,24 @@ void Node::Folder(
} }
} }
std::string Node::ToString(std::string blank, const Node *node) const {
std::stringstream ss;
ss << type_ << "-> \n";
if (inputs_.size() > 1 && node != inputs_.back()) {
return ss.str();
} else if (inputs_.size() > 1 && node == inputs_.back()) {
ss << "\n" << blank << type_ << "\n";
}
for (int i = 0; i < outputs_.size(); ++i) {
ss << blank << outputs_[i]->ToString(blank + " ", this) << "";
}
return ss.str();
}
std::string Node::ToString() const { return this->ToString(" ", this); }
void Node::Description() { void Node::Description() {
if (op_desc_.get()) { if (op_desc_.get()) {
DLOG << *op_desc_; DLOG << *op_desc_;
......
...@@ -27,6 +27,8 @@ namespace paddle_mobile { ...@@ -27,6 +27,8 @@ namespace paddle_mobile {
namespace framework { namespace framework {
class Node : PaddleMobileObject { class Node : PaddleMobileObject {
friend class ProgramOptimize;
public: public:
Node() {} Node() {}
explicit Node(const std::string &type) : type_(type) {} explicit Node(const std::string &type) : type_(type) {}
...@@ -42,8 +44,8 @@ class Node : PaddleMobileObject { ...@@ -42,8 +44,8 @@ class Node : PaddleMobileObject {
std::map<std::string, std::pair<std::string, std::string>> change_map); std::map<std::string, std::pair<std::string, std::string>> change_map);
std::vector<std::shared_ptr<framework::OpDesc>> OpDescs(uint size); std::vector<std::shared_ptr<framework::OpDesc>> OpDescs(uint size);
std::vector<std::shared_ptr<framework::OpDesc>> OpDescs(); std::vector<std::shared_ptr<framework::OpDesc>> OpDescs();
std::shared_ptr<framework::OpDesc> OpDesc() { return op_desc_; } std::shared_ptr<framework::OpDesc> OpDescOfNode() { return op_desc_; }
std::string BeginType() { return type_; } std::string Type() { return type_; }
void Description(); void Description();
private: private:
......
...@@ -19,11 +19,12 @@ namespace paddle_mobile { ...@@ -19,11 +19,12 @@ namespace paddle_mobile {
namespace framework { namespace framework {
// std::shared_ptr<ProgramDesc> ProgramOptimize::Optimize() {}
std::shared_ptr<ProgramDesc> ProgramOptimize::FushionOptimize( std::shared_ptr<ProgramDesc> ProgramOptimize::FushionOptimize(
std::shared_ptr<ProgramDesc> ori_des) { std::shared_ptr<ProgramDesc> ori_des, bool add_split) {
ProgramDesc *optimize_program = new ProgramDesc(*ori_des); // ProgramDesc *optimize_program = new ProgramDesc(*ori_des);
std::shared_ptr<ProgramDesc> optimize_program =
std::make_shared<ProgramDesc>(*ori_des);
current_block_ = optimize_program->Blocks().size();
for (int i = 0; i < optimize_program->Blocks().size(); ++i) { for (int i = 0; i < optimize_program->Blocks().size(); ++i) {
std::unordered_map<std::string, std::shared_ptr<Node>> output_nodes; std::unordered_map<std::string, std::shared_ptr<Node>> output_nodes;
...@@ -96,10 +97,145 @@ std::shared_ptr<ProgramDesc> ProgramOptimize::FushionOptimize( ...@@ -96,10 +97,145 @@ std::shared_ptr<ProgramDesc> ProgramOptimize::FushionOptimize(
} }
// DLOG << "node: \n" << *begin_node; // DLOG << "node: \n" << *begin_node;
block->ops_ = begin_node->OpDescs();
std::vector<std::shared_ptr<framework::OpDesc>> op_descs;
GenerateOps(&op_descs, begin_node.get());
block->ops_ = op_descs;
}
for (int m = 0; m < new_blocks_.size(); ++m) {
std::shared_ptr<BlockDesc> new_block = new_blocks_[m];
new_block->index_ = m + ori_des->blocks_.size();
optimize_program->blocks_.push_back(new_block);
}
return optimize_program;
}
void ProgramOptimize::GenerateOps(
std::vector<std::shared_ptr<framework::OpDesc>> *op_desc, Node *input_node,
Node *current_node, bool adding_thread, int thread_num,
std::shared_ptr<BlockDesc> new_block) {
if (current_node->outputs_.size() > 1) {
adding_thread = false;
}
bool can_add_split = false;
// 如果当前节点有多个输出 并且 只有当前节点对应的 op_desc_ 输出数为 1 时支持
if (current_node->outputs_.size() > 1 &&
op_input_output_key[current_node->op_desc_->type_].second.size() == 1) {
can_add_split = true;
// 遍历当前节点的 output 节点
for (const auto &output : current_node->outputs_) {
// 不支持 output 有多个 output 的情况
if (output->outputs_.size() > 1) {
DLOG << "don't support multi output of output";
can_add_split = false;
break;
}
//与节点关联的 OpDesc
std::shared_ptr<framework::OpDesc> &op_desc = output->op_desc_;
//获取这个 op 的 inputs key 和 outputs key
auto inputs_and_outputs = op_input_output_key[op_desc->type_];
//判断现在 是否存在这个 op
//判断这个 output 和 input key 的 size 等于 1
if (op_input_output_key.find(op_desc->type_) !=
op_input_output_key.end() &&
inputs_and_outputs.first.size() == 1 &&
inputs_and_outputs.second.size() == 1) {
auto inputs_of_output = op_desc->Input(inputs_and_outputs.first[0]);
auto outputs_of_output = op_desc->Output(inputs_and_outputs.second[0]);
// 判断一下, 如果输入和输出没有同名, 是支持的
for (int i = 0; i < inputs_of_output.size(); ++i) {
std::string input_of_output = inputs_of_output[i];
for (int j = 0; j < outputs_of_output.size(); ++j) {
std::string output_of_output = outputs_of_output[j];
if (input_of_output == output_of_output) {
DLOG << "output的 output 包含 input" << input_of_output;
can_add_split = false;
break;
}
}
}
} else { // 如果模型中包含没有的 op, 则不支持添加 split
DLOG << "找不到 这个 op 类型: " << output->op_desc_->type_;
can_add_split = false;
}
}
}
if (current_node->inputs_.size() > 1 &&
input_node != current_node->inputs_.back()) {
return;
} else if (current_node->inputs_.size() > 1 &&
input_node == current_node->inputs_.back()) {
new_block.reset();
adding_thread = false;
op_desc->push_back(current_node->op_desc_);
} else {
if (new_block.get() && adding_thread) {
new_block->ops_.push_back(current_node->op_desc_);
} else {
op_desc->push_back(current_node->op_desc_);
}
}
if (adding_thread) {
Attribute attr;
attr.Set<int>(thread_num);
current_node->op_desc_->attrs_["thread"] = attr;
}
if (can_add_split) {
new_block = std::make_shared<BlockDesc>();
new_block->multi_thread_ = true;
new_block->index_ = current_block_;
new_blocks_.push_back(new_block);
adding_thread = true;
std::shared_ptr<OpDesc> split_op_desc = std::make_shared<OpDesc>();
split_op_desc->type_ = G_OP_TYPE_SPLIT;
auto outputs = current_node->op_desc_->Output(
op_input_output_key[current_node->op_desc_->Type()].second[0]);
split_op_desc->inputs_ = {
{op_input_output_key[G_OP_TYPE_SPLIT].first[0], outputs}};
auto &split_outputs =
split_op_desc->outputs_[op_input_output_key[G_OP_TYPE_SPLIT].second[0]];
for (const auto &output : current_node->outputs_) {
split_outputs.push_back(outputs[0]);
}
Attribute attr;
attr.Set<int>(current_block_);
split_op_desc->attrs_["block_id"] = attr;
op_desc->push_back(split_op_desc);
current_block_++;
}
for (int i = 0; i < current_node->outputs_.size(); ++i) {
auto &output = current_node->outputs_[i];
if (can_add_split) {
GenerateOps(op_desc, current_node, output.get(), adding_thread, i,
new_block);
} else {
GenerateOps(op_desc, current_node, output.get(), adding_thread,
thread_num, new_block);
}
} }
std::shared_ptr<ProgramDesc> shared_optimzie(optimize_program);
return shared_optimzie;
} }
void ProgramOptimize::GenerateOps(
std::vector<std::shared_ptr<framework::OpDesc>> *op_descs,
Node *begin_node) {
// std::vector<std::shared_ptr<framework::OpDesc>> *op_desc,
// Node *input_node, Node *current_node, bool adding_thread, int
// thread_num
this->GenerateOps(op_descs, begin_node, begin_node, false, -1, nullptr);
}
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -28,12 +28,17 @@ class ProgramOptimize { ...@@ -28,12 +28,17 @@ class ProgramOptimize {
public: public:
ProgramOptimize() {} ProgramOptimize() {}
std::shared_ptr<ProgramDesc> FushionOptimize( std::shared_ptr<ProgramDesc> FushionOptimize(
std::shared_ptr<ProgramDesc> ori_des); std::shared_ptr<ProgramDesc> ori_des, bool add_split = false);
private: private:
// std::shared_ptr<ProgramDesc> ori_desc_; int current_block_;
std::vector<std::unordered_map<std::string, std::shared_ptr<Node>>> std::vector<std::shared_ptr<BlockDesc>> new_blocks_;
outputs_nodes_;
void GenerateOps(std::vector<std::shared_ptr<framework::OpDesc>> *op_descs,
Node *begin_node);
void GenerateOps(std::vector<std::shared_ptr<framework::OpDesc>> *op_desc,
Node *input_node, Node *current_node, bool adding_thread,
int thread_num, std::shared_ptr<BlockDesc> new_block);
}; };
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -32,11 +32,13 @@ void ProgramDesc::Description(std::string header) { ...@@ -32,11 +32,13 @@ void ProgramDesc::Description(std::string header) {
if (header.size()) { if (header.size()) {
LOG(kLOG_INFO) << header; LOG(kLOG_INFO) << header;
} }
for (const auto &block : this->blocks_) {
for (int i = 0; i < this->blocks_.size(); ++i) {
auto block = this->blocks_[i];
LOG(kLOG_DEBUG) << "block: " << block->ID(); LOG(kLOG_DEBUG) << "block: " << block->ID();
LOG(kLOG_INFO) << "block ops size: " << block->Ops().size(); LOG(kLOG_INFO) << "block ops size: " << block->Ops().size();
for (int j = 0; j < block->Ops().size(); ++j) { for (int j = 0; j < block->Ops().size(); ++j) {
const auto &op = block->Ops()[j]; auto op = block->Ops()[j];
LOG(kLOG_DEBUG1) << "op: " << op->Type(); LOG(kLOG_DEBUG1) << "op: " << op->Type();
for (auto &input : op->GetInputs()) { for (auto &input : op->GetInputs()) {
LOG(kLOG_DEBUG2) << "input parameter: " << input.first; LOG(kLOG_DEBUG2) << "input parameter: " << input.first;
...@@ -71,6 +73,9 @@ void ProgramDesc::Description(std::string header) { ...@@ -71,6 +73,9 @@ void ProgramDesc::Description(std::string header) {
} }
} }
} }
for (const auto &block : this->blocks_) {
}
#endif #endif
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册