提交 9c20948d 编写于 作者: D Dong Lin 提交者: TensorFlower Gardener

Unit tests should use file in temporary directory as config file for CLIConfig

PiperOrigin-RevId: 258503116
上级 e25e18d6
...@@ -168,6 +168,9 @@ tf_cc_test( ...@@ -168,6 +168,9 @@ tf_cc_test(
size = "small", size = "small",
srcs = ["debug_io_utils_test.cc"], srcs = ["debug_io_utils_test.cc"],
linkstatic = tf_kernel_tests_linkstatic(), linkstatic = tf_kernel_tests_linkstatic(),
tags = [
"no_oss", # TODO(b/137652456): remove when fixed
],
deps = [ deps = [
":debug_callback_registry", ":debug_callback_registry",
":debug_grpc_testlib", ":debug_grpc_testlib",
......
...@@ -1171,6 +1171,7 @@ sh_test( ...@@ -1171,6 +1171,7 @@ sh_test(
":offline_analyzer", ":offline_analyzer",
], ],
tags = [ tags = [
"no_oss", # TODO(b/137652456): remove when fixed
"no_windows", "no_windows",
], ],
) )
...@@ -19,12 +19,14 @@ from __future__ import print_function ...@@ -19,12 +19,14 @@ from __future__ import print_function
import argparse import argparse
import curses import curses
import os
import tempfile import tempfile
import threading import threading
import numpy as np import numpy as np
from six.moves import queue from six.moves import queue
from tensorflow.python.debug.cli import cli_config
from tensorflow.python.debug.cli import cli_test_utils from tensorflow.python.debug.cli import cli_test_utils
from tensorflow.python.debug.cli import curses_ui from tensorflow.python.debug.cli import curses_ui
from tensorflow.python.debug.cli import debugger_cli_common from tensorflow.python.debug.cli import debugger_cli_common
...@@ -81,7 +83,10 @@ class MockCursesUI(curses_ui.CursesUI): ...@@ -81,7 +83,10 @@ class MockCursesUI(curses_ui.CursesUI):
# Observer for toast messages. # Observer for toast messages.
self.toasts = [] self.toasts = []
curses_ui.CursesUI.__init__(self) curses_ui.CursesUI.__init__(
self,
config=cli_config.CLIConfig(
config_file_path=os.path.join(tempfile.mkdtemp(), ".tfdbg_config")))
# Override the default path to the command history file to avoid test # Override the default path to the command history file to avoid test
# concurrency issues. # concurrency issues.
......
...@@ -17,6 +17,8 @@ from __future__ import absolute_import ...@@ -17,6 +17,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
import shutil
import argparse import argparse
import tempfile import tempfile
...@@ -53,6 +55,16 @@ class MockReadlineUI(readline_ui.ReadlineUI): ...@@ -53,6 +55,16 @@ class MockReadlineUI(readline_ui.ReadlineUI):
class CursesTest(test_util.TensorFlowTestCase): class CursesTest(test_util.TensorFlowTestCase):
def setUp(self):
self._tmp_dir = tempfile.mkdtemp()
self._tmp_config_path = os.path.join(self._tmp_dir, ".tfdbg_config")
self.assertFalse(gfile.Exists(self._tmp_config_path))
super(CursesTest, self).setUp()
def tearDown(self):
shutil.rmtree(self._tmp_dir)
super(CursesTest, self).tearDown()
def _babble(self, args, screen_info=None): def _babble(self, args, screen_info=None):
ap = argparse.ArgumentParser( ap = argparse.ArgumentParser(
description="Do babble.", usage=argparse.SUPPRESS) description="Do babble.", usage=argparse.SUPPRESS)
...@@ -70,16 +82,23 @@ class CursesTest(test_util.TensorFlowTestCase): ...@@ -70,16 +82,23 @@ class CursesTest(test_util.TensorFlowTestCase):
return debugger_cli_common.RichTextLines(lines) return debugger_cli_common.RichTextLines(lines)
def testUIFactoryCreatesReadlineUI(self): def testUIFactoryCreatesReadlineUI(self):
ui = ui_factory.get_ui("readline") ui = ui_factory.get_ui(
"readline",
config=cli_config.CLIConfig(config_file_path=self._tmp_config_path))
self.assertIsInstance(ui, readline_ui.ReadlineUI) self.assertIsInstance(ui, readline_ui.ReadlineUI)
def testUIFactoryRaisesExceptionOnInvalidUIType(self): def testUIFactoryRaisesExceptionOnInvalidUIType(self):
with self.assertRaisesRegexp(ValueError, "Invalid ui_type: 'foobar'"): with self.assertRaisesRegexp(ValueError, "Invalid ui_type: 'foobar'"):
ui_factory.get_ui("foobar") ui_factory.get_ui(
"foobar",
config=cli_config.CLIConfig(config_file_path=self._tmp_config_path))
def testUIFactoryRaisesExceptionOnInvalidUITypeGivenAvailable(self): def testUIFactoryRaisesExceptionOnInvalidUITypeGivenAvailable(self):
with self.assertRaisesRegexp(ValueError, "Invalid ui_type: 'readline'"): with self.assertRaisesRegexp(ValueError, "Invalid ui_type: 'readline'"):
ui_factory.get_ui("readline", available_ui_types=["curses"]) ui_factory.get_ui(
"readline",
available_ui_types=["curses"],
config=cli_config.CLIConfig(config_file_path=self._tmp_config_path))
def testRunUIExitImmediately(self): def testRunUIExitImmediately(self):
"""Make sure that the UI can exit properly after launch.""" """Make sure that the UI can exit properly after launch."""
......
...@@ -41,7 +41,10 @@ def main(_): ...@@ -41,7 +41,10 @@ def main(_):
z = tf.matmul(m, v, name="z") z = tf.matmul(m, v, name="z")
if FLAGS.debug: if FLAGS.debug:
sess = tf_debug.LocalCLIDebugWrapperSession(sess, ui_type=FLAGS.ui_type) sess = tf_debug.LocalCLIDebugWrapperSession(
sess,
ui_type=FLAGS.ui_type,
use_random_config_path=FLAGS.use_random_config_path)
if FLAGS.error == "shape_mismatch": if FLAGS.error == "shape_mismatch":
print(sess.run(y, feed_dict={ph_float: np.array([[0.0], [1.0], [2.0]])})) print(sess.run(y, feed_dict={ph_float: np.array([[0.0], [1.0], [2.0]])}))
...@@ -76,6 +79,14 @@ if __name__ == "__main__": ...@@ -76,6 +79,14 @@ if __name__ == "__main__":
const=True, const=True,
default=False, default=False,
help="Use debugger to track down bad values during training") help="Use debugger to track down bad values during training")
parser.add_argument(
"--use_random_config_path",
type="bool",
nargs="?",
const=True,
default=False,
help="""If set, set config file path to a random file in the temporary
directory.""")
FLAGS, unparsed = parser.parse_known_args() FLAGS, unparsed = parser.parse_known_args()
with tf.Graph().as_default(): with tf.Graph().as_default():
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
...@@ -125,7 +125,10 @@ def main(_): ...@@ -125,7 +125,10 @@ def main(_):
"The --debug and --tensorboard_debug_address flags are mutually " "The --debug and --tensorboard_debug_address flags are mutually "
"exclusive.") "exclusive.")
if FLAGS.debug: if FLAGS.debug:
sess = tf_debug.LocalCLIDebugWrapperSession(sess, ui_type=FLAGS.ui_type) sess = tf_debug.LocalCLIDebugWrapperSession(
sess,
ui_type=FLAGS.ui_type,
use_random_config_path=FLAGS.use_random_config_path)
elif FLAGS.tensorboard_debug_address: elif FLAGS.tensorboard_debug_address:
sess = tf_debug.TensorBoardDebugWrapperSession( sess = tf_debug.TensorBoardDebugWrapperSession(
sess, FLAGS.tensorboard_debug_address) sess, FLAGS.tensorboard_debug_address)
...@@ -189,6 +192,14 @@ if __name__ == "__main__": ...@@ -189,6 +192,14 @@ if __name__ == "__main__":
help="Connect to the TensorBoard Debugger Plugin backend specified by " help="Connect to the TensorBoard Debugger Plugin backend specified by "
"the gRPC address (e.g., localhost:1234). Mutually exclusive with the " "the gRPC address (e.g., localhost:1234). Mutually exclusive with the "
"--debug flag.") "--debug flag.")
parser.add_argument(
"--use_random_config_path",
type="bool",
nargs="?",
const=True,
default=False,
help="""If set, set config file path to a random file in the temporary
directory.""")
FLAGS, unparsed = parser.parse_known_args() FLAGS, unparsed = parser.parse_known_args()
with tf.Graph().as_default(): with tf.Graph().as_default():
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
...@@ -71,13 +71,13 @@ run ...@@ -71,13 +71,13 @@ run
exit exit
EOF EOF
cat << EOF | ${DEBUG_ERRORS_BIN} --error=uninitialized_variable --debug --ui_type=readline cat << EOF | ${DEBUG_ERRORS_BIN} --error=uninitialized_variable --debug --ui_type=readline --use_random_config_path
run run
ni -a -d -t v/read ni -a -d -t v/read
exit exit
EOF EOF
cat << EOF | ${DEBUG_MNIST_BIN} --debug --max_steps=1 --fake_data --ui_type=readline cat << EOF | ${DEBUG_MNIST_BIN} --debug --max_steps=1 --fake_data --ui_type=readline --use_random_config_path
run -t 1 run -t 1
run --node_name_filter hidden --op_type_filter MatMul run --node_name_filter hidden --op_type_filter MatMul
run -f has_inf_or_nan run -f has_inf_or_nan
......
...@@ -25,6 +25,7 @@ import tempfile ...@@ -25,6 +25,7 @@ import tempfile
# Google-internal import(s). # Google-internal import(s).
from tensorflow.python.debug.cli import analyzer_cli from tensorflow.python.debug.cli import analyzer_cli
from tensorflow.python.debug.cli import cli_config
from tensorflow.python.debug.cli import cli_shared from tensorflow.python.debug.cli import cli_shared
from tensorflow.python.debug.cli import command_parser from tensorflow.python.debug.cli import command_parser
from tensorflow.python.debug.cli import debugger_cli_common from tensorflow.python.debug.cli import debugger_cli_common
...@@ -38,6 +39,7 @@ from tensorflow.python.debug.wrappers import framework ...@@ -38,6 +39,7 @@ from tensorflow.python.debug.wrappers import framework
_DUMP_ROOT_PREFIX = "tfdbg_" _DUMP_ROOT_PREFIX = "tfdbg_"
# TODO(donglin) Remove use_random_config_path after b/137652456 is fixed.
class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
"""Concrete subclass of BaseDebugWrapperSession implementing a local CLI. """Concrete subclass of BaseDebugWrapperSession implementing a local CLI.
...@@ -51,7 +53,8 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): ...@@ -51,7 +53,8 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
dump_root=None, dump_root=None,
log_usage=True, log_usage=True,
ui_type="curses", ui_type="curses",
thread_name_filter=None): thread_name_filter=None,
use_random_config_path=False):
"""Constructor of LocalCLIDebugWrapperSession. """Constructor of LocalCLIDebugWrapperSession.
Args: Args:
...@@ -66,6 +69,8 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): ...@@ -66,6 +69,8 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
(curses | readline) (curses | readline)
thread_name_filter: Regular-expression white list for thread name. See thread_name_filter: Regular-expression white list for thread name. See
the doc of `BaseDebugWrapperSession` for details. the doc of `BaseDebugWrapperSession` for details.
use_random_config_path: If true, set config file path to a random file in
the temporary directory.
Raises: Raises:
ValueError: If dump_root is an existing and non-empty directory or if ValueError: If dump_root is an existing and non-empty directory or if
...@@ -120,8 +125,11 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): ...@@ -120,8 +125,11 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
self._skip_debug = False self._skip_debug = False
self._run_start_response = None self._run_start_response = None
self._is_run_start = True self._is_run_start = True
self._ui_type = ui_type self._ui_type = ui_type
self._config = None
if use_random_config_path:
self._config = cli_config.CLIConfig(
config_file_path=os.path.join(tempfile.mkdtemp(), ".tfdbg_config"))
def _is_disk_usage_reset_each_run(self): def _is_disk_usage_reset_each_run(self):
# The dumped tensors are all cleaned up after every Session.run # The dumped tensors are all cleaned up after every Session.run
...@@ -279,8 +287,7 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): ...@@ -279,8 +287,7 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
def _prep_cli_for_run_start(self): def _prep_cli_for_run_start(self):
"""Prepare (but not launch) the CLI for run-start.""" """Prepare (but not launch) the CLI for run-start."""
self._run_cli = ui_factory.get_ui(self._ui_type, config=self._config)
self._run_cli = ui_factory.get_ui(self._ui_type)
help_intro = debugger_cli_common.RichTextLines([]) help_intro = debugger_cli_common.RichTextLines([])
if self._run_call_count == 1: if self._run_call_count == 1:
...@@ -409,8 +416,11 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): ...@@ -409,8 +416,11 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
self._title_color = "red_on_white" self._title_color = "red_on_white"
self._run_cli = analyzer_cli.create_analyzer_ui( self._run_cli = analyzer_cli.create_analyzer_ui(
debug_dump, self._tensor_filters, ui_type=self._ui_type, debug_dump,
on_ui_exit=self._remove_dump_root) self._tensor_filters,
ui_type=self._ui_type,
on_ui_exit=self._remove_dump_root,
config=self._config)
# Get names of all dumped tensors. # Get names of all dumped tensors.
dumped_tensor_names = [] dumped_tensor_names = []
......
...@@ -23,6 +23,7 @@ import tempfile ...@@ -23,6 +23,7 @@ import tempfile
import numpy as np import numpy as np
from tensorflow.python.debug.cli import cli_config
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session from tensorflow.python.client import session
...@@ -112,7 +113,10 @@ class LocalCLIDebuggerWrapperSessionForTest( ...@@ -112,7 +113,10 @@ class LocalCLIDebuggerWrapperSessionForTest(
else: else:
self.observers["run_end_cli_run_numbers"].append(self._run_call_count) self.observers["run_end_cli_run_numbers"].append(self._run_call_count)
readline_cli = ui_factory.get_ui("readline") readline_cli = ui_factory.get_ui(
"readline",
config=cli_config.CLIConfig(
config_file_path=os.path.join(tempfile.mkdtemp(), ".tfdbg_config")))
self._register_this_run_info(readline_cli) self._register_this_run_info(readline_cli)
while True: while True:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册