未验证 提交 efe6e284 编写于 作者: F Feiyu Chan 提交者: GitHub

fix strided_slice_op's GetExpectedKernelType (#28192)

* fix strided_slice_op's GetExpectedKernelType when input tensor is at CUDAPinnedPlace

* add unittest for tensors in cuda pinned place

* skip test for cuda pinned place on cpu machines
上级 271ee58f
...@@ -154,9 +154,14 @@ class StridedSliceOp : public framework::OperatorWithKernel { ...@@ -154,9 +154,14 @@ class StridedSliceOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
// NOTE: cuda pinned tensor need to copy its data to target place
auto in_tensor = ctx.Input<Tensor>("Input");
if (platform::is_cuda_pinned_place(in_tensor->place())) {
return framework::OpKernelType(in_tensor->type(), ctx.device_context());
}
return framework::OpKernelType( return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"), OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.Input<Tensor>("Input")->place()); in_tensor->place());
} }
framework::OpKernelType GetKernelTypeForVar( framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const Tensor &tensor, const std::string &var_name, const Tensor &tensor,
......
...@@ -511,6 +511,17 @@ class TestStridedSliceAPI(unittest.TestCase): ...@@ -511,6 +511,17 @@ class TestStridedSliceAPI(unittest.TestCase):
x, axes=axes, starts=starts, ends=ends, strides=strides_1) x, axes=axes, starts=starts, ends=ends, strides=strides_1)
assert sliced_1.shape == (3, 2, 2, 2) assert sliced_1.shape == (3, 2, 2, 2)
@unittest.skipIf(not paddle.is_compiled_with_cuda(),
"Cannot use CUDAPinnedPlace in CPU only version")
def test_cuda_pinned_place(self):
with paddle.fluid.dygraph.guard():
x = paddle.to_tensor(
np.random.randn(2, 10), place=paddle.CUDAPinnedPlace())
self.assertTrue(x.place.is_cuda_pinned_place())
y = x[:, ::2]
self.assertFalse(x.place.is_cuda_pinned_place())
self.assertFalse(y.place.is_cuda_pinned_place())
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册