From 9ee1b7bc045091522c53cc69d174bebda979667e Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Thu, 13 Sep 2018 21:17:30 +0800 Subject: [PATCH] add some comments --- paddle/fluid/framework/details/multi_devices_graph_pass.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index 5781936cb3..7e7f1234c2 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -431,7 +431,9 @@ std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( CreateReduceOp(&result, g_name, cur_device_id); graph->Get(kShardedVarDevice) .emplace(g_name, cur_device_id); - bcast_var_name_set[cur_device_id].emplace(p_name); + if (!is_dist_train) { + bcast_var_name_set[cur_device_id].emplace(p_name); + } break; case BuildStrategy::ReduceStrategy::kAllReduce: if (IsSparseGradient(g_name)) { @@ -461,7 +463,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( if ((use_gpu && strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) || is_dist_train) { - // Insert BCast Ops + // allways broadcast receieved parameters for distributed training for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) { auto &to_bcast_set = bcast_var_name_set[dev_id]; for (auto &bcast_name : to_bcast_set) { -- GitLab