diff --git a/BERT/reader/cls.py b/BERT/reader/cls.py index 767d817ba31fd7df73466919f848d94338e997ff..cebd583549eb8d138984e2a0d92b05e7dfbda8c5 100644 --- a/BERT/reader/cls.py +++ b/BERT/reader/cls.py @@ -118,7 +118,12 @@ class DataProcessor(object): """Gets progress for training phase.""" return self.current_train_example, self.current_train_epoch - def data_generator(self, batch_size, phase='train', epoch=1, shuffle=True): + def data_generator(self, + batch_size, + phase='train', + epoch=1, + dev_count=1, + shuffle=True): """ Generate data for train, dev or test. @@ -178,6 +183,7 @@ class DataProcessor(object): yield batch, total_token_num def wrapper(): + all_dev_batches = [] for batch_data, total_token_num in batch_reader( instance_reader, batch_size, self.in_tokens): batch_data = self.generate_batch_data( @@ -188,7 +194,12 @@ class DataProcessor(object): return_input_mask=True, return_max_len=False, return_num_token=False) - yield batch_data + if len(all_dev_batches) < dev_count: + all_dev_batches.append(batch_data) + else: + for batch in all_dev_batches: + yield batch + all_dev_batches = [batch_data] return wrapper diff --git a/BERT/reader/squad.py b/BERT/reader/squad.py index 0c9deea53de08171b8294f991ed4d99cd5114662..4c69763fdb7ec3cc72fb1d13ca89956965ee3c68 100644 --- a/BERT/reader/squad.py +++ b/BERT/reader/squad.py @@ -488,6 +488,7 @@ class DataProcessor(object): batch_size, phase='train', shuffle=False, + dev_count=1, version_2_with_negative=False, epoch=1): if phase == 'train': @@ -549,9 +550,10 @@ class DataProcessor(object): else: features = self.get_features(examples, is_training=False) + all_dev_batches = [] for batch_data, total_token_num in batch_reader( features, batch_size, self._in_tokens): - yield prepare_batch_data( + batch_data = prepare_batch_data( batch_data, total_token_num, voc_size=-1, @@ -562,6 +564,12 @@ class DataProcessor(object): return_input_mask=True, return_max_len=False, return_num_token=False) + if len(all_dev_batches) < dev_count: + all_dev_batches.append(batch_data) + else: + for batch in all_dev_batches: + yield batch + all_dev_batches = [batch_data] return wrapper diff --git a/BERT/run_classifier.py b/BERT/run_classifier.py index f28bdcc1cb1f453dfad7ac3fbee249fd3db6ba3b..2835a8bef1cdafd07b898f792605851b82ad3ddd 100644 --- a/BERT/run_classifier.py +++ b/BERT/run_classifier.py @@ -148,6 +148,7 @@ def main(args): batch_size=args.batch_size, phase='train', epoch=args.epoch, + dev_count=dev_count, shuffle=True) num_train_examples = processor.get_num_examples(phase='train') @@ -330,6 +331,7 @@ def main(args): batch_size=args.batch_size, phase='dev', epoch=1, + dev_count=1, shuffle=False)) evaluate(exe, test_prog, test_pyreader, [loss.name, accuracy.name, num_seqs.name], @@ -341,6 +343,7 @@ def main(args): batch_size=args.batch_size, phase='test', epoch=1, + dev_count=1, shuffle=False)) evaluate(exe, test_prog, test_pyreader, [loss.name, accuracy.name, num_seqs.name], @@ -355,7 +358,7 @@ def main(args): if args.do_val: test_pyreader.decorate_tensor_provider( processor.data_generator( - batch_size=args.batch_size, phase='dev', epoch=1, + batch_size=args.batch_size, phase='dev', epoch=1, dev_count=1, shuffle=False)) print("Final validation result:") evaluate(exe, test_prog, test_pyreader, @@ -368,6 +371,7 @@ def main(args): batch_size=args.batch_size, phase='test', epoch=1, + dev_count=1, shuffle=False)) print("Final test result:") evaluate(exe, test_prog, test_pyreader, diff --git a/BERT/run_squad.py b/BERT/run_squad.py index 1df3a83cb778fc618657bce40cb96b40984375a1..9607a8be0a79aaf77b4f1f4df9fe91b3f1ff28ce 100644 --- a/BERT/run_squad.py +++ b/BERT/run_squad.py @@ -242,6 +242,7 @@ def train(args): batch_size=args.batch_size, phase='train', shuffle=False, + dev_count=dev_count, version_2_with_negative=args.version_2_with_negative, epoch=args.epoch) @@ -413,6 +414,7 @@ def train(args): batch_size=args.batch_size, phase='predict', shuffle=False, + dev_count=1, epoch=1)) predict(exe, test_prog, test_pyreader, [