提交 d7de0442 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1780 add op BasicLSTMCell

Merge pull request !1780 from zhaozhenlong/op/lstm-open
...@@ -299,6 +299,9 @@ class Validator: ...@@ -299,6 +299,9 @@ class Validator:
def get_typename(t): def get_typename(t):
return t.__name__ if hasattr(t, '__name__') else str(t) return t.__name__ if hasattr(t, '__name__') else str(t)
if isinstance(arg_type, type(mstype.tensor)):
arg_type = arg_type.element_type()
if arg_type in valid_types: if arg_type in valid_types:
return arg_type return arg_type
type_names = [get_typename(t) for t in valid_types] type_names = [get_typename(t) for t in valid_types]
......
...@@ -697,3 +697,25 @@ def get_bprop_ctc_loss(self): ...@@ -697,3 +697,25 @@ def get_bprop_ctc_loss(self):
return grad, zeros_like(labels_indices), zeros_like(labels_values), zeros_like(sequence_length) return grad, zeros_like(labels_indices), zeros_like(labels_values), zeros_like(sequence_length)
return bprop return bprop
@bprop_getters.register(P.BasicLSTMCell)
def get_bprop_basic_lstm_cell(self):
"""Grad definition for `BasicLSTMCell` operation."""
basic_lstm_cell_cstate_grad = G.BasicLSTMCellCStateGrad(
forget_bias=self.forget_bias,
activation=self.activation
)
basic_lstm_cell_weight_grad = G.BasicLSTMCellWeightGrad()
basic_lstm_cell_input_grad = G.BasicLSTMCellInputGrad(keep_prob=self.keep_prob)
def bprop(x, h, c, w, b, out, dout):
_, _, it, jt, ft, ot, tanhct = out
dct, dht, _, _, _, _, _ = dout
dgate, dct_1 = basic_lstm_cell_cstate_grad(c, dht, dct, it, jt, ft, ot, tanhct)
dxt, dht = basic_lstm_cell_input_grad(dgate, w)
dw, db = basic_lstm_cell_weight_grad(F.depend(x, dxt), h, dgate)
return dxt, dht, dct_1, dw, db
return bprop
...@@ -230,3 +230,7 @@ from .atan_grad import _atan_grad_tbe ...@@ -230,3 +230,7 @@ from .atan_grad import _atan_grad_tbe
from .atanh import _atanh_tbe from .atanh import _atanh_tbe
from .cosh import _cosh_tbe from .cosh import _cosh_tbe
from .sinh import _sinh_tbe from .sinh import _sinh_tbe
from .basic_lstm_cell import _basic_lstm_cell_tbe
from .basic_lstm_cell_c_state_grad import _basic_lstm_cell_c_state_grad_tbe
from .basic_lstm_cell_weight_grad import _basic_lstm_cell_weight_grad_tbe
from .basic_lstm_cell_input_grad import _basic_lstm_cell_input_grad_tbe
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""BasicLSTMCell op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
basic_lstm_cell_op_info = TBERegOp("BasicLSTMCell") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("basic_lstm_cell.so") \
.compute_cost(10) \
.kernel_name("basic_lstm_cell") \
.attr("keep_prob", "optional", "float", "all") \
.attr("forget_bias", "optional", "float", "all") \
.attr("state_is_tuple", "optional", "bool", "true") \
.attr("activation", "optional", "str", "all") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.input(1, "h", False, "required", "all") \
.input(2, "c", False, "required", "all") \
.input(3, "w", False, "required", "all") \
.input(4, "b", False, "required", "all") \
.input(5, "mask", False, "optional", "all") \
.output(0, "ct", False, "required", "all") \
.output(1, "ht", False, "required", "all") \
.output(2, "it", False, "optional", "all") \
.output(3, "jt", False, "optional", "all") \
.output(4, "ft", False, "optional", "all") \
.output(5, "ot", False, "optional", "all") \
.output(6, "tanhct", False, "optional", "all") \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F32_FracNZ, DataType.F16_FracZ,
DataType.F32_Default, DataType.U8_Default, DataType.F32_FracNZ, DataType.F16_FracNZ,
DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ,
DataType.F32_FracNZ) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracZ,
DataType.F16_Default, DataType.U8_Default, DataType.F16_FracNZ, DataType.F16_FracNZ,
DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
DataType.F16_FracNZ) \
.get_op_info()
@op_info_register(basic_lstm_cell_op_info)
def _basic_lstm_cell_tbe():
"""BasicLSTMCell TBE register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""BasicLSTMCellCStateGrad op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
basic_lstm_cell_c_state_grad_op_info = TBERegOp("BasicLSTMCellCStateGrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("basic_lstm_cell_c_state_grad.so") \
.compute_cost(10) \
.kernel_name("basic_lstm_cell_c_state_grad") \
.attr("forget_bias", "optional", "float", "all") \
.attr("activation", "optional", "str", "all") \
.partial_flag(True) \
.input(0, "c", False, "required", "all") \
.input(1, "dht", False, "required", "all") \
.input(2, "dct", False, "required", "all") \
.input(3, "it", False, "required", "all") \
.input(4, "ft", False, "required", "all") \
.input(5, "jt", False, "required", "all") \
.input(6, "ot", False, "required", "all") \
.input(7, "tanhct", False, "required", "all") \
.output(0, "dgate", False, "required", "all") \
.output(1, "dct_1", False, "required", "all") \
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ,
DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ,
DataType.F16_FracNZ, DataType.F16_FracNZ) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
DataType.F32_FracNZ, DataType.F16_FracNZ) \
.get_op_info()
@op_info_register(basic_lstm_cell_c_state_grad_op_info)
def _basic_lstm_cell_c_state_grad_tbe():
"""BasicLSTMCellCStateGrad TBE register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""BasicLSTMCellInputGrad op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
basic_lstm_cell_input_grad_op_info = TBERegOp("BasicLSTMCellInputGrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("basic_lstm_cell_input_grad.so") \
.compute_cost(10) \
.kernel_name("basic_lstm_cell_input_grad") \
.attr("keep_prob", "optional", "float", "all") \
.partial_flag(True) \
.input(0, "dgate", False, "required", "all") \
.input(1, "w", False, "required", "all") \
.input(2, "dropout_mask", False, "optional", "all") \
.output(0, "dxt", False, "required", "all") \
.output(1, "dht", False, "required", "all") \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.U8_Default, DataType.F32_FracNZ,
DataType.F32_FracNZ) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.U8_Default, DataType.F16_FracNZ,
DataType.F16_FracNZ) \
.get_op_info()
@op_info_register(basic_lstm_cell_input_grad_op_info)
def _basic_lstm_cell_input_grad_tbe():
"""BasicLSTMCellInputGrad TBE register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""BasicLSTMCellWeightGrad op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
basic_lstm_cell_weight_grad_op_info = TBERegOp("BasicLSTMCellWeightGrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("basic_lstm_cell_weight_grad.so") \
.compute_cost(10) \
.kernel_name("basic_lstm_cell_weight_grad") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.input(1, "h", False, "required", "all") \
.input(2, "dgate", False, "required", "all") \
.output(0, "dw", False, "required", "all") \
.output(1, "db", False, "required", "all") \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracZ,
DataType.F32_Default) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracZ,
DataType.F16_Default) \
.get_op_info()
@op_info_register(basic_lstm_cell_weight_grad_op_info)
def _basic_lstm_cell_weight_grad_tbe():
"""BasicLSTMCellWeightGrad TBE register"""
return
...@@ -72,7 +72,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, ...@@ -72,7 +72,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
SparseSoftmaxCrossEntropyWithLogits, Tanh, SparseSoftmaxCrossEntropyWithLogits, Tanh,
TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl, TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl,
ApplyProximalAdagrad, SparseApplyProximalAdagrad, ApplyProximalAdagrad, SparseApplyProximalAdagrad,
ApplyRMSProp, ApplyCenteredRMSProp) ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell)
from .other_ops import Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, CheckValid, MakeRefKey, CheckBprop from .other_ops import Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, CheckValid, MakeRefKey, CheckBprop
from . import _quant_ops from . import _quant_ops
from ._quant_ops import * from ._quant_ops import *
...@@ -287,7 +287,8 @@ __all__ = [ ...@@ -287,7 +287,8 @@ __all__ = [
"BesselI0e", "BesselI0e",
"BesselI1e", "BesselI1e",
"Atan", "Atan",
"Atanh" "Atanh",
"BasicLSTMCell"
] ]
__all__.extend(_quant_ops.__all__) __all__.extend(_quant_ops.__all__)
......
...@@ -1173,3 +1173,106 @@ class AtanGrad(PrimitiveWithInfer): ...@@ -1173,3 +1173,106 @@ class AtanGrad(PrimitiveWithInfer):
args = {"x": x, "dout": dout} args = {"x": x, "dout": dout}
validator.check_tensor_type_same(args, mstype.number_type, self.name) validator.check_tensor_type_same(args, mstype.number_type, self.name)
return x return x
class BasicLSTMCellCStateGrad(PrimitiveWithInfer):
"""Computes the state gradients of BasicLSTMCell."""
@prim_attr_register
def __init__(self, forget_bias, activation):
self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
self.activation = validator.check_string("activation", activation, ['tanh'], self.name)
def infer_shape(self, c_shape, dht_shape, dct_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape):
# dhy and dcy should be same shape
validator.check_integer("c rank", len(c_shape), 2, Rel.EQ, self.name)
validator.check("dht rank", len(dht_shape), "c rank", len(c_shape), Rel.EQ, self.name)
validator.check("dct rank", len(dct_shape), "c rank", len(c_shape), Rel.EQ, self.name)
validator.check("it rank", len(it_shape), "c rank", len(c_shape), Rel.EQ, self.name)
validator.check("jt rank", len(jt_shape), "c rank", len(c_shape), Rel.EQ, self.name)
validator.check("ft rank", len(ft_shape), "c rank", len(c_shape), Rel.EQ, self.name)
validator.check("ot rank", len(ot_shape), "c rank", len(c_shape), Rel.EQ, self.name)
validator.check("tanhct rank", len(tanhct_shape), "c rank", len(c_shape), Rel.EQ, self.name)
validator.check("dht shape", dht_shape, "c shape", c_shape, Rel.EQ, self.name)
validator.check("dct shape", dct_shape, "c shape", c_shape, Rel.EQ, self.name)
validator.check("it shape", it_shape, "c shape", c_shape, Rel.EQ, self.name)
validator.check("jt shape", jt_shape, "c shape", c_shape, Rel.EQ, self.name)
validator.check("ft shape", ft_shape, "c shape", c_shape, Rel.EQ, self.name)
validator.check("ot shape", ot_shape, "c shape", c_shape, Rel.EQ, self.name)
validator.check("tanhct shape", tanhct_shape, "c shape", c_shape, Rel.EQ, self.name)
dgate_shape = (c_shape[0], 4 * c_shape[1])
dct_1_shape = c_shape
return (dgate_shape, dct_1_shape)
def infer_dtype(self, c_dtype, dht_dtype, dct_dtype, it_dtype, jt_dtype, ft_dtype, ot_dtype, tanhct_dtype):
validator.check_subclass("c", c_dtype, [mstype.tensor], self.name)
validator.check_subclass("dht", dht_dtype, [mstype.tensor], self.name)
validator.check_subclass("dct", dct_dtype, [mstype.tensor], self.name)
validator.check_subclass("it", it_dtype, [mstype.tensor], self.name)
validator.check_subclass("jt", jt_dtype, [mstype.tensor], self.name)
validator.check_subclass("ft", ft_dtype, [mstype.tensor], self.name)
validator.check_subclass("ot", ot_dtype, [mstype.tensor], self.name)
validator.check_subclass("tanhct", tanhct_dtype, [mstype.tensor], self.name)
validator.check_type_name("c", c_dtype, [mstype.float16, mstype.float32], self.name)
validator.check_type_name("dht", dht_dtype, [mstype.float16, mstype.float32], self.name)
validator.check_type_name("dct", dct_dtype, [mstype.float16, mstype.float32], self.name)
validator.check_type_name("it", it_dtype, [mstype.float16, mstype.float32], self.name)
validator.check_type_name("jt", jt_dtype, [mstype.float16, mstype.float32], self.name)
validator.check_type_name("ft", ft_dtype, [mstype.float16, mstype.float32], self.name)
validator.check_type_name("ot", ot_dtype, [mstype.float16, mstype.float32], self.name)
validator.check_type_name("tanhct", tanhct_dtype, [mstype.float16, mstype.float32], self.name)
return (c_dtype, c_dtype)
class BasicLSTMCellWeightGrad(PrimitiveWithInfer):
"""Computes the weight gradients of BasicLSTM."""
@prim_attr_register
def __init__(self):
pass
def infer_shape(self, x_shape, h_shape, dgate_shape):
validator.check_integer("x rank", len(x_shape), 2, Rel.EQ, self.name)
validator.check("h rank", len(h_shape), " x rank", len(x_shape), Rel.EQ, self.name)
validator.check("dgate rank", len(dgate_shape), "x rank", len(x_shape), Rel.EQ, self.name)
validator.check("h_shape[0]", h_shape[0], "x_shape[0]", x_shape[0], Rel.EQ, self.name)
validator.check("dgate_shape[0]", dgate_shape[0], "h_shape[0]", h_shape[0], Rel.EQ, self.name)
validator.check("dgate_shape[1]", dgate_shape[1], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name)
dw_shape = (dgate_shape[1], x_shape[1] + h_shape[1], 1, 1)
db_shape = (dgate_shape[1], 1, 1, 1)
return (dw_shape, db_shape)
def infer_dtype(self, x_dtype, h_dtype, dgate_dtype):
validator.check_subclass("x", x_dtype, mstype.tensor, self.name)
validator.check_subclass("h", h_dtype, mstype.tensor, self.name)
validator.check_subclass("dgate", dgate_dtype, mstype.tensor, self.name)
validator.check_type_name("x", x_dtype, [mstype.float16, mstype.float32], self.name)
validator.check_type_name("h", h_dtype, [mstype.float16, mstype.float32], self.name)
validator.check_type_name("dgate", dgate_dtype, [mstype.float16, mstype.float32], self.name)
return (x_dtype, x_dtype)
class BasicLSTMCellInputGrad(PrimitiveWithInfer):
"""Computes the input gradients of BasicLSTM."""
@prim_attr_register
def __init__(self, keep_prob):
self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0.0, 1.0, Rel.INC_BOTH, self.name)
def infer_shape(self, dgate_shape, w_shape):
validator.check_integer("dgate rank", len(dgate_shape), 2, Rel.EQ, self.name)
validator.check_integer("w rank", len(w_shape), 4, Rel.EQ, self.name)
validator.check("dgate_shape[1]", dgate_shape[1], "w_shape[0]", w_shape[0], Rel.EQ, self.name)
dxt_shape = (dgate_shape[0], w_shape[1] - w_shape[0] // 4)
dht_shape = (dgate_shape[0], dgate_shape[1] // 4)
return (dxt_shape, dht_shape)
def infer_dtype(self, dgate_dtype, w_dtype):
validator.check_subclass("dgate", dgate_dtype, mstype.tensor, self.name)
validator.check_subclass("w", w_dtype, mstype.tensor, self.name)
validator.check_type_name("dgate", dgate_dtype, [mstype.float16, mstype.float32], self.name)
validator.check_type_name("w", w_dtype, [mstype.float16, mstype.float32], self.name)
return (dgate_dtype, dgate_dtype)
...@@ -3363,3 +3363,109 @@ class CTCLoss(PrimitiveWithInfer): ...@@ -3363,3 +3363,109 @@ class CTCLoss(PrimitiveWithInfer):
validator.check_tensor_type_same({"labels_values_dtype": labels_values}, [mstype.int32], self.name) validator.check_tensor_type_same({"labels_values_dtype": labels_values}, [mstype.int32], self.name)
validator.check_tensor_type_same({"sequence_length_dtype": sequence_length}, [mstype.int32], self.name) validator.check_tensor_type_same({"sequence_length_dtype": sequence_length}, [mstype.int32], self.name)
return inputs, inputs return inputs, inputs
class BasicLSTMCell(PrimitiveWithInfer):
r"""
Performs the long short term memory(LSTM) on the input.
.. math::
\begin{array}{ll} \\
i_t = \sigma(W_{ix} x_t + b_{ix} + W_{ih} h_{(t-1)} + b_{ih}) \\
f_t = \sigma(W_{fx} x_t + b_{fx} + W_{fh} h_{(t-1)} + b_{fh}) \\
\tilde{c}_t = \tanh(W_{cx} x_t + b_{cx} + W_{ch} h_{(t-1)} + b_{ch}) \\
o_t = \sigma(W_{ox} x_t + b_{ox} + W_{oh} h_{(t-1)} + b_{oh}) \\
c_t = f_t * c_{(t-1)} + i_t * \tilde{c}_t \\
h_t = o_t * \tanh(c_t) \\
\end{array}
Here :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product. :math:`W, b`
are learnable weights between the output and the input in the formula. For instance,
:math:`W_{ix}, b_{ix}` are the weight and bias used to transform from input :math:`x` to :math:`i`.
Details can be found in paper `LONG SHORT-TERM MEMORY
<https://www.bioinf.jku.at/publications/older/2604.pdf>`_ and
`Long Short-Term Memory Recurrent Neural Network Architectures for Large Scale Acoustic Modeling
<https://static.googleusercontent.com/media/research.google.com/zh-CN//pubs/archive/43905.pdf>`_.
Args:
keep_prob (float): If not 1.0, append `Dropout` layer on the outputs of each
LSTM layer except the last layer. Default 1.0. The range of dropout is [0.0, 1.0].
forget_bias (float): Add forget bias to forget gate biases in order to decrease former scale. Default to 1.0.
state_is_tuple (bool): If True, state is tensor tuple, containing h and c; If False, one tensor,
need split first. Default to True.
activation (str): Activation. Default to "tanh".
Inputs:
- **x** (Tensor) - Current words. Tensor of shape (`batch_size`, `input_size`).
- **h** (Tensor) - Hidden state last moment. Tensor of shape (`batch_size`, `hidden_size`).
- **c** (Tensor) - Cell state last moment. Tensor of shape (`batch_size`, `hidden_size`).
- **w** (Tensor) - Weight. Tensor of shape (`4 x hidden_size`, `input_size + hidden_size`, 1, 1).
- **b** (Tensor) - Bias. Tensor of shape (`4 x hidden_size`, 1, 1, 1).
Outputs:
- **ct** (Tensor) - Forward :math:`c_t` cache at moment `t`. Tensor of shape (`batch_size`, `hidden_size`).
- **ht** (Tensor) - Cell output. Tensor of shape (`batch_size`, `hidden_size`).
- **it** (Tensor) - Forward :math:`i_t` cache at moment `t`. Tensor of shape (`batch_size`, `4 x hidden_size`).
- **jt** (Tensor) - Forward :math:`j_t` cache at moment `t`. Tensor of shape (`batch_size`, `4 x hidden_size`).
- **ft** (Tensor) - Forward :math:`f_t` cache at moment `t`. Tensor of shape (`batch_size`, `4 x hidden_size`).
- **ot** (Tensor) - Forward :math:`o_t` cache at moment `t`. Tensor of shape (`batch_size`, `4 x hidden_size`).
- **tanhct** (Tensor) - Forward :math:`tanh c_t` cache at moment `t`.
Tensor of shape (`batch_size`, `4 x hidden_size`).
Examples:
'block': P.BasicLSTMCell(keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh'),
'desc_inputs': [[128, 128], [128, 128], [128, 128], [512, 256, 1, 1],[512, 1, 1, 1]],
'desc_bprop': [[128, 128], [128, 128], [128, 128], [128, 128], [128, 128], [128, 128], [128, 128]],
>>> x = Tensor(np.random.rand(128, 128).astype(np.float16))
>>> h = Tensor(np.random.rand(128, 128).astype(np.float16))
>>> c = Tensor(np.random.rand(128, 128).astype(np.float16))
>>> w = Tensor(np.random.rand(512, 256, 1, 1).astype(np.float16))
>>> b = Tensor(np.random.rand(512, 1, 1, 1).astype(np.float16))
>>> lstm = P.BasicLSTMCell(keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh')
>>> lstm(x, h, c, w, b)
"""
@prim_attr_register
def __init__(self, keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh'):
self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0.0, 1.0, Rel.INC_BOTH, self.name)
self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
self.state_is_tuple = validator.check_value_type("state_is_tuple", state_is_tuple, [bool], self.name)
self.activation = validator.check_string("activation", activation, ['tanh'], self.name)
def infer_shape(self, x_shape, h_shape, c_shape, w_shape, b_shape):
# (batch_size, input_size)
validator.check_integer("x_shape", len(x_shape), 2, Rel.EQ, self.name)
# h and c should be same shape
validator.check_integer("h_shape", len(h_shape), 2, Rel.EQ, self.name)
validator.check("h rank", len(h_shape), "c rank", len(c_shape), Rel.EQ, self.name)
validator.check("h shape", h_shape, "c shape", c_shape, Rel.EQ, self.name)
validator.check_integer("w rank", len(w_shape), 4, Rel.EQ, self.name)
validator.check_integer("b rank", len(b_shape), 4, Rel.EQ, self.name)
validator.check("w_shape[0]", w_shape[0], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name)
validator.check("w_shape[1]", w_shape[1], "x_shape[1]+h_shape[1]", x_shape[1] + h_shape[1], Rel.EQ, self.name)
validator.check("b_shape[0]", b_shape[0], "4*h_shape[1]", 4*h_shape[1], Rel.EQ, self.name)
ct_shape = c_shape
ht_shape = h_shape
it_shape = h_shape
jt_shape = h_shape
ft_shape = h_shape
ot_shape = h_shape
tanhct_shape = h_shape
return (ct_shape, ht_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape)
def infer_dtype(self, x_dtype, h_dtype, c_dtype, w_dtype, b_dtype):
validator.check_subclass("x", x_dtype, [mstype.tensor], self.name)
validator.check_subclass("h", h_dtype, [mstype.tensor], self.name)
validator.check_subclass("c", c_dtype, [mstype.tensor], self.name)
validator.check_subclass("w", w_dtype, [mstype.tensor], self.name)
validator.check_subclass("b", b_dtype, [mstype.tensor], self.name)
validator.check_type_name("x", x_dtype, [mstype.float16, mstype.float32], self.name)
validator.check_type_name("h", h_dtype, [mstype.float16, mstype.float32], self.name)
validator.check_type_name("c", c_dtype, [mstype.float16, mstype.float32], self.name)
validator.check_type_name("w", w_dtype, [mstype.float16, mstype.float32], self.name)
validator.check_type_name("b", b_dtype, [mstype.float16, mstype.float32], self.name)
return (x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype)
...@@ -891,6 +891,11 @@ test_case_nn_ops = [ ...@@ -891,6 +891,11 @@ test_case_nn_ops = [
'desc_inputs': [[128, 64, 32, 32], [128, 64, 32, 32], [64], [64], [64]], 'desc_inputs': [[128, 64, 32, 32], [128, 64, 32, 32], [64], [64], [64]],
'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]], 'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]],
'skip': ['backward']}), 'skip': ['backward']}),
('BasicLSTMCell', {
'block': P.BasicLSTMCell(keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh'),
'desc_inputs': [[128, 128], [128, 128], [128, 128], [512, 256, 1, 1],[512, 1, 1, 1]],
'desc_bprop': [[128, 128], [128, 128], [128, 128], [128, 128], [128, 128], [128, 128], [128, 128]],
'skip': []}),
('TopK', { ('TopK', {
'block': P.TopK(), 'block': P.TopK(),
'desc_const': [5], 'desc_const': [5],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册