未验证 提交 46951224 编写于 作者: D Difer 提交者: GitHub

【Hackathon No57】add fp16 & bf16 for max_pool2d_with_index, max_pool3d_with_index (#52314)

* add fp_bf for pool_max_withidx

* fix some error

* fix error

* codestyle error

* fix masktype

* fix input bf type

* input bf dtype convert error

* back to convert input to bf16 first

* fix convert error

* fix bf16 grad check
上级 5ca3bc6d
...@@ -1963,7 +1963,7 @@ __global__ void KernelMaxPool2dWithIdx(const int nthreads, ...@@ -1963,7 +1963,7 @@ __global__ void KernelMaxPool2dWithIdx(const int nthreads,
wstart = max(wstart, 0); wstart = max(wstart, 0);
} }
T1 ele = -FLT_MAX; T1 ele = static_cast<T1>(-FLT_MAX);
int max_index = -1; int max_index = -1;
for (int h = hstart; h < hend; ++h) { for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) { for (int w = wstart; w < wend; ++w) {
...@@ -2015,7 +2015,7 @@ __global__ void AdaptiveKernelMaxPool2dWithIdx(const int nthreads, ...@@ -2015,7 +2015,7 @@ __global__ void AdaptiveKernelMaxPool2dWithIdx(const int nthreads,
wstart = AdaptStartIndex(w_offset, input_width, output_width); wstart = AdaptStartIndex(w_offset, input_width, output_width);
wend = AdaptEndIndex(w_offset, input_width, output_width); wend = AdaptEndIndex(w_offset, input_width, output_width);
T1 ele = -FLT_MAX; T1 ele = static_cast<T1>(-FLT_MAX);
int max_index = -1; int max_index = -1;
for (int h = hstart; h < hend; ++h) { for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) { for (int w = wstart; w < wend; ++w) {
...@@ -2089,7 +2089,7 @@ __global__ void KernelMaxPool2DWithIdxGrad(const int nthreads, ...@@ -2089,7 +2089,7 @@ __global__ void KernelMaxPool2DWithIdxGrad(const int nthreads,
pwend = min((w_offset + padding_width) / stride_width + 1, output_width); pwend = min((w_offset + padding_width) / stride_width + 1, output_width);
} }
T1 input_grad_data = 0; T1 input_grad_data = static_cast<T1>(0);
int input_current_featuremap_idx = h_offset * input_width + w_offset; int input_current_featuremap_idx = h_offset * input_width + w_offset;
for (int ph = phstart; ph < phend; ++ph) { for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) { for (int pw = pwstart; pw < pwend; ++pw) {
...@@ -2259,6 +2259,14 @@ template class MaxPool2dWithIndexFunctor<phi::GPUContext, float, int>; ...@@ -2259,6 +2259,14 @@ template class MaxPool2dWithIndexFunctor<phi::GPUContext, float, int>;
template class MaxPool2dWithIndexGradFunctor<phi::GPUContext, float, int>; template class MaxPool2dWithIndexGradFunctor<phi::GPUContext, float, int>;
template class MaxPool2dWithIndexFunctor<phi::GPUContext, double, int>; template class MaxPool2dWithIndexFunctor<phi::GPUContext, double, int>;
template class MaxPool2dWithIndexGradFunctor<phi::GPUContext, double, int>; template class MaxPool2dWithIndexGradFunctor<phi::GPUContext, double, int>;
template class MaxPool2dWithIndexFunctor<phi::GPUContext, dtype::float16, int>;
template class MaxPool2dWithIndexGradFunctor<phi::GPUContext,
dtype::float16,
int>;
template class MaxPool2dWithIndexFunctor<phi::GPUContext, dtype::bfloat16, int>;
template class MaxPool2dWithIndexGradFunctor<phi::GPUContext,
dtype::bfloat16,
int>;
template <typename T1, typename T2> template <typename T1, typename T2>
__global__ void KernelMaxPool3DWithIdx(const int ncd, __global__ void KernelMaxPool3DWithIdx(const int ncd,
...@@ -2324,7 +2332,7 @@ __global__ void KernelMaxPool3DWithIdx(const int ncd, ...@@ -2324,7 +2332,7 @@ __global__ void KernelMaxPool3DWithIdx(const int ncd,
wstart = max(wstart, 0); wstart = max(wstart, 0);
} }
T1 ele = -FLT_MAX; T1 ele = static_cast<T1>(-FLT_MAX);
int max_index = -1; int max_index = -1;
for (int d = dstart; d < dend; ++d) { for (int d = dstart; d < dend; ++d) {
for (int h = hstart; h < hend; ++h) { for (int h = hstart; h < hend; ++h) {
...@@ -2560,6 +2568,14 @@ template class MaxPool3dWithIndexFunctor<phi::GPUContext, float, int>; ...@@ -2560,6 +2568,14 @@ template class MaxPool3dWithIndexFunctor<phi::GPUContext, float, int>;
template class MaxPool3dWithIndexGradFunctor<phi::GPUContext, float, int>; template class MaxPool3dWithIndexGradFunctor<phi::GPUContext, float, int>;
template class MaxPool3dWithIndexFunctor<phi::GPUContext, double, int>; template class MaxPool3dWithIndexFunctor<phi::GPUContext, double, int>;
template class MaxPool3dWithIndexGradFunctor<phi::GPUContext, double, int>; template class MaxPool3dWithIndexGradFunctor<phi::GPUContext, double, int>;
template class MaxPool3dWithIndexFunctor<phi::GPUContext, dtype::float16, int>;
template class MaxPool3dWithIndexGradFunctor<phi::GPUContext,
dtype::float16,
int>;
template class MaxPool3dWithIndexFunctor<phi::GPUContext, dtype::bfloat16, int>;
template class MaxPool3dWithIndexGradFunctor<phi::GPUContext,
dtype::bfloat16,
int>;
} // namespace funcs } // namespace funcs
} // namespace phi } // namespace phi
...@@ -38,7 +38,9 @@ PD_REGISTER_KERNEL(max_pool2d_with_index_grad, ...@@ -38,7 +38,9 @@ PD_REGISTER_KERNEL(max_pool2d_with_index_grad,
ALL_LAYOUT, ALL_LAYOUT,
phi::MaxPool2dWithIndexGradKernel, phi::MaxPool2dWithIndexGradKernel,
float, float,
double) { double,
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->InputAt(1).SetDataType(phi::CppTypeToDataType<int>::Type()); kernel->InputAt(1).SetDataType(phi::CppTypeToDataType<int>::Type());
} }
...@@ -55,6 +57,8 @@ PD_REGISTER_KERNEL(max_pool3d_with_index_grad, ...@@ -55,6 +57,8 @@ PD_REGISTER_KERNEL(max_pool3d_with_index_grad,
ALL_LAYOUT, ALL_LAYOUT,
phi::MaxPool3dWithIndexGradKernel, phi::MaxPool3dWithIndexGradKernel,
float, float,
double) { double,
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->InputAt(1).SetDataType(phi::CppTypeToDataType<int>::Type()); kernel->InputAt(1).SetDataType(phi::CppTypeToDataType<int>::Type());
} }
...@@ -32,7 +32,9 @@ PD_REGISTER_KERNEL(max_pool2d_with_index, ...@@ -32,7 +32,9 @@ PD_REGISTER_KERNEL(max_pool2d_with_index,
ALL_LAYOUT, ALL_LAYOUT,
phi::MaxPool2dWithIndexKernel, phi::MaxPool2dWithIndexKernel,
float, float,
double) { double,
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->OutputAt(1).SetDataType(phi::CppTypeToDataType<int>::Type()); kernel->OutputAt(1).SetDataType(phi::CppTypeToDataType<int>::Type());
} }
...@@ -49,6 +51,8 @@ PD_REGISTER_KERNEL(max_pool3d_with_index, ...@@ -49,6 +51,8 @@ PD_REGISTER_KERNEL(max_pool3d_with_index,
ALL_LAYOUT, ALL_LAYOUT,
phi::MaxPool3dWithIndexKernel, phi::MaxPool3dWithIndexKernel,
float, float,
double) { double,
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->OutputAt(1).SetDataType(phi::CppTypeToDataType<int>::Type()); kernel->OutputAt(1).SetDataType(phi::CppTypeToDataType<int>::Type());
} }
...@@ -15,9 +15,16 @@ ...@@ -15,9 +15,16 @@
import unittest import unittest
import numpy as np import numpy as np
from eager_op_test import OpTest from eager_op_test import (
OpTest,
convert_float_to_uint16,
convert_uint16_to_float,
get_numeric_gradient,
)
import paddle import paddle
from paddle.fluid import core
from paddle.fluid.tests.unittests.testsuite import create_op
def adaptive_start_index(index, input_size, output_size): def adaptive_start_index(index, input_size, output_size):
...@@ -149,9 +156,18 @@ class TestMaxPoolWithIndex_Op(OpTest): ...@@ -149,9 +156,18 @@ class TestMaxPoolWithIndex_Op(OpTest):
self.init_test_case() self.init_test_case()
self.init_global() self.init_global()
self.init_adaptive() self.init_adaptive()
self.init_dtype()
input = np.random.random(self.shape).astype("float64") if self.is_bfloat16_op():
input = np.random.random(self.shape).astype(np.float32)
input = convert_uint16_to_float(
convert_float_to_uint16(np.round(input * 100.0, 2))
)
else:
input = np.random.random(self.shape).astype(self.dtype)
input = np.round(input * 100.0, 2) input = np.round(input * 100.0, 2)
output, mask = self.pool_forward_naive( output, mask = self.pool_forward_naive(
input, input,
self.ksize, self.ksize,
...@@ -160,8 +176,11 @@ class TestMaxPoolWithIndex_Op(OpTest): ...@@ -160,8 +176,11 @@ class TestMaxPoolWithIndex_Op(OpTest):
self.global_pool, self.global_pool,
self.adaptive, self.adaptive,
) )
output = output.astype("float64")
mask = mask.astype("int32") mask = mask.astype("int32")
if self.is_bfloat16_op():
output = output.astype(np.float32)
else:
output = output.astype(self.dtype)
self.attrs = { self.attrs = {
'strides': self.strides, 'strides': self.strides,
...@@ -171,9 +190,21 @@ class TestMaxPoolWithIndex_Op(OpTest): ...@@ -171,9 +190,21 @@ class TestMaxPoolWithIndex_Op(OpTest):
'adaptive': self.adaptive, 'adaptive': self.adaptive,
} }
if self.is_bfloat16_op():
self.inputs = {'X': convert_float_to_uint16(input)}
self.outputs = {
'Out': convert_float_to_uint16(output),
"Mask": mask,
}
self.inputs_fp32 = {'X': input}
else:
self.inputs = {'X': input} self.inputs = {'X': input}
self.outputs = {'Out': output, "Mask": mask} self.outputs = {'Out': output, "Mask": mask}
def init_dtype(self):
self.dtype = np.float64
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -220,9 +251,90 @@ class TestCase3(TestCase2): ...@@ -220,9 +251,90 @@ class TestCase3(TestCase2):
self.global_pool = False self.global_pool = False
# ----------------max_pool2d_with_index---------------- class TestCastAdaptive3d(TestMaxPoolWithIndex_Op):
def init_adaptive(self):
self.adaptive = True
# ----------------max_pool3d_with_index_fp16----------------
def create_test_fp16_class(parent):
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestMaxPool3dFP16(parent):
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place)
def test_check_grad(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_grad_with_place(place, {'X'}, ['Out'])
cls_name = "{}_{}".format(parent.__name__, "FP16OP")
TestMaxPool3dFP16.__name__ = cls_name
globals()[cls_name] = TestMaxPool3dFP16
create_test_fp16_class(TestMaxPoolWithIndex_Op)
create_test_fp16_class(TestCase1)
create_test_fp16_class(TestCase2)
create_test_fp16_class(TestCase3)
create_test_fp16_class(TestCastAdaptive3d)
# ----------------max_pool3d_with_index_bf16----------------
def create_test_bf16_class(parent):
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and do not support bfloat16",
)
class TestMaxPool3dBF16(parent):
def init_dtype(self):
self.dtype = np.uint16
def get_numeric_grad(self, place, check_name):
scope = core.Scope()
self._check_grad_helper()
op = create_op(
scope, self.op_type, self.inputs, self.outputs, self.attrs
)
return get_numeric_gradient(
place, scope, op, self.inputs_fp32, check_name, ['Out']
)
def test_check_output(self):
place = core.CUDAPlace(0)
if core.is_bfloat16_supported(place):
self.check_output_with_place(place)
def test_check_grad(self):
place = core.CUDAPlace(0)
numeric_grads = self.get_numeric_grad(place, 'X')
if core.is_bfloat16_supported(place):
self.check_grad_with_place(
place, {'X'}, ['Out'], user_defined_grads=[numeric_grads]
)
cls_name = "{}_{}".format(parent.__name__, "BF16OP")
TestMaxPool3dBF16.__name__ = cls_name
globals()[cls_name] = TestMaxPool3dBF16
create_test_bf16_class(TestMaxPoolWithIndex_Op)
create_test_bf16_class(TestCase1)
create_test_bf16_class(TestCase2)
create_test_bf16_class(TestCase3)
create_test_bf16_class(TestCastAdaptive3d)
# ----------------max_pool2d_with_index----------------
def max_pool2d_with_index_wapper( def max_pool2d_with_index_wapper(
x, x,
kernel_size=[], kernel_size=[],
...@@ -279,9 +391,82 @@ class TestCastAdaptive2d(TestCase6): ...@@ -279,9 +391,82 @@ class TestCastAdaptive2d(TestCase6):
self.adaptive = True self.adaptive = True
class TestCastAdaptive3d(TestMaxPoolWithIndex_Op): # ----------------max_pool2d_with_index_fp16----------------
def init_adaptive(self): def create_test_fp16_class(parent):
self.adaptive = True @unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestMaxPool2dFP16(parent):
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place)
def test_check_grad(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_grad_with_place(place, {'X'}, ['Out'])
cls_name = "{}_{}".format(parent.__name__, "FP16OP")
TestMaxPool2dFP16.__name__ = cls_name
globals()[cls_name] = TestMaxPool2dFP16
create_test_fp16_class(TestCase4)
create_test_fp16_class(TestCase5)
create_test_fp16_class(TestCase6)
create_test_fp16_class(TestCase7)
create_test_fp16_class(TestCastAdaptive2d)
# ----------------max_pool2d_with_index_bf16----------------
def create_test_bf16_class(parent):
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and do not support bfloat16",
)
class TestMaxPool2dBF16(parent):
def init_dtype(self):
self.dtype = np.uint16
def get_numeric_grad(self, place, check_name):
scope = core.Scope()
self._check_grad_helper()
op = create_op(
scope, self.op_type, self.inputs, self.outputs, self.attrs
)
return get_numeric_gradient(
place, scope, op, self.inputs_fp32, check_name, ['Out']
)
def test_check_output(self):
place = core.CUDAPlace(0)
if core.is_bfloat16_supported(place):
self.check_output_with_place(place)
def test_check_grad(self):
place = core.CUDAPlace(0)
numeric_grads = self.get_numeric_grad(place, 'X')
if core.is_bfloat16_supported(place):
self.check_grad_with_place(
place, {'X'}, ['Out'], user_defined_grads=[numeric_grads]
)
cls_name = "{}_{}".format(parent.__name__, "BF16OP")
TestMaxPool2dBF16.__name__ = cls_name
globals()[cls_name] = TestMaxPool2dBF16
create_test_bf16_class(TestCase4)
create_test_bf16_class(TestCase5)
create_test_bf16_class(TestCase6)
create_test_bf16_class(TestCase7)
create_test_bf16_class(TestCastAdaptive2d)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册