diff --git a/mindspore/nn/optim/adam.py b/mindspore/nn/optim/adam.py index 365fe8f1c5f879c56b3f10393dca62fa6078b255..40cb4b70d6bac780b7a094df1e95aca8ea9df692 100755 --- a/mindspore/nn/optim/adam.py +++ b/mindspore/nn/optim/adam.py @@ -51,7 +51,6 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d Returns: Tensor, the new value of v after updating. """ - success = True if optim_filter: op_mul = P.Mul() op_square = P.Square() @@ -81,8 +80,9 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param)))) next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m)))) next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v)))) - success = F.depend(success, next_param) - return success + + return op_cast(next_param, F.dtype(param)) + return gradient @_adam_opt.register("Function", "Function", "Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", diff --git a/mindspore/nn/optim/lamb.py b/mindspore/nn/optim/lamb.py index aa33c3eeb9cdb0b6e86a39079b91c3d86155aa5b..0197167adfcdd5892b32d8ebda40c1c6f568be3a 100755 --- a/mindspore/nn/optim/lamb.py +++ b/mindspore/nn/optim/lamb.py @@ -104,11 +104,11 @@ def _update_run_op(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) - next_param = F.depend(next_param, F.assign(param, next_param)) - next_param = F.depend(next_param, F.assign(m, next_m)) - next_param = F.depend(next_param, F.assign(v, next_v)) + next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param)))) + next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m)))) + next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v)))) - return next_param + return op_cast(next_param, F.dtype(param)) return gradient