From 49ac507214362224c5eceba781ac72c5e133f739 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Wed, 10 Apr 2019 07:28:29 +0000 Subject: [PATCH] Yield dev_count times batches in finetuning for exiting training normally --- BERT/reader/cls.py | 15 +++++++++++++-- BERT/reader/squad.py | 10 +++++++++- BERT/run_classifier.py | 6 +++++- BERT/run_squad.py | 2 ++ 4 files changed, 29 insertions(+), 4 deletions(-) diff --git a/BERT/reader/cls.py b/BERT/reader/cls.py index 767d817..cebd583 100644 --- a/BERT/reader/cls.py +++ b/BERT/reader/cls.py @@ -118,7 +118,12 @@ class DataProcessor(object): """Gets progress for training phase.""" return self.current_train_example, self.current_train_epoch - def data_generator(self, batch_size, phase='train', epoch=1, shuffle=True): + def data_generator(self, + batch_size, + phase='train', + epoch=1, + dev_count=1, + shuffle=True): """ Generate data for train, dev or test. @@ -178,6 +183,7 @@ class DataProcessor(object): yield batch, total_token_num def wrapper(): + all_dev_batches = [] for batch_data, total_token_num in batch_reader( instance_reader, batch_size, self.in_tokens): batch_data = self.generate_batch_data( @@ -188,7 +194,12 @@ class DataProcessor(object): return_input_mask=True, return_max_len=False, return_num_token=False) - yield batch_data + if len(all_dev_batches) < dev_count: + all_dev_batches.append(batch_data) + else: + for batch in all_dev_batches: + yield batch + all_dev_batches = [batch_data] return wrapper diff --git a/BERT/reader/squad.py b/BERT/reader/squad.py index 0c9deea..4c69763 100644 --- a/BERT/reader/squad.py +++ b/BERT/reader/squad.py @@ -488,6 +488,7 @@ class DataProcessor(object): batch_size, phase='train', shuffle=False, + dev_count=1, version_2_with_negative=False, epoch=1): if phase == 'train': @@ -549,9 +550,10 @@ class DataProcessor(object): else: features = self.get_features(examples, is_training=False) + all_dev_batches = [] for batch_data, total_token_num in batch_reader( features, batch_size, self._in_tokens): - yield prepare_batch_data( + batch_data = prepare_batch_data( batch_data, total_token_num, voc_size=-1, @@ -562,6 +564,12 @@ class DataProcessor(object): return_input_mask=True, return_max_len=False, return_num_token=False) + if len(all_dev_batches) < dev_count: + all_dev_batches.append(batch_data) + else: + for batch in all_dev_batches: + yield batch + all_dev_batches = [batch_data] return wrapper diff --git a/BERT/run_classifier.py b/BERT/run_classifier.py index f28bdcc..2835a8b 100644 --- a/BERT/run_classifier.py +++ b/BERT/run_classifier.py @@ -148,6 +148,7 @@ def main(args): batch_size=args.batch_size, phase='train', epoch=args.epoch, + dev_count=dev_count, shuffle=True) num_train_examples = processor.get_num_examples(phase='train') @@ -330,6 +331,7 @@ def main(args): batch_size=args.batch_size, phase='dev', epoch=1, + dev_count=1, shuffle=False)) evaluate(exe, test_prog, test_pyreader, [loss.name, accuracy.name, num_seqs.name], @@ -341,6 +343,7 @@ def main(args): batch_size=args.batch_size, phase='test', epoch=1, + dev_count=1, shuffle=False)) evaluate(exe, test_prog, test_pyreader, [loss.name, accuracy.name, num_seqs.name], @@ -355,7 +358,7 @@ def main(args): if args.do_val: test_pyreader.decorate_tensor_provider( processor.data_generator( - batch_size=args.batch_size, phase='dev', epoch=1, + batch_size=args.batch_size, phase='dev', epoch=1, dev_count=1, shuffle=False)) print("Final validation result:") evaluate(exe, test_prog, test_pyreader, @@ -368,6 +371,7 @@ def main(args): batch_size=args.batch_size, phase='test', epoch=1, + dev_count=1, shuffle=False)) print("Final test result:") evaluate(exe, test_prog, test_pyreader, diff --git a/BERT/run_squad.py b/BERT/run_squad.py index 1df3a83..9607a8b 100644 --- a/BERT/run_squad.py +++ b/BERT/run_squad.py @@ -242,6 +242,7 @@ def train(args): batch_size=args.batch_size, phase='train', shuffle=False, + dev_count=dev_count, version_2_with_negative=args.version_2_with_negative, epoch=args.epoch) @@ -413,6 +414,7 @@ def train(args): batch_size=args.batch_size, phase='predict', shuffle=False, + dev_count=1, epoch=1)) predict(exe, test_prog, test_pyreader, [ -- GitLab