提交 0d822c01 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Fix so we preserve the value of `executing_eagerly_outside_functions()`

in the specific case of:
* Eager execution enabled
* Inside a FuncGraph, inside a graph
* In a replica context (such as in a call to
  `tf.distribute.Strategy.call_for_each_replica()`).

PiperOrigin-RevId: 224930182
上级 ce608761
......@@ -183,6 +183,34 @@ class MirroredStrategyVariableCreatorStackTest(
expected = ("main_thread:thread_0", "main_thread:thread_1")
self.assertEqual(expected, result)
@combinations.generate(combinations.combine(
distribution=[
combinations.mirrored_strategy_with_gpu_and_cpu,
combinations.core_mirrored_strategy_with_gpu_and_cpu],
mode=["graph", "eager"]))
class MirroredStrategyCallForEachReplicaTest(test.TestCase):
def testExecutingEagerlyOutsideFunction(self, distribution):
"""Verify we preserve the value of executing_eagerly_outside_functions()."""
def model_fn():
return ops.executing_eagerly_outside_functions()
originally = ops.executing_eagerly_outside_functions()
with distribution.scope():
in_scope = ops.executing_eagerly_outside_functions()
in_model_fn = distribution.extended.call_for_each_replica(model_fn)
unwrapped = distribution.unwrap(in_model_fn)
self.assertEqual(in_scope, unwrapped[0])
self.assertEqual(in_scope, originally)
# Verify this all again, but this time in a FuncGraph.
with func_graph.FuncGraph("fg").as_default(), distribution.scope():
in_scope = ops.executing_eagerly_outside_functions()
in_model_fn = distribution.extended.call_for_each_replica(model_fn)
unwrapped = distribution.unwrap(in_model_fn)
self.assertEqual(in_scope, unwrapped[0])
self.assertEqual(in_scope, originally)
@combinations.generate(combinations.combine(
distribution=[
......
......@@ -50,8 +50,8 @@ from tensorflow.python.util.tf_export import tf_export
@contextlib.contextmanager
def _enter_graph(g):
if context.executing_eagerly():
def _enter_graph(g, eager):
if eager:
with g.as_default(), context.eager_mode():
yield
else:
......@@ -839,14 +839,19 @@ class MirroredExtended(distribute_lib.DistributionStrategyExtended):
self.has_paused = threading.Event()
# These fields have to do with inheriting various contexts from the
# parent thread:
ctx = context.context()
self.in_eager = ctx.executing_eagerly()
# pylint: disable=protected-access
self.context_mode = context.context()._eager_context.mode
if not context.context()._context_handle:
context.context()._initialize_handle_and_devices()
if not ctx._context_handle:
ctx._initialize_handle_and_devices()
self.context_device_policy = (
pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy(
context.context()._context_handle))
ctx._context_handle))
self.graph = ops.get_default_graph()
with ops.init_scope():
self._init_in_eager = context.executing_eagerly()
self._init_graph = ops.get_default_graph()
self._variable_creator_stack = self.graph._variable_creator_stack[:]
self._captured_var_scope = variable_scope.get_variable_scope()
# Adding a "/" at end lets us re-enter this scope later.
......@@ -867,9 +872,9 @@ class MirroredExtended(distribute_lib.DistributionStrategyExtended):
if self.coord.should_stop():
return
with self.coord.stop_on_exception(), \
context.context()._mode(self.context_mode), \
_enter_graph(self._init_graph, self._init_in_eager), \
_enter_graph(self.graph, self.in_eager), \
context.context().device_policy(self.context_device_policy), \
_enter_graph(self.graph), \
MirroredReplicaContext(self.distribution, constant_op.constant(
self.replica_id, dtypes.int32)), \
ops.device(self.device), \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册