未验证 提交 bd06be00 编写于 作者: R RedContritio 提交者: GitHub

support auto generate for op rmsprop optimizer (#52709)

上级 9246b93c
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle {
namespace operators {
class RmspropOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
class RmspropOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Param",
"(Tensor, default Tensor<float>) "
"Input parameter value that has to be updated.");
AddInput("MeanSquare",
"(Tensor, default Tensor<float>)"
" The mean square value that gets updated.");
AddInput("MeanGrad",
"(Tensor, default Tensor<float>)"
" The moving average of gradient")
.AsDispensable();
AddInput("LearningRate",
"(Tensor, default Tensor<float>) "
"The learning rate should be a tensor of size 1.");
AddInput("Grad",
"(Tensor, default Tensor<float>) "
"Input gradient of the parameter.");
AddInput("Moment",
"(Tensor, default Tensor<float>) The moment that gets updated.");
AddInput("MasterParam", "FP32 master weight for AMP.").AsDispensable();
AddOutput("ParamOut", "(Tensor) Output updated parameter value.");
AddOutput("MomentOut", "(Tensor) Output updated moment.");
AddOutput("MeanSquareOut", "(Tensor) Output Mean squared updated value.");
AddOutput("MeanGradOut",
"(Tensor) Output moving average of gradient updated value.");
AddOutput("MasterParamOut",
"The updated FP32 master weight for AMP. "
"It shared memory with Input(MasterParam).")
.AsDispensable();
AddAttr<float>("epsilon",
"(float, default 1e-10) Constant "
"for numerical stability.")
.SetDefault(1.0e-10f);
AddAttr<float>("decay",
"(float, default 0.9) "
"Discounting factor for coming gradient.")
.SetDefault(0.9f);
AddAttr<float>("momentum", "(float, default 0.0) Constant value.")
.SetDefault(0.0f);
AddAttr<bool>("centered", "(bool, default false) use centered rmsprop.")
.SetDefault(false);
AddAttr<bool>("multi_precision",
"(bool, default false) "
"Whether to use multi-precision during weight updating.")
.SetDefault(false);
AddComment(R"DOC(
Rmsprop Optimizer.
$$
MeanSquareOut = decay * MeanSquare + (1 - decay) * Grad * Grad \\
MomentOut = momentum * Moment +
\frac{LearningRate * Grad}{\sqrt{MeanSquareOut + epsilon}} \\
ParamOut = Param - MomentOut
$$
if centered is true:
mean_grad = decay * mean_square{t-1} + (1-decay) * gradient
mean_square = decay * mean_square{t-1} + (1-decay) * gradient ** 2
mom = momentum * mom{t-1} + learning_rate * g_t /
sqrt(mean_square - mean_grad**2 + epsilon)
param -= mom
The original slides that proposed Rmsprop: Slide 29 of
http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(rmsprop,
RmspropInferShapeFunctor,
PD_INFER_META(phi::RmspropInferMeta));
REGISTER_OP_WITHOUT_GRADIENT(rmsprop,
ops::RmspropOp,
ops::RmspropOpMaker,
RmspropInferShapeFunctor);
......@@ -14,8 +14,7 @@ register_unity_group(
proximal_gd_op.cc
decayed_adagrad_op.cc
adadelta_op.cc
dpsgd_op.cc
rmsprop_op.cc)
dpsgd_op.cc)
register_unity_group(
cu
ftrl_op.cu
......@@ -27,5 +26,4 @@ register_unity_group(
adam_op.cu
decayed_adagrad_op.cu
adadelta_op.cu
lamb_op.cu
rmsprop_op.cu)
lamb_op.cu)
......@@ -1070,18 +1070,6 @@
intermediate : xshape
backward: reshape_grad
- op : rmsprop_
args : (Tensor param, Tensor mean_square, Tensor grad, Tensor moment, Tensor learning_rate, Tensor mean_grad, Tensor master_param, float epsilon, float decay, float momentum, bool centered, bool multi_precision)
output : Tensor(param_out), Tensor(moment_out), Tensor(mean_square_out), Tensor(mean_grad_out), Tensor(master_param_out)
infer_meta :
func : RmspropInferMeta
kernel :
func : rmsprop {dense, dense, dense, dense, dense, dense, dense-> dense, dense, dense, dense, dense}
rmsprop_dense_param_sparse_grad {dense, dense, selected_rows, dense, dense, dense, dense-> dense, dense, dense, dense, dense}
data_type : param
optional : mean_grad, master_param
inplace : (param -> param_out), (moment -> moment_out), (mean_square -> mean_square_out), (mean_grad -> mean_grad_out), (master_param->master_param_out)
- op : rnn
args: (Tensor x, Tensor[] pre_state, Tensor[] weight_list, Tensor sequence_length, Tensor dropout_state_in, float dropout_prob=0.0, bool is_bidirec=false, int input_size=10, int hidden_size=100, int num_layers=1, str mode="RNN_TANH", int seed=0, bool is_test=false)
output: Tensor(out), Tensor(dropout_state_out), Tensor[](state){pre_state.size()}, Tensor(reserve)
......
......@@ -1790,6 +1790,12 @@
support_tensor : true
manual_signature : [reverse]
- op : rmsprop_
inputs :
{param: Param, mean_square: MeanSquare, mean_grad: MeanGrad, learning_rate: LearningRate, grad: Grad, moment: Moment, master_param: MasterParam}
outputs :
{param_out: ParamOut, moment_out: MomentOut, mean_square_out: MeanSquareOut, mean_grad_out: MeanGradOut, master_param_outs: MasterParamOut}
- op : roll
backward : roll_grad
inputs :
......
......@@ -1511,6 +1511,18 @@
data_type : x
backward : reverse_grad
- op : rmsprop_
args : (Tensor param, Tensor mean_square, Tensor grad, Tensor moment, Tensor learning_rate, Tensor mean_grad, Tensor master_param, float epsilon = 1.0e-10f, float decay = 0.9f, float momentum = 0.0f, bool centered = false, bool multi_precision = false)
output : Tensor(param_out), Tensor(moment_out), Tensor(mean_square_out), Tensor(mean_grad_out), Tensor(master_param_outs)
infer_meta :
func : RmspropInferMeta
kernel :
func : rmsprop {dense, dense, dense, dense, dense, dense, dense-> dense, dense, dense, dense, dense}
rmsprop_dense_param_sparse_grad {dense, dense, selected_rows, dense, dense, dense, dense-> dense, dense, dense, dense, dense}
data_type : param
optional : mean_grad, master_param, master_param_outs
inplace : (param -> param_out), (moment -> moment_out), (mean_square -> mean_square_out), (mean_grad -> mean_grad_out), (master_param->master_param_outs)
- op : roll
args : (Tensor x, IntArray shifts={}, int64_t[] axis={})
output : Tensor(out)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature RmspropOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorInput("Grad")) {
return KernelSignature(
"rmsprop",
{"Param",
"MeanSquare",
"Grad",
"Moment",
"LearningRate",
"MeanGrad",
"MasterParam"},
{"epsilon", "decay", "momentum", "centered", "multi_precision"},
{"ParamOut",
"MomentOut",
"MeanSquareOut",
"MeanGradOut",
"MasterParamOut"});
} else if (ctx.IsSelectedRowsInput("Grad")) {
return KernelSignature(
"rmsprop_dense_param_sparse_grad",
{"Param",
"MeanSquare",
"Grad",
"Moment",
"LearningRate",
"MeanGrad",
"MasterParam"},
{"epsilon", "decay", "momentum", "centered", "multi_precision"},
{"ParamOut",
"MomentOut",
"MeanSquareOut",
"MeanGradOut",
"MasterParamOut"});
}
return KernelSignature("unregistered", {}, {}, {});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(rmsprop, phi::RmspropOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册