utils.py 2.8 KB
Newer Older
Z
Zeyu Chen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
import os
import sys
import time
import numpy as np
import random

import paddle.fluid as fluid
import paddle


def get_predict_label(pos_prob):
    neg_prob = 1 - pos_prob
    # threshold should be (1, 0.5)
    neu_threshold = 0.55
    if neg_prob > neu_threshold:
        class3_label = 0
    elif pos_prob > neu_threshold:
        class3_label = 2
    else:
        class3_label = 1
    if pos_prob >= neg_prob:
        class2_label = 2
    else:
        class2_label = 0
    return class3_label, class2_label


def to_lodtensor(data, place):
    """
    convert ot LODtensor
    """
    seq_lens = [len(seq) for seq in data]
    cur_len = 0
    lod = [cur_len]
    for l in seq_lens:
        cur_len += l
        lod.append(cur_len)
    flattened_data = np.concatenate(data, axis=0).astype("int64")
    flattened_data = flattened_data.reshape([len(flattened_data), 1])
    res = fluid.LoDTensor()
    res.set(flattened_data, place)
    res.set_lod([lod])
    return res


def data2tensor(data, place):
    """
    data2tensor
    """
    input_seq = to_lodtensor(map(lambda x: x[0], data), place)
    return {"words": input_seq}


def data_reader(file_path, word_dict, is_shuffle=True):
    """
    Convert word sequence into slot
    """
    unk_id = len(word_dict)
    all_data = []
    with open(file_path, "r") as fin:
        for line in fin:
            cols = line.strip().split("\t")
            label = int(cols[0])
            wids = [
                word_dict[x] if x in word_dict else unk_id
                for x in cols[1].split(" ")
            ]
            all_data.append((wids, label))
    if is_shuffle:
        random.shuffle(all_data)

    def reader():
        for doc, label in all_data:
            yield doc, label

    return reader


def load_vocab(file_path):
    """
    load the given vocabulary
    """
    vocab = {}
    with open(file_path) as f:
        wid = 0
        for line in f:
            vocab[line.strip()] = wid
            wid += 1
    vocab["<unk>"] = len(vocab)
    return vocab


def prepare_data(data_path, word_dict_path, batch_size, mode):
    """
    prepare data
    """
    assert os.path.exists(
        word_dict_path), "The given word dictionary dose not exist."
    if mode == "train":
        assert os.path.exists(
            data_path), "The given training data does not exist."
    if mode == "eval" or mode == "infer":
        assert os.path.exists(data_path), "The given test data does not exist."

    word_dict = load_vocab(word_dict_path)
    if mode == "train":
        train_reader = paddle.batch(
            data_reader(data_path, word_dict, True), batch_size)
        return word_dict, train_reader
    else:
        test_reader = paddle.batch(
            data_reader(data_path, word_dict, False), batch_size)
        return word_dict, test_reader