utils.py 5.9 KB
Newer Older
F
frankwhzhang 已提交
1 2 3 4 5 6 7
import sys
import collections
import six
import time
import numpy as np
import paddle.fluid as fluid
import paddle
F
frankwhzhang 已提交
8
import os
F
frankwhzhang 已提交
9

10

F
frankwhzhang 已提交
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
def to_lodtensor(data, place):
    """ convert to LODtensor """
    seq_lens = [len(seq) for seq in data]
    cur_len = 0
    lod = [cur_len]
    for l in seq_lens:
        cur_len += l
        lod.append(cur_len)
    flattened_data = np.concatenate(data, axis=0).astype("int64")
    flattened_data = flattened_data.reshape([len(flattened_data), 1])
    res = fluid.LoDTensor()
    res.set(flattened_data, place)
    res.set_lod([lod])
    return res

26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47

def to_lodtensor_bpr(raw_data, neg_size, vocab_size, place):
    """ convert to LODtensor """
    data = [dat[0] for dat in raw_data]
    seq_lens = [len(seq) for seq in data]
    cur_len = 0
    lod = [cur_len]
    for l in seq_lens:
        cur_len += l
        lod.append(cur_len)
    flattened_data = np.concatenate(data, axis=0).astype("int64")
    flattened_data = flattened_data.reshape([len(flattened_data), 1])
    res = fluid.LoDTensor()
    res.set(flattened_data, place)
    res.set_lod([lod])

    data = [dat[1] for dat in raw_data]
    pos_data = np.concatenate(data, axis=0).astype("int64")
    length = np.size(pos_data)
    neg_data = np.tile(pos_data, neg_size)
    np.random.shuffle(neg_data)
    for ii in range(length * neg_size):
Z
zhangwenhui03 已提交
48 49
        if neg_data[ii] == pos_data[ii // neg_size]:
            neg_data[ii] = pos_data[length - 1 - ii // neg_size]
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87

    label_data = np.column_stack(
        (pos_data.reshape(length, 1), neg_data.reshape(length, neg_size)))
    res_label = fluid.LoDTensor()
    res_label.set(label_data, place)
    res_label.set_lod([lod])

    res_pos = fluid.LoDTensor()
    res_pos.set(np.zeros([len(flattened_data), 1]).astype("int64"), place)
    res_pos.set_lod([lod])

    return res, res_pos, res_label


def to_lodtensor_bpr_test(raw_data, vocab_size, place):
    """ convert to LODtensor """
    data = [dat[0] for dat in raw_data]
    seq_lens = [len(seq) for seq in data]
    cur_len = 0
    lod = [cur_len]
    for l in seq_lens:
        cur_len += l
        lod.append(cur_len)
    flattened_data = np.concatenate(data, axis=0).astype("int64")
    flattened_data = flattened_data.reshape([len(flattened_data), 1])
    res = fluid.LoDTensor()
    res.set(flattened_data, place)
    res.set_lod([lod])

    data = [dat[1] for dat in raw_data]
    flattened_data = np.concatenate(data, axis=0).astype("int64")
    flattened_data = flattened_data.reshape([len(flattened_data), 1])
    res_pos = fluid.LoDTensor()
    res_pos.set(flattened_data, place)
    res_pos.set_lod([lod])
    return res, res_pos


F
frankwhzhang 已提交
88 89 90 91
def get_vocab_size(vocab_path):
    with open(vocab_path, "r") as rf:
        line = rf.readline()
        return int(line.strip())
F
frankwhzhang 已提交
92

93

F
frankwhzhang 已提交
94 95
def prepare_data(file_dir,
                 vocab_path,
F
frankwhzhang 已提交
96 97 98
                 batch_size,
                 buffer_size=1000,
                 word_freq_threshold=0,
F
frankwhzhang 已提交
99
                 is_train=True):
F
frankwhzhang 已提交
100 101
    """ prepare the English Pann Treebank (PTB) data """
    print("start constuct word dict")
F
frankwhzhang 已提交
102 103 104
    if is_train:
        vocab_size = get_vocab_size(vocab_path)
        reader = sort_batch(
F
frankwhzhang 已提交
105 106
            paddle.reader.shuffle(
                train(
F
frankwhzhang 已提交
107
                    file_dir, buffer_size, data_type=DataType.SEQ),
F
frankwhzhang 已提交
108 109 110
                buf_size=buffer_size),
            batch_size,
            batch_size * 20)
F
frankwhzhang 已提交
111
    else:
112 113
        vocab_size = get_vocab_size(vocab_path)
        reader = paddle.batch(
F
frankwhzhang 已提交
114
            test(
115
                file_dir, buffer_size, data_type=DataType.SEQ), batch_size)
F
frankwhzhang 已提交
116
    return vocab_size, reader
F
frankwhzhang 已提交
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167


def sort_batch(reader, batch_size, sort_group_size, drop_last=False):
    """
    Create a batched reader.
    :param reader: the data reader to read from.
    :type reader: callable
    :param batch_size: size of each mini-batch
    :type batch_size: int
    :param sort_group_size: size of partial sorted batch
    :type sort_group_size: int
    :param drop_last: drop the last batch, if the size of last batch is not equal to batch_size.
    :type drop_last: bool
    :return: the batched reader.
    :rtype: callable
    """

    def batch_reader():
        r = reader()
        b = []
        for instance in r:
            b.append(instance)
            if len(b) == sort_group_size:
                sortl = sorted(b, key=lambda x: len(x[0]), reverse=True)
                b = []
                c = []
                for sort_i in sortl:
                    c.append(sort_i)
                    if (len(c) == batch_size):
                        yield c
                        c = []
        if drop_last == False and len(b) != 0:
            sortl = sorted(b, key=lambda x: len(x[0]), reverse=True)
            c = []
            for sort_i in sortl:
                c.append(sort_i)
        if (len(c) == batch_size):
            yield c
            c = []

    # Batch size check
    batch_size = int(batch_size)
    if batch_size <= 0:
        raise ValueError("batch_size should be a positive integeral value, "
                         "but got batch_size={}".format(batch_size))
    return batch_reader


class DataType(object):
    SEQ = 2

168

F
frankwhzhang 已提交
169
def reader_creator(file_dir, n, data_type):
F
frankwhzhang 已提交
170
    def reader():
F
frankwhzhang 已提交
171 172 173 174 175 176 177 178 179 180 181 182 183
        files = os.listdir(file_dir)
        for fi in files:
            with open(file_dir + '/' + fi, "r") as f:
                for l in f:
                    if DataType.SEQ == data_type:
                        l = l.strip().split()
                        l = [w for w in l]
                        src_seq = l[:len(l) - 1]
                        trg_seq = l[1:]
                        if n > 0 and len(src_seq) > n: continue
                        yield src_seq, trg_seq
                    else:
                        assert False, 'error data type'
184

F
frankwhzhang 已提交
185 186
    return reader

187

F
frankwhzhang 已提交
188 189
def train(train_dir, n, data_type=DataType.SEQ):
    return reader_creator(train_dir, n, data_type)
F
frankwhzhang 已提交
190

191

F
frankwhzhang 已提交
192 193
def test(test_dir, n, data_type=DataType.SEQ):
    return reader_creator(test_dir, n, data_type)