diff --git a/paddle/fluid/operators/fake_init_op.cc b/paddle/fluid/operators/fake_init_op.cc index 2b3a54115656acef9f44975bbdac1dc76e5feb2d..05aa492410499c6ba384ae126b39793721966c98 100644 --- a/paddle/fluid/operators/fake_init_op.cc +++ b/paddle/fluid/operators/fake_init_op.cc @@ -68,9 +68,10 @@ class FakeInitOpMaker : public framework::OpProtoAndCheckerMaker { "(Tensor) Tensor of specified shape will be filled " "with the specified value"); AddComment(R"DOC( -FakeInitBatchSizeLike Operator. +FakeInit Operator. -Init an op but not alloc tensor for it, it is used for distributed lookup table. +Init an variable but not alloc memory for it, it is used for init the +table parameter at trainer side in distributed lookup table. )DOC"); } diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py index 54a1c68a37f6929890aab697b48d621e6effb7d8..2b7227a6460f7f9e1ada54c28eb8a88230cce6e1 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py @@ -497,7 +497,7 @@ class TestDistLookupTable(TestDistLookupTableBase): # 5 save table self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"]) - trainer, _ = self.get_trainer() + trainer, trainer_startup = self.get_trainer() self.assertEqual(len(trainer.blocks), 1) ops = [ 'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids', @@ -511,6 +511,16 @@ class TestDistLookupTable(TestDistLookupTableBase): ] self.assertEqual([op.type for op in trainer.blocks[0].ops], ops) + startup_ops = [ + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'uniform_random', 'recv', 'recv', + 'fetch_barrier', 'fake_init' + ] + self.assertEqual([op.type for op in trainer_startup.blocks[0].ops], + startup_ops) + class TestAsyncLocalLookupTable(TestDistLookupTableBase): def net_conf(self):