提交 2ce050bb 编写于 作者: M Megvii Engine Team

fix(lite): add warnning to TensorBatchCollector

GitOrigin-RevId: ba45e6a5a48a2ea3c5a0554f6b4e63665954150d
上级 ce119ef5
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import threading import threading
import warnings
import numpy as np import numpy as np
...@@ -51,15 +52,24 @@ class TensorBatchCollector: ...@@ -51,15 +52,24 @@ class TensorBatchCollector:
) )
def collect_id(self, array, batch_id): def collect_id(self, array, batch_id):
# get the batch index
with self._mutex:
if batch_id in self._free_list:
self._free_list.remove(batch_id)
else:
warnings.warn(
"batch {} has been collected, please call free before collected it again.".format(
batch_id
)
)
self._collect_with_id(array, batch_id)
def _collect_with_id(self, array, batch_id):
if isinstance(array, np.ndarray): if isinstance(array, np.ndarray):
shape = array.shape shape = array.shape
assert list(shape) == self.shape[1:] assert list(shape) == self.shape[1:]
in_dtype = ctype_to_lite_dtypes[np.ctypeslib.as_ctypes_type(array.dtype)] in_dtype = ctype_to_lite_dtypes[np.ctypeslib.as_ctypes_type(array.dtype)]
assert in_dtype == self.dtype assert in_dtype == self.dtype
# get the batch index
with self._mutex:
if batch_id in self._free_list:
self._free_list.remove(batch_id)
# get the subtensor # get the subtensor
subtensor = self._tensor.slice([batch_id], [batch_id + 1]) subtensor = self._tensor.slice([batch_id], [batch_id + 1])
if subtensor.device_type == LiteDeviceType.LITE_CPU: if subtensor.device_type == LiteDeviceType.LITE_CPU:
...@@ -77,10 +87,6 @@ class TensorBatchCollector: ...@@ -77,10 +87,6 @@ class TensorBatchCollector:
assert list(shape) == self.shape[1:] assert list(shape) == self.shape[1:]
in_dtype = array.layout.data_type in_dtype = array.layout.data_type
assert in_dtype == self.dtype assert in_dtype == self.dtype
# get the batch index
with self._mutex:
if batch_id in self._free_list:
self._free_list.remove(batch_id)
# get the subtensor # get the subtensor
subtensor = self._tensor.slice([batch_id], [batch_id + 1]) subtensor = self._tensor.slice([batch_id], [batch_id + 1])
subtensor.copy_from(array) subtensor.copy_from(array)
...@@ -90,9 +96,12 @@ class TensorBatchCollector: ...@@ -90,9 +96,12 @@ class TensorBatchCollector:
def collect(self, array): def collect(self, array):
with self._mutex: with self._mutex:
if len(self._free_list) == 0: if len(self._free_list) == 0:
warnings.warn(
"all batch has been collected, please call free before collect again."
)
return -1 return -1
idx = self._free_list.pop(0) idx = self._free_list.pop(0)
return self.collect_id(array, idx) return self._collect_with_id(array, idx)
def collect_by_ctypes(self, data, length): def collect_by_ctypes(self, data, length):
""" """
...@@ -115,6 +124,12 @@ class TensorBatchCollector: ...@@ -115,6 +124,12 @@ class TensorBatchCollector:
def free(self, indexes): def free(self, indexes):
with self._mutex: with self._mutex:
for i in indexes:
if i in self._free_list:
warnings.warn(
"batch id {} has not collected before free it.".format(i)
)
self._free_list.remove(i)
self._free_list.extend(indexes) self._free_list.extend(indexes)
def get(self): def get(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册