Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
43567f9b
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
43567f9b
编写于
7月 18, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 18, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3147 add RNNTLoss and RandomCategorical ops for aicpu
Merge pull request !3147 from yanzhenxiang2020/add_rnnt_cate_open.new
上级
4945d34a
2ae6dfe9
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
294 addition
and
1 deletion
+294
-1
mindspore/ops/_grad/grad_nn_ops.py
mindspore/ops/_grad/grad_nn_ops.py
+10
-0
mindspore/ops/_op_impl/aicpu/__init__.py
mindspore/ops/_op_impl/aicpu/__init__.py
+2
-0
mindspore/ops/_op_impl/aicpu/random_categorical.py
mindspore/ops/_op_impl/aicpu/random_categorical.py
+48
-0
mindspore/ops/_op_impl/aicpu/rnnt_loss.py
mindspore/ops/_op_impl/aicpu/rnnt_loss.py
+37
-0
mindspore/ops/operations/__init__.py
mindspore/ops/operations/__init__.py
+4
-1
mindspore/ops/operations/nn_ops.py
mindspore/ops/operations/nn_ops.py
+56
-0
mindspore/ops/operations/random_ops.py
mindspore/ops/operations/random_ops.py
+58
-0
tests/st/ops/ascend/test_aicpu_ops/test_random_categorical.py
...s/st/ops/ascend/test_aicpu_ops/test_random_categorical.py
+38
-0
tests/st/ops/ascend/test_aicpu_ops/test_rnnt_loss.py
tests/st/ops/ascend/test_aicpu_ops/test_rnnt_loss.py
+41
-0
未找到文件。
mindspore/ops/_grad/grad_nn_ops.py
浏览文件 @
43567f9b
...
...
@@ -567,6 +567,16 @@ def get_bprop_l2_loss(self):
return
bprop
@
bprop_getters
.
register
(
P
.
RNNTLoss
)
def
get_bprop_rnnt_loss
(
self
):
"""Grad definition for `RNNTLoss` operation."""
def
bprop
(
acts
,
labels
,
act_lens
,
label_lens
,
out
,
dout
):
grad
=
out
[
1
]
return
grad
,
zeros_like
(
labels
),
zeros_like
(
act_lens
),
zeros_like
(
label_lens
)
return
bprop
@
bprop_getters
.
register
(
P
.
PReLU
)
def
get_bprop_prelu
(
self
):
"""Grad definition for `PReLU` operation."""
...
...
mindspore/ops/_op_impl/aicpu/__init__.py
浏览文件 @
43567f9b
...
...
@@ -30,3 +30,5 @@ from .ctcloss import _ctcloss_aicpu
from
.reverse_sequence
import
_reverse_sequence_aicpu
from
.crop_and_resize
import
_crop_and_resize_aicpu
from
.end_of_sequence
import
_end_of_sequence_aicpu
from
.rnnt_loss
import
_rnnt_loss_aicpu
from
.random_categorical
import
_random_categorical_aicpu
mindspore/ops/_op_impl/aicpu/random_categorical.py
0 → 100644
浏览文件 @
43567f9b
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""RandomCategorical op"""
from
mindspore.ops.op_info_register
import
op_info_register
,
AiCPURegOp
,
DataType
random_categorical_op_info
=
AiCPURegOp
(
"RandomCategorical"
)
\
.
fusion_type
(
"OPAQUE"
)
\
.
input
(
0
,
"logits"
,
"required"
)
\
.
input
(
1
,
"num_sample"
,
"required"
)
\
.
input
(
2
,
"seed"
,
"required"
)
\
.
output
(
0
,
"output"
,
"required"
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
I32_Default
,
DataType
.
I32_Default
,
DataType
.
I16_Default
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
I32_Default
,
DataType
.
I32_Default
,
DataType
.
I16_Default
)
\
.
dtype_format
(
DataType
.
F64_Default
,
DataType
.
I32_Default
,
DataType
.
I32_Default
,
DataType
.
I16_Default
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
I32_Default
,
DataType
.
I32_Default
,
DataType
.
I32_Default
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
I32_Default
,
DataType
.
I32_Default
,
DataType
.
I32_Default
)
\
.
dtype_format
(
DataType
.
F64_Default
,
DataType
.
I32_Default
,
DataType
.
I32_Default
,
DataType
.
I32_Default
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
I32_Default
,
DataType
.
I32_Default
,
DataType
.
I64_Default
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
I32_Default
,
DataType
.
I32_Default
,
DataType
.
I64_Default
)
\
.
dtype_format
(
DataType
.
F64_Default
,
DataType
.
I32_Default
,
DataType
.
I32_Default
,
DataType
.
I64_Default
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
I64_Default
,
DataType
.
I64_Default
,
DataType
.
I16_Default
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
I64_Default
,
DataType
.
I64_Default
,
DataType
.
I16_Default
)
\
.
dtype_format
(
DataType
.
F64_Default
,
DataType
.
I64_Default
,
DataType
.
I64_Default
,
DataType
.
I16_Default
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
I64_Default
,
DataType
.
I64_Default
,
DataType
.
I32_Default
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
I64_Default
,
DataType
.
I64_Default
,
DataType
.
I32_Default
)
\
.
dtype_format
(
DataType
.
F64_Default
,
DataType
.
I64_Default
,
DataType
.
I64_Default
,
DataType
.
I32_Default
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
I64_Default
,
DataType
.
I64_Default
,
DataType
.
I64_Default
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
I64_Default
,
DataType
.
I64_Default
,
DataType
.
I64_Default
)
\
.
dtype_format
(
DataType
.
F64_Default
,
DataType
.
I64_Default
,
DataType
.
I64_Default
,
DataType
.
I64_Default
)
\
.
get_op_info
()
@
op_info_register
(
random_categorical_op_info
)
def
_random_categorical_aicpu
():
"""RandomCategorical AiCPU register"""
return
mindspore/ops/_op_impl/aicpu/rnnt_loss.py
0 → 100644
浏览文件 @
43567f9b
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""RNNTLoss op"""
from
mindspore.ops.op_info_register
import
op_info_register
,
AiCPURegOp
,
DataType
rnnt_loss_op_info
=
AiCPURegOp
(
"RNNTLoss"
)
\
.
fusion_type
(
"OPAQUE"
)
\
.
input
(
0
,
"acts"
,
"required"
)
\
.
input
(
1
,
"labels"
,
"required"
)
\
.
input
(
2
,
"input_lengths"
,
"required"
)
\
.
input
(
3
,
"label_lengths"
,
"required"
)
\
.
output
(
0
,
"costs"
,
"required"
)
\
.
output
(
1
,
"grads"
,
"required"
)
\
.
attr
(
"blank_label"
,
"int"
)
\
.
dtype_format
(
DataType
.
F32_NCHW
,
DataType
.
I32_NCHW
,
DataType
.
I32_NCHW
,
DataType
.
I32_NCHW
,
DataType
.
F32_NCHW
,
DataType
.
F32_NCHW
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
I32_Default
,
DataType
.
I32_Default
,
DataType
.
I32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
get_op_info
()
@
op_info_register
(
rnnt_loss_op_info
)
def
_rnnt_loss_aicpu
():
"""RNNTLoss AiCPU register"""
return
mindspore/ops/operations/__init__.py
浏览文件 @
43567f9b
...
...
@@ -55,7 +55,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
Sin
,
Sqrt
,
Rsqrt
,
BesselI0e
,
BesselI1e
,
TruncateDiv
,
TruncateMod
,
Square
,
Sub
,
TensorAdd
,
Sign
,
Round
,
SquareSumAll
,
Atan
,
Atanh
,
Cosh
,
Sinh
,
Eps
,
Tan
)
from
.random_ops
import
(
RandomChoiceWithMask
,
Normal
)
from
.random_ops
import
(
RandomChoiceWithMask
,
Normal
,
RandomCategorical
)
from
.nn_ops
import
(
LSTM
,
SGD
,
Adam
,
SparseApplyAdam
,
SparseApplyLazyAdam
,
ApplyMomentum
,
BatchNorm
,
BiasAdd
,
Conv2D
,
DepthwiseConv2dNative
,
...
...
@@ -70,6 +70,7 @@ from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, Appl
ResizeBilinear
,
Sigmoid
,
SigmoidCrossEntropyWithLogits
,
SmoothL1Loss
,
Softmax
,
Softsign
,
Softplus
,
LRN
,
RNNTLoss
,
SoftmaxCrossEntropyWithLogits
,
ROIAlign
,
SparseSoftmaxCrossEntropyWithLogits
,
Tanh
,
TopK
,
BinaryCrossEntropy
,
SparseApplyAdagrad
,
LARSUpdate
,
ApplyFtrl
,
SparseApplyFtrl
,
...
...
@@ -171,6 +172,7 @@ __all__ = [
'Tanh'
,
'RandomChoiceWithMask'
,
'Normal'
,
'RandomCategorical'
,
'ResizeBilinear'
,
'ScalarSummary'
,
'ImageSummary'
,
...
...
@@ -202,6 +204,7 @@ __all__ = [
'SmoothL1Loss'
,
'L2Loss'
,
'CTCLoss'
,
'RNNTLoss'
,
'ReduceAll'
,
'ScalarToArray'
,
'ScalarToTensor'
,
...
...
mindspore/ops/operations/nn_ops.py
浏览文件 @
43567f9b
...
...
@@ -1736,6 +1736,62 @@ class DataFormatDimMap(PrimitiveWithInfer):
return
x_type
class
RNNTLoss
(
PrimitiveWithInfer
):
"""
Computes the RNNTLoss and its gradient with respect to the softmax outputs.
Args:
blank_label (int): blank label. Default: 0.
Inputs:
- **acts** (Tensor[float32]) - Tensor of shape :math:`(B, T, U, V)`.
- **labels** (Tensor[int32]) - Tensor of shape :math:`(B, U-1)`.
- **input_lengths** (Tensor[int32]) - Tensor of shape :math:`(B,)`.
- **label_lebgths** (Tensor[int32]) - Tensor of shape :math:`(B,)`.
Outputs:
- **costs** (Tensor[int32]) - Tensor of shape :math:`(B,)`.
- **grads** (Tensor[int32]) - Has the same shape as `acts`.
Examples:
>>> B, T, U, V = 1, 2, 3, 5
>>> acts = np.random.random((B, T, U, V)).astype(np.float32)
>>> labels = np.array([[1, 2]]).astype(np.int32)
>>> input_length = np.array([T] * B).astype(np.int32)
>>> label_length = np.array([len(l) for l in labels]).astype(np.int32)
>>> rnnt_loss = P.RNNTLoss(blank_label=blank)
>>> costs, grads = rnnt_loss(Tensor(acts), Tensor(labels), Tensor(input_length), Tensor(label_length))
"""
@
prim_attr_register
def
__init__
(
self
,
blank_label
=
0
):
validator
.
check_value_type
(
'blank_label'
,
blank_label
,
[
int
],
self
.
name
)
self
.
init_prim_io_names
(
inputs
=
[
'acts'
,
'labels'
,
'input_length'
,
'label_length'
],
outputs
=
[
'costs'
,
'grads'
])
def
infer_shape
(
self
,
acts_shape
,
labels_shape
,
input_length_shape
,
label_length_shape
):
validator
.
check_integer
(
'acts_rank'
,
len
(
acts_shape
),
4
,
Rel
.
EQ
,
self
.
name
)
validator
.
check_integer
(
'labels_rank'
,
len
(
labels_shape
),
2
,
Rel
.
EQ
,
self
.
name
)
validator
.
check_integer
(
'input_length_rank'
,
len
(
input_length_shape
),
1
,
Rel
.
EQ
,
self
.
name
)
validator
.
check_integer
(
'label_length_rank'
,
len
(
label_length_shape
),
1
,
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
'labels shape[0]'
,
labels_shape
[
0
],
'acts shape[0]'
,
acts_shape
[
0
],
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
'labels shape[1]'
,
labels_shape
[
1
],
'acts shape[2]-1'
,
acts_shape
[
2
]
-
1
,
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
'input_length size'
,
input_length_shape
[
0
],
'acts shape[0]'
,
acts_shape
[
0
],
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
'label_length size'
,
label_length_shape
[
0
],
'acts shape[0]'
,
acts_shape
[
0
],
Rel
.
EQ
,
self
.
name
)
costs_shape
=
(
acts_shape
[
0
],)
return
(
costs_shape
,
acts_shape
)
def
infer_dtype
(
self
,
acts_type
,
labels_type
,
input_length_type
,
label_length_type
):
validator
.
check_subclass
(
"acts_type"
,
acts_type
,
mstype
.
tensor
,
self
.
name
)
validator
.
check_subclass
(
"labels_type"
,
labels_type
,
mstype
.
tensor
,
self
.
name
)
validator
.
check_subclass
(
"input_length_type"
,
input_length_type
,
mstype
.
tensor
,
self
.
name
)
validator
.
check_subclass
(
"label_length_type"
,
label_length_type
,
mstype
.
tensor
,
self
.
name
)
validator
.
check_tensor_type_same
({
"acts_type"
:
acts_type
},
[
mstype
.
float32
],
self
.
name
)
validator
.
check_tensor_type_same
({
"labels_type"
:
labels_type
},
[
mstype
.
int32
],
self
.
name
)
validator
.
check_tensor_type_same
({
"input_length_type"
:
input_length_type
},
[
mstype
.
int32
],
self
.
name
)
validator
.
check_tensor_type_same
({
"label_length_type"
:
label_length_type
},
[
mstype
.
int32
],
self
.
name
)
return
(
acts_type
,
acts_type
)
class
SGD
(
PrimitiveWithInfer
):
"""
Computes stochastic gradient descent (optionally with momentum).
...
...
mindspore/ops/operations/random_ops.py
浏览文件 @
43567f9b
...
...
@@ -108,3 +108,61 @@ class Normal(PrimitiveWithInfer):
"dtype"
:
mstype
.
float32
,
"value"
:
None
}
return
out
class
RandomCategorical
(
PrimitiveWithInfer
):
"""
Generates random samples from a given categorical distribution tensor.
Args:
dtype (mindspore.dtype): The type of output. Its value should be one of [mindspore.int16,
mindspore.int32, mindspore.int64]. Default: mindspore.int64.
Inputs:
- **logits** (Tensor) - The input tensor. 2-D Tensor with shape [batch_size, num_classes].
- **num_sample** (int) - Number of sample to be drawn. Only constant values is allowed.
- **seed** (int) - Random seed. Default: 0. Only constant values is allowed.
Outputs:
- **output** (Tensor) - The output Tensor with shape [batch_size, num_samples].
Examples:
>>> class Net(nn.Cell):
>>> def __init__(self, num_sample):
>>> super(Net, self).__init__()
>>> self.random_categorical = P.RandomCategorical(mindspore.int64)
>>> self.num_sample = num_sample
>>> def construct(self, logits, seed=0):
>>> return self.random_categorical(logits, self.num_sample, seed)
>>>
>>> x = np.random.random((10, 5)).astype(np.float32)
>>> net = Net(8)
>>> output = net(Tensor(x))
"""
@
prim_attr_register
def
__init__
(
self
,
dtype
=
mstype
.
int64
):
"""Init RandomCategorical"""
self
.
dtype
=
dtype
valid_values
=
(
mstype
.
int32
,
mstype
.
int16
,
mstype
.
int64
)
validator
.
check_type_name
(
"dtype"
,
dtype
,
valid_values
,
self
.
name
)
self
.
init_prim_io_names
(
inputs
=
[
'logits'
,
'num_samples'
,
'seed'
],
outputs
=
[
'output'
])
def
__infer__
(
self
,
logits
,
num_samples
,
seed
):
logits_dtype
=
logits
[
'dtype'
]
valid_types
=
(
mstype
.
float32
,
mstype
.
float16
,
mstype
.
float64
)
validator
.
check_tensor_type_same
({
'logits'
:
logits_dtype
},
valid_types
,
self
.
name
)
num_samples_v
=
num_samples
[
'value'
]
seed_v
=
seed
[
'value'
]
validator
.
check_value_type
(
'num_samples'
,
num_samples_v
,
(
int
,),
self
.
name
)
validator
.
check_value_type
(
'seed'
,
seed_v
,
(
int
,),
self
.
name
)
validator
.
check_integer
(
"num_samples"
,
num_samples_v
,
0
,
Rel
.
GT
,
self
.
name
)
x_shape
=
list
(
logits
[
'shape'
])
if
len
(
x_shape
)
!=
2
:
raise
ValueError
(
"RandomCategorical shape should be 2-dimension."
)
ndim
=
len
(
x_shape
)
-
1
x_shape
[
ndim
]
=
num_samples_v
return
{
'shape'
:
(
x_shape
),
'dtype'
:
(
self
.
dtype
),
'value'
:
None
}
tests/st/ops/ascend/test_aicpu_ops/test_random_categorical.py
0 → 100644
浏览文件 @
43567f9b
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
numpy
as
np
import
mindspore
import
mindspore.nn
as
nn
import
mindspore.context
as
context
from
mindspore
import
Tensor
from
mindspore.ops
import
operations
as
P
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
class
Net
(
nn
.
Cell
):
def
__init__
(
self
,
num_sample
):
super
(
Net
,
self
).
__init__
()
self
.
random_categorical
=
P
.
RandomCategorical
(
mindspore
.
int64
)
self
.
num_sample
=
num_sample
def
construct
(
self
,
logits
,
seed
=
0
):
return
self
.
random_categorical
(
logits
,
self
.
num_sample
,
seed
)
def
test_net
():
x
=
np
.
random
.
random
((
10
,
5
)).
astype
(
np
.
float32
)
net
=
Net
(
8
)
output
=
net
(
Tensor
(
x
))
print
(
x
)
print
(
output
.
asnumpy
())
#print(output.dtype())
tests/st/ops/ascend/test_aicpu_ops/test_rnnt_loss.py
0 → 100644
浏览文件 @
43567f9b
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
numpy
as
np
import
mindspore.nn
as
nn
import
mindspore.context
as
context
from
mindspore
import
Tensor
from
mindspore.ops
import
operations
as
P
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
self
.
rnnt_loss
=
P
.
RNNTLoss
(
blank_label
=
0
)
def
construct
(
self
,
acts
,
labels
,
act_lens
,
label_lens
):
return
self
.
rnnt_loss
(
acts
,
labels
,
act_lens
,
label_lens
)
def
test_net
():
B
,
T
,
U
,
V
=
1
,
2
,
3
,
5
acts
=
np
.
random
.
random
((
B
,
T
,
U
,
V
)).
astype
(
np
.
float32
)
labels
=
np
.
array
([[
np
.
random
.
randint
(
1
,
V
-
1
)
for
_
in
range
(
U
-
1
)]]).
astype
(
np
.
int32
)
input_length
=
np
.
array
([
T
]
*
B
).
astype
(
np
.
int32
)
label_length
=
np
.
array
([
len
(
l
)
for
l
in
labels
]).
astype
(
np
.
int32
)
rnnt_loss
=
Net
()
costs
,
grads
=
rnnt_loss
(
Tensor
(
acts
),
Tensor
(
labels
),
Tensor
(
input_length
),
Tensor
(
label_length
))
print
(
Tensor
(
acts
),
Tensor
(
labels
),
Tensor
(
input_length
),
Tensor
(
label_length
))
print
(
costs
.
asnumpy
())
print
(
grads
.
asnumpy
())
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录