未验证 提交 a97a8dd1 编写于 作者: X Xiaoxu Chen 提交者: GitHub

Add forward_gradients api and enable high-order differentiation for Jacobian/Hessian (#43354)

* enable Jacobian,Hessian supporting new autograd

* fix prim mode failed in PR-CI-Windows

* add forward_gradients api

* add forward_gradients api

* skip test_autograd_functional_prim in windows ci

* fix test_autograd_funciton_prim timeouot

* remove the block parameter in prim2orig method

* remove duplicate to_tensors code snippet # test=allcases
上级 82cd8d21
......@@ -17,6 +17,7 @@ import typing
import paddle
from paddle.fluid import framework
from paddle.autograd.utils import as_tensors
def vjp(func, xs, v=None):
......@@ -346,10 +347,16 @@ class _Jacobian(object):
"""
def __init__(self, func, xs):
self._xs = _separate(xs)
self._ys = func(*_as_tensors(self._xs))
self._flatten_xs = self._flatten(_as_tensors(self._xs))
self._flatten_ys = self._flatten(_as_tensors(self._ys))
# Skip separating in prim mode temporarily, as detach and clone are not
# primitive operators.
if not paddle.fluid._non_static_mode(
) and paddle.incubate.autograd.prim_enabled():
self._xs = xs
else:
self._xs = _separate(xs)
self._ys = func(*as_tensors(self._xs))
self._flatten_xs = self._flatten(as_tensors(self._xs))
self._flatten_ys = self._flatten(as_tensors(self._ys))
self._cache = {}
@property
......@@ -385,9 +392,13 @@ class _Jacobian(object):
return self._cached_evaluate(
indexes[self._lazy_axis])[other_indexes]
lazy_indexes = self._lazy_indexes(indexes)
part_jac = paddle.stack(
# Using concat and reshape to replace stack operator temporarily, as
# it is not a primitive operator.
shape = list(self.shape)
shape[self._lazy_axis] = len(lazy_indexes)
part_jac = paddle.concat(
[self._cached_evaluate(i) for i in lazy_indexes],
axis=self._lazy_axis)
axis=self._lazy_axis).reshape(shape)
return part_jac[self._shifted_indexes(indexes, len(lazy_indexes))]
def _cached_evaluate(self, k):
......@@ -449,7 +460,7 @@ class _JacobianBatchLast(_Jacobian):
def _flatten(self, xs):
return paddle.concat(
tuple(x.reshape((-1, x.shape[-1])) for x in _as_tensors(xs)), 0)
tuple(x.reshape((-1, x.shape[-1])) for x in as_tensors(xs)), 0)
def _evaluate(self, row):
return self._flatten(_grad(self._flatten_ys[row, :], self._xs))
......@@ -475,7 +486,7 @@ class _JacobianBatchFirst(_Jacobian):
def _flatten(self, xs):
return paddle.concat(
tuple(x.reshape((x.shape[0], -1)) for x in _as_tensors(xs)), 1)
tuple(x.reshape((x.shape[0], -1)) for x in as_tensors(xs)), 1)
def _evaluate(self, row_index):
return self._flatten(_grad(self._flatten_ys[:, row_index], self._xs))
......@@ -526,10 +537,6 @@ def _multi_index(indexes, shape):
return tuple(positive_indexes)
def _as_tensors(xs):
return (xs, ) if isinstance(xs, framework.Variable) else xs
def _stack_tensor_or_return_none(origin_list):
assert len(origin_list) > 0, "Can't not stack an empty list"
return paddle.stack(origin_list, axis=0) if isinstance(
......@@ -683,7 +690,7 @@ def _check_v_shape(v, refs):
if v is None:
return
v, refs = _as_tensors(v), _as_tensors(refs)
v, refs = as_tensors(v), as_tensors(refs)
if len(refs) != len(v):
raise RuntimeError(f"The argument v is a tuple of invalid length:"
f"should be {len(refs)} but got {len(v)}.")
......@@ -805,8 +812,8 @@ def jacobian(func, inputs, create_graph=False, allow_unused=False):
# [0., 0., 0., 2.]]), None))
'''
inputs = _as_tensors(inputs)
outputs = _as_tensors(func(*inputs))
inputs = as_tensors(inputs)
outputs = as_tensors(func(*inputs))
fin_size = len(inputs)
fout_size = len(outputs)
flat_outputs = tuple(
......@@ -942,8 +949,8 @@ def batch_jacobian(func, inputs, create_graph=False, allow_unused=False):
'''
inputs = _as_tensors(inputs)
outputs = _as_tensors(func(*inputs))
inputs = as_tensors(inputs)
outputs = as_tensors(func(*inputs))
batch_size = inputs[0].shape[0]
for input in inputs:
......@@ -1103,7 +1110,7 @@ def batch_hessian(func, inputs, create_graph=False, allow_unused=False):
# [0., 2., 0., 2., 0., 2., 0., 2.]]), None), (None, None))
'''
inputs = _as_tensors(inputs)
inputs = as_tensors(inputs)
outputs = func(*inputs)
batch_size = inputs[0].shape[0]
for input in inputs:
......@@ -1234,7 +1241,7 @@ def hessian(func, inputs, create_graph=False, allow_unused=False):
# [0., 1., 1., 2.]]), None), (None, None))
'''
inputs = _as_tensors(inputs)
inputs = as_tensors(inputs)
outputs = func(*inputs)
assert isinstance(outputs, paddle.Tensor) and outputs.shape == [
1
......@@ -1339,12 +1346,12 @@ def vhp(func, inputs, v=None, create_graph=False, allow_unused=False):
# [[8., 8.],
# [8., 8.]]), None])
'''
xs = _as_tensors(inputs)
xs = as_tensors(inputs)
if v is not None:
v = _as_tensors(v)
v = as_tensors(v)
xs, v = _separate(xs), _separate(v)
outputs = func(*xs)
ys = _as_tensors(outputs)
ys = as_tensors(outputs)
assert len(ys) == 1 and isinstance(
ys[0], framework.Variable
) and ys[0].shape == [
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing
from paddle.fluid import framework
def as_tensors(xs):
if isinstance(xs, framework.Variable):
return (xs, )
elif isinstance(xs, typing.Sequence):
return tuple(xs)
else:
return xs
......@@ -5,10 +5,19 @@ file(
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
set(GC_ENVS FLAGS_eager_delete_tensor_gb=0.0)
if(WIN32)
# TODO: Fix these unittests failed on Windows
list(REMOVE_ITEM TEST_OPS test_autograd_functional_prim)
list(REMOVE_ITEM TEST_OPS test_primapi)
endif()
foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS})
endforeach()
set_tests_properties(test_autograd_functional_dynamic PROPERTIES TIMEOUT 160)
set_tests_properties(test_autograd_functional_dynamic PROPERTIES TIMEOUT 200)
set_tests_properties(test_autograd_functional_static PROPERTIES TIMEOUT 160)
set_tests_properties(test_gradients_and_minimize PROPERTIES TIMEOUT 60)
if(NOT WIN32)
set_tests_properties(test_autograd_functional_prim PROPERTIES TIMEOUT 60)
endif()
......@@ -21,7 +21,7 @@ import paddle
import paddle.fluid as fluid
import paddle.compat as cpt
import paddle.nn.functional as F
from paddle.autograd.functional import _as_tensors
from paddle.autograd.utils import as_tensors
from paddle.fluid.framework import _test_eager_guard, _in_legacy_dygraph, _in_eager_without_dygraph_check
import config
......@@ -33,7 +33,7 @@ from utils import matmul, mul, nested, o2, pow, reduce, reduce_dim, unuse
def make_v(f, inputs):
outputs = _as_tensors(f(*inputs))
outputs = as_tensors(f(*inputs))
return [paddle.ones_like(x) for x in outputs]
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing
import unittest
import numpy as np
import paddle
import config
import utils
@utils.place(config.DEVICES)
@utils.parameterize((utils.TEST_CASE_NAME, 'fun', 'args', 'dtype'), (
('unary_float32', paddle.tanh, (np.random.rand(2, 3), ), 'float32'),
('binary_float32', paddle.matmul,
(np.random.rand(2, 3), np.random.rand(3, 2)), 'float32'),
('unary_float64', paddle.tanh, (np.random.rand(2, 3), ), 'float64'),
('binary_float64', paddle.matmul,
(np.random.rand(2, 3), np.random.rand(3, 2)), 'float64'),
))
class TestJacobianPrim(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.args = [arg.astype(cls.dtype) for arg in cls.args]
cls._rtol = config.TOLERANCE.get(
cls.dtype).get('first_order_grad').get('rtol')
cls._atol = config.TOLERANCE.get(
cls.dtype).get('first_order_grad').get('atol')
def setUp(self):
paddle.enable_static()
paddle.incubate.autograd.enable_prim()
def tearDown(self):
paddle.incubate.autograd.disable_prim()
paddle.disable_static()
def test_jacobian_prim(self):
def wrapper(fun, args):
mp = paddle.static.Program()
sp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
static_args = [
paddle.static.data(f'arg{i}', arg.shape, self.dtype)
for i, arg in enumerate(args)
]
for arg in static_args:
arg.stop_gradient = False
jac = paddle.incubate.autograd.Jacobian(fun, static_args)[:]
if paddle.incubate.autograd.prim_enabled():
paddle.incubate.autograd.prim2orig()
exe = paddle.static.Executor()
exe.run(sp)
[jac] = exe.run(mp,
feed={f'arg{i}': arg
for i, arg in enumerate(args)},
fetch_list=[jac])
return jac
paddle.incubate.autograd.enable_prim()
prim_jac = wrapper(self.fun, self.args)
paddle.incubate.autograd.disable_prim()
orig_jac = wrapper(self.fun, self.args)
np.testing.assert_allclose(orig_jac,
prim_jac,
rtol=self._rtol,
atol=self._atol)
@utils.place(config.DEVICES)
@utils.parameterize((utils.TEST_CASE_NAME, 'fun', 'args', 'dtype'), (
('unary_float32', paddle.tanh, (np.random.rand(1), ), 'float32'),
('binary_float32', paddle.multiply,
(np.random.rand(1), np.random.rand(1)), 'float32'),
('unary_float64', paddle.tanh, (np.random.rand(1), ), 'float64'),
('binary_float64', paddle.multiply,
(np.random.rand(1), np.random.rand(1)), 'float64'),
))
class TestHessianPrim(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.args = [arg.astype(cls.dtype) for arg in cls.args]
cls._rtol = config.TOLERANCE.get(
cls.dtype).get('second_order_grad').get('rtol')
cls._atol = config.TOLERANCE.get(
cls.dtype).get('second_order_grad').get('atol')
def setUp(self):
paddle.enable_static()
paddle.incubate.autograd.enable_prim()
def tearDown(self):
paddle.incubate.autograd.disable_prim()
paddle.disable_static()
def test_jacobian_prim(self):
def wrapper(fun, args):
mp = paddle.static.Program()
sp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
static_args = [
paddle.static.data(f'arg{i}', arg.shape, self.dtype)
for i, arg in enumerate(args)
]
for arg in static_args:
arg.stop_gradient = False
hessian = paddle.incubate.autograd.Hessian(fun, static_args)[:]
if paddle.incubate.autograd.prim_enabled():
paddle.incubate.autograd.prim2orig()
exe = paddle.static.Executor()
exe.run(sp)
[hessian
] = exe.run(mp,
feed={f'arg{i}': arg
for i, arg in enumerate(args)},
fetch_list=[hessian])
return hessian
paddle.incubate.autograd.enable_prim()
prim_jac = wrapper(self.fun, self.args)
paddle.incubate.autograd.disable_prim()
orig_jac = wrapper(self.fun, self.args)
np.testing.assert_allclose(orig_jac,
prim_jac,
rtol=self._rtol,
atol=self._atol)
if __name__ == "__main__":
unittest.main()
......@@ -23,7 +23,6 @@ import config
import utils
from utils import (_compute_numerical_batch_jacobian,
_compute_numerical_jacobian)
from paddle.autograd.functional import _as_tensors
paddle.enable_static()
......@@ -58,7 +57,7 @@ class TestVJP(unittest.TestCase):
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
feed, static_xs, static_v = gen_static_data_and_feed(
feed, static_xs, static_v = utils.gen_static_data_and_feed(
self.xs, self.v, stop_gradient=self.stop_gradient)
ys, xs_grads = paddle.autograd.vjp(self.fun, static_xs, static_v)
exe.run(sp)
......@@ -69,7 +68,7 @@ class TestVJP(unittest.TestCase):
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
feed, static_xs, static_v = gen_static_data_and_feed(
feed, static_xs, static_v = utils.gen_static_data_and_feed(
self.xs, self.v, False)
ys = self.fun(*static_xs) if isinstance(
static_xs, typing.Sequence) else self.fun(static_xs)
......@@ -102,7 +101,7 @@ class TestVJPException(unittest.TestCase):
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
feed, static_xs, static_v = gen_static_data_and_feed(
feed, static_xs, static_v = utils.gen_static_data_and_feed(
self.xs, self.v)
ys, xs_grads = paddle.autograd.vjp(self.fun, static_xs, static_v)
self.exe.run(sp)
......@@ -113,37 +112,6 @@ class TestVJPException(unittest.TestCase):
self._vjp()
def gen_static_data_and_feed(xs, v, stop_gradient=True):
feed = {}
if isinstance(xs, typing.Sequence):
static_xs = []
for i, x in enumerate(xs):
x = paddle.static.data(f"x{i}", x.shape, x.dtype)
x.stop_gradient = stop_gradient
static_xs.append(x)
feed.update({f'x{idx}': value for idx, value in enumerate(xs)})
else:
static_xs = paddle.static.data('x', xs.shape, xs.dtype)
static_xs.stop_gradient = stop_gradient
feed.update({'x': xs})
if isinstance(v, typing.Sequence):
static_v = []
for i, e in enumerate(v):
e = paddle.static.data(f'v{idx}', v.shape, v.dtype)
e.stop_gradient = stop_gradient
static_v.append(e)
feed.update({f'v{idx}': value for idx, value in v})
elif v is not None:
static_v = paddle.static.data('v', v.shape, v.dtype)
static_v.stop_gradient = stop_gradient
feed.update({'v': v})
else:
static_v = v
return feed, static_xs, static_v
def approx_jacobian(f, xs, dtype, eps=1e-5, batch=False):
r"""Computes an approximate Jacobian matrix of a multi-valued function
using finite differences.
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing
import unittest
import numpy as np
import paddle
from paddle.incubate.autograd import primapi
import config
import utils
@utils.place(config.DEVICES)
@utils.parameterize(
(utils.TEST_CASE_NAME, 'fun', 'xs', 'v', 'dtype'),
(('matmul', paddle.matmul,
(np.random.rand(2, 3), np.random.rand(3, 2)), None, 'float32'),
('multiply', paddle.multiply,
(np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float64'),
('add', paddle.add,
(np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float32'),
('input_not_sequence', paddle.tanh,
(np.random.rand(5, 5), ), None, 'float64'),
('input_gradients_not_none', paddle.matmul,
(np.random.rand(3, 3), np.random.rand(3, 3)),
(np.random.rand(3, 3), np.random.rand(3, 3)), 'float64')))
class TestForwardGradients(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.xs = tuple(x.astype(cls.dtype) for x in cls.xs)
cls._rtol = config.TOLERANCE.get(str(
cls.dtype)).get("first_order_grad").get("rtol")
cls._atol = config.TOLERANCE.get(str(
cls.dtype)).get("first_order_grad").get("atol")
def setUp(self):
paddle.enable_static()
paddle.incubate.autograd.enable_prim()
def tearDown(self):
paddle.incubate.autograd.disable_prim()
paddle.disable_static()
def test_forward_gradients(self):
def expected():
paddle.incubate.autograd.disable_prim()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
feed, static_xs, static_v = utils.gen_static_data_and_feed(
self.xs, self.v, stop_gradient=False)
_, ys_grad = paddle.autograd.jvp(self.fun, static_xs, static_v)
exe = paddle.static.Executor()
exe.run(sp)
out = exe.run(mp, feed=feed, fetch_list=ys_grad)
paddle.incubate.autograd.enable_prim()
return out
def actual():
paddle.incubate.autograd.enable_prim()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
feed, static_xs, static_v = utils.gen_static_data_and_feed(
self.xs, self.v, stop_gradient=False)
ys = self.fun(*static_xs) if isinstance(
static_xs, typing.Sequence) else self.fun(static_xs)
ys_grad = primapi.forward_gradients(ys, static_xs, static_v)
paddle.incubate.autograd.prim2orig(mp.block(0))
exe = paddle.static.Executor()
exe.run(sp)
out = exe.run(mp, feed=feed, fetch_list=ys_grad)
paddle.incubate.autograd.disable_prim()
return out
actual = actual()
expected = expected()
self.assertEqual(type(actual), type(expected))
np.testing.assert_allclose(np.concatenate(actual),
np.concatenate(expected),
rtol=self._rtol,
atol=self._atol)
def test_prim_disabled(self):
paddle.incubate.autograd.disable_prim()
sp = paddle.static.Program()
mp = paddle.static.Program()
with self.assertRaises(RuntimeError):
with paddle.static.program_guard(mp, sp):
feed, static_xs, static_v = utils.gen_static_data_and_feed(
self.xs, self.v, stop_gradient=False)
ys = self.fun(*static_xs) if isinstance(
static_xs, typing.Sequence) else self.fun(static_xs)
ys_grad = primapi.forward_gradients(ys, static_xs, static_v)
paddle.incubate.autograd.prim2orig(mp.block(0))
exe = paddle.static.Executor()
exe.run(sp)
exe.run(mp, feed=feed, fetch_list=ys_grad)
paddle.incubate.autograd.enable_prim()
def test_illegal_param(self):
paddle.incubate.autograd.enable_prim()
with self.assertRaises(TypeError):
primapi.forward_gradients(1, paddle.static.data('inputs',
shape=[1]))
with self.assertRaises(TypeError):
primapi.forward_gradients(paddle.static.data('targets', shape=[1]),
1)
paddle.incubate.autograd.disable_prim()
if __name__ == '__main__':
unittest.main()
......@@ -22,7 +22,7 @@ import contextlib
import collections
import numpy as np
import paddle
from paddle.autograd.functional import _as_tensors
from paddle.autograd.utils import as_tensors
##########################################################
......@@ -57,8 +57,8 @@ def _set_item(t, idx, value):
def _compute_numerical_jacobian(func, xs, delta, np_dtype):
xs = list(_as_tensors(xs))
ys = list(_as_tensors(func(*xs)))
xs = list(as_tensors(xs))
ys = list(as_tensors(func(*xs)))
fin_size = len(xs)
fout_size = len(ys)
jacobian = list([] for _ in range(fout_size))
......@@ -74,11 +74,11 @@ def _compute_numerical_jacobian(func, xs, delta, np_dtype):
orig = _get_item(xs[j], q)
x_pos = orig + delta
xs[j] = _set_item(xs[j], q, x_pos)
ys_pos = _as_tensors(func(*xs))
ys_pos = as_tensors(func(*xs))
x_neg = orig - delta
xs[j] = _set_item(xs[j], q, x_neg)
ys_neg = _as_tensors(func(*xs))
ys_neg = as_tensors(func(*xs))
xs[j] = _set_item(xs[j], q, orig)
......@@ -91,8 +91,8 @@ def _compute_numerical_jacobian(func, xs, delta, np_dtype):
def _compute_numerical_hessian(func, xs, delta, np_dtype):
xs = list(_as_tensors(xs))
ys = list(_as_tensors(func(*xs)))
xs = list(as_tensors(xs))
ys = list(as_tensors(func(*xs)))
fin_size = len(xs)
hessian = list([] for _ in range(fin_size))
for i in range(fin_size):
......@@ -136,8 +136,8 @@ def _compute_numerical_batch_jacobian(func,
np_dtype,
merge_batch=True):
no_batch_jacobian = _compute_numerical_jacobian(func, xs, delta, np_dtype)
xs = list(_as_tensors(xs))
ys = list(_as_tensors(func(*xs)))
xs = list(as_tensors(xs))
ys = list(as_tensors(func(*xs)))
fin_size = len(xs)
fout_size = len(ys)
bs = xs[0].shape[0]
......@@ -164,7 +164,7 @@ def _compute_numerical_batch_jacobian(func,
def _compute_numerical_batch_hessian(func, xs, delta, np_dtype):
xs = list(_as_tensors(xs))
xs = list(as_tensors(xs))
batch_size = xs[0].shape[0]
fin_size = len(xs)
hessian = []
......@@ -202,7 +202,7 @@ def _compute_numerical_batch_hessian(func, xs, delta, np_dtype):
def _compute_numerical_vjp(func, xs, v, delta, np_dtype):
xs = _as_tensors(xs)
xs = as_tensors(xs)
jacobian = np.array(_compute_numerical_jacobian(func, xs, delta, np_dtype))
if v is None:
v = [paddle.ones_like(x) for x in xs]
......@@ -217,7 +217,7 @@ def _compute_numerical_vjp(func, xs, v, delta, np_dtype):
def _compute_numerical_vhp(func, xs, v, delta, np_dtype):
xs = list(_as_tensors(xs))
xs = list(as_tensors(xs))
hessian = np.array(_compute_numerical_hessian(func, xs, delta, np_dtype))
flat_v = np.array([v_el.numpy().reshape(-1) for v_el in v])
vhp = [np.zeros((_product(x.shape)), dtype=np_dtype) for x in xs]
......@@ -391,3 +391,37 @@ def _np_concat_matrix_sequence(src, src_format=MatrixFormat.NM):
if not isinstance(src[0], typing.Sequence):
src = [src]
return concat_row(tuple(concat_col(xs) for xs in src))
##########################################################
# Utils for generating test data.
##########################################################
def gen_static_data_and_feed(xs, v, stop_gradient=True):
feed = {}
if isinstance(xs, typing.Sequence):
static_xs = []
for i, x in enumerate(xs):
x = paddle.static.data(f"x{i}", x.shape, x.dtype)
x.stop_gradient = stop_gradient
static_xs.append(x)
feed.update({f'x{idx}': value for idx, value in enumerate(xs)})
else:
static_xs = paddle.static.data('x', xs.shape, xs.dtype)
static_xs.stop_gradient = stop_gradient
feed.update({'x': xs})
if isinstance(v, typing.Sequence):
static_v = []
for i, e in enumerate(v):
e = paddle.static.data(f'v{i}', e.shape, e.dtype)
e.stop_gradient = stop_gradient
static_v.append(e)
feed.update({f'v{i}': value for i, value in enumerate(v)})
elif v is not None:
static_v = paddle.static.data('v', v.shape, v.dtype)
static_v.stop_gradient = stop_gradient
feed.update({'v': v})
else:
static_v = v
return feed, static_xs, static_v
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing
import paddle.autograd.utils as tensor_utils
import paddle.incubate.autograd.utils as prim_utils
from paddle.fluid import framework
from paddle.incubate.autograd import primx
@framework.static_only
def forward_gradients(targets, inputs, input_gradients=None):
"""Forward mode of automatic differentiation.
.. note::
**ONLY available in the static mode and primitive operators.**
Args:
targets: The target tensor or tensors
inputs: The input tensor or tensors
input_gradients: The gradient Tensor or Tensors of inputs which has
the same shape with inputs, Defaults to None, in this case is
equivalent to all ones .
Returns:
target_gradients (Tensor|Sequence[Tensor]): The gradients for targets.
Examples:
.. code-block:: python
import numpy as np
import paddle
paddle.enable_static()
paddle.incubate.autograd.enable_prim()
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
x = paddle.static.data('x', shape=[1], dtype='float32')
y = x * x
y_grad = paddle.incubate.autograd.forward_gradients(y, x)
paddle.incubate.autograd.prim2orig()
exe = paddle.static.Executor()
exe.run(startup_program)
y_grad = exe.run(main_program, feed={'x': np.array([2.]).astype('float32')}, fetch_list=[y_grad])
print(y_grad)
# [array([4.], dtype=float32)]
paddle.incubate.autograd.disable_prim()
paddle.disable_static()
"""
if not prim_utils.prim_enabled():
raise RuntimeError('forward_gradients must be running on primitive'
'operators, use enable_prim to turn it on.')
if not isinstance(targets, (framework.Variable, typing.Sequence)):
raise TypeError(f'Expected targets is Tensor|Sequence[Tesnor], '
f'but got {type(targets)}.')
if not isinstance(inputs, (framework.Variable, typing.Sequence)):
raise TypeError(f'Expected inputs is Tensor|Sequence[Tesnor], '
f'but got {type(inputs)}.')
ys, xs, xs_dot = tensor_utils.as_tensors(targets), tensor_utils.as_tensors(
inputs), tensor_utils.as_tensors(input_gradients)
block = framework.default_main_program().current_block()
if any(x.block != block for x in xs + ys):
raise RuntimeError(
'Variable in inputs and targets should exist in current block of '
'main program.')
primx.orig2prim(block)
ad = primx.Transform(ys[0].block)
_, ys_dot = ad.linearize(xs, ys, xs_dot)
return ys_dot[0] if isinstance(targets, framework.Variable) else ys_dot
......@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing
import paddle
......@@ -656,10 +657,14 @@ def split_transpose(op, check_dot, ys_bar):
@REGISTER_TRANSPOSE('concat_p')
def concat_transpose(op, check_dot, y_bar):
xs, = op_position_inputs(op)
if not isinstance(xs, typing.Sequence):
xs = [xs]
for x in xs:
assert check_dot(x), 'check_dot(x) must be True'
axis = op.attr('axis')
sections = [x.shape[axis] for x in xs]
if len(sections) == 1:
return y_bar
return split(y_bar, num_or_sections=sections, axis=axis)
......
......@@ -20,8 +20,9 @@ from paddle import compat as cpt
from .primops import fill_const, add
from .primreg import op_position_inputs, op_position_output, lookup_orig2prim, lookup_prim2orig
from .primrules import _orig2prim, _prim2orig, _jvp, _transpose
from .utils import get_input_var_list, get_output_var_list, to_tensors, flatten, flatten_and_remove_none
from .utils import get_input_var_list, get_output_var_list, flatten, flatten_and_remove_none
from collections import OrderedDict
from paddle.autograd.utils import as_tensors
def topo_path(xs, ys, block=None):
......@@ -457,7 +458,7 @@ def _lower(block, reverse):
for orig_out, new_out in zip(
expand_nested_list(get_output_var_list(op)),
expand_nested_list(to_tensors(lower_fn(op, *input_args)))):
expand_nested_list(as_tensors(lower_fn(op, *input_args)))):
assert not (orig_out is None) ^ (
new_out is None), "orig_out and new_out should match."
vars_to_remove.add(new_out.name)
......@@ -591,7 +592,7 @@ def _gradients(ys, xs, ys_bar=None):
xs_bar: a list gradients of input `xs`
"""
ys, xs = to_tensors(ys), to_tensors(xs)
ys, xs, ys_bar = as_tensors(ys), as_tensors(xs), as_tensors(ys_bar)
block = default_main_program().current_block()
for el in xs + ys:
assert el is None or el.block == block, f'variable in xs and ys should be None or in current block of main program'
......
......@@ -158,13 +158,6 @@ def get_output_var_list(op):
]
def to_tensors(xs):
if isinstance(xs, paddle.fluid.framework.Variable):
return [xs]
else:
return xs
def flatten(inp):
if inp is None or isinstance(inp, paddle.fluid.framework.Variable):
return [inp]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册