提交 be1373dc 编写于 作者: X Xin Pan

Polish

上级 46f3a39e
...@@ -39,7 +39,7 @@ struct NCCLAllReduceOpHandle : public OpHandleBase { ...@@ -39,7 +39,7 @@ struct NCCLAllReduceOpHandle : public OpHandleBase {
// Delay and buffer nccl_all_reduce together can significantly increase // Delay and buffer nccl_all_reduce together can significantly increase
// performance. Disable this feature by returning false. // performance. Disable this feature by returning false.
bool IsDelayedOp() override { return true; }; bool IsMultiDeviceTransfer() override { return true; };
protected: protected:
void RunImpl() override; void RunImpl() override;
......
...@@ -55,7 +55,9 @@ class OpHandleBase { ...@@ -55,7 +55,9 @@ class OpHandleBase {
void AddOutput(VarHandleBase *out); void AddOutput(VarHandleBase *out);
virtual bool IsDelayedOp() { return false; } // If the Op involves data transfer of multiple devices that
// will likely block other computations.
virtual bool IsMultiDeviceTransfer() { return false; }
protected: protected:
virtual void RunImpl() = 0; virtual void RunImpl() = 0;
......
...@@ -50,7 +50,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -50,7 +50,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
// together since we currently cannot overlap computation and memcpy streams. // together since we currently cannot overlap computation and memcpy streams.
// Should revisit it if overlapping is available. // Should revisit it if overlapping is available.
std::unordered_set<OpHandleBase *> delayed_ops; std::unordered_set<OpHandleBase *> delayed_ops;
std::unordered_set<OpHandleBase *> after_delayed_ops; std::unordered_set<OpHandleBase *> blocked_by_delayed_ops;
std::unordered_set<VarHandleBase *> delayed_vars; std::unordered_set<VarHandleBase *> delayed_vars;
auto InsertPendingVar = [&pending_vars, &ready_vars](VarHandleBase &var) { auto InsertPendingVar = [&pending_vars, &ready_vars](VarHandleBase &var) {
...@@ -119,7 +119,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -119,7 +119,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
auto run_all_ready_ops = [&] { auto run_all_ready_ops = [&] {
for (auto *op : ready_ops) { for (auto *op : ready_ops) {
if (op->IsDelayedOp()) { if (op->IsMultiDeviceTransfer()) {
delayed_ops.insert(op); delayed_ops.insert(op);
delayed_vars.insert(op->outputs_.begin(), op->outputs_.end()); delayed_vars.insert(op->outputs_.begin(), op->outputs_.end());
ready_vars.Extend(op->outputs_); ready_vars.Extend(op->outputs_);
...@@ -162,20 +162,22 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -162,20 +162,22 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
--deps; --deps;
if (deps == 0) { if (deps == 0) {
if (delayed_vars.find(ready_var) != delayed_vars.end()) { if (delayed_vars.find(ready_var) != delayed_vars.end()) {
after_delayed_ops.insert(op); blocked_by_delayed_ops.insert(op);
} else { } else {
ready_ops.insert(op); ready_ops.insert(op);
} }
} }
} }
} }
// When there are no other ops to schedule, schedule buffered delayed
// ops and unblock other ops.
if (ready_ops.empty() && !delayed_ops.empty() && running_ops_ == 0) { if (ready_ops.empty() && !delayed_ops.empty() && running_ops_ == 0) {
RunDelayedOps(delayed_ops); RunDelayedOps(delayed_ops);
delayed_ops.clear(); delayed_ops.clear();
for (auto *op : after_delayed_ops) { for (auto *op : blocked_by_delayed_ops) {
ready_ops.insert(op); ready_ops.insert(op);
} }
after_delayed_ops.clear(); blocked_by_delayed_ops.clear();
} }
// Keep loop until all vars are ready. // Keep loop until all vars are ready.
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册