From 7788b65e1583b614781366b0fb792b92a640141b Mon Sep 17 00:00:00 2001 From: yunyaoXYY <109218879+yunyaoXYY@users.noreply.github.com> Date: Thu, 30 Mar 2023 11:53:44 +0800 Subject: [PATCH] [AMP OP&Test] Register FP16 for multinomial. (#52107) * add FP16 for multinomial * fix input data * update code * fix FP16 * fix code --- paddle/phi/kernels/gpu/multinomial_kernel.cu | 58 ++++++++--------- .../tests/unittests/test_multinomial_op.py | 62 +++++++++++++++++++ python/paddle/tensor/random.py | 4 +- 3 files changed, 95 insertions(+), 29 deletions(-) diff --git a/paddle/phi/kernels/gpu/multinomial_kernel.cu b/paddle/phi/kernels/gpu/multinomial_kernel.cu index 5666024cfaa..afc4e9d30a1 100644 --- a/paddle/phi/kernels/gpu/multinomial_kernel.cu +++ b/paddle/phi/kernels/gpu/multinomial_kernel.cu @@ -27,6 +27,7 @@ namespace cub = hipcub; #endif #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/scalar.h" #include "paddle/phi/core/ddim.h" #include "paddle/phi/core/kernel_registry.h" @@ -41,25 +42,25 @@ namespace cub = hipcub; namespace phi { -template -__global__ void NormalizeProbability(T* norm_probs, +template +__global__ void NormalizeProbability(MT* norm_probs, const T* in_data, - T* sum_rows, + MT* sum_rows, int64_t num_distributions, int64_t num_categories) { int id = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x; if (id < num_distributions * num_categories) { PADDLE_ENFORCE( - in_data[id] >= 0.0, + static_cast(in_data[id]) >= 0.0, "The input of multinomial distribution should be >= 0, but got %f.", - in_data[id]); + static_cast(in_data[id])); int64_t row_id = id / num_categories; PADDLE_ENFORCE(sum_rows[row_id] > 0.0, "The sum of one multinomial distribution probability should " "be > 0, but got %f.", sum_rows[row_id]); - norm_probs[id] = in_data[id] / sum_rows[row_id]; + norm_probs[id] = static_cast(in_data[id]) / sum_rows[row_id]; } } @@ -131,6 +132,8 @@ void MultinomialKernel(const Context& dev_ctx, const Scalar& num_samples, bool replacement, DenseTensor* out) { + using MT = typename kps::details::MPTypeTrait::Type; + auto int_num_samples = num_samples.to(); auto* in_data = x.data(); int64_t* out_data = dev_ctx.template Alloc(out); @@ -138,7 +141,6 @@ void MultinomialKernel(const Context& dev_ctx, int64_t dim_size = in_dims.size(); const int64_t num_categories = in_dims[dim_size - 1]; const int64_t num_distributions = dim_size > 1 ? in_dims[dim_size - 2] : 1; - // If replacement is False, it's not a replaceable sample. Every category // can be used only once. if (!replacement) { @@ -153,11 +155,11 @@ void MultinomialKernel(const Context& dev_ctx, for (size_t j = 0; j < num_categories; ++j) { T weight = cpu_in_data[i * num_categories + j]; PADDLE_ENFORCE_GE( - weight, + static_cast(weight), 0, errors::InvalidArgument( "Each element of multinomial'input must >= 0, but got %f.", - weight)); + static_cast(weight))); if (weight == static_cast(0)) { zero_num++; } @@ -174,8 +176,8 @@ void MultinomialKernel(const Context& dev_ctx, // Refer to [gumbel softmax algorithm] DenseTensor rand = EmptyLike(dev_ctx, x); T* rand_data = rand.data(); - funcs::uniform_distribution dist; - funcs::exponential_transform trans(1.0); + funcs::uniform_distribution dist; + funcs::exponential_transform trans(1.0); funcs::distribution_and_transform(dev_ctx, &rand, dist, trans); funcs::ForRange for_range(dev_ctx, x.numel()); @@ -200,61 +202,60 @@ void MultinomialKernel(const Context& dev_ctx, // sum_row_data: sum of each row DenseTensor sum_rows_tensor; sum_rows_tensor.Resize({num_distributions}); - auto* sum_rows_data = dev_ctx.template Alloc(&sum_rows_tensor); - + auto* sum_rows_data = dev_ctx.template Alloc(&sum_rows_tensor); auto& place = *dev_ctx.eigen_device(); if (num_distributions == 1) { auto eigen_input = EigenVector::Flatten(x); - auto eigen_sum_rows = EigenVector::Flatten(sum_rows_tensor); + auto eigen_sum_rows = EigenVector::Flatten(sum_rows_tensor); eigen_sum_rows.device(place) = eigen_input.sum(Eigen::DSizes(1)) + .template cast() .eval() - .reshape(Eigen::DSizes(sum_rows_tensor.dims()[0])); + .template cast() + .reshape(Eigen::DSizes(sum_rows_tensor.dims()[0])) + .template cast(); } else { auto eigen_input = EigenMatrix::From(x); - auto eigen_sum_rows = EigenVector::Flatten(sum_rows_tensor); - eigen_sum_rows.device(place) = eigen_input.sum(Eigen::DSizes(1)); + auto eigen_sum_rows = EigenVector::Flatten(sum_rows_tensor); + eigen_sum_rows.device(place) = + eigen_input.sum(Eigen::DSizes(1)).template cast(); } - // Normalize row of each distribution to get the probability in range [0, // 1]. // norm_probs_data: probability of the distribution DenseTensor norm_probs_tensor; norm_probs_tensor.Resize({num_distributions, num_categories}); - auto* norm_probs_data = dev_ctx.template Alloc(&norm_probs_tensor); - + auto* norm_probs_data = dev_ctx.template Alloc(&norm_probs_tensor); // number of threads in a block is min(num_categories, 512) int block_size = num_categories < 512 ? num_categories : 512; dim3 block_norm(block_size); dim3 grid_norm((num_distributions * num_categories - 1) / block_norm.x + 1); - NormalizeProbability + + NormalizeProbability <<>>(norm_probs_data, in_data, sum_rows_data, num_distributions, num_categories); - // Get cumulative probability of each distribution. It's the same function // of ``cumsum`` op. DenseTensor cumulative_probs_tensor; cumulative_probs_tensor.Resize({num_distributions, num_categories}); auto* cumulative_probs_data = - dev_ctx.template Alloc(&cumulative_probs_tensor); - + dev_ctx.template Alloc(&cumulative_probs_tensor); // 'phi::funcs::InclusiveScan' has higher accuracy than // 'thrust::inclusive_scan' - funcs::InclusiveScan>( + funcs::InclusiveScan>( /*in*/ norm_probs_data, /*out*/ cumulative_probs_data, /*outer_dim*/ static_cast(num_distributions), /*mid_dim*/ static_cast(num_categories), /*inner_dim*/ static_cast(1), /*init*/ static_cast(0), - std::plus(), + std::plus(), /*reverse=*/false, dev_ctx); - // Sample the multinomial distributions. dim3 block(128); int64_t device_id = dev_ctx.GetPlace().GetDeviceId(); @@ -269,7 +270,7 @@ void MultinomialKernel(const Context& dev_ctx, uint64_t increment = curand4_loop_times * 4; auto seed_offset = gen_cuda->IncrementOffset(increment); - sampleMultinomialWithReplacement + sampleMultinomialWithReplacement <<>>(int_num_samples, out_data, num_distributions, @@ -286,6 +287,7 @@ PD_REGISTER_KERNEL(multinomial, // cuda_only GPU, ALL_LAYOUT, phi::MultinomialKernel, + phi::dtype::float16, float, double) { kernel->OutputAt(0).SetDataType(phi::DataType::INT64); diff --git a/python/paddle/fluid/tests/unittests/test_multinomial_op.py b/python/paddle/fluid/tests/unittests/test_multinomial_op.py index 7fbfd6f5ec6..2fc10c88ba0 100644 --- a/python/paddle/fluid/tests/unittests/test_multinomial_op.py +++ b/python/paddle/fluid/tests/unittests/test_multinomial_op.py @@ -104,6 +104,68 @@ class TestMultinomialOp3(TestMultinomialOp): ) +# FP16 OP +class TestMultinomialFP16Op(OpTest): + def setUp(self): + paddle.enable_static() + self.op_type = "multinomial" + self.dtype = np.float16 + self.init_data() + self.inputs = {"X": self.input_np} + + def init_data(self): + # input probability is a vector, and replacement is True + self.input_np = np.random.rand(4).astype(self.dtype) + self.outputs = {"Out": np.zeros(100000).astype("int64")} + self.attrs = {"num_samples": 100000, "replacement": True} + + def test_check_output(self): + self.check_output_customized(self.verify_output) + + def sample_output(self, out): + return sample_output_one_dimension(out, 4) + + def verify_output(self, outs): + # normalize the input to get the probability + prob = self.input_np / self.input_np.sum(axis=-1, keepdims=True) + sample_prob = self.sample_output(np.array(outs[0])) + np.testing.assert_allclose( + sample_prob, + prob, + rtol=0, + atol=0.01, + err_msg='sample_prob: ' + str(sample_prob) + '\nprob: ' + str(prob), + ) + + +class TestMultinomialFP16Op2(TestMultinomialFP16Op): + def init_data(self): + # input probability is a matrix + self.input_np = np.random.rand(3, 4).astype(self.dtype) + self.outputs = {"Out": np.zeros((3, 100000)).astype("int64")} + self.attrs = {"num_samples": 100000, "replacement": True} + + def sample_output(self, out): + return sample_output_two_dimension(out, [3, 4]) + + +class TestMultinomialFP16Op3(TestMultinomialFP16Op): + def init_data(self): + # replacement is False. number of samples must be less than number of categories. + self.input_np = np.random.rand(1000).astype(self.dtype) + self.outputs = {"Out": np.zeros(100).astype("int64")} + self.attrs = {"num_samples": 100, "replacement": False} + + def verify_output(self, outs): + out = np.array(outs[0]) + unique_out = np.unique(out) + self.assertEqual( + len(unique_out), + 100, + "replacement is False. categories can't be sampled repeatedly", + ) + + class TestMultinomialApi(unittest.TestCase): def test_dygraph(self): # input probability is a vector, and replacement is True diff --git a/python/paddle/tensor/random.py b/python/paddle/tensor/random.py index 33f2b18b9be..4339a25f565 100644 --- a/python/paddle/tensor/random.py +++ b/python/paddle/tensor/random.py @@ -187,7 +187,9 @@ def multinomial(x, num_samples=1, replacement=False, name=None): if in_dygraph_mode(): return _C_ops.multinomial(x, num_samples, replacement) else: - check_variable_and_dtype(x, "x", ["float32", "float64"], "multinomial") + check_variable_and_dtype( + x, "x", ["uint16", "float16", "float32", "float64"], "multinomial" + ) helper = LayerHelper("multinomial", **locals()) out = helper.create_variable_for_type_inference( -- GitLab