From f0fc20ee770c1b0539e2462ae7bd3c7e7cf80ced Mon Sep 17 00:00:00 2001 From: guosheng Date: Fri, 11 May 2018 00:35:09 +0800 Subject: [PATCH] Refine docs and codes in Transformer by following comments --- .../transformer/config.py | 2 +- .../transformer/reader.py | 32 +++++++++++++++---- .../transformer/train.py | 2 +- 3 files changed, 27 insertions(+), 9 deletions(-) diff --git a/fluid/neural_machine_translation/transformer/config.py b/fluid/neural_machine_translation/transformer/config.py index 74f920a1..8ab9efce 100644 --- a/fluid/neural_machine_translation/transformer/config.py +++ b/fluid/neural_machine_translation/transformer/config.py @@ -5,7 +5,7 @@ class TrainTaskConfig(object): # the number of sequences contained in a mini-batch. batch_size = 32 # the hyper parameters for Adam optimizer. - # This static learning_rate will multiply LearningRateScheduler + # This static learning_rate will be multiplied to the LearningRateScheduler # derived learning rate the to get the final learning rate. learning_rate = 1 beta1 = 0.9 diff --git a/fluid/neural_machine_translation/transformer/reader.py b/fluid/neural_machine_translation/transformer/reader.py index 5daa70a2..1b9b6b69 100644 --- a/fluid/neural_machine_translation/transformer/reader.py +++ b/fluid/neural_machine_translation/transformer/reader.py @@ -34,7 +34,8 @@ class Pool(object): if self._sort: self._pool.sort( - key=lambda sample: max(len(sample[0]), len(sample[1])) if len(sample) > 1 else len(sample[0]) + key=lambda sample: max(len(sample[0]), len(sample[1])) \ + if len(sample) > 1 else len(sample[0]) ) if self._end and len(self._pool) < self._pool_size: @@ -63,8 +64,28 @@ class Pool(object): class DataReader(object): """ The data reader loads all data from files and produces batches of data - in the way corresponding to settings. - number of tokens or number of sequences. + in the way corresponding to settings. See the doc of __init__ function + for more setting details. + + An example of returning a generator producing data batches whose data + is shuffled in each pass and sorted in each pool: + + ``` + train_data = DataReader( + src_vocab_fpath='data/src_vocab_file', + trg_vocab_fpath='data/trg_vocab_file', + fpattern='data/part-*', + use_token_batch=True, + batch_size=2000, + pool_size=10000, + sort_type=SortType.POOL, + shuffle=True, + shuffle_batch=True, + start_mark='', + end_mark='', + unk_mark='', + clip_last_batch=False).batch_generator + ``` """ def __init__(self, @@ -99,14 +120,11 @@ class DataReader(object): or the maximum number of tokens (include paddings) contained in a mini-batch. :type batch_size: int - :param pool_size: The buffer size to pool data. + :param pool_size: The size of pool buffer. :type pool_size: int :param sort_type: The grain to sort by length: 'global' for all instances; 'pool' for instances in pool; 'none' for no sort. :type sort_type: basestring - :param sort_type: The grain to sort by length: 'global' for all - instances; 'pool' for instances in pool; 'none' for no sort. - :type sort_type: basestring :param clip_last_batch: Whether to clip the last uncompleted batch. :type clip_last_batch: bool :param tar_fname: The data file in tar if fpattern matches a tar file. diff --git a/fluid/neural_machine_translation/transformer/train.py b/fluid/neural_machine_translation/transformer/train.py index 90fa7572..d097a079 100644 --- a/fluid/neural_machine_translation/transformer/train.py +++ b/fluid/neural_machine_translation/transformer/train.py @@ -208,7 +208,7 @@ def train(args): def read_multiple(reader, count=dev_count if args.use_token_batch else 1, - clip_last=False): + clip_last=True): """ Stack data from reader for multi-devices. """ -- GitLab