util.py 6.0 KB
Newer Older
1 2 3
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
4 5 6
import sys
from six import reraise
from tblib import Traceback
7 8
from multiprocessing import Manager, Process
import posix_ipc, mmap
9

10 11
import numpy as np

12

Z
zhxfl 已提交
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
def to_lodtensor(data, place):
    """convert tensor to lodtensor
    """
    seq_lens = [len(seq) for seq in data]
    cur_len = 0
    lod = [cur_len]
    for l in seq_lens:
        cur_len += l
        lod.append(cur_len)
    flattened_data = numpy.concatenate(data, axis=0).astype("int64")
    flattened_data = flattened_data.reshape([len(flattened_data), 1])
    res = fluid.LoDTensor()
    res.set(flattened_data, place)
    res.set_lod([lod])
    return res


def lodtensor_to_ndarray(lod_tensor):
    """conver lodtensor to ndarray
    """
    dims = lod_tensor.get_dims()
    ret = np.zeros(shape=dims).astype('float32')
    for i in xrange(np.product(dims)):
        ret.ravel()[i] = lod_tensor.get_float_element(i)
    return ret, lod_tensor.lod()
38 39


40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
def batch_to_ndarray(batch_samples, lod):
    frame_dim = batch_samples[0][0].shape[1]
    batch_feature = np.zeros((lod[-1], frame_dim), dtype="float32")
    batch_label = np.zeros((lod[-1], 1), dtype="int64")
    start = 0
    for sample in batch_samples:
        frame_num = sample[0].shape[0]
        batch_feature[start:start + frame_num, :] = sample[0]
        batch_label[start:start + frame_num, :] = sample[1]
        start += frame_num
    return (batch_feature, batch_label)


class DaemonProcessGroup(object):
    def __init__(self, proc_num, target, args):
        self._proc_num = proc_num
        self._workers = [
            Process(
                target=target, args=args) for _ in xrange(self._proc_num)
        ]

    def start_all(self):
        for w in self._workers:
            w.daemon = True
            w.start()

    @property
    def proc_num(self):
        return self._proc_num


class EpochEndSignal(object):
    pass


75 76 77 78
class CriticalException(Exception):
    pass


79
class SharedNDArray(object):
X
Xinghai Sun 已提交
80 81 82 83 84 85 86 87 88
    """SharedNDArray utilizes shared memory to avoid data serialization when
    object of which shared between different processes. We can reconstruct the
    ndarray when memory address provided.

    Args:
        name (str): Address name of shared memory.
        is_verify (bool): Whether to do validation for writing operation.
    """

89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
    def __init__(self, name, is_verify=False):
        self._name = name
        self._shm = None
        self._buf = None
        self._array = np.zeros(1, dtype=np.float32)
        self._inited = False
        self._is_verify = is_verify

    def zeros_like(self, shape, dtype):
        size = int(np.prod(shape)) * np.dtype(dtype).itemsize
        if self._inited:
            self._shm = posix_ipc.SharedMemory(self._name)
        else:
            self._shm = posix_ipc.SharedMemory(
                self._name, posix_ipc.O_CREAT, size=size)
        self._buf = mmap.mmap(self._shm.fd, size)
        self._array = np.ndarray(shape, dtype, self._buf, order='C')

    def copy(self, ndarray):
        size = int(np.prod(ndarray.shape)) * np.dtype(ndarray.dtype).itemsize
        self.zeros_like(ndarray.shape, ndarray.dtype)
        self._array[:] = ndarray
        self._buf.flush()
        self._inited = True

        if self._is_verify:
            shm = posix_ipc.SharedMemory(self._name)
            buf = mmap.mmap(shm.fd, size)
            array = np.ndarray(ndarray.shape, ndarray.dtype, buf, order='C')
            np.testing.assert_array_equal(array, ndarray)

    @property
    def ndarray(self):
        return self._array

    def recycle(self, pool):
        self._buf.close()
        self._shm.close_fd()
        self._inited = False
        pool[self._name] = self

    def __getstate__(self):
        return (self._name, self._array.shape, self._array.dtype, self._inited,
                self._is_verify)

    def __setstate__(self, state):
        self._name = state[0]
        self._inited = state[3]
        self.zeros_like(state[1], state[2])
        self._is_verify = state[4]


class SharedMemoryPoolManager(object):
X
Xinghai Sun 已提交
142 143 144 145 146 147 148 149 150 151 152 153
    """SharedMemoryPoolManager maintains a multiprocessing.Manager.dict object.
    All available addresses are allocated once and will be reused. Though this
    class is not process-safe, the pool can be shared between processes. All
    shared memory should be unlinked before the main process exited.

    Args:
        pool_size (int): Size of shared memory pool.
        manager (dict): A multiprocessing.Manager object, the pool is
                        maintained by the proxy process.
        name_prefix (str): Address prefix of shared memory.
    """

154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
    def __init__(self, pool_size, manager, name_prefix='/deep_asr'):
        self._names = []
        self._dict = manager.dict()

        for i in xrange(pool_size):
            name = name_prefix + '_' + str(i)
            self._dict[name] = SharedNDArray(name)
            self._names.append(name)

    @property
    def pool(self):
        return self._dict

    def __del__(self):
        for name in self._names:
            # have to unlink the shared memory
            posix_ipc.unlink_shared_memory(name)


Y
yangyaming 已提交
173 174 175 176
def suppress_signal(signo, stack_frame):
    pass


177
def suppress_complaints(verbose, notify=None):
178 179 180 181 182 183
    def decorator_maker(func):
        def suppress_warpper(*args, **kwargs):
            try:
                func(*args, **kwargs)
            except:
                et, ev, tb = sys.exc_info()
184 185 186 187 188 189

                if notify is not None:
                    notify(except_type=et, except_value=ev, traceback=tb)

                if verbose == 1 or isinstance(ev, CriticalException):
                    reraise(et, ev, Traceback(tb).as_traceback())
190 191 192 193

        return suppress_warpper

    return decorator_maker
194 195 196 197 198 199 200 201 202 203 204 205


class ForceExitWrapper(object):
    def __init__(self, exit_flag):
        self._exit_flag = exit_flag

    @suppress_complaints(verbose=0)
    def __call__(self, *args, **kwargs):
        self._exit_flag.value = True

    def __eq__(self, flag):
        return self._exit_flag.value == flag