reader.py 1.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
from utils import load_dict


def train_reader(data_file_path, word_dict_file):
    def reader():
        word_dict = load_dict(word_dict_file)

        unk_id = word_dict[u"<unk>"]
        bos_id = word_dict[u"<s>"]
        eos_id = word_dict[u"<e>"]

        with open(data_file_path, "r") as f:
            for line in f:
                line_split = line.strip().decode(
                    "utf8", errors="ignore").split("\t")
                if len(line_split) < 3: continue

                poetry = line_split[2].split(".")
                poetry_ids = []
                for sen in poetry:
                    if sen:
                        poetry_ids.append([bos_id] + [
                            word_dict.get(word, unk_id)
                            for word in "".join(sen.split())
                        ] + [eos_id])
                l = len(poetry_ids)
                if l < 2: continue
                for i in range(l - 1):
29 30
                    yield poetry_ids[i], poetry_ids[i + 1][:-1], poetry_ids[
                        i + 1][1:]
31 32 33 34 35 36 37 38 39 40 41 42 43 44

    return reader


def gen_reader(data_file_path, word_dict_file):
    def reader():
        word_dict = load_dict(word_dict_file)

        unk_id = word_dict[u"<unk>"]
        bos_id = word_dict[u"<s>"]
        eos_id = word_dict[u"<e>"]

        with open(data_file_path, "r") as f:
            for line in f:
45 46 47 48 49
                input_line = "".join(line.strip().decode(
                    "utf8", errors="ignore").split())
                yield [bos_id] + [
                    word_dict.get(word, unk_id) for word in input_line
                ] + [eos_id]
50 51

    return reader