未验证 提交 017452e9 编写于 作者: C Charles-hit 提交者: GitHub

[OpTest]add only_check_prim parameter in check grad (#51210)

* support elementwise_pow bfloat16

* add only_check_prim parameters in check_grad

* modify unit test

* fix floor test

* fix sigmoid bfloat16 test
上级 10b95e8d
......@@ -648,6 +648,5 @@ struct ElementwiseInversePowFunctor<dtype::float16> {
return static_cast<dtype::float16>(std::pow(f_b, f_a));
}
};
} // namespace funcs
} // namespace phi
......@@ -1672,9 +1672,6 @@ class OpTest(unittest.TestCase):
# Support operators which are not in the NO_FP64_CHECK_GRAD_OP_LIST list can be test prim with fp32
setattr(self.__class__, 'check_prim', True)
self.__class__.op_type = self.op_type
if prim_checker.is_only_check_prim():
self.only_prim = True
return
# set some flags by the combination of arguments.
self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs)
if (
......@@ -1842,8 +1839,6 @@ class OpTest(unittest.TestCase):
check_prim=check_prim,
inplace_atol=inplace_atol,
)
if hasattr(self, 'only_prim') and self.only_prim:
continue
if check_dygraph:
outs, dygraph_dygraph_outs, fetch_list = res
else:
......@@ -1958,6 +1953,7 @@ class OpTest(unittest.TestCase):
user_defined_grad_outputs=None,
check_dygraph=True,
check_prim=False,
only_check_prim=False,
):
self._check_grad_helper()
places = self._get_places()
......@@ -1974,6 +1970,7 @@ class OpTest(unittest.TestCase):
user_defined_grad_outputs,
check_dygraph=check_dygraph,
check_prim=check_prim,
only_check_prim=only_check_prim,
)
def check_grad_with_place(
......@@ -1989,6 +1986,7 @@ class OpTest(unittest.TestCase):
user_defined_grad_outputs=None,
check_dygraph=True,
check_prim=False,
only_check_prim=False,
numeric_place=None,
):
core._set_prim_all_enabled(False)
......@@ -2005,8 +2003,7 @@ class OpTest(unittest.TestCase):
# Support operators which are not in the NO_FP64_CHECK_GRAD_OP_LIST list can be test prim with fp32
setattr(self.__class__, 'check_prim', True)
self._check_grad_helper()
if prim_grad_checker.is_only_check_prim():
self.only_prim = True
if only_check_prim:
return
self.scope = core.Scope()
op_inputs = self.inputs if hasattr(self, "inputs") else dict()
......
......@@ -1506,10 +1506,6 @@ class OpTest(unittest.TestCase):
# Support operators which not in the NO_FP64_CHECK_GRAD_OP_LIST list can be test prim with fp32
setattr(self.__class__, 'check_prim', True)
self.__class__.op_type = self.op_type
if prim_checker.is_only_check_prim():
self.only_prim = True
return
# disable legacy dygraph check when check_eager is True
if check_eager:
check_dygraph = False
......@@ -2087,8 +2083,6 @@ class OpTest(unittest.TestCase):
check_eager=check_eager,
check_prim=check_prim,
)
if hasattr(self, 'only_prim') and self.only_prim:
continue
if check_eager:
assert not check_dygraph
outs, eager_dygraph_outs, fetch_list = res
......@@ -2212,6 +2206,7 @@ class OpTest(unittest.TestCase):
check_dygraph=True,
check_eager=False,
check_prim=False,
only_check_prim=False,
):
# disable legacy dygraph check when check_eager is True
if check_eager:
......@@ -2233,6 +2228,7 @@ class OpTest(unittest.TestCase):
check_dygraph,
check_eager=check_eager,
check_prim=check_prim,
only_check_prim=only_check_prim,
)
def check_grad_with_place(
......@@ -2250,6 +2246,7 @@ class OpTest(unittest.TestCase):
numeric_place=None,
check_eager=False,
check_prim=False,
only_check_prim=False,
):
core._set_prim_all_enabled(False)
if check_prim:
......@@ -2265,8 +2262,7 @@ class OpTest(unittest.TestCase):
# Support operators which not in the NO_FP64_CHECK_GRAD_OP_LIST list can be test prim with fp32
setattr(self.__class__, 'check_prim', True)
self._check_grad_helper()
if prim_grad_checker.is_only_check_prim():
self.only_prim = True
if only_check_prim:
return
# disable legacy dygraph check when check_eager is True
if check_eager:
......
......@@ -297,11 +297,6 @@ class PrimForwardChecker:
if hasattr(self.op_test, 'enable_check_jit_comp_with_cinn')
else True
)
self.only_prim = (
self.op_test.only_prim
if hasattr(self.op_test, 'only_prim')
else False
)
self.kernel_sig = self.get_kernel_sig()
def init_checker_threshold(self):
......@@ -413,9 +408,6 @@ class PrimForwardChecker:
)
return kernel_sig
def is_only_check_prim(self):
return self.only_prim
def get_eager_desire(self):
paddle.disable_static()
if type(self.place) is paddle.fluid.libpaddle.CPUPlace:
......@@ -601,7 +593,7 @@ class PrimForwardChecker:
msg = (
'Check static comp forward api out failed. Mismatch between static comp '
'and eager on %s, when enable_fw_comp is %s,the forward api out tensor\'s index is : %d \n'
'static comp forward api out tensor:%s\n eager forward api out tensor:%s\n'
'static comp forward api out tensor:\n%s\n eager forward api out tensor:\n%s\n'
% (
str(self.place),
self.enable_fw_comp,
......@@ -663,7 +655,7 @@ class PrimForwardChecker:
msg = (
'Check jit comp forward api out failed. Mismatch between jit comp '
'and eager on %s, when enable_fw_comp is %s,the forward api out tensor\'s index is : %d \n'
'jit comp forward api out tensor:%s\n eager forward api out tensor:%s\n'
'jit comp forward api out tensor:\n%s\n eager forward api out tensor:\n%s\n'
% (
str(self.place),
self.enable_fw_comp,
......@@ -743,7 +735,7 @@ class PrimForwardChecker:
msg = (
'Check jit comp with cinn forward api out failed. Mismatch between jit comp and eager on %s, '
'when enable_fw_comp is %s, enable_cinn is %s, the forward api out tensor\'s index is : %d \n'
'jit comp forward api out tensor:%s\n eager forward api out tensor:%s\n'
'jit comp forward api out tensor:\n%s\n eager forward api out tensor:\n%s\n'
% (
str(self.place),
self.enable_fw_comp,
......@@ -931,7 +923,7 @@ class PrimGradChecker(PrimForwardChecker):
msg = (
'Check eager comp grad out failed. Mismatch between eager comp '
'and eager on %s, when enable_rev_comp is %s,the eager comp grad out tensor\'s index is : %d \n'
'eager comp grad out tensor:%s\n eager grad out tensor:%s\n'
'eager comp grad out tensor:\n%s\n eager grad out tensor:\n%s\n'
% (
str(self.place),
self.enable_rev_comp,
......@@ -1021,7 +1013,7 @@ class PrimGradChecker(PrimForwardChecker):
msg = (
'Check static comp grad out failed. Mismatch between static comp '
'and eager on %s, when enable_fw_comp is %s,enable_rev_comp is %s,the forward api out tensor\'s index is : %d \n'
'static comp grad out tensor:%s\n eager grad out tensor:%s\n'
'static comp grad out tensor:\n%s\n eager grad out tensor:\n%s\n'
% (
str(self.place),
self.enable_fw_comp,
......@@ -1118,7 +1110,7 @@ class PrimGradChecker(PrimForwardChecker):
msg = (
'Check jit comp grad out failed. Mismatch between jit comp '
'and eager on %s, when enable_fw_comp is %s, enable_rev_comp is %s,the grad out tensor\'s index is : %d \n'
'jit comp grad out tensor:%s\n eager grad out out tensor:%s\n'
'jit comp grad out tensor:\n%s\n eager grad out out tensor:\n%s\n'
% (
str(self.place),
self.enable_fw_comp,
......@@ -1229,7 +1221,7 @@ class PrimGradChecker(PrimForwardChecker):
msg = (
'Check jit comp with cinn grad out failed. Mismatch between jit comp with cinn '
'and eager on %s, when enable_fw_comp is %s, enable_rev_comp is %s, enable_cinn is %s,'
'the grad out tensor\'s index is : %d ,jit comp with cinn grad out tensor:%s\n eager grad out out tensor:%s\n'
'the grad out tensor\'s index is : %d ,jit comp with cinn grad out tensor:\n%s\n eager grad out out tensor:\n%s\n'
% (
str(self.place),
self.enable_fw_comp,
......
......@@ -94,7 +94,7 @@ class TestActivation_ZeroDim(TestActivation):
self.shape = []
class TestExpPrimFp32(OpTest):
class TestExpFp32_Prim(OpTest):
def setUp(self):
self.op_type = "exp"
self.prim_op_type = "prim"
......@@ -108,8 +108,7 @@ class TestExpPrimFp32(OpTest):
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
self.skip_cinn()
self.set_only_prim()
self.if_skip_cinn()
def test_check_output(self):
self.check_output()
......@@ -123,40 +122,34 @@ class TestExpPrimFp32(OpTest):
def init_shape(self):
self.shape = [12, 17]
def skip_cinn(self):
def if_skip_cinn(self):
self.enable_cinn = True
def set_only_prim(self):
pass
class TestExpPrimFp64(TestExpPrimFp32):
class TestExpFp64_Prim(TestExpFp32_Prim):
def init_dtype(self):
self.dtype = np.float64
class TestExpPrimFp16(TestExpPrimFp32):
class TestExpFp16_Prim(TestExpFp32_Prim):
def init_dtype(self):
self.dtype = np.float16
def set_only_prim(self):
self.only_prim = True
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out', check_prim=True)
self.check_grad(['X'], 'Out', check_prim=True, only_check_prim=True)
def skip_cinn(self):
def if_skip_cinn(self):
self.enable_cinn = True
class TestExpPrim_ZeroDim(TestExpPrimFp32):
class TestExpPrim_ZeroDim(TestExpFp32_Prim):
def init_shape(self):
self.shape = []
def skip_cinn(self):
def if_skip_cinn(self):
self.enable_cinn = False
......@@ -287,36 +280,6 @@ class TestSigmoid_ZeroDim(TestSigmoid):
self.shape = []
class TestSigmoidFP16(TestActivation):
def setUp(self):
self.op_type = "sigmoid"
self.prim_op_type = "comp"
self.enable_cinn = False
self.only_prim = True
self.python_api = paddle.nn.functional.sigmoid
self.init_dtype()
self.init_shape()
np.random.seed(1024)
x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
out = 1 / (1 + np.exp(-x))
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
def init_dtype(self):
self.dtype = np.float16
def test_check_grad(self):
self.check_grad(['X'], 'Out', max_relative_error=0.01, check_prim=True)
def test_check_output(self):
check_eager = False
if hasattr(self, 'check_eager'):
check_eager = self.check_eager
self.check_output(check_eager=check_eager, check_prim=True)
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
......@@ -328,7 +291,6 @@ class TestSigmoidBF16(OpTest):
self.python_api = paddle.nn.functional.sigmoid
self.init_dtype()
self.init_shape()
np.random.seed(1024)
x = np.random.uniform(-1, 1, self.shape).astype(np.float32)
out = 1 / (1 + np.exp(-x))
......@@ -346,7 +308,7 @@ class TestSigmoidBF16(OpTest):
def test_check_output(self):
place = core.CUDAPlace(0)
# elementwise_pow can not support bfloat16, skip check_prim = True.
# elementwise_pow doesn't support bfloat16, skip check_prim here.
self.check_output_with_place(place)
def test_check_grad(self):
......@@ -370,6 +332,7 @@ class TestSilu(TestActivation):
self.python_api = paddle.nn.functional.silu
self.init_dtype()
self.init_shape()
self.if_skip_cinn()
np.random.seed(1024)
x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
......@@ -381,46 +344,19 @@ class TestSilu(TestActivation):
def init_dtype(self):
self.dtype = np.float32
def if_skip_cinn(self):
pass
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', check_prim=True)
class TestSilu_ZeroDim(TestSilu):
def init_shape(self):
self.shape = []
self.enable_cinn = False
class TestSiluFP16(TestActivation):
def setUp(self):
self.op_type = "silu"
self.prim_op_type = "comp"
self.enable_cinn = True
self.only_prim = True
self.python_api = paddle.nn.functional.silu
self.init_dtype()
self.init_shape()
np.random.seed(1024)
x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
out = x / (np.exp(-x) + 1)
self.inputs = {'X': x}
self.outputs = {'Out': out}
def init_dtype(self):
self.dtype = np.float16
def test_check_grad(self):
self.check_grad(['X'], 'Out', check_prim=True)
def test_check_output(self):
check_eager = False
if hasattr(self, 'check_eager'):
check_eager = self.check_eager
self.check_output(check_eager=check_eager, check_prim=True)
def if_skip_cinn(self):
self.enable_cinn = False
class TestSiluAPI(unittest.TestCase):
......@@ -1408,6 +1344,7 @@ class TestCeil_ZeroDim(TestCeil):
class TestFloor(TestActivation):
def setUp(self):
self.op_type = "floor"
self.prim_op_type = "prim"
self.check_eager = True
self.python_api = paddle.floor
self.init_dtype()
......@@ -1435,16 +1372,10 @@ class TestFloor_ZeroDim(TestFloor):
self.shape = []
class TestFloorPrim(TestActivation):
class TestFloor_Prim(TestActivation):
def setUp(self):
self.op_type = "floor"
self.prim_op_type = "prim"
# the gradient on floor, ceil, round is undefined.
# we return zero as gradient, but the numpy return nan.
# for prim, we compare result with eager python api,
# so, we use only_prim flag to express we only test prim.
self.only_prim = True
self.check_eager = True
self.python_api = paddle.floor
self.init_dtype()
......@@ -1465,15 +1396,19 @@ class TestFloorPrim(TestActivation):
self.shape = [10, 12]
def test_check_grad(self):
self.check_grad(['X'], 'Out', check_prim=True)
# the gradient on floor, ceil, round is undefined.
# we return zero as gradient, but the numpy return nan.
# for prim, we compare result with eager python api,
# so, we use only_prim flag to express we only test prim.
self.check_grad(['X'], 'Out', check_prim=True, only_check_prim=True)
class TestFloorPrim_ZeroDim(TestFloorPrim):
class TestFloor_ZeroDim_Prim(TestFloor_Prim):
def init_shape(self):
self.shape = []
class TestFloorPrimFp16(TestFloorPrim):
class TestFloorFp16_Prim(TestFloor_Prim):
def init_dtype(self):
self.dtype = np.float16
......@@ -2284,8 +2219,17 @@ class TestHardSwish(TestActivation):
def init_shape(self):
self.shape = [10, 12]
def if_only_check_prim(self):
return False
def test_check_grad(self):
self.check_grad(['X'], 'Out', check_eager=True, check_prim=True)
self.check_grad(
['X'],
'Out',
check_eager=True,
check_prim=True,
only_check_prim=self.if_only_check_prim(),
)
def test_check_output(self):
self.check_output(check_eager=True, check_prim=True)
......@@ -2303,9 +2247,11 @@ class TestHardSwish_ZeroDim(TestHardSwish):
class TestHardSwishFP16(TestHardSwish):
def setUp(self):
super().setUp()
self.only_prim = True
self.enable_cinn = False
def if_only_check_prim(self):
return True
def init_dtype(self):
self.dtype = np.float16
......@@ -3813,7 +3759,12 @@ create_test_act_cudnn_class(TestTanh)
# ------------------ Test Fp16 ----------------------
def create_test_act_fp16_class(
parent, atol=1e-3, grad_check=True, grad_atol=0.80
parent,
atol=1e-3,
grad_check=True,
check_prim=False,
enable_cinn=True,
grad_atol=0.80,
):
@unittest.skipIf(
not paddle.is_compiled_with_cuda(), "core is not compiled with CUDA"
......@@ -3822,18 +3773,27 @@ def create_test_act_fp16_class(
def init_dtype(self):
self.dtype = np.float16
def if_skip_cinn(self):
self.enable_cinn = enable_cinn
def test_check_output(self):
place = core.CUDAPlace(0)
support_fp16 = core.is_float16_supported(place)
if support_fp16:
self.check_output_with_place(place, atol=atol)
self.check_output_with_place(
place, atol=atol, check_prim=check_prim
)
def test_check_grad(self):
place = core.CUDAPlace(0)
support_fp16 = core.is_float16_supported(place)
if support_fp16 and grad_check:
self.check_grad_with_place(
place, ['X'], 'Out', max_relative_error=grad_atol
place,
['X'],
'Out',
check_prim=check_prim,
max_relative_error=grad_atol,
)
cls_name = "{0}_{1}".format(parent.__name__, "fp16")
......@@ -3843,10 +3803,8 @@ def create_test_act_fp16_class(
create_test_act_fp16_class(TestActivation)
create_test_act_fp16_class(TestExpm1)
create_test_act_fp16_class(TestSigmoid)
create_test_act_fp16_class(TestSigmoidFP16)
create_test_act_fp16_class(TestSilu)
create_test_act_fp16_class(TestSiluFP16)
create_test_act_fp16_class(TestSigmoid, check_prim=True)
create_test_act_fp16_class(TestSilu, check_prim=True)
create_test_act_fp16_class(TestLogSigmoid)
create_test_act_fp16_class(TestTanh)
create_test_act_fp16_class(TestTanhshrink)
......@@ -3855,7 +3813,7 @@ create_test_act_fp16_class(TestSoftshrink)
create_test_act_fp16_class(TestSqrt)
create_test_act_fp16_class(TestAbs)
create_test_act_fp16_class(TestCeil, grad_check=False)
create_test_act_fp16_class(TestFloor, grad_check=False)
create_test_act_fp16_class(TestFloor, check_prim=True, grad_check=False)
create_test_act_fp16_class(TestCos, grad_atol=0.85)
create_test_act_fp16_class(TestTan, grad_atol=0.85)
create_test_act_fp16_class(TestCosh, grad_atol=0.85)
......
......@@ -37,7 +37,6 @@ class TestElementwiseAddOp(OpTest):
self.init_input_output()
self.init_kernel_type()
self.init_axis()
self.only_prim()
self.if_check_prim()
self.if_skip_cinn()
......@@ -103,9 +102,6 @@ class TestElementwiseAddOp(OpTest):
def init_axis(self):
self.axis = -1
def only_prim(self):
pass
def if_check_prim(self):
self.check_prim = self.axis == -1
......@@ -156,41 +152,6 @@ class TestFP16ElementwiseAddOp(TestElementwiseAddOp):
check_prim=self.check_prim,
)
def test_check_grad_normal(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(
place,
['X', 'Y'],
'Out',
check_dygraph=self.check_dygraph(),
check_prim=self.check_prim,
)
def test_check_grad_ingore_x(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(
place,
['Y'],
'Out',
no_grad_set=set("X"),
check_dygraph=self.check_dygraph(),
check_prim=self.check_prim,
)
def test_check_grad_ingore_y(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(
place,
['X'],
'Out',
no_grad_set=set('Y'),
check_dygraph=self.check_dygraph(),
check_prim=self.check_prim,
)
@unittest.skipIf(
not core.is_compiled_with_cuda()
......@@ -266,9 +227,6 @@ class TestFP16ElementwiseAddOp_scalar(TestFP16ElementwiseAddOp):
self.y = np.random.rand(1).astype(self.dtype)
self.out = self.x + self.y
def only_prim(self):
self.only_prim = True
@skip_check_grad_ci(
reason="[skip shape check] Use y_shape(1,1) to test broadcast."
......@@ -519,9 +477,12 @@ class TestFP16ElementwiseAddOp_rowwise_add_0(TestFP16ElementwiseAddOp):
class TestElementwiseAddOp_rowwise_add_1(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(100, 1).astype(self.dtype)
self.y = np.random.rand(1).astype(self.dtype)
self.out = self.x + self.y.reshape(1, 1)
self.x = np.random.rand(10, 100, 1).astype(self.dtype)
self.y = np.random.rand(100, 1).astype(self.dtype)
self.out = self.x + self.y.reshape(1, 100, 1)
def if_skip_cinn(self):
self.enable_cinn = False
class TestFP16ElementwiseAddOp_rowwise_add_1(TestFP16ElementwiseAddOp):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册