提交 5c15ce77 编写于 作者: H Hongkun Yu 提交者: A. Unique TensorFlower

Fix a mistake in previous change

PiperOrigin-RevId: 281409019
上级 252e6384
......@@ -59,12 +59,10 @@ def get_pretrain_dataset_fn(input_file_pattern, seq_length,
"""Returns input dataset from input file string."""
def _dataset_fn(ctx=None):
"""Returns tf.data.Dataset for distributed BERT pretraining."""
input_files = []
for input_pattern in input_file_pattern.split(','):
input_files.extend(tf.io.gfile.glob(input_pattern))
input_patterns = input_file_pattern.split(',')
batch_size = ctx.get_per_replica_batch_size(global_batch_size)
train_dataset = input_pipeline.create_pretrain_dataset(
input_files,
input_patterns,
seq_length,
max_predictions_per_seq,
batch_size,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册