未验证 提交 4ff16eb2 编写于 作者: G GaoWei8 提交者: GitHub

Add padding cudnn interface (#26370)

* add lstm cudnn of padding data and refine cudnn codes
上级 35f53ecd
......@@ -37,41 +37,42 @@ class CudnnLSTMOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasOutput("LastC"), "Output", "LastC", "CudnnLSTM");
auto in_dims = ctx->GetInputDim("Input");
auto init_dims = ctx->GetInputDim("InitH");
auto init_h_dims = ctx->GetInputDim("InitH");
auto init_c_dims = ctx->GetInputDim("InitC");
PADDLE_ENFORCE_EQ(in_dims.size(), 3,
platform::errors::InvalidArgument(
"The rank of Input in CudnnLSTM must be 3. But "
"received Input's rank is %d.",
in_dims.size()));
PADDLE_ENFORCE_EQ(init_dims.size(), 3,
PADDLE_ENFORCE_EQ(init_h_dims.size(), 3,
platform::errors::InvalidArgument(
"The rank of InitH in CudnnLSTM must be 3. But "
"received InitH's rank is %d.",
init_dims.size()));
init_h_dims.size()));
PADDLE_ENFORCE_EQ(in_dims[1], init_dims[1],
platform::errors::InvalidArgument(
"The in_dims[1] (Input dims) and init_dims[1] (InitH "
"dims) should be equal. But "
"received in_dims[1] is %d and init_dims[1] is %d.",
in_dims[1], init_dims[1]));
PADDLE_ENFORCE_EQ(in_dims[2], init_dims[2],
PADDLE_ENFORCE_EQ(
in_dims[1], init_h_dims[1],
platform::errors::InvalidArgument(
"The in_dims[1] (Input dims) and init_h_dims[1] (InitH "
"dims) should be equal. But "
"received in_dims[1] is %d and init_h_dims[1] is %d.",
in_dims[1], init_h_dims[1]));
PADDLE_ENFORCE_EQ(init_c_dims, init_h_dims,
platform::errors::InvalidArgument(
"The in_dims[2] (Input dims) and init_dims[2] (InitH "
"dims) should be equal. But "
"received in_dims[2] is %d and init_dims[2] is %d.",
in_dims[2], init_dims[2]));
"The InitC dims and InitH "
"dims should be equal. But "
"received init_c_dims is %d and init_h_dims is %d.",
init_c_dims, init_h_dims));
auto out_dims = in_dims;
auto hidden_size = ctx->Attrs().Get<int>("hidden_size");
bool is_bidirec = ctx->Attrs().Get<bool>("is_bidirec");
out_dims[2] = is_bidirec ? hidden_size * 2 : hidden_size;
auto last_dims = init_dims;
last_dims[0] = is_bidirec ? last_dims[0] * 2 : last_dims[0];
ctx->SetOutputDim("Out", out_dims);
ctx->SetOutputDim("LastH", last_dims);
ctx->SetOutputDim("LastC", last_dims);
ctx->SetOutputDim("LastH", init_c_dims);
ctx->SetOutputDim("LastC", init_h_dims);
}
protected:
......@@ -95,7 +96,7 @@ class CudnnLSTMOpMaker : public framework::OpProtoAndCheckerMaker {
"different batch)"
"batch_size is the instance number of this batch"
"input_size is the hidden size of the input."
"input_hidden_size and the hidden_size in the next may not be same");
"input_size and the hidden_size in the next may not be same");
AddInput("InitH",
"(Tensor) the initial hidden state of the LSTM"
"input. This is a tensor with shape (num_layers x batch_size x "
......@@ -154,6 +155,13 @@ class CudnnLSTMOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault(1);
AddAttr<bool>("is_test", "True if in test phase.").SetDefault(false);
AddAttr<int>("seed", "seed to used if fix_seed is True").SetDefault(0);
AddAttr<std::vector<int>>("sequence_length",
"(vector<int>) When the input data is padding, "
"set this parameter. This parameter represents "
"the variable sequence"
"lengths in a batch. The size of the vector has "
"to equal the batch_size.")
.SetDefault({});
AddComment(R"DOC(
CUDNN LSTM implementation
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/operators/cudnn_rnn_cache.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cudnn_desc.h"
#include "paddle/fluid/platform/cudnn_helper.h"
namespace paddle {
namespace operators {
......@@ -55,50 +56,96 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
int num_layers = ctx.Attr<int>("num_layers");
bool is_test = ctx.Attr<bool>("is_test");
int seed = ctx.Attr<int>("seed");
auto sequence_length = ctx.Attr<std::vector<int>>("sequence_length");
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle();
CudnnRNNCache *cudnn_rnn_cache = new CudnnRNNCache();
int seq_length = x->dims()[0];
int batch_size = x->dims()[1];
int input_size = x->dims()[2];
int weight_numel = w->numel();
bool state_initialized = state_out->IsInitialized() ? true : false;
auto input_w_numel = w->numel();
auto seq_len = x->dims()[0];
auto batch_size = x->dims()[1];
auto input_dim = x->dims()[2];
size_t workspace_size;
size_t reserve_size;
bool state_initialized = state_out->IsInitialized() ? true : false;
cudnnDataType_t cudnn_type = platform::ToCudnnDataType(
framework::ToDataType(std::type_index(typeid(T))));
cudnn_rnn_cache->init(handle, ctx.GetPlace(), seq_len, batch_size,
input_dim, hidden_size, num_layers, dropout_prob,
is_bidirec, seed, input_w_numel, &reserve_size,
state_out, state_initialized, cudnn_type);
platform::ScopedRNNBase rnn(seq_length, batch_size, input_size, hidden_size,
num_layers, dropout_prob, seed, weight_numel,
state_initialized, is_bidirec);
rnn.Create<T>(handle, ctx.GetPlace(), sequence_length, &workspace_size,
&reserve_size, state_out);
framework::Tensor workspace_data_;
workspace_data_.Resize({static_cast<int64_t>(workspace_size)});
workspace_data_.mutable_data<uint8_t>(ctx.GetPlace());
auto *reserve_data = reserve->mutable_data<uint8_t>(
{static_cast<int64_t>(reserve_size)}, ctx.GetPlace());
if (is_test) {
// for inference
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardInference(
handle, cudnn_rnn_cache->rnn_desc_, seq_len, cudnn_rnn_cache->x_desc_,
x_data, cudnn_rnn_cache->hx_desc_, init_h_data,
cudnn_rnn_cache->cx_desc_, init_c_data, cudnn_rnn_cache->w_desc_,
w_data, cudnn_rnn_cache->y_desc_, out_data, cudnn_rnn_cache->hy_desc_,
last_h_data, cudnn_rnn_cache->cy_desc_, last_c_data,
cudnn_rnn_cache->workspace_data_.data<uint8_t>(),
cudnn_rnn_cache->workspace_size_));
if (sequence_length.empty()) {
// for inference
// This interface is used when the input/output is unpadded.
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardInference(
handle, rnn.rnn_desc(), seq_length, rnn.x_desc(), x_data,
rnn.hx_desc(), init_h_data, rnn.cx_desc(), init_c_data,
rnn.w_desc(), w_data, rnn.y_desc(), out_data, rnn.hy_desc(),
last_h_data, rnn.cy_desc(), last_c_data,
workspace_data_.data<uint8_t>(), workspace_size));
} else {
#if CUDNN_VERSION >= 7201
// for inference
// This interface is used when the input/output is padded.
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnRNNForwardInferenceEx(
handle, rnn.rnn_desc(), rnn.x_seq_desc(), x_data, rnn.hx_desc(),
init_h_data, rnn.cx_desc(), init_c_data, rnn.w_desc(), w_data,
rnn.y_seq_desc(), out_data, rnn.hy_desc(), last_h_data,
rnn.cy_desc(), last_c_data, nullptr, nullptr, nullptr, nullptr,
nullptr, nullptr, nullptr, nullptr,
workspace_data_.data<uint8_t>(), workspace_size));
#else
PADDLE_ENFORCE_NOT_NULL(
nullptr, platform::errors::Unavailable(
"The padded input is supported by "
"cudnnRNNForwardInferenceEx, but it only works when "
"the version of cudnn is larger than 7.2.1"));
#endif
}
} else {
// for train
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardTraining(
handle, cudnn_rnn_cache->rnn_desc_, seq_len, cudnn_rnn_cache->x_desc_,
x_data, cudnn_rnn_cache->hx_desc_, init_h_data,
cudnn_rnn_cache->cx_desc_, init_c_data, cudnn_rnn_cache->w_desc_,
w_data, cudnn_rnn_cache->y_desc_, out_data, cudnn_rnn_cache->hy_desc_,
last_h_data, cudnn_rnn_cache->cy_desc_, last_c_data,
cudnn_rnn_cache->workspace_data_.data<uint8_t>(),
cudnn_rnn_cache->workspace_size_, reserve_data, reserve_size));
if (sequence_length.empty()) {
// for train
// This interface is used when the input/output is unpadded.
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardTraining(
handle, rnn.rnn_desc(), seq_length, rnn.x_desc(), x_data,
rnn.hx_desc(), init_h_data, rnn.cx_desc(), init_c_data,
rnn.w_desc(), w_data, rnn.y_desc(), out_data, rnn.hy_desc(),
last_h_data, rnn.cy_desc(), last_c_data,
workspace_data_.data<uint8_t>(), workspace_size, reserve_data,
reserve_size));
} else {
#if CUDNN_VERSION >= 7201
// for train
// This interface is used when the input/output is padded.
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnRNNForwardTrainingEx(
handle, rnn.rnn_desc(), rnn.x_seq_desc(), x_data, rnn.hx_desc(),
init_h_data, rnn.cx_desc(), init_c_data, rnn.w_desc(), w_data,
rnn.y_seq_desc(), out_data, rnn.hy_desc(), last_h_data,
rnn.cy_desc(), last_c_data, nullptr, nullptr, nullptr, nullptr,
nullptr, nullptr, nullptr, nullptr,
workspace_data_.data<uint8_t>(), workspace_size, reserve_data,
reserve_size));
#else
PADDLE_ENFORCE_NOT_NULL(
nullptr, platform::errors::Unavailable(
"The padded input is supported by "
"cudnnRNNForwardTrainingEx, but it only works when "
"the version of cudnn is larger than 7.2.1"));
#endif
}
}
delete cudnn_rnn_cache;
}
};
......@@ -156,44 +203,74 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
int hidden_size = ctx.Attr<int>("hidden_size");
int num_layers = ctx.Attr<int>("num_layers");
int seed = ctx.Attr<int>("seed");
auto sequence_length = ctx.Attr<std::vector<int>>("sequence_length");
CudnnRNNCache *cudnn_rnn_cache = new CudnnRNNCache();
int seq_length = input_dims[0];
int batch_size = input->dims()[1];
int input_size = input->dims()[2];
int weight_numel = weight->numel();
auto input_w_numel = weight->numel();
auto seq_len = input_dims[0];
auto batch_size = input->dims()[1];
auto input_dim = input->dims()[2];
size_t workspace_size;
size_t reserve_size;
cudnnDataType_t cudnn_type = platform::ToCudnnDataType(
framework::ToDataType(std::type_index(typeid(T))));
cudnn_rnn_cache->init(handle, ctx.GetPlace(), seq_len, batch_size,
input_dim, hidden_size, num_layers, dropout_prob,
is_bidirec, seed, input_w_numel, &reserve_size,
const_cast<Tensor *>(state_out), true, cudnn_type);
auto work_data = cudnn_rnn_cache->workspace_data_.data<uint8_t>();
platform::ScopedRNNBase rnn(seq_length, batch_size, input_size, hidden_size,
num_layers, dropout_prob, seed, weight_numel,
true, is_bidirec);
rnn.Create<T>(handle, ctx.GetPlace(), sequence_length, &workspace_size,
&reserve_size, const_cast<Tensor *>(state_out));
framework::Tensor workspace_data_;
workspace_data_.Resize({static_cast<int64_t>(workspace_size)});
workspace_data_.mutable_data<uint8_t>(ctx.GetPlace());
const uint8_t *reserve_data = reserve->data<uint8_t>();
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardData(
handle, cudnn_rnn_cache->rnn_desc_, seq_len, cudnn_rnn_cache->y_desc_,
out_data, cudnn_rnn_cache->y_desc_, out_grad_data,
cudnn_rnn_cache->hy_desc_, last_h_grad_data, cudnn_rnn_cache->cy_desc_,
last_c_grad_data, cudnn_rnn_cache->w_desc_, weight_data,
cudnn_rnn_cache->hx_desc_, init_h_data, cudnn_rnn_cache->cx_desc_,
init_c_data, cudnn_rnn_cache->x_desc_, in_grad_data,
cudnn_rnn_cache->hx_desc_, init_h_grad_data, cudnn_rnn_cache->cx_desc_,
init_c_grad_data, work_data, cudnn_rnn_cache->workspace_size_,
const_cast<uint8_t *>(reserve_data), reserve_size));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardWeights(
handle, cudnn_rnn_cache->rnn_desc_, seq_len, cudnn_rnn_cache->x_desc_,
input->data<T>(), cudnn_rnn_cache->hx_desc_, init_h->data<T>(),
cudnn_rnn_cache->y_desc_, out->data<T>(),
cudnn_rnn_cache->workspace_data_.data<uint8_t>(),
cudnn_rnn_cache->workspace_size_, cudnn_rnn_cache->w_desc_,
weight_grad->data<T>(), const_cast<uint8_t *>(reserve_data),
reserve_size));
delete cudnn_rnn_cache;
if (sequence_length.empty()) {
// This interface is used when the input/output is unpadded.
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardData(
handle, rnn.rnn_desc(), seq_length, rnn.y_desc(), out_data,
rnn.y_desc(), out_grad_data, rnn.hy_desc(), last_h_grad_data,
rnn.cy_desc(), last_c_grad_data, rnn.w_desc(), weight_data,
rnn.hx_desc(), init_h_data, rnn.cx_desc(), init_c_data, rnn.x_desc(),
in_grad_data, rnn.hx_desc(), init_h_grad_data, rnn.cx_desc(),
init_c_grad_data, workspace_data_.data<uint8_t>(), workspace_size,
const_cast<uint8_t *>(reserve_data), reserve_size));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardWeights(
handle, rnn.rnn_desc(), seq_length, rnn.x_desc(), input->data<T>(),
rnn.hx_desc(), init_h->data<T>(), rnn.y_desc(), out->data<T>(),
workspace_data_.data<uint8_t>(), workspace_size, rnn.w_desc(),
weight_grad->data<T>(), const_cast<uint8_t *>(reserve_data),
reserve_size));
} else {
#if CUDNN_VERSION >= 7201
// for train
// This interface is used when the input/output is padded.
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardDataEx(
handle, rnn.rnn_desc(), rnn.y_seq_desc(), out_data, rnn.y_seq_desc(),
out_grad_data, nullptr, nullptr, rnn.hy_desc(), last_h_grad_data,
rnn.cy_desc(), last_c_grad_data, rnn.w_desc(), weight_data,
rnn.hx_desc(), init_h_data, rnn.cx_desc(), init_c_data,
rnn.x_seq_desc(), in_grad_data, rnn.hx_desc(), init_h_grad_data,
rnn.cx_desc(), init_c_grad_data, nullptr, nullptr,
workspace_data_.data<uint8_t>(), workspace_size,
const_cast<uint8_t *>(reserve_data), reserve_size));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardWeightsEx(
handle, rnn.rnn_desc(), rnn.x_seq_desc(), input->data<T>(),
rnn.hx_desc(), init_h->data<T>(), rnn.y_seq_desc(), out->data<T>(),
workspace_data_.data<uint8_t>(), workspace_size, rnn.w_desc(),
weight_grad->data<T>(), const_cast<uint8_t *>(reserve_data),
reserve_size));
#else
PADDLE_ENFORCE_NOT_NULL(
nullptr,
platform::errors::Unavailable(
"The padded input of rnn is supported by cudnnRNNBackwardDataEx, "
"cudnnRNNBackwardWeightsEx, but it only works when the version "
"of cudnn is larger than 7.2.1"));
#endif
}
}
};
......
......@@ -273,11 +273,116 @@ class ScopedTensorDescriptor {
groups);
}
inline cudnnTensorDescriptor_t descriptor(const cudnnDataType_t cudnn_type,
const std::vector<int>& dim,
const std::vector<int>& stride) {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnSetTensorNdDescriptor(
desc_, cudnn_type, dim.size(), dim.data(), stride.data()));
return desc_;
}
template <typename T>
inline cudnnTensorDescriptor_t descriptor(const std::vector<int>& dim,
const std::vector<int>& stride) {
return descriptor(CudnnDataType<T>::type, dim, stride);
}
private:
cudnnTensorDescriptor_t desc_;
DISABLE_COPY_AND_ASSIGN(ScopedTensorDescriptor);
};
class ScopedRNNTensorDescriptor {
public:
ScopedRNNTensorDescriptor() {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnCreateRNNDataDescriptor(&desc_));
}
~ScopedRNNTensorDescriptor() PADDLE_MAY_THROW {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnDestroyRNNDataDescriptor(desc_));
}
inline cudnnRNNDataDescriptor_t descriptor(
const cudnnDataType_t cudnn_type, int max_seq_length, int batch_size,
int input_size, bool time_major, const std::vector<int>& seq_length) {
static float padding_fill = 0.0f;
cudnnRNNDataLayout_t layout;
if (time_major) {
layout = CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED;
} else {
layout = CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED;
}
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnSetRNNDataDescriptor(
desc_, cudnn_type, layout, max_seq_length, batch_size, input_size,
seq_length.data(), static_cast<void*>(&padding_fill)));
return desc_;
}
template <typename T>
inline cudnnRNNDataDescriptor_t descriptor(
int max_length, int batch_size, int input_size, bool time_major,
const std::vector<int>& seq_length) {
return descriptor(CudnnDataType<T>::type, max_length, batch_size,
input_size, time_major, seq_length);
}
private:
cudnnRNNDataDescriptor_t desc_;
DISABLE_COPY_AND_ASSIGN(ScopedRNNTensorDescriptor);
};
class ScopedDropoutDescriptor {
public:
ScopedDropoutDescriptor() {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnCreateDropoutDescriptor(&desc_));
}
~ScopedDropoutDescriptor() PADDLE_MAY_THROW {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnDestroyDropoutDescriptor(desc_));
}
inline cudnnDropoutDescriptor_t descriptor(const cudnnHandle_t& handle,
const platform::Place& place,
bool initialized,
float dropout_prob_,
framework::Tensor* dropout_state_,
int seed, size_t state_size) {
auto* dropout_state_data = dropout_state_->data<uint8_t>();
if (!initialized) {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnSetDropoutDescriptor(
desc_, handle, dropout_prob_, dropout_state_data, state_size, seed));
} else {
auto dropout_state_dims = dropout_state_->dims();
state_size = dropout_state_dims[0];
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnRestoreDropoutDescriptor(
desc_, handle, dropout_prob_, dropout_state_data, state_size, 0));
}
return desc_;
}
private:
cudnnDropoutDescriptor_t desc_;
DISABLE_COPY_AND_ASSIGN(ScopedDropoutDescriptor);
};
class ScopedRNNDescriptor {
public:
ScopedRNNDescriptor() {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnCreateRNNDescriptor(&desc_));
}
~ScopedRNNDescriptor() PADDLE_MAY_THROW {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnDestroyRNNDescriptor(desc_));
}
inline cudnnRNNDescriptor_t descriptor() { return desc_; }
private:
cudnnRNNDescriptor_t desc_;
DISABLE_COPY_AND_ASSIGN(ScopedRNNDescriptor);
};
class ScopedFilterDescriptor {
public:
ScopedFilterDescriptor() {
......@@ -319,6 +424,167 @@ class ScopedFilterDescriptor {
DISABLE_COPY_AND_ASSIGN(ScopedFilterDescriptor);
};
class ScopedRNNBase {
public:
ScopedRNNBase(int seq_length, int batch_size, int input_size, int hidden_size,
int num_layers, float dropout_prob, int seed, int weight_numel,
bool initialized, bool is_bidirec)
: seq_length_(seq_length),
batch_size_(batch_size),
input_size_(input_size),
hidden_size_(hidden_size),
num_layers_(num_layers),
dropout_prob_(dropout_prob),
seed_(seed),
weight_numel_(weight_numel),
initialized_(initialized),
is_bidirec_(is_bidirec) {}
template <typename T>
void Create(const cudnnHandle_t& handle, const platform::Place& place,
std::vector<int> sequence_length, size_t* workspace_size,
size_t* reserve_size, framework::Tensor* dropout_state) {
int numDirections = is_bidirec_ ? 2 : 1;
cudnnDataType_t cudnn_type = platform::CudnnDataType<T>::type;
// ------------------- cudnn x, y descriptors ---------------------
std::vector<int> dims_x = {batch_size_, input_size_, 1};
std::vector<int> strides_x = {input_size_, 1, 1};
std::vector<int> dims_y = {batch_size_, hidden_size_ * numDirections, 1};
std::vector<int> strides_y = {hidden_size_ * numDirections, 1, 1};
for (int i = 0; i < seq_length_; ++i) {
x_desc_.emplace_back(x_d.descriptor<T>(dims_x, strides_x));
y_desc_.emplace_back(y_d.descriptor<T>(dims_y, strides_y));
}
if (!sequence_length.empty()) {
x_seq_desc_ = x_seq_d.descriptor<T>(seq_length_, batch_size_, input_size_,
true, sequence_length);
y_seq_desc_ = y_seq_d.descriptor<T>(seq_length_, batch_size_,
hidden_size_ * numDirections, true,
sequence_length);
}
// ------------------- cudnn hx, hy, cx, cy descriptors----------
std::vector<int> dims_hx = {num_layers_ * numDirections, batch_size_,
hidden_size_};
std::vector<int> strides_hx = {hidden_size_ * batch_size_, hidden_size_, 1};
hx_desc_ = hx_d.descriptor<T>(dims_hx, strides_hx);
cx_desc_ = cx_d.descriptor<T>(dims_hx, strides_hx);
hy_desc_ = hy_d.descriptor<T>(dims_hx, strides_hx);
cy_desc_ = cy_d.descriptor<T>(dims_hx, strides_hx);
// ------------------- cudnn dropout descriptors ---------------------
size_t state_size;
if (!initialized_) {
PADDLE_ENFORCE_CUDA_SUCCESS(
dynload::cudnnDropoutGetStatesSize(handle, &state_size));
dropout_state->mutable_data<uint8_t>({static_cast<int64_t>(state_size)},
place);
}
dropout_desc_ =
dropout_d.descriptor(handle, place, initialized_, dropout_prob_,
dropout_state, seed_, state_size);
// ------------------- cudnn rnn descriptors ---------------------
rnn_desc_ = rnn_d.descriptor();
#if CUDNN_VERSION >= 6000
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetRNNDescriptor_v6(
handle, rnn_desc_, hidden_size_, num_layers_, dropout_desc_,
CUDNN_LINEAR_INPUT,
is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, CUDNN_LSTM,
CUDNN_RNN_ALGO_STANDARD, cudnn_type));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetRNNDescriptor(
rnn_desc_, hidden_size_, num_layers_, dropout_desc_, CUDNN_LINEAR_INPUT,
is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, CUDNN_LSTM,
cudnn_type));
#endif
if (!sequence_length.empty()) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetRNNPaddingMode(
rnn_desc_, CUDNN_RNN_PADDED_IO_ENABLED));
}
// ------------------- cudnn weights_size ---------------------
size_t weights_size_;
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnGetRNNParamsSize(
handle, rnn_desc_, x_desc_[0], &weights_size_, cudnn_type));
PADDLE_ENFORCE_EQ(
weights_size_, sizeof(T) * weight_numel_,
platform::errors::InvalidArgument(
"The cudnn lstm and setting weight size should be same."));
// ------------------- cudnn weight descriptors ---------------------
platform::DataLayout layout = platform::DataLayout::kNCHW;
int dim_tmp = weights_size_ / sizeof(T);
std::vector<int> dim_w = {dim_tmp, 1, 1};
w_desc_ = w_d.descriptor<T>(layout, dim_w);
// ------------------- cudnn workspace, reserve size ---------------------
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnGetRNNWorkspaceSize(
handle, rnn_desc_, seq_length_, x_desc_.data(), workspace_size));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnGetRNNTrainingReserveSize(
handle, rnn_desc_, seq_length_, x_desc_.data(), reserve_size));
}
cudnnTensorDescriptor_t* x_desc() { return x_desc_.data(); }
cudnnTensorDescriptor_t* y_desc() { return y_desc_.data(); }
cudnnRNNDataDescriptor_t x_seq_desc() { return x_seq_desc_; }
cudnnRNNDataDescriptor_t y_seq_desc() { return y_seq_desc_; }
cudnnTensorDescriptor_t hx_desc() { return hx_desc_; }
cudnnTensorDescriptor_t cx_desc() { return cx_desc_; }
cudnnTensorDescriptor_t hy_desc() { return hy_desc_; }
cudnnTensorDescriptor_t cy_desc() { return cy_desc_; }
cudnnRNNDescriptor_t rnn_desc() { return rnn_desc_; }
cudnnDropoutDescriptor_t dropout_desc() { return dropout_desc_; }
cudnnFilterDescriptor_t w_desc() { return w_desc_; }
private:
int seq_length_;
int batch_size_;
int input_size_;
int hidden_size_;
int num_layers_;
float dropout_prob_;
int seed_;
int weight_numel_;
bool initialized_;
bool is_bidirec_;
std::vector<cudnnTensorDescriptor_t> x_desc_;
std::vector<cudnnTensorDescriptor_t> y_desc_;
cudnnRNNDataDescriptor_t x_seq_desc_;
cudnnRNNDataDescriptor_t y_seq_desc_;
// A tensor descriptor describing the initial hidden state of the RNN.
cudnnTensorDescriptor_t hx_desc_;
// A tensor descriptor describing the initial cell state for LSTM networks.
cudnnTensorDescriptor_t cx_desc_;
// A tensor descriptor describing the final hidden state of the RNN.
cudnnTensorDescriptor_t hy_desc_;
// A tensor descriptor describing the final cell state for LSTM networks.
cudnnTensorDescriptor_t cy_desc_;
cudnnDropoutDescriptor_t dropout_desc_;
cudnnFilterDescriptor_t w_desc_;
cudnnRNNDescriptor_t rnn_desc_;
ScopedTensorDescriptor x_d;
ScopedTensorDescriptor y_d;
ScopedRNNTensorDescriptor x_seq_d;
ScopedRNNTensorDescriptor y_seq_d;
ScopedTensorDescriptor hx_d;
ScopedTensorDescriptor cx_d;
ScopedTensorDescriptor hy_d;
ScopedTensorDescriptor cy_d;
ScopedDropoutDescriptor dropout_d;
ScopedFilterDescriptor w_d;
ScopedRNNDescriptor rnn_d;
};
class ScopedConvolutionDescriptor {
public:
ScopedConvolutionDescriptor() {
......
......@@ -101,6 +101,9 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
__macro(cudnnDropoutGetStatesSize); \
__macro(cudnnSetDropoutDescriptor); \
__macro(cudnnRestoreDropoutDescriptor); \
__macro(cudnnCreateRNNDataDescriptor); \
__macro(cudnnDestroyRNNDataDescriptor); \
__macro(cudnnSetRNNDataDescriptor); \
__macro(cudnnCreateRNNDescriptor); \
__macro(cudnnGetRNNParamsSize); \
__macro(cudnnGetRNNWorkspaceSize); \
......@@ -109,6 +112,11 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
__macro(cudnnRNNBackwardData); \
__macro(cudnnRNNBackwardWeights); \
__macro(cudnnRNNForwardInference); \
__macro(cudnnRNNForwardTrainingEx); \
__macro(cudnnSetRNNPaddingMode); \
__macro(cudnnRNNBackwardDataEx); \
__macro(cudnnRNNBackwardWeightsEx); \
__macro(cudnnRNNForwardInferenceEx); \
__macro(cudnnDestroyDropoutDescriptor); \
__macro(cudnnDestroyRNNDescriptor); \
__macro(cudnnSetTensorNdDescriptorEx);
......
......@@ -16,6 +16,7 @@ from __future__ import print_function
import unittest
import numpy as np
import math
import paddle.fluid.core as core
from op_test import OpTest
......@@ -27,120 +28,372 @@ SIGMOID_THRESHOLD_MAX = 13.0
EXP_MAX_INPUT = 40.0
def lstm_naive(input, w):
seq_len, batch_size, hidden_size = input.shape
offset = 0
wi = w[offset:offset + hidden_size * hidden_size].reshape(
(hidden_size, hidden_size)).transpose()
offset += hidden_size * hidden_size
wf = w[offset:offset + hidden_size * hidden_size].reshape(
(hidden_size, hidden_size)).transpose()
offset += hidden_size * hidden_size
wc = w[offset:offset + hidden_size * hidden_size].reshape(
(hidden_size, hidden_size)).transpose()
offset += hidden_size * hidden_size
wo = w[offset:offset + hidden_size * hidden_size].reshape(
(hidden_size, hidden_size)).transpose()
offset += hidden_size * hidden_size
ri = w[offset:offset + hidden_size * hidden_size].reshape(
(hidden_size, hidden_size)).transpose()
offset += hidden_size * hidden_size
rf = w[offset:offset + hidden_size * hidden_size].reshape(
(hidden_size, hidden_size)).transpose()
offset += hidden_size * hidden_size
rc = w[offset:offset + hidden_size * hidden_size].reshape(
(hidden_size, hidden_size)).transpose()
offset += hidden_size * hidden_size
ro = w[offset:offset + hidden_size * hidden_size].reshape(
(hidden_size, hidden_size)).transpose()
offset += hidden_size * hidden_size
bi_1 = w[offset:offset + hidden_size]
offset += hidden_size
bf_1 = w[offset:offset + hidden_size]
offset += hidden_size
bc_1 = w[offset:offset + hidden_size]
offset += hidden_size
bo_1 = w[offset:offset + hidden_size]
offset += hidden_size
bi_2 = w[offset:offset + hidden_size]
offset += hidden_size
bf_2 = w[offset:offset + hidden_size]
offset += hidden_size
bc_2 = w[offset:offset + hidden_size]
offset += hidden_size
bo_2 = w[offset:offset + hidden_size]
def sigmoid(x):
y = np.copy(x)
y[x < SIGMOID_THRESHOLD_MIN] = SIGMOID_THRESHOLD_MIN
y[x > SIGMOID_THRESHOLD_MAX] = SIGMOID_THRESHOLD_MAX
return 1. / (1. + np.exp(-y))
def tanh(x):
y = -2. * x
y[y > EXP_MAX_INPUT] = EXP_MAX_INPUT
return (2. / (1. + np.exp(y))) - 1.
output = []
pre_h = np.zeros((1, batch_size, hidden_size), dtype=input.dtype)
pre_c = np.zeros((1, batch_size, hidden_size), dtype=input.dtype)
for i in range(seq_len):
emb_1 = input[i]
input_gate = sigmoid(
np.matmul(emb_1, wi) + np.matmul(pre_h, ri) + bi_1 + bi_2)
forget_gate = sigmoid(
np.matmul(emb_1, wf) + np.matmul(pre_h, rf) + bf_1 + bf_2)
output_gate = sigmoid(
np.matmul(emb_1, wo) + np.matmul(pre_h, ro) + bo_1 + bo_2)
c_t_temp = tanh(
np.matmul(emb_1, wc) + np.matmul(pre_h, rc) + bc_1 + bc_2)
new_c = input_gate * c_t_temp + forget_gate * pre_c
new_h = output_gate * tanh(new_c)
pre_h = new_h
pre_c = new_c
output.append(new_h)
output = np.concatenate(output, -1)
output = output.reshape((batch_size, -1, hidden_size))
output = output.transpose((1, 0, 2))
return output, pre_h, pre_c
class LayerMixin(object):
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
class LayerListMixin(LayerMixin):
def __init__(self, layers=None):
self._layers = list(layers) if layers else []
def append(self, layer):
self._layers.append(layer)
def __iter__(self):
return iter(self._layers)
class LSTMCell(LayerMixin):
def __init__(self, input_size, hidden_size, bias=True):
self.input_size = input_size
self.hidden_size = hidden_size
self.bias = bias
self.dtype = np.float64
self.parameters = dict()
std = 1.0 / math.sqrt(hidden_size)
self.weight_ih = np.ones(
(4 * hidden_size, input_size), dtype=self.dtype)
self.weight_hh = np.ones((4 * hidden_size,
hidden_size)).astype(self.dtype)
self.parameters['weight_ih'] = self.weight_ih
self.parameters['weight_hh'] = self.weight_hh
if bias:
self.bias_ih = np.ones((4 * hidden_size)).astype(self.dtype)
self.bias_hh = np.ones((4 * hidden_size)).astype(self.dtype)
self.parameters['bias_ih'] = self.bias_ih
self.parameters['bias_hh'] = self.bias_hh
else:
self.bias_ih = None
self.bias_hh = None
def init_state(self, inputs):
batch_size = inputs.shape[0]
init_h = np.zeros((batch_size, self.hidden_size), dtype=inputs.dtype)
init_c = np.zeros((batch_size, self.hidden_size), dtype=inputs.dtype)
return init_h, init_c
def forward(self, inputs, hx=None):
if hx is None:
hx = self.init_state(inputs)
pre_hidden, pre_cell = hx
gates = np.matmul(inputs, self.weight_ih.T)
if self.bias_ih is not None:
gates = gates + self.bias_ih
gates += np.matmul(pre_hidden, self.weight_hh.T)
if self.bias_hh is not None:
gates = gates + self.bias_hh
chunked_gates = np.split(gates, 4, -1)
i = 1.0 / (1.0 + np.exp(-chunked_gates[0]))
f = 1.0 / (1.0 + np.exp(-chunked_gates[1]))
o = 1.0 / (1.0 + np.exp(-chunked_gates[3]))
c = f * pre_cell + i * np.tanh(chunked_gates[2])
h = o * np.tanh(c)
return h, (h, c)
def sequence_mask(lengths, max_len=None):
if max_len is None:
max_len = np.max(lengths)
else:
assert max_len >= np.max(lengths)
return np.arange(max_len) < np.expand_dims(lengths, -1)
def update_state(mask, new, old):
if not isinstance(old, (tuple, list)):
return np.where(mask, new, old)
else:
return tuple(map(lambda x, y: np.where(mask, x, y), new, old))
def rnn(cell,
inputs,
initial_states,
sequence_length=None,
time_major=False,
is_reverse=False):
if not time_major:
inputs = np.transpose(inputs, [1, 0, 2])
if is_reverse:
inputs = np.flip(inputs, 0)
if sequence_length is None:
mask = None
else:
mask = np.transpose(sequence_mask(sequence_length), [1, 0])
mask = np.expand_dims(mask, -1)
if is_reverse:
mask = np.flip(mask, 0)
time_steps = inputs.shape[0]
state = initial_states
outputs = []
for t in range(time_steps):
x_t = inputs[t]
if mask is not None:
m_t = mask[t]
y, new_state = cell(x_t, state)
y = np.where(m_t, y, 0.)
outputs.append(y)
state = update_state(m_t, new_state, state)
else:
y, new_state = cell(x_t, state)
outputs.append(y)
state = new_state
outputs = np.stack(outputs)
final_state = state
if is_reverse:
outputs = np.flip(outputs, 0)
if not time_major:
outputs = np.transpose(outputs, [1, 0, 2])
return outputs, final_state
def birnn(cell_fw,
cell_bw,
inputs,
initial_states,
sequence_length=None,
time_major=False):
states_fw, states_bw = initial_states
outputs_fw, states_fw = rnn(cell_fw,
inputs,
states_fw,
sequence_length,
time_major=time_major)
outputs_bw, states_bw = rnn(cell_bw,
inputs,
states_bw,
sequence_length,
time_major=time_major,
is_reverse=True)
outputs = np.concatenate((outputs_fw, outputs_bw), -1)
final_states = (states_fw, states_bw)
return outputs, final_states
def flatten(nested):
return list(_flatten(nested))
def _flatten(nested):
for item in nested:
if isinstance(item, (list, tuple)):
for subitem in _flatten(item):
yield subitem
else:
yield item
def unstack(array, axis=0):
num = array.shape[axis]
sub_arrays = np.split(array, num, axis)
return [np.squeeze(sub_array, axis) for sub_array in sub_arrays]
def dropout(array, p=0.0):
if p == 0.0:
return array
mask = (np.random.uniform(size=array.shape) < (1 - p)).astype(array.dtype)
return array * (mask / (1 - p))
def split_states(states, bidirectional=False, state_components=1):
if state_components == 1:
states = unstack(states)
if not bidirectional:
return states
else:
return list(zip(states[::2], states[1::2]))
else:
assert len(states) == state_components
states = tuple([unstack(item) for item in states])
if not bidirectional:
return list(zip(*states))
else:
states = list(zip(*states))
return list(zip(states[::2], states[1::2]))
def concat_states(states, bidirectional=False, state_components=1):
if state_components == 1:
return np.stack(flatten(states))
else:
states = flatten(states)
componnets = []
for i in range(state_components):
componnets.append(states[i::state_components])
return [np.stack(item) for item in componnets]
class RNN(LayerMixin):
def __init__(self, cell, is_reverse=False, time_major=False):
super(RNN, self).__init__()
self.cell = cell
if not hasattr(self.cell, "call"):
# for non-dygraph mode, `rnn` api uses cell.call
self.cell.call = self.cell.forward
self.is_reverse = is_reverse
self.time_major = time_major
def forward(self, inputs, initial_states=None, sequence_length=None):
final_outputs, final_states = rnn(self.cell,
inputs,
initial_states=initial_states,
sequence_length=sequence_length,
time_major=self.time_major,
is_reverse=self.is_reverse)
return final_outputs, final_states
class BiRNN(LayerMixin):
def __init__(self, cell_fw, cell_bw, time_major=False):
super(BiRNN, self).__init__()
self.cell_fw = cell_fw
self.cell_bw = cell_bw
self.time_major = time_major
def forward(self,
inputs,
initial_states=None,
sequence_length=None,
**kwargs):
if isinstance(initial_states, (list, tuple)):
assert len(initial_states) == 2, \
"length of initial_states should be 2 when it is a list/tuple"
else:
initial_states = [initial_states, initial_states]
outputs, final_states = birnn(self.cell_fw, self.cell_bw, inputs,
initial_states, sequence_length,
self.time_major)
return outputs, final_states
class RNNMixin(LayerListMixin):
def forward(self, inputs, initial_states=None, sequence_length=None):
batch_index = 1 if self.time_major else 0
batch_size = inputs.shape[batch_index]
dtype = inputs.dtype
if initial_states is None:
state_shape = (self.num_layers * self.num_directions, batch_size,
self.hidden_size)
if self.state_components == 1:
initial_states = np.zeros(state_shape, dtype)
else:
initial_states = tuple([
np.zeros(state_shape, dtype)
for _ in range(self.state_components)
])
states = split_states(initial_states, self.num_directions == 2,
self.state_components)
final_states = []
for i, rnn_layer in enumerate(self):
if i > 0:
inputs = dropout(inputs, self.dropout)
outputs, final_state = rnn_layer(inputs, states[i], sequence_length)
final_states.append(final_state)
inputs = outputs
final_states = concat_states(final_states, self.num_directions == 2,
self.state_components)
return outputs, final_states
class LSTM(RNNMixin):
def __init__(self,
input_size,
hidden_size,
num_layers=1,
direction="forward",
dropout=0.,
time_major=False):
super(LSTM, self).__init__()
if direction in ["forward", "backward"]:
is_reverse = direction == "backward"
cell = LSTMCell(input_size, hidden_size)
self.append(RNN(cell, is_reverse, time_major))
for i in range(1, num_layers):
cell = LSTMCell(hidden_size, hidden_size)
self.append(RNN(cell, is_reverse, time_major))
elif direction == "bidirectional":
cell_fw = LSTMCell(input_size, hidden_size)
cell_bw = LSTMCell(input_size, hidden_size)
self.append(BiRNN(cell_fw, cell_bw, time_major))
for i in range(1, num_layers):
cell_fw = LSTMCell(2 * hidden_size, hidden_size)
cell_bw = LSTMCell(2 * hidden_size, hidden_size)
self.append(BiRNN(cell_fw, cell_bw, time_major))
else:
raise ValueError(
"direction should be forward, backward or bidirectional, "
"received direction = {}".format(direction))
self.input_size = input_size
self.hidden_size = hidden_size
self.dropout = dropout
self.num_directions = 2 if direction == "bidirectional" else 1
self.time_major = time_major
self.num_layers = num_layers
self.state_components = 2
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNLstmOp(OpTest):
# TODO(GaoWei8):when input dtype is fp64, precision threshold should be removed.
#TODO(GaoWei8): Need to satisfy the result through the new interface
def setUp(self):
self.op_type = "cudnn_lstm"
self.dtype = np.float64
self.sequence_length = np.array([12, 11, 10, 9, 8], dtype=np.int32)
self.num_layers = 1
seq_length = 20
seq_length = 12
batch_size = 5
hidden_size = 20
input_size = 21
hidden_size = 21
input_weight_size = (hidden_size * hidden_size) * 4
hidden_weight_size = (hidden_size * hidden_size) * 4
weight_size = input_weight_size + hidden_weight_size
weight_size += hidden_size * 8
weight_size *= self.num_layers
input = np.random.uniform(
low=-0.1, high=0.1, size=(seq_length, batch_size,
hidden_size)).astype(self.dtype)
flat_w = np.random.uniform(
low=-0.1, high=0.1, size=(weight_size)).astype(self.dtype)
output, last_hidden, last_cell = lstm_naive(input, flat_w)
init_h = np.zeros((1, batch_size, hidden_size), dtype=np.float64)
init_c = np.zeros((1, batch_size, hidden_size), dtype=np.float64)
low=-0.1, high=0.1,
size=(seq_length, batch_size, input_size)).astype(self.dtype)
input[11][1:][:] = 0
input[10][2:][:] = 0
input[9][3:][:] = 0
input[8][4:][:] = 0
rnn1 = LSTM(
input_size,
hidden_size,
self.num_layers,
time_major=True,
direction="forward")
output, (last_hidden, last_cell) = rnn1(
input, sequence_length=self.sequence_length)
flat_w = np.ones((weight_size)).astype(self.dtype)
init_h = np.zeros((self.num_layers, batch_size,
hidden_size)).astype(self.dtype)
init_c = np.zeros((self.num_layers, batch_size,
hidden_size)).astype(self.dtype)
state_out = np.ndarray((300)).astype("uint8")
self.inputs = {
......@@ -152,9 +405,10 @@ class TestCUDNNLstmOp(OpTest):
self.attrs = {
'dropout_prob': 0.0,
'is_bidirec': False,
'input_size': hidden_size,
'input_size': input_size,
'hidden_size': hidden_size,
'num_layers': 1,
'sequence_length': self.sequence_length.tolist()
}
self.outputs = {
'Out': output,
......@@ -164,19 +418,33 @@ class TestCUDNNLstmOp(OpTest):
'StateOut': state_out
}
def set_attrs(self):
pass
def test_output_with_place(self):
# depend on the scope structure
place = core.CUDAPlace(0)
self.check_output_with_place(
place, no_check_set=['Reserve', 'StateOut'])
def test_grad_with_place(self):
# depend on the scope structure
place = core.CUDAPlace(0)
self.check_grad_with_place(
place,
set(['Input', 'W', 'InitH', 'InitC']), ['Out', 'LastH', 'LastC'],
max_relative_error=1e-4)
self.check_grad_with_place(place,
set(['Input', 'W', 'InitH', 'InitC']),
['Out', 'LastH', 'LastC'])
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNLstmOp2(TestCUDNNLstmOp):
def set_attrs(self):
self.sequence_length = np.array([], dtype=np.int32)
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNLstmOp3(TestCUDNNLstmOp):
def set_attrs(self):
self.num_layers = 2
@unittest.skipIf(not core.is_compiled_with_cuda(),
......@@ -198,7 +466,7 @@ class TestCUDNNlstmAPI(unittest.TestCase):
'float64', 0.0)
rnn_out, last_h, last_c = layers.lstm(input, init_h, init_c, seq_len,
hidden_size, num_layers,
dropout_prob)
dropout_prob, False, True)
exe = fluid.Executor(fluid.CUDAPlace(0))
exe.run(fluid.default_startup_program())
input_i = np.random.uniform(
......@@ -208,12 +476,6 @@ class TestCUDNNlstmAPI(unittest.TestCase):
feed={'input': input_i},
fetch_list=[rnn_out, last_h, last_c, 'cudnn_lstm_0.w_0'])
output, last_hidden, last_cell = lstm_naive(input_i, out[3])
self.assertTrue(np.allclose(output, out[0], atol=1e-5))
self.assertTrue(np.allclose(last_hidden, out[1], atol=1e-5))
self.assertTrue(np.allclose(last_cell, out[2], atol=1e-5))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册