未验证 提交 8bfd978f 编写于 作者: N NetPunk 提交者: GitHub

【PaddlePaddle Hackathon 4】:为maxout算子支持 float16 数据类型 (#50976)

* support fp16 for maxout op

* format code

* change api

* add test for static float16

* format code

* formatting code

* atol alignment

* experiment—1

* experiment-2

* experiment-3

* format code
上级 25b4ba7f
...@@ -175,9 +175,11 @@ void MaxOutGradFunctor<DeviceContext, T>::operator()( ...@@ -175,9 +175,11 @@ void MaxOutGradFunctor<DeviceContext, T>::operator()(
} }
template class MaxOutGradFunctor<phi::GPUContext, float>; template class MaxOutGradFunctor<phi::GPUContext, float>;
template class MaxOutGradFunctor<phi::GPUContext, phi::dtype::float16>;
template class MaxOutGradFunctor<phi::GPUContext, double>; template class MaxOutGradFunctor<phi::GPUContext, double>;
template class MaxOutFunctor<phi::GPUContext, float>; template class MaxOutFunctor<phi::GPUContext, float>;
template class MaxOutFunctor<phi::GPUContext, phi::dtype::float16>;
template class MaxOutFunctor<phi::GPUContext, double>; template class MaxOutFunctor<phi::GPUContext, double>;
} // namespace funcs } // namespace funcs
......
...@@ -15,5 +15,10 @@ ...@@ -15,5 +15,10 @@
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/maxout_grad_kernel_impl.h" #include "paddle/phi/kernels/impl/maxout_grad_kernel_impl.h"
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(maxout_grad,
maxout_grad, GPU, ALL_LAYOUT, phi::MaxOutGradKernel, float, double) {} GPU,
ALL_LAYOUT,
phi::MaxOutGradKernel,
float,
phi::dtype::float16,
double) {}
...@@ -15,4 +15,10 @@ ...@@ -15,4 +15,10 @@
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/maxout_kernel_impl.h" #include "paddle/phi/kernels/impl/maxout_kernel_impl.h"
PD_REGISTER_KERNEL(maxout, GPU, ALL_LAYOUT, phi::MaxOutKernel, float, double) {} PD_REGISTER_KERNEL(maxout,
GPU,
ALL_LAYOUT,
phi::MaxOutKernel,
float,
phi::dtype::float16,
double) {}
...@@ -136,5 +136,40 @@ class TestMaxoutAPI(unittest.TestCase): ...@@ -136,5 +136,40 @@ class TestMaxoutAPI(unittest.TestCase):
self.assertRaises(ValueError, F.maxout, x_float32, 2, 2) self.assertRaises(ValueError, F.maxout, x_float32, 2, 2)
class TestMaxOutOpFP16(TestMaxOutOp):
def set_attrs(self):
self.dtype = 'float16'
class TestMaxoutFP16Case1(TestMaxOutOpFP16):
def set_attrs(self):
self.axis = -1
class TestMaxoutFP16Case2(TestMaxOutOpFP16):
def set_attrs(self):
self.axis = 3
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestMaxoutStaticAPIFP16(unittest.TestCase):
def setUp(self):
self.x_np = np.random.uniform(-1, 1, [2, 6, 5, 4]).astype(np.float16)
self.groups = 2
self.axis = 1
self.place = paddle.CUDAPlace(0)
def test_static_api(self):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype)
out = F.maxout(x, self.groups, self.axis)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'X': self.x_np}, fetch_list=[out])
out_ref = maxout_forward_naive(self.x_np, self.groups, self.axis)
np.testing.assert_allclose(out_ref, res[0], rtol=1e-05)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -787,7 +787,7 @@ def maxout(x, groups, axis=1, name=None): ...@@ -787,7 +787,7 @@ def maxout(x, groups, axis=1, name=None):
Parameters: Parameters:
x (Tensor): The input is 4-D Tensor with shape [N, C, H, W] or [N, H, W, C], the data type x (Tensor): The input is 4-D Tensor with shape [N, C, H, W] or [N, H, W, C], the data type
of input is float32 or float64. of input is float16, float32 or float64.
groups (int): The groups number of maxout. `groups` specifies the groups (int): The groups number of maxout. `groups` specifies the
index of channel dimension where maxout will be performed. This must be index of channel dimension where maxout will be performed. This must be
a factor of number of features. a factor of number of features.
...@@ -822,7 +822,9 @@ def maxout(x, groups, axis=1, name=None): ...@@ -822,7 +822,9 @@ def maxout(x, groups, axis=1, name=None):
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.maxout(x, groups, axis) return _C_ops.maxout(x, groups, axis)
else: else:
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'maxout') check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'maxout'
)
if axis not in [1, -1, 3]: if axis not in [1, -1, 3]:
raise ValueError( raise ValueError(
"Attr(axis) should be 1 when data format is NCHW, -1 or 3 when data format is NHWC. Received " "Attr(axis) should be 1 when data format is NCHW, -1 or 3 when data format is NHWC. Received "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册