提交 5c886ced 编写于 作者: W wangmeng28

Implement train data generation and preprocess for chinese poetry

上级 fb18316b
此差异已折叠。
#!/bin/bash
git clone https://github.com/chinese-poetry/chinese-poetry.git
if [ ! -d raw ]
then
mkdir raw
fi
mv chinese-poetry/json/poet.tang.* raw/
rm -rf chinese-poetry
# -*- coding: utf-8 -*-
import os
import io
import re
import json
import click
import collections
def build_vocabulary(dataset, cutoff=0):
dictionary = collections.defaultdict(int)
for data in dataset:
for sent in data[2]:
for char in sent:
dictionary[char] += 1
dictionary = filter(lambda x: x[1] >= cutoff, dictionary.items())
dictionary = sorted(dictionary, key=lambda x: (-x[1], x[0]))
vocab, _ = list(zip(*dictionary))
return (u"<unk>", u"<s>", u"<e>") + vocab
@click.command("preprocess")
@click.option("--datadir", type=str, help="Path to raw data")
@click.option("--outfile", type=str, help="Path to save the training data")
@click.option("--dictfile", type=str, help="Path to save the dictionary file")
def preprocess(datadir, outfile, dictfile):
dataset = []
note_pattern1 = re.compile(u"(.*?)", re.U)
note_pattern2 = re.compile(u"〖.*?〗", re.U)
note_pattern3 = re.compile(u"-.*?-。?", re.U)
note_pattern4 = re.compile(u"(.*$", re.U)
note_pattern5 = re.compile(u"。。.*)$", re.U)
note_pattern6 = re.compile(u"。。", re.U)
note_pattern7 = re.compile(u"[《》「」\[\]]", re.U)
print("Loading raw data...")
for fn in os.listdir(datadir):
with io.open(os.path.join(datadir, fn), "r", encoding="utf8") as f:
for data in json.load(f):
title = data['title']
author = data['author']
p = "".join(data['paragraphs'])
p = "".join(p.split())
p = note_pattern1.sub(u"", p)
p = note_pattern2.sub(u"", p)
p = note_pattern3.sub(u"", p)
p = note_pattern4.sub(u"", p)
p = note_pattern5.sub(u"。", p)
p = note_pattern6.sub(u"。", p)
p = note_pattern7.sub(u"", p)
if (p == u"" or u"{" in p or u"}" in p or u"{" in p or
u"}" in p or u"、" in p or u":" in p or u";" in p or
u"!" in p or u"?" in p or u"●" in p or u"□" in p or
u"囗" in p or u")" in p):
continue
paragraphs = p.split(u"。")
paragraphs = filter(lambda x: len(x), paragraphs)
if len(paragraphs) > 1:
dataset.append((title, author, paragraphs))
print("Finished...")
print("Constructing vocabularies...")
vocab = build_vocabulary(dataset, cutoff=10)
with io.open(dictfile, "w", encoding="utf8") as f:
for v in vocab:
f.write(v + "\n")
print("Finished...")
print("Writing processed data...")
with io.open(outfile, "w", encoding="utf8") as f:
for data in dataset:
title = data[0]
author = data[1]
paragraphs = ".".join(data[2])
f.write("\t".join((title, author, paragraphs)) + "\n")
print("Finished...")
if __name__ == "__main__":
preprocess()
......@@ -44,7 +44,7 @@ def load_initial_model(model_path, parameters):
@click.option(
"--decoder_depth",
default=3,
help="The number of stacked LSTM layers in encoder.")
help="The number of stacked LSTM layers in decoder.")
@click.option(
"--train_data_path", required=True, help="The path of trainning data.")
@click.option(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册