未验证 提交 f5e430c5 编写于 作者: W WeiXin 提交者: GitHub

slice_op support bool tensor. (#35586)

上级 97a73e1d
...@@ -51,6 +51,7 @@ struct EigenPad<Eigen::DefaultDevice, T, Rank> { ...@@ -51,6 +51,7 @@ struct EigenPad<Eigen::DefaultDevice, T, Rank> {
template struct FUNCTOR<Eigen::DefaultDevice, TYPE, 4>; \ template struct FUNCTOR<Eigen::DefaultDevice, TYPE, 4>; \
template struct FUNCTOR<Eigen::DefaultDevice, TYPE, 5>; \ template struct FUNCTOR<Eigen::DefaultDevice, TYPE, 5>; \
template struct FUNCTOR<Eigen::DefaultDevice, TYPE, 6> template struct FUNCTOR<Eigen::DefaultDevice, TYPE, 6>
INSTANTIATION(EigenPad, bool);
INSTANTIATION(EigenPad, int); INSTANTIATION(EigenPad, int);
INSTANTIATION(EigenPad, int64_t); INSTANTIATION(EigenPad, int64_t);
INSTANTIATION(EigenPad, float); INSTANTIATION(EigenPad, float);
......
...@@ -53,6 +53,7 @@ struct EigenPad<Eigen::GpuDevice, T, Rank> { ...@@ -53,6 +53,7 @@ struct EigenPad<Eigen::GpuDevice, T, Rank> {
template struct FUNCTOR<Eigen::GpuDevice, TYPE, 4>; \ template struct FUNCTOR<Eigen::GpuDevice, TYPE, 4>; \
template struct FUNCTOR<Eigen::GpuDevice, TYPE, 5>; \ template struct FUNCTOR<Eigen::GpuDevice, TYPE, 5>; \
template struct FUNCTOR<Eigen::GpuDevice, TYPE, 6> template struct FUNCTOR<Eigen::GpuDevice, TYPE, 6>
INSTANTIATION(EigenPad, bool);
INSTANTIATION(EigenPad, int); INSTANTIATION(EigenPad, int);
INSTANTIATION(EigenPad, int64_t); INSTANTIATION(EigenPad, int64_t);
INSTANTIATION(EigenPad, float); INSTANTIATION(EigenPad, float);
......
...@@ -53,6 +53,7 @@ struct EigenSlice<Eigen::GpuDevice, T, Rank> { ...@@ -53,6 +53,7 @@ struct EigenSlice<Eigen::GpuDevice, T, Rank> {
template struct FUNCTOR<Eigen::GpuDevice, TYPE, 4>; \ template struct FUNCTOR<Eigen::GpuDevice, TYPE, 4>; \
template struct FUNCTOR<Eigen::GpuDevice, TYPE, 5>; \ template struct FUNCTOR<Eigen::GpuDevice, TYPE, 5>; \
template struct FUNCTOR<Eigen::GpuDevice, TYPE, 6> template struct FUNCTOR<Eigen::GpuDevice, TYPE, 6>
INSTANTIATION(EigenSlice, bool);
INSTANTIATION(EigenSlice, int); INSTANTIATION(EigenSlice, int);
INSTANTIATION(EigenSlice, int64_t); INSTANTIATION(EigenSlice, int64_t);
INSTANTIATION(EigenSlice, float); INSTANTIATION(EigenSlice, float);
......
...@@ -434,7 +434,8 @@ REGISTER_OPERATOR(slice_grad, ops::SliceOpGrad, ...@@ -434,7 +434,8 @@ REGISTER_OPERATOR(slice_grad, ops::SliceOpGrad,
ops::SliceOpGradVarTypeInference); ops::SliceOpGradVarTypeInference);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
slice, ops::SliceKernel<paddle::platform::CPUDeviceContext, int>, slice, ops::SliceKernel<paddle::platform::CPUDeviceContext, bool>,
ops::SliceKernel<paddle::platform::CPUDeviceContext, int>,
ops::SliceKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::SliceKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::SliceKernel<paddle::platform::CPUDeviceContext, float>, ops::SliceKernel<paddle::platform::CPUDeviceContext, float>,
ops::SliceKernel<paddle::platform::CPUDeviceContext, double>, ops::SliceKernel<paddle::platform::CPUDeviceContext, double>,
...@@ -444,7 +445,8 @@ REGISTER_OP_CPU_KERNEL( ...@@ -444,7 +445,8 @@ REGISTER_OP_CPU_KERNEL(
paddle::platform::complex<double>>); paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
slice_grad, ops::SliceGradKernel<paddle::platform::CPUDeviceContext, int>, slice_grad, ops::SliceGradKernel<paddle::platform::CPUDeviceContext, bool>,
ops::SliceGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::SliceGradKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::SliceGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::SliceGradKernel<paddle::platform::CPUDeviceContext, float>, ops::SliceGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::SliceGradKernel<paddle::platform::CPUDeviceContext, double>, ops::SliceGradKernel<paddle::platform::CPUDeviceContext, double>,
...@@ -454,7 +456,8 @@ REGISTER_OP_CPU_KERNEL( ...@@ -454,7 +456,8 @@ REGISTER_OP_CPU_KERNEL(
paddle::platform::complex<double>>); paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
slice, ops::SliceKernel<paddle::platform::CUDADeviceContext, float>, slice, ops::SliceKernel<paddle::platform::CUDADeviceContext, bool>,
ops::SliceKernel<paddle::platform::CUDADeviceContext, float>,
ops::SliceKernel<paddle::platform::CUDADeviceContext, double>, ops::SliceKernel<paddle::platform::CUDADeviceContext, double>,
ops::SliceKernel<paddle::platform::CUDADeviceContext, int>, ops::SliceKernel<paddle::platform::CUDADeviceContext, int>,
ops::SliceKernel<paddle::platform::CUDADeviceContext, int64_t>, ops::SliceKernel<paddle::platform::CUDADeviceContext, int64_t>,
...@@ -466,7 +469,7 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -466,7 +469,7 @@ REGISTER_OP_CUDA_KERNEL(
paddle::platform::complex<double>>); paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
slice_grad, slice_grad, ops::SliceGradKernel<paddle::platform::CUDADeviceContext, bool>,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, float>, ops::SliceGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, double>, ops::SliceGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, int>, ops::SliceGradKernel<paddle::platform::CUDADeviceContext, int>,
......
...@@ -553,6 +553,22 @@ class TestSliceApiWithTensor(unittest.TestCase): ...@@ -553,6 +553,22 @@ class TestSliceApiWithTensor(unittest.TestCase):
self.assertTrue(np.array_equal(a_1.numpy(), a_2.numpy())) self.assertTrue(np.array_equal(a_1.numpy(), a_2.numpy()))
def test_bool_tensor(self):
with paddle.fluid.dygraph.guard():
array = (np.arange(60).reshape([3, 4, 5]) % 3).astype('bool')
tt = paddle.to_tensor(array)
tt.stop_gradient = False
starts = [0, 1, 2]
ends = [3, 5, 4]
axes = [0, 1, 2]
y_paddle = paddle.slice(tt, axes, starts, ends)
y_np = tt[0:3, 1:5, 2:4]
self.assertTrue(paddle.bool == y_paddle.dtype)
self.assertTrue(np.array_equal(y_paddle.numpy(), y_np))
class TestSliceApiWithLoDTensorArray(unittest.TestCase): class TestSliceApiWithLoDTensorArray(unittest.TestCase):
def setUp(self): def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册