未验证 提交 b0e7226e 编写于 作者: W wanghuancoder 提交者: GitHub

fix rmsprop_ yaml bug (#49026)

* fix rmsprop_ yaml bug
上级 77ed03d6
......@@ -1646,7 +1646,7 @@
kernel :
func : rmsprop {dense, dense, dense, dense, dense, dense -> dense, dense, dense, dense}
rmsprop_dense_param_sparse_grad {dense, dense, selected_rows, dense, dense, dense -> dense, dense, dense, dense}
optional : mean_grad
optional : mean_grad
inplace : (param -> param_out), (moment -> moment_out), (mean_square -> mean_square_out), (mean_grad -> mean_grad_out)
- op : rnn
......
......@@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle import _C_ops
from ..fluid import framework
from ..fluid.framework import in_dygraph_mode
from .optimizer import Optimizer
__all__ = []
......@@ -216,32 +219,48 @@ class RMSProp(Optimizer):
mean_grad_acc = self._get_accumulator(
self._mean_grad_acc_str, param_and_grad[0]
)
rmsprop_op = block.append_op(
type=self.type,
inputs={
"Param": param_and_grad[0],
"Grad": param_and_grad[1],
"Moment": momentum_acc,
"MeanSquare": mean_square_acc,
"MeanGrad": mean_grad_acc,
"LearningRate": self._create_param_lr(param_and_grad),
},
outputs={
"ParamOut": param_and_grad[0],
"MomentOut": momentum_acc,
"MeanSquareOut": mean_square_acc,
"MeanGradOut": mean_grad_acc,
},
attrs={
"epsilon": self._epsilon,
"decay": self._rho,
"momentum": self._momentum,
"centered": self._centered,
},
stop_gradient=True,
)
return rmsprop_op
if in_dygraph_mode():
_C_ops.rmsprop_(
param_and_grad[0],
mean_square_acc,
param_and_grad[1],
momentum_acc,
self._create_param_lr(param_and_grad),
mean_grad_acc,
self._epsilon,
self._rho,
self._momentum,
self._centered,
)
return None
else:
rmsprop_op = block.append_op(
type=self.type,
inputs={
"Param": param_and_grad[0],
"Grad": param_and_grad[1],
"Moment": momentum_acc,
"MeanSquare": mean_square_acc,
"MeanGrad": mean_grad_acc,
"LearningRate": self._create_param_lr(param_and_grad),
},
outputs={
"ParamOut": param_and_grad[0],
"MomentOut": momentum_acc,
"MeanSquareOut": mean_square_acc,
"MeanGradOut": mean_grad_acc,
},
attrs={
"epsilon": self._epsilon,
"decay": self._rho,
"momentum": self._momentum,
"centered": self._centered,
},
stop_gradient=True,
)
return rmsprop_op
def _update_param_group(self, parameters):
self._epsilon = parameters.get('epsilon', self._default_dict['epsilon'])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册