From c0e8dd875896038eb3dfb7ff1d21eb0c31aff4b0 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Tue, 7 Aug 2018 11:40:29 +0800 Subject: [PATCH] add unit test for dist lookup table --- .../tests/unittests/test_dist_transpiler.py | 74 +++++++++++++++++++ 1 file changed, 74 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py index 7107cec2bfc..5b8200df5e1 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py @@ -359,5 +359,79 @@ class TestL2DecayWithPiecewise(TranspilerTest): ["sum", "scale", "scale", "elementwise_add", "momentum"]) +class TestDistLookupTableBase(TranspilerTest): + def network_with_table(self, is_sparse, is_distributed): + def emb_pool(ids): + table_size = 1000 + emb_size = 64 + emb = fluid.layers.embedding( + input=ids, + size=[table_size, emb_size], + dtype='float32', + param_attr='shared_w', # share parameter + is_sparse=is_sparse, + is_distributed=is_distributed) + pool = fluid.layers.sequence_pool(input=emb, pool_type='average') + return pool + + title_ids = fluid.layers.data( + name='title_ids', shape=[1], dtype='int64', lod_level=1) + brand_ids = fluid.layers.data( + name='brand_ids', shape=[1], dtype='int64', lod_level=1) + title_emb = emb_pool(title_ids) + brand_emb = emb_pool(brand_ids) + fc0 = fluid.layers.concat(input=[title_emb, brand_emb], axis=1) + predict = fluid.layers.fc(input=fc0, + size=2, + act=None, + param_attr=fluid.ParamAttr(name='fc_w'), + bias_attr=fluid.ParamAttr(name='fc_b')) + + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + cost = fluid.layers.cross_entropy(input=predict, label=label) + avg_cost = fluid.layers.mean(cost) + optimizer = fluid.optimizer.Adam(learning_rate=0.003) + optimizer.minimize(avg_cost) + + +class TestDistLookupTable(TestDistLookupTableBase): + def net_conf(self): + self.network_with_table(is_sparse=True, is_distributed=True) + + def transpiler_test_impl(self): + pserver1, startup1 = self.get_pserver(self.pserver1_ep) + + self.assertEqual(len(pserver1.blocks), 6) + # 0 listen_and_serv + # 1 optimize for fc_w or fc_b adam + self.assertEqual([op.type for op in pserver1.blocks[1].ops], + ["sum", "scale", "adam", "scale", "scale"]) + # 2 optimize for table sgd + self.assertEqual([op.type for op in pserver1.blocks[2].ops], + ["sum", "sgd"]) + # 3 prefetch -> lookup_sparse_table for data0 + self.assertEqual([op.type for op in pserver1.blocks[3].ops], + ["lookup_sparse_table"]) + # 4 prefetch -> lookup_sparse_table for data1 + self.assertEqual([op.type for op in pserver1.blocks[4].ops], + ["lookup_sparse_table"]) + # 5 save table + self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"]) + + trainer = self.get_trainer() + self.assertEqual(len(trainer.blocks), 1) + ops = [ + 'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids', + 'prefetch', 'merge_ids', 'sequence_pool', 'concat', 'mul', + 'elementwise_add', 'cross_entropy', 'mean', 'fill_constant', + 'mean_grad', 'cross_entropy_grad', 'elementwise_add_grad', 'send', + 'mul_grad', 'send', 'concat_grad', 'sequence_pool_grad', + 'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad', + 'sum', 'split_ids', 'send', 'send_barrier', 'recv', 'recv', + 'fetch_barrier' + ] + self.assertEqual([op.type for op in trainer.blocks[0].ops], ops) + + if __name__ == "__main__": unittest.main() -- GitLab