未验证 提交 ca4df333 编写于 作者: Z zhangbo9674 提交者: GitHub

[bf16] add bf16 kernel: elementwise_div (#39602)

* add elementwise_div

* refine rocm

* refine code

* refine op register

* solve conflict

* refine unittest

* refine unittest precision

* add rocm
上级 1fcaab45
...@@ -53,6 +53,8 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -53,6 +53,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, float>, ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>, paddle::platform::float16>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext,
paddle::platform::bfloat16>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, double>, ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int>, ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int64_t>, ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int64_t>,
...@@ -65,6 +67,8 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -65,6 +67,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, float>, ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>, paddle::platform::float16>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::bfloat16>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, double>, ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, int>, ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, int64_t>, ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
...@@ -78,6 +82,8 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -78,6 +82,8 @@ REGISTER_OP_CUDA_KERNEL(
float>, float>,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext, ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>, paddle::platform::float16>,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::bfloat16>,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext, ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
double>, double>,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext, ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
......
...@@ -105,6 +105,18 @@ __forceinline__ __device__ float16 CudaShuffleXorSync(unsigned mask, ...@@ -105,6 +105,18 @@ __forceinline__ __device__ float16 CudaShuffleXorSync(unsigned mask,
return float16(__shfl_xor_sync(mask, val.to_half(), width)); return float16(__shfl_xor_sync(mask, val.to_half(), width));
} }
template <>
__forceinline__ __device__ bfloat16 CudaShuffleXorSync(unsigned mask,
bfloat16 val,
int width) {
#if defined(PADDLE_CUDA_BF16)
return bfloat16(__shfl_xor_sync(mask, static_cast<nv_bfloat16>(val), width));
#else
PADDLE_ENFORCE(
false, "__shfl_xor_sync with bfloat16 is not supported on cuda <= 11.");
#endif
}
template <> template <>
__forceinline__ __device__ paddle::platform::complex<float> CudaShuffleXorSync( __forceinline__ __device__ paddle::platform::complex<float> CudaShuffleXorSync(
unsigned mask, paddle::platform::complex<float> val, int width) { unsigned mask, paddle::platform::complex<float> val, int width) {
......
...@@ -91,6 +91,13 @@ __forceinline__ __device__ float16 CudaShuffleXorSync(unsigned mask, ...@@ -91,6 +91,13 @@ __forceinline__ __device__ float16 CudaShuffleXorSync(unsigned mask,
return float16(__shfl_xor(static_cast<float>(val), width)); return float16(__shfl_xor(static_cast<float>(val), width));
} }
template <>
__forceinline__ __device__ bfloat16 CudaShuffleXorSync(unsigned mask,
bfloat16 val,
int width) {
return bfloat16(__shfl_xor(static_cast<float>(val), width));
}
template <> template <>
__forceinline__ __device__ paddle::platform::complex<float> CudaShuffleXorSync( __forceinline__ __device__ paddle::platform::complex<float> CudaShuffleXorSync(
unsigned mask, paddle::platform::complex<float> val, int width) { unsigned mask, paddle::platform::complex<float> val, int width) {
......
...@@ -92,6 +92,7 @@ DEFINE_CUDA_ELEMENTWISE_OP(Divide) ...@@ -92,6 +92,7 @@ DEFINE_CUDA_ELEMENTWISE_OP(Divide)
} // namespace phi } // namespace phi
using float16 = phi::dtype::float16; using float16 = phi::dtype::float16;
using bfloat16 = phi::dtype::bfloat16;
using complex64 = ::phi::dtype::complex<float>; using complex64 = ::phi::dtype::complex<float>;
using complex128 = ::phi::dtype::complex<double>; using complex128 = ::phi::dtype::complex<double>;
...@@ -128,6 +129,7 @@ PD_REGISTER_KERNEL(divide_raw, ...@@ -128,6 +129,7 @@ PD_REGISTER_KERNEL(divide_raw,
int, int,
int64_t, int64_t,
float16, float16,
bfloat16,
complex64, complex64,
complex128) {} complex128) {}
PD_REGISTER_KERNEL(multiply_raw, PD_REGISTER_KERNEL(multiply_raw,
......
...@@ -18,7 +18,7 @@ import numpy as np ...@@ -18,7 +18,7 @@ import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
from op_test import OpTest, skip_check_grad_ci from op_test import OpTest, skip_check_grad_ci, convert_float_to_uint16
class ElementwiseDivOp(OpTest): class ElementwiseDivOp(OpTest):
...@@ -55,6 +55,42 @@ class ElementwiseDivOp(OpTest): ...@@ -55,6 +55,42 @@ class ElementwiseDivOp(OpTest):
pass pass
@unittest.skipIf(
not core.is_compiled_with_cuda() or core.cudnn_version() < 8100,
"core is not compiled with CUDA and cudnn version need larger than 8.1.0")
class TestElementwiseDivOpBF16(OpTest):
def setUp(self):
self.op_type = "elementwise_div"
self.dtype = np.uint16
x = np.random.uniform(0.1, 1, [12, 13]).astype(np.float32)
y = np.random.uniform(0.1, 1, [12, 13]).astype(np.float32)
out = np.divide(x, y)
self.inputs = {
'X': convert_float_to_uint16(x),
'Y': convert_float_to_uint16(y)
}
self.outputs = {'Out': convert_float_to_uint16(out)}
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
def test_check_grad_normal(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X', 'Y'], 'Out')
def test_check_grad_ingore_x(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['Y'], 'Out', no_grad_set=set("X"))
def test_check_grad_ingore_y(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out', no_grad_set=set('Y'))
@skip_check_grad_ci( @skip_check_grad_ci(
reason="[skip shape check] Use y_shape(1) to test broadcast.") reason="[skip shape check] Use y_shape(1) to test broadcast.")
class TestElementwiseDivOp_scalar(ElementwiseDivOp): class TestElementwiseDivOp_scalar(ElementwiseDivOp):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册