From a2a1b66ffc8020f2786e17f5d85094bc6d9d2358 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 20 Nov 2019 17:29:54 -0800 Subject: [PATCH] Move distribution stragegy init to beginning of run PiperOrigin-RevId: 281641189 --- official/transformer/v2/transformer_main.py | 32 ++++++++++----------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/official/transformer/v2/transformer_main.py b/official/transformer/v2/transformer_main.py index 70b329ae0..fad093db9 100644 --- a/official/transformer/v2/transformer_main.py +++ b/official/transformer/v2/transformer_main.py @@ -160,22 +160,6 @@ class TransformerTask(object): params["dtype"] = flags_core.get_tf_dtype(flags_obj) params["enable_metrics_in_training"] = flags_obj.enable_metrics_in_training - if params["dtype"] == tf.float16: - # TODO(reedwm): It's pretty ugly to set the global policy in a constructor - # like this. What if multiple instances of TransformerTask are created? - # We should have a better way in the tf.keras.mixed_precision API of doing - # this. - loss_scale = flags_core.get_loss_scale(flags_obj, - default_for_fp16="dynamic") - policy = tf.compat.v2.keras.mixed_precision.experimental.Policy( - "mixed_float16", loss_scale=loss_scale) - tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy) - - if params["dtype"] == tf.bfloat16: - policy = tf.compat.v2.keras.mixed_precision.experimental.Policy( - "mixed_bfloat16") - tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy) - self.distribution_strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=num_gpus, @@ -193,6 +177,22 @@ class TransformerTask(object): else: logging.info("Not using any distribution strategy.") + if params["dtype"] == tf.float16: + # TODO(reedwm): It's pretty ugly to set the global policy in a constructor + # like this. What if multiple instances of TransformerTask are created? + # We should have a better way in the tf.keras.mixed_precision API of doing + # this. + loss_scale = flags_core.get_loss_scale(flags_obj, + default_for_fp16="dynamic") + policy = tf.compat.v2.keras.mixed_precision.experimental.Policy( + "mixed_float16", loss_scale=loss_scale) + tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy) + + if params["dtype"] == tf.bfloat16: + policy = tf.compat.v2.keras.mixed_precision.experimental.Policy( + "mixed_bfloat16") + tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy) + @property def use_tpu(self): if self.distribution_strategy: -- GitLab