From 706503d0ece627851ef1ce9fdef66bc439b567b5 Mon Sep 17 00:00:00 2001 From: Charles-hit <56987902+Charles-hit@users.noreply.github.com> Date: Thu, 18 May 2023 19:27:46 +0800 Subject: [PATCH] =?UTF-8?q?[AMP=20OP&Test]support=20prod=E3=80=81meshgrid?= =?UTF-8?q?=E3=80=81expand=5Fas=20bf16=20dtype=20(#53865)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add meshgrid,expand_as, prod and grad bf16 kernel * fix bf16 for optest * modify code style * fix amp test --- .../phi/kernels/gpu/expand_as_grad_kernel.cu | 3 +- paddle/phi/kernels/gpu/expand_as_kernel.cu | 3 +- .../kernels/gpu/meshgrid_grad_kernel.cu.cc | 3 +- paddle/phi/kernels/gpu/meshgrid_kernel.cu.cc | 3 +- paddle/phi/kernels/gpu/prod_grad_kernel.cu | 4 +- paddle/phi/kernels/kps/prod_kernel.cu | 12 +- paddle/phi/kernels/prod_kernel.cc | 4 +- .../fluid/tests/unittests/eager_op_test.py | 61 +++--- .../tests/unittests/test_expand_as_v2_op.py | 206 ++++++++++++++++-- .../fluid/tests/unittests/test_meshgrid_op.py | 69 +++++- .../fluid/tests/unittests/test_reduce_op.py | 164 +++++++++++++- python/paddle/tensor/creation.py | 2 +- python/paddle/tensor/manipulation.py | 20 +- python/paddle/tensor/math.py | 7 +- 14 files changed, 474 insertions(+), 87 deletions(-) diff --git a/paddle/phi/kernels/gpu/expand_as_grad_kernel.cu b/paddle/phi/kernels/gpu/expand_as_grad_kernel.cu index f7f0ed63324..635dbf7101c 100644 --- a/paddle/phi/kernels/gpu/expand_as_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/expand_as_grad_kernel.cu @@ -60,4 +60,5 @@ PD_REGISTER_KERNEL(expand_as_grad, float, double, int, - int64_t) {} + int64_t, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/expand_as_kernel.cu b/paddle/phi/kernels/gpu/expand_as_kernel.cu index de47024f295..603e43482b9 100644 --- a/paddle/phi/kernels/gpu/expand_as_kernel.cu +++ b/paddle/phi/kernels/gpu/expand_as_kernel.cu @@ -87,4 +87,5 @@ PD_REGISTER_KERNEL(expand_as, double, int, int64_t, - bool) {} + bool, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/meshgrid_grad_kernel.cu.cc b/paddle/phi/kernels/gpu/meshgrid_grad_kernel.cu.cc index 17f74cd3743..2dd9e7dc6ce 100644 --- a/paddle/phi/kernels/gpu/meshgrid_grad_kernel.cu.cc +++ b/paddle/phi/kernels/gpu/meshgrid_grad_kernel.cu.cc @@ -26,4 +26,5 @@ PD_REGISTER_KERNEL(meshgrid_grad, float, double, int, - int64_t) {} + int64_t, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/meshgrid_kernel.cu.cc b/paddle/phi/kernels/gpu/meshgrid_kernel.cu.cc index 73120c13916..5a1c74f4193 100644 --- a/paddle/phi/kernels/gpu/meshgrid_kernel.cu.cc +++ b/paddle/phi/kernels/gpu/meshgrid_kernel.cu.cc @@ -26,4 +26,5 @@ PD_REGISTER_KERNEL(meshgrid, float, double, int, - int64_t) {} + int64_t, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/prod_grad_kernel.cu b/paddle/phi/kernels/gpu/prod_grad_kernel.cu index 301cc46b0b7..8a407386508 100644 --- a/paddle/phi/kernels/gpu/prod_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/prod_grad_kernel.cu @@ -25,4 +25,6 @@ PD_REGISTER_KERNEL(prod_grad, float, double, int, - int64_t) {} + int64_t, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/kps/prod_kernel.cu b/paddle/phi/kernels/kps/prod_kernel.cu index a584af357cf..40f5cdcf642 100644 --- a/paddle/phi/kernels/kps/prod_kernel.cu +++ b/paddle/phi/kernels/kps/prod_kernel.cu @@ -35,6 +35,14 @@ void ProdKernel(const Context& dev_ctx, #ifdef PADDLE_WITH_XPU_KP PD_REGISTER_KERNEL(prod, KPS, ALL_LAYOUT, phi::ProdKernel, float) {} #else -PD_REGISTER_KERNEL( - prod, KPS, ALL_LAYOUT, phi::ProdKernel, float, double, int, int64_t) {} +PD_REGISTER_KERNEL(prod, + KPS, + ALL_LAYOUT, + phi::ProdKernel, + float, + double, + int, + int64_t, + phi::dtype::float16, + phi::dtype::bfloat16) {} #endif diff --git a/paddle/phi/kernels/prod_kernel.cc b/paddle/phi/kernels/prod_kernel.cc index 12b55e12030..ea3faaebd95 100644 --- a/paddle/phi/kernels/prod_kernel.cc +++ b/paddle/phi/kernels/prod_kernel.cc @@ -48,7 +48,9 @@ PD_REGISTER_KERNEL(prod_infer, float, double, int, - int64_t) {} + int64_t, + phi::dtype::float16, + phi::dtype::bfloat16) {} #endif #if defined(PADDLE_WITH_XPU_KP) && !defined(PADDLE_WITH_XPU) diff --git a/python/paddle/fluid/tests/unittests/eager_op_test.py b/python/paddle/fluid/tests/unittests/eager_op_test.py index 63a14a66efa..ca16e71a8ea 100644 --- a/python/paddle/fluid/tests/unittests/eager_op_test.py +++ b/python/paddle/fluid/tests/unittests/eager_op_test.py @@ -2009,12 +2009,6 @@ class OpTest(unittest.TestCase): return True return super()._is_skip_name(name) - if check_prim: - prim_checker = PrimForwardChecker(self, place) - prim_checker.check() - # Support operators which are not in the NO_FP64_CHECK_GRAD_OP_LIST list can be test prim with fp32 - self.__class__.check_prim = True - self.__class__.op_type = self.op_type # set some flags by the combination of arguments. if self.is_float16_op(): self.dtype = np.float16 @@ -2058,6 +2052,14 @@ class OpTest(unittest.TestCase): raise AssertionError( "no_check_set of op %s must be set to None." % self.op_type ) + + if check_prim: + prim_checker = PrimForwardChecker(self, place) + prim_checker.check() + # Support operators which are not in the NO_FP64_CHECK_GRAD_OP_LIST list can be test prim with fp32 + self.__class__.check_prim = True + self.__class__.op_type = self.op_type + static_checker = StaticChecker(self, self.outputs) static_checker.check() outs, fetch_list = static_checker.outputs, static_checker.fetch_list @@ -2404,6 +2406,7 @@ class OpTest(unittest.TestCase): core._set_prim_all_enabled(False) core.set_prim_eager_enabled(False) if check_prim: + self._check_grad_helper() prim_grad_checker = PrimGradChecker( self, place, @@ -2415,7 +2418,6 @@ class OpTest(unittest.TestCase): prim_grad_checker.check() # Support operators which are not in the NO_FP64_CHECK_GRAD_OP_LIST list can be test prim with fp32 self.__class__.check_prim = True - self._check_grad_helper() if only_check_prim: return self.scope = core.Scope() @@ -2657,33 +2659,24 @@ class OpTest(unittest.TestCase): outputs = dygraph_outputs if self.dtype == np.uint16: - cast_inputs = self._find_var_in_dygraph( - outputs, output_names[0] - ) - if isinstance(cast_inputs, paddle.Tensor): - cast_outputs = paddle.cast( - cast_inputs, core.VarDesc.VarType.FP32 - ) - elif isinstance(cast_inputs, list): - cast_outputs = [] - for cast_input in cast_inputs: - if isinstance(cast_input, paddle.Tensor): - cast_outputs.append( - paddle.cast( - cast_input, core.VarDesc.VarType.FP32 - ) - ) - else: - raise TypeError( - "Unsupported test data type %s." - % type(cast_input) - ) - else: - raise TypeError( - "Unsupported test data type %s." % type(cast_inputs) - ) - outputs = {output_names[0]: cast_outputs} + cast_inputs = [] + for output_name in output_names: + cast_input = self._find_var_in_dygraph(outputs, output_name) + cast_inputs = cast_inputs + cast_input + cast_outputs = [] + for cast_input in cast_inputs: + if isinstance(cast_input, paddle.Tensor): + cast_outputs.append( + paddle.cast(cast_input, core.VarDesc.VarType.FP32) + ) + else: + raise TypeError( + "Unsupported test data type %s." % type(cast_input) + ) + outputs = {} + for i in range(len(output_names)): + outputs.update({output_names[i]: [cast_outputs[i]]}) outputs_valid = {} for output_name in output_names: outputs_valid[output_name] = self._find_var_in_dygraph( @@ -2791,7 +2784,7 @@ class OpTest(unittest.TestCase): if user_defined_grad_outputs is None: if self.dtype == np.uint16: cast_inputs = list(map(block.var, output_names)) - if self.op_type == "broadcast_tensors": + if self.op_type in ["broadcast_tensors", "meshgrid"]: output_names = self.cast_bf16_output(block, cast_inputs) else: cast_outputs = block.create_var( diff --git a/python/paddle/fluid/tests/unittests/test_expand_as_v2_op.py b/python/paddle/fluid/tests/unittests/test_expand_as_v2_op.py index 2e69b1c3091..ee472130f64 100755 --- a/python/paddle/fluid/tests/unittests/test_expand_as_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_expand_as_v2_op.py @@ -15,10 +15,11 @@ import unittest import numpy as np -from eager_op_test import OpTest +from eager_op_test import OpTest, convert_float_to_uint16 import paddle from paddle import fluid +from paddle.fluid import core class TestExpandAsBasic(OpTest): @@ -27,14 +28,25 @@ class TestExpandAsBasic(OpTest): self.prim_op_type = "comp" self.python_api = paddle.expand_as self.public_python_api = paddle.expand_as - x = np.random.rand(100).astype("float64") - target_tensor = np.random.rand(2, 100).astype("float64") + self.init_dtype() + self.init_inputs_and_outputs() + self.if_enable_cinn() + + def init_dtype(self): + self.dtype = np.float64 + + def init_inputs_and_outputs(self): + x = np.random.rand(100).astype(self.dtype) + target_tensor = np.random.rand(2, 100).astype(self.dtype) self.inputs = {'X': x, "Y": target_tensor} self.attrs = {'target_shape': target_tensor.shape} bcast_dims = [2, 1] output = np.tile(self.inputs['X'], bcast_dims) self.outputs = {'Out': output} + def if_enable_cinn(self): + pass + def test_check_output(self): self.check_output(check_prim=True) @@ -42,14 +54,43 @@ class TestExpandAsBasic(OpTest): self.check_grad(['X'], 'Out', check_prim=True) +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support the bfloat16", +) +class TestExpandAsBasicBFP16OP(TestExpandAsBasic): + def init_dtype(self): + self.dtype = np.uint16 + + def init_inputs_and_outputs(self): + x = np.random.rand(100).astype(np.float32) + target_tensor = np.random.rand(2, 100).astype(np.float32) + self.inputs = { + 'X': convert_float_to_uint16(x), + "Y": convert_float_to_uint16(target_tensor), + } + self.attrs = {'target_shape': target_tensor.shape} + bcast_dims = [2, 1] + output = np.tile(x, bcast_dims) + self.outputs = {'Out': convert_float_to_uint16(output)} + + def if_enable_cinn(self): + self.enable_cinn = False + + def test_check_output(self): + self.check_output_with_place(place=paddle.CUDAPlace(0)) + + def test_check_grad(self): + self.check_grad_with_place( + paddle.CUDAPlace(0), ['X'], 'Out', check_prim=True + ) + + class TestExpandAsOpRank2(TestExpandAsBasic): - def setUp(self): - self.op_type = "expand_as_v2" - self.prim_op_type = "comp" - self.python_api = paddle.expand_as - self.public_python_api = paddle.expand_as - x = np.random.rand(10, 12).astype("float64") - target_tensor = np.random.rand(10, 12).astype("float64") + def init_inputs_and_outputs(self): + x = np.random.rand(10, 12).astype(self.dtype) + target_tensor = np.random.rand(10, 12).astype(self.dtype) self.inputs = {'X': x, "Y": target_tensor} self.attrs = {'target_shape': target_tensor.shape} bcast_dims = [1, 1] @@ -57,14 +98,43 @@ class TestExpandAsOpRank2(TestExpandAsBasic): self.outputs = {'Out': output} +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support the bfloat16", +) +class TestExpandAsOpRank2BFP16OP(TestExpandAsOpRank2): + def init_dtype(self): + self.dtype = np.uint16 + + def init_inputs_and_outputs(self): + x = np.random.rand(10, 12).astype(np.float32) + target_tensor = np.random.rand(10, 12).astype(np.float32) + self.inputs = { + 'X': convert_float_to_uint16(x), + "Y": convert_float_to_uint16(target_tensor), + } + self.attrs = {'target_shape': target_tensor.shape} + bcast_dims = [1, 1] + output = np.tile(x, bcast_dims) + self.outputs = {'Out': convert_float_to_uint16(output)} + + def if_enable_cinn(self): + self.enable_cinn = False + + def test_check_output(self): + self.check_output_with_place(place=paddle.CUDAPlace(0)) + + def test_check_grad(self): + self.check_grad_with_place( + paddle.CUDAPlace(0), ['X'], 'Out', check_prim=True + ) + + class TestExpandAsOpRank3(TestExpandAsBasic): - def setUp(self): - self.op_type = "expand_as_v2" - self.prim_op_type = "comp" - self.python_api = paddle.expand_as - self.public_python_api = paddle.expand_as - x = np.random.rand(2, 3, 20).astype("float64") - target_tensor = np.random.rand(2, 3, 20).astype("float64") + def init_inputs_and_outputs(self): + x = np.random.rand(2, 3, 20).astype(self.dtype) + target_tensor = np.random.rand(2, 3, 20).astype(self.dtype) self.inputs = {'X': x, "Y": target_tensor} self.attrs = {'target_shape': target_tensor.shape} bcast_dims = [1, 1, 1] @@ -72,14 +142,43 @@ class TestExpandAsOpRank3(TestExpandAsBasic): self.outputs = {'Out': output} +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support the bfloat16", +) +class TestExpandAsOpRank3BFP16OP(TestExpandAsOpRank3): + def init_dtype(self): + self.dtype = np.uint16 + + def init_inputs_and_outputs(self): + x = np.random.rand(2, 3, 20).astype(np.float32) + target_tensor = np.random.rand(2, 3, 20).astype(np.float32) + self.inputs = { + 'X': convert_float_to_uint16(x), + "Y": convert_float_to_uint16(target_tensor), + } + self.attrs = {'target_shape': target_tensor.shape} + bcast_dims = [1, 1, 1] + output = np.tile(x, bcast_dims) + self.outputs = {'Out': convert_float_to_uint16(output)} + + def if_enable_cinn(self): + self.enable_cinn = False + + def test_check_output(self): + self.check_output_with_place(place=paddle.CUDAPlace(0)) + + def test_check_grad(self): + self.check_grad_with_place( + paddle.CUDAPlace(0), ['X'], 'Out', check_prim=True + ) + + class TestExpandAsOpRank4(TestExpandAsBasic): - def setUp(self): - self.op_type = "expand_as_v2" - self.prim_op_type = "comp" - self.python_api = paddle.expand_as - self.public_python_api = paddle.expand_as - x = np.random.rand(1, 1, 7, 16).astype("float64") - target_tensor = np.random.rand(4, 6, 7, 16).astype("float64") + def init_inputs_and_outputs(self): + x = np.random.rand(1, 1, 7, 16).astype(self.dtype) + target_tensor = np.random.rand(4, 6, 7, 16).astype(self.dtype) self.inputs = {'X': x, "Y": target_tensor} self.attrs = {'target_shape': target_tensor.shape} bcast_dims = [4, 6, 1, 1] @@ -87,6 +186,39 @@ class TestExpandAsOpRank4(TestExpandAsBasic): self.outputs = {'Out': output} +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support the bfloat16", +) +class TestExpandAsOpRank4BFP16OP(TestExpandAsOpRank3): + def init_dtype(self): + self.dtype = np.uint16 + + def init_inputs_and_outputs(self): + x = np.random.rand(1, 1, 7, 16).astype(np.float32) + target_tensor = np.random.rand(4, 6, 7, 16).astype(np.float32) + self.inputs = { + 'X': convert_float_to_uint16(x), + "Y": convert_float_to_uint16(target_tensor), + } + self.attrs = {'target_shape': target_tensor.shape} + bcast_dims = [4, 6, 1, 1] + output = np.tile(x, bcast_dims) + self.outputs = {'Out': convert_float_to_uint16(output)} + + def if_enable_cinn(self): + self.enable_cinn = False + + def test_check_output(self): + self.check_output_with_place(place=paddle.CUDAPlace(0)) + + def test_check_grad(self): + self.check_grad_with_place( + paddle.CUDAPlace(0), ['X'], 'Out', check_prim=True + ) + + class TestExpandAsOpRank5(TestExpandAsBasic): no_need_check_grad = True @@ -107,6 +239,32 @@ class TestExpandAsOpRank5(TestExpandAsBasic): pass +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support the bfloat16", +) +class TestExpandAsOpRank5BFP16OP(TestExpandAsOpRank5): + def setUp(self): + self.op_type = "expand_as_v2" + self.prim_op_type = "comp" + self.python_api = paddle.expand_as + self.public_python_api = paddle.expand_as + x = np.random.rand(1, 1, 7, 16).astype("int64") + target_tensor = np.random.rand(4, 6, 7, 16).astype("float32") + self.inputs = {'X': x, "Y": convert_float_to_uint16(target_tensor)} + self.attrs = {'target_shape': target_tensor.shape} + bcast_dims = [4, 6, 1, 1] + output = np.tile(x, bcast_dims) + self.outputs = {'Out': convert_float_to_uint16(output)} + + def test_check_output(self): + self.check_output_with_place(place=paddle.CUDAPlace(0)) + + def test_check_grad(self): + pass + + class TestExpandAsV2Error(unittest.TestCase): def test_errors(self): with fluid.program_guard(fluid.Program(), fluid.Program()): diff --git a/python/paddle/fluid/tests/unittests/test_meshgrid_op.py b/python/paddle/fluid/tests/unittests/test_meshgrid_op.py index 0039d4ee422..1928108bfae 100644 --- a/python/paddle/fluid/tests/unittests/test_meshgrid_op.py +++ b/python/paddle/fluid/tests/unittests/test_meshgrid_op.py @@ -15,10 +15,11 @@ import unittest import numpy as np -from eager_op_test import OpTest +from eager_op_test import OpTest, convert_float_to_uint16 import paddle from paddle import fluid +from paddle.fluid import core def meshgrid_wrapper(x): @@ -32,11 +33,7 @@ class TestMeshgridOp(OpTest): self.python_api = meshgrid_wrapper self.public_python_api = meshgrid_wrapper self.dtype = self.get_dtype() - ins, outs = self.init_test_data() - self.inputs = {'X': [('x%d' % i, ins[i]) for i in range(len(ins))]} - self.outputs = { - 'Out': [('out%d' % i, outs[i]) for i in range(len(outs))] - } + self.init_inputs_and_outputs() self.python_out_sig = ['out0', 'out1'] self.if_enable_cinn() @@ -49,7 +46,7 @@ class TestMeshgridOp(OpTest): def test_check_grad(self): self.check_grad(['x0'], ['out0', 'out1'], check_prim=True) - def init_test_data(self): + def init_inputs_and_outputs(self): self.shape = self.get_x_shape() ins = [] outs = [] @@ -61,7 +58,10 @@ class TestMeshgridOp(OpTest): out_reshape[i] = self.shape[i] out_temp = np.reshape(ins[i], out_reshape) outs.append(np.broadcast_to(out_temp, self.shape)) - return ins, outs + self.inputs = {'X': [('x%d' % i, ins[i]) for i in range(len(ins))]} + self.outputs = { + 'Out': [('out%d' % i, outs[i]) for i in range(len(outs))] + } def get_x_shape(self): return [100, 200] @@ -84,6 +84,52 @@ class TestMeshgridOp2Fp16(TestMeshgridOp): return np.float16 +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support the bfloat16", +) +class TestMeshgridOpBFP16OP(TestMeshgridOp): + def init_data_type(self): + self.data_type = np.uint16 + + def init_inputs_and_outputs(self): + self.shape = self.get_x_shape() + ins = [] + outs = [] + for i in range(len(self.shape)): + ins.append(np.random.random((self.shape[i],)).astype(self.dtype)) + + for i in range(len(self.shape)): + out_reshape = [1] * len(self.shape) + out_reshape[i] = self.shape[i] + out_temp = np.reshape(ins[i], out_reshape) + outs.append(np.broadcast_to(out_temp, self.shape)) + self.inputs = { + 'X': [ + ('x%d' % i, convert_float_to_uint16(ins[i])) + for i in range(len(ins)) + ] + } + self.outputs = { + 'Out': [ + ('out%d' % i, convert_float_to_uint16(outs[i])) + for i in range(len(outs)) + ] + } + + def if_enable_cinn(self): + self.enable_cinn = False + + def test_check_output(self): + self.check_output_with_place(place=paddle.CUDAPlace(0)) + + def test_check_grad(self): + self.check_grad_with_place( + paddle.CUDAPlace(0), ['x0'], ['out0', 'out1'], check_prim=True + ) + + class TestMeshgridOp3(unittest.TestCase): def test_api(self): x = paddle.static.data(shape=[100], dtype='int32', name='x') @@ -273,7 +319,7 @@ class TestMeshgridOp8(unittest.TestCase): class TestMeshGrid_ZeroDim(TestMeshgridOp): - def init_test_data(self): + def init_inputs_and_outputs(self): self.shape = self.get_x_shape() ins = [] outs = [] @@ -285,7 +331,10 @@ class TestMeshGrid_ZeroDim(TestMeshgridOp): out_reshape[i] = self.shape[i] out_temp = np.reshape(ins[i], out_reshape) outs.append(np.broadcast_to(out_temp, self.shape)) - return ins, outs + self.inputs = {'X': [('x%d' % i, ins[i]) for i in range(len(ins))]} + self.outputs = { + 'Out': [('out%d' % i, outs[i]) for i in range(len(outs))] + } def get_x_shape(self): return [1, 2, 3] diff --git a/python/paddle/fluid/tests/unittests/test_reduce_op.py b/python/paddle/fluid/tests/unittests/test_reduce_op.py index 9c22e9c043a..f567f427532 100644 --- a/python/paddle/fluid/tests/unittests/test_reduce_op.py +++ b/python/paddle/fluid/tests/unittests/test_reduce_op.py @@ -303,6 +303,7 @@ class TestMaxFP32Op(OpTest): self.python_api = paddle.max self.public_python_api = paddle.max self.init_dtype() + self.if_enable_cinn() if self.dtype == np.uint16: x = np.random.random((5, 6, 10)).astype(np.float32) self.inputs = {'X': convert_float_to_uint16(x)} @@ -316,6 +317,9 @@ class TestMaxFP32Op(OpTest): else: self.outputs = {'Out': out} + def if_enable_cinn(self): + pass + def test_check_output(self): self.check_output() @@ -346,6 +350,9 @@ class TestMaxBF16Op(TestMaxFP32Op): def init_dtype(self): self.dtype = np.uint16 + def if_enable_cinn(self): + self.enable_cinn = False + def test_check_output(self): self.check_output_with_place(core.CUDAPlace(0)) @@ -487,8 +494,11 @@ class TestProdOp(OpTest): self.python_api = raw_reduce_prod self.public_python_api = raw_reduce_prod self.prim_op_type = "prim" - self.init_data_type() + self.init_inputs_and_outputs() + self.if_enable_cinn() + + def init_inputs_and_outputs(self): self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.data_type)} self.outputs = {'Out': self.inputs['X'].prod(axis=0)} @@ -497,6 +507,9 @@ class TestProdOp(OpTest): "float32" if core.is_compiled_with_rocm() else "float64" ) + def if_enable_cinn(self): + pass + def test_check_output(self): self.check_output() @@ -504,6 +517,49 @@ class TestProdOp(OpTest): self.check_grad(['X'], 'Out', check_prim=True) +@unittest.skipIf( + not paddle.is_compiled_with_cuda(), "FP16 test runs only on GPU" +) +class TestProdFP16OP(TestProdOp): + def init_data_type(self): + self.data_type = "float16" + + def test_check_output(self): + self.check_output_with_place(place=paddle.CUDAPlace(0)) + + def test_check_grad(self): + self.check_grad_with_place( + paddle.CUDAPlace(0), ['X'], 'Out', check_prim=True + ) + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support the bfloat16", +) +class TestProdBFP16OP(TestProdOp): + def init_data_type(self): + self.data_type = np.uint16 + + def init_inputs_and_outputs(self): + x = np.random.random((5, 6, 10)).astype("float32") + out = x.prod(axis=0) + self.inputs = {'X': convert_float_to_uint16(x)} + self.outputs = {'Out': convert_float_to_uint16(out)} + + def if_enable_cinn(self): + self.enable_cinn = False + + def test_check_output(self): + self.check_output_with_place(place=paddle.CUDAPlace(0)) + + def test_check_grad(self): + self.check_grad_with_place( + paddle.CUDAPlace(0), ['X'], 'Out', check_prim=True + ) + + class TestProdOpFp64(TestProdOp): def init_data_type(self): self.data_type = "float64" @@ -522,11 +578,16 @@ class TestProdOp_ZeroDim(OpTest): # 0-D tensor doesn't support in cinn self.enable_cinn = False + def init_inputs_and_outputs(self): + self.inputs = {'X': np.random.random([]).astype("float64")} + self.outputs = {'Out': self.inputs['X'].prod()} + self.attrs = {'dim': [], 'reduce_all': True} + def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_prim=True) class TestProd6DOp(OpTest): @@ -536,6 +597,15 @@ class TestProd6DOp(OpTest): self.public_python_api = raw_reduce_prod self.prim_op_type = "prim" self.init_data_type() + self.init_inputs_and_outputs() + self.if_enable_cinn() + + def init_data_type(self): + self.data_type = ( + "float32" if core.is_compiled_with_rocm() else "float64" + ) + + def init_inputs_and_outputs(self): self.inputs = { 'X': np.random.random((5, 6, 2, 3, 4, 2)).astype(self.data_type) } @@ -544,10 +614,8 @@ class TestProd6DOp(OpTest): 'Out': self.inputs['X'].prod(axis=tuple(self.attrs['dim'])) } - def init_data_type(self): - self.data_type = ( - "float32" if core.is_compiled_with_rocm() else "float64" - ) + def if_enable_cinn(self): + pass def test_check_output(self): self.check_output() @@ -556,12 +624,59 @@ class TestProd6DOp(OpTest): self.check_grad(['X'], 'Out', check_prim=True) +@unittest.skipIf( + not paddle.is_compiled_with_cuda(), "FP16 test runs only on GPU" +) +class TestProd6DFP16OP(TestProd6DOp): + def init_data_type(self): + self.data_type = "float16" + + def test_check_output(self): + self.check_output_with_place(place=paddle.CUDAPlace(0)) + + def test_check_grad(self): + self.check_grad_with_place( + paddle.CUDAPlace(0), ['X'], 'Out', check_prim=True + ) + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support the bfloat16", +) +class TestProd6DBFP16OP(TestProd6DOp): + def init_data_type(self): + self.data_type = np.uint16 + + def init_inputs_and_outputs(self): + x = np.random.random((5, 6, 2, 3, 4, 2)).astype("float32") + self.attrs = {'dim': [2, 3, 4]} + out = x.prod(axis=tuple(self.attrs['dim'])) + self.inputs = {'X': convert_float_to_uint16(x)} + self.outputs = {'Out': convert_float_to_uint16(out)} + + def if_enable_cinn(self): + self.enable_cinn = False + + def test_check_output(self): + self.check_output_with_place(place=paddle.CUDAPlace(0)) + + def test_check_grad(self): + self.check_grad_with_place( + paddle.CUDAPlace(0), ['X'], 'Out', check_prim=True + ) + + class TestProd8DOp(OpTest): def setUp(self): self.op_type = "reduce_prod" self.python_api = raw_reduce_prod self.public_python_api = raw_reduce_prod self.init_data_type() + self.init_inputs_and_outputs() + + def init_inputs_and_outputs(self): self.inputs = { 'X': np.random.random((2, 5, 3, 2, 2, 3, 4, 2)).astype( self.data_type @@ -584,6 +699,43 @@ class TestProd8DOp(OpTest): self.check_grad(['X'], 'Out') +@unittest.skipIf( + not paddle.is_compiled_with_cuda(), "FP16 test runs only on GPU" +) +class TestProd8DFP16OP(TestProd8DOp): + def init_data_type(self): + self.data_type = "float16" + + def test_check_output(self): + self.check_output_with_place(place=paddle.CUDAPlace(0)) + + def test_check_grad(self): + self.check_grad_with_place(paddle.CUDAPlace(0), ['X'], 'Out') + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support the bfloat16", +) +class TestProd8DBFP16OP(TestProd8DOp): + def init_data_type(self): + self.data_type = np.uint16 + + def init_inputs_and_outputs(self): + x = np.random.random((2, 5, 3, 2, 2, 3, 4, 2)).astype("float32") + self.attrs = {'dim': [2, 3, 4]} + out = x.prod(axis=tuple(self.attrs['dim'])) + self.inputs = {'X': convert_float_to_uint16(x)} + self.outputs = {'Out': convert_float_to_uint16(out)} + + def test_check_output(self): + self.check_output_with_place(place=paddle.CUDAPlace(0)) + + def test_check_grad(self): + self.check_grad_with_place(paddle.CUDAPlace(0), ['X'], 'Out') + + class TestAllOp(OpTest): def setUp(self): self.op_type = "reduce_all" diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index b6c9cce21ff..bf6c603a7ec 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -1563,7 +1563,7 @@ def meshgrid(*args, **kwargs): check_dtype( input_.dtype, 'create data type', - ['float16', 'float32', 'float64', 'int32', 'int64'], + ['uint16', 'float16', 'float32', 'float64', 'int32', 'int64'], 'meshgrid', ) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 6b94843b65c..1371db3ea93 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -3275,7 +3275,15 @@ def expand_as(x, y, name=None): check_variable_and_dtype( x, 'x', - ['bool', 'float32', 'float64', 'int32', 'int64'], + [ + 'bool', + 'float32', + 'float64', + 'int32', + 'int64', + 'float16', + 'uint16', + ], 'expand_as', ) check_type(y, 'y', Variable, 'expand_as') @@ -3348,7 +3356,15 @@ def broadcast_to(x, shape, name=None): check_variable_and_dtype( x, 'x', - ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], + [ + 'bool', + 'uint16', + 'float16', + 'float32', + 'float64', + 'int32', + 'int64', + ], 'broadcast_to', ) check_type(shape, 'shape', (list, tuple, Variable), 'broadcast_to') diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index f3635a69651..d36e0ff3064 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -3718,7 +3718,10 @@ def prod(x, axis=None, keepdim=False, dtype=None, name=None): """ if dtype is not None: check_dtype( - dtype, 'dtype', ['float32', 'float64', 'int32', 'int64'], 'prod' + dtype, + 'dtype', + ['float32', 'float64', 'int32', 'int64', "float16", "uint16"], + 'prod', ) if x.dtype != convert_np_dtype_to_dtype_(dtype): x = cast(x, dtype) @@ -3731,7 +3734,7 @@ def prod(x, axis=None, keepdim=False, dtype=None, name=None): check_variable_and_dtype( x, 'x/input', - ['float32', 'float64', 'int32', 'int64'], + ['float32', 'float64', 'int32', 'int64', "float16", "uint16"], 'reduce_prod', ) out = helper.create_variable_for_type_inference( -- GitLab