utils.py 2.0 KB
Newer Older
Z
add ssr  
zhangwenhui03 已提交
1 2 3 4 5 6
import numpy as np
import reader as reader
import os
import logging
import paddle.fluid as fluid
import paddle
Z
zhang wenhui 已提交
7
import io
Z
add ssr  
zhangwenhui03 已提交
8 9 10


def get_vocab_size(vocab_path):
Z
zhang wenhui 已提交
11
    with io.open(vocab_path, "r", encoding='utf-8') as rf:
Z
add ssr  
zhangwenhui03 已提交
12 13 14 15 16 17 18 19
        line = rf.readline()
        return int(line.strip())


def construct_train_data(file_dir, vocab_path, batch_size):
    vocab_size = get_vocab_size(vocab_path)
    files = [file_dir + '/' + f for f in os.listdir(file_dir)]
    y_data = reader.YoochooseDataset(vocab_size)
Z
zhang wenhui 已提交
20
    train_reader = fluid.io.batch(
Z
zhang wenhui 已提交
21
        fluid.io.shuffle(
Z
add ssr  
zhangwenhui03 已提交
22 23 24 25 26 27 28 29 30
            y_data.train(files), buf_size=batch_size * 100),
        batch_size=batch_size)
    return train_reader, vocab_size


def construct_test_data(file_dir, vocab_path, batch_size):
    vocab_size = get_vocab_size(vocab_path)
    files = [file_dir + '/' + f for f in os.listdir(file_dir)]
    y_data = reader.YoochooseDataset(vocab_size)
Z
zhang wenhui 已提交
31
    test_reader = fluid.io.batch(y_data.test(files), batch_size=batch_size)
Z
add ssr  
zhangwenhui03 已提交
32 33 34
    return test_reader, vocab_size


Z
zhang wenhui 已提交
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
def check_version():
    """
     Log error and exit when the installed version of paddlepaddle is
     not satisfied.
     """
    err = "PaddlePaddle version 1.6 or higher is required, " \
          "or a suitable develop version is satisfied as well. \n" \
          "Please make sure the version is good with your code." \

    try:
        fluid.require_version('1.6.0')
    except Exception as e:
        logger.error(err)
        sys.exit(1)


Z
add ssr  
zhangwenhui03 已提交
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
def infer_data(raw_data, place):
    data = [dat[0] for dat in raw_data]
    seq_lens = [len(seq) for seq in data]
    cur_len = 0
    lod = [cur_len]
    for l in seq_lens:
        cur_len += l
        lod.append(cur_len)
    flattened_data = np.concatenate(data, axis=0).astype("int64")
    flattened_data = flattened_data.reshape([len(flattened_data), 1])
    res = fluid.LoDTensor()
    res.set(flattened_data, place)
    res.set_lod([lod])
    p_label = [dat[1] for dat in raw_data]
    pos_label = np.array(p_label).astype("int64").reshape(len(p_label), 1)
    return res, pos_label