提交 cd39fd58 编写于 作者: A A. Unique TensorFlower

Slice Transformer outputs with a TF Op layer instead of a Lambda layer.

Lambda layers are fundamentally non-portable since they serialize
Python bytecode.

PiperOrigin-RevId: 336917172
上级 cb93c205
......@@ -163,9 +163,10 @@ class AlbertEncoder(tf.keras.Model):
data = shared_layer([data, attention_mask])
encoder_outputs.append(data)
first_token_tensor = (
tf.keras.layers.Lambda(lambda x: tf.squeeze(x[:, 0:1, :], axis=1))(data)
)
# Applying a tf.slice op (through subscript notation) to a Keras tensor
# like this will create a SliceOpLambda layer. This is better than a Lambda
# layer with Python code, because that is fundamentally less portable.
first_token_tensor = data[:, 0, :]
cls_output = tf.keras.layers.Dense(
units=hidden_size,
activation='tanh',
......
......@@ -193,9 +193,11 @@ class EncoderScaffold(tf.keras.Model):
layer_output_data.append(data)
self._hidden_layers.append(layer)
first_token_tensor = (
tf.keras.layers.Lambda(lambda x: tf.squeeze(x[:, 0:1, :], axis=1))(
layer_output_data[-1]))
last_layer_output = layer_output_data[-1]
# Applying a tf.slice op (through subscript notation) to a Keras tensor
# like this will create a SliceOpLambda layer. This is better than a Lambda
# layer with Python code, because that is fundamentally less portable.
first_token_tensor = last_layer_output[:, 0, :]
self._pooler_layer = tf.keras.layers.Dense(
units=pooled_output_dim,
activation='tanh',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册