From 7e4bdf6add9e16e9f7e55aebf30ea5f64862843b Mon Sep 17 00:00:00 2001 From: lilei Date: Mon, 15 Jun 2020 11:29:08 +0800 Subject: [PATCH] proximal_ada_grad optimizer --- mindspore/nn/optim/proximal_ada_grad.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/mindspore/nn/optim/proximal_ada_grad.py b/mindspore/nn/optim/proximal_ada_grad.py index e7e11b21c..6e6196a21 100644 --- a/mindspore/nn/optim/proximal_ada_grad.py +++ b/mindspore/nn/optim/proximal_ada_grad.py @@ -31,15 +31,13 @@ def _tensor_run_opt(opt, learning_rate, l1, l2, gradient, weight, accum): return success -def _check_param_value(accum, learning_rate, l1, l2, use_locking, prim_name=None): +def _check_param_value(accum, l1, l2, use_locking, prim_name=None): """Check inputs param.""" validator.check_value_type("accum", accum, [float], prim_name) - validator.check_value_type("learning_rate", learning_rate, [float], prim_name) validator.check_value_type("l1", l1, [float], prim_name) validator.check_value_type("l2", l2, [float], prim_name) validator.check_value_type("use_locking", use_locking, [bool], prim_name) validator.check_number_range("accum", accum, 0.0, float("inf"), Rel.INC_LEFT, prim_name) - validator.check_number_range("learning_rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, prim_name) validator.check_number_range("l1", l1, 0.0, float("inf"), Rel.INC_LEFT, prim_name) validator.check_number_range("l2", l2, 0.0, float("inf"), Rel.INC_LEFT, prim_name) @@ -79,10 +77,10 @@ class ProximalAdagrad(Optimizer): def __init__(self, params, accum=0.1, learning_rate=0.001, l1=0.0, l2=0.0, use_locking=False, loss_scale=1.0, weight_decay=0.0): - super(ProximalAdagrad, self).__init__(0.0, params, weight_decay, loss_scale) + super(ProximalAdagrad, self).__init__(learning_rate, params, weight_decay, loss_scale) if self.is_group: raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.") - _check_param_value(accum, learning_rate, l1, l2, use_locking, self.cls_name) + _check_param_value(accum, l1, l2, use_locking, self.cls_name) self.accum = self.parameters.clone(prefix="accum", init=accum) self.l1 = Tensor(l1, mstype.float32) self.l2 = Tensor(l2, mstype.float32) -- GitLab