未验证 提交 ad4a0466 编写于 作者: C Chen Weihang 提交者: GitHub

Add cuda pinned place branch in slice op GetExpectedKernelType (#26027)

* add cuda pinned place branch

* add unittest

* add skip when not gpu
上级 86794ccc
......@@ -155,6 +155,10 @@ class SliceOp : public framework::OperatorWithKernel {
in_tensor.IsInitialized(), true,
platform::errors::InvalidArgument(
"The tensor Input (Input) of Slice op is not initialized."));
// NOTE: cuda pinned tensor need to copy its data to target place
if (platform::is_cuda_pinned_place(in_tensor.place())) {
return framework::OpKernelType(in_tensor.type(), ctx.device_context());
}
return framework::OpKernelType(in_tensor.type(), in_tensor.place());
}
return framework::OpKernelType(
......
......@@ -663,5 +663,21 @@ class TestImperativeVarBaseGetItem(unittest.TestCase):
self.assertRaises(Exception, test_float_in_index)
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestImperativeCUDAPinnedInput(unittest.TestCase):
def test_input_cuda_pinned_var(self):
with fluid.dygraph.guard():
data = np.random.random((2, 80, 16128)).astype('float32')
var = core.VarBase(
value=data,
name='',
persistable=False,
place=fluid.CUDAPinnedPlace(),
zero_copy=False)
sliced = var[:, 10:, :var.shape[1]]
self.assertEqual(sliced.shape, [2, 70, 80])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册