utils.py 3.8 KB
Newer Older
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15 16
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from paddle.fluid.framework import Variable
17
from paddle.fluid.data_feeder import check_type
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32


def check_input_type(input, name, op_name):
    r"""Check whether the input is tensor or variable."""
    if paddle.in_dynamic_mode():
        if not isinstance(input, paddle.Tensor):
            raise ValueError("The input: {} must be tensor.".format(input))
    else:
        check_type(input, name, Variable, op_name)


def check_initial_inverse_hessian_estimate(H0):
    r"""Check whether the specified initial_inverse_hessian_estimate is symmetric and positive definite.
        Raise errors when precondition not met.

33
    Note:
34 35
        In static graph can not raise error directly, so use py_func make raise_func as a op,
        and use paddle.static.nn.cond to decide if put the op in net.
36
        cholesky is the fast way to check positive definition, but in static graph can not catch
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
        exception to raise value error, so use eigvals rather than cholesky in static graph.
    """
    is_symmetric = paddle.all(paddle.equal(H0, H0.t()))

    def raise_func():
        raise ValueError(
            "The initial_inverse_hessian_estimate should be symmetric and positive definite, but the specified is not."
        )

    if paddle.in_dynamic_mode():
        if not is_symmetric:
            raise_func()
        try:
            paddle.linalg.cholesky(H0)
        except RuntimeError as error:
            raise_func()
    else:

        def create_tmp_var(program, name, dtype, shape):
56 57 58
            return program.current_block().create_var(name=name,
                                                      dtype=dtype,
                                                      shape=shape)
59

60 61 62 63
        out_var = create_tmp_var(paddle.static.default_main_program(),
                                 name='output',
                                 dtype='float32',
                                 shape=[-1])
64 65

        def false_fn():
66 67 68
            paddle.static.nn.py_func(func=raise_func,
                                     x=is_symmetric,
                                     out=out_var)
69 70 71 72 73 74 75 76 77 78 79 80

        paddle.static.nn.cond(is_symmetric, None, false_fn)
        # eigvals only support cpu
        paddle.set_device("cpu")
        eigvals = paddle.paddle.linalg.eigvals(H0)
        is_positive = paddle.all(eigvals.real() > 0.) and paddle.all(
            eigvals.imag() == 0.)
        paddle.static.nn.cond(is_positive, None, false_fn)


def _value_and_gradient(f, x, v=None):
    r"""Compute function value and gradient of f at x.
81

82 83 84 85 86
    Args:
        f (Callable): the objective function.
        x (Tensor): the input tensor.
    Returns:
        value: a tensor that holds the function value.
87
        gradient: a tensor that holds the function gradients.
88
    """
S
Sing_chan 已提交
89 90 91 92
    # use detach to cut off relation between x and original graph
    x = x.detach()
    x.stop_gradient = False
    value = f(x)
93
    if paddle.in_dynamic_mode():
S
Sing_chan 已提交
94 95
        # only need to compute first order derivative, and some op dont support high order derivative.
        gradient = paddle.grad([value], [x], create_graph=False)[0]
96
    else:
S
Sing_chan 已提交
97 98 99
        gradient = paddle.static.gradients([value], [x])[0]
    # use detach to make results real number without grad to avoid assign error
    return value.detach(), gradient.detach()