From be1373dcf9c233b6a0c870232adb0e66df64f80c Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Sun, 1 Apr 2018 18:11:52 -0700 Subject: [PATCH] Polish --- .../framework/details/nccl_all_reduce_op_handle.h | 2 +- paddle/fluid/framework/details/op_handle_base.h | 4 +++- .../framework/details/threaded_ssa_graph_executor.cc | 12 +++++++----- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.h b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.h index bb92625667..ad14a3c5cb 100644 --- a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.h +++ b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.h @@ -39,7 +39,7 @@ struct NCCLAllReduceOpHandle : public OpHandleBase { // Delay and buffer nccl_all_reduce together can significantly increase // performance. Disable this feature by returning false. - bool IsDelayedOp() override { return true; }; + bool IsMultiDeviceTransfer() override { return true; }; protected: void RunImpl() override; diff --git a/paddle/fluid/framework/details/op_handle_base.h b/paddle/fluid/framework/details/op_handle_base.h index 54c2d627ff..d7a541ac4b 100644 --- a/paddle/fluid/framework/details/op_handle_base.h +++ b/paddle/fluid/framework/details/op_handle_base.h @@ -55,7 +55,9 @@ class OpHandleBase { void AddOutput(VarHandleBase *out); - virtual bool IsDelayedOp() { return false; } + // If the Op involves data transfer of multiple devices that + // will likely block other computations. + virtual bool IsMultiDeviceTransfer() { return false; } protected: virtual void RunImpl() = 0; diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 32fc9100ab..65fbfb65e1 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -50,7 +50,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( // together since we currently cannot overlap computation and memcpy streams. // Should revisit it if overlapping is available. std::unordered_set delayed_ops; - std::unordered_set after_delayed_ops; + std::unordered_set blocked_by_delayed_ops; std::unordered_set delayed_vars; auto InsertPendingVar = [&pending_vars, &ready_vars](VarHandleBase &var) { @@ -119,7 +119,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( auto run_all_ready_ops = [&] { for (auto *op : ready_ops) { - if (op->IsDelayedOp()) { + if (op->IsMultiDeviceTransfer()) { delayed_ops.insert(op); delayed_vars.insert(op->outputs_.begin(), op->outputs_.end()); ready_vars.Extend(op->outputs_); @@ -162,20 +162,22 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( --deps; if (deps == 0) { if (delayed_vars.find(ready_var) != delayed_vars.end()) { - after_delayed_ops.insert(op); + blocked_by_delayed_ops.insert(op); } else { ready_ops.insert(op); } } } } + // When there are no other ops to schedule, schedule buffered delayed + // ops and unblock other ops. if (ready_ops.empty() && !delayed_ops.empty() && running_ops_ == 0) { RunDelayedOps(delayed_ops); delayed_ops.clear(); - for (auto *op : after_delayed_ops) { + for (auto *op : blocked_by_delayed_ops) { ready_ops.insert(op); } - after_delayed_ops.clear(); + blocked_by_delayed_ops.clear(); } // Keep loop until all vars are ready. } -- GitLab