diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 29c6c2f593d54098cb9869004c628d10f95c0cd0..bc08c0730be4bf7abf6b022a47a793639368fcb1 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -194,6 +194,7 @@ paddle.fluid.layers.grid_sampler ArgSpec(args=['x', 'grid', 'name'], varargs=Non paddle.fluid.layers.log_loss ArgSpec(args=['input', 'label', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(0.0001, None)) paddle.fluid.layers.add_position_encoding ArgSpec(args=['input', 'alpha', 'beta', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.bilinear_tensor_product ArgSpec(args=['x', 'y', 'size', 'act', 'name', 'param_attr', 'bias_attr'], varargs=None, keywords=None, defaults=(None, None, None, None)) +paddle.fluid.layers.lstm ArgSpec(args=['input', 'init_h', 'init_c', 'max_len', 'hidden_size', 'num_layers', 'dropout_prob', 'is_bidirec', 'is_test', 'name', 'default_initializer', 'seed'], varargs=None, keywords=None, defaults=(0.0, False, False, None, None, -1)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)) paddle.fluid.layers.read_file ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None) diff --git a/paddle/fluid/platform/dynload/cudnn.h b/paddle/fluid/platform/dynload/cudnn.h index db62377898339def415a13d185f85f34de326d7f..213cd8a9ce094512cea6f6405492ec8feff11516 100644 --- a/paddle/fluid/platform/dynload/cudnn.h +++ b/paddle/fluid/platform/dynload/cudnn.h @@ -111,7 +111,23 @@ extern void EnforceCUDNNLoaded(const char* fn_name); __macro(cudnnFindConvolutionForwardAlgorithmEx); \ __macro(cudnnFindConvolutionBackwardFilterAlgorithmEx); \ __macro(cudnnFindConvolutionBackwardDataAlgorithmEx); \ - __macro(cudnnGetErrorString); + __macro(cudnnGetErrorString); \ + __macro(cudnnCreateDropoutDescriptor); \ + __macro(cudnnDropoutGetStatesSize); \ + __macro(cudnnSetDropoutDescriptor); \ + __macro(cudnnCreateRNNDescriptor); \ + __macro(cudnnSetRNNDescriptor); \ + __macro(cudnnGetRNNParamsSize); \ + __macro(cudnnGetRNNWorkspaceSize); \ + __macro(cudnnGetRNNTrainingReserveSize); \ + __macro(cudnnRNNForwardTraining); \ + __macro(cudnnRNNBackwardData); \ + __macro(cudnnRNNBackwardWeights); \ + __macro(cudnnRNNForwardInference); \ + __macro(cudnnDestroyDropoutDescriptor); \ + __macro(cudnnDestroyRNNDescriptor); \ + __macro(cudnnSetRNNDescriptor_v6); + CUDNN_DNN_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) #define CUDNN_DNN_ROUTINE_EACH_R2(__macro) \ diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 29a0de29dcaa6e9510c30e9a1186d2b1b88246f6..831a84b140ca733531dda7859fefe2ac9b35c2ce 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -169,6 +169,7 @@ __all__ = [ 'log_loss', 'add_position_encoding', 'bilinear_tensor_product', + 'lstm', ] @@ -472,6 +473,168 @@ def dynamic_lstm(input, return hidden, cell +def lstm(input, + init_h, + init_c, + max_len, + hidden_size, + num_layers, + dropout_prob=0.0, + is_bidirec=False, + is_test=False, + name=None, + default_initializer=None, + seed=-1): + """ + If Device is GPU, This op will use cudnn LSTM implementation + + A four-gate Long Short-Term Memory network with no peephole connections. + In the forward pass the output ht and cell output ct for a given iteration can be computed from the recurrent input ht-1, + the cell input ct-1 and the previous layer input xt given matrices W, R and biases bW, bR from the following equations: + + $$ i_t = \\sigma(W_{ix}x_{t} + W_{ih}h_{t-1} + bx_i + bh_i) $$ + + $$ f_t = \\sigma(W_{fx}x_{t} + W_{fh}h_{t-1} + bx_f + bh_f) $$ + + $$ o_t = \\sigma(W_{ox}x_{t} + W_{oh}h_{t-1} + bx_o + bh_o) $$ + + $$ \\tilde{c_t} = tanh(W_{cx}x_t + W_{ch}h_{t-1} + bx_c + bh_c) $$ + + $$ c_t = f_t \\odot c_{t-1} + i_t \\odot \\tilde{c_t} $$ + + $$ h_t = o_t \\odot tanh(c_t) $$ + + - W terms denote weight matrices (e.g. $W_{ix}$ is the matrix + of weights from the input gate to the input) + - The b terms denote bias vectors ($bx_i$ and $bh_i$ are the input gate bias vector). + - sigmoid is the logistic sigmoid function. + - $i, f, o$ and $c$ are the input gate, forget gate, output gate, + and cell activation vectors, respectively, all of which have the same size as + the cell output activation vector $h$. + - The $\odot$ is the element-wise product of the vectors. + - `tanh` is the activation functions. + - $\tilde{c_t}$ is also called candidate hidden state, + which is computed based on the current input and the previous hidden state. + + Where sigmoid is the sigmoid operator: sigmoid(x) = 1 / (1 + e^-x), * represents a point-wise multiplication, + X represensts a matrix multiplication + + + Args: + input (Variable): LSTM input tensor, shape MUST be ( seq_len x batch_size x input_size ) + init_h(Variable): The initial hidden state of the LSTM + This is a tensor with shape ( num_layers x batch_size x hidden_size) + if is_bidirec = True, shape should be ( num_layers*2 x batch_size x hidden_size) + init_c(Variable): The initial cell state of the LSTM. + This is a tensor with shape ( num_layers x batch_size x hidden_size ) + if is_bidirec = True, shape should be ( num_layers*2 x batch_size x hidden_size) + max_len (int): max length of LSTM. the first dim of input tensor CAN NOT greater than max_len + hidden_size (int): hidden size of the LSTM + num_layers (int): total layers number of the LSTM + dropout_prob(float|0.0): dropout prob, dropout ONLY work between rnn layers, NOT between time steps + There is NO dropout work on rnn output of the last RNN layers + is_bidirec (bool): If it is bidirectional + is_test (bool): If it is in test phrase + name (str|None): A name for this layer(optional). If set None, the layer + will be named automatically. + default_initializer(Initialize|None): Where use initializer to initialize the Weight + If set None, defaule initializer will be used + seed(int): Seed for dropout in LSTM, If it's -1, dropout will use random seed + + + Returns: + rnn_out(Tensor): result of LSTM hidden, shape is (seq_len x batch_size x hidden_size) + if is_bidirec set to True, shape will be ( seq_len x batch_sze x hidden_size*2) + last_h(Tensor): the hidden state of the last step of LSTM + shape is ( num_layers x batch_size x hidden_size ) + if is_bidirec set to True, shape will be ( num_layers*2 x batch_size x hidden_size) + last_c(Tensor): the cell state of the last step of LSTM + shape is ( num_layers x batch_size x hidden_size ) + if is_bidirec set to True, shape will be ( num_layers*2 x batch_size x hidden_size) + + + Examples: + .. code-block:: python + + input = embedding + batch_size = 20 + max_len = 100 + dropout_prob = 0.2 + input_size = 100 + hidden_size = 150 + num_layers = 1 + init_hidden1 = layers.fill_constant( [num_layers, batch_size, hidden_size], 'float32', 0.0, stop_grad=False) + init_cell1 = layers.fill_constant( [num_layers, batch_size, hidden_size], 'float32', 0.0, stop_grad=False) + + rnn_out, last_h, last_c = layers.lstm( input, init_h, init_c, \ + max_len, dropout_prob, input_size, hidden_size, \ + num_layers) + """ + + helper = LayerHelper('cudnn_lstm', **locals()) + + dtype = input.dtype + input_shape = list(input.shape) + input_size = input_shape[-1] + weight_size = 0 + for i in range(num_layers): + if i == 0: + input_weight_size = (input_size * hidden_size) * 4 + else: + if is_bidirec: + input_weight_size = (hidden_size * 2 * hidden_size) * 4 + else: + input_weight_size = (hidden_size * hidden_size) * 4 + + hidden_weight_size = (hidden_size * hidden_size) * 4 + + if is_bidirec: + weight_size += (input_weight_size + hidden_weight_size) * 2 + weight_size += hidden_size * 8 * 2 + else: + weight_size += input_weight_size + hidden_weight_size + weight_size += hidden_size * 8 + + weight = helper.create_parameter( + attr=helper.param_attr, + shape=[weight_size], + dtype=dtype, + default_initializer=default_initializer) + + out = helper.create_variable_for_type_inference(dtype) + last_h = helper.create_variable_for_type_inference(dtype) + last_c = helper.create_variable_for_type_inference(dtype) + + cache = helper.create_variable( + persistable=True, type=core.VarDesc.VarType.RAW, stop_gradient=True) + + helper.append_op( + type='cudnn_lstm', + inputs={ + 'Input': input, + 'InitH': init_h, + 'InitC': init_c, + 'W': weight, + 'Cache': cache, + }, + outputs={ + 'Out': out, + 'last_h': last_h, + 'last_c': last_c, + }, + attrs={ + 'max_len': max_len, + 'is_bidirec': is_bidirec, + 'input_size': input_size, + 'hidden_size': hidden_size, + 'num_layers': num_layers, + 'is_test': is_test, + 'dropout_prob': dropout_prob, + 'seed': seed, + }) + return out, last_h, last_c + + def dynamic_lstmp(input, size, proj_size, diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 271b9c740fd99554e9a7aa8d476a52cf6385b1d9..76a707efdc0804be0316ab12c347ffed6199529a 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -216,6 +216,15 @@ class OpTest(unittest.TestCase): self.dtype) outputs = append_input_output(block, op_proto, self.outputs, False, self.dtype) + + if hasattr(self, "cache_name_list"): + for name in self.cache_name_list: + inputs[name] = block.create_var( + name=name, + persistable=True, + type=core.VarDesc.VarType.RAW, + stop_gradient=True) + op = block.append_op( type=self.op_type, inputs=inputs, @@ -428,8 +437,17 @@ class OpTest(unittest.TestCase): op_inputs = self.inputs if hasattr(self, "inputs") else dict() op_outputs = self.outputs if hasattr(self, "outputs") else dict() op_attrs = self.attrs if hasattr(self, "attrs") else dict() - self.op = create_op(self.scope, self.op_type, op_inputs, op_outputs, - op_attrs) + + cache_list = None + if hasattr(self, "cache_name_list"): + cache_list = self.cache_name_list + self.op = create_op( + self.scope, + self.op_type, + op_inputs, + op_outputs, + op_attrs, + cache_list=cache_list) if no_grad_set is None: no_grad_set = set() diff --git a/python/paddle/fluid/tests/unittests/testsuite.py b/python/paddle/fluid/tests/unittests/testsuite.py index 34fbb1b549cf5fc5f75bcc0715e5c83665f1d200..dc3b2cb8bc15836a4bf067caa05c3a37a917ecad 100644 --- a/python/paddle/fluid/tests/unittests/testsuite.py +++ b/python/paddle/fluid/tests/unittests/testsuite.py @@ -20,7 +20,7 @@ import paddle.fluid.core as core from paddle.fluid.op import Operator -def create_op(scope, op_type, inputs, outputs, attrs): +def create_op(scope, op_type, inputs, outputs, attrs, cache_list=None): kwargs = dict() op_maker = core.op_proto_and_checker_maker @@ -43,6 +43,11 @@ def create_op(scope, op_type, inputs, outputs, attrs): __create_var__(in_name, sub_in_name) else: __create_var__(in_name, in_name) + if cache_list != None and isinstance(cache_list, list): + for name in cache_list: + kwargs[name] = [] + scope.var(name) + kwargs[name].append(name) for out_name, out_dup in Operator.get_op_outputs(op_type): if out_name in outputs: