提交 96a2e44a 编写于 作者: Q qiaolongfei

optimize seq2seq-dataset

上级 37806792
...@@ -12,22 +12,176 @@ ...@@ -12,22 +12,176 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import sys
import random
import operator import operator
import numpy as np
from subprocess import Popen, PIPE
from os.path import join as join_path
from optparse import OptionParser from optparse import OptionParser
from os.path import join as join_path
from subprocess import Popen, PIPE
import numpy as np
from paddle.utils.preprocess_util import * from paddle.utils.preprocess_util import *
from paddle.utils.preprocess_util import save_list, DatasetCreater
""" """
Usage: run following command to show help message. Usage: run following command to show help message.
python preprocess.py -h python preprocess.py -h
""" """
class SeqToSeqDatasetCreater(DatasetCreater):
"""
A class to process data for sequence to sequence application.
"""
def __init__(self, data_path, output_path):
"""
data_path: the path to store the train data, test data and gen data
output_path: the path to store the processed dataset
"""
DatasetCreater.__init__(self, data_path)
self.gen_dir_name = 'gen'
self.gen_list_name = 'gen.list'
self.output_path = output_path
def concat_file(self, file_path, file1, file2, output_path, output):
"""
Concat file1 and file2 to be one output file
The i-th line of output = i-th line of file1 + '\t' + i-th line of file2
file_path: the path to store file1 and file2
output_path: the path to store output file
"""
file1 = os.path.join(file_path, file1)
file2 = os.path.join(file_path, file2)
output = os.path.join(output_path, output)
if not os.path.exists(output):
os.system('paste ' + file1 + ' ' + file2 + ' > ' + output)
def cat_file(self, dir_path, suffix, output_path, output):
"""
Cat all the files in dir_path with suffix to be one output file
dir_path: the base directory to store input file
suffix: suffix of file name
output_path: the path to store output file
"""
cmd = 'cat '
file_list = os.listdir(dir_path)
file_list.sort()
for file in file_list:
if file.endswith(suffix):
cmd += os.path.join(dir_path, file) + ' '
output = os.path.join(output_path, output)
if not os.path.exists(output):
os.system(cmd + '> ' + output)
def build_dict(self, file_path, dict_path, dict_size=-1):
"""
Create the dictionary for the file, Note that
1. Valid characters include all printable characters
2. There is distinction between uppercase and lowercase letters
3. There is 3 special token:
<s>: the start of a sequence
<e>: the end of a sequence
<unk>: a word not included in dictionary
file_path: the path to store file
dict_path: the path to store dictionary
dict_size: word count of dictionary
if is -1, dictionary will contains all the words in file
"""
if not os.path.exists(dict_path):
dictory = dict()
with open(file_path, "r") as fdata:
for line in fdata:
line = line.split('\t')
for line_split in line:
words = line_split.strip().split()
for word in words:
if word not in dictory:
dictory[word] = 1
else:
dictory[word] += 1
output = open(dict_path, "w+")
output.write('<s>\n<e>\n<unk>\n')
count = 3
for key, value in sorted(
dictory.items(), key=lambda d: d[1], reverse=True):
output.write(key + "\n")
count += 1
if count == dict_size:
break
self.dict_size = count
def create_dataset(self,
dict_size=-1,
mergeDict=False,
suffixes=['.src', '.trg']):
"""
Create seqToseq dataset
"""
# dataset_list and dir_list has one-to-one relationship
train_dataset = os.path.join(self.data_path, self.train_dir_name)
test_dataset = os.path.join(self.data_path, self.test_dir_name)
gen_dataset = os.path.join(self.data_path, self.gen_dir_name)
dataset_list = [train_dataset, test_dataset, gen_dataset]
train_dir = os.path.join(self.output_path, self.train_dir_name)
test_dir = os.path.join(self.output_path, self.test_dir_name)
gen_dir = os.path.join(self.output_path, self.gen_dir_name)
dir_list = [train_dir, test_dir, gen_dir]
# create directory
for dir in dir_list:
if not os.path.exists(dir):
os.makedirs(dir)
# checkout dataset should be parallel corpora
suffix_len = len(suffixes[0])
for dataset in dataset_list:
file_list = os.listdir(dataset)
if len(file_list) % 2 == 1:
raise RuntimeError("dataset should be parallel corpora")
file_list.sort()
for i in range(0, len(file_list), 2):
if file_list[i][:-suffix_len] != file_list[i + 1][:-suffix_len]:
raise RuntimeError(
"source and target file name should be equal")
# cat all the files with the same suffix in dataset
for suffix in suffixes:
for dataset in dataset_list:
outname = os.path.basename(dataset) + suffix
self.cat_file(dataset, suffix, dataset, outname)
# concat parallel corpora and create file.list
print 'concat parallel corpora for dataset'
id = 0
list = ['train.list', 'test.list', 'gen.list']
for dataset in dataset_list:
outname = os.path.basename(dataset)
self.concat_file(dataset, outname + suffixes[0],
outname + suffixes[1], dir_list[id], outname)
save_list([os.path.join(dir_list[id], outname)],
os.path.join(self.output_path, list[id]))
id += 1
# build dictionary for train data
dict = ['src.dict', 'trg.dict']
dict_path = [
os.path.join(self.output_path, dict[0]),
os.path.join(self.output_path, dict[1])
]
if mergeDict:
outname = os.path.join(train_dir, train_dataset.split('/')[-1])
print 'build src dictionary for train data'
self.build_dict(outname, dict_path[0], dict_size)
print 'build trg dictionary for train data'
os.system('cp ' + dict_path[0] + ' ' + dict_path[1])
else:
outname = os.path.join(train_dataset, self.train_dir_name)
for id in range(0, 2):
suffix = suffixes[id]
print 'build ' + suffix[1:] + ' dictionary for train data'
self.build_dict(outname + suffix, dict_path[id], dict_size)
print 'dictionary size is', self.dict_size
def save_dict(dict, filename, is_reverse=True): def save_dict(dict, filename, is_reverse=True):
""" """
Save dictionary into file. Save dictionary into file.
......
...@@ -14,72 +14,68 @@ ...@@ -14,72 +14,68 @@
""" """
wmt14 dataset wmt14 dataset
""" """
import os
import os.path
import tarfile import tarfile
import paddle.v2.dataset.common import paddle.v2.dataset.common
from wmt14_util import SeqToSeqDatasetCreater
__all__ = ['train', 'test', 'build_dict'] __all__ = ['train', 'test', 'build_dict']
URL_DEV_TEST = 'http://www-lium.univ-lemans.fr/~schwenk/cslm_joint_paper/data/dev+test.tgz' URL_DEV_TEST = 'http://www-lium.univ-lemans.fr/~schwenk/cslm_joint_paper/data/dev+test.tgz'
MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5' MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5'
# this is a small set of data for test. The original data is too large and will be add later. # this is a small set of data for test. The original data is too large and will be add later.
URL_TRAIN = 'http://paddlepaddle.bj.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz' URL_TRAIN = 'http://localhost:8989/wmt14.tgz'
MD5_TRAIN = '7373473f86016f1f48037c9c340a2d5b' MD5_TRAIN = 'a755315dd01c2c35bde29a744ede23a6'
START = "<s>" START = "<s>"
END = "<e>" END = "<e>"
UNK = "<unk>" UNK = "<unk>"
UNK_IDX = 2 UNK_IDX = 2
DEFAULT_DATA_DIR = "./data"
ORIGIN_DATA_DIR = "wmt14" def __read_to_dict__(tar_file, dict_size):
INNER_DATA_DIR = "pre-wmt14" def __to_dict__(fd, size):
SRC_DICT = INNER_DATA_DIR + "/src.dict"
TRG_DICT = INNER_DATA_DIR + "/trg.dict"
TRAIN_FILE = INNER_DATA_DIR + "/train/train"
def __process_data__(data_path, dict_size=None):
downloaded_data = os.path.join(data_path, ORIGIN_DATA_DIR)
if not os.path.exists(downloaded_data):
# 1. download and extract tgz.
with tarfile.open(
paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14',
MD5_TRAIN)) as tf:
tf.extractall(data_path)
# 2. process data file to intermediate format.
processed_data = os.path.join(data_path, INNER_DATA_DIR)
if not os.path.exists(processed_data):
dict_size = dict_size or -1
data_creator = SeqToSeqDatasetCreater(downloaded_data, processed_data)
data_creator.create_dataset(dict_size, mergeDict=False)
def __read_to_dict__(dict_path, count):
with open(dict_path, "r") as fin:
out_dict = dict() out_dict = dict()
for line_count, line in enumerate(fin): for line_count, line in enumerate(fd):
if line_count <= count: if line_count < size:
out_dict[line.strip()] = line_count out_dict[line.strip()] = line_count
else: else:
break break
return out_dict return out_dict
with tarfile.open(tar_file, mode='r') as f:
def __reader__(file_name, src_dict, trg_dict): names = [
with open(file_name, 'r') as f: each_item.name for each_item in f
for line_count, line in enumerate(f): if each_item.name.endswith("src.dict")
]
assert len(names) == 1
src_dict = __to_dict__(f.extractfile(names[0]), dict_size)
names = [
each_item.name for each_item in f
if each_item.name.endswith("trg.dict")
]
assert len(names) == 1
trg_dict = __to_dict__(f.extractfile(names[0]), dict_size)
return src_dict, trg_dict
def reader_creator(tar_file, file_name, dict_size):
def reader():
src_dict, trg_dict = __read_to_dict__(tar_file, dict_size)
with tarfile.open(tar_file, mode='r') as f:
names = [
each_item.name for each_item in f
if each_item.name.endswith(file_name)
]
for name in names:
for line in f.extractfile(name):
line_split = line.strip().split('\t') line_split = line.strip().split('\t')
if len(line_split) != 2: if len(line_split) != 2:
continue continue
src_seq = line_split[0] # one source sequence src_seq = line_split[0] # one source sequence
src_words = src_seq.split() src_words = src_seq.split()
src_ids = [ src_ids = [
src_dict.get(w, UNK_IDX) for w in [START] + src_words + [END] src_dict.get(w, UNK_IDX)
for w in [START] + src_words + [END]
] ]
trg_seq = line_split[1] # one target sequence trg_seq = line_split[1] # one target sequence
...@@ -94,23 +90,16 @@ def __reader__(file_name, src_dict, trg_dict): ...@@ -94,23 +90,16 @@ def __reader__(file_name, src_dict, trg_dict):
yield src_ids, trg_ids, trg_ids_next yield src_ids, trg_ids, trg_ids_next
return reader
def train(data_dir=None, dict_size=None):
data_dir = data_dir or DEFAULT_DATA_DIR
__process_data__(data_dir, dict_size)
src_lang_dict = os.path.join(data_dir, SRC_DICT)
trg_lang_dict = os.path.join(data_dir, TRG_DICT)
train_file_name = os.path.join(data_dir, TRAIN_FILE)
default_dict_size = len(open(src_lang_dict, "r").readlines())
if dict_size > default_dict_size:
raise ValueError("dict_dim should not be larger then the "
"length of word dict")
real_dict_dim = dict_size or default_dict_size def train(dict_size):
return reader_creator(
paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN),
'train/train', dict_size)
src_dict = __read_to_dict__(src_lang_dict, real_dict_dim)
trg_dict = __read_to_dict__(trg_lang_dict, real_dict_dim)
return lambda: __reader__(train_file_name, src_dict, trg_dict) def test(dict_size):
return reader_creator(
paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN),
'test/test', dict_size)
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from paddle.utils.preprocess_util import save_list, DatasetCreater
class SeqToSeqDatasetCreater(DatasetCreater):
"""
A class to process data for sequence to sequence application.
"""
def __init__(self, data_path, output_path):
"""
data_path: the path to store the train data, test data and gen data
output_path: the path to store the processed dataset
"""
DatasetCreater.__init__(self, data_path)
self.gen_dir_name = 'gen'
self.gen_list_name = 'gen.list'
self.output_path = output_path
def concat_file(self, file_path, file1, file2, output_path, output):
"""
Concat file1 and file2 to be one output file
The i-th line of output = i-th line of file1 + '\t' + i-th line of file2
file_path: the path to store file1 and file2
output_path: the path to store output file
"""
file1 = os.path.join(file_path, file1)
file2 = os.path.join(file_path, file2)
output = os.path.join(output_path, output)
if not os.path.exists(output):
os.system('paste ' + file1 + ' ' + file2 + ' > ' + output)
def cat_file(self, dir_path, suffix, output_path, output):
"""
Cat all the files in dir_path with suffix to be one output file
dir_path: the base directory to store input file
suffix: suffix of file name
output_path: the path to store output file
"""
cmd = 'cat '
file_list = os.listdir(dir_path)
file_list.sort()
for file in file_list:
if file.endswith(suffix):
cmd += os.path.join(dir_path, file) + ' '
output = os.path.join(output_path, output)
if not os.path.exists(output):
os.system(cmd + '> ' + output)
def build_dict(self, file_path, dict_path, dict_size=-1):
"""
Create the dictionary for the file, Note that
1. Valid characters include all printable characters
2. There is distinction between uppercase and lowercase letters
3. There is 3 special token:
<s>: the start of a sequence
<e>: the end of a sequence
<unk>: a word not included in dictionary
file_path: the path to store file
dict_path: the path to store dictionary
dict_size: word count of dictionary
if is -1, dictionary will contains all the words in file
"""
if not os.path.exists(dict_path):
dictory = dict()
with open(file_path, "r") as fdata:
for line in fdata:
line = line.split('\t')
for line_split in line:
words = line_split.strip().split()
for word in words:
if word not in dictory:
dictory[word] = 1
else:
dictory[word] += 1
output = open(dict_path, "w+")
output.write('<s>\n<e>\n<unk>\n')
count = 3
for key, value in sorted(
dictory.items(), key=lambda d: d[1], reverse=True):
output.write(key + "\n")
count += 1
if count == dict_size:
break
self.dict_size = count
def create_dataset(self,
dict_size=-1,
mergeDict=False,
suffixes=['.src', '.trg']):
"""
Create seqToseq dataset
"""
# dataset_list and dir_list has one-to-one relationship
train_dataset = os.path.join(self.data_path, self.train_dir_name)
test_dataset = os.path.join(self.data_path, self.test_dir_name)
gen_dataset = os.path.join(self.data_path, self.gen_dir_name)
dataset_list = [train_dataset, test_dataset, gen_dataset]
train_dir = os.path.join(self.output_path, self.train_dir_name)
test_dir = os.path.join(self.output_path, self.test_dir_name)
gen_dir = os.path.join(self.output_path, self.gen_dir_name)
dir_list = [train_dir, test_dir, gen_dir]
# create directory
for dir in dir_list:
if not os.path.exists(dir):
os.makedirs(dir)
# checkout dataset should be parallel corpora
suffix_len = len(suffixes[0])
for dataset in dataset_list:
file_list = os.listdir(dataset)
if len(file_list) % 2 == 1:
raise RuntimeError("dataset should be parallel corpora")
file_list.sort()
for i in range(0, len(file_list), 2):
if file_list[i][:-suffix_len] != file_list[i + 1][:-suffix_len]:
raise RuntimeError(
"source and target file name should be equal")
# cat all the files with the same suffix in dataset
for suffix in suffixes:
for dataset in dataset_list:
outname = os.path.basename(dataset) + suffix
self.cat_file(dataset, suffix, dataset, outname)
# concat parallel corpora and create file.list
print 'concat parallel corpora for dataset'
id = 0
list = ['train.list', 'test.list', 'gen.list']
for dataset in dataset_list:
outname = os.path.basename(dataset)
self.concat_file(dataset, outname + suffixes[0],
outname + suffixes[1], dir_list[id], outname)
save_list([os.path.join(dir_list[id], outname)],
os.path.join(self.output_path, list[id]))
id += 1
# build dictionary for train data
dict = ['src.dict', 'trg.dict']
dict_path = [
os.path.join(self.output_path, dict[0]),
os.path.join(self.output_path, dict[1])
]
if mergeDict:
outname = os.path.join(train_dir, train_dataset.split('/')[-1])
print 'build src dictionary for train data'
self.build_dict(outname, dict_path[0], dict_size)
print 'build trg dictionary for train data'
os.system('cp ' + dict_path[0] + ' ' + dict_path[1])
else:
outname = os.path.join(train_dataset, self.train_dir_name)
for id in range(0, 2):
suffix = suffixes[id]
print 'build ' + suffix[1:] + ' dictionary for train data'
self.build_dict(outname + suffix, dict_path[id], dict_size)
print 'dictionary size is', self.dict_size
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册