提交 cfef35b1 编写于 作者: R Renjie Liu 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 464685466
上级 4647a7aa
......@@ -41,3 +41,5 @@ class PretrainerConfig(base_config.Config):
cls_heads: List[ClsHeadConfig] = dataclasses.field(default_factory=list)
mlm_activation: str = "gelu"
mlm_initializer_range: float = 0.02
# Currently only used for mobile bert.
mlm_output_weights_use_proj: bool = False
......@@ -447,6 +447,7 @@ class MobileBertMaskedLM(tf.keras.layers.Layer):
activation=None,
initializer='glorot_uniform',
output='logits',
output_weights_use_proj=False,
**kwargs):
"""Class initialization.
......@@ -457,6 +458,9 @@ class MobileBertMaskedLM(tf.keras.layers.Layer):
uniform initializer.
output: The output style for this layer. Can be either `logits` or
`predictions`.
output_weights_use_proj: Use projection instead of concating extra output
weights, this may reduce the MLM task accuracy but will reduce the model
params as well.
**kwargs: keyword arguments.
"""
super().__init__(**kwargs)
......@@ -469,6 +473,7 @@ class MobileBertMaskedLM(tf.keras.layers.Layer):
('Unknown `output` value "%s". `output` can be either "logits" or '
'"predictions"') % output)
self._output_type = output
self._output_weights_use_proj = output_weights_use_proj
def build(self, input_shape):
self._vocab_size, embedding_width = self.embedding_table.shape
......@@ -480,11 +485,18 @@ class MobileBertMaskedLM(tf.keras.layers.Layer):
name='transform/dense')
if hidden_size > embedding_width:
self.extra_output_weights = self.add_weight(
'extra_output_weights',
shape=(self._vocab_size, hidden_size - embedding_width),
initializer=tf_utils.clone_initializer(self.initializer),
trainable=True)
if self._output_weights_use_proj:
self.extra_output_weights = self.add_weight(
'output_weights_proj',
shape=(embedding_width, hidden_size),
initializer=tf_utils.clone_initializer(self.initializer),
trainable=True)
else:
self.extra_output_weights = self.add_weight(
'extra_output_weights',
shape=(self._vocab_size, hidden_size - embedding_width),
initializer=tf_utils.clone_initializer(self.initializer),
trainable=True)
elif hidden_size == embedding_width:
self.extra_output_weights = None
else:
......@@ -509,10 +521,16 @@ class MobileBertMaskedLM(tf.keras.layers.Layer):
if self.extra_output_weights is None:
lm_data = tf.matmul(lm_data, self.embedding_table, transpose_b=True)
else:
lm_data = tf.matmul(
lm_data,
tf.concat([self.embedding_table, self.extra_output_weights], axis=1),
transpose_b=True)
if self._output_weights_use_proj:
lm_data = tf.matmul(
lm_data, self.extra_output_weights, transpose_b=True)
lm_data = tf.matmul(lm_data, self.embedding_table, transpose_b=True)
else:
lm_data = tf.matmul(
lm_data,
tf.concat([self.embedding_table, self.extra_output_weights],
axis=1),
transpose_b=True)
logits = tf.nn.bias_add(lm_data, self.bias)
masked_positions_length = masked_positions.shape.as_list()[1] or tf.shape(
......
......@@ -63,6 +63,7 @@ student_model:
type: mobilebert
mlm_activation: relu
mlm_initializer_range: 0.02
mlm_output_weights_use_proj: true
teacher_model:
cls_heads: []
encoder:
......
......@@ -85,6 +85,7 @@ def build_bert_pretrainer(pretrainer_cfg: params.PretrainerModelParams,
activation=tf_utils.get_activation(pretrainer_cfg.mlm_activation),
initializer=tf.keras.initializers.TruncatedNormal(
stddev=pretrainer_cfg.mlm_initializer_range),
output_weights_use_proj=pretrainer_cfg.mlm_output_weights_use_proj,
name='cls/predictions')
pretrainer = edgetpu_pretrainer.MobileBERTEdgeTPUPretrainer(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册