提交 f7fd59b8 编写于 作者: C Chen Chen 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 284792715
上级 558bab5d
......@@ -65,6 +65,10 @@ def define_common_bert_flags():
flags.DEFINE_string(
'hub_module_url', None, 'TF-Hub path/url to Bert module. '
'If specified, init_checkpoint flag should not be used.')
flags.DEFINE_enum(
'model_type', 'bert', ['bert', 'albert'],
'Specifies the type of the model. '
'If "bert", will use canonical BERT; if "albert", will use ALBERT model.')
# Adds flags for mixed precision training.
flags_core.define_performance(
......
......@@ -287,7 +287,11 @@ def run_bert(strategy,
train_input_fn=None,
eval_input_fn=None):
"""Run BERT training."""
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
if FLAGS.model_type == 'bert':
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
else:
assert FLAGS.model_type == 'albert'
bert_config = modeling.AlbertConfig.from_json_file(FLAGS.bert_config_file)
if FLAGS.mode == 'export_only':
# As Keras ModelCheckpoint callback used with Keras compile/fit() API
# internally uses model.save_weights() to save checkpoints, we must
......
......@@ -22,6 +22,7 @@ import tensorflow as tf
import tensorflow_hub as hub
from official.modeling import tf_utils
from official.nlp import bert_modeling
from official.nlp.modeling import losses
from official.nlp.modeling import networks
from official.nlp.modeling.networks import bert_classifier
......@@ -139,14 +140,14 @@ def _get_transformer_encoder(bert_config,
"""Gets a 'TransformerEncoder' object.
Args:
bert_config: A 'modeling.BertConfig' object.
bert_config: A 'modeling.BertConfig' or 'modeling.AlbertConfig' object.
sequence_length: Maximum sequence length of the training data.
float_dtype: tf.dtype, tf.float32 or tf.float16.
Returns:
A networks.TransformerEncoder object.
"""
return networks.TransformerEncoder(
kwargs = dict(
vocab_size=bert_config.vocab_size,
hidden_size=bert_config.hidden_size,
num_layers=bert_config.num_hidden_layers,
......@@ -161,6 +162,12 @@ def _get_transformer_encoder(bert_config,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range),
float_dtype=float_dtype.name)
if isinstance(bert_config, bert_modeling.AlbertConfig):
kwargs['embedding_width'] = bert_config.embedding_size
return networks.AlbertTransformerEncoder(**kwargs)
else:
assert isinstance(bert_config, bert_modeling.BertConfig)
return networks.TransformerEncoder(**kwargs)
def pretrain_model(bert_config,
......@@ -332,7 +339,8 @@ def classifier_model(bert_config,
maximum sequence length `max_seq_length`.
Args:
bert_config: BertConfig, the config defines the core BERT model.
bert_config: BertConfig or AlbertConfig, the config defines the core
BERT or ALBERT model.
float_type: dtype, tf.float32 or tf.bfloat16.
num_labels: integer, the number of classes.
max_seq_length: integer, the maximum input sequence length.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册