提交 61c7cbca 编写于 作者: M Martin Wicke 提交者: TensorFlower Gardener

Add functions to switch between 1.x and 2.x global behavior.

PiperOrigin-RevId: 223595880
上级 f30d3d01
......@@ -35,8 +35,9 @@ _tf_api_dir = _os.path.dirname(_os.path.dirname(bitwise.__file__)) # pylint: di
if _tf_api_dir not in __path__:
__path__.append(_tf_api_dir)
# Calls to enable and disable features.
enable_eager_execution() # pylint: disable=undefined-variable
# Enable TF2 behaviors
from tensorflow.python.compat import compat as _compat # pylint: disable=g-import-not-at-top
_compat.enable_v2_behavior()
# These symbols appear because we import the python package which
# in turn imports from tensorflow.core and tensorflow.python. They
......
......@@ -9,7 +9,10 @@ py_library(
srcs = ["compat.py"],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
deps = ["//tensorflow/python:util"],
deps = [
"//tensorflow/python:tf2",
"//tensorflow/python:util",
],
)
tf_py_test(
......
......@@ -23,6 +23,12 @@ from __future__ import division
from __future__ import print_function
import datetime
from tensorflow.python import tf2
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import variable_scope
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
......@@ -132,3 +138,40 @@ def forward_compatibility_horizon(year, month, day):
yield
finally:
_FORWARD_COMPATIBILITY_HORIZON = old_compat_date
@tf_export(v1=["enable_v2_behavior"])
def enable_v2_behavior():
"""Enables TensorFlow 2.x behaviors.
This function can be called at the beginning of the program (before `Tensors`,
`Graphs` or other structures have been created, and before devices have been
initialized. It switches all global behaviors that are different between
TensorFlow 1.x and 2.x to behave as intended for 2.x.
This function is called in the main TensorFlow `__init__.py` file, user should
not need to call it, except during complex migrations.
"""
tf2.enable() # Switches TensorArrayV2 and control flow V2
ops.enable_eager_execution()
tensor_shape.enable_v2_tensorshape() # Also switched by tf2
variable_scope.enable_resource_variables()
@tf_export(v1=["disable_v2_behavior"])
def disable_v2_behavior():
"""Enables TensorFlow 2.x behaviors.
This function can be called at the beginning of the program (before `Tensors`,
`Graphs` or other structures have been created, and before devices have been
initialized. It switches all global behaviors that are different between
TensorFlow 1.x and 2.x to behave as intended for 1.x.
User can call this function to disable 2.x behavior during complex migrations.
"""
tf2.disable() # Switches TensorArrayV2 and control flow V2
ops.disable_eager_execution()
tensor_shape.disable_v2_tensorshape() # Also switched by tf2
variable_scope.disable_resource_variables()
......@@ -5393,7 +5393,7 @@ def inside_function():
return get_default_graph().building_function
@tf_export("enable_eager_execution")
@tf_export(v1=["enable_eager_execution"])
def enable_eager_execution(config=None,
device_policy=None,
execution_mode=None):
......@@ -5464,6 +5464,17 @@ def enable_eager_execution(config=None,
server_def=None)
@tf_export(v1=["disable_eager_execution"])
def disable_eager_execution():
"""Disables eager execution.
This function can only be called before any Graphs, Ops, or Tensors have been
created. It can be used at the beginning of the program for complex migration
projects from TensorFlow 1.x to 2.x.
"""
context.default_execution_mode = context.GRAPH_MODE
def enable_eager_execution_internal(config=None,
device_policy=None,
execution_mode=None,
......@@ -5471,6 +5482,7 @@ def enable_eager_execution_internal(config=None,
"""Enables eager execution for the lifetime of this program.
Most of the doc string for enable_eager_execution is relevant here as well.
Args:
config: See enable_eager_execution doc string
device_policy: See enable_eager_execution doc string
......
......@@ -25,6 +25,21 @@ from __future__ import print_function
import os
_force_enable = False
def enable():
"""Enables v2 behaviors."""
global _force_enable
_force_enable = True
def disable():
"""Disables v2 behaviors (TF2_BEHAVIOR env variable is still respected)."""
global _force_enable
_force_enable = False
def enabled():
"""Returns True iff TensorFlow 2.0 behavior should be enabled."""
return os.getenv("TF2_BEHAVIOR") is not None
return _force_enable or os.getenv("TF2_BEHAVIOR") is not None
......@@ -1052,10 +1052,18 @@ tf_module {
name: "dimension_value"
argspec: "args=[\'dimension\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "disable_eager_execution"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "disable_resource_variables"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "disable_v2_behavior"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "disable_v2_tensorshape"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
......@@ -1096,6 +1104,10 @@ tf_module {
name: "enable_resource_variables"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "enable_v2_behavior"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "enable_v2_tensorshape"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
......
......@@ -608,10 +608,6 @@ tf_module {
name: "einsum"
argspec: "args=[\'equation\'], varargs=inputs, keywords=kwargs, defaults=None"
}
member_method {
name: "enable_eager_execution"
argspec: "args=[\'config\', \'device_policy\', \'execution_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "ensure_shape"
argspec: "args=[\'x\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
......
......@@ -126,7 +126,9 @@ renames = {
'tf.digamma': 'tf.math.digamma',
'tf.dimension_at_index': 'tf.compat.v1.dimension_at_index',
'tf.dimension_value': 'tf.compat.v1.dimension_value',
'tf.disable_eager_execution': 'tf.compat.v1.disable_eager_execution',
'tf.disable_resource_variables': 'tf.compat.v1.disable_resource_variables',
'tf.disable_v2_behavior': 'tf.compat.v1.disable_v2_behavior',
'tf.disable_v2_tensorshape': 'tf.compat.v1.disable_v2_tensorshape',
'tf.distributions.Bernoulli': 'tf.compat.v1.distributions.Bernoulli',
'tf.distributions.Beta': 'tf.compat.v1.distributions.Beta',
......@@ -147,7 +149,9 @@ renames = {
'tf.distributions.Uniform': 'tf.compat.v1.distributions.Uniform',
'tf.distributions.kl_divergence': 'tf.compat.v1.distributions.kl_divergence',
'tf.div': 'tf.compat.v1.div',
'tf.enable_eager_execution': 'tf.compat.v1.enable_eager_execution',
'tf.enable_resource_variables': 'tf.compat.v1.enable_resource_variables',
'tf.enable_v2_behavior': 'tf.compat.v1.enable_v2_behavior',
'tf.enable_v2_tensorshape': 'tf.compat.v1.enable_v2_tensorshape',
'tf.encode_base64': 'tf.io.encode_base64',
'tf.erf': 'tf.math.erf',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册