提交 759748dc 编写于 作者: F fary86

Fix bugs of adam and lamb optimizer

上级 5a63dac0
...@@ -51,7 +51,6 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d ...@@ -51,7 +51,6 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d
Returns: Returns:
Tensor, the new value of v after updating. Tensor, the new value of v after updating.
""" """
success = True
if optim_filter: if optim_filter:
op_mul = P.Mul() op_mul = P.Mul()
op_square = P.Square() op_square = P.Square()
...@@ -81,8 +80,9 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d ...@@ -81,8 +80,9 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d
next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param)))) next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param))))
next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m)))) next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m))))
next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v)))) next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v))))
success = F.depend(success, next_param)
return success return op_cast(next_param, F.dtype(param))
return gradient
@_adam_opt.register("Function", "Function", "Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", @_adam_opt.register("Function", "Function", "Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
......
...@@ -104,11 +104,11 @@ def _update_run_op(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v ...@@ -104,11 +104,11 @@ def _update_run_op(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v
next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
next_param = F.depend(next_param, F.assign(param, next_param)) next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param))))
next_param = F.depend(next_param, F.assign(m, next_m)) next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m))))
next_param = F.depend(next_param, F.assign(v, next_v)) next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v))))
return next_param return op_cast(next_param, F.dtype(param))
return gradient return gradient
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册