提交 d5ef7923 编写于 作者: M Megvii Engine Team

perf(lite): optimized lite tensor get data by share

GitOrigin-RevId: 62e48ca53926514b87df35e5e08df904949518be
上级 ce9ad07a
...@@ -173,7 +173,7 @@ LITE_API int LITE_register_memory_pair( ...@@ -173,7 +173,7 @@ LITE_API int LITE_register_memory_pair(
* clear the physical and virtual address pair in mge. * clear the physical and virtual address pair in mge.
*/ */
LITE_API int LITE_clear_memory_pair( LITE_API int LITE_clear_memory_pair(
void* phy_ptr, void* vir_ptr, LiteDeviceType device, LiteBackend backend); void* vir_ptr, void* phy_ptr, LiteDeviceType device, LiteBackend backend);
#ifdef __cplusplus #ifdef __cplusplus
} }
......
...@@ -198,7 +198,7 @@ int LITE_register_memory_pair( ...@@ -198,7 +198,7 @@ int LITE_register_memory_pair(
} }
int LITE_clear_memory_pair( int LITE_clear_memory_pair(
void* phy_ptr, void* vir_ptr, LiteDeviceType device, LiteBackend backend) { void* vir_ptr, void* phy_ptr, LiteDeviceType device, LiteBackend backend) {
LITE_CAPI_BEGIN(); LITE_CAPI_BEGIN();
lite::clear_memory_pair(vir_ptr, phy_ptr, device, backend); lite::clear_memory_pair(vir_ptr, phy_ptr, device, backend);
LITE_CAPI_END(); LITE_CAPI_END();
......
...@@ -225,6 +225,7 @@ class LiteTensor(object): ...@@ -225,6 +225,7 @@ class LiteTensor(object):
tensor_desc.device_id = device_id tensor_desc.device_id = device_id
tensor_desc.is_pinned_host = is_pinned_host tensor_desc.is_pinned_host = is_pinned_host
self._api.LITE_make_tensor(tensor_desc, byref(self._tensor)) self._api.LITE_make_tensor(tensor_desc, byref(self._tensor))
self.update()
def __del__(self): def __del__(self):
self._api.LITE_destroy_tensor(self._tensor) self._api.LITE_destroy_tensor(self._tensor)
...@@ -318,6 +319,11 @@ class LiteTensor(object): ...@@ -318,6 +319,11 @@ class LiteTensor(object):
self._device_type = device_type self._device_type = device_type
self._api.LITE_get_tensor_layout(self._tensor, byref(self._layout)) self._api.LITE_get_tensor_layout(self._tensor, byref(self._layout))
c_types = _lite_dtypes_to_ctype[self._layout.data_type]
self.np_array_type = np.ctypeslib._ctype_ndarray(
c_types, list(self._layout.shapes)[0 : self._layout.ndim]
)
def copy_from(self, src_tensor): def copy_from(self, src_tensor):
""" """
copy memory form the src_tensor copy memory form the src_tensor
...@@ -447,15 +453,11 @@ class LiteTensor(object): ...@@ -447,15 +453,11 @@ class LiteTensor(object):
return the numpy arrray, be careful, the data in numpy is valid before return the numpy arrray, be careful, the data in numpy is valid before
the tensor memory is write again, such as LiteNetwok forward next time. the tensor memory is write again, such as LiteNetwok forward next time.
""" """
assert self.is_continue, "get_data_by_share can only apply in continue tensor."
assert (
self.is_pinned_host or self.device_type == LiteDeviceType.LITE_CPU
), "get_data_by_share can only apply in CPU tensor or cpu pinned tensor."
memory = self.get_ctypes_memory() buffer = c_void_p()
c_type = _lite_dtypes_to_ctype[LiteDataType(self._layout.data_type)] self._api.LITE_get_tensor_memory(self._tensor, byref(buffer))
pnt = cast(memory, POINTER(c_type)) buffer = self.np_array_type.from_address(buffer.value)
return np.ctypeslib.as_array(pnt, self._layout.shapes) return np.ctypeslib.as_array(buffer)
def to_numpy(self): def to_numpy(self):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册