utils.py 6.4 KB
Newer Older
F
frankwhzhang 已提交
1 2 3 4 5 6 7
import sys
import collections
import six
import time
import numpy as np
import paddle.fluid as fluid
import paddle
F
frankwhzhang 已提交
8
import os
Z
zhang wenhui 已提交
9
import io
F
frankwhzhang 已提交
10

11

F
frankwhzhang 已提交
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
def to_lodtensor(data, place):
    """ convert to LODtensor """
    seq_lens = [len(seq) for seq in data]
    cur_len = 0
    lod = [cur_len]
    for l in seq_lens:
        cur_len += l
        lod.append(cur_len)
    flattened_data = np.concatenate(data, axis=0).astype("int64")
    flattened_data = flattened_data.reshape([len(flattened_data), 1])
    res = fluid.LoDTensor()
    res.set(flattened_data, place)
    res.set_lod([lod])
    return res

27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48

def to_lodtensor_bpr(raw_data, neg_size, vocab_size, place):
    """ convert to LODtensor """
    data = [dat[0] for dat in raw_data]
    seq_lens = [len(seq) for seq in data]
    cur_len = 0
    lod = [cur_len]
    for l in seq_lens:
        cur_len += l
        lod.append(cur_len)
    flattened_data = np.concatenate(data, axis=0).astype("int64")
    flattened_data = flattened_data.reshape([len(flattened_data), 1])
    res = fluid.LoDTensor()
    res.set(flattened_data, place)
    res.set_lod([lod])

    data = [dat[1] for dat in raw_data]
    pos_data = np.concatenate(data, axis=0).astype("int64")
    length = np.size(pos_data)
    neg_data = np.tile(pos_data, neg_size)
    np.random.shuffle(neg_data)
    for ii in range(length * neg_size):
