提交 7e12e4f4 编写于 作者: C Chen Chen 提交者: A. Unique TensorFlower

Internal Change

PiperOrigin-RevId: 281630227
上级 e16594d1
......@@ -374,7 +374,8 @@ def squad_model(bert_config,
bert_config: BertConfig, the config defines the core Bert model.
max_seq_length: integer, the maximum input sequence length.
float_type: tf.dtype, tf.float32 or tf.bfloat16.
initializer: Initializer for weights in BertSquadLogitsLayer.
initializer: Initializer for the final dense layer in the span labeler.
Defaulted to TruncatedNormal initializer.
hub_module_url: TF-Hub path/url to Bert module.
use_keras_bert: Whether to use keras BERT. Note that when the above
'hub_module_url' is specified, 'use_keras_bert' cannot be True.
......@@ -389,12 +390,14 @@ def squad_model(bert_config,
if hub_module_url and use_keras_bert:
raise ValueError(
'Cannot use hub_module_url and keras BERT at the same time.')
if initializer is None:
initializer = tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range)
if use_keras_bert:
bert_encoder = _get_transformer_encoder(
bert_config, max_seq_length, float_type)
bert_encoder = _get_transformer_encoder(bert_config, max_seq_length,
float_type)
return bert_span_labeler.BertSpanLabeler(
network=bert_encoder), bert_encoder
network=bert_encoder, initializer=initializer), bert_encoder
input_word_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
......@@ -421,9 +424,6 @@ def squad_model(bert_config,
# has dimensionality (batch_size, sequence_length, num_hidden).
sequence_output = core_model.outputs[1]
if initializer is None:
initializer = tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range)
squad_logits_layer = BertSquadLogitsLayer(
initializer=initializer, float_type=float_type, name='squad_logits')
start_logits, end_logits = squad_logits_layer(sequence_output)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册