提交 9ee1b7bc 编写于 作者: Y Yancey1989

add some comments

上级 bad4ea19
...@@ -431,7 +431,9 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -431,7 +431,9 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
CreateReduceOp(&result, g_name, cur_device_id); CreateReduceOp(&result, g_name, cur_device_id);
graph->Get<ShardedVarDevice>(kShardedVarDevice) graph->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(g_name, cur_device_id); .emplace(g_name, cur_device_id);
if (!is_dist_train) {
bcast_var_name_set[cur_device_id].emplace(p_name); bcast_var_name_set[cur_device_id].emplace(p_name);
}
break; break;
case BuildStrategy::ReduceStrategy::kAllReduce: case BuildStrategy::ReduceStrategy::kAllReduce:
if (IsSparseGradient(g_name)) { if (IsSparseGradient(g_name)) {
...@@ -461,7 +463,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -461,7 +463,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
if ((use_gpu && if ((use_gpu &&
strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) || strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) ||
is_dist_train) { is_dist_train) {
// Insert BCast Ops // allways broadcast receieved parameters for distributed training
for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) { for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) {
auto &to_bcast_set = bcast_var_name_set[dev_id]; auto &to_bcast_set = bcast_var_name_set[dev_id];
for (auto &bcast_name : to_bcast_set) { for (auto &bcast_name : to_bcast_set) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册