utils.py 2.6 KB
Newer Older
P
Peng Li 已提交
1 2 3 4 5 6 7
import argparse
import gzip
import logging
import sys
import numpy

__all__ = [
8 9 10 11 12 13
    "open_file",
    "cumsum",
    "logger",
    "DotBar",
    "load_dict",
    "load_wordvecs",
P
Peng Li 已提交
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
]

logger = logging.getLogger("paddle")
logger.setLevel(logging.INFO)


def open_file(filename, *args1, **args2):
    """
    Open a file

    :param filename: name of the file
    :type filename: str
    :return: a file handler
    """
    if filename.endswith(".gz"):
        return gzip.open(filename, *args1, **args2)
    else:
        return open(filename, *args1, **args2)


def cumsum(array):
    """
    Caculute the accumulated sum of array. For example, array=[1, 2, 3], the
    result is [1, 1+2, 1+2+3]

    :param array: input array
    :type array: python list or numpy array
    :return: the accumulated sum of array
    """
    if len(array) <= 1:
        return list(array)
    ret = list(array)
    for i in xrange(1, len(ret)):
        ret[i] += ret[i - 1]
    return ret


class DotBar(object):
    """
    A simple dot bar
    """

    def __init__(self, obj, step=200, dots_per_line=50, f=sys.stderr):
        """
        :param obj: an iteratable obj
        :type obj: a python itertor
        :param step: print a dot every step iterations
        :type step: int
        :param dots_per_line: dots each line
        :type dots_per_line: int
        :param f: print dot to f, default value is sys.stderr
        :type f: a file handler
        """
        self.obj = obj
        self.step = step
        self.dots_per_line = dots_per_line
        self.f = f

72
    def __enter__(self, ):
P
Peng Li 已提交
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
        self.obj.__enter__()
        self.idx = 0
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.f.write("\n")
        if self.obj is sys.stdin or self.obj is sys.stdout:
            return
        self.obj.__exit__(exc_type, exc_value, traceback)

    def __iter__(self):
        return self

    def next(self):
        self.idx += 1
        if self.idx % self.step == 0:
            self.f.write(".")
        if self.idx % (self.step * self.dots_per_line) == 0:
            self.f.write("\n")

        return self.obj.next()


def load_dict(word_dict_path):
    with open_file(word_dict_path) as f:
        # the first word must be OOV
        vocab = {k.rstrip("\n").split()[0].decode("utf-8"):i \
                        for i, k in enumerate(f)}
    return vocab


def load_wordvecs(word_dict_path, wordvecs_path):
    vocab = load_dict(word_dict_path)
    wordvecs = numpy.loadtxt(wordvecs_path, delimiter=",", dtype="float32")
    assert len(vocab) == wordvecs.shape[0]
    return vocab, wordvecs