Z
zhangwenhui03 已提交
49 50
        if neg_data[ii] == pos_data[ii // neg_size]:
            neg_data[ii] = pos_data[length - 1 - ii // neg_size]
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88

    label_data = np.column_stack(
        (pos_data.reshape(length, 1), neg_data.reshape(length, neg_size)))
    res_label = fluid.LoDTensor()
    res_label.set(label_data, place)
    res_label.set_lod([lod])

    res_pos = fluid.LoDTensor()
    res_pos.set(np.zeros([len(flattened_data), 1]).astype("int64"), place)
    res_pos.set_lod([lod])

    return res, res_pos, res_label


def to_lodtensor_bpr_test(raw_data, vocab_size, place):
    """ convert to LODtensor """
    data = [dat[0] for dat in raw_data]
    seq_lens = [len(seq) for seq in data]
    cur_len = 0
    lod = [cur_len]
    for l in seq_lens:
        cur_len += l
        lod.append(cur_len)
    flattened_data = np.concatenate(data, axis=0).astype("int64")
    flattened_data = flattened_data.reshape([len(flattened_data), 1])
    res = fluid.LoDTensor()
    res.set(flattened_data, place)
    res.set_lod([lod])

    data = [dat[1] for dat in raw_data]
    flattened_data = np.concatenate(data, axis=0).astype("int64")
    flattened_data = flattened_data.reshape([len(flattened_data), 1])
    res_pos = fluid.LoDTensor()
    res_pos.set(flattened_data, place)
    res_pos.set_lod([lod])
    return res, res_pos


F
frankwhzhang 已提交
89
def get_vocab_size(vocab_path):
Z
zhang wenhui 已提交
90
    with io.open(vocab_path, "r", encoding='utf-8') as rf:
F
frankwhzhang 已提交
91 92
        line = rf.readline()
        return int(line.strip())
F
frankwhzhang 已提交
93

94

F
frankwhzhang 已提交
95 96
def prepare_data(file_dir,
                 vocab_path,
F
frankwhzhang 已提交
97 98 99
                 batch_size,
                 buffer_size=1000,
                 word_freq_threshold=0,
F
frankwhzhang 已提交
100
                 is_train=True):
F
frankwhzhang 已提交
101 102
    """ prepare the English Pann Treebank (PTB) data """
    print("start constuct word dict")
u010070587's avatar
u010070587 已提交
103
    if is_train and 'ce_mode' not in os.environ:
F
frankwhzhang 已提交
104 105
        vocab_size = get_vocab_size(vocab_path)
        reader = sort_batch(
F
frankwhzhang 已提交
106 107
            paddle.reader.shuffle(
                train(
F
frankwhzhang 已提交
108
                    file_dir, buffer_size, data_type=DataType.SEQ),
F
frankwhzhang 已提交
109 110 111
                buf_size=buffer_size),
            batch_size,
            batch_size * 20)
F
frankwhzhang 已提交
112
    else:
113
        vocab_size = get_vocab_size(vocab_path)
Z
zhang wenhui 已提交
114
        reader = fluid.io.batch(
F
frankwhzhang 已提交
115
            test(
116
                file_dir, buffer_size, data_type=DataType.SEQ), batch_size)
F
frankwhzhang 已提交
117
    return vocab_size, reader
F
frankwhzhang 已提交
118 119


Z
zhang wenhui 已提交
120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
def check_version():
    """
     Log error and exit when the installed version of paddlepaddle is
     not satisfied.
     """
    err = "PaddlePaddle version 1.6 or higher is required, " \
          "or a suitable develop version is satisfied as well. \n" \
          "Please make sure the version is good with your code." \

    try:
        fluid.require_version('1.6.0')
    except Exception as e:
        logger.error(err)
        sys.exit(1)


F
frankwhzhang 已提交
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
def sort_batch(reader, batch_size, sort_group_size, drop_last=False):
    """
    Create a batched reader.
    :param reader: the data reader to read from.
    :type reader: callable
    :param batch_size: size of each mini-batch
    :type batch_size: int
    :param sort_group_size: size of partial sorted batch
    :type sort_group_size: int
    :param drop_last: drop the last batch, if the size of last batch is not equal to batch_size.
    :type drop_last: bool
    :return: the batched reader.
    :rtype: callable
    """

    def batch_reader():
        r = reader()
        b = []
        for instance in r:
            b.append(instance)
            if len(b) == sort_group_size:
                sortl = sorted(b, key=lambda x: len(x[0]), reverse=True)
                b = []
                c = []
                for sort_i in sortl:
                    c.append(sort_i)
                    if (len(c) == batch_size):
                        yield c
                        c = []
        if drop_last == False and len(b) != 0:
            sortl = sorted(b, key=lambda x: len(x[0]), reverse=True)
            c = []
            for sort_i in sortl:
                c.append(sort_i)
        if (len(c) == batch_size):
            yield c
            c = []

    # Batch size check
    batch_size = int(batch_size)
    if batch_size <= 0:
        raise ValueError("batch_size should be a positive integeral value, "
                         "but got batch_size={}".format(batch_size))
    return batch_reader


class DataType(object):
    SEQ = 2

185

F
frankwhzhang 已提交
186
def reader_creator(file_dir, n, data_type):
F
frankwhzhang 已提交
187
    def reader():
F
frankwhzhang 已提交
188 189
        files = os.listdir(file_dir)
        for fi in files:
Z
zhang wenhui 已提交
190 191
            with io.open(
                    os.path.join(file_dir, fi), "r", encoding='utf-8') as f:
F
frankwhzhang 已提交
192 193 194 195 196 197 198 199 200 201
                for l in f:
                    if DataType.SEQ == data_type:
                        l = l.strip().split()
                        l = [w for w in l]
                        src_seq = l[:len(l) - 1]
                        trg_seq = l[1:]
                        if n > 0 and len(src_seq) > n: continue
                        yield src_seq, trg_seq
                    else:
                        assert False, 'error data type'
202

F
frankwhzhang 已提交
203 204
    return reader

205

F
frankwhzhang 已提交
206 207
def train(train_dir, n, data_type=DataType.SEQ):
    return reader_creator(train_dir, n, data_type)
F
frankwhzhang 已提交
208

209

F
frankwhzhang 已提交
210 211
def test(test_dir, n, data_type=DataType.SEQ):
    return reader_creator(test_dir, n, data_type)