未验证 提交 66cf8b08 编写于 作者: Z zyfncg 提交者: GitHub

[Phi] Move Rnn Op from fluid to phi (#41007)

* move rnn kernel to phi

* move infershape of rnn to phi

* fix HIP bug

* rename function

* fix HIP bug

* fix hip bug
上级 59c4fdac
...@@ -396,7 +396,7 @@ class GRUCPUKernel : public framework::OpKernel<T> { ...@@ -396,7 +396,7 @@ class GRUCPUKernel : public framework::OpKernel<T> {
frame_size * 2, T(1), gru_value.gate_value, frame_size * 3); frame_size * 2, T(1), gru_value.gate_value, frame_size * 3);
} }
phi::funcs::detail::forward_reset_output( phi::funcs::detail::forward_reset_output<DeviceContext>(
phi::funcs::detail::forward::gru_resetOutput<T>(), gru_value, phi::funcs::detail::forward::gru_resetOutput<T>(), gru_value,
frame_size, cur_batch_size, active_gate); frame_size, cur_batch_size, active_gate);
...@@ -408,7 +408,7 @@ class GRUCPUKernel : public framework::OpKernel<T> { ...@@ -408,7 +408,7 @@ class GRUCPUKernel : public framework::OpKernel<T> {
frame_size * 3); frame_size * 3);
} }
phi::funcs::detail::forward_final_output( phi::funcs::detail::forward_final_output<DeviceContext>(
phi::funcs::detail::forward::gru_finalOutput<T>(), gru_value, phi::funcs::detail::forward::gru_finalOutput<T>(), gru_value,
frame_size, cur_batch_size, active_node, origin_mode); frame_size, cur_batch_size, active_node, origin_mode);
......
...@@ -12,11 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,11 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/rnn_op.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -25,69 +27,6 @@ class RNNOp : public framework::OperatorWithKernel { ...@@ -25,69 +27,6 @@ class RNNOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "RNN");
OP_INOUT_CHECK(ctx->HasInputs("PreState"), "Input", "PreState", "RNN");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "RNN");
OP_INOUT_CHECK(ctx->HasOutputs("State"), "Output", "State", "RNN");
auto in_dims = ctx->GetInputDim("Input");
auto pre_state_dims = ctx->GetInputsDim("PreState");
PADDLE_ENFORCE_EQ(in_dims.size(), 3,
platform::errors::InvalidArgument(
"The rank of Input in RNN must be 3. But "
"received Input's rank is %d.",
in_dims.size()));
if (ctx->HasInput("SequenceLength")) {
auto seq_dims = ctx->GetInputDim("SequenceLength");
PADDLE_ENFORCE_EQ(
in_dims[1], seq_dims[0],
platform::errors::InvalidArgument(
"The size of SequenceLength has to equal the batch_size. But "
"received batch_size is %d and the size of SequenceLength is %d.",
in_dims[1], seq_dims[0]));
}
PADDLE_ENFORCE_EQ(pre_state_dims[0].size(), 3,
platform::errors::InvalidArgument(
"The rank of PreState in RNN must be 3. But "
"the received rank is %d.",
pre_state_dims[0].size()));
size_t i = 0;
for (; i < pre_state_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(
in_dims[1], pre_state_dims[i][1],
platform::errors::InvalidArgument(
"The second dimension size (representing for batch size) of "
"Input and PreState should be equal. But received %d and %d.",
in_dims[1], pre_state_dims[i][1]));
PADDLE_ENFORCE_EQ(
pre_state_dims[0], pre_state_dims[i],
platform::errors::InvalidArgument(
"The dims of all tensors in PreState should be same. But "
"received PreState[0] is %s and PreState[%d] is %s.",
pre_state_dims[0], i, pre_state_dims[i]));
}
auto mode = ctx->Attrs().Get<std::string>("mode");
size_t num_state = mode == "LSTM" ? 2 : 1;
PADDLE_ENFORCE_EQ(
i, num_state,
platform::errors::InvalidArgument(
"The number of tensors in PreState of %s should be %d, "
"but received %d.",
mode, 2, i));
auto out_dims = in_dims;
auto hidden_size = ctx->Attrs().Get<int>("hidden_size");
bool is_bidirec = ctx->Attrs().Get<bool>("is_bidirec");
out_dims[2] = is_bidirec ? hidden_size * 2 : hidden_size;
ctx->SetOutputDim("Out", out_dims);
ctx->SetOutputsDim("State", pre_state_dims);
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
...@@ -249,15 +188,11 @@ class NotImpleKernel : public framework::OpKernel<T> { ...@@ -249,15 +188,11 @@ class NotImpleKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(rnn, RnnInferShapeFunctor,
PD_INFER_META(phi::RnnInferMeta));
REGISTER_OPERATOR(rnn, ops::RNNOp, ops::RNNOpMaker, REGISTER_OPERATOR(rnn, ops::RNNOp, ops::RNNOpMaker,
ops::RNNGradOpMaker<paddle::framework::OpDesc>, ops::RNNGradOpMaker<paddle::framework::OpDesc>,
ops::RNNGradOpMaker<paddle::imperative::OpBase>); ops::RNNGradOpMaker<paddle::imperative::OpBase>,
RnnInferShapeFunctor);
REGISTER_OPERATOR(rnn_grad, ops::RNNGradOp); REGISTER_OPERATOR(rnn_grad, ops::RNNGradOp);
REGISTER_OP_CPU_KERNEL(
rnn, ops::RNNCPUKernel<paddle::platform::CPUDeviceContext, float>,
ops::RNNCPUKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
rnn_grad, ops::RNNCPUGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::RNNCPUGradKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/utils.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
#ifdef PADDLE_WITH_HIP
using gpuRNNMode_t = miopenRNNMode_t;
using gpuDnnHandle_t = miopenHandle_t;
using gpuDnnDataType_t = miopenDataType_t;
#else
using gpuRNNMode_t = cudnnRNNMode_t;
using gpuDnnHandle_t = cudnnHandle_t;
using gpuDnnDataType_t = cudnnDataType_t;
#endif
class RNNDescriptors {
public:
RNNDescriptors(int seq_length, int batch_size, int input_size,
int hidden_size, int num_layers, float dropout_prob, int seed,
int weight_numel, gpuRNNMode_t mode, bool is_bidirec,
bool is_test)
: seq_length_(seq_length),
batch_size_(batch_size),
input_size_(input_size),
hidden_size_(hidden_size),
num_layers_(num_layers),
dropout_prob_(dropout_prob),
seed_(seed),
weight_numel_(weight_numel),
mode_(mode),
is_bidirec_(is_bidirec),
is_test_(is_test) {}
template <typename T>
void Create(const gpuDnnHandle_t &handle, const platform::Place &place,
const std::vector<int> &sequence_length, size_t *workspace_size,
size_t *reserve_size, framework::Tensor *dropout_state) {
int numDirections = is_bidirec_ ? 2 : 1;
gpuDnnDataType_t cudnn_type = platform::CudnnDataType<T>::type;
// ------------------- cudnn x, y descriptors ---------------------
std::vector<int> dims_x = {batch_size_, input_size_, 1};
std::vector<int> strides_x = {input_size_, 1, 1};
std::vector<int> dims_y = {batch_size_, hidden_size_ * numDirections, 1};
std::vector<int> strides_y = {hidden_size_ * numDirections, 1, 1};
for (int i = 0; i < seq_length_; ++i) {
x_descs_.emplace_back(x_desc_.descriptor<T>(dims_x, strides_x));
y_descs_.emplace_back(y_desc_.descriptor<T>(dims_y, strides_y));
}
#if defined(PADDLE_WITH_CUDA) && CUDNN_VERSION >= 7201
if (!sequence_length.empty()) {
x_seq_desc_.descriptor<T>(seq_length_, batch_size_, input_size_, true,
sequence_length);
y_seq_desc_.descriptor<T>(seq_length_, batch_size_,
hidden_size_ * numDirections, true,
sequence_length);
}
#endif
// ------------------- cudnn hx, hy, cx, cy descriptors----------
std::vector<int> dims_hx = {num_layers_ * numDirections, batch_size_,
hidden_size_};
std::vector<int> strides_hx = {hidden_size_ * batch_size_, hidden_size_, 1};
init_h_desc_.descriptor<T>(dims_hx, strides_hx);
init_c_desc_.descriptor<T>(dims_hx, strides_hx);
last_h_desc_.descriptor<T>(dims_hx, strides_hx);
last_c_desc_.descriptor<T>(dims_hx, strides_hx);
// ------------------- cudnn dropout descriptors ---------------------
size_t state_size;
bool is_initialized = dropout_state->IsInitialized();
if (!is_test_ && !is_initialized) {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::miopenDropoutGetStatesSize(handle, &state_size));
dropout_state->mutable_data<uint8_t>({static_cast<int64_t>(state_size)},
place);
#else
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnDropoutGetStatesSize(handle, &state_size));
dropout_state->mutable_data<uint8_t>({static_cast<int64_t>(state_size)},
place);
#endif
}
dropout_desc_.descriptor(handle, place, is_initialized, dropout_prob_,
is_test_ ? nullptr : dropout_state, seed_,
state_size);
// ------------------- cudnn rnn descriptors ---------------------
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenSetRNNDescriptor_V2(
rnn_desc_.desc(), hidden_size_, num_layers_, dropout_desc_.desc(),
miopenRNNlinear,
is_bidirec_ ? miopenRNNbidirection : miopenRNNunidirection, mode_,
miopenRNNwithBias, miopenRNNdefault, cudnn_type));
#elif CUDNN_VERSION >= 6000
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetRNNDescriptor_v6(
handle, rnn_desc_.desc(), hidden_size_, num_layers_,
dropout_desc_.desc(), CUDNN_LINEAR_INPUT,
is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, mode_,
CUDNN_RNN_ALGO_STANDARD, cudnn_type));
#else
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetRNNDescriptor(
rnn_desc_.desc(), hidden_size_, num_layers_, dropout_desc_.desc(),
CUDNN_LINEAR_INPUT,
is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, mode_,
cudnn_type));
#endif
#if defined(PADDLE_WITH_CUDA) && CUDNN_VERSION >= 7201
if (!sequence_length.empty()) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetRNNPaddingMode(
rnn_desc_.desc(), CUDNN_RNN_PADDED_IO_ENABLED));
}
#endif
// ------------------- cudnn weights_size ---------------------
size_t weights_size_;
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenGetRNNParamsSize(
handle, rnn_desc_.desc(), x_descs_[0], &weights_size_, cudnn_type));
#else
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnGetRNNParamsSize(
handle, rnn_desc_.desc(), x_descs_[0], &weights_size_, cudnn_type));
#endif
PADDLE_ENFORCE_EQ(
weights_size_, sizeof(T) * weight_numel_,
platform::errors::InvalidArgument(
"The cudnn rnn and setting weight size should be same."));
// ------------------- cudnn weight descriptors ---------------------
platform::DataLayout layout = platform::DataLayout::kNCHW;
int dim_tmp = weights_size_ / sizeof(T);
std::vector<int> dim_w = {dim_tmp, 1, 1};
weight_desc_.descriptor<T>(layout, dim_w);
// ------------------- cudnn workspace, reserve size ---------------------
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenGetRNNWorkspaceSize(
handle, rnn_desc_.desc(), seq_length_, x_descs_.data(),
workspace_size));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::miopenGetRNNTrainingReserveSize(
handle, rnn_desc_.desc(), seq_length_, x_descs_.data(),
reserve_size));
#else
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnGetRNNWorkspaceSize(
handle, rnn_desc_.desc(), seq_length_, x_descs_.data(),
workspace_size));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnGetRNNTrainingReserveSize(
handle, rnn_desc_.desc(), seq_length_, x_descs_.data(),
reserve_size));
#endif
}
#ifdef PADDLE_WITH_HIP
miopenTensorDescriptor_t *x_descs() { return x_descs_.data(); }
miopenTensorDescriptor_t *y_descs() { return y_descs_.data(); }
miopenTensorDescriptor_t init_h_desc() { return init_h_desc_.desc(); }
miopenTensorDescriptor_t init_c_desc() { return init_c_desc_.desc(); }
miopenTensorDescriptor_t last_h_desc() { return last_h_desc_.desc(); }
miopenTensorDescriptor_t last_c_desc() { return last_c_desc_.desc(); }
miopenRNNDescriptor_t rnn_desc() { return rnn_desc_.desc(); }
miopenDropoutDescriptor_t dropout_desc() { return dropout_desc_.desc(); }
miopenTensorDescriptor_t weight_desc() { return weight_desc_.desc(); }
#else
cudnnTensorDescriptor_t *x_descs() { return x_descs_.data(); }
cudnnTensorDescriptor_t *y_descs() { return y_descs_.data(); }
#if CUDNN_VERSION >= 7201
cudnnRNNDataDescriptor_t x_seq_desc() { return x_seq_desc_.desc(); }
cudnnRNNDataDescriptor_t y_seq_desc() { return y_seq_desc_.desc(); }
#endif
cudnnTensorDescriptor_t init_h_desc() { return init_h_desc_.desc(); }
cudnnTensorDescriptor_t init_c_desc() { return init_c_desc_.desc(); }
cudnnTensorDescriptor_t last_h_desc() { return last_h_desc_.desc(); }
cudnnTensorDescriptor_t last_c_desc() { return last_c_desc_.desc(); }
cudnnRNNDescriptor_t rnn_desc() { return rnn_desc_.desc(); }
cudnnDropoutDescriptor_t dropout_desc() { return dropout_desc_.desc(); }
cudnnFilterDescriptor_t weight_desc() { return weight_desc_.desc(); }
#endif
private:
int seq_length_;
int batch_size_;
int input_size_;
int hidden_size_;
int num_layers_;
float dropout_prob_;
int seed_;
int weight_numel_;
gpuRNNMode_t mode_;
bool is_bidirec_;
bool is_test_;
#ifdef PADDLE_WITH_HIP
std::vector<miopenTensorDescriptor_t> x_descs_;
std::vector<miopenTensorDescriptor_t> y_descs_;
#else
std::vector<cudnnTensorDescriptor_t> x_descs_;
std::vector<cudnnTensorDescriptor_t> y_descs_;
#endif
platform::ScopedTensorDescriptor x_desc_;
platform::ScopedTensorDescriptor y_desc_;
#if defined(PADDLE_WITH_CUDA) && CUDNN_VERSION >= 7201
platform::ScopedRNNTensorDescriptor x_seq_desc_;
platform::ScopedRNNTensorDescriptor y_seq_desc_;
#endif
platform::ScopedTensorDescriptor init_h_desc_;
platform::ScopedTensorDescriptor init_c_desc_;
platform::ScopedTensorDescriptor last_h_desc_;
platform::ScopedTensorDescriptor last_c_desc_;
platform::ScopedDropoutDescriptor dropout_desc_;
platform::ScopedFilterDescriptor weight_desc_;
platform::ScopedRNNDescriptor rnn_desc_;
};
template <typename T, typename Type>
bool is_continuous(const Type &weight_list) {
bool continuous = true;
for (size_t i = 0; i < weight_list.size() - 1; ++i) {
auto *in_data = weight_list[i]->template data<T>();
auto *in_after_data = weight_list[i + 1]->template data<T>();
auto in_size = weight_list[i]->numel();
bool temp = in_data + in_size == in_after_data;
continuous = continuous && temp;
}
return continuous;
}
template <typename T>
void weight_to_tensor(const platform::Place &place, gpuStream_t stream,
const std::vector<const Tensor *> &weight_list,
Tensor *weight) {
auto weight_data = weight->data<T>();
int weight_offset = 0;
for (size_t i = 0; i < weight_list.size(); ++i) {
const T *in_data = weight_list[i]->data<T>();
auto in_size = weight_list[i]->numel();
memory::Copy(weight->place(), weight_data + weight_offset,
weight_list[i]->place(), in_data, in_size * sizeof(T), stream);
weight_offset += in_size;
}
}
template <typename T>
void weight_to_tensor_list(const platform::Place &place, gpuStream_t stream,
std::vector<Tensor *> *weight_grad,
const std::vector<const Tensor *> &weight_input,
const Tensor *weight) {
int weight_offset = 0;
auto *weight_data = weight->data<T>();
for (size_t i = 0; i < weight_input.size(); ++i) {
auto in_size = weight_input[i]->numel();
T *weight_grad_data = (*weight_grad)[i]->mutable_data<T>(place);
const T *src = weight_data + weight_offset;
memory::Copy((*weight_grad)[i]->place(), weight_grad_data, weight->place(),
src, in_size * sizeof(T), stream);
weight_offset += in_size;
}
}
#ifdef PADDLE_WITH_HIP
template <typename T>
void weight_list_to_tensor(const platform::Place &place, gpuStream_t stream,
const std::vector<Tensor> &tensor_list,
Tensor *weight_whole, const size_t offset = 0UL) {
size_t weight_offset = offset;
auto weight_data = weight_whole->data<T>();
for (size_t i = 0; i < tensor_list.size(); ++i) {
const T *in_data = tensor_list[i].data<T>();
auto in_size = tensor_list[i].numel();
memory::Copy(weight_whole->place(), weight_data + weight_offset,
tensor_list[i].place(), in_data, in_size * sizeof(T), stream);
weight_offset += in_size;
}
}
template <typename T>
void weight_to_permuted_tensor(const platform::Place &place, gpuStream_t stream,
std::vector<const Tensor *> *weight_list,
Tensor *weight_whole,
const gpuRNNMode_t rnn_mode,
const bool is_bidirec) {
if (is_bidirec) {
for (size_t i = 0; i < weight_list->size(); i += 4) {
auto tmp = (*weight_list)[i + 1];
(*weight_list)[i + 1] = (*weight_list)[i + 2];
(*weight_list)[i + 2] = tmp;
}
}
size_t weight_offset = 0;
for (size_t i = 0; i < weight_list->size(); ++i) {
if (rnn_mode == miopenLSTM) {
std::vector<Tensor> split_tensor = (*weight_list)[i]->Chunk(4, 0);
weight_list_to_tensor<T>(
place, stream,
{split_tensor[0], split_tensor[1], split_tensor[3], split_tensor[2]},
weight_whole, weight_offset);
} else if (rnn_mode == miopenGRU) {
std::vector<Tensor> split_tensor = (*weight_list)[i]->Chunk(3, 0);
weight_list_to_tensor<T>(
place, stream, {split_tensor[1], split_tensor[0], split_tensor[2]},
weight_whole, weight_offset);
} else {
weight_list_to_tensor<T>(place, stream, {*(*weight_list)[i]},
weight_whole, weight_offset);
}
weight_offset += (*weight_list)[i]->numel();
}
}
template <typename T>
void tensor_to_permuted_weight(const platform::Place &place, gpuStream_t stream,
const Tensor &tensor,
std::vector<Tensor *> *weight_grad_list,
const gpuRNNMode_t rnn_mode,
const bool is_bidirec) {
if (is_bidirec) {
for (size_t i = 0; i < weight_grad_list->size(); i += 4) {
auto tmp = (*weight_grad_list)[i + 1];
(*weight_grad_list)[i + 1] = (*weight_grad_list)[i + 2];
(*weight_grad_list)[i + 2] = tmp;
}
}
size_t weight_offset = 0;
for (size_t i = 0; i < weight_grad_list->size(); ++i) {
auto numel_size = (*weight_grad_list)[i]->numel();
Tensor temp;
temp.mutable_data<T>({numel_size}, place);
temp.ShareDataWith(tensor.Slice(weight_offset, weight_offset + numel_size));
if (rnn_mode == miopenLSTM) {
std::vector<Tensor> split_tensor = temp.Chunk(4, 0);
weight_list_to_tensor<T>(
place, stream,
{split_tensor[0], split_tensor[1], split_tensor[3], split_tensor[2]},
(*weight_grad_list)[i]);
} else if (rnn_mode == miopenGRU) {
std::vector<Tensor> split_tensor = temp.Chunk(3, 0);
weight_list_to_tensor<T>(
place, stream, {split_tensor[1], split_tensor[0], split_tensor[2]},
(*weight_grad_list)[i]);
} else {
weight_list_to_tensor<T>(place, stream, {temp}, (*weight_grad_list)[i]);
}
weight_offset += numel_size;
}
if (is_bidirec) {
for (size_t i = 0; i < weight_grad_list->size(); i += 4) {
auto tmp = (*weight_grad_list)[i + 1];
(*weight_grad_list)[i + 1] = (*weight_grad_list)[i + 2];
(*weight_grad_list)[i + 2] = tmp;
}
}
}
#endif
template <typename T>
class RNNCudnnKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const Tensor *x = ctx.Input<Tensor>("Input");
auto pre_state = ctx.MultiInput<Tensor>("PreState");
Tensor *out = ctx.Output<Tensor>("Out");
auto state = ctx.MultiOutput<Tensor>("State");
Tensor *reserve = ctx.Output<Tensor>("Reserve");
Tensor *state_out = ctx.Output<Tensor>("DropoutState");
float dropout_prob = ctx.Attr<float>("dropout_prob");
bool is_bidirec = ctx.Attr<bool>("is_bidirec");
int hidden_size = ctx.Attr<int>("hidden_size");
int num_layers = ctx.Attr<int>("num_layers");
auto mode = ctx.Attr<std::string>("mode");
#ifdef PADDLE_WITH_HIP
gpuRNNMode_t rnn_mode = miopenLSTM;
if (mode == "LSTM")
rnn_mode = miopenLSTM;
else if (mode == "GRU")
rnn_mode = miopenGRU;
else if (mode == "RNN_RELU")
rnn_mode = miopenRNNRELU;
else if (mode == "RNN_TANH")
rnn_mode = miopenRNNTANH;
#else
gpuRNNMode_t rnn_mode = CUDNN_LSTM;
if (mode == "LSTM")
rnn_mode = CUDNN_LSTM;
else if (mode == "GRU")
rnn_mode = CUDNN_GRU;
else if (mode == "RNN_RELU")
rnn_mode = CUDNN_RNN_RELU;
else if (mode == "RNN_TANH")
rnn_mode = CUDNN_RNN_TANH;
#endif
else
PADDLE_THROW(platform::errors::InvalidArgument(
"rnn_mode should be LSTM, GRU, RNN_RELU or RNN_TANH, but received: "
"%s.",
mode));
bool is_test = ctx.Attr<bool>("is_test");
int seed = ctx.Attr<int>("seed");
if (!is_test) {
int device_id = ctx.GetPlace().GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
if (gen_cuda->GetIsInitPy() && seed == 0) {
// If perform `manual_seed` in python and inner seed is not specified
// (equals 0), use global generator generated seed.
seed = static_cast<int>(gen_cuda->Random64());
} else if (seed == 0) {
// use random generated seed
std::random_device rd;
seed = rd();
} // else use `ctx.Attr<int>("seed")` specified seed
}
const T *x_data = x->data<T>();
const T *init_h_data = pre_state[0]->data<T>();
const T *init_c_data = nullptr;
T *out_data = out->mutable_data<T>(ctx.GetPlace());
T *last_h_data = state[0]->mutable_data<T>(ctx.GetPlace());
T *last_c_data = nullptr;
#ifdef PADDLE_WITH_HIP
if (rnn_mode == miopenLSTM) {
#else
if (rnn_mode == CUDNN_LSTM) {
#endif
init_c_data = pre_state[1]->data<T>();
last_c_data = state[1]->mutable_data<T>(ctx.GetPlace());
}
bool has_seq_length = ctx.HasInput("SequenceLength");
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_EQ(has_seq_length, false,
platform::errors::InvalidArgument(
"ROCm do not support SequenceLength yet."));
#endif
std::vector<int> SequenceLength;
if (has_seq_length) {
auto *sequence_length = ctx.Input<Tensor>("SequenceLength");
SequenceLength = operators::GetDataFromTensor<int>(sequence_length);
}
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle();
int seq_length = x->dims()[0];
int batch_size = x->dims()[1];
int input_size = x->dims()[2];
size_t workspace_size;
size_t reserve_size;
Tensor weight_whole;
T *w_data = nullptr;
auto place = ctx.GetPlace();
auto stream = reinterpret_cast<const platform::CUDADeviceContext &>(
ctx.device_context())
.stream();
auto weight_list = ctx.MultiInput<framework::Tensor>("WeightList");
auto weight_numel = std::accumulate(
weight_list.begin(), weight_list.end(), 0,
[](int64_t num, const Tensor *t) { return num + t->numel(); });
bool continuous =
is_continuous<T, std::vector<const Tensor *>>(weight_list);
#ifdef PADDLE_WITH_HIP
// Need to permute weight, set continuous to false
continuous = false;
#endif
if (!continuous) {
LOG_FIRST_N(WARNING, 2)
<< "If the memory space of the Input WeightList is not continuous, "
"less efficient calculation will be called. Please call "
"flatten_parameters() to make the input memory continuous.";
weight_whole.mutable_data<T>({weight_numel}, place);
#ifdef PADDLE_WITH_HIP
// MIOPEN need to permute weight for miopenLSTM or miopenGRU
weight_to_permuted_tensor<T>(place, stream, &weight_list, &weight_whole,
rnn_mode, is_bidirec);
#else
weight_to_tensor<T>(place, stream, weight_list, &weight_whole);
#endif
w_data = weight_whole.data<T>();
#ifndef PADDLE_WITH_HIP
// MIOPEN need to permute weight, do not share with weight_grad
if (is_test) { // maybe also reset small weights' ptr for training
int offset = 0;
for (size_t i = 0; i < weight_list.size(); ++i) {
size_t len = weight_list[i]->numel();
auto dim = weight_list[i]->dims();
const_cast<Tensor *>(weight_list[i])
->ShareDataWith(
weight_whole.Slice(static_cast<int64_t>(offset),
static_cast<int64_t>(offset + len)))
.Resize(dim);
offset += len;
}
}
#endif
} else {
w_data = const_cast<T *>(weight_list[0]->data<T>());
}
RNNDescriptors rnn(seq_length, batch_size, input_size, hidden_size,
num_layers, dropout_prob, seed, weight_numel, rnn_mode,
is_bidirec, is_test);
rnn.Create<T>(handle, ctx.GetPlace(), SequenceLength, &workspace_size,
&reserve_size, state_out);
framework::Tensor workspace_data_;
workspace_data_.mutable_data<uint8_t>(
{static_cast<int64_t>(workspace_size)}, ctx.GetPlace());
auto *reserve_data = reserve->mutable_data<uint8_t>(
{static_cast<int64_t>(reserve_size)}, ctx.GetPlace());
if (is_test) {
RNNInferece(has_seq_length, handle, seq_length, &rnn, x_data, init_h_data,
init_c_data, w_data, out_data, last_h_data, last_c_data,
&workspace_data_, workspace_size);
} else {
if (!has_seq_length) {
// for train
// This interface is used when the input/output is unpadded.
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenRNNForwardTraining(
handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), x_data,
rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data,
rnn.weight_desc(), w_data, rnn.y_descs(), out_data,
rnn.last_h_desc(), last_h_data, rnn.last_c_desc(), last_c_data,
workspace_data_.data<uint8_t>(), workspace_size, reserve_data,
reserve_size));
#else
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNForwardTraining(
handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), x_data,
rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data,
rnn.weight_desc(), w_data, rnn.y_descs(), out_data,
rnn.last_h_desc(), last_h_data, rnn.last_c_desc(), last_c_data,
workspace_data_.data<uint8_t>(), workspace_size, reserve_data,
reserve_size));
#endif
} else {
#if defined(PADDLE_WITH_CUDA) && CUDNN_VERSION >= 7201
// for train
// This interface is used when the input/output is padded.
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNForwardTrainingEx(
handle, rnn.rnn_desc(), rnn.x_seq_desc(), x_data, rnn.init_h_desc(),
init_h_data, rnn.init_c_desc(), init_c_data, rnn.weight_desc(),
w_data, rnn.y_seq_desc(), out_data, rnn.last_h_desc(), last_h_data,
rnn.last_c_desc(), last_c_data, nullptr, nullptr, nullptr, nullptr,
nullptr, nullptr, nullptr, nullptr, workspace_data_.data<uint8_t>(),
workspace_size, reserve_data, reserve_size));
#else
PADDLE_THROW(platform::errors::Unavailable(
"The padded input is supported by "
"cudnnRNNForwardTrainingEx, but it only works when "
"the version of cudnn is larger than 7.2.1"));
#endif
}
}
}
void RNNInferece(const bool &has_seq_length, const gpuDnnHandle_t &handle,
const int &seq_length, RNNDescriptors *rnn, const T *x_data,
const T *init_h_data, const T *init_c_data, const T *w_data,
T *out_data, T *last_h_data, T *last_c_data,
framework::Tensor *workspace_data,
const size_t &workspace_size) const {
if (!has_seq_length) {
// for inference
// This interface is used when the input/output is unpadded.
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenRNNForwardInference(
handle, rnn->rnn_desc(), seq_length, rnn->x_descs(), x_data,
rnn->init_h_desc(), init_h_data, rnn->init_c_desc(), init_c_data,
rnn->weight_desc(), w_data, rnn->y_descs(), out_data,
rnn->last_h_desc(), last_h_data, rnn->last_c_desc(), last_c_data,
workspace_data->data<uint8_t>(), workspace_size));
#else
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNForwardInference(
handle, rnn->rnn_desc(), seq_length, rnn->x_descs(), x_data,
rnn->init_h_desc(), init_h_data, rnn->init_c_desc(), init_c_data,
rnn->weight_desc(), w_data, rnn->y_descs(), out_data,
rnn->last_h_desc(), last_h_data, rnn->last_c_desc(), last_c_data,
workspace_data->data<uint8_t>(), workspace_size));
#endif
} else {
#if defined(PADDLE_WITH_CUDA) && CUDNN_VERSION >= 7201
// for inference
// This interface is used when the input/output is padded.
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNForwardInferenceEx(
handle, rnn->rnn_desc(), rnn->x_seq_desc(), x_data,
rnn->init_h_desc(), init_h_data, rnn->init_c_desc(), init_c_data,
rnn->weight_desc(), w_data, rnn->y_seq_desc(), out_data,
rnn->last_h_desc(), last_h_data, rnn->last_c_desc(), last_c_data,
nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
nullptr, workspace_data->data<uint8_t>(), workspace_size));
#else
// CUDNN VERSION has to >=7.2.1
PADDLE_THROW(platform::errors::Unavailable(
"The padded input is supported by "
"cudnnRNNForwardInferenceEx, but it only works when "
"the version of cudnn is larger than 7.2.1"));
#endif
}
}
};
template <typename T>
class RNNGradCudnnKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *input = ctx.Input<Tensor>("Input");
auto pre_state = ctx.MultiInput<Tensor>("PreState");
auto weight_list = ctx.MultiInput<Tensor>("WeightList");
auto *state_out = ctx.Input<Tensor>("DropoutState");
auto *reserve = ctx.Input<Tensor>("Reserve");
auto *out = ctx.Input<Tensor>("Out");
// auto state = ctx.MultiInput<Tensor>("State");
auto *out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto state_grad = ctx.MultiInput<Tensor>(framework::GradVarName("State"));
auto *in_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
auto pre_state_grad =
ctx.MultiOutput<Tensor>(framework::GradVarName("PreState"));
auto weight_grad_list =
ctx.MultiOutput<Tensor>(framework::GradVarName("WeightList"));
float dropout_prob = ctx.Attr<float>("dropout_prob");
bool is_bidirec = ctx.Attr<bool>("is_bidirec");
int hidden_size = ctx.Attr<int>("hidden_size");
int num_layers = ctx.Attr<int>("num_layers");
auto mode = ctx.Attr<std::string>("mode");
#ifdef PADDLE_WITH_HIP
miopenRNNMode_t rnn_mode = miopenLSTM;
if (mode == "LSTM")
rnn_mode = miopenLSTM;
else if (mode == "GRU")
rnn_mode = miopenGRU;
else if (mode == "RNN_RELU")
rnn_mode = miopenRNNRELU;
else if (mode == "RNN_TANH")
rnn_mode = miopenRNNTANH;
#else
cudnnRNNMode_t rnn_mode = CUDNN_LSTM;
if (mode == "LSTM")
rnn_mode = CUDNN_LSTM;
else if (mode == "GRU")
rnn_mode = CUDNN_GRU;
else if (mode == "RNN_RELU")
rnn_mode = CUDNN_RNN_RELU;
else if (mode == "RNN_TANH")
rnn_mode = CUDNN_RNN_TANH;
#endif
else
PADDLE_THROW(platform::errors::InvalidArgument(
"rnn_mode should be LSTM, GRU, RNN_RELU or RNN_TANH, but received: "
"%s.",
mode));
bool is_test = ctx.Attr<bool>("is_test");
int seed = ctx.Attr<int>("seed");
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle();
auto place = ctx.GetPlace();
auto weight_numel = std::accumulate(
weight_list.begin(), weight_list.end(), 0,
[](int64_t num, const Tensor *t) { return num + t->numel(); });
bool continuous =
is_continuous<T, std::vector<const Tensor *>>(weight_list);
auto stream = reinterpret_cast<const platform::CUDADeviceContext &>(
ctx.device_context())
.stream();
Tensor weight_whole;
T *weight_data = nullptr;
#ifdef PADDLE_WITH_HIP
// Need to permute weight, set continuous to false
continuous = false;
#endif
if (!continuous) {
weight_whole.mutable_data<T>({weight_numel}, place);
#ifdef PADDLE_WITH_HIP
// MIOPEN need to permute weight for miopenLSTM or miopenGRU
weight_to_permuted_tensor<T>(place, stream, &weight_list, &weight_whole,
rnn_mode, is_bidirec);
#else
weight_to_tensor<T>(place, stream, weight_list, &weight_whole);
#endif
weight_data = weight_whole.data<T>();
} else {
weight_data = const_cast<T *>(weight_list[0]->data<T>());
}
Tensor weight_grad;
phi::funcs::SetConstant<paddle::platform::CUDADeviceContext, T> zero;
weight_grad.mutable_data<T>({weight_numel}, ctx.GetPlace());
zero(dev_ctx, &weight_grad, static_cast<T>(0.0));
T *weight_grad_data = weight_grad.data<T>();
#ifdef PADDLE_WITH_HIP
// MIOPEN need to permute weight_grad_list, so do not share data with
// weight_grad
for (size_t i = 0; i < weight_grad_list.size(); ++i) {
weight_grad_list[i]->mutable_data<T>(ctx.GetPlace());
}
#else
int offset = 0;
for (size_t i = 0; i < weight_grad_list.size(); ++i) {
size_t len = weight_grad_list[i]->numel();
auto dim = weight_grad_list[i]->dims();
weight_grad_list[i]
->ShareDataWith(weight_grad.Slice(static_cast<int64_t>(offset),
static_cast<int64_t>(offset + len)))
.Resize(dim);
offset += len;
}
#endif
Tensor input_grad_value;
if (!in_grad) {
in_grad = &input_grad_value;
in_grad->Resize(input->dims());
}
auto *init_h_data = pre_state[0]->data<T>();
// auto *last_h_data = state[0]->data<T>();
auto *last_h_grad_data = state_grad[0]->data<T>();
const T *init_c_data = nullptr;
// const T *last_c_data = nullptr;
const T *last_c_grad_data = nullptr;
T *init_h_grad_data =
pre_state_grad.size() != 0 && pre_state_grad[0]
? pre_state_grad[0]->mutable_data<T>(ctx.GetPlace())
: nullptr;
T *init_c_grad_data = nullptr;
#ifdef PADDLE_WITH_HIP
if (rnn_mode == miopenLSTM) {
#else
if (rnn_mode == CUDNN_LSTM) {
#endif
init_c_data = pre_state[1]->data<T>();
// last_c_data = state[1]->data<T>();
last_c_grad_data = state_grad[1]->data<T>();
init_c_grad_data =
pre_state_grad.size() != 0 && pre_state_grad[1]
? pre_state_grad[1]->mutable_data<T>(ctx.GetPlace())
: nullptr;
}
auto *out_data = out->data<T>();
auto *out_grad_data = out_grad->data<T>();
// need check exist
T *in_grad_data = nullptr;
if (in_grad) {
in_grad_data = in_grad->mutable_data<T>(ctx.GetPlace());
}
bool has_seq_length = ctx.HasInput("SequenceLength");
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_EQ(has_seq_length, false,
platform::errors::InvalidArgument(
"ROCm do not support SequenceLength yet."));
#endif
std::vector<int> SequenceLength;
if (has_seq_length) {
auto *sequence_length = ctx.Input<Tensor>("SequenceLength");
SequenceLength = operators::GetDataFromTensor<int>(sequence_length);
}
auto input_dims = input->dims();
int seq_length = input_dims[0];
int batch_size = input_dims[1];
int input_size = input_dims[2];
size_t workspace_size;
size_t reserve_size;
RNNDescriptors rnn(seq_length, batch_size, input_size, hidden_size,
num_layers, dropout_prob, seed, weight_numel, rnn_mode,
is_bidirec, is_test);
rnn.Create<T>(handle, ctx.GetPlace(), SequenceLength, &workspace_size,
&reserve_size, const_cast<Tensor *>(state_out));
framework::Tensor workspace_data_;
workspace_data_.mutable_data<uint8_t>(
{static_cast<int64_t>(workspace_size)}, ctx.GetPlace());
const uint8_t *reserve_data = reserve->data<uint8_t>();
if (!has_seq_length) {
if (in_grad) {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenRNNBackwardData(
handle, rnn.rnn_desc(), seq_length, rnn.y_descs(), out_data,
rnn.y_descs(), out_grad_data, rnn.last_h_desc(), last_h_grad_data,
rnn.last_c_desc(), last_c_grad_data, rnn.weight_desc(), weight_data,
rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data,
rnn.x_descs(), in_grad_data, rnn.init_h_desc(), init_h_grad_data,
rnn.init_c_desc(), init_c_grad_data,
workspace_data_.data<uint8_t>(), workspace_size,
const_cast<uint8_t *>(reserve_data), reserve_size));
#else
// This interface is used when the input/output is unpadded.
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNBackwardData(
handle, rnn.rnn_desc(), seq_length, rnn.y_descs(), out_data,
rnn.y_descs(), out_grad_data, rnn.last_h_desc(), last_h_grad_data,
rnn.last_c_desc(), last_c_grad_data, rnn.weight_desc(), weight_data,
rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data,
rnn.x_descs(), in_grad_data, rnn.init_h_desc(), init_h_grad_data,
rnn.init_c_desc(), init_c_grad_data,
workspace_data_.data<uint8_t>(), workspace_size,
const_cast<uint8_t *>(reserve_data), reserve_size));
#endif
}
if (!weight_grad_list.empty()) {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenRNNBackwardWeights(
handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), input->data<T>(),
rnn.init_h_desc(), init_h_data, rnn.y_descs(), out->data<T>(),
rnn.weight_desc(), weight_grad_data,
workspace_data_.data<uint8_t>(), workspace_size,
const_cast<uint8_t *>(reserve_data), reserve_size));
// permute weight grad list from weight grad tensor
tensor_to_permuted_weight<T>(place, stream, weight_grad,
&weight_grad_list, rnn_mode, is_bidirec);
#else
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNBackwardWeights(
handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), input->data<T>(),
rnn.init_h_desc(), init_h_data, rnn.y_descs(), out->data<T>(),
workspace_data_.data<uint8_t>(), workspace_size, rnn.weight_desc(),
weight_grad_data, const_cast<uint8_t *>(reserve_data),
reserve_size));
#endif
}
} else {
#if defined(PADDLE_WITH_CUDA) && CUDNN_VERSION >= 7201
// for train
// This interface is used when the input/output is padded.
if (in_grad) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNBackwardDataEx(
handle, rnn.rnn_desc(), rnn.y_seq_desc(), out_data,
rnn.y_seq_desc(), out_grad_data, nullptr, nullptr,
rnn.last_h_desc(), last_h_grad_data, rnn.last_c_desc(),
last_c_grad_data, rnn.weight_desc(), weight_data, rnn.init_h_desc(),
init_h_data, rnn.init_c_desc(), init_c_data, rnn.x_seq_desc(),
in_grad_data, rnn.init_h_desc(), init_h_grad_data,
rnn.init_c_desc(), init_c_grad_data, nullptr, nullptr,
workspace_data_.data<uint8_t>(), workspace_size,
const_cast<uint8_t *>(reserve_data), reserve_size));
}
if (!weight_grad_list.empty()) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNBackwardWeightsEx(
handle, rnn.rnn_desc(), rnn.x_seq_desc(), input->data<T>(),
rnn.init_h_desc(), init_h_data, rnn.y_seq_desc(), out->data<T>(),
workspace_data_.data<uint8_t>(), workspace_size, rnn.weight_desc(),
weight_grad_data, const_cast<uint8_t *>(reserve_data),
reserve_size));
}
#else
PADDLE_THROW(platform::errors::Unavailable(
"The padded input of rnn is supported by cudnnRNNBackwardDataEx, "
"cudnnRNNBackwardWeightsEx, but it only works when the version "
"of cudnn is larger than 7.2.1"));
#endif
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
REGISTER_OP_CUDA_KERNEL(rnn, ops::RNNCudnnKernel<float>);
REGISTER_OP_CUDA_KERNEL(rnn_grad, ops::RNNGradCudnnKernel<float>);
#else
REGISTER_OP_CUDA_KERNEL(rnn, ops::RNNCudnnKernel<float>,
ops::RNNCudnnKernel<double>);
REGISTER_OP_CUDA_KERNEL(rnn_grad, ops::RNNGradCudnnKernel<float>,
ops::RNNGradCudnnKernel<double>);
#endif
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <memory>
#include <string>
#include <type_traits>
#include <vector>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/operators/math/fc.h"
#include "paddle/fluid/operators/unique_op.h"
#include "paddle/fluid/operators/utils.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/detail/activation_functions.h"
#include "paddle/phi/kernels/funcs/gru_compute.h"
#include "paddle/phi/kernels/funcs/lstm_compute.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
using TensorList = std::vector<framework::Tensor>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
#define DEFINE_MODE_DETECTOR(MODE_NAME, MODE_STR) \
inline bool is_##MODE_NAME(const framework::ExecutionContext& ctx) { \
const std::string& mode = ctx.Attr<std::string>("mode"); \
return mode == #MODE_STR; \
}
DEFINE_MODE_DETECTOR(lstm, LSTM);
DEFINE_MODE_DETECTOR(gru, GRU);
DEFINE_MODE_DETECTOR(rnn_relu, RNN_RELU);
DEFINE_MODE_DETECTOR(rnn_tanh, RNN_TANH);
void SwapPoniter(Tensor** a, Tensor** b) {
Tensor* c = *a;
*a = *b;
*b = c;
}
template <typename T>
void create_mask_matrix(const framework::ExecutionContext& context,
const Tensor* sequence_length, Tensor* mask_matrix,
const bool& is_reverse, int* min_seq_len) {
const auto& seq_len_vec = GetDataFromTensor<int>(sequence_length);
const int& table_width = mask_matrix->dims()[0];
Tensor temp;
temp.Resize(phi::make_ddim({mask_matrix->dims()[1], mask_matrix->dims()[0]}));
T* data_temp = temp.mutable_data<T>(context.GetPlace());
std::fill(data_temp, data_temp + mask_matrix->numel(), static_cast<T>(1.0));
*min_seq_len = table_width;
for (unsigned int i = 0; i < seq_len_vec.size(); i++) {
// reset the mask matrix
*min_seq_len = std::min(seq_len_vec[i], *min_seq_len);
if (seq_len_vec[i] == table_width) {
continue;
}
if (is_reverse) {
std::fill(data_temp + i * table_width,
data_temp + (i + 1) * table_width - seq_len_vec[i],
static_cast<T>(0));
} else {
std::fill(data_temp + i * table_width + seq_len_vec[i],
data_temp + (i + 1) * table_width, static_cast<T>(0));
}
}
mask_matrix->mutable_data<T>(context.GetPlace());
std::vector<int> trans_vec;
trans_vec.emplace_back(1);
trans_vec.emplace_back(0);
auto& dev_ctx = context.template device_context<platform::CPUDeviceContext>();
TransCompute<platform::CPUDeviceContext, T>(2, dev_ctx, temp, mask_matrix,
trans_vec);
}
template <typename T>
struct Cell {
virtual ~Cell() {}
virtual void operator()(const platform::CPUDeviceContext* device_ctx,
Tensor* input, const Tensor* weight_hh,
const Tensor* init_h, const Tensor* init_c,
Tensor* last_h, Tensor* last_c, Tensor* last_c_act,
Tensor* output, const Tensor* bias_hh,
Tensor* weight_hh_gru) const {}
};
template <typename T, template <typename> class EigenActivationFunctor,
phi::funcs::detail::ActivationType act_type>
struct SimpleRNNCell : Cell<T> {
void operator()(const platform::CPUDeviceContext* device_ctx, Tensor* input,
const Tensor* weight_hh, const Tensor* init_h,
const Tensor* init_c, Tensor* last_h, Tensor* last_c,
Tensor* last_c_act, Tensor* output, const Tensor* bias_hh,
Tensor* weight_hh_gru) const override {
auto blas = phi::funcs::GetBlas<platform::CPUDeviceContext, T>(*device_ctx);
auto mat_dim_a =
phi::funcs::CreateMatrixDescriptor(init_h->dims(), 0, false);
auto mat_dim_b =
phi::funcs::CreateMatrixDescriptor(weight_hh->dims(), 0, true);
mat_dim_a.height_ *= mat_dim_a.batch_size_;
mat_dim_a.batch_size_ = 0;
// convert the batch matmul to matmul, this operator could be speed faster
blas.MatMul(*init_h, mat_dim_a, *weight_hh, mat_dim_b, static_cast<T>(1.0),
input, static_cast<T>(1.0));
auto z = EigenVector<T>::Flatten(
GET_DATA_SAFELY(input, "Input", "z", "Activation"));
auto hidden = EigenVector<T>::Flatten(
GET_DATA_SAFELY(output, "Output", "hidden", "Activation"));
auto* place = device_ctx->eigen_device();
EigenActivationFunctor<T> functor;
functor(*place, z, hidden);
}
};
template <typename T>
struct GRUCell : Cell<T> {
void operator()(const platform::CPUDeviceContext* device_ctx, Tensor* input,
const Tensor* weight_hh, const Tensor* init_h,
const Tensor* init_c, Tensor* last_h, Tensor* last_c,
Tensor* last_c_act, Tensor* output, const Tensor* bias_hh,
Tensor* weight_hh_gru) const override {
auto blas = phi::funcs::GetBlas<platform::CPUDeviceContext, T>(*device_ctx);
auto mat_dim_a =
phi::funcs::CreateMatrixDescriptor(init_h->dims(), 0, false);
auto mat_dim_b =
phi::funcs::CreateMatrixDescriptor(weight_hh_gru->dims(), 0, true);
mat_dim_a.height_ *= mat_dim_a.batch_size_;
mat_dim_a.batch_size_ = 0;
// convert the batch matmul to matmul, this operator could be speed faster
blas.MatMul(*init_h, mat_dim_a, *weight_hh_gru, mat_dim_b,
static_cast<T>(1.0), input, static_cast<T>(1.0));
size_t frame_size = init_h->dims()[2];
size_t batch_size = init_h->dims()[1];
phi::funcs::GRUMetaValue<T> gru_value;
gru_value.gate_weight = weight_hh->data<T>();
gru_value.state_weight = weight_hh->data<T>() + 2 * frame_size * frame_size;
gru_value.reset_bias = bias_hh->data<T>() + 2 * frame_size;
gru_value.gate_value = input->data<T>();
gru_value.reset_output_value = last_c->data<T>();
gru_value.output_value = output->data<T>();
gru_value.prev_out_value = init_h->data<T>();
auto gate_act = phi::funcs::detail::GetActivationType("sigmoid_v2");
auto cand_act = phi::funcs::detail::GetActivationType("tanh_v2");
phi::funcs::GRUUnitFunctorV2<platform::CPUDeviceContext, T>::compute(
*device_ctx, gru_value, frame_size, batch_size, cand_act, gate_act);
}
};
template <typename T>
struct LSTMCell : Cell<T> {
void operator()(const platform::CPUDeviceContext* device_ctx, Tensor* input,
const Tensor* weight_hh, const Tensor* init_h,
const Tensor* init_c, Tensor* last_h, Tensor* last_c,
Tensor* last_c_act, Tensor* output, const Tensor* bias_hh,
Tensor* weight_hh_gru) const override {
auto blas = phi::funcs::GetBlas<platform::CPUDeviceContext, T>(*device_ctx);
auto mat_dim_a =
phi::funcs::CreateMatrixDescriptor(init_h->dims(), 0, false);
auto mat_dim_b =
phi::funcs::CreateMatrixDescriptor(weight_hh->dims(), 0, true);
mat_dim_a.height_ *= mat_dim_a.batch_size_;
mat_dim_a.batch_size_ = 0;
// convert the batch matmul to matmul, this operator could be speed faster
blas.MatMul(*init_h, mat_dim_a, *weight_hh, mat_dim_b, static_cast<T>(1.0),
input, static_cast<T>(1.0));
phi::funcs::LstmMetaValue<T> lstm_value;
lstm_value.check_ig = nullptr;
lstm_value.check_fg = nullptr;
lstm_value.check_og = nullptr;
auto gate_act = phi::funcs::detail::GetActivationType("sigmoid_v2");
auto cell_act = phi::funcs::detail::GetActivationType("tanh_v2");
auto cand_act = phi::funcs::detail::GetActivationType("tanh_v2");
size_t frame_size = init_h->dims()[2];
size_t batch_size = init_h->dims()[1];
Tensor cell_pre_act;
if (last_c_act == nullptr) { /* is test */
cell_pre_act.mutable_data<T>(init_h->dims(), device_ctx->GetPlace());
last_c_act = &cell_pre_act;
}
lstm_value.prev_state_value = init_c->data<T>();
lstm_value.gate_value = input->data<T>();
lstm_value.output_value = output->data<T>();
lstm_value.state_value = last_c->data<T>();
lstm_value.state_active_value = last_c_act->data<T>();
T cell_clip = 0.0;
phi::funcs::LstmUnitFunctor<platform::CPUDeviceContext, T>::compute(
*device_ctx, lstm_value, frame_size, batch_size, cell_clip, gate_act,
cell_act, cand_act, false);
}
};
template <typename T>
void dropout_helper(const framework::ExecutionContext& context, Tensor* x,
Tensor* y, const Tensor* mask, const float& dropout_prob) {
auto& place = *context.template device_context<platform::CPUDeviceContext>()
.eigen_device();
auto dropout_mask = EigenVector<uint8_t>::Flatten(*mask);
auto in = EigenVector<T>::Flatten(*x);
auto out = EigenVector<T>::Flatten(*y);
if (dropout_prob == 1.0f) {
out.device(place) = static_cast<T>(0) * in;
} else {
out.device(place) =
in * dropout_mask.cast<T>() / static_cast<T>(1.0f - dropout_prob);
}
}
template <typename T>
void dropout_cpu_function_inplace(const framework::ExecutionContext& context,
Tensor* x, Tensor* y, Tensor* mask,
const float& dropout_prob,
const int& seed_number, bool is_test,
bool* is_has_reset) {
if (is_test) {
return;
}
size_t size = phi::product(x->dims());
auto* mask_data = mask->data<uint8_t>();
if (!(*is_has_reset)) {
// Special case when dropout_prob is 1.0
if (dropout_prob == 1.0f) {
std::fill(mask_data, mask_data + size, static_cast<uint8_t>(0));
} else {
auto engine = framework::GetCPURandomEngine(seed_number);
std::uniform_real_distribution<float> dist(0, 1);
for (size_t i = 0; i < size; ++i) {
if (dist(*engine) < dropout_prob) {
mask_data[i] = 0;
} else {
mask_data[i] = 1;
}
}
}
*is_has_reset = true;
}
dropout_helper<T>(context, x, y, mask, dropout_prob);
}
template <typename T>
void dropout_cpu_grad_function_inplace(
const framework::ExecutionContext& context, Tensor* grad_x,
const Tensor* mask, const float& dropout_prob) {
dropout_helper<T>(context, grad_x, grad_x, mask, dropout_prob);
}
template <typename T, typename CellType>
struct Layer {
explicit Layer(const CellType& cell) : cell_(cell) {}
virtual ~Layer() {}
void preprocess(const framework::ExecutionContext& context,
const Tensor* input, const Tensor& weight,
const Tensor& bias_ih, const Tensor& bias_hh,
Tensor* cache_input, bool is_test) {
// crate the temp input for the X * W_ih^T + Bias_ih
auto& dev_ctx =
context.template device_context<platform::CPUDeviceContext>();
const int& hidden_size = weight.dims()[0];
cache_input->Resize(
phi::make_ddim({input->dims()[0], input->dims()[1], hidden_size}));
if (is_test) {
cache_input->mutable_data<T>(context.GetPlace());
}
auto blas = phi::funcs::GetBlas<platform::CPUDeviceContext, T>(dev_ctx);
auto mat_dim_a =
phi::funcs::CreateMatrixDescriptor(input->dims(), 0, false);
auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(weight.dims(), 0, true);
// convert the batch matmul to matmul, this operator could be speed faster
mat_dim_a.height_ *= mat_dim_a.batch_size_;
mat_dim_a.batch_size_ = 0;
blas.MatMul(*input, mat_dim_a, weight, mat_dim_b, static_cast<T>(1.0),
cache_input, static_cast<T>(0));
auto in = framework::EigenMatrix<T>::Reshape(
*cache_input, cache_input->dims().size() - 1);
auto bias_ih_tmp = framework::EigenMatrix<T>::From(
bias_ih, phi::make_ddim({1, bias_ih.dims()[0]}));
const int& row_num =
phi::product(cache_input->dims()) / cache_input->dims()[2];
in = in + bias_ih_tmp.broadcast(Eigen::DSizes<int, 2>(row_num, 1));
if (is_gru(context)) {
// reset_gate update_gate cell_gate = [1, 1, 0]
Tensor bias_hh_tmp;
bias_hh_tmp.Resize({bias_hh.numel()});
bias_hh_tmp.mutable_data<T>(context.GetPlace());
framework::TensorCopy(bias_hh, context.GetPlace(), dev_ctx, &bias_hh_tmp);
bias_hh_tmp.Resize({3, bias_hh_tmp.numel() / 3});
auto bias_hh_tmp_unbind = Unbind(bias_hh_tmp);
phi::funcs::SetConstant<platform::CPUDeviceContext, T> zero;
zero(dev_ctx, &bias_hh_tmp_unbind[2], static_cast<T>(0.0));
auto bias_hh_after_mask = framework::EigenMatrix<T>::From(
bias_hh_tmp, phi::make_ddim({1, bias_hh.dims()[0]}));
in = in + bias_hh_after_mask.broadcast(Eigen::DSizes<int, 2>(row_num, 1));
} else {
auto bias_hh_no_mask = framework::EigenMatrix<T>::From(
bias_hh, phi::make_ddim({1, bias_hh.dims()[0]}));
in = in + bias_hh_no_mask.broadcast(Eigen::DSizes<int, 2>(row_num, 1));
}
}
void postprocess(const framework::ExecutionContext& context, Tensor* output,
const Tensor* init_h, const Tensor* init_c, Tensor* last_h,
Tensor* last_c, const Tensor& mask_tensor) {
// in the output, if mask flag is 0, we will retun the zero data
auto& place = *context.template device_context<platform::CPUDeviceContext>()
.eigen_device();
auto out =
framework::EigenMatrix<T>::Reshape(*output, output->dims().size() - 1);
auto mask = framework::EigenMatrix<T>::From(
mask_tensor, phi::make_ddim({mask_tensor.dims()[1], 1}));
auto pre_h =
framework::EigenMatrix<T>::Reshape(*init_h, init_h->dims().size() - 1);
auto curr_h =
framework::EigenMatrix<T>::Reshape(*last_h, last_h->dims().size() - 1);
auto mask_broadcast =
mask.broadcast(Eigen::DSizes<int, 2>(1, output->dims()[2]));
curr_h.device(place) = out * mask_broadcast + pre_h * (1 - mask_broadcast);
out.device(place) = out * mask_broadcast;
if (is_lstm(context)) {
auto pre_c = framework::EigenMatrix<T>::Reshape(
*init_c, init_c->dims().size() - 1);
auto curr_c = framework::EigenMatrix<T>::Reshape(
*last_c, last_c->dims().size() - 1);
curr_c.device(place) =
curr_c * mask_broadcast + pre_c * (1 - mask_broadcast);
}
}
virtual void operator()(const framework::ExecutionContext& context,
const Tensor* input, const TensorList& vec,
const TensorList& init_h, const TensorList& init_c,
const Tensor* sequence_length, TensorList last_h,
TensorList last_c, Tensor* output,
const int& layer_idx, const int& gate_num,
Tensor* gate_value, Tensor* cell_value,
Tensor* cell_act_value, bool is_test) {}
void RunTestIter(const framework::ExecutionContext& context,
const Tensor* input, const TensorList& vec,
const TensorList& init_h, const TensorList& init_c,
const Tensor* sequence_length, TensorList* last_h_ptr,
TensorList* last_c_ptr, Tensor* output, int layer_idx,
Tensor* gate_value, Tensor* cell_value,
Tensor* cell_act_value, bool is_bidirect, int offset) {
bool is_reverse = false;
if (is_bidirect) {
layer_idx = 2 * layer_idx + offset;
if (offset > 0) {
is_reverse = true;
}
}
auto& dev_ctx =
context.template device_context<platform::CPUDeviceContext>();
const int& time_step = input->dims()[0];
this->preprocess(context, input, vec[0 + offset * 4], vec[2 + offset * 4],
vec[3 + offset * 4], gate_value, true);
auto input_tensors = Unbind(*gate_value);
auto output_tensors = Unbind(*output);
if (is_reverse) {
std::reverse(input_tensors.begin(), input_tensors.end());
std::reverse(output_tensors.begin(), output_tensors.end());
}
TensorList mask_tensor_list;
// construct the mask matrix for the mask
bool has_sequence_length = false;
if (sequence_length != nullptr) {
has_sequence_length = true;
}
Tensor mask_matrix;
int mask_min_length = time_step;
if (has_sequence_length) {
mask_matrix.Resize(phi::make_ddim({time_step, input->dims()[1]}));
create_mask_matrix<T>(context, sequence_length, &mask_matrix, is_reverse,
&mask_min_length);
mask_tensor_list = Unbind(mask_matrix);
}
if (is_reverse) {
mask_min_length = mask_min_length - time_step + 1;
}
bool has_allocate_mem_c = false;
bool has_use_last_h_holder = false;
const int& reverse_flag = is_reverse ? -1 : 1;
// define the init_h holder for the swap
Tensor init_h_temp;
framework::TensorCopy(*&init_h[layer_idx], context.GetPlace(), dev_ctx,
&init_h_temp);
Tensor* init_h_holder = &init_h_temp;
Tensor* last_h_holder = nullptr;
if (0 < mask_min_length) {
last_h_holder = &(output_tensors[0]);
} else {
last_h_holder = &(*last_h_ptr)[layer_idx];
has_use_last_h_holder = true;
}
Tensor* init_c_holder = nullptr;
const Tensor* init_c_temp_holder = nullptr;
Tensor init_c_temp;
Tensor* last_c_holder = nullptr;
Tensor last_c_temp;
if (is_lstm(context)) {
last_c_holder = &(*last_c_ptr)[layer_idx];
init_c_temp_holder = &init_c[layer_idx];
} else if (is_gru(context)) {
// for reset output value
last_c_temp.Resize(init_h[layer_idx].dims());
last_c_temp.mutable_data<T>(context.GetPlace());
last_c_holder = &last_c_temp;
}
Tensor weight_hh_tmp; // for gru
if (is_gru(context)) {
weight_hh_tmp.Resize(vec[1 + offset * 4].dims());
weight_hh_tmp.mutable_data<T>(context.GetPlace());
framework::TensorCopy(vec[1 + offset * 4], context.GetPlace(), dev_ctx,
&weight_hh_tmp);
weight_hh_tmp.Resize({3, weight_hh_tmp.numel() / 3});
auto weight_hh_tmp_unbind = Unbind(weight_hh_tmp);
phi::funcs::SetConstant<platform::CPUDeviceContext, T> zero;
zero(dev_ctx, &weight_hh_tmp_unbind[2], static_cast<T>(0.0));
weight_hh_tmp.Resize(vec[1 + offset * 4].dims());
}
for (int i = 0; i < time_step; i++) {
bool in_mask = (reverse_flag * i) >= mask_min_length;
if (i > 0) {
if (!has_allocate_mem_c) {
if (is_lstm(context) || is_gru(context)) {
init_c_temp.Resize(init_h[layer_idx].dims());
init_c_temp.mutable_data<T>(context.GetPlace());
init_c_holder = &init_c_temp;
}
has_allocate_mem_c = true;
}
SwapPoniter(&init_c_holder, &last_c_holder);
init_c_temp_holder = init_c_holder;
}
cell_(&dev_ctx, &input_tensors[i], &vec[1 + offset * 4], init_h_holder,
init_c_temp_holder, last_h_holder, last_c_holder, nullptr,
&output_tensors[i], &vec[3 + offset * 4] /* bias_hh */,
&weight_hh_tmp);
if (in_mask) {
this->postprocess(context, &output_tensors[i], init_h_holder,
init_c_temp_holder, last_h_holder, last_c_holder,
mask_tensor_list[i]);
}
// prepare next step
if (i + 1 < time_step) {
bool next_step_mask = (reverse_flag * (i + 1)) >= mask_min_length;
if (next_step_mask) {
if (!has_use_last_h_holder) {
init_h_holder = &(*last_h_ptr)[layer_idx];
}
} else {
init_h_holder = &(output_tensors[i + 1]);
}
SwapPoniter(&init_h_holder, &last_h_holder);
}
}
if (has_sequence_length) {
if (last_h_holder != &(*last_h_ptr)[layer_idx]) {
framework::TensorCopy(*last_h_holder, context.GetPlace(), dev_ctx,
&(*last_h_ptr)[layer_idx]);
}
} else {
framework::TensorCopy(output_tensors[time_step - 1], context.GetPlace(),
dev_ctx, &(*last_h_ptr)[layer_idx]);
}
if (time_step % 2 == 0) {
if (is_lstm(context)) {
framework::TensorCopy(*last_c_holder, context.GetPlace(), dev_ctx,
&(*last_c_ptr)[layer_idx]);
}
}
}
void RunIter(const framework::ExecutionContext& context, const Tensor* input,
const TensorList& vec, const TensorList& init_h,
const TensorList& init_c, const Tensor* sequence_length,
TensorList* last_h_ptr, TensorList* last_c_ptr, Tensor* output,
int layer_idx, Tensor* gate_value, Tensor* cell_value,
Tensor* cell_act_value, bool is_bidirect, int offset,
bool is_test) {
if (is_test) {
RunTestIter(context, input, vec, init_h, init_c, sequence_length,
last_h_ptr, last_c_ptr, output, layer_idx, gate_value,
cell_value, cell_act_value, is_bidirect, offset);
return;
}
bool is_reverse = false;
if (is_bidirect) {
layer_idx = 2 * layer_idx + offset;
if (offset > 0) {
is_reverse = true;
}
}
auto& dev_ctx =
context.template device_context<platform::CPUDeviceContext>();
const int& time_step = input->dims()[0];
this->preprocess(context, input, vec[0 + offset * 4], vec[2 + offset * 4],
vec[3 + offset * 4], gate_value, is_test);
auto input_tensors = Unbind(*gate_value);
auto output_tensors = Unbind(*output);
if (is_reverse) {
std::reverse(input_tensors.begin(), input_tensors.end());
std::reverse(output_tensors.begin(), output_tensors.end());
}
TensorList mask_tensor_list;
// construct the mask matrix for the mask
bool has_sequence_length = false;
if (sequence_length != nullptr) {
has_sequence_length = true;
}
Tensor mask_matrix;
int mask_min_length = time_step;
if (has_sequence_length) {
mask_matrix.Resize(phi::make_ddim({time_step, input->dims()[1]}));
create_mask_matrix<T>(context, sequence_length, &mask_matrix, is_reverse,
&mask_min_length);
mask_tensor_list = Unbind(mask_matrix);
}
if (is_reverse) {
mask_min_length = mask_min_length - time_step + 1;
}
// define the init_h holder for the swap
bool has_use_last_h_holder = false;
const int& reverse_flag = is_reverse ? -1 : 1;
TensorList cell_value_tensors;
TensorList cell_act_value_tensors;
Tensor init_h_temp;
framework::TensorCopy(*&init_h[layer_idx], context.GetPlace(), dev_ctx,
&init_h_temp);
Tensor* init_h_holder = &init_h_temp;
Tensor* last_h_holder = nullptr;
if (0 < mask_min_length) {
last_h_holder = &(output_tensors[0]);
} else {
last_h_holder = &(*last_h_ptr)[layer_idx];
has_use_last_h_holder = true;
}
const Tensor* init_c_holder = nullptr;
Tensor* last_c_holder = nullptr;
Tensor* last_c_act_holder = nullptr;
if (is_lstm(context) || is_gru(context)) {
cell_value->Resize({time_step, cell_value->numel() / time_step});
cell_value_tensors = Unbind(*cell_value);
if (is_lstm(context)) {
cell_act_value->Resize(
{time_step, cell_act_value->numel() / time_step});
cell_act_value_tensors = Unbind(*cell_act_value);
}
}
Tensor weight_hh_tmp; // for gru
if (is_gru(context)) {
weight_hh_tmp.Resize(vec[1 + offset * 4].dims());
weight_hh_tmp.mutable_data<T>(context.GetPlace());
framework::TensorCopy(vec[1 + offset * 4], context.GetPlace(), dev_ctx,
&weight_hh_tmp);
weight_hh_tmp.Resize({3, weight_hh_tmp.numel() / 3});
auto weight_hh_tmp_unbind = Unbind(weight_hh_tmp);
phi::funcs::SetConstant<platform::CPUDeviceContext, T> zero;
zero(dev_ctx, &weight_hh_tmp_unbind[2], static_cast<T>(0.0));
weight_hh_tmp.Resize(vec[1 + offset * 4].dims());
}
for (int i = 0; i < time_step; i++) {
bool in_mask = (reverse_flag * i) >= mask_min_length;
if (is_lstm(context)) {
if (i == 0) {
init_c_holder = &init_c[layer_idx];
} else {
init_c_holder = &cell_value_tensors[i - 1];
}
cell_value_tensors[i].Resize(init_c[layer_idx].dims());
cell_act_value_tensors[i].Resize(init_c[layer_idx].dims());
last_c_holder = &cell_value_tensors[i];
last_c_act_holder = &cell_act_value_tensors[i];
} else if (is_gru(context)) {
cell_value_tensors[i].Resize(init_h[layer_idx].dims());
last_c_holder = &cell_value_tensors[i];
}
cell_(&dev_ctx, &input_tensors[i], &vec[1 + offset * 4], init_h_holder,
init_c_holder, last_h_holder, last_c_holder, last_c_act_holder,
&output_tensors[i], &vec[3 + offset * 4] /* bias_hh */,
&weight_hh_tmp);
if (in_mask) {
this->postprocess(context, &output_tensors[i], init_h_holder,
init_c_holder, last_h_holder, last_c_holder,
mask_tensor_list[i]);
}
// prepare next step
if (i + 1 < time_step) {
bool next_step_mask = (reverse_flag * (i + 1)) >= mask_min_length;
if (next_step_mask) {
if (!has_use_last_h_holder) {
init_h_holder = &(*last_h_ptr)[layer_idx];
}
} else {
init_h_holder = &(output_tensors[i + 1]);
}
SwapPoniter(&init_h_holder, &last_h_holder);
}
}
if (has_sequence_length) {
if (last_h_holder != &(*last_h_ptr)[layer_idx]) {
framework::TensorCopy(*last_h_holder, context.GetPlace(), dev_ctx,
&(*last_h_ptr)[layer_idx]);
}
} else {
framework::TensorCopy(output_tensors[time_step - 1], context.GetPlace(),
dev_ctx, &(*last_h_ptr)[layer_idx]);
}
if (is_lstm(context)) {
framework::TensorCopy(cell_value_tensors[time_step - 1],
context.GetPlace(), dev_ctx,
&(*last_c_ptr)[layer_idx]);
}
}
// Cell for the rnn module
CellType cell_;
};
template <typename T, typename CellType>
struct SingleLayer : public Layer<T, CellType> {
explicit SingleLayer(const CellType& cell) : Layer<T, CellType>(cell) {}
void operator()(const framework::ExecutionContext& context,
const Tensor* input, const TensorList& vec,
const TensorList& init_h, const TensorList& init_c,
const Tensor* sequence_length, TensorList last_h,
TensorList last_c, Tensor* output, const int& layer_idx,
const int& gate_num, Tensor* gate_value, Tensor* cell_value,
Tensor* cell_act_value, bool is_test) {
this->RunIter(context, input, vec, init_h, init_c, sequence_length, &last_h,
&last_c, output, layer_idx, gate_value, cell_value,
cell_act_value, false, 0, is_test);
}
};
template <typename T, typename CellType>
struct BidirLayer : public Layer<T, CellType> {
explicit BidirLayer(const CellType& cell) : Layer<T, CellType>(cell) {}
void operator()(const framework::ExecutionContext& context,
const Tensor* input, const TensorList& vec,
const TensorList& init_h, const TensorList& init_c,
const Tensor* sequence_length, TensorList last_h,
TensorList last_c, Tensor* output, const int& layer_idx,
const int& gate_num, Tensor* gate_value, Tensor* cell_value,
Tensor* cell_act_value, bool is_test) {
TensorList output_vec(2);
Tensor forward_input_w, forward_cell_value, forward_cell_act_value;
Tensor backward_input_w, backward_cell_value, backward_cell_act_value;
int time_step = input->dims()[0];
int batch_size = input->dims()[1];
int hidden_size = output->dims()[2];
for (int i = 0; i < 2; ++i) {
output_vec[i].Resize({time_step, batch_size, hidden_size / 2});
output_vec[i].mutable_data<T>(context.GetPlace());
}
if (!is_test) {
gate_value->Resize({2, gate_value->numel() / 2});
forward_input_w = gate_value->Slice(0, 1);
backward_input_w = gate_value->Slice(1, 2);
if (is_lstm(context) || is_gru(context)) /* for lstm and gru */ {
cell_value->Resize({2, cell_value->numel() / 2});
cell_act_value->Resize({2, cell_act_value->numel() / 2});
forward_cell_value = cell_value->Slice(0, 1);
backward_cell_value = cell_value->Slice(1, 2);
if (is_lstm(context)) {
forward_cell_act_value = cell_act_value->Slice(0, 1);
backward_cell_act_value = cell_act_value->Slice(1, 2);
}
}
}
this->RunIter(context, input, vec, init_h, init_c, sequence_length, &last_h,
&last_c, &output_vec[0], layer_idx, &forward_input_w,
&forward_cell_value, &forward_cell_act_value, true, 0,
is_test);
this->RunIter(context, input, vec, init_h, init_c, sequence_length, &last_h,
&last_c, &output_vec[1], layer_idx, &backward_input_w,
&backward_cell_value, &backward_cell_act_value, true, 1,
is_test);
// concat the the output result
auto& dev_ctx =
context.template device_context<platform::CPUDeviceContext>();
paddle::operators::math::ConcatFunctor<platform::CPUDeviceContext, T>
concat_functor;
concat_functor(dev_ctx, output_vec, static_cast<int>(2), output);
}
};
template <typename TensorType>
void SplitReserveData(const framework::ExecutionContext& ctx,
TensorType* reserve_data, Tensor* gate_data,
Tensor* cell_data, Tensor* cell_act_data,
Tensor* hidden_data, int direction_num,
const int& time_step, const int& batch_size,
const int& hidden_size, const int& gate_num,
const int& num_layers) {
const int& gate_data_idx = gate_num * num_layers;
const int& cell_data_idx = (gate_num + 1) * num_layers;
const int& cell_act_data_idx = (gate_num + 2) * num_layers;
// simple rnn
int hidden_data_start_idx = gate_data_idx;
*gate_data = reserve_data->Slice(0, gate_data_idx);
if (is_lstm(ctx)) {
*cell_data = reserve_data->Slice(gate_data_idx, cell_data_idx);
*cell_act_data = reserve_data->Slice(cell_data_idx, cell_act_data_idx);
hidden_data_start_idx = cell_act_data_idx;
} else if (is_gru(ctx)) {
*cell_data = reserve_data->Slice(gate_data_idx, cell_data_idx);
hidden_data_start_idx = cell_data_idx;
}
int hidden_data_idx = hidden_data_start_idx + (num_layers - 1);
if (num_layers > 1) {
*hidden_data = reserve_data->Slice(hidden_data_start_idx, hidden_data_idx);
}
}
template <typename TensorType>
void reset_parameter_vector(const std::vector<TensorType>& raw_params_vec,
const int& num_layers, const int& gate_num,
const bool& is_bidirec,
std::vector<TensorList>* params_vec) {
// the parameter raw seuquence is [FWhi, FWhh, BWhi, BWhh] * num_layers
// + [FBhi, FBhh, BBhi, BBhh] * num_layers, we will reset the parameter to
// ([FWhi, FWhh, FBhi, FBhh] + [BWhi, BWhh, BBhi, BBhh]) * num_layers
const int& direction_num = is_bidirec ? 2 : 1;
const int& layer_weight_size = 4 * direction_num;
const int& all_weight_size = num_layers * layer_weight_size;
const int& bias_start_idx = all_weight_size / 2;
for (int i = 0; i < num_layers; i++) {
TensorList tensor_list;
tensor_list.reserve(layer_weight_size);
for (int j = 0; j < layer_weight_size; j++) {
Tensor tensor_holder;
tensor_list.emplace_back(tensor_holder);
}
for (int j = 0; j < layer_weight_size; j++) {
int k = j % 4;
const int& section = j / 4;
int tensor_idx = i * 2 * direction_num + section * 2 + k % 2;
if (k >= 2) {
tensor_idx += bias_start_idx;
}
tensor_list[j].ShareDataWith(*raw_params_vec[tensor_idx]);
}
params_vec->emplace_back(tensor_list);
}
}
template <typename CellType, typename T>
void AllocateReserveData(const framework::ExecutionContext& ctx,
Tensor* reserve_data, Tensor* gate_data,
Tensor* cell_data, Tensor* cell_act_data,
Tensor* hidden_data, const Tensor* input,
bool is_bidirec, int num_layers, int gate_num,
int hidden_size) {
const int& direction_num = is_bidirec ? 2 : 1;
const int& time_step = input->dims()[0];
const int& batch_size = input->dims()[1];
const int& block_size = direction_num * time_step * batch_size * hidden_size;
int hidden_data_idx = (num_layers - 1);
if (is_lstm(ctx)) {
hidden_data_idx += (gate_num + 2) * num_layers;
} else if (is_gru(ctx)) {
hidden_data_idx += (gate_num + 1) * num_layers;
} else {
hidden_data_idx += gate_num * num_layers;
}
reserve_data->Resize({hidden_data_idx, block_size});
reserve_data->mutable_data<T>(ctx.GetPlace());
SplitReserveData(ctx, reserve_data, gate_data, cell_data, cell_act_data,
hidden_data, direction_num, time_step, batch_size,
hidden_size, gate_num, num_layers);
}
template <typename CellType, template <typename, typename> class LayerT,
template <typename, typename> class SingleLayerT,
template <typename, typename> class BidirLayerT, typename T>
void RnnFunc(const framework::ExecutionContext& ctx, const Tensor* input,
const std::vector<const Tensor*> weight_list, const Tensor* init_h,
const Tensor* init_c, const Tensor* sequence_length,
Tensor* last_h, Tensor* last_c, Tensor* output,
Tensor* dropout_mask, const int& num_layers, const int& gate_num,
const int& input_size, const int& hidden_size,
const bool& is_bidirec, const std::string& cell_type,
const float& dropout_prob, bool is_test, const int& seed,
Tensor* reserve_data) {
const int& direction_num = is_bidirec ? 2 : 1;
const auto& init_h_dims = init_h->dims();
PADDLE_ENFORCE_EQ(init_h_dims[0], num_layers * direction_num,
platform::errors::InvalidArgument(
"The num_layers of in RNN layer must be the same as "
"first dim of init hidden, but received"
" num_layers:%d, dim:%d",
num_layers, init_h_dims[0]));
if (is_lstm(ctx)) {
const auto& init_c_dims = init_c->dims();
PADDLE_ENFORCE_EQ(init_c_dims[0], num_layers * direction_num,
platform::errors::InvalidArgument(
"The num_layers of in RNN layer must be the same as "
"first dim of cell state hidden, but received"
" num_layers:%d, dim:%d",
num_layers, init_h_dims[0]));
}
CellType cell;
std::vector<TensorList> parameter_lists;
parameter_lists.reserve(num_layers);
reset_parameter_vector(weight_list, num_layers, gate_num, is_bidirec,
&parameter_lists);
Tensor gate_data, cell_data, cell_act_data, hidden_data;
if (!is_test) {
AllocateReserveData<CellType, T>(
ctx, reserve_data, &gate_data, &cell_data, &cell_act_data, &hidden_data,
input, is_bidirec, num_layers, gate_num, hidden_size);
gate_data.Resize({num_layers, gate_data.numel() / num_layers});
cell_data.Resize({num_layers, cell_data.numel() / num_layers});
cell_act_data.Resize({num_layers, cell_act_data.numel() / num_layers});
if (num_layers > 1) {
hidden_data.Resize(
{num_layers - 1, hidden_data.numel() / (num_layers - 1)});
}
}
Tensor* input_holder;
Tensor* output_holder = output;
Tensor temp;
bool has_allocate_mem = false;
auto init_h_unbind = Unbind(*init_h);
auto last_h_unbind = Unbind(*last_h);
TensorList init_c_unbind, last_c_unbind;
if (is_lstm(ctx)) {
init_c_unbind = Unbind(*init_c);
last_c_unbind = Unbind(*last_c);
}
Tensor curr_gate_data, curr_cell_data, curr_cell_act_data;
Tensor curr_hidden_data, prev_hidden_data;
bool has_dropout_reset = false;
for (int i = 0; i < num_layers; i++) {
if (!is_test) {
if (cell_data.numel() > 0) /** for lstm, gru **/ {
curr_cell_data = cell_data.Slice(i, i + 1);
}
if (cell_act_data.numel() > 0) /*for lstm*/ {
curr_cell_act_data = cell_act_data.Slice(i, i + 1);
}
curr_gate_data = gate_data.Slice(i, i + 1);
output_holder = output;
if (i < num_layers - 1 && num_layers > 1) {
curr_hidden_data = hidden_data.Slice(i, i + 1);
curr_hidden_data.Resize(output->dims());
output_holder = &curr_hidden_data;
}
}
if (i > 0) {
if (!has_allocate_mem) {
temp.Resize(output->dims());
temp.mutable_data<T>(ctx.GetPlace());
input_holder = &temp;
has_allocate_mem = true;
}
if (!is_test) {
prev_hidden_data = hidden_data.Slice(i - 1, i);
input_holder->Resize(output->dims());
if (dropout_prob != 0) {
dropout_cpu_function_inplace<T>(ctx, &prev_hidden_data, input_holder,
dropout_mask, dropout_prob, seed,
is_test, &has_dropout_reset);
} else {
input_holder = &prev_hidden_data;
input_holder->Resize(output->dims());
}
} else {
SwapPoniter(&output_holder, &input_holder);
}
}
const Tensor* input_temp_holder = input;
if (i > 0) {
input_temp_holder = input_holder;
}
LayerT<T, CellType>* layer;
SingleLayerT<T, CellType> slayer(cell);
BidirLayerT<T, CellType> blayer(cell);
if (is_bidirec) {
layer = &blayer;
} else {
layer = &slayer;
}
(*layer)(ctx, input_temp_holder, parameter_lists[i], init_h_unbind,
init_c_unbind, sequence_length, last_h_unbind, last_c_unbind,
output_holder, i, gate_num, &curr_gate_data, &curr_cell_data,
&curr_cell_act_data, is_test);
}
if (num_layers % 2 == 0) {
framework::TensorCopy(
*output_holder, ctx.GetPlace(),
ctx.template device_context<platform::CPUDeviceContext>(), output);
}
}
template <typename DeviceContext, typename T>
class RNNCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("Input");
auto pre_state = ctx.MultiInput<Tensor>("PreState");
auto weight_list = ctx.MultiInput<framework::Tensor>("WeightList");
auto state = ctx.MultiOutput<Tensor>("State");
auto* output = ctx.Output<Tensor>("Out");
auto* dropout_mask = ctx.Output<Tensor>("DropoutState");
auto* reserve_data = ctx.Output<Tensor>("Reserve");
const int& num_layers = ctx.Attr<int>("num_layers");
const bool& is_bidirec = ctx.Attr<bool>("is_bidirec");
const int& input_size = ctx.Attr<int>("input_size");
const int& hidden_size = ctx.Attr<int>("hidden_size");
const float& dropout_prob = ctx.Attr<float>("dropout_prob");
const std::string& mode = ctx.Attr<std::string>("mode");
const int& seed = ctx.Attr<int>("seed");
bool is_test = ctx.HasAttr("is_test") ? ctx.Attr<bool>("is_test") : false;
bool has_seq_length = ctx.HasInput("SequenceLength");
const Tensor* sequence_length = nullptr;
if (has_seq_length) {
sequence_length = ctx.Input<Tensor>("SequenceLength");
}
if (dropout_mask->IsInitialized()) {
if (dropout_mask->numel() != output->numel()) dropout_mask->clear();
}
dropout_mask->mutable_data<uint8_t>(output->dims(), ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
phi::funcs::SetConstant<platform::CPUDeviceContext, uint8_t> ones;
ones(dev_ctx, dropout_mask, static_cast<uint8_t>(1));
// init the output and allocate the memory
output->mutable_data<T>(ctx.GetPlace());
int gate_num = 4;
state[0]->mutable_data<T>(ctx.GetPlace());
if (is_lstm(ctx)) {
state[1]->mutable_data<T>(ctx.GetPlace());
RnnFunc<LSTMCell<T>, Layer, SingleLayer, BidirLayer, T>(
ctx, input, weight_list, pre_state[0], pre_state[1], sequence_length,
state[0], state[1], output, dropout_mask, num_layers, gate_num,
input_size, hidden_size, is_bidirec, mode, dropout_prob, is_test,
seed, reserve_data);
} else if (is_rnn_relu(ctx)) {
gate_num = 1;
RnnFunc<SimpleRNNCell<T, ReluCPUFunctor,
phi::funcs::detail::ActivationType::kReLU>,
Layer, SingleLayer, BidirLayer, T>(
ctx, input, weight_list, pre_state[0], nullptr, sequence_length,
state[0], nullptr, output, dropout_mask, num_layers, gate_num,
input_size, hidden_size, is_bidirec, mode, dropout_prob, is_test,
seed, reserve_data);
} else if (is_rnn_tanh(ctx)) {
gate_num = 1;
RnnFunc<SimpleRNNCell<T, TanhFunctor,
phi::funcs::detail::ActivationType::kTanhV2>,
Layer, SingleLayer, BidirLayer, T>(
ctx, input, weight_list, pre_state[0], nullptr, sequence_length,
state[0], nullptr, output, dropout_mask, num_layers, gate_num,
input_size, hidden_size, is_bidirec, mode, dropout_prob, is_test,
seed, reserve_data);
} else if (is_gru(ctx)) {
gate_num = 3;
RnnFunc<GRUCell<T>, Layer, SingleLayer, BidirLayer, T>(
ctx, input, weight_list, pre_state[0], nullptr, sequence_length,
state[0], nullptr, output, dropout_mask, num_layers, gate_num,
input_size, hidden_size, is_bidirec, mode, dropout_prob, is_test,
seed, reserve_data);
}
}
};
template <typename T>
void create_lstm_value(phi::funcs::LstmMetaValue<T>* lstm_value) {
lstm_value->check_ig = nullptr;
lstm_value->check_fg = nullptr;
lstm_value->check_og = nullptr;
}
template <typename T>
void create_lstm_grad(phi::funcs::LstmMetaGrad<T>* lstm_grad) {
lstm_grad->check_ig_grad = nullptr;
lstm_grad->check_fg_grad = nullptr;
lstm_grad->check_og_grad = nullptr;
}
template <typename T>
void create_tensor_by_list(const framework::ExecutionContext& context,
Tensor* dst, const std::vector<T>& v) {
int tensor_size = v.size();
dst->Resize({tensor_size});
dst->mutable_data<T>(context.GetPlace());
int size = v.size();
for (int i = 0; i < size; ++i) {
dst->data<T>()[i] = v[i];
}
}
template <typename T, typename GradCellType>
struct GradLayer {
explicit GradLayer(const GradCellType& cell) : cell_(cell) {}
virtual ~GradLayer() {}
void run_rnn_grad_function(
const framework::ExecutionContext& context,
const platform::CPUDeviceContext& device_ctx, const Tensor* input,
Tensor* input_grad, const Tensor* sequence_length,
std::vector<Tensor>* init_h_unbind, std::vector<Tensor>* init_c_unbind,
std::vector<Tensor>* init_h_grad_unbind,
std::vector<Tensor>* init_c_grad_unbind, Tensor* layer_grad_gate_tensor,
std::vector<Tensor>* layer_gate_tensor_unbind,
std::vector<Tensor>* layer_grad_gate_tensor_unbind,
std::vector<Tensor>* layer_state_tensor_unbind,
std::vector<Tensor>* layer_act_state_tensor_unbind,
std::vector<Tensor>* output_tensor_unbind,
std::vector<Tensor>* output_grad_tensor_unbind,
const TensorList& last_h_grad_unbind,
const TensorList& last_c_grad_unbind,
const std::vector<TensorList>& parameter_lists,
std::vector<TensorList>* weight_list_grad, const int& layer_idx,
const int& time_step, const bool& has_sequence_length,
const bool& is_bidirec, const bool& is_reverse) {
const int& direction_num = is_bidirec ? 2 : 1;
const int& current_reverse_idx = is_reverse ? 1 : 0;
const int& current_layer_idx =
direction_num * layer_idx + current_reverse_idx;
int begin_idx = 0;
if (is_reverse) {
begin_idx = time_step;
}
Tensor mask_matrix;
TensorList mask_tensor_list;
int mask_min_length = time_step;
if (has_sequence_length) {
mask_matrix.Resize(phi::make_ddim({time_step, input->dims()[1]}));
create_mask_matrix<T>(context, sequence_length, &mask_matrix, is_reverse,
&mask_min_length);
mask_tensor_list = Unbind(mask_matrix);
}
// copy the last_h, last_c for swaping pointer
Tensor a, b;
Tensor* dynamic_grad_last_h = &a;
Tensor* dynamic_grad_last_c = &b;
dynamic_grad_last_h->Resize(last_h_grad_unbind[current_layer_idx].dims());
dynamic_grad_last_h->mutable_data<T>(context.GetPlace());
framework::TensorCopy(last_h_grad_unbind[current_layer_idx],
context.GetPlace(), dynamic_grad_last_h);
if (last_c_grad_unbind.size() > 0) {
dynamic_grad_last_c->Resize(last_c_grad_unbind[current_layer_idx].dims());
dynamic_grad_last_c->mutable_data<T>(context.GetPlace());
framework::TensorCopy(last_c_grad_unbind[current_layer_idx],
context.GetPlace(), dynamic_grad_last_c);
} else {
dynamic_grad_last_c = nullptr;
}
Tensor c, d;
Tensor* dynamic_grad_pre_h = &c;
Tensor* dynamic_grad_pre_c = &d;
phi::funcs::SetConstant<platform::CPUDeviceContext, T> zero;
if (init_h_grad_unbind->size() > 0) {
dynamic_grad_pre_h->ShareDataWith(
(*init_h_grad_unbind)[current_layer_idx]);
} else {
dynamic_grad_pre_h->Resize(dynamic_grad_last_h->dims());
dynamic_grad_pre_h->mutable_data<T>(context.GetPlace());
zero(device_ctx, dynamic_grad_pre_h, static_cast<T>(0.0));
}
if (init_c_grad_unbind->size() > 0) {
dynamic_grad_pre_c->ShareDataWith(
(*init_c_grad_unbind)[current_layer_idx]);
} else {
if (is_lstm(context) || is_gru(context)) {
dynamic_grad_pre_c->Resize(dynamic_grad_last_h->dims());
dynamic_grad_pre_c->mutable_data<T>(context.GetPlace());
if (is_gru(context)) {
dynamic_grad_last_c = dynamic_grad_pre_c;
}
} else {
dynamic_grad_pre_c = nullptr;
}
}
if (is_reverse) {
// must be reverse the input, output, input_grad, output_grad
// the gate and grad_gate must be reverse
std::reverse(layer_gate_tensor_unbind->begin(),
layer_gate_tensor_unbind->end());
std::reverse(layer_grad_gate_tensor_unbind->begin(),
layer_grad_gate_tensor_unbind->end());
/*
if (has_sequence_length) {
std::reverse(mask_tensor_list.begin(), mask_tensor_list.end());
}*/
std::reverse(output_tensor_unbind->begin(), output_tensor_unbind->end());
std::reverse(output_grad_tensor_unbind->begin(),
output_grad_tensor_unbind->end());
}
Tensor* weight_grad =
&((*weight_list_grad)[layer_idx][current_reverse_idx * 4 + 1]);
weight_grad->mutable_data<T>(context.GetPlace());
zero(device_ctx, weight_grad, static_cast<T>(0.0));
Tensor* pre_hidden = nullptr;
Tensor* pre_state = nullptr;
Tensor* hidden = nullptr;
if (is_gru(context)) {
zero(device_ctx,
&((*weight_list_grad)[layer_idx][current_reverse_idx * 4 + 3]),
static_cast<T>(0.0));
}
for (int i = time_step - 1; i >= 0; --i) {
if (has_sequence_length) {
this->mask_preprocess(context, &(*output_grad_tensor_unbind)[i],
dynamic_grad_last_h, dynamic_grad_last_c,
dynamic_grad_pre_h, dynamic_grad_pre_c,
mask_tensor_list[i]);
} else {
this->preprocess(context, &(*output_grad_tensor_unbind)[i],
dynamic_grad_last_h);
}
hidden = &(*output_tensor_unbind)[i];
if (i == 0) {
pre_hidden = &(*init_h_unbind)[current_layer_idx];
if (init_c_unbind->size() > 0) {
pre_state = &(*init_c_unbind)[current_layer_idx];
}
} else {
pre_hidden = &(*output_tensor_unbind)[i - 1];
if (layer_state_tensor_unbind->size() > 0) {
pre_state = &(*layer_state_tensor_unbind)[begin_idx + i - 1];
}
}
this->cell_(
context, &(*layer_gate_tensor_unbind)[i],
&(*layer_state_tensor_unbind)[begin_idx + i],
&(*layer_act_state_tensor_unbind)[begin_idx + i], hidden,
&(parameter_lists[layer_idx][current_reverse_idx * 4 + 1]),
pre_hidden, pre_state, dynamic_grad_last_h, dynamic_grad_last_c,
&(*layer_grad_gate_tensor_unbind)[i], weight_grad, dynamic_grad_pre_h,
dynamic_grad_pre_c,
&((*weight_list_grad)[layer_idx][current_reverse_idx * 4 + 3]),
mask_tensor_list[i], has_sequence_length);
SwapPoniter(&dynamic_grad_last_h, &dynamic_grad_pre_h);
SwapPoniter(&dynamic_grad_last_c, &dynamic_grad_pre_c);
}
// postproces for gradient for w_hi, X, bias_hi, bias_hh
this->postprocess(context, *layer_grad_gate_tensor, *input, input_grad,
parameter_lists[layer_idx],
&((*weight_list_grad)[layer_idx]), is_reverse);
// copy the gradient to init_c init_h
if ((*init_h_grad_unbind).size() > 0 && time_step % 2 == 0) {
framework::TensorCopy(*dynamic_grad_last_h, context.GetPlace(),
&((*init_h_grad_unbind)[current_layer_idx]));
}
if ((*init_c_grad_unbind).size() > 0 && time_step % 2 == 0) {
framework::TensorCopy(*dynamic_grad_last_c, context.GetPlace(),
&((*init_c_grad_unbind)[current_layer_idx]));
}
}
virtual void operator()(
const framework::ExecutionContext& context, const Tensor* input,
const Tensor* output, const TensorList& init_h_unbind,
const TensorList& init_c_unbind, const TensorList& last_h_grad_unbind,
const TensorList& last_c_grad_unbind,
const TensorList& gate_tensor_unbind,
const TensorList& state_tensor_unbind,
const TensorList& act_state_tensor_unbind, const Tensor* output_grad,
const std::vector<TensorList>& parameter_lists,
const Tensor* sequence_length, Tensor* input_grad,
TensorList* init_h_grad_unbind, TensorList* init_c_grad_unbind,
const std::vector<TensorList>& weight_list_grad, const int& layer_idx,
const int& gate_num) {}
void preprocess(const framework::ExecutionContext& context,
const Tensor* grad_output, Tensor* grad_last_h) {
auto& place = *context.template device_context<platform::CPUDeviceContext>()
.eigen_device();
auto output_grad = framework::EigenMatrix<T>::Reshape(
*grad_output, grad_output->dims().size() - 1);
auto last_h_grad = framework::EigenMatrix<T>::Reshape(
*grad_last_h, grad_last_h->dims().size() - 1);
// the output gradient contribute the gradient to last_h
last_h_grad.device(place) = last_h_grad + output_grad;
}
void mask_preprocess(const framework::ExecutionContext& context,
const Tensor* grad_output, Tensor* grad_last_h,
Tensor* grad_last_c, Tensor* grad_pre_h,
Tensor* grad_pre_c, const Tensor& mask_tensor) {
auto& place = *context.template device_context<platform::CPUDeviceContext>()
.eigen_device();
auto mask = framework::EigenMatrix<T>::From(
mask_tensor, phi::make_ddim({mask_tensor.dims()[1], 1}));
auto mask_broadcast =
mask.broadcast(Eigen::DSizes<int, 2>(1, grad_output->dims()[2]));
auto last_h_grad = framework::EigenMatrix<T>::Reshape(
*grad_last_h, grad_last_h->dims().size() - 1);
auto pre_h_grad = framework::EigenMatrix<T>::Reshape(
*grad_pre_h, grad_pre_h->dims().size() - 1);
auto output_grad = framework::EigenMatrix<T>::Reshape(
*grad_output, grad_output->dims().size() - 1);
last_h_grad.device(place) = last_h_grad + output_grad * mask_broadcast;
pre_h_grad.device(place) = (1 - mask_broadcast) * last_h_grad;
last_h_grad.device(place) = mask_broadcast * last_h_grad;
if (grad_last_c && grad_pre_c && is_lstm(context)) {
auto last_c_grad = framework::EigenMatrix<T>::Reshape(
*grad_last_c, grad_last_c->dims().size() - 1);
auto pre_c_grad = framework::EigenMatrix<T>::Reshape(
*grad_pre_c, grad_pre_c->dims().size() - 1);
pre_c_grad.device(place) = (1 - mask_broadcast) * last_c_grad;
last_c_grad.device(place) = mask_broadcast * last_c_grad;
}
}
void postprocess(const framework::ExecutionContext& context,
const Tensor& grad_gate, const Tensor& input,
Tensor* input_grad, const TensorList& parameters,
TensorList* grad_parameters, const int& is_reverse) {
// we get the grad_gate step by step, and need to bradocast the grad to the
// grad_w_hi, grad_bias_hi, grad_bias_hh
int begin_idx = 0;
if (is_reverse) {
begin_idx = 4;
}
auto& device_ctx =
context.template device_context<platform::CPUDeviceContext>();
auto blas = phi::funcs::GetBlas<platform::CPUDeviceContext, T>(device_ctx);
// calc the gradient for the w_hi
auto mat_dim_out_grad =
phi::funcs::CreateMatrixDescriptor(grad_gate.dims(), 0, true);
auto mat_dim_input =
phi::funcs::CreateMatrixDescriptor(input.dims(), 0, false);
mat_dim_out_grad.width_ *= mat_dim_out_grad.batch_size_;
mat_dim_out_grad.batch_size_ = 0;
mat_dim_input.height_ *= mat_dim_input.batch_size_;
mat_dim_input.batch_size_ = 0;
blas.MatMul(grad_gate, mat_dim_out_grad, input, mat_dim_input,
static_cast<T>(1.0), &((*grad_parameters)[begin_idx + 0]),
T(0));
// calc the gradient for the X
auto mat_dim_out_grad_new =
phi::funcs::CreateMatrixDescriptor(grad_gate.dims(), 0, false);
mat_dim_out_grad_new.height_ *= mat_dim_out_grad_new.batch_size_;
mat_dim_out_grad_new.batch_size_ = 0;
auto mat_dim_parameter =
phi::funcs::CreateMatrixDescriptor(parameters[0].dims(), 0, false);
blas.MatMul(grad_gate, mat_dim_out_grad_new, parameters[begin_idx + 0],
mat_dim_parameter, static_cast<T>(1.0), input_grad, T(1));
// calc the gradient of Bias_hi, Bias_hh
phi::funcs::ColwiseSum<platform::CPUDeviceContext, T> col_sum;
Tensor tmp_grad_gate;
tmp_grad_gate.ShareDataWith(grad_gate);
tmp_grad_gate.Resize(
{grad_gate.dims()[0] * grad_gate.dims()[1], grad_gate.dims()[2]});
col_sum(device_ctx, tmp_grad_gate, &((*grad_parameters)[begin_idx + 2]));
// Bias_hh
if (!is_gru(context)) {
col_sum(device_ctx, tmp_grad_gate, &((*grad_parameters)[begin_idx + 3]));
}
}
GradCellType cell_;
};
template <typename T, typename GradCellType>
struct SingleGradLayer : GradLayer<T, GradCellType> {
// explicit SingleGradLayer(GradCellType& cell) : cell_(cell) {}
explicit SingleGradLayer(const GradCellType& cell)
: GradLayer<T, GradCellType>(cell) {}
virtual ~SingleGradLayer() {}
void operator()(
const framework::ExecutionContext& context, const Tensor* input,
const Tensor* output, std::vector<Tensor>* init_h_unbind,
std::vector<Tensor>* init_c_unbind, const TensorList& last_h_grad_unbind,
const TensorList& last_c_grad_unbind,
const TensorList& gate_tensor_unbind,
const TensorList& state_tensor_unbind,
const TensorList& act_state_tensor_unbind, const Tensor* output_grad,
const std::vector<TensorList>& parameter_lists,
const Tensor* sequence_length, Tensor* input_grad,
TensorList* init_h_grad_unbind, TensorList* init_c_grad_unbind,
std::vector<TensorList>* weight_list_grad, const int& layer_idx,
const int& gate_num) {
auto& device_ctx =
context.template device_context<platform::CPUDeviceContext>();
phi::funcs::SetConstant<platform::CPUDeviceContext, T> zero;
zero(device_ctx, input_grad, static_cast<T>(0.0));
const bool& is_bidirec = context.Attr<bool>("is_bidirec");
const int& time_step = input->dims()[0];
const int& batch_size = input->dims()[1];
const int& direction_num = is_bidirec ? 2 : 1;
const int& hidden_size = context.Attr<int>("hidden_size");
// in this section, create the gate_state_grad for the postprocess calculate
// ubind the output, the output from [time_step, batch_size, hidden_size]
auto output_tensor_unbind = Unbind(*output);
auto output_grad_tensor_unbind = Unbind(*output_grad);
auto layer_gate_tensor = gate_tensor_unbind[layer_idx];
layer_gate_tensor.Resize(
{time_step * direction_num, batch_size, hidden_size * gate_num});
auto layer_gate_tensor_unbind = Unbind(layer_gate_tensor);
// the gate_tensor and the grad_gate_tensor must be unbind
Tensor layer_grad_gate_tensor;
layer_grad_gate_tensor.Resize(layer_gate_tensor.dims());
layer_grad_gate_tensor.mutable_data<T>(context.GetPlace());
auto layer_grad_gate_tensor_unbind = Unbind(layer_grad_gate_tensor);
Tensor layer_state_tensor;
TensorList layer_state_tensor_unbind;
if (state_tensor_unbind.size() > 0) {
layer_state_tensor = state_tensor_unbind[layer_idx];
layer_state_tensor.Resize(
{time_step * direction_num, batch_size, hidden_size});
layer_state_tensor_unbind = Unbind(layer_state_tensor);
}
Tensor layer_act_state_tensor;
TensorList layer_act_state_tensor_unbind;
if (act_state_tensor_unbind.size() > 0) {
layer_act_state_tensor = act_state_tensor_unbind[layer_idx];
layer_act_state_tensor.Resize(
{time_step * direction_num, batch_size, hidden_size});
layer_act_state_tensor_unbind = Unbind(layer_act_state_tensor);
}
const bool& has_sequence_length = sequence_length == nullptr ? false : true;
this->run_rnn_grad_function(
context, device_ctx, input, input_grad, sequence_length, init_h_unbind,
init_c_unbind, init_h_grad_unbind, init_c_grad_unbind,
&layer_grad_gate_tensor, &layer_gate_tensor_unbind,
&layer_grad_gate_tensor_unbind, &layer_state_tensor_unbind,
&layer_act_state_tensor_unbind, &output_tensor_unbind,
&output_grad_tensor_unbind, last_h_grad_unbind, last_c_grad_unbind,
parameter_lists, weight_list_grad, layer_idx, time_step,
has_sequence_length, is_bidirec, false);
}
};
template <typename T>
void split_tensor_at_last_dim(const framework::ExecutionContext& context,
const platform::CPUDeviceContext& dev_ctx,
const Tensor* output,
std::vector<Tensor*>* output_vec,
const int& axis) {
std::vector<const framework::Tensor*> shape_refer;
(*output_vec)[0]->Resize(
{output->dims()[0], output->dims()[1], output->dims()[2] / 2});
(*output_vec)[0]->mutable_data<T>(context.GetPlace());
(*output_vec)[1]->Resize(
{output->dims()[0], output->dims()[1], output->dims()[2] / 2});
(*output_vec)[1]->mutable_data<T>(context.GetPlace());
shape_refer.emplace_back((*output_vec)[0]);
shape_refer.emplace_back((*output_vec)[1]);
math::SplitFunctor<platform::CPUDeviceContext, T> functor;
functor(dev_ctx, *output, shape_refer, axis, output_vec);
}
template <typename T, typename GradCellType>
struct BidirGradLayer : GradLayer<T, GradCellType> {
explicit BidirGradLayer(const GradCellType& cell)
: GradLayer<T, GradCellType>(cell) {}
virtual ~BidirGradLayer() {}
void operator()(
const framework::ExecutionContext& context, const Tensor* input,
const Tensor* output, std::vector<Tensor>* init_h_unbind,
std::vector<Tensor>* init_c_unbind, const TensorList& last_h_grad_unbind,
const TensorList& last_c_grad_unbind,
const TensorList& gate_tensor_unbind,
const TensorList& state_tensor_unbind,
const TensorList& act_state_tensor_unbind, const Tensor* output_grad,
const std::vector<TensorList>& parameter_lists,
const Tensor* sequence_length, Tensor* input_grad,
TensorList* init_h_grad_unbind, TensorList* init_c_grad_unbind,
std::vector<TensorList>* weight_list_grad, const int& layer_idx,
const int& gate_num) {
const bool& is_bidirec = context.Attr<bool>("is_bidirec");
const int& time_step = input->dims()[0];
const int& batch_size = input->dims()[1];
const int& direction_num = is_bidirec ? 2 : 1;
const int& hidden_size = context.Attr<int>("hidden_size");
// split the output two tensor to output_forward, output_backward
auto& device_ctx =
context.template device_context<platform::CPUDeviceContext>();
phi::funcs::SetConstant<platform::CPUDeviceContext, T> zero;
zero(device_ctx, input_grad, static_cast<T>(0.0));
std::vector<Tensor*> output_vec;
Tensor forward_output;
Tensor backward_output;
std::vector<Tensor> forward_output_tensor_unbind;
std::vector<Tensor> backward_output_tensor_unbind;
// in the last layer, we will use the output as the last hidden
// the output just the concat the forward hidden, backward hidden, so just
// split it
// in other layer, we just split the hidden in the rows
output_vec.emplace_back(&forward_output);
output_vec.emplace_back(&backward_output);
split_tensor_at_last_dim<T>(context, device_ctx, output, &output_vec, 2);
forward_output_tensor_unbind = Unbind(*(output_vec[0]));
backward_output_tensor_unbind = Unbind(*(output_vec[1]));
std::vector<Tensor*> output_grad_vec;
Tensor grad_forward_output;
Tensor grad_backward_output;
output_grad_vec.emplace_back(&grad_forward_output);
output_grad_vec.emplace_back(&grad_backward_output);
split_tensor_at_last_dim<T>(context, device_ctx, output_grad,
&output_grad_vec, 2);
auto forward_output_grad_tensor_unbind = Unbind(*(output_grad_vec[0]));
auto backward_output_grad_tensor_unbind = Unbind(*(output_grad_vec[1]));
// the gate_tensor and the grad_gate_tensor must be unbind
auto layer_gate_tensor = gate_tensor_unbind[layer_idx];
layer_gate_tensor.Resize(
{time_step * 2, batch_size, hidden_size * gate_num});
auto layer_forward_gate_tensor = layer_gate_tensor.Slice(0, time_step);
auto layer_backward_gate_tensor =
layer_gate_tensor.Slice(time_step, 2 * time_step);
auto layer_forward_gate_tensor_unbind = Unbind(layer_forward_gate_tensor);
auto layer_backward_gate_tensor_unbind = Unbind(layer_backward_gate_tensor);
Tensor layer_grad_gate_tensor;
layer_grad_gate_tensor.Resize(layer_gate_tensor.dims());
layer_grad_gate_tensor.mutable_data<T>(context.GetPlace());
zero(device_ctx, &layer_grad_gate_tensor, static_cast<T>(0.0));
auto layer_forward_grad_gate_tensor =
layer_grad_gate_tensor.Slice(0, time_step);
auto layer_backward_grad_gate_tensor =
layer_grad_gate_tensor.Slice(time_step, 2 * time_step);
auto layer_forward_grad_gate_tensor_unbind =
Unbind(layer_forward_grad_gate_tensor);
auto layer_backward_grad_gate_tensor_unbind =
Unbind(layer_backward_grad_gate_tensor);
Tensor layer_state_tensor;
TensorList layer_state_tensor_unbind;
if (state_tensor_unbind.size() > 0) {
layer_state_tensor = state_tensor_unbind[layer_idx];
layer_state_tensor.Resize(
{time_step * direction_num, batch_size, hidden_size});
layer_state_tensor_unbind = Unbind(layer_state_tensor);
}
Tensor layer_act_state_tensor;
TensorList layer_act_state_tensor_unbind;
if (act_state_tensor_unbind.size() > 0) {
layer_act_state_tensor = act_state_tensor_unbind[layer_idx];
layer_act_state_tensor.Resize(
{time_step * direction_num, batch_size, hidden_size});
layer_act_state_tensor_unbind = Unbind(layer_act_state_tensor);
}
const bool& has_sequence_length = sequence_length == nullptr ? false : true;
this->run_rnn_grad_function(
context, device_ctx, input, input_grad, sequence_length, init_h_unbind,
init_c_unbind, init_h_grad_unbind, init_c_grad_unbind,
&layer_forward_grad_gate_tensor, &layer_forward_gate_tensor_unbind,
&layer_forward_grad_gate_tensor_unbind, &layer_state_tensor_unbind,
&layer_act_state_tensor_unbind, &forward_output_tensor_unbind,
&forward_output_grad_tensor_unbind, last_h_grad_unbind,
last_c_grad_unbind, parameter_lists, weight_list_grad, layer_idx,
time_step, has_sequence_length, is_bidirec, false);
this->run_rnn_grad_function(
context, device_ctx, input, input_grad, sequence_length, init_h_unbind,
init_c_unbind, init_h_grad_unbind, init_c_grad_unbind,
&layer_backward_grad_gate_tensor, &layer_backward_gate_tensor_unbind,
&layer_backward_grad_gate_tensor_unbind, &layer_state_tensor_unbind,
&layer_act_state_tensor_unbind, &backward_output_tensor_unbind,
&backward_output_grad_tensor_unbind, last_h_grad_unbind,
last_c_grad_unbind, parameter_lists, weight_list_grad, layer_idx,
time_step, has_sequence_length, is_bidirec, true);
}
};
template <typename T>
void backup_tensor(const framework::ExecutionContext& context, Tensor* dst,
Tensor* src) {
auto& device_ctx =
context.template device_context<platform::CPUDeviceContext>();
dst->Resize(src->dims());
dst->mutable_data<T>(context.GetPlace());
framework::TensorCopy(*src, device_ctx.GetPlace(), device_ctx, dst);
}
template <typename T>
struct GradCell {
virtual ~GradCell() {}
virtual void operator()(const framework::ExecutionContext& context,
Tensor* gate_tensor, Tensor* state_tensor,
Tensor* act_state_tensor, Tensor* hidden_tensor,
const Tensor* weight_hh, Tensor* pre_hidden,
Tensor* pre_state, Tensor* grad_hidden,
Tensor* grad_state, Tensor* grad_gate,
Tensor* grad_weight_hh, Tensor* grad_pre_hidden,
Tensor* grad_pre_state, Tensor* grad_bias_hh,
const Tensor& mask_tensor,
bool has_sequence_length) const {}
void postprocess_pre_hidden_grad(const framework::ExecutionContext& context,
Tensor* grad_pre_hidden,
Tensor* grad_pre_hidden_bak,
Tensor* grad_pre_state,
Tensor* grad_pre_state_bak,
const Tensor& mask_tensor,
bool has_sequence_length) const {
if (has_sequence_length) {
auto& place =
*context.template device_context<platform::CPUDeviceContext>()
.eigen_device();
auto mask = framework::EigenMatrix<T>::From(
mask_tensor, phi::make_ddim({mask_tensor.dims()[1], 1}));
auto mask_broadcast =
mask.broadcast(Eigen::DSizes<int, 2>(1, grad_pre_hidden->dims()[2]));
auto pre_hidden_grad = framework::EigenMatrix<T>::Reshape(
*grad_pre_hidden, grad_pre_hidden->dims().size() - 1);
auto pre_hidden_bak_grad = framework::EigenMatrix<T>::Reshape(
*grad_pre_hidden_bak, grad_pre_hidden_bak->dims().size() - 1);
pre_hidden_grad.device(place) =
(1 - mask_broadcast) * pre_hidden_bak_grad +
pre_hidden_grad * mask_broadcast;
if (grad_pre_state) {
auto pre_state_grad = framework::EigenMatrix<T>::Reshape(
*grad_pre_state, grad_pre_state->dims().size() - 1);
auto pre_state_bak_grad = framework::EigenMatrix<T>::Reshape(
*grad_pre_state_bak, grad_pre_state_bak->dims().size() - 1);
pre_state_grad.device(place) =
(1 - mask_broadcast) * pre_state_bak_grad +
pre_state_grad * mask_broadcast;
}
}
}
virtual void update_pre_hidden_grad(
const framework::ExecutionContext& context, Tensor* grad_gate,
const Tensor* weight_hh, Tensor* grad_pre_hidden,
Tensor* grad_pre_hidden_bak, Tensor* grad_pre_state,
Tensor* grad_pre_state_bak, const Tensor& mask_tensor,
bool has_sequence_length) const {
auto& device_ctx =
context.template device_context<platform::CPUDeviceContext>();
auto blas = phi::funcs::GetBlas<platform::CPUDeviceContext, T>(device_ctx);
Tensor* grad_gate_tmp = grad_gate;
auto mat_dim_a =
phi::funcs::CreateMatrixDescriptor(grad_gate_tmp->dims(), 0, false);
mat_dim_a.height_ *= mat_dim_a.batch_size_;
mat_dim_a.batch_size_ = 0;
auto mat_dim_b =
phi::funcs::CreateMatrixDescriptor(weight_hh->dims(), 0, false);
blas.MatMul(*grad_gate_tmp, mat_dim_a, *weight_hh, mat_dim_b,
static_cast<T>(1.0), grad_pre_hidden, 0);
postprocess_pre_hidden_grad(context, grad_pre_hidden, grad_pre_hidden_bak,
grad_pre_state, grad_pre_state_bak, mask_tensor,
has_sequence_length);
}
virtual void update_weight_hh_grad(const framework::ExecutionContext& context,
Tensor* grad_gate, Tensor* pre_hidden,
Tensor* grad_weight_hh) const {
auto& device_ctx =
context.template device_context<platform::CPUDeviceContext>();
auto blas = phi::funcs::GetBlas<platform::CPUDeviceContext, T>(device_ctx);
auto mat_dim_c =
phi::funcs::CreateMatrixDescriptor(grad_gate->dims(), 0, true);
mat_dim_c.height_ *= mat_dim_c.batch_size_;
mat_dim_c.batch_size_ = 0;
auto mat_dim_d =
phi::funcs::CreateMatrixDescriptor(pre_hidden->dims(), 0, false);
mat_dim_d.height_ *= mat_dim_d.batch_size_;
mat_dim_d.batch_size_ = 0;
blas.MatMul(*grad_gate, mat_dim_c, *pre_hidden, mat_dim_d,
static_cast<T>(1.0), grad_weight_hh, static_cast<T>(1.0));
}
};
template <typename T, template <typename> class EigenActivationBackwardFunctor>
struct SimpleRNNGradCell : GradCell<T> {
void operator()(const framework::ExecutionContext& context,
Tensor* gate_tensor, Tensor* state_tensor,
Tensor* act_state_tensor, Tensor* hidden_tensor,
const Tensor* weight_hh, Tensor* pre_hidden,
Tensor* pre_state, Tensor* grad_hidden, Tensor* grad_state,
Tensor* grad_gate, Tensor* grad_weight_hh,
Tensor* grad_pre_hidden, Tensor* grad_pre_state,
Tensor* grad_bias_hh, const Tensor& mask_tensor,
bool has_sequence_length) const override {
auto& device_ctx =
context.template device_context<platform::CPUDeviceContext>();
Tensor grad_pre_hidden_bak;
if (has_sequence_length) {
backup_tensor<T>(context, &grad_pre_hidden_bak, grad_pre_hidden);
}
// h = act(z)
// update dz
auto dz = EigenVector<T>::Flatten(
GET_DATA_SAFELY(grad_gate, "Output", "dz", "Grad"));
auto dh = EigenVector<T>::Flatten(
GET_DATA_SAFELY(grad_hidden, "Input", "dh", "Grad"));
auto h = EigenVector<T>::Flatten(
GET_DATA_SAFELY(hidden_tensor, "Input", "h", "Value"));
// useless, but need this argument to execute functor
auto z = EigenVector<T>::Flatten(
GET_DATA_SAFELY(gate_tensor, "Input", "z", "Value"));
auto* place = device_ctx.eigen_device();
EigenActivationBackwardFunctor<T> functor;
functor(*place, z, h, dh, dz);
// update grad_weight_hh, grad_pre_hidden
this->update_pre_hidden_grad(context, grad_gate, weight_hh, grad_pre_hidden,
&grad_pre_hidden_bak, nullptr, nullptr,
mask_tensor, has_sequence_length);
this->update_weight_hh_grad(context, grad_gate, pre_hidden, grad_weight_hh);
}
};
template <typename T>
struct GRUGradCell : GradCell<T> {
void operator()(const framework::ExecutionContext& context,
Tensor* gate_tensor, Tensor* state_tensor,
Tensor* act_state_tensor, Tensor* hidden_tensor,
const Tensor* weight_hh, Tensor* pre_hidden,
Tensor* pre_state, Tensor* grad_hidden, Tensor* grad_state,
Tensor* grad_gate, Tensor* grad_weight_hh,
Tensor* grad_pre_hidden, Tensor* grad_pre_state,
Tensor* grad_bias_hh, const Tensor& mask_tensor,
bool has_sequence_length) const override {
auto& device_ctx =
context.template device_context<platform::CPUDeviceContext>();
size_t frame_size = pre_hidden->dims()[2];
size_t batch_size = pre_hidden->dims()[1];
Tensor grad_pre_hidden_bak;
if (has_sequence_length) {
backup_tensor<T>(context, &grad_pre_hidden_bak, grad_pre_hidden);
}
// zero pre_hidden
phi::funcs::SetConstant<platform::CPUDeviceContext, T> zero;
zero(device_ctx, grad_pre_hidden, static_cast<T>(0.0));
phi::funcs::GRUMetaValue<T> gru_value;
phi::funcs::GRUMetaGrad<T> gru_grad;
gru_value.gate_value = gate_tensor->data<T>();
gru_value.prev_out_value = pre_hidden->data<T>();
gru_value.reset_output_value = state_tensor->data<T>();
gru_value.state_weight = weight_hh->data<T>() + 2 * frame_size * frame_size;
gru_value.gate_weight = weight_hh->data<T>();
gru_grad.gate_grad = grad_gate->data<T>();
gru_grad.reset_output_grad = grad_state->data<T>();
gru_grad.prev_out_grad = grad_pre_hidden->data<T>();
gru_grad.output_grad = grad_hidden->data<T>();
gru_grad.gate_weight_grad = grad_weight_hh->data<T>();
gru_grad.state_weight_grad =
grad_weight_hh->data<T>() + 2 * frame_size * frame_size;
gru_grad.bias_hh_grad = grad_bias_hh->data<T>();
auto act_gate = phi::funcs::detail::GetActivationType("sigmoid_v2");
auto act_node = phi::funcs::detail::GetActivationType("tanh_v2");
phi::funcs::GRUUnitGradFunctorV2<platform::CPUDeviceContext, T>::compute(
device_ctx, gru_value, gru_grad, frame_size, batch_size, act_node,
act_gate);
this->postprocess_pre_hidden_grad(context, grad_pre_hidden,
&grad_pre_hidden_bak, nullptr, nullptr,
mask_tensor, has_sequence_length);
}
};
template <typename T>
struct LSTMGradCell : GradCell<T> {
void operator()(const framework::ExecutionContext& context,
Tensor* gate_tensor, Tensor* state_tensor,
Tensor* act_state_tensor, Tensor* hidden_tensor,
const Tensor* weight_hh, Tensor* pre_hidden,
Tensor* pre_state, Tensor* grad_hidden, Tensor* grad_state,
Tensor* grad_gate, Tensor* grad_weight_hh,
Tensor* grad_pre_hidden, Tensor* grad_pre_state,
Tensor* grad_bias_hh, const Tensor& mask_tensor,
bool has_sequence_length) const override {
auto& device_ctx =
context.template device_context<platform::CPUDeviceContext>();
size_t frame_size = state_tensor->dims()[2];
size_t batch_size = state_tensor->dims()[1];
Tensor grad_pre_hidden_bak;
Tensor grad_pre_state_bak;
if (has_sequence_length) {
backup_tensor<T>(context, &grad_pre_hidden_bak, grad_pre_hidden);
backup_tensor<T>(context, &grad_pre_state_bak, grad_pre_state);
}
phi::funcs::LstmMetaValue<T> lstm_value;
phi::funcs::LstmMetaGrad<T> lstm_grad;
create_lstm_value(&lstm_value);
create_lstm_grad(&lstm_grad);
lstm_value.gate_value = gate_tensor->data<T>();
lstm_value.state_value = state_tensor->data<T>();
lstm_value.state_active_value = act_state_tensor->data<T>();
lstm_value.prev_state_value = pre_state->data<T>();
lstm_grad.state_grad = grad_state->data<T>();
lstm_grad.gate_grad = grad_gate->data<T>();
lstm_grad.output_grad = grad_hidden->data<T>();
lstm_grad.prev_state_grad = grad_pre_state->data<T>();
lstm_value.output_value = nullptr;
lstm_grad.state_active_grad = nullptr;
auto gate_act = phi::funcs::detail::GetActivationType("sigmoid_v2");
auto state_act = phi::funcs::detail::GetActivationType("tanh_v2");
auto cand_act = phi::funcs::detail::GetActivationType("tanh_v2");
T cell_clip = 0.0;
phi::funcs::LstmUnitGradFunctor<platform::CPUDeviceContext, T>::compute(
device_ctx, lstm_value, lstm_grad, frame_size, batch_size, cell_clip,
gate_act, state_act, cand_act, false);
this->update_pre_hidden_grad(
context, grad_gate, weight_hh, grad_pre_hidden, &grad_pre_hidden_bak,
grad_pre_state, &grad_pre_state_bak, mask_tensor, has_sequence_length);
this->update_weight_hh_grad(context, grad_gate, pre_hidden, grad_weight_hh);
}
};
template <typename GradCellType,
template <typename, typename> class SingleGradLayerT,
template <typename, typename> class BidirGradLayerT, typename T>
void RnnGradFunc(const framework::ExecutionContext& context,
const int& gate_num) {
// get the tensor pointer for the input
auto* input = context.Input<Tensor>("Input");
auto weight_list = context.MultiInput<Tensor>("WeightList");
auto pre_state = context.MultiInput<Tensor>("PreState");
const Tensor* init_h = pre_state[0];
const Tensor* init_c = nullptr;
if (is_lstm(context)) {
init_c = pre_state[1];
}
auto* reserve_state = context.Input<Tensor>("Reserve");
auto* dropout_state = context.Input<Tensor>("DropoutState");
auto* output = context.Input<Tensor>("Out");
auto* output_grad = context.Input<Tensor>(framework::GradVarName("Out"));
auto state_grad = context.MultiInput<Tensor>(framework::GradVarName("State"));
const Tensor* last_h_grad = state_grad[0];
const Tensor* last_c_grad = nullptr;
if (is_lstm(context)) {
last_c_grad = state_grad[1];
}
bool has_seq_length = context.HasInput("SequenceLength");
const Tensor* sequence_length = nullptr;
if (has_seq_length) {
sequence_length = context.Input<Tensor>("SequenceLength");
}
// get the tensor pointer for the output
auto* input_grad = context.Output<Tensor>(framework::GradVarName("Input"));
auto weight_grad_list = context.MultiOutput<framework::Tensor>(
framework::GradVarName("WeightList"));
auto pre_state_grad =
context.MultiOutput<Tensor>(framework::GradVarName("PreState"));
Tensor* init_h_grad = nullptr;
Tensor* init_c_grad = nullptr;
if (pre_state_grad.size() > 0) { // has gradient
init_h_grad = pre_state_grad[0];
if (is_lstm(context)) {
init_c_grad = pre_state_grad[1];
}
}
// get the attributes for the calcluate
const int& num_layers = context.Attr<int>("num_layers");
const bool& is_bidirec = context.Attr<bool>("is_bidirec");
const float& dropout_prob = context.Attr<float>("dropout_prob");
bool is_test =
context.HasAttr("is_test") ? context.Attr<bool>("is_test") : false;
// get the input_size, batch_size, time_step, hidden_size
const int& time_step = input->dims()[0];
const int& batch_size = input->dims()[1];
const int& hidden_size = context.Attr<int>("hidden_size");
const int& direction_num = is_bidirec ? 2 : 1;
// allocate the memory and initization the input_grad
Tensor input_grad_value;
if (!input_grad) {
input_grad = &input_grad_value;
}
input_grad->mutable_data<T>(input->dims(), context.GetPlace());
if (init_h_grad) {
init_h_grad->mutable_data<T>(init_h->dims(), context.GetPlace());
}
if (init_c_grad) {
init_c_grad->mutable_data<T>(init_c->dims(), context.GetPlace());
}
// reset the parameter to sorted order and allocate the memory
std::vector<TensorList> parameter_lists;
parameter_lists.reserve(num_layers);
reset_parameter_vector(weight_list, num_layers, gate_num, is_bidirec,
&parameter_lists);
for (unsigned int i = 0; i < weight_grad_list.size(); ++i) {
weight_grad_list[i]->mutable_data<T>(context.GetPlace());
}
std::vector<TensorList> parameter_lists_grad;
parameter_lists_grad.reserve(num_layers);
reset_parameter_vector(weight_grad_list, num_layers, gate_num, is_bidirec,
&parameter_lists_grad);
// resolve the state of reverse_state
Tensor gate_tensor;
Tensor state_tensor;
Tensor act_state_tensor;
Tensor hidden_tensor;
SplitReserveData(context, reserve_state, &gate_tensor, &state_tensor,
&act_state_tensor, &hidden_tensor, direction_num, time_step,
batch_size, hidden_size, gate_num, num_layers);
int gate_num_tmp = gate_num;
if (gate_num == 0) {
gate_num_tmp = 1;
}
gate_tensor.Resize({num_layers, time_step * direction_num, batch_size,
hidden_size * gate_num_tmp});
if (state_tensor.numel() > 0) {
state_tensor.Resize(
{num_layers, time_step * direction_num, batch_size, hidden_size});
}
if (act_state_tensor.numel() > 0) {
act_state_tensor.Resize(
{num_layers, time_step * direction_num, batch_size, hidden_size});
}
if (num_layers > 1) {
hidden_tensor.Resize(
{num_layers - 1, time_step, batch_size, hidden_size * direction_num});
}
// unbind
auto last_h_grad_unbind = Unbind(*last_h_grad);
auto gate_tensor_unbind = Unbind(gate_tensor);
TensorList last_c_grad_unbind;
if (last_c_grad) {
last_c_grad_unbind = Unbind(*last_c_grad);
}
TensorList init_h_unbind, init_c_unbind;
TensorList init_h_grad_unbind, init_c_grad_unbind;
TensorList state_tensor_unbind, act_state_tensor_unbind;
TensorList hidden_tensor_unbind;
init_h_unbind = Unbind(*init_h);
if (init_c) {
init_c_unbind = Unbind(*init_c);
}
if (init_h_grad != nullptr) {
init_h_grad_unbind = Unbind(*init_h_grad);
}
if (init_c_grad != nullptr) {
init_c_grad_unbind = Unbind(*init_c_grad);
}
if (state_tensor.numel() > 0) {
state_tensor_unbind = Unbind(state_tensor);
}
if (act_state_tensor.numel() > 0) {
act_state_tensor_unbind = Unbind(act_state_tensor);
}
if (num_layers > 1) {
hidden_tensor_unbind = Unbind(hidden_tensor);
}
// squeeze the hidden first dim
for (unsigned int i = 0; i < hidden_tensor_unbind.size(); i++) {
hidden_tensor_unbind[i].Resize(
phi::slice_ddim(hidden_tensor_unbind[i].dims(), 1,
hidden_tensor_unbind[i].dims().size()));
}
// add the output tensor to the hidden vector
Tensor tmp;
hidden_tensor_unbind.emplace_back(tmp);
hidden_tensor_unbind[num_layers - 1].ShareDataWith(*output);
GradCellType cell;
Tensor layer_input;
Tensor layer_output;
Tensor* layer_input_grad_holder = nullptr;
Tensor tmp_out;
tmp_out.ShareDataWith(*output_grad);
Tensor* layer_output_grad_holder = &tmp_out;
Tensor input_grad_temp;
Tensor output_grad_temp;
bool has_allocate_mem = false;
for (int i = num_layers - 1; i >= 0; --i) {
// the layer input output had saved, just use the data
if (i > 0) {
if (layer_input.numel() == 0) {
layer_input.Resize(hidden_tensor_unbind[i - 1].dims());
layer_input.mutable_data<T>(context.GetPlace());
}
dropout_helper<T>(context, &hidden_tensor_unbind[i - 1], &layer_input,
dropout_state, dropout_prob);
} else {
layer_input.ShareDataWith(*input);
}
layer_output.ShareDataWith(hidden_tensor_unbind[i]);
if (num_layers == 1) {
layer_input_grad_holder = input_grad;
} else {
if (i == num_layers - 1) {
input_grad_temp.Resize(layer_input.dims());
input_grad_temp.mutable_data<T>(context.GetPlace());
layer_input_grad_holder = &input_grad_temp;
}
}
if (is_bidirec) {
BidirGradLayerT<T, GradCellType> layer(cell);
layer(context, &layer_input, &layer_output, &init_h_unbind,
&init_c_unbind, last_h_grad_unbind, last_c_grad_unbind,
gate_tensor_unbind, state_tensor_unbind, act_state_tensor_unbind,
layer_output_grad_holder, parameter_lists, sequence_length,
layer_input_grad_holder, &init_h_grad_unbind, &init_c_grad_unbind,
&parameter_lists_grad, i, gate_num_tmp);
} else {
SingleGradLayerT<T, GradCellType> layer(cell);
layer(context, &layer_input, &layer_output, &init_h_unbind,
&init_c_unbind, last_h_grad_unbind, last_c_grad_unbind,
gate_tensor_unbind, state_tensor_unbind, act_state_tensor_unbind,
layer_output_grad_holder, parameter_lists, sequence_length,
layer_input_grad_holder, &init_h_grad_unbind, &init_c_grad_unbind,
&parameter_lists_grad, i, gate_num_tmp);
}
// calcluate the dropout gradient for the layer_input_grad_holder
// dropout_state save in the forward process
if (i > 0) {
if ((!is_test) && (dropout_prob != 0)) {
dropout_cpu_grad_function_inplace<T>(context, layer_input_grad_holder,
dropout_state, dropout_prob);
}
}
if (i - 1 == 0) {
layer_output_grad_holder = input_grad;
} else {
if (!has_allocate_mem) {
output_grad_temp.Resize(layer_input_grad_holder->dims());
output_grad_temp.mutable_data<T>(context.GetPlace());
layer_output_grad_holder = &output_grad_temp;
has_allocate_mem = true;
}
}
SwapPoniter(&layer_input_grad_holder, &layer_output_grad_holder);
}
}
template <typename DeviceContext, typename T>
class RNNCPUGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
int gate_num = 4;
if (is_lstm(ctx)) {
RnnGradFunc<LSTMGradCell<T>, SingleGradLayer, BidirGradLayer, T>(
ctx, gate_num);
} else if (is_gru(ctx)) {
gate_num = 3;
RnnGradFunc<GRUGradCell<T>, SingleGradLayer, BidirGradLayer, T>(ctx,
gate_num);
// run gru
} else if (is_rnn_relu(ctx)) {
gate_num = 1;
RnnGradFunc<SimpleRNNGradCell<T, ReluGradFunctor>, SingleGradLayer,
BidirGradLayer, T>(ctx, gate_num);
// run rnn
} else if (is_rnn_tanh(ctx)) {
gate_num = 1;
RnnGradFunc<SimpleRNNGradCell<T, TanhGradFunctor>, SingleGradLayer,
BidirGradLayer, T>(ctx, gate_num);
}
}
};
} // namespace operators
} // namespace paddle
...@@ -647,7 +647,6 @@ void BindImperative(py::module *m_ptr) { ...@@ -647,7 +647,6 @@ void BindImperative(py::module *m_ptr) {
} else { } else {
act_name = name.cast<std::string>(); act_name = name.cast<std::string>();
} }
VLOG(4) << "Init VarBase :" << act_name;
new (&self) imperative::VarBase(act_name); new (&self) imperative::VarBase(act_name);
self.SetPersistable(persistable); self.SetPersistable(persistable);
self.SetType(type); self.SetType(type);
......
...@@ -1082,6 +1082,91 @@ void PsroiPoolInferMeta(const MetaTensor& x, ...@@ -1082,6 +1082,91 @@ void PsroiPoolInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype()); out->set_dtype(x.dtype());
} }
void RnnInferMeta(const MetaTensor& x,
const std::vector<MetaTensor*>& pre_state,
const std::vector<MetaTensor*>& weight_list,
paddle::optional<const MetaTensor&> sequence_length,
float dropout_prob,
bool is_bidirec,
int input_size,
int hidden_size,
int num_layers,
const std::string& mode,
int seed,
bool is_test,
MetaTensor* out,
MetaTensor* dropout_state,
std::vector<MetaTensor*> state,
MetaTensor* reserve) {
auto in_dims = x.dims();
PADDLE_ENFORCE_EQ(
in_dims.size(),
3,
phi::errors::InvalidArgument("The rank of Input in RNN must be 3. But "
"received Input's rank is %d.",
in_dims.size()));
if (sequence_length) {
auto seq_dims = sequence_length->dims();
PADDLE_ENFORCE_EQ(
in_dims[1],
seq_dims[0],
phi::errors::InvalidArgument(
"The size of SequenceLength has to equal the batch_size. But "
"received batch_size is %d and the size of SequenceLength is %d.",
in_dims[1],
seq_dims[0]));
}
PADDLE_ENFORCE_EQ(pre_state[0]->dims().size(),
3,
phi::errors::InvalidArgument(
"The rank of PreState in RNN must be 3. But "
"the received rank is %d.",
pre_state[0]->dims().size()));
size_t i = 0;
for (; i < pre_state.size(); ++i) {
PADDLE_ENFORCE_EQ(
in_dims[1],
pre_state[i]->dims()[1],
phi::errors::InvalidArgument(
"The second dimension size (representing for batch size) of "
"Input and PreState should be equal. But received %d and %d.",
in_dims[1],
pre_state[i]->dims()[1]));
PADDLE_ENFORCE_EQ(
pre_state[0]->dims(),
pre_state[i]->dims(),
phi::errors::InvalidArgument(
"The dims of all tensors in PreState should be same. But "
"received PreState[0] is %s and PreState[%d] is %s.",
pre_state[0]->dims(),
i,
pre_state[i]->dims()));
}
size_t num_state = mode == "LSTM" ? 2 : 1;
PADDLE_ENFORCE_EQ(i,
num_state,
phi::errors::InvalidArgument(
"The number of tensors in PreState of %s should be %d, "
"but received %d.",
mode,
2,
i));
auto out_dims = in_dims;
out_dims[2] = is_bidirec ? hidden_size * 2 : hidden_size;
out->set_dims(out_dims);
out->set_dtype(x.dtype());
int state_num = pre_state.size();
for (int i = 0; i < state_num; ++i) {
state[i]->set_dims(pre_state[i]->dims());
state[i]->set_dtype(x.dtype());
}
}
void WarpctcInferMeta(const MetaTensor& logits, void WarpctcInferMeta(const MetaTensor& logits,
const MetaTensor& label, const MetaTensor& label,
const paddle::optional<const MetaTensor&> logits_length, const paddle::optional<const MetaTensor&> logits_length,
......
...@@ -214,6 +214,23 @@ void PsroiPoolInferMeta(const MetaTensor& x, ...@@ -214,6 +214,23 @@ void PsroiPoolInferMeta(const MetaTensor& x,
float spatial_scale, float spatial_scale,
MetaTensor* out); MetaTensor* out);
void RnnInferMeta(const MetaTensor& x,
const std::vector<MetaTensor*>& pre_state,
const std::vector<MetaTensor*>& weight_list,
paddle::optional<const MetaTensor&> sequence_length,
float dropout_prob,
bool is_bidirec,
int input_size,
int hidden_size,
int num_layers,
const std::string& mode,
int seed,
bool is_test,
MetaTensor* out,
MetaTensor* dropout_state,
std::vector<MetaTensor*> state,
MetaTensor* reserve);
void WarpctcInferMeta(const MetaTensor& logits, void WarpctcInferMeta(const MetaTensor& logits,
const MetaTensor& label, const MetaTensor& label,
const paddle::optional<const MetaTensor&> logits_length, const paddle::optional<const MetaTensor&> logits_length,
......
...@@ -32,7 +32,7 @@ set(MANUAL_BUILD_KERNELS adam_kernel adamw_kernel deformable_conv_kernel deforma ...@@ -32,7 +32,7 @@ set(MANUAL_BUILD_KERNELS adam_kernel adamw_kernel deformable_conv_kernel deforma
matrix_power_kernel matrix_power_grad_kernel maxout_kernel maxout_grad_kernel pool_kernel matrix_power_kernel matrix_power_grad_kernel maxout_kernel maxout_grad_kernel pool_kernel
put_along_axis_kernel put_along_axis_grad_kernel segment_pool_kernel segment_pool_grad_kernel put_along_axis_kernel put_along_axis_grad_kernel segment_pool_kernel segment_pool_grad_kernel
softmax_kernel softmax_grad_kernel take_along_axis_kernel take_along_axis_grad_kernel softmax_kernel softmax_grad_kernel take_along_axis_kernel take_along_axis_grad_kernel
triangular_solve_grad_kernel determinant_grad_kernel reduce_kernel warpctc_kernel warpctc_grad_kernel) triangular_solve_grad_kernel determinant_grad_kernel reduce_kernel rnn_kernel rnn_grad_kernel warpctc_kernel warpctc_grad_kernel)
kernel_library(adam_kernel DEPS gflags glog flags ${COMMON_KERNEL_DEPS} selected_rows_functor threadpool jit_kernel_helper) kernel_library(adam_kernel DEPS gflags glog flags ${COMMON_KERNEL_DEPS} selected_rows_functor threadpool jit_kernel_helper)
kernel_library(adamw_kernel DEPS ${COMMON_KERNEL_DEPS} adam_kernel) kernel_library(adamw_kernel DEPS ${COMMON_KERNEL_DEPS} adam_kernel)
kernel_library(deformable_conv_kernel DEPS ${COMMON_KERNEL_DEPS} deformable_conv_functor) kernel_library(deformable_conv_kernel DEPS ${COMMON_KERNEL_DEPS} deformable_conv_functor)
...@@ -58,6 +58,8 @@ kernel_library(take_along_axis_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_ ...@@ -58,6 +58,8 @@ kernel_library(take_along_axis_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_
kernel_library(take_along_axis_grad_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel) kernel_library(take_along_axis_grad_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel)
kernel_library(triangular_solve_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_reduce) kernel_library(triangular_solve_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_reduce)
kernel_library(determinant_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_inverse) kernel_library(determinant_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_inverse)
kernel_library(rnn_kernel DEPS ${COMMON_KERNEL_DEPS} concat_and_split_functor lstm_compute gru_compute)
kernel_library(rnn_grad_kernel DEPS ${COMMON_KERNEL_DEPS} concat_and_split_functor lstm_compute gru_compute)
kernel_library(warpctc_kernel DEPS ${COMMON_KERNEL_DEPS} phi_dynload_warpctc sequence_padding sequence_scale) kernel_library(warpctc_kernel DEPS ${COMMON_KERNEL_DEPS} phi_dynload_warpctc sequence_padding sequence_scale)
kernel_library(warpctc_grad_kernel DEPS ${COMMON_KERNEL_DEPS} sequence_padding sequence_scale) kernel_library(warpctc_grad_kernel DEPS ${COMMON_KERNEL_DEPS} sequence_padding sequence_scale)
...@@ -73,5 +75,5 @@ copy_if_different(${kernel_declare_file} ${kernel_declare_file_final}) ...@@ -73,5 +75,5 @@ copy_if_different(${kernel_declare_file} ${kernel_declare_file_final})
# For strings kernels # For strings kernels
add_subdirectory(strings) add_subdirectory(strings)
# 5. kernel autotune # 5. kernel autotune
add_subdirectory(autotune) add_subdirectory(autotune)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/operators/utils.h"
namespace phi {
#define DEFINE_MODE_DETECTOR(MODE_NAME, MODE_STR) \
inline bool is_##MODE_NAME(const std::string& mode) { \
return mode == #MODE_STR; \
}
DEFINE_MODE_DETECTOR(lstm, LSTM);
DEFINE_MODE_DETECTOR(gru, GRU);
DEFINE_MODE_DETECTOR(rnn_relu, RNN_RELU);
DEFINE_MODE_DETECTOR(rnn_tanh, RNN_TANH);
inline void SwapPoniter(DenseTensor** a, DenseTensor** b) {
DenseTensor* c = *a;
*a = *b;
*b = c;
}
template <typename T>
void CreateMaskMatrix(const CPUContext& dev_ctx,
const DenseTensor* sequence_length,
DenseTensor* mask_matrix,
const bool& is_reverse,
int* min_seq_len) {
const auto& seq_len_vec =
paddle::operators::GetDataFromTensor<int>(sequence_length);
const int table_width = mask_matrix->dims()[0];
DenseTensor temp =
Empty<T>(dev_ctx, {mask_matrix->dims()[1], mask_matrix->dims()[0]});
T* data_temp = temp.data<T>();
std::fill(data_temp, data_temp + mask_matrix->numel(), static_cast<T>(1.0));
*min_seq_len = table_width;
for (unsigned int i = 0; i < seq_len_vec.size(); i++) {
// reset the mask matrix
*min_seq_len = std::min(seq_len_vec[i], *min_seq_len);
if (seq_len_vec[i] == table_width) {
continue;
}
if (is_reverse) {
std::fill(data_temp + i * table_width,
data_temp + (i + 1) * table_width - seq_len_vec[i],
static_cast<T>(0));
} else {
std::fill(data_temp + i * table_width + seq_len_vec[i],
data_temp + (i + 1) * table_width,
static_cast<T>(0));
}
}
dev_ctx.Alloc<T>(mask_matrix);
std::vector<int> trans_vec;
trans_vec.emplace_back(1);
trans_vec.emplace_back(0);
funcs::TransCompute<CPUContext, T>(2, dev_ctx, temp, mask_matrix, trans_vec);
}
template <typename TensorType>
void ResetParameterVector(const std::vector<TensorType>& raw_params_vec,
int num_layers,
int gate_num,
bool is_bidirec,
std::vector<std::vector<DenseTensor>>* params_vec) {
// the parameter raw seuquence is [FWhi, FWhh, BWhi, BWhh] * num_layers
// + [FBhi, FBhh, BBhi, BBhh] * num_layers, we will reset the parameter to
// ([FWhi, FWhh, FBhi, FBhh] + [BWhi, BWhh, BBhi, BBhh]) * num_layers
const int& direction_num = is_bidirec ? 2 : 1;
const int& layer_weight_size = 4 * direction_num;
const int& all_weight_size = num_layers * layer_weight_size;
const int& bias_start_idx = all_weight_size / 2;
for (int i = 0; i < num_layers; i++) {
std::vector<DenseTensor> tensor_list;
tensor_list.reserve(layer_weight_size);
for (int j = 0; j < layer_weight_size; j++) {
DenseTensor tensor_holder;
tensor_list.emplace_back(tensor_holder);
}
for (int j = 0; j < layer_weight_size; j++) {
int k = j % 4;
const int& section = j / 4;
int tensor_idx = i * 2 * direction_num + section * 2 + k % 2;
if (k >= 2) {
tensor_idx += bias_start_idx;
}
tensor_list[j].ShareDataWith(*raw_params_vec[tensor_idx]);
}
params_vec->emplace_back(tensor_list);
}
}
template <typename T>
void DropoutHelper(const CPUContext& dev_ctx,
DenseTensor* x,
DenseTensor* y,
const DenseTensor* mask,
float dropout_prob) {
auto& place = *dev_ctx.eigen_device();
auto dropout_mask = EigenVector<uint8_t>::Flatten(*mask);
auto in = EigenVector<T>::Flatten(*x);
auto out = EigenVector<T>::Flatten(*y);
if (dropout_prob == 1.0f) {
out.device(place) = static_cast<T>(0) * in;
} else {
out.device(place) =
in * dropout_mask.cast<T>() / static_cast<T>(1.0f - dropout_prob);
}
}
template <typename T>
void DropoutCpuFunctionInplace(const CPUContext& dev_ctx,
DenseTensor* x,
DenseTensor* y,
DenseTensor* mask,
const float& dropout_prob,
const int& seed_number,
bool is_test,
bool* is_has_reset) {
if (is_test) {
return;
}
size_t size = phi::product(x->dims());
auto* mask_data = mask->data<uint8_t>();
if (!(*is_has_reset)) {
// Special case when dropout_prob is 1.0
if (dropout_prob == 1.0f) {
std::fill(mask_data, mask_data + size, static_cast<uint8_t>(0));
} else {
auto engine = paddle::framework::GetCPURandomEngine(seed_number);
std::uniform_real_distribution<float> dist(0, 1);
for (size_t i = 0; i < size; ++i) {
if (dist(*engine) < dropout_prob) {
mask_data[i] = 0;
} else {
mask_data[i] = 1;
}
}
}
*is_has_reset = true;
}
DropoutHelper<T>(dev_ctx, x, y, mask, dropout_prob);
}
template <typename Context, typename TensorType>
void SplitReserveData(const Context& dev_ctx,
int direction_num,
int time_step,
int batch_size,
int hidden_size,
int gate_num,
int num_layers,
const std::string& mode,
TensorType* reserve_data,
DenseTensor* gate_data,
DenseTensor* cell_data,
DenseTensor* cell_act_data,
DenseTensor* hidden_data) {
int gate_data_idx = gate_num * num_layers;
int cell_data_idx = (gate_num + 1) * num_layers;
int cell_act_data_idx = (gate_num + 2) * num_layers;
// simple rnn
int hidden_data_start_idx = gate_data_idx;
*gate_data = reserve_data->Slice(0, gate_data_idx);
if (is_lstm(mode)) {
*cell_data = reserve_data->Slice(gate_data_idx, cell_data_idx);
*cell_act_data = reserve_data->Slice(cell_data_idx, cell_act_data_idx);
hidden_data_start_idx = cell_act_data_idx;
} else if (is_gru(mode)) {
*cell_data = reserve_data->Slice(gate_data_idx, cell_data_idx);
hidden_data_start_idx = cell_data_idx;
}
int hidden_data_idx = hidden_data_start_idx + (num_layers - 1);
if (num_layers > 1) {
*hidden_data = reserve_data->Slice(hidden_data_start_idx, hidden_data_idx);
}
}
template <typename CellType, typename T, typename Context>
void AllocateReserveData(const Context& dev_ctx,
bool is_bidirec,
int num_layers,
int gate_num,
int hidden_size,
const std::string& mode,
DenseTensor* reserve_data,
DenseTensor* gate_data,
DenseTensor* cell_data,
DenseTensor* cell_act_data,
DenseTensor* hidden_data,
const DenseTensor* input) {
int direction_num = is_bidirec ? 2 : 1;
int time_step = input->dims()[0];
int batch_size = input->dims()[1];
int block_size = direction_num * time_step * batch_size * hidden_size;
int hidden_data_idx = (num_layers - 1);
if (is_lstm(mode)) {
hidden_data_idx += (gate_num + 2) * num_layers;
} else if (is_gru(mode)) {
hidden_data_idx += (gate_num + 1) * num_layers;
} else {
hidden_data_idx += gate_num * num_layers;
}
reserve_data->Resize({hidden_data_idx, block_size});
dev_ctx.template Alloc<T>(reserve_data);
SplitReserveData(dev_ctx,
direction_num,
time_step,
batch_size,
hidden_size,
gate_num,
num_layers,
mode,
reserve_data,
gate_data,
cell_data,
cell_act_data,
hidden_data);
}
inline std::vector<DenseTensor> Unbind(const DenseTensor& in) {
int64_t size = in.dims()[0];
std::vector<DenseTensor> tensors;
tensors.reserve(size);
for (int64_t i = 0; i < size; ++i) {
tensors.emplace_back(in.Slice(i, i + 1));
}
return tensors;
}
template <typename CellType,
template <typename, typename> class LayerT,
template <typename, typename> class SingleLayerT,
template <typename, typename> class BidirLayerT,
typename T,
typename Context>
void RnnFunc(const Context& dev_ctx,
const DenseTensor* input,
const std::vector<const DenseTensor*>& weight_list,
const DenseTensor* init_h,
const DenseTensor* init_c,
const DenseTensor* sequence_length,
DenseTensor* last_h,
DenseTensor* last_c,
DenseTensor* output,
DenseTensor* dropout_mask,
int num_layers,
int gate_num,
int input_size,
int hidden_size,
bool is_bidirec,
const std::string& cell_type,
float dropout_prob,
bool is_test,
int seed,
DenseTensor* reserve_data) {
int direction_num = is_bidirec ? 2 : 1;
const auto& init_h_dims = init_h->dims();
PADDLE_ENFORCE_EQ(init_h_dims[0],
num_layers * direction_num,
phi::errors::InvalidArgument(
"The num_layers of in RNN layer must be the same as "
"first dim of init hidden, but received"
" num_layers:%d, dim:%d",
num_layers,
init_h_dims[0]));
if (is_lstm(cell_type)) {
const auto& init_c_dims = init_c->dims();
PADDLE_ENFORCE_EQ(init_c_dims[0],
num_layers * direction_num,
phi::errors::InvalidArgument(
"The num_layers of in RNN layer must be the same as "
"first dim of cell state hidden, but received"
" num_layers:%d, dim:%d",
num_layers,
init_h_dims[0]));
}
CellType cell;
std::vector<std::vector<DenseTensor>> parameter_lists;
parameter_lists.reserve(num_layers);
ResetParameterVector(
weight_list, num_layers, gate_num, is_bidirec, &parameter_lists);
DenseTensor gate_data, cell_data, cell_act_data, hidden_data;
if (!is_test) {
AllocateReserveData<CellType, T, Context>(dev_ctx,
is_bidirec,
num_layers,
gate_num,
hidden_size,
cell_type,
reserve_data,
&gate_data,
&cell_data,
&cell_act_data,
&hidden_data,
input);
gate_data.Resize({num_layers, gate_data.numel() / num_layers});
cell_data.Resize({num_layers, cell_data.numel() / num_layers});
cell_act_data.Resize({num_layers, cell_act_data.numel() / num_layers});
if (num_layers > 1) {
hidden_data.Resize(
{num_layers - 1, hidden_data.numel() / (num_layers - 1)});
}
}
DenseTensor* input_holder;
DenseTensor* output_holder = output;
bool has_allocate_mem = false;
auto init_h_unbind = Unbind(*init_h);
auto last_h_unbind = Unbind(*last_h);
std::vector<DenseTensor> init_c_unbind, last_c_unbind;
if (is_lstm(cell_type)) {
init_c_unbind = Unbind(*init_c);
last_c_unbind = Unbind(*last_c);
}
DenseTensor curr_gate_data, curr_cell_data, curr_cell_act_data;
DenseTensor curr_hidden_data, prev_hidden_data;
DenseTensor temp;
bool has_dropout_reset = false;
for (int i = 0; i < num_layers; i++) {
if (!is_test) {
if (cell_data.numel() > 0) /** for lstm, gru **/ {
curr_cell_data = cell_data.Slice(i, i + 1);
}
if (cell_act_data.numel() > 0) /*for lstm*/ {
curr_cell_act_data = cell_act_data.Slice(i, i + 1);
}
curr_gate_data = gate_data.Slice(i, i + 1);
output_holder = output;
if (i < num_layers - 1 && num_layers > 1) {
curr_hidden_data = hidden_data.Slice(i, i + 1);
curr_hidden_data.Resize(output->dims());
output_holder = &curr_hidden_data;
}
}
if (i > 0) {
if (!has_allocate_mem) {
temp.Resize(output->dims());
dev_ctx.template Alloc<T>(&temp);
input_holder = &temp;
has_allocate_mem = true;
}
if (!is_test) {
prev_hidden_data = hidden_data.Slice(i - 1, i);
input_holder->Resize(output->dims());
if (dropout_prob != 0) {
DropoutCpuFunctionInplace<T>(dev_ctx,
&prev_hidden_data,
input_holder,
dropout_mask,
dropout_prob,
seed,
is_test,
&has_dropout_reset);
} else {
input_holder = &prev_hidden_data;
input_holder->Resize(output->dims());
}
} else {
SwapPoniter(&output_holder, &input_holder);
}
}
const DenseTensor* input_temp_holder = input;
if (i > 0) {
input_temp_holder = input_holder;
}
LayerT<T, CellType>* layer;
SingleLayerT<T, CellType> slayer(cell);
BidirLayerT<T, CellType> blayer(cell);
if (is_bidirec) {
layer = &blayer;
} else {
layer = &slayer;
}
(*layer)(dev_ctx,
input_temp_holder,
parameter_lists[i],
init_h_unbind,
init_c_unbind,
sequence_length,
last_h_unbind,
last_c_unbind,
output_holder,
i,
gate_num,
&curr_gate_data,
&curr_cell_data,
&curr_cell_act_data,
cell_type,
is_test);
}
if (num_layers % 2 == 0) {
Copy(dev_ctx, *output_holder, dev_ctx.GetPlace(), false, output);
}
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/rnn_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/cpu/rnn_functor.h"
#include "paddle/phi/kernels/funcs/activation_functor.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
#include "paddle/phi/kernels/funcs/detail/activation_functions.h"
#include "paddle/phi/kernels/funcs/gru_compute.h"
#include "paddle/phi/kernels/funcs/lstm_compute.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename T>
void BackupTensor(const CPUContext& dev_ctx,
DenseTensor* dst,
DenseTensor* src) {
dst->Resize(src->dims());
dev_ctx.Alloc<T>(dst);
Copy(dev_ctx, *src, dev_ctx.GetPlace(), false, dst);
}
template <typename T>
void CreateLstmValue(phi::funcs::LstmMetaValue<T>* lstm_value) {
lstm_value->check_ig = nullptr;
lstm_value->check_fg = nullptr;
lstm_value->check_og = nullptr;
}
template <typename T>
void CreateLstmGrad(phi::funcs::LstmMetaGrad<T>* lstm_grad) {
lstm_grad->check_ig_grad = nullptr;
lstm_grad->check_fg_grad = nullptr;
lstm_grad->check_og_grad = nullptr;
}
template <typename T>
struct GradCell {
virtual ~GradCell() {}
virtual void operator()(const CPUContext& dev_ctx,
DenseTensor* gate_tensor,
DenseTensor* state_tensor,
DenseTensor* act_state_tensor,
DenseTensor* hidden_tensor,
const DenseTensor* weight_hh,
DenseTensor* pre_hidden,
DenseTensor* pre_state,
DenseTensor* grad_hidden,
DenseTensor* grad_state,
DenseTensor* grad_gate,
DenseTensor* grad_weight_hh,
DenseTensor* grad_pre_hidden,
DenseTensor* grad_pre_state,
DenseTensor* grad_bias_hh,
const DenseTensor& mask_tensor,
bool has_sequence_length) const {}
void postprocess_pre_hidden_grad(const CPUContext& dev_ctx,
DenseTensor* grad_pre_hidden,
DenseTensor* grad_pre_hidden_bak,
DenseTensor* grad_pre_state,
DenseTensor* grad_pre_state_bak,
const DenseTensor& mask_tensor,
bool has_sequence_length) const {
if (has_sequence_length) {
auto& place = *dev_ctx.eigen_device();
auto mask = EigenMatrix<T>::From(
mask_tensor, phi::make_ddim({mask_tensor.dims()[1], 1}));
auto mask_broadcast =
mask.broadcast(Eigen::DSizes<int, 2>(1, grad_pre_hidden->dims()[2]));
auto pre_hidden_grad = EigenMatrix<T>::Reshape(
*grad_pre_hidden, grad_pre_hidden->dims().size() - 1);
auto pre_hidden_bak_grad = EigenMatrix<T>::Reshape(
*grad_pre_hidden_bak, grad_pre_hidden_bak->dims().size() - 1);
pre_hidden_grad.device(place) =
(1 - mask_broadcast) * pre_hidden_bak_grad +
pre_hidden_grad * mask_broadcast;
if (grad_pre_state) {
auto pre_state_grad = EigenMatrix<T>::Reshape(
*grad_pre_state, grad_pre_state->dims().size() - 1);
auto pre_state_bak_grad = EigenMatrix<T>::Reshape(
*grad_pre_state_bak, grad_pre_state_bak->dims().size() - 1);
pre_state_grad.device(place) =
(1 - mask_broadcast) * pre_state_bak_grad +
pre_state_grad * mask_broadcast;
}
}
}
virtual void update_pre_hidden_grad(const CPUContext& dev_ctx,
DenseTensor* grad_gate,
const DenseTensor* weight_hh,
DenseTensor* grad_pre_hidden,
DenseTensor* grad_pre_hidden_bak,
DenseTensor* grad_pre_state,
DenseTensor* grad_pre_state_bak,
const DenseTensor& mask_tensor,
bool has_sequence_length) const {
auto blas = phi::funcs::GetBlas<CPUContext, T>(dev_ctx);
DenseTensor* grad_gate_tmp = grad_gate;
auto mat_dim_a =
phi::funcs::CreateMatrixDescriptor(grad_gate_tmp->dims(), 0, false);
mat_dim_a.height_ *= mat_dim_a.batch_size_;
mat_dim_a.batch_size_ = 0;
auto mat_dim_b =
phi::funcs::CreateMatrixDescriptor(weight_hh->dims(), 0, false);
blas.MatMul(*grad_gate_tmp,
mat_dim_a,
*weight_hh,
mat_dim_b,
static_cast<T>(1.0),
grad_pre_hidden,
0);
postprocess_pre_hidden_grad(dev_ctx,
grad_pre_hidden,
grad_pre_hidden_bak,
grad_pre_state,
grad_pre_state_bak,
mask_tensor,
has_sequence_length);
}
virtual void update_weight_hh_grad(const CPUContext& dev_ctx,
DenseTensor* grad_gate,
DenseTensor* pre_hidden,
DenseTensor* grad_weight_hh) const {
auto blas = phi::funcs::GetBlas<CPUContext, T>(dev_ctx);
auto mat_dim_c =
phi::funcs::CreateMatrixDescriptor(grad_gate->dims(), 0, true);
mat_dim_c.height_ *= mat_dim_c.batch_size_;
mat_dim_c.batch_size_ = 0;
auto mat_dim_d =
phi::funcs::CreateMatrixDescriptor(pre_hidden->dims(), 0, false);
mat_dim_d.height_ *= mat_dim_d.batch_size_;
mat_dim_d.batch_size_ = 0;
blas.MatMul(*grad_gate,
mat_dim_c,
*pre_hidden,
mat_dim_d,
static_cast<T>(1.0),
grad_weight_hh,
static_cast<T>(1.0));
}
};
template <typename T, template <typename> class EigenActivationBackwardFunctor>
struct SimpleRNNGradCell : GradCell<T> {
void operator()(const CPUContext& dev_ctx,
DenseTensor* gate_tensor,
DenseTensor* state_tensor,
DenseTensor* act_state_tensor,
DenseTensor* hidden_tensor,
const DenseTensor* weight_hh,
DenseTensor* pre_hidden,
DenseTensor* pre_state,
DenseTensor* grad_hidden,
DenseTensor* grad_state,
DenseTensor* grad_gate,
DenseTensor* grad_weight_hh,
DenseTensor* grad_pre_hidden,
DenseTensor* grad_pre_state,
DenseTensor* grad_bias_hh,
const DenseTensor& mask_tensor,
bool has_sequence_length) const override {
DenseTensor grad_pre_hidden_bak;
if (has_sequence_length) {
BackupTensor<T>(dev_ctx, &grad_pre_hidden_bak, grad_pre_hidden);
}
// h = act(z)
// update dz
auto dz = EigenVector<T>::Flatten(
GET_DATA_SAFELY(grad_gate, "Output", "dz", "Grad"));
auto dh = EigenVector<T>::Flatten(
GET_DATA_SAFELY(grad_hidden, "Input", "dh", "Grad"));
auto h = EigenVector<T>::Flatten(
GET_DATA_SAFELY(hidden_tensor, "Input", "h", "Value"));
// useless, but need this argument to execute functor
auto z = EigenVector<T>::Flatten(
GET_DATA_SAFELY(gate_tensor, "Input", "z", "Value"));
auto* place = dev_ctx.eigen_device();
EigenActivationBackwardFunctor<T> functor;
functor(*place, z, h, dh, dz);
// update grad_weight_hh, grad_pre_hidden
this->update_pre_hidden_grad(dev_ctx,
grad_gate,
weight_hh,
grad_pre_hidden,
&grad_pre_hidden_bak,
nullptr,
nullptr,
mask_tensor,
has_sequence_length);
this->update_weight_hh_grad(dev_ctx, grad_gate, pre_hidden, grad_weight_hh);
}
};
template <typename T>
struct GRUGradCell : GradCell<T> {
void operator()(const CPUContext& dev_ctx,
DenseTensor* gate_tensor,
DenseTensor* state_tensor,
DenseTensor* act_state_tensor,
DenseTensor* hidden_tensor,
const DenseTensor* weight_hh,
DenseTensor* pre_hidden,
DenseTensor* pre_state,
DenseTensor* grad_hidden,
DenseTensor* grad_state,
DenseTensor* grad_gate,
DenseTensor* grad_weight_hh,
DenseTensor* grad_pre_hidden,
DenseTensor* grad_pre_state,
DenseTensor* grad_bias_hh,
const DenseTensor& mask_tensor,
bool has_sequence_length) const override {
size_t frame_size = pre_hidden->dims()[2];
size_t batch_size = pre_hidden->dims()[1];
DenseTensor grad_pre_hidden_bak;
if (has_sequence_length) {
BackupTensor<T>(dev_ctx, &grad_pre_hidden_bak, grad_pre_hidden);
}
// zero pre_hidden
phi::funcs::SetConstant<CPUContext, T> zero;
zero(dev_ctx, grad_pre_hidden, static_cast<T>(0.0));
phi::funcs::GRUMetaValue<T> gru_value;
phi::funcs::GRUMetaGrad<T> gru_grad;
gru_value.gate_value = gate_tensor->data<T>();
gru_value.prev_out_value = pre_hidden->data<T>();
gru_value.reset_output_value = state_tensor->data<T>();
gru_value.state_weight = weight_hh->data<T>() + 2 * frame_size * frame_size;
gru_value.gate_weight = weight_hh->data<T>();
gru_grad.gate_grad = grad_gate->data<T>();
gru_grad.reset_output_grad = grad_state->data<T>();
gru_grad.prev_out_grad = grad_pre_hidden->data<T>();
gru_grad.output_grad = grad_hidden->data<T>();
gru_grad.gate_weight_grad = grad_weight_hh->data<T>();
gru_grad.state_weight_grad =
grad_weight_hh->data<T>() + 2 * frame_size * frame_size;
gru_grad.bias_hh_grad = grad_bias_hh->data<T>();
auto act_gate = phi::funcs::detail::GetActivationType("sigmoid_v2");
auto act_node = phi::funcs::detail::GetActivationType("tanh_v2");
phi::funcs::GRUUnitGradFunctorV2<CPUContext, T>::compute(dev_ctx,
gru_value,
gru_grad,
frame_size,
batch_size,
act_node,
act_gate);
this->postprocess_pre_hidden_grad(dev_ctx,
grad_pre_hidden,
&grad_pre_hidden_bak,
nullptr,
nullptr,
mask_tensor,
has_sequence_length);
}
};
template <typename T>
struct LSTMGradCell : GradCell<T> {
void operator()(const CPUContext& dev_ctx,
DenseTensor* gate_tensor,
DenseTensor* state_tensor,
DenseTensor* act_state_tensor,
DenseTensor* hidden_tensor,
const DenseTensor* weight_hh,
DenseTensor* pre_hidden,
DenseTensor* pre_state,
DenseTensor* grad_hidden,
DenseTensor* grad_state,
DenseTensor* grad_gate,
DenseTensor* grad_weight_hh,
DenseTensor* grad_pre_hidden,
DenseTensor* grad_pre_state,
DenseTensor* grad_bias_hh,
const DenseTensor& mask_tensor,
bool has_sequence_length) const override {
size_t frame_size = state_tensor->dims()[2];
size_t batch_size = state_tensor->dims()[1];
DenseTensor grad_pre_hidden_bak;
DenseTensor grad_pre_state_bak;
if (has_sequence_length) {
BackupTensor<T>(dev_ctx, &grad_pre_hidden_bak, grad_pre_hidden);
BackupTensor<T>(dev_ctx, &grad_pre_state_bak, grad_pre_state);
}
phi::funcs::LstmMetaValue<T> lstm_value;
phi::funcs::LstmMetaGrad<T> lstm_grad;
CreateLstmValue(&lstm_value);
CreateLstmGrad(&lstm_grad);
lstm_value.gate_value = gate_tensor->data<T>();
lstm_value.state_value = state_tensor->data<T>();
lstm_value.state_active_value = act_state_tensor->data<T>();
lstm_value.prev_state_value = pre_state->data<T>();
lstm_grad.state_grad = grad_state->data<T>();
lstm_grad.gate_grad = grad_gate->data<T>();
lstm_grad.output_grad = grad_hidden->data<T>();
lstm_grad.prev_state_grad = grad_pre_state->data<T>();
lstm_value.output_value = nullptr;
lstm_grad.state_active_grad = nullptr;
auto gate_act = phi::funcs::detail::GetActivationType("sigmoid_v2");
auto state_act = phi::funcs::detail::GetActivationType("tanh_v2");
auto cand_act = phi::funcs::detail::GetActivationType("tanh_v2");
T cell_clip = 0.0;
phi::funcs::LstmUnitGradFunctor<CPUContext, T>::compute(dev_ctx,
lstm_value,
lstm_grad,
frame_size,
batch_size,
cell_clip,
gate_act,
state_act,
cand_act,
false);
this->update_pre_hidden_grad(dev_ctx,
grad_gate,
weight_hh,
grad_pre_hidden,
&grad_pre_hidden_bak,
grad_pre_state,
&grad_pre_state_bak,
mask_tensor,
has_sequence_length);
this->update_weight_hh_grad(dev_ctx, grad_gate, pre_hidden, grad_weight_hh);
}
};
template <typename T, typename GradCellType>
struct GradLayer {
explicit GradLayer(const GradCellType& cell) : cell_(cell) {}
virtual ~GradLayer() {}
void run_rnn_grad_function(
const CPUContext& dev_ctx,
const DenseTensor* input,
DenseTensor* input_grad,
const DenseTensor* sequence_length,
std::vector<DenseTensor>* init_h_unbind,
std::vector<DenseTensor>* init_c_unbind,
std::vector<DenseTensor>* init_h_grad_unbind,
std::vector<DenseTensor>* init_c_grad_unbind,
DenseTensor* layer_grad_gate_tensor,
std::vector<DenseTensor>* layer_gate_tensor_unbind,
std::vector<DenseTensor>* layer_grad_gate_tensor_unbind,
std::vector<DenseTensor>* layer_state_tensor_unbind,
std::vector<DenseTensor>* layer_act_state_tensor_unbind,
std::vector<DenseTensor>* output_tensor_unbind,
std::vector<DenseTensor>* output_grad_tensor_unbind,
const std::vector<DenseTensor>& last_h_grad_unbind,
const std::vector<DenseTensor>& last_c_grad_unbind,
const std::vector<std::vector<DenseTensor>>& parameter_lists,
std::vector<std::vector<DenseTensor>>* weight_list_grad,
int layer_idx,
int time_step,
bool has_sequence_length,
bool is_bidirec,
bool is_reverse,
const std::string& mode) {
int direction_num = is_bidirec ? 2 : 1;
int current_reverse_idx = is_reverse ? 1 : 0;
int current_layer_idx = direction_num * layer_idx + current_reverse_idx;
int begin_idx = 0;
if (is_reverse) {
begin_idx = time_step;
}
DenseTensor mask_matrix;
std::vector<DenseTensor> mask_tensor_list;
int mask_min_length = time_step;
if (has_sequence_length) {
mask_matrix.Resize(phi::make_ddim({time_step, input->dims()[1]}));
CreateMaskMatrix<T>(
dev_ctx, sequence_length, &mask_matrix, is_reverse, &mask_min_length);
mask_tensor_list = Unbind(mask_matrix);
}
// copy the last_h, last_c for swaping pointer
DenseTensor a, b;
DenseTensor* dynamic_grad_last_h = &a;
DenseTensor* dynamic_grad_last_c = &b;
dynamic_grad_last_h->Resize(last_h_grad_unbind[current_layer_idx].dims());
dev_ctx.Alloc<T>(dynamic_grad_last_h);
Copy(dev_ctx,
last_h_grad_unbind[current_layer_idx],
dev_ctx.GetPlace(),
false,
dynamic_grad_last_h);
if (last_c_grad_unbind.size() > 0) {
dynamic_grad_last_c->Resize(last_c_grad_unbind[current_layer_idx].dims());
dev_ctx.Alloc<T>(dynamic_grad_last_c);
Copy(dev_ctx,
last_c_grad_unbind[current_layer_idx],
dev_ctx.GetPlace(),
false,
dynamic_grad_last_c);
} else {
dynamic_grad_last_c = nullptr;
}
DenseTensor c, d;
DenseTensor* dynamic_grad_pre_h = &c;
DenseTensor* dynamic_grad_pre_c = &d;
phi::funcs::SetConstant<CPUContext, T> zero;
if (init_h_grad_unbind->size() > 0) {
dynamic_grad_pre_h->ShareDataWith(
(*init_h_grad_unbind)[current_layer_idx]);
} else {
dynamic_grad_pre_h->Resize(dynamic_grad_last_h->dims());
dev_ctx.Alloc<T>(dynamic_grad_pre_h);
zero(dev_ctx, dynamic_grad_pre_h, static_cast<T>(0.0));
}
if (init_c_grad_unbind->size() > 0) {
dynamic_grad_pre_c->ShareDataWith(
(*init_c_grad_unbind)[current_layer_idx]);
} else {
if (is_lstm(mode) || is_gru(mode)) {
dynamic_grad_pre_c->Resize(dynamic_grad_last_h->dims());
dev_ctx.Alloc<T>(dynamic_grad_pre_c);
if (is_gru(mode)) {
dynamic_grad_last_c = dynamic_grad_pre_c;
}
} else {
dynamic_grad_pre_c = nullptr;
}
}
if (is_reverse) {
// must be reverse the input, output, input_grad, output_grad
// the gate and grad_gate must be reverse
std::reverse(layer_gate_tensor_unbind->begin(),
layer_gate_tensor_unbind->end());
std::reverse(layer_grad_gate_tensor_unbind->begin(),
layer_grad_gate_tensor_unbind->end());
/*
if (has_sequence_length) {
std::reverse(mask_tensor_list.begin(), mask_tensor_list.end());
}*/
std::reverse(output_tensor_unbind->begin(), output_tensor_unbind->end());
std::reverse(output_grad_tensor_unbind->begin(),
output_grad_tensor_unbind->end());
}
DenseTensor* weight_grad =
&((*weight_list_grad)[layer_idx][current_reverse_idx * 4 + 1]);
dev_ctx.Alloc<T>(weight_grad);
zero(dev_ctx, weight_grad, static_cast<T>(0.0));
DenseTensor* pre_hidden = nullptr;
DenseTensor* pre_state = nullptr;
DenseTensor* hidden = nullptr;
if (is_gru(mode)) {
zero(dev_ctx,
&((*weight_list_grad)[layer_idx][current_reverse_idx * 4 + 3]),
static_cast<T>(0.0));
}
for (int i = time_step - 1; i >= 0; --i) {
if (has_sequence_length) {
this->mask_preprocess(dev_ctx,
&(*output_grad_tensor_unbind)[i],
dynamic_grad_last_h,
dynamic_grad_last_c,
dynamic_grad_pre_h,
dynamic_grad_pre_c,
mask_tensor_list[i],
mode);
} else {
this->preprocess(
dev_ctx, &(*output_grad_tensor_unbind)[i], dynamic_grad_last_h);
}
hidden = &(*output_tensor_unbind)[i];
if (i == 0) {
pre_hidden = &(*init_h_unbind)[current_layer_idx];
if (init_c_unbind->size() > 0) {
pre_state = &(*init_c_unbind)[current_layer_idx];
}
} else {
pre_hidden = &(*output_tensor_unbind)[i - 1];
if (layer_state_tensor_unbind->size() > 0) {
pre_state = &(*layer_state_tensor_unbind)[begin_idx + i - 1];
}
}
this->cell_(
dev_ctx,
&(*layer_gate_tensor_unbind)[i],
&(*layer_state_tensor_unbind)[begin_idx + i],
&(*layer_act_state_tensor_unbind)[begin_idx + i],
hidden,
&(parameter_lists[layer_idx][current_reverse_idx * 4 + 1]),
pre_hidden,
pre_state,
dynamic_grad_last_h,
dynamic_grad_last_c,
&(*layer_grad_gate_tensor_unbind)[i],
weight_grad,
dynamic_grad_pre_h,
dynamic_grad_pre_c,
&((*weight_list_grad)[layer_idx][current_reverse_idx * 4 + 3]),
mask_tensor_list[i],
has_sequence_length);
SwapPoniter(&dynamic_grad_last_h, &dynamic_grad_pre_h);
SwapPoniter(&dynamic_grad_last_c, &dynamic_grad_pre_c);
}
// postproces for gradient for w_hi, X, bias_hi, bias_hh
this->postprocess(dev_ctx,
*layer_grad_gate_tensor,
*input,
input_grad,
parameter_lists[layer_idx],
&((*weight_list_grad)[layer_idx]),
is_reverse,
mode);
// copy the gradient to init_c init_h
if ((*init_h_grad_unbind).size() > 0 && time_step % 2 == 0) {
Copy(dev_ctx,
*dynamic_grad_last_h,
dev_ctx.GetPlace(),
false,
&((*init_h_grad_unbind)[current_layer_idx]));
}
if ((*init_c_grad_unbind).size() > 0 && time_step % 2 == 0) {
Copy(dev_ctx,
*dynamic_grad_last_c,
dev_ctx.GetPlace(),
false,
&((*init_c_grad_unbind)[current_layer_idx]));
}
}
virtual void operator()(
const CPUContext& dev_ctx,
const DenseTensor* input,
const DenseTensor* output,
const std::vector<DenseTensor>& init_h_unbind,
const std::vector<DenseTensor>& init_c_unbind,
const std::vector<DenseTensor>& last_h_grad_unbind,
const std::vector<DenseTensor>& last_c_grad_unbind,
const std::vector<DenseTensor>& gate_tensor_unbind,
const std::vector<DenseTensor>& state_tensor_unbind,
const std::vector<DenseTensor>& act_state_tensor_unbind,
const DenseTensor* output_grad,
const std::vector<std::vector<DenseTensor>>& parameter_lists,
const DenseTensor* sequence_length,
DenseTensor* input_grad,
std::vector<DenseTensor>* init_h_grad_unbind,
std::vector<DenseTensor>* init_c_grad_unbind,
const std::vector<std::vector<DenseTensor>>& weight_list_grad,
int layer_idx,
bool is_bidirec,
int hidden_size,
const std::string& mode,
int gate_num) {}
void preprocess(const CPUContext& dev_ctx,
const DenseTensor* grad_output,
DenseTensor* grad_last_h) {
auto& place = *dev_ctx.eigen_device();
auto output_grad =
EigenMatrix<T>::Reshape(*grad_output, grad_output->dims().size() - 1);
auto last_h_grad =
EigenMatrix<T>::Reshape(*grad_last_h, grad_last_h->dims().size() - 1);
// the output gradient contribute the gradient to last_h
last_h_grad.device(place) = last_h_grad + output_grad;
}
void mask_preprocess(const CPUContext& dev_ctx,
const DenseTensor* grad_output,
DenseTensor* grad_last_h,
DenseTensor* grad_last_c,
DenseTensor* grad_pre_h,
DenseTensor* grad_pre_c,
const DenseTensor& mask_tensor,
const std::string& mode) {
auto& place = *dev_ctx.eigen_device();
auto mask = EigenMatrix<T>::From(
mask_tensor, phi::make_ddim({mask_tensor.dims()[1], 1}));
auto mask_broadcast =
mask.broadcast(Eigen::DSizes<int, 2>(1, grad_output->dims()[2]));
auto last_h_grad =
EigenMatrix<T>::Reshape(*grad_last_h, grad_last_h->dims().size() - 1);
auto pre_h_grad =
EigenMatrix<T>::Reshape(*grad_pre_h, grad_pre_h->dims().size() - 1);
auto output_grad =
EigenMatrix<T>::Reshape(*grad_output, grad_output->dims().size() - 1);
last_h_grad.device(place) = last_h_grad + output_grad * mask_broadcast;
pre_h_grad.device(place) = (1 - mask_broadcast) * last_h_grad;
last_h_grad.device(place) = mask_broadcast * last_h_grad;
if (grad_last_c && grad_pre_c && is_lstm(mode)) {
auto last_c_grad =
EigenMatrix<T>::Reshape(*grad_last_c, grad_last_c->dims().size() - 1);
auto pre_c_grad =
EigenMatrix<T>::Reshape(*grad_pre_c, grad_pre_c->dims().size() - 1);
pre_c_grad.device(place) = (1 - mask_broadcast) * last_c_grad;
last_c_grad.device(place) = mask_broadcast * last_c_grad;
}
}
void postprocess(const CPUContext& dev_ctx,
const DenseTensor& grad_gate,
const DenseTensor& input,
DenseTensor* input_grad,
const std::vector<DenseTensor>& parameters,
std::vector<DenseTensor>* grad_parameters,
int is_reverse,
const std::string& mode) {
// we get the grad_gate step by step, and need to bradocast the grad to the
// grad_w_hi, grad_bias_hi, grad_bias_hh
int begin_idx = 0;
if (is_reverse) {
begin_idx = 4;
}
auto blas = phi::funcs::GetBlas<CPUContext, T>(dev_ctx);
// calc the gradient for the w_hi
auto mat_dim_out_grad =
phi::funcs::CreateMatrixDescriptor(grad_gate.dims(), 0, true);
auto mat_dim_input =
phi::funcs::CreateMatrixDescriptor(input.dims(), 0, false);
mat_dim_out_grad.width_ *= mat_dim_out_grad.batch_size_;
mat_dim_out_grad.batch_size_ = 0;
mat_dim_input.height_ *= mat_dim_input.batch_size_;
mat_dim_input.batch_size_ = 0;
blas.MatMul(grad_gate,
mat_dim_out_grad,
input,
mat_dim_input,
static_cast<T>(1.0),
&((*grad_parameters)[begin_idx + 0]),
T(0));
// calc the gradient for the X
auto mat_dim_out_grad_new =
phi::funcs::CreateMatrixDescriptor(grad_gate.dims(), 0, false);
mat_dim_out_grad_new.height_ *= mat_dim_out_grad_new.batch_size_;
mat_dim_out_grad_new.batch_size_ = 0;
auto mat_dim_parameter =
phi::funcs::CreateMatrixDescriptor(parameters[0].dims(), 0, false);
blas.MatMul(grad_gate,
mat_dim_out_grad_new,
parameters[begin_idx + 0],
mat_dim_parameter,
static_cast<T>(1.0),
input_grad,
T(1));
// calc the gradient of Bias_hi, Bias_hh
phi::funcs::ColwiseSum<CPUContext, T> col_sum;
DenseTensor tmp_grad_gate;
tmp_grad_gate.ShareDataWith(grad_gate);
tmp_grad_gate.Resize(
{grad_gate.dims()[0] * grad_gate.dims()[1], grad_gate.dims()[2]});
col_sum(dev_ctx, tmp_grad_gate, &((*grad_parameters)[begin_idx + 2]));
// Bias_hh
if (!is_gru(mode)) {
col_sum(dev_ctx, tmp_grad_gate, &((*grad_parameters)[begin_idx + 3]));
}
}
GradCellType cell_;
};
template <typename T, typename GradCellType>
struct SingleGradLayer : GradLayer<T, GradCellType> {
// explicit SingleGradLayer(GradCellType& cell) : cell_(cell) {}
explicit SingleGradLayer(const GradCellType& cell)
: GradLayer<T, GradCellType>(cell) {}
virtual ~SingleGradLayer() {}
void operator()(const CPUContext& dev_ctx,
const DenseTensor* input,
const DenseTensor* output,
std::vector<DenseTensor>* init_h_unbind,
std::vector<DenseTensor>* init_c_unbind,
const std::vector<DenseTensor>& last_h_grad_unbind,
const std::vector<DenseTensor>& last_c_grad_unbind,
const std::vector<DenseTensor>& gate_tensor_unbind,
const std::vector<DenseTensor>& state_tensor_unbind,
const std::vector<DenseTensor>& act_state_tensor_unbind,
const DenseTensor* output_grad,
const std::vector<std::vector<DenseTensor>>& parameter_lists,
const DenseTensor* sequence_length,
DenseTensor* input_grad,
std::vector<DenseTensor>* init_h_grad_unbind,
std::vector<DenseTensor>* init_c_grad_unbind,
std::vector<std::vector<DenseTensor>>* weight_list_grad,
int layer_idx,
bool is_bidirec,
int hidden_size,
const std::string& mode,
int gate_num) {
phi::funcs::SetConstant<CPUContext, T> zero;
zero(dev_ctx, input_grad, static_cast<T>(0.0));
int time_step = input->dims()[0];
int batch_size = input->dims()[1];
int direction_num = is_bidirec ? 2 : 1;
// in this section, create the gate_state_grad for the postprocess calculate
// ubind the output, the output from [time_step, batch_size, hidden_size]
auto output_tensor_unbind = Unbind(*output);
auto output_grad_tensor_unbind = Unbind(*output_grad);
auto layer_gate_tensor = gate_tensor_unbind[layer_idx];
layer_gate_tensor.Resize(
{time_step * direction_num, batch_size, hidden_size * gate_num});
auto layer_gate_tensor_unbind = Unbind(layer_gate_tensor);
// the gate_tensor and the grad_gate_tensor must be unbind
DenseTensor layer_grad_gate_tensor;
layer_grad_gate_tensor.Resize(layer_gate_tensor.dims());
dev_ctx.Alloc<T>(&layer_grad_gate_tensor);
auto layer_grad_gate_tensor_unbind = Unbind(layer_grad_gate_tensor);
DenseTensor layer_state_tensor;
std::vector<DenseTensor> layer_state_tensor_unbind;
if (state_tensor_unbind.size() > 0) {
layer_state_tensor = state_tensor_unbind[layer_idx];
layer_state_tensor.Resize(
{time_step * direction_num, batch_size, hidden_size});
layer_state_tensor_unbind = Unbind(layer_state_tensor);
}
DenseTensor layer_act_state_tensor;
std::vector<DenseTensor> layer_act_state_tensor_unbind;
if (act_state_tensor_unbind.size() > 0) {
layer_act_state_tensor = act_state_tensor_unbind[layer_idx];
layer_act_state_tensor.Resize(
{time_step * direction_num, batch_size, hidden_size});
layer_act_state_tensor_unbind = Unbind(layer_act_state_tensor);
}
bool has_sequence_length = sequence_length == nullptr ? false : true;
this->run_rnn_grad_function(dev_ctx,
input,
input_grad,
sequence_length,
init_h_unbind,
init_c_unbind,
init_h_grad_unbind,
init_c_grad_unbind,
&layer_grad_gate_tensor,
&layer_gate_tensor_unbind,
&layer_grad_gate_tensor_unbind,
&layer_state_tensor_unbind,
&layer_act_state_tensor_unbind,
&output_tensor_unbind,
&output_grad_tensor_unbind,
last_h_grad_unbind,
last_c_grad_unbind,
parameter_lists,
weight_list_grad,
layer_idx,
time_step,
has_sequence_length,
is_bidirec,
false,
mode);
}
};
template <typename T>
void split_tensor_at_last_dim(const CPUContext& dev_ctx,
const DenseTensor* output,
std::vector<DenseTensor*>* output_vec,
int axis) {
std::vector<const DenseTensor*> shape_refer;
(*output_vec)[0]->Resize(
{output->dims()[0], output->dims()[1], output->dims()[2] / 2});
dev_ctx.Alloc<T>((*output_vec)[0]);
(*output_vec)[1]->Resize(
{output->dims()[0], output->dims()[1], output->dims()[2] / 2});
dev_ctx.Alloc<T>((*output_vec)[1]);
shape_refer.emplace_back((*output_vec)[0]);
shape_refer.emplace_back((*output_vec)[1]);
funcs::SplitFunctor<CPUContext, T> functor;
functor(dev_ctx, *output, shape_refer, axis, output_vec);
}
template <typename T, typename GradCellType>
struct BidirGradLayer : GradLayer<T, GradCellType> {
explicit BidirGradLayer(const GradCellType& cell)
: GradLayer<T, GradCellType>(cell) {}
virtual ~BidirGradLayer() {}
void operator()(const CPUContext& dev_ctx,
const DenseTensor* input,
const DenseTensor* output,
std::vector<DenseTensor>* init_h_unbind,
std::vector<DenseTensor>* init_c_unbind,
const std::vector<DenseTensor>& last_h_grad_unbind,
const std::vector<DenseTensor>& last_c_grad_unbind,
const std::vector<DenseTensor>& gate_tensor_unbind,
const std::vector<DenseTensor>& state_tensor_unbind,
const std::vector<DenseTensor>& act_state_tensor_unbind,
const DenseTensor* output_grad,
const std::vector<std::vector<DenseTensor>>& parameter_lists,
const DenseTensor* sequence_length,
DenseTensor* input_grad,
std::vector<DenseTensor>* init_h_grad_unbind,
std::vector<DenseTensor>* init_c_grad_unbind,
std::vector<std::vector<DenseTensor>>* weight_list_grad,
int layer_idx,
bool is_bidirec,
int hidden_size,
const std::string& mode,
int gate_num) {
int time_step = input->dims()[0];
int batch_size = input->dims()[1];
int direction_num = is_bidirec ? 2 : 1;
// split the output two tensor to output_forward, output_backward
phi::funcs::SetConstant<CPUContext, T> zero;
zero(dev_ctx, input_grad, static_cast<T>(0.0));
std::vector<DenseTensor*> output_vec;
DenseTensor forward_output;
DenseTensor backward_output;
std::vector<DenseTensor> forward_output_tensor_unbind;
std::vector<DenseTensor> backward_output_tensor_unbind;
// in the last layer, we will use the output as the last hidden
// the output just the concat the forward hidden, backward hidden, so just
// split it
// in other layer, we just split the hidden in the rows
output_vec.emplace_back(&forward_output);
output_vec.emplace_back(&backward_output);
split_tensor_at_last_dim<T>(dev_ctx, output, &output_vec, 2);
forward_output_tensor_unbind = Unbind(*(output_vec[0]));
backward_output_tensor_unbind = Unbind(*(output_vec[1]));
std::vector<DenseTensor*> output_grad_vec;
DenseTensor grad_forward_output;
DenseTensor grad_backward_output;
output_grad_vec.emplace_back(&grad_forward_output);
output_grad_vec.emplace_back(&grad_backward_output);
split_tensor_at_last_dim<T>(dev_ctx, output_grad, &output_grad_vec, 2);
auto forward_output_grad_tensor_unbind = Unbind(*(output_grad_vec[0]));
auto backward_output_grad_tensor_unbind = Unbind(*(output_grad_vec[1]));
// the gate_tensor and the grad_gate_tensor must be unbind
auto layer_gate_tensor = gate_tensor_unbind[layer_idx];
layer_gate_tensor.Resize(
{time_step * 2, batch_size, hidden_size * gate_num});
auto layer_forward_gate_tensor = layer_gate_tensor.Slice(0, time_step);
auto layer_backward_gate_tensor =
layer_gate_tensor.Slice(time_step, 2 * time_step);
auto layer_forward_gate_tensor_unbind = Unbind(layer_forward_gate_tensor);
auto layer_backward_gate_tensor_unbind = Unbind(layer_backward_gate_tensor);
DenseTensor layer_grad_gate_tensor;
layer_grad_gate_tensor.Resize(layer_gate_tensor.dims());
dev_ctx.Alloc<T>(&layer_grad_gate_tensor);
zero(dev_ctx, &layer_grad_gate_tensor, static_cast<T>(0.0));
auto layer_forward_grad_gate_tensor =
layer_grad_gate_tensor.Slice(0, time_step);
auto layer_backward_grad_gate_tensor =
layer_grad_gate_tensor.Slice(time_step, 2 * time_step);
auto layer_forward_grad_gate_tensor_unbind =
Unbind(layer_forward_grad_gate_tensor);
auto layer_backward_grad_gate_tensor_unbind =
Unbind(layer_backward_grad_gate_tensor);
DenseTensor layer_state_tensor;
std::vector<DenseTensor> layer_state_tensor_unbind;
if (state_tensor_unbind.size() > 0) {
layer_state_tensor = state_tensor_unbind[layer_idx];
layer_state_tensor.Resize(
{time_step * direction_num, batch_size, hidden_size});
layer_state_tensor_unbind = Unbind(layer_state_tensor);
}
DenseTensor layer_act_state_tensor;
std::vector<DenseTensor> layer_act_state_tensor_unbind;
if (act_state_tensor_unbind.size() > 0) {
layer_act_state_tensor = act_state_tensor_unbind[layer_idx];
layer_act_state_tensor.Resize(
{time_step * direction_num, batch_size, hidden_size});
layer_act_state_tensor_unbind = Unbind(layer_act_state_tensor);
}
const bool& has_sequence_length = sequence_length == nullptr ? false : true;
this->run_rnn_grad_function(dev_ctx,
input,
input_grad,
sequence_length,
init_h_unbind,
init_c_unbind,
init_h_grad_unbind,
init_c_grad_unbind,
&layer_forward_grad_gate_tensor,
&layer_forward_gate_tensor_unbind,
&layer_forward_grad_gate_tensor_unbind,
&layer_state_tensor_unbind,
&layer_act_state_tensor_unbind,
&forward_output_tensor_unbind,
&forward_output_grad_tensor_unbind,
last_h_grad_unbind,
last_c_grad_unbind,
parameter_lists,
weight_list_grad,
layer_idx,
time_step,
has_sequence_length,
is_bidirec,
false,
mode);
this->run_rnn_grad_function(dev_ctx,
input,
input_grad,
sequence_length,
init_h_unbind,
init_c_unbind,
init_h_grad_unbind,
init_c_grad_unbind,
&layer_backward_grad_gate_tensor,
&layer_backward_gate_tensor_unbind,
&layer_backward_grad_gate_tensor_unbind,
&layer_state_tensor_unbind,
&layer_act_state_tensor_unbind,
&backward_output_tensor_unbind,
&backward_output_grad_tensor_unbind,
last_h_grad_unbind,
last_c_grad_unbind,
parameter_lists,
weight_list_grad,
layer_idx,
time_step,
has_sequence_length,
is_bidirec,
true,
mode);
}
};
template <typename T>
void dropout_cpu_grad_function_inplace(const CPUContext& dev_ctx,
DenseTensor* grad_x,
const DenseTensor* mask,
float dropout_prob) {
DropoutHelper<T>(dev_ctx, grad_x, grad_x, mask, dropout_prob);
}
template <typename GradCellType,
template <typename, typename> class SingleGradLayerT,
template <typename, typename> class BidirGradLayerT,
typename T>
void RnnGradFunc(const CPUContext& dev_ctx,
const DenseTensor& x,
const std::vector<const DenseTensor*>& pre_state,
const std::vector<const DenseTensor*>& weight_list,
paddle::optional<const DenseTensor&> sequence_length,
const DenseTensor& out,
const DenseTensor& dropout_state,
const DenseTensor& reserve,
const DenseTensor& out_grad,
const std::vector<const DenseTensor*>& state_grad,
float dropout_prob,
bool is_bidirec,
int input_size,
int hidden_size,
int num_layers,
const std::string& mode,
int seed,
bool is_test,
int gate_num,
DenseTensor* x_grad,
std::vector<DenseTensor*> pre_state_grad,
std::vector<DenseTensor*> weight_grad_list) {
const DenseTensor* init_h = pre_state[0];
const DenseTensor* init_c = nullptr;
if (is_lstm(mode)) {
init_c = pre_state[1];
}
const DenseTensor* last_h_grad = state_grad[0];
const DenseTensor* last_c_grad = nullptr;
if (is_lstm(mode)) {
last_c_grad = state_grad[1];
}
DenseTensor* init_h_grad = nullptr;
DenseTensor* init_c_grad = nullptr;
if (!pre_state_grad.empty()) { // has gradient
init_h_grad = pre_state_grad[0];
if (is_lstm(mode) && pre_state_grad.size() > 1) {
init_c_grad = pre_state_grad[1];
}
}
// get the input_size, batch_size, time_step
const int time_step = x.dims()[0];
const int batch_size = x.dims()[1];
const int direction_num = is_bidirec ? 2 : 1;
// allocate the memory and initization the x_grad
DenseTensor x_grad_value;
if (!x_grad) {
x_grad = &x_grad_value;
}
x_grad->Resize(x.dims());
dev_ctx.Alloc<T>(x_grad);
if (init_h_grad) {
init_h_grad->Resize(init_h->dims());
dev_ctx.Alloc<T>(init_h_grad);
}
if (init_c_grad) {
init_c_grad->Resize(init_c->dims());
dev_ctx.Alloc<T>(init_c_grad);
}
// reset the parameter to sorted order and allocate the memory
std::vector<std::vector<DenseTensor>> parameter_lists;
parameter_lists.reserve(num_layers);
ResetParameterVector(
weight_list, num_layers, gate_num, is_bidirec, &parameter_lists);
for (unsigned int i = 0; i < weight_grad_list.size(); ++i) {
dev_ctx.Alloc<T>(weight_grad_list[i]);
}
std::vector<std::vector<DenseTensor>> parameter_lists_grad;
parameter_lists_grad.reserve(num_layers);
ResetParameterVector(weight_grad_list,
num_layers,
gate_num,
is_bidirec,
&parameter_lists_grad);
// resolve the state of reverse_state
DenseTensor gate_tensor;
DenseTensor state_tensor;
DenseTensor act_state_tensor;
DenseTensor hidden_tensor;
SplitReserveData(dev_ctx,
direction_num,
time_step,
batch_size,
hidden_size,
gate_num,
num_layers,
mode,
&reserve,
&gate_tensor,
&state_tensor,
&act_state_tensor,
&hidden_tensor);
int gate_num_tmp = gate_num;
if (gate_num == 0) {
gate_num_tmp = 1;
}
gate_tensor.Resize({num_layers,
time_step * direction_num,
batch_size,
hidden_size * gate_num_tmp});
if (state_tensor.numel() > 0) {
state_tensor.Resize(
{num_layers, time_step * direction_num, batch_size, hidden_size});
}
if (act_state_tensor.numel() > 0) {
act_state_tensor.Resize(
{num_layers, time_step * direction_num, batch_size, hidden_size});
}
if (num_layers > 1) {
hidden_tensor.Resize(
{num_layers - 1, time_step, batch_size, hidden_size * direction_num});
}
// unbind
auto last_h_grad_unbind = Unbind(*last_h_grad);
auto gate_tensor_unbind = Unbind(gate_tensor);
std::vector<DenseTensor> last_c_grad_unbind;
if (last_c_grad) {
last_c_grad_unbind = Unbind(*last_c_grad);
}
std::vector<DenseTensor> init_h_unbind, init_c_unbind;
std::vector<DenseTensor> init_h_grad_unbind, init_c_grad_unbind;
std::vector<DenseTensor> state_tensor_unbind, act_state_tensor_unbind;
std::vector<DenseTensor> hidden_tensor_unbind;
init_h_unbind = Unbind(*init_h);
if (init_c) {
init_c_unbind = Unbind(*init_c);
}
if (init_h_grad != nullptr) {
init_h_grad_unbind = Unbind(*init_h_grad);
}
if (init_c_grad != nullptr) {
init_c_grad_unbind = Unbind(*init_c_grad);
}
if (state_tensor.numel() > 0) {
state_tensor_unbind = Unbind(state_tensor);
}
if (act_state_tensor.numel() > 0) {
act_state_tensor_unbind = Unbind(act_state_tensor);
}
if (num_layers > 1) {
hidden_tensor_unbind = Unbind(hidden_tensor);
}
// squeeze the hidden first dim
for (unsigned int i = 0; i < hidden_tensor_unbind.size(); i++) {
hidden_tensor_unbind[i].Resize(
phi::slice_ddim(hidden_tensor_unbind[i].dims(),
1,
hidden_tensor_unbind[i].dims().size()));
}
// add the output tensor to the hidden vector
DenseTensor tmp;
hidden_tensor_unbind.emplace_back(tmp);
hidden_tensor_unbind[num_layers - 1].ShareDataWith(out);
GradCellType cell;
DenseTensor layer_input;
DenseTensor layer_output;
DenseTensor* layer_x_grad_holder = nullptr;
DenseTensor tmp_out;
tmp_out.ShareDataWith(out_grad);
DenseTensor* layer_output_grad_holder = &tmp_out;
DenseTensor x_grad_temp;
DenseTensor output_grad_temp;
bool has_allocate_mem = false;
for (int i = num_layers - 1; i >= 0; --i) {
// the layer input output had saved, just use the data
if (i > 0) {
if (layer_input.numel() == 0) {
layer_input.Resize(hidden_tensor_unbind[i - 1].dims());
dev_ctx.Alloc<T>(&layer_input);
}
DropoutHelper<T>(dev_ctx,
&hidden_tensor_unbind[i - 1],
&layer_input,
&dropout_state,
dropout_prob);
} else {
layer_input.ShareDataWith(x);
}
layer_output.ShareDataWith(hidden_tensor_unbind[i]);
if (num_layers == 1) {
layer_x_grad_holder = x_grad;
} else {
if (i == num_layers - 1) {
x_grad_temp.Resize(layer_input.dims());
dev_ctx.Alloc<T>(&x_grad_temp);
layer_x_grad_holder = &x_grad_temp;
}
}
if (is_bidirec) {
BidirGradLayerT<T, GradCellType> layer(cell);
layer(dev_ctx,
&layer_input,
&layer_output,
&init_h_unbind,
&init_c_unbind,
last_h_grad_unbind,
last_c_grad_unbind,
gate_tensor_unbind,
state_tensor_unbind,
act_state_tensor_unbind,
layer_output_grad_holder,
parameter_lists,
sequence_length.get_ptr(),
layer_x_grad_holder,
&init_h_grad_unbind,
&init_c_grad_unbind,
&parameter_lists_grad,
i,
is_bidirec,
hidden_size,
mode,
gate_num_tmp);
} else {
SingleGradLayerT<T, GradCellType> layer(cell);
layer(dev_ctx,
&layer_input,
&layer_output,
&init_h_unbind,
&init_c_unbind,
last_h_grad_unbind,
last_c_grad_unbind,
gate_tensor_unbind,
state_tensor_unbind,
act_state_tensor_unbind,
layer_output_grad_holder,
parameter_lists,
sequence_length.get_ptr(),
layer_x_grad_holder,
&init_h_grad_unbind,
&init_c_grad_unbind,
&parameter_lists_grad,
i,
is_bidirec,
hidden_size,
mode,
gate_num_tmp);
}
// calcluate the dropout gradient for the layer_x_grad_holder
// dropout_state save in the forward process
if (i > 0) {
if ((!is_test) && (dropout_prob != 0)) {
dropout_cpu_grad_function_inplace<T>(
dev_ctx, layer_x_grad_holder, &dropout_state, dropout_prob);
}
}
if (i - 1 == 0) {
layer_output_grad_holder = x_grad;
} else {
if (!has_allocate_mem) {
output_grad_temp.Resize(layer_x_grad_holder->dims());
dev_ctx.Alloc<T>(&output_grad_temp);
layer_output_grad_holder = &output_grad_temp;
has_allocate_mem = true;
}
}
SwapPoniter(&layer_x_grad_holder, &layer_output_grad_holder);
}
}
template <typename T, typename Context>
void RnnGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<const DenseTensor*>& pre_state,
const std::vector<const DenseTensor*>& weight_list,
paddle::optional<const DenseTensor&> sequence_length,
const DenseTensor& out,
const DenseTensor& dropout_state,
const DenseTensor& reserve,
const DenseTensor& out_grad,
const std::vector<const DenseTensor*>& state_grad,
float dropout_prob,
bool is_bidirec,
int input_size,
int hidden_size,
int num_layers,
const std::string& mode,
int seed,
bool is_test,
DenseTensor* x_grad,
std::vector<DenseTensor*> pre_state_grad,
std::vector<DenseTensor*> weight_grad_list) {
int gate_num = 4;
if (is_lstm(mode)) {
RnnGradFunc<LSTMGradCell<T>, SingleGradLayer, BidirGradLayer, T>(
dev_ctx,
x,
pre_state,
weight_list,
sequence_length,
out,
dropout_state,
reserve,
out_grad,
state_grad,
dropout_prob,
is_bidirec,
input_size,
hidden_size,
num_layers,
mode,
seed,
is_test,
gate_num,
x_grad,
pre_state_grad,
weight_grad_list);
} else if (is_gru(mode)) {
gate_num = 3;
RnnGradFunc<GRUGradCell<T>, SingleGradLayer, BidirGradLayer, T>(
dev_ctx,
x,
pre_state,
weight_list,
sequence_length,
out,
dropout_state,
reserve,
out_grad,
state_grad,
dropout_prob,
is_bidirec,
input_size,
hidden_size,
num_layers,
mode,
seed,
is_test,
gate_num,
x_grad,
pre_state_grad,
weight_grad_list);
// run gru
} else if (is_rnn_relu(mode)) {
gate_num = 1;
RnnGradFunc<SimpleRNNGradCell<T, funcs::ReluGradFunctor>,
SingleGradLayer,
BidirGradLayer,
T>(dev_ctx,
x,
pre_state,
weight_list,
sequence_length,
out,
dropout_state,
reserve,
out_grad,
state_grad,
dropout_prob,
is_bidirec,
input_size,
hidden_size,
num_layers,
mode,
seed,
is_test,
gate_num,
x_grad,
pre_state_grad,
weight_grad_list);
// run rnn
} else if (is_rnn_tanh(mode)) {
gate_num = 1;
RnnGradFunc<SimpleRNNGradCell<T, funcs::TanhGradFunctor>,
SingleGradLayer,
BidirGradLayer,
T>(dev_ctx,
x,
pre_state,
weight_list,
sequence_length,
out,
dropout_state,
reserve,
out_grad,
state_grad,
dropout_prob,
is_bidirec,
input_size,
hidden_size,
num_layers,
mode,
seed,
is_test,
gate_num,
x_grad,
pre_state_grad,
weight_grad_list);
}
}
} // namespace phi
PD_REGISTER_KERNEL(
rnn_grad, CPU, ALL_LAYOUT, phi::RnnGradKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/rnn_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/cpu/rnn_functor.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/activation_functor.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
#include "paddle/phi/kernels/funcs/detail/activation_functions.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/gru_compute.h"
#include "paddle/phi/kernels/funcs/lstm_compute.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename T>
struct Cell {
virtual ~Cell() {}
virtual void operator()(const CPUContext* dev_ctx,
DenseTensor* input,
const DenseTensor* weight_hh,
const DenseTensor* init_h,
const DenseTensor* init_c,
DenseTensor* last_h,
DenseTensor* last_c,
DenseTensor* last_c_act,
DenseTensor* output,
const DenseTensor* bias_hh,
DenseTensor* weight_hh_gru) const {}
};
template <typename T,
template <typename> class EigenActivationFunctor,
funcs::detail::ActivationType act_type>
struct SimpleRNNCell : Cell<T> {
void operator()(const CPUContext* dev_ctx,
DenseTensor* input,
const DenseTensor* weight_hh,
const DenseTensor* init_h,
const DenseTensor* init_c,
DenseTensor* last_h,
DenseTensor* last_c,
DenseTensor* last_c_act,
DenseTensor* output,
const DenseTensor* bias_hh,
DenseTensor* weight_hh_gru) const override {
auto blas = phi::funcs::GetBlas<CPUContext, T>(*dev_ctx);
auto mat_dim_a =
phi::funcs::CreateMatrixDescriptor(init_h->dims(), 0, false);
auto mat_dim_b =
phi::funcs::CreateMatrixDescriptor(weight_hh->dims(), 0, true);
mat_dim_a.height_ *= mat_dim_a.batch_size_;
mat_dim_a.batch_size_ = 0;
// convert the batch matmul to matmul, this operator could be speed faster
blas.MatMul(*init_h,
mat_dim_a,
*weight_hh,
mat_dim_b,
static_cast<T>(1.0),
input,
static_cast<T>(1.0));
auto z = EigenVector<T>::Flatten(
GET_DATA_SAFELY(input, "Input", "z", "Activation"));
auto hidden = EigenVector<T>::Flatten(
GET_DATA_SAFELY(output, "Output", "hidden", "Activation"));
auto* place = dev_ctx->eigen_device();
EigenActivationFunctor<T> functor;
functor(*place, z, hidden);
}
};
template <typename T>
struct GRUCell : Cell<T> {
void operator()(const CPUContext* dev_ctx,
DenseTensor* input,
const DenseTensor* weight_hh,
const DenseTensor* init_h,
const DenseTensor* init_c,
DenseTensor* last_h,
DenseTensor* last_c,
DenseTensor* last_c_act,
DenseTensor* output,
const DenseTensor* bias_hh,
DenseTensor* weight_hh_gru) const override {
auto blas = phi::funcs::GetBlas<CPUContext, T>(*dev_ctx);
auto mat_dim_a =
phi::funcs::CreateMatrixDescriptor(init_h->dims(), 0, false);
auto mat_dim_b =
phi::funcs::CreateMatrixDescriptor(weight_hh_gru->dims(), 0, true);
mat_dim_a.height_ *= mat_dim_a.batch_size_;
mat_dim_a.batch_size_ = 0;
// convert the batch matmul to matmul, this operator could be speed faster
blas.MatMul(*init_h,
mat_dim_a,
*weight_hh_gru,
mat_dim_b,
static_cast<T>(1.0),
input,
static_cast<T>(1.0));
size_t frame_size = init_h->dims()[2];
size_t batch_size = init_h->dims()[1];
phi::funcs::GRUMetaValue<T> gru_value;
gru_value.gate_weight = weight_hh->data<T>();
gru_value.state_weight = weight_hh->data<T>() + 2 * frame_size * frame_size;
gru_value.reset_bias = bias_hh->data<T>() + 2 * frame_size;
gru_value.gate_value = input->data<T>();
gru_value.reset_output_value = last_c->data<T>();
gru_value.output_value = output->data<T>();
gru_value.prev_out_value = init_h->data<T>();
auto gate_act = phi::funcs::detail::GetActivationType("sigmoid_v2");
auto cand_act = phi::funcs::detail::GetActivationType("tanh_v2");
phi::funcs::GRUUnitFunctorV2<CPUContext, T>::compute(
*dev_ctx, gru_value, frame_size, batch_size, cand_act, gate_act);
}
};
template <typename T>
struct LSTMCell : Cell<T> {
void operator()(const CPUContext* dev_ctx,
DenseTensor* input,
const DenseTensor* weight_hh,
const DenseTensor* init_h,
const DenseTensor* init_c,
DenseTensor* last_h,
DenseTensor* last_c,
DenseTensor* last_c_act,
DenseTensor* output,
const DenseTensor* bias_hh,
DenseTensor* weight_hh_gru) const override {
auto blas = phi::funcs::GetBlas<CPUContext, T>(*dev_ctx);
auto mat_dim_a =
phi::funcs::CreateMatrixDescriptor(init_h->dims(), 0, false);
auto mat_dim_b =
phi::funcs::CreateMatrixDescriptor(weight_hh->dims(), 0, true);
mat_dim_a.height_ *= mat_dim_a.batch_size_;
mat_dim_a.batch_size_ = 0;
// convert the batch matmul to matmul, this operator could be speed faster
blas.MatMul(*init_h,
mat_dim_a,
*weight_hh,
mat_dim_b,
static_cast<T>(1.0),
input,
static_cast<T>(1.0));
phi::funcs::LstmMetaValue<T> lstm_value;
lstm_value.check_ig = nullptr;
lstm_value.check_fg = nullptr;
lstm_value.check_og = nullptr;
auto gate_act = phi::funcs::detail::GetActivationType("sigmoid_v2");
auto cell_act = phi::funcs::detail::GetActivationType("tanh_v2");
auto cand_act = phi::funcs::detail::GetActivationType("tanh_v2");
size_t frame_size = init_h->dims()[2];
size_t batch_size = init_h->dims()[1];
DenseTensor cell_pre_act;
if (last_c_act == nullptr) { /* is test */
cell_pre_act.Resize(init_h->dims());
dev_ctx->Alloc<T>(&cell_pre_act);
last_c_act = &cell_pre_act;
}
lstm_value.prev_state_value = init_c->data<T>();
lstm_value.gate_value = input->data<T>();
lstm_value.output_value = output->data<T>();
lstm_value.state_value = last_c->data<T>();
lstm_value.state_active_value = last_c_act->data<T>();
T cell_clip = 0.0;
phi::funcs::LstmUnitFunctor<CPUContext, T>::compute(*dev_ctx,
lstm_value,
frame_size,
batch_size,
cell_clip,
gate_act,
cell_act,
cand_act,
false);
}
};
template <typename T, typename CellType>
struct Layer {
explicit Layer(const CellType& cell) : cell_(cell) {}
virtual ~Layer() {}
void preprocess(const CPUContext& dev_ctx,
const DenseTensor& input,
const DenseTensor& weight,
const DenseTensor& bias_ih,
const DenseTensor& bias_hh,
const std::string& mode,
bool is_test,
DenseTensor* cache_input) {
// crate the temp input for the X * W_ih^T + Bias_ih
const int& hidden_size = weight.dims()[0];
cache_input->Resize(
phi::make_ddim({input.dims()[0], input.dims()[1], hidden_size}));
if (is_test) {
dev_ctx.Alloc<T>(cache_input);
}
auto blas = phi::funcs::GetBlas<CPUContext, T>(dev_ctx);
auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(input.dims(), 0, false);
auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(weight.dims(), 0, true);
// convert the batch matmul to matmul, this operator could be speed faster
mat_dim_a.height_ *= mat_dim_a.batch_size_;
mat_dim_a.batch_size_ = 0;
blas.MatMul(input,
mat_dim_a,
weight,
mat_dim_b,
static_cast<T>(1.0),
cache_input,
static_cast<T>(0));
auto in =
EigenMatrix<T>::Reshape(*cache_input, cache_input->dims().size() - 1);
auto bias_ih_tmp =
EigenMatrix<T>::From(bias_ih, phi::make_ddim({1, bias_ih.dims()[0]}));
const int row_num =
phi::product(cache_input->dims()) / cache_input->dims()[2];
in = in + bias_ih_tmp.broadcast(Eigen::DSizes<int, 2>(row_num, 1));
if (is_gru(mode)) {
// reset_gate update_gate cell_gate = [1, 1, 0]
DenseTensor bias_hh_tmp = Empty<T>(dev_ctx, {bias_hh.numel()});
Copy(dev_ctx, bias_hh, CPUPlace(), false, &bias_hh_tmp);
bias_hh_tmp.Resize({3, bias_hh_tmp.numel() / 3});
auto bias_hh_tmp_unbind = Unbind(bias_hh_tmp);
phi::funcs::SetConstant<CPUContext, T> zero;
zero(dev_ctx, &bias_hh_tmp_unbind[2], static_cast<T>(0.0));
auto bias_hh_after_mask = EigenMatrix<T>::From(
bias_hh_tmp, phi::make_ddim({1, bias_hh.dims()[0]}));
in = in + bias_hh_after_mask.broadcast(Eigen::DSizes<int, 2>(row_num, 1));
} else {
auto bias_hh_no_mask =
EigenMatrix<T>::From(bias_hh, phi::make_ddim({1, bias_hh.dims()[0]}));
in = in + bias_hh_no_mask.broadcast(Eigen::DSizes<int, 2>(row_num, 1));
}
}
void postprocess(const CPUContext& dev_ctx,
DenseTensor* output,
const DenseTensor* init_h,
const DenseTensor* init_c,
DenseTensor* last_h,
DenseTensor* last_c,
const DenseTensor& mask_tensor,
const std::string& mode) {
// in the output, if mask flag is 0, we will retun the zero data
auto& place = *dev_ctx.eigen_device();
auto out = EigenMatrix<T>::Reshape(*output, output->dims().size() - 1);
auto mask = EigenMatrix<T>::From(
mask_tensor, phi::make_ddim({mask_tensor.dims()[1], 1}));
auto pre_h = EigenMatrix<T>::Reshape(*init_h, init_h->dims().size() - 1);
auto curr_h = EigenMatrix<T>::Reshape(*last_h, last_h->dims().size() - 1);
auto mask_broadcast =
mask.broadcast(Eigen::DSizes<int, 2>(1, output->dims()[2]));
curr_h.device(place) = out * mask_broadcast + pre_h * (1 - mask_broadcast);
out.device(place) = out * mask_broadcast;
if (is_lstm(mode)) {
auto pre_c = EigenMatrix<T>::Reshape(*init_c, init_c->dims().size() - 1);
auto curr_c = EigenMatrix<T>::Reshape(*last_c, last_c->dims().size() - 1);
curr_c.device(place) =
curr_c * mask_broadcast + pre_c * (1 - mask_broadcast);
}
}
virtual void operator()(const CPUContext& dev_ctx,
const DenseTensor* input,
const std::vector<DenseTensor>& vec,
const std::vector<DenseTensor>& init_h,
const std::vector<DenseTensor>& init_c,
const DenseTensor* sequence_length,
std::vector<DenseTensor> last_h,
std::vector<DenseTensor> last_c,
DenseTensor* output,
const int& layer_idx,
const int& gate_num,
DenseTensor* gate_value,
DenseTensor* cell_value,
DenseTensor* cell_act_value,
const std::string& mode,
bool is_test) {}
void RunTestIter(const CPUContext& dev_ctx,
const DenseTensor* input,
const std::vector<DenseTensor>& vec,
const std::vector<DenseTensor>& init_h,
const std::vector<DenseTensor>& init_c,
const DenseTensor* sequence_length,
std::vector<DenseTensor>* last_h_ptr,
std::vector<DenseTensor>* last_c_ptr,
DenseTensor* output,
int layer_idx,
DenseTensor* gate_value,
DenseTensor* cell_value,
DenseTensor* cell_act_value,
bool is_bidirect,
int offset,
const std::string& mode) {
bool is_reverse = false;
if (is_bidirect) {
layer_idx = 2 * layer_idx + offset;
if (offset > 0) {
is_reverse = true;
}
}
const int time_step = input->dims()[0];
this->preprocess(dev_ctx,
*input,
vec[0 + offset * 4],
vec[2 + offset * 4],
vec[3 + offset * 4],
mode,
true,
gate_value);
auto input_tensors = Unbind(*gate_value);
auto output_tensors = Unbind(*output);
if (is_reverse) {
std::reverse(input_tensors.begin(), input_tensors.end());
std::reverse(output_tensors.begin(), output_tensors.end());
}
std::vector<DenseTensor> mask_tensor_list;
// construct the mask matrix for the mask
bool has_sequence_length = false;
if (sequence_length != nullptr) {
has_sequence_length = true;
}
DenseTensor mask_matrix;
int mask_min_length = time_step;
if (has_sequence_length) {
mask_matrix.Resize(phi::make_ddim({time_step, input->dims()[1]}));
CreateMaskMatrix<T>(
dev_ctx, sequence_length, &mask_matrix, is_reverse, &mask_min_length);
mask_tensor_list = Unbind(mask_matrix);
}
if (is_reverse) {
mask_min_length = mask_min_length - time_step + 1;
}
bool has_allocate_mem_c = false;
bool has_use_last_h_holder = false;
const int& reverse_flag = is_reverse ? -1 : 1;
// define the init_h holder for the swap
DenseTensor init_h_temp;
Copy(dev_ctx, *&init_h[layer_idx], dev_ctx.GetPlace(), false, &init_h_temp);
DenseTensor* init_h_holder = &init_h_temp;
DenseTensor* last_h_holder = nullptr;
if (0 < mask_min_length) {
last_h_holder = &(output_tensors[0]);
} else {
last_h_holder = &(*last_h_ptr)[layer_idx];
has_use_last_h_holder = true;
}
DenseTensor* init_c_holder = nullptr;
const DenseTensor* init_c_temp_holder = nullptr;
DenseTensor init_c_temp;
DenseTensor* last_c_holder = nullptr;
DenseTensor last_c_temp;
if (is_lstm(mode)) {
last_c_holder = &(*last_c_ptr)[layer_idx];
init_c_temp_holder = &init_c[layer_idx];
} else if (is_gru(mode)) {
// for reset output value
last_c_temp.Resize(init_h[layer_idx].dims());
dev_ctx.Alloc<T>(&last_c_temp);
last_c_holder = &last_c_temp;
}
DenseTensor weight_hh_tmp; // for gru
if (is_gru(mode)) {
weight_hh_tmp.Resize(vec[1 + offset * 4].dims());
dev_ctx.Alloc<T>(&weight_hh_tmp);
Copy(dev_ctx,
vec[1 + offset * 4],
dev_ctx.GetPlace(),
false,
&weight_hh_tmp);
weight_hh_tmp.Resize({3, weight_hh_tmp.numel() / 3});
auto weight_hh_tmp_unbind = Unbind(weight_hh_tmp);
phi::funcs::SetConstant<CPUContext, T> zero;
zero(dev_ctx, &weight_hh_tmp_unbind[2], static_cast<T>(0.0));
weight_hh_tmp.Resize(vec[1 + offset * 4].dims());
}
for (int i = 0; i < time_step; i++) {
bool in_mask = (reverse_flag * i) >= mask_min_length;
if (i > 0) {
if (!has_allocate_mem_c) {
if (is_lstm(mode) || is_gru(mode)) {
init_c_temp.Resize(init_h[layer_idx].dims());
dev_ctx.Alloc<T>(&init_c_temp);
init_c_holder = &init_c_temp;
}
has_allocate_mem_c = true;
}
SwapPoniter(&init_c_holder, &last_c_holder);
init_c_temp_holder = init_c_holder;
}
cell_(&dev_ctx,
&input_tensors[i],
&vec[1 + offset * 4],
init_h_holder,
init_c_temp_holder,
last_h_holder,
last_c_holder,
nullptr,
&output_tensors[i],
&vec[3 + offset * 4] /* bias_hh */,
&weight_hh_tmp);
if (in_mask) {
this->postprocess(dev_ctx,
&output_tensors[i],
init_h_holder,
init_c_temp_holder,
last_h_holder,
last_c_holder,
mask_tensor_list[i],
mode);
}
// prepare next step
if (i + 1 < time_step) {
bool next_step_mask = (reverse_flag * (i + 1)) >= mask_min_length;
if (next_step_mask) {
if (!has_use_last_h_holder) {
init_h_holder = &(*last_h_ptr)[layer_idx];
}
} else {
init_h_holder = &(output_tensors[i + 1]);
}
SwapPoniter(&init_h_holder, &last_h_holder);
}
}
if (has_sequence_length) {
if (last_h_holder != &(*last_h_ptr)[layer_idx]) {
Copy(dev_ctx,
*last_h_holder,
dev_ctx.GetPlace(),
false,
&(*last_h_ptr)[layer_idx]);
}
} else {
Copy(dev_ctx,
output_tensors[time_step - 1],
dev_ctx.GetPlace(),
false,
&(*last_h_ptr)[layer_idx]);
}
if (time_step % 2 == 0) {
if (is_lstm(mode)) {
Copy(dev_ctx,
*last_c_holder,
dev_ctx.GetPlace(),
false,
&(*last_c_ptr)[layer_idx]);
}
}
}
void RunIter(const CPUContext& dev_ctx,
const DenseTensor* input,
const std::vector<DenseTensor>& vec,
const std::vector<DenseTensor>& init_h,
const std::vector<DenseTensor>& init_c,
const DenseTensor* sequence_length,
std::vector<DenseTensor>* last_h_ptr,
std::vector<DenseTensor>* last_c_ptr,
DenseTensor* output,
int layer_idx,
DenseTensor* gate_value,
DenseTensor* cell_value,
DenseTensor* cell_act_value,
bool is_bidirect,
int offset,
const std::string& mode,
bool is_test) {
if (is_test) {
RunTestIter(dev_ctx,
input,
vec,
init_h,
init_c,
sequence_length,
last_h_ptr,
last_c_ptr,
output,
layer_idx,
gate_value,
cell_value,
cell_act_value,
is_bidirect,
offset,
mode);
return;
}
bool is_reverse = false;
if (is_bidirect) {
layer_idx = 2 * layer_idx + offset;
if (offset > 0) {
is_reverse = true;
}
}
const int time_step = input->dims()[0];
this->preprocess(dev_ctx,
*input,
vec[0 + offset * 4],
vec[2 + offset * 4],
vec[3 + offset * 4],
mode,
is_test,
gate_value);
auto input_tensors = Unbind(*gate_value);
auto output_tensors = Unbind(*output);
if (is_reverse) {
std::reverse(input_tensors.begin(), input_tensors.end());
std::reverse(output_tensors.begin(), output_tensors.end());
}
std::vector<DenseTensor> mask_tensor_list;
// construct the mask matrix for the mask
bool has_sequence_length = false;
if (sequence_length != nullptr) {
has_sequence_length = true;
}
DenseTensor mask_matrix;
int mask_min_length = time_step;
if (has_sequence_length) {
mask_matrix.Resize(phi::make_ddim({time_step, input->dims()[1]}));
CreateMaskMatrix<T>(
dev_ctx, sequence_length, &mask_matrix, is_reverse, &mask_min_length);
mask_tensor_list = Unbind(mask_matrix);
}
if (is_reverse) {
mask_min_length = mask_min_length - time_step + 1;
}
// define the init_h holder for the swap
bool has_use_last_h_holder = false;
const int& reverse_flag = is_reverse ? -1 : 1;
std::vector<DenseTensor> cell_value_tensors;
std::vector<DenseTensor> cell_act_value_tensors;
DenseTensor init_h_temp;
Copy(dev_ctx, *&init_h[layer_idx], dev_ctx.GetPlace(), false, &init_h_temp);
DenseTensor* init_h_holder = &init_h_temp;
DenseTensor* last_h_holder = nullptr;
if (0 < mask_min_length) {
last_h_holder = &(output_tensors[0]);
} else {
last_h_holder = &(*last_h_ptr)[layer_idx];
has_use_last_h_holder = true;
}
const DenseTensor* init_c_holder = nullptr;
DenseTensor* last_c_holder = nullptr;
DenseTensor* last_c_act_holder = nullptr;
if (is_lstm(mode) || is_gru(mode)) {
cell_value->Resize({time_step, cell_value->numel() / time_step});
cell_value_tensors = Unbind(*cell_value);
if (is_lstm(mode)) {
cell_act_value->Resize(
{time_step, cell_act_value->numel() / time_step});
cell_act_value_tensors = Unbind(*cell_act_value);
}
}
DenseTensor weight_hh_tmp; // for gru
if (is_gru(mode)) {
weight_hh_tmp.Resize(vec[1 + offset * 4].dims());
dev_ctx.Alloc<T>(&weight_hh_tmp);
Copy(dev_ctx,
vec[1 + offset * 4],
dev_ctx.GetPlace(),
false,
&weight_hh_tmp);
weight_hh_tmp.Resize({3, weight_hh_tmp.numel() / 3});
auto weight_hh_tmp_unbind = Unbind(weight_hh_tmp);
phi::funcs::SetConstant<CPUContext, T> zero;
zero(dev_ctx, &weight_hh_tmp_unbind[2], static_cast<T>(0.0));
weight_hh_tmp.Resize(vec[1 + offset * 4].dims());
}
for (int i = 0; i < time_step; i++) {
bool in_mask = (reverse_flag * i) >= mask_min_length;
if (is_lstm(mode)) {
if (i == 0) {
init_c_holder = &init_c[layer_idx];
} else {
init_c_holder = &cell_value_tensors[i - 1];
}
cell_value_tensors[i].Resize(init_c[layer_idx].dims());
cell_act_value_tensors[i].Resize(init_c[layer_idx].dims());
last_c_holder = &cell_value_tensors[i];
last_c_act_holder = &cell_act_value_tensors[i];
} else if (is_gru(mode)) {
cell_value_tensors[i].Resize(init_h[layer_idx].dims());
last_c_holder = &cell_value_tensors[i];
}
cell_(&dev_ctx,
&input_tensors[i],
&vec[1 + offset * 4],
init_h_holder,
init_c_holder,
last_h_holder,
last_c_holder,
last_c_act_holder,
&output_tensors[i],
&vec[3 + offset * 4] /* bias_hh */,
&weight_hh_tmp);
if (in_mask) {
this->postprocess(dev_ctx,
&output_tensors[i],
init_h_holder,
init_c_holder,
last_h_holder,
last_c_holder,
mask_tensor_list[i],
mode);
}
// prepare next step
if (i + 1 < time_step) {
bool next_step_mask = (reverse_flag * (i + 1)) >= mask_min_length;
if (next_step_mask) {
if (!has_use_last_h_holder) {
init_h_holder = &(*last_h_ptr)[layer_idx];
}
} else {
init_h_holder = &(output_tensors[i + 1]);
}
SwapPoniter(&init_h_holder, &last_h_holder);
}
}
if (has_sequence_length) {
if (last_h_holder != &(*last_h_ptr)[layer_idx]) {
Copy(dev_ctx,
*last_h_holder,
dev_ctx.GetPlace(),
false,
&(*last_h_ptr)[layer_idx]);
}
} else {
Copy(dev_ctx,
output_tensors[time_step - 1],
dev_ctx.GetPlace(),
false,
&(*last_h_ptr)[layer_idx]);
}
if (is_lstm(mode)) {
Copy(dev_ctx,
cell_value_tensors[time_step - 1],
dev_ctx.GetPlace(),
false,
&(*last_c_ptr)[layer_idx]);
}
}
// Cell for the rnn module
CellType cell_;
};
template <typename T, typename CellType>
struct SingleLayer : public Layer<T, CellType> {
explicit SingleLayer(const CellType& cell) : Layer<T, CellType>(cell) {}
void operator()(const CPUContext& dev_ctx,
const DenseTensor* input,
const std::vector<DenseTensor>& vec,
const std::vector<DenseTensor>& init_h,
const std::vector<DenseTensor>& init_c,
const DenseTensor* sequence_length,
std::vector<DenseTensor> last_h,
std::vector<DenseTensor> last_c,
DenseTensor* output,
const int& layer_idx,
const int& gate_num,
DenseTensor* gate_value,
DenseTensor* cell_value,
DenseTensor* cell_act_value,
const std::string& mode,
bool is_test) {
this->RunIter(dev_ctx,
input,
vec,
init_h,
init_c,
sequence_length,
&last_h,
&last_c,
output,
layer_idx,
gate_value,
cell_value,
cell_act_value,
false,
0,
mode,
is_test);
}
};
template <typename T, typename CellType>
struct BidirLayer : public Layer<T, CellType> {
explicit BidirLayer(const CellType& cell) : Layer<T, CellType>(cell) {}
void operator()(const CPUContext& dev_ctx,
const DenseTensor* input,
const std::vector<DenseTensor>& vec,
const std::vector<DenseTensor>& init_h,
const std::vector<DenseTensor>& init_c,
const DenseTensor* sequence_length,
std::vector<DenseTensor> last_h,
std::vector<DenseTensor> last_c,
DenseTensor* output,
const int& layer_idx,
const int& gate_num,
DenseTensor* gate_value,
DenseTensor* cell_value,
DenseTensor* cell_act_value,
const std::string& mode,
bool is_test) {
std::vector<DenseTensor> output_vec(2);
DenseTensor forward_input_w, forward_cell_value, forward_cell_act_value;
DenseTensor backward_input_w, backward_cell_value, backward_cell_act_value;
int time_step = input->dims()[0];
int batch_size = input->dims()[1];
int hidden_size = output->dims()[2];
for (int i = 0; i < 2; ++i) {
output_vec[i].Resize({time_step, batch_size, hidden_size / 2});
dev_ctx.Alloc<T>(&output_vec[i]);
}
if (!is_test) {
gate_value->Resize({2, gate_value->numel() / 2});
forward_input_w = gate_value->Slice(0, 1);
backward_input_w = gate_value->Slice(1, 2);
if (is_lstm(mode) || is_gru(mode)) /* for lstm and gru */ {
cell_value->Resize({2, cell_value->numel() / 2});
cell_act_value->Resize({2, cell_act_value->numel() / 2});
forward_cell_value = cell_value->Slice(0, 1);
backward_cell_value = cell_value->Slice(1, 2);
if (is_lstm(mode)) {
forward_cell_act_value = cell_act_value->Slice(0, 1);
backward_cell_act_value = cell_act_value->Slice(1, 2);
}
}
}
this->RunIter(dev_ctx,
input,
vec,
init_h,
init_c,
sequence_length,
&last_h,
&last_c,
&output_vec[0],
layer_idx,
&forward_input_w,
&forward_cell_value,
&forward_cell_act_value,
true,
0,
mode,
is_test);
this->RunIter(dev_ctx,
input,
vec,
init_h,
init_c,
sequence_length,
&last_h,
&last_c,
&output_vec[1],
layer_idx,
&backward_input_w,
&backward_cell_value,
&backward_cell_act_value,
true,
1,
mode,
is_test);
// concat the the output result
funcs::ConcatFunctor<CPUContext, T> concat_functor;
concat_functor(dev_ctx, output_vec, static_cast<int>(2), output);
}
};
template <typename T, typename Context>
void RnnKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<const DenseTensor*>& pre_state,
const std::vector<const DenseTensor*>& weight_list,
paddle::optional<const DenseTensor&> sequence_length,
float dropout_prob,
bool is_bidirec,
int input_size,
int hidden_size,
int num_layers,
const std::string& mode,
int seed,
bool is_test,
DenseTensor* out,
DenseTensor* dropout_state,
std::vector<DenseTensor*> state,
DenseTensor* reserve) {
if (dropout_state->IsInitialized()) {
if (dropout_state->numel() != out->numel()) dropout_state->clear();
}
const auto& out_dim = out->dims();
Full<uint8_t>(dev_ctx, {out_dim.Get(), out_dim.size()}, 1, dropout_state);
// init the output and allocate the memory
dev_ctx.template Alloc<T>(out);
int gate_num = 4;
dev_ctx.template Alloc<T>(state[0]);
if (is_lstm(mode)) {
dev_ctx.template Alloc<T>(state[1]);
RnnFunc<LSTMCell<T>, Layer, SingleLayer, BidirLayer, T>(
dev_ctx,
&x,
weight_list,
pre_state[0],
pre_state[1],
sequence_length.get_ptr(),
state[0],
state[1],
out,
dropout_state,
num_layers,
gate_num,
input_size,
hidden_size,
is_bidirec,
mode,
dropout_prob,
is_test,
seed,
reserve);
} else if (is_rnn_relu(mode)) {
gate_num = 1;
RnnFunc<SimpleRNNCell<T,
funcs::ReluCPUFunctor,
phi::funcs::detail::ActivationType::kReLU>,
Layer,
SingleLayer,
BidirLayer,
T>(dev_ctx,
&x,
weight_list,
pre_state[0],
nullptr,
sequence_length.get_ptr(),
state[0],
nullptr,
out,
dropout_state,
num_layers,
gate_num,
input_size,
hidden_size,
is_bidirec,
mode,
dropout_prob,
is_test,
seed,
reserve);
} else if (is_rnn_tanh(mode)) {
gate_num = 1;
RnnFunc<SimpleRNNCell<T,
funcs::TanhFunctor,
phi::funcs::detail::ActivationType::kTanhV2>,
Layer,
SingleLayer,
BidirLayer,
T>(dev_ctx,
&x,
weight_list,
pre_state[0],
nullptr,
sequence_length.get_ptr(),
state[0],
nullptr,
out,
dropout_state,
num_layers,
gate_num,
input_size,
hidden_size,
is_bidirec,
mode,
dropout_prob,
is_test,
seed,
reserve);
} else if (is_gru(mode)) {
gate_num = 3;
RnnFunc<GRUCell<T>, Layer, SingleLayer, BidirLayer, T>(
dev_ctx,
&x,
weight_list,
pre_state[0],
nullptr,
sequence_length.get_ptr(),
state[0],
nullptr,
out,
dropout_state,
num_layers,
gate_num,
input_size,
hidden_size,
is_bidirec,
mode,
dropout_prob,
is_test,
seed,
reserve);
}
}
} // namespace phi
PD_REGISTER_KERNEL(rnn, CPU, ALL_LAYOUT, phi::RnnKernel, float, double) {}
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <type_traits> #include <type_traits>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/activation_op.h" #include "paddle/phi/kernels/funcs/activation_functor.h"
#include "paddle/phi/kernels/funcs/detail/activation_functions.h" #include "paddle/phi/kernels/funcs/detail/activation_functions.h"
#include "paddle/phi/kernels/funcs/gru_compute.h" #include "paddle/phi/kernels/funcs/gru_compute.h"
...@@ -283,11 +283,10 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output, ...@@ -283,11 +283,10 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
#endif #endif
} }
template <typename T> template <typename T, typename Context>
inline void forward_reset_outputV2( inline void forward_reset_outputV2(const Context &context,
const paddle::platform::CPUDeviceContext &context, phi::funcs::GRUMetaValue<T> value,
phi::funcs::GRUMetaValue<T> value, int frame_size) {
int frame_size) {
auto &place = *context.eigen_device(); auto &place = *context.eigen_device();
auto value_reset_gate = auto value_reset_gate =
typename EigenVector<T>::Type(value.gate_value, Array1(frame_size)); typename EigenVector<T>::Type(value.gate_value, Array1(frame_size));
...@@ -297,23 +296,20 @@ inline void forward_reset_outputV2( ...@@ -297,23 +296,20 @@ inline void forward_reset_outputV2(
value.reset_output_value, Array1(frame_size)); value.reset_output_value, Array1(frame_size));
auto value_reset_bias = auto value_reset_bias =
typename EigenVector<T>::ConstType(value.reset_bias, Array1(frame_size)); typename EigenVector<T>::ConstType(value.reset_bias, Array1(frame_size));
paddle::operators::SigmoidFunctor<T>()( SigmoidFunctor<T>()(place, value_reset_gate, value_reset_gate);
place, value_reset_gate, value_reset_gate); SigmoidFunctor<T>()(place, value_update_gate, value_update_gate);
paddle::operators::SigmoidFunctor<T>()(
place, value_update_gate, value_update_gate);
value_reset_output.device(place) = value_reset_output.device(place) =
(value_reset_output + value_reset_bias) * value_reset_gate; (value_reset_output + value_reset_bias) * value_reset_gate;
} }
template <class OpResetOutput, typename T> template <typename Context, class OpResetOutput, typename T>
inline void forward_reset_output( inline void forward_reset_output(OpResetOutput op_reset_output,
OpResetOutput op_reset_output, phi::funcs::GRUMetaValue<T> value,
phi::funcs::GRUMetaValue<T> value, int frame_size,
int frame_size, int batch_size,
int batch_size, ActivationType active_gate,
ActivationType active_gate, bool old_version = true,
bool old_version = true, const Context *context = nullptr) {
const paddle::platform::CPUDeviceContext *context = nullptr) {
for (int b = 0; b < batch_size; b++) { for (int b = 0; b < batch_size; b++) {
if (!old_version) { if (!old_version) {
// use eigen // use eigen
...@@ -348,11 +344,10 @@ inline void forward_reset_output( ...@@ -348,11 +344,10 @@ inline void forward_reset_output(
} }
} }
template <typename T> template <typename T, typename Context>
inline void forward_final_outputV2( inline void forward_final_outputV2(const Context &context,
const paddle::platform::CPUDeviceContext &context, phi::funcs::GRUMetaValue<T> value,
phi::funcs::GRUMetaValue<T> value, int frame_size) {
int frame_size) {
auto &place = *context.eigen_device(); auto &place = *context.eigen_device();
auto value_update_gate = typename EigenVector<T>::Type( auto value_update_gate = typename EigenVector<T>::Type(
value.gate_value + frame_size, Array1(frame_size)); value.gate_value + frame_size, Array1(frame_size));
...@@ -360,8 +355,7 @@ inline void forward_final_outputV2( ...@@ -360,8 +355,7 @@ inline void forward_final_outputV2(
value.gate_value + 2 * frame_size, Array1(frame_size)); value.gate_value + 2 * frame_size, Array1(frame_size));
auto value_output = auto value_output =
typename EigenVector<T>::Type(value.output_value, Array1(frame_size)); typename EigenVector<T>::Type(value.output_value, Array1(frame_size));
paddle::operators::TanhFunctor<T>()( TanhFunctor<T>()(place, value_frame_state, value_frame_state);
place, value_frame_state, value_frame_state);
value_output.device(place) = value_output.device(place) =
(static_cast<T>(1.0) - value_update_gate) * value_frame_state; (static_cast<T>(1.0) - value_update_gate) * value_frame_state;
if (value.prev_out_value) { if (value.prev_out_value) {
...@@ -372,16 +366,15 @@ inline void forward_final_outputV2( ...@@ -372,16 +366,15 @@ inline void forward_final_outputV2(
} }
} }
template <class OpFinalOutput, typename T> template <typename Context, class OpFinalOutput, typename T>
inline void forward_final_output( inline void forward_final_output(OpFinalOutput op_final_output,
OpFinalOutput op_final_output, phi::funcs::GRUMetaValue<T> value,
phi::funcs::GRUMetaValue<T> value, int frame_size,
int frame_size, int batch_size,
int batch_size, ActivationType active_node,
ActivationType active_node, bool origin_mode,
bool origin_mode, bool old_version = true,
bool old_version = true, const Context *context = nullptr) {
const paddle::platform::CPUDeviceContext *context = nullptr) {
for (int b = 0; b < batch_size; b++) { for (int b = 0; b < batch_size; b++) {
if (!old_version) { if (!old_version) {
// eigen // eigen
...@@ -871,8 +864,8 @@ inline void backward_reset_grad(OpResetGrad op_reset_grad, ...@@ -871,8 +864,8 @@ inline void backward_reset_grad(OpResetGrad op_reset_grad,
} }
} }
template <typename T> template <typename T, typename Context>
inline void gru_backward(const paddle::platform::CPUDeviceContext &context, inline void gru_backward(const Context &context,
phi::funcs::GRUMetaValue<T> value, phi::funcs::GRUMetaValue<T> value,
phi::funcs::GRUMetaGrad<T> grad, phi::funcs::GRUMetaGrad<T> grad,
int frame_size) { int frame_size) {
...@@ -901,14 +894,13 @@ inline void gru_backward(const paddle::platform::CPUDeviceContext &context, ...@@ -901,14 +894,13 @@ inline void gru_backward(const paddle::platform::CPUDeviceContext &context,
if (value.prev_out_value) { if (value.prev_out_value) {
auto value_prev_out = typename EigenVector<T>::ConstType( auto value_prev_out = typename EigenVector<T>::ConstType(
value.prev_out_value, Array1(frame_size)); value.prev_out_value, Array1(frame_size));
paddle::operators::SigmoidGradFunctor<T>()( SigmoidGradFunctor<T>()(place,
place, 1 /*useless*/,
1 /*useless*/, value_update_gate,
value_update_gate, (value_prev_out - value_frame_state) * grad_output,
(value_prev_out - value_frame_state) * grad_output, grad_update_gate);
grad_update_gate);
} else { } else {
paddle::operators::SigmoidGradFunctor<T>()( SigmoidGradFunctor<T>()(
place, place,
1 /*useless*/, 1 /*useless*/,
value_update_gate, value_update_gate,
...@@ -921,13 +913,12 @@ inline void gru_backward(const paddle::platform::CPUDeviceContext &context, ...@@ -921,13 +913,12 @@ inline void gru_backward(const paddle::platform::CPUDeviceContext &context,
grad_prev_out.device(place) = grad_prev_out.device(place) =
grad_prev_out + grad_output * value_update_gate; grad_prev_out + grad_output * value_update_gate;
} }
paddle::operators::TanhGradFunctor<T>()( TanhGradFunctor<T>()(place,
place, 1 /*useless*/,
1 /*useless*/, value_frame_state,
value_frame_state, grad_output * (static_cast<T>(1.0) - value_update_gate),
grad_output * (static_cast<T>(1.0) - value_update_gate), grad_frame_state);
grad_frame_state); SigmoidGradFunctor<T>()(
paddle::operators::SigmoidGradFunctor<T>()(
place, place,
1 /*useless*/, 1 /*useless*/,
value_reset_gate, value_reset_gate,
...@@ -938,8 +929,8 @@ inline void gru_backward(const paddle::platform::CPUDeviceContext &context, ...@@ -938,8 +929,8 @@ inline void gru_backward(const paddle::platform::CPUDeviceContext &context,
} }
} }
template <class OpGruGrad, typename T> template <class OpGruGrad, typename T, typename Context>
inline void cpu_gru_backward(const paddle::platform::CPUDeviceContext &context, inline void cpu_gru_backward(const Context &context,
OpGruGrad op_gru_grad, OpGruGrad op_gru_grad,
phi::funcs::GRUMetaValue<T> value, phi::funcs::GRUMetaValue<T> value,
phi::funcs::GRUMetaGrad<T> grad, phi::funcs::GRUMetaGrad<T> grad,
......
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <type_traits> #include <type_traits>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/activation_op.h" #include "paddle/phi/kernels/funcs/activation_functor.h"
#include "paddle/phi/kernels/funcs/detail/activation_functions.h" #include "paddle/phi/kernels/funcs/detail/activation_functions.h"
#include "paddle/phi/kernels/funcs/lstm_compute.h" #include "paddle/phi/kernels/funcs/lstm_compute.h"
...@@ -409,11 +409,10 @@ void avx_lstm_backward_one_sequence(Op op, ...@@ -409,11 +409,10 @@ void avx_lstm_backward_one_sequence(Op op,
#endif #endif
} }
template <class T> template <class T, class Context>
void eigen_lstm_forward_one_sequence( void eigen_lstm_forward_one_sequence(const Context &context,
const paddle::platform::CPUDeviceContext &context, phi::funcs::LstmMetaValue<T> value,
phi::funcs::LstmMetaValue<T> value, int frame_size) {
int frame_size) {
auto eigen_value_ig = auto eigen_value_ig =
typename EigenVector<T>::Type(value.gate_value, Array1(frame_size)); typename EigenVector<T>::Type(value.gate_value, Array1(frame_size));
auto eigen_value_fg = typename EigenVector<T>::Type( auto eigen_value_fg = typename EigenVector<T>::Type(
...@@ -430,10 +429,10 @@ void eigen_lstm_forward_one_sequence( ...@@ -430,10 +429,10 @@ void eigen_lstm_forward_one_sequence(
typename EigenVector<T>::Type(value.output_value, Array1(frame_size)); typename EigenVector<T>::Type(value.output_value, Array1(frame_size));
auto &place = *context.eigen_device(); auto &place = *context.eigen_device();
paddle::operators::TanhFunctor<T>()(place, eigen_value_in, eigen_value_in); TanhFunctor<T>()(place, eigen_value_in, eigen_value_in);
paddle::operators::SigmoidFunctor<T>()(place, eigen_value_ig, eigen_value_ig); SigmoidFunctor<T>()(place, eigen_value_ig, eigen_value_ig);
paddle::operators::SigmoidFunctor<T>()(place, eigen_value_fg, eigen_value_fg); SigmoidFunctor<T>()(place, eigen_value_fg, eigen_value_fg);
paddle::operators::SigmoidFunctor<T>()(place, eigen_value_og, eigen_value_og); SigmoidFunctor<T>()(place, eigen_value_og, eigen_value_og);
eigen_state.device(place) = eigen_value_in * eigen_value_ig; eigen_state.device(place) = eigen_value_in * eigen_value_ig;
if (value.prev_state_value) { if (value.prev_state_value) {
...@@ -442,16 +441,15 @@ void eigen_lstm_forward_one_sequence( ...@@ -442,16 +441,15 @@ void eigen_lstm_forward_one_sequence(
eigen_state.device(place) = eigen_state + eigen_prev_state * eigen_value_fg; eigen_state.device(place) = eigen_state + eigen_prev_state * eigen_value_fg;
} }
paddle::operators::TanhFunctor<T>()(place, eigen_state, eigen_state_act); TanhFunctor<T>()(place, eigen_state, eigen_state_act);
eigen_output.device(place) = eigen_value_og * eigen_state_act; eigen_output.device(place) = eigen_value_og * eigen_state_act;
} }
template <class T> template <class T, class Context>
void eigen_lstm_backward_one_sequence( void eigen_lstm_backward_one_sequence(const Context &context,
const paddle::platform::CPUDeviceContext &context, phi::funcs::LstmMetaValue<T> value,
phi::funcs::LstmMetaValue<T> value, phi::funcs::LstmMetaGrad<T> grad,
phi::funcs::LstmMetaGrad<T> grad, int frame_size) {
int frame_size) {
auto eigen_value_ig = auto eigen_value_ig =
typename EigenVector<T>::Type(value.gate_value, Array1(frame_size)); typename EigenVector<T>::Type(value.gate_value, Array1(frame_size));
auto eigen_value_fg = typename EigenVector<T>::Type( auto eigen_value_fg = typename EigenVector<T>::Type(
...@@ -477,38 +475,35 @@ void eigen_lstm_backward_one_sequence( ...@@ -477,38 +475,35 @@ void eigen_lstm_backward_one_sequence(
typename EigenVector<T>::Type(grad.state_grad, Array1(frame_size)); typename EigenVector<T>::Type(grad.state_grad, Array1(frame_size));
auto &place = *context.eigen_device(); auto &place = *context.eigen_device();
paddle::operators::SigmoidGradFunctor<T>()( SigmoidGradFunctor<T>()(place,
place, 1 /*useless*/,
1 /*useless*/, eigen_value_og,
eigen_value_og, eigen_grad_output * eigen_state_act,
eigen_grad_output * eigen_state_act, eigen_grad_og);
eigen_grad_og);
eigen_grad_state.device(place) = eigen_grad_state.device(place) =
eigen_grad_state + eigen_grad_state +
eigen_grad_output * eigen_value_og * eigen_grad_output * eigen_value_og *
(static_cast<T>(1) - eigen_state_act * eigen_state_act); (static_cast<T>(1) - eigen_state_act * eigen_state_act);
paddle::operators::TanhGradFunctor<T>()(place, TanhGradFunctor<T>()(place,
1, 1,
eigen_value_in, eigen_value_in,
eigen_grad_state * eigen_value_ig, eigen_grad_state * eigen_value_ig,
eigen_grad_in); eigen_grad_in);
paddle::operators::SigmoidGradFunctor<T>()(place, SigmoidGradFunctor<T>()(place,
1, 1,
eigen_value_ig, eigen_value_ig,
eigen_grad_state * eigen_value_in, eigen_grad_state * eigen_value_in,
eigen_grad_ig); eigen_grad_ig);
if (value.prev_state_value) { if (value.prev_state_value) {
auto eigen_prev_state = typename EigenVector<T>::ConstType( auto eigen_prev_state = typename EigenVector<T>::ConstType(
value.prev_state_value, Array1(frame_size)); value.prev_state_value, Array1(frame_size));
paddle::operators::SigmoidGradFunctor<T>()( SigmoidGradFunctor<T>()(place,
place, 1,
1, eigen_value_fg,
eigen_value_fg, eigen_grad_state * eigen_prev_state,
eigen_grad_state * eigen_prev_state, eigen_grad_fg);
eigen_grad_fg);
} else { } else {
paddle::operators::SigmoidGradFunctor<T>()( SigmoidGradFunctor<T>()(place, 1, eigen_value_fg, 0, eigen_grad_fg);
place, 1, eigen_value_fg, 0, eigen_grad_fg);
} }
if (grad.prev_state_grad) { if (grad.prev_state_grad) {
auto eigen_grad_pre_state = auto eigen_grad_pre_state =
...@@ -517,8 +512,8 @@ void eigen_lstm_backward_one_sequence( ...@@ -517,8 +512,8 @@ void eigen_lstm_backward_one_sequence(
} }
} }
template <class T, class Op> template <class T, class Op, class Context>
void cpu_lstm_forward(const paddle::platform::CPUDeviceContext &context, void cpu_lstm_forward(const Context &context,
Op op, Op op,
phi::funcs::LstmMetaValue<T> value, phi::funcs::LstmMetaValue<T> value,
int frame_size, int frame_size,
...@@ -552,8 +547,8 @@ void cpu_lstm_forward(const paddle::platform::CPUDeviceContext &context, ...@@ -552,8 +547,8 @@ void cpu_lstm_forward(const paddle::platform::CPUDeviceContext &context,
} }
} }
template <class T, class Op> template <class T, class Op, class Context>
void cpu_lstm_backward(const paddle::platform::CPUDeviceContext &context, void cpu_lstm_backward(const Context &context,
Op op, Op op,
phi::funcs::LstmMetaValue<T> value, phi::funcs::LstmMetaValue<T> value,
phi::funcs::LstmMetaGrad<T> grad, phi::funcs::LstmMetaGrad<T> grad,
......
...@@ -46,7 +46,7 @@ struct GRUUnitFunctor<paddle::platform::CPUDeviceContext, T> { ...@@ -46,7 +46,7 @@ struct GRUUnitFunctor<paddle::platform::CPUDeviceContext, T> {
frame_size * 3); frame_size * 3);
} }
detail::forward_reset_output( detail::forward_reset_output<paddle::platform::CPUDeviceContext>(
phi::funcs::detail::forward::gru_resetOutput<T>(), phi::funcs::detail::forward::gru_resetOutput<T>(),
value, value,
frame_size, frame_size,
...@@ -71,7 +71,7 @@ struct GRUUnitFunctor<paddle::platform::CPUDeviceContext, T> { ...@@ -71,7 +71,7 @@ struct GRUUnitFunctor<paddle::platform::CPUDeviceContext, T> {
frame_size * 3); frame_size * 3);
} }
detail::forward_final_output( detail::forward_final_output<paddle::platform::CPUDeviceContext>(
phi::funcs::detail::forward::gru_finalOutput<T>(), phi::funcs::detail::forward::gru_finalOutput<T>(),
value, value,
frame_size, frame_size,
...@@ -233,6 +233,59 @@ struct GRUUnitFunctorV2<paddle::platform::CPUDeviceContext, T> { ...@@ -233,6 +233,59 @@ struct GRUUnitFunctorV2<paddle::platform::CPUDeviceContext, T> {
} }
}; };
template <typename T>
struct GRUUnitFunctorV2<CPUContext, T> {
static void compute(const CPUContext &context,
GRUMetaValue<T> value,
int frame_size,
int batch_size,
const phi::funcs::detail::ActivationType active_node,
const phi::funcs::detail::ActivationType active_gate) {
#if !defined(__NVCC__) && !defined(__HIPCC___)
auto blas = phi::funcs::GetBlas<CPUContext, T>(context);
if (value.prev_out_value) {
blas.GEMM(CblasNoTrans,
CblasTrans,
batch_size,
frame_size,
frame_size,
1,
value.prev_out_value,
value.state_weight,
0,
value.reset_output_value);
}
detail::forward_reset_output(
phi::funcs::detail::forward::gru_resetOutput<T>(),
value,
frame_size,
batch_size,
active_gate,
false,
&context);
T *cell_state_value = value.gate_value + 2 * frame_size;
T *reset_output_value = value.reset_output_value;
for (int b = 0; b < batch_size; ++b) {
blas.VADD(
frame_size, cell_state_value, reset_output_value, cell_state_value);
cell_state_value += frame_size * 3;
reset_output_value += frame_size;
}
detail::forward_final_output(
phi::funcs::detail::forward::gru_finalOutput<T>(),
value,
frame_size,
batch_size,
active_node,
true,
false,
&context);
#endif
}
};
template <typename T> template <typename T>
struct GRUUnitGradFunctorV2<paddle::platform::CPUDeviceContext, T> { struct GRUUnitGradFunctorV2<paddle::platform::CPUDeviceContext, T> {
static void compute(const paddle::platform::CPUDeviceContext &context, static void compute(const paddle::platform::CPUDeviceContext &context,
...@@ -358,6 +411,130 @@ struct GRUUnitGradFunctorV2<paddle::platform::CPUDeviceContext, T> { ...@@ -358,6 +411,130 @@ struct GRUUnitGradFunctorV2<paddle::platform::CPUDeviceContext, T> {
} }
}; };
template <typename T>
struct GRUUnitGradFunctorV2<CPUContext, T> {
static void compute(const CPUContext &context,
GRUMetaValue<T> value,
GRUMetaGrad<T> grad,
int frame_size,
int batch_size,
const phi::funcs::detail::ActivationType active_node,
const phi::funcs::detail::ActivationType active_gate) {
#if !defined(__NVCC__) && !defined(__HIPCC___)
// calculate grad_update_gate, grad_frame_state,
// grad_reset_output, grad_reset_gate
detail::cpu_gru_backward(context,
phi::funcs::detail::backward::gru<T>(),
value,
grad,
frame_size,
batch_size,
active_node,
active_gate);
auto blas = phi::funcs::GetBlas<CPUContext, T>(context);
if (grad.prev_out_grad && value.prev_out_value) {
// update prev_out_grad
blas.GEMM(false,
false,
batch_size,
frame_size,
frame_size,
1,
grad.gate_grad,
frame_size * 3,
value.gate_weight,
frame_size,
1,
grad.prev_out_grad,
frame_size);
blas.GEMM(false,
false,
batch_size,
frame_size,
frame_size,
1,
grad.gate_grad + frame_size,
frame_size * 3,
value.gate_weight + frame_size * frame_size,
frame_size,
1,
grad.prev_out_grad,
frame_size);
blas.GEMM(false,
false,
batch_size,
frame_size,
frame_size,
1,
grad.reset_output_grad,
frame_size,
value.state_weight,
frame_size,
1,
grad.prev_out_grad,
frame_size);
// update weight_hh_grad
if (grad.gate_weight_grad) {
// reset gate
blas.GEMM(true,
false,
frame_size,
frame_size,
batch_size,
1,
grad.gate_grad,
frame_size * 3,
value.prev_out_value,
frame_size,
1,
grad.gate_weight_grad,
frame_size);
// update gate
blas.GEMM(true,
false,
frame_size,
frame_size,
batch_size,
1,
grad.gate_grad + frame_size,
frame_size * 3,
value.prev_out_value,
frame_size,
1,
grad.gate_weight_grad + frame_size * frame_size,
frame_size);
// cell state
blas.GEMM(true,
false,
frame_size,
frame_size,
batch_size,
1,
grad.reset_output_grad,
frame_size,
value.prev_out_value,
frame_size,
1,
grad.state_weight_grad,
frame_size);
}
}
// update bias_hh_grad
T *gate_grad = grad.gate_grad;
T *bias_hh_grad = grad.bias_hh_grad;
T *state_bias_grad = grad.bias_hh_grad + 2 * frame_size;
T *reset_output_grad = grad.reset_output_grad;
for (int b = 0; b < batch_size; ++b) {
blas.VADD(2 * frame_size, bias_hh_grad, gate_grad, bias_hh_grad);
blas.VADD(
frame_size, state_bias_grad, reset_output_grad, state_bias_grad);
gate_grad += 3 * frame_size;
reset_output_grad += frame_size;
}
#endif
}
};
template struct GRUUnitFunctor<paddle::platform::CPUDeviceContext, float>; template struct GRUUnitFunctor<paddle::platform::CPUDeviceContext, float>;
template struct GRUUnitFunctor<paddle::platform::CPUDeviceContext, double>; template struct GRUUnitFunctor<paddle::platform::CPUDeviceContext, double>;
template struct GRUUnitGradFunctor<paddle::platform::CPUDeviceContext, float>; template struct GRUUnitGradFunctor<paddle::platform::CPUDeviceContext, float>;
...@@ -369,5 +546,10 @@ template struct GRUUnitGradFunctorV2<paddle::platform::CPUDeviceContext, float>; ...@@ -369,5 +546,10 @@ template struct GRUUnitGradFunctorV2<paddle::platform::CPUDeviceContext, float>;
template struct GRUUnitGradFunctorV2<paddle::platform::CPUDeviceContext, template struct GRUUnitGradFunctorV2<paddle::platform::CPUDeviceContext,
double>; double>;
template struct GRUUnitFunctorV2<CPUContext, float>;
template struct GRUUnitFunctorV2<CPUContext, double>;
template struct GRUUnitGradFunctorV2<CPUContext, float>;
template struct GRUUnitGradFunctorV2<CPUContext, double>;
} // namespace funcs } // namespace funcs
} // namespace phi } // namespace phi
...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/phi/kernels/funcs/lstm_compute.h" #include "paddle/phi/kernels/funcs/lstm_compute.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/kernels/funcs/detail/lstm_cpu_kernel.h" #include "paddle/phi/kernels/funcs/detail/lstm_cpu_kernel.h"
#include "paddle/phi/kernels/funcs/detail/lstm_kernel.h" #include "paddle/phi/kernels/funcs/detail/lstm_kernel.h"
...@@ -51,6 +53,38 @@ struct LstmUnitFunctor<paddle::platform::CPUDeviceContext, T> { ...@@ -51,6 +53,38 @@ struct LstmUnitFunctor<paddle::platform::CPUDeviceContext, T> {
} }
}; };
template <class T>
struct LstmUnitFunctor<CPUContext, T> {
static void compute(const CPUContext& context,
LstmMetaValue<T> value,
int frame_size,
int batch_size,
T cell_clip,
const phi::funcs::detail::ActivationType& gate_act,
const phi::funcs::detail::ActivationType& cell_act,
const phi::funcs::detail::ActivationType& cand_act,
bool old_api_version = true) {
for (int b = 0; b < batch_size; b++) {
detail::cpu_lstm_forward(context,
phi::funcs::detail::forward::lstm<T>(),
value,
frame_size,
cell_clip,
cand_act,
gate_act,
cell_act,
old_api_version);
value.gate_value += frame_size * 4;
value.state_value += frame_size;
value.state_active_value += frame_size;
value.output_value += frame_size;
if (value.prev_state_value) {
value.prev_state_value += frame_size;
}
}
}
};
template <class T> template <class T>
struct LstmUnitGradFunctor<paddle::platform::CPUDeviceContext, T> { struct LstmUnitGradFunctor<paddle::platform::CPUDeviceContext, T> {
static void compute(const paddle::platform::CPUDeviceContext& context, static void compute(const paddle::platform::CPUDeviceContext& context,
...@@ -94,10 +128,58 @@ struct LstmUnitGradFunctor<paddle::platform::CPUDeviceContext, T> { ...@@ -94,10 +128,58 @@ struct LstmUnitGradFunctor<paddle::platform::CPUDeviceContext, T> {
} }
}; };
template <class T>
struct LstmUnitGradFunctor<CPUContext, T> {
static void compute(const CPUContext& context,
LstmMetaValue<T> value,
LstmMetaGrad<T> grad,
int frame_size,
int batch_size,
T cell_clip,
const phi::funcs::detail::ActivationType& gate_act,
const phi::funcs::detail::ActivationType& cell_act,
const phi::funcs::detail::ActivationType& cand_act,
bool old_api_version = true) {
for (int b = 0; b < batch_size; b++) {
detail::cpu_lstm_backward(context,
phi::funcs::detail::backward::lstm<T>(),
value,
grad,
frame_size,
cell_clip,
cand_act,
gate_act,
cell_act,
old_api_version);
value.gate_value += frame_size * 4;
value.state_value += frame_size;
value.state_active_value += frame_size;
value.output_value += frame_size;
if (value.prev_state_value) {
value.prev_state_value += frame_size;
}
grad.gate_grad += frame_size * 4;
grad.state_grad += frame_size;
grad.state_active_grad += frame_size;
grad.output_grad += frame_size;
if (grad.prev_state_grad) {
grad.prev_state_grad += frame_size;
}
}
}
};
template class LstmUnitFunctor<paddle::platform::CPUDeviceContext, float>; template class LstmUnitFunctor<paddle::platform::CPUDeviceContext, float>;
template class LstmUnitFunctor<paddle::platform::CPUDeviceContext, double>; template class LstmUnitFunctor<paddle::platform::CPUDeviceContext, double>;
template class LstmUnitGradFunctor<paddle::platform::CPUDeviceContext, float>; template class LstmUnitGradFunctor<paddle::platform::CPUDeviceContext, float>;
template class LstmUnitGradFunctor<paddle::platform::CPUDeviceContext, double>; template class LstmUnitGradFunctor<paddle::platform::CPUDeviceContext, double>;
template class LstmUnitFunctor<CPUContext, float>;
template class LstmUnitFunctor<CPUContext, double>;
template class LstmUnitGradFunctor<CPUContext, float>;
template class LstmUnitGradFunctor<CPUContext, double>;
} // namespace funcs } // namespace funcs
} // namespace phi } // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
namespace phi {
#ifdef PADDLE_WITH_HIP
using gpuRNNMode_t = miopenRNNMode_t;
using gpuDnnHandle_t = miopenHandle_t;
using gpuDnnDataType_t = miopenDataType_t;
#else
using gpuRNNMode_t = cudnnRNNMode_t;
using gpuDnnHandle_t = cudnnHandle_t;
using gpuDnnDataType_t = cudnnDataType_t;
#endif
class RNNDescriptors {
public:
RNNDescriptors(int seq_length,
int batch_size,
int input_size,
int hidden_size,
int num_layers,
float dropout_prob,
int seed,
int weight_numel,
gpuRNNMode_t mode,
bool is_bidirec,
bool is_test)
: seq_length_(seq_length),
batch_size_(batch_size),
input_size_(input_size),
hidden_size_(hidden_size),
num_layers_(num_layers),
dropout_prob_(dropout_prob),
seed_(seed),
weight_numel_(weight_numel),
mode_(mode),
is_bidirec_(is_bidirec),
is_test_(is_test) {}
template <typename T>
void Create(const gpuDnnHandle_t &handle,
const Place &place,
const std::vector<int> &sequence_length,
size_t *workspace_size,
size_t *reserve_size,
DenseTensor *dropout_state) {
int numDirections = is_bidirec_ ? 2 : 1;
gpuDnnDataType_t cudnn_type = paddle::platform::CudnnDataType<T>::type;
// ------------------- cudnn x, y descriptors ---------------------
std::vector<int> dims_x = {batch_size_, input_size_, 1};
std::vector<int> strides_x = {input_size_, 1, 1};
std::vector<int> dims_y = {batch_size_, hidden_size_ * numDirections, 1};
std::vector<int> strides_y = {hidden_size_ * numDirections, 1, 1};
for (int i = 0; i < seq_length_; ++i) {
x_descs_.emplace_back(x_desc_.descriptor<T>(dims_x, strides_x));
y_descs_.emplace_back(y_desc_.descriptor<T>(dims_y, strides_y));
}
#if defined(PADDLE_WITH_CUDA) && CUDNN_VERSION >= 7201
if (!sequence_length.empty()) {
x_seq_desc_.descriptor<T>(
seq_length_, batch_size_, input_size_, true, sequence_length);
y_seq_desc_.descriptor<T>(seq_length_,
batch_size_,
hidden_size_ * numDirections,
true,
sequence_length);
}
#endif
// ------------------- cudnn hx, hy, cx, cy descriptors----------
std::vector<int> dims_hx = {
num_layers_ * numDirections, batch_size_, hidden_size_};
std::vector<int> strides_hx = {hidden_size_ * batch_size_, hidden_size_, 1};
init_h_desc_.descriptor<T>(dims_hx, strides_hx);
init_c_desc_.descriptor<T>(dims_hx, strides_hx);
last_h_desc_.descriptor<T>(dims_hx, strides_hx);
last_c_desc_.descriptor<T>(dims_hx, strides_hx);
// ------------------- cudnn dropout descriptors ---------------------
size_t state_size;
bool is_initialized = dropout_state->IsInitialized();
if (!is_test_ && !is_initialized) {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::miopenDropoutGetStatesSize(handle,
&state_size));
dropout_state->mutable_data<uint8_t>({static_cast<int64_t>(state_size)},
place);
#else
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnDropoutGetStatesSize(handle,
&state_size));
dropout_state->mutable_data<uint8_t>({static_cast<int64_t>(state_size)},
place);
#endif
}
dropout_desc_.descriptor(handle,
place,
is_initialized,
dropout_prob_,
is_test_ ? nullptr : dropout_state,
seed_,
state_size);
// ------------------- cudnn rnn descriptors ---------------------
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::miopenSetRNNDescriptor_V2(
rnn_desc_.desc(),
hidden_size_,
num_layers_,
dropout_desc_.desc(),
miopenRNNlinear,
is_bidirec_ ? miopenRNNbidirection : miopenRNNunidirection,
mode_,
miopenRNNwithBias,
miopenRNNdefault,
cudnn_type));
#elif CUDNN_VERSION >= 6000
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnSetRNNDescriptor_v6(
handle,
rnn_desc_.desc(),
hidden_size_,
num_layers_,
dropout_desc_.desc(),
CUDNN_LINEAR_INPUT,
is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL,
mode_,
CUDNN_RNN_ALGO_STANDARD,
cudnn_type));
#else
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cudnnSetRNNDescriptor(
rnn_desc_.desc(),
hidden_size_,
num_layers_,
dropout_desc_.desc(),
CUDNN_LINEAR_INPUT,
is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL,
mode_,
cudnn_type));
#endif
#if defined(PADDLE_WITH_CUDA) && CUDNN_VERSION >= 7201
if (!sequence_length.empty()) {
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnSetRNNPaddingMode(
rnn_desc_.desc(), CUDNN_RNN_PADDED_IO_ENABLED));
}
#endif
// ------------------- cudnn weights_size ---------------------
size_t weights_size_;
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::miopenGetRNNParamsSize(
handle, rnn_desc_.desc(), x_descs_[0], &weights_size_, cudnn_type));
#else
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cudnnGetRNNParamsSize(
handle, rnn_desc_.desc(), x_descs_[0], &weights_size_, cudnn_type));
#endif
PADDLE_ENFORCE_EQ(
weights_size_,
sizeof(T) * weight_numel_,
phi::errors::InvalidArgument(
"The cudnn rnn and setting weight size should be same."));
// ------------------- cudnn weight descriptors ---------------------
auto layout = paddle::platform::DataLayout::kNCHW;
int dim_tmp = weights_size_ / sizeof(T);
std::vector<int> dim_w = {dim_tmp, 1, 1};
weight_desc_.descriptor<T>(layout, dim_w);
// ------------------- cudnn workspace, reserve size ---------------------
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::miopenGetRNNWorkspaceSize(handle,
rnn_desc_.desc(),
seq_length_,
x_descs_.data(),
workspace_size));
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::miopenGetRNNTrainingReserveSize(
handle,
rnn_desc_.desc(),
seq_length_,
x_descs_.data(),
reserve_size));
#else
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnGetRNNWorkspaceSize(handle,
rnn_desc_.desc(),
seq_length_,
x_descs_.data(),
workspace_size));
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnGetRNNTrainingReserveSize(
handle,
rnn_desc_.desc(),
seq_length_,
x_descs_.data(),
reserve_size));
#endif
}
#ifdef PADDLE_WITH_HIP
miopenTensorDescriptor_t *x_descs() { return x_descs_.data(); }
miopenTensorDescriptor_t *y_descs() { return y_descs_.data(); }
miopenTensorDescriptor_t init_h_desc() { return init_h_desc_.desc(); }
miopenTensorDescriptor_t init_c_desc() { return init_c_desc_.desc(); }
miopenTensorDescriptor_t last_h_desc() { return last_h_desc_.desc(); }
miopenTensorDescriptor_t last_c_desc() { return last_c_desc_.desc(); }
miopenRNNDescriptor_t rnn_desc() { return rnn_desc_.desc(); }
miopenDropoutDescriptor_t dropout_desc() { return dropout_desc_.desc(); }
miopenTensorDescriptor_t weight_desc() { return weight_desc_.desc(); }
#else
cudnnTensorDescriptor_t *x_descs() { return x_descs_.data(); }
cudnnTensorDescriptor_t *y_descs() { return y_descs_.data(); }
#if CUDNN_VERSION >= 7201
cudnnRNNDataDescriptor_t x_seq_desc() { return x_seq_desc_.desc(); }
cudnnRNNDataDescriptor_t y_seq_desc() { return y_seq_desc_.desc(); }
#endif
cudnnTensorDescriptor_t init_h_desc() { return init_h_desc_.desc(); }
cudnnTensorDescriptor_t init_c_desc() { return init_c_desc_.desc(); }
cudnnTensorDescriptor_t last_h_desc() { return last_h_desc_.desc(); }
cudnnTensorDescriptor_t last_c_desc() { return last_c_desc_.desc(); }
cudnnRNNDescriptor_t rnn_desc() { return rnn_desc_.desc(); }
cudnnDropoutDescriptor_t dropout_desc() { return dropout_desc_.desc(); }
cudnnFilterDescriptor_t weight_desc() { return weight_desc_.desc(); }
#endif
private:
int seq_length_;
int batch_size_;
int input_size_;
int hidden_size_;
int num_layers_;
float dropout_prob_;
int seed_;
int weight_numel_;
gpuRNNMode_t mode_;
bool is_bidirec_;
bool is_test_;
#ifdef PADDLE_WITH_HIP
std::vector<miopenTensorDescriptor_t> x_descs_;
std::vector<miopenTensorDescriptor_t> y_descs_;
#else
std::vector<cudnnTensorDescriptor_t> x_descs_;
std::vector<cudnnTensorDescriptor_t> y_descs_;
#endif
paddle::platform::ScopedTensorDescriptor x_desc_;
paddle::platform::ScopedTensorDescriptor y_desc_;
#if defined(PADDLE_WITH_CUDA) && CUDNN_VERSION >= 7201
paddle::platform::ScopedRNNTensorDescriptor x_seq_desc_;
paddle::platform::ScopedRNNTensorDescriptor y_seq_desc_;
#endif
paddle::platform::ScopedTensorDescriptor init_h_desc_;
paddle::platform::ScopedTensorDescriptor init_c_desc_;
paddle::platform::ScopedTensorDescriptor last_h_desc_;
paddle::platform::ScopedTensorDescriptor last_c_desc_;
paddle::platform::ScopedDropoutDescriptor dropout_desc_;
paddle::platform::ScopedFilterDescriptor weight_desc_;
paddle::platform::ScopedRNNDescriptor rnn_desc_;
};
template <typename T, typename Type>
bool IsContinuous(const Type &weight_list) {
bool continuous = true;
for (size_t i = 0; i < weight_list.size() - 1; ++i) {
auto *in_data = weight_list[i]->template data<T>();
auto *in_after_data = weight_list[i + 1]->template data<T>();
auto in_size = weight_list[i]->numel();
bool temp = in_data + in_size == in_after_data;
continuous = continuous && temp;
}
return continuous;
}
template <typename T>
void WeightToTensor(const Place &place,
gpuStream_t stream,
const std::vector<const DenseTensor *> &weight_list,
DenseTensor *weight) {
auto weight_data = weight->data<T>();
int weight_offset = 0;
for (size_t i = 0; i < weight_list.size(); ++i) {
const T *in_data = weight_list[i]->data<T>();
auto in_size = weight_list[i]->numel();
paddle::memory::Copy(weight->place(),
weight_data + weight_offset,
weight_list[i]->place(),
in_data,
in_size * sizeof(T),
stream);
weight_offset += in_size;
}
}
#ifdef PADDLE_WITH_HIP
template <typename T>
void WeightListToTensor(const Place &place,
gpuStream_t stream,
const std::vector<DenseTensor> &tensor_list,
DenseTensor *weight_whole,
const size_t offset = 0UL) {
size_t weight_offset = offset;
auto weight_data = weight_whole->data<T>();
for (size_t i = 0; i < tensor_list.size(); ++i) {
const T *in_data = tensor_list[i].data<T>();
auto in_size = tensor_list[i].numel();
paddle::memory::Copy(weight_whole->place(),
weight_data + weight_offset,
tensor_list[i].place(),
in_data,
in_size * sizeof(T),
stream);
weight_offset += in_size;
}
}
template <typename T>
void WeightToPermutedTensor(const Place &place,
gpuStream_t stream,
std::vector<const DenseTensor *> *weight_list,
DenseTensor *weight_whole,
const gpuRNNMode_t rnn_mode,
const bool is_bidirec) {
if (is_bidirec) {
for (size_t i = 0; i < weight_list->size(); i += 4) {
auto tmp = (*weight_list)[i + 1];
(*weight_list)[i + 1] = (*weight_list)[i + 2];
(*weight_list)[i + 2] = tmp;
}
}
size_t weight_offset = 0;
for (size_t i = 0; i < weight_list->size(); ++i) {
if (rnn_mode == miopenLSTM) {
std::vector<DenseTensor> split_tensor = (*weight_list)[i]->Chunk(4, 0);
WeightListToTensor<T>(
place,
stream,
{split_tensor[0], split_tensor[1], split_tensor[3], split_tensor[2]},
weight_whole,
weight_offset);
} else if (rnn_mode == miopenGRU) {
std::vector<DenseTensor> split_tensor = (*weight_list)[i]->Chunk(3, 0);
WeightListToTensor<T>(place,
stream,
{split_tensor[1], split_tensor[0], split_tensor[2]},
weight_whole,
weight_offset);
} else {
WeightListToTensor<T>(
place, stream, {*(*weight_list)[i]}, weight_whole, weight_offset);
}
weight_offset += (*weight_list)[i]->numel();
}
}
#endif
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/rnn_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/gpu/rnn_functor.h"
#include "paddle/fluid/operators/utils.h"
namespace phi {
#ifdef PADDLE_WITH_HIP
template <typename T>
void TensorToPermutedWeight(const Place &place,
gpuStream_t stream,
const DenseTensor &tensor,
std::vector<DenseTensor *> *weight_grad_list,
const gpuRNNMode_t rnn_mode,
bool is_bidirec) {
if (is_bidirec) {
for (size_t i = 0; i < weight_grad_list->size(); i += 4) {
auto tmp = (*weight_grad_list)[i + 1];
(*weight_grad_list)[i + 1] = (*weight_grad_list)[i + 2];
(*weight_grad_list)[i + 2] = tmp;
}
}
size_t weight_offset = 0;
for (size_t i = 0; i < weight_grad_list->size(); ++i) {
auto numel_size = (*weight_grad_list)[i]->numel();
DenseTensor temp;
temp.Resize({numel_size});
temp.ShareDataWith(tensor.Slice(weight_offset, weight_offset + numel_size));
if (rnn_mode == miopenLSTM) {
std::vector<DenseTensor> split_tensor = temp.Chunk(4, 0);
WeightListToTensor<T>(
place,
stream,
{split_tensor[0], split_tensor[1], split_tensor[3], split_tensor[2]},
(*weight_grad_list)[i]);
} else if (rnn_mode == miopenGRU) {
std::vector<DenseTensor> split_tensor = temp.Chunk(3, 0);
WeightListToTensor<T>(place,
stream,
{split_tensor[1], split_tensor[0], split_tensor[2]},
(*weight_grad_list)[i]);
} else {
WeightListToTensor<T>(place, stream, {temp}, (*weight_grad_list)[i]);
}
weight_offset += numel_size;
}
if (is_bidirec) {
for (size_t i = 0; i < weight_grad_list->size(); i += 4) {
auto tmp = (*weight_grad_list)[i + 1];
(*weight_grad_list)[i + 1] = (*weight_grad_list)[i + 2];
(*weight_grad_list)[i + 2] = tmp;
}
}
}
#endif
template <typename T, typename Context>
void RnnGradKernel(const Context &dev_ctx,
const DenseTensor &x,
const std::vector<const DenseTensor *> &pre_state,
const std::vector<const DenseTensor *> &weight_list,
paddle::optional<const DenseTensor &> sequence_length,
const DenseTensor &out,
const DenseTensor &dropout_state,
const DenseTensor &reserve,
const DenseTensor &out_grad,
const std::vector<const DenseTensor *> &state_grad,
float dropout_prob,
bool is_bidirec,
int input_size,
int hidden_size,
int num_layers,
const std::string &mode,
int seed,
bool is_test,
DenseTensor *x_grad,
std::vector<DenseTensor *> pre_state_grad,
std::vector<DenseTensor *> weight_grad_list) {
#ifdef PADDLE_WITH_HIP
miopenRNNMode_t rnn_mode = miopenLSTM;
if (mode == "LSTM")
rnn_mode = miopenLSTM;
else if (mode == "GRU")
rnn_mode = miopenGRU;
else if (mode == "RNN_RELU")
rnn_mode = miopenRNNRELU;
else if (mode == "RNN_TANH")
rnn_mode = miopenRNNTANH;
#else
cudnnRNNMode_t rnn_mode = CUDNN_LSTM;
if (mode == "LSTM")
rnn_mode = CUDNN_LSTM;
else if (mode == "GRU")
rnn_mode = CUDNN_GRU;
else if (mode == "RNN_RELU")
rnn_mode = CUDNN_RNN_RELU;
else if (mode == "RNN_TANH")
rnn_mode = CUDNN_RNN_TANH;
#endif
else
PADDLE_THROW(phi::errors::InvalidArgument(
"rnn_mode should be LSTM, GRU, RNN_RELU or RNN_TANH, but received: "
"%s.",
mode));
auto handle = dev_ctx.cudnn_handle();
auto place = dev_ctx.GetPlace();
auto weight_numel = std::accumulate(
weight_list.begin(),
weight_list.end(),
0,
[](int64_t num, const DenseTensor *t) { return num + t->numel(); });
bool continuous =
IsContinuous<T, std::vector<const DenseTensor *>>(weight_list);
auto stream = dev_ctx.stream();
DenseTensor weight_whole;
T *weight_data = nullptr;
#ifdef PADDLE_WITH_HIP
// Need to permute weight, set continuous to false
continuous = false;
#endif
if (!continuous) {
weight_whole.Resize({weight_numel});
dev_ctx.template Alloc<T>(&weight_whole);
#ifdef PADDLE_WITH_HIP
// MIOPEN need to permute weight for miopenLSTM or miopenGRU
std::vector<const DenseTensor *> weight_list_tmp = weight_list;
WeightToPermutedTensor<T>(
place, stream, &weight_list_tmp, &weight_whole, rnn_mode, is_bidirec);
#else
WeightToTensor<T>(place, stream, weight_list, &weight_whole);
#endif
weight_data = weight_whole.data<T>();
} else {
weight_data = const_cast<T *>(weight_list[0]->data<T>());
}
DenseTensor weight_grad = Full<T>(dev_ctx, {weight_numel}, 0);
T *weight_grad_data = weight_grad.data<T>();
#ifdef PADDLE_WITH_HIP
// MIOPEN need to permute weight_grad_list, so do not share data with
// weight_grad
for (size_t i = 0; i < weight_grad_list.size(); ++i) {
dev_ctx.template Alloc<T>(weight_grad_list[i]);
}
#else
int offset = 0;
for (size_t i = 0; i < weight_grad_list.size(); ++i) {
size_t len = weight_grad_list[i]->numel();
auto dim = weight_grad_list[i]->dims();
weight_grad_list[i]
->ShareDataWith(weight_grad.Slice(static_cast<int64_t>(offset),
static_cast<int64_t>(offset + len)))
.Resize(dim);
offset += len;
}
#endif
DenseTensor input_grad_value;
if (!x_grad) {
x_grad = &input_grad_value;
x_grad->Resize(x.dims());
}
auto *init_h_data = pre_state[0]->data<T>();
// auto *last_h_data = state[0]->data<T>();
auto *last_h_grad_data = state_grad[0]->data<T>();
const T *init_c_data = nullptr;
// const T *last_c_data = nullptr;
const T *last_c_grad_data = nullptr;
T *init_h_grad_data = pre_state_grad.size() != 0 && pre_state_grad[0]
? dev_ctx.template Alloc<T>(pre_state_grad[0])
: nullptr;
T *init_c_grad_data = nullptr;
#ifdef PADDLE_WITH_HIP
if (rnn_mode == miopenLSTM) {
#else
if (rnn_mode == CUDNN_LSTM) {
#endif
init_c_data = pre_state[1]->data<T>();
// last_c_data = state[1]->data<T>();
last_c_grad_data = state_grad[1]->data<T>();
init_c_grad_data = pre_state_grad.size() >= 2 && pre_state_grad[1]
? dev_ctx.template Alloc<T>(pre_state_grad[1])
: nullptr;
}
auto *out_data = out.data<T>();
auto *out_grad_data = out_grad.data<T>();
// need check exist
T *x_grad_data = nullptr;
if (x_grad) {
x_grad_data = dev_ctx.template Alloc<T>(x_grad);
}
bool has_seq_length = sequence_length.is_initialized();
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_EQ(
has_seq_length,
false,
phi::errors::InvalidArgument("ROCm do not support SequenceLength yet."));
#endif
std::vector<int> SequenceLength;
if (has_seq_length) {
SequenceLength =
paddle::operators::GetDataFromTensor<int>(sequence_length.get_ptr());
}
auto input_dims = x.dims();
int seq_length = input_dims[0];
int batch_size = input_dims[1];
int input_size_local = input_dims[2];
size_t workspace_size;
size_t reserve_size;
RNNDescriptors rnn(seq_length,
batch_size,
input_size_local,
hidden_size,
num_layers,
dropout_prob,
seed,
weight_numel,
rnn_mode,
is_bidirec,
is_test);
rnn.Create<T>(handle,
dev_ctx.GetPlace(),
SequenceLength,
&workspace_size,
&reserve_size,
const_cast<DenseTensor *>(&dropout_state));
DenseTensor workspace_data_ =
Empty<uint8_t>(dev_ctx, {static_cast<int64_t>(workspace_size)});
const uint8_t *reserve_data = reserve.data<uint8_t>();
if (!has_seq_length) {
if (x_grad) {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::miopenRNNBackwardData(
handle,
rnn.rnn_desc(),
seq_length,
rnn.y_descs(),
out_data,
rnn.y_descs(),
out_grad_data,
rnn.last_h_desc(),
last_h_grad_data,
rnn.last_c_desc(),
last_c_grad_data,
rnn.weight_desc(),
weight_data,
rnn.init_h_desc(),
init_h_data,
rnn.init_c_desc(),
init_c_data,
rnn.x_descs(),
x_grad_data,
rnn.init_h_desc(),
init_h_grad_data,
rnn.init_c_desc(),
init_c_grad_data,
workspace_data_.data<uint8_t>(),
workspace_size,
const_cast<uint8_t *>(reserve_data),
reserve_size));
#else
// This interface is used when the input/output is unpadded.
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnRNNBackwardData(
handle,
rnn.rnn_desc(),
seq_length,
rnn.y_descs(),
out_data,
rnn.y_descs(),
out_grad_data,
rnn.last_h_desc(),
last_h_grad_data,
rnn.last_c_desc(),
last_c_grad_data,
rnn.weight_desc(),
weight_data,
rnn.init_h_desc(),
init_h_data,
rnn.init_c_desc(),
init_c_data,
rnn.x_descs(),
x_grad_data,
rnn.init_h_desc(),
init_h_grad_data,
rnn.init_c_desc(),
init_c_grad_data,
workspace_data_.data<uint8_t>(),
workspace_size,
const_cast<uint8_t *>(reserve_data),
reserve_size));
#endif
}
if (!weight_grad_list.empty()) {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::miopenRNNBackwardWeights(
handle,
rnn.rnn_desc(),
seq_length,
rnn.x_descs(),
x.data<T>(),
rnn.init_h_desc(),
init_h_data,
rnn.y_descs(),
out.data<T>(),
rnn.weight_desc(),
weight_grad_data,
workspace_data_.data<uint8_t>(),
workspace_size,
const_cast<uint8_t *>(reserve_data),
reserve_size));
// permute weight grad list from weight grad tensor
TensorToPermutedWeight<T>(
place, stream, weight_grad, &weight_grad_list, rnn_mode, is_bidirec);
#else
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnRNNBackwardWeights(
handle,
rnn.rnn_desc(),
seq_length,
rnn.x_descs(),
x.data<T>(),
rnn.init_h_desc(),
init_h_data,
rnn.y_descs(),
out.data<T>(),
workspace_data_.data<uint8_t>(),
workspace_size,
rnn.weight_desc(),
weight_grad_data,
const_cast<uint8_t *>(reserve_data),
reserve_size));
#endif
}
} else {
#if defined(PADDLE_WITH_CUDA) && CUDNN_VERSION >= 7201
// for train
// This interface is used when the input/output is padded.
if (x_grad) {
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnRNNBackwardDataEx(
handle,
rnn.rnn_desc(),
rnn.y_seq_desc(),
out_data,
rnn.y_seq_desc(),
out_grad_data,
nullptr,
nullptr,
rnn.last_h_desc(),
last_h_grad_data,
rnn.last_c_desc(),
last_c_grad_data,
rnn.weight_desc(),
weight_data,
rnn.init_h_desc(),
init_h_data,
rnn.init_c_desc(),
init_c_data,
rnn.x_seq_desc(),
x_grad_data,
rnn.init_h_desc(),
init_h_grad_data,
rnn.init_c_desc(),
init_c_grad_data,
nullptr,
nullptr,
workspace_data_.data<uint8_t>(),
workspace_size,
const_cast<uint8_t *>(reserve_data),
reserve_size));
}
if (!weight_grad_list.empty()) {
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnRNNBackwardWeightsEx(
handle,
rnn.rnn_desc(),
rnn.x_seq_desc(),
x.data<T>(),
rnn.init_h_desc(),
init_h_data,
rnn.y_seq_desc(),
out.data<T>(),
workspace_data_.data<uint8_t>(),
workspace_size,
rnn.weight_desc(),
weight_grad_data,
const_cast<uint8_t *>(reserve_data),
reserve_size));
}
#else
PADDLE_THROW(phi::errors::Unavailable(
"The padded input of rnn is supported by cudnnRNNBackwardDataEx, "
"cudnnRNNBackwardWeightsEx, but it only works when the version "
"of cudnn is larger than 7.2.1"));
#endif
}
}
} // namespace phi
#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
PD_REGISTER_KERNEL(rnn_grad, GPU, ALL_LAYOUT, phi::RnnGradKernel, float) {}
#else
PD_REGISTER_KERNEL(
rnn_grad, GPU, ALL_LAYOUT, phi::RnnGradKernel, float, double) {}
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/rnn_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/gpu/rnn_functor.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/operators/utils.h"
namespace phi {
template <typename T>
void RNNInferece(bool has_seq_length,
const gpuDnnHandle_t &handle,
int seq_length,
RNNDescriptors *rnn,
const T *x_data,
const T *init_h_data,
const T *init_c_data,
const T *w_data,
T *out_data,
T *last_h_data,
T *last_c_data,
DenseTensor *workspace_data,
size_t workspace_size) {
if (!has_seq_length) {
// for inference
// This interface is used when the input/output is unpadded.
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::miopenRNNForwardInference(
handle,
rnn->rnn_desc(),
seq_length,
rnn->x_descs(),
x_data,
rnn->init_h_desc(),
init_h_data,
rnn->init_c_desc(),
init_c_data,
rnn->weight_desc(),
w_data,
rnn->y_descs(),
out_data,
rnn->last_h_desc(),
last_h_data,
rnn->last_c_desc(),
last_c_data,
workspace_data->data<uint8_t>(),
workspace_size));
#else
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnRNNForwardInference(
handle,
rnn->rnn_desc(),
seq_length,
rnn->x_descs(),
x_data,
rnn->init_h_desc(),
init_h_data,
rnn->init_c_desc(),
init_c_data,
rnn->weight_desc(),
w_data,
rnn->y_descs(),
out_data,
rnn->last_h_desc(),
last_h_data,
rnn->last_c_desc(),
last_c_data,
workspace_data->data<uint8_t>(),
workspace_size));
#endif
} else {
#if defined(PADDLE_WITH_CUDA) && CUDNN_VERSION >= 7201
// for inference
// This interface is used when the input/output is padded.
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnRNNForwardInferenceEx(
handle,
rnn->rnn_desc(),
rnn->x_seq_desc(),
x_data,
rnn->init_h_desc(),
init_h_data,
rnn->init_c_desc(),
init_c_data,
rnn->weight_desc(),
w_data,
rnn->y_seq_desc(),
out_data,
rnn->last_h_desc(),
last_h_data,
rnn->last_c_desc(),
last_c_data,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
workspace_data->data<uint8_t>(),
workspace_size));
#else
// CUDNN VERSION has to >=7.2.1
PADDLE_THROW(phi::errors::Unavailable(
"The padded input is supported by "
"cudnnRNNForwardInferenceEx, but it only works when "
"the version of cudnn is larger than 7.2.1"));
#endif
}
}
template <typename T, typename Context>
void RnnKernel(const Context &dev_ctx,
const DenseTensor &x,
const std::vector<const DenseTensor *> &pre_state,
const std::vector<const DenseTensor *> &weight_list,
paddle::optional<const DenseTensor &> sequence_length,
float dropout_prob,
bool is_bidirec,
int input_size,
int hidden_size,
int num_layers,
const std::string &mode,
int seed,
bool is_test,
DenseTensor *out,
DenseTensor *dropout_state,
std::vector<DenseTensor *> state,
DenseTensor *reserve) {
#ifdef PADDLE_WITH_HIP
gpuRNNMode_t rnn_mode = miopenLSTM;
if (mode == "LSTM")
rnn_mode = miopenLSTM;
else if (mode == "GRU")
rnn_mode = miopenGRU;
else if (mode == "RNN_RELU")
rnn_mode = miopenRNNRELU;
else if (mode == "RNN_TANH")
rnn_mode = miopenRNNTANH;
#else
gpuRNNMode_t rnn_mode = CUDNN_LSTM;
if (mode == "LSTM")
rnn_mode = CUDNN_LSTM;
else if (mode == "GRU")
rnn_mode = CUDNN_GRU;
else if (mode == "RNN_RELU")
rnn_mode = CUDNN_RNN_RELU;
else if (mode == "RNN_TANH")
rnn_mode = CUDNN_RNN_TANH;
#endif
else
PADDLE_THROW(phi::errors::InvalidArgument(
"rnn_mode should be LSTM, GRU, RNN_RELU or RNN_TANH, but received: "
"%s.",
mode));
if (!is_test) {
int device_id = dev_ctx.GetPlace().GetDeviceId();
auto gen_cuda = paddle::framework::GetDefaultCUDAGenerator(device_id);
if (gen_cuda->GetIsInitPy() && seed == 0) {
// If perform `manual_seed` in python and inner seed is not specified
// (equals 0), use global generator generated seed.
seed = static_cast<int>(gen_cuda->Random64());
} else if (seed == 0) {
// use random generated seed
std::random_device rd;
seed = rd();
} // else use `ctx.Attr<int>("seed")` specified seed
}
const T *x_data = x.data<T>();
const T *init_h_data = pre_state[0]->data<T>();
const T *init_c_data = nullptr;
T *out_data = dev_ctx.template Alloc<T>(out);
T *last_h_data = dev_ctx.template Alloc<T>(state[0]);
T *last_c_data = nullptr;
#ifdef PADDLE_WITH_HIP
if (rnn_mode == miopenLSTM) {
#else
if (rnn_mode == CUDNN_LSTM) {
#endif
init_c_data = pre_state[1]->data<T>();
last_c_data = dev_ctx.template Alloc<T>(state[1]);
}
bool has_seq_length = sequence_length.is_initialized();
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_EQ(
has_seq_length,
false,
phi::errors::InvalidArgument("ROCm do not support SequenceLength yet."));
#endif
std::vector<int> SequenceLength;
if (has_seq_length) {
SequenceLength =
paddle::operators::GetDataFromTensor<int>(sequence_length.get_ptr());
}
auto handle = dev_ctx.cudnn_handle();
int seq_length = x.dims()[0];
int batch_size = x.dims()[1];
int input_size_local = x.dims()[2];
size_t workspace_size;
size_t reserve_size;
DenseTensor weight_whole;
T *w_data = nullptr;
auto place = dev_ctx.GetPlace();
auto stream = dev_ctx.stream();
auto weight_numel = std::accumulate(
weight_list.begin(),
weight_list.end(),
0,
[](int64_t num, const DenseTensor *t) { return num + t->numel(); });
bool continuous =
IsContinuous<T, std::vector<const DenseTensor *>>(weight_list);
#ifdef PADDLE_WITH_HIP
// Need to permute weight, set continuous to false
continuous = false;
#endif
if (!continuous) {
LOG_FIRST_N(WARNING, 2)
<< "If the memory space of the Input WeightList is not continuous, "
"less efficient calculation will be called. Please call "
"flatten_parameters() to make the input memory continuous.";
weight_whole.Resize({weight_numel});
dev_ctx.template Alloc<T>(&weight_whole);
#ifdef PADDLE_WITH_HIP
// MIOPEN need to permute weight for miopenLSTM or miopenGRU
std::vector<const DenseTensor *> weight_list_tmp = weight_list;
WeightToPermutedTensor<T>(
place, stream, &weight_list_tmp, &weight_whole, rnn_mode, is_bidirec);
#else
WeightToTensor<T>(place, stream, weight_list, &weight_whole);
#endif
w_data = weight_whole.data<T>();
#ifndef PADDLE_WITH_HIP
// MIOPEN need to permute weight, do not share with weight_grad
if (is_test) { // maybe also reset small weights' ptr for training
int offset = 0;
for (size_t i = 0; i < weight_list.size(); ++i) {
size_t len = weight_list[i]->numel();
auto dim = weight_list[i]->dims();
const_cast<DenseTensor *>(weight_list[i])
->ShareDataWith(
weight_whole.Slice(static_cast<int64_t>(offset),
static_cast<int64_t>(offset + len)))
.Resize(dim);
offset += len;
}
}
#endif
} else {
w_data = const_cast<T *>(weight_list[0]->data<T>());
}
RNNDescriptors rnn(seq_length,
batch_size,
input_size_local,
hidden_size,
num_layers,
dropout_prob,
seed,
weight_numel,
rnn_mode,
is_bidirec,
is_test);
rnn.Create<T>(handle,
dev_ctx.GetPlace(),
SequenceLength,
&workspace_size,
&reserve_size,
dropout_state);
DenseTensor workspace_data_ =
Empty<uint8_t>(dev_ctx, {static_cast<int64_t>(workspace_size)});
reserve->Resize({static_cast<int64_t>(reserve_size)});
auto *reserve_data = dev_ctx.template Alloc<uint8_t>(reserve);
if (is_test) {
RNNInferece(has_seq_length,
handle,
seq_length,
&rnn,
x_data,
init_h_data,
init_c_data,
w_data,
out_data,
last_h_data,
last_c_data,
&workspace_data_,
workspace_size);
} else {
if (!has_seq_length) {
// for train
// This interface is used when the input/output is unpadded.
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::miopenRNNForwardTraining(
handle,
rnn.rnn_desc(),
seq_length,
rnn.x_descs(),
x_data,
rnn.init_h_desc(),
init_h_data,
rnn.init_c_desc(),
init_c_data,
rnn.weight_desc(),
w_data,
rnn.y_descs(),
out_data,
rnn.last_h_desc(),
last_h_data,
rnn.last_c_desc(),
last_c_data,
workspace_data_.data<uint8_t>(),
workspace_size,
reserve_data,
reserve_size));
#else
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnRNNForwardTraining(
handle,
rnn.rnn_desc(),
seq_length,
rnn.x_descs(),
x_data,
rnn.init_h_desc(),
init_h_data,
rnn.init_c_desc(),
init_c_data,
rnn.weight_desc(),
w_data,
rnn.y_descs(),
out_data,
rnn.last_h_desc(),
last_h_data,
rnn.last_c_desc(),
last_c_data,
workspace_data_.data<uint8_t>(),
workspace_size,
reserve_data,
reserve_size));
#endif
} else {
#if defined(PADDLE_WITH_CUDA) && CUDNN_VERSION >= 7201
// for train
// This interface is used when the input/output is padded.
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnRNNForwardTrainingEx(
handle,
rnn.rnn_desc(),
rnn.x_seq_desc(),
x_data,
rnn.init_h_desc(),
init_h_data,
rnn.init_c_desc(),
init_c_data,
rnn.weight_desc(),
w_data,
rnn.y_seq_desc(),
out_data,
rnn.last_h_desc(),
last_h_data,
rnn.last_c_desc(),
last_c_data,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
workspace_data_.data<uint8_t>(),
workspace_size,
reserve_data,
reserve_size));
#else
PADDLE_THROW(phi::errors::Unavailable(
"The padded input is supported by "
"cudnnRNNForwardTrainingEx, but it only works when "
"the version of cudnn is larger than 7.2.1"));
#endif
}
}
}
} // namespace phi
#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
PD_REGISTER_KERNEL(rnn, GPU, ALL_LAYOUT, phi::RnnKernel, float) {}
#else
PD_REGISTER_KERNEL(rnn, GPU, ALL_LAYOUT, phi::RnnKernel, float, double) {}
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/utils/optional.h"
namespace phi {
template <typename T, typename Context>
void RnnGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<const DenseTensor*>& pre_state,
const std::vector<const DenseTensor*>& weight_list,
paddle::optional<const DenseTensor&> sequence_length,
const DenseTensor& out,
const DenseTensor& dropout_state,
const DenseTensor& reserve,
const DenseTensor& out_grad,
const std::vector<const DenseTensor*>& state_grad,
float dropout_prob,
bool is_bidirec,
int input_size,
int hidden_size,
int num_layers,
const std::string& mode,
int seed,
bool is_test,
DenseTensor* x_grad,
std::vector<DenseTensor*> pre_state_grad,
std::vector<DenseTensor*> weight_grad_list);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/utils/optional.h"
namespace phi {
template <typename T, typename Context>
void RnnKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<const DenseTensor*>& pre_state,
const std::vector<const DenseTensor*>& weight_list,
paddle::optional<const DenseTensor&> sequence_length,
float dropout_prob,
bool is_bidirec,
int input_size,
int hidden_size,
int num_layers,
const std::string& mode,
int seed,
bool is_test,
DenseTensor* out,
DenseTensor* dropout_state,
std::vector<DenseTensor*> state,
DenseTensor* reserve);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature RnnOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("rnn",
{"Input", "PreState", "WeightList", "SequenceLength"},
{"dropout_prob",
"is_bidirec",
"input_size",
"hidden_size",
"num_layers",
"mode",
"seed",
"is_test"},
{"Out", "DropoutState", "State", "Reserve"});
}
KernelSignature RnnGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("rnn_grad",
{"Input",
"PreState",
"WeightList",
"SequenceLength",
"Out",
"DropoutState",
"Reserve",
GradVarName("Out"),
GradVarName("State")},
{"dropout_prob",
"is_bidirec",
"input_size",
"hidden_size",
"num_layers",
"mode",
"seed",
"is_test"},
{GradVarName("Input"),
GradVarName("PreState"),
GradVarName("WeightList")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(rnn, phi::RnnOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(rnn_grad, phi::RnnGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册