提交 270f79c8 编写于 作者: Z zhaozhenlong

add op BasicLSTMCell vm

上级 0b94376b
......@@ -299,6 +299,9 @@ class Validator:
def get_typename(t):
return t.__name__ if hasattr(t, '__name__') else str(t)
if isinstance(arg_type, type(mstype.tensor)):
arg_type = arg_type.element_type()
if arg_type in valid_types:
return arg_type
type_names = [get_typename(t) for t in valid_types]
......
......@@ -697,3 +697,25 @@ def get_bprop_ctc_loss(self):
return grad, zeros_like(labels_indices), zeros_like(labels_values), zeros_like(sequence_length)
return bprop
@bprop_getters.register(P.BasicLSTMCell)
def get_bprop_basic_lstm_cell(self):
"""Grad definition for `BasicLSTMCell` operation."""
basic_lstm_cell_cstate_grad = G.BasicLSTMCellCStateGrad(
forget_bias=self.forget_bias,
activation=self.activation
)
basic_lstm_cell_weight_grad = G.BasicLSTMCellWeightGrad()
basic_lstm_cell_input_grad = G.BasicLSTMCellInputGrad(keep_prob=self.keep_prob)
def bprop(x, h, c, w, b, out, dout):
_, _, it, jt, ft, ot, tanhct = out
dct, dht, _, _, _, _, _ = dout
dgate, dct_1 = basic_lstm_cell_cstate_grad(c, dht, dct, it, jt, ft, ot, tanhct)
dxt, dht = basic_lstm_cell_input_grad(dgate, w)
dw, db = basic_lstm_cell_weight_grad(F.depend(x, dxt), h, dgate)
return dxt, dht, dct_1, dw, db
return bprop
......@@ -230,3 +230,7 @@ from .atan_grad import _atan_grad_tbe
from .atanh import _atanh_tbe
from .cosh import _cosh_tbe
from .sinh import _sinh_tbe
from .basic_lstm_cell import _basic_lstm_cell_tbe
from .basic_lstm_cell_c_state_grad import _basic_lstm_cell_c_state_grad_tbe
from .basic_lstm_cell_weight_grad import _basic_lstm_cell_weight_grad_tbe
from .basic_lstm_cell_input_grad import _basic_lstm_cell_input_grad_tbe
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""BasicLSTMCell op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
basic_lstm_cell_op_info = TBERegOp("BasicLSTMCell") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("basic_lstm_cell.so") \
.compute_cost(10) \
.kernel_name("basic_lstm_cell") \
.attr("keep_prob", "optional", "float", "all") \
.attr("forget_bias", "optional", "float", "all") \
.attr("state_is_tuple", "optional", "bool", "true") \
.attr("activation", "optional", "str", "all") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.input(1, "h", False, "required", "all") \
.input(2, "c", False, "required", "all") \
.input(3, "w", False, "required", "all") \
.input(4, "b", False, "required", "all") \
.input(5, "mask", False, "optional", "all") \
.output(0, "ct", False, "required", "all") \
.output(1, "ht", False, "required", "all") \
.output(2, "it", False, "optional", "all") \
.output(3, "jt", False, "optional", "all") \
.output(4, "ft", False, "optional", "all") \
.output(5, "ot", False, "optional", "all") \
.output(6, "tanhct", False, "optional", "all") \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F32_FracNZ, DataType.F16_FracZ,
DataType.F32_Default, DataType.U8_Default, DataType.F32_FracNZ, DataType.F16_FracNZ,
DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ,
DataType.F32_FracNZ) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracZ,
DataType.F16_Default, DataType.U8_Default, DataType.F16_FracNZ, DataType.F16_FracNZ,
DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
DataType.F16_FracNZ) \
.get_op_info()
@op_info_register(basic_lstm_cell_op_info)
def _basic_lstm_cell_tbe():
"""BasicLSTMCell TBE register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""BasicLSTMCellCStateGrad op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
basic_lstm_cell_c_state_grad_op_info = TBERegOp("BasicLSTMCellCStateGrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("basic_lstm_cell_c_state_grad.so") \
.compute_cost(10) \
.kernel_name("basic_lstm_cell_c_state_grad") \
.attr("forget_bias", "optional", "float", "all") \
.attr("activation", "optional", "str", "all") \
.partial_flag(True) \
.input(0, "c", False, "required", "all") \
.input(1, "dht", False, "required", "all") \
.input(2, "dct", False, "required", "all") \
.input(3, "it", False, "required", "all") \
.input(4, "ft", False, "required", "all") \
.input(5, "jt", False, "required", "all") \
.input(6, "ot", False, "required", "all") \
.input(7, "tanhct", False, "required", "all") \
.output(0, "dgate", False, "required", "all") \
.output(1, "dct_1", False, "required", "all") \
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ,
DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ,
DataType.F16_FracNZ, DataType.F16_FracNZ) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
DataType.F32_FracNZ, DataType.F16_FracNZ) \
.get_op_info()
@op_info_register(basic_lstm_cell_c_state_grad_op_info)
def _basic_lstm_cell_c_state_grad_tbe():
"""BasicLSTMCellCStateGrad TBE register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""BasicLSTMCellInputGrad op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
basic_lstm_cell_input_grad_op_info = TBERegOp("BasicLSTMCellInputGrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("basic_lstm_cell_input_grad.so") \
.compute_cost(10) \
.kernel_name("basic_lstm_cell_input_grad") \
.attr("keep_prob", "optional", "float", "all") \
.partial_flag(True) \
.input(0, "dgate", False, "required", "all") \
.input(1, "w", False, "required", "all") \
.input(2, "dropout_mask", False, "optional", "all") \
.output(0, "dxt", False, "required", "all") \
.output(1, "dht", False, "required", "all") \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.U8_Default, DataType.F32_FracNZ,
DataType.F32_FracNZ) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.U8_Default, DataType.F16_FracNZ,
DataType.F16_FracNZ) \
.get_op_info()
@op_info_register(basic_lstm_cell_input_grad_op_info)
def _basic_lstm_cell_input_grad_tbe():
"""BasicLSTMCellInputGrad TBE register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""BasicLSTMCellWeightGrad op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
basic_lstm_cell_weight_grad_op_info = TBERegOp("BasicLSTMCellWeightGrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("basic_lstm_cell_weight_grad.so") \
.compute_cost(10) \
.kernel_name("basic_lstm_cell_weight_grad") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.input(1, "h", False, "required", "all") \
.input(2, "dgate", False, "required", "all") \
.output(0, "dw", False, "required", "all") \
.output(1, "db", False, "required", "all") \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracZ,
DataType.F32_Default) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracZ,
DataType.F16_Default) \
.get_op_info()
@op_info_register(basic_lstm_cell_weight_grad_op_info)
def _basic_lstm_cell_weight_grad_tbe():
"""BasicLSTMCellWeightGrad TBE register"""
return
......@@ -72,7 +72,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
SparseSoftmaxCrossEntropyWithLogits, Tanh,
TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl,
ApplyProximalAdagrad, SparseApplyProximalAdagrad,
ApplyRMSProp, ApplyCenteredRMSProp)
ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell)
from .other_ops import Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, CheckValid, MakeRefKey, CheckBprop
from . import _quant_ops
from ._quant_ops import *
......@@ -287,7 +287,8 @@ __all__ = [
"BesselI0e",
"BesselI1e",
"Atan",
"Atanh"
"Atanh",
"BasicLSTMCell"
]
__all__.extend(_quant_ops.__all__)
......
......@@ -1173,3 +1173,106 @@ class AtanGrad(PrimitiveWithInfer):
args = {"x": x, "dout": dout}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
return x
class BasicLSTMCellCStateGrad(PrimitiveWithInfer):
"""Computes the state gradients of BasicLSTMCell."""
@prim_attr_register
def __init__(self, forget_bias, activation):
self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
self.activation = validator.check_string("activation", activation, ['tanh'], self.name)
def infer_shape(self, c_shape, dht_shape, dct_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape):
# dhy and dcy should be same shape
validator.check_integer("c rank", len(c_shape), 2, Rel.EQ, self.name)
validator.check("dht rank", len(dht_shape), "c rank", len(c_shape), Rel.EQ, self.name)
validator.check("dct rank", len(dct_shape), "c rank", len(c_shape), Rel.EQ, self.name)
validator.check("it rank", len(it_shape), "c rank", len(c_shape), Rel.EQ, self.name)
validator.check("jt rank", len(jt_shape), "c rank", len(c_shape), Rel.EQ, self.name)
validator.check("ft rank", len(ft_shape), "c rank", len(c_shape), Rel.EQ, self.name)
validator.check("ot rank", len(ot_shape), "c rank", len(c_shape), Rel.EQ, self.name)
validator.check("tanhct rank", len(tanhct_shape), "c rank", len(c_shape), Rel.EQ, self.name)
validator.check("dht shape", dht_shape, "c shape", c_shape, Rel.EQ, self.name)
validator.check("dct shape", dct_shape, "c shape", c_shape, Rel.EQ, self.name)
validator.check("it shape", it_shape, "c shape", c_shape, Rel.EQ, self.name)
validator.check("jt shape", jt_shape, "c shape", c_shape, Rel.EQ, self.name)
validator.check("ft shape", ft_shape, "c shape", c_shape, Rel.EQ, self.name)
validator.check("ot shape", ot_shape, "c shape", c_shape, Rel.EQ, self.name)
validator.check("tanhct shape", tanhct_shape, "c shape", c_shape, Rel.EQ, self.name)
dgate_shape = (c_shape[0], 4 * c_shape[1])
dct_1_shape = c_shape
return (dgate_shape, dct_1_shape)
def infer_dtype(self, c_dtype, dht_dtype, dct_dtype, it_dtype, jt_dtype, ft_dtype, ot_dtype, tanhct_dtype):
validator.check_subclass("c", c_dtype, [mstype.tensor], self.name)
validator.check_subclass("dht", dht_dtype, [mstype.tensor], self.name)
validator.check_subclass("dct", dct_dtype, [mstype.tensor], self.name)
validator.check_subclass("it", it_dtype, [mstype.tensor], self.name)
validator.check_subclass("jt", jt_dtype, [mstype.tensor], self.name)
validator.check_subclass("ft", ft_dtype, [mstype.tensor], self.name)
validator.check_subclass("ot", ot_dtype, [mstype.tensor], self.name)
validator.check_subclass("tanhct", tanhct_dtype, [mstype.tensor], self.name)
validator.check_type_name("c", c_dtype, [mstype.float16, mstype.float32], self.name)
validator.check_type_name("dht", dht_dtype, [mstype.float16, mstype.float32], self.name)
validator.check_type_name("dct", dct_dtype, [mstype.float16, mstype.float32], self.name)
validator.check_type_name("it", it_dtype, [mstype.float16, mstype.float32], self.name)
validator.check_type_name("jt", jt_dtype, [mstype.float16, mstype.float32], self.name)
validator.check_type_name("ft", ft_dtype, [mstype.float16, mstype.float32], self.name)
validator.check_type_name("ot", ot_dtype, [mstype.float16, mstype.float32], self.name)
validator.check_type_name("tanhct", tanhct_dtype, [mstype.float16, mstype.float32], self.name)
return (c_dtype, c_dtype)
class BasicLSTMCellWeightGrad(PrimitiveWithInfer):
"""Computes the weight gradients of BasicLSTM."""
@prim_attr_register
def __init__(self):
pass
def infer_shape(self, x_shape, h_shape, dgate_shape):
validator.check_integer("x rank", len(x_shape), 2, Rel.EQ, self.name)
validator.check("h rank", len(h_shape), " x rank", len(x_shape), Rel.EQ, self.name)
validator.check("dgate rank", len(dgate_shape), "x rank", len(x_shape), Rel.EQ, self.name)
validator.check("h_shape[0]", h_shape[0], "x_shape[0]", x_shape[0], Rel.EQ, self.name)
validator.check("dgate_shape[0]", dgate_shape[0], "h_shape[0]", h_shape[0], Rel.EQ, self.name)
validator.check("dgate_shape[1]", dgate_shape[1], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name)
dw_shape = (dgate_shape[1], x_shape[1] + h_shape[1], 1, 1)
db_shape = (dgate_shape[1], 1, 1, 1)
return (dw_shape, db_shape)
def infer_dtype(self, x_dtype, h_dtype, dgate_dtype):
validator.check_subclass("x", x_dtype, mstype.tensor, self.name)
validator.check_subclass("h", h_dtype, mstype.tensor, self.name)
validator.check_subclass("dgate", dgate_dtype, mstype.tensor, self.name)
validator.check_type_name("x", x_dtype, [mstype.float16, mstype.float32], self.name)
validator.check_type_name("h", h_dtype, [mstype.float16, mstype.float32], self.name)
validator.check_type_name("dgate", dgate_dtype, [mstype.float16, mstype.float32], self.name)
return (x_dtype, x_dtype)
class BasicLSTMCellInputGrad(PrimitiveWithInfer):
"""Computes the input gradients of BasicLSTM."""
@prim_attr_register
def __init__(self, keep_prob):
self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0.0, 1.0, Rel.INC_BOTH, self.name)
def infer_shape(self, dgate_shape, w_shape):
validator.check_integer("dgate rank", len(dgate_shape), 2, Rel.EQ, self.name)
validator.check_integer("w rank", len(w_shape), 4, Rel.EQ, self.name)
validator.check("dgate_shape[1]", dgate_shape[1], "w_shape[0]", w_shape[0], Rel.EQ, self.name)
dxt_shape = (dgate_shape[0], w_shape[1] - w_shape[0] // 4)
dht_shape = (dgate_shape[0], dgate_shape[1] // 4)
return (dxt_shape, dht_shape)
def infer_dtype(self, dgate_dtype, w_dtype):
validator.check_subclass("dgate", dgate_dtype, mstype.tensor, self.name)
validator.check_subclass("w", w_dtype, mstype.tensor, self.name)
validator.check_type_name("dgate", dgate_dtype, [mstype.float16, mstype.float32], self.name)
validator.check_type_name("w", w_dtype, [mstype.float16, mstype.float32], self.name)
return (dgate_dtype, dgate_dtype)
......@@ -3363,3 +3363,109 @@ class CTCLoss(PrimitiveWithInfer):
validator.check_tensor_type_same({"labels_values_dtype": labels_values}, [mstype.int32], self.name)
validator.check_tensor_type_same({"sequence_length_dtype": sequence_length}, [mstype.int32], self.name)
return inputs, inputs
class BasicLSTMCell(PrimitiveWithInfer):
r"""
Performs the long short term memory(LSTM) on the input.
.. math::
\begin{array}{ll} \\
i_t = \sigma(W_{ix} x_t + b_{ix} + W_{ih} h_{(t-1)} + b_{ih}) \\
f_t = \sigma(W_{fx} x_t + b_{fx} + W_{fh} h_{(t-1)} + b_{fh}) \\
\tilde{c}_t = \tanh(W_{cx} x_t + b_{cx} + W_{ch} h_{(t-1)} + b_{ch}) \\
o_t = \sigma(W_{ox} x_t + b_{ox} + W_{oh} h_{(t-1)} + b_{oh}) \\
c_t = f_t * c_{(t-1)} + i_t * \tilde{c}_t \\
h_t = o_t * \tanh(c_t) \\
\end{array}
Here :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product. :math:`W, b`
are learnable weights between the output and the input in the formula. For instance,
:math:`W_{ix}, b_{ix}` are the weight and bias used to transform from input :math:`x` to :math:`i`.
Details can be found in paper `LONG SHORT-TERM MEMORY
<https://www.bioinf.jku.at/publications/older/2604.pdf>`_ and
`Long Short-Term Memory Recurrent Neural Network Architectures for Large Scale Acoustic Modeling
<https://static.googleusercontent.com/media/research.google.com/zh-CN//pubs/archive/43905.pdf>`_.
Args:
keep_prob (float): If not 1.0, append `Dropout` layer on the outputs of each
LSTM layer except the last layer. Default 1.0. The range of dropout is [0.0, 1.0].
forget_bias (float): Add forget bias to forget gate biases in order to decrease former scale. Default to 1.0.
state_is_tuple (bool): If True, state is tensor tuple, containing h and c; If False, one tensor,
need split first. Default to True.
activation (str): Activation. Default to "tanh".
Inputs:
- **x** (Tensor) - Current words. Tensor of shape (`batch_size`, `input_size`).
- **h** (Tensor) - Hidden state last moment. Tensor of shape (`batch_size`, `hidden_size`).
- **c** (Tensor) - Cell state last moment. Tensor of shape (`batch_size`, `hidden_size`).
- **w** (Tensor) - Weight. Tensor of shape (`4 x hidden_size`, `input_size + hidden_size`, 1, 1).
- **b** (Tensor) - Bias. Tensor of shape (`4 x hidden_size`, 1, 1, 1).
Outputs:
- **ct** (Tensor) - Forward :math:`c_t` cache at moment `t`. Tensor of shape (`batch_size`, `hidden_size`).
- **ht** (Tensor) - Cell output. Tensor of shape (`batch_size`, `hidden_size`).
- **it** (Tensor) - Forward :math:`i_t` cache at moment `t`. Tensor of shape (`batch_size`, `4 x hidden_size`).
- **jt** (Tensor) - Forward :math:`j_t` cache at moment `t`. Tensor of shape (`batch_size`, `4 x hidden_size`).
- **ft** (Tensor) - Forward :math:`f_t` cache at moment `t`. Tensor of shape (`batch_size`, `4 x hidden_size`).
- **ot** (Tensor) - Forward :math:`o_t` cache at moment `t`. Tensor of shape (`batch_size`, `4 x hidden_size`).
- **tanhct** (Tensor) - Forward :math:`tanh c_t` cache at moment `t`.
Tensor of shape (`batch_size`, `4 x hidden_size`).
Examples:
'block': P.BasicLSTMCell(keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh'),
'desc_inputs': [[128, 128], [128, 128], [128, 128], [512, 256, 1, 1],[512, 1, 1, 1]],
'desc_bprop': [[128, 128], [128, 128], [128, 128], [128, 128], [128, 128], [128, 128], [128, 128]],
>>> x = Tensor(np.random.rand(128, 128).astype(np.float16))
>>> h = Tensor(np.random.rand(128, 128).astype(np.float16))
>>> c = Tensor(np.random.rand(128, 128).astype(np.float16))
>>> w = Tensor(np.random.rand(512, 256, 1, 1).astype(np.float16))
>>> b = Tensor(np.random.rand(512, 1, 1, 1).astype(np.float16))
>>> lstm = P.BasicLSTMCell(keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh')
>>> lstm(x, h, c, w, b)
"""
@prim_attr_register
def __init__(self, keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh'):
self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0.0, 1.0, Rel.INC_BOTH, self.name)
self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
self.state_is_tuple = validator.check_value_type("state_is_tuple", state_is_tuple, [bool], self.name)
self.activation = validator.check_string("activation", activation, ['tanh'], self.name)
def infer_shape(self, x_shape, h_shape, c_shape, w_shape, b_shape):
# (batch_size, input_size)
validator.check_integer("x_shape", len(x_shape), 2, Rel.EQ, self.name)
# h and c should be same shape
validator.check_integer("h_shape", len(h_shape), 2, Rel.EQ, self.name)
validator.check("h rank", len(h_shape), "c rank", len(c_shape), Rel.EQ, self.name)
validator.check("h shape", h_shape, "c shape", c_shape, Rel.EQ, self.name)
validator.check_integer("w rank", len(w_shape), 4, Rel.EQ, self.name)
validator.check_integer("b rank", len(b_shape), 4, Rel.EQ, self.name)
validator.check("w_shape[0]", w_shape[0], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name)
validator.check("w_shape[1]", w_shape[1], "x_shape[1]+h_shape[1]", x_shape[1] + h_shape[1], Rel.EQ, self.name)
validator.check("b_shape[0]", b_shape[0], "4*h_shape[1]", 4*h_shape[1], Rel.EQ, self.name)
ct_shape = c_shape
ht_shape = h_shape
it_shape = h_shape
jt_shape = h_shape
ft_shape = h_shape
ot_shape = h_shape
tanhct_shape = h_shape
return (ct_shape, ht_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape)
def infer_dtype(self, x_dtype, h_dtype, c_dtype, w_dtype, b_dtype):
validator.check_subclass("x", x_dtype, [mstype.tensor], self.name)
validator.check_subclass("h", h_dtype, [mstype.tensor], self.name)
validator.check_subclass("c", c_dtype, [mstype.tensor], self.name)
validator.check_subclass("w", w_dtype, [mstype.tensor], self.name)
validator.check_subclass("b", b_dtype, [mstype.tensor], self.name)
validator.check_type_name("x", x_dtype, [mstype.float16, mstype.float32], self.name)
validator.check_type_name("h", h_dtype, [mstype.float16, mstype.float32], self.name)
validator.check_type_name("c", c_dtype, [mstype.float16, mstype.float32], self.name)
validator.check_type_name("w", w_dtype, [mstype.float16, mstype.float32], self.name)
validator.check_type_name("b", b_dtype, [mstype.float16, mstype.float32], self.name)
return (x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype)
......@@ -891,6 +891,11 @@ test_case_nn_ops = [
'desc_inputs': [[128, 64, 32, 32], [128, 64, 32, 32], [64], [64], [64]],
'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]],
'skip': ['backward']}),
('BasicLSTMCell', {
'block': P.BasicLSTMCell(keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh'),
'desc_inputs': [[128, 128], [128, 128], [128, 128], [512, 256, 1, 1],[512, 1, 1, 1]],
'desc_bprop': [[128, 128], [128, 128], [128, 128], [128, 128], [128, 128], [128, 128], [128, 128]],
'skip': []}),
('TopK', {
'block': P.TopK(),
'desc_const': [5],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册