提交 33a4c207 编写于 作者: H Hongkun Yu 提交者: A. Unique TensorFlower

[Clean up] Consolidate distribution utils.

PiperOrigin-RevId: 331359058
上级 41a1e1d6
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helper functions for running models in a distributed setting."""
import json
import os
import random
import string
from absl import logging
import tensorflow as tf
def _collective_communication(all_reduce_alg):
"""Return a CollectiveCommunication based on all_reduce_alg.
Args:
all_reduce_alg: a string specifying which collective communication to pick,
or None.
Returns:
tf.distribute.experimental.CollectiveCommunication object
Raises:
ValueError: if `all_reduce_alg` not in [None, "ring", "nccl"]
"""
collective_communication_options = {
None: tf.distribute.experimental.CollectiveCommunication.AUTO,
"ring": tf.distribute.experimental.CollectiveCommunication.RING,
"nccl": tf.distribute.experimental.CollectiveCommunication.NCCL
}
if all_reduce_alg not in collective_communication_options:
raise ValueError(
"When used with `multi_worker_mirrored`, valid values for "
"all_reduce_alg are [`ring`, `nccl`]. Supplied value: {}".format(
all_reduce_alg))
return collective_communication_options[all_reduce_alg]
def _mirrored_cross_device_ops(all_reduce_alg, num_packs):
"""Return a CrossDeviceOps based on all_reduce_alg and num_packs.
Args:
all_reduce_alg: a string specifying which cross device op to pick, or None.
num_packs: an integer specifying number of packs for the cross device op.
Returns:
tf.distribute.CrossDeviceOps object or None.
Raises:
ValueError: if `all_reduce_alg` not in [None, "nccl", "hierarchical_copy"].
"""
if all_reduce_alg is None:
return None
mirrored_all_reduce_options = {
"nccl": tf.distribute.NcclAllReduce,
"hierarchical_copy": tf.distribute.HierarchicalCopyAllReduce
}
if all_reduce_alg not in mirrored_all_reduce_options:
raise ValueError(
"When used with `mirrored`, valid values for all_reduce_alg are "
"[`nccl`, `hierarchical_copy`]. Supplied value: {}".format(
all_reduce_alg))
cross_device_ops_class = mirrored_all_reduce_options[all_reduce_alg]
return cross_device_ops_class(num_packs=num_packs)
def tpu_initialize(tpu_address):
"""Initializes TPU for TF 2.x training.
Args:
tpu_address: string, bns address of master TPU worker.
Returns:
A TPUClusterResolver.
"""
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
tpu=tpu_address)
if tpu_address not in ("", "local"):
tf.config.experimental_connect_to_cluster(cluster_resolver)
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
return cluster_resolver
def get_distribution_strategy(distribution_strategy="mirrored",
num_gpus=0,
all_reduce_alg=None,
num_packs=1,
tpu_address=None):
"""Return a DistributionStrategy for running the model.
Args:
distribution_strategy: a string specifying which distribution strategy to
use. Accepted values are "off", "one_device", "mirrored",
"parameter_server", "multi_worker_mirrored", and "tpu" -- case
insensitive. "off" means not to use Distribution Strategy; "tpu" means to
use TPUStrategy using `tpu_address`.
num_gpus: Number of GPUs to run this model.
all_reduce_alg: Optional. Specifies which algorithm to use when performing
all-reduce. For `MirroredStrategy`, valid values are "nccl" and
"hierarchical_copy". For `MultiWorkerMirroredStrategy`, valid values are
"ring" and "nccl". If None, DistributionStrategy will choose based on
device topology.
num_packs: Optional. Sets the `num_packs` in `tf.distribute.NcclAllReduce`
or `tf.distribute.HierarchicalCopyAllReduce` for `MirroredStrategy`.
tpu_address: Optional. String that represents TPU to connect to. Must not be
None if `distribution_strategy` is set to `tpu`.
Returns:
tf.distribute.DistibutionStrategy object.
Raises:
ValueError: if `distribution_strategy` is "off" or "one_device" and
`num_gpus` is larger than 1; or `num_gpus` is negative or if
`distribution_strategy` is `tpu` but `tpu_address` is not specified.
"""
if num_gpus < 0:
raise ValueError("`num_gpus` can not be negative.")
distribution_strategy = distribution_strategy.lower()
if distribution_strategy == "off":
if num_gpus > 1:
raise ValueError("When {} GPUs are specified, distribution_strategy "
"flag cannot be set to `off`.".format(num_gpus))
return None
if distribution_strategy == "tpu":
# When tpu_address is an empty string, we communicate with local TPUs.
cluster_resolver = tpu_initialize(tpu_address)
return tf.distribute.experimental.TPUStrategy(cluster_resolver)
if distribution_strategy == "multi_worker_mirrored":
return tf.distribute.experimental.MultiWorkerMirroredStrategy(
communication=_collective_communication(all_reduce_alg))
if distribution_strategy == "one_device":
if num_gpus == 0:
return tf.distribute.OneDeviceStrategy("device:CPU:0")
if num_gpus > 1:
raise ValueError("`OneDeviceStrategy` can not be used for more than "
"one device.")
return tf.distribute.OneDeviceStrategy("device:GPU:0")
if distribution_strategy == "mirrored":
if num_gpus == 0:
devices = ["device:CPU:0"]
else:
devices = ["device:GPU:%d" % i for i in range(num_gpus)]
return tf.distribute.MirroredStrategy(
devices=devices,
cross_device_ops=_mirrored_cross_device_ops(all_reduce_alg, num_packs))
if distribution_strategy == "parameter_server":
return tf.distribute.experimental.ParameterServerStrategy()
raise ValueError("Unrecognized Distribution Strategy: %r" %
distribution_strategy)
def configure_cluster(worker_hosts=None, task_index=-1):
"""Set multi-worker cluster spec in TF_CONFIG environment variable.
Args:
worker_hosts: comma-separated list of worker ip:port pairs.
Returns:
Number of workers in the cluster.
"""
tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
if tf_config:
num_workers = (
len(tf_config["cluster"].get("chief", [])) +
len(tf_config["cluster"].get("worker", [])))
elif worker_hosts:
workers = worker_hosts.split(",")
num_workers = len(workers)
if num_workers > 1 and task_index < 0:
raise ValueError("Must specify task_index when number of workers > 1")
task_index = 0 if num_workers == 1 else task_index
os.environ["TF_CONFIG"] = json.dumps({
"cluster": {
"worker": workers
},
"task": {
"type": "worker",
"index": task_index
}
})
else:
num_workers = 1
return num_workers
def get_strategy_scope(strategy):
if strategy:
strategy_scope = strategy.scope()
else:
strategy_scope = DummyContextManager()
return strategy_scope
class DummyContextManager(object):
def __enter__(self):
pass
def __exit__(self, *args):
pass
......@@ -14,32 +14,28 @@
# ==============================================================================
""" Tests for distribution util functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import tensorflow.compat.v2 as tf
from official.utils.misc import distribution_utils
from official.common import distribute_utils
class GetDistributionStrategyTest(tf.test.TestCase):
"""Tests for get_distribution_strategy."""
def test_one_device_strategy_cpu(self):
ds = distribution_utils.get_distribution_strategy(num_gpus=0)
ds = distribute_utils.get_distribution_strategy(num_gpus=0)
self.assertEquals(ds.num_replicas_in_sync, 1)
self.assertEquals(len(ds.extended.worker_devices), 1)
self.assertIn('CPU', ds.extended.worker_devices[0])
def test_one_device_strategy_gpu(self):
ds = distribution_utils.get_distribution_strategy(num_gpus=1)
ds = distribute_utils.get_distribution_strategy(num_gpus=1)
self.assertEquals(ds.num_replicas_in_sync, 1)
self.assertEquals(len(ds.extended.worker_devices), 1)
self.assertIn('GPU', ds.extended.worker_devices[0])
def test_mirrored_strategy(self):
ds = distribution_utils.get_distribution_strategy(num_gpus=5)
ds = distribute_utils.get_distribution_strategy(num_gpus=5)
self.assertEquals(ds.num_replicas_in_sync, 5)
self.assertEquals(len(ds.extended.worker_devices), 5)
for device in ds.extended.worker_devices:
......
......@@ -31,7 +31,7 @@ import tensorflow as tf
from typing import Optional, Dict, List, Text, Callable, Union, Iterator, Any
from official.modeling.hyperparams import params_dict
from official.utils import hyperparams_flags
from official.utils.misc import distribution_utils
from official.common import distribute_utils
from official.utils.misc import keras_utils
FLAGS = flags.FLAGS
......@@ -745,8 +745,8 @@ class ExecutorBuilder(object):
"""
def __init__(self, strategy_type=None, strategy_config=None):
_ = distribution_utils.configure_cluster(strategy_config.worker_hosts,
strategy_config.task_index)
_ = distribute_utils.configure_cluster(strategy_config.worker_hosts,
strategy_config.task_index)
"""Constructor.
Args:
......@@ -756,7 +756,7 @@ class ExecutorBuilder(object):
strategy_config: necessary config for constructing the proper Strategy.
Check strategy_flags_dict() for examples of the structure.
"""
self._strategy = distribution_utils.get_distribution_strategy(
self._strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=strategy_type,
num_gpus=strategy_config.num_gpus,
all_reduce_alg=strategy_config.all_reduce_alg,
......
......@@ -26,11 +26,10 @@ from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
from official.common import distribute_utils
from official.nlp.albert import configs as albert_configs
from official.nlp.bert import bert_models
from official.nlp.bert import run_classifier as run_classifier_bert
from official.utils.misc import distribution_utils
FLAGS = flags.FLAGS
......@@ -77,7 +76,7 @@ def main(_):
if not FLAGS.model_dir:
FLAGS.model_dir = '/tmp/bert20/'
strategy = distribution_utils.get_distribution_strategy(
strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy,
num_gpus=FLAGS.num_gpus,
tpu_address=FLAGS.tpu)
......
......@@ -27,12 +27,11 @@ from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
from official.common import distribute_utils
from official.nlp.albert import configs as albert_configs
from official.nlp.bert import run_squad_helper
from official.nlp.bert import tokenization
from official.nlp.data import squad_lib_sp
from official.utils.misc import distribution_utils
flags.DEFINE_string(
'sp_model_file', None,
......@@ -104,9 +103,8 @@ def main(_):
# Configures cluster spec for multi-worker distribution strategy.
if FLAGS.num_gpus > 0:
_ = distribution_utils.configure_cluster(FLAGS.worker_hosts,
FLAGS.task_index)
strategy = distribution_utils.get_distribution_strategy(
_ = distribute_utils.configure_cluster(FLAGS.worker_hosts, FLAGS.task_index)
strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy,
num_gpus=FLAGS.num_gpus,
all_reduce_alg=FLAGS.all_reduce_alg,
......
......@@ -25,8 +25,8 @@ import tempfile
from absl import logging
import tensorflow as tf
from tensorflow.python.util import deprecation
from official.common import distribute_utils
from official.staging.training import grad_utils
from official.utils.misc import distribution_utils
_SUMMARY_TXT = 'training_summary.txt'
_MIN_SUMMARY_STEPS = 10
......@@ -266,7 +266,7 @@ def run_customized_training_loop(
train_iterator = _get_input_iterator(train_input_fn, strategy)
eval_loss_metric = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
with distribution_utils.get_strategy_scope(strategy):
with distribute_utils.get_strategy_scope(strategy):
# To correctly place the model weights on accelerators,
# model and optimizer should be created in scope.
model, sub_model = model_fn()
......
......@@ -28,6 +28,7 @@ from absl import flags
from absl import logging
import gin
import tensorflow as tf
from official.common import distribute_utils
from official.modeling import performance
from official.nlp import optimization
from official.nlp.bert import bert_models
......@@ -35,7 +36,6 @@ from official.nlp.bert import common_flags
from official.nlp.bert import configs as bert_configs
from official.nlp.bert import input_pipeline
from official.nlp.bert import model_saving_utils
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
flags.DEFINE_enum(
......@@ -447,7 +447,7 @@ def custom_main(custom_callbacks=None, custom_metrics=None):
FLAGS.model_dir)
return
strategy = distribution_utils.get_distribution_strategy(
strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy,
num_gpus=FLAGS.num_gpus,
tpu_address=FLAGS.tpu)
......
......@@ -23,6 +23,7 @@ from absl import flags
from absl import logging
import gin
import tensorflow as tf
from official.common import distribute_utils
from official.modeling import performance
from official.nlp import optimization
from official.nlp.bert import bert_models
......@@ -30,7 +31,6 @@ from official.nlp.bert import common_flags
from official.nlp.bert import configs
from official.nlp.bert import input_pipeline
from official.nlp.bert import model_training_utils
from official.utils.misc import distribution_utils
flags.DEFINE_string('input_files', None,
......@@ -205,9 +205,8 @@ def main(_):
FLAGS.model_dir = '/tmp/bert20/'
# Configures cluster spec for multi-worker distribution strategy.
if FLAGS.num_gpus > 0:
_ = distribution_utils.configure_cluster(FLAGS.worker_hosts,
FLAGS.task_index)
strategy = distribution_utils.get_distribution_strategy(
_ = distribute_utils.configure_cluster(FLAGS.worker_hosts, FLAGS.task_index)
strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy,
num_gpus=FLAGS.num_gpus,
all_reduce_alg=FLAGS.all_reduce_alg,
......
......@@ -28,12 +28,11 @@ from absl import flags
from absl import logging
import gin
import tensorflow as tf
from official.common import distribute_utils
from official.nlp.bert import configs as bert_configs
from official.nlp.bert import run_squad_helper
from official.nlp.bert import tokenization
from official.nlp.data import squad_lib as squad_lib_wp
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
......@@ -105,9 +104,8 @@ def main(_):
# Configures cluster spec for multi-worker distribution strategy.
if FLAGS.num_gpus > 0:
_ = distribution_utils.configure_cluster(FLAGS.worker_hosts,
FLAGS.task_index)
strategy = distribution_utils.get_distribution_strategy(
_ = distribute_utils.configure_cluster(FLAGS.worker_hosts, FLAGS.task_index)
strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy,
num_gpus=FLAGS.num_gpus,
all_reduce_alg=FLAGS.all_reduce_alg,
......
......@@ -27,13 +27,13 @@ from absl import flags
from absl import logging
from six.moves import zip
import tensorflow as tf
from official.common import distribute_utils
from official.modeling.hyperparams import params_dict
from official.nlp.nhnet import evaluation
from official.nlp.nhnet import input_pipeline
from official.nlp.nhnet import models
from official.nlp.nhnet import optimizer
from official.nlp.transformer import metrics as transformer_metrics
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
FLAGS = flags.FLAGS
......@@ -185,7 +185,7 @@ def run():
if FLAGS.enable_mlir_bridge:
tf.config.experimental.enable_mlir_bridge()
strategy = distribution_utils.get_distribution_strategy(
strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy, tpu_address=FLAGS.tpu)
if strategy:
logging.info("***** Number of cores used : %d",
......
......@@ -23,11 +23,11 @@ from official.core import train_utils
# pylint: disable=unused-import
from official.common import registry_imports
# pylint: enable=unused-import
from official.common import distribute_utils
from official.common import flags as tfm_flags
from official.core import task_factory
from official.core import train_lib
from official.modeling import performance
from official.utils.misc import distribution_utils
FLAGS = flags.FLAGS
......@@ -48,7 +48,7 @@ def main(_):
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale)
distribution_strategy = distribution_utils.get_distribution_strategy(
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
......
......@@ -28,13 +28,13 @@ import tensorflow as tf
# pylint: disable=unused-import
from official.common import registry_imports
# pylint: enable=unused-import
from official.common import distribute_utils
from official.common import flags as tfm_flags
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.modeling import performance
from official.modeling.hyperparams import config_definitions
from official.utils.misc import distribution_utils
FLAGS = flags.FLAGS
......@@ -77,7 +77,7 @@ def run_continuous_finetune(
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale)
distribution_strategy = distribution_utils.get_distribution_strategy(
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
......
......@@ -29,7 +29,7 @@ from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
from official.common import distribute_utils
from official.modeling import performance
from official.nlp.transformer import compute_bleu
from official.nlp.transformer import data_pipeline
......@@ -40,7 +40,6 @@ from official.nlp.transformer import transformer
from official.nlp.transformer import translate
from official.nlp.transformer.utils import tokenizer
from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
INF = int(1e9)
......@@ -161,7 +160,7 @@ class TransformerTask(object):
params["steps_between_evals"] = flags_obj.steps_between_evals
params["enable_checkpointing"] = flags_obj.enable_checkpointing
self.distribution_strategy = distribution_utils.get_distribution_strategy(
self.distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
num_gpus=num_gpus,
all_reduce_alg=flags_obj.all_reduce_alg,
......@@ -197,7 +196,7 @@ class TransformerTask(object):
keras_utils.set_session_config(enable_xla=flags_obj.enable_xla)
_ensure_dir(flags_obj.model_dir)
with distribution_utils.get_strategy_scope(self.distribution_strategy):
with distribute_utils.get_strategy_scope(self.distribution_strategy):
model = transformer.create_model(params, is_train=True)
opt = self._create_optimizer()
......@@ -376,7 +375,7 @@ class TransformerTask(object):
# We only want to create the model under DS scope for TPU case.
# When 'distribution_strategy' is None, a no-op DummyContextManager will
# be used.
with distribution_utils.get_strategy_scope(distribution_strategy):
with distribute_utils.get_strategy_scope(distribution_strategy):
if not self.predict_model:
self.predict_model = transformer.create_model(self.params, False)
self._load_weights_if_possible(
......
......@@ -23,13 +23,13 @@ from absl import logging
import numpy as np
import tensorflow as tf
# pylint: disable=unused-import
from official.common import distribute_utils
from official.nlp.xlnet import common_flags
from official.nlp.xlnet import data_utils
from official.nlp.xlnet import optimization
from official.nlp.xlnet import training_utils
from official.nlp.xlnet import xlnet_config
from official.nlp.xlnet import xlnet_modeling as modeling
from official.utils.misc import distribution_utils
flags.DEFINE_integer("n_class", default=2, help="Number of classes.")
flags.DEFINE_string(
......@@ -130,7 +130,7 @@ def get_metric_fn():
def main(unused_argv):
del unused_argv
strategy = distribution_utils.get_distribution_strategy(
strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=FLAGS.strategy_type,
tpu_address=FLAGS.tpu)
if strategy:
......
......@@ -23,13 +23,13 @@ from absl import flags
from absl import logging
import tensorflow as tf
# pylint: disable=unused-import
from official.common import distribute_utils
from official.nlp.xlnet import common_flags
from official.nlp.xlnet import data_utils
from official.nlp.xlnet import optimization
from official.nlp.xlnet import training_utils
from official.nlp.xlnet import xlnet_config
from official.nlp.xlnet import xlnet_modeling as modeling
from official.utils.misc import distribution_utils
flags.DEFINE_integer(
"num_predict",
......@@ -72,7 +72,7 @@ def get_pretrainxlnet_model(model_config, run_config):
def main(unused_argv):
del unused_argv
num_hosts = 1
strategy = distribution_utils.get_distribution_strategy(
strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=FLAGS.strategy_type,
tpu_address=FLAGS.tpu)
if FLAGS.strategy_type == "tpu":
......
......@@ -27,6 +27,7 @@ from absl import logging
import tensorflow as tf
# pylint: disable=unused-import
import sentencepiece as spm
from official.common import distribute_utils
from official.nlp.xlnet import common_flags
from official.nlp.xlnet import data_utils
from official.nlp.xlnet import optimization
......@@ -34,7 +35,6 @@ from official.nlp.xlnet import squad_utils
from official.nlp.xlnet import training_utils
from official.nlp.xlnet import xlnet_config
from official.nlp.xlnet import xlnet_modeling as modeling
from official.utils.misc import distribution_utils
flags.DEFINE_string(
"test_feature_path", default=None, help="Path to feature of test set.")
......@@ -212,7 +212,7 @@ def get_qaxlnet_model(model_config, run_config, start_n_top, end_n_top):
def main(unused_argv):
del unused_argv
strategy = distribution_utils.get_distribution_strategy(
strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=FLAGS.strategy_type,
tpu_address=FLAGS.tpu)
if strategy:
......
......@@ -21,20 +21,17 @@ from __future__ import print_function
import json
import os
# pylint: disable=g-bad-import-order
import numpy as np
from absl import flags
from absl import logging
import numpy as np
import tensorflow as tf
# pylint: enable=g-bad-import-order
from official.common import distribute_utils
from official.recommendation import constants as rconst
from official.recommendation import data_pipeline
from official.recommendation import data_preprocessing
from official.recommendation import movielens
from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
FLAGS = flags.FLAGS
......@@ -142,7 +139,7 @@ def get_v1_distribution_strategy(params):
tpu_cluster_resolver, steps_per_run=100)
else:
distribution = distribution_utils.get_distribution_strategy(
distribution = distribute_utils.get_distribution_strategy(
num_gpus=params["num_gpus"])
return distribution
......
......@@ -33,13 +33,13 @@ from absl import logging
import tensorflow.compat.v2 as tf
# pylint: enable=g-bad-import-order
from official.common import distribute_utils
from official.recommendation import constants as rconst
from official.recommendation import movielens
from official.recommendation import ncf_common
from official.recommendation import ncf_input_pipeline
from official.recommendation import neumf_model
from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
from official.utils.misc import model_helpers
......@@ -225,7 +225,7 @@ def run_ncf(_):
loss_scale=flags_core.get_loss_scale(FLAGS, default_for_fp16="dynamic"))
tf.keras.mixed_precision.experimental.set_policy(policy)
strategy = distribution_utils.get_distribution_strategy(
strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy,
num_gpus=FLAGS.num_gpus,
tpu_address=FLAGS.tpu)
......@@ -271,7 +271,7 @@ def run_ncf(_):
params, producer, input_meta_data, strategy))
steps_per_epoch = None if generate_input_online else num_train_steps
with distribution_utils.get_strategy_scope(strategy):
with distribute_utils.get_strategy_scope(strategy):
keras_model = _get_keras_model(params)
optimizer = tf.keras.optimizers.Adam(
learning_rate=params["learning_rate"],
......
......@@ -13,197 +13,5 @@
# limitations under the License.
# ==============================================================================
"""Helper functions for running models in a distributed setting."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import os
import random
import string
from absl import logging
import tensorflow.compat.v2 as tf
from official.utils.misc import tpu_lib
def _collective_communication(all_reduce_alg):
"""Return a CollectiveCommunication based on all_reduce_alg.
Args:
all_reduce_alg: a string specifying which collective communication to pick,
or None.
Returns:
tf.distribute.experimental.CollectiveCommunication object
Raises:
ValueError: if `all_reduce_alg` not in [None, "ring", "nccl"]
"""
collective_communication_options = {
None: tf.distribute.experimental.CollectiveCommunication.AUTO,
"ring": tf.distribute.experimental.CollectiveCommunication.RING,
"nccl": tf.distribute.experimental.CollectiveCommunication.NCCL
}
if all_reduce_alg not in collective_communication_options:
raise ValueError(
"When used with `multi_worker_mirrored`, valid values for "
"all_reduce_alg are [`ring`, `nccl`]. Supplied value: {}".format(
all_reduce_alg))
return collective_communication_options[all_reduce_alg]
def _mirrored_cross_device_ops(all_reduce_alg, num_packs):
"""Return a CrossDeviceOps based on all_reduce_alg and num_packs.
Args:
all_reduce_alg: a string specifying which cross device op to pick, or None.
num_packs: an integer specifying number of packs for the cross device op.
Returns:
tf.distribute.CrossDeviceOps object or None.
Raises:
ValueError: if `all_reduce_alg` not in [None, "nccl", "hierarchical_copy"].
"""
if all_reduce_alg is None:
return None
mirrored_all_reduce_options = {
"nccl": tf.distribute.NcclAllReduce,
"hierarchical_copy": tf.distribute.HierarchicalCopyAllReduce
}
if all_reduce_alg not in mirrored_all_reduce_options:
raise ValueError(
"When used with `mirrored`, valid values for all_reduce_alg are "
"[`nccl`, `hierarchical_copy`]. Supplied value: {}".format(
all_reduce_alg))
cross_device_ops_class = mirrored_all_reduce_options[all_reduce_alg]
return cross_device_ops_class(num_packs=num_packs)
def get_distribution_strategy(distribution_strategy="mirrored",
num_gpus=0,
all_reduce_alg=None,
num_packs=1,
tpu_address=None):
"""Return a DistributionStrategy for running the model.
Args:
distribution_strategy: a string specifying which distribution strategy to
use. Accepted values are "off", "one_device", "mirrored",
"parameter_server", "multi_worker_mirrored", and "tpu" -- case
insensitive. "off" means not to use Distribution Strategy; "tpu" means to
use TPUStrategy using `tpu_address`.
num_gpus: Number of GPUs to run this model.
all_reduce_alg: Optional. Specifies which algorithm to use when performing
all-reduce. For `MirroredStrategy`, valid values are "nccl" and
"hierarchical_copy". For `MultiWorkerMirroredStrategy`, valid values are
"ring" and "nccl". If None, DistributionStrategy will choose based on
device topology.
num_packs: Optional. Sets the `num_packs` in `tf.distribute.NcclAllReduce`
or `tf.distribute.HierarchicalCopyAllReduce` for `MirroredStrategy`.
tpu_address: Optional. String that represents TPU to connect to. Must not be
None if `distribution_strategy` is set to `tpu`.
Returns:
tf.distribute.DistibutionStrategy object.
Raises:
ValueError: if `distribution_strategy` is "off" or "one_device" and
`num_gpus` is larger than 1; or `num_gpus` is negative or if
`distribution_strategy` is `tpu` but `tpu_address` is not specified.
"""
if num_gpus < 0:
raise ValueError("`num_gpus` can not be negative.")
distribution_strategy = distribution_strategy.lower()
if distribution_strategy == "off":
if num_gpus > 1:
raise ValueError("When {} GPUs are specified, distribution_strategy "
"flag cannot be set to `off`.".format(num_gpus))
return None
if distribution_strategy == "tpu":
# When tpu_address is an empty string, we communicate with local TPUs.
cluster_resolver = tpu_lib.tpu_initialize(tpu_address)
return tf.distribute.experimental.TPUStrategy(cluster_resolver)
if distribution_strategy == "multi_worker_mirrored":
return tf.distribute.experimental.MultiWorkerMirroredStrategy(
communication=_collective_communication(all_reduce_alg))
if distribution_strategy == "one_device":
if num_gpus == 0:
return tf.distribute.OneDeviceStrategy("device:CPU:0")
if num_gpus > 1:
raise ValueError("`OneDeviceStrategy` can not be used for more than "
"one device.")
return tf.distribute.OneDeviceStrategy("device:GPU:0")
if distribution_strategy == "mirrored":
if num_gpus == 0:
devices = ["device:CPU:0"]
else:
devices = ["device:GPU:%d" % i for i in range(num_gpus)]
return tf.distribute.MirroredStrategy(
devices=devices,
cross_device_ops=_mirrored_cross_device_ops(all_reduce_alg, num_packs))
if distribution_strategy == "parameter_server":
return tf.distribute.experimental.ParameterServerStrategy()
raise ValueError("Unrecognized Distribution Strategy: %r" %
distribution_strategy)
def configure_cluster(worker_hosts=None, task_index=-1):
"""Set multi-worker cluster spec in TF_CONFIG environment variable.
Args:
worker_hosts: comma-separated list of worker ip:port pairs.
Returns:
Number of workers in the cluster.
"""
tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
if tf_config:
num_workers = (
len(tf_config["cluster"].get("chief", [])) +
len(tf_config["cluster"].get("worker", [])))
elif worker_hosts:
workers = worker_hosts.split(",")
num_workers = len(workers)
if num_workers > 1 and task_index < 0:
raise ValueError("Must specify task_index when number of workers > 1")
task_index = 0 if num_workers == 1 else task_index
os.environ["TF_CONFIG"] = json.dumps({
"cluster": {
"worker": workers
},
"task": {
"type": "worker",
"index": task_index
}
})
else:
num_workers = 1
return num_workers
def get_strategy_scope(strategy):
if strategy:
strategy_scope = strategy.scope()
else:
strategy_scope = DummyContextManager()
return strategy_scope
class DummyContextManager(object):
def __enter__(self):
pass
def __exit__(self, *args):
pass
# pylint: disable=wildcard-import
from official.common.distribute_utils import *
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Initializes TPU system for TF 2.0."""
import tensorflow as tf
def tpu_initialize(tpu_address):
"""Initializes TPU for TF 2.0 training.
Args:
tpu_address: string, bns address of master TPU worker.
Returns:
A TPUClusterResolver.
"""
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
tpu=tpu_address)
if tpu_address not in ('', 'local'):
tf.config.experimental_connect_to_cluster(cluster_resolver)
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
return cluster_resolver
......@@ -19,13 +19,13 @@ from absl import app
from absl import flags
import gin
from official.common import distribute_utils
from official.common import flags as tfm_flags
from official.common import registry_imports # pylint: disable=unused-import
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.modeling import performance
from official.utils.misc import distribution_utils
FLAGS = flags.FLAGS
......@@ -46,7 +46,7 @@ def main(_):
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale)
distribution_strategy = distribution_utils.get_distribution_strategy(
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
......
......@@ -14,28 +14,19 @@
# ==============================================================================
"""Main function to train various object detection models."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import functools
import pprint
# pylint: disable=g-bad-import-order
# Import libraries
import tensorflow as tf
from absl import app
from absl import flags
from absl import logging
# pylint: enable=g-bad-import-order
import tensorflow as tf
from official.common import distribute_utils
from official.modeling.hyperparams import params_dict
from official.modeling.training import distributed_executor as executor
from official.utils import hyperparams_flags
from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
from official.vision.detection.configs import factory as config_factory
from official.vision.detection.dataloader import input_reader
......@@ -87,9 +78,9 @@ def run_executor(params,
strategy = prebuilt_strategy
else:
strategy_config = params.strategy_config
distribution_utils.configure_cluster(strategy_config.worker_hosts,
strategy_config.task_index)
strategy = distribution_utils.get_distribution_strategy(
distribute_utils.configure_cluster(strategy_config.worker_hosts,
strategy_config.task_index)
strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.strategy_type,
num_gpus=strategy_config.num_gpus,
all_reduce_alg=strategy_config.all_reduce_alg,
......
......@@ -23,11 +23,10 @@ from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
from official.common import distribute_utils
from official.modeling import hyperparams
from official.modeling import performance
from official.utils import hyperparams_flags
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
from official.vision.image_classification import callbacks as custom_callbacks
from official.vision.image_classification import dataset_factory
......@@ -291,17 +290,17 @@ def train_and_eval(
"""Runs the train and eval path using compile/fit."""
logging.info('Running train and eval.')
distribution_utils.configure_cluster(params.runtime.worker_hosts,
params.runtime.task_index)
distribute_utils.configure_cluster(params.runtime.worker_hosts,
params.runtime.task_index)
# Note: for TPUs, strategy and scope should be created before the dataset
strategy = strategy_override or distribution_utils.get_distribution_strategy(
strategy = strategy_override or distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu)
strategy_scope = distribution_utils.get_strategy_scope(strategy)
strategy_scope = distribute_utils.get_strategy_scope(strategy)
logging.info('Detected %d devices.',
strategy.num_replicas_in_sync if strategy else 1)
......
......@@ -25,9 +25,8 @@ from absl import flags
from absl import logging
import tensorflow as tf
import tensorflow_datasets as tfds
from official.common import distribute_utils
from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils
from official.utils.misc import model_helpers
from official.vision.image_classification.resnet import common
......@@ -82,12 +81,12 @@ def run(flags_obj, datasets_override=None, strategy_override=None):
Returns:
Dictionary of training and eval stats.
"""
strategy = strategy_override or distribution_utils.get_distribution_strategy(
strategy = strategy_override or distribute_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_obj.num_gpus,
tpu_address=flags_obj.tpu)
strategy_scope = distribution_utils.get_strategy_scope(strategy)
strategy_scope = distribute_utils.get_strategy_scope(strategy)
mnist = tfds.builder('mnist', data_dir=flags_obj.data_dir)
if flags_obj.download:
......
......@@ -23,10 +23,9 @@ from absl import flags
from absl import logging
import orbit
import tensorflow as tf
from official.common import distribute_utils
from official.modeling import performance
from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
from official.utils.misc import model_helpers
from official.vision.image_classification.resnet import common
......@@ -117,7 +116,7 @@ def run(flags_obj):
else 'channels_last')
tf.keras.backend.set_image_data_format(data_format)
strategy = distribution_utils.get_distribution_strategy(
strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_obj.num_gpus,
all_reduce_alg=flags_obj.all_reduce_alg,
......@@ -144,7 +143,7 @@ def run(flags_obj):
flags_obj.batch_size,
flags_obj.log_steps,
logdir=flags_obj.model_dir if flags_obj.enable_tensorboard else None)
with distribution_utils.get_strategy_scope(strategy):
with distribute_utils.get_strategy_scope(strategy):
runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback,
per_epoch_steps)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册