未验证 提交 585f1136 编写于 作者: C Charles-hit 提交者: GitHub

support some prim ops for bf16 dtype (#54285)

上级 f5342918
......@@ -2735,6 +2735,7 @@ def gather(x, index, axis=None, name=None):
'int32',
'int64',
'uint8',
'uint16',
],
'gather',
)
......
......@@ -71,7 +71,7 @@ class TestElementwiseOp(OpTest):
self.check_prim = True
def if_enable_cinn(self):
self.enable_cinn = False
pass
class TestElementwiseFP16OP(TestElementwiseOp):
......@@ -87,6 +87,7 @@ class TestElementwiseFP16OP(TestElementwiseOp):
class TestElementwiseBF16OP(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_sub"
self.prim_op_type = "prim"
self.dtype = np.uint16
self.python_api = paddle.subtract
self.public_python_api = paddle.subtract
......@@ -103,6 +104,9 @@ class TestElementwiseBF16OP(TestElementwiseOp):
self.if_check_prim()
self.if_enable_cinn()
def if_enable_cinn(self):
self.enable_cinn = False
def test_check_grad_normal(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
......@@ -118,7 +122,12 @@ class TestElementwiseBF16OP(TestElementwiseOp):
def test_check_grad_ingore_y(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X'], 'Out', no_grad_set=set('Y'), max_relative_error=0.1
place,
['X'],
'Out',
no_grad_set=set('Y'),
max_relative_error=0.1,
check_prim=True,
)
......@@ -135,6 +144,10 @@ class TestElementwiseSubOp_ZeroDim1(TestElementwiseOp):
}
self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']}
self.if_check_prim()
self.if_enable_cinn()
def if_enable_cinn(self):
self.enable_cinn = False
class TestElementwiseSubFP16OP_ZeroDim1(TestElementwiseSubOp_ZeroDim1):
......@@ -181,6 +194,10 @@ class TestElementwiseSubOp_ZeroDim2(TestElementwiseOp):
}
self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']}
self.if_check_prim()
self.if_enable_cinn()
def if_enable_cinn(self):
self.enable_cinn = False
class TestElementwiseSubFP16OP_ZeroDim2(TestElementwiseSubOp_ZeroDim2):
......@@ -227,6 +244,10 @@ class TestElementwiseSubOp_ZeroDim3(TestElementwiseOp):
}
self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']}
self.if_check_prim()
self.if_enable_cinn()
def if_enable_cinn(self):
self.enable_cinn = False
class TestElementwiseSubFP16OP_ZeroDim3(TestElementwiseSubOp_ZeroDim3):
......@@ -580,6 +601,7 @@ class TestElementwiseSubOp_broadcast_4(TestElementwiseOp):
}
self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']}
self.if_check_prim()
self.if_enable_cinn()
@unittest.skipIf(
......@@ -653,6 +675,7 @@ class TestElementwiseBF16OP_commonuse_1(TestElementwiseBF16OP):
}
self.outputs = {'Out': convert_float_to_uint16(self.outputs['Out'])}
self.if_check_prim()
self.if_enable_cinn()
class TestElementwiseSubOp_commonuse_2(TestElementwiseOp):
......@@ -698,6 +721,7 @@ class TestElementwiseBF16OP_commonuse_2(TestElementwiseBF16OP):
}
self.outputs = {'Out': convert_float_to_uint16(self.outputs['Out'])}
self.if_check_prim()
self.if_enable_cinn()
class TestElementwiseSubOp_xsize_lessthan_ysize(TestElementwiseOp):
......@@ -717,6 +741,7 @@ class TestElementwiseSubOp_xsize_lessthan_ysize(TestElementwiseOp):
'Out': self.inputs['X'].reshape(1, 1, 10, 12) - self.inputs['Y']
}
self.if_check_prim()
self.if_enable_cinn()
class TestElementwiseSubFP16OP_xsize_lessthan_ysize(
......@@ -750,6 +775,7 @@ class TestElementwiseBF16OP_xsize_lessthan_ysize(TestElementwiseBF16OP):
}
self.outputs = {'Out': convert_float_to_uint16(self.outputs['Out'])}
self.if_check_prim()
self.if_enable_cinn()
class TestComplexElementwiseSubOp(OpTest):
......
......@@ -37,12 +37,8 @@ class TestGatherOp(OpTest):
self.public_python_api = paddle.gather
self.config()
self.prim_op_type = "prim"
xnp = np.random.random(self.x_shape).astype(self.x_type)
self.inputs = {
'X': xnp,
'Index': np.array(self.index).astype(self.index_type),
}
self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]}
self.init_inputs_and_outputs()
self.if_enable_cinn()
def test_check_output(self):
self.check_output()
......@@ -62,12 +58,56 @@ class TestGatherOp(OpTest):
def config_dtype(self):
self.x_type = "float64"
def init_inputs_and_outputs(self):
xnp = np.random.random(self.x_shape).astype(self.x_type)
self.inputs = {
'X': xnp,
'Index': np.array(self.index).astype(self.index_type),
}
self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]}
def if_enable_cinn(self):
pass
class TestGatherOpFP16(TestGatherOp):
def config_dtype(self):
self.x_type = "float16"
@unittest.skipIf(
not core.is_compiled_with_cuda()
or core.cudnn_version() < 8100
or paddle.device.cuda.get_device_capability()[0] < 8,
"only support compiled with CUDA and cudnn version need larger than 8.1.0 and device's compute capability is at least 8.0",
)
class TestGatherOpBFP16(TestGatherOp):
def config_dtype(self):
self.x_type = "float32"
self.dtype = np.uint16
def init_inputs_and_outputs(self):
xnp = np.random.random(self.x_shape).astype(self.x_type)
self.inputs = {
'X': convert_float_to_uint16(xnp),
'Index': np.array(self.index).astype(self.index_type),
}
self.outputs = {
'Out': convert_float_to_uint16(xnp[self.inputs["Index"]])
}
def if_enable_cinn(self):
self.enable_cinn = False
def test_check_output(self):
self.check_output_with_place(place=paddle.CUDAPlace(0))
def test_check_grad(self):
self.check_grad_with_place(
paddle.CUDAPlace(0), ['X'], 'Out', check_prim=True
)
class TestCase1(TestGatherOp):
def config(self):
"""
......@@ -87,6 +127,14 @@ class TestCase1FP16(TestCase1):
self.x_type = "float16"
class TestCase1BFP16(TestGatherOpBFP16):
def config(self):
self.x_shape = 100
self.config_dtype()
self.index = [1, 3, 5]
self.index_type = "int32"
class TestCase2(TestGatherOp):
def config(self):
"""
......@@ -106,6 +154,14 @@ class TestCase2FP16(TestCase2):
self.x_type = "float16"
class TestCase2BFP16(TestGatherOpBFP16):
def config(self):
self.x_shape = 100
self.config_dtype()
self.index = [1, 3, 5]
self.index_type = "int64"
class TestCase3(TestGatherOp):
def config(self):
"""
......@@ -125,6 +181,14 @@ class TestCase3Fp16(TestCase3):
self.x_type = "float16"
class TestCase3BFP16(TestGatherOpBFP16):
def config(self):
self.x_shape = (10, 20)
self.config_dtype()
self.index = [1, 3, 5]
self.index_type = "int64"
class TestCase4(TestGatherOp):
def config(self):
self.x_shape = (10, 20)
......@@ -142,6 +206,15 @@ class TestCase4FP16(TestCase4):
self.x_type = "float16"
class TestCase4BFP16(TestGatherOpBFP16):
def config(self):
self.x_shape = (10, 20)
self.attrs = {'overwrite': False}
self.config_dtype()
self.index = [1, 1]
self.index_type = "int32"
class TestCase5(TestGatherOp):
def config(self):
self.x_shape = (10, 20)
......@@ -154,6 +227,15 @@ class TestCase5(TestGatherOp):
self.x_type = "float64"
class TestCase5BFP16(TestGatherOpBFP16):
def config(self):
self.x_shape = (10, 20)
self.attrs = {'overwrite': False}
self.config_dtype()
self.index = [1, 1]
self.index_type = "int32"
class TestCase5FP16(TestCase5):
def config_dtype(self):
self.x_type = "float16"
......@@ -176,6 +258,15 @@ class TestCase6FP16(TestCase6):
self.x_type = "float16"
class TestCase6BFP16(TestGatherOpBFP16):
def config(self):
self.x_shape = (10, 20)
self.attrs = {'overwrite': True}
self.config_dtype()
self.index = [1, 3]
self.index_type = "int32"
class TestGatherBF16Op(OpTest):
def setUp(self):
self.op_type = "gather"
......
......@@ -36,7 +36,7 @@ class TestSumOp(OpTest):
self.prim_op_type = "prim"
self.inputs = {'X': self.x}
self.outputs = {'Out': self.out}
self.enable_cinn = True
self.if_enable_cinn()
def init_dtype(self):
self.dtype = np.float64
......@@ -47,6 +47,9 @@ class TestSumOp(OpTest):
def init_attrs(self):
self.attrs = {'dim': [0]}
def if_enable_cinn(self):
pass
def calc_output(self):
self.out = self.x.sum(axis=tuple(self.attrs['dim']))
......@@ -984,7 +987,10 @@ class Test1DReduce(OpTest):
self.prim_op_type = "prim"
self.inputs = {'X': np.random.random(120).astype("float64")}
self.outputs = {'Out': self.inputs['X'].sum(axis=0)}
self.enable_cinn = True
self.if_enable_cinn()
def if_enable_cinn(self):
pass
def test_check_output(self):
self.check_output()
......@@ -1002,6 +1008,7 @@ class Test2DReduce0(Test1DReduce):
self.attrs = {'dim': [0]}
self.inputs = {'X': np.random.random((20, 10)).astype("float64")}
self.outputs = {'Out': self.inputs['X'].sum(axis=0)}
self.if_enable_cinn()
class Test2DReduce1(Test1DReduce):
......@@ -1015,6 +1022,7 @@ class Test2DReduce1(Test1DReduce):
self.outputs = {
'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']))
}
self.if_enable_cinn()
class Test3DReduce0(Test1DReduce):
......@@ -1028,6 +1036,7 @@ class Test3DReduce0(Test1DReduce):
self.outputs = {
'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']))
}
self.if_enable_cinn()
class Test3DReduce1(Test1DReduce):
......@@ -1041,6 +1050,7 @@ class Test3DReduce1(Test1DReduce):
self.outputs = {
'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']))
}
self.if_enable_cinn()
class Test3DReduce2(Test1DReduce):
......@@ -1054,6 +1064,7 @@ class Test3DReduce2(Test1DReduce):
self.outputs = {
'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']))
}
self.if_enable_cinn()
class Test3DReduce3(Test1DReduce):
......@@ -1067,6 +1078,7 @@ class Test3DReduce3(Test1DReduce):
self.outputs = {
'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']))
}
self.if_enable_cinn()
def reduce_sum_wrapper2(x, axis=[0], dtype=None, keepdim=False):
......@@ -1105,6 +1117,7 @@ class TestKeepDimReduce(Test1DReduce):
axis=tuple(self.attrs['dim']), keepdims=self.attrs['keep_dim']
)
}
self.if_enable_cinn()
class TestKeepDimReduceForEager(Test1DReduce):
......@@ -1208,6 +1221,10 @@ class TestKeepDimReduceSumMultiAxises(OpTest):
axis=tuple(self.attrs['dim']), keepdims=True
)
}
self.if_enable_cinn()
def if_enable_cinn(self):
pass
def test_check_output(self):
self.check_output()
......@@ -1248,7 +1265,10 @@ class TestReduceSumWithDimOne(OpTest):
axis=tuple(self.attrs['dim']), keepdims=True
)
}
self.enable_cinn = True
self.if_enable_cinn()
def if_enable_cinn(self):
pass
def test_check_output(self):
self.check_output()
......@@ -1290,7 +1310,10 @@ class TestReduceSumWithNumelOne(OpTest):
axis=tuple(self.attrs['dim']), keepdims=False
)
}
self.enable_cinn = True
self.if_enable_cinn()
def if_enable_cinn(self):
pass
def test_check_output(self):
self.check_output()
......@@ -1314,7 +1337,10 @@ class TestReduceAll(OpTest):
self.inputs = {'X': np.random.random((100, 1, 1)).astype("float64")}
self.attrs = {'reduce_all': True, 'keep_dim': False}
self.outputs = {'Out': self.inputs['X'].sum()}
self.enable_cinn = True
self.if_enable_cinn()
def if_enable_cinn(self):
pass
def test_check_output(self):
self.check_output()
......@@ -1332,7 +1358,10 @@ class TestReduceAllFp32(OpTest):
self.inputs = {'X': np.random.random((100, 1, 1)).astype("float32")}
self.attrs = {'reduce_all': True, 'keep_dim': False}
self.outputs = {'Out': self.inputs['X'].sum()}
self.enable_cinn = True
self.if_enable_cinn()
def if_enable_cinn(self):
pass
def test_check_output(self):
self.check_output()
......@@ -1350,7 +1379,10 @@ class Test1DReduceWithAxes1(OpTest):
self.inputs = {'X': np.random.random(100).astype("float64")}
self.attrs = {'dim': [0], 'keep_dim': False}
self.outputs = {'Out': self.inputs['X'].sum(axis=0)}
self.enable_cinn = True
self.if_enable_cinn()
def if_enable_cinn(self):
pass
def test_check_output(self):
self.check_output()
......@@ -1380,6 +1412,10 @@ class TestReduceWithDtype(OpTest):
'out_dtype': int(convert_np_dtype_to_dtype_(np.float64)),
}
)
self.if_enable_cinn()
def if_enable_cinn(self):
pass
def test_check_output(self):
self.check_output()
......
......@@ -531,9 +531,8 @@ class TestBF16(OpTest):
def test_check_output(self):
self.check_output()
# pad not support bfloat16, so we can't test prim.
def test_check_grad_normal(self):
self.check_grad(['Input'], 'Out')
self.check_grad(['Input'], 'Out', check_prim=True)
# Test python API
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册