提交 ff6dd474 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Use self._in_graph_mode consistently in ResourceVariable

instead of sometimes getting it from the context.

Also: fix formatting of a comment and use a more precise test to detect
if initial_value is set.
PiperOrigin-RevId: 168047258
上级 f331f528
......@@ -42,14 +42,14 @@ from tensorflow.python.ops.gen_resource_variable_ops import *
from tensorflow.python.util import compat
def _eager_safe_variable_handle(shape, dtype, shared_name, name,
def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode,
container=None):
"""Creates a variable handle with information to do shape inference."""
handle = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype,
shared_name=shared_name,
name=name,
container=container)
if context.in_graph_mode():
if graph_mode:
return handle
with context.graph_mode(), ops.Graph().as_default():
h = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype,
......@@ -152,8 +152,8 @@ class ResourceVariable(variables.Variable):
uniquified automatically.
dtype: If set, initial_value will be converted to the given type.
If None, either the datatype will be kept (if initial_value is
a Tensor) or float32 will be used (if it is a Python object convertible
to a Tensor).
a Tensor) or float32 will be used (if it is a Python object convertible
to a Tensor).
variable_def: `VariableDef` protocol buffer. If not None, recreates the
`ResourceVariable` object with its contents. `variable_def` and other
arguments (except for import_scope) are mutually exclusive.
......@@ -172,7 +172,7 @@ class ResourceVariable(variables.Variable):
shape and `validate_shape` is `True`.
"""
if variable_def:
if initial_value:
if initial_value is not None:
raise ValueError("variable_def and initial_value are mutually "
"exclusive.")
if not context.in_graph_mode():
......@@ -277,7 +277,8 @@ class ResourceVariable(variables.Variable):
shape=initial_value.get_shape(),
dtype=initial_value.dtype.base_dtype,
shared_name=handle_name,
name=name)
name=name,
graph_mode=self._in_graph_mode)
self._handle_device = (
self._handle.device if self._in_graph_mode else
context.get_default_context().device_name)
......@@ -291,6 +292,7 @@ class ResourceVariable(variables.Variable):
dtype=initial_value.dtype.base_dtype,
shared_name=handle_name,
name=name,
graph_mode=False,
container="")
self._handle_device = (
self._handle.device if self._in_graph_mode else
......@@ -316,6 +318,7 @@ class ResourceVariable(variables.Variable):
dtype=initial_value.dtype.base_dtype,
shared_name=handle_name,
name=name,
graph_mode=self._in_graph_mode,
container="")
self._handle_device = (self._handle.device if self._in_graph_mode else
context.get_default_context().device_name)
......@@ -372,6 +375,7 @@ class ResourceVariable(variables.Variable):
"""Initializes from `VariableDef` proto."""
# Note that init_from_proto is currently not supported in Eager mode.
assert context.in_graph_mode()
self._in_graph_mode = True
assert isinstance(variable_def, variable_pb2.VariableDef)
if not variable_def.is_resource:
raise ValueError("Trying to restore Variable as ResourceVariable.")
......@@ -434,7 +438,7 @@ class ResourceVariable(variables.Variable):
@property
def create(self):
"""The op responsible for initializing this variable."""
if not context.in_graph_mode():
if not self._in_graph_mode:
raise RuntimeError("Calling create in EAGER mode not supported.")
return self._initializer_op
......@@ -520,7 +524,7 @@ class ResourceVariable(variables.Variable):
# In graph mode, ensure we read the variable in the same device as the
# handle. In eager mode, however, this sometimes tries to read a GPU
# variable in the CPU because the handle is host memory. For now, then, we
# need to skip the device block in eager. TODO(apassos) eager should have
# need to skip the device block in eager. TODO(apassos): eager should have
# separate notions of device and memory, so handle.device can be GPU while
# handle.memory_space is always CPU.
if context.in_graph_mode():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册