From e0007f319c53f8d214f1292c34d5c1ac86bc7fdd Mon Sep 17 00:00:00 2001 From: denglianbin <112610123+denglianbin@users.noreply.github.com> Date: Fri, 17 Mar 2023 19:16:57 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Hackathon=20No.46=E3=80=91=E4=B8=BA=20?= =?UTF-8?q?Paddle=20gumbel=5Fsoftmax=20=E7=AE=97=E5=AD=90=E5=AE=9E?= =?UTF-8?q?=E7=8E=B0=20float16=20=E6=95=B0=E6=8D=AE=E7=B1=BB=E5=9E=8B?= =?UTF-8?q?=E6=94=AF=E6=8C=81=20(#50923)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * finish task * fix some question. * fix error * change unittest:zeroDim. --- .../kernels/gpu/gumbel_softmax_grad_kernel.cu | 1 + .../phi/kernels/gpu/gumbel_softmax_kernel.cu | 32 ++++++++------ .../tests/unittests/test_gumbel_softmax_op.py | 42 ++++++++++++++++++- python/paddle/nn/functional/activation.py | 6 ++- 4 files changed, 66 insertions(+), 15 deletions(-) diff --git a/paddle/phi/kernels/gpu/gumbel_softmax_grad_kernel.cu b/paddle/phi/kernels/gpu/gumbel_softmax_grad_kernel.cu index 2d7b4bdeb69..119b30eadff 100644 --- a/paddle/phi/kernels/gpu/gumbel_softmax_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/gumbel_softmax_grad_kernel.cu @@ -21,5 +21,6 @@ PD_REGISTER_KERNEL(gumbel_softmax_grad, GPU, ALL_LAYOUT, phi::GumbelSoftmaxGradKernel, + phi::dtype::float16, float, double) {} diff --git a/paddle/phi/kernels/gpu/gumbel_softmax_kernel.cu b/paddle/phi/kernels/gpu/gumbel_softmax_kernel.cu index dcbf003281f..aee591894cc 100644 --- a/paddle/phi/kernels/gpu/gumbel_softmax_kernel.cu +++ b/paddle/phi/kernels/gpu/gumbel_softmax_kernel.cu @@ -13,7 +13,7 @@ // limitations under the License. #include "paddle/phi/kernels/gumbel_softmax_kernel.h" - +#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/axis_utils.h" #include "paddle/phi/kernels/impl/gumbel_softmax_kernel_impl.h" @@ -116,17 +116,18 @@ struct OneHotGenerator { } }; -template +template __global__ void AddGumbelNoiseCUDAKernel(const T* input_data, T* output_data, - T* noise, + MPType* noise, const float temperature, int64_t n) { int index = threadIdx.x + blockIdx.x * blockDim.x; int step = blockDim.x * gridDim.x; for (int64_t i = index; i < n; i += step) { - T gumbel_noise = -log(-log(noise[i])); - output_data[i] = (gumbel_noise + input_data[i]) / temperature; + MPType gumbel_noise = -log(-log(noise[i])); + output_data[i] = static_cast( + (gumbel_noise + static_cast(input_data[i])) / temperature); } } @@ -141,7 +142,8 @@ struct GumbleNoiseGenerator { DenseTensor random_tensor; int64_t size = size_to_axis * size_from_axis; random_tensor.Resize(make_ddim({size})); - T* random_data = ctx.template Alloc(&random_tensor); + using MPType = typename phi::dtype::MPTypeTrait::Type; + MPType* random_data = ctx.template Alloc(&random_tensor); // generate gumbel noise int device_id = ctx.GetPlace().GetDeviceId(); @@ -152,10 +154,11 @@ struct GumbleNoiseGenerator { uint64_t offset = seed_offset.second; thrust::counting_iterator index_sequence_begin(0); - thrust::transform(index_sequence_begin, - index_sequence_begin + size, - thrust::device_ptr(random_data), - UniformCUDAGenerator(0.00001, 1, seed, size * offset)); + thrust::transform( + index_sequence_begin, + index_sequence_begin + size, + thrust::device_ptr(random_data), + UniformCUDAGenerator(0.00001, 1, seed, size * offset)); // add gumbel noise to X const int thread_size = 512; @@ -168,5 +171,10 @@ struct GumbleNoiseGenerator { } // namespace phi #endif -PD_REGISTER_KERNEL( - gumbel_softmax, GPU, ALL_LAYOUT, phi::GumbelSoftmaxKernel, float, double) {} +PD_REGISTER_KERNEL(gumbel_softmax, + GPU, + ALL_LAYOUT, + phi::GumbelSoftmaxKernel, + phi::dtype::float16, + float, + double) {} diff --git a/python/paddle/fluid/tests/unittests/test_gumbel_softmax_op.py b/python/paddle/fluid/tests/unittests/test_gumbel_softmax_op.py index 30e4d7943ff..0a2725c1e3e 100644 --- a/python/paddle/fluid/tests/unittests/test_gumbel_softmax_op.py +++ b/python/paddle/fluid/tests/unittests/test_gumbel_softmax_op.py @@ -53,10 +53,13 @@ class TestGumbelSoftmaxOp(OpTest): class TestGumbelSoftmax_ZeroDim(OpTest): + def init_attrs(self): + self.dtype = "float64" + def setUp(self): self.op_type = "gumbel_softmax" self.python_api = F.gumbel_softmax - self.dtype = "float64" + self.init_attrs() x = np.random.uniform(0.1, 1, []).astype(self.dtype) out = np.array(1.0).astype(self.dtype) @@ -103,6 +106,43 @@ class TestGumbelSoftmaxOp5(TestGumbelSoftmaxOp): self.dtype = "float64" +class TestGumbelSoftmax_ZeroDim_FP16OP(TestGumbelSoftmax_ZeroDim): + def init_attrs(self): + self.dtype = np.float16 + + +class TestGumbelSoftmaxFP16OP2(TestGumbelSoftmaxOp): + def init_attrs(self): + self.shape = [20, 10] + self.attrs = {"hard": True, "axis": 0} + self.count_expected = 10 + self.dtype = np.float16 + + +class TestGumbelSoftmaxFP16OP3(TestGumbelSoftmaxOp): + def init_attrs(self): + self.shape = [100] + self.attrs = {"hard": True, "axis": -1} + self.count_expected = 1 + self.dtype = np.float16 + + +class TestGumbelSoftmaxFP16OP4(TestGumbelSoftmaxOp): + def init_attrs(self): + self.shape = [20, 10, 5] + self.attrs = {"hard": True, "axis": -1} + self.count_expected = 200 + self.dtype = np.float16 + + +class TestGumbelSoftmaxFP16OP5(TestGumbelSoftmaxOp): + def init_attrs(self): + self.shape = [20, 10, 5] + self.attrs = {"hard": True, "axis": 1} + self.count_expected = 100 + self.dtype = np.float16 + + class TestGumbelSoftmaxOpSampleDistribution(OpTest): def softmax(self, x): x_row_max = x.max(axis=-1) diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 8ee01b5e58f..0ec80b3c5bb 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -1664,7 +1664,7 @@ def gumbel_softmax(x, temperature=1.0, hard=False, axis=-1, name=None): Parameters: x (Tensor): An N-D Tensor, the first N - 1 dimensions index into a batch of independent distributions and the last dimension represents - a vector of probabilities with datatype float32, float64. + a vector of probabilities with datatype float16, float32, float64. temperature (float, optional): non-negative scalar temperature. Default is 1.0. hard (bool, optional): if True, the returned samples will be discretized as @@ -1705,7 +1705,9 @@ def gumbel_softmax(x, temperature=1.0, hard=False, axis=-1, name=None): ) helper = LayerHelper("gumbel_softmax", **locals()) - check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'gumbel_softmax') + check_variable_and_dtype( + x, 'x', ['float16', 'float32', 'float64'], 'gumbel_softmax' + ) out = helper.create_variable_for_type_inference(x.dtype) helper.append_op( type='gumbel_softmax', -- GitLab