Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
8acb5bdf
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8acb5bdf
编写于
8月 20, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 20, 2020
浏览文件
操作
浏览文件
下载
差异文件
!4734 fix error of bnn_layers
Merge pull request !4734 from byweng/add_test
上级
86616ac5
d8a4827f
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
43 addition
and
43 deletion
+43
-43
mindspore/nn/probability/bnn_layers/conv_variational.py
mindspore/nn/probability/bnn_layers/conv_variational.py
+14
-14
mindspore/nn/probability/bnn_layers/dense_variational.py
mindspore/nn/probability/bnn_layers/dense_variational.py
+8
-14
mindspore/nn/probability/bnn_layers/layer_distribution.py
mindspore/nn/probability/bnn_layers/layer_distribution.py
+21
-15
未找到文件。
mindspore/nn/probability/bnn_layers/conv_variational.py
浏览文件 @
8acb5bdf
...
...
@@ -61,6 +61,12 @@ class _ConvVariational(_Conv):
raise
ValueError
(
'Attr
\'
pad_mode
\'
of
\'
Conv2d
\'
Op passed '
+
str
(
pad_mode
)
+
', should be one of values in
\'
valid
\'
,
\'
same
\'
,
\'
pad
\'
.'
)
if
not
isinstance
(
stride
,
(
int
,
tuple
)):
raise
TypeError
(
'The type of `stride` should be `int` of `tuple`'
)
if
not
isinstance
(
dilation
,
(
int
,
tuple
)):
raise
TypeError
(
'The type of `dilation` should be `int` of `tuple`'
)
# convolution args
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
...
...
@@ -87,13 +93,10 @@ class _ConvVariational(_Conv):
raise
TypeError
(
'The type of `weight_prior_fn` should be `NormalPrior`'
)
self
.
weight_prior
=
weight_prior_fn
()
if
isinstance
(
weight_posterior_fn
,
Cell
):
if
weight_posterior_fn
.
__class__
.
__name__
!=
'NormalPosterior'
:
raise
TypeError
(
'The type of `weight_posterior_fn` should be `NormalPosterior`'
)
else
:
if
weight_posterior_fn
.
__name__
!=
'NormalPosterior'
:
raise
TypeError
(
'The type of `weight_posterior_fn` should be `NormalPosterior`'
)
self
.
weight_posterior
=
weight_posterior_fn
(
shape
=
self
.
shape
,
name
=
'bnn_weight'
)
try
:
self
.
weight_posterior
=
weight_posterior_fn
(
shape
=
self
.
shape
,
name
=
'bnn_weight'
)
except
TypeError
:
raise
TypeError
(
'The type of `weight_posterior_fn` should be `NormalPosterior`'
)
if
self
.
has_bias
:
self
.
bias
.
requires_grad
=
False
...
...
@@ -107,13 +110,10 @@ class _ConvVariational(_Conv):
raise
TypeError
(
'The type of `bias_prior_fn` should be `NormalPrior`'
)
self
.
bias_prior
=
bias_prior_fn
()
if
isinstance
(
bias_posterior_fn
,
Cell
):
if
bias_posterior_fn
.
__class__
.
__name__
!=
'NormalPosterior'
:
raise
TypeError
(
'The type of `bias_posterior_fn` should be `NormalPosterior`'
)
else
:
if
bias_posterior_fn
.
__name__
!=
'NormalPosterior'
:
raise
TypeError
(
'The type of `bias_posterior_fn` should be `NormalPosterior`'
)
self
.
bias_posterior
=
bias_posterior_fn
(
shape
=
[
self
.
out_channels
],
name
=
'bnn_bias'
)
try
:
self
.
bias_posterior
=
bias_posterior_fn
(
shape
=
[
self
.
out_channels
],
name
=
'bnn_bias'
)
except
TypeError
:
raise
TypeError
(
'The type of `bias_posterior_fn` should be `NormalPosterior`'
)
# mindspore operations
self
.
bias_add
=
P
.
BiasAdd
()
...
...
mindspore/nn/probability/bnn_layers/dense_variational.py
浏览文件 @
8acb5bdf
...
...
@@ -51,13 +51,10 @@ class _DenseVariational(Cell):
raise
TypeError
(
'The type of `weight_prior_fn` should be `NormalPrior`'
)
self
.
weight_prior
=
weight_prior_fn
()
if
isinstance
(
weight_posterior_fn
,
Cell
):
if
weight_posterior_fn
.
__class__
.
__name__
!=
'NormalPosterior'
:
raise
TypeError
(
'The type of `weight_posterior_fn` should be `NormalPosterior`'
)
else
:
if
weight_posterior_fn
.
__name__
!=
'NormalPosterior'
:
raise
TypeError
(
'The type of `weight_posterior_fn` should be `NormalPosterior`'
)
self
.
weight_posterior
=
weight_posterior_fn
(
shape
=
[
self
.
out_channels
,
self
.
in_channels
],
name
=
'bnn_weight'
)
try
:
self
.
weight_posterior
=
weight_posterior_fn
(
shape
=
[
self
.
out_channels
,
self
.
in_channels
],
name
=
'bnn_weight'
)
except
TypeError
:
raise
TypeError
(
'The type of `weight_posterior_fn` should be `NormalPosterior`'
)
if
self
.
has_bias
:
if
isinstance
(
bias_prior_fn
,
Cell
):
...
...
@@ -69,13 +66,10 @@ class _DenseVariational(Cell):
raise
TypeError
(
'The type of `bias_prior_fn` should be `NormalPrior`'
)
self
.
bias_prior
=
bias_prior_fn
()
if
isinstance
(
bias_posterior_fn
,
Cell
):
if
bias_posterior_fn
.
__class__
.
__name__
!=
'NormalPosterior'
:
raise
TypeError
(
'The type of `bias_posterior_fn` should be `NormalPosterior`'
)
else
:
if
bias_posterior_fn
.
__name__
!=
'NormalPosterior'
:
raise
TypeError
(
'The type of `bias_posterior_fn` should be `NormalPosterior`'
)
self
.
bias_posterior
=
bias_posterior_fn
(
shape
=
[
self
.
out_channels
],
name
=
'bnn_bias'
)
try
:
self
.
bias_posterior
=
bias_posterior_fn
(
shape
=
[
self
.
out_channels
],
name
=
'bnn_bias'
)
except
TypeError
:
raise
TypeError
(
'The type of `bias_posterior_fn` should be `NormalPosterior`'
)
self
.
activation
=
activation
if
isinstance
(
self
.
activation
,
str
):
...
...
mindspore/nn/probability/bnn_layers/layer_distribution.py
浏览文件 @
8acb5bdf
...
...
@@ -51,15 +51,16 @@ class NormalPosterior(Cell):
Args:
name (str): Name prepended to trainable parameter.
shape (list): Shape of the mean and standard deviation.
shape (list
, tuple
): Shape of the mean and standard deviation.
dtype (class `mindspore.dtype`): The argument is used to define the data type of the output tensor.
Default: mindspore.float32.
loc_mean ( float, array_like of floats): Mean of distribution to initialize trainable parameters. Default: 0.
loc_std ( float, array_like of floats): Standard deviation of distribution to initialize trainable parameters.
Default: 0.1.
untransformed_scale_mean ( float, array_like of floats): Mean of distribution to initialize trainable
loc_mean (int, float, array_like of floats): Mean of distribution to initialize trainable parameters.
Default: 0.
loc_std (int, float, array_like of floats): Standard deviation of distribution to initialize trainable
parameters. Default: 0.1.
untransformed_scale_mean (int, float, array_like of floats): Mean of distribution to initialize trainable
parameters. Default: -5.
untransformed_scale_std ( float, array_like of floats): Standard deviation of distribution to initialize
untransformed_scale_std (
int,
float, array_like of floats): Standard deviation of distribution to initialize
trainable parameters. Default: 0.1.
Returns:
...
...
@@ -80,20 +81,25 @@ class NormalPosterior(Cell):
if
not
isinstance
(
shape
,
(
tuple
,
list
)):
raise
TypeError
(
'The type of `shape` should be `tuple` or `list`'
)
if
not
(
np
.
array
(
shape
)
>
0
).
all
():
raise
ValueError
(
'Negative dimensions are not allowed'
)
try
:
mean_arr
=
np
.
random
.
normal
(
loc_mean
,
loc_std
,
shape
)
except
ValueError
as
msg
:
raise
ValueError
(
msg
)
except
TypeError
as
msg
:
raise
TypeError
(
msg
)
if
not
(
np
.
array
(
loc_std
)
>=
0
).
all
():
raise
ValueError
(
'The value of `loc_std` < 0'
)
if
not
(
np
.
array
(
untransformed_scale_std
)
>=
0
).
all
():
raise
ValueError
(
'The value of `untransformed_scale_std` < 0'
)
try
:
untransformed_scale_arr
=
np
.
random
.
normal
(
untransformed_scale_mean
,
untransformed_scale_std
,
shape
)
except
ValueError
as
msg
:
raise
ValueError
(
msg
)
except
TypeError
as
msg
:
raise
TypeError
(
msg
)
self
.
mean
=
Parameter
(
Tensor
(
np
.
random
.
normal
(
loc_mean
,
loc_std
,
shape
)
,
dtype
=
dtype
),
name
=
name
+
'_mean'
)
Tensor
(
mean_arr
,
dtype
=
dtype
),
name
=
name
+
'_mean'
)
self
.
untransformed_std
=
Parameter
(
Tensor
(
np
.
random
.
normal
(
untransformed_scale_mean
,
untransformed_scale_std
,
shape
),
dtype
=
dtype
),
name
=
name
+
'_untransformed_std'
)
Tensor
(
untransformed_scale_arr
,
dtype
=
dtype
),
name
=
name
+
'_untransformed_std'
)
self
.
normal
=
Normal
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录