未验证 提交 a6ae1e35 编写于 作者: R RedContritio 提交者: GitHub

support auto generate for op momentum optimizer (#52611)

* support auto generate for op momentum optimizer

* remove momentum_op.* and update signature

* fix dgc momentum op maker error
上级 757aa470
......@@ -19,9 +19,9 @@
namespace paddle {
namespace operators {
class DGCMomentumOp : public MomentumOp {
class DGCMomentumOp : public framework::OperatorWithKernel {
public:
using MomentumOp::MomentumOp;
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
......@@ -32,7 +32,82 @@ class DGCMomentumOp : public MomentumOp {
OP_INOUT_CHECK(ctx->HasInput("nranks"), "Input", "nranks", "DGCMomentumOp");
OP_INOUT_CHECK(
ctx->HasOutput("Grad_out"), "Output", "Grad_out", "DGCMomentumOp");
return MomentumOp::InferShape(ctx);
PADDLE_ENFORCE_EQ(ctx->HasInput("Param"),
true,
platform::errors::NotFound(
"Input(param) of Momentum should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Grad"),
true,
platform::errors::NotFound(
"Input(grad) of Momentum should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Velocity"),
true,
platform::errors::NotFound(
"Input(velocity) of Momentum should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("LearningRate"),
true,
platform::errors::NotFound(
"Input(LearningRate) of Momentum should not be null."));
PADDLE_ENFORCE_EQ(ctx->GetInputsVarType("Param").front(),
framework::proto::VarType::LOD_TENSOR,
platform::errors::InvalidArgument(
"The input var's type should be phi::DenseTensor, "
"but the received is %s",
ctx->GetInputsVarType("Param").front()));
PADDLE_ENFORCE_EQ(ctx->HasOutput("ParamOut"),
true,
platform::errors::NotFound(
"Output(ParamOut) of Momentum should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("VelocityOut"),
true,
platform::errors::NotFound(
"Output(VelocityOut) of Momentum should not be null."));
auto lr_dims = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_NE(phi::product(lr_dims),
0,
platform::errors::InvalidArgument(
"Maybe the Input variable LearningRate has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function."));
PADDLE_ENFORCE_EQ(phi::product(lr_dims),
1,
platform::errors::InvalidArgument(
"Learning_rate should be a scalar. But Received "
"LearningRate's dim [%s]",
phi::product(lr_dims)));
auto param_dim = ctx->GetInputDim("Param");
if (ctx->GetInputsVarType("Grad")[0] ==
framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(
param_dim,
ctx->GetInputDim("Grad"),
platform::errors::InvalidArgument(
"Param and Grad input of MomentumOp should have the same "
"dimension. But received Param's dim [%s] and Grad's dim [%s].",
param_dim,
ctx->GetInputDim("Grad")));
PADDLE_ENFORCE_EQ(
param_dim,
ctx->GetInputDim("Velocity"),
platform::errors::InvalidArgument(
"Param and Velocity of MomentumOp should have the same "
"dimension. But received Param's dim [%s] and Velocity [%s].",
param_dim,
ctx->GetInputDim("Velocity")));
}
ctx->SetOutputDim("ParamOut", param_dim);
ctx->SetOutputDim("VelocityOut", param_dim);
if (ctx->HasOutput("MasterParamOut")) {
ctx->SetOutputDim("MasterParamOut", param_dim);
}
}
phi::KernelKey GetKernelTypeForVar(
......@@ -49,22 +124,75 @@ class DGCMomentumOp : public MomentumOp {
return framework::OperatorWithKernel::GetKernelTypeForVar(
var_name, tensor, expected_kernel_type);
}
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
OperatorWithKernel::IndicateVarDataType(ctx, "Param");
return phi::KernelKey(input_data_type, ctx.GetPlace());
}
};
class DGCMomentumOpMaker : public MomentumOpMaker {
class DGCMomentumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Param",
"(phi::DenseTensor, default phi::DenseTensor<float>) "
"Input parameter that has to be updated");
AddInput("Grad",
"(phi::DenseTensor, default phi::DenseTensor<float>) "
"Input gradient of the parameter");
AddInput("Velocity",
"(phi::DenseTensor, default phi::DenseTensor<float>) "
"Input velocity (corresponding to the parameter) "
"that has to be updated");
AddInput("LearningRate",
"(phi::DenseTensor, default phi::DenseTensor<float>) "
"Input learning rate");
AddInput("MasterParam", "FP32 master weight for AMP.").AsDispensable();
AddInput("current_step", "(Tensor) Current step.");
AddInput("nranks", "(Tensor) The number of trainers.");
AddOutput("ParamOut",
"(phi::DenseTensor) This output is updated parameter. "
"It shared memory with Input(Param).");
AddOutput("VelocityOut",
"(phi::DenseTensor) This output is updated velocity. "
"It shared memory with Input(Velocity).");
AddOutput("MasterParamOut",
"The updated FP32 master weight for AMP. "
"It shared memory with Input(MasterParam).")
.AsDispensable();
AddOutput("Grad_out", "(Tensor) Output grad gradient");
AddAttr<float>("mu", "(float) Momentum coefficient");
AddAttr<bool>("use_nesterov",
"(bool, default false) "
"Use Nesterov Momentum")
.SetDefault(false);
AddAttr<std::string>("regularization_method",
"(string) regularization_method, right now only "
"support l2decay or none")
.SetDefault("");
AddAttr<float>("regularization_coeff", "(float) regularization_coeff")
.SetDefault(0.0f);
AddAttr<bool>("multi_precision",
"(bool, default false) "
"Whether to use multi-precision during weight updating.")
.SetDefault(false);
AddAttr<float>(
"rescale_grad",
"(float, default 1.0) Multiply the gradient with `rescale_grad`"
"before updating. Often choose to be `1.0/batch_size`.")
.SetDefault(1.0f);
AddAttr<float>("rampup_begin_step",
"(float, -1.0)"
"The period when begin DGC.")
.SetDefault(-1.0);
return MomentumOpMaker::Make();
AddComment(R"DOC(
DGC Momentum Operator.
)DOC");
}
};
......
......@@ -16,7 +16,8 @@
#include <memory>
#include "paddle/fluid/operators/optimizers/momentum_op.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/kernels/momentum_kernel.h"
#include "paddle/phi/kernels/sgd_kernel.h"
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/optimizers/momentum_op.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace operators {
class MomentumOpInferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext* ctx) const override {
auto in_var_type = ctx->GetInputType("Param");
PADDLE_ENFORCE_EQ(
in_var_type == framework::proto::VarType::SELECTED_ROWS ||
in_var_type == framework::proto::VarType::LOD_TENSOR,
true,
platform::errors::InvalidArgument(
"Only support LodTensor and SelectedRows, Unexpected Input Type."));
ctx->SetOutputType("ParamOut", in_var_type, framework::ALL_ELEMENTS);
}
};
void MomentumOpMaker::Make() {
AddInput("Param",
"(phi::DenseTensor, default phi::DenseTensor<float>) "
"Input parameter that has to be updated");
AddInput("Grad",
"(phi::DenseTensor, default phi::DenseTensor<float>) "
"Input gradient of the parameter");
AddInput("Velocity",
"(phi::DenseTensor, default phi::DenseTensor<float>) "
"Input velocity (corresponding to the parameter) "
"that has to be updated");
AddInput("LearningRate",
"(phi::DenseTensor, default phi::DenseTensor<float>) "
"Input learning rate");
AddInput("MasterParam", "FP32 master weight for AMP.").AsDispensable();
AddOutput("ParamOut",
"(phi::DenseTensor) This output is updated parameter. "
"It shared memory with Input(Param).");
AddOutput("VelocityOut",
"(phi::DenseTensor) This output is updated velocity. "
"It shared memory with Input(Velocity).");
AddOutput("MasterParamOut",
"The updated FP32 master weight for AMP. "
"It shared memory with Input(MasterParam).")
.AsDispensable();
AddAttr<float>("mu", "(float) Momentum coefficient");
AddAttr<bool>("use_nesterov",
"(bool, default false) "
"Use Nesterov Momentum")
.SetDefault(false);
AddAttr<std::string>(
"regularization_method",
"(string) regularization_method, right now only support l2decay or none")
.SetDefault("");
AddAttr<float>("regularization_coeff", "(float) regularization_coeff")
.SetDefault(0.0f);
AddAttr<bool>("multi_precision",
"(bool, default false) "
"Whether to use multi-precision during weight updating.")
.SetDefault(false);
AddAttr<float>(
"rescale_grad",
"(float, default 1.0) Multiply the gradient with `rescale_grad`"
"before updating. Often choose to be `1.0/batch_size`.")
.SetDefault(1.0f);
AddComment(R"DOC(
Momentum Optimizer.
This optimizer has a flag for Nestrov Momentum.
The update equations are as follows:
$$
velocity = mu * velocity + gradient \\
if (use\_nesterov): \\
param = param - (gradient + mu * velocity) * learning\_rate \\
else: \\
param = param - learning\_rate * velocity. \\
$$
)DOC");
}
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
momentum,
ops::MomentumOp,
ops::MomentumOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::MomentumOpInferVarType);
REGISTER_OP_VERSION(momentum).AddCheckpoint(
R"ROC(
Upgrade momentum add 4 attributes [regularization_method, regularization_coeff,
multi_precision, rescale_grad].
)ROC",
paddle::framework::compatible::OpVersionDesc()
.NewInput("MasterParam", "FP32 master weight for AMP.")
.NewOutput("MasterParamOut",
"The updated FP32 master weight for AMP. "
"It shared memory with Input(MasterParam).")
.NewAttr("regularization_method",
"(string) regularization_method, right now only support "
"l2decay or none",
std::string(""))
.NewAttr("regularization_coeff", "(float) regularization_coeff", 0.0f)
.NewAttr(
"multi_precision",
"(bool) Whether to use multi-precision during weight updating.",
false)
.NewAttr("rescale_grad",
"(float) Multiply the gradient with `rescale_grad`"
"before updating. Often choose to be `1.0/batch_size`.",
1.0f));
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/algorithm.h"
namespace paddle {
namespace operators {
class MomentumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
};
class MomentumOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("Param"),
true,
platform::errors::NotFound(
"Input(param) of Momentum should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Grad"),
true,
platform::errors::NotFound(
"Input(grad) of Momentum should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Velocity"),
true,
platform::errors::NotFound(
"Input(velocity) of Momentum should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("LearningRate"),
true,
platform::errors::NotFound(
"Input(LearningRate) of Momentum should not be null."));
PADDLE_ENFORCE_EQ(ctx->GetInputsVarType("Param").front(),
framework::proto::VarType::LOD_TENSOR,
platform::errors::InvalidArgument(
"The input var's type should be phi::DenseTensor, "
"but the received is %s",
ctx->GetInputsVarType("Param").front()));
PADDLE_ENFORCE_EQ(ctx->HasOutput("ParamOut"),
true,
platform::errors::NotFound(
"Output(ParamOut) of Momentum should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("VelocityOut"),
true,
platform::errors::NotFound(
"Output(VelocityOut) of Momentum should not be null."));
auto lr_dims = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_NE(phi::product(lr_dims),
0,
platform::errors::InvalidArgument(
"Maybe the Input variable LearningRate has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function."));
PADDLE_ENFORCE_EQ(phi::product(lr_dims),
1,
platform::errors::InvalidArgument(
"Learning_rate should be a scalar. But Received "
"LearningRate's dim [%s]",
phi::product(lr_dims)));
auto param_dim = ctx->GetInputDim("Param");
if (ctx->GetInputsVarType("Grad")[0] ==
framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(
param_dim,
ctx->GetInputDim("Grad"),
platform::errors::InvalidArgument(
"Param and Grad input of MomentumOp should have the same "
"dimension. But received Param's dim [%s] and Grad's dim [%s].",
param_dim,
ctx->GetInputDim("Grad")));
PADDLE_ENFORCE_EQ(
param_dim,
ctx->GetInputDim("Velocity"),
platform::errors::InvalidArgument(
"Param and Velocity of MomentumOp should have the same "
"dimension. But received Param's dim [%s] and Velocity [%s].",
param_dim,
ctx->GetInputDim("Velocity")));
}
ctx->SetOutputDim("ParamOut", param_dim);
ctx->SetOutputDim("VelocityOut", param_dim);
if (ctx->HasOutput("MasterParamOut")) {
ctx->SetOutputDim("MasterParamOut", param_dim);
}
}
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
OperatorWithKernel::IndicateVarDataType(ctx, "Param");
return phi::KernelKey(input_data_type, ctx.GetPlace());
}
};
} // namespace operators
} // namespace paddle
......@@ -8,7 +8,6 @@ register_unity_group(
cc
ftrl_op.cc
lars_momentum_op.cc
momentum_op.cc
proximal_adagrad_op.cc
adagrad_op.cc
adam_op.cc
......
......@@ -948,17 +948,6 @@
func : mish
backward : mish_grad
- op : momentum_
args : (Tensor param, Tensor grad, Tensor velocity, Tensor learning_rate, Tensor master_param, float mu, bool use_nesterov = false, str regularization_method = "", float regularization_coeff = 0.0, bool multi_precision = false, float rescale_grad = 1.0f)
output : Tensor(param_out), Tensor(velocity_out), Tensor(master_param_out)
infer_meta:
func : MomentumInferMeta
kernel :
func : momentum
data_type : param
optional : master_param
inplace : (param -> param_out), (velocity -> velocity_out), (master_param -> master_param_out)
- op : multiclass_nms3
args : (Tensor bboxes, Tensor scores, Tensor rois_num, float score_threshold, int nms_top_k, int keep_top_k, float nms_threshold=0.3, bool normalized=true, float nms_eta=1.0, int background_label=0)
output : Tensor(out), Tensor(index), Tensor(nms_rois_num)
......
......@@ -1431,6 +1431,12 @@
outputs :
{out : Out, indices : Indices}
- op : momentum_
inputs :
{param : Param, grad : Grad, velocity : Velocity, learning_rate : LearningRate, master_param : MasterParam}
outputs :
{param_out : ParamOut, velocity_out : VelocityOut, master_param_out : MasterParamOut}
- op : multi_dot
backward : multi_dot_grad
inputs :
......
......@@ -153,6 +153,27 @@
- add_output : RoisNum
comment : The number of RoIs in each image.
- op : momentum
version :
- checkpoint : Upgrade momentum add 4 attributes [regularization_method, regularization_coeff, multi_precision, rescale_grad].
action :
- add_input : MasterParam
comment : FP32 master weight for AMP.
- add_output : MasterParamOut
comment : The updated FP32 master weight for AMP. It shared memory with Input(MasterParam).
- add_attr : regularization_method
comment : (string) regularization_method, right now only support l2decay or none
default : std::string("")
- add_attr : regularization_coeff
comment : (float) regularization_coeff
default : 0.0
- add_attr : multi_precision
comment : (bool) Whether to use multi-precision during weight updating.
default : "false"
- add_attr : rescale_grad
comment : (float) Multiply the gradient with `rescale_grad` before updating. Often choose to be `1.0/batch_size`.
default : 1.0
- op : not_equal
version :
- checkpoint : Upgrade compare ops, add a new attribute [force_cpu]
......
......@@ -1184,6 +1184,18 @@
func : mode
backward : mode_grad
- op : momentum_
args : (Tensor param, Tensor grad, Tensor velocity, Tensor learning_rate, Tensor master_param, float mu, bool use_nesterov = false, str regularization_method = "", float regularization_coeff = 0.0f, bool multi_precision = false, float rescale_grad = 1.0f)
output : Tensor(param_out), Tensor(velocity_out), Tensor(master_param_out)
infer_meta:
func : MomentumInferMeta
kernel :
func : momentum {dense, dense, dense, dense, dense -> dense, dense, dense},
momentum_dense_param_sparse_grad {dense, selected_rows, dense, dense, dense -> dense, dense, dense}
data_type : param
optional : master_param, master_param_out
inplace : (param -> param_out), (velocity -> velocity_out), (master_param -> master_param_out)
- op : multi_dot
args : (Tensor[] x)
output : Tensor
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature MomentumOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorInput("Grad")) {
return KernelSignature(
"momentum",
{"Param", "Grad", "Velocity", "LearningRate", "MasterParam"},
{"mu",
"use_nesterov",
"regularization_method",
"regularization_coeff",
"multi_precision",
"rescale_grad"},
{"ParamOut", "VelocityOut", "MasterParamOut"});
} else if (ctx.IsSelectedRowsInput("Grad")) {
return KernelSignature(
"momentum_dense_param_sparse_grad",
{"Param", "Grad", "Velocity", "LearningRate", "MasterParam"},
{"mu",
"use_nesterov",
"regularization_method",
"regularization_coeff",
"multi_precision",
"rescale_grad"},
{"ParamOut", "VelocityOut", "MasterParamOut"});
}
return KernelSignature("unregistered", {}, {}, {});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(momentum, phi::MomentumOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册