From 75d8bd3a50a43ca68967a9c7c60db8eacab4ae36 Mon Sep 17 00:00:00 2001 From: Zhenyu Tan Date: Wed, 6 Mar 2019 08:00:13 -0800 Subject: [PATCH] Internal change PiperOrigin-RevId: 237047796 --- tensorflow/python/keras/optimizer_v2/optimizer_v2.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py index 62050506751..b3327f12836 100644 --- a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py +++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py @@ -526,11 +526,13 @@ class OptimizerV2(trackable.Trackable): initializer, shape=var.shape, dtype=var.dtype) else: initial_value = initializer - weight = tf_variables.Variable( - name="%s/%s" % (var._shared_name, slot_name), # pylint: disable=protected-access - dtype=var.dtype, - trainable=False, - initial_value=initial_value) + strategy = distribute_ctx.get_strategy() + with strategy.colocate_vars_with(var): + weight = tf_variables.Variable( + name="%s/%s" % (var._shared_name, slot_name), # pylint: disable=protected-access + dtype=var.dtype, + trainable=False, + initial_value=initial_value) backend.track_variable(weight) slot_dict[slot_name] = weight self._restore_slot_variable( -- GitLab