未验证 提交 a33a4d01 编写于 作者: J jiangcheng 提交者: GitHub

[AMP] add fp16&bf16 support for flatten op (#52035)

* [AMP] add fp16&bf16 support for flatten op

* fix ci bug

* fix inpute should astype self.dtype bug and fix zerodim test name

* remove 0D-tensor bf16 test for window-inference-ci pass

* remove flatten from op_accuracy_white_list
上级 a34abdb5
......@@ -15,9 +15,10 @@
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
import paddle
from paddle.fluid import core
class TestFlattenOp(OpTest):
......@@ -31,7 +32,8 @@ class TestFlattenOp(OpTest):
self.stop_axis = -1
self.skip_cinn()
self.init_test_case()
self.inputs = {"X": np.random.random(self.in_shape).astype("float64")}
self.init_test_dtype()
self.init_input_data()
self.init_attrs()
self.outputs = {
"Out": self.inputs["X"].reshape(self.new_shape),
......@@ -42,10 +44,20 @@ class TestFlattenOp(OpTest):
self.enable_cinn = True
def test_check_output(self):
self.check_output(no_check_set=["XShape"], check_prim=True)
if str(self.dtype) in {"float16", "uint16"}:
self.check_output_with_place(
core.CUDAPlace(0), no_check_set=["XShape"], check_prim=True
)
else:
self.check_output(no_check_set=["XShape"], check_prim=True)
def test_check_grad(self):
self.check_grad(["X"], "Out", check_prim=True)
if str(self.dtype) in {"float16", "uint16"}:
self.check_grad_with_place(
core.CUDAPlace(0), ["X"], "Out", check_prim=True
)
else:
self.check_grad(["X"], "Out", check_prim=True)
def init_test_case(self):
self.in_shape = (3, 2, 5, 4)
......@@ -59,6 +71,42 @@ class TestFlattenOp(OpTest):
"stop_axis": self.stop_axis,
}
def init_test_dtype(self):
self.dtype = "float64"
def init_input_data(self):
if str(self.dtype) != "uint16":
x = np.random.random(self.in_shape).astype(self.dtype)
else:
x = np.random.random(self.in_shape).astype("float32")
x = convert_float_to_uint16(x)
self.inputs = {"X": x}
class TestFlattenFP32Op(TestFlattenOp):
def init_test_dtype(self):
self.dtype = "float32"
@unittest.skipIf(
not core.is_compiled_with_cuda(),
"core is not complied with CUDA",
)
class TestFlattenFP16Op(TestFlattenOp):
def init_test_dtype(self):
self.dtype = "float16"
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestFlattenBF16Op(TestFlattenOp):
def init_test_dtype(self):
self.dtype = "uint16"
class TestFlattenOp_1(TestFlattenOp):
def init_test_case(self):
......@@ -74,6 +122,30 @@ class TestFlattenOp_1(TestFlattenOp):
}
class TestFlattenFP32Op_1(TestFlattenOp_1):
def init_test_dtype(self):
self.dtype = "float32"
@unittest.skipIf(
not core.is_compiled_with_cuda(),
"core is not complied with CUDA",
)
class TestFlattenFP16Op_1(TestFlattenOp_1):
def init_test_dtype(self):
self.dtype = "float16"
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestFlattenBF16Op_1(TestFlattenOp_1):
def init_test_dtype(self):
self.dtype = "uint16"
class TestFlattenOp_2(TestFlattenOp):
def init_test_case(self):
self.in_shape = (3, 2, 5, 4)
......@@ -88,6 +160,30 @@ class TestFlattenOp_2(TestFlattenOp):
}
class TestFlattenFP32Op_2(TestFlattenOp_2):
def init_test_dtype(self):
self.dtype = "float32"
@unittest.skipIf(
not core.is_compiled_with_cuda(),
"core is not complied with CUDA",
)
class TestFlattenFP16Op_2(TestFlattenOp_2):
def init_test_dtype(self):
self.dtype = "float16"
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestFlattenBF16Op_2(TestFlattenOp_2):
def init_test_dtype(self):
self.dtype = "uint16"
class TestFlattenOp_3(TestFlattenOp):
def init_test_case(self):
self.in_shape = (3, 2, 5, 4)
......@@ -102,6 +198,30 @@ class TestFlattenOp_3(TestFlattenOp):
}
class TestFlattenFP32Op_3(TestFlattenOp_3):
def init_test_dtype(self):
self.dtype = "float32"
@unittest.skipIf(
not core.is_compiled_with_cuda(),
"core is not complied with CUDA",
)
class TestFlattenFP16Op_3(TestFlattenOp_3):
def init_test_dtype(self):
self.dtype = "float16"
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestFlattenBF16Op_3(TestFlattenOp_3):
def init_test_dtype(self):
self.dtype = "uint16"
class TestFlattenOp_4(TestFlattenOp):
def init_test_case(self):
self.in_shape = (3, 2, 5, 4)
......@@ -116,6 +236,30 @@ class TestFlattenOp_4(TestFlattenOp):
}
class TestFlattenFP32Op_4(TestFlattenOp_4):
def init_test_dtype(self):
self.dtype = "float32"
@unittest.skipIf(
not core.is_compiled_with_cuda(),
"core is not complied with CUDA",
)
class TestFlattenFP16Op_4(TestFlattenOp_4):
def init_test_dtype(self):
self.dtype = "float16"
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestFlattenBF16Op_4(TestFlattenOp_4):
def init_test_dtype(self):
self.dtype = "uint16"
class TestFlattenOp_5(TestFlattenOp):
def init_test_case(self):
self.in_shape = (3, 2, 5, 4)
......@@ -130,7 +274,31 @@ class TestFlattenOp_5(TestFlattenOp):
}
class TestFlattenOp_6(TestFlattenOp):
class TestFlattenFP32Op_5(TestFlattenOp_5):
def init_test_dtype(self):
self.dtype = "float32"
@unittest.skipIf(
not core.is_compiled_with_cuda(),
"core is not complied with CUDA",
)
class TestFlattenFP16Op_5(TestFlattenOp_5):
def init_test_dtype(self):
self.dtype = "float16"
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestFlattenBF16Op_5(TestFlattenOp_5):
def init_test_dtype(self):
self.dtype = "uint16"
class TestFlattenOp_ZeroDim(TestFlattenOp):
def init_test_case(self):
self.in_shape = ()
self.start_axis = 0
......@@ -147,6 +315,20 @@ class TestFlattenOp_6(TestFlattenOp):
}
class TestFlattenFP32Op_ZeroDim(TestFlattenOp_ZeroDim):
def init_test_dtype(self):
self.dtype = "float32"
@unittest.skipIf(
not core.is_compiled_with_cuda(),
"core is not complied with CUDA",
)
class TestFlattenFP16Op_ZeroDim(TestFlattenOp_ZeroDim):
def init_test_dtype(self):
self.dtype = "float16"
class TestFlattenOpSixDims(TestFlattenOp):
def init_test_case(self):
self.in_shape = (3, 2, 3, 2, 4, 4)
......@@ -161,6 +343,30 @@ class TestFlattenOpSixDims(TestFlattenOp):
}
class TestFlattenFP32OpSixDims(TestFlattenOpSixDims):
def init_test_dtype(self):
self.dtype = "float32"
@unittest.skipIf(
not core.is_compiled_with_cuda(),
"core is not complied with CUDA",
)
class TestFlattenFP16OpSixDims(TestFlattenOpSixDims):
def init_test_dtype(self):
self.dtype = "float16"
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestFlattenBF16OpSixDims(TestFlattenOpSixDims):
def init_test_dtype(self):
self.dtype = "uint16"
class TestFlatten2OpError(unittest.TestCase):
def test_errors(self):
image_shape = (2, 3, 4, 4)
......
......@@ -1591,6 +1591,7 @@ def flatten(x, start_axis=0, stop_axis=-1, name=None):
'int32',
'int64',
'uint8',
'uint16',
],
'flatten',
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册