diff --git a/paddle/operators/parallel_do_op.cc b/paddle/operators/parallel_do_op.cc index c2561fa2bf3aa0992f32ed1295c6640d55e6322b..a00458ea068dd703d2c7f362511ed08bc212d2a8 100644 --- a/paddle/operators/parallel_do_op.cc +++ b/paddle/operators/parallel_do_op.cc @@ -64,6 +64,12 @@ static void SplitTensorAndMoveTensorToScopes( } } +void WaitOnPlace(const platform::Place place) { + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); + dev_ctx.Wait(); +} + void WaitOnPlaces(const std::vector places) { platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); @@ -214,6 +220,7 @@ class ParallelDoGradOp : public framework::OperatorBase { auto &tensor_to_merge = sub_scopes[i]->FindVar(s)->Get(); if (!(places[i] == places[0])) { framework::Copy(tensor_to_merge, places[0], tmp); + WaitOnPlace(places[0]); } else { tmp->ShareDataWith(tensor_to_merge); } @@ -222,12 +229,13 @@ class ParallelDoGradOp : public framework::OperatorBase { "sum", {{"X", {s, tmp_name}}}, {{"Out", {s}}}, framework::AttributeMap{}); sum_op->Run(*sub_scopes[0], places[0]); - WaitOnPlaces(places); + WaitOnPlace(places[0]); } VLOG(3) << result; framework::Copy(result, place, scope.FindVar(s)->GetMutable()); } + WaitOnPlaces(places); } }; diff --git a/python/paddle/v2/fluid/tests/test_parallel_op.py b/python/paddle/v2/fluid/tests/test_parallel_op.py index 45196ef6fe5230a6b3ead0b64fee09492188da82..d36f7d07ac381d835618af4b420525ff0e607651 100644 --- a/python/paddle/v2/fluid/tests/test_parallel_op.py +++ b/python/paddle/v2/fluid/tests/test_parallel_op.py @@ -15,9 +15,6 @@ import unittest import paddle.v2.fluid as fluid import numpy -import sys -# TODO(dzhwinter): get places op check need to be enhanced. -sys.exit(0) class BaseParallelForTest(unittest.TestCase): @@ -165,13 +162,13 @@ class ParallelOpTest(BaseParallelForTest): feed={ 'img': numpy.random.random(size=(51, 784)).astype('float32') }, - fetch='fc1.w@GRAD') + fetch=['fc1.w@GRAD']) def test_fc_with_tiny_data(self): self.run_test( callback=ParallelOpTest.__network__, feed={'img': numpy.random.random(size=(1, 784)).astype('float32')}, - fetch='fc1.w@GRAD') + fetch=['fc1.w@GRAD']) if __name__ == '__main__':