From ff6dd474a624960e89a0a62a7855595cece2edca Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 8 Sep 2017 14:33:01 -0700 Subject: [PATCH] Use self._in_graph_mode consistently in ResourceVariable instead of sometimes getting it from the context. Also: fix formatting of a comment and use a more precise test to detect if initial_value is set. PiperOrigin-RevId: 168047258 --- .../python/ops/resource_variable_ops.py | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index fdc8a5843fe..c735be06983 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -42,14 +42,14 @@ from tensorflow.python.ops.gen_resource_variable_ops import * from tensorflow.python.util import compat -def _eager_safe_variable_handle(shape, dtype, shared_name, name, +def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode, container=None): """Creates a variable handle with information to do shape inference.""" handle = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype, shared_name=shared_name, name=name, container=container) - if context.in_graph_mode(): + if graph_mode: return handle with context.graph_mode(), ops.Graph().as_default(): h = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype, @@ -152,8 +152,8 @@ class ResourceVariable(variables.Variable): uniquified automatically. dtype: If set, initial_value will be converted to the given type. If None, either the datatype will be kept (if initial_value is - a Tensor) or float32 will be used (if it is a Python object convertible - to a Tensor). + a Tensor) or float32 will be used (if it is a Python object convertible + to a Tensor). variable_def: `VariableDef` protocol buffer. If not None, recreates the `ResourceVariable` object with its contents. `variable_def` and other arguments (except for import_scope) are mutually exclusive. @@ -172,7 +172,7 @@ class ResourceVariable(variables.Variable): shape and `validate_shape` is `True`. """ if variable_def: - if initial_value: + if initial_value is not None: raise ValueError("variable_def and initial_value are mutually " "exclusive.") if not context.in_graph_mode(): @@ -277,7 +277,8 @@ class ResourceVariable(variables.Variable): shape=initial_value.get_shape(), dtype=initial_value.dtype.base_dtype, shared_name=handle_name, - name=name) + name=name, + graph_mode=self._in_graph_mode) self._handle_device = ( self._handle.device if self._in_graph_mode else context.get_default_context().device_name) @@ -291,6 +292,7 @@ class ResourceVariable(variables.Variable): dtype=initial_value.dtype.base_dtype, shared_name=handle_name, name=name, + graph_mode=False, container="") self._handle_device = ( self._handle.device if self._in_graph_mode else @@ -316,6 +318,7 @@ class ResourceVariable(variables.Variable): dtype=initial_value.dtype.base_dtype, shared_name=handle_name, name=name, + graph_mode=self._in_graph_mode, container="") self._handle_device = (self._handle.device if self._in_graph_mode else context.get_default_context().device_name) @@ -372,6 +375,7 @@ class ResourceVariable(variables.Variable): """Initializes from `VariableDef` proto.""" # Note that init_from_proto is currently not supported in Eager mode. assert context.in_graph_mode() + self._in_graph_mode = True assert isinstance(variable_def, variable_pb2.VariableDef) if not variable_def.is_resource: raise ValueError("Trying to restore Variable as ResourceVariable.") @@ -434,7 +438,7 @@ class ResourceVariable(variables.Variable): @property def create(self): """The op responsible for initializing this variable.""" - if not context.in_graph_mode(): + if not self._in_graph_mode: raise RuntimeError("Calling create in EAGER mode not supported.") return self._initializer_op @@ -520,7 +524,7 @@ class ResourceVariable(variables.Variable): # In graph mode, ensure we read the variable in the same device as the # handle. In eager mode, however, this sometimes tries to read a GPU # variable in the CPU because the handle is host memory. For now, then, we - # need to skip the device block in eager. TODO(apassos) eager should have + # need to skip the device block in eager. TODO(apassos): eager should have # separate notions of device and memory, so handle.device can be GPU while # handle.memory_space is always CPU. if context.in_graph_mode(): -- GitLab