提交 6f78fd7d 编写于 作者: T tensor-tang

fuse fc in gru

上级 300180cc
......@@ -15,8 +15,11 @@ limitations under the License. */
#include "paddle/fluid/operators/fusion_gru_op.h"
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h"
#include "paddle/fluid/operators/math/detail/gru_kernel.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/gru_compute.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
......@@ -25,47 +28,69 @@ namespace paddle {
namespace operators {
void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(%s) of GRUOp should not be null.", "Input");
PADDLE_ENFORCE(ctx->HasInput("Weight"),
"Input(%s) of GRUOp should not be null.", "Weight");
PADDLE_ENFORCE(ctx->HasOutput("BatchGate"),
"Output(%s) of GRUOp should not be null.", "BatchGate");
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasInput("WeightX"),
"Input(WeightX) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasInput("WeightH"),
"Input(WeightH) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("XX"), "Output(XX) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedGate"),
"Output(BatchedGate) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchResetHiddenPrev"),
"Output(%s) of GRUOp should not be null.",
"BatchResetHiddenPrev");
PADDLE_ENFORCE(ctx->HasOutput("BatchHidden"),
"Output(%s) of GRUOp should not be null.", "BatchHidden");
"Output(BatchResetHiddenPrev) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedHidden"),
"Output(BatchedHidden) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Output(%s) of GRUOp should not be null.", "Hidden");
auto input_dims = ctx->GetInputDim("Input");
auto weight_dims = ctx->GetInputDim("Weight");
int input_size = input_dims[1];
int frame_size = weight_dims[0];
PADDLE_ENFORCE_EQ(input_size, frame_size * 3,
"The input_size must be 3 times of frame_size in GRUOp.");
PADDLE_ENFORCE_EQ(
weight_dims[1], frame_size * 3,
"The shape of Weight matrix must be [frame_size, frame_size * 3].");
"Output(Hidden) of GRU should not be null.");
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
auto wx_dims = ctx->GetInputDim("WeightX");
PADDLE_ENFORCE_EQ(wx_dims.size(), 2,
"The rank of Input(WeightX) should be 2.");
PADDLE_ENFORCE_EQ(wx_dims[0], x_dims[1],
"The first dimension of Input(WeightX) "
"should be %d.",
x_dims[1]);
int frame_size = wx_dims[1] / 3;
auto wh_dims = ctx->GetInputDim("WeightH");
PADDLE_ENFORCE_EQ(wh_dims.size(), 2,
"The rank of Input(WeightH) should be 2.");
PADDLE_ENFORCE_EQ(wh_dims[0], frame_size,
"The first dimension of Input(WeightH) "
"should be %d.",
frame_size);
PADDLE_ENFORCE_EQ(wh_dims[1], 3 * frame_size,
"The second dimension of Input(WeightH) "
"should be 3 * %d.",
frame_size);
if (ctx->HasInput("H0")) {
auto h0_dims = ctx->GetInputDim("H0");
PADDLE_ENFORCE_EQ(h0_dims[1], frame_size,
"The width of H0 must be equal to frame_size.");
}
if (ctx->HasInput("Bias")) {
auto bias_dims = ctx->GetInputDim("Bias");
int bias_height = bias_dims[0];
int bias_width = bias_dims[1];
PADDLE_ENFORCE_EQ(bias_height, 1,
"The shape of Bias must be [1, frame_size * 3].");
PADDLE_ENFORCE_EQ(bias_width, frame_size * 3,
auto b_dims = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2.");
PADDLE_ENFORCE_EQ(b_dims[0], 1,
"The first dimension of Input(Bias) should be 1.");
PADDLE_ENFORCE_EQ(b_dims[1], frame_size * 3,
"The shape of Bias must be [1, frame_size * 3].");
}
ctx->SetOutputDim("BatchGate", input_dims);
ctx->SetOutputDim("BatchResetHiddenPrev", {input_dims[0], frame_size});
ctx->SetOutputDim("BatchHidden", {input_dims[0], frame_size});
ctx->SetOutputDim("Hidden", {input_dims[0], frame_size});
ctx->ShareLoD("Input", "Hidden");
framework::DDim out_dims({x_dims[0], frame_size});
ctx->SetOutputDim("Hidden", out_dims);
ctx->SetOutputDim("BatchedGate", {x_dims[0], wx_dims[1]});
ctx->SetOutputDim("BatchedHidden", out_dims);
ctx->SetOutputDim("BatchResetHiddenPrev", out_dims);
ctx->ShareLoD("X", "Hidden");
int xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
ctx->SetOutputDim("XX", {x_dims[0], xx_width});
ctx->ShareLoD("X", "XX");
}
framework::OpKernelType FusionGRUOp::GetExpectedKernelType(
......@@ -76,53 +101,38 @@ framework::OpKernelType FusionGRUOp::GetExpectedKernelType(
}
void FusionGRUOpMaker::Make() {
AddInput("Input",
"(LoDTensor) The first input is a LodTensor, which supports "
AddInput("X",
"(LoDTensor) the input is a LodTensor, which support "
"variable-time length input sequence. The underlying tensor in "
"this LoDTenosr is a matrix with shape (T X 3D), where, T is the "
"total time steps in this mini-batch, D is the hidden size.");
"this LoDTensor is a matrix with shape (T X M), where T is the "
"total time steps in this mini-batch, M is the dim size of x.");
AddInput("H0",
"(Tensor, optional) The initial hidden state is an optional "
"input. This is a tensor with shape (N x D), where N is the "
"batch size, D is the hidden size.")
.AsDispensable();
AddInput(
"Weight",
"(Tensor) The learnable hidden-hidden weight matrix with shape "
"(D x 3D), where D is the hidden size. The elements continuous in "
"memory can be divided into two parts. The first part are weights of "
"the update gate and reset gate with shape (D x 2D), and the second "
"part are weights of output candidate with shape (D x D).");
AddInput("WeightX",
"(Tensor) The FC weight with shape (M x 3D),"
"where M is the dim size of x, D is the hidden size. ");
AddInput("WeightH",
"(Tensor) (D x 3D) Same as GRUOp, where D is the hidden size. ");
AddInput("Bias",
"(Tensor, optional) Bias vector with shape (1 x 3D) concating "
"bias of the update gate, reset gate and output candidate.")
"(Tensor, optional) (1 x 3D)."
"Almost same as GRUOp."
"Note: if have FC bias it should be added on this bias.")
.AsDispensable();
AddOutput("BatchGate",
"(LoDTensor) To compute with batches, sequence data will be "
"reorganized into several successive batches each containing "
"data from the same time step. The LoDTensor BatchGate contains "
"the update gate, reset gate and output candidate values "
"organized in batches. The LoD size is 2. The first LoD contains "
"the batch offsets and the second LoD contains the indexes in "
"the raw sequence data.")
AddOutput("XX",
"(LoDTensor) the result after X * WeightX (size is T x 4D)"
" or batched_X (size is T x M), this will be automatically chosen,"
" where T is the total time steps in this mini-batch,"
" D is the hidden size, M is the dim size of x input.")
.AsIntermediate();
AddOutput(
"BatchResetHiddenPrev",
"(LoDTensor) The reseted hidden state LoDTensor organized in batches. "
"This LoDTensor is a matrix with shape (T X D) and has the same LoD "
"with `BatchGate`.")
AddOutput("BatchedGate", "(LoDTensor) Same as GRUOp").AsIntermediate();
AddOutput("BatchResetHiddenPrev", "(LoDTensor) (T x 3D) Same as GRUOp.")
.AsIntermediate();
AddOutput(
"BatchHidden",
"(LoDTensor) The hidden state LoDTensor organized in batches. "
"This LoDTensor is a matrix with shape (T X D) and has the same LoD "
"with `BatchGate`.")
AddOutput("BatchedHidden", "(LoDTensor) (T X D) Same as GRUOp.")
.AsIntermediate();
AddOutput(
"Hidden",
"(LoDTensor) the hidden state LoDTensor organized in sequences. "
"This LoDTensor is a matrix with shape (T X D) and has the same LoD "
"with `BatchGate`.");
AddOutput("Hidden", "(LoDTensor) (T x D) Same as GRUOp");
AddAttr<std::string>("activation",
"(string, default tanh) "
"The activation type used for output candidate {h}_t.")
......@@ -156,52 +166,71 @@ inline void ReorderInitState(const DeviceContext& ctx,
template <typename DeviceContext, typename T>
class FusionGRUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<LoDTensor>("X");
auto* h = context.Input<LoDTensor>("H");
auto* h0 = context.Input<Tensor>("H0");
auto* x_weight = context.Input<Tensor>("XWeight"); // x_dim*3D
auto* gate_weight = context.Input<Tensor>("HWeight"); // D*3D
auto* bias = context.Input<Tensor>("Bias"); // 1*3D
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<LoDTensor>("X");
auto* wx = ctx.Input<Tensor>("WeightX");
auto* wh = ctx.Input<Tensor>("WeightH");
auto* bias = ctx.Input<Tensor>("Bias");
auto* h0 = ctx.Input<Tensor>("H0");
auto hidden_dims = hidden->dims();
auto* xx = ctx.Output<LoDTensor>("XX");
auto* batched_gate = ctx.Output<LoDTensor>("BatchedGate");
auto* batch_reset_hidden_prev =
ctx.Output<LoDTensor>("BatchResetHiddenPrev");
auto* batch_hidden = ctx.Output<LoDTensor>("BatchedHidden");
auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
bool is_reverse = ctx.Attr<bool>("is_reverse");
bool is_reverse = context.Attr<bool>("is_reverse");
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
auto& dev_ctx = context.template device_context<DeviceContext>();
to_batch(dev_ctx, *input, batch_gate, true, is_reverse);
T* xx_data = xx->mutable_data<T>(ctx.GetPlace());
T* batched_gate_data = batched_gate->mutable_data<T>(ctx.GetPlace());
batch_reset_hidden_prev->mutable_data<T>(ctx.GetPlace());
batch_hidden->mutable_data<T>(ctx.GetPlace());
hidden_out->mutable_data<T>(ctx.GetPlace());
if (bias) {
math::RowwiseAdd<DeviceContext, T> add_bias;
add_bias(dev_ctx, *batch_gate, *bias, batch_gate);
const T* x_data = x->data<T>();
const T* wx_data = wx->data<T>();
const T* wh_data = wh->data<T>();
auto x_dims = x->dims();
auto wx_dims = wx->dims();
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
if (x_dims[1] > wx_dims[1]) {
math::FCCompute<DeviceContext, T>(blas, x_dims[0], wx_dims[1], x_dims[1],
x_data, wx_data, xx_data,
bias ? bias->data<T>() : NULL);
to_batch(dev_ctx, *xx, batched_gate, true, is_reverse);
} else {
to_batch(dev_ctx, *x, xx, true, is_reverse);
batched_gate->set_lod(xx->lod());
math::FCCompute<DeviceContext, T>(blas, x_dims[0], wx_dims[1], x_dims[1],
xx_data, wx_data, batched_gate_data,
bias ? bias->data<T>() : NULL);
}
int frame_size = hidden_dims[1];
int frame_size = static_cast<int>(wx_dims[1] / 3);
math::GRUMetaValue<T> gru_value;
gru_value.gate_weight = const_cast<T*>(weight_data);
gru_value.gate_weight = const_cast<T*>(wh_data);
gru_value.state_weight =
const_cast<T*>(weight_data + 2 * frame_size * frame_size);
const_cast<T*>(wh_data + 2 * frame_size * frame_size);
Tensor ordered_h0;
framework::Vector<size_t> order(batch_gate->lod()[2]);
framework::Vector<size_t> order(batched_gate->lod()[2]);
if (h0) {
// Since the batch computing for GRU reorders the input sequences
// according to their length. The initialized cell state also needs
// to reorder.
ReorderInitState<DeviceContext, T>(
context.template device_context<DeviceContext>(), *h0, order,
&ordered_h0, true);
ctx.template device_context<DeviceContext>(), *h0, order, &ordered_h0,
true);
gru_value.prev_out_value = ordered_h0.data<T>();
} else {
gru_value.prev_out_value = nullptr;
}
auto batch_starts = batch_gate->lod()[0];
auto batch_starts = batched_gate->lod()[0];
size_t seq_len = batch_starts.size() - 1;
auto active_node = math::detail::GetActivationType(
context.Attr<std::string>("activation"));
auto active_node =
math::detail::GetActivationType(ctx.Attr<std::string>("activation"));
auto active_gate = math::detail::GetActivationType(
context.Attr<std::string>("gate_activation"));
ctx.Attr<std::string>("gate_activation"));
#ifdef PADDLE_WITH_MKLML
// use MKL packed to speedup GEMM
......@@ -226,7 +255,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
int bend = static_cast<int>(batch_starts[n + 1]);
int cur_batch_size = bend - bstart;
Tensor gate_t = batch_gate->Slice(bstart, bend);
Tensor gate_t = batched_gate->Slice(bstart, bend);
Tensor reset_hidden_prev_t =
batch_reset_hidden_prev->Slice(bstart, bend);
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
......@@ -269,7 +298,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
int bend = static_cast<int>(batch_starts[n + 1]);
int cur_batch_size = bend - bstart;
Tensor gate_t = batch_gate->Slice(bstart, bend);
Tensor gate_t = batched_gate->Slice(bstart, bend);
Tensor reset_hidden_prev_t =
batch_reset_hidden_prev->Slice(bstart, bend);
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
......@@ -287,8 +316,8 @@ class FusionGRUKernel : public framework::OpKernel<T> {
}
#endif
math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
batch_hidden->set_lod(batch_gate->lod());
to_seq(dev_ctx, *batch_hidden, hidden);
batch_hidden->set_lod(batched_gate->lod());
to_seq(dev_ctx, *batch_hidden, hidden_out);
}
};
......@@ -300,4 +329,4 @@ REGISTER_OPERATOR(fusion_gru, ops::FusionGRUOp, ops::FusionGRUOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OP_CPU_KERNEL(
fusion_gru, ops::FusionGRUKernel<paddle::platform::CPUDeviceContext, float>,
ops::GRUKernel<paddle::platform::CPUDeviceContext, double>);
ops::FusionGRUKernel<paddle::platform::CPUDeviceContext, double>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册