提交 49ac5072 编写于 作者: Y Yibing Liu

Yield dev_count times batches in finetuning for exiting training normally

上级 2ed85a39
......@@ -118,7 +118,12 @@ class DataProcessor(object):
"""Gets progress for training phase."""
return self.current_train_example, self.current_train_epoch
def data_generator(self, batch_size, phase='train', epoch=1, shuffle=True):
def data_generator(self,
batch_size,
phase='train',
epoch=1,
dev_count=1,
shuffle=True):
"""
Generate data for train, dev or test.
......@@ -178,6 +183,7 @@ class DataProcessor(object):
yield batch, total_token_num
def wrapper():
all_dev_batches = []
for batch_data, total_token_num in batch_reader(
instance_reader, batch_size, self.in_tokens):
batch_data = self.generate_batch_data(
......@@ -188,7 +194,12 @@ class DataProcessor(object):
return_input_mask=True,
return_max_len=False,
return_num_token=False)
yield batch_data
if len(all_dev_batches) < dev_count:
all_dev_batches.append(batch_data)
else:
for batch in all_dev_batches:
yield batch
all_dev_batches = [batch_data]
return wrapper
......
......@@ -488,6 +488,7 @@ class DataProcessor(object):
batch_size,
phase='train',
shuffle=False,
dev_count=1,
version_2_with_negative=False,
epoch=1):
if phase == 'train':
......@@ -549,9 +550,10 @@ class DataProcessor(object):
else:
features = self.get_features(examples, is_training=False)
all_dev_batches = []
for batch_data, total_token_num in batch_reader(
features, batch_size, self._in_tokens):
yield prepare_batch_data(
batch_data = prepare_batch_data(
batch_data,
total_token_num,
voc_size=-1,
......@@ -562,6 +564,12 @@ class DataProcessor(object):
return_input_mask=True,
return_max_len=False,
return_num_token=False)
if len(all_dev_batches) < dev_count:
all_dev_batches.append(batch_data)
else:
for batch in all_dev_batches:
yield batch
all_dev_batches = [batch_data]
return wrapper
......
......@@ -148,6 +148,7 @@ def main(args):
batch_size=args.batch_size,
phase='train',
epoch=args.epoch,
dev_count=dev_count,
shuffle=True)
num_train_examples = processor.get_num_examples(phase='train')
......@@ -330,6 +331,7 @@ def main(args):
batch_size=args.batch_size,
phase='dev',
epoch=1,
dev_count=1,
shuffle=False))
evaluate(exe, test_prog, test_pyreader,
[loss.name, accuracy.name, num_seqs.name],
......@@ -341,6 +343,7 @@ def main(args):
batch_size=args.batch_size,
phase='test',
epoch=1,
dev_count=1,
shuffle=False))
evaluate(exe, test_prog, test_pyreader,
[loss.name, accuracy.name, num_seqs.name],
......@@ -355,7 +358,7 @@ def main(args):
if args.do_val:
test_pyreader.decorate_tensor_provider(
processor.data_generator(
batch_size=args.batch_size, phase='dev', epoch=1,
batch_size=args.batch_size, phase='dev', epoch=1, dev_count=1,
shuffle=False))
print("Final validation result:")
evaluate(exe, test_prog, test_pyreader,
......@@ -368,6 +371,7 @@ def main(args):
batch_size=args.batch_size,
phase='test',
epoch=1,
dev_count=1,
shuffle=False))
print("Final test result:")
evaluate(exe, test_prog, test_pyreader,
......
......@@ -242,6 +242,7 @@ def train(args):
batch_size=args.batch_size,
phase='train',
shuffle=False,
dev_count=dev_count,
version_2_with_negative=args.version_2_with_negative,
epoch=args.epoch)
......@@ -413,6 +414,7 @@ def train(args):
batch_size=args.batch_size,
phase='predict',
shuffle=False,
dev_count=1,
epoch=1))
predict(exe, test_prog, test_pyreader, [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册