提交 8f913295 编写于 作者: T tensor-tang

fuse fc in lstm

上级 ddb05dff
......@@ -14,29 +14,37 @@ limitations under the License. */
#include "paddle/fluid/operators/fusion_lstm_op.h"
#include <string>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/lstm_compute.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
DECLARE_int32(paddle_num_threads);
namespace paddle {
namespace operators {
void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Weight"),
"Input(Weight) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("WeightX"),
"Input(WeightX) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("WeightH"),
"Input(WeightH) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Bias"),
"Input(Bias) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("XX"),
"Output(XX) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Output(Hidden) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Cell"),
"Output(Cell) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchGate"),
"Output(BatchGate) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedGate"),
"Output(BatchedGate) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchCellPreAct"),
"Output(BatchGate) of LSTM should not be null.");
"Output(BatchedGate) of LSTM should not be null.");
auto in_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_EQ(in_dims.size(), 2, "Input(X)'s rank must be 2.");
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
if (ctx->HasInput("H0")) {
PADDLE_ENFORCE(ctx->HasInput("C0"),
......@@ -49,15 +57,24 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
"should be the same.");
}
int frame_size = in_dims[1] / 4;
auto w_dims = ctx->GetInputDim("Weight");
PADDLE_ENFORCE_EQ(w_dims.size(), 2, "The rank of Input(Weight) should be 2.");
PADDLE_ENFORCE_EQ(w_dims[0], frame_size,
"The first dimension of Input(Weight) "
auto wx_dims = ctx->GetInputDim("WeightX");
PADDLE_ENFORCE_EQ(wx_dims.size(), 2,
"The rank of Input(WeightX) should be 2.");
PADDLE_ENFORCE_EQ(wx_dims[0], x_dims[1],
"The first dimension of Input(WeightX) "
"should be %d.",
x_dims[1]);
int frame_size = wx_dims[1] / 4;
auto wh_dims = ctx->GetInputDim("WeightH");
PADDLE_ENFORCE_EQ(wh_dims.size(), 2,
"The rank of Input(WeightH) should be 2.");
PADDLE_ENFORCE_EQ(wh_dims[0], frame_size,
"The first dimension of Input(WeightH) "
"should be %d.",
frame_size);
PADDLE_ENFORCE_EQ(w_dims[1], 4 * frame_size,
"The second dimension of Input(Weight) "
PADDLE_ENFORCE_EQ(wh_dims[1], 4 * frame_size,
"The second dimension of Input(WeightH) "
"should be 4 * %d.",
frame_size);
......@@ -66,36 +83,35 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_EQ(b_dims[0], 1,
"The first dimension of Input(Bias) should be 1.");
if (ctx->Attrs().Get<bool>("use_peepholes")) {
PADDLE_ENFORCE_EQ(b_dims[1], 7 * frame_size,
"The second dimension of Input(Bias) should be "
"7 * %d if enable peepholes connection",
frame_size);
} else {
PADDLE_ENFORCE(!ctx->Attrs().Get<bool>("use_peepholes"),
"Do not support peephole yet.");
PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size,
"The second dimension of Input(Bias) should be "
"4 * %d if disable peepholes connection",
frame_size);
}
framework::DDim out_dims({in_dims[0], frame_size});
framework::DDim out_dims({x_dims[0], frame_size});
ctx->SetOutputDim("Hidden", out_dims);
ctx->SetOutputDim("Cell", out_dims);
ctx->SetOutputDim("BatchGate", in_dims);
ctx->SetOutputDim("BatchedGate", {x_dims[0], wx_dims[1]});
ctx->SetOutputDim("BatchCellPreAct", out_dims);
ctx->ShareLoD("Input", "Hidden");
ctx->ShareLoD("Input", "Cell");
ctx->ShareLoD("X", "Hidden");
ctx->ShareLoD("X", "Cell");
int xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
ctx->SetOutputDim("XX", {x_dims[0], xx_width});
ctx->ShareLoD("X", "XX");
}
framework::OpKernelType FusionLSTMOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
}
void FusionLSTMOpMaker::Make() {
AddInput("Input",
AddInput("X",
"(LoDTensor) the first input is a LodTensor, which support "
"variable-time length input sequence. The underlying tensor in "
"this LoDTensor is a matrix with shape (T X 4D), where T is the "
......@@ -130,7 +146,12 @@ void FusionLSTMOpMaker::Make() {
AddOutput("Cell",
"(LoDTensor) the cell state of LSTM operator. "
"The shape is (T x D), and lod is the same with the `Input`.");
AddOutput("BatchGate",
AddOutput("XX",
"(LoDTensor) the first input is a LodTensor, which support "
"variable-time length input sequence. The underlying tensor in "
"this LoDTensor is a matrix with shape (T X 4D), where T is the "
"total time steps in this mini-batch, D is the hidden size.");
AddOutput("BatchedGate",
"(LoDTensor) This LoDTensor contains input gate, forget gate "
"and output gate after the nonlinear computation. This "
"LoDTensor has the same shape as the reorganized input, which "
......@@ -219,80 +240,102 @@ inline void ReorderInitState(const DeviceContext& ctx,
framework::Tensor* dst, bool indexed_src) {
math::CopyMatrixRowsFunctor<DeviceContext, T> row_shuffle;
dst->mutable_data<T>(src.dims(), ctx.GetPlace());
// TODO(TJ): check mem copy perf
row_shuffle(ctx, src, index_lod, dst, indexed_src);
}
// TODO(TJ): can move to math::details
template <typename DeviceContext, typename T>
inline void SimpleFC(const math::BlasT<DeviceContext, T>& blas, const int M,
const int N, const int K, const T* A, const T* B, T* C,
const T* bias_data = NULL) {
blas.GEMM(CblasNoTrans, CblasNoTrans, M, N, K, static_cast<T>(1), A, B,
static_cast<T>(0), C);
if (bias_data) {
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for if (FLAGS_paddle_num_threads > 1)
#endif
for (int i = 0; i < M; i++) {
blas.AXPY(N, static_cast<T>(1), bias_data, C + i * N);
}
}
}
template <typename DeviceContext, typename T>
class LSTMKernel : public framework::OpKernel<T> {
class FuisonLSTMKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<LoDTensor>("Input");
auto* weight = ctx.Input<Tensor>("Weight");
auto* x = ctx.Input<LoDTensor>("X");
auto* wx = ctx.Input<Tensor>("WeightX"); // x*4D
auto* wh = ctx.Input<Tensor>("WeightH"); // D*4D
auto* bias = ctx.Input<Tensor>("Bias");
auto* hidden_t0 = ctx.Input<Tensor>("H0");
auto* cell_t0 = ctx.Input<Tensor>("C0");
auto* batch_gate = ctx.Output<LoDTensor>("BatchGate");
batch_gate->mutable_data<T>(ctx.GetPlace());
// the result after x*Wx (size: sum_words*4D) or batched_x (size:
// sum_words*x)
auto* xx = ctx.Output<LoDTensor>("XX");
auto* batched_gate = ctx.Output<LoDTensor>("BatchedGate");
auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
hidden_out->mutable_data<T>(ctx.GetPlace());
auto* cell_out = ctx.Output<LoDTensor>("Cell");
bool is_reverse = ctx.Attr<bool>("is_reverse");
T* xx_data = xx->mutable_data<T>(ctx.GetPlace());
T* batched_gate_data = batched_gate->mutable_data<T>(ctx.GetPlace());
hidden_out->mutable_data<T>(ctx.GetPlace());
cell_out->mutable_data<T>(ctx.GetPlace());
bool is_reverse = ctx.Attr<bool>("is_reverse");
const T* x_data = x->data<T>();
const T* wx_data = wx->data<T>();
auto x_dims = x->dims();
auto wx_dims = wx->dims();
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
auto& device_ctx = ctx.template device_context<DeviceContext>();
to_batch(device_ctx, *input, batch_gate, true, is_reverse);
auto in_dims = input->dims();
int frame_size = static_cast<int>(in_dims[1] / 4);
framework::DDim dims({in_dims[0], frame_size});
if (bias) {
Tensor b = *bias;
b.Resize({bias->numel(), 1});
Tensor gate_bias = b.Slice(0, 4 * frame_size);
math::RowwiseAdd<DeviceContext, T> add_bias;
add_bias(device_ctx, *batch_gate, gate_bias, batch_gate);
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
// TODO(TJ): op test these two cases
if (x_dims[1] > wx_dims[1]) {
SimpleFC<DeviceContext, T>(blas, x_dims[0], wx_dims[1], x_dims[1], x_data,
wx_data, xx_data, bias->data<T>());
to_batch(dev_ctx, *xx, batched_gate, true, is_reverse);
} else {
to_batch(dev_ctx, *x, xx, true, is_reverse);
SimpleFC<DeviceContext, T>(blas, x_dims[0], wx_dims[1], x_dims[1],
xx_data, wx_data, batched_gate_data,
bias->data<T>());
}
int frame_size = static_cast<int>(wx_dims[1] / 4);
framework::DDim out_dims({x_dims[0], frame_size});
math::LstmMetaValue<T> lstm_value;
if (bias && ctx.Attr<bool>("use_peepholes")) {
T* bias_data = const_cast<T*>(bias->data<T>());
// the code style in LstmMetaValue will be updated later.
lstm_value.check_ig = bias_data + 4 * frame_size;
lstm_value.check_fg = lstm_value.check_ig + frame_size;
lstm_value.check_og = lstm_value.check_fg + frame_size;
} else {
// no peephole
lstm_value.check_ig = nullptr;
lstm_value.check_fg = nullptr;
lstm_value.check_og = nullptr;
}
lstm_value.prev_state_value = nullptr;
Tensor ordered_c0;
framework::Vector<size_t> order(batch_gate->lod()[2]);
framework::Vector<size_t> order(batched_gate->lod()[2]);
if (cell_t0) {
// Since the batch computing for LSTM reorders the input sequence
// according to their length. The initialized cell state also needs
// to reorder.
ReorderInitState<DeviceContext, T>(device_ctx, *cell_t0, order,
&ordered_c0, true);
ReorderInitState<DeviceContext, T>(dev_ctx, *cell_t0, order, &ordered_c0,
true);
lstm_value.prev_state_value = ordered_c0.data<T>();
}
// Use the local variable as here.
LoDTensor batch_hidden, batch_cell;
auto* batch_cell_pre_act = ctx.Output<LoDTensor>("BatchCellPreAct");
batch_hidden.mutable_data<T>(dims, ctx.GetPlace());
batch_cell.mutable_data<T>(dims, ctx.GetPlace());
batch_cell_pre_act->mutable_data<T>(dims, ctx.GetPlace());
batch_hidden.mutable_data<T>(out_dims, ctx.GetPlace());
batch_cell.mutable_data<T>(out_dims, ctx.GetPlace());
batch_cell_pre_act->mutable_data<T>(out_dims, ctx.GetPlace());
auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1;
auto batch_starts = batched_gate->lod()[0];
size_t max_seq_len = batch_starts.size() - 1;
auto gate_act = math::detail::GetActivationType(
ctx.Attr<std::string>("gate_activation"));
auto cell_act = math::detail::GetActivationType(
......@@ -300,12 +343,11 @@ class LSTMKernel : public framework::OpKernel<T> {
auto cand_act = math::detail::GetActivationType(
ctx.Attr<std::string>("candidate_activation"));
auto blas = math::GetBlas<DeviceContext, T>(device_ctx);
for (size_t n = 0; n < num_batch; n++) {
for (size_t n = 0; n < max_seq_len; n++) {
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);
Tensor gate_t = batch_gate->Slice(bstart, bend);
Tensor gate_t = batched_gate->Slice(bstart, bend);
Tensor out_t = batch_hidden.Slice(bstart, bend);
Tensor cell_t = batch_cell.Slice(bstart, bend);
Tensor cell_pre_act_t = batch_cell_pre_act->Slice(bstart, bend);
......@@ -316,9 +358,11 @@ class LSTMKernel : public framework::OpKernel<T> {
int pre_h_start = static_cast<int>(batch_starts[n - 1]);
int pre_h_end = pre_h_start + cur_batch_size;
auto pre_hidden_t = batch_hidden.Slice(pre_h_start, pre_h_end);
blas.MatMul(pre_hidden_t, false, *weight, false, static_cast<T>(1.0),
// TODO(TJ): use gemm directly
blas.MatMul(pre_hidden_t, false, *wh, false, static_cast<T>(1.0),
&gate_t, static_cast<T>(1.0));
} else if (hidden_t0) {
// TODO(TJ): move h0 outside for
// If n == 0 and there is no initialized hidden state, that is to say
// the H0 is zeros, the calculation W_h * H0 will be skiped.
// If n == 0 and there is initialized hidden state, calculate W_h * H0.
......@@ -327,10 +371,11 @@ class LSTMKernel : public framework::OpKernel<T> {
// according to their length. The initialized hidden state also needs
// to reorder.
Tensor ordered_h0;
ReorderInitState<DeviceContext, T>(device_ctx, *hidden_t0, order,
ReorderInitState<DeviceContext, T>(dev_ctx, *hidden_t0, order,
&ordered_h0, true);
blas.MatMul(ordered_h0, false, *weight, false, static_cast<T>(1.0),
&gate_t, static_cast<T>(1.0));
// TODO(TJ): use gemm directly
blas.MatMul(ordered_h0, false, *wh, false, static_cast<T>(1.0), &gate_t,
static_cast<T>(1.0));
}
lstm_value.gate_value = gate_t.data<T>();
......@@ -338,19 +383,19 @@ class LSTMKernel : public framework::OpKernel<T> {
lstm_value.state_value = cell_t.data<T>();
lstm_value.state_active_value = cell_pre_act_t.data<T>();
math::LstmUnitFunctor<DeviceContext, T>::compute(
device_ctx, lstm_value, frame_size, cur_batch_size, gate_act,
cell_act, cand_act);
dev_ctx, lstm_value, frame_size, cur_batch_size, gate_act, cell_act,
cand_act);
lstm_value.prev_state_value = lstm_value.state_value;
}
math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
batch_hidden.set_lod(batch_gate->lod());
batch_hidden.set_lod(batched_gate->lod());
// restore the output hidden in LoDTensor from the batch hidden
to_seq(device_ctx, batch_hidden, hidden_out);
to_seq(dev_ctx, batch_hidden, hidden_out);
batch_cell.set_lod(batch_gate->lod());
batch_cell.set_lod(batched_gate->lod());
// restore the output cell state in LoDTensor from the batch cell
to_seq(device_ctx, batch_cell, cell_out);
to_seq(dev_ctx, batch_cell, cell_out);
}
};
......@@ -358,9 +403,10 @@ class LSTMKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(lstm, ops::LSTMOp, ops::LSTMOpMaker,
REGISTER_OPERATOR(fusion_lstm, ops::FusionLSTMOp, ops::FusionLSTMOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OP_CPU_KERNEL(
fusion_lstm, ops::LSTMKernel<paddle::platform::CPUDeviceContext, float>,
ops::LSTMKernel<paddle::platform::CPUDeviceContext, double>);
fusion_lstm,
ops::FuisonLSTMKernel<paddle::platform::CPUDeviceContext, float>,
ops::FuisonLSTMKernel<paddle::platform::CPUDeviceContext, double>);
......@@ -15,10 +15,6 @@ limitations under the License. */
#pragma once
// #include <string>
#include "paddle/fluid/framework/op_registry.h"
// #include "paddle/fluid/operators/math/blas.h"
// #include "paddle/fluid/operators/math/detail/activation_functions.h"
// #include "paddle/fluid/operators/math/lstm_compute.h"
// #include "paddle/fluid/operators/math/sequence2batch.h"
namespace paddle {
namespace operators {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册