From 62fd3209e1b58cf8d5a28dcc2f752d79150a8a04 Mon Sep 17 00:00:00 2001 From: WangXi Date: Mon, 9 Mar 2020 10:03:12 +0800 Subject: [PATCH] Fix dgc param regularizer, test=develop (#22888) --- python/paddle/fluid/optimizer.py | 33 ++++++++++++------- .../tests/unittests/test_dgc_optimizer.py | 16 ++++++++- 2 files changed, 36 insertions(+), 13 deletions(-) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 90dee2b11e..6b42b40403 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -1149,21 +1149,23 @@ class DGCMomentumOptimizer(Optimizer): self._num_trainers = num_trainers self._clip_norm = local_grad_clip_norm * (num_trainers**-0.5) - self._get_dgc_regularization_param() + self.regular_type, self.regular_coeff = self._get_regularization_param( + self.regularization) - def _get_dgc_regularization_param(self): - self.regular_coeff = 0.0 - self.regular_type = 0 + def _get_regularization_param(self, regularization): + regular_type = 0 + regular_coeff = 0.0 - if self.regularization is not None: - self.regular_coeff = self.regularization._regularization_coeff + if regularization is not None: + regular_coeff = regularization._regularization_coeff from .regularizer import L1Decay, L2Decay - if isinstance(self.regularization, L1Decay): - self.regular_type = 1 - elif isinstance(self.regularization, L2Decay): - self.regular_type = 2 + if isinstance(regularization, L1Decay): + regular_type = 1 + elif isinstance(regularization, L2Decay): + regular_type = 2 else: assert False, 'regularization must be None|L1Decay|L2Deacy' + return regular_type, regular_coeff def _is_use_dgc(self, param_var, grad_var): var_numel = abs(reduce(lambda x, y: x * y, param_var.shape)) @@ -1364,6 +1366,13 @@ class DGCMomentumOptimizer(Optimizer): block = framework.default_main_program().global_block() op_maker = core.op_proto_and_checker_maker + regular_type = self.regular_type + regular_coeff = self.regular_coeff + # The regularizer of the Parameters have higher priority + if param_var.regularizer is not None: + regular_type, regular_coeff = self._get_regularization_param( + param_var.regularizer) + dgc_op = block.append_op( type="dgc", inputs={ @@ -1388,8 +1397,8 @@ class DGCMomentumOptimizer(Optimizer): "use_nesterov": self._use_nesterov, "rampup_begin_step": float(self._rampup_begin_step), "rampup_step": float(self._rampup_step), - "regular_coeff": float(self.regular_coeff), - "regular_type": int(self.regular_type), + "regular_coeff": float(regular_coeff), + "regular_type": int(regular_type), }, stop_gradient=True) diff --git a/python/paddle/fluid/tests/unittests/test_dgc_optimizer.py b/python/paddle/fluid/tests/unittests/test_dgc_optimizer.py index 07dda4d594..521e498176 100644 --- a/python/paddle/fluid/tests/unittests/test_dgc_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_dgc_optimizer.py @@ -44,7 +44,9 @@ class TestDGCMomentumOptimizer(unittest.TestCase): shape=[dims[0], dims[1]], lod_level=0, name="mul.x", - optimize_attr={'learning_rate': 1.1}) + optimize_attr={'learning_rate': 1.1}, + regularizer=None if regularization is not None else + regularizer.L2DecayRegularizer(2e-4)) mul_y = block.create_var( dtype="float32", shape=[dims[1], dims[2]], @@ -102,6 +104,14 @@ class TestDGCMomentumOptimizer(unittest.TestCase): self.assertEqual(init_ops[0].type, "fill_constant") self.assertAlmostEqual(init_ops[0].attr('value'), learning_rate) + # check dgc op regularization coeff + train_ops = program.global_block().ops + for op in train_ops: + if op.type == "dgc": + coeff = 2e-4 if regularization is None else 1e-4 + self.assertAlmostEqual(op.attr('regular_coeff'), coeff) + print("dgc regular_coeff=" + str(coeff)) + with open("test_dgc_optimizer_" + name + ".log", "w") as f: program_to_code(program, fout=f) @@ -116,6 +126,10 @@ class TestDGCMomentumOptimizer(unittest.TestCase): name="dgc_momentum", regularization=regularizer.L2Decay(1e-4)) + # check param.regularizer in dgc + self.check_dgc_momentum_optimizer( + dims=[16, 1024, 8], name="dgc_momentum") + if __name__ == '__main__': unittest.main() -- GitLab