提交 d5ef7923 编写于 作者: M Megvii Engine Team

perf(lite): optimized lite tensor get data by share

GitOrigin-RevId: 62e48ca53926514b87df35e5e08df904949518be
上级 ce9ad07a
......@@ -173,7 +173,7 @@ LITE_API int LITE_register_memory_pair(
* clear the physical and virtual address pair in mge.
*/
LITE_API int LITE_clear_memory_pair(
void* phy_ptr, void* vir_ptr, LiteDeviceType device, LiteBackend backend);
void* vir_ptr, void* phy_ptr, LiteDeviceType device, LiteBackend backend);
#ifdef __cplusplus
}
......
......@@ -198,7 +198,7 @@ int LITE_register_memory_pair(
}
int LITE_clear_memory_pair(
void* phy_ptr, void* vir_ptr, LiteDeviceType device, LiteBackend backend) {
void* vir_ptr, void* phy_ptr, LiteDeviceType device, LiteBackend backend) {
LITE_CAPI_BEGIN();
lite::clear_memory_pair(vir_ptr, phy_ptr, device, backend);
LITE_CAPI_END();
......
......@@ -225,6 +225,7 @@ class LiteTensor(object):
tensor_desc.device_id = device_id
tensor_desc.is_pinned_host = is_pinned_host
self._api.LITE_make_tensor(tensor_desc, byref(self._tensor))
self.update()
def __del__(self):
self._api.LITE_destroy_tensor(self._tensor)
......@@ -318,6 +319,11 @@ class LiteTensor(object):
self._device_type = device_type
self._api.LITE_get_tensor_layout(self._tensor, byref(self._layout))
c_types = _lite_dtypes_to_ctype[self._layout.data_type]
self.np_array_type = np.ctypeslib._ctype_ndarray(
c_types, list(self._layout.shapes)[0 : self._layout.ndim]
)
def copy_from(self, src_tensor):
"""
copy memory form the src_tensor
......@@ -447,15 +453,11 @@ class LiteTensor(object):
return the numpy arrray, be careful, the data in numpy is valid before
the tensor memory is write again, such as LiteNetwok forward next time.
"""
assert self.is_continue, "get_data_by_share can only apply in continue tensor."
assert (
self.is_pinned_host or self.device_type == LiteDeviceType.LITE_CPU
), "get_data_by_share can only apply in CPU tensor or cpu pinned tensor."
memory = self.get_ctypes_memory()
c_type = _lite_dtypes_to_ctype[LiteDataType(self._layout.data_type)]
pnt = cast(memory, POINTER(c_type))
return np.ctypeslib.as_array(pnt, self._layout.shapes)
buffer = c_void_p()
self._api.LITE_get_tensor_memory(self._tensor, byref(buffer))
buffer = self.np_array_type.from_address(buffer.value)
return np.ctypeslib.as_array(buffer)
def to_numpy(self):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册