提交 0c5ed5f6 编写于 作者: T tensor-tang

enable peephole jitcode

test=develop
上级 e3b61cf5
...@@ -221,10 +221,14 @@ void LSTMJitCode::generate() { ...@@ -221,10 +221,14 @@ void LSTMJitCode::generate() {
reg64_t reg_ptr_ct_1 = r9; reg64_t reg_ptr_ct_1 = r9;
reg64_t reg_ptr_ct = r10; reg64_t reg_ptr_ct = r10;
reg64_t reg_ptr_ht = r11; reg64_t reg_ptr_ht = r11;
reg64_t reg_ptr_wp = r12;
mov(reg_ptr_gates, ptr[param1 + offsetof(lstm_t, gates)]); mov(reg_ptr_gates, ptr[param1 + offsetof(lstm_t, gates)]);
mov(reg_ptr_ct_1, ptr[param1 + offsetof(lstm_t, ct_1)]); mov(reg_ptr_ct_1, ptr[param1 + offsetof(lstm_t, ct_1)]);
mov(reg_ptr_ct, ptr[param1 + offsetof(lstm_t, ct)]); mov(reg_ptr_ct, ptr[param1 + offsetof(lstm_t, ct)]);
mov(reg_ptr_ht, ptr[param1 + offsetof(lstm_t, ht)]); mov(reg_ptr_ht, ptr[param1 + offsetof(lstm_t, ht)]);
if (use_peephole_) {
mov(reg_ptr_wp, ptr[param1 + offsetof(lstm_t, wp)]);
}
int offset = 0; int offset = 0;
int d = num_ * sizeof(float); int d = num_ * sizeof(float);
...@@ -235,13 +239,27 @@ void LSTMJitCode::generate() { ...@@ -235,13 +239,27 @@ void LSTMJitCode::generate() {
act<ymm_t>(ymm_c, ymm_src, act_cand_); act<ymm_t>(ymm_c, ymm_src, act_cand_);
// i // i
vmovups(ymm_src, ptr[reg_ptr_gates + offset + d]); vmovups(ymm_src, ptr[reg_ptr_gates + offset + d]);
if (!compute_c1h1_ && use_peephole_) {
ymm_t ymm_wp = ymm_t(2);
ymm_t ymm_ct_1 = ymm_t(3);
vmovups(ymm_wp, ptr[reg_ptr_wp + offset]);
vmovups(ymm_ct_1, ptr[reg_ptr_ct_1 + offset]);
vmulps(ymm_wp, ymm_ct_1, ymm_wp);
vaddps(ymm_src, ymm_src, ymm_wp);
}
act<ymm_t>(ymm_i, ymm_src, act_gate_); act<ymm_t>(ymm_i, ymm_src, act_gate_);
vmulps(ymm_c, ymm_c, ymm_i); vmulps(ymm_c, ymm_c, ymm_i);
if (!compute_c1h1_) { if (!compute_c1h1_) {
// f // f
vmovups(ymm_src, ptr[reg_ptr_gates + offset + 2 * d]); vmovups(ymm_src, ptr[reg_ptr_gates + offset + 2 * d]);
act<ymm_t>(ymm_f, ymm_src, act_gate_);
vmovups(ymm_i, ptr[reg_ptr_ct_1 + offset]); vmovups(ymm_i, ptr[reg_ptr_ct_1 + offset]);
if (use_peephole_) {
ymm_t ymm_wp = ymm_t(3);
vmovups(ymm_wp, ptr[reg_ptr_wp + offset + d]);
vmulps(ymm_wp, ymm_i, ymm_wp);
vaddps(ymm_src, ymm_src, ymm_wp);
}
act<ymm_t>(ymm_f, ymm_src, act_gate_);
vmulps(ymm_f, ymm_f, ymm_i); vmulps(ymm_f, ymm_f, ymm_i);
vaddps(ymm_f, ymm_f, ymm_c); vaddps(ymm_f, ymm_f, ymm_c);
} }
...@@ -250,8 +268,14 @@ void LSTMJitCode::generate() { ...@@ -250,8 +268,14 @@ void LSTMJitCode::generate() {
ymm_t ymm_o = compute_c1h1_ ? ymm_f : ymm_c; ymm_t ymm_o = compute_c1h1_ ? ymm_f : ymm_c;
ymm_t ymm_tmp = ymm_i; ymm_t ymm_tmp = ymm_i;
vmovups(ptr[reg_ptr_ct + offset], ymm_ct); // save ct vmovups(ptr[reg_ptr_ct + offset], ymm_ct); // save ct
act<ymm_t>(ymm_tmp, ymm_ct, act_cell_);
vmovups(ymm_src, ptr[reg_ptr_gates + offset + 3 * d]); vmovups(ymm_src, ptr[reg_ptr_gates + offset + 3 * d]);
if (use_peephole_) {
ymm_t ymm_wp = ymm_t(2);
vmovups(ymm_wp, ptr[reg_ptr_wp + offset + d * 2]);
vmulps(ymm_wp, ymm_ct, ymm_wp);
vaddps(ymm_src, ymm_src, ymm_wp);
}
act<ymm_t>(ymm_tmp, ymm_ct, act_cell_);
act<ymm_t>(ymm_o, ymm_src, act_gate_); act<ymm_t>(ymm_o, ymm_src, act_gate_);
vmulps(ymm_o, ymm_tmp, ymm_o); vmulps(ymm_o, ymm_tmp, ymm_o);
vmovups(ptr[reg_ptr_ht + offset], ymm_o); // save ht vmovups(ptr[reg_ptr_ht + offset], ymm_o); // save ht
......
...@@ -108,7 +108,7 @@ class PeepholeKernelImpl : public LSTMKernel<T> { ...@@ -108,7 +108,7 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
template <> template <>
bool PeepholeKernelImpl<float>::useJIT(int d) { bool PeepholeKernelImpl<float>::useJIT(int d) {
return false; // peephole jitcode not ready yet return gen::LSTMJitCode::init(d);
} }
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册