From b0e7226e6a307b9ed99e6a5cedef34f5fcffb2ed Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Tue, 13 Dec 2022 20:28:32 +0800 Subject: [PATCH] fix rmsprop_ yaml bug (#49026) * fix rmsprop_ yaml bug --- paddle/phi/api/yaml/legacy_ops.yaml | 2 +- python/paddle/optimizer/rmsprop.py | 69 ++++++++++++++++++----------- 2 files changed, 45 insertions(+), 26 deletions(-) diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index f228a70857e..ba446289f7e 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -1646,7 +1646,7 @@ kernel : func : rmsprop {dense, dense, dense, dense, dense, dense -> dense, dense, dense, dense} rmsprop_dense_param_sparse_grad {dense, dense, selected_rows, dense, dense, dense -> dense, dense, dense, dense} - optional : mean_grad + optional : mean_grad inplace : (param -> param_out), (moment -> moment_out), (mean_square -> mean_square_out), (mean_grad -> mean_grad_out) - op : rnn diff --git a/python/paddle/optimizer/rmsprop.py b/python/paddle/optimizer/rmsprop.py index c6e78f538bc..460c5e00ed2 100644 --- a/python/paddle/optimizer/rmsprop.py +++ b/python/paddle/optimizer/rmsprop.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from paddle import _C_ops + from ..fluid import framework +from ..fluid.framework import in_dygraph_mode from .optimizer import Optimizer __all__ = [] @@ -216,32 +219,48 @@ class RMSProp(Optimizer): mean_grad_acc = self._get_accumulator( self._mean_grad_acc_str, param_and_grad[0] ) - rmsprop_op = block.append_op( - type=self.type, - inputs={ - "Param": param_and_grad[0], - "Grad": param_and_grad[1], - "Moment": momentum_acc, - "MeanSquare": mean_square_acc, - "MeanGrad": mean_grad_acc, - "LearningRate": self._create_param_lr(param_and_grad), - }, - outputs={ - "ParamOut": param_and_grad[0], - "MomentOut": momentum_acc, - "MeanSquareOut": mean_square_acc, - "MeanGradOut": mean_grad_acc, - }, - attrs={ - "epsilon": self._epsilon, - "decay": self._rho, - "momentum": self._momentum, - "centered": self._centered, - }, - stop_gradient=True, - ) - return rmsprop_op + if in_dygraph_mode(): + _C_ops.rmsprop_( + param_and_grad[0], + mean_square_acc, + param_and_grad[1], + momentum_acc, + self._create_param_lr(param_and_grad), + mean_grad_acc, + self._epsilon, + self._rho, + self._momentum, + self._centered, + ) + return None + else: + rmsprop_op = block.append_op( + type=self.type, + inputs={ + "Param": param_and_grad[0], + "Grad": param_and_grad[1], + "Moment": momentum_acc, + "MeanSquare": mean_square_acc, + "MeanGrad": mean_grad_acc, + "LearningRate": self._create_param_lr(param_and_grad), + }, + outputs={ + "ParamOut": param_and_grad[0], + "MomentOut": momentum_acc, + "MeanSquareOut": mean_square_acc, + "MeanGradOut": mean_grad_acc, + }, + attrs={ + "epsilon": self._epsilon, + "decay": self._rho, + "momentum": self._momentum, + "centered": self._centered, + }, + stop_gradient=True, + ) + + return rmsprop_op def _update_param_group(self, parameters): self._epsilon = parameters.get('epsilon', self._default_dict['epsilon']) -- GitLab