未验证 提交 445ca3db 编写于 作者: Q Qiao Longfei 提交者: GitHub

Merge pull request #12607 from jacquesqiao/add-unit-test-for-async-transpile

Add unit test for async transpile
......@@ -51,17 +51,17 @@ class TranspilerTest(unittest.TestCase):
self.origin_prog = main.clone()
return main
def get_trainer(self, config=None):
t = self._transpiler_instance(config)
def get_trainer(self, config=None, sync_mode=True):
t = self._transpiler_instance(config, sync_mode)
return t.get_trainer_program()
def get_pserver(self, ep, config=None):
t = self._transpiler_instance(config)
def get_pserver(self, ep, config=None, sync_mode=True):
t = self._transpiler_instance(config, sync_mode)
pserver = t.get_pserver_program(ep)
startup = t.get_startup_program(ep, pserver)
return pserver, startup
def _transpiler_instance(self, config=None):
def _transpiler_instance(self, config=None, sync_mode=True):
if not self.transpiler:
main = self.get_main_program()
self.transpiler = fluid.DistributeTranspiler(config=config)
......@@ -69,7 +69,8 @@ class TranspilerTest(unittest.TestCase):
self.trainer_id,
program=main,
pservers=self.pserver_eps,
trainers=self.trainers)
trainers=self.trainers,
sync_mode=sync_mode)
return self.transpiler
......@@ -464,5 +465,76 @@ class TestDistLookupTable(TestDistLookupTableBase):
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
class TestAsyncLocalLookupTable(TestDistLookupTableBase):
def net_conf(self):
self.network_with_table(is_sparse=True, is_distributed=False)
def transpiler_test_impl(self):
config = fluid.DistributeTranspilerConfig()
pserver1, startup1 = self.get_pserver(self.pserver1_ep, config, False)
self.assertEqual(len(pserver1.blocks), 3)
# 0 listen_and_serv
# 1 optimize for fc_w or fc_b adam
self.assertEqual([op.type for op in pserver1.blocks[1].ops],
["adam", "scale", "scale"])
# 2 optimize for table adam
# NOTE: if param is not selected rows, the grad will scaled to grad / trainer_num
self.assertEqual([op.type for op in pserver1.blocks[2].ops],
["adam", "scale", "scale"])
trainer = self.get_trainer(config)
self.assertEqual(len(trainer.blocks), 1)
ops = [
'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool',
'concat', 'mul', 'elementwise_add', 'cross_entropy', 'mean',
'fill_constant', 'mean_grad', 'cross_entropy_grad',
'elementwise_add_grad', 'send', 'mul_grad', 'send', 'concat_grad',
'sequence_pool_grad', 'lookup_table_grad', 'sequence_pool_grad',
'lookup_table_grad', 'sum', 'split_selected_rows', 'send', 'recv',
'recv', 'recv', 'concat'
]
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
class TestAsyncDistLookupTable(TestDistLookupTableBase):
def net_conf(self):
self.network_with_table(is_sparse=True, is_distributed=True)
def transpiler_test_impl(self):
config = fluid.DistributeTranspilerConfig()
pserver1, startup1 = self.get_pserver(self.pserver1_ep, config, False)
self.assertEqual(len(pserver1.blocks), 6)
# 0 listen_and_serv
# 1 optimize for fc_w or fc_b adam
self.assertEqual([op.type for op in pserver1.blocks[1].ops],
["adam", "scale", "scale"])
# 2 optimize for table sgd
self.assertEqual([op.type for op in pserver1.blocks[2].ops], ["sgd"])
# 3 prefetch -> lookup_sparse_table for data0
self.assertEqual([op.type for op in pserver1.blocks[3].ops],
["lookup_sparse_table"])
# 4 prefetch -> lookup_sparse_table for data1
self.assertEqual([op.type for op in pserver1.blocks[4].ops],
["lookup_sparse_table"])
# 5 save table
self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"])
trainer = self.get_trainer(config)
self.assertEqual(len(trainer.blocks), 1)
ops = [
'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids',
'prefetch', 'merge_ids', 'sequence_pool', 'concat', 'mul',
'elementwise_add', 'cross_entropy', 'mean', 'fill_constant',
'mean_grad', 'cross_entropy_grad', 'elementwise_add_grad', 'send',
'mul_grad', 'send', 'concat_grad', 'sequence_pool_grad',
'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad',
'sum', 'split_ids', 'send', 'recv', 'recv'
]
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
if __name__ == "__main__":
unittest.main()
......@@ -293,6 +293,7 @@ class DistributeTranspiler(object):
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
if self.sync_mode:
program.global_block().append_op(
type="fetch_barrier",
inputs={},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册