util.py 5.1 KB
Newer Older
1 2 3
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
4 5 6
import sys
from six import reraise
from tblib import Traceback
7 8
from multiprocessing import Manager, Process
import posix_ipc, mmap
9

10 11
import numpy as np

12

Z
zhxfl 已提交
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
def to_lodtensor(data, place):
    """convert tensor to lodtensor
    """
    seq_lens = [len(seq) for seq in data]
    cur_len = 0
    lod = [cur_len]
    for l in seq_lens:
        cur_len += l
        lod.append(cur_len)
    flattened_data = numpy.concatenate(data, axis=0).astype("int64")
    flattened_data = flattened_data.reshape([len(flattened_data), 1])
    res = fluid.LoDTensor()
    res.set(flattened_data, place)
    res.set_lod([lod])
    return res


def lodtensor_to_ndarray(lod_tensor):
    """conver lodtensor to ndarray
    """
    dims = lod_tensor.get_dims()
    ret = np.zeros(shape=dims).astype('float32')
    for i in xrange(np.product(dims)):
        ret.ravel()[i] = lod_tensor.get_float_element(i)
    return ret, lod_tensor.lod()
38 39


40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
def batch_to_ndarray(batch_samples, lod):
    frame_dim = batch_samples[0][0].shape[1]
    batch_feature = np.zeros((lod[-1], frame_dim), dtype="float32")
    batch_label = np.zeros((lod[-1], 1), dtype="int64")
    start = 0
    for sample in batch_samples:
        frame_num = sample[0].shape[0]
        batch_feature[start:start + frame_num, :] = sample[0]
        batch_label[start:start + frame_num, :] = sample[1]
        start += frame_num
    return (batch_feature, batch_label)


class DaemonProcessGroup(object):
    def __init__(self, proc_num, target, args):
        self._proc_num = proc_num
        self._workers = [
            Process(
                target=target, args=args) for _ in xrange(self._proc_num)
        ]

    def start_all(self):
        for w in self._workers:
            w.daemon = True
            w.start()

    @property
    def proc_num(self):
        return self._proc_num


class EpochEndSignal(object):
    pass


75 76 77 78
class CriticalException(Exception):
    pass


79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
class SharedNDArray(object):
    def __init__(self, name, is_verify=False):
        self._name = name
        self._shm = None
        self._buf = None
        self._array = np.zeros(1, dtype=np.float32)
        self._inited = False
        self._is_verify = is_verify

    def zeros_like(self, shape, dtype):
        size = int(np.prod(shape)) * np.dtype(dtype).itemsize
        if self._inited:
            self._shm = posix_ipc.SharedMemory(self._name)
        else:
            self._shm = posix_ipc.SharedMemory(
                self._name, posix_ipc.O_CREAT, size=size)
        self._buf = mmap.mmap(self._shm.fd, size)
        self._array = np.ndarray(shape, dtype, self._buf, order='C')

    def copy(self, ndarray):
        size = int(np.prod(ndarray.shape)) * np.dtype(ndarray.dtype).itemsize
        self.zeros_like(ndarray.shape, ndarray.dtype)
        self._array[:] = ndarray
        self._buf.flush()
        self._inited = True

        if self._is_verify:
            shm = posix_ipc.SharedMemory(self._name)
            buf = mmap.mmap(shm.fd, size)
            array = np.ndarray(ndarray.shape, ndarray.dtype, buf, order='C')
            np.testing.assert_array_equal(array, ndarray)

    @property
    def ndarray(self):
        return self._array

    def recycle(self, pool):
        self._buf.close()
        self._shm.close_fd()
        self._inited = False
        pool[self._name] = self

    def __getstate__(self):
        return (self._name, self._array.shape, self._array.dtype, self._inited,
                self._is_verify)

    def __setstate__(self, state):
        self._name = state[0]
        self._inited = state[3]
        self.zeros_like(state[1], state[2])
        self._is_verify = state[4]


class SharedMemoryPoolManager(object):
    def __init__(self, pool_size, manager, name_prefix='/deep_asr'):
        self._names = []
        self._dict = manager.dict()

        for i in xrange(pool_size):
            name = name_prefix + '_' + str(i)
            self._dict[name] = SharedNDArray(name)
            self._names.append(name)

    @property
    def pool(self):
        return self._dict

    def __del__(self):
        for name in self._names:
            # have to unlink the shared memory
            posix_ipc.unlink_shared_memory(name)


Y
yangyaming 已提交
152 153 154 155
def suppress_signal(signo, stack_frame):
    pass


156
def suppress_complaints(verbose, notify=None):
157 158 159 160 161 162
    def decorator_maker(func):
        def suppress_warpper(*args, **kwargs):
            try:
                func(*args, **kwargs)
            except:
                et, ev, tb = sys.exc_info()
163 164 165 166 167 168

                if notify is not None:
                    notify(except_type=et, except_value=ev, traceback=tb)

                if verbose == 1 or isinstance(ev, CriticalException):
                    reraise(et, ev, Traceback(tb).as_traceback())
169 170 171 172

        return suppress_warpper

    return decorator_maker
173 174 175 176 177 178 179 180 181 182 183 184


class ForceExitWrapper(object):
    def __init__(self, exit_flag):
        self._exit_flag = exit_flag

    @suppress_complaints(verbose=0)
    def __call__(self, *args, **kwargs):
        self._exit_flag.value = True

    def __eq__(self, flag):
        return self._exit_flag.value == flag