未验证 提交 89b91021 编写于 作者: W wanghuancoder 提交者: GitHub

add informata for strided grad kernel (#56947)

上级 e08d0646
...@@ -36,6 +36,11 @@ void DiagonalGradStridedKernel(const Context& dev_ctx, ...@@ -36,6 +36,11 @@ void DiagonalGradStridedKernel(const Context& dev_ctx,
dev_ctx, *in_grad, 0, in_grad); dev_ctx, *in_grad, 0, in_grad);
})); }));
DenseTensor tmp; DenseTensor tmp;
tmp.set_layout(out_grad.layout());
tmp.set_lod(out_grad.lod());
tmp.set_type(out_grad.dtype());
tmp.Resize(out_grad.dims());
DiagonalStridedKernel<Context>(dev_ctx, *in_grad, offset, axis1, axis2, &tmp); DiagonalStridedKernel<Context>(dev_ctx, *in_grad, offset, axis1, axis2, &tmp);
PD_VISIT_ALL_TYPES(out_grad.dtype(), "DiagonalGradStridedKernel", ([&] { PD_VISIT_ALL_TYPES(out_grad.dtype(), "DiagonalGradStridedKernel", ([&] {
phi::StridedCopyKernel<data_t, Context>( phi::StridedCopyKernel<data_t, Context>(
......
...@@ -34,6 +34,11 @@ void IndexSelectGradStridedKernel(const Context& dev_ctx, ...@@ -34,6 +34,11 @@ void IndexSelectGradStridedKernel(const Context& dev_ctx,
dev_ctx, *x_grad, 0, x_grad); dev_ctx, *x_grad, 0, x_grad);
})); }));
DenseTensor tmp; DenseTensor tmp;
tmp.set_layout(out_grad.layout());
tmp.set_lod(out_grad.lod());
tmp.set_type(out_grad.dtype());
tmp.Resize(out_grad.dims());
IndexSelectStridedKernel<Context>(dev_ctx, *x_grad, index, dim, &tmp); IndexSelectStridedKernel<Context>(dev_ctx, *x_grad, index, dim, &tmp);
PD_VISIT_ALL_TYPES(out_grad.dtype(), "IndexSelectGradStridedKernel", ([&] { PD_VISIT_ALL_TYPES(out_grad.dtype(), "IndexSelectGradStridedKernel", ([&] {
phi::StridedCopyKernel<data_t, Context>( phi::StridedCopyKernel<data_t, Context>(
......
...@@ -38,6 +38,10 @@ void StridedSliceRawGradStridedKernel(const Context& dev_ctx, ...@@ -38,6 +38,10 @@ void StridedSliceRawGradStridedKernel(const Context& dev_ctx,
dev_ctx, *x_grad, 0, x_grad); dev_ctx, *x_grad, 0, x_grad);
})); }));
DenseTensor tmp; DenseTensor tmp;
tmp.set_layout(out_grad.layout());
tmp.set_lod(out_grad.lod());
tmp.set_type(out_grad.dtype());
tmp.Resize(out_grad.dims());
StridedSliceRawStridedKernel<Context>(dev_ctx, StridedSliceRawStridedKernel<Context>(dev_ctx,
*x_grad, *x_grad,
axes, axes,
......
...@@ -40,6 +40,11 @@ void TensorUnfoldGradKernel(const Context& dev_ctx, ...@@ -40,6 +40,11 @@ void TensorUnfoldGradKernel(const Context& dev_ctx,
})); }));
} }
DenseTensor tmp; DenseTensor tmp;
tmp.set_layout(out_grad.layout());
tmp.set_lod(out_grad.lod());
tmp.set_type(out_grad.dtype());
tmp.Resize(out_grad.dims());
TensorUnfoldKernel<Context>(dev_ctx, *input_grad, axis, size, step, &tmp); TensorUnfoldKernel<Context>(dev_ctx, *input_grad, axis, size, step, &tmp);
PD_VISIT_ALL_TYPES(out_grad.dtype(), "TensorUnfoldGradKernel", ([&] { PD_VISIT_ALL_TYPES(out_grad.dtype(), "TensorUnfoldGradKernel", ([&] {
phi::StridedCopyKernel<data_t, Context>( phi::StridedCopyKernel<data_t, Context>(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册