From f196b84eacb7488756b68d3dc497517bc45de33b Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Fri, 8 Apr 2022 10:41:32 +0800 Subject: [PATCH] fix bugs of reshape double grad infermeta (#41459) (#41493) --- paddle/fluid/operators/reshape_op.cc | 9 +++++---- paddle/phi/infermeta/backward.cc | 8 ++++++++ paddle/phi/infermeta/backward.h | 4 ++++ paddle/phi/kernels/reshape_grad_kernel.cc | 1 + paddle/phi/kernels/reshape_grad_kernel.h | 1 + paddle/phi/ops/compat/reshape_sig.cc | 2 +- 6 files changed, 20 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index 0befc873ed..8ccd1b26a3 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -441,26 +441,27 @@ class ReshapeDoubleGradKernel { public: void operator()(const framework::ExecutionContext &ctx) const { auto *dd_x = ctx.Input("DDX"); + auto *d_out = ctx.Input("DOut"); auto *dd_out = ctx.Output("DDOut"); dd_out->mutable_data(ctx.GetPlace(), dd_x->type()); if (platform::is_cpu_place(ctx.GetPlace())) { auto &dev_ctx = ctx.device_context(); phi::ReshapeDoubleGradKernel( - static_cast(dev_ctx), *dd_x, dd_out); + static_cast(dev_ctx), *d_out, *dd_x, dd_out); } #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (platform::is_gpu_place(ctx.GetPlace())) { auto &dev_ctx = ctx.device_context(); phi::ReshapeDoubleGradKernel( - static_cast(dev_ctx), *dd_x, dd_out); + static_cast(dev_ctx), *d_out, *dd_x, dd_out); } #endif #ifdef PADDLE_WITH_XPU if (platform::is_xpu_place(ctx.GetPlace())) { auto &dev_ctx = ctx.device_context(); phi::ReshapeDoubleGradKernel( - static_cast(dev_ctx), *dd_x, dd_out); + static_cast(dev_ctx), *d_out, *dd_x, dd_out); } #endif } @@ -658,7 +659,7 @@ REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp, DECLARE_INFER_SHAPE_FUNCTOR(reshape2_grad_grad, Reshape2DoubleGradInferShapeFunctor, - PD_INFER_META(phi::GeneralUnaryGradInferMeta)); + PD_INFER_META(phi::ReshapeDoubleGradInferMeta)); REGISTER_OPERATOR(reshape2_grad_grad, ops::Reshape2DoubleGradOp, ops::ReshapeDoubleGradInplaceInferer, diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 43d7d0393d..49e416fd01 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -409,6 +409,14 @@ void RealAndImagGradInferMeta(const MetaTensor& out_grad, MetaTensor* dx) { dx->set_layout(out_grad.layout()); } +void ReshapeDoubleGradInferMeta(const MetaTensor& out_grad, + const MetaTensor& x_grad_grad, + MetaTensor* out_grad_grad) { + if (out_grad_grad != nullptr) { + out_grad_grad->share_dims(out_grad); + } +} + void ScatterGradInferMeta(const MetaTensor& index, const MetaTensor& updates, const MetaTensor& out_grad, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 432c1aacfc..eff3731bf2 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -176,6 +176,10 @@ void PoolGradInferMeta(const MetaTensor& x, void RealAndImagGradInferMeta(const MetaTensor& out_grad, MetaTensor* dx); +void ReshapeDoubleGradInferMeta(const MetaTensor& out_grad, + const MetaTensor& x_grad_grad, + MetaTensor* out_grad_grad); + void ScatterGradInferMeta(const MetaTensor& index, const MetaTensor& updates, const MetaTensor& out_grad, diff --git a/paddle/phi/kernels/reshape_grad_kernel.cc b/paddle/phi/kernels/reshape_grad_kernel.cc index 3813296640..129a69d4e4 100644 --- a/paddle/phi/kernels/reshape_grad_kernel.cc +++ b/paddle/phi/kernels/reshape_grad_kernel.cc @@ -30,6 +30,7 @@ void ReshapeGradKernel(const Context& dev_ctx, template void ReshapeDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& out_grad, const DenseTensor& x_grad_grad, DenseTensor* out_grad_grad) { ReshapeGradKernel(dev_ctx, x_grad_grad, out_grad_grad); diff --git a/paddle/phi/kernels/reshape_grad_kernel.h b/paddle/phi/kernels/reshape_grad_kernel.h index 4eb3f68337..06ec3de15a 100644 --- a/paddle/phi/kernels/reshape_grad_kernel.h +++ b/paddle/phi/kernels/reshape_grad_kernel.h @@ -25,6 +25,7 @@ void ReshapeGradKernel(const Context& dev_ctx, template void ReshapeDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& out_grad, const DenseTensor& x_grad_grad, DenseTensor* out_grad_grad); diff --git a/paddle/phi/ops/compat/reshape_sig.cc b/paddle/phi/ops/compat/reshape_sig.cc index 6b528efe6d..04f64e4035 100644 --- a/paddle/phi/ops/compat/reshape_sig.cc +++ b/paddle/phi/ops/compat/reshape_sig.cc @@ -47,7 +47,7 @@ KernelSignature ReshapeGradOpArgumentMapping( KernelSignature ReshapeDoubleGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature("reshape_double_grad", {"DDX"}, {}, {"DDOut"}); + return KernelSignature("reshape_double_grad", {"DOut", "DDX"}, {}, {"DDOut"}); } } // namespace phi -- GitLab