diff --git a/tensorflow/python/eager/polymorphic_function/polymorphic_function.py b/tensorflow/python/eager/polymorphic_function/polymorphic_function.py index 6079d7ef2b30d305d1395d608fff8e7783f76c5c..7b91eb0da12539bdde3ae1342e44a05d3712655c 100644 --- a/tensorflow/python/eager/polymorphic_function/polymorphic_function.py +++ b/tensorflow/python/eager/polymorphic_function/polymorphic_function.py @@ -448,10 +448,10 @@ class OptionalXlaContext: self.xla_context.Exit() -# TODO(mdan): Consider expose this type for instance type checking. +# TODO(b/297237997): Use PolymorphicFunction here after migrating uses. @tf_export("__internal__.function.Function", v1=[]) class Function(core.GenericFunction, trackable.Trackable): - """A `tf.types.experimental.GenericFunction` created by `tf.function`. + """A `tf.types.experimental.PolymorphicFunction` created by `tf.function`. Currently, individual methods/attributes under this class are not guaranteed by the TF API contract, and are subject to future changes. @@ -802,7 +802,7 @@ class Function(core.GenericFunction, trackable.Trackable): @traceback_utils.filter_traceback def __call__(self, *args, **kwds): - # Implements GenericFunction.__call__. + # Implements PolymorphicFunction.__call__. if self._run_functions_eagerly: with trace.Trace(self._name, tf_function_call="eager"): return self._python_function(*args, **kwds) @@ -954,7 +954,7 @@ class Function(core.GenericFunction, trackable.Trackable): ) def experimental_get_compiler_ir(self, *args, **kwargs): - # Implements GenericFunction.experimental_get_compiler_ir + # Implements PolymorphicFunction.experimental_get_compiler_ir context.ensure_initialized() if not self._jit_compile: raise ValueError("Compiler IR can only be returned for functions marked " @@ -1222,7 +1222,7 @@ class Function(core.GenericFunction, trackable.Trackable): return concrete def get_concrete_function(self, *args, **kwargs): - # Implements GenericFunction.get_concrete_function. + # Implements PolymorphicFunction.get_concrete_function. concrete = self._get_concrete_function_garbage_collected(*args, **kwargs) concrete._garbage_collector.release() # pylint: disable=protected-access return concrete @@ -1295,10 +1295,10 @@ def function( experimental_relax_shapes=None, experimental_compile=None, experimental_follow_type_hints=None # pylint: disable=unused-argument -) -> core.GenericFunction: +) -> core.PolymorphicFunction: """Compiles a function into a callable TensorFlow graph. - `tf.function` constructs a `tf.types.experimental.GenericFunction` that + `tf.function` constructs a `tf.types.experimental.PolymorphicFunction` that executes a TensorFlow graph (`tf.Graph`) created by trace-compiling the TensorFlow operations in `func`. More information on the topic can be found in [Introduction to Graphs and tf.function] @@ -1320,7 +1320,7 @@ def function( The trace-compilation allows non-TensorFlow operations to execute, but under special conditions. In general, only TensorFlow operations are guaranteed to - run and create fresh results whenever the `GenericFunction` is called. + run and create fresh results whenever the `PolymorphicFunction` is called. ## Features @@ -1385,7 +1385,7 @@ def function( ## `tf.function` creates polymorphic callables - Internally, `tf.types.experimental.GenericFunction` may contain multiple + Internally, `tf.types.experimental.PolymorphicFunction` may contain multiple `tf.types.experimental.ConcreteFunction`s, each specialized to arguments with different data types or shapes, since TensorFlow can perform more optimizations on graphs of specific shapes, dtypes and values of constant @@ -1395,11 +1395,11 @@ def function( For more information, see the [tf.function guide](https://www.tensorflow.org/guide/function#rules_of_tracing) - Executing a `GenericFunction` will select and execute the appropriate + Executing a `PolymorphicFunction` will select and execute the appropriate `ConcreteFunction` based on the argument types and values. To obtain an individual `ConcreteFunction`, use the - `GenericFunction.get_concrete_function` method. It can be called with the + `PolymorphicFunction.get_concrete_function` method. It can be called with the same arguments as `func` and returns a `tf.types.experimental.ConcreteFunction`. `ConcreteFunction`s are backed by a single `tf.Graph`: @@ -1410,14 +1410,14 @@ def function( >>> isinstance(f.get_concrete_function(1).graph, tf.Graph) True - `ConcreteFunction`s can be executed just like `GenericFunction`s, but their + `ConcreteFunction`s can be executed just like `PolymorphicFunction`s, but their input is resticted to the types to which they're specialized. ## Retracing - `ConcreteFunctions` are built (traced) on the fly, as the `GenericFunction` is + `ConcreteFunctions` are built (traced) on the fly, as the `PolymorphicFunction` is called with new TensorFlow types or shapes, or with new Python values as - arguments. When `GenericFunction` builds a new trace, it is said that `func` + arguments. When `PolymorphicFunction` builds a new trace, it is said that `func` is retraced. Retracing is a frequent performance concern for `tf.function` as it can be considerably slower than executing a graph that's already been traced. It is ideal to minimize the amount of retracing in your code. @@ -1443,7 +1443,7 @@ def function( ## Input signatures - For Tensor arguments, `GenericFunction`creates a new `ConcreteFunction` for + For Tensor arguments, `PolymorphicFunction`creates a new `ConcreteFunction` for every unique set of input shapes and datatypes. The example below creates two separate `ConcreteFunction`s, each specialized to a different shape: @@ -1459,7 +1459,7 @@ def function( this process. The input signature specifies the shape and type of each Tensor argument to the function using a `tf.TensorSpec` object. More general shapes can be used. This ensures only one `ConcreteFunction` is created, and - restricts the `GenericFunction` to the specified shapes and types. It is + restricts the `PolymorphicFunction` to the specified shapes and types. It is an effective way to limit retracing when Tensors have dynamic shapes. >>> @tf.function( @@ -1602,9 +1602,9 @@ def function( reduce_retracing instead. Returns: - If `func` is not None, returns a `tf.types.experimental.GenericFunction`. + If `func` is not None, returns a `tf.types.experimental.PolymorphicFunction`. If `func` is None, returns a decorator that, when invoked with a single - `func` argument, returns a `tf.types.experimental.GenericFunction`. + `func` argument, returns a `tf.types.experimental.PolymorphicFunction`. Raises: `ValueError` when attempting to use `jit_compile=True`, but XLA support is diff --git a/tensorflow/python/ops/cond.py b/tensorflow/python/ops/cond.py index 9fae845aaeb4695f737129309b4882cb0f4743c1..23940e23847693a64e220c25419e0260f2979658 100644 --- a/tensorflow/python/ops/cond.py +++ b/tensorflow/python/ops/cond.py @@ -358,8 +358,8 @@ def _eager_cond_implementation(pred, true_fn, false_fn, strict, name): # Eager tensors from a parallel device may not have a constant # value. Running the cond op itself would work, but we don't have logic to # build cond ops without wrapping in a function first. - if (not isinstance(true_fn, core.GenericFunction) - or not isinstance(false_fn, core.GenericFunction)): + if (not isinstance(true_fn, core.PolymorphicFunction) + or not isinstance(false_fn, core.PolymorphicFunction)): raise TypeError("When running tf.cond on a parallel device, 'true_fn' " "and 'false_fn' must be decorated with `tf.function`.") functions_run_eagerly = eager_function_run.functions_run_eagerly() diff --git a/tensorflow/python/saved_model/signature_serialization.py b/tensorflow/python/saved_model/signature_serialization.py index 38362c8087a83858b29ca26c91afd655afa1c1e4..3a7a0ff4d3e9fa834c449939823e8d3b03c88562 100644 --- a/tensorflow/python/saved_model/signature_serialization.py +++ b/tensorflow/python/saved_model/signature_serialization.py @@ -226,7 +226,7 @@ def canonicalize_signatures(signatures): # pylint: enable=protected-access concrete_signatures[signature_key] = final_concrete # pylint: enable=cell-var-from-loop - if isinstance(function, core.GenericFunction): + if isinstance(function, core.PolymorphicFunction): flattened_defaults = nest.flatten( function.function_spec.fullargspec.defaults # pylint: disable=protected-access ) diff --git a/tensorflow/python/trackable/autotrackable.py b/tensorflow/python/trackable/autotrackable.py index d70e4d0079a0269af14223d1dff2568871397e39..2a0a20535ebf802a74a9e1427a996931485ea543 100644 --- a/tensorflow/python/trackable/autotrackable.py +++ b/tensorflow/python/trackable/autotrackable.py @@ -118,7 +118,7 @@ class AutoTrackable(base.Trackable): # (e.g. captured variables). Make sure we return those too. children = {} for name, child in self._checkpoint_dependencies: - if isinstance(child, (core_types.GenericFunction, + if isinstance(child, (core_types.PolymorphicFunction, core_types.ConcreteFunction)): # Skip "tracked" functions for now since there may be objects that # automatically track functions that should not be saved. diff --git a/tensorflow/python/types/core.py b/tensorflow/python/types/core.py index 5e09162d4cb721feeecb571e0c3d8242f7f559c8..c2159d5a85c78f4c2b004424ccb8a86211f1ba96 100644 --- a/tensorflow/python/types/core.py +++ b/tensorflow/python/types/core.py @@ -183,9 +183,8 @@ class ConcreteFunction(Callable, metaclass=abc.ABCMeta): """Returns the original `AtomicFunction` owned by this ConcreteFunction.""" -# TODO(mdan): Name just `types.Function`, for historic continuity? -@tf_export("types.experimental.GenericFunction", v1=[]) -class GenericFunction(Callable, metaclass=abc.ABCMeta): +@tf_export("types.experimental.PolymorphicFunction", v1=[]) +class PolymorphicFunction(Callable, metaclass=abc.ABCMeta): """Base class for polymorphic graph functions. Graph functions are Python callable objects that dispatch calls to a @@ -365,6 +364,13 @@ class GenericFunction(Callable, metaclass=abc.ABCMeta): pass +# TODO(b/297237997): Delete this once all usages are removed. +@tf_export("types.experimental.GenericFunction", v1=[]) +class GenericFunction(PolymorphicFunction): + """Please use tf.types.experimental.PolymorphicFunction instead.""" + pass + + @runtime_checkable class TensorProtocol(Protocol): """Protocol type for objects that can be converted to Tensor.""" diff --git a/tensorflow/python/types/trace.py b/tensorflow/python/types/trace.py index cac45067a18656f30d2b8bb1c536f821ac9b644b..e47a0c0f0536019e5f3a7afd4b93ad2957c0d09b 100644 --- a/tensorflow/python/types/trace.py +++ b/tensorflow/python/types/trace.py @@ -14,9 +14,9 @@ # ============================================================================== """tf.function tracing types. -See `core.GenericFunction` and `core.ConcreteFunction`. +See `core.PolymorphicFunction` and `core.ConcreteFunction`. -`GenericFunction` assigns types to call arguments, forming a signature. +`PolymorphicFunction` assigns types to call arguments, forming a signature. Function signatures are used to match arguments to `ConcreteFunction`s. For example, when a new `ConcreteFunction` is traced, it is assigned a the signature of the arguments it was traced with. Subsequent call arguments diff --git a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.function.-function.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.function.-function.pbtxt index 9ad1541c05b72266b99e5d96282e0a7704ccf684..8b9e7df373dd1b4102f7687e653cb1c687266fd8 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.function.-function.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.function.-function.pbtxt @@ -2,6 +2,7 @@ path: "tensorflow.__internal__.function.Function" tf_class { is_instance: "" is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.types.experimental.-generic-function.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.types.experimental.-generic-function.pbtxt index 8857721f2a5aa49c5b9f0f93fb994f394a2a7b63..d8764b58a58073577521f2d4a2bac09b731c7625 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.types.experimental.-generic-function.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.types.experimental.-generic-function.pbtxt @@ -1,6 +1,7 @@ path: "tensorflow.types.experimental.GenericFunction" tf_class { is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/v2/tensorflow.types.experimental.-polymorphic-function.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.types.experimental.-polymorphic-function.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..966ee0dc5e481cd03606370dba6455831b07f4ec --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.types.experimental.-polymorphic-function.pbtxt @@ -0,0 +1,21 @@ +path: "tensorflow.types.experimental.PolymorphicFunction" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "function_type" + mtype: "" + } + member_method { + name: "__init__" + } + member_method { + name: "experimental_get_compiler_ir" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "get_concrete_function" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.types.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.types.experimental.pbtxt index 64094dea498c3188d22a72367679278a808ab01e..ec0460a30412a5601a63d3e54a2d5a1c0eea28bb 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.types.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.types.experimental.pbtxt @@ -20,6 +20,10 @@ tf_module { name: "GenericFunction" mtype: "" } + member { + name: "PolymorphicFunction" + mtype: "" + } member { name: "SupportsTracingProtocol" mtype: ""