提交 e87e1fc6 编写于 作者: X Xun Deng

changed distribution api

上级 6945eb28
......@@ -34,55 +34,56 @@ class Bernoulli(Distribution):
Examples:
>>> # To initialize a Bernoulli distribution of prob 0.5
>>> n = nn.Bernoulli(0.5, dtype=mstype.int32)
>>> import mindspore.nn.probability.distribution as msd
>>> b = msd.Bernoulli(0.5, dtype=mstype.int32)
>>>
>>> # The following creates two independent Bernoulli distributions
>>> n = nn.Bernoulli([0.5, 0.5], dtype=mstype.int32)
>>> b = msd.Bernoulli([0.5, 0.5], dtype=mstype.int32)
>>>
>>> # A Bernoulli distribution can be initilized without arguments
>>> # In this case, probs must be passed in through construct.
>>> n = nn.Bernoulli(dtype=mstype.int32)
>>> # In this case, probs must be passed in through args during function calls.
>>> b = msd.Bernoulli(dtype=mstype.int32)
>>>
>>> # To use Bernoulli distribution in a network
>>> # To use Bernoulli in a network
>>> class net(Cell):
>>> def __init__(self):
>>> super(net, self).__init__():
>>> self.b1 = nn.Bernoulli(0.5, dtype=mstype.int32)
>>> self.b2 = nn.Bernoulli(dtype=mstype.int32)
>>> self.b1 = msd.Bernoulli(0.5, dtype=mstype.int32)
>>> self.b2 = msd.Bernoulli(dtype=mstype.int32)
>>>
>>> # All the following calls in construct are valid
>>> def construct(self, value, probs_b, probs_a):
>>>
>>> # Similar calls can be made to other probability functions
>>> # by replacing 'prob' with the name of the function
>>> ans = self.b1('prob', value)
>>> ans = self.b1.prob(value)
>>> # Evaluate with the respect to distribution b
>>> ans = self.b1('prob', value, probs_b)
>>> ans = self.b1.prob(value, probs_b)
>>>
>>> # probs must be passed in through construct
>>> ans = self.b2('prob', value, probs_a)
>>> # probs must be passed in during function calls
>>> ans = self.b2.prob(value, probs_a)
>>>
>>> # Functions 'sd', 'var', 'entropy' have the same usage like 'mean'
>>> # Will return [0.0]
>>> ans = self.b1('mean')
>>> # Will return mean_b
>>> ans = self.b1('mean', probs_b)
>>> # Functions 'sd', 'var', 'entropy' have the same usage as 'mean'
>>> # Will return 0.5
>>> ans = self.b1.mean()
>>> # Will return probs_b
>>> ans = self.b1.mean(probs_b)
>>>
>>> # probs must be passed in through construct
>>> ans = self.b2('mean', probs_a)
>>> # probs must be passed in during function calls
>>> ans = self.b2.mean(probs_a)
>>>
>>> # Usage of 'kl_loss' and 'cross_entropy' are similar
>>> ans = self.b1('kl_loss', 'Bernoulli', probs_b)
>>> ans = self.b1('kl_loss', 'Bernoulli', probs_b, probs_a)
>>> ans = self.b1.kl_loss('Bernoulli', probs_b)
>>> ans = self.b1.kl_loss('Bernoulli', probs_b, probs_a)
>>>
>>> # Additional probs_a must be passed in through construct
>>> ans = self.b2('kl_loss', 'Bernoulli', probs_b, probs_a)
>>> # Additional probs_a must be passed in through
>>> ans = self.b2.kl_loss('Bernoulli', probs_b, probs_a)
>>>
>>> # Sample Usage
>>> ans = self.b1('sample')
>>> ans = self.b1('sample', (2,3))
>>> ans = self.b1('sample', (2,3), probs_b)
>>> ans = self.b2('sample', (2,3), probs_a)
>>> # Sample
>>> ans = self.b1.sample()
>>> ans = self.b1.sample((2,3))
>>> ans = self.b1.sample((2,3), probs_b)
>>> ans = self.b2.sample((2,3), probs_a)
"""
def __init__(self,
......@@ -130,71 +131,61 @@ class Bernoulli(Distribution):
"""
return self._probs
def _mean(self, name='mean', probs1=None):
def _mean(self, probs1=None):
r"""
.. math::
MEAN(B) = probs1
"""
if name == 'mean':
return self.probs if probs1 is None else probs1
return None
return self.probs if probs1 is None else probs1
def _mode(self, name='mode', probs1=None):
def _mode(self, probs1=None):
r"""
.. math::
MODE(B) = 1 if probs1 > 0.5 else = 0
"""
if name == 'mode':
probs1 = self.probs if probs1 is None else probs1
prob_type = self.dtypeop(probs1)
zeros = self.fill(prob_type, self.shape(probs1), 0.0)
ones = self.fill(prob_type, self.shape(probs1), 1.0)
comp = self.less(0.5, probs1)
return self.select(comp, ones, zeros)
return None
probs1 = self.probs if probs1 is None else probs1
prob_type = self.dtypeop(probs1)
zeros = self.fill(prob_type, self.shape(probs1), 0.0)
ones = self.fill(prob_type, self.shape(probs1), 1.0)
comp = self.less(0.5, probs1)
return self.select(comp, ones, zeros)
def _var(self, name='var', probs1=None):
def _var(self, probs1=None):
r"""
.. math::
VAR(B) = probs1 * probs0
"""
if name in self._variance_functions:
probs1 = self.probs if probs1 is None else probs1
probs0 = 1.0 - probs1
return probs0 * probs1
return None
probs1 = self.probs if probs1 is None else probs1
probs0 = 1.0 - probs1
return probs0 * probs1
def _entropy(self, name='entropy', probs=None):
def _entropy(self, probs=None):
r"""
.. math::
H(B) = -probs0 * \log(probs0) - probs1 * \log(probs1)
"""
if name == 'entropy':
probs1 = self.probs if probs is None else probs
probs0 = 1 - probs1
return -1 * (probs0 * self.log(probs0)) - (probs1 * self.log(probs1))
return None
probs1 = self.probs if probs is None else probs
probs0 = 1 - probs1
return -1 * (probs0 * self.log(probs0)) - (probs1 * self.log(probs1))
def _cross_entropy(self, name, dist, probs1_b, probs1_a=None):
def _cross_entropy(self, dist, probs1_b, probs1_a=None):
"""
Evaluate cross_entropy between Bernoulli distributions.
Args:
name (str): name of the funtion.
dist (str): type of the distributions. Should be "Bernoulli" in this case.
probs1_b (Tensor): probs1 of distribution b.
probs1_a (Tensor): probs1 of distribution a. Default: self.probs.
"""
if name == 'cross_entropy' and dist == 'Bernoulli':
return self._entropy(probs=probs1_a) + self._kl_loss(name, dist, probs1_b, probs1_a)
if dist == 'Bernoulli':
return self._entropy(probs=probs1_a) + self._kl_loss(dist, probs1_b, probs1_a)
return None
def _prob(self, name, value, probs=None):
def _prob(self, value, probs=None):
r"""
pmf of Bernoulli distribution.
Args:
name (str): name of the function. Should be "prob" when passed in from construct.
value (Tensor): a Tensor composed of only zeros and ones.
probs (Tensor): probability of outcome is 1. Default: self.probs.
......@@ -202,18 +193,15 @@ class Bernoulli(Distribution):
pmf(k) = probs1 if k = 1;
pmf(k) = probs0 if k = 0;
"""
if name in self._prob_functions:
probs1 = self.probs if probs is None else probs
probs0 = 1.0 - probs1
return (probs1 * value) + (probs0 * (1.0 - value))
return None
probs1 = self.probs if probs is None else probs
probs0 = 1.0 - probs1
return (probs1 * value) + (probs0 * (1.0 - value))
def _cdf(self, name, value, probs=None):
def _cdf(self, value, probs=None):
r"""
cdf of Bernoulli distribution.
Args:
name (str): name of the function.
value (Tensor): value to be evaluated.
probs (Tensor): probability of outcome is 1. Default: self.probs.
......@@ -222,25 +210,22 @@ class Bernoulli(Distribution):
cdf(k) = probs0 if 0 <= k <1;
cdf(k) = 1 if k >=1;
"""
if name in self._cdf_survival_functions:
probs1 = self.probs if probs is None else probs
prob_type = self.dtypeop(probs1)
value = value * self.fill(prob_type, self.shape(probs1), 1.0)
probs0 = 1.0 - probs1 * self.fill(prob_type, self.shape(value), 1.0)
comp_zero = self.less(value, 0.0)
comp_one = self.less(value, 1.0)
zeros = self.fill(prob_type, self.shape(value), 0.0)
ones = self.fill(prob_type, self.shape(value), 1.0)
less_than_zero = self.select(comp_zero, zeros, probs0)
return self.select(comp_one, less_than_zero, ones)
return None
probs1 = self.probs if probs is None else probs
prob_type = self.dtypeop(probs1)
value = value * self.fill(prob_type, self.shape(probs1), 1.0)
probs0 = 1.0 - probs1 * self.fill(prob_type, self.shape(value), 1.0)
comp_zero = self.less(value, 0.0)
comp_one = self.less(value, 1.0)
zeros = self.fill(prob_type, self.shape(value), 0.0)
ones = self.fill(prob_type, self.shape(value), 1.0)
less_than_zero = self.select(comp_zero, zeros, probs0)
return self.select(comp_one, less_than_zero, ones)
def _kl_loss(self, name, dist, probs1_b, probs1_a=None):
def _kl_loss(self, dist, probs1_b, probs1_a=None):
r"""
Evaluate bernoulli-bernoulli kl divergence, i.e. KL(a||b).
Args:
name (str): name of the funtion.
dist (str): type of the distributions. Should be "Bernoulli" in this case.
probs1_b (Tensor): probs1 of distribution b.
probs1_a (Tensor): probs1 of distribution a. Default: self.probs.
......@@ -249,31 +234,28 @@ class Bernoulli(Distribution):
KL(a||b) = probs1_a * \log(\fract{probs1_a}{probs1_b}) +
probs0_a * \log(\fract{probs0_a}{probs0_b})
"""
if name in self._divergence_functions and dist == 'Bernoulli':
if dist == 'Bernoulli':
probs1_a = self.probs if probs1_a is None else probs1_a
probs0_a = 1.0 - probs1_a
probs0_b = 1.0 - probs1_b
return probs1_a * self.log(probs1_a / probs1_b) + probs0_a * self.log(probs0_a / probs0_b)
return None
def _sample(self, name, shape=(), probs=None):
def _sample(self, shape=(), probs=None):
"""
Sampling.
Args:
name (str): name of the function. Should always be 'sample' when passed in from construct.
shape (tuple): shape of the sample. Default: ().
probs (Tensor): probs1 of the samples. Default: self.probs.
Returns:
Tensor, shape is shape + batch_shape.
"""
if name == 'sample':
probs1 = self.probs if probs is None else probs
l_zero = self.const(0.0)
h_one = self.const(1.0)
sample_uniform = self.uniform(shape + self.shape(probs1), l_zero, h_one)
sample = self.less(sample_uniform, probs1)
sample = self.cast(sample, self.dtype)
return sample
return None
probs1 = self.probs if probs is None else probs
l_zero = self.const(0.0)
h_one = self.const(1.0)
sample_uniform = self.uniform(shape + self.shape(probs1), l_zero, h_one)
sample = self.less(sample_uniform, probs1)
sample = self.cast(sample, self.dtype)
return sample
......@@ -27,11 +27,7 @@ class Distribution(Cell):
Note:
Derived class should override operations such as ,_mean, _prob,
and _log_prob. Functions should be called through construct when
used inside a network. Arguments should be passed in through *args
in the form of function name followed by additional arguments.
Functions such as cdf and prob, require a value to be passed in while
functions such as mean and sd do not require arguments other than name.
and _log_prob. Arguments should be passed in through *args.
Dist_spec_args are unique for each type of distribution. For example, mean and sd
are the dist_spec_args for a Normal distribution.
......@@ -73,11 +69,6 @@ class Distribution(Cell):
self._set_log_survival()
self._set_cross_entropy()
self._prob_functions = ('prob', 'log_prob')
self._cdf_survival_functions = ('cdf', 'log_cdf', 'survival_function', 'log_survival')
self._variance_functions = ('var', 'sd')
self._divergence_functions = ('kl_loss', 'cross_entropy')
@property
def name(self):
return self._name
......@@ -185,7 +176,7 @@ class Distribution(Cell):
Evaluate the log probability(pdf or pmf) at the given value.
Note:
Args must include name of the function and value.
Args must include value.
Dist_spec_args are optional.
"""
return self._call_log_prob(*args)
......@@ -204,7 +195,7 @@ class Distribution(Cell):
Evaluate the probability (pdf or pmf) at given value.
Note:
Args must include name of the function and value.
Args must include value.
Dist_spec_args are optional.
"""
return self._call_prob(*args)
......@@ -223,7 +214,7 @@ class Distribution(Cell):
Evaluate the cdf at given value.
Note:
Args must include name of the function and value.
Args must include value.
Dist_spec_args are optional.
"""
return self._call_cdf(*args)
......@@ -260,7 +251,7 @@ class Distribution(Cell):
Evaluate the log cdf at given value.
Note:
Args must include name of the function and value.
Args must include value.
Dist_spec_args are optional.
"""
return self._call_log_cdf(*args)
......@@ -279,7 +270,7 @@ class Distribution(Cell):
Evaluate the survival function at given value.
Note:
Args must include name of the function and value.
Args must include value.
Dist_spec_args are optional.
"""
return self._call_survival(*args)
......@@ -307,7 +298,7 @@ class Distribution(Cell):
Evaluate the log survival function at given value.
Note:
Args must include name of the function and value.
Args must include value.
Dist_spec_args are optional.
"""
return self._call_log_survival(*args)
......@@ -326,7 +317,7 @@ class Distribution(Cell):
Evaluate the KL divergence, i.e. KL(a||b).
Note:
Args must include name of the function, type of the distribution, parameters of distribution b.
Args must include type of the distribution, parameters of distribution b.
Parameters for distribution a are optional.
"""
return self._kl_loss(*args)
......@@ -336,7 +327,7 @@ class Distribution(Cell):
Evaluate the mean.
Note:
Args must include the name of function. Dist_spec_args are optional.
Dist_spec_args are optional.
"""
return self._mean(*args)
......@@ -345,7 +336,7 @@ class Distribution(Cell):
Evaluate the mode.
Note:
Args must include the name of function. Dist_spec_args are optional.
Dist_spec_args are optional.
"""
return self._mode(*args)
......@@ -354,7 +345,7 @@ class Distribution(Cell):
Evaluate the standard deviation.
Note:
Args must include the name of function. Dist_spec_args are optional.
Dist_spec_args are optional.
"""
return self._call_sd(*args)
......@@ -363,7 +354,7 @@ class Distribution(Cell):
Evaluate the variance.
Note:
Args must include the name of function. Dist_spec_args are optional.
Dist_spec_args are optional.
"""
return self._call_var(*args)
......@@ -390,7 +381,7 @@ class Distribution(Cell):
Evaluate the entropy.
Note:
Args must include the name of function. Dist_spec_args are optional.
Dist_spec_args are optional.
"""
return self._entropy(*args)
......@@ -399,7 +390,7 @@ class Distribution(Cell):
Evaluate the cross_entropy between distribution a and b.
Note:
Args must include name of the function, type of the distribution, parameters of distribution b.
Args must include type of the distribution, parameters of distribution b.
Parameters for distribution a are optional.
"""
return self._call_cross_entropy(*args)
......@@ -421,13 +412,13 @@ class Distribution(Cell):
*args (list): arguments passed in through construct.
Note:
Args must include name of the function.
Shape of the sample and dist_spec_args are optional.
Shape of the sample is default to ().
Dist_spec_args are optional.
"""
return self._sample(*args)
def construct(self, *inputs):
def construct(self, name, *args):
"""
Override construct in Cell.
......@@ -437,35 +428,36 @@ class Distribution(Cell):
'var', 'sd', 'entropy', 'kl_loss', 'cross_entropy', 'sample'.
Args:
*inputs (list): inputs[0] is always the name of the function.
"""
if inputs[0] == 'log_prob':
return self._call_log_prob(*inputs)
if inputs[0] == 'prob':
return self._call_prob(*inputs)
if inputs[0] == 'cdf':
return self._call_cdf(*inputs)
if inputs[0] == 'log_cdf':
return self._call_log_cdf(*inputs)
if inputs[0] == 'survival_function':
return self._call_survival(*inputs)
if inputs[0] == 'log_survival':
return self._call_log_survival(*inputs)
if inputs[0] == 'kl_loss':
return self._kl_loss(*inputs)
if inputs[0] == 'mean':
return self._mean(*inputs)
if inputs[0] == 'mode':
return self._mode(*inputs)
if inputs[0] == 'sd':
return self._call_sd(*inputs)
if inputs[0] == 'var':
return self._call_var(*inputs)
if inputs[0] == 'entropy':
return self._entropy(*inputs)
if inputs[0] == 'cross_entropy':
return self._call_cross_entropy(*inputs)
if inputs[0] == 'sample':
return self._sample(*inputs)
name (str): name of the function.
*args (list): list of arguments needed for the function.
"""
if name == 'log_prob':
return self._call_log_prob(*args)
if name == 'prob':
return self._call_prob(*args)
if name == 'cdf':
return self._call_cdf(*args)
if name == 'log_cdf':
return self._call_log_cdf(*args)
if name == 'survival_function':
return self._call_survival(*args)
if name == 'log_survival':
return self._call_log_survival(*args)
if name == 'kl_loss':
return self._kl_loss(*args)
if name == 'mean':
return self._mean(*args)
if name == 'mode':
return self._mode(*args)
if name == 'sd':
return self._call_sd(*args)
if name == 'var':
return self._call_var(*args)
if name == 'entropy':
return self._entropy(*args)
if name == 'cross_entropy':
return self._call_cross_entropy(*args)
if name == 'sample':
return self._sample(*args)
return None
......@@ -35,55 +35,56 @@ class Exponential(Distribution):
Examples:
>>> # To initialize an Exponential distribution of rate 0.5
>>> n = nn.Exponential(0.5, dtype=mstype.float32)
>>> import mindspore.nn.probability.distribution as msd
>>> e = msd.Exponential(0.5, dtype=mstype.float32)
>>>
>>> # The following creates two independent Exponential distributions
>>> n = nn.Exponential([0.5, 0.5], dtype=mstype.float32)
>>> e = msd.Exponential([0.5, 0.5], dtype=mstype.float32)
>>>
>>> # A Exponential distribution can be initilized without arguments
>>> # In this case, rate must be passed in through construct.
>>> n = nn.Exponential(dtype=mstype.float32)
>>> # An Exponential distribution can be initilized without arguments
>>> # In this case, rate must be passed in through args during function calls
>>> e = msd.Exponential(dtype=mstype.float32)
>>>
>>> # To use Exponential distribution in a network
>>> # To use Exponential in a network
>>> class net(Cell):
>>> def __init__(self):
>>> super(net, self).__init__():
>>> self.e1 = nn.Exponential(0.5, dtype=mstype.float32)
>>> self.e2 = nn.Exponential(dtype=mstype.float32)
>>> self.e1 = msd.Exponential(0.5, dtype=mstype.float32)
>>> self.e2 = msd.Exponential(dtype=mstype.float32)
>>>
>>> # All the following calls in construct are valid
>>> def construct(self, value, rate_b, rate_a):
>>>
>>> # Similar calls can be made to other probability functions
>>> # by replacing 'prob' with the name of the function
>>> ans = self.e1('prob', value)
>>> ans = self.e1.prob(value)
>>> # Evaluate with the respect to distribution b
>>> ans = self.e1('prob', value, rate_b)
>>> ans = self.e1.prob(value, rate_b)
>>>
>>> # Rate must be passed in through construct
>>> ans = self.e2('prob', value, rate_a)
>>> # Rate must be passed in during function calls
>>> ans = self.e2.prob(value, rate_a)
>>>
>>> # Functions 'sd', 'var', 'entropy' have the same usage with 'mean'
>>> # Will return [0.0]
>>> ans = self.e1('mean')
>>> # Will return mean_b
>>> ans = self.e1('mean', rate_b)
>>> # Functions 'sd', 'var', 'entropy' have the same usage as'mean'
>>> # Will return 2
>>> ans = self.e1.mean()
>>> # Will return 1 / rate_b
>>> ans = self.e1.mean(rate_b)
>>>
>>> # Rate must be passed in through construct
>>> ans = self.e2('mean', rate_a)
>>> # Rate must be passed in during function calls
>>> ans = self.e2.mean(rate_a)
>>>
>>> # Usage of 'kl_loss' and 'cross_entropy' are similar
>>> ans = self.e1('kl_loss', 'Exponential', rate_b)
>>> ans = self.e1('kl_loss', 'Exponential', rate_b, rate_a)
>>> ans = self.e1.kl_loss('Exponential', rate_b)
>>> ans = self.e1.kl_loss('Exponential', rate_b, rate_a)
>>>
>>> # Additional rate must be passed in through construct
>>> ans = self.e2('kl_loss', 'Exponential', rate_b, rate_a)
>>> # Additional rate must be passed in
>>> ans = self.e2.kl_loss('Exponential', rate_b, rate_a)
>>>
>>> # Sample Usage
>>> ans = self.e1('sample')
>>> ans = self.e1('sample', (2,3))
>>> ans = self.e1('sample', (2,3), rate_b)
>>> ans = self.e2('sample', (2,3), rate_a)
>>> # Sample
>>> ans = self.e1.sample()
>>> ans = self.e1.sample((2,3))
>>> ans = self.e1.sample((2,3), rate_b)
>>> ans = self.e2.sample((2,3), rate_a)
"""
def __init__(self,
......@@ -131,67 +132,59 @@ class Exponential(Distribution):
"""
return self._rate
def _mean(self, name='mean', rate=None):
def _mean(self, rate=None):
r"""
.. math::
MEAN(EXP) = \fract{1.0}{\lambda}.
"""
if name == 'mean':
rate = self.rate if rate is None else rate
return 1.0 / rate
return None
rate = self.rate if rate is None else rate
return 1.0 / rate
def _mode(self, name='mode', rate=None):
def _mode(self, rate=None):
r"""
.. math::
MODE(EXP) = 0.
"""
if name == 'mode':
rate = self.rate if rate is None else rate
return self.fill(self.dtype, self.shape(rate), 0.)
return None
rate = self.rate if rate is None else rate
return self.fill(self.dtype, self.shape(rate), 0.)
def _sd(self, name='sd', rate=None):
def _sd(self, rate=None):
r"""
.. math::
sd(EXP) = \fract{1.0}{\lambda}.
"""
if name in self._variance_functions:
rate = self.rate if rate is None else rate
return 1.0 / rate
return None
rate = self.rate if rate is None else rate
return 1.0 / rate
def _entropy(self, name='entropy', rate=None):
def _entropy(self, rate=None):
r"""
.. math::
H(Exp) = 1 - \log(\lambda).
"""
rate = self.rate if rate is None else rate
if name == 'entropy':
return 1.0 - self.log(rate)
return None
return 1.0 - self.log(rate)
def _cross_entropy(self, name, dist, rate_b, rate_a=None):
def _cross_entropy(self, dist, rate_b, rate_a=None):
"""
Evaluate cross_entropy between Exponential distributions.
Args:
name (str): name of the funtion. Should always be "cross_entropy" when passed in from construct.
dist (str): type of the distributions. Should be "Exponential" in this case.
rate_b (Tensor): rate of distribution b.
rate_a (Tensor): rate of distribution a. Default: self.rate.
"""
if name == 'cross_entropy' and dist == 'Exponential':
return self._entropy(rate=rate_a) + self._kl_loss(name, dist, rate_b, rate_a)
if dist == 'Exponential':
return self._entropy(rate=rate_a) + self._kl_loss(dist, rate_b, rate_a)
return None
def _prob(self, name, value, rate=None):
def _prob(self, value, rate=None):
r"""
pdf of Exponential distribution.
Args:
Args:
name (str): name of the function.
value (Tensor): value to be evaluated.
rate (Tensor): rate of the distribution. Default: self.rate.
......@@ -201,20 +194,17 @@ class Exponential(Distribution):
.. math::
pdf(x) = rate * \exp(-1 * \lambda * x) if x >= 0 else 0
"""
if name in self._prob_functions:
rate = self.rate if rate is None else rate
prob = rate * self.exp(-1. * rate * value)
zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0)
comp = self.less(value, zeros)
return self.select(comp, zeros, prob)
return None
rate = self.rate if rate is None else rate
prob = rate * self.exp(-1. * rate * value)
zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0)
comp = self.less(value, zeros)
return self.select(comp, zeros, prob)
def _cdf(self, name, value, rate=None):
def _cdf(self, value, rate=None):
r"""
cdf of Exponential distribution.
Args:
name (str): name of the function.
value (Tensor): value to be evaluated.
rate (Tensor): rate of the distribution. Default: self.rate.
......@@ -224,45 +214,40 @@ class Exponential(Distribution):
.. math::
cdf(x) = 1.0 - \exp(-1 * \lambda * x) if x >= 0 else 0
"""
if name in self._cdf_survival_functions:
rate = self.rate if rate is None else rate
cdf = 1.0 - self.exp(-1. * rate * value)
zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0)
comp = self.less(value, zeros)
return self.select(comp, zeros, cdf)
return None
rate = self.rate if rate is None else rate
cdf = 1.0 - self.exp(-1. * rate * value)
zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0)
comp = self.less(value, zeros)
return self.select(comp, zeros, cdf)
def _kl_loss(self, name, dist, rate_b, rate_a=None):
def _kl_loss(self, dist, rate_b, rate_a=None):
"""
Evaluate exp-exp kl divergence, i.e. KL(a||b).
Args:
name (str): name of the funtion.
dist (str): type of the distributions. Should be "Exponential" in this case.
rate_b (Tensor): rate of distribution b.
rate_a (Tensor): rate of distribution a. Default: self.rate.
"""
if name in self._divergence_functions and dist == 'Exponential':
if dist == 'Exponential':
rate_a = self.rate if rate_a is None else rate_a
return self.log(rate_a) - self.log(rate_b) + rate_b / rate_a - 1.0
return None
def _sample(self, name, shape=(), rate=None):
def _sample(self, shape=(), rate=None):
"""
Sampling.
Args:
name (str): name of the function.
shape (tuple): shape of the sample. Default: ().
rate (Tensor): rate of the distribution. Default: self.rate.
Returns:
Tensor, shape is shape + batch_shape.
"""
if name == 'sample':
rate = self.rate if rate is None else rate
minval = self.const(self.minval)
maxval = self.const(1.0)
sample = self.uniform(shape + self.shape(rate), minval, maxval)
return -self.log(sample) / rate
return None
rate = self.rate if rate is None else rate
minval = self.const(self.minval)
maxval = self.const(1.0)
sample = self.uniform(shape + self.shape(rate), minval, maxval)
return -self.log(sample) / rate
......@@ -36,55 +36,56 @@ class Geometric(Distribution):
Examples:
>>> # To initialize a Geometric distribution of prob 0.5
>>> n = nn.Geometric(0.5, dtype=mstype.int32)
>>> import mindspore.nn.probability.distribution as msd
>>> n = msd.Geometric(0.5, dtype=mstype.int32)
>>>
>>> # The following creates two independent Geometric distributions
>>> n = nn.Geometric([0.5, 0.5], dtype=mstype.int32)
>>> n = msd.Geometric([0.5, 0.5], dtype=mstype.int32)
>>>
>>> # A Geometric distribution can be initilized without arguments
>>> # In this case, probs must be passed in through construct.
>>> n = nn.Geometric(dtype=mstype.int32)
>>> # In this case, probs must be passed in through args during function calls.
>>> n = msd.Geometric(dtype=mstype.int32)
>>>
>>> # To use Geometric distribution in a network
>>> # To use Geometric in a network
>>> class net(Cell):
>>> def __init__(self):
>>> super(net, self).__init__():
>>> self.g1 = nn.Geometric(0.5, dtype=mstype.int32)
>>> self.g2 = nn.Geometric(dtype=mstype.int32)
>>> self.g1 = msd.Geometric(0.5, dtype=mstype.int32)
>>> self.g2 = msd.Geometric(dtype=mstype.int32)
>>>
>>> # Tthe following calls are valid in construct
>>> def construct(self, value, probs_b, probs_a):
>>>
>>> # Similar calls can be made to other probability functions
>>> # by replacing 'prob' with the name of the function
>>> ans = self.g1('prob', value)
>>> ans = self.g1.prob(value)
>>> # Evaluate with the respect to distribution b
>>> ans = self.g1('prob', value, probs_b)
>>> ans = self.g1.prob(value, probs_b)
>>>
>>> # Probs must be passed in through construct
>>> ans = self.g2('prob', value, probs_a)
>>> # Probs must be passed in during function calls
>>> ans = self.g2.prob(value, probs_a)
>>>
>>> # Functions 'sd', 'var', 'entropy' have the same usage with 'mean'
>>> # Will return [0.0]
>>> ans = self.g1('mean')
>>> # Will return mean_b
>>> ans = self.g1('mean', probs_b)
>>> # Functions 'sd', 'var', 'entropy' have the same usage as 'mean'
>>> # Will return 1.0
>>> ans = self.g1.mean()
>>> # Another possible usage
>>> ans = self.g1.mean(probs_b)
>>>
>>> # Probs must be passed in through construct
>>> ans = self.g2('mean', probs_a)
>>> # Probs must be passed in during function calls
>>> ans = self.g2.mean(probs_a)
>>>
>>> # Usage of 'kl_loss' and 'cross_entropy' are similar
>>> ans = self.g1('kl_loss', 'Geometric', probs_b)
>>> ans = self.g1('kl_loss', 'Geometric', probs_b, probs_a)
>>> ans = self.g1.kl_loss('Geometric', probs_b)
>>> ans = self.g1.kl_loss('Geometric', probs_b, probs_a)
>>>
>>> # Additional probs must be passed in through construct
>>> ans = self.g2('kl_loss', 'Geometric', probs_b, probs_a)
>>> # Additional probs must be passed in
>>> ans = self.g2.kl_loss('Geometric', probs_b, probs_a)
>>>
>>> # Sample Usage
>>> ans = self.g1('sample')
>>> ans = self.g1('sample', (2,3))
>>> ans = self.g1('sample', (2,3), probs_b)
>>> ans = self.g2('sample', (2,3), probs_a)
>>> # Sample
>>> ans = self.g1.sample()
>>> ans = self.g1.sample((2,3))
>>> ans = self.g1.sample((2,3), probs_b)
>>> ans = self.g2.sample((2,3), probs_a)
"""
def __init__(self,
......@@ -134,67 +135,57 @@ class Geometric(Distribution):
"""
return self._probs
def _mean(self, name='mean', probs1=None):
def _mean(self, probs1=None):
r"""
.. math::
MEAN(Geo) = \fratc{1 - probs1}{probs1}
"""
if name == 'mean':
probs1 = self.probs if probs1 is None else probs1
return (1. - probs1) / probs1
return None
probs1 = self.probs if probs1 is None else probs1
return (1. - probs1) / probs1
def _mode(self, name='mode', probs1=None):
def _mode(self, probs1=None):
r"""
.. math::
MODE(Geo) = 0
"""
if name == 'mode':
probs1 = self.probs if probs1 is None else probs1
return self.fill(self.dtypeop(probs1), self.shape(probs1), 0.)
return None
probs1 = self.probs if probs1 is None else probs1
return self.fill(self.dtypeop(probs1), self.shape(probs1), 0.)
def _var(self, name='var', probs1=None):
def _var(self, probs1=None):
r"""
.. math::
VAR(Geo) = \fract{1 - probs1}{probs1 ^ {2}}
"""
if name in self._variance_functions:
probs1 = self.probs if probs1 is None else probs1
return (1.0 - probs1) / self.sq(probs1)
return None
probs1 = self.probs if probs1 is None else probs1
return (1.0 - probs1) / self.sq(probs1)
def _entropy(self, name='entropy', probs=None):
def _entropy(self, probs=None):
r"""
.. math::
H(Geo) = \fract{-1 * probs0 \log_2 (1-probs0)\ - prob1 * \log_2 (1-probs1)\ }{probs1}
"""
if name == 'entropy':
probs1 = self.probs if probs is None else probs
probs0 = 1.0 - probs1
return (-probs0 * self.log(probs0) - probs1 * self.log(probs1)) / probs1
return None
probs1 = self.probs if probs is None else probs
probs0 = 1.0 - probs1
return (-probs0 * self.log(probs0) - probs1 * self.log(probs1)) / probs1
def _cross_entropy(self, name, dist, probs1_b, probs1_a=None):
def _cross_entropy(self, dist, probs1_b, probs1_a=None):
r"""
Evaluate cross_entropy between Geometric distributions.
Args:
name (str): name of the funtion. Should always be "cross_entropy" when passed in from construct.
dist (str): type of the distributions. Should be "Geometric" in this case.
probs1_b (Tensor): probability of success of distribution b.
probs1_a (Tensor): probability of success of distribution a. Default: self.probs.
"""
if name == 'cross_entropy' and dist == 'Geometric':
return self._entropy(probs=probs1_a) + self._kl_loss(name, dist, probs1_b, probs1_a)
if dist == 'Geometric':
return self._entropy(probs=probs1_a) + self._kl_loss(dist, probs1_b, probs1_a)
return None
def _prob(self, name, value, probs=None):
def _prob(self, value, probs=None):
r"""
pmf of Geometric distribution.
Args:
name (str): name of the function. Should be "prob" when passed in from construct.
value (Tensor): a Tensor composed of only natural numbers.
probs (Tensor): probability of success. Default: self.probs.
......@@ -202,27 +193,24 @@ class Geometric(Distribution):
pmf(k) = probs0 ^k * probs1 if k >= 0;
pmf(k) = 0 if k < 0.
"""
if name in self._prob_functions:
probs1 = self.probs if probs is None else probs
dtype = self.dtypeop(value)
if self.issubclass(dtype, mstype.int_):
pass
elif self.issubclass(dtype, mstype.float_):
value = self.floor(value)
else:
return None
pmf = self.pow((1.0 - probs1), value) * probs1
zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0)
comp = self.less(value, zeros)
return self.select(comp, zeros, pmf)
return None
probs1 = self.probs if probs is None else probs
dtype = self.dtypeop(value)
if self.issubclass(dtype, mstype.int_):
pass
elif self.issubclass(dtype, mstype.float_):
value = self.floor(value)
else:
return None
pmf = self.pow((1.0 - probs1), value) * probs1
zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0)
comp = self.less(value, zeros)
return self.select(comp, zeros, pmf)
def _cdf(self, name, value, probs=None):
def _cdf(self, value, probs=None):
r"""
cdf of Geometric distribution.
Args:
name (str): name of the function.
value (Tensor): a Tensor composed of only natural numbers.
probs (Tensor): probability of success. Default: self.probs.
......@@ -231,28 +219,26 @@ class Geometric(Distribution):
cdf(k) = 0 if k < 0.
"""
if name in self._cdf_survival_functions:
probs1 = self.probs if probs is None else probs
probs0 = 1.0 - probs1
dtype = self.dtypeop(value)
if self.issubclass(dtype, mstype.int_):
pass
elif self.issubclass(dtype, mstype.float_):
value = self.floor(value)
else:
return None
cdf = 1.0 - self.pow(probs0, value + 1.0)
zeros = self.fill(self.dtypeop(probs1), self.shape(cdf), 0.0)
comp = self.less(value, zeros)
return self.select(comp, zeros, cdf)
return None
probs1 = self.probs if probs is None else probs
probs0 = 1.0 - probs1
dtype = self.dtypeop(value)
if self.issubclass(dtype, mstype.int_):
pass
elif self.issubclass(dtype, mstype.float_):
value = self.floor(value)
else:
return None
cdf = 1.0 - self.pow(probs0, value + 1.0)
zeros = self.fill(self.dtypeop(probs1), self.shape(cdf), 0.0)
comp = self.less(value, zeros)
return self.select(comp, zeros, cdf)
def _kl_loss(self, name, dist, probs1_b, probs1_a=None):
def _kl_loss(self, dist, probs1_b, probs1_a=None):
r"""
Evaluate Geometric-Geometric kl divergence, i.e. KL(a||b).
Args:
name (str): name of the funtion.
dist (str): type of the distributions. Should be "Geometric" in this case.
probs1_b (Tensor): probability of success of distribution b.
probs1_a (Tensor): probability of success of distribution a. Default: self.probs.
......@@ -260,29 +246,26 @@ class Geometric(Distribution):
.. math::
KL(a||b) = \log(\fract{probs1_a}{probs1_b}) + \fract{probs0_a}{probs1_a} * \log(\fract{probs0_a}{probs0_b})
"""
if name in self._divergence_functions and dist == 'Geometric':
if dist == 'Geometric':
probs1_a = self.probs if probs1_a is None else probs1_a
probs0_a = 1.0 - probs1_a
probs0_b = 1.0 - probs1_b
return self.log(probs1_a / probs1_b) + (probs0_a / probs1_a) * self.log(probs0_a / probs0_b)
return None
def _sample(self, name, shape=(), probs=None):
def _sample(self, shape=(), probs=None):
"""
Sampling.
Args:
name (str): name of the function. Should always be 'sample' when passed in from construct.
shape (tuple): shape of the sample. Default: ().
probs (Tensor): probability of success. Default: self.probs.
Returns:
Tensor, shape is shape + batch_shape.
"""
if name == 'sample':
probs = self.probs if probs is None else probs
minval = self.const(self.minval)
maxval = self.const(1.0)
sample_uniform = self.uniform(shape + self.shape(probs), minval, maxval)
return self.floor(self.log(sample_uniform) / self.log(1.0 - probs))
return None
probs = self.probs if probs is None else probs
minval = self.const(self.minval)
maxval = self.const(1.0)
sample_uniform = self.uniform(shape + self.shape(probs), minval, maxval)
return self.floor(self.log(sample_uniform) / self.log(1.0 - probs))
......@@ -17,7 +17,6 @@ import numpy as np
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from mindspore.context import get_context
from .distribution import Distribution
from ._utils.utils import convert_to_batch, check_greater_equal_zero
......@@ -39,55 +38,56 @@ class Normal(Distribution):
Examples:
>>> # To initialize a Normal distribution of mean 3.0 and standard deviation 4.0
>>> n = nn.Normal(3.0, 4.0, dtype=mstype.float32)
>>> import mindspore.nn.probability.distribution as msd
>>> n = msd.Normal(3.0, 4.0, dtype=mstype.float32)
>>>
>>> # The following creates two independent Normal distributions
>>> n = nn.Normal([3.0, 3.0], [4.0, 4.0], dtype=mstype.float32)
>>> n = msd.Normal([3.0, 3.0], [4.0, 4.0], dtype=mstype.float32)
>>>
>>> # A normal distribution can be initilize without arguments
>>> # In this case, mean and sd must be passed in through construct.
>>> n = nn.Normal(dtype=mstype.float32)
>>> # A Normal distribution can be initilize without arguments
>>> # In this case, mean and sd must be passed in through args.
>>> n = msd.Normal(dtype=mstype.float32)
>>>
>>> # To use normal in a network
>>> # To use Normal in a network
>>> class net(Cell):
>>> def __init__(self):
>>> super(net, self).__init__():
>>> self.n1 = nn.Normal(0.0, 1.0, dtype=mstype.float32)
>>> self.n2 = nn.Normal(dtype=mstype.float32)
>>> self.n1 = msd.Nomral(0.0, 1.0, dtype=mstype.float32)
>>> self.n2 = msd.Normal(dtype=mstype.float32)
>>>
>>> # The following calls are valid in construct
>>> def construct(self, value, mean_b, sd_b, mean_a, sd_a):
>>>
>>> # Similar calls can be made to other probability functions
>>> # by replacing 'prob' with the name of the function
>>> ans = self.n1('prob', value)
>>> ans = self.n1.prob(value)
>>> # Evaluate with the respect to distribution b
>>> ans = self.n1('prob', value, mean_b, sd_b)
>>> ans = self.n1.prob(value, mean_b, sd_b)
>>>
>>> # mean and sd must be passed in through construct
>>> ans = self.n2('prob', value, mean_a, sd_a)
>>> # mean and sd must be passed in during function calls
>>> ans = self.n2.prob(value, mean_a, sd_a)
>>>
>>> # Functions 'sd', 'var', 'entropy' have the same usage with 'mean'
>>> # Will return [0.0]
>>> ans = self.n1('mean')
>>> # Will return mean_b
>>> ans = self.n1('mean', mean_b, sd_b)
>>> # Functions 'sd', 'var', 'entropy' have the same usage as 'mean'
>>> # will return [0.0]
>>> ans = self.n1.mean()
>>> # will return mean_b
>>> ans = self.n1.mean(mean_b, sd_b)
>>>
>>> # mean and sd must be passed in through construct
>>> ans = self.n2('mean', mean_a, sd_a)
>>> # mean and sd must be passed during function calls
>>> ans = self.n2.mean(mean_a, sd_a)
>>>
>>> # Usage of 'kl_loss' and 'cross_entropy' are similar
>>> ans = self.n1('kl_loss', 'Normal', mean_b, sd_b)
>>> ans = self.n1('kl_loss', 'Normal', mean_b, sd_b, mean_a, sd_a)
>>> ans = self.n1.kl_loss('Normal', mean_b, sd_b)
>>> ans = self.n1.kl_loss('Normal', mean_b, sd_b, mean_a, sd_a)
>>>
>>> # Additional mean and sd must be passed in through construct
>>> ans = self.n2('kl_loss', 'Normal', mean_b, sd_b, mean_a, sd_a)
>>> # Additional mean and sd must be passed
>>> ans = self.n2.kl_loss('Normal', mean_b, sd_b, mean_a, sd_a)
>>>
>>> # Sample Usage
>>> ans = self.n1('sample')
>>> ans = self.n1('sample', (2,3))
>>> ans = self.n1('sample', (2,3), mean_b, sd_b)
>>> ans = self.n2('sample', (2,3), mean_a, sd_a)
>>> # Sample
>>> ans = self.n1.sample()
>>> ans = self.n1.sample((2,3))
>>> ans = self.n1.sample((2,3), mean_b, sd_b)
>>> ans = self.n2.sample((2,3), mean_a, sd_a)
"""
def __init__(self,
......@@ -114,7 +114,7 @@ class Normal(Distribution):
self.const = P.ScalarToArray()
self.erf = P.Erf()
self.exp = P.Exp()
self.expm1 = P.Expm1() if get_context('device_target') == 'Ascend' else self._expm1_by_step
self.expm1 = self._expm1_by_step
self.fill = P.Fill()
self.log = P.Log()
self.shape = P.Shape()
......@@ -135,67 +135,57 @@ class Normal(Distribution):
"""
return self.exp(x) - 1.0
def _mean(self, name='mean', mean=None, sd=None):
def _mean(self, mean=None, sd=None):
"""
Mean of the distribution.
"""
if name == 'mean':
mean = self._mean_value if mean is None or sd is None else mean
return mean
return None
mean = self._mean_value if mean is None or sd is None else mean
return mean
def _mode(self, name='mode', mean=None, sd=None):
def _mode(self, mean=None, sd=None):
"""
Mode of the distribution.
"""
if name == 'mode':
mean = self._mean_value if mean is None or sd is None else mean
return mean
return None
mean = self._mean_value if mean is None or sd is None else mean
return mean
def _sd(self, name='sd', mean=None, sd=None):
def _sd(self, mean=None, sd=None):
"""
Standard deviation of the distribution.
"""
if name in self._variance_functions:
sd = self._sd_value if mean is None or sd is None else sd
return sd
return None
sd = self._sd_value if mean is None or sd is None else sd
return sd
def _entropy(self, name='entropy', sd=None):
def _entropy(self, sd=None):
r"""
Evaluate entropy.
.. math::
H(X) = \log(\sqrt(numpy.e * 2. * numpy.pi * \sq(\sigma)))
"""
if name == 'entropy':
sd = self._sd_value if sd is None else sd
return self.log(self.sqrt(np.e * 2. * np.pi * self.sq(sd)))
return None
sd = self._sd_value if sd is None else sd
return self.log(self.sqrt(np.e * 2. * np.pi * self.sq(sd)))
def _cross_entropy(self, name, dist, mean_b, sd_b, mean_a=None, sd_a=None):
def _cross_entropy(self, dist, mean_b, sd_b, mean_a=None, sd_a=None):
r"""
Evaluate cross_entropy between normal distributions.
Args:
name (str): name of the funtion passed in from construct. Should always be "cross_entropy".
dist (str): type of the distributions. Should be "Normal" in this case.
mean_b (Tensor): mean of distribution b.
sd_b (Tensor): standard deviation distribution b.
mean_a (Tensor): mean of distribution a. Default: self._mean_value.
sd_a (Tensor): standard deviation distribution a. Default: self._sd_value.
"""
if name == 'cross_entropy' and dist == 'Normal':
return self._entropy(sd=sd_a) + self._kl_loss(name, dist, mean_b, sd_b, mean_a, sd_a)
if dist == 'Normal':
return self._entropy(sd=sd_a) + self._kl_loss(dist, mean_b, sd_b, mean_a, sd_a)
return None
def _log_prob(self, name, value, mean=None, sd=None):
def _log_prob(self, value, mean=None, sd=None):
r"""
Evaluate log probability.
Args:
name (str): name of the funtion passed in from construct.
value (Tensor): value to be evaluated.
mean (Tensor): mean of the distribution. Default: self._mean_value.
sd (Tensor): standard deviation the distribution. Default: self._sd_value.
......@@ -203,20 +193,17 @@ class Normal(Distribution):
.. math::
L(x) = -1* \fract{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2))
"""
if name in self._prob_functions:
mean = self._mean_value if mean is None else mean
sd = self._sd_value if sd is None else sd
unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd))
neg_normalization = -1. * self.log(self.sqrt(2. * np.pi * self.sq(sd)))
return unnormalized_log_prob + neg_normalization
return None
mean = self._mean_value if mean is None else mean
sd = self._sd_value if sd is None else sd
unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd))
neg_normalization = -1. * self.log(self.sqrt(2. * np.pi * self.sq(sd)))
return unnormalized_log_prob + neg_normalization
def _cdf(self, name, value, mean=None, sd=None):
def _cdf(self, value, mean=None, sd=None):
r"""
Evaluate cdf of given value.
Args:
name (str): name of the funtion passed in from construct. Should always be "cdf".
value (Tensor): value to be evaluated.
mean (Tensor): mean of the distribution. Default: self._mean_value.
sd (Tensor): standard deviation the distribution. Default: self._sd_value.
......@@ -224,20 +211,17 @@ class Normal(Distribution):
.. math::
cdf(x) = 0.5 * (1+ Erf((x - \mu) / ( \sigma * \sqrt(2))))
"""
if name in self._cdf_survival_functions:
mean = self._mean_value if mean is None else mean
sd = self._sd_value if sd is None else sd
sqrt2 = self.sqrt(self.const(2.0))
adjusted = (value - mean) / (sd * sqrt2)
return 0.5 * (1.0 + self.erf(adjusted))
return None
mean = self._mean_value if mean is None else mean
sd = self._sd_value if sd is None else sd
sqrt2 = self.sqrt(self.const(2.0))
adjusted = (value - mean) / (sd * sqrt2)
return 0.5 * (1.0 + self.erf(adjusted))
def _kl_loss(self, name, dist, mean_b, sd_b, mean_a=None, sd_a=None):
def _kl_loss(self, dist, mean_b, sd_b, mean_a=None, sd_a=None):
r"""
Evaluate Normal-Normal kl divergence, i.e. KL(a||b).
Args:
name (str): name of the funtion passed in from construct.
dist (str): type of the distributions. Should be "Normal" in this case.
mean_b (Tensor): mean of distribution b.
sd_b (Tensor): standard deviation distribution b.
......@@ -248,7 +232,7 @@ class Normal(Distribution):
KL(a||b) = 0.5 * (\fract{MEAN(a)}{STD(b)} - \fract{MEAN(b)}{STD(b)}) ^ 2 +
0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b)))
"""
if name in self._divergence_functions and dist == 'Normal':
if dist == 'Normal':
mean_a = self._mean_value if mean_a is None else mean_a
sd_a = self._sd_value if sd_a is None else sd_a
diff_log_scale = self.log(sd_a) - self.log(sd_b)
......@@ -256,12 +240,11 @@ class Normal(Distribution):
return 0.5 * squared_diff + 0.5 * self.expm1(2 * diff_log_scale) - diff_log_scale
return None
def _sample(self, name, shape=(), mean=None, sd=None):
def _sample(self, shape=(), mean=None, sd=None):
"""
Sampling.
Args:
name (str): name of the function. Should always be 'sample' when passed in from construct.
shape (tuple): shape of the sample. Default: ().
mean (Tensor): mean of the samples. Default: self._mean_value.
sd (Tensor): standard deviation of the samples. Default: self._sd_value.
......@@ -269,14 +252,12 @@ class Normal(Distribution):
Returns:
Tensor, shape is shape + batch_shape.
"""
if name == 'sample':
mean = self._mean_value if mean is None else mean
sd = self._sd_value if sd is None else sd
batch_shape = self.shape(self.zeroslike(mean) + self.zeroslike(sd))
sample_shape = shape + batch_shape
mean_zero = self.const(0.0)
sd_one = self.const(1.0)
sample_norm = C.normal(sample_shape, mean_zero, sd_one, self.seed)
sample = mean + sample_norm * sd
return sample
return None
mean = self._mean_value if mean is None else mean
sd = self._sd_value if sd is None else sd
batch_shape = self.shape(self.zeroslike(mean) + self.zeroslike(sd))
sample_shape = shape + batch_shape
mean_zero = self.const(0.0)
sd_one = self.const(1.0)
sample_norm = C.normal(sample_shape, mean_zero, sd_one, self.seed)
sample = mean + sample_norm * sd
return sample
......@@ -35,55 +35,56 @@ class Uniform(Distribution):
Examples:
>>> # To initialize a Uniform distribution of mean 3.0 and standard deviation 4.0
>>> n = nn.Uniform(0.0, 1.0, dtype=mstype.float32)
>>> import mindspore.nn.probability.distribution as msd
>>> u = msd.Uniform(0.0, 1.0, dtype=mstype.float32)
>>>
>>> # The following creates two independent Uniform distributions
>>> n = nn.Uniform([0.0, 0.0], [1.0, 2.0], dtype=mstype.float32)
>>> u = msd.Uniform([0.0, 0.0], [1.0, 2.0], dtype=mstype.float32)
>>>
>>> # A Uniform distribution can be initilized without arguments
>>> # In this case, high and low must be passed in through construct.
>>> n = nn.Uniform(dtype=mstype.float32)
>>> # In this case, high and low must be passed in through args during function calls.
>>> u = msd.Uniform(dtype=mstype.float32)
>>>
>>> # To use Uniform in a network
>>> class net(Cell):
>>> def __init__(self)
>>> super(net, self).__init__():
>>> self.u1 = nn.Uniform(0.0, 1.0, dtype=mstype.float32)
>>> self.u2 = nn.Uniform(dtype=mstype.float32)
>>> self.u1 = msd.Uniform(0.0, 1.0, dtype=mstype.float32)
>>> self.u2 = msd.Uniform(dtype=mstype.float32)
>>>
>>> # All the following calls in construct are valid
>>> def construct(self, value, low_b, high_b, low_a, high_a):
>>>
>>> # Similar calls can be made to other probability functions
>>> # by replacing 'prob' with the name of the function
>>> ans = self.u1('prob', value)
>>> ans = self.u1.prob(value)
>>> # Evaluate with the respect to distribution b
>>> ans = self.u1('prob', value, low_b, high_b)
>>> ans = self.u1.prob(value, low_b, high_b)
>>>
>>> # High and low must be passed in through construct
>>> ans = self.u2('prob', value, low_a, high_a)
>>> # High and low must be passed in during function calls
>>> ans = self.u2.prob(value, low_a, high_a)
>>>
>>> # Functions 'sd', 'var', 'entropy' have the same usage with 'mean'
>>> # Will return [0.0]
>>> ans = self.u1('mean')
>>> # Will return low_b
>>> ans = self.u1('mean', low_b, high_b)
>>> # Functions 'sd', 'var', 'entropy' have the same usage as 'mean'
>>> # Will return 0.5
>>> ans = self.u1.mean()
>>> # Will return (low_b + high_b) / 2
>>> ans = self.u1.mean(low_b, high_b)
>>>
>>> # High and low must be passed in through construct
>>> ans = self.u2('mean', low_a, high_a)
>>> # High and low must be passed in during function calls
>>> ans = self.u2.mean(low_a, high_a)
>>>
>>> # Usage of 'kl_loss' and 'cross_entropy' are similar
>>> ans = self.u1('kl_loss', 'Uniform', low_b, high_b)
>>> ans = self.u1('kl_loss', 'Uniform', low_b, high_b, low_a, high_a)
>>> ans = self.u1.kl_loss('Uniform', low_b, high_b)
>>> ans = self.u1.kl_loss('Uniform', low_b, high_b, low_a, high_a)
>>>
>>> # Additional high and low must be passed in through construct
>>> ans = self.u2('kl_loss', 'Uniform', low_b, high_b, low_a, high_a)
>>> # Additional high and low must be passed
>>> ans = self.u2.kl_loss('Uniform', low_b, high_b, low_a, high_a)
>>>
>>> # Sample Usage
>>> ans = self.u1('sample')
>>> ans = self.u1('sample', (2,3))
>>> ans = self.u1('sample', (2,3), low_b, high_b)
>>> ans = self.u2('sample', (2,3), low_a, high_a)
>>> # Sample
>>> ans = self.u1.sample()
>>> ans = self.u1.sample((2,3))
>>> ans = self.u1.sample((2,3), low_b, high_b)
>>> ans = self.u2.sample((2,3), low_a, high_a)
"""
def __init__(self,
......@@ -142,73 +143,64 @@ class Uniform(Distribution):
"""
return self._high
def _range(self, name='range', low=None, high=None):
def _range(self, low=None, high=None):
r"""
Return the range of the distribution.
.. math::
range(U) = high -low
"""
if name == 'range':
low = self.low if low is None else low
high = self.high if high is None else high
return high - low
return None
low = self.low if low is None else low
high = self.high if high is None else high
return high - low
def _mean(self, name='mean', low=None, high=None):
def _mean(self, low=None, high=None):
r"""
.. math::
MEAN(U) = \fract{low + high}{2}.
"""
if name == 'mean':
low = self.low if low is None else low
high = self.high if high is None else high
return (low + high) / 2.
return None
low = self.low if low is None else low
high = self.high if high is None else high
return (low + high) / 2.
def _var(self, name='var', low=None, high=None):
def _var(self, low=None, high=None):
r"""
.. math::
VAR(U) = \fract{(high -low) ^ 2}{12}.
"""
if name in self._variance_functions:
low = self.low if low is None else low
high = self.high if high is None else high
return self.sq(high - low) / 12.0
return None
low = self.low if low is None else low
high = self.high if high is None else high
return self.sq(high - low) / 12.0
def _entropy(self, name='entropy', low=None, high=None):
def _entropy(self, low=None, high=None):
r"""
.. math::
H(U) = \log(high - low).
"""
if name == 'entropy':
low = self.low if low is None else low
high = self.high if high is None else high
return self.log(high - low)
return None
low = self.low if low is None else low
high = self.high if high is None else high
return self.log(high - low)
def _cross_entropy(self, name, dist, low_b, high_b, low_a=None, high_a=None):
def _cross_entropy(self, dist, low_b, high_b, low_a=None, high_a=None):
"""
Evaluate cross_entropy between Uniform distributoins.
Args:
name (str): name of the funtion.
dist (str): type of the distributions. Should be "Uniform" in this case.
low_b (Tensor): lower bound of distribution b.
high_b (Tensor): upper bound of distribution b.
low_a (Tensor): lower bound of distribution a. Default: self.low.
high_a (Tensor): upper bound of distribution a. Default: self.high.
"""
if name == 'cross_entropy' and dist == 'Uniform':
return self._entropy(low=low_a, high=high_a) + self._kl_loss(name, dist, low_b, high_b, low_a, high_a)
if dist == 'Uniform':
return self._entropy(low=low_a, high=high_a) + self._kl_loss(dist, low_b, high_b, low_a, high_a)
return None
def _prob(self, name, value, low=None, high=None):
def _prob(self, value, low=None, high=None):
r"""
pdf of Uniform distribution.
Args:
name (str): name of the function.
value (Tensor): value to be evaluated.
low (Tensor): lower bound of the distribution. Default: self.low.
high (Tensor): upper bound of the distribution. Default: self.high.
......@@ -218,32 +210,29 @@ class Uniform(Distribution):
pdf(x) = \fract{1.0}{high -low} if low <= x <= high;
pdf(x) = 0 if x > high;
"""
if name in self._prob_functions:
low = self.low if low is None else low
high = self.high if high is None else high
ones = self.fill(self.dtype, self.shape(value), 1.0)
prob = ones / (high - low)
broadcast_shape = self.shape(prob)
zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0)
comp_lo = self.less(value, low)
comp_hi = self.lessequal(value, high)
less_than_low = self.select(comp_lo, zeros, prob)
return self.select(comp_hi, less_than_low, zeros)
return None
low = self.low if low is None else low
high = self.high if high is None else high
ones = self.fill(self.dtype, self.shape(value), 1.0)
prob = ones / (high - low)
broadcast_shape = self.shape(prob)
zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0)
comp_lo = self.less(value, low)
comp_hi = self.lessequal(value, high)
less_than_low = self.select(comp_lo, zeros, prob)
return self.select(comp_hi, less_than_low, zeros)
def _kl_loss(self, name, dist, low_b, high_b, low_a=None, high_a=None):
def _kl_loss(self, dist, low_b, high_b, low_a=None, high_a=None):
"""
Evaluate uniform-uniform kl divergence, i.e. KL(a||b).
Args:
name (str): name of the funtion.
dist (str): type of the distributions. Should be "Uniform" in this case.
low_b (Tensor): lower bound of distribution b.
high_b (Tensor): upper bound of distribution b.
low_a (Tensor): lower bound of distribution a. Default: self.low.
high_a (Tensor): upper bound of distribution a. Default: self.high.
"""
if name in self._divergence_functions and dist == 'Uniform':
if dist == 'Uniform':
low_a = self.low if low_a is None else low_a
high_a = self.high if high_a is None else high_a
kl = self.log(high_b - low_b) / self.log(high_a - low_a)
......@@ -251,12 +240,11 @@ class Uniform(Distribution):
return self.select(comp, kl, self.log(self.zeroslike(kl)))
return None
def _cdf(self, name, value, low=None, high=None):
def _cdf(self, value, low=None, high=None):
r"""
cdf of Uniform distribution.
Args:
name (str): name of the function.
value (Tensor): value to be evaluated.
low (Tensor): lower bound of the distribution. Default: self.low.
high (Tensor): upper bound of the distribution. Default: self.high.
......@@ -266,25 +254,22 @@ class Uniform(Distribution):
cdf(x) = \fract{x - low}{high -low} if low <= x <= high;
cdf(x) = 1 if x > high;
"""
if name in self._cdf_survival_functions:
low = self.low if low is None else low
high = self.high if high is None else high
prob = (value - low) / (high - low)
broadcast_shape = self.shape(prob)
zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0)
ones = self.fill(self.dtypeop(prob), broadcast_shape, 1.0)
comp_lo = self.less(value, low)
comp_hi = self.less(value, high)
less_than_low = self.select(comp_lo, zeros, prob)
return self.select(comp_hi, less_than_low, ones)
return None
low = self.low if low is None else low
high = self.high if high is None else high
prob = (value - low) / (high - low)
broadcast_shape = self.shape(prob)
zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0)
ones = self.fill(self.dtypeop(prob), broadcast_shape, 1.0)
comp_lo = self.less(value, low)
comp_hi = self.less(value, high)
less_than_low = self.select(comp_lo, zeros, prob)
return self.select(comp_hi, less_than_low, ones)
def _sample(self, name, shape=(), low=None, high=None):
def _sample(self, shape=(), low=None, high=None):
"""
Sampling.
Args:
name (str): name of the function. Should always be 'sample' when passed in from construct.
shape (tuple): shape of the sample. Default: ().
low (Tensor): lower bound of the distribution. Default: self.low.
high (Tensor): upper bound of the distribution. Default: self.high.
......@@ -292,13 +277,11 @@ class Uniform(Distribution):
Returns:
Tensor, shape is shape + batch_shape.
"""
if name == 'sample':
low = self.low if low is None else low
high = self.high if high is None else high
broadcast_shape = self.shape(low + high)
l_zero = self.const(0.0)
h_one = self.const(1.0)
sample_uniform = self.uniform(shape + broadcast_shape, l_zero, h_one)
sample = (high - low) * sample_uniform + low
return sample
return None
low = self.low if low is None else low
high = self.high if high is None else high
broadcast_shape = self.shape(low + high)
l_zero = self.const(0.0)
h_one = self.const(1.0)
sample_uniform = self.uniform(shape + broadcast_shape, l_zero, h_one)
sample = (high - low) * sample_uniform + low
return sample
......@@ -19,7 +19,6 @@ import mindspore.context as context
import mindspore.nn as nn
import mindspore.nn.probability.distribution as msd
from mindspore import Tensor
from mindspore.common.api import ms_function
from mindspore import dtype
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
......@@ -32,9 +31,8 @@ class Prob(nn.Cell):
super(Prob, self).__init__()
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_):
return self.b('prob', x_)
return self.b.prob(x_)
def test_pmf():
"""
......@@ -57,9 +55,8 @@ class LogProb(nn.Cell):
super(LogProb, self).__init__()
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_):
return self.b('log_prob', x_)
return self.b.log_prob(x_)
def test_log_likelihood():
"""
......@@ -81,9 +78,8 @@ class KL(nn.Cell):
super(KL, self).__init__()
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_):
return self.b('kl_loss', 'Bernoulli', x_)
return self.b.kl_loss('Bernoulli', x_)
def test_kl_loss():
"""
......@@ -107,9 +103,8 @@ class Basics(nn.Cell):
super(Basics, self).__init__()
self.b = msd.Bernoulli([0.3, 0.5, 0.7], dtype=dtype.int32)
@ms_function
def construct(self):
return self.b('mean'), self.b('sd'), self.b('mode')
return self.b.mean(), self.b.sd(), self.b.mode()
def test_basics():
"""
......@@ -134,9 +129,8 @@ class Sampling(nn.Cell):
self.b = msd.Bernoulli([0.7, 0.5], seed=seed, dtype=dtype.int32)
self.shape = shape
@ms_function
def construct(self, probs=None):
return self.b('sample', self.shape, probs)
return self.b.sample(self.shape, probs)
def test_sample():
"""
......@@ -155,9 +149,8 @@ class CDF(nn.Cell):
super(CDF, self).__init__()
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_):
return self.b('cdf', x_)
return self.b.cdf(x_)
def test_cdf():
"""
......@@ -171,7 +164,6 @@ def test_cdf():
tol = 1e-6
assert (np.abs(output.asnumpy() - expect_cdf) < tol).all()
class LogCDF(nn.Cell):
"""
Test class: log cdf of bernoulli distributions.
......@@ -180,9 +172,8 @@ class LogCDF(nn.Cell):
super(LogCDF, self).__init__()
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_):
return self.b('log_cdf', x_)
return self.b.log_cdf(x_)
def test_logcdf():
"""
......@@ -205,9 +196,8 @@ class SF(nn.Cell):
super(SF, self).__init__()
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_):
return self.b('survival_function', x_)
return self.b.survival_function(x_)
def test_survival():
"""
......@@ -230,9 +220,8 @@ class LogSF(nn.Cell):
super(LogSF, self).__init__()
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_):
return self.b('log_survival', x_)
return self.b.log_survival(x_)
def test_log_survival():
"""
......@@ -254,9 +243,8 @@ class EntropyH(nn.Cell):
super(EntropyH, self).__init__()
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ms_function
def construct(self):
return self.b('entropy')
return self.b.entropy()
def test_entropy():
"""
......@@ -277,12 +265,11 @@ class CrossEntropy(nn.Cell):
super(CrossEntropy, self).__init__()
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_):
entropy = self.b('entropy')
kl_loss = self.b('kl_loss', 'Bernoulli', x_)
entropy = self.b.entropy()
kl_loss = self.b.kl_loss('Bernoulli', x_)
h_sum_kl = entropy + kl_loss
cross_entropy = self.b('cross_entropy', 'Bernoulli', x_)
cross_entropy = self.b.cross_entropy('Bernoulli', x_)
return h_sum_kl - cross_entropy
def test_cross_entropy():
......
......@@ -19,7 +19,6 @@ import mindspore.context as context
import mindspore.nn as nn
import mindspore.nn.probability.distribution as msd
from mindspore import Tensor
from mindspore.common.api import ms_function
from mindspore import dtype
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
......@@ -32,9 +31,8 @@ class Prob(nn.Cell):
super(Prob, self).__init__()
self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32)
@ms_function
def construct(self, x_):
return self.e('prob', x_)
return self.e.prob(x_)
def test_pdf():
"""
......@@ -56,9 +54,8 @@ class LogProb(nn.Cell):
super(LogProb, self).__init__()
self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32)
@ms_function
def construct(self, x_):
return self.e('log_prob', x_)
return self.e.log_prob(x_)
def test_log_likelihood():
"""
......@@ -80,9 +77,8 @@ class KL(nn.Cell):
super(KL, self).__init__()
self.e = msd.Exponential([1.5], dtype=dtype.float32)
@ms_function
def construct(self, x_):
return self.e('kl_loss', 'Exponential', x_)
return self.e.kl_loss('Exponential', x_)
def test_kl_loss():
"""
......@@ -104,9 +100,8 @@ class Basics(nn.Cell):
super(Basics, self).__init__()
self.e = msd.Exponential([0.5], dtype=dtype.float32)
@ms_function
def construct(self):
return self.e('mean'), self.e('sd'), self.e('mode')
return self.e.mean(), self.e.sd(), self.e.mode()
def test_basics():
"""
......@@ -131,9 +126,8 @@ class Sampling(nn.Cell):
self.e = msd.Exponential([[1.0], [0.5]], seed=seed, dtype=dtype.float32)
self.shape = shape
@ms_function
def construct(self, rate=None):
return self.e('sample', self.shape, rate)
return self.e.sample(self.shape, rate)
def test_sample():
"""
......@@ -154,9 +148,8 @@ class CDF(nn.Cell):
super(CDF, self).__init__()
self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32)
@ms_function
def construct(self, x_):
return self.e('cdf', x_)
return self.e.cdf(x_)
def test_cdf():
"""
......@@ -178,9 +171,8 @@ class LogCDF(nn.Cell):
super(LogCDF, self).__init__()
self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32)
@ms_function
def construct(self, x_):
return self.e('log_cdf', x_)
return self.e.log_cdf(x_)
def test_log_cdf():
"""
......@@ -202,9 +194,8 @@ class SF(nn.Cell):
super(SF, self).__init__()
self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32)
@ms_function
def construct(self, x_):
return self.e('survival_function', x_)
return self.e.survival_function(x_)
def test_survival():
"""
......@@ -226,9 +217,8 @@ class LogSF(nn.Cell):
super(LogSF, self).__init__()
self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32)
@ms_function
def construct(self, x_):
return self.e('log_survival', x_)
return self.e.log_survival(x_)
def test_log_survival():
"""
......@@ -250,9 +240,8 @@ class EntropyH(nn.Cell):
super(EntropyH, self).__init__()
self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32)
@ms_function
def construct(self):
return self.e('entropy')
return self.e.entropy()
def test_entropy():
"""
......@@ -273,12 +262,11 @@ class CrossEntropy(nn.Cell):
super(CrossEntropy, self).__init__()
self.e = msd.Exponential([1.0], dtype=dtype.float32)
@ms_function
def construct(self, x_):
entropy = self.e('entropy')
kl_loss = self.e('kl_loss', 'Exponential', x_)
entropy = self.e.entropy()
kl_loss = self.e.kl_loss('Exponential', x_)
h_sum_kl = entropy + kl_loss
cross_entropy = self.e('cross_entropy', 'Exponential', x_)
cross_entropy = self.e.cross_entropy('Exponential', x_)
return h_sum_kl - cross_entropy
def test_cross_entropy():
......
......@@ -19,7 +19,6 @@ import mindspore.context as context
import mindspore.nn as nn
import mindspore.nn.probability.distribution as msd
from mindspore import Tensor
from mindspore.common.api import ms_function
from mindspore import dtype
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
......@@ -32,9 +31,8 @@ class Prob(nn.Cell):
super(Prob, self).__init__()
self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_):
return self.g('prob', x_)
return self.g.prob(x_)
def test_pmf():
"""
......@@ -56,9 +54,8 @@ class LogProb(nn.Cell):
super(LogProb, self).__init__()
self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_):
return self.g('log_prob', x_)
return self.g.log_prob(x_)
def test_log_likelihood():
"""
......@@ -80,9 +77,8 @@ class KL(nn.Cell):
super(KL, self).__init__()
self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_):
return self.g('kl_loss', 'Geometric', x_)
return self.g.kl_loss('Geometric', x_)
def test_kl_loss():
"""
......@@ -106,9 +102,8 @@ class Basics(nn.Cell):
super(Basics, self).__init__()
self.g = msd.Geometric([0.5, 0.5], dtype=dtype.int32)
@ms_function
def construct(self):
return self.g('mean'), self.g('sd'), self.g('mode')
return self.g.mean(), self.g.sd(), self.g.mode()
def test_basics():
"""
......@@ -133,9 +128,8 @@ class Sampling(nn.Cell):
self.g = msd.Geometric([0.7, 0.5], seed=seed, dtype=dtype.int32)
self.shape = shape
@ms_function
def construct(self, probs=None):
return self.g('sample', self.shape, probs)
return self.g.sample(self.shape, probs)
def test_sample():
"""
......@@ -154,9 +148,8 @@ class CDF(nn.Cell):
super(CDF, self).__init__()
self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_):
return self.g('cdf', x_)
return self.g.cdf(x_)
def test_cdf():
"""
......@@ -178,9 +171,8 @@ class LogCDF(nn.Cell):
super(LogCDF, self).__init__()
self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_):
return self.g('log_cdf', x_)
return self.g.log_cdf(x_)
def test_logcdf():
"""
......@@ -202,9 +194,8 @@ class SF(nn.Cell):
super(SF, self).__init__()
self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_):
return self.g('survival_function', x_)
return self.g.survival_function(x_)
def test_survival():
"""
......@@ -226,9 +217,8 @@ class LogSF(nn.Cell):
super(LogSF, self).__init__()
self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_):
return self.g('log_survival', x_)
return self.g.log_survival(x_)
def test_log_survival():
"""
......@@ -250,9 +240,8 @@ class EntropyH(nn.Cell):
super(EntropyH, self).__init__()
self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ms_function
def construct(self):
return self.g('entropy')
return self.g.entropy()
def test_entropy():
"""
......@@ -273,12 +262,11 @@ class CrossEntropy(nn.Cell):
super(CrossEntropy, self).__init__()
self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_):
entropy = self.g('entropy')
kl_loss = self.g('kl_loss', 'Geometric', x_)
entropy = self.g.entropy()
kl_loss = self.g.kl_loss('Geometric', x_)
h_sum_kl = entropy + kl_loss
ans = self.g('cross_entropy', 'Geometric', x_)
ans = self.g.cross_entropy('Geometric', x_)
return h_sum_kl - ans
def test_cross_entropy():
......
......@@ -19,7 +19,6 @@ import mindspore.context as context
import mindspore.nn as nn
import mindspore.nn.probability.distribution as msd
from mindspore import Tensor
from mindspore.common.api import ms_function
from mindspore import dtype
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
......@@ -32,9 +31,8 @@ class Prob(nn.Cell):
super(Prob, self).__init__()
self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
@ms_function
def construct(self, x_):
return self.n('prob', x_)
return self.n.prob(x_)
def test_pdf():
"""
......@@ -55,9 +53,8 @@ class LogProb(nn.Cell):
super(LogProb, self).__init__()
self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
@ms_function
def construct(self, x_):
return self.n('log_prob', x_)
return self.n.log_prob(x_)
def test_log_likelihood():
"""
......@@ -79,9 +76,8 @@ class KL(nn.Cell):
super(KL, self).__init__()
self.n = msd.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32)
@ms_function
def construct(self, x_, y_):
return self.n('kl_loss', 'Normal', x_, y_)
return self.n.kl_loss('Normal', x_, y_)
def test_kl_loss():
......@@ -113,9 +109,8 @@ class Basics(nn.Cell):
super(Basics, self).__init__()
self.n = msd.Normal(np.array([3.0]), np.array([2.0, 4.0]), dtype=dtype.float32)
@ms_function
def construct(self):
return self.n('mean'), self.n('sd'), self.n('mode')
return self.n.mean(), self.n.sd(), self.n.mode()
def test_basics():
"""
......@@ -139,9 +134,8 @@ class Sampling(nn.Cell):
self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), seed=seed, dtype=dtype.float32)
self.shape = shape
@ms_function
def construct(self, mean=None, sd=None):
return self.n('sample', self.shape, mean, sd)
return self.n.sample(self.shape, mean, sd)
def test_sample():
"""
......@@ -163,9 +157,8 @@ class CDF(nn.Cell):
super(CDF, self).__init__()
self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
@ms_function
def construct(self, x_):
return self.n('cdf', x_)
return self.n.cdf(x_)
def test_cdf():
......@@ -187,9 +180,8 @@ class LogCDF(nn.Cell):
super(LogCDF, self).__init__()
self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
@ms_function
def construct(self, x_):
return self.n('log_cdf', x_)
return self.n.log_cdf(x_)
def test_log_cdf():
"""
......@@ -210,9 +202,8 @@ class SF(nn.Cell):
super(SF, self).__init__()
self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
@ms_function
def construct(self, x_):
return self.n('survival_function', x_)
return self.n.survival_function(x_)
def test_survival():
"""
......@@ -233,9 +224,8 @@ class LogSF(nn.Cell):
super(LogSF, self).__init__()
self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
@ms_function
def construct(self, x_):
return self.n('log_survival', x_)
return self.n.log_survival(x_)
def test_log_survival():
"""
......@@ -256,9 +246,8 @@ class EntropyH(nn.Cell):
super(EntropyH, self).__init__()
self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
@ms_function
def construct(self):
return self.n('entropy')
return self.n.entropy()
def test_entropy():
"""
......@@ -279,12 +268,11 @@ class CrossEntropy(nn.Cell):
super(CrossEntropy, self).__init__()
self.n = msd.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32)
@ms_function
def construct(self, x_, y_):
entropy = self.n('entropy')
kl_loss = self.n('kl_loss', 'Normal', x_, y_)
entropy = self.n.entropy()
kl_loss = self.n.kl_loss('Normal', x_, y_)
h_sum_kl = entropy + kl_loss
cross_entropy = self.n('cross_entropy', 'Normal', x_, y_)
cross_entropy = self.n.cross_entropy('Normal', x_, y_)
return h_sum_kl - cross_entropy
def test_cross_entropy():
......@@ -297,3 +285,40 @@ def test_cross_entropy():
diff = cross_entropy(mean, sd)
tol = 1e-6
assert (np.abs(diff.asnumpy() - np.zeros(diff.shape)) < tol).all()
class Net(nn.Cell):
"""
Test class: expand single distribution instance to multiple graphs
by specifying the attributes.
"""
def __init__(self):
super(Net, self).__init__()
self.normal = msd.Normal(0., 1., dtype=dtype.float32)
def construct(self, x_, y_):
kl = self.normal.kl_loss('Normal', x_, y_)
prob = self.normal.prob(kl)
return prob
def test_multiple_graphs():
"""
Test multiple graphs case.
"""
prob = Net()
mean_a = np.array([0.0]).astype(np.float32)
sd_a = np.array([1.0]).astype(np.float32)
mean_b = np.array([1.0]).astype(np.float32)
sd_b = np.array([1.0]).astype(np.float32)
ans = prob(Tensor(mean_b), Tensor(sd_b))
diff_log_scale = np.log(sd_a) - np.log(sd_b)
squared_diff = np.square(mean_a / sd_b - mean_b / sd_b)
expect_kl_loss = 0.5 * squared_diff + 0.5 * \
np.expm1(2 * diff_log_scale) - diff_log_scale
norm_benchmark = stats.norm(np.array([0.0]), np.array([1.0]))
expect_prob = norm_benchmark.pdf(expect_kl_loss).astype(np.float32)
tol = 1e-6
assert (np.abs(ans.asnumpy() - expect_prob) < tol).all()
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test cases for new api of normal distribution"""
import numpy as np
from scipy import stats
import mindspore.nn as nn
import mindspore.nn.probability.distribution as msd
from mindspore import dtype
from mindspore import Tensor
import mindspore.context as context
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell):
"""
Test class: new api of normal distribution.
"""
def __init__(self):
super(Net, self).__init__()
self.normal = msd.Normal(0., 1., dtype=dtype.float32)
def construct(self, x_, y_):
kl = self.normal.kl_loss('kl_loss', 'Normal', x_, y_)
prob = self.normal.prob('prob', kl)
return prob
def test_new_api():
"""
Test new api of normal distribution.
"""
prob = Net()
mean_a = np.array([0.0]).astype(np.float32)
sd_a = np.array([1.0]).astype(np.float32)
mean_b = np.array([1.0]).astype(np.float32)
sd_b = np.array([1.0]).astype(np.float32)
ans = prob(Tensor(mean_b), Tensor(sd_b))
diff_log_scale = np.log(sd_a) - np.log(sd_b)
squared_diff = np.square(mean_a / sd_b - mean_b / sd_b)
expect_kl_loss = 0.5 * squared_diff + 0.5 * \
np.expm1(2 * diff_log_scale) - diff_log_scale
norm_benchmark = stats.norm(np.array([0.0]), np.array([1.0]))
expect_prob = norm_benchmark.pdf(expect_kl_loss).astype(np.float32)
tol = 1e-6
assert (np.abs(ans.asnumpy() - expect_prob) < tol).all()
......@@ -19,7 +19,6 @@ import mindspore.context as context
import mindspore.nn as nn
import mindspore.nn.probability.distribution as msd
from mindspore import Tensor
from mindspore.common.api import ms_function
from mindspore import dtype
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
......@@ -32,9 +31,8 @@ class Prob(nn.Cell):
super(Prob, self).__init__()
self.u = msd.Uniform([0.0], [[1.0], [2.0]], dtype=dtype.float32)
@ms_function
def construct(self, x_):
return self.u('prob', x_)
return self.u.prob(x_)
def test_pdf():
"""
......@@ -56,9 +54,8 @@ class LogProb(nn.Cell):
super(LogProb, self).__init__()
self.u = msd.Uniform([0.0], [[1.0], [2.0]], dtype=dtype.float32)
@ms_function
def construct(self, x_):
return self.u('log_prob', x_)
return self.u.log_prob(x_)
def test_log_likelihood():
"""
......@@ -80,9 +77,8 @@ class KL(nn.Cell):
super(KL, self).__init__()
self.u = msd.Uniform([0.0], [1.5], dtype=dtype.float32)
@ms_function
def construct(self, x_, y_):
return self.u('kl_loss', 'Uniform', x_, y_)
return self.u.kl_loss('Uniform', x_, y_)
def test_kl_loss():
"""
......@@ -106,9 +102,8 @@ class Basics(nn.Cell):
super(Basics, self).__init__()
self.u = msd.Uniform([0.0], [3.0], dtype=dtype.float32)
@ms_function
def construct(self):
return self.u('mean'), self.u('sd')
return self.u.mean(), self.u.sd()
def test_basics():
"""
......@@ -131,9 +126,8 @@ class Sampling(nn.Cell):
self.u = msd.Uniform([0.0], [[1.0], [2.0]], seed=seed, dtype=dtype.float32)
self.shape = shape
@ms_function
def construct(self, low=None, high=None):
return self.u('sample', self.shape, low, high)
return self.u.sample(self.shape, low, high)
def test_sample():
"""
......@@ -155,9 +149,8 @@ class CDF(nn.Cell):
super(CDF, self).__init__()
self.u = msd.Uniform([0.0], [1.0], dtype=dtype.float32)
@ms_function
def construct(self, x_):
return self.u('cdf', x_)
return self.u.cdf(x_)
def test_cdf():
"""
......@@ -179,9 +172,8 @@ class LogCDF(nn.Cell):
super(LogCDF, self).__init__()
self.u = msd.Uniform([0.0], [1.0], dtype=dtype.float32)
@ms_function
def construct(self, x_):
return self.u('log_cdf', x_)
return self.u.log_cdf(x_)
class SF(nn.Cell):
"""
......@@ -191,9 +183,8 @@ class SF(nn.Cell):
super(SF, self).__init__()
self.u = msd.Uniform([0.0], [1.0], dtype=dtype.float32)
@ms_function
def construct(self, x_):
return self.u('survival_function', x_)
return self.u.survival_function(x_)
class LogSF(nn.Cell):
"""
......@@ -203,9 +194,8 @@ class LogSF(nn.Cell):
super(LogSF, self).__init__()
self.u = msd.Uniform([0.0], [1.0], dtype=dtype.float32)
@ms_function
def construct(self, x_):
return self.u('log_survival', x_)
return self.u.log_survival(x_)
class EntropyH(nn.Cell):
"""
......@@ -215,9 +205,8 @@ class EntropyH(nn.Cell):
super(EntropyH, self).__init__()
self.u = msd.Uniform([0.0], [1.0, 2.0], dtype=dtype.float32)
@ms_function
def construct(self):
return self.u('entropy')
return self.u.entropy()
def test_entropy():
"""
......@@ -238,12 +227,11 @@ class CrossEntropy(nn.Cell):
super(CrossEntropy, self).__init__()
self.u = msd.Uniform([0.0], [1.5], dtype=dtype.float32)
@ms_function
def construct(self, x_, y_):
entropy = self.u('entropy')
kl_loss = self.u('kl_loss', 'Uniform', x_, y_)
entropy = self.u.entropy()
kl_loss = self.u.kl_loss('Uniform', x_, y_)
h_sum_kl = entropy + kl_loss
cross_entropy = self.u('cross_entropy', 'Uniform', x_, y_)
cross_entropy = self.u.cross_entropy('Uniform', x_, y_)
return h_sum_kl - cross_entropy
def test_log_cdf():
......
......@@ -49,12 +49,12 @@ class BernoulliProb(nn.Cell):
self.b = msd.Bernoulli(0.5, dtype=dtype.int32)
def construct(self, value):
prob = self.b('prob', value)
log_prob = self.b('log_prob', value)
cdf = self.b('cdf', value)
log_cdf = self.b('log_cdf', value)
sf = self.b('survival_function', value)
log_sf = self.b('log_survival', value)
prob = self.b.prob(value)
log_prob = self.b.log_prob(value)
cdf = self.b.cdf(value)
log_cdf = self.b.log_cdf(value)
sf = self.b.survival_function(value)
log_sf = self.b.log_survival(value)
return prob + log_prob + cdf + log_cdf + sf + log_sf
def test_bernoulli_prob():
......@@ -75,12 +75,12 @@ class BernoulliProb1(nn.Cell):
self.b = msd.Bernoulli(dtype=dtype.int32)
def construct(self, value, probs):
prob = self.b('prob', value, probs)
log_prob = self.b('log_prob', value, probs)
cdf = self.b('cdf', value, probs)
log_cdf = self.b('log_cdf', value, probs)
sf = self.b('survival_function', value, probs)
log_sf = self.b('log_survival', value, probs)
prob = self.b.prob(value, probs)
log_prob = self.b.log_prob(value, probs)
cdf = self.b.cdf(value, probs)
log_cdf = self.b.log_cdf(value, probs)
sf = self.b.survival_function(value, probs)
log_sf = self.b.log_survival(value, probs)
return prob + log_prob + cdf + log_cdf + sf + log_sf
def test_bernoulli_prob1():
......@@ -103,8 +103,8 @@ class BernoulliKl(nn.Cell):
self.b2 = msd.Bernoulli(dtype=dtype.int32)
def construct(self, probs_b, probs_a):
kl1 = self.b1('kl_loss', 'Bernoulli', probs_b)
kl2 = self.b2('kl_loss', 'Bernoulli', probs_b, probs_a)
kl1 = self.b1.kl_loss('Bernoulli', probs_b)
kl2 = self.b2.kl_loss('Bernoulli', probs_b, probs_a)
return kl1 + kl2
def test_kl():
......@@ -127,8 +127,8 @@ class BernoulliCrossEntropy(nn.Cell):
self.b2 = msd.Bernoulli(dtype=dtype.int32)
def construct(self, probs_b, probs_a):
h1 = self.b1('cross_entropy', 'Bernoulli', probs_b)
h2 = self.b2('cross_entropy', 'Bernoulli', probs_b, probs_a)
h1 = self.b1.cross_entropy('Bernoulli', probs_b)
h2 = self.b2.cross_entropy('Bernoulli', probs_b, probs_a)
return h1 + h2
def test_cross_entropy():
......@@ -150,11 +150,11 @@ class BernoulliBasics(nn.Cell):
self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32)
def construct(self):
mean = self.b('mean')
sd = self.b('sd')
var = self.b('var')
mode = self.b('mode')
entropy = self.b('entropy')
mean = self.b.mean()
sd = self.b.sd()
var = self.b.var()
mode = self.b.mode()
entropy = self.b.entropy()
return mean + sd + var + mode + entropy
def test_bascis():
......@@ -164,3 +164,28 @@ def test_bascis():
net = BernoulliBasics()
ans = net()
assert isinstance(ans, Tensor)
class BernoulliConstruct(nn.Cell):
"""
Bernoulli distribution: going through construct.
"""
def __init__(self):
super(BernoulliConstruct, self).__init__()
self.b = msd.Bernoulli(0.5, dtype=dtype.int32)
self.b1 = msd.Bernoulli(dtype=dtype.int32)
def construct(self, value, probs):
prob = self.b('prob', value)
prob1 = self.b('prob', value, probs)
prob2 = self.b1('prob', value, probs)
return prob + prob1 + prob2
def test_bernoulli_construct():
"""
Test probability function going through construct.
"""
net = BernoulliConstruct()
value = Tensor([0, 0, 0, 0, 0], dtype=dtype.float32)
probs = Tensor([0.5], dtype=dtype.float32)
ans = net(value, probs)
assert isinstance(ans, Tensor)
......@@ -50,12 +50,12 @@ class ExponentialProb(nn.Cell):
self.e = msd.Exponential(0.5, dtype=dtype.float32)
def construct(self, value):
prob = self.e('prob', value)
log_prob = self.e('log_prob', value)
cdf = self.e('cdf', value)
log_cdf = self.e('log_cdf', value)
sf = self.e('survival_function', value)
log_sf = self.e('log_survival', value)
prob = self.e.prob(value)
log_prob = self.e.log_prob(value)
cdf = self.e.cdf(value)
log_cdf = self.e.log_cdf(value)
sf = self.e.survival_function(value)
log_sf = self.e.log_survival(value)
return prob + log_prob + cdf + log_cdf + sf + log_sf
def test_exponential_prob():
......@@ -76,12 +76,12 @@ class ExponentialProb1(nn.Cell):
self.e = msd.Exponential(dtype=dtype.float32)
def construct(self, value, rate):
prob = self.e('prob', value, rate)
log_prob = self.e('log_prob', value, rate)
cdf = self.e('cdf', value, rate)
log_cdf = self.e('log_cdf', value, rate)
sf = self.e('survival_function', value, rate)
log_sf = self.e('log_survival', value, rate)
prob = self.e.prob(value, rate)
log_prob = self.e.log_prob(value, rate)
cdf = self.e.cdf(value, rate)
log_cdf = self.e.log_cdf(value, rate)
sf = self.e.survival_function(value, rate)
log_sf = self.e.log_survival(value, rate)
return prob + log_prob + cdf + log_cdf + sf + log_sf
def test_exponential_prob1():
......@@ -104,8 +104,8 @@ class ExponentialKl(nn.Cell):
self.e2 = msd.Exponential(dtype=dtype.float32)
def construct(self, rate_b, rate_a):
kl1 = self.e1('kl_loss', 'Exponential', rate_b)
kl2 = self.e2('kl_loss', 'Exponential', rate_b, rate_a)
kl1 = self.e1.kl_loss('Exponential', rate_b)
kl2 = self.e2.kl_loss('Exponential', rate_b, rate_a)
return kl1 + kl2
def test_kl():
......@@ -128,8 +128,8 @@ class ExponentialCrossEntropy(nn.Cell):
self.e2 = msd.Exponential(dtype=dtype.float32)
def construct(self, rate_b, rate_a):
h1 = self.e1('cross_entropy', 'Exponential', rate_b)
h2 = self.e2('cross_entropy', 'Exponential', rate_b, rate_a)
h1 = self.e1.cross_entropy('Exponential', rate_b)
h2 = self.e2.cross_entropy('Exponential', rate_b, rate_a)
return h1 + h2
def test_cross_entropy():
......@@ -151,11 +151,11 @@ class ExponentialBasics(nn.Cell):
self.e = msd.Exponential([0.3, 0.5], dtype=dtype.float32)
def construct(self):
mean = self.e('mean')
sd = self.e('sd')
var = self.e('var')
mode = self.e('mode')
entropy = self.e('entropy')
mean = self.e.mean()
sd = self.e.sd()
var = self.e.var()
mode = self.e.mode()
entropy = self.e.entropy()
return mean + sd + var + mode + entropy
def test_bascis():
......@@ -165,3 +165,29 @@ def test_bascis():
net = ExponentialBasics()
ans = net()
assert isinstance(ans, Tensor)
class ExpConstruct(nn.Cell):
"""
Exponential distribution: going through construct.
"""
def __init__(self):
super(ExpConstruct, self).__init__()
self.e = msd.Exponential(0.5, dtype=dtype.float32)
self.e1 = msd.Exponential(dtype=dtype.float32)
def construct(self, value, rate):
prob = self.e('prob', value)
prob1 = self.e('prob', value, rate)
prob2 = self.e1('prob', value, rate)
return prob + prob1 + prob2
def test_exp_construct():
"""
Test probability function going through construct.
"""
net = ExpConstruct()
value = Tensor([0, 0, 0, 0, 0], dtype=dtype.float32)
probs = Tensor([0.5], dtype=dtype.float32)
ans = net(value, probs)
assert isinstance(ans, Tensor)
......@@ -50,12 +50,12 @@ class GeometricProb(nn.Cell):
self.g = msd.Geometric(0.5, dtype=dtype.int32)
def construct(self, value):
prob = self.g('prob', value)
log_prob = self.g('log_prob', value)
cdf = self.g('cdf', value)
log_cdf = self.g('log_cdf', value)
sf = self.g('survival_function', value)
log_sf = self.g('log_survival', value)
prob = self.g.prob(value)
log_prob = self.g.log_prob(value)
cdf = self.g.cdf(value)
log_cdf = self.g.log_cdf(value)
sf = self.g.survival_function(value)
log_sf = self.g.log_survival(value)
return prob + log_prob + cdf + log_cdf + sf + log_sf
def test_geometric_prob():
......@@ -76,12 +76,12 @@ class GeometricProb1(nn.Cell):
self.g = msd.Geometric(dtype=dtype.int32)
def construct(self, value, probs):
prob = self.g('prob', value, probs)
log_prob = self.g('log_prob', value, probs)
cdf = self.g('cdf', value, probs)
log_cdf = self.g('log_cdf', value, probs)
sf = self.g('survival_function', value, probs)
log_sf = self.g('log_survival', value, probs)
prob = self.g.prob(value, probs)
log_prob = self.g.log_prob(value, probs)
cdf = self.g.cdf(value, probs)
log_cdf = self.g.log_cdf(value, probs)
sf = self.g.survival_function(value, probs)
log_sf = self.g.log_survival(value, probs)
return prob + log_prob + cdf + log_cdf + sf + log_sf
def test_geometric_prob1():
......@@ -105,8 +105,8 @@ class GeometricKl(nn.Cell):
self.g2 = msd.Geometric(dtype=dtype.int32)
def construct(self, probs_b, probs_a):
kl1 = self.g1('kl_loss', 'Geometric', probs_b)
kl2 = self.g2('kl_loss', 'Geometric', probs_b, probs_a)
kl1 = self.g1.kl_loss('Geometric', probs_b)
kl2 = self.g2.kl_loss('Geometric', probs_b, probs_a)
return kl1 + kl2
def test_kl():
......@@ -129,8 +129,8 @@ class GeometricCrossEntropy(nn.Cell):
self.g2 = msd.Geometric(dtype=dtype.int32)
def construct(self, probs_b, probs_a):
h1 = self.g1('cross_entropy', 'Geometric', probs_b)
h2 = self.g2('cross_entropy', 'Geometric', probs_b, probs_a)
h1 = self.g1.cross_entropy('Geometric', probs_b)
h2 = self.g2.cross_entropy('Geometric', probs_b, probs_a)
return h1 + h2
def test_cross_entropy():
......@@ -152,11 +152,11 @@ class GeometricBasics(nn.Cell):
self.g = msd.Geometric([0.3, 0.5], dtype=dtype.int32)
def construct(self):
mean = self.g('mean')
sd = self.g('sd')
var = self.g('var')
mode = self.g('mode')
entropy = self.g('entropy')
mean = self.g.mean()
sd = self.g.sd()
var = self.g.var()
mode = self.g.mode()
entropy = self.g.entropy()
return mean + sd + var + mode + entropy
def test_bascis():
......@@ -166,3 +166,29 @@ def test_bascis():
net = GeometricBasics()
ans = net()
assert isinstance(ans, Tensor)
class GeoConstruct(nn.Cell):
"""
Bernoulli distribution: going through construct.
"""
def __init__(self):
super(GeoConstruct, self).__init__()
self.g = msd.Geometric(0.5, dtype=dtype.int32)
self.g1 = msd.Geometric(dtype=dtype.int32)
def construct(self, value, probs):
prob = self.g('prob', value)
prob1 = self.g('prob', value, probs)
prob2 = self.g1('prob', value, probs)
return prob + prob1 + prob2
def test_geo_construct():
"""
Test probability function going through construct.
"""
net = GeoConstruct()
value = Tensor([0, 0, 0, 0, 0], dtype=dtype.float32)
probs = Tensor([0.5], dtype=dtype.float32)
ans = net(value, probs)
assert isinstance(ans, Tensor)
......@@ -50,12 +50,12 @@ class NormalProb(nn.Cell):
self.normal = msd.Normal(3.0, 4.0, dtype=dtype.float32)
def construct(self, value):
prob = self.normal('prob', value)
log_prob = self.normal('log_prob', value)
cdf = self.normal('cdf', value)
log_cdf = self.normal('log_cdf', value)
sf = self.normal('survival_function', value)
log_sf = self.normal('log_survival', value)
prob = self.normal.prob(value)
log_prob = self.normal.log_prob(value)
cdf = self.normal.cdf(value)
log_cdf = self.normal.log_cdf(value)
sf = self.normal.survival_function(value)
log_sf = self.normal.log_survival(value)
return prob + log_prob + cdf + log_cdf + sf + log_sf
def test_normal_prob():
......@@ -77,12 +77,12 @@ class NormalProb1(nn.Cell):
self.normal = msd.Normal()
def construct(self, value, mean, sd):
prob = self.normal('prob', value, mean, sd)
log_prob = self.normal('log_prob', value, mean, sd)
cdf = self.normal('cdf', value, mean, sd)
log_cdf = self.normal('log_cdf', value, mean, sd)
sf = self.normal('survival_function', value, mean, sd)
log_sf = self.normal('log_survival', value, mean, sd)
prob = self.normal.prob(value, mean, sd)
log_prob = self.normal.log_prob(value, mean, sd)
cdf = self.normal.cdf(value, mean, sd)
log_cdf = self.normal.log_cdf(value, mean, sd)
sf = self.normal.survival_function(value, mean, sd)
log_sf = self.normal.log_survival(value, mean, sd)
return prob + log_prob + cdf + log_cdf + sf + log_sf
def test_normal_prob1():
......@@ -106,8 +106,8 @@ class NormalKl(nn.Cell):
self.n2 = msd.Normal(dtype=dtype.float32)
def construct(self, mean_b, sd_b, mean_a, sd_a):
kl1 = self.n1('kl_loss', 'Normal', mean_b, sd_b)
kl2 = self.n2('kl_loss', 'Normal', mean_b, sd_b, mean_a, sd_a)
kl1 = self.n1.kl_loss('Normal', mean_b, sd_b)
kl2 = self.n2.kl_loss('Normal', mean_b, sd_b, mean_a, sd_a)
return kl1 + kl2
def test_kl():
......@@ -132,8 +132,8 @@ class NormalCrossEntropy(nn.Cell):
self.n2 = msd.Normal(dtype=dtype.float32)
def construct(self, mean_b, sd_b, mean_a, sd_a):
h1 = self.n1('cross_entropy', 'Normal', mean_b, sd_b)
h2 = self.n2('cross_entropy', 'Normal', mean_b, sd_b, mean_a, sd_a)
h1 = self.n1.cross_entropy('Normal', mean_b, sd_b)
h2 = self.n2.cross_entropy('Normal', mean_b, sd_b, mean_a, sd_a)
return h1 + h2
def test_cross_entropy():
......@@ -157,10 +157,10 @@ class NormalBasics(nn.Cell):
self.n = msd.Normal(3.0, 4.0, dtype=dtype.float32)
def construct(self):
mean = self.n('mean')
sd = self.n('sd')
mode = self.n('mode')
entropy = self.n('entropy')
mean = self.n.mean()
sd = self.n.sd()
mode = self.n.mode()
entropy = self.n.entropy()
return mean + sd + mode + entropy
def test_bascis():
......@@ -170,3 +170,30 @@ def test_bascis():
net = NormalBasics()
ans = net()
assert isinstance(ans, Tensor)
class NormalConstruct(nn.Cell):
"""
Normal distribution: going through construct.
"""
def __init__(self):
super(NormalConstruct, self).__init__()
self.normal = msd.Normal(3.0, 4.0)
self.normal1 = msd.Normal()
def construct(self, value, mean, sd):
prob = self.normal('prob', value)
prob1 = self.normal('prob', value, mean, sd)
prob2 = self.normal1('prob', value, mean, sd)
return prob + prob1 + prob2
def test_normal_construct():
"""
Test probability function going through construct.
"""
net = NormalConstruct()
value = Tensor([0.5, 1.0], dtype=dtype.float32)
mean = Tensor([0.0], dtype=dtype.float32)
sd = Tensor([1.0], dtype=dtype.float32)
ans = net(value, mean, sd)
assert isinstance(ans, Tensor)
......@@ -60,12 +60,12 @@ class UniformProb(nn.Cell):
self.u = msd.Uniform(3.0, 4.0, dtype=dtype.float32)
def construct(self, value):
prob = self.u('prob', value)
log_prob = self.u('log_prob', value)
cdf = self.u('cdf', value)
log_cdf = self.u('log_cdf', value)
sf = self.u('survival_function', value)
log_sf = self.u('log_survival', value)
prob = self.u.prob(value)
log_prob = self.u.log_prob(value)
cdf = self.u.cdf(value)
log_cdf = self.u.log_cdf(value)
sf = self.u.survival_function(value)
log_sf = self.u.log_survival(value)
return prob + log_prob + cdf + log_cdf + sf + log_sf
def test_uniform_prob():
......@@ -86,12 +86,12 @@ class UniformProb1(nn.Cell):
self.u = msd.Uniform(dtype=dtype.float32)
def construct(self, value, low, high):
prob = self.u('prob', value, low, high)
log_prob = self.u('log_prob', value, low, high)
cdf = self.u('cdf', value, low, high)
log_cdf = self.u('log_cdf', value, low, high)
sf = self.u('survival_function', value, low, high)
log_sf = self.u('log_survival', value, low, high)
prob = self.u.prob(value, low, high)
log_prob = self.u.log_prob(value, low, high)
cdf = self.u.cdf(value, low, high)
log_cdf = self.u.log_cdf(value, low, high)
sf = self.u.survival_function(value, low, high)
log_sf = self.u.log_survival(value, low, high)
return prob + log_prob + cdf + log_cdf + sf + log_sf
def test_uniform_prob1():
......@@ -115,8 +115,8 @@ class UniformKl(nn.Cell):
self.u2 = msd.Uniform(dtype=dtype.float32)
def construct(self, low_b, high_b, low_a, high_a):
kl1 = self.u1('kl_loss', 'Uniform', low_b, high_b)
kl2 = self.u2('kl_loss', 'Uniform', low_b, high_b, low_a, high_a)
kl1 = self.u1.kl_loss('Uniform', low_b, high_b)
kl2 = self.u2.kl_loss('Uniform', low_b, high_b, low_a, high_a)
return kl1 + kl2
def test_kl():
......@@ -141,8 +141,8 @@ class UniformCrossEntropy(nn.Cell):
self.u2 = msd.Uniform(dtype=dtype.float32)
def construct(self, low_b, high_b, low_a, high_a):
h1 = self.u1('cross_entropy', 'Uniform', low_b, high_b)
h2 = self.u2('cross_entropy', 'Uniform', low_b, high_b, low_a, high_a)
h1 = self.u1.cross_entropy('Uniform', low_b, high_b)
h2 = self.u2.cross_entropy('Uniform', low_b, high_b, low_a, high_a)
return h1 + h2
def test_cross_entropy():
......@@ -166,10 +166,10 @@ class UniformBasics(nn.Cell):
self.u = msd.Uniform(3.0, 4.0, dtype=dtype.float32)
def construct(self):
mean = self.u('mean')
sd = self.u('sd')
var = self.u('var')
entropy = self.u('entropy')
mean = self.u.mean()
sd = self.u.sd()
var = self.u.var()
entropy = self.u.entropy()
return mean + sd + var + entropy
def test_bascis():
......@@ -179,3 +179,30 @@ def test_bascis():
net = UniformBasics()
ans = net()
assert isinstance(ans, Tensor)
class UniConstruct(nn.Cell):
"""
Unifrom distribution: going through construct.
"""
def __init__(self):
super(UniConstruct, self).__init__()
self.u = msd.Uniform(-4.0, 4.0)
self.u1 = msd.Uniform()
def construct(self, value, low, high):
prob = self.u('prob', value)
prob1 = self.u('prob', value, low, high)
prob2 = self.u1('prob', value, low, high)
return prob + prob1 + prob2
def test_uniform_construct():
"""
Test probability function going through construct.
"""
net = UniConstruct()
value = Tensor([-5.0, 0.0, 1.0, 5.0], dtype=dtype.float32)
low = Tensor([-1.0], dtype=dtype.float32)
high = Tensor([1.0], dtype=dtype.float32)
ans = net(value, low, high)
assert isinstance(ans, Tensor)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册