未验证 提交 585f1136 编写于 作者: C Charles-hit 提交者: GitHub

support some prim ops for bf16 dtype (#54285)

上级 f5342918
...@@ -2735,6 +2735,7 @@ def gather(x, index, axis=None, name=None): ...@@ -2735,6 +2735,7 @@ def gather(x, index, axis=None, name=None):
'int32', 'int32',
'int64', 'int64',
'uint8', 'uint8',
'uint16',
], ],
'gather', 'gather',
) )
......
...@@ -71,7 +71,7 @@ class TestElementwiseOp(OpTest): ...@@ -71,7 +71,7 @@ class TestElementwiseOp(OpTest):
self.check_prim = True self.check_prim = True
def if_enable_cinn(self): def if_enable_cinn(self):
self.enable_cinn = False pass
class TestElementwiseFP16OP(TestElementwiseOp): class TestElementwiseFP16OP(TestElementwiseOp):
...@@ -87,6 +87,7 @@ class TestElementwiseFP16OP(TestElementwiseOp): ...@@ -87,6 +87,7 @@ class TestElementwiseFP16OP(TestElementwiseOp):
class TestElementwiseBF16OP(TestElementwiseOp): class TestElementwiseBF16OP(TestElementwiseOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_sub" self.op_type = "elementwise_sub"
self.prim_op_type = "prim"
self.dtype = np.uint16 self.dtype = np.uint16
self.python_api = paddle.subtract self.python_api = paddle.subtract
self.public_python_api = paddle.subtract self.public_python_api = paddle.subtract
...@@ -103,6 +104,9 @@ class TestElementwiseBF16OP(TestElementwiseOp): ...@@ -103,6 +104,9 @@ class TestElementwiseBF16OP(TestElementwiseOp):
self.if_check_prim() self.if_check_prim()
self.if_enable_cinn() self.if_enable_cinn()
def if_enable_cinn(self):
self.enable_cinn = False
def test_check_grad_normal(self): def test_check_grad_normal(self):
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_grad_with_place( self.check_grad_with_place(
...@@ -118,7 +122,12 @@ class TestElementwiseBF16OP(TestElementwiseOp): ...@@ -118,7 +122,12 @@ class TestElementwiseBF16OP(TestElementwiseOp):
def test_check_grad_ingore_y(self): def test_check_grad_ingore_y(self):
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_grad_with_place( self.check_grad_with_place(
place, ['X'], 'Out', no_grad_set=set('Y'), max_relative_error=0.1 place,
['X'],
'Out',
no_grad_set=set('Y'),
max_relative_error=0.1,
check_prim=True,
) )
...@@ -135,6 +144,10 @@ class TestElementwiseSubOp_ZeroDim1(TestElementwiseOp): ...@@ -135,6 +144,10 @@ class TestElementwiseSubOp_ZeroDim1(TestElementwiseOp):
} }
self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']} self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']}
self.if_check_prim() self.if_check_prim()
self.if_enable_cinn()
def if_enable_cinn(self):
self.enable_cinn = False
class TestElementwiseSubFP16OP_ZeroDim1(TestElementwiseSubOp_ZeroDim1): class TestElementwiseSubFP16OP_ZeroDim1(TestElementwiseSubOp_ZeroDim1):
...@@ -181,6 +194,10 @@ class TestElementwiseSubOp_ZeroDim2(TestElementwiseOp): ...@@ -181,6 +194,10 @@ class TestElementwiseSubOp_ZeroDim2(TestElementwiseOp):
} }
self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']} self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']}
self.if_check_prim() self.if_check_prim()
self.if_enable_cinn()
def if_enable_cinn(self):
self.enable_cinn = False
class TestElementwiseSubFP16OP_ZeroDim2(TestElementwiseSubOp_ZeroDim2): class TestElementwiseSubFP16OP_ZeroDim2(TestElementwiseSubOp_ZeroDim2):
...@@ -227,6 +244,10 @@ class TestElementwiseSubOp_ZeroDim3(TestElementwiseOp): ...@@ -227,6 +244,10 @@ class TestElementwiseSubOp_ZeroDim3(TestElementwiseOp):
} }
self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']} self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']}
self.if_check_prim() self.if_check_prim()
self.if_enable_cinn()
def if_enable_cinn(self):
self.enable_cinn = False
class TestElementwiseSubFP16OP_ZeroDim3(TestElementwiseSubOp_ZeroDim3): class TestElementwiseSubFP16OP_ZeroDim3(TestElementwiseSubOp_ZeroDim3):
...@@ -580,6 +601,7 @@ class TestElementwiseSubOp_broadcast_4(TestElementwiseOp): ...@@ -580,6 +601,7 @@ class TestElementwiseSubOp_broadcast_4(TestElementwiseOp):
} }
self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']} self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']}
self.if_check_prim() self.if_check_prim()
self.if_enable_cinn()
@unittest.skipIf( @unittest.skipIf(
...@@ -653,6 +675,7 @@ class TestElementwiseBF16OP_commonuse_1(TestElementwiseBF16OP): ...@@ -653,6 +675,7 @@ class TestElementwiseBF16OP_commonuse_1(TestElementwiseBF16OP):
} }
self.outputs = {'Out': convert_float_to_uint16(self.outputs['Out'])} self.outputs = {'Out': convert_float_to_uint16(self.outputs['Out'])}
self.if_check_prim() self.if_check_prim()
self.if_enable_cinn()
class TestElementwiseSubOp_commonuse_2(TestElementwiseOp): class TestElementwiseSubOp_commonuse_2(TestElementwiseOp):
...@@ -698,6 +721,7 @@ class TestElementwiseBF16OP_commonuse_2(TestElementwiseBF16OP): ...@@ -698,6 +721,7 @@ class TestElementwiseBF16OP_commonuse_2(TestElementwiseBF16OP):
} }
self.outputs = {'Out': convert_float_to_uint16(self.outputs['Out'])} self.outputs = {'Out': convert_float_to_uint16(self.outputs['Out'])}
self.if_check_prim() self.if_check_prim()
self.if_enable_cinn()
class TestElementwiseSubOp_xsize_lessthan_ysize(TestElementwiseOp): class TestElementwiseSubOp_xsize_lessthan_ysize(TestElementwiseOp):
...@@ -717,6 +741,7 @@ class TestElementwiseSubOp_xsize_lessthan_ysize(TestElementwiseOp): ...@@ -717,6 +741,7 @@ class TestElementwiseSubOp_xsize_lessthan_ysize(TestElementwiseOp):
'Out': self.inputs['X'].reshape(1, 1, 10, 12) - self.inputs['Y'] 'Out': self.inputs['X'].reshape(1, 1, 10, 12) - self.inputs['Y']
} }
self.if_check_prim() self.if_check_prim()
self.if_enable_cinn()
class TestElementwiseSubFP16OP_xsize_lessthan_ysize( class TestElementwiseSubFP16OP_xsize_lessthan_ysize(
...@@ -750,6 +775,7 @@ class TestElementwiseBF16OP_xsize_lessthan_ysize(TestElementwiseBF16OP): ...@@ -750,6 +775,7 @@ class TestElementwiseBF16OP_xsize_lessthan_ysize(TestElementwiseBF16OP):
} }
self.outputs = {'Out': convert_float_to_uint16(self.outputs['Out'])} self.outputs = {'Out': convert_float_to_uint16(self.outputs['Out'])}
self.if_check_prim() self.if_check_prim()
self.if_enable_cinn()
class TestComplexElementwiseSubOp(OpTest): class TestComplexElementwiseSubOp(OpTest):
......
...@@ -37,12 +37,8 @@ class TestGatherOp(OpTest): ...@@ -37,12 +37,8 @@ class TestGatherOp(OpTest):
self.public_python_api = paddle.gather self.public_python_api = paddle.gather
self.config() self.config()
self.prim_op_type = "prim" self.prim_op_type = "prim"
xnp = np.random.random(self.x_shape).astype(self.x_type) self.init_inputs_and_outputs()
self.inputs = { self.if_enable_cinn()
'X': xnp,
'Index': np.array(self.index).astype(self.index_type),
}
self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -62,12 +58,56 @@ class TestGatherOp(OpTest): ...@@ -62,12 +58,56 @@ class TestGatherOp(OpTest):
def config_dtype(self): def config_dtype(self):
self.x_type = "float64" self.x_type = "float64"
def init_inputs_and_outputs(self):
xnp = np.random.random(self.x_shape).astype(self.x_type)
self.inputs = {
'X': xnp,
'Index': np.array(self.index).astype(self.index_type),
}
self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]}
def if_enable_cinn(self):
pass
class TestGatherOpFP16(TestGatherOp): class TestGatherOpFP16(TestGatherOp):
def config_dtype(self): def config_dtype(self):
self.x_type = "float16" self.x_type = "float16"
@unittest.skipIf(
not core.is_compiled_with_cuda()
or core.cudnn_version() < 8100
or paddle.device.cuda.get_device_capability()[0] < 8,
"only support compiled with CUDA and cudnn version need larger than 8.1.0 and device's compute capability is at least 8.0",
)
class TestGatherOpBFP16(TestGatherOp):
def config_dtype(self):
self.x_type = "float32"
self.dtype = np.uint16
def init_inputs_and_outputs(self):
xnp = np.random.random(self.x_shape).astype(self.x_type)
self.inputs = {
'X': convert_float_to_uint16(xnp),
'Index': np.array(self.index).astype(self.index_type),
}
self.outputs = {
'Out': convert_float_to_uint16(xnp[self.inputs["Index"]])
}
def if_enable_cinn(self):
self.enable_cinn = False
def test_check_output(self):
self.check_output_with_place(place=paddle.CUDAPlace(0))
def test_check_grad(self):
self.check_grad_with_place(
paddle.CUDAPlace(0), ['X'], 'Out', check_prim=True
)
class TestCase1(TestGatherOp): class TestCase1(TestGatherOp):
def config(self): def config(self):
""" """
...@@ -87,6 +127,14 @@ class TestCase1FP16(TestCase1): ...@@ -87,6 +127,14 @@ class TestCase1FP16(TestCase1):
self.x_type = "float16" self.x_type = "float16"
class TestCase1BFP16(TestGatherOpBFP16):
def config(self):
self.x_shape = 100
self.config_dtype()
self.index = [1, 3, 5]
self.index_type = "int32"
class TestCase2(TestGatherOp): class TestCase2(TestGatherOp):
def config(self): def config(self):
""" """
...@@ -106,6 +154,14 @@ class TestCase2FP16(TestCase2): ...@@ -106,6 +154,14 @@ class TestCase2FP16(TestCase2):
self.x_type = "float16" self.x_type = "float16"
class TestCase2BFP16(TestGatherOpBFP16):
def config(self):
self.x_shape = 100
self.config_dtype()
self.index = [1, 3, 5]
self.index_type = "int64"
class TestCase3(TestGatherOp): class TestCase3(TestGatherOp):
def config(self): def config(self):
""" """
...@@ -125,6 +181,14 @@ class TestCase3Fp16(TestCase3): ...@@ -125,6 +181,14 @@ class TestCase3Fp16(TestCase3):
self.x_type = "float16" self.x_type = "float16"
class TestCase3BFP16(TestGatherOpBFP16):
def config(self):
self.x_shape = (10, 20)
self.config_dtype()
self.index = [1, 3, 5]
self.index_type = "int64"
class TestCase4(TestGatherOp): class TestCase4(TestGatherOp):
def config(self): def config(self):
self.x_shape = (10, 20) self.x_shape = (10, 20)
...@@ -142,6 +206,15 @@ class TestCase4FP16(TestCase4): ...@@ -142,6 +206,15 @@ class TestCase4FP16(TestCase4):
self.x_type = "float16" self.x_type = "float16"
class TestCase4BFP16(TestGatherOpBFP16):
def config(self):
self.x_shape = (10, 20)
self.attrs = {'overwrite': False}
self.config_dtype()
self.index = [1, 1]
self.index_type = "int32"
class TestCase5(TestGatherOp): class TestCase5(TestGatherOp):
def config(self): def config(self):
self.x_shape = (10, 20) self.x_shape = (10, 20)
...@@ -154,6 +227,15 @@ class TestCase5(TestGatherOp): ...@@ -154,6 +227,15 @@ class TestCase5(TestGatherOp):
self.x_type = "float64" self.x_type = "float64"
class TestCase5BFP16(TestGatherOpBFP16):
def config(self):
self.x_shape = (10, 20)
self.attrs = {'overwrite': False}
self.config_dtype()
self.index = [1, 1]
self.index_type = "int32"
class TestCase5FP16(TestCase5): class TestCase5FP16(TestCase5):
def config_dtype(self): def config_dtype(self):
self.x_type = "float16" self.x_type = "float16"
...@@ -176,6 +258,15 @@ class TestCase6FP16(TestCase6): ...@@ -176,6 +258,15 @@ class TestCase6FP16(TestCase6):
self.x_type = "float16" self.x_type = "float16"
class TestCase6BFP16(TestGatherOpBFP16):
def config(self):
self.x_shape = (10, 20)
self.attrs = {'overwrite': True}
self.config_dtype()
self.index = [1, 3]
self.index_type = "int32"
class TestGatherBF16Op(OpTest): class TestGatherBF16Op(OpTest):
def setUp(self): def setUp(self):
self.op_type = "gather" self.op_type = "gather"
......
...@@ -36,7 +36,7 @@ class TestSumOp(OpTest): ...@@ -36,7 +36,7 @@ class TestSumOp(OpTest):
self.prim_op_type = "prim" self.prim_op_type = "prim"
self.inputs = {'X': self.x} self.inputs = {'X': self.x}
self.outputs = {'Out': self.out} self.outputs = {'Out': self.out}
self.enable_cinn = True self.if_enable_cinn()
def init_dtype(self): def init_dtype(self):
self.dtype = np.float64 self.dtype = np.float64
...@@ -47,6 +47,9 @@ class TestSumOp(OpTest): ...@@ -47,6 +47,9 @@ class TestSumOp(OpTest):
def init_attrs(self): def init_attrs(self):
self.attrs = {'dim': [0]} self.attrs = {'dim': [0]}
def if_enable_cinn(self):
pass
def calc_output(self): def calc_output(self):
self.out = self.x.sum(axis=tuple(self.attrs['dim'])) self.out = self.x.sum(axis=tuple(self.attrs['dim']))
...@@ -984,7 +987,10 @@ class Test1DReduce(OpTest): ...@@ -984,7 +987,10 @@ class Test1DReduce(OpTest):
self.prim_op_type = "prim" self.prim_op_type = "prim"
self.inputs = {'X': np.random.random(120).astype("float64")} self.inputs = {'X': np.random.random(120).astype("float64")}
self.outputs = {'Out': self.inputs['X'].sum(axis=0)} self.outputs = {'Out': self.inputs['X'].sum(axis=0)}
self.enable_cinn = True self.if_enable_cinn()
def if_enable_cinn(self):
pass
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -1002,6 +1008,7 @@ class Test2DReduce0(Test1DReduce): ...@@ -1002,6 +1008,7 @@ class Test2DReduce0(Test1DReduce):
self.attrs = {'dim': [0]} self.attrs = {'dim': [0]}
self.inputs = {'X': np.random.random((20, 10)).astype("float64")} self.inputs = {'X': np.random.random((20, 10)).astype("float64")}
self.outputs = {'Out': self.inputs['X'].sum(axis=0)} self.outputs = {'Out': self.inputs['X'].sum(axis=0)}
self.if_enable_cinn()
class Test2DReduce1(Test1DReduce): class Test2DReduce1(Test1DReduce):
...@@ -1015,6 +1022,7 @@ class Test2DReduce1(Test1DReduce): ...@@ -1015,6 +1022,7 @@ class Test2DReduce1(Test1DReduce):
self.outputs = { self.outputs = {
'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim'])) 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']))
} }
self.if_enable_cinn()
class Test3DReduce0(Test1DReduce): class Test3DReduce0(Test1DReduce):
...@@ -1028,6 +1036,7 @@ class Test3DReduce0(Test1DReduce): ...@@ -1028,6 +1036,7 @@ class Test3DReduce0(Test1DReduce):
self.outputs = { self.outputs = {
'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim'])) 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']))
} }
self.if_enable_cinn()
class Test3DReduce1(Test1DReduce): class Test3DReduce1(Test1DReduce):
...@@ -1041,6 +1050,7 @@ class Test3DReduce1(Test1DReduce): ...@@ -1041,6 +1050,7 @@ class Test3DReduce1(Test1DReduce):
self.outputs = { self.outputs = {
'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim'])) 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']))
} }
self.if_enable_cinn()
class Test3DReduce2(Test1DReduce): class Test3DReduce2(Test1DReduce):
...@@ -1054,6 +1064,7 @@ class Test3DReduce2(Test1DReduce): ...@@ -1054,6 +1064,7 @@ class Test3DReduce2(Test1DReduce):
self.outputs = { self.outputs = {
'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim'])) 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']))
} }
self.if_enable_cinn()
class Test3DReduce3(Test1DReduce): class Test3DReduce3(Test1DReduce):
...@@ -1067,6 +1078,7 @@ class Test3DReduce3(Test1DReduce): ...@@ -1067,6 +1078,7 @@ class Test3DReduce3(Test1DReduce):
self.outputs = { self.outputs = {
'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim'])) 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']))
} }
self.if_enable_cinn()
def reduce_sum_wrapper2(x, axis=[0], dtype=None, keepdim=False): def reduce_sum_wrapper2(x, axis=[0], dtype=None, keepdim=False):
...@@ -1105,6 +1117,7 @@ class TestKeepDimReduce(Test1DReduce): ...@@ -1105,6 +1117,7 @@ class TestKeepDimReduce(Test1DReduce):
axis=tuple(self.attrs['dim']), keepdims=self.attrs['keep_dim'] axis=tuple(self.attrs['dim']), keepdims=self.attrs['keep_dim']
) )
} }
self.if_enable_cinn()
class TestKeepDimReduceForEager(Test1DReduce): class TestKeepDimReduceForEager(Test1DReduce):
...@@ -1208,6 +1221,10 @@ class TestKeepDimReduceSumMultiAxises(OpTest): ...@@ -1208,6 +1221,10 @@ class TestKeepDimReduceSumMultiAxises(OpTest):
axis=tuple(self.attrs['dim']), keepdims=True axis=tuple(self.attrs['dim']), keepdims=True
) )
} }
self.if_enable_cinn()
def if_enable_cinn(self):
pass
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -1248,7 +1265,10 @@ class TestReduceSumWithDimOne(OpTest): ...@@ -1248,7 +1265,10 @@ class TestReduceSumWithDimOne(OpTest):
axis=tuple(self.attrs['dim']), keepdims=True axis=tuple(self.attrs['dim']), keepdims=True
) )
} }
self.enable_cinn = True self.if_enable_cinn()
def if_enable_cinn(self):
pass
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -1290,7 +1310,10 @@ class TestReduceSumWithNumelOne(OpTest): ...@@ -1290,7 +1310,10 @@ class TestReduceSumWithNumelOne(OpTest):
axis=tuple(self.attrs['dim']), keepdims=False axis=tuple(self.attrs['dim']), keepdims=False
) )
} }
self.enable_cinn = True self.if_enable_cinn()
def if_enable_cinn(self):
pass
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -1314,7 +1337,10 @@ class TestReduceAll(OpTest): ...@@ -1314,7 +1337,10 @@ class TestReduceAll(OpTest):
self.inputs = {'X': np.random.random((100, 1, 1)).astype("float64")} self.inputs = {'X': np.random.random((100, 1, 1)).astype("float64")}
self.attrs = {'reduce_all': True, 'keep_dim': False} self.attrs = {'reduce_all': True, 'keep_dim': False}
self.outputs = {'Out': self.inputs['X'].sum()} self.outputs = {'Out': self.inputs['X'].sum()}
self.enable_cinn = True self.if_enable_cinn()
def if_enable_cinn(self):
pass
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -1332,7 +1358,10 @@ class TestReduceAllFp32(OpTest): ...@@ -1332,7 +1358,10 @@ class TestReduceAllFp32(OpTest):
self.inputs = {'X': np.random.random((100, 1, 1)).astype("float32")} self.inputs = {'X': np.random.random((100, 1, 1)).astype("float32")}
self.attrs = {'reduce_all': True, 'keep_dim': False} self.attrs = {'reduce_all': True, 'keep_dim': False}
self.outputs = {'Out': self.inputs['X'].sum()} self.outputs = {'Out': self.inputs['X'].sum()}
self.enable_cinn = True self.if_enable_cinn()
def if_enable_cinn(self):
pass
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -1350,7 +1379,10 @@ class Test1DReduceWithAxes1(OpTest): ...@@ -1350,7 +1379,10 @@ class Test1DReduceWithAxes1(OpTest):
self.inputs = {'X': np.random.random(100).astype("float64")} self.inputs = {'X': np.random.random(100).astype("float64")}
self.attrs = {'dim': [0], 'keep_dim': False} self.attrs = {'dim': [0], 'keep_dim': False}
self.outputs = {'Out': self.inputs['X'].sum(axis=0)} self.outputs = {'Out': self.inputs['X'].sum(axis=0)}
self.enable_cinn = True self.if_enable_cinn()
def if_enable_cinn(self):
pass
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -1380,6 +1412,10 @@ class TestReduceWithDtype(OpTest): ...@@ -1380,6 +1412,10 @@ class TestReduceWithDtype(OpTest):
'out_dtype': int(convert_np_dtype_to_dtype_(np.float64)), 'out_dtype': int(convert_np_dtype_to_dtype_(np.float64)),
} }
) )
self.if_enable_cinn()
def if_enable_cinn(self):
pass
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
......
...@@ -531,9 +531,8 @@ class TestBF16(OpTest): ...@@ -531,9 +531,8 @@ class TestBF16(OpTest):
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
# pad not support bfloat16, so we can't test prim.
def test_check_grad_normal(self): def test_check_grad_normal(self):
self.check_grad(['Input'], 'Out') self.check_grad(['Input'], 'Out', check_prim=True)
# Test python API # Test python API
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册