reader.py 2.0 KB
Newer Older
Y
Yibing Liu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
"""
Reader for auto dialogue evaluation
"""

import sys
import time
import numpy as np
import random

import paddle.fluid as fluid
import paddle

def to_lodtensor(data, place):
    """
    Convert to LODtensor 
    """
    seq_lens = [len(seq) for seq in data]
    cur_len = 0
    lod = [cur_len]
    for l in seq_lens:
        cur_len += l
        lod.append(cur_len)
    flattened_data = np.concatenate(data, axis=0).astype("int64")
    flattened_data = flattened_data.reshape([len(flattened_data), 1])
    res = fluid.LoDTensor()
    res.set(flattened_data, place)
    res.set_lod([lod])
    return res


def reshape_batch(batch, place):
    """
    Reshape batch
    """
    context_reshape = to_lodtensor([dat[0] for dat in batch], place)
    response_reshape = to_lodtensor([dat[1] for dat in batch], place)
    label_reshape = [dat[2] for dat in batch]
    return (context_reshape, response_reshape, label_reshape)


def batch_reader(data_path,
                 batch_size,
                 place,
                 max_len=50,
                 sample_pro=1):
    """
    Yield batch
    """
    batch = []
    with open(data_path, 'r') as f:
        Print = True
        for line in f:
            #sample for training data
            if sample_pro < 1:
                if random.random() > sample_pro:
                    continue

            tokens = line.strip().split('\t')
            assert len(tokens) == 3
            context = [int(x) for x in tokens[0].split()[:max_len]]
            response = [int(x) for x in tokens[1].split()[:max_len]]

            label = [int(tokens[2])]
            #label = int(tokens[2])
            instance = (context, response, label)

            if len(batch) < batch_size:
                batch.append(instance)
            else:
                if len(batch) == batch_size:
                    yield reshape_batch(batch, place)
                batch = [instance]

        if len(batch) == batch_size:
            yield reshape_batch(batch, place)