提交 9297ba0a 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5048 fix gpu multinomial

Merge pull request !5048 from baihuawei/0821
......@@ -33,7 +33,12 @@ namespace kernel {
template <typename T>
class MultinomialGpuKernel : public GpuKernel {
public:
MultinomialGpuKernel() : input_size_0_(0), output_size_(0), distributions_(0), workspace_size_(sizeof(curandState)) {}
MultinomialGpuKernel()
: input_size_0_(0),
output_size_(0),
distributions_(0),
workspace_size_(sizeof(curandState)),
replacement_(true) {}
~MultinomialGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
......@@ -49,6 +54,19 @@ class MultinomialGpuKernel : public GpuKernel {
int categories = SizeToInt(inputs[0]->size / sizeof(T)) / distributions_;
int num_sample = SizeToInt(outputs[0]->size / sizeof(T)) / distributions_;
// check input
T *cum_sum_input = nullptr;
CHECK_CUDA_RET_WITH_EXCEPT(cudaMalloc(reinterpret_cast<void **>(&cum_sum_input), input_size_0_),
"cudaMalloc failed.");
CheckPeram(input_addr, cum_sum_input, categories, stream_ptr);
if (replacement_) {
Multinomial(seed_, cum_sum_input, num_sample, devStates, output_addr, IntToSize(distributions_),
IntToSize(categories), reinterpret_cast<cudaStream_t>(stream_ptr));
}
CHECK_CUDA_RET_WITH_EXCEPT(cudaFree(cum_sum_input), "cudaFree failed.");
return true;
}
void CheckPeram(const T *input_addr, T *cum_sum_input, int categories, void *stream_ptr) {
T *flag = nullptr;
T *cflag = nullptr;
CHECK_CUDA_RET_WITH_EXCEPT(cudaMalloc(reinterpret_cast<void **>(&cflag), sizeof(T)), "cudaMalloc failed.");
......@@ -67,9 +85,6 @@ class MultinomialGpuKernel : public GpuKernel {
if (*flag > 0) {
MS_LOG(EXCEPTION) << "Input is invalid (input element < 0)";
}
T *cum_sum_input = nullptr;
CHECK_CUDA_RET_WITH_EXCEPT(cudaMalloc(reinterpret_cast<void **>(&cum_sum_input), input_size_0_),
"cudaMalloc failed.");
CumSum(input_addr, cum_sum_input, cum_sum_input, IntToSize(distributions_), IntToSize(categories), 1,
IntToSize(categories), 1, false, false, reinterpret_cast<cudaStream_t>(stream_ptr));
CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_ptr)),
......@@ -82,14 +97,10 @@ class MultinomialGpuKernel : public GpuKernel {
if (*flag > 0) {
MS_LOG(EXCEPTION) << "Input is invalid (sum <= 0)";
}
Multinomial(seed_, cum_sum_input, num_sample, devStates, output_addr, IntToSize(distributions_),
IntToSize(categories), reinterpret_cast<cudaStream_t>(stream_ptr));
CHECK_CUDA_RET_WITH_EXCEPT(cudaFree(cum_sum_input), "cudaFree failed.");
CHECK_CUDA_RET_WITH_EXCEPT(cudaFree(cflag), "cudaFree failed.");
CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(flag), "cudaFreeHost failed.");
return true;
}
bool Init(const CNodePtr &kernel_node) override {
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
......@@ -114,9 +125,15 @@ class MultinomialGpuKernel : public GpuKernel {
}
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
output_size_ = sizeof(int);
for (size_t i = 0; i < output_shape.size(); i++) {
output_size_ *= output_shape[i];
workspace_size_ *= output_shape[i];
workspace_size_ = sizeof(int);
replacement_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("replacement"));
if (replacement_) {
for (size_t i = 0; i < output_shape.size(); i++) {
output_size_ *= output_shape[i];
}
}
if (replacement_) {
workspace_size_ = output_size_;
}
seed_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed"));
InitSizeLists();
......@@ -136,6 +153,7 @@ class MultinomialGpuKernel : public GpuKernel {
size_t output_size_;
size_t distributions_;
size_t workspace_size_;
bool replacement_;
int seed_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
......
......@@ -20,8 +20,6 @@ from .. import functional as F
from ..primitive import constexpr
from .multitype_ops import _constexpr_utils as const_utils
from ...common import dtype as mstype
from ..._checkparam import Validator as validator
from ..._checkparam import Rel
# set graph-level RNG seed
_GRAPH_SEED = 0
......@@ -204,14 +202,13 @@ def multinomial(inputs, num_sample, replacement=True, seed=0):
Note:
The rows of input do not need to sum to one (in which case we use the values as weights),
but must be non-negative, finite and have a non-zero sum.
Args:
seed (int): Seed data is used as entropy source for Random number engines generating pseudo-random numbers.
Default: 0.
Inputs:
- **input** (Tensor) - the input tensor containing probabilities, must be 1 or 2 dims.
- **num_samples** (int) - number of samples to draw.
- **replacement** (bool, optional) - whether to draw with replacement or not, default True.
Args:
input (Tensor) - the input tensor containing probabilities, must be 1 or 2 dims.
num_samples (int) - number of samples to draw.
replacement (bool, optional) - whether to draw with replacement or not, default True.
seed (int, optional) - used as entropy source for Random number engines generating pseudo-random numbers.
Must be non-negative. Default: 0.
Outputs:
Tensor. have the same rows with input, each row has num_samples sampled indices.
......@@ -222,21 +219,19 @@ def multinomial(inputs, num_sample, replacement=True, seed=0):
"""
shape = P.Shape()
reshape = P.Reshape()
validator.check_value_type('replacement', replacement, (bool,), None)
validator.check_value_type('num_sample', num_sample, (int,), None)
validator.check_integer("num_sample", num_sample, 0, Rel.GT, None)
if inputs.dim() != 1 and inputs.dim() != 2:
raise ValueError("inputs dim must be 1d or 2d")
if not replacement:
P.Multinomial(replacement=replacement, seed=seed)(inputs, num_sample)
if shape(inputs)[-1] < num_sample:
raise ValueError("num_sample must be less than shape(input)[-1] without replacement")
n_dist = 1
if len(shape(inputs)) > 1:
n_dist = shape(inputs)[-2]
random_uniform = P.UniformReal(seed=seed)((n_dist * num_sample,))
random_uniform = P.UniformReal(seed=seed)((n_dist * shape(inputs)[-1],))
if n_dist != 1:
random_uniform = reshape(random_uniform, (n_dist, num_sample))
random_uniform = reshape(random_uniform, (n_dist, shape(inputs)[-1]))
vals = P.RealDiv()(P.Log()(random_uniform), inputs + 1e-6)
_, indices = P.TopK()(vals, num_sample)
return indices
return P.Multinomial(seed=seed)(inputs, num_sample)
return P.Multinomial(replacement=replacement, seed=seed)(inputs, num_sample)
......@@ -438,11 +438,12 @@ class Multinomial(PrimitiveWithInfer):
but must be non-negative, finite and have a non-zero sum.
Args:
seed (int): Seed data is used as entropy source for Random number engines generating pseudo-random numbers.
Default: 0.
Must be non-negative. Default: 0.
replacement(bool) - whether to draw with replacement or not.
Inputs:
- **input** (Tensor[float32]) - the input tensor containing the cumsum of probabilities, must be 1 or 2 dims.
- **num_samples** (int) - number of samples to draw.
- **num_samples** (int32) - number of samples to draw.
Outputs:
Tensor. have the same rows with input, each row has num_samples sampled indices.
......@@ -450,13 +451,15 @@ class Multinomial(PrimitiveWithInfer):
Examples:
>>> input = Tensor([0., 9., 4., 0.], mstype.float32)
>>> multinomial = P.Multinomial(seed=10)
>>> output = multinomial(input, 2)
>>> output = multinomial(input, 2, True)
"""
@prim_attr_register
def __init__(self, seed=0):
def __init__(self, replacement=True, seed=0):
"""init"""
validator.check_value_type("seed", seed, [int], self.name)
validator.check_integer("seed", seed, 0, Rel.GE, self.name)
validator.check_value_type("replacement", replacement, [bool], self.name)
self.init_prim_io_names(inputs=['input', 'num_sample'], outputs=['output'])
def __infer__(self, inputs, num_samples):
......@@ -467,7 +470,7 @@ class Multinomial(PrimitiveWithInfer):
num_samples_value = num_samples["value"]
if num_samples_value is None:
raise ValueError(f"For {self.name}, shape nust be const")
validator.check_value_type("num_samples", num_samples_value, [int], self.name)
validator.check_value_type("num_samples", num_samples_value, (int,), self.name)
validator.check_integer("num_samples", num_samples_value, 0, Rel.GT, None)
y_shape = (num_samples_value,)
if len(input_shape) == 2:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册