提交 3fc70674 编写于 作者: Z Zongwei Zhou 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 328842362
上级 e3a74e5b
......@@ -77,6 +77,11 @@ def define_common_bert_flags():
'sub_model_export_name', None,
'If set, `sub_model` checkpoints are exported into '
'FLAGS.model_dir/FLAGS.sub_model_export_name.')
flags.DEFINE_bool('explicit_allreduce', False,
'True to use explicit allreduce instead of the implicit '
'allreduce in optimizer.apply_gradients(). If fp16 mixed '
'precision training is used, this also enables allreduce '
'gradients in fp16.')
flags_core.define_log_steps()
......@@ -116,3 +121,10 @@ def use_graph_rewrite():
def get_loss_scale():
return flags_core.get_loss_scale(flags.FLAGS, default_for_fp16='dynamic')
def clip_by_global_norm_callback(grads_and_vars):
grads, variables = zip(*grads_and_vars)
(clipped_grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
return zip(clipped_grads, variables)
......@@ -106,7 +106,10 @@ def run_customized_training(strategy,
train_batch_size,
use_next_sentence_label=True,
train_summary_interval=0,
custom_callbacks=None):
custom_callbacks=None,
explicit_allreduce=False,
pre_allreduce_callbacks=None,
post_allreduce_callbacks=None):
"""Run BERT pretrain model training using low-level API."""
train_input_fn = get_pretrain_dataset_fn(input_files, max_seq_length,
......@@ -140,6 +143,9 @@ def run_customized_training(strategy,
steps_per_loop=steps_per_loop,
epochs=epochs,
sub_model_export_name='pretrained/bert_model',
explicit_allreduce=explicit_allreduce,
pre_allreduce_callbacks=pre_allreduce_callbacks,
post_allreduce_callbacks=post_allreduce_callbacks,
train_summary_interval=train_summary_interval,
custom_callbacks=custom_callbacks)
......@@ -159,6 +165,10 @@ def run_bert_pretrain(strategy, custom_callbacks=None):
performance.set_mixed_precision_policy(common_flags.dtype())
# If explicit_allreduce = True, apply_gradients() no longer implicitly
# allreduce gradients, users manually allreduce gradient and pass the
# allreduced grads_and_vars to apply_gradients(). clip_by_global_norm is kept
# before allreduce, to be consistent with original TF1 model.
return run_customized_training(
strategy,
bert_config,
......@@ -177,7 +187,9 @@ def run_bert_pretrain(strategy, custom_callbacks=None):
FLAGS.train_batch_size,
FLAGS.use_next_sentence_label,
FLAGS.train_summary_interval,
custom_callbacks=custom_callbacks)
custom_callbacks=custom_callbacks,
explicit_allreduce=FLAGS.explicit_allreduce,
pre_allreduce_callbacks=[common_flags.clip_by_global_norm_callback])
def main(_):
......
......@@ -262,13 +262,8 @@ def train_squad(strategy,
# If explicit_allreduce = True, apply_gradients() no longer implicitly
# allreduce gradients, users manually allreduce gradient and pass the
# allreduced grads_and_vars to apply_gradients(). clip_by_global_norm will be
# applied to allreduced gradients.
def clip_by_global_norm_callback(grads_and_vars):
grads, variables = zip(*grads_and_vars)
(clipped_grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
return zip(clipped_grads, variables)
# allreduced grads_and_vars to apply_gradients(). clip_by_global_norm is kept
# before allreduce, to be consistent with the original TF1 model.
model_training_utils.run_customized_training_loop(
strategy=strategy,
model_fn=_get_squad_model,
......@@ -282,8 +277,8 @@ def train_squad(strategy,
sub_model_export_name=sub_model_export_name,
run_eagerly=run_eagerly,
custom_callbacks=custom_callbacks,
explicit_allreduce=False,
post_allreduce_callbacks=[clip_by_global_norm_callback])
explicit_allreduce=FLAGS.explicit_allreduce,
pre_allreduce_callbacks=[common_flags.clip_by_global_norm_callback])
def prediction_output_squad(strategy, input_meta_data, tokenizer, squad_lib,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册