utils.py 4.4 KB
Newer Older
L
Li Fuchen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14

Y
Yibing Liu 已提交
15 16 17 18 19 20 21 22 23 24
"""
EmoTect utilities.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import io
import os
import sys
25
import six
Y
Yibing Liu 已提交
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
import random

import paddle
import paddle.fluid as fluid
import numpy as np

def init_checkpoint(exe, init_checkpoint_path, main_program):
    """
    Init CheckPoint
    """
    assert os.path.exists(
        init_checkpoint_path), "[%s] cann't be found." % init_checkpoint_path

    def existed_persitables(var):
        """
        If existed presitabels
        """
        if not fluid.io.is_persistable(var):
            return False
        return os.path.exists(os.path.join(init_checkpoint_path, var.name))

    fluid.io.load_vars(
        exe,
        init_checkpoint_path,
        main_program=main_program,
        predicate=existed_persitables)
    print("Load model from {}".format(init_checkpoint_path))


55
def word2id(word_dict, query):
Y
Yibing Liu 已提交
56
    """
57
    Convert word sequence into id list
Y
Yibing Liu 已提交
58 59
    """
    unk_id = len(word_dict)
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
    wids = [word_dict[w] if w in word_dict else unk_id
            for w in query.strip().split(" ")]
    return wids


def pad_wid(wids, max_seq_len=128, pad_id=0):
    """
    Padding data to max_seq_len
    """
    seq_len = len(wids)
    if seq_len < max_seq_len:
        for i in range(max_seq_len - seq_len):
            wids.append(pad_id)
    else:
        wids = wids[:max_seq_len]
    seq_len = max_seq_len
    return wids, seq_len


def data_reader(file_path, word_dict, num_examples, phrase, epoch, max_seq_len):
    """
    Data reader, which convert word sequence into id list
    """
Y
Yibing Liu 已提交
83 84 85 86 87 88 89
    all_data = []
    with io.open(file_path, "r", encoding='utf8') as fin:
        for line in fin:
            if line.startswith("label"):
                continue
            if phrase == "infer":
                cols = line.strip().split("\t")
90 91 92 93
                query = cols[-1] if len(cols) != -1 else cols[0]
                wids = word2id(word_dict, query)
                wids, seq_len = pad_wid(wids, max_seq_len)
                all_data.append((wids, seq_len))
Y
Yibing Liu 已提交
94 95 96 97 98 99
            else:
                cols = line.strip().split("\t")
                if len(cols) != 2:
                    sys.stderr.write("[NOTICE] Error Format Line!")
                    continue
                label = int(cols[0])
100 101 102 103
                query = cols[1].strip()
                wids = word2id(word_dict, query)
                wids, seq_len = pad_wid(wids, max_seq_len)
                all_data.append((wids, label, seq_len))
Y
Yibing Liu 已提交
104 105 106
    num_examples[phrase] = len(all_data)

    if phrase == "infer":
L
Li Fuchen 已提交
107

Y
Yibing Liu 已提交
108 109 110 111
        def reader():
            """
            Infer reader function
            """
112 113
            for wids, seq_len in all_data:
                yield wids, seq_len
L
Li Fuchen 已提交
114

Y
Yibing Liu 已提交
115 116 117 118 119 120 121
        return reader

    def reader():
        """
        Reader function
        """
        for idx in range(epoch):
u010070587's avatar
u010070587 已提交
122
            if phrase == "train" and 'ce_mode' not in os.environ:
Y
Yibing Liu 已提交
123
                random.shuffle(all_data)
124 125
            for wids, label, seq_len in all_data:
                yield wids, label, seq_len
L
Li Fuchen 已提交
126

Y
Yibing Liu 已提交
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
    return reader


def load_vocab(file_path):
    """
    load the given vocabulary
    """
    vocab = {}
    with io.open(file_path, 'r', encoding='utf8') as fin:
        wid = 0
        for line in fin:
            if line.strip() not in vocab:
                vocab[line.strip()] = wid
                wid += 1
    vocab["<unk>"] = len(vocab)
    return vocab
143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161


def print_arguments(args):
    """
    print arguments
    """
    print('-----------  Configuration Arguments -----------')
    for arg, value in sorted(six.iteritems(vars(args))):
        print('%s: %s' % (arg, value))
    print('------------------------------------------------')


def query2ids(vocab_path, query):
    """
    Convert query to id list according to the given vocab
    """
    vocab = load_vocab(vocab_path)
    wids = word2id(vocab, query)
    return wids