提交 d7ac1cc8 编写于 作者: T tensor-tang

refine seq when bs is large

上级 9dd5a177
...@@ -285,17 +285,22 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -285,17 +285,22 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
act_cell(D, ct, gates + D2); \ act_cell(D, ct, gates + D2); \
blas.VMUL(D, gates + D2, gates + D3, ht) blas.VMUL(D, gates + D2, gates + D3, ht)
#define COMPUTE_CtHt_WITHOUT_H0C0(gates, ct, ht) \ #define GET_Ct_NOH0C0(gates, ct) \
/* C_t = igated * cgated*/ \
act_gate(D, gates + D, gates + D); \ act_gate(D, gates + D, gates + D); \
act_cand(D, gates, gates); \ act_cand(D, gates, gates); \
/* C_t = igated * cgated*/ \ blas.VMUL(D, gates, gates + D, ct)
blas.VMUL(D, gates, gates + D, ct); \
/* get outgated*/ \ #define COMPUTE_CtHt_NOH0C0(gates, ct, ht) \
if (use_peepholes) { \ GET_Ct_NOH0C0(gates, ct); \
/* put W_oc * C_t on igated */ \ act_gate(D, gates + D3, gates + D3); \
GET_Ht(ct, gates, ht)
#define COMPUTE_CtHt_PEEPHOLE_NOH0C0(gates, ct, ht) \
GET_Ct_NOH0C0(gates, ct); \
/* get outgated, put W_oc * C_t on igated */ \
blas.VMUL(D, wc_data + D2, ct, gates + D); \ blas.VMUL(D, wc_data + D2, ct, gates + D); \
blas.VADD(D, gates + D, gates + D3, gates + D3); \ blas.VADD(D, gates + D, gates + D3, gates + D3); \
} \
act_gate(D, gates + D3, gates + D3); \ act_gate(D, gates + D3, gates + D3); \
GET_Ht(ct, gates, ht) GET_Ht(ct, gates, ht)
...@@ -354,24 +359,38 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -354,24 +359,38 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
h_out_data = h_out_data + gate_offset; \ h_out_data = h_out_data + gate_offset; \
c_out_data = c_out_data + gate_offset c_out_data = c_out_data + gate_offset
#define PROCESS_H0C0 \ #define PROCESS_H0C0_DEFINES \
int bid = is_reverse ? N - 1 - i : i; \ int bid = is_reverse ? N - 1 - i : i; \
int seq_len = x_lod[0][bid + 1] - x_lod[0][bid]; \ int seq_len = x_lod[0][bid + 1] - x_lod[0][bid]; \
const T* prev_c_data = nullptr; \ const T* prev_c_data = nullptr; \
const T* prev_h_data = nullptr; \ const T* prev_h_data = nullptr; \
int tstart = 0; \ int tstart = 0
#define PROCESS_H0C0_PEEPHOLE \
PROCESS_H0C0_DEFINES; \
if (h0_data) { \
prev_h_data = h0_data + bid * D; \
prev_c_data = c0_data + bid * D; \
} else { \
COMPUTE_CtHt_PEEPHOLE_NOH0C0(xx_data, c_out_data, h_out_data); \
MOVE_ONE_STEP; \
tstart = 1; \
}
#define PROCESS_H0C0 \
PROCESS_H0C0_DEFINES; \
if (h0_data) { \ if (h0_data) { \
prev_h_data = h0_data + bid * D; \ prev_h_data = h0_data + bid * D; \
prev_c_data = c0_data + bid * D; \ prev_c_data = c0_data + bid * D; \
} else { \ } else { \
COMPUTE_CtHt_WITHOUT_H0C0(xx_data, c_out_data, h_out_data); \ COMPUTE_CtHt_NOH0C0(xx_data, c_out_data, h_out_data); \
MOVE_ONE_STEP; \ MOVE_ONE_STEP; \
tstart = 1; \ tstart = 1; \
} }
if (use_peepholes) { if (use_peepholes) {
for (int i = 0; i < N; ++i) { for (int i = 0; i < N; ++i) {
PROCESS_H0C0; PROCESS_H0C0_PEEPHOLE
for (int step = tstart; step < seq_len; ++step) { for (int step = tstart; step < seq_len; ++step) {
GEMM_WH_ADDON(1, prev_h_data, xx_data); GEMM_WH_ADDON(1, prev_h_data, xx_data);
COMPUTE_CtHt_PEEPHOLE(xx_data, prev_c_data, c_out_data, h_out_data); COMPUTE_CtHt_PEEPHOLE(xx_data, prev_c_data, c_out_data, h_out_data);
...@@ -380,7 +399,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -380,7 +399,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
} }
} else { } else {
for (int i = 0; i < N; ++i) { for (int i = 0; i < N; ++i) {
PROCESS_H0C0; PROCESS_H0C0
for (int step = tstart; step < seq_len; ++step) { for (int step = tstart; step < seq_len; ++step) {
GEMM_WH_ADDON(1, prev_h_data, xx_data); GEMM_WH_ADDON(1, prev_h_data, xx_data);
COMPUTE_CtHt(xx_data, prev_c_data, c_out_data, h_out_data); COMPUTE_CtHt(xx_data, prev_c_data, c_out_data, h_out_data);
...@@ -388,6 +407,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -388,6 +407,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
} }
} }
} }
#undef PROCESS_H0C0_DEFINES
#undef PROCESS_H0C0_PEEPHOLE
#undef PROCESS_H0C0 #undef PROCESS_H0C0
#undef MOVE_ONE_STEP #undef MOVE_ONE_STEP
} }
...@@ -460,7 +481,13 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -460,7 +481,13 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
T* cur_h_out_data = batched_h_out_data; T* cur_h_out_data = batched_h_out_data;
T* cur_c_out_data = batched_c_out_data; T* cur_c_out_data = batched_c_out_data;
for (int i = 0; i < max_bs; ++i) { for (int i = 0; i < max_bs; ++i) {
COMPUTE_CtHt_WITHOUT_H0C0(cur_in_data, cur_c_out_data, cur_h_out_data); GET_Ct_NOH0C0(cur_in_data, cur_c_out_data);
if (use_peepholes) {
blas.VMUL(D, wc_data + D2, cur_c_out_data, cur_in_data + D);
blas.VADD(D, cur_in_data + D, cur_in_data + D3, cur_in_data + D3);
}
act_gate(D, cur_in_data + D3, cur_in_data + D3);
GET_Ht(cur_c_out_data, cur_in_data, cur_h_out_data);
cur_in_data += D4; cur_in_data += D4;
cur_c_out_data += D; cur_c_out_data += D;
cur_h_out_data += D; cur_h_out_data += D;
...@@ -541,7 +568,9 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -541,7 +568,9 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
#undef COMPUTE_CtHt_PEEPHOLE #undef COMPUTE_CtHt_PEEPHOLE
#undef COMPUTE_CtHt #undef COMPUTE_CtHt
#undef COMPUTE_CtHt_WITHOUT_H0C0 #undef GET_Ct_NOH0C0
#undef COMPUTE_CtHt_NOH0C0
#undef COMPUTE_CtHt_PEEPHOLE_NOH0C0
#undef GET_Ht #undef GET_Ht
#undef GET_Ct #undef GET_Ct
#undef GEMM_WH_ADDON #undef GEMM_WH_ADDON
......
...@@ -183,7 +183,7 @@ class TestFusionLSTMOpPeepholesInitReverse(TestFusionLSTMOp): ...@@ -183,7 +183,7 @@ class TestFusionLSTMOpPeepholesInitReverse(TestFusionLSTMOp):
self.is_reverse = True self.is_reverse = True
class TestFusionLSTMOpPoopholesBS1(TestFusionLSTMOp): class TestFusionLSTMOpPeepholesBS1(TestFusionLSTMOp):
def set_conf(self): def set_conf(self):
self.use_peepholes = True self.use_peepholes = True
self.lod = [[2]] self.lod = [[2]]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册