diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index 36be5c83f8bafb6c934d1d7682b5227b1f71c089..337a86b3421fdb90c98cd5097dd880fdbe5871b9 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -183,6 +183,34 @@ class MirroredStrategyVariableCreatorStackTest( expected = ("main_thread:thread_0", "main_thread:thread_1") self.assertEqual(expected, result) +@combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=["graph", "eager"])) +class MirroredStrategyCallForEachReplicaTest(test.TestCase): + + def testExecutingEagerlyOutsideFunction(self, distribution): + """Verify we preserve the value of executing_eagerly_outside_functions().""" + def model_fn(): + return ops.executing_eagerly_outside_functions() + + originally = ops.executing_eagerly_outside_functions() + with distribution.scope(): + in_scope = ops.executing_eagerly_outside_functions() + in_model_fn = distribution.extended.call_for_each_replica(model_fn) + unwrapped = distribution.unwrap(in_model_fn) + self.assertEqual(in_scope, unwrapped[0]) + self.assertEqual(in_scope, originally) + + # Verify this all again, but this time in a FuncGraph. + with func_graph.FuncGraph("fg").as_default(), distribution.scope(): + in_scope = ops.executing_eagerly_outside_functions() + in_model_fn = distribution.extended.call_for_each_replica(model_fn) + unwrapped = distribution.unwrap(in_model_fn) + self.assertEqual(in_scope, unwrapped[0]) + self.assertEqual(in_scope, originally) + @combinations.generate(combinations.combine( distribution=[ diff --git a/tensorflow/python/distribute/mirrored_strategy.py b/tensorflow/python/distribute/mirrored_strategy.py index cb94dfcfbd206eb81bbb76b36ded23a4f3bc2515..9692c88dfcba2a2e73c8447bb4d374cb923953bd 100644 --- a/tensorflow/python/distribute/mirrored_strategy.py +++ b/tensorflow/python/distribute/mirrored_strategy.py @@ -50,8 +50,8 @@ from tensorflow.python.util.tf_export import tf_export @contextlib.contextmanager -def _enter_graph(g): - if context.executing_eagerly(): +def _enter_graph(g, eager): + if eager: with g.as_default(), context.eager_mode(): yield else: @@ -839,14 +839,19 @@ class MirroredExtended(distribute_lib.DistributionStrategyExtended): self.has_paused = threading.Event() # These fields have to do with inheriting various contexts from the # parent thread: + ctx = context.context() + self.in_eager = ctx.executing_eagerly() # pylint: disable=protected-access - self.context_mode = context.context()._eager_context.mode - if not context.context()._context_handle: - context.context()._initialize_handle_and_devices() + if not ctx._context_handle: + ctx._initialize_handle_and_devices() self.context_device_policy = ( pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy( - context.context()._context_handle)) + ctx._context_handle)) self.graph = ops.get_default_graph() + with ops.init_scope(): + self._init_in_eager = context.executing_eagerly() + self._init_graph = ops.get_default_graph() + self._variable_creator_stack = self.graph._variable_creator_stack[:] self._captured_var_scope = variable_scope.get_variable_scope() # Adding a "/" at end lets us re-enter this scope later. @@ -867,9 +872,9 @@ class MirroredExtended(distribute_lib.DistributionStrategyExtended): if self.coord.should_stop(): return with self.coord.stop_on_exception(), \ - context.context()._mode(self.context_mode), \ + _enter_graph(self._init_graph, self._init_in_eager), \ + _enter_graph(self.graph, self.in_eager), \ context.context().device_policy(self.context_device_policy), \ - _enter_graph(self.graph), \ MirroredReplicaContext(self.distribution, constant_op.constant( self.replica_id, dtypes.int32)), \ ops.device(self.device), \