提交 c2385e2e 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3615 Move nn/distribution to nn/probability/distribution

Merge pull request !3615 from XunDeng/pp_poc_v3
...@@ -17,14 +17,14 @@ Neural Networks Cells. ...@@ -17,14 +17,14 @@ Neural Networks Cells.
Pre-defined building blocks or computing units to construct Neural Networks. Pre-defined building blocks or computing units to construct Neural Networks.
""" """
from . import layer, loss, optim, metrics, wrap, distribution from . import layer, loss, optim, metrics, wrap, probability
from .cell import Cell, GraphKernel from .cell import Cell, GraphKernel
from .layer import * from .layer import *
from .loss import * from .loss import *
from .optim import * from .optim import *
from .metrics import * from .metrics import *
from .wrap import * from .wrap import *
from .distribution import * from .probability import *
__all__ = ["Cell", "GraphKernel"] __all__ = ["Cell", "GraphKernel"]
...@@ -33,7 +33,7 @@ __all__.extend(loss.__all__) ...@@ -33,7 +33,7 @@ __all__.extend(loss.__all__)
__all__.extend(optim.__all__) __all__.extend(optim.__all__)
__all__.extend(metrics.__all__) __all__.extend(metrics.__all__)
__all__.extend(wrap.__all__) __all__.extend(wrap.__all__)
__all__.extend(distribution.__all__) __all__.extend(probability.__all__)
__all__.sort() __all__.sort()
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Probability.
The high-level components(Distributions) used to construct the probabilistic network.
"""
from .distribution import *
__all__ = []
__all__.extend(distribution.__all__)
...@@ -16,9 +16,9 @@ ...@@ -16,9 +16,9 @@
"""Utitly functions to help distribution class.""" """Utitly functions to help distribution class."""
import numpy as np import numpy as np
from mindspore.ops import _utils as utils from mindspore.ops import _utils as utils
from ....common.tensor import Tensor from mindspore.common.tensor import Tensor
from ....common.parameter import Parameter from mindspore.common.parameter import Parameter
from ....common import dtype as mstype from mindspore.common import dtype as mstype
def cast_to_tensor(t, dtype=mstype.float32): def cast_to_tensor(t, dtype=mstype.float32):
""" """
......
...@@ -13,10 +13,10 @@ ...@@ -13,10 +13,10 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Bernoulli Distribution""" """Bernoulli Distribution"""
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P from mindspore.ops import operations as P
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob from ._utils.utils import cast_to_tensor, check_prob
from ...common import dtype as mstype
class Bernoulli(Distribution): class Bernoulli(Distribution):
""" """
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""basic""" """basic"""
from ..cell import Cell from mindspore.nn.cell import Cell
from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param
class Distribution(Cell): class Distribution(Cell):
......
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
"""Exponential Distribution""" """Exponential Distribution"""
import numpy as np import numpy as np
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ...common import dtype as mstype
from ._utils.utils import cast_to_tensor, check_greater_zero from ._utils.utils import cast_to_tensor, check_greater_zero
class Exponential(Distribution): class Exponential(Distribution):
......
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
"""Geometric Distribution""" """Geometric Distribution"""
import numpy as np import numpy as np
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob from ._utils.utils import cast_to_tensor, check_prob
from ...common import dtype as mstype
class Geometric(Distribution): class Geometric(Distribution):
""" """
......
...@@ -16,10 +16,11 @@ ...@@ -16,10 +16,11 @@
import numpy as np import numpy as np
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from mindspore.context import get_context
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import convert_to_batch, check_greater_equal_zero from ._utils.utils import convert_to_batch, check_greater_equal_zero
from ...common import dtype as mstype
from ...context import get_context
class Normal(Distribution): class Normal(Distribution):
""" """
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
# ============================================================================ # ============================================================================
"""Uniform Distribution""" """Uniform Distribution"""
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ...common import dtype as mstype
from ._utils.utils import convert_to_batch, check_greater from ._utils.utils import convert_to_batch, check_greater
class Uniform(Distribution): class Uniform(Distribution):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册