diff --git a/paddle/phi/kernels/gpu/cross_grad_kernel.cu b/paddle/phi/kernels/gpu/cross_grad_kernel.cu index b3316ea875b9060a7d0a73c86d7bd7fd8517760f..58f53fcf3f3d2261b003bb419315225572acf264 100644 --- a/paddle/phi/kernels/gpu/cross_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/cross_grad_kernel.cu @@ -191,6 +191,7 @@ PD_REGISTER_KERNEL(cross_grad, ALL_LAYOUT, phi::CrossGradKernel, phi::dtype::float16, + phi::dtype::bfloat16, float, double, int, diff --git a/paddle/phi/kernels/gpu/cross_kernel.cu b/paddle/phi/kernels/gpu/cross_kernel.cu index 60623cb8e3d747063cea5c20e36660ab8849853b..461e3a219d5d6ac4358688c2abd447d03992a441 100644 --- a/paddle/phi/kernels/gpu/cross_kernel.cu +++ b/paddle/phi/kernels/gpu/cross_kernel.cu @@ -168,6 +168,7 @@ PD_REGISTER_KERNEL(cross, ALL_LAYOUT, phi::CrossKernel, phi::dtype::float16, + phi::dtype::bfloat16, float, double, int, diff --git a/paddle/phi/kernels/gpu/dot_grad_kernel.cu b/paddle/phi/kernels/gpu/dot_grad_kernel.cu index 874d0f03b7dce3889cb313ef19d990e3e08f9066..0bd448339b661dd5e61402de680005b13ebfa96a 100644 --- a/paddle/phi/kernels/gpu/dot_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/dot_grad_kernel.cu @@ -15,7 +15,9 @@ limitations under the License. */ #include "paddle/phi/kernels/dot_grad_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/complex.h" +#include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/dot_grad_kernel_impl.h" @@ -28,4 +30,6 @@ PD_REGISTER_KERNEL(dot_grad, int, int64_t, phi::dtype::complex, - phi::dtype::complex) {} + phi::dtype::complex, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/dot_kernel.cu b/paddle/phi/kernels/gpu/dot_kernel.cu index 144fc66e3837b96598a8f7ace4d09d0a8d1edbdf..5005f6390d2ac0dfd0eca04d00df9272a0d50510 100644 --- a/paddle/phi/kernels/gpu/dot_kernel.cu +++ b/paddle/phi/kernels/gpu/dot_kernel.cu @@ -15,6 +15,8 @@ #include "paddle/phi/kernels/dot_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/eigen/common.h" @@ -61,4 +63,6 @@ PD_REGISTER_KERNEL(dot, int, int64_t, complex64, - complex128) {} + complex128, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/python/paddle/fluid/tests/unittests/test_cross_op.py b/python/paddle/fluid/tests/unittests/test_cross_op.py index bbfa19aa7ff044014007109a4d99880dba6a6a7e..1114bb0b69ffbdedacbfec466435b1b0895e0fbf 100644 --- a/python/paddle/fluid/tests/unittests/test_cross_op.py +++ b/python/paddle/fluid/tests/unittests/test_cross_op.py @@ -15,11 +15,11 @@ import unittest import numpy as np -from eager_op_test import OpTest +from eager_op_test import OpTest, convert_float_to_uint16 import paddle from paddle import fluid -from paddle.fluid import Program, program_guard +from paddle.fluid import Program, core, program_guard class TestCrossOp(OpTest): @@ -65,6 +65,9 @@ class TestCrossOpCase1(TestCrossOp): self.outputs = {'Out': np.array(z_list).reshape(self.shape)} +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) class TestCrossFP16Op(TestCrossOp): def initTestCase(self): self.shape = (2048, 3) @@ -77,6 +80,51 @@ class TestCrossFP16Op(TestCrossOp): self.outputs = {'Out': np.array(z_list).reshape(self.shape)} +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA and not support the bfloat16", +) +class TestCrossBF16Op(OpTest): + def setUp(self): + self.op_type = "cross" + self.python_api = paddle.cross + self.initTestCase() + self.x = np.random.random(self.shape).astype(np.float32) + self.y = np.random.random(self.shape).astype(np.float32) + self.inputs = { + 'X': convert_float_to_uint16(self.x), + 'Y': convert_float_to_uint16(self.y), + } + self.init_output() + + def initTestCase(self): + self.attrs = {'dim': -2} + self.dtype = np.uint16 + self.shape = (1024, 3, 1) + + def init_output(self): + x = np.squeeze(self.x, 2) + y = np.squeeze(self.y, 2) + z_list = [] + for i in range(1024): + z_list.append(np.cross(x[i], y[i])) + out = np.array(z_list).astype(np.float32).reshape(self.shape) + self.outputs = {'Out': convert_float_to_uint16(out)} + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_bfloat16_supported(place): + self.check_output_with_place(place) + + def test_check_grad_normal(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_bfloat16_supported(place): + self.check_grad_with_place(place, ['X', 'Y'], 'Out') + + class TestCrossAPI(unittest.TestCase): def input_data(self): self.data_x = np.array( diff --git a/python/paddle/fluid/tests/unittests/test_dot_op.py b/python/paddle/fluid/tests/unittests/test_dot_op.py index 4acf5f4ed14ef99ac3d2cf6a2e8040df4b3cbd8f..5cb061c368b900d643245cbc56af0863d6b5af60 100644 --- a/python/paddle/fluid/tests/unittests/test_dot_op.py +++ b/python/paddle/fluid/tests/unittests/test_dot_op.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from eager_op_test import OpTest +from eager_op_test import OpTest, convert_float_to_uint16 import paddle from paddle import fluid @@ -85,7 +85,7 @@ class DotOp(OpTest): def init_input_output(self): self.x = np.random.uniform(0.1, 1, [121]).astype(self.dtype) self.y = np.random.uniform(1, 3, [121]).astype(self.dtype) - self.out = np.dot(self.x, self.y) + self.out = np.dot(self.x, self.y).astype(self.dtype) def init_dtype(self): self.dtype = np.float64 @@ -314,6 +314,201 @@ class TestComplexDotOp2D(OpTest): ) +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class TestDotFP16Op(OpTest): + def setUp(self): + self.op_type = "dot" + self.python_api = paddle.dot + self.init_dtype() + self.init_input_output() + + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(self.x), + 'Y': OpTest.np_dtype_to_fluid_dtype(self.y), + } + self.outputs = {'Out': self.out} + self.attrs = {} + + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=0.125) + + def test_check_grad_normal(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_grad_with_place(place, ['X', 'Y'], 'Out') + + def test_check_grad_ingore_x(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_grad_with_place( + place, ['Y'], 'Out', no_grad_set=set("X") + ) + + def test_check_grad_ingore_y(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_grad_with_place( + place, ['X'], 'Out', no_grad_set=set("Y") + ) + + def init_input_output(self): + self.x = np.random.uniform(0.1, 1, [121]).astype(self.dtype) + self.y = np.random.uniform(1, 3, [121]).astype(self.dtype) + self.out = np.dot(self.x, self.y) + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class DotFP16OpBatch(TestDotFP16Op): + def init_input_output(self): + self.x = ( + np.random.uniform(0.1, 1, [132]) + .astype(self.dtype) + .reshape([11, 12]) + ) + self.y = ( + np.random.uniform(1, 3, [132]).astype(self.dtype).reshape([11, 12]) + ) + self.out = np.sum(self.x * self.y, axis=1).reshape([11, 1]) + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA and not support the bfloat16", +) +class TestDotBF16Op(OpTest): + def setUp(self): + self.op_type = "dot" + self.python_api = paddle.dot + self.init_dtype() + self.init_input_output() + + self.inputs = { + 'X': convert_float_to_uint16(self.x), + 'Y': convert_float_to_uint16(self.y), + } + self.outputs = {'Out': convert_float_to_uint16(self.out)} + self.attrs = {} + + def init_dtype(self): + self.dtype = np.uint16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_bfloat16_supported(place): + self.check_output_with_place(place, atol=0.5) + + def test_check_grad_normal(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_bfloat16_supported(place): + self.check_grad_with_place( + place, + ['X', 'Y'], + 'Out', + user_defined_grads=[self.inputs['Y'], self.inputs['X']], + ) + + def test_check_grad_ingore_x(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_bfloat16_supported(place): + self.check_grad_with_place( + place, + ['Y'], + 'Out', + no_grad_set=set("X"), + user_defined_grads=[self.inputs['X']], + ) + + def test_check_grad_ingore_y(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_bfloat16_supported(place): + self.check_grad_with_place( + place, + ['X'], + 'Out', + no_grad_set=set("Y"), + user_defined_grads=[self.inputs['Y']], + ) + + def init_input_output(self): + self.x = np.random.uniform(0.1, 1, [121]).astype(np.float32) + self.y = np.random.uniform(1, 3, [121]).astype(np.float32) + self.out = np.dot(self.x, self.y) + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA and not support the bfloat16", +) +class DotBF16OpBatch(TestDotBF16Op): + def init_input_output(self): + self.x = ( + np.random.uniform(0.1, 1, [132]) + .astype(np.float32) + .reshape([11, 12]) + ) + self.y = ( + np.random.uniform(1, 3, [132]).astype(np.float32).reshape([11, 12]) + ) + self.out = np.sum(self.x * self.y, axis=1).reshape([11, 1]) + + def test_check_grad_normal(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_bfloat16_supported(place): + self.check_grad_with_place( + place, + ['X', 'Y'], + 'Out', + user_defined_grads=[ + self.y / self.y.shape[0], + self.x / self.x.shape[0], + ], + ) + + def test_check_grad_ingore_x(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_bfloat16_supported(place): + self.check_grad_with_place( + place, + ['Y'], + 'Out', + no_grad_set=set("X"), + user_defined_grads=[self.x / self.x.shape[0]], + ) + + def test_check_grad_ingore_y(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_bfloat16_supported(place): + self.check_grad_with_place( + place, + ['X'], + 'Out', + no_grad_set=set("Y"), + user_defined_grads=[self.y / self.y.shape[0]], + ) + + if __name__ == '__main__': paddle.enable_static() unittest.main()