未验证 提交 ee1ed42c 编写于 作者: G GaoWei8 提交者: GitHub

change sequence length attribute to input (#27193)

* replace sequence length attr to input
上级 2d8281d5
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
namespace paddle {
namespace operators {
class ScopedRNNBase {
public:
ScopedRNNBase(int seq_length, int batch_size, int input_size, int hidden_size,
int num_layers, float dropout_prob, int seed, int weight_numel,
bool initialized, bool is_bidirec)
: seq_length_(seq_length),
batch_size_(batch_size),
input_size_(input_size),
hidden_size_(hidden_size),
num_layers_(num_layers),
dropout_prob_(dropout_prob),
seed_(seed),
weight_numel_(weight_numel),
initialized_(initialized),
is_bidirec_(is_bidirec) {}
template <typename T>
void Create(const cudnnHandle_t& handle, const platform::Place& place,
const std::vector<int>& sequence_length, size_t* workspace_size,
size_t* reserve_size, framework::Tensor* dropout_state) {
int numDirections = is_bidirec_ ? 2 : 1;
cudnnDataType_t cudnn_type = platform::CudnnDataType<T>::type;
// ------------------- cudnn x, y descriptors ---------------------
std::vector<int> dims_x = {batch_size_, input_size_, 1};
std::vector<int> strides_x = {input_size_, 1, 1};
std::vector<int> dims_y = {batch_size_, hidden_size_ * numDirections, 1};
std::vector<int> strides_y = {hidden_size_ * numDirections, 1, 1};
for (int i = 0; i < seq_length_; ++i) {
x_descs_.emplace_back(x_desc_.descriptor<T>(dims_x, strides_x));
y_descs_.emplace_back(y_desc_.descriptor<T>(dims_y, strides_y));
}
if (!sequence_length.empty()) {
x_seq_desc_.descriptor<T>(seq_length_, batch_size_, input_size_, true,
sequence_length);
y_seq_desc_.descriptor<T>(seq_length_, batch_size_,
hidden_size_ * numDirections, true,
sequence_length);
}
// ------------------- cudnn hx, hy, cx, cy descriptors----------
std::vector<int> dims_hx = {num_layers_ * numDirections, batch_size_,
hidden_size_};
std::vector<int> strides_hx = {hidden_size_ * batch_size_, hidden_size_, 1};
init_h_desc_.descriptor<T>(dims_hx, strides_hx);
init_c_desc_.descriptor<T>(dims_hx, strides_hx);
last_h_desc_.descriptor<T>(dims_hx, strides_hx);
last_c_desc_.descriptor<T>(dims_hx, strides_hx);
// ------------------- cudnn dropout descriptors ---------------------
size_t state_size;
if (!initialized_) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnDropoutGetStatesSize(handle, &state_size));
dropout_state->mutable_data<uint8_t>({static_cast<int64_t>(state_size)},
place);
}
dropout_desc_.descriptor(handle, place, initialized_, dropout_prob_,
dropout_state, seed_, state_size);
// ------------------- cudnn rnn descriptors ---------------------
#if CUDNN_VERSION >= 6000
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetRNNDescriptor_v6(
handle, rnn_desc_.desc(), hidden_size_, num_layers_,
dropout_desc_.desc(), CUDNN_LINEAR_INPUT,
is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, CUDNN_LSTM,
CUDNN_RNN_ALGO_STANDARD, cudnn_type));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetRNNDescriptor(
rnn_desc_.desc(), hidden_size_, num_layers_, dropout_desc_.desc(),
CUDNN_LINEAR_INPUT,
is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, CUDNN_LSTM,
cudnn_type));
#endif
if (!sequence_length.empty()) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetRNNPaddingMode(
rnn_desc_.desc(), CUDNN_RNN_PADDED_IO_ENABLED));
}
// ------------------- cudnn weights_size ---------------------
size_t weights_size_;
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnGetRNNParamsSize(
handle, rnn_desc_.desc(), x_descs_[0], &weights_size_, cudnn_type));
PADDLE_ENFORCE_EQ(
weights_size_, sizeof(T) * weight_numel_,
platform::errors::InvalidArgument(
"The cudnn lstm and setting weight size should be same."));
// ------------------- cudnn weight descriptors ---------------------
platform::DataLayout layout = platform::DataLayout::kNCHW;
int dim_tmp = weights_size_ / sizeof(T);
std::vector<int> dim_w = {dim_tmp, 1, 1};
weight_desc_.descriptor<T>(layout, dim_w);
// ------------------- cudnn workspace, reserve size ---------------------
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnGetRNNWorkspaceSize(
handle, rnn_desc_.desc(), seq_length_, x_descs_.data(),
workspace_size));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnGetRNNTrainingReserveSize(
handle, rnn_desc_.desc(), seq_length_, x_descs_.data(),
reserve_size));
}
cudnnTensorDescriptor_t* x_descs() { return x_descs_.data(); }
cudnnTensorDescriptor_t* y_descs() { return y_descs_.data(); }
cudnnRNNDataDescriptor_t x_seq_desc() { return x_seq_desc_.desc(); }
cudnnRNNDataDescriptor_t y_seq_desc() { return y_seq_desc_.desc(); }
cudnnTensorDescriptor_t init_h_desc() { return init_h_desc_.desc(); }
cudnnTensorDescriptor_t init_c_desc() { return init_c_desc_.desc(); }
cudnnTensorDescriptor_t last_h_desc() { return last_h_desc_.desc(); }
cudnnTensorDescriptor_t last_c_desc() { return last_c_desc_.desc(); }
cudnnRNNDescriptor_t rnn_desc() { return rnn_desc_.desc(); }
cudnnDropoutDescriptor_t dropout_desc() { return dropout_desc_.desc(); }
cudnnFilterDescriptor_t weight_desc() { return weight_desc_.desc(); }
private:
int seq_length_;
int batch_size_;
int input_size_;
int hidden_size_;
int num_layers_;
float dropout_prob_;
int seed_;
int weight_numel_;
bool initialized_;
bool is_bidirec_;
std::vector<cudnnTensorDescriptor_t> x_descs_;
std::vector<cudnnTensorDescriptor_t> y_descs_;
platform::ScopedTensorDescriptor x_desc_;
platform::ScopedTensorDescriptor y_desc_;
platform::ScopedRNNTensorDescriptor x_seq_desc_;
platform::ScopedRNNTensorDescriptor y_seq_desc_;
platform::ScopedTensorDescriptor init_h_desc_;
platform::ScopedTensorDescriptor init_c_desc_;
platform::ScopedTensorDescriptor last_h_desc_;
platform::ScopedTensorDescriptor last_c_desc_;
platform::ScopedDropoutDescriptor dropout_desc_;
platform::ScopedFilterDescriptor weight_desc_;
platform::ScopedRNNDescriptor rnn_desc_;
};
} // namespace operators
} // namespace paddle
......@@ -51,6 +51,16 @@ class CudnnLSTMOp : public framework::OperatorWithKernel {
"received InitH's rank is %d.",
init_h_dims.size()));
if (ctx->HasInput("SequenceLength")) {
auto seq_dims = ctx->GetInputDim("SequenceLength");
PADDLE_ENFORCE_EQ(
in_dims[1], seq_dims[0],
platform::errors::InvalidArgument(
"The size of SequenceLength has to equal the batch_size. But "
"received batch_size is %d and the size of SequenceLength is %d.",
in_dims[1], seq_dims[0]));
}
PADDLE_ENFORCE_EQ(
in_dims[1], init_h_dims[1],
platform::errors::InvalidArgument(
......@@ -113,6 +123,12 @@ class CudnnLSTMOpMaker : public framework::OpProtoAndCheckerMaker {
"(Tensor) the learnable hidden-hidden weights."
" The shape is (N), where N is total weight size of the LSTM. "
" cudnn concatenate all the weight to one Tensor");
AddInput("SequenceLength",
"(Tensor) When the input data is padding, "
"set this parameter. This parameter represents "
"the variable sequence lengths in a batch. "
"The size of the vector has to equal the batch_size.")
.AsDispensable();
AddOutput("Reserve",
"(Tensor, a temporary output Tensor to store the reserve_data "
"of cudnn kernel.")
......@@ -155,13 +171,6 @@ class CudnnLSTMOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault(1);
AddAttr<bool>("is_test", "True if in test phase.").SetDefault(false);
AddAttr<int>("seed", "seed to used if fix_seed is True").SetDefault(0);
AddAttr<std::vector<int>>("sequence_length",
"(vector<int>) When the input data is padding, "
"set this parameter. This parameter represents "
"the variable sequence"
"lengths in a batch. The size of the vector has "
"to equal the batch_size.")
.SetDefault({});
AddComment(R"DOC(
CUDNN LSTM implementation
......@@ -243,6 +252,9 @@ class CudnnLSTMGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("InitH", this->Input("InitH"));
op->SetInput("InitC", this->Input("InitC"));
op->SetInput("W", this->Input("W"));
if (this->HasInput("SequenceLength")) {
op->SetInput("SequenceLength", this->Input("SequenceLength"));
}
op->SetInput("Reserve", this->Output("Reserve"));
op->SetInput("StateOut", this->Output("StateOut"));
op->SetInput("Out", this->Output("Out"));
......
......@@ -13,8 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/cudnn_rnn_cache.h"
#include "paddle/fluid/operators/cudnn_lstm_cache.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/utils.h"
#include "paddle/fluid/platform/cudnn_desc.h"
#include "paddle/fluid/platform/cudnn_helper.h"
......@@ -24,6 +25,43 @@ namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
template <typename T>
void LSTMInferece(const bool &has_seq_length, const cudnnHandle_t &handle,
const int &seq_length, ScopedRNNBase *rnn, const T *x_data,
const T *init_h_data, const T *init_c_data, const T *w_data,
T *out_data, T *last_h_data, T *last_c_data,
framework::Tensor *workspace_data,
const size_t &workspace_size) {
if (!has_seq_length) {
// for inference
// This interface is used when the input/output is unpadded.
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardInference(
handle, rnn->rnn_desc(), seq_length, rnn->x_descs(), x_data,
rnn->init_h_desc(), init_h_data, rnn->init_c_desc(), init_c_data,
rnn->weight_desc(), w_data, rnn->y_descs(), out_data,
rnn->last_h_desc(), last_h_data, rnn->last_c_desc(), last_c_data,
workspace_data->data<uint8_t>(), workspace_size));
} else {
#if CUDNN_VERSION >= 7201
// for inference
// This interface is used when the input/output is padded.
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardInferenceEx(
handle, rnn->rnn_desc(), rnn->x_seq_desc(), x_data, rnn->init_h_desc(),
init_h_data, rnn->init_c_desc(), init_c_data, rnn->weight_desc(),
w_data, rnn->y_seq_desc(), out_data, rnn->last_h_desc(), last_h_data,
rnn->last_c_desc(), last_c_data, nullptr, nullptr, nullptr, nullptr,
nullptr, nullptr, nullptr, nullptr, workspace_data->data<uint8_t>(),
workspace_size));
#else
// CUDNN VERSION has to >=7.2.1
PADDLE_THROW(platform::errors::Unavailable(
"The padded input is supported by "
"cudnnRNNForwardInferenceEx, but it only works when "
"the version of cudnn is larger than 7.2.1"));
#endif
}
}
template <typename T>
class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
public:
......@@ -56,7 +94,13 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
int num_layers = ctx.Attr<int>("num_layers");
bool is_test = ctx.Attr<bool>("is_test");
int seed = ctx.Attr<int>("seed");
auto sequence_length = ctx.Attr<std::vector<int>>("sequence_length");
bool has_seq_length = ctx.HasInput("SequenceLength");
std::vector<int> SequenceLength;
if (has_seq_length) {
auto *sequence_length = ctx.Input<Tensor>("SequenceLength");
SequenceLength = operators::GetDataFromTensor<int>(sequence_length);
}
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle();
......@@ -70,58 +114,32 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
size_t workspace_size;
size_t reserve_size;
platform::ScopedRNNBase rnn(seq_length, batch_size, input_size, hidden_size,
num_layers, dropout_prob, seed, weight_numel,
state_initialized, is_bidirec);
rnn.Create<T>(handle, ctx.GetPlace(), sequence_length, &workspace_size,
ScopedRNNBase rnn(seq_length, batch_size, input_size, hidden_size,
num_layers, dropout_prob, seed, weight_numel,
state_initialized, is_bidirec);
rnn.Create<T>(handle, ctx.GetPlace(), SequenceLength, &workspace_size,
&reserve_size, state_out);
framework::Tensor workspace_data_;
workspace_data_.Resize({static_cast<int64_t>(workspace_size)});
workspace_data_.mutable_data<uint8_t>(ctx.GetPlace());
workspace_data_.mutable_data<uint8_t>(
{static_cast<int64_t>(workspace_size)}, ctx.GetPlace());
auto *reserve_data = reserve->mutable_data<uint8_t>(
{static_cast<int64_t>(reserve_size)}, ctx.GetPlace());
if (is_test) {
if (sequence_length.empty()) {
// for inference
// This interface is used when the input/output is unpadded.
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardInference(
handle, rnn.rnn_desc(), seq_length, rnn.x_desc(), x_data,
rnn.hx_desc(), init_h_data, rnn.cx_desc(), init_c_data,
rnn.w_desc(), w_data, rnn.y_desc(), out_data, rnn.hy_desc(),
last_h_data, rnn.cy_desc(), last_c_data,
workspace_data_.data<uint8_t>(), workspace_size));
} else {
#if CUDNN_VERSION >= 7201
// for inference
// This interface is used when the input/output is padded.
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnRNNForwardInferenceEx(
handle, rnn.rnn_desc(), rnn.x_seq_desc(), x_data, rnn.hx_desc(),
init_h_data, rnn.cx_desc(), init_c_data, rnn.w_desc(), w_data,
rnn.y_seq_desc(), out_data, rnn.hy_desc(), last_h_data,
rnn.cy_desc(), last_c_data, nullptr, nullptr, nullptr, nullptr,
nullptr, nullptr, nullptr, nullptr,
workspace_data_.data<uint8_t>(), workspace_size));
#else
PADDLE_ENFORCE_NOT_NULL(
nullptr, platform::errors::Unavailable(
"The padded input is supported by "
"cudnnRNNForwardInferenceEx, but it only works when "
"the version of cudnn is larger than 7.2.1"));
#endif
}
LSTMInferece<T>(has_seq_length, handle, seq_length, &rnn, x_data,
init_h_data, init_c_data, w_data, out_data, last_h_data,
last_c_data, &workspace_data_, workspace_size);
} else {
if (sequence_length.empty()) {
if (!has_seq_length) {
// for train
// This interface is used when the input/output is unpadded.
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardTraining(
handle, rnn.rnn_desc(), seq_length, rnn.x_desc(), x_data,
rnn.hx_desc(), init_h_data, rnn.cx_desc(), init_c_data,
rnn.w_desc(), w_data, rnn.y_desc(), out_data, rnn.hy_desc(),
last_h_data, rnn.cy_desc(), last_c_data,
handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), x_data,
rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data,
rnn.weight_desc(), w_data, rnn.y_descs(), out_data,
rnn.last_h_desc(), last_h_data, rnn.last_c_desc(), last_c_data,
workspace_data_.data<uint8_t>(), workspace_size, reserve_data,
reserve_size));
} else {
......@@ -130,19 +148,18 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
// This interface is used when the input/output is padded.
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnRNNForwardTrainingEx(
handle, rnn.rnn_desc(), rnn.x_seq_desc(), x_data, rnn.hx_desc(),
init_h_data, rnn.cx_desc(), init_c_data, rnn.w_desc(), w_data,
rnn.y_seq_desc(), out_data, rnn.hy_desc(), last_h_data,
rnn.cy_desc(), last_c_data, nullptr, nullptr, nullptr, nullptr,
nullptr, nullptr, nullptr, nullptr,
workspace_data_.data<uint8_t>(), workspace_size, reserve_data,
reserve_size));
handle, rnn.rnn_desc(), rnn.x_seq_desc(), x_data,
rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data,
rnn.weight_desc(), w_data, rnn.y_seq_desc(), out_data,
rnn.last_h_desc(), last_h_data, rnn.last_c_desc(), last_c_data,
nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
nullptr, workspace_data_.data<uint8_t>(), workspace_size,
reserve_data, reserve_size));
#else
PADDLE_ENFORCE_NOT_NULL(
nullptr, platform::errors::Unavailable(
"The padded input is supported by "
"cudnnRNNForwardTrainingEx, but it only works when "
"the version of cudnn is larger than 7.2.1"));
PADDLE_THROW(platform::errors::Unavailable(
"The padded input is supported by "
"cudnnRNNForwardTrainingEx, but it only works when "
"the version of cudnn is larger than 7.2.1"));
#endif
}
}
......@@ -203,7 +220,13 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
int hidden_size = ctx.Attr<int>("hidden_size");
int num_layers = ctx.Attr<int>("num_layers");
int seed = ctx.Attr<int>("seed");
auto sequence_length = ctx.Attr<std::vector<int>>("sequence_length");
bool has_seq_length = ctx.HasInput("SequenceLength");
std::vector<int> SequenceLength;
if (has_seq_length) {
auto *sequence_length = ctx.Input<Tensor>("SequenceLength");
SequenceLength = operators::GetDataFromTensor<int>(sequence_length);
}
int seq_length = input_dims[0];
int batch_size = input->dims()[1];
......@@ -213,33 +236,33 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
size_t workspace_size;
size_t reserve_size;
platform::ScopedRNNBase rnn(seq_length, batch_size, input_size, hidden_size,
num_layers, dropout_prob, seed, weight_numel,
true, is_bidirec);
ScopedRNNBase rnn(seq_length, batch_size, input_size, hidden_size,
num_layers, dropout_prob, seed, weight_numel, true,
is_bidirec);
rnn.Create<T>(handle, ctx.GetPlace(), sequence_length, &workspace_size,
rnn.Create<T>(handle, ctx.GetPlace(), SequenceLength, &workspace_size,
&reserve_size, const_cast<Tensor *>(state_out));
framework::Tensor workspace_data_;
workspace_data_.Resize({static_cast<int64_t>(workspace_size)});
workspace_data_.mutable_data<uint8_t>(ctx.GetPlace());
workspace_data_.mutable_data<uint8_t>(
{static_cast<int64_t>(workspace_size)}, ctx.GetPlace());
const uint8_t *reserve_data = reserve->data<uint8_t>();
if (sequence_length.empty()) {
if (!has_seq_length) {
// This interface is used when the input/output is unpadded.
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardData(
handle, rnn.rnn_desc(), seq_length, rnn.y_desc(), out_data,
rnn.y_desc(), out_grad_data, rnn.hy_desc(), last_h_grad_data,
rnn.cy_desc(), last_c_grad_data, rnn.w_desc(), weight_data,
rnn.hx_desc(), init_h_data, rnn.cx_desc(), init_c_data, rnn.x_desc(),
in_grad_data, rnn.hx_desc(), init_h_grad_data, rnn.cx_desc(),
init_c_grad_data, workspace_data_.data<uint8_t>(), workspace_size,
const_cast<uint8_t *>(reserve_data), reserve_size));
handle, rnn.rnn_desc(), seq_length, rnn.y_descs(), out_data,
rnn.y_descs(), out_grad_data, rnn.last_h_desc(), last_h_grad_data,
rnn.last_c_desc(), last_c_grad_data, rnn.weight_desc(), weight_data,
rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data,
rnn.x_descs(), in_grad_data, rnn.init_h_desc(), init_h_grad_data,
rnn.init_c_desc(), init_c_grad_data, workspace_data_.data<uint8_t>(),
workspace_size, const_cast<uint8_t *>(reserve_data), reserve_size));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardWeights(
handle, rnn.rnn_desc(), seq_length, rnn.x_desc(), input->data<T>(),
rnn.hx_desc(), init_h->data<T>(), rnn.y_desc(), out->data<T>(),
workspace_data_.data<uint8_t>(), workspace_size, rnn.w_desc(),
handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), input->data<T>(),
rnn.init_h_desc(), init_h->data<T>(), rnn.y_descs(), out->data<T>(),
workspace_data_.data<uint8_t>(), workspace_size, rnn.weight_desc(),
weight_grad->data<T>(), const_cast<uint8_t *>(reserve_data),
reserve_size));
} else {
......@@ -248,27 +271,25 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
// This interface is used when the input/output is padded.
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardDataEx(
handle, rnn.rnn_desc(), rnn.y_seq_desc(), out_data, rnn.y_seq_desc(),
out_grad_data, nullptr, nullptr, rnn.hy_desc(), last_h_grad_data,
rnn.cy_desc(), last_c_grad_data, rnn.w_desc(), weight_data,
rnn.hx_desc(), init_h_data, rnn.cx_desc(), init_c_data,
rnn.x_seq_desc(), in_grad_data, rnn.hx_desc(), init_h_grad_data,
rnn.cx_desc(), init_c_grad_data, nullptr, nullptr,
out_grad_data, nullptr, nullptr, rnn.last_h_desc(), last_h_grad_data,
rnn.last_c_desc(), last_c_grad_data, rnn.weight_desc(), weight_data,
rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data,
rnn.x_seq_desc(), in_grad_data, rnn.init_h_desc(), init_h_grad_data,
rnn.init_c_desc(), init_c_grad_data, nullptr, nullptr,
workspace_data_.data<uint8_t>(), workspace_size,
const_cast<uint8_t *>(reserve_data), reserve_size));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardWeightsEx(
handle, rnn.rnn_desc(), rnn.x_seq_desc(), input->data<T>(),
rnn.hx_desc(), init_h->data<T>(), rnn.y_seq_desc(), out->data<T>(),
workspace_data_.data<uint8_t>(), workspace_size, rnn.w_desc(),
weight_grad->data<T>(), const_cast<uint8_t *>(reserve_data),
reserve_size));
rnn.init_h_desc(), init_h->data<T>(), rnn.y_seq_desc(),
out->data<T>(), workspace_data_.data<uint8_t>(), workspace_size,
rnn.weight_desc(), weight_grad->data<T>(),
const_cast<uint8_t *>(reserve_data), reserve_size));
#else
PADDLE_ENFORCE_NOT_NULL(
nullptr,
platform::errors::Unavailable(
"The padded input of rnn is supported by cudnnRNNBackwardDataEx, "
"cudnnRNNBackwardWeightsEx, but it only works when the version "
"of cudnn is larger than 7.2.1"));
PADDLE_THROW(platform::errors::Unavailable(
"The padded input of rnn is supported by cudnnRNNBackwardDataEx, "
"cudnnRNNBackwardWeightsEx, but it only works when the version "
"of cudnn is larger than 7.2.1"));
#endif
}
}
......
......@@ -287,6 +287,8 @@ class ScopedTensorDescriptor {
return descriptor(CudnnDataType<T>::type, dim, stride);
}
inline cudnnTensorDescriptor_t desc() { return desc_; }
private:
cudnnTensorDescriptor_t desc_;
DISABLE_COPY_AND_ASSIGN(ScopedTensorDescriptor);
......@@ -329,6 +331,8 @@ class ScopedRNNTensorDescriptor {
input_size, time_major, seq_length);
}
inline cudnnRNNDataDescriptor_t desc() { return desc_; }
private:
cudnnRNNDataDescriptor_t desc_;
DISABLE_COPY_AND_ASSIGN(ScopedRNNTensorDescriptor);
......@@ -361,6 +365,7 @@ class ScopedDropoutDescriptor {
}
return desc_;
}
inline cudnnDropoutDescriptor_t desc() { return desc_; }
private:
cudnnDropoutDescriptor_t desc_;
......@@ -376,7 +381,7 @@ class ScopedRNNDescriptor {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnDestroyRNNDescriptor(desc_));
}
inline cudnnRNNDescriptor_t descriptor() { return desc_; }
inline cudnnRNNDescriptor_t desc() { return desc_; }
private:
cudnnRNNDescriptor_t desc_;
......@@ -419,172 +424,13 @@ class ScopedFilterDescriptor {
kernel, groups);
}
inline cudnnFilterDescriptor_t desc() { return desc_; }
private:
cudnnFilterDescriptor_t desc_;
DISABLE_COPY_AND_ASSIGN(ScopedFilterDescriptor);
};
class ScopedRNNBase {
public:
ScopedRNNBase(int seq_length, int batch_size, int input_size, int hidden_size,
int num_layers, float dropout_prob, int seed, int weight_numel,
bool initialized, bool is_bidirec)
: seq_length_(seq_length),
batch_size_(batch_size),
input_size_(input_size),
hidden_size_(hidden_size),
num_layers_(num_layers),
dropout_prob_(dropout_prob),
seed_(seed),
weight_numel_(weight_numel),
initialized_(initialized),
is_bidirec_(is_bidirec) {}
template <typename T>
void Create(const cudnnHandle_t& handle, const platform::Place& place,
std::vector<int> sequence_length, size_t* workspace_size,
size_t* reserve_size, framework::Tensor* dropout_state) {
int numDirections = is_bidirec_ ? 2 : 1;
cudnnDataType_t cudnn_type = platform::CudnnDataType<T>::type;
// ------------------- cudnn x, y descriptors ---------------------
std::vector<int> dims_x = {batch_size_, input_size_, 1};
std::vector<int> strides_x = {input_size_, 1, 1};
std::vector<int> dims_y = {batch_size_, hidden_size_ * numDirections, 1};
std::vector<int> strides_y = {hidden_size_ * numDirections, 1, 1};
for (int i = 0; i < seq_length_; ++i) {
x_desc_.emplace_back(x_d.descriptor<T>(dims_x, strides_x));
y_desc_.emplace_back(y_d.descriptor<T>(dims_y, strides_y));
}
if (!sequence_length.empty()) {
x_seq_desc_ = x_seq_d.descriptor<T>(seq_length_, batch_size_, input_size_,
true, sequence_length);
y_seq_desc_ = y_seq_d.descriptor<T>(seq_length_, batch_size_,
hidden_size_ * numDirections, true,
sequence_length);
}
// ------------------- cudnn hx, hy, cx, cy descriptors----------
std::vector<int> dims_hx = {num_layers_ * numDirections, batch_size_,
hidden_size_};
std::vector<int> strides_hx = {hidden_size_ * batch_size_, hidden_size_, 1};
hx_desc_ = hx_d.descriptor<T>(dims_hx, strides_hx);
cx_desc_ = cx_d.descriptor<T>(dims_hx, strides_hx);
hy_desc_ = hy_d.descriptor<T>(dims_hx, strides_hx);
cy_desc_ = cy_d.descriptor<T>(dims_hx, strides_hx);
// ------------------- cudnn dropout descriptors ---------------------
size_t state_size;
if (!initialized_) {
PADDLE_ENFORCE_CUDA_SUCCESS(
dynload::cudnnDropoutGetStatesSize(handle, &state_size));
dropout_state->mutable_data<uint8_t>({static_cast<int64_t>(state_size)},
place);
}
dropout_desc_ =
dropout_d.descriptor(handle, place, initialized_, dropout_prob_,
dropout_state, seed_, state_size);
// ------------------- cudnn rnn descriptors ---------------------
rnn_desc_ = rnn_d.descriptor();
#if CUDNN_VERSION >= 6000
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetRNNDescriptor_v6(
handle, rnn_desc_, hidden_size_, num_layers_, dropout_desc_,
CUDNN_LINEAR_INPUT,
is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, CUDNN_LSTM,
CUDNN_RNN_ALGO_STANDARD, cudnn_type));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetRNNDescriptor(
rnn_desc_, hidden_size_, num_layers_, dropout_desc_, CUDNN_LINEAR_INPUT,
is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, CUDNN_LSTM,
cudnn_type));
#endif
if (!sequence_length.empty()) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetRNNPaddingMode(
rnn_desc_, CUDNN_RNN_PADDED_IO_ENABLED));
}
// ------------------- cudnn weights_size ---------------------
size_t weights_size_;
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnGetRNNParamsSize(
handle, rnn_desc_, x_desc_[0], &weights_size_, cudnn_type));
PADDLE_ENFORCE_EQ(
weights_size_, sizeof(T) * weight_numel_,
platform::errors::InvalidArgument(
"The cudnn lstm and setting weight size should be same."));
// ------------------- cudnn weight descriptors ---------------------
platform::DataLayout layout = platform::DataLayout::kNCHW;
int dim_tmp = weights_size_ / sizeof(T);
std::vector<int> dim_w = {dim_tmp, 1, 1};
w_desc_ = w_d.descriptor<T>(layout, dim_w);
// ------------------- cudnn workspace, reserve size ---------------------
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnGetRNNWorkspaceSize(
handle, rnn_desc_, seq_length_, x_desc_.data(), workspace_size));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnGetRNNTrainingReserveSize(
handle, rnn_desc_, seq_length_, x_desc_.data(), reserve_size));
}
cudnnTensorDescriptor_t* x_desc() { return x_desc_.data(); }
cudnnTensorDescriptor_t* y_desc() { return y_desc_.data(); }
cudnnRNNDataDescriptor_t x_seq_desc() { return x_seq_desc_; }
cudnnRNNDataDescriptor_t y_seq_desc() { return y_seq_desc_; }
cudnnTensorDescriptor_t hx_desc() { return hx_desc_; }
cudnnTensorDescriptor_t cx_desc() { return cx_desc_; }
cudnnTensorDescriptor_t hy_desc() { return hy_desc_; }
cudnnTensorDescriptor_t cy_desc() { return cy_desc_; }
cudnnRNNDescriptor_t rnn_desc() { return rnn_desc_; }
cudnnDropoutDescriptor_t dropout_desc() { return dropout_desc_; }
cudnnFilterDescriptor_t w_desc() { return w_desc_; }
private:
int seq_length_;
int batch_size_;
int input_size_;
int hidden_size_;
int num_layers_;
float dropout_prob_;
int seed_;
int weight_numel_;
bool initialized_;
bool is_bidirec_;
std::vector<cudnnTensorDescriptor_t> x_desc_;
std::vector<cudnnTensorDescriptor_t> y_desc_;
cudnnRNNDataDescriptor_t x_seq_desc_;
cudnnRNNDataDescriptor_t y_seq_desc_;
// A tensor descriptor describing the initial hidden state of the RNN.
cudnnTensorDescriptor_t hx_desc_;
// A tensor descriptor describing the initial cell state for LSTM networks.
cudnnTensorDescriptor_t cx_desc_;
// A tensor descriptor describing the final hidden state of the RNN.
cudnnTensorDescriptor_t hy_desc_;
// A tensor descriptor describing the final cell state for LSTM networks.
cudnnTensorDescriptor_t cy_desc_;
cudnnDropoutDescriptor_t dropout_desc_;
cudnnFilterDescriptor_t w_desc_;
cudnnRNNDescriptor_t rnn_desc_;
ScopedTensorDescriptor x_d;
ScopedTensorDescriptor y_d;
ScopedRNNTensorDescriptor x_seq_d;
ScopedRNNTensorDescriptor y_seq_d;
ScopedTensorDescriptor hx_d;
ScopedTensorDescriptor cx_d;
ScopedTensorDescriptor hy_d;
ScopedTensorDescriptor cy_d;
ScopedDropoutDescriptor dropout_d;
ScopedFilterDescriptor w_d;
ScopedRNNDescriptor rnn_d;
};
class ScopedConvolutionDescriptor {
public:
ScopedConvolutionDescriptor() {
......
......@@ -400,7 +400,8 @@ class TestCUDNNLstmOp(OpTest):
'Input': input,
'W': flat_w,
'InitH': init_h,
'InitC': init_c
'InitC': init_c,
'SequenceLength': self.sequence_length
}
self.attrs = {
'dropout_prob': 0.0,
......@@ -408,7 +409,6 @@ class TestCUDNNLstmOp(OpTest):
'input_size': input_size,
'hidden_size': hidden_size,
'num_layers': 1,
'sequence_length': self.sequence_length.tolist()
}
self.outputs = {
'Out': output,
......@@ -436,13 +436,6 @@ class TestCUDNNLstmOp(OpTest):
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNLstmOp2(TestCUDNNLstmOp):
def set_attrs(self):
self.sequence_length = np.array([], dtype=np.int32)
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNLstmOp3(TestCUDNNLstmOp):
def set_attrs(self):
self.num_layers = 2
......
......@@ -26,4 +26,5 @@ NEED_TO_FIX_OP_LIST = [
'squared_l2_distance',
'tree_conv',
'cvm',
'cudnn_lstm',
]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册