未验证 提交 26187c27 编写于 作者: Q Qi Li 提交者: GitHub

Fix elementwise_div UT by providing user defined gradients (#43536) (#43909)

Cherry-pick of #43536

Backgroud in #43262

In elementwise_div UT, the numeric gradient (validation) has large relative error in comparison to analytic gradient (Paddle OP).

    The default rtol for UTs is 0.005
    The rtol for float32 and float64 elementwise_div OP is set to be 0.05
    The rtol for float16 and bfloat16 elementwise_div OP is set to be 1.0

The relative error is too large, so this PR provides user defined gradients to test elementwise_div followed by the analytic method.
上级 69e82d83
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -15,277 +15,295 @@ ...@@ -15,277 +15,295 @@
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from op_test import OpTest, skip_check_grad_ci, convert_float_to_uint16 from op_test import OpTest, skip_check_grad_ci, convert_float_to_uint16
import paddle
from paddle import fluid
from paddle.fluid import core
class ElementwiseDivOp(OpTest): class ElementwiseDivOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "elementwise_div" self.op_type = "elementwise_div"
self.python_api = paddle.divide self.python_api = paddle.divide
self.dtype = np.float64 self.init_args()
self.init_dtype() self.init_dtype()
""" Warning self.init_shape()
CPU gradient check error!
'X': np.random.random((32,84)).astype("float32"), x = self.gen_data(self.x_shape).astype(self.val_dtype)
'Y': np.random.random((32,84)).astype("float32") y = self.gen_data(self.y_shape).astype(self.val_dtype)
""" out = self.compute_output(x, y).astype(self.val_dtype)
grad_out = np.ones(out.shape).astype(self.val_dtype)
grad_x = self.compute_gradient_x(grad_out, y).astype(self.val_dtype)
grad_y = self.compute_gradient_y(grad_out, out,
y).astype(self.val_dtype)
# Convert np.float32 data to np.uint16 for bfloat16 Paddle OP
if self.dtype == np.uint16:
x = convert_float_to_uint16(x)
y = convert_float_to_uint16(y)
out = convert_float_to_uint16(out)
grad_out = convert_float_to_uint16(grad_out)
grad_x = convert_float_to_uint16(grad_x)
grad_y = convert_float_to_uint16(grad_y)
self.inputs = {'X': x, 'Y': y}
self.outputs = {'Out': out}
self.grad_out = grad_out
self.grad_x = grad_x
self.grad_y = grad_y
def init_args(self):
self.check_dygraph = True
self.place = None
self.inputs = { def init_dtype(self):
'X': np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype), self.dtype = np.float64
'Y': np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) self.val_dtype = np.float64
}
self.outputs = {'Out': np.divide(self.inputs['X'], self.inputs['Y'])}
def check_eager(self): def init_shape(self):
return (not hasattr(self, "attrs") or (self.attrs["axis"] != -1)) self.x_shape = [13, 17]
self.y_shape = [13, 17]
def test_check_output(self): def gen_data(self, shape):
self.check_output(check_eager=False) return np.random.uniform(0.1, 1, shape)
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.05)
def test_check_grad_ingore_x(self): def compute_output(self, x, y):
self.check_grad( return x / y
['Y'], 'Out', max_relative_error=0.05, no_grad_set=set("X"))
def test_check_grad_ingore_y(self): def compute_gradient_x(self, grad_out, y):
self.check_grad( return grad_out / y
['X'], 'Out', max_relative_error=0.05, no_grad_set=set('Y'))
def init_dtype(self): def compute_gradient_y(self, grad_out, out, y):
pass return -1 * grad_out * out / y
def test_check_output(self):
if self.place is None:
self.check_output()
else:
self.check_output_with_place(self.place)
def test_check_gradient(self):
check_list = []
check_list.append({
'grad': ['X', 'Y'],
'no_grad': None,
'val_grad': [self.grad_x, self.grad_y]
})
check_list.append({
'grad': ['Y'],
'no_grad': set('X'),
'val_grad': [self.grad_y]
})
check_list.append({
'grad': ['X'],
'no_grad': set('Y'),
'val_grad': [self.grad_x]
})
for check_option in check_list:
check_args = [check_option['grad'], 'Out']
check_kwargs = {
'no_grad_set': check_option['no_grad'],
'user_defined_grads': check_option['val_grad'],
'user_defined_grad_outputs': [self.grad_out],
'check_dygraph': self.check_dygraph
}
if self.place is None:
self.check_grad(*check_args, **check_kwargs)
else:
check_args.insert(0, self.place)
self.check_grad_with_place(*check_args, **check_kwargs)
@unittest.skipIf(not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support the bfloat16")
class TestElementwiseDivOpBF16(ElementwiseDivOp):
def init_args(self):
# In due to output data type inconsistence of bfloat16 paddle op, we disable the dygraph check.
self.check_dygraph = False
self.place = core.CUDAPlace(0)
@unittest.skipIf(not core.is_compiled_with_cuda() or def init_dtype(self):
not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and not support the bfloat16")
class TestElementwiseDivOpBF16(OpTest):
def setUp(self):
self.op_type = "elementwise_div"
self.python_api = paddle.divide
self.dtype = np.uint16 self.dtype = np.uint16
self.val_dtype = np.float32
x = np.random.uniform(0.1, 1, [12, 13]).astype(np.float32) def init_shape(self):
y = np.random.uniform(0.1, 1, [12, 13]).astype(np.float32) self.x_shape = [12, 13]
self.y_shape = [12, 13]
out = np.divide(x, y)
self.inputs = { @skip_check_grad_ci(
'X': convert_float_to_uint16(x), reason="[skip shape check] Use y_shape(1) to test broadcast.")
'Y': convert_float_to_uint16(y) class TestElementwiseDivOpScalar(ElementwiseDivOp):
}
self.outputs = {'Out': convert_float_to_uint16(out)}
def test_check_output(self): def init_shape(self):
place = core.CUDAPlace(0) self.x_shape = [20, 3, 4]
self.check_output_with_place(place) self.y_shape = [1]
def test_check_grad_normal(self): def compute_gradient_y(self, grad_out, out, y):
place = core.CUDAPlace(0) return np.array([np.sum(-1 * grad_out * out / y)])
self.check_grad_with_place(place, ['X', 'Y'], 'Out')
def test_check_grad_ingore_x(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['Y'], 'Out', no_grad_set=set("X"))
def test_check_grad_ingore_y(self): class TestElementwiseDivOpVector(ElementwiseDivOp):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out', no_grad_set=set('Y'))
def init_shape(self):
self.x_shape = [100]
self.y_shape = [100]
@skip_check_grad_ci(
reason="[skip shape check] Use y_shape(1) to test broadcast.")
class TestElementwiseDivOp_scalar(ElementwiseDivOp):
def setUp(self):
self.op_type = "elementwise_div"
self.python_api = paddle.divide
self.inputs = {
'X': np.random.uniform(0.1, 1, [20, 3, 4]).astype(np.float64),
'Y': np.random.uniform(0.1, 1, [1]).astype(np.float64)
}
self.outputs = {'Out': self.inputs['X'] / self.inputs['Y']}
class TestElementwiseDivOpBroadcast0(ElementwiseDivOp):
class TestElementwiseDivOp_Vector(ElementwiseDivOp): def init_shape(self):
def setUp(self): self.x_shape = [100, 3, 4]
self.op_type = "elementwise_div" self.y_shape = [100]
self.python_api = paddle.divide self.attrs = {'axis': 0}
self.inputs = {
'X': np.random.uniform(0.1, 1, [100]).astype("float64"),
'Y': np.random.uniform(0.1, 1, [100]).astype("float64")
}
self.outputs = {'Out': np.divide(self.inputs['X'], self.inputs['Y'])}
def compute_output(self, x, y):
return x / y.reshape(100, 1, 1)
class TestElementwiseDivOp_broadcast_0(ElementwiseDivOp): def compute_gradient_x(self, grad_out, y):
def setUp(self): return grad_out / y.reshape(100, 1, 1)
self.op_type = "elementwise_div"
self.python_api = paddle.divide
self.inputs = {
'X': np.random.uniform(0.1, 1, [100, 3, 4]).astype("float64"),
'Y': np.random.uniform(0.1, 1, [100]).astype("float64")
}
self.attrs = {'axis': 0} def compute_gradient_y(self, grad_out, out, y):
self.outputs = { return np.sum(-1 * grad_out * out / y.reshape(100, 1, 1), axis=(1, 2))
'Out':
np.divide(self.inputs['X'], self.inputs['Y'].reshape(100, 1, 1))
}
class TestElementwiseDivOp_broadcast_1(ElementwiseDivOp): class TestElementwiseDivOpBroadcast1(ElementwiseDivOp):
def setUp(self):
self.op_type = "elementwise_div"
self.python_api = paddle.divide
self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 100, 4]).astype("float64"),
'Y': np.random.uniform(0.1, 1, [100]).astype("float64")
}
def init_shape(self):
self.x_shape = [2, 100, 4]
self.y_shape = [100]
self.attrs = {'axis': 1} self.attrs = {'axis': 1}
self.outputs = {
'Out':
np.divide(self.inputs['X'], self.inputs['Y'].reshape(1, 100, 1))
}
def compute_output(self, x, y):
return x / y.reshape(1, 100, 1)
class TestElementwiseDivOp_broadcast_2(ElementwiseDivOp): def compute_gradient_x(self, grad_out, y):
def setUp(self): return grad_out / y.reshape(1, 100, 1)
self.op_type = "elementwise_div"
self.python_api = paddle.divide
self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 3, 100]).astype("float64"),
'Y': np.random.uniform(0.1, 1, [100]).astype("float64")
}
self.outputs = { def compute_gradient_y(self, grad_out, out, y):
'Out': return np.sum(-1 * grad_out * out / y.reshape(1, 100, 1), axis=(0, 2))
np.divide(self.inputs['X'], self.inputs['Y'].reshape(1, 1, 100))
}
class TestElementwiseDivOp_broadcast_3(ElementwiseDivOp): class TestElementwiseDivOpBroadcast2(ElementwiseDivOp):
def setUp(self):
self.op_type = "elementwise_div"
self.python_api = paddle.divide
self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 10, 12, 5]).astype("float64"),
'Y': np.random.uniform(0.1, 1, [10, 12]).astype("float64")
}
def init_shape(self):
self.x_shape = [2, 3, 100]
self.y_shape = [100]
def compute_output(self, x, y):
return x / y.reshape(1, 1, 100)
def compute_gradient_x(self, grad_out, y):
return grad_out / y.reshape(1, 1, 100)
def compute_gradient_y(self, grad_out, out, y):
return np.sum(-1 * grad_out * out / y.reshape(1, 1, 100), axis=(0, 1))
class TestElementwiseDivOpBroadcast3(ElementwiseDivOp):
def init_shape(self):
self.x_shape = [2, 10, 12, 5]
self.y_shape = [10, 12]
self.attrs = {'axis': 1} self.attrs = {'axis': 1}
self.outputs = {
'Out':
np.divide(self.inputs['X'], self.inputs['Y'].reshape(1, 10, 12, 1))
}
def compute_output(self, x, y):
return x / y.reshape(1, 10, 12, 1)
class TestElementwiseDivOp_broadcast_4(ElementwiseDivOp): def compute_gradient_x(self, grad_out, y):
def setUp(self): return grad_out / y.reshape(1, 10, 12, 1)
self.op_type = "elementwise_div"
self.python_api = paddle.divide
self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 3, 50]).astype("float64"),
'Y': np.random.uniform(0.1, 1, [2, 1, 50]).astype("float64")
}
self.outputs = {'Out': np.divide(self.inputs['X'], self.inputs['Y'])}
def compute_gradient_y(self, grad_out, out, y):
return np.sum(-1 * grad_out * out / y.reshape(1, 10, 12, 1),
axis=(0, 3))
class TestElementwiseDivOp_broadcast_5(ElementwiseDivOp):
def setUp(self):
self.op_type = "elementwise_div"
self.python_api = paddle.divide
self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 3, 4, 20]).astype("float64"),
'Y': np.random.uniform(0.1, 1, [2, 3, 1, 20]).astype("float64")
}
self.outputs = {'Out': np.divide(self.inputs['X'], self.inputs['Y'])}
class TestElementwiseDivOpBroadcast4(ElementwiseDivOp):
class TestElementwiseDivOp_commonuse_1(ElementwiseDivOp): def init_shape(self):
def setUp(self): self.x_shape = [2, 3, 50]
self.op_type = "elementwise_div" self.y_shape = [2, 1, 50]
self.python_api = paddle.divide
self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 3, 100]).astype("float64"),
'Y': np.random.uniform(0.1, 1, [1, 1, 100]).astype("float64"),
}
self.outputs = {'Out': np.divide(self.inputs['X'], self.inputs['Y'])}
def compute_gradient_y(self, grad_out, out, y):
return np.sum(-1 * grad_out * out / y, axis=(1)).reshape(2, 1, 50)
class TestElementwiseDivOp_commonuse_2(ElementwiseDivOp):
def setUp(self):
self.op_type = "elementwise_div"
self.python_api = paddle.divide
self.inputs = {
'X': np.random.uniform(0.1, 1, [30, 3, 1, 5]).astype("float64"),
'Y': np.random.uniform(0.1, 1, [30, 1, 4, 1]).astype("float64"),
}
self.outputs = {'Out': np.divide(self.inputs['X'], self.inputs['Y'])}
class TestElementwiseDivOpBroadcast5(ElementwiseDivOp):
class TestElementwiseDivOp_xsize_lessthan_ysize(ElementwiseDivOp): def init_shape(self):
def setUp(self): self.x_shape = [2, 3, 4, 20]
self.op_type = "elementwise_div" self.y_shape = [2, 3, 1, 20]
self.python_api = paddle.divide
self.inputs = { def compute_gradient_y(self, grad_out, out, y):
'X': np.random.uniform(0.1, 1, [10, 12]).astype("float64"), return np.sum(-1 * grad_out * out / y, axis=(2)).reshape(2, 3, 1, 20)
'Y': np.random.uniform(0.1, 1, [2, 3, 10, 12]).astype("float64"),
}
class TestElementwiseDivOpCommonuse1(ElementwiseDivOp):
def init_shape(self):
self.x_shape = [2, 3, 100]
self.y_shape = [1, 1, 100]
def compute_gradient_y(self, grad_out, out, y):
return np.sum(-1 * grad_out * out / y, axis=(0, 1)).reshape(1, 1, 100)
class TestElementwiseDivOpCommonuse2(ElementwiseDivOp):
def init_shape(self):
self.x_shape = [30, 3, 1, 5]
self.y_shape = [30, 1, 4, 1]
def compute_gradient_x(self, grad_out, y):
return np.sum(grad_out / y, axis=(2)).reshape(30, 3, 1, 5)
def compute_gradient_y(self, grad_out, out, y):
return np.sum(-1 * grad_out * out / y, axis=(1, 3)).reshape(30, 1, 4, 1)
class TestElementwiseDivOpXsizeLessThanYsize(ElementwiseDivOp):
def init_shape(self):
self.x_shape = [10, 12]
self.y_shape = [2, 3, 10, 12]
self.attrs = {'axis': 2} self.attrs = {'axis': 2}
self.outputs = {'Out': np.divide(self.inputs['X'], self.inputs['Y'])} def compute_gradient_x(self, grad_out, y):
return np.sum(grad_out / y, axis=(0, 1))
class TestElementwiseDivOp_INT(OpTest): class TestElementwiseDivOpInt(ElementwiseDivOp):
def setUp(self):
self.op_type = "elementwise_div" def init_dtype(self):
self.python_api = paddle.divide
self.dtype = np.int32 self.dtype = np.int32
self.init_dtype() self.val_dtype = np.int32
self.inputs = {
'X': np.random.randint(
1, 5, size=[13, 17]).astype(self.dtype),
'Y': np.random.randint(
1, 5, size=[13, 17]).astype(self.dtype)
}
self.outputs = {'Out': self.inputs['X'] // self.inputs['Y']}
def test_check_output(self): def gen_data(self, shape):
self.check_output() return np.random.randint(1, 5, size=shape)
def init_dtype(self): def compute_output(self, x, y):
pass return x // y
@unittest.skipIf(not core.is_compiled_with_cuda(), @unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA") "core is not compiled with CUDA")
class TestElementwiseDivOpFp16(ElementwiseDivOp): class TestElementwiseDivOpFp16(ElementwiseDivOp):
def init_dtype(self): def init_dtype(self):
self.dtype = np.float16 self.dtype = np.float16
self.val_dtype = np.float16
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out', max_relative_error=1)
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'], 'Out', max_relative_error=1, no_grad_set=set("X"))
def test_check_grad_ingore_y(self):
self.check_grad(
['X'], 'Out', max_relative_error=1, no_grad_set=set('Y'))
class TestElementwiseDivBroadcast(unittest.TestCase): class TestElementwiseDivBroadcast(unittest.TestCase):
def test_shape_with_batch_sizes(self): def test_shape_with_batch_sizes(self):
with fluid.program_guard(fluid.Program()): with fluid.program_guard(fluid.Program()):
x_var = fluid.data( x_var = fluid.data(name='x',
name='x', dtype='float32', shape=[None, 3, None, None]) dtype='float32',
shape=[None, 3, None, None])
one = 2. one = 2.
out = one / x_var out = one / x_var
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
...@@ -295,6 +313,7 @@ class TestElementwiseDivBroadcast(unittest.TestCase): ...@@ -295,6 +313,7 @@ class TestElementwiseDivBroadcast(unittest.TestCase):
class TestDivideOp(unittest.TestCase): class TestDivideOp(unittest.TestCase):
def test_name(self): def test_name(self):
with fluid.program_guard(fluid.Program()): with fluid.program_guard(fluid.Program()):
x = fluid.data(name="x", shape=[2, 3], dtype="float32") x = fluid.data(name="x", shape=[2, 3], dtype="float32")
...@@ -316,6 +335,7 @@ class TestDivideOp(unittest.TestCase): ...@@ -316,6 +335,7 @@ class TestDivideOp(unittest.TestCase):
class TestComplexElementwiseDivOp(OpTest): class TestComplexElementwiseDivOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "elementwise_div" self.op_type = "elementwise_div"
self.python_api = paddle.divide self.python_api = paddle.divide
...@@ -352,30 +372,28 @@ class TestComplexElementwiseDivOp(OpTest): ...@@ -352,30 +372,28 @@ class TestComplexElementwiseDivOp(OpTest):
self.check_output(check_eager=False) self.check_output(check_eager=False)
def test_check_grad_normal(self): def test_check_grad_normal(self):
self.check_grad( self.check_grad(['X', 'Y'],
['X', 'Y'], 'Out',
'Out', user_defined_grads=[self.grad_x, self.grad_y],
user_defined_grads=[self.grad_x, self.grad_y], user_defined_grad_outputs=[self.grad_out])
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_x(self): def test_check_grad_ingore_x(self):
self.check_grad( self.check_grad(['Y'],
['Y'], 'Out',
'Out', no_grad_set=set("X"),
no_grad_set=set("X"), user_defined_grads=[self.grad_y],
user_defined_grads=[self.grad_y], user_defined_grad_outputs=[self.grad_out])
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_y(self): def test_check_grad_ingore_y(self):
self.check_grad( self.check_grad(['X'],
['X'], 'Out',
'Out', no_grad_set=set('Y'),
no_grad_set=set('Y'), user_defined_grads=[self.grad_x],
user_defined_grads=[self.grad_x], user_defined_grad_outputs=[self.grad_out])
user_defined_grad_outputs=[self.grad_out])
class TestRealComplexElementwiseDivOp(TestComplexElementwiseDivOp): class TestRealComplexElementwiseDivOp(TestComplexElementwiseDivOp):
def init_input_output(self): def init_input_output(self):
self.x = np.random.random((2, 3, 4, 5)).astype(self.dtype) self.x = np.random.random((2, 3, 4, 5)).astype(self.dtype)
self.y = np.random.random( self.y = np.random.random(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册