From 640cff0a5c72e2a5e525d48ea83cd48b3048bbe9 Mon Sep 17 00:00:00 2001 From: Difer <707065510@qq.com> Date: Tue, 16 May 2023 19:31:07 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Hackathon=20No57=E3=80=91add=20bf16=20?= =?UTF-8?q?for=20mode=20(#53195)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add bf16 for mode * remove random seed 666 * try to fix op_type error * test for me * try to fix op_type * fix redundancy code * add fp,bf for lastdim * fix some error * simplify code * fix shape error * optype error * fix skipif bf16 --- paddle/phi/kernels/gpu/mode_grad_kernel.cu | 6 +- paddle/phi/kernels/gpu/mode_kernel.cu | 14 +- .../fluid/tests/unittests/test_mode_op.py | 133 ++++++++++++++---- 3 files changed, 125 insertions(+), 28 deletions(-) diff --git a/paddle/phi/kernels/gpu/mode_grad_kernel.cu b/paddle/phi/kernels/gpu/mode_grad_kernel.cu index e297eb88d0c..e0fe5a3d0ab 100644 --- a/paddle/phi/kernels/gpu/mode_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/mode_grad_kernel.cu @@ -15,6 +15,8 @@ #include "paddle/phi/kernels/mode_grad_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/mode.h" @@ -89,4 +91,6 @@ PD_REGISTER_KERNEL(mode_grad, float, double, int32_t, - int64_t) {} + int64_t, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/mode_kernel.cu b/paddle/phi/kernels/gpu/mode_kernel.cu index c834d87aca9..ed598b2e75d 100644 --- a/paddle/phi/kernels/gpu/mode_kernel.cu +++ b/paddle/phi/kernels/gpu/mode_kernel.cu @@ -15,6 +15,8 @@ #include "paddle/phi/kernels/mode_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/mode.h" @@ -129,7 +131,15 @@ void ModeKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL( - mode, GPU, ALL_LAYOUT, phi::ModeKernel, float, double, int32_t, int64_t) { +PD_REGISTER_KERNEL(mode, + GPU, + ALL_LAYOUT, + phi::ModeKernel, + float, + double, + int32_t, + int64_t, + phi::dtype::float16, + phi::dtype::bfloat16) { kernel->OutputAt(1).SetDataType(phi::DataType::INT64); } diff --git a/python/paddle/fluid/tests/unittests/test_mode_op.py b/python/paddle/fluid/tests/unittests/test_mode_op.py index eaf9c39631c..15c376f57ab 100644 --- a/python/paddle/fluid/tests/unittests/test_mode_op.py +++ b/python/paddle/fluid/tests/unittests/test_mode_op.py @@ -15,10 +15,15 @@ import unittest import numpy as np -from eager_op_test import OpTest +from eager_op_test import ( + OpTest, + convert_float_to_uint16, + convert_uint16_to_float, +) import paddle from paddle import fluid +from paddle.fluid import core def _mode1D(a): @@ -34,7 +39,7 @@ def _mode1D(a): mode = sorted_array[i] index = sorted_inds[i] max_freq = cur_freq - cur_freq = 0 + cur_freq = 0 return mode, index @@ -59,51 +64,129 @@ def cal_mode(a, axis, keepdim=False): class TestModeOp(OpTest): def init_args(self): self.axis = 1 + self.input_shape = (2, 64, 1) + + def init_input_data(self): + self.input_data = np.random.rand(*self.input_shape).astype(self.dtype) + self.inputs = {'X': self.input_data} + + def init_dtype(self): + self.dtype = np.float64 def setUp(self): self.op_type = "mode" self.python_api = paddle.mode - self.dtype = np.float64 - np.random.seed(666) - self.input_data = np.random.rand(2, 64, 1) + self.init_dtype() self.init_args() - self.inputs = {'X': self.input_data} + self.init_input_data() self.attrs = {'axis': self.axis} output, indices = cal_mode(self.input_data, axis=self.axis) self.outputs = {'Out': output, 'Indices': indices} + def init_numeric_grads(self): + if self.axis < 0: + axis = len(self.input_data.shape) + self.axis + else: + axis = self.axis + if self.dtype == np.float64: + dtype = np.float64 + else: + dtype = np.float32 + grad = np.zeros(self.input_data.shape).astype(dtype) + in_dims = list(range(grad.ndim)) + if axis == len(self.input_data.shape) - 1: + a_view = grad + else: + a_view = np.transpose( + grad, + in_dims[:axis] + in_dims[axis + 1 :] + [axis], + ) + idx = np.array(self.outputs['Indices']).flatten() + inds = np.ndindex(a_view.shape[:-1]) + for i, ind in enumerate(inds): + a_view[ind][idx[i]] = 1 / np.prod(self.outputs['Indices'].shape) + if axis == len(self.input_data.shape) - 1: + grad = a_view + else: + grad = np.transpose( + a_view, + in_dims[:axis] + in_dims[-1:] + in_dims[axis:-1], + ) + return grad + def test_check_output(self): paddle.enable_static() self.check_output() def test_check_grad(self): paddle.enable_static() - self.check_grad({'X'}, 'Out') + grad = self.init_numeric_grads() + self.check_grad({'X'}, 'Out', user_defined_grads=[grad]) + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class TestModeFP16Op(TestModeOp): + def init_dtype(self): + self.dtype = np.float16 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA and not support the bfloat16", +) +class TestModeBF16Op(TestModeOp): + def init_dtype(self): + self.dtype = np.uint16 + + def init_input_data(self): + self.input_data = np.random.rand(*self.input_shape).astype(np.float32) + self.input_data = convert_uint16_to_float( + convert_float_to_uint16(self.input_data) + ) + self.inputs = {'X': convert_float_to_uint16(self.input_data)} + + def test_check_output(self): + place = core.CUDAPlace(0) + paddle.enable_static() + if core.is_bfloat16_supported(place): + self.check_output_with_place(place) + + def test_check_grad(self): + place = core.CUDAPlace(0) + paddle.enable_static() + grad = self.init_numeric_grads() + if core.is_bfloat16_supported(place): + self.check_grad_with_place( + place, {'X'}, 'Out', user_defined_grads=[grad] + ) -class TestModeOpLastdim(OpTest): + +class TestModeOpLastdim(TestModeOp): def init_args(self): self.axis = -1 + self.input_shape = (2, 1, 1, 2, 30) - def setUp(self): - self.op_type = "mode" - self.python_api = paddle.mode - self.dtype = np.float64 - np.random.seed(666) - self.input_data = np.random.rand(2, 1, 1, 2, 30) - self.init_args() - self.inputs = {'X': self.input_data} - self.attrs = {'axis': self.axis} - output, indices = cal_mode(self.input_data, axis=self.axis) - self.outputs = {'Out': output, 'Indices': indices} - def test_check_output(self): - paddle.enable_static() - self.check_output() +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class TestModeFP16OpLastdim(TestModeFP16Op): + def init_args(self): + self.axis = -1 + self.input_shape = (2, 1, 1, 2, 30) - def test_check_grad(self): - paddle.enable_static() - self.check_grad({'X'}, 'Out') + +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class TestModeBF16OpLastdim(TestModeBF16Op): + def init_args(self): + self.axis = -1 + self.input_shape = (2, 1, 1, 2, 30) class TestModeOpKernels(unittest.TestCase): -- GitLab