From 42a7514504c65e6c3892127c42153cfccf1aef31 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Fri, 10 Feb 2023 15:56:27 +0800 Subject: [PATCH] Fix inferMefer in transpose2_grad (#50388) * Fix inferMefer in transpose2_grad * fix infershape * fix unittest --- paddle/fluid/operators/transpose_op.cc | 45 +++++++++----------------- 1 file changed, 15 insertions(+), 30 deletions(-) diff --git a/paddle/fluid/operators/transpose_op.cc b/paddle/fluid/operators/transpose_op.cc index d49cbad1147..671997adda8 100644 --- a/paddle/fluid/operators/transpose_op.cc +++ b/paddle/fluid/operators/transpose_op.cc @@ -16,6 +16,10 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" + #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" #endif @@ -179,19 +183,6 @@ class TransposeOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "TransposeOpGrad"); - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), - "Input", - framework::GradVarName("Out"), - "TransposeOpGrad"); - auto x_dims = ctx->GetInputDim("X"); - ctx->SetOutputDim(framework::GradVarName("X"), x_dims); - if (ctx->HasOutput(framework::GradVarName("X"))) { - ctx->SetOutputDim(framework::GradVarName("X"), x_dims); - } - } - protected: phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { @@ -320,21 +311,6 @@ class Transpose2OpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK( - ctx->HasInput("XShape"), "Input", "XShape", "Transpose2OpGrad"); - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), - "Input", - framework::GradVarName("Out"), - "Transpose2OpGrad"); - if (ctx->HasOutput(framework::GradVarName("X"))) { - auto xshape_dim = ctx->GetInputDim("XShape"); - auto x_shape_dim = phi::slice_ddim(xshape_dim, 1, xshape_dim.size()); - ctx->SetOutputDim(framework::GradVarName("X"), x_shape_dim); - ctx->ShareLoD("XShape", framework::GradVarName("X")); - } - } - protected: phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { @@ -359,6 +335,13 @@ class TransposeGradInferVarType : public framework::VarTypeInference { } // namespace operators } // namespace paddle +DECLARE_INFER_SHAPE_FUNCTOR(transpose_grad, + TransposeGradInferShapeFunctor, + PD_INFER_META(phi::TransposeGradInferMeta)); + +DECLARE_INFER_SHAPE_FUNCTOR(transpose2_grad, + Transpose2GradInferShapeFunctor, + PD_INFER_META(phi::TransposeGradInferMeta)); namespace ops = paddle::operators; REGISTER_OPERATOR( transpose, @@ -368,7 +351,8 @@ REGISTER_OPERATOR( paddle::framework::DefaultGradOpMaker); REGISTER_OPERATOR(transpose_grad, ops::TransposeOpGrad, - ops::TransposeGradInferVarType); + ops::TransposeGradInferVarType, + TransposeGradInferShapeFunctor); REGISTER_OPERATOR(transpose2, ops::Transpose2Op, @@ -379,4 +363,5 @@ REGISTER_OPERATOR(transpose2_grad, ops::Transpose2OpGrad, ops::TransposeGradInferVarType, ops::Transpose2DoubleGradMaker, - ops::Transpose2DoubleGradMaker); + ops::Transpose2DoubleGradMaker, + Transpose2GradInferShapeFunctor); -- GitLab