diff --git a/python/paddle/optimizer/momentum.py b/python/paddle/optimizer/momentum.py index 5fc5506ec3a3294c18d7592446f2a18b03a96dce..111b2720c86687f141169c963ec69f48cfe01df2 100644 --- a/python/paddle/optimizer/momentum.py +++ b/python/paddle/optimizer/momentum.py @@ -104,7 +104,7 @@ class Momentum(Optimizer): raise ValueError("learning_rate is not set") if momentum is None: raise ValueError("momentum is not set") - predicate = lambda regular: isinstance(regular, L2DecayRegularizer) + predicate = lambda regular: isinstance(regular, (L2DecayRegularizer, float)) py_regular = None if predicate(weight_decay) else weight_decay super(Momentum, self).__init__( learning_rate=learning_rate, @@ -120,6 +120,9 @@ class Momentum(Optimizer): if (isinstance(weight_decay, L2DecayRegularizer)): self._regularization_method = "l2_decay" self._regularization_coeff = weight_decay._regularization_coeff + if (isinstance(weight_decay, float)): + self._regularization_method = "l2_decay" + self._regularization_coeff = weight_decay self._multi_precision = multi_precision self._rescale_grad = rescale_grad self._master_weights = {}