util.py 1.8 KB
Newer Older
1 2 3
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
4
import sys
5 6
from six import reraise
from tblib import Traceback
7

8 9
import numpy as np

10

Z
zhxfl 已提交
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
def to_lodtensor(data, place):
    """convert tensor to lodtensor
    """
    seq_lens = [len(seq) for seq in data]
    cur_len = 0
    lod = [cur_len]
    for l in seq_lens:
        cur_len += l
        lod.append(cur_len)
    flattened_data = numpy.concatenate(data, axis=0).astype("int64")
    flattened_data = flattened_data.reshape([len(flattened_data), 1])
    res = fluid.LoDTensor()
    res.set(flattened_data, place)
    res.set_lod([lod])
    return res


Y
Yibing Liu 已提交
28 29 30 31 32 33 34
def split_infer_result(infer_seq, lod):
    infer_batch = []
    for i in xrange(0, len(lod[0]) - 1):
        infer_batch.append(infer_seq[lod[0][i]:lod[0][i + 1]])
    return infer_batch


35 36 37 38
class CriticalException(Exception):
    pass


Y
yangyaming 已提交
39 40 41 42
def suppress_signal(signo, stack_frame):
    pass


43
def suppress_complaints(verbose, notify=None):
44 45 46 47 48 49
    def decorator_maker(func):
        def suppress_warpper(*args, **kwargs):
            try:
                func(*args, **kwargs)
            except:
                et, ev, tb = sys.exc_info()
50 51 52 53 54 55

                if notify is not None:
                    notify(except_type=et, except_value=ev, traceback=tb)

                if verbose == 1 or isinstance(ev, CriticalException):
                    reraise(et, ev, Traceback(tb).as_traceback())
56 57 58 59

        return suppress_warpper

    return decorator_maker
60 61 62 63 64 65 66 67 68 69 70 71


class ForceExitWrapper(object):
    def __init__(self, exit_flag):
        self._exit_flag = exit_flag

    @suppress_complaints(verbose=0)
    def __call__(self, *args, **kwargs):
        self._exit_flag.value = True

    def __eq__(self, flag):
        return self._exit_flag.value == flag