提交 1af7172d 编写于 作者: Y Yanhui Liang 提交者: A. Unique TensorFlower

Remove 'num_workers' arg from get_distribution_strategy() method.

PiperOrigin-RevId: 291810091
上级 569ec532
......@@ -673,12 +673,11 @@ class ExecutorBuilder(object):
"""
def __init__(self, strategy_type=None, strategy_config=None):
num_workers = distribution_utils.configure_cluster(
_ = distribution_utils.configure_cluster(
strategy_config.worker_hosts, strategy_config.task_index)
self._strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=strategy_type,
num_gpus=strategy_config.num_gpus,
num_workers=num_workers,
all_reduce_alg=strategy_config.all_reduce_alg,
num_packs=strategy_config.num_packs,
tpu_address=strategy_config.tpu)
......
......@@ -563,7 +563,6 @@ def resnet_main(
distribution_strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_core.get_num_gpus(flags_obj),
num_workers=num_workers,
all_reduce_alg=flags_obj.all_reduce_alg,
num_packs=flags_obj.num_packs)
......
......@@ -83,7 +83,6 @@ def _mirrored_cross_device_ops(all_reduce_alg, num_packs):
def get_distribution_strategy(distribution_strategy="mirrored",
num_gpus=0,
num_workers=1,
all_reduce_alg=None,
num_packs=1,
tpu_address=None):
......@@ -96,7 +95,6 @@ def get_distribution_strategy(distribution_strategy="mirrored",
'off' means not to use Distribution Strategy; 'tpu' means to use
TPUStrategy using `tpu_address`.
num_gpus: Number of GPUs to run this model.
num_workers: Number of workers to run this model.
all_reduce_alg: Optional. Specifies which algorithm to use when performing
all-reduce. For `MirroredStrategy`, valid values are "nccl" and
"hierarchical_copy". For `MultiWorkerMirroredStrategy`, valid values are
......@@ -120,8 +118,8 @@ def get_distribution_strategy(distribution_strategy="mirrored",
if distribution_strategy == "off":
if num_gpus > 1:
raise ValueError(
"When {} GPUs and {} workers are specified, distribution_strategy "
"flag cannot be set to 'off'.".format(num_gpus, num_workers))
"When {} GPUs are specified, distribution_strategy "
"flag cannot be set to 'off'.".format(num_gpus))
return None
if distribution_strategy == "tpu":
......
......@@ -104,7 +104,6 @@ def run(flags_obj):
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_obj.num_gpus,
num_workers=distribution_utils.configure_cluster(),
all_reduce_alg=flags_obj.all_reduce_alg,
num_packs=flags_obj.num_packs)
......
......@@ -212,7 +212,6 @@ def run(flags_obj):
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_obj.num_gpus,
num_workers=distribution_utils.configure_cluster(),
all_reduce_alg=flags_obj.all_reduce_alg,
num_packs=flags_obj.num_packs,
tpu_address=flags_obj.tpu)
......
......@@ -84,13 +84,12 @@ def run(flags_obj):
tf.keras.backend.set_image_data_format(data_format)
# Configures cluster spec for distribution strategy.
num_workers = distribution_utils.configure_cluster(flags_obj.worker_hosts,
flags_obj.task_index)
_ = distribution_utils.configure_cluster(flags_obj.worker_hosts,
flags_obj.task_index)
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_obj.num_gpus,
num_workers=num_workers,
all_reduce_alg=flags_obj.all_reduce_alg,
num_packs=flags_obj.num_packs,
tpu_address=flags_obj.tpu)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册