diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py index 1f37e4e91d55035846f14b6e2bf019c78e304eea..abd372126848c5779cf7d989dc03e421dc94b1cf 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py @@ -51,17 +51,17 @@ class TranspilerTest(unittest.TestCase): self.origin_prog = main.clone() return main - def get_trainer(self, config=None): - t = self._transpiler_instance(config) + def get_trainer(self, config=None, sync_mode=True): + t = self._transpiler_instance(config, sync_mode) return t.get_trainer_program() - def get_pserver(self, ep, config=None): - t = self._transpiler_instance(config) + def get_pserver(self, ep, config=None, sync_mode=True): + t = self._transpiler_instance(config, sync_mode) pserver = t.get_pserver_program(ep) startup = t.get_startup_program(ep, pserver) return pserver, startup - def _transpiler_instance(self, config=None): + def _transpiler_instance(self, config=None, sync_mode=True): if not self.transpiler: main = self.get_main_program() self.transpiler = fluid.DistributeTranspiler(config=config) @@ -69,7 +69,8 @@ class TranspilerTest(unittest.TestCase): self.trainer_id, program=main, pservers=self.pserver_eps, - trainers=self.trainers) + trainers=self.trainers, + sync_mode=sync_mode) return self.transpiler @@ -470,8 +471,7 @@ class TestAsyncLocalLookupTable(TestDistLookupTableBase): def transpiler_test_impl(self): config = fluid.DistributeTranspilerConfig() - config.sync_mode = False - pserver1, startup1 = self.get_pserver(self.pserver1_ep, config) + pserver1, startup1 = self.get_pserver(self.pserver1_ep, config, False) self.assertEqual(len(pserver1.blocks), 3) # 0 listen_and_serv @@ -503,9 +503,8 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase): def transpiler_test_impl(self): config = fluid.DistributeTranspilerConfig() - config.sync_mode = False - pserver1, startup1 = self.get_pserver(self.pserver1_ep, config) + pserver1, startup1 = self.get_pserver(self.pserver1_ep, config, False) self.assertEqual(len(pserver1.blocks), 6) # 0 listen_and_serv @@ -525,7 +524,6 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase): trainer = self.get_trainer(config) self.assertEqual(len(trainer.blocks), 1) - print([op.type for op in trainer.blocks[0].ops]) ops = [ 'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'concat', 'mul', diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index be96c20db1ac7c5d1e667ee1e0aa213f12bc20b0..820509bbcc4679cadb06554476798c76e6869eb5 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -124,7 +124,6 @@ class DistributeTranspilerConfig(object): slice_var_up = True split_method = None min_block_size = 8192 - sync_mode = True class DistributeTranspiler(object): @@ -198,7 +197,7 @@ class DistributeTranspiler(object): program = default_main_program() self.origin_program = program self.trainer_num = trainers - self.sync_mode = sync_mode and self.config.sync_mode + self.sync_mode = sync_mode self.trainer_id = trainer_id pserver_endpoints = pservers.split(",") self.pserver_endpoints = pserver_endpoints