util.py 6.3 KB
Newer Older
1 2 3
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
X
Xinghai Sun 已提交
4
import sys, time
5 6
from six import reraise
from tblib import Traceback
7 8
from multiprocessing import Manager, Process
import posix_ipc, mmap
9

10 11
import numpy as np

12

Z
zhxfl 已提交
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
def to_lodtensor(data, place):
    """convert tensor to lodtensor
    """
    seq_lens = [len(seq) for seq in data]
    cur_len = 0
    lod = [cur_len]
    for l in seq_lens:
        cur_len += l
        lod.append(cur_len)
    flattened_data = numpy.concatenate(data, axis=0).astype("int64")
    flattened_data = flattened_data.reshape([len(flattened_data), 1])
    res = fluid.LoDTensor()
    res.set(flattened_data, place)
    res.set_lod([lod])
    return res


def lodtensor_to_ndarray(lod_tensor):
    """conver lodtensor to ndarray
    """
    dims = lod_tensor.get_dims()
    ret = np.zeros(shape=dims).astype('float32')
    for i in xrange(np.product(dims)):
        ret.ravel()[i] = lod_tensor.get_float_element(i)
    return ret, lod_tensor.lod()
38 39


40 41 42 43 44 45 46 47 48 49 50 51 52
def batch_to_ndarray(batch_samples, lod):
    frame_dim = batch_samples[0][0].shape[1]
    batch_feature = np.zeros((lod[-1], frame_dim), dtype="float32")
    batch_label = np.zeros((lod[-1], 1), dtype="int64")
    start = 0
    for sample in batch_samples:
        frame_num = sample[0].shape[0]
        batch_feature[start:start + frame_num, :] = sample[0]
        batch_label[start:start + frame_num, :] = sample[1]
        start += frame_num
    return (batch_feature, batch_label)


Y
Yibing Liu 已提交
53 54 55 56 57 58 59
def split_infer_result(infer_seq, lod):
    infer_batch = []
    for i in xrange(0, len(lod[0]) - 1):
        infer_batch.append(infer_seq[lod[0][i]:lod[0][i + 1]])
    return infer_batch


60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
class DaemonProcessGroup(object):
    def __init__(self, proc_num, target, args):
        self._proc_num = proc_num
        self._workers = [
            Process(
                target=target, args=args) for _ in xrange(self._proc_num)
        ]

    def start_all(self):
        for w in self._workers:
            w.daemon = True
            w.start()

    @property
    def proc_num(self):
        return self._proc_num


class EpochEndSignal(object):
    pass


82 83 84 85
class CriticalException(Exception):
    pass


86
class SharedNDArray(object):
X
Xinghai Sun 已提交
87
    """SharedNDArray utilizes shared memory to avoid data serialization when
X
Xinghai Sun 已提交
88 89
    data object shared among different processes. We can reconstruct the
    `ndarray` when memory address, shape and dtype provided.
X
Xinghai Sun 已提交
90 91 92

    Args:
        name (str): Address name of shared memory.
X
Xinghai Sun 已提交
93
        whether_verify (bool): Whether to validate the writing operation.
X
Xinghai Sun 已提交
94 95
    """

X
Xinghai Sun 已提交
96
    def __init__(self, name, whether_verify=False):
97 98 99 100 101
        self._name = name
        self._shm = None
        self._buf = None
        self._array = np.zeros(1, dtype=np.float32)
        self._inited = False
X
Xinghai Sun 已提交
102
        self._whether_verify = whether_verify
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120

    def zeros_like(self, shape, dtype):
        size = int(np.prod(shape)) * np.dtype(dtype).itemsize
        if self._inited:
            self._shm = posix_ipc.SharedMemory(self._name)
        else:
            self._shm = posix_ipc.SharedMemory(
                self._name, posix_ipc.O_CREAT, size=size)
        self._buf = mmap.mmap(self._shm.fd, size)
        self._array = np.ndarray(shape, dtype, self._buf, order='C')

    def copy(self, ndarray):
        size = int(np.prod(ndarray.shape)) * np.dtype(ndarray.dtype).itemsize
        self.zeros_like(ndarray.shape, ndarray.dtype)
        self._array[:] = ndarray
        self._buf.flush()
        self._inited = True

X
Xinghai Sun 已提交
121
        if self._whether_verify:
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
            shm = posix_ipc.SharedMemory(self._name)
            buf = mmap.mmap(shm.fd, size)
            array = np.ndarray(ndarray.shape, ndarray.dtype, buf, order='C')
            np.testing.assert_array_equal(array, ndarray)

    @property
    def ndarray(self):
        return self._array

    def recycle(self, pool):
        self._buf.close()
        self._shm.close_fd()
        self._inited = False
        pool[self._name] = self

    def __getstate__(self):
        return (self._name, self._array.shape, self._array.dtype, self._inited,
X
Xinghai Sun 已提交
139
                self._whether_verify)
140 141 142 143 144

    def __setstate__(self, state):
        self._name = state[0]
        self._inited = state[3]
        self.zeros_like(state[1], state[2])
X
Xinghai Sun 已提交
145
        self._whether_verify = state[4]
146 147 148


class SharedMemoryPoolManager(object):
X
Xinghai Sun 已提交
149 150 151 152 153 154 155 156 157 158 159 160
    """SharedMemoryPoolManager maintains a multiprocessing.Manager.dict object.
    All available addresses are allocated once and will be reused. Though this
    class is not process-safe, the pool can be shared between processes. All
    shared memory should be unlinked before the main process exited.

    Args:
        pool_size (int): Size of shared memory pool.
        manager (dict): A multiprocessing.Manager object, the pool is
                        maintained by the proxy process.
        name_prefix (str): Address prefix of shared memory.
    """

161 162 163
    def __init__(self, pool_size, manager, name_prefix='/deep_asr'):
        self._names = []
        self._dict = manager.dict()
X
Xinghai Sun 已提交
164
        self._time_prefix = time.strftime('%Y%m%d%H%M%S')
165 166

        for i in xrange(pool_size):
X
Xinghai Sun 已提交
167
            name = name_prefix + '_' + self._time_prefix + '_' + str(i)
168 169 170 171 172 173 174 175 176 177 178 179 180
            self._dict[name] = SharedNDArray(name)
            self._names.append(name)

    @property
    def pool(self):
        return self._dict

    def __del__(self):
        for name in self._names:
            # have to unlink the shared memory
            posix_ipc.unlink_shared_memory(name)


Y
yangyaming 已提交
181 182 183 184
def suppress_signal(signo, stack_frame):
    pass


185
def suppress_complaints(verbose, notify=None):
186 187 188 189 190 191
    def decorator_maker(func):
        def suppress_warpper(*args, **kwargs):
            try:
                func(*args, **kwargs)
            except:
                et, ev, tb = sys.exc_info()
192 193 194 195 196 197

                if notify is not None:
                    notify(except_type=et, except_value=ev, traceback=tb)

                if verbose == 1 or isinstance(ev, CriticalException):
                    reraise(et, ev, Traceback(tb).as_traceback())
198 199 200 201

        return suppress_warpper

    return decorator_maker
202 203 204 205 206 207 208 209 210 211 212 213


class ForceExitWrapper(object):
    def __init__(self, exit_flag):
        self._exit_flag = exit_flag

    @suppress_complaints(verbose=0)
    def __call__(self, *args, **kwargs):
        self._exit_flag.value = True

    def __eq__(self, flag):
        return self._exit_flag.value == flag