未验证 提交 ec2f68e8 编写于 作者: L levi131 提交者: GitHub

Add functional autograd API: jacobian (#35917)

* init functional jacobian api

* finish test with dtype float32

* add float64 test case

* polish code

* use atol=1e-5 with dtype float64

* fix for ci

* set timeout for test_jacobian

* polish API docstring

* modify docstring
上级 6d62769a
...@@ -18,5 +18,6 @@ from .backward_mode import backward # noqa: F401 ...@@ -18,5 +18,6 @@ from .backward_mode import backward # noqa: F401
from .py_layer import PyLayer, PyLayerContext # noqa: F401 from .py_layer import PyLayer, PyLayerContext # noqa: F401
from ..framework import set_grad_enabled # noqa: F401 from ..framework import set_grad_enabled # noqa: F401
from ..fluid.dygraph.base import no_grad_ as no_grad # noqa: F401 from ..fluid.dygraph.base import no_grad_ as no_grad # noqa: F401
from .functional import jacobian # noqa: F401
__all__ = ['backward', 'PyLayer', 'PyLayerContext'] __all__ = ['backward', 'PyLayer', 'PyLayerContext']
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid import framework
import paddle
def _check_tensors(in_out_list, name):
assert in_out_list is not None, "{} should not be None".format(name)
if isinstance(in_out_list, (list, tuple)):
assert len(in_out_list) > 0, "{} connot be empyt".format(name)
for each_var in in_out_list:
assert isinstance(
each_var,
paddle.Tensor), "Elements of {} must be paddle.Tensor".format(
name)
return in_out_list
else:
assert isinstance(
in_out_list,
paddle.Tensor), "{} must be Tensor or list of Tensor".format(name)
return [in_out_list]
def _stack_tensor_or_return_none(origin_list):
assert len(origin_list) > 0, "Can't not stack an empty list"
return paddle.stack(
origin_list, axis=0) if isinstance(origin_list[0],
paddle.Tensor) else None
@framework.dygraph_only
def jacobian(func, inputs, create_graph=False, allow_unused=False):
'''
.. note::
**This API is ONLY available in imperative mode.**
This API computes the Jacobian matrix of `func` with respect to `inputs`.
Parameters:
func (function): a Python function that takes a Tensor or a Tensor
list/tuple as inputs and returns a Tensor or a Tensor tuple.
inputs (Tensor|list(Tensor)|tuple(Tensor)): the input Tensor or
Tensor list/tuple of the function ``func``.
create_graph (bool, optional): whether to create the gradient graphs
of the computing process. When it is True, higher order derivatives
are supported to compute; when it is False, the gradient graphs of
the computing process would be discarded. Defaults to ``False``.
allow_unused (bool, optional): whether to raise error or return None if
some Tensors of `inputs` are unreachable in the graph. Error would
be raised if allow_unused=False, and None would be returned as
their gradients if allow_unused=True. Default False.
Returns:
Jacobian (Tensor or nested tuple of Tensors): if function ``func``
takes a Tensor as inputs and returns a Tensor as outputs, Jacobian
will be a single Tensor containing the Jacobian matrix for the
linearized inputs and outputs. If one of the inputs and outputs is
a Tensor, and another is a Tensor list/tuple, then the Jacobian will
be a tuple of Tensors. If both of inputs and outputs are Tensor
list/tuple, then the Jacobian will be a tuple of tuple of Tensors
where ``Jacobian[i][j]`` will contain the Jacobian matrix of the
linearized ``i``th output and ``j``th input and will have same
dtype and device as the corresponding input. ``Jacobian[i][j]`` will
have as size ``m * n``, where ``m`` and ``n`` denote the numbers of
elements of ``i``th output and ``j``th input respectively.
Examples 1:
.. code-block:: python
import paddle
def func(x):
return paddle.matmul(x, x)
x = paddle.ones(shape=[2, 2], dtype='float32')
x.stop_gradient = False
jacobian = paddle.autograd.jacobian(func, x)
print(jacobian)
# Tensor(shape=[4, 4], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [[2., 1., 1., 0.],
# [1., 2., 0., 1.],
# [1., 0., 2., 1.],
# [0., 1., 1., 2.]])
Examples 2:
.. code-block:: python
import paddle
def func(x, y):
return paddle.matmul(x, y)
x = paddle.ones(shape=[2, 2], dtype='float32')
y = paddle.ones(shape=[2, 2], dtype='float32') * 2
x.stop_gradient = False
y.stop_gradient = False
jacobian = paddle.autograd.jacobian(func, [x, y], create_graph=True)
print(jacobian)
# (Tensor(shape=[4, 4], dtype=float32, place=CUDAPlace(0), stop_gradient=False,
# [[2., 2., 0., 0.],
# [2., 2., 0., 0.],
# [0., 0., 2., 2.],
# [0., 0., 2., 2.]]),
# Tensor(shape=[4, 4], dtype=float32, place=CUDAPlace(0), stop_gradient=False,
# [[1., 0., 1., 0.],
# [0., 1., 0., 1.],
# [1., 0., 1., 0.],
# [0., 1., 0., 1.]]))
Examples 3:
.. code-block:: python
import paddle
def func(x, y):
return paddle.matmul(x, y), x * x
x = paddle.ones(shape=[2, 2], dtype='float32')
y = paddle.ones(shape=[2, 2], dtype='float32') * 2
x.stop_gradient = False
y.stop_gradient = False
jacobian = paddle.autograd.jacobian(func, [x, y], allow_unused=True)
print(jacobian)
# ((Tensor(shape=[4, 4], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [[2., 2., 0., 0.],
# [2., 2., 0., 0.],
# [0., 0., 2., 2.],
# [0., 0., 2., 2.]]),
# Tensor(shape=[4, 4], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [[1., 0., 1., 0.],
# [0., 1., 0., 1.],
# [1., 0., 1., 0.],
# [0., 1., 0., 1.]])),
# (Tensor(shape=[4, 4], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [[2., 0., 0., 0.],
# [0., 2., 0., 0.],
# [0., 0., 2., 0.],
# [0., 0., 0., 2.]]), None))
'''
inputs = _check_tensors(inputs, "inputs")
outputs = _check_tensors(func(*inputs), "outputs")
fin_size = len(inputs)
fout_size = len(outputs)
flat_outputs = tuple(
paddle.reshape(
output, shape=[-1]) for output in outputs)
jacobian = tuple()
for i, flat_output in enumerate(flat_outputs):
jac_i = list([] for _ in range(fin_size))
for k in range(len(flat_output)):
row_k = paddle.grad(
flat_output[k],
inputs,
create_graph=create_graph,
retain_graph=True,
allow_unused=allow_unused)
for j in range(fin_size):
jac_i[j].append(
paddle.reshape(
row_k[j], shape=[-1])
if isinstance(row_k[j], paddle.Tensor) else None)
jacobian += (tuple(
_stack_tensor_or_return_none(jac_i_j) for jac_i_j in jac_i), )
if fin_size == 1 and fout_size == 1:
return jacobian[0][0]
elif fin_size == 1 and fout_size != 1:
return tuple(jacobian[i][0] for i in range(fout_size))
elif fin_size != 1 and fout_size == 1:
return jacobian[0]
else:
return jacobian
...@@ -414,7 +414,7 @@ def grad(outputs, ...@@ -414,7 +414,7 @@ def grad(outputs,
no_grad_vars=None): no_grad_vars=None):
''' '''
.. note:: .. note::
**This API is ONLY available in Dygraph mode.** **This API is ONLY available in imperative mode.**
This API computes the sum of gradients of `outputs` with respect to each `inputs` . This API computes the sum of gradients of `outputs` with respect to each `inputs` .
......
...@@ -702,6 +702,7 @@ endif() ...@@ -702,6 +702,7 @@ endif()
add_subdirectory(sequence) add_subdirectory(sequence)
add_subdirectory(dygraph_to_static) add_subdirectory(dygraph_to_static)
add_subdirectory(rnn) add_subdirectory(rnn)
add_subdirectory(autograd)
if (NOT WIN32 OR NOT WITH_GPU) if (NOT WIN32 OR NOT WITH_GPU)
add_subdirectory(fft) add_subdirectory(fft)
......
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
set(GC_ENVS FLAGS_eager_delete_tensor_gb=0.0)
foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS})
endforeach(TEST_OP)
set_tests_properties(test_jacobian PROPERTIES TIMEOUT 20)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.compat as cpt
from paddle.autograd.functional import _check_tensors
def _product(t):
if isinstance(t, int):
return t
else:
return np.product(t)
def _get_item(t, idx):
assert isinstance(t, paddle.Tensor), "The first argument t must be Tensor."
assert isinstance(idx,
int), "The second argument idx must be an int number."
flat_t = paddle.reshape(t, [-1])
return flat_t.__getitem__(idx)
def _set_item(t, idx, value):
assert isinstance(t, paddle.Tensor), "The first argument t must be Tensor."
assert isinstance(idx,
int), "The second argument idx must be an int number."
flat_t = paddle.reshape(t, [-1])
flat_t.__setitem__(idx, value)
return paddle.reshape(flat_t, t.shape)
def _compute_numerical_jacobian(func, xs, delta, np_dtype):
xs = _check_tensors(xs, "xs")
ys = _check_tensors(func(*xs), "ys")
fin_size = len(xs)
fout_size = len(ys)
jacobian = list([] for _ in range(fout_size))
for i in range(fout_size):
jac_i = list([] for _ in range(fin_size))
for j in range(fin_size):
jac_i[j] = np.zeros(
(_product(ys[i].shape), _product(xs[j].shape)), dtype=np_dtype)
jacobian[i] = jac_i
for j in range(fin_size):
for q in range(_product(xs[j].shape)):
orig = _get_item(xs[j], q)
x_pos = orig + delta
xs[j] = _set_item(xs[j], q, x_pos)
ys_pos = _check_tensors(func(*xs), "ys_pos")
x_neg = orig - delta
xs[j] = _set_item(xs[j], q, x_neg)
ys_neg = _check_tensors(func(*xs), "ys_neg")
xs[j] = _set_item(xs[j], q, orig)
for i in range(fout_size):
for p in range(_product(ys[i].shape)):
y_pos = _get_item(ys_pos[i], p)
y_neg = _get_item(ys_neg[i], p)
jacobian[i][j][p][q] = (y_pos - y_neg) / delta / 2.
return jacobian
class TestJacobian(unittest.TestCase):
@classmethod
def setUpClass(self):
self.shape = (4, 4)
self.dtype = 'float32'
self.np_dtype = np.float32
self.numerical_delta = 1e-4
self.rtol = 1e-3
self.atol = 1e-3
self.x = paddle.rand(shape=self.shape, dtype=self.dtype)
self.y = paddle.rand(shape=self.shape, dtype=self.dtype)
def test_single_input_and_single_output(self):
def func(x):
return paddle.matmul(x, x)
numerical_jacobian = _compute_numerical_jacobian(
func, self.x, self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
jacobian = paddle.autograd.jacobian(func, self.x)
assert np.allclose(jacobian.numpy(), numerical_jacobian[0][0],
self.rtol, self.atol)
def test_single_input_and_multi_output(self):
def func(x):
return paddle.matmul(x, x), x * x
numerical_jacobian = _compute_numerical_jacobian(
func, self.x, self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
jacobian = paddle.autograd.jacobian(func, self.x)
for i in range(len(jacobian)):
assert np.allclose(jacobian[i].numpy(), numerical_jacobian[i][0],
self.rtol, self.atol)
def test_multi_input_and_single_output(self):
def func(x, y):
return paddle.matmul(x, y)
numerical_jacobian = _compute_numerical_jacobian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
jacobian = paddle.autograd.jacobian(func, [self.x, self.y])
for j in range(len(jacobian)):
assert np.allclose(jacobian[j].numpy(), numerical_jacobian[0][j],
self.rtol, self.atol)
def test_multi_input_and_multi_output(self):
def func(x, y):
return paddle.matmul(x, y), x * y
numerical_jacobian = _compute_numerical_jacobian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
jacobian = paddle.autograd.jacobian(func, [self.x, self.y])
for i in range(len(jacobian)):
for j in range(len(jacobian[0])):
assert np.allclose(jacobian[i][j].numpy(),
numerical_jacobian[i][j], self.rtol,
self.atol)
def test_allow_unused_false(self):
def func(x, y):
return paddle.matmul(x, x)
try:
self.x.stop_gradient = False
self.y.stop_gradient = False
jacobian = paddle.autograd.jacobian(func, [self.x, self.y])
except ValueError as e:
error_msg = cpt.get_exception_message(e)
assert error_msg.find("allow_unused") > 0
def test_allow_unused_true(self):
def func(x, y):
return paddle.matmul(x, x)
numerical_jacobian = _compute_numerical_jacobian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
jacobian = paddle.autograd.jacobian(
func, [self.x, self.y], allow_unused=True)
assert np.allclose(jacobian[0].numpy(), numerical_jacobian[0][0],
self.rtol, self.atol)
assert jacobian[1] is None
def test_create_graph_false(self):
def func(x, y):
return paddle.matmul(x, y)
numerical_jacobian = _compute_numerical_jacobian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
jacobian = paddle.autograd.jacobian(func, [self.x, self.y])
for j in range(len(jacobian)):
assert jacobian[j].stop_gradient == True
assert np.allclose(jacobian[j].numpy(), numerical_jacobian[0][j],
self.rtol, self.atol)
try:
paddle.grad(jacobian[0], [self.x, self.y])
except RuntimeError as e:
error_msg = cpt.get_exception_message(e)
assert error_msg.find("has no gradient") > 0
def test_create_graph_true(self):
def func(x, y):
return paddle.matmul(x, y)
numerical_jacobian = _compute_numerical_jacobian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
jacobian = paddle.autograd.jacobian(
func, [self.x, self.y], create_graph=True)
for j in range(len(jacobian)):
assert jacobian[j].stop_gradient == False
assert np.allclose(jacobian[j].numpy(), numerical_jacobian[0][j],
self.rtol, self.atol)
double_grad = paddle.grad(jacobian[0], [self.x, self.y])
assert double_grad is not None
class TestJacobianFloat64(TestJacobian):
@classmethod
def setUpClass(self):
self.shape = (4, 4)
self.dtype = 'float64'
self.np_dtype = np.float64
self.numerical_delta = 1e-7
self.rtol = 1e-7
self.atol = 1e-7
self.x = paddle.rand(shape=self.shape, dtype=self.dtype)
self.y = paddle.rand(shape=self.shape, dtype=self.dtype)
# NOTE(levi): skip this test case temporaryly.
def test_create_graph_true(self):
pass
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册