提交 5d4a29ea 编写于 作者: A Akshay Modi 提交者: TensorFlower Gardener

Special case wrapping of ndarrays in the gradient tape code.

PiperOrigin-RevId: 317762474
Change-Id: Ie848ad90a88aff5b2faef4069c3f05887038c367
上级 2d8d440d
......@@ -588,6 +588,7 @@ py_library(
"//tensorflow/python:tensor_shape",
"//tensorflow/python:unconnected_gradients",
"//tensorflow/python:util",
"//tensorflow/python/ops/numpy_ops:numpy",
"//tensorflow/python/ops/parallel_for:control_flow_ops",
"@six_archive//:six",
],
......
......@@ -62,6 +62,9 @@ from tensorflow.python.util.tf_export import tf_export
pfor_ops = LazyLoader(
"pfor_ops", globals(),
"tensorflow.python.ops.parallel_for.control_flow_ops")
np_arrays = LazyLoader(
"np_arrays", globals(),
"tensorflow.python.ops.numpy_ops.np_arrays")
function = LazyLoader("function", globals(),
"tensorflow.python.eager.function")
......@@ -721,9 +724,11 @@ pywrap_tfe.TFE_Py_RegisterVSpace(_default_vspace)
def _handle_or_self(x):
"""If x is ResourceVariable, return its handle, else x."""
"""Unwrap resource variable/ndarray to return tensors."""
if resource_variable_ops.is_resource_variable(x):
x = x.handle
return x.handle
if isinstance(x, np_arrays.ndarray):
return x.data
return x
......@@ -1023,6 +1028,7 @@ class GradientTape(object):
"gradient in order to compute higher order "
"derivatives.", 1)
num_ndarrays = 0
flat_targets = []
for t in nest.flatten(target):
if not backprop_util.IsTrainable(t):
......@@ -1033,7 +1039,12 @@ class GradientTape(object):
if resource_variable_ops.is_resource_variable(t):
with self:
t = ops.convert_to_tensor(t)
elif isinstance(t, np_arrays.ndarray):
t = t.data
num_ndarrays += 1
flat_targets.append(t)
# Only rewrap if all targets are ndarray. If not, prefer tensors.
rewrap_as_ndarray = num_ndarrays == len(flat_targets)
flat_sources = nest.flatten(sources)
flat_sources_raw = flat_sources
......@@ -1066,6 +1077,9 @@ class GradientTape(object):
self._watched_variables = self._tape.watched_variables()
self._tape = None
if rewrap_as_ndarray:
flat_grad = nest.map_structure(np_arrays.tensor_to_ndarray, flat_grad)
grad = nest.pack_sequence_as(sources, flat_grad)
return grad
......@@ -1120,6 +1134,10 @@ class GradientTape(object):
ValueError: If vectorization of jacobian computation fails.
"""
flat_sources = nest.flatten(sources)
rewrap_as_ndarray = False
if isinstance(target, np_arrays.ndarray):
target = target.data
rewrap_as_ndarray = True
target_static_shape = target.shape
target_shape = array_ops.shape(target)
# Note that we push and pop the tape here and below. This is needed since we
......@@ -1169,6 +1187,8 @@ class GradientTape(object):
out = array_ops.reshape(out, new_shape)
if context.executing_eagerly():
out.set_shape(target_static_shape.concatenate(flat_sources[i].shape))
if rewrap_as_ndarray:
out = np_arrays.tensor_to_ndarray(out)
output[i] = out
return nest.pack_sequence_as(sources, output)
......
......@@ -82,10 +82,10 @@ class NdarraySpec(type_spec.BatchableTypeSpec):
return (self._data_spec,)
def _batch(self, batch_size):
return NdarraySpec(self._data_spec.batch(batch_size))
return NdarraySpec(self._data_spec._batch(batch_size)) # pylint: disable=protected-access
def _unbatch(self):
return NdarraySpec(self._data_spec.unbatch())
return NdarraySpec(self._data_spec._unbatch()) # pylint: disable=protected-access
class ndarray(composite_tensor.CompositeTensor): # pylint: disable=invalid-name
......@@ -306,10 +306,6 @@ class ndarray(composite_tensor.CompositeTensor): # pylint: disable=invalid-name
def __repr__(self):
return 'ndarray<{}>'.format(self.data.__repr__())
@property
def _id(self):
return self.data._id # pylint: disable=protected-access
def tensor_to_ndarray(tensor):
return ndarray.from_tensor(tensor)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册