From a6d8ffae097d0132989ae4688d224121ec6d8f35 Mon Sep 17 00:00:00 2001 From: Todd Wang Date: Thu, 1 Nov 2018 18:35:10 -0700 Subject: [PATCH] Fix a bug in tpu.py and xla.py that while creating an identity node for control input edges under rewrite context, the parent control flow context is lost. (#23446) PiperOrigin-RevId: 219724472 --- tensorflow/contrib/compiler/xla.py | 13 +++++-------- tensorflow/contrib/tpu/python/tpu/tpu.py | 13 +++++-------- 2 files changed, 10 insertions(+), 16 deletions(-) diff --git a/tensorflow/contrib/compiler/xla.py b/tensorflow/contrib/compiler/xla.py index 873b03580d6..83d9d8c54ab 100644 --- a/tensorflow/contrib/compiler/xla.py +++ b/tensorflow/contrib/compiler/xla.py @@ -179,14 +179,11 @@ class XLACompileContext(control_flow_ops.XLAControlFlowContext): if external_control_inputs: # Use an identity to pull control inputs as data inputs. Note that we # ignore ops which don't have outputs. TODO(phawkins): fix that. - with ops.control_dependencies(None): - self.Enter() - external_control_inputs = [ - array_ops.identity(x.outputs[0]).op - for x in external_control_inputs - if x.outputs - ] - self.Exit() + external_control_inputs = [ + array_ops.identity(x.outputs[0]).op + for x in external_control_inputs + if x.outputs + ] # pylint: disable=protected-access op._add_control_inputs(external_control_inputs) # pylint: enable=protected-access diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index 11aaa1c66a8..a5ccaa071b9 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -371,14 +371,11 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): if external_control_inputs: # Use an identity to pull control inputs as data inputs. Note that we # ignore ops which don't have outputs. TODO(phawkins): fix that. - with ops.control_dependencies(None): - self.Enter() - external_control_inputs = [ - array_ops.identity(x.outputs[0]).op - for x in external_control_inputs - if x.outputs - ] - self.Exit() + external_control_inputs = [ + array_ops.identity(x.outputs[0]).op + for x in external_control_inputs + if x.outputs + ] # pylint: disable=protected-access op._add_control_inputs(external_control_inputs) # pylint: enable=protected-access -- GitLab