提交 2adcf837 编写于 作者: R Rick Chao 提交者: TensorFlower Gardener

Remove usage of...

Remove usage of multi_process_runner_util.try_run_and_except_connection_error() as setting FAIL_FAST=false fixes the connection error issue. Make grpc_fail_fast an arg for MultiProcessRunner's __init__().

PiperOrigin-RevId: 281193957
Change-Id: I1c99cb90a15fdb26892d0ad37c533c186297b09d
上级 06f9ec34
......@@ -1323,7 +1323,6 @@ py_library(
srcs = ["multi_process_runner.py"],
deps = [
":multi_process_lib",
":multi_process_runner_util",
":multi_worker_test_base",
"//tensorflow/python:client_testlib",
"//tensorflow/python:tf2",
......@@ -1332,12 +1331,6 @@ py_library(
],
)
py_library(
name = "multi_process_runner_util",
srcs = ["multi_process_runner_util.py"],
deps = [],
)
py_library(
name = "multi_process_lib",
srcs = ["multi_process_lib.py"],
......
......@@ -90,6 +90,7 @@ class MultiProcessRunner(object):
cluster_spec,
max_run_time=None,
capture_std_stream=False,
grpc_fail_fast=False,
args=None,
kwargs=None):
"""Creates a multi-process runner.
......@@ -111,6 +112,8 @@ class MultiProcessRunner(object):
level C/C++ code. So it can be delayed for arbitrarily long time.
capture_std_stream: Boolean, whether the messages streamed to stdout and
stderr in subprocesses are captured.
grpc_fail_fast: Whether GRPC connection between processes should fail
without retrying. Defaults to False.
args: Positional arguments to be sent to functions run on processes.
kwargs: Keyword arguments to be sent to functions run on processes.
......@@ -131,6 +134,7 @@ class MultiProcessRunner(object):
self._cluster_spec = cluster_spec
self._max_run_time = max_run_time
self._capture_std_stream = capture_std_stream
self._grpc_fail_fast = grpc_fail_fast
self._args = args or ()
self._kwargs = kwargs or {}
self._processes = []
......@@ -164,6 +168,7 @@ class MultiProcessRunner(object):
def _proc_func_wrapper(self, task_type, task_id, *arg, **kwargs):
"""The wrapper function that actually gets run in child process(es)."""
os.environ['GRPC_FAIL_FAST'] = str(self._grpc_fail_fast)
os.environ['TF_CONFIG'] = json.dumps({
'cluster': self._cluster_spec,
'task': {
......@@ -331,6 +336,7 @@ def run(proc_func,
cluster_spec,
max_run_time=None,
capture_std_stream=False,
grpc_fail_fast=False,
args=None,
kwargs=None): # pylint: disable=g-doc-args
"""Runs functions in local child processes.
......@@ -347,6 +353,7 @@ def run(proc_func,
cluster_spec,
max_run_time=max_run_time,
capture_std_stream=capture_std_stream,
grpc_fail_fast=grpc_fail_fast,
args=args,
kwargs=kwargs)
runner.start()
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Util for multi-process runner."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
from tensorflow.python.framework import errors_impl
@contextlib.contextmanager
def try_run_and_except_connection_error(test_obj):
"""Context manager to skip cases not considered failures by the tests."""
# TODO(b/142074107): Remove this try-except once within-loop fault-tolerance
# is supported. This is temporarily needed to avoid test flakiness.
try:
yield
except errors_impl.UnavailableError as e:
if ('Connection reset by peer' in str(e) or 'Socket closed' in str(e) or
'failed to connect to all addresses' in str(e)):
test_obj.skipTest(
'Skipping connection error between processes: {}'.format(str(e)))
else:
raise
......@@ -26,7 +26,6 @@ import numpy as np
from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import multi_process_runner_util
from tensorflow.python.distribute import multi_worker_test_base as test_base
from tensorflow.python.distribute import reduce_util
from tensorflow.python.eager import context
......@@ -76,11 +75,9 @@ class MultiWorkerContinuousRunTest(test.TestCase, parameterized.TestCase):
for _ in range(100):
worker_step_fn()
# TODO(b/141948186): Remove this `with` block once b/141948186 is resolved.
with multi_process_runner_util.try_run_and_except_connection_error(self):
multi_process_runner.run(
worker_fn,
cluster_spec=test_base.create_cluster_spec(num_workers=num_workers))
multi_process_runner.run(
worker_fn,
cluster_spec=test_base.create_cluster_spec(num_workers=num_workers))
if __name__ == '__main__':
......
......@@ -24,7 +24,6 @@ from absl.testing import parameterized
from tensorflow.python.distribute import collective_all_reduce_strategy as collective_strategy
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import multi_process_runner_util
from tensorflow.python.distribute import multi_worker_test_base as test_base
from tensorflow.python.keras import callbacks
from tensorflow.python.keras.distribute import multi_worker_testing_utils
......@@ -74,6 +73,7 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
def proc_model_checkpoint_saves_on_chief_but_not_otherwise(
test_obj, file_format):
model, saving_filepath, train_ds, steps = _model_setup(
test_obj, file_format)
num_epoch = 2
......@@ -104,12 +104,10 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
training_state.checkpoint_exists(saving_filepath),
test_base.is_chief())
# TODO(b/141948186): Remove this `with` block once b/141948186 is resolved.
with multi_process_runner_util.try_run_and_except_connection_error(self):
multi_process_runner.run(
proc_model_checkpoint_saves_on_chief_but_not_otherwise,
cluster_spec=test_base.create_cluster_spec(num_workers=2),
args=(self, file_format))
multi_process_runner.run(
proc_model_checkpoint_saves_on_chief_but_not_otherwise,
cluster_spec=test_base.create_cluster_spec(num_workers=2),
args=(self, file_format))
@combinations.generate(combinations.combine(mode=['eager']))
def test_tensorboard_saves_on_chief_but_not_otherwise(self, mode):
......@@ -142,12 +140,10 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
test_obj.assertEqual(
bool(file_io.list_directory(saving_filepath)), test_base.is_chief())
# TODO(b/141948186): Remove this `with` block once b/141948186 is resolved.
with multi_process_runner_util.try_run_and_except_connection_error(self):
multi_process_runner.run(
proc_tensorboard_saves_on_chief_but_not_otherwise,
cluster_spec=test_base.create_cluster_spec(num_workers=2),
args=(self,))
multi_process_runner.run(
proc_tensorboard_saves_on_chief_but_not_otherwise,
cluster_spec=test_base.create_cluster_spec(num_workers=2),
args=(self,))
@combinations.generate(combinations.combine(mode=['eager']))
def test_tensorboard_can_still_save_to_temp_even_if_it_exists(self, mode):
......@@ -173,12 +169,10 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
steps_per_epoch=steps,
callbacks=[callbacks.TensorBoard(log_dir=saving_filepath)])
# TODO(b/141948186): Remove this `with` block once b/141948186 is resolved.
with multi_process_runner_util.try_run_and_except_connection_error(self):
multi_process_runner.run(
proc_tensorboard_can_still_save_to_temp_even_if_it_exists,
cluster_spec=test_base.create_cluster_spec(num_workers=2),
args=(self,))
multi_process_runner.run(
proc_tensorboard_can_still_save_to_temp_even_if_it_exists,
cluster_spec=test_base.create_cluster_spec(num_workers=2),
args=(self,))
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册