diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 0a71f153434cc89b26d97e1e1fcfae6c2bc88566..550c7ddeb438ec2ab4389bb50d5eb6f13b4f496a 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -187,6 +187,7 @@ paddle.fluid.layers.grid_sampler ArgSpec(args=['x', 'grid', 'name'], varargs=Non paddle.fluid.layers.log_loss ArgSpec(args=['input', 'label', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(0.0001, None)) paddle.fluid.layers.add_position_encoding ArgSpec(args=['input', 'alpha', 'beta', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.bilinear_tensor_product ArgSpec(args=['x', 'y', 'size', 'act', 'name', 'param_attr', 'bias_attr'], varargs=None, keywords=None, defaults=(None, None, None, None)) +paddle.fluid.layers.cudnn_lstm ArgSpec(args=['input', 'init_h', 'init_c', 'batch_size', 'max_len', 'dropout_prob', 'input_size', 'hidden_size', 'num_layers', 'is_bidirec', 'dtype', 'is_test', 'name', 'default_initializer', 'fix_seed', 'seed'], varargs=None, keywords=None, defaults=(False, 'float32', False, None, None, False, 0)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)) paddle.fluid.layers.read_file ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None) diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index bfdfdc56b34098dd170f8ca98d27b41759c2f57b..06d3ee9e72527fded0db3e8fbca17b6eaa38304c 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -174,6 +174,14 @@ class ExecutionContext { return op_.Inputs(name).size(); } + const std::string InputVarName(const std::string& name) const { + return op_.Input(name); + } + + const std::string OutputVarName(const std::string& name) const { + return op_.Output(name); + } + size_t OutputSize(const std::string& name) const { return op_.Outputs(name).size(); } diff --git a/paddle/fluid/operators/cudnn_lstm_op.cc b/paddle/fluid/operators/cudnn_lstm_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..cadc5b88308156ab29235b11e8179bb5eede4c3d --- /dev/null +++ b/paddle/fluid/operators/cudnn_lstm_op.cc @@ -0,0 +1,204 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/cudnn_lstm_op.h" +#include + +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/cudnn_helper.h" +#endif + +namespace paddle { +namespace operators { + +class CudnnLSTMOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input(Input) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("W"), + "Input(Weight) of LSTM should not be null."); + + PADDLE_ENFORCE(ctx->HasInput("InitH"), + "Input(init_h) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("InitC"), + "Input(init_c) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Cache"), + "Input(Cache) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("last_h"), + "Output(last_h) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("last_c"), + "Output(last_c) of LSTM should not be null."); + + auto in_dims = ctx->GetInputDim("Input"); + PADDLE_ENFORCE_EQ(in_dims.size(), 3, "Input(X)'s rank must be 3."); + + ctx->SetOutputDim("Out", ctx->GetInputDim("Input")); + ctx->SetOutputDim("last_h", ctx->GetInputDim("InitH")); + ctx->SetOutputDim("last_c", ctx->GetInputDim("InitC")); + } +}; + +class CudnnLSTMOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput( + "Input", + "(Tensor) RNN input tensor, which support variable-time length input " + "sequence." + "The shape of the Tensor MUST be ( seq_len * batch_size * input_size)" + "seq_len is the total time step in this mini-batch (CAN be change in " + "different batch)" + "batch_size is the instance number of this batch" + "input_size is the hidden size of the input." + "input_hidden_size and the hidden_size in the next may not be same"); + AddInput("InitH", + "(Tensor) the initial hidden state of the LSTM" + "input. This is a tensor with shape (num_layers x batch_size x " + "hidden_size)" + "and When is_bidirec is True, the shape will be (num_layers*2 x " + "batch_size x hidden_size)"); + AddInput("InitC", + "(Tensor) the initial cell state of the LSTm " + "input. This is a tensor with shape (num_layers x batch_size x " + "hidden_size)" + "and When is_bidirec is True, the shape will be (num_layers*2 x " + "batch_size x hidden_size)"); + AddInput("W", + "(Tensor) the learnable hidden-hidden weights." + " The shape is (N), where N is total weight size of the LSTM. " + " cudnn concatenate all the weight to one Tensor"); + AddInput("Cache", + "The cache of dropout op, a RAW type variable including random " + "number generator states and some descriptors, which is used in " + "cudnn kernel.") + .AsDispensable(); + AddOutput("Out", + "(Tensor) the hidden state of LSTM operator. " + "The shape is ( seq_len x batch_size x hidden_size) if " + "is_bidirec is False" + "and When is_bidirec is True, the shape will be ( seq_len x " + "batch_size x hidden_size * 2) "); + AddOutput("last_h", + "(Tensor) the hidden state of the last step. " + "The shape is ( num_layers x batch_size x hidden_size) if " + "is_bidirec is False" + "and When is_bidirec is True, the shape will be (num_layers*2 x " + "batch_size x hidden_size)"); + AddOutput("last_c", + "(Tensor) the cell state of the last step" + "The shape is ( num_layers x batch_size x hidden_size) if " + "is_bidirec is False" + "and When is_bidirect is True, the shape will be (num_layers*2 x " + "batch_size x hidden_size*2)"); + AddAttr("max_len", + "max length of the LSTM op" + "the first dim of the Input can NOT be greater than max_len") + .SetDefault(20); + AddAttr( + "dropout_prob", + "dropout prob of the dropout op" + "the dropout ONLY work between lstm layers, not between time steps" + "There is no dropout work on the Out tensor") + .SetDefault(0.0); + AddAttr("is_bidirec", + "is_bidirec" + "if it is bidirection rnn" + "The will affect the shape of the Out, last_h, and last_c") + .SetDefault(false); + AddAttr("input_size", "input size ot the Input Tensor").SetDefault(10); + AddAttr("batch_size", "the instance number the batch").SetDefault(10); + AddAttr("hidden_size", "hidden size of the LSTM").SetDefault(100); + AddAttr("num_layers", "the total layer number of the LSTM") + .SetDefault(1); + AddAttr("is_test", "True if in test phase.").SetDefault(false); + AddAttr("fix_seed", "True if it fix dropout seed").SetDefault(false); + AddAttr("seed", "seed to used if fix_seed is True").SetDefault(0); + AddComment(R"DOC( +CUDNN LSTM implementation + +A four-gate Long Short-Term Memory network with no peephole connections. +In the forward pass the output ht and cell output ct for a given iteration can be computed from the recurrent input ht-1, +the cell input ct-1 and the previous layer input xt given matrices W, R and biases bW, bR from the following equations: + +it = σ(Wi X xt + Ri X ht-1 + bWi + bRi) +ft = σ(Wf X xt + Rf X ht-1 + bWf + bRf) +ot = σ(Wo X xt + Ro X ht-1 + bWo + bRo) +c't = tanh(Wc X xt + Rc X ht-1 + bWc + bRc) +ct = ft * ct-1 + it * c't +ht = ot * tanh(ct) + +Where σ is the sigmoid operator: σ(x) = 1 / (1 + e^-x), * represents a point-wise multiplication, +X represensts a matrix multiplication +and tanh is the hyperbolic tangent function. it, ft, ot, c't represent the input, forget, output and new gates respectively. + + +)DOC"); + } +}; + +class CudnnLSTMGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input(Input) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("last_h"), + "Input(last_h) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("last_c"), + "Input(last_c) of LSTM should not be null."); + + PADDLE_ENFORCE(ctx->HasInput("Cache"), + "Input(last_c) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("InitH"), + "Input(init_h) of LSTM should not be null."); + + PADDLE_ENFORCE(ctx->HasInput("InitC"), + "Input(init_c) of LSTM should not be null."); + + auto SetOutGradDim = [&ctx](const std::string& name) { + auto g_name = framework::GradVarName(name); + if (ctx->HasOutput(g_name)) { + ctx->SetOutputDim(g_name, ctx->GetInputDim(name)); + } + }; + + SetOutGradDim("Input"); + SetOutGradDim("W"); + SetOutGradDim("InitH"); + SetOutGradDim("InitC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(cudnn_lstm, ops::CudnnLSTMOp, ops::CudnnLSTMOpMaker, + paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(cudnn_lstm_grad, ops::CudnnLSTMGradOp); + +REGISTER_OP_CPU_KERNEL( + cudnn_lstm, + ops::CudnnLSTMKernel); + +REGISTER_OP_CPU_KERNEL( + cudnn_lstm_grad, + ops::CudnnLSTMGradKernel); diff --git a/paddle/fluid/operators/cudnn_lstm_op.cu.cc b/paddle/fluid/operators/cudnn_lstm_op.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..9caf65b53ff9a2e520ad35fc8d09d123c2df801c --- /dev/null +++ b/paddle/fluid/operators/cudnn_lstm_op.cu.cc @@ -0,0 +1,491 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/cudnn_lstm_op.h" +#include "paddle/fluid/platform/cudnn_helper.h" + +namespace paddle { +namespace operators { + +using LoDTensor = framework::LoDTensor; +using Tensor = framework::Tensor; + +struct CudnnRNNCache { + CudnnRNNCache() { + x_desc_ = NULL; + y_desc_ = NULL; + dx_desc_ = NULL; + dy_desc_ = NULL; + } + ~CudnnRNNCache() { release(); } + + cudnnRNNDescriptor_t rnn_desc_; + cudnnTensorDescriptor_t *x_desc_; + cudnnTensorDescriptor_t *y_desc_; + cudnnTensorDescriptor_t *dx_desc_; + cudnnTensorDescriptor_t *dy_desc_; + + cudnnTensorDescriptor_t hx_desc_; + cudnnTensorDescriptor_t cx_desc_; + cudnnTensorDescriptor_t hy_desc_; + cudnnTensorDescriptor_t cy_desc_; + + cudnnTensorDescriptor_t dhx_desc_; + cudnnTensorDescriptor_t dcx_desc_; + cudnnTensorDescriptor_t dhy_desc_; + cudnnTensorDescriptor_t dcy_desc_; + + cudnnTensorDescriptor_t output_x_desc_; + cudnnTensorDescriptor_t output_y_desc_; + + cudnnDropoutDescriptor_t dropout_desc_; + + size_t weights_size_; + cudnnFilterDescriptor_t w_desc_; + cudnnFilterDescriptor_t dw_desc_; + + size_t workspace_size_; + size_t reserve_size_; + Tensor reserve_data_; + Tensor workspace_data_; + + Tensor dropout_state_; + + size_t max_length_; + + float dropout_prob_; + bool is_bidirec_; + + int batch_size_; + int input_size_; + int hidden_size_; + int num_layers_; + int seed_; + + void init(cudnnHandle_t handle, const framework::ExecutionContext &ctx, + size_t max_len, int batch_size, int input_size, int hidden_size, + int num_layers, float dropout_prob, bool is_bidirec, int seed, + int weight_numel) { + max_length_ = max_len; + batch_size_ = batch_size; + input_size_ = input_size; + hidden_size_ = hidden_size; + num_layers_ = num_layers; + dropout_prob_ = dropout_prob; + is_bidirec_ = is_bidirec; + seed_ = seed; + + x_desc_ = new cudnnTensorDescriptor_t[max_length_]; + y_desc_ = new cudnnTensorDescriptor_t[max_length_]; + dx_desc_ = new cudnnTensorDescriptor_t[max_length_]; + dy_desc_ = new cudnnTensorDescriptor_t[max_length_]; + int dim_a[3]; + int stride_a[3]; + + for (size_t i = 0; i < max_length_; ++i) { + CUDNN_ENFORCE( + platform::dynload::cudnnCreateTensorDescriptor(&x_desc_[i])); + CUDNN_ENFORCE( + platform::dynload::cudnnCreateTensorDescriptor(&y_desc_[i])); + CUDNN_ENFORCE( + platform::dynload::cudnnCreateTensorDescriptor(&dx_desc_[i])); + CUDNN_ENFORCE( + platform::dynload::cudnnCreateTensorDescriptor(&dy_desc_[i])); + dim_a[0] = batch_size_; + dim_a[1] = input_size_; + dim_a[2] = 1; + + stride_a[0] = dim_a[2] * dim_a[1]; + stride_a[1] = dim_a[2]; + stride_a[2] = 1; + CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( + x_desc_[i], CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); + CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( + dx_desc_[i], CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); + + dim_a[0] = batch_size_; + dim_a[1] = is_bidirec_ ? hidden_size_ * 2 : hidden_size_; + dim_a[2] = 1; + + stride_a[0] = dim_a[2] * dim_a[1]; + stride_a[1] = dim_a[2]; + stride_a[2] = 1; + + CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( + y_desc_[i], CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); + CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( + dy_desc_[i], CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); + } + + dim_a[0] = num_layers_ * (is_bidirec_ ? 2 : 1); + dim_a[1] = batch_size_; + dim_a[2] = hidden_size_; + + stride_a[0] = dim_a[2] * dim_a[1]; + stride_a[1] = dim_a[2]; + stride_a[2] = 1; + + CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&hx_desc_)); + CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&cx_desc_)); + CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&hy_desc_)); + CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&cy_desc_)); + CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&dhx_desc_)); + CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&dcx_desc_)); + CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&dhy_desc_)); + CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&dcy_desc_)); + + CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( + hx_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); + CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( + cx_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); + CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( + hy_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); + CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( + cy_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); + CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( + dhx_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); + CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( + dcx_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); + CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( + dhy_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); + CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( + dcy_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); + + CUDNN_ENFORCE( + platform::dynload::cudnnCreateDropoutDescriptor(&dropout_desc_)); + + size_t state_size; + CUDNN_ENFORCE( + platform::dynload::cudnnDropoutGetStatesSize(handle, &state_size); + dropout_state_.Resize({static_cast(state_size)})); + auto *dropout_state_data = + dropout_state_.mutable_data(ctx.GetPlace()); + CUDNN_ENFORCE(platform::dynload::cudnnSetDropoutDescriptor( + dropout_desc_, handle, dropout_prob_, dropout_state_data, state_size, + seed_)); + + CUDNN_ENFORCE(platform::dynload::cudnnCreateRNNDescriptor(&rnn_desc_)); + CUDNN_ENFORCE(platform::dynload::cudnnSetRNNDescriptor_v6( + handle, rnn_desc_, hidden_size_, num_layers_, dropout_desc_, + CUDNN_LINEAR_INPUT, + is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, CUDNN_LSTM, + CUDNN_RNN_ALGO_STANDARD, CUDNN_DATA_FLOAT)); + + CUDNN_ENFORCE(platform::dynload::cudnnCreateFilterDescriptor(&w_desc_)); + CUDNN_ENFORCE(platform::dynload::cudnnCreateFilterDescriptor(&dw_desc_)); + + CUDNN_ENFORCE(platform::dynload::cudnnGetRNNParamsSize( + handle, rnn_desc_, x_desc_[0], &weights_size_, CUDNN_DATA_FLOAT)); + + PADDLE_ENFORCE_EQ(weights_size_, sizeof(float) * weight_numel, + "cudnn lstm weight size should be SAME"); + int dim_w[3]; + dim_w[0] = weights_size_ / sizeof(float); + dim_w[1] = 1; + dim_w[2] = 1; + CUDNN_ENFORCE(platform::dynload::cudnnSetFilterNdDescriptor( + w_desc_, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, 3, dim_w)); + CUDNN_ENFORCE(platform::dynload::cudnnSetFilterNdDescriptor( + dw_desc_, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, 3, dim_w)); + + CUDNN_ENFORCE(platform::dynload::cudnnGetRNNWorkspaceSize( + handle, rnn_desc_, max_length_, x_desc_, &workspace_size_)); + CUDNN_ENFORCE(platform::dynload::cudnnGetRNNTrainingReserveSize( + handle, rnn_desc_, max_length_, x_desc_, &reserve_size_)); + + reserve_data_.Resize({static_cast(reserve_size_)}); + reserve_data_.mutable_data(ctx.GetPlace()); + + workspace_data_.Resize({static_cast(workspace_size_)}); + workspace_data_.mutable_data(ctx.GetPlace()); + } + + void release() { + for (size_t i = 0; i < max_length_; ++i) { + CUDNN_ENFORCE( + platform::dynload::cudnnDestroyTensorDescriptor(x_desc_[i])); + CUDNN_ENFORCE( + platform::dynload::cudnnDestroyTensorDescriptor(y_desc_[i])); + CUDNN_ENFORCE( + platform::dynload::cudnnDestroyTensorDescriptor(dx_desc_[i])); + CUDNN_ENFORCE( + platform::dynload::cudnnDestroyTensorDescriptor(dy_desc_[i])); + } + + delete[] x_desc_; + delete[] y_desc_; + delete[] dx_desc_; + delete[] dy_desc_; + + CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(hx_desc_)); + CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(cx_desc_)); + CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(hy_desc_)); + CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(cy_desc_)); + CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(dhx_desc_)); + CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(dcx_desc_)); + CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(dhy_desc_)); + CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(dcy_desc_)); + + CUDNN_ENFORCE( + platform::dynload::cudnnDestroyDropoutDescriptor(dropout_desc_)); + CUDNN_ENFORCE(platform::dynload::cudnnDestroyRNNDescriptor(rnn_desc_)); + + CUDNN_ENFORCE(platform::dynload::cudnnDestroyFilterDescriptor(w_desc_)); + CUDNN_ENFORCE(platform::dynload::cudnnDestroyFilterDescriptor(dw_desc_)); + } +}; + +template +class CudnnLSTMGPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + const Tensor *x = ctx.Input("Input"); + const Tensor *init_h = ctx.Input("InitH"); + const Tensor *init_c = ctx.Input("InitC"); + + auto w = ctx.Input("W"); + + Tensor *out = ctx.Output("Out"); + Tensor *last_h = ctx.Output("last_h"); + Tensor *last_c = ctx.Output("last_c"); + + const T *x_data = x->data(); + const T *init_h_data = init_h->data(); + const T *init_c_data = init_c->data(); + + const T *w_data = w->data(); + + T *out_data = out->mutable_data(ctx.GetPlace()); + T *last_h_data = last_h->mutable_data(ctx.GetPlace()); + T *last_c_data = last_c->mutable_data(ctx.GetPlace()); + + size_t max_len = ctx.Attr("max_len"); + float dropout_prob = ctx.Attr("dropout_prob"); + bool is_bidirec = ctx.Attr("is_bidirec"); + int batch_size = ctx.Attr("batch_size"); + int input_size = ctx.Attr("input_size"); + int hidden_size = ctx.Attr("hidden_size"); + int num_layers = ctx.Attr("num_layers"); + bool is_test = ctx.Attr("is_test"); + + /* + if (is_test) { + TensorCopy(*x, ctx.GetPlace(), out); + return; + }*/ + + auto &dev_ctx = ctx.template device_context(); + auto handle = dev_ctx.cudnn_handle(); + auto *cache_var = ctx.InputVar("Cache"); + if (!cache_var) { + // The RAW type cache variable wouldn't be created and broadcasted on + // multi-devices before the first running. + // use parent scope to make cache persistable + auto *scope = const_cast(ctx.scope().parent()); + auto cache_var_name = ctx.InputVarName("Cache"); + cache_var = scope->Var(cache_var_name); + } + CudnnRNNCache *cudnn_rnn_cache = nullptr; + if (cache_var->IsInitialized()) { + cudnn_rnn_cache = const_cast(cache_var) + ->GetMutable(); + } else { + cudnn_rnn_cache = const_cast(cache_var) + ->GetMutable(); + std::random_device rnd; + int seed = ctx.Attr("fix_seed") ? ctx.Attr("seed") : rnd(); + + auto input_w_numel = w->numel(); + cudnn_rnn_cache->init(handle, ctx, max_len, batch_size, input_size, + hidden_size, num_layers, dropout_prob, is_bidirec, + seed, input_w_numel); + } + + auto run_seq_len = x->dims()[0]; + + if (is_test) { + // for inference + CUDNN_ENFORCE(platform::dynload::cudnnRNNForwardInference( + handle, cudnn_rnn_cache->rnn_desc_, run_seq_len, + cudnn_rnn_cache->x_desc_, x_data, cudnn_rnn_cache->hx_desc_, + init_h_data, cudnn_rnn_cache->cx_desc_, init_c_data, + cudnn_rnn_cache->w_desc_, w_data, cudnn_rnn_cache->y_desc_, out_data, + cudnn_rnn_cache->hy_desc_, last_h_data, cudnn_rnn_cache->cy_desc_, + last_c_data, cudnn_rnn_cache->workspace_data_.data(), + cudnn_rnn_cache->workspace_size_)); + } else { + // for train + CUDNN_ENFORCE(platform::dynload::cudnnRNNForwardTraining( + handle, cudnn_rnn_cache->rnn_desc_, run_seq_len, + cudnn_rnn_cache->x_desc_, x_data, cudnn_rnn_cache->hx_desc_, + init_h_data, cudnn_rnn_cache->cx_desc_, init_c_data, + cudnn_rnn_cache->w_desc_, w_data, cudnn_rnn_cache->y_desc_, out_data, + cudnn_rnn_cache->hy_desc_, last_h_data, cudnn_rnn_cache->cy_desc_, + last_c_data, cudnn_rnn_cache->workspace_data_.data(), + cudnn_rnn_cache->workspace_size_, + cudnn_rnn_cache->reserve_data_.data(), + cudnn_rnn_cache->reserve_size_)); + } + } +}; + +template +class CudnnLSTMGPUGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *input = ctx.Input("Input"); + auto *weight = ctx.Input("W"); + auto *init_h = ctx.Input("InitH"); + auto *init_c = ctx.Input("InitC"); + // auto * last_h = ctx.Input("last_h"); + // auto * last_c = ctx.Input("last_c"); + auto *out = ctx.Input("Out"); + auto *out_grad = ctx.Input(framework::GradVarName("Out")); + auto *last_h_grad = ctx.Input(framework::GradVarName("last_h")); + auto *last_c_grad = ctx.Input(framework::GradVarName("last_c")); + + // auto* init_h = ctx.Input("init_h"); + // auto* init_c = ctx.Input("init_c"); + + auto *in_grad = ctx.Output(framework::GradVarName("Input")); + auto *weight_grad = ctx.Output(framework::GradVarName("W")); + auto *init_h_grad = ctx.Output(framework::GradVarName("InitH")); + auto *init_c_grad = ctx.Output(framework::GradVarName("InitC")); + + auto &dev_ctx = ctx.template device_context(); + auto handle = dev_ctx.cudnn_handle(); + auto *cache_var = ctx.InputVar("Cache"); + PADDLE_ENFORCE(cache_var->IsInitialized()); + CudnnRNNCache *cudnn_rnn_cache = + const_cast(cache_var) + ->GetMutable(); + + auto input_dims = input->dims(); + auto weight_dims = weight->dims(); + auto init_h_dims = init_h->dims(); + auto init_c_dims = init_c->dims(); + in_grad->mutable_data(ctx.GetPlace()); + weight_grad->mutable_data(ctx.GetPlace()); + math::SetConstant zero; + zero(dev_ctx, in_grad, static_cast(0.0)); + zero(dev_ctx, weight_grad, static_cast(0.0)); + + T *init_h_grad_data = NULL; + if (init_h_grad == nullptr) { + Tensor init_h_grad_temp; + init_h_grad_temp.mutable_data(init_h_dims, ctx.GetPlace()); + zero(dev_ctx, &init_h_grad_temp, static_cast(0.0)); + + init_h_grad_data = init_h_grad_temp.data(); + } else { + init_h_grad->mutable_data(init_h_dims, ctx.GetPlace()); + zero(dev_ctx, init_h_grad, static_cast(0.0)); + init_h_grad_data = init_h_grad->data(); + } + + T *init_c_grad_data = NULL; + if (init_c_grad == nullptr) { + Tensor init_c_grad_temp; + init_c_grad_temp.mutable_data(init_c_dims, ctx.GetPlace()); + zero(dev_ctx, &init_c_grad_temp, static_cast(0.0)); + + init_c_grad_data = init_c_grad_temp.data(); + } else { + init_c_grad->mutable_data(init_c_dims, ctx.GetPlace()); + zero(dev_ctx, init_c_grad, static_cast(0.0)); + init_c_grad_data = init_c_grad->data(); + } + + const T *last_h_grad_data = NULL; + if (last_h_grad == nullptr) { + Tensor last_h_grad_temp; + last_h_grad_temp.mutable_data(init_h_dims, ctx.GetPlace()); + zero(dev_ctx, &last_h_grad_temp, static_cast(0.0)); + + last_h_grad_data = (const T *)last_h_grad_temp.data(); + } else { + last_h_grad_data = last_h_grad->data(); + } + + const T *last_c_grad_data = NULL; + if (last_c_grad == nullptr) { + Tensor last_c_grad_temp; + last_c_grad_temp.mutable_data(init_c_dims, ctx.GetPlace()); + zero(dev_ctx, &last_c_grad_temp, static_cast(0.0)); + + last_c_grad_data = (const T *)last_c_grad_temp.data(); + } else { + last_c_grad_data = last_c_grad->data(); + } + + const T *out_grad_data = NULL; + if (out_grad == nullptr) { + Tensor out_grad_temp; + out_grad_temp.mutable_data(out->dims(), ctx.GetPlace()); + zero(dev_ctx, &out_grad_temp, static_cast(0.0)); + + out_grad_data = (const T *)out_grad_temp.data(); + } else { + out_grad_data = out_grad->data(); + } + + // zero( dev_ctx, last_h_grad, static_cast(0.0)); + // zero( dev_ctx, last_c_grad, static_cast(0.0)); + + auto out_data = out->data(); + // auto out_grad_data = out_grad->data(); + auto weight_data = weight->data(); + auto init_h_data = init_h->data(); + auto init_c_data = init_c->data(); + auto in_grad_data = in_grad->data(); + + auto work_data = cudnn_rnn_cache->workspace_data_.data(); + auto reserve_data = cudnn_rnn_cache->reserve_data_.data(); + + auto run_seq_len = input_dims[0]; + PADDLE_ENFORCE_LE((size_t)run_seq_len, cudnn_rnn_cache->max_length_, + "cudnn running seq_len CAN not greater max_lengh"); + CUDNN_ENFORCE(platform::dynload::cudnnRNNBackwardData( + handle, cudnn_rnn_cache->rnn_desc_, run_seq_len, + cudnn_rnn_cache->y_desc_, out_data, cudnn_rnn_cache->dy_desc_, + out_grad_data, cudnn_rnn_cache->dhy_desc_, last_h_grad_data, + cudnn_rnn_cache->dcy_desc_, last_c_grad_data, cudnn_rnn_cache->w_desc_, + weight_data, cudnn_rnn_cache->hx_desc_, init_h_data, + cudnn_rnn_cache->cx_desc_, init_c_data, cudnn_rnn_cache->dx_desc_, + in_grad_data, cudnn_rnn_cache->dhx_desc_, init_h_grad_data, + cudnn_rnn_cache->dcx_desc_, init_c_grad_data, work_data, + cudnn_rnn_cache->workspace_size_, reserve_data, + cudnn_rnn_cache->reserve_size_)); + + CUDNN_ENFORCE(platform::dynload::cudnnRNNBackwardWeights( + handle, cudnn_rnn_cache->rnn_desc_, run_seq_len, + cudnn_rnn_cache->x_desc_, input->data(), cudnn_rnn_cache->hx_desc_, + init_h->data(), cudnn_rnn_cache->y_desc_, out->data(), + cudnn_rnn_cache->workspace_data_.data(), + cudnn_rnn_cache->workspace_size_, cudnn_rnn_cache->dw_desc_, + weight_grad->data(), cudnn_rnn_cache->reserve_data_.data(), + cudnn_rnn_cache->reserve_size_)); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + cudnn_lstm, + ops::CudnnLSTMGPUKernel); +REGISTER_OP_CUDA_KERNEL( + cudnn_lstm_grad, + ops::CudnnLSTMGPUGradKernel); diff --git a/paddle/fluid/operators/cudnn_lstm_op.h b/paddle/fluid/operators/cudnn_lstm_op.h new file mode 100644 index 0000000000000000000000000000000000000000..fb4b37e46e004b0f9d09e77f71995df1d18f87d1 --- /dev/null +++ b/paddle/fluid/operators/cudnn_lstm_op.h @@ -0,0 +1,42 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/detail/activation_functions.h" +#include "paddle/fluid/operators/math/lstm_compute.h" +#include "paddle/fluid/operators/math/sequence2batch.h" + +namespace paddle { +namespace operators { + +using LoDTensor = framework::LoDTensor; +using Tensor = framework::Tensor; + +template +class CudnnLSTMKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override {} +}; + +template +class CudnnLSTMGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override {} +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/platform/dynload/cudnn.h b/paddle/fluid/platform/dynload/cudnn.h index db62377898339def415a13d185f85f34de326d7f..213cd8a9ce094512cea6f6405492ec8feff11516 100644 --- a/paddle/fluid/platform/dynload/cudnn.h +++ b/paddle/fluid/platform/dynload/cudnn.h @@ -111,7 +111,23 @@ extern void EnforceCUDNNLoaded(const char* fn_name); __macro(cudnnFindConvolutionForwardAlgorithmEx); \ __macro(cudnnFindConvolutionBackwardFilterAlgorithmEx); \ __macro(cudnnFindConvolutionBackwardDataAlgorithmEx); \ - __macro(cudnnGetErrorString); + __macro(cudnnGetErrorString); \ + __macro(cudnnCreateDropoutDescriptor); \ + __macro(cudnnDropoutGetStatesSize); \ + __macro(cudnnSetDropoutDescriptor); \ + __macro(cudnnCreateRNNDescriptor); \ + __macro(cudnnSetRNNDescriptor); \ + __macro(cudnnGetRNNParamsSize); \ + __macro(cudnnGetRNNWorkspaceSize); \ + __macro(cudnnGetRNNTrainingReserveSize); \ + __macro(cudnnRNNForwardTraining); \ + __macro(cudnnRNNBackwardData); \ + __macro(cudnnRNNBackwardWeights); \ + __macro(cudnnRNNForwardInference); \ + __macro(cudnnDestroyDropoutDescriptor); \ + __macro(cudnnDestroyRNNDescriptor); \ + __macro(cudnnSetRNNDescriptor_v6); + CUDNN_DNN_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) #define CUDNN_DNN_ROUTINE_EACH_R2(__macro) \ diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 7af1f380e701f867921a16d9f0a91bcfad5e23ea..abb82e750520ae4c6d80b7088f3bd11c59bdd3b6 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -169,6 +169,7 @@ __all__ = [ 'log_loss', 'add_position_encoding', 'bilinear_tensor_product', + 'cudnn_lstm', ] @@ -466,6 +467,157 @@ def dynamic_lstm(input, return hidden, cell +def cudnn_lstm(input, + init_h, + init_c, + batch_size, + max_len, + dropout_prob, + input_size, + hidden_size, + num_layers, + is_bidirec=False, + dtype='float32', + is_test=False, + name=None, + default_initializer=None, + fix_seed=False, + seed=0): + """ + CUDNN LSTM implementation + + A four-gate Long Short-Term Memory network with no peephole connections. + In the forward pass the output ht and cell output ct for a given iteration can be computed from the recurrent input ht-1, + the cell input ct-1 and the previous layer input xt given matrices W, R and biases bW, bR from the following equations: + + it = sigmoid(Wi X xt + Ri X ht-1 + bWi + bRi) + ft = sigmoid(Wf X xt + Rf X ht-1 + bWf + bRf) + ot = sigmoid(Wo X xt + Ro X ht-1 + bWo + bRo) + c't = tanh(Wc X xt + Rc X ht-1 + bWc + bRc) + ct = ft * ct-1 + it * c't + ht = ot * tanh(ct) + + Where sigmoid is the sigmoid operator: sigmoid(x) = 1 / (1 + e^-x), * represents a point-wise multiplication, + X represensts a matrix multiplication + and tanh is the hyperbolic tangent function. it, ft, ot, c't represent the input, forget, output and new gates respectively. + + + Args: + input (Variable): LSTM input tensor, shape MUST be ( seq_len x batch_size x input_size ) + init_h(Variable): The initial hidden state of the LSTM + This is a tensor with shape ( num_layers x batch_size x hidden_size) + if is_bidirec = True, shape should be ( num_layers*2 x batch_size x hidden_size) + init_c(Variable): The initial cell state of the LSTM. + This is a tensor with shape ( num_layers x batch_size x hidden_size ) + if is_bidirec = True, shape should be ( num_layers*2 x batch_size x hidden_size) + batch_size (int): total distance numer of the batch + max_len (int): max length of LSTM. the first dim of input tensor CAN NOT greater than max_len + dropout_prob(float): dropout prob, dropout ONLY work between rnn layers, NOT between time steps + There is NO dropout work on rnn output of the last RNN layers + input_size (int): hidden size of the input tensor + hidden_size (int): hidden size of the LSTM + num_layers (int): total layers number of the LSTM + is_bidirec (bool): If it is bidirectional + dtype (str): Data type. Choices = ["float32", "float64"], default "float32". + is_test (bool): If it is in test phrase + name (str|None): A name for this layer(optional). If set None, the layer + will be named automatically. + default_initializer(Initialize|None): Where use initializer to initialize the Weight + If set None, defaule initializer will be used + + + Returns: + rnn_out(Tensor): result of LSTM hidden, shape is (seq_len x batch_size x hidden_size) + if is_bidirec set to True, shape will be ( seq_len x batch_sze x hidden_size*2) + last_h(Tensor): the hidden state of the last step of LSTM + shape is ( num_layers x batch_size x hidden_size ) + if is_bidirec set to True, shape will be ( num_layers*2 x batch_size x hidden_size) + last_c(Tensor): the cell state of the last step of LSTM + shape is ( num_layers x batch_size x hidden_size ) + if is_bidirec set to True, shape will be ( num_layers*2 x batch_size x hidden_size) + + + Examples: + .. code-block:: python + + input = embedding + batch_size = 20 + max_len = 100 + dropout_prob = 0.2 + input_size = 100 + hidden_size = 150 + num_layers = 1 + init_hidden1 = layers.fill_constant( [num_layers, batch_size, hidden_size], 'float32', 0.0, stop_grad=False) + init_cell1 = layers.fill_constant( [num_layers, batch_size, hidden_size], 'float32', 0.0, stop_grad=False) + + rnn_out, last_h, last_c = layers.cudnn_lstm( input, init_h, init_c, batch_size, \ + max_len, dropout_prob, input_size, hidden_size, \ + num_layers) + """ + + helper = LayerHelper('cudnn_lstm', **locals()) + + weight_size = 0 + for i in range(num_layers): + if i == 0: + input_weight_size = (input_size * hidden_size) * 4 + else: + if is_bidirec: + input_weight_size = (hidden_size * 2 * hidden_size) * 4 + else: + input_weight_size = (hidden_size * hidden_size) * 4 + + hidden_weight_size = (hidden_size * hidden_size) * 4 + + if is_bidirec: + weight_size += (input_weight_size + hidden_weight_size) * 2 + weight_size += hidden_size * 8 * 2 + else: + weight_size += input_weight_size + hidden_weight_size + weight_size += hidden_size * 8 + + weight = helper.create_parameter( + attr=helper.param_attr, + shape=[weight_size], + dtype=dtype, + default_initializer=default_initializer) + + out = helper.create_variable_for_type_inference(dtype) + last_h = helper.create_variable_for_type_inference(dtype) + last_c = helper.create_variable_for_type_inference(dtype) + + cache = helper.create_variable( + persistable=True, type=core.VarDesc.VarType.RAW, stop_gradient=True) + + helper.append_op( + type='cudnn_lstm', + inputs={ + 'Input': input, + 'InitH': init_h, + 'InitC': init_c, + 'W': weight, + 'Cache': cache, + }, + outputs={ + 'Out': out, + 'last_h': last_h, + 'last_c': last_c, + }, + attrs={ + 'max_len': max_len, + 'is_bidirec': is_bidirec, + 'input_size': input_size, + 'batch_size': batch_size, + 'hidden_size': hidden_size, + 'num_layers': num_layers, + 'is_test': is_test, + 'dropout_prob': dropout_prob, + 'fix_seed': fix_seed, + 'seed': seed, + }) + return out, last_h, last_c + + def dynamic_lstmp(input, size, proj_size,