Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
19d80b87
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
19d80b87
编写于
7月 22, 2020
作者:
P
peixu_ren
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix minor errors in probabilistic programming
上级
380db207
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
18 addition
and
10 deletion
+18
-10
mindspore/ops/composite/__init__.py
mindspore/ops/composite/__init__.py
+2
-1
mindspore/ops/composite/random_ops.py
mindspore/ops/composite/random_ops.py
+14
-7
tests/st/ops/ascend/test_aicpu_ops/test_standard_normal.py
tests/st/ops/ascend/test_aicpu_ops/test_standard_normal.py
+1
-1
tests/st/ops/gpu/test_standard_normal.py
tests/st/ops/gpu/test_standard_normal.py
+1
-1
未找到文件。
mindspore/ops/composite/__init__.py
浏览文件 @
19d80b87
...
@@ -27,7 +27,7 @@ from .clip_ops import clip_by_value
...
@@ -27,7 +27,7 @@ from .clip_ops import clip_by_value
from
.multitype_ops.add_impl
import
hyper_add
from
.multitype_ops.add_impl
import
hyper_add
from
.multitype_ops.ones_like_impl
import
ones_like
from
.multitype_ops.ones_like_impl
import
ones_like
from
.multitype_ops.zeros_like_impl
import
zeros_like
from
.multitype_ops.zeros_like_impl
import
zeros_like
from
.random_ops
import
normal
from
.random_ops
import
set_seed
,
normal
__all__
=
[
__all__
=
[
...
@@ -48,5 +48,6 @@ __all__ = [
...
@@ -48,5 +48,6 @@ __all__ = [
'zeros_like'
,
'zeros_like'
,
'ones_like'
,
'ones_like'
,
'zip_operation'
,
'zip_operation'
,
'set_seed'
,
'normal'
,
'normal'
,
'clip_by_value'
,]
'clip_by_value'
,]
mindspore/ops/composite/random_ops.py
浏览文件 @
19d80b87
...
@@ -15,8 +15,11 @@
...
@@ -15,8 +15,11 @@
"""Operations for random number generatos."""
"""Operations for random number generatos."""
from
mindspore.ops.primitive
import
constexpr
from
..
import
operations
as
P
from
..
import
operations
as
P
from
..
import
functional
as
F
from
..primitive
import
constexpr
from
.multitype_ops
import
_constexpr_utils
as
const_utils
from
...common
import
dtype
as
mstype
# set graph-level RNG seed
# set graph-level RNG seed
_GRAPH_SEED
=
0
_GRAPH_SEED
=
0
...
@@ -31,17 +34,17 @@ def get_seed():
...
@@ -31,17 +34,17 @@ def get_seed():
return
_GRAPH_SEED
return
_GRAPH_SEED
def
normal
(
shape
,
mean
,
stddev
,
seed
):
def
normal
(
shape
,
mean
,
stddev
,
seed
=
0
):
"""
"""
Generates random numbers according to the Normal (or Gaussian) random number distribution.
Generates random numbers according to the Normal (or Gaussian) random number distribution.
It is defined as:
It is defined as:
Args:
Args:
- **shape** (tuple) -
The shape of random tensor to be generated.
shape (tuple):
The shape of random tensor to be generated.
- **mean** (Tensor) -
The mean μ distribution parameter, which specifies the location of the peak.
mean (Tensor):
The mean μ distribution parameter, which specifies the location of the peak.
With float32 data type.
With float32 data type.
- **stddev** (Tensor) -
The deviation σ distribution parameter. With float32 data type.
stddev (Tensor):
The deviation σ distribution parameter. With float32 data type.
- **seed**
(int): Seed is used as entropy source for Random number engines generating pseudo-random numbers.
seed
(int): Seed is used as entropy source for Random number engines generating pseudo-random numbers.
Default: 0.
Default: 0.
Returns:
Returns:
...
@@ -52,9 +55,13 @@ def normal(shape, mean, stddev, seed):
...
@@ -52,9 +55,13 @@ def normal(shape, mean, stddev, seed):
>>> shape = (4, 16)
>>> shape = (4, 16)
>>> mean = Tensor(1.0, mstype.float32)
>>> mean = Tensor(1.0, mstype.float32)
>>> stddev = Tensor(1.0, mstype.float32)
>>> stddev = Tensor(1.0, mstype.float32)
>>> C.set_seed(10)
>>> output = C.normal(shape, mean, stddev, seed=5)
>>> output = C.normal(shape, mean, stddev, seed=5)
"""
"""
set_seed
(
10
)
mean_dtype
=
F
.
dtype
(
mean
)
stddev_dtype
=
F
.
dtype
(
stddev
)
const_utils
.
check_tensors_dtype_same
(
mean_dtype
,
mstype
.
float32
,
"normal"
)
const_utils
.
check_tensors_dtype_same
(
stddev_dtype
,
mstype
.
float32
,
"normal"
)
seed1
=
get_seed
()
seed1
=
get_seed
()
seed2
=
seed
seed2
=
seed
stdnormal
=
P
.
StandardNormal
(
seed1
,
seed2
)
stdnormal
=
P
.
StandardNormal
(
seed1
,
seed2
)
...
...
tests/st/ops/ascend/test_aicpu_ops/test_standard_normal.py
浏览文件 @
19d80b87
...
@@ -29,7 +29,7 @@ class Net(nn.Cell):
...
@@ -29,7 +29,7 @@ class Net(nn.Cell):
self
.
stdnormal
=
P
.
StandardNormal
(
seed
,
seed2
)
self
.
stdnormal
=
P
.
StandardNormal
(
seed
,
seed2
)
def
construct
(
self
):
def
construct
(
self
):
return
self
.
stdnormal
(
self
.
shape
,
self
.
seed
,
self
.
seed2
)
return
self
.
stdnormal
(
self
.
shape
)
def
test_net
():
def
test_net
():
...
...
tests/st/ops/gpu/test_standard_normal.py
浏览文件 @
19d80b87
...
@@ -29,7 +29,7 @@ class Net(nn.Cell):
...
@@ -29,7 +29,7 @@ class Net(nn.Cell):
self
.
stdnormal
=
P
.
StandardNormal
(
seed
,
seed2
)
self
.
stdnormal
=
P
.
StandardNormal
(
seed
,
seed2
)
def
construct
(
self
):
def
construct
(
self
):
return
self
.
stdnormal
(
self
.
shape
,
self
.
seed
,
self
.
seed2
)
return
self
.
stdnormal
(
self
.
shape
)
def
test_net
():
def
test_net
():
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录