提交 2ae32f0b 编写于 作者: Q qiaolongfei

revert the change of api

上级 1b690213
...@@ -51,17 +51,17 @@ class TranspilerTest(unittest.TestCase): ...@@ -51,17 +51,17 @@ class TranspilerTest(unittest.TestCase):
self.origin_prog = main.clone() self.origin_prog = main.clone()
return main return main
def get_trainer(self, config=None): def get_trainer(self, config=None, sync_mode=True):
t = self._transpiler_instance(config) t = self._transpiler_instance(config, sync_mode)
return t.get_trainer_program() return t.get_trainer_program()
def get_pserver(self, ep, config=None): def get_pserver(self, ep, config=None, sync_mode=True):
t = self._transpiler_instance(config) t = self._transpiler_instance(config, sync_mode)
pserver = t.get_pserver_program(ep) pserver = t.get_pserver_program(ep)
startup = t.get_startup_program(ep, pserver) startup = t.get_startup_program(ep, pserver)
return pserver, startup return pserver, startup
def _transpiler_instance(self, config=None): def _transpiler_instance(self, config=None, sync_mode=True):
if not self.transpiler: if not self.transpiler:
main = self.get_main_program() main = self.get_main_program()
self.transpiler = fluid.DistributeTranspiler(config=config) self.transpiler = fluid.DistributeTranspiler(config=config)
...@@ -69,7 +69,8 @@ class TranspilerTest(unittest.TestCase): ...@@ -69,7 +69,8 @@ class TranspilerTest(unittest.TestCase):
self.trainer_id, self.trainer_id,
program=main, program=main,
pservers=self.pserver_eps, pservers=self.pserver_eps,
trainers=self.trainers) trainers=self.trainers,
sync_mode=sync_mode)
return self.transpiler return self.transpiler
...@@ -470,8 +471,7 @@ class TestAsyncLocalLookupTable(TestDistLookupTableBase): ...@@ -470,8 +471,7 @@ class TestAsyncLocalLookupTable(TestDistLookupTableBase):
def transpiler_test_impl(self): def transpiler_test_impl(self):
config = fluid.DistributeTranspilerConfig() config = fluid.DistributeTranspilerConfig()
config.sync_mode = False pserver1, startup1 = self.get_pserver(self.pserver1_ep, config, False)
pserver1, startup1 = self.get_pserver(self.pserver1_ep, config)
self.assertEqual(len(pserver1.blocks), 3) self.assertEqual(len(pserver1.blocks), 3)
# 0 listen_and_serv # 0 listen_and_serv
...@@ -503,9 +503,8 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase): ...@@ -503,9 +503,8 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase):
def transpiler_test_impl(self): def transpiler_test_impl(self):
config = fluid.DistributeTranspilerConfig() config = fluid.DistributeTranspilerConfig()
config.sync_mode = False
pserver1, startup1 = self.get_pserver(self.pserver1_ep, config) pserver1, startup1 = self.get_pserver(self.pserver1_ep, config, False)
self.assertEqual(len(pserver1.blocks), 6) self.assertEqual(len(pserver1.blocks), 6)
# 0 listen_and_serv # 0 listen_and_serv
...@@ -525,7 +524,6 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase): ...@@ -525,7 +524,6 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase):
trainer = self.get_trainer(config) trainer = self.get_trainer(config)
self.assertEqual(len(trainer.blocks), 1) self.assertEqual(len(trainer.blocks), 1)
print([op.type for op in trainer.blocks[0].ops])
ops = [ ops = [
'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids', 'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids',
'prefetch', 'merge_ids', 'sequence_pool', 'concat', 'mul', 'prefetch', 'merge_ids', 'sequence_pool', 'concat', 'mul',
......
...@@ -124,7 +124,6 @@ class DistributeTranspilerConfig(object): ...@@ -124,7 +124,6 @@ class DistributeTranspilerConfig(object):
slice_var_up = True slice_var_up = True
split_method = None split_method = None
min_block_size = 8192 min_block_size = 8192
sync_mode = True
class DistributeTranspiler(object): class DistributeTranspiler(object):
...@@ -198,7 +197,7 @@ class DistributeTranspiler(object): ...@@ -198,7 +197,7 @@ class DistributeTranspiler(object):
program = default_main_program() program = default_main_program()
self.origin_program = program self.origin_program = program
self.trainer_num = trainers self.trainer_num = trainers
self.sync_mode = sync_mode and self.config.sync_mode self.sync_mode = sync_mode
self.trainer_id = trainer_id self.trainer_id = trainer_id
pserver_endpoints = pservers.split(",") pserver_endpoints = pservers.split(",")
self.pserver_endpoints = pserver_endpoints self.pserver_endpoints = pserver_endpoints
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册