提交 e0009b90 编写于 作者: X Xun Deng

added type check in distributions and fixed bugs in cast_to_tensor

上级 1744948d
......@@ -15,6 +15,7 @@
"""Utitly functions to help distribution class."""
import numpy as np
from mindspore.ops import _utils as utils
from mindspore.ops.primitive import constexpr
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from mindspore.common import dtype as mstype
......@@ -23,7 +24,7 @@ from mindspore.ops import composite as C
import mindspore.nn as nn
import mindspore.nn.probability as msp
def cast_to_tensor(t, hint_dtype=mstype.float32):
def cast_to_tensor(t, hint_type=mstype.float32):
"""
Cast an user input value into a Tensor of dtype.
If the input t is of type Parameter, t is directly returned as a Parameter.
......@@ -38,24 +39,27 @@ def cast_to_tensor(t, hint_dtype=mstype.float32):
Returns:
Tensor.
"""
if t is None:
raise ValueError(f'Input cannot be None in cast_to_tensor')
if isinstance(t, Parameter):
return t
t_type = hint_type
if isinstance(t, Tensor):
if t.dtype != hint_dtype:
raise TypeError(f"Input tensor should be type {hint_dtype}.")
#check if the Tensor in shape of Tensor(4)
if t.dim() == 0:
value = t.asnumpy()
return Tensor([value], dtype=hint_dtype)
return Tensor([value], dtype=t_type)
#convert the type of tensor to dtype
return t
return Tensor(t.asnumpy(), dtype=t_type)
if isinstance(t, (list, np.ndarray)):
return Tensor(t, dtype=hint_dtype)
if np.isscalar(t):
return Tensor([t], dtype=hint_dtype)
raise RuntimeError("Input type is not supported.")
def convert_to_batch(t, batch_shape, hint_dtype):
return Tensor(t, dtype=t_type)
if isinstance(t, bool):
raise TypeError(f'Input cannot be Type Bool')
if isinstance(t, (int, float)):
return Tensor([t], dtype=t_type)
raise TypeError("Input type is not supported.")
def convert_to_batch(t, batch_shape, required_type):
"""
Convert a Tensor to a given batch shape.
......@@ -72,8 +76,8 @@ def convert_to_batch(t, batch_shape, hint_dtype):
"""
if isinstance(t, Parameter):
return t
t = cast_to_tensor(t, hint_dtype)
return Tensor(np.broadcast_to(t.asnumpy(), batch_shape), dtype=hint_dtype)
t = cast_to_tensor(t, required_type)
return Tensor(np.broadcast_to(t.asnumpy(), batch_shape), dtype=required_type)
def check_scalar_from_param(params):
"""
......@@ -91,7 +95,7 @@ def check_scalar_from_param(params):
return False
if isinstance(value, (str, type(params['dtype']))):
continue
elif np.isscalar(value):
elif isinstance(value, (int, float)):
continue
else:
return False
......@@ -119,10 +123,11 @@ def calc_broadcast_shape_from_param(params):
if isinstance(value, Parameter):
value_t = value.default_input
else:
value_t = cast_to_tensor(value, params['dtype'])
value_t = cast_to_tensor(value, mstype.float32)
broadcast_shape = utils.get_broadcast_shape(broadcast_shape, list(value_t.shape), params['name'])
return tuple(broadcast_shape)
def check_greater_equal_zero(value, name):
"""
Check if the given Tensor is greater zero.
......@@ -155,14 +160,17 @@ def check_greater_zero(value, name):
ValueError: if the input value is less than or equal to zero.
"""
if value is None:
raise ValueError(f'input value cannot be None in check_greater_zero')
if isinstance(value, Parameter):
if isinstance(value.default_input, MetaTensor):
if not isinstance(value.default_input, Tensor):
return
value = value.default_input
comp = np.less(np.zeros(value.shape), value.asnumpy())
if not comp.all():
raise ValueError(f'{name} should be greater than zero.')
def check_greater(a, b, name_a, name_b):
"""
Check if Tensor b is strictly greater than Tensor a.
......@@ -176,6 +184,8 @@ def check_greater(a, b, name_a, name_b):
Raises:
ValueError: if b is less than or equal to a
"""
if a is None or b is None:
raise ValueError(f'input value cannot be None in check_greater')
if isinstance(a, Parameter) or isinstance(b, Parameter):
return
comp = np.less(a.asnumpy(), b.asnumpy())
......@@ -193,6 +203,8 @@ def check_prob(p):
Raises:
ValueError: if p is not a proper probability.
"""
if p is None:
raise ValueError(f'input value cannot be None in check_greater_zero')
if isinstance(p, Parameter):
if not isinstance(p.default_input, Tensor):
return
......@@ -259,3 +271,12 @@ def check_tensor_type(name, inputs, valid_type):
def check_type(data_type, value_type, name):
if not data_type in value_type:
raise TypeError(f"For {name}, valid type include {value_type}, {data_type} is invalid")
@constexpr
def raise_none_error(name):
raise ValueError(f"{name} should be specified. Value cannot be None")
@constexpr
def check_distribution_name(name, expected_name):
if name != expected_name:
raise ValueError(f"Distribution should be {expected_name}.")
......@@ -17,7 +17,7 @@ from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob, check_type
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error
class Bernoulli(Distribution):
"""
......@@ -99,8 +99,9 @@ class Bernoulli(Distribution):
valid_dtype = mstype.int_type + mstype.uint_type
check_type(dtype, valid_dtype, "Bernoulli")
super(Bernoulli, self).__init__(seed, dtype, name, param)
self.parameter_type = mstype.float32
if probs is not None:
self._probs = cast_to_tensor(probs, hint_dtype=mstype.float32)
self._probs = cast_to_tensor(probs, mstype.float32)
check_prob(self.probs)
else:
self._probs = probs
......@@ -111,6 +112,7 @@ class Bernoulli(Distribution):
self.dtypeop = P.DType()
self.erf = P.Erf()
self.exp = P.Exp()
self.floor = P.Floor()
self.fill = P.Fill()
self.log = P.Log()
self.less = P.Less()
......@@ -139,14 +141,19 @@ class Bernoulli(Distribution):
.. math::
MEAN(B) = probs1
"""
return self.probs if probs1 is None else probs1
probs1 = self.cast(probs1, self.parameter_type) if probs1 is not None else self.probs
if probs1 is None:
raise_none_error("probs1")
return probs1
def _mode(self, probs1=None):
r"""
.. math::
MODE(B) = 1 if probs1 > 0.5 else = 0
"""
probs1 = self.probs if probs1 is None else probs1
probs1 = self.cast(probs1, self.parameter_type) if probs1 is not None else self.probs
if probs1 is None:
raise_none_error("probs1")
prob_type = self.dtypeop(probs1)
zeros = self.fill(prob_type, self.shape(probs1), 0.0)
ones = self.fill(prob_type, self.shape(probs1), 1.0)
......@@ -158,7 +165,9 @@ class Bernoulli(Distribution):
.. math::
VAR(B) = probs1 * probs0
"""
probs1 = self.probs if probs1 is None else probs1
probs1 = self.cast(probs1, self.parameter_type) if probs1 is not None else self.probs
if probs1 is None:
raise_none_error("probs1")
probs0 = 1.0 - probs1
return self.exp(self.log(probs0) + self.log(probs1))
......@@ -167,7 +176,9 @@ class Bernoulli(Distribution):
.. math::
H(B) = -probs0 * \log(probs0) - probs1 * \log(probs1)
"""
probs1 = self.probs if probs is None else probs
probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs
if probs1 is None:
raise_none_error("probs")
probs0 = 1 - probs1
return -1 * (probs0 * self.log(probs0)) - (probs1 * self.log(probs1))
......@@ -180,9 +191,8 @@ class Bernoulli(Distribution):
probs1_b (Tensor): probs1 of distribution b.
probs1_a (Tensor): probs1 of distribution a. Default: self.probs.
"""
if dist == 'Bernoulli':
return self._entropy(probs=probs1_a) + self._kl_loss(dist, probs1_b, probs1_a)
return None
check_distribution_name(dist, 'Bernoulli')
return self._entropy(probs=probs1_a) + self._kl_loss(dist, probs1_b, probs1_a)
def _log_prob(self, value, probs=None):
r"""
......@@ -196,7 +206,13 @@ class Bernoulli(Distribution):
pmf(k) = probs1 if k = 1;
pmf(k) = probs0 if k = 0;
"""
probs1 = self.probs if probs is None else probs
if value is None:
raise_none_error("value")
value = self.cast(value, mstype.float32)
value = self.floor(value)
probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs
if probs1 is None:
raise_none_error("probs")
probs0 = 1.0 - probs1
return self.log(probs1) * value + self.log(probs0) * (1.0 - value)
......@@ -213,7 +229,13 @@ class Bernoulli(Distribution):
cdf(k) = probs0 if 0 <= k <1;
cdf(k) = 1 if k >=1;
"""
probs1 = self.probs if probs is None else probs
if value is None:
raise_none_error("value")
value = self.cast(value, mstype.float32)
value = self.floor(value)
probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs
if probs1 is None:
raise_none_error("probs")
prob_type = self.dtypeop(probs1)
value = value * self.fill(prob_type, self.shape(probs1), 1.0)
probs0 = 1.0 - probs1 * self.fill(prob_type, self.shape(value), 1.0)
......@@ -230,19 +252,23 @@ class Bernoulli(Distribution):
Args:
dist (str): type of the distributions. Should be "Bernoulli" in this case.
probs1_b (Tensor): probs1 of distribution b.
probs1_a (Tensor): probs1 of distribution a. Default: self.probs.
probs1_b (Tensor, Number): probs1 of distribution b.
probs1_a (Tensor, Number): probs1 of distribution a. Default: self.probs.
.. math::
KL(a||b) = probs1_a * \log(\frac{probs1_a}{probs1_b}) +
probs0_a * \log(\frac{probs0_a}{probs0_b})
"""
if dist == 'Bernoulli':
probs1_a = self.probs if probs1_a is None else probs1_a
probs0_a = 1.0 - probs1_a
probs0_b = 1.0 - probs1_b
return probs1_a * self.log(probs1_a / probs1_b) + probs0_a * self.log(probs0_a / probs0_b)
return None
check_distribution_name(dist, 'Bernoulli')
if probs1_b is None:
raise_none_error("probs1_b")
probs1_b = self.cast(probs1_b, self.parameter_type)
probs1_a = self.cast(probs1_a, self.parameter_type) if probs1_a is not None else self.probs
if probs1_a is None:
raise_none_error("probs1_a")
probs0_a = 1.0 - probs1_a
probs0_b = 1.0 - probs1_b
return probs1_a * self.log(probs1_a / probs1_b) + probs0_a * self.log(probs0_a / probs0_b)
def _sample(self, shape=(), probs=None):
"""
......@@ -250,12 +276,14 @@ class Bernoulli(Distribution):
Args:
shape (tuple): shape of the sample. Default: ().
probs (Tensor): probs1 of the samples. Default: self.probs.
probs (Tensor, Number): probs1 of the samples. Default: self.probs.
Returns:
Tensor, shape is shape + batch_shape.
"""
probs1 = self.probs if probs is None else probs
probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs
if probs1 is None:
raise_none_error("probs")
l_zero = self.const(0.0)
h_one = self.const(1.0)
sample_uniform = self.uniform(shape + self.shape(probs1), l_zero, h_one, self.seed)
......
......@@ -18,7 +18,8 @@ from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\
raise_none_error
class Exponential(Distribution):
"""
......@@ -100,8 +101,9 @@ class Exponential(Distribution):
valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, "Exponential")
super(Exponential, self).__init__(seed, dtype, name, param)
self.parameter_type = dtype
if rate is not None:
self._rate = cast_to_tensor(rate, dtype)
self._rate = cast_to_tensor(rate, self.parameter_type)
check_greater_zero(self._rate, "rate")
else:
self._rate = rate
......@@ -141,16 +143,19 @@ class Exponential(Distribution):
.. math::
MEAN(EXP) = \frac{1.0}{\lambda}.
"""
rate = self.rate if rate is None else rate
rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate
if rate is None:
raise_none_error("rate")
return 1.0 / rate
def _mode(self, rate=None):
r"""
.. math::
MODE(EXP) = 0.
"""
rate = self.rate if rate is None else rate
rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate
if rate is None:
raise_none_error("rate")
return self.fill(self.dtype, self.shape(rate), 0.)
def _sd(self, rate=None):
......@@ -158,7 +163,9 @@ class Exponential(Distribution):
.. math::
sd(EXP) = \frac{1.0}{\lambda}.
"""
rate = self.rate if rate is None else rate
rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate
if rate is None:
raise_none_error("rate")
return 1.0 / rate
def _entropy(self, rate=None):
......@@ -166,7 +173,9 @@ class Exponential(Distribution):
.. math::
H(Exp) = 1 - \log(\lambda).
"""
rate = self.rate if rate is None else rate
rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate
if rate is None:
raise_none_error("rate")
return 1.0 - self.log(rate)
......@@ -179,9 +188,9 @@ class Exponential(Distribution):
rate_b (Tensor): rate of distribution b.
rate_a (Tensor): rate of distribution a. Default: self.rate.
"""
if dist == 'Exponential':
return self._entropy(rate=rate_a) + self._kl_loss(dist, rate_b, rate_a)
return None
check_distribution_name(dist, 'Exponential')
return self._entropy(rate=rate_a) + self._kl_loss(dist, rate_b, rate_a)
def _prob(self, value, rate=None):
r"""
......@@ -198,7 +207,12 @@ class Exponential(Distribution):
.. math::
pdf(x) = rate * \exp(-1 * \lambda * x) if x >= 0 else 0
"""
rate = self.rate if rate is None else rate
if value is None:
raise_none_error("value")
value = self.cast(value, self.dtype)
rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate
if rate is None:
raise_none_error("rate")
prob = self.exp(self.log(rate) - rate * value)
zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0)
comp = self.less(value, zeros)
......@@ -218,7 +232,12 @@ class Exponential(Distribution):
.. math::
cdf(x) = 1.0 - \exp(-1 * \lambda * x) if x >= 0 else 0
"""
rate = self.rate if rate is None else rate
if value is None:
raise_none_error("value")
value = self.cast(value, self.dtype)
rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate
if rate is None:
raise_none_error("rate")
cdf = 1.0 - self.exp(-1. * rate * value)
zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0)
comp = self.less(value, zeros)
......@@ -234,10 +253,14 @@ class Exponential(Distribution):
rate_b (Tensor): rate of distribution b.
rate_a (Tensor): rate of distribution a. Default: self.rate.
"""
if dist == 'Exponential':
rate_a = self.rate if rate_a is None else rate_a
return self.log(rate_a) - self.log(rate_b) + rate_b / rate_a - 1.0
return None
check_distribution_name(dist, 'Exponential')
if rate_b is None:
raise_none_error("rate_b")
rate_b = self.cast(rate_b, self.parameter_type)
rate_a = self.cast(rate_a, self.parameter_type) if rate_a is not None else self.rate
if rate_a is None:
raise_none_error("rate_a")
return self.log(rate_a) - self.log(rate_b) + rate_b / rate_a - 1.0
def _sample(self, shape=(), rate=None):
"""
......@@ -250,7 +273,9 @@ class Exponential(Distribution):
Returns:
Tensor, shape is shape + batch_shape.
"""
rate = self.rate if rate is None else rate
rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate
if rate is None:
raise_none_error("rate")
minval = self.const(self.minval)
maxval = self.const(1.0)
sample_uniform = self.uniform(shape + self.shape(rate), minval, maxval, self.seed)
......
......@@ -18,7 +18,8 @@ from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob, check_type
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\
raise_none_error
class Geometric(Distribution):
"""
......@@ -101,8 +102,9 @@ class Geometric(Distribution):
valid_dtype = mstype.int_type + mstype.uint_type
check_type(dtype, valid_dtype, "Geometric")
super(Geometric, self).__init__(seed, dtype, name, param)
self.parameter_type = mstype.float32
if probs is not None:
self._probs = cast_to_tensor(probs, hint_dtype=mstype.float32)
self._probs = cast_to_tensor(probs, self.parameter_type)
check_prob(self._probs)
else:
self._probs = probs
......@@ -145,7 +147,9 @@ class Geometric(Distribution):
.. math::
MEAN(Geo) = \fratc{1 - probs1}{probs1}
"""
probs1 = self.probs if probs1 is None else probs1
probs1 = self.cast(probs1, self.parameter_type) if probs1 is not None else self.probs
if probs1 is None:
raise_none_error("probs1")
return (1. - probs1) / probs1
def _mode(self, probs1=None):
......@@ -153,7 +157,9 @@ class Geometric(Distribution):
.. math::
MODE(Geo) = 0
"""
probs1 = self.probs if probs1 is None else probs1
probs1 = self.cast(probs1, self.parameter_type) if probs1 is not None else self.probs
if probs1 is None:
raise_none_error("probs1")
return self.fill(self.dtypeop(probs1), self.shape(probs1), 0.)
def _var(self, probs1=None):
......@@ -161,7 +167,9 @@ class Geometric(Distribution):
.. math::
VAR(Geo) = \frac{1 - probs1}{probs1 ^ {2}}
"""
probs1 = self.probs if probs1 is None else probs1
probs1 = self.cast(probs1, self.parameter_type) if probs1 is not None else self.probs
if probs1 is None:
raise_none_error("probs1")
return (1.0 - probs1) / self.sq(probs1)
def _entropy(self, probs=None):
......@@ -169,7 +177,9 @@ class Geometric(Distribution):
.. math::
H(Geo) = \frac{-1 * probs0 \log_2 (1-probs0)\ - prob1 * \log_2 (1-probs1)\ }{probs1}
"""
probs1 = self.probs if probs is None else probs
probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs
if probs1 is None:
raise_none_error("probs")
probs0 = 1.0 - probs1
return (-probs0 * self.log(probs0) - probs1 * self.log(probs1)) / probs1
......@@ -182,9 +192,8 @@ class Geometric(Distribution):
probs1_b (Tensor): probability of success of distribution b.
probs1_a (Tensor): probability of success of distribution a. Default: self.probs.
"""
if dist == 'Geometric':
return self._entropy(probs=probs1_a) + self._kl_loss(dist, probs1_b, probs1_a)
return None
check_distribution_name(dist, 'Geometric')
return self._entropy(probs=probs1_a) + self._kl_loss(dist, probs1_b, probs1_a)
def _prob(self, value, probs=None):
r"""
......@@ -198,14 +207,13 @@ class Geometric(Distribution):
pmf(k) = probs0 ^k * probs1 if k >= 0;
pmf(k) = 0 if k < 0.
"""
probs1 = self.probs if probs is None else probs
dtype = self.dtypeop(value)
if self.issubclass(dtype, mstype.int_):
pass
elif self.issubclass(dtype, mstype.float_):
value = self.floor(value)
else:
return None
if value is None:
raise_none_error("value")
value = self.cast(value, mstype.float32)
value = self.floor(value)
probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs
if probs1 is None:
raise_none_error("probs")
pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1))
zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0)
comp = self.less(value, zeros)
......@@ -224,15 +232,14 @@ class Geometric(Distribution):
cdf(k) = 0 if k < 0.
"""
probs1 = self.probs if probs is None else probs
if value is None:
raise_none_error("value")
value = self.cast(value, mstype.float32)
value = self.floor(value)
probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs
if probs1 is None:
raise_none_error("probs")
probs0 = 1.0 - probs1
dtype = self.dtypeop(value)
if self.issubclass(dtype, mstype.int_):
pass
elif self.issubclass(dtype, mstype.float_):
value = self.floor(value)
else:
return None
cdf = 1.0 - self.pow(probs0, value + 1.0)
zeros = self.fill(self.dtypeop(probs1), self.shape(cdf), 0.0)
comp = self.less(value, zeros)
......@@ -251,12 +258,16 @@ class Geometric(Distribution):
.. math::
KL(a||b) = \log(\frac{probs1_a}{probs1_b}) + \frac{probs0_a}{probs1_a} * \log(\frac{probs0_a}{probs0_b})
"""
if dist == 'Geometric':
probs1_a = self.probs if probs1_a is None else probs1_a
probs0_a = 1.0 - probs1_a
probs0_b = 1.0 - probs1_b
return self.log(probs1_a / probs1_b) + (probs0_a / probs1_a) * self.log(probs0_a / probs0_b)
return None
check_distribution_name(dist, 'Geometric')
if probs1_b is None:
raise_none_error("probs1_b")
probs1_b = self.cast(probs1_b, self.parameter_type)
probs1_a = self.cast(probs1_a, self.parameter_type) if probs1_a is not None else self.probs
if probs1_a is None:
raise_none_error("probs1_a")
probs0_a = 1.0 - probs1_a
probs0_b = 1.0 - probs1_b
return self.log(probs1_a / probs1_b) + (probs0_a / probs1_a) * self.log(probs0_a / probs0_b)
def _sample(self, shape=(), probs=None):
"""
......@@ -269,9 +280,11 @@ class Geometric(Distribution):
Returns:
Tensor, shape is shape + batch_shape.
"""
probs = self.probs if probs is None else probs
probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs
if probs1 is None:
raise_none_error("probs")
minval = self.const(self.minval)
maxval = self.const(1.0)
sample_uniform = self.uniform(shape + self.shape(probs), minval, maxval, self.seed)
sample = self.floor(self.log(sample_uniform) / self.log(1.0 - probs))
sample_uniform = self.uniform(shape + self.shape(probs1), minval, maxval, self.seed)
sample = self.floor(self.log(sample_uniform) / self.log(1.0 - probs1))
return self.cast(sample, self.dtype)
......@@ -18,8 +18,8 @@ from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import convert_to_batch, check_greater_zero, check_type
from ._utils.utils import convert_to_batch, check_greater_zero, check_type, check_distribution_name,\
raise_none_error
class Normal(Distribution):
"""
......@@ -103,9 +103,10 @@ class Normal(Distribution):
valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, "Normal")
super(Normal, self).__init__(seed, dtype, name, param)
self.parameter_type = dtype
if mean is not None and sd is not None:
self._mean_value = convert_to_batch(mean, self.broadcast_shape, dtype)
self._sd_value = convert_to_batch(sd, self.broadcast_shape, dtype)
self._mean_value = convert_to_batch(mean, self.broadcast_shape, self.parameter_type)
self._sd_value = convert_to_batch(sd, self.broadcast_shape, self.parameter_type)
check_greater_zero(self._sd_value, "Standard deviation")
else:
self._mean_value = mean
......@@ -113,6 +114,7 @@ class Normal(Distribution):
#ops needed for the class
self.cast = P.Cast()
self.const = P.ScalarToArray()
self.erf = P.Erf()
self.exp = P.Exp()
......@@ -141,31 +143,51 @@ class Normal(Distribution):
"""
Mean of the distribution.
"""
mean = self._mean_value if mean is None or sd is None else mean
mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value
if mean is None:
raise_none_error("mean")
sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value
if sd is None:
raise_none_error("sd")
return mean
def _mode(self, mean=None, sd=None):
"""
Mode of the distribution.
"""
mean = self._mean_value if mean is None or sd is None else mean
mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value
if mean is None:
raise_none_error("mean")
sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value
if sd is None:
raise_none_error("sd")
return mean
def _sd(self, mean=None, sd=None):
"""
Standard deviation of the distribution.
"""
sd = self._sd_value if mean is None or sd is None else sd
mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value
if mean is None:
raise_none_error("mean")
sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value
if sd is None:
raise_none_error("sd")
return sd
def _entropy(self, sd=None):
def _entropy(self, mean=None, sd=None):
r"""
Evaluate entropy.
.. math::
H(X) = \log(\sqrt(numpy.e * 2. * numpy.pi * \sq(\sigma)))
"""
sd = self._sd_value if sd is None else sd
mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value
if mean is None:
raise_none_error("mean")
sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value
if sd is None:
raise_none_error("sd")
return self.log(self.sqrt(self.const(np.e * 2. * np.pi))) + self.log(sd)
def _cross_entropy(self, dist, mean_b, sd_b, mean_a=None, sd_a=None):
......@@ -179,9 +201,8 @@ class Normal(Distribution):
mean_a (Tensor): mean of distribution a. Default: self._mean_value.
sd_a (Tensor): standard deviation distribution a. Default: self._sd_value.
"""
if dist == 'Normal':
return self._entropy(sd=sd_a) + self._kl_loss(dist, mean_b, sd_b, mean_a, sd_a)
return None
check_distribution_name(dist, 'Normal')
return self._entropy(mean=mean_a, sd=sd_a) + self._kl_loss(dist, mean_b, sd_b, mean_a, sd_a)
def _log_prob(self, value, mean=None, sd=None):
r"""
......@@ -195,10 +216,17 @@ class Normal(Distribution):
.. math::
L(x) = -1* \frac{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2))
"""
mean = self._mean_value if mean is None else mean
sd = self._sd_value if sd is None else sd
if value is None:
raise_none_error("value")
value = self.cast(value, self.dtype)
mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value
if mean is None:
raise_none_error("mean")
sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value
if sd is None:
raise_none_error("sd")
unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd))
neg_normalization = -1. * self.log(self.sqrt(self.const(2. * np.pi))) - self.log(sd)
neg_normalization = -1. * self.log(self.const(2. * np.pi)) / 2. - self.log(sd)
return unnormalized_log_prob + neg_normalization
def _cdf(self, value, mean=None, sd=None):
......@@ -213,8 +241,15 @@ class Normal(Distribution):
.. math::
cdf(x) = 0.5 * (1+ Erf((x - \mu) / ( \sigma * \sqrt(2))))
"""
mean = self._mean_value if mean is None else mean
sd = self._sd_value if sd is None else sd
if value is None:
raise_none_error("value")
value = self.cast(value, self.dtype)
mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value
if mean is None:
raise_none_error("mean")
sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value
if sd is None:
raise_none_error("sd")
sqrt2 = self.sqrt(self.const(2.0))
adjusted = (value - mean) / (sd * sqrt2)
return 0.5 * (1.0 + self.erf(adjusted))
......@@ -234,13 +269,23 @@ class Normal(Distribution):
KL(a||b) = 0.5 * (\frac{MEAN(a)}{STD(b)} - \frac{MEAN(b)}{STD(b)}) ^ 2 +
0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b)))
"""
if dist == 'Normal':
mean_a = self._mean_value if mean_a is None else mean_a
sd_a = self._sd_value if sd_a is None else sd_a
diff_log_scale = self.log(sd_a) - self.log(sd_b)
squared_diff = self.sq(mean_a / sd_b - mean_b / sd_b)
return 0.5 * squared_diff + 0.5 * self.expm1(2 * diff_log_scale) - diff_log_scale
return None
check_distribution_name(dist, 'Normal')
if mean_b is None:
raise_none_error("mean_b")
if sd_b is None:
raise_none_error("sd_b")
mean_b = self.cast(mean_b, self.parameter_type)
sd_b = self.cast(sd_b, self.parameter_type)
mean_a = self.cast(mean_a, self.parameter_type) if mean_a is not None else self._mean_value
sd_a = self.cast(sd_a, self.parameter_type) if sd_a is not None else self._sd_value
if mean_a is None:
raise_none_error("mean_a")
if sd_a is None:
raise_none_error("sd_a")
diff_log_scale = self.log(sd_a) - self.log(sd_b)
squared_diff = self.sq(mean_a / sd_b - mean_b / sd_b)
return 0.5 * squared_diff + 0.5 * self.expm1(2 * diff_log_scale) - diff_log_scale
def _sample(self, shape=(), mean=None, sd=None):
"""
......@@ -254,8 +299,12 @@ class Normal(Distribution):
Returns:
Tensor, shape is shape + batch_shape.
"""
mean = self._mean_value if mean is None else mean
sd = self._sd_value if sd is None else sd
mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value
if mean is None:
raise_none_error("mean")
sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value
if sd is None:
raise_none_error("sd")
batch_shape = self.shape(self.zeroslike(mean) + self.zeroslike(sd))
sample_shape = shape + batch_shape
sample_norm = C.normal(sample_shape, mean, sd, self.seed)
......
......@@ -17,7 +17,8 @@ from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import convert_to_batch, check_greater, check_type
from ._utils.utils import convert_to_batch, check_greater, check_type, check_distribution_name,\
raise_none_error
class Uniform(Distribution):
"""
......@@ -101,6 +102,7 @@ class Uniform(Distribution):
valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, "Uniform")
super(Uniform, self).__init__(seed, dtype, name, param)
self.parameter_type = dtype
if low is not None and high is not None:
self._low = convert_to_batch(low, self.broadcast_shape, dtype)
self._high = convert_to_batch(high, self.broadcast_shape, dtype)
......@@ -153,8 +155,12 @@ class Uniform(Distribution):
.. math::
range(U) = high -low
"""
low = self.low if low is None else low
high = self.high if high is None else high
low = self.cast(low, self.parameter_type) if low is not None else self.low
if low is None:
raise_none_error("low")
high = self.cast(high, self.parameter_type) if high is not None else self.high
if high is None:
raise_none_error("high")
return high - low
def _mean(self, low=None, high=None):
......@@ -162,18 +168,25 @@ class Uniform(Distribution):
.. math::
MEAN(U) = \frac{low + high}{2}.
"""
low = self.low if low is None else low
high = self.high if high is None else high
low = self.cast(low, self.parameter_type) if low is not None else self.low
if low is None:
raise_none_error("low")
high = self.cast(high, self.parameter_type) if high is not None else self.high
if high is None:
raise_none_error("high")
return (low + high) / 2.
def _var(self, low=None, high=None):
r"""
.. math::
VAR(U) = \frac{(high -low) ^ 2}{12}.
"""
low = self.low if low is None else low
high = self.high if high is None else high
low = self.cast(low, self.parameter_type) if low is not None else self.low
if low is None:
raise_none_error("low")
high = self.cast(high, self.parameter_type) if high is not None else self.high
if high is None:
raise_none_error("high")
return self.sq(high - low) / 12.0
def _entropy(self, low=None, high=None):
......@@ -181,8 +194,12 @@ class Uniform(Distribution):
.. math::
H(U) = \log(high - low).
"""
low = self.low if low is None else low
high = self.high if high is None else high
low = self.cast(low, self.parameter_type) if low is not None else self.low
if low is None:
raise_none_error("low")
high = self.cast(high, self.parameter_type) if high is not None else self.high
if high is None:
raise_none_error("high")
return self.log(high - low)
def _cross_entropy(self, dist, low_b, high_b, low_a=None, high_a=None):
......@@ -196,9 +213,8 @@ class Uniform(Distribution):
low_a (Tensor): lower bound of distribution a. Default: self.low.
high_a (Tensor): upper bound of distribution a. Default: self.high.
"""
if dist == 'Uniform':
return self._entropy(low=low_a, high=high_a) + self._kl_loss(dist, low_b, high_b, low_a, high_a)
return None
check_distribution_name(dist, 'Uniform')
return self._entropy(low=low_a, high=high_a) + self._kl_loss(dist, low_b, high_b, low_a, high_a)
def _prob(self, value, low=None, high=None):
r"""
......@@ -214,8 +230,15 @@ class Uniform(Distribution):
pdf(x) = \frac{1.0}{high -low} if low <= x <= high;
pdf(x) = 0 if x > high;
"""
low = self.low if low is None else low
high = self.high if high is None else high
if value is None:
raise_none_error("value")
value = self.cast(value, self.dtype)
low = self.cast(low, self.parameter_type) if low is not None else self.low
if low is None:
raise_none_error("low")
high = self.cast(high, self.parameter_type) if high is not None else self.high
if high is None:
raise_none_error("high")
neg_ones = self.fill(self.dtype, self.shape(value), -1.0)
prob = self.exp(neg_ones * self.log(high - low))
broadcast_shape = self.shape(prob)
......@@ -236,13 +259,22 @@ class Uniform(Distribution):
low_a (Tensor): lower bound of distribution a. Default: self.low.
high_a (Tensor): upper bound of distribution a. Default: self.high.
"""
if dist == 'Uniform':
low_a = self.low if low_a is None else low_a
high_a = self.high if high_a is None else high_a
kl = self.log(high_b - low_b) / self.log(high_a - low_a)
comp = self.logicaland(self.lessequal(low_b, low_a), self.lessequal(high_a, high_b))
return self.select(comp, kl, self.log(self.zeroslike(kl)))
return None
check_distribution_name(dist, 'Uniform')
if low_b is None:
raise_none_error("low_b")
if high_b is None:
raise_none_error("high_b")
low_b = self.cast(low_b, self.parameter_type)
high_b = self.cast(high_b, self.parameter_type)
low_a = self.cast(low_a, self.parameter_type) if low_a is not None else self.low
if low_a is None:
raise_none_error("low_a")
high_a = self.cast(high_a, self.parameter_type) if high_a is not None else self.high
if high_a is None:
raise_none_error("high_a")
kl = self.log(high_b - low_b) / self.log(high_a - low_a)
comp = self.logicaland(self.lessequal(low_b, low_a), self.lessequal(high_a, high_b))
return self.select(comp, kl, self.log(self.zeroslike(kl)))
def _cdf(self, value, low=None, high=None):
r"""
......@@ -258,8 +290,15 @@ class Uniform(Distribution):
cdf(x) = \frac{x - low}{high -low} if low <= x <= high;
cdf(x) = 1 if x > high;
"""
low = self.low if low is None else low
high = self.high if high is None else high
if value is None:
raise_none_error("value")
value = self.cast(value, self.dtype)
low = self.cast(low, self.parameter_type) if low is not None else self.low
if low is None:
raise_none_error("low")
high = self.cast(high, self.parameter_type) if high is not None else self.high
if high is None:
raise_none_error("high")
prob = (value - low) / (high - low)
broadcast_shape = self.shape(prob)
zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0)
......@@ -281,8 +320,12 @@ class Uniform(Distribution):
Returns:
Tensor, shape is shape + batch_shape.
"""
low = self.low if low is None else low
high = self.high if high is None else high
low = self.cast(low, self.parameter_type) if low is not None else self.low
if low is None:
raise_none_error("low")
high = self.cast(high, self.parameter_type) if high is not None else self.high
if high is None:
raise_none_error("high")
broadcast_shape = self.shape(low + high)
l_zero = self.const(0.0)
h_one = self.const(1.0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册