提交 152adcbd 编写于 作者: A A. Unique TensorFlower

Enabled performance related parameters for Transformer: all_reduce_alg,

enable_eager, tf_gpu_thread_mode and datasets_num_private_threads

PiperOrigin-RevId: 285157129
上级 26bbda73
......@@ -71,6 +71,9 @@ def define_transformer_flags():
dtype=True,
loss_scale=True,
all_reduce_alg=True,
num_packs=True,
tf_gpu_thread_mode=True,
datasets_num_private_threads=True,
enable_xla=True,
force_v2_in_keras_compile=True,
fp16_implementation=True
......@@ -86,7 +89,7 @@ def define_transformer_flags():
'convolutions and batch normalizations, and this flag allows to '
'disable it.'
)
flags_core.define_benchmark()
flags_core.define_device(tpu=True)
......
......@@ -164,6 +164,8 @@ class TransformerTask(object):
self.distribution_strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
num_gpus=num_gpus,
all_reduce_alg=flags_obj.all_reduce_alg,
num_packs=flags_obj.num_packs,
tpu_address=flags_obj.tpu or "")
if self.use_tpu:
params["num_replicas"] = self.distribution_strategy.num_replicas_in_sync
......@@ -465,6 +467,14 @@ def main(_):
with logger.benchmark_context(flags_obj):
task = TransformerTask(flags_obj)
# Execute flag override logic for better model performance
if flags_obj.tf_gpu_thread_mode:
keras_utils.set_gpu_thread_mode_and_count(
per_gpu_thread_count=flags_obj.per_gpu_thread_count,
gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
num_gpus=flags_obj.num_gpus,
datasets_num_private_threads=flags_obj.datasets_num_private_threads)
if flags_obj.mode == "train":
task.train()
elif flags_obj.mode == "predict":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册