未验证 提交 79dbbcce 编写于 作者: T Tongxin Bai 提交者: GitHub

[autograd.functional] Fix a bug on handling v=None in vjp and jvp (#36445)

* autograd.functional passed pylint checker.

* autograd.functional: fix import errors.

* autograd.functional: fixed unit tests.

* autograd.functional minor format change

* [autograd.functional] Fixed vjp and jvp's v=None bug.
上级 e496d1e9
...@@ -23,9 +23,10 @@ from .utils import _tensors, _stack_tensor_or_return_none, _replace_none_with_ze ...@@ -23,9 +23,10 @@ from .utils import _tensors, _stack_tensor_or_return_none, _replace_none_with_ze
@contextlib.contextmanager @contextlib.contextmanager
def gradient_scope(*var_lists, create_graph=False, allow_unused=False): def gradient_scope(*var_lists, create_graph=False, allow_unused=False):
def grad_fn(ys, xs, v, create_graph=create_graph): def grad_fn(ys, xs, v=None, create_graph=create_graph):
if v is not None:
assert len(ys) == len(v), ( assert len(ys) == len(v), (
f'`v` is expected to be of the same size as the output. ' f'The argument {v} is expected to be of the same size as the output. '
f'Here the output is {ys}, and `v` is {v}.') f'Here the output is {ys}, and `v` is {v}.')
if allow_unused: if allow_unused:
ys = [ ys = [
...@@ -49,6 +50,8 @@ def gradient_scope(*var_lists, create_graph=False, allow_unused=False): ...@@ -49,6 +50,8 @@ def gradient_scope(*var_lists, create_graph=False, allow_unused=False):
return out return out
def process(vl): def process(vl):
if vl is None:
return None
out = [] out = []
# If v is treated as constant in the outer scope, its gradient is guaranteed # If v is treated as constant in the outer scope, its gradient is guaranteed
# not to be taken beyond this scope. Within this scope, however, v's gradient # not to be taken beyond this scope. Within this scope, however, v's gradient
...@@ -151,7 +154,9 @@ def vjp(func, inputs, v=None, create_graph=False, allow_unused=False): ...@@ -151,7 +154,9 @@ def vjp(func, inputs, v=None, create_graph=False, allow_unused=False):
# [[2., 1.], # [[2., 1.],
# [1., 0.]]), None] # [1., 0.]]), None]
""" """
xs, v = _tensors(inputs, "inputs"), _tensors(v, "v") xs = _tensors(inputs, "inputs")
if v is not None:
v = _tensors(v, "v")
with gradient_scope( with gradient_scope(
xs, v, create_graph=create_graph, xs, v, create_graph=create_graph,
...@@ -221,7 +226,9 @@ def jvp(func, inputs, v=None, create_graph=False, allow_unused=False): ...@@ -221,7 +226,9 @@ def jvp(func, inputs, v=None, create_graph=False, allow_unused=False):
# [0., 0.]])] # [0., 0.]])]
""" """
xs, v = _tensors(inputs, "inputs"), _tensors(v, "v") xs = _tensors(inputs, "inputs")
if v is not None:
v = _tensors(v, "v")
with gradient_scope( with gradient_scope(
xs, v, create_graph=create_graph, xs, v, create_graph=create_graph,
......
...@@ -205,6 +205,16 @@ class TestVJP(TestAutogradFunctional): ...@@ -205,6 +205,16 @@ class TestVJP(TestAutogradFunctional):
vjp_result, grad_result = vjp(), grad() vjp_result, grad_result = vjp(), grad()
self.check_results(grad_result, vjp_result) self.check_results(grad_result, vjp_result)
def test_vjp_i2o2_omitting_v_no_create_graph(self):
test_cases = [
[o2, ['A', 'A']], #noqa
] #noqa
for f, inputs in test_cases:
inputs = self.gen_inputs(inputs)
vjp, grad = self.gen_test_pairs(f, inputs)
vjp_result, grad_result = vjp(), grad()
self.check_results(grad_result, vjp_result)
def test_vjp_nested_no_create_graph(self): def test_vjp_nested_no_create_graph(self):
x = self.gen_input('a') x = self.gen_input('a')
test_cases = [ test_cases = [
...@@ -289,6 +299,17 @@ class TestJVP(TestAutogradFunctional): ...@@ -289,6 +299,17 @@ class TestJVP(TestAutogradFunctional):
reverse_jac = jac(vjp, f, inputs) reverse_jac = jac(vjp, f, inputs)
self.check_results(forward_jac, reverse_jac) self.check_results(forward_jac, reverse_jac)
def test_jvp_i2o2_omitting_v_no_create_graph(self):
test_cases = [ #noqa
[o2, ['A', 'A']], #noqa
] #noqa
for f, inputs in test_cases:
inputs = self.gen_inputs(inputs)
results_omitting_v = jvp(f, inputs)
v = [ones_like(x) for x in inputs]
results_with_v = jvp(f, inputs, v)
self.check_results(results_omitting_v, results_with_v)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册