diff --git a/lite/lite-c/include/lite-c/global_c.h b/lite/lite-c/include/lite-c/global_c.h index 42eed593f3230d806ab9b9fa22a8b2f3f9878630..b270e5900ff4bdae383db18d53da89799627f03a 100644 --- a/lite/lite-c/include/lite-c/global_c.h +++ b/lite/lite-c/include/lite-c/global_c.h @@ -173,7 +173,7 @@ LITE_API int LITE_register_memory_pair( * clear the physical and virtual address pair in mge. */ LITE_API int LITE_clear_memory_pair( - void* phy_ptr, void* vir_ptr, LiteDeviceType device, LiteBackend backend); + void* vir_ptr, void* phy_ptr, LiteDeviceType device, LiteBackend backend); #ifdef __cplusplus } diff --git a/lite/lite-c/src/global.cpp b/lite/lite-c/src/global.cpp index 8be2644ce2d33ed71a47efa7646d7110eda0fa46..ff14113c937c7d96158c63322adc50b6e69b0837 100644 --- a/lite/lite-c/src/global.cpp +++ b/lite/lite-c/src/global.cpp @@ -198,7 +198,7 @@ int LITE_register_memory_pair( } int LITE_clear_memory_pair( - void* phy_ptr, void* vir_ptr, LiteDeviceType device, LiteBackend backend) { + void* vir_ptr, void* phy_ptr, LiteDeviceType device, LiteBackend backend) { LITE_CAPI_BEGIN(); lite::clear_memory_pair(vir_ptr, phy_ptr, device, backend); LITE_CAPI_END(); diff --git a/lite/pylite/megenginelite/tensor.py b/lite/pylite/megenginelite/tensor.py index 188a486ee028d1d4116b3b898a779ab5d72fdac9..897ac44b5604102c446c531226fcadc376cf72a0 100644 --- a/lite/pylite/megenginelite/tensor.py +++ b/lite/pylite/megenginelite/tensor.py @@ -225,6 +225,7 @@ class LiteTensor(object): tensor_desc.device_id = device_id tensor_desc.is_pinned_host = is_pinned_host self._api.LITE_make_tensor(tensor_desc, byref(self._tensor)) + self.update() def __del__(self): self._api.LITE_destroy_tensor(self._tensor) @@ -318,6 +319,11 @@ class LiteTensor(object): self._device_type = device_type self._api.LITE_get_tensor_layout(self._tensor, byref(self._layout)) + c_types = _lite_dtypes_to_ctype[self._layout.data_type] + self.np_array_type = np.ctypeslib._ctype_ndarray( + c_types, list(self._layout.shapes)[0 : self._layout.ndim] + ) + def copy_from(self, src_tensor): """ copy memory form the src_tensor @@ -447,15 +453,11 @@ class LiteTensor(object): return the numpy arrray, be careful, the data in numpy is valid before the tensor memory is write again, such as LiteNetwok forward next time. """ - assert self.is_continue, "get_data_by_share can only apply in continue tensor." - assert ( - self.is_pinned_host or self.device_type == LiteDeviceType.LITE_CPU - ), "get_data_by_share can only apply in CPU tensor or cpu pinned tensor." - memory = self.get_ctypes_memory() - c_type = _lite_dtypes_to_ctype[LiteDataType(self._layout.data_type)] - pnt = cast(memory, POINTER(c_type)) - return np.ctypeslib.as_array(pnt, self._layout.shapes) + buffer = c_void_p() + self._api.LITE_get_tensor_memory(self._tensor, byref(buffer)) + buffer = self.np_array_type.from_address(buffer.value) + return np.ctypeslib.as_array(buffer) def to_numpy(self): """