未验证 提交 1fbee267 编写于 作者: G GaoWei8 提交者: GitHub

remove scope in cudnn lstm (#25188)

上级 da29760d
......@@ -24,34 +24,62 @@ class CudnnLSTMOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("W"),
"Input(Weight) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("InitH"),
"Input(init_h) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("InitC"),
"Input(init_c) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Cache"),
"Input(Cache) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("last_h"),
"Output(last_h) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("last_c"),
"Output(last_c) of LSTM should not be null.");
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "CudnnLSTM");
OP_INOUT_CHECK(ctx->HasInput("W"), "Input", "W", "CudnnLSTM");
OP_INOUT_CHECK(ctx->HasInput("InitH"), "Input", "InitH", "CudnnLSTM");
OP_INOUT_CHECK(ctx->HasInput("InitC"), "Input", "InitC", "CudnnLSTM");
OP_INOUT_CHECK(ctx->HasOutput("Reserve"), "Output", "Reserve", "CudnnLSTM");
OP_INOUT_CHECK(ctx->HasOutput("StateOut"), "Output", "StateOut",
"CudnnLSTM");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "CudnnLSTM");
OP_INOUT_CHECK(ctx->HasOutput("LastH"), "Output", "LastH", "CudnnLSTM");
OP_INOUT_CHECK(ctx->HasOutput("LastC"), "Output", "LastC", "CudnnLSTM");
auto in_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_EQ(in_dims.size(), 3, "Input(X)'s rank must be 3.");
auto init_dims = ctx->GetInputDim("InitH");
PADDLE_ENFORCE_EQ(in_dims.size(), 3,
platform::errors::InvalidArgument(
"The rank of Input in CudnnLSTM must be 3. But "
"received Input's rank is %d.",
in_dims.size()));
PADDLE_ENFORCE_EQ(init_dims.size(), 3,
platform::errors::InvalidArgument(
"The rank of InitH in CudnnLSTM must be 3. But "
"received InitH's rank is %d.",
init_dims.size()));
PADDLE_ENFORCE_EQ(in_dims[1], init_dims[1],
platform::errors::InvalidArgument(
"The in_dims[1] (Input dims) and init_dims[1] (InitH "
"dims) should be equal. But "
"received in_dims[1] is %d and init_dims[1] is %d.",
in_dims[1], init_dims[1]));
PADDLE_ENFORCE_EQ(in_dims[2], init_dims[2],
platform::errors::InvalidArgument(
"The in_dims[2] (Input dims) and init_dims[2] (InitH "
"dims) should be equal. But "
"received in_dims[2] is %d and init_dims[2] is %d.",
in_dims[2], init_dims[2]));
auto out_dims = in_dims;
auto hidden_size = ctx->Attrs().Get<int>("hidden_size");
out_dims[2] = hidden_size;
bool is_bidirec = ctx->Attrs().Get<bool>("is_bidirec");
out_dims[2] = is_bidirec ? hidden_size * 2 : hidden_size;
auto last_dims = init_dims;
last_dims[0] = is_bidirec ? last_dims[0] * 2 : last_dims[0];
ctx->SetOutputDim("Out", out_dims);
ctx->SetOutputDim("last_h", ctx->GetInputDim("InitH"));
ctx->SetOutputDim("last_c", ctx->GetInputDim("InitC"));
ctx->SetOutputDim("LastH", last_dims);
ctx->SetOutputDim("LastC", last_dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
}
};
......@@ -84,33 +112,31 @@ class CudnnLSTMOpMaker : public framework::OpProtoAndCheckerMaker {
"(Tensor) the learnable hidden-hidden weights."
" The shape is (N), where N is total weight size of the LSTM. "
" cudnn concatenate all the weight to one Tensor");
AddInput("Cache",
"The cache of dropout op, a RAW type variable including random "
"number generator states and some descriptors, which is used in "
"cudnn kernel.")
.AsDispensable();
AddOutput("Reserve",
"(Tensor, a temporary output Tensor to store the reserve_data "
"of cudnn kernel.")
.AsIntermediate();
AddOutput("StateOut",
"Share memory with State. "
"Store the global drop state when training");
AddOutput("Out",
"(Tensor) the hidden state of LSTM operator. "
"The shape is ( seq_len x batch_size x hidden_size) if "
"is_bidirec is False"
"and When is_bidirec is True, the shape will be ( seq_len x "
"batch_size x hidden_size * 2) ");
AddOutput("last_h",
AddOutput("LastH",
"(Tensor) the hidden state of the last step. "
"The shape is ( num_layers x batch_size x hidden_size) if "
"is_bidirec is False"
"and When is_bidirec is True, the shape will be (num_layers*2 x "
"batch_size x hidden_size)");
AddOutput("last_c",
AddOutput("LastC",
"(Tensor) the cell state of the last step"
"The shape is ( num_layers x batch_size x hidden_size) if "
"is_bidirec is False"
"and When is_bidirect is True, the shape will be (num_layers*2 x "
"batch_size x hidden_size*2)");
AddAttr<int>("max_len",
"max length of the LSTM op"
"the first dim of the Input can NOT be greater than max_len")
.SetDefault(20);
AddAttr<float>(
"dropout_prob",
"dropout prob of the dropout op"
......@@ -120,14 +146,14 @@ class CudnnLSTMOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<bool>("is_bidirec",
"is_bidirec"
"if it is bidirectional rnn"
"The will affect the shape of the Out, last_h, and last_c")
"The will affect the shape of the Out, LastH, and LastC")
.SetDefault(false);
AddAttr<int>("input_size", "input size ot the Input Tensor").SetDefault(10);
AddAttr<int>("hidden_size", "hidden size of the LSTM").SetDefault(100);
AddAttr<int>("num_layers", "the total layer number of the LSTM")
.SetDefault(1);
AddAttr<bool>("is_test", "True if in test phase.").SetDefault(false);
AddAttr<int>("seed", "seed to used if fix_seed is True").SetDefault(-1);
AddAttr<int>("seed", "seed to used if fix_seed is True").SetDefault(0);
AddComment(R"DOC(
CUDNN LSTM implementation
......@@ -172,16 +198,10 @@ class CudnnLSTMGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Cache"),
"Input(last_c) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("InitH"),
"Input(init_h) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("InitC"),
"Input(init_c) of LSTM should not be null.");
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "CudnnLSTMGrad");
OP_INOUT_CHECK(ctx->HasInput("W"), "Input", "W", "CudnnLSTMGrad");
OP_INOUT_CHECK(ctx->HasInput("InitH"), "Input", "InitH", "CudnnLSTMGrad");
OP_INOUT_CHECK(ctx->HasInput("InitC"), "Input", "InitC", "CudnnLSTMGrad");
auto SetOutGradDim = [&ctx](const std::string& name) {
auto g_name = framework::GradVarName(name);
......@@ -195,6 +215,12 @@ class CudnnLSTMGradOp : public framework::OperatorWithKernel {
SetOutGradDim("InitH");
SetOutGradDim("InitC");
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
template <typename T>
......@@ -209,13 +235,12 @@ class CudnnLSTMGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("InitH", this->Input("InitH"));
op->SetInput("InitC", this->Input("InitC"));
op->SetInput("W", this->Input("W"));
if (this->HasInput("Cache")) {
op->SetInput("Cache", this->Input("Cache"));
}
op->SetInput("Reserve", this->Output("Reserve"));
op->SetInput("StateOut", this->Output("StateOut"));
op->SetInput("Out", this->Output("Out"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetInput(framework::GradVarName("last_c"), this->OutputGrad("last_c"));
op->SetInput(framework::GradVarName("last_h"), this->OutputGrad("last_h"));
op->SetInput(framework::GradVarName("LastC"), this->OutputGrad("LastC"));
op->SetInput(framework::GradVarName("LastH"), this->OutputGrad("LastH"));
op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
op->SetOutput(framework::GradVarName("W"), this->InputGrad("W"));
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/cudnn_rnn_cache.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cudnn_desc.h"
namespace paddle {
namespace operators {
......@@ -33,8 +34,10 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
auto w = ctx.Input<Tensor>("W");
Tensor *out = ctx.Output<Tensor>("Out");
Tensor *last_h = ctx.Output<Tensor>("last_h");
Tensor *last_c = ctx.Output<Tensor>("last_c");
Tensor *last_h = ctx.Output<Tensor>("LastH");
Tensor *last_c = ctx.Output<Tensor>("LastC");
Tensor *reserve = ctx.Output<Tensor>("Reserve");
Tensor *state_out = ctx.Output<Tensor>("StateOut");
const T *x_data = x->data<T>();
const T *init_h_data = init_h->data<T>();
......@@ -46,72 +49,56 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
T *last_h_data = last_h->mutable_data<T>(ctx.GetPlace());
T *last_c_data = last_c->mutable_data<T>(ctx.GetPlace());
size_t max_len = ctx.Attr<int>("max_len");
float dropout_prob = ctx.Attr<float>("dropout_prob");
bool is_bidirec = ctx.Attr<bool>("is_bidirec");
int input_size = ctx.Attr<int>("input_size");
int hidden_size = ctx.Attr<int>("hidden_size");
int num_layers = ctx.Attr<int>("num_layers");
bool is_test = ctx.Attr<bool>("is_test");
int seed = ctx.Attr<int>("seed");
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle();
auto *cache_var = ctx.InputVar("Cache");
if (!cache_var) {
// The RAW type cache variable wouldn't be created and broadcasted on
// multi-devices before the first running.
// use parent scope to make cache persistable
auto *scope = const_cast<framework::Scope *>(ctx.scope().parent());
auto cache_var_name = ctx.InputNames("Cache")[0];
cache_var = scope->Var(cache_var_name);
}
CudnnRNNCache *cudnn_rnn_cache = nullptr;
if (cache_var->IsInitialized()) {
// const_cast is usually bad.
cudnn_rnn_cache = const_cast<framework::Variable *>(cache_var)
->GetMutable<CudnnRNNCache>();
} else {
// const_cast is usually bad.
cudnn_rnn_cache = const_cast<framework::Variable *>(cache_var)
->GetMutable<CudnnRNNCache>();
std::random_device rnd;
int seed = ctx.Attr<int>("seed");
if (seed == -1) {
seed = rnd();
}
auto input_w_numel = w->numel();
auto batch_size = x->dims()[1];
cudnn_rnn_cache->init(handle, ctx.GetPlace(), max_len, batch_size,
input_size, hidden_size, num_layers, dropout_prob,
is_bidirec, seed, input_w_numel);
}
auto run_seq_len = x->dims()[0];
CudnnRNNCache *cudnn_rnn_cache = new CudnnRNNCache();
auto input_w_numel = w->numel();
auto seq_len = x->dims()[0];
auto batch_size = x->dims()[1];
auto input_dim = x->dims()[2];
size_t reserve_size;
bool state_initialized = state_out->IsInitialized() ? true : false;
cudnnDataType_t cudnn_type = platform::ToCudnnDataType(
framework::ToDataType(std::type_index(typeid(T))));
cudnn_rnn_cache->init(handle, ctx.GetPlace(), seq_len, batch_size,
input_dim, hidden_size, num_layers, dropout_prob,
is_bidirec, seed, input_w_numel, &reserve_size,
state_out, state_initialized, cudnn_type);
auto *reserve_data = reserve->mutable_data<uint8_t>(
{static_cast<int64_t>(reserve_size)}, ctx.GetPlace());
if (is_test) {
// for inference
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardInference(
handle, cudnn_rnn_cache->rnn_desc_, run_seq_len,
cudnn_rnn_cache->x_desc_, x_data, cudnn_rnn_cache->hx_desc_,
init_h_data, cudnn_rnn_cache->cx_desc_, init_c_data,
cudnn_rnn_cache->w_desc_, w_data, cudnn_rnn_cache->y_desc_, out_data,
cudnn_rnn_cache->hy_desc_, last_h_data, cudnn_rnn_cache->cy_desc_,
last_c_data, cudnn_rnn_cache->workspace_data_.data<uint8_t>(),
handle, cudnn_rnn_cache->rnn_desc_, seq_len, cudnn_rnn_cache->x_desc_,
x_data, cudnn_rnn_cache->hx_desc_, init_h_data,
cudnn_rnn_cache->cx_desc_, init_c_data, cudnn_rnn_cache->w_desc_,
w_data, cudnn_rnn_cache->y_desc_, out_data, cudnn_rnn_cache->hy_desc_,
last_h_data, cudnn_rnn_cache->cy_desc_, last_c_data,
cudnn_rnn_cache->workspace_data_.data<uint8_t>(),
cudnn_rnn_cache->workspace_size_));
} else {
// for train
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardTraining(
handle, cudnn_rnn_cache->rnn_desc_, run_seq_len,
cudnn_rnn_cache->x_desc_, x_data, cudnn_rnn_cache->hx_desc_,
init_h_data, cudnn_rnn_cache->cx_desc_, init_c_data,
cudnn_rnn_cache->w_desc_, w_data, cudnn_rnn_cache->y_desc_, out_data,
cudnn_rnn_cache->hy_desc_, last_h_data, cudnn_rnn_cache->cy_desc_,
last_c_data, cudnn_rnn_cache->workspace_data_.data<uint8_t>(),
cudnn_rnn_cache->workspace_size_,
cudnn_rnn_cache->reserve_data_.data<uint8_t>(),
cudnn_rnn_cache->reserve_size_));
handle, cudnn_rnn_cache->rnn_desc_, seq_len, cudnn_rnn_cache->x_desc_,
x_data, cudnn_rnn_cache->hx_desc_, init_h_data,
cudnn_rnn_cache->cx_desc_, init_c_data, cudnn_rnn_cache->w_desc_,
w_data, cudnn_rnn_cache->y_desc_, out_data, cudnn_rnn_cache->hy_desc_,
last_h_data, cudnn_rnn_cache->cy_desc_, last_c_data,
cudnn_rnn_cache->workspace_data_.data<uint8_t>(),
cudnn_rnn_cache->workspace_size_, reserve_data, reserve_size));
}
delete cudnn_rnn_cache;
}
};
......@@ -123,15 +110,13 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
auto *weight = ctx.Input<Tensor>("W");
auto *init_h = ctx.Input<Tensor>("InitH");
auto *init_c = ctx.Input<Tensor>("InitC");
// auto * last_h = ctx.Input<Tensor>("last_h");
// auto * last_c = ctx.Input<Tensor>("last_c");
auto *reserve = ctx.Input<Tensor>("Reserve");
auto *state_out = ctx.Input<Tensor>("StateOut");
auto *out = ctx.Input<Tensor>("Out");
auto *out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *last_h_grad = ctx.Input<Tensor>(framework::GradVarName("last_h"));
auto *last_c_grad = ctx.Input<Tensor>(framework::GradVarName("last_c"));
// auto* init_h = ctx.Input<Tensor>("init_h");
// auto* init_c = ctx.Input<Tensor>("init_c");
auto *last_h_grad = ctx.Input<Tensor>(framework::GradVarName("LastH"));
auto *last_c_grad = ctx.Input<Tensor>(framework::GradVarName("LastC"));
auto *in_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
auto *weight_grad = ctx.Output<Tensor>(framework::GradVarName("W"));
......@@ -140,116 +125,75 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle();
auto *cache_var = ctx.InputVar("Cache");
PADDLE_ENFORCE(cache_var->IsInitialized());
CudnnRNNCache *cudnn_rnn_cache =
const_cast<framework::Variable *>(cache_var)
->GetMutable<CudnnRNNCache>();
auto input_dims = input->dims();
auto init_h_dims = init_h->dims();
auto init_c_dims = init_c->dims();
in_grad->mutable_data<T>(ctx.GetPlace());
weight_grad->mutable_data<T>(ctx.GetPlace());
math::SetConstant<paddle::platform::CUDADeviceContext, T> zero;
zero(dev_ctx, in_grad, static_cast<T>(0.0));
zero(dev_ctx, weight_grad, static_cast<T>(0.0));
T *init_h_grad_data = NULL;
if (init_h_grad == nullptr) {
Tensor init_h_grad_temp;
init_h_grad_temp.mutable_data<T>(init_h_dims, ctx.GetPlace());
zero(dev_ctx, &init_h_grad_temp, static_cast<T>(0.0));
init_h_grad_data = init_h_grad_temp.data<T>();
} else {
init_h_grad->mutable_data<T>(init_h_dims, ctx.GetPlace());
zero(dev_ctx, init_h_grad, static_cast<T>(0.0));
init_h_grad_data = init_h_grad->data<T>();
}
T *init_c_grad_data = NULL;
if (init_c_grad == nullptr) {
Tensor init_c_grad_temp;
init_c_grad_temp.mutable_data<T>(init_c_dims, ctx.GetPlace());
zero(dev_ctx, &init_c_grad_temp, static_cast<T>(0.0));
init_c_grad_data = init_c_grad_temp.data<T>();
} else {
init_c_grad->mutable_data<T>(init_c_dims, ctx.GetPlace());
zero(dev_ctx, init_c_grad, static_cast<T>(0.0));
init_c_grad_data = init_c_grad->data<T>();
}
auto *weight_data = weight->data<T>();
auto *init_h_data = init_h->data<T>();
auto *init_c_data = init_c->data<T>();
auto *out_data = out->data<T>();
auto *out_grad_data = out_grad->data<T>();
auto *last_h_grad_data = last_h_grad->data<T>();
auto *last_c_grad_data = last_c_grad->data<T>();
const T *last_h_grad_data = NULL;
if (last_h_grad == nullptr) {
Tensor last_h_grad_temp;
last_h_grad_temp.mutable_data<T>(init_h_dims, ctx.GetPlace());
zero(dev_ctx, &last_h_grad_temp, static_cast<T>(0.0));
last_h_grad_data = (const T *)last_h_grad_temp.data<T>();
} else {
last_h_grad_data = last_h_grad->data<T>();
}
const T *last_c_grad_data = NULL;
if (last_c_grad == nullptr) {
Tensor last_c_grad_temp;
last_c_grad_temp.mutable_data<T>(init_c_dims, ctx.GetPlace());
zero(dev_ctx, &last_c_grad_temp, static_cast<T>(0.0));
last_c_grad_data = (const T *)last_c_grad_temp.data<T>();
} else {
last_c_grad_data = last_c_grad->data<T>();
}
math::SetConstant<paddle::platform::CUDADeviceContext, T> zero;
weight_grad->mutable_data<T>(ctx.GetPlace());
zero(dev_ctx, weight_grad, static_cast<T>(0.0));
const T *out_grad_data = NULL;
if (out_grad == nullptr) {
Tensor out_grad_temp;
out_grad_temp.mutable_data<T>(out->dims(), ctx.GetPlace());
zero(dev_ctx, &out_grad_temp, static_cast<T>(0.0));
in_grad->mutable_data<T>(input_dims, ctx.GetPlace());
auto *in_grad_data = in_grad->data<T>();
out_grad_data = (const T *)out_grad_temp.data<T>();
} else {
out_grad_data = out_grad->data<T>();
}
init_h_grad->mutable_data<T>(init_h_dims, ctx.GetPlace());
auto *init_h_grad_data = init_h_grad->data<T>();
// zero( dev_ctx, last_h_grad, static_cast<T>(0.0));
// zero( dev_ctx, last_c_grad, static_cast<T>(0.0));
init_c_grad->mutable_data<T>(init_c_dims, ctx.GetPlace());
auto *init_c_grad_data = init_c_grad->data<T>();
auto out_data = out->data<T>();
// auto out_grad_data = out_grad->data<T>();
auto weight_data = weight->data<T>();
auto init_h_data = init_h->data<T>();
auto init_c_data = init_c->data<T>();
auto in_grad_data = in_grad->data<T>();
float dropout_prob = ctx.Attr<float>("dropout_prob");
bool is_bidirec = ctx.Attr<bool>("is_bidirec");
int hidden_size = ctx.Attr<int>("hidden_size");
int num_layers = ctx.Attr<int>("num_layers");
int seed = ctx.Attr<int>("seed");
CudnnRNNCache *cudnn_rnn_cache = new CudnnRNNCache();
auto input_w_numel = weight->numel();
auto seq_len = input_dims[0];
auto batch_size = input->dims()[1];
auto input_dim = input->dims()[2];
size_t reserve_size;
cudnnDataType_t cudnn_type = platform::ToCudnnDataType(
framework::ToDataType(std::type_index(typeid(T))));
cudnn_rnn_cache->init(handle, ctx.GetPlace(), seq_len, batch_size,
input_dim, hidden_size, num_layers, dropout_prob,
is_bidirec, seed, input_w_numel, &reserve_size,
const_cast<Tensor *>(state_out), true, cudnn_type);
auto work_data = cudnn_rnn_cache->workspace_data_.data<uint8_t>();
auto reserve_data = cudnn_rnn_cache->reserve_data_.data<uint8_t>();
const uint8_t *reserve_data = reserve->data<uint8_t>();
auto run_seq_len = input_dims[0];
PADDLE_ENFORCE_LE((size_t)run_seq_len, cudnn_rnn_cache->max_length_,
"cudnn running seq_len CAN not greater max_lengh");
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardData(
handle, cudnn_rnn_cache->rnn_desc_, run_seq_len,
cudnn_rnn_cache->y_desc_, out_data, cudnn_rnn_cache->dy_desc_,
out_grad_data, cudnn_rnn_cache->dhy_desc_, last_h_grad_data,
cudnn_rnn_cache->dcy_desc_, last_c_grad_data, cudnn_rnn_cache->w_desc_,
weight_data, cudnn_rnn_cache->hx_desc_, init_h_data,
cudnn_rnn_cache->cx_desc_, init_c_data, cudnn_rnn_cache->dx_desc_,
in_grad_data, cudnn_rnn_cache->dhx_desc_, init_h_grad_data,
cudnn_rnn_cache->dcx_desc_, init_c_grad_data, work_data,
cudnn_rnn_cache->workspace_size_, reserve_data,
cudnn_rnn_cache->reserve_size_));
handle, cudnn_rnn_cache->rnn_desc_, seq_len, cudnn_rnn_cache->y_desc_,
out_data, cudnn_rnn_cache->y_desc_, out_grad_data,
cudnn_rnn_cache->hy_desc_, last_h_grad_data, cudnn_rnn_cache->cy_desc_,
last_c_grad_data, cudnn_rnn_cache->w_desc_, weight_data,
cudnn_rnn_cache->hx_desc_, init_h_data, cudnn_rnn_cache->cx_desc_,
init_c_data, cudnn_rnn_cache->x_desc_, in_grad_data,
cudnn_rnn_cache->hx_desc_, init_h_grad_data, cudnn_rnn_cache->cx_desc_,
init_c_grad_data, work_data, cudnn_rnn_cache->workspace_size_,
const_cast<uint8_t *>(reserve_data), reserve_size));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardWeights(
handle, cudnn_rnn_cache->rnn_desc_, run_seq_len,
cudnn_rnn_cache->x_desc_, input->data<T>(), cudnn_rnn_cache->hx_desc_,
init_h->data<T>(), cudnn_rnn_cache->y_desc_, out->data<T>(),
handle, cudnn_rnn_cache->rnn_desc_, seq_len, cudnn_rnn_cache->x_desc_,
input->data<T>(), cudnn_rnn_cache->hx_desc_, init_h->data<T>(),
cudnn_rnn_cache->y_desc_, out->data<T>(),
cudnn_rnn_cache->workspace_data_.data<uint8_t>(),
cudnn_rnn_cache->workspace_size_, cudnn_rnn_cache->dw_desc_,
weight_grad->data<T>(), cudnn_rnn_cache->reserve_data_.data<uint8_t>(),
cudnn_rnn_cache->reserve_size_));
cudnn_rnn_cache->workspace_size_, cudnn_rnn_cache->w_desc_,
weight_grad->data<T>(), const_cast<uint8_t *>(reserve_data),
reserve_size));
delete cudnn_rnn_cache;
}
};
......@@ -257,5 +201,7 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(cudnn_lstm, ops::CudnnLSTMGPUKernel<float>);
REGISTER_OP_CUDA_KERNEL(cudnn_lstm_grad, ops::CudnnLSTMGPUGradKernel<float>);
REGISTER_OP_CUDA_KERNEL(cudnn_lstm, ops::CudnnLSTMGPUKernel<float>,
ops::CudnnLSTMGPUKernel<double>);
REGISTER_OP_CUDA_KERNEL(cudnn_lstm_grad, ops::CudnnLSTMGPUGradKernel<float>,
ops::CudnnLSTMGPUGradKernel<double>);
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/cudnn_helper.h"
......@@ -24,16 +25,12 @@ struct CudnnRNNCache {
CudnnRNNCache() {
x_desc_ = NULL;
y_desc_ = NULL;
dx_desc_ = NULL;
dy_desc_ = NULL;
}
~CudnnRNNCache() { release(); }
cudnnRNNDescriptor_t rnn_desc_;
cudnnTensorDescriptor_t *x_desc_;
cudnnTensorDescriptor_t *y_desc_;
cudnnTensorDescriptor_t *dx_desc_;
cudnnTensorDescriptor_t *dy_desc_;
cudnnTensorDescriptor_t hx_desc_;
cudnnTensorDescriptor_t cx_desc_;
......@@ -55,13 +52,9 @@ struct CudnnRNNCache {
cudnnFilterDescriptor_t dw_desc_;
size_t workspace_size_;
size_t reserve_size_;
framework::Tensor reserve_data_;
framework::Tensor workspace_data_;
framework::Tensor dropout_state_;
size_t max_length_;
size_t seq_length_;
float dropout_prob_;
bool is_bidirec_;
......@@ -72,10 +65,12 @@ struct CudnnRNNCache {
int num_layers_;
int seed_;
void init(cudnnHandle_t handle, const platform::Place &place, size_t max_len,
void init(cudnnHandle_t handle, const platform::Place &place, size_t seq_len,
int batch_size, int input_size, int hidden_size, int num_layers,
float dropout_prob, bool is_bidirec, int seed, int weight_numel) {
max_length_ = max_len;
float dropout_prob, bool is_bidirec, int seed, int weight_numel,
size_t *reserve_size_, framework::Tensor *dropout_state_,
bool initialized, cudnnDataType_t cudnn_type) {
seq_length_ = seq_len;
batch_size_ = batch_size;
input_size_ = input_size;
hidden_size_ = hidden_size;
......@@ -84,55 +79,34 @@ struct CudnnRNNCache {
is_bidirec_ = is_bidirec;
seed_ = seed;
x_desc_ = new cudnnTensorDescriptor_t[max_length_];
y_desc_ = new cudnnTensorDescriptor_t[max_length_];
dx_desc_ = new cudnnTensorDescriptor_t[max_length_];
dy_desc_ = new cudnnTensorDescriptor_t[max_length_];
int dim_a[3];
int stride_a[3];
const auto numDirections = is_bidirec_ ? 2 : 1;
auto cudnn_size =
cudnn_type == CUDNN_DATA_FLOAT ? sizeof(float) : sizeof(double);
x_desc_ = new cudnnTensorDescriptor_t[seq_length_];
y_desc_ = new cudnnTensorDescriptor_t[seq_length_];
std::vector<int> dims = {batch_size_, input_size_, 1};
std::vector<int> strides = {input_size_, 1, 1};
std::vector<int> dims_y = {batch_size_, hidden_size_ * numDirections, 1};
std::vector<int> strides_y = {hidden_size_ * numDirections, 1, 1};
for (size_t i = 0; i < max_length_; ++i) {
for (size_t i = 0; i < seq_length_; ++i) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnCreateTensorDescriptor(&x_desc_[i]));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnCreateTensorDescriptor(&y_desc_[i]));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnCreateTensorDescriptor(&dx_desc_[i]));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnCreateTensorDescriptor(&dy_desc_[i]));
dim_a[0] = batch_size_;
dim_a[1] = input_size_;
dim_a[2] = 1;
stride_a[0] = dim_a[2] * dim_a[1];
stride_a[1] = dim_a[2];
stride_a[2] = 1;
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor(
x_desc_[i], CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor(
dx_desc_[i], CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
dim_a[0] = batch_size_;
dim_a[1] = is_bidirec_ ? hidden_size_ * 2 : hidden_size_;
dim_a[2] = 1;
stride_a[0] = dim_a[2] * dim_a[1];
stride_a[1] = dim_a[2];
stride_a[2] = 1;
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor(
y_desc_[i], CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
x_desc_[i], cudnn_type, 3, dims.data(), strides.data()));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor(
dy_desc_[i], CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
y_desc_[i], cudnn_type, 3, dims_y.data(), strides_y.data()));
}
dim_a[0] = num_layers_ * (is_bidirec_ ? 2 : 1);
dim_a[1] = batch_size_;
dim_a[2] = hidden_size_;
stride_a[0] = dim_a[2] * dim_a[1];
stride_a[1] = dim_a[2];
stride_a[2] = 1;
std::vector<int> dims_hx = {num_layers_ * numDirections, batch_size_,
hidden_size_};
std::vector<int> strides_hx = {hidden_size_ * batch_size_, hidden_size_, 1};
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnCreateTensorDescriptor(&hx_desc_));
......@@ -152,33 +126,44 @@ struct CudnnRNNCache {
platform::dynload::cudnnCreateTensorDescriptor(&dcy_desc_));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor(
hx_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
hx_desc_, cudnn_type, 3, dims_hx.data(), strides_hx.data()));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor(
cx_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
cx_desc_, cudnn_type, 3, dims_hx.data(), strides_hx.data()));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor(
hy_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
hy_desc_, cudnn_type, 3, dims_hx.data(), strides_hx.data()));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor(
cy_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
cy_desc_, cudnn_type, 3, dims_hx.data(), strides_hx.data()));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor(
dhx_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
dhx_desc_, cudnn_type, 3, dims_hx.data(), strides_hx.data()));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor(
dcx_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
dcx_desc_, cudnn_type, 3, dims_hx.data(), strides_hx.data()));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor(
dhy_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
dhy_desc_, cudnn_type, 3, dims_hx.data(), strides_hx.data()));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor(
dcy_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
dcy_desc_, cudnn_type, 3, dims_hx.data(), strides_hx.data()));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnCreateDropoutDescriptor(&dropout_desc_));
size_t state_size;
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnDropoutGetStatesSize(handle, &state_size));
dropout_state_.Resize({static_cast<int64_t>(state_size)});
auto *dropout_state_data = dropout_state_.mutable_data<uint8_t>(place);
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetDropoutDescriptor(
dropout_desc_, handle, dropout_prob_, dropout_state_data, state_size,
seed_));
if (!initialized) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnDropoutGetStatesSize(handle, &state_size));
dropout_state_->Resize({static_cast<int64_t>(state_size)});
uint8_t *dropout_state_data =
dropout_state_->mutable_data<uint8_t>(place);
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetDropoutDescriptor(
dropout_desc_, handle, dropout_prob_, dropout_state_data, state_size,
seed_));
} else {
uint8_t *dropout_state_data = dropout_state_->data<uint8_t>();
auto dropout_state_dims = dropout_state_->dims();
state_size = dropout_state_dims[0];
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnRestoreDropoutDescriptor(
dropout_desc_, handle, dropout_prob_, dropout_state_data,
state_size, 0));
}
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnCreateRNNDescriptor(&rnn_desc_));
......@@ -188,12 +173,12 @@ struct CudnnRNNCache {
handle, rnn_desc_, hidden_size_, num_layers_, dropout_desc_,
CUDNN_LINEAR_INPUT,
is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, CUDNN_LSTM,
CUDNN_RNN_ALGO_STANDARD, CUDNN_DATA_FLOAT));
CUDNN_RNN_ALGO_STANDARD, cudnn_type));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetRNNDescriptor(
rnn_desc_, hidden_size_, num_layers_, dropout_desc_, CUDNN_LINEAR_INPUT,
is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, CUDNN_LSTM,
CUDNN_DATA_FLOAT));
cudnn_type));
#endif
PADDLE_ENFORCE_CUDA_SUCCESS(
......@@ -202,48 +187,42 @@ struct CudnnRNNCache {
platform::dynload::cudnnCreateFilterDescriptor(&dw_desc_));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnGetRNNParamsSize(
handle, rnn_desc_, x_desc_[0], &weights_size_, CUDNN_DATA_FLOAT));
handle, rnn_desc_, x_desc_[0], &weights_size_, cudnn_type));
PADDLE_ENFORCE_EQ(
weights_size_, cudnn_size * weight_numel,
platform::errors::InvalidArgument(
"The cudnn lstm and setting weight size should be same."));
PADDLE_ENFORCE_EQ(weights_size_, sizeof(float) * weight_numel,
"cudnn lstm weight size should be SAME");
int dim_w[3];
dim_w[0] = weights_size_ / sizeof(float);
dim_w[0] = weights_size_ / cudnn_size;
dim_w[1] = 1;
dim_w[2] = 1;
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetFilterNdDescriptor(
w_desc_, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, 3, dim_w));
w_desc_, cudnn_type, CUDNN_TENSOR_NCHW, 3, dim_w));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetFilterNdDescriptor(
dw_desc_, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, 3, dim_w));
dw_desc_, cudnn_type, CUDNN_TENSOR_NCHW, 3, dim_w));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnGetRNNWorkspaceSize(
handle, rnn_desc_, max_length_, x_desc_, &workspace_size_));
handle, rnn_desc_, seq_length_, x_desc_, &workspace_size_));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnGetRNNTrainingReserveSize(
handle, rnn_desc_, max_length_, x_desc_, &reserve_size_));
reserve_data_.Resize({static_cast<int64_t>(reserve_size_)});
reserve_data_.mutable_data<uint8_t>(place);
handle, rnn_desc_, seq_length_, x_desc_, reserve_size_));
workspace_data_.Resize({static_cast<int64_t>(workspace_size_)});
workspace_data_.mutable_data<uint8_t>(place);
}
void release() {
for (size_t i = 0; i < max_length_; ++i) {
for (size_t i = 0; i < seq_length_; ++i) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnDestroyTensorDescriptor(x_desc_[i]));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnDestroyTensorDescriptor(y_desc_[i]));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnDestroyTensorDescriptor(dx_desc_[i]));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnDestroyTensorDescriptor(dy_desc_[i]));
}
delete[] x_desc_;
delete[] y_desc_;
delete[] dx_desc_;
delete[] dy_desc_;
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnDestroyTensorDescriptor(hx_desc_));
......
......@@ -100,6 +100,7 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
__macro(cudnnCreateDropoutDescriptor); \
__macro(cudnnDropoutGetStatesSize); \
__macro(cudnnSetDropoutDescriptor); \
__macro(cudnnRestoreDropoutDescriptor); \
__macro(cudnnCreateRNNDescriptor); \
__macro(cudnnGetRNNParamsSize); \
__macro(cudnnGetRNNWorkspaceSize); \
......
......@@ -2213,9 +2213,9 @@ def lstm(input,
input ( :ref:`api_guide_Variable_en` ): LSTM input tensor, 3-D Tensor of shape :math:`[batch\_size, seq\_len, input\_dim]` . Data type is float32 or float64
init_h( :ref:`api_guide_Variable_en` ): The initial hidden state of the LSTM, 3-D Tensor of shape :math:`[num\_layers, batch\_size, hidden\_size]` .
If is_bidirec = True, shape should be :math:`[num\_layers*2, batch\_size, hidden\_size]` . Data type is float32 or float64.
max_len (int): This parameter has no effect and will be discarded.
init_c( :ref:`api_guide_Variable_en` ): The initial cell state of the LSTM, 3-D Tensor of shape :math:`[num\_layers, batch\_size, hidden\_size]` .
If is_bidirec = True, shape should be :math:`[num\_layers*2, batch\_size, hidden\_size]` . Data type is float32 or float64.
max_len (int): max length of LSTM. the first dim of input tensor CAN NOT greater than max_len.
hidden_size (int): hidden size of the LSTM.
num_layers (int): total layers number of the LSTM.
dropout_prob(float, optional): dropout prob, dropout ONLY work between rnn layers, NOT between time steps
......@@ -2256,7 +2256,6 @@ def lstm(input,
data = fluid.data(name='x', shape=[None, 100], dtype='int64')
emb = fluid.embedding(input=data, size=[vocab_size, emb_dim], is_sparse=True)
batch_size = 20
max_len = 100
dropout_prob = 0.2
input_size = 100
hidden_size = 150
......@@ -2309,9 +2308,11 @@ def lstm(input,
out = helper.create_variable_for_type_inference(dtype)
last_h = helper.create_variable_for_type_inference(dtype)
last_c = helper.create_variable_for_type_inference(dtype)
cache = helper.create_variable(
persistable=True, type=core.VarDesc.VarType.RAW, stop_gradient=True)
reserve = helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.UINT8, stop_gradient=True)
state_out = helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.UINT8, stop_gradient=True)
state_out.persistable = True
helper.append_op(
type='cudnn_lstm',
......@@ -2320,15 +2321,15 @@ def lstm(input,
'InitH': init_h,
'InitC': init_c,
'W': weight,
'Cache': cache,
},
outputs={
'Out': out,
'last_h': last_h,
'last_c': last_c,
'LastH': last_h,
'LastC': last_c,
'Reserve': reserve,
'StateOut': state_out,
},
attrs={
'max_len': max_len,
'is_bidirec': is_bidirec,
'input_size': input_size,
'hidden_size': hidden_size,
......
......@@ -20,15 +20,14 @@ import numpy as np
import paddle.fluid.core as core
from op_test import OpTest
import paddle.fluid as fluid
import paddle.fluid.layers as layers
SIGMOID_THRESHOLD_MIN = -40.0
SIGMOID_THRESHOLD_MAX = 13.0
EXP_MAX_INPUT = 40.0
def lstm_naive(
input,
w, ):
def lstm_naive(input, w):
seq_len, batch_size, hidden_size = input.shape
offset = 0
......@@ -86,8 +85,8 @@ def lstm_naive(
return (2. / (1. + np.exp(y))) - 1.
output = []
pre_h = np.zeros((batch_size, hidden_size), dtype=input.dtype)
pre_c = np.zeros((batch_size, hidden_size), dtype=input.dtype)
pre_h = np.zeros((1, batch_size, hidden_size), dtype=input.dtype)
pre_c = np.zeros((1, batch_size, hidden_size), dtype=input.dtype)
for i in range(seq_len):
emb_1 = input[i]
......@@ -110,7 +109,6 @@ def lstm_naive(
output = np.concatenate(output, -1)
output = output.reshape((batch_size, -1, hidden_size))
output = output.transpose((1, 0, 2))
return output, pre_h, pre_c
......@@ -119,11 +117,12 @@ def lstm_naive(
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNLstmOp(OpTest):
# TODO(GaoWei8):when input dtype is fp64, precision threshold should be removed.
def setUp(self):
self.op_type = "cudnn_lstm"
self.dtype = np.float32
self.dtype = np.float64
num_steps = 20
seq_length = 20
batch_size = 5
hidden_size = 20
......@@ -133,33 +132,24 @@ class TestCUDNNLstmOp(OpTest):
weight_size += hidden_size * 8
input = np.random.uniform(
low=-0.1, high=0.1, size=(num_steps, batch_size,
low=-0.1, high=0.1, size=(seq_length, batch_size,
hidden_size)).astype(self.dtype)
flat_w = np.random.uniform(
low=-0.1, high=0.1, size=(weight_size)).astype(self.dtype)
output, last_hidden, last_cell = lstm_naive(input, flat_w)
init_h = np.zeros((batch_size, hidden_size), dtype=np.float32)
init_c = np.zeros((batch_size, hidden_size), dtype=np.float32)
scope = core.Scope()
program = fluid.Program()
block = program.global_block()
cache_temp = block.create_var(
name="Cache",
persistable=True,
type=core.VarDesc.VarType.RAW,
stop_gradient=True)
init_h = np.zeros((1, batch_size, hidden_size), dtype=np.float64)
init_c = np.zeros((1, batch_size, hidden_size), dtype=np.float64)
state_out = np.ndarray((300)).astype("uint8")
self.inputs = {
'Input': OpTest.np_dtype_to_fluid_dtype(input),
'W': OpTest.np_dtype_to_fluid_dtype(flat_w),
'InitH': OpTest.np_dtype_to_fluid_dtype(init_h),
'InitC': OpTest.np_dtype_to_fluid_dtype(init_c),
'Input': input,
'W': flat_w,
'InitH': init_h,
'InitC': init_c
}
self.cache_name_list = ['Cache']
self.attrs = {
'max_len': num_steps,
'dropout_prob': 0.0,
'is_bidirec': False,
'input_size': hidden_size,
......@@ -168,22 +158,61 @@ class TestCUDNNLstmOp(OpTest):
}
self.outputs = {
'Out': output,
"last_h": last_hidden,
'last_c': last_cell
"LastH": last_hidden,
'LastC': last_cell,
'Reserve': np.ndarray((400)).astype("uint8"),
'StateOut': state_out
}
def test_output_with_place(self):
# depend on the scope structure
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-5, check_dygraph=False)
self.check_output_with_place(
place, no_check_set=['Reserve', 'StateOut'])
def test_grad_with_place(self):
# depend on the scope structure
place = core.CUDAPlace(0)
self.check_grad_with_place(
place,
set(['Input', 'W', 'InitH', 'InitC']), ['Out', 'last_h', 'last_c'],
check_dygraph=False)
set(['Input', 'W', 'InitH', 'InitC']), ['Out', 'LastH', 'LastC'],
max_relative_error=1e-4)
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNlstmAPI(unittest.TestCase):
def test_lstm(self):
seq_len = 20
batch_size = 5
hidden_size = 20
dropout_prob = 0.0
num_layers = 1
input = fluid.data(
name='input',
shape=[seq_len, batch_size, hidden_size],
dtype='float64')
init_h = layers.fill_constant([num_layers, batch_size, hidden_size],
'float64', 0.0)
init_c = layers.fill_constant([num_layers, batch_size, hidden_size],
'float64', 0.0)
rnn_out, last_h, last_c = layers.lstm(input, init_h, init_c, seq_len,
hidden_size, num_layers,
dropout_prob)
exe = fluid.Executor(fluid.CUDAPlace(0))
exe.run(fluid.default_startup_program())
input_i = np.random.uniform(
low=-0.1, high=0.1, size=(seq_len, batch_size,
hidden_size)).astype("float64")
out = exe.run(fluid.default_main_program(),
feed={'input': input_i},
fetch_list=[rnn_out, last_h, last_c, 'cudnn_lstm_0.w_0'])
output, last_hidden, last_cell = lstm_naive(input_i, out[3])
self.assertTrue(np.allclose(output, out[0], atol=1e-5))
self.assertTrue(np.allclose(last_hidden, out[1], atol=1e-5))
self.assertTrue(np.allclose(last_cell, out[2], atol=1e-5))
if __name__ == '__main__':
......
......@@ -26,4 +26,5 @@ no_check_set_white_list = [
'cross_entropy2',
'seed',
'amp_check_finite_and_scale',
'cudnn_lstm',
]
......@@ -41,7 +41,8 @@ NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST = [
'unpool', \
'yolov3_loss', \
'inverse', \
'bilateral_slice'
'bilateral_slice',\
'cudnn_lstm'
]
NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST = ['bilinear_interp']
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册