提交 1e1b6622 编写于 作者: Y Yancey1989

update by comment

上级 b084dfab
...@@ -46,11 +46,7 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node, ...@@ -46,11 +46,7 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
#endif #endif
void AllReduceOpHandle::RunImpl() { void AllReduceOpHandle::RunImpl() {
if (dev_ctxes_.size() > 0UL) { platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second);
platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second);
} else {
platform::RecordEvent record_event(Name(), nullptr);
}
if (NoDummyInputSize() == 1) { if (NoDummyInputSize() == 1) {
return; // No need to all reduce when GPU count = 1; return; // No need to all reduce when GPU count = 1;
......
...@@ -22,11 +22,7 @@ namespace framework { ...@@ -22,11 +22,7 @@ namespace framework {
namespace details { namespace details {
void BroadcastOpHandle::RunImpl() { void BroadcastOpHandle::RunImpl() {
if (dev_ctxes_.size() > 0UL) { platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second);
platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second);
} else {
platform::RecordEvent record_event(Name(), nullptr);
}
if (places_.size() == 1) return; if (places_.size() == 1) return;
......
...@@ -87,12 +87,6 @@ std::vector<std::array<int, 3>> DataBalanceOpHandle::GetBalancePlan( ...@@ -87,12 +87,6 @@ std::vector<std::array<int, 3>> DataBalanceOpHandle::GetBalancePlan(
} }
void DataBalanceOpHandle::RunImpl() { void DataBalanceOpHandle::RunImpl() {
if (dev_ctxes_.size() > 0UL) {
platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second);
} else {
platform::RecordEvent record_event(Name(), nullptr);
}
PADDLE_ENFORCE_GT(places_.size(), 1, PADDLE_ENFORCE_GT(places_.size(), 1,
"Data balance can only be enabled when the number of " "Data balance can only be enabled when the number of "
"places to run larger than 1."); "places to run larger than 1.");
......
...@@ -431,10 +431,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -431,10 +431,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
CreateReduceOp(&result, g_name, cur_device_id); CreateReduceOp(&result, g_name, cur_device_id);
graph->Get<ShardedVarDevice>(kShardedVarDevice) graph->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(g_name, cur_device_id); .emplace(g_name, cur_device_id);
if (!is_dist_train) { bcast_var_name_set[cur_device_id].emplace(p_name);
// will send gradients directly when distributed training
bcast_var_name_set[cur_device_id].emplace(p_name);
}
break; break;
case BuildStrategy::ReduceStrategy::kAllReduce: case BuildStrategy::ReduceStrategy::kAllReduce:
if (IsSparseGradient(g_name)) { if (IsSparseGradient(g_name)) {
...@@ -461,9 +458,8 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -461,9 +458,8 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
use_gpu = nccl_ctxs_ != nullptr; use_gpu = nccl_ctxs_ != nullptr;
#endif #endif
if ((use_gpu && if (use_gpu && strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce &&
strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) || !is_dist_train) {
is_dist_train) {
// Insert BCast Ops // Insert BCast Ops
for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) { for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) {
auto &to_bcast_set = bcast_var_name_set[dev_id]; auto &to_bcast_set = bcast_var_name_set[dev_id];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册