From e97905c5faca1f0a3cf3cdd7f8f48665315e9b8b Mon Sep 17 00:00:00 2001 From: Zhang Ting Date: Thu, 4 Feb 2021 13:40:00 +0800 Subject: [PATCH] improve performance of momentum (#30881) --- python/paddle/optimizer/momentum.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/paddle/optimizer/momentum.py b/python/paddle/optimizer/momentum.py index 5fc5506ec3a..111b2720c86 100644 --- a/python/paddle/optimizer/momentum.py +++ b/python/paddle/optimizer/momentum.py @@ -104,7 +104,7 @@ class Momentum(Optimizer): raise ValueError("learning_rate is not set") if momentum is None: raise ValueError("momentum is not set") - predicate = lambda regular: isinstance(regular, L2DecayRegularizer) + predicate = lambda regular: isinstance(regular, (L2DecayRegularizer, float)) py_regular = None if predicate(weight_decay) else weight_decay super(Momentum, self).__init__( learning_rate=learning_rate, @@ -120,6 +120,9 @@ class Momentum(Optimizer): if (isinstance(weight_decay, L2DecayRegularizer)): self._regularization_method = "l2_decay" self._regularization_coeff = weight_decay._regularization_coeff + if (isinstance(weight_decay, float)): + self._regularization_method = "l2_decay" + self._regularization_coeff = weight_decay self._multi_precision = multi_precision self._rescale_grad = rescale_grad self._master_weights = {} -- GitLab