diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index d06e1f574b9fba7e8a8597b5f42a2782b640971c..af1687c8ef1cb7de4b4a86fb23c43d9ab101058e 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -1023,12 +1023,13 @@ def also_run_as_tf_function(f): """ def decorated(*args, **kwds): + def bound_f(): + f(*args, **kwds) with context.eager_mode(): # Running in eager mode - f(*args, **kwds) - - defun_f = def_function.function(f) - defun_f(*args, **kwds) + bound_f() + # Running as TF function + def_function.function(bound_f)() return decorated diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index ddb2ddaf63279532ab7bc559ae4de55978dbd795..bd5c103b38dc1561fbcb19b326052bd4f3c6f293 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -1068,6 +1068,25 @@ tf_py_test( ], ) +tf_py_test( + name = "summary_ops_test", + size = "small", + srcs = ["summary_ops_test.py"], + additional_deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:lib", + "//tensorflow/python:platform", + "//tensorflow/python:summary_ops_v2", + "//tensorflow/python:tensor_util", + "//tensorflow/python/eager:function", + "//tensorflow/python/eager:context", + ], +) + tf_py_test( name = "summary_v1_ops_test", size = "small", diff --git a/tensorflow/python/kernel_tests/summary_ops_test.py b/tensorflow/python/kernel_tests/summary_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..cd446eb40eb9ff1931a3eb4555f9dd81a77b659f --- /dev/null +++ b/tensorflow/python/kernel_tests/summary_ops_test.py @@ -0,0 +1,267 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for V2 summary ops from summary_ops_v2.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.core.framework import summary_pb2 +from tensorflow.core.util import event_pb2 +from tensorflow.python.eager import context +from tensorflow.python.eager import def_function +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.framework import test_util +from tensorflow.python.lib.io import tf_record +from tensorflow.python.ops import summary_ops_v2 as summary_ops +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test + + +class SummaryOpsTest(test_util.TensorFlowTestCase): + + def testWrite(self): + logdir = self.get_temp_dir() + with context.eager_mode(): + with summary_ops.create_file_writer(logdir).as_default(): + output = summary_ops.write('tag', 42, step=12) + self.assertTrue(output.numpy()) + events = events_from_logdir(logdir) + self.assertEqual(2, len(events)) + self.assertEqual(12, events[1].step) + value = events[1].summary.value[0] + self.assertEqual('tag', value.tag) + self.assertEqual(42, to_numpy(value)) + + def testWrite_fromFunction(self): + logdir = self.get_temp_dir() + @def_function.function + def f(): + with summary_ops.create_file_writer(logdir).as_default(): + return summary_ops.write('tag', 42, step=12) + with context.eager_mode(): + output = f() + self.assertTrue(output.numpy()) + events = events_from_logdir(logdir) + self.assertEqual(2, len(events)) + self.assertEqual(12, events[1].step) + value = events[1].summary.value[0] + self.assertEqual('tag', value.tag) + self.assertEqual(42, to_numpy(value)) + + def testWrite_metadata(self): + logdir = self.get_temp_dir() + metadata = summary_pb2.SummaryMetadata() + metadata.plugin_data.plugin_name = 'foo' + with context.eager_mode(): + with summary_ops.create_file_writer(logdir).as_default(): + summary_ops.write('obj', 0, 0, metadata=metadata) + summary_ops.write('bytes', 0, 0, metadata=metadata.SerializeToString()) + m = constant_op.constant(metadata.SerializeToString()) + summary_ops.write('string_tensor', 0, 0, metadata=m) + events = events_from_logdir(logdir) + self.assertEqual(4, len(events)) + self.assertEqual(metadata, events[1].summary.value[0].metadata) + self.assertEqual(metadata, events[2].summary.value[0].metadata) + self.assertEqual(metadata, events[3].summary.value[0].metadata) + + def testWrite_name(self): + @def_function.function + def f(): + output = summary_ops.write('tag', 42, step=12, name='anonymous') + self.assertTrue(output.name.startswith('anonymous')) + f() + + def testWrite_ndarray(self): + logdir = self.get_temp_dir() + with context.eager_mode(): + with summary_ops.create_file_writer(logdir).as_default(): + summary_ops.write('tag', [[1, 2], [3, 4]], step=12) + events = events_from_logdir(logdir) + value = events[1].summary.value[0] + self.assertAllEqual([[1, 2], [3, 4]], to_numpy(value)) + + def testWrite_tensor(self): + logdir = self.get_temp_dir() + with context.eager_mode(): + t = constant_op.constant([[1, 2], [3, 4]]) + with summary_ops.create_file_writer(logdir).as_default(): + summary_ops.write('tag', t, step=12) + expected = t.numpy() + events = events_from_logdir(logdir) + value = events[1].summary.value[0] + self.assertAllEqual(expected, to_numpy(value)) + + def testWrite_tensor_fromFunction(self): + logdir = self.get_temp_dir() + @def_function.function + def f(t): + with summary_ops.create_file_writer(logdir).as_default(): + summary_ops.write('tag', t, step=12) + with context.eager_mode(): + t = constant_op.constant([[1, 2], [3, 4]]) + f(t) + expected = t.numpy() + events = events_from_logdir(logdir) + value = events[1].summary.value[0] + self.assertAllEqual(expected, to_numpy(value)) + + def testWrite_stringTensor(self): + logdir = self.get_temp_dir() + with context.eager_mode(): + with summary_ops.create_file_writer(logdir).as_default(): + summary_ops.write('tag', [b'foo', b'bar'], step=12) + events = events_from_logdir(logdir) + value = events[1].summary.value[0] + self.assertAllEqual([b'foo', b'bar'], to_numpy(value)) + + @test_util.also_run_as_tf_function + def testWrite_noDefaultWriter(self): + with context.eager_mode(): + self.assertFalse(summary_ops.write('tag', 42, step=0)) + + def testWrite_shouldRecordSummaries(self): + logdir = self.get_temp_dir() + with context.eager_mode(): + with summary_ops.create_file_writer(logdir).as_default(): + self.assertTrue(summary_ops.write('default_on', 1, step=0)) + with summary_ops.always_record_summaries(): + self.assertTrue(summary_ops.write('set_on', 1, step=0)) + with summary_ops.never_record_summaries(): + self.assertFalse(summary_ops.write('set_off', 1, step=0)) + events = events_from_logdir(logdir) + self.assertEqual(3, len(events)) + self.assertEqual('default_on', events[1].summary.value[0].tag) + self.assertEqual('set_on', events[2].summary.value[0].tag) + + def testWrite_shouldRecordSummaries_fromFunction(self): + logdir = self.get_temp_dir() + @def_function.function + def f(tag_prefix): + with summary_ops.create_file_writer(logdir).as_default(): + default_output = summary_ops.write(tag_prefix + '_default', 1, step=0) + with summary_ops.always_record_summaries(): + on_output = summary_ops.write(tag_prefix + '_on', 1, step=0) + with summary_ops.never_record_summaries(): + off_output = summary_ops.write(tag_prefix + '_off', 1, step=0) + return [default_output, on_output, off_output] + with context.eager_mode(): + self.assertAllEqual([True, True, False], f('default')) + with summary_ops.always_record_summaries(): + self.assertAllEqual([True, True, False], f('on')) + with summary_ops.never_record_summaries(): + self.assertAllEqual([False, True, False], f('off')) + events = events_from_logdir(logdir) + self.assertEqual(6, len(events)) + self.assertEqual('default_default', events[1].summary.value[0].tag) + self.assertEqual('default_on', events[2].summary.value[0].tag) + self.assertEqual('on_default', events[3].summary.value[0].tag) + self.assertEqual('on_on', events[4].summary.value[0].tag) + self.assertEqual('off_on', events[5].summary.value[0].tag) + + @test_util.also_run_as_tf_function + def testSummaryScope(self): + with summary_ops.summary_scope('foo') as (tag, scope): + self.assertEqual('foo', tag) + self.assertEqual('foo/', scope) + with summary_ops.summary_scope('bar') as (tag, scope): + self.assertEqual('foo/bar', tag) + self.assertEqual('foo/bar/', scope) + with summary_ops.summary_scope('with/slash') as (tag, scope): + self.assertEqual('foo/with/slash', tag) + self.assertEqual('foo/with/slash/', scope) + with ops.name_scope(None): + with summary_ops.summary_scope('unnested') as (tag, scope): + self.assertEqual('unnested', tag) + self.assertEqual('unnested/', scope) + + @test_util.also_run_as_tf_function + def testSummaryScope_defaultName(self): + with summary_ops.summary_scope(None) as (tag, scope): + self.assertEqual('summary', tag) + self.assertEqual('summary/', scope) + with summary_ops.summary_scope(None, 'backup') as (tag, scope): + self.assertEqual('backup', tag) + self.assertEqual('backup/', scope) + + @test_util.also_run_as_tf_function + def testSummaryScope_handlesCharactersIllegalForScope(self): + with summary_ops.summary_scope('f?o?o') as (tag, scope): + self.assertEqual('f?o?o', tag) + self.assertEqual('foo/', scope) + # If all characters aren't legal for a scope name, use default name. + with summary_ops.summary_scope('???', 'backup') as (tag, scope): + self.assertEqual('???', tag) + self.assertEqual('backup/', scope) + + @test_util.also_run_as_tf_function + def testSummaryScope_nameNotUniquifiedForTag(self): + constant_op.constant(0, name='foo') + with summary_ops.summary_scope('foo') as (tag, _): + self.assertEqual('foo', tag) + with summary_ops.summary_scope('foo') as (tag, _): + self.assertEqual('foo', tag) + with ops.name_scope('with'): + constant_op.constant(0, name='slash') + with summary_ops.summary_scope('with/slash') as (tag, _): + self.assertEqual('with/slash', tag) + + +def events_from_file(filepath): + """Returns all events in a single event file. + + Args: + filepath: Path to the event file. + + Returns: + A list of all tf.Event protos in the event file. + """ + records = list(tf_record.tf_record_iterator(filepath)) + result = [] + for r in records: + event = event_pb2.Event() + event.ParseFromString(r) + result.append(event) + return result + + +def events_from_logdir(logdir): + """Returns all events in the single eventfile in logdir. + + Args: + logdir: The directory in which the single event file is sought. + + Returns: + A list of all tf.Event protos from the single event file. + + Raises: + AssertionError: If logdir does not contain exactly one file. + """ + assert gfile.Exists(logdir) + files = gfile.ListDirectory(logdir) + assert len(files) == 1, 'Found not exactly one file in logdir: %s' % files + return events_from_file(os.path.join(logdir, files[0])) + + +def to_numpy(summary_value): + return tensor_util.MakeNdarray(summary_value.tensor) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/ops/summary_ops_v2.py b/tensorflow/python/ops/summary_ops_v2.py index 3f99b9f8773b3d26cf334044e0d127bf7443bfea..168cb975548095be4648a9e705deb797241363c7 100644 --- a/tensorflow/python/ops/summary_ops_v2.py +++ b/tensorflow/python/ops/summary_ops_v2.py @@ -58,14 +58,31 @@ _RUN_NAME_PATTERNS = re.compile(r"^[^\x00-\x1F<>]{0,512}$") _USER_NAME_PATTERNS = re.compile(r"^[a-z]([-a-z0-9]{0,29}[a-z0-9])?$", re.I) -def should_record_summaries(): - """Returns boolean Tensor which is true if summaries should be recorded.""" +def _should_record_summaries_internal(): + """Returns boolean Tensor if summaries should/shouldn't be recorded, or None. + """ global _SHOULD_RECORD_SUMMARIES key = ops.get_default_graph()._graph_key # pylint: disable=protected-access - should = _SHOULD_RECORD_SUMMARIES.setdefault(key, False) + should = _SHOULD_RECORD_SUMMARIES.get(key) return should() if callable(should) else should +def _should_record_summaries_v2(): + """Returns boolean Tensor which is true if summaries should be recorded. + + If no recording status has been set, this defaults to True, unlike the public + should_record_summaries(). + """ + result = _should_record_summaries_internal() + return True if result is None else result + + +def should_record_summaries(): + """Returns boolean Tensor which is true if summaries should be recorded.""" + result = _should_record_summaries_internal() + return False if result is None else result + + @tf_contextlib.contextmanager def _record_summaries(boolean=True): """Sets summary recording on or off per the provided boolean value. @@ -86,7 +103,7 @@ def _record_summaries(boolean=True): # TODO(nickfelt): make this threadlocal global _SHOULD_RECORD_SUMMARIES key = ops.get_default_graph()._graph_key # pylint: disable=protected-access - old = _SHOULD_RECORD_SUMMARIES.setdefault(key, False) + old = _SHOULD_RECORD_SUMMARIES.setdefault(key, None) try: _SHOULD_RECORD_SUMMARIES[key] = boolean yield @@ -370,6 +387,98 @@ def summary_writer_initializer_op(): return _SUMMARY_WRITER_INIT_OP.setdefault(key, []) +_INVALID_SCOPE_CHARACTERS = re.compile(r"[^-_/.A-Za-z0-9]") + + +@tf_export("summary.summary_scope", v1=[]) +@tf_contextlib.contextmanager +def summary_scope(name, default_name="summary", values=None): + """A context manager for use when defining a custom summary op. + + This behaves similarly to `tf.name_scope`, except that it returns a generated + summary tag in addition to the scope name. The tag is structurally similar to + the scope name - derived from the user-provided name, prefixed with enclosing + name scopes if any - but we relax the constraint that it be uniquified, as + well as the character set limitation (so the user-provided name can contain + characters not legal for scope names; in the scope name these are removed). + + This makes the summary tag more predictable and consistent for the user. + + For example, to define a new summary op called `my_op`: + + ```python + def my_op(name, my_value, step): + with tf.summary.summary_scope(name, "MyOp", [my_value]) as (tag, scope): + my_value = tf.convert_to_tensor(my_value) + return tf.summary.write(tag, my_value, step=step) + ``` + + Args: + name: string name for the summary. + default_name: Optional; if provided, used as default name of the summary. + values: Optional; passed as `values` parameter to name_scope. + + Yields: + A tuple `(tag, scope)` as described above. + """ + name = name or default_name + current_scope = ops.get_name_scope() + tag = current_scope + "/" + name if current_scope else name + # Strip illegal characters from the scope name, and if that leaves nothing, + # use None instead so we pick up the default name. + name = _INVALID_SCOPE_CHARACTERS.sub("", name) or None + with ops.name_scope(name, default_name, values) as scope: + yield tag, scope + + +@tf_export("summary.write", v1=[]) +def write(tag, tensor, step, metadata=None, name=None): + """Writes a generic summary to the default SummaryWriter if one exists. + + This exists primarily to support the definition of type-specific summary ops + like scalar() and image(), and is not intended for direct use unless defining + a new type-specific summary op. + + Args: + tag: string tag used to identify the summary (e.g. in TensorBoard), usually + generated with `tf.summary.summary_scope` + tensor: the Tensor holding the summary data to write + step: `int64`-castable monotic step value for this summary + metadata: Optional SummaryMetadata, as a proto or serialized bytes + name: Optional string name for this op. + + Returns: + True on success, or false if no summary was written because no default + summary writer was available. + """ + with ops.name_scope(name, "write_summary") as scope: + if context.context().summary_writer_resource is None: + return constant_op.constant(False) + if metadata is None: + serialized_metadata = constant_op.constant(b"") + elif hasattr(metadata, "SerializeToString"): + serialized_metadata = constant_op.constant(metadata.SerializeToString()) + else: + serialized_metadata = metadata + + def record(): + """Record the actual summary and return True.""" + # Note the identity to move the tensor to the CPU. + with ops.device("cpu:0"): + write_summary_op = gen_summary_ops.write_summary( + context.context().summary_writer_resource, + step, + array_ops.identity(tensor), + tag, + serialized_metadata, + name=scope) + with ops.control_dependencies([write_summary_op]): + return constant_op.constant(True) + + return smart_cond.smart_cond( + _should_record_summaries_v2(), record, _nothing, name="summary_cond") + + def summary_writer_function(name, tensor, function, family=None): """Helper function to write summaries. diff --git a/tensorflow/tools/api/golden/v2/tensorflow.summary.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.summary.pbtxt index 5cf4d7cfd9ac54eeccea5094ad789aede29540b8..61670bd15122f65ef05d20ee5d023a3c326f7757 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.summary.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.summary.pbtxt @@ -40,4 +40,12 @@ tf_module { name: "import_event" argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "summary_scope" + argspec: "args=[\'name\', \'default_name\', \'values\'], varargs=None, keywords=None, defaults=[\'summary\', \'None\'], " + } + member_method { + name: "write" + argspec: "args=[\'tag\', \'tensor\', \'step\', \'metadata\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } }