未验证 提交 706503d0 编写于 作者: C Charles-hit 提交者: GitHub

[AMP OP&Test]support prod、meshgrid、expand_as bf16 dtype (#53865)

* add meshgrid,expand_as, prod and grad bf16 kernel

* fix bf16 for optest

* modify code style

* fix amp test
上级 1007690b
...@@ -60,4 +60,5 @@ PD_REGISTER_KERNEL(expand_as_grad, ...@@ -60,4 +60,5 @@ PD_REGISTER_KERNEL(expand_as_grad,
float, float,
double, double,
int, int,
int64_t) {} int64_t,
phi::dtype::bfloat16) {}
...@@ -87,4 +87,5 @@ PD_REGISTER_KERNEL(expand_as, ...@@ -87,4 +87,5 @@ PD_REGISTER_KERNEL(expand_as,
double, double,
int, int,
int64_t, int64_t,
bool) {} bool,
phi::dtype::bfloat16) {}
...@@ -26,4 +26,5 @@ PD_REGISTER_KERNEL(meshgrid_grad, ...@@ -26,4 +26,5 @@ PD_REGISTER_KERNEL(meshgrid_grad,
float, float,
double, double,
int, int,
int64_t) {} int64_t,
phi::dtype::bfloat16) {}
...@@ -26,4 +26,5 @@ PD_REGISTER_KERNEL(meshgrid, ...@@ -26,4 +26,5 @@ PD_REGISTER_KERNEL(meshgrid,
float, float,
double, double,
int, int,
int64_t) {} int64_t,
phi::dtype::bfloat16) {}
...@@ -25,4 +25,6 @@ PD_REGISTER_KERNEL(prod_grad, ...@@ -25,4 +25,6 @@ PD_REGISTER_KERNEL(prod_grad,
float, float,
double, double,
int, int,
int64_t) {} int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -35,6 +35,14 @@ void ProdKernel(const Context& dev_ctx, ...@@ -35,6 +35,14 @@ void ProdKernel(const Context& dev_ctx,
#ifdef PADDLE_WITH_XPU_KP #ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL(prod, KPS, ALL_LAYOUT, phi::ProdKernel, float) {} PD_REGISTER_KERNEL(prod, KPS, ALL_LAYOUT, phi::ProdKernel, float) {}
#else #else
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(prod,
prod, KPS, ALL_LAYOUT, phi::ProdKernel, float, double, int, int64_t) {} KPS,
ALL_LAYOUT,
phi::ProdKernel,
float,
double,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#endif #endif
...@@ -48,7 +48,9 @@ PD_REGISTER_KERNEL(prod_infer, ...@@ -48,7 +48,9 @@ PD_REGISTER_KERNEL(prod_infer,
float, float,
double, double,
int, int,
int64_t) {} int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#endif #endif
#if defined(PADDLE_WITH_XPU_KP) && !defined(PADDLE_WITH_XPU) #if defined(PADDLE_WITH_XPU_KP) && !defined(PADDLE_WITH_XPU)
......
...@@ -2009,12 +2009,6 @@ class OpTest(unittest.TestCase): ...@@ -2009,12 +2009,6 @@ class OpTest(unittest.TestCase):
return True return True
return super()._is_skip_name(name) return super()._is_skip_name(name)
if check_prim:
prim_checker = PrimForwardChecker(self, place)
prim_checker.check()
# Support operators which are not in the NO_FP64_CHECK_GRAD_OP_LIST list can be test prim with fp32
self.__class__.check_prim = True
self.__class__.op_type = self.op_type
# set some flags by the combination of arguments. # set some flags by the combination of arguments.
if self.is_float16_op(): if self.is_float16_op():
self.dtype = np.float16 self.dtype = np.float16
...@@ -2058,6 +2052,14 @@ class OpTest(unittest.TestCase): ...@@ -2058,6 +2052,14 @@ class OpTest(unittest.TestCase):
raise AssertionError( raise AssertionError(
"no_check_set of op %s must be set to None." % self.op_type "no_check_set of op %s must be set to None." % self.op_type
) )
if check_prim:
prim_checker = PrimForwardChecker(self, place)
prim_checker.check()
# Support operators which are not in the NO_FP64_CHECK_GRAD_OP_LIST list can be test prim with fp32
self.__class__.check_prim = True
self.__class__.op_type = self.op_type
static_checker = StaticChecker(self, self.outputs) static_checker = StaticChecker(self, self.outputs)
static_checker.check() static_checker.check()
outs, fetch_list = static_checker.outputs, static_checker.fetch_list outs, fetch_list = static_checker.outputs, static_checker.fetch_list
...@@ -2404,6 +2406,7 @@ class OpTest(unittest.TestCase): ...@@ -2404,6 +2406,7 @@ class OpTest(unittest.TestCase):
core._set_prim_all_enabled(False) core._set_prim_all_enabled(False)
core.set_prim_eager_enabled(False) core.set_prim_eager_enabled(False)
if check_prim: if check_prim:
self._check_grad_helper()
prim_grad_checker = PrimGradChecker( prim_grad_checker = PrimGradChecker(
self, self,
place, place,
...@@ -2415,7 +2418,6 @@ class OpTest(unittest.TestCase): ...@@ -2415,7 +2418,6 @@ class OpTest(unittest.TestCase):
prim_grad_checker.check() prim_grad_checker.check()
# Support operators which are not in the NO_FP64_CHECK_GRAD_OP_LIST list can be test prim with fp32 # Support operators which are not in the NO_FP64_CHECK_GRAD_OP_LIST list can be test prim with fp32
self.__class__.check_prim = True self.__class__.check_prim = True
self._check_grad_helper()
if only_check_prim: if only_check_prim:
return return
self.scope = core.Scope() self.scope = core.Scope()
...@@ -2657,33 +2659,24 @@ class OpTest(unittest.TestCase): ...@@ -2657,33 +2659,24 @@ class OpTest(unittest.TestCase):
outputs = dygraph_outputs outputs = dygraph_outputs
if self.dtype == np.uint16: if self.dtype == np.uint16:
cast_inputs = self._find_var_in_dygraph( cast_inputs = []
outputs, output_names[0] for output_name in output_names:
) cast_input = self._find_var_in_dygraph(outputs, output_name)
if isinstance(cast_inputs, paddle.Tensor): cast_inputs = cast_inputs + cast_input
cast_outputs = paddle.cast( cast_outputs = []
cast_inputs, core.VarDesc.VarType.FP32 for cast_input in cast_inputs:
) if isinstance(cast_input, paddle.Tensor):
elif isinstance(cast_inputs, list): cast_outputs.append(
cast_outputs = [] paddle.cast(cast_input, core.VarDesc.VarType.FP32)
for cast_input in cast_inputs: )
if isinstance(cast_input, paddle.Tensor): else:
cast_outputs.append( raise TypeError(
paddle.cast( "Unsupported test data type %s." % type(cast_input)
cast_input, core.VarDesc.VarType.FP32 )
)
)
else:
raise TypeError(
"Unsupported test data type %s."
% type(cast_input)
)
else:
raise TypeError(
"Unsupported test data type %s." % type(cast_inputs)
)
outputs = {output_names[0]: cast_outputs}
outputs = {}
for i in range(len(output_names)):
outputs.update({output_names[i]: [cast_outputs[i]]})
outputs_valid = {} outputs_valid = {}
for output_name in output_names: for output_name in output_names:
outputs_valid[output_name] = self._find_var_in_dygraph( outputs_valid[output_name] = self._find_var_in_dygraph(
...@@ -2791,7 +2784,7 @@ class OpTest(unittest.TestCase): ...@@ -2791,7 +2784,7 @@ class OpTest(unittest.TestCase):
if user_defined_grad_outputs is None: if user_defined_grad_outputs is None:
if self.dtype == np.uint16: if self.dtype == np.uint16:
cast_inputs = list(map(block.var, output_names)) cast_inputs = list(map(block.var, output_names))
if self.op_type == "broadcast_tensors": if self.op_type in ["broadcast_tensors", "meshgrid"]:
output_names = self.cast_bf16_output(block, cast_inputs) output_names = self.cast_bf16_output(block, cast_inputs)
else: else:
cast_outputs = block.create_var( cast_outputs = block.create_var(
......
...@@ -15,10 +15,11 @@ ...@@ -15,10 +15,11 @@
import unittest import unittest
import numpy as np import numpy as np
from eager_op_test import OpTest from eager_op_test import OpTest, convert_float_to_uint16
import paddle import paddle
from paddle import fluid from paddle import fluid
from paddle.fluid import core
class TestExpandAsBasic(OpTest): class TestExpandAsBasic(OpTest):
...@@ -27,14 +28,25 @@ class TestExpandAsBasic(OpTest): ...@@ -27,14 +28,25 @@ class TestExpandAsBasic(OpTest):
self.prim_op_type = "comp" self.prim_op_type = "comp"
self.python_api = paddle.expand_as self.python_api = paddle.expand_as
self.public_python_api = paddle.expand_as self.public_python_api = paddle.expand_as
x = np.random.rand(100).astype("float64") self.init_dtype()
target_tensor = np.random.rand(2, 100).astype("float64") self.init_inputs_and_outputs()
self.if_enable_cinn()
def init_dtype(self):
self.dtype = np.float64
def init_inputs_and_outputs(self):
x = np.random.rand(100).astype(self.dtype)
target_tensor = np.random.rand(2, 100).astype(self.dtype)
self.inputs = {'X': x, "Y": target_tensor} self.inputs = {'X': x, "Y": target_tensor}
self.attrs = {'target_shape': target_tensor.shape} self.attrs = {'target_shape': target_tensor.shape}
bcast_dims = [2, 1] bcast_dims = [2, 1]
output = np.tile(self.inputs['X'], bcast_dims) output = np.tile(self.inputs['X'], bcast_dims)
self.outputs = {'Out': output} self.outputs = {'Out': output}
def if_enable_cinn(self):
pass
def test_check_output(self): def test_check_output(self):
self.check_output(check_prim=True) self.check_output(check_prim=True)
...@@ -42,14 +54,43 @@ class TestExpandAsBasic(OpTest): ...@@ -42,14 +54,43 @@ class TestExpandAsBasic(OpTest):
self.check_grad(['X'], 'Out', check_prim=True) self.check_grad(['X'], 'Out', check_prim=True)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support the bfloat16",
)
class TestExpandAsBasicBFP16OP(TestExpandAsBasic):
def init_dtype(self):
self.dtype = np.uint16
def init_inputs_and_outputs(self):
x = np.random.rand(100).astype(np.float32)
target_tensor = np.random.rand(2, 100).astype(np.float32)
self.inputs = {
'X': convert_float_to_uint16(x),
"Y": convert_float_to_uint16(target_tensor),
}
self.attrs = {'target_shape': target_tensor.shape}
bcast_dims = [2, 1]
output = np.tile(x, bcast_dims)
self.outputs = {'Out': convert_float_to_uint16(output)}
def if_enable_cinn(self):
self.enable_cinn = False
def test_check_output(self):
self.check_output_with_place(place=paddle.CUDAPlace(0))
def test_check_grad(self):
self.check_grad_with_place(
paddle.CUDAPlace(0), ['X'], 'Out', check_prim=True
)
class TestExpandAsOpRank2(TestExpandAsBasic): class TestExpandAsOpRank2(TestExpandAsBasic):
def setUp(self): def init_inputs_and_outputs(self):
self.op_type = "expand_as_v2" x = np.random.rand(10, 12).astype(self.dtype)
self.prim_op_type = "comp" target_tensor = np.random.rand(10, 12).astype(self.dtype)
self.python_api = paddle.expand_as
self.public_python_api = paddle.expand_as
x = np.random.rand(10, 12).astype("float64")
target_tensor = np.random.rand(10, 12).astype("float64")
self.inputs = {'X': x, "Y": target_tensor} self.inputs = {'X': x, "Y": target_tensor}
self.attrs = {'target_shape': target_tensor.shape} self.attrs = {'target_shape': target_tensor.shape}
bcast_dims = [1, 1] bcast_dims = [1, 1]
...@@ -57,14 +98,43 @@ class TestExpandAsOpRank2(TestExpandAsBasic): ...@@ -57,14 +98,43 @@ class TestExpandAsOpRank2(TestExpandAsBasic):
self.outputs = {'Out': output} self.outputs = {'Out': output}
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support the bfloat16",
)
class TestExpandAsOpRank2BFP16OP(TestExpandAsOpRank2):
def init_dtype(self):
self.dtype = np.uint16
def init_inputs_and_outputs(self):
x = np.random.rand(10, 12).astype(np.float32)
target_tensor = np.random.rand(10, 12).astype(np.float32)
self.inputs = {
'X': convert_float_to_uint16(x),
"Y": convert_float_to_uint16(target_tensor),
}
self.attrs = {'target_shape': target_tensor.shape}
bcast_dims = [1, 1]
output = np.tile(x, bcast_dims)
self.outputs = {'Out': convert_float_to_uint16(output)}
def if_enable_cinn(self):
self.enable_cinn = False
def test_check_output(self):
self.check_output_with_place(place=paddle.CUDAPlace(0))
def test_check_grad(self):
self.check_grad_with_place(
paddle.CUDAPlace(0), ['X'], 'Out', check_prim=True
)
class TestExpandAsOpRank3(TestExpandAsBasic): class TestExpandAsOpRank3(TestExpandAsBasic):
def setUp(self): def init_inputs_and_outputs(self):
self.op_type = "expand_as_v2" x = np.random.rand(2, 3, 20).astype(self.dtype)
self.prim_op_type = "comp" target_tensor = np.random.rand(2, 3, 20).astype(self.dtype)
self.python_api = paddle.expand_as
self.public_python_api = paddle.expand_as
x = np.random.rand(2, 3, 20).astype("float64")
target_tensor = np.random.rand(2, 3, 20).astype("float64")
self.inputs = {'X': x, "Y": target_tensor} self.inputs = {'X': x, "Y": target_tensor}
self.attrs = {'target_shape': target_tensor.shape} self.attrs = {'target_shape': target_tensor.shape}
bcast_dims = [1, 1, 1] bcast_dims = [1, 1, 1]
...@@ -72,14 +142,43 @@ class TestExpandAsOpRank3(TestExpandAsBasic): ...@@ -72,14 +142,43 @@ class TestExpandAsOpRank3(TestExpandAsBasic):
self.outputs = {'Out': output} self.outputs = {'Out': output}
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support the bfloat16",
)
class TestExpandAsOpRank3BFP16OP(TestExpandAsOpRank3):
def init_dtype(self):
self.dtype = np.uint16
def init_inputs_and_outputs(self):
x = np.random.rand(2, 3, 20).astype(np.float32)
target_tensor = np.random.rand(2, 3, 20).astype(np.float32)
self.inputs = {
'X': convert_float_to_uint16(x),
"Y": convert_float_to_uint16(target_tensor),
}
self.attrs = {'target_shape': target_tensor.shape}
bcast_dims = [1, 1, 1]
output = np.tile(x, bcast_dims)
self.outputs = {'Out': convert_float_to_uint16(output)}
def if_enable_cinn(self):
self.enable_cinn = False
def test_check_output(self):
self.check_output_with_place(place=paddle.CUDAPlace(0))
def test_check_grad(self):
self.check_grad_with_place(
paddle.CUDAPlace(0), ['X'], 'Out', check_prim=True
)
class TestExpandAsOpRank4(TestExpandAsBasic): class TestExpandAsOpRank4(TestExpandAsBasic):
def setUp(self): def init_inputs_and_outputs(self):
self.op_type = "expand_as_v2" x = np.random.rand(1, 1, 7, 16).astype(self.dtype)
self.prim_op_type = "comp" target_tensor = np.random.rand(4, 6, 7, 16).astype(self.dtype)
self.python_api = paddle.expand_as
self.public_python_api = paddle.expand_as
x = np.random.rand(1, 1, 7, 16).astype("float64")
target_tensor = np.random.rand(4, 6, 7, 16).astype("float64")
self.inputs = {'X': x, "Y": target_tensor} self.inputs = {'X': x, "Y": target_tensor}
self.attrs = {'target_shape': target_tensor.shape} self.attrs = {'target_shape': target_tensor.shape}
bcast_dims = [4, 6, 1, 1] bcast_dims = [4, 6, 1, 1]
...@@ -87,6 +186,39 @@ class TestExpandAsOpRank4(TestExpandAsBasic): ...@@ -87,6 +186,39 @@ class TestExpandAsOpRank4(TestExpandAsBasic):
self.outputs = {'Out': output} self.outputs = {'Out': output}
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support the bfloat16",
)
class TestExpandAsOpRank4BFP16OP(TestExpandAsOpRank3):
def init_dtype(self):
self.dtype = np.uint16
def init_inputs_and_outputs(self):
x = np.random.rand(1, 1, 7, 16).astype(np.float32)
target_tensor = np.random.rand(4, 6, 7, 16).astype(np.float32)
self.inputs = {
'X': convert_float_to_uint16(x),
"Y": convert_float_to_uint16(target_tensor),
}
self.attrs = {'target_shape': target_tensor.shape}
bcast_dims = [4, 6, 1, 1]
output = np.tile(x, bcast_dims)
self.outputs = {'Out': convert_float_to_uint16(output)}
def if_enable_cinn(self):
self.enable_cinn = False
def test_check_output(self):
self.check_output_with_place(place=paddle.CUDAPlace(0))
def test_check_grad(self):
self.check_grad_with_place(
paddle.CUDAPlace(0), ['X'], 'Out', check_prim=True
)
class TestExpandAsOpRank5(TestExpandAsBasic): class TestExpandAsOpRank5(TestExpandAsBasic):
no_need_check_grad = True no_need_check_grad = True
...@@ -107,6 +239,32 @@ class TestExpandAsOpRank5(TestExpandAsBasic): ...@@ -107,6 +239,32 @@ class TestExpandAsOpRank5(TestExpandAsBasic):
pass pass
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support the bfloat16",
)
class TestExpandAsOpRank5BFP16OP(TestExpandAsOpRank5):
def setUp(self):
self.op_type = "expand_as_v2"
self.prim_op_type = "comp"
self.python_api = paddle.expand_as
self.public_python_api = paddle.expand_as
x = np.random.rand(1, 1, 7, 16).astype("int64")
target_tensor = np.random.rand(4, 6, 7, 16).astype("float32")
self.inputs = {'X': x, "Y": convert_float_to_uint16(target_tensor)}
self.attrs = {'target_shape': target_tensor.shape}
bcast_dims = [4, 6, 1, 1]
output = np.tile(x, bcast_dims)
self.outputs = {'Out': convert_float_to_uint16(output)}
def test_check_output(self):
self.check_output_with_place(place=paddle.CUDAPlace(0))
def test_check_grad(self):
pass
class TestExpandAsV2Error(unittest.TestCase): class TestExpandAsV2Error(unittest.TestCase):
def test_errors(self): def test_errors(self):
with fluid.program_guard(fluid.Program(), fluid.Program()): with fluid.program_guard(fluid.Program(), fluid.Program()):
......
...@@ -15,10 +15,11 @@ ...@@ -15,10 +15,11 @@
import unittest import unittest
import numpy as np import numpy as np
from eager_op_test import OpTest from eager_op_test import OpTest, convert_float_to_uint16
import paddle import paddle
from paddle import fluid from paddle import fluid
from paddle.fluid import core
def meshgrid_wrapper(x): def meshgrid_wrapper(x):
...@@ -32,11 +33,7 @@ class TestMeshgridOp(OpTest): ...@@ -32,11 +33,7 @@ class TestMeshgridOp(OpTest):
self.python_api = meshgrid_wrapper self.python_api = meshgrid_wrapper
self.public_python_api = meshgrid_wrapper self.public_python_api = meshgrid_wrapper
self.dtype = self.get_dtype() self.dtype = self.get_dtype()
ins, outs = self.init_test_data() self.init_inputs_and_outputs()
self.inputs = {'X': [('x%d' % i, ins[i]) for i in range(len(ins))]}
self.outputs = {
'Out': [('out%d' % i, outs[i]) for i in range(len(outs))]
}
self.python_out_sig = ['out0', 'out1'] self.python_out_sig = ['out0', 'out1']
self.if_enable_cinn() self.if_enable_cinn()
...@@ -49,7 +46,7 @@ class TestMeshgridOp(OpTest): ...@@ -49,7 +46,7 @@ class TestMeshgridOp(OpTest):
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['x0'], ['out0', 'out1'], check_prim=True) self.check_grad(['x0'], ['out0', 'out1'], check_prim=True)
def init_test_data(self): def init_inputs_and_outputs(self):
self.shape = self.get_x_shape() self.shape = self.get_x_shape()
ins = [] ins = []
outs = [] outs = []
...@@ -61,7 +58,10 @@ class TestMeshgridOp(OpTest): ...@@ -61,7 +58,10 @@ class TestMeshgridOp(OpTest):
out_reshape[i] = self.shape[i] out_reshape[i] = self.shape[i]
out_temp = np.reshape(ins[i], out_reshape) out_temp = np.reshape(ins[i], out_reshape)
outs.append(np.broadcast_to(out_temp, self.shape)) outs.append(np.broadcast_to(out_temp, self.shape))
return ins, outs self.inputs = {'X': [('x%d' % i, ins[i]) for i in range(len(ins))]}
self.outputs = {
'Out': [('out%d' % i, outs[i]) for i in range(len(outs))]
}
def get_x_shape(self): def get_x_shape(self):
return [100, 200] return [100, 200]
...@@ -84,6 +84,52 @@ class TestMeshgridOp2Fp16(TestMeshgridOp): ...@@ -84,6 +84,52 @@ class TestMeshgridOp2Fp16(TestMeshgridOp):
return np.float16 return np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support the bfloat16",
)
class TestMeshgridOpBFP16OP(TestMeshgridOp):
def init_data_type(self):
self.data_type = np.uint16
def init_inputs_and_outputs(self):
self.shape = self.get_x_shape()
ins = []
outs = []
for i in range(len(self.shape)):
ins.append(np.random.random((self.shape[i],)).astype(self.dtype))
for i in range(len(self.shape)):
out_reshape = [1] * len(self.shape)
out_reshape[i] = self.shape[i]
out_temp = np.reshape(ins[i], out_reshape)
outs.append(np.broadcast_to(out_temp, self.shape))
self.inputs = {
'X': [
('x%d' % i, convert_float_to_uint16(ins[i]))
for i in range(len(ins))
]
}
self.outputs = {
'Out': [
('out%d' % i, convert_float_to_uint16(outs[i]))
for i in range(len(outs))
]
}
def if_enable_cinn(self):
self.enable_cinn = False
def test_check_output(self):
self.check_output_with_place(place=paddle.CUDAPlace(0))
def test_check_grad(self):
self.check_grad_with_place(
paddle.CUDAPlace(0), ['x0'], ['out0', 'out1'], check_prim=True
)
class TestMeshgridOp3(unittest.TestCase): class TestMeshgridOp3(unittest.TestCase):
def test_api(self): def test_api(self):
x = paddle.static.data(shape=[100], dtype='int32', name='x') x = paddle.static.data(shape=[100], dtype='int32', name='x')
...@@ -273,7 +319,7 @@ class TestMeshgridOp8(unittest.TestCase): ...@@ -273,7 +319,7 @@ class TestMeshgridOp8(unittest.TestCase):
class TestMeshGrid_ZeroDim(TestMeshgridOp): class TestMeshGrid_ZeroDim(TestMeshgridOp):
def init_test_data(self): def init_inputs_and_outputs(self):
self.shape = self.get_x_shape() self.shape = self.get_x_shape()
ins = [] ins = []
outs = [] outs = []
...@@ -285,7 +331,10 @@ class TestMeshGrid_ZeroDim(TestMeshgridOp): ...@@ -285,7 +331,10 @@ class TestMeshGrid_ZeroDim(TestMeshgridOp):
out_reshape[i] = self.shape[i] out_reshape[i] = self.shape[i]
out_temp = np.reshape(ins[i], out_reshape) out_temp = np.reshape(ins[i], out_reshape)
outs.append(np.broadcast_to(out_temp, self.shape)) outs.append(np.broadcast_to(out_temp, self.shape))
return ins, outs self.inputs = {'X': [('x%d' % i, ins[i]) for i in range(len(ins))]}
self.outputs = {
'Out': [('out%d' % i, outs[i]) for i in range(len(outs))]
}
def get_x_shape(self): def get_x_shape(self):
return [1, 2, 3] return [1, 2, 3]
......
...@@ -303,6 +303,7 @@ class TestMaxFP32Op(OpTest): ...@@ -303,6 +303,7 @@ class TestMaxFP32Op(OpTest):
self.python_api = paddle.max self.python_api = paddle.max
self.public_python_api = paddle.max self.public_python_api = paddle.max
self.init_dtype() self.init_dtype()
self.if_enable_cinn()
if self.dtype == np.uint16: if self.dtype == np.uint16:
x = np.random.random((5, 6, 10)).astype(np.float32) x = np.random.random((5, 6, 10)).astype(np.float32)
self.inputs = {'X': convert_float_to_uint16(x)} self.inputs = {'X': convert_float_to_uint16(x)}
...@@ -316,6 +317,9 @@ class TestMaxFP32Op(OpTest): ...@@ -316,6 +317,9 @@ class TestMaxFP32Op(OpTest):
else: else:
self.outputs = {'Out': out} self.outputs = {'Out': out}
def if_enable_cinn(self):
pass
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -346,6 +350,9 @@ class TestMaxBF16Op(TestMaxFP32Op): ...@@ -346,6 +350,9 @@ class TestMaxBF16Op(TestMaxFP32Op):
def init_dtype(self): def init_dtype(self):
self.dtype = np.uint16 self.dtype = np.uint16
def if_enable_cinn(self):
self.enable_cinn = False
def test_check_output(self): def test_check_output(self):
self.check_output_with_place(core.CUDAPlace(0)) self.check_output_with_place(core.CUDAPlace(0))
...@@ -487,8 +494,11 @@ class TestProdOp(OpTest): ...@@ -487,8 +494,11 @@ class TestProdOp(OpTest):
self.python_api = raw_reduce_prod self.python_api = raw_reduce_prod
self.public_python_api = raw_reduce_prod self.public_python_api = raw_reduce_prod
self.prim_op_type = "prim" self.prim_op_type = "prim"
self.init_data_type() self.init_data_type()
self.init_inputs_and_outputs()
self.if_enable_cinn()
def init_inputs_and_outputs(self):
self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.data_type)} self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.data_type)}
self.outputs = {'Out': self.inputs['X'].prod(axis=0)} self.outputs = {'Out': self.inputs['X'].prod(axis=0)}
...@@ -497,6 +507,9 @@ class TestProdOp(OpTest): ...@@ -497,6 +507,9 @@ class TestProdOp(OpTest):
"float32" if core.is_compiled_with_rocm() else "float64" "float32" if core.is_compiled_with_rocm() else "float64"
) )
def if_enable_cinn(self):
pass
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -504,6 +517,49 @@ class TestProdOp(OpTest): ...@@ -504,6 +517,49 @@ class TestProdOp(OpTest):
self.check_grad(['X'], 'Out', check_prim=True) self.check_grad(['X'], 'Out', check_prim=True)
@unittest.skipIf(
not paddle.is_compiled_with_cuda(), "FP16 test runs only on GPU"
)
class TestProdFP16OP(TestProdOp):
def init_data_type(self):
self.data_type = "float16"
def test_check_output(self):
self.check_output_with_place(place=paddle.CUDAPlace(0))
def test_check_grad(self):
self.check_grad_with_place(
paddle.CUDAPlace(0), ['X'], 'Out', check_prim=True
)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support the bfloat16",
)
class TestProdBFP16OP(TestProdOp):
def init_data_type(self):
self.data_type = np.uint16
def init_inputs_and_outputs(self):
x = np.random.random((5, 6, 10)).astype("float32")
out = x.prod(axis=0)
self.inputs = {'X': convert_float_to_uint16(x)}
self.outputs = {'Out': convert_float_to_uint16(out)}
def if_enable_cinn(self):
self.enable_cinn = False
def test_check_output(self):
self.check_output_with_place(place=paddle.CUDAPlace(0))
def test_check_grad(self):
self.check_grad_with_place(
paddle.CUDAPlace(0), ['X'], 'Out', check_prim=True
)
class TestProdOpFp64(TestProdOp): class TestProdOpFp64(TestProdOp):
def init_data_type(self): def init_data_type(self):
self.data_type = "float64" self.data_type = "float64"
...@@ -522,11 +578,16 @@ class TestProdOp_ZeroDim(OpTest): ...@@ -522,11 +578,16 @@ class TestProdOp_ZeroDim(OpTest):
# 0-D tensor doesn't support in cinn # 0-D tensor doesn't support in cinn
self.enable_cinn = False self.enable_cinn = False
def init_inputs_and_outputs(self):
self.inputs = {'X': np.random.random([]).astype("float64")}
self.outputs = {'Out': self.inputs['X'].prod()}
self.attrs = {'dim': [], 'reduce_all': True}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out', check_prim=True)
class TestProd6DOp(OpTest): class TestProd6DOp(OpTest):
...@@ -536,6 +597,15 @@ class TestProd6DOp(OpTest): ...@@ -536,6 +597,15 @@ class TestProd6DOp(OpTest):
self.public_python_api = raw_reduce_prod self.public_python_api = raw_reduce_prod
self.prim_op_type = "prim" self.prim_op_type = "prim"
self.init_data_type() self.init_data_type()
self.init_inputs_and_outputs()
self.if_enable_cinn()
def init_data_type(self):
self.data_type = (
"float32" if core.is_compiled_with_rocm() else "float64"
)
def init_inputs_and_outputs(self):
self.inputs = { self.inputs = {
'X': np.random.random((5, 6, 2, 3, 4, 2)).astype(self.data_type) 'X': np.random.random((5, 6, 2, 3, 4, 2)).astype(self.data_type)
} }
...@@ -544,10 +614,8 @@ class TestProd6DOp(OpTest): ...@@ -544,10 +614,8 @@ class TestProd6DOp(OpTest):
'Out': self.inputs['X'].prod(axis=tuple(self.attrs['dim'])) 'Out': self.inputs['X'].prod(axis=tuple(self.attrs['dim']))
} }
def init_data_type(self): def if_enable_cinn(self):
self.data_type = ( pass
"float32" if core.is_compiled_with_rocm() else "float64"
)
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -556,12 +624,59 @@ class TestProd6DOp(OpTest): ...@@ -556,12 +624,59 @@ class TestProd6DOp(OpTest):
self.check_grad(['X'], 'Out', check_prim=True) self.check_grad(['X'], 'Out', check_prim=True)
@unittest.skipIf(
not paddle.is_compiled_with_cuda(), "FP16 test runs only on GPU"
)
class TestProd6DFP16OP(TestProd6DOp):
def init_data_type(self):
self.data_type = "float16"
def test_check_output(self):
self.check_output_with_place(place=paddle.CUDAPlace(0))
def test_check_grad(self):
self.check_grad_with_place(
paddle.CUDAPlace(0), ['X'], 'Out', check_prim=True
)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support the bfloat16",
)
class TestProd6DBFP16OP(TestProd6DOp):
def init_data_type(self):
self.data_type = np.uint16
def init_inputs_and_outputs(self):
x = np.random.random((5, 6, 2, 3, 4, 2)).astype("float32")
self.attrs = {'dim': [2, 3, 4]}
out = x.prod(axis=tuple(self.attrs['dim']))
self.inputs = {'X': convert_float_to_uint16(x)}
self.outputs = {'Out': convert_float_to_uint16(out)}
def if_enable_cinn(self):
self.enable_cinn = False
def test_check_output(self):
self.check_output_with_place(place=paddle.CUDAPlace(0))
def test_check_grad(self):
self.check_grad_with_place(
paddle.CUDAPlace(0), ['X'], 'Out', check_prim=True
)
class TestProd8DOp(OpTest): class TestProd8DOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "reduce_prod" self.op_type = "reduce_prod"
self.python_api = raw_reduce_prod self.python_api = raw_reduce_prod
self.public_python_api = raw_reduce_prod self.public_python_api = raw_reduce_prod
self.init_data_type() self.init_data_type()
self.init_inputs_and_outputs()
def init_inputs_and_outputs(self):
self.inputs = { self.inputs = {
'X': np.random.random((2, 5, 3, 2, 2, 3, 4, 2)).astype( 'X': np.random.random((2, 5, 3, 2, 2, 3, 4, 2)).astype(
self.data_type self.data_type
...@@ -584,6 +699,43 @@ class TestProd8DOp(OpTest): ...@@ -584,6 +699,43 @@ class TestProd8DOp(OpTest):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
@unittest.skipIf(
not paddle.is_compiled_with_cuda(), "FP16 test runs only on GPU"
)
class TestProd8DFP16OP(TestProd8DOp):
def init_data_type(self):
self.data_type = "float16"
def test_check_output(self):
self.check_output_with_place(place=paddle.CUDAPlace(0))
def test_check_grad(self):
self.check_grad_with_place(paddle.CUDAPlace(0), ['X'], 'Out')
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support the bfloat16",
)
class TestProd8DBFP16OP(TestProd8DOp):
def init_data_type(self):
self.data_type = np.uint16
def init_inputs_and_outputs(self):
x = np.random.random((2, 5, 3, 2, 2, 3, 4, 2)).astype("float32")
self.attrs = {'dim': [2, 3, 4]}
out = x.prod(axis=tuple(self.attrs['dim']))
self.inputs = {'X': convert_float_to_uint16(x)}
self.outputs = {'Out': convert_float_to_uint16(out)}
def test_check_output(self):
self.check_output_with_place(place=paddle.CUDAPlace(0))
def test_check_grad(self):
self.check_grad_with_place(paddle.CUDAPlace(0), ['X'], 'Out')
class TestAllOp(OpTest): class TestAllOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "reduce_all" self.op_type = "reduce_all"
......
...@@ -1563,7 +1563,7 @@ def meshgrid(*args, **kwargs): ...@@ -1563,7 +1563,7 @@ def meshgrid(*args, **kwargs):
check_dtype( check_dtype(
input_.dtype, input_.dtype,
'create data type', 'create data type',
['float16', 'float32', 'float64', 'int32', 'int64'], ['uint16', 'float16', 'float32', 'float64', 'int32', 'int64'],
'meshgrid', 'meshgrid',
) )
......
...@@ -3275,7 +3275,15 @@ def expand_as(x, y, name=None): ...@@ -3275,7 +3275,15 @@ def expand_as(x, y, name=None):
check_variable_and_dtype( check_variable_and_dtype(
x, x,
'x', 'x',
['bool', 'float32', 'float64', 'int32', 'int64'], [
'bool',
'float32',
'float64',
'int32',
'int64',
'float16',
'uint16',
],
'expand_as', 'expand_as',
) )
check_type(y, 'y', Variable, 'expand_as') check_type(y, 'y', Variable, 'expand_as')
...@@ -3348,7 +3356,15 @@ def broadcast_to(x, shape, name=None): ...@@ -3348,7 +3356,15 @@ def broadcast_to(x, shape, name=None):
check_variable_and_dtype( check_variable_and_dtype(
x, x,
'x', 'x',
['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], [
'bool',
'uint16',
'float16',
'float32',
'float64',
'int32',
'int64',
],
'broadcast_to', 'broadcast_to',
) )
check_type(shape, 'shape', (list, tuple, Variable), 'broadcast_to') check_type(shape, 'shape', (list, tuple, Variable), 'broadcast_to')
......
...@@ -3718,7 +3718,10 @@ def prod(x, axis=None, keepdim=False, dtype=None, name=None): ...@@ -3718,7 +3718,10 @@ def prod(x, axis=None, keepdim=False, dtype=None, name=None):
""" """
if dtype is not None: if dtype is not None:
check_dtype( check_dtype(
dtype, 'dtype', ['float32', 'float64', 'int32', 'int64'], 'prod' dtype,
'dtype',
['float32', 'float64', 'int32', 'int64', "float16", "uint16"],
'prod',
) )
if x.dtype != convert_np_dtype_to_dtype_(dtype): if x.dtype != convert_np_dtype_to_dtype_(dtype):
x = cast(x, dtype) x = cast(x, dtype)
...@@ -3731,7 +3734,7 @@ def prod(x, axis=None, keepdim=False, dtype=None, name=None): ...@@ -3731,7 +3734,7 @@ def prod(x, axis=None, keepdim=False, dtype=None, name=None):
check_variable_and_dtype( check_variable_and_dtype(
x, x,
'x/input', 'x/input',
['float32', 'float64', 'int32', 'int64'], ['float32', 'float64', 'int32', 'int64', "float16", "uint16"],
'reduce_prod', 'reduce_prod',
) )
out = helper.create_variable_for_type_inference( out = helper.create_variable_for_type_inference(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册