From bc91012f826c72c0620a49651c11de76b7afecd4 Mon Sep 17 00:00:00 2001 From: chenxujun Date: Tue, 18 Apr 2023 14:48:33 +0800 Subject: [PATCH] =?UTF-8?q?=20=E3=80=90Hackathon=20No.60=E3=80=91randperm,?= =?UTF-8?q?=20split,=20split=5Fwith=5Fnum=20=E7=AE=97=E5=AD=90FP16/BF16?= =?UTF-8?q?=E5=8D=95=E6=B5=8B=E5=AE=8C=E5=96=84=20(#52683)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add split, split_with_num tests * Add randperm tests * Fix code --- paddle/phi/kernels/gpu/randperm_kernel.cu | 5 +- .../fluid/tests/unittests/test_randperm_op.py | 60 ++++++++++++++++++- .../fluid/tests/unittests/test_split_op.py | 54 +++++++++++------ python/paddle/tensor/manipulation.py | 1 + 4 files changed, 97 insertions(+), 23 deletions(-) diff --git a/paddle/phi/kernels/gpu/randperm_kernel.cu b/paddle/phi/kernels/gpu/randperm_kernel.cu index 6ae0e6df07c..456d541451e 100644 --- a/paddle/phi/kernels/gpu/randperm_kernel.cu +++ b/paddle/phi/kernels/gpu/randperm_kernel.cu @@ -28,6 +28,7 @@ namespace cub = hipcub; #include "gflags/gflags.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/memory_utils.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/empty_kernel.h" @@ -165,4 +166,6 @@ PD_REGISTER_KERNEL(randperm, float, double, int, - int64_t) {} + int64_t, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/python/paddle/fluid/tests/unittests/test_randperm_op.py b/python/paddle/fluid/tests/unittests/test_randperm_op.py index eaecf087f9f..5df2873b93e 100644 --- a/python/paddle/fluid/tests/unittests/test_randperm_op.py +++ b/python/paddle/fluid/tests/unittests/test_randperm_op.py @@ -15,7 +15,11 @@ import unittest import numpy as np -from eager_op_test import OpTest +from eager_op_test import ( + OpTest, + convert_float_to_uint16, + convert_uint16_to_float, +) import paddle from paddle.fluid import core @@ -40,12 +44,21 @@ def error_msg(data_np): def convert_dtype(dtype_str): - dtype_str_list = ["int32", "int64", "float32", "float64"] + dtype_str_list = [ + "int32", + "int64", + "float16", + "float32", + "float64", + "uint16", + ] dtype_num_list = [ core.VarDesc.VarType.INT32, core.VarDesc.VarType.INT64, + core.VarDesc.VarType.FP16, core.VarDesc.VarType.FP32, core.VarDesc.VarType.FP64, + core.VarDesc.VarType.BF16, ] assert dtype_str in dtype_str_list, ( dtype_str + " should in " + str(dtype_str_list) @@ -62,9 +75,9 @@ class TestRandpermOp(OpTest): self.n = 200 self.dtype = "int64" + self.init_attrs() self.inputs = {} self.outputs = {"Out": np.zeros(self.n).astype(self.dtype)} - self.init_attrs() self.attrs = { "n": self.n, "dtype": convert_dtype(self.dtype), @@ -103,6 +116,47 @@ class TestRandpermOpFloat64(TestRandpermOp): self.dtype = "float64" +class TestRandpermFP16Op(TestRandpermOp): + def init_attrs(self): + self.dtype = "float16" + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support bfloat16", +) +class TestRandpermBF16Op(OpTest): + def setUp(self): + self.op_type = "randperm" + self.python_api = paddle.randperm + self.n = 200 + + self.init_attrs() + self.inputs = {} + self.outputs = {"Out": np.zeros(self.n).astype(self.np_dtype)} + self.attrs = { + "n": self.n, + "dtype": convert_dtype(self.dtype), + } + + self.outputs['Out'] = convert_float_to_uint16(self.outputs['Out']) + self.place = core.CUDAPlace(0) + + def init_attrs(self): + self.dtype = "uint16" + self.np_dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place_customized(self.verify_output, self.place) + + def verify_output(self, outs): + out_np = convert_uint16_to_float(np.array(outs[0])) + self.assertTrue( + check_randperm_out(self.n, out_np), msg=error_msg(out_np) + ) + + class TestRandpermOpError(unittest.TestCase): def test_errors(self): with program_guard(Program(), Program()): diff --git a/python/paddle/fluid/tests/unittests/test_split_op.py b/python/paddle/fluid/tests/unittests/test_split_op.py index 3149ca82b3f..29446bafbbf 100644 --- a/python/paddle/fluid/tests/unittests/test_split_op.py +++ b/python/paddle/fluid/tests/unittests/test_split_op.py @@ -65,7 +65,7 @@ class TestSplitOp(OpTest): # test with attr(num) -class TestSplitOp_2(OpTest): +class TestSplitWithNumOp(OpTest): def setUp(self): self.python_api = paddle.split self.public_python_api = paddle.split @@ -74,18 +74,32 @@ class TestSplitOp_2(OpTest): self.prim_op_type = "prim" self.dtype = self.get_dtype() self.init_data() - self.inputs = {'X': self.x} self.attrs = { 'axis': self.axis, 'sections': self.sections, 'num': self.num, } - - out = np.split(self.x, self.indices_or_sections, self.axis) - self.outputs = {'Out': [('out%d' % i, out[i]) for i in range(len(out))]} + if self.dtype == np.uint16: + self.inputs = {'X': convert_float_to_uint16(self.x)} + out = np.split(self.x, self.indices_or_sections, self.axis) + self.outputs = { + 'Out': [ + ('out%d' % i, convert_float_to_uint16(out[i])) + for i in range(len(out)) + ] + } + else: + self.inputs = {'X': self.x} + out = np.split(self.x, self.indices_or_sections, self.axis) + self.outputs = { + 'Out': [('out%d' % i, out[i]) for i in range(len(out))] + } def init_data(self): - self.x = np.random.random((4, 5, 6)).astype(self.dtype) + if self.dtype == np.uint16: + self.x = np.random.random((4, 5, 6)).astype(np.float32) + else: + self.x = np.random.random((4, 5, 6)).astype(self.dtype) self.axis = 2 self.sections = [] self.num = 3 @@ -240,28 +254,28 @@ def create_test_fp16(parent): @unittest.skipIf( not core.is_compiled_with_cuda(), "core is not compiled with CUDA" ) - class TestSplitFp16(parent): + class TestSplitFP16Op(parent): def get_dtype(self): return np.float16 - def test_check_grad(self): - pass - - cls_name = "{}_{}".format(parent.__name__, "Fp16") - TestSplitFp16.__name__ = cls_name - globals()[cls_name] = TestSplitFp16 + cls_name = "{}_{}".format(parent.__name__, "FP16Op") + TestSplitFP16Op.__name__ = cls_name + globals()[cls_name] = TestSplitFP16Op create_test_fp16(TestSplitOp) +create_test_fp16(TestSplitWithNumOp) # ----------------Split Bf16---------------- def create_test_bf16(parent): @unittest.skipIf( - not core.is_compiled_with_cuda(), "core is not compiled with CUDA" + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support bfloat16", ) - class TestSplitBf16(parent): + class TestSplitBF16Op(parent): def get_dtype(self): return np.uint16 @@ -270,14 +284,16 @@ def create_test_bf16(parent): self.check_output_with_place(place) def test_check_grad(self): - pass + place = core.CUDAPlace(0) + self.check_grad_with_place(place, ['X'], 'out2') - cls_name = "{}_{}".format(parent.__name__, "Bf16") - TestSplitBf16.__name__ = cls_name - globals()[cls_name] = TestSplitBf16 + cls_name = "{}_{}".format(parent.__name__, "BF16Op") + TestSplitBF16Op.__name__ = cls_name + globals()[cls_name] = TestSplitBF16Op create_test_bf16(TestSplitOp) +create_test_bf16(TestSplitWithNumOp) class TestSplitAPI(unittest.TestCase): diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 7309f804013..6ce6587d44e 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1976,6 +1976,7 @@ def split(x, num_or_sections, axis=0, name=None): 'int32', 'int64', 'uint8', + 'uint16', 'int8', ], 'split', -- GitLab