提交 ab45bec8 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4924 Modify API comments and fix error of st

Merge pull request !4924 from byweng/fix_param_check
...@@ -67,8 +67,13 @@ class WithBNNLossCell: ...@@ -67,8 +67,13 @@ class WithBNNLossCell:
def __init__(self, backbone, loss_fn, dnn_factor=1, bnn_factor=1): def __init__(self, backbone, loss_fn, dnn_factor=1, bnn_factor=1):
if isinstance(dnn_factor, bool) or not isinstance(dnn_factor, (int, float)): if isinstance(dnn_factor, bool) or not isinstance(dnn_factor, (int, float)):
raise TypeError('The type of `dnn_factor` should be `int` or `float`') raise TypeError('The type of `dnn_factor` should be `int` or `float`')
if dnn_factor < 0:
raise ValueError('The value of `dnn_factor` should >= 0')
if isinstance(bnn_factor, bool) or not isinstance(bnn_factor, (int, float)): if isinstance(bnn_factor, bool) or not isinstance(bnn_factor, (int, float)):
raise TypeError('The type of `bnn_factor` should be `int` or `float`') raise TypeError('The type of `bnn_factor` should be `int` or `float`')
if bnn_factor < 0:
raise ValueError('The value of `bnn_factor` should >= 0')
self.backbone = backbone self.backbone = backbone
self.loss_fn = loss_fn self.loss_fn = loss_fn
......
...@@ -61,12 +61,6 @@ class _ConvVariational(_Conv): ...@@ -61,12 +61,6 @@ class _ConvVariational(_Conv):
raise ValueError('Attr \'pad_mode\' of \'Conv2d\' Op passed ' raise ValueError('Attr \'pad_mode\' of \'Conv2d\' Op passed '
+ str(pad_mode) + ', should be one of values in \'valid\', \'same\', \'pad\'.') + str(pad_mode) + ', should be one of values in \'valid\', \'same\', \'pad\'.')
if isinstance(stride, bool) or not isinstance(stride, (int, tuple)):
raise TypeError('The type of `stride` should be `int` of `tuple`')
if isinstance(dilation, bool) or not isinstance(dilation, (int, tuple)):
raise TypeError('The type of `dilation` should be `int` of `tuple`')
# convolution args # convolution args
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
......
...@@ -29,7 +29,7 @@ class NormalPrior(Cell): ...@@ -29,7 +29,7 @@ class NormalPrior(Cell):
To initialize a normal distribution of mean 0 and standard deviation 0.1. To initialize a normal distribution of mean 0 and standard deviation 0.1.
Args: Args:
dtype (class `mindspore.dtype`): The argument is used to define the data type of the output tensor. dtype (:class:`mindspore.dtype`): The argument is used to define the data type of the output tensor.
Default: mindspore.float32. Default: mindspore.float32.
mean (int, float): Mean of normal distribution. mean (int, float): Mean of normal distribution.
std (int, float): Standard deviation of normal distribution. std (int, float): Standard deviation of normal distribution.
...@@ -52,7 +52,7 @@ class NormalPosterior(Cell): ...@@ -52,7 +52,7 @@ class NormalPosterior(Cell):
Args: Args:
name (str): Name prepended to trainable parameter. name (str): Name prepended to trainable parameter.
shape (list, tuple): Shape of the mean and standard deviation. shape (list, tuple): Shape of the mean and standard deviation.
dtype (class `mindspore.dtype`): The argument is used to define the data type of the output tensor. dtype (:class:`mindspore.dtype`): The argument is used to define the data type of the output tensor.
Default: mindspore.float32. Default: mindspore.float32.
loc_mean (int, float): Mean of distribution to initialize trainable parameters. Default: 0. loc_mean (int, float): Mean of distribution to initialize trainable parameters. Default: 0.
loc_std (int, float): Standard deviation of distribution to initialize trainable parameters. Default: 0.1. loc_std (int, float): Standard deviation of distribution to initialize trainable parameters. Default: 0.1.
......
...@@ -63,8 +63,13 @@ class TransformToBNN: ...@@ -63,8 +63,13 @@ class TransformToBNN:
def __init__(self, trainable_dnn, dnn_factor=1, bnn_factor=1): def __init__(self, trainable_dnn, dnn_factor=1, bnn_factor=1):
if isinstance(dnn_factor, bool) or not isinstance(dnn_factor, (int, float)): if isinstance(dnn_factor, bool) or not isinstance(dnn_factor, (int, float)):
raise TypeError('The type of `dnn_factor` should be `int` or `float`') raise TypeError('The type of `dnn_factor` should be `int` or `float`')
if dnn_factor < 0:
raise ValueError('The value of `dnn_factor` should >= 0')
if isinstance(bnn_factor, bool) or not isinstance(bnn_factor, (int, float)): if isinstance(bnn_factor, bool) or not isinstance(bnn_factor, (int, float)):
raise TypeError('The type of `bnn_factor` should be `int` or `float`') raise TypeError('The type of `bnn_factor` should be `int` or `float`')
if bnn_factor < 0:
raise ValueError('The value of `bnn_factor` should >= 0')
net_with_loss = trainable_dnn.network net_with_loss = trainable_dnn.network
self.optimizer = trainable_dnn.optimizer self.optimizer = trainable_dnn.optimizer
...@@ -88,9 +93,9 @@ class TransformToBNN: ...@@ -88,9 +93,9 @@ class TransformToBNN:
Transform the whole DNN model to BNN model, and wrap BNN model by TrainOneStepCell. Transform the whole DNN model to BNN model, and wrap BNN model by TrainOneStepCell.
Args: Args:
get_dense_args (function): The arguments gotten from the DNN full connection layer. Default: lambda dp: get_dense_args (:class:`function`): The arguments gotten from the DNN full connection layer. Default: lambda dp:
{"in_channels": dp.in_channels, "out_channels": dp.out_channels, "has_bias": dp.has_bias}. {"in_channels": dp.in_channels, "out_channels": dp.out_channels, "has_bias": dp.has_bias}.
get_conv_args (function): The arguments gotten from the DNN convolutional layer. Default: lambda dp: get_conv_args (:class:`function`): The arguments gotten from the DNN convolutional layer. Default: lambda dp:
{"in_channels": dp.in_channels, "out_channels": dp.out_channels, "pad_mode": dp.pad_mode, {"in_channels": dp.in_channels, "out_channels": dp.out_channels, "pad_mode": dp.pad_mode,
"kernel_size": dp.kernel_size, "stride": dp.stride, "has_bias": dp.has_bias}. "kernel_size": dp.kernel_size, "stride": dp.stride, "has_bias": dp.has_bias}.
add_dense_args (dict): The new arguments added to BNN full connection layer. Note that the arguments in add_dense_args (dict): The new arguments added to BNN full connection layer. Note that the arguments in
...@@ -134,10 +139,10 @@ class TransformToBNN: ...@@ -134,10 +139,10 @@ class TransformToBNN:
Args: Args:
dnn_layer_type (Cell): The type of DNN layer to be transformed to BNN layer. The optional values are dnn_layer_type (Cell): The type of DNN layer to be transformed to BNN layer. The optional values are
nn.Dense, nn.Conv2d. nn.Dense, nn.Conv2d.
bnn_layer_type (Cell): The type of BNN layer to be transformed to. The optional values are bnn_layer_type (Cell): The type of BNN layer to be transformed to. The optional values are
DenseReparameterization, ConvReparameterization. DenseReparam, ConvReparam.
get_args (dict): The arguments gotten from the DNN layer. Default: None. get_args (:class:`function`): The arguments gotten from the DNN layer. Default: None.
add_args (dict): The new arguments added to BNN layer. Note that the arguments in `add_args` should not add_args (dict): The new arguments added to BNN layer. Note that the arguments in `add_args` should not
duplicate arguments in `get_args`. Default: None. duplicate arguments in `get_args`. Default: None.
......
...@@ -108,22 +108,22 @@ class VaeGan(nn.Cell): ...@@ -108,22 +108,22 @@ class VaeGan(nn.Cell):
return ld_real, ld_fake, ld_p, recon_x, x, mu, std return ld_real, ld_fake, ld_p, recon_x, x, mu, std
class VaeGanLoss(nn.Cell): class VaeGanLoss(ELBO):
def __init__(self): def __init__(self):
super(VaeGanLoss, self).__init__() super(VaeGanLoss, self).__init__()
self.zeros = P.ZerosLike() self.zeros = P.ZerosLike()
self.mse = nn.MSELoss(reduction='sum') self.mse = nn.MSELoss(reduction='sum')
self.elbo = ELBO(latent_prior='Normal', output_prior='Normal')
def construct(self, data, label): def construct(self, data, label):
ld_real, ld_fake, ld_p, recon_x, x, mean, std = data ld_real, ld_fake, ld_p, recon_x, x, mu, std = data
y_real = self.zeros(ld_real) + 1 y_real = self.zeros(ld_real) + 1
y_fake = self.zeros(ld_fake) y_fake = self.zeros(ld_fake)
elbo_data = (recon_x, x, mean, std)
loss_D = self.mse(ld_real, y_real) loss_D = self.mse(ld_real, y_real)
loss_GD = self.mse(ld_p, y_fake) loss_GD = self.mse(ld_p, y_fake)
loss_G = self.mse(ld_fake, y_real) loss_G = self.mse(ld_fake, y_real)
elbo_loss = self.elbo(elbo_data, label) reconstruct_loss = self.recon_loss(x, recon_x)
kl_loss = self.posterior('kl_loss', 'Normal', self.zeros(mu), self.zeros(mu) + 1, mu, std)
elbo_loss = reconstruct_loss + self.sum(kl_loss)
return loss_D + loss_G + loss_GD + elbo_loss return loss_D + loss_G + loss_GD + elbo_loss
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册