From 7715cdc69bd74a9ffd9b8477886986a5fcf95d06 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Fri, 17 Aug 2018 11:25:07 +0800 Subject: [PATCH] add unit test in dist transpiler --- .../tests/unittests/test_dist_transpiler.py | 22 +++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py index 2528f82ff88..f8026ae9425 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py @@ -393,13 +393,14 @@ class TestDistLookupTableBase(TranspilerTest): def network_with_table(self, is_sparse, is_distributed): self.table_size = 1000 self.emb_size = 64 + self.lookup_table_name = 'shared_w' def emb_pool(ids): emb = fluid.layers.embedding( input=ids, size=[self.table_size, self.emb_size], dtype='float32', - param_attr='shared_w', # share parameter + param_attr=self.lookup_table_name, # share parameter is_sparse=is_sparse, is_distributed=is_distributed) pool = fluid.layers.sequence_pool(input=emb, pool_type='average') @@ -572,7 +573,7 @@ class TestDistLookupTableSliceSize(TestDistLookupTableBase): def transpiler_test_impl(self): config = fluid.DistributeTranspilerConfig() - pserver1, startup1 = self.get_pserver(self.pserver1_ep, config) + pserver1, _ = self.get_pserver(self.pserver1_ep, config) self.assertTrue(self.transpiler.has_distributed_lookup_table) lookup_table_var = pserver1.global_block().vars[ @@ -582,6 +583,23 @@ class TestDistLookupTableSliceSize(TestDistLookupTableBase): self.assertEqual(row_size, calc_row_size) +class TestDistArgsInProgram(TestDistLookupTableBase): + def net_conf(self): + self.network_with_table(is_sparse=True, is_distributed=True) + + def transpiler_test_impl(self): + config = fluid.DistributeTranspilerConfig() + + trainer, _ = self.get_trainer() + + self.assertTrue(trainer._is_distributed) + self.assertTrue(trainer._is_chief) + self.assertEqual(trainer._distributed_lookup_table, + self.lookup_table_name) + self.assertEqual(trainer._endpoints, + [self.pserver1_ep, self.pserver2_ep]) + + class TestRMSPropOptimizer(TranspilerTest): def net_conf(self): x = fluid.layers.data(name='x', shape=[1000], dtype='float32') -- GitLab