提交 8b032e91 编写于 作者: H Hongkun Yu 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 396717391
上级 e453835a
......@@ -43,10 +43,11 @@ def _create_fake_dataset(output_path, seq_length, num_masked_tokens,
f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
return f
rng = np.random.default_rng(37)
for _ in range(num_examples):
features = {}
padding = np.zeros(shape=(max_seq_length - seq_length), dtype=np.int32)
input_ids = np.random.randint(low=1, high=100, size=(seq_length))
input_ids = rng.integers(low=1, high=100, size=(seq_length))
features['input_ids'] = create_int_feature(
np.concatenate((input_ids, padding)))
features['input_mask'] = create_int_feature(
......@@ -56,9 +57,9 @@ def _create_fake_dataset(output_path, seq_length, num_masked_tokens,
features['position_ids'] = create_int_feature(
np.concatenate((np.ones_like(input_ids), padding)))
features['masked_lm_positions'] = create_int_feature(
np.random.randint(60, size=(num_masked_tokens), dtype=np.int64))
rng.integers(60, size=(num_masked_tokens), dtype=np.int64))
features['masked_lm_ids'] = create_int_feature(
np.random.randint(100, size=(num_masked_tokens), dtype=np.int64))
rng.integers(100, size=(num_masked_tokens), dtype=np.int64))
features['masked_lm_weights'] = create_float_feature(
np.ones((num_masked_tokens,), dtype=np.float32))
features['next_sentence_labels'] = create_int_feature(np.array([0]))
......@@ -156,6 +157,7 @@ class PretrainDynamicDataLoaderTest(tf.test.TestCase, parameterized.TestCase):
self.assertEqual(dynamic_metrics[key], static_metrics[key])
def test_load_dataset(self):
tf.random.set_seed(0)
max_seq_length = 128
batch_size = 2
input_path_1 = os.path.join(self.get_temp_dir(), 'train_1.tf_record')
......@@ -178,7 +180,8 @@ class PretrainDynamicDataLoaderTest(tf.test.TestCase, parameterized.TestCase):
input_path=input_paths,
seq_bucket_lengths=[64, 128],
use_position_id=True,
global_batch_size=batch_size)
global_batch_size=batch_size,
deterministic=True)
dataset = pretrain_dynamic_dataloader.PretrainingDynamicDataLoader(
data_config).load()
dataset_it = iter(dataset)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册