未验证 提交 f936adbd 编写于 作者: M MRXLT 提交者: GitHub

fix adam (#27343)

* fix adam

* rmsprop support double
上级 d6b54de4
...@@ -143,4 +143,5 @@ http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) ...@@ -143,4 +143,5 @@ http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(rmsprop, ops::RmspropOp, ops::RmspropOpMaker); REGISTER_OP_WITHOUT_GRADIENT(rmsprop, ops::RmspropOp, ops::RmspropOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
rmsprop, ops::RmspropOpKernel<paddle::platform::CPUDeviceContext, float>); rmsprop, ops::RmspropOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::RmspropOpKernel<paddle::platform::CPUDeviceContext, double>);
...@@ -15,4 +15,5 @@ limitations under the License. */ ...@@ -15,4 +15,5 @@ limitations under the License. */
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
rmsprop, ops::RmspropOpKernel<paddle::platform::CUDADeviceContext, float>); rmsprop, ops::RmspropOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::RmspropOpKernel<paddle::platform::CUDADeviceContext, double>);
...@@ -282,14 +282,13 @@ class Adam(Optimizer): ...@@ -282,14 +282,13 @@ class Adam(Optimizer):
for param in self._parameter_list: for param in self._parameter_list:
if not param.trainable: if not param.trainable:
continue continue
if hasattr( if param._grad_ivar() is not None:
param, "_is_sparse" grad_var = param._grad_ivar()
) and param._is_sparse and self.regularization is not None: if hasattr(grad_var, "_is_sparse") and grad_var._is_sparse(
) and self.regularization is not None:
raise RuntimeError( raise RuntimeError(
"Adam don't support weight_decay with sparse parameters, please set it to None." "Adam don't support weight_decay with sparse parameters, please set it to None."
) )
if param._grad_ivar() is not None:
grad_var = param._grad_ivar()
params_grads.append((param, grad_var)) params_grads.append((param, grad_var))
optimize_ops = self._apply_optimize( optimize_ops = self._apply_optimize(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册