未验证 提交 de0bad2a 编写于 作者: Z zhangbo9674 提交者: GitHub

[bf16] add bf16 cuda kernel: concat and split (#39380)

* add concat & split

* add concat kernel

* add concat unittest

* add split unittest
上级 24103cbb
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/concat_op.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"
......@@ -25,6 +26,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, bool>,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, plat::bfloat16>,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, uint8_t>,
......
......@@ -21,4 +21,5 @@ REGISTER_OP_CUDA_KERNEL(
ops::SplitOpKernel<plat::CUDADeviceContext, int64_t>,
ops::SplitOpKernel<plat::CUDADeviceContext, int>,
ops::SplitOpKernel<plat::CUDADeviceContext, bool>,
ops::SplitOpKernel<plat::CUDADeviceContext, plat::float16>);
ops::SplitOpKernel<plat::CUDADeviceContext, plat::float16>,
ops::SplitOpKernel<plat::CUDADeviceContext, plat::bfloat16>);
......@@ -121,5 +121,6 @@ PT_REGISTER_KERNEL(concat,
int,
uint8_t,
paddle::platform::float16,
paddle::platform::bfloat16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
......@@ -16,7 +16,7 @@ from __future__ import print_function
import unittest
import numpy as np
from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci
from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci, convert_float_to_uint16
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard, core
import paddle
......@@ -44,14 +44,32 @@ class TestConcatOp(OpTest):
return "float64"
def test_check_output(self):
if self.dtype == np.uint16:
place = core.CUDAPlace(0)
self.check_output_with_place(place)
else:
self.check_output()
def test_check_grad(self):
if self.dtype == np.uint16:
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['x0'], 'Out')
self.check_grad_with_place(place, ['x1'], 'Out')
self.check_grad_with_place(place, ['x2'], 'Out')
else:
self.check_grad(['x0'], 'Out')
self.check_grad(['x1'], 'Out')
self.check_grad(['x2'], 'Out')
def init_test_data(self):
if self.dtype == np.uint16:
x0 = np.random.random((5, 1, 4, 5)).astype(np.float32)
self.x0 = convert_float_to_uint16(x0)
x1 = np.random.random((5, 2, 4, 5)).astype(np.float32)
self.x1 = convert_float_to_uint16(x1)
x2 = np.random.random((5, 3, 4, 5)).astype(np.float32)
self.x2 = convert_float_to_uint16(x2)
else:
self.x0 = np.random.random((5, 1, 4, 5)).astype(self.dtype)
self.x1 = np.random.random((5, 2, 4, 5)).astype(self.dtype)
self.x2 = np.random.random((5, 3, 4, 5)).astype(self.dtype)
......@@ -193,6 +211,22 @@ create_test_fp16(TestConcatOp5)
create_test_fp16(TestConcatOp6)
#----------------Concat Bf16----------------
def create_test_bf16(parent):
@unittest.skipIf(not paddle.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestConcatBf16(parent):
def get_dtype(self):
return np.uint16
cls_name = "{0}_{1}".format(parent.__name__, "Bf16")
TestConcatBf16.__name__ = cls_name
globals()[cls_name] = TestConcatBf16
create_test_bf16(TestConcatOp)
class TestConcatOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
......
......@@ -16,7 +16,7 @@ from __future__ import print_function
import paddle
import unittest
import numpy as np
from op_test import OpTest
from op_test import OpTest, convert_float_to_uint16
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard, core
......@@ -26,12 +26,19 @@ class TestSplitOp(OpTest):
self._set_op_type()
self.dtype = self.get_dtype()
axis = 1
if self.dtype == np.uint16:
x = np.random.random((4, 5, 6)).astype(np.float32)
out = np.split(x, [2, 3], axis)
self.inputs = {'X': convert_float_to_uint16(x)}
self.outputs = {'Out': [('out%d' % i, convert_float_to_uint16(out[i])) \
for i in range(len(out))]}
else:
x = np.random.random((4, 5, 6)).astype(self.dtype)
out = np.split(x, [2, 3], axis)
self.inputs = {'X': x}
self.attrs = {'axis': axis, 'sections': [2, 1, 2]}
self.outputs = {'Out': [('out%d' % i, out[i]) \
for i in range(len(out))]}
self.attrs = {'axis': axis, 'sections': [2, 1, 2]}
def get_dtype(self):
return "float64"
......@@ -226,6 +233,30 @@ def create_test_fp16(parent):
create_test_fp16(TestSplitOp)
#----------------Split Bf16----------------
def create_test_bf16(parent):
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestSplitBf16(parent):
def get_dtype(self):
return np.uint16
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
pass
cls_name = "{0}_{1}".format(parent.__name__, "Bf16")
TestSplitBf16.__name__ = cls_name
globals()[cls_name] = TestSplitBf16
create_test_bf16(TestSplitOp)
class TestSplitAPI(unittest.TestCase):
def test_api(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册