提交 d5f641cb 编写于 作者: P Peter Buchlovsky 提交者: TensorFlower Gardener

Make MirroredStrategy throw an error when creating a trainable ReplicaLocalVariable.

PiperOrigin-RevId: 251596345
上级 f57f1ce4
......@@ -844,6 +844,10 @@ class ParameterServerStrategyTest(
num_gpus_per_worker=2)
self._test_all_reduce_mean_gradient_tape(distribution)
def testTrainableVariables(self):
distribution = parameter_server_strategy.ParameterServerStrategy()
self._test_trainable_variable(distribution)
class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase,
parameterized.TestCase):
......
......@@ -217,7 +217,6 @@ def _create_mirrored_variable(strategy, device_map, logical_device, # pylint: d
elif synchronization == variable_scope.VariableSynchronization.ON_READ:
# Variables that are to be synced on read are replica local.
is_sync_on_read = True
kwargs["trainable"] = False
elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE or
synchronization == variable_scope.VariableSynchronization.AUTO):
# `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`.
......
......@@ -226,6 +226,9 @@ class MirroredTwoDeviceDistributionTest(
def testSummaryForReplicaZeroOnly(self, distribution):
self._test_summary_for_replica_zero_only(distribution)
def testTrainableVariables(self, distribution):
self._test_trainable_variable(distribution)
def one_device_combinations():
return combinations.combine(
......
......@@ -110,6 +110,9 @@ class OneDeviceStrategyTest(
def testAllReduceMeanGradientTape(self, distribution):
self._test_all_reduce_mean_gradient_tape(distribution)
def testTrainableVariables(self, distribution):
self._test_trainable_variable(distribution)
@combinations.generate(
combinations.combine(
......
......@@ -426,6 +426,23 @@ class DistributionTestBase(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
run_and_concatenate(strategy, i)
def _test_trainable_variable(self, strategy):
with strategy.scope():
v1 = variables.Variable(1.0)
self.assertEqual(True, v1.trainable)
v2 = variables.Variable(
1.0, synchronization=variables.VariableSynchronization.ON_READ)
self.assertEqual(False, v2.trainable)
with self.assertRaisesRegexp(
ValueError,
"Synchronization value can be set to VariableSynchronization.ON_READ "
"only for non-trainable variables"):
_ = variables.Variable(
1.0, trainable=True,
synchronization=variables.VariableSynchronization.ON_READ)
class OneDeviceDistributionTestBase(test.TestCase):
"""Some tests that should work with any one-device DistributionStrategy."""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册