未验证 提交 2fedd39b 编写于 作者: Z zhangbo9674 提交者: GitHub

[bf16] add bf16 kernel: elementwise_add elementwise_mul elementwise_sub (#39716)

* add ele_add

* add ele_mul

* add ele_sub

* sovle conflict

* fix npu

* refine ele_add

* add ele_mul unittest

* refine ele_sub

* refine ci

* refine unittest
上级 2553af4f
...@@ -24,5 +24,6 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -24,5 +24,6 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseAddKernel<plat::CUDADeviceContext, int>, ops::ElementwiseAddKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, int64_t>, ops::ElementwiseAddKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::float16>, ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::bfloat16>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex<float>>, ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex<float>>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex<double>>); ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex<double>>);
...@@ -167,6 +167,8 @@ REGISTER_OP_CPU_KERNEL( ...@@ -167,6 +167,8 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, int>, ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, bool>, ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, bool>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>, paddle::platform::complex<float>>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext,
...@@ -178,6 +180,8 @@ REGISTER_OP_CPU_KERNEL( ...@@ -178,6 +180,8 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, int>, ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, bool>, ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, bool>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>, paddle::platform::complex<float>>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext,
...@@ -194,6 +198,8 @@ REGISTER_OP_CPU_KERNEL( ...@@ -194,6 +198,8 @@ REGISTER_OP_CPU_KERNEL(
int64_t>, int64_t>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
bool>, bool>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>, paddle::platform::complex<float>>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
...@@ -210,6 +216,8 @@ REGISTER_OP_CPU_KERNEL( ...@@ -210,6 +216,8 @@ REGISTER_OP_CPU_KERNEL(
int64_t>, int64_t>,
ops::ElementwiseMulTripleGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseMulTripleGradKernel<paddle::platform::CPUDeviceContext,
bool>, bool>,
ops::ElementwiseMulTripleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>,
ops::ElementwiseMulTripleGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseMulTripleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>, paddle::platform::complex<float>>,
ops::ElementwiseMulTripleGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseMulTripleGradKernel<paddle::platform::CPUDeviceContext,
......
...@@ -100,6 +100,7 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -100,6 +100,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseMulKernel<plat::CUDADeviceContext, int64_t>, ops::ElementwiseMulKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, bool>, ops::ElementwiseMulKernel<plat::CUDADeviceContext, bool>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::float16>, ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::bfloat16>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::complex<float>>, ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::complex<float>>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::complex<double>>); ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
...@@ -110,6 +111,7 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -110,6 +111,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int64_t>, ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, bool>, ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, bool>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, plat::float16>, ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, plat::bfloat16>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, ops::ElementwiseMulGradKernel<plat::CUDADeviceContext,
plat::complex<float>>, plat::complex<float>>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, ops::ElementwiseMulGradKernel<plat::CUDADeviceContext,
...@@ -122,6 +124,8 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -122,6 +124,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, int64_t>, ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, bool>, ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, bool>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, plat::float16>, ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext,
plat::bfloat16>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext,
plat::complex<float>>, plat::complex<float>>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext,
...@@ -134,6 +138,8 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -134,6 +138,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext, int64_t>, ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext, bool>, ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext, bool>,
ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext, plat::float16>, ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext,
plat::bfloat16>,
ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext, ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext,
plat::complex<float>>, plat::complex<float>>,
ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext, ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext,
......
...@@ -99,6 +99,8 @@ REGISTER_OP_CPU_KERNEL( ...@@ -99,6 +99,8 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, int16_t>, ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, int16_t>,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, int>, ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>, paddle::platform::complex<float>>,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext,
...@@ -110,6 +112,8 @@ REGISTER_OP_CPU_KERNEL( ...@@ -110,6 +112,8 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, int16_t>, ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, int16_t>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, int>, ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>, paddle::platform::complex<float>>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext,
...@@ -126,6 +130,8 @@ REGISTER_OP_CPU_KERNEL( ...@@ -126,6 +130,8 @@ REGISTER_OP_CPU_KERNEL(
int>, int>,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext,
int64_t>, int64_t>,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>, paddle::platform::complex<float>>,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext,
......
...@@ -22,6 +22,8 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -22,6 +22,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, float>, ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>, paddle::platform::float16>,
ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext,
paddle::platform::bfloat16>,
ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, double>, ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, int>, ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, int64_t>, ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, int64_t>,
...@@ -34,6 +36,8 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -34,6 +36,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, float>, ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>, paddle::platform::float16>,
ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::bfloat16>,
ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, double>, ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, int>, ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, int64_t>, ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
...@@ -51,6 +55,8 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -51,6 +55,8 @@ REGISTER_OP_CUDA_KERNEL(
int>, int>,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext, ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext,
int64_t>, int64_t>,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::bfloat16>,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext, ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>, paddle::platform::complex<float>>,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext, ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext,
......
...@@ -170,6 +170,7 @@ PD_REGISTER_KERNEL(subtract_grad, ...@@ -170,6 +170,7 @@ PD_REGISTER_KERNEL(subtract_grad,
int16_t, int16_t,
int, int,
int64_t, int64_t,
phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
...@@ -182,5 +183,6 @@ PD_REGISTER_KERNEL(subtract_double_grad, ...@@ -182,5 +183,6 @@ PD_REGISTER_KERNEL(subtract_double_grad,
int16_t, int16_t,
int, int,
int64_t, int64_t,
phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
...@@ -139,7 +139,8 @@ PD_REGISTER_KERNEL(subtract_raw, ...@@ -139,7 +139,8 @@ PD_REGISTER_KERNEL(subtract_raw,
int, int,
int64_t, int64_t,
complex64, complex64,
complex128) {} complex128,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(divide_raw, PD_REGISTER_KERNEL(divide_raw,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
...@@ -160,7 +161,8 @@ PD_REGISTER_KERNEL(multiply_raw, ...@@ -160,7 +161,8 @@ PD_REGISTER_KERNEL(multiply_raw,
int64_t, int64_t,
bool, bool,
complex64, complex64,
complex128) {} complex128,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(sum_raw, PD_REGISTER_KERNEL(sum_raw,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
......
...@@ -76,6 +76,36 @@ struct CBlas<phi::dtype::bfloat16> { ...@@ -76,6 +76,36 @@ struct CBlas<phi::dtype::bfloat16> {
"Blas VCOPY do not supported on CPU with bfloat16," "Blas VCOPY do not supported on CPU with bfloat16,"
" please check your code")); " please check your code"));
} }
template <typename... ARGS>
static void VADD(int n,
const phi::dtype::bfloat16 *x,
const phi::dtype::bfloat16 *y,
phi::dtype::bfloat16 *z) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] + y[i];
}
}
template <typename... ARGS>
static void VMUL(int n,
const phi::dtype::bfloat16 *x,
const phi::dtype::bfloat16 *y,
phi::dtype::bfloat16 *z) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] * y[i];
}
}
template <typename... ARGS>
static void VSUB(int n,
const phi::dtype::bfloat16 *x,
const phi::dtype::bfloat16 *y,
phi::dtype::bfloat16 *z) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] - y[i];
}
}
}; };
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
......
...@@ -128,6 +128,7 @@ PD_REGISTER_KERNEL(add_grad, ...@@ -128,6 +128,7 @@ PD_REGISTER_KERNEL(add_grad,
int, int,
int64_t, int64_t,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
...@@ -140,6 +141,7 @@ PD_REGISTER_KERNEL(add_double_grad, ...@@ -140,6 +141,7 @@ PD_REGISTER_KERNEL(add_double_grad,
int, int,
int64_t, int64_t,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
...@@ -152,6 +154,7 @@ PD_REGISTER_KERNEL(add_triple_grad, ...@@ -152,6 +154,7 @@ PD_REGISTER_KERNEL(add_triple_grad,
int, int,
int64_t, int64_t,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
...@@ -164,6 +167,7 @@ PD_REGISTER_KERNEL(subtract_grad, ...@@ -164,6 +167,7 @@ PD_REGISTER_KERNEL(subtract_grad,
int, int,
int64_t, int64_t,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
...@@ -176,5 +180,6 @@ PD_REGISTER_KERNEL(subtract_double_grad, ...@@ -176,5 +180,6 @@ PD_REGISTER_KERNEL(subtract_double_grad,
int, int,
int64_t, int64_t,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
...@@ -106,6 +106,7 @@ PD_REGISTER_KERNEL(add_raw, ...@@ -106,6 +106,7 @@ PD_REGISTER_KERNEL(add_raw,
int, int,
int64_t, int64_t,
float16, float16,
bfloat16,
complex64, complex64,
complex128) {} complex128) {}
PD_REGISTER_KERNEL(subtract_raw, PD_REGISTER_KERNEL(subtract_raw,
...@@ -118,6 +119,7 @@ PD_REGISTER_KERNEL(subtract_raw, ...@@ -118,6 +119,7 @@ PD_REGISTER_KERNEL(subtract_raw,
int, int,
int64_t, int64_t,
float16, float16,
bfloat16,
complex64, complex64,
complex128) {} complex128) {}
PD_REGISTER_KERNEL(divide_raw, PD_REGISTER_KERNEL(divide_raw,
...@@ -143,7 +145,8 @@ PD_REGISTER_KERNEL(multiply_raw, ...@@ -143,7 +145,8 @@ PD_REGISTER_KERNEL(multiply_raw,
bool, bool,
float16, float16,
complex64, complex64,
complex128) {} complex128,
bfloat16) {}
PD_REGISTER_KERNEL(sum_raw, PD_REGISTER_KERNEL(sum_raw,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
......
...@@ -121,7 +121,8 @@ PD_REGISTER_KERNEL(subtract, ...@@ -121,7 +121,8 @@ PD_REGISTER_KERNEL(subtract,
int, int,
int64_t, int64_t,
complex64, complex64,
complex128) {} complex128,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(divide, PD_REGISTER_KERNEL(divide,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
...@@ -142,7 +143,8 @@ PD_REGISTER_KERNEL(multiply, ...@@ -142,7 +143,8 @@ PD_REGISTER_KERNEL(multiply,
int64_t, int64_t,
bool, bool,
complex64, complex64,
complex128) {} complex128,
phi::dtype::bfloat16) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(mean, PD_REGISTER_KERNEL(mean,
...@@ -180,6 +182,7 @@ PD_REGISTER_KERNEL(add, ...@@ -180,6 +182,7 @@ PD_REGISTER_KERNEL(add,
int, int,
int64_t, int64_t,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16,
complex64, complex64,
complex128) {} complex128) {}
PD_REGISTER_KERNEL(subtract, PD_REGISTER_KERNEL(subtract,
......
...@@ -17,7 +17,7 @@ import unittest ...@@ -17,7 +17,7 @@ import unittest
import numpy as np import numpy as np
import paddle import paddle
import paddle.fluid.core as core import paddle.fluid.core as core
from op_test import OpTest, skip_check_grad_ci from op_test import OpTest, skip_check_grad_ci, convert_float_to_uint16
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard from paddle.fluid import compiler, Program, program_guard
...@@ -98,6 +98,46 @@ class TestFP16ElementwiseAddOp(TestElementwiseAddOp): ...@@ -98,6 +98,46 @@ class TestFP16ElementwiseAddOp(TestElementwiseAddOp):
place, atol=1e-3, check_dygraph=(self.use_mkldnn == False)) place, atol=1e-3, check_dygraph=(self.use_mkldnn == False))
@unittest.skipIf(
not core.is_compiled_with_cuda() or core.cudnn_version() < 8100,
"core is not compiled with CUDA and cudnn version need larger than 8.1.0")
class TestBF16ElementwiseAddOp(OpTest):
def setUp(self):
self.op_type = "elementwise_add"
self.dtype = np.uint16
self.x = np.random.uniform(0.1, 1, [13, 17]).astype(np.float32)
self.y = np.random.uniform(0.1, 1, [13, 17]).astype(np.float32)
self.out = np.add(self.x, self.y)
self.axis = -1
self.inputs = {
'X':
OpTest.np_dtype_to_fluid_dtype(convert_float_to_uint16(self.x)),
'Y':
OpTest.np_dtype_to_fluid_dtype(convert_float_to_uint16(self.y))
}
self.attrs = {'axis': self.axis, 'use_mkldnn': False}
self.outputs = {'Out': convert_float_to_uint16(self.out)}
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
def test_check_grad_normal(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X', 'Y'], 'Out')
def test_check_grad_ingore_x(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['Y'], 'Out', no_grad_set=set("X"))
def test_check_grad_ingore_y(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out', no_grad_set=set('Y'))
@skip_check_grad_ci( @skip_check_grad_ci(
reason="[skip shape check] Use y_shape(1) to test broadcast.") reason="[skip shape check] Use y_shape(1) to test broadcast.")
class TestElementwiseAddOp_scalar(TestElementwiseAddOp): class TestElementwiseAddOp_scalar(TestElementwiseAddOp):
......
...@@ -23,7 +23,7 @@ import paddle.fluid.core as core ...@@ -23,7 +23,7 @@ import paddle.fluid.core as core
from paddle.fluid import Program, compiler, program_guard from paddle.fluid import Program, compiler, program_guard
from paddle.fluid.op import Operator from paddle.fluid.op import Operator
from op_test import OpTest, skip_check_grad_ci from op_test import OpTest, skip_check_grad_ci, convert_float_to_uint16
class ElementwiseMulOp(OpTest): class ElementwiseMulOp(OpTest):
...@@ -83,6 +83,39 @@ class ElementwiseMulOp(OpTest): ...@@ -83,6 +83,39 @@ class ElementwiseMulOp(OpTest):
pass pass
class TestBF16ElementwiseMulOp(OpTest):
def setUp(self):
self.op_type = "elementwise_mul"
self.dtype = np.uint16
self.x = np.random.uniform(0.1, 1, [13, 17]).astype(np.float32)
self.y = np.random.uniform(0.1, 1, [13, 17]).astype(np.float32)
self.out = np.multiply(self.x, self.y)
self.axis = -1
self.inputs = {
'X':
OpTest.np_dtype_to_fluid_dtype(convert_float_to_uint16(self.x)),
'Y':
OpTest.np_dtype_to_fluid_dtype(convert_float_to_uint16(self.y))
}
self.outputs = {'Out': convert_float_to_uint16(self.out)}
self.attrs = {'axis': self.axis, 'use_mkldnn': False}
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out')
def test_check_grad_ingore_x(self):
self.check_grad(['Y'], 'Out', no_grad_set=set("X"))
def test_check_grad_ingore_y(self):
self.check_grad(['X'], 'Out', no_grad_set=set('Y'))
@skip_check_grad_ci( @skip_check_grad_ci(
reason="[skip shape check] Use y_shape(1) to test broadcast.") reason="[skip shape check] Use y_shape(1) to test broadcast.")
class TestElementwiseMulOp_scalar(ElementwiseMulOp): class TestElementwiseMulOp_scalar(ElementwiseMulOp):
......
...@@ -17,7 +17,8 @@ import unittest ...@@ -17,7 +17,8 @@ import unittest
import numpy as np import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from op_test import OpTest, skip_check_grad_ci import paddle.fluid.core as core
from op_test import OpTest, skip_check_grad_ci, convert_float_to_uint16
class TestElementwiseOp(OpTest): class TestElementwiseOp(OpTest):
...@@ -44,6 +45,33 @@ class TestElementwiseOp(OpTest): ...@@ -44,6 +45,33 @@ class TestElementwiseOp(OpTest):
['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y')) ['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y'))
class TestBF16ElementwiseOp(OpTest):
def setUp(self):
self.op_type = "elementwise_sub"
self.dtype = np.uint16
x = np.random.uniform(0.1, 1, [13, 17]).astype(np.float32)
y = np.random.uniform(0.1, 1, [13, 17]).astype(np.float32)
out = x - y
self.inputs = {
'X': convert_float_to_uint16(x),
'Y': convert_float_to_uint16(y)
}
self.outputs = {'Out': convert_float_to_uint16(out)}
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out')
def test_check_grad_ingore_x(self):
self.check_grad(['Y'], 'Out', no_grad_set=set("X"))
def test_check_grad_ingore_y(self):
self.check_grad(['X'], 'Out', no_grad_set=set('Y'))
@skip_check_grad_ci( @skip_check_grad_ci(
reason="[skip shape check] Use y_shape(1) to test broadcast.") reason="[skip shape check] Use y_shape(1) to test broadcast.")
class TestElementwiseSubOp_scalar(TestElementwiseOp): class TestElementwiseSubOp_scalar(TestElementwiseOp):
......
...@@ -1143,7 +1143,7 @@ class TestBf16(unittest.TestCase): ...@@ -1143,7 +1143,7 @@ class TestBf16(unittest.TestCase):
def test_bf16(self): def test_bf16(self):
out_fp32 = self.train(enable_amp=False) out_fp32 = self.train(enable_amp=False)
out_bf16 = self.train(enable_amp=True) out_bf16 = self.train(enable_amp=True)
self.assertTrue(np.allclose(out_fp32, out_bf16, rtol=1.e-3, atol=1.e-2)) self.assertTrue(np.allclose(out_fp32, out_bf16, rtol=1.e-3, atol=1.e-1))
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册