From ce6f2825778d861b99e9772a0625b09802277c75 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Mon, 16 Jul 2018 07:06:11 +0000 Subject: [PATCH] Disable all sorts in data reader --- .../transformer/profile.py | 26 ++++--------------- 1 file changed, 5 insertions(+), 21 deletions(-) diff --git a/fluid/neural_machine_translation/transformer/profile.py b/fluid/neural_machine_translation/transformer/profile.py index 7857fb9a..caf3125b 100644 --- a/fluid/neural_machine_translation/transformer/profile.py +++ b/fluid/neural_machine_translation/transformer/profile.py @@ -58,22 +58,6 @@ def parse_args(): type=int, default=10000, help="The buffer size to pool data.") - parser.add_argument( - "--sort_type", - default="pool", - choices=("global", "pool", "none"), - help="The grain to sort by length: global for all instances; pool for " - "instances in pool; none for no sort.") - parser.add_argument( - "--shuffle", - type=ast.literal_eval, - default=True, - help="The flag indicating whether to shuffle instances.") - parser.add_argument( - "--shuffle_batch", - type=ast.literal_eval, - default=True, - help="The flag indicating whether to shuffle the data batches.") parser.add_argument( "--special_token", type=str, @@ -161,8 +145,7 @@ def train_loop(exe, train_progm, init, num_iters, train_data, dev_count, (batch_id, total_sum_cost, total_avg_cost, np.exp([min(total_avg_cost, 100)]))) init = True - total_time = time.time() - start_time - return total_time, exec_time + return time.time() - start_time, exec_time def profile(args): @@ -205,6 +188,7 @@ def profile(args): else: exe.run(fluid.framework.default_startup_program()) + # Disable all sorts for they will be done in the 1st batch. train_data = reader.DataReader( src_vocab_fpath=args.src_vocab_fpath, trg_vocab_fpath=args.trg_vocab_fpath, @@ -212,9 +196,9 @@ def profile(args): use_token_batch=args.use_token_batch, batch_size=args.batch_size * (1 if args.use_token_batch else dev_count), pool_size=args.pool_size, - sort_type=args.sort_type, - shuffle=args.shuffle, - shuffle_batch=args.shuffle_batch, + sort_type='none', + shuffle=False, + shuffle_batch=False, start_mark=args.special_token[0], end_mark=args.special_token[1], unk_mark=args.special_token[2], -- GitLab