提交 dc11fa9f 编写于 作者: X Xun Deng

Fixed CheckTuple issues and error message

上级 04decda0
...@@ -22,6 +22,7 @@ from mindspore.common.parameter import Parameter ...@@ -22,6 +22,7 @@ from mindspore.common.parameter import Parameter
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore import context
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.nn.probability as msp import mindspore.nn.probability as msp
...@@ -273,7 +274,8 @@ def check_type(data_type, value_type, name): ...@@ -273,7 +274,8 @@ def check_type(data_type, value_type, name):
@constexpr @constexpr
def raise_none_error(name): def raise_none_error(name):
raise ValueError(f"{name} should be specified. Value cannot be None") raise TypeError(f"the type {name} should be subclass of Tensor."
f" It should not be None since it is not specified during initialization.")
@constexpr @constexpr
def raise_not_impl_error(name): def raise_not_impl_error(name):
...@@ -298,15 +300,20 @@ class CheckTuple(PrimitiveWithInfer): ...@@ -298,15 +300,20 @@ class CheckTuple(PrimitiveWithInfer):
def __infer__(self, x, name): def __infer__(self, x, name):
if not isinstance(x['dtype'], tuple): if not isinstance(x['dtype'], tuple):
raise TypeError("Input type should be a tuple: " + name["value"]) raise TypeError(f"For {name['value']}, Input type should b a tuple.")
out = {'shape': None, out = {'shape': None,
'dtype': None, 'dtype': None,
'value': None} 'value': x["value"]}
return out return out
def __call__(self, *args): def __call__(self, x, name):
return if context.get_context("mode") == 0:
return x["value"]
#Pynative mode
if isinstance(x, tuple):
return x
raise TypeError(f"For {name['value']}, Input type should b a tuple.")
class CheckTensor(PrimitiveWithInfer): class CheckTensor(PrimitiveWithInfer):
""" """
...@@ -327,5 +334,5 @@ class CheckTensor(PrimitiveWithInfer): ...@@ -327,5 +334,5 @@ class CheckTensor(PrimitiveWithInfer):
'value': None} 'value': None}
return out return out
def __call__(self, *args): def __call__(self, x, name):
return return
...@@ -18,7 +18,6 @@ from mindspore.ops import operations as P ...@@ -18,7 +18,6 @@ from mindspore.ops import operations as P
from mindspore.ops import composite as C from mindspore.ops import composite as C
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error
from ._utils.utils import CheckTensor, CheckTuple
from ._utils.custom_ops import log_by_step from ._utils.custom_ops import log_by_step
class Bernoulli(Distribution): class Bernoulli(Distribution):
...@@ -125,9 +124,6 @@ class Bernoulli(Distribution): ...@@ -125,9 +124,6 @@ class Bernoulli(Distribution):
self.sqrt = P.Sqrt() self.sqrt = P.Sqrt()
self.uniform = C.uniform self.uniform = C.uniform
self.checktensor = CheckTensor()
self.checktuple = CheckTuple()
def extend_repr(self): def extend_repr(self):
if self.is_scalar_batch: if self.is_scalar_batch:
str_info = f'probs = {self.probs}' str_info = f'probs = {self.probs}'
...@@ -279,7 +275,7 @@ class Bernoulli(Distribution): ...@@ -279,7 +275,7 @@ class Bernoulli(Distribution):
Returns: Returns:
Tensor, shape is shape + batch_shape. Tensor, shape is shape + batch_shape.
""" """
self.checktuple(shape, 'shape') shape = self.checktuple(shape, 'shape')
probs1 = self._check_param(probs1) probs1 = self._check_param(probs1)
origin_shape = shape + self.shape(probs1) origin_shape = shape + self.shape(probs1)
if origin_shape == (): if origin_shape == ():
......
...@@ -17,6 +17,7 @@ from mindspore.nn.cell import Cell ...@@ -17,6 +17,7 @@ from mindspore.nn.cell import Cell
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel from mindspore._checkparam import Rel
from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param
from ._utils.utils import CheckTuple, CheckTensor
class Distribution(Cell): class Distribution(Cell):
""" """
...@@ -79,6 +80,9 @@ class Distribution(Cell): ...@@ -79,6 +80,9 @@ class Distribution(Cell):
self._set_log_survival() self._set_log_survival()
self._set_cross_entropy() self._set_cross_entropy()
self.checktuple = CheckTuple()
self.checktensor = CheckTensor()
@property @property
def name(self): def name(self):
return self._name return self._name
......
...@@ -20,7 +20,6 @@ from mindspore.common import dtype as mstype ...@@ -20,7 +20,6 @@ from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\ from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\
raise_none_error raise_none_error
from ._utils.utils import CheckTensor, CheckTuple
from ._utils.custom_ops import log_by_step from ._utils.custom_ops import log_by_step
class Exponential(Distribution): class Exponential(Distribution):
...@@ -127,8 +126,6 @@ class Exponential(Distribution): ...@@ -127,8 +126,6 @@ class Exponential(Distribution):
self.sq = P.Square() self.sq = P.Square()
self.uniform = C.uniform self.uniform = C.uniform
self.checktensor = CheckTensor()
self.checktuple = CheckTuple()
def extend_repr(self): def extend_repr(self):
if self.is_scalar_batch: if self.is_scalar_batch:
...@@ -270,7 +267,7 @@ class Exponential(Distribution): ...@@ -270,7 +267,7 @@ class Exponential(Distribution):
Returns: Returns:
Tensor, shape is shape + batch_shape. Tensor, shape is shape + batch_shape.
""" """
self.checktuple(shape, 'shape') shape = self.checktuple(shape, 'shape')
rate = self._check_param(rate) rate = self._check_param(rate)
origin_shape = shape + self.shape(rate) origin_shape = shape + self.shape(rate)
if origin_shape == (): if origin_shape == ():
......
...@@ -20,7 +20,6 @@ from mindspore.common import dtype as mstype ...@@ -20,7 +20,6 @@ from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\ from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\
raise_none_error raise_none_error
from ._utils.utils import CheckTensor, CheckTuple
from ._utils.custom_ops import log_by_step from ._utils.custom_ops import log_by_step
class Geometric(Distribution): class Geometric(Distribution):
...@@ -131,8 +130,6 @@ class Geometric(Distribution): ...@@ -131,8 +130,6 @@ class Geometric(Distribution):
self.sqrt = P.Sqrt() self.sqrt = P.Sqrt()
self.uniform = C.uniform self.uniform = C.uniform
self.checktensor = CheckTensor()
self.checktuple = CheckTuple()
def extend_repr(self): def extend_repr(self):
if self.is_scalar_batch: if self.is_scalar_batch:
...@@ -278,7 +275,7 @@ class Geometric(Distribution): ...@@ -278,7 +275,7 @@ class Geometric(Distribution):
Returns: Returns:
Tensor, shape is shape + batch_shape. Tensor, shape is shape + batch_shape.
""" """
self.checktuple(shape, 'shape') shape = self.checktuple(shape, 'shape')
probs1 = self._check_param(probs1) probs1 = self._check_param(probs1)
origin_shape = shape + self.shape(probs1) origin_shape = shape + self.shape(probs1)
if origin_shape == (): if origin_shape == ():
......
...@@ -20,7 +20,6 @@ from mindspore.common import dtype as mstype ...@@ -20,7 +20,6 @@ from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import convert_to_batch, check_greater_zero, check_type, check_distribution_name,\ from ._utils.utils import convert_to_batch, check_greater_zero, check_type, check_distribution_name,\
raise_none_error raise_none_error
from ._utils.utils import CheckTensor, CheckTuple
from ._utils.custom_ops import log_by_step, expm1_by_step from ._utils.custom_ops import log_by_step, expm1_by_step
class Normal(Distribution): class Normal(Distribution):
...@@ -128,9 +127,6 @@ class Normal(Distribution): ...@@ -128,9 +127,6 @@ class Normal(Distribution):
self.sqrt = P.Sqrt() self.sqrt = P.Sqrt()
self.zeroslike = P.ZerosLike() self.zeroslike = P.ZerosLike()
self.checktensor = CheckTensor()
self.checktuple = CheckTuple()
def extend_repr(self): def extend_repr(self):
if self.is_scalar_batch: if self.is_scalar_batch:
str_info = f'mean = {self._mean_value}, standard deviation = {self._sd_value}' str_info = f'mean = {self._mean_value}, standard deviation = {self._sd_value}'
...@@ -277,7 +273,7 @@ class Normal(Distribution): ...@@ -277,7 +273,7 @@ class Normal(Distribution):
Returns: Returns:
Tensor, shape is shape + batch_shape. Tensor, shape is shape + batch_shape.
""" """
self.checktuple(shape, 'shape') shape = self.checktuple(shape, 'shape')
mean, sd = self._check_param(mean, sd) mean, sd = self._check_param(mean, sd)
batch_shape = self.shape(mean + sd) batch_shape = self.shape(mean + sd)
origin_shape = shape + batch_shape origin_shape = shape + batch_shape
......
...@@ -116,4 +116,4 @@ class TransformedDistribution(Distribution): ...@@ -116,4 +116,4 @@ class TransformedDistribution(Distribution):
if not self.is_linear_transformation: if not self.is_linear_transformation:
raise_not_impl_error("mean") raise_not_impl_error("mean")
return self.bijector("forward", self.distribution("mean")) return self.bijector("forward", self.distribution("mean", *args, **kwargs))
...@@ -19,7 +19,6 @@ from mindspore.common import dtype as mstype ...@@ -19,7 +19,6 @@ from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import convert_to_batch, check_greater, check_type, check_distribution_name,\ from ._utils.utils import convert_to_batch, check_greater, check_type, check_distribution_name,\
raise_none_error raise_none_error
from ._utils.utils import CheckTensor, CheckTuple
from ._utils.custom_ops import log_by_step from ._utils.custom_ops import log_by_step
class Uniform(Distribution): class Uniform(Distribution):
...@@ -131,9 +130,6 @@ class Uniform(Distribution): ...@@ -131,9 +130,6 @@ class Uniform(Distribution):
self.zeroslike = P.ZerosLike() self.zeroslike = P.ZerosLike()
self.uniform = C.uniform self.uniform = C.uniform
self.checktensor = CheckTensor()
self.checktuple = CheckTuple()
def extend_repr(self): def extend_repr(self):
if self.is_scalar_batch: if self.is_scalar_batch:
str_info = f'low = {self.low}, high = {self.high}' str_info = f'low = {self.low}, high = {self.high}'
...@@ -306,7 +302,7 @@ class Uniform(Distribution): ...@@ -306,7 +302,7 @@ class Uniform(Distribution):
Returns: Returns:
Tensor, shape is shape + batch_shape. Tensor, shape is shape + batch_shape.
""" """
self.checktuple(shape, 'shape') shape = self.checktuple(shape, 'shape')
low, high = self._check_param(low, high) low, high = self._check_param(low, high)
broadcast_shape = self.shape(low + high) broadcast_shape = self.shape(low + high)
origin_shape = shape + broadcast_shape origin_shape = shape + broadcast_shape
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册