diff --git a/nce_cost/train.py b/nce_cost/train.py index 4ab5043725805003cf151c6d0c8af8dbbc8c199f..94b761880a93fa5b7d0d42d7128ffb2625fe24b8 100644 --- a/nce_cost/train.py +++ b/nce_cost/train.py @@ -43,9 +43,12 @@ def train(model_save_dir): parameters.to_tar(f) trainer.train( - paddle.batch(paddle.dataset.imikolov.train(word_dict, 5), 64), - num_passes=1000, - event_handler=event_handler) + paddle.batch( + paddle.reader.shuffle( + lambda: paddle.dataset.imikolov.train(word_dict, 5)(), + buf_size=1000), 64), + num_passes=1000, + event_handler=event_handler) if __name__ == "__main__":