提交 96a2e44a 编写于 作者: Q qiaolongfei

optimize seq2seq-dataset

上级 37806792
develop 1.8.5 2.0.1-rocm-post 2.4.1 Ligoml-patch-1 OliverLPH-patch-1 OliverLPH-patch-2 PaddlePM-patch-1 PaddlePM-patch-2 ZHUI-patch-1 add_default_att add_kylinv10 add_model_benchmark_ci add_some_yaml_config addfile all_new_design_exec ascendrc ascendrelease bugfix-eval-frame-leakgae cherry-pick-fix-customOP-random-fail cherry_undefined_var compile_windows cp_2.4_fix_numpy delete_2.0.1-rocm-post delete_add_default_att delete_all_new_design_exec delete_ascendrc delete_compile_windows delete_delete_addfile delete_disable_iterable_dataset_unittest delete_fix_dataloader_memory_leak delete_fix_imperative_dygraph_error delete_fix_retry_ci delete_fix_undefined_var delete_improve_sccache delete_incubate/lite delete_paddle_tiny_install delete_paralleltest delete_prv-disable-more-cache delete_revert-31068-fix_conv3d_windows delete_revert-31562-mean delete_revert-33630-bug-fix delete_revert-34159-add_npu_bce_logical_dev delete_revert-34910-spinlocks_for_allocator delete_revert-35069-revert-34910-spinlocks_for_allocator delete_revert-36057-dev/read_flags_in_ut dingjiaweiww-patch-1 disable_iterable_dataset_unittest dy2static enable_eager_model_test final_state_gen_python_c final_state_intermediate fix-numpy-issue fix-run-program-grad-node-mem fix_check fix_concat_slice fix_custom_device_copy_sync fix_dataloader_memory_leak fix_dlpack_for fix_imperative_dygraph_error fix_newexe_gc fix_npu_ci fix_op_flops fix_retry_ci fix_rnn_docs fix_tensor_type fix_undefined_var fix_var_stop_gradient_error fixiscan fixiscan1 fixiscan2 fixiscan3 github/fork/123malin/netifaces github/fork/123malin/tdm_abacus github/fork/AshburnLee/dev_unique github/fork/ForFishes/fix_memory_matmul github/fork/ForFishes/rm_fluid github/fork/LielinJiang/move-2.0-api github/fork/LielinJiang/visual-dl-cb github/fork/LiuChiachi/add-transformer-generate-square-subsequent-mask-api github/fork/LiuChiachi/fix-example-code-for-hapi-Model github/fork/LiuChiachi/remove-input-requirment-in-dygraph-Model github/fork/MrChengmo/fix_ps_profiler github/fork/MrChengmo/update_ps_heter github/fork/PWhiddy/patch-1 github/fork/Shixiaowei02/dev/save_load_upgrade github/fork/TCChenlong/fix_hapi github/fork/TCChenlong/fix_inden github/fork/Thunderbrook/xpu_slice github/fork/XieYunshen/disable_ut_test_parallel_executor_fetch_isolated_var github/fork/XieYunshen/disable_ut_test_parallel_executor_fetch_isolated_var_2 github/fork/XieYunshen/disable_ut_test_parallel_executor_fetch_isolated_var_3 github/fork/XieYunshen/timeout_20S_ut github/fork/ZeyuChen/remove-nltk github/fork/arlesniak/arlesniak/selective__mkldnn_flags github/fork/baiyfbupt/code_doc_mig github/fork/chalsliu/set_timeout github/fork/chen-zhiyu/develop github/fork/chenwhql/ci/try_to_find_test_buffer_shared_memory_reuse_pass_error github/fork/chenwhql/dygraph/remove_scale_loss_and_apply_collective_grads github/fork/chenwhql/saveload/add_get_inference_program github/fork/chenwhql/saveload/remove_save_load_config github/fork/cryoco/pass-compatibility-trt github/fork/danleifeng/isempty_api2.0 github/fork/frankwhzhang/api_transfer github/fork/hbwx24/error_msg/cuda_kernel_error_msg github/fork/heavengate/cherry_yolo_box github/fork/heavengate/update_yolo_box github/fork/iclementine/rnn_fix github/fork/iducn/testestse github/fork/jczaja/prv-25537-fix github/fork/jeff41404/release/1.8 github/fork/jiweibo/api_2.0 github/fork/jiweibo/fix_lite_resnet50_test github/fork/juncaipeng/fix_doc_1 github/fork/lfchener/sample_code github/fork/littletomatodonkey/fix_reg_doc github/fork/liym27/dy2stat_update_assign_to_rc20 github/fork/luotao1/profiler_ut github/fork/mapingshuo/add_wait github/fork/mapingshuo/doc_2.0 github/fork/mapingshuo/zero-0.5 github/fork/miraiwk/dev github/fork/pangyoki/add-Categorical-class-branch github/fork/pangyoki/add-multinomial-op-branch github/fork/pangyoki/fix-test_distritbution-CI github/fork/qjing666/doublegrad github/fork/qjing666/fix_hdfs_download github/fork/sandyhouse/add_gather_etc github/fork/sandyhouse/add_send_recv_alltoall_etc github/fork/sandyhouse/pipeline_exe_run github/fork/seiriosPlus/feature/large_scale_kv_save_delta github/fork/seiriosPlus/fix/paddle_errors_fix github/fork/seiriosPlus/fix/paddle_op_errors github/fork/shangzhizhou/fix_test_activation_op_random_bug github/fork/smallv0221/yxp0924 github/fork/smallv0221/yxp0925 github/fork/swtkiwi/del-matplotlib github/fork/tianshuo78520a/kunlun_test github/fork/tianshuo78520a/update_dockerfile github/fork/wanghaoshuang/bert_fuse github/fork/wanghaoshuang/label_smooth github/fork/wanghuancoder/develop_CUDASynchronize github/fork/wanghuancoder/develop_Layer_doc github/fork/wanghuancoder/develop_ParameterList_doc github/fork/wanghuancoder/develop_Sequential_doc github/fork/wanghuancoder/develop_bilinear_tensor_product github/fork/wanghuancoder/develop_coverage_build_sh github/fork/wanghuancoder/develop_in_dynamic_mode_doc github/fork/wanghuancoder/develop_unique_name_doc github/fork/wangxicoding/fleet_meta_combine github/fork/wawltor/error_message_fix_5 github/fork/willthefrog/remove_l2_norm github/fork/windstamp/momentum_op github/fork/windstamp/mv_op_5 github/fork/windstamp/normal_api github/fork/wojtuss/wojtuss/fusion_gru_quantization github/fork/wojtuss/wojtuss/quantization-with-shift github/fork/wzzju/fix_err_info github/fork/wzzju/pure_fp16 github/fork/xiemoyuan/op_error_message github/fork/xiemoyuan/optimize_error_message github/fork/yaoxuefeng6/fix_doc github/fork/yaoxuefeng6/mod_dataset_v2 github/fork/yongqiangma/lod github/fork/ysh329/fix-clip-by-norm-error github/fork/ysh329/fix-error-clip-by-value github/fork/yukavio/error_info github/fork/zhangting2020/conv_filter_grad github/fork/zhangting2020/is_compile_with_cuda github/fork/zhangting2020/place_doc github/fork/zhangting2020/program github/fork/zhhsplendid/fix_any github/fork/zhhsplendid/refine_api2 github/fork/zhhsplendid/refine_api2_test github/fork/zhhsplendid/refine_api_test_ptb_lm github/fork/zhhsplendid/refine_api_test_resnet github/fork/zhhsplendid/refine_api_test_simnet github/fork/zhiqiu/dev/refine_initializer github/fork/zhiqiu/dev/remove_inplace_argument github/fork/zlsh80826/nvinfer_plugin_var_len_cuda11 hack_event improve_sccache incuabte/new_frl incubate/frl_train_eval incubate/infrt incubate/lite incubate/new_frl incubate/new_frl_rc incubate/stride inplace_addto layer_norm make_flag_adding_easier master matmul_double_grad move_embedding_to_phi move_histogram_to_pten move_sgd_to_phi move_slice_to_pten move_temporal_shift_to_phi move_yolo_box_to_phi npu_fix_alloc numel operator_opt paddle_tiny_install paralleltest pass-compile-eval-frame preln_ernie prv-disable-more-cache prv-md-even-more prv-onednn-2.5 prv-reshape-mkldnn-ut2 pten_tensor_refactor release-deleted/2.5 release-rc/2.5 release/0.10.0 release/0.11.0 release/0.12.0 release/0.13.0 release/0.14.0 release/0.15.0 release/1.0.0 release/1.1 release/1.2 release/1.3 release/1.4 release/1.5 release/1.6 release/1.7 release/1.8 release/2.0 release/2.0-alpha release/2.0-beta release/2.0-rc release/2.0-rc1 release/2.1 release/2.2 release/2.3 release/2.3-fc-ernie-fix release/2.4 release/2.5 release/lite-0.1 release/llm_2.5 revert-24981-add_device_attr_for_regulization revert-26856-strategy_example2 revert-27520-disable_pr revert-31068-fix_conv3d_windows revert-31562-mean revert-32290-develop-hardlabel revert-33037-forci revert-33475-fix_cifar_label_dimension revert-33630-bug-fix revert-34159-add_npu_bce_logical_dev revert-34406-add_copy_from_tensor revert-34910-spinlocks_for_allocator revert-35069-revert-34910-spinlocks_for_allocator revert-36057-dev/read_flags_in_ut revert-36201-refine_fast_threaded_ssa_graph_executor revert-36985-add_license revert-37318-refactor_dygraph_to_eager revert-37926-eager_coreops_500 revert-37956-revert-37727-pylayer_support_tuple revert-38100-mingdong revert-38301-allocation_rearrange_pr revert-38703-numpy_bf16_package_reupload revert-38732-remove_useless_header_in_elementwise_mul_grad revert-38959-Reduce_Grad revert-39143-adjust_empty revert-39227-move_trace_op_to_pten revert-39268-dev/remove_concat_fluid_kernel revert-40170-support_partial_grad revert-41056-revert-40727-move_some_activaion_to_phi revert-41065-revert-40993-mv_ele_floordiv_pow revert-41068-revert-40790-phi_new revert-41944-smaller_inference_api_test revert-42149-do-not-reset-default-stream-for-stream-safe-cuda-allocator revert-43155-fix_ut_tempfile revert-43882-revert-41944-smaller_inference_api_test revert-45808-phi/simplify_size_op revert-46827-deform_comment revert-47325-remove_cudnn_hardcode revert-47645-add_npu_storage_dims revert-48815-set_free_when_no_cache_hit_default_value_true revert-49499-test_ninja_on_ci revert-49654-prim_api_gen revert-49673-modify_get_single_cov revert-49763-fix_static_composite_gen revert-50158-fix_found_inf_bug_for_custom_optimizer revert-50188-refine_optimizer_create_accumulators revert-50335-fix_optminizer_set_auxiliary_var_bug revert-51676-flag_delete revert-51850-fix_softmaxce_dev revert-52175-dev_peak_memory revert-52186-deve revert-52523-test_py38 revert-52912-develop revert-53248-set_cmake_policy revert-54029-fix_windows_compile_bug revert-54068-support_translating_op_attribute revert-54214-modify_cmake_dependencies revert-54370-offline_pslib revert-54391-fix_cmake_md5error revert-54411-fix_cpp17_compile revert-54466-offline_pslib revert-54480-cmake-rocksdb revert-55568-fix_BF16_bug1 revert-56328-new_ir_support_vector_type_place_transfer revert-56366-fix_openssl_bug revert-56545-revert-56366-fix_openssl_bug revert-56620-fix_new_ir_ocr_bug revert-56925-check_inputs_grad_semantic revert-57005-refine_stride_flag rocm_dev_0217 sd_conv_linear_autocast semi-auto/rule-base support-0D-sort support_weight_transpose test_benchmark_ci test_feature_precision_test_c test_for_Filtetfiles test_model_benchmark test_model_benchmark_ci zhiqiu-patch-1 v2.5.1 v2.5.0 v2.5.0-rc1 v2.5.0-rc0 v2.4.2 v2.4.1 v2.4.0 v2.4.0-rc0 v2.3.2 v2.3.1 v2.3.0 v2.3.0-rc0 v2.2.2 v2.2.1 v2.2.0 v2.2.0-rc0 v2.2.0-bak0 v2.1.3 v2.1.2 v2.1.1 v2.1.0 v2.1.0-rc0 v2.0.2 v2.0.1 v2.0.0 v2.0.0-rc1 v2.0.0-rc0 v2.0.0-beta0 v2.0.0-alpha0 v1.8.5 v1.8.4 v1.8.3 v1.8.2 v1.8.1 v1.8.0 v1.7.2 v1.7.1 v1.7.0 v1.6.3 v1.6.2 v1.6.1 v1.6.0 v1.6.0-rc0 v1.5.2 v1.5.1 v1.5.0 v1.4.1 v1.4.0 v1.3.2 v1.3.1 v1.3.0 v1.2.1 v1.2.0 v1.1.0 v1.0.2 v1.0.1 v1.0.0 v1.0.0-rc0 v0.15.0 v0.15.0-rc0 v0.14.0 v0.13.0 v0.12.0 v0.11.1a2 v0.11.1a1 v0.11.0 v0.10.0 v0.10.0rc4 v0.10.0rc lite-v0.1
3 合并请求!11636[IMPORTANT] MKLDNN layout: Support for sum operator,!2081Release/0.10.0,!1560optimize Seq2seq dataset
...@@ -12,22 +12,176 @@ ...@@ -12,22 +12,176 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import sys
import random
import operator import operator
import numpy as np
from subprocess import Popen, PIPE
from os.path import join as join_path
from optparse import OptionParser from optparse import OptionParser
from os.path import join as join_path
from subprocess import Popen, PIPE
import numpy as np
from paddle.utils.preprocess_util import * from paddle.utils.preprocess_util import *
from paddle.utils.preprocess_util import save_list, DatasetCreater
""" """
Usage: run following command to show help message. Usage: run following command to show help message.
python preprocess.py -h python preprocess.py -h
""" """
class SeqToSeqDatasetCreater(DatasetCreater):
"""
A class to process data for sequence to sequence application.
"""
def __init__(self, data_path, output_path):
"""
data_path: the path to store the train data, test data and gen data
output_path: the path to store the processed dataset
"""
DatasetCreater.__init__(self, data_path)
self.gen_dir_name = 'gen'
self.gen_list_name = 'gen.list'
self.output_path = output_path
def concat_file(self, file_path, file1, file2, output_path, output):
"""
Concat file1 and file2 to be one output file
The i-th line of output = i-th line of file1 + '\t' + i-th line of file2
file_path: the path to store file1 and file2
output_path: the path to store output file
"""
file1 = os.path.join(file_path, file1)
file2 = os.path.join(file_path, file2)
output = os.path.join(output_path, output)
if not os.path.exists(output):
os.system('paste ' + file1 + ' ' + file2 + ' > ' + output)
def cat_file(self, dir_path, suffix, output_path, output):
"""
Cat all the files in dir_path with suffix to be one output file
dir_path: the base directory to store input file
suffix: suffix of file name
output_path: the path to store output file
"""
cmd = 'cat '
file_list = os.listdir(dir_path)
file_list.sort()
for file in file_list:
if file.endswith(suffix):
cmd += os.path.join(dir_path, file) + ' '
output = os.path.join(output_path, output)
if not os.path.exists(output):
os.system(cmd + '> ' + output)
def build_dict(self, file_path, dict_path, dict_size=-1):
"""
Create the dictionary for the file, Note that
1. Valid characters include all printable characters
2. There is distinction between uppercase and lowercase letters
3. There is 3 special token:
<s>: the start of a sequence
<e>: the end of a sequence
<unk>: a word not included in dictionary
file_path: the path to store file
dict_path: the path to store dictionary
dict_size: word count of dictionary
if is -1, dictionary will contains all the words in file
"""
if not os.path.exists(dict_path):
dictory = dict()
with open(file_path, "r") as fdata:
for line in fdata:
line = line.split('\t')
for line_split in line:
words = line_split.strip().split()
for word in words:
if word not in dictory:
dictory[word] = 1
else:
dictory[word] += 1
output = open(dict_path, "w+")
output.write('<s>\n<e>\n<unk>\n')
count = 3
for key, value in sorted(
dictory.items(), key=lambda d: d[1], reverse=True):
output.write(key + "\n")
count += 1
if count == dict_size:
break
self.dict_size = count
def create_dataset(self,
dict_size=-1,
mergeDict=False,
suffixes=['.src', '.trg']):
"""
Create seqToseq dataset
"""
# dataset_list and dir_list has one-to-one relationship
train_dataset = os.path.join(self.data_path, self.train_dir_name)
test_dataset = os.path.join(self.data_path, self.test_dir_name)
gen_dataset = os.path.join(self.data_path, self.gen_dir_name)
dataset_list = [train_dataset, test_dataset, gen_dataset]
train_dir = os.path.join(self.output_path, self.train_dir_name)
test_dir = os.path.join(self.output_path, self.test_dir_name)
gen_dir = os.path.join(self.output_path, self.gen_dir_name)
dir_list = [train_dir, test_dir, gen_dir]
# create directory
for dir in dir_list:
if not os.path.exists(dir):
os.makedirs(dir)
# checkout dataset should be parallel corpora
suffix_len = len(suffixes[0])
for dataset in dataset_list:
file_list = os.listdir(dataset)
if len(file_list) % 2 == 1:
raise RuntimeError("dataset should be parallel corpora")
file_list.sort()
for i in range(0, len(file_list), 2):
if file_list[i][:-suffix_len] != file_list[i + 1][:-suffix_len]:
raise RuntimeError(
"source and target file name should be equal")
# cat all the files with the same suffix in dataset
for suffix in suffixes:
for dataset in dataset_list:
outname = os.path.basename(dataset) + suffix
self.cat_file(dataset, suffix, dataset, outname)
# concat parallel corpora and create file.list
print 'concat parallel corpora for dataset'
id = 0
list = ['train.list', 'test.list', 'gen.list']
for dataset in dataset_list:
outname = os.path.basename(dataset)
self.concat_file(dataset, outname + suffixes[0],
outname + suffixes[1], dir_list[id], outname)
save_list([os.path.join(dir_list[id], outname)],
os.path.join(self.output_path, list[id]))
id += 1
# build dictionary for train data
dict = ['src.dict', 'trg.dict']
dict_path = [
os.path.join(self.output_path, dict[0]),
os.path.join(self.output_path, dict[1])
]
if mergeDict:
outname = os.path.join(train_dir, train_dataset.split('/')[-1])
print 'build src dictionary for train data'
self.build_dict(outname, dict_path[0], dict_size)
print 'build trg dictionary for train data'
os.system('cp ' + dict_path[0] + ' ' + dict_path[1])
else:
outname = os.path.join(train_dataset, self.train_dir_name)
for id in range(0, 2):
suffix = suffixes[id]
print 'build ' + suffix[1:] + ' dictionary for train data'
self.build_dict(outname + suffix, dict_path[id], dict_size)
print 'dictionary size is', self.dict_size
def save_dict(dict, filename, is_reverse=True): def save_dict(dict, filename, is_reverse=True):
""" """
Save dictionary into file. Save dictionary into file.
......
...@@ -14,103 +14,92 @@ ...@@ -14,103 +14,92 @@
""" """
wmt14 dataset wmt14 dataset
""" """
import os
import os.path
import tarfile import tarfile
import paddle.v2.dataset.common import paddle.v2.dataset.common
from wmt14_util import SeqToSeqDatasetCreater
__all__ = ['train', 'test', 'build_dict'] __all__ = ['train', 'test', 'build_dict']
URL_DEV_TEST = 'http://www-lium.univ-lemans.fr/~schwenk/cslm_joint_paper/data/dev+test.tgz' URL_DEV_TEST = 'http://www-lium.univ-lemans.fr/~schwenk/cslm_joint_paper/data/dev+test.tgz'
MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5' MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5'
# this is a small set of data for test. The original data is too large and will be add later. # this is a small set of data for test. The original data is too large and will be add later.
URL_TRAIN = 'http://paddlepaddle.bj.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz' URL_TRAIN = 'http://localhost:8989/wmt14.tgz'
MD5_TRAIN = '7373473f86016f1f48037c9c340a2d5b' MD5_TRAIN = 'a755315dd01c2c35bde29a744ede23a6'
START = "<s>" START = "<s>"
END = "<e>" END = "<e>"
UNK = "<unk>" UNK = "<unk>"
UNK_IDX = 2 UNK_IDX = 2
DEFAULT_DATA_DIR = "./data"
ORIGIN_DATA_DIR = "wmt14" def __read_to_dict__(tar_file, dict_size):
INNER_DATA_DIR = "pre-wmt14" def __to_dict__(fd, size):
SRC_DICT = INNER_DATA_DIR + "/src.dict"
TRG_DICT = INNER_DATA_DIR + "/trg.dict"
TRAIN_FILE = INNER_DATA_DIR + "/train/train"
def __process_data__(data_path, dict_size=None):
downloaded_data = os.path.join(data_path, ORIGIN_DATA_DIR)
if not os.path.exists(downloaded_data):
# 1. download and extract tgz.
with tarfile.open(
paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14',
MD5_TRAIN)) as tf:
tf.extractall(data_path)
# 2. process data file to intermediate format.
processed_data = os.path.join(data_path, INNER_DATA_DIR)
if not os.path.exists(processed_data):
dict_size = dict_size or -1
data_creator = SeqToSeqDatasetCreater(downloaded_data, processed_data)
data_creator.create_dataset(dict_size, mergeDict=False)
def __read_to_dict__(dict_path, count):
with open(dict_path, "r") as fin:
out_dict = dict() out_dict = dict()
for line_count, line in enumerate(fin): for line_count, line in enumerate(fd):
if line_count <= count: if line_count < size:
out_dict[line.strip()] = line_count out_dict[line.strip()] = line_count
else: else:
break break
return out_dict return out_dict
with tarfile.open(tar_file, mode='r') as f:
def __reader__(file_name, src_dict, trg_dict): names = [
with open(file_name, 'r') as f: each_item.name for each_item in f
for line_count, line in enumerate(f): if each_item.name.endswith("src.dict")
line_split = line.strip().split('\t') ]
if len(line_split) != 2: assert len(names) == 1
continue src_dict = __to_dict__(f.extractfile(names[0]), dict_size)
src_seq = line_split[0] # one source sequence names = [
src_words = src_seq.split() each_item.name for each_item in f
src_ids = [ if each_item.name.endswith("trg.dict")
src_dict.get(w, UNK_IDX) for w in [START] + src_words + [END] ]
assert len(names) == 1
trg_dict = __to_dict__(f.extractfile(names[0]), dict_size)
return src_dict, trg_dict
def reader_creator(tar_file, file_name, dict_size):
def reader():
src_dict, trg_dict = __read_to_dict__(tar_file, dict_size)
with tarfile.open(tar_file, mode='r') as f:
names = [
each_item.name for each_item in f
if each_item.name.endswith(file_name)
] ]
for name in names:
trg_seq = line_split[1] # one target sequence for line in f.extractfile(name):
trg_words = trg_seq.split() line_split = line.strip().split('\t')
trg_ids = [trg_dict.get(w, UNK_IDX) for w in trg_words] if len(line_split) != 2:
continue
# remove sequence whose length > 80 in training mode src_seq = line_split[0] # one source sequence
if len(src_ids) > 80 or len(trg_ids) > 80: src_words = src_seq.split()
continue src_ids = [
trg_ids_next = trg_ids + [trg_dict[END]] src_dict.get(w, UNK_IDX)
trg_ids = [trg_dict[START]] + trg_ids for w in [START] + src_words + [END]
]
yield src_ids, trg_ids, trg_ids_next
trg_seq = line_split[1] # one target sequence
trg_words = trg_seq.split()
def train(data_dir=None, dict_size=None): trg_ids = [trg_dict.get(w, UNK_IDX) for w in trg_words]
data_dir = data_dir or DEFAULT_DATA_DIR
__process_data__(data_dir, dict_size) # remove sequence whose length > 80 in training mode
src_lang_dict = os.path.join(data_dir, SRC_DICT) if len(src_ids) > 80 or len(trg_ids) > 80:
trg_lang_dict = os.path.join(data_dir, TRG_DICT) continue
train_file_name = os.path.join(data_dir, TRAIN_FILE) trg_ids_next = trg_ids + [trg_dict[END]]
trg_ids = [trg_dict[START]] + trg_ids
default_dict_size = len(open(src_lang_dict, "r").readlines())
yield src_ids, trg_ids, trg_ids_next
if dict_size > default_dict_size:
raise ValueError("dict_dim should not be larger then the " return reader
"length of word dict")
real_dict_dim = dict_size or default_dict_size def train(dict_size):
return reader_creator(
src_dict = __read_to_dict__(src_lang_dict, real_dict_dim) paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN),
trg_dict = __read_to_dict__(trg_lang_dict, real_dict_dim) 'train/train', dict_size)
return lambda: __reader__(train_file_name, src_dict, trg_dict)
def test(dict_size):
return reader_creator(
paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN),
'test/test', dict_size)
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from paddle.utils.preprocess_util import save_list, DatasetCreater
class SeqToSeqDatasetCreater(DatasetCreater):
"""
A class to process data for sequence to sequence application.
"""
def __init__(self, data_path, output_path):
"""
data_path: the path to store the train data, test data and gen data
output_path: the path to store the processed dataset
"""
DatasetCreater.__init__(self, data_path)
self.gen_dir_name = 'gen'
self.gen_list_name = 'gen.list'
self.output_path = output_path
def concat_file(self, file_path, file1, file2, output_path, output):
"""
Concat file1 and file2 to be one output file
The i-th line of output = i-th line of file1 + '\t' + i-th line of file2
file_path: the path to store file1 and file2
output_path: the path to store output file
"""
file1 = os.path.join(file_path, file1)
file2 = os.path.join(file_path, file2)
output = os.path.join(output_path, output)
if not os.path.exists(output):
os.system('paste ' + file1 + ' ' + file2 + ' > ' + output)
def cat_file(self, dir_path, suffix, output_path, output):
"""
Cat all the files in dir_path with suffix to be one output file
dir_path: the base directory to store input file
suffix: suffix of file name
output_path: the path to store output file
"""
cmd = 'cat '
file_list = os.listdir(dir_path)
file_list.sort()
for file in file_list:
if file.endswith(suffix):
cmd += os.path.join(dir_path, file) + ' '
output = os.path.join(output_path, output)
if not os.path.exists(output):
os.system(cmd + '> ' + output)
def build_dict(self, file_path, dict_path, dict_size=-1):
"""
Create the dictionary for the file, Note that
1. Valid characters include all printable characters
2. There is distinction between uppercase and lowercase letters
3. There is 3 special token:
<s>: the start of a sequence
<e>: the end of a sequence
<unk>: a word not included in dictionary
file_path: the path to store file
dict_path: the path to store dictionary
dict_size: word count of dictionary
if is -1, dictionary will contains all the words in file
"""
if not os.path.exists(dict_path):
dictory = dict()
with open(file_path, "r") as fdata:
for line in fdata:
line = line.split('\t')
for line_split in line:
words = line_split.strip().split()
for word in words:
if word not in dictory:
dictory[word] = 1
else:
dictory[word] += 1
output = open(dict_path, "w+")
output.write('<s>\n<e>\n<unk>\n')
count = 3
for key, value in sorted(
dictory.items(), key=lambda d: d[1], reverse=True):
output.write(key + "\n")
count += 1
if count == dict_size:
break
self.dict_size = count
def create_dataset(self,
dict_size=-1,
mergeDict=False,
suffixes=['.src', '.trg']):
"""
Create seqToseq dataset
"""
# dataset_list and dir_list has one-to-one relationship
train_dataset = os.path.join(self.data_path, self.train_dir_name)
test_dataset = os.path.join(self.data_path, self.test_dir_name)
gen_dataset = os.path.join(self.data_path, self.gen_dir_name)
dataset_list = [train_dataset, test_dataset, gen_dataset]
train_dir = os.path.join(self.output_path, self.train_dir_name)
test_dir = os.path.join(self.output_path, self.test_dir_name)
gen_dir = os.path.join(self.output_path, self.gen_dir_name)
dir_list = [train_dir, test_dir, gen_dir]
# create directory
for dir in dir_list:
if not os.path.exists(dir):
os.makedirs(dir)
# checkout dataset should be parallel corpora
suffix_len = len(suffixes[0])
for dataset in dataset_list:
file_list = os.listdir(dataset)
if len(file_list) % 2 == 1:
raise RuntimeError("dataset should be parallel corpora")
file_list.sort()
for i in range(0, len(file_list), 2):
if file_list[i][:-suffix_len] != file_list[i + 1][:-suffix_len]:
raise RuntimeError(
"source and target file name should be equal")
# cat all the files with the same suffix in dataset
for suffix in suffixes:
for dataset in dataset_list:
outname = os.path.basename(dataset) + suffix
self.cat_file(dataset, suffix, dataset, outname)
# concat parallel corpora and create file.list
print 'concat parallel corpora for dataset'
id = 0
list = ['train.list', 'test.list', 'gen.list']
for dataset in dataset_list:
outname = os.path.basename(dataset)
self.concat_file(dataset, outname + suffixes[0],
outname + suffixes[1], dir_list[id], outname)
save_list([os.path.join(dir_list[id], outname)],
os.path.join(self.output_path, list[id]))
id += 1
# build dictionary for train data
dict = ['src.dict', 'trg.dict']
dict_path = [
os.path.join(self.output_path, dict[0]),
os.path.join(self.output_path, dict[1])
]
if mergeDict:
outname = os.path.join(train_dir, train_dataset.split('/')[-1])
print 'build src dictionary for train data'
self.build_dict(outname, dict_path[0], dict_size)
print 'build trg dictionary for train data'
os.system('cp ' + dict_path[0] + ' ' + dict_path[1])
else:
outname = os.path.join(train_dataset, self.train_dir_name)
for id in range(0, 2):
suffix = suffixes[id]
print 'build ' + suffix[1:] + ' dictionary for train data'
self.build_dict(outname + suffix, dict_path[id], dict_size)
print 'dictionary size is', self.dict_size
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
反馈
建议
客服 返回
顶部