From e87e1fc6bcb73892c649b27a785048627561434f Mon Sep 17 00:00:00 2001 From: Xun Deng Date: Wed, 29 Jul 2020 16:47:44 -0400 Subject: [PATCH] changed distribution api --- .../nn/probability/distribution/bernoulli.py | 164 ++++++++-------- .../probability/distribution/distribution.py | 106 +++++------ .../probability/distribution/exponential.py | 147 +++++++-------- .../nn/probability/distribution/geometric.py | 177 ++++++++---------- .../nn/probability/distribution/normal.py | 157 +++++++--------- .../nn/probability/distribution/uniform.py | 177 ++++++++---------- .../test_distribution/test_bernoulli.py | 39 ++-- .../test_distribution/test_exponential.py | 38 ++-- .../test_distribution/test_geometric.py | 38 ++-- .../ascend/test_distribution/test_normal.py | 75 +++++--- .../test_distribution/test_normal_new_api.py | 62 ------ .../ascend/test_distribution/test_uniform.py | 38 ++-- .../python/nn/distribution/test_bernoulli.py | 67 ++++--- .../nn/distribution/test_exponential.py | 68 ++++--- .../python/nn/distribution/test_geometric.py | 68 ++++--- .../ut/python/nn/distribution/test_normal.py | 67 +++++-- .../ut/python/nn/distribution/test_uniform.py | 67 +++++-- 17 files changed, 753 insertions(+), 802 deletions(-) delete mode 100644 tests/st/ops/ascend/test_distribution/test_normal_new_api.py diff --git a/mindspore/nn/probability/distribution/bernoulli.py b/mindspore/nn/probability/distribution/bernoulli.py index cab74f97f..0aaeabf9a 100644 --- a/mindspore/nn/probability/distribution/bernoulli.py +++ b/mindspore/nn/probability/distribution/bernoulli.py @@ -34,55 +34,56 @@ class Bernoulli(Distribution): Examples: >>> # To initialize a Bernoulli distribution of prob 0.5 - >>> n = nn.Bernoulli(0.5, dtype=mstype.int32) + >>> import mindspore.nn.probability.distribution as msd + >>> b = msd.Bernoulli(0.5, dtype=mstype.int32) >>> >>> # The following creates two independent Bernoulli distributions - >>> n = nn.Bernoulli([0.5, 0.5], dtype=mstype.int32) + >>> b = msd.Bernoulli([0.5, 0.5], dtype=mstype.int32) >>> >>> # A Bernoulli distribution can be initilized without arguments - >>> # In this case, probs must be passed in through construct. - >>> n = nn.Bernoulli(dtype=mstype.int32) + >>> # In this case, probs must be passed in through args during function calls. + >>> b = msd.Bernoulli(dtype=mstype.int32) >>> - >>> # To use Bernoulli distribution in a network + >>> # To use Bernoulli in a network >>> class net(Cell): >>> def __init__(self): >>> super(net, self).__init__(): - >>> self.b1 = nn.Bernoulli(0.5, dtype=mstype.int32) - >>> self.b2 = nn.Bernoulli(dtype=mstype.int32) + >>> self.b1 = msd.Bernoulli(0.5, dtype=mstype.int32) + >>> self.b2 = msd.Bernoulli(dtype=mstype.int32) >>> >>> # All the following calls in construct are valid >>> def construct(self, value, probs_b, probs_a): >>> >>> # Similar calls can be made to other probability functions >>> # by replacing 'prob' with the name of the function - >>> ans = self.b1('prob', value) + >>> ans = self.b1.prob(value) >>> # Evaluate with the respect to distribution b - >>> ans = self.b1('prob', value, probs_b) + >>> ans = self.b1.prob(value, probs_b) >>> - >>> # probs must be passed in through construct - >>> ans = self.b2('prob', value, probs_a) + >>> # probs must be passed in during function calls + >>> ans = self.b2.prob(value, probs_a) >>> - >>> # Functions 'sd', 'var', 'entropy' have the same usage like 'mean' - >>> # Will return [0.0] - >>> ans = self.b1('mean') - >>> # Will return mean_b - >>> ans = self.b1('mean', probs_b) + >>> # Functions 'sd', 'var', 'entropy' have the same usage as 'mean' + >>> # Will return 0.5 + >>> ans = self.b1.mean() + >>> # Will return probs_b + >>> ans = self.b1.mean(probs_b) >>> - >>> # probs must be passed in through construct - >>> ans = self.b2('mean', probs_a) + >>> # probs must be passed in during function calls + >>> ans = self.b2.mean(probs_a) >>> >>> # Usage of 'kl_loss' and 'cross_entropy' are similar - >>> ans = self.b1('kl_loss', 'Bernoulli', probs_b) - >>> ans = self.b1('kl_loss', 'Bernoulli', probs_b, probs_a) + >>> ans = self.b1.kl_loss('Bernoulli', probs_b) + >>> ans = self.b1.kl_loss('Bernoulli', probs_b, probs_a) >>> - >>> # Additional probs_a must be passed in through construct - >>> ans = self.b2('kl_loss', 'Bernoulli', probs_b, probs_a) + >>> # Additional probs_a must be passed in through + >>> ans = self.b2.kl_loss('Bernoulli', probs_b, probs_a) >>> - >>> # Sample Usage - >>> ans = self.b1('sample') - >>> ans = self.b1('sample', (2,3)) - >>> ans = self.b1('sample', (2,3), probs_b) - >>> ans = self.b2('sample', (2,3), probs_a) + >>> # Sample + >>> ans = self.b1.sample() + >>> ans = self.b1.sample((2,3)) + >>> ans = self.b1.sample((2,3), probs_b) + >>> ans = self.b2.sample((2,3), probs_a) """ def __init__(self, @@ -130,71 +131,61 @@ class Bernoulli(Distribution): """ return self._probs - def _mean(self, name='mean', probs1=None): + def _mean(self, probs1=None): r""" .. math:: MEAN(B) = probs1 """ - if name == 'mean': - return self.probs if probs1 is None else probs1 - return None + return self.probs if probs1 is None else probs1 - def _mode(self, name='mode', probs1=None): + def _mode(self, probs1=None): r""" .. math:: MODE(B) = 1 if probs1 > 0.5 else = 0 """ - if name == 'mode': - probs1 = self.probs if probs1 is None else probs1 - prob_type = self.dtypeop(probs1) - zeros = self.fill(prob_type, self.shape(probs1), 0.0) - ones = self.fill(prob_type, self.shape(probs1), 1.0) - comp = self.less(0.5, probs1) - return self.select(comp, ones, zeros) - return None + probs1 = self.probs if probs1 is None else probs1 + prob_type = self.dtypeop(probs1) + zeros = self.fill(prob_type, self.shape(probs1), 0.0) + ones = self.fill(prob_type, self.shape(probs1), 1.0) + comp = self.less(0.5, probs1) + return self.select(comp, ones, zeros) - def _var(self, name='var', probs1=None): + def _var(self, probs1=None): r""" .. math:: VAR(B) = probs1 * probs0 """ - if name in self._variance_functions: - probs1 = self.probs if probs1 is None else probs1 - probs0 = 1.0 - probs1 - return probs0 * probs1 - return None + probs1 = self.probs if probs1 is None else probs1 + probs0 = 1.0 - probs1 + return probs0 * probs1 - def _entropy(self, name='entropy', probs=None): + def _entropy(self, probs=None): r""" .. math:: H(B) = -probs0 * \log(probs0) - probs1 * \log(probs1) """ - if name == 'entropy': - probs1 = self.probs if probs is None else probs - probs0 = 1 - probs1 - return -1 * (probs0 * self.log(probs0)) - (probs1 * self.log(probs1)) - return None + probs1 = self.probs if probs is None else probs + probs0 = 1 - probs1 + return -1 * (probs0 * self.log(probs0)) - (probs1 * self.log(probs1)) - def _cross_entropy(self, name, dist, probs1_b, probs1_a=None): + def _cross_entropy(self, dist, probs1_b, probs1_a=None): """ Evaluate cross_entropy between Bernoulli distributions. Args: - name (str): name of the funtion. dist (str): type of the distributions. Should be "Bernoulli" in this case. probs1_b (Tensor): probs1 of distribution b. probs1_a (Tensor): probs1 of distribution a. Default: self.probs. """ - if name == 'cross_entropy' and dist == 'Bernoulli': - return self._entropy(probs=probs1_a) + self._kl_loss(name, dist, probs1_b, probs1_a) + if dist == 'Bernoulli': + return self._entropy(probs=probs1_a) + self._kl_loss(dist, probs1_b, probs1_a) return None - def _prob(self, name, value, probs=None): + def _prob(self, value, probs=None): r""" pmf of Bernoulli distribution. Args: - name (str): name of the function. Should be "prob" when passed in from construct. value (Tensor): a Tensor composed of only zeros and ones. probs (Tensor): probability of outcome is 1. Default: self.probs. @@ -202,18 +193,15 @@ class Bernoulli(Distribution): pmf(k) = probs1 if k = 1; pmf(k) = probs0 if k = 0; """ - if name in self._prob_functions: - probs1 = self.probs if probs is None else probs - probs0 = 1.0 - probs1 - return (probs1 * value) + (probs0 * (1.0 - value)) - return None + probs1 = self.probs if probs is None else probs + probs0 = 1.0 - probs1 + return (probs1 * value) + (probs0 * (1.0 - value)) - def _cdf(self, name, value, probs=None): + def _cdf(self, value, probs=None): r""" cdf of Bernoulli distribution. Args: - name (str): name of the function. value (Tensor): value to be evaluated. probs (Tensor): probability of outcome is 1. Default: self.probs. @@ -222,25 +210,22 @@ class Bernoulli(Distribution): cdf(k) = probs0 if 0 <= k <1; cdf(k) = 1 if k >=1; """ - if name in self._cdf_survival_functions: - probs1 = self.probs if probs is None else probs - prob_type = self.dtypeop(probs1) - value = value * self.fill(prob_type, self.shape(probs1), 1.0) - probs0 = 1.0 - probs1 * self.fill(prob_type, self.shape(value), 1.0) - comp_zero = self.less(value, 0.0) - comp_one = self.less(value, 1.0) - zeros = self.fill(prob_type, self.shape(value), 0.0) - ones = self.fill(prob_type, self.shape(value), 1.0) - less_than_zero = self.select(comp_zero, zeros, probs0) - return self.select(comp_one, less_than_zero, ones) - return None + probs1 = self.probs if probs is None else probs + prob_type = self.dtypeop(probs1) + value = value * self.fill(prob_type, self.shape(probs1), 1.0) + probs0 = 1.0 - probs1 * self.fill(prob_type, self.shape(value), 1.0) + comp_zero = self.less(value, 0.0) + comp_one = self.less(value, 1.0) + zeros = self.fill(prob_type, self.shape(value), 0.0) + ones = self.fill(prob_type, self.shape(value), 1.0) + less_than_zero = self.select(comp_zero, zeros, probs0) + return self.select(comp_one, less_than_zero, ones) - def _kl_loss(self, name, dist, probs1_b, probs1_a=None): + def _kl_loss(self, dist, probs1_b, probs1_a=None): r""" Evaluate bernoulli-bernoulli kl divergence, i.e. KL(a||b). Args: - name (str): name of the funtion. dist (str): type of the distributions. Should be "Bernoulli" in this case. probs1_b (Tensor): probs1 of distribution b. probs1_a (Tensor): probs1 of distribution a. Default: self.probs. @@ -249,31 +234,28 @@ class Bernoulli(Distribution): KL(a||b) = probs1_a * \log(\fract{probs1_a}{probs1_b}) + probs0_a * \log(\fract{probs0_a}{probs0_b}) """ - if name in self._divergence_functions and dist == 'Bernoulli': + if dist == 'Bernoulli': probs1_a = self.probs if probs1_a is None else probs1_a probs0_a = 1.0 - probs1_a probs0_b = 1.0 - probs1_b return probs1_a * self.log(probs1_a / probs1_b) + probs0_a * self.log(probs0_a / probs0_b) return None - def _sample(self, name, shape=(), probs=None): + def _sample(self, shape=(), probs=None): """ Sampling. Args: - name (str): name of the function. Should always be 'sample' when passed in from construct. shape (tuple): shape of the sample. Default: (). probs (Tensor): probs1 of the samples. Default: self.probs. Returns: Tensor, shape is shape + batch_shape. """ - if name == 'sample': - probs1 = self.probs if probs is None else probs - l_zero = self.const(0.0) - h_one = self.const(1.0) - sample_uniform = self.uniform(shape + self.shape(probs1), l_zero, h_one) - sample = self.less(sample_uniform, probs1) - sample = self.cast(sample, self.dtype) - return sample - return None + probs1 = self.probs if probs is None else probs + l_zero = self.const(0.0) + h_one = self.const(1.0) + sample_uniform = self.uniform(shape + self.shape(probs1), l_zero, h_one) + sample = self.less(sample_uniform, probs1) + sample = self.cast(sample, self.dtype) + return sample diff --git a/mindspore/nn/probability/distribution/distribution.py b/mindspore/nn/probability/distribution/distribution.py index 9f1e1a120..dd3d39f0d 100644 --- a/mindspore/nn/probability/distribution/distribution.py +++ b/mindspore/nn/probability/distribution/distribution.py @@ -27,11 +27,7 @@ class Distribution(Cell): Note: Derived class should override operations such as ,_mean, _prob, - and _log_prob. Functions should be called through construct when - used inside a network. Arguments should be passed in through *args - in the form of function name followed by additional arguments. - Functions such as cdf and prob, require a value to be passed in while - functions such as mean and sd do not require arguments other than name. + and _log_prob. Arguments should be passed in through *args. Dist_spec_args are unique for each type of distribution. For example, mean and sd are the dist_spec_args for a Normal distribution. @@ -73,11 +69,6 @@ class Distribution(Cell): self._set_log_survival() self._set_cross_entropy() - self._prob_functions = ('prob', 'log_prob') - self._cdf_survival_functions = ('cdf', 'log_cdf', 'survival_function', 'log_survival') - self._variance_functions = ('var', 'sd') - self._divergence_functions = ('kl_loss', 'cross_entropy') - @property def name(self): return self._name @@ -185,7 +176,7 @@ class Distribution(Cell): Evaluate the log probability(pdf or pmf) at the given value. Note: - Args must include name of the function and value. + Args must include value. Dist_spec_args are optional. """ return self._call_log_prob(*args) @@ -204,7 +195,7 @@ class Distribution(Cell): Evaluate the probability (pdf or pmf) at given value. Note: - Args must include name of the function and value. + Args must include value. Dist_spec_args are optional. """ return self._call_prob(*args) @@ -223,7 +214,7 @@ class Distribution(Cell): Evaluate the cdf at given value. Note: - Args must include name of the function and value. + Args must include value. Dist_spec_args are optional. """ return self._call_cdf(*args) @@ -260,7 +251,7 @@ class Distribution(Cell): Evaluate the log cdf at given value. Note: - Args must include name of the function and value. + Args must include value. Dist_spec_args are optional. """ return self._call_log_cdf(*args) @@ -279,7 +270,7 @@ class Distribution(Cell): Evaluate the survival function at given value. Note: - Args must include name of the function and value. + Args must include value. Dist_spec_args are optional. """ return self._call_survival(*args) @@ -307,7 +298,7 @@ class Distribution(Cell): Evaluate the log survival function at given value. Note: - Args must include name of the function and value. + Args must include value. Dist_spec_args are optional. """ return self._call_log_survival(*args) @@ -326,7 +317,7 @@ class Distribution(Cell): Evaluate the KL divergence, i.e. KL(a||b). Note: - Args must include name of the function, type of the distribution, parameters of distribution b. + Args must include type of the distribution, parameters of distribution b. Parameters for distribution a are optional. """ return self._kl_loss(*args) @@ -336,7 +327,7 @@ class Distribution(Cell): Evaluate the mean. Note: - Args must include the name of function. Dist_spec_args are optional. + Dist_spec_args are optional. """ return self._mean(*args) @@ -345,7 +336,7 @@ class Distribution(Cell): Evaluate the mode. Note: - Args must include the name of function. Dist_spec_args are optional. + Dist_spec_args are optional. """ return self._mode(*args) @@ -354,7 +345,7 @@ class Distribution(Cell): Evaluate the standard deviation. Note: - Args must include the name of function. Dist_spec_args are optional. + Dist_spec_args are optional. """ return self._call_sd(*args) @@ -363,7 +354,7 @@ class Distribution(Cell): Evaluate the variance. Note: - Args must include the name of function. Dist_spec_args are optional. + Dist_spec_args are optional. """ return self._call_var(*args) @@ -390,7 +381,7 @@ class Distribution(Cell): Evaluate the entropy. Note: - Args must include the name of function. Dist_spec_args are optional. + Dist_spec_args are optional. """ return self._entropy(*args) @@ -399,7 +390,7 @@ class Distribution(Cell): Evaluate the cross_entropy between distribution a and b. Note: - Args must include name of the function, type of the distribution, parameters of distribution b. + Args must include type of the distribution, parameters of distribution b. Parameters for distribution a are optional. """ return self._call_cross_entropy(*args) @@ -421,13 +412,13 @@ class Distribution(Cell): *args (list): arguments passed in through construct. Note: - Args must include name of the function. - Shape of the sample and dist_spec_args are optional. + Shape of the sample is default to (). + Dist_spec_args are optional. """ return self._sample(*args) - def construct(self, *inputs): + def construct(self, name, *args): """ Override construct in Cell. @@ -437,35 +428,36 @@ class Distribution(Cell): 'var', 'sd', 'entropy', 'kl_loss', 'cross_entropy', 'sample'. Args: - *inputs (list): inputs[0] is always the name of the function. - """ - - if inputs[0] == 'log_prob': - return self._call_log_prob(*inputs) - if inputs[0] == 'prob': - return self._call_prob(*inputs) - if inputs[0] == 'cdf': - return self._call_cdf(*inputs) - if inputs[0] == 'log_cdf': - return self._call_log_cdf(*inputs) - if inputs[0] == 'survival_function': - return self._call_survival(*inputs) - if inputs[0] == 'log_survival': - return self._call_log_survival(*inputs) - if inputs[0] == 'kl_loss': - return self._kl_loss(*inputs) - if inputs[0] == 'mean': - return self._mean(*inputs) - if inputs[0] == 'mode': - return self._mode(*inputs) - if inputs[0] == 'sd': - return self._call_sd(*inputs) - if inputs[0] == 'var': - return self._call_var(*inputs) - if inputs[0] == 'entropy': - return self._entropy(*inputs) - if inputs[0] == 'cross_entropy': - return self._call_cross_entropy(*inputs) - if inputs[0] == 'sample': - return self._sample(*inputs) + name (str): name of the function. + *args (list): list of arguments needed for the function. + """ + + if name == 'log_prob': + return self._call_log_prob(*args) + if name == 'prob': + return self._call_prob(*args) + if name == 'cdf': + return self._call_cdf(*args) + if name == 'log_cdf': + return self._call_log_cdf(*args) + if name == 'survival_function': + return self._call_survival(*args) + if name == 'log_survival': + return self._call_log_survival(*args) + if name == 'kl_loss': + return self._kl_loss(*args) + if name == 'mean': + return self._mean(*args) + if name == 'mode': + return self._mode(*args) + if name == 'sd': + return self._call_sd(*args) + if name == 'var': + return self._call_var(*args) + if name == 'entropy': + return self._entropy(*args) + if name == 'cross_entropy': + return self._call_cross_entropy(*args) + if name == 'sample': + return self._sample(*args) return None diff --git a/mindspore/nn/probability/distribution/exponential.py b/mindspore/nn/probability/distribution/exponential.py index 4f37ca776..74c6a40ab 100644 --- a/mindspore/nn/probability/distribution/exponential.py +++ b/mindspore/nn/probability/distribution/exponential.py @@ -35,55 +35,56 @@ class Exponential(Distribution): Examples: >>> # To initialize an Exponential distribution of rate 0.5 - >>> n = nn.Exponential(0.5, dtype=mstype.float32) + >>> import mindspore.nn.probability.distribution as msd + >>> e = msd.Exponential(0.5, dtype=mstype.float32) >>> >>> # The following creates two independent Exponential distributions - >>> n = nn.Exponential([0.5, 0.5], dtype=mstype.float32) + >>> e = msd.Exponential([0.5, 0.5], dtype=mstype.float32) >>> - >>> # A Exponential distribution can be initilized without arguments - >>> # In this case, rate must be passed in through construct. - >>> n = nn.Exponential(dtype=mstype.float32) + >>> # An Exponential distribution can be initilized without arguments + >>> # In this case, rate must be passed in through args during function calls + >>> e = msd.Exponential(dtype=mstype.float32) >>> - >>> # To use Exponential distribution in a network + >>> # To use Exponential in a network >>> class net(Cell): >>> def __init__(self): >>> super(net, self).__init__(): - >>> self.e1 = nn.Exponential(0.5, dtype=mstype.float32) - >>> self.e2 = nn.Exponential(dtype=mstype.float32) + >>> self.e1 = msd.Exponential(0.5, dtype=mstype.float32) + >>> self.e2 = msd.Exponential(dtype=mstype.float32) >>> >>> # All the following calls in construct are valid >>> def construct(self, value, rate_b, rate_a): >>> >>> # Similar calls can be made to other probability functions >>> # by replacing 'prob' with the name of the function - >>> ans = self.e1('prob', value) + >>> ans = self.e1.prob(value) >>> # Evaluate with the respect to distribution b - >>> ans = self.e1('prob', value, rate_b) + >>> ans = self.e1.prob(value, rate_b) >>> - >>> # Rate must be passed in through construct - >>> ans = self.e2('prob', value, rate_a) + >>> # Rate must be passed in during function calls + >>> ans = self.e2.prob(value, rate_a) >>> - >>> # Functions 'sd', 'var', 'entropy' have the same usage with 'mean' - >>> # Will return [0.0] - >>> ans = self.e1('mean') - >>> # Will return mean_b - >>> ans = self.e1('mean', rate_b) + >>> # Functions 'sd', 'var', 'entropy' have the same usage as'mean' + >>> # Will return 2 + >>> ans = self.e1.mean() + >>> # Will return 1 / rate_b + >>> ans = self.e1.mean(rate_b) >>> - >>> # Rate must be passed in through construct - >>> ans = self.e2('mean', rate_a) + >>> # Rate must be passed in during function calls + >>> ans = self.e2.mean(rate_a) >>> >>> # Usage of 'kl_loss' and 'cross_entropy' are similar - >>> ans = self.e1('kl_loss', 'Exponential', rate_b) - >>> ans = self.e1('kl_loss', 'Exponential', rate_b, rate_a) + >>> ans = self.e1.kl_loss('Exponential', rate_b) + >>> ans = self.e1.kl_loss('Exponential', rate_b, rate_a) >>> - >>> # Additional rate must be passed in through construct - >>> ans = self.e2('kl_loss', 'Exponential', rate_b, rate_a) + >>> # Additional rate must be passed in + >>> ans = self.e2.kl_loss('Exponential', rate_b, rate_a) >>> - >>> # Sample Usage - >>> ans = self.e1('sample') - >>> ans = self.e1('sample', (2,3)) - >>> ans = self.e1('sample', (2,3), rate_b) - >>> ans = self.e2('sample', (2,3), rate_a) + >>> # Sample + >>> ans = self.e1.sample() + >>> ans = self.e1.sample((2,3)) + >>> ans = self.e1.sample((2,3), rate_b) + >>> ans = self.e2.sample((2,3), rate_a) """ def __init__(self, @@ -131,67 +132,59 @@ class Exponential(Distribution): """ return self._rate - def _mean(self, name='mean', rate=None): + def _mean(self, rate=None): r""" .. math:: MEAN(EXP) = \fract{1.0}{\lambda}. """ - if name == 'mean': - rate = self.rate if rate is None else rate - return 1.0 / rate - return None + rate = self.rate if rate is None else rate + return 1.0 / rate - def _mode(self, name='mode', rate=None): + + def _mode(self, rate=None): r""" .. math:: MODE(EXP) = 0. """ - if name == 'mode': - rate = self.rate if rate is None else rate - return self.fill(self.dtype, self.shape(rate), 0.) - return None + rate = self.rate if rate is None else rate + return self.fill(self.dtype, self.shape(rate), 0.) - def _sd(self, name='sd', rate=None): + def _sd(self, rate=None): r""" .. math:: sd(EXP) = \fract{1.0}{\lambda}. """ - if name in self._variance_functions: - rate = self.rate if rate is None else rate - return 1.0 / rate - return None + rate = self.rate if rate is None else rate + return 1.0 / rate - def _entropy(self, name='entropy', rate=None): + def _entropy(self, rate=None): r""" .. math:: H(Exp) = 1 - \log(\lambda). """ rate = self.rate if rate is None else rate - if name == 'entropy': - return 1.0 - self.log(rate) - return None + return 1.0 - self.log(rate) - def _cross_entropy(self, name, dist, rate_b, rate_a=None): + + def _cross_entropy(self, dist, rate_b, rate_a=None): """ Evaluate cross_entropy between Exponential distributions. Args: - name (str): name of the funtion. Should always be "cross_entropy" when passed in from construct. dist (str): type of the distributions. Should be "Exponential" in this case. rate_b (Tensor): rate of distribution b. rate_a (Tensor): rate of distribution a. Default: self.rate. """ - if name == 'cross_entropy' and dist == 'Exponential': - return self._entropy(rate=rate_a) + self._kl_loss(name, dist, rate_b, rate_a) + if dist == 'Exponential': + return self._entropy(rate=rate_a) + self._kl_loss(dist, rate_b, rate_a) return None - def _prob(self, name, value, rate=None): + def _prob(self, value, rate=None): r""" pdf of Exponential distribution. Args: Args: - name (str): name of the function. value (Tensor): value to be evaluated. rate (Tensor): rate of the distribution. Default: self.rate. @@ -201,20 +194,17 @@ class Exponential(Distribution): .. math:: pdf(x) = rate * \exp(-1 * \lambda * x) if x >= 0 else 0 """ - if name in self._prob_functions: - rate = self.rate if rate is None else rate - prob = rate * self.exp(-1. * rate * value) - zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0) - comp = self.less(value, zeros) - return self.select(comp, zeros, prob) - return None + rate = self.rate if rate is None else rate + prob = rate * self.exp(-1. * rate * value) + zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0) + comp = self.less(value, zeros) + return self.select(comp, zeros, prob) - def _cdf(self, name, value, rate=None): + def _cdf(self, value, rate=None): r""" cdf of Exponential distribution. Args: - name (str): name of the function. value (Tensor): value to be evaluated. rate (Tensor): rate of the distribution. Default: self.rate. @@ -224,45 +214,40 @@ class Exponential(Distribution): .. math:: cdf(x) = 1.0 - \exp(-1 * \lambda * x) if x >= 0 else 0 """ - if name in self._cdf_survival_functions: - rate = self.rate if rate is None else rate - cdf = 1.0 - self.exp(-1. * rate * value) - zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0) - comp = self.less(value, zeros) - return self.select(comp, zeros, cdf) - return None + rate = self.rate if rate is None else rate + cdf = 1.0 - self.exp(-1. * rate * value) + zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0) + comp = self.less(value, zeros) + return self.select(comp, zeros, cdf) + - def _kl_loss(self, name, dist, rate_b, rate_a=None): + def _kl_loss(self, dist, rate_b, rate_a=None): """ Evaluate exp-exp kl divergence, i.e. KL(a||b). Args: - name (str): name of the funtion. dist (str): type of the distributions. Should be "Exponential" in this case. rate_b (Tensor): rate of distribution b. rate_a (Tensor): rate of distribution a. Default: self.rate. """ - if name in self._divergence_functions and dist == 'Exponential': + if dist == 'Exponential': rate_a = self.rate if rate_a is None else rate_a return self.log(rate_a) - self.log(rate_b) + rate_b / rate_a - 1.0 return None - def _sample(self, name, shape=(), rate=None): + def _sample(self, shape=(), rate=None): """ Sampling. Args: - name (str): name of the function. shape (tuple): shape of the sample. Default: (). rate (Tensor): rate of the distribution. Default: self.rate. Returns: Tensor, shape is shape + batch_shape. """ - if name == 'sample': - rate = self.rate if rate is None else rate - minval = self.const(self.minval) - maxval = self.const(1.0) - sample = self.uniform(shape + self.shape(rate), minval, maxval) - return -self.log(sample) / rate - return None + rate = self.rate if rate is None else rate + minval = self.const(self.minval) + maxval = self.const(1.0) + sample = self.uniform(shape + self.shape(rate), minval, maxval) + return -self.log(sample) / rate diff --git a/mindspore/nn/probability/distribution/geometric.py b/mindspore/nn/probability/distribution/geometric.py index cac194e11..59bc8f0c9 100644 --- a/mindspore/nn/probability/distribution/geometric.py +++ b/mindspore/nn/probability/distribution/geometric.py @@ -36,55 +36,56 @@ class Geometric(Distribution): Examples: >>> # To initialize a Geometric distribution of prob 0.5 - >>> n = nn.Geometric(0.5, dtype=mstype.int32) + >>> import mindspore.nn.probability.distribution as msd + >>> n = msd.Geometric(0.5, dtype=mstype.int32) >>> >>> # The following creates two independent Geometric distributions - >>> n = nn.Geometric([0.5, 0.5], dtype=mstype.int32) + >>> n = msd.Geometric([0.5, 0.5], dtype=mstype.int32) >>> >>> # A Geometric distribution can be initilized without arguments - >>> # In this case, probs must be passed in through construct. - >>> n = nn.Geometric(dtype=mstype.int32) + >>> # In this case, probs must be passed in through args during function calls. + >>> n = msd.Geometric(dtype=mstype.int32) >>> - >>> # To use Geometric distribution in a network + >>> # To use Geometric in a network >>> class net(Cell): >>> def __init__(self): >>> super(net, self).__init__(): - >>> self.g1 = nn.Geometric(0.5, dtype=mstype.int32) - >>> self.g2 = nn.Geometric(dtype=mstype.int32) + >>> self.g1 = msd.Geometric(0.5, dtype=mstype.int32) + >>> self.g2 = msd.Geometric(dtype=mstype.int32) >>> >>> # Tthe following calls are valid in construct >>> def construct(self, value, probs_b, probs_a): >>> >>> # Similar calls can be made to other probability functions >>> # by replacing 'prob' with the name of the function - >>> ans = self.g1('prob', value) + >>> ans = self.g1.prob(value) >>> # Evaluate with the respect to distribution b - >>> ans = self.g1('prob', value, probs_b) + >>> ans = self.g1.prob(value, probs_b) >>> - >>> # Probs must be passed in through construct - >>> ans = self.g2('prob', value, probs_a) + >>> # Probs must be passed in during function calls + >>> ans = self.g2.prob(value, probs_a) >>> - >>> # Functions 'sd', 'var', 'entropy' have the same usage with 'mean' - >>> # Will return [0.0] - >>> ans = self.g1('mean') - >>> # Will return mean_b - >>> ans = self.g1('mean', probs_b) + >>> # Functions 'sd', 'var', 'entropy' have the same usage as 'mean' + >>> # Will return 1.0 + >>> ans = self.g1.mean() + >>> # Another possible usage + >>> ans = self.g1.mean(probs_b) >>> - >>> # Probs must be passed in through construct - >>> ans = self.g2('mean', probs_a) + >>> # Probs must be passed in during function calls + >>> ans = self.g2.mean(probs_a) >>> >>> # Usage of 'kl_loss' and 'cross_entropy' are similar - >>> ans = self.g1('kl_loss', 'Geometric', probs_b) - >>> ans = self.g1('kl_loss', 'Geometric', probs_b, probs_a) + >>> ans = self.g1.kl_loss('Geometric', probs_b) + >>> ans = self.g1.kl_loss('Geometric', probs_b, probs_a) >>> - >>> # Additional probs must be passed in through construct - >>> ans = self.g2('kl_loss', 'Geometric', probs_b, probs_a) + >>> # Additional probs must be passed in + >>> ans = self.g2.kl_loss('Geometric', probs_b, probs_a) >>> - >>> # Sample Usage - >>> ans = self.g1('sample') - >>> ans = self.g1('sample', (2,3)) - >>> ans = self.g1('sample', (2,3), probs_b) - >>> ans = self.g2('sample', (2,3), probs_a) + >>> # Sample + >>> ans = self.g1.sample() + >>> ans = self.g1.sample((2,3)) + >>> ans = self.g1.sample((2,3), probs_b) + >>> ans = self.g2.sample((2,3), probs_a) """ def __init__(self, @@ -134,67 +135,57 @@ class Geometric(Distribution): """ return self._probs - def _mean(self, name='mean', probs1=None): + def _mean(self, probs1=None): r""" .. math:: MEAN(Geo) = \fratc{1 - probs1}{probs1} """ - if name == 'mean': - probs1 = self.probs if probs1 is None else probs1 - return (1. - probs1) / probs1 - return None + probs1 = self.probs if probs1 is None else probs1 + return (1. - probs1) / probs1 - def _mode(self, name='mode', probs1=None): + def _mode(self, probs1=None): r""" .. math:: MODE(Geo) = 0 """ - if name == 'mode': - probs1 = self.probs if probs1 is None else probs1 - return self.fill(self.dtypeop(probs1), self.shape(probs1), 0.) - return None + probs1 = self.probs if probs1 is None else probs1 + return self.fill(self.dtypeop(probs1), self.shape(probs1), 0.) - def _var(self, name='var', probs1=None): + def _var(self, probs1=None): r""" .. math:: VAR(Geo) = \fract{1 - probs1}{probs1 ^ {2}} """ - if name in self._variance_functions: - probs1 = self.probs if probs1 is None else probs1 - return (1.0 - probs1) / self.sq(probs1) - return None + probs1 = self.probs if probs1 is None else probs1 + return (1.0 - probs1) / self.sq(probs1) - def _entropy(self, name='entropy', probs=None): + def _entropy(self, probs=None): r""" .. math:: H(Geo) = \fract{-1 * probs0 \log_2 (1-probs0)\ - prob1 * \log_2 (1-probs1)\ }{probs1} """ - if name == 'entropy': - probs1 = self.probs if probs is None else probs - probs0 = 1.0 - probs1 - return (-probs0 * self.log(probs0) - probs1 * self.log(probs1)) / probs1 - return None + probs1 = self.probs if probs is None else probs + probs0 = 1.0 - probs1 + return (-probs0 * self.log(probs0) - probs1 * self.log(probs1)) / probs1 - def _cross_entropy(self, name, dist, probs1_b, probs1_a=None): + def _cross_entropy(self, dist, probs1_b, probs1_a=None): r""" Evaluate cross_entropy between Geometric distributions. Args: - name (str): name of the funtion. Should always be "cross_entropy" when passed in from construct. dist (str): type of the distributions. Should be "Geometric" in this case. probs1_b (Tensor): probability of success of distribution b. probs1_a (Tensor): probability of success of distribution a. Default: self.probs. """ - if name == 'cross_entropy' and dist == 'Geometric': - return self._entropy(probs=probs1_a) + self._kl_loss(name, dist, probs1_b, probs1_a) + if dist == 'Geometric': + return self._entropy(probs=probs1_a) + self._kl_loss(dist, probs1_b, probs1_a) return None - def _prob(self, name, value, probs=None): + def _prob(self, value, probs=None): r""" pmf of Geometric distribution. Args: - name (str): name of the function. Should be "prob" when passed in from construct. value (Tensor): a Tensor composed of only natural numbers. probs (Tensor): probability of success. Default: self.probs. @@ -202,27 +193,24 @@ class Geometric(Distribution): pmf(k) = probs0 ^k * probs1 if k >= 0; pmf(k) = 0 if k < 0. """ - if name in self._prob_functions: - probs1 = self.probs if probs is None else probs - dtype = self.dtypeop(value) - if self.issubclass(dtype, mstype.int_): - pass - elif self.issubclass(dtype, mstype.float_): - value = self.floor(value) - else: - return None - pmf = self.pow((1.0 - probs1), value) * probs1 - zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0) - comp = self.less(value, zeros) - return self.select(comp, zeros, pmf) - return None + probs1 = self.probs if probs is None else probs + dtype = self.dtypeop(value) + if self.issubclass(dtype, mstype.int_): + pass + elif self.issubclass(dtype, mstype.float_): + value = self.floor(value) + else: + return None + pmf = self.pow((1.0 - probs1), value) * probs1 + zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0) + comp = self.less(value, zeros) + return self.select(comp, zeros, pmf) - def _cdf(self, name, value, probs=None): + def _cdf(self, value, probs=None): r""" cdf of Geometric distribution. Args: - name (str): name of the function. value (Tensor): a Tensor composed of only natural numbers. probs (Tensor): probability of success. Default: self.probs. @@ -231,28 +219,26 @@ class Geometric(Distribution): cdf(k) = 0 if k < 0. """ - if name in self._cdf_survival_functions: - probs1 = self.probs if probs is None else probs - probs0 = 1.0 - probs1 - dtype = self.dtypeop(value) - if self.issubclass(dtype, mstype.int_): - pass - elif self.issubclass(dtype, mstype.float_): - value = self.floor(value) - else: - return None - cdf = 1.0 - self.pow(probs0, value + 1.0) - zeros = self.fill(self.dtypeop(probs1), self.shape(cdf), 0.0) - comp = self.less(value, zeros) - return self.select(comp, zeros, cdf) - return None + probs1 = self.probs if probs is None else probs + probs0 = 1.0 - probs1 + dtype = self.dtypeop(value) + if self.issubclass(dtype, mstype.int_): + pass + elif self.issubclass(dtype, mstype.float_): + value = self.floor(value) + else: + return None + cdf = 1.0 - self.pow(probs0, value + 1.0) + zeros = self.fill(self.dtypeop(probs1), self.shape(cdf), 0.0) + comp = self.less(value, zeros) + return self.select(comp, zeros, cdf) - def _kl_loss(self, name, dist, probs1_b, probs1_a=None): + + def _kl_loss(self, dist, probs1_b, probs1_a=None): r""" Evaluate Geometric-Geometric kl divergence, i.e. KL(a||b). Args: - name (str): name of the funtion. dist (str): type of the distributions. Should be "Geometric" in this case. probs1_b (Tensor): probability of success of distribution b. probs1_a (Tensor): probability of success of distribution a. Default: self.probs. @@ -260,29 +246,26 @@ class Geometric(Distribution): .. math:: KL(a||b) = \log(\fract{probs1_a}{probs1_b}) + \fract{probs0_a}{probs1_a} * \log(\fract{probs0_a}{probs0_b}) """ - if name in self._divergence_functions and dist == 'Geometric': + if dist == 'Geometric': probs1_a = self.probs if probs1_a is None else probs1_a probs0_a = 1.0 - probs1_a probs0_b = 1.0 - probs1_b return self.log(probs1_a / probs1_b) + (probs0_a / probs1_a) * self.log(probs0_a / probs0_b) return None - def _sample(self, name, shape=(), probs=None): + def _sample(self, shape=(), probs=None): """ Sampling. Args: - name (str): name of the function. Should always be 'sample' when passed in from construct. shape (tuple): shape of the sample. Default: (). probs (Tensor): probability of success. Default: self.probs. Returns: Tensor, shape is shape + batch_shape. """ - if name == 'sample': - probs = self.probs if probs is None else probs - minval = self.const(self.minval) - maxval = self.const(1.0) - sample_uniform = self.uniform(shape + self.shape(probs), minval, maxval) - return self.floor(self.log(sample_uniform) / self.log(1.0 - probs)) - return None + probs = self.probs if probs is None else probs + minval = self.const(self.minval) + maxval = self.const(1.0) + sample_uniform = self.uniform(shape + self.shape(probs), minval, maxval) + return self.floor(self.log(sample_uniform) / self.log(1.0 - probs)) diff --git a/mindspore/nn/probability/distribution/normal.py b/mindspore/nn/probability/distribution/normal.py index aae9b3963..f243a2bc3 100644 --- a/mindspore/nn/probability/distribution/normal.py +++ b/mindspore/nn/probability/distribution/normal.py @@ -17,7 +17,6 @@ import numpy as np from mindspore.ops import operations as P from mindspore.ops import composite as C from mindspore.common import dtype as mstype -from mindspore.context import get_context from .distribution import Distribution from ._utils.utils import convert_to_batch, check_greater_equal_zero @@ -39,55 +38,56 @@ class Normal(Distribution): Examples: >>> # To initialize a Normal distribution of mean 3.0 and standard deviation 4.0 - >>> n = nn.Normal(3.0, 4.0, dtype=mstype.float32) + >>> import mindspore.nn.probability.distribution as msd + >>> n = msd.Normal(3.0, 4.0, dtype=mstype.float32) >>> >>> # The following creates two independent Normal distributions - >>> n = nn.Normal([3.0, 3.0], [4.0, 4.0], dtype=mstype.float32) + >>> n = msd.Normal([3.0, 3.0], [4.0, 4.0], dtype=mstype.float32) >>> - >>> # A normal distribution can be initilize without arguments - >>> # In this case, mean and sd must be passed in through construct. - >>> n = nn.Normal(dtype=mstype.float32) + >>> # A Normal distribution can be initilize without arguments + >>> # In this case, mean and sd must be passed in through args. + >>> n = msd.Normal(dtype=mstype.float32) >>> - >>> # To use normal in a network + >>> # To use Normal in a network >>> class net(Cell): >>> def __init__(self): >>> super(net, self).__init__(): - >>> self.n1 = nn.Normal(0.0, 1.0, dtype=mstype.float32) - >>> self.n2 = nn.Normal(dtype=mstype.float32) + >>> self.n1 = msd.Nomral(0.0, 1.0, dtype=mstype.float32) + >>> self.n2 = msd.Normal(dtype=mstype.float32) >>> >>> # The following calls are valid in construct >>> def construct(self, value, mean_b, sd_b, mean_a, sd_a): >>> >>> # Similar calls can be made to other probability functions >>> # by replacing 'prob' with the name of the function - >>> ans = self.n1('prob', value) + >>> ans = self.n1.prob(value) >>> # Evaluate with the respect to distribution b - >>> ans = self.n1('prob', value, mean_b, sd_b) + >>> ans = self.n1.prob(value, mean_b, sd_b) >>> - >>> # mean and sd must be passed in through construct - >>> ans = self.n2('prob', value, mean_a, sd_a) + >>> # mean and sd must be passed in during function calls + >>> ans = self.n2.prob(value, mean_a, sd_a) >>> - >>> # Functions 'sd', 'var', 'entropy' have the same usage with 'mean' - >>> # Will return [0.0] - >>> ans = self.n1('mean') - >>> # Will return mean_b - >>> ans = self.n1('mean', mean_b, sd_b) + >>> # Functions 'sd', 'var', 'entropy' have the same usage as 'mean' + >>> # will return [0.0] + >>> ans = self.n1.mean() + >>> # will return mean_b + >>> ans = self.n1.mean(mean_b, sd_b) >>> - >>> # mean and sd must be passed in through construct - >>> ans = self.n2('mean', mean_a, sd_a) + >>> # mean and sd must be passed during function calls + >>> ans = self.n2.mean(mean_a, sd_a) >>> >>> # Usage of 'kl_loss' and 'cross_entropy' are similar - >>> ans = self.n1('kl_loss', 'Normal', mean_b, sd_b) - >>> ans = self.n1('kl_loss', 'Normal', mean_b, sd_b, mean_a, sd_a) + >>> ans = self.n1.kl_loss('Normal', mean_b, sd_b) + >>> ans = self.n1.kl_loss('Normal', mean_b, sd_b, mean_a, sd_a) >>> - >>> # Additional mean and sd must be passed in through construct - >>> ans = self.n2('kl_loss', 'Normal', mean_b, sd_b, mean_a, sd_a) + >>> # Additional mean and sd must be passed + >>> ans = self.n2.kl_loss('Normal', mean_b, sd_b, mean_a, sd_a) >>> - >>> # Sample Usage - >>> ans = self.n1('sample') - >>> ans = self.n1('sample', (2,3)) - >>> ans = self.n1('sample', (2,3), mean_b, sd_b) - >>> ans = self.n2('sample', (2,3), mean_a, sd_a) + >>> # Sample + >>> ans = self.n1.sample() + >>> ans = self.n1.sample((2,3)) + >>> ans = self.n1.sample((2,3), mean_b, sd_b) + >>> ans = self.n2.sample((2,3), mean_a, sd_a) """ def __init__(self, @@ -114,7 +114,7 @@ class Normal(Distribution): self.const = P.ScalarToArray() self.erf = P.Erf() self.exp = P.Exp() - self.expm1 = P.Expm1() if get_context('device_target') == 'Ascend' else self._expm1_by_step + self.expm1 = self._expm1_by_step self.fill = P.Fill() self.log = P.Log() self.shape = P.Shape() @@ -135,67 +135,57 @@ class Normal(Distribution): """ return self.exp(x) - 1.0 - def _mean(self, name='mean', mean=None, sd=None): + def _mean(self, mean=None, sd=None): """ Mean of the distribution. """ - if name == 'mean': - mean = self._mean_value if mean is None or sd is None else mean - return mean - return None + mean = self._mean_value if mean is None or sd is None else mean + return mean - def _mode(self, name='mode', mean=None, sd=None): + def _mode(self, mean=None, sd=None): """ Mode of the distribution. """ - if name == 'mode': - mean = self._mean_value if mean is None or sd is None else mean - return mean - return None + mean = self._mean_value if mean is None or sd is None else mean + return mean - def _sd(self, name='sd', mean=None, sd=None): + def _sd(self, mean=None, sd=None): """ Standard deviation of the distribution. """ - if name in self._variance_functions: - sd = self._sd_value if mean is None or sd is None else sd - return sd - return None + sd = self._sd_value if mean is None or sd is None else sd + return sd - def _entropy(self, name='entropy', sd=None): + def _entropy(self, sd=None): r""" Evaluate entropy. .. math:: H(X) = \log(\sqrt(numpy.e * 2. * numpy.pi * \sq(\sigma))) """ - if name == 'entropy': - sd = self._sd_value if sd is None else sd - return self.log(self.sqrt(np.e * 2. * np.pi * self.sq(sd))) - return None + sd = self._sd_value if sd is None else sd + return self.log(self.sqrt(np.e * 2. * np.pi * self.sq(sd))) - def _cross_entropy(self, name, dist, mean_b, sd_b, mean_a=None, sd_a=None): + def _cross_entropy(self, dist, mean_b, sd_b, mean_a=None, sd_a=None): r""" Evaluate cross_entropy between normal distributions. Args: - name (str): name of the funtion passed in from construct. Should always be "cross_entropy". dist (str): type of the distributions. Should be "Normal" in this case. mean_b (Tensor): mean of distribution b. sd_b (Tensor): standard deviation distribution b. mean_a (Tensor): mean of distribution a. Default: self._mean_value. sd_a (Tensor): standard deviation distribution a. Default: self._sd_value. """ - if name == 'cross_entropy' and dist == 'Normal': - return self._entropy(sd=sd_a) + self._kl_loss(name, dist, mean_b, sd_b, mean_a, sd_a) + if dist == 'Normal': + return self._entropy(sd=sd_a) + self._kl_loss(dist, mean_b, sd_b, mean_a, sd_a) return None - def _log_prob(self, name, value, mean=None, sd=None): + def _log_prob(self, value, mean=None, sd=None): r""" Evaluate log probability. Args: - name (str): name of the funtion passed in from construct. value (Tensor): value to be evaluated. mean (Tensor): mean of the distribution. Default: self._mean_value. sd (Tensor): standard deviation the distribution. Default: self._sd_value. @@ -203,20 +193,17 @@ class Normal(Distribution): .. math:: L(x) = -1* \fract{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2)) """ - if name in self._prob_functions: - mean = self._mean_value if mean is None else mean - sd = self._sd_value if sd is None else sd - unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd)) - neg_normalization = -1. * self.log(self.sqrt(2. * np.pi * self.sq(sd))) - return unnormalized_log_prob + neg_normalization - return None + mean = self._mean_value if mean is None else mean + sd = self._sd_value if sd is None else sd + unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd)) + neg_normalization = -1. * self.log(self.sqrt(2. * np.pi * self.sq(sd))) + return unnormalized_log_prob + neg_normalization - def _cdf(self, name, value, mean=None, sd=None): + def _cdf(self, value, mean=None, sd=None): r""" Evaluate cdf of given value. Args: - name (str): name of the funtion passed in from construct. Should always be "cdf". value (Tensor): value to be evaluated. mean (Tensor): mean of the distribution. Default: self._mean_value. sd (Tensor): standard deviation the distribution. Default: self._sd_value. @@ -224,20 +211,17 @@ class Normal(Distribution): .. math:: cdf(x) = 0.5 * (1+ Erf((x - \mu) / ( \sigma * \sqrt(2)))) """ - if name in self._cdf_survival_functions: - mean = self._mean_value if mean is None else mean - sd = self._sd_value if sd is None else sd - sqrt2 = self.sqrt(self.const(2.0)) - adjusted = (value - mean) / (sd * sqrt2) - return 0.5 * (1.0 + self.erf(adjusted)) - return None + mean = self._mean_value if mean is None else mean + sd = self._sd_value if sd is None else sd + sqrt2 = self.sqrt(self.const(2.0)) + adjusted = (value - mean) / (sd * sqrt2) + return 0.5 * (1.0 + self.erf(adjusted)) - def _kl_loss(self, name, dist, mean_b, sd_b, mean_a=None, sd_a=None): + def _kl_loss(self, dist, mean_b, sd_b, mean_a=None, sd_a=None): r""" Evaluate Normal-Normal kl divergence, i.e. KL(a||b). Args: - name (str): name of the funtion passed in from construct. dist (str): type of the distributions. Should be "Normal" in this case. mean_b (Tensor): mean of distribution b. sd_b (Tensor): standard deviation distribution b. @@ -248,7 +232,7 @@ class Normal(Distribution): KL(a||b) = 0.5 * (\fract{MEAN(a)}{STD(b)} - \fract{MEAN(b)}{STD(b)}) ^ 2 + 0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b))) """ - if name in self._divergence_functions and dist == 'Normal': + if dist == 'Normal': mean_a = self._mean_value if mean_a is None else mean_a sd_a = self._sd_value if sd_a is None else sd_a diff_log_scale = self.log(sd_a) - self.log(sd_b) @@ -256,12 +240,11 @@ class Normal(Distribution): return 0.5 * squared_diff + 0.5 * self.expm1(2 * diff_log_scale) - diff_log_scale return None - def _sample(self, name, shape=(), mean=None, sd=None): + def _sample(self, shape=(), mean=None, sd=None): """ Sampling. Args: - name (str): name of the function. Should always be 'sample' when passed in from construct. shape (tuple): shape of the sample. Default: (). mean (Tensor): mean of the samples. Default: self._mean_value. sd (Tensor): standard deviation of the samples. Default: self._sd_value. @@ -269,14 +252,12 @@ class Normal(Distribution): Returns: Tensor, shape is shape + batch_shape. """ - if name == 'sample': - mean = self._mean_value if mean is None else mean - sd = self._sd_value if sd is None else sd - batch_shape = self.shape(self.zeroslike(mean) + self.zeroslike(sd)) - sample_shape = shape + batch_shape - mean_zero = self.const(0.0) - sd_one = self.const(1.0) - sample_norm = C.normal(sample_shape, mean_zero, sd_one, self.seed) - sample = mean + sample_norm * sd - return sample - return None + mean = self._mean_value if mean is None else mean + sd = self._sd_value if sd is None else sd + batch_shape = self.shape(self.zeroslike(mean) + self.zeroslike(sd)) + sample_shape = shape + batch_shape + mean_zero = self.const(0.0) + sd_one = self.const(1.0) + sample_norm = C.normal(sample_shape, mean_zero, sd_one, self.seed) + sample = mean + sample_norm * sd + return sample diff --git a/mindspore/nn/probability/distribution/uniform.py b/mindspore/nn/probability/distribution/uniform.py index 0c0d73f3e..2fc459f56 100644 --- a/mindspore/nn/probability/distribution/uniform.py +++ b/mindspore/nn/probability/distribution/uniform.py @@ -35,55 +35,56 @@ class Uniform(Distribution): Examples: >>> # To initialize a Uniform distribution of mean 3.0 and standard deviation 4.0 - >>> n = nn.Uniform(0.0, 1.0, dtype=mstype.float32) + >>> import mindspore.nn.probability.distribution as msd + >>> u = msd.Uniform(0.0, 1.0, dtype=mstype.float32) >>> >>> # The following creates two independent Uniform distributions - >>> n = nn.Uniform([0.0, 0.0], [1.0, 2.0], dtype=mstype.float32) + >>> u = msd.Uniform([0.0, 0.0], [1.0, 2.0], dtype=mstype.float32) >>> >>> # A Uniform distribution can be initilized without arguments - >>> # In this case, high and low must be passed in through construct. - >>> n = nn.Uniform(dtype=mstype.float32) + >>> # In this case, high and low must be passed in through args during function calls. + >>> u = msd.Uniform(dtype=mstype.float32) >>> >>> # To use Uniform in a network >>> class net(Cell): >>> def __init__(self) >>> super(net, self).__init__(): - >>> self.u1 = nn.Uniform(0.0, 1.0, dtype=mstype.float32) - >>> self.u2 = nn.Uniform(dtype=mstype.float32) + >>> self.u1 = msd.Uniform(0.0, 1.0, dtype=mstype.float32) + >>> self.u2 = msd.Uniform(dtype=mstype.float32) >>> >>> # All the following calls in construct are valid >>> def construct(self, value, low_b, high_b, low_a, high_a): >>> >>> # Similar calls can be made to other probability functions >>> # by replacing 'prob' with the name of the function - >>> ans = self.u1('prob', value) + >>> ans = self.u1.prob(value) >>> # Evaluate with the respect to distribution b - >>> ans = self.u1('prob', value, low_b, high_b) + >>> ans = self.u1.prob(value, low_b, high_b) >>> - >>> # High and low must be passed in through construct - >>> ans = self.u2('prob', value, low_a, high_a) + >>> # High and low must be passed in during function calls + >>> ans = self.u2.prob(value, low_a, high_a) >>> - >>> # Functions 'sd', 'var', 'entropy' have the same usage with 'mean' - >>> # Will return [0.0] - >>> ans = self.u1('mean') - >>> # Will return low_b - >>> ans = self.u1('mean', low_b, high_b) + >>> # Functions 'sd', 'var', 'entropy' have the same usage as 'mean' + >>> # Will return 0.5 + >>> ans = self.u1.mean() + >>> # Will return (low_b + high_b) / 2 + >>> ans = self.u1.mean(low_b, high_b) >>> - >>> # High and low must be passed in through construct - >>> ans = self.u2('mean', low_a, high_a) + >>> # High and low must be passed in during function calls + >>> ans = self.u2.mean(low_a, high_a) >>> >>> # Usage of 'kl_loss' and 'cross_entropy' are similar - >>> ans = self.u1('kl_loss', 'Uniform', low_b, high_b) - >>> ans = self.u1('kl_loss', 'Uniform', low_b, high_b, low_a, high_a) + >>> ans = self.u1.kl_loss('Uniform', low_b, high_b) + >>> ans = self.u1.kl_loss('Uniform', low_b, high_b, low_a, high_a) >>> - >>> # Additional high and low must be passed in through construct - >>> ans = self.u2('kl_loss', 'Uniform', low_b, high_b, low_a, high_a) + >>> # Additional high and low must be passed + >>> ans = self.u2.kl_loss('Uniform', low_b, high_b, low_a, high_a) >>> - >>> # Sample Usage - >>> ans = self.u1('sample') - >>> ans = self.u1('sample', (2,3)) - >>> ans = self.u1('sample', (2,3), low_b, high_b) - >>> ans = self.u2('sample', (2,3), low_a, high_a) + >>> # Sample + >>> ans = self.u1.sample() + >>> ans = self.u1.sample((2,3)) + >>> ans = self.u1.sample((2,3), low_b, high_b) + >>> ans = self.u2.sample((2,3), low_a, high_a) """ def __init__(self, @@ -142,73 +143,64 @@ class Uniform(Distribution): """ return self._high - def _range(self, name='range', low=None, high=None): + def _range(self, low=None, high=None): r""" Return the range of the distribution. .. math:: range(U) = high -low """ - if name == 'range': - low = self.low if low is None else low - high = self.high if high is None else high - return high - low - return None + low = self.low if low is None else low + high = self.high if high is None else high + return high - low - def _mean(self, name='mean', low=None, high=None): + def _mean(self, low=None, high=None): r""" .. math:: MEAN(U) = \fract{low + high}{2}. """ - if name == 'mean': - low = self.low if low is None else low - high = self.high if high is None else high - return (low + high) / 2. - return None + low = self.low if low is None else low + high = self.high if high is None else high + return (low + high) / 2. - def _var(self, name='var', low=None, high=None): + + def _var(self, low=None, high=None): r""" .. math:: VAR(U) = \fract{(high -low) ^ 2}{12}. """ - if name in self._variance_functions: - low = self.low if low is None else low - high = self.high if high is None else high - return self.sq(high - low) / 12.0 - return None + low = self.low if low is None else low + high = self.high if high is None else high + return self.sq(high - low) / 12.0 - def _entropy(self, name='entropy', low=None, high=None): + def _entropy(self, low=None, high=None): r""" .. math:: H(U) = \log(high - low). """ - if name == 'entropy': - low = self.low if low is None else low - high = self.high if high is None else high - return self.log(high - low) - return None + low = self.low if low is None else low + high = self.high if high is None else high + return self.log(high - low) - def _cross_entropy(self, name, dist, low_b, high_b, low_a=None, high_a=None): + def _cross_entropy(self, dist, low_b, high_b, low_a=None, high_a=None): """ Evaluate cross_entropy between Uniform distributoins. Args: - name (str): name of the funtion. dist (str): type of the distributions. Should be "Uniform" in this case. low_b (Tensor): lower bound of distribution b. high_b (Tensor): upper bound of distribution b. low_a (Tensor): lower bound of distribution a. Default: self.low. high_a (Tensor): upper bound of distribution a. Default: self.high. """ - if name == 'cross_entropy' and dist == 'Uniform': - return self._entropy(low=low_a, high=high_a) + self._kl_loss(name, dist, low_b, high_b, low_a, high_a) + if dist == 'Uniform': + return self._entropy(low=low_a, high=high_a) + self._kl_loss(dist, low_b, high_b, low_a, high_a) return None - def _prob(self, name, value, low=None, high=None): + def _prob(self, value, low=None, high=None): r""" pdf of Uniform distribution. Args: - name (str): name of the function. value (Tensor): value to be evaluated. low (Tensor): lower bound of the distribution. Default: self.low. high (Tensor): upper bound of the distribution. Default: self.high. @@ -218,32 +210,29 @@ class Uniform(Distribution): pdf(x) = \fract{1.0}{high -low} if low <= x <= high; pdf(x) = 0 if x > high; """ - if name in self._prob_functions: - low = self.low if low is None else low - high = self.high if high is None else high - ones = self.fill(self.dtype, self.shape(value), 1.0) - prob = ones / (high - low) - broadcast_shape = self.shape(prob) - zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0) - comp_lo = self.less(value, low) - comp_hi = self.lessequal(value, high) - less_than_low = self.select(comp_lo, zeros, prob) - return self.select(comp_hi, less_than_low, zeros) - return None + low = self.low if low is None else low + high = self.high if high is None else high + ones = self.fill(self.dtype, self.shape(value), 1.0) + prob = ones / (high - low) + broadcast_shape = self.shape(prob) + zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0) + comp_lo = self.less(value, low) + comp_hi = self.lessequal(value, high) + less_than_low = self.select(comp_lo, zeros, prob) + return self.select(comp_hi, less_than_low, zeros) - def _kl_loss(self, name, dist, low_b, high_b, low_a=None, high_a=None): + def _kl_loss(self, dist, low_b, high_b, low_a=None, high_a=None): """ Evaluate uniform-uniform kl divergence, i.e. KL(a||b). Args: - name (str): name of the funtion. dist (str): type of the distributions. Should be "Uniform" in this case. low_b (Tensor): lower bound of distribution b. high_b (Tensor): upper bound of distribution b. low_a (Tensor): lower bound of distribution a. Default: self.low. high_a (Tensor): upper bound of distribution a. Default: self.high. """ - if name in self._divergence_functions and dist == 'Uniform': + if dist == 'Uniform': low_a = self.low if low_a is None else low_a high_a = self.high if high_a is None else high_a kl = self.log(high_b - low_b) / self.log(high_a - low_a) @@ -251,12 +240,11 @@ class Uniform(Distribution): return self.select(comp, kl, self.log(self.zeroslike(kl))) return None - def _cdf(self, name, value, low=None, high=None): + def _cdf(self, value, low=None, high=None): r""" cdf of Uniform distribution. Args: - name (str): name of the function. value (Tensor): value to be evaluated. low (Tensor): lower bound of the distribution. Default: self.low. high (Tensor): upper bound of the distribution. Default: self.high. @@ -266,25 +254,22 @@ class Uniform(Distribution): cdf(x) = \fract{x - low}{high -low} if low <= x <= high; cdf(x) = 1 if x > high; """ - if name in self._cdf_survival_functions: - low = self.low if low is None else low - high = self.high if high is None else high - prob = (value - low) / (high - low) - broadcast_shape = self.shape(prob) - zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0) - ones = self.fill(self.dtypeop(prob), broadcast_shape, 1.0) - comp_lo = self.less(value, low) - comp_hi = self.less(value, high) - less_than_low = self.select(comp_lo, zeros, prob) - return self.select(comp_hi, less_than_low, ones) - return None + low = self.low if low is None else low + high = self.high if high is None else high + prob = (value - low) / (high - low) + broadcast_shape = self.shape(prob) + zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0) + ones = self.fill(self.dtypeop(prob), broadcast_shape, 1.0) + comp_lo = self.less(value, low) + comp_hi = self.less(value, high) + less_than_low = self.select(comp_lo, zeros, prob) + return self.select(comp_hi, less_than_low, ones) - def _sample(self, name, shape=(), low=None, high=None): + def _sample(self, shape=(), low=None, high=None): """ Sampling. Args: - name (str): name of the function. Should always be 'sample' when passed in from construct. shape (tuple): shape of the sample. Default: (). low (Tensor): lower bound of the distribution. Default: self.low. high (Tensor): upper bound of the distribution. Default: self.high. @@ -292,13 +277,11 @@ class Uniform(Distribution): Returns: Tensor, shape is shape + batch_shape. """ - if name == 'sample': - low = self.low if low is None else low - high = self.high if high is None else high - broadcast_shape = self.shape(low + high) - l_zero = self.const(0.0) - h_one = self.const(1.0) - sample_uniform = self.uniform(shape + broadcast_shape, l_zero, h_one) - sample = (high - low) * sample_uniform + low - return sample - return None + low = self.low if low is None else low + high = self.high if high is None else high + broadcast_shape = self.shape(low + high) + l_zero = self.const(0.0) + h_one = self.const(1.0) + sample_uniform = self.uniform(shape + broadcast_shape, l_zero, h_one) + sample = (high - low) * sample_uniform + low + return sample diff --git a/tests/st/ops/ascend/test_distribution/test_bernoulli.py b/tests/st/ops/ascend/test_distribution/test_bernoulli.py index 98c6d979e..2dc2300f5 100644 --- a/tests/st/ops/ascend/test_distribution/test_bernoulli.py +++ b/tests/st/ops/ascend/test_distribution/test_bernoulli.py @@ -19,7 +19,6 @@ import mindspore.context as context import mindspore.nn as nn import mindspore.nn.probability.distribution as msd from mindspore import Tensor -from mindspore.common.api import ms_function from mindspore import dtype context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") @@ -32,9 +31,8 @@ class Prob(nn.Cell): super(Prob, self).__init__() self.b = msd.Bernoulli(0.7, dtype=dtype.int32) - @ms_function def construct(self, x_): - return self.b('prob', x_) + return self.b.prob(x_) def test_pmf(): """ @@ -57,9 +55,8 @@ class LogProb(nn.Cell): super(LogProb, self).__init__() self.b = msd.Bernoulli(0.7, dtype=dtype.int32) - @ms_function def construct(self, x_): - return self.b('log_prob', x_) + return self.b.log_prob(x_) def test_log_likelihood(): """ @@ -81,9 +78,8 @@ class KL(nn.Cell): super(KL, self).__init__() self.b = msd.Bernoulli(0.7, dtype=dtype.int32) - @ms_function def construct(self, x_): - return self.b('kl_loss', 'Bernoulli', x_) + return self.b.kl_loss('Bernoulli', x_) def test_kl_loss(): """ @@ -107,9 +103,8 @@ class Basics(nn.Cell): super(Basics, self).__init__() self.b = msd.Bernoulli([0.3, 0.5, 0.7], dtype=dtype.int32) - @ms_function def construct(self): - return self.b('mean'), self.b('sd'), self.b('mode') + return self.b.mean(), self.b.sd(), self.b.mode() def test_basics(): """ @@ -134,9 +129,8 @@ class Sampling(nn.Cell): self.b = msd.Bernoulli([0.7, 0.5], seed=seed, dtype=dtype.int32) self.shape = shape - @ms_function def construct(self, probs=None): - return self.b('sample', self.shape, probs) + return self.b.sample(self.shape, probs) def test_sample(): """ @@ -155,9 +149,8 @@ class CDF(nn.Cell): super(CDF, self).__init__() self.b = msd.Bernoulli(0.7, dtype=dtype.int32) - @ms_function def construct(self, x_): - return self.b('cdf', x_) + return self.b.cdf(x_) def test_cdf(): """ @@ -171,7 +164,6 @@ def test_cdf(): tol = 1e-6 assert (np.abs(output.asnumpy() - expect_cdf) < tol).all() - class LogCDF(nn.Cell): """ Test class: log cdf of bernoulli distributions. @@ -180,9 +172,8 @@ class LogCDF(nn.Cell): super(LogCDF, self).__init__() self.b = msd.Bernoulli(0.7, dtype=dtype.int32) - @ms_function def construct(self, x_): - return self.b('log_cdf', x_) + return self.b.log_cdf(x_) def test_logcdf(): """ @@ -205,9 +196,8 @@ class SF(nn.Cell): super(SF, self).__init__() self.b = msd.Bernoulli(0.7, dtype=dtype.int32) - @ms_function def construct(self, x_): - return self.b('survival_function', x_) + return self.b.survival_function(x_) def test_survival(): """ @@ -230,9 +220,8 @@ class LogSF(nn.Cell): super(LogSF, self).__init__() self.b = msd.Bernoulli(0.7, dtype=dtype.int32) - @ms_function def construct(self, x_): - return self.b('log_survival', x_) + return self.b.log_survival(x_) def test_log_survival(): """ @@ -254,9 +243,8 @@ class EntropyH(nn.Cell): super(EntropyH, self).__init__() self.b = msd.Bernoulli(0.7, dtype=dtype.int32) - @ms_function def construct(self): - return self.b('entropy') + return self.b.entropy() def test_entropy(): """ @@ -277,12 +265,11 @@ class CrossEntropy(nn.Cell): super(CrossEntropy, self).__init__() self.b = msd.Bernoulli(0.7, dtype=dtype.int32) - @ms_function def construct(self, x_): - entropy = self.b('entropy') - kl_loss = self.b('kl_loss', 'Bernoulli', x_) + entropy = self.b.entropy() + kl_loss = self.b.kl_loss('Bernoulli', x_) h_sum_kl = entropy + kl_loss - cross_entropy = self.b('cross_entropy', 'Bernoulli', x_) + cross_entropy = self.b.cross_entropy('Bernoulli', x_) return h_sum_kl - cross_entropy def test_cross_entropy(): diff --git a/tests/st/ops/ascend/test_distribution/test_exponential.py b/tests/st/ops/ascend/test_distribution/test_exponential.py index d46fa87bd..ba1689c6f 100644 --- a/tests/st/ops/ascend/test_distribution/test_exponential.py +++ b/tests/st/ops/ascend/test_distribution/test_exponential.py @@ -19,7 +19,6 @@ import mindspore.context as context import mindspore.nn as nn import mindspore.nn.probability.distribution as msd from mindspore import Tensor -from mindspore.common.api import ms_function from mindspore import dtype context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") @@ -32,9 +31,8 @@ class Prob(nn.Cell): super(Prob, self).__init__() self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32) - @ms_function def construct(self, x_): - return self.e('prob', x_) + return self.e.prob(x_) def test_pdf(): """ @@ -56,9 +54,8 @@ class LogProb(nn.Cell): super(LogProb, self).__init__() self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32) - @ms_function def construct(self, x_): - return self.e('log_prob', x_) + return self.e.log_prob(x_) def test_log_likelihood(): """ @@ -80,9 +77,8 @@ class KL(nn.Cell): super(KL, self).__init__() self.e = msd.Exponential([1.5], dtype=dtype.float32) - @ms_function def construct(self, x_): - return self.e('kl_loss', 'Exponential', x_) + return self.e.kl_loss('Exponential', x_) def test_kl_loss(): """ @@ -104,9 +100,8 @@ class Basics(nn.Cell): super(Basics, self).__init__() self.e = msd.Exponential([0.5], dtype=dtype.float32) - @ms_function def construct(self): - return self.e('mean'), self.e('sd'), self.e('mode') + return self.e.mean(), self.e.sd(), self.e.mode() def test_basics(): """ @@ -131,9 +126,8 @@ class Sampling(nn.Cell): self.e = msd.Exponential([[1.0], [0.5]], seed=seed, dtype=dtype.float32) self.shape = shape - @ms_function def construct(self, rate=None): - return self.e('sample', self.shape, rate) + return self.e.sample(self.shape, rate) def test_sample(): """ @@ -154,9 +148,8 @@ class CDF(nn.Cell): super(CDF, self).__init__() self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32) - @ms_function def construct(self, x_): - return self.e('cdf', x_) + return self.e.cdf(x_) def test_cdf(): """ @@ -178,9 +171,8 @@ class LogCDF(nn.Cell): super(LogCDF, self).__init__() self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32) - @ms_function def construct(self, x_): - return self.e('log_cdf', x_) + return self.e.log_cdf(x_) def test_log_cdf(): """ @@ -202,9 +194,8 @@ class SF(nn.Cell): super(SF, self).__init__() self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32) - @ms_function def construct(self, x_): - return self.e('survival_function', x_) + return self.e.survival_function(x_) def test_survival(): """ @@ -226,9 +217,8 @@ class LogSF(nn.Cell): super(LogSF, self).__init__() self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32) - @ms_function def construct(self, x_): - return self.e('log_survival', x_) + return self.e.log_survival(x_) def test_log_survival(): """ @@ -250,9 +240,8 @@ class EntropyH(nn.Cell): super(EntropyH, self).__init__() self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32) - @ms_function def construct(self): - return self.e('entropy') + return self.e.entropy() def test_entropy(): """ @@ -273,12 +262,11 @@ class CrossEntropy(nn.Cell): super(CrossEntropy, self).__init__() self.e = msd.Exponential([1.0], dtype=dtype.float32) - @ms_function def construct(self, x_): - entropy = self.e('entropy') - kl_loss = self.e('kl_loss', 'Exponential', x_) + entropy = self.e.entropy() + kl_loss = self.e.kl_loss('Exponential', x_) h_sum_kl = entropy + kl_loss - cross_entropy = self.e('cross_entropy', 'Exponential', x_) + cross_entropy = self.e.cross_entropy('Exponential', x_) return h_sum_kl - cross_entropy def test_cross_entropy(): diff --git a/tests/st/ops/ascend/test_distribution/test_geometric.py b/tests/st/ops/ascend/test_distribution/test_geometric.py index e4770ff6e..6b2a5ba84 100644 --- a/tests/st/ops/ascend/test_distribution/test_geometric.py +++ b/tests/st/ops/ascend/test_distribution/test_geometric.py @@ -19,7 +19,6 @@ import mindspore.context as context import mindspore.nn as nn import mindspore.nn.probability.distribution as msd from mindspore import Tensor -from mindspore.common.api import ms_function from mindspore import dtype context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") @@ -32,9 +31,8 @@ class Prob(nn.Cell): super(Prob, self).__init__() self.g = msd.Geometric(0.7, dtype=dtype.int32) - @ms_function def construct(self, x_): - return self.g('prob', x_) + return self.g.prob(x_) def test_pmf(): """ @@ -56,9 +54,8 @@ class LogProb(nn.Cell): super(LogProb, self).__init__() self.g = msd.Geometric(0.7, dtype=dtype.int32) - @ms_function def construct(self, x_): - return self.g('log_prob', x_) + return self.g.log_prob(x_) def test_log_likelihood(): """ @@ -80,9 +77,8 @@ class KL(nn.Cell): super(KL, self).__init__() self.g = msd.Geometric(0.7, dtype=dtype.int32) - @ms_function def construct(self, x_): - return self.g('kl_loss', 'Geometric', x_) + return self.g.kl_loss('Geometric', x_) def test_kl_loss(): """ @@ -106,9 +102,8 @@ class Basics(nn.Cell): super(Basics, self).__init__() self.g = msd.Geometric([0.5, 0.5], dtype=dtype.int32) - @ms_function def construct(self): - return self.g('mean'), self.g('sd'), self.g('mode') + return self.g.mean(), self.g.sd(), self.g.mode() def test_basics(): """ @@ -133,9 +128,8 @@ class Sampling(nn.Cell): self.g = msd.Geometric([0.7, 0.5], seed=seed, dtype=dtype.int32) self.shape = shape - @ms_function def construct(self, probs=None): - return self.g('sample', self.shape, probs) + return self.g.sample(self.shape, probs) def test_sample(): """ @@ -154,9 +148,8 @@ class CDF(nn.Cell): super(CDF, self).__init__() self.g = msd.Geometric(0.7, dtype=dtype.int32) - @ms_function def construct(self, x_): - return self.g('cdf', x_) + return self.g.cdf(x_) def test_cdf(): """ @@ -178,9 +171,8 @@ class LogCDF(nn.Cell): super(LogCDF, self).__init__() self.g = msd.Geometric(0.7, dtype=dtype.int32) - @ms_function def construct(self, x_): - return self.g('log_cdf', x_) + return self.g.log_cdf(x_) def test_logcdf(): """ @@ -202,9 +194,8 @@ class SF(nn.Cell): super(SF, self).__init__() self.g = msd.Geometric(0.7, dtype=dtype.int32) - @ms_function def construct(self, x_): - return self.g('survival_function', x_) + return self.g.survival_function(x_) def test_survival(): """ @@ -226,9 +217,8 @@ class LogSF(nn.Cell): super(LogSF, self).__init__() self.g = msd.Geometric(0.7, dtype=dtype.int32) - @ms_function def construct(self, x_): - return self.g('log_survival', x_) + return self.g.log_survival(x_) def test_log_survival(): """ @@ -250,9 +240,8 @@ class EntropyH(nn.Cell): super(EntropyH, self).__init__() self.g = msd.Geometric(0.7, dtype=dtype.int32) - @ms_function def construct(self): - return self.g('entropy') + return self.g.entropy() def test_entropy(): """ @@ -273,12 +262,11 @@ class CrossEntropy(nn.Cell): super(CrossEntropy, self).__init__() self.g = msd.Geometric(0.7, dtype=dtype.int32) - @ms_function def construct(self, x_): - entropy = self.g('entropy') - kl_loss = self.g('kl_loss', 'Geometric', x_) + entropy = self.g.entropy() + kl_loss = self.g.kl_loss('Geometric', x_) h_sum_kl = entropy + kl_loss - ans = self.g('cross_entropy', 'Geometric', x_) + ans = self.g.cross_entropy('Geometric', x_) return h_sum_kl - ans def test_cross_entropy(): diff --git a/tests/st/ops/ascend/test_distribution/test_normal.py b/tests/st/ops/ascend/test_distribution/test_normal.py index f196a7cef..ee851281e 100644 --- a/tests/st/ops/ascend/test_distribution/test_normal.py +++ b/tests/st/ops/ascend/test_distribution/test_normal.py @@ -19,7 +19,6 @@ import mindspore.context as context import mindspore.nn as nn import mindspore.nn.probability.distribution as msd from mindspore import Tensor -from mindspore.common.api import ms_function from mindspore import dtype context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") @@ -32,9 +31,8 @@ class Prob(nn.Cell): super(Prob, self).__init__() self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) - @ms_function def construct(self, x_): - return self.n('prob', x_) + return self.n.prob(x_) def test_pdf(): """ @@ -55,9 +53,8 @@ class LogProb(nn.Cell): super(LogProb, self).__init__() self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) - @ms_function def construct(self, x_): - return self.n('log_prob', x_) + return self.n.log_prob(x_) def test_log_likelihood(): """ @@ -79,9 +76,8 @@ class KL(nn.Cell): super(KL, self).__init__() self.n = msd.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) - @ms_function def construct(self, x_, y_): - return self.n('kl_loss', 'Normal', x_, y_) + return self.n.kl_loss('Normal', x_, y_) def test_kl_loss(): @@ -113,9 +109,8 @@ class Basics(nn.Cell): super(Basics, self).__init__() self.n = msd.Normal(np.array([3.0]), np.array([2.0, 4.0]), dtype=dtype.float32) - @ms_function def construct(self): - return self.n('mean'), self.n('sd'), self.n('mode') + return self.n.mean(), self.n.sd(), self.n.mode() def test_basics(): """ @@ -139,9 +134,8 @@ class Sampling(nn.Cell): self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), seed=seed, dtype=dtype.float32) self.shape = shape - @ms_function def construct(self, mean=None, sd=None): - return self.n('sample', self.shape, mean, sd) + return self.n.sample(self.shape, mean, sd) def test_sample(): """ @@ -163,9 +157,8 @@ class CDF(nn.Cell): super(CDF, self).__init__() self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) - @ms_function def construct(self, x_): - return self.n('cdf', x_) + return self.n.cdf(x_) def test_cdf(): @@ -187,9 +180,8 @@ class LogCDF(nn.Cell): super(LogCDF, self).__init__() self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) - @ms_function def construct(self, x_): - return self.n('log_cdf', x_) + return self.n.log_cdf(x_) def test_log_cdf(): """ @@ -210,9 +202,8 @@ class SF(nn.Cell): super(SF, self).__init__() self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) - @ms_function def construct(self, x_): - return self.n('survival_function', x_) + return self.n.survival_function(x_) def test_survival(): """ @@ -233,9 +224,8 @@ class LogSF(nn.Cell): super(LogSF, self).__init__() self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) - @ms_function def construct(self, x_): - return self.n('log_survival', x_) + return self.n.log_survival(x_) def test_log_survival(): """ @@ -256,9 +246,8 @@ class EntropyH(nn.Cell): super(EntropyH, self).__init__() self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) - @ms_function def construct(self): - return self.n('entropy') + return self.n.entropy() def test_entropy(): """ @@ -279,12 +268,11 @@ class CrossEntropy(nn.Cell): super(CrossEntropy, self).__init__() self.n = msd.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) - @ms_function def construct(self, x_, y_): - entropy = self.n('entropy') - kl_loss = self.n('kl_loss', 'Normal', x_, y_) + entropy = self.n.entropy() + kl_loss = self.n.kl_loss('Normal', x_, y_) h_sum_kl = entropy + kl_loss - cross_entropy = self.n('cross_entropy', 'Normal', x_, y_) + cross_entropy = self.n.cross_entropy('Normal', x_, y_) return h_sum_kl - cross_entropy def test_cross_entropy(): @@ -297,3 +285,40 @@ def test_cross_entropy(): diff = cross_entropy(mean, sd) tol = 1e-6 assert (np.abs(diff.asnumpy() - np.zeros(diff.shape)) < tol).all() + +class Net(nn.Cell): + """ + Test class: expand single distribution instance to multiple graphs + by specifying the attributes. + """ + + def __init__(self): + super(Net, self).__init__() + self.normal = msd.Normal(0., 1., dtype=dtype.float32) + + def construct(self, x_, y_): + kl = self.normal.kl_loss('Normal', x_, y_) + prob = self.normal.prob(kl) + return prob + +def test_multiple_graphs(): + """ + Test multiple graphs case. + """ + prob = Net() + mean_a = np.array([0.0]).astype(np.float32) + sd_a = np.array([1.0]).astype(np.float32) + mean_b = np.array([1.0]).astype(np.float32) + sd_b = np.array([1.0]).astype(np.float32) + ans = prob(Tensor(mean_b), Tensor(sd_b)) + + diff_log_scale = np.log(sd_a) - np.log(sd_b) + squared_diff = np.square(mean_a / sd_b - mean_b / sd_b) + expect_kl_loss = 0.5 * squared_diff + 0.5 * \ + np.expm1(2 * diff_log_scale) - diff_log_scale + + norm_benchmark = stats.norm(np.array([0.0]), np.array([1.0])) + expect_prob = norm_benchmark.pdf(expect_kl_loss).astype(np.float32) + + tol = 1e-6 + assert (np.abs(ans.asnumpy() - expect_prob) < tol).all() diff --git a/tests/st/ops/ascend/test_distribution/test_normal_new_api.py b/tests/st/ops/ascend/test_distribution/test_normal_new_api.py deleted file mode 100644 index 1860ba53c..000000000 --- a/tests/st/ops/ascend/test_distribution/test_normal_new_api.py +++ /dev/null @@ -1,62 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""test cases for new api of normal distribution""" -import numpy as np -from scipy import stats -import mindspore.nn as nn -import mindspore.nn.probability.distribution as msd -from mindspore import dtype -from mindspore import Tensor -import mindspore.context as context - -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") - - -class Net(nn.Cell): - """ - Test class: new api of normal distribution. - """ - - def __init__(self): - super(Net, self).__init__() - self.normal = msd.Normal(0., 1., dtype=dtype.float32) - - def construct(self, x_, y_): - kl = self.normal.kl_loss('kl_loss', 'Normal', x_, y_) - prob = self.normal.prob('prob', kl) - return prob - - -def test_new_api(): - """ - Test new api of normal distribution. - """ - prob = Net() - mean_a = np.array([0.0]).astype(np.float32) - sd_a = np.array([1.0]).astype(np.float32) - mean_b = np.array([1.0]).astype(np.float32) - sd_b = np.array([1.0]).astype(np.float32) - ans = prob(Tensor(mean_b), Tensor(sd_b)) - - diff_log_scale = np.log(sd_a) - np.log(sd_b) - squared_diff = np.square(mean_a / sd_b - mean_b / sd_b) - expect_kl_loss = 0.5 * squared_diff + 0.5 * \ - np.expm1(2 * diff_log_scale) - diff_log_scale - - norm_benchmark = stats.norm(np.array([0.0]), np.array([1.0])) - expect_prob = norm_benchmark.pdf(expect_kl_loss).astype(np.float32) - - tol = 1e-6 - assert (np.abs(ans.asnumpy() - expect_prob) < tol).all() diff --git a/tests/st/ops/ascend/test_distribution/test_uniform.py b/tests/st/ops/ascend/test_distribution/test_uniform.py index 357ad5f04..5e54f2cdc 100644 --- a/tests/st/ops/ascend/test_distribution/test_uniform.py +++ b/tests/st/ops/ascend/test_distribution/test_uniform.py @@ -19,7 +19,6 @@ import mindspore.context as context import mindspore.nn as nn import mindspore.nn.probability.distribution as msd from mindspore import Tensor -from mindspore.common.api import ms_function from mindspore import dtype context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") @@ -32,9 +31,8 @@ class Prob(nn.Cell): super(Prob, self).__init__() self.u = msd.Uniform([0.0], [[1.0], [2.0]], dtype=dtype.float32) - @ms_function def construct(self, x_): - return self.u('prob', x_) + return self.u.prob(x_) def test_pdf(): """ @@ -56,9 +54,8 @@ class LogProb(nn.Cell): super(LogProb, self).__init__() self.u = msd.Uniform([0.0], [[1.0], [2.0]], dtype=dtype.float32) - @ms_function def construct(self, x_): - return self.u('log_prob', x_) + return self.u.log_prob(x_) def test_log_likelihood(): """ @@ -80,9 +77,8 @@ class KL(nn.Cell): super(KL, self).__init__() self.u = msd.Uniform([0.0], [1.5], dtype=dtype.float32) - @ms_function def construct(self, x_, y_): - return self.u('kl_loss', 'Uniform', x_, y_) + return self.u.kl_loss('Uniform', x_, y_) def test_kl_loss(): """ @@ -106,9 +102,8 @@ class Basics(nn.Cell): super(Basics, self).__init__() self.u = msd.Uniform([0.0], [3.0], dtype=dtype.float32) - @ms_function def construct(self): - return self.u('mean'), self.u('sd') + return self.u.mean(), self.u.sd() def test_basics(): """ @@ -131,9 +126,8 @@ class Sampling(nn.Cell): self.u = msd.Uniform([0.0], [[1.0], [2.0]], seed=seed, dtype=dtype.float32) self.shape = shape - @ms_function def construct(self, low=None, high=None): - return self.u('sample', self.shape, low, high) + return self.u.sample(self.shape, low, high) def test_sample(): """ @@ -155,9 +149,8 @@ class CDF(nn.Cell): super(CDF, self).__init__() self.u = msd.Uniform([0.0], [1.0], dtype=dtype.float32) - @ms_function def construct(self, x_): - return self.u('cdf', x_) + return self.u.cdf(x_) def test_cdf(): """ @@ -179,9 +172,8 @@ class LogCDF(nn.Cell): super(LogCDF, self).__init__() self.u = msd.Uniform([0.0], [1.0], dtype=dtype.float32) - @ms_function def construct(self, x_): - return self.u('log_cdf', x_) + return self.u.log_cdf(x_) class SF(nn.Cell): """ @@ -191,9 +183,8 @@ class SF(nn.Cell): super(SF, self).__init__() self.u = msd.Uniform([0.0], [1.0], dtype=dtype.float32) - @ms_function def construct(self, x_): - return self.u('survival_function', x_) + return self.u.survival_function(x_) class LogSF(nn.Cell): """ @@ -203,9 +194,8 @@ class LogSF(nn.Cell): super(LogSF, self).__init__() self.u = msd.Uniform([0.0], [1.0], dtype=dtype.float32) - @ms_function def construct(self, x_): - return self.u('log_survival', x_) + return self.u.log_survival(x_) class EntropyH(nn.Cell): """ @@ -215,9 +205,8 @@ class EntropyH(nn.Cell): super(EntropyH, self).__init__() self.u = msd.Uniform([0.0], [1.0, 2.0], dtype=dtype.float32) - @ms_function def construct(self): - return self.u('entropy') + return self.u.entropy() def test_entropy(): """ @@ -238,12 +227,11 @@ class CrossEntropy(nn.Cell): super(CrossEntropy, self).__init__() self.u = msd.Uniform([0.0], [1.5], dtype=dtype.float32) - @ms_function def construct(self, x_, y_): - entropy = self.u('entropy') - kl_loss = self.u('kl_loss', 'Uniform', x_, y_) + entropy = self.u.entropy() + kl_loss = self.u.kl_loss('Uniform', x_, y_) h_sum_kl = entropy + kl_loss - cross_entropy = self.u('cross_entropy', 'Uniform', x_, y_) + cross_entropy = self.u.cross_entropy('Uniform', x_, y_) return h_sum_kl - cross_entropy def test_log_cdf(): diff --git a/tests/ut/python/nn/distribution/test_bernoulli.py b/tests/ut/python/nn/distribution/test_bernoulli.py index 3ddbe9bc5..d34455dbe 100644 --- a/tests/ut/python/nn/distribution/test_bernoulli.py +++ b/tests/ut/python/nn/distribution/test_bernoulli.py @@ -49,12 +49,12 @@ class BernoulliProb(nn.Cell): self.b = msd.Bernoulli(0.5, dtype=dtype.int32) def construct(self, value): - prob = self.b('prob', value) - log_prob = self.b('log_prob', value) - cdf = self.b('cdf', value) - log_cdf = self.b('log_cdf', value) - sf = self.b('survival_function', value) - log_sf = self.b('log_survival', value) + prob = self.b.prob(value) + log_prob = self.b.log_prob(value) + cdf = self.b.cdf(value) + log_cdf = self.b.log_cdf(value) + sf = self.b.survival_function(value) + log_sf = self.b.log_survival(value) return prob + log_prob + cdf + log_cdf + sf + log_sf def test_bernoulli_prob(): @@ -75,12 +75,12 @@ class BernoulliProb1(nn.Cell): self.b = msd.Bernoulli(dtype=dtype.int32) def construct(self, value, probs): - prob = self.b('prob', value, probs) - log_prob = self.b('log_prob', value, probs) - cdf = self.b('cdf', value, probs) - log_cdf = self.b('log_cdf', value, probs) - sf = self.b('survival_function', value, probs) - log_sf = self.b('log_survival', value, probs) + prob = self.b.prob(value, probs) + log_prob = self.b.log_prob(value, probs) + cdf = self.b.cdf(value, probs) + log_cdf = self.b.log_cdf(value, probs) + sf = self.b.survival_function(value, probs) + log_sf = self.b.log_survival(value, probs) return prob + log_prob + cdf + log_cdf + sf + log_sf def test_bernoulli_prob1(): @@ -103,8 +103,8 @@ class BernoulliKl(nn.Cell): self.b2 = msd.Bernoulli(dtype=dtype.int32) def construct(self, probs_b, probs_a): - kl1 = self.b1('kl_loss', 'Bernoulli', probs_b) - kl2 = self.b2('kl_loss', 'Bernoulli', probs_b, probs_a) + kl1 = self.b1.kl_loss('Bernoulli', probs_b) + kl2 = self.b2.kl_loss('Bernoulli', probs_b, probs_a) return kl1 + kl2 def test_kl(): @@ -127,8 +127,8 @@ class BernoulliCrossEntropy(nn.Cell): self.b2 = msd.Bernoulli(dtype=dtype.int32) def construct(self, probs_b, probs_a): - h1 = self.b1('cross_entropy', 'Bernoulli', probs_b) - h2 = self.b2('cross_entropy', 'Bernoulli', probs_b, probs_a) + h1 = self.b1.cross_entropy('Bernoulli', probs_b) + h2 = self.b2.cross_entropy('Bernoulli', probs_b, probs_a) return h1 + h2 def test_cross_entropy(): @@ -150,11 +150,11 @@ class BernoulliBasics(nn.Cell): self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32) def construct(self): - mean = self.b('mean') - sd = self.b('sd') - var = self.b('var') - mode = self.b('mode') - entropy = self.b('entropy') + mean = self.b.mean() + sd = self.b.sd() + var = self.b.var() + mode = self.b.mode() + entropy = self.b.entropy() return mean + sd + var + mode + entropy def test_bascis(): @@ -164,3 +164,28 @@ def test_bascis(): net = BernoulliBasics() ans = net() assert isinstance(ans, Tensor) + +class BernoulliConstruct(nn.Cell): + """ + Bernoulli distribution: going through construct. + """ + def __init__(self): + super(BernoulliConstruct, self).__init__() + self.b = msd.Bernoulli(0.5, dtype=dtype.int32) + self.b1 = msd.Bernoulli(dtype=dtype.int32) + + def construct(self, value, probs): + prob = self.b('prob', value) + prob1 = self.b('prob', value, probs) + prob2 = self.b1('prob', value, probs) + return prob + prob1 + prob2 + +def test_bernoulli_construct(): + """ + Test probability function going through construct. + """ + net = BernoulliConstruct() + value = Tensor([0, 0, 0, 0, 0], dtype=dtype.float32) + probs = Tensor([0.5], dtype=dtype.float32) + ans = net(value, probs) + assert isinstance(ans, Tensor) diff --git a/tests/ut/python/nn/distribution/test_exponential.py b/tests/ut/python/nn/distribution/test_exponential.py index 280ed70e5..43aa42827 100644 --- a/tests/ut/python/nn/distribution/test_exponential.py +++ b/tests/ut/python/nn/distribution/test_exponential.py @@ -50,12 +50,12 @@ class ExponentialProb(nn.Cell): self.e = msd.Exponential(0.5, dtype=dtype.float32) def construct(self, value): - prob = self.e('prob', value) - log_prob = self.e('log_prob', value) - cdf = self.e('cdf', value) - log_cdf = self.e('log_cdf', value) - sf = self.e('survival_function', value) - log_sf = self.e('log_survival', value) + prob = self.e.prob(value) + log_prob = self.e.log_prob(value) + cdf = self.e.cdf(value) + log_cdf = self.e.log_cdf(value) + sf = self.e.survival_function(value) + log_sf = self.e.log_survival(value) return prob + log_prob + cdf + log_cdf + sf + log_sf def test_exponential_prob(): @@ -76,12 +76,12 @@ class ExponentialProb1(nn.Cell): self.e = msd.Exponential(dtype=dtype.float32) def construct(self, value, rate): - prob = self.e('prob', value, rate) - log_prob = self.e('log_prob', value, rate) - cdf = self.e('cdf', value, rate) - log_cdf = self.e('log_cdf', value, rate) - sf = self.e('survival_function', value, rate) - log_sf = self.e('log_survival', value, rate) + prob = self.e.prob(value, rate) + log_prob = self.e.log_prob(value, rate) + cdf = self.e.cdf(value, rate) + log_cdf = self.e.log_cdf(value, rate) + sf = self.e.survival_function(value, rate) + log_sf = self.e.log_survival(value, rate) return prob + log_prob + cdf + log_cdf + sf + log_sf def test_exponential_prob1(): @@ -104,8 +104,8 @@ class ExponentialKl(nn.Cell): self.e2 = msd.Exponential(dtype=dtype.float32) def construct(self, rate_b, rate_a): - kl1 = self.e1('kl_loss', 'Exponential', rate_b) - kl2 = self.e2('kl_loss', 'Exponential', rate_b, rate_a) + kl1 = self.e1.kl_loss('Exponential', rate_b) + kl2 = self.e2.kl_loss('Exponential', rate_b, rate_a) return kl1 + kl2 def test_kl(): @@ -128,8 +128,8 @@ class ExponentialCrossEntropy(nn.Cell): self.e2 = msd.Exponential(dtype=dtype.float32) def construct(self, rate_b, rate_a): - h1 = self.e1('cross_entropy', 'Exponential', rate_b) - h2 = self.e2('cross_entropy', 'Exponential', rate_b, rate_a) + h1 = self.e1.cross_entropy('Exponential', rate_b) + h2 = self.e2.cross_entropy('Exponential', rate_b, rate_a) return h1 + h2 def test_cross_entropy(): @@ -151,11 +151,11 @@ class ExponentialBasics(nn.Cell): self.e = msd.Exponential([0.3, 0.5], dtype=dtype.float32) def construct(self): - mean = self.e('mean') - sd = self.e('sd') - var = self.e('var') - mode = self.e('mode') - entropy = self.e('entropy') + mean = self.e.mean() + sd = self.e.sd() + var = self.e.var() + mode = self.e.mode() + entropy = self.e.entropy() return mean + sd + var + mode + entropy def test_bascis(): @@ -165,3 +165,29 @@ def test_bascis(): net = ExponentialBasics() ans = net() assert isinstance(ans, Tensor) + + +class ExpConstruct(nn.Cell): + """ + Exponential distribution: going through construct. + """ + def __init__(self): + super(ExpConstruct, self).__init__() + self.e = msd.Exponential(0.5, dtype=dtype.float32) + self.e1 = msd.Exponential(dtype=dtype.float32) + + def construct(self, value, rate): + prob = self.e('prob', value) + prob1 = self.e('prob', value, rate) + prob2 = self.e1('prob', value, rate) + return prob + prob1 + prob2 + +def test_exp_construct(): + """ + Test probability function going through construct. + """ + net = ExpConstruct() + value = Tensor([0, 0, 0, 0, 0], dtype=dtype.float32) + probs = Tensor([0.5], dtype=dtype.float32) + ans = net(value, probs) + assert isinstance(ans, Tensor) diff --git a/tests/ut/python/nn/distribution/test_geometric.py b/tests/ut/python/nn/distribution/test_geometric.py index c6cdd6516..b705aae78 100644 --- a/tests/ut/python/nn/distribution/test_geometric.py +++ b/tests/ut/python/nn/distribution/test_geometric.py @@ -50,12 +50,12 @@ class GeometricProb(nn.Cell): self.g = msd.Geometric(0.5, dtype=dtype.int32) def construct(self, value): - prob = self.g('prob', value) - log_prob = self.g('log_prob', value) - cdf = self.g('cdf', value) - log_cdf = self.g('log_cdf', value) - sf = self.g('survival_function', value) - log_sf = self.g('log_survival', value) + prob = self.g.prob(value) + log_prob = self.g.log_prob(value) + cdf = self.g.cdf(value) + log_cdf = self.g.log_cdf(value) + sf = self.g.survival_function(value) + log_sf = self.g.log_survival(value) return prob + log_prob + cdf + log_cdf + sf + log_sf def test_geometric_prob(): @@ -76,12 +76,12 @@ class GeometricProb1(nn.Cell): self.g = msd.Geometric(dtype=dtype.int32) def construct(self, value, probs): - prob = self.g('prob', value, probs) - log_prob = self.g('log_prob', value, probs) - cdf = self.g('cdf', value, probs) - log_cdf = self.g('log_cdf', value, probs) - sf = self.g('survival_function', value, probs) - log_sf = self.g('log_survival', value, probs) + prob = self.g.prob(value, probs) + log_prob = self.g.log_prob(value, probs) + cdf = self.g.cdf(value, probs) + log_cdf = self.g.log_cdf(value, probs) + sf = self.g.survival_function(value, probs) + log_sf = self.g.log_survival(value, probs) return prob + log_prob + cdf + log_cdf + sf + log_sf def test_geometric_prob1(): @@ -105,8 +105,8 @@ class GeometricKl(nn.Cell): self.g2 = msd.Geometric(dtype=dtype.int32) def construct(self, probs_b, probs_a): - kl1 = self.g1('kl_loss', 'Geometric', probs_b) - kl2 = self.g2('kl_loss', 'Geometric', probs_b, probs_a) + kl1 = self.g1.kl_loss('Geometric', probs_b) + kl2 = self.g2.kl_loss('Geometric', probs_b, probs_a) return kl1 + kl2 def test_kl(): @@ -129,8 +129,8 @@ class GeometricCrossEntropy(nn.Cell): self.g2 = msd.Geometric(dtype=dtype.int32) def construct(self, probs_b, probs_a): - h1 = self.g1('cross_entropy', 'Geometric', probs_b) - h2 = self.g2('cross_entropy', 'Geometric', probs_b, probs_a) + h1 = self.g1.cross_entropy('Geometric', probs_b) + h2 = self.g2.cross_entropy('Geometric', probs_b, probs_a) return h1 + h2 def test_cross_entropy(): @@ -152,11 +152,11 @@ class GeometricBasics(nn.Cell): self.g = msd.Geometric([0.3, 0.5], dtype=dtype.int32) def construct(self): - mean = self.g('mean') - sd = self.g('sd') - var = self.g('var') - mode = self.g('mode') - entropy = self.g('entropy') + mean = self.g.mean() + sd = self.g.sd() + var = self.g.var() + mode = self.g.mode() + entropy = self.g.entropy() return mean + sd + var + mode + entropy def test_bascis(): @@ -166,3 +166,29 @@ def test_bascis(): net = GeometricBasics() ans = net() assert isinstance(ans, Tensor) + + +class GeoConstruct(nn.Cell): + """ + Bernoulli distribution: going through construct. + """ + def __init__(self): + super(GeoConstruct, self).__init__() + self.g = msd.Geometric(0.5, dtype=dtype.int32) + self.g1 = msd.Geometric(dtype=dtype.int32) + + def construct(self, value, probs): + prob = self.g('prob', value) + prob1 = self.g('prob', value, probs) + prob2 = self.g1('prob', value, probs) + return prob + prob1 + prob2 + +def test_geo_construct(): + """ + Test probability function going through construct. + """ + net = GeoConstruct() + value = Tensor([0, 0, 0, 0, 0], dtype=dtype.float32) + probs = Tensor([0.5], dtype=dtype.float32) + ans = net(value, probs) + assert isinstance(ans, Tensor) diff --git a/tests/ut/python/nn/distribution/test_normal.py b/tests/ut/python/nn/distribution/test_normal.py index 559855ee4..f569aa67a 100644 --- a/tests/ut/python/nn/distribution/test_normal.py +++ b/tests/ut/python/nn/distribution/test_normal.py @@ -50,12 +50,12 @@ class NormalProb(nn.Cell): self.normal = msd.Normal(3.0, 4.0, dtype=dtype.float32) def construct(self, value): - prob = self.normal('prob', value) - log_prob = self.normal('log_prob', value) - cdf = self.normal('cdf', value) - log_cdf = self.normal('log_cdf', value) - sf = self.normal('survival_function', value) - log_sf = self.normal('log_survival', value) + prob = self.normal.prob(value) + log_prob = self.normal.log_prob(value) + cdf = self.normal.cdf(value) + log_cdf = self.normal.log_cdf(value) + sf = self.normal.survival_function(value) + log_sf = self.normal.log_survival(value) return prob + log_prob + cdf + log_cdf + sf + log_sf def test_normal_prob(): @@ -77,12 +77,12 @@ class NormalProb1(nn.Cell): self.normal = msd.Normal() def construct(self, value, mean, sd): - prob = self.normal('prob', value, mean, sd) - log_prob = self.normal('log_prob', value, mean, sd) - cdf = self.normal('cdf', value, mean, sd) - log_cdf = self.normal('log_cdf', value, mean, sd) - sf = self.normal('survival_function', value, mean, sd) - log_sf = self.normal('log_survival', value, mean, sd) + prob = self.normal.prob(value, mean, sd) + log_prob = self.normal.log_prob(value, mean, sd) + cdf = self.normal.cdf(value, mean, sd) + log_cdf = self.normal.log_cdf(value, mean, sd) + sf = self.normal.survival_function(value, mean, sd) + log_sf = self.normal.log_survival(value, mean, sd) return prob + log_prob + cdf + log_cdf + sf + log_sf def test_normal_prob1(): @@ -106,8 +106,8 @@ class NormalKl(nn.Cell): self.n2 = msd.Normal(dtype=dtype.float32) def construct(self, mean_b, sd_b, mean_a, sd_a): - kl1 = self.n1('kl_loss', 'Normal', mean_b, sd_b) - kl2 = self.n2('kl_loss', 'Normal', mean_b, sd_b, mean_a, sd_a) + kl1 = self.n1.kl_loss('Normal', mean_b, sd_b) + kl2 = self.n2.kl_loss('Normal', mean_b, sd_b, mean_a, sd_a) return kl1 + kl2 def test_kl(): @@ -132,8 +132,8 @@ class NormalCrossEntropy(nn.Cell): self.n2 = msd.Normal(dtype=dtype.float32) def construct(self, mean_b, sd_b, mean_a, sd_a): - h1 = self.n1('cross_entropy', 'Normal', mean_b, sd_b) - h2 = self.n2('cross_entropy', 'Normal', mean_b, sd_b, mean_a, sd_a) + h1 = self.n1.cross_entropy('Normal', mean_b, sd_b) + h2 = self.n2.cross_entropy('Normal', mean_b, sd_b, mean_a, sd_a) return h1 + h2 def test_cross_entropy(): @@ -157,10 +157,10 @@ class NormalBasics(nn.Cell): self.n = msd.Normal(3.0, 4.0, dtype=dtype.float32) def construct(self): - mean = self.n('mean') - sd = self.n('sd') - mode = self.n('mode') - entropy = self.n('entropy') + mean = self.n.mean() + sd = self.n.sd() + mode = self.n.mode() + entropy = self.n.entropy() return mean + sd + mode + entropy def test_bascis(): @@ -170,3 +170,30 @@ def test_bascis(): net = NormalBasics() ans = net() assert isinstance(ans, Tensor) + + +class NormalConstruct(nn.Cell): + """ + Normal distribution: going through construct. + """ + def __init__(self): + super(NormalConstruct, self).__init__() + self.normal = msd.Normal(3.0, 4.0) + self.normal1 = msd.Normal() + + def construct(self, value, mean, sd): + prob = self.normal('prob', value) + prob1 = self.normal('prob', value, mean, sd) + prob2 = self.normal1('prob', value, mean, sd) + return prob + prob1 + prob2 + +def test_normal_construct(): + """ + Test probability function going through construct. + """ + net = NormalConstruct() + value = Tensor([0.5, 1.0], dtype=dtype.float32) + mean = Tensor([0.0], dtype=dtype.float32) + sd = Tensor([1.0], dtype=dtype.float32) + ans = net(value, mean, sd) + assert isinstance(ans, Tensor) diff --git a/tests/ut/python/nn/distribution/test_uniform.py b/tests/ut/python/nn/distribution/test_uniform.py index 2cc91f016..a631998e8 100644 --- a/tests/ut/python/nn/distribution/test_uniform.py +++ b/tests/ut/python/nn/distribution/test_uniform.py @@ -60,12 +60,12 @@ class UniformProb(nn.Cell): self.u = msd.Uniform(3.0, 4.0, dtype=dtype.float32) def construct(self, value): - prob = self.u('prob', value) - log_prob = self.u('log_prob', value) - cdf = self.u('cdf', value) - log_cdf = self.u('log_cdf', value) - sf = self.u('survival_function', value) - log_sf = self.u('log_survival', value) + prob = self.u.prob(value) + log_prob = self.u.log_prob(value) + cdf = self.u.cdf(value) + log_cdf = self.u.log_cdf(value) + sf = self.u.survival_function(value) + log_sf = self.u.log_survival(value) return prob + log_prob + cdf + log_cdf + sf + log_sf def test_uniform_prob(): @@ -86,12 +86,12 @@ class UniformProb1(nn.Cell): self.u = msd.Uniform(dtype=dtype.float32) def construct(self, value, low, high): - prob = self.u('prob', value, low, high) - log_prob = self.u('log_prob', value, low, high) - cdf = self.u('cdf', value, low, high) - log_cdf = self.u('log_cdf', value, low, high) - sf = self.u('survival_function', value, low, high) - log_sf = self.u('log_survival', value, low, high) + prob = self.u.prob(value, low, high) + log_prob = self.u.log_prob(value, low, high) + cdf = self.u.cdf(value, low, high) + log_cdf = self.u.log_cdf(value, low, high) + sf = self.u.survival_function(value, low, high) + log_sf = self.u.log_survival(value, low, high) return prob + log_prob + cdf + log_cdf + sf + log_sf def test_uniform_prob1(): @@ -115,8 +115,8 @@ class UniformKl(nn.Cell): self.u2 = msd.Uniform(dtype=dtype.float32) def construct(self, low_b, high_b, low_a, high_a): - kl1 = self.u1('kl_loss', 'Uniform', low_b, high_b) - kl2 = self.u2('kl_loss', 'Uniform', low_b, high_b, low_a, high_a) + kl1 = self.u1.kl_loss('Uniform', low_b, high_b) + kl2 = self.u2.kl_loss('Uniform', low_b, high_b, low_a, high_a) return kl1 + kl2 def test_kl(): @@ -141,8 +141,8 @@ class UniformCrossEntropy(nn.Cell): self.u2 = msd.Uniform(dtype=dtype.float32) def construct(self, low_b, high_b, low_a, high_a): - h1 = self.u1('cross_entropy', 'Uniform', low_b, high_b) - h2 = self.u2('cross_entropy', 'Uniform', low_b, high_b, low_a, high_a) + h1 = self.u1.cross_entropy('Uniform', low_b, high_b) + h2 = self.u2.cross_entropy('Uniform', low_b, high_b, low_a, high_a) return h1 + h2 def test_cross_entropy(): @@ -166,10 +166,10 @@ class UniformBasics(nn.Cell): self.u = msd.Uniform(3.0, 4.0, dtype=dtype.float32) def construct(self): - mean = self.u('mean') - sd = self.u('sd') - var = self.u('var') - entropy = self.u('entropy') + mean = self.u.mean() + sd = self.u.sd() + var = self.u.var() + entropy = self.u.entropy() return mean + sd + var + entropy def test_bascis(): @@ -179,3 +179,30 @@ def test_bascis(): net = UniformBasics() ans = net() assert isinstance(ans, Tensor) + + +class UniConstruct(nn.Cell): + """ + Unifrom distribution: going through construct. + """ + def __init__(self): + super(UniConstruct, self).__init__() + self.u = msd.Uniform(-4.0, 4.0) + self.u1 = msd.Uniform() + + def construct(self, value, low, high): + prob = self.u('prob', value) + prob1 = self.u('prob', value, low, high) + prob2 = self.u1('prob', value, low, high) + return prob + prob1 + prob2 + +def test_uniform_construct(): + """ + Test probability function going through construct. + """ + net = UniConstruct() + value = Tensor([-5.0, 0.0, 1.0, 5.0], dtype=dtype.float32) + low = Tensor([-1.0], dtype=dtype.float32) + high = Tensor([1.0], dtype=dtype.float32) + ans = net(value, low, high) + assert isinstance(ans, Tensor) -- GitLab