diff --git a/tensorflow/python/distribute/cluster_resolver/kubernetes_cluster_resolver.py b/tensorflow/python/distribute/cluster_resolver/kubernetes_cluster_resolver.py index 88625a55426c4a794666ae71b04066d19e9ffcd0..7ff6ec0f2d5c6f6d2315e98cf5e7250b118fbadd 100644 --- a/tensorflow/python/distribute/cluster_resolver/kubernetes_cluster_resolver.py +++ b/tensorflow/python/distribute/cluster_resolver/kubernetes_cluster_resolver.py @@ -107,16 +107,14 @@ class KubernetesClusterResolver(ClusterResolver): Returns: The name or URL of the session master. """ + task_type = task_type if task_type is not None else self.task_type + task_index = task_index if task_index is not None else self.task_index + if task_type is not None and task_index is not None: return format_master_url( self.cluster_spec().task_address(task_type, task_index), rpc_layer or self.rpc_layer) - if self.task_type is not None and self.task_index is not None: - return format_master_url( - self.cluster_spec().task_address(self.task_type, self.task_index), - rpc_layer or self.rpc_layer) - return '' def cluster_spec(self): diff --git a/tensorflow/python/distribute/cluster_resolver/slurm_cluster_resolver.py b/tensorflow/python/distribute/cluster_resolver/slurm_cluster_resolver.py index 1ab81731b7a111848608068220488a368d9b86ec..9dbe25b613447fde2140585742d005dab82fb018 100644 --- a/tensorflow/python/distribute/cluster_resolver/slurm_cluster_resolver.py +++ b/tensorflow/python/distribute/cluster_resolver/slurm_cluster_resolver.py @@ -23,6 +23,7 @@ import os import subprocess from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver +from tensorflow.python.distribute.cluster_resolver.cluster_resolver import format_master_url from tensorflow.python.training.server_lib import ClusterSpec @@ -206,10 +207,13 @@ class SlurmClusterResolver(ClusterResolver): """ task_type = task_type if task_type is not None else self.task_type task_index = task_index if task_index is not None else self.task_index - rpc_layer = rpc_layer or self.rpc_layer - master = self.cluster_spec().task_address(task_type, task_index) - return '%s://%s' % (rpc_layer, master) if rpc_layer else master + if task_type is not None and task_index is not None: + return format_master_url( + self.cluster_spec().task_address(task_type, task_index), + rpc_layer or self.rpc_layer) + + return '' @property def environment(self):