diff --git a/paddle/fluid/operators/rnn_op.cc b/paddle/fluid/operators/rnn_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..dfdd32e10b9a99eec5f37133715daad1c4d82896 --- /dev/null +++ b/paddle/fluid/operators/rnn_op.cc @@ -0,0 +1,255 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace operators { + +class RNNOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "RNN"); + OP_INOUT_CHECK(ctx->HasInputs("PreState"), "Input", "PreState", "RNN"); + + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "RNN"); + OP_INOUT_CHECK(ctx->HasOutputs("State"), "Output", "State", "RNN"); + + auto in_dims = ctx->GetInputDim("Input"); + auto pre_state_dims = ctx->GetInputsDim("PreState"); + + PADDLE_ENFORCE_EQ(in_dims.size(), 3, + platform::errors::InvalidArgument( + "The rank of Input in RNN must be 3. But " + "received Input's rank is %d.", + in_dims.size())); + + if (ctx->HasInput("SequenceLength")) { + auto seq_dims = ctx->GetInputDim("SequenceLength"); + PADDLE_ENFORCE_EQ( + in_dims[1], seq_dims[0], + platform::errors::InvalidArgument( + "The size of SequenceLength has to equal the batch_size. But " + "received batch_size is %d and the size of SequenceLength is %d.", + in_dims[1], seq_dims[0])); + } + + PADDLE_ENFORCE_EQ(pre_state_dims[0].size(), 3, + platform::errors::InvalidArgument( + "The rank of PreState in RNN must be 3. But " + "the received rank is %d.", + pre_state_dims[0].size())); + size_t i = 0; + for (; i < pre_state_dims.size(); ++i) { + PADDLE_ENFORCE_EQ( + in_dims[1], pre_state_dims[i][1], + platform::errors::InvalidArgument( + "The second dimension size (representing for batch size) of " + "Input and PreState should be equal. But received %d and %d.", + in_dims[1], pre_state_dims[i][1])); + PADDLE_ENFORCE_EQ( + pre_state_dims[0], pre_state_dims[i], + platform::errors::InvalidArgument( + "The dims of all tensors in PreState should be same. But " + "received PreState[0] is %s and PreState[%d] is %s.", + pre_state_dims[0], i, pre_state_dims[i])); + } + auto mode = ctx->Attrs().Get("mode"); + size_t num_state = mode == "LSTM" ? 2 : 1; + PADDLE_ENFORCE_EQ( + i, num_state, + platform::errors::InvalidArgument( + "The number of tensors in PreState of %s should be %d, " + "but received %d.", + mode, 2, i)); + + auto out_dims = in_dims; + auto hidden_size = ctx->Attrs().Get("hidden_size"); + bool is_bidirec = ctx->Attrs().Get("is_bidirec"); + out_dims[2] = is_bidirec ? hidden_size * 2 : hidden_size; + ctx->SetOutputDim("Out", out_dims); + ctx->SetOutputsDim("State", pre_state_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); + } +}; + +class RNNOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput( + "Input", + "(Tensor) RNN input tensor, which support variable-time length input " + "sequence." + "The shape of the Tensor MUST be ( seq_len * batch_size * input_size)" + "seq_len is the total time step in this mini-batch (CAN be change in " + "different batch)" + "batch_size is the instance number of this batch" + "input_size is the hidden size of the input." + "input_size and the hidden_size in the next may not be same"); + AddInput("PreState", + "(Tensor) the initial hidden state of the LSTM" + "input. This is a tensor with shape (num_layers x batch_size x " + "hidden_size)" + "and When is_bidirec is True, the shape will be (num_layers*2 x " + "batch_size x hidden_size)") + .AsDuplicable(); + AddInput("WeightList", + "(vector), stores weight and bias data when the weight " + "use the list format. ") + .AsDuplicable(); + AddInput("SequenceLength", + "(Tensor) When the input data is padding, " + "set this parameter. This parameter represents " + "the variable sequence lengths in a batch. " + "The size of the vector has to equal the batch_size.") + .AsDispensable(); + AddOutput("DropoutState", + "Store the global drop state when training, needed by cudnn rnn.") + .AsDispensable(); + // maybe need add intermediate outputs for cpu kernel + AddOutput("Reserve", + "(Tensor, a temporary output Tensor to store the reserve_data " + "of cudnn kernel.") + .AsIntermediate(); + AddOutput("Out", + "(Tensor) the hidden state of LSTM operator. " + "The shape is ( seq_len x batch_size x hidden_size) if " + "is_bidirec is False" + "and When is_bidirec is True, the shape will be ( seq_len x " + "batch_size x hidden_size * 2) "); + AddOutput("State", + "(Tensor) the hidden state of the last step. " + "The shape is ( num_layers x batch_size x hidden_size) if " + "is_bidirec is False" + "and When is_bidirec is True, the shape will be (num_layers*2 x " + "batch_size x hidden_size)") + .AsDuplicable(); + AddAttr( + "dropout_prob", + "dropout prob of the dropout op" + "the dropout ONLY work between rnn layers, not between time steps" + "There is no dropout work on the Out tensor") + .SetDefault(0.0); + AddAttr("is_bidirec", "whether it is bidirectional rnn") + .SetDefault(false); + AddAttr("input_size", "input size ot the Input Tensor").SetDefault(10); + AddAttr("hidden_size", "hidden size of rnn").SetDefault(100); + AddAttr("num_layers", "the total layer number").SetDefault(1); + AddAttr( + "mode", + "(string) rnn types, including: LSTM, GRU, RNN_RELU, RNN_TANH."); + AddAttr("is_test", "True if in test phase.").SetDefault(false); + AddAttr("seed", "seed to used if fix_seed is True").SetDefault(0); + AddComment(R"DOC( +)DOC"); + } +}; + +class RNNGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "RNN"); + OP_INOUT_CHECK(ctx->HasInputs("PreState"), "Input", "PreState", "RNN"); + OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "RNN"); + // OP_INOUT_CHECK(ctx->HasInputs("State"), "Input", "State", "RNN"); + + auto SetOutGradDim = [&ctx](const std::string& name) { + auto g_name = framework::GradVarName(name); + if (ctx->HasOutput(g_name)) { + ctx->SetOutputDim(g_name, ctx->GetInputDim(name)); + } + }; + + SetOutGradDim("Input"); + if (ctx->HasOutputs(framework::GradVarName("WeightList"))) { + ctx->SetOutputsDim(framework::GradVarName("WeightList"), + ctx->GetInputsDim("WeightList")); + } + if (ctx->HasOutputs(framework::GradVarName("PreState"))) { + ctx->SetOutputsDim(framework::GradVarName("PreState"), + ctx->GetInputsDim("PreState")); + } + } + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); + } +}; + +template +class RNNGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("rnn_grad"); + op->SetInput("Input", this->Input("Input")); + op->SetInput("PreState", this->Input("PreState")); + op->SetInput("WeightList", this->Input("WeightList")); + if (this->HasInput("SequenceLength")) { + op->SetInput("SequenceLength", this->Input("SequenceLength")); + } + op->SetInput("DropoutState", this->Output("DropoutState")); + op->SetInput("Reserve", this->Output("Reserve")); + op->SetInput("Out", this->Output("Out")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetInput(framework::GradVarName("State"), this->OutputGrad("State")); + + op->SetOutput(framework::GradVarName("WeightList"), + this->InputGrad("WeightList", false)); + + op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input")); + op->SetOutput(framework::GradVarName("PreState"), + this->InputGrad("PreState", false)); + op->SetAttrMap(this->Attrs()); + } +}; + +template +class NotImpleKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "CPU is not support for this kernel now. Will be add in the future")); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(rnn, ops::RNNOp, ops::RNNOpMaker, + ops::RNNGradOpMaker, + ops::RNNGradOpMaker); +REGISTER_OPERATOR(rnn_grad, ops::RNNGradOp); + +REGISTER_OP_CPU_KERNEL(rnn, ops::NotImpleKernel); +REGISTER_OP_CPU_KERNEL(rnn_grad, ops::NotImpleKernel); diff --git a/paddle/fluid/operators/rnn_op.cu.cc b/paddle/fluid/operators/rnn_op.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..568db79722324fe6a5ca446d6ed56ea76c74a82c --- /dev/null +++ b/paddle/fluid/operators/rnn_op.cu.cc @@ -0,0 +1,630 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/framework/generator.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/utils.h" +#include "paddle/fluid/platform/cudnn_helper.h" +#include "paddle/fluid/platform/dynload/cudnn.h" + +namespace paddle { +namespace platform { +class CUDADeviceContext; +struct CUDAPlace; +} // namespace platform +} // namespace paddle + +namespace paddle { +namespace operators { + +using LoDTensor = framework::LoDTensor; +using Tensor = framework::Tensor; + +class RNNDescriptors { + public: + RNNDescriptors(int seq_length, int batch_size, int input_size, + int hidden_size, int num_layers, float dropout_prob, int seed, + int weight_numel, cudnnRNNMode_t mode, bool is_bidirec, + bool is_test) + : seq_length_(seq_length), + batch_size_(batch_size), + input_size_(input_size), + hidden_size_(hidden_size), + num_layers_(num_layers), + dropout_prob_(dropout_prob), + seed_(seed), + weight_numel_(weight_numel), + mode_(mode), + is_bidirec_(is_bidirec), + is_test_(is_test) {} + + template + void Create(const cudnnHandle_t &handle, const platform::Place &place, + const std::vector &sequence_length, size_t *workspace_size, + size_t *reserve_size, framework::Tensor *dropout_state) { + int numDirections = is_bidirec_ ? 2 : 1; + cudnnDataType_t cudnn_type = platform::CudnnDataType::type; + + // ------------------- cudnn x, y descriptors --------------------- + std::vector dims_x = {batch_size_, input_size_, 1}; + std::vector strides_x = {input_size_, 1, 1}; + std::vector dims_y = {batch_size_, hidden_size_ * numDirections, 1}; + std::vector strides_y = {hidden_size_ * numDirections, 1, 1}; + for (int i = 0; i < seq_length_; ++i) { + x_descs_.emplace_back(x_desc_.descriptor(dims_x, strides_x)); + y_descs_.emplace_back(y_desc_.descriptor(dims_y, strides_y)); + } + +#if CUDNN_VERSION >= 7201 + if (!sequence_length.empty()) { + x_seq_desc_.descriptor(seq_length_, batch_size_, input_size_, true, + sequence_length); + y_seq_desc_.descriptor(seq_length_, batch_size_, + hidden_size_ * numDirections, true, + sequence_length); + } +#endif + + // ------------------- cudnn hx, hy, cx, cy descriptors---------- + std::vector dims_hx = {num_layers_ * numDirections, batch_size_, + hidden_size_}; + std::vector strides_hx = {hidden_size_ * batch_size_, hidden_size_, 1}; + init_h_desc_.descriptor(dims_hx, strides_hx); + init_c_desc_.descriptor(dims_hx, strides_hx); + last_h_desc_.descriptor(dims_hx, strides_hx); + last_c_desc_.descriptor(dims_hx, strides_hx); + + // ------------------- cudnn dropout descriptors --------------------- + size_t state_size; + if (!is_test_ && !dropout_state->IsInitialized()) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cudnnDropoutGetStatesSize(handle, &state_size)); + dropout_state->mutable_data({static_cast(state_size)}, + place); + } + dropout_desc_.descriptor(handle, place, dropout_state->IsInitialized(), + dropout_prob_, is_test_ ? nullptr : dropout_state, + seed_, state_size); + +// ------------------- cudnn rnn descriptors --------------------- +#if CUDNN_VERSION >= 6000 + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetRNNDescriptor_v6( + handle, rnn_desc_.desc(), hidden_size_, num_layers_, + dropout_desc_.desc(), CUDNN_LINEAR_INPUT, + is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, mode_, + CUDNN_RNN_ALGO_STANDARD, cudnn_type)); +#else + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetRNNDescriptor( + rnn_desc_.desc(), hidden_size_, num_layers_, dropout_desc_.desc(), + CUDNN_LINEAR_INPUT, + is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, mode_, + cudnn_type)); +#endif + +#if CUDNN_VERSION >= 7201 + if (!sequence_length.empty()) { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetRNNPaddingMode( + rnn_desc_.desc(), CUDNN_RNN_PADDED_IO_ENABLED)); + } +#endif + + // ------------------- cudnn weights_size --------------------- + size_t weights_size_; + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnGetRNNParamsSize( + handle, rnn_desc_.desc(), x_descs_[0], &weights_size_, cudnn_type)); + PADDLE_ENFORCE_EQ( + weights_size_, sizeof(T) * weight_numel_, + platform::errors::InvalidArgument( + "The cudnn rnn and setting weight size should be same.")); + // ------------------- cudnn weight descriptors --------------------- + platform::DataLayout layout = platform::DataLayout::kNCHW; + int dim_tmp = weights_size_ / sizeof(T); + std::vector dim_w = {dim_tmp, 1, 1}; + weight_desc_.descriptor(layout, dim_w); + // ------------------- cudnn workspace, reserve size --------------------- + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnGetRNNWorkspaceSize( + handle, rnn_desc_.desc(), seq_length_, x_descs_.data(), + workspace_size)); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cudnnGetRNNTrainingReserveSize( + handle, rnn_desc_.desc(), seq_length_, x_descs_.data(), + reserve_size)); + } + cudnnTensorDescriptor_t *x_descs() { return x_descs_.data(); } + cudnnTensorDescriptor_t *y_descs() { return y_descs_.data(); } +#if CUDNN_VERSION >= 7201 + cudnnRNNDataDescriptor_t x_seq_desc() { return x_seq_desc_.desc(); } + cudnnRNNDataDescriptor_t y_seq_desc() { return y_seq_desc_.desc(); } +#endif + cudnnTensorDescriptor_t init_h_desc() { return init_h_desc_.desc(); } + cudnnTensorDescriptor_t init_c_desc() { return init_c_desc_.desc(); } + cudnnTensorDescriptor_t last_h_desc() { return last_h_desc_.desc(); } + cudnnTensorDescriptor_t last_c_desc() { return last_c_desc_.desc(); } + cudnnRNNDescriptor_t rnn_desc() { return rnn_desc_.desc(); } + cudnnDropoutDescriptor_t dropout_desc() { return dropout_desc_.desc(); } + cudnnFilterDescriptor_t weight_desc() { return weight_desc_.desc(); } + + private: + int seq_length_; + int batch_size_; + int input_size_; + int hidden_size_; + int num_layers_; + float dropout_prob_; + int seed_; + int weight_numel_; + cudnnRNNMode_t mode_; + bool is_bidirec_; + bool is_test_; + std::vector x_descs_; + std::vector y_descs_; + + platform::ScopedTensorDescriptor x_desc_; + platform::ScopedTensorDescriptor y_desc_; +#if CUDNN_VERSION >= 7201 + platform::ScopedRNNTensorDescriptor x_seq_desc_; + platform::ScopedRNNTensorDescriptor y_seq_desc_; +#endif + platform::ScopedTensorDescriptor init_h_desc_; + platform::ScopedTensorDescriptor init_c_desc_; + platform::ScopedTensorDescriptor last_h_desc_; + platform::ScopedTensorDescriptor last_c_desc_; + platform::ScopedDropoutDescriptor dropout_desc_; + platform::ScopedFilterDescriptor weight_desc_; + platform::ScopedRNNDescriptor rnn_desc_; +}; + +template +bool is_continuous(const Type &weight_list) { + bool continuous = true; + for (size_t i = 0; i < weight_list.size() - 1; ++i) { + auto *in_data = weight_list[i]->template data(); + auto *in_after_data = weight_list[i + 1]->template data(); + auto in_size = weight_list[i]->numel(); + bool temp = in_data + in_size == in_after_data; + continuous = continuous && temp; + } + return continuous; +} + +template +void weight_to_tensor(const platform::Place &place, cudaStream_t stream, + const std::vector &weight_list, + Tensor *weight) { + auto weight_data = weight->data(); + int weight_offset = 0; + for (size_t i = 0; i < weight_list.size(); ++i) { + const T *in_data = weight_list[i]->data(); + auto in_size = weight_list[i]->numel(); + + memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, weight->place()), + weight_data + weight_offset, + BOOST_GET_CONST(platform::CUDAPlace, weight_list[i]->place()), + in_data, in_size * sizeof(T), stream); + weight_offset += in_size; + } +} + +template +void weight_to_tensor_list(const platform::Place &place, cudaStream_t stream, + std::vector *weight_grad, + const std::vector &weight_input, + const Tensor *weight) { + int weight_offset = 0; + auto *weight_data = weight->data(); + for (size_t i = 0; i < weight_input.size(); ++i) { + auto in_size = weight_input[i]->numel(); + T *weight_grad_data = (*weight_grad)[i]->mutable_data(place); + const T *src = weight_data + weight_offset; + + memory::Copy( + BOOST_GET_CONST(platform::CUDAPlace, (*weight_grad)[i]->place()), + weight_grad_data, BOOST_GET_CONST(platform::CUDAPlace, weight->place()), + src, in_size * sizeof(T), stream); + weight_offset += in_size; + } +} + +template +class RNNCudnnKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + const Tensor *x = ctx.Input("Input"); + auto pre_state = ctx.MultiInput("PreState"); + + Tensor *out = ctx.Output("Out"); + auto state = ctx.MultiOutput("State"); + Tensor *reserve = ctx.Output("Reserve"); + Tensor *state_out = ctx.Output("DropoutState"); + + float dropout_prob = ctx.Attr("dropout_prob"); + bool is_bidirec = ctx.Attr("is_bidirec"); + int hidden_size = ctx.Attr("hidden_size"); + int num_layers = ctx.Attr("num_layers"); + auto mode = ctx.Attr("mode"); + cudnnRNNMode_t rnn_mode = CUDNN_LSTM; + if (mode == "LSTM") + rnn_mode = CUDNN_LSTM; + else if (mode == "GRU") + rnn_mode = CUDNN_GRU; + else if (mode == "RNN_RELU") + rnn_mode = CUDNN_RNN_RELU; + else if (mode == "RNN_TANH") + rnn_mode = CUDNN_RNN_TANH; + else + PADDLE_THROW(platform::errors::InvalidArgument( + "rnn_mode should be LSTM, GRU, RNN_RELU or RNN_TANH, but received: " + "%s.", + mode)); + + bool is_test = ctx.Attr("is_test"); + int seed = ctx.Attr("seed"); + if (!is_test) { + int device_id = + BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()).GetDeviceId(); + auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); + if (gen_cuda->GetIsInitPy() && seed == 0) { + // If perform `manual_seed` in python and inner seed is not specified + // (equals 0), use global generator generated seed. + seed = static_cast(gen_cuda->Random64()); + } else if (seed == 0) { + // use random generated seed + std::random_device rd; + seed = rd(); + } // else use `ctx.Attr("seed")` specified seed + } + + const T *x_data = x->data(); + const T *init_h_data = pre_state[0]->data(); + const T *init_c_data = nullptr; + T *out_data = out->mutable_data(ctx.GetPlace()); + T *last_h_data = state[0]->mutable_data(ctx.GetPlace()); + T *last_c_data = nullptr; + if (rnn_mode == CUDNN_LSTM) { + init_c_data = pre_state[1]->data(); + last_c_data = state[1]->mutable_data(ctx.GetPlace()); + } + + bool has_seq_length = ctx.HasInput("SequenceLength"); + std::vector SequenceLength; + if (has_seq_length) { + auto *sequence_length = ctx.Input("SequenceLength"); + SequenceLength = operators::GetDataFromTensor(sequence_length); + } + + auto &dev_ctx = ctx.template device_context(); + auto handle = dev_ctx.cudnn_handle(); + + int seq_length = x->dims()[0]; + int batch_size = x->dims()[1]; + int input_size = x->dims()[2]; + + size_t workspace_size; + size_t reserve_size; + Tensor weight_whole; + T *w_data = nullptr; + auto place = ctx.GetPlace(); + auto stream = reinterpret_cast( + ctx.device_context()) + .stream(); + auto weight_list = ctx.MultiInput("WeightList"); + auto weight_numel = std::accumulate( + weight_list.begin(), weight_list.end(), 0, + [](int64_t num, const Tensor *t) { return num + t->numel(); }); + bool continuous = + is_continuous>(weight_list); + if (!continuous) { + LOG_FIRST_N(WARNING, 2) + << "If the memory space of the Input WeightList is not continuous, " + "less efficient calculation will be called. Please call " + "flatten_parameters() to make the input memory continuous."; + weight_whole.mutable_data({weight_numel}, place); + weight_to_tensor(place, stream, weight_list, &weight_whole); + w_data = weight_whole.data(); + if (is_test) { // maybe also reset small weights' ptr for training + int offset = 0; + for (size_t i = 0; i < weight_list.size(); ++i) { + size_t len = weight_list[i]->numel(); + auto dim = weight_list[i]->dims(); + const_cast(weight_list[i]) + ->ShareDataWith( + weight_whole.Slice(static_cast(offset), + static_cast(offset + len))) + .Resize(dim); + offset += len; + } + } + } else { + w_data = const_cast(weight_list[0]->data()); + } + + RNNDescriptors rnn(seq_length, batch_size, input_size, hidden_size, + num_layers, dropout_prob, seed, weight_numel, rnn_mode, + is_bidirec, is_test); + rnn.Create(handle, ctx.GetPlace(), SequenceLength, &workspace_size, + &reserve_size, state_out); + + framework::Tensor workspace_data_; + workspace_data_.mutable_data( + {static_cast(workspace_size)}, ctx.GetPlace()); + + auto *reserve_data = reserve->mutable_data( + {static_cast(reserve_size)}, ctx.GetPlace()); + + if (is_test) { + RNNInferece(has_seq_length, handle, seq_length, &rnn, x_data, init_h_data, + init_c_data, w_data, out_data, last_h_data, last_c_data, + &workspace_data_, workspace_size); + } else { + if (!has_seq_length) { + // for train + // This interface is used when the input/output is unpadded. + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardTraining( + handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), x_data, + rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data, + rnn.weight_desc(), w_data, rnn.y_descs(), out_data, + rnn.last_h_desc(), last_h_data, rnn.last_c_desc(), last_c_data, + workspace_data_.data(), workspace_size, reserve_data, + reserve_size)); + } else { +#if CUDNN_VERSION >= 7201 + // for train + // This interface is used when the input/output is padded. + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cudnnRNNForwardTrainingEx( + handle, rnn.rnn_desc(), rnn.x_seq_desc(), x_data, + rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data, + rnn.weight_desc(), w_data, rnn.y_seq_desc(), out_data, + rnn.last_h_desc(), last_h_data, rnn.last_c_desc(), last_c_data, + nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, + nullptr, workspace_data_.data(), workspace_size, + reserve_data, reserve_size)); +#else + PADDLE_THROW(platform::errors::Unavailable( + "The padded input is supported by " + "cudnnRNNForwardTrainingEx, but it only works when " + "the version of cudnn is larger than 7.2.1")); +#endif + } + } + } + + void RNNInferece(const bool &has_seq_length, const cudnnHandle_t &handle, + const int &seq_length, RNNDescriptors *rnn, const T *x_data, + const T *init_h_data, const T *init_c_data, const T *w_data, + T *out_data, T *last_h_data, T *last_c_data, + framework::Tensor *workspace_data, + const size_t &workspace_size) const { + if (!has_seq_length) { + // for inference + // This interface is used when the input/output is unpadded. + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardInference( + handle, rnn->rnn_desc(), seq_length, rnn->x_descs(), x_data, + rnn->init_h_desc(), init_h_data, rnn->init_c_desc(), init_c_data, + rnn->weight_desc(), w_data, rnn->y_descs(), out_data, + rnn->last_h_desc(), last_h_data, rnn->last_c_desc(), last_c_data, + workspace_data->data(), workspace_size)); + } else { +#if CUDNN_VERSION >= 7201 + // for inference + // This interface is used when the input/output is padded. + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardInferenceEx( + handle, rnn->rnn_desc(), rnn->x_seq_desc(), x_data, + rnn->init_h_desc(), init_h_data, rnn->init_c_desc(), init_c_data, + rnn->weight_desc(), w_data, rnn->y_seq_desc(), out_data, + rnn->last_h_desc(), last_h_data, rnn->last_c_desc(), last_c_data, + nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, + nullptr, workspace_data->data(), workspace_size)); +#else + // CUDNN VERSION has to >=7.2.1 + PADDLE_THROW(platform::errors::Unavailable( + "The padded input is supported by " + "cudnnRNNForwardInferenceEx, but it only works when " + "the version of cudnn is larger than 7.2.1")); +#endif + } + } +}; + +template +class RNNGradCudnnKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *input = ctx.Input("Input"); + auto pre_state = ctx.MultiInput("PreState"); + auto weight_list = ctx.MultiInput("WeightList"); + auto *state_out = ctx.Input("DropoutState"); + auto *reserve = ctx.Input("Reserve"); + auto *out = ctx.Input("Out"); + // auto state = ctx.MultiInput("State"); + + auto *out_grad = ctx.Input(framework::GradVarName("Out")); + auto state_grad = ctx.MultiInput(framework::GradVarName("State")); + + auto *in_grad = ctx.Output(framework::GradVarName("Input")); + auto pre_state_grad = + ctx.MultiOutput(framework::GradVarName("PreState")); + auto weight_grad_list = + ctx.MultiOutput(framework::GradVarName("WeightList")); + + float dropout_prob = ctx.Attr("dropout_prob"); + bool is_bidirec = ctx.Attr("is_bidirec"); + int hidden_size = ctx.Attr("hidden_size"); + int num_layers = ctx.Attr("num_layers"); + auto mode = ctx.Attr("mode"); + cudnnRNNMode_t rnn_mode = CUDNN_LSTM; + if (mode == "LSTM") + rnn_mode = CUDNN_LSTM; + else if (mode == "GRU") + rnn_mode = CUDNN_GRU; + else if (mode == "RNN_RELU") + rnn_mode = CUDNN_RNN_RELU; + else if (mode == "RNN_TANH") + rnn_mode = CUDNN_RNN_TANH; + else + PADDLE_THROW(platform::errors::InvalidArgument( + "rnn_mode should be LSTM, GRU, RNN_RELU or RNN_TANH, but received: " + "%s.", + mode)); + bool is_test = ctx.Attr("is_test"); + int seed = ctx.Attr("seed"); + + auto &dev_ctx = ctx.template device_context(); + auto handle = dev_ctx.cudnn_handle(); + + auto place = ctx.GetPlace(); + auto weight_numel = std::accumulate( + weight_list.begin(), weight_list.end(), 0, + [](int64_t num, const Tensor *t) { return num + t->numel(); }); + bool continuous = + is_continuous>(weight_list); + + auto stream = reinterpret_cast( + ctx.device_context()) + .stream(); + Tensor weight_whole; + T *weight_data = nullptr; + + if (!continuous) { + weight_whole.mutable_data({weight_numel}, place); + weight_to_tensor(place, stream, weight_list, &weight_whole); + weight_data = weight_whole.data(); + } else { + weight_data = const_cast(weight_list[0]->data()); + } + + Tensor weight_grad; + math::SetConstant zero; + weight_grad.mutable_data({weight_numel}, ctx.GetPlace()); + zero(dev_ctx, &weight_grad, static_cast(0.0)); + T *weight_grad_data = weight_grad.data(); + + int offset = 0; + for (size_t i = 0; i < weight_grad_list.size(); ++i) { + size_t len = weight_grad_list[i]->numel(); + auto dim = weight_grad_list[i]->dims(); + weight_grad_list[i] + ->ShareDataWith(weight_grad.Slice(static_cast(offset), + static_cast(offset + len))) + .Resize(dim); + offset += len; + } + + auto *init_h_data = pre_state[0]->data(); + // auto *last_h_data = state[0]->data(); + auto *last_h_grad_data = state_grad[0]->data(); + const T *init_c_data = nullptr; + // const T *last_c_data = nullptr; + const T *last_c_grad_data = nullptr; + T *init_h_grad_data = + pre_state_grad.size() != 0 && pre_state_grad[0] + ? pre_state_grad[0]->mutable_data(ctx.GetPlace()) + : nullptr; + T *init_c_grad_data = nullptr; + if (rnn_mode == CUDNN_LSTM) { + init_c_data = pre_state[1]->data(); + // last_c_data = state[1]->data(); + last_c_grad_data = state_grad[1]->data(); + init_c_grad_data = + pre_state_grad.size() != 0 && pre_state_grad[1] + ? pre_state_grad[1]->mutable_data(ctx.GetPlace()) + : nullptr; + } + auto *out_data = out->data(); + auto *out_grad_data = out_grad->data(); + // maybe need check exist + auto *in_grad_data = in_grad->mutable_data(ctx.GetPlace()); + + bool has_seq_length = ctx.HasInput("SequenceLength"); + std::vector SequenceLength; + if (has_seq_length) { + auto *sequence_length = ctx.Input("SequenceLength"); + SequenceLength = operators::GetDataFromTensor(sequence_length); + } + + auto input_dims = input->dims(); + int seq_length = input_dims[0]; + int batch_size = input_dims[1]; + int input_size = input_dims[2]; + + size_t workspace_size; + size_t reserve_size; + + RNNDescriptors rnn(seq_length, batch_size, input_size, hidden_size, + num_layers, dropout_prob, seed, weight_numel, rnn_mode, + is_bidirec, is_test); + + rnn.Create(handle, ctx.GetPlace(), SequenceLength, &workspace_size, + &reserve_size, const_cast(state_out)); + + framework::Tensor workspace_data_; + workspace_data_.mutable_data( + {static_cast(workspace_size)}, ctx.GetPlace()); + const uint8_t *reserve_data = reserve->data(); + + if (!has_seq_length) { + // This interface is used when the input/output is unpadded. + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardData( + handle, rnn.rnn_desc(), seq_length, rnn.y_descs(), out_data, + rnn.y_descs(), out_grad_data, rnn.last_h_desc(), last_h_grad_data, + rnn.last_c_desc(), last_c_grad_data, rnn.weight_desc(), weight_data, + rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data, + rnn.x_descs(), in_grad_data, rnn.init_h_desc(), init_h_grad_data, + rnn.init_c_desc(), init_c_grad_data, workspace_data_.data(), + workspace_size, const_cast(reserve_data), reserve_size)); + + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardWeights( + handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), input->data(), + rnn.init_h_desc(), init_h_data, rnn.y_descs(), out->data(), + workspace_data_.data(), workspace_size, rnn.weight_desc(), + weight_grad_data, const_cast(reserve_data), reserve_size)); + } else { +#if CUDNN_VERSION >= 7201 + // for train + // This interface is used when the input/output is padded. + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardDataEx( + handle, rnn.rnn_desc(), rnn.y_seq_desc(), out_data, rnn.y_seq_desc(), + out_grad_data, nullptr, nullptr, rnn.last_h_desc(), last_h_grad_data, + rnn.last_c_desc(), last_c_grad_data, rnn.weight_desc(), weight_data, + rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data, + rnn.x_seq_desc(), in_grad_data, rnn.init_h_desc(), init_h_grad_data, + rnn.init_c_desc(), init_c_grad_data, nullptr, nullptr, + workspace_data_.data(), workspace_size, + const_cast(reserve_data), reserve_size)); + + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardWeightsEx( + handle, rnn.rnn_desc(), rnn.x_seq_desc(), input->data(), + rnn.init_h_desc(), init_h_data, rnn.y_seq_desc(), out->data(), + workspace_data_.data(), workspace_size, rnn.weight_desc(), + weight_grad_data, const_cast(reserve_data), reserve_size)); +#else + PADDLE_THROW(platform::errors::Unavailable( + "The padded input of rnn is supported by cudnnRNNBackwardDataEx, " + "cudnnRNNBackwardWeightsEx, but it only works when the version " + "of cudnn is larger than 7.2.1")); +#endif + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(rnn, ops::RNNCudnnKernel, + ops::RNNCudnnKernel); +REGISTER_OP_CUDA_KERNEL(rnn_grad, ops::RNNGradCudnnKernel, + ops::RNNGradCudnnKernel); diff --git a/paddle/fluid/platform/cudnn_helper.h b/paddle/fluid/platform/cudnn_helper.h index e983e36895353c215af19937980946a33c242b8c..e591852cc9580901afbd773f0c03c16f9d9167c9 100644 --- a/paddle/fluid/platform/cudnn_helper.h +++ b/paddle/fluid/platform/cudnn_helper.h @@ -361,6 +361,12 @@ class ScopedDropoutDescriptor { float dropout_prob_, framework::Tensor* dropout_state_, int seed, size_t state_size) { + if (dropout_state_ == nullptr) { // for no dropout or test + PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnSetDropoutDescriptor( + desc_, handle, 0 /* dropout */, nullptr, 0 /* state_size */, + 0 /* seed */)); + return desc_; + } auto* dropout_state_data = dropout_state_->data(); if (!initialized) { PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnSetDropoutDescriptor( diff --git a/python/paddle/fluid/tests/unittests/rnn/test_rnn_nets.py b/python/paddle/fluid/tests/unittests/rnn/test_rnn_nets.py index 2eec265b5d27ae4e1e49e1eaa165bf90ef154686..87bdee8a91d21bc5fb0344f578b3ff595767be76 100644 --- a/python/paddle/fluid/tests/unittests/rnn/test_rnn_nets.py +++ b/python/paddle/fluid/tests/unittests/rnn/test_rnn_nets.py @@ -93,10 +93,14 @@ class TestSimpleRNN(unittest.TestCase): np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5) np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) + def test_predict(self): + predict_test_util(self.place, "SimpleRNN") + def runTest(self): self.test_with_initial_state() self.test_with_zero_state() self.test_with_input_lengths() + self.test_predict() class TestGRU(unittest.TestCase): @@ -175,10 +179,14 @@ class TestGRU(unittest.TestCase): np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5) np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) + def test_predict(self): + predict_test_util(self.place, "GRU") + def runTest(self): self.test_with_initial_state() self.test_with_zero_state() self.test_with_input_lengths() + self.test_predict() class TestLSTM(unittest.TestCase): @@ -258,61 +266,7 @@ class TestLSTM(unittest.TestCase): np.testing.assert_allclose(c1, c2.numpy(), atol=1e-8, rtol=1e-5) def test_predict(self): - place = paddle.set_device(self.place) - paddle.seed(123) - np.random.seed(123) - - class Net(paddle.nn.Layer): - def __init__(self): - super(Net, self).__init__() - self.rnn1 = paddle.nn.LSTM( - 16, 32, 2, direction="bidirectional", dropout=0.1) - - def forward(self, input): - return self.rnn1(input) - - x = paddle.randn((4, 10, 16)) - x.stop_gradient = False - seq_len = paddle.to_tensor(np.array([10, 6, 8, 5])) - mask = sequence_mask(seq_len, maxlen=10, dtype=x.dtype) - mask = paddle.unsqueeze(mask, [2]) - rnn = Net() - y, (h, c) = rnn(x) - y = y * mask - loss = paddle.mean(y) - loss.backward() - optimizer = paddle.optimizer.Adam( - learning_rate=0.1, parameters=rnn.parameters()) - optimizer.step() - rnn.eval() - y, (h, c) = rnn(x) - # `jit.to_static` would include a train_program, eval mode might cause - # some errors currently, such as dropout grad op gets `is_test == True`. - rnn.train() - - rnn = paddle.jit.to_static( - rnn, - [paddle.static.InputSpec( - shape=[None, None, 16], dtype=x.dtype)]) - paddle.jit.save(rnn, "./inference/lstm_infer") - - paddle.enable_static() - - new_scope = paddle.static.Scope() - with paddle.static.scope_guard(new_scope): - exe = paddle.static.Executor(place) - [inference_program, feed_target_names, - fetch_targets] = paddle.static.load_inference_model( - dirname="./inference", - executor=exe, - model_filename="lstm_infer.pdmodel", - params_filename="lstm_infer.pdiparams") - results = exe.run(inference_program, - feed={feed_target_names[0]: x.numpy()}, - fetch_list=fetch_targets) - np.testing.assert_equal( - y.numpy(), results[0]) # eval results equal predict results - paddle.disable_static() + predict_test_util(self.place, "LSTM") def runTest(self): self.test_with_initial_state() @@ -321,6 +275,66 @@ class TestLSTM(unittest.TestCase): self.test_predict() +def predict_test_util(place, mode): + place = paddle.set_device(place) + paddle.seed(123) + np.random.seed(123) + + class Net(paddle.nn.Layer): + def __init__(self): + super(Net, self).__init__() + self.rnn = getattr(paddle.nn, mode)(16, + 32, + 2, + direction="bidirectional", + dropout=0.1) + + def forward(self, input): + return self.rnn(input) + + x = paddle.randn((4, 10, 16)) + x.stop_gradient = False + seq_len = paddle.to_tensor(np.array([10, 6, 8, 5])) + mask = sequence_mask(seq_len, maxlen=10, dtype=x.dtype) + mask = paddle.unsqueeze(mask, [2]) + rnn = Net() + y, _ = rnn(x) + y = y * mask + loss = paddle.mean(y) + loss.backward() + optimizer = paddle.optimizer.Adam( + learning_rate=0.1, parameters=rnn.parameters()) + optimizer.step() + rnn.eval() + y, _ = rnn(x) + # `jit.to_static` would include a train_program, eval mode might cause + # some errors currently, such as dropout grad op gets `is_test == True`. + rnn.train() + + rnn = paddle.jit.to_static( + rnn, [paddle.static.InputSpec( + shape=[None, None, 16], dtype=x.dtype)]) + paddle.jit.save(rnn, "./inference/%s_infer" % mode) + + paddle.enable_static() + + new_scope = paddle.static.Scope() + with paddle.static.scope_guard(new_scope): + exe = paddle.static.Executor(place) + [inference_program, feed_target_names, + fetch_targets] = paddle.static.load_inference_model( + dirname="./inference", + executor=exe, + model_filename="%s_infer.pdmodel" % mode, + params_filename="%s_infer.pdiparams" % mode) + results = exe.run(inference_program, + feed={feed_target_names[0]: x.numpy()}, + fetch_list=fetch_targets) + np.testing.assert_equal( + y.numpy(), results[0]) # eval results equal predict results + paddle.disable_static() + + def load_tests(loader, tests, pattern): suite = unittest.TestSuite() devices = ["cpu", "gpu"] if paddle.fluid.is_compiled_with_cuda() \ diff --git a/python/paddle/nn/layer/rnn.py b/python/paddle/nn/layer/rnn.py index 33904524862d4edb16318c64cf9638c7e9156b60..ee989f27ebf72a20a16b575ee9be30b694f0d483 100644 --- a/python/paddle/nn/layer/rnn.py +++ b/python/paddle/nn/layer/rnn.py @@ -990,7 +990,6 @@ class RNNBase(LayerList): self.could_use_cudnn &= direction != "backward" self.could_use_cudnn &= len(self.parameters()) == num_layers * 4 * ( 2 if direction == "bidirectional" else 1) - self.could_use_cudnn &= mode == "LSTM" # currently only support LSTM # Expose params as RNN's attribute, which can make it compatible when # replacing small ops composed rnn with cpp rnn kernel. @@ -1062,22 +1061,18 @@ class RNNBase(LayerList): def _cudnn_impl(self, inputs, initial_states, sequence_length): if not self.time_major: inputs = paddle.tensor.transpose(inputs, [1, 0, 2]) - # unify LSTM/GRU/SimpleRNN later, currently only support LSTM - # TODO(guosheng): use `core.ops.cudnn_lstm` in dygraph mode if support - # specify output, since `dropout_state` should be a persistable tensor - # rather than a temporary on. out = self._helper.create_variable_for_type_inference(inputs.dtype) - last_h = self._helper.create_variable_for_type_inference(inputs.dtype) - last_c = self._helper.create_variable_for_type_inference(inputs.dtype) + state = [ + self._helper.create_variable_for_type_inference(inputs.dtype) + for i in range(self.state_components) + ] reserve = self._helper.create_variable_for_type_inference( dtype=fluid.core.VarDesc.VarType.UINT8, stop_gradient=True) inputs = { 'Input': inputs, - # 'W': self._flat_weight, # would be unused_var 'WeightList': self._all_weights, - 'InitH': initial_states[0], - 'InitC': initial_states[1], + 'PreState': initial_states, 'SequenceLength': sequence_length } attrs = { @@ -1086,23 +1081,22 @@ class RNNBase(LayerList): 'input_size': self.input_size, 'hidden_size': self.hidden_size, 'num_layers': self.num_layers, + 'mode': self.mode, 'is_test': not self.training } outputs = { 'Out': out, - 'LastH': last_h, - 'LastC': last_c, + 'State': state, 'Reserve': reserve, - 'StateOut': self._dropout_state, + 'DropoutState': self._dropout_state, } self._helper.append_op( - type="cudnn_lstm", inputs=inputs, outputs=outputs, attrs=attrs) + type="rnn", inputs=inputs, outputs=outputs, attrs=attrs) out = paddle.tensor.transpose(out, [1, 0, 2]) if not self.time_major else out - states = (last_h, last_c) - return out, states + return out, tuple(state) if len(state) > 1 else state[0] def forward(self, inputs, initial_states=None, sequence_length=None): batch_index = 1 if self.time_major else 0