提交 6d98ba30 编写于 作者: P phlrain

add cudnn lstm

test=release/1.2
上级 cc441ee1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
class CudnnLSTMOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("W"),
"Input(Weight) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("InitH"),
"Input(init_h) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("InitC"),
"Input(init_c) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Cache"),
"Input(Cache) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("last_h"),
"Output(last_h) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("last_c"),
"Output(last_c) of LSTM should not be null.");
auto in_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_EQ(in_dims.size(), 3, "Input(X)'s rank must be 3.");
ctx->SetOutputDim("Out", ctx->GetInputDim("Input"));
ctx->SetOutputDim("last_h", ctx->GetInputDim("InitH"));
ctx->SetOutputDim("last_c", ctx->GetInputDim("InitC"));
}
};
class CudnnLSTMOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput(
"Input",
"(Tensor) RNN input tensor, which support variable-time length input "
"sequence."
"The shape of the Tensor MUST be ( seq_len * batch_size * input_size)"
"seq_len is the total time step in this mini-batch (CAN be change in "
"different batch)"
"batch_size is the instance number of this batch"
"input_size is the hidden size of the input."
"input_hidden_size and the hidden_size in the next may not be same");
AddInput("InitH",
"(Tensor) the initial hidden state of the LSTM"
"input. This is a tensor with shape (num_layers x batch_size x "
"hidden_size)"
"and When is_bidirec is True, the shape will be (num_layers*2 x "
"batch_size x hidden_size)");
AddInput("InitC",
"(Tensor) the initial cell state of the LSTm "
"input. This is a tensor with shape (num_layers x batch_size x "
"hidden_size)"
"and When is_bidirec is True, the shape will be (num_layers*2 x "
"batch_size x hidden_size)");
AddInput("W",
"(Tensor) the learnable hidden-hidden weights."
" The shape is (N), where N is total weight size of the LSTM. "
" cudnn concatenate all the weight to one Tensor");
AddInput("Cache",
"The cache of dropout op, a RAW type variable including random "
"number generator states and some descriptors, which is used in "
"cudnn kernel.")
.AsDispensable();
AddOutput("Out",
"(Tensor) the hidden state of LSTM operator. "
"The shape is ( seq_len x batch_size x hidden_size) if "
"is_bidirec is False"
"and When is_bidirec is True, the shape will be ( seq_len x "
"batch_size x hidden_size * 2) ");
AddOutput("last_h",
"(Tensor) the hidden state of the last step. "
"The shape is ( num_layers x batch_size x hidden_size) if "
"is_bidirec is False"
"and When is_bidirec is True, the shape will be (num_layers*2 x "
"batch_size x hidden_size)");
AddOutput("last_c",
"(Tensor) the cell state of the last step"
"The shape is ( num_layers x batch_size x hidden_size) if "
"is_bidirec is False"
"and When is_bidirect is True, the shape will be (num_layers*2 x "
"batch_size x hidden_size*2)");
AddAttr<int>("max_len",
"max length of the LSTM op"
"the first dim of the Input can NOT be greater than max_len")
.SetDefault(20);
AddAttr<float>(
"dropout_prob",
"dropout prob of the dropout op"
"the dropout ONLY work between lstm layers, not between time steps"
"There is no dropout work on the Out tensor")
.SetDefault(0.0);
AddAttr<bool>("is_bidirec",
"is_bidirec"
"if it is bidirection rnn"
"The will affect the shape of the Out, last_h, and last_c")
.SetDefault(false);
AddAttr<int>("input_size", "input size ot the Input Tensor").SetDefault(10);
AddAttr<int>("hidden_size", "hidden size of the LSTM").SetDefault(100);
AddAttr<int>("num_layers", "the total layer number of the LSTM")
.SetDefault(1);
AddAttr<bool>("is_test", "True if in test phase.").SetDefault(false);
AddAttr<int>("seed", "seed to used if fix_seed is True").SetDefault(-1);
AddComment(R"DOC(
CUDNN LSTM implementation
A four-gate Long Short-Term Memory network with no peephole connections.
In the forward pass the output ht and cell output ct for a given iteration can be computed from the recurrent input ht-1,
the cell input ct-1 and the previous layer input xt given matrices W, R and biases bW, bR from the following equations:
$$ i_t = sigmoid(W_{ix}x_{t} + W_{ih}h_{t-1} + bx_i + bh_i) $$
$$ f_t = sigmoid(W_{fx}x_{t} + W_{fh}h_{t-1} + bx_f + bh_f) $$
$$ o_t = sigmoid(W_{ox}x_{t} + W_{oh}h_{t-1} + bx_o + bh_o) $$
$$ \\tilde{c_t} = tanh(W_{cx}x_t + W_{ch}h_{t-1} + bx_c + bh_c) $$
$$ c_t = f_t \\odot c_{t-1} + i_t \\odot \\tilde{c_t} $$
$$ h_t = o_t \\odot tanh(c_t) $$
- W terms denote weight matrices (e.g. $W_{ix}$ is the matrix
of weights from the input gate to the input)
- The b terms denote bias vectors ($bx_i$ and $bh_i$ are the input gate bias vector).
- sigmoid is the logistic sigmoid function.
- $i, f, o$ and $c$ are the input gate, forget gate, output gate,
and cell activation vectors, respectively, all of which have the same size as
the cell output activation vector $h$.
- The $\odot$ is the element-wise product of the vectors.
- `tanh` is the activation functions.
- $\tilde{c_t}$ is also called candidate hidden state,
which is computed based on the current input and the previous hidden state.
Where sigmoid is the sigmoid operator: sigmoid(x) = 1 / (1 + e^-x), * represents a point-wise multiplication,
X represensts a matrix multiplication
)DOC");
}
};
class CudnnLSTMGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("last_h"),
"Input(last_h) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("last_c"),
"Input(last_c) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Cache"),
"Input(last_c) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("InitH"),
"Input(init_h) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("InitC"),
"Input(init_c) of LSTM should not be null.");
auto SetOutGradDim = [&ctx](const std::string& name) {
auto g_name = framework::GradVarName(name);
if (ctx->HasOutput(g_name)) {
ctx->SetOutputDim(g_name, ctx->GetInputDim(name));
}
};
SetOutGradDim("Input");
SetOutGradDim("W");
SetOutGradDim("InitH");
SetOutGradDim("InitC");
}
};
template <typename T>
class NotImpleKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_THROW(
"CPU is not support for this kernel now. Will be add in the future");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(cudnn_lstm, ops::CudnnLSTMOp, ops::CudnnLSTMOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(cudnn_lstm_grad, ops::CudnnLSTMGradOp);
REGISTER_OP_CPU_KERNEL(cudnn_lstm, ops::NotImpleKernel<float>);
REGISTER_OP_CPU_KERNEL(cudnn_lstm_grad, ops::NotImpleKernel<float>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cudnn_helper.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
struct CudnnRNNCache {
CudnnRNNCache() {
x_desc_ = NULL;
y_desc_ = NULL;
dx_desc_ = NULL;
dy_desc_ = NULL;
}
~CudnnRNNCache() { release(); }
cudnnRNNDescriptor_t rnn_desc_;
cudnnTensorDescriptor_t *x_desc_;
cudnnTensorDescriptor_t *y_desc_;
cudnnTensorDescriptor_t *dx_desc_;
cudnnTensorDescriptor_t *dy_desc_;
cudnnTensorDescriptor_t hx_desc_;
cudnnTensorDescriptor_t cx_desc_;
cudnnTensorDescriptor_t hy_desc_;
cudnnTensorDescriptor_t cy_desc_;
cudnnTensorDescriptor_t dhx_desc_;
cudnnTensorDescriptor_t dcx_desc_;
cudnnTensorDescriptor_t dhy_desc_;
cudnnTensorDescriptor_t dcy_desc_;
cudnnTensorDescriptor_t output_x_desc_;
cudnnTensorDescriptor_t output_y_desc_;
cudnnDropoutDescriptor_t dropout_desc_;
size_t weights_size_;
cudnnFilterDescriptor_t w_desc_;
cudnnFilterDescriptor_t dw_desc_;
size_t workspace_size_;
size_t reserve_size_;
Tensor reserve_data_;
Tensor workspace_data_;
Tensor dropout_state_;
size_t max_length_;
float dropout_prob_;
bool is_bidirec_;
int batch_size_;
int input_size_;
int hidden_size_;
int num_layers_;
int seed_;
void init(cudnnHandle_t handle, const framework::ExecutionContext &ctx,
size_t max_len, int batch_size, int input_size, int hidden_size,
int num_layers, float dropout_prob, bool is_bidirec, int seed,
int weight_numel) {
max_length_ = max_len;
batch_size_ = batch_size;
input_size_ = input_size;
hidden_size_ = hidden_size;
num_layers_ = num_layers;
dropout_prob_ = dropout_prob;
is_bidirec_ = is_bidirec;
seed_ = seed;
x_desc_ = new cudnnTensorDescriptor_t[max_length_];
y_desc_ = new cudnnTensorDescriptor_t[max_length_];
dx_desc_ = new cudnnTensorDescriptor_t[max_length_];
dy_desc_ = new cudnnTensorDescriptor_t[max_length_];
int dim_a[3];
int stride_a[3];
for (size_t i = 0; i < max_length_; ++i) {
CUDNN_ENFORCE(
platform::dynload::cudnnCreateTensorDescriptor(&x_desc_[i]));
CUDNN_ENFORCE(
platform::dynload::cudnnCreateTensorDescriptor(&y_desc_[i]));
CUDNN_ENFORCE(
platform::dynload::cudnnCreateTensorDescriptor(&dx_desc_[i]));
CUDNN_ENFORCE(
platform::dynload::cudnnCreateTensorDescriptor(&dy_desc_[i]));
dim_a[0] = batch_size_;
dim_a[1] = input_size_;
dim_a[2] = 1;
stride_a[0] = dim_a[2] * dim_a[1];
stride_a[1] = dim_a[2];
stride_a[2] = 1;
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
x_desc_[i], CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
dx_desc_[i], CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
dim_a[0] = batch_size_;
dim_a[1] = is_bidirec_ ? hidden_size_ * 2 : hidden_size_;
dim_a[2] = 1;
stride_a[0] = dim_a[2] * dim_a[1];
stride_a[1] = dim_a[2];
stride_a[2] = 1;
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
y_desc_[i], CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
dy_desc_[i], CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
}
dim_a[0] = num_layers_ * (is_bidirec_ ? 2 : 1);
dim_a[1] = batch_size_;
dim_a[2] = hidden_size_;
stride_a[0] = dim_a[2] * dim_a[1];
stride_a[1] = dim_a[2];
stride_a[2] = 1;
CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&hx_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&cx_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&hy_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&cy_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&dhx_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&dcx_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&dhy_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&dcy_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
hx_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
cx_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
hy_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
cy_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
dhx_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
dcx_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
dhy_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
dcy_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
CUDNN_ENFORCE(
platform::dynload::cudnnCreateDropoutDescriptor(&dropout_desc_));
size_t state_size;
CUDNN_ENFORCE(
platform::dynload::cudnnDropoutGetStatesSize(handle, &state_size);
dropout_state_.Resize({static_cast<int64_t>(state_size)}));
auto *dropout_state_data =
dropout_state_.mutable_data<uint8_t>(ctx.GetPlace());
CUDNN_ENFORCE(platform::dynload::cudnnSetDropoutDescriptor(
dropout_desc_, handle, dropout_prob_, dropout_state_data, state_size,
seed_));
CUDNN_ENFORCE(platform::dynload::cudnnCreateRNNDescriptor(&rnn_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnSetRNNDescriptor_v6(
handle, rnn_desc_, hidden_size_, num_layers_, dropout_desc_,
CUDNN_LINEAR_INPUT,
is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, CUDNN_LSTM,
CUDNN_RNN_ALGO_STANDARD, CUDNN_DATA_FLOAT));
CUDNN_ENFORCE(platform::dynload::cudnnCreateFilterDescriptor(&w_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnCreateFilterDescriptor(&dw_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnGetRNNParamsSize(
handle, rnn_desc_, x_desc_[0], &weights_size_, CUDNN_DATA_FLOAT));
PADDLE_ENFORCE_EQ(weights_size_, sizeof(float) * weight_numel,
"cudnn lstm weight size should be SAME");
int dim_w[3];
dim_w[0] = weights_size_ / sizeof(float);
dim_w[1] = 1;
dim_w[2] = 1;
CUDNN_ENFORCE(platform::dynload::cudnnSetFilterNdDescriptor(
w_desc_, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, 3, dim_w));
CUDNN_ENFORCE(platform::dynload::cudnnSetFilterNdDescriptor(
dw_desc_, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, 3, dim_w));
CUDNN_ENFORCE(platform::dynload::cudnnGetRNNWorkspaceSize(
handle, rnn_desc_, max_length_, x_desc_, &workspace_size_));
CUDNN_ENFORCE(platform::dynload::cudnnGetRNNTrainingReserveSize(
handle, rnn_desc_, max_length_, x_desc_, &reserve_size_));
reserve_data_.Resize({static_cast<int64_t>(reserve_size_)});
reserve_data_.mutable_data<uint8_t>(ctx.GetPlace());
workspace_data_.Resize({static_cast<int64_t>(workspace_size_)});
workspace_data_.mutable_data<uint8_t>(ctx.GetPlace());
}
void release() {
for (size_t i = 0; i < max_length_; ++i) {
CUDNN_ENFORCE(
platform::dynload::cudnnDestroyTensorDescriptor(x_desc_[i]));
CUDNN_ENFORCE(
platform::dynload::cudnnDestroyTensorDescriptor(y_desc_[i]));
CUDNN_ENFORCE(
platform::dynload::cudnnDestroyTensorDescriptor(dx_desc_[i]));
CUDNN_ENFORCE(
platform::dynload::cudnnDestroyTensorDescriptor(dy_desc_[i]));
}
delete[] x_desc_;
delete[] y_desc_;
delete[] dx_desc_;
delete[] dy_desc_;
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(hx_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(cx_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(hy_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(cy_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(dhx_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(dcx_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(dhy_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(dcy_desc_));
CUDNN_ENFORCE(
platform::dynload::cudnnDestroyDropoutDescriptor(dropout_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnDestroyRNNDescriptor(rnn_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnDestroyFilterDescriptor(w_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnDestroyFilterDescriptor(dw_desc_));
}
};
template <typename T>
class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const Tensor *x = ctx.Input<Tensor>("Input");
const Tensor *init_h = ctx.Input<Tensor>("InitH");
const Tensor *init_c = ctx.Input<Tensor>("InitC");
auto w = ctx.Input<Tensor>("W");
Tensor *out = ctx.Output<Tensor>("Out");
Tensor *last_h = ctx.Output<Tensor>("last_h");
Tensor *last_c = ctx.Output<Tensor>("last_c");
const T *x_data = x->data<T>();
const T *init_h_data = init_h->data<T>();
const T *init_c_data = init_c->data<T>();
const T *w_data = w->data<T>();
T *out_data = out->mutable_data<T>(ctx.GetPlace());
T *last_h_data = last_h->mutable_data<T>(ctx.GetPlace());
T *last_c_data = last_c->mutable_data<T>(ctx.GetPlace());
size_t max_len = ctx.Attr<int>("max_len");
float dropout_prob = ctx.Attr<float>("dropout_prob");
bool is_bidirec = ctx.Attr<bool>("is_bidirec");
int input_size = ctx.Attr<int>("input_size");
int hidden_size = ctx.Attr<int>("hidden_size");
int num_layers = ctx.Attr<int>("num_layers");
bool is_test = ctx.Attr<bool>("is_test");
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle();
auto *cache_var = ctx.InputVar("Cache");
if (!cache_var) {
// The RAW type cache variable wouldn't be created and broadcasted on
// multi-devices before the first running.
// use parent scope to make cache persistable
auto *scope = const_cast<framework::Scope *>(ctx.scope().parent());
auto cache_var_name = ctx.Inputs("Cache")[0];
cache_var = scope->Var(cache_var_name);
}
CudnnRNNCache *cudnn_rnn_cache = nullptr;
if (cache_var->IsInitialized()) {
cudnn_rnn_cache = const_cast<framework::Variable *>(cache_var)
->GetMutable<CudnnRNNCache>();
} else {
cudnn_rnn_cache = const_cast<framework::Variable *>(cache_var)
->GetMutable<CudnnRNNCache>();
std::random_device rnd;
int seed = ctx.Attr<int>("seed");
if (seed == -1) {
seed = rnd();
}
auto input_w_numel = w->numel();
auto batch_size = x->dims()[1];
cudnn_rnn_cache->init(handle, ctx, max_len, batch_size, input_size,
hidden_size, num_layers, dropout_prob, is_bidirec,
seed, input_w_numel);
}
auto run_seq_len = x->dims()[0];
if (is_test) {
// for inference
CUDNN_ENFORCE(platform::dynload::cudnnRNNForwardInference(
handle, cudnn_rnn_cache->rnn_desc_, run_seq_len,
cudnn_rnn_cache->x_desc_, x_data, cudnn_rnn_cache->hx_desc_,
init_h_data, cudnn_rnn_cache->cx_desc_, init_c_data,
cudnn_rnn_cache->w_desc_, w_data, cudnn_rnn_cache->y_desc_, out_data,
cudnn_rnn_cache->hy_desc_, last_h_data, cudnn_rnn_cache->cy_desc_,
last_c_data, cudnn_rnn_cache->workspace_data_.data<uint8_t>(),
cudnn_rnn_cache->workspace_size_));
} else {
// for train
CUDNN_ENFORCE(platform::dynload::cudnnRNNForwardTraining(
handle, cudnn_rnn_cache->rnn_desc_, run_seq_len,
cudnn_rnn_cache->x_desc_, x_data, cudnn_rnn_cache->hx_desc_,
init_h_data, cudnn_rnn_cache->cx_desc_, init_c_data,
cudnn_rnn_cache->w_desc_, w_data, cudnn_rnn_cache->y_desc_, out_data,
cudnn_rnn_cache->hy_desc_, last_h_data, cudnn_rnn_cache->cy_desc_,
last_c_data, cudnn_rnn_cache->workspace_data_.data<uint8_t>(),
cudnn_rnn_cache->workspace_size_,
cudnn_rnn_cache->reserve_data_.data<uint8_t>(),
cudnn_rnn_cache->reserve_size_));
}
}
};
template <typename T>
class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *input = ctx.Input<Tensor>("Input");
auto *weight = ctx.Input<Tensor>("W");
auto *init_h = ctx.Input<Tensor>("InitH");
auto *init_c = ctx.Input<Tensor>("InitC");
// auto * last_h = ctx.Input<Tensor>("last_h");
// auto * last_c = ctx.Input<Tensor>("last_c");
auto *out = ctx.Input<Tensor>("Out");
auto *out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *last_h_grad = ctx.Input<Tensor>(framework::GradVarName("last_h"));
auto *last_c_grad = ctx.Input<Tensor>(framework::GradVarName("last_c"));
// auto* init_h = ctx.Input<Tensor>("init_h");
// auto* init_c = ctx.Input<Tensor>("init_c");
auto *in_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
auto *weight_grad = ctx.Output<Tensor>(framework::GradVarName("W"));
auto *init_h_grad = ctx.Output<Tensor>(framework::GradVarName("InitH"));
auto *init_c_grad = ctx.Output<Tensor>(framework::GradVarName("InitC"));
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle();
auto *cache_var = ctx.InputVar("Cache");
PADDLE_ENFORCE(cache_var->IsInitialized());
CudnnRNNCache *cudnn_rnn_cache =
const_cast<framework::Variable *>(cache_var)
->GetMutable<CudnnRNNCache>();
auto input_dims = input->dims();
auto weight_dims = weight->dims();
auto init_h_dims = init_h->dims();
auto init_c_dims = init_c->dims();
in_grad->mutable_data<T>(ctx.GetPlace());
weight_grad->mutable_data<T>(ctx.GetPlace());
math::SetConstant<paddle::platform::CUDADeviceContext, T> zero;
zero(dev_ctx, in_grad, static_cast<T>(0.0));
zero(dev_ctx, weight_grad, static_cast<T>(0.0));
T *init_h_grad_data = NULL;
if (init_h_grad == nullptr) {
Tensor init_h_grad_temp;
init_h_grad_temp.mutable_data<T>(init_h_dims, ctx.GetPlace());
zero(dev_ctx, &init_h_grad_temp, static_cast<T>(0.0));
init_h_grad_data = init_h_grad_temp.data<T>();
} else {
init_h_grad->mutable_data<T>(init_h_dims, ctx.GetPlace());
zero(dev_ctx, init_h_grad, static_cast<T>(0.0));
init_h_grad_data = init_h_grad->data<T>();
}
T *init_c_grad_data = NULL;
if (init_c_grad == nullptr) {
Tensor init_c_grad_temp;
init_c_grad_temp.mutable_data<T>(init_c_dims, ctx.GetPlace());
zero(dev_ctx, &init_c_grad_temp, static_cast<T>(0.0));
init_c_grad_data = init_c_grad_temp.data<T>();
} else {
init_c_grad->mutable_data<T>(init_c_dims, ctx.GetPlace());
zero(dev_ctx, init_c_grad, static_cast<T>(0.0));
init_c_grad_data = init_c_grad->data<T>();
}
const T *last_h_grad_data = NULL;
if (last_h_grad == nullptr) {
Tensor last_h_grad_temp;
last_h_grad_temp.mutable_data<T>(init_h_dims, ctx.GetPlace());
zero(dev_ctx, &last_h_grad_temp, static_cast<T>(0.0));
last_h_grad_data = (const T *)last_h_grad_temp.data<T>();
} else {
last_h_grad_data = last_h_grad->data<T>();
}
const T *last_c_grad_data = NULL;
if (last_c_grad == nullptr) {
Tensor last_c_grad_temp;
last_c_grad_temp.mutable_data<T>(init_c_dims, ctx.GetPlace());
zero(dev_ctx, &last_c_grad_temp, static_cast<T>(0.0));
last_c_grad_data = (const T *)last_c_grad_temp.data<T>();
} else {
last_c_grad_data = last_c_grad->data<T>();
}
const T *out_grad_data = NULL;
if (out_grad == nullptr) {
Tensor out_grad_temp;
out_grad_temp.mutable_data<T>(out->dims(), ctx.GetPlace());
zero(dev_ctx, &out_grad_temp, static_cast<T>(0.0));
out_grad_data = (const T *)out_grad_temp.data<T>();
} else {
out_grad_data = out_grad->data<T>();
}
// zero( dev_ctx, last_h_grad, static_cast<T>(0.0));
// zero( dev_ctx, last_c_grad, static_cast<T>(0.0));
auto out_data = out->data<T>();
// auto out_grad_data = out_grad->data<T>();
auto weight_data = weight->data<T>();
auto init_h_data = init_h->data<T>();
auto init_c_data = init_c->data<T>();
auto in_grad_data = in_grad->data<T>();
auto work_data = cudnn_rnn_cache->workspace_data_.data<uint8_t>();
auto reserve_data = cudnn_rnn_cache->reserve_data_.data<uint8_t>();
auto run_seq_len = input_dims[0];
PADDLE_ENFORCE_LE((size_t)run_seq_len, cudnn_rnn_cache->max_length_,
"cudnn running seq_len CAN not greater max_lengh");
CUDNN_ENFORCE(platform::dynload::cudnnRNNBackwardData(
handle, cudnn_rnn_cache->rnn_desc_, run_seq_len,
cudnn_rnn_cache->y_desc_, out_data, cudnn_rnn_cache->dy_desc_,
out_grad_data, cudnn_rnn_cache->dhy_desc_, last_h_grad_data,
cudnn_rnn_cache->dcy_desc_, last_c_grad_data, cudnn_rnn_cache->w_desc_,
weight_data, cudnn_rnn_cache->hx_desc_, init_h_data,
cudnn_rnn_cache->cx_desc_, init_c_data, cudnn_rnn_cache->dx_desc_,
in_grad_data, cudnn_rnn_cache->dhx_desc_, init_h_grad_data,
cudnn_rnn_cache->dcx_desc_, init_c_grad_data, work_data,
cudnn_rnn_cache->workspace_size_, reserve_data,
cudnn_rnn_cache->reserve_size_));
CUDNN_ENFORCE(platform::dynload::cudnnRNNBackwardWeights(
handle, cudnn_rnn_cache->rnn_desc_, run_seq_len,
cudnn_rnn_cache->x_desc_, input->data<T>(), cudnn_rnn_cache->hx_desc_,
init_h->data<T>(), cudnn_rnn_cache->y_desc_, out->data<T>(),
cudnn_rnn_cache->workspace_data_.data<uint8_t>(),
cudnn_rnn_cache->workspace_size_, cudnn_rnn_cache->dw_desc_,
weight_grad->data<T>(), cudnn_rnn_cache->reserve_data_.data<uint8_t>(),
cudnn_rnn_cache->reserve_size_));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(cudnn_lstm, ops::CudnnLSTMGPUKernel<float>);
REGISTER_OP_CUDA_KERNEL(cudnn_lstm_grad, ops::CudnnLSTMGPUGradKernel<float>);
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid.core as core
from op_test import OpTest
import paddle.fluid as fluid
SIGMOID_THRESHOLD_MIN = -40.0
SIGMOID_THRESHOLD_MAX = 13.0
EXP_MAX_INPUT = 40.0
def lstm_naive(
input,
w, ):
seq_len, batch_size, hidden_size = input.shape
offset = 0
wi = w[offset:offset + hidden_size * hidden_size].reshape(
(hidden_size, hidden_size)).transpose()
offset += hidden_size * hidden_size
wf = w[offset:offset + hidden_size * hidden_size].reshape(
(hidden_size, hidden_size)).transpose()
offset += hidden_size * hidden_size
wc = w[offset:offset + hidden_size * hidden_size].reshape(
(hidden_size, hidden_size)).transpose()
offset += hidden_size * hidden_size
wo = w[offset:offset + hidden_size * hidden_size].reshape(
(hidden_size, hidden_size)).transpose()
offset += hidden_size * hidden_size
ri = w[offset:offset + hidden_size * hidden_size].reshape(
(hidden_size, hidden_size)).transpose()
offset += hidden_size * hidden_size
rf = w[offset:offset + hidden_size * hidden_size].reshape(
(hidden_size, hidden_size)).transpose()
offset += hidden_size * hidden_size
rc = w[offset:offset + hidden_size * hidden_size].reshape(
(hidden_size, hidden_size)).transpose()
offset += hidden_size * hidden_size
ro = w[offset:offset + hidden_size * hidden_size].reshape(
(hidden_size, hidden_size)).transpose()
offset += hidden_size * hidden_size
bi_1 = w[offset:offset + hidden_size]
offset += hidden_size
bf_1 = w[offset:offset + hidden_size]
offset += hidden_size
bc_1 = w[offset:offset + hidden_size]
offset += hidden_size
bo_1 = w[offset:offset + hidden_size]
offset += hidden_size
bi_2 = w[offset:offset + hidden_size]
offset += hidden_size
bf_2 = w[offset:offset + hidden_size]
offset += hidden_size
bc_2 = w[offset:offset + hidden_size]
offset += hidden_size
bo_2 = w[offset:offset + hidden_size]
def sigmoid(x):
y = np.copy(x)
y[x < SIGMOID_THRESHOLD_MIN] = SIGMOID_THRESHOLD_MIN
y[x > SIGMOID_THRESHOLD_MAX] = SIGMOID_THRESHOLD_MAX
return 1. / (1. + np.exp(-y))
def tanh(x):
y = -2. * x
y[y > EXP_MAX_INPUT] = EXP_MAX_INPUT
return (2. / (1. + np.exp(y))) - 1.
output = []
pre_h = np.zeros((batch_size, hidden_size), dtype=input.dtype)
pre_c = np.zeros((batch_size, hidden_size), dtype=input.dtype)
for i in range(seq_len):
emb_1 = input[i]
input_gate = sigmoid(
np.matmul(emb_1, wi) + np.matmul(pre_h, ri) + bi_1 + bi_2)
forget_gate = sigmoid(
np.matmul(emb_1, wf) + np.matmul(pre_h, rf) + bf_1 + bf_2)
output_gate = sigmoid(
np.matmul(emb_1, wo) + np.matmul(pre_h, ro) + bo_1 + bo_2)
c_t_temp = tanh(
np.matmul(emb_1, wc) + np.matmul(pre_h, rc) + bc_1 + bc_2)
new_c = input_gate * c_t_temp + forget_gate * pre_c
new_h = output_gate * tanh(new_c)
pre_h = new_h
pre_c = new_c
output.append(new_h)
output = np.concatenate(output, -1)
output = output.reshape((batch_size, -1, hidden_size))
output = output.transpose((1, 0, 2))
return output, pre_h, pre_c
class TestCUDNNLstmOp(OpTest):
def setUp(self):
self.op_type = "cudnn_lstm"
self.dtype = np.float32
num_steps = 20
batch_size = 5
hidden_size = 20
input_weight_size = (hidden_size * hidden_size) * 4
hidden_weight_size = (hidden_size * hidden_size) * 4
weight_size = input_weight_size + hidden_weight_size
weight_size += hidden_size * 8
input = np.random.uniform(
low=-0.1, high=0.1, size=(num_steps, batch_size,
hidden_size)).astype(self.dtype)
flat_w = np.random.uniform(
low=-0.1, high=0.1, size=(weight_size)).astype(self.dtype)
output, last_hidden, last_cell = lstm_naive(input, flat_w)
init_h = np.zeros((batch_size, hidden_size), dtype=np.float32)
init_c = np.zeros((batch_size, hidden_size), dtype=np.float32)
scope = core.Scope()
program = fluid.Program()
block = program.global_block()
cache_temp = block.create_var(
name="Cache",
persistable=True,
type=core.VarDesc.VarType.RAW,
stop_gradient=True)
self.inputs = {
'Input': OpTest.np_dtype_to_fluid_dtype(input),
'W': OpTest.np_dtype_to_fluid_dtype(flat_w),
'InitH': OpTest.np_dtype_to_fluid_dtype(init_h),
'InitC': OpTest.np_dtype_to_fluid_dtype(init_c),
}
self.cache_name_list = ['Cache']
self.attrs = {
'max_len': num_steps,
'dropout_prob': 0.0,
'is_bidirec': False,
'input_size': hidden_size,
'hidden_size': hidden_size,
'num_layers': 1,
}
self.outputs = {
'Out': output,
"last_h": last_hidden,
'last_c': last_cell
}
def test_output_with_place(self):
if self.testcuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-5)
def test_grad_with_place(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(
place,
set(['Input', 'W', 'InitH', 'InitC']),
['Out', 'last_h', 'last_c'],
max_relative_error=0.02)
def testcuda(self):
return core.is_compiled_with_cuda()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册