From cb40c33137c7361c70742551a9a8f85c291fe640 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 26 Mar 2018 17:01:39 +0800 Subject: [PATCH] Update unittest --- .../details/computation_op_handle.cc | 2 +- .../details/threaded_ssa_graph_executor.cc | 29 ++++++++ .../details/threaded_ssa_graph_executor.h | 3 + .../tests/unittests/test_parallel_executor.py | 68 ++++++++++--------- 4 files changed, 70 insertions(+), 32 deletions(-) diff --git a/paddle/fluid/framework/details/computation_op_handle.cc b/paddle/fluid/framework/details/computation_op_handle.cc index 348b944cf9..53ab8eb775 100644 --- a/paddle/fluid/framework/details/computation_op_handle.cc +++ b/paddle/fluid/framework/details/computation_op_handle.cc @@ -33,7 +33,7 @@ void ComputationOpHandle::RunImpl() { } } - op_->Run(*scope_, place_); + op_->Run(*scope_->FindVar("@TMP_SCOPE@")->Get(), place_); } std::string ComputationOpHandle::Name() const { return op_->Type(); } diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index f609395d40..dcb611b8b1 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -112,6 +112,12 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ready_ops.clear(); }; + // Create local scopes. + for (auto &scope : local_scopes_) { + auto &local_scope = scope->NewScope(); + *scope->Var("@TMP_SCOPE@")->GetMutable() = &local_scope; + } + // Step 3. Execution while (!pending_vars.empty()) { // 1. Run All Ready ops @@ -156,9 +162,32 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( // Keep loop until all vars are ready. } + ++computation_count_; + + auto sync_computation = [&] { + computation_count_ = 0; + // Wait All computational streams + for (auto p : this->places_) { + platform::DeviceContextPool::Instance().Get(p)->Wait(); + } + + // NOTE: the temp scope can be dropped lazily if needed. + // Drop tmp scopes; + for (auto &scope : local_scopes_) { + auto &kid = *scope->Var("@TMP_SCOPE@")->GetMutable(); + kid = nullptr; + scope->DropKids(); + } + }; + // Wait FetchOps. for (auto &fetch_op : fetch_ops) { fetch_op.WaitAndMergeCPUTensors(); + sync_computation(); + } + + if (computation_count_ == max_async_computation) { + sync_computation(); } return fetch_data; diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h index 5b099c18c9..805f80e7f7 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h @@ -48,6 +48,9 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { platform::DeviceContextPool fetch_ctxs_; const bool use_event_; std::unique_ptr exception_; + + size_t computation_count_{0}; + size_t max_async_computation{100}; }; } // namespace details diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor.py b/python/paddle/fluid/tests/unittests/test_parallel_executor.py index d5d2275e4d..106320839c 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor.py @@ -178,7 +178,32 @@ def SE_ResNeXt152(): return loss -class ParallelExecutor(unittest.TestCase): +class TestParallelExecutorBase(unittest.TestCase): + def check_network_convergence(self, method, memory_opt=True, iter=10): + main = fluid.Program() + startup = fluid.Program() + with fluid.program_guard(main, startup): + loss = method() + adam = fluid.optimizer.Adam() + adam.minimize(loss) + if memory_opt: + fluid.memory_optimize(main) + + exe = fluid.ParallelExecutor(loss_name=loss.name, use_cuda=True) + first_loss, = exe.run([loss.name]) + first_loss = numpy.array(first_loss) + + for i in xrange(iter): + exe.run([]) + + last_loss, = exe.run([loss.name]) + last_loss = numpy.array(last_loss) + + print first_loss, last_loss + self.assertGreater(first_loss[0], last_loss[0]) + + +class TestMNIST(TestParallelExecutorBase): @classmethod def setUpClass(cls): # Convert mnist to recordio file @@ -195,6 +220,16 @@ class ParallelExecutor(unittest.TestCase): fluid.recordio_writer.convert_reader_to_recordio_file( './mnist.recordio', reader, feeder) + def test_simple_fc(self): + self.check_network_convergence(simple_fc_net) + + def test_batchnorm_fc(self): + self.check_network_convergence(fc_with_batchnorm) + + +class TestResnet(TestParallelExecutorBase): + @classmethod + def setUpClass(cls): with fluid.program_guard(fluid.Program(), fluid.Program()): reader = paddle.batch(flowers.train(), batch_size=4) feeder = fluid.DataFeeder( @@ -208,34 +243,5 @@ class ParallelExecutor(unittest.TestCase): fluid.recordio_writer.convert_reader_to_recordio_file( "./flowers.recordio", reader, feeder) - def test_simple_fc(self): - self.check_network_convergence(simple_fc_net) - - def test_batchnorm_fc(self): - self.check_network_convergence(fc_with_batchnorm) - - def check_network_convergence(self, method, memory_opt=True, iter=10): - main = fluid.Program() - startup = fluid.Program() - with fluid.program_guard(main, startup): - loss = method() - adam = fluid.optimizer.Adam() - adam.minimize(loss) - if memory_opt: - fluid.memory_optimize(main) - - exe = fluid.ParallelExecutor(loss_name=loss.name, use_cuda=True) - first_loss, = exe.run([loss.name]) - first_loss = numpy.array(first_loss) - - for i in xrange(iter): - exe.run([]) - - last_loss, = exe.run([loss.name]) - last_loss = numpy.array(last_loss) - - print first_loss, last_loss - self.assertGreater(first_loss[0], last_loss[0]) - def test_resnet(self): - self.check_network_convergence(SE_ResNeXt152, iter=20) + self.check_network_convergence(SE_ResNeXt152, iter=200) -- GitLab