Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
c1b09be8
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c1b09be8
编写于
8月 06, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 06, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3786 add gpu multinomial sample python code
Merge pull request !3786 from baihuawei/multinomial
上级
bb776efe
4d92c5b3
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
162 addition
and
2 deletion
+162
-2
mindspore/nn/probability/distribution/_utils/utils.py
mindspore/nn/probability/distribution/_utils/utils.py
+55
-0
mindspore/ops/composite/__init__.py
mindspore/ops/composite/__init__.py
+2
-1
mindspore/ops/composite/random_ops.py
mindspore/ops/composite/random_ops.py
+51
-0
mindspore/ops/operations/__init__.py
mindspore/ops/operations/__init__.py
+2
-1
mindspore/ops/operations/random_ops.py
mindspore/ops/operations/random_ops.py
+52
-0
未找到文件。
mindspore/nn/probability/distribution/_utils/utils.py
浏览文件 @
c1b09be8
...
...
@@ -19,6 +19,9 @@ from mindspore.ops import _utils as utils
from
mindspore.common.tensor
import
Tensor
from
mindspore.common.parameter
import
Parameter
from
mindspore.common
import
dtype
as
mstype
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
composite
as
C
import
mindspore.nn
as
nn
def
cast_to_tensor
(
t
,
dtype
=
mstype
.
float32
):
"""
...
...
@@ -196,3 +199,55 @@ def check_prob(p):
comp
=
np
.
greater
(
p
.
asnumpy
(),
np
.
ones
(
p
.
shape
))
if
comp
.
any
():
raise
ValueError
(
'Probabilities should be less than or equal to one'
)
def
logits_to_probs
(
logits
,
is_binary
=
False
):
"""
converts logits into probabilities.
Args:
logits (Tensor)
is_binary (bool)
"""
if
is_binary
:
return
nn
.
sigmoid
()(
logits
)
return
nn
.
softmax
(
axis
=-
1
)(
logits
)
def
clamp_probs
(
probs
):
"""
clamp probs boundary
Args:
probs (Tensor)
"""
eps
=
P
.
Eps
()(
probs
)
return
C
.
clip_by_value
(
probs
,
eps
,
1
-
eps
)
def
probs_to_logits
(
probs
,
is_binary
=
False
):
"""
converts probabilities into logits.
Args:
probs (Tensor)
is_binary (bool)
"""
ps_clamped
=
clamp_probs
(
probs
)
if
is_binary
:
return
P
.
Log
()(
ps_clamped
)
-
P
.
Log
()(
1
-
ps_clamped
)
return
P
.
Log
()(
ps_clamped
)
def
check_tensor_type
(
name
,
inputs
,
valid_type
):
"""
Check if inputs is proper.
Args:
inputs: Tensor to be checked.
name: inputs name
Raises:
ValueError: if inputs is not a proper Tensor.
"""
if
not
isinstance
(
inputs
,
Tensor
):
raise
TypeError
(
f
"
{
name
}
should be a Tensor"
)
inputs
=
P
.
DType
()(
inputs
)
if
inputs
not
in
valid_type
:
raise
TypeError
(
f
"
{
name
}
dtype is invalid"
)
mindspore/ops/composite/__init__.py
浏览文件 @
c1b09be8
...
...
@@ -27,7 +27,7 @@ from .clip_ops import clip_by_value
from
.multitype_ops.add_impl
import
hyper_add
from
.multitype_ops.ones_like_impl
import
ones_like
from
.multitype_ops.zeros_like_impl
import
zeros_like
from
.random_ops
import
set_seed
,
normal
from
.random_ops
import
set_seed
,
normal
,
multinomial
__all__
=
[
...
...
@@ -50,4 +50,5 @@ __all__ = [
'zip_operation'
,
'set_seed'
,
'normal'
,
'multinomial'
,
'clip_by_value'
,]
mindspore/ops/composite/random_ops.py
浏览文件 @
c1b09be8
...
...
@@ -20,6 +20,9 @@ from .. import functional as F
from
..primitive
import
constexpr
from
.multitype_ops
import
_constexpr_utils
as
const_utils
from
...common
import
dtype
as
mstype
from
...common.tensor
import
Tensor
from
..._checkparam
import
Validator
as
validator
from
..._checkparam
import
Rel
# set graph-level RNG seed
_GRAPH_SEED
=
0
...
...
@@ -68,3 +71,51 @@ def normal(shape, mean, stddev, seed=0):
rnd
=
stdnormal
(
shape
)
value
=
rnd
*
stddev
+
mean
return
value
def
multinomial
(
inputs
,
num_sample
=
None
,
replacement
=
True
,
seed
=
0
):
r
"""
Returns a tensor sampled from the multinomial probability distribution located in the corresponding
row of tensor input.
Note:
The rows of input do not need to sum to one (in which case we use the values as weights),
but must be non-negative, finite and have a non-zero sum.
Args:
seed (int): Seed data is used as entropy source for Random number engines generating pseudo-random numbers.
Default: 0.
Inputs:
- **input** (Tensor) - the input tensor containing probabilities, must be 1 or 2 dims.
- **num_samples** (int) - number of samples to draw, default None.
- **replacement** (bool, optional) - whether to draw with replacement or not, default True.
Outputs:
Tensor. have the same rows with input, each row has num_samples sampled indices.
Examples:
>>> input = Tensor([0, 9, 4, 0], mstype.float32)
>>> output = C.multinomial(input, 2, True)
"""
shape
=
P
.
Shape
()
reshape
=
P
.
Reshape
()
validator
.
check_value_type
(
'replacement'
,
replacement
,
(
bool
,),
None
)
validator
.
check_value_type
(
'num_sample'
,
num_sample
,
(
int
,),
None
)
validator
.
check_integer
(
"num_sample"
,
num_sample
,
0
,
Rel
.
GT
,
None
)
if
inputs
.
dim
()
!=
1
and
inputs
.
dim
()
!=
2
:
raise
ValueError
(
"inputs dim must be 1d or 2d"
)
if
not
replacement
:
if
shape
(
inputs
)[
-
1
]
<
num_sample
:
raise
ValueError
(
"num_sample must be less than shape(input)[-1] without replacement"
)
n_dist
=
1
if
len
(
shape
(
inputs
))
>
1
:
n_dist
=
shape
(
inputs
)[
-
2
]
a
=
Tensor
(
0.0
,
mstype
.
float32
)
b
=
Tensor
(
1.0
,
mstype
.
float32
)
uniform
=
P
.
UniformReal
(
seed
=
seed
)((
n_dist
*
num_sample
,),
a
,
b
)
if
n_dist
!=
1
:
uniform
=
reshape
(
uniform
,
(
n_dist
,
num_sample
))
vals
=
P
.
RealDiv
()(
P
.
Log
()(
uniform
),
inputs
+
1e-6
)
_
,
indices
=
P
.
TopK
()(
vals
,
num_sample
)
return
indices
return
P
.
Multinomial
(
seed
=
seed
)(
inputs
,
num_sample
)
mindspore/ops/operations/__init__.py
浏览文件 @
c1b09be8
...
...
@@ -57,7 +57,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
Square
,
Sub
,
TensorAdd
,
Sign
,
Round
,
SquareSumAll
,
Atan
,
Atanh
,
Cosh
,
Sinh
,
Eps
,
Tan
)
from
.random_ops
import
(
RandomChoiceWithMask
,
StandardNormal
,
Gamma
,
Poisson
,
UniformInt
,
UniformReal
,
RandomCategorical
,
Laplace
)
RandomCategorical
,
Laplace
,
Multinomial
)
from
.nn_ops
import
(
LSTM
,
SGD
,
Adam
,
FusedSparseAdam
,
FusedSparseLazyAdam
,
ApplyMomentum
,
BatchNorm
,
BiasAdd
,
Conv2D
,
DepthwiseConv2dNative
,
...
...
@@ -184,6 +184,7 @@ __all__ = [
'Tanh'
,
'RandomChoiceWithMask'
,
'StandardNormal'
,
'Multinomial'
,
'Gamma'
,
'Poisson'
,
'UniformInt'
,
...
...
mindspore/ops/operations/random_ops.py
浏览文件 @
c1b09be8
...
...
@@ -409,6 +409,7 @@ class RandomCategorical(PrimitiveWithInfer):
>>> net = Net(8)
>>> output = net(Tensor(x))
"""
@
prim_attr_register
def
__init__
(
self
,
dtype
=
mstype
.
int64
):
"""Init RandomCategorical"""
...
...
@@ -436,3 +437,54 @@ class RandomCategorical(PrimitiveWithInfer):
return
{
'shape'
:
(
x_shape
),
'dtype'
:
(
self
.
dtype
),
'value'
:
None
}
class
Multinomial
(
PrimitiveWithInfer
):
r
"""
Returns a tensor sampled from the multinomial probability distribution located in the corresponding
row of tensor input.
Note:
The rows of input do not need to sum to one (in which case we use the values as weights),
but must be non-negative, finite and have a non-zero sum.
Args:
seed (int): Seed data is used as entropy source for Random number engines generating pseudo-random numbers.
Default: 0.
Inputs:
- **input** (Tensor[float32]) - the input tensor containing the cumsum of probabilities, must be 1 or 2 dims.
- **num_samples** (int) - number of samples to draw.
Outputs:
Tensor. have the same rows with input, each row has num_samples sampled indices.
Examples:
>>> input = Tensor([0., 9., 4., 0.], mstype.float32)
>>> multinomial = P.Multinomial(seed=10)
>>> output = multinomial(input, 2)
"""
@
prim_attr_register
def
__init__
(
self
,
seed
=
0
):
"""init"""
validator
.
check_value_type
(
"seed"
,
seed
,
[
int
],
self
.
name
)
self
.
init_prim_io_names
(
inputs
=
[
'input'
,
'num_sample'
],
outputs
=
[
'output'
])
def
__infer__
(
self
,
inputs
,
num_samples
):
input_shape
=
inputs
[
"shape"
]
if
len
(
input_shape
)
!=
1
and
len
(
input_shape
)
!=
2
:
raise
ValueError
(
"input dim must be 1 or 2"
)
validator
.
check_tensor_type_same
({
'inputs'
:
inputs
[
'dtype'
]},
[
mstype
.
float32
],
self
.
name
)
num_samples_value
=
num_samples
[
"value"
]
if
num_samples_value
is
None
:
raise
ValueError
(
f
"For
{
self
.
name
}
, shape nust be const"
)
validator
.
check_value_type
(
"num_samples"
,
num_samples_value
,
[
int
],
self
.
name
)
validator
.
check_integer
(
"num_samples"
,
num_samples_value
,
0
,
Rel
.
GT
,
None
)
y_shape
=
(
num_samples_value
,)
if
len
(
input_shape
)
==
2
:
y_shape
=
(
input_shape
[
0
],
num_samples_value
)
out
=
{
"shape"
:
y_shape
,
"dtype"
:
mstype
.
int32
,
"value"
:
None
}
return
out
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录