From 8b5a12a14a42ea1805416a6b6982ec1624306213 Mon Sep 17 00:00:00 2001 From: Tom Hennigan Date: Mon, 25 Feb 2019 06:26:28 -0800 Subject: [PATCH] Make V1 and V2 VariableAggregation enums interchangeable. PiperOrigin-RevId: 235518682 --- .../python/distribute/mirrored_strategy.py | 10 ++++---- .../python/kernel_tests/variables_test.py | 24 +++++++++++++++++++ tensorflow/python/ops/variables.py | 14 +++++++++++ 3 files changed, 44 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/distribute/mirrored_strategy.py b/tensorflow/python/distribute/mirrored_strategy.py index 96c7191652a..70fb302f109 100644 --- a/tensorflow/python/distribute/mirrored_strategy.py +++ b/tensorflow/python/distribute/mirrored_strategy.py @@ -220,8 +220,9 @@ def _create_mirrored_variable(strategy, device_map, logical_device, # pylint: d # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`. is_replica_local = False else: - raise ValueError("Invalid variable synchronization mode: " + - synchronization + " for variable: " + kwargs["name"]) + raise ValueError( + "Invalid variable synchronization mode: %s for variable: %s" % + (synchronization, kwargs["name"])) # Get aggregation value aggregation = kwargs.pop("aggregation", @@ -232,8 +233,9 @@ def _create_mirrored_variable(strategy, device_map, logical_device, # pylint: d variable_scope.VariableAggregation.MEAN, variable_scope.VariableAggregation.ONLY_FIRST_REPLICA ): - raise ValueError("Invalid variable aggregation mode: " + aggregation + - " for variable: " + kwargs["name"]) + raise ValueError( + "Invalid variable aggregation mode: %s for variable: %s" % + (aggregation, kwargs["name"])) # Ignore user-specified caching device, not needed for mirrored variables. kwargs.pop("caching_device", None) diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py index 028ef11fc49..b3316b73ff6 100644 --- a/tensorflow/python/kernel_tests/variables_test.py +++ b/tensorflow/python/kernel_tests/variables_test.py @@ -828,5 +828,29 @@ class VariableContainerTest(test.TestCase): self.assertEqual(compat.as_bytes(""), v4.op.get_attr("container")) +class AggregationModesTest(test.TestCase): + + def testV1V2Equal(self): + v1 = variables.VariableAggregation + v2 = variables.VariableAggregationV2 + + self.assertEqual(v1.NONE, v2.NONE) + self.assertEqual(v1.SUM, v2.SUM) + self.assertEqual(v1.MEAN, v2.MEAN) + self.assertEqual(v1.ONLY_FIRST_REPLICA, v2.ONLY_FIRST_REPLICA) + self.assertEqual(v1.ONLY_FIRST_TOWER, v2.ONLY_FIRST_REPLICA) + + self.assertEqual(v2.NONE, v1.NONE) + self.assertEqual(v2.SUM, v1.SUM) + self.assertEqual(v2.MEAN, v1.MEAN) + self.assertEqual(v2.ONLY_FIRST_REPLICA, v1.ONLY_FIRST_REPLICA) + self.assertEqual(v2.ONLY_FIRST_REPLICA, v1.ONLY_FIRST_TOWER) + + self.assertEqual(hash(v1.NONE), hash(v2.NONE)) + self.assertEqual(hash(v1.SUM), hash(v2.SUM)) + self.assertEqual(hash(v1.MEAN), hash(v2.MEAN)) + self.assertEqual(hash(v1.ONLY_FIRST_REPLICA), hash(v2.ONLY_FIRST_REPLICA)) + self.assertEqual(hash(v1.ONLY_FIRST_TOWER), hash(v2.ONLY_FIRST_REPLICA)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index 219ba7fbb2e..412300772b5 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -103,6 +103,17 @@ class VariableAggregationV2(enum.Enum): MEAN = 2 ONLY_FIRST_REPLICA = 3 + def __hash__(self): + return hash(self.value) + + def __eq__(self, other): + if self is other: + return True + elif isinstance(other, VariableAggregation): + return int(self.value) == int(other.value) + else: + return False + @tf_export(v1=["VariableAggregation"]) class VariableAggregation(enum.Enum): @@ -112,6 +123,9 @@ class VariableAggregation(enum.Enum): ONLY_FIRST_REPLICA = 3 ONLY_FIRST_TOWER = 3 # DEPRECATED + def __hash__(self): + return hash(self.value) + VariableAggregation.__doc__ = ( VariableAggregationV2.__doc__ + -- GitLab