未验证 提交 53409bcd 编写于 作者: Y YuanRisheng 提交者: GitHub

fix bugs of reshape double grad infermeta (#41459)

上级 c31386ef
...@@ -441,26 +441,27 @@ class ReshapeDoubleGradKernel { ...@@ -441,26 +441,27 @@ class ReshapeDoubleGradKernel {
public: public:
void operator()(const framework::ExecutionContext &ctx) const { void operator()(const framework::ExecutionContext &ctx) const {
auto *dd_x = ctx.Input<framework::Tensor>("DDX"); auto *dd_x = ctx.Input<framework::Tensor>("DDX");
auto *d_out = ctx.Input<framework::Tensor>("DOut");
auto *dd_out = ctx.Output<framework::Tensor>("DDOut"); auto *dd_out = ctx.Output<framework::Tensor>("DDOut");
dd_out->mutable_data(ctx.GetPlace(), dd_x->type()); dd_out->mutable_data(ctx.GetPlace(), dd_x->type());
if (platform::is_cpu_place(ctx.GetPlace())) { if (platform::is_cpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CPUDeviceContext>(); auto &dev_ctx = ctx.device_context<platform::CPUDeviceContext>();
phi::ReshapeDoubleGradKernel( phi::ReshapeDoubleGradKernel(
static_cast<const phi::CPUContext &>(dev_ctx), *dd_x, dd_out); static_cast<const phi::CPUContext &>(dev_ctx), *d_out, *dd_x, dd_out);
} }
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::is_gpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CUDADeviceContext>(); auto &dev_ctx = ctx.device_context<platform::CUDADeviceContext>();
phi::ReshapeDoubleGradKernel( phi::ReshapeDoubleGradKernel(
static_cast<const phi::GPUContext &>(dev_ctx), *dd_x, dd_out); static_cast<const phi::GPUContext &>(dev_ctx), *d_out, *dd_x, dd_out);
} }
#endif #endif
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
if (platform::is_xpu_place(ctx.GetPlace())) { if (platform::is_xpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::XPUDeviceContext>(); auto &dev_ctx = ctx.device_context<platform::XPUDeviceContext>();
phi::ReshapeDoubleGradKernel( phi::ReshapeDoubleGradKernel(
static_cast<const phi::XPUContext &>(dev_ctx), *dd_x, dd_out); static_cast<const phi::XPUContext &>(dev_ctx), *d_out, *dd_x, dd_out);
} }
#endif #endif
} }
...@@ -658,7 +659,7 @@ REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp, ...@@ -658,7 +659,7 @@ REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp,
DECLARE_INFER_SHAPE_FUNCTOR(reshape2_grad_grad, DECLARE_INFER_SHAPE_FUNCTOR(reshape2_grad_grad,
Reshape2DoubleGradInferShapeFunctor, Reshape2DoubleGradInferShapeFunctor,
PD_INFER_META(phi::GeneralUnaryGradInferMeta)); PD_INFER_META(phi::ReshapeDoubleGradInferMeta));
REGISTER_OPERATOR(reshape2_grad_grad, ops::Reshape2DoubleGradOp, REGISTER_OPERATOR(reshape2_grad_grad, ops::Reshape2DoubleGradOp,
ops::ReshapeDoubleGradInplaceInferer, ops::ReshapeDoubleGradInplaceInferer,
......
...@@ -409,6 +409,14 @@ void RealAndImagGradInferMeta(const MetaTensor& out_grad, MetaTensor* dx) { ...@@ -409,6 +409,14 @@ void RealAndImagGradInferMeta(const MetaTensor& out_grad, MetaTensor* dx) {
dx->set_layout(out_grad.layout()); dx->set_layout(out_grad.layout());
} }
void ReshapeDoubleGradInferMeta(const MetaTensor& out_grad,
const MetaTensor& x_grad_grad,
MetaTensor* out_grad_grad) {
if (out_grad_grad != nullptr) {
out_grad_grad->share_dims(out_grad);
}
}
void ScatterGradInferMeta(const MetaTensor& index, void ScatterGradInferMeta(const MetaTensor& index,
const MetaTensor& updates, const MetaTensor& updates,
const MetaTensor& out_grad, const MetaTensor& out_grad,
......
...@@ -176,6 +176,10 @@ void PoolGradInferMeta(const MetaTensor& x, ...@@ -176,6 +176,10 @@ void PoolGradInferMeta(const MetaTensor& x,
void RealAndImagGradInferMeta(const MetaTensor& out_grad, MetaTensor* dx); void RealAndImagGradInferMeta(const MetaTensor& out_grad, MetaTensor* dx);
void ReshapeDoubleGradInferMeta(const MetaTensor& out_grad,
const MetaTensor& x_grad_grad,
MetaTensor* out_grad_grad);
void ScatterGradInferMeta(const MetaTensor& index, void ScatterGradInferMeta(const MetaTensor& index,
const MetaTensor& updates, const MetaTensor& updates,
const MetaTensor& out_grad, const MetaTensor& out_grad,
......
...@@ -30,6 +30,7 @@ void ReshapeGradKernel(const Context& dev_ctx, ...@@ -30,6 +30,7 @@ void ReshapeGradKernel(const Context& dev_ctx,
template <typename Context> template <typename Context>
void ReshapeDoubleGradKernel(const Context& dev_ctx, void ReshapeDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const DenseTensor& x_grad_grad, const DenseTensor& x_grad_grad,
DenseTensor* out_grad_grad) { DenseTensor* out_grad_grad) {
ReshapeGradKernel(dev_ctx, x_grad_grad, out_grad_grad); ReshapeGradKernel(dev_ctx, x_grad_grad, out_grad_grad);
......
...@@ -25,6 +25,7 @@ void ReshapeGradKernel(const Context& dev_ctx, ...@@ -25,6 +25,7 @@ void ReshapeGradKernel(const Context& dev_ctx,
template <typename Context> template <typename Context>
void ReshapeDoubleGradKernel(const Context& dev_ctx, void ReshapeDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const DenseTensor& x_grad_grad, const DenseTensor& x_grad_grad,
DenseTensor* out_grad_grad); DenseTensor* out_grad_grad);
......
...@@ -47,7 +47,7 @@ KernelSignature ReshapeGradOpArgumentMapping( ...@@ -47,7 +47,7 @@ KernelSignature ReshapeGradOpArgumentMapping(
KernelSignature ReshapeDoubleGradOpArgumentMapping( KernelSignature ReshapeDoubleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("reshape_double_grad", {"DDX"}, {}, {"DDOut"}); return KernelSignature("reshape_double_grad", {"DOut", "DDX"}, {}, {"DDOut"});
} }
} // namespace phi } // namespace phi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册