提交 0c5ed5f6 编写于 作者: T tensor-tang

enable peephole jitcode

test=develop
上级 e3b61cf5
......@@ -221,10 +221,14 @@ void LSTMJitCode::generate() {
reg64_t reg_ptr_ct_1 = r9;
reg64_t reg_ptr_ct = r10;
reg64_t reg_ptr_ht = r11;
reg64_t reg_ptr_wp = r12;
mov(reg_ptr_gates, ptr[param1 + offsetof(lstm_t, gates)]);
mov(reg_ptr_ct_1, ptr[param1 + offsetof(lstm_t, ct_1)]);
mov(reg_ptr_ct, ptr[param1 + offsetof(lstm_t, ct)]);
mov(reg_ptr_ht, ptr[param1 + offsetof(lstm_t, ht)]);
if (use_peephole_) {
mov(reg_ptr_wp, ptr[param1 + offsetof(lstm_t, wp)]);
}
int offset = 0;
int d = num_ * sizeof(float);
......@@ -235,13 +239,27 @@ void LSTMJitCode::generate() {
act<ymm_t>(ymm_c, ymm_src, act_cand_);
// i
vmovups(ymm_src, ptr[reg_ptr_gates + offset + d]);
if (!compute_c1h1_ && use_peephole_) {
ymm_t ymm_wp = ymm_t(2);
ymm_t ymm_ct_1 = ymm_t(3);
vmovups(ymm_wp, ptr[reg_ptr_wp + offset]);
vmovups(ymm_ct_1, ptr[reg_ptr_ct_1 + offset]);
vmulps(ymm_wp, ymm_ct_1, ymm_wp);
vaddps(ymm_src, ymm_src, ymm_wp);
}
act<ymm_t>(ymm_i, ymm_src, act_gate_);
vmulps(ymm_c, ymm_c, ymm_i);
if (!compute_c1h1_) {
// f
vmovups(ymm_src, ptr[reg_ptr_gates + offset + 2 * d]);
act<ymm_t>(ymm_f, ymm_src, act_gate_);
vmovups(ymm_i, ptr[reg_ptr_ct_1 + offset]);
if (use_peephole_) {
ymm_t ymm_wp = ymm_t(3);
vmovups(ymm_wp, ptr[reg_ptr_wp + offset + d]);
vmulps(ymm_wp, ymm_i, ymm_wp);
vaddps(ymm_src, ymm_src, ymm_wp);
}
act<ymm_t>(ymm_f, ymm_src, act_gate_);
vmulps(ymm_f, ymm_f, ymm_i);
vaddps(ymm_f, ymm_f, ymm_c);
}
......@@ -250,8 +268,14 @@ void LSTMJitCode::generate() {
ymm_t ymm_o = compute_c1h1_ ? ymm_f : ymm_c;
ymm_t ymm_tmp = ymm_i;
vmovups(ptr[reg_ptr_ct + offset], ymm_ct); // save ct
act<ymm_t>(ymm_tmp, ymm_ct, act_cell_);
vmovups(ymm_src, ptr[reg_ptr_gates + offset + 3 * d]);
if (use_peephole_) {
ymm_t ymm_wp = ymm_t(2);
vmovups(ymm_wp, ptr[reg_ptr_wp + offset + d * 2]);
vmulps(ymm_wp, ymm_ct, ymm_wp);
vaddps(ymm_src, ymm_src, ymm_wp);
}
act<ymm_t>(ymm_tmp, ymm_ct, act_cell_);
act<ymm_t>(ymm_o, ymm_src, act_gate_);
vmulps(ymm_o, ymm_tmp, ymm_o);
vmovups(ptr[reg_ptr_ht + offset], ymm_o); // save ht
......
......@@ -108,7 +108,7 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
#ifdef PADDLE_WITH_XBYAK
template <>
bool PeepholeKernelImpl<float>::useJIT(int d) {
return false; // peephole jitcode not ready yet
return gen::LSTMJitCode::init(d);
}
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册