提交 99bac634 编写于 作者: B bingyaweng

fix error of bnn_layers

上级 d8a4827f
......@@ -222,22 +222,22 @@ class ConvReparam(_ConvVariational):
Default: 1.
has_bias (bool): Specifies whether the layer uses a bias vector.
Default: False.
weight_prior_fn: prior distribution for convolution kernel.
weight_prior_fn: prior distribution for weight.
It should return a mindspore distribution instance.
Default: NormalPrior. (which creates an instance of standard
normal distribution).
weight_posterior_fn: posterior distribution for sampling convolution
kernel. It should be a function handle which returns a mindspore
distribution instance.
Default: NormalPosterior.
normal distribution). The current version only supports NormalPrior.
weight_posterior_fn: posterior distribution for sampling weight.
It should be a function handle which returns a mindspore
distribution instance. Default: NormalPosterior. The current
version only supports NormalPosterior.
bias_prior_fn: prior distribution for bias vector. It should return
a mindspore distribution.
Default: NormalPrior(which creates an instance of standard
normal distribution).
a mindspore distribution. Default: NormalPrior(which creates an
instance of standard normal distribution). The current version
only supports NormalPrior.
bias_posterior_fn: posterior distribution for sampling bias vector.
It should be a function handle which returns a mindspore
distribution instance.
Default: NormalPosterior.
distribution instance. Default: NormalPosterior. The current
version only supports NormalPosterior.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
......
......@@ -72,9 +72,16 @@ class _DenseVariational(Cell):
raise TypeError('The type of `bias_posterior_fn` should be `NormalPosterior`')
self.activation = activation
if isinstance(self.activation, str):
self.activation = get_activation(activation)
self.activation_flag = self.activation is not None
if not self.activation:
self.activation_flag = False
else:
self.activation_flag = True
if isinstance(self.activation, str):
self.activation = get_activation(activation)
elif isinstance(self.activation, Cell):
self.activation = activation
else:
raise ValueError('The type of `activation` is wrong.')
self.matmul = P.MatMul(transpose_b=True)
self.bias_add = P.BiasAdd()
......@@ -145,23 +152,25 @@ class DenseReparam(_DenseVariational):
in_channels (int): The number of input channel.
out_channels (int): The number of output channel .
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
activation (str, Cell): Regularizer function applied to the output of the layer. The type of activation can
be str (eg. 'relu') or Cell (eg. nn.ReLU()). Note that if the type of activation is Cell, it must have been
instantiated. Default: None.
weight_prior_fn: prior distribution for weight.
It should return a mindspore distribution instance.
Default: NormalPrior. (which creates an instance of standard
normal distribution).
normal distribution). The current version only supports NormalPrior.
weight_posterior_fn: posterior distribution for sampling weight.
It should be a function handle which returns a mindspore
distribution instance.
Default: NormalPosterior.
distribution instance. Default: NormalPosterior. The current
version only supports NormalPosterior.
bias_prior_fn: prior distribution for bias vector. It should return
a mindspore distribution.
Default: NormalPrior(which creates an instance of standard
normal distribution).
a mindspore distribution. Default: NormalPrior(which creates an
instance of standard normal distribution). The current version
only supports NormalPrior.
bias_posterior_fn: posterior distribution for sampling bias vector.
It should be a function handle which returns a mindspore
distribution instance.
Default: NormalPosterior.
distribution instance. Default: NormalPosterior. The current
version only supports NormalPosterior.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`.
......
......@@ -54,14 +54,11 @@ class NormalPosterior(Cell):
shape (list, tuple): Shape of the mean and standard deviation.
dtype (class `mindspore.dtype`): The argument is used to define the data type of the output tensor.
Default: mindspore.float32.
loc_mean (int, float, array_like of floats): Mean of distribution to initialize trainable parameters.
Default: 0.
loc_std (int, float, array_like of floats): Standard deviation of distribution to initialize trainable
parameters. Default: 0.1.
untransformed_scale_mean (int, float, array_like of floats): Mean of distribution to initialize trainable
parameters. Default: -5.
untransformed_scale_std (int, float, array_like of floats): Standard deviation of distribution to initialize
trainable parameters. Default: 0.1.
loc_mean (int, float): Mean of distribution to initialize trainable parameters. Default: 0.
loc_std (int, float): Standard deviation of distribution to initialize trainable parameters. Default: 0.1.
untransformed_scale_mean (int, float): Mean of distribution to initialize trainable parameters. Default: -5.
untransformed_scale_std (int, float): Standard deviation of distribution to initialize trainable parameters.
Default: 0.1.
Returns:
Cell, a normal distribution.
......@@ -81,25 +78,25 @@ class NormalPosterior(Cell):
if not isinstance(shape, (tuple, list)):
raise TypeError('The type of `shape` should be `tuple` or `list`')
try:
mean_arr = np.random.normal(loc_mean, loc_std, shape)
except ValueError as msg:
raise ValueError(msg)
except TypeError as msg:
raise TypeError(msg)
if not isinstance(loc_mean, (int, float)):
raise TypeError('The type of `loc_mean` should be `int` or `float`')
try:
untransformed_scale_arr = np.random.normal(untransformed_scale_mean, untransformed_scale_std, shape)
except ValueError as msg:
raise ValueError(msg)
except TypeError as msg:
raise TypeError(msg)
if not isinstance(untransformed_scale_mean, (int, float)):
raise TypeError('The type of `untransformed_scale_mean` should be `int` or `float`')
if not (isinstance(loc_std, (int, float)) and loc_std >= 0):
raise TypeError('The type of `loc_std` should be `int` or `float` and its value should > 0')
if not (isinstance(untransformed_scale_std, (int, float)) and untransformed_scale_std >= 0):
raise TypeError('The type of `untransformed_scale_std` should be `int` or `float` and '
'its value should > 0')
self.mean = Parameter(
Tensor(mean_arr, dtype=dtype), name=name + '_mean')
Tensor(np.random.normal(loc_mean, loc_std, shape), dtype=dtype), name=name + '_mean')
self.untransformed_std = Parameter(
Tensor(untransformed_scale_arr, dtype=dtype), name=name + '_untransformed_std')
Tensor(np.random.normal(untransformed_scale_mean, untransformed_scale_std, shape), dtype=dtype),
name=name + '_untransformed_std')
self.normal = Normal()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册