Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
86616ac5
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
86616ac5
编写于
8月 20, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 20, 2020
浏览文件
操作
浏览文件
下载
差异文件
!4773 Fix empty shape issue in distribution sample functions
Merge pull request !4773 from peixu_ren/custom_bijector
上级
0aeaa7f0
b4767023
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
56 addition
and
16 deletion
+56
-16
mindspore/nn/probability/distribution/_utils/utils.py
mindspore/nn/probability/distribution/_utils/utils.py
+1
-5
mindspore/nn/probability/distribution/bernoulli.py
mindspore/nn/probability/distribution/bernoulli.py
+11
-2
mindspore/nn/probability/distribution/exponential.py
mindspore/nn/probability/distribution/exponential.py
+11
-2
mindspore/nn/probability/distribution/geometric.py
mindspore/nn/probability/distribution/geometric.py
+11
-2
mindspore/nn/probability/distribution/normal.py
mindspore/nn/probability/distribution/normal.py
+11
-3
mindspore/nn/probability/distribution/uniform.py
mindspore/nn/probability/distribution/uniform.py
+11
-2
未找到文件。
mindspore/nn/probability/distribution/_utils/utils.py
浏览文件 @
86616ac5
...
...
@@ -45,10 +45,6 @@ def cast_to_tensor(t, hint_type=mstype.float32):
return
t
t_type
=
hint_type
if
isinstance
(
t
,
Tensor
):
#check if the Tensor in shape of Tensor(4)
if
t
.
dim
()
==
0
:
value
=
t
.
asnumpy
()
return
Tensor
([
value
],
dtype
=
t_type
)
#convert the type of tensor to dtype
return
Tensor
(
t
.
asnumpy
(),
dtype
=
t_type
)
if
isinstance
(
t
,
(
list
,
np
.
ndarray
)):
...
...
@@ -56,7 +52,7 @@ def cast_to_tensor(t, hint_type=mstype.float32):
if
isinstance
(
t
,
bool
):
raise
TypeError
(
f
'Input cannot be Type Bool'
)
if
isinstance
(
t
,
(
int
,
float
)):
return
Tensor
(
[
t
]
,
dtype
=
t_type
)
return
Tensor
(
t
,
dtype
=
t_type
)
raise
TypeError
(
"Input type is not supported."
)
def
convert_to_batch
(
t
,
batch_shape
,
required_type
):
...
...
mindspore/nn/probability/distribution/bernoulli.py
浏览文件 @
86616ac5
...
...
@@ -107,6 +107,7 @@ class Bernoulli(Distribution):
self
.
_probs
=
probs
# ops needed for the class
self
.
squeeze
=
P
.
Squeeze
(
0
)
self
.
cast
=
P
.
Cast
()
self
.
const
=
P
.
ScalarToArray
()
self
.
dtypeop
=
P
.
DType
()
...
...
@@ -284,8 +285,16 @@ class Bernoulli(Distribution):
probs1
=
self
.
cast
(
probs
,
self
.
parameter_type
)
if
probs
is
not
None
else
self
.
probs
if
probs1
is
None
:
raise_none_error
(
"probs"
)
origin_shape
=
shape
+
self
.
shape
(
probs1
)
if
origin_shape
==
():
sample_shape
=
(
1
,)
else
:
sample_shape
=
origin_shape
l_zero
=
self
.
const
(
0.0
)
h_one
=
self
.
const
(
1.0
)
sample_uniform
=
self
.
uniform
(
s
hape
+
self
.
shape
(
probs1
)
,
l_zero
,
h_one
,
self
.
seed
)
sample_uniform
=
self
.
uniform
(
s
ample_shape
,
l_zero
,
h_one
,
self
.
seed
)
sample
=
self
.
less
(
sample_uniform
,
probs1
)
return
self
.
cast
(
sample
,
self
.
dtype
)
value
=
self
.
cast
(
sample
,
self
.
dtype
)
if
origin_shape
==
():
value
=
self
.
squeeze
(
value
)
return
value
mindspore/nn/probability/distribution/exponential.py
浏览文件 @
86616ac5
...
...
@@ -111,6 +111,7 @@ class Exponential(Distribution):
self
.
minval
=
np
.
finfo
(
np
.
float
).
tiny
# ops needed for the class
self
.
squeeze
=
P
.
Squeeze
(
0
)
self
.
cast
=
P
.
Cast
()
self
.
const
=
P
.
ScalarToArray
()
self
.
dtypeop
=
P
.
DType
()
...
...
@@ -276,8 +277,16 @@ class Exponential(Distribution):
rate
=
self
.
cast
(
rate
,
self
.
parameter_type
)
if
rate
is
not
None
else
self
.
rate
if
rate
is
None
:
raise_none_error
(
"rate"
)
origin_shape
=
shape
+
self
.
shape
(
rate
)
if
origin_shape
==
():
sample_shape
=
(
1
,)
else
:
sample_shape
=
origin_shape
minval
=
self
.
const
(
self
.
minval
)
maxval
=
self
.
const
(
1.0
)
sample_uniform
=
self
.
uniform
(
s
hape
+
self
.
shape
(
rate
)
,
minval
,
maxval
,
self
.
seed
)
sample_uniform
=
self
.
uniform
(
s
ample_shape
,
minval
,
maxval
,
self
.
seed
)
sample
=
-
self
.
log
(
sample_uniform
)
/
rate
return
self
.
cast
(
sample
,
self
.
dtype
)
value
=
self
.
cast
(
sample
,
self
.
dtype
)
if
origin_shape
==
():
value
=
self
.
squeeze
(
value
)
return
value
mindspore/nn/probability/distribution/geometric.py
浏览文件 @
86616ac5
...
...
@@ -112,6 +112,7 @@ class Geometric(Distribution):
self
.
minval
=
np
.
finfo
(
np
.
float
).
tiny
# ops needed for the class
self
.
squeeze
=
P
.
Squeeze
(
0
)
self
.
cast
=
P
.
Cast
()
self
.
const
=
P
.
ScalarToArray
()
self
.
dtypeop
=
P
.
DType
()
...
...
@@ -283,8 +284,16 @@ class Geometric(Distribution):
probs1
=
self
.
cast
(
probs
,
self
.
parameter_type
)
if
probs
is
not
None
else
self
.
probs
if
probs1
is
None
:
raise_none_error
(
"probs"
)
origin_shape
=
shape
+
self
.
shape
(
probs1
)
if
origin_shape
==
():
sample_shape
=
(
1
,)
else
:
sample_shape
=
origin_shape
minval
=
self
.
const
(
self
.
minval
)
maxval
=
self
.
const
(
1.0
)
sample_uniform
=
self
.
uniform
(
s
hape
+
self
.
shape
(
probs1
)
,
minval
,
maxval
,
self
.
seed
)
sample_uniform
=
self
.
uniform
(
s
ample_shape
,
minval
,
maxval
,
self
.
seed
)
sample
=
self
.
floor
(
self
.
log
(
sample_uniform
)
/
self
.
log
(
1.0
-
probs1
))
return
self
.
cast
(
sample
,
self
.
dtype
)
value
=
self
.
cast
(
sample
,
self
.
dtype
)
if
origin_shape
==
():
value
=
self
.
squeeze
(
value
)
return
value
mindspore/nn/probability/distribution/normal.py
浏览文件 @
86616ac5
...
...
@@ -114,6 +114,7 @@ class Normal(Distribution):
#ops needed for the class
self
.
squeeze
=
P
.
Squeeze
(
0
)
self
.
cast
=
P
.
Cast
()
self
.
const
=
P
.
ScalarToArray
()
self
.
erf
=
P
.
Erf
()
...
...
@@ -305,7 +306,14 @@ class Normal(Distribution):
sd
=
self
.
cast
(
sd
,
self
.
parameter_type
)
if
sd
is
not
None
else
self
.
_sd_value
if
sd
is
None
:
raise_none_error
(
"sd"
)
batch_shape
=
self
.
shape
(
self
.
zeroslike
(
mean
)
+
self
.
zeroslike
(
sd
))
sample_shape
=
shape
+
batch_shape
batch_shape
=
self
.
shape
(
mean
+
sd
)
origin_shape
=
shape
+
batch_shape
if
origin_shape
==
():
sample_shape
=
(
1
,)
else
:
sample_shape
=
origin_shape
sample_norm
=
C
.
normal
(
sample_shape
,
mean
,
sd
,
self
.
seed
)
return
sample_norm
value
=
self
.
cast
(
sample_norm
,
self
.
dtype
)
if
origin_shape
==
():
value
=
self
.
squeeze
(
value
)
return
value
mindspore/nn/probability/distribution/uniform.py
浏览文件 @
86616ac5
...
...
@@ -112,6 +112,7 @@ class Uniform(Distribution):
self
.
_high
=
high
# ops needed for the class
self
.
squeeze
=
P
.
Squeeze
(
0
)
self
.
cast
=
P
.
Cast
()
self
.
const
=
P
.
ScalarToArray
()
self
.
dtypeop
=
P
.
DType
()
...
...
@@ -327,8 +328,16 @@ class Uniform(Distribution):
if
high
is
None
:
raise_none_error
(
"high"
)
broadcast_shape
=
self
.
shape
(
low
+
high
)
origin_shape
=
shape
+
broadcast_shape
if
origin_shape
==
():
sample_shape
=
(
1
,)
else
:
sample_shape
=
origin_shape
l_zero
=
self
.
const
(
0.0
)
h_one
=
self
.
const
(
1.0
)
sample_uniform
=
self
.
uniform
(
s
hape
+
broadcast
_shape
,
l_zero
,
h_one
,
self
.
seed
)
sample_uniform
=
self
.
uniform
(
s
ample
_shape
,
l_zero
,
h_one
,
self
.
seed
)
sample
=
(
high
-
low
)
*
sample_uniform
+
low
return
self
.
cast
(
sample
,
self
.
dtype
)
value
=
self
.
cast
(
sample
,
self
.
dtype
)
if
origin_shape
==
():
value
=
self
.
squeeze
(
value
)
return
value
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录