diff --git a/lite/pylite/megenginelite/utils.py b/lite/pylite/megenginelite/utils.py index 670aa910b53d952c9279f74aa575fa173823b301..be0e35c633303e6b6c9632cbe3966a0d4ea08d61 100644 --- a/lite/pylite/megenginelite/utils.py +++ b/lite/pylite/megenginelite/utils.py @@ -7,6 +7,7 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import threading +import warnings import numpy as np @@ -51,15 +52,24 @@ class TensorBatchCollector: ) def collect_id(self, array, batch_id): + # get the batch index + with self._mutex: + if batch_id in self._free_list: + self._free_list.remove(batch_id) + else: + warnings.warn( + "batch {} has been collected, please call free before collected it again.".format( + batch_id + ) + ) + self._collect_with_id(array, batch_id) + + def _collect_with_id(self, array, batch_id): if isinstance(array, np.ndarray): shape = array.shape assert list(shape) == self.shape[1:] in_dtype = ctype_to_lite_dtypes[np.ctypeslib.as_ctypes_type(array.dtype)] assert in_dtype == self.dtype - # get the batch index - with self._mutex: - if batch_id in self._free_list: - self._free_list.remove(batch_id) # get the subtensor subtensor = self._tensor.slice([batch_id], [batch_id + 1]) if subtensor.device_type == LiteDeviceType.LITE_CPU: @@ -77,10 +87,6 @@ class TensorBatchCollector: assert list(shape) == self.shape[1:] in_dtype = array.layout.data_type assert in_dtype == self.dtype - # get the batch index - with self._mutex: - if batch_id in self._free_list: - self._free_list.remove(batch_id) # get the subtensor subtensor = self._tensor.slice([batch_id], [batch_id + 1]) subtensor.copy_from(array) @@ -90,9 +96,12 @@ class TensorBatchCollector: def collect(self, array): with self._mutex: if len(self._free_list) == 0: + warnings.warn( + "all batch has been collected, please call free before collect again." + ) return -1 idx = self._free_list.pop(0) - return self.collect_id(array, idx) + return self._collect_with_id(array, idx) def collect_by_ctypes(self, data, length): """ @@ -115,6 +124,12 @@ class TensorBatchCollector: def free(self, indexes): with self._mutex: + for i in indexes: + if i in self._free_list: + warnings.warn( + "batch id {} has not collected before free it.".format(i) + ) + self._free_list.remove(i) self._free_list.extend(indexes) def get(self):