提交 c2b9d542 编写于 作者: M Megvii Engine Team

fix(lite): fix the possibility of obtaining incorrect host device type when calling get_io_tensor

GitOrigin-RevId: db0c8c071239bc7f3d7afaf97203426550d06402
上级 2f06d580
......@@ -140,7 +140,8 @@ class TestNetwork(TestShuffleNetCuda):
network.load(self.model_path)
input_tensor = network.get_io_tensor("data")
assert input_tensor.device_type == LiteDeviceType.LITE_CPU
# the device type is cuda, but by default, the memory type is pinned memory on the host side, which is not on cuda.
assert input_tensor.device_type == LiteDeviceType.LITE_CUDA
self.do_forward(network)
......
......@@ -102,7 +102,7 @@ TensorImplDft::TensorImplDft(
LiteDeviceType TensorImplDft::get_device_type() const {
if (is_host()) {
return LiteDeviceType::LITE_CPU;
return get_device_from_locator(m_host_tensor->comp_node().locator());
} else {
return get_device_from_locator(m_dev_tensor->comp_node().locator());
}
......
......@@ -571,6 +571,17 @@ TEST(TestTensor, ConcatDevice) {
check(1);
check(2);
}
TEST(TestTensor, CudaOutputDevice) {
Layout layout{{1, 4}, 2};
bool is_pinned_host = true;
Tensor tensor(LiteDeviceType::LITE_CUDA, layout, is_pinned_host);
// If is_pinned_host is true, when calling update_from_implement(), the device type
// should always be updated with
// get_device_from_locator(m_host_tensor->comp_node().locator()).
tensor.update_from_implement();
ASSERT_EQ(tensor.get_device_type(), LiteDeviceType::LITE_CUDA);
}
#endif
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册