未验证 提交 98fc4277 编写于 作者: C co63oc 提交者: GitHub

Add broadcast_tensors tests (#52961)

上级 8163faaa
......@@ -16,7 +16,7 @@
#include <vector>
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
......@@ -104,8 +104,10 @@ PD_REGISTER_KERNEL(broadcast_tensors_grad,
GPU,
ALL_LAYOUT,
phi::BroadcastTensorsGradKernel,
bool,
int,
int64_t,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -14,7 +14,7 @@
#include "paddle/phi/kernels/broadcast_tensors_kernel.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/broadcast_tensors_kernel_impl.h"
......@@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(broadcast_tensors,
int64_t,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -2689,6 +2689,26 @@ class OpTest(unittest.TestCase):
def np_value_to_fluid_value(input):
return input
def cast_bf16_output(self, block, cast_inputs):
output_names = []
for i in range(0, len(cast_inputs)):
cast_output = block.create_var(
dtype="float32", shape=cast_inputs[i].shape
)
cast_op = block.append_op(
inputs={"X": cast_inputs[i]},
outputs={"Out": cast_output},
type="cast",
attrs={
"in_dtype": core.VarDesc.VarType.BF16,
"out_dtype": core.VarDesc.VarType.FP32,
},
)
cast_op.desc.infer_var_type(block.desc)
cast_op.desc.infer_shape(block.desc)
output_names.append(cast_output.name)
return output_names
def _get_gradient(
self,
input_to_check,
......@@ -2712,6 +2732,9 @@ class OpTest(unittest.TestCase):
if user_defined_grad_outputs is None:
if self.dtype == np.uint16:
cast_inputs = list(map(block.var, output_names))
if self.op_type == "broadcast_tensors":
output_names = self.cast_bf16_output(block, cast_inputs)
else:
cast_outputs = block.create_var(
dtype="float32", shape=cast_inputs[0].shape
)
......
......@@ -16,7 +16,7 @@ import random
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
import paddle
from paddle.fluid import core
......@@ -43,7 +43,7 @@ def find_output_shape(input_list):
return list(reversed(output_shape))
def make_inputs_outputs(input_shapes, dtype):
def make_inputs_outputs(input_shapes, dtype, is_bfloat16=False):
"""Automatically generate formatted inputs and outputs from input_shapes"""
input_list = [
np.random.random(shape).astype(dtype) for shape in input_shapes
......@@ -53,6 +53,16 @@ def make_inputs_outputs(input_shapes, dtype):
x + np.zeros(output_shape).astype(x.dtype) for x in input_list
]
if is_bfloat16:
input_list = [
convert_float_to_uint16(input_list[i])
for i in range(len(input_list))
]
output_list = [
convert_float_to_uint16(output_list[i])
for i in range(len(output_list))
]
output_formatted = {
"Out": [(f"out{i}", output_list[i]) for i in range(len(output_list))]
}
......@@ -63,24 +73,24 @@ def make_inputs_outputs(input_shapes, dtype):
return input_formatted, output_formatted
def gen_rank_diff_test(dtype):
def gen_rank_diff_test(dtype, is_bfloat16=False):
input_shapes = [(2, 60, 1), (6, 2, 1, 10)]
return make_inputs_outputs(input_shapes, dtype)
return make_inputs_outputs(input_shapes, dtype, is_bfloat16)
def gen_no_broadcast_test(dtype):
def gen_no_broadcast_test(dtype, is_bfloat16=False):
input_shapes = [(12, 1, 10, 1), (12, 1, 10, 1)]
return make_inputs_outputs(input_shapes, dtype)
return make_inputs_outputs(input_shapes, dtype, is_bfloat16)
def gen_mixed_tensors_test(dtype):
def gen_mixed_tensors_test(dtype, is_bfloat16=False):
input_shapes = [(2, 60, 1), (2, 2, 1, 30), (1, 2, 60, 1)]
return make_inputs_outputs(input_shapes, dtype)
return make_inputs_outputs(input_shapes, dtype, is_bfloat16)
def gen_empty_tensors_test(dtype):
def gen_empty_tensors_test(dtype, is_bfloat16=False):
input_shapes = [(0), (0), (0)]
return make_inputs_outputs(input_shapes, dtype)
return make_inputs_outputs(input_shapes, dtype, is_bfloat16)
class TestCPUBroadcastTensorsOp(OpTest):
......@@ -125,7 +135,7 @@ class TestCPUBroadcastTensorsOp(OpTest):
def test_check_output(self):
self.run_dual_test(
self.check_output_with_place,
{"place": self.place, "atol": 1e-1},
{"place": self.place},
)
def test_check_grad_normal(self):
......@@ -135,7 +145,6 @@ class TestCPUBroadcastTensorsOp(OpTest):
"place": self.place,
"inputs_to_check": ['x0', 'x1'],
"output_names": ['out0', 'out1'],
"max_relative_error": 0.05,
},
)
self.run_triple_in_test(
......@@ -144,7 +153,6 @@ class TestCPUBroadcastTensorsOp(OpTest):
"place": self.place,
"inputs_to_check": ['x0', 'x1', 'x2'],
"output_names": ['out0', 'out1', "out2"],
"max_relative_error": 0.05,
},
)
......@@ -152,14 +160,77 @@ class TestCPUBroadcastTensorsOp(OpTest):
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestCUDABroadcastTensorsOp(TestCPUBroadcastTensorsOp):
class TestBroadcastTensorsFP16Op(TestCPUBroadcastTensorsOp):
def set_place(self):
self.place = core.CUDAPlace(0)
def set_dtypes(self):
self.dtypes = ['float64']
if core.is_float16_supported(self.place):
self.dtypes.append('float16')
self.dtypes = ['float16']
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestBroadcastTensorsBF16Op(OpTest):
def setUp(self):
self.op_type = "broadcast_tensors"
self.dtype = np.uint16
self.np_dtype = "float32"
self.use_mkldnn = False
self.attrs = {'use_mkldnn': self.use_mkldnn}
self.test_gen_func_list = [
gen_rank_diff_test,
gen_no_broadcast_test,
gen_mixed_tensors_test,
]
self.python_api = paddle.broadcast_tensors
self.place = core.CUDAPlace(0)
def run_dual_test(self, test_func, args):
for gen_func in self.test_gen_func_list:
self.inputs, self.outputs = gen_func(self.np_dtype, True)
if len(self.outputs["Out"]) < 3:
self.python_out_sig = [
f"out{i}" for i in range(len(self.outputs["Out"]))
]
test_func(**args)
def run_triple_in_test(self, test_func, args):
self.inputs, self.outputs = self.test_gen_func_list[2](
self.np_dtype, True
)
self.python_out_sig = [
f"out{i}" for i in range(len(self.outputs["Out"]))
]
test_func(**args)
def test_check_output(self):
self.run_dual_test(
self.check_output_with_place,
{"place": self.place},
)
def test_check_grad_normal(self):
self.run_dual_test(
self.check_grad_with_place,
{
"place": self.place,
"inputs_to_check": ['x0', 'x1'],
"output_names": ['out0', 'out1'],
"check_dygraph": False,
},
)
self.run_triple_in_test(
self.check_grad_with_place,
{
"place": self.place,
"inputs_to_check": ['x0', 'x1', 'x2'],
"output_names": ['out0', 'out1', 'out2'],
"check_dygraph": False,
},
)
class TestBroadcastTensorsAPI(unittest.TestCase):
......
......@@ -1240,7 +1240,15 @@ def broadcast_tensors(input, name=None):
check_variable_and_dtype(
x,
'input[' + str(id) + ']',
['bool', 'float32', 'float64', 'int32', 'int64'],
[
'bool',
'float16',
'float32',
'float64',
'int32',
'int64',
'uint16',
],
'broadcast_tensors',
)
if x.dtype != input[0].dtype:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册