From 1b8fd85d4460b5cf9dab3ce68897b130f83ebfb2 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Fri, 22 Apr 2022 11:12:23 +0800 Subject: [PATCH] Support double grad check of op in Eager mode and Add log double grad yaml (#42090) * Support double grad check of op in Eager mode * fix bugs of backward yaml * adjust code format --- .../fluid/tests/unittests/gradient_checker.py | 224 ++++++++++++++++++ .../unittests/test_activation_nn_grad.py | 20 +- python/paddle/utils/code_gen/backward.yaml | 13 +- 3 files changed, 255 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/gradient_checker.py b/python/paddle/fluid/tests/unittests/gradient_checker.py index dff2b7aa8d8..562d52668ce 100644 --- a/python/paddle/fluid/tests/unittests/gradient_checker.py +++ b/python/paddle/fluid/tests/unittests/gradient_checker.py @@ -20,11 +20,13 @@ import six import collections import numpy as np from itertools import product +import paddle import paddle.fluid as fluid import paddle.fluid.core as core from paddle.fluid.executor import Executor from paddle.fluid.backward import _append_grad_suffix_, _as_list +from paddle.fluid.framework import _test_eager_guard def _product(t): @@ -58,6 +60,19 @@ def _get_item(t, i, np_dtype): raise ValueError("Not supported data type " + str(np_dtype)) +def _get_item_for_dygraph(t, i, np_dtype): + if np_dtype == np.float16: + np_t = t.numpy().astype(np.float16) + elif np_dtype == np.float32: + np_t = t.numpy().astype(np.float32) + elif np_dtype == np.float64: + np_t = t.numpy().astype(np.float64) + else: + raise ValueError("Not supported data type " + str(np_dtype)) + np_t = np_t.flatten() + return np_t[i] + + def _set_item(t, i, e, np_dtype): if np_dtype == np.float16: np_t = np.array(t).astype(np.float16) @@ -74,6 +89,22 @@ def _set_item(t, i, e, np_dtype): raise ValueError("Not supported data type " + str(np_dtype)) +def _set_item_for_dygraph(t, i, e, np_dtype): + if np_dtype == np.float16: + np_t = t.numpy().astype(np.float16) + elif np_dtype == np.float32: + np_t = t.numpy().astype(np.float32) + elif np_dtype == np.float64: + np_t = t.numpy().astype(np.float64) + else: + raise ValueError("Not supported data type " + str(np_dtype)) + shape = np_t.shape + np_t = np_t.flatten() + np_t[i] = e + np_t = np_t.reshape(shape) + paddle.assign(np_t, t) + + def set_var_in_scope(scope, place, name, value, recursive_seq_len=None): t = scope.var(name).get_tensor() t.set(value, place) @@ -138,6 +169,8 @@ def _compute_numerical_jacobian(program, x, y, place, scope, delta): np_type = dtype_to_np_dtype(x.dtype) jacobian = [make_jacobian(x, _product(yi.shape), np_type) for yi in y] + if np_type == np.float64: + delta = 1e-5 for i in six.moves.xrange(x_size): orig = _get_item(x_t, i, np_type) x_pos = orig + delta @@ -510,3 +543,194 @@ def triple_grad_check(x, eps=eps, atol=atol, rtol=rtol) + + +def get_static_double_grad(x, y, x_init=None, dy_init=None, place=None): + """ + Get Double Grad result of static graph. + + Args: + x (Variable|list[Variable]): input variables to the program. + y (Variable|list[Variable]): output variables to the program. + x_init (numpy.array|list[numpy.array]|None): the init value for input x. + dy_init (numpy.array|list[numpy.array]|None): the init value for output y. + place (fluid.CPUPlace or fluid.CUDAPlace): the device. + Returns: + A list of numpy array that stores second derivative result calulated by static graph. + """ + + program = fluid.default_main_program() + scope = fluid.executor.global_scope() + y_grads = [] + for i in six.moves.xrange(len(y)): + yi = y[i] + dyi_name = _append_grad_suffix_(yi.name) + np_type = dtype_to_np_dtype(yi.dtype) + dy = program.global_block().create_var( + name=dyi_name, shape=yi.shape, dtype=np_type, persistable=True) + dy.stop_gradient = False + set_var_in_scope(scope, place, dyi_name, dy_init[i]) + y_grads.append(dy) + + # append first order grads + dx = fluid.gradients(y, x, y_grads) + + # y_grads are the input of first-order backward, + # so, they are also the input of second-order backward. + x += y_grads + x_init += dy_init + y = dx + + # check input arguments + x = _as_list(x) + y = _as_list(y) + + for v in x: + v.stop_gradient = False + v.persistable = True + if place is None: + place = fluid.CPUPlace() + if program is None: + program = fluid.default_main_program() + + # init variable in strtup program + scope = fluid.executor.global_scope() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + x_init = _as_list(x_init) + # init inputs if x_init is not None + if x_init: + if len(x_init) != len(x): + raise ValueError('len(x_init) (=%d) is not the same' + ' as len(x) (= %d)' % (len(x_init), len(x))) + # init variable in main program + for var, arr in zip(x, x_init): + assert var.shape == arr.shape + feeds = {k.name: v for k, v in zip(x, x_init)} + exe.run(program, feed=feeds, scope=scope) + + dys = [] + for yi in y: + np_type = dtype_to_np_dtype(yi.dtype) + dy_name = _append_grad_suffix_(yi.name) + # create dy Variable in Program + dy = program.global_block().create_var( + name=dy_name, shape=yi.shape, dtype=np_type, persistable=True) + # init dy tensor in scope + value = np.ones(yi.shape, dtype=np_type) + dy_t = set_var_in_scope(scope, place, dy_name, value) + dys.append(dy) + + # append second order backward + ddx = fluid.gradients(y, x, dys) + exe = fluid.Executor(place) + + # filter None in dx for DX/DY may be None in kernel + # only fetch not None dx in exe.run + filted = [(i, dxi) for i, dxi in enumerate(ddx) if dxi is not None] + filted_idx, filted_ddx = zip(*filted) + ddx_res = exe.run(program, scope=scope, fetch_list=filted_ddx) + + return ddx_res + + +def get_eager_double_grad(func, x_init=None, dy_init=None): + """ + Get Double Grad result of dygraph. + + Args: + func: A wrapped dygraph function that its logic is equal to static program + x_init (numpy.array|list[numpy.array]|None): the init value for input x. + dy_init (numpy.array|list[numpy.array]|None): the init value for gradient of output. + Returns: + A list of numpy array that stores second derivative result calulated by dygraph + """ + inputs = [] + dys = [] + for x in x_init: + input_tensor = paddle.to_tensor(x) + input_tensor.stop_gradient = False + inputs.append(input_tensor) + for dy in dy_init: + dy_tensor = paddle.to_tensor(dy) + dy_tensor.stop_gradient = False + dys.append(dy_tensor) + # calculate first derivative + outputs = func(inputs) + d_inputs = paddle.grad( + outputs=outputs, inputs=inputs, grad_outputs=dys, create_graph=True) + + # calcluate second derivative + inputs = inputs + dys + ddys = [] + for d_input in d_inputs: + d_input.stop_gradient = False + ddy = paddle.ones(shape=d_input.shape, dtype=d_input.dtype) + ddy.stop_gradient = False + ddys.append(ddy) + dd_inputs = paddle.grad(outputs=d_inputs, inputs=inputs, grad_outputs=ddys) + return [dd_input.numpy() for dd_input in dd_inputs] + + +def double_grad_check_for_dygraph(func, + x, + y, + x_init=None, + place=None, + atol=1e-5, + rtol=1e-3, + raise_exception=True): + """ + Check gradients of gradients. This function will append backward to the + program before second order gradient check. + + Args: + func: A wrapped dygraph function that its logic is equal to static program + x (Variable|list[Variable]): input variables to the program. + y (Variable|list[Variable]): output variables to the program. + x_init (numpy.array|list[numpy.array]|None): the init value for input x. + place (fluid.CPUPlace or fluid.CUDAPlace): the device. + eps (float): perturbation for finite differences. + atol (float): absolute tolerance. + rtol (float): relative tolerance. + raise_exception (bool): whether to raise an exception if + the check fails. Default is True. + """ + + def fail_test(msg): + if raise_exception: + raise RuntimeError(msg) + return False + + # check input arguments + x = _as_list(x) + for v in x: + v.stop_gradient = False + v.persistable = True + y = _as_list(y) + + y_grads_init = [] + for yi in y: + np_type = dtype_to_np_dtype(yi.dtype) + v = np.random.random(size=yi.shape).astype(np_type) + y_grads_init.append(v) + + x_init = _as_list(x_init) + + paddle.disable_static() + with _test_eager_guard(): + eager_double_grad = get_eager_double_grad(func, x_init, y_grads_init) + paddle.enable_static() + + static_double_grad = get_static_double_grad(x, y, x_init, y_grads_init, + place) + + for i in six.moves.xrange(len(static_double_grad)): + if not np.allclose(static_double_grad[i], eager_double_grad[i], rtol, + atol): + msg = 'Check eager double result fail. Mismatch between static_graph double grad %s ' \ + 'and eager double grad %s on %s,\n' \ + 'static:%s\n eager:%s\n' \ + % (static_double_grad[i].name, eager_double_grad[i].name, str(place), static_double_grad[i], eager_double_grad[i]) + return fail_test(msg) diff --git a/python/paddle/fluid/tests/unittests/test_activation_nn_grad.py b/python/paddle/fluid/tests/unittests/test_activation_nn_grad.py index eb4243ef1cb..72240be41dd 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_nn_grad.py +++ b/python/paddle/fluid/tests/unittests/test_activation_nn_grad.py @@ -23,6 +23,7 @@ import paddle.fluid.layers as layers import paddle.fluid.core as core import gradient_checker import paddle.nn.functional as F +from paddle.fluid.framework import _test_eager_guard from decorator_helper import prog_scope @@ -42,6 +43,7 @@ class TestSigmoidTripleGradCheck(unittest.TestCase): [x], y, x_init=x_arr, place=place, eps=eps) def test_grad(self): + paddle.enable_static() places = [fluid.CPUPlace()] if core.is_compiled_with_cuda(): places.append(fluid.CUDAPlace(0)) @@ -64,6 +66,7 @@ class TestSigmoidDoubleGradCheck(unittest.TestCase): [x], y, x_init=x_arr, place=place, eps=eps) def test_grad(self): + paddle.enable_static() places = [fluid.CPUPlace()] if core.is_compiled_with_cuda(): places.append(fluid.CUDAPlace(0)) @@ -86,6 +89,7 @@ class TestTanhTripleGradCheck(unittest.TestCase): [x], y, x_init=x_arr, place=place, eps=eps) def test_grad(self): + paddle.enable_static() places = [fluid.CPUPlace()] if core.is_compiled_with_cuda(): places.append(fluid.CUDAPlace(0)) @@ -108,6 +112,7 @@ class TestTanhDoubleGradCheck(unittest.TestCase): [x], y, x_init=x_arr, place=place, eps=eps) def test_grad(self): + paddle.enable_static() places = [fluid.CPUPlace()] if core.is_compiled_with_cuda(): places.append(fluid.CUDAPlace(0)) @@ -132,6 +137,7 @@ class TestReluDoubleGradCheck(unittest.TestCase): [x], y, x_init=x_arr, place=place, eps=eps) def test_grad(self): + paddle.enable_static() places = [fluid.CPUPlace()] if core.is_compiled_with_cuda(): places.append(fluid.CUDAPlace(0)) @@ -158,6 +164,7 @@ class TestLeakyReluDoubleGradCheck(unittest.TestCase): [x], y, x_init=x_arr, place=place, eps=eps) def test_grad(self): + paddle.enable_static() places = [fluid.CPUPlace()] if core.is_compiled_with_cuda(): places = [fluid.CUDAPlace(0)] @@ -184,6 +191,7 @@ class TestELUDoubleGradCheck(unittest.TestCase): [x], y, x_init=x_arr, place=place, eps=eps) def test_grad(self): + paddle.enable_static() places = [fluid.CPUPlace()] if core.is_compiled_with_cuda(): places.append(fluid.CUDAPlace(0)) @@ -210,6 +218,7 @@ class TestCELUDoubleGradCheck(unittest.TestCase): [x], y, x_init=x_arr, place=place, eps=eps) def test_grad(self): + paddle.enable_static() places = [fluid.CPUPlace()] if core.is_compiled_with_cuda(): places.append(fluid.CUDAPlace(0)) @@ -234,6 +243,7 @@ class TestSqrtDoubleGradCheck(unittest.TestCase): [x], y, x_init=x_arr, place=place, eps=eps) def test_grad(self): + paddle.enable_static() places = [fluid.CPUPlace()] if core.is_compiled_with_cuda(): places = [fluid.CUDAPlace(0)] @@ -258,6 +268,7 @@ class TestRsqrtDoubleGradCheck(unittest.TestCase): [x], y, x_init=x_arr, place=place, eps=eps) def test_grad(self): + paddle.enable_static() places = [fluid.CPUPlace()] if core.is_compiled_with_cuda(): places = [fluid.CUDAPlace(0)] @@ -282,6 +293,7 @@ class TestSquareDoubleGradCheck(unittest.TestCase): [x], y, x_init=x_arr, place=place, eps=eps) def test_grad(self): + paddle.enable_static() places = [fluid.CPUPlace()] if core.is_compiled_with_cuda(): places.append(fluid.CUDAPlace(0)) @@ -310,6 +322,7 @@ class TestAbsDoubleGradCheck(unittest.TestCase): [x], y, x_init=x_arr, place=place, eps=eps) def test_grad(self): + paddle.enable_static() places = [fluid.CPUPlace()] if core.is_compiled_with_cuda(): places.append(fluid.CUDAPlace(0)) @@ -318,6 +331,9 @@ class TestAbsDoubleGradCheck(unittest.TestCase): class TestLogDoubleGradCheck(unittest.TestCase): + def log_wrapper(self, x): + return paddle.log(x[0]) + @prog_scope() def func(self, place): shape = [2, 3, 7, 9] @@ -332,8 +348,11 @@ class TestLogDoubleGradCheck(unittest.TestCase): gradient_checker.double_grad_check( [x], y, x_init=x_arr, place=place, eps=eps) + gradient_checker.double_grad_check_for_dygraph( + self.log_wrapper, [x], y, x_init=x_arr, place=place) def test_grad(self): + paddle.enable_static() places = [fluid.CPUPlace()] if core.is_compiled_with_cuda(): places.append(fluid.CUDAPlace(0)) @@ -342,5 +361,4 @@ class TestLogDoubleGradCheck(unittest.TestCase): if __name__ == "__main__": - paddle.enable_static() unittest.main() diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index 64acc140c21..dfdc2335ae1 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -839,6 +839,16 @@ kernel : func : log2_grad +- backward_api : log_double_grad + forward : log_grad (Tensor x, Tensor grad_out) -> Tensor(grad_x) + args : (Tensor x, Tensor grad_out, Tensor grad_x_grad) + output : Tensor(x_grad), Tensor(grad_out_grad) + infer_meta : + func : GeneralBinaryGradInferMeta + param : [x, x] + kernel : + func : log_double_grad + - backward_api : log_grad forward : log (Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) @@ -848,6 +858,7 @@ param : [x] kernel : func : log_grad + backward : log_double_grad - backward_api : log_loss_grad forward : log_loss (Tensor input, Tensor label, float epsilon) -> Tensor(out) @@ -1473,7 +1484,7 @@ func : UnchangedInferMeta param : [x] kernel : - func : sigmoid_cross_entropy_with_logits_grad + func : sigmoid_cross_entropy_with_logits_grad - backward_api : sigmoid_double_grad forward : sigmoid_grad (Tensor out, Tensor fwd_grad_out) -> Tensor(grad_x) -- GitLab