diff --git a/paddle/phi/kernels/gpu/broadcast_tensors_grad_kernel.cu b/paddle/phi/kernels/gpu/broadcast_tensors_grad_kernel.cu index 7acfd33e94a9a415d892440956a071479a9c4665..a202c428394d7f84729ceb3648ff26f143cd7ee6 100644 --- a/paddle/phi/kernels/gpu/broadcast_tensors_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/broadcast_tensors_grad_kernel.cu @@ -16,7 +16,7 @@ #include -#include "paddle/phi/common/float16.h" +#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/kernel_registry.h" @@ -104,8 +104,10 @@ PD_REGISTER_KERNEL(broadcast_tensors_grad, GPU, ALL_LAYOUT, phi::BroadcastTensorsGradKernel, + bool, int, int64_t, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/broadcast_tensors_kernel.cu b/paddle/phi/kernels/gpu/broadcast_tensors_kernel.cu index 82dac4552a42febca6f8c159b87d53918c8d9f15..3d16797cb66c09e85a590365f97c9b4004d0d2ed 100644 --- a/paddle/phi/kernels/gpu/broadcast_tensors_kernel.cu +++ b/paddle/phi/kernels/gpu/broadcast_tensors_kernel.cu @@ -14,7 +14,7 @@ #include "paddle/phi/kernels/broadcast_tensors_kernel.h" -#include "paddle/phi/common/float16.h" +#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/broadcast_tensors_kernel_impl.h" @@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(broadcast_tensors, int64_t, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/python/paddle/fluid/tests/unittests/eager_op_test.py b/python/paddle/fluid/tests/unittests/eager_op_test.py index 0bcb34f956b238ed44dd3cebf16fe43ece56be5d..1be451d65e83b2c23feb977bde3b5fed347b7ffd 100644 --- a/python/paddle/fluid/tests/unittests/eager_op_test.py +++ b/python/paddle/fluid/tests/unittests/eager_op_test.py @@ -2689,6 +2689,26 @@ class OpTest(unittest.TestCase): def np_value_to_fluid_value(input): return input + def cast_bf16_output(self, block, cast_inputs): + output_names = [] + for i in range(0, len(cast_inputs)): + cast_output = block.create_var( + dtype="float32", shape=cast_inputs[i].shape + ) + cast_op = block.append_op( + inputs={"X": cast_inputs[i]}, + outputs={"Out": cast_output}, + type="cast", + attrs={ + "in_dtype": core.VarDesc.VarType.BF16, + "out_dtype": core.VarDesc.VarType.FP32, + }, + ) + cast_op.desc.infer_var_type(block.desc) + cast_op.desc.infer_shape(block.desc) + output_names.append(cast_output.name) + return output_names + def _get_gradient( self, input_to_check, @@ -2712,21 +2732,24 @@ class OpTest(unittest.TestCase): if user_defined_grad_outputs is None: if self.dtype == np.uint16: cast_inputs = list(map(block.var, output_names)) - cast_outputs = block.create_var( - dtype="float32", shape=cast_inputs[0].shape - ) - cast_op = block.append_op( - inputs={"X": cast_inputs}, - outputs={"Out": cast_outputs}, - type="cast", - attrs={ - "in_dtype": core.VarDesc.VarType.BF16, - "out_dtype": core.VarDesc.VarType.FP32, - }, - ) - cast_op.desc.infer_var_type(block.desc) - cast_op.desc.infer_shape(block.desc) - output_names = [cast_outputs.name] + if self.op_type == "broadcast_tensors": + output_names = self.cast_bf16_output(block, cast_inputs) + else: + cast_outputs = block.create_var( + dtype="float32", shape=cast_inputs[0].shape + ) + cast_op = block.append_op( + inputs={"X": cast_inputs}, + outputs={"Out": cast_outputs}, + type="cast", + attrs={ + "in_dtype": core.VarDesc.VarType.BF16, + "out_dtype": core.VarDesc.VarType.FP32, + }, + ) + cast_op.desc.infer_var_type(block.desc) + cast_op.desc.infer_shape(block.desc) + output_names = [cast_outputs.name] loss = append_loss_ops(block, output_names) param_grad_list = append_backward( loss=loss, diff --git a/python/paddle/fluid/tests/unittests/test_broadcast_tensors_op.py b/python/paddle/fluid/tests/unittests/test_broadcast_tensors_op.py index d2af1933469aa0395fa2d38989ea1a780bd68c85..be37ff0157898384f06666bc1378001d388ab5b7 100644 --- a/python/paddle/fluid/tests/unittests/test_broadcast_tensors_op.py +++ b/python/paddle/fluid/tests/unittests/test_broadcast_tensors_op.py @@ -16,7 +16,7 @@ import random import unittest import numpy as np -from eager_op_test import OpTest +from eager_op_test import OpTest, convert_float_to_uint16 import paddle from paddle.fluid import core @@ -43,7 +43,7 @@ def find_output_shape(input_list): return list(reversed(output_shape)) -def make_inputs_outputs(input_shapes, dtype): +def make_inputs_outputs(input_shapes, dtype, is_bfloat16=False): """Automatically generate formatted inputs and outputs from input_shapes""" input_list = [ np.random.random(shape).astype(dtype) for shape in input_shapes @@ -53,6 +53,16 @@ def make_inputs_outputs(input_shapes, dtype): x + np.zeros(output_shape).astype(x.dtype) for x in input_list ] + if is_bfloat16: + input_list = [ + convert_float_to_uint16(input_list[i]) + for i in range(len(input_list)) + ] + output_list = [ + convert_float_to_uint16(output_list[i]) + for i in range(len(output_list)) + ] + output_formatted = { "Out": [(f"out{i}", output_list[i]) for i in range(len(output_list))] } @@ -63,24 +73,24 @@ def make_inputs_outputs(input_shapes, dtype): return input_formatted, output_formatted -def gen_rank_diff_test(dtype): +def gen_rank_diff_test(dtype, is_bfloat16=False): input_shapes = [(2, 60, 1), (6, 2, 1, 10)] - return make_inputs_outputs(input_shapes, dtype) + return make_inputs_outputs(input_shapes, dtype, is_bfloat16) -def gen_no_broadcast_test(dtype): +def gen_no_broadcast_test(dtype, is_bfloat16=False): input_shapes = [(12, 1, 10, 1), (12, 1, 10, 1)] - return make_inputs_outputs(input_shapes, dtype) + return make_inputs_outputs(input_shapes, dtype, is_bfloat16) -def gen_mixed_tensors_test(dtype): +def gen_mixed_tensors_test(dtype, is_bfloat16=False): input_shapes = [(2, 60, 1), (2, 2, 1, 30), (1, 2, 60, 1)] - return make_inputs_outputs(input_shapes, dtype) + return make_inputs_outputs(input_shapes, dtype, is_bfloat16) -def gen_empty_tensors_test(dtype): +def gen_empty_tensors_test(dtype, is_bfloat16=False): input_shapes = [(0), (0), (0)] - return make_inputs_outputs(input_shapes, dtype) + return make_inputs_outputs(input_shapes, dtype, is_bfloat16) class TestCPUBroadcastTensorsOp(OpTest): @@ -125,7 +135,7 @@ class TestCPUBroadcastTensorsOp(OpTest): def test_check_output(self): self.run_dual_test( self.check_output_with_place, - {"place": self.place, "atol": 1e-1}, + {"place": self.place}, ) def test_check_grad_normal(self): @@ -135,7 +145,6 @@ class TestCPUBroadcastTensorsOp(OpTest): "place": self.place, "inputs_to_check": ['x0', 'x1'], "output_names": ['out0', 'out1'], - "max_relative_error": 0.05, }, ) self.run_triple_in_test( @@ -144,7 +153,6 @@ class TestCPUBroadcastTensorsOp(OpTest): "place": self.place, "inputs_to_check": ['x0', 'x1', 'x2'], "output_names": ['out0', 'out1', "out2"], - "max_relative_error": 0.05, }, ) @@ -152,14 +160,77 @@ class TestCPUBroadcastTensorsOp(OpTest): @unittest.skipIf( not core.is_compiled_with_cuda(), "core is not compiled with CUDA" ) -class TestCUDABroadcastTensorsOp(TestCPUBroadcastTensorsOp): +class TestBroadcastTensorsFP16Op(TestCPUBroadcastTensorsOp): def set_place(self): self.place = core.CUDAPlace(0) def set_dtypes(self): - self.dtypes = ['float64'] - if core.is_float16_supported(self.place): - self.dtypes.append('float16') + self.dtypes = ['float16'] + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support bfloat16", +) +class TestBroadcastTensorsBF16Op(OpTest): + def setUp(self): + self.op_type = "broadcast_tensors" + self.dtype = np.uint16 + self.np_dtype = "float32" + self.use_mkldnn = False + self.attrs = {'use_mkldnn': self.use_mkldnn} + self.test_gen_func_list = [ + gen_rank_diff_test, + gen_no_broadcast_test, + gen_mixed_tensors_test, + ] + self.python_api = paddle.broadcast_tensors + self.place = core.CUDAPlace(0) + + def run_dual_test(self, test_func, args): + for gen_func in self.test_gen_func_list: + self.inputs, self.outputs = gen_func(self.np_dtype, True) + if len(self.outputs["Out"]) < 3: + self.python_out_sig = [ + f"out{i}" for i in range(len(self.outputs["Out"])) + ] + test_func(**args) + + def run_triple_in_test(self, test_func, args): + self.inputs, self.outputs = self.test_gen_func_list[2]( + self.np_dtype, True + ) + self.python_out_sig = [ + f"out{i}" for i in range(len(self.outputs["Out"])) + ] + test_func(**args) + + def test_check_output(self): + self.run_dual_test( + self.check_output_with_place, + {"place": self.place}, + ) + + def test_check_grad_normal(self): + self.run_dual_test( + self.check_grad_with_place, + { + "place": self.place, + "inputs_to_check": ['x0', 'x1'], + "output_names": ['out0', 'out1'], + "check_dygraph": False, + }, + ) + self.run_triple_in_test( + self.check_grad_with_place, + { + "place": self.place, + "inputs_to_check": ['x0', 'x1', 'x2'], + "output_names": ['out0', 'out1', 'out2'], + "check_dygraph": False, + }, + ) class TestBroadcastTensorsAPI(unittest.TestCase): diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 27e2a4b812de57a512cce9b6d69cceddf70885a0..5c4a029711b0fcb7a0cd935fe0ab8fcd89487126 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1240,7 +1240,15 @@ def broadcast_tensors(input, name=None): check_variable_and_dtype( x, 'input[' + str(id) + ']', - ['bool', 'float32', 'float64', 'int32', 'int64'], + [ + 'bool', + 'float16', + 'float32', + 'float64', + 'int32', + 'int64', + 'uint16', + ], 'broadcast_tensors', ) if x.dtype != input[0].dtype: