未验证 提交 52889e38 编写于 作者: H huangjiyi 提交者: GitHub

move cudnn_lstm kernel to phi (#53730)

* update

* fix bug

* test

* test

* update

* update mutable_data

* fix bug

* update

* fix bug

* update output type reg

* update

* update
上级 5b054d2f
......@@ -292,15 +292,6 @@ class CudnnLSTMGradOpMaker : public framework::SingleGradOpMaker<T> {
}
};
template <typename T, typename DeviceContext>
class NotImpleKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"CPU is not support for this kernel now. Will be add in the future"));
}
};
} // namespace operators
} // namespace paddle
......@@ -312,9 +303,6 @@ REGISTER_OPERATOR(cudnn_lstm,
ops::CudnnLSTMGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(cudnn_lstm_grad, ops::CudnnLSTMGradOp);
PD_REGISTER_STRUCT_KERNEL(
cudnn_lstm, CPU, ALL_LAYOUT, ops::NotImpleKernel, float) {}
// TODO(Shixiaowei02) Add ModifyInput support
REGISTER_OP_VERSION(cudnn_lstm)
.AddCheckpoint(
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/generator.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/operators/cudnn_lstm_cache.h"
#endif
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/operators/miopen_lstm_cache.h"
#endif
namespace paddle {
namespace operators {
template <typename T, typename Type>
bool is_continuous(const Type &weight_list) {
bool continuous = true;
for (size_t i = 0; i < weight_list.size() - 1; ++i) {
auto *in_data = weight_list[i]->template data<T>();
auto *in_after_data = weight_list[i + 1]->template data<T>();
auto in_size = weight_list[i]->numel();
bool temp = in_data + in_size == in_after_data;
continuous = continuous && temp;
}
return continuous;
}
int size_sum(const std::vector<const phi::DenseTensor *> &weight_list) {
int size = 0;
for (size_t i = 0; i < weight_list.size(); ++i) {
auto in_size = weight_list[i]->numel();
size += in_size;
}
return size;
}
template <typename T>
void weight_to_tensor(const platform::Place &place,
gpuStream_t stream,
const std::vector<const phi::DenseTensor *> &weight_list,
phi::DenseTensor *weight) {
auto weight_data = weight->data<T>();
int weight_offset = 0;
for (size_t i = 0; i < weight_list.size(); ++i) {
const T *in_data = weight_list[i]->data<T>();
auto in_size = weight_list[i]->numel();
memory::Copy(weight->place(),
weight_data + weight_offset,
weight_list[i]->place(),
in_data,
in_size * sizeof(T),
stream);
weight_offset += in_size;
}
}
template <typename T>
void weight_to_tensor_list(
const platform::Place &place,
gpuStream_t stream,
std::vector<phi::DenseTensor *> *weight_grad,
const std::vector<const phi::DenseTensor *> &weight_input,
const phi::DenseTensor *weight) {
int weight_offset = 0;
auto *weight_data = weight->data<T>();
for (size_t i = 0; i < weight_input.size(); ++i) {
auto in_size = weight_input[i]->numel();
T *weight_grad_data = (*weight_grad)[i]->mutable_data<T>(place);
const T *src = weight_data + weight_offset;
memory::Copy((*weight_grad)[i]->place(),
weight_grad_data,
weight->place(),
src,
in_size * sizeof(T),
stream);
weight_offset += in_size;
}
}
template <typename T>
#ifdef PADDLE_WITH_HIP
void LSTMInferece(const bool &has_seq_length,
const miopenHandle_t &handle,
#else
void LSTMInferece(const bool &has_seq_length,
const cudnnHandle_t &handle,
#endif
const int &seq_length,
ScopedRNNBase *rnn,
const T *x_data,
const T *init_h_data,
const T *init_c_data,
const T *w_data,
T *out_data,
T *last_h_data,
T *last_c_data,
phi::DenseTensor *workspace_data,
const size_t &workspace_size) {
if (!has_seq_length) {
// for inference
// This interface is used when the input/output is unpadded.
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenRNNForwardInference(
handle,
rnn->rnn_desc(),
seq_length,
rnn->x_descs(),
x_data,
rnn->init_h_desc(),
init_h_data,
rnn->init_c_desc(),
init_c_data,
rnn->weight_desc(),
w_data,
rnn->y_descs(),
out_data,
rnn->last_h_desc(),
last_h_data,
rnn->last_c_desc(),
last_c_data,
workspace_data->data<uint8_t>(),
workspace_size));
#else
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNForwardInference(
handle,
rnn->rnn_desc(),
seq_length,
rnn->x_descs(),
x_data,
rnn->init_h_desc(),
init_h_data,
rnn->init_c_desc(),
init_c_data,
rnn->weight_desc(),
w_data,
rnn->y_descs(),
out_data,
rnn->last_h_desc(),
last_h_data,
rnn->last_c_desc(),
last_c_data,
workspace_data->data<uint8_t>(),
workspace_size));
#endif
} else {
#if !defined(PADDLE_WITH_HIP) && CUDNN_VERSION >= 7201
// for inference
// This interface is used when the input/output is padded.
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNForwardInferenceEx(
handle,
rnn->rnn_desc(),
rnn->x_seq_desc(),
x_data,
rnn->init_h_desc(),
init_h_data,
rnn->init_c_desc(),
init_c_data,
rnn->weight_desc(),
w_data,
rnn->y_seq_desc(),
out_data,
rnn->last_h_desc(),
last_h_data,
rnn->last_c_desc(),
last_c_data,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
workspace_data->data<uint8_t>(),
workspace_size));
#else
// CUDNN VERSION has to >=7.2.1
PADDLE_THROW(platform::errors::Unavailable(
"The padded input is supported by "
"cudnnRNNForwardInferenceEx, but it only works when "
"the version of cudnn is larger than 7.2.1"));
#endif
}
}
template <typename T, typename DeviceContext>
class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const phi::DenseTensor *x = ctx.Input<phi::DenseTensor>("Input");
const phi::DenseTensor *init_h = ctx.Input<phi::DenseTensor>("InitH");
const phi::DenseTensor *init_c = ctx.Input<phi::DenseTensor>("InitC");
phi::DenseTensor *out = ctx.Output<phi::DenseTensor>("Out");
phi::DenseTensor *last_h = ctx.Output<phi::DenseTensor>("LastH");
phi::DenseTensor *last_c = ctx.Output<phi::DenseTensor>("LastC");
phi::DenseTensor *reserve = ctx.Output<phi::DenseTensor>("Reserve");
phi::DenseTensor *state_out = ctx.Output<phi::DenseTensor>("StateOut");
const T *x_data = x->data<T>();
const T *init_h_data = init_h->data<T>();
const T *init_c_data = init_c->data<T>();
T *out_data = out->mutable_data<T>(ctx.GetPlace());
T *last_h_data = last_h->mutable_data<T>(ctx.GetPlace());
T *last_c_data = last_c->mutable_data<T>(ctx.GetPlace());
float dropout_prob = ctx.Attr<float>("dropout_prob");
bool is_bidirec = ctx.Attr<bool>("is_bidirec");
int hidden_size = ctx.Attr<int>("hidden_size");
int num_layers = ctx.Attr<int>("num_layers");
bool is_test = ctx.Attr<bool>("is_test");
int seed = ctx.Attr<int>("seed");
if (!is_test) {
if (seed == 0) {
// If not specify seed, use global Generator to generate seed.
int device_id = ctx.GetPlace().GetDeviceId();
auto gen_cuda = phi::DefaultCUDAGenerator(device_id);
seed = static_cast<int>(gen_cuda->Random64());
}
// else use `ctx.Attr<int>("seed")` specified seed
}
bool has_seq_length = ctx.HasInput("SequenceLength");
std::vector<int> SequenceLength;
if (has_seq_length) {
auto *sequence_length = ctx.Input<phi::DenseTensor>("SequenceLength");
SequenceLength = phi::GetVectorFromTensor<int>(sequence_length);
}
auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
auto handle = dev_ctx.cudnn_handle();
int seq_length = x->dims()[0];
int batch_size = x->dims()[1];
int input_size = x->dims()[2];
bool state_initialized = state_out->IsInitialized() ? true : false;
size_t workspace_size;
size_t reserve_size;
phi::DenseTensor weight_whole;
T *w_data = nullptr;
int weight_numel;
bool w_initialized = false;
auto place = ctx.GetPlace();
auto stream =
reinterpret_cast<const phi::GPUContext &>(ctx.device_context())
.stream();
if (is_test && ctx.HasInput("W")) {
auto *W = ctx.Input<phi::DenseTensor>("W");
w_initialized = W->IsInitialized() ? true : false;
weight_numel = W->numel();
}
if (!w_initialized) {
auto weight_list = ctx.MultiInput<phi::DenseTensor>("WeightList");
bool continuous =
is_continuous<T, std::vector<const phi::DenseTensor *>>(weight_list);
weight_numel = size_sum(weight_list);
if (!continuous) {
LOG_FIRST_N(WARNING, 2)
<< "If the memory space of the Input WeightList is not continuous, "
"less efficient calculation will be called. Please call "
"flatten_parameters() to make the input memory continuous.";
weight_whole.mutable_data<T>({weight_numel}, place);
weight_to_tensor<T>(place, stream, weight_list, &weight_whole);
w_data = weight_whole.data<T>();
if (is_test) { // maybe also reset small weights' ptr for training
int offset = 0;
for (size_t i = 0; i < weight_list.size(); ++i) {
size_t len = weight_list[i]->numel();
auto dim = weight_list[i]->dims();
const_cast<phi::DenseTensor *>(weight_list[i])
->ShareDataWith(
weight_whole.Slice(static_cast<int64_t>(offset),
static_cast<int64_t>(offset + len)))
.Resize(dim);
offset += len;
}
}
} else {
w_data = const_cast<T *>(weight_list[0]->data<T>());
}
} else {
auto *W = ctx.Input<phi::DenseTensor>("W");
w_data = const_cast<T *>(W->data<T>());
}
ScopedRNNBase rnn(seq_length,
batch_size,
input_size,
hidden_size,
num_layers,
dropout_prob,
seed,
weight_numel,
state_initialized,
is_bidirec);
rnn.Create<T>(handle,
ctx.GetPlace(),
SequenceLength,
&workspace_size,
&reserve_size,
state_out);
phi::DenseTensor workspace_data_;
workspace_data_.mutable_data<uint8_t>(
{static_cast<int64_t>(workspace_size)}, ctx.GetPlace());
auto *reserve_data = reserve->mutable_data<uint8_t>(
{static_cast<int64_t>(reserve_size)}, ctx.GetPlace());
if (is_test) {
LSTMInferece<T>(has_seq_length,
handle,
seq_length,
&rnn,
x_data,
init_h_data,
init_c_data,
w_data,
out_data,
last_h_data,
last_c_data,
&workspace_data_,
workspace_size);
} else {
if (!has_seq_length) {
// for train
// This interface is used when the input/output is unpadded.
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenRNNForwardTraining(
handle,
rnn.rnn_desc(),
seq_length,
rnn.x_descs(),
x_data,
rnn.init_h_desc(),
init_h_data,
rnn.init_c_desc(),
init_c_data,
rnn.weight_desc(),
w_data,
rnn.y_descs(),
out_data,
rnn.last_h_desc(),
last_h_data,
rnn.last_c_desc(),
last_c_data,
workspace_data_.data<uint8_t>(),
workspace_size,
reserve_data,
reserve_size));
#else
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNForwardTraining(
handle,
rnn.rnn_desc(),
seq_length,
rnn.x_descs(),
x_data,
rnn.init_h_desc(),
init_h_data,
rnn.init_c_desc(),
init_c_data,
rnn.weight_desc(),
w_data,
rnn.y_descs(),
out_data,
rnn.last_h_desc(),
last_h_data,
rnn.last_c_desc(),
last_c_data,
workspace_data_.data<uint8_t>(),
workspace_size,
reserve_data,
reserve_size));
#endif
} else {
#if !defined(PADDLE_WITH_HIP) && CUDNN_VERSION >= 7201
// for train
// This interface is used when the input/output is padded.
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNForwardTrainingEx(
handle,
rnn.rnn_desc(),
rnn.x_seq_desc(),
x_data,
rnn.init_h_desc(),
init_h_data,
rnn.init_c_desc(),
init_c_data,
rnn.weight_desc(),
w_data,
rnn.y_seq_desc(),
out_data,
rnn.last_h_desc(),
last_h_data,
rnn.last_c_desc(),
last_c_data,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
workspace_data_.data<uint8_t>(),
workspace_size,
reserve_data,
reserve_size));
#else
PADDLE_THROW(platform::errors::Unavailable(
"The padded input is supported by "
"cudnnRNNForwardTrainingEx, but it only works when "
"the version of cudnn is larger than 7.2.1"));
#endif
}
}
}
};
template <typename T, typename DeviceContext>
class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *input = ctx.Input<phi::DenseTensor>("Input");
auto *init_h = ctx.Input<phi::DenseTensor>("InitH");
auto *init_c = ctx.Input<phi::DenseTensor>("InitC");
auto *reserve = ctx.Input<phi::DenseTensor>("Reserve");
auto *state_out = ctx.Input<phi::DenseTensor>("StateOut");
auto weight_list = ctx.MultiInput<phi::DenseTensor>("WeightList");
auto *out = ctx.Input<phi::DenseTensor>("Out");
auto *out_grad = ctx.Input<phi::DenseTensor>(framework::GradVarName("Out"));
auto *last_h_grad =
ctx.Input<phi::DenseTensor>(framework::GradVarName("LastH"));
auto *last_c_grad =
ctx.Input<phi::DenseTensor>(framework::GradVarName("LastC"));
auto *in_grad =
ctx.Output<phi::DenseTensor>(framework::GradVarName("Input"));
auto *init_h_grad =
ctx.Output<phi::DenseTensor>(framework::GradVarName("InitH"));
auto *init_c_grad =
ctx.Output<phi::DenseTensor>(framework::GradVarName("InitC"));
auto weight_grad_list =
ctx.MultiOutput<phi::DenseTensor>(framework::GradVarName("WeightList"));
auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
auto handle = dev_ctx.cudnn_handle();
auto input_dims = input->dims();
auto init_h_dims = init_h->dims();
auto init_c_dims = init_c->dims();
auto *init_h_data = init_h->data<T>();
auto *init_c_data = init_c->data<T>();
auto *out_data = out->data<T>();
auto *out_grad_data = out_grad->data<T>();
auto *last_h_grad_data = last_h_grad->data<T>();
auto *last_c_grad_data = last_c_grad->data<T>();
auto place = ctx.GetPlace();
int weight_numel = size_sum(weight_list);
bool continuous =
is_continuous<T, std::vector<const phi::DenseTensor *>>(weight_list);
auto stream =
reinterpret_cast<const phi::GPUContext &>(ctx.device_context())
.stream();
phi::DenseTensor weight_whole;
T *weight_data = nullptr;
if (!continuous) {
weight_whole.mutable_data<T>({weight_numel}, place);
weight_to_tensor<T>(place, stream, weight_list, &weight_whole);
weight_data = weight_whole.data<T>();
} else {
weight_data = const_cast<T *>(weight_list[0]->data<T>());
}
phi::DenseTensor weight_grad;
phi::funcs::SetConstant<phi::GPUContext, T> zero;
weight_grad.mutable_data<T>({weight_numel}, ctx.GetPlace());
zero(dev_ctx, &weight_grad, static_cast<T>(0.0));
T *weight_grad_data = weight_grad.data<T>();
int offset = 0;
for (size_t i = 0; i < weight_grad_list.size(); ++i) {
size_t len = weight_grad_list[i]->numel();
auto dim = weight_grad_list[i]->dims();
weight_grad_list[i]
->ShareDataWith(weight_grad.Slice(static_cast<int64_t>(offset),
static_cast<int64_t>(offset + len)))
.Resize(dim);
offset += len;
}
in_grad->mutable_data<T>(input_dims, ctx.GetPlace());
auto *in_grad_data = in_grad->data<T>();
if (init_h_grad) init_h_grad->mutable_data<T>(init_h_dims, ctx.GetPlace());
auto *init_h_grad_data = init_h_grad ? init_h_grad->data<T>() : nullptr;
if (init_c_grad) init_c_grad->mutable_data<T>(init_c_dims, ctx.GetPlace());
auto *init_c_grad_data = init_c_grad ? init_c_grad->data<T>() : nullptr;
float dropout_prob = ctx.Attr<float>("dropout_prob");
bool is_bidirec = ctx.Attr<bool>("is_bidirec");
int hidden_size = ctx.Attr<int>("hidden_size");
int num_layers = ctx.Attr<int>("num_layers");
int seed = ctx.Attr<int>("seed");
bool has_seq_length = ctx.HasInput("SequenceLength");
std::vector<int> SequenceLength;
if (has_seq_length) {
auto *sequence_length = ctx.Input<phi::DenseTensor>("SequenceLength");
SequenceLength = phi::GetVectorFromTensor<int>(sequence_length);
}
int seq_length = input_dims[0];
int batch_size = input->dims()[1];
int input_size = input->dims()[2];
size_t workspace_size;
size_t reserve_size;
ScopedRNNBase rnn(seq_length,
batch_size,
input_size,
hidden_size,
num_layers,
dropout_prob,
seed,
weight_numel,
true,
is_bidirec);
rnn.Create<T>(handle,
ctx.GetPlace(),
SequenceLength,
&workspace_size,
&reserve_size,
const_cast<phi::DenseTensor *>(state_out));
phi::DenseTensor workspace_data_;
workspace_data_.mutable_data<uint8_t>(
{static_cast<int64_t>(workspace_size)}, ctx.GetPlace());
const uint8_t *reserve_data = reserve->data<uint8_t>();
if (!has_seq_length) {
// This interface is used when the input/output is unpadded.
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenRNNBackwardData(
handle,
rnn.rnn_desc(),
seq_length,
rnn.y_descs(),
out_data,
rnn.y_descs(),
out_grad_data,
rnn.last_h_desc(),
last_h_grad_data,
rnn.last_c_desc(),
last_c_grad_data,
rnn.weight_desc(),
weight_data,
rnn.init_h_desc(),
init_h_data,
rnn.init_c_desc(),
init_c_data,
rnn.x_descs(),
in_grad_data,
rnn.init_h_desc(),
init_h_grad_data,
rnn.init_c_desc(),
init_c_grad_data,
workspace_data_.data<uint8_t>(),
workspace_size,
const_cast<uint8_t *>(reserve_data),
reserve_size));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenRNNBackwardWeights(
handle,
rnn.rnn_desc(),
seq_length,
rnn.x_descs(),
input->data<T>(),
rnn.init_h_desc(),
init_h->data<T>(),
rnn.y_descs(),
out->data<T>(),
rnn.weight_desc(),
weight_grad_data,
workspace_data_.data<uint8_t>(),
workspace_size,
const_cast<uint8_t *>(reserve_data),
reserve_size));
#else
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNBackwardData(
handle,
rnn.rnn_desc(),
seq_length,
rnn.y_descs(),
out_data,
rnn.y_descs(),
out_grad_data,
rnn.last_h_desc(),
last_h_grad_data,
rnn.last_c_desc(),
last_c_grad_data,
rnn.weight_desc(),
weight_data,
rnn.init_h_desc(),
init_h_data,
rnn.init_c_desc(),
init_c_data,
rnn.x_descs(),
in_grad_data,
rnn.init_h_desc(),
init_h_grad_data,
rnn.init_c_desc(),
init_c_grad_data,
workspace_data_.data<uint8_t>(),
workspace_size,
const_cast<uint8_t *>(reserve_data),
reserve_size));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNBackwardWeights(
handle,
rnn.rnn_desc(),
seq_length,
rnn.x_descs(),
input->data<T>(),
rnn.init_h_desc(),
init_h->data<T>(),
rnn.y_descs(),
out->data<T>(),
workspace_data_.data<uint8_t>(),
workspace_size,
rnn.weight_desc(),
weight_grad_data,
const_cast<uint8_t *>(reserve_data),
reserve_size));
#endif
} else {
#if !defined(PADDLE_WITH_HIP) && CUDNN_VERSION >= 7201
// for train
// This interface is used when the input/output is padded.
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNBackwardDataEx(
handle,
rnn.rnn_desc(),
rnn.y_seq_desc(),
out_data,
rnn.y_seq_desc(),
out_grad_data,
nullptr,
nullptr,
rnn.last_h_desc(),
last_h_grad_data,
rnn.last_c_desc(),
last_c_grad_data,
rnn.weight_desc(),
weight_data,
rnn.init_h_desc(),
init_h_data,
rnn.init_c_desc(),
init_c_data,
rnn.x_seq_desc(),
in_grad_data,
rnn.init_h_desc(),
init_h_grad_data,
rnn.init_c_desc(),
init_c_grad_data,
nullptr,
nullptr,
workspace_data_.data<uint8_t>(),
workspace_size,
const_cast<uint8_t *>(reserve_data),
reserve_size));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNBackwardWeightsEx(
handle,
rnn.rnn_desc(),
rnn.x_seq_desc(),
input->data<T>(),
rnn.init_h_desc(),
init_h->data<T>(),
rnn.y_seq_desc(),
out->data<T>(),
workspace_data_.data<uint8_t>(),
workspace_size,
rnn.weight_desc(),
weight_grad_data,
const_cast<uint8_t *>(reserve_data),
reserve_size));
#else
PADDLE_THROW(platform::errors::Unavailable(
"The padded input of rnn is supported by cudnnRNNBackwardDataEx, "
"cudnnRNNBackwardWeightsEx, but it only works when the version "
"of cudnn is larger than 7.2.1"));
#endif
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
PD_REGISTER_STRUCT_KERNEL(
cudnn_lstm, GPU, ALL_LAYOUT, ops::CudnnLSTMGPUKernel, float) {}
PD_REGISTER_STRUCT_KERNEL(
cudnn_lstm_grad, GPU, ALL_LAYOUT, ops::CudnnLSTMGPUGradKernel, float) {}
#else
PD_REGISTER_STRUCT_KERNEL(
cudnn_lstm, GPU, ALL_LAYOUT, ops::CudnnLSTMGPUKernel, float, double) {}
PD_REGISTER_STRUCT_KERNEL(cudnn_lstm_grad,
GPU,
ALL_LAYOUT,
ops::CudnnLSTMGPUGradKernel,
float,
double) {}
#endif
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/cudnn_lstm_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void CudnnLSTMKernel(
const Context& ctx,
const DenseTensor& x,
const DenseTensor& init_h,
const DenseTensor& init_c,
const paddle::optional<DenseTensor>& w,
const paddle::optional<std::vector<const DenseTensor*>>& weight_list,
const paddle::optional<DenseTensor>& sequence_length,
float dropout_prob,
bool is_bidirec,
int hidden_size,
int num_layers,
bool is_test,
int seed,
DenseTensor* out,
DenseTensor* last_h,
DenseTensor* last_c,
DenseTensor* reserve,
DenseTensor* state_out) {
PADDLE_THROW(phi::errors::Unimplemented(
"CPU is not support for cudnn_lstm now. Will be add in the future"));
}
} // namespace phi
PD_REGISTER_KERNEL(cudnn_lstm, CPU, ALL_LAYOUT, phi::CudnnLSTMKernel, float) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/utils/optional.h"
namespace phi {
template <typename T, typename Context>
void CudnnLSTMGradKernel(
const Context& ctx,
const DenseTensor& x,
const DenseTensor& init_h,
const DenseTensor& init_c,
const paddle::optional<std::vector<const DenseTensor*>>& weight_list,
const paddle::optional<DenseTensor>& sequence_length,
const DenseTensor& out,
const DenseTensor& reserve,
const DenseTensor& state_out,
const DenseTensor& out_grad,
const DenseTensor& last_h_grad,
const DenseTensor& last_c_grad,
float dropout_prob,
bool is_bidirec,
int hidden_size,
int num_layers,
bool is_test,
int seed,
DenseTensor* x_grad,
DenseTensor* init_h_grad,
DenseTensor* init_c_grad,
std::vector<DenseTensor*> weight_list_grad);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/utils/optional.h"
namespace phi {
template <typename T, typename Context>
void CudnnLSTMKernel(
const Context& ctx,
const DenseTensor& x,
const DenseTensor& init_h,
const DenseTensor& init_c,
const paddle::optional<DenseTensor>& w,
const paddle::optional<std::vector<const DenseTensor*>>& weight_list,
const paddle::optional<DenseTensor>& sequence_length,
float dropout_prob,
bool is_bidirec,
int hidden_size,
int num_layers,
bool is_test,
int seed,
DenseTensor* out,
DenseTensor* last_h,
DenseTensor* last_c,
DenseTensor* reserve,
DenseTensor* state_out);
} // namespace phi
......@@ -16,12 +16,13 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
#include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/backends/dynload/cudnn.h"
#include "paddle/phi/backends/gpu/forwards.h"
#include "paddle/phi/backends/gpu/gpu_dnn.h"
#include "paddle/phi/core/dense_tensor.h"
namespace paddle {
namespace operators {
namespace phi {
class ScopedRNNBase {
public:
......@@ -48,13 +49,13 @@ class ScopedRNNBase {
template <typename T>
void Create(const cudnnHandle_t& handle,
const platform::Place& place,
const phi::Place& place,
const std::vector<int>& sequence_length,
size_t* workspace_size,
size_t* reserve_size,
phi::DenseTensor* dropout_state) {
int numDirections = is_bidirec_ ? 2 : 1;
cudnnDataType_t cudnn_type = platform::CudnnDataType<T>::type;
cudnnDataType_t cudnn_type = phi::backends::gpu::CudnnDataType<T>::type;
// ------------------- cudnn x, y descriptors ---------------------
std::vector<int> dims_x = {batch_size_, input_size_, 1};
......@@ -91,9 +92,11 @@ class ScopedRNNBase {
size_t state_size;
if (!initialized_) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnDropoutGetStatesSize(handle, &state_size));
dropout_state->mutable_data<uint8_t>({static_cast<int64_t>(state_size)},
place);
phi::dynload::cudnnDropoutGetStatesSize(handle, &state_size));
phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance();
auto* dev_ctx = reinterpret_cast<phi::GPUContext*>(pool.Get(place));
dropout_state->Resize({static_cast<int64_t>(state_size)});
dev_ctx->template Alloc<uint8_t>(dropout_state);
}
dropout_desc_.descriptor(handle,
place,
......@@ -104,7 +107,7 @@ class ScopedRNNBase {
state_size);
// ------------------- cudnn rnn descriptors ---------------------
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetRNNDescriptor_v6(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnSetRNNDescriptor_v6(
handle,
rnn_desc_.desc(),
hidden_size_,
......@@ -118,38 +121,35 @@ class ScopedRNNBase {
#if CUDNN_VERSION >= 7201
if (!sequence_length.empty()) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetRNNPaddingMode(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnSetRNNPaddingMode(
rnn_desc_.desc(), CUDNN_RNN_PADDED_IO_ENABLED));
}
#endif
// ------------------- cudnn weights_size ---------------------
size_t weights_size_;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnGetRNNParamsSize(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnGetRNNParamsSize(
handle, rnn_desc_.desc(), x_descs_[0], &weights_size_, cudnn_type));
PADDLE_ENFORCE_EQ(
weights_size_,
sizeof(T) * weight_numel_,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"The cudnn lstm and setting weight size should be same."));
// ------------------- cudnn weight descriptors ---------------------
platform::DataLayout layout = platform::DataLayout::kNCHW;
phi::backends::gpu::DataLayout layout =
phi::backends::gpu::DataLayout::kNCHW;
int dim_tmp = weights_size_ / sizeof(T);
std::vector<int> dim_w = {dim_tmp, 1, 1};
weight_desc_.descriptor<T>(layout, dim_w);
// ------------------- cudnn workspace, reserve size ---------------------
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnGetRNNWorkspaceSize(handle,
rnn_desc_.desc(),
seq_length_,
x_descs_.data(),
workspace_size));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnGetRNNTrainingReserveSize(handle,
rnn_desc_.desc(),
seq_length_,
x_descs_.data(),
reserve_size));
phi::dynload::cudnnGetRNNWorkspaceSize(handle,
rnn_desc_.desc(),
seq_length_,
x_descs_.data(),
workspace_size));
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnGetRNNTrainingReserveSize(
handle, rnn_desc_.desc(), seq_length_, x_descs_.data(), reserve_size));
}
cudnnTensorDescriptor_t* x_descs() { return x_descs_.data(); }
cudnnTensorDescriptor_t* y_descs() { return y_descs_.data(); }
......@@ -179,20 +179,19 @@ class ScopedRNNBase {
std::vector<cudnnTensorDescriptor_t> x_descs_;
std::vector<cudnnTensorDescriptor_t> y_descs_;
platform::ScopedTensorDescriptor x_desc_;
platform::ScopedTensorDescriptor y_desc_;
phi::backends::gpu::ScopedTensorDescriptor x_desc_;
phi::backends::gpu::ScopedTensorDescriptor y_desc_;
#if CUDNN_VERSION >= 7201
platform::ScopedRNNTensorDescriptor x_seq_desc_;
platform::ScopedRNNTensorDescriptor y_seq_desc_;
phi::backends::gpu::ScopedRNNTensorDescriptor x_seq_desc_;
phi::backends::gpu::ScopedRNNTensorDescriptor y_seq_desc_;
#endif
platform::ScopedTensorDescriptor init_h_desc_;
platform::ScopedTensorDescriptor init_c_desc_;
platform::ScopedTensorDescriptor last_h_desc_;
platform::ScopedTensorDescriptor last_c_desc_;
platform::ScopedDropoutDescriptor dropout_desc_;
platform::ScopedFilterDescriptor weight_desc_;
platform::ScopedRNNDescriptor rnn_desc_;
phi::backends::gpu::ScopedTensorDescriptor init_h_desc_;
phi::backends::gpu::ScopedTensorDescriptor init_c_desc_;
phi::backends::gpu::ScopedTensorDescriptor last_h_desc_;
phi::backends::gpu::ScopedTensorDescriptor last_c_desc_;
phi::backends::gpu::ScopedDropoutDescriptor dropout_desc_;
phi::backends::gpu::ScopedFilterDescriptor weight_desc_;
phi::backends::gpu::ScopedRNNDescriptor rnn_desc_;
};
} // namespace operators
} // namespace paddle
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/cudnn_lstm_grad_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/gpu/cudnn_lstm_utils.h"
namespace phi {
template <typename T, typename Context>
void CudnnLSTMGradKernel(
const Context &ctx,
const DenseTensor &x,
const DenseTensor &init_h,
const DenseTensor &init_c,
const paddle::optional<std::vector<const DenseTensor *>> &weight_list,
const paddle::optional<DenseTensor> &sequence_length,
const DenseTensor &out,
const DenseTensor &reserve,
const DenseTensor &state_out,
const DenseTensor &out_grad,
const DenseTensor &last_h_grad,
const DenseTensor &last_c_grad,
float dropout_prob,
bool is_bidirec,
int hidden_size,
int num_layers,
bool is_test,
int seed,
DenseTensor *x_grad,
DenseTensor *init_h_grad,
DenseTensor *init_c_grad,
std::vector<DenseTensor *> weight_grad_list) {
auto input_dims = x.dims();
auto init_h_dims = init_h.dims();
auto init_c_dims = init_c.dims();
auto *init_h_data = init_h.data<T>();
auto *init_c_data = init_c.data<T>();
auto *out_data = out.data<T>();
auto *out_grad_data = out_grad.data<T>();
auto *last_h_grad_data = last_h_grad.data<T>();
auto *last_c_grad_data = last_c_grad.data<T>();
auto running_weight_list = *weight_list.get_ptr();
int weight_numel = size_sum(running_weight_list);
bool continuous = is_continuous<T, std::vector<const phi::DenseTensor *>>(
running_weight_list);
auto handle = ctx.cudnn_handle();
auto place = ctx.GetPlace();
auto stream = ctx.stream();
phi::DenseTensor weight_whole;
T *weight_data = nullptr;
if (!continuous) {
weight_whole.Resize({weight_numel});
ctx.template Alloc<T>(&weight_whole);
weight_to_tensor<T>(place, stream, running_weight_list, &weight_whole);
weight_data = weight_whole.data<T>();
} else {
weight_data = const_cast<T *>(running_weight_list[0]->data<T>());
}
phi::DenseTensor weight_grad;
phi::funcs::SetConstant<phi::GPUContext, T> zero;
weight_grad.Resize({weight_numel});
ctx.template Alloc<T>(&weight_grad);
zero(ctx, &weight_grad, static_cast<T>(0.0));
T *weight_grad_data = weight_grad.data<T>();
int offset = 0;
for (size_t i = 0; i < weight_grad_list.size(); ++i) {
size_t len = weight_grad_list[i]->numel();
auto dim = weight_grad_list[i]->dims();
weight_grad_list[i]
->ShareDataWith(weight_grad.Slice(static_cast<int64_t>(offset),
static_cast<int64_t>(offset + len)))
.Resize(dim);
offset += len;
}
x_grad->Resize(input_dims);
ctx.template Alloc<T>(x_grad);
auto *in_grad_data = x_grad->data<T>();
if (init_h_grad) {
init_h_grad->Resize(init_h_dims);
ctx.template Alloc<T>(init_h_grad);
}
auto *init_h_grad_data = init_h_grad ? init_h_grad->data<T>() : nullptr;
if (init_c_grad) {
init_c_grad->Resize(init_c_dims);
ctx.template Alloc<T>(init_c_grad);
}
auto *init_c_grad_data = init_c_grad ? init_c_grad->data<T>() : nullptr;
auto running_seq_length = sequence_length.get_ptr();
bool has_seq_length = running_seq_length != nullptr;
std::vector<int> SequenceLength;
if (has_seq_length) {
SequenceLength = phi::GetVectorFromTensor<int>(running_seq_length);
}
int seq_length = input_dims[0];
int batch_size = x.dims()[1];
int input_size = x.dims()[2];
size_t workspace_size;
size_t reserve_size;
ScopedRNNBase rnn(seq_length,
batch_size,
input_size,
hidden_size,
num_layers,
dropout_prob,
seed,
weight_numel,
true,
is_bidirec);
rnn.Create<T>(handle,
ctx.GetPlace(),
SequenceLength,
&workspace_size,
&reserve_size,
const_cast<phi::DenseTensor *>(&state_out));
phi::DenseTensor workspace_data_;
workspace_data_.Resize({static_cast<int64_t>(workspace_size)});
ctx.template Alloc<uint8_t>(&workspace_data_);
const uint8_t *reserve_data = reserve.data<uint8_t>();
if (!has_seq_length) {
// This interface is used when the input/output is unpadded.
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::miopenRNNBackwardData(handle,
rnn.rnn_desc(),
seq_length,
rnn.y_descs(),
out_data,
rnn.y_descs(),
out_grad_data,
rnn.last_h_desc(),
last_h_grad_data,
rnn.last_c_desc(),
last_c_grad_data,
rnn.weight_desc(),
weight_data,
rnn.init_h_desc(),
init_h_data,
rnn.init_c_desc(),
init_c_data,
rnn.x_descs(),
in_grad_data,
rnn.init_h_desc(),
init_h_grad_data,
rnn.init_c_desc(),
init_c_grad_data,
workspace_data_.data<uint8_t>(),
workspace_size,
const_cast<uint8_t *>(reserve_data),
reserve_size));
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::miopenRNNBackwardWeights(
handle,
rnn.rnn_desc(),
seq_length,
rnn.x_descs(),
x.data<T>(),
rnn.init_h_desc(),
init_h.data<T>(),
rnn.y_descs(),
out.data<T>(),
rnn.weight_desc(),
weight_grad_data,
workspace_data_.data<uint8_t>(),
workspace_size,
const_cast<uint8_t *>(reserve_data),
reserve_size));
#else
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::cudnnRNNBackwardData(handle,
rnn.rnn_desc(),
seq_length,
rnn.y_descs(),
out_data,
rnn.y_descs(),
out_grad_data,
rnn.last_h_desc(),
last_h_grad_data,
rnn.last_c_desc(),
last_c_grad_data,
rnn.weight_desc(),
weight_data,
rnn.init_h_desc(),
init_h_data,
rnn.init_c_desc(),
init_c_data,
rnn.x_descs(),
in_grad_data,
rnn.init_h_desc(),
init_h_grad_data,
rnn.init_c_desc(),
init_c_grad_data,
workspace_data_.data<uint8_t>(),
workspace_size,
const_cast<uint8_t *>(reserve_data),
reserve_size));
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnRNNBackwardWeights(
handle,
rnn.rnn_desc(),
seq_length,
rnn.x_descs(),
x.data<T>(),
rnn.init_h_desc(),
init_h.data<T>(),
rnn.y_descs(),
out.data<T>(),
workspace_data_.data<uint8_t>(),
workspace_size,
rnn.weight_desc(),
weight_grad_data,
const_cast<uint8_t *>(reserve_data),
reserve_size));
#endif
} else {
#if !defined(PADDLE_WITH_HIP) && CUDNN_VERSION >= 7201
// for train
// This interface is used when the input/output is padded.
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnRNNBackwardDataEx(
handle,
rnn.rnn_desc(),
rnn.y_seq_desc(),
out_data,
rnn.y_seq_desc(),
out_grad_data,
nullptr,
nullptr,
rnn.last_h_desc(),
last_h_grad_data,
rnn.last_c_desc(),
last_c_grad_data,
rnn.weight_desc(),
weight_data,
rnn.init_h_desc(),
init_h_data,
rnn.init_c_desc(),
init_c_data,
rnn.x_seq_desc(),
in_grad_data,
rnn.init_h_desc(),
init_h_grad_data,
rnn.init_c_desc(),
init_c_grad_data,
nullptr,
nullptr,
workspace_data_.data<uint8_t>(),
workspace_size,
const_cast<uint8_t *>(reserve_data),
reserve_size));
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnRNNBackwardWeightsEx(
handle,
rnn.rnn_desc(),
rnn.x_seq_desc(),
x.data<T>(),
rnn.init_h_desc(),
init_h.data<T>(),
rnn.y_seq_desc(),
out.data<T>(),
workspace_data_.data<uint8_t>(),
workspace_size,
rnn.weight_desc(),
weight_grad_data,
const_cast<uint8_t *>(reserve_data),
reserve_size));
#else
PADDLE_THROW(phi::errors::Unavailable(
"The padded input of rnn is supported by cudnnRNNBackwardDataEx, "
"cudnnRNNBackwardWeightsEx, but it only works when the version "
"of cudnn is larger than 7.2.1"));
#endif
}
}
} // namespace phi
#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(
cudnn_lstm_grad, GPU, ALL_LAYOUT, phi::CudnnLSTMGradKernel, float) {}
#else
PD_REGISTER_KERNEL(
cudnn_lstm_grad, GPU, ALL_LAYOUT, phi::CudnnLSTMGradKernel, float, double) {
}
#endif
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/cudnn_lstm_kernel.h"
#include "glog/logging.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/gpu/cudnn_lstm_utils.h"
namespace phi {
template <typename T>
#ifdef PADDLE_WITH_HIP
void LSTMInferece(const bool &has_seq_length,
const miopenHandle_t &handle,
#else
void LSTMInferece(const bool &has_seq_length,
const cudnnHandle_t &handle,
#endif
const int &seq_length,
ScopedRNNBase *rnn,
const T *x_data,
const T *init_h_data,
const T *init_c_data,
const T *w_data,
T *out_data,
T *last_h_data,
T *last_c_data,
phi::DenseTensor *workspace_data,
const size_t &workspace_size) {
if (!has_seq_length) {
// for inference
// This interface is used when the input/output is unpadded.
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::miopenRNNForwardInference(handle,
rnn->rnn_desc(),
seq_length,
rnn->x_descs(),
x_data,
rnn->init_h_desc(),
init_h_data,
rnn->init_c_desc(),
init_c_data,
rnn->weight_desc(),
w_data,
rnn->y_descs(),
out_data,
rnn->last_h_desc(),
last_h_data,
rnn->last_c_desc(),
last_c_data,
workspace_data->data<uint8_t>(),
workspace_size));
#else
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::cudnnRNNForwardInference(handle,
rnn->rnn_desc(),
seq_length,
rnn->x_descs(),
x_data,
rnn->init_h_desc(),
init_h_data,
rnn->init_c_desc(),
init_c_data,
rnn->weight_desc(),
w_data,
rnn->y_descs(),
out_data,
rnn->last_h_desc(),
last_h_data,
rnn->last_c_desc(),
last_c_data,
workspace_data->data<uint8_t>(),
workspace_size));
#endif
} else {
#if !defined(PADDLE_WITH_HIP) && CUDNN_VERSION >= 7201
// for inference
// This interface is used when the input/output is padded.
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnRNNForwardInferenceEx(
handle,
rnn->rnn_desc(),
rnn->x_seq_desc(),
x_data,
rnn->init_h_desc(),
init_h_data,
rnn->init_c_desc(),
init_c_data,
rnn->weight_desc(),
w_data,
rnn->y_seq_desc(),
out_data,
rnn->last_h_desc(),
last_h_data,
rnn->last_c_desc(),
last_c_data,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
workspace_data->data<uint8_t>(),
workspace_size));
#else
// CUDNN VERSION has to >=7.2.1
PADDLE_THROW(phi::errors::Unavailable(
"The padded input is supported by "
"cudnnRNNForwardInferenceEx, but it only works when "
"the version of cudnn is larger than 7.2.1"));
#endif
}
}
template <typename T, typename Context>
void CudnnLSTMKernel(
const Context &ctx,
const DenseTensor &x,
const DenseTensor &init_h,
const DenseTensor &init_c,
const paddle::optional<DenseTensor> &w,
const paddle::optional<std::vector<const DenseTensor *>> &weight_list,
const paddle::optional<DenseTensor> &sequence_length,
float dropout_prob,
bool is_bidirec,
int hidden_size,
int num_layers,
bool is_test,
int seed,
DenseTensor *out,
DenseTensor *last_h,
DenseTensor *last_c,
DenseTensor *reserve,
DenseTensor *state_out) {
const T *x_data = x.data<T>();
const T *init_h_data = init_h.data<T>();
const T *init_c_data = init_c.data<T>();
T *out_data = ctx.template Alloc<T>(out);
T *last_h_data = ctx.template Alloc<T>(last_h);
T *last_c_data = ctx.template Alloc<T>(last_c);
if (!is_test) {
if (seed == 0) {
// If not specify seed, use global Generator to generate seed.
int device_id = ctx.GetPlace().GetDeviceId();
auto gen_cuda = phi::DefaultCUDAGenerator(device_id);
seed = static_cast<int>(gen_cuda->Random64());
}
}
auto *running_sequence_length = sequence_length.get_ptr();
bool has_seq_length = running_sequence_length != nullptr;
std::vector<int> SequenceLength;
if (has_seq_length) {
SequenceLength = phi::GetVectorFromTensor<int>(running_sequence_length);
}
auto handle = ctx.cudnn_handle();
int seq_length = x.dims()[0];
int batch_size = x.dims()[1];
int input_size = x.dims()[2];
bool state_initialized = state_out->IsInitialized() ? true : false;
size_t workspace_size;
size_t reserve_size;
phi::DenseTensor weight_whole;
T *w_data = nullptr;
int weight_numel;
bool w_initialized = false;
auto place = ctx.GetPlace();
auto stream = ctx.stream();
auto *running_w = w.get_ptr();
if (is_test && running_w != nullptr) {
w_initialized = running_w->IsInitialized() ? true : false;
weight_numel = running_w->numel();
}
if (!w_initialized) {
auto running_weight_list = *weight_list.get_ptr();
bool continuous = is_continuous<T, std::vector<const phi::DenseTensor *>>(
running_weight_list);
weight_numel = size_sum(running_weight_list);
if (!continuous) {
LOG_FIRST_N(WARNING, 2)
<< "If the memory space of the Input WeightList is not continuous, "
"less efficient calculation will be called. Please call "
"flatten_parameters() to make the input memory continuous.";
weight_whole.Resize({weight_numel});
ctx.template Alloc<T>(&weight_whole);
weight_to_tensor<T>(place, stream, running_weight_list, &weight_whole);
w_data = weight_whole.data<T>();
if (is_test) { // maybe also reset small weights' ptr for training
int offset = 0;
for (size_t i = 0; i < running_weight_list.size(); ++i) {
size_t len = running_weight_list[i]->numel();
auto dim = running_weight_list[i]->dims();
const_cast<phi::DenseTensor *>(running_weight_list[i])
->ShareDataWith(
weight_whole.Slice(static_cast<int64_t>(offset),
static_cast<int64_t>(offset + len)))
.Resize(dim);
offset += len;
}
}
} else {
w_data = const_cast<T *>(running_weight_list[0]->data<T>());
}
} else {
w_data = const_cast<T *>(running_w->data<T>());
}
ScopedRNNBase rnn(seq_length,
batch_size,
input_size,
hidden_size,
num_layers,
dropout_prob,
seed,
weight_numel,
state_initialized,
is_bidirec);
rnn.Create<T>(handle,
ctx.GetPlace(),
SequenceLength,
&workspace_size,
&reserve_size,
state_out);
phi::DenseTensor workspace_data_;
workspace_data_.Resize({static_cast<int64_t>(workspace_size)});
ctx.template Alloc<uint8_t>(&workspace_data_);
reserve->Resize({static_cast<int64_t>(reserve_size)});
auto *reserve_data = ctx.template Alloc<uint8_t>(reserve);
if (is_test) {
LSTMInferece<T>(has_seq_length,
handle,
seq_length,
&rnn,
x_data,
init_h_data,
init_c_data,
w_data,
out_data,
last_h_data,
last_c_data,
&workspace_data_,
workspace_size);
} else {
if (!has_seq_length) {
// for train
// This interface is used when the input/output is unpadded.
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::miopenRNNForwardTraining(
handle,
rnn.rnn_desc(),
seq_length,
rnn.x_descs(),
x_data,
rnn.init_h_desc(),
init_h_data,
rnn.init_c_desc(),
init_c_data,
rnn.weight_desc(),
w_data,
rnn.y_descs(),
out_data,
rnn.last_h_desc(),
last_h_data,
rnn.last_c_desc(),
last_c_data,
workspace_data_.data<uint8_t>(),
workspace_size,
reserve_data,
reserve_size));
#else
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::cudnnRNNForwardTraining(handle,
rnn.rnn_desc(),
seq_length,
rnn.x_descs(),
x_data,
rnn.init_h_desc(),
init_h_data,
rnn.init_c_desc(),
init_c_data,
rnn.weight_desc(),
w_data,
rnn.y_descs(),
out_data,
rnn.last_h_desc(),
last_h_data,
rnn.last_c_desc(),
last_c_data,
workspace_data_.data<uint8_t>(),
workspace_size,
reserve_data,
reserve_size));
#endif
} else {
#if !defined(PADDLE_WITH_HIP) && CUDNN_VERSION >= 7201
// for train
// This interface is used when the input/output is padded.
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnRNNForwardTrainingEx(
handle,
rnn.rnn_desc(),
rnn.x_seq_desc(),
x_data,
rnn.init_h_desc(),
init_h_data,
rnn.init_c_desc(),
init_c_data,
rnn.weight_desc(),
w_data,
rnn.y_seq_desc(),
out_data,
rnn.last_h_desc(),
last_h_data,
rnn.last_c_desc(),
last_c_data,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
workspace_data_.data<uint8_t>(),
workspace_size,
reserve_data,
reserve_size));
#else
PADDLE_THROW(phi::errors::Unavailable(
"The padded input is supported by "
"cudnnRNNForwardTrainingEx, but it only works when "
"the version of cudnn is larger than 7.2.1"));
#endif
}
}
}
} // namespace phi
#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(cudnn_lstm, GPU, ALL_LAYOUT, phi::CudnnLSTMKernel, float) {
kernel->OutputAt(3).SetDataType(phi::DataType::UINT8);
kernel->OutputAt(4).SetDataType(phi::DataType::UINT8);
}
#else
PD_REGISTER_KERNEL(
cudnn_lstm, GPU, ALL_LAYOUT, phi::CudnnLSTMKernel, float, double) {
kernel->OutputAt(3).SetDataType(phi::DataType::UINT8);
kernel->OutputAt(4).SetDataType(phi::DataType::UINT8);
}
#endif
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/generator.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/phi/kernels/gpu/cudnn_lstm_cache.h"
#endif
#ifdef PADDLE_WITH_HIP
#include "paddle/phi/kernels/gpu/miopen_lstm_cache.h"
#endif
namespace phi {
template <typename T, typename Type>
inline bool is_continuous(const Type &weight_list) {
bool continuous = true;
for (size_t i = 0; i < weight_list.size() - 1; ++i) {
auto *in_data = weight_list[i]->template data<T>();
auto *in_after_data = weight_list[i + 1]->template data<T>();
auto in_size = weight_list[i]->numel();
bool temp = in_data + in_size == in_after_data;
continuous = continuous && temp;
}
return continuous;
}
inline int size_sum(const std::vector<const phi::DenseTensor *> &weight_list) {
int size = 0;
for (size_t i = 0; i < weight_list.size(); ++i) {
auto in_size = weight_list[i]->numel();
size += in_size;
}
return size;
}
template <typename T>
inline void weight_to_tensor(
const phi::Place &place,
gpuStream_t stream,
const std::vector<const phi::DenseTensor *> &weight_list,
phi::DenseTensor *weight) {
auto weight_data = weight->data<T>();
int weight_offset = 0;
for (size_t i = 0; i < weight_list.size(); ++i) {
const T *in_data = weight_list[i]->data<T>();
auto in_size = weight_list[i]->numel();
memory_utils::Copy(weight->place(),
weight_data + weight_offset,
weight_list[i]->place(),
in_data,
in_size * sizeof(T),
stream);
weight_offset += in_size;
}
}
} // namespace phi
......@@ -16,11 +16,13 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/backends/dynload/miopen.h"
#include "paddle/phi/backends/gpu/forwards.h"
#include "paddle/phi/backends/gpu/gpu_dnn.h"
#include "paddle/phi/core/dense_tensor.h"
namespace paddle {
namespace operators {
namespace phi {
class ScopedRNNBase {
public:
......@@ -47,13 +49,13 @@ class ScopedRNNBase {
template <typename T>
void Create(const miopenHandle_t& handle,
const platform::Place& place,
const phi::Place& place,
const std::vector<int>& sequence_length,
size_t* workspace_size,
size_t* reserve_size,
phi::DenseTensor* dropout_state) {
int numDirections = is_bidirec_ ? 2 : 1;
miopenDataType_t miopen_type = platform::CudnnDataType<T>::type;
miopenDataType_t miopen_type = phi::backends::gpu::CudnnDataType<T>::type;
// ------------------- miopen x, y descriptors ---------------------
std::vector<int> dims_x = {batch_size_, input_size_, 1};
......@@ -78,9 +80,11 @@ class ScopedRNNBase {
size_t state_size;
if (!initialized_) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::miopenDropoutGetStatesSize(handle, &state_size));
dropout_state->mutable_data<uint8_t>({static_cast<int64_t>(state_size)},
place);
phi::dynload::miopenDropoutGetStatesSize(handle, &state_size));
phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance();
auto* dev_ctx = reinterpret_cast<phi::GPUContext*>(pool.Get(place));
dropout_state->Resize({static_cast<int64_t>(state_size)});
dev_ctx->template Alloc<uint8_t>(dropout_state);
}
dropout_desc_.descriptor(handle,
place,
......@@ -91,7 +95,7 @@ class ScopedRNNBase {
state_size);
// ------------------- miopen rnn descriptors ---------------------
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenSetRNNDescriptor_V2(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::miopenSetRNNDescriptor_V2(
rnn_desc_.desc(),
hidden_size_,
num_layers_,
......@@ -105,31 +109,28 @@ class ScopedRNNBase {
// ------------------- miopen weights_size ---------------------
size_t weights_size_;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenGetRNNParamsSize(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::miopenGetRNNParamsSize(
handle, rnn_desc_.desc(), x_descs_[0], &weights_size_, miopen_type));
PADDLE_ENFORCE_EQ(
weights_size_,
sizeof(T) * weight_numel_,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"The miopen lstm and setting weight size should be same."));
// ------------------- miopen weight descriptors ---------------------
platform::DataLayout layout = platform::DataLayout::kNCHW;
phi::backends::gpu::DataLayout layout =
phi::backends::gpu::DataLayout::kNCHW;
int dim_tmp = weights_size_ / sizeof(T);
std::vector<int> dim_w = {dim_tmp, 1, 1};
weight_desc_.descriptor<T>(layout, dim_w);
// ------------------- miopen workspace, reserve size ---------------------
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::miopenGetRNNWorkspaceSize(handle,
rnn_desc_.desc(),
seq_length_,
x_descs_.data(),
workspace_size));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::miopenGetRNNTrainingReserveSize(handle,
rnn_desc_.desc(),
seq_length_,
x_descs_.data(),
reserve_size));
phi::dynload::miopenGetRNNWorkspaceSize(handle,
rnn_desc_.desc(),
seq_length_,
x_descs_.data(),
workspace_size));
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::miopenGetRNNTrainingReserveSize(
handle, rnn_desc_.desc(), seq_length_, x_descs_.data(), reserve_size));
}
miopenTensorDescriptor_t* x_descs() { return x_descs_.data(); }
miopenTensorDescriptor_t* y_descs() { return y_descs_.data(); }
......@@ -155,16 +156,15 @@ class ScopedRNNBase {
std::vector<miopenTensorDescriptor_t> x_descs_;
std::vector<miopenTensorDescriptor_t> y_descs_;
platform::ScopedTensorDescriptor x_desc_;
platform::ScopedTensorDescriptor y_desc_;
platform::ScopedTensorDescriptor init_h_desc_;
platform::ScopedTensorDescriptor init_c_desc_;
platform::ScopedTensorDescriptor last_h_desc_;
platform::ScopedTensorDescriptor last_c_desc_;
platform::ScopedDropoutDescriptor dropout_desc_;
platform::ScopedFilterDescriptor weight_desc_;
platform::ScopedRNNDescriptor rnn_desc_;
phi::backends::gpu::ScopedTensorDescriptor x_desc_;
phi::backends::gpu::ScopedTensorDescriptor y_desc_;
phi::backends::gpu::ScopedTensorDescriptor init_h_desc_;
phi::backends::gpu::ScopedTensorDescriptor init_c_desc_;
phi::backends::gpu::ScopedTensorDescriptor last_h_desc_;
phi::backends::gpu::ScopedTensorDescriptor last_c_desc_;
phi::backends::gpu::ScopedDropoutDescriptor dropout_desc_;
phi::backends::gpu::ScopedFilterDescriptor weight_desc_;
phi::backends::gpu::ScopedRNNDescriptor rnn_desc_;
};
} // namespace operators
} // namespace paddle
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature CudnnLSTMOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"cudnn_lstm",
{"Input", "InitH", "InitC", "W", "WeightList", "SequenceLength"},
{"dropout_prob",
"is_bidirec",
"hidden_size",
"num_layers",
"is_test",
"seed"},
{"Out", "LastH", "LastC", "Reserve", "StateOut"});
}
KernelSignature CudnnLSTMGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"cudnn_lstm_grad",
{"Input",
"InitH",
"InitC",
"WeightList",
"SequenceLength",
"Out",
"Reserve",
"StateOut",
"Out@GRAD",
"LastH@GRAD",
"LastC@GRAD"},
{"dropout_prob",
"is_bidirec",
"hidden_size",
"num_layers",
"is_test",
"seed"},
{"Input@GRAD", "InitH@GRAD", "InitC@GRAD", "WeightList@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(cudnn_lstm, phi::CudnnLSTMOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(cudnn_lstm_grad,
phi::CudnnLSTMGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册