提交 6f78fd7d 编写于 作者: T tensor-tang

fuse fc in gru

上级 300180cc
...@@ -15,8 +15,11 @@ limitations under the License. */ ...@@ -15,8 +15,11 @@ limitations under the License. */
#include "paddle/fluid/operators/fusion_gru_op.h" #include "paddle/fluid/operators/fusion_gru_op.h"
#include <string> #include <string>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h"
#include "paddle/fluid/operators/math/detail/gru_kernel.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/gru_compute.h" #include "paddle/fluid/operators/math/gru_compute.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/sequence2batch.h" #include "paddle/fluid/operators/math/sequence2batch.h"
...@@ -25,47 +28,69 @@ namespace paddle { ...@@ -25,47 +28,69 @@ namespace paddle {
namespace operators { namespace operators {
void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("Input"), PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of GRU should not be null.");
"Input(%s) of GRUOp should not be null.", "Input"); PADDLE_ENFORCE(ctx->HasInput("WeightX"),
PADDLE_ENFORCE(ctx->HasInput("Weight"), "Input(WeightX) of GRU should not be null.");
"Input(%s) of GRUOp should not be null.", "Weight"); PADDLE_ENFORCE(ctx->HasInput("WeightH"),
PADDLE_ENFORCE(ctx->HasOutput("BatchGate"), "Input(WeightH) of GRU should not be null.");
"Output(%s) of GRUOp should not be null.", "BatchGate");
PADDLE_ENFORCE(ctx->HasOutput("XX"), "Output(XX) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedGate"),
"Output(BatchedGate) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchResetHiddenPrev"), PADDLE_ENFORCE(ctx->HasOutput("BatchResetHiddenPrev"),
"Output(%s) of GRUOp should not be null.", "Output(BatchResetHiddenPrev) of GRU should not be null.");
"BatchResetHiddenPrev"); PADDLE_ENFORCE(ctx->HasOutput("BatchedHidden"),
PADDLE_ENFORCE(ctx->HasOutput("BatchHidden"), "Output(BatchedHidden) of GRU should not be null.");
"Output(%s) of GRUOp should not be null.", "BatchHidden");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"), PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Output(%s) of GRUOp should not be null.", "Hidden"); "Output(Hidden) of GRU should not be null.");
auto input_dims = ctx->GetInputDim("Input");
auto weight_dims = ctx->GetInputDim("Weight"); auto x_dims = ctx->GetInputDim("X");
int input_size = input_dims[1]; PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
int frame_size = weight_dims[0];
PADDLE_ENFORCE_EQ(input_size, frame_size * 3, auto wx_dims = ctx->GetInputDim("WeightX");
"The input_size must be 3 times of frame_size in GRUOp."); PADDLE_ENFORCE_EQ(wx_dims.size(), 2,
PADDLE_ENFORCE_EQ( "The rank of Input(WeightX) should be 2.");
weight_dims[1], frame_size * 3, PADDLE_ENFORCE_EQ(wx_dims[0], x_dims[1],
"The shape of Weight matrix must be [frame_size, frame_size * 3]."); "The first dimension of Input(WeightX) "
"should be %d.",
x_dims[1]);
int frame_size = wx_dims[1] / 3;
auto wh_dims = ctx->GetInputDim("WeightH");
PADDLE_ENFORCE_EQ(wh_dims.size(), 2,
"The rank of Input(WeightH) should be 2.");
PADDLE_ENFORCE_EQ(wh_dims[0], frame_size,
"The first dimension of Input(WeightH) "
"should be %d.",
frame_size);
PADDLE_ENFORCE_EQ(wh_dims[1], 3 * frame_size,
"The second dimension of Input(WeightH) "
"should be 3 * %d.",
frame_size);
if (ctx->HasInput("H0")) { if (ctx->HasInput("H0")) {
auto h0_dims = ctx->GetInputDim("H0"); auto h0_dims = ctx->GetInputDim("H0");
PADDLE_ENFORCE_EQ(h0_dims[1], frame_size, PADDLE_ENFORCE_EQ(h0_dims[1], frame_size,
"The width of H0 must be equal to frame_size."); "The width of H0 must be equal to frame_size.");
} }
if (ctx->HasInput("Bias")) { if (ctx->HasInput("Bias")) {
auto bias_dims = ctx->GetInputDim("Bias"); auto b_dims = ctx->GetInputDim("Bias");
int bias_height = bias_dims[0]; PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2.");
int bias_width = bias_dims[1]; PADDLE_ENFORCE_EQ(b_dims[0], 1,
PADDLE_ENFORCE_EQ(bias_height, 1, "The first dimension of Input(Bias) should be 1.");
"The shape of Bias must be [1, frame_size * 3]."); PADDLE_ENFORCE_EQ(b_dims[1], frame_size * 3,
PADDLE_ENFORCE_EQ(bias_width, frame_size * 3,
"The shape of Bias must be [1, frame_size * 3]."); "The shape of Bias must be [1, frame_size * 3].");
} }
ctx->SetOutputDim("BatchGate", input_dims); framework::DDim out_dims({x_dims[0], frame_size});
ctx->SetOutputDim("BatchResetHiddenPrev", {input_dims[0], frame_size}); ctx->SetOutputDim("Hidden", out_dims);
ctx->SetOutputDim("BatchHidden", {input_dims[0], frame_size}); ctx->SetOutputDim("BatchedGate", {x_dims[0], wx_dims[1]});
ctx->SetOutputDim("Hidden", {input_dims[0], frame_size}); ctx->SetOutputDim("BatchedHidden", out_dims);
ctx->ShareLoD("Input", "Hidden"); ctx->SetOutputDim("BatchResetHiddenPrev", out_dims);
ctx->ShareLoD("X", "Hidden");
int xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
ctx->SetOutputDim("XX", {x_dims[0], xx_width});
ctx->ShareLoD("X", "XX");
} }
framework::OpKernelType FusionGRUOp::GetExpectedKernelType( framework::OpKernelType FusionGRUOp::GetExpectedKernelType(
...@@ -76,53 +101,38 @@ framework::OpKernelType FusionGRUOp::GetExpectedKernelType( ...@@ -76,53 +101,38 @@ framework::OpKernelType FusionGRUOp::GetExpectedKernelType(
} }
void FusionGRUOpMaker::Make() { void FusionGRUOpMaker::Make() {
AddInput("Input", AddInput("X",
"(LoDTensor) The first input is a LodTensor, which supports " "(LoDTensor) the input is a LodTensor, which support "
"variable-time length input sequence. The underlying tensor in " "variable-time length input sequence. The underlying tensor in "
"this LoDTenosr is a matrix with shape (T X 3D), where, T is the " "this LoDTensor is a matrix with shape (T X M), where T is the "
"total time steps in this mini-batch, D is the hidden size."); "total time steps in this mini-batch, M is the dim size of x.");
AddInput("H0", AddInput("H0",
"(Tensor, optional) The initial hidden state is an optional " "(Tensor, optional) The initial hidden state is an optional "
"input. This is a tensor with shape (N x D), where N is the " "input. This is a tensor with shape (N x D), where N is the "
"batch size, D is the hidden size.") "batch size, D is the hidden size.")
.AsDispensable(); .AsDispensable();
AddInput( AddInput("WeightX",
"Weight", "(Tensor) The FC weight with shape (M x 3D),"
"(Tensor) The learnable hidden-hidden weight matrix with shape " "where M is the dim size of x, D is the hidden size. ");
"(D x 3D), where D is the hidden size. The elements continuous in " AddInput("WeightH",
"memory can be divided into two parts. The first part are weights of " "(Tensor) (D x 3D) Same as GRUOp, where D is the hidden size. ");
"the update gate and reset gate with shape (D x 2D), and the second "
"part are weights of output candidate with shape (D x D).");
AddInput("Bias", AddInput("Bias",
"(Tensor, optional) Bias vector with shape (1 x 3D) concating " "(Tensor, optional) (1 x 3D)."
"bias of the update gate, reset gate and output candidate.") "Almost same as GRUOp."
"Note: if have FC bias it should be added on this bias.")
.AsDispensable(); .AsDispensable();
AddOutput("BatchGate", AddOutput("XX",
"(LoDTensor) To compute with batches, sequence data will be " "(LoDTensor) the result after X * WeightX (size is T x 4D)"
"reorganized into several successive batches each containing " " or batched_X (size is T x M), this will be automatically chosen,"
"data from the same time step. The LoDTensor BatchGate contains " " where T is the total time steps in this mini-batch,"
"the update gate, reset gate and output candidate values " " D is the hidden size, M is the dim size of x input.")
"organized in batches. The LoD size is 2. The first LoD contains "
"the batch offsets and the second LoD contains the indexes in "
"the raw sequence data.")
.AsIntermediate(); .AsIntermediate();
AddOutput( AddOutput("BatchedGate", "(LoDTensor) Same as GRUOp").AsIntermediate();
"BatchResetHiddenPrev", AddOutput("BatchResetHiddenPrev", "(LoDTensor) (T x 3D) Same as GRUOp.")
"(LoDTensor) The reseted hidden state LoDTensor organized in batches. "
"This LoDTensor is a matrix with shape (T X D) and has the same LoD "
"with `BatchGate`.")
.AsIntermediate(); .AsIntermediate();
AddOutput( AddOutput("BatchedHidden", "(LoDTensor) (T X D) Same as GRUOp.")
"BatchHidden",
"(LoDTensor) The hidden state LoDTensor organized in batches. "
"This LoDTensor is a matrix with shape (T X D) and has the same LoD "
"with `BatchGate`.")
.AsIntermediate(); .AsIntermediate();
AddOutput( AddOutput("Hidden", "(LoDTensor) (T x D) Same as GRUOp");
"Hidden",
"(LoDTensor) the hidden state LoDTensor organized in sequences. "
"This LoDTensor is a matrix with shape (T X D) and has the same LoD "
"with `BatchGate`.");
AddAttr<std::string>("activation", AddAttr<std::string>("activation",
"(string, default tanh) " "(string, default tanh) "
"The activation type used for output candidate {h}_t.") "The activation type used for output candidate {h}_t.")
...@@ -156,52 +166,71 @@ inline void ReorderInitState(const DeviceContext& ctx, ...@@ -156,52 +166,71 @@ inline void ReorderInitState(const DeviceContext& ctx,
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class FusionGRUKernel : public framework::OpKernel<T> { class FusionGRUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = context.Input<LoDTensor>("X"); auto* x = ctx.Input<LoDTensor>("X");
auto* h = context.Input<LoDTensor>("H"); auto* wx = ctx.Input<Tensor>("WeightX");
auto* h0 = context.Input<Tensor>("H0"); auto* wh = ctx.Input<Tensor>("WeightH");
auto* x_weight = context.Input<Tensor>("XWeight"); // x_dim*3D auto* bias = ctx.Input<Tensor>("Bias");
auto* gate_weight = context.Input<Tensor>("HWeight"); // D*3D auto* h0 = ctx.Input<Tensor>("H0");
auto* bias = context.Input<Tensor>("Bias"); // 1*3D
auto hidden_dims = hidden->dims(); auto* xx = ctx.Output<LoDTensor>("XX");
auto* batched_gate = ctx.Output<LoDTensor>("BatchedGate");
auto* batch_reset_hidden_prev =
ctx.Output<LoDTensor>("BatchResetHiddenPrev");
auto* batch_hidden = ctx.Output<LoDTensor>("BatchedHidden");
auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
bool is_reverse = ctx.Attr<bool>("is_reverse");
bool is_reverse = context.Attr<bool>("is_reverse"); T* xx_data = xx->mutable_data<T>(ctx.GetPlace());
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch; T* batched_gate_data = batched_gate->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = context.template device_context<DeviceContext>(); batch_reset_hidden_prev->mutable_data<T>(ctx.GetPlace());
to_batch(dev_ctx, *input, batch_gate, true, is_reverse); batch_hidden->mutable_data<T>(ctx.GetPlace());
hidden_out->mutable_data<T>(ctx.GetPlace());
if (bias) { const T* x_data = x->data<T>();
math::RowwiseAdd<DeviceContext, T> add_bias; const T* wx_data = wx->data<T>();
add_bias(dev_ctx, *batch_gate, *bias, batch_gate); const T* wh_data = wh->data<T>();
auto x_dims = x->dims();
auto wx_dims = wx->dims();
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
if (x_dims[1] > wx_dims[1]) {
math::FCCompute<DeviceContext, T>(blas, x_dims[0], wx_dims[1], x_dims[1],
x_data, wx_data, xx_data,
bias ? bias->data<T>() : NULL);
to_batch(dev_ctx, *xx, batched_gate, true, is_reverse);
} else {
to_batch(dev_ctx, *x, xx, true, is_reverse);
batched_gate->set_lod(xx->lod());
math::FCCompute<DeviceContext, T>(blas, x_dims[0], wx_dims[1], x_dims[1],
xx_data, wx_data, batched_gate_data,
bias ? bias->data<T>() : NULL);
} }
int frame_size = hidden_dims[1]; int frame_size = static_cast<int>(wx_dims[1] / 3);
math::GRUMetaValue<T> gru_value; math::GRUMetaValue<T> gru_value;
gru_value.gate_weight = const_cast<T*>(weight_data); gru_value.gate_weight = const_cast<T*>(wh_data);
gru_value.state_weight = gru_value.state_weight =
const_cast<T*>(weight_data + 2 * frame_size * frame_size); const_cast<T*>(wh_data + 2 * frame_size * frame_size);
Tensor ordered_h0; Tensor ordered_h0;
framework::Vector<size_t> order(batch_gate->lod()[2]); framework::Vector<size_t> order(batched_gate->lod()[2]);
if (h0) { if (h0) {
// Since the batch computing for GRU reorders the input sequences
// according to their length. The initialized cell state also needs
// to reorder.
ReorderInitState<DeviceContext, T>( ReorderInitState<DeviceContext, T>(
context.template device_context<DeviceContext>(), *h0, order, ctx.template device_context<DeviceContext>(), *h0, order, &ordered_h0,
&ordered_h0, true); true);
gru_value.prev_out_value = ordered_h0.data<T>(); gru_value.prev_out_value = ordered_h0.data<T>();
} else { } else {
gru_value.prev_out_value = nullptr; gru_value.prev_out_value = nullptr;
} }
auto batch_starts = batch_gate->lod()[0]; auto batch_starts = batched_gate->lod()[0];
size_t seq_len = batch_starts.size() - 1; size_t seq_len = batch_starts.size() - 1;
auto active_node = math::detail::GetActivationType( auto active_node =
context.Attr<std::string>("activation")); math::detail::GetActivationType(ctx.Attr<std::string>("activation"));
auto active_gate = math::detail::GetActivationType( auto active_gate = math::detail::GetActivationType(
context.Attr<std::string>("gate_activation")); ctx.Attr<std::string>("gate_activation"));
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
// use MKL packed to speedup GEMM // use MKL packed to speedup GEMM
...@@ -226,7 +255,7 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -226,7 +255,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
int bend = static_cast<int>(batch_starts[n + 1]); int bend = static_cast<int>(batch_starts[n + 1]);
int cur_batch_size = bend - bstart; int cur_batch_size = bend - bstart;
Tensor gate_t = batch_gate->Slice(bstart, bend); Tensor gate_t = batched_gate->Slice(bstart, bend);
Tensor reset_hidden_prev_t = Tensor reset_hidden_prev_t =
batch_reset_hidden_prev->Slice(bstart, bend); batch_reset_hidden_prev->Slice(bstart, bend);
Tensor hidden_t = batch_hidden->Slice(bstart, bend); Tensor hidden_t = batch_hidden->Slice(bstart, bend);
...@@ -269,7 +298,7 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -269,7 +298,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
int bend = static_cast<int>(batch_starts[n + 1]); int bend = static_cast<int>(batch_starts[n + 1]);
int cur_batch_size = bend - bstart; int cur_batch_size = bend - bstart;
Tensor gate_t = batch_gate->Slice(bstart, bend); Tensor gate_t = batched_gate->Slice(bstart, bend);
Tensor reset_hidden_prev_t = Tensor reset_hidden_prev_t =
batch_reset_hidden_prev->Slice(bstart, bend); batch_reset_hidden_prev->Slice(bstart, bend);
Tensor hidden_t = batch_hidden->Slice(bstart, bend); Tensor hidden_t = batch_hidden->Slice(bstart, bend);
...@@ -287,8 +316,8 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -287,8 +316,8 @@ class FusionGRUKernel : public framework::OpKernel<T> {
} }
#endif #endif
math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq; math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
batch_hidden->set_lod(batch_gate->lod()); batch_hidden->set_lod(batched_gate->lod());
to_seq(dev_ctx, *batch_hidden, hidden); to_seq(dev_ctx, *batch_hidden, hidden_out);
} }
}; };
...@@ -300,4 +329,4 @@ REGISTER_OPERATOR(fusion_gru, ops::FusionGRUOp, ops::FusionGRUOpMaker, ...@@ -300,4 +329,4 @@ REGISTER_OPERATOR(fusion_gru, ops::FusionGRUOp, ops::FusionGRUOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
fusion_gru, ops::FusionGRUKernel<paddle::platform::CPUDeviceContext, float>, fusion_gru, ops::FusionGRUKernel<paddle::platform::CPUDeviceContext, float>,
ops::GRUKernel<paddle::platform::CPUDeviceContext, double>); ops::FusionGRUKernel<paddle::platform::CPUDeviceContext, double>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册