提交 093f0363 编写于 作者: A Alexandre Passos 提交者: TensorFlower Gardener

Utility to run tests inside tf.function and eager.

Relies on being able to run the assert* test methods inside a
py_func to run them inside the graph, so there's no need for
self.evaluate or similar methods which create a graph/eager
hybrid programming model.

PiperOrigin-RevId: 224575790
上级 71ea120a
......@@ -428,20 +428,21 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
self.evaluate(variables.global_variables_initializer())
self.assertEqual(self.evaluate(value), 2.0)
@test_util.run_in_graph_and_eager_modes
@test_util.also_run_as_tf_function
def testInitScopeTensorInitializationInFunction(self):
@def_function.function
def tensor_init():
with ops.init_scope():
const = constant_op.constant(2.0)
# Note: this variable bypasses tf.function's variable creation
# requirements by bypassing variable_creator_scope by using
# ResourceVariable instead of Variable.
self.v = resource_variable_ops.ResourceVariable(const)
return self.v.read_value()
value = tensor_init()
if not context.executing_eagerly():
self.evaluate(variables.global_variables_initializer())
self.assertEqual(self.evaluate(value), 2.0)
self.assertAllEqual(value, 2.0)
def testDefunShapeInferenceWithCapturedResourceVariable(self):
v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]])
......
......@@ -54,6 +54,7 @@ from tensorflow.python import tf2
from tensorflow.python.client import device_lib
from tensorflow.python.client import session
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import tape
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes
......@@ -67,6 +68,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import versions
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
......@@ -76,6 +78,7 @@ from tensorflow.python.util import compat
from tensorflow.python.util import deprecation
from tensorflow.python.util import memory
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.protobuf import compare
from tensorflow.python.util.tf_export import tf_export
......@@ -1009,6 +1012,58 @@ def run_in_graph_and_eager_modes(func=None,
return decorator
def py_func_if_in_function(f):
def decorated(*args, **kwds):
if not ops.get_default_graph()._building_function:
return f(*args, **kwds)
tensor_args, tensor_indices = zip(
*[(x, i) for i, x in enumerate(args)
if isinstance(x, (ops.Tensor, variables.Variable))])
def inner_f(*inner_tensor_args):
my_args = list(args)
for i, n in zip(tensor_indices, inner_tensor_args):
my_args[i] = n
return f(*my_args, **kwds)
return script_ops.py_func(inner_f, tensor_args, [])
return tf_decorator.make_decorator(f, decorated)
def also_run_as_tf_function(f):
"""Runs the decorated test twice--once as is, once inside a tf.function.
This allows you to run a test both in eager execution and inside a
tf.function, exercising the two execution modes supported in tf 2.0. The test
assertions are automatically done inside tf.py_funcs, and tf.function ensures
that they run in the proper order and with the proper side effects.
Currently variable creation is not supported in tests annotated with this
decorator since it's tricky to ensure the variable doesn't get repeatedly
created when retracing the tf.function.
Args:
f: the test method to be decorated
Returns:
The decorated test method, which will run both in eager and inside a
tf.function.
"""
def decorated(*args, **kwds):
with context.eager_mode():
# Running in eager mode
f(*args, **kwds)
defun_f = def_function.function(f)
defun_f(*args, **kwds)
return decorated
def run_deprecated_v1(func=None):
"""Execute the decorated test in graph mode.
......@@ -1783,8 +1838,8 @@ class TensorFlowTestCase(googletest.TestCase):
return ret
# pylint: enable=invalid-name
# pylint: enable=invalid-name
@py_func_if_in_function
def assertNear(self, f1, f2, err, msg=None):
"""Asserts that two floats are near each other.
......@@ -1803,6 +1858,7 @@ class TensorFlowTestCase(googletest.TestCase):
"%f != %f +/- %f%s" % (f1, f2, err, " (%s)" % msg
if msg is not None else ""))
@py_func_if_in_function
def assertArrayNear(self, farray1, farray2, err, msg=None):
"""Asserts that two float arrays are near each other.
......@@ -1822,6 +1878,7 @@ class TensorFlowTestCase(googletest.TestCase):
def _NDArrayNear(self, ndarray1, ndarray2, err):
return np.linalg.norm(ndarray1 - ndarray2) < err
@py_func_if_in_function
def assertNDArrayNear(self, ndarray1, ndarray2, err, msg=None):
"""Asserts that two numpy arrays have near values.
......@@ -1959,6 +2016,7 @@ class TensorFlowTestCase(googletest.TestCase):
e.args = ((e.args[0] + " : " + msg,) + e.args[1:])
raise
@py_func_if_in_function
def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None):
"""Asserts that two structures of numpy arrays or Tensors, have near values.
......@@ -1984,6 +2042,7 @@ class TensorFlowTestCase(googletest.TestCase):
"""
self._assertAllCloseRecursive(a, b, rtol=rtol, atol=atol, msg=msg)
@py_func_if_in_function
def assertAllCloseAccordingToType(self,
a,
b,
......@@ -2031,6 +2090,7 @@ class TensorFlowTestCase(googletest.TestCase):
self.assertAllClose(a, b, rtol=rtol, atol=atol, msg=msg)
@py_func_if_in_function
def assertNotAllClose(self, a, b, **kwargs):
"""Assert that two numpy arrays, or or Tensors, do not have near values.
......@@ -2049,6 +2109,7 @@ class TensorFlowTestCase(googletest.TestCase):
return
raise AssertionError("The two values are close at all elements")
@py_func_if_in_function
def assertAllEqual(self, a, b, msg=None):
"""Asserts that two numpy arrays or Tensors have the same values.
......@@ -2091,6 +2152,7 @@ class TensorFlowTestCase(googletest.TestCase):
msgs.append("not equal rhs = {}".format(y))
np.testing.assert_array_equal(a, b, err_msg="\n".join(msgs))
@py_func_if_in_function
def assertAllGreater(self, a, comparison_target):
"""Assert element values are all greater than a target value.
......@@ -2102,6 +2164,7 @@ class TensorFlowTestCase(googletest.TestCase):
a = self._GetNdArray(a)
self.assertGreater(np.min(a), comparison_target)
@py_func_if_in_function
def assertAllLess(self, a, comparison_target):
"""Assert element values are all less than a target value.
......@@ -2113,6 +2176,7 @@ class TensorFlowTestCase(googletest.TestCase):
a = self._GetNdArray(a)
self.assertLess(np.max(a), comparison_target)
@py_func_if_in_function
def assertAllGreaterEqual(self, a, comparison_target):
"""Assert element values are all greater than or equal to a target value.
......@@ -2124,6 +2188,7 @@ class TensorFlowTestCase(googletest.TestCase):
a = self._GetNdArray(a)
self.assertGreaterEqual(np.min(a), comparison_target)
@py_func_if_in_function
def assertAllLessEqual(self, a, comparison_target):
"""Assert element values are all less than or equal to a target value.
......@@ -2166,6 +2231,7 @@ class TensorFlowTestCase(googletest.TestCase):
lines.append(prefix + "...")
return lines
@py_func_if_in_function
def assertAllInRange(self,
target,
lower_bound,
......@@ -2224,6 +2290,7 @@ class TensorFlowTestCase(googletest.TestCase):
"Subscript(s) and value(s) of the offending elements:\n" +
"\n".join(self._format_subscripts(violation_subscripts, target)))
@py_func_if_in_function
def assertAllInSet(self, target, expected_set):
"""Assert that elements of a Tensor are all in a given closed set.
......@@ -2245,6 +2312,7 @@ class TensorFlowTestCase(googletest.TestCase):
raise AssertionError("%d unique element(s) are not in the set %s: %s" %
(np.size(diff), expected_set, diff))
@py_func_if_in_function
def assertDTypeEqual(self, target, expected_dtype):
"""Assert ndarray data type is equal to expected.
......
......@@ -237,7 +237,8 @@ class VariableScopeTest(test.TestCase):
_ = d2(x)
self.assertEqual(len(d2.variables), 2)
v3, v4 = d2.variables
self.assertAllEqual([v1, v2], [v3, v4])
self.assertEqual(v1, v3)
self.assertEqual(v2, v4)
f()
# TODO(mihaimaruseac): Not converted to use wrap_function because of
......@@ -1684,7 +1685,7 @@ class VariableScopeWithCustomGetterTest(test.TestCase):
with variable_scope.variable_creator_scope(creator_b):
variable_scope.variable(1.0, name="one_name")
self.assertAllEqual(variable_names, ["forced_name"])
self.assertEqual(variable_names[0], "forced_name")
called = [False]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册