From cb40c33137c7361c70742551a9a8f85c291fe640 Mon Sep 17 00:00:00 2001
From: Yu Yang <yuyang18@baidu.com>
Date: Mon, 26 Mar 2018 17:01:39 +0800
Subject: [PATCH] Update unittest

---
 .../details/computation_op_handle.cc          |  2 +-
 .../details/threaded_ssa_graph_executor.cc    | 29 ++++++++
 .../details/threaded_ssa_graph_executor.h     |  3 +
 .../tests/unittests/test_parallel_executor.py | 68 ++++++++++---------
 4 files changed, 70 insertions(+), 32 deletions(-)

diff --git a/paddle/fluid/framework/details/computation_op_handle.cc b/paddle/fluid/framework/details/computation_op_handle.cc
index 348b944cf92..53ab8eb7754 100644
--- a/paddle/fluid/framework/details/computation_op_handle.cc
+++ b/paddle/fluid/framework/details/computation_op_handle.cc
@@ -33,7 +33,7 @@ void ComputationOpHandle::RunImpl() {
     }
   }
 
-  op_->Run(*scope_, place_);
+  op_->Run(*scope_->FindVar("@TMP_SCOPE@")->Get<Scope *>(), place_);
 }
 
 std::string ComputationOpHandle::Name() const { return op_->Type(); }
diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
index f609395d40f..dcb611b8b1c 100644
--- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
+++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
@@ -112,6 +112,12 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
     ready_ops.clear();
   };
 
+  // Create local scopes.
+  for (auto &scope : local_scopes_) {
+    auto &local_scope = scope->NewScope();
+    *scope->Var("@TMP_SCOPE@")->GetMutable<Scope *>() = &local_scope;
+  }
+
   // Step 3. Execution
   while (!pending_vars.empty()) {
     // 1. Run All Ready ops
@@ -156,9 +162,32 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
     // Keep loop until all vars are ready.
   }
 
+  ++computation_count_;
+
+  auto sync_computation = [&] {
+    computation_count_ = 0;
+    // Wait All computational streams
+    for (auto p : this->places_) {
+      platform::DeviceContextPool::Instance().Get(p)->Wait();
+    }
+
+    // NOTE: the temp scope can be dropped lazily if needed.
+    // Drop tmp scopes;
+    for (auto &scope : local_scopes_) {
+      auto &kid = *scope->Var("@TMP_SCOPE@")->GetMutable<Scope *>();
+      kid = nullptr;
+      scope->DropKids();
+    }
+  };
+
   // Wait FetchOps.
   for (auto &fetch_op : fetch_ops) {
     fetch_op.WaitAndMergeCPUTensors();
+    sync_computation();
+  }
+
+  if (computation_count_ == max_async_computation) {
+    sync_computation();
   }
 
   return fetch_data;
diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h
index 5b099c18c92..805f80e7f73 100644
--- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h
+++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h
@@ -48,6 +48,9 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
   platform::DeviceContextPool fetch_ctxs_;
   const bool use_event_;
   std::unique_ptr<platform::EnforceNotMet> exception_;
+
+  size_t computation_count_{0};
+  size_t max_async_computation{100};
 };
 
 }  // namespace details
diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor.py b/python/paddle/fluid/tests/unittests/test_parallel_executor.py
index d5d2275e4d9..106320839c6 100644
--- a/python/paddle/fluid/tests/unittests/test_parallel_executor.py
+++ b/python/paddle/fluid/tests/unittests/test_parallel_executor.py
@@ -178,7 +178,32 @@ def SE_ResNeXt152():
     return loss
 
 
-class ParallelExecutor(unittest.TestCase):
+class TestParallelExecutorBase(unittest.TestCase):
+    def check_network_convergence(self, method, memory_opt=True, iter=10):
+        main = fluid.Program()
+        startup = fluid.Program()
+        with fluid.program_guard(main, startup):
+            loss = method()
+            adam = fluid.optimizer.Adam()
+            adam.minimize(loss)
+            if memory_opt:
+                fluid.memory_optimize(main)
+
+            exe = fluid.ParallelExecutor(loss_name=loss.name, use_cuda=True)
+            first_loss, = exe.run([loss.name])
+            first_loss = numpy.array(first_loss)
+
+            for i in xrange(iter):
+                exe.run([])
+
+            last_loss, = exe.run([loss.name])
+            last_loss = numpy.array(last_loss)
+
+            print first_loss, last_loss
+            self.assertGreater(first_loss[0], last_loss[0])
+
+
+class TestMNIST(TestParallelExecutorBase):
     @classmethod
     def setUpClass(cls):
         # Convert mnist to recordio file
@@ -195,6 +220,16 @@ class ParallelExecutor(unittest.TestCase):
             fluid.recordio_writer.convert_reader_to_recordio_file(
                 './mnist.recordio', reader, feeder)
 
+    def test_simple_fc(self):
+        self.check_network_convergence(simple_fc_net)
+
+    def test_batchnorm_fc(self):
+        self.check_network_convergence(fc_with_batchnorm)
+
+
+class TestResnet(TestParallelExecutorBase):
+    @classmethod
+    def setUpClass(cls):
         with fluid.program_guard(fluid.Program(), fluid.Program()):
             reader = paddle.batch(flowers.train(), batch_size=4)
             feeder = fluid.DataFeeder(
@@ -208,34 +243,5 @@ class ParallelExecutor(unittest.TestCase):
             fluid.recordio_writer.convert_reader_to_recordio_file(
                 "./flowers.recordio", reader, feeder)
 
-    def test_simple_fc(self):
-        self.check_network_convergence(simple_fc_net)
-
-    def test_batchnorm_fc(self):
-        self.check_network_convergence(fc_with_batchnorm)
-
-    def check_network_convergence(self, method, memory_opt=True, iter=10):
-        main = fluid.Program()
-        startup = fluid.Program()
-        with fluid.program_guard(main, startup):
-            loss = method()
-            adam = fluid.optimizer.Adam()
-            adam.minimize(loss)
-            if memory_opt:
-                fluid.memory_optimize(main)
-
-            exe = fluid.ParallelExecutor(loss_name=loss.name, use_cuda=True)
-            first_loss, = exe.run([loss.name])
-            first_loss = numpy.array(first_loss)
-
-            for i in xrange(iter):
-                exe.run([])
-
-            last_loss, = exe.run([loss.name])
-            last_loss = numpy.array(last_loss)
-
-            print first_loss, last_loss
-            self.assertGreater(first_loss[0], last_loss[0])
-
     def test_resnet(self):
-        self.check_network_convergence(SE_ResNeXt152, iter=20)
+        self.check_network_convergence(SE_ResNeXt152, iter=200)
-- 
GitLab