未验证 提交 a533dae3 编写于 作者: H huangjiyi 提交者: GitHub

move dgc_momentum InferShape to phi (#56358)

上级 ee01d78f
...@@ -14,7 +14,9 @@ ...@@ -14,7 +14,9 @@
#include <string> #include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -24,92 +26,6 @@ class DGCMomentumOp : public framework::OperatorWithKernel { ...@@ -24,92 +26,6 @@ class DGCMomentumOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("current_step"),
"Input",
"current_step",
"DGCMomentumOp");
OP_INOUT_CHECK(ctx->HasInput("nranks"), "Input", "nranks", "DGCMomentumOp");
OP_INOUT_CHECK(
ctx->HasOutput("Grad_out"), "Output", "Grad_out", "DGCMomentumOp");
PADDLE_ENFORCE_EQ(ctx->HasInput("Param"),
true,
platform::errors::NotFound(
"Input(param) of Momentum should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Grad"),
true,
platform::errors::NotFound(
"Input(grad) of Momentum should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Velocity"),
true,
platform::errors::NotFound(
"Input(velocity) of Momentum should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("LearningRate"),
true,
platform::errors::NotFound(
"Input(LearningRate) of Momentum should not be null."));
PADDLE_ENFORCE_EQ(ctx->GetInputsVarType("Param").front(),
framework::proto::VarType::LOD_TENSOR,
platform::errors::InvalidArgument(
"The input var's type should be phi::DenseTensor, "
"but the received is %s",
ctx->GetInputsVarType("Param").front()));
PADDLE_ENFORCE_EQ(ctx->HasOutput("ParamOut"),
true,
platform::errors::NotFound(
"Output(ParamOut) of Momentum should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("VelocityOut"),
true,
platform::errors::NotFound(
"Output(VelocityOut) of Momentum should not be null."));
auto lr_dims = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_NE(phi::product(lr_dims),
0,
platform::errors::InvalidArgument(
"Maybe the Input variable LearningRate has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function."));
PADDLE_ENFORCE_EQ(phi::product(lr_dims),
1,
platform::errors::InvalidArgument(
"Learning_rate should be a scalar. But Received "
"LearningRate's dim [%s]",
phi::product(lr_dims)));
auto param_dim = ctx->GetInputDim("Param");
if (ctx->GetInputsVarType("Grad")[0] ==
framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(
param_dim,
ctx->GetInputDim("Grad"),
platform::errors::InvalidArgument(
"Param and Grad input of MomentumOp should have the same "
"dimension. But received Param's dim [%s] and Grad's dim [%s].",
param_dim,
ctx->GetInputDim("Grad")));
PADDLE_ENFORCE_EQ(
param_dim,
ctx->GetInputDim("Velocity"),
platform::errors::InvalidArgument(
"Param and Velocity of MomentumOp should have the same "
"dimension. But received Param's dim [%s] and Velocity [%s].",
param_dim,
ctx->GetInputDim("Velocity")));
}
ctx->SetOutputDim("ParamOut", param_dim);
ctx->SetOutputDim("VelocityOut", param_dim);
if (ctx->HasOutput("MasterParamOut")) {
ctx->SetOutputDim("MasterParamOut", param_dim);
}
}
phi::KernelKey GetKernelTypeForVar( phi::KernelKey GetKernelTypeForVar(
const std::string& var_name, const std::string& var_name,
const phi::DenseTensor& tensor, const phi::DenseTensor& tensor,
...@@ -199,7 +115,12 @@ DGC Momentum Operator. ...@@ -199,7 +115,12 @@ DGC Momentum Operator.
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
DECLARE_INFER_SHAPE_FUNCTOR(dgc_momentum,
DGCMomentumInferShapeFunctor,
PD_INFER_META(phi::DGCMomentumInferMeta));
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(dgc_momentum, REGISTER_OP_WITHOUT_GRADIENT(dgc_momentum,
ops::DGCMomentumOp, ops::DGCMomentumOp,
ops::DGCMomentumOpMaker); ops::DGCMomentumOpMaker,
DGCMomentumInferShapeFunctor);
...@@ -1305,6 +1305,67 @@ void DeformableConvInferMeta(const MetaTensor& x, ...@@ -1305,6 +1305,67 @@ void DeformableConvInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype()); out->set_dtype(x.dtype());
} }
void DGCMomentumInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& velocity,
const MetaTensor& learning_rate,
const MetaTensor& master_param,
const MetaTensor& current_step_tensor,
const MetaTensor& nranks_tensor,
float mu,
bool use_nesterov,
const std::string& regularization_method,
float regularization_coeff,
bool multi_precision,
float rescale_grad,
float rampup_begin_step,
MetaTensor* param_out,
MetaTensor* velocity_out,
MetaTensor* master_param_out,
MetaTensor* grad_out) {
auto lr_dims = learning_rate.dims();
PADDLE_ENFORCE_NE(phi::product(lr_dims),
0,
phi::errors::InvalidArgument(
"Maybe the Input variable LearningRate has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function."));
PADDLE_ENFORCE_EQ(phi::product(lr_dims),
1,
phi::errors::InvalidArgument(
"Learning_rate should be a scalar. But Received "
"LearningRate's dim [%s]",
phi::product(lr_dims)));
auto param_dims = param.dims();
auto grad_dims = grad.dims();
auto velocity_dims = velocity.dims();
PADDLE_ENFORCE_EQ(
param_dims,
grad_dims,
phi::errors::InvalidArgument(
"Param and Grad input of MomentumOp should have the same "
"dimension. But received Param's dim [%s] and Grad's dim [%s].",
param_dims,
grad_dims));
PADDLE_ENFORCE_EQ(
param_dims,
velocity_dims,
phi::errors::InvalidArgument(
"Param and Velocity of MomentumOp should have the same "
"dimension. But received Param's dim [%s] and Velocity [%s].",
param_dims,
velocity_dims));
param_out->set_dims(param_dims);
velocity_out->set_dims(param_dims);
if (master_param != nullptr) {
master_param_out->set_dims(param_dims);
}
}
void EditDistanceInferMeta(const MetaTensor& hyps, void EditDistanceInferMeta(const MetaTensor& hyps,
const MetaTensor& refs, const MetaTensor& refs,
const MetaTensor& hypslength, const MetaTensor& hypslength,
......
...@@ -280,6 +280,25 @@ void DeformableConvInferMeta(const MetaTensor& x, ...@@ -280,6 +280,25 @@ void DeformableConvInferMeta(const MetaTensor& x,
MetaTensor* out, MetaTensor* out,
MetaConfig config = MetaConfig()); MetaConfig config = MetaConfig());
void DGCMomentumInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& velocity,
const MetaTensor& learning_rate,
const MetaTensor& master_param,
const MetaTensor& current_step_tensor,
const MetaTensor& nranks_tensor,
float mu,
bool use_nesterov,
const std::string& regularization_method,
float regularization_coeff,
bool multi_precision,
float rescale_grad,
float rampup_begin_step,
MetaTensor* param_out,
MetaTensor* velocity_out,
MetaTensor* master_param_out,
MetaTensor* grad_out);
void EditDistanceInferMeta(const MetaTensor& hyps, void EditDistanceInferMeta(const MetaTensor& hyps,
const MetaTensor& refs, const MetaTensor& refs,
const MetaTensor& hypslength, const MetaTensor& hypslength,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册