From ad4a0466a53a96ce3d23908cafb94378e71ea403 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Mon, 10 Aug 2020 11:06:39 +0800 Subject: [PATCH] Add cuda pinned place branch in slice op GetExpectedKernelType (#26027) * add cuda pinned place branch * add unittest * add skip when not gpu --- paddle/fluid/operators/slice_op.cc | 4 ++++ .../fluid/tests/unittests/test_slice_op.py | 16 ++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/paddle/fluid/operators/slice_op.cc b/paddle/fluid/operators/slice_op.cc index 8f5df7b6d5d..d147ec3e407 100644 --- a/paddle/fluid/operators/slice_op.cc +++ b/paddle/fluid/operators/slice_op.cc @@ -155,6 +155,10 @@ class SliceOp : public framework::OperatorWithKernel { in_tensor.IsInitialized(), true, platform::errors::InvalidArgument( "The tensor Input (Input) of Slice op is not initialized.")); + // NOTE: cuda pinned tensor need to copy its data to target place + if (platform::is_cuda_pinned_place(in_tensor.place())) { + return framework::OpKernelType(in_tensor.type(), ctx.device_context()); + } return framework::OpKernelType(in_tensor.type(), in_tensor.place()); } return framework::OpKernelType( diff --git a/python/paddle/fluid/tests/unittests/test_slice_op.py b/python/paddle/fluid/tests/unittests/test_slice_op.py index 5ccc6677993..fdcd2d350a6 100644 --- a/python/paddle/fluid/tests/unittests/test_slice_op.py +++ b/python/paddle/fluid/tests/unittests/test_slice_op.py @@ -663,5 +663,21 @@ class TestImperativeVarBaseGetItem(unittest.TestCase): self.assertRaises(Exception, test_float_in_index) +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestImperativeCUDAPinnedInput(unittest.TestCase): + def test_input_cuda_pinned_var(self): + with fluid.dygraph.guard(): + data = np.random.random((2, 80, 16128)).astype('float32') + var = core.VarBase( + value=data, + name='', + persistable=False, + place=fluid.CUDAPinnedPlace(), + zero_copy=False) + sliced = var[:, 10:, :var.shape[1]] + self.assertEqual(sliced.shape, [2, 70, 80]) + + if __name__ == '__main__': unittest.main() -- GitLab