提交 8b5a12a1 编写于 作者: T Tom Hennigan 提交者: TensorFlower Gardener

Make V1 and V2 VariableAggregation enums interchangeable.

PiperOrigin-RevId: 235518682
上级 2e1ea487
......@@ -220,8 +220,9 @@ def _create_mirrored_variable(strategy, device_map, logical_device, # pylint: d
# `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`.
is_replica_local = False
else:
raise ValueError("Invalid variable synchronization mode: " +
synchronization + " for variable: " + kwargs["name"])
raise ValueError(
"Invalid variable synchronization mode: %s for variable: %s" %
(synchronization, kwargs["name"]))
# Get aggregation value
aggregation = kwargs.pop("aggregation",
......@@ -232,8 +233,9 @@ def _create_mirrored_variable(strategy, device_map, logical_device, # pylint: d
variable_scope.VariableAggregation.MEAN,
variable_scope.VariableAggregation.ONLY_FIRST_REPLICA
):
raise ValueError("Invalid variable aggregation mode: " + aggregation +
" for variable: " + kwargs["name"])
raise ValueError(
"Invalid variable aggregation mode: %s for variable: %s" %
(aggregation, kwargs["name"]))
# Ignore user-specified caching device, not needed for mirrored variables.
kwargs.pop("caching_device", None)
......
......@@ -828,5 +828,29 @@ class VariableContainerTest(test.TestCase):
self.assertEqual(compat.as_bytes(""), v4.op.get_attr("container"))
class AggregationModesTest(test.TestCase):
def testV1V2Equal(self):
v1 = variables.VariableAggregation
v2 = variables.VariableAggregationV2
self.assertEqual(v1.NONE, v2.NONE)
self.assertEqual(v1.SUM, v2.SUM)
self.assertEqual(v1.MEAN, v2.MEAN)
self.assertEqual(v1.ONLY_FIRST_REPLICA, v2.ONLY_FIRST_REPLICA)
self.assertEqual(v1.ONLY_FIRST_TOWER, v2.ONLY_FIRST_REPLICA)
self.assertEqual(v2.NONE, v1.NONE)
self.assertEqual(v2.SUM, v1.SUM)
self.assertEqual(v2.MEAN, v1.MEAN)
self.assertEqual(v2.ONLY_FIRST_REPLICA, v1.ONLY_FIRST_REPLICA)
self.assertEqual(v2.ONLY_FIRST_REPLICA, v1.ONLY_FIRST_TOWER)
self.assertEqual(hash(v1.NONE), hash(v2.NONE))
self.assertEqual(hash(v1.SUM), hash(v2.SUM))
self.assertEqual(hash(v1.MEAN), hash(v2.MEAN))
self.assertEqual(hash(v1.ONLY_FIRST_REPLICA), hash(v2.ONLY_FIRST_REPLICA))
self.assertEqual(hash(v1.ONLY_FIRST_TOWER), hash(v2.ONLY_FIRST_REPLICA))
if __name__ == "__main__":
test.main()
......@@ -103,6 +103,17 @@ class VariableAggregationV2(enum.Enum):
MEAN = 2
ONLY_FIRST_REPLICA = 3
def __hash__(self):
return hash(self.value)
def __eq__(self, other):
if self is other:
return True
elif isinstance(other, VariableAggregation):
return int(self.value) == int(other.value)
else:
return False
@tf_export(v1=["VariableAggregation"])
class VariableAggregation(enum.Enum):
......@@ -112,6 +123,9 @@ class VariableAggregation(enum.Enum):
ONLY_FIRST_REPLICA = 3
ONLY_FIRST_TOWER = 3 # DEPRECATED
def __hash__(self):
return hash(self.value)
VariableAggregation.__doc__ = (
VariableAggregationV2.__doc__ +
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册