未验证 提交 9f2ccf5b 编写于 作者: T tensor-tang 提交者: GitHub

Merge pull request #13237 from tensor-tang/refine/op/peephole

refine fusion lstm/peephole and fusion gru
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h" #include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
#include <string> #include <string>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
......
...@@ -30,14 +30,7 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -30,14 +30,7 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
"Input(WeightX) of GRU should not be null."); "Input(WeightX) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasInput("WeightH"), PADDLE_ENFORCE(ctx->HasInput("WeightH"),
"Input(WeightH) of GRU should not be null."); "Input(WeightH) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("XX"), "Output(XX) of GRU should not be null."); PADDLE_ENFORCE(ctx->HasOutput("XX"), "Output(XX) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"),
"Output(ReorderedH0) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"),
"Output(BatchedInput) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedOut"),
"Output(BatchedOut) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"), PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Output(Hidden) of GRU should not be null."); "Output(Hidden) of GRU should not be null.");
...@@ -80,15 +73,20 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -80,15 +73,20 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
} }
framework::DDim out_dims({x_dims[0], frame_size}); framework::DDim out_dims({x_dims[0], frame_size});
ctx->SetOutputDim("Hidden", out_dims); ctx->SetOutputDim("Hidden", out_dims);
ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]});
ctx->SetOutputDim("BatchedOut", out_dims);
ctx->ShareLoD("X", "Hidden"); ctx->ShareLoD("X", "Hidden");
int xx_width; int xx_width;
if (ctx->Attrs().Get<bool>("use_seq")) { if (ctx->Attrs().Get<bool>("use_seq")) {
xx_width = wx_dims[1]; xx_width = wx_dims[1];
} else { } else {
xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"),
"Output(ReorderedH0) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"),
"Output(BatchedInput) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedOut"),
"Output(BatchedOut) of GRU should not be null.");
ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]});
ctx->SetOutputDim("BatchedOut", out_dims);
} }
ctx->SetOutputDim("XX", {x_dims[0], xx_width}); ctx->SetOutputDim("XX", {x_dims[0], xx_width});
ctx->ShareLoD("X", "XX"); ctx->ShareLoD("X", "XX");
......
...@@ -38,16 +38,6 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -38,16 +38,6 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
"Output(Hidden) of LSTM should not be null."); "Output(Hidden) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Cell"), PADDLE_ENFORCE(ctx->HasOutput("Cell"),
"Output(Cell) of LSTM should not be null."); "Output(Cell) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"),
"Output(BatchedInput) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedHidden"),
"Output(BatchedHidden) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedCell"),
"Output(BatchedCell) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"),
"Output(ReorderedH0) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("ReorderedC0"),
"Output(ReorderedC0) of LSTM should not be null.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
...@@ -88,28 +78,36 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -88,28 +78,36 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2."); PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2.");
PADDLE_ENFORCE_EQ(b_dims[0], 1, PADDLE_ENFORCE_EQ(b_dims[0], 1,
"The first dimension of Input(Bias) should be 1."); "The first dimension of Input(Bias) should be 1.");
PADDLE_ENFORCE_EQ(
auto use_peepholes = ctx->Attrs().Get<bool>("use_peepholes"); b_dims[1], (ctx->Attrs().Get<bool>("use_peepholes") ? 7 : 4) * frame_size,
PADDLE_ENFORCE_EQ(b_dims[1], (use_peepholes ? 7 : 4) * frame_size, "The second dimension of Input(Bias) should be "
"The second dimension of Input(Bias) should be " "7 * %d if enable peepholes connection or"
"7 * %d if enable peepholes connection or" "4 * %d if disable peepholes",
"4 * %d if disable peepholes", frame_size, frame_size);
frame_size, frame_size);
framework::DDim out_dims({x_dims[0], frame_size}); framework::DDim out_dims({x_dims[0], frame_size});
ctx->SetOutputDim("Hidden", out_dims); ctx->SetOutputDim("Hidden", out_dims);
ctx->SetOutputDim("Cell", out_dims); ctx->SetOutputDim("Cell", out_dims);
ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]});
ctx->SetOutputDim("BatchedHidden", out_dims);
ctx->SetOutputDim("BatchedCell", out_dims);
ctx->ShareLoD("X", "Hidden"); ctx->ShareLoD("X", "Hidden");
ctx->ShareLoD("X", "Cell"); ctx->ShareLoD("X", "Cell");
int xx_width; int xx_width;
if (ctx->Attrs().Get<bool>("use_seq")) { if (ctx->Attrs().Get<bool>("use_seq")) {
xx_width = wx_dims[1]; xx_width = wx_dims[1];
} else { } else {
xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"),
"Output(BatchedInput) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedHidden"),
"Output(BatchedHidden) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedCell"),
"Output(BatchedCell) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"),
"Output(ReorderedH0) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("ReorderedC0"),
"Output(ReorderedC0) of LSTM should not be null.");
ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]});
ctx->SetOutputDim("BatchedHidden", out_dims);
ctx->SetOutputDim("BatchedCell", out_dims);
} }
ctx->SetOutputDim("XX", {x_dims[0], xx_width}); ctx->SetOutputDim("XX", {x_dims[0], xx_width});
ctx->ShareLoD("X", "XX"); ctx->ShareLoD("X", "XX");
...@@ -232,18 +230,18 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -232,18 +230,18 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
act_cand = act_functor(act_cand_str); \ act_cand = act_functor(act_cand_str); \
} }
#define INIT_BASE_INPUT_OUTPUT \ #define INIT_BASE_INPUT_OUTPUT \
auto* x = ctx.Input<LoDTensor>("X"); \ auto* x = ctx.Input<LoDTensor>("X"); \
auto* h0 = ctx.Input<Tensor>("H0"); \ auto* h0 = ctx.Input<Tensor>("H0"); \
auto* c0 = ctx.Input<Tensor>("C0"); \ auto* c0 = ctx.Input<Tensor>("C0"); \
auto* wx = ctx.Input<Tensor>("WeightX"); \ auto* wx = ctx.Input<Tensor>("WeightX"); \
auto* wh = ctx.Input<Tensor>("WeightH"); \ auto* wh = ctx.Input<Tensor>("WeightH"); \
auto* bias = ctx.Input<Tensor>("Bias"); \ auto* bias = ctx.Input<Tensor>("Bias"); \
auto* xx = ctx.Output<LoDTensor>("XX"); \ auto* xx = ctx.Output<LoDTensor>("XX"); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \ auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
auto* cell_out = ctx.Output<LoDTensor>("Cell"); \ auto* cell_out = ctx.Output<LoDTensor>("Cell"); \
bool use_peepholes = ctx.Attr<bool>("use_peepholes"); \ bool is_reverse = ctx.Attr<bool>("is_reverse"); \
bool is_reverse = ctx.Attr<bool>("is_reverse"); bool use_peepholes = ctx.Attr<bool>("use_peepholes");
#define INIT_BASE_SIZES \ #define INIT_BASE_SIZES \
auto x_dims = x->dims(); /* T x M*/ \ auto x_dims = x->dims(); /* T x M*/ \
...@@ -254,172 +252,183 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -254,172 +252,183 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
const int D3 = D * 3; \ const int D3 = D * 3; \
const int D4 = wh_dims[1]; const int D4 = wh_dims[1];
#define INIT_BASE_INPUT_DATAS \
const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \
/* diagonal weight*/ \
const T* wc_data = bias->data<T>() + D4; \
/* for peephole only*/ \
Tensor checked_cell; \
T* checked_cell_data = nullptr; \
auto place = ctx.GetPlace(); \
if (use_peepholes) { \
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \
checked_cell_data = checked_cell.mutable_data<T>({2, D}, place); \
}
/// Compute LSTM
#define GEMM_WH_ADDON(bs, prev, out) \
blas.GEMM(CblasNoTrans, CblasNoTrans, bs, D4, D, static_cast<T>(1), prev, D, \
wh_data, D4, static_cast<T>(1), out, D4)
// gates: W_ch, W_ih, W_fh, W_oh
#define GET_Ct(ct_1, gates, ct) \
/* C_t = C_t-1 * fgated + cand_gated * igated*/ \
act_cand(D, gates, gates); \
blas.VMUL(D, gates, gates + D, gates + D); \
blas.VMUL(D, ct_1, gates + D2, gates + D2); \
blas.VADD(D, gates + D, gates + D2, ct)
#define GET_Ht(ct, gates, ht) \
/* H_t = act_cell(C_t) * ogated */ \
act_cell(D, ct, gates + D2); \
blas.VMUL(D, gates + D2, gates + D3, ht)
#define GET_Ct_NOH0C0(gates, ct) \
/* C_t = igated * cgated*/ \
act_gate(D, gates + D, gates + D); \
act_cand(D, gates, gates); \
blas.VMUL(D, gates, gates + D, ct)
#define COMPUTE_CtHt_NOH0C0(gates, ct, ht) \
GET_Ct_NOH0C0(gates, ct); \
act_gate(D, gates + D3, gates + D3); \
GET_Ht(ct, gates, ht)
#define COMPUTE_CtHt_PEEPHOLE_NOH0C0(gates, ct, ht) \
GET_Ct_NOH0C0(gates, ct); \
/* get outgated, put W_oc * C_t on igated */ \
blas.VMUL(D, wc_data + D2, ct, gates + D); \
blas.VADD(D, gates + D, gates + D3, gates + D3); \
act_gate(D, gates + D3, gates + D3); \
GET_Ht(ct, gates, ht)
#define COMPUTE_CtHt(gates, ct_1, ct, ht) \
act_gate(D3, gates + D, gates + D); \
GET_Ct(ct_1, gates, ct); \
GET_Ht(ct, gates, ht)
#define COMPUTE_CtHt_PEEPHOLE(gates, ct_1, ct, ht) \
/* get fgated and igated*/ \
blas.VMUL(D, wc_data, ct_1, checked_cell_data); \
blas.VMUL(D, wc_data + D, ct_1, checked_cell_data + D); \
blas.VADD(D2, checked_cell_data, gates + D, gates + D); \
act_gate(D2, gates + D, gates + D); \
GET_Ct(ct_1, gates, ct); \
/* get ogated*/ \
blas.VMUL(D, wc_data + D2, ct, gates + D); \
blas.VADD(D, gates + D, gates + D3, gates + D3); \
act_gate(D, gates + D3, gates + D3); \
GET_Ht(ct, gates, ht)
void SeqCompute(const framework::ExecutionContext& ctx) const { void SeqCompute(const framework::ExecutionContext& ctx) const {
using DeviceContext = paddle::platform::CPUDeviceContext; using DeviceContext = paddle::platform::CPUDeviceContext;
INIT_BASE_INPUT_OUTPUT INIT_BASE_INPUT_OUTPUT
INIT_BASE_SIZES INIT_BASE_SIZES
INIT_VEC_FUNC INIT_VEC_FUNC
INIT_BASE_INPUT_DATAS
auto x_lod = x->lod(); auto x_lod = x->lod();
const int total_T = x_dims[0]; const int total_T = x_dims[0];
const int N = x_lod[0].size() - 1; // batch size const int N = x_lod[0].size() - 1;
const T* x_data = x->data<T>();
const T* h0_data = h0 ? h0->data<T>() : nullptr; const T* h0_data = h0 ? h0->data<T>() : nullptr;
const T* c0_data = c0 ? c0->data<T>() : nullptr; const T* c0_data = c0 ? c0->data<T>() : nullptr;
const T* bias_data = bias->data<T>(); T* xx_data = xx->mutable_data<T>(place);
const T* wc_data = bias_data + D4; // w_ic, w_fc, w_oc T* h_out_data = hidden_out->mutable_data<T>(place);
const T* wx_data = wx->data<T>(); T* c_out_data = cell_out->mutable_data<T>(place);
const T* wh_data = wh->data<T>();
T* xx_data = xx->mutable_data<T>(ctx.GetPlace());
T* hidden_out_data = hidden_out->mutable_data<T>(ctx.GetPlace());
T* cell_out_data = cell_out->mutable_data<T>(ctx.GetPlace());
// use local variable
framework::DDim check_dims({3, D});
Tensor checked_cell; // w_ic * Ct-1, w_fc * Ct-1, w_oc * Ct
auto checked_cell_data =
checked_cell.mutable_data<T>(check_dims, ctx.GetPlace());
auto blas = math::GetBlas<DeviceContext, T>(ctx); auto blas = math::GetBlas<DeviceContext, T>(ctx);
math::FCCompute<DeviceContext, T>(blas, total_T, D4, M, x_data, wx_data, math::FCCompute<DeviceContext, T>(blas, total_T, D4, M, x_data, wx_data,
xx_data, bias->data<T>()); xx_data, bias->data<T>());
int xx_offset = D4; int xx_offset = D4;
int gate_offset = D; int gate_offset = D;
if (is_reverse) { if (is_reverse) {
const int offset = (total_T - 1) * D; const int offset = (total_T - 1) * D;
xx_data = xx_data + offset * 4; xx_data = xx_data + offset * 4;
hidden_out_data = hidden_out_data + offset; h_out_data = h_out_data + offset;
cell_out_data = cell_out_data + offset; c_out_data = c_out_data + offset;
xx_offset = -D4; xx_offset = -D4;
gate_offset = -D; gate_offset = -D;
} }
auto move_step = [&]() { #define MOVE_ONE_STEP \
xx_data = xx_data + xx_offset; prev_h_data = h_out_data; \
hidden_out_data = hidden_out_data + gate_offset; prev_c_data = c_out_data; \
cell_out_data = cell_out_data + gate_offset; xx_data = xx_data + xx_offset; \
}; h_out_data = h_out_data + gate_offset; \
c_out_data = c_out_data + gate_offset
for (int i = 0; i < N; ++i) {
int bid = is_reverse ? N - 1 - i : i; #define PROCESS_H0C0_DEFINES \
int seq_len = x_lod[0][bid + 1] - x_lod[0][bid]; int bid = is_reverse ? N - 1 - i : i; \
const T* prev_c_data = nullptr; int seq_len = x_lod[0][bid + 1] - x_lod[0][bid]; \
const T* prev_h_data = nullptr; const T* prev_c_data = nullptr; \
const T* prev_h_data = nullptr; \
int tstart = 0; int tstart = 0
if (h0_data) {
prev_h_data = h0_data + bid * D; #define PROCESS_H0C0_PEEPHOLE \
prev_c_data = c0_data + bid * D; PROCESS_H0C0_DEFINES; \
} else { if (h0_data) { \
// If step == 0 and there is no initialized hidden state, that is to say prev_h_data = h0_data + bid * D; \
// the H0 is zeros. Then W_h * H_t-1 can be skipped prev_c_data = c0_data + bid * D; \
} else { \
// ~C_t COMPUTE_CtHt_PEEPHOLE_NOH0C0(xx_data, c_out_data, h_out_data); \
act_cand(D, xx_data, xx_data); MOVE_ONE_STEP; \
if (use_peepholes) { tstart = 1; \
// I_t, F_t }
act_gate(D2, xx_data + D, xx_data + D);
} else {
// I_t, F_t, O_t
act_gate(D3, xx_data + D, xx_data + D);
}
// C_t = I_t * ~C_t
blas.VMUL(D, xx_data, xx_data + D, cell_out_data);
if (use_peepholes) {
// + W_oc * C_t for peephole connection
blas.VMUL(D, wc_data + D2, cell_out_data, checked_cell_data + D2);
blas.VADD(D, xx_data + D3, checked_cell_data + D2, xx_data + D3);
// O_t
act_gate(D, xx_data + D3, xx_data + D3);
}
// hidden out= act_state(cellout) * outgate
act_cell(D, cell_out_data, xx_data + D2);
// H_t = O_t * act_state(C_t)
blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data);
// prev
prev_h_data = hidden_out_data;
prev_c_data = cell_out_data;
tstart = 1;
move_step();
}
for (int step = tstart; step < seq_len; ++step) {
// + W_h * H_t-1
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D4, D, static_cast<T>(1),
prev_h_data, D, wh_data, D4, static_cast<T>(1), xx_data, D4);
// ~C_t #define PROCESS_H0C0 \
act_cand(D, xx_data, xx_data); PROCESS_H0C0_DEFINES; \
if (h0_data) { \
prev_h_data = h0_data + bid * D; \
prev_c_data = c0_data + bid * D; \
} else { \
COMPUTE_CtHt_NOH0C0(xx_data, c_out_data, h_out_data); \
MOVE_ONE_STEP; \
tstart = 1; \
}
if (use_peepholes) { if (use_peepholes) {
// + W_ic|W_fc * C_t-1 for peephole connection for (int i = 0; i < N; ++i) {
blas.VMUL(D, wc_data, prev_c_data, checked_cell_data); PROCESS_H0C0_PEEPHOLE
blas.VMUL(D, wc_data + D, prev_c_data, checked_cell_data + D); for (int step = tstart; step < seq_len; ++step) {
blas.VADD(D2, xx_data + D, checked_cell_data, xx_data + D); GEMM_WH_ADDON(1, prev_h_data, xx_data);
// I_t, F_t COMPUTE_CtHt_PEEPHOLE(xx_data, prev_c_data, c_out_data, h_out_data);
act_gate(D2, xx_data + D, xx_data + D); MOVE_ONE_STEP;
} else {
// I_t, F_t, O_t
act_gate(D3, xx_data + D, xx_data + D);
} }
}
// F_t * C_t-1 } else {
blas.VMUL(D, xx_data + D2, prev_c_data, xx_data + D2); for (int i = 0; i < N; ++i) {
// I_t * ~C_t PROCESS_H0C0
blas.VMUL(D, xx_data, xx_data + D, xx_data + D); for (int step = tstart; step < seq_len; ++step) {
// C_t = F_t * C_t-1 + I_t * ~C_t GEMM_WH_ADDON(1, prev_h_data, xx_data);
blas.VADD(D, xx_data + D, xx_data + D2, cell_out_data); COMPUTE_CtHt(xx_data, prev_c_data, c_out_data, h_out_data);
MOVE_ONE_STEP;
if (use_peepholes) {
// + W_oc * C_t for peephole connection
blas.VMUL(D, wc_data + D2, cell_out_data, checked_cell_data + D2);
blas.VADD(D, xx_data + D3, checked_cell_data + D2, xx_data + D3);
// O_t
act_gate(D, xx_data + D3, xx_data + D3);
} }
}
// hidden out= act_state(cellout) * outgate }
act_cell(D, cell_out_data, xx_data + D2); #undef PROCESS_H0C0_DEFINES
// H_t = O_t * act_state(C_t) #undef PROCESS_H0C0_PEEPHOLE
blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data); #undef PROCESS_H0C0
#undef MOVE_ONE_STEP
// prev
prev_h_data = hidden_out_data;
prev_c_data = cell_out_data;
move_step();
} // for each step in batch
} // for each batch
} }
void BatchCompute(const framework::ExecutionContext& ctx) const { void BatchCompute(const framework::ExecutionContext& ctx) const {
using DeviceContext = platform::CPUDeviceContext; using DeviceContext = platform::CPUDeviceContext;
INIT_BASE_INPUT_OUTPUT INIT_BASE_INPUT_OUTPUT
if (x->lod()[0].size() == 2) { // batch size == 1 if (x->lod()[0].size() == 2) {
SeqCompute(ctx); SeqCompute(ctx);
return; return;
} }
INIT_BASE_SIZES INIT_BASE_SIZES
INIT_VEC_FUNC INIT_VEC_FUNC
INIT_BASE_INPUT_DATAS
auto* reordered_h0 = ctx.Output<Tensor>("ReorderedH0"); auto* reordered_h0 = ctx.Output<Tensor>("ReorderedH0");
auto* reordered_c0 = ctx.Output<Tensor>("ReorderedC0"); auto* reordered_c0 = ctx.Output<Tensor>("ReorderedC0");
auto* batched_input = ctx.Output<LoDTensor>("BatchedInput"); auto* batched_input = ctx.Output<LoDTensor>("BatchedInput");
auto* batched_c_out = ctx.Output<LoDTensor>("BatchedCell"); auto* batched_c_out = ctx.Output<LoDTensor>("BatchedCell");
auto* batched_h_out = ctx.Output<LoDTensor>("BatchedHidden"); auto* batched_h_out = ctx.Output<LoDTensor>("BatchedHidden");
const T* x_data = x->data<T>();
const T* wx_data = wx->data<T>();
const T* wh_data = wh->data<T>();
const T* bias_data = bias->data<T>();
const T* wc_data = bias_data + D4; // w_ic, w_fc, w_oc
auto place = ctx.GetPlace();
T* xx_data = xx->mutable_data<T>(place); T* xx_data = xx->mutable_data<T>(place);
T* batched_input_data = batched_input->mutable_data<T>(place); T* batched_input_data = batched_input->mutable_data<T>(place);
T* batched_c_out_data = batched_c_out->mutable_data<T>(place); T* batched_c_out_data = batched_c_out->mutable_data<T>(place);
...@@ -427,12 +436,6 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -427,12 +436,6 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
hidden_out->mutable_data<T>(place); hidden_out->mutable_data<T>(place);
cell_out->mutable_data<T>(place); cell_out->mutable_data<T>(place);
// use local variable
framework::DDim check_dims({3, D});
Tensor checked_cell; // w_ic * Ct-1, w_fc * Ct-1, w_oc * Ct
auto checked_cell_data =
checked_cell.mutable_data<T>(check_dims, ctx.GetPlace());
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch; math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx); auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
...@@ -454,27 +457,17 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -454,27 +457,17 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
reordered_h0->Resize({max_bs, D}); reordered_h0->Resize({max_bs, D});
reordered_c0->Resize({max_bs, D}); reordered_c0->Resize({max_bs, D});
T* prev_batch_h_data = nullptr;
T* prev_batch_c_data = nullptr;
T* cur_batch_in_data = batched_input_data;
T* cur_batch_h_out_data = batched_h_out_data;
T* cur_batch_c_out_data = batched_c_out_data;
auto move_step = [&](int bs) {
cur_batch_in_data += bs * D4;
cur_batch_c_out_data += bs * D;
cur_batch_h_out_data += bs * D;
};
int tstart = 0; int tstart = 0;
T* prev_h_data = nullptr;
T* prev_c_data = nullptr;
if (h0) { if (h0) {
// reorder h0, c0 // reorder h0, c0
T* reordered_h0_data = reordered_h0->mutable_data<T>(place); T* reordered_h0_data = reordered_h0->mutable_data<T>(place);
T* reordered_c0_data = reordered_c0->mutable_data<T>(place); T* reordered_c0_data = reordered_c0->mutable_data<T>(place);
const T* h0_data = h0->data<T>(); const T* h0_data = h0->data<T>();
const T* c0_data = c0->data<T>(); const T* c0_data = c0->data<T>();
prev_batch_h_data = reordered_h0_data; prev_h_data = reordered_h0_data;
prev_batch_c_data = reordered_c0_data; prev_c_data = reordered_c0_data;
size_t sz = sizeof(T) * D; size_t sz = sizeof(T) * D;
for (int i = 0; i < max_bs; ++i) { for (int i = 0; i < max_bs; ++i) {
std::memcpy(reordered_h0_data, h0_data + seq_order[i] * D, sz); std::memcpy(reordered_h0_data, h0_data + seq_order[i] * D, sz);
...@@ -483,123 +476,80 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -483,123 +476,80 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
reordered_c0_data += D; reordered_c0_data += D;
} }
} else { } else {
// Compute with no H0/C0 // compute without h0, c0
T* cur_in_data = cur_batch_in_data; T* cur_in_data = batched_input_data;
T* cur_c_out_data = cur_batch_c_out_data; T* cur_h_out_data = batched_h_out_data;
T* cur_h_out_data = cur_batch_h_out_data; T* cur_c_out_data = batched_c_out_data;
for (int i = 0; i < max_bs; ++i) {
// If step == 0 and there is no initialized hidden state, that is to say GET_Ct_NOH0C0(cur_in_data, cur_c_out_data);
// the H0 is zeros. Then W_h * H_t-1 can be skiped
for (int i = 0; i < max_bs; ++i) { // iterate each data in 1st batch
// ~C_t
act_cand(D, cur_in_data, cur_in_data);
if (use_peepholes) {
// I_t, F_t
act_gate(D2, cur_in_data + D, cur_in_data + D);
} else {
// I_t, F_t, O_t
act_gate(D3, cur_in_data + D, cur_in_data + D);
}
// C_t = I_t * ~C_t
blas.VMUL(D, cur_in_data, cur_in_data + D, cur_c_out_data);
if (use_peepholes) { if (use_peepholes) {
// + W_oc * C_t for peephole connection blas.VMUL(D, wc_data + D2, cur_c_out_data, cur_in_data + D);
blas.VMUL(D, wc_data + D2, cur_c_out_data, checked_cell_data + D2); blas.VADD(D, cur_in_data + D, cur_in_data + D3, cur_in_data + D3);
blas.VADD(D, cur_in_data + D3, checked_cell_data + D2,
cur_in_data + D3);
// O_t
act_gate(D, cur_in_data + D3, cur_in_data + D3);
} }
act_gate(D, cur_in_data + D3, cur_in_data + D3);
// hidden out= act_state(cellout) * outgate GET_Ht(cur_c_out_data, cur_in_data, cur_h_out_data);
act_cell(D, cur_c_out_data, cur_in_data + D2);
// H_t = O_t * act_state(C_t)
blas.VMUL(D, cur_in_data + D2, cur_in_data + D3, cur_h_out_data);
// move to next data in the same batch
cur_in_data += D4; cur_in_data += D4;
cur_c_out_data += D; cur_c_out_data += D;
cur_h_out_data += D; cur_h_out_data += D;
} }
// move to data for next timestep
prev_batch_h_data = cur_batch_h_out_data;
prev_batch_c_data = cur_batch_c_out_data;
move_step(max_bs);
tstart = 1; tstart = 1;
prev_h_data = batched_h_out_data;
prev_c_data = batched_c_out_data;
} }
const auto& batch_starts = batched_lod[0]; const auto& batch_starts = batched_lod[0];
const int max_seq_len = batch_starts.size() - 1; const int max_seq_len = batch_starts.size() - 1;
for (int step = tstart; step < max_seq_len; ++step) { const int offset = tstart * max_bs * D;
const int cur_bs = batch_starts[step + 1] - batch_starts[step]; batched_input_data = batched_input_data + offset * 4;
// + W_h * H_t-1 batched_h_out_data = batched_h_out_data + offset;
blas.GEMM(CblasNoTrans, CblasNoTrans, cur_bs, D4, D, static_cast<T>(1), batched_c_out_data = batched_c_out_data + offset;
prev_batch_h_data, D, wh_data, D4, static_cast<T>(1),
cur_batch_in_data, D4); #define DEFINE_CUR \
T* cur_in_data = batched_input_data; \
T* cur_in_data = cur_batch_in_data; T* cur_prev_c_data = prev_c_data; \
T* cur_c_out_data = cur_batch_c_out_data; T* cur_c_out_data = batched_c_out_data; \
T* cur_h_out_data = cur_batch_h_out_data; T* cur_h_out_data = batched_h_out_data
T* prev_c_data = prev_batch_c_data; // NULL if no C0 in step0
T* prev_h_data = prev_batch_h_data; // NULL if no H0 in step0 #define MOVE_ONE_BATCH \
auto next_data_in_batch = [&]() { cur_in_data += D4; \
cur_in_data += D4; cur_prev_c_data += D; \
cur_c_out_data += D; cur_c_out_data += D; \
cur_h_out_data += D; cur_h_out_data += D
prev_c_data = prev_c_data ? prev_c_data + D : nullptr;
prev_h_data = prev_h_data ? prev_h_data + D : nullptr; #define MOVE_ONE_STEP \
}; prev_c_data = batched_c_out_data; \
prev_h_data = batched_h_out_data; \
for (int i = 0; i < cur_bs; ++i) { // iterate each data in same batch batched_c_out_data = cur_c_out_data; \
// ~C_t batched_h_out_data = cur_h_out_data; \
act_cand(D, cur_in_data, cur_in_data); batched_input_data = cur_in_data
if (use_peepholes) { if (use_peepholes) {
// + W_ic|W_fc * C_t-1 for peephole connection for (int step = tstart; step < max_seq_len; ++step) {
blas.VMUL(D, wc_data, prev_c_data, checked_cell_data); const int cur_bs = batch_starts[step + 1] - batch_starts[step];
blas.VMUL(D, wc_data + D, prev_c_data, checked_cell_data + D); GEMM_WH_ADDON(cur_bs, prev_h_data, batched_input_data);
blas.VADD(D2, cur_in_data + D, checked_cell_data, cur_in_data + D); DEFINE_CUR;
// I_t, F_t for (int i = 0; i < cur_bs; ++i) {
act_gate(D2, cur_in_data + D, cur_in_data + D); COMPUTE_CtHt_PEEPHOLE(cur_in_data, cur_prev_c_data, cur_c_out_data,
} else { cur_h_out_data);
// I_t, F_t, O_t MOVE_ONE_BATCH;
act_gate(D3, cur_in_data + D, cur_in_data + D);
} }
MOVE_ONE_STEP;
// F_t * C_t-1 }
blas.VMUL(D, cur_in_data + D2, prev_c_data, cur_in_data + D2); } else {
// I_t * ~C_t for (int step = tstart; step < max_seq_len; ++step) {
blas.VMUL(D, cur_in_data, cur_in_data + D, cur_in_data + D); const int cur_bs = batch_starts[step + 1] - batch_starts[step];
// C_t = F_t * C_t-1 + I_t * ~C_t GEMM_WH_ADDON(cur_bs, prev_h_data, batched_input_data);
blas.VADD(D, cur_in_data + D, cur_in_data + D2, cur_c_out_data); DEFINE_CUR;
for (int i = 0; i < cur_bs; ++i) {
if (use_peepholes) { COMPUTE_CtHt(cur_in_data, cur_prev_c_data, cur_c_out_data,
// + W_oc * C_t for peephole connection cur_h_out_data);
blas.VMUL(D, wc_data + D2, cur_c_out_data, checked_cell_data + D2); MOVE_ONE_BATCH;
blas.VADD(D, cur_in_data + D3, checked_cell_data + D2,
cur_in_data + D3);
// O_t
act_gate(D, cur_in_data + D3, cur_in_data + D3);
} }
MOVE_ONE_STEP;
// hidden out= act_state(cellout) * outgate
act_cell(D, cur_c_out_data, cur_in_data + D2);
// H_t = O_t * act_state(C_t)
blas.VMUL(D, cur_in_data + D2, cur_in_data + D3, cur_h_out_data);
// move to next data in same batch
next_data_in_batch();
} }
// move to data for next timestep
prev_batch_h_data = cur_batch_h_out_data;
prev_batch_c_data = cur_batch_c_out_data;
move_step(cur_bs);
} }
#undef MOVE_ONE_STEP
#undef MOVE_ONE_BATCH
#undef DEFINE_CUR
math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq; math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
batched_h_out->set_lod(batched_lod); batched_h_out->set_lod(batched_lod);
...@@ -615,6 +565,16 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -615,6 +565,16 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
BatchCompute(ctx); BatchCompute(ctx);
} }
} }
#undef COMPUTE_CtHt_PEEPHOLE
#undef COMPUTE_CtHt
#undef GET_Ct_NOH0C0
#undef COMPUTE_CtHt_NOH0C0
#undef COMPUTE_CtHt_PEEPHOLE_NOH0C0
#undef GET_Ht
#undef GET_Ct
#undef GEMM_WH_ADDON
#undef INIT_BASE_INPUT_DATAS
#undef INIT_BASE_SIZES #undef INIT_BASE_SIZES
#undef INIT_BASE_INPUT_OUTPUT #undef INIT_BASE_INPUT_OUTPUT
#undef INIT_VEC_FUNC #undef INIT_VEC_FUNC
......
...@@ -53,12 +53,11 @@ class TestFusionLSTMOp(OpTest): ...@@ -53,12 +53,11 @@ class TestFusionLSTMOp(OpTest):
self.M = 8 self.M = 8
self.D = 16 self.D = 16
self.has_initial_state = False self.has_initial_state = False
self.use_peepholes = False
self.is_reverse = False self.is_reverse = False
self.act_gate = 'sigmoid' self.act_gate = 'sigmoid'
self.act_cell = 'tanh' self.act_cell = 'tanh'
self.act_cand = 'tanh' self.act_cand = 'tanh'
self.use_peepholes = False
self.use_seq = False
self.set_conf() self.set_conf()
T = sum(self.lod[0]) T = sum(self.lod[0])
...@@ -108,7 +107,6 @@ class TestFusionLSTMOp(OpTest): ...@@ -108,7 +107,6 @@ class TestFusionLSTMOp(OpTest):
} }
self.attrs = { self.attrs = {
'use_peepholes': self.use_peepholes, 'use_peepholes': self.use_peepholes,
'use_seq': self.use_seq,
'is_reverse': self.is_reverse, 'is_reverse': self.is_reverse,
'gate_activation': self.act_gate, 'gate_activation': self.act_gate,
'cell_activation': self.act_cell, 'cell_activation': self.act_cell,
...@@ -178,50 +176,18 @@ class TestFusionLSTMOpPeepholesReverse(TestFusionLSTMOp): ...@@ -178,50 +176,18 @@ class TestFusionLSTMOpPeepholesReverse(TestFusionLSTMOp):
self.is_reverse = True self.is_reverse = True
class TestFusionLSTMOpPoopholesBS1(TestFusionLSTMOp): class TestFusionLSTMOpPeepholesInitReverse(TestFusionLSTMOp):
def set_conf(self): def set_conf(self):
self.use_peepholes = True self.use_peepholes = True
self.lod = [[3]]
self.D = 16
class TestFusionLSTMOpSeqInit(TestFusionLSTMOp):
def set_conf(self):
self.use_seq = True
self.has_initial_state = True
class TestFusionLSTMOpSeqReverse(TestFusionLSTMOp):
def set_conf(self):
self.use_seq = True
self.is_reverse = True
class TestFusionLSTMOpSeqInitReverse(TestFusionLSTMOp):
def set_conf(self):
self.use_seq = True
self.has_initial_state = True self.has_initial_state = True
self.is_reverse = True self.is_reverse = True
class TestFusionLSTMOpSeqPeepholes(TestFusionLSTMOp): class TestFusionLSTMOpPeepholesBS1(TestFusionLSTMOp):
def set_conf(self): def set_conf(self):
self.use_seq = True
self.use_peepholes = True self.use_peepholes = True
self.lod = [[2]]
self.D = 8
class TestFusionLSTMOpSeqPeepholesInit(TestFusionLSTMOp):
def set_conf(self):
self.use_seq = True
self.use_peepholes = True
self.has_initial_state = True
class TestFusionLSTMOpSeqPeepholesReverse(TestFusionLSTMOp):
def set_conf(self):
self.use_seq = True
self.use_peepholes = True
self.is_reverse = True
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册