utils.py 2.0 KB
Newer Older
Z
add ssr  
zhangwenhui03 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
import numpy as np
import reader as reader
import os
import logging
import paddle.fluid as fluid
import paddle


def get_vocab_size(vocab_path):
    with open(vocab_path, "r") as rf:
        line = rf.readline()
        return int(line.strip())


def construct_train_data(file_dir, vocab_path, batch_size):
    vocab_size = get_vocab_size(vocab_path)
    files = [file_dir + '/' + f for f in os.listdir(file_dir)]
    y_data = reader.YoochooseDataset(vocab_size)
19
    train_reader = fluid.io.batch(
Z
add ssr  
zhangwenhui03 已提交
20 21 22 23 24 25 26 27 28 29
        paddle.reader.shuffle(
            y_data.train(files), buf_size=batch_size * 100),
        batch_size=batch_size)
    return train_reader, vocab_size


def construct_test_data(file_dir, vocab_path, batch_size):
    vocab_size = get_vocab_size(vocab_path)
    files = [file_dir + '/' + f for f in os.listdir(file_dir)]
    y_data = reader.YoochooseDataset(vocab_size)
30
    test_reader = fluid.io.batch(y_data.test(files), batch_size=batch_size)
Z
add ssr  
zhangwenhui03 已提交
31 32
    return test_reader, vocab_size

F
frankwhzhang 已提交
33 34 35 36 37 38 39 40 41 42 43 44 45 46
def check_version():
     """
     Log error and exit when the installed version of paddlepaddle is
     not satisfied.
     """
     err = "PaddlePaddle version 1.6 or higher is required, " \
           "or a suitable develop version is satisfied as well. \n" \
           "Please make sure the version is good with your code." \

     try:
         fluid.require_version('1.6.0')
     except Exception as e:
         logger.error(err)
         sys.exit(1)
Z
add ssr  
zhangwenhui03 已提交
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63

def infer_data(raw_data, place):
    data = [dat[0] for dat in raw_data]
    seq_lens = [len(seq) for seq in data]
    cur_len = 0
    lod = [cur_len]
    for l in seq_lens:
        cur_len += l
        lod.append(cur_len)
    flattened_data = np.concatenate(data, axis=0).astype("int64")
    flattened_data = flattened_data.reshape([len(flattened_data), 1])
    res = fluid.LoDTensor()
    res.set(flattened_data, place)
    res.set_lod([lod])
    p_label = [dat[1] for dat in raw_data]
    pos_label = np.array(p_label).astype("int64").reshape(len(p_label), 1)
    return res, pos_label