reader.py 1.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
import os


def train_reader(data_dir, word_dict, label_dict):
    """
    Reader interface for training data

    :param data_dir: data directory
    :type data_dir: str
    :param word_dict: path of word dictionary,
        the dictionary must has a "UNK" in it.
    :type word_dict: Python dict
    :param label_dict: path of label dictionary
    :type label_dict: Python dict
    """

    def reader():
        UNK_ID = word_dict["<UNK>"]
        word_col = 1
        lbl_col = 0

        for file_name in os.listdir(data_dir):
            with open(os.path.join(data_dir, file_name), "r") as f:
                for line in f:
                    line_split = line.strip().split("\t")
                    word_ids = [
                        word_dict.get(w, UNK_ID)
                        for w in line_split[word_col].split()
                    ]
                    yield word_ids, label_dict[line_split[lbl_col]]

    return reader


def test_reader(data_dir, word_dict):
    """
    Reader interface for testing data

    :param data_dir: data directory.
    :type data_dir: str
    :param word_dict: path of word dictionary,
        the dictionary must has a "UNK" in it.
    :type word_dict: Python dict
    """

    def reader():
        UNK_ID = word_dict["<UNK>"]
        word_col = 1

        for file_name in os.listdir(data_dir):
            with open(os.path.join(data_dir, file_name), "r") as f:
                for line in f:
                    line_split = line.strip().split("\t")
                    if len(line_split) < word_col: continue
                    word_ids = [
                        word_dict.get(w, UNK_ID)
                        for w in line_split[word_col].split()
                    ]
                    yield word_ids, line_split[word_col]

    return reader