提交 82749d0a 编写于 作者: B bingyaweng

fix error in bnn_layers and transforms

上级 61dbb1b1
...@@ -26,8 +26,8 @@ class ClassWrap: ...@@ -26,8 +26,8 @@ class ClassWrap:
self._cls = cls self._cls = cls
self.bnn_loss_file = None self.bnn_loss_file = None
def __call__(self, backbone, loss_fn, backbone_factor, kl_factor): def __call__(self, backbone, loss_fn, dnn_factor, bnn_factor):
obj = self._cls(backbone, loss_fn, backbone_factor, kl_factor) obj = self._cls(backbone, loss_fn, dnn_factor, bnn_factor)
bnn_with_loss = obj() bnn_with_loss = obj()
self.bnn_loss_file = obj.bnn_loss_file self.bnn_loss_file = obj.bnn_loss_file
return bnn_with_loss return bnn_with_loss
...@@ -65,6 +65,11 @@ class WithBNNLossCell: ...@@ -65,6 +65,11 @@ class WithBNNLossCell:
""" """
def __init__(self, backbone, loss_fn, dnn_factor=1, bnn_factor=1): def __init__(self, backbone, loss_fn, dnn_factor=1, bnn_factor=1):
if not isinstance(dnn_factor, (int, float)):
raise TypeError('The type of `dnn_factor` should be `int` or `float`')
if not isinstance(bnn_factor, (int, float)):
raise TypeError('The type of `bnn_factor` should be `int` or `float`')
self.backbone = backbone self.backbone = backbone
self.loss_fn = loss_fn self.loss_fn = loss_fn
self.dnn_factor = dnn_factor self.dnn_factor = dnn_factor
......
...@@ -79,20 +79,40 @@ class _ConvVariational(_Conv): ...@@ -79,20 +79,40 @@ class _ConvVariational(_Conv):
self.weight.requires_grad = False self.weight.requires_grad = False
if isinstance(weight_prior_fn, Cell): if isinstance(weight_prior_fn, Cell):
if weight_prior_fn.__class__.__name__ != 'NormalPrior':
raise TypeError('The type of `weight_prior_fn` should be `NormalPrior`')
self.weight_prior = weight_prior_fn self.weight_prior = weight_prior_fn
else: else:
if weight_prior_fn.__name__ != 'NormalPrior':
raise TypeError('The type of `weight_prior_fn` should be `NormalPrior`')
self.weight_prior = weight_prior_fn() self.weight_prior = weight_prior_fn()
if isinstance(weight_posterior_fn, Cell):
if weight_posterior_fn.__class__.__name__ != 'NormalPosterior':
raise TypeError('The type of `weight_posterior_fn` should be `NormalPosterior`')
else:
if weight_posterior_fn.__name__ != 'NormalPosterior':
raise TypeError('The type of `weight_posterior_fn` should be `NormalPosterior`')
self.weight_posterior = weight_posterior_fn(shape=self.shape, name='bnn_weight') self.weight_posterior = weight_posterior_fn(shape=self.shape, name='bnn_weight')
if self.has_bias: if self.has_bias:
self.bias.requires_grad = False self.bias.requires_grad = False
if isinstance(bias_prior_fn, Cell): if isinstance(bias_prior_fn, Cell):
if bias_prior_fn.__class__.__name__ != 'NormalPrior':
raise TypeError('The type of `bias_prior_fn` should be `NormalPrior`')
self.bias_prior = bias_prior_fn self.bias_prior = bias_prior_fn
else: else:
if bias_prior_fn.__name__ != 'NormalPrior':
raise TypeError('The type of `bias_prior_fn` should be `NormalPrior`')
self.bias_prior = bias_prior_fn() self.bias_prior = bias_prior_fn()
if isinstance(bias_posterior_fn, Cell):
if bias_posterior_fn.__class__.__name__ != 'NormalPosterior':
raise TypeError('The type of `bias_posterior_fn` should be `NormalPosterior`')
else:
if bias_posterior_fn.__name__ != 'NormalPosterior':
raise TypeError('The type of `bias_posterior_fn` should be `NormalPosterior`')
self.bias_posterior = bias_posterior_fn(shape=[self.out_channels], name='bnn_bias') self.bias_posterior = bias_posterior_fn(shape=[self.out_channels], name='bnn_bias')
# mindspore operations # mindspore operations
......
...@@ -43,18 +43,38 @@ class _DenseVariational(Cell): ...@@ -43,18 +43,38 @@ class _DenseVariational(Cell):
self.has_bias = check_bool(has_bias) self.has_bias = check_bool(has_bias)
if isinstance(weight_prior_fn, Cell): if isinstance(weight_prior_fn, Cell):
if weight_prior_fn.__class__.__name__ != 'NormalPrior':
raise TypeError('The type of `weight_prior_fn` should be `NormalPrior`')
self.weight_prior = weight_prior_fn self.weight_prior = weight_prior_fn
else: else:
if weight_prior_fn.__name__ != 'NormalPrior':
raise TypeError('The type of `weight_prior_fn` should be `NormalPrior`')
self.weight_prior = weight_prior_fn() self.weight_prior = weight_prior_fn()
if isinstance(weight_posterior_fn, Cell):
if weight_posterior_fn.__class__.__name__ != 'NormalPosterior':
raise TypeError('The type of `weight_posterior_fn` should be `NormalPosterior`')
else:
if weight_posterior_fn.__name__ != 'NormalPosterior':
raise TypeError('The type of `weight_posterior_fn` should be `NormalPosterior`')
self.weight_posterior = weight_posterior_fn(shape=[self.out_channels, self.in_channels], name='bnn_weight') self.weight_posterior = weight_posterior_fn(shape=[self.out_channels, self.in_channels], name='bnn_weight')
if self.has_bias: if self.has_bias:
if isinstance(bias_prior_fn, Cell): if isinstance(bias_prior_fn, Cell):
if bias_prior_fn.__class__.__name__ != 'NormalPrior':
raise TypeError('The type of `bias_prior_fn` should be `NormalPrior`')
self.bias_prior = bias_prior_fn self.bias_prior = bias_prior_fn
else: else:
if bias_prior_fn.__name__ != 'NormalPrior':
raise TypeError('The type of `bias_prior_fn` should be `NormalPrior`')
self.bias_prior = bias_prior_fn() self.bias_prior = bias_prior_fn()
if isinstance(bias_posterior_fn, Cell):
if bias_posterior_fn.__class__.__name__ != 'NormalPosterior':
raise TypeError('The type of `bias_posterior_fn` should be `NormalPosterior`')
else:
if bias_posterior_fn.__name__ != 'NormalPosterior':
raise TypeError('The type of `bias_posterior_fn` should be `NormalPosterior`')
self.bias_posterior = bias_posterior_fn(shape=[self.out_channels], name='bnn_bias') self.bias_posterior = bias_posterior_fn(shape=[self.out_channels], name='bnn_bias')
self.activation = activation self.activation = activation
......
...@@ -75,7 +75,18 @@ class NormalPosterior(Cell): ...@@ -75,7 +75,18 @@ class NormalPosterior(Cell):
untransformed_scale_std=0.1): untransformed_scale_std=0.1):
super(NormalPosterior, self).__init__() super(NormalPosterior, self).__init__()
if not isinstance(name, str): if not isinstance(name, str):
raise ValueError('The type of `name` should be `str`') raise TypeError('The type of `name` should be `str`')
if not isinstance(shape, (tuple, list)):
raise TypeError('The type of `shape` should be `tuple` or `list`')
if not (np.array(shape) > 0).all():
raise ValueError('Negative dimensions are not allowed')
if not (np.array(loc_std) >= 0).all():
raise ValueError('The value of `loc_std` < 0')
if not (np.array(untransformed_scale_std) >= 0).all():
raise ValueError('The value of `untransformed_scale_std` < 0')
self.mean = Parameter( self.mean = Parameter(
Tensor(np.random.normal(loc_mean, loc_std, shape), dtype=dtype), name=name + '_mean') Tensor(np.random.normal(loc_mean, loc_std, shape), dtype=dtype), name=name + '_mean')
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
""" """
Transforms. Transforms.
The high-level components used to transform model between DNN and DNN. The high-level components used to transform model between DNN and BNN.
""" """
from . import transform_bnn from . import transform_bnn
from .transform_bnn import TransformToBNN from .transform_bnn import TransformToBNN
......
...@@ -54,3 +54,13 @@ class WithBNNLossCell(nn.Cell): ...@@ -54,3 +54,13 @@ class WithBNNLossCell(nn.Cell):
self.kl_loss.append(layer.compute_kl_loss) self.kl_loss.append(layer.compute_kl_loss)
else: else:
self._add_kl_loss(layer) self._add_kl_loss(layer)
@property
def backbone_network(self):
"""
Returns the backbone network.
Returns:
Cell, the backbone network.
"""
return self._backbone
...@@ -61,6 +61,11 @@ class TransformToBNN: ...@@ -61,6 +61,11 @@ class TransformToBNN:
""" """
def __init__(self, trainable_dnn, dnn_factor=1, bnn_factor=1): def __init__(self, trainable_dnn, dnn_factor=1, bnn_factor=1):
if not isinstance(dnn_factor, (int, float)):
raise TypeError('The type of `dnn_factor` should be `int` or `float`')
if not isinstance(bnn_factor, (int, float)):
raise TypeError('The type of `bnn_factor` should be `int` or `float`')
net_with_loss = trainable_dnn.network net_with_loss = trainable_dnn.network
self.optimizer = trainable_dnn.optimizer self.optimizer = trainable_dnn.optimizer
self.backbone = net_with_loss.backbone_network self.backbone = net_with_loss.backbone_network
...@@ -88,8 +93,10 @@ class TransformToBNN: ...@@ -88,8 +93,10 @@ class TransformToBNN:
get_conv_args (function): The arguments gotten from the DNN convolutional layer. Default: lambda dp: get_conv_args (function): The arguments gotten from the DNN convolutional layer. Default: lambda dp:
{"in_channels": dp.in_channels, "out_channels": dp.out_channels, "pad_mode": dp.pad_mode, {"in_channels": dp.in_channels, "out_channels": dp.out_channels, "pad_mode": dp.pad_mode,
"kernel_size": dp.kernel_size, "stride": dp.stride, "has_bias": dp.has_bias}. "kernel_size": dp.kernel_size, "stride": dp.stride, "has_bias": dp.has_bias}.
add_dense_args (dict): The new arguments added to BNN full connection layer. Default: {}. add_dense_args (dict): The new arguments added to BNN full connection layer. Note that the arguments in
add_conv_args (dict): The new arguments added to BNN convolutional layer. Default: {}. `add_dense_args` should not duplicate arguments in `get_dense_args`. Default: {}.
add_conv_args (dict): The new arguments added to BNN convolutional layer. Note that the arguments in
`add_conv_args` should not duplicate arguments in `get_conv_args`. Default: {}.
Returns: Returns:
Cell, a trainable BNN model wrapped by TrainOneStepCell. Cell, a trainable BNN model wrapped by TrainOneStepCell.
...@@ -131,7 +138,8 @@ class TransformToBNN: ...@@ -131,7 +138,8 @@ class TransformToBNN:
bnn_layer_type (Cell): The type of BNN layer to be transformed to. The optional values are bnn_layer_type (Cell): The type of BNN layer to be transformed to. The optional values are
DenseReparameterization, ConvReparameterization. DenseReparameterization, ConvReparameterization.
get_args (dict): The arguments gotten from the DNN layer. Default: None. get_args (dict): The arguments gotten from the DNN layer. Default: None.
add_args (dict): The new arguments added to BNN layer. Default: None. add_args (dict): The new arguments added to BNN layer. Note that the arguments in `add_args` should not
duplicate arguments in `get_args`. Default: None.
Returns: Returns:
Cell, a trainable model wrapped by TrainOneStepCell, whose sprcific type of layer is transformed to the Cell, a trainable model wrapped by TrainOneStepCell, whose sprcific type of layer is transformed to the
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册