提交 8acb5bdf 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4734 fix error of bnn_layers

Merge pull request !4734 from byweng/add_test
...@@ -61,6 +61,12 @@ class _ConvVariational(_Conv): ...@@ -61,6 +61,12 @@ class _ConvVariational(_Conv):
raise ValueError('Attr \'pad_mode\' of \'Conv2d\' Op passed ' raise ValueError('Attr \'pad_mode\' of \'Conv2d\' Op passed '
+ str(pad_mode) + ', should be one of values in \'valid\', \'same\', \'pad\'.') + str(pad_mode) + ', should be one of values in \'valid\', \'same\', \'pad\'.')
if not isinstance(stride, (int, tuple)):
raise TypeError('The type of `stride` should be `int` of `tuple`')
if not isinstance(dilation, (int, tuple)):
raise TypeError('The type of `dilation` should be `int` of `tuple`')
# convolution args # convolution args
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
...@@ -87,13 +93,10 @@ class _ConvVariational(_Conv): ...@@ -87,13 +93,10 @@ class _ConvVariational(_Conv):
raise TypeError('The type of `weight_prior_fn` should be `NormalPrior`') raise TypeError('The type of `weight_prior_fn` should be `NormalPrior`')
self.weight_prior = weight_prior_fn() self.weight_prior = weight_prior_fn()
if isinstance(weight_posterior_fn, Cell): try:
if weight_posterior_fn.__class__.__name__ != 'NormalPosterior': self.weight_posterior = weight_posterior_fn(shape=self.shape, name='bnn_weight')
raise TypeError('The type of `weight_posterior_fn` should be `NormalPosterior`') except TypeError:
else: raise TypeError('The type of `weight_posterior_fn` should be `NormalPosterior`')
if weight_posterior_fn.__name__ != 'NormalPosterior':
raise TypeError('The type of `weight_posterior_fn` should be `NormalPosterior`')
self.weight_posterior = weight_posterior_fn(shape=self.shape, name='bnn_weight')
if self.has_bias: if self.has_bias:
self.bias.requires_grad = False self.bias.requires_grad = False
...@@ -107,13 +110,10 @@ class _ConvVariational(_Conv): ...@@ -107,13 +110,10 @@ class _ConvVariational(_Conv):
raise TypeError('The type of `bias_prior_fn` should be `NormalPrior`') raise TypeError('The type of `bias_prior_fn` should be `NormalPrior`')
self.bias_prior = bias_prior_fn() self.bias_prior = bias_prior_fn()
if isinstance(bias_posterior_fn, Cell): try:
if bias_posterior_fn.__class__.__name__ != 'NormalPosterior': self.bias_posterior = bias_posterior_fn(shape=[self.out_channels], name='bnn_bias')
raise TypeError('The type of `bias_posterior_fn` should be `NormalPosterior`') except TypeError:
else: raise TypeError('The type of `bias_posterior_fn` should be `NormalPosterior`')
if bias_posterior_fn.__name__ != 'NormalPosterior':
raise TypeError('The type of `bias_posterior_fn` should be `NormalPosterior`')
self.bias_posterior = bias_posterior_fn(shape=[self.out_channels], name='bnn_bias')
# mindspore operations # mindspore operations
self.bias_add = P.BiasAdd() self.bias_add = P.BiasAdd()
......
...@@ -51,13 +51,10 @@ class _DenseVariational(Cell): ...@@ -51,13 +51,10 @@ class _DenseVariational(Cell):
raise TypeError('The type of `weight_prior_fn` should be `NormalPrior`') raise TypeError('The type of `weight_prior_fn` should be `NormalPrior`')
self.weight_prior = weight_prior_fn() self.weight_prior = weight_prior_fn()
if isinstance(weight_posterior_fn, Cell): try:
if weight_posterior_fn.__class__.__name__ != 'NormalPosterior': self.weight_posterior = weight_posterior_fn(shape=[self.out_channels, self.in_channels], name='bnn_weight')
raise TypeError('The type of `weight_posterior_fn` should be `NormalPosterior`') except TypeError:
else: raise TypeError('The type of `weight_posterior_fn` should be `NormalPosterior`')
if weight_posterior_fn.__name__ != 'NormalPosterior':
raise TypeError('The type of `weight_posterior_fn` should be `NormalPosterior`')
self.weight_posterior = weight_posterior_fn(shape=[self.out_channels, self.in_channels], name='bnn_weight')
if self.has_bias: if self.has_bias:
if isinstance(bias_prior_fn, Cell): if isinstance(bias_prior_fn, Cell):
...@@ -69,13 +66,10 @@ class _DenseVariational(Cell): ...@@ -69,13 +66,10 @@ class _DenseVariational(Cell):
raise TypeError('The type of `bias_prior_fn` should be `NormalPrior`') raise TypeError('The type of `bias_prior_fn` should be `NormalPrior`')
self.bias_prior = bias_prior_fn() self.bias_prior = bias_prior_fn()
if isinstance(bias_posterior_fn, Cell): try:
if bias_posterior_fn.__class__.__name__ != 'NormalPosterior': self.bias_posterior = bias_posterior_fn(shape=[self.out_channels], name='bnn_bias')
raise TypeError('The type of `bias_posterior_fn` should be `NormalPosterior`') except TypeError:
else: raise TypeError('The type of `bias_posterior_fn` should be `NormalPosterior`')
if bias_posterior_fn.__name__ != 'NormalPosterior':
raise TypeError('The type of `bias_posterior_fn` should be `NormalPosterior`')
self.bias_posterior = bias_posterior_fn(shape=[self.out_channels], name='bnn_bias')
self.activation = activation self.activation = activation
if isinstance(self.activation, str): if isinstance(self.activation, str):
......
...@@ -51,15 +51,16 @@ class NormalPosterior(Cell): ...@@ -51,15 +51,16 @@ class NormalPosterior(Cell):
Args: Args:
name (str): Name prepended to trainable parameter. name (str): Name prepended to trainable parameter.
shape (list): Shape of the mean and standard deviation. shape (list, tuple): Shape of the mean and standard deviation.
dtype (class `mindspore.dtype`): The argument is used to define the data type of the output tensor. dtype (class `mindspore.dtype`): The argument is used to define the data type of the output tensor.
Default: mindspore.float32. Default: mindspore.float32.
loc_mean ( float, array_like of floats): Mean of distribution to initialize trainable parameters. Default: 0. loc_mean (int, float, array_like of floats): Mean of distribution to initialize trainable parameters.
loc_std ( float, array_like of floats): Standard deviation of distribution to initialize trainable parameters. Default: 0.
Default: 0.1. loc_std (int, float, array_like of floats): Standard deviation of distribution to initialize trainable
untransformed_scale_mean ( float, array_like of floats): Mean of distribution to initialize trainable parameters. Default: 0.1.
untransformed_scale_mean (int, float, array_like of floats): Mean of distribution to initialize trainable
parameters. Default: -5. parameters. Default: -5.
untransformed_scale_std ( float, array_like of floats): Standard deviation of distribution to initialize untransformed_scale_std (int, float, array_like of floats): Standard deviation of distribution to initialize
trainable parameters. Default: 0.1. trainable parameters. Default: 0.1.
Returns: Returns:
...@@ -80,20 +81,25 @@ class NormalPosterior(Cell): ...@@ -80,20 +81,25 @@ class NormalPosterior(Cell):
if not isinstance(shape, (tuple, list)): if not isinstance(shape, (tuple, list)):
raise TypeError('The type of `shape` should be `tuple` or `list`') raise TypeError('The type of `shape` should be `tuple` or `list`')
if not (np.array(shape) > 0).all(): try:
raise ValueError('Negative dimensions are not allowed') mean_arr = np.random.normal(loc_mean, loc_std, shape)
except ValueError as msg:
raise ValueError(msg)
except TypeError as msg:
raise TypeError(msg)
if not (np.array(loc_std) >= 0).all(): try:
raise ValueError('The value of `loc_std` < 0') untransformed_scale_arr = np.random.normal(untransformed_scale_mean, untransformed_scale_std, shape)
if not (np.array(untransformed_scale_std) >= 0).all(): except ValueError as msg:
raise ValueError('The value of `untransformed_scale_std` < 0') raise ValueError(msg)
except TypeError as msg:
raise TypeError(msg)
self.mean = Parameter( self.mean = Parameter(
Tensor(np.random.normal(loc_mean, loc_std, shape), dtype=dtype), name=name + '_mean') Tensor(mean_arr, dtype=dtype), name=name + '_mean')
self.untransformed_std = Parameter( self.untransformed_std = Parameter(
Tensor(np.random.normal(untransformed_scale_mean, untransformed_scale_std, shape), dtype=dtype), Tensor(untransformed_scale_arr, dtype=dtype), name=name + '_untransformed_std')
name=name + '_untransformed_std')
self.normal = Normal() self.normal = Normal()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册