提交 be1373dc 编写于 作者: X Xin Pan

Polish

上级 46f3a39e
......@@ -39,7 +39,7 @@ struct NCCLAllReduceOpHandle : public OpHandleBase {
// Delay and buffer nccl_all_reduce together can significantly increase
// performance. Disable this feature by returning false.
bool IsDelayedOp() override { return true; };
bool IsMultiDeviceTransfer() override { return true; };
protected:
void RunImpl() override;
......
......@@ -55,7 +55,9 @@ class OpHandleBase {
void AddOutput(VarHandleBase *out);
virtual bool IsDelayedOp() { return false; }
// If the Op involves data transfer of multiple devices that
// will likely block other computations.
virtual bool IsMultiDeviceTransfer() { return false; }
protected:
virtual void RunImpl() = 0;
......
......@@ -50,7 +50,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
// together since we currently cannot overlap computation and memcpy streams.
// Should revisit it if overlapping is available.
std::unordered_set<OpHandleBase *> delayed_ops;
std::unordered_set<OpHandleBase *> after_delayed_ops;
std::unordered_set<OpHandleBase *> blocked_by_delayed_ops;
std::unordered_set<VarHandleBase *> delayed_vars;
auto InsertPendingVar = [&pending_vars, &ready_vars](VarHandleBase &var) {
......@@ -119,7 +119,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
auto run_all_ready_ops = [&] {
for (auto *op : ready_ops) {
if (op->IsDelayedOp()) {
if (op->IsMultiDeviceTransfer()) {
delayed_ops.insert(op);
delayed_vars.insert(op->outputs_.begin(), op->outputs_.end());
ready_vars.Extend(op->outputs_);
......@@ -162,20 +162,22 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
--deps;
if (deps == 0) {
if (delayed_vars.find(ready_var) != delayed_vars.end()) {
after_delayed_ops.insert(op);
blocked_by_delayed_ops.insert(op);
} else {
ready_ops.insert(op);
}
}
}
}
// When there are no other ops to schedule, schedule buffered delayed
// ops and unblock other ops.
if (ready_ops.empty() && !delayed_ops.empty() && running_ops_ == 0) {
RunDelayedOps(delayed_ops);
delayed_ops.clear();
for (auto *op : after_delayed_ops) {
for (auto *op : blocked_by_delayed_ops) {
ready_ops.insert(op);
}
after_delayed_ops.clear();
blocked_by_delayed_ops.clear();
}
// Keep loop until all vars are ready.
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册