From b1b244630ce7fa270a97cc3fb0bd50ee43dcbc13 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Thu, 17 Mar 2022 10:20:25 +0800 Subject: [PATCH] move grid sample op infershape (#40625) --- paddle/fluid/operators/grid_sampler_op.cc | 63 +++++------------------ paddle/phi/infermeta/binary.cc | 42 +++++++++++++++ paddle/phi/infermeta/binary.h | 5 ++ 3 files changed, 59 insertions(+), 51 deletions(-) diff --git a/paddle/fluid/operators/grid_sampler_op.cc b/paddle/fluid/operators/grid_sampler_op.cc index 6ee9582dac..f6d3fd8984 100644 --- a/paddle/fluid/operators/grid_sampler_op.cc +++ b/paddle/fluid/operators/grid_sampler_op.cc @@ -15,9 +15,13 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/infermeta/binary.h" namespace paddle { namespace operators { @@ -27,43 +31,6 @@ using Tensor = framework::Tensor; class GridSampleOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "GridSampler"); - OP_INOUT_CHECK(ctx->HasInput("Grid"), "Input", "Grid", "GridSampler"); - OP_INOUT_CHECK(ctx->HasOutput("Output"), "Output", "Output", "GridSampler"); - - auto x_dims = ctx->GetInputDim("X"); - auto grid_dims = ctx->GetInputDim("Grid"); - PADDLE_ENFORCE_EQ(x_dims.size(), 4, - platform::errors::InvalidArgument( - "Input(X) of GridSampleOp should be 4-D Tensor, but " - "received X dimension size(%d)", - x_dims.size())); - PADDLE_ENFORCE_EQ(grid_dims.size(), 4, - platform::errors::InvalidArgument( - "Input(Grid) of GridSampleOp should be 4-D Tensor, " - "but received X dimension size(%d)", - grid_dims.size())); - if (ctx->IsRuntime() || grid_dims[3] > 0) { - PADDLE_ENFORCE_EQ( - grid_dims[3], 2, - platform::errors::InvalidArgument( - "Input(Grid) dimension[3] should be 2, but received %d", - grid_dims[3])); - } - if (ctx->IsRuntime()) { - PADDLE_ENFORCE_EQ( - grid_dims[0], x_dims[0], - platform::errors::InvalidArgument( - "Input(X) and Input(Grid) dimension[0] should be equal, but " - "received X dimension[0](%d) != Grid dimension[0](%d)", - x_dims[0], grid_dims[0])); - } - - ctx->SetOutputDim("Output", - {x_dims[0], x_dims[1], grid_dims[1], grid_dims[2]}); - ctx->ShareLoD("X", "Output"); - } protected: framework::OpKernelType GetExpectedKernelType( @@ -173,18 +140,6 @@ class GridSampleOpMaker : public framework::OpProtoAndCheckerMaker { class GridSampleOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", - framework::GradVarName("X"), "grid_sampler"); - auto input_dims = ctx->GetInputDim("X"); - auto grid_dims = ctx->GetInputDim("Grid"); - if (ctx->HasOutput(framework::GradVarName("X"))) { - ctx->SetOutputDim(framework::GradVarName("X"), input_dims); - } - if (ctx->HasOutput(framework::GradVarName("Grid"))) { - ctx->SetOutputDim(framework::GradVarName("Grid"), grid_dims); - } - } protected: framework::OpKernelType GetExpectedKernelType( @@ -224,10 +179,16 @@ class GridSampleGradMaker : public framework::SingleGradOpMaker { } // namespace paddle namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(grid_sampler, GridSamplerInferShapeFunctor, + PD_INFER_META(phi::GridSampleBaseInferMeta)); REGISTER_OPERATOR(grid_sampler, ops::GridSampleOp, ops::GridSampleOpMaker, ops::GridSampleGradMaker, - ops::GridSampleGradMaker); -REGISTER_OPERATOR(grid_sampler_grad, ops::GridSampleOpGrad); + ops::GridSampleGradMaker, + GridSamplerInferShapeFunctor); +DECLARE_INFER_SHAPE_FUNCTOR(grid_sampler_grad, GridSamplerGradInferShapeFunctor, + PD_INFER_META(phi::GeneralBinaryGradInferMeta)); +REGISTER_OPERATOR(grid_sampler_grad, ops::GridSampleOpGrad, + GridSamplerGradInferShapeFunctor); REGISTER_OP_VERSION(grid_sampler) .AddCheckpoint( diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 38dce0dc69..521f2a9bf0 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -571,6 +571,48 @@ void GatherTreeMeta(const MetaTensor& ids, out->set_dims(ids_dims); } +void GridSampleBaseInferMeta(const MetaTensor& x, + const MetaTensor& grid, + MetaTensor* out, + MetaConfig config) { + auto x_dims = x.dims(); + auto grid_dims = grid.dims(); + PADDLE_ENFORCE_EQ(x_dims.size(), + 4, + phi::errors::InvalidArgument( + "Input(X) of GridSampleOp should be 4-D Tensor, but " + "received X dimension size(%d)", + x_dims.size())); + PADDLE_ENFORCE_EQ(grid_dims.size(), + 4, + phi::errors::InvalidArgument( + "Input(Grid) of GridSampleOp should be 4-D Tensor, " + "but received X dimension size(%d)", + grid_dims.size())); + if (config.is_runtime || grid_dims[3] > 0) { + PADDLE_ENFORCE_EQ( + grid_dims[3], + 2, + phi::errors::InvalidArgument( + "Input(Grid) dimension[3] should be 2, but received %d", + grid_dims[3])); + } + if (config.is_runtime) { + PADDLE_ENFORCE_EQ( + grid_dims[0], + x_dims[0], + phi::errors::InvalidArgument( + "Input(X) and Input(Grid) dimension[0] should be equal, but " + "received X dimension[0](%d) != Grid dimension[0](%d)", + x_dims[0], + grid_dims[0])); + } + + out->set_dims({x_dims[0], x_dims[1], grid_dims[1], grid_dims[2]}); + out->set_dtype(x.dtype()); + out->share_lod(x); +} + void HuberLossInferMeta(const MetaTensor& input, const MetaTensor& label, float delta, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 8cf7ce3930..9e1a35640a 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -103,6 +103,11 @@ void GatherTreeMeta(const MetaTensor& ids, const MetaTensor& parents, MetaTensor* out); +void GridSampleBaseInferMeta(const MetaTensor& x, + const MetaTensor& grid, + MetaTensor* out, + MetaConfig config = MetaConfig()); + void HuberLossInferMeta(const MetaTensor& input_meta, const MetaTensor& label_meta, float delta, -- GitLab