提交 90cb7851 编写于 作者: X Xun Deng

move nn/distribution to nn/probability/distribution

上级 1f4222ed
......@@ -17,14 +17,14 @@ Neural Networks Cells.
Pre-defined building blocks or computing units to construct Neural Networks.
"""
from . import layer, loss, optim, metrics, wrap, distribution
from . import layer, loss, optim, metrics, wrap, probability
from .cell import Cell, GraphKernel
from .layer import *
from .loss import *
from .optim import *
from .metrics import *
from .wrap import *
from .distribution import *
from .probability import *
__all__ = ["Cell", "GraphKernel"]
......@@ -33,7 +33,7 @@ __all__.extend(loss.__all__)
__all__.extend(optim.__all__)
__all__.extend(metrics.__all__)
__all__.extend(wrap.__all__)
__all__.extend(distribution.__all__)
__all__.extend(probability.__all__)
__all__.sort()
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Probability.
The high-level components(Distributions) used to construct the probabilistic network.
"""
from .distribution import *
__all__ = []
__all__.extend(distribution.__all__)
......@@ -16,9 +16,9 @@
"""Utitly functions to help distribution class."""
import numpy as np
from mindspore.ops import _utils as utils
from ....common.tensor import Tensor
from ....common.parameter import Parameter
from ....common import dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from mindspore.common import dtype as mstype
def cast_to_tensor(t, dtype=mstype.float32):
"""
......
......@@ -13,10 +13,10 @@
# limitations under the License.
# ============================================================================
"""Bernoulli Distribution"""
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob
from ...common import dtype as mstype
class Bernoulli(Distribution):
"""
......
......@@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""basic"""
from ..cell import Cell
from mindspore.nn.cell import Cell
from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param
class Distribution(Cell):
......
......@@ -15,8 +15,8 @@
"""Exponential Distribution"""
import numpy as np
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ...common import dtype as mstype
from ._utils.utils import cast_to_tensor, check_greater_zero
class Exponential(Distribution):
......
......@@ -15,9 +15,9 @@
"""Geometric Distribution"""
import numpy as np
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob
from ...common import dtype as mstype
class Geometric(Distribution):
"""
......
......@@ -16,10 +16,11 @@
import numpy as np
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from mindspore.context import get_context
from .distribution import Distribution
from ._utils.utils import convert_to_batch, check_greater_equal_zero
from ...common import dtype as mstype
from ...context import get_context
class Normal(Distribution):
"""
......
......@@ -14,8 +14,8 @@
# ============================================================================
"""Uniform Distribution"""
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ...common import dtype as mstype
from ._utils.utils import convert_to_batch, check_greater
class Uniform(Distribution):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册