未验证 提交 eb677102 编写于 作者: D Difer 提交者: GitHub

【Hackathon No57】add_bf16_fp16 unittest for conv3d & conv3d_transpose (#52195)

* add test+conv3d_transpose_part2

* fix some merge error

* fix codestyle

* fix typo

* fix codestyle

* fix some error

* add redef float2uint

* fix conv3d and conv3d_transpose
上级 d7a5e900
...@@ -15,10 +15,16 @@ ...@@ -15,10 +15,16 @@
import unittest import unittest
import numpy as np import numpy as np
from eager_op_test import OpTest, paddle_static_guard from eager_op_test import (
OpTest,
convert_float_to_uint16,
get_numeric_gradient,
paddle_static_guard,
)
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.tests.unittests.testsuite import create_op
def conv3d_forward_naive( def conv3d_forward_naive(
...@@ -179,6 +185,77 @@ def create_test_cudnn_class(parent): ...@@ -179,6 +185,77 @@ def create_test_cudnn_class(parent):
globals()[cls_name] = TestCUDNNCase globals()[cls_name] = TestCUDNNCase
def create_test_cudnn_bf16_class(parent):
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and do not support bfloat16",
)
class TestConv3DCUDNNBF16(parent):
def get_numeric_grad(self, place, check_name):
scope = core.Scope()
self._check_grad_helper()
op = create_op(
scope, self.op_type, self.inputs, self.outputs, self.attrs
)
return get_numeric_gradient(
place, scope, op, self.inputs_fp32, check_name, ['Output']
)
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.uint16
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(
place, check_dygraph=(not self.use_mkldnn)
)
def test_check_grad_no_filter(self):
place = core.CUDAPlace(0)
numeric_grads = self.get_numeric_grad(place, 'Input')
self.check_grad_with_place(
place,
['Input'],
'Output',
no_grad_set={'Filter'},
check_dygraph=(not self.use_mkldnn),
user_defined_grads=[numeric_grads],
)
def test_check_grad_no_input(self):
place = core.CUDAPlace(0)
numeric_grads = self.get_numeric_grad(place, 'Filter')
self.check_grad_with_place(
place,
['Filter'],
'Output',
no_grad_set={'Input'},
check_dygraph=(not self.use_mkldnn),
user_defined_grads=[numeric_grads],
)
def test_check_grad(self):
place = core.CUDAPlace(0)
numeric_input_grads = self.get_numeric_grad(place, 'Input')
numeric_fliter_grads = self.get_numeric_grad(place, 'Filter')
self.check_grad_with_place(
place,
{'Input', 'Filter'},
'Output',
user_defined_grads=[numeric_input_grads, numeric_fliter_grads],
check_dygraph=(not self.use_mkldnn),
)
cls_name = "{}_{}".format(parent.__name__, "CUDNNBF16OP")
TestConv3DCUDNNBF16.__name__ = cls_name
globals()[cls_name] = TestConv3DCUDNNBF16
def create_test_padding_SAME_class(parent): def create_test_padding_SAME_class(parent):
class TestPaddingSMAECase(parent): class TestPaddingSMAECase(parent):
def init_paddings(self): def init_paddings(self):
...@@ -323,19 +400,37 @@ class TestConv3DOp(OpTest): ...@@ -323,19 +400,37 @@ class TestConv3DOp(OpTest):
'dilations': self.dilations, 'dilations': self.dilations,
} }
input = np.random.random(self.input_size).astype(self.dtype) if self.is_bfloat16_op():
filter = np.random.random(self.filter_size).astype(self.dtype) input = np.random.random(self.input_size).astype(np.float32)
filter = np.random.random(self.filter_size).astype(np.float32)
else:
input = np.random.random(self.input_size).astype(self.dtype)
filter = np.random.random(self.filter_size).astype(self.dtype)
output = conv3d_forward_naive( output = conv3d_forward_naive(
input, input,
filter, filter,
self.groups, self.groups,
conv3d_param, conv3d_param,
).astype(self.dtype) )
if self.is_bfloat16_op():
output = convert_float_to_uint16(output)
self.inputs = {
'Input': convert_float_to_uint16(input),
'Filter': convert_float_to_uint16(filter),
}
self.inputs_fp32 = {
'Input': OpTest.np_dtype_to_fluid_dtype(input),
'Filter': OpTest.np_dtype_to_fluid_dtype(filter),
}
else:
output = output.astype(self.dtype)
self.inputs = {
'Input': OpTest.np_dtype_to_fluid_dtype(input),
'Filter': OpTest.np_dtype_to_fluid_dtype(filter),
}
self.inputs = {
'Input': OpTest.np_dtype_to_fluid_dtype(input),
'Filter': OpTest.np_dtype_to_fluid_dtype(filter),
}
self.attrs = { self.attrs = {
'strides': self.stride, 'strides': self.stride,
'paddings': self.pad, 'paddings': self.pad,
...@@ -358,8 +453,6 @@ class TestConv3DOp(OpTest): ...@@ -358,8 +453,6 @@ class TestConv3DOp(OpTest):
) )
def test_check_grad(self): def test_check_grad(self):
if self.dtype == np.float16:
return
place = core.CUDAPlace(0) if self.has_cudnn() else core.CPUPlace() place = core.CUDAPlace(0) if self.has_cudnn() else core.CPUPlace()
# TODO(wangzhongpu): support mkldnn op in dygraph mode # TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_grad_with_place( self.check_grad_with_place(
...@@ -371,8 +464,7 @@ class TestConv3DOp(OpTest): ...@@ -371,8 +464,7 @@ class TestConv3DOp(OpTest):
) )
def test_check_grad_no_filter(self): def test_check_grad_no_filter(self):
if self.dtype == np.float16:
return
place = core.CUDAPlace(0) if self.has_cudnn() else core.CPUPlace() place = core.CUDAPlace(0) if self.has_cudnn() else core.CPUPlace()
# TODO(wangzhongpu): support mkldnn op in dygraph mode # TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_grad_with_place( self.check_grad_with_place(
...@@ -385,8 +477,7 @@ class TestConv3DOp(OpTest): ...@@ -385,8 +477,7 @@ class TestConv3DOp(OpTest):
) )
def test_check_grad_no_input(self): def test_check_grad_no_input(self):
if self.dtype == np.float16:
return
place = core.CUDAPlace(0) if self.has_cudnn() else core.CPUPlace() place = core.CUDAPlace(0) if self.has_cudnn() else core.CPUPlace()
# TODO(wangzhongpu): support mkldnn op in dygraph mode # TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_grad_with_place( self.check_grad_with_place(
...@@ -617,6 +708,14 @@ class TestCUDNNExhaustiveSearch(TestCUDNN): ...@@ -617,6 +708,14 @@ class TestCUDNNExhaustiveSearch(TestCUDNN):
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
# ----------------Conv3DCUDNN bf16----------------
create_test_cudnn_bf16_class(TestConv3DOp)
create_test_cudnn_bf16_class(TestWithGroup1)
create_test_cudnn_bf16_class(TestWithGroup2)
create_test_cudnn_bf16_class(TestWith1x1)
create_test_cudnn_bf16_class(TestWithInput1x1Filter1x1)
# ---- test asymmetric padding ---- # ---- test asymmetric padding ----
...@@ -1114,4 +1213,5 @@ class TestConv3DAPI_Error(unittest.TestCase): ...@@ -1114,4 +1213,5 @@ class TestConv3DAPI_Error(unittest.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -19,11 +19,25 @@ import numpy as np ...@@ -19,11 +19,25 @@ import numpy as np
import paddle import paddle
paddle.enable_static() paddle.enable_static()
from eager_op_test import OpTest from eager_op_test import OpTest, copy_bits_from_float_to_uint16
from paddle.fluid import core from paddle.fluid import core
def convert_float_to_uint16(float_list, data_format="NCHW"):
if data_format == "NHWC":
float_list = np.transpose(float_list, [0, 4, 1, 2, 3])
new_output = []
for x in np.nditer(float_list):
new_output.append(np.uint16(copy_bits_from_float_to_uint16(x)))
new_output = np.reshape(new_output, float_list.shape).view(np.uint16)
if data_format == "NHWC":
new_output = np.transpose(new_output, [0, 2, 3, 4, 1])
return new_output
def conv3dtranspose_forward_naive(input_, filter_, attrs): def conv3dtranspose_forward_naive(input_, filter_, attrs):
padding_algorithm = attrs['padding_algorithm'] padding_algorithm = attrs['padding_algorithm']
if padding_algorithm not in ["SAME", "VALID", "EXPLICIT"]: if padding_algorithm not in ["SAME", "VALID", "EXPLICIT"]:
...@@ -134,6 +148,86 @@ def conv3dtranspose_forward_naive(input_, filter_, attrs): ...@@ -134,6 +148,86 @@ def conv3dtranspose_forward_naive(input_, filter_, attrs):
return out return out
def create_test_cudnn_fp16_class(parent, grad_check=True):
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestConv3DTransposeCUDNNFP16(parent):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=2e-2)
def test_check_grad_no_filter(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place) and grad_check:
self.check_grad_with_place(
place, ['Input'], 'Output', no_grad_set={'Filter'}
)
def test_check_grad_no_input(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place) and grad_check:
self.check_grad_with_place(
place, ['Filter'], 'Output', no_grad_set={'Input'}
)
cls_name = "{}_{}".format(parent.__name__, "CUDNNFP16OP")
TestConv3DTransposeCUDNNFP16.__name__ = cls_name
globals()[cls_name] = TestConv3DTransposeCUDNNFP16
def create_test_cudnn_bf16_class(parent):
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and do not support bfloat16",
)
class TestConv3DTransposeCUDNNBF16(parent):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.uint16
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place,
{'Input', 'Filter'},
'Output',
)
def test_check_grad_no_filter(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place,
['Input'],
'Output',
no_grad_set={'Filter'},
)
def test_check_grad_no_input(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place,
['Filter'],
'Output',
no_grad_set={'Input'},
)
cls_name = "{}_{}".format(parent.__name__, "CUDNNBF16OP")
TestConv3DTransposeCUDNNBF16.__name__ = cls_name
globals()[cls_name] = TestConv3DTransposeCUDNNBF16
def conv3d_transpose_wrapper( def conv3d_transpose_wrapper(
x, x,
weight, weight,
...@@ -172,12 +266,16 @@ class TestConv3DTransposeOp(OpTest): ...@@ -172,12 +266,16 @@ class TestConv3DTransposeOp(OpTest):
self.pad = [0, 0, 0] self.pad = [0, 0, 0]
self.padding_algorithm = "EXPLICIT" self.padding_algorithm = "EXPLICIT"
self.init_op_type() self.init_op_type()
self.init_kernel_type()
self.init_test_case() self.init_test_case()
input_ = np.random.random(self.input_size).astype("float32") if self.is_bfloat16_op():
filter_ = np.random.random(self.filter_size).astype("float32") input = np.random.random(self.input_size).astype(np.float32)
filter = np.random.random(self.filter_size).astype(np.float32)
else:
input = np.random.random(self.input_size).astype(self.dtype)
filter = np.random.random(self.filter_size).astype(self.dtype)
self.inputs = {'Input': input_, 'Filter': filter_}
self.attrs = { self.attrs = {
'strides': self.stride, 'strides': self.stride,
'paddings': self.pad, 'paddings': self.pad,
...@@ -189,9 +287,21 @@ class TestConv3DTransposeOp(OpTest): ...@@ -189,9 +287,21 @@ class TestConv3DTransposeOp(OpTest):
} }
output = conv3dtranspose_forward_naive( output = conv3dtranspose_forward_naive(
input_, filter_, self.attrs input, filter, self.attrs
).astype("float32") ).astype("float32")
if self.is_bfloat16_op():
self.inputs = {
'Input': convert_float_to_uint16(input),
'Filter': convert_float_to_uint16(filter),
}
else:
self.inputs = {
'Input': input,
'Filter': filter,
}
output = output.astype(self.dtype)
self.outputs = {'Output': output} self.outputs = {'Output': output}
def test_check_output(self): def test_check_output(self):
...@@ -264,6 +374,9 @@ class TestConv3DTransposeOp(OpTest): ...@@ -264,6 +374,9 @@ class TestConv3DTransposeOp(OpTest):
self.op_type = "conv3d_transpose" self.op_type = "conv3d_transpose"
self.python_api = conv3d_transpose_wrapper self.python_api = conv3d_transpose_wrapper
def init_kernel_type(self):
self.dtype = np.float32
class TestWithSymmetricPad(TestConv3DTransposeOp): class TestWithSymmetricPad(TestConv3DTransposeOp):
def init_test_case(self): def init_test_case(self):
...@@ -596,6 +709,30 @@ class TestCUDNNWithGroups_NHWC(TestWithGroups): ...@@ -596,6 +709,30 @@ class TestCUDNNWithGroups_NHWC(TestWithGroups):
self.python_api = conv3d_transpose_wrapper self.python_api = conv3d_transpose_wrapper
# ----------------Conv3DTransposeCUDNN fp16----------------
create_test_cudnn_fp16_class(TestConv3DTransposeOp)
create_test_cudnn_fp16_class(TestWithSymmetricPad)
create_test_cudnn_fp16_class(TestWithAsymmetricPad)
create_test_cudnn_fp16_class(TestWithSAMEPad)
create_test_cudnn_fp16_class(TestWithVALIDPad)
create_test_cudnn_fp16_class(TestWithStride)
create_test_cudnn_fp16_class(TestWithGroups)
create_test_cudnn_fp16_class(TestWithDilation)
create_test_cudnn_fp16_class(Test_NHWC)
# ----------------Conv3DTransposeCUDNN bf16----------------
create_test_cudnn_bf16_class(TestConv3DTransposeOp)
create_test_cudnn_bf16_class(TestWithSymmetricPad)
create_test_cudnn_bf16_class(TestWithAsymmetricPad)
create_test_cudnn_bf16_class(TestWithSAMEPad)
create_test_cudnn_bf16_class(TestWithVALIDPad)
create_test_cudnn_bf16_class(TestWithStride)
create_test_cudnn_bf16_class(TestWithGroups)
create_test_cudnn_bf16_class(TestWithDilation)
create_test_cudnn_bf16_class(Test_NHWC)
class TestConv3dTranspose(unittest.TestCase): class TestConv3dTranspose(unittest.TestCase):
def error_weight_input(self): def error_weight_input(self):
array = np.array([1], dtype=np.float32) array = np.array([1], dtype=np.float32)
......
...@@ -15,7 +15,11 @@ ...@@ -15,7 +15,11 @@
import unittest import unittest
import numpy as np import numpy as np
from test_conv3d_transpose_op import TestConv3DTransposeOp from test_conv3d_transpose_op import (
TestConv3DTransposeOp,
create_test_cudnn_bf16_class,
create_test_cudnn_fp16_class,
)
import paddle import paddle
from paddle import fluid from paddle import fluid
...@@ -84,6 +88,22 @@ class TestWithDilation_NHWC(TestConv3DTransposeOp): ...@@ -84,6 +88,22 @@ class TestWithDilation_NHWC(TestConv3DTransposeOp):
self.data_format = 'NHWC' self.data_format = 'NHWC'
# ----------------Conv3DTransposeCUDNN fp16----------------
create_test_cudnn_fp16_class(TestWithSymmetricPad_NHWC)
create_test_cudnn_fp16_class(TestWithAsymmetricPad_NHWC)
create_test_cudnn_fp16_class(TestWithGroups_NHWC)
create_test_cudnn_fp16_class(TestWithStride_NHWC)
create_test_cudnn_fp16_class(TestWithDilation_NHWC)
# ----------------Conv3DTransposeCUDNN bf16----------------
create_test_cudnn_bf16_class(TestWithSymmetricPad_NHWC)
create_test_cudnn_bf16_class(TestWithAsymmetricPad_NHWC)
create_test_cudnn_bf16_class(TestWithGroups_NHWC)
create_test_cudnn_bf16_class(TestWithStride_NHWC)
create_test_cudnn_bf16_class(TestWithDilation_NHWC)
class TestConv3DTransposeAPI(unittest.TestCase): class TestConv3DTransposeAPI(unittest.TestCase):
def test_case1(self): def test_case1(self):
data1 = paddle.static.data( data1 = paddle.static.data(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册