提交 acf46baf 编写于 作者: P pkuliuliu

add Normal op

上级 e9670f3c
...@@ -25,3 +25,4 @@ from .squeeze import _squeeze_aicpu ...@@ -25,3 +25,4 @@ from .squeeze import _squeeze_aicpu
from .expand_dims import _expand_dims_aicpu from .expand_dims import _expand_dims_aicpu
from .random_choice_with_mask import _random_choice_with_mask_aicpu from .random_choice_with_mask import _random_choice_with_mask_aicpu
from .pack import _pack_aicpu from .pack import _pack_aicpu
from .normal import _normal_aicpu
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Normal op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
normal_op_info = AiCPURegOp("Normal") \
.fusion_type("OPAQUE") \
.input(0, "shape", "required") \
.input(1, "mean", "required") \
.input(2, "stddev", "required") \
.output(0, "y", "required") \
.attr("seed", "int") \
.dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \
.get_op_info()
@op_info_register(normal_op_info)
def _normal_aicpu():
"""Normal AiCPU register"""
return
...@@ -53,7 +53,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AssignAdd, AssignSub, Atan2 ...@@ -53,7 +53,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AssignAdd, AssignSub, Atan2
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e,
Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh) Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh)
from .random_ops import (RandomChoiceWithMask) from .random_ops import (RandomChoiceWithMask, Normal)
from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, ApplyMomentum, BatchNorm, from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, ApplyMomentum, BatchNorm,
BiasAdd, Conv2D, BiasAdd, Conv2D,
DepthwiseConv2dNative, DepthwiseConv2dNative,
...@@ -163,6 +163,7 @@ __all__ = [ ...@@ -163,6 +163,7 @@ __all__ = [
'HSigmoid', 'HSigmoid',
'Tanh', 'Tanh',
'RandomChoiceWithMask', 'RandomChoiceWithMask',
'Normal',
'ResizeBilinear', 'ResizeBilinear',
'ScalarSummary', 'ScalarSummary',
'ImageSummary', 'ImageSummary',
......
...@@ -64,3 +64,47 @@ class RandomChoiceWithMask(PrimitiveWithInfer): ...@@ -64,3 +64,47 @@ class RandomChoiceWithMask(PrimitiveWithInfer):
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, [mstype.bool_], self.name) validator.check_tensor_type_same({'x': x_dtype}, [mstype.bool_], self.name)
return (mstype.int32, mstype.bool_) return (mstype.int32, mstype.bool_)
class Normal(PrimitiveWithInfer):
"""
Generates random samples from a normal(Gaussian) distribution.
Args:
seed (int): Random seed. Default: 0.
Inputs:
- **shape** (tuple[int]) - The shape of output tensor. Only constant value is allowed.
- **mean** (Tensor) - The mean of the distribution, with float32 data type.
- **stddev** (Tensor) - The standard deviation of the distribution, with float32 data type.
Outputs:
Tensor, with the given shape from the specific distribution and float32 data type.
Examples:
>>> normal = P.Normal()
>>> mean = Tensor(0., mstype.float32)
>>> stddev = Tensor(1., mstype.float32)
>>> out = normal((32, 3, 3), mean, stddev)
"""
@prim_attr_register
def __init__(self, seed=0):
"""Init Normal"""
validator.check_value_type("seed", seed, [int], self.name)
def __infer__(self, shape, mean, stddev):
shape_value = shape["value"]
if shape_value is None:
raise ValueError(f"For {self.name}, shape must be const.")
validator.check_value_type("shape", shape_value, [tuple], self.name)
for i, shape_i in enumerate(shape_value):
validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GE, self.name)
validator.check_tensor_type_same({"mean": mean["dtype"]}, [mstype.float32], self.name)
validator.check_tensor_type_same({"stddev": stddev["dtype"]}, [mstype.float32], self.name)
out = {"shape": shape_value,
"dtype": mstype.float32,
"value": None}
return out
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore.context as context
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common import Tensor
from mindspore.common import dtype as mstype
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
class Net(nn.Cell):
def __init__(self, shape=None, mean=0.0, stddev=1.0, seed=0):
super(Net, self).__init__()
self._mean = Tensor(mean, mstype.float32)
self._stddev = Tensor(stddev, mstype.float32)
self._normal = P.Normal(seed=seed)
self._shape = shape
def construct(self):
return self._normal(self._shape, self._mean, self._stddev)
def test_net_3x2x4():
mean = 0.0
stddev = 1.0
seed = 0
net = Net((3, 2, 4), mean, stddev, seed)
out = net()
assert out.shape == (3, 2, 4)
...@@ -399,6 +399,19 @@ class InplaceSubNet(nn.Cell): ...@@ -399,6 +399,19 @@ class InplaceSubNet(nn.Cell):
return out return out
class NormalNet(nn.Cell):
def __init__(self, shape=None, mean=0.0, stddev=1.0, seed=0):
super(NormalNet, self).__init__()
self.normal = P.Normal(seed=seed)
self.shape = shape
self.mean = Tensor(mean, mstype.float32)
self.stddev = Tensor(stddev, mstype.float32)
def construct(self):
out = self.normal(self.shape, self.mean, self.stddev)
return out
test_case_math_ops = [ test_case_math_ops = [
('BitwiseAnd', { ('BitwiseAnd', {
'block': P.BitwiseAnd(), 'block': P.BitwiseAnd(),
...@@ -895,6 +908,10 @@ test_case_math_ops = [ ...@@ -895,6 +908,10 @@ test_case_math_ops = [
'desc_inputs': [Tensor([-1.0, 0.0, 1.5, 2.0, 5.0, 15], mstype.float16), Tensor([0.0, 5.0], mstype.float16)], 'desc_inputs': [Tensor([-1.0, 0.0, 1.5, 2.0, 5.0, 15], mstype.float16), Tensor([0.0, 5.0], mstype.float16)],
'desc_bprop': [], 'desc_bprop': [],
'skip': ['backward']}), 'skip': ['backward']}),
('Normal', {
'block': NormalNet((3, 2, 4), 0.0, 1.0, 0),
'desc_inputs': [],
'skip': ['backward']}),
] ]
test_case_nn_ops = [ test_case_nn_ops = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册