未验证 提交 3c2bdaa8 编写于 作者: L levi131 提交者: GitHub

unify usage of tuple and list (#36368)

* modify format

* modify format
上级 033a73c3
...@@ -18,20 +18,7 @@ from ..fluid import framework ...@@ -18,20 +18,7 @@ from ..fluid import framework
from ..fluid.dygraph import grad from ..fluid.dygraph import grad
from ..nn.initializer import assign from ..nn.initializer import assign
from ..tensor import reshape, zeros_like, to_tensor from ..tensor import reshape, zeros_like, to_tensor
from .utils import _check_tensors, _stack_tensor_or_return_none, _replace_none_with_zero_tensor from .utils import _tensors, _stack_tensor_or_return_none, _replace_none_with_zero_tensor
def to_tensorlist(tl):
if not isinstance(tl, list):
if isinstance(tl, tuple):
tl = list(tl)
else:
tl = [tl]
for t in tl:
assert isinstance(t, paddle.Tensor) or t is None, (
f'{t} is expected to be paddle.Tensor or None, but found {type(t)}.'
)
return tl
@contextlib.contextmanager @contextlib.contextmanager
...@@ -98,19 +85,19 @@ def vjp(func, inputs, v=None, create_graph=False, allow_unused=False): ...@@ -98,19 +85,19 @@ def vjp(func, inputs, v=None, create_graph=False, allow_unused=False):
reverse mode automatic differentiation. reverse mode automatic differentiation.
Args: Args:
func(Callable): `func` takes as input a tensor or a list func(Callable): `func` takes as input a tensor or a list/tuple
of tensors and returns a tensor or a list of tensors. of tensors and returns a tensor or a list/tuple of tensors.
inputs(list[Tensor]|Tensor): used as positional arguments inputs(list[Tensor]|tuple[Tensor]|Tensor): used as positional
to evaluate `func`. `inputs` is accepted as one tensor arguments to evaluate `func`. `inputs` is accepted as one
or a list of tensors. tensor or a list of tensors.
v(list[Tensor]|Tensor, optional): the cotangent vector v(list[Tensor]|tuple[Tensor]|Tensor|None, optional): the
invovled in the VJP computation. `v` matches the size cotangent vector invovled in the VJP computation. `v` matches
and shape of `func`'s output. Default value is None the size and shape of `func`'s output. Default value is None
and in this case is equivalent to all ones the same size and in this case is equivalent to all ones the same size
of `func`'s output. of `func`'s output.
create_graph(bool, optional): if `True`, gradients can create_graph(bool, optional): if `True`, gradients can be
be evaluated on the results. If `False`, taking gradients evaluated on the results. If `False`, taking gradients on
on the results is invalid. Default value is False. the results is invalid. Default value is False.
allow_unused(bool, optional): In case that some Tensors of allow_unused(bool, optional): In case that some Tensors of
`inputs` do not contribute to the computation of the output. `inputs` do not contribute to the computation of the output.
If `allow_unused` is False, an error will be raised, If `allow_unused` is False, an error will be raised,
...@@ -119,8 +106,9 @@ def vjp(func, inputs, v=None, create_graph=False, allow_unused=False): ...@@ -119,8 +106,9 @@ def vjp(func, inputs, v=None, create_graph=False, allow_unused=False):
Returns: Returns:
output(tuple): output(tuple):
func_out: the output of `func(inputs)` func_out(list[Tensor]|tuple[Tensor]|Tensor): the output of
vjp(list[Tensor]|Tensor): the pullback results of `v` on `func` `func(inputs)`
vjp(list[Tensor]): the pullback results of `v` on `func`
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -163,13 +151,13 @@ def vjp(func, inputs, v=None, create_graph=False, allow_unused=False): ...@@ -163,13 +151,13 @@ def vjp(func, inputs, v=None, create_graph=False, allow_unused=False):
# [[2., 1.], # [[2., 1.],
# [1., 0.]]), None] # [1., 0.]]), None]
""" """
xs, v = to_tensorlist(inputs), to_tensorlist(v) xs, v = _tensors(inputs, "inputs"), _tensors(v, "v")
with gradient_scope( with gradient_scope(
xs, v, create_graph=create_graph, xs, v, create_graph=create_graph,
allow_unused=allow_unused) as [xs, v, grad_fn, return_fn]: allow_unused=allow_unused) as [xs, v, grad_fn, return_fn]:
outputs = func(*xs) outputs = func(*xs)
ys = to_tensorlist(outputs) ys = _tensors(outputs, "outputs")
grads = grad_fn(ys, xs, v) grads = grad_fn(ys, xs, v)
outputs, grads = return_fn(outputs), return_fn(grads) outputs, grads = return_fn(outputs), return_fn(grads)
...@@ -186,16 +174,16 @@ def jvp(func, inputs, v=None, create_graph=False, allow_unused=False): ...@@ -186,16 +174,16 @@ def jvp(func, inputs, v=None, create_graph=False, allow_unused=False):
**This API is ONLY available in imperative mode.** **This API is ONLY available in imperative mode.**
Args: Args:
func(Callable): `func` takes as input a tensor or a list func(Callable): `func` takes as input a tensor or a list/tuple
of tensors and returns a tensor or a list of tensors. of tensors and returns a tensor or a list/tuple of tensors.
inputs(list[Tensor]|Tensor): used as positional arguments inputs(list[Tensor]|tuple[Tensor]|Tensor): used as positional
to evaluate `func`. `inputs` is accepted as one tensor arguments to evaluate `func`. `inputs` is accepted as one
or a list of tensors. tensor or a list/tuple of tensors.
v(list[Tensor]|Tensor, optional): the tangent vector v(list[Tensor]|tuple[Tensor]|Tensor|None, optional): the
invovled in the JVP computation. `v` matches the size tangent vector invovled in the JVP computation. `v` matches
and shape of `inputs`. `v` is Optional if `func` returns the size and shape of `inputs`. `v` is Optional if `func`
a single tensor. Default value is None and in this case returns a single tensor. Default value is None and in this
is equivalent to all ones the same size of `inputs`. case is equivalent to all ones the same size of `inputs`.
create_graph(bool, optional): if `True`, gradients can create_graph(bool, optional): if `True`, gradients can
be evaluated on the results. If `False`, taking gradients be evaluated on the results. If `False`, taking gradients
on the results is invalid. Default value is False. on the results is invalid. Default value is False.
...@@ -207,8 +195,9 @@ def jvp(func, inputs, v=None, create_graph=False, allow_unused=False): ...@@ -207,8 +195,9 @@ def jvp(func, inputs, v=None, create_graph=False, allow_unused=False):
Returns: Returns:
output(tuple): output(tuple):
func_out: the output of `func(inputs)` func_out(list[Tensor]|tuple[Tensor]|Tensor): the output of
jvp(list[Tensor]|Tensor): the pullback results of `v` on `func` `func(inputs)`
jvp(list[Tensor]): the pullback results of `v` on `func`
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -232,13 +221,13 @@ def jvp(func, inputs, v=None, create_graph=False, allow_unused=False): ...@@ -232,13 +221,13 @@ def jvp(func, inputs, v=None, create_graph=False, allow_unused=False):
# [0., 0.]])] # [0., 0.]])]
""" """
xs, v = to_tensorlist(inputs), to_tensorlist(v) xs, v = _tensors(inputs, "inputs"), _tensors(v, "v")
with gradient_scope( with gradient_scope(
xs, v, create_graph=create_graph, xs, v, create_graph=create_graph,
allow_unused=allow_unused) as [xs, v, grad_fn, return_fn]: allow_unused=allow_unused) as [xs, v, grad_fn, return_fn]:
outputs = func(*xs) outputs = func(*xs)
ys = to_tensorlist(outputs) ys = _tensors(outputs, "outputs")
ys_grad = [zeros_like(y) for y in ys] ys_grad = [zeros_like(y) for y in ys]
xs_grad = grad_fn(ys, xs, ys_grad, create_graph=True) xs_grad = grad_fn(ys, xs, ys_grad, create_graph=True)
ys_grad = grad_fn(xs_grad, ys_grad, v) ys_grad = grad_fn(xs_grad, ys_grad, v)
...@@ -357,8 +346,8 @@ def jacobian(func, inputs, create_graph=False, allow_unused=False): ...@@ -357,8 +346,8 @@ def jacobian(func, inputs, create_graph=False, allow_unused=False):
# [0., 0., 0., 2.]]), None)) # [0., 0., 0., 2.]]), None))
''' '''
inputs = _check_tensors(inputs, "inputs") inputs = _tensors(inputs, "inputs")
outputs = _check_tensors(func(*inputs), "outputs") outputs = _tensors(func(*inputs), "outputs")
fin_size = len(inputs) fin_size = len(inputs)
fout_size = len(outputs) fout_size = len(outputs)
flat_outputs = tuple(reshape(output, shape=[-1]) for output in outputs) flat_outputs = tuple(reshape(output, shape=[-1]) for output in outputs)
...@@ -494,7 +483,7 @@ def hessian(func, inputs, create_graph=False, allow_unused=False): ...@@ -494,7 +483,7 @@ def hessian(func, inputs, create_graph=False, allow_unused=False):
# [0., 1., 1., 2.]]), None), (None, None)) # [0., 1., 1., 2.]]), None), (None, None))
''' '''
inputs = _check_tensors(inputs, "inputs") inputs = _tensors(inputs, "inputs")
outputs = func(*inputs) outputs = func(*inputs)
assert isinstance(outputs, paddle.Tensor) and outputs.shape == [ assert isinstance(outputs, paddle.Tensor) and outputs.shape == [
1 1
......
...@@ -15,22 +15,20 @@ ...@@ -15,22 +15,20 @@
import paddle import paddle
def _check_tensors(in_out_list, name): def _tensors(ts, name):
assert in_out_list is not None, "{} should not be None".format(name) if isinstance(ts, (list, tuple)):
assert len(ts) > 0, "{} connot be empty".format(name)
if isinstance(in_out_list, (list, tuple)): for each_t in ts:
assert len(in_out_list) > 0, "{} connot be empyt".format(name)
for each_var in in_out_list:
assert isinstance( assert isinstance(
each_var, each_t, paddle.Tensor
paddle.Tensor), "Elements of {} must be paddle.Tensor".format( ) or each_t is None, "Elements of {} must be paddle.Tensor or None".format(
name) name)
return list(in_out_list) return list(ts)
else: else:
assert isinstance( assert isinstance(
in_out_list, ts, paddle.Tensor
paddle.Tensor), "{} must be Tensor or list of Tensor".format(name) ) or ts is None, "{} must be Tensor or list of Tensor".format(name)
return [in_out_list] return [ts]
def _stack_tensor_or_return_none(origin_list): def _stack_tensor_or_return_none(origin_list):
......
...@@ -456,7 +456,7 @@ def grad(outputs, ...@@ -456,7 +456,7 @@ def grad(outputs,
the Tensors whose gradients are not needed to compute. Default None. the Tensors whose gradients are not needed to compute. Default None.
Returns: Returns:
tuple: a tuple of Tensors, whose length is the same as the Tensor number list: a list of Tensors, whose length is the same as the Tensor number
inside `inputs`, and the i-th returned Tensor is the sum of gradients of inside `inputs`, and the i-th returned Tensor is the sum of gradients of
`outputs` with respect to the i-th `inputs`. `outputs` with respect to the i-th `inputs`.
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import unittest import unittest
import paddle import paddle
from paddle.autograd.functional import vjp, jvp, to_tensorlist from paddle.autograd.functional import vjp, jvp, _tensors
from paddle import grad, ones_like, zeros_like from paddle import grad, ones_like, zeros_like
...@@ -55,7 +55,7 @@ def nested(x): ...@@ -55,7 +55,7 @@ def nested(x):
def make_v(f, inputs): def make_v(f, inputs):
outputs = to_tensorlist(f(*inputs)) outputs = _tensors(f(*inputs), "outputs")
return [ones_like(x) for x in outputs] return [ones_like(x) for x in outputs]
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import numpy as np import numpy as np
import paddle import paddle
from paddle.autograd.functional import _check_tensors from paddle.autograd.functional import _tensors
def _product(t): def _product(t):
...@@ -42,8 +42,8 @@ def _set_item(t, idx, value): ...@@ -42,8 +42,8 @@ def _set_item(t, idx, value):
def _compute_numerical_jacobian(func, xs, delta, np_dtype): def _compute_numerical_jacobian(func, xs, delta, np_dtype):
xs = _check_tensors(xs, "xs") xs = _tensors(xs, "xs")
ys = _check_tensors(func(*xs), "ys") ys = _tensors(func(*xs), "ys")
fin_size = len(xs) fin_size = len(xs)
fout_size = len(ys) fout_size = len(ys)
jacobian = list([] for _ in range(fout_size)) jacobian = list([] for _ in range(fout_size))
...@@ -59,11 +59,11 @@ def _compute_numerical_jacobian(func, xs, delta, np_dtype): ...@@ -59,11 +59,11 @@ def _compute_numerical_jacobian(func, xs, delta, np_dtype):
orig = _get_item(xs[j], q) orig = _get_item(xs[j], q)
x_pos = orig + delta x_pos = orig + delta
xs[j] = _set_item(xs[j], q, x_pos) xs[j] = _set_item(xs[j], q, x_pos)
ys_pos = _check_tensors(func(*xs), "ys_pos") ys_pos = _tensors(func(*xs), "ys_pos")
x_neg = orig - delta x_neg = orig - delta
xs[j] = _set_item(xs[j], q, x_neg) xs[j] = _set_item(xs[j], q, x_neg)
ys_neg = _check_tensors(func(*xs), "ys_neg") ys_neg = _tensors(func(*xs), "ys_neg")
xs[j] = _set_item(xs[j], q, orig) xs[j] = _set_item(xs[j], q, orig)
...@@ -76,8 +76,8 @@ def _compute_numerical_jacobian(func, xs, delta, np_dtype): ...@@ -76,8 +76,8 @@ def _compute_numerical_jacobian(func, xs, delta, np_dtype):
def _compute_numerical_hessian(func, xs, delta, np_dtype): def _compute_numerical_hessian(func, xs, delta, np_dtype):
xs = _check_tensors(xs, "xs") xs = _tensors(xs, "xs")
ys = _check_tensors(func(*xs), "ys") ys = _tensors(func(*xs), "ys")
fin_size = len(xs) fin_size = len(xs)
hessian = list([] for _ in range(fin_size)) hessian = list([] for _ in range(fin_size))
for i in range(fin_size): for i in range(fin_size):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册