Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
d7de0442
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d7de0442
编写于
6月 01, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 01, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1780 add op BasicLSTMCell
Merge pull request !1780 from zhaozhenlong/op/lstm-open
上级
96ebda91
270f79c8
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
436 addition
and
2 deletion
+436
-2
mindspore/_checkparam.py
mindspore/_checkparam.py
+3
-0
mindspore/ops/_grad/grad_nn_ops.py
mindspore/ops/_grad/grad_nn_ops.py
+22
-0
mindspore/ops/_op_impl/tbe/__init__.py
mindspore/ops/_op_impl/tbe/__init__.py
+4
-0
mindspore/ops/_op_impl/tbe/basic_lstm_cell.py
mindspore/ops/_op_impl/tbe/basic_lstm_cell.py
+57
-0
mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py
mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py
+50
-0
mindspore/ops/_op_impl/tbe/basic_lstm_cell_input_grad.py
mindspore/ops/_op_impl/tbe/basic_lstm_cell_input_grad.py
+42
-0
mindspore/ops/_op_impl/tbe/basic_lstm_cell_weight_grad.py
mindspore/ops/_op_impl/tbe/basic_lstm_cell_weight_grad.py
+41
-0
mindspore/ops/operations/__init__.py
mindspore/ops/operations/__init__.py
+3
-2
mindspore/ops/operations/_grad_ops.py
mindspore/ops/operations/_grad_ops.py
+103
-0
mindspore/ops/operations/nn_ops.py
mindspore/ops/operations/nn_ops.py
+106
-0
tests/ut/python/ops/test_ops.py
tests/ut/python/ops/test_ops.py
+5
-0
未找到文件。
mindspore/_checkparam.py
浏览文件 @
d7de0442
...
...
@@ -299,6 +299,9 @@ class Validator:
def
get_typename
(
t
):
return
t
.
__name__
if
hasattr
(
t
,
'__name__'
)
else
str
(
t
)
if
isinstance
(
arg_type
,
type
(
mstype
.
tensor
)):
arg_type
=
arg_type
.
element_type
()
if
arg_type
in
valid_types
:
return
arg_type
type_names
=
[
get_typename
(
t
)
for
t
in
valid_types
]
...
...
mindspore/ops/_grad/grad_nn_ops.py
浏览文件 @
d7de0442
...
...
@@ -697,3 +697,25 @@ def get_bprop_ctc_loss(self):
return
grad
,
zeros_like
(
labels_indices
),
zeros_like
(
labels_values
),
zeros_like
(
sequence_length
)
return
bprop
@
bprop_getters
.
register
(
P
.
BasicLSTMCell
)
def
get_bprop_basic_lstm_cell
(
self
):
"""Grad definition for `BasicLSTMCell` operation."""
basic_lstm_cell_cstate_grad
=
G
.
BasicLSTMCellCStateGrad
(
forget_bias
=
self
.
forget_bias
,
activation
=
self
.
activation
)
basic_lstm_cell_weight_grad
=
G
.
BasicLSTMCellWeightGrad
()
basic_lstm_cell_input_grad
=
G
.
BasicLSTMCellInputGrad
(
keep_prob
=
self
.
keep_prob
)
def
bprop
(
x
,
h
,
c
,
w
,
b
,
out
,
dout
):
_
,
_
,
it
,
jt
,
ft
,
ot
,
tanhct
=
out
dct
,
dht
,
_
,
_
,
_
,
_
,
_
=
dout
dgate
,
dct_1
=
basic_lstm_cell_cstate_grad
(
c
,
dht
,
dct
,
it
,
jt
,
ft
,
ot
,
tanhct
)
dxt
,
dht
=
basic_lstm_cell_input_grad
(
dgate
,
w
)
dw
,
db
=
basic_lstm_cell_weight_grad
(
F
.
depend
(
x
,
dxt
),
h
,
dgate
)
return
dxt
,
dht
,
dct_1
,
dw
,
db
return
bprop
mindspore/ops/_op_impl/tbe/__init__.py
浏览文件 @
d7de0442
...
...
@@ -230,3 +230,7 @@ from .atan_grad import _atan_grad_tbe
from
.atanh
import
_atanh_tbe
from
.cosh
import
_cosh_tbe
from
.sinh
import
_sinh_tbe
from
.basic_lstm_cell
import
_basic_lstm_cell_tbe
from
.basic_lstm_cell_c_state_grad
import
_basic_lstm_cell_c_state_grad_tbe
from
.basic_lstm_cell_weight_grad
import
_basic_lstm_cell_weight_grad_tbe
from
.basic_lstm_cell_input_grad
import
_basic_lstm_cell_input_grad_tbe
mindspore/ops/_op_impl/tbe/basic_lstm_cell.py
0 → 100644
浏览文件 @
d7de0442
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""BasicLSTMCell op"""
from
mindspore.ops.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
basic_lstm_cell_op_info
=
TBERegOp
(
"BasicLSTMCell"
)
\
.
fusion_type
(
"OPAQUE"
)
\
.
async_flag
(
False
)
\
.
binfile_name
(
"basic_lstm_cell.so"
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"basic_lstm_cell"
)
\
.
attr
(
"keep_prob"
,
"optional"
,
"float"
,
"all"
)
\
.
attr
(
"forget_bias"
,
"optional"
,
"float"
,
"all"
)
\
.
attr
(
"state_is_tuple"
,
"optional"
,
"bool"
,
"true"
)
\
.
attr
(
"activation"
,
"optional"
,
"str"
,
"all"
)
\
.
partial_flag
(
True
)
\
.
input
(
0
,
"x"
,
False
,
"required"
,
"all"
)
\
.
input
(
1
,
"h"
,
False
,
"required"
,
"all"
)
\
.
input
(
2
,
"c"
,
False
,
"required"
,
"all"
)
\
.
input
(
3
,
"w"
,
False
,
"required"
,
"all"
)
\
.
input
(
4
,
"b"
,
False
,
"required"
,
"all"
)
\
.
input
(
5
,
"mask"
,
False
,
"optional"
,
"all"
)
\
.
output
(
0
,
"ct"
,
False
,
"required"
,
"all"
)
\
.
output
(
1
,
"ht"
,
False
,
"required"
,
"all"
)
\
.
output
(
2
,
"it"
,
False
,
"optional"
,
"all"
)
\
.
output
(
3
,
"jt"
,
False
,
"optional"
,
"all"
)
\
.
output
(
4
,
"ft"
,
False
,
"optional"
,
"all"
)
\
.
output
(
5
,
"ot"
,
False
,
"optional"
,
"all"
)
\
.
output
(
6
,
"tanhct"
,
False
,
"optional"
,
"all"
)
\
.
dtype_format
(
DataType
.
F16_FracNZ
,
DataType
.
F16_FracNZ
,
DataType
.
F32_FracNZ
,
DataType
.
F16_FracZ
,
DataType
.
F32_Default
,
DataType
.
U8_Default
,
DataType
.
F32_FracNZ
,
DataType
.
F16_FracNZ
,
DataType
.
F32_FracNZ
,
DataType
.
F32_FracNZ
,
DataType
.
F32_FracNZ
,
DataType
.
F32_FracNZ
,
DataType
.
F32_FracNZ
)
\
.
dtype_format
(
DataType
.
F16_FracNZ
,
DataType
.
F16_FracNZ
,
DataType
.
F16_FracNZ
,
DataType
.
F16_FracZ
,
DataType
.
F16_Default
,
DataType
.
U8_Default
,
DataType
.
F16_FracNZ
,
DataType
.
F16_FracNZ
,
DataType
.
F16_FracNZ
,
DataType
.
F16_FracNZ
,
DataType
.
F16_FracNZ
,
DataType
.
F16_FracNZ
,
DataType
.
F16_FracNZ
)
\
.
get_op_info
()
@
op_info_register
(
basic_lstm_cell_op_info
)
def
_basic_lstm_cell_tbe
():
"""BasicLSTMCell TBE register"""
return
mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py
0 → 100644
浏览文件 @
d7de0442
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""BasicLSTMCellCStateGrad op"""
from
mindspore.ops.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
basic_lstm_cell_c_state_grad_op_info
=
TBERegOp
(
"BasicLSTMCellCStateGrad"
)
\
.
fusion_type
(
"OPAQUE"
)
\
.
async_flag
(
False
)
\
.
binfile_name
(
"basic_lstm_cell_c_state_grad.so"
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"basic_lstm_cell_c_state_grad"
)
\
.
attr
(
"forget_bias"
,
"optional"
,
"float"
,
"all"
)
\
.
attr
(
"activation"
,
"optional"
,
"str"
,
"all"
)
\
.
partial_flag
(
True
)
\
.
input
(
0
,
"c"
,
False
,
"required"
,
"all"
)
\
.
input
(
1
,
"dht"
,
False
,
"required"
,
"all"
)
\
.
input
(
2
,
"dct"
,
False
,
"required"
,
"all"
)
\
.
input
(
3
,
"it"
,
False
,
"required"
,
"all"
)
\
.
input
(
4
,
"ft"
,
False
,
"required"
,
"all"
)
\
.
input
(
5
,
"jt"
,
False
,
"required"
,
"all"
)
\
.
input
(
6
,
"ot"
,
False
,
"required"
,
"all"
)
\
.
input
(
7
,
"tanhct"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"dgate"
,
False
,
"required"
,
"all"
)
\
.
output
(
1
,
"dct_1"
,
False
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F32_FracNZ
,
DataType
.
F32_FracNZ
,
DataType
.
F32_FracNZ
,
DataType
.
F32_FracNZ
,
DataType
.
F32_FracNZ
,
DataType
.
F32_FracNZ
,
DataType
.
F32_FracNZ
,
DataType
.
F32_FracNZ
,
DataType
.
F16_FracNZ
,
DataType
.
F16_FracNZ
)
\
.
dtype_format
(
DataType
.
F16_FracNZ
,
DataType
.
F16_FracNZ
,
DataType
.
F16_FracNZ
,
DataType
.
F16_FracNZ
,
DataType
.
F16_FracNZ
,
DataType
.
F16_FracNZ
,
DataType
.
F16_FracNZ
,
DataType
.
F16_FracNZ
,
DataType
.
F32_FracNZ
,
DataType
.
F16_FracNZ
)
\
.
get_op_info
()
@
op_info_register
(
basic_lstm_cell_c_state_grad_op_info
)
def
_basic_lstm_cell_c_state_grad_tbe
():
"""BasicLSTMCellCStateGrad TBE register"""
return
mindspore/ops/_op_impl/tbe/basic_lstm_cell_input_grad.py
0 → 100644
浏览文件 @
d7de0442
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""BasicLSTMCellInputGrad op"""
from
mindspore.ops.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
basic_lstm_cell_input_grad_op_info
=
TBERegOp
(
"BasicLSTMCellInputGrad"
)
\
.
fusion_type
(
"OPAQUE"
)
\
.
async_flag
(
False
)
\
.
binfile_name
(
"basic_lstm_cell_input_grad.so"
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"basic_lstm_cell_input_grad"
)
\
.
attr
(
"keep_prob"
,
"optional"
,
"float"
,
"all"
)
\
.
partial_flag
(
True
)
\
.
input
(
0
,
"dgate"
,
False
,
"required"
,
"all"
)
\
.
input
(
1
,
"w"
,
False
,
"required"
,
"all"
)
\
.
input
(
2
,
"dropout_mask"
,
False
,
"optional"
,
"all"
)
\
.
output
(
0
,
"dxt"
,
False
,
"required"
,
"all"
)
\
.
output
(
1
,
"dht"
,
False
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F16_FracNZ
,
DataType
.
F16_FracZ
,
DataType
.
U8_Default
,
DataType
.
F32_FracNZ
,
DataType
.
F32_FracNZ
)
\
.
dtype_format
(
DataType
.
F16_FracNZ
,
DataType
.
F16_FracZ
,
DataType
.
U8_Default
,
DataType
.
F16_FracNZ
,
DataType
.
F16_FracNZ
)
\
.
get_op_info
()
@
op_info_register
(
basic_lstm_cell_input_grad_op_info
)
def
_basic_lstm_cell_input_grad_tbe
():
"""BasicLSTMCellInputGrad TBE register"""
return
mindspore/ops/_op_impl/tbe/basic_lstm_cell_weight_grad.py
0 → 100644
浏览文件 @
d7de0442
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""BasicLSTMCellWeightGrad op"""
from
mindspore.ops.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
basic_lstm_cell_weight_grad_op_info
=
TBERegOp
(
"BasicLSTMCellWeightGrad"
)
\
.
fusion_type
(
"OPAQUE"
)
\
.
async_flag
(
False
)
\
.
binfile_name
(
"basic_lstm_cell_weight_grad.so"
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"basic_lstm_cell_weight_grad"
)
\
.
partial_flag
(
True
)
\
.
input
(
0
,
"x"
,
False
,
"required"
,
"all"
)
\
.
input
(
1
,
"h"
,
False
,
"required"
,
"all"
)
\
.
input
(
2
,
"dgate"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"dw"
,
False
,
"required"
,
"all"
)
\
.
output
(
1
,
"db"
,
False
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F16_FracNZ
,
DataType
.
F16_FracNZ
,
DataType
.
F16_FracNZ
,
DataType
.
F16_FracZ
,
DataType
.
F32_Default
)
\
.
dtype_format
(
DataType
.
F16_FracNZ
,
DataType
.
F16_FracNZ
,
DataType
.
F16_FracNZ
,
DataType
.
F16_FracZ
,
DataType
.
F16_Default
)
\
.
get_op_info
()
@
op_info_register
(
basic_lstm_cell_weight_grad_op_info
)
def
_basic_lstm_cell_weight_grad_tbe
():
"""BasicLSTMCellWeightGrad TBE register"""
return
mindspore/ops/operations/__init__.py
浏览文件 @
d7de0442
...
...
@@ -72,7 +72,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
SparseSoftmaxCrossEntropyWithLogits
,
Tanh
,
TopK
,
BinaryCrossEntropy
,
SparseApplyAdagrad
,
LARSUpdate
,
ApplyFtrl
,
SparseApplyFtrl
,
ApplyProximalAdagrad
,
SparseApplyProximalAdagrad
,
ApplyRMSProp
,
ApplyCenteredRMSProp
)
ApplyRMSProp
,
ApplyCenteredRMSProp
,
BasicLSTMCell
)
from
.other_ops
import
Assign
,
IOU
,
BoundingBoxDecode
,
BoundingBoxEncode
,
CheckValid
,
MakeRefKey
,
CheckBprop
from
.
import
_quant_ops
from
._quant_ops
import
*
...
...
@@ -287,7 +287,8 @@ __all__ = [
"BesselI0e"
,
"BesselI1e"
,
"Atan"
,
"Atanh"
"Atanh"
,
"BasicLSTMCell"
]
__all__
.
extend
(
_quant_ops
.
__all__
)
...
...
mindspore/ops/operations/_grad_ops.py
浏览文件 @
d7de0442
...
...
@@ -1173,3 +1173,106 @@ class AtanGrad(PrimitiveWithInfer):
args
=
{
"x"
:
x
,
"dout"
:
dout
}
validator
.
check_tensor_type_same
(
args
,
mstype
.
number_type
,
self
.
name
)
return
x
class
BasicLSTMCellCStateGrad
(
PrimitiveWithInfer
):
"""Computes the state gradients of BasicLSTMCell."""
@
prim_attr_register
def
__init__
(
self
,
forget_bias
,
activation
):
self
.
forget_bias
=
validator
.
check_value_type
(
"forget_bias"
,
forget_bias
,
[
float
],
self
.
name
)
self
.
activation
=
validator
.
check_string
(
"activation"
,
activation
,
[
'tanh'
],
self
.
name
)
def
infer_shape
(
self
,
c_shape
,
dht_shape
,
dct_shape
,
it_shape
,
jt_shape
,
ft_shape
,
ot_shape
,
tanhct_shape
):
# dhy and dcy should be same shape
validator
.
check_integer
(
"c rank"
,
len
(
c_shape
),
2
,
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
"dht rank"
,
len
(
dht_shape
),
"c rank"
,
len
(
c_shape
),
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
"dct rank"
,
len
(
dct_shape
),
"c rank"
,
len
(
c_shape
),
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
"it rank"
,
len
(
it_shape
),
"c rank"
,
len
(
c_shape
),
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
"jt rank"
,
len
(
jt_shape
),
"c rank"
,
len
(
c_shape
),
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
"ft rank"
,
len
(
ft_shape
),
"c rank"
,
len
(
c_shape
),
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
"ot rank"
,
len
(
ot_shape
),
"c rank"
,
len
(
c_shape
),
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
"tanhct rank"
,
len
(
tanhct_shape
),
"c rank"
,
len
(
c_shape
),
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
"dht shape"
,
dht_shape
,
"c shape"
,
c_shape
,
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
"dct shape"
,
dct_shape
,
"c shape"
,
c_shape
,
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
"it shape"
,
it_shape
,
"c shape"
,
c_shape
,
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
"jt shape"
,
jt_shape
,
"c shape"
,
c_shape
,
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
"ft shape"
,
ft_shape
,
"c shape"
,
c_shape
,
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
"ot shape"
,
ot_shape
,
"c shape"
,
c_shape
,
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
"tanhct shape"
,
tanhct_shape
,
"c shape"
,
c_shape
,
Rel
.
EQ
,
self
.
name
)
dgate_shape
=
(
c_shape
[
0
],
4
*
c_shape
[
1
])
dct_1_shape
=
c_shape
return
(
dgate_shape
,
dct_1_shape
)
def
infer_dtype
(
self
,
c_dtype
,
dht_dtype
,
dct_dtype
,
it_dtype
,
jt_dtype
,
ft_dtype
,
ot_dtype
,
tanhct_dtype
):
validator
.
check_subclass
(
"c"
,
c_dtype
,
[
mstype
.
tensor
],
self
.
name
)
validator
.
check_subclass
(
"dht"
,
dht_dtype
,
[
mstype
.
tensor
],
self
.
name
)
validator
.
check_subclass
(
"dct"
,
dct_dtype
,
[
mstype
.
tensor
],
self
.
name
)
validator
.
check_subclass
(
"it"
,
it_dtype
,
[
mstype
.
tensor
],
self
.
name
)
validator
.
check_subclass
(
"jt"
,
jt_dtype
,
[
mstype
.
tensor
],
self
.
name
)
validator
.
check_subclass
(
"ft"
,
ft_dtype
,
[
mstype
.
tensor
],
self
.
name
)
validator
.
check_subclass
(
"ot"
,
ot_dtype
,
[
mstype
.
tensor
],
self
.
name
)
validator
.
check_subclass
(
"tanhct"
,
tanhct_dtype
,
[
mstype
.
tensor
],
self
.
name
)
validator
.
check_type_name
(
"c"
,
c_dtype
,
[
mstype
.
float16
,
mstype
.
float32
],
self
.
name
)
validator
.
check_type_name
(
"dht"
,
dht_dtype
,
[
mstype
.
float16
,
mstype
.
float32
],
self
.
name
)
validator
.
check_type_name
(
"dct"
,
dct_dtype
,
[
mstype
.
float16
,
mstype
.
float32
],
self
.
name
)
validator
.
check_type_name
(
"it"
,
it_dtype
,
[
mstype
.
float16
,
mstype
.
float32
],
self
.
name
)
validator
.
check_type_name
(
"jt"
,
jt_dtype
,
[
mstype
.
float16
,
mstype
.
float32
],
self
.
name
)
validator
.
check_type_name
(
"ft"
,
ft_dtype
,
[
mstype
.
float16
,
mstype
.
float32
],
self
.
name
)
validator
.
check_type_name
(
"ot"
,
ot_dtype
,
[
mstype
.
float16
,
mstype
.
float32
],
self
.
name
)
validator
.
check_type_name
(
"tanhct"
,
tanhct_dtype
,
[
mstype
.
float16
,
mstype
.
float32
],
self
.
name
)
return
(
c_dtype
,
c_dtype
)
class
BasicLSTMCellWeightGrad
(
PrimitiveWithInfer
):
"""Computes the weight gradients of BasicLSTM."""
@
prim_attr_register
def
__init__
(
self
):
pass
def
infer_shape
(
self
,
x_shape
,
h_shape
,
dgate_shape
):
validator
.
check_integer
(
"x rank"
,
len
(
x_shape
),
2
,
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
"h rank"
,
len
(
h_shape
),
" x rank"
,
len
(
x_shape
),
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
"dgate rank"
,
len
(
dgate_shape
),
"x rank"
,
len
(
x_shape
),
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
"h_shape[0]"
,
h_shape
[
0
],
"x_shape[0]"
,
x_shape
[
0
],
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
"dgate_shape[0]"
,
dgate_shape
[
0
],
"h_shape[0]"
,
h_shape
[
0
],
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
"dgate_shape[1]"
,
dgate_shape
[
1
],
"4*h_shape[1]"
,
4
*
h_shape
[
1
],
Rel
.
EQ
,
self
.
name
)
dw_shape
=
(
dgate_shape
[
1
],
x_shape
[
1
]
+
h_shape
[
1
],
1
,
1
)
db_shape
=
(
dgate_shape
[
1
],
1
,
1
,
1
)
return
(
dw_shape
,
db_shape
)
def
infer_dtype
(
self
,
x_dtype
,
h_dtype
,
dgate_dtype
):
validator
.
check_subclass
(
"x"
,
x_dtype
,
mstype
.
tensor
,
self
.
name
)
validator
.
check_subclass
(
"h"
,
h_dtype
,
mstype
.
tensor
,
self
.
name
)
validator
.
check_subclass
(
"dgate"
,
dgate_dtype
,
mstype
.
tensor
,
self
.
name
)
validator
.
check_type_name
(
"x"
,
x_dtype
,
[
mstype
.
float16
,
mstype
.
float32
],
self
.
name
)
validator
.
check_type_name
(
"h"
,
h_dtype
,
[
mstype
.
float16
,
mstype
.
float32
],
self
.
name
)
validator
.
check_type_name
(
"dgate"
,
dgate_dtype
,
[
mstype
.
float16
,
mstype
.
float32
],
self
.
name
)
return
(
x_dtype
,
x_dtype
)
class
BasicLSTMCellInputGrad
(
PrimitiveWithInfer
):
"""Computes the input gradients of BasicLSTM."""
@
prim_attr_register
def
__init__
(
self
,
keep_prob
):
self
.
keep_prob
=
validator
.
check_value_type
(
"keep_prob"
,
keep_prob
,
[
float
],
self
.
name
)
self
.
keep_prob
=
validator
.
check_number_range
(
"keep_prob"
,
keep_prob
,
0.0
,
1.0
,
Rel
.
INC_BOTH
,
self
.
name
)
def
infer_shape
(
self
,
dgate_shape
,
w_shape
):
validator
.
check_integer
(
"dgate rank"
,
len
(
dgate_shape
),
2
,
Rel
.
EQ
,
self
.
name
)
validator
.
check_integer
(
"w rank"
,
len
(
w_shape
),
4
,
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
"dgate_shape[1]"
,
dgate_shape
[
1
],
"w_shape[0]"
,
w_shape
[
0
],
Rel
.
EQ
,
self
.
name
)
dxt_shape
=
(
dgate_shape
[
0
],
w_shape
[
1
]
-
w_shape
[
0
]
//
4
)
dht_shape
=
(
dgate_shape
[
0
],
dgate_shape
[
1
]
//
4
)
return
(
dxt_shape
,
dht_shape
)
def
infer_dtype
(
self
,
dgate_dtype
,
w_dtype
):
validator
.
check_subclass
(
"dgate"
,
dgate_dtype
,
mstype
.
tensor
,
self
.
name
)
validator
.
check_subclass
(
"w"
,
w_dtype
,
mstype
.
tensor
,
self
.
name
)
validator
.
check_type_name
(
"dgate"
,
dgate_dtype
,
[
mstype
.
float16
,
mstype
.
float32
],
self
.
name
)
validator
.
check_type_name
(
"w"
,
w_dtype
,
[
mstype
.
float16
,
mstype
.
float32
],
self
.
name
)
return
(
dgate_dtype
,
dgate_dtype
)
mindspore/ops/operations/nn_ops.py
浏览文件 @
d7de0442
...
...
@@ -3363,3 +3363,109 @@ class CTCLoss(PrimitiveWithInfer):
validator
.
check_tensor_type_same
({
"labels_values_dtype"
:
labels_values
},
[
mstype
.
int32
],
self
.
name
)
validator
.
check_tensor_type_same
({
"sequence_length_dtype"
:
sequence_length
},
[
mstype
.
int32
],
self
.
name
)
return
inputs
,
inputs
class
BasicLSTMCell
(
PrimitiveWithInfer
):
r
"""
Performs the long short term memory(LSTM) on the input.
.. math::
\begin{array}{ll} \\
i_t = \sigma(W_{ix} x_t + b_{ix} + W_{ih} h_{(t-1)} + b_{ih}) \\
f_t = \sigma(W_{fx} x_t + b_{fx} + W_{fh} h_{(t-1)} + b_{fh}) \\
\tilde{c}_t = \tanh(W_{cx} x_t + b_{cx} + W_{ch} h_{(t-1)} + b_{ch}) \\
o_t = \sigma(W_{ox} x_t + b_{ox} + W_{oh} h_{(t-1)} + b_{oh}) \\
c_t = f_t * c_{(t-1)} + i_t * \tilde{c}_t \\
h_t = o_t * \tanh(c_t) \\
\end{array}
Here :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product. :math:`W, b`
are learnable weights between the output and the input in the formula. For instance,
:math:`W_{ix}, b_{ix}` are the weight and bias used to transform from input :math:`x` to :math:`i`.
Details can be found in paper `LONG SHORT-TERM MEMORY
<https://www.bioinf.jku.at/publications/older/2604.pdf>`_ and
`Long Short-Term Memory Recurrent Neural Network Architectures for Large Scale Acoustic Modeling
<https://static.googleusercontent.com/media/research.google.com/zh-CN//pubs/archive/43905.pdf>`_.
Args:
keep_prob (float): If not 1.0, append `Dropout` layer on the outputs of each
LSTM layer except the last layer. Default 1.0. The range of dropout is [0.0, 1.0].
forget_bias (float): Add forget bias to forget gate biases in order to decrease former scale. Default to 1.0.
state_is_tuple (bool): If True, state is tensor tuple, containing h and c; If False, one tensor,
need split first. Default to True.
activation (str): Activation. Default to "tanh".
Inputs:
- **x** (Tensor) - Current words. Tensor of shape (`batch_size`, `input_size`).
- **h** (Tensor) - Hidden state last moment. Tensor of shape (`batch_size`, `hidden_size`).
- **c** (Tensor) - Cell state last moment. Tensor of shape (`batch_size`, `hidden_size`).
- **w** (Tensor) - Weight. Tensor of shape (`4 x hidden_size`, `input_size + hidden_size`, 1, 1).
- **b** (Tensor) - Bias. Tensor of shape (`4 x hidden_size`, 1, 1, 1).
Outputs:
- **ct** (Tensor) - Forward :math:`c_t` cache at moment `t`. Tensor of shape (`batch_size`, `hidden_size`).
- **ht** (Tensor) - Cell output. Tensor of shape (`batch_size`, `hidden_size`).
- **it** (Tensor) - Forward :math:`i_t` cache at moment `t`. Tensor of shape (`batch_size`, `4 x hidden_size`).
- **jt** (Tensor) - Forward :math:`j_t` cache at moment `t`. Tensor of shape (`batch_size`, `4 x hidden_size`).
- **ft** (Tensor) - Forward :math:`f_t` cache at moment `t`. Tensor of shape (`batch_size`, `4 x hidden_size`).
- **ot** (Tensor) - Forward :math:`o_t` cache at moment `t`. Tensor of shape (`batch_size`, `4 x hidden_size`).
- **tanhct** (Tensor) - Forward :math:`tanh c_t` cache at moment `t`.
Tensor of shape (`batch_size`, `4 x hidden_size`).
Examples:
'block': P.BasicLSTMCell(keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh'),
'desc_inputs': [[128, 128], [128, 128], [128, 128], [512, 256, 1, 1],[512, 1, 1, 1]],
'desc_bprop': [[128, 128], [128, 128], [128, 128], [128, 128], [128, 128], [128, 128], [128, 128]],
>>> x = Tensor(np.random.rand(128, 128).astype(np.float16))
>>> h = Tensor(np.random.rand(128, 128).astype(np.float16))
>>> c = Tensor(np.random.rand(128, 128).astype(np.float16))
>>> w = Tensor(np.random.rand(512, 256, 1, 1).astype(np.float16))
>>> b = Tensor(np.random.rand(512, 1, 1, 1).astype(np.float16))
>>> lstm = P.BasicLSTMCell(keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh')
>>> lstm(x, h, c, w, b)
"""
@
prim_attr_register
def
__init__
(
self
,
keep_prob
=
1.0
,
forget_bias
=
1.0
,
state_is_tuple
=
True
,
activation
=
'tanh'
):
self
.
keep_prob
=
validator
.
check_value_type
(
"keep_prob"
,
keep_prob
,
[
float
],
self
.
name
)
self
.
keep_prob
=
validator
.
check_number_range
(
"keep_prob"
,
keep_prob
,
0.0
,
1.0
,
Rel
.
INC_BOTH
,
self
.
name
)
self
.
forget_bias
=
validator
.
check_value_type
(
"forget_bias"
,
forget_bias
,
[
float
],
self
.
name
)
self
.
state_is_tuple
=
validator
.
check_value_type
(
"state_is_tuple"
,
state_is_tuple
,
[
bool
],
self
.
name
)
self
.
activation
=
validator
.
check_string
(
"activation"
,
activation
,
[
'tanh'
],
self
.
name
)
def
infer_shape
(
self
,
x_shape
,
h_shape
,
c_shape
,
w_shape
,
b_shape
):
# (batch_size, input_size)
validator
.
check_integer
(
"x_shape"
,
len
(
x_shape
),
2
,
Rel
.
EQ
,
self
.
name
)
# h and c should be same shape
validator
.
check_integer
(
"h_shape"
,
len
(
h_shape
),
2
,
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
"h rank"
,
len
(
h_shape
),
"c rank"
,
len
(
c_shape
),
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
"h shape"
,
h_shape
,
"c shape"
,
c_shape
,
Rel
.
EQ
,
self
.
name
)
validator
.
check_integer
(
"w rank"
,
len
(
w_shape
),
4
,
Rel
.
EQ
,
self
.
name
)
validator
.
check_integer
(
"b rank"
,
len
(
b_shape
),
4
,
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
"w_shape[0]"
,
w_shape
[
0
],
"4*h_shape[1]"
,
4
*
h_shape
[
1
],
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
"w_shape[1]"
,
w_shape
[
1
],
"x_shape[1]+h_shape[1]"
,
x_shape
[
1
]
+
h_shape
[
1
],
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
"b_shape[0]"
,
b_shape
[
0
],
"4*h_shape[1]"
,
4
*
h_shape
[
1
],
Rel
.
EQ
,
self
.
name
)
ct_shape
=
c_shape
ht_shape
=
h_shape
it_shape
=
h_shape
jt_shape
=
h_shape
ft_shape
=
h_shape
ot_shape
=
h_shape
tanhct_shape
=
h_shape
return
(
ct_shape
,
ht_shape
,
it_shape
,
jt_shape
,
ft_shape
,
ot_shape
,
tanhct_shape
)
def
infer_dtype
(
self
,
x_dtype
,
h_dtype
,
c_dtype
,
w_dtype
,
b_dtype
):
validator
.
check_subclass
(
"x"
,
x_dtype
,
[
mstype
.
tensor
],
self
.
name
)
validator
.
check_subclass
(
"h"
,
h_dtype
,
[
mstype
.
tensor
],
self
.
name
)
validator
.
check_subclass
(
"c"
,
c_dtype
,
[
mstype
.
tensor
],
self
.
name
)
validator
.
check_subclass
(
"w"
,
w_dtype
,
[
mstype
.
tensor
],
self
.
name
)
validator
.
check_subclass
(
"b"
,
b_dtype
,
[
mstype
.
tensor
],
self
.
name
)
validator
.
check_type_name
(
"x"
,
x_dtype
,
[
mstype
.
float16
,
mstype
.
float32
],
self
.
name
)
validator
.
check_type_name
(
"h"
,
h_dtype
,
[
mstype
.
float16
,
mstype
.
float32
],
self
.
name
)
validator
.
check_type_name
(
"c"
,
c_dtype
,
[
mstype
.
float16
,
mstype
.
float32
],
self
.
name
)
validator
.
check_type_name
(
"w"
,
w_dtype
,
[
mstype
.
float16
,
mstype
.
float32
],
self
.
name
)
validator
.
check_type_name
(
"b"
,
b_dtype
,
[
mstype
.
float16
,
mstype
.
float32
],
self
.
name
)
return
(
x_dtype
,
x_dtype
,
x_dtype
,
x_dtype
,
x_dtype
,
x_dtype
,
x_dtype
)
tests/ut/python/ops/test_ops.py
浏览文件 @
d7de0442
...
...
@@ -891,6 +891,11 @@ test_case_nn_ops = [
'desc_inputs'
:
[[
128
,
64
,
32
,
32
],
[
128
,
64
,
32
,
32
],
[
64
],
[
64
],
[
64
]],
'desc_bprop'
:
[[
128
,
64
,
32
,
32
],
[
64
],
[
64
],
[
64
],
[
64
]],
'skip'
:
[
'backward'
]}),
(
'BasicLSTMCell'
,
{
'block'
:
P
.
BasicLSTMCell
(
keep_prob
=
1.0
,
forget_bias
=
1.0
,
state_is_tuple
=
True
,
activation
=
'tanh'
),
'desc_inputs'
:
[[
128
,
128
],
[
128
,
128
],
[
128
,
128
],
[
512
,
256
,
1
,
1
],[
512
,
1
,
1
,
1
]],
'desc_bprop'
:
[[
128
,
128
],
[
128
,
128
],
[
128
,
128
],
[
128
,
128
],
[
128
,
128
],
[
128
,
128
],
[
128
,
128
]],
'skip'
:
[]}),
(
'TopK'
,
{
'block'
:
P
.
TopK
(),
'desc_const'
:
[
5
],
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录