提交 415dad3a 编写于 作者: X Xun Deng

added some parameter checking

上级 9ad82f79
......@@ -17,11 +17,14 @@ Distribution operation utility functions.
"""
from .utils import *
__all__ = ['convert_to_batch',
'cast_to_tensor',
'check_greater',
'check_greater_equal_zero',
'check_greater_zero',
'calc_broadcast_shape_from_param',
'check_scalar_from_param',
'check_prob']
__all__ = [
'convert_to_batch',
'cast_to_tensor',
'check_greater',
'check_greater_equal_zero',
'check_greater_zero',
'calc_broadcast_shape_from_param',
'check_scalar_from_param',
'check_prob',
'check_type',
]
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -23,7 +22,7 @@ from mindspore.ops import operations as P
from mindspore.ops import composite as C
import mindspore.nn as nn
def cast_to_tensor(t, dtype=mstype.float32):
def cast_to_tensor(t, hint_dtype=mstype.float32):
"""
Cast an user input value into a Tensor of dtype.
If the input t is of type Parameter, t is directly returned as a Parameter.
......@@ -41,25 +40,26 @@ def cast_to_tensor(t, dtype=mstype.float32):
if isinstance(t, Parameter):
return t
if isinstance(t, Tensor):
if t.dtype != hint_dtype:
raise TypeError(f"Input tensor should be type {hint_dtype}.")
#check if the Tensor in shape of Tensor(4)
if t.dim() == 0:
value = t.asnumpy()
return Tensor([t], dtype=dtype)
return Tensor([value], dtype=hint_dtype)
#convert the type of tensor to dtype
t.set_dtype(dtype)
return t
if isinstance(t, (list, np.ndarray)):
return Tensor(t, dtype=dtype)
return Tensor(t, dtype=hint_dtype)
if np.isscalar(t):
return Tensor([t], dtype=dtype)
return Tensor([t], dtype=hint_dtype)
raise RuntimeError("Input type is not supported.")
def convert_to_batch(t, batch_shape, dtype):
def convert_to_batch(t, batch_shape, hint_dtype):
"""
Convert a Tensor to a given batch shape.
Args:
t (Tensor, Parameter): Tensor to be converted.
t (int, float, list, numpy.ndarray, Tensor, Parameter): Tensor to be converted.
batch_shape (tuple): desired batch shape.
dtype (mindspore.dtype): desired dtype.
......@@ -71,9 +71,8 @@ def convert_to_batch(t, batch_shape, dtype):
"""
if isinstance(t, Parameter):
return t
if isinstance(t, Tensor):
return Tensor(np.broadcast_to(t.asnumpy(), batch_shape), dtype=dtype)
return Tensor(np.broadcast_to(t, batch_shape), dtype=dtype)
t = cast_to_tensor(t, hint_dtype)
return Tensor(np.broadcast_to(t.asnumpy(), batch_shape), dtype=hint_dtype)
def check_scalar_from_param(params):
"""
......@@ -85,6 +84,8 @@ def check_scalar_from_param(params):
Notes: String parameters are excluded.
"""
for value in params.values():
if isinstance(value, (nn.probability.bijector.Bijector, nn.probability.distribution.Distribution)):
return params['distribution'].is_scalar_batch
if isinstance(value, Parameter):
return False
if isinstance(value, (str, type(params['dtype']))):
......@@ -108,6 +109,8 @@ def calc_broadcast_shape_from_param(params):
"""
broadcast_shape = []
for value in params.values():
if isinstance(value, (nn.probability.bijector.Bijector, nn.probability.distribution.Distribution)):
return params['distribution'].broadcast_shape
if isinstance(value, (str, type(params['dtype']))):
continue
if value is None:
......@@ -251,3 +254,7 @@ def check_tensor_type(name, inputs, valid_type):
inputs = P.DType()(inputs)
if inputs not in valid_type:
raise TypeError(f"{name} dtype is invalid")
def check_type(data_type, value_type, name):
if not data_type in value_type:
raise TypeError(f"For {name}, valid type include {value_type}, {data_type} is invalid")
......@@ -16,7 +16,7 @@
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob
from ._utils.utils import cast_to_tensor, check_prob, check_type
class Bernoulli(Distribution):
"""
......@@ -95,13 +95,14 @@ class Bernoulli(Distribution):
Constructor of Bernoulli distribution.
"""
param = dict(locals())
super(Bernoulli, self).__init__(dtype, name, param)
valid_dtype = mstype.int_type + mstype.uint_type
check_type(dtype, valid_dtype, "Bernoulli")
super(Bernoulli, self).__init__(seed, dtype, name, param)
if probs is not None:
self._probs = cast_to_tensor(probs, dtype=mstype.float32)
self._probs = cast_to_tensor(probs, hint_dtype=mstype.float32)
check_prob(self.probs)
else:
self._probs = probs
self.seed = seed
# ops needed for the class
self.cast = P.Cast()
......@@ -231,8 +232,8 @@ class Bernoulli(Distribution):
probs1_a (Tensor): probs1 of distribution a. Default: self.probs.
.. math::
KL(a||b) = probs1_a * \log(\fract{probs1_a}{probs1_b}) +
probs0_a * \log(\fract{probs0_a}{probs0_b})
KL(a||b) = probs1_a * \log(\frac{probs1_a}{probs1_b}) +
probs0_a * \log(\frac{probs0_a}{probs0_b})
"""
if dist == 'Bernoulli':
probs1_a = self.probs if probs1_a is None else probs1_a
......
......@@ -14,6 +14,7 @@
# ============================================================================
"""basic"""
from mindspore.nn.cell import Cell
from mindspore._checkparam import Validator as validator
from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param
class Distribution(Cell):
......@@ -38,6 +39,7 @@ class Distribution(Cell):
original distribuion.
"""
def __init__(self,
seed,
dtype,
name,
param):
......@@ -46,7 +48,11 @@ class Distribution(Cell):
Constructor of distribution class.
"""
super(Distribution, self).__init__()
validator.check_value_type('name', name, [str], 'distribution_name')
validator.check_value_type('seed', seed, [int], name)
self._name = name
self._seed = seed
self._dtype = dtype
self._parameters = {}
# parsing parameters
......@@ -77,6 +83,10 @@ class Distribution(Cell):
def dtype(self):
return self._dtype
@property
def seed(self):
return self._seed
@property
def parameters(self):
return self._parameters
......@@ -85,6 +95,10 @@ class Distribution(Cell):
def is_scalar_batch(self):
return self._is_scalar_batch
@property
def broadcast_shape(self):
return self._broadcast_shape
def _set_prob(self):
"""
Set probability funtion based on the availability of _prob and _log_likehood.
......
......@@ -17,7 +17,7 @@ import numpy as np
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_greater_zero
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type
class Exponential(Distribution):
"""
......@@ -96,9 +96,11 @@ class Exponential(Distribution):
Constructor of Exponential distribution.
"""
param = dict(locals())
super(Exponential, self).__init__(dtype, name, param)
valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, "Exponential")
super(Exponential, self).__init__(seed, dtype, name, param)
if rate is not None:
self._rate = cast_to_tensor(rate, mstype.float32)
self._rate = cast_to_tensor(rate, dtype)
check_greater_zero(self._rate, "rate")
else:
self._rate = rate
......@@ -135,7 +137,7 @@ class Exponential(Distribution):
def _mean(self, rate=None):
r"""
.. math::
MEAN(EXP) = \fract{1.0}{\lambda}.
MEAN(EXP) = \frac{1.0}{\lambda}.
"""
rate = self.rate if rate is None else rate
return 1.0 / rate
......@@ -152,7 +154,7 @@ class Exponential(Distribution):
def _sd(self, rate=None):
r"""
.. math::
sd(EXP) = \fract{1.0}{\lambda}.
sd(EXP) = \frac{1.0}{\lambda}.
"""
rate = self.rate if rate is None else rate
return 1.0 / rate
......
......@@ -17,7 +17,7 @@ import numpy as np
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob
from ._utils.utils import cast_to_tensor, check_prob, check_type
class Geometric(Distribution):
"""
......@@ -97,9 +97,11 @@ class Geometric(Distribution):
Constructor of Geometric distribution.
"""
param = dict(locals())
super(Geometric, self).__init__(dtype, name, param)
valid_dtype = mstype.int_type + mstype.uint_type
check_type(dtype, valid_dtype, "Geometric")
super(Geometric, self).__init__(seed, dtype, name, param)
if probs is not None:
self._probs = cast_to_tensor(probs, dtype=mstype.float32)
self._probs = cast_to_tensor(probs, hint_dtype=mstype.float32)
check_prob(self._probs)
else:
self._probs = probs
......@@ -154,7 +156,7 @@ class Geometric(Distribution):
def _var(self, probs1=None):
r"""
.. math::
VAR(Geo) = \fract{1 - probs1}{probs1 ^ {2}}
VAR(Geo) = \frac{1 - probs1}{probs1 ^ {2}}
"""
probs1 = self.probs if probs1 is None else probs1
return (1.0 - probs1) / self.sq(probs1)
......@@ -162,7 +164,7 @@ class Geometric(Distribution):
def _entropy(self, probs=None):
r"""
.. math::
H(Geo) = \fract{-1 * probs0 \log_2 (1-probs0)\ - prob1 * \log_2 (1-probs1)\ }{probs1}
H(Geo) = \frac{-1 * probs0 \log_2 (1-probs0)\ - prob1 * \log_2 (1-probs1)\ }{probs1}
"""
probs1 = self.probs if probs is None else probs
probs0 = 1.0 - probs1
......@@ -244,7 +246,7 @@ class Geometric(Distribution):
probs1_a (Tensor): probability of success of distribution a. Default: self.probs.
.. math::
KL(a||b) = \log(\fract{probs1_a}{probs1_b}) + \fract{probs0_a}{probs1_a} * \log(\fract{probs0_a}{probs0_b})
KL(a||b) = \log(\frac{probs1_a}{probs1_b}) + \frac{probs0_a}{probs1_a} * \log(\frac{probs0_a}{probs0_b})
"""
if dist == 'Geometric':
probs1_a = self.probs if probs1_a is None else probs1_a
......
......@@ -18,7 +18,7 @@ from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import convert_to_batch, check_greater_equal_zero
from ._utils.utils import convert_to_batch, check_greater_equal_zero, check_type
class Normal(Distribution):
......@@ -100,15 +100,17 @@ class Normal(Distribution):
Constructor of normal distribution.
"""
param = dict(locals())
super(Normal, self).__init__(dtype, name, param)
valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, "Normal")
super(Normal, self).__init__(seed, dtype, name, param)
if mean is not None and sd is not None:
self._mean_value = convert_to_batch(mean, self._broadcast_shape, dtype)
self._sd_value = convert_to_batch(sd, self._broadcast_shape, dtype)
self._mean_value = convert_to_batch(mean, self.broadcast_shape, dtype)
self._sd_value = convert_to_batch(sd, self.broadcast_shape, dtype)
check_greater_equal_zero(self._sd_value, "Standard deviation")
else:
self._mean_value = mean
self._sd_value = sd
self.seed = seed
#ops needed for the class
self.const = P.ScalarToArray()
......@@ -191,7 +193,7 @@ class Normal(Distribution):
sd (Tensor): standard deviation the distribution. Default: self._sd_value.
.. math::
L(x) = -1* \fract{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2))
L(x) = -1* \frac{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2))
"""
mean = self._mean_value if mean is None else mean
sd = self._sd_value if sd is None else sd
......@@ -229,7 +231,7 @@ class Normal(Distribution):
sd_a (Tensor): standard deviation distribution a. Default: self._sd_value.
.. math::
KL(a||b) = 0.5 * (\fract{MEAN(a)}{STD(b)} - \fract{MEAN(b)}{STD(b)}) ^ 2 +
KL(a||b) = 0.5 * (\frac{MEAN(a)}{STD(b)} - \frac{MEAN(b)}{STD(b)}) ^ 2 +
0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b)))
"""
if dist == 'Normal':
......
......@@ -14,7 +14,11 @@
# ============================================================================
"""Transformed Distribution"""
from mindspore.ops import operations as P
from mindspore._checkparam import Validator as validator
from mindspore.common import dtype as mstype
import mindspore.nn as nn
from .distribution import Distribution
from ._utils.utils import check_type
class TransformedDistribution(Distribution):
"""
......@@ -35,12 +39,19 @@ class TransformedDistribution(Distribution):
def __init__(self,
bijector,
distribution,
dtype,
seed=0,
name="transformed_distribution"):
"""
Constructor of transformed_distribution class.
"""
param = dict(locals())
super(TransformedDistribution, self).__init__(distribution.dtype, name, param)
validator.check_value_type('bijector', bijector, [nn.probability.bijector.Bijector], name)
validator.check_value_type('distribution', distribution, [Distribution], name)
valid_dtype = mstype.number_type
check_type(dtype, valid_dtype, "transformed_distribution")
super(TransformedDistribution, self).__init__(seed, dtype, name, param)
self._bijector = bijector
self._distribution = distribution
self._is_linear_transformation = bijector.is_constant_jacobian
......
......@@ -16,7 +16,7 @@
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import convert_to_batch, check_greater
from ._utils.utils import convert_to_batch, check_greater, check_type
class Uniform(Distribution):
"""
......@@ -97,10 +97,12 @@ class Uniform(Distribution):
Constructor of Uniform distribution.
"""
param = dict(locals())
super(Uniform, self).__init__(dtype, name, param)
valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, "Uniform")
super(Uniform, self).__init__(seed, dtype, name, param)
if low is not None and high is not None:
self._low = convert_to_batch(low, self._broadcast_shape, dtype)
self._high = convert_to_batch(high, self._broadcast_shape, dtype)
self._low = convert_to_batch(low, self.broadcast_shape, dtype)
self._high = convert_to_batch(high, self.broadcast_shape, dtype)
check_greater(self.low, self.high, "low value", "high value")
else:
self._low = low
......@@ -156,7 +158,7 @@ class Uniform(Distribution):
def _mean(self, low=None, high=None):
r"""
.. math::
MEAN(U) = \fract{low + high}{2}.
MEAN(U) = \frac{low + high}{2}.
"""
low = self.low if low is None else low
high = self.high if high is None else high
......@@ -166,7 +168,7 @@ class Uniform(Distribution):
def _var(self, low=None, high=None):
r"""
.. math::
VAR(U) = \fract{(high -low) ^ 2}{12}.
VAR(U) = \frac{(high -low) ^ 2}{12}.
"""
low = self.low if low is None else low
high = self.high if high is None else high
......@@ -207,7 +209,7 @@ class Uniform(Distribution):
.. math::
pdf(x) = 0 if x < low;
pdf(x) = \fract{1.0}{high -low} if low <= x <= high;
pdf(x) = \frac{1.0}{high -low} if low <= x <= high;
pdf(x) = 0 if x > high;
"""
low = self.low if low is None else low
......@@ -251,7 +253,7 @@ class Uniform(Distribution):
.. math::
cdf(x) = 0 if x < low;
cdf(x) = \fract{x - low}{high -low} if low <= x <= high;
cdf(x) = \frac{x - low}{high -low} if low <= x <= high;
cdf(x) = 1 if x > high;
"""
low = self.low if low is None else low
......
......@@ -31,6 +31,18 @@ def test_arguments():
b = msd.Bernoulli([0.0, 0.3, 0.5, 1.0], dtype=dtype.int32)
assert isinstance(b, msd.Distribution)
def test_type():
with pytest.raises(TypeError):
msd.Bernoulli([0.1], dtype=dtype.float32)
def test_name():
with pytest.raises(TypeError):
msd.Bernoulli([0.1], name=1.0)
def test_seed():
with pytest.raises(TypeError):
msd.Bernoulli([0.1], seed='seed')
def test_prob():
"""
Invalid probability.
......
......@@ -32,6 +32,18 @@ def test_arguments():
e = msd.Exponential([0.1, 0.3, 0.5, 1.0], dtype=dtype.float32)
assert isinstance(e, msd.Distribution)
def test_type():
with pytest.raises(TypeError):
msd.Exponential([0.1], dtype=dtype.int32)
def test_name():
with pytest.raises(TypeError):
msd.Exponential([0.1], name=1.0)
def test_seed():
with pytest.raises(TypeError):
msd.Exponential([0.1], seed='seed')
def test_rate():
"""
Invalid rate.
......
......@@ -32,6 +32,18 @@ def test_arguments():
g = msd.Geometric([0.0, 0.3, 0.5, 1.0], dtype=dtype.int32)
assert isinstance(g, msd.Distribution)
def test_type():
with pytest.raises(TypeError):
msd.Geometric([0.1], dtype=dtype.float32)
def test_name():
with pytest.raises(TypeError):
msd.Geometric([0.1], name=1.0)
def test_seed():
with pytest.raises(TypeError):
msd.Geometric([0.1], seed='seed')
def test_prob():
"""
Invalid probability.
......
......@@ -30,6 +30,17 @@ def test_normal_shape_errpr():
with pytest.raises(ValueError):
msd.Normal([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32)
def test_type():
with pytest.raises(TypeError):
msd.Normal(0., 1., dtype=dtype.int32)
def test_name():
with pytest.raises(TypeError):
msd.Normal(0., 1., name=1.0)
def test_seed():
with pytest.raises(TypeError):
msd.Normal(0., 1., seed='seed')
def test_arguments():
"""
......
......@@ -30,6 +30,17 @@ def test_uniform_shape_errpr():
with pytest.raises(ValueError):
msd.Uniform([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32)
def test_type():
with pytest.raises(TypeError):
msd.Uniform(0., 1., dtype=dtype.int32)
def test_name():
with pytest.raises(TypeError):
msd.Uniform(0., 1., name=1.0)
def test_seed():
with pytest.raises(TypeError):
msd.Uniform(0., 1., seed='seed')
def test_arguments():
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册