From 52889e385f2d5e8e831a5add98e26c36a4e3e164 Mon Sep 17 00:00:00 2001 From: huangjiyi <43315610+huangjiyi@users.noreply.github.com> Date: Tue, 16 May 2023 15:19:56 +0800 Subject: [PATCH] move cudnn_lstm kernel to phi (#53730) * update * fix bug * test * test * update * update mutable_data * fix bug * update * fix bug * update output type reg * update * update --- paddle/fluid/operators/cudnn_lstm_op.cc | 12 - paddle/fluid/operators/cudnn_lstm_op.cu.cc | 743 ------------------ paddle/phi/kernels/cpu/cudnn_lstm_kernel.cc | 49 ++ paddle/phi/kernels/cudnn_lstm_grad_kernel.h | 47 ++ paddle/phi/kernels/cudnn_lstm_kernel.h | 45 ++ .../kernels/gpu}/cudnn_lstm_cache.h | 77 +- .../phi/kernels/gpu/cudnn_lstm_grad_kernel.cu | 312 ++++++++ paddle/phi/kernels/gpu/cudnn_lstm_kernel.cu | 374 +++++++++ paddle/phi/kernels/gpu/cudnn_lstm_utils.h | 76 ++ .../kernels/gpu}/miopen_lstm_cache.h | 70 +- paddle/phi/ops/compat/cudnn_lstm_sig.cc | 59 ++ 11 files changed, 1035 insertions(+), 829 deletions(-) delete mode 100644 paddle/fluid/operators/cudnn_lstm_op.cu.cc create mode 100644 paddle/phi/kernels/cpu/cudnn_lstm_kernel.cc create mode 100644 paddle/phi/kernels/cudnn_lstm_grad_kernel.h create mode 100644 paddle/phi/kernels/cudnn_lstm_kernel.h rename paddle/{fluid/operators => phi/kernels/gpu}/cudnn_lstm_cache.h (72%) create mode 100644 paddle/phi/kernels/gpu/cudnn_lstm_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/cudnn_lstm_kernel.cu create mode 100644 paddle/phi/kernels/gpu/cudnn_lstm_utils.h rename paddle/{fluid/operators => phi/kernels/gpu}/miopen_lstm_cache.h (71%) create mode 100644 paddle/phi/ops/compat/cudnn_lstm_sig.cc diff --git a/paddle/fluid/operators/cudnn_lstm_op.cc b/paddle/fluid/operators/cudnn_lstm_op.cc index c0386a36901..a3cd48cc50c 100644 --- a/paddle/fluid/operators/cudnn_lstm_op.cc +++ b/paddle/fluid/operators/cudnn_lstm_op.cc @@ -292,15 +292,6 @@ class CudnnLSTMGradOpMaker : public framework::SingleGradOpMaker { } }; -template -class NotImpleKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_THROW(platform::errors::Unimplemented( - "CPU is not support for this kernel now. Will be add in the future")); - } -}; - } // namespace operators } // namespace paddle @@ -312,9 +303,6 @@ REGISTER_OPERATOR(cudnn_lstm, ops::CudnnLSTMGradOpMaker); REGISTER_OPERATOR(cudnn_lstm_grad, ops::CudnnLSTMGradOp); -PD_REGISTER_STRUCT_KERNEL( - cudnn_lstm, CPU, ALL_LAYOUT, ops::NotImpleKernel, float) {} - // TODO(Shixiaowei02) Add ModifyInput support REGISTER_OP_VERSION(cudnn_lstm) .AddCheckpoint( diff --git a/paddle/fluid/operators/cudnn_lstm_op.cu.cc b/paddle/fluid/operators/cudnn_lstm_op.cu.cc deleted file mode 100644 index dd9f8bc6297..00000000000 --- a/paddle/fluid/operators/cudnn_lstm_op.cu.cc +++ /dev/null @@ -1,743 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/core/generator.h" -#include "paddle/phi/core/tensor_utils.h" -#include "paddle/phi/kernels/funcs/math_function.h" -#ifdef PADDLE_WITH_CUDA -#include "paddle/fluid/operators/cudnn_lstm_cache.h" -#endif -#ifdef PADDLE_WITH_HIP -#include "paddle/fluid/operators/miopen_lstm_cache.h" -#endif - -namespace paddle { -namespace operators { - -template -bool is_continuous(const Type &weight_list) { - bool continuous = true; - for (size_t i = 0; i < weight_list.size() - 1; ++i) { - auto *in_data = weight_list[i]->template data(); - auto *in_after_data = weight_list[i + 1]->template data(); - auto in_size = weight_list[i]->numel(); - bool temp = in_data + in_size == in_after_data; - continuous = continuous && temp; - } - return continuous; -} - -int size_sum(const std::vector &weight_list) { - int size = 0; - for (size_t i = 0; i < weight_list.size(); ++i) { - auto in_size = weight_list[i]->numel(); - size += in_size; - } - return size; -} - -template -void weight_to_tensor(const platform::Place &place, - gpuStream_t stream, - const std::vector &weight_list, - phi::DenseTensor *weight) { - auto weight_data = weight->data(); - int weight_offset = 0; - for (size_t i = 0; i < weight_list.size(); ++i) { - const T *in_data = weight_list[i]->data(); - auto in_size = weight_list[i]->numel(); - - memory::Copy(weight->place(), - weight_data + weight_offset, - weight_list[i]->place(), - in_data, - in_size * sizeof(T), - stream); - weight_offset += in_size; - } -} - -template -void weight_to_tensor_list( - const platform::Place &place, - gpuStream_t stream, - std::vector *weight_grad, - const std::vector &weight_input, - const phi::DenseTensor *weight) { - int weight_offset = 0; - auto *weight_data = weight->data(); - for (size_t i = 0; i < weight_input.size(); ++i) { - auto in_size = weight_input[i]->numel(); - T *weight_grad_data = (*weight_grad)[i]->mutable_data(place); - const T *src = weight_data + weight_offset; - - memory::Copy((*weight_grad)[i]->place(), - weight_grad_data, - weight->place(), - src, - in_size * sizeof(T), - stream); - weight_offset += in_size; - } -} - -template -#ifdef PADDLE_WITH_HIP -void LSTMInferece(const bool &has_seq_length, - const miopenHandle_t &handle, -#else -void LSTMInferece(const bool &has_seq_length, - const cudnnHandle_t &handle, -#endif - const int &seq_length, - ScopedRNNBase *rnn, - const T *x_data, - const T *init_h_data, - const T *init_c_data, - const T *w_data, - T *out_data, - T *last_h_data, - T *last_c_data, - phi::DenseTensor *workspace_data, - const size_t &workspace_size) { - if (!has_seq_length) { -// for inference -// This interface is used when the input/output is unpadded. -#ifdef PADDLE_WITH_HIP - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenRNNForwardInference( - handle, - rnn->rnn_desc(), - seq_length, - rnn->x_descs(), - x_data, - rnn->init_h_desc(), - init_h_data, - rnn->init_c_desc(), - init_c_data, - rnn->weight_desc(), - w_data, - rnn->y_descs(), - out_data, - rnn->last_h_desc(), - last_h_data, - rnn->last_c_desc(), - last_c_data, - workspace_data->data(), - workspace_size)); -#else - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNForwardInference( - handle, - rnn->rnn_desc(), - seq_length, - rnn->x_descs(), - x_data, - rnn->init_h_desc(), - init_h_data, - rnn->init_c_desc(), - init_c_data, - rnn->weight_desc(), - w_data, - rnn->y_descs(), - out_data, - rnn->last_h_desc(), - last_h_data, - rnn->last_c_desc(), - last_c_data, - workspace_data->data(), - workspace_size)); -#endif - } else { -#if !defined(PADDLE_WITH_HIP) && CUDNN_VERSION >= 7201 - // for inference - // This interface is used when the input/output is padded. - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNForwardInferenceEx( - handle, - rnn->rnn_desc(), - rnn->x_seq_desc(), - x_data, - rnn->init_h_desc(), - init_h_data, - rnn->init_c_desc(), - init_c_data, - rnn->weight_desc(), - w_data, - rnn->y_seq_desc(), - out_data, - rnn->last_h_desc(), - last_h_data, - rnn->last_c_desc(), - last_c_data, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - workspace_data->data(), - workspace_size)); -#else - // CUDNN VERSION has to >=7.2.1 - PADDLE_THROW(platform::errors::Unavailable( - "The padded input is supported by " - "cudnnRNNForwardInferenceEx, but it only works when " - "the version of cudnn is larger than 7.2.1")); -#endif - } -} - -template -class CudnnLSTMGPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - const phi::DenseTensor *x = ctx.Input("Input"); - const phi::DenseTensor *init_h = ctx.Input("InitH"); - const phi::DenseTensor *init_c = ctx.Input("InitC"); - - phi::DenseTensor *out = ctx.Output("Out"); - phi::DenseTensor *last_h = ctx.Output("LastH"); - phi::DenseTensor *last_c = ctx.Output("LastC"); - phi::DenseTensor *reserve = ctx.Output("Reserve"); - phi::DenseTensor *state_out = ctx.Output("StateOut"); - - const T *x_data = x->data(); - const T *init_h_data = init_h->data(); - const T *init_c_data = init_c->data(); - - T *out_data = out->mutable_data(ctx.GetPlace()); - T *last_h_data = last_h->mutable_data(ctx.GetPlace()); - T *last_c_data = last_c->mutable_data(ctx.GetPlace()); - - float dropout_prob = ctx.Attr("dropout_prob"); - bool is_bidirec = ctx.Attr("is_bidirec"); - int hidden_size = ctx.Attr("hidden_size"); - int num_layers = ctx.Attr("num_layers"); - bool is_test = ctx.Attr("is_test"); - int seed = ctx.Attr("seed"); - - if (!is_test) { - if (seed == 0) { - // If not specify seed, use global Generator to generate seed. - int device_id = ctx.GetPlace().GetDeviceId(); - auto gen_cuda = phi::DefaultCUDAGenerator(device_id); - seed = static_cast(gen_cuda->Random64()); - } - // else use `ctx.Attr("seed")` specified seed - } - - bool has_seq_length = ctx.HasInput("SequenceLength"); - std::vector SequenceLength; - if (has_seq_length) { - auto *sequence_length = ctx.Input("SequenceLength"); - SequenceLength = phi::GetVectorFromTensor(sequence_length); - } - - auto &dev_ctx = ctx.template device_context(); - auto handle = dev_ctx.cudnn_handle(); - - int seq_length = x->dims()[0]; - int batch_size = x->dims()[1]; - int input_size = x->dims()[2]; - bool state_initialized = state_out->IsInitialized() ? true : false; - - size_t workspace_size; - size_t reserve_size; - phi::DenseTensor weight_whole; - T *w_data = nullptr; - int weight_numel; - bool w_initialized = false; - auto place = ctx.GetPlace(); - auto stream = - reinterpret_cast(ctx.device_context()) - .stream(); - if (is_test && ctx.HasInput("W")) { - auto *W = ctx.Input("W"); - w_initialized = W->IsInitialized() ? true : false; - weight_numel = W->numel(); - } - if (!w_initialized) { - auto weight_list = ctx.MultiInput("WeightList"); - bool continuous = - is_continuous>(weight_list); - weight_numel = size_sum(weight_list); - - if (!continuous) { - LOG_FIRST_N(WARNING, 2) - << "If the memory space of the Input WeightList is not continuous, " - "less efficient calculation will be called. Please call " - "flatten_parameters() to make the input memory continuous."; - weight_whole.mutable_data({weight_numel}, place); - weight_to_tensor(place, stream, weight_list, &weight_whole); - w_data = weight_whole.data(); - if (is_test) { // maybe also reset small weights' ptr for training - int offset = 0; - for (size_t i = 0; i < weight_list.size(); ++i) { - size_t len = weight_list[i]->numel(); - auto dim = weight_list[i]->dims(); - const_cast(weight_list[i]) - ->ShareDataWith( - weight_whole.Slice(static_cast(offset), - static_cast(offset + len))) - .Resize(dim); - offset += len; - } - } - } else { - w_data = const_cast(weight_list[0]->data()); - } - } else { - auto *W = ctx.Input("W"); - w_data = const_cast(W->data()); - } - - ScopedRNNBase rnn(seq_length, - batch_size, - input_size, - hidden_size, - num_layers, - dropout_prob, - seed, - weight_numel, - state_initialized, - is_bidirec); - rnn.Create(handle, - ctx.GetPlace(), - SequenceLength, - &workspace_size, - &reserve_size, - state_out); - - phi::DenseTensor workspace_data_; - workspace_data_.mutable_data( - {static_cast(workspace_size)}, ctx.GetPlace()); - - auto *reserve_data = reserve->mutable_data( - {static_cast(reserve_size)}, ctx.GetPlace()); - - if (is_test) { - LSTMInferece(has_seq_length, - handle, - seq_length, - &rnn, - x_data, - init_h_data, - init_c_data, - w_data, - out_data, - last_h_data, - last_c_data, - &workspace_data_, - workspace_size); - } else { - if (!has_seq_length) { -// for train -// This interface is used when the input/output is unpadded. -#ifdef PADDLE_WITH_HIP - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenRNNForwardTraining( - handle, - rnn.rnn_desc(), - seq_length, - rnn.x_descs(), - x_data, - rnn.init_h_desc(), - init_h_data, - rnn.init_c_desc(), - init_c_data, - rnn.weight_desc(), - w_data, - rnn.y_descs(), - out_data, - rnn.last_h_desc(), - last_h_data, - rnn.last_c_desc(), - last_c_data, - workspace_data_.data(), - workspace_size, - reserve_data, - reserve_size)); -#else - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNForwardTraining( - handle, - rnn.rnn_desc(), - seq_length, - rnn.x_descs(), - x_data, - rnn.init_h_desc(), - init_h_data, - rnn.init_c_desc(), - init_c_data, - rnn.weight_desc(), - w_data, - rnn.y_descs(), - out_data, - rnn.last_h_desc(), - last_h_data, - rnn.last_c_desc(), - last_c_data, - workspace_data_.data(), - workspace_size, - reserve_data, - reserve_size)); -#endif - } else { -#if !defined(PADDLE_WITH_HIP) && CUDNN_VERSION >= 7201 - // for train - // This interface is used when the input/output is padded. - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNForwardTrainingEx( - handle, - rnn.rnn_desc(), - rnn.x_seq_desc(), - x_data, - rnn.init_h_desc(), - init_h_data, - rnn.init_c_desc(), - init_c_data, - rnn.weight_desc(), - w_data, - rnn.y_seq_desc(), - out_data, - rnn.last_h_desc(), - last_h_data, - rnn.last_c_desc(), - last_c_data, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - workspace_data_.data(), - workspace_size, - reserve_data, - reserve_size)); -#else - PADDLE_THROW(platform::errors::Unavailable( - "The padded input is supported by " - "cudnnRNNForwardTrainingEx, but it only works when " - "the version of cudnn is larger than 7.2.1")); -#endif - } - } - } -}; - -template -class CudnnLSTMGPUGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - auto *input = ctx.Input("Input"); - auto *init_h = ctx.Input("InitH"); - auto *init_c = ctx.Input("InitC"); - auto *reserve = ctx.Input("Reserve"); - auto *state_out = ctx.Input("StateOut"); - auto weight_list = ctx.MultiInput("WeightList"); - - auto *out = ctx.Input("Out"); - auto *out_grad = ctx.Input(framework::GradVarName("Out")); - auto *last_h_grad = - ctx.Input(framework::GradVarName("LastH")); - auto *last_c_grad = - ctx.Input(framework::GradVarName("LastC")); - - auto *in_grad = - ctx.Output(framework::GradVarName("Input")); - auto *init_h_grad = - ctx.Output(framework::GradVarName("InitH")); - auto *init_c_grad = - ctx.Output(framework::GradVarName("InitC")); - auto weight_grad_list = - ctx.MultiOutput(framework::GradVarName("WeightList")); - - auto &dev_ctx = ctx.template device_context(); - auto handle = dev_ctx.cudnn_handle(); - - auto input_dims = input->dims(); - auto init_h_dims = init_h->dims(); - auto init_c_dims = init_c->dims(); - - auto *init_h_data = init_h->data(); - auto *init_c_data = init_c->data(); - auto *out_data = out->data(); - auto *out_grad_data = out_grad->data(); - auto *last_h_grad_data = last_h_grad->data(); - auto *last_c_grad_data = last_c_grad->data(); - - auto place = ctx.GetPlace(); - int weight_numel = size_sum(weight_list); - bool continuous = - is_continuous>(weight_list); - - auto stream = - reinterpret_cast(ctx.device_context()) - .stream(); - phi::DenseTensor weight_whole; - T *weight_data = nullptr; - - if (!continuous) { - weight_whole.mutable_data({weight_numel}, place); - weight_to_tensor(place, stream, weight_list, &weight_whole); - weight_data = weight_whole.data(); - } else { - weight_data = const_cast(weight_list[0]->data()); - } - - phi::DenseTensor weight_grad; - phi::funcs::SetConstant zero; - weight_grad.mutable_data({weight_numel}, ctx.GetPlace()); - zero(dev_ctx, &weight_grad, static_cast(0.0)); - T *weight_grad_data = weight_grad.data(); - - int offset = 0; - for (size_t i = 0; i < weight_grad_list.size(); ++i) { - size_t len = weight_grad_list[i]->numel(); - auto dim = weight_grad_list[i]->dims(); - weight_grad_list[i] - ->ShareDataWith(weight_grad.Slice(static_cast(offset), - static_cast(offset + len))) - .Resize(dim); - offset += len; - } - - in_grad->mutable_data(input_dims, ctx.GetPlace()); - auto *in_grad_data = in_grad->data(); - - if (init_h_grad) init_h_grad->mutable_data(init_h_dims, ctx.GetPlace()); - auto *init_h_grad_data = init_h_grad ? init_h_grad->data() : nullptr; - - if (init_c_grad) init_c_grad->mutable_data(init_c_dims, ctx.GetPlace()); - auto *init_c_grad_data = init_c_grad ? init_c_grad->data() : nullptr; - - float dropout_prob = ctx.Attr("dropout_prob"); - bool is_bidirec = ctx.Attr("is_bidirec"); - int hidden_size = ctx.Attr("hidden_size"); - int num_layers = ctx.Attr("num_layers"); - int seed = ctx.Attr("seed"); - - bool has_seq_length = ctx.HasInput("SequenceLength"); - std::vector SequenceLength; - if (has_seq_length) { - auto *sequence_length = ctx.Input("SequenceLength"); - SequenceLength = phi::GetVectorFromTensor(sequence_length); - } - - int seq_length = input_dims[0]; - int batch_size = input->dims()[1]; - int input_size = input->dims()[2]; - - size_t workspace_size; - size_t reserve_size; - - ScopedRNNBase rnn(seq_length, - batch_size, - input_size, - hidden_size, - num_layers, - dropout_prob, - seed, - weight_numel, - true, - is_bidirec); - - rnn.Create(handle, - ctx.GetPlace(), - SequenceLength, - &workspace_size, - &reserve_size, - const_cast(state_out)); - - phi::DenseTensor workspace_data_; - workspace_data_.mutable_data( - {static_cast(workspace_size)}, ctx.GetPlace()); - const uint8_t *reserve_data = reserve->data(); - - if (!has_seq_length) { -// This interface is used when the input/output is unpadded. -#ifdef PADDLE_WITH_HIP - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenRNNBackwardData( - handle, - rnn.rnn_desc(), - seq_length, - rnn.y_descs(), - out_data, - rnn.y_descs(), - out_grad_data, - rnn.last_h_desc(), - last_h_grad_data, - rnn.last_c_desc(), - last_c_grad_data, - rnn.weight_desc(), - weight_data, - rnn.init_h_desc(), - init_h_data, - rnn.init_c_desc(), - init_c_data, - rnn.x_descs(), - in_grad_data, - rnn.init_h_desc(), - init_h_grad_data, - rnn.init_c_desc(), - init_c_grad_data, - workspace_data_.data(), - workspace_size, - const_cast(reserve_data), - reserve_size)); - - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenRNNBackwardWeights( - handle, - rnn.rnn_desc(), - seq_length, - rnn.x_descs(), - input->data(), - rnn.init_h_desc(), - init_h->data(), - rnn.y_descs(), - out->data(), - rnn.weight_desc(), - weight_grad_data, - workspace_data_.data(), - workspace_size, - const_cast(reserve_data), - reserve_size)); -#else - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNBackwardData( - handle, - rnn.rnn_desc(), - seq_length, - rnn.y_descs(), - out_data, - rnn.y_descs(), - out_grad_data, - rnn.last_h_desc(), - last_h_grad_data, - rnn.last_c_desc(), - last_c_grad_data, - rnn.weight_desc(), - weight_data, - rnn.init_h_desc(), - init_h_data, - rnn.init_c_desc(), - init_c_data, - rnn.x_descs(), - in_grad_data, - rnn.init_h_desc(), - init_h_grad_data, - rnn.init_c_desc(), - init_c_grad_data, - workspace_data_.data(), - workspace_size, - const_cast(reserve_data), - reserve_size)); - - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNBackwardWeights( - handle, - rnn.rnn_desc(), - seq_length, - rnn.x_descs(), - input->data(), - rnn.init_h_desc(), - init_h->data(), - rnn.y_descs(), - out->data(), - workspace_data_.data(), - workspace_size, - rnn.weight_desc(), - weight_grad_data, - const_cast(reserve_data), - reserve_size)); -#endif - } else { -#if !defined(PADDLE_WITH_HIP) && CUDNN_VERSION >= 7201 - // for train - // This interface is used when the input/output is padded. - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNBackwardDataEx( - handle, - rnn.rnn_desc(), - rnn.y_seq_desc(), - out_data, - rnn.y_seq_desc(), - out_grad_data, - nullptr, - nullptr, - rnn.last_h_desc(), - last_h_grad_data, - rnn.last_c_desc(), - last_c_grad_data, - rnn.weight_desc(), - weight_data, - rnn.init_h_desc(), - init_h_data, - rnn.init_c_desc(), - init_c_data, - rnn.x_seq_desc(), - in_grad_data, - rnn.init_h_desc(), - init_h_grad_data, - rnn.init_c_desc(), - init_c_grad_data, - nullptr, - nullptr, - workspace_data_.data(), - workspace_size, - const_cast(reserve_data), - reserve_size)); - - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNBackwardWeightsEx( - handle, - rnn.rnn_desc(), - rnn.x_seq_desc(), - input->data(), - rnn.init_h_desc(), - init_h->data(), - rnn.y_seq_desc(), - out->data(), - workspace_data_.data(), - workspace_size, - rnn.weight_desc(), - weight_grad_data, - const_cast(reserve_data), - reserve_size)); -#else - PADDLE_THROW(platform::errors::Unavailable( - "The padded input of rnn is supported by cudnnRNNBackwardDataEx, " - "cudnnRNNBackwardWeightsEx, but it only works when the version " - "of cudnn is larger than 7.2.1")); -#endif - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -#ifdef PADDLE_WITH_HIP -// MIOPEN do not support double -PD_REGISTER_STRUCT_KERNEL( - cudnn_lstm, GPU, ALL_LAYOUT, ops::CudnnLSTMGPUKernel, float) {} -PD_REGISTER_STRUCT_KERNEL( - cudnn_lstm_grad, GPU, ALL_LAYOUT, ops::CudnnLSTMGPUGradKernel, float) {} -#else -PD_REGISTER_STRUCT_KERNEL( - cudnn_lstm, GPU, ALL_LAYOUT, ops::CudnnLSTMGPUKernel, float, double) {} -PD_REGISTER_STRUCT_KERNEL(cudnn_lstm_grad, - GPU, - ALL_LAYOUT, - ops::CudnnLSTMGPUGradKernel, - float, - double) {} -#endif diff --git a/paddle/phi/kernels/cpu/cudnn_lstm_kernel.cc b/paddle/phi/kernels/cpu/cudnn_lstm_kernel.cc new file mode 100644 index 00000000000..cd709fe2bf4 --- /dev/null +++ b/paddle/phi/kernels/cpu/cudnn_lstm_kernel.cc @@ -0,0 +1,49 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/cudnn_lstm_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void CudnnLSTMKernel( + const Context& ctx, + const DenseTensor& x, + const DenseTensor& init_h, + const DenseTensor& init_c, + const paddle::optional& w, + const paddle::optional>& weight_list, + const paddle::optional& sequence_length, + float dropout_prob, + bool is_bidirec, + int hidden_size, + int num_layers, + bool is_test, + int seed, + DenseTensor* out, + DenseTensor* last_h, + DenseTensor* last_c, + DenseTensor* reserve, + DenseTensor* state_out) { + PADDLE_THROW(phi::errors::Unimplemented( + "CPU is not support for cudnn_lstm now. Will be add in the future")); +} + +} // namespace phi + +PD_REGISTER_KERNEL(cudnn_lstm, CPU, ALL_LAYOUT, phi::CudnnLSTMKernel, float) {} diff --git a/paddle/phi/kernels/cudnn_lstm_grad_kernel.h b/paddle/phi/kernels/cudnn_lstm_grad_kernel.h new file mode 100644 index 00000000000..848d3376ee3 --- /dev/null +++ b/paddle/phi/kernels/cudnn_lstm_grad_kernel.h @@ -0,0 +1,47 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/utils/optional.h" + +namespace phi { + +template +void CudnnLSTMGradKernel( + const Context& ctx, + const DenseTensor& x, + const DenseTensor& init_h, + const DenseTensor& init_c, + const paddle::optional>& weight_list, + const paddle::optional& sequence_length, + const DenseTensor& out, + const DenseTensor& reserve, + const DenseTensor& state_out, + const DenseTensor& out_grad, + const DenseTensor& last_h_grad, + const DenseTensor& last_c_grad, + float dropout_prob, + bool is_bidirec, + int hidden_size, + int num_layers, + bool is_test, + int seed, + DenseTensor* x_grad, + DenseTensor* init_h_grad, + DenseTensor* init_c_grad, + std::vector weight_list_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/cudnn_lstm_kernel.h b/paddle/phi/kernels/cudnn_lstm_kernel.h new file mode 100644 index 00000000000..441808235fd --- /dev/null +++ b/paddle/phi/kernels/cudnn_lstm_kernel.h @@ -0,0 +1,45 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/utils/optional.h" + +namespace phi { + +template +void CudnnLSTMKernel( + const Context& ctx, + const DenseTensor& x, + const DenseTensor& init_h, + const DenseTensor& init_c, + const paddle::optional& w, + const paddle::optional>& weight_list, + const paddle::optional& sequence_length, + float dropout_prob, + bool is_bidirec, + int hidden_size, + int num_layers, + bool is_test, + int seed, + DenseTensor* out, + DenseTensor* last_h, + DenseTensor* last_c, + DenseTensor* reserve, + DenseTensor* state_out); + +} // namespace phi diff --git a/paddle/fluid/operators/cudnn_lstm_cache.h b/paddle/phi/kernels/gpu/cudnn_lstm_cache.h similarity index 72% rename from paddle/fluid/operators/cudnn_lstm_cache.h rename to paddle/phi/kernels/gpu/cudnn_lstm_cache.h index 32f1b46dbbd..197049452f9 100644 --- a/paddle/fluid/operators/cudnn_lstm_cache.h +++ b/paddle/phi/kernels/gpu/cudnn_lstm_cache.h @@ -16,12 +16,13 @@ limitations under the License. */ #include -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" -#include "paddle/fluid/platform/dynload/cudnn.h" +#include "paddle/phi/backends/context_pool.h" +#include "paddle/phi/backends/dynload/cudnn.h" +#include "paddle/phi/backends/gpu/forwards.h" +#include "paddle/phi/backends/gpu/gpu_dnn.h" +#include "paddle/phi/core/dense_tensor.h" -namespace paddle { -namespace operators { +namespace phi { class ScopedRNNBase { public: @@ -48,13 +49,13 @@ class ScopedRNNBase { template void Create(const cudnnHandle_t& handle, - const platform::Place& place, + const phi::Place& place, const std::vector& sequence_length, size_t* workspace_size, size_t* reserve_size, phi::DenseTensor* dropout_state) { int numDirections = is_bidirec_ ? 2 : 1; - cudnnDataType_t cudnn_type = platform::CudnnDataType::type; + cudnnDataType_t cudnn_type = phi::backends::gpu::CudnnDataType::type; // ------------------- cudnn x, y descriptors --------------------- std::vector dims_x = {batch_size_, input_size_, 1}; @@ -91,9 +92,11 @@ class ScopedRNNBase { size_t state_size; if (!initialized_) { PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnDropoutGetStatesSize(handle, &state_size)); - dropout_state->mutable_data({static_cast(state_size)}, - place); + phi::dynload::cudnnDropoutGetStatesSize(handle, &state_size)); + phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance(); + auto* dev_ctx = reinterpret_cast(pool.Get(place)); + dropout_state->Resize({static_cast(state_size)}); + dev_ctx->template Alloc(dropout_state); } dropout_desc_.descriptor(handle, place, @@ -104,7 +107,7 @@ class ScopedRNNBase { state_size); // ------------------- cudnn rnn descriptors --------------------- - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetRNNDescriptor_v6( + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnSetRNNDescriptor_v6( handle, rnn_desc_.desc(), hidden_size_, @@ -118,38 +121,35 @@ class ScopedRNNBase { #if CUDNN_VERSION >= 7201 if (!sequence_length.empty()) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetRNNPaddingMode( + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnSetRNNPaddingMode( rnn_desc_.desc(), CUDNN_RNN_PADDED_IO_ENABLED)); } #endif // ------------------- cudnn weights_size --------------------- size_t weights_size_; - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnGetRNNParamsSize( + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnGetRNNParamsSize( handle, rnn_desc_.desc(), x_descs_[0], &weights_size_, cudnn_type)); PADDLE_ENFORCE_EQ( weights_size_, sizeof(T) * weight_numel_, - platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The cudnn lstm and setting weight size should be same.")); // ------------------- cudnn weight descriptors --------------------- - platform::DataLayout layout = platform::DataLayout::kNCHW; + phi::backends::gpu::DataLayout layout = + phi::backends::gpu::DataLayout::kNCHW; int dim_tmp = weights_size_ / sizeof(T); std::vector dim_w = {dim_tmp, 1, 1}; weight_desc_.descriptor(layout, dim_w); // ------------------- cudnn workspace, reserve size --------------------- PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnGetRNNWorkspaceSize(handle, - rnn_desc_.desc(), - seq_length_, - x_descs_.data(), - workspace_size)); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnGetRNNTrainingReserveSize(handle, - rnn_desc_.desc(), - seq_length_, - x_descs_.data(), - reserve_size)); + phi::dynload::cudnnGetRNNWorkspaceSize(handle, + rnn_desc_.desc(), + seq_length_, + x_descs_.data(), + workspace_size)); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnGetRNNTrainingReserveSize( + handle, rnn_desc_.desc(), seq_length_, x_descs_.data(), reserve_size)); } cudnnTensorDescriptor_t* x_descs() { return x_descs_.data(); } cudnnTensorDescriptor_t* y_descs() { return y_descs_.data(); } @@ -179,20 +179,19 @@ class ScopedRNNBase { std::vector x_descs_; std::vector y_descs_; - platform::ScopedTensorDescriptor x_desc_; - platform::ScopedTensorDescriptor y_desc_; + phi::backends::gpu::ScopedTensorDescriptor x_desc_; + phi::backends::gpu::ScopedTensorDescriptor y_desc_; #if CUDNN_VERSION >= 7201 - platform::ScopedRNNTensorDescriptor x_seq_desc_; - platform::ScopedRNNTensorDescriptor y_seq_desc_; + phi::backends::gpu::ScopedRNNTensorDescriptor x_seq_desc_; + phi::backends::gpu::ScopedRNNTensorDescriptor y_seq_desc_; #endif - platform::ScopedTensorDescriptor init_h_desc_; - platform::ScopedTensorDescriptor init_c_desc_; - platform::ScopedTensorDescriptor last_h_desc_; - platform::ScopedTensorDescriptor last_c_desc_; - platform::ScopedDropoutDescriptor dropout_desc_; - platform::ScopedFilterDescriptor weight_desc_; - platform::ScopedRNNDescriptor rnn_desc_; + phi::backends::gpu::ScopedTensorDescriptor init_h_desc_; + phi::backends::gpu::ScopedTensorDescriptor init_c_desc_; + phi::backends::gpu::ScopedTensorDescriptor last_h_desc_; + phi::backends::gpu::ScopedTensorDescriptor last_c_desc_; + phi::backends::gpu::ScopedDropoutDescriptor dropout_desc_; + phi::backends::gpu::ScopedFilterDescriptor weight_desc_; + phi::backends::gpu::ScopedRNNDescriptor rnn_desc_; }; -} // namespace operators -} // namespace paddle +} // namespace phi diff --git a/paddle/phi/kernels/gpu/cudnn_lstm_grad_kernel.cu b/paddle/phi/kernels/gpu/cudnn_lstm_grad_kernel.cu new file mode 100644 index 00000000000..661a1dd90e7 --- /dev/null +++ b/paddle/phi/kernels/gpu/cudnn_lstm_grad_kernel.cu @@ -0,0 +1,312 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/cudnn_lstm_grad_kernel.h" + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/gpu/cudnn_lstm_utils.h" + +namespace phi { + +template +void CudnnLSTMGradKernel( + const Context &ctx, + const DenseTensor &x, + const DenseTensor &init_h, + const DenseTensor &init_c, + const paddle::optional> &weight_list, + const paddle::optional &sequence_length, + const DenseTensor &out, + const DenseTensor &reserve, + const DenseTensor &state_out, + const DenseTensor &out_grad, + const DenseTensor &last_h_grad, + const DenseTensor &last_c_grad, + float dropout_prob, + bool is_bidirec, + int hidden_size, + int num_layers, + bool is_test, + int seed, + DenseTensor *x_grad, + DenseTensor *init_h_grad, + DenseTensor *init_c_grad, + std::vector weight_grad_list) { + auto input_dims = x.dims(); + auto init_h_dims = init_h.dims(); + auto init_c_dims = init_c.dims(); + + auto *init_h_data = init_h.data(); + auto *init_c_data = init_c.data(); + auto *out_data = out.data(); + auto *out_grad_data = out_grad.data(); + auto *last_h_grad_data = last_h_grad.data(); + auto *last_c_grad_data = last_c_grad.data(); + + auto running_weight_list = *weight_list.get_ptr(); + int weight_numel = size_sum(running_weight_list); + bool continuous = is_continuous>( + running_weight_list); + + auto handle = ctx.cudnn_handle(); + auto place = ctx.GetPlace(); + auto stream = ctx.stream(); + phi::DenseTensor weight_whole; + T *weight_data = nullptr; + + if (!continuous) { + weight_whole.Resize({weight_numel}); + ctx.template Alloc(&weight_whole); + weight_to_tensor(place, stream, running_weight_list, &weight_whole); + weight_data = weight_whole.data(); + } else { + weight_data = const_cast(running_weight_list[0]->data()); + } + + phi::DenseTensor weight_grad; + phi::funcs::SetConstant zero; + weight_grad.Resize({weight_numel}); + ctx.template Alloc(&weight_grad); + zero(ctx, &weight_grad, static_cast(0.0)); + T *weight_grad_data = weight_grad.data(); + + int offset = 0; + for (size_t i = 0; i < weight_grad_list.size(); ++i) { + size_t len = weight_grad_list[i]->numel(); + auto dim = weight_grad_list[i]->dims(); + weight_grad_list[i] + ->ShareDataWith(weight_grad.Slice(static_cast(offset), + static_cast(offset + len))) + .Resize(dim); + offset += len; + } + + x_grad->Resize(input_dims); + ctx.template Alloc(x_grad); + auto *in_grad_data = x_grad->data(); + + if (init_h_grad) { + init_h_grad->Resize(init_h_dims); + ctx.template Alloc(init_h_grad); + } + auto *init_h_grad_data = init_h_grad ? init_h_grad->data() : nullptr; + + if (init_c_grad) { + init_c_grad->Resize(init_c_dims); + ctx.template Alloc(init_c_grad); + } + auto *init_c_grad_data = init_c_grad ? init_c_grad->data() : nullptr; + + auto running_seq_length = sequence_length.get_ptr(); + bool has_seq_length = running_seq_length != nullptr; + std::vector SequenceLength; + if (has_seq_length) { + SequenceLength = phi::GetVectorFromTensor(running_seq_length); + } + + int seq_length = input_dims[0]; + int batch_size = x.dims()[1]; + int input_size = x.dims()[2]; + + size_t workspace_size; + size_t reserve_size; + + ScopedRNNBase rnn(seq_length, + batch_size, + input_size, + hidden_size, + num_layers, + dropout_prob, + seed, + weight_numel, + true, + is_bidirec); + + rnn.Create(handle, + ctx.GetPlace(), + SequenceLength, + &workspace_size, + &reserve_size, + const_cast(&state_out)); + + phi::DenseTensor workspace_data_; + workspace_data_.Resize({static_cast(workspace_size)}); + ctx.template Alloc(&workspace_data_); + const uint8_t *reserve_data = reserve.data(); + + if (!has_seq_length) { +// This interface is used when the input/output is unpadded. +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::miopenRNNBackwardData(handle, + rnn.rnn_desc(), + seq_length, + rnn.y_descs(), + out_data, + rnn.y_descs(), + out_grad_data, + rnn.last_h_desc(), + last_h_grad_data, + rnn.last_c_desc(), + last_c_grad_data, + rnn.weight_desc(), + weight_data, + rnn.init_h_desc(), + init_h_data, + rnn.init_c_desc(), + init_c_data, + rnn.x_descs(), + in_grad_data, + rnn.init_h_desc(), + init_h_grad_data, + rnn.init_c_desc(), + init_c_grad_data, + workspace_data_.data(), + workspace_size, + const_cast(reserve_data), + reserve_size)); + + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::miopenRNNBackwardWeights( + handle, + rnn.rnn_desc(), + seq_length, + rnn.x_descs(), + x.data(), + rnn.init_h_desc(), + init_h.data(), + rnn.y_descs(), + out.data(), + rnn.weight_desc(), + weight_grad_data, + workspace_data_.data(), + workspace_size, + const_cast(reserve_data), + reserve_size)); +#else + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnRNNBackwardData(handle, + rnn.rnn_desc(), + seq_length, + rnn.y_descs(), + out_data, + rnn.y_descs(), + out_grad_data, + rnn.last_h_desc(), + last_h_grad_data, + rnn.last_c_desc(), + last_c_grad_data, + rnn.weight_desc(), + weight_data, + rnn.init_h_desc(), + init_h_data, + rnn.init_c_desc(), + init_c_data, + rnn.x_descs(), + in_grad_data, + rnn.init_h_desc(), + init_h_grad_data, + rnn.init_c_desc(), + init_c_grad_data, + workspace_data_.data(), + workspace_size, + const_cast(reserve_data), + reserve_size)); + + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnRNNBackwardWeights( + handle, + rnn.rnn_desc(), + seq_length, + rnn.x_descs(), + x.data(), + rnn.init_h_desc(), + init_h.data(), + rnn.y_descs(), + out.data(), + workspace_data_.data(), + workspace_size, + rnn.weight_desc(), + weight_grad_data, + const_cast(reserve_data), + reserve_size)); +#endif + } else { +#if !defined(PADDLE_WITH_HIP) && CUDNN_VERSION >= 7201 + // for train + // This interface is used when the input/output is padded. + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnRNNBackwardDataEx( + handle, + rnn.rnn_desc(), + rnn.y_seq_desc(), + out_data, + rnn.y_seq_desc(), + out_grad_data, + nullptr, + nullptr, + rnn.last_h_desc(), + last_h_grad_data, + rnn.last_c_desc(), + last_c_grad_data, + rnn.weight_desc(), + weight_data, + rnn.init_h_desc(), + init_h_data, + rnn.init_c_desc(), + init_c_data, + rnn.x_seq_desc(), + in_grad_data, + rnn.init_h_desc(), + init_h_grad_data, + rnn.init_c_desc(), + init_c_grad_data, + nullptr, + nullptr, + workspace_data_.data(), + workspace_size, + const_cast(reserve_data), + reserve_size)); + + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnRNNBackwardWeightsEx( + handle, + rnn.rnn_desc(), + rnn.x_seq_desc(), + x.data(), + rnn.init_h_desc(), + init_h.data(), + rnn.y_seq_desc(), + out.data(), + workspace_data_.data(), + workspace_size, + rnn.weight_desc(), + weight_grad_data, + const_cast(reserve_data), + reserve_size)); +#else + PADDLE_THROW(phi::errors::Unavailable( + "The padded input of rnn is supported by cudnnRNNBackwardDataEx, " + "cudnnRNNBackwardWeightsEx, but it only works when the version " + "of cudnn is larger than 7.2.1")); +#endif + } +} + +} // namespace phi + +#ifdef PADDLE_WITH_HIP +PD_REGISTER_KERNEL( + cudnn_lstm_grad, GPU, ALL_LAYOUT, phi::CudnnLSTMGradKernel, float) {} +#else +PD_REGISTER_KERNEL( + cudnn_lstm_grad, GPU, ALL_LAYOUT, phi::CudnnLSTMGradKernel, float, double) { +} +#endif diff --git a/paddle/phi/kernels/gpu/cudnn_lstm_kernel.cu b/paddle/phi/kernels/gpu/cudnn_lstm_kernel.cu new file mode 100644 index 00000000000..e3f2b780f3f --- /dev/null +++ b/paddle/phi/kernels/gpu/cudnn_lstm_kernel.cu @@ -0,0 +1,374 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/cudnn_lstm_kernel.h" + +#include "glog/logging.h" + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/gpu/cudnn_lstm_utils.h" + +namespace phi { + +template +#ifdef PADDLE_WITH_HIP +void LSTMInferece(const bool &has_seq_length, + const miopenHandle_t &handle, +#else +void LSTMInferece(const bool &has_seq_length, + const cudnnHandle_t &handle, +#endif + const int &seq_length, + ScopedRNNBase *rnn, + const T *x_data, + const T *init_h_data, + const T *init_c_data, + const T *w_data, + T *out_data, + T *last_h_data, + T *last_c_data, + phi::DenseTensor *workspace_data, + const size_t &workspace_size) { + if (!has_seq_length) { +// for inference +// This interface is used when the input/output is unpadded. +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::miopenRNNForwardInference(handle, + rnn->rnn_desc(), + seq_length, + rnn->x_descs(), + x_data, + rnn->init_h_desc(), + init_h_data, + rnn->init_c_desc(), + init_c_data, + rnn->weight_desc(), + w_data, + rnn->y_descs(), + out_data, + rnn->last_h_desc(), + last_h_data, + rnn->last_c_desc(), + last_c_data, + workspace_data->data(), + workspace_size)); +#else + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnRNNForwardInference(handle, + rnn->rnn_desc(), + seq_length, + rnn->x_descs(), + x_data, + rnn->init_h_desc(), + init_h_data, + rnn->init_c_desc(), + init_c_data, + rnn->weight_desc(), + w_data, + rnn->y_descs(), + out_data, + rnn->last_h_desc(), + last_h_data, + rnn->last_c_desc(), + last_c_data, + workspace_data->data(), + workspace_size)); +#endif + } else { +#if !defined(PADDLE_WITH_HIP) && CUDNN_VERSION >= 7201 + // for inference + // This interface is used when the input/output is padded. + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnRNNForwardInferenceEx( + handle, + rnn->rnn_desc(), + rnn->x_seq_desc(), + x_data, + rnn->init_h_desc(), + init_h_data, + rnn->init_c_desc(), + init_c_data, + rnn->weight_desc(), + w_data, + rnn->y_seq_desc(), + out_data, + rnn->last_h_desc(), + last_h_data, + rnn->last_c_desc(), + last_c_data, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + workspace_data->data(), + workspace_size)); +#else + // CUDNN VERSION has to >=7.2.1 + PADDLE_THROW(phi::errors::Unavailable( + "The padded input is supported by " + "cudnnRNNForwardInferenceEx, but it only works when " + "the version of cudnn is larger than 7.2.1")); +#endif + } +} + +template +void CudnnLSTMKernel( + const Context &ctx, + const DenseTensor &x, + const DenseTensor &init_h, + const DenseTensor &init_c, + const paddle::optional &w, + const paddle::optional> &weight_list, + const paddle::optional &sequence_length, + float dropout_prob, + bool is_bidirec, + int hidden_size, + int num_layers, + bool is_test, + int seed, + DenseTensor *out, + DenseTensor *last_h, + DenseTensor *last_c, + DenseTensor *reserve, + DenseTensor *state_out) { + const T *x_data = x.data(); + const T *init_h_data = init_h.data(); + const T *init_c_data = init_c.data(); + + T *out_data = ctx.template Alloc(out); + T *last_h_data = ctx.template Alloc(last_h); + T *last_c_data = ctx.template Alloc(last_c); + + if (!is_test) { + if (seed == 0) { + // If not specify seed, use global Generator to generate seed. + int device_id = ctx.GetPlace().GetDeviceId(); + auto gen_cuda = phi::DefaultCUDAGenerator(device_id); + seed = static_cast(gen_cuda->Random64()); + } + } + + auto *running_sequence_length = sequence_length.get_ptr(); + bool has_seq_length = running_sequence_length != nullptr; + std::vector SequenceLength; + if (has_seq_length) { + SequenceLength = phi::GetVectorFromTensor(running_sequence_length); + } + + auto handle = ctx.cudnn_handle(); + + int seq_length = x.dims()[0]; + int batch_size = x.dims()[1]; + int input_size = x.dims()[2]; + bool state_initialized = state_out->IsInitialized() ? true : false; + + size_t workspace_size; + size_t reserve_size; + phi::DenseTensor weight_whole; + T *w_data = nullptr; + int weight_numel; + bool w_initialized = false; + auto place = ctx.GetPlace(); + auto stream = ctx.stream(); + auto *running_w = w.get_ptr(); + if (is_test && running_w != nullptr) { + w_initialized = running_w->IsInitialized() ? true : false; + weight_numel = running_w->numel(); + } + if (!w_initialized) { + auto running_weight_list = *weight_list.get_ptr(); + bool continuous = is_continuous>( + running_weight_list); + weight_numel = size_sum(running_weight_list); + + if (!continuous) { + LOG_FIRST_N(WARNING, 2) + << "If the memory space of the Input WeightList is not continuous, " + "less efficient calculation will be called. Please call " + "flatten_parameters() to make the input memory continuous."; + weight_whole.Resize({weight_numel}); + ctx.template Alloc(&weight_whole); + weight_to_tensor(place, stream, running_weight_list, &weight_whole); + w_data = weight_whole.data(); + if (is_test) { // maybe also reset small weights' ptr for training + int offset = 0; + for (size_t i = 0; i < running_weight_list.size(); ++i) { + size_t len = running_weight_list[i]->numel(); + auto dim = running_weight_list[i]->dims(); + const_cast(running_weight_list[i]) + ->ShareDataWith( + weight_whole.Slice(static_cast(offset), + static_cast(offset + len))) + .Resize(dim); + offset += len; + } + } + } else { + w_data = const_cast(running_weight_list[0]->data()); + } + } else { + w_data = const_cast(running_w->data()); + } + + ScopedRNNBase rnn(seq_length, + batch_size, + input_size, + hidden_size, + num_layers, + dropout_prob, + seed, + weight_numel, + state_initialized, + is_bidirec); + rnn.Create(handle, + ctx.GetPlace(), + SequenceLength, + &workspace_size, + &reserve_size, + state_out); + + phi::DenseTensor workspace_data_; + workspace_data_.Resize({static_cast(workspace_size)}); + ctx.template Alloc(&workspace_data_); + + reserve->Resize({static_cast(reserve_size)}); + auto *reserve_data = ctx.template Alloc(reserve); + + if (is_test) { + LSTMInferece(has_seq_length, + handle, + seq_length, + &rnn, + x_data, + init_h_data, + init_c_data, + w_data, + out_data, + last_h_data, + last_c_data, + &workspace_data_, + workspace_size); + } else { + if (!has_seq_length) { +// for train +// This interface is used when the input/output is unpadded. +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::miopenRNNForwardTraining( + handle, + rnn.rnn_desc(), + seq_length, + rnn.x_descs(), + x_data, + rnn.init_h_desc(), + init_h_data, + rnn.init_c_desc(), + init_c_data, + rnn.weight_desc(), + w_data, + rnn.y_descs(), + out_data, + rnn.last_h_desc(), + last_h_data, + rnn.last_c_desc(), + last_c_data, + workspace_data_.data(), + workspace_size, + reserve_data, + reserve_size)); +#else + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnRNNForwardTraining(handle, + rnn.rnn_desc(), + seq_length, + rnn.x_descs(), + x_data, + rnn.init_h_desc(), + init_h_data, + rnn.init_c_desc(), + init_c_data, + rnn.weight_desc(), + w_data, + rnn.y_descs(), + out_data, + rnn.last_h_desc(), + last_h_data, + rnn.last_c_desc(), + last_c_data, + workspace_data_.data(), + workspace_size, + reserve_data, + reserve_size)); +#endif + } else { +#if !defined(PADDLE_WITH_HIP) && CUDNN_VERSION >= 7201 + // for train + // This interface is used when the input/output is padded. + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnRNNForwardTrainingEx( + handle, + rnn.rnn_desc(), + rnn.x_seq_desc(), + x_data, + rnn.init_h_desc(), + init_h_data, + rnn.init_c_desc(), + init_c_data, + rnn.weight_desc(), + w_data, + rnn.y_seq_desc(), + out_data, + rnn.last_h_desc(), + last_h_data, + rnn.last_c_desc(), + last_c_data, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + workspace_data_.data(), + workspace_size, + reserve_data, + reserve_size)); +#else + PADDLE_THROW(phi::errors::Unavailable( + "The padded input is supported by " + "cudnnRNNForwardTrainingEx, but it only works when " + "the version of cudnn is larger than 7.2.1")); +#endif + } + } +} + +} // namespace phi + +#ifdef PADDLE_WITH_HIP +PD_REGISTER_KERNEL(cudnn_lstm, GPU, ALL_LAYOUT, phi::CudnnLSTMKernel, float) { + kernel->OutputAt(3).SetDataType(phi::DataType::UINT8); + kernel->OutputAt(4).SetDataType(phi::DataType::UINT8); +} +#else +PD_REGISTER_KERNEL( + cudnn_lstm, GPU, ALL_LAYOUT, phi::CudnnLSTMKernel, float, double) { + kernel->OutputAt(3).SetDataType(phi::DataType::UINT8); + kernel->OutputAt(4).SetDataType(phi::DataType::UINT8); +} +#endif diff --git a/paddle/phi/kernels/gpu/cudnn_lstm_utils.h b/paddle/phi/kernels/gpu/cudnn_lstm_utils.h new file mode 100644 index 00000000000..e5fc5184945 --- /dev/null +++ b/paddle/phi/kernels/gpu/cudnn_lstm_utils.h @@ -0,0 +1,76 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/phi/common/memory_utils.h" +#include "paddle/phi/core/generator.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/phi/kernels/gpu/cudnn_lstm_cache.h" +#endif +#ifdef PADDLE_WITH_HIP +#include "paddle/phi/kernels/gpu/miopen_lstm_cache.h" +#endif + +namespace phi { + +template +inline bool is_continuous(const Type &weight_list) { + bool continuous = true; + for (size_t i = 0; i < weight_list.size() - 1; ++i) { + auto *in_data = weight_list[i]->template data(); + auto *in_after_data = weight_list[i + 1]->template data(); + auto in_size = weight_list[i]->numel(); + bool temp = in_data + in_size == in_after_data; + continuous = continuous && temp; + } + return continuous; +} + +inline int size_sum(const std::vector &weight_list) { + int size = 0; + for (size_t i = 0; i < weight_list.size(); ++i) { + auto in_size = weight_list[i]->numel(); + size += in_size; + } + return size; +} + +template +inline void weight_to_tensor( + const phi::Place &place, + gpuStream_t stream, + const std::vector &weight_list, + phi::DenseTensor *weight) { + auto weight_data = weight->data(); + int weight_offset = 0; + for (size_t i = 0; i < weight_list.size(); ++i) { + const T *in_data = weight_list[i]->data(); + auto in_size = weight_list[i]->numel(); + + memory_utils::Copy(weight->place(), + weight_data + weight_offset, + weight_list[i]->place(), + in_data, + in_size * sizeof(T), + stream); + weight_offset += in_size; + } +} + +} // namespace phi diff --git a/paddle/fluid/operators/miopen_lstm_cache.h b/paddle/phi/kernels/gpu/miopen_lstm_cache.h similarity index 71% rename from paddle/fluid/operators/miopen_lstm_cache.h rename to paddle/phi/kernels/gpu/miopen_lstm_cache.h index a9a6482fd48..1ce6acc44fc 100644 --- a/paddle/fluid/operators/miopen_lstm_cache.h +++ b/paddle/phi/kernels/gpu/miopen_lstm_cache.h @@ -16,11 +16,13 @@ limitations under the License. */ #include -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" +#include "paddle/phi/backends/context_pool.h" +#include "paddle/phi/backends/dynload/miopen.h" +#include "paddle/phi/backends/gpu/forwards.h" +#include "paddle/phi/backends/gpu/gpu_dnn.h" +#include "paddle/phi/core/dense_tensor.h" -namespace paddle { -namespace operators { +namespace phi { class ScopedRNNBase { public: @@ -47,13 +49,13 @@ class ScopedRNNBase { template void Create(const miopenHandle_t& handle, - const platform::Place& place, + const phi::Place& place, const std::vector& sequence_length, size_t* workspace_size, size_t* reserve_size, phi::DenseTensor* dropout_state) { int numDirections = is_bidirec_ ? 2 : 1; - miopenDataType_t miopen_type = platform::CudnnDataType::type; + miopenDataType_t miopen_type = phi::backends::gpu::CudnnDataType::type; // ------------------- miopen x, y descriptors --------------------- std::vector dims_x = {batch_size_, input_size_, 1}; @@ -78,9 +80,11 @@ class ScopedRNNBase { size_t state_size; if (!initialized_) { PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::miopenDropoutGetStatesSize(handle, &state_size)); - dropout_state->mutable_data({static_cast(state_size)}, - place); + phi::dynload::miopenDropoutGetStatesSize(handle, &state_size)); + phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance(); + auto* dev_ctx = reinterpret_cast(pool.Get(place)); + dropout_state->Resize({static_cast(state_size)}); + dev_ctx->template Alloc(dropout_state); } dropout_desc_.descriptor(handle, place, @@ -91,7 +95,7 @@ class ScopedRNNBase { state_size); // ------------------- miopen rnn descriptors --------------------- - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenSetRNNDescriptor_V2( + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::miopenSetRNNDescriptor_V2( rnn_desc_.desc(), hidden_size_, num_layers_, @@ -105,31 +109,28 @@ class ScopedRNNBase { // ------------------- miopen weights_size --------------------- size_t weights_size_; - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenGetRNNParamsSize( + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::miopenGetRNNParamsSize( handle, rnn_desc_.desc(), x_descs_[0], &weights_size_, miopen_type)); PADDLE_ENFORCE_EQ( weights_size_, sizeof(T) * weight_numel_, - platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The miopen lstm and setting weight size should be same.")); // ------------------- miopen weight descriptors --------------------- - platform::DataLayout layout = platform::DataLayout::kNCHW; + phi::backends::gpu::DataLayout layout = + phi::backends::gpu::DataLayout::kNCHW; int dim_tmp = weights_size_ / sizeof(T); std::vector dim_w = {dim_tmp, 1, 1}; weight_desc_.descriptor(layout, dim_w); // ------------------- miopen workspace, reserve size --------------------- PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::miopenGetRNNWorkspaceSize(handle, - rnn_desc_.desc(), - seq_length_, - x_descs_.data(), - workspace_size)); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::miopenGetRNNTrainingReserveSize(handle, - rnn_desc_.desc(), - seq_length_, - x_descs_.data(), - reserve_size)); + phi::dynload::miopenGetRNNWorkspaceSize(handle, + rnn_desc_.desc(), + seq_length_, + x_descs_.data(), + workspace_size)); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::miopenGetRNNTrainingReserveSize( + handle, rnn_desc_.desc(), seq_length_, x_descs_.data(), reserve_size)); } miopenTensorDescriptor_t* x_descs() { return x_descs_.data(); } miopenTensorDescriptor_t* y_descs() { return y_descs_.data(); } @@ -155,16 +156,15 @@ class ScopedRNNBase { std::vector x_descs_; std::vector y_descs_; - platform::ScopedTensorDescriptor x_desc_; - platform::ScopedTensorDescriptor y_desc_; - platform::ScopedTensorDescriptor init_h_desc_; - platform::ScopedTensorDescriptor init_c_desc_; - platform::ScopedTensorDescriptor last_h_desc_; - platform::ScopedTensorDescriptor last_c_desc_; - platform::ScopedDropoutDescriptor dropout_desc_; - platform::ScopedFilterDescriptor weight_desc_; - platform::ScopedRNNDescriptor rnn_desc_; + phi::backends::gpu::ScopedTensorDescriptor x_desc_; + phi::backends::gpu::ScopedTensorDescriptor y_desc_; + phi::backends::gpu::ScopedTensorDescriptor init_h_desc_; + phi::backends::gpu::ScopedTensorDescriptor init_c_desc_; + phi::backends::gpu::ScopedTensorDescriptor last_h_desc_; + phi::backends::gpu::ScopedTensorDescriptor last_c_desc_; + phi::backends::gpu::ScopedDropoutDescriptor dropout_desc_; + phi::backends::gpu::ScopedFilterDescriptor weight_desc_; + phi::backends::gpu::ScopedRNNDescriptor rnn_desc_; }; -} // namespace operators -} // namespace paddle +} // namespace phi diff --git a/paddle/phi/ops/compat/cudnn_lstm_sig.cc b/paddle/phi/ops/compat/cudnn_lstm_sig.cc new file mode 100644 index 00000000000..83e61b396ee --- /dev/null +++ b/paddle/phi/ops/compat/cudnn_lstm_sig.cc @@ -0,0 +1,59 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature CudnnLSTMOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature( + "cudnn_lstm", + {"Input", "InitH", "InitC", "W", "WeightList", "SequenceLength"}, + {"dropout_prob", + "is_bidirec", + "hidden_size", + "num_layers", + "is_test", + "seed"}, + {"Out", "LastH", "LastC", "Reserve", "StateOut"}); +} + +KernelSignature CudnnLSTMGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "cudnn_lstm_grad", + {"Input", + "InitH", + "InitC", + "WeightList", + "SequenceLength", + "Out", + "Reserve", + "StateOut", + "Out@GRAD", + "LastH@GRAD", + "LastC@GRAD"}, + {"dropout_prob", + "is_bidirec", + "hidden_size", + "num_layers", + "is_test", + "seed"}, + {"Input@GRAD", "InitH@GRAD", "InitC@GRAD", "WeightList@GRAD"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(cudnn_lstm, phi::CudnnLSTMOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(cudnn_lstm_grad, + phi::CudnnLSTMGradOpArgumentMapping); -- GitLab