diff --git a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.h b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.h index bb926256676761f107ab386ff5815fabbd088664..ad14a3c5cb4625fa121cad2daed389c441e78771 100644 --- a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.h +++ b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.h @@ -39,7 +39,7 @@ struct NCCLAllReduceOpHandle : public OpHandleBase { // Delay and buffer nccl_all_reduce together can significantly increase // performance. Disable this feature by returning false. - bool IsDelayedOp() override { return true; }; + bool IsMultiDeviceTransfer() override { return true; }; protected: void RunImpl() override; diff --git a/paddle/fluid/framework/details/op_handle_base.h b/paddle/fluid/framework/details/op_handle_base.h index 54c2d627ff30461a93e765241652d42e204c6689..d7a541ac4bb83625060db337446d03a1afda3ed0 100644 --- a/paddle/fluid/framework/details/op_handle_base.h +++ b/paddle/fluid/framework/details/op_handle_base.h @@ -55,7 +55,9 @@ class OpHandleBase { void AddOutput(VarHandleBase *out); - virtual bool IsDelayedOp() { return false; } + // If the Op involves data transfer of multiple devices that + // will likely block other computations. + virtual bool IsMultiDeviceTransfer() { return false; } protected: virtual void RunImpl() = 0; diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 32fc9100ab13c3a442c5a24e9ad750328c50de13..65fbfb65e1656c4bdd8f54cdeeeec9da3e94decf 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -50,7 +50,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( // together since we currently cannot overlap computation and memcpy streams. // Should revisit it if overlapping is available. std::unordered_set delayed_ops; - std::unordered_set after_delayed_ops; + std::unordered_set blocked_by_delayed_ops; std::unordered_set delayed_vars; auto InsertPendingVar = [&pending_vars, &ready_vars](VarHandleBase &var) { @@ -119,7 +119,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( auto run_all_ready_ops = [&] { for (auto *op : ready_ops) { - if (op->IsDelayedOp()) { + if (op->IsMultiDeviceTransfer()) { delayed_ops.insert(op); delayed_vars.insert(op->outputs_.begin(), op->outputs_.end()); ready_vars.Extend(op->outputs_); @@ -162,20 +162,22 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( --deps; if (deps == 0) { if (delayed_vars.find(ready_var) != delayed_vars.end()) { - after_delayed_ops.insert(op); + blocked_by_delayed_ops.insert(op); } else { ready_ops.insert(op); } } } } + // When there are no other ops to schedule, schedule buffered delayed + // ops and unblock other ops. if (ready_ops.empty() && !delayed_ops.empty() && running_ops_ == 0) { RunDelayedOps(delayed_ops); delayed_ops.clear(); - for (auto *op : after_delayed_ops) { + for (auto *op : blocked_by_delayed_ops) { ready_ops.insert(op); } - after_delayed_ops.clear(); + blocked_by_delayed_ops.clear(); } // Keep loop until all vars are ready. }