From 68aeb4e7e9caa87469ffbcd39af2e25bcff35710 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Fri, 26 Oct 2018 22:25:58 +0800 Subject: [PATCH] add fake init test in test_dist_transpiler --- paddle/fluid/operators/fake_init_op.cc | 5 +++-- .../fluid/tests/unittests/test_dist_transpiler.py | 12 +++++++++++- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/fake_init_op.cc b/paddle/fluid/operators/fake_init_op.cc index 2b3a5411565..05aa4924104 100644 --- a/paddle/fluid/operators/fake_init_op.cc +++ b/paddle/fluid/operators/fake_init_op.cc @@ -68,9 +68,10 @@ class FakeInitOpMaker : public framework::OpProtoAndCheckerMaker { "(Tensor) Tensor of specified shape will be filled " "with the specified value"); AddComment(R"DOC( -FakeInitBatchSizeLike Operator. +FakeInit Operator. -Init an op but not alloc tensor for it, it is used for distributed lookup table. +Init an variable but not alloc memory for it, it is used for init the +table parameter at trainer side in distributed lookup table. )DOC"); } diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py index 54a1c68a37f..2b7227a6460 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py @@ -497,7 +497,7 @@ class TestDistLookupTable(TestDistLookupTableBase): # 5 save table self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"]) - trainer, _ = self.get_trainer() + trainer, trainer_startup = self.get_trainer() self.assertEqual(len(trainer.blocks), 1) ops = [ 'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids', @@ -511,6 +511,16 @@ class TestDistLookupTable(TestDistLookupTableBase): ] self.assertEqual([op.type for op in trainer.blocks[0].ops], ops) + startup_ops = [ + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'uniform_random', 'recv', 'recv', + 'fetch_barrier', 'fake_init' + ] + self.assertEqual([op.type for op in trainer_startup.blocks[0].ops], + startup_ops) + class TestAsyncLocalLookupTable(TestDistLookupTableBase): def net_conf(self): -- GitLab