提交 9b8b13e8 编写于 作者: H Hongkun Yu 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 330982547
上级 3a18c6ab
......@@ -406,8 +406,6 @@ class MobileBERTEncoder(tf.keras.Model):
num_feedforward_networks=4,
normalization_type='no_norm',
classifier_activation=False,
return_all_layers=False,
return_attention_score=False,
**kwargs):
"""Class initialization.
......@@ -438,8 +436,6 @@ class MobileBERTEncoder(tf.keras.Model):
MobileBERT paper. 'layer_norm' is used for the teacher model.
classifier_activation: If using the tanh activation for the final
representation of the [CLS] token in fine-tuning.
return_all_layers: If return all layer outputs.
return_attention_score: If return attention scores for each layer.
**kwargs: Other keyworded and arguments.
"""
self._self_setattr_tracking = False
......@@ -513,12 +509,11 @@ class MobileBERTEncoder(tf.keras.Model):
else:
self._pooler_layer = None
if return_all_layers:
outputs = [all_layer_outputs, first_token]
else:
outputs = [prev_output, first_token]
if return_attention_score:
outputs.append(all_attention_scores)
outputs = dict(
sequence_output=prev_output,
pooled_output=first_token,
encoder_outputs=all_layer_outputs,
attention_scores=all_attention_scores)
super(MobileBERTEncoder, self).__init__(
inputs=self.inputs, outputs=outputs, **kwargs)
......
......@@ -32,7 +32,7 @@ def generate_fake_input(batch_size=1, seq_len=5, vocab_size=10000, seed=0):
return fake_input
class ModelingTest(parameterized.TestCase, tf.test.TestCase):
class MobileBertEncoderTest(parameterized.TestCase, tf.test.TestCase):
def test_embedding_layer_with_token_type(self):
layer = mobile_bert_encoder.MobileBertEmbedding(10, 8, 2, 16)
......@@ -116,7 +116,9 @@ class ModelingTest(parameterized.TestCase, tf.test.TestCase):
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
layer_output, pooler_output = test_network([word_ids, mask, type_ids])
outputs = test_network([word_ids, mask, type_ids])
layer_output, pooler_output = outputs['sequence_output'], outputs[
'pooled_output']
self.assertIsInstance(test_network.transformer_layers, list)
self.assertLen(test_network.transformer_layers, num_blocks)
......@@ -134,13 +136,13 @@ class ModelingTest(parameterized.TestCase, tf.test.TestCase):
test_network = mobile_bert_encoder.MobileBERTEncoder(
word_vocab_size=100,
hidden_size=hidden_size,
num_blocks=num_blocks,
return_all_layers=True)
num_blocks=num_blocks)
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
all_layer_output, _ = test_network([word_ids, mask, type_ids])
outputs = test_network([word_ids, mask, type_ids])
all_layer_output = outputs['encoder_outputs']
self.assertIsInstance(all_layer_output, list)
self.assertLen(all_layer_output, num_blocks + 1)
......@@ -153,16 +155,13 @@ class ModelingTest(parameterized.TestCase, tf.test.TestCase):
test_network = mobile_bert_encoder.MobileBERTEncoder(
word_vocab_size=vocab_size,
hidden_size=hidden_size,
num_blocks=num_blocks,
return_all_layers=False)
num_blocks=num_blocks)
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
layer_out_tensor, pooler_out_tensor = test_network(
[word_ids, mask, type_ids])
model = tf.keras.Model([word_ids, mask, type_ids],
[layer_out_tensor, pooler_out_tensor])
outputs = test_network([word_ids, mask, type_ids])
model = tf.keras.Model([word_ids, mask, type_ids], outputs)
input_seq = generate_fake_input(
batch_size=1, seq_len=sequence_length, vocab_size=vocab_size)
......@@ -170,13 +169,12 @@ class ModelingTest(parameterized.TestCase, tf.test.TestCase):
batch_size=1, seq_len=sequence_length, vocab_size=2)
token_type = generate_fake_input(
batch_size=1, seq_len=sequence_length, vocab_size=2)
layer_output, pooler_output = model.predict(
[input_seq, input_mask, token_type])
outputs = model.predict([input_seq, input_mask, token_type])
layer_output_shape = [1, sequence_length, hidden_size]
self.assertAllEqual(layer_output.shape, layer_output_shape)
pooler_output_shape = [1, hidden_size]
self.assertAllEqual(pooler_output.shape, pooler_output_shape)
sequence_output_shape = [1, sequence_length, hidden_size]
self.assertAllEqual(outputs['sequence_output'].shape, sequence_output_shape)
pooled_output_shape = [1, hidden_size]
self.assertAllEqual(outputs['pooled_output'].shape, pooled_output_shape)
def test_mobilebert_encoder_invocation_with_attention_score(self):
vocab_size = 100
......@@ -186,18 +184,13 @@ class ModelingTest(parameterized.TestCase, tf.test.TestCase):
test_network = mobile_bert_encoder.MobileBERTEncoder(
word_vocab_size=vocab_size,
hidden_size=hidden_size,
num_blocks=num_blocks,
return_all_layers=False,
return_attention_score=True)
num_blocks=num_blocks)
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
layer_out_tensor, pooler_out_tensor, attention_out_tensor = test_network(
[word_ids, mask, type_ids])
model = tf.keras.Model(
[word_ids, mask, type_ids],
[layer_out_tensor, pooler_out_tensor, attention_out_tensor])
outputs = test_network([word_ids, mask, type_ids])
model = tf.keras.Model([word_ids, mask, type_ids], outputs)
input_seq = generate_fake_input(
batch_size=1, seq_len=sequence_length, vocab_size=vocab_size)
......@@ -205,9 +198,8 @@ class ModelingTest(parameterized.TestCase, tf.test.TestCase):
batch_size=1, seq_len=sequence_length, vocab_size=2)
token_type = generate_fake_input(
batch_size=1, seq_len=sequence_length, vocab_size=2)
_, _, attention_score_output = model.predict(
[input_seq, input_mask, token_type])
self.assertLen(attention_score_output, num_blocks)
outputs = model.predict([input_seq, input_mask, token_type])
self.assertLen(outputs['attention_scores'], num_blocks)
@parameterized.named_parameters(
('sequence_classification', models.BertClassifier, [None, 5]),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册