From 93439a553937e77e8877a149d13039960da59abf Mon Sep 17 00:00:00 2001 From: Anna R Date: Tue, 11 Dec 2018 13:44:11 -0800 Subject: [PATCH] Use "in symbol.__dict__" instead of "hasattr" to check if a symbol has api names set. The former would behave correctly for subclasses. Also, moving get_v1_names|constants and get_v2_names|constants functions to tf_export.py to reduce code duplication. PiperOrigin-RevId: 225063242 --- tensorflow/python/util/tf_export.py | 88 +++++++++++++++++++ tensorflow/python/util/tf_export_test.py | 28 ++++++ .../tools/compatibility/tf_upgrade_v2_test.py | 38 ++------ .../update/generate_v2_renames_map.py | 62 ++----------- .../update/generate_v2_reorders_map.py | 36 +------- 5 files changed, 128 insertions(+), 124 deletions(-) diff --git a/tensorflow/python/util/tf_export.py b/tensorflow/python/util/tf_export.py index ec70cae7d2f..74afc3746fb 100644 --- a/tensorflow/python/util/tf_export.py +++ b/tensorflow/python/util/tf_export.py @@ -147,6 +147,94 @@ def get_canonical_name(api_names, deprecated_api_names): return None +def get_v1_names(symbol): + """Get a list of TF 1.* names for this symbol. + + Args: + symbol: symbol to get API names for. + + Returns: + List of all API names for this symbol including TensorFlow and + Estimator names. + """ + names_v1 = [] + tensorflow_api_attr_v1 = API_ATTRS_V1[TENSORFLOW_API_NAME].names + estimator_api_attr_v1 = API_ATTRS_V1[ESTIMATOR_API_NAME].names + + if not hasattr(symbol, tensorflow_api_attr_v1): + return names_v1 + if tensorflow_api_attr_v1 in symbol.__dict__: + names_v1.extend(getattr(symbol, tensorflow_api_attr_v1)) + if estimator_api_attr_v1 in symbol.__dict__: + names_v1.extend(getattr(symbol, estimator_api_attr_v1)) + return names_v1 + + +def get_v2_names(symbol): + """Get a list of TF 2.0 names for this symbol. + + Args: + symbol: symbol to get API names for. + + Returns: + List of all API names for this symbol including TensorFlow and + Estimator names. + """ + names_v2 = [] + tensorflow_api_attr = API_ATTRS[TENSORFLOW_API_NAME].names + estimator_api_attr = API_ATTRS[ESTIMATOR_API_NAME].names + + if not hasattr(symbol, tensorflow_api_attr): + return names_v2 + if tensorflow_api_attr in symbol.__dict__: + names_v2.extend(getattr(symbol, tensorflow_api_attr)) + if estimator_api_attr in symbol.__dict__: + names_v2.extend(getattr(symbol, estimator_api_attr)) + return names_v2 + + +def get_v1_constants(module): + """Get a list of TF 1.* constants in this module. + + Args: + module: TensorFlow module. + + Returns: + List of all API constants under the given module including TensorFlow and + Estimator constants. + """ + constants_v1 = [] + tensorflow_constants_attr_v1 = API_ATTRS_V1[TENSORFLOW_API_NAME].constants + estimator_constants_attr_v1 = API_ATTRS_V1[ESTIMATOR_API_NAME].constants + + if hasattr(module, tensorflow_constants_attr_v1): + constants_v1.extend(getattr(module, tensorflow_constants_attr_v1)) + if hasattr(module, estimator_constants_attr_v1): + constants_v1.extend(getattr(module, estimator_constants_attr_v1)) + return constants_v1 + + +def get_v2_constants(module): + """Get a list of TF 2.0 constants in this module. + + Args: + module: TensorFlow module. + + Returns: + List of all API constants under the given module including TensorFlow and + Estimator constants. + """ + constants_v2 = [] + tensorflow_constants_attr = API_ATTRS[TENSORFLOW_API_NAME].constants + estimator_constants_attr = API_ATTRS[ESTIMATOR_API_NAME].constants + + if hasattr(module, tensorflow_constants_attr): + constants_v2.extend(getattr(module, tensorflow_constants_attr)) + if hasattr(module, estimator_constants_attr): + constants_v2.extend(getattr(module, estimator_constants_attr)) + return constants_v2 + + class api_export(object): # pylint: disable=invalid-name """Provides ways to export symbols to the TensorFlow API.""" diff --git a/tensorflow/python/util/tf_export_test.py b/tensorflow/python/util/tf_export_test.py index a0fac8bf362..20625792e9b 100644 --- a/tensorflow/python/util/tf_export_test.py +++ b/tensorflow/python/util/tf_export_test.py @@ -62,6 +62,10 @@ class ValidateExportTest(test.TestCase): del symbol._tf_api_names if hasattr(symbol, '_tf_api_names_v1'): del symbol._tf_api_names_v1 + if hasattr(symbol, '_estimator_api_names'): + del symbol._estimator_api_names + if hasattr(symbol, '_estimator_api_names_v1'): + del symbol._estimator_api_names_v1 def _CreateMockModule(self, name): mock_module = self.MockModule(name) @@ -74,6 +78,10 @@ class ValidateExportTest(test.TestCase): decorated_function = export_decorator(_test_function) self.assertEquals(decorated_function, _test_function) self.assertEquals(('nameA', 'nameB'), decorated_function._tf_api_names) + self.assertEquals(['nameA', 'nameB'], + tf_export.get_v1_names(decorated_function)) + self.assertEquals(['nameA', 'nameB'], + tf_export.get_v2_names(decorated_function)) def testExportMultipleFunctions(self): export_decorator1 = tf_export.tf_export('nameA', 'nameB') @@ -95,6 +103,22 @@ class ValidateExportTest(test.TestCase): export_decorator_b(TestClassB) self.assertEquals(('TestClassA1',), TestClassA._tf_api_names) self.assertEquals(('TestClassB1',), TestClassB._tf_api_names) + self.assertEquals(['TestClassA1'], tf_export.get_v1_names(TestClassA)) + self.assertEquals(['TestClassB1'], tf_export.get_v1_names(TestClassB)) + + def testExportClassInEstimator(self): + export_decorator_a = tf_export.tf_export('TestClassA1') + export_decorator_a(TestClassA) + self.assertEquals(('TestClassA1',), TestClassA._tf_api_names) + + export_decorator_b = tf_export.estimator_export( + 'estimator.TestClassB1') + export_decorator_b(TestClassB) + self.assertTrue('_tf_api_names' not in TestClassB.__dict__) + self.assertEquals(('TestClassA1',), TestClassA._tf_api_names) + self.assertEquals(['TestClassA1'], tf_export.get_v1_names(TestClassA)) + self.assertEquals(['estimator.TestClassB1'], + tf_export.get_v1_names(TestClassB)) def testExportSingleConstant(self): module1 = self._CreateMockModule('module1') @@ -103,6 +127,10 @@ class ValidateExportTest(test.TestCase): export_decorator.export_constant('module1', 'test_constant') self.assertEquals([(('NAME_A', 'NAME_B'), 'test_constant')], module1._tf_api_constants) + self.assertEquals([(('NAME_A', 'NAME_B'), 'test_constant')], + tf_export.get_v1_constants(module1)) + self.assertEquals([(('NAME_A', 'NAME_B'), 'test_constant')], + tf_export.get_v2_constants(module1)) def testExportMultipleConstants(self): module1 = self._CreateMockModule('module1') diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py index 0fc7a187342..2cc874fe7f5 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py +++ b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py @@ -37,32 +37,6 @@ from tensorflow.tools.compatibility import ast_edits from tensorflow.tools.compatibility import tf_upgrade_v2 -_TENSORFLOW_API_ATTR_V1 = ( - tf_export.API_ATTRS_V1[tf_export.TENSORFLOW_API_NAME].names) -_TENSORFLOW_API_ATTR = tf_export.API_ATTRS[tf_export.TENSORFLOW_API_NAME].names -_ESTIMATOR_API_ATTR_V1 = ( - tf_export.API_ATTRS_V1[tf_export.ESTIMATOR_API_NAME].names) -_ESTIMATOR_API_ATTR = tf_export.API_ATTRS[tf_export.ESTIMATOR_API_NAME].names - - -def get_v1_names(symbol): - names_v1 = [] - if hasattr(symbol, _TENSORFLOW_API_ATTR_V1): - names_v1.extend(getattr(symbol, _TENSORFLOW_API_ATTR_V1)) - if hasattr(symbol, _ESTIMATOR_API_ATTR_V1): - names_v1.extend(getattr(symbol, _ESTIMATOR_API_ATTR_V1)) - return names_v1 - - -def get_v2_names(symbol): - names_v2 = set() - if hasattr(symbol, _TENSORFLOW_API_ATTR): - names_v2.update(getattr(symbol, _TENSORFLOW_API_ATTR)) - if hasattr(symbol, _ESTIMATOR_API_ATTR): - names_v2.update(getattr(symbol, _ESTIMATOR_API_ATTR)) - return list(names_v2) - - def get_symbol_for_name(root, name): name_parts = name.split(".") symbol = root @@ -118,7 +92,7 @@ class TestUpgrade(test_util.TensorFlowTestCase): def symbol_collector(unused_path, unused_parent, children): for child in children: _, attr = tf_decorator.unwrap(child[1]) - api_names_v2 = get_v2_names(attr) + api_names_v2 = tf_export.get_v2_names(attr) for name in api_names_v2: cls.v2_symbols["tf." + name] = attr @@ -166,7 +140,7 @@ class TestUpgrade(test_util.TensorFlowTestCase): def conversion_visitor(unused_path, unused_parent, children): for child in children: _, attr = tf_decorator.unwrap(child[1]) - api_names = get_v1_names(attr) + api_names = tf_export.get_v1_names(attr) for name in api_names: _, _, _, text = self._upgrade("tf." + name) if (text and @@ -190,7 +164,7 @@ class TestUpgrade(test_util.TensorFlowTestCase): def conversion_visitor(unused_path, unused_parent, children): for child in children: _, attr = tf_decorator.unwrap(child[1]) - api_names = get_v1_names(attr) + api_names = tf_export.get_v1_names(attr) for name in api_names: if collect: v1_symbols.add("tf." + name) @@ -219,7 +193,7 @@ class TestUpgrade(test_util.TensorFlowTestCase): def arg_test_visitor(unused_path, unused_parent, children): for child in children: _, attr = tf_decorator.unwrap(child[1]) - names_v1 = get_v1_names(attr) + names_v1 = tf_export.get_v1_names(attr) for name in names_v1: name = "tf.%s" % name @@ -270,7 +244,7 @@ class TestUpgrade(test_util.TensorFlowTestCase): _, attr = tf_decorator.unwrap(child[1]) if not tf_inspect.isfunction(attr): continue - names_v1 = get_v1_names(attr) + names_v1 = tf_export.get_v1_names(attr) arg_names_v1 = get_args(attr) for name in names_v1: @@ -340,7 +314,7 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map # get other names for this function attr = get_symbol_for_name(tf.compat.v1, name) _, attr = tf_decorator.unwrap(attr) - v1_names = get_v1_names(attr) + v1_names = tf_export.get_v1_names(attr) self.assertTrue(v1_names) v1_names = ["tf.%s" % n for n in v1_names] # check if any other name is in diff --git a/tensorflow/tools/compatibility/update/generate_v2_renames_map.py b/tensorflow/tools/compatibility/update/generate_v2_renames_map.py index 19ad6c3a2a5..a2c5e7cf82d 100644 --- a/tensorflow/tools/compatibility/update/generate_v2_renames_map.py +++ b/tensorflow/tools/compatibility/update/generate_v2_renames_map.py @@ -64,58 +64,6 @@ from __future__ import print_function """ -_TENSORFLOW_API_ATTR_V1 = ( - tf_export.API_ATTRS_V1[tf_export.TENSORFLOW_API_NAME].names) -_TENSORFLOW_API_ATTR = tf_export.API_ATTRS[tf_export.TENSORFLOW_API_NAME].names -_TENSORFLOW_CONSTANTS_ATTR_V1 = ( - tf_export.API_ATTRS_V1[tf_export.TENSORFLOW_API_NAME].constants) -_TENSORFLOW_CONSTANTS_ATTR = ( - tf_export.API_ATTRS[tf_export.TENSORFLOW_API_NAME].constants) - -_ESTIMATOR_API_ATTR_V1 = ( - tf_export.API_ATTRS_V1[tf_export.ESTIMATOR_API_NAME].names) -_ESTIMATOR_API_ATTR = tf_export.API_ATTRS[tf_export.ESTIMATOR_API_NAME].names -_ESTIMATOR_CONSTANTS_ATTR_V1 = ( - tf_export.API_ATTRS_V1[tf_export.ESTIMATOR_API_NAME].constants) -_ESTIMATOR_CONSTANTS_ATTR = ( - tf_export.API_ATTRS[tf_export.ESTIMATOR_API_NAME].constants) - - -def get_v1_names(symbol): - names_v1 = [] - if hasattr(symbol, _TENSORFLOW_API_ATTR_V1): - names_v1.extend(getattr(symbol, _TENSORFLOW_API_ATTR_V1)) - if hasattr(symbol, _ESTIMATOR_API_ATTR_V1): - names_v1.extend(getattr(symbol, _ESTIMATOR_API_ATTR_V1)) - return names_v1 - - -def get_v2_names(symbol): - names_v2 = [] - if hasattr(symbol, _TENSORFLOW_API_ATTR): - names_v2.extend(getattr(symbol, _TENSORFLOW_API_ATTR)) - if hasattr(symbol, _ESTIMATOR_API_ATTR): - names_v2.extend(getattr(symbol, _ESTIMATOR_API_ATTR)) - return list(names_v2) - - -def get_v1_constants(module): - constants_v1 = [] - if hasattr(module, _TENSORFLOW_CONSTANTS_ATTR_V1): - constants_v1.extend(getattr(module, _TENSORFLOW_CONSTANTS_ATTR_V1)) - if hasattr(module, _ESTIMATOR_CONSTANTS_ATTR_V1): - constants_v1.extend(getattr(module, _ESTIMATOR_CONSTANTS_ATTR_V1)) - return constants_v1 - - -def get_v2_constants(module): - constants_v2 = [] - if hasattr(module, _TENSORFLOW_CONSTANTS_ATTR): - constants_v2.extend(getattr(module, _TENSORFLOW_CONSTANTS_ATTR)) - if hasattr(module, _ESTIMATOR_CONSTANTS_ATTR): - constants_v2.extend(getattr(module, _ESTIMATOR_CONSTANTS_ATTR)) - return constants_v2 - def get_canonical_name(v2_names, v1_name): if v2_names: @@ -131,7 +79,7 @@ def get_all_v2_names(): """Visitor that collects TF 2.0 names.""" for child in children: _, attr = tf_decorator.unwrap(child[1]) - api_names_v2 = get_v2_names(attr) + api_names_v2 = tf_export.get_v2_names(attr) for name in api_names_v2: v2_names.add(name) @@ -149,8 +97,8 @@ def collect_constant_renames(): """ renames = set() for module in sys.modules.values(): - constants_v1_list = get_v1_constants(module) - constants_v2_list = get_v2_constants(module) + constants_v1_list = tf_export.get_v1_constants(module) + constants_v2_list = tf_export.get_v2_constants(module) # _tf_api_constants attribute contains a list of tuples: # (api_names_list, constant_name) @@ -186,8 +134,8 @@ def collect_function_renames(): """Visitor that collects rename strings to add to rename_line_set.""" for child in children: _, attr = tf_decorator.unwrap(child[1]) - api_names_v1 = get_v1_names(attr) - api_names_v2 = get_v2_names(attr) + api_names_v1 = tf_export.get_v1_names(attr) + api_names_v2 = tf_export.get_v2_names(attr) deprecated_api_names = set(api_names_v1) - set(api_names_v2) for name in deprecated_api_names: renames.add((name, get_canonical_name(api_names_v2, name))) diff --git a/tensorflow/tools/compatibility/update/generate_v2_reorders_map.py b/tensorflow/tools/compatibility/update/generate_v2_reorders_map.py index 63541771bf3..0eb942d3961 100644 --- a/tensorflow/tools/compatibility/update/generate_v2_reorders_map.py +++ b/tensorflow/tools/compatibility/update/generate_v2_reorders_map.py @@ -64,40 +64,6 @@ from __future__ import print_function """ -_TENSORFLOW_API_ATTR_V1 = ( - tf_export.API_ATTRS_V1[tf_export.TENSORFLOW_API_NAME].names) -_TENSORFLOW_API_ATTR = tf_export.API_ATTRS[tf_export.TENSORFLOW_API_NAME].names -_TENSORFLOW_CONSTANTS_ATTR_V1 = ( - tf_export.API_ATTRS_V1[tf_export.TENSORFLOW_API_NAME].constants) -_TENSORFLOW_CONSTANTS_ATTR = ( - tf_export.API_ATTRS[tf_export.TENSORFLOW_API_NAME].constants) - -_ESTIMATOR_API_ATTR_V1 = ( - tf_export.API_ATTRS_V1[tf_export.ESTIMATOR_API_NAME].names) -_ESTIMATOR_API_ATTR = tf_export.API_ATTRS[tf_export.ESTIMATOR_API_NAME].names -_ESTIMATOR_CONSTANTS_ATTR_V1 = ( - tf_export.API_ATTRS_V1[tf_export.ESTIMATOR_API_NAME].constants) -_ESTIMATOR_CONSTANTS_ATTR = ( - tf_export.API_ATTRS[tf_export.ESTIMATOR_API_NAME].constants) - - -def get_v1_names(symbol): - names_v1 = [] - if hasattr(symbol, _TENSORFLOW_API_ATTR_V1): - names_v1.extend(getattr(symbol, _TENSORFLOW_API_ATTR_V1)) - if hasattr(symbol, _ESTIMATOR_API_ATTR_V1): - names_v1.extend(getattr(symbol, _ESTIMATOR_API_ATTR_V1)) - return names_v1 - - -def get_v2_names(symbol): - names_v2 = [] - if hasattr(symbol, _TENSORFLOW_API_ATTR): - names_v2.extend(getattr(symbol, _TENSORFLOW_API_ATTR)) - if hasattr(symbol, _ESTIMATOR_API_ATTR): - names_v2.extend(getattr(symbol, _ESTIMATOR_API_ATTR)) - return list(names_v2) - def collect_function_arg_names(function_names): """Determines argument names for reordered function signatures. @@ -115,7 +81,7 @@ def collect_function_arg_names(function_names): """Visitor that collects arguments for reordered functions.""" for child in children: _, attr = tf_decorator.unwrap(child[1]) - api_names_v1 = get_v1_names(attr) + api_names_v1 = tf_export.get_v1_names(attr) api_names_v1 = ['tf.%s' % name for name in api_names_v1] matches_function_names = any( name in function_names for name in api_names_v1) -- GitLab