diff --git a/paddle/fluid/operators/gather_op.cc b/paddle/fluid/operators/gather_op.cc index cf4d7b1d670b8add6ff5a138851c6a23ee54169e..8a405cc6fc1baefe997fb5b6133a56d6a2fc0438 100644 --- a/paddle/fluid/operators/gather_op.cc +++ b/paddle/fluid/operators/gather_op.cc @@ -201,12 +201,14 @@ REGISTER_OPERATOR(gather_grad, ops::GatherGradOp, REGISTER_OP_CPU_KERNEL(gather, ops::GatherOpKernel, ops::GatherOpKernel, ops::GatherOpKernel, ops::GatherOpKernel, - ops::GatherOpKernel); + ops::GatherOpKernel, + ops::GatherOpKernel); REGISTER_OP_CPU_KERNEL(gather_grad, ops::GatherGradientOpKernel, ops::GatherGradientOpKernel, ops::GatherGradientOpKernel, ops::GatherGradientOpKernel, - ops::GatherGradientOpKernel); + ops::GatherGradientOpKernel, + ops::GatherGradientOpKernel); REGISTER_OP_VERSION(gather) .AddCheckpoint(R"ROC(upgrad gather, add a new input [Axis])ROC", paddle::framework::compatible::OpVersionDesc().NewInput( diff --git a/paddle/fluid/operators/gather_op.cu b/paddle/fluid/operators/gather_op.cu index 19568835a6e96080bb1c0af642bf9cb19c346bf9..a502a13040949a34e88a4d585327a58ffe92562c 100644 --- a/paddle/fluid/operators/gather_op.cu +++ b/paddle/fluid/operators/gather_op.cu @@ -130,9 +130,11 @@ REGISTER_OP_CUDA_KERNEL(gather, ops::GatherOpCUDAKernel, ops::GatherOpCUDAKernel, ops::GatherOpCUDAKernel, ops::GatherOpCUDAKernel, - ops::GatherOpCUDAKernel); + ops::GatherOpCUDAKernel, + ops::GatherOpCUDAKernel); REGISTER_OP_CUDA_KERNEL(gather_grad, ops::GatherGradOpCUDAKernel, ops::GatherGradOpCUDAKernel, ops::GatherGradOpCUDAKernel, ops::GatherGradOpCUDAKernel, - ops::GatherGradOpCUDAKernel); + ops::GatherGradOpCUDAKernel, + ops::GatherGradOpCUDAKernel); diff --git a/paddle/fluid/operators/math/selected_rows_functor.cu b/paddle/fluid/operators/math/selected_rows_functor.cu index 8563d8b05b186c025ecc4c970a400765adeb0c5d..a4678550cf7bd0d4aa2759d4887dddabed5f9ba4 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cu +++ b/paddle/fluid/operators/math/selected_rows_functor.cu @@ -16,6 +16,7 @@ limitations under the License. */ #include #include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/fluid/platform/float16.h" #include "paddle/phi/kernels/funcs/math_function.h" @@ -445,6 +446,7 @@ template struct MergeAdd; template struct MergeAdd; template struct MergeAdd; template struct MergeAdd; +template struct MergeAdd; template struct MergeAdd>; template struct MergeAdd>; diff --git a/paddle/fluid/operators/sum_op.cu b/paddle/fluid/operators/sum_op.cu index 3e2d2a5495b3428ce0fad9d61431d53b44eea330..33590c1d7cca04e215e55abb26fb2aa3c3b61bec 100644 --- a/paddle/fluid/operators/sum_op.cu +++ b/paddle/fluid/operators/sum_op.cu @@ -258,4 +258,5 @@ REGISTER_OP_CUDA_KERNEL( ops::SumKernel, ops::SumKernel, ops::SumKernel, - ops::SumKernel); + ops::SumKernel, + ops::SumKernel); diff --git a/paddle/fluid/platform/device/gpu/gpu_primitives.h b/paddle/fluid/platform/device/gpu/gpu_primitives.h index 8aec8e840f33273a3130355c751e635e4a3f6736..803674779e756f000005d106f950659ea765c5ce 100644 --- a/paddle/fluid/platform/device/gpu/gpu_primitives.h +++ b/paddle/fluid/platform/device/gpu/gpu_primitives.h @@ -20,6 +20,7 @@ limitations under the License. */ #include #endif #include +#include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/float16.h" @@ -244,6 +245,72 @@ __device__ __forceinline__ void VectorizedAtomicAddPerBlock( #endif #endif +// NOTE(zhangbo): cuda do not have atomicCAS for __nv_bfloat16. +inline static __device__ uint32_t bf16_add_to_low_half(uint32_t val, float x) { + bfloat16 low_half; + // the bfloat16 in lower 16bits + low_half.x = static_cast(val & 0xFFFFu); + low_half = static_cast(static_cast(low_half) + x); + return (val & 0xFFFF0000u) | low_half.x; +} + +inline static __device__ uint32_t bf16_add_to_high_half(uint32_t val, float x) { + bfloat16 high_half; + // the bfloat16 in higher 16bits + high_half.x = static_cast(val >> 16); + high_half = static_cast(static_cast(high_half) + x); + return (val & 0xFFFFu) | (static_cast(high_half.x) << 16); +} + +#if CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +static __device__ __forceinline__ bfloat16 CUDABF16ToPDBF16(__nv_bfloat16 x) { + return *reinterpret_cast(&x); +} + +static __device__ __forceinline__ __nv_bfloat16 PDBF16ToCUDABF16(bfloat16 x) { + return *reinterpret_cast<__nv_bfloat16 *>(&x); +} + +CUDA_ATOMIC_WRAPPER(Add, bfloat16) { + return CUDABF16ToPDBF16(atomicAdd(reinterpret_cast<__nv_bfloat16 *>(address), + PDBF16ToCUDABF16(val))); +} +#else +CUDA_ATOMIC_WRAPPER(Add, bfloat16) { + // concrete packed bfloat16 value may exsits in lower or higher 16bits + // of the 32bits address. + uint32_t *address_as_ui = reinterpret_cast( + reinterpret_cast(address) - + (reinterpret_cast(address) & 0x02)); + float val_f = static_cast(val); + uint32_t old = *address_as_ui; + uint32_t sum; + uint32_t newval; + uint32_t assumed; + if (((uintptr_t)address & 0x02) == 0) { + // the bfloat16 value stay at lower 16 bits of the address. + do { + assumed = old; + old = atomicCAS(address_as_ui, assumed, + bf16_add_to_low_half(assumed, val_f)); + } while (old != assumed); + bfloat16 ret; + ret.x = old & 0xFFFFu; + return ret; + } else { + // the bfloat16 value stay at higher 16 bits of the address. + do { + assumed = old; + old = atomicCAS(address_as_ui, assumed, + bf16_add_to_high_half(assumed, val_f)); + } while (old != assumed); + bfloat16 ret; + ret.x = old >> 16; + return ret; + } +} +#endif + CUDA_ATOMIC_WRAPPER(Add, complex) { float *real = reinterpret_cast(address); float *imag = real + 1; diff --git a/paddle/phi/kernels/gpu/scale_kernel.cu b/paddle/phi/kernels/gpu/scale_kernel.cu index d9c8de21c5bc2d26cb371d03be30ed0616a27a64..930c50a24be8fae40535c2d5e6dbbe85e7ced990 100644 --- a/paddle/phi/kernels/gpu/scale_kernel.cu +++ b/paddle/phi/kernels/gpu/scale_kernel.cu @@ -70,6 +70,7 @@ PD_REGISTER_KERNEL(scale, float, double, phi::dtype::float16, + phi::dtype::bfloat16, uint8_t, int8_t, int16_t, diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 848ebae0706e3c62e0e0e6579cd3c04f02d43be4..5694ef25c79ffe700b7d77de6acd3595936471ca 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -482,7 +482,12 @@ class OpTest(unittest.TestCase): op_proto = OpProtoHolder.instance().get_op_proto(self.op_type) "infer datatype from inputs and outputs for this test case" - self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs) + if self.is_bfloat16_op(): + self.dtype = np.uint16 + self.__class__.dtype = self.dtype + self.output_dtype = np.uint16 + else: + self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs) inputs = append_input_output(block, op_proto, self.inputs, True, self.dtype) outputs = append_input_output(block, op_proto, self.outputs, False, diff --git a/python/paddle/fluid/tests/unittests/test_gather_op.py b/python/paddle/fluid/tests/unittests/test_gather_op.py index 83b39a62f152d2c7e02abe313ffeeafe017d033d..978a3d86d882a2e0d59e8244a956f5c97a4bd9ef 100644 --- a/python/paddle/fluid/tests/unittests/test_gather_op.py +++ b/python/paddle/fluid/tests/unittests/test_gather_op.py @@ -16,7 +16,7 @@ from __future__ import print_function import unittest import numpy as np -from op_test import OpTest +from op_test import OpTest, convert_float_to_uint16 import paddle import paddle.fluid as fluid from paddle.framework import core @@ -117,6 +117,39 @@ class TestCase6(TestGatherOp): self.index_type = "int32" +class TestGatherBF16Op(OpTest): + def setUp(self): + self.op_type = "gather" + self.dtype = np.uint16 + self.config() + xnp = np.random.random(self.x_shape).astype(np.float32) + axis_np = np.array(self.axis).astype(self.axis_type) + index_np = np.array(self.index).astype(self.index_type) + self.inputs = { + 'X': convert_float_to_uint16(xnp), + 'Index': index_np, + 'Axis': axis_np + } + out = gather_numpy(self.inputs['X'], index_np, axis_np[0]) + self.outputs = {'Out': out} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out', numeric_grad_delta=0.5) + + def config(self): + """ + For multi-dimension input + """ + self.x_shape = (3, 88, 3) + self.index = [1, 3, 5] + self.index_type = "int32" + self.axis = [1] + self.axis_type = "int32" + + class TestGatherOp1(OpTest): def setUp(self): self.op_type = "gather" diff --git a/python/paddle/fluid/tests/unittests/test_scale_op.py b/python/paddle/fluid/tests/unittests/test_scale_op.py index c1ce032f506127e495dfd3231471fdabe6dfa26b..d432b8057f624831f40b8cd48a0ede694f8d0a55 100644 --- a/python/paddle/fluid/tests/unittests/test_scale_op.py +++ b/python/paddle/fluid/tests/unittests/test_scale_op.py @@ -16,7 +16,7 @@ from __future__ import print_function import unittest import numpy as np -from op_test import OpTest +from op_test import OpTest, convert_float_to_uint16 import paddle import paddle.fluid as fluid import paddle.fluid.core as core @@ -153,6 +153,23 @@ class TestScaleFp16Op(TestScaleOp): place, ["X"], "Out", max_relative_error=0.05) +class TestScaleBF16Op(OpTest): + def setUp(self): + self.op_type = "scale" + self.dtype = np.uint16 + self.attrs = {'scale': -2.3} + x = np.random.random((10, 10)).astype(np.float32) + out = x * np.float32(self.attrs['scale']) + self.inputs = {'X': convert_float_to_uint16(x)} + self.outputs = {'Out': convert_float_to_uint16(out)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out', numeric_grad_delta=0.8) + + @unittest.skipIf(not core.is_compiled_with_cuda(), "core is not compiled with CUDA") class TestScaleFp16OpSelectedRows(TestScaleOpSelectedRows): diff --git a/python/paddle/fluid/tests/unittests/test_sum_op.py b/python/paddle/fluid/tests/unittests/test_sum_op.py index eddccd4ff24f1a8b7c23bda3da813bc87c199cbe..7040145a76833588f0a5738b1b09e10061497e8c 100644 --- a/python/paddle/fluid/tests/unittests/test_sum_op.py +++ b/python/paddle/fluid/tests/unittests/test_sum_op.py @@ -298,6 +298,32 @@ def create_test_sum_fp16_class(parent): globals()[cls_name] = TestSumFp16Case +#----------- test bf16 ----------- +class TestSumBF16Op(OpTest): + def setUp(self): + self.op_type = "sum" + self.init_kernel_type() + x0 = np.random.random((3, 40)).astype(np.float32) + x1 = np.random.random((3, 40)).astype(np.float32) + x2 = np.random.random((3, 40)).astype(np.float32) + y = x0 + x1 + x2 + self.inputs = { + "X": [("x0", convert_float_to_uint16(x0)), + ("x1", convert_float_to_uint16(x1)), + ("x2", convert_float_to_uint16(x2))] + } + self.outputs = {'Out': convert_float_to_uint16(y)} + + def init_kernel_type(self): + self.dtype = np.uint16 + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['x0'], 'Out', numeric_grad_delta=0.5) + + class API_Test_Add_n(unittest.TestCase): def test_api(self): with fluid.program_guard(fluid.Program(), fluid.Program()):