提交 99bac634 编写于 作者: B bingyaweng

fix error of bnn_layers

上级 d8a4827f
...@@ -222,22 +222,22 @@ class ConvReparam(_ConvVariational): ...@@ -222,22 +222,22 @@ class ConvReparam(_ConvVariational):
Default: 1. Default: 1.
has_bias (bool): Specifies whether the layer uses a bias vector. has_bias (bool): Specifies whether the layer uses a bias vector.
Default: False. Default: False.
weight_prior_fn: prior distribution for convolution kernel. weight_prior_fn: prior distribution for weight.
It should return a mindspore distribution instance. It should return a mindspore distribution instance.
Default: NormalPrior. (which creates an instance of standard Default: NormalPrior. (which creates an instance of standard
normal distribution). normal distribution). The current version only supports NormalPrior.
weight_posterior_fn: posterior distribution for sampling convolution weight_posterior_fn: posterior distribution for sampling weight.
kernel. It should be a function handle which returns a mindspore It should be a function handle which returns a mindspore
distribution instance. distribution instance. Default: NormalPosterior. The current
Default: NormalPosterior. version only supports NormalPosterior.
bias_prior_fn: prior distribution for bias vector. It should return bias_prior_fn: prior distribution for bias vector. It should return
a mindspore distribution. a mindspore distribution. Default: NormalPrior(which creates an
Default: NormalPrior(which creates an instance of standard instance of standard normal distribution). The current version
normal distribution). only supports NormalPrior.
bias_posterior_fn: posterior distribution for sampling bias vector. bias_posterior_fn: posterior distribution for sampling bias vector.
It should be a function handle which returns a mindspore It should be a function handle which returns a mindspore
distribution instance. distribution instance. Default: NormalPosterior. The current
Default: NormalPosterior. version only supports NormalPosterior.
Inputs: Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
......
...@@ -72,9 +72,16 @@ class _DenseVariational(Cell): ...@@ -72,9 +72,16 @@ class _DenseVariational(Cell):
raise TypeError('The type of `bias_posterior_fn` should be `NormalPosterior`') raise TypeError('The type of `bias_posterior_fn` should be `NormalPosterior`')
self.activation = activation self.activation = activation
if not self.activation:
self.activation_flag = False
else:
self.activation_flag = True
if isinstance(self.activation, str): if isinstance(self.activation, str):
self.activation = get_activation(activation) self.activation = get_activation(activation)
self.activation_flag = self.activation is not None elif isinstance(self.activation, Cell):
self.activation = activation
else:
raise ValueError('The type of `activation` is wrong.')
self.matmul = P.MatMul(transpose_b=True) self.matmul = P.MatMul(transpose_b=True)
self.bias_add = P.BiasAdd() self.bias_add = P.BiasAdd()
...@@ -145,23 +152,25 @@ class DenseReparam(_DenseVariational): ...@@ -145,23 +152,25 @@ class DenseReparam(_DenseVariational):
in_channels (int): The number of input channel. in_channels (int): The number of input channel.
out_channels (int): The number of output channel . out_channels (int): The number of output channel .
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None. activation (str, Cell): Regularizer function applied to the output of the layer. The type of activation can
be str (eg. 'relu') or Cell (eg. nn.ReLU()). Note that if the type of activation is Cell, it must have been
instantiated. Default: None.
weight_prior_fn: prior distribution for weight. weight_prior_fn: prior distribution for weight.
It should return a mindspore distribution instance. It should return a mindspore distribution instance.
Default: NormalPrior. (which creates an instance of standard Default: NormalPrior. (which creates an instance of standard
normal distribution). normal distribution). The current version only supports NormalPrior.
weight_posterior_fn: posterior distribution for sampling weight. weight_posterior_fn: posterior distribution for sampling weight.
It should be a function handle which returns a mindspore It should be a function handle which returns a mindspore
distribution instance. distribution instance. Default: NormalPosterior. The current
Default: NormalPosterior. version only supports NormalPosterior.
bias_prior_fn: prior distribution for bias vector. It should return bias_prior_fn: prior distribution for bias vector. It should return
a mindspore distribution. a mindspore distribution. Default: NormalPrior(which creates an
Default: NormalPrior(which creates an instance of standard instance of standard normal distribution). The current version
normal distribution). only supports NormalPrior.
bias_posterior_fn: posterior distribution for sampling bias vector. bias_posterior_fn: posterior distribution for sampling bias vector.
It should be a function handle which returns a mindspore It should be a function handle which returns a mindspore
distribution instance. distribution instance. Default: NormalPosterior. The current
Default: NormalPosterior. version only supports NormalPosterior.
Inputs: Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`. - **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`.
......
...@@ -54,14 +54,11 @@ class NormalPosterior(Cell): ...@@ -54,14 +54,11 @@ class NormalPosterior(Cell):
shape (list, tuple): Shape of the mean and standard deviation. shape (list, tuple): Shape of the mean and standard deviation.
dtype (class `mindspore.dtype`): The argument is used to define the data type of the output tensor. dtype (class `mindspore.dtype`): The argument is used to define the data type of the output tensor.
Default: mindspore.float32. Default: mindspore.float32.
loc_mean (int, float, array_like of floats): Mean of distribution to initialize trainable parameters. loc_mean (int, float): Mean of distribution to initialize trainable parameters. Default: 0.
Default: 0. loc_std (int, float): Standard deviation of distribution to initialize trainable parameters. Default: 0.1.
loc_std (int, float, array_like of floats): Standard deviation of distribution to initialize trainable untransformed_scale_mean (int, float): Mean of distribution to initialize trainable parameters. Default: -5.
parameters. Default: 0.1. untransformed_scale_std (int, float): Standard deviation of distribution to initialize trainable parameters.
untransformed_scale_mean (int, float, array_like of floats): Mean of distribution to initialize trainable Default: 0.1.
parameters. Default: -5.
untransformed_scale_std (int, float, array_like of floats): Standard deviation of distribution to initialize
trainable parameters. Default: 0.1.
Returns: Returns:
Cell, a normal distribution. Cell, a normal distribution.
...@@ -81,25 +78,25 @@ class NormalPosterior(Cell): ...@@ -81,25 +78,25 @@ class NormalPosterior(Cell):
if not isinstance(shape, (tuple, list)): if not isinstance(shape, (tuple, list)):
raise TypeError('The type of `shape` should be `tuple` or `list`') raise TypeError('The type of `shape` should be `tuple` or `list`')
try: if not isinstance(loc_mean, (int, float)):
mean_arr = np.random.normal(loc_mean, loc_std, shape) raise TypeError('The type of `loc_mean` should be `int` or `float`')
except ValueError as msg:
raise ValueError(msg)
except TypeError as msg:
raise TypeError(msg)
try: if not isinstance(untransformed_scale_mean, (int, float)):
untransformed_scale_arr = np.random.normal(untransformed_scale_mean, untransformed_scale_std, shape) raise TypeError('The type of `untransformed_scale_mean` should be `int` or `float`')
except ValueError as msg:
raise ValueError(msg) if not (isinstance(loc_std, (int, float)) and loc_std >= 0):
except TypeError as msg: raise TypeError('The type of `loc_std` should be `int` or `float` and its value should > 0')
raise TypeError(msg)
if not (isinstance(untransformed_scale_std, (int, float)) and untransformed_scale_std >= 0):
raise TypeError('The type of `untransformed_scale_std` should be `int` or `float` and '
'its value should > 0')
self.mean = Parameter( self.mean = Parameter(
Tensor(mean_arr, dtype=dtype), name=name + '_mean') Tensor(np.random.normal(loc_mean, loc_std, shape), dtype=dtype), name=name + '_mean')
self.untransformed_std = Parameter( self.untransformed_std = Parameter(
Tensor(untransformed_scale_arr, dtype=dtype), name=name + '_untransformed_std') Tensor(np.random.normal(untransformed_scale_mean, untransformed_scale_std, shape), dtype=dtype),
name=name + '_untransformed_std')
self.normal = Normal() self.normal = Normal()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册