diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py index 62050506751054305bd1081f7785cccce3961f2f..b3327f128361c5898d153d62590347248e1b4a38 100644 --- a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py +++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py @@ -526,11 +526,13 @@ class OptimizerV2(trackable.Trackable): initializer, shape=var.shape, dtype=var.dtype) else: initial_value = initializer - weight = tf_variables.Variable( - name="%s/%s" % (var._shared_name, slot_name), # pylint: disable=protected-access - dtype=var.dtype, - trainable=False, - initial_value=initial_value) + strategy = distribute_ctx.get_strategy() + with strategy.colocate_vars_with(var): + weight = tf_variables.Variable( + name="%s/%s" % (var._shared_name, slot_name), # pylint: disable=protected-access + dtype=var.dtype, + trainable=False, + initial_value=initial_value) backend.track_variable(weight) slot_dict[slot_name] = weight self._restore_slot_variable(