From fa9d3fa5bf8ccab35a499ace24f08a32cc9c03c0 Mon Sep 17 00:00:00 2001 From: Guo Sheng Date: Fri, 16 Oct 2020 17:10:41 +0800 Subject: [PATCH] Incorporate cudnn_lstm into LSTM api (#27217) * Incorporate cudnn_lstm into LSTM api. test=develop * Make coalesce_tensor support alignment optionally. test=develop * Reorganize RNN apis. test=develop * Fix cudnn rnn layout conversion. test=develop * Add sequence_length support for RNN cudnn implement. Add optional init_h and init_c gradient for cudnn_lstm_op. test=develop * Use create_parameter for rnn cudnn impl. test=develop * Move `self._flat_weight = self.create_parameter()` in RNNBase to main_program. test=develop * Update RNN api unittest to use set_device. test=develop * Fix set_place for unit tests of RNN apis. test=develop * Fix use_align in coalesce_tensor_op. test=develop * Adjust RNN apis arguments according to comments. test=develop * Polish documents for SimpleRNN apis. test=develop * Refine random seed in cudnn_lstm_op. Expose rnn params from sublayers to RNN. test=develop * Fix RNN saving for jit.save. Refine cudnn_lstm dropout behavior. test=develop * Fix doc of GRU. test=develop * Use ShareDataWith to avoid copying for cudnn_lstm_op test. test=develop * Remove updates on cudnn_lstm temporarily. test=develop * Use ShareDataWith to avoid copying for cudnn_lstm_op test. test=develop * Refine random seed in cudnn_lstm_op. test=develop * Fix test_lstm by adjust ConcreteProgram buffer getter. test=develop * Use create_parameter instead of create_var for rnn._flat_weight for static graph usage. test=develop * Remove W input for cudnn_lstm to pass unused_var_check. test=develop * Add test_predict for RNN unit tests coverage. test=develop * Fix code style of rnn. test=develop * Fix F.rnn usage in rnn.py. test=develop --- paddle/fluid/operators/coalesce_tensor_op.cc | 31 +- paddle/fluid/operators/cudnn_lstm_op.cu.cc | 42 +- paddle/fluid/operators/save_op.cc | 1 + paddle/fluid/operators/save_op.cu | 1 + .../dygraph_to_static/program_translator.py | 18 +- .../tests/unittests/rnn/test_rnn_nets.py | 82 +++- .../unittests/rnn/test_rnn_nets_static.py | 21 +- python/paddle/nn/layer/rnn.py | 441 ++++++++++++------ 8 files changed, 444 insertions(+), 193 deletions(-) diff --git a/paddle/fluid/operators/coalesce_tensor_op.cc b/paddle/fluid/operators/coalesce_tensor_op.cc index d67d90c348..a7c0f12711 100644 --- a/paddle/fluid/operators/coalesce_tensor_op.cc +++ b/paddle/fluid/operators/coalesce_tensor_op.cc @@ -67,6 +67,7 @@ class CoalesceTensorOpKernel : public framework::OpKernel { } auto in_tensors = context.MultiInput("Input"); + bool use_align = context.Attr("use_align"); if (context.Attr("check_name")) { for (size_t i = 0; i < in_var_names.size(); ++i) { @@ -93,7 +94,7 @@ class CoalesceTensorOpKernel : public framework::OpKernel { context.Attr("dtype")); size_t size_of_dtype = framework::SizeOfType(dtype); GetMemSizeAndDtype(in_tensors, in_var_names, &numel, size_of_dtype, - context.GetPlace()); + context.GetPlace(), use_align); // Alloc the continuous space auto fused_tensor = context.Output("FusedOutput"); @@ -111,8 +112,11 @@ class CoalesceTensorOpKernel : public framework::OpKernel { framework::TensorCopy(*in_tensors[i], context.GetPlace(), dev_ctx, &sub_tensor); - offset += platform::Alignment(len * size_of_dtype, context.GetPlace()) / - size_of_dtype; + offset += + use_align + ? platform::Alignment(len * size_of_dtype, context.GetPlace()) / + size_of_dtype + : len; } } else if (context.Attr("set_constant")) { math::SetConstant set_constant; @@ -131,8 +135,10 @@ class CoalesceTensorOpKernel : public framework::OpKernel { ->ShareDataWith(fused_tensor->Slice( static_cast(offset), static_cast(offset + len))) .Resize(dim); - len = platform::Alignment(len * size_of_dtype, context.GetPlace()) / - size_of_dtype; + len = use_align + ? platform::Alignment(len * size_of_dtype, context.GetPlace()) / + size_of_dtype + : len; offset += len; ss << "output(" << out_var_names[i] << ") dim:(" << dim << ")" << " address: " << out_tensors[i]->data() << ", "; @@ -144,7 +150,8 @@ class CoalesceTensorOpKernel : public framework::OpKernel { void GetMemSizeAndDtype( const std::vector &lod_tensors, const std::vector var_names, size_t *numel, - const size_t &size_of_dtype, const platform::Place &place) const { + const size_t &size_of_dtype, const platform::Place &place, + const bool use_align = true) const { PADDLE_ENFORCE_EQ( lod_tensors.size(), var_names.size(), platform::errors::InvalidArgument( @@ -167,9 +174,11 @@ class CoalesceTensorOpKernel : public framework::OpKernel { ss << "input(" << var_names[i] << ") dim:(" << lod_tensors[i]->dims() << ") " << " addres:" << lod_tensors[i]->data() << ", "; - *numel += platform::Alignment(static_cast(size) * size_of_dtype, - place) / - size_of_dtype; + *numel += use_align + ? platform::Alignment( + static_cast(size) * size_of_dtype, place) / + size_of_dtype + : static_cast(size); } VLOG(10) << ss.str(); @@ -223,6 +232,10 @@ class CoalesceTensorOpMaker : public framework::OpProtoAndCheckerMaker { "Whether to check the name of Input and Output to ensure " "they are the same separately.") .SetDefault(false); + AddAttr("use_align", + "Whether to consider memory chunk and take alignment into " + "account for inputs and outputs.") + .SetDefault(true); AddComment(R"DOC( CoalesceTensor Operator. diff --git a/paddle/fluid/operators/cudnn_lstm_op.cu.cc b/paddle/fluid/operators/cudnn_lstm_op.cu.cc index bea7d9c02c..e935a3c0aa 100644 --- a/paddle/fluid/operators/cudnn_lstm_op.cu.cc +++ b/paddle/fluid/operators/cudnn_lstm_op.cu.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/cudnn_lstm_cache.h" #include "paddle/fluid/operators/math/math_function.h" @@ -156,6 +157,21 @@ class CudnnLSTMGPUKernel : public framework::OpKernel { bool is_test = ctx.Attr("is_test"); int seed = ctx.Attr("seed"); + if (!is_test) { + int device_id = + BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()).GetDeviceId(); + auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); + if (gen_cuda->GetIsInitPy() && seed == 0) { + // If perform `manual_seed` in python and inner seed is not specified + // (equals 0), use global generator generated seed. + seed = static_cast(gen_cuda->Random64()); + } else if (seed == 0) { + // use random generated seed + std::random_device rd; + seed = rd(); + } // else use `ctx.Attr("seed")` specified seed + } + bool has_seq_length = ctx.HasInput("SequenceLength"); std::vector SequenceLength; if (has_seq_length) { @@ -194,13 +210,25 @@ class CudnnLSTMGPUKernel : public framework::OpKernel { if (!continuous) { LOG_FIRST_N(WARNING, 2) - << "If the memory space of the Input WeightList is not " - "continuous, less efficient calculation will be " - "called. Please call coalesce_tensor op to make the " - "input memory continuous."; + << "If the memory space of the Input WeightList is not continuous, " + "less efficient calculation will be called. Please call " + "flatten_parameters() to make the input memory continuous."; weight_whole.mutable_data({weight_numel}, place); weight_to_tensor(place, stream, weight_list, &weight_whole); w_data = weight_whole.data(); + if (is_test) { // maybe also reset small weights' ptr for training + int offset = 0; + for (size_t i = 0; i < weight_list.size(); ++i) { + size_t len = weight_list[i]->numel(); + auto dim = weight_list[i]->dims(); + const_cast(weight_list[i]) + ->ShareDataWith( + weight_whole.Slice(static_cast(offset), + static_cast(offset + len))) + .Resize(dim); + offset += len; + } + } } else { w_data = const_cast(weight_list[0]->data()); } @@ -226,12 +254,6 @@ class CudnnLSTMGPUKernel : public framework::OpKernel { LSTMInferece(has_seq_length, handle, seq_length, &rnn, x_data, init_h_data, init_c_data, w_data, out_data, last_h_data, last_c_data, &workspace_data_, workspace_size); - if (!w_initialized && ctx.HasInput("W") && ctx.HasInput("WeightList")) { - auto *W = const_cast(ctx.Input("W")); - auto weight_list = ctx.MultiInput("WeightList"); - W->mutable_data({weight_numel}, place); - weight_to_tensor(place, stream, weight_list, W); - } } else { if (!has_seq_length) { // for train diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index c2a58b4199..f619f3d59c 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -89,6 +89,7 @@ REGISTER_OP_CPU_KERNEL( save, ops::SaveOpKernel, ops::SaveOpKernel, ops::SaveOpKernel, + ops::SaveOpKernel, ops::SaveOpKernel, ops::SaveOpKernel, ops::SaveOpKernel); diff --git a/paddle/fluid/operators/save_op.cu b/paddle/fluid/operators/save_op.cu index 0a778a694e..5c8c5a7545 100644 --- a/paddle/fluid/operators/save_op.cu +++ b/paddle/fluid/operators/save_op.cu @@ -21,6 +21,7 @@ REGISTER_OP_CUDA_KERNEL( save, ops::SaveOpKernel, ops::SaveOpKernel, ops::SaveOpKernel, + ops::SaveOpKernel, ops::SaveOpKernel, ops::SaveOpKernel, ops::SaveOpKernel`_ for more details. @@ -807,13 +813,14 @@ class RNN(Layer): initial_states=None, sequence_length=None, **kwargs): - final_outputs, final_states = paddle.fluid.layers.rnn(self.cell, - inputs, - initial_states=initial_states, - sequence_length=sequence_length, - time_major=self.time_major, - is_reverse=self.is_reverse, - **kwargs) + final_outputs, final_states = paddle.fluid.layers.rnn( + self.cell, + inputs, + initial_states=initial_states, + sequence_length=sequence_length, + time_major=self.time_major, + is_reverse=self.is_reverse, + **kwargs) return final_outputs, final_states @@ -909,18 +916,194 @@ class BiRNN(Layer): assert len(initial_states) == 2, \ "length of initial_states should be 2 when it is a list/tuple" - outputs, final_states = paddle.fluid.layers.birnn(self.cell_fw, self.cell_bw, inputs, - initial_states, sequence_length, - self.time_major, **kwargs) + outputs, final_states = paddle.fluid.layers.birnn( + self.cell_fw, self.cell_bw, inputs, initial_states, sequence_length, + self.time_major, **kwargs) return outputs, final_states -class RNNMixin(LayerList): +class RNNBase(LayerList): r""" - A Mixin class for RNN networks. It provides `forward` method for SimpleRNN, - LSTM and GRU. + RNNBase class for RNN networks. It provides `forward`, `flatten_parameters` + and other common methods for SimpleRNN, LSTM and GRU. """ + def __init__(self, + mode, + input_size, + hidden_size, + num_layers=1, + direction="forward", + time_major=False, + dropout=0., + weight_ih_attr=None, + weight_hh_attr=None, + bias_ih_attr=None, + bias_hh_attr=None): + super(RNNBase, self).__init__() + self.mode = mode + self.input_size = input_size + self.hidden_size = hidden_size + self.dropout = dropout + self.num_directions = 2 if direction == "bidirectional" else 1 + self.time_major = time_major + self.num_layers = num_layers + self.state_components = 2 if mode == "LSTM" else 1 + + kwargs = { + "weight_ih_attr": weight_ih_attr, + "weight_hh_attr": weight_hh_attr, + "bias_ih_attr": bias_ih_attr, + "bias_hh_attr": bias_hh_attr + } + + if mode == "LSTM": + rnn_cls = LSTMCell + elif mode == "GRU": + rnn_cls = GRUCell + else: + rnn_cls = SimpleRNNCell + kwargs["activation"] = self.activation + + if direction in ["forward", "backward"]: + is_reverse = direction == "backward" + cell = rnn_cls(input_size, hidden_size, **kwargs) + self.append(RNN(cell, is_reverse, time_major)) + for i in range(1, num_layers): + cell = rnn_cls(hidden_size, hidden_size, **kwargs) + self.append(RNN(cell, is_reverse, time_major)) + elif direction == "bidirectional": + cell_fw = rnn_cls(input_size, hidden_size, **kwargs) + cell_bw = rnn_cls(input_size, hidden_size, **kwargs) + self.append(BiRNN(cell_fw, cell_bw, time_major)) + for i in range(1, num_layers): + cell_fw = rnn_cls(2 * hidden_size, hidden_size, **kwargs) + cell_bw = rnn_cls(2 * hidden_size, hidden_size, **kwargs) + self.append(BiRNN(cell_fw, cell_bw, time_major)) + else: + raise ValueError( + "direction should be forward, backward or bidirectional, " + "received direction = {}".format(direction)) + + self.could_use_cudnn = get_device().startswith( + "gpu:") and get_cudnn_version() + self.could_use_cudnn &= direction != "backward" + self.could_use_cudnn &= len(self.parameters()) == num_layers * 4 * ( + 2 if direction == "bidirectional" else 1) + self.could_use_cudnn &= mode == "LSTM" # currently only support LSTM + + # Expose params as RNN's attribute, which can make it compatible when + # replacing small ops composed rnn with cpp rnn kernel. + # Moreover, `jit.to_static` assumes params are added by current layer + # and wouldn't include sublayer's params in current layer, which also + # requires these params are added to current layer for `jit.save`. + param_names = [] + for layer in range(self.num_layers): + for direction in range(self.num_directions): + suffix = '_reverse' if direction == 1 else '' + param_names.extend(['weight_ih_l{}{}', 'weight_hh_l{}{}']) + if bias_ih_attr != False: param_names.append('bias_ih_l{}{}') + if bias_hh_attr != False: param_names.append('bias_hh_l{}{}') + param_names = [x.format(layer, suffix) for x in param_names] + for name, param in zip(param_names, self.parameters()): + setattr(self, name, param) + + self.flatten_parameters() + + def flatten_parameters(self): + """ + Resets parameter data pointer to address in continuous memory block for + cudnn usage. + """ + if self.could_use_cudnn: + # layer.parameters() is depth first and ordered + # for i in layer: for j in direct: w_ih, w_hh, b_ih, b_hh + # need to reorganize to cudnn param layout: + # all bias following all weights + params = self.parameters(include_sublayers=False) + shape = [np.prod(param.shape) for param in params] + self._all_weights = [None] * len(params) + for i, param in enumerate(params): + offset = 0 if i % 4 < 2 else (2 * self.num_layers * + self.num_directions) + layer_idx = i // 4 + self._all_weights[offset + layer_idx * 2 + i % 2] = param + # Wrap using a list to avoid registed into params and saving, maybe + # need a better way to handle this later. Use `create_parameter` to + # add both to main_program and startup_program for static-graph. + # Use Constant initializer to avoid make effect on random generator. + self._flat_weight = [ + self.create_parameter( + shape=[np.sum(shape)], + dtype=params[0].dtype, + default_initializer=I.Constant(0.0)) + ] + # dropout state may also can be hided and avoid saving + # should dropout state be persistable for static-graph + self._dropout_state = self.create_variable( + dtype=fluid.core.VarDesc.VarType.UINT8) + # for static-graph, append coalesce_tensor into startup program + with fluid.program_guard(fluid.default_startup_program(), + fluid.default_startup_program()): + with framework.no_grad(): + self._helper.append_op( + type="coalesce_tensor", + inputs={"Input": self._all_weights}, + outputs={ + "Output": self._all_weights, + "FusedOutput": self._flat_weight + }, + attrs={ + "copy_data": True, + "use_align": False, + "dtype": params[0].dtype + }) + + def _cudnn_impl(self, inputs, initial_states, sequence_length): + if not self.time_major: + inputs = paddle.tensor.transpose(inputs, [1, 0, 2]) + # unify LSTM/GRU/SimpleRNN later, currently only support LSTM + # TODO(guosheng): use `core.ops.cudnn_lstm` in dygraph mode if support + # specify output, since `dropout_state` should be a persistable tensor + # rather than a temporary on. + out = self._helper.create_variable_for_type_inference(inputs.dtype) + last_h = self._helper.create_variable_for_type_inference(inputs.dtype) + last_c = self._helper.create_variable_for_type_inference(inputs.dtype) + reserve = self._helper.create_variable_for_type_inference( + dtype=fluid.core.VarDesc.VarType.UINT8, stop_gradient=True) + + inputs = { + 'Input': inputs, + # 'W': self._flat_weight, # would be unused_var + 'WeightList': self._all_weights, + 'InitH': initial_states[0], + 'InitC': initial_states[1], + 'SequenceLength': sequence_length + } + attrs = { + 'dropout_prob': self.dropout, + 'is_bidirec': self.num_directions == 2, + 'input_size': self.input_size, + 'hidden_size': self.hidden_size, + 'num_layers': self.num_layers, + 'is_test': not self.training + } + + outputs = { + 'Out': out, + 'LastH': last_h, + 'LastC': last_c, + 'Reserve': reserve, + 'StateOut': self._dropout_state, + } + + self._helper.append_op( + type="cudnn_lstm", inputs=inputs, outputs=outputs, attrs=attrs) + out = paddle.tensor.transpose(out, + [1, 0, 2]) if not self.time_major else out + states = (last_h, last_c) + return out, states + def forward(self, inputs, initial_states=None, sequence_length=None): batch_index = 1 if self.time_major else 0 dtype = inputs.dtype @@ -937,6 +1120,10 @@ class RNNMixin(LayerList): for _ in range(self.state_components) ]) + if self.could_use_cudnn: + # Add CPU kernel and dispatch in backend later + return self._cudnn_impl(inputs, initial_states, sequence_length) + states = split_states(initial_states, self.num_directions == 2, self.state_components) final_states = [] @@ -957,7 +1144,7 @@ class RNNMixin(LayerList): return outputs, final_states -class SimpleRNN(RNNMixin): +class SimpleRNN(RNNBase): r""" Multilayer Elman network(SimpleRNN). It takes input sequences and initial states as inputs, and returns the output sequences and the final states. @@ -970,22 +1157,28 @@ class SimpleRNN(RNNMixin): .. math:: - h_{t} & = \mathrm{tanh}(W_{ih}x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh}) + h_{t} & = act(W_{ih}x_{t} + b_{ih} + W_{hh}h{t-1} + b_{hh}) y_{t} & = h_{t} + + where :math:`act` is for :attr:`activation` , and * is the elemetwise + multiplication operator. + + Using key word arguments to construct is recommended. Parameters: input_size (int): The input size for the first layer's cell. hidden_size (int): The hidden size for each layer's cell. num_layers (int, optional): Number of layers. Defaults to 1. - activation (str, optional): The activation in each SimpleRNN cell. It can be - `tanh` or `relu`. Defaults to `tanh`. direction (str, optional): The direction of the network. It can be "forward", - "backward" and "bidirectional". Defaults to "forward". - dropout (float, optional): The droput probability. Dropout is applied to the - input of each layer except for the first layer. Defaults to 0. + "backward" and "bidirectional". When "bidirectional", the way to merge + outputs of forward and backward is concatenating. Defaults to "forward". time_major (bool, optional): Whether the first dimension of the input means the time steps. Defaults to False. + dropout (float, optional): The droput probability. Dropout is applied to the + input of each layer except for the first layer. Defaults to 0. + activation (str, optional): The activation in each SimpleRNN cell. It can be + `tanh` or `relu`. Defaults to `tanh`. weight_ih_attr (ParamAttr, optional): The parameter attribute for `weight_ih` of each cell. Defaults to None. weight_hh_attr (ParamAttr, optional): The parameter attribute for @@ -1002,7 +1195,7 @@ class SimpleRNN(RNNMixin): If `time_major` is True, the shape is `[time_steps, batch_size, input_size]`, else, the shape is `[batch_size, time_steps, hidden_size]`. initial_states (Tensor, optional): the initial state. The shape is - `[num_lauers * num_directions, batch_size, hidden_size]`. + `[num_layers * num_directions, batch_size, hidden_size]`. If initial_state is not given, zero initial states are used. sequence_length (Tensor, optional): shape `[batch_size]`, dtype: int64 or int32. The valid lengths of input sequences. Defaults to None. @@ -1020,10 +1213,21 @@ class SimpleRNN(RNNMixin): Note that `num_directions` is 2 if direction is "bidirectional" else 1. final_states (Tensor): final states. The shape is - `[num_lauers * num_directions, batch_size, hidden_size]`. + `[num_layers * num_directions, batch_size, hidden_size]`. Note that `num_directions` is 2 if direction is "bidirectional" else 1. + Attributes: + weight_ih_l[k]: the learnable input-hidden weights of the k-th layer, + If `k = 0`, the shape is `[hidden_size, input_size]`. Otherwise, + the shape is `[hidden_size, num_directions * hidden_size]`. + weight_hh_l[k]: the learnable hidden-hidden weights of the k-th layer, + with shape `[hidden_size, hidden_size]`. + bias_ih_l[k]: the learnable input-hidden bias of the k-th layer, + with shape `[hidden_size]`. + bias_hh_l[k]: the learnable hidden-hidden bias of the k-th layer, + with shape `[hidden_size]`. + Examples: .. code-block:: python @@ -1048,59 +1252,28 @@ class SimpleRNN(RNNMixin): input_size, hidden_size, num_layers=1, - activation="tanh", direction="forward", - dropout=0., time_major=False, + dropout=0., + activation="tanh", weight_ih_attr=None, weight_hh_attr=None, bias_ih_attr=None, bias_hh_attr=None, name=None): - super(SimpleRNN, self).__init__() - - if direction in ["forward", "backward"]: - is_reverse = direction == "backward" - cell = SimpleRNNCell(input_size, hidden_size, activation, - weight_ih_attr, weight_hh_attr, bias_ih_attr, - bias_hh_attr) - self.append(RNN(cell, is_reverse, time_major)) - for i in range(1, num_layers): - cell = SimpleRNNCell(hidden_size, hidden_size, activation, - weight_ih_attr, weight_hh_attr, - bias_ih_attr, bias_hh_attr) - self.append(RNN(cell, is_reverse, time_major)) - elif direction == "bidirectional": - cell_fw = SimpleRNNCell(input_size, hidden_size, activation, - weight_ih_attr, weight_hh_attr, - bias_ih_attr, bias_hh_attr) - cell_bw = SimpleRNNCell(input_size, hidden_size, activation, - weight_ih_attr, weight_hh_attr, - bias_ih_attr, bias_hh_attr) - self.append(BiRNN(cell_fw, cell_bw, time_major)) - for i in range(1, num_layers): - cell_fw = SimpleRNNCell( - 2 * hidden_size, hidden_size, activation, weight_ih_attr, - weight_hh_attr, bias_ih_attr, bias_hh_attr) - cell_bw = SimpleRNNCell( - 2 * hidden_size, hidden_size, activation, weight_ih_attr, - weight_hh_attr, bias_ih_attr, bias_hh_attr) - self.append(BiRNN(cell_fw, cell_bw, time_major)) + if activation == "tanh": + mode = "RNN_TANH" + elif activation == "relu": + mode = "RNN_RELU" else: - raise ValueError( - "direction should be forward, backward or bidirectional, " - "received direction = {}".format(direction)) - - self.input_size = input_size - self.hidden_size = hidden_size - self.dropout = dropout - self.num_directions = 2 if direction == "bidirectional" else 1 - self.time_major = time_major - self.num_layers = num_layers - self.state_components = 1 + raise ValueError("Unknown activation '{}'".format(activation)) + self.activation = activation + super(SimpleRNN, self).__init__( + mode, input_size, hidden_size, num_layers, direction, time_major, + dropout, weight_ih_attr, weight_hh_attr, bias_ih_attr, bias_hh_attr) -class LSTM(RNNMixin): +class LSTM(RNNBase): r""" Multilayer LSTM. It takes a sequence and an initial state as inputs, and returns the output sequences and the final states. @@ -1130,16 +1303,19 @@ class LSTM(RNNMixin): where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise multiplication operator. + Using key word arguments to construct is recommended. + Parameters: input_size (int): The input size for the first layer's cell. hidden_size (int): The hidden size for each layer's cell. num_layers (int, optional): Number of layers. Defaults to 1. - direction (str, optional): The direction of the network. It can be - "forward", "backward" and "bidirectional". Defaults to "forward". - dropout (float, optional): The droput probability. Dropout is applied - to the input of each layer except for the first layer. Defaults to 0. + direction (str, optional): The direction of the network. It can be "forward", + "backward" and "bidirectional". When "bidirectional", the way to merge + outputs of forward and backward is concatenating. Defaults to "forward". time_major (bool, optional): Whether the first dimension of the input means the time steps. Defaults to False. + dropout (float, optional): The droput probability. Dropout is applied + to the input of each layer except for the first layer. Defaults to 0. weight_ih_attr (ParamAttr, optional): The parameter attribute for `weight_ih` of each cell. Default: None. weight_hh_attr (ParamAttr, optional): The parameter attribute for @@ -1156,7 +1332,7 @@ class LSTM(RNNMixin): If `time_major` is True, the shape is `[time_steps, batch_size, input_size]`, else, the shape is `[batch_size, time_steps, hidden_size]`. initial_states (tuple, optional): the initial state, a tuple of (h, c), - the shape of each is `[num_lauers * num_directions, batch_size, hidden_size]`. + the shape of each is `[num_layers * num_directions, batch_size, hidden_size]`. If initial_state is not given, zero initial states are used. sequence_length (Tensor, optional): shape `[batch_size]`, dtype: int64 or int32. The valid lengths of input sequences. Defaults to None. @@ -1175,10 +1351,21 @@ class LSTM(RNNMixin): else 1. final_states (tuple): the final state, a tuple of two tensors, h and c. The shape of each is - `[num_lauers * num_directions, batch_size, hidden_size]`. + `[num_layers * num_directions, batch_size, hidden_size]`. Note that `num_directions` is 2 if direction is "bidirectional" else 1. + Attributes: + weight_ih_l[k]: the learnable input-hidden weights of the k-th layer, + If `k = 0`, the shape is `[hidden_size, input_size]`. Otherwise, + the shape is `[hidden_size, num_directions * hidden_size]`. + weight_hh_l[k]: the learnable hidden-hidden weights of the k-th layer, + with shape `[hidden_size, hidden_size]`. + bias_ih_l[k]: the learnable input-hidden bias of the k-th layer, + with shape `[hidden_size]`. + bias_hh_l[k]: the learnable hidden-hidden bias of the k-th layer, + with shape `[hidden_size]`. + Examples: .. code-block:: python @@ -1207,51 +1394,19 @@ class LSTM(RNNMixin): hidden_size, num_layers=1, direction="forward", - dropout=0., time_major=False, + dropout=0., weight_ih_attr=None, weight_hh_attr=None, bias_ih_attr=None, bias_hh_attr=None, name=None): - super(LSTM, self).__init__() - - if direction in ["forward", "backward"]: - is_reverse = direction == "backward" - cell = LSTMCell(input_size, hidden_size, weight_ih_attr, - weight_hh_attr, bias_ih_attr, bias_hh_attr) - self.append(RNN(cell, is_reverse, time_major)) - for i in range(1, num_layers): - cell = LSTMCell(hidden_size, hidden_size, weight_ih_attr, - weight_hh_attr, bias_ih_attr, bias_hh_attr) - self.append(RNN(cell, is_reverse, time_major)) - elif direction == "bidirectional": - cell_fw = LSTMCell(input_size, hidden_size, weight_ih_attr, - weight_hh_attr, bias_ih_attr, bias_hh_attr) - cell_bw = LSTMCell(input_size, hidden_size, weight_ih_attr, - weight_hh_attr, bias_ih_attr, bias_hh_attr) - self.append(BiRNN(cell_fw, cell_bw, time_major)) - for i in range(1, num_layers): - cell_fw = LSTMCell(2 * hidden_size, hidden_size, weight_ih_attr, - weight_hh_attr, bias_ih_attr, bias_hh_attr) - cell_bw = LSTMCell(2 * hidden_size, hidden_size, weight_ih_attr, - weight_hh_attr, bias_ih_attr, bias_hh_attr) - self.append(BiRNN(cell_fw, cell_bw, time_major)) - else: - raise ValueError( - "direction should be forward, backward or bidirectional, " - "received direction = {}".format(direction)) - - self.input_size = input_size - self.hidden_size = hidden_size - self.dropout = dropout - self.num_directions = 2 if direction == "bidirectional" else 1 - self.time_major = time_major - self.num_layers = num_layers - self.state_components = 2 + super(LSTM, self).__init__( + "LSTM", input_size, hidden_size, num_layers, direction, time_major, + dropout, weight_ih_attr, weight_hh_attr, bias_ih_attr, bias_hh_attr) -class GRU(RNNMixin): +class GRU(RNNBase): r""" Multilayer GRU. It takes input sequencse and initial states as inputs, and returns the output sequences and the final states. @@ -1277,16 +1432,19 @@ class GRU(RNNMixin): where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise multiplication operator. + Using key word arguments to construct is recommended. + Parameters: input_size (int): The input size for the first layer's cell. hidden_size (int): The hidden size for each layer's cell. num_layers (int, optional): Number of layers. Defaults to 1. - direction (str, optional): The direction of the network. It can be - "forward", "backward" and "bidirectional". Defaults to "forward". - dropout (float, optional): The droput probability. Dropout is applied - to the input of each layer except for the first layer. Defaults to 0. + direction (str, optional): The direction of the network. It can be "forward", + "backward" and "bidirectional". When "bidirectional", the way to merge + outputs of forward and backward is concatenating. Defaults to "forward". time_major (bool, optional): Whether the first dimension of the input means the time steps. Defaults to False. + dropout (float, optional): The droput probability. Dropout is applied + to the input of each layer except for the first layer. Defaults to 0. weight_ih_attr (ParamAttr, optional): The parameter attribute for `weight_ih` of each cell. Default: None. weight_hh_attr (ParamAttr, optional): The parameter attribute for @@ -1303,7 +1461,7 @@ class GRU(RNNMixin): If `time_major` is True, the shape is `[time_steps, batch_size, input_size]`, else, the shape is `[batch_size, time_steps, hidden_size]`. initial_states (Tensor, optional): the initial state. The shape is - `[num_lauers * num_directions, batch_size, hidden_size]`. + `[num_layers * num_directions, batch_size, hidden_size]`. If initial_state is not given, zero initial states are used. Defaults to None. sequence_length (Tensor, optional): shape `[batch_size]`, dtype: int64 @@ -1322,10 +1480,21 @@ class GRU(RNNMixin): Note that `num_directions` is 2 if direction is "bidirectional" else 1. final_states (Tensor): final states. The shape is - `[num_lauers * num_directions, batch_size, hidden_size]`. + `[num_layers * num_directions, batch_size, hidden_size]`. Note that `num_directions` is 2 if direction is "bidirectional" else 1. + Attributes: + weight_ih_l[k]: the learnable input-hidden weights of the k-th layer, + If `k = 0`, the shape is `[hidden_size, input_size]`. Otherwise, + the shape is `[hidden_size, num_directions * hidden_size]`. + weight_hh_l[k]: the learnable hidden-hidden weights of the k-th layer, + with shape `[hidden_size, hidden_size]`. + bias_ih_l[k]: the learnable input-hidden bias of the k-th layer, + with shape `[hidden_size]`. + bias_hh_l[k]: the learnable hidden-hidden bias of the k-th layer, + with shape `[hidden_size]`. + Examples: .. code-block:: python @@ -1351,45 +1520,13 @@ class GRU(RNNMixin): hidden_size, num_layers=1, direction="forward", - dropout=0., time_major=False, + dropout=0., weight_ih_attr=None, weight_hh_attr=None, bias_ih_attr=None, bias_hh_attr=None, name=None): - super(GRU, self).__init__() - - if direction in ["forward", "backward"]: - is_reverse = direction == "backward" - cell = GRUCell(input_size, hidden_size, weight_ih_attr, - weight_hh_attr, bias_ih_attr, bias_hh_attr) - self.append(RNN(cell, is_reverse, time_major)) - for i in range(1, num_layers): - cell = GRUCell(hidden_size, hidden_size, weight_ih_attr, - weight_hh_attr, bias_ih_attr, bias_hh_attr) - self.append(RNN(cell, is_reverse, time_major)) - elif direction == "bidirectional": - cell_fw = GRUCell(input_size, hidden_size, weight_ih_attr, - weight_hh_attr, bias_ih_attr, bias_hh_attr) - cell_bw = GRUCell(input_size, hidden_size, weight_ih_attr, - weight_hh_attr, bias_ih_attr, bias_hh_attr) - self.append(BiRNN(cell_fw, cell_bw, time_major)) - for i in range(1, num_layers): - cell_fw = GRUCell(2 * hidden_size, hidden_size, weight_ih_attr, - weight_hh_attr, bias_ih_attr, bias_hh_attr) - cell_bw = GRUCell(2 * hidden_size, hidden_size, weight_ih_attr, - weight_hh_attr, bias_ih_attr, bias_hh_attr) - self.append(BiRNN(cell_fw, cell_bw, time_major)) - else: - raise ValueError( - "direction should be forward, backward or bidirectional, " - "received direction = {}".format(direction)) - - self.input_size = input_size - self.hidden_size = hidden_size - self.dropout = dropout - self.num_directions = 2 if direction == "bidirectional" else 1 - self.time_major = time_major - self.num_layers = num_layers - self.state_components = 1 + super(GRU, self).__init__( + "GRU", input_size, hidden_size, num_layers, direction, time_major, + dropout, weight_ih_attr, weight_hh_attr, bias_ih_attr, bias_hh_attr) -- GitLab