未验证 提交 f936adbd 编写于 作者: M MRXLT 提交者: GitHub

fix adam (#27343)

* fix adam

* rmsprop support double
上级 d6b54de4
......@@ -143,4 +143,5 @@ http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(rmsprop, ops::RmspropOp, ops::RmspropOpMaker);
REGISTER_OP_CPU_KERNEL(
rmsprop, ops::RmspropOpKernel<paddle::platform::CPUDeviceContext, float>);
rmsprop, ops::RmspropOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::RmspropOpKernel<paddle::platform::CPUDeviceContext, double>);
......@@ -15,4 +15,5 @@ limitations under the License. */
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
rmsprop, ops::RmspropOpKernel<paddle::platform::CUDADeviceContext, float>);
rmsprop, ops::RmspropOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::RmspropOpKernel<paddle::platform::CUDADeviceContext, double>);
......@@ -282,14 +282,13 @@ class Adam(Optimizer):
for param in self._parameter_list:
if not param.trainable:
continue
if hasattr(
param, "_is_sparse"
) and param._is_sparse and self.regularization is not None:
raise RuntimeError(
"Adam don't support weight_decay with sparse parameters, please set it to None."
)
if param._grad_ivar() is not None:
grad_var = param._grad_ivar()
if hasattr(grad_var, "_is_sparse") and grad_var._is_sparse(
) and self.regularization is not None:
raise RuntimeError(
"Adam don't support weight_decay with sparse parameters, please set it to None."
)
params_grads.append((param, grad_var))
optimize_ops = self._apply_optimize(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册