未验证 提交 b0e7226e 编写于 作者: W wanghuancoder 提交者: GitHub

fix rmsprop_ yaml bug (#49026)

* fix rmsprop_ yaml bug
上级 77ed03d6
......@@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle import _C_ops
from ..fluid import framework
from ..fluid.framework import in_dygraph_mode
from .optimizer import Optimizer
__all__ = []
......@@ -216,6 +219,22 @@ class RMSProp(Optimizer):
mean_grad_acc = self._get_accumulator(
self._mean_grad_acc_str, param_and_grad[0]
)
if in_dygraph_mode():
_C_ops.rmsprop_(
param_and_grad[0],
mean_square_acc,
param_and_grad[1],
momentum_acc,
self._create_param_lr(param_and_grad),
mean_grad_acc,
self._epsilon,
self._rho,
self._momentum,
self._centered,
)
return None
else:
rmsprop_op = block.append_op(
type=self.type,
inputs={
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册