未验证 提交 1e1c7275 编写于 作者: W will-jl944 提交者: GitHub

slice op supports uint8_t (#47067)

上级 be3908a3
...@@ -460,6 +460,7 @@ REGISTER_OPERATOR(slice_grad, ...@@ -460,6 +460,7 @@ REGISTER_OPERATOR(slice_grad,
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
slice, slice,
ops::SliceKernel<phi::CPUContext, bool>, ops::SliceKernel<phi::CPUContext, bool>,
ops::SliceKernel<phi::CPUContext, uint8_t>,
ops::SliceKernel<phi::CPUContext, int>, ops::SliceKernel<phi::CPUContext, int>,
ops::SliceKernel<phi::CPUContext, int64_t>, ops::SliceKernel<phi::CPUContext, int64_t>,
ops::SliceKernel<phi::CPUContext, float>, ops::SliceKernel<phi::CPUContext, float>,
...@@ -471,6 +472,7 @@ REGISTER_OP_CPU_KERNEL( ...@@ -471,6 +472,7 @@ REGISTER_OP_CPU_KERNEL(
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
slice_grad, slice_grad,
ops::SliceGradKernel<phi::CPUContext, bool>, ops::SliceGradKernel<phi::CPUContext, bool>,
ops::SliceGradKernel<phi::CPUContext, uint8_t>,
ops::SliceGradKernel<phi::CPUContext, int>, ops::SliceGradKernel<phi::CPUContext, int>,
ops::SliceGradKernel<phi::CPUContext, int64_t>, ops::SliceGradKernel<phi::CPUContext, int64_t>,
ops::SliceGradKernel<phi::CPUContext, float>, ops::SliceGradKernel<phi::CPUContext, float>,
...@@ -482,6 +484,7 @@ REGISTER_OP_CPU_KERNEL( ...@@ -482,6 +484,7 @@ REGISTER_OP_CPU_KERNEL(
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
slice, slice,
ops::SliceKernel<phi::GPUContext, bool>, ops::SliceKernel<phi::GPUContext, bool>,
ops::SliceKernel<phi::GPUContext, uint8_t>,
ops::SliceKernel<phi::GPUContext, float>, ops::SliceKernel<phi::GPUContext, float>,
ops::SliceKernel<phi::GPUContext, double>, ops::SliceKernel<phi::GPUContext, double>,
ops::SliceKernel<phi::GPUContext, int>, ops::SliceKernel<phi::GPUContext, int>,
...@@ -494,6 +497,7 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -494,6 +497,7 @@ REGISTER_OP_CUDA_KERNEL(
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
slice_grad, slice_grad,
ops::SliceGradKernel<phi::GPUContext, bool>, ops::SliceGradKernel<phi::GPUContext, bool>,
ops::SliceGradKernel<phi::GPUContext, uint8_t>,
ops::SliceGradKernel<phi::GPUContext, float>, ops::SliceGradKernel<phi::GPUContext, float>,
ops::SliceGradKernel<phi::GPUContext, double>, ops::SliceGradKernel<phi::GPUContext, double>,
ops::SliceGradKernel<phi::GPUContext, int>, ops::SliceGradKernel<phi::GPUContext, int>,
......
...@@ -23,6 +23,7 @@ PD_REGISTER_KERNEL(slice_grad, ...@@ -23,6 +23,7 @@ PD_REGISTER_KERNEL(slice_grad,
ALL_LAYOUT, ALL_LAYOUT,
phi::SliceGradRawKernel, phi::SliceGradRawKernel,
bool, bool,
uint8_t,
int, int,
int64_t, int64_t,
float, float,
......
...@@ -23,6 +23,7 @@ PD_REGISTER_KERNEL(slice, ...@@ -23,6 +23,7 @@ PD_REGISTER_KERNEL(slice,
ALL_LAYOUT, ALL_LAYOUT,
phi::SliceRawKernel, phi::SliceRawKernel,
bool, bool,
uint8_t,
int, int,
int64_t, int64_t,
float, float,
......
...@@ -59,6 +59,7 @@ struct EigenPad<Eigen::DefaultDevice, T, Rank> { ...@@ -59,6 +59,7 @@ struct EigenPad<Eigen::DefaultDevice, T, Rank> {
template struct FUNCTOR<Eigen::DefaultDevice, TYPE, 6>; template struct FUNCTOR<Eigen::DefaultDevice, TYPE, 6>;
INSTANTIATION(EigenPad, bool); INSTANTIATION(EigenPad, bool);
INSTANTIATION(EigenPad, uint8_t);
INSTANTIATION(EigenPad, int); INSTANTIATION(EigenPad, int);
INSTANTIATION(EigenPad, int64_t); INSTANTIATION(EigenPad, int64_t);
INSTANTIATION(EigenPad, float); INSTANTIATION(EigenPad, float);
......
...@@ -59,6 +59,7 @@ struct EigenPad<Eigen::GpuDevice, T, Rank> { ...@@ -59,6 +59,7 @@ struct EigenPad<Eigen::GpuDevice, T, Rank> {
template struct FUNCTOR<Eigen::GpuDevice, TYPE, 5>; \ template struct FUNCTOR<Eigen::GpuDevice, TYPE, 5>; \
template struct FUNCTOR<Eigen::GpuDevice, TYPE, 6> template struct FUNCTOR<Eigen::GpuDevice, TYPE, 6>
INSTANTIATION(EigenPad, bool); INSTANTIATION(EigenPad, bool);
INSTANTIATION(EigenPad, uint8_t);
INSTANTIATION(EigenPad, int); INSTANTIATION(EigenPad, int);
INSTANTIATION(EigenPad, int64_t); INSTANTIATION(EigenPad, int64_t);
INSTANTIATION(EigenPad, float); INSTANTIATION(EigenPad, float);
......
...@@ -59,6 +59,7 @@ struct EigenSlice<Eigen::GpuDevice, T, Rank> { ...@@ -59,6 +59,7 @@ struct EigenSlice<Eigen::GpuDevice, T, Rank> {
template struct FUNCTOR<Eigen::GpuDevice, TYPE, 5>; \ template struct FUNCTOR<Eigen::GpuDevice, TYPE, 5>; \
template struct FUNCTOR<Eigen::GpuDevice, TYPE, 6> template struct FUNCTOR<Eigen::GpuDevice, TYPE, 6>
INSTANTIATION(EigenSlice, bool); INSTANTIATION(EigenSlice, bool);
INSTANTIATION(EigenSlice, uint8_t);
INSTANTIATION(EigenSlice, int); INSTANTIATION(EigenSlice, int);
INSTANTIATION(EigenSlice, int64_t); INSTANTIATION(EigenSlice, int64_t);
INSTANTIATION(EigenSlice, float); INSTANTIATION(EigenSlice, float);
......
...@@ -23,6 +23,7 @@ PD_REGISTER_KERNEL(slice_grad, ...@@ -23,6 +23,7 @@ PD_REGISTER_KERNEL(slice_grad,
ALL_LAYOUT, ALL_LAYOUT,
phi::SliceGradRawKernel, phi::SliceGradRawKernel,
bool, bool,
uint8_t,
int, int,
int64_t, int64_t,
float, float,
......
...@@ -23,6 +23,7 @@ PD_REGISTER_KERNEL(slice, ...@@ -23,6 +23,7 @@ PD_REGISTER_KERNEL(slice,
ALL_LAYOUT, ALL_LAYOUT,
phi::SliceRawKernel, phi::SliceRawKernel,
bool, bool,
uint8_t,
int, int,
int64_t, int64_t,
float, float,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册