提交 4f75adb1 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4297 sync code incubator to master

Merge pull request !4297 from guozhijian/code_sync_incubator_f3c32baf_to_master_fcfc75a3_0811
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
"""aicpu ops""" """aicpu ops"""
from .init_data_set_queue import _init_data_set_queue_aicpu from .init_data_set_queue import _init_data_set_queue_aicpu
from .embedding_lookup import _embedding_lookup_aicpu from .embedding_lookup import _embedding_lookup_aicpu
from .padding import _padding_aicpu
from .dropout_genmask import _dropout_genmask_aicpu from .dropout_genmask import _dropout_genmask_aicpu
from .get_next import _get_next_aicpu from .get_next import _get_next_aicpu
from .print_tensor import _print_aicpu from .print_tensor import _print_aicpu
...@@ -43,3 +44,7 @@ from .laplace import _laplace_aicpu ...@@ -43,3 +44,7 @@ from .laplace import _laplace_aicpu
from .strided_slice import _strided_slice_aicpu from .strided_slice import _strided_slice_aicpu
from .strided_slice_grad import _strided_slice_grad_aicpu from .strided_slice_grad import _strided_slice_grad_aicpu
from .end_of_sequence import _end_of_sequence_aicpu from .end_of_sequence import _end_of_sequence_aicpu
from .fused_sparse_adam import _fused_sparse_adam_aicpu
from .fused_sparse_lazy_adam import _fused_sparse_lazy_adam_aicpu
from .fused_sparse_ftrl import _fused_sparse_ftrl_aicpu
from .fused_sparse_proximal_adagrad import _fused_sparse_proximal_adagrad_aicpu
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FusedSparseAdam op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
fused_sparse_adam_op_info = AiCPURegOp("FusedSparseAdam") \
.fusion_type("OPAQUE") \
.attr("use_locking", "bool") \
.attr("use_nesterov", "bool") \
.input(0, "var", "required") \
.input(1, "m", "required") \
.input(2, "v", "required") \
.input(3, "beta1_power", "required") \
.input(4, "beta2_power", "required") \
.input(5, "lr", "required") \
.input(6, "beta1", "required") \
.input(7, "beta2", "required") \
.input(8, "epsilon", "required") \
.input(9, "grad", "required") \
.input(10, "indices", "required") \
.output(0, "var", "required") \
.output(1, "m", "required") \
.output(2, "v", "required") \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(fused_sparse_adam_op_info)
def _fused_sparse_adam_aicpu():
"""FusedSparseAdam aicpu register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FusedSparseFtrl op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
fused_sparse_ftrl_op_info = AiCPURegOp("FusedSparseFtrl") \
.fusion_type("OPAQUE") \
.attr("lr", "float") \
.attr("l1", "float") \
.attr("l2", "float") \
.attr("lr_power", "float") \
.attr("use_locking", "bool") \
.input(0, "var", "required") \
.input(1, "accum", "required") \
.input(2, "linear", "required") \
.input(3, "grad", "required") \
.input(4, "indices", "required") \
.output(0, "var", "required") \
.output(1, "accum", "required") \
.output(2, "linear", "required") \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(fused_sparse_ftrl_op_info)
def _fused_sparse_ftrl_aicpu():
"""FusedSparseFtrl aicpu register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FusedSparseLazyAdam op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
fused_sparse_lazy_adam_op_info = AiCPURegOp("FusedSparseLazyAdam") \
.fusion_type("OPAQUE") \
.attr("use_locking", "bool") \
.attr("use_nesterov", "bool") \
.input(0, "var", "required") \
.input(1, "m", "required") \
.input(2, "v", "required") \
.input(3, "beta1_power", "required") \
.input(4, "beta2_power", "required") \
.input(5, "lr", "required") \
.input(6, "beta1", "required") \
.input(7, "beta2", "required") \
.input(8, "epsilon", "required") \
.input(9, "grad", "required") \
.input(10, "indices", "required") \
.output(0, "var", "required") \
.output(1, "m", "required") \
.output(2, "v", "required") \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(fused_sparse_lazy_adam_op_info)
def _fused_sparse_lazy_adam_aicpu():
"""FusedSparseLazyAdam aicpu register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FusedSparseProximalAdagrad op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
fused_sparse_proximal_adagrad_op_info = AiCPURegOp("FusedSparseProximalAdagrad") \
.fusion_type("OPAQUE") \
.attr("use_locking", "bool") \
.input(0, "var", "required") \
.input(1, "accum", "required") \
.input(2, "lr", "required") \
.input(3, "l1", "required") \
.input(4, "l2", "required") \
.input(5, "grad", "required") \
.input(6, "indices", "required") \
.output(0, "var", "required") \
.output(1, "accum", "required") \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, DataType.F32_Default,
DataType.F32_Default) \
.get_op_info()
@op_info_register(fused_sparse_proximal_adagrad_op_info)
def _fused_sparse_proximal_adagrad_aicpu():
"""FusedSparseProximalAdagrad aicpu register"""
return
...@@ -23,6 +23,7 @@ gamma_op_info = AiCPURegOp("Gamma") \ ...@@ -23,6 +23,7 @@ gamma_op_info = AiCPURegOp("Gamma") \
.input(2, "beta", "required") \ .input(2, "beta", "required") \
.output(0, "output", "required") \ .output(0, "output", "required") \
.attr("seed", "int") \ .attr("seed", "int") \
.attr("seed2", "int") \
.dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ .dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \ .dtype_format(DataType.I32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \
.get_op_info() .get_op_info()
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Padding op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
padding_op_info = AiCPURegOp("Padding") \
.fusion_type("OPAQUE") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.attr("pad_dim_size", "int") \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.get_op_info()
@op_info_register(padding_op_info)
def _padding_aicpu():
"""Padding AiCPU register"""
return
...@@ -22,6 +22,7 @@ poisson_op_info = AiCPURegOp("Poisson") \ ...@@ -22,6 +22,7 @@ poisson_op_info = AiCPURegOp("Poisson") \
.input(1, "mean", "required") \ .input(1, "mean", "required") \
.output(0, "output", "required") \ .output(0, "output", "required") \
.attr("seed", "int") \ .attr("seed", "int") \
.attr("seed2", "int") \
.dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.I32_Default) \ .dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_NCHW, DataType.F32_NCHW, DataType.I32_NCHW) \ .dtype_format(DataType.I32_NCHW, DataType.F32_NCHW, DataType.I32_NCHW) \
.get_op_info() .get_op_info()
......
...@@ -23,6 +23,7 @@ uniform_int_op_info = AiCPURegOp("UniformInt") \ ...@@ -23,6 +23,7 @@ uniform_int_op_info = AiCPURegOp("UniformInt") \
.input(2, "b", "required") \ .input(2, "b", "required") \
.output(0, "output", "required") \ .output(0, "output", "required") \
.attr("seed", "int") \ .attr("seed", "int") \
.attr("seed2", "int") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_NCHW, DataType.I32_NCHW, DataType.I32_NCHW, DataType.I32_NCHW) \ .dtype_format(DataType.I32_NCHW, DataType.I32_NCHW, DataType.I32_NCHW, DataType.I32_NCHW) \
.get_op_info() .get_op_info()
......
...@@ -19,12 +19,11 @@ from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataTyp ...@@ -19,12 +19,11 @@ from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataTyp
uniform_real_op_info = AiCPURegOp("UniformReal") \ uniform_real_op_info = AiCPURegOp("UniformReal") \
.fusion_type("OPAQUE") \ .fusion_type("OPAQUE") \
.input(0, "shape", "required") \ .input(0, "shape", "required") \
.input(1, "a", "required") \
.input(2, "b", "required") \
.output(0, "output", "required") \ .output(0, "output", "required") \
.attr("seed", "int") \ .attr("seed", "int") \
.dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ .attr("seed2", "int") \
.dtype_format(DataType.I32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \ .dtype_format(DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_NCHW, DataType.F32_NCHW) \
.get_op_info() .get_op_info()
@op_info_register(uniform_real_op_info) @op_info_register(uniform_real_op_info)
......
...@@ -27,7 +27,7 @@ from .clip_ops import clip_by_value ...@@ -27,7 +27,7 @@ from .clip_ops import clip_by_value
from .multitype_ops.add_impl import hyper_add from .multitype_ops.add_impl import hyper_add
from .multitype_ops.ones_like_impl import ones_like from .multitype_ops.ones_like_impl import ones_like
from .multitype_ops.zeros_like_impl import zeros_like from .multitype_ops.zeros_like_impl import zeros_like
from .random_ops import set_seed, normal, multinomial from .random_ops import set_seed, normal, uniform, gamma, poisson, multinomial
__all__ = [ __all__ = [
...@@ -50,5 +50,8 @@ __all__ = [ ...@@ -50,5 +50,8 @@ __all__ = [
'zip_operation', 'zip_operation',
'set_seed', 'set_seed',
'normal', 'normal',
'uniform',
'gamma',
'poisson',
'multinomial', 'multinomial',
'clip_by_value',] 'clip_by_value',]
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Operations for random number generatos.""" """Operations for random number generators."""
from .. import operations as P from .. import operations as P
from .. import functional as F from .. import functional as F
...@@ -66,7 +66,6 @@ def get_seed(): ...@@ -66,7 +66,6 @@ def get_seed():
def normal(shape, mean, stddev, seed=0): def normal(shape, mean, stddev, seed=0):
""" """
Generates random numbers according to the Normal (or Gaussian) random number distribution. Generates random numbers according to the Normal (or Gaussian) random number distribution.
It is defined as:
Args: Args:
shape (tuple): The shape of random tensor to be generated. shape (tuple): The shape of random tensor to be generated.
...@@ -137,10 +136,108 @@ def multinomial(inputs, num_sample=None, replacement=True, seed=0): ...@@ -137,10 +136,108 @@ def multinomial(inputs, num_sample=None, replacement=True, seed=0):
n_dist = shape(inputs)[-2] n_dist = shape(inputs)[-2]
a = Tensor(0.0, mstype.float32) a = Tensor(0.0, mstype.float32)
b = Tensor(1.0, mstype.float32) b = Tensor(1.0, mstype.float32)
uniform = P.UniformReal(seed=seed)((n_dist * num_sample,), a, b) random_uniform = P.UniformReal(seed=seed)((n_dist * num_sample,), a, b)
if n_dist != 1: if n_dist != 1:
uniform = reshape(uniform, (n_dist, num_sample)) random_uniform = reshape(random_uniform, (n_dist, num_sample))
vals = P.RealDiv()(P.Log()(uniform), inputs + 1e-6) vals = P.RealDiv()(P.Log()(random_uniform), inputs + 1e-6)
_, indices = P.TopK()(vals, num_sample) _, indices = P.TopK()(vals, num_sample)
return indices return indices
return P.Multinomial(seed=seed)(inputs, num_sample) return P.Multinomial(seed=seed)(inputs, num_sample)
def uniform(shape, a, b, seed=0, dtype=mstype.float32):
"""
Generates random numbers according to the Uniform random number distribution.
Args:
shape (tuple): The shape of random tensor to be generated.
a (Tensor): The a distribution parameter.
It defines the minimum possibly generated value. With int32 or float32 data type.
If dtype is int32, only one number is allowed.
b (Tensor): The b distribution parameter.
It defines the maximum possibly generated value. With int32 or float32 data type.
If dtype is int32, only one number is allowed.
seed (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers.
Default: 0.
Returns:
Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of a and b.
The dtype is float32.
Examples:
>>> shape = (4, 16)
>>> a = Tensor(1.0, mstype.float32)
>>> b = Tensor(1.0, mstype.float32)
>>> output = C.uniform(shape, a, b, seed=5)
"""
a_dtype = F.dtype(a)
b_dtype = F.dtype(b)
const_utils.check_tensors_dtype_same(a_dtype, dtype, "uniform")
const_utils.check_tensors_dtype_same(b_dtype, dtype, "uniform")
seed1 = get_seed()
seed2 = seed
if const_utils.is_same_type(dtype, mstype.int32):
rnd = P.UniformInt(seed1, seed2)
value = rnd(shape, a, b)
else:
uniform_real = P.UniformReal(seed1, seed2)
rnd = uniform_real(shape)
value = rnd * (b - a) + a
return value
def gamma(shape, alpha, beta, seed=0):
"""
Generates random numbers according to the Gamma random number distribution.
Args:
shape (tuple): The shape of random tensor to be generated.
alpha (Tensor): The alpha α distribution parameter. With float32 data type.
beta (Tensor): The beta β distribution parameter. With float32 data type.
seed (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers.
Default: 0.
Returns:
Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of alpha and beta.
The dtype is float32.
Examples:
>>> shape = (4, 16)
>>> alpha = Tensor(1.0, mstype.float32)
>>> beta = Tensor(1.0, mstype.float32)
>>> output = C.gamma(shape, alpha, beta, seed=5)
"""
alpha_dtype = F.dtype(alpha)
beta_dtype = F.dtype(beta)
const_utils.check_tensors_dtype_same(alpha_dtype, mstype.float32, "gamma")
const_utils.check_tensors_dtype_same(beta_dtype, mstype.float32, "gamma")
seed1 = get_seed()
seed2 = seed
random_gamma = P.Gamma(seed1, seed2)
value = random_gamma(shape, alpha, beta)
return value
def poisson(shape, mean, seed=0):
"""
Generates random numbers according to the Poisson random number distribution.
Args:
shape (tuple): The shape of random tensor to be generated.
mean (Tensor): The mean μ distribution parameter. With float32 data type.
seed (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers.
Default: 0.
Returns:
Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of mean.
The dtype is float32.
Examples:
>>> shape = (4, 16)
>>> mean = Tensor(1.0, mstype.float32)
>>> output = C.poisson(shape, mean, seed=5)
"""
mean_dtype = F.dtype(mean)
const_utils.check_tensors_dtype_same(mean_dtype, mstype.float32, "poisson")
seed1 = get_seed()
seed2 = seed
random_poisson = P.Poisson(seed1, seed2)
value = random_poisson(shape, mean)
return value
...@@ -27,7 +27,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, ...@@ -27,7 +27,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue,
SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin, SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin,
ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
Shape, Size, Slice, Split, TransShape, ParallelConcat, Shape, Size, Slice, Split, TransShape, ParallelConcat, Padding,
ScatterNdAdd, ScatterNdSub, ScatterNonAliasingAdd, ReverseV2, Rint, ScatterNdAdd, ScatterNdSub, ScatterNonAliasingAdd, ReverseV2, Rint,
Squeeze, StridedSlice, Tile, TensorScatterUpdate, Squeeze, StridedSlice, Tile, TensorScatterUpdate,
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentProd, Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentProd,
...@@ -147,6 +147,7 @@ __all__ = [ ...@@ -147,6 +147,7 @@ __all__ = [
'GatherV2', 'GatherV2',
'SparseGatherV2', 'SparseGatherV2',
'EmbeddingLookup', 'EmbeddingLookup',
'Padding',
'Concat', 'Concat',
'Pack', 'Pack',
'Unpack', 'Unpack',
......
...@@ -645,6 +645,46 @@ class SparseGatherV2(GatherV2): ...@@ -645,6 +645,46 @@ class SparseGatherV2(GatherV2):
""" """
class Padding(PrimitiveWithInfer):
"""
Extend the last dimension of input tensor from 1 to pad_dim_size, fill with 0.
Args:
pad_dim_size (int): The extend value of last dimension of x, must be positive.
Inputs:
- **x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. The rank of x should be at least 2.
The last dimension of x should be 1.
Outputs:
Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
Examples:
>>> x = Tensor(np.array([[8], [10]]), mindspore.float32)
>>> pad_dim_size = 4
>>> out = P.Padding(pad_dim_size)(x)
[[8, 0, 0, 0], [10, 0, 0, 0]]
"""
@prim_attr_register
def __init__(self, pad_dim_size=8):
"""init padding"""
validator.check_value_type("pad_dim_size", pad_dim_size, [int], self.name)
validator.check_integer("pad_dim_size", pad_dim_size, 0, Rel.GT, self.name)
self.pad_dim_size = pad_dim_size
def __infer__(self, x):
validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
x_shape = list(x['shape'])
validator.check_integer("rank of x", len(x_shape), 1, Rel.GT, self.name)
validator.check_integer("last dim of x", x_shape[-1], 1, Rel.EQ, self.name)
out_shape = x_shape
out_shape[-1] = self.pad_dim_size
out = {'shape': out_shape,
'dtype': x['dtype'],
'value': None}
return out
class Split(PrimitiveWithInfer): class Split(PrimitiveWithInfer):
""" """
Splits input tensor into output_num of tensors along the given axis and output numbers. Splits input tensor into output_num of tensors along the given axis and output numbers.
......
...@@ -34,8 +34,7 @@ class StandardNormal(PrimitiveWithInfer): ...@@ -34,8 +34,7 @@ class StandardNormal(PrimitiveWithInfer):
- **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed. - **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed.
Outputs: Outputs:
Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of mean and stddev. Tensor. The shape that the input 'shape' denotes. The dtype is float32.
The dtype is float32.
Examples: Examples:
>>> shape = (4, 16) >>> shape = (4, 16)
...@@ -126,8 +125,8 @@ class Gamma(PrimitiveWithInfer): ...@@ -126,8 +125,8 @@ class Gamma(PrimitiveWithInfer):
\text{P}(x|α,β) = \frac{\exp(-x/β)}{{β^α}\cdot{\Gamma(α)}}\cdot{x^{α-1}}, \text{P}(x|α,β) = \frac{\exp(-x/β)}{{β^α}\cdot{\Gamma(α)}}\cdot{x^{α-1}},
Args: Args:
seed (int): Seed data is used as entropy source for Random number engines generating pseudo-random numbers. seed (int): Random seed. Default: 0.
Default: 0. seed2 (int): Random seed2. Default: 0.
Inputs: Inputs:
- **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed. - **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed.
...@@ -149,10 +148,11 @@ class Gamma(PrimitiveWithInfer): ...@@ -149,10 +148,11 @@ class Gamma(PrimitiveWithInfer):
""" """
@prim_attr_register @prim_attr_register
def __init__(self, seed=0): def __init__(self, seed=0, seed2=0):
"""Init Gamma""" """Init Gamma"""
self.init_prim_io_names(inputs=['shape', 'alpha', 'beta'], outputs=['output']) self.init_prim_io_names(inputs=['shape', 'alpha', 'beta'], outputs=['output'])
validator.check_value_type('seed', seed, [int], self.name) validator.check_value_type('seed', seed, [int], self.name)
validator.check_value_type('seed2', seed2, [int], self.name)
def __infer__(self, shape, alpha, beta): def __infer__(self, shape, alpha, beta):
shape_v = shape["value"] shape_v = shape["value"]
...@@ -180,8 +180,8 @@ class Poisson(PrimitiveWithInfer): ...@@ -180,8 +180,8 @@ class Poisson(PrimitiveWithInfer):
\text{P}(i|μ) = \frac{\exp(-μ)μ^{i}}{i!}, \text{P}(i|μ) = \frac{\exp(-μ)μ^{i}}{i!},
Args: Args:
seed (int): Seed data is used as entropy source for Random number engines generating pseudo-random numbers. seed (int): Random seed. Default: 0.
Default: 0. seed2 (int): Random seed2. Default: 0.
Inputs: Inputs:
- **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed. - **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed.
...@@ -200,10 +200,11 @@ class Poisson(PrimitiveWithInfer): ...@@ -200,10 +200,11 @@ class Poisson(PrimitiveWithInfer):
""" """
@prim_attr_register @prim_attr_register
def __init__(self, seed=0): def __init__(self, seed=0, seed2=0):
"""Init Poisson""" """Init Poisson"""
self.init_prim_io_names(inputs=['shape', 'mean'], outputs=['output']) self.init_prim_io_names(inputs=['shape', 'mean'], outputs=['output'])
validator.check_value_type('seed', seed, [int], self.name) validator.check_value_type('seed', seed, [int], self.name)
validator.check_value_type('seed2', seed2, [int], self.name)
def __infer__(self, shape, mean): def __infer__(self, shape, mean):
shape_v = shape["value"] shape_v = shape["value"]
...@@ -223,7 +224,7 @@ class Poisson(PrimitiveWithInfer): ...@@ -223,7 +224,7 @@ class Poisson(PrimitiveWithInfer):
class UniformInt(PrimitiveWithInfer): class UniformInt(PrimitiveWithInfer):
r""" r"""
Produces random integer values i, uniformly distributed on the closed interval [a, b], that is, Produces random integer values i, uniformly distributed on the closed interval [a, b), that is,
distributed according to the discrete probability function: distributed according to the discrete probability function:
.. math:: .. math::
...@@ -233,19 +234,18 @@ class UniformInt(PrimitiveWithInfer): ...@@ -233,19 +234,18 @@ class UniformInt(PrimitiveWithInfer):
The number in tensor a should be strictly less than b at any position after broadcasting. The number in tensor a should be strictly less than b at any position after broadcasting.
Args: Args:
seed (int): Seed data is used as entropy source for Random number engines generating pseudo-random numbers. seed (int): Random seed. Default: 0.
Default: 0. seed2 (int): Random seed2. Default: 0.
Inputs: Inputs:
- **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed. - **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed.
- **a** (Tensor) - The a distribution parameter. - **a** (Tensor) - The a distribution parameter.
It defines the minimum possibly generated value. With int32 data type. It defines the minimum possibly generated value. With int32 data type. Only one number is supported.
- **b** (Tensor) - The b distribution parameter. - **b** (Tensor) - The b distribution parameter.
It defines the maximum possibly generated value. With int32 data type. It defines the maximum possibly generated value. With int32 data type. Only one number is supported.
Outputs: Outputs:
Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of a and b. Tensor. The shape that the input 'shape' denotes. The dtype is int32.
The dtype is int32.
Examples: Examples:
>>> shape = (4, 16) >>> shape = (4, 16)
...@@ -256,10 +256,11 @@ class UniformInt(PrimitiveWithInfer): ...@@ -256,10 +256,11 @@ class UniformInt(PrimitiveWithInfer):
""" """
@prim_attr_register @prim_attr_register
def __init__(self, seed=0): def __init__(self, seed=0, seed2=0):
"""Init UniformInt""" """Init UniformInt"""
self.init_prim_io_names(inputs=['shape', 'a', 'b'], outputs=['output']) self.init_prim_io_names(inputs=['shape', 'a', 'b'], outputs=['output'])
validator.check_value_type('seed', seed, [int], self.name) validator.check_value_type('seed', seed, [int], self.name)
validator.check_value_type('seed2', seed2, [int], self.name)
def __infer__(self, shape, a, b): def __infer__(self, shape, a, b):
shape_v = shape["value"] shape_v = shape["value"]
...@@ -270,10 +271,12 @@ class UniformInt(PrimitiveWithInfer): ...@@ -270,10 +271,12 @@ class UniformInt(PrimitiveWithInfer):
validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name) validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name)
validator.check_tensor_type_same({"a": a["dtype"]}, [mstype.int32], self.name) validator.check_tensor_type_same({"a": a["dtype"]}, [mstype.int32], self.name)
validator.check_tensor_type_same({"b": b["dtype"]}, [mstype.int32], self.name) validator.check_tensor_type_same({"b": b["dtype"]}, [mstype.int32], self.name)
broadcast_shape = get_broadcast_shape(a['shape'], b['shape'], self.name) a_shape = a['shape']
broadcast_shape = get_broadcast_shape(broadcast_shape, shape_v, self.name) b_shape = b['shape']
validator.check("dim of a", len(a_shape), '0(scalar)', 0, Rel.EQ, self.name)
validator.check("dim of b", len(b_shape), '0(scalar)', 0, Rel.EQ, self.name)
out = { out = {
'shape': broadcast_shape, 'shape': shape_v,
'dtype': mstype.int32, 'dtype': mstype.int32,
'value': None} 'value': None}
return out return out
...@@ -281,54 +284,40 @@ class UniformInt(PrimitiveWithInfer): ...@@ -281,54 +284,40 @@ class UniformInt(PrimitiveWithInfer):
class UniformReal(PrimitiveWithInfer): class UniformReal(PrimitiveWithInfer):
r""" r"""
Produces random floating-point values i, uniformly distributed on the interval [min(a, b), max(a, b)), that is,\ Produces random floating-point values i, uniformly distributed on the interval [0, 1).
distributed according to the probability density function:
.. math::
\text{P}(i|a,b) = \frac{1}{b-a},
Args: Args:
seed (int): Seed data is used as entropy source for Random number engines generating pseudo-random numbers. seed (int): Random seed. Default: 0.
Default: 0. seed2 (int): Random seed2. Default: 0.
Inputs: Inputs:
- **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed. - **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed.
- **a** (Tensor) - The a distribution parameter.
It defines the minimum possibly generated value. With float32 data type.
- **b** (Tensor) - The b distribution parameter.
It defines the maximum possibly generated value. With float32 data type.
Outputs: Outputs:
Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of a and b. Tensor. The shape that the input 'shape' denotes. The dtype is float32.
The dtype is float32.
Examples: Examples:
>>> shape = (4, 16) >>> shape = (4, 16)
>>> a = Tensor(1.0, mstype.float32) >>> uniformreal = P.UniformReal(seed=2)
>>> b = Tensor(5.0, mstype.float32) >>> output = uniformreal(shape)
>>> uniform_real = P.UniformReal(seed=10)
>>> output = uniform_real(shape, a, b)
""" """
@prim_attr_register @prim_attr_register
def __init__(self, seed=0): def __init__(self, seed=0, seed2=0):
"""Init UniformReal""" """Init UniformReal"""
self.init_prim_io_names(inputs=['shape', 'a', 'b'], outputs=['output']) self.init_prim_io_names(inputs=['shape'], outputs=['output'])
validator.check_value_type('seed', seed, [int], self.name) validator.check_value_type('seed', seed, [int], self.name)
validator.check_value_type('seed2', seed2, [int], self.name)
def __infer__(self, shape, a, b): def __infer__(self, shape):
shape_v = shape["value"] shape_v = shape["value"]
if shape_v is None: if shape_v is None:
raise ValueError(f"For {self.name}, shape must be const.") raise ValueError(f"For {self.name}, shape must be const.")
validator.check_value_type("shape", shape_v, [tuple], self.name) validator.check_value_type("shape", shape_v, [tuple], self.name)
for i, shape_i in enumerate(shape_v): for i, shape_i in enumerate(shape_v):
validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name) validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name)
validator.check_tensor_type_same({"a": a["dtype"]}, [mstype.float32], self.name)
validator.check_tensor_type_same({"b": b["dtype"]}, [mstype.float32], self.name)
broadcast_shape = get_broadcast_shape(a['shape'], b['shape'], self.name)
broadcast_shape = get_broadcast_shape(broadcast_shape, shape_v, self.name)
out = { out = {
'shape': broadcast_shape, 'shape': shape_v,
'dtype': mstype.float32, 'dtype': mstype.float32,
'value': None} 'value': None}
return out return out
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
import mindspore.context as context
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.common.parameter import Parameter
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
beta1_power = 0.9
beta2_power = 0.999
lr = 0.001
beta1 = 0.9
beta2 = 0.999
epsilon = 1e-8
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.fused_sparse_adam = P.FusedSparseAdam()
self.var = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="var")
self.m = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="m")
self.v = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="v")
def construct(self, grad, indices):
return self.fused_sparse_adam(self.var, self.m, self.v, beta1_power, beta2_power, lr, beta1, beta2, epsilon,
grad, indices)
def test_net():
gradient = Tensor(np.array([0.22948648, 0.14569908, 0.92861906, 0.66870148])
.reshape([2, 1, 2]).astype(np.float32))
indices = Tensor([0, 1], mstype.int32)
net = Net()
output = net(gradient, indices)
print(output)
print(net.var.default_input)
print(net.m.default_input)
print(net.v.default_input)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore.common.dtype as mstype
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.common.parameter import Parameter
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
lr = 0.01
l1 = 0.0
l2 = 0.0
lr_power = -0.5
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.fused_sparse_ftrl = P.FusedSparseFtrl(lr=0.1, l1=0.0, l2=0.0, lr_power=-0.5)
self.var = Parameter(Tensor(np.ones([3, 3]).astype(np.float32)), name="var")
self.accum = Parameter(Tensor(np.ones([3, 3]).astype(np.float32)), name="accum")
self.linear = Parameter(Tensor(np.ones([3, 3]).astype(np.float32)), name="linear")
def construct(self, grad, indices):
return self.fused_sparse_ftrl(self.var, self.accum, self.linear, grad, indices)
def test_net():
gradient = Tensor(np.array([-3, 2, 3, 0, 0, 0, -4, -1, -2])
.reshape([3, 3]).astype(np.float32))
indices = Tensor(np.ones([3]), mstype.int32)
net = Net()
output = net(gradient, indices)
print(output)
print(net.var.default_input)
print(net.accum.default_input)
print(net.linear.default_input)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore.common.dtype as mstype
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.common.parameter import Parameter
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
beta1_power = 0.9
beta2_power = 0.999
lr = 0.001
beta1 = 0.9
beta2 = 0.999
epsilon = 1e-8
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.fused_sparse_lazy_adam = P.FusedSparseLazyAdam()
self.var = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="var")
self.m = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="m")
self.v = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="v")
def construct(self, grad, indices):
return self.fused_sparse_lazy_adam(self.var, self.m, self.v, beta1_power, beta2_power,
lr, beta1, beta2, epsilon, grad, indices)
def test_net():
gradient = Tensor(np.array([0.22948648, 0.14569908, 0.92861906, 0.66870148])
.reshape([2, 1, 2]).astype(np.float32))
indices = Tensor([0, 1], mstype.int32)
net = Net()
output = net(gradient, indices)
print(output)
print(net.var.default_input)
print(net.m.default_input)
print(net.v.default_input)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore.nn as nn
import mindspore.context as context
import mindspore.common.dtype as mstype
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.common.parameter import Parameter
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.fused_sparse_proximal_adagrad = P.FusedSparseProximalAdagrad()
self.var = Parameter(Tensor(np.ones([3, 3]).astype(np.float32)), name="var")
self.accum = Parameter(Tensor(np.ones([3, 3]).astype(np.float32)), name="accum")
self.lr = 0.01
self.l1 = 0.0
self.l2 = 0.0
def construct(self, grad, indices):
return self.fused_sparse_proximal_adagrad(self.var, self.accum, self.lr, self.l1, self.l2,
grad, indices)
def test_net():
gradient = Tensor(np.array([-3, 2, 3, 0, 0, 0, -4, -1, -2])
.reshape([3, 3]).astype(np.float32))
indices = Tensor(np.ones([3]), mstype.int32)
net = Net()
output = net(gradient, indices)
print(output)
print(net.var.default_input)
print(net.accum.default_input)
...@@ -24,9 +24,9 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") ...@@ -24,9 +24,9 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self, shape, seed=0): def __init__(self, shape, seed=0, seed2=0):
super(Net, self).__init__() super(Net, self).__init__()
self.gamma = P.Gamma(seed=seed) self.gamma = P.Gamma(seed=seed, seed2=seed2)
self.shape = shape self.shape = shape
def construct(self, alpha, beta): def construct(self, alpha, beta):
...@@ -38,10 +38,9 @@ def test_net_1D(): ...@@ -38,10 +38,9 @@ def test_net_1D():
shape = (3, 2, 4) shape = (3, 2, 4)
alpha = 1.0 alpha = 1.0
beta = 1.0 beta = 1.0
net = Net(shape, seed) net = Net(shape=shape, seed=seed)
talpha, tbeta = Tensor(alpha, mstype.float32), Tensor(beta, mstype.float32) talpha, tbeta = Tensor(alpha, mstype.float32), Tensor(beta, mstype.float32)
output = net(talpha, tbeta) output = net(talpha, tbeta)
print(output.asnumpy())
assert output.shape == (3, 2, 4) assert output.shape == (3, 2, 4)
...@@ -50,8 +49,7 @@ def test_net_ND(): ...@@ -50,8 +49,7 @@ def test_net_ND():
shape = (3, 1, 2) shape = (3, 1, 2)
alpha = np.array([[[1], [2]], [[3], [4]], [[5], [6]]]).astype(np.float32) alpha = np.array([[[1], [2]], [[3], [4]], [[5], [6]]]).astype(np.float32)
beta = np.array([1.0]).astype(np.float32) beta = np.array([1.0]).astype(np.float32)
net = Net(shape, seed) net = Net(shape=shape, seed=seed)
talpha, tbeta = Tensor(alpha), Tensor(beta) talpha, tbeta = Tensor(alpha), Tensor(beta)
output = net(talpha, tbeta) output = net(talpha, tbeta)
print(output.asnumpy())
assert output.shape == (3, 2, 2) assert output.shape == (3, 2, 2)
...@@ -24,7 +24,7 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") ...@@ -24,7 +24,7 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self, shape): def __init__(self, shape, seed=0, seed2=0):
super(Net, self).__init__() super(Net, self).__init__()
self.poisson = P.Poisson() self.poisson = P.Poisson()
self.shape = shape self.shape = shape
...@@ -36,17 +36,16 @@ class Net(nn.Cell): ...@@ -36,17 +36,16 @@ class Net(nn.Cell):
def test_net_1(): def test_net_1():
shape = (2, 16) shape = (2, 16)
mean = np.array([5.0]).astype(np.float32) mean = np.array([5.0]).astype(np.float32)
net = Net(shape) net = Net(shape=shape)
tmean = Tensor(mean) tmean = Tensor(mean)
output = net(tmean) output = net(tmean)
print(output.asnumpy())
assert output.shape == (2, 16) assert output.shape == (2, 16)
def test_net_2(): def test_net_2():
shape = (4, 1) shape = (4, 1)
mean = np.array([5.0, 10.0]).astype(np.float32) mean = np.array([5.0, 10.0]).astype(np.float32)
net = Net(shape) net = Net(shape=shape)
tmean = Tensor(mean) tmean = Tensor(mean)
output = net(tmean) output = net(tmean)
print(output.asnumpy()) print(output.asnumpy())
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
import numpy as np
import mindspore.context as context import mindspore.context as context
import mindspore.nn as nn import mindspore.nn as nn
...@@ -24,7 +23,7 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") ...@@ -24,7 +23,7 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self, shape, seed=0): def __init__(self, shape, seed=0, seed2=0):
super(Net, self).__init__() super(Net, self).__init__()
self.uniformint = P.UniformInt(seed=seed) self.uniformint = P.UniformInt(seed=seed)
self.shape = shape self.shape = shape
...@@ -38,10 +37,9 @@ def test_net_1D(): ...@@ -38,10 +37,9 @@ def test_net_1D():
shape = (3, 2, 4) shape = (3, 2, 4)
a = 1 a = 1
b = 5 b = 5
net = Net(shape, seed) net = Net(shape, seed=seed)
ta, tb = Tensor(a, mstype.int32), Tensor(b, mstype.int32) ta, tb = Tensor(a, mstype.int32), Tensor(b, mstype.int32)
output = net(ta, tb) output = net(ta, tb)
print(output.asnumpy())
assert output.shape == (3, 2, 4) assert output.shape == (3, 2, 4)
......
...@@ -12,36 +12,29 @@ ...@@ -12,36 +12,29 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
import numpy as np
import mindspore.context as context import mindspore.context as context
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self, shape, seed=0): def __init__(self, shape, seed=0, seed2=0):
super(Net, self).__init__() super(Net, self).__init__()
self.uniformreal = P.UniformReal(seed=seed) self.uniformreal = P.UniformReal(seed=seed)
self.shape = shape self.shape = shape
def construct(self, a, b): def construct(self):
return self.uniformreal(self.shape, a, b) return self.uniformreal(self.shape)
def test_net_1D(): def test_net():
seed = 10 seed = 10
shape = (3, 2, 4) shape = (3, 2, 4)
a = 1.0 net = Net(shape, seed=seed)
b = 5.0 output = net()
net = Net(shape, seed)
ta, tb = Tensor(a, mstype.float32), Tensor(b, mstype.float32)
output = net(ta, tb)
print(output.asnumpy())
assert output.shape == (3, 2, 4) assert output.shape == (3, 2, 4)
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.ops import composite as C
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell):
def __init__(self, shape, seed=0):
super(Net, self).__init__()
self.shape = shape
self.seed = seed
def construct(self, alpha, beta):
C.set_seed(20)
return C.gamma(self.shape, alpha, beta, self.seed)
def test_net_1D():
seed = 10
shape = (3, 2, 4)
alpha = 1.0
beta = 1.0
net = Net(shape, seed)
talpha, tbeta = Tensor(alpha, mstype.float32), Tensor(beta, mstype.float32)
output = net(talpha, tbeta)
assert output.shape == (3, 2, 4)
def test_net_ND():
seed = 10
shape = (3, 1, 2)
alpha = np.array([[[1], [2]], [[3], [4]], [[5], [6]]]).astype(np.float32)
beta = np.array([1.0]).astype(np.float32)
net = Net(shape, seed)
talpha, tbeta = Tensor(alpha, mstype.float32), Tensor(beta, mstype.float32)
output = net(talpha, tbeta)
assert output.shape == (3, 2, 2)
...@@ -12,9 +12,7 @@ ...@@ -12,9 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
import numpy as np import numpy as np
import pytest
import mindspore.context as context import mindspore.context as context
import mindspore.nn as nn import mindspore.nn as nn
...@@ -32,6 +30,7 @@ class Net(nn.Cell): ...@@ -32,6 +30,7 @@ class Net(nn.Cell):
self.seed = seed self.seed = seed
def construct(self, mean, stddev): def construct(self, mean, stddev):
C.set_seed(20)
return C.normal(self.shape, mean, stddev, self.seed) return C.normal(self.shape, mean, stddev, self.seed)
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.ops import composite as C
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell):
def __init__(self, shape, seed=0):
super(Net, self).__init__()
self.shape = shape
self.seed = seed
def construct(self, mean):
C.set_seed(20)
return C.poisson(self.shape, mean, self.seed)
def test_net_1D():
seed = 10
shape = (3, 2, 4)
mean = 1.0
net = Net(shape, seed)
tmean = Tensor(mean, mstype.float32)
output = net(tmean)
assert output.shape == (3, 2, 4)
def test_net_ND():
seed = 10
shape = (3, 1, 2)
mean = np.array([[[1], [2]], [[3], [4]], [[5], [6]]]).astype(np.float32)
net = Net(shape, seed)
tmean = Tensor(mean, mstype.float32)
output = net(tmean)
assert output.shape == (3, 2, 2)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.ops import composite as C
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell):
def __init__(self, shape, seed=0):
super(Net, self).__init__()
self.shape = shape
self.seed = seed
def construct(self, a, b):
C.set_seed(20)
return C.uniform(self.shape, a, b, self.seed)
def test_net_1D():
seed = 10
shape = (3, 2, 4)
a = 1.0
b = 6.0
net = Net(shape, seed)
ta, tb = Tensor(a, mstype.float32), Tensor(b, mstype.float32)
output = net(ta, tb)
assert output.shape == (3, 2, 4)
def test_net_ND():
seed = 10
shape = (3, 1, 2)
a = np.array([[[1], [2]], [[3], [4]], [[5], [6]]]).astype(np.float32)
b = np.array([1.0]).astype(np.float32)
net = Net(shape, seed)
ta, tb = Tensor(a, mstype.float32), Tensor(b, mstype.float32)
output = net(ta, tb)
assert output.shape == (3, 2, 2)
...@@ -43,4 +43,4 @@ def test_net(): ...@@ -43,4 +43,4 @@ def test_net():
tx, ty = Tensor(x), Tensor(y) tx, ty = Tensor(x), Tensor(y)
output = mask(tx, ty) output = mask(tx, ty)
print(output.asnumpy()) print(output.asnumpy())
assert ([255, 255, 255, 255] == output.asnumpy()).all() assert ([255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255] == output.asnumpy()).all()
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore import Tensor
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend")
class Net(nn.Cell):
def __init__(self, pad_dim_size):
super(Net, self).__init__()
self.padding = P.Padding(pad_dim_size)
def construct(self, x):
return self.padding(x)
def test_padding():
x = Tensor(np.array([[8], [10]]), mstype.int32)
padding = Net(4)
out = padding(x)
assert(out.asnumpy() == [[8, 0, 0, 0], [10, 0, 0, 0]]).all()
...@@ -592,44 +592,33 @@ class LaplaceNet(nn.Cell): ...@@ -592,44 +592,33 @@ class LaplaceNet(nn.Cell):
class GammaNet(nn.Cell): class GammaNet(nn.Cell):
def __init__(self, shape=None, seed=0): def __init__(self, shape=None, seed=0):
super(GammaNet, self).__init__() super(GammaNet, self).__init__()
self.gamma = P.Gamma(seed=seed)
self.shape = shape self.shape = shape
self.seed = seed
def construct(self, alpha, beta): def construct(self, alpha, beta):
out = self.gamma(self.shape, alpha, beta) out = C.gamma(self.shape, alpha, beta, self.seed)
return out return out
class PoissonNet(nn.Cell): class PoissonNet(nn.Cell):
def __init__(self, shape=None, seed=0): def __init__(self, shape=None, seed=0):
super(PoissonNet, self).__init__() super(PoissonNet, self).__init__()
self.poisson = P.Poisson(seed=seed)
self.shape = shape self.shape = shape
self.seed = seed
def construct(self, mean): def construct(self, mean):
out = self.poisson(self.shape, mean) out = C.poisson(self.shape, mean, self.seed)
return out
class UniformIntNet(nn.Cell):
def __init__(self, shape=None, seed=0):
super(UniformIntNet, self).__init__()
self.uniformint = P.UniformInt(seed=seed)
self.shape = shape
def construct(self, a, b):
out = self.uniformint(self.shape, a, b)
return out return out
class UniformRealNet(nn.Cell): class UniformNet(nn.Cell):
def __init__(self, shape=None, seed=0): def __init__(self, shape=None, seed=0):
super(UniformRealNet, self).__init__() super(UniformNet, self).__init__()
self.uniformreal = P.UniformReal(seed=seed)
self.shape = shape self.shape = shape
self.seed = seed
def construct(self, a, b): def construct(self, a, b):
out = self.uniformreal(self.shape, a, b) out = C.uniform(self.shape, a, b, self.seed)
return out return out
...@@ -924,13 +913,9 @@ test_case_math_ops = [ ...@@ -924,13 +913,9 @@ test_case_math_ops = [
'block': PoissonNet((3, 2, 4), 0), 'block': PoissonNet((3, 2, 4), 0),
'desc_inputs': [Tensor(2.0, mstype.float32)], 'desc_inputs': [Tensor(2.0, mstype.float32)],
'skip': ['backward']}), 'skip': ['backward']}),
('UniformInt', { ('Uniform', {
'block': UniformIntNet((3, 2, 4), 0), 'block': UniformNet((3, 2, 4), 0),
'desc_inputs': [Tensor(1, mstype.int32), Tensor(15, mstype.int32)], 'desc_inputs': [Tensor(0.0, mstype.float32), Tensor(1.0, mstype.float32)],
'skip': ['backward']}),
('UniformReal', {
'block': UniformRealNet((3, 2, 4), 0),
'desc_inputs': [Tensor(1.0, mstype.float32), Tensor(5.0, mstype.float32)],
'skip': ['backward']}), 'skip': ['backward']}),
('RandomChoiceWithMask', { ('RandomChoiceWithMask', {
'block': P.RandomChoiceWithMask(256), 'block': P.RandomChoiceWithMask(256),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册