From 59e5c49f850f9e94b49a0a75136efc2e19918a3c Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Wed, 16 Mar 2022 10:27:56 +0800 Subject: [PATCH] move gather infershape (#40594) --- paddle/fluid/operators/gather_op.cc | 72 ++++++----------------------- paddle/phi/infermeta/binary.cc | 49 ++++++++++++++++++++ paddle/phi/infermeta/binary.h | 6 +++ 3 files changed, 68 insertions(+), 59 deletions(-) diff --git a/paddle/fluid/operators/gather_op.cc b/paddle/fluid/operators/gather_op.cc index 7910d94298..9f2b48a24b 100644 --- a/paddle/fluid/operators/gather_op.cc +++ b/paddle/fluid/operators/gather_op.cc @@ -15,9 +15,14 @@ limitations under the License. */ #include #include #include + +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/phi/core/ddim.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/infermeta/binary.h" namespace paddle { namespace operators { @@ -26,58 +31,6 @@ class GatherOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, - platform::errors::InvalidArgument( - "Input(X) of GatherOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasInput("Index"), true, - platform::errors::InvalidArgument( - "Input(Index) of GatherOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, - platform::errors::InvalidArgument( - "Output(Out) of GatherOp should not be null.")); - - auto index_dims = ctx->GetInputDim("Index"); - - if (index_dims.size() == 2) { - PADDLE_ENFORCE_EQ( - index_dims[1], 1, - platform::errors::InvalidArgument( - "The last dim of index should be 1 when it is 2D, but we get %d", - index_dims[1])); - } else { - PADDLE_ENFORCE_EQ( - index_dims.size(), 1, - platform::errors::InvalidArgument( - "The index should be 1D, when it is not 2D, but we get %d", - index_dims.size())); - } - - auto axis = ctx->Attrs().Get("axis"); - auto input_dim = ctx->GetInputDim("X"); - if (ctx->HasInput("Axis") || axis == 0) { - // if HasInput("Axis"), we can not obtain correct shape of output - int batch_size = index_dims[0]; - framework::DDim output_dims(input_dim); - output_dims[0] = batch_size; - ctx->SetOutputDim("Out", output_dims); - ctx->ShareLoD("X", /*->*/ "Out"); - } else { - int index_size = index_dims[0]; - std::vector out_dim_vec; - for (int i = 0; i < axis; i++) { - out_dim_vec.push_back(input_dim[i]); - } - out_dim_vec.push_back(index_size); - for (int i = axis + 1; i < input_dim.size(); i++) { - out_dim_vec.push_back(input_dim[i]); - } - auto output_dims = phi::make_ddim(out_dim_vec); - ctx->SetOutputDim("Out", output_dims); - ctx->ShareLoD("X", /*->*/ "Out"); - } - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -100,11 +53,6 @@ class GatherGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); - ctx->ShareLoD("X", /*-->*/ framework::GradVarName("X")); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -193,11 +141,17 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(GatherGradNoNeedBufferVarInferer, "X"); } // namespace paddle namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(gather, GatherInferShapeFunctor, + PD_INFER_META(phi::GatherInferMeta)); REGISTER_OPERATOR(gather, ops::GatherOp, ops::GatherOpMaker, ops::GatherGradOpMaker, - ops::GatherGradOpMaker); + ops::GatherGradOpMaker, + GatherInferShapeFunctor); +DECLARE_INFER_SHAPE_FUNCTOR(gather_grad, GatherGradInferShapeFunctor, + PD_INFER_META(phi::GeneralUnaryGradInferMeta)); REGISTER_OPERATOR(gather_grad, ops::GatherGradOp, - ops::GatherGradNoNeedBufferVarInferer); + ops::GatherGradNoNeedBufferVarInferer, + GatherGradInferShapeFunctor); REGISTER_OP_VERSION(gather) .AddCheckpoint(R"ROC(upgrad gather, add a new input [Axis])ROC", diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index ff2cf81a90..ffb1ed5450 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -431,6 +431,55 @@ void ElementwiseRawInferMeta(const MetaTensor& x, out->share_lod(x); } +void GatherInferMeta(const MetaTensor& x, + const MetaTensor& index, + const Scalar& axis, + MetaTensor* out) { + auto index_dims = index.dims(); + + if (index_dims.size() == 2) { + PADDLE_ENFORCE_EQ( + index_dims[1], + 1, + phi::errors::InvalidArgument( + "The last dim of index should be 1 when it is 2D, but we get %d", + index_dims[1])); + } else { + PADDLE_ENFORCE_EQ( + index_dims.size(), + 1, + phi::errors::InvalidArgument( + "The index should be 1D, when it is not 2D, but we get %d", + index_dims.size())); + } + + auto input_dim = x.dims(); + auto axis_v = axis.to(); + if (axis.FromTensor() || axis_v == 0) { + // if axis.FromTensor(), we can not obtain correct shape of output + int batch_size = index_dims[0]; + phi::DDim output_dims(input_dim); + output_dims[0] = batch_size; + out->set_dims(output_dims); + out->set_dtype(x.dtype()); + out->share_lod(x); + } else { + int index_size = index_dims[0]; + std::vector out_dim_vec; + for (int i = 0; i < axis_v; i++) { + out_dim_vec.push_back(input_dim[i]); + } + out_dim_vec.push_back(index_size); + for (int i = axis_v + 1; i < input_dim.size(); i++) { + out_dim_vec.push_back(input_dim[i]); + } + auto output_dims = phi::make_ddim(out_dim_vec); + out->set_dims(output_dims); + out->set_dtype(x.dtype()); + out->share_lod(x); + } +} + void GatherNdInferMeta(const MetaTensor& x, const MetaTensor& index, MetaTensor* out) { diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index cfae45cf04..d852db7a84 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include "paddle/phi/common/scalar.h" #include "paddle/phi/core/meta_tensor.h" namespace phi { @@ -81,6 +82,11 @@ void ElementwiseRawInferMeta(const MetaTensor& x_meta, int axis, MetaTensor* out); +void GatherInferMeta(const MetaTensor& x, + const MetaTensor& index, + const Scalar& axis, + MetaTensor* out); + void GatherNdInferMeta(const MetaTensor& x, const MetaTensor& index, MetaTensor* out); -- GitLab