From 55cf75dfa931ca934d5126e249e538e1b222292f Mon Sep 17 00:00:00 2001 From: gongweibao Date: Thu, 26 Apr 2018 03:15:16 +0000 Subject: [PATCH] add data_loader --- .../transformer_nist_base/config.py | 4 + .../transformer_nist_base/data_util.py | 233 ++++++++++++++++++ .../transformer_nist_base/nmt_fluid.py | 18 +- 3 files changed, 252 insertions(+), 3 deletions(-) create mode 100644 fluid/neural_machine_translation/transformer_nist_base/data_util.py diff --git a/fluid/neural_machine_translation/transformer_nist_base/config.py b/fluid/neural_machine_translation/transformer_nist_base/config.py index 14f4ce0c..0cae6f36 100644 --- a/fluid/neural_machine_translation/transformer_nist_base/config.py +++ b/fluid/neural_machine_translation/transformer_nist_base/config.py @@ -5,6 +5,10 @@ class TrainTaskConfig(object): # the number of sequences contained in a mini-batch. batch_size = 56 + token_batch_size = 0 + # the flag indicating whether to sort all sequences by length, + # which can accelerate training but may damp the accuracy slightly. + sort_by_length = True # the hyper parameters for Adam optimizer. learning_rate = 0.001 diff --git a/fluid/neural_machine_translation/transformer_nist_base/data_util.py b/fluid/neural_machine_translation/transformer_nist_base/data_util.py new file mode 100644 index 00000000..ffa11a7a --- /dev/null +++ b/fluid/neural_machine_translation/transformer_nist_base/data_util.py @@ -0,0 +1,233 @@ +import os +import tarfile +import glob + +import random + +''' +START_MARK = "" +END_MARK = "" +UNK_MARK = "" +''' + +START_MARK = "<_GO>" +END_MARK = "<_EOS>" +UNK_MARK = "<_UNK>" + +class DataLoader(object): + def __init__(self, + src_vocab_fpath, + trg_vocab_fpath, + fpattern, + batch_size, + token_batch_size=0, + tar_fname=None, + sort_by_length=True, + shuffle=True, + min_len=0, + max_len=100, + n_batch=1): + self._src_vocab = self._load_dict(src_vocab_fpath) + self._trg_vocab = self._load_dict(trg_vocab_fpath) + self._batch_size = batch_size + self._token_batch_size = token_batch_size + self._tar_fname = tar_fname + self._sort_by_length = sort_by_length + self._shuffle = shuffle + self._min_len = min_len + self._max_len = max_len + self._n_batch = n_batch + + src_seq_words, trg_seq_words = self._load_data(fpattern, tar_fname) + self._src_seq_words = src_seq_words + self._trg_seq_words = trg_seq_words + src_seq_ids = [[ + self._src_vocab.get(word, self._src_vocab.get(UNK_MARK)) + for word in ([START_MARK] + src_seq + [END_MARK]) + ] for src_seq in self._src_seq_words] + trg_seq_ids = [[ + self._trg_vocab.get(word, self._trg_vocab.get(UNK_MARK)) + for word in ([START_MARK] + trg_seq + [END_MARK]) + ] for trg_seq in self._trg_seq_words] + self._src_seq_words = src_seq_ids + self._trg_seq_words = trg_seq_ids + + self._ins_cnt = len(self._src_seq_words) + assert len(self._trg_seq_words) == self._ins_cnt + + self._ins_idx = [i for i in xrange(self._ins_cnt)] + + if sort_by_length: + self._sort_index_by_len() + + # fix the batch + self._compose_batch_idx() + + self._epoch_idx = 0 + self._cur_batch_idx = 0 + + def _parse_file(self, f_obj): + src_seq_words = [] + trg_seq_words = [] + for line in f_obj: + fields = line.strip().split('\t') + is_valid = True + line_words = [] + + for i, field in enumerate(fields): + words = field.split() + if len(words) == 0 or \ + len(words) < self._min_len or \ + len(words) > self._max_len: + is_valid = False + break + line_words.append(words) + + if not is_valid: continue + + assert len(line_words) == 2 + + src_seq_words.append(line_words[0]) + trg_seq_words.append(line_words[1]) + + return (src_seq_words, trg_seq_words) + + def _load_data(self, fpattern, tar_fname=None): + fpaths = glob.glob(fpattern) + src_seq_words = [] + trg_seq_words = [] + + for fpath in fpaths: + if tarfile.is_tarfile(fpath): + assert tar_fname is not None + f = tarfile.open(fpath, 'r') + one_file_data = self._parse_file(f.extractfile(tar_fname)) + else: + assert os.path.isfile(fpath) + one_file_data = self._parse_file(open(fpath, 'r')) + + part_src_words, part_trg_words = one_file_data + + if len(src_seq_words) == 0: + src_seq_words, trg_seq_words = part_src_words, part_trg_words + continue + + src_seq_words.extend(part_src_words) + trg_seq_words.extend(part_trg_words) + + return src_seq_words, trg_seq_words + + def _load_dict(self, dict_path, reverse=False): + word_dict = {} + with open(dict_path, "r") as fdict: + for idx, line in enumerate(fdict): + if reverse: + word_dict[idx] = line.strip() + else: + word_dict[line.strip()] = idx + return word_dict + + def __iter__(self): + return self + + def __len__(self): + return sum([1 for _ in self]) + + def __next__(self): + return self.next() + + def _compose_batch_idx(self): + self._epoch_batch_idx = [] + + idx = 0 + + if self._token_batch_size > 0: + batch_idx = [] + max_src_len = 0 + max_trg_len = 0 + while idx < self._ins_cnt: + max_src_len = max(len(self._src_seq_words[self._ins_idx[idx]]), + max_src_len) + max_trg_len = max(len(self._trg_seq_words[self._ins_idx[idx]]), + max_trg_len) + max_len = max(max_src_len, max_trg_len) + if max_len * (len(batch_idx) + 1) > self._token_batch_size: + self._epoch_batch_idx.append(batch_idx) + max_src_len = 0 + max_trg_len = 0 + batch_idx = [] + continue + batch_idx.append(self._ins_idx[idx]) + idx += 1 + if len(batch_idx) > 0: + self._epoch_batch_idx.append(batch_idx) + else: + while idx < self._ins_cnt: + batch_idx = self._ins_idx[idx:idx + self._batch_size] + if len(batch_idx) > 0: + self._epoch_batch_idx.append(batch_idx) + idx += len(batch_idx) + + if self._shuffle: + if not self._sort_by_length and self._token_batch_size == 0: + random.shuffle(self._ins_idx) + #self._src_seq_words = self._src_seq_words[self._ins_idx] + #self._trg_seq_words = self._trg_seq_words[self._ins_idx] + self._src_seq_words = [ + self._src_seq_words[ins_idx] for ins_idx in self._ins_idx + ] + self._trg_seq_words = [ + self._trg_seq_words[ins_idx] for ins_idx in self._ins_idx + ] + else: + random.shuffle(self._epoch_batch_idx) + + def _sort_index_by_len(self): + self._ins_idx.sort( + key=lambda idx: max( + len(self._src_seq_words[idx]), + len(self._trg_seq_words[idx]))) + + def next(self): + while self._cur_batch_idx < len(self._epoch_batch_idx): + batch_idx = self._epoch_batch_idx[self._cur_batch_idx] + src_seq_words = [self._src_seq_words[idx] for idx in batch_idx] + trg_seq_words = [self._trg_seq_words[idx] for idx in batch_idx] + # consider whether drop + self._cur_batch_idx += 1 + return zip(src_seq_words, + [trg_seq[:-1] for trg_seq in trg_seq_words], + [trg_seq[1:] for trg_seq in trg_seq_words]) + + if self._cur_batch_idx >= len(self._epoch_batch_idx): + self._epoch_idx += 1 + self._cur_batch_idx = 0 + if self._shuffle: + if not self._sort_by_length and self._token_batch_size == 0: + random.shuffle(self._ins_idx) + #self._src_seq_words = self._src_seq_words[self._ins_idx] + #self._trg_seq_words = self._trg_seq_words[self._ins_idx] + self._src_seq_words = [ + self._src_seq_words[ins_idx] + for ins_idx in self._ins_idx + ] + self._trg_seq_words = [ + self._trg_seq_words[ins_idx] + for ins_idx in self._ins_idx + ] + else: + random.shuffle(self._epoch_batch_idx) + raise StopIteration + +if __name__ == "__main__": + '''data_loader = DataLoader("/root/workspace/unify_reader/wmt16/en_10000.dict", + "/root/workspace/unify_reader/wmt16/de_10000.dict", + "/root/workspace/unify_reader/wmt16/wmt16.tar.gz", + 2, tar_fname="wmt16/train")''' + data_loader = DataLoader( + "/root/workspace/unify_reader/nist06n_tiny/cn_30001.dict.unify", + "/root/workspace/unify_reader/nist06n_tiny/en_30001.dict.unify", + "/root/workspace/unify_reader/nist06n_tiny/data/part-*", + 2) + print data_loader.next() + diff --git a/fluid/neural_machine_translation/transformer_nist_base/nmt_fluid.py b/fluid/neural_machine_translation/transformer_nist_base/nmt_fluid.py index 258422b9..bedddffb 100644 --- a/fluid/neural_machine_translation/transformer_nist_base/nmt_fluid.py +++ b/fluid/neural_machine_translation/transformer_nist_base/nmt_fluid.py @@ -15,7 +15,7 @@ from config import TrainTaskConfig, ModelHyperParams, pos_enc_param_names, \ import paddle.fluid.debuger as debuger import nist_data_provider import sys -#from memory_profiler import profile +import data_util def str2bool(v): if v.lower() in ('yes', 'true', 't', 'y', '1'): @@ -261,6 +261,7 @@ def main(): test_ppl = np.exp([min(test_avg_cost, 100)]) return test_avg_cost, test_ppl + ''' def train_loop(exe, trainer_prog): # Initialize the parameters. """ @@ -272,13 +273,15 @@ def main(): pos_enc_param.set( position_encoding_init(ModelHyperParams.max_length + 1, ModelHyperParams.d_model), place) + ''' def train_loop(exe, trainer_prog): for pass_id in xrange(args.pass_num): ts = time.time() total = 0 pass_start_time = time.time() - for batch_id, data in enumerate(train_reader()): + for batch_id, data in enumerate(train_reader): + print len(data) if len(data) != args.batch_size: continue @@ -412,6 +415,16 @@ def main(): position_encoding_init(ModelHyperParams.max_length + 1, ModelHyperParams.d_model), place) + train_reader = data_util.DataLoader( + src_vocab_fpath="/root/data/nist06n/cn_30001.dict", + trg_vocab_fpath="/root/data/nist06n/en_30001.dict", + fpattern="/root/data/nist06/data-%d/part-*" % (args.task_index), + batch_size=args.batch_size, + token_batch_size=TrainTaskConfig.token_batch_size, + sort_by_length=TrainTaskConfig.sort_by_length, + shuffle=True) + + ''' train_reader = paddle.batch( paddle.reader.shuffle( nist_data_provider.train("data", ModelHyperParams.src_vocab_size, @@ -419,7 +432,6 @@ def main(): buf_size=100000), batch_size=args.batch_size) - ''' test_reader = paddle.batch( nist_data_provider.train("data", ModelHyperParams.src_vocab_size, ModelHyperParams.trg_vocab_size), -- GitLab