未验证 提交 bb29800a 编写于 作者: C chengduo 提交者: GitHub

small refine (#11460)

上级 ab0c2e1d
...@@ -207,53 +207,56 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -207,53 +207,56 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
is_forwarding = false; is_forwarding = false;
} else { } else {
int op_dev_id = GetOpDeviceID(*op); int op_dev_id = GetOpDeviceID(*op);
if (op_dev_id == -1) { // var on all device if (op_dev_id != -1) { // This op only runs on one specific device.
CreateComputationalOps(&result, *op, places_.size());
} else {
CreateComputationalOp(&result, *op, op_dev_id); CreateComputationalOp(&result, *op, op_dev_id);
for (auto &var_name : op->OutputArgumentNames()) { for (auto &var_name : op->OutputArgumentNames()) {
var_name_on_devices_.emplace(var_name, op_dev_id); var_name_on_devices_.emplace(var_name, op_dev_id);
} }
} } else {
if (!is_forwarding && places_.size() > 1) { // This op runs on all devices, and its output may have parameter's
// Currently, we assume that once gradient is generated, it can be // gradients.
// broadcast, and each gradient is only broadcast once. CreateComputationalOps(&result, *op, places_.size());
if (static_cast<bool>(boost::get<int>(op->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) & if (!is_forwarding && places_.size() > 1) {
static_cast<int>(OpRole::kBackward))) { // Currently, we assume that once gradient is generated, it can be
try { // broadcast, and each gradient is only broadcast once.
auto backward_vars = if (static_cast<bool>(boost::get<int>(op->GetAttr(
boost::get<std::vector<std::string>>(op->GetNullableAttr( OpProtoAndCheckerMaker::OpRoleAttrName())) &
OpProtoAndCheckerMaker::OpRoleVarAttrName())); static_cast<int>(OpRole::kBackward))) {
try {
PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0); auto backward_vars =
boost::get<std::vector<std::string>>(op->GetNullableAttr(
for (size_t i = 0; i < backward_vars.size(); i += 2) { OpProtoAndCheckerMaker::OpRoleVarAttrName()));
auto &p_name = backward_vars[i];
auto &g_name = backward_vars[i + 1]; PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0);
VLOG(10) << "Bcast " << g_name << " for parameter " << p_name;
for (size_t i = 0; i < backward_vars.size(); i += 2) {
switch (strategy_.reduce_) { auto &p_name = backward_vars[i];
case BuildStrategy::ReduceStrategy::kReduce: auto &g_name = backward_vars[i + 1];
cur_device_id = GetAppropriateDeviceID({g_name}); VLOG(10) << "Bcast " << g_name << " for parameter " << p_name;
CreateReduceOp(&result, g_name, cur_device_id);
var_name_on_devices_.emplace(g_name, cur_device_id); switch (strategy_.reduce_) {
bcast_var_name_set[cur_device_id].emplace(p_name); case BuildStrategy::ReduceStrategy::kReduce:
break; cur_device_id = GetAppropriateDeviceID({g_name});
case BuildStrategy::ReduceStrategy::kAllReduce: CreateReduceOp(&result, g_name, cur_device_id);
if (IsSparseGradient(g_name)) { var_name_on_devices_.emplace(g_name, cur_device_id);
CreateReduceOp(&result, g_name, 0); bcast_var_name_set[cur_device_id].emplace(p_name);
CreateBroadcastOp(&result, g_name, 0); break;
} else { case BuildStrategy::ReduceStrategy::kAllReduce:
InsertAllReduceOp(&result, g_name); if (IsSparseGradient(g_name)) {
} CreateReduceOp(&result, g_name, 0);
break; CreateBroadcastOp(&result, g_name, 0);
default: } else {
LOG(FATAL) << "Unknown reduce strategy "; InsertAllReduceOp(&result, g_name);
break; }
break;
default:
LOG(FATAL) << "Unknown reduce strategy ";
break;
}
} }
} catch (boost::bad_get e) {
} }
} catch (boost::bad_get e) {
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册