From 89b91021b77d9f1ba54efe97410a842363e79a09 Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Tue, 5 Sep 2023 17:26:12 +0800 Subject: [PATCH] add informata for strided grad kernel (#56947) --- paddle/phi/kernels/stride/diagonal_grad_kernel.cc | 5 +++++ paddle/phi/kernels/stride/index_select_grad_kernel.cc | 5 +++++ paddle/phi/kernels/stride/strided_slice_grad_kernel.cc | 4 ++++ paddle/phi/kernels/stride/tensor_unfold_grad_kernel.cc | 5 +++++ 4 files changed, 19 insertions(+) diff --git a/paddle/phi/kernels/stride/diagonal_grad_kernel.cc b/paddle/phi/kernels/stride/diagonal_grad_kernel.cc index 21f4b7564e6..d5ebcd6f4ab 100644 --- a/paddle/phi/kernels/stride/diagonal_grad_kernel.cc +++ b/paddle/phi/kernels/stride/diagonal_grad_kernel.cc @@ -36,6 +36,11 @@ void DiagonalGradStridedKernel(const Context& dev_ctx, dev_ctx, *in_grad, 0, in_grad); })); DenseTensor tmp; + tmp.set_layout(out_grad.layout()); + tmp.set_lod(out_grad.lod()); + tmp.set_type(out_grad.dtype()); + tmp.Resize(out_grad.dims()); + DiagonalStridedKernel(dev_ctx, *in_grad, offset, axis1, axis2, &tmp); PD_VISIT_ALL_TYPES(out_grad.dtype(), "DiagonalGradStridedKernel", ([&] { phi::StridedCopyKernel( diff --git a/paddle/phi/kernels/stride/index_select_grad_kernel.cc b/paddle/phi/kernels/stride/index_select_grad_kernel.cc index 977c5c51e49..15ab602fe53 100644 --- a/paddle/phi/kernels/stride/index_select_grad_kernel.cc +++ b/paddle/phi/kernels/stride/index_select_grad_kernel.cc @@ -34,6 +34,11 @@ void IndexSelectGradStridedKernel(const Context& dev_ctx, dev_ctx, *x_grad, 0, x_grad); })); DenseTensor tmp; + tmp.set_layout(out_grad.layout()); + tmp.set_lod(out_grad.lod()); + tmp.set_type(out_grad.dtype()); + tmp.Resize(out_grad.dims()); + IndexSelectStridedKernel(dev_ctx, *x_grad, index, dim, &tmp); PD_VISIT_ALL_TYPES(out_grad.dtype(), "IndexSelectGradStridedKernel", ([&] { phi::StridedCopyKernel( diff --git a/paddle/phi/kernels/stride/strided_slice_grad_kernel.cc b/paddle/phi/kernels/stride/strided_slice_grad_kernel.cc index 5e55f38d342..9b2d03a00e8 100644 --- a/paddle/phi/kernels/stride/strided_slice_grad_kernel.cc +++ b/paddle/phi/kernels/stride/strided_slice_grad_kernel.cc @@ -38,6 +38,10 @@ void StridedSliceRawGradStridedKernel(const Context& dev_ctx, dev_ctx, *x_grad, 0, x_grad); })); DenseTensor tmp; + tmp.set_layout(out_grad.layout()); + tmp.set_lod(out_grad.lod()); + tmp.set_type(out_grad.dtype()); + tmp.Resize(out_grad.dims()); StridedSliceRawStridedKernel(dev_ctx, *x_grad, axes, diff --git a/paddle/phi/kernels/stride/tensor_unfold_grad_kernel.cc b/paddle/phi/kernels/stride/tensor_unfold_grad_kernel.cc index 5f9e03ac533..620d7bbb46d 100644 --- a/paddle/phi/kernels/stride/tensor_unfold_grad_kernel.cc +++ b/paddle/phi/kernels/stride/tensor_unfold_grad_kernel.cc @@ -40,6 +40,11 @@ void TensorUnfoldGradKernel(const Context& dev_ctx, })); } DenseTensor tmp; + tmp.set_layout(out_grad.layout()); + tmp.set_lod(out_grad.lod()); + tmp.set_type(out_grad.dtype()); + tmp.Resize(out_grad.dims()); + TensorUnfoldKernel(dev_ctx, *input_grad, axis, size, step, &tmp); PD_VISIT_ALL_TYPES(out_grad.dtype(), "TensorUnfoldGradKernel", ([&] { phi::StridedCopyKernel( -- GitLab