提交 fb370f44 编写于 作者: Y yuyang18

Refine code

上级 03e4da6d
......@@ -163,28 +163,35 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
if (static_cast<bool>(boost::get<int>(op->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) &
static_cast<int>(OpRole::kBackward))) {
auto backward_vars = boost::get<std::vector<std::string>>(
op->GetAttrOrDefault(OpProtoAndCheckerMaker::OpRoleVarAttrName(),
std::vector<std::string>()));
for (auto &og : backward_vars) {
try {
auto backward_vars =
boost::get<std::vector<std::string>>(op->GetNullableAttr(
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0);
for (size_t i = 0; i < backward_vars.size(); ++i) {
auto &p_name = backward_vars[i];
auto &g_name = backward_vars[i + 1];
switch (strategy_.reduce_) {
case BuildStrategy::ReduceStrategy::kReduce:
CreateReduceOp(&result, og, cur_device_id);
var_name_on_devices[cur_device_id].emplace(og);
bcast_var_name_set[cur_device_id].emplace(
og.substr(0, og.size() - strlen(kGradVarSuffix)));
CreateReduceOp(&result, g_name, cur_device_id);
var_name_on_devices[cur_device_id].emplace(g_name);
bcast_var_name_set[cur_device_id].emplace(p_name);
cur_device_id = (cur_device_id + 1) % places_.size();
break;
case BuildStrategy::ReduceStrategy::kAllReduce:
if (IsSparseGradient(var_types, og)) {
CreateReduceOp(&result, og, 0);
CreateBroadcastOp(&result, og, 0);
if (IsSparseGradient(var_types, g_name)) {
CreateReduceOp(&result, g_name, 0);
CreateBroadcastOp(&result, g_name, 0);
} else {
InsertNCCLAllReduceOp(&result, og);
InsertNCCLAllReduceOp(&result, g_name);
}
break;
}
}
} catch (boost::bad_get e) {
}
}
}
}
......
......@@ -223,13 +223,12 @@ Attribute OpDesc::GetAttr(const std::string &name) const {
return it->second;
}
Attribute OpDesc::GetAttrOrDefault(
const std::string &name, paddle::framework::Attribute default_attr) const {
Attribute OpDesc::GetNullableAttr(const std::string &name) const {
auto it = attrs_.find(name);
if (it != attrs_.end()) {
return it->second;
} else {
return default_attr;
return Attribute();
}
}
......
......@@ -78,8 +78,7 @@ class OpDesc {
Attribute GetAttr(const std::string &name) const;
Attribute GetAttrOrDefault(const std::string &name,
Attribute default_attr) const;
Attribute GetNullableAttr(const std::string &name) const;
int GetBlockAttr(const std::string &name) const;
......
......@@ -536,7 +536,7 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
if g.op is None:
raise ValueError("Unexpected branch")
attr_val = [p.name]
attr_val = [p.name, g.name]
if g.op.has_attr(op_role_var_attr_name):
attr_val.extend(g.op.attr(op_role_var_attr_name))
g.op.set_attr(op_role_var_attr_name, attr_val)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册