提交 2ae32f0b 编写于 作者: Q qiaolongfei

revert the change of api

上级 1b690213
......@@ -51,17 +51,17 @@ class TranspilerTest(unittest.TestCase):
self.origin_prog = main.clone()
return main
def get_trainer(self, config=None):
t = self._transpiler_instance(config)
def get_trainer(self, config=None, sync_mode=True):
t = self._transpiler_instance(config, sync_mode)
return t.get_trainer_program()
def get_pserver(self, ep, config=None):
t = self._transpiler_instance(config)
def get_pserver(self, ep, config=None, sync_mode=True):
t = self._transpiler_instance(config, sync_mode)
pserver = t.get_pserver_program(ep)
startup = t.get_startup_program(ep, pserver)
return pserver, startup
def _transpiler_instance(self, config=None):
def _transpiler_instance(self, config=None, sync_mode=True):
if not self.transpiler:
main = self.get_main_program()
self.transpiler = fluid.DistributeTranspiler(config=config)
......@@ -69,7 +69,8 @@ class TranspilerTest(unittest.TestCase):
self.trainer_id,
program=main,
pservers=self.pserver_eps,
trainers=self.trainers)
trainers=self.trainers,
sync_mode=sync_mode)
return self.transpiler
......@@ -470,8 +471,7 @@ class TestAsyncLocalLookupTable(TestDistLookupTableBase):
def transpiler_test_impl(self):
config = fluid.DistributeTranspilerConfig()
config.sync_mode = False
pserver1, startup1 = self.get_pserver(self.pserver1_ep, config)
pserver1, startup1 = self.get_pserver(self.pserver1_ep, config, False)
self.assertEqual(len(pserver1.blocks), 3)
# 0 listen_and_serv
......@@ -503,9 +503,8 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase):
def transpiler_test_impl(self):
config = fluid.DistributeTranspilerConfig()
config.sync_mode = False
pserver1, startup1 = self.get_pserver(self.pserver1_ep, config)
pserver1, startup1 = self.get_pserver(self.pserver1_ep, config, False)
self.assertEqual(len(pserver1.blocks), 6)
# 0 listen_and_serv
......@@ -525,7 +524,6 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase):
trainer = self.get_trainer(config)
self.assertEqual(len(trainer.blocks), 1)
print([op.type for op in trainer.blocks[0].ops])
ops = [
'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids',
'prefetch', 'merge_ids', 'sequence_pool', 'concat', 'mul',
......
......@@ -124,7 +124,6 @@ class DistributeTranspilerConfig(object):
slice_var_up = True
split_method = None
min_block_size = 8192
sync_mode = True
class DistributeTranspiler(object):
......@@ -198,7 +197,7 @@ class DistributeTranspiler(object):
program = default_main_program()
self.origin_program = program
self.trainer_num = trainers
self.sync_mode = sync_mode and self.config.sync_mode
self.sync_mode = sync_mode
self.trainer_id = trainer_id
pserver_endpoints = pservers.split(",")
self.pserver_endpoints = pserver_endpoints
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册