diff --git a/fluid/neural_machine_translation/transformer/config.py b/fluid/neural_machine_translation/transformer/config.py index 74f920a10a2e556e2681f872c778730141c20073..8ab9efce1a275ea9539b05c0b959dee42d83c759 100644 --- a/fluid/neural_machine_translation/transformer/config.py +++ b/fluid/neural_machine_translation/transformer/config.py @@ -5,7 +5,7 @@ class TrainTaskConfig(object): # the number of sequences contained in a mini-batch. batch_size = 32 # the hyper parameters for Adam optimizer. - # This static learning_rate will multiply LearningRateScheduler + # This static learning_rate will be multiplied to the LearningRateScheduler # derived learning rate the to get the final learning rate. learning_rate = 1 beta1 = 0.9 diff --git a/fluid/neural_machine_translation/transformer/reader.py b/fluid/neural_machine_translation/transformer/reader.py index 5daa70a2336cd5b3d8e1c9568174832219d3a9c6..1b9b6b6962ecf4d46ba1d87103b94a3747573a49 100644 --- a/fluid/neural_machine_translation/transformer/reader.py +++ b/fluid/neural_machine_translation/transformer/reader.py @@ -34,7 +34,8 @@ class Pool(object): if self._sort: self._pool.sort( - key=lambda sample: max(len(sample[0]), len(sample[1])) if len(sample) > 1 else len(sample[0]) + key=lambda sample: max(len(sample[0]), len(sample[1])) \ + if len(sample) > 1 else len(sample[0]) ) if self._end and len(self._pool) < self._pool_size: @@ -63,8 +64,28 @@ class Pool(object): class DataReader(object): """ The data reader loads all data from files and produces batches of data - in the way corresponding to settings. - number of tokens or number of sequences. + in the way corresponding to settings. See the doc of __init__ function + for more setting details. + + An example of returning a generator producing data batches whose data + is shuffled in each pass and sorted in each pool: + + ``` + train_data = DataReader( + src_vocab_fpath='data/src_vocab_file', + trg_vocab_fpath='data/trg_vocab_file', + fpattern='data/part-*', + use_token_batch=True, + batch_size=2000, + pool_size=10000, + sort_type=SortType.POOL, + shuffle=True, + shuffle_batch=True, + start_mark='', + end_mark='', + unk_mark='', + clip_last_batch=False).batch_generator + ``` """ def __init__(self, @@ -99,14 +120,11 @@ class DataReader(object): or the maximum number of tokens (include paddings) contained in a mini-batch. :type batch_size: int - :param pool_size: The buffer size to pool data. + :param pool_size: The size of pool buffer. :type pool_size: int :param sort_type: The grain to sort by length: 'global' for all instances; 'pool' for instances in pool; 'none' for no sort. :type sort_type: basestring - :param sort_type: The grain to sort by length: 'global' for all - instances; 'pool' for instances in pool; 'none' for no sort. - :type sort_type: basestring :param clip_last_batch: Whether to clip the last uncompleted batch. :type clip_last_batch: bool :param tar_fname: The data file in tar if fpattern matches a tar file. diff --git a/fluid/neural_machine_translation/transformer/train.py b/fluid/neural_machine_translation/transformer/train.py index 90fa75728679f1153daf5e9533e63b9c1ae64b9d..d097a07973c398669af5b811e3f85303aefca730 100644 --- a/fluid/neural_machine_translation/transformer/train.py +++ b/fluid/neural_machine_translation/transformer/train.py @@ -208,7 +208,7 @@ def train(args): def read_multiple(reader, count=dev_count if args.use_token_batch else 1, - clip_last=False): + clip_last=True): """ Stack data from reader for multi-devices. """