未验证 提交 cf7cd247 编写于 作者: F Fisher 提交者: GitHub

[CINN] Enable CINN unittest on atan2, tile, top_k, where (#54280)

* Enable check_cinn on atan2, tile, top_k and where

* Update cmakelists in legacy_test

* Reformat code

* Enable check_cinn on op take_along_axis legacy test

* Enable check_cinn on pool2d

* Remove check_cinn=False

* Try fix tile test error

* Rename enable_cinn to test_cinn

* Refactor test_tile_op

* Replace all enable_cinn to check_cinn

* Revert pool2d test timeout

* Remove check_prim and use enable_cinn
上级 1a30fe54
...@@ -1191,12 +1191,17 @@ set(TEST_CINN_OPS ...@@ -1191,12 +1191,17 @@ set(TEST_CINN_OPS
test_roll_op test_roll_op
test_sum_op test_sum_op
test_elementwise_min_op test_elementwise_min_op
test_atan2_op
test_top_k_op
test_where_op
test_take_along_axis_op
test_arg_min_max_op test_arg_min_max_op
test_reverse_op test_reverse_op
test_flip test_flip
test_triangular_solve_op test_triangular_solve_op
test_scatter_nd_op test_scatter_nd_op
test_strided_slice_op test_strided_slice_op
test_pool2d_op
test_instance_norm_op test_instance_norm_op
test_cumsum_op test_cumsum_op
test_pad_op test_pad_op
......
...@@ -34,6 +34,7 @@ class TestAtan2(OpTest): ...@@ -34,6 +34,7 @@ class TestAtan2(OpTest):
def setUp(self): def setUp(self):
self.op_type = "atan2" self.op_type = "atan2"
self.python_api = paddle.atan2 self.python_api = paddle.atan2
self.check_cinn = True
self.init_dtype() self.init_dtype()
x1 = np.random.uniform(-1, -0.1, [15, 17]).astype(self.dtype) x1 = np.random.uniform(-1, -0.1, [15, 17]).astype(self.dtype)
...@@ -44,10 +45,10 @@ class TestAtan2(OpTest): ...@@ -44,10 +45,10 @@ class TestAtan2(OpTest):
self.outputs = {'Out': out} self.outputs = {'Out': out}
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X1', 'X2'], 'Out') self.check_grad(['X1', 'X2'], 'Out', check_cinn=self.check_cinn)
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_cinn=self.check_cinn)
def init_dtype(self): def init_dtype(self):
self.dtype = np.float64 self.dtype = np.float64
...@@ -67,6 +68,7 @@ class TestAtan2_float(TestAtan2): ...@@ -67,6 +68,7 @@ class TestAtan2_float(TestAtan2):
self.inputs['X2'], self.inputs['X2'],
1 / self.inputs['X1'].size, 1 / self.inputs['X1'].size,
), ),
check_cinn=self.check_cinn,
) )
...@@ -139,6 +141,7 @@ class TestAtan2BF16OP(OpTest): ...@@ -139,6 +141,7 @@ class TestAtan2BF16OP(OpTest):
self.op_type = 'atan2' self.op_type = 'atan2'
self.python_api = paddle.atan2 self.python_api = paddle.atan2
self.dtype = np.uint16 self.dtype = np.uint16
self.check_cinn = True
x1 = np.random.uniform(-1, -0.1, [15, 17]).astype('float32') x1 = np.random.uniform(-1, -0.1, [15, 17]).astype('float32')
x2 = np.random.uniform(0.1, 1, [15, 17]).astype('float32') x2 = np.random.uniform(0.1, 1, [15, 17]).astype('float32')
out = np.arctan2(x1, x2) out = np.arctan2(x1, x2)
...@@ -151,11 +154,13 @@ class TestAtan2BF16OP(OpTest): ...@@ -151,11 +154,13 @@ class TestAtan2BF16OP(OpTest):
def test_check_output(self): def test_check_output(self):
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_output_with_place(place) self.check_output_with_place(place, check_cinn=self.check_cinn)
def test_check_grad(self): def test_check_grad(self):
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X1', 'X2'], 'Out') self.check_grad_with_place(
place, ['X1', 'X2'], 'Out', check_cinn=self.check_cinn
)
class TestAtan2Error(unittest.TestCase): class TestAtan2Error(unittest.TestCase):
......
...@@ -421,7 +421,10 @@ class TestPool2D_Op_Mixin: ...@@ -421,7 +421,10 @@ class TestPool2D_Op_Mixin:
if self.has_cudnn(): if self.has_cudnn():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_output_with_place( self.check_output_with_place(
place, atol=1e-5, check_dygraph=(not self.use_mkldnn) place,
atol=1e-5,
check_dygraph=(not self.use_mkldnn),
check_cinn=True,
) )
else: else:
self.check_output(check_dygraph=(not self.use_mkldnn)) self.check_output(check_dygraph=(not self.use_mkldnn))
...@@ -437,6 +440,7 @@ class TestPool2D_Op_Mixin: ...@@ -437,6 +440,7 @@ class TestPool2D_Op_Mixin:
{'X'}, {'X'},
'Out', 'Out',
check_dygraph=(not self.use_mkldnn), check_dygraph=(not self.use_mkldnn),
check_cinn=True,
) )
elif self.pool_type != "max": elif self.pool_type != "max":
self.check_grad( self.check_grad(
...@@ -586,6 +590,7 @@ def create_test_cudnn_fp16_class(parent, check_grad=True): ...@@ -586,6 +590,7 @@ def create_test_cudnn_fp16_class(parent, check_grad=True):
self.check_output_with_place( self.check_output_with_place(
place, place,
check_dygraph=(not self.use_mkldnn), check_dygraph=(not self.use_mkldnn),
check_cinn=True,
) )
def test_check_grad(self): def test_check_grad(self):
...@@ -601,6 +606,7 @@ def create_test_cudnn_fp16_class(parent, check_grad=True): ...@@ -601,6 +606,7 @@ def create_test_cudnn_fp16_class(parent, check_grad=True):
{'X'}, {'X'},
'Out', 'Out',
check_dygraph=(not self.use_mkldnn), check_dygraph=(not self.use_mkldnn),
check_cinn=True,
) )
cls_name = "{}_{}".format(parent.__name__, "CUDNNFp16Op") cls_name = "{}_{}".format(parent.__name__, "CUDNNFp16Op")
...@@ -625,6 +631,7 @@ def create_test_fp16_class(parent, check_grad=True): ...@@ -625,6 +631,7 @@ def create_test_fp16_class(parent, check_grad=True):
self.check_output_with_place( self.check_output_with_place(
place, place,
check_dygraph=(not self.use_mkldnn), check_dygraph=(not self.use_mkldnn),
check_cinn=True,
) )
def test_check_grad(self): def test_check_grad(self):
...@@ -640,6 +647,7 @@ def create_test_fp16_class(parent, check_grad=True): ...@@ -640,6 +647,7 @@ def create_test_fp16_class(parent, check_grad=True):
{'X'}, {'X'},
'Out', 'Out',
check_dygraph=(not self.use_mkldnn), check_dygraph=(not self.use_mkldnn),
check_cinn=True,
) )
cls_name = "{}_{}".format(parent.__name__, "Fp16Op") cls_name = "{}_{}".format(parent.__name__, "Fp16Op")
...@@ -662,6 +670,7 @@ def create_test_bf16_class(parent, check_grad=True): ...@@ -662,6 +670,7 @@ def create_test_bf16_class(parent, check_grad=True):
self.check_output_with_place( self.check_output_with_place(
place, place,
check_dygraph=(not self.use_mkldnn), check_dygraph=(not self.use_mkldnn),
check_cinn=True,
) )
def test_check_grad(self): def test_check_grad(self):
...@@ -672,6 +681,7 @@ def create_test_bf16_class(parent, check_grad=True): ...@@ -672,6 +681,7 @@ def create_test_bf16_class(parent, check_grad=True):
{'X'}, {'X'},
'Out', 'Out',
check_dygraph=(not self.use_mkldnn), check_dygraph=(not self.use_mkldnn),
check_cinn=True,
) )
cls_name = "{}_{}".format(parent.__name__, "Bf16Op") cls_name = "{}_{}".format(parent.__name__, "Bf16Op")
...@@ -1001,10 +1011,12 @@ class TestCase5_Max(TestCase2): ...@@ -1001,10 +1011,12 @@ class TestCase5_Max(TestCase2):
if self.has_cudnn() and self.pool_type == "max": if self.has_cudnn() and self.pool_type == "max":
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_grad_with_place( self.check_grad_with_place(
place, {'X'}, 'Out', max_relative_error=1.00 place, {'X'}, 'Out', max_relative_error=1.00, check_cinn=True
) )
elif self.pool_type == "max": elif self.pool_type == "max":
self.check_grad({'X'}, 'Out', max_relative_error=1.00) self.check_grad(
{'X'}, 'Out', max_relative_error=1.00, check_cinn=True
)
class TestCase5_channel_last_Max(TestCase5_Max): class TestCase5_channel_last_Max(TestCase5_Max):
......
...@@ -28,6 +28,7 @@ class TestTakeAlongAxisOp(OpTest): ...@@ -28,6 +28,7 @@ class TestTakeAlongAxisOp(OpTest):
self.init_data() self.init_data()
self.op_type = "take_along_axis" self.op_type = "take_along_axis"
self.python_api = paddle.tensor.take_along_axis self.python_api = paddle.tensor.take_along_axis
self.check_cinn = True
self.xnp = np.random.random(self.x_shape).astype(self.x_type) self.xnp = np.random.random(self.x_shape).astype(self.x_type)
self.target = np.take_along_axis(self.xnp, self.index, self.axis) self.target = np.take_along_axis(self.xnp, self.index, self.axis)
broadcast_shape_list = list(self.x_shape) broadcast_shape_list = list(self.x_shape)
...@@ -42,10 +43,10 @@ class TestTakeAlongAxisOp(OpTest): ...@@ -42,10 +43,10 @@ class TestTakeAlongAxisOp(OpTest):
self.outputs = {'Result': self.target} self.outputs = {'Result': self.target}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_cinn=self.check_cinn)
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['Input'], 'Result') self.check_grad(['Input'], 'Result', check_cinn=self.check_cinn)
def init_data(self): def init_data(self):
self.x_type = "float64" self.x_type = "float64"
...@@ -81,6 +82,7 @@ class TestTakeAlongAxisBF16Op(OpTest): ...@@ -81,6 +82,7 @@ class TestTakeAlongAxisBF16Op(OpTest):
self.init_data() self.init_data()
self.op_type = "take_along_axis" self.op_type = "take_along_axis"
self.python_api = paddle.tensor.take_along_axis self.python_api = paddle.tensor.take_along_axis
self.check_cinn = True
self.xnp = np.random.random(self.x_shape).astype(self.x_type) self.xnp = np.random.random(self.x_shape).astype(self.x_type)
self.target = np.take_along_axis(self.xnp, self.index, self.axis) self.target = np.take_along_axis(self.xnp, self.index, self.axis)
broadcast_shape_list = list(self.x_shape) broadcast_shape_list = list(self.x_shape)
...@@ -99,10 +101,12 @@ class TestTakeAlongAxisBF16Op(OpTest): ...@@ -99,10 +101,12 @@ class TestTakeAlongAxisBF16Op(OpTest):
self.place = core.CUDAPlace(0) self.place = core.CUDAPlace(0)
def test_check_output(self): def test_check_output(self):
self.check_output_with_place(self.place) self.check_output_with_place(self.place, check_cinn=self.check_cinn)
def test_check_grad(self): def test_check_grad(self):
self.check_grad_with_place(self.place, ['Input'], 'Result') self.check_grad_with_place(
self.place, ['Input'], 'Result', check_cinn=self.check_cinn
)
def init_data(self): def init_data(self):
self.dtype = np.uint16 self.dtype = np.uint16
......
...@@ -40,14 +40,14 @@ class TestTileOpRank1(OpTest): ...@@ -40,14 +40,14 @@ class TestTileOpRank1(OpTest):
self.outputs = {'Out': output} self.outputs = {'Out': output}
def if_enable_cinn(self): def if_enable_cinn(self):
pass self.check_cinn = True
def init_data(self): def init_data(self):
self.ori_shape = [100] self.ori_shape = [100]
self.repeat_times = [2] self.repeat_times = [2]
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_cinn=self.check_cinn)
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Out', check_prim=True) self.check_grad(['X'], 'Out', check_prim=True)
...@@ -59,6 +59,7 @@ class TestTileOpRank_ZeroDim1(TestTileOpRank1): ...@@ -59,6 +59,7 @@ class TestTileOpRank_ZeroDim1(TestTileOpRank1):
self.repeat_times = [] self.repeat_times = []
def if_enable_cinn(self): def if_enable_cinn(self):
self.check_cinn = False
self.enable_cinn = False self.enable_cinn = False
...@@ -68,6 +69,7 @@ class TestTileOpRank_ZeroDim2(TestTileOpRank1): ...@@ -68,6 +69,7 @@ class TestTileOpRank_ZeroDim2(TestTileOpRank1):
self.repeat_times = [2] self.repeat_times = [2]
def if_enable_cinn(self): def if_enable_cinn(self):
self.check_cinn = False
self.enable_cinn = False self.enable_cinn = False
...@@ -77,6 +79,7 @@ class TestTileOpRank_ZeroDim3(TestTileOpRank1): ...@@ -77,6 +79,7 @@ class TestTileOpRank_ZeroDim3(TestTileOpRank1):
self.repeat_times = [2, 3] self.repeat_times = [2, 3]
def if_enable_cinn(self): def if_enable_cinn(self):
self.check_cinn = False
self.enable_cinn = False self.enable_cinn = False
...@@ -86,38 +89,57 @@ class TestTileOpRank2Expanding(TestTileOpRank1): ...@@ -86,38 +89,57 @@ class TestTileOpRank2Expanding(TestTileOpRank1):
self.ori_shape = [120] self.ori_shape = [120]
self.repeat_times = [2, 2] self.repeat_times = [2, 2]
def if_enable_cinn(self):
self.check_cinn = True
class TestTileOpRank2(TestTileOpRank1): class TestTileOpRank2(TestTileOpRank1):
def init_data(self): def init_data(self):
self.ori_shape = [12, 14] self.ori_shape = [12, 14]
self.repeat_times = [2, 3] self.repeat_times = [2, 3]
def if_enable_cinn(self):
self.check_cinn = True
class TestTileOpRank3_Corner(TestTileOpRank1): class TestTileOpRank3_Corner(TestTileOpRank1):
def init_data(self): def init_data(self):
self.ori_shape = (2, 10, 5) self.ori_shape = (2, 10, 5)
self.repeat_times = (1, 1, 1) self.repeat_times = (1, 1, 1)
def if_enable_cinn(self):
self.check_cinn = True
class TestTileOpRank3_Corner2(TestTileOpRank1): class TestTileOpRank3_Corner2(TestTileOpRank1):
def init_data(self): def init_data(self):
self.ori_shape = (2, 10, 5) self.ori_shape = (2, 10, 5)
self.repeat_times = (2, 2) self.repeat_times = (2, 2)
def if_enable_cinn(self):
self.check_cinn = True
class TestTileOpRank3(TestTileOpRank1): class TestTileOpRank3(TestTileOpRank1):
def init_data(self): def init_data(self):
self.ori_shape = (2, 4, 15) self.ori_shape = (2, 4, 15)
self.repeat_times = (2, 1, 4) self.repeat_times = (2, 1, 4)
def if_enable_cinn(self):
self.check_cinn = True
class TestTileOpRank4(TestTileOpRank1): class TestTileOpRank4(TestTileOpRank1):
def init_data(self): def init_data(self):
self.ori_shape = (2, 4, 5, 7) self.ori_shape = (2, 4, 5, 7)
self.repeat_times = (3, 2, 1, 2) self.repeat_times = (3, 2, 1, 2)
def if_enable_cinn(self):
self.check_cinn = True
# Situation 2: repeat_times is a list (with tensor) # Situation 2: repeat_times is a list (with tensor)
# CINN not support repeat_times is a tensor now
class TestTileOpRank1_tensor_attr(OpTest): class TestTileOpRank1_tensor_attr(OpTest):
def setUp(self): def setUp(self):
self.op_type = "tile" self.op_type = "tile"
...@@ -164,6 +186,7 @@ class TestTileOpRank2_attr_tensor(TestTileOpRank1_tensor_attr): ...@@ -164,6 +186,7 @@ class TestTileOpRank2_attr_tensor(TestTileOpRank1_tensor_attr):
# Situation 3: repeat_times is a tensor # Situation 3: repeat_times is a tensor
# CINN not support repeat_times is a tensor now
class TestTileOpRank1_tensor(OpTest): class TestTileOpRank1_tensor(OpTest):
def setUp(self): def setUp(self):
self.op_type = "tile" self.op_type = "tile"
...@@ -206,9 +229,13 @@ class TestTileOpInteger(OpTest): ...@@ -206,9 +229,13 @@ class TestTileOpInteger(OpTest):
self.attrs = {'repeat_times': [2, 1, 4]} self.attrs = {'repeat_times': [2, 1, 4]}
output = np.tile(self.inputs['X'], (2, 1, 4)) output = np.tile(self.inputs['X'], (2, 1, 4))
self.outputs = {'Out': output} self.outputs = {'Out': output}
self.if_enable_cinn()
def if_enable_cinn(self):
self.check_cinn = True
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_cinn=self.check_cinn)
class TestTileFP16OP(OpTest): class TestTileFP16OP(OpTest):
...@@ -217,7 +244,6 @@ class TestTileFP16OP(OpTest): ...@@ -217,7 +244,6 @@ class TestTileFP16OP(OpTest):
self.dtype = np.float16 self.dtype = np.float16
self.python_api = paddle.tile self.python_api = paddle.tile
self.prim_op_type = "prim" self.prim_op_type = "prim"
self.enable_cinn = True
self.public_python_api = paddle.tile self.public_python_api = paddle.tile
self.init_data() self.init_data()
x = np.random.uniform(10, size=self.ori_shape).astype(self.dtype) x = np.random.uniform(10, size=self.ori_shape).astype(self.dtype)
...@@ -225,6 +251,10 @@ class TestTileFP16OP(OpTest): ...@@ -225,6 +251,10 @@ class TestTileFP16OP(OpTest):
self.inputs = {'X': x} self.inputs = {'X': x}
self.attrs = {'repeat_times': self.repeat_times} self.attrs = {'repeat_times': self.repeat_times}
self.outputs = {'Out': output} self.outputs = {'Out': output}
self.if_enable_cinn()
def if_enable_cinn(self):
self.check_cinn = True
def init_data(self): def init_data(self):
self.dtype = np.float16 self.dtype = np.float16
...@@ -232,7 +262,7 @@ class TestTileFP16OP(OpTest): ...@@ -232,7 +262,7 @@ class TestTileFP16OP(OpTest):
self.repeat_times = [2, 1, 4] self.repeat_times = [2, 1, 4]
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_cinn=self.check_cinn)
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Out', check_prim=True) self.check_grad(['X'], 'Out', check_prim=True)
...@@ -256,10 +286,14 @@ class TestTileBF16OP(OpTest): ...@@ -256,10 +286,14 @@ class TestTileBF16OP(OpTest):
self.inputs = {'X': convert_float_to_uint16(x)} self.inputs = {'X': convert_float_to_uint16(x)}
self.attrs = {'repeat_times': self.repeat_times} self.attrs = {'repeat_times': self.repeat_times}
self.outputs = {'Out': convert_float_to_uint16(output)} self.outputs = {'Out': convert_float_to_uint16(output)}
self.if_enable_cinn()
def if_enable_cinn(self):
self.check_cinn = True
def test_check_output(self): def test_check_output(self):
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_output_with_place(place) self.check_output_with_place(place, check_cinn=self.check_cinn)
def init_data(self): def init_data(self):
self.dtype = np.uint16 self.dtype = np.uint16
...@@ -280,9 +314,13 @@ class TestTileOpBoolean(OpTest): ...@@ -280,9 +314,13 @@ class TestTileOpBoolean(OpTest):
self.attrs = {'repeat_times': [2, 1, 4]} self.attrs = {'repeat_times': [2, 1, 4]}
output = np.tile(self.inputs['X'], (2, 1, 4)) output = np.tile(self.inputs['X'], (2, 1, 4))
self.outputs = {'Out': output} self.outputs = {'Out': output}
self.if_enable_cinn()
def if_enable_cinn(self):
self.check_cinn = True
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_cinn=self.check_cinn)
# Situation 56: input x is Integer # Situation 56: input x is Integer
...@@ -296,9 +334,13 @@ class TestTileOpInt64_t(OpTest): ...@@ -296,9 +334,13 @@ class TestTileOpInt64_t(OpTest):
self.attrs = {'repeat_times': [2, 1, 4]} self.attrs = {'repeat_times': [2, 1, 4]}
output = np.tile(self.inputs['X'], (2, 1, 4)) output = np.tile(self.inputs['X'], (2, 1, 4))
self.outputs = {'Out': output} self.outputs = {'Out': output}
self.if_enable_cinn()
def if_enable_cinn(self):
self.check_cinn = True
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_cinn=self.check_cinn)
class TestTileError(unittest.TestCase): class TestTileError(unittest.TestCase):
......
...@@ -26,6 +26,7 @@ class TestTopkOp(OpTest): ...@@ -26,6 +26,7 @@ class TestTopkOp(OpTest):
self.set_args() self.set_args()
self.op_type = "top_k" self.op_type = "top_k"
self.dtype = np.float64 self.dtype = np.float64
self.check_cinn = True
self.init_dtype() self.init_dtype()
k = self.top_k k = self.top_k
...@@ -54,10 +55,10 @@ class TestTopkOp(OpTest): ...@@ -54,10 +55,10 @@ class TestTopkOp(OpTest):
self.top_k = 1 self.top_k = 1
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_cinn=self.check_cinn)
def test_check_grad(self): def test_check_grad(self):
self.check_grad({'X'}, 'Out') self.check_grad({'X'}, 'Out', check_cinn=self.check_cinn)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -27,15 +27,16 @@ class TestWhereOp(OpTest): ...@@ -27,15 +27,16 @@ class TestWhereOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = 'where' self.op_type = 'where'
self.python_api = paddle.where self.python_api = paddle.where
self.check_cinn = True
self.init_config() self.init_config()
self.inputs = {'Condition': self.cond, 'X': self.x, 'Y': self.y} self.inputs = {'Condition': self.cond, 'X': self.x, 'Y': self.y}
self.outputs = {'Out': np.where(self.cond, self.x, self.y)} self.outputs = {'Out': np.where(self.cond, self.x, self.y)}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_cinn=self.check_cinn)
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X', 'Y'], 'Out') self.check_grad(['X', 'Y'], 'Out', check_cinn=self.check_cinn)
def init_config(self): def init_config(self):
self.x = np.random.uniform((-3), 5, 100).astype('float64') self.x = np.random.uniform((-3), 5, 100).astype('float64')
...@@ -68,6 +69,7 @@ class TestWhereBF16OP(OpTest): ...@@ -68,6 +69,7 @@ class TestWhereBF16OP(OpTest):
self.op_type = 'where' self.op_type = 'where'
self.dtype = np.uint16 self.dtype = np.uint16
self.python_api = paddle.where self.python_api = paddle.where
self.check_cinn = True
self.init_config() self.init_config()
self.inputs = { self.inputs = {
'Condition': self.cond, 'Condition': self.cond,
...@@ -80,12 +82,16 @@ class TestWhereBF16OP(OpTest): ...@@ -80,12 +82,16 @@ class TestWhereBF16OP(OpTest):
def test_check_output(self): def test_check_output(self):
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_output_with_place(place) self.check_output_with_place(place, check_cinn=self.check_cinn)
def test_check_grad(self): def test_check_grad(self):
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_grad_with_place( self.check_grad_with_place(
place, ['X', 'Y'], 'Out', numeric_grad_delta=0.05 place,
['X', 'Y'],
'Out',
numeric_grad_delta=0.05,
check_cinn=self.check_cinn,
) )
def init_config(self): def init_config(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册