提交 662c0c2e 编写于 作者: T TensorFlower Gardener

Merge pull request #44373 from fsx950223:fix_warning

PiperOrigin-RevId: 339946530
Change-Id: I9ff4d4a380b6f2c2181edfc76cbb41af09abf0e2
......@@ -1594,6 +1594,35 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
self.assertIn('gradient_tape/my_scope/', op.name)
self.assertEqual(num_sin_ops_found, 2)
@test_util.assert_no_new_pyobjects_executing_eagerly
def testRecomputeGradWithDifferentShape(self):
@custom_gradient.recompute_grad
def outer(x):
return [x[0] + 1, x[1] + 1]
x = [
variables.Variable([1.0, 2.0], name='a'),
variables.Variable(1.0, name='b')
]
with backprop.GradientTape():
y = outer(x)
self.assertAllEqual(y[0], [2.0, 3.0])
self.assertAllEqual(y[1], 2.0)
@custom_gradient.recompute_grad
def outer_dict(x):
for key in x.keys():
x[key] = x[key] + 1
return x
x = {x[0].ref(): x[0], x[1].ref(): x[1]}
with backprop.GradientTape():
y = outer_dict(x)
y = list(y.values())
self.assertAllEqual(y[0], [2.0, 3.0])
self.assertAllEqual(y[1], 2.0)
@test_util.assert_no_new_pyobjects_executing_eagerly
def testRecomputeGradWithNestedFunctionAndWhileLoop(self):
......
......@@ -524,7 +524,7 @@ def recompute_grad(f):
# Gradient calculation for reverse mode autodiff.
variables = grad_kwargs.get("variables")
with backprop.GradientTape() as t:
id_args = [gen_array_ops.identity(x) for x in args]
id_args = nest.map_structure(gen_array_ops.identity, args)
t.watch(id_args)
if variables is not None:
t.watch(variables)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册