From 7cd2c13f1b58a7c225883ff8c58cf49a34442835 Mon Sep 17 00:00:00 2001 From: pangyoki Date: Tue, 29 Sep 2020 10:45:35 -0500 Subject: [PATCH] add multinomial op (#27219) * add multinomial cpu kernel * fix C++ notype error * fix windows ci array len error * let array len be const * change array to vector * add cuda kernrl with num_distribution is 1, and not support replacement=False * add multinomial python api * support num_distribution different multinomial distributions * add multinomial python api unittest * change output dtype to int64 * fix coverage prob * optimize format * fix dtype of output error, should be int64_t --- paddle/fluid/operators/multinomial_op.cc | 103 ++++++++ paddle/fluid/operators/multinomial_op.cu | 245 ++++++++++++++++++ paddle/fluid/operators/multinomial_op.h | 127 +++++++++ python/paddle/__init__.py | 1 + .../tests/unittests/test_multinomial_op.py | 179 +++++++++++++ python/paddle/tensor/__init__.py | 1 + python/paddle/tensor/random.py | 66 +++++ 7 files changed, 722 insertions(+) create mode 100644 paddle/fluid/operators/multinomial_op.cc create mode 100644 paddle/fluid/operators/multinomial_op.cu create mode 100644 paddle/fluid/operators/multinomial_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_multinomial_op.py diff --git a/paddle/fluid/operators/multinomial_op.cc b/paddle/fluid/operators/multinomial_op.cc new file mode 100644 index 00000000000..94c9fc2d974 --- /dev/null +++ b/paddle/fluid/operators/multinomial_op.cc @@ -0,0 +1,103 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include "paddle/fluid/operators/multinomial_op.h" + +#include +#include +#include + +#include "paddle/fluid/framework/generator.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/common_infer_shape_functions.h" + +namespace paddle { +namespace operators { + +class MultinomialOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "A tensor contains probabilities of categories"); + AddOutput("Out", "The output tensor of multinomial op"); + AddAttr("num_samples", "number of the generated samples") + .SetDefault(1); + AddAttr("replacement", "can a category be sampled more than once") + .SetDefault(false); + AddComment(R"DOC( +This OP returns a Tensor filled with the sampled categoris according to Multinomial probabilities. + + Out ~ Multinomial(X) + +)DOC"); + } +}; + +class MultinomialOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Multinomial"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Multinomial"); + + auto x_dim = ctx->GetInputDim("X"); + int64_t x_rank = x_dim.size(); + std::vector out_dims(x_rank); + for (int64_t i = 0; i < x_rank - 1; i++) { + out_dims[i] = x_dim[i]; + } + + int64_t num_samples = ctx->Attrs().Get("num_samples"); + out_dims[x_rank - 1] = num_samples; + + ctx->SetOutputDim("Out", framework::make_ddim(out_dims)); + } +}; + +template +class MultinomialOpKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + const auto x = ctx.Input("X"); + auto out = ctx.Output("Out"); + const int64_t num_samples = ctx.Attr("num_samples"); + const bool replacement = ctx.Attr("replacement"); + + auto *in_data = x->data(); + int64_t *out_data = out->mutable_data(ctx.GetPlace()); + + auto in_dims = x->dims(); + int64_t in_rank = in_dims.size(); + const int64_t num_categories = in_dims[in_rank - 1]; + const int64_t num_distributions = in_rank > 1 ? in_dims[in_rank - 2] : 1; + + MultinomialFunctor(out_data, in_data, num_samples, replacement, + num_categories, num_distributions); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OPERATOR( + multinomial, ops::MultinomialOp, ops::MultinomialOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); + +REGISTER_OP_CPU_KERNEL( + multinomial, ops::MultinomialOpKernel, + ops::MultinomialOpKernel); diff --git a/paddle/fluid/operators/multinomial_op.cu b/paddle/fluid/operators/multinomial_op.cu new file mode 100644 index 00000000000..2762f0ce9bd --- /dev/null +++ b/paddle/fluid/operators/multinomial_op.cu @@ -0,0 +1,245 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include + +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/multinomial_op.h" +#include "paddle/fluid/platform/transform.h" + +namespace paddle { +namespace operators { + +template +__global__ void NormalizeProbability(T* norm_probs, const T* in_data, + T* sum_rows) { + int id = threadIdx.x + blockIdx.x * blockDim.x + + blockIdx.y * gridDim.x * blockDim.x; + norm_probs[id] = in_data[id] / sum_rows[blockIdx.y]; +} + +template +__global__ void GetCumulativeProbs(T* norm_probs_data, + int64_t num_distributions, + int64_t num_categories, + T* cumulative_probs) { + for (int id = blockIdx.x; id < num_distributions; id += gridDim.x) { + thrust::inclusive_scan(thrust::device, + norm_probs_data + id * num_categories, + norm_probs_data + (id + 1) * num_categories, + cumulative_probs + id * num_categories); + } +} + +template +struct RandomGeneratorCudaFunctor { + unsigned int seed_; + __host__ __device__ RandomGeneratorCudaFunctor(int seed) : seed_(seed) {} + + __host__ __device__ T operator()(const unsigned int n) const { + thrust::minstd_rand rng; + rng.seed(seed_); + thrust::uniform_real_distribution dist(0.0, 1.0); + rng.discard(n); + return dist(rng); + } +}; + +template +__device__ int binarySearchFunctor(T* cumulative_probs, T* norm_probs_data, + int num_categories, T rng_number) { + int left = 0; + int right = num_categories; + + while (right - left > 0) { + int mid = left + (right - left) / 2; + + T temp_prob = cumulative_probs[mid]; + if (temp_prob < rng_number) { + left = mid + 1; + } else { + right = mid; + } + } + + if (left == num_categories) { + left = num_categories - 1; + } + + while (left >= 1 && norm_probs_data[left] == 0) left--; + + return left; +} + +template +__global__ void sampleMultinomialWithReplacement( + T* rng_data, const int64_t num_samples, int64_t* out_data, + const int64_t num_distributions, const int64_t num_categories, + T* cumulative_probs, T* norm_probs_data) { + // use binary search to get the selected category sample id. + // let cumulative_probs[id-1] < rng_data < cumulative_probs[id]. + + int idx = threadIdx.x + blockIdx.x * blockDim.x + + blockIdx.y * gridDim.x * blockDim.x; + + // for every distribution + for (int dist = blockIdx.y; dist < num_distributions; dist += gridDim.y) { + // for every sample + for (int sample = blockIdx.x * blockDim.x + threadIdx.x; + sample < num_samples; sample += blockDim.x * gridDim.x) { + T rng_number = rng_data[sample + dist * num_samples]; + + // Find the bucket that a uniform random number lies in + int selected_category = binarySearchFunctor( + cumulative_probs + dist * num_categories, + norm_probs_data + dist * num_categories, num_categories, rng_number); + + out_data[sample + dist * num_samples] = selected_category; + } + } +} + +template +class MultinomialOpKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const auto x = ctx.Input("X"); + auto out = ctx.Output("Out"); + + const int64_t num_samples = ctx.Attr("num_samples"); + const bool replacement = ctx.Attr("replacement"); + + auto* in_data = x->data(); + int64_t* out_data = out->mutable_data(ctx.GetPlace()); + + auto in_dims = x->dims(); + int64_t in_rank = in_dims.size(); + const int64_t num_categories = in_dims[in_rank - 1]; + const int64_t num_distributions = in_rank > 1 ? in_dims[in_rank - 2] : 1; + + // If replacement is False, it's not a replaceable sample. Every category + // can + // be used only once. So after every sample, probability of the distribution + // will change. The implementation can't be parallelizable. Thus, call CPU + // implementation ``MultinomialFunctor`` to sample the distribution. + if (!replacement) { + int64_t in_data_numel = x->numel(); + int64_t out_data_numel = out->numel(); + + T* cpu_in_data = new T[in_data_numel]; + int64_t* cpu_out_data = new int64_t[out_data_numel]; + + cudaMemcpy(cpu_in_data, in_data, in_data_numel * sizeof(T), + cudaMemcpyDeviceToHost); + + MultinomialFunctor(cpu_out_data, cpu_in_data, num_samples, replacement, + num_categories, num_distributions); + cudaMemcpy(out_data, cpu_out_data, out_data_numel * sizeof(int64_t), + cudaMemcpyHostToDevice); + + delete[] cpu_in_data; + delete[] cpu_out_data; + return; + } + + // Sum of input may not be 1. To get probability in range [0, 1], calculate + // sum of each row of input, and then use the sum to normalize the input. + // sum_row_data: sum of each row + framework::Tensor sum_rows_tensor; + auto* sum_rows_data = + sum_rows_tensor.mutable_data({num_distributions}, ctx.GetPlace()); + + auto& place = *ctx.template device_context() + .eigen_device(); + + if (num_distributions == 1) { + auto eigen_input = framework::EigenVector::Flatten(*x); + auto eigen_sum_rows = framework::EigenVector::Flatten(sum_rows_tensor); + eigen_sum_rows.device(place) = + eigen_input.sum(Eigen::DSizes(1)) + .eval() + .reshape(Eigen::DSizes(sum_rows_tensor.dims()[0])); + } else { + auto eigen_input = framework::EigenMatrix::From(*x); + auto eigen_sum_rows = framework::EigenVector::Flatten(sum_rows_tensor); + eigen_sum_rows.device(place) = eigen_input.sum(Eigen::DSizes(1)); + } + + // Normalize row of each distribution to get the probability in range [0, + // 1]. + // norm_probs_data: probability of the distribution + framework::Tensor norm_probs_tensor; + auto* norm_probs_data = norm_probs_tensor.mutable_data( + {num_distributions, num_categories}, ctx.GetPlace()); + + // number of threads in a block is min(num_categories, 512) + dim3 block_norm(num_categories < 512 ? num_categories : 512); + dim3 grid_norm((num_categories - 1) / block_norm.x + 1, num_distributions); + NormalizeProbability< + T><<>>( + norm_probs_data, in_data, sum_rows_data); + + // Get cumulative probability of each distribution. It's the same function + // of + // ``cumsum`` op. + framework::Tensor cumulative_probs_tensor; + auto* cumulative_probs = cumulative_probs_tensor.mutable_data( + {num_distributions, num_categories}, ctx.GetPlace()); + dim3 block_cumsum(1); + dim3 grid_cumsum(num_distributions); + GetCumulativeProbs<<>>( + norm_probs_data, num_distributions, num_categories, cumulative_probs); + + // Generate random number for each sample. + std::random_device rd; + auto seed = rd(); + + framework::Tensor rng_data_tensor; + auto* rng_data = rng_data_tensor.mutable_data( + {num_distributions, num_samples}, ctx.GetPlace()); + + thrust::counting_iterator index_sequence_begin(0); + platform::Transform trans; + auto* context = + static_cast(&ctx.device_context()); + trans(*context, index_sequence_begin, + index_sequence_begin + num_distributions * num_samples, rng_data, + RandomGeneratorCudaFunctor(seed)); + + // Sample the multinomial distributions. + dim3 block_sample(128); + dim3 grid_sample((num_samples - 1) / block_sample.x + 1, num_distributions); + sampleMultinomialWithReplacement<<>>( + rng_data, num_samples, out_data, num_distributions, num_categories, + cumulative_probs, norm_probs_data); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL( + multinomial, ops::MultinomialOpKernel, + ops::MultinomialOpKernel); diff --git a/paddle/fluid/operators/multinomial_op.h b/paddle/fluid/operators/multinomial_op.h new file mode 100644 index 00000000000..420d2cd11e3 --- /dev/null +++ b/paddle/fluid/operators/multinomial_op.h @@ -0,0 +1,127 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/fluid/framework/generator.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/hostdevice.h" + +namespace paddle { +namespace operators { + +/** + * Samples a multinomial distribution given a probability input + */ + +template +void MultinomialFunctor(int64_t* out_data, const T* in_data, + const int64_t num_samples, const bool replacement, + const int64_t num_categories, + const int64_t num_distributions) { + std::vector cumulative_probs(num_categories); + + std::uniform_real_distribution dist(0, 1); + auto gen_ptr = framework::DefaultCPUGenerator(); + auto engine = gen_ptr->GetCPUEngine(); + + for (int64_t i = 0; i < num_distributions; i++) { + T probs_sum = 0; + T prob_value; + int64_t num_zeros = 0; + for (int64_t j = 0; j < num_categories; j++) { + prob_value = in_data[i * num_categories + j]; + PADDLE_ENFORCE_GE( + prob_value, 0.0, + platform::errors::OutOfRange( + "The input of multinomial distribution should be >= 0")); + PADDLE_ENFORCE_EQ((std::isinf(static_cast(prob_value)) || + std::isnan(static_cast(prob_value))), + false, platform::errors::OutOfRange( + "The input of multinomial distribution " + "shoud not be infinity or NaN")); + probs_sum += prob_value; + if (prob_value == 0) { + num_zeros += 1; + } + cumulative_probs[j] = probs_sum; + } + PADDLE_ENFORCE_GT(probs_sum, 0.0, platform::errors::OutOfRange( + "The sum of input should not be 0")); + PADDLE_ENFORCE_EQ( + (replacement || (num_categories - num_zeros >= num_samples)), true, + platform::errors::OutOfRange("When replacement is False, number of " + "samples should be less than non-zero " + "categories")); + + for (int64_t j = 0; j < num_categories; j++) { + cumulative_probs[j] /= probs_sum; + } + + for (int64_t s = 0; s < num_samples; s++) { + T uniform_rand = dist(*engine); + // use binary search to get the selected category sample id. + // let cumulative_probs[id-1] < uniform_rand < cumulative_probs[id]. + int64_t left = 0; + int64_t right = num_categories; + int64_t mid; + int64_t sample_id; + T temp_prob; + cumulative_probs[(num_categories - 1)] = 1; + + while (right > left) { + mid = left + (right - left) / 2; + temp_prob = cumulative_probs[mid]; + if (temp_prob < uniform_rand) { + left = mid + 1; + } else { + right = mid; + } + } + sample_id = left; + + out_data[i * num_samples + s] = sample_id; + + // if replacement is false, the selected category should be removed. + if (!replacement && s < num_samples - 1) { + T sample_prob; + T new_prob = 0; + T new_sum; + + if (sample_id != 0) { + new_prob = cumulative_probs[sample_id - 1]; + } + sample_prob = cumulative_probs[sample_id] - new_prob; + new_sum = 1.0 - sample_prob; + + for (int64_t j = 0; j < num_categories; j++) { + new_prob = cumulative_probs[j]; + if (j >= sample_id) { + new_prob -= sample_prob; + } + new_prob /= new_sum; + cumulative_probs[j] = new_prob; + } + } + } + } +} + +template +class MultinomialOpKernel; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 93b9a71ed7d..84713d513fb 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -201,6 +201,7 @@ from .tensor.math import isfinite #DEFINE_ALIAS from .tensor.math import isinf #DEFINE_ALIAS from .tensor.math import isnan #DEFINE_ALIAS from .tensor.math import prod #DEFINE_ALIAS +from .tensor.random import multinomial #DEFINE_ALIAS from .tensor.random import standard_normal from .tensor.random import normal from .tensor.random import uniform #DEFINE_ALIAS diff --git a/python/paddle/fluid/tests/unittests/test_multinomial_op.py b/python/paddle/fluid/tests/unittests/test_multinomial_op.py new file mode 100644 index 00000000000..7cca7738efd --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_multinomial_op.py @@ -0,0 +1,179 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle +import paddle.fluid as fluid +from op_test import OpTest +import numpy as np + + +class TestMultinomialOp(OpTest): + def setUp(self): + self.op_type = "multinomial" + self.init_data() + self.inputs = {"X": self.input_np} + + def init_data(self): + # input probability is a vector, and replacement is True + self.input_np = np.random.rand(4) + self.outputs = {"Out": np.zeros(100000).astype("int64")} + self.attrs = {"num_samples": 100000, "replacement": True} + + def test_check_output(self): + self.check_output_customized(self.verify_output) + + def sample_output(self, out): + # count numbers of different categories + sample_prob = np.unique(out, return_counts=True)[1].astype("float32") + sample_prob /= sample_prob.sum() + return sample_prob + + def verify_output(self, outs): + # normalize the input to get the probability + prob = self.input_np / self.input_np.sum(axis=-1, keepdims=True) + sample_prob = self.sample_output(np.array(outs[0])) + self.assertTrue( + np.allclose( + sample_prob, prob, rtol=0, atol=0.01), + "sample_prob: " + str(sample_prob) + "\nprob: " + str(prob)) + + +class TestMultinomialOp2(TestMultinomialOp): + def init_data(self): + # input probability is a matrix + self.input_np = np.random.rand(3, 4) + self.outputs = {"Out": np.zeros((3, 100000)).astype("int64")} + self.attrs = {"num_samples": 100000, "replacement": True} + + def sample_output(self, out): + out_list = np.split(out, 3, axis=0) + count_array = [0] * 3 + for i in range(3): + count_array[i] = np.unique( + out_list[i], return_counts=True)[1].astype("float32") + sample_prob = np.stack(count_array, axis=0) + sample_prob /= sample_prob.sum(axis=-1, keepdims=True) + return sample_prob + + +class TestMultinomialOp3(TestMultinomialOp): + def init_data(self): + # replacement is False. number of samples must be less than number of categories. + self.input_np = np.random.rand(1000) + self.outputs = {"Out": np.zeros(100).astype("int64")} + self.attrs = {"num_samples": 100, "replacement": False} + + def verify_output(self, outs): + out = np.array(outs[0]) + unique_out = np.unique(out) + self.assertEqual( + len(unique_out), 100, + "replacement is False. categories can't be sampled repeatedly") + + +class TestMultinomialApi(unittest.TestCase): + def test_dygraph(self): + # input probability is a vector, and replacement is True + paddle.disable_static() + x = paddle.rand([4]) + out = paddle.multinomial(x, num_samples=100000, replacement=True) + x_numpy = x.numpy() + paddle.enable_static() + + sample_prob = np.unique( + out.numpy(), return_counts=True)[1].astype("float32") + sample_prob /= sample_prob.sum() + + prob = x_numpy / x_numpy.sum(axis=-1, keepdims=True) + self.assertTrue( + np.allclose( + sample_prob, prob, rtol=0, atol=0.01), + "sample_prob: " + str(sample_prob) + "\nprob: " + str(prob)) + + def test_dygraph2(self): + # input probability is a matrix, and replacement is True + paddle.disable_static() + x = paddle.rand([3, 4]) + out = paddle.multinomial(x, num_samples=100000, replacement=True) + x_numpy = x.numpy() + + out_list = np.split(out.numpy(), 3, axis=0) + count_array = [0] * 3 + for i in range(3): + count_array[i] = np.unique( + out_list[i], return_counts=True)[1].astype("float32") + sample_prob = np.stack(count_array, axis=0) + sample_prob /= sample_prob.sum(axis=-1, keepdims=True) + + prob = x_numpy / x_numpy.sum(axis=-1, keepdims=True) + self.assertTrue( + np.allclose( + sample_prob, prob, rtol=0, atol=0.01), + "sample_prob: " + str(sample_prob) + "\nprob: " + str(prob)) + paddle.enable_static() + + def test_dygraph3(self): + # replacement is False. number of samples must be less than number of categories. + paddle.disable_static() + x = paddle.rand([1000]) + out = paddle.multinomial(x, num_samples=100, replacement=False) + x_numpy = x.numpy() + + unique_out = np.unique(out.numpy()) + self.assertEqual( + len(unique_out), 100, + "replacement is False. categories can't be sampled repeatedly") + paddle.enable_static() + + def test_static(self): + paddle.enable_static() + startup_program = fluid.Program() + train_program = fluid.Program() + with fluid.program_guard(train_program, startup_program): + x = fluid.data('x', shape=[4], dtype='float32') + out = paddle.multinomial(x, num_samples=100000, replacement=True) + + place = fluid.CPUPlace() + if fluid.core.is_compiled_with_cuda(): + place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + + exe.run(startup_program) + x_np = np.random.rand(4).astype('float32') + out = exe.run(train_program, feed={'x': x_np}, fetch_list=[out]) + + sample_prob = np.unique(out, return_counts=True)[1].astype("float32") + sample_prob /= sample_prob.sum() + + prob = x_np / x_np.sum(axis=-1, keepdims=True) + self.assertTrue( + np.allclose( + sample_prob, prob, rtol=0, atol=0.01), + "sample_prob: " + str(sample_prob) + "\nprob: " + str(prob)) + + +class TestMultinomialAlias(unittest.TestCase): + def test_alias(self): + paddle.disable_static() + x = paddle.rand([4]) + paddle.multinomial(x, num_samples=10, replacement=True) + paddle.tensor.multinomial(x, num_samples=10, replacement=True) + paddle.tensor.random.multinomial(x, num_samples=10, replacement=True) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index b6bab16c968..940bd1a4674 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -166,6 +166,7 @@ from .math import isfinite #DEFINE_ALIAS from .math import isinf #DEFINE_ALIAS from .math import isnan #DEFINE_ALIAS from .math import prod #DEFINE_ALIAS +from .random import multinomial #DEFINE_ALIAS from .random import standard_normal from .random import normal from .random import uniform #DEFINE_ALIAS diff --git a/python/paddle/tensor/random.py b/python/paddle/tensor/random.py index 9ffd81995ed..a46946cea86 100644 --- a/python/paddle/tensor/random.py +++ b/python/paddle/tensor/random.py @@ -23,6 +23,7 @@ import paddle __all__ = [ 'bernoulli', + 'multinomial', 'standard_normal', 'normal', 'uniform', @@ -85,6 +86,71 @@ def bernoulli(x, name=None): return out +def multinomial(x, num_samples=1, replacement=False, name=None): + """ + This OP returns a Tensor filled with random values sampled from a Multinomical + distribution. The input ``x`` is a tensor with probabilities for generating the + random number. Each element in ``x`` should be larger or equal to 0, but not all + 0. ``replacement`` indicates whether it is a replaceable sample. If ``replacement`` + is True, a category can be sampled more than once. + + Args: + x(Tensor): A tensor with probabilities for generating the random number. The data type + should be float32, float64. + num_samples(int, optional): Number of samples, default is 1. + replacement(bool, optional): Whether it is a replaceable sample, default is False. + name(str, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please + refer to :ref:`api_guide_Name`. + Returns: + Tensor: A Tensor filled with sampled category index after ``num_samples`` times samples. + + Examples: + .. code-block:: python + + import paddle + + paddle.disable_static() + + x = paddle.rand([2,4]) + print(x.numpy()) + # [[0.7713825 0.4055941 0.433339 0.70706886] + # [0.9223313 0.8519825 0.04574518 0.16560672]] + + out1 = paddle.multinomial(x, num_samples=5, replacement=True) + print(out1.numpy()) + # [[3 3 1 1 0] + # [0 0 0 0 1]] + + # out2 = paddle.multinomial(x, num_samples=5) + # OutOfRangeError: When replacement is False, number of samples + # should be less than non-zero categories + + out3 = paddle.multinomial(x, num_samples=3) + print(out3.numpy()) + # [[0 2 3] + # [0 1 3]] + + """ + + if in_dygraph_mode(): + return core.ops.multinomial(x, 'num_samples', num_samples, + 'replacement', replacement) + + check_variable_and_dtype(x, "x", ["float32", "float64"], "multinomial") + + helper = LayerHelper("multinomial", **locals()) + out = helper.create_variable_for_type_inference( + dtype=convert_np_dtype_to_dtype_('int64')) + helper.append_op( + type='multinomial', + inputs={"X": x}, + outputs={'Out': out}, + attrs={'num_samples': num_samples, + 'replacement': replacement}) + return out + + def gaussian(shape, mean=0.0, std=1.0, dtype=None, name=None): """ This OP returns a Tensor filled with random values sampled from a Gaussian -- GitLab