From 6d98ba3074799c21dd577a273d6a33c59d555596 Mon Sep 17 00:00:00 2001 From: phlrain Date: Mon, 3 Dec 2018 15:51:56 +0800 Subject: [PATCH] add cudnn lstm test=release/1.2 --- paddle/fluid/operators/cudnn_lstm_op.cc | 218 ++++++++ paddle/fluid/operators/cudnn_lstm_op.cu.cc | 485 ++++++++++++++++++ .../tests/unittests/test_lstm_cudnn_op.py | 192 +++++++ 3 files changed, 895 insertions(+) create mode 100644 paddle/fluid/operators/cudnn_lstm_op.cc create mode 100644 paddle/fluid/operators/cudnn_lstm_op.cu.cc create mode 100644 python/paddle/fluid/tests/unittests/test_lstm_cudnn_op.py diff --git a/paddle/fluid/operators/cudnn_lstm_op.cc b/paddle/fluid/operators/cudnn_lstm_op.cc new file mode 100644 index 000000000..e63d57be5 --- /dev/null +++ b/paddle/fluid/operators/cudnn_lstm_op.cc @@ -0,0 +1,218 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +class CudnnLSTMOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input(Input) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("W"), + "Input(Weight) of LSTM should not be null."); + + PADDLE_ENFORCE(ctx->HasInput("InitH"), + "Input(init_h) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("InitC"), + "Input(init_c) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Cache"), + "Input(Cache) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("last_h"), + "Output(last_h) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("last_c"), + "Output(last_c) of LSTM should not be null."); + + auto in_dims = ctx->GetInputDim("Input"); + PADDLE_ENFORCE_EQ(in_dims.size(), 3, "Input(X)'s rank must be 3."); + + ctx->SetOutputDim("Out", ctx->GetInputDim("Input")); + ctx->SetOutputDim("last_h", ctx->GetInputDim("InitH")); + ctx->SetOutputDim("last_c", ctx->GetInputDim("InitC")); + } +}; + +class CudnnLSTMOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput( + "Input", + "(Tensor) RNN input tensor, which support variable-time length input " + "sequence." + "The shape of the Tensor MUST be ( seq_len * batch_size * input_size)" + "seq_len is the total time step in this mini-batch (CAN be change in " + "different batch)" + "batch_size is the instance number of this batch" + "input_size is the hidden size of the input." + "input_hidden_size and the hidden_size in the next may not be same"); + AddInput("InitH", + "(Tensor) the initial hidden state of the LSTM" + "input. This is a tensor with shape (num_layers x batch_size x " + "hidden_size)" + "and When is_bidirec is True, the shape will be (num_layers*2 x " + "batch_size x hidden_size)"); + AddInput("InitC", + "(Tensor) the initial cell state of the LSTm " + "input. This is a tensor with shape (num_layers x batch_size x " + "hidden_size)" + "and When is_bidirec is True, the shape will be (num_layers*2 x " + "batch_size x hidden_size)"); + AddInput("W", + "(Tensor) the learnable hidden-hidden weights." + " The shape is (N), where N is total weight size of the LSTM. " + " cudnn concatenate all the weight to one Tensor"); + AddInput("Cache", + "The cache of dropout op, a RAW type variable including random " + "number generator states and some descriptors, which is used in " + "cudnn kernel.") + .AsDispensable(); + AddOutput("Out", + "(Tensor) the hidden state of LSTM operator. " + "The shape is ( seq_len x batch_size x hidden_size) if " + "is_bidirec is False" + "and When is_bidirec is True, the shape will be ( seq_len x " + "batch_size x hidden_size * 2) "); + AddOutput("last_h", + "(Tensor) the hidden state of the last step. " + "The shape is ( num_layers x batch_size x hidden_size) if " + "is_bidirec is False" + "and When is_bidirec is True, the shape will be (num_layers*2 x " + "batch_size x hidden_size)"); + AddOutput("last_c", + "(Tensor) the cell state of the last step" + "The shape is ( num_layers x batch_size x hidden_size) if " + "is_bidirec is False" + "and When is_bidirect is True, the shape will be (num_layers*2 x " + "batch_size x hidden_size*2)"); + AddAttr("max_len", + "max length of the LSTM op" + "the first dim of the Input can NOT be greater than max_len") + .SetDefault(20); + AddAttr( + "dropout_prob", + "dropout prob of the dropout op" + "the dropout ONLY work between lstm layers, not between time steps" + "There is no dropout work on the Out tensor") + .SetDefault(0.0); + AddAttr("is_bidirec", + "is_bidirec" + "if it is bidirection rnn" + "The will affect the shape of the Out, last_h, and last_c") + .SetDefault(false); + AddAttr("input_size", "input size ot the Input Tensor").SetDefault(10); + AddAttr("hidden_size", "hidden size of the LSTM").SetDefault(100); + AddAttr("num_layers", "the total layer number of the LSTM") + .SetDefault(1); + AddAttr("is_test", "True if in test phase.").SetDefault(false); + AddAttr("seed", "seed to used if fix_seed is True").SetDefault(-1); + AddComment(R"DOC( +CUDNN LSTM implementation + +A four-gate Long Short-Term Memory network with no peephole connections. +In the forward pass the output ht and cell output ct for a given iteration can be computed from the recurrent input ht-1, +the cell input ct-1 and the previous layer input xt given matrices W, R and biases bW, bR from the following equations: + +$$ i_t = sigmoid(W_{ix}x_{t} + W_{ih}h_{t-1} + bx_i + bh_i) $$ + +$$ f_t = sigmoid(W_{fx}x_{t} + W_{fh}h_{t-1} + bx_f + bh_f) $$ + +$$ o_t = sigmoid(W_{ox}x_{t} + W_{oh}h_{t-1} + bx_o + bh_o) $$ + +$$ \\tilde{c_t} = tanh(W_{cx}x_t + W_{ch}h_{t-1} + bx_c + bh_c) $$ + +$$ c_t = f_t \\odot c_{t-1} + i_t \\odot \\tilde{c_t} $$ + +$$ h_t = o_t \\odot tanh(c_t) $$ + +- W terms denote weight matrices (e.g. $W_{ix}$ is the matrix + of weights from the input gate to the input) +- The b terms denote bias vectors ($bx_i$ and $bh_i$ are the input gate bias vector). +- sigmoid is the logistic sigmoid function. +- $i, f, o$ and $c$ are the input gate, forget gate, output gate, + and cell activation vectors, respectively, all of which have the same size as + the cell output activation vector $h$. +- The $\odot$ is the element-wise product of the vectors. +- `tanh` is the activation functions. +- $\tilde{c_t}$ is also called candidate hidden state, + which is computed based on the current input and the previous hidden state. + +Where sigmoid is the sigmoid operator: sigmoid(x) = 1 / (1 + e^-x), * represents a point-wise multiplication, +X represensts a matrix multiplication + + +)DOC"); + } +}; + +class CudnnLSTMGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input(Input) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("last_h"), + "Input(last_h) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("last_c"), + "Input(last_c) of LSTM should not be null."); + + PADDLE_ENFORCE(ctx->HasInput("Cache"), + "Input(last_c) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("InitH"), + "Input(init_h) of LSTM should not be null."); + + PADDLE_ENFORCE(ctx->HasInput("InitC"), + "Input(init_c) of LSTM should not be null."); + + auto SetOutGradDim = [&ctx](const std::string& name) { + auto g_name = framework::GradVarName(name); + if (ctx->HasOutput(g_name)) { + ctx->SetOutputDim(g_name, ctx->GetInputDim(name)); + } + }; + + SetOutGradDim("Input"); + SetOutGradDim("W"); + SetOutGradDim("InitH"); + SetOutGradDim("InitC"); + } +}; + +template +class NotImpleKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_THROW( + "CPU is not support for this kernel now. Will be add in the future"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(cudnn_lstm, ops::CudnnLSTMOp, ops::CudnnLSTMOpMaker, + paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(cudnn_lstm_grad, ops::CudnnLSTMGradOp); + +REGISTER_OP_CPU_KERNEL(cudnn_lstm, ops::NotImpleKernel); +REGISTER_OP_CPU_KERNEL(cudnn_lstm_grad, ops::NotImpleKernel); diff --git a/paddle/fluid/operators/cudnn_lstm_op.cu.cc b/paddle/fluid/operators/cudnn_lstm_op.cu.cc new file mode 100644 index 000000000..e01070c7b --- /dev/null +++ b/paddle/fluid/operators/cudnn_lstm_op.cu.cc @@ -0,0 +1,485 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/cudnn_helper.h" + +namespace paddle { +namespace operators { + +using LoDTensor = framework::LoDTensor; +using Tensor = framework::Tensor; + +struct CudnnRNNCache { + CudnnRNNCache() { + x_desc_ = NULL; + y_desc_ = NULL; + dx_desc_ = NULL; + dy_desc_ = NULL; + } + ~CudnnRNNCache() { release(); } + + cudnnRNNDescriptor_t rnn_desc_; + cudnnTensorDescriptor_t *x_desc_; + cudnnTensorDescriptor_t *y_desc_; + cudnnTensorDescriptor_t *dx_desc_; + cudnnTensorDescriptor_t *dy_desc_; + + cudnnTensorDescriptor_t hx_desc_; + cudnnTensorDescriptor_t cx_desc_; + cudnnTensorDescriptor_t hy_desc_; + cudnnTensorDescriptor_t cy_desc_; + + cudnnTensorDescriptor_t dhx_desc_; + cudnnTensorDescriptor_t dcx_desc_; + cudnnTensorDescriptor_t dhy_desc_; + cudnnTensorDescriptor_t dcy_desc_; + + cudnnTensorDescriptor_t output_x_desc_; + cudnnTensorDescriptor_t output_y_desc_; + + cudnnDropoutDescriptor_t dropout_desc_; + + size_t weights_size_; + cudnnFilterDescriptor_t w_desc_; + cudnnFilterDescriptor_t dw_desc_; + + size_t workspace_size_; + size_t reserve_size_; + Tensor reserve_data_; + Tensor workspace_data_; + + Tensor dropout_state_; + + size_t max_length_; + + float dropout_prob_; + bool is_bidirec_; + + int batch_size_; + int input_size_; + int hidden_size_; + int num_layers_; + int seed_; + + void init(cudnnHandle_t handle, const framework::ExecutionContext &ctx, + size_t max_len, int batch_size, int input_size, int hidden_size, + int num_layers, float dropout_prob, bool is_bidirec, int seed, + int weight_numel) { + max_length_ = max_len; + batch_size_ = batch_size; + input_size_ = input_size; + hidden_size_ = hidden_size; + num_layers_ = num_layers; + dropout_prob_ = dropout_prob; + is_bidirec_ = is_bidirec; + seed_ = seed; + + x_desc_ = new cudnnTensorDescriptor_t[max_length_]; + y_desc_ = new cudnnTensorDescriptor_t[max_length_]; + dx_desc_ = new cudnnTensorDescriptor_t[max_length_]; + dy_desc_ = new cudnnTensorDescriptor_t[max_length_]; + int dim_a[3]; + int stride_a[3]; + + for (size_t i = 0; i < max_length_; ++i) { + CUDNN_ENFORCE( + platform::dynload::cudnnCreateTensorDescriptor(&x_desc_[i])); + CUDNN_ENFORCE( + platform::dynload::cudnnCreateTensorDescriptor(&y_desc_[i])); + CUDNN_ENFORCE( + platform::dynload::cudnnCreateTensorDescriptor(&dx_desc_[i])); + CUDNN_ENFORCE( + platform::dynload::cudnnCreateTensorDescriptor(&dy_desc_[i])); + dim_a[0] = batch_size_; + dim_a[1] = input_size_; + dim_a[2] = 1; + + stride_a[0] = dim_a[2] * dim_a[1]; + stride_a[1] = dim_a[2]; + stride_a[2] = 1; + CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( + x_desc_[i], CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); + CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( + dx_desc_[i], CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); + + dim_a[0] = batch_size_; + dim_a[1] = is_bidirec_ ? hidden_size_ * 2 : hidden_size_; + dim_a[2] = 1; + + stride_a[0] = dim_a[2] * dim_a[1]; + stride_a[1] = dim_a[2]; + stride_a[2] = 1; + + CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( + y_desc_[i], CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); + CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( + dy_desc_[i], CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); + } + + dim_a[0] = num_layers_ * (is_bidirec_ ? 2 : 1); + dim_a[1] = batch_size_; + dim_a[2] = hidden_size_; + + stride_a[0] = dim_a[2] * dim_a[1]; + stride_a[1] = dim_a[2]; + stride_a[2] = 1; + + CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&hx_desc_)); + CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&cx_desc_)); + CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&hy_desc_)); + CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&cy_desc_)); + CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&dhx_desc_)); + CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&dcx_desc_)); + CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&dhy_desc_)); + CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&dcy_desc_)); + + CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( + hx_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); + CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( + cx_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); + CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( + hy_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); + CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( + cy_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); + CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( + dhx_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); + CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( + dcx_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); + CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( + dhy_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); + CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( + dcy_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); + + CUDNN_ENFORCE( + platform::dynload::cudnnCreateDropoutDescriptor(&dropout_desc_)); + + size_t state_size; + CUDNN_ENFORCE( + platform::dynload::cudnnDropoutGetStatesSize(handle, &state_size); + dropout_state_.Resize({static_cast(state_size)})); + auto *dropout_state_data = + dropout_state_.mutable_data(ctx.GetPlace()); + CUDNN_ENFORCE(platform::dynload::cudnnSetDropoutDescriptor( + dropout_desc_, handle, dropout_prob_, dropout_state_data, state_size, + seed_)); + + CUDNN_ENFORCE(platform::dynload::cudnnCreateRNNDescriptor(&rnn_desc_)); + CUDNN_ENFORCE(platform::dynload::cudnnSetRNNDescriptor_v6( + handle, rnn_desc_, hidden_size_, num_layers_, dropout_desc_, + CUDNN_LINEAR_INPUT, + is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, CUDNN_LSTM, + CUDNN_RNN_ALGO_STANDARD, CUDNN_DATA_FLOAT)); + + CUDNN_ENFORCE(platform::dynload::cudnnCreateFilterDescriptor(&w_desc_)); + CUDNN_ENFORCE(platform::dynload::cudnnCreateFilterDescriptor(&dw_desc_)); + + CUDNN_ENFORCE(platform::dynload::cudnnGetRNNParamsSize( + handle, rnn_desc_, x_desc_[0], &weights_size_, CUDNN_DATA_FLOAT)); + + PADDLE_ENFORCE_EQ(weights_size_, sizeof(float) * weight_numel, + "cudnn lstm weight size should be SAME"); + int dim_w[3]; + dim_w[0] = weights_size_ / sizeof(float); + dim_w[1] = 1; + dim_w[2] = 1; + CUDNN_ENFORCE(platform::dynload::cudnnSetFilterNdDescriptor( + w_desc_, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, 3, dim_w)); + CUDNN_ENFORCE(platform::dynload::cudnnSetFilterNdDescriptor( + dw_desc_, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, 3, dim_w)); + + CUDNN_ENFORCE(platform::dynload::cudnnGetRNNWorkspaceSize( + handle, rnn_desc_, max_length_, x_desc_, &workspace_size_)); + CUDNN_ENFORCE(platform::dynload::cudnnGetRNNTrainingReserveSize( + handle, rnn_desc_, max_length_, x_desc_, &reserve_size_)); + + reserve_data_.Resize({static_cast(reserve_size_)}); + reserve_data_.mutable_data(ctx.GetPlace()); + + workspace_data_.Resize({static_cast(workspace_size_)}); + workspace_data_.mutable_data(ctx.GetPlace()); + } + + void release() { + for (size_t i = 0; i < max_length_; ++i) { + CUDNN_ENFORCE( + platform::dynload::cudnnDestroyTensorDescriptor(x_desc_[i])); + CUDNN_ENFORCE( + platform::dynload::cudnnDestroyTensorDescriptor(y_desc_[i])); + CUDNN_ENFORCE( + platform::dynload::cudnnDestroyTensorDescriptor(dx_desc_[i])); + CUDNN_ENFORCE( + platform::dynload::cudnnDestroyTensorDescriptor(dy_desc_[i])); + } + + delete[] x_desc_; + delete[] y_desc_; + delete[] dx_desc_; + delete[] dy_desc_; + + CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(hx_desc_)); + CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(cx_desc_)); + CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(hy_desc_)); + CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(cy_desc_)); + CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(dhx_desc_)); + CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(dcx_desc_)); + CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(dhy_desc_)); + CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(dcy_desc_)); + + CUDNN_ENFORCE( + platform::dynload::cudnnDestroyDropoutDescriptor(dropout_desc_)); + CUDNN_ENFORCE(platform::dynload::cudnnDestroyRNNDescriptor(rnn_desc_)); + + CUDNN_ENFORCE(platform::dynload::cudnnDestroyFilterDescriptor(w_desc_)); + CUDNN_ENFORCE(platform::dynload::cudnnDestroyFilterDescriptor(dw_desc_)); + } +}; + +template +class CudnnLSTMGPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + const Tensor *x = ctx.Input("Input"); + const Tensor *init_h = ctx.Input("InitH"); + const Tensor *init_c = ctx.Input("InitC"); + + auto w = ctx.Input("W"); + + Tensor *out = ctx.Output("Out"); + Tensor *last_h = ctx.Output("last_h"); + Tensor *last_c = ctx.Output("last_c"); + + const T *x_data = x->data(); + const T *init_h_data = init_h->data(); + const T *init_c_data = init_c->data(); + + const T *w_data = w->data(); + + T *out_data = out->mutable_data(ctx.GetPlace()); + T *last_h_data = last_h->mutable_data(ctx.GetPlace()); + T *last_c_data = last_c->mutable_data(ctx.GetPlace()); + + size_t max_len = ctx.Attr("max_len"); + float dropout_prob = ctx.Attr("dropout_prob"); + bool is_bidirec = ctx.Attr("is_bidirec"); + int input_size = ctx.Attr("input_size"); + int hidden_size = ctx.Attr("hidden_size"); + int num_layers = ctx.Attr("num_layers"); + bool is_test = ctx.Attr("is_test"); + + auto &dev_ctx = ctx.template device_context(); + auto handle = dev_ctx.cudnn_handle(); + auto *cache_var = ctx.InputVar("Cache"); + if (!cache_var) { + // The RAW type cache variable wouldn't be created and broadcasted on + // multi-devices before the first running. + // use parent scope to make cache persistable + auto *scope = const_cast(ctx.scope().parent()); + auto cache_var_name = ctx.Inputs("Cache")[0]; + cache_var = scope->Var(cache_var_name); + } + CudnnRNNCache *cudnn_rnn_cache = nullptr; + if (cache_var->IsInitialized()) { + cudnn_rnn_cache = const_cast(cache_var) + ->GetMutable(); + } else { + cudnn_rnn_cache = const_cast(cache_var) + ->GetMutable(); + std::random_device rnd; + int seed = ctx.Attr("seed"); + if (seed == -1) { + seed = rnd(); + } + + auto input_w_numel = w->numel(); + auto batch_size = x->dims()[1]; + cudnn_rnn_cache->init(handle, ctx, max_len, batch_size, input_size, + hidden_size, num_layers, dropout_prob, is_bidirec, + seed, input_w_numel); + } + + auto run_seq_len = x->dims()[0]; + + if (is_test) { + // for inference + CUDNN_ENFORCE(platform::dynload::cudnnRNNForwardInference( + handle, cudnn_rnn_cache->rnn_desc_, run_seq_len, + cudnn_rnn_cache->x_desc_, x_data, cudnn_rnn_cache->hx_desc_, + init_h_data, cudnn_rnn_cache->cx_desc_, init_c_data, + cudnn_rnn_cache->w_desc_, w_data, cudnn_rnn_cache->y_desc_, out_data, + cudnn_rnn_cache->hy_desc_, last_h_data, cudnn_rnn_cache->cy_desc_, + last_c_data, cudnn_rnn_cache->workspace_data_.data(), + cudnn_rnn_cache->workspace_size_)); + } else { + // for train + CUDNN_ENFORCE(platform::dynload::cudnnRNNForwardTraining( + handle, cudnn_rnn_cache->rnn_desc_, run_seq_len, + cudnn_rnn_cache->x_desc_, x_data, cudnn_rnn_cache->hx_desc_, + init_h_data, cudnn_rnn_cache->cx_desc_, init_c_data, + cudnn_rnn_cache->w_desc_, w_data, cudnn_rnn_cache->y_desc_, out_data, + cudnn_rnn_cache->hy_desc_, last_h_data, cudnn_rnn_cache->cy_desc_, + last_c_data, cudnn_rnn_cache->workspace_data_.data(), + cudnn_rnn_cache->workspace_size_, + cudnn_rnn_cache->reserve_data_.data(), + cudnn_rnn_cache->reserve_size_)); + } + } +}; + +template +class CudnnLSTMGPUGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *input = ctx.Input("Input"); + auto *weight = ctx.Input("W"); + auto *init_h = ctx.Input("InitH"); + auto *init_c = ctx.Input("InitC"); + // auto * last_h = ctx.Input("last_h"); + // auto * last_c = ctx.Input("last_c"); + auto *out = ctx.Input("Out"); + auto *out_grad = ctx.Input(framework::GradVarName("Out")); + auto *last_h_grad = ctx.Input(framework::GradVarName("last_h")); + auto *last_c_grad = ctx.Input(framework::GradVarName("last_c")); + + // auto* init_h = ctx.Input("init_h"); + // auto* init_c = ctx.Input("init_c"); + + auto *in_grad = ctx.Output(framework::GradVarName("Input")); + auto *weight_grad = ctx.Output(framework::GradVarName("W")); + auto *init_h_grad = ctx.Output(framework::GradVarName("InitH")); + auto *init_c_grad = ctx.Output(framework::GradVarName("InitC")); + + auto &dev_ctx = ctx.template device_context(); + auto handle = dev_ctx.cudnn_handle(); + auto *cache_var = ctx.InputVar("Cache"); + PADDLE_ENFORCE(cache_var->IsInitialized()); + CudnnRNNCache *cudnn_rnn_cache = + const_cast(cache_var) + ->GetMutable(); + + auto input_dims = input->dims(); + auto weight_dims = weight->dims(); + auto init_h_dims = init_h->dims(); + auto init_c_dims = init_c->dims(); + in_grad->mutable_data(ctx.GetPlace()); + weight_grad->mutable_data(ctx.GetPlace()); + math::SetConstant zero; + zero(dev_ctx, in_grad, static_cast(0.0)); + zero(dev_ctx, weight_grad, static_cast(0.0)); + + T *init_h_grad_data = NULL; + if (init_h_grad == nullptr) { + Tensor init_h_grad_temp; + init_h_grad_temp.mutable_data(init_h_dims, ctx.GetPlace()); + zero(dev_ctx, &init_h_grad_temp, static_cast(0.0)); + + init_h_grad_data = init_h_grad_temp.data(); + } else { + init_h_grad->mutable_data(init_h_dims, ctx.GetPlace()); + zero(dev_ctx, init_h_grad, static_cast(0.0)); + init_h_grad_data = init_h_grad->data(); + } + + T *init_c_grad_data = NULL; + if (init_c_grad == nullptr) { + Tensor init_c_grad_temp; + init_c_grad_temp.mutable_data(init_c_dims, ctx.GetPlace()); + zero(dev_ctx, &init_c_grad_temp, static_cast(0.0)); + + init_c_grad_data = init_c_grad_temp.data(); + } else { + init_c_grad->mutable_data(init_c_dims, ctx.GetPlace()); + zero(dev_ctx, init_c_grad, static_cast(0.0)); + init_c_grad_data = init_c_grad->data(); + } + + const T *last_h_grad_data = NULL; + if (last_h_grad == nullptr) { + Tensor last_h_grad_temp; + last_h_grad_temp.mutable_data(init_h_dims, ctx.GetPlace()); + zero(dev_ctx, &last_h_grad_temp, static_cast(0.0)); + + last_h_grad_data = (const T *)last_h_grad_temp.data(); + } else { + last_h_grad_data = last_h_grad->data(); + } + + const T *last_c_grad_data = NULL; + if (last_c_grad == nullptr) { + Tensor last_c_grad_temp; + last_c_grad_temp.mutable_data(init_c_dims, ctx.GetPlace()); + zero(dev_ctx, &last_c_grad_temp, static_cast(0.0)); + + last_c_grad_data = (const T *)last_c_grad_temp.data(); + } else { + last_c_grad_data = last_c_grad->data(); + } + + const T *out_grad_data = NULL; + if (out_grad == nullptr) { + Tensor out_grad_temp; + out_grad_temp.mutable_data(out->dims(), ctx.GetPlace()); + zero(dev_ctx, &out_grad_temp, static_cast(0.0)); + + out_grad_data = (const T *)out_grad_temp.data(); + } else { + out_grad_data = out_grad->data(); + } + + // zero( dev_ctx, last_h_grad, static_cast(0.0)); + // zero( dev_ctx, last_c_grad, static_cast(0.0)); + + auto out_data = out->data(); + // auto out_grad_data = out_grad->data(); + auto weight_data = weight->data(); + auto init_h_data = init_h->data(); + auto init_c_data = init_c->data(); + auto in_grad_data = in_grad->data(); + + auto work_data = cudnn_rnn_cache->workspace_data_.data(); + auto reserve_data = cudnn_rnn_cache->reserve_data_.data(); + + auto run_seq_len = input_dims[0]; + PADDLE_ENFORCE_LE((size_t)run_seq_len, cudnn_rnn_cache->max_length_, + "cudnn running seq_len CAN not greater max_lengh"); + CUDNN_ENFORCE(platform::dynload::cudnnRNNBackwardData( + handle, cudnn_rnn_cache->rnn_desc_, run_seq_len, + cudnn_rnn_cache->y_desc_, out_data, cudnn_rnn_cache->dy_desc_, + out_grad_data, cudnn_rnn_cache->dhy_desc_, last_h_grad_data, + cudnn_rnn_cache->dcy_desc_, last_c_grad_data, cudnn_rnn_cache->w_desc_, + weight_data, cudnn_rnn_cache->hx_desc_, init_h_data, + cudnn_rnn_cache->cx_desc_, init_c_data, cudnn_rnn_cache->dx_desc_, + in_grad_data, cudnn_rnn_cache->dhx_desc_, init_h_grad_data, + cudnn_rnn_cache->dcx_desc_, init_c_grad_data, work_data, + cudnn_rnn_cache->workspace_size_, reserve_data, + cudnn_rnn_cache->reserve_size_)); + + CUDNN_ENFORCE(platform::dynload::cudnnRNNBackwardWeights( + handle, cudnn_rnn_cache->rnn_desc_, run_seq_len, + cudnn_rnn_cache->x_desc_, input->data(), cudnn_rnn_cache->hx_desc_, + init_h->data(), cudnn_rnn_cache->y_desc_, out->data(), + cudnn_rnn_cache->workspace_data_.data(), + cudnn_rnn_cache->workspace_size_, cudnn_rnn_cache->dw_desc_, + weight_grad->data(), cudnn_rnn_cache->reserve_data_.data(), + cudnn_rnn_cache->reserve_size_)); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(cudnn_lstm, ops::CudnnLSTMGPUKernel); +REGISTER_OP_CUDA_KERNEL(cudnn_lstm_grad, ops::CudnnLSTMGPUGradKernel); diff --git a/python/paddle/fluid/tests/unittests/test_lstm_cudnn_op.py b/python/paddle/fluid/tests/unittests/test_lstm_cudnn_op.py new file mode 100644 index 000000000..0e9e2e842 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_lstm_cudnn_op.py @@ -0,0 +1,192 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np + +import paddle.fluid.core as core +from op_test import OpTest +import paddle.fluid as fluid + +SIGMOID_THRESHOLD_MIN = -40.0 +SIGMOID_THRESHOLD_MAX = 13.0 +EXP_MAX_INPUT = 40.0 + + +def lstm_naive( + input, + w, ): + seq_len, batch_size, hidden_size = input.shape + + offset = 0 + wi = w[offset:offset + hidden_size * hidden_size].reshape( + (hidden_size, hidden_size)).transpose() + offset += hidden_size * hidden_size + wf = w[offset:offset + hidden_size * hidden_size].reshape( + (hidden_size, hidden_size)).transpose() + offset += hidden_size * hidden_size + wc = w[offset:offset + hidden_size * hidden_size].reshape( + (hidden_size, hidden_size)).transpose() + offset += hidden_size * hidden_size + wo = w[offset:offset + hidden_size * hidden_size].reshape( + (hidden_size, hidden_size)).transpose() + offset += hidden_size * hidden_size + ri = w[offset:offset + hidden_size * hidden_size].reshape( + (hidden_size, hidden_size)).transpose() + offset += hidden_size * hidden_size + rf = w[offset:offset + hidden_size * hidden_size].reshape( + (hidden_size, hidden_size)).transpose() + offset += hidden_size * hidden_size + rc = w[offset:offset + hidden_size * hidden_size].reshape( + (hidden_size, hidden_size)).transpose() + offset += hidden_size * hidden_size + ro = w[offset:offset + hidden_size * hidden_size].reshape( + (hidden_size, hidden_size)).transpose() + offset += hidden_size * hidden_size + + bi_1 = w[offset:offset + hidden_size] + offset += hidden_size + bf_1 = w[offset:offset + hidden_size] + offset += hidden_size + bc_1 = w[offset:offset + hidden_size] + offset += hidden_size + bo_1 = w[offset:offset + hidden_size] + offset += hidden_size + + bi_2 = w[offset:offset + hidden_size] + offset += hidden_size + bf_2 = w[offset:offset + hidden_size] + offset += hidden_size + bc_2 = w[offset:offset + hidden_size] + offset += hidden_size + bo_2 = w[offset:offset + hidden_size] + + def sigmoid(x): + y = np.copy(x) + y[x < SIGMOID_THRESHOLD_MIN] = SIGMOID_THRESHOLD_MIN + y[x > SIGMOID_THRESHOLD_MAX] = SIGMOID_THRESHOLD_MAX + return 1. / (1. + np.exp(-y)) + + def tanh(x): + y = -2. * x + y[y > EXP_MAX_INPUT] = EXP_MAX_INPUT + return (2. / (1. + np.exp(y))) - 1. + + output = [] + pre_h = np.zeros((batch_size, hidden_size), dtype=input.dtype) + pre_c = np.zeros((batch_size, hidden_size), dtype=input.dtype) + + for i in range(seq_len): + emb_1 = input[i] + + input_gate = sigmoid( + np.matmul(emb_1, wi) + np.matmul(pre_h, ri) + bi_1 + bi_2) + forget_gate = sigmoid( + np.matmul(emb_1, wf) + np.matmul(pre_h, rf) + bf_1 + bf_2) + output_gate = sigmoid( + np.matmul(emb_1, wo) + np.matmul(pre_h, ro) + bo_1 + bo_2) + c_t_temp = tanh( + np.matmul(emb_1, wc) + np.matmul(pre_h, rc) + bc_1 + bc_2) + new_c = input_gate * c_t_temp + forget_gate * pre_c + new_h = output_gate * tanh(new_c) + + pre_h = new_h + pre_c = new_c + + output.append(new_h) + + output = np.concatenate(output, -1) + output = output.reshape((batch_size, -1, hidden_size)) + + output = output.transpose((1, 0, 2)) + + return output, pre_h, pre_c + + +class TestCUDNNLstmOp(OpTest): + def setUp(self): + self.op_type = "cudnn_lstm" + self.dtype = np.float32 + + num_steps = 20 + batch_size = 5 + hidden_size = 20 + + input_weight_size = (hidden_size * hidden_size) * 4 + hidden_weight_size = (hidden_size * hidden_size) * 4 + weight_size = input_weight_size + hidden_weight_size + weight_size += hidden_size * 8 + + input = np.random.uniform( + low=-0.1, high=0.1, size=(num_steps, batch_size, + hidden_size)).astype(self.dtype) + flat_w = np.random.uniform( + low=-0.1, high=0.1, size=(weight_size)).astype(self.dtype) + + output, last_hidden, last_cell = lstm_naive(input, flat_w) + + init_h = np.zeros((batch_size, hidden_size), dtype=np.float32) + init_c = np.zeros((batch_size, hidden_size), dtype=np.float32) + scope = core.Scope() + program = fluid.Program() + block = program.global_block() + + cache_temp = block.create_var( + name="Cache", + persistable=True, + type=core.VarDesc.VarType.RAW, + stop_gradient=True) + self.inputs = { + 'Input': OpTest.np_dtype_to_fluid_dtype(input), + 'W': OpTest.np_dtype_to_fluid_dtype(flat_w), + 'InitH': OpTest.np_dtype_to_fluid_dtype(init_h), + 'InitC': OpTest.np_dtype_to_fluid_dtype(init_c), + } + self.cache_name_list = ['Cache'] + self.attrs = { + 'max_len': num_steps, + 'dropout_prob': 0.0, + 'is_bidirec': False, + 'input_size': hidden_size, + 'hidden_size': hidden_size, + 'num_layers': 1, + } + self.outputs = { + 'Out': output, + "last_h": last_hidden, + 'last_c': last_cell + } + + def test_output_with_place(self): + if self.testcuda(): + place = core.CUDAPlace(0) + self.check_output_with_place(place, atol=1e-5) + + def test_grad_with_place(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, + set(['Input', 'W', 'InitH', 'InitC']), + ['Out', 'last_h', 'last_c'], + max_relative_error=0.02) + + def testcuda(self): + return core.is_compiled_with_cuda() + + +if __name__ == '__main__': + unittest.main() -- GitLab