提交 ce6f2825 编写于 作者: Y Yibing Liu

Disable all sorts in data reader

上级 950914d8
...@@ -58,22 +58,6 @@ def parse_args(): ...@@ -58,22 +58,6 @@ def parse_args():
type=int, type=int,
default=10000, default=10000,
help="The buffer size to pool data.") help="The buffer size to pool data.")
parser.add_argument(
"--sort_type",
default="pool",
choices=("global", "pool", "none"),
help="The grain to sort by length: global for all instances; pool for "
"instances in pool; none for no sort.")
parser.add_argument(
"--shuffle",
type=ast.literal_eval,
default=True,
help="The flag indicating whether to shuffle instances.")
parser.add_argument(
"--shuffle_batch",
type=ast.literal_eval,
default=True,
help="The flag indicating whether to shuffle the data batches.")
parser.add_argument( parser.add_argument(
"--special_token", "--special_token",
type=str, type=str,
...@@ -161,8 +145,7 @@ def train_loop(exe, train_progm, init, num_iters, train_data, dev_count, ...@@ -161,8 +145,7 @@ def train_loop(exe, train_progm, init, num_iters, train_data, dev_count,
(batch_id, total_sum_cost, total_avg_cost, (batch_id, total_sum_cost, total_avg_cost,
np.exp([min(total_avg_cost, 100)]))) np.exp([min(total_avg_cost, 100)])))
init = True init = True
total_time = time.time() - start_time return time.time() - start_time, exec_time
return total_time, exec_time
def profile(args): def profile(args):
...@@ -205,6 +188,7 @@ def profile(args): ...@@ -205,6 +188,7 @@ def profile(args):
else: else:
exe.run(fluid.framework.default_startup_program()) exe.run(fluid.framework.default_startup_program())
# Disable all sorts for they will be done in the 1st batch.
train_data = reader.DataReader( train_data = reader.DataReader(
src_vocab_fpath=args.src_vocab_fpath, src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath, trg_vocab_fpath=args.trg_vocab_fpath,
...@@ -212,9 +196,9 @@ def profile(args): ...@@ -212,9 +196,9 @@ def profile(args):
use_token_batch=args.use_token_batch, use_token_batch=args.use_token_batch,
batch_size=args.batch_size * (1 if args.use_token_batch else dev_count), batch_size=args.batch_size * (1 if args.use_token_batch else dev_count),
pool_size=args.pool_size, pool_size=args.pool_size,
sort_type=args.sort_type, sort_type='none',
shuffle=args.shuffle, shuffle=False,
shuffle_batch=args.shuffle_batch, shuffle_batch=False,
start_mark=args.special_token[0], start_mark=args.special_token[0],
end_mark=args.special_token[1], end_mark=args.special_token[1],
unk_mark=args.special_token[2], unk_mark=args.special_token[2],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册