提交 b9d3495c 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1081 format nn. and change quant.py

Merge pull request !1081 from SanjayChan/formast
...@@ -17,24 +17,26 @@ Layer. ...@@ -17,24 +17,26 @@ Layer.
The high-level components(Cells) used to construct the neural network. The high-level components(Cells) used to construct the neural network.
""" """
from .activation import Softmax, LogSoftmax, ReLU, ReLU6, Tanh, GELU, ELU, Sigmoid, PReLU, get_activation, LeakyReLU, HSigmoid, HSwish from . import activation, normalization, container, conv, lstm, basic, embedding, pooling, image, quant
from .normalization import BatchNorm1d, BatchNorm2d, LayerNorm, GroupNorm, GlobalBatchNorm from .activation import *
from .container import SequentialCell, CellList from .normalization import *
from .conv import Conv2d, Conv2dTranspose from .container import *
from .lstm import LSTM from .conv import *
from .basic import Dropout, Flatten, Dense, ClipByNorm, Norm, OneHot, Pad, Unfold from .lstm import *
from .embedding import Embedding from .basic import *
from .pooling import AvgPool2d, MaxPool2d, AvgPool1d from .embedding import *
from .image import ImageGradients, SSIM, PSNR from .pooling import *
from .image import *
from .quant import *
__all__ = ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid', __all__ = []
'PReLU', 'get_activation', 'LeakyReLU', 'HSigmoid', 'HSwish', 'ELU', __all__.extend(activation.__all__)
'BatchNorm1d', 'BatchNorm2d', 'LayerNorm', 'GroupNorm', 'GlobalBatchNorm', __all__.extend(normalization.__all__)
'SequentialCell', 'CellList', __all__.extend(container.__all__)
'Conv2d', 'Conv2dTranspose', __all__.extend(conv.__all__)
'LSTM', __all__.extend(lstm.__all__)
'Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', __all__.extend(basic.__all__)
'Embedding', __all__.extend(embedding.__all__)
'AvgPool2d', 'MaxPool2d', 'AvgPool1d', 'Pad', 'Unfold', __all__.extend(pooling.__all__)
'ImageGradients', 'SSIM', 'PSNR', __all__.extend(image.__all__)
] __all__.extend(quant.__all__)
...@@ -22,6 +22,21 @@ from mindspore.common.tensor import Tensor ...@@ -22,6 +22,21 @@ from mindspore.common.tensor import Tensor
from mindspore._extends import cell_attr_register from mindspore._extends import cell_attr_register
from ..cell import Cell from ..cell import Cell
__all__ = ['Softmax',
'LogSoftmax',
'ReLU',
'ReLU6',
'Tanh',
'GELU',
'Sigmoid',
'PReLU',
'get_activation',
'LeakyReLU',
'HSigmoid',
'HSwish',
'ELU',
]
class Softmax(Cell): class Softmax(Cell):
r""" r"""
...@@ -54,6 +69,7 @@ class Softmax(Cell): ...@@ -54,6 +69,7 @@ class Softmax(Cell):
>>> softmax(input_x) >>> softmax(input_x)
[0.03168 0.01166 0.0861 0.636 0.2341] [0.03168 0.01166 0.0861 0.636 0.2341]
""" """
def __init__(self, axis=-1): def __init__(self, axis=-1):
super(Softmax, self).__init__() super(Softmax, self).__init__()
self.softmax = P.Softmax(axis) self.softmax = P.Softmax(axis)
...@@ -128,6 +144,7 @@ class ELU(Cell): ...@@ -128,6 +144,7 @@ class ELU(Cell):
>>> elu(input_x) >>> elu(input_x)
""" """
def __init__(self, alpha=1.0): def __init__(self, alpha=1.0):
super(ELU, self).__init__() super(ELU, self).__init__()
self.elu = P.Elu(alpha) self.elu = P.Elu(alpha)
...@@ -156,6 +173,7 @@ class ReLU(Cell): ...@@ -156,6 +173,7 @@ class ReLU(Cell):
>>> relu(input_x) >>> relu(input_x)
[0. 2. 0. 2. 0.] [0. 2. 0. 2. 0.]
""" """
def __init__(self): def __init__(self):
super(ReLU, self).__init__() super(ReLU, self).__init__()
self.relu = P.ReLU() self.relu = P.ReLU()
...@@ -184,6 +202,7 @@ class ReLU6(Cell): ...@@ -184,6 +202,7 @@ class ReLU6(Cell):
>>> relu6(input_x) >>> relu6(input_x)
[0. 0. 0. 2. 1.] [0. 0. 0. 2. 1.]
""" """
def __init__(self): def __init__(self):
super(ReLU6, self).__init__() super(ReLU6, self).__init__()
self.relu6 = P.ReLU6() self.relu6 = P.ReLU6()
...@@ -221,6 +240,7 @@ class LeakyReLU(Cell): ...@@ -221,6 +240,7 @@ class LeakyReLU(Cell):
[[-0.2 4. -1.6] [[-0.2 4. -1.6]
[ 2 -1. 9.]] [ 2 -1. 9.]]
""" """
def __init__(self, alpha=0.2): def __init__(self, alpha=0.2):
super(LeakyReLU, self).__init__() super(LeakyReLU, self).__init__()
self.greater_equal = P.GreaterEqual() self.greater_equal = P.GreaterEqual()
...@@ -262,6 +282,7 @@ class Tanh(Cell): ...@@ -262,6 +282,7 @@ class Tanh(Cell):
>>> tanh(input_x) >>> tanh(input_x)
[0.7617 0.964 0.995 0.964 0.7617] [0.7617 0.964 0.995 0.964 0.7617]
""" """
def __init__(self): def __init__(self):
super(Tanh, self).__init__() super(Tanh, self).__init__()
self.tanh = P.Tanh() self.tanh = P.Tanh()
...@@ -293,6 +314,7 @@ class GELU(Cell): ...@@ -293,6 +314,7 @@ class GELU(Cell):
[[-1.5880802e-01 3.9999299e+00 -3.1077917e-21] [[-1.5880802e-01 3.9999299e+00 -3.1077917e-21]
[ 1.9545976e+00 -2.2918017e-07 9.0000000e+00]] [ 1.9545976e+00 -2.2918017e-07 9.0000000e+00]]
""" """
def __init__(self): def __init__(self):
super(GELU, self).__init__() super(GELU, self).__init__()
self.gelu = P.Gelu() self.gelu = P.Gelu()
...@@ -322,6 +344,7 @@ class Sigmoid(Cell): ...@@ -322,6 +344,7 @@ class Sigmoid(Cell):
>>> sigmoid(input_x) >>> sigmoid(input_x)
[0.2688 0.11914 0.5 0.881 0.7305] [0.2688 0.11914 0.5 0.881 0.7305]
""" """
def __init__(self): def __init__(self):
super(Sigmoid, self).__init__() super(Sigmoid, self).__init__()
self.sigmoid = P.Sigmoid() self.sigmoid = P.Sigmoid()
...@@ -410,6 +433,7 @@ class HSwish(Cell): ...@@ -410,6 +433,7 @@ class HSwish(Cell):
>>> hswish(input_x) >>> hswish(input_x)
""" """
def __init__(self): def __init__(self):
super(HSwish, self).__init__() super(HSwish, self).__init__()
self.hswish = P.HSwish() self.hswish = P.HSwish()
...@@ -443,6 +467,7 @@ class HSigmoid(Cell): ...@@ -443,6 +467,7 @@ class HSigmoid(Cell):
>>> hsigmoid(input_x) >>> hsigmoid(input_x)
""" """
def __init__(self): def __init__(self):
super(HSigmoid, self).__init__() super(HSigmoid, self).__init__()
self.hsigmoid = P.HSigmoid() self.hsigmoid = P.HSigmoid()
......
...@@ -30,6 +30,7 @@ from ..cell import Cell ...@@ -30,6 +30,7 @@ from ..cell import Cell
from .activation import get_activation from .activation import get_activation
from ..._checkparam import Validator as validator from ..._checkparam import Validator as validator
__all__ = ['Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', 'Pad', 'Unfold']
class Dropout(Cell): class Dropout(Cell):
r""" r"""
......
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
"""container""" """container"""
from collections import OrderedDict from collections import OrderedDict
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
from ..cell import Cell from ..cell import Cell
__all__ = ['SequentialCell', 'CellList']
def _valid_index(cell_num, index): def _valid_index(cell_num, index):
if not isinstance(index, int): if not isinstance(index, int):
......
...@@ -21,6 +21,7 @@ from mindspore._checkparam import check_bool, twice, check_int_positive, check_i ...@@ -21,6 +21,7 @@ from mindspore._checkparam import check_bool, twice, check_int_positive, check_i
from mindspore._extends import cell_attr_register from mindspore._extends import cell_attr_register
from ..cell import Cell from ..cell import Cell
__all__ = ['Conv2d', 'Conv2dTranspose']
class _Conv(Cell): class _Conv(Cell):
""" """
......
...@@ -21,6 +21,7 @@ from mindspore.common.initializer import initializer ...@@ -21,6 +21,7 @@ from mindspore.common.initializer import initializer
from ..cell import Cell from ..cell import Cell
from ..._checkparam import Validator as validator from ..._checkparam import Validator as validator
__all__ = ['Embedding']
class Embedding(Cell): class Embedding(Cell):
r""" r"""
......
...@@ -23,6 +23,7 @@ from mindspore._checkparam import Validator as validator ...@@ -23,6 +23,7 @@ from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel from mindspore._checkparam import Rel
from ..cell import Cell from ..cell import Cell
__all__ = ['ImageGradients', 'SSIM', 'PSNR']
class ImageGradients(Cell): class ImageGradients(Cell):
r""" r"""
......
...@@ -19,6 +19,7 @@ from mindspore.common.parameter import Parameter ...@@ -19,6 +19,7 @@ from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
__all__ = ['LSTM']
class LSTM(Cell): class LSTM(Cell):
r""" r"""
......
...@@ -29,6 +29,8 @@ from mindspore._checkparam import check_int_positive ...@@ -29,6 +29,8 @@ from mindspore._checkparam import check_int_positive
from ..cell import Cell from ..cell import Cell
__all__ = ['BatchNorm1d', 'BatchNorm2d', 'LayerNorm', 'GroupNorm', 'GlobalBatchNorm']
class _BatchNorm(Cell): class _BatchNorm(Cell):
"""Batch Normalization base class.""" """Batch Normalization base class."""
@cell_attr_register @cell_attr_register
......
...@@ -21,6 +21,7 @@ from ... import context ...@@ -21,6 +21,7 @@ from ... import context
from ..cell import Cell from ..cell import Cell
from ..._checkparam import Rel from ..._checkparam import Rel
__all__ = ['AvgPool2d', 'MaxPool2d', 'AvgPool1d']
class _PoolNd(Cell): class _PoolNd(Cell):
"""N-D AvgPool""" """N-D AvgPool"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册