未验证 提交 f196b84e 编写于 作者: Y YuanRisheng 提交者: GitHub

fix bugs of reshape double grad infermeta (#41459) (#41493)

上级 57fe4fc9
......@@ -441,26 +441,27 @@ class ReshapeDoubleGradKernel {
public:
void operator()(const framework::ExecutionContext &ctx) const {
auto *dd_x = ctx.Input<framework::Tensor>("DDX");
auto *d_out = ctx.Input<framework::Tensor>("DOut");
auto *dd_out = ctx.Output<framework::Tensor>("DDOut");
dd_out->mutable_data(ctx.GetPlace(), dd_x->type());
if (platform::is_cpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CPUDeviceContext>();
phi::ReshapeDoubleGradKernel(
static_cast<const phi::CPUContext &>(dev_ctx), *dd_x, dd_out);
static_cast<const phi::CPUContext &>(dev_ctx), *d_out, *dd_x, dd_out);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CUDADeviceContext>();
phi::ReshapeDoubleGradKernel(
static_cast<const phi::GPUContext &>(dev_ctx), *dd_x, dd_out);
static_cast<const phi::GPUContext &>(dev_ctx), *d_out, *dd_x, dd_out);
}
#endif
#ifdef PADDLE_WITH_XPU
if (platform::is_xpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::XPUDeviceContext>();
phi::ReshapeDoubleGradKernel(
static_cast<const phi::XPUContext &>(dev_ctx), *dd_x, dd_out);
static_cast<const phi::XPUContext &>(dev_ctx), *d_out, *dd_x, dd_out);
}
#endif
}
......@@ -658,7 +659,7 @@ REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp,
DECLARE_INFER_SHAPE_FUNCTOR(reshape2_grad_grad,
Reshape2DoubleGradInferShapeFunctor,
PD_INFER_META(phi::GeneralUnaryGradInferMeta));
PD_INFER_META(phi::ReshapeDoubleGradInferMeta));
REGISTER_OPERATOR(reshape2_grad_grad, ops::Reshape2DoubleGradOp,
ops::ReshapeDoubleGradInplaceInferer,
......
......@@ -409,6 +409,14 @@ void RealAndImagGradInferMeta(const MetaTensor& out_grad, MetaTensor* dx) {
dx->set_layout(out_grad.layout());
}
void ReshapeDoubleGradInferMeta(const MetaTensor& out_grad,
const MetaTensor& x_grad_grad,
MetaTensor* out_grad_grad) {
if (out_grad_grad != nullptr) {
out_grad_grad->share_dims(out_grad);
}
}
void ScatterGradInferMeta(const MetaTensor& index,
const MetaTensor& updates,
const MetaTensor& out_grad,
......
......@@ -176,6 +176,10 @@ void PoolGradInferMeta(const MetaTensor& x,
void RealAndImagGradInferMeta(const MetaTensor& out_grad, MetaTensor* dx);
void ReshapeDoubleGradInferMeta(const MetaTensor& out_grad,
const MetaTensor& x_grad_grad,
MetaTensor* out_grad_grad);
void ScatterGradInferMeta(const MetaTensor& index,
const MetaTensor& updates,
const MetaTensor& out_grad,
......
......@@ -30,6 +30,7 @@ void ReshapeGradKernel(const Context& dev_ctx,
template <typename Context>
void ReshapeDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const DenseTensor& x_grad_grad,
DenseTensor* out_grad_grad) {
ReshapeGradKernel(dev_ctx, x_grad_grad, out_grad_grad);
......
......@@ -25,6 +25,7 @@ void ReshapeGradKernel(const Context& dev_ctx,
template <typename Context>
void ReshapeDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const DenseTensor& x_grad_grad,
DenseTensor* out_grad_grad);
......
......@@ -47,7 +47,7 @@ KernelSignature ReshapeGradOpArgumentMapping(
KernelSignature ReshapeDoubleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("reshape_double_grad", {"DDX"}, {}, {"DDOut"});
return KernelSignature("reshape_double_grad", {"DOut", "DDX"}, {}, {"DDOut"});
}
} // namespace phi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册