diff --git a/paddle/fluid/operators/multinomial_op.cc b/paddle/fluid/operators/multinomial_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..844164a0b8afd2e7d06062e3dc5762a0c06f0403 --- /dev/null +++ b/paddle/fluid/operators/multinomial_op.cc @@ -0,0 +1,103 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include "paddle/fluid/operators/multinomial_op.h" + +#include +#include +#include + +#include "paddle/fluid/framework/generator.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/common_infer_shape_functions.h" + +namespace paddle { +namespace operators { + +class MultinomialOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "A tensor contains probabilities of categories"); + AddOutput("Out", "The output tensor of multinomial op"); + AddAttr("num_samples", "number of the generated samples") + .SetDefault(1); + AddAttr("replacement", "can a category be sampled more than once") + .SetDefault(false); + AddComment(R"DOC( +This OP returns a Tensor filled with the sampled categoris according to Multinomial probabilities. + + Out ~ Multinomial(X) + +)DOC"); + } +}; + +class MultinomialOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Multinomial"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Multinomial"); + + auto x_dim = ctx->GetInputDim("X"); + int64_t x_rank = x_dim.size(); + std::vector out_dims(x_rank); + for (int64_t i = 0; i < x_rank - 1; i++) { + out_dims[i] = x_dim[i]; + } + + int64_t num_samples = ctx->Attrs().Get("num_samples"); + out_dims[x_rank - 1] = num_samples; + + ctx->SetOutputDim("Out", framework::make_ddim(out_dims)); + } +}; + +template +class MultinomialOpKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + const auto x = ctx.Input("X"); + auto out = ctx.Output("Out"); + const int64_t num_samples = ctx.Attr("num_samples"); + const bool replacement = ctx.Attr("replacement"); + + auto *in_data = x->data(); + auto *out_data = out->mutable_data(ctx.GetPlace()); + + auto in_dims = x->dims(); + int64_t in_rank = in_dims.size(); + const int64_t num_categories = in_dims[in_rank - 1]; + const int64_t num_distributions = in_rank > 1 ? in_dims[in_rank - 2] : 1; + + MultinomialFunctor(out_data, in_data, num_samples, replacement, + num_categories, num_distributions); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OPERATOR( + multinomial, ops::MultinomialOp, ops::MultinomialOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); + +REGISTER_OP_CPU_KERNEL( + multinomial, ops::MultinomialOpKernel, + ops::MultinomialOpKernel); diff --git a/paddle/fluid/operators/multinomial_op.h b/paddle/fluid/operators/multinomial_op.h new file mode 100644 index 0000000000000000000000000000000000000000..05c96ffb4b0ff44bab25818274f488d17c3c2a21 --- /dev/null +++ b/paddle/fluid/operators/multinomial_op.h @@ -0,0 +1,127 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/generator.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/hostdevice.h" + +namespace paddle { +namespace operators { + +/** + * Samples a multinomial distribution given a probability input + */ + +template +void MultinomialFunctor(T* out_data, const T* in_data, + const int64_t num_samples, const bool replacement, + const int64_t num_categories, + const int64_t num_distributions) { + C = num_categories; + T cumulative_probs[C]; + + std::uniform_real_distribution dist(0, 1); + auto gen_ptr = framework::DefaultCPUGenerator(); + auto engine = gen_ptr->GetCPUEngine(); + + for (int64_t i = 0; i < num_distributions; i++) { + T probs_sum = 0; + T prob_value; + int64_t num_zeros = 0; + for (int64_t j = 0; j < num_categories; j++) { + prob_value = in_data[i * num_categories + j]; + PADDLE_ENFORCE_GE( + prob_value, 0.0, + platform::errors::OutOfRange( + "The input of multinomial distribution should be >= 0")); + PADDLE_ENFORCE_EQ((std::isinf(static_cast(prob_value)) || + std::isnan(static_cast(prob_value))), + false, platform::errors::OutOfRange( + "The input of multinomial distribution " + "shoud not be infinity or NaN")); + probs_sum += prob_value; + if (prob_value == 0) { + num_zeros += 1; + } + cumulative_probs[j] = probs_sum; + } + PADDLE_ENFORCE_GT(probs_sum, 0.0, platform::errors::OutOfRange( + "The sum of input should not be 0")); + PADDLE_ENFORCE_EQ( + (replacement || (num_categories - num_zeros >= num_samples)), true, + platform::errors::OutOfRange("When replacement is False, number of " + "samples should be less than non-zero " + "categories")); + + for (int64_t j = 0; j < num_categories; j++) { + cumulative_probs[j] /= probs_sum; + } + + for (int64_t s = 0; s < num_samples; s++) { + T uniform_rand = dist(*engine); + // use binary search to get the selected category sample id. + // let cumulative_probs[id-1] < uniform_rand < cumulative_probs[id]. + int64_t left = 0; + int64_t right = num_categories; + int64_t mid; + int64_t sample_id; + T temp_prob; + cumulative_probs[(num_categories - 1)] = 1; + + while (right > left) { + mid = left + (right - left) / 2; + temp_prob = cumulative_probs[mid]; + if (temp_prob < uniform_rand) { + left = mid + 1; + } else { + right = mid; + } + } + sample_id = left; + + out_data[i * num_samples + s] = sample_id; + + // if replacement is false, the selected category should be removed. + if (!replacement && s < num_samples - 1) { + T sample_prob; + T new_prob = 0; + T new_sum; + + if (sample_id != 0) { + new_prob = cumulative_probs[sample_id - 1]; + } + sample_prob = cumulative_probs[sample_id] - new_prob; + new_sum = 1.0 - sample_prob; + + for (int64_t j = 0; j < num_categories; j++) { + new_prob = cumulative_probs[j]; + if (j >= sample_id) { + new_prob -= sample_prob; + } + new_prob /= new_sum; + cumulative_probs[j] = new_prob; + } + } + } + } +} + +template +class MultinomialOpKernel; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_multinomial_op.py b/python/paddle/fluid/tests/unittests/test_multinomial_op.py new file mode 100644 index 0000000000000000000000000000000000000000..0168ae6bc6d1b41e335c937fb68520c4577b4ed5 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_multinomial_op.py @@ -0,0 +1,97 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle +from op_test import OpTest +import numpy as np + + +class TestMultinomialOp(OpTest): + def setUp(self): + self.op_type = "multinomial" + self.init_data() + self.inputs = {"X": self.input_np} + + def init_data(self): + # input probability is a vector, and replacement is True + self.input_np = np.random.rand(4) + self.outputs = {"Out": np.zeros(100000).astype("int64")} + self.attrs = {"num_samples": 100000, "replacement": True} + + def test_check_output(self): + self.check_output_customized(self.verify_output) + + def sample_output(self, out): + # count numbers of different categories + sample_prob = np.unique(out, return_counts=True)[1].astype("float32") + sample_prob /= sample_prob.sum() + return sample_prob + + def verify_output(self, outs): + # normalize the input to get the probability + prob = self.input_np / self.input_np.sum(axis=-1, keepdims=True) + sample_prob = self.sample_output(np.array(outs[0])) + self.assertTrue( + np.allclose( + sample_prob, prob, rtol=0, atol=0.01), + "sample_prob: " + str(sample_prob) + "\nprob: " + str(prob)) + + +class TestMultinomialOp2(TestMultinomialOp): + def init_data(self): + # input probability is a matrix + self.input_np = np.random.rand(3, 4) + self.outputs = {"Out": np.zeros((3, 100000)).astype("int64")} + self.attrs = {"num_samples": 100000, "replacement": True} + + def sample_output(self, out): + out_list = np.split(out, 3, axis=0) + count_array = [0] * 3 + for i in range(3): + count_array[i] = np.unique( + out_list[i], return_counts=True)[1].astype("float32") + sample_prob = np.stack(count_array, axis=0) + sample_prob /= sample_prob.sum(axis=-1, keepdims=True) + return sample_prob + + +class TestMultinomialOp3(TestMultinomialOp): + def init_data(self): + # replacement is False. number of samples must be less than number of categories. + self.input_np = np.random.rand(1000) + self.outputs = {"Out": np.zeros(100).astype("int64")} + self.attrs = {"num_samples": 100, "replacement": False} + + def verify_output(self, outs): + out = np.array(outs[0]) + unique_out = np.unique(out) + self.assertEqual( + len(unique_out), 100, + "replacement is False. categories can't be sampled repeatedly") + + +""" +class TestReplacementError(unittest.TestCase): + def init_data(self): + # replacement is False. if number of samples is larger than number of categories, raise error. + self.input_np = np.random.rand(4) + self.outputs = {"Out": np.zeros(10).astype("int64")} + self.attrs = {"num_samples": 10, "replacement": False} +""" + +if __name__ == "__main__": + unittest.main()