提交 3f8132e9 编写于 作者: Z Zongwei Zhou 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 329042049
上级 379f951d
......@@ -82,6 +82,13 @@ def define_common_bert_flags():
'allreduce in optimizer.apply_gradients(). If fp16 mixed '
'precision training is used, this also enables allreduce '
'gradients in fp16.')
flags.DEFINE_integer('allreduce_bytes_per_pack', 0,
'Number of bytes of a gradient pack for allreduce. '
'Should be positive integer, if set to 0, all '
'gradients are in one pack. Breaking gradient into '
'packs could enable overlap between allreduce and '
'backprop computation. This flag only takes effect '
'when explicit_allreduce is set to True.')
flags_core.define_log_steps()
......
......@@ -133,7 +133,8 @@ def run_customized_training_loop(
explicit_allreduce=False,
pre_allreduce_callbacks=None,
post_allreduce_callbacks=None,
train_summary_interval=0):
train_summary_interval=0,
allreduce_bytes_per_pack=0):
"""Run BERT pretrain model training using low-level API.
Arguments:
......@@ -201,6 +202,11 @@ def run_customized_training_loop(
when explicit_allreduce=True.
train_summary_interval: Step interval for training summaries. If the value
is a negative number, then training summaries are not enabled.
allreduce_bytes_per_pack: A non-negative integer. Breaks collective
operations into packs of certain size. If it's zero, all gradients are
in one pack. Breaking gradient into packs could enable overlap between
allreduce and backprop computation. This flag only takes effect when
explicit_allreduce is set to True.'
Returns:
Trained model.
......@@ -332,7 +338,8 @@ def run_customized_training_loop(
grad_utils.minimize_using_explicit_allreduce(tape, optimizer, loss,
training_vars,
pre_allreduce_callbacks,
post_allreduce_callbacks)
post_allreduce_callbacks,
allreduce_bytes_per_pack)
else:
if isinstance(optimizer,
tf.keras.mixed_precision.experimental.LossScaleOptimizer):
......
......@@ -109,7 +109,8 @@ def run_customized_training(strategy,
custom_callbacks=None,
explicit_allreduce=False,
pre_allreduce_callbacks=None,
post_allreduce_callbacks=None):
post_allreduce_callbacks=None,
allreduce_bytes_per_pack=0):
"""Run BERT pretrain model training using low-level API."""
train_input_fn = get_pretrain_dataset_fn(input_files, max_seq_length,
......@@ -146,6 +147,7 @@ def run_customized_training(strategy,
explicit_allreduce=explicit_allreduce,
pre_allreduce_callbacks=pre_allreduce_callbacks,
post_allreduce_callbacks=post_allreduce_callbacks,
allreduce_bytes_per_pack=allreduce_bytes_per_pack,
train_summary_interval=train_summary_interval,
custom_callbacks=custom_callbacks)
......@@ -165,10 +167,12 @@ def run_bert_pretrain(strategy, custom_callbacks=None):
performance.set_mixed_precision_policy(common_flags.dtype())
# If explicit_allreduce = True, apply_gradients() no longer implicitly
# allreduce gradients, users manually allreduce gradient and pass the
# allreduced grads_and_vars to apply_gradients(). clip_by_global_norm is kept
# before allreduce, to be consistent with original TF1 model.
# Only when explicit_allreduce = True, post_allreduce_callbacks and
# allreduce_bytes_per_pack will take effect. optimizer.apply_gradients() no
# longer implicitly allreduce gradients, users manually allreduce gradient and
# pass the allreduced grads_and_vars to apply_gradients().
# With explicit_allreduce = True, clip_by_global_norm is moved to after
# allreduce.
return run_customized_training(
strategy,
bert_config,
......@@ -191,7 +195,8 @@ def run_bert_pretrain(strategy, custom_callbacks=None):
explicit_allreduce=FLAGS.explicit_allreduce,
pre_allreduce_callbacks=[
model_training_utils.clip_by_global_norm_callback
])
],
allreduce_bytes_per_pack=FLAGS.allreduce_bytes_per_pack)
def main(_):
......
......@@ -260,10 +260,12 @@ def train_squad(strategy,
use_graph_rewrite=common_flags.use_graph_rewrite())
return squad_model, core_model
# If explicit_allreduce = True, apply_gradients() no longer implicitly
# allreduce gradients, users manually allreduce gradient and pass the
# allreduced grads_and_vars to apply_gradients(). clip_by_global_norm is kept
# before allreduce, to be consistent with the original TF1 model.
# Only when explicit_allreduce = True, post_allreduce_callbacks and
# allreduce_bytes_per_pack will take effect. optimizer.apply_gradients() no
# longer implicitly allreduce gradients, users manually allreduce gradient and
# pass the allreduced grads_and_vars to apply_gradients().
# With explicit_allreduce = True, clip_by_global_norm is moved to after
# allreduce.
model_training_utils.run_customized_training_loop(
strategy=strategy,
model_fn=_get_squad_model,
......@@ -280,7 +282,8 @@ def train_squad(strategy,
explicit_allreduce=FLAGS.explicit_allreduce,
pre_allreduce_callbacks=[
model_training_utils.clip_by_global_norm_callback
])
],
allreduce_bytes_per_pack=FLAGS.allreduce_bytes_per_pack)
def prediction_output_squad(strategy, input_meta_data, tokenizer, squad_lib,
......
......@@ -48,7 +48,8 @@ def _filter_grads(grads_and_vars):
def _filter_and_allreduce_gradients(grads_and_vars,
allreduce_precision="float32"):
allreduce_precision="float32",
bytes_per_pack=0):
"""Filter None grads and then allreduce gradients in specified precision.
This utils function is used when users intent to explicitly allreduce
......@@ -59,6 +60,8 @@ def _filter_and_allreduce_gradients(grads_and_vars,
Arguments:
grads_and_vars: gradients and variables pairs.
allreduce_precision: Whether to allreduce gradients in float32 or float16.
bytes_per_pack: A non-negative integer. Breaks collective operations into
packs of certain size. If it's zero, all gradients are in one pack.
Returns:
pairs of allreduced non-None gradients and variables.
......@@ -67,8 +70,10 @@ def _filter_and_allreduce_gradients(grads_and_vars,
(grads, variables) = zip(*filtered_grads_and_vars)
if allreduce_precision == "float16":
grads = [tf.cast(grad, "float16") for grad in grads]
hints = tf.distribute.experimental.CollectiveHints(
bytes_per_pack=bytes_per_pack)
allreduced_grads = tf.distribute.get_replica_context().all_reduce(
tf.distribute.ReduceOp.SUM, grads)
tf.distribute.ReduceOp.SUM, grads, experimental_hints=hints)
if allreduce_precision == "float16":
allreduced_grads = [tf.cast(grad, "float32") for grad in allreduced_grads]
return allreduced_grads, variables
......@@ -85,7 +90,8 @@ def minimize_using_explicit_allreduce(tape,
loss,
trainable_variables,
pre_allreduce_callbacks=None,
post_allreduce_callbacks=None):
post_allreduce_callbacks=None,
allreduce_bytes_per_pack=0):
"""Minimizes loss for one step by updating `trainable_variables`.
Minimizes loss for one step by updating `trainable_variables`.
......@@ -111,6 +117,9 @@ def minimize_using_explicit_allreduce(tape,
returns a new gradients and model variables paris. The callback
functions will be invoked in the list order and right before gradients
are applied to variables for updates. Default is no callbacks.
allreduce_bytes_per_pack: A non-negative integer. Breaks collective
operations into packs of certain size. If it's zero, all gradients are
in one pack.
"""
if isinstance(optimizer,
tf.keras.mixed_precision.experimental.LossScaleOptimizer):
......@@ -123,7 +132,9 @@ def minimize_using_explicit_allreduce(tape,
grads_and_vars = _run_callbacks(pre_allreduce_callbacks, grads_and_vars)
(allreduced_scaled_grads,
filtered_training_vars) = _filter_and_allreduce_gradients(
grads_and_vars, allreduce_precision="float16")
grads_and_vars,
allreduce_precision="float16",
bytes_per_pack=allreduce_bytes_per_pack)
allreduced_unscaled_grads = optimizer.get_unscaled_gradients(
allreduced_scaled_grads)
grads_and_vars = zip(allreduced_unscaled_grads, filtered_training_vars)
......@@ -135,7 +146,9 @@ def minimize_using_explicit_allreduce(tape,
grads_and_vars = _run_callbacks(pre_allreduce_callbacks, grads_and_vars)
(allreduced_grads,
filtered_training_vars) = _filter_and_allreduce_gradients(
grads_and_vars, allreduce_precision="float32")
grads_and_vars,
allreduce_precision="float32",
bytes_per_pack=allreduce_bytes_per_pack)
grads_and_vars = zip(allreduced_grads, filtered_training_vars)
if post_allreduce_callbacks:
grads_and_vars = _run_callbacks(post_allreduce_callbacks, grads_and_vars)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册