diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py index da19acad24372447a0bd21ef99189d5cf82d2da7..04865a72d3dbe00993c65f8b29b301be3f049a37 100644 --- a/tensorflow/python/eager/def_function.py +++ b/tensorflow/python/eager/def_function.py @@ -13,7 +13,52 @@ # limitations under the License. # ============================================================================== # pylint: disable=unidiomatic-typecheck -"""Prototype decorator for defining graph functions with eager semantics.""" +"""API for defining graph functions with some additional eager semantics. + +def_function.function wraps the function concept in function.py ("defun") to +allow initializing `tf.Variable`s with subgraphs of the function. For example: + +```python +class M(tf.Module): + def __init__(self): + self.v_opinit = None + self.v_arginit = None + + @tf.function + def __call__(self, x): + # Variables are only created on the first call to the function. This is a + # common pattern in layer libraries. + if self.v_opinit is None: + # self.v_opinit will outlive the function call, but `tf.ones` is traced as + # part of the function body before the `tf.Variable` object is + # created. This subgraph is easy to lift out of the function. + self.v_opinit = tf.Variable(tf.ones([])) + + # If arguments feed into variable initialization, it can be very tricky to + # disentangle from the rest of the function. We don't attempt it. + self.v_arginit = tf.Variable(tf.ones(tf.shape(x)) * tf.constant(2.)) + return self.v_opinit + self.v_arginit + x +``` + +These patterns with "defun" throw an error asking the user to put the variable's +initializer in a lambda. With tf.function they work with eager semantics either +by lifting the subgraph out of the function and using it to initialize the +variable, or by initializing variables on the first call to the function (if +they weren't already initialized by something else, e.g. a checkpoint API). The +latter requires tf.conds, and is not well supported by TF-XLA, so we only do it +when necessary. + +Since these patterns are relatively common in layer libraries, we expose the +wrapper in this file as `tf.function`. The function concept in function.py is an +internal implementation detail. + +In order to support these variable initialization patterns, tf.function defines +a variable subtype (UnliftedInitializerVariable) which collects the input +subgraph. This type of variable replaces the regular variable type on the first +tf.function trace. To exclude initializers from the function body (the `tf.ones` +ops above and associated assignment operations), tf.function traces a second +time if it sees variables on the first call. +""" from __future__ import absolute_import from __future__ import division