From 609077e9a6a755176eeec6b6eb34d6a4b09c62b1 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Fri, 25 Mar 2022 11:32:56 +0800 Subject: [PATCH] move mul op infershape (#40917) --- paddle/fluid/operators/mul_op.cc | 99 ++++---------------------------- paddle/phi/infermeta/binary.cc | 75 ++++++++++++++++++++++++ paddle/phi/infermeta/binary.h | 6 ++ 3 files changed, 93 insertions(+), 87 deletions(-) diff --git a/paddle/fluid/operators/mul_op.cc b/paddle/fluid/operators/mul_op.cc index 6738f15ef7..ef04d5582d 100644 --- a/paddle/fluid/operators/mul_op.cc +++ b/paddle/fluid/operators/mul_op.cc @@ -21,6 +21,10 @@ limitations under the License. */ #include "paddle/fluid/platform/mkldnn_helper.h" #endif +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/infermeta/binary.h" namespace paddle { namespace operators { @@ -34,72 +38,6 @@ class MulOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Mul"); - OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "Mul"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Mul"); - - auto x_dims = ctx->GetInputDim("X"); - auto y_dims = ctx->GetInputDim("Y"); - - int x_num_col_dims = ctx->Attrs().Get("x_num_col_dims"); - int y_num_col_dims = ctx->Attrs().Get("y_num_col_dims"); - - VLOG(3) << "mul operator x.shape=" << x_dims << " y.shape=" << y_dims - << " x_num_col_dims=" << x_num_col_dims - << " y_num_col_dims=" << y_num_col_dims; - - PADDLE_ENFORCE_NE(phi::product(y_dims), 0, - platform::errors::PreconditionNotMet( - "The Input variable Y(%s) has not " - "been initialized. You may need to confirm " - "if you put exe.run(startup_program) " - "after optimizer.minimize function.", - ctx->Inputs("Y").front())); - PADDLE_ENFORCE_GT( - x_dims.size(), x_num_col_dims, - platform::errors::InvalidArgument( - "The input tensor X's dimensions of MulOp " - "should be larger than x_num_col_dims. But received X's " - "dimensions = %d, X's shape = [%s], x_num_col_dims = %d.", - x_dims.size(), x_dims, x_num_col_dims)); - PADDLE_ENFORCE_GT( - y_dims.size(), y_num_col_dims, - platform::errors::InvalidArgument( - "The input tensor Y's dimensions of MulOp " - "should be larger than y_num_col_dims. But received Y's " - "dimensions = %d, Y's shape = [%s], y_num_col_dims = %d.", - y_dims.size(), y_dims, y_num_col_dims)); - - auto x_mat_dims = phi::flatten_to_2d(x_dims, x_num_col_dims); - auto y_mat_dims = phi::flatten_to_2d(y_dims, y_num_col_dims); - - PADDLE_ENFORCE_EQ( - x_mat_dims[1], y_mat_dims[0], - platform::errors::InvalidArgument( - "After flatten the input tensor X and Y to 2-D dimensions matrix " - "X1 and Y1, the matrix X1's width must be equal with matrix Y1's " - "height. But received X's shape = [%s], X1's shape = [%s], X1's " - "width = %s; Y's shape = [%s], Y1's shape = [%s], Y1's height = " - "%s.", - x_dims, x_mat_dims, x_mat_dims[1], y_dims, y_mat_dims, - y_mat_dims[0])); - std::vector output_dims; - output_dims.reserve( - static_cast(x_num_col_dims + y_dims.size() - y_num_col_dims)); - - for (int i = 0; i < x_num_col_dims; ++i) { - output_dims.push_back(x_dims[i]); - } - - for (int i = y_num_col_dims; i < y_dims.size(); ++i) { - output_dims.push_back(y_dims[i]); - } - - ctx->SetOutputDim("Out", phi::make_ddim(output_dims)); - ctx->ShareLoD("X", /*->*/ "Out"); - } - framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const { framework::LibraryType library = framework::LibraryType::kPlain; @@ -225,25 +163,6 @@ class MulGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "mul"); - OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "mul"); - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", - "Out@GRAD", "mul"); - auto x_dims = ctx->GetInputDim("X"); - auto y_dims = ctx->GetInputDim("Y"); - - auto x_grad_name = framework::GradVarName("X"); - auto y_grad_name = framework::GradVarName("Y"); - - if (ctx->HasOutput(x_grad_name)) { - ctx->SetOutputDim(x_grad_name, x_dims); - } - if (ctx->HasOutput(y_grad_name)) { - ctx->SetOutputDim(y_grad_name, y_dims); - } - } - framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const { framework::LibraryType library = framework::LibraryType::kPlain; @@ -348,12 +267,18 @@ class MulDoubleGradMaker : public framework::SingleGradOpMaker { } // namespace paddle namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(mul, MulInferShapeFunctor, + PD_INFER_META(phi::MatmulWithFlattenInferMeta)); REGISTER_OPERATOR(mul, ops::MulOp, ops::MulOpMaker, ops::MulOpInferVarType, ops::MulOpGradMaker, - ops::MulOpGradMaker); + ops::MulOpGradMaker, + MulInferShapeFunctor); +DECLARE_INFER_SHAPE_FUNCTOR(mul_grad, MulGradInferShapeFunctor, + PD_INFER_META(phi::GeneralBinaryGradInferMeta)); REGISTER_OPERATOR(mul_grad, ops::MulGradOp, ops::MulDoubleGradMaker, - ops::MulDoubleGradMaker); + ops::MulDoubleGradMaker, + MulGradInferShapeFunctor); REGISTER_OPERATOR(mul_grad_grad, ops::MulDoubleGradOp); diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 5221076f10..a0310e1c2e 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -1267,6 +1267,81 @@ void MatmulInferMeta(const MetaTensor& x, out->set_layout(x.layout()); } +void MatmulWithFlattenInferMeta(const MetaTensor& x, + const MetaTensor& y, + int x_num_col_dims, + int y_num_col_dims, + MetaTensor* out) { + auto x_dims = x.dims(); + auto y_dims = y.dims(); + + VLOG(3) << "mul operator x.shape=" << x_dims << " y.shape=" << y_dims + << " x_num_col_dims=" << x_num_col_dims + << " y_num_col_dims=" << y_num_col_dims; + + PADDLE_ENFORCE_NE(phi::product(y_dims), + 0, + phi::errors::PreconditionNotMet( + "The Input variable Y has not " + "been initialized. You may need to confirm " + "if you put exe.run(startup_program) " + "after optimizer.minimize function.")); + PADDLE_ENFORCE_GT( + x_dims.size(), + x_num_col_dims, + phi::errors::InvalidArgument( + "The input tensor X's dimensions of MulOp " + "should be larger than x_num_col_dims. But received X's " + "dimensions = %d, X's shape = [%s], x_num_col_dims = %d.", + x_dims.size(), + x_dims, + x_num_col_dims)); + PADDLE_ENFORCE_GT( + y_dims.size(), + y_num_col_dims, + phi::errors::InvalidArgument( + "The input tensor Y's dimensions of MulOp " + "should be larger than y_num_col_dims. But received Y's " + "dimensions = %d, Y's shape = [%s], y_num_col_dims = %d.", + y_dims.size(), + y_dims, + y_num_col_dims)); + + auto x_mat_dims = phi::flatten_to_2d(x_dims, x_num_col_dims); + auto y_mat_dims = phi::flatten_to_2d(y_dims, y_num_col_dims); + + PADDLE_ENFORCE_EQ( + x_mat_dims[1], + y_mat_dims[0], + phi::errors::InvalidArgument( + "After flatten the input tensor X and Y to 2-D dimensions matrix " + "X1 and Y1, the matrix X1's width must be equal with matrix Y1's " + "height. But received X's shape = [%s], X1's shape = [%s], X1's " + "width = %s; Y's shape = [%s], Y1's shape = [%s], Y1's height = " + "%s.", + x_dims, + x_mat_dims, + x_mat_dims[1], + y_dims, + y_mat_dims, + y_mat_dims[0])); + std::vector output_dims; + output_dims.reserve( + static_cast(x_num_col_dims + y_dims.size() - y_num_col_dims)); + + for (int i = 0; i < x_num_col_dims; ++i) { + output_dims.push_back(x_dims[i]); + } + + for (int i = y_num_col_dims; i < y_dims.size(); ++i) { + output_dims.push_back(y_dims[i]); + } + + out->set_dims(phi::make_ddim(output_dims)); + out->set_dtype(x.dtype()); + out->share_lod(x); +} + void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out) { auto dim_x = x.dims(); auto dim_vec = vec.dims(); diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index f9a9398437..1c5dbf1d9f 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -186,6 +186,12 @@ void MatmulInferMeta(const MetaTensor& x, bool trans_y, MetaTensor* out); +void MatmulWithFlattenInferMeta(const MetaTensor& x, + const MetaTensor& y, + int x_num_col_dims, + int y_num_col_dims, + MetaTensor* out); + void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out); void PReluInferMeta(const MetaTensor& x, -- GitLab