utils.py 4.7 KB
Newer Older
1
# -*- coding: utf-8 -*-
2
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
3
#
4 5 6 7 8
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
import threading
10
import warnings
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35

import numpy as np

from .base import *
from .struct import *
from .tensor import *


class TensorBatchCollector:
    """
    this is a tensor utils to collect subtensor in batch continuous
    """

    def __init__(
        self,
        shape,
        dtype=LiteDataType.LITE_INT8,
        device_type=LiteDeviceType.LITE_CUDA,
        device_id=0,
        is_pinned_host=False,
        tensor=None,
    ):
        self._mutex = threading.Lock()
        self.dev_type = device_type
        self.is_pinned_host = is_pinned_host
36
        self.dev_id = device_id
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
        self.shape = shape
        self.dtype = LiteLayout(dtype=dtype).data_type
        self._free_list = list(range(self.shape[0]))

        if tensor is not None:
            assert (
                tensor.layout.shapes[0 : tensor.layout.ndim] == shape
            ), "The tensor set to TensorBatchCollector is not right."
            self._tensor = tensor
            self.dtype = tensor.layout.data_type
            self.device_type = tensor.device_type
            self.device_id = tensor.device_type
        else:
            self._tensor = LiteTensor(
                LiteLayout(shape, dtype), device_type, device_id, is_pinned_host
            )

    def collect_id(self, array, batch_id):
55 56 57 58 59 60 61 62 63 64 65 66 67
        # get the batch index
        with self._mutex:
            if batch_id in self._free_list:
                self._free_list.remove(batch_id)
            else:
                warnings.warn(
                    "batch {} has been collected, please call free before collected it again.".format(
                        batch_id
                    )
                )
        self._collect_with_id(array, batch_id)

    def _collect_with_id(self, array, batch_id):
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
        if isinstance(array, np.ndarray):
            shape = array.shape
            assert list(shape) == self.shape[1:]
            in_dtype = ctype_to_lite_dtypes[np.ctypeslib.as_ctypes_type(array.dtype)]
            assert in_dtype == self.dtype
            # get the subtensor
            subtensor = self._tensor.slice([batch_id], [batch_id + 1])
            if subtensor.device_type == LiteDeviceType.LITE_CPU:
                subtensor.set_data_by_copy(array)
            else:
                pinned_tensor = LiteTensor(
                    subtensor.layout, self.dev_type, self.dev_id, True
                )
                pinned_tensor.set_data_by_share(array)
                subtensor.copy_from(pinned_tensor)
        else:
            assert isinstance(array, LiteTensor)
            ndim = array.layout.ndim
            shape = list(array.layout.shapes)[0:ndim]
            assert list(shape) == self.shape[1:]
            in_dtype = array.layout.data_type
            assert in_dtype == self.dtype
            # get the subtensor
            subtensor = self._tensor.slice([batch_id], [batch_id + 1])
            subtensor.copy_from(array)

        return batch_id

    def collect(self, array):
        with self._mutex:
            if len(self._free_list) == 0:
99 100 101
                warnings.warn(
                    "all batch has been collected, please call free before collect again."
                )
102 103
                return -1
            idx = self._free_list.pop(0)
104
        return self._collect_with_id(array, idx)
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126

    def collect_by_ctypes(self, data, length):
        """
        collect with ctypes data input
        """
        with self._mutex:
            if len(self._free_list) == 0:
                return -1
            idx = self._free_list.pop(0)
        # get the subtensor
        subtensor = self._tensor.slice([idx], [idx + 1])
        if subtensor.device_type == LiteDeviceType.LITE_CPU:
            subtensor.set_data_by_copy(data, length)
        else:
            pinned_tensor = LiteTensor(
                subtensor.layout, self.dev_type, self.dev_id, True
            )
            pinned_tensor.set_data_by_share(data, length)
            subtensor.copy_from(pinned_tensor)

    def free(self, indexes):
        with self._mutex:
127 128 129 130 131 132
            for i in indexes:
                if i in self._free_list:
                    warnings.warn(
                        "batch id {} has not collected before free it.".format(i)
                    )
                    self._free_list.remove(i)
133 134 135 136 137 138 139
            self._free_list.extend(indexes)

    def get(self):
        return self._tensor

    def to_numpy(self):
        return self._tensor.to_numpy()