提交 2d278cdc 编写于 作者: C chenxuyi

bugfix: predict_hangs

+ sanity check in `in tokens` mode
上级 abd48478
...@@ -78,7 +78,7 @@ data_g.add_arg("dev_set", str, None, "Path to validation data.") ...@@ -78,7 +78,7 @@ data_g.add_arg("dev_set", str, None, "Path to validation data.")
data_g.add_arg("vocab_path", str, None, "Vocabulary path.") data_g.add_arg("vocab_path", str, None, "Vocabulary path.")
data_g.add_arg("max_seq_len", int, 512, "Number of words of the longest seqence.") data_g.add_arg("max_seq_len", int, 512, "Number of words of the longest seqence.")
data_g.add_arg("batch_size", int, 32, "Total examples' number in batch for training. see also --in_tokens.") data_g.add_arg("batch_size", int, 32, "Total examples' number in batch for training. see also --in_tokens.")
data_g.add_arg("predict_batch_size", int, 8, "Total examples' number in batch for predict. see also --in_tokens.") data_g.add_arg("predict_batch_size", int, None, "Total examples' number in batch for predict. see also --in_tokens.")
data_g.add_arg("in_tokens", bool, False, data_g.add_arg("in_tokens", bool, False,
"If set, the batch size will be the maximum number of tokens in one batch. " "If set, the batch size will be the maximum number of tokens in one batch. "
"Otherwise, it will be the maximum number of examples in one batch.") "Otherwise, it will be the maximum number of examples in one batch.")
......
...@@ -92,6 +92,8 @@ def main(args): ...@@ -92,6 +92,8 @@ def main(args):
num_train_examples = reader.get_num_examples(args.train_set) num_train_examples = reader.get_num_examples(args.train_set)
if args.in_tokens: if args.in_tokens:
if args.batch_size < args.max_seq_len:
raise ValueError('if in_tokens=True, batch_size should greater than max_sqelen, got batch_size:%d seqlen:%d' % (args.batch_size, args.max_seq_len))
max_train_steps = args.epoch * num_train_examples // ( max_train_steps = args.epoch * num_train_examples // (
args.batch_size // args.max_seq_len) // dev_count args.batch_size // args.max_seq_len) // dev_count
else: else:
...@@ -376,11 +378,12 @@ def main(args): ...@@ -376,11 +378,12 @@ def main(args):
def evaluate_wrapper(args, reader, exe, test_prog, test_pyreader, graph_vars, def evaluate_wrapper(args, reader, exe, test_prog, test_pyreader, graph_vars,
epoch, steps): epoch, steps):
# evaluate dev set # evaluate dev set
batch_size = args.batch_size if args.predict_batch_size is None else args.predict_batch_size
for ds in args.dev_set.split(','): for ds in args.dev_set.split(','):
test_pyreader.decorate_tensor_provider( test_pyreader.decorate_tensor_provider(
reader.data_generator( reader.data_generator(
ds, ds,
batch_size=args.predict_batch_size, batch_size=batch_size,
epoch=1, epoch=1,
dev_count=1, dev_count=1,
shuffle=False)) shuffle=False))
...@@ -403,12 +406,13 @@ def predict_wrapper(args, reader, exe, test_prog, test_pyreader, graph_vars, ...@@ -403,12 +406,13 @@ def predict_wrapper(args, reader, exe, test_prog, test_pyreader, graph_vars,
test_sets = args.test_set.split(',') test_sets = args.test_set.split(',')
save_dirs = args.test_save.split(',') save_dirs = args.test_save.split(',')
assert len(test_sets) == len(save_dirs) assert len(test_sets) == len(save_dirs)
batch_size = args.batch_size if args.predict_batch_size is None else args.predict_batch_size
for test_f, save_f in zip(test_sets, save_dirs): for test_f, save_f in zip(test_sets, save_dirs):
test_pyreader.decorate_tensor_provider( test_pyreader.decorate_tensor_provider(
reader.data_generator( reader.data_generator(
test_f, test_f,
batch_size=args.predict_batch_size, batch_size=batch_size,
epoch=1, epoch=1,
dev_count=1, dev_count=1,
shuffle=False)) shuffle=False))
......
...@@ -95,6 +95,8 @@ def main(args): ...@@ -95,6 +95,8 @@ def main(args):
num_train_examples = reader.get_num_examples("train") num_train_examples = reader.get_num_examples("train")
if args.in_tokens: if args.in_tokens:
if args.batch_size < args.max_seq_len:
raise ValueError('if in_tokens=True, batch_size should greater than max_sqelen, got batch_size:%d seqlen:%d' % (args.batch_size, args.max_seq_len))
max_train_steps = args.epoch * num_train_examples // ( max_train_steps = args.epoch * num_train_examples // (
args.batch_size // args.max_seq_len) // dev_count args.batch_size // args.max_seq_len) // dev_count
else: else:
......
...@@ -85,6 +85,9 @@ def main(args): ...@@ -85,6 +85,9 @@ def main(args):
num_train_examples = reader.get_num_examples(args.train_set) num_train_examples = reader.get_num_examples(args.train_set)
if args.in_tokens: if args.in_tokens:
if args.batch_size < args.max_seq_len:
raise ValueError('if in_tokens=True, batch_size should greater than max_sqelen, got batch_size:%d seqlen:%d' % (args.batch_size, args.max_seq_len))
max_train_steps = args.epoch * num_train_examples // ( max_train_steps = args.epoch * num_train_examples // (
args.batch_size // args.max_seq_len) // dev_count args.batch_size // args.max_seq_len) // dev_count
else: else:
...@@ -297,11 +300,12 @@ def main(args): ...@@ -297,11 +300,12 @@ def main(args):
def evaluate_wrapper(reader, exe, test_prog, test_pyreader, graph_vars, def evaluate_wrapper(reader, exe, test_prog, test_pyreader, graph_vars,
epoch, steps): epoch, steps):
# evaluate dev set # evaluate dev set
batch_size = args.batch_size if args.predict_batch_size is None else args.predict_batch_size
for ds in args.dev_set.split(','): #single card eval for ds in args.dev_set.split(','): #single card eval
test_pyreader.decorate_tensor_provider( test_pyreader.decorate_tensor_provider(
reader.data_generator( reader.data_generator(
ds, ds,
batch_size=args.predict_batch_size, batch_size=batch_size,
epoch=1, epoch=1,
dev_count=1, dev_count=1,
shuffle=False)) shuffle=False))
...@@ -318,10 +322,11 @@ def predict_wrapper(reader, exe, test_prog, test_pyreader, graph_vars, ...@@ -318,10 +322,11 @@ def predict_wrapper(reader, exe, test_prog, test_pyreader, graph_vars,
save_dirs = args.test_save.split(',') save_dirs = args.test_save.split(',')
assert len(test_sets) == len(save_dirs), 'number of test_sets & test_save not match, got %d vs %d' % (len(test_sets), len(save_dirs)) assert len(test_sets) == len(save_dirs), 'number of test_sets & test_save not match, got %d vs %d' % (len(test_sets), len(save_dirs))
batch_size = args.batch_size if args.predict_batch_size is None else args.predict_batch_size
for test_f, save_f in zip(test_sets, save_dirs): for test_f, save_f in zip(test_sets, save_dirs):
test_pyreader.decorate_tensor_provider(reader.data_generator( test_pyreader.decorate_tensor_provider(reader.data_generator(
test_f, test_f,
batch_size=args.predict_batch_size, batch_size=batch_size,
epoch=1, epoch=1,
dev_count=1, dev_count=1,
shuffle=False)) shuffle=False))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册