未验证 提交 a33a4d01 编写于 作者: J jiangcheng 提交者: GitHub

[AMP] add fp16&bf16 support for flatten op (#52035)

* [AMP] add fp16&bf16 support for flatten op

* fix ci bug

* fix inpute should astype self.dtype bug and fix zerodim test name

* remove 0D-tensor bf16 test for window-inference-ci pass

* remove flatten from op_accuracy_white_list
上级 a34abdb5
...@@ -15,9 +15,10 @@ ...@@ -15,9 +15,10 @@
import unittest import unittest
import numpy as np import numpy as np
from eager_op_test import OpTest from eager_op_test import OpTest, convert_float_to_uint16
import paddle import paddle
from paddle.fluid import core
class TestFlattenOp(OpTest): class TestFlattenOp(OpTest):
...@@ -31,7 +32,8 @@ class TestFlattenOp(OpTest): ...@@ -31,7 +32,8 @@ class TestFlattenOp(OpTest):
self.stop_axis = -1 self.stop_axis = -1
self.skip_cinn() self.skip_cinn()
self.init_test_case() self.init_test_case()
self.inputs = {"X": np.random.random(self.in_shape).astype("float64")} self.init_test_dtype()
self.init_input_data()
self.init_attrs() self.init_attrs()
self.outputs = { self.outputs = {
"Out": self.inputs["X"].reshape(self.new_shape), "Out": self.inputs["X"].reshape(self.new_shape),
...@@ -42,10 +44,20 @@ class TestFlattenOp(OpTest): ...@@ -42,10 +44,20 @@ class TestFlattenOp(OpTest):
self.enable_cinn = True self.enable_cinn = True
def test_check_output(self): def test_check_output(self):
self.check_output(no_check_set=["XShape"], check_prim=True) if str(self.dtype) in {"float16", "uint16"}:
self.check_output_with_place(
core.CUDAPlace(0), no_check_set=["XShape"], check_prim=True
)
else:
self.check_output(no_check_set=["XShape"], check_prim=True)
def test_check_grad(self): def test_check_grad(self):
self.check_grad(["X"], "Out", check_prim=True) if str(self.dtype) in {"float16", "uint16"}:
self.check_grad_with_place(
core.CUDAPlace(0), ["X"], "Out", check_prim=True
)
else:
self.check_grad(["X"], "Out", check_prim=True)
def init_test_case(self): def init_test_case(self):
self.in_shape = (3, 2, 5, 4) self.in_shape = (3, 2, 5, 4)
...@@ -59,6 +71,42 @@ class TestFlattenOp(OpTest): ...@@ -59,6 +71,42 @@ class TestFlattenOp(OpTest):
"stop_axis": self.stop_axis, "stop_axis": self.stop_axis,
} }
def init_test_dtype(self):
self.dtype = "float64"
def init_input_data(self):
if str(self.dtype) != "uint16":
x = np.random.random(self.in_shape).astype(self.dtype)
else:
x = np.random.random(self.in_shape).astype("float32")
x = convert_float_to_uint16(x)
self.inputs = {"X": x}
class TestFlattenFP32Op(TestFlattenOp):
def init_test_dtype(self):
self.dtype = "float32"
@unittest.skipIf(
not core.is_compiled_with_cuda(),
"core is not complied with CUDA",
)
class TestFlattenFP16Op(TestFlattenOp):
def init_test_dtype(self):
self.dtype = "float16"
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestFlattenBF16Op(TestFlattenOp):
def init_test_dtype(self):
self.dtype = "uint16"
class TestFlattenOp_1(TestFlattenOp): class TestFlattenOp_1(TestFlattenOp):
def init_test_case(self): def init_test_case(self):
...@@ -74,6 +122,30 @@ class TestFlattenOp_1(TestFlattenOp): ...@@ -74,6 +122,30 @@ class TestFlattenOp_1(TestFlattenOp):
} }
class TestFlattenFP32Op_1(TestFlattenOp_1):
def init_test_dtype(self):
self.dtype = "float32"
@unittest.skipIf(
not core.is_compiled_with_cuda(),
"core is not complied with CUDA",
)
class TestFlattenFP16Op_1(TestFlattenOp_1):
def init_test_dtype(self):
self.dtype = "float16"
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestFlattenBF16Op_1(TestFlattenOp_1):
def init_test_dtype(self):
self.dtype = "uint16"
class TestFlattenOp_2(TestFlattenOp): class TestFlattenOp_2(TestFlattenOp):
def init_test_case(self): def init_test_case(self):
self.in_shape = (3, 2, 5, 4) self.in_shape = (3, 2, 5, 4)
...@@ -88,6 +160,30 @@ class TestFlattenOp_2(TestFlattenOp): ...@@ -88,6 +160,30 @@ class TestFlattenOp_2(TestFlattenOp):
} }
class TestFlattenFP32Op_2(TestFlattenOp_2):
def init_test_dtype(self):
self.dtype = "float32"
@unittest.skipIf(
not core.is_compiled_with_cuda(),
"core is not complied with CUDA",
)
class TestFlattenFP16Op_2(TestFlattenOp_2):
def init_test_dtype(self):
self.dtype = "float16"
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestFlattenBF16Op_2(TestFlattenOp_2):
def init_test_dtype(self):
self.dtype = "uint16"
class TestFlattenOp_3(TestFlattenOp): class TestFlattenOp_3(TestFlattenOp):
def init_test_case(self): def init_test_case(self):
self.in_shape = (3, 2, 5, 4) self.in_shape = (3, 2, 5, 4)
...@@ -102,6 +198,30 @@ class TestFlattenOp_3(TestFlattenOp): ...@@ -102,6 +198,30 @@ class TestFlattenOp_3(TestFlattenOp):
} }
class TestFlattenFP32Op_3(TestFlattenOp_3):
def init_test_dtype(self):
self.dtype = "float32"
@unittest.skipIf(
not core.is_compiled_with_cuda(),
"core is not complied with CUDA",
)
class TestFlattenFP16Op_3(TestFlattenOp_3):
def init_test_dtype(self):
self.dtype = "float16"
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestFlattenBF16Op_3(TestFlattenOp_3):
def init_test_dtype(self):
self.dtype = "uint16"
class TestFlattenOp_4(TestFlattenOp): class TestFlattenOp_4(TestFlattenOp):
def init_test_case(self): def init_test_case(self):
self.in_shape = (3, 2, 5, 4) self.in_shape = (3, 2, 5, 4)
...@@ -116,6 +236,30 @@ class TestFlattenOp_4(TestFlattenOp): ...@@ -116,6 +236,30 @@ class TestFlattenOp_4(TestFlattenOp):
} }
class TestFlattenFP32Op_4(TestFlattenOp_4):
def init_test_dtype(self):
self.dtype = "float32"
@unittest.skipIf(
not core.is_compiled_with_cuda(),
"core is not complied with CUDA",
)
class TestFlattenFP16Op_4(TestFlattenOp_4):
def init_test_dtype(self):
self.dtype = "float16"
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestFlattenBF16Op_4(TestFlattenOp_4):
def init_test_dtype(self):
self.dtype = "uint16"
class TestFlattenOp_5(TestFlattenOp): class TestFlattenOp_5(TestFlattenOp):
def init_test_case(self): def init_test_case(self):
self.in_shape = (3, 2, 5, 4) self.in_shape = (3, 2, 5, 4)
...@@ -130,7 +274,31 @@ class TestFlattenOp_5(TestFlattenOp): ...@@ -130,7 +274,31 @@ class TestFlattenOp_5(TestFlattenOp):
} }
class TestFlattenOp_6(TestFlattenOp): class TestFlattenFP32Op_5(TestFlattenOp_5):
def init_test_dtype(self):
self.dtype = "float32"
@unittest.skipIf(
not core.is_compiled_with_cuda(),
"core is not complied with CUDA",
)
class TestFlattenFP16Op_5(TestFlattenOp_5):
def init_test_dtype(self):
self.dtype = "float16"
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestFlattenBF16Op_5(TestFlattenOp_5):
def init_test_dtype(self):
self.dtype = "uint16"
class TestFlattenOp_ZeroDim(TestFlattenOp):
def init_test_case(self): def init_test_case(self):
self.in_shape = () self.in_shape = ()
self.start_axis = 0 self.start_axis = 0
...@@ -147,6 +315,20 @@ class TestFlattenOp_6(TestFlattenOp): ...@@ -147,6 +315,20 @@ class TestFlattenOp_6(TestFlattenOp):
} }
class TestFlattenFP32Op_ZeroDim(TestFlattenOp_ZeroDim):
def init_test_dtype(self):
self.dtype = "float32"
@unittest.skipIf(
not core.is_compiled_with_cuda(),
"core is not complied with CUDA",
)
class TestFlattenFP16Op_ZeroDim(TestFlattenOp_ZeroDim):
def init_test_dtype(self):
self.dtype = "float16"
class TestFlattenOpSixDims(TestFlattenOp): class TestFlattenOpSixDims(TestFlattenOp):
def init_test_case(self): def init_test_case(self):
self.in_shape = (3, 2, 3, 2, 4, 4) self.in_shape = (3, 2, 3, 2, 4, 4)
...@@ -161,6 +343,30 @@ class TestFlattenOpSixDims(TestFlattenOp): ...@@ -161,6 +343,30 @@ class TestFlattenOpSixDims(TestFlattenOp):
} }
class TestFlattenFP32OpSixDims(TestFlattenOpSixDims):
def init_test_dtype(self):
self.dtype = "float32"
@unittest.skipIf(
not core.is_compiled_with_cuda(),
"core is not complied with CUDA",
)
class TestFlattenFP16OpSixDims(TestFlattenOpSixDims):
def init_test_dtype(self):
self.dtype = "float16"
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestFlattenBF16OpSixDims(TestFlattenOpSixDims):
def init_test_dtype(self):
self.dtype = "uint16"
class TestFlatten2OpError(unittest.TestCase): class TestFlatten2OpError(unittest.TestCase):
def test_errors(self): def test_errors(self):
image_shape = (2, 3, 4, 4) image_shape = (2, 3, 4, 4)
......
...@@ -1591,6 +1591,7 @@ def flatten(x, start_axis=0, stop_axis=-1, name=None): ...@@ -1591,6 +1591,7 @@ def flatten(x, start_axis=0, stop_axis=-1, name=None):
'int32', 'int32',
'int64', 'int64',
'uint8', 'uint8',
'uint16',
], ],
'flatten', 'flatten',
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册