未验证 提交 1b8fd85d 编写于 作者: Y YuanRisheng 提交者: GitHub

Support double grad check of op in Eager mode and Add log double grad yaml (#42090)

* Support double grad check of op in Eager mode

* fix bugs of backward yaml

* adjust code format
上级 86a88631
......@@ -20,11 +20,13 @@ import six
import collections
import numpy as np
from itertools import product
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.executor import Executor
from paddle.fluid.backward import _append_grad_suffix_, _as_list
from paddle.fluid.framework import _test_eager_guard
def _product(t):
......@@ -58,6 +60,19 @@ def _get_item(t, i, np_dtype):
raise ValueError("Not supported data type " + str(np_dtype))
def _get_item_for_dygraph(t, i, np_dtype):
if np_dtype == np.float16:
np_t = t.numpy().astype(np.float16)
elif np_dtype == np.float32:
np_t = t.numpy().astype(np.float32)
elif np_dtype == np.float64:
np_t = t.numpy().astype(np.float64)
else:
raise ValueError("Not supported data type " + str(np_dtype))
np_t = np_t.flatten()
return np_t[i]
def _set_item(t, i, e, np_dtype):
if np_dtype == np.float16:
np_t = np.array(t).astype(np.float16)
......@@ -74,6 +89,22 @@ def _set_item(t, i, e, np_dtype):
raise ValueError("Not supported data type " + str(np_dtype))
def _set_item_for_dygraph(t, i, e, np_dtype):
if np_dtype == np.float16:
np_t = t.numpy().astype(np.float16)
elif np_dtype == np.float32:
np_t = t.numpy().astype(np.float32)
elif np_dtype == np.float64:
np_t = t.numpy().astype(np.float64)
else:
raise ValueError("Not supported data type " + str(np_dtype))
shape = np_t.shape
np_t = np_t.flatten()
np_t[i] = e
np_t = np_t.reshape(shape)
paddle.assign(np_t, t)
def set_var_in_scope(scope, place, name, value, recursive_seq_len=None):
t = scope.var(name).get_tensor()
t.set(value, place)
......@@ -138,6 +169,8 @@ def _compute_numerical_jacobian(program, x, y, place, scope, delta):
np_type = dtype_to_np_dtype(x.dtype)
jacobian = [make_jacobian(x, _product(yi.shape), np_type) for yi in y]
if np_type == np.float64:
delta = 1e-5
for i in six.moves.xrange(x_size):
orig = _get_item(x_t, i, np_type)
x_pos = orig + delta
......@@ -510,3 +543,194 @@ def triple_grad_check(x,
eps=eps,
atol=atol,
rtol=rtol)
def get_static_double_grad(x, y, x_init=None, dy_init=None, place=None):
"""
Get Double Grad result of static graph.
Args:
x (Variable|list[Variable]): input variables to the program.
y (Variable|list[Variable]): output variables to the program.
x_init (numpy.array|list[numpy.array]|None): the init value for input x.
dy_init (numpy.array|list[numpy.array]|None): the init value for output y.
place (fluid.CPUPlace or fluid.CUDAPlace): the device.
Returns:
A list of numpy array that stores second derivative result calulated by static graph.
"""
program = fluid.default_main_program()
scope = fluid.executor.global_scope()
y_grads = []
for i in six.moves.xrange(len(y)):
yi = y[i]
dyi_name = _append_grad_suffix_(yi.name)
np_type = dtype_to_np_dtype(yi.dtype)
dy = program.global_block().create_var(
name=dyi_name, shape=yi.shape, dtype=np_type, persistable=True)
dy.stop_gradient = False
set_var_in_scope(scope, place, dyi_name, dy_init[i])
y_grads.append(dy)
# append first order grads
dx = fluid.gradients(y, x, y_grads)
# y_grads are the input of first-order backward,
# so, they are also the input of second-order backward.
x += y_grads
x_init += dy_init
y = dx
# check input arguments
x = _as_list(x)
y = _as_list(y)
for v in x:
v.stop_gradient = False
v.persistable = True
if place is None:
place = fluid.CPUPlace()
if program is None:
program = fluid.default_main_program()
# init variable in strtup program
scope = fluid.executor.global_scope()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
x_init = _as_list(x_init)
# init inputs if x_init is not None
if x_init:
if len(x_init) != len(x):
raise ValueError('len(x_init) (=%d) is not the same'
' as len(x) (= %d)' % (len(x_init), len(x)))
# init variable in main program
for var, arr in zip(x, x_init):
assert var.shape == arr.shape
feeds = {k.name: v for k, v in zip(x, x_init)}
exe.run(program, feed=feeds, scope=scope)
dys = []
for yi in y:
np_type = dtype_to_np_dtype(yi.dtype)
dy_name = _append_grad_suffix_(yi.name)
# create dy Variable in Program
dy = program.global_block().create_var(
name=dy_name, shape=yi.shape, dtype=np_type, persistable=True)
# init dy tensor in scope
value = np.ones(yi.shape, dtype=np_type)
dy_t = set_var_in_scope(scope, place, dy_name, value)
dys.append(dy)
# append second order backward
ddx = fluid.gradients(y, x, dys)
exe = fluid.Executor(place)
# filter None in dx for DX/DY may be None in kernel
# only fetch not None dx in exe.run
filted = [(i, dxi) for i, dxi in enumerate(ddx) if dxi is not None]
filted_idx, filted_ddx = zip(*filted)
ddx_res = exe.run(program, scope=scope, fetch_list=filted_ddx)
return ddx_res
def get_eager_double_grad(func, x_init=None, dy_init=None):
"""
Get Double Grad result of dygraph.
Args:
func: A wrapped dygraph function that its logic is equal to static program
x_init (numpy.array|list[numpy.array]|None): the init value for input x.
dy_init (numpy.array|list[numpy.array]|None): the init value for gradient of output.
Returns:
A list of numpy array that stores second derivative result calulated by dygraph
"""
inputs = []
dys = []
for x in x_init:
input_tensor = paddle.to_tensor(x)
input_tensor.stop_gradient = False
inputs.append(input_tensor)
for dy in dy_init:
dy_tensor = paddle.to_tensor(dy)
dy_tensor.stop_gradient = False
dys.append(dy_tensor)
# calculate first derivative
outputs = func(inputs)
d_inputs = paddle.grad(
outputs=outputs, inputs=inputs, grad_outputs=dys, create_graph=True)
# calcluate second derivative
inputs = inputs + dys
ddys = []
for d_input in d_inputs:
d_input.stop_gradient = False
ddy = paddle.ones(shape=d_input.shape, dtype=d_input.dtype)
ddy.stop_gradient = False
ddys.append(ddy)
dd_inputs = paddle.grad(outputs=d_inputs, inputs=inputs, grad_outputs=ddys)
return [dd_input.numpy() for dd_input in dd_inputs]
def double_grad_check_for_dygraph(func,
x,
y,
x_init=None,
place=None,
atol=1e-5,
rtol=1e-3,
raise_exception=True):
"""
Check gradients of gradients. This function will append backward to the
program before second order gradient check.
Args:
func: A wrapped dygraph function that its logic is equal to static program
x (Variable|list[Variable]): input variables to the program.
y (Variable|list[Variable]): output variables to the program.
x_init (numpy.array|list[numpy.array]|None): the init value for input x.
place (fluid.CPUPlace or fluid.CUDAPlace): the device.
eps (float): perturbation for finite differences.
atol (float): absolute tolerance.
rtol (float): relative tolerance.
raise_exception (bool): whether to raise an exception if
the check fails. Default is True.
"""
def fail_test(msg):
if raise_exception:
raise RuntimeError(msg)
return False
# check input arguments
x = _as_list(x)
for v in x:
v.stop_gradient = False
v.persistable = True
y = _as_list(y)
y_grads_init = []
for yi in y:
np_type = dtype_to_np_dtype(yi.dtype)
v = np.random.random(size=yi.shape).astype(np_type)
y_grads_init.append(v)
x_init = _as_list(x_init)
paddle.disable_static()
with _test_eager_guard():
eager_double_grad = get_eager_double_grad(func, x_init, y_grads_init)
paddle.enable_static()
static_double_grad = get_static_double_grad(x, y, x_init, y_grads_init,
place)
for i in six.moves.xrange(len(static_double_grad)):
if not np.allclose(static_double_grad[i], eager_double_grad[i], rtol,
atol):
msg = 'Check eager double result fail. Mismatch between static_graph double grad %s ' \
'and eager double grad %s on %s,\n' \
'static:%s\n eager:%s\n' \
% (static_double_grad[i].name, eager_double_grad[i].name, str(place), static_double_grad[i], eager_double_grad[i])
return fail_test(msg)
......@@ -23,6 +23,7 @@ import paddle.fluid.layers as layers
import paddle.fluid.core as core
import gradient_checker
import paddle.nn.functional as F
from paddle.fluid.framework import _test_eager_guard
from decorator_helper import prog_scope
......@@ -42,6 +43,7 @@ class TestSigmoidTripleGradCheck(unittest.TestCase):
[x], y, x_init=x_arr, place=place, eps=eps)
def test_grad(self):
paddle.enable_static()
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
......@@ -64,6 +66,7 @@ class TestSigmoidDoubleGradCheck(unittest.TestCase):
[x], y, x_init=x_arr, place=place, eps=eps)
def test_grad(self):
paddle.enable_static()
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
......@@ -86,6 +89,7 @@ class TestTanhTripleGradCheck(unittest.TestCase):
[x], y, x_init=x_arr, place=place, eps=eps)
def test_grad(self):
paddle.enable_static()
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
......@@ -108,6 +112,7 @@ class TestTanhDoubleGradCheck(unittest.TestCase):
[x], y, x_init=x_arr, place=place, eps=eps)
def test_grad(self):
paddle.enable_static()
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
......@@ -132,6 +137,7 @@ class TestReluDoubleGradCheck(unittest.TestCase):
[x], y, x_init=x_arr, place=place, eps=eps)
def test_grad(self):
paddle.enable_static()
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
......@@ -158,6 +164,7 @@ class TestLeakyReluDoubleGradCheck(unittest.TestCase):
[x], y, x_init=x_arr, place=place, eps=eps)
def test_grad(self):
paddle.enable_static()
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places = [fluid.CUDAPlace(0)]
......@@ -184,6 +191,7 @@ class TestELUDoubleGradCheck(unittest.TestCase):
[x], y, x_init=x_arr, place=place, eps=eps)
def test_grad(self):
paddle.enable_static()
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
......@@ -210,6 +218,7 @@ class TestCELUDoubleGradCheck(unittest.TestCase):
[x], y, x_init=x_arr, place=place, eps=eps)
def test_grad(self):
paddle.enable_static()
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
......@@ -234,6 +243,7 @@ class TestSqrtDoubleGradCheck(unittest.TestCase):
[x], y, x_init=x_arr, place=place, eps=eps)
def test_grad(self):
paddle.enable_static()
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places = [fluid.CUDAPlace(0)]
......@@ -258,6 +268,7 @@ class TestRsqrtDoubleGradCheck(unittest.TestCase):
[x], y, x_init=x_arr, place=place, eps=eps)
def test_grad(self):
paddle.enable_static()
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places = [fluid.CUDAPlace(0)]
......@@ -282,6 +293,7 @@ class TestSquareDoubleGradCheck(unittest.TestCase):
[x], y, x_init=x_arr, place=place, eps=eps)
def test_grad(self):
paddle.enable_static()
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
......@@ -310,6 +322,7 @@ class TestAbsDoubleGradCheck(unittest.TestCase):
[x], y, x_init=x_arr, place=place, eps=eps)
def test_grad(self):
paddle.enable_static()
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
......@@ -318,6 +331,9 @@ class TestAbsDoubleGradCheck(unittest.TestCase):
class TestLogDoubleGradCheck(unittest.TestCase):
def log_wrapper(self, x):
return paddle.log(x[0])
@prog_scope()
def func(self, place):
shape = [2, 3, 7, 9]
......@@ -332,8 +348,11 @@ class TestLogDoubleGradCheck(unittest.TestCase):
gradient_checker.double_grad_check(
[x], y, x_init=x_arr, place=place, eps=eps)
gradient_checker.double_grad_check_for_dygraph(
self.log_wrapper, [x], y, x_init=x_arr, place=place)
def test_grad(self):
paddle.enable_static()
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
......@@ -342,5 +361,4 @@ class TestLogDoubleGradCheck(unittest.TestCase):
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
......@@ -839,6 +839,16 @@
kernel :
func : log2_grad
- backward_api : log_double_grad
forward : log_grad (Tensor x, Tensor grad_out) -> Tensor(grad_x)
args : (Tensor x, Tensor grad_out, Tensor grad_x_grad)
output : Tensor(x_grad), Tensor(grad_out_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [x, x]
kernel :
func : log_double_grad
- backward_api : log_grad
forward : log (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
......@@ -848,6 +858,7 @@
param : [x]
kernel :
func : log_grad
backward : log_double_grad
- backward_api : log_loss_grad
forward : log_loss (Tensor input, Tensor label, float epsilon) -> Tensor(out)
......@@ -1473,7 +1484,7 @@
func : UnchangedInferMeta
param : [x]
kernel :
func : sigmoid_cross_entropy_with_logits_grad
func : sigmoid_cross_entropy_with_logits_grad
- backward_api : sigmoid_double_grad
forward : sigmoid_grad (Tensor out, Tensor fwd_grad_out) -> Tensor(grad_x)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册