diff --git a/paddle/fluid/operators/eigen/pad.cc b/paddle/fluid/operators/eigen/pad.cc index 421c9eaf5cde2bbbca56512685903ee3dc28fc49..9db4571357a78781669951d4c672344d2555cde4 100644 --- a/paddle/fluid/operators/eigen/pad.cc +++ b/paddle/fluid/operators/eigen/pad.cc @@ -51,6 +51,7 @@ struct EigenPad { template struct FUNCTOR; \ template struct FUNCTOR; \ template struct FUNCTOR +INSTANTIATION(EigenPad, bool); INSTANTIATION(EigenPad, int); INSTANTIATION(EigenPad, int64_t); INSTANTIATION(EigenPad, float); diff --git a/paddle/fluid/operators/eigen/pad.cu b/paddle/fluid/operators/eigen/pad.cu index 4cf88712d95cbb2e526068ebdfca9999e5fda449..e028a8aef18cfc62c1541cc1931f95b772df8768 100644 --- a/paddle/fluid/operators/eigen/pad.cu +++ b/paddle/fluid/operators/eigen/pad.cu @@ -53,6 +53,7 @@ struct EigenPad { template struct FUNCTOR; \ template struct FUNCTOR; \ template struct FUNCTOR +INSTANTIATION(EigenPad, bool); INSTANTIATION(EigenPad, int); INSTANTIATION(EigenPad, int64_t); INSTANTIATION(EigenPad, float); diff --git a/paddle/fluid/operators/eigen/slice.cu b/paddle/fluid/operators/eigen/slice.cu index dc51fa722202bb2d8b7fb168255a13916f3dc157..3dfd0500cc954f3990ed12d2be5b1a653c733d74 100644 --- a/paddle/fluid/operators/eigen/slice.cu +++ b/paddle/fluid/operators/eigen/slice.cu @@ -53,6 +53,7 @@ struct EigenSlice { template struct FUNCTOR; \ template struct FUNCTOR; \ template struct FUNCTOR +INSTANTIATION(EigenSlice, bool); INSTANTIATION(EigenSlice, int); INSTANTIATION(EigenSlice, int64_t); INSTANTIATION(EigenSlice, float); diff --git a/paddle/fluid/operators/slice_op.cc b/paddle/fluid/operators/slice_op.cc index 2a13998b4f60151d8082fe5f5c4820158b046bc1..a5513ba648776c1906d2a67bd51890ca51dc01fd 100644 --- a/paddle/fluid/operators/slice_op.cc +++ b/paddle/fluid/operators/slice_op.cc @@ -434,7 +434,8 @@ REGISTER_OPERATOR(slice_grad, ops::SliceOpGrad, ops::SliceOpGradVarTypeInference); REGISTER_OP_CPU_KERNEL( - slice, ops::SliceKernel, + slice, ops::SliceKernel, + ops::SliceKernel, ops::SliceKernel, ops::SliceKernel, ops::SliceKernel, @@ -444,7 +445,8 @@ REGISTER_OP_CPU_KERNEL( paddle::platform::complex>); REGISTER_OP_CPU_KERNEL( - slice_grad, ops::SliceGradKernel, + slice_grad, ops::SliceGradKernel, + ops::SliceGradKernel, ops::SliceGradKernel, ops::SliceGradKernel, ops::SliceGradKernel, @@ -454,7 +456,8 @@ REGISTER_OP_CPU_KERNEL( paddle::platform::complex>); REGISTER_OP_CUDA_KERNEL( - slice, ops::SliceKernel, + slice, ops::SliceKernel, + ops::SliceKernel, ops::SliceKernel, ops::SliceKernel, ops::SliceKernel, @@ -466,7 +469,7 @@ REGISTER_OP_CUDA_KERNEL( paddle::platform::complex>); REGISTER_OP_CUDA_KERNEL( - slice_grad, + slice_grad, ops::SliceGradKernel, ops::SliceGradKernel, ops::SliceGradKernel, ops::SliceGradKernel, diff --git a/python/paddle/fluid/tests/unittests/test_slice_op.py b/python/paddle/fluid/tests/unittests/test_slice_op.py index a80dc87525ab803e56f02ad217d7431b1a18c7bc..57d5453ec968ea07a54bc2436fbacabac542d732 100644 --- a/python/paddle/fluid/tests/unittests/test_slice_op.py +++ b/python/paddle/fluid/tests/unittests/test_slice_op.py @@ -553,6 +553,22 @@ class TestSliceApiWithTensor(unittest.TestCase): self.assertTrue(np.array_equal(a_1.numpy(), a_2.numpy())) + def test_bool_tensor(self): + with paddle.fluid.dygraph.guard(): + array = (np.arange(60).reshape([3, 4, 5]) % 3).astype('bool') + tt = paddle.to_tensor(array) + tt.stop_gradient = False + + starts = [0, 1, 2] + ends = [3, 5, 4] + axes = [0, 1, 2] + + y_paddle = paddle.slice(tt, axes, starts, ends) + y_np = tt[0:3, 1:5, 2:4] + + self.assertTrue(paddle.bool == y_paddle.dtype) + self.assertTrue(np.array_equal(y_paddle.numpy(), y_np)) + class TestSliceApiWithLoDTensorArray(unittest.TestCase): def setUp(self):