提交 65c6755f 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3691 Change distribution import

Merge pull request !3691 from XunDeng/pp_poc_v3
...@@ -24,7 +24,6 @@ from .loss import * ...@@ -24,7 +24,6 @@ from .loss import *
from .optim import * from .optim import *
from .metrics import * from .metrics import *
from .wrap import * from .wrap import *
from .probability import *
__all__ = ["Cell", "GraphKernel"] __all__ = ["Cell", "GraphKernel"]
...@@ -33,7 +32,7 @@ __all__.extend(loss.__all__) ...@@ -33,7 +32,7 @@ __all__.extend(loss.__all__)
__all__.extend(optim.__all__) __all__.extend(optim.__all__)
__all__.extend(metrics.__all__) __all__.extend(metrics.__all__)
__all__.extend(wrap.__all__) __all__.extend(wrap.__all__)
__all__.extend(probability.__all__)
__all__.sort() __all__.sort()
...@@ -15,10 +15,7 @@ ...@@ -15,10 +15,7 @@
""" """
Probability. Probability.
The high-level components(Distributions) used to construct the probabilistic network. The high-level components used to construct the probabilistic network.
""" """
from .distribution import * from . import distribution
__all__ = []
__all__.extend(distribution.__all__)
...@@ -12,11 +12,12 @@ ...@@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""test cases for bernoulli distribution""" """test cases for Bernoulli distribution"""
import numpy as np import numpy as np
from scipy import stats from scipy import stats
import mindspore.context as context import mindspore.context as context
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.nn.probability.distribution as msd
from mindspore import Tensor from mindspore import Tensor
from mindspore.common.api import ms_function from mindspore.common.api import ms_function
from mindspore import dtype from mindspore import dtype
...@@ -29,7 +30,7 @@ class Prob(nn.Cell): ...@@ -29,7 +30,7 @@ class Prob(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(Prob, self).__init__() super(Prob, self).__init__()
self.b = nn.Bernoulli(0.7, dtype=dtype.int32) self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ms_function @ms_function
def construct(self, x_): def construct(self, x_):
...@@ -54,7 +55,7 @@ class LogProb(nn.Cell): ...@@ -54,7 +55,7 @@ class LogProb(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(LogProb, self).__init__() super(LogProb, self).__init__()
self.b = nn.Bernoulli(0.7, dtype=dtype.int32) self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ms_function @ms_function
def construct(self, x_): def construct(self, x_):
...@@ -78,7 +79,7 @@ class KL(nn.Cell): ...@@ -78,7 +79,7 @@ class KL(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(KL, self).__init__() super(KL, self).__init__()
self.b = nn.Bernoulli(0.7, dtype=dtype.int32) self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ms_function @ms_function
def construct(self, x_): def construct(self, x_):
...@@ -104,7 +105,7 @@ class Basics(nn.Cell): ...@@ -104,7 +105,7 @@ class Basics(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(Basics, self).__init__() super(Basics, self).__init__()
self.b = nn.Bernoulli([0.3, 0.5, 0.7], dtype=dtype.int32) self.b = msd.Bernoulli([0.3, 0.5, 0.7], dtype=dtype.int32)
@ms_function @ms_function
def construct(self): def construct(self):
...@@ -130,7 +131,7 @@ class Sampling(nn.Cell): ...@@ -130,7 +131,7 @@ class Sampling(nn.Cell):
""" """
def __init__(self, shape, seed=0): def __init__(self, shape, seed=0):
super(Sampling, self).__init__() super(Sampling, self).__init__()
self.b = nn.Bernoulli([0.7, 0.5], seed=seed, dtype=dtype.int32) self.b = msd.Bernoulli([0.7, 0.5], seed=seed, dtype=dtype.int32)
self.shape = shape self.shape = shape
@ms_function @ms_function
...@@ -152,7 +153,7 @@ class CDF(nn.Cell): ...@@ -152,7 +153,7 @@ class CDF(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(CDF, self).__init__() super(CDF, self).__init__()
self.b = nn.Bernoulli(0.7, dtype=dtype.int32) self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ms_function @ms_function
def construct(self, x_): def construct(self, x_):
...@@ -177,7 +178,7 @@ class LogCDF(nn.Cell): ...@@ -177,7 +178,7 @@ class LogCDF(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(LogCDF, self).__init__() super(LogCDF, self).__init__()
self.b = nn.Bernoulli(0.7, dtype=dtype.int32) self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ms_function @ms_function
def construct(self, x_): def construct(self, x_):
...@@ -202,7 +203,7 @@ class SF(nn.Cell): ...@@ -202,7 +203,7 @@ class SF(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(SF, self).__init__() super(SF, self).__init__()
self.b = nn.Bernoulli(0.7, dtype=dtype.int32) self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ms_function @ms_function
def construct(self, x_): def construct(self, x_):
...@@ -227,7 +228,7 @@ class LogSF(nn.Cell): ...@@ -227,7 +228,7 @@ class LogSF(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(LogSF, self).__init__() super(LogSF, self).__init__()
self.b = nn.Bernoulli(0.7, dtype=dtype.int32) self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ms_function @ms_function
def construct(self, x_): def construct(self, x_):
...@@ -251,7 +252,7 @@ class EntropyH(nn.Cell): ...@@ -251,7 +252,7 @@ class EntropyH(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(EntropyH, self).__init__() super(EntropyH, self).__init__()
self.b = nn.Bernoulli(0.7, dtype=dtype.int32) self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ms_function @ms_function
def construct(self): def construct(self):
...@@ -274,7 +275,7 @@ class CrossEntropy(nn.Cell): ...@@ -274,7 +275,7 @@ class CrossEntropy(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(CrossEntropy, self).__init__() super(CrossEntropy, self).__init__()
self.b = nn.Bernoulli(0.7, dtype=dtype.int32) self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ms_function @ms_function
def construct(self, x_): def construct(self, x_):
......
...@@ -12,11 +12,12 @@ ...@@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""test cases for exponential distribution""" """test cases for Exponential distribution"""
import numpy as np import numpy as np
from scipy import stats from scipy import stats
import mindspore.context as context import mindspore.context as context
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.nn.probability.distribution as msd
from mindspore import Tensor from mindspore import Tensor
from mindspore.common.api import ms_function from mindspore.common.api import ms_function
from mindspore import dtype from mindspore import dtype
...@@ -29,7 +30,7 @@ class Prob(nn.Cell): ...@@ -29,7 +30,7 @@ class Prob(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(Prob, self).__init__() super(Prob, self).__init__()
self.e = nn.Exponential([[1.0], [0.5]], dtype=dtype.float32) self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32)
@ms_function @ms_function
def construct(self, x_): def construct(self, x_):
...@@ -53,7 +54,7 @@ class LogProb(nn.Cell): ...@@ -53,7 +54,7 @@ class LogProb(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(LogProb, self).__init__() super(LogProb, self).__init__()
self.e = nn.Exponential([[1.0], [0.5]], dtype=dtype.float32) self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32)
@ms_function @ms_function
def construct(self, x_): def construct(self, x_):
...@@ -77,7 +78,7 @@ class KL(nn.Cell): ...@@ -77,7 +78,7 @@ class KL(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(KL, self).__init__() super(KL, self).__init__()
self.e = nn.Exponential([1.5], dtype=dtype.float32) self.e = msd.Exponential([1.5], dtype=dtype.float32)
@ms_function @ms_function
def construct(self, x_): def construct(self, x_):
...@@ -101,7 +102,7 @@ class Basics(nn.Cell): ...@@ -101,7 +102,7 @@ class Basics(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(Basics, self).__init__() super(Basics, self).__init__()
self.e = nn.Exponential([0.5], dtype=dtype.float32) self.e = msd.Exponential([0.5], dtype=dtype.float32)
@ms_function @ms_function
def construct(self): def construct(self):
...@@ -127,7 +128,7 @@ class Sampling(nn.Cell): ...@@ -127,7 +128,7 @@ class Sampling(nn.Cell):
""" """
def __init__(self, shape, seed=0): def __init__(self, shape, seed=0):
super(Sampling, self).__init__() super(Sampling, self).__init__()
self.e = nn.Exponential([[1.0], [0.5]], seed=seed, dtype=dtype.float32) self.e = msd.Exponential([[1.0], [0.5]], seed=seed, dtype=dtype.float32)
self.shape = shape self.shape = shape
@ms_function @ms_function
...@@ -151,7 +152,7 @@ class CDF(nn.Cell): ...@@ -151,7 +152,7 @@ class CDF(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(CDF, self).__init__() super(CDF, self).__init__()
self.e = nn.Exponential([[1.0], [0.5]], dtype=dtype.float32) self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32)
@ms_function @ms_function
def construct(self, x_): def construct(self, x_):
...@@ -175,7 +176,7 @@ class LogCDF(nn.Cell): ...@@ -175,7 +176,7 @@ class LogCDF(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(LogCDF, self).__init__() super(LogCDF, self).__init__()
self.e = nn.Exponential([[1.0], [0.5]], dtype=dtype.float32) self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32)
@ms_function @ms_function
def construct(self, x_): def construct(self, x_):
...@@ -199,7 +200,7 @@ class SF(nn.Cell): ...@@ -199,7 +200,7 @@ class SF(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(SF, self).__init__() super(SF, self).__init__()
self.e = nn.Exponential([[1.0], [0.5]], dtype=dtype.float32) self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32)
@ms_function @ms_function
def construct(self, x_): def construct(self, x_):
...@@ -223,7 +224,7 @@ class LogSF(nn.Cell): ...@@ -223,7 +224,7 @@ class LogSF(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(LogSF, self).__init__() super(LogSF, self).__init__()
self.e = nn.Exponential([[1.0], [0.5]], dtype=dtype.float32) self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32)
@ms_function @ms_function
def construct(self, x_): def construct(self, x_):
...@@ -247,7 +248,7 @@ class EntropyH(nn.Cell): ...@@ -247,7 +248,7 @@ class EntropyH(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(EntropyH, self).__init__() super(EntropyH, self).__init__()
self.e = nn.Exponential([[1.0], [0.5]], dtype=dtype.float32) self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32)
@ms_function @ms_function
def construct(self): def construct(self):
...@@ -270,7 +271,7 @@ class CrossEntropy(nn.Cell): ...@@ -270,7 +271,7 @@ class CrossEntropy(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(CrossEntropy, self).__init__() super(CrossEntropy, self).__init__()
self.e = nn.Exponential([1.0], dtype=dtype.float32) self.e = msd.Exponential([1.0], dtype=dtype.float32)
@ms_function @ms_function
def construct(self, x_): def construct(self, x_):
......
...@@ -17,6 +17,7 @@ import numpy as np ...@@ -17,6 +17,7 @@ import numpy as np
from scipy import stats from scipy import stats
import mindspore.context as context import mindspore.context as context
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.nn.probability.distribution as msd
from mindspore import Tensor from mindspore import Tensor
from mindspore.common.api import ms_function from mindspore.common.api import ms_function
from mindspore import dtype from mindspore import dtype
...@@ -29,7 +30,7 @@ class Prob(nn.Cell): ...@@ -29,7 +30,7 @@ class Prob(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(Prob, self).__init__() super(Prob, self).__init__()
self.g = nn.Geometric(0.7, dtype=dtype.int32) self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ms_function @ms_function
def construct(self, x_): def construct(self, x_):
...@@ -53,7 +54,7 @@ class LogProb(nn.Cell): ...@@ -53,7 +54,7 @@ class LogProb(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(LogProb, self).__init__() super(LogProb, self).__init__()
self.g = nn.Geometric(0.7, dtype=dtype.int32) self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ms_function @ms_function
def construct(self, x_): def construct(self, x_):
...@@ -77,7 +78,7 @@ class KL(nn.Cell): ...@@ -77,7 +78,7 @@ class KL(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(KL, self).__init__() super(KL, self).__init__()
self.g = nn.Geometric(0.7, dtype=dtype.int32) self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ms_function @ms_function
def construct(self, x_): def construct(self, x_):
...@@ -103,7 +104,7 @@ class Basics(nn.Cell): ...@@ -103,7 +104,7 @@ class Basics(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(Basics, self).__init__() super(Basics, self).__init__()
self.g = nn.Geometric([0.5, 0.5], dtype=dtype.int32) self.g = msd.Geometric([0.5, 0.5], dtype=dtype.int32)
@ms_function @ms_function
def construct(self): def construct(self):
...@@ -129,7 +130,7 @@ class Sampling(nn.Cell): ...@@ -129,7 +130,7 @@ class Sampling(nn.Cell):
""" """
def __init__(self, shape, seed=0): def __init__(self, shape, seed=0):
super(Sampling, self).__init__() super(Sampling, self).__init__()
self.g = nn.Geometric([0.7, 0.5], seed=seed, dtype=dtype.int32) self.g = msd.Geometric([0.7, 0.5], seed=seed, dtype=dtype.int32)
self.shape = shape self.shape = shape
@ms_function @ms_function
...@@ -151,7 +152,7 @@ class CDF(nn.Cell): ...@@ -151,7 +152,7 @@ class CDF(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(CDF, self).__init__() super(CDF, self).__init__()
self.g = nn.Geometric(0.7, dtype=dtype.int32) self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ms_function @ms_function
def construct(self, x_): def construct(self, x_):
...@@ -175,7 +176,7 @@ class LogCDF(nn.Cell): ...@@ -175,7 +176,7 @@ class LogCDF(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(LogCDF, self).__init__() super(LogCDF, self).__init__()
self.g = nn.Geometric(0.7, dtype=dtype.int32) self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ms_function @ms_function
def construct(self, x_): def construct(self, x_):
...@@ -199,7 +200,7 @@ class SF(nn.Cell): ...@@ -199,7 +200,7 @@ class SF(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(SF, self).__init__() super(SF, self).__init__()
self.g = nn.Geometric(0.7, dtype=dtype.int32) self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ms_function @ms_function
def construct(self, x_): def construct(self, x_):
...@@ -223,7 +224,7 @@ class LogSF(nn.Cell): ...@@ -223,7 +224,7 @@ class LogSF(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(LogSF, self).__init__() super(LogSF, self).__init__()
self.g = nn.Geometric(0.7, dtype=dtype.int32) self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ms_function @ms_function
def construct(self, x_): def construct(self, x_):
...@@ -247,7 +248,7 @@ class EntropyH(nn.Cell): ...@@ -247,7 +248,7 @@ class EntropyH(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(EntropyH, self).__init__() super(EntropyH, self).__init__()
self.g = nn.Geometric(0.7, dtype=dtype.int32) self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ms_function @ms_function
def construct(self): def construct(self):
...@@ -270,7 +271,7 @@ class CrossEntropy(nn.Cell): ...@@ -270,7 +271,7 @@ class CrossEntropy(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(CrossEntropy, self).__init__() super(CrossEntropy, self).__init__()
self.g = nn.Geometric(0.7, dtype=dtype.int32) self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ms_function @ms_function
def construct(self, x_): def construct(self, x_):
......
...@@ -12,11 +12,12 @@ ...@@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""test cases for normal distribution""" """test cases for Normal distribution"""
import numpy as np import numpy as np
from scipy import stats from scipy import stats
import mindspore.context as context import mindspore.context as context
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.nn.probability.distribution as msd
from mindspore import Tensor from mindspore import Tensor
from mindspore.common.api import ms_function from mindspore.common.api import ms_function
from mindspore import dtype from mindspore import dtype
...@@ -29,7 +30,7 @@ class Prob(nn.Cell): ...@@ -29,7 +30,7 @@ class Prob(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(Prob, self).__init__() super(Prob, self).__init__()
self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
@ms_function @ms_function
def construct(self, x_): def construct(self, x_):
...@@ -52,7 +53,7 @@ class LogProb(nn.Cell): ...@@ -52,7 +53,7 @@ class LogProb(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(LogProb, self).__init__() super(LogProb, self).__init__()
self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
@ms_function @ms_function
def construct(self, x_): def construct(self, x_):
...@@ -76,7 +77,7 @@ class KL(nn.Cell): ...@@ -76,7 +77,7 @@ class KL(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(KL, self).__init__() super(KL, self).__init__()
self.n = nn.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) self.n = msd.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32)
@ms_function @ms_function
def construct(self, x_, y_): def construct(self, x_, y_):
...@@ -110,7 +111,7 @@ class Basics(nn.Cell): ...@@ -110,7 +111,7 @@ class Basics(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(Basics, self).__init__() super(Basics, self).__init__()
self.n = nn.Normal(np.array([3.0]), np.array([2.0, 4.0]), dtype=dtype.float32) self.n = msd.Normal(np.array([3.0]), np.array([2.0, 4.0]), dtype=dtype.float32)
@ms_function @ms_function
def construct(self): def construct(self):
...@@ -135,7 +136,7 @@ class Sampling(nn.Cell): ...@@ -135,7 +136,7 @@ class Sampling(nn.Cell):
""" """
def __init__(self, shape, seed=0): def __init__(self, shape, seed=0):
super(Sampling, self).__init__() super(Sampling, self).__init__()
self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), seed=seed, dtype=dtype.float32) self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), seed=seed, dtype=dtype.float32)
self.shape = shape self.shape = shape
@ms_function @ms_function
...@@ -160,7 +161,7 @@ class CDF(nn.Cell): ...@@ -160,7 +161,7 @@ class CDF(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(CDF, self).__init__() super(CDF, self).__init__()
self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
@ms_function @ms_function
def construct(self, x_): def construct(self, x_):
...@@ -184,7 +185,7 @@ class LogCDF(nn.Cell): ...@@ -184,7 +185,7 @@ class LogCDF(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(LogCDF, self).__init__() super(LogCDF, self).__init__()
self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
@ms_function @ms_function
def construct(self, x_): def construct(self, x_):
...@@ -207,7 +208,7 @@ class SF(nn.Cell): ...@@ -207,7 +208,7 @@ class SF(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(SF, self).__init__() super(SF, self).__init__()
self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
@ms_function @ms_function
def construct(self, x_): def construct(self, x_):
...@@ -230,7 +231,7 @@ class LogSF(nn.Cell): ...@@ -230,7 +231,7 @@ class LogSF(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(LogSF, self).__init__() super(LogSF, self).__init__()
self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
@ms_function @ms_function
def construct(self, x_): def construct(self, x_):
...@@ -253,7 +254,7 @@ class EntropyH(nn.Cell): ...@@ -253,7 +254,7 @@ class EntropyH(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(EntropyH, self).__init__() super(EntropyH, self).__init__()
self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
@ms_function @ms_function
def construct(self): def construct(self):
...@@ -276,7 +277,7 @@ class CrossEntropy(nn.Cell): ...@@ -276,7 +277,7 @@ class CrossEntropy(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(CrossEntropy, self).__init__() super(CrossEntropy, self).__init__()
self.n = nn.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) self.n = msd.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32)
@ms_function @ms_function
def construct(self, x_, y_): def construct(self, x_, y_):
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import numpy as np import numpy as np
from scipy import stats from scipy import stats
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.nn.probability.distribution as msd
from mindspore import dtype from mindspore import dtype
from mindspore import Tensor from mindspore import Tensor
import mindspore.context as context import mindspore.context as context
...@@ -30,7 +31,7 @@ class Net(nn.Cell): ...@@ -30,7 +31,7 @@ class Net(nn.Cell):
def __init__(self): def __init__(self):
super(Net, self).__init__() super(Net, self).__init__()
self.normal = nn.Normal(0., 1., dtype=dtype.float32) self.normal = msd.Normal(0., 1., dtype=dtype.float32)
def construct(self, x_, y_): def construct(self, x_, y_):
kl = self.normal.kl_loss('kl_loss', 'Normal', x_, y_) kl = self.normal.kl_loss('kl_loss', 'Normal', x_, y_)
......
...@@ -12,11 +12,12 @@ ...@@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""test cases for uniform distribution""" """test cases for Uniform distribution"""
import numpy as np import numpy as np
from scipy import stats from scipy import stats
import mindspore.context as context import mindspore.context as context
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.nn.probability.distribution as msd
from mindspore import Tensor from mindspore import Tensor
from mindspore.common.api import ms_function from mindspore.common.api import ms_function
from mindspore import dtype from mindspore import dtype
...@@ -29,7 +30,7 @@ class Prob(nn.Cell): ...@@ -29,7 +30,7 @@ class Prob(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(Prob, self).__init__() super(Prob, self).__init__()
self.u = nn.Uniform([0.0], [[1.0], [2.0]], dtype=dtype.float32) self.u = msd.Uniform([0.0], [[1.0], [2.0]], dtype=dtype.float32)
@ms_function @ms_function
def construct(self, x_): def construct(self, x_):
...@@ -53,7 +54,7 @@ class LogProb(nn.Cell): ...@@ -53,7 +54,7 @@ class LogProb(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(LogProb, self).__init__() super(LogProb, self).__init__()
self.u = nn.Uniform([0.0], [[1.0], [2.0]], dtype=dtype.float32) self.u = msd.Uniform([0.0], [[1.0], [2.0]], dtype=dtype.float32)
@ms_function @ms_function
def construct(self, x_): def construct(self, x_):
...@@ -77,7 +78,7 @@ class KL(nn.Cell): ...@@ -77,7 +78,7 @@ class KL(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(KL, self).__init__() super(KL, self).__init__()
self.u = nn.Uniform([0.0], [1.5], dtype=dtype.float32) self.u = msd.Uniform([0.0], [1.5], dtype=dtype.float32)
@ms_function @ms_function
def construct(self, x_, y_): def construct(self, x_, y_):
...@@ -103,7 +104,7 @@ class Basics(nn.Cell): ...@@ -103,7 +104,7 @@ class Basics(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(Basics, self).__init__() super(Basics, self).__init__()
self.u = nn.Uniform([0.0], [3.0], dtype=dtype.float32) self.u = msd.Uniform([0.0], [3.0], dtype=dtype.float32)
@ms_function @ms_function
def construct(self): def construct(self):
...@@ -127,7 +128,7 @@ class Sampling(nn.Cell): ...@@ -127,7 +128,7 @@ class Sampling(nn.Cell):
""" """
def __init__(self, shape, seed=0): def __init__(self, shape, seed=0):
super(Sampling, self).__init__() super(Sampling, self).__init__()
self.u = nn.Uniform([0.0], [[1.0], [2.0]], seed=seed, dtype=dtype.float32) self.u = msd.Uniform([0.0], [[1.0], [2.0]], seed=seed, dtype=dtype.float32)
self.shape = shape self.shape = shape
@ms_function @ms_function
...@@ -152,7 +153,7 @@ class CDF(nn.Cell): ...@@ -152,7 +153,7 @@ class CDF(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(CDF, self).__init__() super(CDF, self).__init__()
self.u = nn.Uniform([0.0], [1.0], dtype=dtype.float32) self.u = msd.Uniform([0.0], [1.0], dtype=dtype.float32)
@ms_function @ms_function
def construct(self, x_): def construct(self, x_):
...@@ -176,7 +177,7 @@ class LogCDF(nn.Cell): ...@@ -176,7 +177,7 @@ class LogCDF(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(LogCDF, self).__init__() super(LogCDF, self).__init__()
self.u = nn.Uniform([0.0], [1.0], dtype=dtype.float32) self.u = msd.Uniform([0.0], [1.0], dtype=dtype.float32)
@ms_function @ms_function
def construct(self, x_): def construct(self, x_):
...@@ -188,7 +189,7 @@ class SF(nn.Cell): ...@@ -188,7 +189,7 @@ class SF(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(SF, self).__init__() super(SF, self).__init__()
self.u = nn.Uniform([0.0], [1.0], dtype=dtype.float32) self.u = msd.Uniform([0.0], [1.0], dtype=dtype.float32)
@ms_function @ms_function
def construct(self, x_): def construct(self, x_):
...@@ -200,7 +201,7 @@ class LogSF(nn.Cell): ...@@ -200,7 +201,7 @@ class LogSF(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(LogSF, self).__init__() super(LogSF, self).__init__()
self.u = nn.Uniform([0.0], [1.0], dtype=dtype.float32) self.u = msd.Uniform([0.0], [1.0], dtype=dtype.float32)
@ms_function @ms_function
def construct(self, x_): def construct(self, x_):
...@@ -212,7 +213,7 @@ class EntropyH(nn.Cell): ...@@ -212,7 +213,7 @@ class EntropyH(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(EntropyH, self).__init__() super(EntropyH, self).__init__()
self.u = nn.Uniform([0.0], [1.0, 2.0], dtype=dtype.float32) self.u = msd.Uniform([0.0], [1.0, 2.0], dtype=dtype.float32)
@ms_function @ms_function
def construct(self): def construct(self):
...@@ -235,7 +236,7 @@ class CrossEntropy(nn.Cell): ...@@ -235,7 +236,7 @@ class CrossEntropy(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(CrossEntropy, self).__init__() super(CrossEntropy, self).__init__()
self.u = nn.Uniform([0.0], [1.5], dtype=dtype.float32) self.u = msd.Uniform([0.0], [1.5], dtype=dtype.float32)
@ms_function @ms_function
def construct(self, x_, y_): def construct(self, x_, y_):
......
...@@ -13,11 +13,12 @@ ...@@ -13,11 +13,12 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
""" """
Test nn.Distribution.Bernoulli. Test nn.probability.distribution.Bernoulli.
""" """
import pytest import pytest
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.nn.probability.distribution as msd
from mindspore import dtype from mindspore import dtype
from mindspore import Tensor from mindspore import Tensor
...@@ -25,19 +26,19 @@ def test_arguments(): ...@@ -25,19 +26,19 @@ def test_arguments():
""" """
Args passing during initialization. Args passing during initialization.
""" """
b = nn.Bernoulli() b = msd.Bernoulli()
assert isinstance(b, nn.Distribution) assert isinstance(b, msd.Distribution)
b = nn.Bernoulli([0.0, 0.3, 0.5, 1.0], dtype=dtype.int32) b = msd.Bernoulli([0.0, 0.3, 0.5, 1.0], dtype=dtype.int32)
assert isinstance(b, nn.Distribution) assert isinstance(b, msd.Distribution)
def test_prob(): def test_prob():
""" """
Invalid probability. Invalid probability.
""" """
with pytest.raises(ValueError): with pytest.raises(ValueError):
nn.Bernoulli([-0.1], dtype=dtype.int32) msd.Bernoulli([-0.1], dtype=dtype.int32)
with pytest.raises(ValueError): with pytest.raises(ValueError):
nn.Bernoulli([1.1], dtype=dtype.int32) msd.Bernoulli([1.1], dtype=dtype.int32)
class BernoulliProb(nn.Cell): class BernoulliProb(nn.Cell):
""" """
...@@ -45,7 +46,7 @@ class BernoulliProb(nn.Cell): ...@@ -45,7 +46,7 @@ class BernoulliProb(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(BernoulliProb, self).__init__() super(BernoulliProb, self).__init__()
self.b = nn.Bernoulli(0.5, dtype=dtype.int32) self.b = msd.Bernoulli(0.5, dtype=dtype.int32)
def construct(self, value): def construct(self, value):
prob = self.b('prob', value) prob = self.b('prob', value)
...@@ -71,7 +72,7 @@ class BernoulliProb1(nn.Cell): ...@@ -71,7 +72,7 @@ class BernoulliProb1(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(BernoulliProb1, self).__init__() super(BernoulliProb1, self).__init__()
self.b = nn.Bernoulli(dtype=dtype.int32) self.b = msd.Bernoulli(dtype=dtype.int32)
def construct(self, value, probs): def construct(self, value, probs):
prob = self.b('prob', value, probs) prob = self.b('prob', value, probs)
...@@ -98,8 +99,8 @@ class BernoulliKl(nn.Cell): ...@@ -98,8 +99,8 @@ class BernoulliKl(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(BernoulliKl, self).__init__() super(BernoulliKl, self).__init__()
self.b1 = nn.Bernoulli(0.7, dtype=dtype.int32) self.b1 = msd.Bernoulli(0.7, dtype=dtype.int32)
self.b2 = nn.Bernoulli(dtype=dtype.int32) self.b2 = msd.Bernoulli(dtype=dtype.int32)
def construct(self, probs_b, probs_a): def construct(self, probs_b, probs_a):
kl1 = self.b1('kl_loss', 'Bernoulli', probs_b) kl1 = self.b1('kl_loss', 'Bernoulli', probs_b)
...@@ -122,8 +123,8 @@ class BernoulliCrossEntropy(nn.Cell): ...@@ -122,8 +123,8 @@ class BernoulliCrossEntropy(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(BernoulliCrossEntropy, self).__init__() super(BernoulliCrossEntropy, self).__init__()
self.b1 = nn.Bernoulli(0.7, dtype=dtype.int32) self.b1 = msd.Bernoulli(0.7, dtype=dtype.int32)
self.b2 = nn.Bernoulli(dtype=dtype.int32) self.b2 = msd.Bernoulli(dtype=dtype.int32)
def construct(self, probs_b, probs_a): def construct(self, probs_b, probs_a):
h1 = self.b1('cross_entropy', 'Bernoulli', probs_b) h1 = self.b1('cross_entropy', 'Bernoulli', probs_b)
...@@ -146,7 +147,7 @@ class BernoulliBasics(nn.Cell): ...@@ -146,7 +147,7 @@ class BernoulliBasics(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(BernoulliBasics, self).__init__() super(BernoulliBasics, self).__init__()
self.b = nn.Bernoulli([0.3, 0.5], dtype=dtype.int32) self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32)
def construct(self): def construct(self):
mean = self.b('mean') mean = self.b('mean')
......
...@@ -13,11 +13,12 @@ ...@@ -13,11 +13,12 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
""" """
Test nn.Distribution.Exponential. Test nn.probability.distribution.Exponential.
""" """
import pytest import pytest
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.nn.probability.distribution as msd
from mindspore import dtype from mindspore import dtype
from mindspore import Tensor from mindspore import Tensor
...@@ -26,19 +27,19 @@ def test_arguments(): ...@@ -26,19 +27,19 @@ def test_arguments():
""" """
Args passing during initialization. Args passing during initialization.
""" """
e = nn.Exponential() e = msd.Exponential()
assert isinstance(e, nn.Distribution) assert isinstance(e, msd.Distribution)
e = nn.Exponential([0.1, 0.3, 0.5, 1.0], dtype=dtype.float32) e = msd.Exponential([0.1, 0.3, 0.5, 1.0], dtype=dtype.float32)
assert isinstance(e, nn.Distribution) assert isinstance(e, msd.Distribution)
def test_rate(): def test_rate():
""" """
Invalid rate. Invalid rate.
""" """
with pytest.raises(ValueError): with pytest.raises(ValueError):
nn.Exponential([-0.1], dtype=dtype.float32) msd.Exponential([-0.1], dtype=dtype.float32)
with pytest.raises(ValueError): with pytest.raises(ValueError):
nn.Exponential([0.0], dtype=dtype.float32) msd.Exponential([0.0], dtype=dtype.float32)
class ExponentialProb(nn.Cell): class ExponentialProb(nn.Cell):
""" """
...@@ -46,7 +47,7 @@ class ExponentialProb(nn.Cell): ...@@ -46,7 +47,7 @@ class ExponentialProb(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(ExponentialProb, self).__init__() super(ExponentialProb, self).__init__()
self.e = nn.Exponential(0.5, dtype=dtype.float32) self.e = msd.Exponential(0.5, dtype=dtype.float32)
def construct(self, value): def construct(self, value):
prob = self.e('prob', value) prob = self.e('prob', value)
...@@ -72,7 +73,7 @@ class ExponentialProb1(nn.Cell): ...@@ -72,7 +73,7 @@ class ExponentialProb1(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(ExponentialProb1, self).__init__() super(ExponentialProb1, self).__init__()
self.e = nn.Exponential(dtype=dtype.float32) self.e = msd.Exponential(dtype=dtype.float32)
def construct(self, value, rate): def construct(self, value, rate):
prob = self.e('prob', value, rate) prob = self.e('prob', value, rate)
...@@ -99,8 +100,8 @@ class ExponentialKl(nn.Cell): ...@@ -99,8 +100,8 @@ class ExponentialKl(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(ExponentialKl, self).__init__() super(ExponentialKl, self).__init__()
self.e1 = nn.Exponential(0.7, dtype=dtype.float32) self.e1 = msd.Exponential(0.7, dtype=dtype.float32)
self.e2 = nn.Exponential(dtype=dtype.float32) self.e2 = msd.Exponential(dtype=dtype.float32)
def construct(self, rate_b, rate_a): def construct(self, rate_b, rate_a):
kl1 = self.e1('kl_loss', 'Exponential', rate_b) kl1 = self.e1('kl_loss', 'Exponential', rate_b)
...@@ -123,8 +124,8 @@ class ExponentialCrossEntropy(nn.Cell): ...@@ -123,8 +124,8 @@ class ExponentialCrossEntropy(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(ExponentialCrossEntropy, self).__init__() super(ExponentialCrossEntropy, self).__init__()
self.e1 = nn.Exponential(0.3, dtype=dtype.float32) self.e1 = msd.Exponential(0.3, dtype=dtype.float32)
self.e2 = nn.Exponential(dtype=dtype.float32) self.e2 = msd.Exponential(dtype=dtype.float32)
def construct(self, rate_b, rate_a): def construct(self, rate_b, rate_a):
h1 = self.e1('cross_entropy', 'Exponential', rate_b) h1 = self.e1('cross_entropy', 'Exponential', rate_b)
...@@ -147,7 +148,7 @@ class ExponentialBasics(nn.Cell): ...@@ -147,7 +148,7 @@ class ExponentialBasics(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(ExponentialBasics, self).__init__() super(ExponentialBasics, self).__init__()
self.e = nn.Exponential([0.3, 0.5], dtype=dtype.float32) self.e = msd.Exponential([0.3, 0.5], dtype=dtype.float32)
def construct(self): def construct(self):
mean = self.e('mean') mean = self.e('mean')
......
...@@ -13,11 +13,12 @@ ...@@ -13,11 +13,12 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
""" """
Test nn.Distribution.Geometric. Test nn.probability.distribution.Geometric.
""" """
import pytest import pytest
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.nn.probability.distribution as msd
from mindspore import dtype from mindspore import dtype
from mindspore import Tensor from mindspore import Tensor
...@@ -26,19 +27,19 @@ def test_arguments(): ...@@ -26,19 +27,19 @@ def test_arguments():
""" """
Args passing during initialization. Args passing during initialization.
""" """
g = nn.Geometric() g = msd.Geometric()
assert isinstance(g, nn.Distribution) assert isinstance(g, msd.Distribution)
g = nn.Geometric([0.0, 0.3, 0.5, 1.0], dtype=dtype.int32) g = msd.Geometric([0.0, 0.3, 0.5, 1.0], dtype=dtype.int32)
assert isinstance(g, nn.Distribution) assert isinstance(g, msd.Distribution)
def test_prob(): def test_prob():
""" """
Invalid probability. Invalid probability.
""" """
with pytest.raises(ValueError): with pytest.raises(ValueError):
nn.Geometric([-0.1], dtype=dtype.int32) msd.Geometric([-0.1], dtype=dtype.int32)
with pytest.raises(ValueError): with pytest.raises(ValueError):
nn.Geometric([1.1], dtype=dtype.int32) msd.Geometric([1.1], dtype=dtype.int32)
class GeometricProb(nn.Cell): class GeometricProb(nn.Cell):
""" """
...@@ -46,7 +47,7 @@ class GeometricProb(nn.Cell): ...@@ -46,7 +47,7 @@ class GeometricProb(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(GeometricProb, self).__init__() super(GeometricProb, self).__init__()
self.g = nn.Geometric(0.5, dtype=dtype.int32) self.g = msd.Geometric(0.5, dtype=dtype.int32)
def construct(self, value): def construct(self, value):
prob = self.g('prob', value) prob = self.g('prob', value)
...@@ -72,7 +73,7 @@ class GeometricProb1(nn.Cell): ...@@ -72,7 +73,7 @@ class GeometricProb1(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(GeometricProb1, self).__init__() super(GeometricProb1, self).__init__()
self.g = nn.Geometric(dtype=dtype.int32) self.g = msd.Geometric(dtype=dtype.int32)
def construct(self, value, probs): def construct(self, value, probs):
prob = self.g('prob', value, probs) prob = self.g('prob', value, probs)
...@@ -100,8 +101,8 @@ class GeometricKl(nn.Cell): ...@@ -100,8 +101,8 @@ class GeometricKl(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(GeometricKl, self).__init__() super(GeometricKl, self).__init__()
self.g1 = nn.Geometric(0.7, dtype=dtype.int32) self.g1 = msd.Geometric(0.7, dtype=dtype.int32)
self.g2 = nn.Geometric(dtype=dtype.int32) self.g2 = msd.Geometric(dtype=dtype.int32)
def construct(self, probs_b, probs_a): def construct(self, probs_b, probs_a):
kl1 = self.g1('kl_loss', 'Geometric', probs_b) kl1 = self.g1('kl_loss', 'Geometric', probs_b)
...@@ -124,8 +125,8 @@ class GeometricCrossEntropy(nn.Cell): ...@@ -124,8 +125,8 @@ class GeometricCrossEntropy(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(GeometricCrossEntropy, self).__init__() super(GeometricCrossEntropy, self).__init__()
self.g1 = nn.Geometric(0.3, dtype=dtype.int32) self.g1 = msd.Geometric(0.3, dtype=dtype.int32)
self.g2 = nn.Geometric(dtype=dtype.int32) self.g2 = msd.Geometric(dtype=dtype.int32)
def construct(self, probs_b, probs_a): def construct(self, probs_b, probs_a):
h1 = self.g1('cross_entropy', 'Geometric', probs_b) h1 = self.g1('cross_entropy', 'Geometric', probs_b)
...@@ -148,7 +149,7 @@ class GeometricBasics(nn.Cell): ...@@ -148,7 +149,7 @@ class GeometricBasics(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(GeometricBasics, self).__init__() super(GeometricBasics, self).__init__()
self.g = nn.Geometric([0.3, 0.5], dtype=dtype.int32) self.g = msd.Geometric([0.3, 0.5], dtype=dtype.int32)
def construct(self): def construct(self):
mean = self.g('mean') mean = self.g('mean')
......
...@@ -13,12 +13,13 @@ ...@@ -13,12 +13,13 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
""" """
Test nn.Distribution.Normal. Test nn.probability.distribution.Normal.
""" """
import numpy as np import numpy as np
import pytest import pytest
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.nn.probability.distribution as msd
from mindspore import dtype from mindspore import dtype
from mindspore import Tensor from mindspore import Tensor
...@@ -27,17 +28,17 @@ def test_normal_shape_errpr(): ...@@ -27,17 +28,17 @@ def test_normal_shape_errpr():
Invalid shapes. Invalid shapes.
""" """
with pytest.raises(ValueError): with pytest.raises(ValueError):
nn.Normal([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32) msd.Normal([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32)
def test_arguments(): def test_arguments():
""" """
args passing during initialization. args passing during initialization.
""" """
n = nn.Normal() n = msd.Normal()
assert isinstance(n, nn.Distribution) assert isinstance(n, msd.Distribution)
n = nn.Normal([3.0], [4.0], dtype=dtype.float32) n = msd.Normal([3.0], [4.0], dtype=dtype.float32)
assert isinstance(n, nn.Distribution) assert isinstance(n, msd.Distribution)
class NormalProb(nn.Cell): class NormalProb(nn.Cell):
...@@ -46,7 +47,7 @@ class NormalProb(nn.Cell): ...@@ -46,7 +47,7 @@ class NormalProb(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(NormalProb, self).__init__() super(NormalProb, self).__init__()
self.normal = nn.Normal(3.0, 4.0, dtype=dtype.float32) self.normal = msd.Normal(3.0, 4.0, dtype=dtype.float32)
def construct(self, value): def construct(self, value):
prob = self.normal('prob', value) prob = self.normal('prob', value)
...@@ -73,7 +74,7 @@ class NormalProb1(nn.Cell): ...@@ -73,7 +74,7 @@ class NormalProb1(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(NormalProb1, self).__init__() super(NormalProb1, self).__init__()
self.normal = nn.Normal() self.normal = msd.Normal()
def construct(self, value, mean, sd): def construct(self, value, mean, sd):
prob = self.normal('prob', value, mean, sd) prob = self.normal('prob', value, mean, sd)
...@@ -101,8 +102,8 @@ class NormalKl(nn.Cell): ...@@ -101,8 +102,8 @@ class NormalKl(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(NormalKl, self).__init__() super(NormalKl, self).__init__()
self.n1 = nn.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) self.n1 = msd.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32)
self.n2 = nn.Normal(dtype=dtype.float32) self.n2 = msd.Normal(dtype=dtype.float32)
def construct(self, mean_b, sd_b, mean_a, sd_a): def construct(self, mean_b, sd_b, mean_a, sd_a):
kl1 = self.n1('kl_loss', 'Normal', mean_b, sd_b) kl1 = self.n1('kl_loss', 'Normal', mean_b, sd_b)
...@@ -127,8 +128,8 @@ class NormalCrossEntropy(nn.Cell): ...@@ -127,8 +128,8 @@ class NormalCrossEntropy(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(NormalCrossEntropy, self).__init__() super(NormalCrossEntropy, self).__init__()
self.n1 = nn.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) self.n1 = msd.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32)
self.n2 = nn.Normal(dtype=dtype.float32) self.n2 = msd.Normal(dtype=dtype.float32)
def construct(self, mean_b, sd_b, mean_a, sd_a): def construct(self, mean_b, sd_b, mean_a, sd_a):
h1 = self.n1('cross_entropy', 'Normal', mean_b, sd_b) h1 = self.n1('cross_entropy', 'Normal', mean_b, sd_b)
...@@ -153,7 +154,7 @@ class NormalBasics(nn.Cell): ...@@ -153,7 +154,7 @@ class NormalBasics(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(NormalBasics, self).__init__() super(NormalBasics, self).__init__()
self.n = nn.Normal(3.0, 4.0, dtype=dtype.float32) self.n = msd.Normal(3.0, 4.0, dtype=dtype.float32)
def construct(self): def construct(self):
mean = self.n('mean') mean = self.n('mean')
......
...@@ -13,12 +13,13 @@ ...@@ -13,12 +13,13 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
""" """
Test nn.Distribution.Uniform. Test nn.probability.distribution.Uniform.
""" """
import numpy as np import numpy as np
import pytest import pytest
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.nn.probability.distribution as msd
from mindspore import dtype from mindspore import dtype
from mindspore import Tensor from mindspore import Tensor
...@@ -27,17 +28,17 @@ def test_uniform_shape_errpr(): ...@@ -27,17 +28,17 @@ def test_uniform_shape_errpr():
Invalid shapes. Invalid shapes.
""" """
with pytest.raises(ValueError): with pytest.raises(ValueError):
nn.Uniform([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32) msd.Uniform([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32)
def test_arguments(): def test_arguments():
""" """
Args passing during initialization. Args passing during initialization.
""" """
u = nn.Uniform() u = msd.Uniform()
assert isinstance(u, nn.Distribution) assert isinstance(u, msd.Distribution)
u = nn.Uniform([3.0], [4.0], dtype=dtype.float32) u = msd.Uniform([3.0], [4.0], dtype=dtype.float32)
assert isinstance(u, nn.Distribution) assert isinstance(u, msd.Distribution)
def test_invalid_range(): def test_invalid_range():
...@@ -45,9 +46,9 @@ def test_invalid_range(): ...@@ -45,9 +46,9 @@ def test_invalid_range():
Test range of uniform distribution. Test range of uniform distribution.
""" """
with pytest.raises(ValueError): with pytest.raises(ValueError):
nn.Uniform(0.0, 0.0, dtype=dtype.float32) msd.Uniform(0.0, 0.0, dtype=dtype.float32)
with pytest.raises(ValueError): with pytest.raises(ValueError):
nn.Uniform(1.0, 0.0, dtype=dtype.float32) msd.Uniform(1.0, 0.0, dtype=dtype.float32)
class UniformProb(nn.Cell): class UniformProb(nn.Cell):
...@@ -56,7 +57,7 @@ class UniformProb(nn.Cell): ...@@ -56,7 +57,7 @@ class UniformProb(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(UniformProb, self).__init__() super(UniformProb, self).__init__()
self.u = nn.Uniform(3.0, 4.0, dtype=dtype.float32) self.u = msd.Uniform(3.0, 4.0, dtype=dtype.float32)
def construct(self, value): def construct(self, value):
prob = self.u('prob', value) prob = self.u('prob', value)
...@@ -82,7 +83,7 @@ class UniformProb1(nn.Cell): ...@@ -82,7 +83,7 @@ class UniformProb1(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(UniformProb1, self).__init__() super(UniformProb1, self).__init__()
self.u = nn.Uniform(dtype=dtype.float32) self.u = msd.Uniform(dtype=dtype.float32)
def construct(self, value, low, high): def construct(self, value, low, high):
prob = self.u('prob', value, low, high) prob = self.u('prob', value, low, high)
...@@ -110,8 +111,8 @@ class UniformKl(nn.Cell): ...@@ -110,8 +111,8 @@ class UniformKl(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(UniformKl, self).__init__() super(UniformKl, self).__init__()
self.u1 = nn.Uniform(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) self.u1 = msd.Uniform(np.array([3.0]), np.array([4.0]), dtype=dtype.float32)
self.u2 = nn.Uniform(dtype=dtype.float32) self.u2 = msd.Uniform(dtype=dtype.float32)
def construct(self, low_b, high_b, low_a, high_a): def construct(self, low_b, high_b, low_a, high_a):
kl1 = self.u1('kl_loss', 'Uniform', low_b, high_b) kl1 = self.u1('kl_loss', 'Uniform', low_b, high_b)
...@@ -136,8 +137,8 @@ class UniformCrossEntropy(nn.Cell): ...@@ -136,8 +137,8 @@ class UniformCrossEntropy(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(UniformCrossEntropy, self).__init__() super(UniformCrossEntropy, self).__init__()
self.u1 = nn.Uniform(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) self.u1 = msd.Uniform(np.array([3.0]), np.array([4.0]), dtype=dtype.float32)
self.u2 = nn.Uniform(dtype=dtype.float32) self.u2 = msd.Uniform(dtype=dtype.float32)
def construct(self, low_b, high_b, low_a, high_a): def construct(self, low_b, high_b, low_a, high_a):
h1 = self.u1('cross_entropy', 'Uniform', low_b, high_b) h1 = self.u1('cross_entropy', 'Uniform', low_b, high_b)
...@@ -162,7 +163,7 @@ class UniformBasics(nn.Cell): ...@@ -162,7 +163,7 @@ class UniformBasics(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(UniformBasics, self).__init__() super(UniformBasics, self).__init__()
self.u = nn.Uniform(3.0, 4.0, dtype=dtype.float32) self.u = msd.Uniform(3.0, 4.0, dtype=dtype.float32)
def construct(self): def construct(self):
mean = self.u('mean') mean = self.u('mean')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册