提交 83966daa 编写于 作者: S Stephan Lee 提交者: TensorFlower Gardener

Summary API for trace.

The Trace API allows user to trace execution and collect
graph or profile information.

PiperOrigin-RevId: 235563123
上级 64ff7c16
......@@ -3308,6 +3308,7 @@ py_library(
":util",
"//tensorflow/core:protos_all_py",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:profiler",
"@six_archive//:six",
],
)
......
......@@ -51,16 +51,24 @@ _profiler_lock = threading.Lock()
_run_num = 0
class ProfilerAlreadyRunningError(Exception):
pass
class ProfilerNotRunningError(Exception):
pass
def start():
"""Start profiling.
Raises:
AssertionError: If another profiling session is running.
ProfilerAlreadyRunningError: If another profiling session is running.
"""
global _profiler
with _profiler_lock:
if _profiler is not None:
raise AssertionError('Another profiler is running.')
raise ProfilerAlreadyRunningError('Another profiler is running.')
profiler_context = pywrap_tensorflow.TFE_NewProfilerContext()
if context.default_execution_mode == context.EAGER_MODE:
pywrap_tensorflow.TFE_ProfilerContextSetEagerContext(
......@@ -82,13 +90,14 @@ def stop():
to file for offline analysis by tensorboard.
Raises:
AssertionError: If there is no active profiling session.
ProfilerNotRunningError: If there is no active profiling session.
"""
global _profiler
global _run_num
with _profiler_lock:
if _profiler is None:
raise AssertionError('Cannot stop profiling. No profiler is running.')
raise ProfilerNotRunningError(
'Cannot stop profiling. No profiler is running.')
with c_api_util.tf_buffer() as buffer_:
pywrap_tensorflow.TFE_ProfilerSerializeToString(
context.context()._handle, # pylint: disable=protected-access
......
......@@ -33,7 +33,7 @@ class ProfilerTest(test_util.TensorFlowTestCase):
five = constant_op.constant(5)
product = three * five
self.assertAllEqual(15, product)
with self.assertRaises(AssertionError):
with self.assertRaises(profiler.ProfilerAlreadyRunningError):
profiler.start()
profile_result = profiler.stop()
......@@ -41,7 +41,7 @@ class ProfilerTest(test_util.TensorFlowTestCase):
profile_pb.ParseFromString(profile_result)
profile_pb_str = '%s' % profile_pb
self.assertTrue('Mul' in profile_pb_str)
with self.assertRaises(AssertionError):
with self.assertRaises(profiler.ProfilerNotRunningError):
profiler.stop()
......
......@@ -1113,6 +1113,7 @@ cuda_py_test(
"//tensorflow/python:math_ops",
"//tensorflow/python:platform",
"//tensorflow/python:summary_ops_v2",
"@six_archive//:six",
"//tensorflow/python:tensor_spec",
"//tensorflow/python:tensor_util",
"//tensorflow/python:variables",
......
......@@ -21,6 +21,8 @@ from __future__ import print_function
import os
import unittest
import six
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
from tensorflow.core.framework import step_stats_pb2
......@@ -45,6 +47,7 @@ from tensorflow.python.ops import summary_ops_v2 as summary_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
class SummaryOpsCoreTest(test_util.TensorFlowTestCase):
......@@ -571,6 +574,9 @@ class SummaryWriterTest(test_util.TensorFlowTestCase):
class SummaryOpsTest(test_util.TensorFlowTestCase):
def tearDown(self):
summary_ops.disable_trace()
def run_metadata(self, *args, **kwargs):
assert context.executing_eagerly()
logdir = self.get_temp_dir()
......@@ -616,6 +622,18 @@ class SummaryOpsTest(test_util.TensorFlowTestCase):
# the second event.
return events[1].summary
def run_trace(self, f):
assert context.executing_eagerly()
logdir = self.get_temp_dir()
writer = summary_ops.create_file_writer(logdir)
summary_ops.enable_trace(graph=True, profiler=False)
with writer.as_default():
f()
summary_ops.export_trace(name='foo', step=1)
writer.close()
events = events_from_logdir(logdir)
return events[1].summary
@test_util.run_v2_only
def testRunMetadata_usesNameAsTag(self):
meta = config_pb2.RunMetadata()
......@@ -718,6 +736,62 @@ class SummaryOpsTest(test_util.TensorFlowTestCase):
first_val = summary.value[0]
self.assertEqual(model.to_json(), first_val.tensor.string_val[0])
@test_util.run_v2_only
def testTrace(self):
@def_function.function
def f():
x = constant_op.constant(2)
y = constant_op.constant(3)
return x**y
summary = self.run_trace(f)
first_val = summary.value[0]
actual_run_metadata = config_pb2.RunMetadata.FromString(
first_val.tensor.string_val[0])
# Content of function_graphs is large and, for instance, device can change.
self.assertTrue(hasattr(actual_run_metadata, 'function_graphs'))
@test_util.run_v2_only
def testTrace_cannotEnableTraceInFunction(self):
@def_function.function
def f():
summary_ops.enable_trace(graph=True, profiler=False)
x = constant_op.constant(2)
y = constant_op.constant(3)
return x**y
with test.mock.patch.object(logging, 'warn') as mock_log:
f()
self.assertRegexpMatches(
str(mock_log.call_args), 'Must enable trace in eager mode.')
@test_util.run_v2_only
def testTrace_cannotExportTraceWithoutTrace(self):
with six.assertRaisesRegex(self, ValueError,
'Must enable trace before export.'):
summary_ops.export_trace(name='foo', step=1)
@test_util.run_v2_only
def testTrace_cannotExportTraceInFunction(self):
summary_ops.enable_trace(graph=True, profiler=False)
@def_function.function
def f():
x = constant_op.constant(2)
y = constant_op.constant(3)
summary_ops.export_trace(name='foo', step=1)
return x**y
with test.mock.patch.object(logging, 'warn') as mock_log:
f()
self.assertRegexpMatches(
str(mock_log.call_args),
'Can only export trace while executing eagerly.')
def events_from_file(filepath):
"""Returns all events in a single event file.
......
......@@ -20,10 +20,12 @@ from __future__ import division
from __future__ import print_function
import abc
import collections
import functools
import getpass
import os
import re
import threading
import time
import six
......@@ -32,6 +34,7 @@ from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import summary_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.eager import context
from tensorflow.python.eager import profiler as _profiler
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
......@@ -974,3 +977,97 @@ def keras_model(name, data, step):
tensor=constant_op.constant(json_string, dtype=dtypes.string),
step=step,
metadata=summary_metadata)
_TraceContext = collections.namedtuple("TraceContext", ("graph", "profiler"))
_current_trace_context_lock = threading.Lock()
_current_trace_context = None
@tf_export("summary.enable_trace", v1=[])
def enable_trace(graph=True, profiler=False):
"""Enables execution trace.
Args:
graph: whether to collect graphs used in execution
profiler: whether to enable profiler.
Returns:
None
"""
if not context.context().executing_eagerly():
logging.warn("Must enable trace in eager mode.")
return
global _current_trace_context
with _current_trace_context_lock:
if _current_trace_context:
logging.warn("Trace already enabled")
return
if graph and not profiler:
context.context().enable_graph_collection()
if profiler:
context.context().enable_run_metadata()
_profiler.start()
_current_trace_context = _TraceContext(graph=graph, profiler=profiler)
@tf_export("summary.export_trace", v1=[])
def export_trace(name, step, profiler_outdir=None):
"""Exports trace as a Summary and/or profile file.
Args:
name: A name for the summary to be written.
step: Required `int64`-castable monotonic step value.
profiler_outdir: Output directory for profiler. It is required when profiler
is enabled when trace was started. Otherwise, it is ignored.
Returns:
None
"""
# TODO(stephanlee): See if we can remove profiler_outdir and infer it from
# the SummaryWriter's logdir.
global _current_trace_context
if not context.context().executing_eagerly():
logging.warn("Can only export trace while executing eagerly.")
return
with _current_trace_context_lock:
if _current_trace_context is None:
raise ValueError("Must enable trace before export.")
graph, profiler = _current_trace_context
if profiler and profiler_outdir is None:
raise ValueError("Required profiler_outdir is not specified")
run_meta = context.context().export_run_metadata()
if graph and not profiler:
run_metadata_graphs(name, run_meta, step)
else:
run_metadata(name, run_meta, step)
if profiler:
_profiler.save(profiler_outdir, _profiler.stop())
disable_trace()
@tf_export("summary.disable_trace", v1=[])
def disable_trace():
"""Disables and resets the trace state."""
global _current_trace_context
with _current_trace_context_lock:
_current_trace_context = None
# Disabling run_metadata disables graph collection as well.
context.context().disable_run_metadata()
# profiler only has start and stop. One needs to stop in order to export
# and stopping when it is not running will raise an error.
try:
_profiler.stop()
except _profiler.ProfilerNotRunningError:
pass
......@@ -16,6 +16,18 @@ tf_module {
name: "create_noop_writer"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "disable_trace"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "enable_trace"
argspec: "args=[\'graph\', \'profiler\'], varargs=None, keywords=None, defaults=[\'True\', \'False\'], "
}
member_method {
name: "export_trace"
argspec: "args=[\'name\', \'step\', \'profiler_outdir\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "flush"
argspec: "args=[\'writer\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册