提交 49ac5072 编写于 作者: Y Yibing Liu

Yield dev_count times batches in finetuning for exiting training normally

上级 2ed85a39
...@@ -118,7 +118,12 @@ class DataProcessor(object): ...@@ -118,7 +118,12 @@ class DataProcessor(object):
"""Gets progress for training phase.""" """Gets progress for training phase."""
return self.current_train_example, self.current_train_epoch return self.current_train_example, self.current_train_epoch
def data_generator(self, batch_size, phase='train', epoch=1, shuffle=True): def data_generator(self,
batch_size,
phase='train',
epoch=1,
dev_count=1,
shuffle=True):
""" """
Generate data for train, dev or test. Generate data for train, dev or test.
...@@ -178,6 +183,7 @@ class DataProcessor(object): ...@@ -178,6 +183,7 @@ class DataProcessor(object):
yield batch, total_token_num yield batch, total_token_num
def wrapper(): def wrapper():
all_dev_batches = []
for batch_data, total_token_num in batch_reader( for batch_data, total_token_num in batch_reader(
instance_reader, batch_size, self.in_tokens): instance_reader, batch_size, self.in_tokens):
batch_data = self.generate_batch_data( batch_data = self.generate_batch_data(
...@@ -188,7 +194,12 @@ class DataProcessor(object): ...@@ -188,7 +194,12 @@ class DataProcessor(object):
return_input_mask=True, return_input_mask=True,
return_max_len=False, return_max_len=False,
return_num_token=False) return_num_token=False)
yield batch_data if len(all_dev_batches) < dev_count:
all_dev_batches.append(batch_data)
else:
for batch in all_dev_batches:
yield batch
all_dev_batches = [batch_data]
return wrapper return wrapper
......
...@@ -488,6 +488,7 @@ class DataProcessor(object): ...@@ -488,6 +488,7 @@ class DataProcessor(object):
batch_size, batch_size,
phase='train', phase='train',
shuffle=False, shuffle=False,
dev_count=1,
version_2_with_negative=False, version_2_with_negative=False,
epoch=1): epoch=1):
if phase == 'train': if phase == 'train':
...@@ -549,9 +550,10 @@ class DataProcessor(object): ...@@ -549,9 +550,10 @@ class DataProcessor(object):
else: else:
features = self.get_features(examples, is_training=False) features = self.get_features(examples, is_training=False)
all_dev_batches = []
for batch_data, total_token_num in batch_reader( for batch_data, total_token_num in batch_reader(
features, batch_size, self._in_tokens): features, batch_size, self._in_tokens):
yield prepare_batch_data( batch_data = prepare_batch_data(
batch_data, batch_data,
total_token_num, total_token_num,
voc_size=-1, voc_size=-1,
...@@ -562,6 +564,12 @@ class DataProcessor(object): ...@@ -562,6 +564,12 @@ class DataProcessor(object):
return_input_mask=True, return_input_mask=True,
return_max_len=False, return_max_len=False,
return_num_token=False) return_num_token=False)
if len(all_dev_batches) < dev_count:
all_dev_batches.append(batch_data)
else:
for batch in all_dev_batches:
yield batch
all_dev_batches = [batch_data]
return wrapper return wrapper
......
...@@ -148,6 +148,7 @@ def main(args): ...@@ -148,6 +148,7 @@ def main(args):
batch_size=args.batch_size, batch_size=args.batch_size,
phase='train', phase='train',
epoch=args.epoch, epoch=args.epoch,
dev_count=dev_count,
shuffle=True) shuffle=True)
num_train_examples = processor.get_num_examples(phase='train') num_train_examples = processor.get_num_examples(phase='train')
...@@ -330,6 +331,7 @@ def main(args): ...@@ -330,6 +331,7 @@ def main(args):
batch_size=args.batch_size, batch_size=args.batch_size,
phase='dev', phase='dev',
epoch=1, epoch=1,
dev_count=1,
shuffle=False)) shuffle=False))
evaluate(exe, test_prog, test_pyreader, evaluate(exe, test_prog, test_pyreader,
[loss.name, accuracy.name, num_seqs.name], [loss.name, accuracy.name, num_seqs.name],
...@@ -341,6 +343,7 @@ def main(args): ...@@ -341,6 +343,7 @@ def main(args):
batch_size=args.batch_size, batch_size=args.batch_size,
phase='test', phase='test',
epoch=1, epoch=1,
dev_count=1,
shuffle=False)) shuffle=False))
evaluate(exe, test_prog, test_pyreader, evaluate(exe, test_prog, test_pyreader,
[loss.name, accuracy.name, num_seqs.name], [loss.name, accuracy.name, num_seqs.name],
...@@ -355,7 +358,7 @@ def main(args): ...@@ -355,7 +358,7 @@ def main(args):
if args.do_val: if args.do_val:
test_pyreader.decorate_tensor_provider( test_pyreader.decorate_tensor_provider(
processor.data_generator( processor.data_generator(
batch_size=args.batch_size, phase='dev', epoch=1, batch_size=args.batch_size, phase='dev', epoch=1, dev_count=1,
shuffle=False)) shuffle=False))
print("Final validation result:") print("Final validation result:")
evaluate(exe, test_prog, test_pyreader, evaluate(exe, test_prog, test_pyreader,
...@@ -368,6 +371,7 @@ def main(args): ...@@ -368,6 +371,7 @@ def main(args):
batch_size=args.batch_size, batch_size=args.batch_size,
phase='test', phase='test',
epoch=1, epoch=1,
dev_count=1,
shuffle=False)) shuffle=False))
print("Final test result:") print("Final test result:")
evaluate(exe, test_prog, test_pyreader, evaluate(exe, test_prog, test_pyreader,
......
...@@ -242,6 +242,7 @@ def train(args): ...@@ -242,6 +242,7 @@ def train(args):
batch_size=args.batch_size, batch_size=args.batch_size,
phase='train', phase='train',
shuffle=False, shuffle=False,
dev_count=dev_count,
version_2_with_negative=args.version_2_with_negative, version_2_with_negative=args.version_2_with_negative,
epoch=args.epoch) epoch=args.epoch)
...@@ -413,6 +414,7 @@ def train(args): ...@@ -413,6 +414,7 @@ def train(args):
batch_size=args.batch_size, batch_size=args.batch_size,
phase='predict', phase='predict',
shuffle=False, shuffle=False,
dev_count=1,
epoch=1)) epoch=1))
predict(exe, test_prog, test_pyreader, [ predict(exe, test_prog, test_pyreader, [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册