From 759748dc066aac977ddabc4df6b4112cde7c1e91 Mon Sep 17 00:00:00 2001 From: fary86 Date: Tue, 8 Sep 2020 16:04:02 +0800 Subject: [PATCH] Fix bugs of adam and lamb optimizer --- mindspore/nn/optim/adam.py | 6 +++--- mindspore/nn/optim/lamb.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/mindspore/nn/optim/adam.py b/mindspore/nn/optim/adam.py index 365fe8f1c..40cb4b70d 100755 --- a/mindspore/nn/optim/adam.py +++ b/mindspore/nn/optim/adam.py @@ -51,7 +51,6 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d Returns: Tensor, the new value of v after updating. """ - success = True if optim_filter: op_mul = P.Mul() op_square = P.Square() @@ -81,8 +80,9 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param)))) next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m)))) next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v)))) - success = F.depend(success, next_param) - return success + + return op_cast(next_param, F.dtype(param)) + return gradient @_adam_opt.register("Function", "Function", "Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", diff --git a/mindspore/nn/optim/lamb.py b/mindspore/nn/optim/lamb.py index aa33c3eeb..0197167ad 100755 --- a/mindspore/nn/optim/lamb.py +++ b/mindspore/nn/optim/lamb.py @@ -104,11 +104,11 @@ def _update_run_op(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) - next_param = F.depend(next_param, F.assign(param, next_param)) - next_param = F.depend(next_param, F.assign(m, next_m)) - next_param = F.depend(next_param, F.assign(v, next_v)) + next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param)))) + next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m)))) + next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v)))) - return next_param + return op_cast(next_param, F.dtype(param)) return gradient -- GitLab