提交 93439a55 编写于 作者: A Anna R 提交者: TensorFlower Gardener

Use "in symbol.__dict__" instead of "hasattr" to check if a symbol has api

names set. The former would behave correctly for subclasses.

Also, moving get_v1_names|constants and get_v2_names|constants functions to tf_export.py to reduce code duplication.

PiperOrigin-RevId: 225063242
上级 f9dbe986
......@@ -147,6 +147,94 @@ def get_canonical_name(api_names, deprecated_api_names):
return None
def get_v1_names(symbol):
"""Get a list of TF 1.* names for this symbol.
Args:
symbol: symbol to get API names for.
Returns:
List of all API names for this symbol including TensorFlow and
Estimator names.
"""
names_v1 = []
tensorflow_api_attr_v1 = API_ATTRS_V1[TENSORFLOW_API_NAME].names
estimator_api_attr_v1 = API_ATTRS_V1[ESTIMATOR_API_NAME].names
if not hasattr(symbol, tensorflow_api_attr_v1):
return names_v1
if tensorflow_api_attr_v1 in symbol.__dict__:
names_v1.extend(getattr(symbol, tensorflow_api_attr_v1))
if estimator_api_attr_v1 in symbol.__dict__:
names_v1.extend(getattr(symbol, estimator_api_attr_v1))
return names_v1
def get_v2_names(symbol):
"""Get a list of TF 2.0 names for this symbol.
Args:
symbol: symbol to get API names for.
Returns:
List of all API names for this symbol including TensorFlow and
Estimator names.
"""
names_v2 = []
tensorflow_api_attr = API_ATTRS[TENSORFLOW_API_NAME].names
estimator_api_attr = API_ATTRS[ESTIMATOR_API_NAME].names
if not hasattr(symbol, tensorflow_api_attr):
return names_v2
if tensorflow_api_attr in symbol.__dict__:
names_v2.extend(getattr(symbol, tensorflow_api_attr))
if estimator_api_attr in symbol.__dict__:
names_v2.extend(getattr(symbol, estimator_api_attr))
return names_v2
def get_v1_constants(module):
"""Get a list of TF 1.* constants in this module.
Args:
module: TensorFlow module.
Returns:
List of all API constants under the given module including TensorFlow and
Estimator constants.
"""
constants_v1 = []
tensorflow_constants_attr_v1 = API_ATTRS_V1[TENSORFLOW_API_NAME].constants
estimator_constants_attr_v1 = API_ATTRS_V1[ESTIMATOR_API_NAME].constants
if hasattr(module, tensorflow_constants_attr_v1):
constants_v1.extend(getattr(module, tensorflow_constants_attr_v1))
if hasattr(module, estimator_constants_attr_v1):
constants_v1.extend(getattr(module, estimator_constants_attr_v1))
return constants_v1
def get_v2_constants(module):
"""Get a list of TF 2.0 constants in this module.
Args:
module: TensorFlow module.
Returns:
List of all API constants under the given module including TensorFlow and
Estimator constants.
"""
constants_v2 = []
tensorflow_constants_attr = API_ATTRS[TENSORFLOW_API_NAME].constants
estimator_constants_attr = API_ATTRS[ESTIMATOR_API_NAME].constants
if hasattr(module, tensorflow_constants_attr):
constants_v2.extend(getattr(module, tensorflow_constants_attr))
if hasattr(module, estimator_constants_attr):
constants_v2.extend(getattr(module, estimator_constants_attr))
return constants_v2
class api_export(object): # pylint: disable=invalid-name
"""Provides ways to export symbols to the TensorFlow API."""
......
......@@ -62,6 +62,10 @@ class ValidateExportTest(test.TestCase):
del symbol._tf_api_names
if hasattr(symbol, '_tf_api_names_v1'):
del symbol._tf_api_names_v1
if hasattr(symbol, '_estimator_api_names'):
del symbol._estimator_api_names
if hasattr(symbol, '_estimator_api_names_v1'):
del symbol._estimator_api_names_v1
def _CreateMockModule(self, name):
mock_module = self.MockModule(name)
......@@ -74,6 +78,10 @@ class ValidateExportTest(test.TestCase):
decorated_function = export_decorator(_test_function)
self.assertEquals(decorated_function, _test_function)
self.assertEquals(('nameA', 'nameB'), decorated_function._tf_api_names)
self.assertEquals(['nameA', 'nameB'],
tf_export.get_v1_names(decorated_function))
self.assertEquals(['nameA', 'nameB'],
tf_export.get_v2_names(decorated_function))
def testExportMultipleFunctions(self):
export_decorator1 = tf_export.tf_export('nameA', 'nameB')
......@@ -95,6 +103,22 @@ class ValidateExportTest(test.TestCase):
export_decorator_b(TestClassB)
self.assertEquals(('TestClassA1',), TestClassA._tf_api_names)
self.assertEquals(('TestClassB1',), TestClassB._tf_api_names)
self.assertEquals(['TestClassA1'], tf_export.get_v1_names(TestClassA))
self.assertEquals(['TestClassB1'], tf_export.get_v1_names(TestClassB))
def testExportClassInEstimator(self):
export_decorator_a = tf_export.tf_export('TestClassA1')
export_decorator_a(TestClassA)
self.assertEquals(('TestClassA1',), TestClassA._tf_api_names)
export_decorator_b = tf_export.estimator_export(
'estimator.TestClassB1')
export_decorator_b(TestClassB)
self.assertTrue('_tf_api_names' not in TestClassB.__dict__)
self.assertEquals(('TestClassA1',), TestClassA._tf_api_names)
self.assertEquals(['TestClassA1'], tf_export.get_v1_names(TestClassA))
self.assertEquals(['estimator.TestClassB1'],
tf_export.get_v1_names(TestClassB))
def testExportSingleConstant(self):
module1 = self._CreateMockModule('module1')
......@@ -103,6 +127,10 @@ class ValidateExportTest(test.TestCase):
export_decorator.export_constant('module1', 'test_constant')
self.assertEquals([(('NAME_A', 'NAME_B'), 'test_constant')],
module1._tf_api_constants)
self.assertEquals([(('NAME_A', 'NAME_B'), 'test_constant')],
tf_export.get_v1_constants(module1))
self.assertEquals([(('NAME_A', 'NAME_B'), 'test_constant')],
tf_export.get_v2_constants(module1))
def testExportMultipleConstants(self):
module1 = self._CreateMockModule('module1')
......
......@@ -37,32 +37,6 @@ from tensorflow.tools.compatibility import ast_edits
from tensorflow.tools.compatibility import tf_upgrade_v2
_TENSORFLOW_API_ATTR_V1 = (
tf_export.API_ATTRS_V1[tf_export.TENSORFLOW_API_NAME].names)
_TENSORFLOW_API_ATTR = tf_export.API_ATTRS[tf_export.TENSORFLOW_API_NAME].names
_ESTIMATOR_API_ATTR_V1 = (
tf_export.API_ATTRS_V1[tf_export.ESTIMATOR_API_NAME].names)
_ESTIMATOR_API_ATTR = tf_export.API_ATTRS[tf_export.ESTIMATOR_API_NAME].names
def get_v1_names(symbol):
names_v1 = []
if hasattr(symbol, _TENSORFLOW_API_ATTR_V1):
names_v1.extend(getattr(symbol, _TENSORFLOW_API_ATTR_V1))
if hasattr(symbol, _ESTIMATOR_API_ATTR_V1):
names_v1.extend(getattr(symbol, _ESTIMATOR_API_ATTR_V1))
return names_v1
def get_v2_names(symbol):
names_v2 = set()
if hasattr(symbol, _TENSORFLOW_API_ATTR):
names_v2.update(getattr(symbol, _TENSORFLOW_API_ATTR))
if hasattr(symbol, _ESTIMATOR_API_ATTR):
names_v2.update(getattr(symbol, _ESTIMATOR_API_ATTR))
return list(names_v2)
def get_symbol_for_name(root, name):
name_parts = name.split(".")
symbol = root
......@@ -118,7 +92,7 @@ class TestUpgrade(test_util.TensorFlowTestCase):
def symbol_collector(unused_path, unused_parent, children):
for child in children:
_, attr = tf_decorator.unwrap(child[1])
api_names_v2 = get_v2_names(attr)
api_names_v2 = tf_export.get_v2_names(attr)
for name in api_names_v2:
cls.v2_symbols["tf." + name] = attr
......@@ -166,7 +140,7 @@ class TestUpgrade(test_util.TensorFlowTestCase):
def conversion_visitor(unused_path, unused_parent, children):
for child in children:
_, attr = tf_decorator.unwrap(child[1])
api_names = get_v1_names(attr)
api_names = tf_export.get_v1_names(attr)
for name in api_names:
_, _, _, text = self._upgrade("tf." + name)
if (text and
......@@ -190,7 +164,7 @@ class TestUpgrade(test_util.TensorFlowTestCase):
def conversion_visitor(unused_path, unused_parent, children):
for child in children:
_, attr = tf_decorator.unwrap(child[1])
api_names = get_v1_names(attr)
api_names = tf_export.get_v1_names(attr)
for name in api_names:
if collect:
v1_symbols.add("tf." + name)
......@@ -219,7 +193,7 @@ class TestUpgrade(test_util.TensorFlowTestCase):
def arg_test_visitor(unused_path, unused_parent, children):
for child in children:
_, attr = tf_decorator.unwrap(child[1])
names_v1 = get_v1_names(attr)
names_v1 = tf_export.get_v1_names(attr)
for name in names_v1:
name = "tf.%s" % name
......@@ -270,7 +244,7 @@ class TestUpgrade(test_util.TensorFlowTestCase):
_, attr = tf_decorator.unwrap(child[1])
if not tf_inspect.isfunction(attr):
continue
names_v1 = get_v1_names(attr)
names_v1 = tf_export.get_v1_names(attr)
arg_names_v1 = get_args(attr)
for name in names_v1:
......@@ -340,7 +314,7 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
# get other names for this function
attr = get_symbol_for_name(tf.compat.v1, name)
_, attr = tf_decorator.unwrap(attr)
v1_names = get_v1_names(attr)
v1_names = tf_export.get_v1_names(attr)
self.assertTrue(v1_names)
v1_names = ["tf.%s" % n for n in v1_names]
# check if any other name is in
......
......@@ -64,58 +64,6 @@ from __future__ import print_function
"""
_TENSORFLOW_API_ATTR_V1 = (
tf_export.API_ATTRS_V1[tf_export.TENSORFLOW_API_NAME].names)
_TENSORFLOW_API_ATTR = tf_export.API_ATTRS[tf_export.TENSORFLOW_API_NAME].names
_TENSORFLOW_CONSTANTS_ATTR_V1 = (
tf_export.API_ATTRS_V1[tf_export.TENSORFLOW_API_NAME].constants)
_TENSORFLOW_CONSTANTS_ATTR = (
tf_export.API_ATTRS[tf_export.TENSORFLOW_API_NAME].constants)
_ESTIMATOR_API_ATTR_V1 = (
tf_export.API_ATTRS_V1[tf_export.ESTIMATOR_API_NAME].names)
_ESTIMATOR_API_ATTR = tf_export.API_ATTRS[tf_export.ESTIMATOR_API_NAME].names
_ESTIMATOR_CONSTANTS_ATTR_V1 = (
tf_export.API_ATTRS_V1[tf_export.ESTIMATOR_API_NAME].constants)
_ESTIMATOR_CONSTANTS_ATTR = (
tf_export.API_ATTRS[tf_export.ESTIMATOR_API_NAME].constants)
def get_v1_names(symbol):
names_v1 = []
if hasattr(symbol, _TENSORFLOW_API_ATTR_V1):
names_v1.extend(getattr(symbol, _TENSORFLOW_API_ATTR_V1))
if hasattr(symbol, _ESTIMATOR_API_ATTR_V1):
names_v1.extend(getattr(symbol, _ESTIMATOR_API_ATTR_V1))
return names_v1
def get_v2_names(symbol):
names_v2 = []
if hasattr(symbol, _TENSORFLOW_API_ATTR):
names_v2.extend(getattr(symbol, _TENSORFLOW_API_ATTR))
if hasattr(symbol, _ESTIMATOR_API_ATTR):
names_v2.extend(getattr(symbol, _ESTIMATOR_API_ATTR))
return list(names_v2)
def get_v1_constants(module):
constants_v1 = []
if hasattr(module, _TENSORFLOW_CONSTANTS_ATTR_V1):
constants_v1.extend(getattr(module, _TENSORFLOW_CONSTANTS_ATTR_V1))
if hasattr(module, _ESTIMATOR_CONSTANTS_ATTR_V1):
constants_v1.extend(getattr(module, _ESTIMATOR_CONSTANTS_ATTR_V1))
return constants_v1
def get_v2_constants(module):
constants_v2 = []
if hasattr(module, _TENSORFLOW_CONSTANTS_ATTR):
constants_v2.extend(getattr(module, _TENSORFLOW_CONSTANTS_ATTR))
if hasattr(module, _ESTIMATOR_CONSTANTS_ATTR):
constants_v2.extend(getattr(module, _ESTIMATOR_CONSTANTS_ATTR))
return constants_v2
def get_canonical_name(v2_names, v1_name):
if v2_names:
......@@ -131,7 +79,7 @@ def get_all_v2_names():
"""Visitor that collects TF 2.0 names."""
for child in children:
_, attr = tf_decorator.unwrap(child[1])
api_names_v2 = get_v2_names(attr)
api_names_v2 = tf_export.get_v2_names(attr)
for name in api_names_v2:
v2_names.add(name)
......@@ -149,8 +97,8 @@ def collect_constant_renames():
"""
renames = set()
for module in sys.modules.values():
constants_v1_list = get_v1_constants(module)
constants_v2_list = get_v2_constants(module)
constants_v1_list = tf_export.get_v1_constants(module)
constants_v2_list = tf_export.get_v2_constants(module)
# _tf_api_constants attribute contains a list of tuples:
# (api_names_list, constant_name)
......@@ -186,8 +134,8 @@ def collect_function_renames():
"""Visitor that collects rename strings to add to rename_line_set."""
for child in children:
_, attr = tf_decorator.unwrap(child[1])
api_names_v1 = get_v1_names(attr)
api_names_v2 = get_v2_names(attr)
api_names_v1 = tf_export.get_v1_names(attr)
api_names_v2 = tf_export.get_v2_names(attr)
deprecated_api_names = set(api_names_v1) - set(api_names_v2)
for name in deprecated_api_names:
renames.add((name, get_canonical_name(api_names_v2, name)))
......
......@@ -64,40 +64,6 @@ from __future__ import print_function
"""
_TENSORFLOW_API_ATTR_V1 = (
tf_export.API_ATTRS_V1[tf_export.TENSORFLOW_API_NAME].names)
_TENSORFLOW_API_ATTR = tf_export.API_ATTRS[tf_export.TENSORFLOW_API_NAME].names
_TENSORFLOW_CONSTANTS_ATTR_V1 = (
tf_export.API_ATTRS_V1[tf_export.TENSORFLOW_API_NAME].constants)
_TENSORFLOW_CONSTANTS_ATTR = (
tf_export.API_ATTRS[tf_export.TENSORFLOW_API_NAME].constants)
_ESTIMATOR_API_ATTR_V1 = (
tf_export.API_ATTRS_V1[tf_export.ESTIMATOR_API_NAME].names)
_ESTIMATOR_API_ATTR = tf_export.API_ATTRS[tf_export.ESTIMATOR_API_NAME].names
_ESTIMATOR_CONSTANTS_ATTR_V1 = (
tf_export.API_ATTRS_V1[tf_export.ESTIMATOR_API_NAME].constants)
_ESTIMATOR_CONSTANTS_ATTR = (
tf_export.API_ATTRS[tf_export.ESTIMATOR_API_NAME].constants)
def get_v1_names(symbol):
names_v1 = []
if hasattr(symbol, _TENSORFLOW_API_ATTR_V1):
names_v1.extend(getattr(symbol, _TENSORFLOW_API_ATTR_V1))
if hasattr(symbol, _ESTIMATOR_API_ATTR_V1):
names_v1.extend(getattr(symbol, _ESTIMATOR_API_ATTR_V1))
return names_v1
def get_v2_names(symbol):
names_v2 = []
if hasattr(symbol, _TENSORFLOW_API_ATTR):
names_v2.extend(getattr(symbol, _TENSORFLOW_API_ATTR))
if hasattr(symbol, _ESTIMATOR_API_ATTR):
names_v2.extend(getattr(symbol, _ESTIMATOR_API_ATTR))
return list(names_v2)
def collect_function_arg_names(function_names):
"""Determines argument names for reordered function signatures.
......@@ -115,7 +81,7 @@ def collect_function_arg_names(function_names):
"""Visitor that collects arguments for reordered functions."""
for child in children:
_, attr = tf_decorator.unwrap(child[1])
api_names_v1 = get_v1_names(attr)
api_names_v1 = tf_export.get_v1_names(attr)
api_names_v1 = ['tf.%s' % name for name in api_names_v1]
matches_function_names = any(
name in function_names for name in api_names_v1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册