Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
61dbb1b1
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
61dbb1b1
编写于
8月 13, 2020
作者:
B
bingyaweng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add bnn_layers to nn.probability
上级
fb2f888e
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
685 addition
and
5 deletion
+685
-5
mindspore/nn/probability/bnn_layers/__init__.py
mindspore/nn/probability/bnn_layers/__init__.py
+31
-0
mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py
mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py
+92
-0
mindspore/nn/probability/bnn_layers/conv_variational.py
mindspore/nn/probability/bnn_layers/conv_variational.py
+270
-0
mindspore/nn/probability/bnn_layers/dense_variational.py
mindspore/nn/probability/bnn_layers/dense_variational.py
+188
-0
mindspore/nn/probability/bnn_layers/layer_distribution.py
mindspore/nn/probability/bnn_layers/layer_distribution.py
+96
-0
mindspore/nn/probability/distribution/_utils/utils.py
mindspore/nn/probability/distribution/_utils/utils.py
+3
-2
mindspore/nn/probability/transforms/bnn_loss/generate_kl_loss.py
...re/nn/probability/transforms/bnn_loss/generate_kl_loss.py
+2
-2
requirements.txt
requirements.txt
+1
-0
setup.py
setup.py
+2
-1
未找到文件。
mindspore/nn/probability/bnn_layers/__init__.py
0 → 100644
浏览文件 @
61dbb1b1
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Bayesian Layer.
The high-level components(Cells) used to construct the bayesian neural network.
"""
from
.
import
conv_variational
,
dense_variational
,
layer_distribution
,
bnn_cell_wrapper
from
.conv_variational
import
ConvReparam
from
.dense_variational
import
DenseReparam
from
.layer_distribution
import
NormalPrior
,
NormalPosterior
from
.bnn_cell_wrapper
import
WithBNNLossCell
__all__
=
[]
__all__
.
extend
(
conv_variational
.
__all__
)
__all__
.
extend
(
dense_variational
.
__all__
)
__all__
.
extend
(
layer_distribution
.
__all__
)
__all__
.
extend
(
bnn_cell_wrapper
.
__all__
)
mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py
0 → 100644
浏览文件 @
61dbb1b1
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Generate WithLossCell suitable for BNN."""
from
.conv_variational
import
_ConvVariational
from
.dense_variational
import
_DenseVariational
from
..transforms.bnn_loss.generate_kl_loss
import
gain_bnn_with_loss
__all__
=
[
'WithBNNLossCell'
]
class
ClassWrap
:
"""Decorator of WithBNNLossCell"""
def
__init__
(
self
,
cls
):
self
.
_cls
=
cls
self
.
bnn_loss_file
=
None
def
__call__
(
self
,
backbone
,
loss_fn
,
backbone_factor
,
kl_factor
):
obj
=
self
.
_cls
(
backbone
,
loss_fn
,
backbone_factor
,
kl_factor
)
bnn_with_loss
=
obj
()
self
.
bnn_loss_file
=
obj
.
bnn_loss_file
return
bnn_with_loss
@
ClassWrap
class
WithBNNLossCell
:
r
"""
Generate WithLossCell suitable for BNN.
Args:
backbone (Cell): The target network.
loss_fn (Cell): The loss function used to compute loss.
dnn_factor(int, float): The coefficient of backbone's loss, which is computed by loss functin. Default: 1.
bnn_factor(int, float): The coefficient of kl loss, which is kl divergence of Bayesian layer. Default: 1.
Inputs:
- **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
- **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
Outputs:
Tensor, a scalar tensor with shape :math:`()`.
Examples:
>>> net = Net()
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
>>> net_with_criterion_object = WithBNNLossCell(net, loss_fn)
>>> net_with_criterion = net_with_criterion_object()
>>>
>>> batch_size = 2
>>> data = Tensor(np.ones([batch_size, 3, 64, 64]).astype(np.float32) * 0.01)
>>> label = Tensor(np.ones([batch_size, 1, 1, 1]).astype(np.int32))
>>>
>>> net_with_criterion(data, label)
"""
def
__init__
(
self
,
backbone
,
loss_fn
,
dnn_factor
=
1
,
bnn_factor
=
1
):
self
.
backbone
=
backbone
self
.
loss_fn
=
loss_fn
self
.
dnn_factor
=
dnn_factor
self
.
bnn_factor
=
bnn_factor
self
.
bnn_loss_file
=
None
def
_generate_loss_cell
(
self
):
"""Generate WithBNNLossCell by ast."""
layer_count
=
self
.
_kl_loss_count
(
self
.
backbone
)
bnn_with_loss
,
self
.
bnn_loss_file
=
gain_bnn_with_loss
(
layer_count
,
self
.
backbone
,
self
.
loss_fn
,
self
.
dnn_factor
,
self
.
bnn_factor
)
return
bnn_with_loss
def
_kl_loss_count
(
self
,
net
):
""" Calculate the number of Bayesian layers."""
count
=
0
for
(
_
,
layer
)
in
net
.
name_cells
().
items
():
if
isinstance
(
layer
,
(
_DenseVariational
,
_ConvVariational
)):
count
+=
1
else
:
count
+=
self
.
_kl_loss_count
(
layer
)
return
count
def
__call__
(
self
):
return
self
.
_generate_loss_cell
()
mindspore/nn/probability/bnn_layers/conv_variational.py
0 → 100644
浏览文件 @
61dbb1b1
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Convolutional variational layers."""
from
mindspore.ops
import
operations
as
P
from
mindspore._checkparam
import
twice
from
...layer.conv
import
_Conv
from
...cell
import
Cell
from
.layer_distribution
import
NormalPrior
,
NormalPosterior
__all__
=
[
'ConvReparam'
]
class
_ConvVariational
(
_Conv
):
"""
Base class for all convolutional variational layers.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
pad_mode
=
'same'
,
padding
=
0
,
dilation
=
1
,
group
=
1
,
has_bias
=
False
,
weight_prior_fn
=
NormalPrior
,
weight_posterior_fn
=
lambda
name
,
shape
:
NormalPosterior
(
name
=
name
,
shape
=
shape
),
bias_prior_fn
=
NormalPrior
,
bias_posterior_fn
=
lambda
name
,
shape
:
NormalPosterior
(
name
=
name
,
shape
=
shape
)):
kernel_size
=
twice
(
kernel_size
)
stride
=
twice
(
stride
)
dilation
=
twice
(
dilation
)
super
(
_ConvVariational
,
self
).
__init__
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
pad_mode
,
padding
,
dilation
,
group
,
has_bias
,
weight_init
=
'normal'
,
bias_init
=
'zeros'
)
if
pad_mode
not
in
(
'valid'
,
'same'
,
'pad'
):
raise
ValueError
(
'Attr
\'
pad_mode
\'
of
\'
Conv2d
\'
Op passed '
+
str
(
pad_mode
)
+
', should be one of values in
\'
valid
\'
,
\'
same
\'
,
\'
pad
\'
.'
)
# convolution args
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
kernel_size
=
kernel_size
self
.
stride
=
stride
self
.
pad_mode
=
pad_mode
self
.
padding
=
padding
self
.
dilation
=
dilation
self
.
group
=
group
self
.
has_bias
=
has_bias
# distribution trainable parameters
self
.
shape
=
[
self
.
out_channels
,
self
.
in_channels
//
self
.
group
,
*
self
.
kernel_size
]
self
.
weight
.
requires_grad
=
False
if
isinstance
(
weight_prior_fn
,
Cell
):
self
.
weight_prior
=
weight_prior_fn
else
:
self
.
weight_prior
=
weight_prior_fn
()
self
.
weight_posterior
=
weight_posterior_fn
(
shape
=
self
.
shape
,
name
=
'bnn_weight'
)
if
self
.
has_bias
:
self
.
bias
.
requires_grad
=
False
if
isinstance
(
bias_prior_fn
,
Cell
):
self
.
bias_prior
=
bias_prior_fn
else
:
self
.
bias_prior
=
bias_prior_fn
()
self
.
bias_posterior
=
bias_posterior_fn
(
shape
=
[
self
.
out_channels
],
name
=
'bnn_bias'
)
# mindspore operations
self
.
bias_add
=
P
.
BiasAdd
()
self
.
conv2d
=
P
.
Conv2D
(
out_channel
=
self
.
out_channels
,
kernel_size
=
self
.
kernel_size
,
mode
=
1
,
pad_mode
=
self
.
pad_mode
,
pad
=
self
.
padding
,
stride
=
self
.
stride
,
dilation
=
self
.
dilation
,
group
=
self
.
group
)
self
.
log
=
P
.
Log
()
self
.
sum
=
P
.
ReduceSum
()
def
construct
(
self
,
inputs
):
outputs
=
self
.
_apply_variational_weight
(
inputs
)
if
self
.
has_bias
:
outputs
=
self
.
_apply_variational_bias
(
outputs
)
return
outputs
def
extend_repr
(
self
):
str_info
=
'in_channels={}, out_channels={}, kernel_size={}, weight_mean={}, stride={}, pad_mode={}, '
\
'padding={}, dilation={}, group={}, weight_std={}, has_bias={}'
\
.
format
(
self
.
in_channels
,
self
.
out_channels
,
self
.
kernel_size
,
self
.
stride
,
self
.
pad_mode
,
self
.
padding
,
self
.
dilation
,
self
.
group
,
self
.
weight_posterior
.
mean
,
self
.
weight_posterior
.
untransformed_std
,
self
.
has_bias
)
if
self
.
has_bias
:
str_info
=
str_info
+
', bias_mean={}, bias_std={}'
\
.
format
(
self
.
bias_posterior
.
mean
,
self
.
bias_posterior
.
untransformed_std
)
return
str_info
def
_apply_variational_bias
(
self
,
inputs
):
bias_posterior_tensor
=
self
.
bias_posterior
(
"sample"
)
return
self
.
bias_add
(
inputs
,
bias_posterior_tensor
)
def
compute_kl_loss
(
self
):
"""Compute kl loss"""
weight_post_mean
=
self
.
weight_posterior
(
"mean"
)
weight_post_sd
=
self
.
weight_posterior
(
"sd"
)
kl
=
self
.
weight_prior
(
"kl_loss"
,
"Normal"
,
weight_post_mean
,
weight_post_sd
)
kl_loss
=
self
.
sum
(
kl
)
if
self
.
has_bias
:
bias_post_mean
=
self
.
bias_posterior
(
"mean"
)
bias_post_sd
=
self
.
bias_posterior
(
"sd"
)
kl
=
self
.
bias_prior
(
"kl_loss"
,
"Normal"
,
bias_post_mean
,
bias_post_sd
)
kl
=
self
.
sum
(
kl
)
kl_loss
+=
kl
return
kl_loss
class
ConvReparam
(
_ConvVariational
):
r
"""
Convolutional variational layers with Reparameterization.
See more details in paper `Auto-Encoding Variational Bayes
<https://arxiv.org/abs/1312.6114>`
Args:
in_channels (int): The number of input channel :math:`C_{in}`.
out_channels (int): The number of output channel :math:`C_{out}`.
kernel_size (Union[int, tuple[int]]): The data type is int or
tuple with 2 integers. Specifies the height and width of the 2D
convolution window. Single int means the value if for both
height and width of the kernel. A tuple of 2 ints means the
first value is for the height and the other is for the width of
the kernel.
stride(Union[int, tuple[int]]): The distance of kernel moving,
an int number that represents the height and width of movement
are both strides, or a tuple of two int numbers that represent
height and width of movement respectively. Default: 1.
pad_mode (str): Specifies padding mode. The optional values are
"same", "valid", "pad". Default: "same".
- same: Adopts the way of completion. Output height and width
will be the same as the input.
Total number of padding will be calculated for horizontal and
vertical direction and evenly distributed to top and bottom,
left and right if possible. Otherwise, the last extra padding
will be done from the bottom and the right side. If this mode
is set, `padding` must be 0.
- valid: Adopts the way of discarding. The possibly largest
height and width of output will be return without padding.
Extra pixels will be discarded. If this mode is set, `padding`
must be 0.
- pad: Implicit paddings on both sides of the input. The number
of `padding` will be padded to the input Tensor borders.
`padding` should be greater than or equal to 0.
padding (Union[int, tuple[int]]): Implicit paddings on both sides of
the input. Default: 0.
dilation (Union[int, tuple[int]]): The data type is int or tuple
with 2 integers. Specifies the dilation rate to use for dilated
convolution. If set to be :math:`k > 1`,
there will be :math:`k - 1` pixels skipped for each sampling
location. Its value should be greater or equal to 1 and bounded
by the height and width of the input. Default: 1.
group (int): Split filter into groups, `in_ channels` and
`out_channels` should be divisible by the number of groups.
Default: 1.
has_bias (bool): Specifies whether the layer uses a bias vector.
Default: False.
weight_prior_fn: prior distribution for convolution kernel.
It should return a mindspore distribution instance.
Default: NormalPrior. (which creates an instance of standard
normal distribution).
weight_posterior_fn: posterior distribution for sampling convolution
kernel. It should be a function handle which returns a mindspore
distribution instance.
Default: NormalPosterior.
bias_prior_fn: prior distribution for bias vector. It should return
a mindspore distribution.
Default: NormalPrior(which creates an instance of standard
normal distribution).
bias_posterior_fn: posterior distribution for sampling bias vector.
It should be a function handle which returns a mindspore
distribution instance.
Default: NormalPosterior.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
Outputs:
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
Examples:
Examples:
>>> net = ConvReparam(120, 240, 4, has_bias=False)
>>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
>>> net(input).shape
(1, 240, 1024, 640)
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
pad_mode
=
'same'
,
padding
=
0
,
dilation
=
1
,
group
=
1
,
has_bias
=
False
,
weight_prior_fn
=
NormalPrior
,
weight_posterior_fn
=
lambda
name
,
shape
:
NormalPosterior
(
name
=
name
,
shape
=
shape
),
bias_prior_fn
=
NormalPrior
,
bias_posterior_fn
=
lambda
name
,
shape
:
NormalPosterior
(
name
=
name
,
shape
=
shape
)):
super
(
ConvReparam
,
self
).
__init__
(
in_channels
,
out_channels
,
kernel_size
,
stride
=
stride
,
pad_mode
=
pad_mode
,
padding
=
padding
,
dilation
=
dilation
,
group
=
group
,
has_bias
=
has_bias
,
weight_prior_fn
=
weight_prior_fn
,
weight_posterior_fn
=
weight_posterior_fn
,
bias_prior_fn
=
bias_prior_fn
,
bias_posterior_fn
=
bias_posterior_fn
)
def
_apply_variational_weight
(
self
,
inputs
):
weight_posterior_tensor
=
self
.
weight_posterior
(
"sample"
)
outputs
=
self
.
conv2d
(
inputs
,
weight_posterior_tensor
)
return
outputs
mindspore/nn/probability/bnn_layers/dense_variational.py
0 → 100644
浏览文件 @
61dbb1b1
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""dense_variational"""
from
mindspore.ops
import
operations
as
P
from
mindspore._checkparam
import
check_int_positive
,
check_bool
from
...cell
import
Cell
from
...layer.activation
import
get_activation
from
.layer_distribution
import
NormalPrior
,
NormalPosterior
__all__
=
[
'DenseReparam'
]
class
_DenseVariational
(
Cell
):
"""
Base class for all dense variational layers.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
activation
=
None
,
has_bias
=
True
,
weight_prior_fn
=
NormalPrior
,
weight_posterior_fn
=
lambda
name
,
shape
:
NormalPosterior
(
name
=
name
,
shape
=
shape
),
bias_prior_fn
=
NormalPrior
,
bias_posterior_fn
=
lambda
name
,
shape
:
NormalPosterior
(
name
=
name
,
shape
=
shape
)):
super
(
_DenseVariational
,
self
).
__init__
()
self
.
in_channels
=
check_int_positive
(
in_channels
)
self
.
out_channels
=
check_int_positive
(
out_channels
)
self
.
has_bias
=
check_bool
(
has_bias
)
if
isinstance
(
weight_prior_fn
,
Cell
):
self
.
weight_prior
=
weight_prior_fn
else
:
self
.
weight_prior
=
weight_prior_fn
()
self
.
weight_posterior
=
weight_posterior_fn
(
shape
=
[
self
.
out_channels
,
self
.
in_channels
],
name
=
'bnn_weight'
)
if
self
.
has_bias
:
if
isinstance
(
bias_prior_fn
,
Cell
):
self
.
bias_prior
=
bias_prior_fn
else
:
self
.
bias_prior
=
bias_prior_fn
()
self
.
bias_posterior
=
bias_posterior_fn
(
shape
=
[
self
.
out_channels
],
name
=
'bnn_bias'
)
self
.
activation
=
activation
if
isinstance
(
self
.
activation
,
str
):
self
.
activation
=
get_activation
(
activation
)
self
.
activation_flag
=
self
.
activation
is
not
None
self
.
matmul
=
P
.
MatMul
(
transpose_b
=
True
)
self
.
bias_add
=
P
.
BiasAdd
()
self
.
sum
=
P
.
ReduceSum
()
def
construct
(
self
,
x
):
outputs
=
self
.
_apply_variational_weight
(
x
)
if
self
.
has_bias
:
outputs
=
self
.
_apply_variational_bias
(
outputs
)
if
self
.
activation_flag
:
outputs
=
self
.
activation
(
outputs
)
return
outputs
def
extend_repr
(
self
):
str_info
=
'in_channels={}, out_channels={}, weight_mean={}, weight_std={}, has_bias={}'
\
.
format
(
self
.
in_channels
,
self
.
out_channels
,
self
.
weight_posterior
.
mean
,
self
.
weight_posterior
.
untransformed_std
,
self
.
has_bias
)
if
self
.
has_bias
:
str_info
=
str_info
+
', bias_mean={}, bias_std={}'
\
.
format
(
self
.
bias_posterior
.
mean
,
self
.
bias_posterior
.
untransformed_std
)
if
self
.
activation_flag
:
str_info
=
str_info
+
', activation={}'
.
format
(
self
.
activation
)
return
str_info
def
_apply_variational_bias
(
self
,
inputs
):
bias_posterior_tensor
=
self
.
bias_posterior
(
"sample"
)
return
self
.
bias_add
(
inputs
,
bias_posterior_tensor
)
def
compute_kl_loss
(
self
):
"""Compute kl loss."""
weight_post_mean
=
self
.
weight_posterior
(
"mean"
)
weight_post_sd
=
self
.
weight_posterior
(
"sd"
)
kl
=
self
.
weight_prior
(
"kl_loss"
,
"Normal"
,
weight_post_mean
,
weight_post_sd
)
kl_loss
=
self
.
sum
(
kl
)
if
self
.
has_bias
:
bias_post_mean
=
self
.
bias_posterior
(
"mean"
)
bias_post_sd
=
self
.
bias_posterior
(
"sd"
)
kl
=
self
.
bias_prior
(
"kl_loss"
,
"Normal"
,
bias_post_mean
,
bias_post_sd
)
kl
=
self
.
sum
(
kl
)
kl_loss
+=
kl
return
kl_loss
class
DenseReparam
(
_DenseVariational
):
r
"""
Dense variational layers with Reparameterization.
See more details in paper `Auto-Encoding Variational Bayes
<https://arxiv.org/abs/1312.6114>`
Applies dense-connected layer for the input. This layer implements the operation as:
.. math::
\text{outputs} = \text{activation}(\text{inputs} * \text{kernel} + \text{bias}),
where :math:`\text{activation}` is the activation function passed as the activation
argument (if passed in), :math:`\text{activation}` is a weight matrix with the same
data type as the inputs created by the layer, :math:`\text{weight}` is a weight
matrix sampling from posterior distribution of weight, and :math:`\text{bias}` is a
bias vector with the same data type as the inputs created by the layer (only if
has_bias is True). The bias vector is sampling from posterior distribution of
:math:`\text{bias}`.
Args:
in_channels (int): The number of input channel.
out_channels (int): The number of output channel .
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
weight_prior_fn: prior distribution for weight.
It should return a mindspore distribution instance.
Default: NormalPrior. (which creates an instance of standard
normal distribution).
weight_posterior_fn: posterior distribution for sampling weight.
It should be a function handle which returns a mindspore
distribution instance.
Default: NormalPosterior.
bias_prior_fn: prior distribution for bias vector. It should return
a mindspore distribution.
Default: NormalPrior(which creates an instance of standard
normal distribution).
bias_posterior_fn: posterior distribution for sampling bias vector.
It should be a function handle which returns a mindspore
distribution instance.
Default: NormalPosterior.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`.
Outputs:
Tensor of shape :math:`(N, out\_channels)`.
Examples:
>>> net = DenseReparam(3, 4)
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
>>> net(input)
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
activation
=
None
,
has_bias
=
True
,
weight_prior_fn
=
NormalPrior
,
weight_posterior_fn
=
lambda
name
,
shape
:
NormalPosterior
(
name
=
name
,
shape
=
shape
),
bias_prior_fn
=
NormalPrior
,
bias_posterior_fn
=
lambda
name
,
shape
:
NormalPosterior
(
name
=
name
,
shape
=
shape
)):
super
(
DenseReparam
,
self
).
__init__
(
in_channels
,
out_channels
,
activation
=
activation
,
has_bias
=
has_bias
,
weight_prior_fn
=
weight_prior_fn
,
weight_posterior_fn
=
weight_posterior_fn
,
bias_prior_fn
=
bias_prior_fn
,
bias_posterior_fn
=
bias_posterior_fn
)
def
_apply_variational_weight
(
self
,
inputs
):
weight_posterior_tensor
=
self
.
weight_posterior
(
"sample"
)
outputs
=
self
.
matmul
(
inputs
,
weight_posterior_tensor
)
return
outputs
mindspore/nn/probability/bnn_layers/layer_distribution.py
0 → 100644
浏览文件 @
61dbb1b1
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Initialize normal distributions"""
import
numpy
as
np
import
mindspore.common.dtype
as
mstype
from
mindspore.common.tensor
import
Tensor
from
mindspore.common.parameter
import
Parameter
from
mindspore.ops
import
operations
as
P
from
...cell
import
Cell
from
..distribution.normal
import
Normal
__all__
=
[
'NormalPrior'
,
'NormalPosterior'
]
class
NormalPrior
(
Cell
):
r
"""
To initialize a normal distribution of mean 0 and standard deviation 0.1.
Args:
dtype (class `mindspore.dtype`): The argument is used to define the data type of the output tensor.
Default: mindspore.float32.
mean (int, float): Mean of normal distribution.
std (int, float): Standard deviation of normal distribution.
Returns:
Cell, a normal distribution.
"""
def
__init__
(
self
,
dtype
=
mstype
.
float32
,
mean
=
0
,
std
=
0.1
):
super
(
NormalPrior
,
self
).
__init__
()
self
.
normal
=
Normal
(
mean
,
std
,
dtype
=
dtype
)
def
construct
(
self
,
*
inputs
):
return
self
.
normal
(
*
inputs
)
class
NormalPosterior
(
Cell
):
r
"""
Build Normal distributions with trainable parameters.
Args:
name (str): Name prepended to trainable parameter.
shape (list): Shape of the mean and standard deviation.
dtype (class `mindspore.dtype`): The argument is used to define the data type of the output tensor.
Default: mindspore.float32.
loc_mean ( float, array_like of floats): Mean of distribution to initialize trainable parameters. Default: 0.
loc_std ( float, array_like of floats): Standard deviation of distribution to initialize trainable parameters.
Default: 0.1.
untransformed_scale_mean ( float, array_like of floats): Mean of distribution to initialize trainable
parameters. Default: -5.
untransformed_scale_std ( float, array_like of floats): Standard deviation of distribution to initialize
trainable parameters. Default: 0.1.
Returns:
Cell, a normal distribution.
"""
def
__init__
(
self
,
name
,
shape
,
dtype
=
mstype
.
float32
,
loc_mean
=
0
,
loc_std
=
0.1
,
untransformed_scale_mean
=-
5
,
untransformed_scale_std
=
0.1
):
super
(
NormalPosterior
,
self
).
__init__
()
if
not
isinstance
(
name
,
str
):
raise
ValueError
(
'The type of `name` should be `str`'
)
self
.
mean
=
Parameter
(
Tensor
(
np
.
random
.
normal
(
loc_mean
,
loc_std
,
shape
),
dtype
=
dtype
),
name
=
name
+
'_mean'
)
self
.
untransformed_std
=
Parameter
(
Tensor
(
np
.
random
.
normal
(
untransformed_scale_mean
,
untransformed_scale_std
,
shape
),
dtype
=
dtype
),
name
=
name
+
'_untransformed_std'
)
self
.
normal
=
Normal
()
def
std_trans
(
self
,
std_pre
):
"""Transform std_pre to prevent its value being zero."""
std
=
1e-6
+
P
.
Log
()(
P
.
Exp
()(
std_pre
)
+
1
)
return
std
def
construct
(
self
,
*
inputs
):
std
=
self
.
std_trans
(
self
.
untransformed_std
)
return
self
.
normal
(
*
inputs
,
mean
=
self
.
mean
,
sd
=
std
)
mindspore/nn/probability/distribution/_utils/utils.py
浏览文件 @
61dbb1b1
...
@@ -21,6 +21,7 @@ from mindspore.common import dtype as mstype
...
@@ -21,6 +21,7 @@ from mindspore.common import dtype as mstype
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
composite
as
C
from
mindspore.ops
import
composite
as
C
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
import
mindspore.nn.probability
as
msp
def
cast_to_tensor
(
t
,
hint_dtype
=
mstype
.
float32
):
def
cast_to_tensor
(
t
,
hint_dtype
=
mstype
.
float32
):
"""
"""
...
@@ -84,7 +85,7 @@ def check_scalar_from_param(params):
...
@@ -84,7 +85,7 @@ def check_scalar_from_param(params):
Notes: String parameters are excluded.
Notes: String parameters are excluded.
"""
"""
for
value
in
params
.
values
():
for
value
in
params
.
values
():
if
isinstance
(
value
,
(
nn
.
probability
.
bijector
.
Bijector
,
nn
.
probability
.
distribution
.
Distribution
)):
if
isinstance
(
value
,
(
msp
.
bijector
.
Bijector
,
msp
.
distribution
.
Distribution
)):
return
params
[
'distribution'
].
is_scalar_batch
return
params
[
'distribution'
].
is_scalar_batch
if
isinstance
(
value
,
Parameter
):
if
isinstance
(
value
,
Parameter
):
return
False
return
False
...
@@ -109,7 +110,7 @@ def calc_broadcast_shape_from_param(params):
...
@@ -109,7 +110,7 @@ def calc_broadcast_shape_from_param(params):
"""
"""
broadcast_shape
=
[]
broadcast_shape
=
[]
for
value
in
params
.
values
():
for
value
in
params
.
values
():
if
isinstance
(
value
,
(
nn
.
probability
.
bijector
.
Bijector
,
nn
.
probability
.
distribution
.
Distribution
)):
if
isinstance
(
value
,
(
msp
.
bijector
.
Bijector
,
msp
.
distribution
.
Distribution
)):
return
params
[
'distribution'
].
broadcast_shape
return
params
[
'distribution'
].
broadcast_shape
if
isinstance
(
value
,
(
str
,
type
(
params
[
'dtype'
]))):
if
isinstance
(
value
,
(
str
,
type
(
params
[
'dtype'
]))):
continue
continue
...
...
mindspore/nn/probability/transforms/bnn_loss/generate_kl_loss.py
浏览文件 @
61dbb1b1
...
@@ -36,7 +36,7 @@ class _CodeTransformer(ast.NodeTransformer):
...
@@ -36,7 +36,7 @@ class _CodeTransformer(ast.NodeTransformer):
def
visit_FunctionDef
(
self
,
node
):
def
visit_FunctionDef
(
self
,
node
):
"""visit function and add kl_loss computation."""
"""visit function and add kl_loss computation."""
self
.
generic_visit
(
node
)
self
.
generic_visit
(
node
)
if
node
.
name
==
'c
ompute
_kl_loss'
:
if
node
.
name
==
'c
al
_kl_loss'
:
for
i
in
range
(
self
.
layer_count
):
for
i
in
range
(
self
.
layer_count
):
func
=
ast
.
Assign
(
targets
=
[
ast
.
Name
(
id
=
'loss'
,
ctx
=
ast
.
Store
())],
func
=
ast
.
Assign
(
targets
=
[
ast
.
Name
(
id
=
'loss'
,
ctx
=
ast
.
Store
())],
value
=
ast
.
BinOp
(
left
=
ast
.
Name
(
id
=
'loss'
,
ctx
=
ast
.
Load
()),
op
=
ast
.
Add
(),
value
=
ast
.
BinOp
(
left
=
ast
.
Name
(
id
=
'loss'
,
ctx
=
ast
.
Load
()),
op
=
ast
.
Add
(),
...
@@ -71,7 +71,7 @@ def gain_bnn_with_loss(layer_count, backbone, loss_fn, dnn_factor, bnn_factor):
...
@@ -71,7 +71,7 @@ def gain_bnn_with_loss(layer_count, backbone, loss_fn, dnn_factor, bnn_factor):
layer_count (int): The number of kl loss to be generated, namely the number of Bayesian layers.
layer_count (int): The number of kl loss to be generated, namely the number of Bayesian layers.
backbone (Cell): The target network to wrap.
backbone (Cell): The target network to wrap.
loss_fn (Cell): The loss function used to compute loss.
loss_fn (Cell): The loss function used to compute loss.
dnn_factor (
(
int, float): The coefficient of backbone's loss, which is computed by loss function.
dnn_factor (int, float): The coefficient of backbone's loss, which is computed by loss function.
bnn_factor (int, float): The coefficient of kl loss, which is kl divergence of Bayesian layer.
bnn_factor (int, float): The coefficient of kl loss, which is kl divergence of Bayesian layer.
"""
"""
bnn_loss_func
=
_generate_kl_loss_func
(
layer_count
)
bnn_loss_func
=
_generate_kl_loss_func
(
layer_count
)
...
...
requirements.txt
浏览文件 @
61dbb1b1
...
@@ -14,3 +14,4 @@ opencv-python >= 4.1.2.30 # for ut test
...
@@ -14,3 +14,4 @@ opencv-python >= 4.1.2.30 # for ut test
sklearn
>= 0.0 # for st test
sklearn
>= 0.0 # for st test
pandas
>= 1.0.2 # for ut test
pandas
>= 1.0.2 # for ut test
bs4
bs4
astunparse
setup.py
浏览文件 @
61dbb1b1
...
@@ -92,7 +92,8 @@ required_package = [
...
@@ -92,7 +92,8 @@ required_package = [
'easydict >= 1.9'
,
'easydict >= 1.9'
,
'sympy >= 1.4'
,
'sympy >= 1.4'
,
'cffi >= 1.13.2'
,
'cffi >= 1.13.2'
,
'decorator >= 4.4.0'
'decorator >= 4.4.0'
,
'astunparse >= 1.6.3'
]
]
package_data
=
{
package_data
=
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录