reader.py 2.6 KB
Newer Older
Y
Yibing Liu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
"""
Senta Reader
"""

import os
import types
import csv
import numpy as np
from utils import load_vocab
from utils import data_reader

import paddle
import paddle.fluid as fluid

class SentaProcessor(object):
    """
    Processor class for data convertors for senta
    """

20 21 22 23 24
    def __init__(self,
                 data_dir,
                 vocab_path,
                 random_seed,
                 max_seq_len):
Y
Yibing Liu 已提交
25 26 27 28
        self.data_dir = data_dir
        self.vocab = load_vocab(vocab_path)
        self.num_examples = {"train": -1, "dev": -1, "infer": -1}
        np.random.seed(random_seed)
29
        self.max_seq_len = max_seq_len
Y
Yibing Liu 已提交
30

31
    def get_train_examples(self, data_dir, epoch, max_seq_len):
Y
Yibing Liu 已提交
32 33 34
        """
        Load training examples
        """
35
        return data_reader((self.data_dir + "/train.tsv"), self.vocab, self.num_examples, "train", epoch, max_seq_len)
Y
Yibing Liu 已提交
36

37
    def get_dev_examples(self, data_dir, epoch, max_seq_len):
Y
Yibing Liu 已提交
38 39 40
        """
        Load dev examples
        """
41
        return data_reader((self.data_dir + "/dev.tsv"), self.vocab, self.num_examples, "dev", epoch, max_seq_len)
Y
Yibing Liu 已提交
42

43
    def get_test_examples(self, data_dir, epoch, max_seq_len):
Y
Yibing Liu 已提交
44 45 46
        """
        Load test examples
        """
47
        return data_reader((self.data_dir + "/test.tsv"), self.vocab, self.num_examples, "infer", epoch, max_seq_len)
Y
Yibing Liu 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74

    def get_labels(self):
        """
        Return Labels
        """
        return ["0", "1"]

    def get_num_examples(self, phase):
        """
        Return num of examples in train, dev, test set
        """
        if phase not in ['train', 'dev', 'infer']:
            raise ValueError(
                "Unknown phase, which should be in ['train', 'dev', 'infer'].")
        return self.num_examples[phase]

    def get_train_progress(self):
        """
        Get train progress
        """
        return self.current_train_example, self.current_train_epoch

    def data_generator(self, batch_size, phase='train', epoch=1, shuffle=True):
        """
        Generate data for train, dev or infer
        """
        if phase == "train":
75 76
            return paddle.batch(self.get_train_examples(self.data_dir, epoch, self.max_seq_len), batch_size)
            #return self.get_train_examples(self.data_dir, epoch, self.max_seq_len)
Y
Yibing Liu 已提交
77
        elif phase == "dev":
78
            return paddle.batch(self.get_dev_examples(self.data_dir, epoch, self.max_seq_len), batch_size)
Y
Yibing Liu 已提交
79
        elif phase == "infer":
80
            return paddle.batch(self.get_test_examples(self.data_dir, epoch, self.max_seq_len), batch_size)
Y
Yibing Liu 已提交
81 82 83
        else:
            raise ValueError(
                "Unknown phase, which should be in ['train', 'dev', 'infer'].")