提交 9c20948d 编写于 作者: D Dong Lin 提交者: TensorFlower Gardener

Unit tests should use file in temporary directory as config file for CLIConfig

PiperOrigin-RevId: 258503116
上级 e25e18d6
......@@ -168,6 +168,9 @@ tf_cc_test(
size = "small",
srcs = ["debug_io_utils_test.cc"],
linkstatic = tf_kernel_tests_linkstatic(),
tags = [
"no_oss", # TODO(b/137652456): remove when fixed
],
deps = [
":debug_callback_registry",
":debug_grpc_testlib",
......
......@@ -1171,6 +1171,7 @@ sh_test(
":offline_analyzer",
],
tags = [
"no_oss", # TODO(b/137652456): remove when fixed
"no_windows",
],
)
......@@ -19,12 +19,14 @@ from __future__ import print_function
import argparse
import curses
import os
import tempfile
import threading
import numpy as np
from six.moves import queue
from tensorflow.python.debug.cli import cli_config
from tensorflow.python.debug.cli import cli_test_utils
from tensorflow.python.debug.cli import curses_ui
from tensorflow.python.debug.cli import debugger_cli_common
......@@ -81,7 +83,10 @@ class MockCursesUI(curses_ui.CursesUI):
# Observer for toast messages.
self.toasts = []
curses_ui.CursesUI.__init__(self)
curses_ui.CursesUI.__init__(
self,
config=cli_config.CLIConfig(
config_file_path=os.path.join(tempfile.mkdtemp(), ".tfdbg_config")))
# Override the default path to the command history file to avoid test
# concurrency issues.
......
......@@ -17,6 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import shutil
import argparse
import tempfile
......@@ -53,6 +55,16 @@ class MockReadlineUI(readline_ui.ReadlineUI):
class CursesTest(test_util.TensorFlowTestCase):
def setUp(self):
self._tmp_dir = tempfile.mkdtemp()
self._tmp_config_path = os.path.join(self._tmp_dir, ".tfdbg_config")
self.assertFalse(gfile.Exists(self._tmp_config_path))
super(CursesTest, self).setUp()
def tearDown(self):
shutil.rmtree(self._tmp_dir)
super(CursesTest, self).tearDown()
def _babble(self, args, screen_info=None):
ap = argparse.ArgumentParser(
description="Do babble.", usage=argparse.SUPPRESS)
......@@ -70,16 +82,23 @@ class CursesTest(test_util.TensorFlowTestCase):
return debugger_cli_common.RichTextLines(lines)
def testUIFactoryCreatesReadlineUI(self):
ui = ui_factory.get_ui("readline")
ui = ui_factory.get_ui(
"readline",
config=cli_config.CLIConfig(config_file_path=self._tmp_config_path))
self.assertIsInstance(ui, readline_ui.ReadlineUI)
def testUIFactoryRaisesExceptionOnInvalidUIType(self):
with self.assertRaisesRegexp(ValueError, "Invalid ui_type: 'foobar'"):
ui_factory.get_ui("foobar")
ui_factory.get_ui(
"foobar",
config=cli_config.CLIConfig(config_file_path=self._tmp_config_path))
def testUIFactoryRaisesExceptionOnInvalidUITypeGivenAvailable(self):
with self.assertRaisesRegexp(ValueError, "Invalid ui_type: 'readline'"):
ui_factory.get_ui("readline", available_ui_types=["curses"])
ui_factory.get_ui(
"readline",
available_ui_types=["curses"],
config=cli_config.CLIConfig(config_file_path=self._tmp_config_path))
def testRunUIExitImmediately(self):
"""Make sure that the UI can exit properly after launch."""
......
......@@ -41,7 +41,10 @@ def main(_):
z = tf.matmul(m, v, name="z")
if FLAGS.debug:
sess = tf_debug.LocalCLIDebugWrapperSession(sess, ui_type=FLAGS.ui_type)
sess = tf_debug.LocalCLIDebugWrapperSession(
sess,
ui_type=FLAGS.ui_type,
use_random_config_path=FLAGS.use_random_config_path)
if FLAGS.error == "shape_mismatch":
print(sess.run(y, feed_dict={ph_float: np.array([[0.0], [1.0], [2.0]])}))
......@@ -76,6 +79,14 @@ if __name__ == "__main__":
const=True,
default=False,
help="Use debugger to track down bad values during training")
parser.add_argument(
"--use_random_config_path",
type="bool",
nargs="?",
const=True,
default=False,
help="""If set, set config file path to a random file in the temporary
directory.""")
FLAGS, unparsed = parser.parse_known_args()
with tf.Graph().as_default():
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
......@@ -125,7 +125,10 @@ def main(_):
"The --debug and --tensorboard_debug_address flags are mutually "
"exclusive.")
if FLAGS.debug:
sess = tf_debug.LocalCLIDebugWrapperSession(sess, ui_type=FLAGS.ui_type)
sess = tf_debug.LocalCLIDebugWrapperSession(
sess,
ui_type=FLAGS.ui_type,
use_random_config_path=FLAGS.use_random_config_path)
elif FLAGS.tensorboard_debug_address:
sess = tf_debug.TensorBoardDebugWrapperSession(
sess, FLAGS.tensorboard_debug_address)
......@@ -189,6 +192,14 @@ if __name__ == "__main__":
help="Connect to the TensorBoard Debugger Plugin backend specified by "
"the gRPC address (e.g., localhost:1234). Mutually exclusive with the "
"--debug flag.")
parser.add_argument(
"--use_random_config_path",
type="bool",
nargs="?",
const=True,
default=False,
help="""If set, set config file path to a random file in the temporary
directory.""")
FLAGS, unparsed = parser.parse_known_args()
with tf.Graph().as_default():
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
......@@ -71,13 +71,13 @@ run
exit
EOF
cat << EOF | ${DEBUG_ERRORS_BIN} --error=uninitialized_variable --debug --ui_type=readline
cat << EOF | ${DEBUG_ERRORS_BIN} --error=uninitialized_variable --debug --ui_type=readline --use_random_config_path
run
ni -a -d -t v/read
exit
EOF
cat << EOF | ${DEBUG_MNIST_BIN} --debug --max_steps=1 --fake_data --ui_type=readline
cat << EOF | ${DEBUG_MNIST_BIN} --debug --max_steps=1 --fake_data --ui_type=readline --use_random_config_path
run -t 1
run --node_name_filter hidden --op_type_filter MatMul
run -f has_inf_or_nan
......
......@@ -25,6 +25,7 @@ import tempfile
# Google-internal import(s).
from tensorflow.python.debug.cli import analyzer_cli
from tensorflow.python.debug.cli import cli_config
from tensorflow.python.debug.cli import cli_shared
from tensorflow.python.debug.cli import command_parser
from tensorflow.python.debug.cli import debugger_cli_common
......@@ -38,6 +39,7 @@ from tensorflow.python.debug.wrappers import framework
_DUMP_ROOT_PREFIX = "tfdbg_"
# TODO(donglin) Remove use_random_config_path after b/137652456 is fixed.
class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
"""Concrete subclass of BaseDebugWrapperSession implementing a local CLI.
......@@ -51,7 +53,8 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
dump_root=None,
log_usage=True,
ui_type="curses",
thread_name_filter=None):
thread_name_filter=None,
use_random_config_path=False):
"""Constructor of LocalCLIDebugWrapperSession.
Args:
......@@ -66,6 +69,8 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
(curses | readline)
thread_name_filter: Regular-expression white list for thread name. See
the doc of `BaseDebugWrapperSession` for details.
use_random_config_path: If true, set config file path to a random file in
the temporary directory.
Raises:
ValueError: If dump_root is an existing and non-empty directory or if
......@@ -120,8 +125,11 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
self._skip_debug = False
self._run_start_response = None
self._is_run_start = True
self._ui_type = ui_type
self._config = None
if use_random_config_path:
self._config = cli_config.CLIConfig(
config_file_path=os.path.join(tempfile.mkdtemp(), ".tfdbg_config"))
def _is_disk_usage_reset_each_run(self):
# The dumped tensors are all cleaned up after every Session.run
......@@ -279,8 +287,7 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
def _prep_cli_for_run_start(self):
"""Prepare (but not launch) the CLI for run-start."""
self._run_cli = ui_factory.get_ui(self._ui_type)
self._run_cli = ui_factory.get_ui(self._ui_type, config=self._config)
help_intro = debugger_cli_common.RichTextLines([])
if self._run_call_count == 1:
......@@ -409,8 +416,11 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
self._title_color = "red_on_white"
self._run_cli = analyzer_cli.create_analyzer_ui(
debug_dump, self._tensor_filters, ui_type=self._ui_type,
on_ui_exit=self._remove_dump_root)
debug_dump,
self._tensor_filters,
ui_type=self._ui_type,
on_ui_exit=self._remove_dump_root,
config=self._config)
# Get names of all dumped tensors.
dumped_tensor_names = []
......
......@@ -23,6 +23,7 @@ import tempfile
import numpy as np
from tensorflow.python.debug.cli import cli_config
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session
......@@ -112,7 +113,10 @@ class LocalCLIDebuggerWrapperSessionForTest(
else:
self.observers["run_end_cli_run_numbers"].append(self._run_call_count)
readline_cli = ui_factory.get_ui("readline")
readline_cli = ui_factory.get_ui(
"readline",
config=cli_config.CLIConfig(
config_file_path=os.path.join(tempfile.mkdtemp(), ".tfdbg_config")))
self._register_this_run_info(readline_cli)
while True:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册