未验证 提交 7cd2c13f 编写于 作者: P pangyoki 提交者: GitHub

add multinomial op (#27219)

* add multinomial cpu kernel

* fix C++ notype error

* fix windows ci array len error

* let array len be const

* change array to vector

* add cuda kernrl with num_distribution is 1, and not support replacement=False

* add multinomial python api

* support num_distribution different multinomial distributions

* add multinomial python api unittest

* change output dtype to int64

* fix coverage prob

* optimize format

* fix dtype of output error, should be int64_t
上级 d2369dd9
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/multinomial_op.h"
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/common_infer_shape_functions.h"
namespace paddle {
namespace operators {
class MultinomialOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "A tensor contains probabilities of categories");
AddOutput("Out", "The output tensor of multinomial op");
AddAttr<int>("num_samples", "number of the generated samples")
.SetDefault(1);
AddAttr<bool>("replacement", "can a category be sampled more than once")
.SetDefault(false);
AddComment(R"DOC(
This OP returns a Tensor filled with the sampled categoris according to Multinomial probabilities.
Out ~ Multinomial(X)
)DOC");
}
};
class MultinomialOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Multinomial");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Multinomial");
auto x_dim = ctx->GetInputDim("X");
int64_t x_rank = x_dim.size();
std::vector<int64_t> out_dims(x_rank);
for (int64_t i = 0; i < x_rank - 1; i++) {
out_dims[i] = x_dim[i];
}
int64_t num_samples = ctx->Attrs().Get<int>("num_samples");
out_dims[x_rank - 1] = num_samples;
ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
}
};
template <typename T>
class MultinomialOpKernel<platform::CPUDeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const auto x = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
const int64_t num_samples = ctx.Attr<int>("num_samples");
const bool replacement = ctx.Attr<bool>("replacement");
auto *in_data = x->data<T>();
int64_t *out_data = out->mutable_data<int64_t>(ctx.GetPlace());
auto in_dims = x->dims();
int64_t in_rank = in_dims.size();
const int64_t num_categories = in_dims[in_rank - 1];
const int64_t num_distributions = in_rank > 1 ? in_dims[in_rank - 2] : 1;
MultinomialFunctor<T>(out_data, in_data, num_samples, replacement,
num_categories, num_distributions);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OPERATOR(
multinomial, ops::MultinomialOp, ops::MultinomialOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
multinomial, ops::MultinomialOpKernel<plat::CPUDeviceContext, float>,
ops::MultinomialOpKernel<plat::CPUDeviceContext, double>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/execution_policy.h>
#include <thrust/random.h>
#include <thrust/scan.h>
#include <thrust/transform.h>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/multinomial_op.h"
#include "paddle/fluid/platform/transform.h"
namespace paddle {
namespace operators {
template <typename T>
__global__ void NormalizeProbability(T* norm_probs, const T* in_data,
T* sum_rows) {
int id = threadIdx.x + blockIdx.x * blockDim.x +
blockIdx.y * gridDim.x * blockDim.x;
norm_probs[id] = in_data[id] / sum_rows[blockIdx.y];
}
template <typename T>
__global__ void GetCumulativeProbs(T* norm_probs_data,
int64_t num_distributions,
int64_t num_categories,
T* cumulative_probs) {
for (int id = blockIdx.x; id < num_distributions; id += gridDim.x) {
thrust::inclusive_scan(thrust::device,
norm_probs_data + id * num_categories,
norm_probs_data + (id + 1) * num_categories,
cumulative_probs + id * num_categories);
}
}
template <typename T>
struct RandomGeneratorCudaFunctor {
unsigned int seed_;
__host__ __device__ RandomGeneratorCudaFunctor(int seed) : seed_(seed) {}
__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed_);
thrust::uniform_real_distribution<T> dist(0.0, 1.0);
rng.discard(n);
return dist(rng);
}
};
template <typename T>
__device__ int binarySearchFunctor(T* cumulative_probs, T* norm_probs_data,
int num_categories, T rng_number) {
int left = 0;
int right = num_categories;
while (right - left > 0) {
int mid = left + (right - left) / 2;
T temp_prob = cumulative_probs[mid];
if (temp_prob < rng_number) {
left = mid + 1;
} else {
right = mid;
}
}
if (left == num_categories) {
left = num_categories - 1;
}
while (left >= 1 && norm_probs_data[left] == 0) left--;
return left;
}
template <typename T>
__global__ void sampleMultinomialWithReplacement(
T* rng_data, const int64_t num_samples, int64_t* out_data,
const int64_t num_distributions, const int64_t num_categories,
T* cumulative_probs, T* norm_probs_data) {
// use binary search to get the selected category sample id.
// let cumulative_probs[id-1] < rng_data < cumulative_probs[id].
int idx = threadIdx.x + blockIdx.x * blockDim.x +
blockIdx.y * gridDim.x * blockDim.x;
// for every distribution
for (int dist = blockIdx.y; dist < num_distributions; dist += gridDim.y) {
// for every sample
for (int sample = blockIdx.x * blockDim.x + threadIdx.x;
sample < num_samples; sample += blockDim.x * gridDim.x) {
T rng_number = rng_data[sample + dist * num_samples];
// Find the bucket that a uniform random number lies in
int selected_category = binarySearchFunctor<T>(
cumulative_probs + dist * num_categories,
norm_probs_data + dist * num_categories, num_categories, rng_number);
out_data[sample + dist * num_samples] = selected_category;
}
}
}
template <typename T>
class MultinomialOpKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto x = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
const int64_t num_samples = ctx.Attr<int>("num_samples");
const bool replacement = ctx.Attr<bool>("replacement");
auto* in_data = x->data<T>();
int64_t* out_data = out->mutable_data<int64_t>(ctx.GetPlace());
auto in_dims = x->dims();
int64_t in_rank = in_dims.size();
const int64_t num_categories = in_dims[in_rank - 1];
const int64_t num_distributions = in_rank > 1 ? in_dims[in_rank - 2] : 1;
// If replacement is False, it's not a replaceable sample. Every category
// can
// be used only once. So after every sample, probability of the distribution
// will change. The implementation can't be parallelizable. Thus, call CPU
// implementation ``MultinomialFunctor`` to sample the distribution.
if (!replacement) {
int64_t in_data_numel = x->numel();
int64_t out_data_numel = out->numel();
T* cpu_in_data = new T[in_data_numel];
int64_t* cpu_out_data = new int64_t[out_data_numel];
cudaMemcpy(cpu_in_data, in_data, in_data_numel * sizeof(T),
cudaMemcpyDeviceToHost);
MultinomialFunctor<T>(cpu_out_data, cpu_in_data, num_samples, replacement,
num_categories, num_distributions);
cudaMemcpy(out_data, cpu_out_data, out_data_numel * sizeof(int64_t),
cudaMemcpyHostToDevice);
delete[] cpu_in_data;
delete[] cpu_out_data;
return;
}
// Sum of input may not be 1. To get probability in range [0, 1], calculate
// sum of each row of input, and then use the sum to normalize the input.
// sum_row_data: sum of each row
framework::Tensor sum_rows_tensor;
auto* sum_rows_data =
sum_rows_tensor.mutable_data<T>({num_distributions}, ctx.GetPlace());
auto& place = *ctx.template device_context<platform::CUDADeviceContext>()
.eigen_device();
if (num_distributions == 1) {
auto eigen_input = framework::EigenVector<T>::Flatten(*x);
auto eigen_sum_rows = framework::EigenVector<T>::Flatten(sum_rows_tensor);
eigen_sum_rows.device(place) =
eigen_input.sum(Eigen::DSizes<int, 1>(1))
.eval()
.reshape(Eigen::DSizes<int, 1>(sum_rows_tensor.dims()[0]));
} else {
auto eigen_input = framework::EigenMatrix<T>::From(*x);
auto eigen_sum_rows = framework::EigenVector<T>::Flatten(sum_rows_tensor);
eigen_sum_rows.device(place) = eigen_input.sum(Eigen::DSizes<int, 1>(1));
}
// Normalize row of each distribution to get the probability in range [0,
// 1].
// norm_probs_data: probability of the distribution
framework::Tensor norm_probs_tensor;
auto* norm_probs_data = norm_probs_tensor.mutable_data<T>(
{num_distributions, num_categories}, ctx.GetPlace());
// number of threads in a block is min(num_categories, 512)
dim3 block_norm(num_categories < 512 ? num_categories : 512);
dim3 grid_norm((num_categories - 1) / block_norm.x + 1, num_distributions);
NormalizeProbability<
T><<<grid_norm, block_norm, 0, ctx.cuda_device_context().stream()>>>(
norm_probs_data, in_data, sum_rows_data);
// Get cumulative probability of each distribution. It's the same function
// of
// ``cumsum`` op.
framework::Tensor cumulative_probs_tensor;
auto* cumulative_probs = cumulative_probs_tensor.mutable_data<T>(
{num_distributions, num_categories}, ctx.GetPlace());
dim3 block_cumsum(1);
dim3 grid_cumsum(num_distributions);
GetCumulativeProbs<T><<<grid_cumsum, block_cumsum, 0,
ctx.cuda_device_context().stream()>>>(
norm_probs_data, num_distributions, num_categories, cumulative_probs);
// Generate random number for each sample.
std::random_device rd;
auto seed = rd();
framework::Tensor rng_data_tensor;
auto* rng_data = rng_data_tensor.mutable_data<T>(
{num_distributions, num_samples}, ctx.GetPlace());
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
platform::Transform<platform::CUDADeviceContext> trans;
auto* context =
static_cast<const platform::CUDADeviceContext*>(&ctx.device_context());
trans(*context, index_sequence_begin,
index_sequence_begin + num_distributions * num_samples, rng_data,
RandomGeneratorCudaFunctor<T>(seed));
// Sample the multinomial distributions.
dim3 block_sample(128);
dim3 grid_sample((num_samples - 1) / block_sample.x + 1, num_distributions);
sampleMultinomialWithReplacement<T><<<grid_sample, block_sample, 0,
ctx.cuda_device_context().stream()>>>(
rng_data, num_samples, out_data, num_distributions, num_categories,
cumulative_probs, norm_probs_data);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
multinomial, ops::MultinomialOpKernel<plat::CUDADeviceContext, float>,
ops::MultinomialOpKernel<plat::CUDADeviceContext, double>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/hostdevice.h"
namespace paddle {
namespace operators {
/**
* Samples a multinomial distribution given a probability input
*/
template <typename T>
void MultinomialFunctor(int64_t* out_data, const T* in_data,
const int64_t num_samples, const bool replacement,
const int64_t num_categories,
const int64_t num_distributions) {
std::vector<T> cumulative_probs(num_categories);
std::uniform_real_distribution<T> dist(0, 1);
auto gen_ptr = framework::DefaultCPUGenerator();
auto engine = gen_ptr->GetCPUEngine();
for (int64_t i = 0; i < num_distributions; i++) {
T probs_sum = 0;
T prob_value;
int64_t num_zeros = 0;
for (int64_t j = 0; j < num_categories; j++) {
prob_value = in_data[i * num_categories + j];
PADDLE_ENFORCE_GE(
prob_value, 0.0,
platform::errors::OutOfRange(
"The input of multinomial distribution should be >= 0"));
PADDLE_ENFORCE_EQ((std::isinf(static_cast<double>(prob_value)) ||
std::isnan(static_cast<double>(prob_value))),
false, platform::errors::OutOfRange(
"The input of multinomial distribution "
"shoud not be infinity or NaN"));
probs_sum += prob_value;
if (prob_value == 0) {
num_zeros += 1;
}
cumulative_probs[j] = probs_sum;
}
PADDLE_ENFORCE_GT(probs_sum, 0.0, platform::errors::OutOfRange(
"The sum of input should not be 0"));
PADDLE_ENFORCE_EQ(
(replacement || (num_categories - num_zeros >= num_samples)), true,
platform::errors::OutOfRange("When replacement is False, number of "
"samples should be less than non-zero "
"categories"));
for (int64_t j = 0; j < num_categories; j++) {
cumulative_probs[j] /= probs_sum;
}
for (int64_t s = 0; s < num_samples; s++) {
T uniform_rand = dist(*engine);
// use binary search to get the selected category sample id.
// let cumulative_probs[id-1] < uniform_rand < cumulative_probs[id].
int64_t left = 0;
int64_t right = num_categories;
int64_t mid;
int64_t sample_id;
T temp_prob;
cumulative_probs[(num_categories - 1)] = 1;
while (right > left) {
mid = left + (right - left) / 2;
temp_prob = cumulative_probs[mid];
if (temp_prob < uniform_rand) {
left = mid + 1;
} else {
right = mid;
}
}
sample_id = left;
out_data[i * num_samples + s] = sample_id;
// if replacement is false, the selected category should be removed.
if (!replacement && s < num_samples - 1) {
T sample_prob;
T new_prob = 0;
T new_sum;
if (sample_id != 0) {
new_prob = cumulative_probs[sample_id - 1];
}
sample_prob = cumulative_probs[sample_id] - new_prob;
new_sum = 1.0 - sample_prob;
for (int64_t j = 0; j < num_categories; j++) {
new_prob = cumulative_probs[j];
if (j >= sample_id) {
new_prob -= sample_prob;
}
new_prob /= new_sum;
cumulative_probs[j] = new_prob;
}
}
}
}
}
template <typename DeviceContext, typename T>
class MultinomialOpKernel;
} // namespace operators
} // namespace paddle
......@@ -201,6 +201,7 @@ from .tensor.math import isfinite #DEFINE_ALIAS
from .tensor.math import isinf #DEFINE_ALIAS
from .tensor.math import isnan #DEFINE_ALIAS
from .tensor.math import prod #DEFINE_ALIAS
from .tensor.random import multinomial #DEFINE_ALIAS
from .tensor.random import standard_normal
from .tensor.random import normal
from .tensor.random import uniform #DEFINE_ALIAS
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle
import paddle.fluid as fluid
from op_test import OpTest
import numpy as np
class TestMultinomialOp(OpTest):
def setUp(self):
self.op_type = "multinomial"
self.init_data()
self.inputs = {"X": self.input_np}
def init_data(self):
# input probability is a vector, and replacement is True
self.input_np = np.random.rand(4)
self.outputs = {"Out": np.zeros(100000).astype("int64")}
self.attrs = {"num_samples": 100000, "replacement": True}
def test_check_output(self):
self.check_output_customized(self.verify_output)
def sample_output(self, out):
# count numbers of different categories
sample_prob = np.unique(out, return_counts=True)[1].astype("float32")
sample_prob /= sample_prob.sum()
return sample_prob
def verify_output(self, outs):
# normalize the input to get the probability
prob = self.input_np / self.input_np.sum(axis=-1, keepdims=True)
sample_prob = self.sample_output(np.array(outs[0]))
self.assertTrue(
np.allclose(
sample_prob, prob, rtol=0, atol=0.01),
"sample_prob: " + str(sample_prob) + "\nprob: " + str(prob))
class TestMultinomialOp2(TestMultinomialOp):
def init_data(self):
# input probability is a matrix
self.input_np = np.random.rand(3, 4)
self.outputs = {"Out": np.zeros((3, 100000)).astype("int64")}
self.attrs = {"num_samples": 100000, "replacement": True}
def sample_output(self, out):
out_list = np.split(out, 3, axis=0)
count_array = [0] * 3
for i in range(3):
count_array[i] = np.unique(
out_list[i], return_counts=True)[1].astype("float32")
sample_prob = np.stack(count_array, axis=0)
sample_prob /= sample_prob.sum(axis=-1, keepdims=True)
return sample_prob
class TestMultinomialOp3(TestMultinomialOp):
def init_data(self):
# replacement is False. number of samples must be less than number of categories.
self.input_np = np.random.rand(1000)
self.outputs = {"Out": np.zeros(100).astype("int64")}
self.attrs = {"num_samples": 100, "replacement": False}
def verify_output(self, outs):
out = np.array(outs[0])
unique_out = np.unique(out)
self.assertEqual(
len(unique_out), 100,
"replacement is False. categories can't be sampled repeatedly")
class TestMultinomialApi(unittest.TestCase):
def test_dygraph(self):
# input probability is a vector, and replacement is True
paddle.disable_static()
x = paddle.rand([4])
out = paddle.multinomial(x, num_samples=100000, replacement=True)
x_numpy = x.numpy()
paddle.enable_static()
sample_prob = np.unique(
out.numpy(), return_counts=True)[1].astype("float32")
sample_prob /= sample_prob.sum()
prob = x_numpy / x_numpy.sum(axis=-1, keepdims=True)
self.assertTrue(
np.allclose(
sample_prob, prob, rtol=0, atol=0.01),
"sample_prob: " + str(sample_prob) + "\nprob: " + str(prob))
def test_dygraph2(self):
# input probability is a matrix, and replacement is True
paddle.disable_static()
x = paddle.rand([3, 4])
out = paddle.multinomial(x, num_samples=100000, replacement=True)
x_numpy = x.numpy()
out_list = np.split(out.numpy(), 3, axis=0)
count_array = [0] * 3
for i in range(3):
count_array[i] = np.unique(
out_list[i], return_counts=True)[1].astype("float32")
sample_prob = np.stack(count_array, axis=0)
sample_prob /= sample_prob.sum(axis=-1, keepdims=True)
prob = x_numpy / x_numpy.sum(axis=-1, keepdims=True)
self.assertTrue(
np.allclose(
sample_prob, prob, rtol=0, atol=0.01),
"sample_prob: " + str(sample_prob) + "\nprob: " + str(prob))
paddle.enable_static()
def test_dygraph3(self):
# replacement is False. number of samples must be less than number of categories.
paddle.disable_static()
x = paddle.rand([1000])
out = paddle.multinomial(x, num_samples=100, replacement=False)
x_numpy = x.numpy()
unique_out = np.unique(out.numpy())
self.assertEqual(
len(unique_out), 100,
"replacement is False. categories can't be sampled repeatedly")
paddle.enable_static()
def test_static(self):
paddle.enable_static()
startup_program = fluid.Program()
train_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
x = fluid.data('x', shape=[4], dtype='float32')
out = paddle.multinomial(x, num_samples=100000, replacement=True)
place = fluid.CPUPlace()
if fluid.core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(startup_program)
x_np = np.random.rand(4).astype('float32')
out = exe.run(train_program, feed={'x': x_np}, fetch_list=[out])
sample_prob = np.unique(out, return_counts=True)[1].astype("float32")
sample_prob /= sample_prob.sum()
prob = x_np / x_np.sum(axis=-1, keepdims=True)
self.assertTrue(
np.allclose(
sample_prob, prob, rtol=0, atol=0.01),
"sample_prob: " + str(sample_prob) + "\nprob: " + str(prob))
class TestMultinomialAlias(unittest.TestCase):
def test_alias(self):
paddle.disable_static()
x = paddle.rand([4])
paddle.multinomial(x, num_samples=10, replacement=True)
paddle.tensor.multinomial(x, num_samples=10, replacement=True)
paddle.tensor.random.multinomial(x, num_samples=10, replacement=True)
if __name__ == "__main__":
unittest.main()
......@@ -166,6 +166,7 @@ from .math import isfinite #DEFINE_ALIAS
from .math import isinf #DEFINE_ALIAS
from .math import isnan #DEFINE_ALIAS
from .math import prod #DEFINE_ALIAS
from .random import multinomial #DEFINE_ALIAS
from .random import standard_normal
from .random import normal
from .random import uniform #DEFINE_ALIAS
......
......@@ -23,6 +23,7 @@ import paddle
__all__ = [
'bernoulli',
'multinomial',
'standard_normal',
'normal',
'uniform',
......@@ -85,6 +86,71 @@ def bernoulli(x, name=None):
return out
def multinomial(x, num_samples=1, replacement=False, name=None):
"""
This OP returns a Tensor filled with random values sampled from a Multinomical
distribution. The input ``x`` is a tensor with probabilities for generating the
random number. Each element in ``x`` should be larger or equal to 0, but not all
0. ``replacement`` indicates whether it is a replaceable sample. If ``replacement``
is True, a category can be sampled more than once.
Args:
x(Tensor): A tensor with probabilities for generating the random number. The data type
should be float32, float64.
num_samples(int, optional): Number of samples, default is 1.
replacement(bool, optional): Whether it is a replaceable sample, default is False.
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
Returns:
Tensor: A Tensor filled with sampled category index after ``num_samples`` times samples.
Examples:
.. code-block:: python
import paddle
paddle.disable_static()
x = paddle.rand([2,4])
print(x.numpy())
# [[0.7713825 0.4055941 0.433339 0.70706886]
# [0.9223313 0.8519825 0.04574518 0.16560672]]
out1 = paddle.multinomial(x, num_samples=5, replacement=True)
print(out1.numpy())
# [[3 3 1 1 0]
# [0 0 0 0 1]]
# out2 = paddle.multinomial(x, num_samples=5)
# OutOfRangeError: When replacement is False, number of samples
# should be less than non-zero categories
out3 = paddle.multinomial(x, num_samples=3)
print(out3.numpy())
# [[0 2 3]
# [0 1 3]]
"""
if in_dygraph_mode():
return core.ops.multinomial(x, 'num_samples', num_samples,
'replacement', replacement)
check_variable_and_dtype(x, "x", ["float32", "float64"], "multinomial")
helper = LayerHelper("multinomial", **locals())
out = helper.create_variable_for_type_inference(
dtype=convert_np_dtype_to_dtype_('int64'))
helper.append_op(
type='multinomial',
inputs={"X": x},
outputs={'Out': out},
attrs={'num_samples': num_samples,
'replacement': replacement})
return out
def gaussian(shape, mean=0.0, std=1.0, dtype=None, name=None):
"""
This OP returns a Tensor filled with random values sampled from a Gaussian
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册