utils.py 8.0 KB
Newer Older
1 2
# -*- coding: utf-8 -*-
import threading
3
import warnings
4 5 6 7 8 9 10 11 12 13

import numpy as np

from .base import *
from .struct import *
from .tensor import *


class TensorBatchCollector:
    """
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
    A tensor utils is used to collect many single batch tensor to a multi batch
    size tensor, when the multi batch size tensor collect finish, the result
    tensor can be get and send to the model input for forwarding.

    when collect single batch tensor, the single batch tensor is no need in the
    same device_type and device_id with the result tensor, however the dtype must
    match and the shape must match except the highest dimension.

    Args:
        shape: the multi batch size tensor shape, After collection, the result
            tensor shape.
        dtype(LiteDataType): the datatype of the single batch tensor and the
            result tensor, default value is LiteDataType.LITE_INT8.
        device_type(LiteDeviceType): the target device type the result tensor
            will allocate, default value is LiteDeviceType.LITE_CUDA.
        device_id: the device id the result tensor will allocate, default 0.
        is_pinned_host: Whether the memory is pinned memory, refer to CUDA
            pinned memory, default False.
        tensor(LiteTensor): the result tensor, user can also create the multi
            batch size tensor and then create the TensorBatchColletor, if tensor is
            not None, all the member, such as shape, dtype, device_type,
            device_id, is_pinned_host will get from the tensor, if the tensor is
            None and the result tensor will create by the TensorBatchCollector,
            default is None.

    Note:
        when collect tensor, the single batch tensor or array shape must match the 
        result tensor shape except the batch size dimension (the highest dimension)

    Examples:

        .. code-block:: python

            import numpy as np
            batch_tensor = TensorBatchCollector([4, 8, 8])
            arr = np.ones([8, 8], "int8")
            for i in range(4):
                batch_tensor.collect(arr)
                arr += 1
            data = batch_tensor.to_numpy()
            assert data.shape[0] == 4
            assert data.shape[1] == 8
            assert data.shape[2] == 8
            for i in range(4):
                for j in range(64):
                    assert data[i][j // 8][j % 8] == i + 1

61 62 63 64 65 66 67 68 69 70 71 72 73 74
    """

    def __init__(
        self,
        shape,
        dtype=LiteDataType.LITE_INT8,
        device_type=LiteDeviceType.LITE_CUDA,
        device_id=0,
        is_pinned_host=False,
        tensor=None,
    ):
        self._mutex = threading.Lock()
        self.dev_type = device_type
        self.is_pinned_host = is_pinned_host
75
        self.dev_id = device_id
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
        self.shape = shape
        self.dtype = LiteLayout(dtype=dtype).data_type
        self._free_list = list(range(self.shape[0]))

        if tensor is not None:
            assert (
                tensor.layout.shapes[0 : tensor.layout.ndim] == shape
            ), "The tensor set to TensorBatchCollector is not right."
            self._tensor = tensor
            self.dtype = tensor.layout.data_type
            self.device_type = tensor.device_type
            self.device_id = tensor.device_type
        else:
            self._tensor = LiteTensor(
                LiteLayout(shape, dtype), device_type, device_id, is_pinned_host
            )

    def collect_id(self, array, batch_id):
94 95 96 97 98 99 100 101 102 103 104
        """
        Collect a single batch through an array and store the array data to the
        specific batch_id.

        Args:
            array: an array maybe LiteTensor or numpy ndarray, the shape of
                array must match the result tensor shape except the highest
                dimension.
            batch_id: the batch id to store the array data to the result tensor,
                if the batch_id has already collected, a warning will generate.
        """
105 106 107 108 109 110 111 112 113 114 115 116 117
        # get the batch index
        with self._mutex:
            if batch_id in self._free_list:
                self._free_list.remove(batch_id)
            else:
                warnings.warn(
                    "batch {} has been collected, please call free before collected it again.".format(
                        batch_id
                    )
                )
        self._collect_with_id(array, batch_id)

    def _collect_with_id(self, array, batch_id):
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
        if isinstance(array, np.ndarray):
            shape = array.shape
            assert list(shape) == self.shape[1:]
            in_dtype = ctype_to_lite_dtypes[np.ctypeslib.as_ctypes_type(array.dtype)]
            assert in_dtype == self.dtype
            # get the subtensor
            subtensor = self._tensor.slice([batch_id], [batch_id + 1])
            if subtensor.device_type == LiteDeviceType.LITE_CPU:
                subtensor.set_data_by_copy(array)
            else:
                pinned_tensor = LiteTensor(
                    subtensor.layout, self.dev_type, self.dev_id, True
                )
                pinned_tensor.set_data_by_share(array)
                subtensor.copy_from(pinned_tensor)
        else:
            assert isinstance(array, LiteTensor)
            ndim = array.layout.ndim
            shape = list(array.layout.shapes)[0:ndim]
            assert list(shape) == self.shape[1:]
            in_dtype = array.layout.data_type
            assert in_dtype == self.dtype
            # get the subtensor
            subtensor = self._tensor.slice([batch_id], [batch_id + 1])
            subtensor.copy_from(array)

        return batch_id

    def collect(self, array):
147 148 149 150 151 152 153 154
        """
        Collect a single batch through an array and store the array data to an
        empty batch, the empty batch is the front batch id in free list.

        Args:
            array: an array maybe LiteTensor or numpy ndarray, the shape must
                match the result tensor shape except the highest dimension
        """
155 156
        with self._mutex:
            if len(self._free_list) == 0:
157 158 159
                warnings.warn(
                    "all batch has been collected, please call free before collect again."
                )
160 161
                return -1
            idx = self._free_list.pop(0)
162
        return self._collect_with_id(array, idx)
163 164 165

    def collect_by_ctypes(self, data, length):
        """
166 167 168 169 170 171 172
        Collect a single batch through an ctypes memory buffer and store the
        ctypes memory data to an empty batch, the empty batch is the front
        batch id in free list.

        Args:
            array: an array maybe LiteTensor or numpy ndarray, the shape must
                match the result tensor shape except the highest dimension
173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
        """
        with self._mutex:
            if len(self._free_list) == 0:
                return -1
            idx = self._free_list.pop(0)
        # get the subtensor
        subtensor = self._tensor.slice([idx], [idx + 1])
        if subtensor.device_type == LiteDeviceType.LITE_CPU:
            subtensor.set_data_by_copy(data, length)
        else:
            pinned_tensor = LiteTensor(
                subtensor.layout, self.dev_type, self.dev_id, True
            )
            pinned_tensor.set_data_by_share(data, length)
            subtensor.copy_from(pinned_tensor)

    def free(self, indexes):
190 191 192 193 194 195 196
        """
        free the batch ids in the indexes, after the batch id is freed, it can
        be collected again without warning.

        Args:
            indexes: a list of to be freed batch id
        """
197
        with self._mutex:
198 199 200 201 202 203
            for i in indexes:
                if i in self._free_list:
                    warnings.warn(
                        "batch id {} has not collected before free it.".format(i)
                    )
                    self._free_list.remove(i)
204 205 206
            self._free_list.extend(indexes)

    def get(self):
207 208 209
        """
        After finish collection, get the result tensor
        """
210 211 212
        return self._tensor

    def to_numpy(self):
213 214 215
        """
        Convert the result tensor to a numpy ndarray
        """
216
        return self._tensor.to_numpy()