diff --git a/paddle/fluid/operators/cudnn_lstm_op.cc b/paddle/fluid/operators/cudnn_lstm_op.cc index 16e2ca464b5c4de6aa65109cd794d17e4dcd6a2a..7081490fd1bf0e26cb8aa90d69a76a5476cef044 100644 --- a/paddle/fluid/operators/cudnn_lstm_op.cc +++ b/paddle/fluid/operators/cudnn_lstm_op.cc @@ -24,34 +24,62 @@ class CudnnLSTMOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Input"), - "Input(Input) of LSTM should not be null."); - PADDLE_ENFORCE(ctx->HasInput("W"), - "Input(Weight) of LSTM should not be null."); - - PADDLE_ENFORCE(ctx->HasInput("InitH"), - "Input(init_h) of LSTM should not be null."); - PADDLE_ENFORCE(ctx->HasInput("InitC"), - "Input(init_c) of LSTM should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Cache"), - "Input(Cache) of LSTM should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of LSTM should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("last_h"), - "Output(last_h) of LSTM should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("last_c"), - "Output(last_c) of LSTM should not be null."); + OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "CudnnLSTM"); + OP_INOUT_CHECK(ctx->HasInput("W"), "Input", "W", "CudnnLSTM"); + OP_INOUT_CHECK(ctx->HasInput("InitH"), "Input", "InitH", "CudnnLSTM"); + OP_INOUT_CHECK(ctx->HasInput("InitC"), "Input", "InitC", "CudnnLSTM"); + + OP_INOUT_CHECK(ctx->HasOutput("Reserve"), "Output", "Reserve", "CudnnLSTM"); + OP_INOUT_CHECK(ctx->HasOutput("StateOut"), "Output", "StateOut", + "CudnnLSTM"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "CudnnLSTM"); + OP_INOUT_CHECK(ctx->HasOutput("LastH"), "Output", "LastH", "CudnnLSTM"); + OP_INOUT_CHECK(ctx->HasOutput("LastC"), "Output", "LastC", "CudnnLSTM"); auto in_dims = ctx->GetInputDim("Input"); - PADDLE_ENFORCE_EQ(in_dims.size(), 3, "Input(X)'s rank must be 3."); + auto init_dims = ctx->GetInputDim("InitH"); + PADDLE_ENFORCE_EQ(in_dims.size(), 3, + platform::errors::InvalidArgument( + "The rank of Input in CudnnLSTM must be 3. But " + "received Input's rank is %d.", + in_dims.size())); + PADDLE_ENFORCE_EQ(init_dims.size(), 3, + platform::errors::InvalidArgument( + "The rank of InitH in CudnnLSTM must be 3. But " + "received InitH's rank is %d.", + init_dims.size())); + + PADDLE_ENFORCE_EQ(in_dims[1], init_dims[1], + platform::errors::InvalidArgument( + "The in_dims[1] (Input dims) and init_dims[1] (InitH " + "dims) should be equal. But " + "received in_dims[1] is %d and init_dims[1] is %d.", + in_dims[1], init_dims[1])); + PADDLE_ENFORCE_EQ(in_dims[2], init_dims[2], + platform::errors::InvalidArgument( + "The in_dims[2] (Input dims) and init_dims[2] (InitH " + "dims) should be equal. But " + "received in_dims[2] is %d and init_dims[2] is %d.", + in_dims[2], init_dims[2])); auto out_dims = in_dims; auto hidden_size = ctx->Attrs().Get("hidden_size"); - out_dims[2] = hidden_size; + bool is_bidirec = ctx->Attrs().Get("is_bidirec"); + out_dims[2] = is_bidirec ? hidden_size * 2 : hidden_size; + auto last_dims = init_dims; + last_dims[0] = is_bidirec ? last_dims[0] * 2 : last_dims[0]; ctx->SetOutputDim("Out", out_dims); - ctx->SetOutputDim("last_h", ctx->GetInputDim("InitH")); - ctx->SetOutputDim("last_c", ctx->GetInputDim("InitC")); + ctx->SetOutputDim("LastH", last_dims); + ctx->SetOutputDim("LastC", last_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); } }; @@ -84,33 +112,31 @@ class CudnnLSTMOpMaker : public framework::OpProtoAndCheckerMaker { "(Tensor) the learnable hidden-hidden weights." " The shape is (N), where N is total weight size of the LSTM. " " cudnn concatenate all the weight to one Tensor"); - AddInput("Cache", - "The cache of dropout op, a RAW type variable including random " - "number generator states and some descriptors, which is used in " - "cudnn kernel.") - .AsDispensable(); + AddOutput("Reserve", + "(Tensor, a temporary output Tensor to store the reserve_data " + "of cudnn kernel.") + .AsIntermediate(); + AddOutput("StateOut", + "Share memory with State. " + "Store the global drop state when training"); AddOutput("Out", "(Tensor) the hidden state of LSTM operator. " "The shape is ( seq_len x batch_size x hidden_size) if " "is_bidirec is False" "and When is_bidirec is True, the shape will be ( seq_len x " "batch_size x hidden_size * 2) "); - AddOutput("last_h", + AddOutput("LastH", "(Tensor) the hidden state of the last step. " "The shape is ( num_layers x batch_size x hidden_size) if " "is_bidirec is False" "and When is_bidirec is True, the shape will be (num_layers*2 x " "batch_size x hidden_size)"); - AddOutput("last_c", + AddOutput("LastC", "(Tensor) the cell state of the last step" "The shape is ( num_layers x batch_size x hidden_size) if " "is_bidirec is False" "and When is_bidirect is True, the shape will be (num_layers*2 x " "batch_size x hidden_size*2)"); - AddAttr("max_len", - "max length of the LSTM op" - "the first dim of the Input can NOT be greater than max_len") - .SetDefault(20); AddAttr( "dropout_prob", "dropout prob of the dropout op" @@ -120,14 +146,14 @@ class CudnnLSTMOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("is_bidirec", "is_bidirec" "if it is bidirectional rnn" - "The will affect the shape of the Out, last_h, and last_c") + "The will affect the shape of the Out, LastH, and LastC") .SetDefault(false); AddAttr("input_size", "input size ot the Input Tensor").SetDefault(10); AddAttr("hidden_size", "hidden size of the LSTM").SetDefault(100); AddAttr("num_layers", "the total layer number of the LSTM") .SetDefault(1); AddAttr("is_test", "True if in test phase.").SetDefault(false); - AddAttr("seed", "seed to used if fix_seed is True").SetDefault(-1); + AddAttr("seed", "seed to used if fix_seed is True").SetDefault(0); AddComment(R"DOC( CUDNN LSTM implementation @@ -172,16 +198,10 @@ class CudnnLSTMGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Input"), - "Input(Input) of LSTM should not be null."); - PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) of LSTM should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Cache"), - "Input(last_c) of LSTM should not be null."); - PADDLE_ENFORCE(ctx->HasInput("InitH"), - "Input(init_h) of LSTM should not be null."); - - PADDLE_ENFORCE(ctx->HasInput("InitC"), - "Input(init_c) of LSTM should not be null."); + OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "CudnnLSTMGrad"); + OP_INOUT_CHECK(ctx->HasInput("W"), "Input", "W", "CudnnLSTMGrad"); + OP_INOUT_CHECK(ctx->HasInput("InitH"), "Input", "InitH", "CudnnLSTMGrad"); + OP_INOUT_CHECK(ctx->HasInput("InitC"), "Input", "InitC", "CudnnLSTMGrad"); auto SetOutGradDim = [&ctx](const std::string& name) { auto g_name = framework::GradVarName(name); @@ -195,6 +215,12 @@ class CudnnLSTMGradOp : public framework::OperatorWithKernel { SetOutGradDim("InitH"); SetOutGradDim("InitC"); } + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); + } }; template @@ -209,13 +235,12 @@ class CudnnLSTMGradOpMaker : public framework::SingleGradOpMaker { op->SetInput("InitH", this->Input("InitH")); op->SetInput("InitC", this->Input("InitC")); op->SetInput("W", this->Input("W")); - if (this->HasInput("Cache")) { - op->SetInput("Cache", this->Input("Cache")); - } + op->SetInput("Reserve", this->Output("Reserve")); + op->SetInput("StateOut", this->Output("StateOut")); op->SetInput("Out", this->Output("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - op->SetInput(framework::GradVarName("last_c"), this->OutputGrad("last_c")); - op->SetInput(framework::GradVarName("last_h"), this->OutputGrad("last_h")); + op->SetInput(framework::GradVarName("LastC"), this->OutputGrad("LastC")); + op->SetInput(framework::GradVarName("LastH"), this->OutputGrad("LastH")); op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input")); op->SetOutput(framework::GradVarName("W"), this->InputGrad("W")); diff --git a/paddle/fluid/operators/cudnn_lstm_op.cu.cc b/paddle/fluid/operators/cudnn_lstm_op.cu.cc index 579dddee8e82183b778f03595bb4657002262073..37e5e518ea2af9bb437775c8fa7e86816bb1d8ae 100644 --- a/paddle/fluid/operators/cudnn_lstm_op.cu.cc +++ b/paddle/fluid/operators/cudnn_lstm_op.cu.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/cudnn_rnn_cache.h" #include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/cudnn_desc.h" namespace paddle { namespace operators { @@ -33,8 +34,10 @@ class CudnnLSTMGPUKernel : public framework::OpKernel { auto w = ctx.Input("W"); Tensor *out = ctx.Output("Out"); - Tensor *last_h = ctx.Output("last_h"); - Tensor *last_c = ctx.Output("last_c"); + Tensor *last_h = ctx.Output("LastH"); + Tensor *last_c = ctx.Output("LastC"); + Tensor *reserve = ctx.Output("Reserve"); + Tensor *state_out = ctx.Output("StateOut"); const T *x_data = x->data(); const T *init_h_data = init_h->data(); @@ -46,72 +49,56 @@ class CudnnLSTMGPUKernel : public framework::OpKernel { T *last_h_data = last_h->mutable_data(ctx.GetPlace()); T *last_c_data = last_c->mutable_data(ctx.GetPlace()); - size_t max_len = ctx.Attr("max_len"); float dropout_prob = ctx.Attr("dropout_prob"); bool is_bidirec = ctx.Attr("is_bidirec"); - int input_size = ctx.Attr("input_size"); int hidden_size = ctx.Attr("hidden_size"); int num_layers = ctx.Attr("num_layers"); bool is_test = ctx.Attr("is_test"); + int seed = ctx.Attr("seed"); auto &dev_ctx = ctx.template device_context(); auto handle = dev_ctx.cudnn_handle(); - auto *cache_var = ctx.InputVar("Cache"); - if (!cache_var) { - // The RAW type cache variable wouldn't be created and broadcasted on - // multi-devices before the first running. - // use parent scope to make cache persistable - auto *scope = const_cast(ctx.scope().parent()); - auto cache_var_name = ctx.InputNames("Cache")[0]; - cache_var = scope->Var(cache_var_name); - } - CudnnRNNCache *cudnn_rnn_cache = nullptr; - if (cache_var->IsInitialized()) { - // const_cast is usually bad. - cudnn_rnn_cache = const_cast(cache_var) - ->GetMutable(); - } else { - // const_cast is usually bad. - cudnn_rnn_cache = const_cast(cache_var) - ->GetMutable(); - std::random_device rnd; - int seed = ctx.Attr("seed"); - if (seed == -1) { - seed = rnd(); - } - - auto input_w_numel = w->numel(); - auto batch_size = x->dims()[1]; - cudnn_rnn_cache->init(handle, ctx.GetPlace(), max_len, batch_size, - input_size, hidden_size, num_layers, dropout_prob, - is_bidirec, seed, input_w_numel); - } - auto run_seq_len = x->dims()[0]; + CudnnRNNCache *cudnn_rnn_cache = new CudnnRNNCache(); + + auto input_w_numel = w->numel(); + auto seq_len = x->dims()[0]; + auto batch_size = x->dims()[1]; + auto input_dim = x->dims()[2]; + size_t reserve_size; + bool state_initialized = state_out->IsInitialized() ? true : false; + cudnnDataType_t cudnn_type = platform::ToCudnnDataType( + framework::ToDataType(std::type_index(typeid(T)))); + cudnn_rnn_cache->init(handle, ctx.GetPlace(), seq_len, batch_size, + input_dim, hidden_size, num_layers, dropout_prob, + is_bidirec, seed, input_w_numel, &reserve_size, + state_out, state_initialized, cudnn_type); + + auto *reserve_data = reserve->mutable_data( + {static_cast(reserve_size)}, ctx.GetPlace()); if (is_test) { // for inference PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardInference( - handle, cudnn_rnn_cache->rnn_desc_, run_seq_len, - cudnn_rnn_cache->x_desc_, x_data, cudnn_rnn_cache->hx_desc_, - init_h_data, cudnn_rnn_cache->cx_desc_, init_c_data, - cudnn_rnn_cache->w_desc_, w_data, cudnn_rnn_cache->y_desc_, out_data, - cudnn_rnn_cache->hy_desc_, last_h_data, cudnn_rnn_cache->cy_desc_, - last_c_data, cudnn_rnn_cache->workspace_data_.data(), + handle, cudnn_rnn_cache->rnn_desc_, seq_len, cudnn_rnn_cache->x_desc_, + x_data, cudnn_rnn_cache->hx_desc_, init_h_data, + cudnn_rnn_cache->cx_desc_, init_c_data, cudnn_rnn_cache->w_desc_, + w_data, cudnn_rnn_cache->y_desc_, out_data, cudnn_rnn_cache->hy_desc_, + last_h_data, cudnn_rnn_cache->cy_desc_, last_c_data, + cudnn_rnn_cache->workspace_data_.data(), cudnn_rnn_cache->workspace_size_)); } else { // for train PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardTraining( - handle, cudnn_rnn_cache->rnn_desc_, run_seq_len, - cudnn_rnn_cache->x_desc_, x_data, cudnn_rnn_cache->hx_desc_, - init_h_data, cudnn_rnn_cache->cx_desc_, init_c_data, - cudnn_rnn_cache->w_desc_, w_data, cudnn_rnn_cache->y_desc_, out_data, - cudnn_rnn_cache->hy_desc_, last_h_data, cudnn_rnn_cache->cy_desc_, - last_c_data, cudnn_rnn_cache->workspace_data_.data(), - cudnn_rnn_cache->workspace_size_, - cudnn_rnn_cache->reserve_data_.data(), - cudnn_rnn_cache->reserve_size_)); + handle, cudnn_rnn_cache->rnn_desc_, seq_len, cudnn_rnn_cache->x_desc_, + x_data, cudnn_rnn_cache->hx_desc_, init_h_data, + cudnn_rnn_cache->cx_desc_, init_c_data, cudnn_rnn_cache->w_desc_, + w_data, cudnn_rnn_cache->y_desc_, out_data, cudnn_rnn_cache->hy_desc_, + last_h_data, cudnn_rnn_cache->cy_desc_, last_c_data, + cudnn_rnn_cache->workspace_data_.data(), + cudnn_rnn_cache->workspace_size_, reserve_data, reserve_size)); } + delete cudnn_rnn_cache; } }; @@ -123,15 +110,13 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel { auto *weight = ctx.Input("W"); auto *init_h = ctx.Input("InitH"); auto *init_c = ctx.Input("InitC"); - // auto * last_h = ctx.Input("last_h"); - // auto * last_c = ctx.Input("last_c"); + auto *reserve = ctx.Input("Reserve"); + auto *state_out = ctx.Input("StateOut"); + auto *out = ctx.Input("Out"); auto *out_grad = ctx.Input(framework::GradVarName("Out")); - auto *last_h_grad = ctx.Input(framework::GradVarName("last_h")); - auto *last_c_grad = ctx.Input(framework::GradVarName("last_c")); - - // auto* init_h = ctx.Input("init_h"); - // auto* init_c = ctx.Input("init_c"); + auto *last_h_grad = ctx.Input(framework::GradVarName("LastH")); + auto *last_c_grad = ctx.Input(framework::GradVarName("LastC")); auto *in_grad = ctx.Output(framework::GradVarName("Input")); auto *weight_grad = ctx.Output(framework::GradVarName("W")); @@ -140,116 +125,75 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel { auto &dev_ctx = ctx.template device_context(); auto handle = dev_ctx.cudnn_handle(); - auto *cache_var = ctx.InputVar("Cache"); - PADDLE_ENFORCE(cache_var->IsInitialized()); - CudnnRNNCache *cudnn_rnn_cache = - const_cast(cache_var) - ->GetMutable(); auto input_dims = input->dims(); auto init_h_dims = init_h->dims(); auto init_c_dims = init_c->dims(); - in_grad->mutable_data(ctx.GetPlace()); - weight_grad->mutable_data(ctx.GetPlace()); - math::SetConstant zero; - zero(dev_ctx, in_grad, static_cast(0.0)); - zero(dev_ctx, weight_grad, static_cast(0.0)); - - T *init_h_grad_data = NULL; - if (init_h_grad == nullptr) { - Tensor init_h_grad_temp; - init_h_grad_temp.mutable_data(init_h_dims, ctx.GetPlace()); - zero(dev_ctx, &init_h_grad_temp, static_cast(0.0)); - - init_h_grad_data = init_h_grad_temp.data(); - } else { - init_h_grad->mutable_data(init_h_dims, ctx.GetPlace()); - zero(dev_ctx, init_h_grad, static_cast(0.0)); - init_h_grad_data = init_h_grad->data(); - } - - T *init_c_grad_data = NULL; - if (init_c_grad == nullptr) { - Tensor init_c_grad_temp; - init_c_grad_temp.mutable_data(init_c_dims, ctx.GetPlace()); - zero(dev_ctx, &init_c_grad_temp, static_cast(0.0)); - init_c_grad_data = init_c_grad_temp.data(); - } else { - init_c_grad->mutable_data(init_c_dims, ctx.GetPlace()); - zero(dev_ctx, init_c_grad, static_cast(0.0)); - init_c_grad_data = init_c_grad->data(); - } + auto *weight_data = weight->data(); + auto *init_h_data = init_h->data(); + auto *init_c_data = init_c->data(); + auto *out_data = out->data(); + auto *out_grad_data = out_grad->data(); + auto *last_h_grad_data = last_h_grad->data(); + auto *last_c_grad_data = last_c_grad->data(); - const T *last_h_grad_data = NULL; - if (last_h_grad == nullptr) { - Tensor last_h_grad_temp; - last_h_grad_temp.mutable_data(init_h_dims, ctx.GetPlace()); - zero(dev_ctx, &last_h_grad_temp, static_cast(0.0)); - - last_h_grad_data = (const T *)last_h_grad_temp.data(); - } else { - last_h_grad_data = last_h_grad->data(); - } - - const T *last_c_grad_data = NULL; - if (last_c_grad == nullptr) { - Tensor last_c_grad_temp; - last_c_grad_temp.mutable_data(init_c_dims, ctx.GetPlace()); - zero(dev_ctx, &last_c_grad_temp, static_cast(0.0)); - - last_c_grad_data = (const T *)last_c_grad_temp.data(); - } else { - last_c_grad_data = last_c_grad->data(); - } + math::SetConstant zero; + weight_grad->mutable_data(ctx.GetPlace()); + zero(dev_ctx, weight_grad, static_cast(0.0)); - const T *out_grad_data = NULL; - if (out_grad == nullptr) { - Tensor out_grad_temp; - out_grad_temp.mutable_data(out->dims(), ctx.GetPlace()); - zero(dev_ctx, &out_grad_temp, static_cast(0.0)); + in_grad->mutable_data(input_dims, ctx.GetPlace()); + auto *in_grad_data = in_grad->data(); - out_grad_data = (const T *)out_grad_temp.data(); - } else { - out_grad_data = out_grad->data(); - } + init_h_grad->mutable_data(init_h_dims, ctx.GetPlace()); + auto *init_h_grad_data = init_h_grad->data(); - // zero( dev_ctx, last_h_grad, static_cast(0.0)); - // zero( dev_ctx, last_c_grad, static_cast(0.0)); + init_c_grad->mutable_data(init_c_dims, ctx.GetPlace()); + auto *init_c_grad_data = init_c_grad->data(); - auto out_data = out->data(); - // auto out_grad_data = out_grad->data(); - auto weight_data = weight->data(); - auto init_h_data = init_h->data(); - auto init_c_data = init_c->data(); - auto in_grad_data = in_grad->data(); + float dropout_prob = ctx.Attr("dropout_prob"); + bool is_bidirec = ctx.Attr("is_bidirec"); + int hidden_size = ctx.Attr("hidden_size"); + int num_layers = ctx.Attr("num_layers"); + int seed = ctx.Attr("seed"); + + CudnnRNNCache *cudnn_rnn_cache = new CudnnRNNCache(); + + auto input_w_numel = weight->numel(); + auto seq_len = input_dims[0]; + auto batch_size = input->dims()[1]; + auto input_dim = input->dims()[2]; + size_t reserve_size; + cudnnDataType_t cudnn_type = platform::ToCudnnDataType( + framework::ToDataType(std::type_index(typeid(T)))); + cudnn_rnn_cache->init(handle, ctx.GetPlace(), seq_len, batch_size, + input_dim, hidden_size, num_layers, dropout_prob, + is_bidirec, seed, input_w_numel, &reserve_size, + const_cast(state_out), true, cudnn_type); auto work_data = cudnn_rnn_cache->workspace_data_.data(); - auto reserve_data = cudnn_rnn_cache->reserve_data_.data(); + const uint8_t *reserve_data = reserve->data(); - auto run_seq_len = input_dims[0]; - PADDLE_ENFORCE_LE((size_t)run_seq_len, cudnn_rnn_cache->max_length_, - "cudnn running seq_len CAN not greater max_lengh"); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardData( - handle, cudnn_rnn_cache->rnn_desc_, run_seq_len, - cudnn_rnn_cache->y_desc_, out_data, cudnn_rnn_cache->dy_desc_, - out_grad_data, cudnn_rnn_cache->dhy_desc_, last_h_grad_data, - cudnn_rnn_cache->dcy_desc_, last_c_grad_data, cudnn_rnn_cache->w_desc_, - weight_data, cudnn_rnn_cache->hx_desc_, init_h_data, - cudnn_rnn_cache->cx_desc_, init_c_data, cudnn_rnn_cache->dx_desc_, - in_grad_data, cudnn_rnn_cache->dhx_desc_, init_h_grad_data, - cudnn_rnn_cache->dcx_desc_, init_c_grad_data, work_data, - cudnn_rnn_cache->workspace_size_, reserve_data, - cudnn_rnn_cache->reserve_size_)); + handle, cudnn_rnn_cache->rnn_desc_, seq_len, cudnn_rnn_cache->y_desc_, + out_data, cudnn_rnn_cache->y_desc_, out_grad_data, + cudnn_rnn_cache->hy_desc_, last_h_grad_data, cudnn_rnn_cache->cy_desc_, + last_c_grad_data, cudnn_rnn_cache->w_desc_, weight_data, + cudnn_rnn_cache->hx_desc_, init_h_data, cudnn_rnn_cache->cx_desc_, + init_c_data, cudnn_rnn_cache->x_desc_, in_grad_data, + cudnn_rnn_cache->hx_desc_, init_h_grad_data, cudnn_rnn_cache->cx_desc_, + init_c_grad_data, work_data, cudnn_rnn_cache->workspace_size_, + const_cast(reserve_data), reserve_size)); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardWeights( - handle, cudnn_rnn_cache->rnn_desc_, run_seq_len, - cudnn_rnn_cache->x_desc_, input->data(), cudnn_rnn_cache->hx_desc_, - init_h->data(), cudnn_rnn_cache->y_desc_, out->data(), + handle, cudnn_rnn_cache->rnn_desc_, seq_len, cudnn_rnn_cache->x_desc_, + input->data(), cudnn_rnn_cache->hx_desc_, init_h->data(), + cudnn_rnn_cache->y_desc_, out->data(), cudnn_rnn_cache->workspace_data_.data(), - cudnn_rnn_cache->workspace_size_, cudnn_rnn_cache->dw_desc_, - weight_grad->data(), cudnn_rnn_cache->reserve_data_.data(), - cudnn_rnn_cache->reserve_size_)); + cudnn_rnn_cache->workspace_size_, cudnn_rnn_cache->w_desc_, + weight_grad->data(), const_cast(reserve_data), + reserve_size)); + delete cudnn_rnn_cache; } }; @@ -257,5 +201,7 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(cudnn_lstm, ops::CudnnLSTMGPUKernel); -REGISTER_OP_CUDA_KERNEL(cudnn_lstm_grad, ops::CudnnLSTMGPUGradKernel); +REGISTER_OP_CUDA_KERNEL(cudnn_lstm, ops::CudnnLSTMGPUKernel, + ops::CudnnLSTMGPUKernel); +REGISTER_OP_CUDA_KERNEL(cudnn_lstm_grad, ops::CudnnLSTMGPUGradKernel, + ops::CudnnLSTMGPUGradKernel); diff --git a/paddle/fluid/operators/cudnn_rnn_cache.h b/paddle/fluid/operators/cudnn_rnn_cache.h index cd33338abc6223a0ae122cbb60f040562b48a761..13a3e7d09b9f628f31bb9ff3b6137acf6d929c5c 100644 --- a/paddle/fluid/operators/cudnn_rnn_cache.h +++ b/paddle/fluid/operators/cudnn_rnn_cache.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/platform/cudnn_helper.h" @@ -24,16 +25,12 @@ struct CudnnRNNCache { CudnnRNNCache() { x_desc_ = NULL; y_desc_ = NULL; - dx_desc_ = NULL; - dy_desc_ = NULL; } ~CudnnRNNCache() { release(); } cudnnRNNDescriptor_t rnn_desc_; cudnnTensorDescriptor_t *x_desc_; cudnnTensorDescriptor_t *y_desc_; - cudnnTensorDescriptor_t *dx_desc_; - cudnnTensorDescriptor_t *dy_desc_; cudnnTensorDescriptor_t hx_desc_; cudnnTensorDescriptor_t cx_desc_; @@ -55,13 +52,9 @@ struct CudnnRNNCache { cudnnFilterDescriptor_t dw_desc_; size_t workspace_size_; - size_t reserve_size_; - framework::Tensor reserve_data_; framework::Tensor workspace_data_; - framework::Tensor dropout_state_; - - size_t max_length_; + size_t seq_length_; float dropout_prob_; bool is_bidirec_; @@ -72,10 +65,12 @@ struct CudnnRNNCache { int num_layers_; int seed_; - void init(cudnnHandle_t handle, const platform::Place &place, size_t max_len, + void init(cudnnHandle_t handle, const platform::Place &place, size_t seq_len, int batch_size, int input_size, int hidden_size, int num_layers, - float dropout_prob, bool is_bidirec, int seed, int weight_numel) { - max_length_ = max_len; + float dropout_prob, bool is_bidirec, int seed, int weight_numel, + size_t *reserve_size_, framework::Tensor *dropout_state_, + bool initialized, cudnnDataType_t cudnn_type) { + seq_length_ = seq_len; batch_size_ = batch_size; input_size_ = input_size; hidden_size_ = hidden_size; @@ -84,55 +79,34 @@ struct CudnnRNNCache { is_bidirec_ = is_bidirec; seed_ = seed; - x_desc_ = new cudnnTensorDescriptor_t[max_length_]; - y_desc_ = new cudnnTensorDescriptor_t[max_length_]; - dx_desc_ = new cudnnTensorDescriptor_t[max_length_]; - dy_desc_ = new cudnnTensorDescriptor_t[max_length_]; - int dim_a[3]; - int stride_a[3]; + const auto numDirections = is_bidirec_ ? 2 : 1; + auto cudnn_size = + cudnn_type == CUDNN_DATA_FLOAT ? sizeof(float) : sizeof(double); + + x_desc_ = new cudnnTensorDescriptor_t[seq_length_]; + y_desc_ = new cudnnTensorDescriptor_t[seq_length_]; + std::vector dims = {batch_size_, input_size_, 1}; + std::vector strides = {input_size_, 1, 1}; + + std::vector dims_y = {batch_size_, hidden_size_ * numDirections, 1}; + std::vector strides_y = {hidden_size_ * numDirections, 1, 1}; - for (size_t i = 0; i < max_length_; ++i) { + for (size_t i = 0; i < seq_length_; ++i) { PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnCreateTensorDescriptor(&x_desc_[i])); PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnCreateTensorDescriptor(&y_desc_[i])); - PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::cudnnCreateTensorDescriptor(&dx_desc_[i])); - PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::cudnnCreateTensorDescriptor(&dy_desc_[i])); - dim_a[0] = batch_size_; - dim_a[1] = input_size_; - dim_a[2] = 1; - - stride_a[0] = dim_a[2] * dim_a[1]; - stride_a[1] = dim_a[2]; - stride_a[2] = 1; - PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( - x_desc_[i], CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); - PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( - dx_desc_[i], CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); - - dim_a[0] = batch_size_; - dim_a[1] = is_bidirec_ ? hidden_size_ * 2 : hidden_size_; - dim_a[2] = 1; - - stride_a[0] = dim_a[2] * dim_a[1]; - stride_a[1] = dim_a[2]; - stride_a[2] = 1; PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( - y_desc_[i], CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); + x_desc_[i], cudnn_type, 3, dims.data(), strides.data())); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( - dy_desc_[i], CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); + y_desc_[i], cudnn_type, 3, dims_y.data(), strides_y.data())); } - dim_a[0] = num_layers_ * (is_bidirec_ ? 2 : 1); - dim_a[1] = batch_size_; - dim_a[2] = hidden_size_; - - stride_a[0] = dim_a[2] * dim_a[1]; - stride_a[1] = dim_a[2]; - stride_a[2] = 1; + std::vector dims_hx = {num_layers_ * numDirections, batch_size_, + hidden_size_}; + std::vector strides_hx = {hidden_size_ * batch_size_, hidden_size_, 1}; PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnCreateTensorDescriptor(&hx_desc_)); @@ -152,33 +126,44 @@ struct CudnnRNNCache { platform::dynload::cudnnCreateTensorDescriptor(&dcy_desc_)); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( - hx_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); + hx_desc_, cudnn_type, 3, dims_hx.data(), strides_hx.data())); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( - cx_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); + cx_desc_, cudnn_type, 3, dims_hx.data(), strides_hx.data())); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( - hy_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); + hy_desc_, cudnn_type, 3, dims_hx.data(), strides_hx.data())); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( - cy_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); + cy_desc_, cudnn_type, 3, dims_hx.data(), strides_hx.data())); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( - dhx_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); + dhx_desc_, cudnn_type, 3, dims_hx.data(), strides_hx.data())); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( - dcx_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); + dcx_desc_, cudnn_type, 3, dims_hx.data(), strides_hx.data())); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( - dhy_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); + dhy_desc_, cudnn_type, 3, dims_hx.data(), strides_hx.data())); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( - dcy_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); + dcy_desc_, cudnn_type, 3, dims_hx.data(), strides_hx.data())); PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnCreateDropoutDescriptor(&dropout_desc_)); size_t state_size; - PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::cudnnDropoutGetStatesSize(handle, &state_size)); - dropout_state_.Resize({static_cast(state_size)}); - auto *dropout_state_data = dropout_state_.mutable_data(place); - PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetDropoutDescriptor( - dropout_desc_, handle, dropout_prob_, dropout_state_data, state_size, - seed_)); + if (!initialized) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cudnnDropoutGetStatesSize(handle, &state_size)); + dropout_state_->Resize({static_cast(state_size)}); + uint8_t *dropout_state_data = + dropout_state_->mutable_data(place); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetDropoutDescriptor( + dropout_desc_, handle, dropout_prob_, dropout_state_data, state_size, + seed_)); + } else { + uint8_t *dropout_state_data = dropout_state_->data(); + auto dropout_state_dims = dropout_state_->dims(); + state_size = dropout_state_dims[0]; + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cudnnRestoreDropoutDescriptor( + dropout_desc_, handle, dropout_prob_, dropout_state_data, + state_size, 0)); + } PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnCreateRNNDescriptor(&rnn_desc_)); @@ -188,12 +173,12 @@ struct CudnnRNNCache { handle, rnn_desc_, hidden_size_, num_layers_, dropout_desc_, CUDNN_LINEAR_INPUT, is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, CUDNN_LSTM, - CUDNN_RNN_ALGO_STANDARD, CUDNN_DATA_FLOAT)); + CUDNN_RNN_ALGO_STANDARD, cudnn_type)); #else PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetRNNDescriptor( rnn_desc_, hidden_size_, num_layers_, dropout_desc_, CUDNN_LINEAR_INPUT, is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, CUDNN_LSTM, - CUDNN_DATA_FLOAT)); + cudnn_type)); #endif PADDLE_ENFORCE_CUDA_SUCCESS( @@ -202,48 +187,42 @@ struct CudnnRNNCache { platform::dynload::cudnnCreateFilterDescriptor(&dw_desc_)); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnGetRNNParamsSize( - handle, rnn_desc_, x_desc_[0], &weights_size_, CUDNN_DATA_FLOAT)); + handle, rnn_desc_, x_desc_[0], &weights_size_, cudnn_type)); + + PADDLE_ENFORCE_EQ( + weights_size_, cudnn_size * weight_numel, + platform::errors::InvalidArgument( + "The cudnn lstm and setting weight size should be same.")); - PADDLE_ENFORCE_EQ(weights_size_, sizeof(float) * weight_numel, - "cudnn lstm weight size should be SAME"); int dim_w[3]; - dim_w[0] = weights_size_ / sizeof(float); + dim_w[0] = weights_size_ / cudnn_size; dim_w[1] = 1; dim_w[2] = 1; PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetFilterNdDescriptor( - w_desc_, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, 3, dim_w)); + w_desc_, cudnn_type, CUDNN_TENSOR_NCHW, 3, dim_w)); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetFilterNdDescriptor( - dw_desc_, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, 3, dim_w)); + dw_desc_, cudnn_type, CUDNN_TENSOR_NCHW, 3, dim_w)); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnGetRNNWorkspaceSize( - handle, rnn_desc_, max_length_, x_desc_, &workspace_size_)); + handle, rnn_desc_, seq_length_, x_desc_, &workspace_size_)); PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnGetRNNTrainingReserveSize( - handle, rnn_desc_, max_length_, x_desc_, &reserve_size_)); - - reserve_data_.Resize({static_cast(reserve_size_)}); - reserve_data_.mutable_data(place); + handle, rnn_desc_, seq_length_, x_desc_, reserve_size_)); workspace_data_.Resize({static_cast(workspace_size_)}); workspace_data_.mutable_data(place); } void release() { - for (size_t i = 0; i < max_length_; ++i) { + for (size_t i = 0; i < seq_length_; ++i) { PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnDestroyTensorDescriptor(x_desc_[i])); PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnDestroyTensorDescriptor(y_desc_[i])); - PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::cudnnDestroyTensorDescriptor(dx_desc_[i])); - PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::cudnnDestroyTensorDescriptor(dy_desc_[i])); } delete[] x_desc_; delete[] y_desc_; - delete[] dx_desc_; - delete[] dy_desc_; PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnDestroyTensorDescriptor(hx_desc_)); diff --git a/paddle/fluid/platform/dynload/cudnn.h b/paddle/fluid/platform/dynload/cudnn.h index 0eb28f0c0c3561f98891ff2a0ab5a26a20b07fb4..ebeb14e940e5fd904e506bca565c4aeae84c93cf 100644 --- a/paddle/fluid/platform/dynload/cudnn.h +++ b/paddle/fluid/platform/dynload/cudnn.h @@ -100,6 +100,7 @@ extern void EnforceCUDNNLoaded(const char* fn_name); __macro(cudnnCreateDropoutDescriptor); \ __macro(cudnnDropoutGetStatesSize); \ __macro(cudnnSetDropoutDescriptor); \ + __macro(cudnnRestoreDropoutDescriptor); \ __macro(cudnnCreateRNNDescriptor); \ __macro(cudnnGetRNNParamsSize); \ __macro(cudnnGetRNNWorkspaceSize); \ diff --git a/python/paddle/fluid/layers/rnn.py b/python/paddle/fluid/layers/rnn.py index 4ec0770aaf0a559b9619e74f7c9178b2fd61062f..bc1368b562d7b354ce34dc87679fd8a0c5a3d012 100644 --- a/python/paddle/fluid/layers/rnn.py +++ b/python/paddle/fluid/layers/rnn.py @@ -2213,9 +2213,9 @@ def lstm(input, input ( :ref:`api_guide_Variable_en` ): LSTM input tensor, 3-D Tensor of shape :math:`[batch\_size, seq\_len, input\_dim]` . Data type is float32 or float64 init_h( :ref:`api_guide_Variable_en` ): The initial hidden state of the LSTM, 3-D Tensor of shape :math:`[num\_layers, batch\_size, hidden\_size]` . If is_bidirec = True, shape should be :math:`[num\_layers*2, batch\_size, hidden\_size]` . Data type is float32 or float64. + max_len (int): This parameter has no effect and will be discarded. init_c( :ref:`api_guide_Variable_en` ): The initial cell state of the LSTM, 3-D Tensor of shape :math:`[num\_layers, batch\_size, hidden\_size]` . If is_bidirec = True, shape should be :math:`[num\_layers*2, batch\_size, hidden\_size]` . Data type is float32 or float64. - max_len (int): max length of LSTM. the first dim of input tensor CAN NOT greater than max_len. hidden_size (int): hidden size of the LSTM. num_layers (int): total layers number of the LSTM. dropout_prob(float, optional): dropout prob, dropout ONLY work between rnn layers, NOT between time steps @@ -2256,7 +2256,6 @@ def lstm(input, data = fluid.data(name='x', shape=[None, 100], dtype='int64') emb = fluid.embedding(input=data, size=[vocab_size, emb_dim], is_sparse=True) batch_size = 20 - max_len = 100 dropout_prob = 0.2 input_size = 100 hidden_size = 150 @@ -2309,9 +2308,11 @@ def lstm(input, out = helper.create_variable_for_type_inference(dtype) last_h = helper.create_variable_for_type_inference(dtype) last_c = helper.create_variable_for_type_inference(dtype) - - cache = helper.create_variable( - persistable=True, type=core.VarDesc.VarType.RAW, stop_gradient=True) + reserve = helper.create_variable_for_type_inference( + dtype=core.VarDesc.VarType.UINT8, stop_gradient=True) + state_out = helper.create_variable_for_type_inference( + dtype=core.VarDesc.VarType.UINT8, stop_gradient=True) + state_out.persistable = True helper.append_op( type='cudnn_lstm', @@ -2320,15 +2321,15 @@ def lstm(input, 'InitH': init_h, 'InitC': init_c, 'W': weight, - 'Cache': cache, }, outputs={ 'Out': out, - 'last_h': last_h, - 'last_c': last_c, + 'LastH': last_h, + 'LastC': last_c, + 'Reserve': reserve, + 'StateOut': state_out, }, attrs={ - 'max_len': max_len, 'is_bidirec': is_bidirec, 'input_size': input_size, 'hidden_size': hidden_size, diff --git a/python/paddle/fluid/tests/unittests/test_lstm_cudnn_op.py b/python/paddle/fluid/tests/unittests/test_lstm_cudnn_op.py index d4189eca0369702120f079dae9067a58da1e9597..90430bbce4d1896c8fdbb829230f2ad8a691adff 100644 --- a/python/paddle/fluid/tests/unittests/test_lstm_cudnn_op.py +++ b/python/paddle/fluid/tests/unittests/test_lstm_cudnn_op.py @@ -20,15 +20,14 @@ import numpy as np import paddle.fluid.core as core from op_test import OpTest import paddle.fluid as fluid +import paddle.fluid.layers as layers SIGMOID_THRESHOLD_MIN = -40.0 SIGMOID_THRESHOLD_MAX = 13.0 EXP_MAX_INPUT = 40.0 -def lstm_naive( - input, - w, ): +def lstm_naive(input, w): seq_len, batch_size, hidden_size = input.shape offset = 0 @@ -86,8 +85,8 @@ def lstm_naive( return (2. / (1. + np.exp(y))) - 1. output = [] - pre_h = np.zeros((batch_size, hidden_size), dtype=input.dtype) - pre_c = np.zeros((batch_size, hidden_size), dtype=input.dtype) + pre_h = np.zeros((1, batch_size, hidden_size), dtype=input.dtype) + pre_c = np.zeros((1, batch_size, hidden_size), dtype=input.dtype) for i in range(seq_len): emb_1 = input[i] @@ -110,7 +109,6 @@ def lstm_naive( output = np.concatenate(output, -1) output = output.reshape((batch_size, -1, hidden_size)) - output = output.transpose((1, 0, 2)) return output, pre_h, pre_c @@ -119,11 +117,12 @@ def lstm_naive( @unittest.skipIf(not core.is_compiled_with_cuda(), "core is not compiled with CUDA") class TestCUDNNLstmOp(OpTest): + # TODO(GaoWei8):when input dtype is fp64, precision threshold should be removed. def setUp(self): self.op_type = "cudnn_lstm" - self.dtype = np.float32 + self.dtype = np.float64 - num_steps = 20 + seq_length = 20 batch_size = 5 hidden_size = 20 @@ -133,33 +132,24 @@ class TestCUDNNLstmOp(OpTest): weight_size += hidden_size * 8 input = np.random.uniform( - low=-0.1, high=0.1, size=(num_steps, batch_size, + low=-0.1, high=0.1, size=(seq_length, batch_size, hidden_size)).astype(self.dtype) flat_w = np.random.uniform( low=-0.1, high=0.1, size=(weight_size)).astype(self.dtype) output, last_hidden, last_cell = lstm_naive(input, flat_w) - init_h = np.zeros((batch_size, hidden_size), dtype=np.float32) - init_c = np.zeros((batch_size, hidden_size), dtype=np.float32) - scope = core.Scope() - program = fluid.Program() - block = program.global_block() - - cache_temp = block.create_var( - name="Cache", - persistable=True, - type=core.VarDesc.VarType.RAW, - stop_gradient=True) + init_h = np.zeros((1, batch_size, hidden_size), dtype=np.float64) + init_c = np.zeros((1, batch_size, hidden_size), dtype=np.float64) + state_out = np.ndarray((300)).astype("uint8") + self.inputs = { - 'Input': OpTest.np_dtype_to_fluid_dtype(input), - 'W': OpTest.np_dtype_to_fluid_dtype(flat_w), - 'InitH': OpTest.np_dtype_to_fluid_dtype(init_h), - 'InitC': OpTest.np_dtype_to_fluid_dtype(init_c), + 'Input': input, + 'W': flat_w, + 'InitH': init_h, + 'InitC': init_c } - self.cache_name_list = ['Cache'] self.attrs = { - 'max_len': num_steps, 'dropout_prob': 0.0, 'is_bidirec': False, 'input_size': hidden_size, @@ -168,22 +158,61 @@ class TestCUDNNLstmOp(OpTest): } self.outputs = { 'Out': output, - "last_h": last_hidden, - 'last_c': last_cell + "LastH": last_hidden, + 'LastC': last_cell, + 'Reserve': np.ndarray((400)).astype("uint8"), + 'StateOut': state_out } def test_output_with_place(self): # depend on the scope structure place = core.CUDAPlace(0) - self.check_output_with_place(place, atol=1e-5, check_dygraph=False) + self.check_output_with_place( + place, no_check_set=['Reserve', 'StateOut']) def test_grad_with_place(self): # depend on the scope structure place = core.CUDAPlace(0) self.check_grad_with_place( place, - set(['Input', 'W', 'InitH', 'InitC']), ['Out', 'last_h', 'last_c'], - check_dygraph=False) + set(['Input', 'W', 'InitH', 'InitC']), ['Out', 'LastH', 'LastC'], + max_relative_error=1e-4) + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestCUDNNlstmAPI(unittest.TestCase): + def test_lstm(self): + seq_len = 20 + batch_size = 5 + hidden_size = 20 + dropout_prob = 0.0 + num_layers = 1 + input = fluid.data( + name='input', + shape=[seq_len, batch_size, hidden_size], + dtype='float64') + init_h = layers.fill_constant([num_layers, batch_size, hidden_size], + 'float64', 0.0) + init_c = layers.fill_constant([num_layers, batch_size, hidden_size], + 'float64', 0.0) + rnn_out, last_h, last_c = layers.lstm(input, init_h, init_c, seq_len, + hidden_size, num_layers, + dropout_prob) + exe = fluid.Executor(fluid.CUDAPlace(0)) + exe.run(fluid.default_startup_program()) + input_i = np.random.uniform( + low=-0.1, high=0.1, size=(seq_len, batch_size, + hidden_size)).astype("float64") + out = exe.run(fluid.default_main_program(), + feed={'input': input_i}, + fetch_list=[rnn_out, last_h, last_c, 'cudnn_lstm_0.w_0']) + + output, last_hidden, last_cell = lstm_naive(input_i, out[3]) + + self.assertTrue(np.allclose(output, out[0], atol=1e-5)) + self.assertTrue(np.allclose(last_hidden, out[1], atol=1e-5)) + self.assertTrue(np.allclose(last_cell, out[2], atol=1e-5)) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py b/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py index b8258f3153a801dfc78db5f43325c0dce5c4b611..0de0eeb464ad700abb2144e49a822582b8653589 100644 --- a/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py @@ -26,4 +26,5 @@ no_check_set_white_list = [ 'cross_entropy2', 'seed', 'amp_check_finite_and_scale', + 'cudnn_lstm', ] diff --git a/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py b/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py index ce6868b5c70ae1218df48f899f936f57f6734582..5300ab935a3405f9f76c08a7f2ece8bad33367ac 100644 --- a/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py @@ -41,7 +41,8 @@ NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST = [ 'unpool', \ 'yolov3_loss', \ 'inverse', \ - 'bilateral_slice' + 'bilateral_slice',\ + 'cudnn_lstm' ] NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST = ['bilinear_interp']