diff --git a/mindspore/nn/probability/dpn/vae/cvae.py b/mindspore/nn/probability/dpn/vae/cvae.py index ddb6279a17947bdb9b0dcdb38a855678715e2edd..96a7d7b20ea5b513bb5b53a0656f5f88ecf2b4d0 100644 --- a/mindspore/nn/probability/dpn/vae/cvae.py +++ b/mindspore/nn/probability/dpn/vae/cvae.py @@ -93,18 +93,21 @@ class ConditionalVAE(Cell): recon_x = self._decode(z_c) return recon_x, x, mu, std - def generate_sample(self, sample_y, generate_nums=None, shape=None): + def generate_sample(self, sample_y, generate_nums, shape): """ Randomly sample from latent space to generate sample. Args: sample_y (Tensor): Define the label of sample, int tensor. generate_nums (int): The number of samples to generate. - shape(tuple): The shape of sample, it should be math:`(generate_nums, C, H, W)`. + shape(tuple): The shape of sample, it should be math:`(generate_nums, C, H, W)` or math:`(-1, C, H, W)`. Returns: Tensor, the generated sample. """ + generate_nums = check_int_positive(generate_nums) + if not isinstance(shape, tuple) or len(shape) != 4 or shape[0] != generate_nums or shape[0] != -1: + raise ValueError('The shape should be (generate_nums, C, H, W) or (-1, C, H, W).') sample_z = self.normal((generate_nums, self.latent_size), self.to_tensor(0.0), self.to_tensor(1.0), seed=0) sample_y = self.one_hot(sample_y) sample_c = self.concat((sample_z, sample_y)) diff --git a/mindspore/nn/probability/dpn/vae/vae.py b/mindspore/nn/probability/dpn/vae/vae.py index 3137d6a4e174a7536f465d70b57f29eaee75b1ad..731a7608621940d2791bae9a47d15330d7ae9260 100644 --- a/mindspore/nn/probability/dpn/vae/vae.py +++ b/mindspore/nn/probability/dpn/vae/vae.py @@ -88,11 +88,14 @@ class VAE(Cell): Args: generate_nums (int): The number of samples to generate. - shape(tuple): The shape of sample, it should be math:`(generate_nums, C, H, W)`. + shape(tuple): The shape of sample, it should be math:`(generate_nums, C, H, W)` or math:`(-1, C, H, W)`. Returns: Tensor, the generated sample. """ + generate_nums = check_int_positive(generate_nums) + if not isinstance(shape, tuple) or len(shape) != 4 or shape[0] != generate_nums or shape[0] != -1: + raise ValueError('The shape should be (generate_nums, C, H, W) or (-1, C, H, W).') sample_z = self.normal((generate_nums, self.latent_size), self.to_tensor(0.0), self.to_tensor(1.0), seed=0) sample = self._decode(sample_z) sample = self.reshape(sample, shape) diff --git a/mindspore/nn/probability/toolbox/uncertainty_evaluation.py b/mindspore/nn/probability/toolbox/uncertainty_evaluation.py index 4467eb9b2a5f45e36a859b5909679d6c23b89328..f8a12f07abd1e0b5a337a6d36b07ff24a192e715 100644 --- a/mindspore/nn/probability/toolbox/uncertainty_evaluation.py +++ b/mindspore/nn/probability/toolbox/uncertainty_evaluation.py @@ -13,18 +13,20 @@ # limitations under the License. # ============================================================================ """Toolbox for Uncertainty Evaluation.""" -import numpy as np +from copy import deepcopy +import numpy as np from mindspore._checkparam import check_int_positive, check_bool from mindspore.ops import composite as C from mindspore.ops import operations as P from mindspore.train import Model from mindspore.train.callback import LossMonitor, ModelCheckpoint, CheckpointConfig from mindspore.train.serialization import load_checkpoint, load_param_into_net + from ...cell import Cell from ...layer.basic import Dense, Flatten, Dropout -from ...layer.conv import Conv2d from ...layer.container import SequentialCell +from ...layer.conv import Conv2d from ...loss import SoftmaxCrossEntropyWithLogits, MSELoss from ...metrics import Accuracy, MSE from ...optim import Adam @@ -36,8 +38,7 @@ class UncertaintyEvaluation: Args: model (Cell): The model for uncertainty evaluation. - epi_train_dataset (Dataset): A dataset iterator to train model for obtain epistemic uncertainty. - ale_train_dataset (Dataset): A dataset iterator to train model for obtain aleatoric uncertainty. + train_dataset (Dataset): A dataset iterator to train model. task_type (str): Option for the task types of model - regression: A regression model. - classification: A classification model. @@ -45,22 +46,20 @@ class UncertaintyEvaluation: If the task type is classification, it must be set; if not classification, it need not to be set. Default: None. epochs (int): Total number of iterations on the data. Default: 1. - epi_uncer_model_path (str): The save or read path of the epistemic uncertainty model. - ale_uncer_model_path (str): The save or read path of the aleatoric uncertainty model. + epi_uncer_model_path (str): The save or read path of the epistemic uncertainty model. Default: None. + ale_uncer_model_path (str): The save or read path of the aleatoric uncertainty model. Default: None. save_model (bool): Save the uncertainty model or not, if True, the epi_uncer_model_path and ale_uncer_model_path should not be None. If False, give the path of the uncertainty model, it will load the model to evaluate, if not given - the path, it will not save or load the uncertainty model. + the path, it will not save or load the uncertainty model. Default: False. Examples: >>> network = LeNet() >>> param_dict = load_checkpoint('checkpoint_lenet.ckpt') >>> load_param_into_net(network, param_dict) - >>> epi_ds_train = create_dataset('workspace/mnist/train') - >>> ale_ds_train = create_dataset('workspace/mnist/train') + >>> ds_train = create_dataset('workspace/mnist/train') >>> evaluation = UncertaintyEvaluation(model=network, - >>> epi_train_dataset=epi_ds_train, - >>> ale_train_dataset=ale_ds_train, + >>> train_dataset=ds_train, >>> task_type='classification', >>> num_classes=10, >>> epochs=1, @@ -71,12 +70,12 @@ class UncertaintyEvaluation: >>> aleatoric_uncertainty = evaluation.eval_aleatoric_uncertainty(eval_data) """ - def __init__(self, model, epi_train_dataset, ale_train_dataset, task_type, num_classes=None, epochs=1, + def __init__(self, model, train_dataset, task_type, num_classes=None, epochs=1, epi_uncer_model_path=None, ale_uncer_model_path=None, save_model=False): self.epi_model = model self.ale_model = model - self.epi_train_dataset = epi_train_dataset - self.ale_train_dataset = ale_train_dataset + self.epi_train_dataset = train_dataset + self.ale_train_dataset = deepcopy(train_dataset) self.task_type = task_type self.epochs = check_int_positive(epochs) self.epi_uncer_model_path = epi_uncer_model_path @@ -93,6 +92,8 @@ class UncertaintyEvaluation: raise ValueError('The task should be regression or classification.') if task_type == 'classification': self.num_classes = check_int_positive(num_classes) + else: + self.num_classes = num_classes if save_model: if epi_uncer_model_path is None or ale_uncer_model_path is None: raise ValueError("If save_model is True, the epi_uncer_model_path and " diff --git a/tests/st/probability/test_uncertainty.py b/tests/st/probability/test_uncertainty.py index ed3f45883f999671c7ac10c557b2ae25882327d9..92850141eb801981009833b7b30450be951e204c 100644 --- a/tests/st/probability/test_uncertainty.py +++ b/tests/st/probability/test_uncertainty.py @@ -119,12 +119,10 @@ if __name__ == '__main__': param_dict = load_checkpoint('checkpoint_lenet.ckpt') load_param_into_net(network, param_dict) # get train and eval dataset - epi_ds_train = create_dataset('workspace/mnist/train') - ale_ds_train = create_dataset('workspace/mnist/train') + ds_train = create_dataset('workspace/mnist/train') ds_eval = create_dataset('workspace/mnist/test') evaluation = UncertaintyEvaluation(model=network, - epi_train_dataset=epi_ds_train, - ale_train_dataset=ale_ds_train, + train_dataset=ds_train, task_type='classification', num_classes=10, epochs=1,