未验证 提交 2a420036 编写于 作者: R RedContritio 提交者: GitHub

support auto generate for op merged_momentum optimizer (#52708)

* fix error in generator/type_mapping.py

* support auto generate for op merged_momentum optimizer
上级 410e25fb
......@@ -76,7 +76,7 @@ opmaker_attr_types_map = {
'int64_t[]': 'std::vector<int64_t>',
'float[]': 'std::vector<float>',
'double[]': 'std::vector<double>',
'str[]': 'std::vector<<std::string>',
'str[]': 'std::vector<std::string>',
}
output_type_map = {'Tensor': 'Tensor', 'Tensor[]': 'std::vector<Tensor>'}
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle {
namespace operators {
class MergedMomentumOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto param_dtype =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "Param");
return phi::KernelKey(param_dtype, ctx.GetPlace());
}
};
class MergedMomentumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Param",
"(Tensor, default Tensor<float>) "
"Input parameter that has to be updated")
.AsDuplicable();
AddInput("Grad",
"(Tensor, default Tensor<float>) "
"Input gradient of the parameter")
.AsDuplicable();
AddInput("Velocity",
"(Tensor, default Tensor<float>) "
"Input velocity (corresponding to the parameter) "
"that has to be updated")
.AsDuplicable();
AddInput("LearningRate",
"(Tensor, default Tensor<float>) "
"Input learning rate")
.AsDuplicable();
AddInput("MasterParam", "FP32 master weight for AMP.")
.AsDispensable()
.AsDuplicable();
AddOutput("ParamOut",
"(Tensor) This output is updated parameter. "
"It shared memory with Input(Param).")
.AsDuplicable();
AddOutput("VelocityOut",
"(Tensor) This output is updated velocity. "
"It shared memory with Input(Velocity).")
.AsDuplicable();
AddOutput("MasterParamOut",
"The updated FP32 master weight for AMP. "
"It shared memory with Input(MasterParam).")
.AsDispensable()
.AsDuplicable();
AddAttr<float>("mu", "(float) Momentum coefficient");
AddAttr<bool>("use_nesterov",
"(bool, default false) "
"Use Nesterov Momentum or not.")
.SetDefault(false);
AddAttr<std::vector<std::string>>(
"regularization_method",
"(string) regularization_method, right now only "
"support l2decay or none")
.SetDefault({});
AddAttr<std::vector<float>>("regularization_coeff",
"(float) regularization_coeff")
.SetDefault({});
AddAttr<bool>("multi_precision",
"(bool, default false) "
"Whether to use multi-precision during weight updating.")
.SetDefault(false);
AddAttr<float>(
"rescale_grad",
"(float, default 1.0) Multiply the gradient with `rescale_grad`"
"before updating. Often choose to be `1.0/batch_size`.")
.SetDefault(1.0f);
AddComment(R"DOC(Merged Momentum Optimizer.)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
DECLARE_INFER_SHAPE_FUNCTOR(merged_momentum,
MergedMomentumInferShapeFunctor,
PD_INFER_META(phi::MergedMomentumInferMeta));
REGISTER_OP_WITHOUT_GRADIENT(merged_momentum,
ops::MergedMomentumOp,
ops::MergedMomentumOpMaker,
MergedMomentumInferShapeFunctor);
......@@ -884,17 +884,6 @@
data_type : param
inplace : (param -> param_out), (moment1 -> moment1_out), (moment2 -> moment2_out), (beta1_pow -> beta1_pow_out), (beta2_pow -> beta2_pow_out), (master_param -> master_param_out)
- op : merged_momentum_
args : (Tensor[] param, Tensor[] grad, Tensor[] velocity, Tensor[] learning_rate, Tensor[] master_param, float mu, bool use_nesterov = false, str[] regularization_method = {}, float[] regularization_coeff = {}, bool multi_precision = false, float rescale_grad = 1.0f)
output : Tensor[](param_out){param.size()}, Tensor[](velocity_out){param.size()}, Tensor[](master_param_out){param.size()}
infer_meta :
func : MergedMomentumInferMeta
optional: master_param
kernel :
func : merged_momentum
data_type : param
inplace : (param -> param_out), (velocity -> velocity_out), (master_param -> master_param_out)
- op : min
args : (Tensor x, IntArray axis={}, bool keepdim=false)
output : Tensor(out)
......
......@@ -1421,6 +1421,12 @@
outputs :
out : Out
- op : merged_momentum_
inputs :
{param : Param, grad : Grad, velocity : Velocity, learning_rate : LearningRate, master_param : MasterParam}
outputs :
{param_out : ParamOut, velocity_out : VelocityOut, master_param_out : MasterParamOut}
- op : meshgrid
backward : meshgrid_grad
inputs :
......
......@@ -1190,6 +1190,17 @@
kernel :
func : merge_selected_rows {selected_rows -> selected_rows}
- op : merged_momentum_
args : (Tensor[] param, Tensor[] grad, Tensor[] velocity, Tensor[] learning_rate, Tensor[] master_param, float mu, bool use_nesterov = false, str[] regularization_method = {}, float[] regularization_coeff = {}, bool multi_precision = false, float rescale_grad = 1.0f)
output : Tensor[](param_out){param.size()}, Tensor[](velocity_out){param.size()}, Tensor[](master_param_out){param.size()}
infer_meta :
func : MergedMomentumInferMeta
kernel :
func : merged_momentum
data_type : param
optional: master_param, master_param_out
inplace : (param -> param_out), (velocity -> velocity_out), (master_param -> master_param_out)
- op : meshgrid
args : (Tensor[] inputs)
output : Tensor[]{inputs.size()}
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature MergedMomentumOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"merged_momentum",
{"Param", "Grad", "Velocity", "LearningRate", "MasterParam"},
{"mu",
"use_nesterov",
"regularization_method",
"regularization_coeff",
"multi_precision",
"rescale_grad"},
{
"ParamOut",
"VelocityOut",
"MasterParamOut",
});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(merged_momentum,
phi::MergedMomentumOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册