/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/cudnn_rnn_cache.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/cudnn_desc.h" #include "paddle/fluid/platform/cudnn_helper.h" namespace paddle { namespace operators { using LoDTensor = framework::LoDTensor; using Tensor = framework::Tensor; template class CudnnLSTMGPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { const Tensor *x = ctx.Input("Input"); const Tensor *init_h = ctx.Input("InitH"); const Tensor *init_c = ctx.Input("InitC"); auto w = ctx.Input("W"); Tensor *out = ctx.Output("Out"); Tensor *last_h = ctx.Output("LastH"); Tensor *last_c = ctx.Output("LastC"); Tensor *reserve = ctx.Output("Reserve"); Tensor *state_out = ctx.Output("StateOut"); const T *x_data = x->data(); const T *init_h_data = init_h->data(); const T *init_c_data = init_c->data(); const T *w_data = w->data(); T *out_data = out->mutable_data(ctx.GetPlace()); T *last_h_data = last_h->mutable_data(ctx.GetPlace()); T *last_c_data = last_c->mutable_data(ctx.GetPlace()); float dropout_prob = ctx.Attr("dropout_prob"); bool is_bidirec = ctx.Attr("is_bidirec"); int hidden_size = ctx.Attr("hidden_size"); int num_layers = ctx.Attr("num_layers"); bool is_test = ctx.Attr("is_test"); int seed = ctx.Attr("seed"); auto sequence_length = ctx.Attr>("sequence_length"); auto &dev_ctx = ctx.template device_context(); auto handle = dev_ctx.cudnn_handle(); int seq_length = x->dims()[0]; int batch_size = x->dims()[1]; int input_size = x->dims()[2]; int weight_numel = w->numel(); bool state_initialized = state_out->IsInitialized() ? true : false; size_t workspace_size; size_t reserve_size; platform::ScopedRNNBase rnn(seq_length, batch_size, input_size, hidden_size, num_layers, dropout_prob, seed, weight_numel, state_initialized, is_bidirec); rnn.Create(handle, ctx.GetPlace(), sequence_length, &workspace_size, &reserve_size, state_out); framework::Tensor workspace_data_; workspace_data_.Resize({static_cast(workspace_size)}); workspace_data_.mutable_data(ctx.GetPlace()); auto *reserve_data = reserve->mutable_data( {static_cast(reserve_size)}, ctx.GetPlace()); if (is_test) { if (sequence_length.empty()) { // for inference // This interface is used when the input/output is unpadded. PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardInference( handle, rnn.rnn_desc(), seq_length, rnn.x_desc(), x_data, rnn.hx_desc(), init_h_data, rnn.cx_desc(), init_c_data, rnn.w_desc(), w_data, rnn.y_desc(), out_data, rnn.hy_desc(), last_h_data, rnn.cy_desc(), last_c_data, workspace_data_.data(), workspace_size)); } else { #if CUDNN_VERSION >= 7201 // for inference // This interface is used when the input/output is padded. PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnRNNForwardInferenceEx( handle, rnn.rnn_desc(), rnn.x_seq_desc(), x_data, rnn.hx_desc(), init_h_data, rnn.cx_desc(), init_c_data, rnn.w_desc(), w_data, rnn.y_seq_desc(), out_data, rnn.hy_desc(), last_h_data, rnn.cy_desc(), last_c_data, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, workspace_data_.data(), workspace_size)); #else PADDLE_ENFORCE_NOT_NULL( nullptr, platform::errors::Unavailable( "The padded input is supported by " "cudnnRNNForwardInferenceEx, but it only works when " "the version of cudnn is larger than 7.2.1")); #endif } } else { if (sequence_length.empty()) { // for train // This interface is used when the input/output is unpadded. PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardTraining( handle, rnn.rnn_desc(), seq_length, rnn.x_desc(), x_data, rnn.hx_desc(), init_h_data, rnn.cx_desc(), init_c_data, rnn.w_desc(), w_data, rnn.y_desc(), out_data, rnn.hy_desc(), last_h_data, rnn.cy_desc(), last_c_data, workspace_data_.data(), workspace_size, reserve_data, reserve_size)); } else { #if CUDNN_VERSION >= 7201 // for train // This interface is used when the input/output is padded. PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnRNNForwardTrainingEx( handle, rnn.rnn_desc(), rnn.x_seq_desc(), x_data, rnn.hx_desc(), init_h_data, rnn.cx_desc(), init_c_data, rnn.w_desc(), w_data, rnn.y_seq_desc(), out_data, rnn.hy_desc(), last_h_data, rnn.cy_desc(), last_c_data, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, workspace_data_.data(), workspace_size, reserve_data, reserve_size)); #else PADDLE_ENFORCE_NOT_NULL( nullptr, platform::errors::Unavailable( "The padded input is supported by " "cudnnRNNForwardTrainingEx, but it only works when " "the version of cudnn is larger than 7.2.1")); #endif } } } }; template class CudnnLSTMGPUGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { auto *input = ctx.Input("Input"); auto *weight = ctx.Input("W"); auto *init_h = ctx.Input("InitH"); auto *init_c = ctx.Input("InitC"); auto *reserve = ctx.Input("Reserve"); auto *state_out = ctx.Input("StateOut"); auto *out = ctx.Input("Out"); auto *out_grad = ctx.Input(framework::GradVarName("Out")); auto *last_h_grad = ctx.Input(framework::GradVarName("LastH")); auto *last_c_grad = ctx.Input(framework::GradVarName("LastC")); auto *in_grad = ctx.Output(framework::GradVarName("Input")); auto *weight_grad = ctx.Output(framework::GradVarName("W")); auto *init_h_grad = ctx.Output(framework::GradVarName("InitH")); auto *init_c_grad = ctx.Output(framework::GradVarName("InitC")); auto &dev_ctx = ctx.template device_context(); auto handle = dev_ctx.cudnn_handle(); auto input_dims = input->dims(); auto init_h_dims = init_h->dims(); auto init_c_dims = init_c->dims(); auto *weight_data = weight->data(); auto *init_h_data = init_h->data(); auto *init_c_data = init_c->data(); auto *out_data = out->data(); auto *out_grad_data = out_grad->data(); auto *last_h_grad_data = last_h_grad->data(); auto *last_c_grad_data = last_c_grad->data(); math::SetConstant zero; weight_grad->mutable_data(ctx.GetPlace()); zero(dev_ctx, weight_grad, static_cast(0.0)); in_grad->mutable_data(input_dims, ctx.GetPlace()); auto *in_grad_data = in_grad->data(); init_h_grad->mutable_data(init_h_dims, ctx.GetPlace()); auto *init_h_grad_data = init_h_grad->data(); init_c_grad->mutable_data(init_c_dims, ctx.GetPlace()); auto *init_c_grad_data = init_c_grad->data(); float dropout_prob = ctx.Attr("dropout_prob"); bool is_bidirec = ctx.Attr("is_bidirec"); int hidden_size = ctx.Attr("hidden_size"); int num_layers = ctx.Attr("num_layers"); int seed = ctx.Attr("seed"); auto sequence_length = ctx.Attr>("sequence_length"); int seq_length = input_dims[0]; int batch_size = input->dims()[1]; int input_size = input->dims()[2]; int weight_numel = weight->numel(); size_t workspace_size; size_t reserve_size; platform::ScopedRNNBase rnn(seq_length, batch_size, input_size, hidden_size, num_layers, dropout_prob, seed, weight_numel, true, is_bidirec); rnn.Create(handle, ctx.GetPlace(), sequence_length, &workspace_size, &reserve_size, const_cast(state_out)); framework::Tensor workspace_data_; workspace_data_.Resize({static_cast(workspace_size)}); workspace_data_.mutable_data(ctx.GetPlace()); const uint8_t *reserve_data = reserve->data(); if (sequence_length.empty()) { // This interface is used when the input/output is unpadded. PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardData( handle, rnn.rnn_desc(), seq_length, rnn.y_desc(), out_data, rnn.y_desc(), out_grad_data, rnn.hy_desc(), last_h_grad_data, rnn.cy_desc(), last_c_grad_data, rnn.w_desc(), weight_data, rnn.hx_desc(), init_h_data, rnn.cx_desc(), init_c_data, rnn.x_desc(), in_grad_data, rnn.hx_desc(), init_h_grad_data, rnn.cx_desc(), init_c_grad_data, workspace_data_.data(), workspace_size, const_cast(reserve_data), reserve_size)); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardWeights( handle, rnn.rnn_desc(), seq_length, rnn.x_desc(), input->data(), rnn.hx_desc(), init_h->data(), rnn.y_desc(), out->data(), workspace_data_.data(), workspace_size, rnn.w_desc(), weight_grad->data(), const_cast(reserve_data), reserve_size)); } else { #if CUDNN_VERSION >= 7201 // for train // This interface is used when the input/output is padded. PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardDataEx( handle, rnn.rnn_desc(), rnn.y_seq_desc(), out_data, rnn.y_seq_desc(), out_grad_data, nullptr, nullptr, rnn.hy_desc(), last_h_grad_data, rnn.cy_desc(), last_c_grad_data, rnn.w_desc(), weight_data, rnn.hx_desc(), init_h_data, rnn.cx_desc(), init_c_data, rnn.x_seq_desc(), in_grad_data, rnn.hx_desc(), init_h_grad_data, rnn.cx_desc(), init_c_grad_data, nullptr, nullptr, workspace_data_.data(), workspace_size, const_cast(reserve_data), reserve_size)); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardWeightsEx( handle, rnn.rnn_desc(), rnn.x_seq_desc(), input->data(), rnn.hx_desc(), init_h->data(), rnn.y_seq_desc(), out->data(), workspace_data_.data(), workspace_size, rnn.w_desc(), weight_grad->data(), const_cast(reserve_data), reserve_size)); #else PADDLE_ENFORCE_NOT_NULL( nullptr, platform::errors::Unavailable( "The padded input of rnn is supported by cudnnRNNBackwardDataEx, " "cudnnRNNBackwardWeightsEx, but it only works when the version " "of cudnn is larger than 7.2.1")); #endif } } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL(cudnn_lstm, ops::CudnnLSTMGPUKernel, ops::CudnnLSTMGPUKernel); REGISTER_OP_CUDA_KERNEL(cudnn_lstm_grad, ops::CudnnLSTMGPUGradKernel, ops::CudnnLSTMGPUGradKernel);