From 2d278cdce8ab24a38a8cd4a86983c9769a39ce09 Mon Sep 17 00:00:00 2001 From: chenxuyi Date: Sun, 29 Sep 2019 11:01:54 +0800 Subject: [PATCH] bugfix: predict_hangs + sanity check in `in tokens` mode --- finetune_args.py | 2 +- run_classifier.py | 8 ++++++-- run_mrc.py | 2 ++ run_sequence_labeling.py | 9 +++++++-- 4 files changed, 16 insertions(+), 5 deletions(-) diff --git a/finetune_args.py b/finetune_args.py index 60100f8..15a2038 100644 --- a/finetune_args.py +++ b/finetune_args.py @@ -78,7 +78,7 @@ data_g.add_arg("dev_set", str, None, "Path to validation data.") data_g.add_arg("vocab_path", str, None, "Vocabulary path.") data_g.add_arg("max_seq_len", int, 512, "Number of words of the longest seqence.") data_g.add_arg("batch_size", int, 32, "Total examples' number in batch for training. see also --in_tokens.") -data_g.add_arg("predict_batch_size", int, 8, "Total examples' number in batch for predict. see also --in_tokens.") +data_g.add_arg("predict_batch_size", int, None, "Total examples' number in batch for predict. see also --in_tokens.") data_g.add_arg("in_tokens", bool, False, "If set, the batch size will be the maximum number of tokens in one batch. " "Otherwise, it will be the maximum number of examples in one batch.") diff --git a/run_classifier.py b/run_classifier.py index 47286f5..7e1dd82 100644 --- a/run_classifier.py +++ b/run_classifier.py @@ -92,6 +92,8 @@ def main(args): num_train_examples = reader.get_num_examples(args.train_set) if args.in_tokens: + if args.batch_size < args.max_seq_len: + raise ValueError('if in_tokens=True, batch_size should greater than max_sqelen, got batch_size:%d seqlen:%d' % (args.batch_size, args.max_seq_len)) max_train_steps = args.epoch * num_train_examples // ( args.batch_size // args.max_seq_len) // dev_count else: @@ -376,11 +378,12 @@ def main(args): def evaluate_wrapper(args, reader, exe, test_prog, test_pyreader, graph_vars, epoch, steps): # evaluate dev set + batch_size = args.batch_size if args.predict_batch_size is None else args.predict_batch_size for ds in args.dev_set.split(','): test_pyreader.decorate_tensor_provider( reader.data_generator( ds, - batch_size=args.predict_batch_size, + batch_size=batch_size, epoch=1, dev_count=1, shuffle=False)) @@ -403,12 +406,13 @@ def predict_wrapper(args, reader, exe, test_prog, test_pyreader, graph_vars, test_sets = args.test_set.split(',') save_dirs = args.test_save.split(',') assert len(test_sets) == len(save_dirs) + batch_size = args.batch_size if args.predict_batch_size is None else args.predict_batch_size for test_f, save_f in zip(test_sets, save_dirs): test_pyreader.decorate_tensor_provider( reader.data_generator( test_f, - batch_size=args.predict_batch_size, + batch_size=batch_size, epoch=1, dev_count=1, shuffle=False)) diff --git a/run_mrc.py b/run_mrc.py index 487b6ba..51e5efd 100644 --- a/run_mrc.py +++ b/run_mrc.py @@ -95,6 +95,8 @@ def main(args): num_train_examples = reader.get_num_examples("train") if args.in_tokens: + if args.batch_size < args.max_seq_len: + raise ValueError('if in_tokens=True, batch_size should greater than max_sqelen, got batch_size:%d seqlen:%d' % (args.batch_size, args.max_seq_len)) max_train_steps = args.epoch * num_train_examples // ( args.batch_size // args.max_seq_len) // dev_count else: diff --git a/run_sequence_labeling.py b/run_sequence_labeling.py index 756f6ab..ce8b27e 100644 --- a/run_sequence_labeling.py +++ b/run_sequence_labeling.py @@ -85,6 +85,9 @@ def main(args): num_train_examples = reader.get_num_examples(args.train_set) if args.in_tokens: + if args.batch_size < args.max_seq_len: + raise ValueError('if in_tokens=True, batch_size should greater than max_sqelen, got batch_size:%d seqlen:%d' % (args.batch_size, args.max_seq_len)) + max_train_steps = args.epoch * num_train_examples // ( args.batch_size // args.max_seq_len) // dev_count else: @@ -297,11 +300,12 @@ def main(args): def evaluate_wrapper(reader, exe, test_prog, test_pyreader, graph_vars, epoch, steps): # evaluate dev set + batch_size = args.batch_size if args.predict_batch_size is None else args.predict_batch_size for ds in args.dev_set.split(','): #single card eval test_pyreader.decorate_tensor_provider( reader.data_generator( ds, - batch_size=args.predict_batch_size, + batch_size=batch_size, epoch=1, dev_count=1, shuffle=False)) @@ -318,10 +322,11 @@ def predict_wrapper(reader, exe, test_prog, test_pyreader, graph_vars, save_dirs = args.test_save.split(',') assert len(test_sets) == len(save_dirs), 'number of test_sets & test_save not match, got %d vs %d' % (len(test_sets), len(save_dirs)) + batch_size = args.batch_size if args.predict_batch_size is None else args.predict_batch_size for test_f, save_f in zip(test_sets, save_dirs): test_pyreader.decorate_tensor_provider(reader.data_generator( test_f, - batch_size=args.predict_batch_size, + batch_size=batch_size, epoch=1, dev_count=1, shuffle=False)) -- GitLab