From 373f64986dd41bfacda4d408d138f25f6fa95c2c Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Thu, 8 Nov 2018 13:55:02 +0800 Subject: [PATCH] add comment and unit test test=develop --- python/paddle/fluid/optimizer.py | 9 +++++++++ .../fluid/tests/unittests/test_dist_transpiler.py | 12 +++++++++++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 9f089ef1e8..94d171d83d 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -249,6 +249,15 @@ class Optimizer(object): def _process_distribute_lookuptable(self, param_grads, loss, startup_program): + """ + Because distribute lookup table only support SGD optimizer for now, not support + other optimizer and regularization, so we should find the table parameter out, + and avoid to add regularization and other op for it, and add sgd optimize op + for it independently. + :param param_grads(list((Var, Var))): list of (param, grad) pair. + :param loss: the loss variable. + :param startup_program: the startup program + """ program = loss.block.program table_name = find_distributed_lookup_table(program) table_param = None diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py index 0957b97980..f08b6ac035 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py @@ -641,7 +641,7 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase): # 5 save table self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"]) - trainer, _ = self.get_trainer(config) + trainer, trainer_startup = self.get_trainer(config) self.assertEqual(len(trainer.blocks), 1) ops = [ 'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', @@ -655,6 +655,16 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase): 'recv', 'concat' ] self.assertEqual([op.type for op in trainer.blocks[0].ops], ops) + startup_ops = [ + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'uniform_random', + 'uniform_random', 'recv', 'recv', 'recv', 'fetch_barrier', 'concat', + 'fake_init' + ] + self.assertEqual([op.type for op in trainer_startup.blocks[0].ops], + startup_ops) class TestDistLookupTableSliceSize(TestDistLookupTableBase): -- GitLab