提交 4fe8b3d3 编写于 作者: X Xun Deng

fix checktensor in pynative mode

上级 b8da525f
......@@ -13,8 +13,10 @@
# limitations under the License.
# ============================================================================
"""Bijector"""
from mindspore import context
from mindspore.nn.cell import Cell
from mindspore._checkparam import Validator as validator
from ..distribution._utils.utils import CheckTensor
from ..distribution import Distribution
from ..distribution import TransformedDistribution
......@@ -40,7 +42,7 @@ class Bijector(Cell):
Constructor of bijector class.
"""
super(Bijector, self).__init__()
validator.check_value_type('name', name, [str], 'Bijector')
validator.check_value_type('name', name, [str], type(self).__name__)
validator.check_value_type('is_constant_jacobian', is_constant_jacobian, [bool], name)
validator.check_value_type('is_injective', is_injective, [bool], name)
self._name = name
......@@ -53,6 +55,9 @@ class Bijector(Cell):
self._is_constant_jacobian = is_constant_jacobian
self._is_injective = is_injective
self.context_mode = context.get_context('mode')
self.checktensor = CheckTensor()
@property
def name(self):
return self._name
......@@ -73,6 +78,15 @@ class Bijector(Cell):
def is_injective(self):
return self._is_injective
def _check_value(self, value, name):
"""
Check availability fo value as a Tensor.
"""
if self.context_mode == 0:
self.checktensor(value, name)
return value
return self.checktensor(value, name)
def forward(self, *args, **kwargs):
"""
Forward transformation: transform the input value to another distribution.
......
......@@ -16,7 +16,6 @@
from mindspore.ops import operations as P
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from ..distribution._utils.utils import CheckTensor
from ..distribution._utils.custom_ops import exp_generic, expm1_generic, log_generic, log1p_generic
from .bijector import Bijector
......@@ -66,8 +65,6 @@ class PowerTransform(Bijector):
self.log = log_generic
self.log1p = log1p_generic
self.checktensor = CheckTensor()
@property
def power(self):
return self._power
......@@ -80,13 +77,13 @@ class PowerTransform(Bijector):
return shape
def _forward(self, x):
self.checktensor(x, 'value')
x = self._check_value(x, 'value')
if self.power == 0:
return self.exp(x)
return self.exp(self.log1p(x * self.power) / self.power)
def _inverse(self, y):
self.checktensor(y, 'value')
y = self._check_value(y, 'value')
if self.power == 0:
return self.log(y)
return self.expm1(self.log(y) * self.power) / self.power
......@@ -103,7 +100,7 @@ class PowerTransform(Bijector):
f'(x) = e^\frac{\log(xc + 1)}{c} * \frac{1}{xc + 1}
\log(f'(x)) = (\frac{1}{c} - 1) * \log(xc + 1)
"""
self.checktensor(x, 'value')
x = self._check_value(x, 'value')
if self.power == 0:
return x
return (1. / self.power - 1) * self.log1p(x * self.power)
......@@ -120,5 +117,5 @@ class PowerTransform(Bijector):
f'(x) = \frac{e^c\log(y)}{y}
\log(f'(x)) = \log(\frac{e^c\log(y)}{y}) = (c-1) * \log(y)
"""
self.checktensor(y, 'value')
y = self._check_value(y, 'value')
return (self.power - 1) * self.log(y)
......@@ -15,7 +15,7 @@
"""Scalar Affine Bijector"""
from mindspore.ops import operations as P
from mindspore._checkparam import Validator as validator
from ..distribution._utils.utils import cast_to_tensor, CheckTensor
from ..distribution._utils.utils import cast_to_tensor
from ..distribution._utils.custom_ops import log_generic
from .bijector import Bijector
......@@ -57,8 +57,8 @@ class ScalarAffine(Bijector):
Constructor of scalar affine bijector.
"""
param = dict(locals())
validator.check_value_type('scale', scale, [int, float], name)
validator.check_value_type('shift', shift, [int, float], name)
validator.check_value_type('scale', scale, [int, float], type(self).__name__)
validator.check_value_type('shift', shift, [int, float], type(self).__name__)
self._scale = cast_to_tensor(scale)
self._shift = cast_to_tensor(shift)
super(ScalarAffine, self).__init__(
......@@ -71,8 +71,6 @@ class ScalarAffine(Bijector):
self.abs = P.Abs()
self.log = log_generic
self.checktensor = CheckTensor()
@property
def scale(self):
return self._scale
......@@ -93,7 +91,7 @@ class ScalarAffine(Bijector):
.. math::
f(x) = a * x + b
"""
self.checktensor(x, 'value')
x = self._check_value(x, 'value')
return self.scale * x + self.shift
def _inverse(self, y):
......@@ -101,7 +99,7 @@ class ScalarAffine(Bijector):
.. math::
f(y) = \frac{y - b}{a}
"""
self.checktensor(y, 'value')
y = self._check_value(y, 'value')
return (y - self.shift) / self.scale
def _forward_log_jacobian(self, x):
......@@ -111,7 +109,7 @@ class ScalarAffine(Bijector):
f'(x) = a
\log(f'(x)) = \log(a)
"""
self.checktensor(x, 'value')
x = self._check_value(x, 'value')
return self.log(self.abs(self.scale))
def _inverse_log_jacobian(self, y):
......@@ -121,5 +119,5 @@ class ScalarAffine(Bijector):
f'(x) = \frac{1.0}{a}
\log(f'(x)) = - \log(a)
"""
self.checktensor(y, 'value')
y = self._check_value(y, 'value')
return -1. * self.log(self.abs(self.scale))
......@@ -18,7 +18,7 @@ from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
from mindspore.nn.layer.activation import LogSigmoid
from mindspore._checkparam import Validator as validator
from ..distribution._utils.utils import cast_to_tensor, CheckTensor
from ..distribution._utils.utils import cast_to_tensor
from ..distribution._utils.custom_ops import exp_generic, expm1_generic, log_generic
from .bijector import Bijector
......@@ -57,7 +57,7 @@ class Softplus(Bijector):
sharpness=1.0,
name='Softplus'):
param = dict(locals())
validator.check_value_type('sharpness', sharpness, [int, float], name)
validator.check_value_type('sharpness', sharpness, [int, float], type(self).__name__)
super(Softplus, self).__init__(name=name, param=param)
self._sharpness = cast_to_tensor(sharpness)
......@@ -76,7 +76,6 @@ class Softplus(Bijector):
self.softplus = self._softplus
self.inverse_softplus = self._inverse_softplus
self.checktensor = CheckTensor()
self.threshold = np.log(np.finfo(np.float32).eps) + 1
self.tiny = np.exp(self.threshold)
......@@ -119,7 +118,7 @@ class Softplus(Bijector):
return shape
def _forward(self, x):
self.checktensor(x, 'value')
x = self._check_value(x, 'value')
scaled_value = self.sharpness * x
return self.softplus(scaled_value) / self.sharpness
......@@ -129,7 +128,7 @@ class Softplus(Bijector):
f(x) = \frac{\log(1 + e^{kx}))}{k}
f^{-1}(y) = \frac{\log(e^{ky} - 1)}{k}
"""
self.checktensor(y, 'value')
y = self._check_value(y, 'value')
scaled_value = self.sharpness * y
return self.inverse_softplus(scaled_value) / self.sharpness
......@@ -140,7 +139,7 @@ class Softplus(Bijector):
f'(x) = \frac{e^{kx}}{ 1 + e^{kx}}
\log(f'(x)) = kx - \log(1 + e^{kx}) = kx - f(kx)
"""
self.checktensor(x, 'value')
x = self._check_value(x, 'value')
scaled_value = self.sharpness * x
return self.log_sigmoid(scaled_value)
......@@ -151,6 +150,6 @@ class Softplus(Bijector):
f'(y) = \frac{e^{ky}}{e^{ky} - 1}
\log(f'(y)) = ky - \log(e^{ky} - 1) = ky - f(ky)
"""
self.checktensor(y, 'value')
y = self._check_value(y, 'value')
scaled_value = self.sharpness * y
return scaled_value - self.inverse_softplus(scaled_value)
......@@ -342,7 +342,7 @@ class CheckTuple(PrimitiveWithInfer):
# Pynative mode
if isinstance(x, tuple):
return x
raise TypeError(f"For {name['value']}, Input type should b a tuple.")
raise TypeError(f"For {name}, input type should be a tuple.")
class CheckTensor(PrimitiveWithInfer):
......@@ -365,4 +365,6 @@ class CheckTensor(PrimitiveWithInfer):
return out
def __call__(self, x, name):
return
if isinstance(x, Tensor):
return x
raise TypeError(f"For {name}, input type should be a Tensor.")
......@@ -99,7 +99,7 @@ class Bernoulli(Distribution):
"""
param = dict(locals())
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
check_type(dtype, valid_dtype, "Bernoulli")
check_type(dtype, valid_dtype, type(self).__name__)
super(Bernoulli, self).__init__(seed, dtype, name, param)
self.parameter_type = mstype.float32
if probs is not None:
......@@ -144,7 +144,10 @@ class Bernoulli(Distribution):
Check availablity of distribution specific args probs1.
"""
if probs1 is not None:
self.checktensor(probs1, 'probs1')
if self.context_mode == 0:
self.checktensor(probs1, 'probs1')
else:
probs1 = self.checktensor(probs1, 'probs1')
return self.cast(probs1, self.parameter_type)
return self.probs if self.probs is not None else raise_none_error('probs1')
......@@ -210,7 +213,7 @@ class Bernoulli(Distribution):
pmf(k) = probs1 if k = 1;
pmf(k) = probs0 if k = 0;
"""
self.checktensor(value, 'value')
value = self._check_value(value, 'value')
value = self.cast(value, mstype.float32)
probs1 = self._check_param(probs1)
probs0 = 1.0 - probs1
......@@ -229,7 +232,7 @@ class Bernoulli(Distribution):
cdf(k) = probs0 if 0 <= k <1;
cdf(k) = 1 if k >=1;
"""
self.checktensor(value, 'value')
value = self._check_value(value, 'value')
value = self.cast(value, mstype.float32)
value = self.floor(value)
probs1 = self._check_param(probs1)
......@@ -257,7 +260,7 @@ class Bernoulli(Distribution):
probs0_a * \log(\frac{probs0_a}{probs0_b})
"""
check_distribution_name(dist, 'Bernoulli')
self.checktensor(probs1_b, 'probs1_b')
probs1_b = self._check_value(probs1_b, 'probs1_b')
probs1_b = self.cast(probs1_b, self.parameter_type)
probs1_a = self._check_param(probs1)
probs0_a = 1.0 - probs1_a
......
......@@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
"""basic"""
from mindspore import context
from mindspore.nn.cell import Cell
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
......@@ -54,7 +55,7 @@ class Distribution(Cell):
Constructor of distribution class.
"""
super(Distribution, self).__init__()
validator.check_value_type('name', name, [str], 'distribution_name')
validator.check_value_type('name', name, [str], type(self).__name__)
validator.check_integer('seed', seed, 0, Rel.GE, name)
self._name = name
......@@ -81,6 +82,7 @@ class Distribution(Cell):
self._set_log_survival()
self._set_cross_entropy()
self.context_mode = context.get_context('mode')
self.checktuple = CheckTuple()
self.checktensor = CheckTensor()
......@@ -108,6 +110,15 @@ class Distribution(Cell):
def broadcast_shape(self):
return self._broadcast_shape
def _check_value(self, value, name):
"""
Check availability fo value as a Tensor.
"""
if self.context_mode == 0:
self.checktensor(value, name)
return value
return self.checktensor(value, name)
def _set_prob(self):
"""
Set probability funtion based on the availability of _prob and _log_likehood.
......
......@@ -100,7 +100,7 @@ class Exponential(Distribution):
"""
param = dict(locals())
valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, "Exponential")
check_type(dtype, valid_dtype, type(self).__name__)
super(Exponential, self).__init__(seed, dtype, name, param)
self.parameter_type = dtype
if rate is not None:
......@@ -146,7 +146,10 @@ class Exponential(Distribution):
Check availablity of distribution specific args rate.
"""
if rate is not None:
self.checktensor(rate, 'rate')
if self.context_mode == 0:
self.checktensor(rate, 'rate')
else:
rate = self.checktensor(rate, 'rate')
return self.cast(rate, self.parameter_type)
return self.rate if self.rate is not None else raise_none_error('rate')
......@@ -210,7 +213,7 @@ class Exponential(Distribution):
.. math::
pdf(x) = rate * \exp(-1 * \lambda * x) if x >= 0 else 0
"""
self.checktensor(value, "value")
value = self._check_value(value, "value")
value = self.cast(value, self.dtype)
rate = self._check_param(rate)
prob = self.exp(self.log(rate) - rate * value)
......@@ -232,7 +235,7 @@ class Exponential(Distribution):
.. math::
cdf(x) = 1.0 - \exp(-1 * \lambda * x) if x >= 0 else 0
"""
self.checktensor(value, 'value')
value = self._check_value(value, 'value')
value = self.cast(value, self.dtype)
rate = self._check_param(rate)
cdf = 1.0 - self.exp(-1. * rate * value)
......@@ -251,7 +254,7 @@ class Exponential(Distribution):
rate_a (Tensor): rate of distribution a. Default: self.rate.
"""
check_distribution_name(dist, 'Exponential')
self.checktensor(rate_b, 'rate_b')
rate_b = self._check_value(rate_b, 'rate_b')
rate_b = self.cast(rate_b, self.parameter_type)
rate_a = self._check_param(rate)
return self.log(rate_a) - self.log(rate_b) + rate_b / rate_a - 1.0
......
......@@ -102,7 +102,7 @@ class Geometric(Distribution):
"""
param = dict(locals())
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
check_type(dtype, valid_dtype, "Geometric")
check_type(dtype, valid_dtype, type(self).__name__)
super(Geometric, self).__init__(seed, dtype, name, param)
self.parameter_type = mstype.float32
if probs is not None:
......@@ -150,7 +150,10 @@ class Geometric(Distribution):
Check availablity of distribution specific args probs1.
"""
if probs1 is not None:
self.checktensor(probs1, 'probs1')
if self.context_mode == 0:
self.checktensor(probs1, 'probs1')
else:
probs1 = self.checktensor(probs1, 'probs1')
return self.cast(probs1, self.parameter_type)
return self.probs if self.probs is not None else raise_none_error('probs1')
......@@ -211,7 +214,7 @@ class Geometric(Distribution):
pmf(k) = probs0 ^k * probs1 if k >= 0;
pmf(k) = 0 if k < 0.
"""
self.checktensor(value, 'value')
value = self._check_value(value, 'value')
value = self.cast(value, mstype.float32)
value = self.floor(value)
probs1 = self._check_param(probs1)
......@@ -233,7 +236,7 @@ class Geometric(Distribution):
cdf(k) = 0 if k < 0.
"""
self.checktensor(value, 'value')
value = self._check_value(value, 'value')
value = self.cast(value, mstype.float32)
value = self.floor(value)
probs1 = self._check_param(probs1)
......@@ -256,7 +259,7 @@ class Geometric(Distribution):
KL(a||b) = \log(\frac{probs1_a}{probs1_b}) + \frac{probs0_a}{probs1_a} * \log(\frac{probs0_a}{probs0_b})
"""
check_distribution_name(dist, 'Geometric')
self.checktensor(probs1_b, 'probs1_b')
probs1_b = self._check_value(probs1_b, 'probs1_b')
probs1_b = self.cast(probs1_b, self.parameter_type)
probs1_a = self._check_param(probs1)
probs0_a = 1.0 - probs1_a
......
......@@ -18,7 +18,7 @@ from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import convert_to_batch, check_greater_zero, check_type, check_distribution_name,\
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\
raise_none_error
from ._utils.custom_ops import exp_generic, expm1_generic, log_generic, erf_generic
......@@ -102,12 +102,12 @@ class Normal(Distribution):
"""
param = dict(locals())
valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, "Normal")
check_type(dtype, valid_dtype, type(self).__name__)
super(Normal, self).__init__(seed, dtype, name, param)
self.parameter_type = dtype
if mean is not None and sd is not None:
self._mean_value = convert_to_batch(mean, self.broadcast_shape, self.parameter_type)
self._sd_value = convert_to_batch(sd, self.broadcast_shape, self.parameter_type)
self._mean_value = cast_to_tensor(mean, self.parameter_type)
self._sd_value = cast_to_tensor(sd, self.parameter_type)
check_greater_zero(self._sd_value, "Standard deviation")
else:
self._mean_value = mean
......@@ -139,12 +139,18 @@ class Normal(Distribution):
Check availablity of distribution specific args mean and sd.
"""
if mean is not None:
self.checktensor(mean, 'mean')
if self.context_mode == 0:
self.checktensor(mean, 'mean')
else:
mean = self.checktensor(mean, 'mean')
mean = self.cast(mean, self.parameter_type)
else:
mean = self._mean_value if self._mean_value is not None else raise_none_error('mean')
if sd is not None:
self.checktensor(sd, 'sd')
if self.context_mode == 0:
self.checktensor(sd, 'sd')
else:
sd = self.checktensor(sd, 'sd')
sd = self.cast(sd, self.parameter_type)
else:
sd = self._sd_value if self._sd_value is not None else raise_none_error('sd')
......@@ -210,7 +216,7 @@ class Normal(Distribution):
.. math::
L(x) = -1* \frac{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2))
"""
self.checktensor(value, 'value')
value = self._check_value(value, 'value')
value = self.cast(value, self.dtype)
mean, sd = self._check_param(mean, sd)
unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd))
......@@ -229,7 +235,7 @@ class Normal(Distribution):
.. math::
cdf(x) = 0.5 * (1+ Erf((x - \mu) / ( \sigma * \sqrt(2))))
"""
self.checktensor(value, 'value')
value = self._check_value(value, 'value')
value = self.cast(value, self.dtype)
mean, sd = self._check_param(mean, sd)
sqrt2 = self.sqrt(self.const(2.0))
......@@ -252,8 +258,8 @@ class Normal(Distribution):
0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b)))
"""
check_distribution_name(dist, 'Normal')
self.checktensor(mean_b, 'mean_b')
self.checktensor(sd_b, 'sd_b')
mean_b = self._check_value(mean_b, 'mean_b')
sd_b = self._check_value(sd_b, 'sd_b')
mean_b = self.cast(mean_b, self.parameter_type)
sd_b = self.cast(sd_b, self.parameter_type)
mean_a, sd_a = self._check_param(mean, sd)
......
......@@ -46,10 +46,10 @@ class TransformedDistribution(Distribution):
Constructor of transformed_distribution class.
"""
param = dict(locals())
validator.check_value_type('bijector', bijector, [nn.probability.bijector.Bijector], name)
validator.check_value_type('distribution', distribution, [Distribution], name)
validator.check_value_type('bijector', bijector, [nn.probability.bijector.Bijector], type(self).__name__)
validator.check_value_type('distribution', distribution, [Distribution], type(self).__name__)
valid_dtype = mstype.number_type
check_type(dtype, valid_dtype, "transformed_distribution")
check_type(dtype, valid_dtype, type(self).__name__)
super(TransformedDistribution, self).__init__(seed, dtype, name, param)
self._bijector = bijector
......
......@@ -17,7 +17,7 @@ from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import convert_to_batch, check_greater, check_type, check_distribution_name,\
from ._utils.utils import cast_to_tensor, check_greater, check_type, check_distribution_name,\
raise_none_error
from ._utils.custom_ops import exp_generic, log_generic
......@@ -101,12 +101,12 @@ class Uniform(Distribution):
"""
param = dict(locals())
valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, "Uniform")
check_type(dtype, valid_dtype, type(self).__name__)
super(Uniform, self).__init__(seed, dtype, name, param)
self.parameter_type = dtype
if low is not None and high is not None:
self._low = convert_to_batch(low, self.broadcast_shape, dtype)
self._high = convert_to_batch(high, self.broadcast_shape, dtype)
self._low = cast_to_tensor(low, dtype)
self._high = cast_to_tensor(high, dtype)
check_greater(self.low, self.high, "low value", "high value")
else:
self._low = low
......@@ -142,12 +142,18 @@ class Uniform(Distribution):
Check availablity of distribution specific args low and high.
"""
if low is not None:
self.checktensor(low, 'low')
if self.context_mode == 0:
self.checktensor(low, 'low')
else:
low = self.checktensor(low, 'low')
low = self.cast(low, self.parameter_type)
else:
low = self.low if self.low is not None else raise_none_error('low')
if high is not None:
self.checktensor(high, 'high')
if self.context_mode == 0:
self.checktensor(high, 'high')
else:
high = self.checktensor(high, 'high')
high = self.cast(high, self.parameter_type)
else:
high = self.high if self.high is not None else raise_none_error('high')
......@@ -231,7 +237,7 @@ class Uniform(Distribution):
pdf(x) = \frac{1.0}{high -low} if low <= x <= high;
pdf(x) = 0 if x > high;
"""
self.checktensor(value, 'value')
value = self._check_value(value, 'value')
value = self.cast(value, self.dtype)
low, high = self._check_param(low, high)
neg_ones = self.fill(self.dtype, self.shape(value), -1.0)
......@@ -255,9 +261,9 @@ class Uniform(Distribution):
high_a (Tensor): upper bound of distribution a. Default: self.high.
"""
check_distribution_name(dist, 'Uniform')
self.checktensor(low_b, 'low_b')
low_b = self._check_value(low_b, 'low_b')
low_b = self.cast(low_b, self.parameter_type)
self.checktensor(high_b, 'high_b')
high_b = self._check_value(high_b, 'high_b')
high_b = self.cast(high_b, self.parameter_type)
low_a, high_a = self._check_param(low, high)
kl = self.log(high_b - low_b) - self.log(high_a - low_a)
......@@ -278,7 +284,7 @@ class Uniform(Distribution):
cdf(x) = \frac{x - low}{high -low} if low <= x <= high;
cdf(x) = 1 if x > high;
"""
self.checktensor(value, 'value')
value = self._check_value(value, 'value')
value = self.cast(value, self.dtype)
low, high = self._check_param(low, high)
prob = (value - low) / (high - low)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册