提交 c1b09be8 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3786 add gpu multinomial sample python code

Merge pull request !3786 from baihuawei/multinomial
......@@ -19,6 +19,9 @@ from mindspore.ops import _utils as utils
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import composite as C
import mindspore.nn as nn
def cast_to_tensor(t, dtype=mstype.float32):
"""
......@@ -196,3 +199,55 @@ def check_prob(p):
comp = np.greater(p.asnumpy(), np.ones(p.shape))
if comp.any():
raise ValueError('Probabilities should be less than or equal to one')
def logits_to_probs(logits, is_binary=False):
"""
converts logits into probabilities.
Args:
logits (Tensor)
is_binary (bool)
"""
if is_binary:
return nn.sigmoid()(logits)
return nn.softmax(axis=-1)(logits)
def clamp_probs(probs):
"""
clamp probs boundary
Args:
probs (Tensor)
"""
eps = P.Eps()(probs)
return C.clip_by_value(probs, eps, 1-eps)
def probs_to_logits(probs, is_binary=False):
"""
converts probabilities into logits.
Args:
probs (Tensor)
is_binary (bool)
"""
ps_clamped = clamp_probs(probs)
if is_binary:
return P.Log()(ps_clamped) - P.Log()(1-ps_clamped)
return P.Log()(ps_clamped)
def check_tensor_type(name, inputs, valid_type):
"""
Check if inputs is proper.
Args:
inputs: Tensor to be checked.
name: inputs name
Raises:
ValueError: if inputs is not a proper Tensor.
"""
if not isinstance(inputs, Tensor):
raise TypeError(f"{name} should be a Tensor")
inputs = P.DType()(inputs)
if inputs not in valid_type:
raise TypeError(f"{name} dtype is invalid")
......@@ -27,7 +27,7 @@ from .clip_ops import clip_by_value
from .multitype_ops.add_impl import hyper_add
from .multitype_ops.ones_like_impl import ones_like
from .multitype_ops.zeros_like_impl import zeros_like
from .random_ops import set_seed, normal
from .random_ops import set_seed, normal, multinomial
__all__ = [
......@@ -50,4 +50,5 @@ __all__ = [
'zip_operation',
'set_seed',
'normal',
'multinomial',
'clip_by_value',]
......@@ -20,6 +20,9 @@ from .. import functional as F
from ..primitive import constexpr
from .multitype_ops import _constexpr_utils as const_utils
from ...common import dtype as mstype
from ...common.tensor import Tensor
from ..._checkparam import Validator as validator
from ..._checkparam import Rel
# set graph-level RNG seed
_GRAPH_SEED = 0
......@@ -68,3 +71,51 @@ def normal(shape, mean, stddev, seed=0):
rnd = stdnormal(shape)
value = rnd * stddev + mean
return value
def multinomial(inputs, num_sample=None, replacement=True, seed=0):
r"""
Returns a tensor sampled from the multinomial probability distribution located in the corresponding
row of tensor input.
Note:
The rows of input do not need to sum to one (in which case we use the values as weights),
but must be non-negative, finite and have a non-zero sum.
Args:
seed (int): Seed data is used as entropy source for Random number engines generating pseudo-random numbers.
Default: 0.
Inputs:
- **input** (Tensor) - the input tensor containing probabilities, must be 1 or 2 dims.
- **num_samples** (int) - number of samples to draw, default None.
- **replacement** (bool, optional) - whether to draw with replacement or not, default True.
Outputs:
Tensor. have the same rows with input, each row has num_samples sampled indices.
Examples:
>>> input = Tensor([0, 9, 4, 0], mstype.float32)
>>> output = C.multinomial(input, 2, True)
"""
shape = P.Shape()
reshape = P.Reshape()
validator.check_value_type('replacement', replacement, (bool,), None)
validator.check_value_type('num_sample', num_sample, (int,), None)
validator.check_integer("num_sample", num_sample, 0, Rel.GT, None)
if inputs.dim() != 1 and inputs.dim() != 2:
raise ValueError("inputs dim must be 1d or 2d")
if not replacement:
if shape(inputs)[-1] < num_sample:
raise ValueError("num_sample must be less than shape(input)[-1] without replacement")
n_dist = 1
if len(shape(inputs)) > 1:
n_dist = shape(inputs)[-2]
a = Tensor(0.0, mstype.float32)
b = Tensor(1.0, mstype.float32)
uniform = P.UniformReal(seed=seed)((n_dist * num_sample,), a, b)
if n_dist != 1:
uniform = reshape(uniform, (n_dist, num_sample))
vals = P.RealDiv()(P.Log()(uniform), inputs + 1e-6)
_, indices = P.TopK()(vals, num_sample)
return indices
return P.Multinomial(seed=seed)(inputs, num_sample)
......@@ -57,7 +57,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan)
from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal,
RandomCategorical, Laplace)
RandomCategorical, Laplace, Multinomial)
from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, ApplyMomentum, BatchNorm,
BiasAdd, Conv2D,
DepthwiseConv2dNative,
......@@ -184,6 +184,7 @@ __all__ = [
'Tanh',
'RandomChoiceWithMask',
'StandardNormal',
'Multinomial',
'Gamma',
'Poisson',
'UniformInt',
......
......@@ -409,6 +409,7 @@ class RandomCategorical(PrimitiveWithInfer):
>>> net = Net(8)
>>> output = net(Tensor(x))
"""
@prim_attr_register
def __init__(self, dtype=mstype.int64):
"""Init RandomCategorical"""
......@@ -436,3 +437,54 @@ class RandomCategorical(PrimitiveWithInfer):
return {'shape': (x_shape),
'dtype': (self.dtype),
'value': None}
class Multinomial(PrimitiveWithInfer):
r"""
Returns a tensor sampled from the multinomial probability distribution located in the corresponding
row of tensor input.
Note:
The rows of input do not need to sum to one (in which case we use the values as weights),
but must be non-negative, finite and have a non-zero sum.
Args:
seed (int): Seed data is used as entropy source for Random number engines generating pseudo-random numbers.
Default: 0.
Inputs:
- **input** (Tensor[float32]) - the input tensor containing the cumsum of probabilities, must be 1 or 2 dims.
- **num_samples** (int) - number of samples to draw.
Outputs:
Tensor. have the same rows with input, each row has num_samples sampled indices.
Examples:
>>> input = Tensor([0., 9., 4., 0.], mstype.float32)
>>> multinomial = P.Multinomial(seed=10)
>>> output = multinomial(input, 2)
"""
@prim_attr_register
def __init__(self, seed=0):
"""init"""
validator.check_value_type("seed", seed, [int], self.name)
self.init_prim_io_names(inputs=['input', 'num_sample'], outputs=['output'])
def __infer__(self, inputs, num_samples):
input_shape = inputs["shape"]
if len(input_shape) != 1 and len(input_shape) != 2:
raise ValueError("input dim must be 1 or 2")
validator.check_tensor_type_same({'inputs': inputs['dtype']}, [mstype.float32], self.name)
num_samples_value = num_samples["value"]
if num_samples_value is None:
raise ValueError(f"For {self.name}, shape nust be const")
validator.check_value_type("num_samples", num_samples_value, [int], self.name)
validator.check_integer("num_samples", num_samples_value, 0, Rel.GT, None)
y_shape = (num_samples_value,)
if len(input_shape) == 2:
y_shape = (input_shape[0], num_samples_value)
out = {
"shape": y_shape,
"dtype": mstype.int32,
"value": None}
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册