提交 d7ac1cc8 编写于 作者: T tensor-tang

refine seq when bs is large

上级 9dd5a177
......@@ -285,18 +285,23 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
act_cell(D, ct, gates + D2); \
blas.VMUL(D, gates + D2, gates + D3, ht)
#define COMPUTE_CtHt_WITHOUT_H0C0(gates, ct, ht) \
act_gate(D, gates + D, gates + D); \
act_cand(D, gates, gates); \
/* C_t = igated * cgated*/ \
blas.VMUL(D, gates, gates + D, ct); \
/* get outgated*/ \
if (use_peepholes) { \
/* put W_oc * C_t on igated */ \
blas.VMUL(D, wc_data + D2, ct, gates + D); \
blas.VADD(D, gates + D, gates + D3, gates + D3); \
} \
act_gate(D, gates + D3, gates + D3); \
#define GET_Ct_NOH0C0(gates, ct) \
/* C_t = igated * cgated*/ \
act_gate(D, gates + D, gates + D); \
act_cand(D, gates, gates); \
blas.VMUL(D, gates, gates + D, ct)
#define COMPUTE_CtHt_NOH0C0(gates, ct, ht) \
GET_Ct_NOH0C0(gates, ct); \
act_gate(D, gates + D3, gates + D3); \
GET_Ht(ct, gates, ht)
#define COMPUTE_CtHt_PEEPHOLE_NOH0C0(gates, ct, ht) \
GET_Ct_NOH0C0(gates, ct); \
/* get outgated, put W_oc * C_t on igated */ \
blas.VMUL(D, wc_data + D2, ct, gates + D); \
blas.VADD(D, gates + D, gates + D3, gates + D3); \
act_gate(D, gates + D3, gates + D3); \
GET_Ht(ct, gates, ht)
#define COMPUTE_CtHt(gates, ct_1, ct, ht) \
......@@ -354,24 +359,38 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
h_out_data = h_out_data + gate_offset; \
c_out_data = c_out_data + gate_offset
#define PROCESS_H0C0 \
int bid = is_reverse ? N - 1 - i : i; \
int seq_len = x_lod[0][bid + 1] - x_lod[0][bid]; \
const T* prev_c_data = nullptr; \
const T* prev_h_data = nullptr; \
int tstart = 0; \
if (h0_data) { \
prev_h_data = h0_data + bid * D; \
prev_c_data = c0_data + bid * D; \
} else { \
COMPUTE_CtHt_WITHOUT_H0C0(xx_data, c_out_data, h_out_data); \
MOVE_ONE_STEP; \
tstart = 1; \
#define PROCESS_H0C0_DEFINES \
int bid = is_reverse ? N - 1 - i : i; \
int seq_len = x_lod[0][bid + 1] - x_lod[0][bid]; \
const T* prev_c_data = nullptr; \
const T* prev_h_data = nullptr; \
int tstart = 0
#define PROCESS_H0C0_PEEPHOLE \
PROCESS_H0C0_DEFINES; \
if (h0_data) { \
prev_h_data = h0_data + bid * D; \
prev_c_data = c0_data + bid * D; \
} else { \
COMPUTE_CtHt_PEEPHOLE_NOH0C0(xx_data, c_out_data, h_out_data); \
MOVE_ONE_STEP; \
tstart = 1; \
}
#define PROCESS_H0C0 \
PROCESS_H0C0_DEFINES; \
if (h0_data) { \
prev_h_data = h0_data + bid * D; \
prev_c_data = c0_data + bid * D; \
} else { \
COMPUTE_CtHt_NOH0C0(xx_data, c_out_data, h_out_data); \
MOVE_ONE_STEP; \
tstart = 1; \
}
if (use_peepholes) {
for (int i = 0; i < N; ++i) {
PROCESS_H0C0;
PROCESS_H0C0_PEEPHOLE
for (int step = tstart; step < seq_len; ++step) {
GEMM_WH_ADDON(1, prev_h_data, xx_data);
COMPUTE_CtHt_PEEPHOLE(xx_data, prev_c_data, c_out_data, h_out_data);
......@@ -380,7 +399,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
}
} else {
for (int i = 0; i < N; ++i) {
PROCESS_H0C0;
PROCESS_H0C0
for (int step = tstart; step < seq_len; ++step) {
GEMM_WH_ADDON(1, prev_h_data, xx_data);
COMPUTE_CtHt(xx_data, prev_c_data, c_out_data, h_out_data);
......@@ -388,6 +407,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
}
}
}
#undef PROCESS_H0C0_DEFINES
#undef PROCESS_H0C0_PEEPHOLE
#undef PROCESS_H0C0
#undef MOVE_ONE_STEP
}
......@@ -460,7 +481,13 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
T* cur_h_out_data = batched_h_out_data;
T* cur_c_out_data = batched_c_out_data;
for (int i = 0; i < max_bs; ++i) {
COMPUTE_CtHt_WITHOUT_H0C0(cur_in_data, cur_c_out_data, cur_h_out_data);
GET_Ct_NOH0C0(cur_in_data, cur_c_out_data);
if (use_peepholes) {
blas.VMUL(D, wc_data + D2, cur_c_out_data, cur_in_data + D);
blas.VADD(D, cur_in_data + D, cur_in_data + D3, cur_in_data + D3);
}
act_gate(D, cur_in_data + D3, cur_in_data + D3);
GET_Ht(cur_c_out_data, cur_in_data, cur_h_out_data);
cur_in_data += D4;
cur_c_out_data += D;
cur_h_out_data += D;
......@@ -541,7 +568,9 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
#undef COMPUTE_CtHt_PEEPHOLE
#undef COMPUTE_CtHt
#undef COMPUTE_CtHt_WITHOUT_H0C0
#undef GET_Ct_NOH0C0
#undef COMPUTE_CtHt_NOH0C0
#undef COMPUTE_CtHt_PEEPHOLE_NOH0C0
#undef GET_Ht
#undef GET_Ct
#undef GEMM_WH_ADDON
......
......@@ -183,7 +183,7 @@ class TestFusionLSTMOpPeepholesInitReverse(TestFusionLSTMOp):
self.is_reverse = True
class TestFusionLSTMOpPoopholesBS1(TestFusionLSTMOp):
class TestFusionLSTMOpPeepholesBS1(TestFusionLSTMOp):
def set_conf(self):
self.use_peepholes = True
self.lod = [[2]]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册