From f936adbd2d9e2a34dd4797ef1769e2c38e8cfae2 Mon Sep 17 00:00:00 2001 From: MRXLT Date: Mon, 21 Sep 2020 11:16:34 +0800 Subject: [PATCH] fix adam (#27343) * fix adam * rmsprop support double --- paddle/fluid/operators/optimizers/rmsprop_op.cc | 3 ++- paddle/fluid/operators/optimizers/rmsprop_op.cu | 3 ++- python/paddle/optimizer/adam.py | 11 +++++------ 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/operators/optimizers/rmsprop_op.cc b/paddle/fluid/operators/optimizers/rmsprop_op.cc index 99d1156ee6..eeee008cdc 100644 --- a/paddle/fluid/operators/optimizers/rmsprop_op.cc +++ b/paddle/fluid/operators/optimizers/rmsprop_op.cc @@ -143,4 +143,5 @@ http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) namespace ops = paddle::operators; REGISTER_OP_WITHOUT_GRADIENT(rmsprop, ops::RmspropOp, ops::RmspropOpMaker); REGISTER_OP_CPU_KERNEL( - rmsprop, ops::RmspropOpKernel); + rmsprop, ops::RmspropOpKernel, + ops::RmspropOpKernel); diff --git a/paddle/fluid/operators/optimizers/rmsprop_op.cu b/paddle/fluid/operators/optimizers/rmsprop_op.cu index 8b17d6a020..bf11ee6867 100644 --- a/paddle/fluid/operators/optimizers/rmsprop_op.cu +++ b/paddle/fluid/operators/optimizers/rmsprop_op.cu @@ -15,4 +15,5 @@ limitations under the License. */ namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( - rmsprop, ops::RmspropOpKernel); + rmsprop, ops::RmspropOpKernel, + ops::RmspropOpKernel); diff --git a/python/paddle/optimizer/adam.py b/python/paddle/optimizer/adam.py index 708aaa788f..24cebf8e6e 100644 --- a/python/paddle/optimizer/adam.py +++ b/python/paddle/optimizer/adam.py @@ -282,14 +282,13 @@ class Adam(Optimizer): for param in self._parameter_list: if not param.trainable: continue - if hasattr( - param, "_is_sparse" - ) and param._is_sparse and self.regularization is not None: - raise RuntimeError( - "Adam don't support weight_decay with sparse parameters, please set it to None." - ) if param._grad_ivar() is not None: grad_var = param._grad_ivar() + if hasattr(grad_var, "_is_sparse") and grad_var._is_sparse( + ) and self.regularization is not None: + raise RuntimeError( + "Adam don't support weight_decay with sparse parameters, please set it to None." + ) params_grads.append((param, grad_var)) optimize_ops = self._apply_optimize( -- GitLab