提交 d52ba79d 编写于 作者: M Megvii Engine Team

fix(lite): support set data by copy on device tensor

GitOrigin-RevId: 88b7f73d364d63bb3cad95ec063f283048895f94
上级 275f12c9
......@@ -407,7 +407,7 @@ class LiteTensor(object):
def set_data_by_copy(self, data, data_length=0, layout=None):
"""
copy the data to the tensor
copy the data to the tensor, the memory of the tensor must be continue
param data: the data to copy to tensor, it should be list,
numpy.ndarraya or ctypes with length
"""
......@@ -415,37 +415,34 @@ class LiteTensor(object):
self.layout = layout
assert self.is_continue, "set_data_by_copy can only apply in continue tensor."
assert (
self.is_pinned_host or self.device_type == LiteDeviceType.LITE_CPU
), "set_data_by_copy can only apply in cpu tensor or pinned tensor."
c_type = _lite_dtypes_to_ctype[LiteDataType(self._layout.data_type)]
tensor_memory = c_void_p()
cpu_tensor = LiteTensor(self._layout)
tensor_length = self.nbytes
if type(data) == list:
length = len(data)
self._api.LITE_get_tensor_memory(self._tensor, byref(tensor_memory))
tensor_length = self.nbytes
assert (
length * sizeof(c_type) <= tensor_length
), "the length of input data to set to the tensor is too large."
arr = (c_type * length)(*data)
memmove(tensor_memory, arr, sizeof(c_type) * length)
cdata = (c_type * length)(*data)
self._api.LITE_reset_tensor_memory(cpu_tensor._tensor, cdata, tensor_length)
self.copy_from(cpu_tensor)
elif type(data) == np.ndarray:
if self.nbytes != data.nbytes:
self.layout = LiteLayout(data.shape, data.dtype)
arr = data.ctypes.data_as(POINTER(c_type))
self._api.LITE_get_tensor_memory(self._tensor, byref(tensor_memory))
assert self.nbytes == data.nbytes
memmove(tensor_memory, arr, self.nbytes)
self.layout = LiteLayout(data.shape, data.dtype)
cpu_tensor.layout = LiteLayout(data.shape, data.dtype)
cdata = data.ctypes.data_as(POINTER(c_type))
self._api.LITE_reset_tensor_memory(cpu_tensor._tensor, cdata, self.nbytes)
self.copy_from(cpu_tensor)
else:
assert (
data_length == self.nbytes or layout is not None
), "when input data is ctypes, the length of input data or layout must set"
self._api.LITE_get_tensor_memory(self._tensor, byref(tensor_memory))
memmove(tensor_memory, data, data_length)
self._api.LITE_reset_tensor_memory(cpu_tensor._tensor, data, tensor_length)
self.copy_from(cpu_tensor)
def get_data_by_share(self):
"""
......@@ -454,6 +451,7 @@ class LiteTensor(object):
the tensor memory is write again, such as LiteNetwok forward next time.
"""
self.update()
buffer = c_void_p()
self._api.LITE_get_tensor_memory(self._tensor, byref(buffer))
buffer = self.np_array_type.from_address(buffer.value)
......
......@@ -323,3 +323,27 @@ def test_tensor_get_memory_by_share():
tensor.set_data_by_copy(arr)
assert test_data[1][18] == 5
assert test_data[3][7] == 345
@require_cuda
def test_tensor_set_data_device():
layout = LiteLayout([2, 16], "int8")
tensor = LiteTensor(layout, device_type=LiteDeviceType.LITE_CUDA)
assert tensor.nbytes == 2 * 16
data = [i for i in range(32)]
tensor.set_data_by_copy(data)
real_data = tensor.to_numpy()
for i in range(32):
assert real_data[i // 16][i % 16] == i
arr = np.ones([2, 16], "int8")
tensor.set_data_by_copy(arr)
real_data = tensor.to_numpy()
for i in range(32):
assert real_data[i // 16][i % 16] == 1
tensor.set_data_by_copy(list(range(32)))
real_data = tensor.to_numpy()
for i in range(32):
assert real_data[i // 16][i % 16] == i
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册