未验证 提交 017452e9 编写于 作者: C Charles-hit 提交者: GitHub

[OpTest]add only_check_prim parameter in check grad (#51210)

* support elementwise_pow bfloat16

* add only_check_prim parameters in check_grad

* modify unit test

* fix floor test

* fix sigmoid bfloat16 test
上级 10b95e8d
...@@ -648,6 +648,5 @@ struct ElementwiseInversePowFunctor<dtype::float16> { ...@@ -648,6 +648,5 @@ struct ElementwiseInversePowFunctor<dtype::float16> {
return static_cast<dtype::float16>(std::pow(f_b, f_a)); return static_cast<dtype::float16>(std::pow(f_b, f_a));
} }
}; };
} // namespace funcs } // namespace funcs
} // namespace phi } // namespace phi
...@@ -1672,9 +1672,6 @@ class OpTest(unittest.TestCase): ...@@ -1672,9 +1672,6 @@ class OpTest(unittest.TestCase):
# Support operators which are not in the NO_FP64_CHECK_GRAD_OP_LIST list can be test prim with fp32 # Support operators which are not in the NO_FP64_CHECK_GRAD_OP_LIST list can be test prim with fp32
setattr(self.__class__, 'check_prim', True) setattr(self.__class__, 'check_prim', True)
self.__class__.op_type = self.op_type self.__class__.op_type = self.op_type
if prim_checker.is_only_check_prim():
self.only_prim = True
return
# set some flags by the combination of arguments. # set some flags by the combination of arguments.
self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs) self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs)
if ( if (
...@@ -1842,8 +1839,6 @@ class OpTest(unittest.TestCase): ...@@ -1842,8 +1839,6 @@ class OpTest(unittest.TestCase):
check_prim=check_prim, check_prim=check_prim,
inplace_atol=inplace_atol, inplace_atol=inplace_atol,
) )
if hasattr(self, 'only_prim') and self.only_prim:
continue
if check_dygraph: if check_dygraph:
outs, dygraph_dygraph_outs, fetch_list = res outs, dygraph_dygraph_outs, fetch_list = res
else: else:
...@@ -1958,6 +1953,7 @@ class OpTest(unittest.TestCase): ...@@ -1958,6 +1953,7 @@ class OpTest(unittest.TestCase):
user_defined_grad_outputs=None, user_defined_grad_outputs=None,
check_dygraph=True, check_dygraph=True,
check_prim=False, check_prim=False,
only_check_prim=False,
): ):
self._check_grad_helper() self._check_grad_helper()
places = self._get_places() places = self._get_places()
...@@ -1974,6 +1970,7 @@ class OpTest(unittest.TestCase): ...@@ -1974,6 +1970,7 @@ class OpTest(unittest.TestCase):
user_defined_grad_outputs, user_defined_grad_outputs,
check_dygraph=check_dygraph, check_dygraph=check_dygraph,
check_prim=check_prim, check_prim=check_prim,
only_check_prim=only_check_prim,
) )
def check_grad_with_place( def check_grad_with_place(
...@@ -1989,6 +1986,7 @@ class OpTest(unittest.TestCase): ...@@ -1989,6 +1986,7 @@ class OpTest(unittest.TestCase):
user_defined_grad_outputs=None, user_defined_grad_outputs=None,
check_dygraph=True, check_dygraph=True,
check_prim=False, check_prim=False,
only_check_prim=False,
numeric_place=None, numeric_place=None,
): ):
core._set_prim_all_enabled(False) core._set_prim_all_enabled(False)
...@@ -2005,8 +2003,7 @@ class OpTest(unittest.TestCase): ...@@ -2005,8 +2003,7 @@ class OpTest(unittest.TestCase):
# Support operators which are not in the NO_FP64_CHECK_GRAD_OP_LIST list can be test prim with fp32 # Support operators which are not in the NO_FP64_CHECK_GRAD_OP_LIST list can be test prim with fp32
setattr(self.__class__, 'check_prim', True) setattr(self.__class__, 'check_prim', True)
self._check_grad_helper() self._check_grad_helper()
if prim_grad_checker.is_only_check_prim(): if only_check_prim:
self.only_prim = True
return return
self.scope = core.Scope() self.scope = core.Scope()
op_inputs = self.inputs if hasattr(self, "inputs") else dict() op_inputs = self.inputs if hasattr(self, "inputs") else dict()
......
...@@ -1506,10 +1506,6 @@ class OpTest(unittest.TestCase): ...@@ -1506,10 +1506,6 @@ class OpTest(unittest.TestCase):
# Support operators which not in the NO_FP64_CHECK_GRAD_OP_LIST list can be test prim with fp32 # Support operators which not in the NO_FP64_CHECK_GRAD_OP_LIST list can be test prim with fp32
setattr(self.__class__, 'check_prim', True) setattr(self.__class__, 'check_prim', True)
self.__class__.op_type = self.op_type self.__class__.op_type = self.op_type
if prim_checker.is_only_check_prim():
self.only_prim = True
return
# disable legacy dygraph check when check_eager is True # disable legacy dygraph check when check_eager is True
if check_eager: if check_eager:
check_dygraph = False check_dygraph = False
...@@ -2087,8 +2083,6 @@ class OpTest(unittest.TestCase): ...@@ -2087,8 +2083,6 @@ class OpTest(unittest.TestCase):
check_eager=check_eager, check_eager=check_eager,
check_prim=check_prim, check_prim=check_prim,
) )
if hasattr(self, 'only_prim') and self.only_prim:
continue
if check_eager: if check_eager:
assert not check_dygraph assert not check_dygraph
outs, eager_dygraph_outs, fetch_list = res outs, eager_dygraph_outs, fetch_list = res
...@@ -2212,6 +2206,7 @@ class OpTest(unittest.TestCase): ...@@ -2212,6 +2206,7 @@ class OpTest(unittest.TestCase):
check_dygraph=True, check_dygraph=True,
check_eager=False, check_eager=False,
check_prim=False, check_prim=False,
only_check_prim=False,
): ):
# disable legacy dygraph check when check_eager is True # disable legacy dygraph check when check_eager is True
if check_eager: if check_eager:
...@@ -2233,6 +2228,7 @@ class OpTest(unittest.TestCase): ...@@ -2233,6 +2228,7 @@ class OpTest(unittest.TestCase):
check_dygraph, check_dygraph,
check_eager=check_eager, check_eager=check_eager,
check_prim=check_prim, check_prim=check_prim,
only_check_prim=only_check_prim,
) )
def check_grad_with_place( def check_grad_with_place(
...@@ -2250,6 +2246,7 @@ class OpTest(unittest.TestCase): ...@@ -2250,6 +2246,7 @@ class OpTest(unittest.TestCase):
numeric_place=None, numeric_place=None,
check_eager=False, check_eager=False,
check_prim=False, check_prim=False,
only_check_prim=False,
): ):
core._set_prim_all_enabled(False) core._set_prim_all_enabled(False)
if check_prim: if check_prim:
...@@ -2265,8 +2262,7 @@ class OpTest(unittest.TestCase): ...@@ -2265,8 +2262,7 @@ class OpTest(unittest.TestCase):
# Support operators which not in the NO_FP64_CHECK_GRAD_OP_LIST list can be test prim with fp32 # Support operators which not in the NO_FP64_CHECK_GRAD_OP_LIST list can be test prim with fp32
setattr(self.__class__, 'check_prim', True) setattr(self.__class__, 'check_prim', True)
self._check_grad_helper() self._check_grad_helper()
if prim_grad_checker.is_only_check_prim(): if only_check_prim:
self.only_prim = True
return return
# disable legacy dygraph check when check_eager is True # disable legacy dygraph check when check_eager is True
if check_eager: if check_eager:
......
...@@ -297,11 +297,6 @@ class PrimForwardChecker: ...@@ -297,11 +297,6 @@ class PrimForwardChecker:
if hasattr(self.op_test, 'enable_check_jit_comp_with_cinn') if hasattr(self.op_test, 'enable_check_jit_comp_with_cinn')
else True else True
) )
self.only_prim = (
self.op_test.only_prim
if hasattr(self.op_test, 'only_prim')
else False
)
self.kernel_sig = self.get_kernel_sig() self.kernel_sig = self.get_kernel_sig()
def init_checker_threshold(self): def init_checker_threshold(self):
...@@ -413,9 +408,6 @@ class PrimForwardChecker: ...@@ -413,9 +408,6 @@ class PrimForwardChecker:
) )
return kernel_sig return kernel_sig
def is_only_check_prim(self):
return self.only_prim
def get_eager_desire(self): def get_eager_desire(self):
paddle.disable_static() paddle.disable_static()
if type(self.place) is paddle.fluid.libpaddle.CPUPlace: if type(self.place) is paddle.fluid.libpaddle.CPUPlace:
...@@ -601,7 +593,7 @@ class PrimForwardChecker: ...@@ -601,7 +593,7 @@ class PrimForwardChecker:
msg = ( msg = (
'Check static comp forward api out failed. Mismatch between static comp ' 'Check static comp forward api out failed. Mismatch between static comp '
'and eager on %s, when enable_fw_comp is %s,the forward api out tensor\'s index is : %d \n' 'and eager on %s, when enable_fw_comp is %s,the forward api out tensor\'s index is : %d \n'
'static comp forward api out tensor:%s\n eager forward api out tensor:%s\n' 'static comp forward api out tensor:\n%s\n eager forward api out tensor:\n%s\n'
% ( % (
str(self.place), str(self.place),
self.enable_fw_comp, self.enable_fw_comp,
...@@ -663,7 +655,7 @@ class PrimForwardChecker: ...@@ -663,7 +655,7 @@ class PrimForwardChecker:
msg = ( msg = (
'Check jit comp forward api out failed. Mismatch between jit comp ' 'Check jit comp forward api out failed. Mismatch between jit comp '
'and eager on %s, when enable_fw_comp is %s,the forward api out tensor\'s index is : %d \n' 'and eager on %s, when enable_fw_comp is %s,the forward api out tensor\'s index is : %d \n'
'jit comp forward api out tensor:%s\n eager forward api out tensor:%s\n' 'jit comp forward api out tensor:\n%s\n eager forward api out tensor:\n%s\n'
% ( % (
str(self.place), str(self.place),
self.enable_fw_comp, self.enable_fw_comp,
...@@ -743,7 +735,7 @@ class PrimForwardChecker: ...@@ -743,7 +735,7 @@ class PrimForwardChecker:
msg = ( msg = (
'Check jit comp with cinn forward api out failed. Mismatch between jit comp and eager on %s, ' 'Check jit comp with cinn forward api out failed. Mismatch between jit comp and eager on %s, '
'when enable_fw_comp is %s, enable_cinn is %s, the forward api out tensor\'s index is : %d \n' 'when enable_fw_comp is %s, enable_cinn is %s, the forward api out tensor\'s index is : %d \n'
'jit comp forward api out tensor:%s\n eager forward api out tensor:%s\n' 'jit comp forward api out tensor:\n%s\n eager forward api out tensor:\n%s\n'
% ( % (
str(self.place), str(self.place),
self.enable_fw_comp, self.enable_fw_comp,
...@@ -931,7 +923,7 @@ class PrimGradChecker(PrimForwardChecker): ...@@ -931,7 +923,7 @@ class PrimGradChecker(PrimForwardChecker):
msg = ( msg = (
'Check eager comp grad out failed. Mismatch between eager comp ' 'Check eager comp grad out failed. Mismatch between eager comp '
'and eager on %s, when enable_rev_comp is %s,the eager comp grad out tensor\'s index is : %d \n' 'and eager on %s, when enable_rev_comp is %s,the eager comp grad out tensor\'s index is : %d \n'
'eager comp grad out tensor:%s\n eager grad out tensor:%s\n' 'eager comp grad out tensor:\n%s\n eager grad out tensor:\n%s\n'
% ( % (
str(self.place), str(self.place),
self.enable_rev_comp, self.enable_rev_comp,
...@@ -1021,7 +1013,7 @@ class PrimGradChecker(PrimForwardChecker): ...@@ -1021,7 +1013,7 @@ class PrimGradChecker(PrimForwardChecker):
msg = ( msg = (
'Check static comp grad out failed. Mismatch between static comp ' 'Check static comp grad out failed. Mismatch between static comp '
'and eager on %s, when enable_fw_comp is %s,enable_rev_comp is %s,the forward api out tensor\'s index is : %d \n' 'and eager on %s, when enable_fw_comp is %s,enable_rev_comp is %s,the forward api out tensor\'s index is : %d \n'
'static comp grad out tensor:%s\n eager grad out tensor:%s\n' 'static comp grad out tensor:\n%s\n eager grad out tensor:\n%s\n'
% ( % (
str(self.place), str(self.place),
self.enable_fw_comp, self.enable_fw_comp,
...@@ -1118,7 +1110,7 @@ class PrimGradChecker(PrimForwardChecker): ...@@ -1118,7 +1110,7 @@ class PrimGradChecker(PrimForwardChecker):
msg = ( msg = (
'Check jit comp grad out failed. Mismatch between jit comp ' 'Check jit comp grad out failed. Mismatch between jit comp '
'and eager on %s, when enable_fw_comp is %s, enable_rev_comp is %s,the grad out tensor\'s index is : %d \n' 'and eager on %s, when enable_fw_comp is %s, enable_rev_comp is %s,the grad out tensor\'s index is : %d \n'
'jit comp grad out tensor:%s\n eager grad out out tensor:%s\n' 'jit comp grad out tensor:\n%s\n eager grad out out tensor:\n%s\n'
% ( % (
str(self.place), str(self.place),
self.enable_fw_comp, self.enable_fw_comp,
...@@ -1229,7 +1221,7 @@ class PrimGradChecker(PrimForwardChecker): ...@@ -1229,7 +1221,7 @@ class PrimGradChecker(PrimForwardChecker):
msg = ( msg = (
'Check jit comp with cinn grad out failed. Mismatch between jit comp with cinn ' 'Check jit comp with cinn grad out failed. Mismatch between jit comp with cinn '
'and eager on %s, when enable_fw_comp is %s, enable_rev_comp is %s, enable_cinn is %s,' 'and eager on %s, when enable_fw_comp is %s, enable_rev_comp is %s, enable_cinn is %s,'
'the grad out tensor\'s index is : %d ,jit comp with cinn grad out tensor:%s\n eager grad out out tensor:%s\n' 'the grad out tensor\'s index is : %d ,jit comp with cinn grad out tensor:\n%s\n eager grad out out tensor:\n%s\n'
% ( % (
str(self.place), str(self.place),
self.enable_fw_comp, self.enable_fw_comp,
......
...@@ -94,7 +94,7 @@ class TestActivation_ZeroDim(TestActivation): ...@@ -94,7 +94,7 @@ class TestActivation_ZeroDim(TestActivation):
self.shape = [] self.shape = []
class TestExpPrimFp32(OpTest): class TestExpFp32_Prim(OpTest):
def setUp(self): def setUp(self):
self.op_type = "exp" self.op_type = "exp"
self.prim_op_type = "prim" self.prim_op_type = "prim"
...@@ -108,8 +108,7 @@ class TestExpPrimFp32(OpTest): ...@@ -108,8 +108,7 @@ class TestExpPrimFp32(OpTest):
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out} self.outputs = {'Out': out}
self.skip_cinn() self.if_skip_cinn()
self.set_only_prim()
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -123,40 +122,34 @@ class TestExpPrimFp32(OpTest): ...@@ -123,40 +122,34 @@ class TestExpPrimFp32(OpTest):
def init_shape(self): def init_shape(self):
self.shape = [12, 17] self.shape = [12, 17]
def skip_cinn(self): def if_skip_cinn(self):
self.enable_cinn = True self.enable_cinn = True
def set_only_prim(self):
pass
class TestExpPrimFp64(TestExpPrimFp32): class TestExpFp64_Prim(TestExpFp32_Prim):
def init_dtype(self): def init_dtype(self):
self.dtype = np.float64 self.dtype = np.float64
class TestExpPrimFp16(TestExpPrimFp32): class TestExpFp16_Prim(TestExpFp32_Prim):
def init_dtype(self): def init_dtype(self):
self.dtype = np.float16 self.dtype = np.float16
def set_only_prim(self):
self.only_prim = True
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Out', check_prim=True) self.check_grad(['X'], 'Out', check_prim=True, only_check_prim=True)
def skip_cinn(self): def if_skip_cinn(self):
self.enable_cinn = True self.enable_cinn = True
class TestExpPrim_ZeroDim(TestExpPrimFp32): class TestExpPrim_ZeroDim(TestExpFp32_Prim):
def init_shape(self): def init_shape(self):
self.shape = [] self.shape = []
def skip_cinn(self): def if_skip_cinn(self):
self.enable_cinn = False self.enable_cinn = False
...@@ -287,36 +280,6 @@ class TestSigmoid_ZeroDim(TestSigmoid): ...@@ -287,36 +280,6 @@ class TestSigmoid_ZeroDim(TestSigmoid):
self.shape = [] self.shape = []
class TestSigmoidFP16(TestActivation):
def setUp(self):
self.op_type = "sigmoid"
self.prim_op_type = "comp"
self.enable_cinn = False
self.only_prim = True
self.python_api = paddle.nn.functional.sigmoid
self.init_dtype()
self.init_shape()
np.random.seed(1024)
x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
out = 1 / (1 + np.exp(-x))
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
def init_dtype(self):
self.dtype = np.float16
def test_check_grad(self):
self.check_grad(['X'], 'Out', max_relative_error=0.01, check_prim=True)
def test_check_output(self):
check_eager = False
if hasattr(self, 'check_eager'):
check_eager = self.check_eager
self.check_output(check_eager=check_eager, check_prim=True)
@unittest.skipIf( @unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA" not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
) )
...@@ -328,7 +291,6 @@ class TestSigmoidBF16(OpTest): ...@@ -328,7 +291,6 @@ class TestSigmoidBF16(OpTest):
self.python_api = paddle.nn.functional.sigmoid self.python_api = paddle.nn.functional.sigmoid
self.init_dtype() self.init_dtype()
self.init_shape() self.init_shape()
np.random.seed(1024) np.random.seed(1024)
x = np.random.uniform(-1, 1, self.shape).astype(np.float32) x = np.random.uniform(-1, 1, self.shape).astype(np.float32)
out = 1 / (1 + np.exp(-x)) out = 1 / (1 + np.exp(-x))
...@@ -346,7 +308,7 @@ class TestSigmoidBF16(OpTest): ...@@ -346,7 +308,7 @@ class TestSigmoidBF16(OpTest):
def test_check_output(self): def test_check_output(self):
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
# elementwise_pow can not support bfloat16, skip check_prim = True. # elementwise_pow doesn't support bfloat16, skip check_prim here.
self.check_output_with_place(place) self.check_output_with_place(place)
def test_check_grad(self): def test_check_grad(self):
...@@ -370,6 +332,7 @@ class TestSilu(TestActivation): ...@@ -370,6 +332,7 @@ class TestSilu(TestActivation):
self.python_api = paddle.nn.functional.silu self.python_api = paddle.nn.functional.silu
self.init_dtype() self.init_dtype()
self.init_shape() self.init_shape()
self.if_skip_cinn()
np.random.seed(1024) np.random.seed(1024)
x = np.random.uniform(-1, 1, self.shape).astype(self.dtype) x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
...@@ -381,46 +344,19 @@ class TestSilu(TestActivation): ...@@ -381,46 +344,19 @@ class TestSilu(TestActivation):
def init_dtype(self): def init_dtype(self):
self.dtype = np.float32 self.dtype = np.float32
def if_skip_cinn(self):
pass
def test_check_grad(self): def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', check_prim=True) self.check_grad(['X'], 'Out', check_prim=True)
class TestSilu_ZeroDim(TestSilu): class TestSilu_ZeroDim(TestSilu):
def init_shape(self): def init_shape(self):
self.shape = [] self.shape = []
self.enable_cinn = False
class TestSiluFP16(TestActivation):
def setUp(self):
self.op_type = "silu"
self.prim_op_type = "comp"
self.enable_cinn = True
self.only_prim = True
self.python_api = paddle.nn.functional.silu
self.init_dtype()
self.init_shape()
np.random.seed(1024) def if_skip_cinn(self):
x = np.random.uniform(-1, 1, self.shape).astype(self.dtype) self.enable_cinn = False
out = x / (np.exp(-x) + 1)
self.inputs = {'X': x}
self.outputs = {'Out': out}
def init_dtype(self):
self.dtype = np.float16
def test_check_grad(self):
self.check_grad(['X'], 'Out', check_prim=True)
def test_check_output(self):
check_eager = False
if hasattr(self, 'check_eager'):
check_eager = self.check_eager
self.check_output(check_eager=check_eager, check_prim=True)
class TestSiluAPI(unittest.TestCase): class TestSiluAPI(unittest.TestCase):
...@@ -1408,6 +1344,7 @@ class TestCeil_ZeroDim(TestCeil): ...@@ -1408,6 +1344,7 @@ class TestCeil_ZeroDim(TestCeil):
class TestFloor(TestActivation): class TestFloor(TestActivation):
def setUp(self): def setUp(self):
self.op_type = "floor" self.op_type = "floor"
self.prim_op_type = "prim"
self.check_eager = True self.check_eager = True
self.python_api = paddle.floor self.python_api = paddle.floor
self.init_dtype() self.init_dtype()
...@@ -1435,16 +1372,10 @@ class TestFloor_ZeroDim(TestFloor): ...@@ -1435,16 +1372,10 @@ class TestFloor_ZeroDim(TestFloor):
self.shape = [] self.shape = []
class TestFloorPrim(TestActivation): class TestFloor_Prim(TestActivation):
def setUp(self): def setUp(self):
self.op_type = "floor" self.op_type = "floor"
self.prim_op_type = "prim" self.prim_op_type = "prim"
# the gradient on floor, ceil, round is undefined.
# we return zero as gradient, but the numpy return nan.
# for prim, we compare result with eager python api,
# so, we use only_prim flag to express we only test prim.
self.only_prim = True
self.check_eager = True self.check_eager = True
self.python_api = paddle.floor self.python_api = paddle.floor
self.init_dtype() self.init_dtype()
...@@ -1465,15 +1396,19 @@ class TestFloorPrim(TestActivation): ...@@ -1465,15 +1396,19 @@ class TestFloorPrim(TestActivation):
self.shape = [10, 12] self.shape = [10, 12]
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Out', check_prim=True) # the gradient on floor, ceil, round is undefined.
# we return zero as gradient, but the numpy return nan.
# for prim, we compare result with eager python api,
# so, we use only_prim flag to express we only test prim.
self.check_grad(['X'], 'Out', check_prim=True, only_check_prim=True)
class TestFloorPrim_ZeroDim(TestFloorPrim): class TestFloor_ZeroDim_Prim(TestFloor_Prim):
def init_shape(self): def init_shape(self):
self.shape = [] self.shape = []
class TestFloorPrimFp16(TestFloorPrim): class TestFloorFp16_Prim(TestFloor_Prim):
def init_dtype(self): def init_dtype(self):
self.dtype = np.float16 self.dtype = np.float16
...@@ -2284,8 +2219,17 @@ class TestHardSwish(TestActivation): ...@@ -2284,8 +2219,17 @@ class TestHardSwish(TestActivation):
def init_shape(self): def init_shape(self):
self.shape = [10, 12] self.shape = [10, 12]
def if_only_check_prim(self):
return False
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Out', check_eager=True, check_prim=True) self.check_grad(
['X'],
'Out',
check_eager=True,
check_prim=True,
only_check_prim=self.if_only_check_prim(),
)
def test_check_output(self): def test_check_output(self):
self.check_output(check_eager=True, check_prim=True) self.check_output(check_eager=True, check_prim=True)
...@@ -2303,9 +2247,11 @@ class TestHardSwish_ZeroDim(TestHardSwish): ...@@ -2303,9 +2247,11 @@ class TestHardSwish_ZeroDim(TestHardSwish):
class TestHardSwishFP16(TestHardSwish): class TestHardSwishFP16(TestHardSwish):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.only_prim = True
self.enable_cinn = False self.enable_cinn = False
def if_only_check_prim(self):
return True
def init_dtype(self): def init_dtype(self):
self.dtype = np.float16 self.dtype = np.float16
...@@ -3813,7 +3759,12 @@ create_test_act_cudnn_class(TestTanh) ...@@ -3813,7 +3759,12 @@ create_test_act_cudnn_class(TestTanh)
# ------------------ Test Fp16 ---------------------- # ------------------ Test Fp16 ----------------------
def create_test_act_fp16_class( def create_test_act_fp16_class(
parent, atol=1e-3, grad_check=True, grad_atol=0.80 parent,
atol=1e-3,
grad_check=True,
check_prim=False,
enable_cinn=True,
grad_atol=0.80,
): ):
@unittest.skipIf( @unittest.skipIf(
not paddle.is_compiled_with_cuda(), "core is not compiled with CUDA" not paddle.is_compiled_with_cuda(), "core is not compiled with CUDA"
...@@ -3822,18 +3773,27 @@ def create_test_act_fp16_class( ...@@ -3822,18 +3773,27 @@ def create_test_act_fp16_class(
def init_dtype(self): def init_dtype(self):
self.dtype = np.float16 self.dtype = np.float16
def if_skip_cinn(self):
self.enable_cinn = enable_cinn
def test_check_output(self): def test_check_output(self):
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
support_fp16 = core.is_float16_supported(place) support_fp16 = core.is_float16_supported(place)
if support_fp16: if support_fp16:
self.check_output_with_place(place, atol=atol) self.check_output_with_place(
place, atol=atol, check_prim=check_prim
)
def test_check_grad(self): def test_check_grad(self):
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
support_fp16 = core.is_float16_supported(place) support_fp16 = core.is_float16_supported(place)
if support_fp16 and grad_check: if support_fp16 and grad_check:
self.check_grad_with_place( self.check_grad_with_place(
place, ['X'], 'Out', max_relative_error=grad_atol place,
['X'],
'Out',
check_prim=check_prim,
max_relative_error=grad_atol,
) )
cls_name = "{0}_{1}".format(parent.__name__, "fp16") cls_name = "{0}_{1}".format(parent.__name__, "fp16")
...@@ -3843,10 +3803,8 @@ def create_test_act_fp16_class( ...@@ -3843,10 +3803,8 @@ def create_test_act_fp16_class(
create_test_act_fp16_class(TestActivation) create_test_act_fp16_class(TestActivation)
create_test_act_fp16_class(TestExpm1) create_test_act_fp16_class(TestExpm1)
create_test_act_fp16_class(TestSigmoid) create_test_act_fp16_class(TestSigmoid, check_prim=True)
create_test_act_fp16_class(TestSigmoidFP16) create_test_act_fp16_class(TestSilu, check_prim=True)
create_test_act_fp16_class(TestSilu)
create_test_act_fp16_class(TestSiluFP16)
create_test_act_fp16_class(TestLogSigmoid) create_test_act_fp16_class(TestLogSigmoid)
create_test_act_fp16_class(TestTanh) create_test_act_fp16_class(TestTanh)
create_test_act_fp16_class(TestTanhshrink) create_test_act_fp16_class(TestTanhshrink)
...@@ -3855,7 +3813,7 @@ create_test_act_fp16_class(TestSoftshrink) ...@@ -3855,7 +3813,7 @@ create_test_act_fp16_class(TestSoftshrink)
create_test_act_fp16_class(TestSqrt) create_test_act_fp16_class(TestSqrt)
create_test_act_fp16_class(TestAbs) create_test_act_fp16_class(TestAbs)
create_test_act_fp16_class(TestCeil, grad_check=False) create_test_act_fp16_class(TestCeil, grad_check=False)
create_test_act_fp16_class(TestFloor, grad_check=False) create_test_act_fp16_class(TestFloor, check_prim=True, grad_check=False)
create_test_act_fp16_class(TestCos, grad_atol=0.85) create_test_act_fp16_class(TestCos, grad_atol=0.85)
create_test_act_fp16_class(TestTan, grad_atol=0.85) create_test_act_fp16_class(TestTan, grad_atol=0.85)
create_test_act_fp16_class(TestCosh, grad_atol=0.85) create_test_act_fp16_class(TestCosh, grad_atol=0.85)
......
...@@ -37,7 +37,6 @@ class TestElementwiseAddOp(OpTest): ...@@ -37,7 +37,6 @@ class TestElementwiseAddOp(OpTest):
self.init_input_output() self.init_input_output()
self.init_kernel_type() self.init_kernel_type()
self.init_axis() self.init_axis()
self.only_prim()
self.if_check_prim() self.if_check_prim()
self.if_skip_cinn() self.if_skip_cinn()
...@@ -103,9 +102,6 @@ class TestElementwiseAddOp(OpTest): ...@@ -103,9 +102,6 @@ class TestElementwiseAddOp(OpTest):
def init_axis(self): def init_axis(self):
self.axis = -1 self.axis = -1
def only_prim(self):
pass
def if_check_prim(self): def if_check_prim(self):
self.check_prim = self.axis == -1 self.check_prim = self.axis == -1
...@@ -156,41 +152,6 @@ class TestFP16ElementwiseAddOp(TestElementwiseAddOp): ...@@ -156,41 +152,6 @@ class TestFP16ElementwiseAddOp(TestElementwiseAddOp):
check_prim=self.check_prim, check_prim=self.check_prim,
) )
def test_check_grad_normal(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(
place,
['X', 'Y'],
'Out',
check_dygraph=self.check_dygraph(),
check_prim=self.check_prim,
)
def test_check_grad_ingore_x(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(
place,
['Y'],
'Out',
no_grad_set=set("X"),
check_dygraph=self.check_dygraph(),
check_prim=self.check_prim,
)
def test_check_grad_ingore_y(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(
place,
['X'],
'Out',
no_grad_set=set('Y'),
check_dygraph=self.check_dygraph(),
check_prim=self.check_prim,
)
@unittest.skipIf( @unittest.skipIf(
not core.is_compiled_with_cuda() not core.is_compiled_with_cuda()
...@@ -266,9 +227,6 @@ class TestFP16ElementwiseAddOp_scalar(TestFP16ElementwiseAddOp): ...@@ -266,9 +227,6 @@ class TestFP16ElementwiseAddOp_scalar(TestFP16ElementwiseAddOp):
self.y = np.random.rand(1).astype(self.dtype) self.y = np.random.rand(1).astype(self.dtype)
self.out = self.x + self.y self.out = self.x + self.y
def only_prim(self):
self.only_prim = True
@skip_check_grad_ci( @skip_check_grad_ci(
reason="[skip shape check] Use y_shape(1,1) to test broadcast." reason="[skip shape check] Use y_shape(1,1) to test broadcast."
...@@ -519,9 +477,12 @@ class TestFP16ElementwiseAddOp_rowwise_add_0(TestFP16ElementwiseAddOp): ...@@ -519,9 +477,12 @@ class TestFP16ElementwiseAddOp_rowwise_add_0(TestFP16ElementwiseAddOp):
class TestElementwiseAddOp_rowwise_add_1(TestElementwiseAddOp): class TestElementwiseAddOp_rowwise_add_1(TestElementwiseAddOp):
def init_input_output(self): def init_input_output(self):
self.x = np.random.rand(100, 1).astype(self.dtype) self.x = np.random.rand(10, 100, 1).astype(self.dtype)
self.y = np.random.rand(1).astype(self.dtype) self.y = np.random.rand(100, 1).astype(self.dtype)
self.out = self.x + self.y.reshape(1, 1) self.out = self.x + self.y.reshape(1, 100, 1)
def if_skip_cinn(self):
self.enable_cinn = False
class TestFP16ElementwiseAddOp_rowwise_add_1(TestFP16ElementwiseAddOp): class TestFP16ElementwiseAddOp_rowwise_add_1(TestFP16ElementwiseAddOp):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册