提交 d3a3f143 编写于 作者: H Hongkun Yu 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 423199224
上级 8b4d4598
......@@ -79,17 +79,29 @@ class PretrainingDynamicDataLoader(pretrain_dataloader.BertPretrainDataLoader):
def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example."""
name_to_features = {
'input_ids': tf.io.VarLenFeature(tf.int64),
'input_mask': tf.io.VarLenFeature(tf.int64),
'segment_ids': tf.io.VarLenFeature(tf.int64),
'masked_lm_positions': tf.io.VarLenFeature(tf.int64),
'masked_lm_ids': tf.io.VarLenFeature(tf.int64),
'masked_lm_weights': tf.io.VarLenFeature(tf.float32),
}
if self._params.use_v2_feature_names:
input_ids_key = 'input_word_ids'
segment_key = 'input_type_ids'
name_to_features.update({
input_ids_key: tf.io.VarLenFeature(tf.int64),
segment_key: tf.io.VarLenFeature(tf.int64),
})
else:
input_ids_key = 'input_ids'
segment_key = 'segment_ids'
name_to_features.update({
input_ids_key: tf.io.VarLenFeature(tf.int64),
segment_key: tf.io.VarLenFeature(tf.int64),
})
if self._use_next_sentence_label:
name_to_features['next_sentence_labels'] = tf.io.FixedLenFeature([1],
tf.int64)
dynamic_keys = ['input_ids', 'input_mask', 'segment_ids']
dynamic_keys = [input_ids_key, 'input_mask', segment_key]
if self._use_position_id:
name_to_features['position_ids'] = tf.io.VarLenFeature(tf.int64)
dynamic_keys.append('position_ids')
......@@ -102,7 +114,7 @@ class PretrainingDynamicDataLoader(pretrain_dataloader.BertPretrainDataLoader):
# sequence length dimension.
# Pad before the first non pad from the back should not be removed.
mask = tf.math.greater(
tf.math.cumsum(example['input_ids'], reverse=True), 0)
tf.math.cumsum(example[input_ids_key], reverse=True), 0)
for key in dynamic_keys:
example[key] = tf.boolean_mask(example[key], mask)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册