diff --git a/mindspore/_checkparam.py b/mindspore/_checkparam.py index 0c101bf1a88f5db781706c6bf52f040a43683197..d8ca5a9845a8648f0089396a757f9507d485e763 100644 --- a/mindspore/_checkparam.py +++ b/mindspore/_checkparam.py @@ -299,6 +299,9 @@ class Validator: def get_typename(t): return t.__name__ if hasattr(t, '__name__') else str(t) + if isinstance(arg_type, type(mstype.tensor)): + arg_type = arg_type.element_type() + if arg_type in valid_types: return arg_type type_names = [get_typename(t) for t in valid_types] diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index 9f1ccdf5a9d76aaed4efda2f676073176e1fe730..260f3c509f141be2f29dcb4dbca6468c75cb3b4a 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -697,3 +697,25 @@ def get_bprop_ctc_loss(self): return grad, zeros_like(labels_indices), zeros_like(labels_values), zeros_like(sequence_length) return bprop + + +@bprop_getters.register(P.BasicLSTMCell) +def get_bprop_basic_lstm_cell(self): + """Grad definition for `BasicLSTMCell` operation.""" + basic_lstm_cell_cstate_grad = G.BasicLSTMCellCStateGrad( + forget_bias=self.forget_bias, + activation=self.activation + ) + + basic_lstm_cell_weight_grad = G.BasicLSTMCellWeightGrad() + + basic_lstm_cell_input_grad = G.BasicLSTMCellInputGrad(keep_prob=self.keep_prob) + + def bprop(x, h, c, w, b, out, dout): + _, _, it, jt, ft, ot, tanhct = out + dct, dht, _, _, _, _, _ = dout + dgate, dct_1 = basic_lstm_cell_cstate_grad(c, dht, dct, it, jt, ft, ot, tanhct) + dxt, dht = basic_lstm_cell_input_grad(dgate, w) + dw, db = basic_lstm_cell_weight_grad(F.depend(x, dxt), h, dgate) + return dxt, dht, dct_1, dw, db + return bprop diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 7a404cde987f6d04897bf41fdb1ac9e37fa8186e..339b3a65156dd9ca33a65f85320deebbcd8cd048 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -230,3 +230,7 @@ from .atan_grad import _atan_grad_tbe from .atanh import _atanh_tbe from .cosh import _cosh_tbe from .sinh import _sinh_tbe +from .basic_lstm_cell import _basic_lstm_cell_tbe +from .basic_lstm_cell_c_state_grad import _basic_lstm_cell_c_state_grad_tbe +from .basic_lstm_cell_weight_grad import _basic_lstm_cell_weight_grad_tbe +from .basic_lstm_cell_input_grad import _basic_lstm_cell_input_grad_tbe diff --git a/mindspore/ops/_op_impl/tbe/basic_lstm_cell.py b/mindspore/ops/_op_impl/tbe/basic_lstm_cell.py new file mode 100644 index 0000000000000000000000000000000000000000..76ad1e460765d7d2aa8907ebfc9a293f4f045669 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/basic_lstm_cell.py @@ -0,0 +1,57 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""BasicLSTMCell op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +basic_lstm_cell_op_info = TBERegOp("BasicLSTMCell") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("basic_lstm_cell.so") \ + .compute_cost(10) \ + .kernel_name("basic_lstm_cell") \ + .attr("keep_prob", "optional", "float", "all") \ + .attr("forget_bias", "optional", "float", "all") \ + .attr("state_is_tuple", "optional", "bool", "true") \ + .attr("activation", "optional", "str", "all") \ + .partial_flag(True) \ + .input(0, "x", False, "required", "all") \ + .input(1, "h", False, "required", "all") \ + .input(2, "c", False, "required", "all") \ + .input(3, "w", False, "required", "all") \ + .input(4, "b", False, "required", "all") \ + .input(5, "mask", False, "optional", "all") \ + .output(0, "ct", False, "required", "all") \ + .output(1, "ht", False, "required", "all") \ + .output(2, "it", False, "optional", "all") \ + .output(3, "jt", False, "optional", "all") \ + .output(4, "ft", False, "optional", "all") \ + .output(5, "ot", False, "optional", "all") \ + .output(6, "tanhct", False, "optional", "all") \ + .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F32_FracNZ, DataType.F16_FracZ, + DataType.F32_Default, DataType.U8_Default, DataType.F32_FracNZ, DataType.F16_FracNZ, + DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, + DataType.F32_FracNZ) \ + .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracZ, + DataType.F16_Default, DataType.U8_Default, DataType.F16_FracNZ, DataType.F16_FracNZ, + DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, + DataType.F16_FracNZ) \ + .get_op_info() + + +@op_info_register(basic_lstm_cell_op_info) +def _basic_lstm_cell_tbe(): + """BasicLSTMCell TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py b/mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..099756ad35da0a62e0319074295d25112638261c --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py @@ -0,0 +1,50 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""BasicLSTMCellCStateGrad op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +basic_lstm_cell_c_state_grad_op_info = TBERegOp("BasicLSTMCellCStateGrad") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("basic_lstm_cell_c_state_grad.so") \ + .compute_cost(10) \ + .kernel_name("basic_lstm_cell_c_state_grad") \ + .attr("forget_bias", "optional", "float", "all") \ + .attr("activation", "optional", "str", "all") \ + .partial_flag(True) \ + .input(0, "c", False, "required", "all") \ + .input(1, "dht", False, "required", "all") \ + .input(2, "dct", False, "required", "all") \ + .input(3, "it", False, "required", "all") \ + .input(4, "ft", False, "required", "all") \ + .input(5, "jt", False, "required", "all") \ + .input(6, "ot", False, "required", "all") \ + .input(7, "tanhct", False, "required", "all") \ + .output(0, "dgate", False, "required", "all") \ + .output(1, "dct_1", False, "required", "all") \ + .dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, + DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, + DataType.F16_FracNZ, DataType.F16_FracNZ) \ + .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, + DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, + DataType.F32_FracNZ, DataType.F16_FracNZ) \ + .get_op_info() + + +@op_info_register(basic_lstm_cell_c_state_grad_op_info) +def _basic_lstm_cell_c_state_grad_tbe(): + """BasicLSTMCellCStateGrad TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/basic_lstm_cell_input_grad.py b/mindspore/ops/_op_impl/tbe/basic_lstm_cell_input_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..d976d1143b779f230be5445fdbf453696842cb31 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/basic_lstm_cell_input_grad.py @@ -0,0 +1,42 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""BasicLSTMCellInputGrad op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +basic_lstm_cell_input_grad_op_info = TBERegOp("BasicLSTMCellInputGrad") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("basic_lstm_cell_input_grad.so") \ + .compute_cost(10) \ + .kernel_name("basic_lstm_cell_input_grad") \ + .attr("keep_prob", "optional", "float", "all") \ + .partial_flag(True) \ + .input(0, "dgate", False, "required", "all") \ + .input(1, "w", False, "required", "all") \ + .input(2, "dropout_mask", False, "optional", "all") \ + .output(0, "dxt", False, "required", "all") \ + .output(1, "dht", False, "required", "all") \ + .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.U8_Default, DataType.F32_FracNZ, + DataType.F32_FracNZ) \ + .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.U8_Default, DataType.F16_FracNZ, + DataType.F16_FracNZ) \ + .get_op_info() + + +@op_info_register(basic_lstm_cell_input_grad_op_info) +def _basic_lstm_cell_input_grad_tbe(): + """BasicLSTMCellInputGrad TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/basic_lstm_cell_weight_grad.py b/mindspore/ops/_op_impl/tbe/basic_lstm_cell_weight_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..83726bc51052b94f57172229a86ecdf0b51f0bfc --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/basic_lstm_cell_weight_grad.py @@ -0,0 +1,41 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""BasicLSTMCellWeightGrad op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +basic_lstm_cell_weight_grad_op_info = TBERegOp("BasicLSTMCellWeightGrad") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("basic_lstm_cell_weight_grad.so") \ + .compute_cost(10) \ + .kernel_name("basic_lstm_cell_weight_grad") \ + .partial_flag(True) \ + .input(0, "x", False, "required", "all") \ + .input(1, "h", False, "required", "all") \ + .input(2, "dgate", False, "required", "all") \ + .output(0, "dw", False, "required", "all") \ + .output(1, "db", False, "required", "all") \ + .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracZ, + DataType.F32_Default) \ + .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracZ, + DataType.F16_Default) \ + .get_op_info() + + +@op_info_register(basic_lstm_cell_weight_grad_op_info) +def _basic_lstm_cell_weight_grad_tbe(): + """BasicLSTMCellWeightGrad TBE register""" + return diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 77fbff629daaa0e763d174cd32c849b38366c7ec..4e1639894b2f6383cf9b5aec9f08fa4245185a86 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -72,7 +72,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, SparseSoftmaxCrossEntropyWithLogits, Tanh, TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl, ApplyProximalAdagrad, SparseApplyProximalAdagrad, - ApplyRMSProp, ApplyCenteredRMSProp) + ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell) from .other_ops import Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, CheckValid, MakeRefKey, CheckBprop from . import _quant_ops from ._quant_ops import * @@ -287,7 +287,8 @@ __all__ = [ "BesselI0e", "BesselI1e", "Atan", - "Atanh" + "Atanh", + "BasicLSTMCell" ] __all__.extend(_quant_ops.__all__) diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 6a2bf43e838ac4c47fb7033d87bc782fc0a5577e..008f5f0edb999afd75b776c237077bb818e82902 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -1173,3 +1173,106 @@ class AtanGrad(PrimitiveWithInfer): args = {"x": x, "dout": dout} validator.check_tensor_type_same(args, mstype.number_type, self.name) return x + + +class BasicLSTMCellCStateGrad(PrimitiveWithInfer): + """Computes the state gradients of BasicLSTMCell.""" + + @prim_attr_register + def __init__(self, forget_bias, activation): + self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name) + self.activation = validator.check_string("activation", activation, ['tanh'], self.name) + + def infer_shape(self, c_shape, dht_shape, dct_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape): + # dhy and dcy should be same shape + validator.check_integer("c rank", len(c_shape), 2, Rel.EQ, self.name) + validator.check("dht rank", len(dht_shape), "c rank", len(c_shape), Rel.EQ, self.name) + validator.check("dct rank", len(dct_shape), "c rank", len(c_shape), Rel.EQ, self.name) + validator.check("it rank", len(it_shape), "c rank", len(c_shape), Rel.EQ, self.name) + validator.check("jt rank", len(jt_shape), "c rank", len(c_shape), Rel.EQ, self.name) + validator.check("ft rank", len(ft_shape), "c rank", len(c_shape), Rel.EQ, self.name) + validator.check("ot rank", len(ot_shape), "c rank", len(c_shape), Rel.EQ, self.name) + validator.check("tanhct rank", len(tanhct_shape), "c rank", len(c_shape), Rel.EQ, self.name) + validator.check("dht shape", dht_shape, "c shape", c_shape, Rel.EQ, self.name) + validator.check("dct shape", dct_shape, "c shape", c_shape, Rel.EQ, self.name) + validator.check("it shape", it_shape, "c shape", c_shape, Rel.EQ, self.name) + validator.check("jt shape", jt_shape, "c shape", c_shape, Rel.EQ, self.name) + validator.check("ft shape", ft_shape, "c shape", c_shape, Rel.EQ, self.name) + validator.check("ot shape", ot_shape, "c shape", c_shape, Rel.EQ, self.name) + validator.check("tanhct shape", tanhct_shape, "c shape", c_shape, Rel.EQ, self.name) + + dgate_shape = (c_shape[0], 4 * c_shape[1]) + dct_1_shape = c_shape + + return (dgate_shape, dct_1_shape) + + def infer_dtype(self, c_dtype, dht_dtype, dct_dtype, it_dtype, jt_dtype, ft_dtype, ot_dtype, tanhct_dtype): + validator.check_subclass("c", c_dtype, [mstype.tensor], self.name) + validator.check_subclass("dht", dht_dtype, [mstype.tensor], self.name) + validator.check_subclass("dct", dct_dtype, [mstype.tensor], self.name) + validator.check_subclass("it", it_dtype, [mstype.tensor], self.name) + validator.check_subclass("jt", jt_dtype, [mstype.tensor], self.name) + validator.check_subclass("ft", ft_dtype, [mstype.tensor], self.name) + validator.check_subclass("ot", ot_dtype, [mstype.tensor], self.name) + validator.check_subclass("tanhct", tanhct_dtype, [mstype.tensor], self.name) + validator.check_type_name("c", c_dtype, [mstype.float16, mstype.float32], self.name) + validator.check_type_name("dht", dht_dtype, [mstype.float16, mstype.float32], self.name) + validator.check_type_name("dct", dct_dtype, [mstype.float16, mstype.float32], self.name) + validator.check_type_name("it", it_dtype, [mstype.float16, mstype.float32], self.name) + validator.check_type_name("jt", jt_dtype, [mstype.float16, mstype.float32], self.name) + validator.check_type_name("ft", ft_dtype, [mstype.float16, mstype.float32], self.name) + validator.check_type_name("ot", ot_dtype, [mstype.float16, mstype.float32], self.name) + validator.check_type_name("tanhct", tanhct_dtype, [mstype.float16, mstype.float32], self.name) + return (c_dtype, c_dtype) + + +class BasicLSTMCellWeightGrad(PrimitiveWithInfer): + """Computes the weight gradients of BasicLSTM.""" + + @prim_attr_register + def __init__(self): + pass + + def infer_shape(self, x_shape, h_shape, dgate_shape): + validator.check_integer("x rank", len(x_shape), 2, Rel.EQ, self.name) + validator.check("h rank", len(h_shape), " x rank", len(x_shape), Rel.EQ, self.name) + validator.check("dgate rank", len(dgate_shape), "x rank", len(x_shape), Rel.EQ, self.name) + validator.check("h_shape[0]", h_shape[0], "x_shape[0]", x_shape[0], Rel.EQ, self.name) + validator.check("dgate_shape[0]", dgate_shape[0], "h_shape[0]", h_shape[0], Rel.EQ, self.name) + validator.check("dgate_shape[1]", dgate_shape[1], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name) + dw_shape = (dgate_shape[1], x_shape[1] + h_shape[1], 1, 1) + db_shape = (dgate_shape[1], 1, 1, 1) + return (dw_shape, db_shape) + + def infer_dtype(self, x_dtype, h_dtype, dgate_dtype): + validator.check_subclass("x", x_dtype, mstype.tensor, self.name) + validator.check_subclass("h", h_dtype, mstype.tensor, self.name) + validator.check_subclass("dgate", dgate_dtype, mstype.tensor, self.name) + validator.check_type_name("x", x_dtype, [mstype.float16, mstype.float32], self.name) + validator.check_type_name("h", h_dtype, [mstype.float16, mstype.float32], self.name) + validator.check_type_name("dgate", dgate_dtype, [mstype.float16, mstype.float32], self.name) + return (x_dtype, x_dtype) + + +class BasicLSTMCellInputGrad(PrimitiveWithInfer): + """Computes the input gradients of BasicLSTM.""" + + @prim_attr_register + def __init__(self, keep_prob): + self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name) + self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0.0, 1.0, Rel.INC_BOTH, self.name) + + def infer_shape(self, dgate_shape, w_shape): + validator.check_integer("dgate rank", len(dgate_shape), 2, Rel.EQ, self.name) + validator.check_integer("w rank", len(w_shape), 4, Rel.EQ, self.name) + validator.check("dgate_shape[1]", dgate_shape[1], "w_shape[0]", w_shape[0], Rel.EQ, self.name) + dxt_shape = (dgate_shape[0], w_shape[1] - w_shape[0] // 4) + dht_shape = (dgate_shape[0], dgate_shape[1] // 4) + return (dxt_shape, dht_shape) + + def infer_dtype(self, dgate_dtype, w_dtype): + validator.check_subclass("dgate", dgate_dtype, mstype.tensor, self.name) + validator.check_subclass("w", w_dtype, mstype.tensor, self.name) + validator.check_type_name("dgate", dgate_dtype, [mstype.float16, mstype.float32], self.name) + validator.check_type_name("w", w_dtype, [mstype.float16, mstype.float32], self.name) + return (dgate_dtype, dgate_dtype) diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index ed7237b04c2548f85aa576cf7069ef14465f4d35..5426afb53c3cd1363f71b8be5296f06ead4b08a2 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -3363,3 +3363,109 @@ class CTCLoss(PrimitiveWithInfer): validator.check_tensor_type_same({"labels_values_dtype": labels_values}, [mstype.int32], self.name) validator.check_tensor_type_same({"sequence_length_dtype": sequence_length}, [mstype.int32], self.name) return inputs, inputs + + +class BasicLSTMCell(PrimitiveWithInfer): + r""" + Performs the long short term memory(LSTM) on the input. + + .. math:: + \begin{array}{ll} \\ + i_t = \sigma(W_{ix} x_t + b_{ix} + W_{ih} h_{(t-1)} + b_{ih}) \\ + f_t = \sigma(W_{fx} x_t + b_{fx} + W_{fh} h_{(t-1)} + b_{fh}) \\ + \tilde{c}_t = \tanh(W_{cx} x_t + b_{cx} + W_{ch} h_{(t-1)} + b_{ch}) \\ + o_t = \sigma(W_{ox} x_t + b_{ox} + W_{oh} h_{(t-1)} + b_{oh}) \\ + c_t = f_t * c_{(t-1)} + i_t * \tilde{c}_t \\ + h_t = o_t * \tanh(c_t) \\ + \end{array} + + Here :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product. :math:`W, b` + are learnable weights between the output and the input in the formula. For instance, + :math:`W_{ix}, b_{ix}` are the weight and bias used to transform from input :math:`x` to :math:`i`. + Details can be found in paper `LONG SHORT-TERM MEMORY + `_ and + `Long Short-Term Memory Recurrent Neural Network Architectures for Large Scale Acoustic Modeling + `_. + + Args: + keep_prob (float): If not 1.0, append `Dropout` layer on the outputs of each + LSTM layer except the last layer. Default 1.0. The range of dropout is [0.0, 1.0]. + forget_bias (float): Add forget bias to forget gate biases in order to decrease former scale. Default to 1.0. + state_is_tuple (bool): If True, state is tensor tuple, containing h and c; If False, one tensor, + need split first. Default to True. + activation (str): Activation. Default to "tanh". + + Inputs: + - **x** (Tensor) - Current words. Tensor of shape (`batch_size`, `input_size`). + - **h** (Tensor) - Hidden state last moment. Tensor of shape (`batch_size`, `hidden_size`). + - **c** (Tensor) - Cell state last moment. Tensor of shape (`batch_size`, `hidden_size`). + - **w** (Tensor) - Weight. Tensor of shape (`4 x hidden_size`, `input_size + hidden_size`, 1, 1). + - **b** (Tensor) - Bias. Tensor of shape (`4 x hidden_size`, 1, 1, 1). + + Outputs: + - **ct** (Tensor) - Forward :math:`c_t` cache at moment `t`. Tensor of shape (`batch_size`, `hidden_size`). + - **ht** (Tensor) - Cell output. Tensor of shape (`batch_size`, `hidden_size`). + - **it** (Tensor) - Forward :math:`i_t` cache at moment `t`. Tensor of shape (`batch_size`, `4 x hidden_size`). + - **jt** (Tensor) - Forward :math:`j_t` cache at moment `t`. Tensor of shape (`batch_size`, `4 x hidden_size`). + - **ft** (Tensor) - Forward :math:`f_t` cache at moment `t`. Tensor of shape (`batch_size`, `4 x hidden_size`). + - **ot** (Tensor) - Forward :math:`o_t` cache at moment `t`. Tensor of shape (`batch_size`, `4 x hidden_size`). + - **tanhct** (Tensor) - Forward :math:`tanh c_t` cache at moment `t`. + Tensor of shape (`batch_size`, `4 x hidden_size`). + + Examples: + 'block': P.BasicLSTMCell(keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh'), + 'desc_inputs': [[128, 128], [128, 128], [128, 128], [512, 256, 1, 1],[512, 1, 1, 1]], + 'desc_bprop': [[128, 128], [128, 128], [128, 128], [128, 128], [128, 128], [128, 128], [128, 128]], + + >>> x = Tensor(np.random.rand(128, 128).astype(np.float16)) + >>> h = Tensor(np.random.rand(128, 128).astype(np.float16)) + >>> c = Tensor(np.random.rand(128, 128).astype(np.float16)) + >>> w = Tensor(np.random.rand(512, 256, 1, 1).astype(np.float16)) + >>> b = Tensor(np.random.rand(512, 1, 1, 1).astype(np.float16)) + >>> lstm = P.BasicLSTMCell(keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh') + >>> lstm(x, h, c, w, b) + """ + + @prim_attr_register + def __init__(self, keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh'): + self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name) + self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0.0, 1.0, Rel.INC_BOTH, self.name) + self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name) + self.state_is_tuple = validator.check_value_type("state_is_tuple", state_is_tuple, [bool], self.name) + self.activation = validator.check_string("activation", activation, ['tanh'], self.name) + + def infer_shape(self, x_shape, h_shape, c_shape, w_shape, b_shape): + # (batch_size, input_size) + validator.check_integer("x_shape", len(x_shape), 2, Rel.EQ, self.name) + + # h and c should be same shape + validator.check_integer("h_shape", len(h_shape), 2, Rel.EQ, self.name) + validator.check("h rank", len(h_shape), "c rank", len(c_shape), Rel.EQ, self.name) + validator.check("h shape", h_shape, "c shape", c_shape, Rel.EQ, self.name) + validator.check_integer("w rank", len(w_shape), 4, Rel.EQ, self.name) + validator.check_integer("b rank", len(b_shape), 4, Rel.EQ, self.name) + validator.check("w_shape[0]", w_shape[0], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name) + validator.check("w_shape[1]", w_shape[1], "x_shape[1]+h_shape[1]", x_shape[1] + h_shape[1], Rel.EQ, self.name) + validator.check("b_shape[0]", b_shape[0], "4*h_shape[1]", 4*h_shape[1], Rel.EQ, self.name) + ct_shape = c_shape + ht_shape = h_shape + it_shape = h_shape + jt_shape = h_shape + ft_shape = h_shape + ot_shape = h_shape + tanhct_shape = h_shape + + return (ct_shape, ht_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape) + + def infer_dtype(self, x_dtype, h_dtype, c_dtype, w_dtype, b_dtype): + validator.check_subclass("x", x_dtype, [mstype.tensor], self.name) + validator.check_subclass("h", h_dtype, [mstype.tensor], self.name) + validator.check_subclass("c", c_dtype, [mstype.tensor], self.name) + validator.check_subclass("w", w_dtype, [mstype.tensor], self.name) + validator.check_subclass("b", b_dtype, [mstype.tensor], self.name) + validator.check_type_name("x", x_dtype, [mstype.float16, mstype.float32], self.name) + validator.check_type_name("h", h_dtype, [mstype.float16, mstype.float32], self.name) + validator.check_type_name("c", c_dtype, [mstype.float16, mstype.float32], self.name) + validator.check_type_name("w", w_dtype, [mstype.float16, mstype.float32], self.name) + validator.check_type_name("b", b_dtype, [mstype.float16, mstype.float32], self.name) + return (x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype) diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 4a0abcea170924e8e3c29bf5254abba1fa2af8ce..91315f692efff7fd282d7603e59ab1e95d289ef1 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -891,6 +891,11 @@ test_case_nn_ops = [ 'desc_inputs': [[128, 64, 32, 32], [128, 64, 32, 32], [64], [64], [64]], 'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]], 'skip': ['backward']}), + ('BasicLSTMCell', { + 'block': P.BasicLSTMCell(keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh'), + 'desc_inputs': [[128, 128], [128, 128], [128, 128], [512, 256, 1, 1],[512, 1, 1, 1]], + 'desc_bprop': [[128, 128], [128, 128], [128, 128], [128, 128], [128, 128], [128, 128], [128, 128]], + 'skip': []}), ('TopK', { 'block': P.TopK(), 'desc_const': [5],