提交 d0342f12 编写于 作者: Y Yu Yang

Simplify DelayOps Logic

上级 494c262a
...@@ -51,8 +51,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -51,8 +51,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
// together since we currently cannot overlap computation and memcpy streams. // together since we currently cannot overlap computation and memcpy streams.
// Should revisit it if overlapping is available. // Should revisit it if overlapping is available.
std::unordered_set<OpHandleBase *> delayed_ops; std::unordered_set<OpHandleBase *> delayed_ops;
std::unordered_set<OpHandleBase *> blocked_by_delayed_ops;
std::unordered_set<VarHandleBase *> delayed_vars;
auto InsertPendingVar = [&pending_vars, &ready_vars](VarHandleBase &var) { auto InsertPendingVar = [&pending_vars, &ready_vars](VarHandleBase &var) {
pending_vars.insert(&var); pending_vars.insert(&var);
...@@ -122,24 +120,26 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -122,24 +120,26 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
InsertPendingOp(*op); InsertPendingOp(*op);
} }
auto run_all_ready_ops = [&] { auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
for (auto *op : ready_ops) { for (auto *op : set) {
if (op->IsMultiDeviceTransfer() && allow_op_delay_) {
delayed_ops.insert(op);
delayed_vars.insert(op->outputs_.begin(), op->outputs_.end());
ready_vars.Extend(op->outputs_);
continue;
}
running_ops_++; running_ops_++;
RunOp(&ready_vars, op); RunOp(&ready_vars, op);
} }
ready_ops.clear(); set.clear();
}; };
// Step 3. Execution // Step 3. Execution
while (!pending_vars.empty() || !ready_ops.empty() || !delayed_ops.empty()) { while (!pending_vars.empty()) {
// 1. Run All Ready ops // 1. Run All Ready ops
run_all_ready_ops(); // Keep loop until all vars are ready.
//
// NOTE: DelayedOps have a lower priority. It will be scheduled after all
// ready_ops have been performed.
if (ready_ops.empty() && allow_op_delay_) {
run_all_ops(delayed_ops);
} else {
run_all_ops(ready_ops);
}
// 2. Find ready variable // 2. Find ready variable
bool timeout; bool timeout;
...@@ -160,29 +160,16 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -160,29 +160,16 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
auto &deps = pending_ops[op]; auto &deps = pending_ops[op];
--deps; --deps;
if (deps == 0) { if (deps == 0) {
if (delayed_vars.find(ready_var) != delayed_vars.end()) { if (op->IsMultiDeviceTransfer() && allow_op_delay_) {
blocked_by_delayed_ops.insert(op); delayed_ops.insert(op);
} else { } else {
ready_ops.insert(op); ready_ops.insert(op);
} }
} }
} }
} }
// When there are no other ops to schedule, schedule buffered delayed
// ops and unblock other ops.
if (ready_ops.empty() && !delayed_ops.empty() && running_ops_ == 0) {
RunDelayedOps(delayed_ops);
delayed_ops.clear();
for (auto *op : blocked_by_delayed_ops) {
ready_ops.insert(op);
}
blocked_by_delayed_ops.clear();
}
// Keep loop until all vars are ready.
} }
PADDLE_ENFORCE(ready_ops.empty()); PADDLE_ENFORCE(ready_ops.empty());
PADDLE_ENFORCE(delayed_ops.empty());
PADDLE_ENFORCE(blocked_by_delayed_ops.empty());
// Wait FetchOps. // Wait FetchOps.
if (!fetch_ops.empty()) { if (!fetch_ops.empty()) {
......
...@@ -206,18 +206,19 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -206,18 +206,19 @@ class TestParallelExecutorBase(unittest.TestCase):
feed_dict={}): feed_dict={}):
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
startup.random_seed = 1 # Fix random seed
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
loss = method(use_feed=len(feed_dict) > 0) loss = method(use_feed=len(feed_dict) > 0)
adam = fluid.optimizer.Adam() adam = fluid.optimizer.Adam()
adam.minimize(loss) adam.minimize(loss)
if memory_opt: if memory_opt:
fluid.memory_optimize(main) fluid.memory_optimize(main)
place = fluid.CUDAPlace(0) place = fluid.CUDAPlace(0)
startup_exe = fluid.Executor(place) startup_exe = fluid.Executor(place)
startup_exe.run(startup) startup_exe.run(startup)
exe = fluid.ParallelExecutor(True, loss_name=loss.name) exe = fluid.ParallelExecutor(
True, loss_name=loss.name, allow_op_delay=allow_op_delay)
if batch_size is not None: if batch_size is not None:
batch_size *= fluid.core.get_cuda_device_count() batch_size *= fluid.core.get_cuda_device_count()
begin = time.time() begin = time.time()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册