From efe6e2840c6a043005e35a28394685011f69ca5b Mon Sep 17 00:00:00 2001 From: Feiyu Chan Date: Thu, 22 Oct 2020 01:53:47 -0500 Subject: [PATCH] fix strided_slice_op's GetExpectedKernelType (#28192) * fix strided_slice_op's GetExpectedKernelType when input tensor is at CUDAPinnedPlace * add unittest for tensors in cuda pinned place * skip test for cuda pinned place on cpu machines --- paddle/fluid/operators/strided_slice_op.cc | 7 ++++++- .../fluid/tests/unittests/test_strided_slice_op.py | 11 +++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/strided_slice_op.cc b/paddle/fluid/operators/strided_slice_op.cc index f8e5d917108..94a0576b772 100644 --- a/paddle/fluid/operators/strided_slice_op.cc +++ b/paddle/fluid/operators/strided_slice_op.cc @@ -154,9 +154,14 @@ class StridedSliceOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { + // NOTE: cuda pinned tensor need to copy its data to target place + auto in_tensor = ctx.Input("Input"); + if (platform::is_cuda_pinned_place(in_tensor->place())) { + return framework::OpKernelType(in_tensor->type(), ctx.device_context()); + } return framework::OpKernelType( OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.Input("Input")->place()); + in_tensor->place()); } framework::OpKernelType GetKernelTypeForVar( const std::string &var_name, const Tensor &tensor, diff --git a/python/paddle/fluid/tests/unittests/test_strided_slice_op.py b/python/paddle/fluid/tests/unittests/test_strided_slice_op.py index 0fe6cd5e7e7..71550c8f247 100644 --- a/python/paddle/fluid/tests/unittests/test_strided_slice_op.py +++ b/python/paddle/fluid/tests/unittests/test_strided_slice_op.py @@ -511,6 +511,17 @@ class TestStridedSliceAPI(unittest.TestCase): x, axes=axes, starts=starts, ends=ends, strides=strides_1) assert sliced_1.shape == (3, 2, 2, 2) + @unittest.skipIf(not paddle.is_compiled_with_cuda(), + "Cannot use CUDAPinnedPlace in CPU only version") + def test_cuda_pinned_place(self): + with paddle.fluid.dygraph.guard(): + x = paddle.to_tensor( + np.random.randn(2, 10), place=paddle.CUDAPinnedPlace()) + self.assertTrue(x.place.is_cuda_pinned_place()) + y = x[:, ::2] + self.assertFalse(x.place.is_cuda_pinned_place()) + self.assertFalse(y.place.is_cuda_pinned_place()) + if __name__ == "__main__": unittest.main() -- GitLab