未验证 提交 e26cced7 编写于 作者: W Wu Yi 提交者: GitHub

refine batch merge pass (#14777)

* refine batch merge pass

* refine batch merge pass test=develop
上级 4048cfa9
...@@ -75,6 +75,7 @@ std::unique_ptr<Graph> BatchMergePass::ApplyImpl( ...@@ -75,6 +75,7 @@ std::unique_ptr<Graph> BatchMergePass::ApplyImpl(
std::vector<Node*> optimize_ops; std::vector<Node*> optimize_ops;
std::vector<Node*> lr_ops; // ops other than forward/backward/optimize std::vector<Node*> lr_ops; // ops other than forward/backward/optimize
std::unordered_set<std::string> grad_names; std::unordered_set<std::string> grad_names;
std::unordered_map<std::string, std::string> gradname2paramname;
std::vector<ir::Node*> nodes = TopologySortOperations(*graph); std::vector<ir::Node*> nodes = TopologySortOperations(*graph);
auto origin_nodes = graph->ReleaseNodes(); auto origin_nodes = graph->ReleaseNodes();
...@@ -99,6 +100,7 @@ std::unique_ptr<Graph> BatchMergePass::ApplyImpl( ...@@ -99,6 +100,7 @@ std::unique_ptr<Graph> BatchMergePass::ApplyImpl(
auto op_role_vars = boost::get<std::vector<std::string>>(op_role_var); auto op_role_vars = boost::get<std::vector<std::string>>(op_role_var);
for (size_t i = 0; i < op_role_vars.size(); i += 2) { for (size_t i = 0; i < op_role_vars.size(); i += 2) {
grad_names.insert(op_role_vars[i + 1]); grad_names.insert(op_role_vars[i + 1]);
gradname2paramname[op_role_vars[i + 1]] = op_role_vars[i];
} }
} else if (op_role & static_cast<int>(framework::OpRole::kLRSched)) { } else if (op_role & static_cast<int>(framework::OpRole::kLRSched)) {
lr_ops.push_back(node); lr_ops.push_back(node);
...@@ -109,7 +111,7 @@ std::unique_ptr<Graph> BatchMergePass::ApplyImpl( ...@@ -109,7 +111,7 @@ std::unique_ptr<Graph> BatchMergePass::ApplyImpl(
// 2. copy forward backward // 2. copy forward backward
ir::Node* prev_repeat_last_op_node = nullptr; ir::Node* prev_repeat_last_op_node = nullptr;
// record origin_grad -> repeated grad list map. // record origin_grad -> repeated_grad_list map.
std::map<ir::Node*, std::vector<ir::Node*>> grad_repeated_map; std::map<ir::Node*, std::vector<ir::Node*>> grad_repeated_map;
std::map<std::string, std::vector<ir::Node*>> created; std::map<std::string, std::vector<ir::Node*>> created;
std::unordered_set<std::string> bn_vars_need_rename; std::unordered_set<std::string> bn_vars_need_rename;
...@@ -124,10 +126,16 @@ std::unique_ptr<Graph> BatchMergePass::ApplyImpl( ...@@ -124,10 +126,16 @@ std::unique_ptr<Graph> BatchMergePass::ApplyImpl(
if (grad_names.find(outname) != grad_names.end()) { if (grad_names.find(outname) != grad_names.end()) {
std::string new_gname = string::Sprintf("%s.repeat.%d", outname, i); std::string new_gname = string::Sprintf("%s.repeat.%d", outname, i);
repeated_op.RenameOutput(outname, new_gname); repeated_op.RenameOutput(outname, new_gname);
// remove op_role_var for backward ops that outputs grad for a
// parameter.
repeated_op.SetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName(),
std::vector<std::string>());
} }
} }
// 3.5 let batch_norm ops use independent vars, note batch_norm_grad do // 3.5 let batch_norm ops use independent vars, note batch_norm_grad do
// not need this update // not need this update, because only moving mean and variance should be
// differ, trainable parameter scale and bias is the same as other
// parameters.
if (node->Name() == "batch_norm") { if (node->Name() == "batch_norm") {
// NOTE: assume bn op created by layers use save var as output mean and // NOTE: assume bn op created by layers use save var as output mean and
// variance // variance
...@@ -224,16 +232,25 @@ std::unique_ptr<Graph> BatchMergePass::ApplyImpl( ...@@ -224,16 +232,25 @@ std::unique_ptr<Graph> BatchMergePass::ApplyImpl(
var->inputs.push_back(repeated_node); var->inputs.push_back(repeated_node);
} }
} }
} } // end copy forward backward
// 5. create GRAD merge op node // 5. create GRAD merge op node: sum(repeat.0...repeat.n) ->
// scale(1/num_repeats)
for (auto kv : grad_repeated_map) { for (auto kv : grad_repeated_map) {
OpDesc sum_op; OpDesc sum_op;
sum_op.SetType("sum"); sum_op.SetType("sum");
std::vector<std::string> repeated_grad_names; std::vector<std::string> repeated_grad_names;
std::vector<std::string> param_grad_op_role_var;
for (auto r : kv.second) { for (auto r : kv.second) {
repeated_grad_names.push_back(r->Var()->Name()); repeated_grad_names.push_back(r->Var()->Name());
} }
// NOTE: use op_role_var to control allreduce op appending in
// multi_devices_graph_pass, we want to append op_role_var
// only once for the merged gradient, so break after first call.
param_grad_op_role_var.push_back(
gradname2paramname.at(kv.first->Var()->Name())); // param
param_grad_op_role_var.push_back(kv.first->Var()->Name()); // grad
sum_op.SetInput("X", repeated_grad_names); sum_op.SetInput("X", repeated_grad_names);
sum_op.SetOutput("Out", {kv.first->Var()->Name()}); sum_op.SetOutput("Out", {kv.first->Var()->Name()});
sum_op.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), sum_op.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
...@@ -256,6 +273,10 @@ std::unique_ptr<Graph> BatchMergePass::ApplyImpl( ...@@ -256,6 +273,10 @@ std::unique_ptr<Graph> BatchMergePass::ApplyImpl(
scale_op.SetAttr("scale", static_cast<float>(1.0f / num_repeats)); scale_op.SetAttr("scale", static_cast<float>(1.0f / num_repeats));
scale_op.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), scale_op.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kBackward)); static_cast<int>(OpRole::kBackward));
scale_op.SetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName(),
param_grad_op_role_var);
auto scale_op_node = result.CreateOpNode(&scale_op); auto scale_op_node = result.CreateOpNode(&scale_op);
scale_op_node->inputs.push_back(sum_out_var_node); scale_op_node->inputs.push_back(sum_out_var_node);
sum_out_var_node->outputs.push_back(scale_op_node); sum_out_var_node->outputs.push_back(scale_op_node);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册