提交 778dd23c 编写于 作者: F Faizan Muhammad 提交者: TensorFlower Gardener

Prepare rename GenericFunction to PolymorphicFunction

PiperOrigin-RevId: 564434713
上级 6e2a4776
......@@ -448,10 +448,10 @@ class OptionalXlaContext:
self.xla_context.Exit()
# TODO(mdan): Consider expose this type for instance type checking.
# TODO(b/297237997): Use PolymorphicFunction here after migrating uses.
@tf_export("__internal__.function.Function", v1=[])
class Function(core.GenericFunction, trackable.Trackable):
"""A `tf.types.experimental.GenericFunction` created by `tf.function`.
"""A `tf.types.experimental.PolymorphicFunction` created by `tf.function`.
Currently, individual methods/attributes under this class are not guaranteed
by the TF API contract, and are subject to future changes.
......@@ -802,7 +802,7 @@ class Function(core.GenericFunction, trackable.Trackable):
@traceback_utils.filter_traceback
def __call__(self, *args, **kwds):
# Implements GenericFunction.__call__.
# Implements PolymorphicFunction.__call__.
if self._run_functions_eagerly:
with trace.Trace(self._name, tf_function_call="eager"):
return self._python_function(*args, **kwds)
......@@ -954,7 +954,7 @@ class Function(core.GenericFunction, trackable.Trackable):
)
def experimental_get_compiler_ir(self, *args, **kwargs):
# Implements GenericFunction.experimental_get_compiler_ir
# Implements PolymorphicFunction.experimental_get_compiler_ir
context.ensure_initialized()
if not self._jit_compile:
raise ValueError("Compiler IR can only be returned for functions marked "
......@@ -1222,7 +1222,7 @@ class Function(core.GenericFunction, trackable.Trackable):
return concrete
def get_concrete_function(self, *args, **kwargs):
# Implements GenericFunction.get_concrete_function.
# Implements PolymorphicFunction.get_concrete_function.
concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
concrete._garbage_collector.release() # pylint: disable=protected-access
return concrete
......@@ -1295,10 +1295,10 @@ def function(
experimental_relax_shapes=None,
experimental_compile=None,
experimental_follow_type_hints=None # pylint: disable=unused-argument
) -> core.GenericFunction:
) -> core.PolymorphicFunction:
"""Compiles a function into a callable TensorFlow graph.
`tf.function` constructs a `tf.types.experimental.GenericFunction` that
`tf.function` constructs a `tf.types.experimental.PolymorphicFunction` that
executes a TensorFlow graph (`tf.Graph`) created by trace-compiling the
TensorFlow operations in `func`. More information on the topic can be found
in [Introduction to Graphs and tf.function]
......@@ -1320,7 +1320,7 @@ def function(
The trace-compilation allows non-TensorFlow operations to execute, but under
special conditions. In general, only TensorFlow operations are guaranteed to
run and create fresh results whenever the `GenericFunction` is called.
run and create fresh results whenever the `PolymorphicFunction` is called.
## Features
......@@ -1385,7 +1385,7 @@ def function(
## `tf.function` creates polymorphic callables
Internally, `tf.types.experimental.GenericFunction` may contain multiple
Internally, `tf.types.experimental.PolymorphicFunction` may contain multiple
`tf.types.experimental.ConcreteFunction`s, each specialized to arguments with
different data types or shapes, since TensorFlow can perform more
optimizations on graphs of specific shapes, dtypes and values of constant
......@@ -1395,11 +1395,11 @@ def function(
For more information, see the
[tf.function guide](https://www.tensorflow.org/guide/function#rules_of_tracing)
Executing a `GenericFunction` will select and execute the appropriate
Executing a `PolymorphicFunction` will select and execute the appropriate
`ConcreteFunction` based on the argument types and values.
To obtain an individual `ConcreteFunction`, use the
`GenericFunction.get_concrete_function` method. It can be called with the
`PolymorphicFunction.get_concrete_function` method. It can be called with the
same arguments as `func` and returns a
`tf.types.experimental.ConcreteFunction`. `ConcreteFunction`s are backed by a
single `tf.Graph`:
......@@ -1410,14 +1410,14 @@ def function(
>>> isinstance(f.get_concrete_function(1).graph, tf.Graph)
True
`ConcreteFunction`s can be executed just like `GenericFunction`s, but their
`ConcreteFunction`s can be executed just like `PolymorphicFunction`s, but their
input is resticted to the types to which they're specialized.
## Retracing
`ConcreteFunctions` are built (traced) on the fly, as the `GenericFunction` is
`ConcreteFunctions` are built (traced) on the fly, as the `PolymorphicFunction` is
called with new TensorFlow types or shapes, or with new Python values as
arguments. When `GenericFunction` builds a new trace, it is said that `func`
arguments. When `PolymorphicFunction` builds a new trace, it is said that `func`
is retraced. Retracing is a frequent performance concern for `tf.function` as
it can be considerably slower than executing a graph that's already been
traced. It is ideal to minimize the amount of retracing in your code.
......@@ -1443,7 +1443,7 @@ def function(
## Input signatures
For Tensor arguments, `GenericFunction`creates a new `ConcreteFunction` for
For Tensor arguments, `PolymorphicFunction`creates a new `ConcreteFunction` for
every unique set of input shapes and datatypes. The example below creates two
separate `ConcreteFunction`s, each specialized to a different shape:
......@@ -1459,7 +1459,7 @@ def function(
this process. The input signature specifies the shape and type of each
Tensor argument to the function using a `tf.TensorSpec` object. More general
shapes can be used. This ensures only one `ConcreteFunction` is created, and
restricts the `GenericFunction` to the specified shapes and types. It is
restricts the `PolymorphicFunction` to the specified shapes and types. It is
an effective way to limit retracing when Tensors have dynamic shapes.
>>> @tf.function(
......@@ -1602,9 +1602,9 @@ def function(
reduce_retracing instead.
Returns:
If `func` is not None, returns a `tf.types.experimental.GenericFunction`.
If `func` is not None, returns a `tf.types.experimental.PolymorphicFunction`.
If `func` is None, returns a decorator that, when invoked with a single
`func` argument, returns a `tf.types.experimental.GenericFunction`.
`func` argument, returns a `tf.types.experimental.PolymorphicFunction`.
Raises:
`ValueError` when attempting to use `jit_compile=True`, but XLA support is
......
......@@ -358,8 +358,8 @@ def _eager_cond_implementation(pred, true_fn, false_fn, strict, name):
# Eager tensors from a parallel device may not have a constant
# value. Running the cond op itself would work, but we don't have logic to
# build cond ops without wrapping in a function first.
if (not isinstance(true_fn, core.GenericFunction)
or not isinstance(false_fn, core.GenericFunction)):
if (not isinstance(true_fn, core.PolymorphicFunction)
or not isinstance(false_fn, core.PolymorphicFunction)):
raise TypeError("When running tf.cond on a parallel device, 'true_fn' "
"and 'false_fn' must be decorated with `tf.function`.")
functions_run_eagerly = eager_function_run.functions_run_eagerly()
......
......@@ -226,7 +226,7 @@ def canonicalize_signatures(signatures):
# pylint: enable=protected-access
concrete_signatures[signature_key] = final_concrete
# pylint: enable=cell-var-from-loop
if isinstance(function, core.GenericFunction):
if isinstance(function, core.PolymorphicFunction):
flattened_defaults = nest.flatten(
function.function_spec.fullargspec.defaults # pylint: disable=protected-access
)
......
......@@ -118,7 +118,7 @@ class AutoTrackable(base.Trackable):
# (e.g. captured variables). Make sure we return those too.
children = {}
for name, child in self._checkpoint_dependencies:
if isinstance(child, (core_types.GenericFunction,
if isinstance(child, (core_types.PolymorphicFunction,
core_types.ConcreteFunction)):
# Skip "tracked" functions for now since there may be objects that
# automatically track functions that should not be saved.
......
......@@ -183,9 +183,8 @@ class ConcreteFunction(Callable, metaclass=abc.ABCMeta):
"""Returns the original `AtomicFunction` owned by this ConcreteFunction."""
# TODO(mdan): Name just `types.Function`, for historic continuity?
@tf_export("types.experimental.GenericFunction", v1=[])
class GenericFunction(Callable, metaclass=abc.ABCMeta):
@tf_export("types.experimental.PolymorphicFunction", v1=[])
class PolymorphicFunction(Callable, metaclass=abc.ABCMeta):
"""Base class for polymorphic graph functions.
Graph functions are Python callable objects that dispatch calls to a
......@@ -365,6 +364,13 @@ class GenericFunction(Callable, metaclass=abc.ABCMeta):
pass
# TODO(b/297237997): Delete this once all usages are removed.
@tf_export("types.experimental.GenericFunction", v1=[])
class GenericFunction(PolymorphicFunction):
"""Please use tf.types.experimental.PolymorphicFunction instead."""
pass
@runtime_checkable
class TensorProtocol(Protocol):
"""Protocol type for objects that can be converted to Tensor."""
......
......@@ -14,9 +14,9 @@
# ==============================================================================
"""tf.function tracing types.
See `core.GenericFunction` and `core.ConcreteFunction`.
See `core.PolymorphicFunction` and `core.ConcreteFunction`.
`GenericFunction` assigns types to call arguments, forming a signature.
`PolymorphicFunction` assigns types to call arguments, forming a signature.
Function signatures are used to match arguments to `ConcreteFunction`s.
For example, when a new `ConcreteFunction` is traced, it is assigned a
the signature of the arguments it was traced with. Subsequent call arguments
......
......@@ -2,6 +2,7 @@ path: "tensorflow.__internal__.function.Function"
tf_class {
is_instance: "<class \'tensorflow.python.eager.polymorphic_function.polymorphic_function.Function\'>"
is_instance: "<class \'tensorflow.python.types.core.GenericFunction\'>"
is_instance: "<class \'tensorflow.python.types.core.PolymorphicFunction\'>"
is_instance: "<class \'tensorflow.python.types.core.Callable\'>"
is_instance: "<class \'tensorflow.python.trackable.base.Trackable\'>"
is_instance: "<type \'object\'>"
......
path: "tensorflow.types.experimental.GenericFunction"
tf_class {
is_instance: "<class \'tensorflow.python.types.core.GenericFunction\'>"
is_instance: "<class \'tensorflow.python.types.core.PolymorphicFunction\'>"
is_instance: "<class \'tensorflow.python.types.core.Callable\'>"
is_instance: "<type \'object\'>"
member {
......
path: "tensorflow.types.experimental.PolymorphicFunction"
tf_class {
is_instance: "<class \'tensorflow.python.types.core.PolymorphicFunction\'>"
is_instance: "<class \'tensorflow.python.types.core.Callable\'>"
is_instance: "<type \'object\'>"
member {
name: "function_type"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
}
member_method {
name: "experimental_get_compiler_ir"
argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "get_concrete_function"
argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
}
}
......@@ -20,6 +20,10 @@ tf_module {
name: "GenericFunction"
mtype: "<type \'type\'>"
}
member {
name: "PolymorphicFunction"
mtype: "<type \'type\'>"
}
member {
name: "SupportsTracingProtocol"
mtype: "<class \'typing._ProtocolMeta\'>"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册