未验证 提交 48ccb785 编写于 作者: S superwinner1 提交者: GitHub

【Hackathon No.55】 add channel_shuffle FP16/BF16 support and tests (#51884)

* No55 add channel_shuffle FP16/BF16 support and tests
上级 205094f0
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
#include <string> #include <string>
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
......
...@@ -23,4 +23,6 @@ PD_REGISTER_KERNEL(channel_shuffle_grad, ...@@ -23,4 +23,6 @@ PD_REGISTER_KERNEL(channel_shuffle_grad,
ALL_LAYOUT, ALL_LAYOUT,
phi::ChannelShuffleGradKernel, phi::ChannelShuffleGradKernel,
float, float,
double) {} double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -23,4 +23,6 @@ PD_REGISTER_KERNEL(channel_shuffle, ...@@ -23,4 +23,6 @@ PD_REGISTER_KERNEL(channel_shuffle,
ALL_LAYOUT, ALL_LAYOUT,
phi::ChannelShuffleKernel, phi::ChannelShuffleKernel,
float, float,
double) {} double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import unittest import unittest
import numpy as np import numpy as np
from eager_op_test import OpTest from eager_op_test import OpTest, convert_float_to_uint16
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
...@@ -45,6 +45,7 @@ def channel_shuffle_np(x, groups, data_format="NCHW"): ...@@ -45,6 +45,7 @@ def channel_shuffle_np(x, groups, data_format="NCHW"):
class TestChannelShuffleOp(OpTest): class TestChannelShuffleOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "channel_shuffle" self.op_type = "channel_shuffle"
self.init_dtype()
self.init_data_format() self.init_data_format()
n, c, h, w = 2, 9, 4, 4 n, c, h, w = 2, 9, 4, 4
self.python_api = paddle.nn.functional.channel_shuffle self.python_api = paddle.nn.functional.channel_shuffle
...@@ -56,13 +57,16 @@ class TestChannelShuffleOp(OpTest): ...@@ -56,13 +57,16 @@ class TestChannelShuffleOp(OpTest):
groups = 3 groups = 3
x = np.random.random(shape).astype("float64") x = np.random.random(shape).astype(self.dtype)
npresult = channel_shuffle_np(x, groups, self.format) npresult = channel_shuffle_np(x, groups, self.format)
self.inputs = {'X': x} self.inputs = {'X': x}
self.outputs = {'Out': npresult} self.outputs = {'Out': npresult}
self.attrs = {'groups': groups, "data_format": self.format} self.attrs = {'groups': groups, "data_format": self.format}
def init_dtype(self):
self.dtype = 'float64'
def init_data_format(self): def init_data_format(self):
self.format = "NCHW" self.format = "NCHW"
...@@ -268,5 +272,53 @@ class TestChannelShuffleError(unittest.TestCase): ...@@ -268,5 +272,53 @@ class TestChannelShuffleError(unittest.TestCase):
self.assertRaises(ValueError, error_data_format_layer) self.assertRaises(ValueError, error_data_format_layer)
class TestChannelShuffleFP16OP(TestChannelShuffleOp):
def init_dtype(self):
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestChannelShuffleBF16OP(OpTest):
def setUp(self):
self.op_type = "channel_shuffle"
self.init_data_format()
n, c, h, w = 2, 9, 4, 4
self.python_api = paddle.nn.functional.channel_shuffle
self.dtype = np.uint16
self.use_mkldnn = False
if self.format == "NCHW":
shape = [n, c, h, w]
if self.format == "NHWC":
shape = [n, h, w, c]
groups = 3
x = np.random.random(shape).astype('float32')
out = channel_shuffle_np(x, groups, self.format)
self.inputs = {'X': convert_float_to_uint16(x)}
self.attrs = {'groups': groups, "data_format": self.format}
self.outputs = {'Out': convert_float_to_uint16(out)}
def init_data_format(self):
self.format = "NCHW"
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place,
['X'],
'Out',
)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册