提交 8c408bbe 编写于 作者: C Chen Chen 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 311597242
上级 7cdb82e3
......@@ -54,9 +54,11 @@ class EncoderScaffold(tf.keras.Model):
pooler_layer_initializer: The initializer for the classification
layer.
embedding_cls: The class or instance to use to embed the input data. This
class or instance defines the inputs to this encoder. If embedding_cls is
not set, a default embedding network (from the original BERT paper) will
be created.
class or instance defines the inputs to this encoder and outputs
(1) embeddings tensor with shape [batch_size, seq_length, hidden_size] and
(2) attention masking with tensor [batch_size, seq_length, seq_length].
If embedding_cls is not set, a default embedding network
(from the original BERT paper) will be created.
embedding_cfg: A dict of kwargs to pass to the embedding_cls, if it needs to
be instantiated. If embedding_cls is not set, a config dict must be
passed to 'embedding_cfg' with the following values:
......@@ -121,7 +123,7 @@ class EncoderScaffold(tf.keras.Model):
else:
self._embedding_network = embedding_cls
inputs = self._embedding_network.inputs
embeddings, mask = self._embedding_network(inputs)
embeddings, attention_mask = self._embedding_network(inputs)
else:
self._embedding_network = None
word_ids = tf.keras.layers.Input(
......@@ -174,7 +176,8 @@ class EncoderScaffold(tf.keras.Model):
tf.keras.layers.Dropout(
rate=embedding_cfg['dropout_rate'])(embeddings))
attention_mask = layers.SelfAttentionMask()([embeddings, mask])
attention_mask = layers.SelfAttentionMask()([embeddings, mask])
data = embeddings
layer_output_data = []
......
......@@ -211,8 +211,6 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
"kernel_initializer":
tf.keras.initializers.TruncatedNormal(stddev=0.02),
}
print(hidden_cfg)
print(embedding_cfg)
# Create a small EncoderScaffold for testing.
test_network = encoder_scaffold.EncoderScaffold(
num_hidden_instances=3,
......@@ -347,7 +345,9 @@ class EncoderScaffoldEmbeddingNetworkTest(keras_parameterized.TestCase):
initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
name="word_embeddings")
word_embeddings = embedding_layer(word_ids)
network = tf.keras.Model([word_ids, mask], [word_embeddings, mask])
attention_mask = layers.SelfAttentionMask()([word_embeddings, mask])
network = tf.keras.Model([word_ids, mask],
[word_embeddings, attention_mask])
hidden_cfg = {
"num_attention_heads":
......@@ -414,7 +414,9 @@ class EncoderScaffoldEmbeddingNetworkTest(keras_parameterized.TestCase):
initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
name="word_embeddings")
word_embeddings = embedding_layer(word_ids)
network = tf.keras.Model([word_ids, mask], [word_embeddings, mask])
attention_mask = layers.SelfAttentionMask()([word_embeddings, mask])
network = tf.keras.Model([word_ids, mask],
[word_embeddings, attention_mask])
hidden_cfg = {
"num_attention_heads":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册