未验证 提交 ad5536eb 编写于 作者: W wangxinxin08 提交者: GitHub

[AMP OP&Test] add fp16/bf16 unittest for conv ops (#51787)

* add unittest for conv2d/depthwise_conv2d/conv2d_transpose

* add bf16 for DWConv and ConvTranspose

* fix unitest of conv2d_transpose

* modify DWConv2d op and unittest

* fix unittest of conv2d_transpose_bf16

* modify unittest name according to review

* modify atol of DWConv2D unittest
上级 b94ef537
......@@ -13,6 +13,8 @@
// limitations under the License.
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/conv_util.h"
......@@ -142,4 +144,5 @@ PD_REGISTER_KERNEL(depthwise_conv2d_grad,
phi::DepthwiseConvGradKernel,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -13,6 +13,8 @@
// limitations under the License.
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/conv_util.h"
#include "paddle/phi/kernels/funcs/batch_norm_utils.h"
......@@ -127,4 +129,5 @@ PD_REGISTER_KERNEL(depthwise_conv2d,
phi::DepthwiseConvKernel,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/backends/dynload/cudnn.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/kernel_registry.h"
......@@ -1071,6 +1072,32 @@ PD_REGISTER_KERNEL(conv3d_transpose_grad,
float,
float16) {}
#else
#if CUDNN_VERSION_MIN(8, 1, 0)
PD_REGISTER_KERNEL(conv2d_transpose_grad,
GPUDNN,
ALL_LAYOUT,
phi::Conv2dTransposeGradGPUDNNKernel,
float,
double,
float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(conv2d_transpose_grad_grad,
GPUDNN,
ALL_LAYOUT,
phi::Conv2dTransposeDoubleGradGPUDNNKernel,
float,
double,
float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(conv3d_transpose_grad,
GPUDNN,
ALL_LAYOUT,
phi::Conv3dTransposeGradGPUDNNKernel,
float,
double,
float16,
phi::dtype::bfloat16) {}
#else
PD_REGISTER_KERNEL(conv2d_transpose_grad,
GPUDNN,
ALL_LAYOUT,
......@@ -1093,3 +1120,5 @@ PD_REGISTER_KERNEL(conv3d_transpose_grad,
double,
float16) {}
#endif
#endif
......@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/backends/dynload/cudnn.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/kernel_registry.h"
......@@ -367,6 +368,24 @@ PD_REGISTER_KERNEL(conv3d_transpose,
float,
float16) {}
#else
#if CUDNN_VERSION_MIN(8, 1, 0)
PD_REGISTER_KERNEL(conv2d_transpose,
GPUDNN,
ALL_LAYOUT,
phi::Conv2dTransposeGPUDNNKernel,
float,
double,
float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(conv3d_transpose,
GPUDNN,
ALL_LAYOUT,
phi::Conv3dTransposeGPUDNNKernel,
float,
double,
float16,
phi::dtype::bfloat16) {}
#else
PD_REGISTER_KERNEL(conv2d_transpose,
GPUDNN,
ALL_LAYOUT,
......@@ -382,3 +401,5 @@ PD_REGISTER_KERNEL(conv3d_transpose,
double,
float16) {}
#endif
#endif
......@@ -672,12 +672,12 @@ create_test_cudnn_class(TestWithInput1x1Filter1x1)
# ----------------Conv2DCUDNN fp16----------------
create_test_cudnn_fp16_class(TestConv2DOp, grad_check=False)
create_test_cudnn_fp16_class(TestWithPad, grad_check=False)
create_test_cudnn_fp16_class(TestWithStride, grad_check=False)
create_test_cudnn_fp16_class(TestWithGroup, grad_check=False)
create_test_cudnn_fp16_class(TestWith1x1, grad_check=False)
create_test_cudnn_fp16_class(TestWithInput1x1Filter1x1, grad_check=False)
create_test_cudnn_fp16_class(TestConv2DOp)
create_test_cudnn_fp16_class(TestWithPad)
create_test_cudnn_fp16_class(TestWithStride)
create_test_cudnn_fp16_class(TestWithGroup)
create_test_cudnn_fp16_class(TestWith1x1)
create_test_cudnn_fp16_class(TestWithInput1x1Filter1x1)
# ----------------Conv2DCUDNN bf16----------------
......@@ -1061,21 +1061,11 @@ create_test_cudnn_channel_last_class(TestWithStride_AsyPadding)
create_test_cudnn_channel_last_class(TestWithGroup_AsyPadding)
create_test_cudnn_channel_last_class(TestWithDilation_AsyPadding)
create_test_cudnn_channel_last_fp16_class(
TestConv2DOp_AsyPadding, grad_check=False
)
create_test_cudnn_channel_last_fp16_class(
TestWithPad_AsyPadding, grad_check=False
)
create_test_cudnn_channel_last_fp16_class(
TestWithStride_AsyPadding, grad_check=False
)
create_test_cudnn_channel_last_fp16_class(
TestWithGroup_AsyPadding, grad_check=False
)
create_test_cudnn_channel_last_fp16_class(
TestWithDilation_AsyPadding, grad_check=False
)
create_test_cudnn_channel_last_fp16_class(TestConv2DOp_AsyPadding)
create_test_cudnn_channel_last_fp16_class(TestWithPad_AsyPadding)
create_test_cudnn_channel_last_fp16_class(TestWithStride_AsyPadding)
create_test_cudnn_channel_last_fp16_class(TestWithGroup_AsyPadding)
create_test_cudnn_channel_last_fp16_class(TestWithDilation_AsyPadding)
if __name__ == '__main__':
paddle.enable_static()
......
......@@ -30,6 +30,8 @@ from test_conv2d_op import (
)
from paddle.fluid import core
from paddle.fluid.tests.unittests.op_test import get_numeric_gradient
from paddle.fluid.tests.unittests.testsuite import create_op
# ----------------TestDepthwiseConv -----
......@@ -349,6 +351,160 @@ class TestDepthwiseConvWithDilation2andFuse_AsyPadding(TestConv2DOp_v2):
self.padding_algorithm = "EXPLICIT"
def create_test_fp16_class(parent, grad_check=True):
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestDepthwiseConvFP16(parent):
def init_kernel_type(self):
self.use_cuda = True
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=2e-2)
def test_check_grad_no_filter(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place) and grad_check:
self.check_grad_with_place(
place, ['Input'], 'Output', no_grad_set=set(['Filter'])
)
def test_check_grad_no_input(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place) and grad_check:
self.check_grad_with_place(
place, ['Filter'], 'Output', no_grad_set=set(['Input'])
)
cls_name = "{0}_{1}".format(parent.__name__, "FP16OP")
TestDepthwiseConvFP16.__name__ = cls_name
globals()[cls_name] = TestDepthwiseConvFP16
def create_test_bf16_class(parent, atol=1e-2):
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and do not support bfloat16",
)
class TestDepthwiseConvBF16(parent):
def get_numeric_grad(self, place, check_name):
scope = core.Scope()
self._check_grad_helper()
op = create_op(
scope, self.op_type, self.inputs, self.outputs, self.attrs
)
return get_numeric_gradient(
place, scope, op, self.inputs_fp32, check_name, ['Output']
)
def init_kernel_type(self):
self.use_cuda = True
self.no_need_check_grad = True
self.dtype = np.uint16
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=atol)
def test_check_grad_no_filter(self):
place = core.CUDAPlace(0)
numeric_grads = self.get_numeric_grad(place, 'Input')
self.check_grad_with_place(
place,
['Input'],
'Output',
no_grad_set=set(['Filter']),
user_defined_grads=[numeric_grads],
)
def test_check_grad_no_input(self):
place = core.CUDAPlace(0)
numeric_grads = self.get_numeric_grad(place, 'Filter')
self.check_grad_with_place(
place,
['Filter'],
'Output',
no_grad_set=set(['Input']),
user_defined_grads=[numeric_grads],
)
cls_name = "{0}_{1}".format(parent.__name__, "BF16OP")
TestDepthwiseConvBF16.__name__ = cls_name
globals()[cls_name] = TestDepthwiseConvBF16
def create_test_channel_last_fp16_class(parent, grad_check=True):
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestChannelLastFP16(parent):
def init_kernel_type(self):
self.use_cuda = True
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=2e-2)
def test_check_grad_no_filter(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place) and grad_check:
self.check_grad_with_place(
place, ['Input'], 'Output', no_grad_set=set(['Filter'])
)
def test_check_grad_no_input(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place) and grad_check:
self.check_grad_with_place(
place, ['Filter'], 'Output', no_grad_set=set(['Input'])
)
def init_data_format(self):
self.data_format = "NHWC"
def init_test_case_2(self):
N, C, H, W = self.input_size
self.input_size = [N, H, W, C]
cls_name = "{0}_{1}".format(parent.__name__, "ChannelLastFP16")
TestChannelLastFP16.__name__ = cls_name
globals()[cls_name] = TestChannelLastFP16
# depthwise conv2d fp16
create_test_fp16_class(TestDepthwiseConv)
create_test_fp16_class(TestDepthwiseConv2)
create_test_fp16_class(TestDepthwiseConv3)
create_test_fp16_class(TestDepthwiseConvWithDilation)
create_test_fp16_class(TestDepthwiseConvWithDilation2)
create_test_fp16_class(TestDepthwiseConvandFuse)
create_test_fp16_class(TestDepthwiseConv2andFuse)
create_test_fp16_class(TestDepthwiseConv3andFuse)
create_test_fp16_class(TestDepthwiseConvWithDilationandFuse)
create_test_fp16_class(TestDepthwiseConvWithDilation2andFuse)
# depthwise conv2d bf16
create_test_bf16_class(TestDepthwiseConv)
create_test_bf16_class(TestDepthwiseConv2)
create_test_bf16_class(TestDepthwiseConv3, atol=4e-2)
create_test_bf16_class(TestDepthwiseConvWithDilation)
create_test_bf16_class(TestDepthwiseConvWithDilation2)
create_test_bf16_class(TestDepthwiseConvandFuse)
create_test_bf16_class(TestDepthwiseConv2andFuse)
create_test_bf16_class(TestDepthwiseConv3andFuse)
create_test_bf16_class(TestDepthwiseConvWithDilationandFuse)
create_test_bf16_class(TestDepthwiseConvWithDilation2andFuse)
# depthwise conv2d
create_test_padding_SAME_class(TestDepthwiseConv_AsyPadding)
......@@ -368,6 +524,15 @@ create_test_channel_last_class(TestDepthwiseConvWithDilation2_AsyPadding)
create_test_channel_last_class(TestDepthwiseConvandFuse_AsyPadding)
create_test_channel_last_class(TestDepthwiseConvWithDilationandFuse_AsyPadding)
# channel last fp16
create_test_channel_last_fp16_class(TestDepthwiseConv_AsyPadding)
create_test_channel_last_fp16_class(TestDepthwiseConvWithDilation2_AsyPadding)
create_test_channel_last_fp16_class(TestDepthwiseConvandFuse_AsyPadding)
create_test_channel_last_fp16_class(
TestDepthwiseConvWithDilationandFuse_AsyPadding
)
# ------------ depthwise conv2d in MIOPEN ---------
if core.is_compiled_with_rocm():
create_test_cudnn_padding_SAME_class(TestDepthwiseConv_AsyPadding)
......
......@@ -21,11 +21,12 @@ import paddle
from paddle import nn
paddle.enable_static()
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16, get_numeric_gradient
from test_attribute_var import UnittestBase
from paddle import fluid
from paddle.fluid import Program, core, program_guard
from paddle.fluid.tests.unittests.testsuite import create_op
def conv2dtranspose_forward_naive(input_, filter_, attrs):
......@@ -182,10 +183,13 @@ class TestConv2DTransposeOp(OpTest):
self.init_op_type()
self.init_test_case()
input_ = np.random.random(self.input_size).astype(self.dtype)
filter_ = np.random.random(self.filter_size).astype(self.dtype)
if self.is_bfloat16_op():
input_ = np.random.random(self.input_size).astype(np.float32)
filter_ = np.random.random(self.filter_size).astype(np.float32)
else:
input_ = np.random.random(self.input_size).astype(self.dtype)
filter_ = np.random.random(self.filter_size).astype(self.dtype)
self.inputs = {'Input': input_, 'Filter': filter_}
self.attrs = {
'strides': self.stride,
'paddings': self.pad,
......@@ -203,9 +207,18 @@ class TestConv2DTransposeOp(OpTest):
if len(self.output_padding) > 0:
self.attrs['output_padding'] = self.output_padding
output = conv2dtranspose_forward_naive(
input_, filter_, self.attrs
).astype(self.dtype)
output = conv2dtranspose_forward_naive(input_, filter_, self.attrs)
if self.is_bfloat16_op():
output = output.astype(np.float32)
self.inputs = {
'Input': convert_float_to_uint16(input_),
'Filter': convert_float_to_uint16(filter_),
}
self.inputs_fp32 = {'Input': input_, 'Filter': filter_}
else:
output = output.astype(self.dtype)
self.inputs = {'Input': input_, 'Filter': filter_}
self.outputs = {'Output': output}
......@@ -758,7 +771,7 @@ class TestCUDNN_FP16(TestConv2DTransposeOp):
self.filter_size = [f_c, 6, 3, 3]
def init_op_type(self):
self.need_check_grad = False
self.need_check_grad = True
self.use_cudnn = True
self.op_type = "conv2d_transpose"
self.python_api = conv2dtranspose_wrapper
......@@ -766,12 +779,63 @@ class TestCUDNN_FP16(TestConv2DTransposeOp):
def test_check_output(self):
if self.use_cudnn:
place = core.CUDAPlace(0)
self.check_output_with_place(
place, atol=0.02, check_dygraph=(not self.use_mkldnn)
)
if core.is_float16_supported(place):
self.check_output_with_place(
place, atol=0.02, check_dygraph=(not self.use_mkldnn)
)
else:
self.check_output(check_dygraph=(not self.use_mkldnn))
def test_check_grad_no_input(self):
if self.need_check_grad:
if self.use_cudnn:
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_grad_with_place(
place,
['Filter'],
'Output',
max_relative_error=0.02,
no_grad_set=set(['Input']),
)
else:
self.check_grad(
['Filter'], 'Output', no_grad_set=set(['Input'])
)
def test_check_grad_no_filter(self):
if self.need_check_grad:
if self.use_cudnn:
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_grad_with_place(
place,
['Input'],
'Output',
max_relative_error=0.02,
no_grad_set=set(['Filter']),
)
else:
self.check_grad(
['Input'], 'Output', no_grad_set=set(['Filter'])
)
def test_check_grad(self):
if self.need_check_grad:
if self.use_cudnn:
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_grad_with_place(
place,
set(['Input', 'Filter']),
'Output',
max_relative_error=0.02,
)
else:
self.check_grad(
set(['Input', 'Filter']), 'Output', max_relative_error=0.02
)
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
......@@ -870,6 +934,178 @@ class TestCUDNNWithEvenUpsample_NHWC_FP16(TestCUDNN_FP16):
self.data_format = 'NHWC'
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and do not support bfloat16",
)
class TestCUDNN_BF16(TestConv2DTransposeOp):
def get_numeric_grad(self, place, check_name):
scope = core.Scope()
self._check_grad_helper()
op = create_op(
scope, self.op_type, self.inputs, self.outputs, self.attrs
)
return get_numeric_gradient(
place, scope, op, self.inputs_fp32, check_name, ['Output']
)
def init_test_case(self):
self.dtype = np.uint16
self.pad = [1, 1]
self.stride = [1, 1]
self.groups = 1
self.dilations = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3]
def init_op_type(self):
self.need_check_grad = False
self.use_cudnn = True
self.op_type = "conv2d_transpose"
self.python_api = conv2dtranspose_wrapper
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(
place, atol=0.02, check_dygraph=(not self.use_mkldnn)
)
def test_check_grad_no_input(self):
place = core.CUDAPlace(0)
numeric_grads = self.get_numeric_grad(place, 'Filter')
self.check_grad_with_place(
place,
['Filter'],
'Output',
max_relative_error=0.02,
no_grad_set=set(['Input']),
user_defined_grads=[numeric_grads],
)
def test_check_grad_no_filter(self):
place = core.CUDAPlace(0)
numeric_grads = self.get_numeric_grad(place, 'Input')
self.check_grad_with_place(
place,
['Input'],
'Output',
max_relative_error=0.02,
no_grad_set=set(['Filter']),
user_defined_grads=[numeric_grads],
)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and do not support bfloat16",
)
class TestCUDNN_NHWC_BF16(TestCUDNN_BF16):
def init_test_case(self):
self.dtype = np.uint16
self.pad = [0, 0]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 5, 5, 3] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 3, 3]
self.data_format = 'NHWC'
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and do not support bfloat16",
)
class TestCUDNNWithSymmetricPad_NHWC_BF16(TestCUDNN_BF16):
def init_test_case(self):
self.dtype = np.uint16
self.pad = [1, 1]
self.stride = [1, 1]
self.groups = 1
self.dilations = [1, 1]
self.input_size = [2, 5, 5, 3] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 3, 3]
self.data_format = 'NHWC'
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and do not support bfloat16",
)
class TestCUDNNWithAsymmetricPad_NHWC_BF16(TestCUDNN_BF16):
def init_test_case(self):
self.dtype = np.uint16
self.pad = [1, 0, 2, 3]
self.stride = [2, 2]
self.groups = 1
self.dilations = [1, 1]
self.input_size = [2, 5, 5, 3] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 3, 3]
self.data_format = 'NHWC'
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and do not support bfloat16",
)
class TestCUDNNWithStride_NHWC_BF16(TestCUDNN_BF16):
def init_test_case(self):
self.dtype = np.uint16
self.pad = [1, 1]
self.stride = [2, 2]
self.groups = 1
self.dilations = [1, 1]
self.input_size = [2, 5, 5, 3] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 3, 3]
self.data_format = 'NHWC'
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and do not support bfloat16",
)
class TestCUDNNWithGroups_NHWC_BF16(TestCUDNN_BF16):
def init_test_case(self):
self.dtype = np.uint16
self.pad = [1, 1]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 2
self.input_size = [2, 5, 5, 4] # NCHW
f_c = self.input_size[-1]
self.filter_size = [f_c, 3, 3, 3]
self.data_format = 'NHWC'
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and do not support bfloat16",
)
class TestCUDNNWithEvenUpsample_NHWC_BF16(TestCUDNN_BF16):
def init_test_case(self):
self.dtype = np.uint16
self.pad = [2, 2]
self.stride = [2, 2]
self.groups = 1
self.dilations = [1, 1]
self.output_size = [14, 14]
self.input_size = [2, 7, 7, 3] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 5, 5]
self.data_format = 'NHWC'
class TestConv2DTransposeAPI(unittest.TestCase):
def test_case1(self):
data1 = paddle.static.data(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册