utils.py 2.7 KB
Newer Older
S
Superjom 已提交
1
import logging
S
Superjom 已提交
2
import paddle
S
Superjom 已提交
3 4 5

UNK = 0

S
Superjom 已提交
6
logger = logging.getLogger("paddle")
S
Superjom 已提交
7 8 9
logger.setLevel(logging.INFO)


S
Superjom 已提交
10
def mode_attr_name(mode):
C
caoying03 已提交
11
    return mode.upper() + "_MODE"
S
Superjom 已提交
12 13


S
Superjom 已提交
14 15 16 17 18 19
def create_attrs(cls):
    for id, mode in enumerate(cls.modes):
        setattr(cls, mode_attr_name(mode), id)


def make_check_method(cls):
C
caoying03 已提交
20
    """
S
Superjom 已提交
21
    create methods for classes.
C
caoying03 已提交
22
    """
S
Superjom 已提交
23 24 25 26 27 28 29 30

    def method(mode):
        def _method(self):
            return self.mode == getattr(cls, mode_attr_name(mode))

        return _method

    for id, mode in enumerate(cls.modes):
C
caoying03 已提交
31
        setattr(cls, "is_" + mode, method(mode))
S
Superjom 已提交
32 33 34 35 36 37 38 39 40 41 42 43


def make_create_method(cls):
    def method(mode):
        @staticmethod
        def _method():
            key = getattr(cls, mode_attr_name(mode))
            return cls(key)

        return _method

    for id, mode in enumerate(cls.modes):
C
caoying03 已提交
44
        setattr(cls, "create_" + mode, method(mode))
S
Superjom 已提交
45 46


C
caoying03 已提交
47
def make_str_method(cls, type_name="unk"):
S
Superjom 已提交
48 49 50 51
    def _str_(self):
        for mode in cls.modes:
            if self.mode == getattr(cls, mode_attr_name(mode)):
                return mode
S
Superjom 已提交
52

S
Superjom 已提交
53 54
    def _hash_(self):
        return self.mode
S
Superjom 已提交
55

C
caoying03 已提交
56 57 58
    setattr(cls, "__str__", _str_)
    setattr(cls, "__repr__", _str_)
    setattr(cls, "__hash__", _hash_)
S
Superjom 已提交
59
    cls.__name__ = type_name
S
Superjom 已提交
60

S
Superjom 已提交
61 62 63 64 65 66 67

def _init_(self, mode, cls):
    if isinstance(mode, int):
        self.mode = mode
    elif isinstance(mode, cls):
        self.mode = mode.mode
    else:
C
caoying03 已提交
68
        raise Exception("A wrong mode type, get type: %s, value: %s." %
S
Superjom 已提交
69
                        (type(mode), mode))
S
Superjom 已提交
70 71 72 73 74 75 76 77 78 79


def build_mode_class(cls):
    create_attrs(cls)
    make_str_method(cls)
    make_check_method(cls)
    make_create_method(cls)


class TaskType(object):
C
caoying03 已提交
80
    modes = "train test infer".split()
S
Superjom 已提交
81 82 83

    def __init__(self, mode):
        _init_(self, mode, TaskType)
S
Superjom 已提交
84 85 86


class ModelType:
C
caoying03 已提交
87
    modes = "classification rank regression".split()
S
Superjom 已提交
88 89

    def __init__(self, mode):
S
Superjom 已提交
90 91
        _init_(self, mode, ModelType)

S
Superjom 已提交
92

S
Superjom 已提交
93
class ModelArch:
C
caoying03 已提交
94
    modes = "fc cnn rnn".split()
S
Superjom 已提交
95

S
Superjom 已提交
96 97
    def __init__(self, mode):
        _init_(self, mode, ModelArch)
S
Superjom 已提交
98 99


S
Superjom 已提交
100 101 102
build_mode_class(TaskType)
build_mode_class(ModelType)
build_mode_class(ModelArch)
S
Superjom 已提交
103 104 105


def sent2ids(sent, vocab):
C
caoying03 已提交
106
    """
S
Superjom 已提交
107
    transform a sentence to a list of ids.
C
caoying03 已提交
108
    """
S
Superjom 已提交
109 110 111 112
    return [vocab.get(w, UNK) for w in sent.split()]


def load_dic(path):
C
caoying03 已提交
113 114 115
    """
    The format of word dictionary : each line is a word.
    """
S
Superjom 已提交
116 117 118 119 120 121
    dic = {}
    with open(path) as f:
        for id, line in enumerate(f):
            w = line.strip()
            dic[w] = id
    return dic
S
Superjom 已提交
122 123


S
Superjom 已提交
124
def display_args(args):
C
caoying03 已提交
125
    logger.info("The arguments passed by command line is :")
S
Superjom 已提交
126 127
    for k, v in sorted(v for v in vars(args).items()):
        logger.info("{}:\t{}".format(k, v))