From 58066a1e2e8f7d21df368f361c9f3c1421eadbd0 Mon Sep 17 00:00:00 2001 From: chenxuyi Date: Wed, 4 Sep 2019 14:36:55 +0800 Subject: [PATCH] default `--predict_batch_size=8` --- finetune_args.py | 2 +- run_classifier.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/finetune_args.py b/finetune_args.py index 15a2038..60100f8 100644 --- a/finetune_args.py +++ b/finetune_args.py @@ -78,7 +78,7 @@ data_g.add_arg("dev_set", str, None, "Path to validation data.") data_g.add_arg("vocab_path", str, None, "Vocabulary path.") data_g.add_arg("max_seq_len", int, 512, "Number of words of the longest seqence.") data_g.add_arg("batch_size", int, 32, "Total examples' number in batch for training. see also --in_tokens.") -data_g.add_arg("predict_batch_size", int, None, "Total examples' number in batch for predict. see also --in_tokens.") +data_g.add_arg("predict_batch_size", int, 8, "Total examples' number in batch for predict. see also --in_tokens.") data_g.add_arg("in_tokens", bool, False, "If set, the batch size will be the maximum number of tokens in one batch. " "Otherwise, it will be the maximum number of examples in one batch.") diff --git a/run_classifier.py b/run_classifier.py index c4d019d..0de6702 100644 --- a/run_classifier.py +++ b/run_classifier.py @@ -80,8 +80,6 @@ def main(args): if args.random_seed is not None: startup_prog.random_seed = args.random_seed - if args.predict_batch_size == None: - args.predict_batch_size = args.batch_size if args.do_train: train_data_generator = reader.data_generator( input_file=args.train_set, -- GitLab