提交 fb370f44 编写于 作者: Y yuyang18

Refine code

上级 03e4da6d
...@@ -163,27 +163,34 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -163,27 +163,34 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
if (static_cast<bool>(boost::get<int>(op->GetAttr( if (static_cast<bool>(boost::get<int>(op->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) & OpProtoAndCheckerMaker::OpRoleAttrName())) &
static_cast<int>(OpRole::kBackward))) { static_cast<int>(OpRole::kBackward))) {
auto backward_vars = boost::get<std::vector<std::string>>( try {
op->GetAttrOrDefault(OpProtoAndCheckerMaker::OpRoleVarAttrName(), auto backward_vars =
std::vector<std::string>())); boost::get<std::vector<std::string>>(op->GetNullableAttr(
for (auto &og : backward_vars) { OpProtoAndCheckerMaker::OpRoleVarAttrName()));
switch (strategy_.reduce_) {
case BuildStrategy::ReduceStrategy::kReduce: PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0);
CreateReduceOp(&result, og, cur_device_id);
var_name_on_devices[cur_device_id].emplace(og); for (size_t i = 0; i < backward_vars.size(); ++i) {
bcast_var_name_set[cur_device_id].emplace( auto &p_name = backward_vars[i];
og.substr(0, og.size() - strlen(kGradVarSuffix))); auto &g_name = backward_vars[i + 1];
cur_device_id = (cur_device_id + 1) % places_.size(); switch (strategy_.reduce_) {
break; case BuildStrategy::ReduceStrategy::kReduce:
case BuildStrategy::ReduceStrategy::kAllReduce: CreateReduceOp(&result, g_name, cur_device_id);
if (IsSparseGradient(var_types, og)) { var_name_on_devices[cur_device_id].emplace(g_name);
CreateReduceOp(&result, og, 0); bcast_var_name_set[cur_device_id].emplace(p_name);
CreateBroadcastOp(&result, og, 0); cur_device_id = (cur_device_id + 1) % places_.size();
} else { break;
InsertNCCLAllReduceOp(&result, og); case BuildStrategy::ReduceStrategy::kAllReduce:
} if (IsSparseGradient(var_types, g_name)) {
break; CreateReduceOp(&result, g_name, 0);
CreateBroadcastOp(&result, g_name, 0);
} else {
InsertNCCLAllReduceOp(&result, g_name);
}
break;
}
} }
} catch (boost::bad_get e) {
} }
} }
} }
......
...@@ -223,13 +223,12 @@ Attribute OpDesc::GetAttr(const std::string &name) const { ...@@ -223,13 +223,12 @@ Attribute OpDesc::GetAttr(const std::string &name) const {
return it->second; return it->second;
} }
Attribute OpDesc::GetAttrOrDefault( Attribute OpDesc::GetNullableAttr(const std::string &name) const {
const std::string &name, paddle::framework::Attribute default_attr) const {
auto it = attrs_.find(name); auto it = attrs_.find(name);
if (it != attrs_.end()) { if (it != attrs_.end()) {
return it->second; return it->second;
} else { } else {
return default_attr; return Attribute();
} }
} }
......
...@@ -78,8 +78,7 @@ class OpDesc { ...@@ -78,8 +78,7 @@ class OpDesc {
Attribute GetAttr(const std::string &name) const; Attribute GetAttr(const std::string &name) const;
Attribute GetAttrOrDefault(const std::string &name, Attribute GetNullableAttr(const std::string &name) const;
Attribute default_attr) const;
int GetBlockAttr(const std::string &name) const; int GetBlockAttr(const std::string &name) const;
......
...@@ -536,7 +536,7 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, ...@@ -536,7 +536,7 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
if g.op is None: if g.op is None:
raise ValueError("Unexpected branch") raise ValueError("Unexpected branch")
attr_val = [p.name] attr_val = [p.name, g.name]
if g.op.has_attr(op_role_var_attr_name): if g.op.has_attr(op_role_var_attr_name):
attr_val.extend(g.op.attr(op_role_var_attr_name)) attr_val.extend(g.op.attr(op_role_var_attr_name))
g.op.set_attr(op_role_var_attr_name, attr_val) g.op.set_attr(op_role_var_attr_name, attr_val)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册