diff --git a/finetune_args.py b/finetune_args.py index 60100f87769245fefdc16d52ca7372b5bf18f460..15a20382ea0352c1094dce988f0689b05762472a 100644 --- a/finetune_args.py +++ b/finetune_args.py @@ -78,7 +78,7 @@ data_g.add_arg("dev_set", str, None, "Path to validation data.") data_g.add_arg("vocab_path", str, None, "Vocabulary path.") data_g.add_arg("max_seq_len", int, 512, "Number of words of the longest seqence.") data_g.add_arg("batch_size", int, 32, "Total examples' number in batch for training. see also --in_tokens.") -data_g.add_arg("predict_batch_size", int, 8, "Total examples' number in batch for predict. see also --in_tokens.") +data_g.add_arg("predict_batch_size", int, None, "Total examples' number in batch for predict. see also --in_tokens.") data_g.add_arg("in_tokens", bool, False, "If set, the batch size will be the maximum number of tokens in one batch. " "Otherwise, it will be the maximum number of examples in one batch.") diff --git a/run_classifier.py b/run_classifier.py index 47286f5a5761e57eb5bf3cf6894620f375f7a134..7e1dd826528e6ebe2f1a1106db0c16464b2036a4 100644 --- a/run_classifier.py +++ b/run_classifier.py @@ -92,6 +92,8 @@ def main(args): num_train_examples = reader.get_num_examples(args.train_set) if args.in_tokens: + if args.batch_size < args.max_seq_len: + raise ValueError('if in_tokens=True, batch_size should greater than max_sqelen, got batch_size:%d seqlen:%d' % (args.batch_size, args.max_seq_len)) max_train_steps = args.epoch * num_train_examples // ( args.batch_size // args.max_seq_len) // dev_count else: @@ -376,11 +378,12 @@ def main(args): def evaluate_wrapper(args, reader, exe, test_prog, test_pyreader, graph_vars, epoch, steps): # evaluate dev set + batch_size = args.batch_size if args.predict_batch_size is None else args.predict_batch_size for ds in args.dev_set.split(','): test_pyreader.decorate_tensor_provider( reader.data_generator( ds, - batch_size=args.predict_batch_size, + batch_size=batch_size, epoch=1, dev_count=1, shuffle=False)) @@ -403,12 +406,13 @@ def predict_wrapper(args, reader, exe, test_prog, test_pyreader, graph_vars, test_sets = args.test_set.split(',') save_dirs = args.test_save.split(',') assert len(test_sets) == len(save_dirs) + batch_size = args.batch_size if args.predict_batch_size is None else args.predict_batch_size for test_f, save_f in zip(test_sets, save_dirs): test_pyreader.decorate_tensor_provider( reader.data_generator( test_f, - batch_size=args.predict_batch_size, + batch_size=batch_size, epoch=1, dev_count=1, shuffle=False)) diff --git a/run_mrc.py b/run_mrc.py index 487b6ba60c1213631d438e77ba1b8bff76bfb86b..51e5efd9b20f5902757d83525d0db681d3a78611 100644 --- a/run_mrc.py +++ b/run_mrc.py @@ -95,6 +95,8 @@ def main(args): num_train_examples = reader.get_num_examples("train") if args.in_tokens: + if args.batch_size < args.max_seq_len: + raise ValueError('if in_tokens=True, batch_size should greater than max_sqelen, got batch_size:%d seqlen:%d' % (args.batch_size, args.max_seq_len)) max_train_steps = args.epoch * num_train_examples // ( args.batch_size // args.max_seq_len) // dev_count else: diff --git a/run_sequence_labeling.py b/run_sequence_labeling.py index 756f6ab67b21ca6b577bb9dab15105e2accdce18..ce8b27e711c0cba77bf146bc3a708e4fc94683bf 100644 --- a/run_sequence_labeling.py +++ b/run_sequence_labeling.py @@ -85,6 +85,9 @@ def main(args): num_train_examples = reader.get_num_examples(args.train_set) if args.in_tokens: + if args.batch_size < args.max_seq_len: + raise ValueError('if in_tokens=True, batch_size should greater than max_sqelen, got batch_size:%d seqlen:%d' % (args.batch_size, args.max_seq_len)) + max_train_steps = args.epoch * num_train_examples // ( args.batch_size // args.max_seq_len) // dev_count else: @@ -297,11 +300,12 @@ def main(args): def evaluate_wrapper(reader, exe, test_prog, test_pyreader, graph_vars, epoch, steps): # evaluate dev set + batch_size = args.batch_size if args.predict_batch_size is None else args.predict_batch_size for ds in args.dev_set.split(','): #single card eval test_pyreader.decorate_tensor_provider( reader.data_generator( ds, - batch_size=args.predict_batch_size, + batch_size=batch_size, epoch=1, dev_count=1, shuffle=False)) @@ -318,10 +322,11 @@ def predict_wrapper(reader, exe, test_prog, test_pyreader, graph_vars, save_dirs = args.test_save.split(',') assert len(test_sets) == len(save_dirs), 'number of test_sets & test_save not match, got %d vs %d' % (len(test_sets), len(save_dirs)) + batch_size = args.batch_size if args.predict_batch_size is None else args.predict_batch_size for test_f, save_f in zip(test_sets, save_dirs): test_pyreader.decorate_tensor_provider(reader.data_generator( test_f, - batch_size=args.predict_batch_size, + batch_size=batch_size, epoch=1, dev_count=1, shuffle=False))