From ee1ed42c9928c913c87030e24e9b4399cb93a355 Mon Sep 17 00:00:00 2001 From: GaoWei8 <53294385+GaoWei8@users.noreply.github.com> Date: Tue, 15 Sep 2020 11:24:02 +0800 Subject: [PATCH] change sequence length attribute to input (#27193) * replace sequence length attr to input --- paddle/fluid/operators/cudnn_lstm_cache.h | 166 +++++++++++++++ paddle/fluid/operators/cudnn_lstm_op.cc | 26 ++- paddle/fluid/operators/cudnn_lstm_op.cu.cc | 195 ++++++++++-------- paddle/fluid/platform/cudnn_helper.h | 170 +-------------- .../tests/unittests/test_lstm_cudnn_op.py | 11 +- .../white_list/check_shape_white_list.py | 1 + 6 files changed, 304 insertions(+), 265 deletions(-) create mode 100644 paddle/fluid/operators/cudnn_lstm_cache.h diff --git a/paddle/fluid/operators/cudnn_lstm_cache.h b/paddle/fluid/operators/cudnn_lstm_cache.h new file mode 100644 index 0000000000..4b46e2b475 --- /dev/null +++ b/paddle/fluid/operators/cudnn_lstm_cache.h @@ -0,0 +1,166 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/platform/cudnn_helper.h" +#include "paddle/fluid/platform/dynload/cudnn.h" + +namespace paddle { +namespace operators { + +class ScopedRNNBase { + public: + ScopedRNNBase(int seq_length, int batch_size, int input_size, int hidden_size, + int num_layers, float dropout_prob, int seed, int weight_numel, + bool initialized, bool is_bidirec) + : seq_length_(seq_length), + batch_size_(batch_size), + input_size_(input_size), + hidden_size_(hidden_size), + num_layers_(num_layers), + dropout_prob_(dropout_prob), + seed_(seed), + weight_numel_(weight_numel), + initialized_(initialized), + is_bidirec_(is_bidirec) {} + + template + void Create(const cudnnHandle_t& handle, const platform::Place& place, + const std::vector& sequence_length, size_t* workspace_size, + size_t* reserve_size, framework::Tensor* dropout_state) { + int numDirections = is_bidirec_ ? 2 : 1; + cudnnDataType_t cudnn_type = platform::CudnnDataType::type; + + // ------------------- cudnn x, y descriptors --------------------- + std::vector dims_x = {batch_size_, input_size_, 1}; + std::vector strides_x = {input_size_, 1, 1}; + std::vector dims_y = {batch_size_, hidden_size_ * numDirections, 1}; + std::vector strides_y = {hidden_size_ * numDirections, 1, 1}; + for (int i = 0; i < seq_length_; ++i) { + x_descs_.emplace_back(x_desc_.descriptor(dims_x, strides_x)); + y_descs_.emplace_back(y_desc_.descriptor(dims_y, strides_y)); + } + if (!sequence_length.empty()) { + x_seq_desc_.descriptor(seq_length_, batch_size_, input_size_, true, + sequence_length); + y_seq_desc_.descriptor(seq_length_, batch_size_, + hidden_size_ * numDirections, true, + sequence_length); + } + + // ------------------- cudnn hx, hy, cx, cy descriptors---------- + std::vector dims_hx = {num_layers_ * numDirections, batch_size_, + hidden_size_}; + std::vector strides_hx = {hidden_size_ * batch_size_, hidden_size_, 1}; + init_h_desc_.descriptor(dims_hx, strides_hx); + init_c_desc_.descriptor(dims_hx, strides_hx); + last_h_desc_.descriptor(dims_hx, strides_hx); + last_c_desc_.descriptor(dims_hx, strides_hx); + + // ------------------- cudnn dropout descriptors --------------------- + size_t state_size; + if (!initialized_) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cudnnDropoutGetStatesSize(handle, &state_size)); + dropout_state->mutable_data({static_cast(state_size)}, + place); + } + dropout_desc_.descriptor(handle, place, initialized_, dropout_prob_, + dropout_state, seed_, state_size); + +// ------------------- cudnn rnn descriptors --------------------- +#if CUDNN_VERSION >= 6000 + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetRNNDescriptor_v6( + handle, rnn_desc_.desc(), hidden_size_, num_layers_, + dropout_desc_.desc(), CUDNN_LINEAR_INPUT, + is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, CUDNN_LSTM, + CUDNN_RNN_ALGO_STANDARD, cudnn_type)); +#else + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetRNNDescriptor( + rnn_desc_.desc(), hidden_size_, num_layers_, dropout_desc_.desc(), + CUDNN_LINEAR_INPUT, + is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, CUDNN_LSTM, + cudnn_type)); +#endif + if (!sequence_length.empty()) { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetRNNPaddingMode( + rnn_desc_.desc(), CUDNN_RNN_PADDED_IO_ENABLED)); + } + + // ------------------- cudnn weights_size --------------------- + size_t weights_size_; + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnGetRNNParamsSize( + handle, rnn_desc_.desc(), x_descs_[0], &weights_size_, cudnn_type)); + PADDLE_ENFORCE_EQ( + weights_size_, sizeof(T) * weight_numel_, + platform::errors::InvalidArgument( + "The cudnn lstm and setting weight size should be same.")); + // ------------------- cudnn weight descriptors --------------------- + platform::DataLayout layout = platform::DataLayout::kNCHW; + int dim_tmp = weights_size_ / sizeof(T); + std::vector dim_w = {dim_tmp, 1, 1}; + weight_desc_.descriptor(layout, dim_w); + // ------------------- cudnn workspace, reserve size --------------------- + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnGetRNNWorkspaceSize( + handle, rnn_desc_.desc(), seq_length_, x_descs_.data(), + workspace_size)); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cudnnGetRNNTrainingReserveSize( + handle, rnn_desc_.desc(), seq_length_, x_descs_.data(), + reserve_size)); + } + cudnnTensorDescriptor_t* x_descs() { return x_descs_.data(); } + cudnnTensorDescriptor_t* y_descs() { return y_descs_.data(); } + cudnnRNNDataDescriptor_t x_seq_desc() { return x_seq_desc_.desc(); } + cudnnRNNDataDescriptor_t y_seq_desc() { return y_seq_desc_.desc(); } + cudnnTensorDescriptor_t init_h_desc() { return init_h_desc_.desc(); } + cudnnTensorDescriptor_t init_c_desc() { return init_c_desc_.desc(); } + cudnnTensorDescriptor_t last_h_desc() { return last_h_desc_.desc(); } + cudnnTensorDescriptor_t last_c_desc() { return last_c_desc_.desc(); } + cudnnRNNDescriptor_t rnn_desc() { return rnn_desc_.desc(); } + cudnnDropoutDescriptor_t dropout_desc() { return dropout_desc_.desc(); } + cudnnFilterDescriptor_t weight_desc() { return weight_desc_.desc(); } + + private: + int seq_length_; + int batch_size_; + int input_size_; + int hidden_size_; + int num_layers_; + float dropout_prob_; + int seed_; + int weight_numel_; + bool initialized_; + bool is_bidirec_; + std::vector x_descs_; + std::vector y_descs_; + + platform::ScopedTensorDescriptor x_desc_; + platform::ScopedTensorDescriptor y_desc_; + platform::ScopedRNNTensorDescriptor x_seq_desc_; + platform::ScopedRNNTensorDescriptor y_seq_desc_; + platform::ScopedTensorDescriptor init_h_desc_; + platform::ScopedTensorDescriptor init_c_desc_; + platform::ScopedTensorDescriptor last_h_desc_; + platform::ScopedTensorDescriptor last_c_desc_; + platform::ScopedDropoutDescriptor dropout_desc_; + platform::ScopedFilterDescriptor weight_desc_; + platform::ScopedRNNDescriptor rnn_desc_; +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/cudnn_lstm_op.cc b/paddle/fluid/operators/cudnn_lstm_op.cc index cc807f193e..82954bc109 100644 --- a/paddle/fluid/operators/cudnn_lstm_op.cc +++ b/paddle/fluid/operators/cudnn_lstm_op.cc @@ -51,6 +51,16 @@ class CudnnLSTMOp : public framework::OperatorWithKernel { "received InitH's rank is %d.", init_h_dims.size())); + if (ctx->HasInput("SequenceLength")) { + auto seq_dims = ctx->GetInputDim("SequenceLength"); + PADDLE_ENFORCE_EQ( + in_dims[1], seq_dims[0], + platform::errors::InvalidArgument( + "The size of SequenceLength has to equal the batch_size. But " + "received batch_size is %d and the size of SequenceLength is %d.", + in_dims[1], seq_dims[0])); + } + PADDLE_ENFORCE_EQ( in_dims[1], init_h_dims[1], platform::errors::InvalidArgument( @@ -113,6 +123,12 @@ class CudnnLSTMOpMaker : public framework::OpProtoAndCheckerMaker { "(Tensor) the learnable hidden-hidden weights." " The shape is (N), where N is total weight size of the LSTM. " " cudnn concatenate all the weight to one Tensor"); + AddInput("SequenceLength", + "(Tensor) When the input data is padding, " + "set this parameter. This parameter represents " + "the variable sequence lengths in a batch. " + "The size of the vector has to equal the batch_size.") + .AsDispensable(); AddOutput("Reserve", "(Tensor, a temporary output Tensor to store the reserve_data " "of cudnn kernel.") @@ -155,13 +171,6 @@ class CudnnLSTMOpMaker : public framework::OpProtoAndCheckerMaker { .SetDefault(1); AddAttr("is_test", "True if in test phase.").SetDefault(false); AddAttr("seed", "seed to used if fix_seed is True").SetDefault(0); - AddAttr>("sequence_length", - "(vector) When the input data is padding, " - "set this parameter. This parameter represents " - "the variable sequence" - "lengths in a batch. The size of the vector has " - "to equal the batch_size.") - .SetDefault({}); AddComment(R"DOC( CUDNN LSTM implementation @@ -243,6 +252,9 @@ class CudnnLSTMGradOpMaker : public framework::SingleGradOpMaker { op->SetInput("InitH", this->Input("InitH")); op->SetInput("InitC", this->Input("InitC")); op->SetInput("W", this->Input("W")); + if (this->HasInput("SequenceLength")) { + op->SetInput("SequenceLength", this->Input("SequenceLength")); + } op->SetInput("Reserve", this->Output("Reserve")); op->SetInput("StateOut", this->Output("StateOut")); op->SetInput("Out", this->Output("Out")); diff --git a/paddle/fluid/operators/cudnn_lstm_op.cu.cc b/paddle/fluid/operators/cudnn_lstm_op.cu.cc index f60cd41d9a..6457d9295d 100644 --- a/paddle/fluid/operators/cudnn_lstm_op.cu.cc +++ b/paddle/fluid/operators/cudnn_lstm_op.cu.cc @@ -13,8 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/cudnn_rnn_cache.h" +#include "paddle/fluid/operators/cudnn_lstm_cache.h" #include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/utils.h" #include "paddle/fluid/platform/cudnn_desc.h" #include "paddle/fluid/platform/cudnn_helper.h" @@ -24,6 +25,43 @@ namespace operators { using LoDTensor = framework::LoDTensor; using Tensor = framework::Tensor; +template +void LSTMInferece(const bool &has_seq_length, const cudnnHandle_t &handle, + const int &seq_length, ScopedRNNBase *rnn, const T *x_data, + const T *init_h_data, const T *init_c_data, const T *w_data, + T *out_data, T *last_h_data, T *last_c_data, + framework::Tensor *workspace_data, + const size_t &workspace_size) { + if (!has_seq_length) { + // for inference + // This interface is used when the input/output is unpadded. + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardInference( + handle, rnn->rnn_desc(), seq_length, rnn->x_descs(), x_data, + rnn->init_h_desc(), init_h_data, rnn->init_c_desc(), init_c_data, + rnn->weight_desc(), w_data, rnn->y_descs(), out_data, + rnn->last_h_desc(), last_h_data, rnn->last_c_desc(), last_c_data, + workspace_data->data(), workspace_size)); + } else { +#if CUDNN_VERSION >= 7201 + // for inference + // This interface is used when the input/output is padded. + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardInferenceEx( + handle, rnn->rnn_desc(), rnn->x_seq_desc(), x_data, rnn->init_h_desc(), + init_h_data, rnn->init_c_desc(), init_c_data, rnn->weight_desc(), + w_data, rnn->y_seq_desc(), out_data, rnn->last_h_desc(), last_h_data, + rnn->last_c_desc(), last_c_data, nullptr, nullptr, nullptr, nullptr, + nullptr, nullptr, nullptr, nullptr, workspace_data->data(), + workspace_size)); +#else + // CUDNN VERSION has to >=7.2.1 + PADDLE_THROW(platform::errors::Unavailable( + "The padded input is supported by " + "cudnnRNNForwardInferenceEx, but it only works when " + "the version of cudnn is larger than 7.2.1")); +#endif + } +} + template class CudnnLSTMGPUKernel : public framework::OpKernel { public: @@ -56,7 +94,13 @@ class CudnnLSTMGPUKernel : public framework::OpKernel { int num_layers = ctx.Attr("num_layers"); bool is_test = ctx.Attr("is_test"); int seed = ctx.Attr("seed"); - auto sequence_length = ctx.Attr>("sequence_length"); + + bool has_seq_length = ctx.HasInput("SequenceLength"); + std::vector SequenceLength; + if (has_seq_length) { + auto *sequence_length = ctx.Input("SequenceLength"); + SequenceLength = operators::GetDataFromTensor(sequence_length); + } auto &dev_ctx = ctx.template device_context(); auto handle = dev_ctx.cudnn_handle(); @@ -70,58 +114,32 @@ class CudnnLSTMGPUKernel : public framework::OpKernel { size_t workspace_size; size_t reserve_size; - platform::ScopedRNNBase rnn(seq_length, batch_size, input_size, hidden_size, - num_layers, dropout_prob, seed, weight_numel, - state_initialized, is_bidirec); - rnn.Create(handle, ctx.GetPlace(), sequence_length, &workspace_size, + ScopedRNNBase rnn(seq_length, batch_size, input_size, hidden_size, + num_layers, dropout_prob, seed, weight_numel, + state_initialized, is_bidirec); + rnn.Create(handle, ctx.GetPlace(), SequenceLength, &workspace_size, &reserve_size, state_out); framework::Tensor workspace_data_; - workspace_data_.Resize({static_cast(workspace_size)}); - workspace_data_.mutable_data(ctx.GetPlace()); + workspace_data_.mutable_data( + {static_cast(workspace_size)}, ctx.GetPlace()); auto *reserve_data = reserve->mutable_data( {static_cast(reserve_size)}, ctx.GetPlace()); if (is_test) { - if (sequence_length.empty()) { - // for inference - // This interface is used when the input/output is unpadded. - PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardInference( - handle, rnn.rnn_desc(), seq_length, rnn.x_desc(), x_data, - rnn.hx_desc(), init_h_data, rnn.cx_desc(), init_c_data, - rnn.w_desc(), w_data, rnn.y_desc(), out_data, rnn.hy_desc(), - last_h_data, rnn.cy_desc(), last_c_data, - workspace_data_.data(), workspace_size)); - } else { -#if CUDNN_VERSION >= 7201 - // for inference - // This interface is used when the input/output is padded. - PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::cudnnRNNForwardInferenceEx( - handle, rnn.rnn_desc(), rnn.x_seq_desc(), x_data, rnn.hx_desc(), - init_h_data, rnn.cx_desc(), init_c_data, rnn.w_desc(), w_data, - rnn.y_seq_desc(), out_data, rnn.hy_desc(), last_h_data, - rnn.cy_desc(), last_c_data, nullptr, nullptr, nullptr, nullptr, - nullptr, nullptr, nullptr, nullptr, - workspace_data_.data(), workspace_size)); -#else - PADDLE_ENFORCE_NOT_NULL( - nullptr, platform::errors::Unavailable( - "The padded input is supported by " - "cudnnRNNForwardInferenceEx, but it only works when " - "the version of cudnn is larger than 7.2.1")); -#endif - } + LSTMInferece(has_seq_length, handle, seq_length, &rnn, x_data, + init_h_data, init_c_data, w_data, out_data, last_h_data, + last_c_data, &workspace_data_, workspace_size); } else { - if (sequence_length.empty()) { + if (!has_seq_length) { // for train // This interface is used when the input/output is unpadded. PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardTraining( - handle, rnn.rnn_desc(), seq_length, rnn.x_desc(), x_data, - rnn.hx_desc(), init_h_data, rnn.cx_desc(), init_c_data, - rnn.w_desc(), w_data, rnn.y_desc(), out_data, rnn.hy_desc(), - last_h_data, rnn.cy_desc(), last_c_data, + handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), x_data, + rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data, + rnn.weight_desc(), w_data, rnn.y_descs(), out_data, + rnn.last_h_desc(), last_h_data, rnn.last_c_desc(), last_c_data, workspace_data_.data(), workspace_size, reserve_data, reserve_size)); } else { @@ -130,19 +148,18 @@ class CudnnLSTMGPUKernel : public framework::OpKernel { // This interface is used when the input/output is padded. PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnRNNForwardTrainingEx( - handle, rnn.rnn_desc(), rnn.x_seq_desc(), x_data, rnn.hx_desc(), - init_h_data, rnn.cx_desc(), init_c_data, rnn.w_desc(), w_data, - rnn.y_seq_desc(), out_data, rnn.hy_desc(), last_h_data, - rnn.cy_desc(), last_c_data, nullptr, nullptr, nullptr, nullptr, - nullptr, nullptr, nullptr, nullptr, - workspace_data_.data(), workspace_size, reserve_data, - reserve_size)); + handle, rnn.rnn_desc(), rnn.x_seq_desc(), x_data, + rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data, + rnn.weight_desc(), w_data, rnn.y_seq_desc(), out_data, + rnn.last_h_desc(), last_h_data, rnn.last_c_desc(), last_c_data, + nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, + nullptr, workspace_data_.data(), workspace_size, + reserve_data, reserve_size)); #else - PADDLE_ENFORCE_NOT_NULL( - nullptr, platform::errors::Unavailable( - "The padded input is supported by " - "cudnnRNNForwardTrainingEx, but it only works when " - "the version of cudnn is larger than 7.2.1")); + PADDLE_THROW(platform::errors::Unavailable( + "The padded input is supported by " + "cudnnRNNForwardTrainingEx, but it only works when " + "the version of cudnn is larger than 7.2.1")); #endif } } @@ -203,7 +220,13 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel { int hidden_size = ctx.Attr("hidden_size"); int num_layers = ctx.Attr("num_layers"); int seed = ctx.Attr("seed"); - auto sequence_length = ctx.Attr>("sequence_length"); + + bool has_seq_length = ctx.HasInput("SequenceLength"); + std::vector SequenceLength; + if (has_seq_length) { + auto *sequence_length = ctx.Input("SequenceLength"); + SequenceLength = operators::GetDataFromTensor(sequence_length); + } int seq_length = input_dims[0]; int batch_size = input->dims()[1]; @@ -213,33 +236,33 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel { size_t workspace_size; size_t reserve_size; - platform::ScopedRNNBase rnn(seq_length, batch_size, input_size, hidden_size, - num_layers, dropout_prob, seed, weight_numel, - true, is_bidirec); + ScopedRNNBase rnn(seq_length, batch_size, input_size, hidden_size, + num_layers, dropout_prob, seed, weight_numel, true, + is_bidirec); - rnn.Create(handle, ctx.GetPlace(), sequence_length, &workspace_size, + rnn.Create(handle, ctx.GetPlace(), SequenceLength, &workspace_size, &reserve_size, const_cast(state_out)); framework::Tensor workspace_data_; - workspace_data_.Resize({static_cast(workspace_size)}); - workspace_data_.mutable_data(ctx.GetPlace()); + workspace_data_.mutable_data( + {static_cast(workspace_size)}, ctx.GetPlace()); const uint8_t *reserve_data = reserve->data(); - if (sequence_length.empty()) { + if (!has_seq_length) { // This interface is used when the input/output is unpadded. PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardData( - handle, rnn.rnn_desc(), seq_length, rnn.y_desc(), out_data, - rnn.y_desc(), out_grad_data, rnn.hy_desc(), last_h_grad_data, - rnn.cy_desc(), last_c_grad_data, rnn.w_desc(), weight_data, - rnn.hx_desc(), init_h_data, rnn.cx_desc(), init_c_data, rnn.x_desc(), - in_grad_data, rnn.hx_desc(), init_h_grad_data, rnn.cx_desc(), - init_c_grad_data, workspace_data_.data(), workspace_size, - const_cast(reserve_data), reserve_size)); + handle, rnn.rnn_desc(), seq_length, rnn.y_descs(), out_data, + rnn.y_descs(), out_grad_data, rnn.last_h_desc(), last_h_grad_data, + rnn.last_c_desc(), last_c_grad_data, rnn.weight_desc(), weight_data, + rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data, + rnn.x_descs(), in_grad_data, rnn.init_h_desc(), init_h_grad_data, + rnn.init_c_desc(), init_c_grad_data, workspace_data_.data(), + workspace_size, const_cast(reserve_data), reserve_size)); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardWeights( - handle, rnn.rnn_desc(), seq_length, rnn.x_desc(), input->data(), - rnn.hx_desc(), init_h->data(), rnn.y_desc(), out->data(), - workspace_data_.data(), workspace_size, rnn.w_desc(), + handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), input->data(), + rnn.init_h_desc(), init_h->data(), rnn.y_descs(), out->data(), + workspace_data_.data(), workspace_size, rnn.weight_desc(), weight_grad->data(), const_cast(reserve_data), reserve_size)); } else { @@ -248,27 +271,25 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel { // This interface is used when the input/output is padded. PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardDataEx( handle, rnn.rnn_desc(), rnn.y_seq_desc(), out_data, rnn.y_seq_desc(), - out_grad_data, nullptr, nullptr, rnn.hy_desc(), last_h_grad_data, - rnn.cy_desc(), last_c_grad_data, rnn.w_desc(), weight_data, - rnn.hx_desc(), init_h_data, rnn.cx_desc(), init_c_data, - rnn.x_seq_desc(), in_grad_data, rnn.hx_desc(), init_h_grad_data, - rnn.cx_desc(), init_c_grad_data, nullptr, nullptr, + out_grad_data, nullptr, nullptr, rnn.last_h_desc(), last_h_grad_data, + rnn.last_c_desc(), last_c_grad_data, rnn.weight_desc(), weight_data, + rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data, + rnn.x_seq_desc(), in_grad_data, rnn.init_h_desc(), init_h_grad_data, + rnn.init_c_desc(), init_c_grad_data, nullptr, nullptr, workspace_data_.data(), workspace_size, const_cast(reserve_data), reserve_size)); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardWeightsEx( handle, rnn.rnn_desc(), rnn.x_seq_desc(), input->data(), - rnn.hx_desc(), init_h->data(), rnn.y_seq_desc(), out->data(), - workspace_data_.data(), workspace_size, rnn.w_desc(), - weight_grad->data(), const_cast(reserve_data), - reserve_size)); + rnn.init_h_desc(), init_h->data(), rnn.y_seq_desc(), + out->data(), workspace_data_.data(), workspace_size, + rnn.weight_desc(), weight_grad->data(), + const_cast(reserve_data), reserve_size)); #else - PADDLE_ENFORCE_NOT_NULL( - nullptr, - platform::errors::Unavailable( - "The padded input of rnn is supported by cudnnRNNBackwardDataEx, " - "cudnnRNNBackwardWeightsEx, but it only works when the version " - "of cudnn is larger than 7.2.1")); + PADDLE_THROW(platform::errors::Unavailable( + "The padded input of rnn is supported by cudnnRNNBackwardDataEx, " + "cudnnRNNBackwardWeightsEx, but it only works when the version " + "of cudnn is larger than 7.2.1")); #endif } } diff --git a/paddle/fluid/platform/cudnn_helper.h b/paddle/fluid/platform/cudnn_helper.h index bbe847e719..bb4c2a89f6 100644 --- a/paddle/fluid/platform/cudnn_helper.h +++ b/paddle/fluid/platform/cudnn_helper.h @@ -287,6 +287,8 @@ class ScopedTensorDescriptor { return descriptor(CudnnDataType::type, dim, stride); } + inline cudnnTensorDescriptor_t desc() { return desc_; } + private: cudnnTensorDescriptor_t desc_; DISABLE_COPY_AND_ASSIGN(ScopedTensorDescriptor); @@ -329,6 +331,8 @@ class ScopedRNNTensorDescriptor { input_size, time_major, seq_length); } + inline cudnnRNNDataDescriptor_t desc() { return desc_; } + private: cudnnRNNDataDescriptor_t desc_; DISABLE_COPY_AND_ASSIGN(ScopedRNNTensorDescriptor); @@ -361,6 +365,7 @@ class ScopedDropoutDescriptor { } return desc_; } + inline cudnnDropoutDescriptor_t desc() { return desc_; } private: cudnnDropoutDescriptor_t desc_; @@ -376,7 +381,7 @@ class ScopedRNNDescriptor { PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnDestroyRNNDescriptor(desc_)); } - inline cudnnRNNDescriptor_t descriptor() { return desc_; } + inline cudnnRNNDescriptor_t desc() { return desc_; } private: cudnnRNNDescriptor_t desc_; @@ -419,172 +424,13 @@ class ScopedFilterDescriptor { kernel, groups); } + inline cudnnFilterDescriptor_t desc() { return desc_; } + private: cudnnFilterDescriptor_t desc_; DISABLE_COPY_AND_ASSIGN(ScopedFilterDescriptor); }; -class ScopedRNNBase { - public: - ScopedRNNBase(int seq_length, int batch_size, int input_size, int hidden_size, - int num_layers, float dropout_prob, int seed, int weight_numel, - bool initialized, bool is_bidirec) - : seq_length_(seq_length), - batch_size_(batch_size), - input_size_(input_size), - hidden_size_(hidden_size), - num_layers_(num_layers), - dropout_prob_(dropout_prob), - seed_(seed), - weight_numel_(weight_numel), - initialized_(initialized), - is_bidirec_(is_bidirec) {} - - template - void Create(const cudnnHandle_t& handle, const platform::Place& place, - std::vector sequence_length, size_t* workspace_size, - size_t* reserve_size, framework::Tensor* dropout_state) { - int numDirections = is_bidirec_ ? 2 : 1; - cudnnDataType_t cudnn_type = platform::CudnnDataType::type; - - // ------------------- cudnn x, y descriptors --------------------- - std::vector dims_x = {batch_size_, input_size_, 1}; - std::vector strides_x = {input_size_, 1, 1}; - - std::vector dims_y = {batch_size_, hidden_size_ * numDirections, 1}; - std::vector strides_y = {hidden_size_ * numDirections, 1, 1}; - - for (int i = 0; i < seq_length_; ++i) { - x_desc_.emplace_back(x_d.descriptor(dims_x, strides_x)); - y_desc_.emplace_back(y_d.descriptor(dims_y, strides_y)); - } - - if (!sequence_length.empty()) { - x_seq_desc_ = x_seq_d.descriptor(seq_length_, batch_size_, input_size_, - true, sequence_length); - y_seq_desc_ = y_seq_d.descriptor(seq_length_, batch_size_, - hidden_size_ * numDirections, true, - sequence_length); - } - - // ------------------- cudnn hx, hy, cx, cy descriptors---------- - std::vector dims_hx = {num_layers_ * numDirections, batch_size_, - hidden_size_}; - std::vector strides_hx = {hidden_size_ * batch_size_, hidden_size_, 1}; - - hx_desc_ = hx_d.descriptor(dims_hx, strides_hx); - cx_desc_ = cx_d.descriptor(dims_hx, strides_hx); - hy_desc_ = hy_d.descriptor(dims_hx, strides_hx); - cy_desc_ = cy_d.descriptor(dims_hx, strides_hx); - - // ------------------- cudnn dropout descriptors --------------------- - size_t state_size; - if (!initialized_) { - PADDLE_ENFORCE_CUDA_SUCCESS( - dynload::cudnnDropoutGetStatesSize(handle, &state_size)); - dropout_state->mutable_data({static_cast(state_size)}, - place); - } - dropout_desc_ = - dropout_d.descriptor(handle, place, initialized_, dropout_prob_, - dropout_state, seed_, state_size); - - // ------------------- cudnn rnn descriptors --------------------- - rnn_desc_ = rnn_d.descriptor(); - -#if CUDNN_VERSION >= 6000 - PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetRNNDescriptor_v6( - handle, rnn_desc_, hidden_size_, num_layers_, dropout_desc_, - CUDNN_LINEAR_INPUT, - is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, CUDNN_LSTM, - CUDNN_RNN_ALGO_STANDARD, cudnn_type)); -#else - PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetRNNDescriptor( - rnn_desc_, hidden_size_, num_layers_, dropout_desc_, CUDNN_LINEAR_INPUT, - is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, CUDNN_LSTM, - cudnn_type)); -#endif - if (!sequence_length.empty()) { - PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetRNNPaddingMode( - rnn_desc_, CUDNN_RNN_PADDED_IO_ENABLED)); - } - // ------------------- cudnn weights_size --------------------- - size_t weights_size_; - PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnGetRNNParamsSize( - handle, rnn_desc_, x_desc_[0], &weights_size_, cudnn_type)); - - PADDLE_ENFORCE_EQ( - weights_size_, sizeof(T) * weight_numel_, - platform::errors::InvalidArgument( - "The cudnn lstm and setting weight size should be same.")); - - // ------------------- cudnn weight descriptors --------------------- - platform::DataLayout layout = platform::DataLayout::kNCHW; - int dim_tmp = weights_size_ / sizeof(T); - std::vector dim_w = {dim_tmp, 1, 1}; - w_desc_ = w_d.descriptor(layout, dim_w); - - // ------------------- cudnn workspace, reserve size --------------------- - PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnGetRNNWorkspaceSize( - handle, rnn_desc_, seq_length_, x_desc_.data(), workspace_size)); - PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::cudnnGetRNNTrainingReserveSize( - handle, rnn_desc_, seq_length_, x_desc_.data(), reserve_size)); - } - - cudnnTensorDescriptor_t* x_desc() { return x_desc_.data(); } - cudnnTensorDescriptor_t* y_desc() { return y_desc_.data(); } - cudnnRNNDataDescriptor_t x_seq_desc() { return x_seq_desc_; } - cudnnRNNDataDescriptor_t y_seq_desc() { return y_seq_desc_; } - cudnnTensorDescriptor_t hx_desc() { return hx_desc_; } - cudnnTensorDescriptor_t cx_desc() { return cx_desc_; } - cudnnTensorDescriptor_t hy_desc() { return hy_desc_; } - cudnnTensorDescriptor_t cy_desc() { return cy_desc_; } - cudnnRNNDescriptor_t rnn_desc() { return rnn_desc_; } - cudnnDropoutDescriptor_t dropout_desc() { return dropout_desc_; } - cudnnFilterDescriptor_t w_desc() { return w_desc_; } - - private: - int seq_length_; - int batch_size_; - int input_size_; - int hidden_size_; - int num_layers_; - float dropout_prob_; - int seed_; - int weight_numel_; - bool initialized_; - bool is_bidirec_; - - std::vector x_desc_; - std::vector y_desc_; - cudnnRNNDataDescriptor_t x_seq_desc_; - cudnnRNNDataDescriptor_t y_seq_desc_; - // A tensor descriptor describing the initial hidden state of the RNN. - cudnnTensorDescriptor_t hx_desc_; - // A tensor descriptor describing the initial cell state for LSTM networks. - cudnnTensorDescriptor_t cx_desc_; - // A tensor descriptor describing the final hidden state of the RNN. - cudnnTensorDescriptor_t hy_desc_; - // A tensor descriptor describing the final cell state for LSTM networks. - cudnnTensorDescriptor_t cy_desc_; - cudnnDropoutDescriptor_t dropout_desc_; - cudnnFilterDescriptor_t w_desc_; - cudnnRNNDescriptor_t rnn_desc_; - - ScopedTensorDescriptor x_d; - ScopedTensorDescriptor y_d; - ScopedRNNTensorDescriptor x_seq_d; - ScopedRNNTensorDescriptor y_seq_d; - ScopedTensorDescriptor hx_d; - ScopedTensorDescriptor cx_d; - ScopedTensorDescriptor hy_d; - ScopedTensorDescriptor cy_d; - ScopedDropoutDescriptor dropout_d; - ScopedFilterDescriptor w_d; - ScopedRNNDescriptor rnn_d; -}; - class ScopedConvolutionDescriptor { public: ScopedConvolutionDescriptor() { diff --git a/python/paddle/fluid/tests/unittests/test_lstm_cudnn_op.py b/python/paddle/fluid/tests/unittests/test_lstm_cudnn_op.py index 1f3dab67f2..29a0fa55f7 100644 --- a/python/paddle/fluid/tests/unittests/test_lstm_cudnn_op.py +++ b/python/paddle/fluid/tests/unittests/test_lstm_cudnn_op.py @@ -400,7 +400,8 @@ class TestCUDNNLstmOp(OpTest): 'Input': input, 'W': flat_w, 'InitH': init_h, - 'InitC': init_c + 'InitC': init_c, + 'SequenceLength': self.sequence_length } self.attrs = { 'dropout_prob': 0.0, @@ -408,7 +409,6 @@ class TestCUDNNLstmOp(OpTest): 'input_size': input_size, 'hidden_size': hidden_size, 'num_layers': 1, - 'sequence_length': self.sequence_length.tolist() } self.outputs = { 'Out': output, @@ -436,13 +436,6 @@ class TestCUDNNLstmOp(OpTest): @unittest.skipIf(not core.is_compiled_with_cuda(), "core is not compiled with CUDA") class TestCUDNNLstmOp2(TestCUDNNLstmOp): - def set_attrs(self): - self.sequence_length = np.array([], dtype=np.int32) - - -@unittest.skipIf(not core.is_compiled_with_cuda(), - "core is not compiled with CUDA") -class TestCUDNNLstmOp3(TestCUDNNLstmOp): def set_attrs(self): self.num_layers = 2 diff --git a/python/paddle/fluid/tests/unittests/white_list/check_shape_white_list.py b/python/paddle/fluid/tests/unittests/white_list/check_shape_white_list.py index 227e6cc28f..e19641e710 100644 --- a/python/paddle/fluid/tests/unittests/white_list/check_shape_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/check_shape_white_list.py @@ -26,4 +26,5 @@ NEED_TO_FIX_OP_LIST = [ 'squared_l2_distance', 'tree_conv', 'cvm', + 'cudnn_lstm', ] -- GitLab