diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 9f089ef1e8bb2a42c378155d8e27f25a1af4adc4..94d171d83d8a14119fcd1f27db6b484fe8b067b9 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -249,6 +249,15 @@ class Optimizer(object): def _process_distribute_lookuptable(self, param_grads, loss, startup_program): + """ + Because distribute lookup table only support SGD optimizer for now, not support + other optimizer and regularization, so we should find the table parameter out, + and avoid to add regularization and other op for it, and add sgd optimize op + for it independently. + :param param_grads(list((Var, Var))): list of (param, grad) pair. + :param loss: the loss variable. + :param startup_program: the startup program + """ program = loss.block.program table_name = find_distributed_lookup_table(program) table_param = None diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py index 0957b97980e7234f480546698b8eee9da4876dab..f08b6ac035e63168691864403bc0607f2718d2fd 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py @@ -641,7 +641,7 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase): # 5 save table self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"]) - trainer, _ = self.get_trainer(config) + trainer, trainer_startup = self.get_trainer(config) self.assertEqual(len(trainer.blocks), 1) ops = [ 'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', @@ -655,6 +655,16 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase): 'recv', 'concat' ] self.assertEqual([op.type for op in trainer.blocks[0].ops], ops) + startup_ops = [ + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'uniform_random', + 'uniform_random', 'recv', 'recv', 'recv', 'fetch_barrier', 'concat', + 'fake_init' + ] + self.assertEqual([op.type for op in trainer_startup.blocks[0].ops], + startup_ops) class TestDistLookupTableSliceSize(TestDistLookupTableBase):