diff --git a/mindspore/nn/probability/bijector/bijector.py b/mindspore/nn/probability/bijector/bijector.py index 22777231f65ba54a933d257d350e57a145232341..e6530b9f7026d7cc1acf8ad3b30ae5ba053de8e2 100644 --- a/mindspore/nn/probability/bijector/bijector.py +++ b/mindspore/nn/probability/bijector/bijector.py @@ -69,31 +69,31 @@ class Bijector(Cell): def is_injective(self): return self._is_injective - def forward(self, *args): + def forward(self, *args, **kwargs): """ Forward transformation: transform the input value to another distribution. """ - return self._forward(*args) + return self._forward(*args, **kwargs) - def inverse(self, *args): + def inverse(self, *args, **kwargs): """ Inverse transformation: transform the input value back to the original distribution. """ - return self._inverse(*args) + return self._inverse(*args, **kwargs) - def forward_log_jacobian(self, *args): + def forward_log_jacobian(self, *args, **kwargs): """ Logarithm of the derivative of forward transformation. """ - return self._forward_log_jacobian(*args) + return self._forward_log_jacobian(*args, **kwargs) - def inverse_log_jacobian(self, *args): + def inverse_log_jacobian(self, *args, **kwargs): """ Logarithm of the derivative of forward transformation. """ - return self._inverse_log_jacobian(*args) + return self._inverse_log_jacobian(*args, **kwargs) - def __call__(self, *args): + def __call__(self, *args, **kwargs): """ Call Bijector directly. This __call__ may go into two directions: @@ -107,9 +107,9 @@ class Bijector(Cell): """ if isinstance(args[0], Distribution): return TransformedDistribution(self, args[0]) - return super(Bijector, self).__call__(*args) + return super(Bijector, self).__call__(*args, **kwargs) - def construct(self, name, *args): + def construct(self, name, *args, **kwargs): """ Override construct in Cell. @@ -120,11 +120,11 @@ class Bijector(Cell): Always raise RuntimeError as Distribution should not be called directly. """ if name == 'forward': - return self.forward(*args) + return self.forward(*args, **kwargs) if name == 'inverse': - return self.inverse(*args) + return self.inverse(*args, **kwargs) if name == 'forward_log_jacobian': - return self.forward_log_jacobian(*args) + return self.forward_log_jacobian(*args, **kwargs) if name == 'inverse_log_jacobian': - return self.inverse_log_jacobian(*args) + return self.inverse_log_jacobian(*args, **kwargs) return None diff --git a/mindspore/nn/probability/distribution/distribution.py b/mindspore/nn/probability/distribution/distribution.py index dd3d39f0d702188a2805f88de5d2a33357bc347f..90ad66292d7e273e5d8f1303c9b69c0b8b29356a 100644 --- a/mindspore/nn/probability/distribution/distribution.py +++ b/mindspore/nn/probability/distribution/distribution.py @@ -27,7 +27,7 @@ class Distribution(Cell): Note: Derived class should override operations such as ,_mean, _prob, - and _log_prob. Arguments should be passed in through *args. + and _log_prob. Arguments should be passed in through *args or **kwargs. Dist_spec_args are unique for each type of distribution. For example, mean and sd are the dist_spec_args for a Normal distribution. @@ -171,7 +171,7 @@ class Distribution(Cell): if hasattr(self, '_cross_entropy'): self._call_cross_entropy = self._cross_entropy - def log_prob(self, *args): + def log_prob(self, *args, **kwargs): """ Evaluate the log probability(pdf or pmf) at the given value. @@ -179,18 +179,18 @@ class Distribution(Cell): Args must include value. Dist_spec_args are optional. """ - return self._call_log_prob(*args) + return self._call_log_prob(*args, **kwargs) - def _calc_prob_from_log_prob(self, *args): + def _calc_prob_from_log_prob(self, *args, **kwargs): r""" Evaluate prob from log probability. .. math:: probability(x) = \exp(log_likehood(x)) """ - return self.exp(self._log_prob(*args)) + return self.exp(self._log_prob(*args, **kwargs)) - def prob(self, *args): + def prob(self, *args, **kwargs): """ Evaluate the probability (pdf or pmf) at given value. @@ -198,18 +198,18 @@ class Distribution(Cell): Args must include value. Dist_spec_args are optional. """ - return self._call_prob(*args) + return self._call_prob(*args, **kwargs) - def _calc_log_prob_from_prob(self, *args): + def _calc_log_prob_from_prob(self, *args, **kwargs): r""" Evaluate log probability from probability. .. math:: log_prob(x) = \log(prob(x)) """ - return self.log(self._prob(*args)) + return self.log(self._prob(*args, **kwargs)) - def cdf(self, *args): + def cdf(self, *args, **kwargs): """ Evaluate the cdf at given value. @@ -217,36 +217,36 @@ class Distribution(Cell): Args must include value. Dist_spec_args are optional. """ - return self._call_cdf(*args) + return self._call_cdf(*args, **kwargs) - def _calc_cdf_from_log_cdf(self, *args): + def _calc_cdf_from_log_cdf(self, *args, **kwargs): r""" Evaluate cdf from log_cdf. .. math:: cdf(x) = \exp(log_cdf(x)) """ - return self.exp(self._log_cdf(*args)) + return self.exp(self._log_cdf(*args, **kwargs)) - def _calc_cdf_from_survival(self, *args): + def _calc_cdf_from_survival(self, *args, **kwargs): r""" Evaluate cdf from survival function. .. math:: cdf(x) = 1 - (survival_function(x)) """ - return 1.0 - self._survival_function(*args) + return 1.0 - self._survival_function(*args, **kwargs) - def _calc_cdf_from_log_survival(self, *args): + def _calc_cdf_from_log_survival(self, *args, **kwargs): r""" Evaluate cdf from log survival function. .. math:: cdf(x) = 1 - (\exp(log_survival(x))) """ - return 1.0 - self.exp(self._log_survival(*args)) + return 1.0 - self.exp(self._log_survival(*args, **kwargs)) - def log_cdf(self, *args): + def log_cdf(self, *args, **kwargs): """ Evaluate the log cdf at given value. @@ -254,18 +254,18 @@ class Distribution(Cell): Args must include value. Dist_spec_args are optional. """ - return self._call_log_cdf(*args) + return self._call_log_cdf(*args, **kwargs) - def _calc_log_cdf_from_call_cdf(self, *args): + def _calc_log_cdf_from_call_cdf(self, *args, **kwargs): r""" Evaluate log cdf from cdf. .. math:: log_cdf(x) = \log(cdf(x)) """ - return self.log(self._call_cdf(*args)) + return self.log(self._call_cdf(*args, **kwargs)) - def survival_function(self, *args): + def survival_function(self, *args, **kwargs): """ Evaluate the survival function at given value. @@ -273,27 +273,27 @@ class Distribution(Cell): Args must include value. Dist_spec_args are optional. """ - return self._call_survival(*args) + return self._call_survival(*args, **kwargs) - def _calc_survival_from_call_cdf(self, *args): + def _calc_survival_from_call_cdf(self, *args, **kwargs): r""" Evaluate survival function from cdf. .. math:: survival_function(x) = 1 - (cdf(x)) """ - return 1.0 - self._call_cdf(*args) + return 1.0 - self._call_cdf(*args, **kwargs) - def _calc_survival_from_log_survival(self, *args): + def _calc_survival_from_log_survival(self, *args, **kwargs): r""" Evaluate survival function from log survival function. .. math:: survival(x) = \exp(survival_function(x)) """ - return self.exp(self._log_survival(*args)) + return self.exp(self._log_survival(*args, **kwargs)) - def log_survival(self, *args): + def log_survival(self, *args, **kwargs): """ Evaluate the log survival function at given value. @@ -301,18 +301,18 @@ class Distribution(Cell): Args must include value. Dist_spec_args are optional. """ - return self._call_log_survival(*args) + return self._call_log_survival(*args, **kwargs) - def _calc_log_survival_from_call_survival(self, *args): + def _calc_log_survival_from_call_survival(self, *args, **kwargs): r""" Evaluate log survival function from survival function. .. math:: log_survival(x) = \log(survival_function(x)) """ - return self.log(self._call_survival(*args)) + return self.log(self._call_survival(*args, **kwargs)) - def kl_loss(self, *args): + def kl_loss(self, *args, **kwargs): """ Evaluate the KL divergence, i.e. KL(a||b). @@ -320,72 +320,72 @@ class Distribution(Cell): Args must include type of the distribution, parameters of distribution b. Parameters for distribution a are optional. """ - return self._kl_loss(*args) + return self._kl_loss(*args, **kwargs) - def mean(self, *args): + def mean(self, *args, **kwargs): """ Evaluate the mean. Note: Dist_spec_args are optional. """ - return self._mean(*args) + return self._mean(*args, **kwargs) - def mode(self, *args): + def mode(self, *args, **kwargs): """ Evaluate the mode. Note: Dist_spec_args are optional. """ - return self._mode(*args) + return self._mode(*args, **kwargs) - def sd(self, *args): + def sd(self, *args, **kwargs): """ Evaluate the standard deviation. Note: Dist_spec_args are optional. """ - return self._call_sd(*args) + return self._call_sd(*args, **kwargs) - def var(self, *args): + def var(self, *args, **kwargs): """ Evaluate the variance. Note: Dist_spec_args are optional. """ - return self._call_var(*args) + return self._call_var(*args, **kwargs) - def _calc_sd_from_var(self, *args): + def _calc_sd_from_var(self, *args, **kwargs): r""" Evaluate log probability from probability. .. math:: STD(x) = \sqrt(VAR(x)) """ - return self.sqrt(self._var(*args)) + return self.sqrt(self._var(*args, **kwargs)) - def _calc_var_from_sd(self, *args): + def _calc_var_from_sd(self, *args, **kwargs): r""" Evaluate log probability from probability. .. math:: VAR(x) = STD(x) ^ 2 """ - return self.sq(self._sd(*args)) + return self.sq(self._sd(*args, **kwargs)) - def entropy(self, *args): + def entropy(self, *args, **kwargs): """ Evaluate the entropy. Note: Dist_spec_args are optional. """ - return self._entropy(*args) + return self._entropy(*args, **kwargs) - def cross_entropy(self, *args): + def cross_entropy(self, *args, **kwargs): """ Evaluate the cross_entropy between distribution a and b. @@ -393,32 +393,29 @@ class Distribution(Cell): Args must include type of the distribution, parameters of distribution b. Parameters for distribution a are optional. """ - return self._call_cross_entropy(*args) + return self._call_cross_entropy(*args, **kwargs) - def _calc_cross_entropy(self, *args): + def _calc_cross_entropy(self, *args, **kwargs): r""" Evaluate cross_entropy from entropy and kl divergence. .. math:: H(X, Y) = H(X) + KL(X||Y) """ - return self._entropy(*args) + self._kl_loss(*args) + return self._entropy(*args, **kwargs) + self._kl_loss(*args, **kwargs) - def sample(self, *args): + def sample(self, *args, **kwargs): """ Sampling function. - Args: - *args (list): arguments passed in through construct. - Note: Shape of the sample is default to (). Dist_spec_args are optional. """ - return self._sample(*args) + return self._sample(*args, **kwargs) - def construct(self, name, *args): + def construct(self, name, *args, **kwargs): """ Override construct in Cell. @@ -433,31 +430,31 @@ class Distribution(Cell): """ if name == 'log_prob': - return self._call_log_prob(*args) + return self._call_log_prob(*args, **kwargs) if name == 'prob': - return self._call_prob(*args) + return self._call_prob(*args, **kwargs) if name == 'cdf': - return self._call_cdf(*args) + return self._call_cdf(*args, **kwargs) if name == 'log_cdf': - return self._call_log_cdf(*args) + return self._call_log_cdf(*args, **kwargs) if name == 'survival_function': - return self._call_survival(*args) + return self._call_survival(*args, **kwargs) if name == 'log_survival': - return self._call_log_survival(*args) + return self._call_log_survival(*args, **kwargs) if name == 'kl_loss': - return self._kl_loss(*args) + return self._kl_loss(*args, **kwargs) if name == 'mean': - return self._mean(*args) + return self._mean(*args, **kwargs) if name == 'mode': - return self._mode(*args) + return self._mode(*args, **kwargs) if name == 'sd': - return self._call_sd(*args) + return self._call_sd(*args, **kwargs) if name == 'var': - return self._call_var(*args) + return self._call_var(*args, **kwargs) if name == 'entropy': - return self._entropy(*args) + return self._entropy(*args, **kwargs) if name == 'cross_entropy': - return self._call_cross_entropy(*args) + return self._call_cross_entropy(*args, **kwargs) if name == 'sample': - return self._sample(*args) + return self._sample(*args, **kwargs) return None diff --git a/mindspore/nn/probability/distribution/normal.py b/mindspore/nn/probability/distribution/normal.py index f243a2bc31fc664b4f82ebf00eefe814dbc39c28..a984c161bfd5ed6e06c59b8bb071a1814a179289 100644 --- a/mindspore/nn/probability/distribution/normal.py +++ b/mindspore/nn/probability/distribution/normal.py @@ -256,8 +256,5 @@ class Normal(Distribution): sd = self._sd_value if sd is None else sd batch_shape = self.shape(self.zeroslike(mean) + self.zeroslike(sd)) sample_shape = shape + batch_shape - mean_zero = self.const(0.0) - sd_one = self.const(1.0) - sample_norm = C.normal(sample_shape, mean_zero, sd_one, self.seed) - sample = mean + sample_norm * sd - return sample + sample_norm = C.normal(sample_shape, mean, sd, self.seed) + return sample_norm