From 8da651057cdcdf266b8205a023473f4c8c1e7189 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Thu, 9 Aug 2018 14:46:28 +0800 Subject: [PATCH] add TestAsyncDistLookupTable --- .../tests/unittests/test_dist_transpiler.py | 41 +++++++++++++++++++ .../fluid/transpiler/distribute_transpiler.py | 20 +++++---- 2 files changed, 52 insertions(+), 9 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py index b24036326d5..91b80ca75f7 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py @@ -464,5 +464,46 @@ class TestDistLookupTable(TestDistLookupTableBase): self.assertEqual([op.type for op in trainer.blocks[0].ops], ops) +class TestAsyncDistLookupTable(TestDistLookupTableBase): + def net_conf(self): + self.network_with_table(is_sparse=True, is_distributed=True) + + def transpiler_test_impl(self): + config = fluid.DistributeTranspilerConfig() + config.sync_mode = False + + pserver1, startup1 = self.get_pserver(self.pserver1_ep, config) + + self.assertEqual(len(pserver1.blocks), 6) + # 0 listen_and_serv + # 1 optimize for fc_w or fc_b adam + self.assertEqual([op.type for op in pserver1.blocks[1].ops], + ["adam", "scale", "scale"]) + # 2 optimize for table sgd + self.assertEqual([op.type for op in pserver1.blocks[2].ops], ["sgd"]) + # 3 prefetch -> lookup_sparse_table for data0 + self.assertEqual([op.type for op in pserver1.blocks[3].ops], + ["lookup_sparse_table"]) + # 4 prefetch -> lookup_sparse_table for data1 + self.assertEqual([op.type for op in pserver1.blocks[4].ops], + ["lookup_sparse_table"]) + # 5 save table + self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"]) + + trainer = self.get_trainer(config) + self.assertEqual(len(trainer.blocks), 1) + print([op.type for op in trainer.blocks[0].ops]) + ops = [ + 'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids', + 'prefetch', 'merge_ids', 'sequence_pool', 'concat', 'mul', + 'elementwise_add', 'cross_entropy', 'mean', 'fill_constant', + 'mean_grad', 'cross_entropy_grad', 'elementwise_add_grad', 'send', + 'mul_grad', 'send', 'concat_grad', 'sequence_pool_grad', + 'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad', + 'sum', 'split_ids', 'send', 'recv', 'recv' + ] + self.assertEqual([op.type for op in trainer.blocks[0].ops], ops) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index b0a100e1db3..be96c20db1a 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -124,6 +124,7 @@ class DistributeTranspilerConfig(object): slice_var_up = True split_method = None min_block_size = 8192 + sync_mode = True class DistributeTranspiler(object): @@ -197,7 +198,7 @@ class DistributeTranspiler(object): program = default_main_program() self.origin_program = program self.trainer_num = trainers - self.sync_mode = sync_mode + self.sync_mode = sync_mode and self.config.sync_mode self.trainer_id = trainer_id pserver_endpoints = pservers.split(",") self.pserver_endpoints = pserver_endpoints @@ -293,14 +294,15 @@ class DistributeTranspiler(object): RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE }) - program.global_block().append_op( - type="fetch_barrier", - inputs={}, - outputs={}, - attrs={ - "endpoints": pserver_endpoints, - RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE - }) + if self.sync_mode: + program.global_block().append_op( + type="fetch_barrier", + inputs={}, + outputs={}, + attrs={ + "endpoints": pserver_endpoints, + RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE + }) for varname, splited_var in self.param_var_mapping.iteritems(): if len(splited_var) <= 1: -- GitLab