From 27996fd1a89d55a2ce0eda4d0e3d23efde882dbe Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Mon, 28 Mar 2022 10:42:05 +0800 Subject: [PATCH] [Phi] Move backward infershape of Reshape Op (#40914) * perfect reshape kernel * fix bugs of sig * add unittest for reshape_sig * fix bugs when run converage --- paddle/fluid/operators/flatten_op.cc | 13 +++++++--- paddle/fluid/operators/reshape_op.cc | 30 ++++++++++++----------- paddle/phi/infermeta/backward.cc | 7 ++++++ paddle/phi/infermeta/backward.h | 2 ++ paddle/phi/ops/compat/reshape_sig.cc | 9 ++++++- paddle/phi/tests/ops/test_op_signature.cc | 18 ++++++++++++++ paddle/phi/tests/ops/test_op_signature.h | 2 +- 7 files changed, 61 insertions(+), 20 deletions(-) diff --git a/paddle/fluid/operators/flatten_op.cc b/paddle/fluid/operators/flatten_op.cc index b0a7007755..d1ac573b84 100644 --- a/paddle/fluid/operators/flatten_op.cc +++ b/paddle/fluid/operators/flatten_op.cc @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/backward.h" #include "paddle/phi/infermeta/unary.h" namespace paddle { @@ -365,10 +366,14 @@ class FlattenContiguousRangeGradOp : public framework::OperatorWithKernel { "FlattenContiguousRangeGrad"); OP_INOUT_CHECK(context->HasInput(framework::GradVarName("Out")), "Input", framework::GradVarName("Out"), "FlattenContiguousRangeGrad"); - auto xshape_dims = context->GetInputDim("XShape"); - auto x_dims = phi::slice_ddim(xshape_dims, 1, xshape_dims.size()); - context->SetOutputDim(framework::GradVarName("X"), x_dims); - context->ShareLoD("XShape", framework::GradVarName("X")); + // Construct MetaTensor for InferMeta Func + using CompatMetaTensor = framework::CompatMetaTensor; + CompatMetaTensor xshape(context->GetInputVarPtrs("XShape")[0], + context->IsRuntime()); + CompatMetaTensor dx( + context->GetOutputVarPtrs(framework::GradVarName("X"))[0], + context->IsRuntime()); + phi::KernelWithXShapeInferMeta(xshape, &dx); } protected: diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index 8d99a60b12..4a4210845c 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -23,6 +23,7 @@ limitations under the License. */ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/common/scalar_array.h" #include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/backward.h" #include "paddle/phi/infermeta/unary.h" #include "paddle/phi/kernels/reshape_grad_kernel.h" #include "paddle/phi/kernels/reshape_kernel.h" @@ -558,10 +559,14 @@ class Reshape2GradOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true, platform::errors::InvalidArgument( "Input(Out@GRAD) shouldn't be null.")); - auto xshape_dims = ctx->GetInputDim("XShape"); - auto x_dims = phi::slice_ddim(xshape_dims, 1, xshape_dims.size()); - ctx->SetOutputDim(framework::GradVarName("X"), x_dims); - ctx->ShareLoD("XShape", framework::GradVarName("X")); + + // Construct MetaTensor for InferMeta Func + using CompatMetaTensor = framework::CompatMetaTensor; + CompatMetaTensor xshape(ctx->GetInputVarPtrs("XShape")[0], + ctx->IsRuntime()); + CompatMetaTensor dx(ctx->GetOutputVarPtrs(framework::GradVarName("X"))[0], + ctx->IsRuntime()); + phi::KernelWithXShapeInferMeta(xshape, &dx); } protected: @@ -592,15 +597,6 @@ class Reshape2DoubleGradOp : public framework::OperatorWithKernel { const framework::AttributeMap &attrs) : OperatorWithKernel(type, inputs, outputs, attrs) {} - void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE_EQ(ctx->HasInput("DDX"), true, - platform::errors::InvalidArgument( - "Input(X@GRAD_GRAD) shouldn't be null.")); - if (ctx->HasOutput("DDOut") && ctx->HasInput("DDX")) { - ctx->ShareDim("DOut", "DDOut"); - } - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { @@ -659,9 +655,15 @@ REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp, ops::Reshape2DoubleGradMaker, ops::Reshape2DoubleGradMaker, ops::ReshapeGradInplaceInferer); + +DECLARE_INFER_SHAPE_FUNCTOR(reshape2_grad_grad, + Reshape2DoubleGradInferShapeFunctor, + PD_INFER_META(phi::GeneralUnaryGradInferMeta)); + REGISTER_OPERATOR(reshape2_grad_grad, ops::Reshape2DoubleGradOp, ops::ReshapeDoubleGradInplaceInferer, - ops::ReshapeDoubleGradOpNoNeedBufferVarInferer); + ops::ReshapeDoubleGradOpNoNeedBufferVarInferer, + Reshape2DoubleGradInferShapeFunctor); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double, diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index b680222f86..5d9ed8e9e8 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -161,6 +161,13 @@ void GumbelSoftmaxGradInferMeta(const MetaTensor& out, dx->share_meta(dout); } +void KernelWithXShapeInferMeta(const MetaTensor& xshape, MetaTensor* dx) { + auto xshape_dims = xshape.dims(); + auto x_dims = phi::slice_ddim(xshape_dims, 1, xshape_dims.size()); + dx->set_dims(x_dims); + dx->share_lod(xshape); +} + void MaxPoolWithIndexGradInferMeta(const MetaTensor& x, const MetaTensor& mask, const MetaTensor& dout, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 5c49a58a71..10b3e7cec7 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -92,6 +92,8 @@ void GumbelSoftmaxGradInferMeta(const MetaTensor& out, int axis, MetaTensor* dx); +void KernelWithXShapeInferMeta(const MetaTensor& xshape, MetaTensor* dx); + void MaxPoolWithIndexGradInferMeta(const MetaTensor& x, const MetaTensor& mask, const MetaTensor& dout, diff --git a/paddle/phi/ops/compat/reshape_sig.cc b/paddle/phi/ops/compat/reshape_sig.cc index b6d10dabb1..ccae6aad02 100644 --- a/paddle/phi/ops/compat/reshape_sig.cc +++ b/paddle/phi/ops/compat/reshape_sig.cc @@ -28,8 +28,15 @@ KernelSignature ReshapeOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature( "reshape_with_xshape", {"X"}, {"shape"}, {"XShape", "Out"}); } + } else { + if (ctx.InputSize("ShapeTensor") > 0) { + return KernelSignature("reshape", {"X"}, {"ShapeTensor"}, {"Out"}); + } else if (ctx.HasInput("Shape")) { + return KernelSignature("reshape", {"X"}, {"Shape"}, {"Out"}); + } else { + return KernelSignature("reshape", {"X"}, {"shape"}, {"Out"}); + } } - return KernelSignature("unregistered", {}, {}, {}); } KernelSignature ReshapeGradOpArgumentMapping( diff --git a/paddle/phi/tests/ops/test_op_signature.cc b/paddle/phi/tests/ops/test_op_signature.cc index 36923972ea..6acf3916a1 100644 --- a/paddle/phi/tests/ops/test_op_signature.cc +++ b/paddle/phi/tests/ops/test_op_signature.cc @@ -577,5 +577,23 @@ TEST(ARG_MAP, allclose) { ASSERT_EQ(attr_names2[1], "Atol"); } +TEST(ARG_MAP, reshape) { + TestArgumentMappingContext arg_case1({"X", "ShapeTensor"}, {}, {}, {"Out"}); + auto signature1 = + OpUtilsMap::Instance().GetArgumentMappingFn("reshape2")(arg_case1); + ASSERT_EQ(signature1.name, "reshape"); + + TestArgumentMappingContext arg_case2({"X", "Shape"}, {}, {}, {"Out"}); + auto signature2 = + OpUtilsMap::Instance().GetArgumentMappingFn("reshape2")(arg_case2); + ASSERT_EQ(signature2.name, "reshape"); + + TestArgumentMappingContext arg_case3( + {"X"}, {}, {{"shape", paddle::any(std::vector({1, 2}))}}, {"Out"}); + auto signature3 = + OpUtilsMap::Instance().GetArgumentMappingFn("reshape2")(arg_case3); + ASSERT_EQ(signature3.name, "reshape"); +} + } // namespace tests } // namespace phi diff --git a/paddle/phi/tests/ops/test_op_signature.h b/paddle/phi/tests/ops/test_op_signature.h index 8468dad10e..4a84793527 100644 --- a/paddle/phi/tests/ops/test_op_signature.h +++ b/paddle/phi/tests/ops/test_op_signature.h @@ -57,7 +57,7 @@ class TestArgumentMappingContext : public phi::ArgumentMappingContext { } size_t InputSize(const std::string& name) const override { - return dense_tensor_inputs.size() + selected_rows_inputs.size(); + return dense_tensor_inputs.count(name) + selected_rows_inputs.count(name); } size_t OutputSize(const std::string& name) const override { -- GitLab