提交 6a7f83d4 编写于 作者: T tensor-tang

enable gru jitcode and refine act and lstm jitcode

test=develop
上级 686eaf20
...@@ -140,32 +140,10 @@ bool VActJitCode::init(int d, operand_type type) { ...@@ -140,32 +140,10 @@ bool VActJitCode::init(int d, operand_type type) {
} }
void VActJitCode::generate() { void VActJitCode::generate() {
xmm_t xmm_zero = xmm_t(2);
ymm_t ymm_zero = ymm_t(2);
if (type_ == operand_type::relu) {
vxorps(ymm_zero, ymm_zero, ymm_zero);
}
int offset = 0; int offset = 0;
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) { for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
vmovups(ymm_src, ptr[param1 + offset]); vmovups(ymm_src, ptr[param1 + offset]);
switch (type_) { act<ymm_t>(ymm_dst, ymm_src, type_);
case operand_type::relu:
relu_jmm<ymm_t>(ymm_dst, ymm_src, ymm_zero);
break;
case operand_type::exp:
exp_jmm<ymm_t>(ymm_dst, ymm_src, 2, 3, 4, 5);
break;
case operand_type::sigmoid:
sigmoid_jmm<ymm_t>(ymm_dst, ymm_src, 2, 3, 4, 5);
break;
case operand_type::tanh:
tanh_jmm<ymm_t>(ymm_dst, ymm_src, 2, 3, 4, 5);
break;
case operand_type::identity:
break;
default:
break;
}
vmovups(ptr[param2 + offset], ymm_dst); vmovups(ptr[param2 + offset], ymm_dst);
offset += sizeof(float) * YMM_FLOAT_BLOCK; offset += sizeof(float) * YMM_FLOAT_BLOCK;
} }
...@@ -182,22 +160,7 @@ void VActJitCode::generate() { ...@@ -182,22 +160,7 @@ void VActJitCode::generate() {
block = 1; block = 1;
vmovss(xmm_src, ptr[param1 + offset]); vmovss(xmm_src, ptr[param1 + offset]);
} }
switch (type_) { act<xmm_t>(xmm_dst, xmm_src, type_);
case operand_type::relu:
relu_jmm<xmm_t>(xmm_dst, xmm_src, xmm_zero);
break;
case operand_type::exp:
exp_jmm<xmm_t>(xmm_dst, xmm_src, 2, 3, 4, 5);
break;
case operand_type::sigmoid:
sigmoid_jmm<xmm_t>(xmm_dst, xmm_src, 2, 3, 4, 5);
break;
case operand_type::tanh:
tanh_jmm<xmm_t>(xmm_dst, xmm_src, 2, 3, 4, 5);
break;
default:
break;
}
if (rest >= 4) { if (rest >= 4) {
vmovups(ptr[param2 + offset], xmm_dst); vmovups(ptr[param2 + offset], xmm_dst);
} else if (rest >= 2) { } else if (rest >= 2) {
...@@ -233,52 +196,64 @@ void LSTMJitCode::generate() { ...@@ -233,52 +196,64 @@ void LSTMJitCode::generate() {
int offset = 0; int offset = 0;
int d = num_ * sizeof(float); int d = num_ * sizeof(float);
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) { for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
/* C_t = C_t-1 * fgated + cand_gated * igated*/ /* gates: W_ch, W_ih, W_fh, W_oh */
// c ymm_t ymm_c = ymm_t(0);
vmovups(ymm_src, ptr[reg_ptr_gates + offset]); ymm_t ymm_i = ymm_t(1);
act<ymm_t>(ymm_c, ymm_src, act_cand_); ymm_t ymm_f = ymm_t(2);
// i ymm_t ymm_o = ymm_t(3);
vmovups(ymm_src, ptr[reg_ptr_gates + offset + d]); ymm_t ymm_ct_1 = ymm_t(4);
if (!compute_c1h1_ && use_peephole_) { ymm_t ymm_wp0 = ymm_t(5);
ymm_t ymm_wp = ymm_t(2); ymm_t ymm_wp1 = ymm_t(6);
ymm_t ymm_ct_1 = ymm_t(3); ymm_t ymm_wp2 = ymm_t(7);
vmovups(ymm_wp, ptr[reg_ptr_wp + offset]); vmovups(ymm_c, ptr[reg_ptr_gates + offset]);
vmovups(ymm_i, ptr[reg_ptr_gates + offset + d]);
vmovups(ymm_f, ptr[reg_ptr_gates + offset + 2 * d]);
vmovups(ymm_o, ptr[reg_ptr_gates + offset + 3 * d]);
if (!compute_c1h1_) {
vmovups(ymm_ct_1, ptr[reg_ptr_ct_1 + offset]); vmovups(ymm_ct_1, ptr[reg_ptr_ct_1 + offset]);
vmulps(ymm_wp, ymm_ct_1, ymm_wp);
vaddps(ymm_src, ymm_src, ymm_wp);
} }
act<ymm_t>(ymm_i, ymm_src, act_gate_); if (use_peephole_) {
vmovups(ymm_wp0, ptr[reg_ptr_wp + offset]);
vmovups(ymm_wp1, ptr[reg_ptr_wp + offset + d]);
vmovups(ymm_wp2, ptr[reg_ptr_wp + offset + 2 * d]);
}
/* C_t = act_cand(c) * act_gate(i) + C_t-1 * act_gate(f) */
// act_cand(c)
act<ymm_t>(ymm_c, ymm_c, act_cand_);
// act_gate(i) or act_gate(ct_1 * wp0 + i)
if (!compute_c1h1_ && use_peephole_) {
vmulps(ymm_wp0, ymm_ct_1, ymm_wp0);
vaddps(ymm_i, ymm_i, ymm_wp0);
}
act<ymm_t>(ymm_i, ymm_i, act_gate_);
vmulps(ymm_c, ymm_c, ymm_i); vmulps(ymm_c, ymm_c, ymm_i);
if (!compute_c1h1_) { if (!compute_c1h1_) {
// f // act_gate(f) or act_gate(ct_1 * wp1 + f)
vmovups(ymm_src, ptr[reg_ptr_gates + offset + 2 * d]);
vmovups(ymm_i, ptr[reg_ptr_ct_1 + offset]);
if (use_peephole_) { if (use_peephole_) {
ymm_t ymm_wp = ymm_t(3); vmulps(ymm_wp1, ymm_ct_1, ymm_wp1);
vmovups(ymm_wp, ptr[reg_ptr_wp + offset + d]); vaddps(ymm_f, ymm_f, ymm_wp1);
vmulps(ymm_wp, ymm_i, ymm_wp);
vaddps(ymm_src, ymm_src, ymm_wp);
} }
act<ymm_t>(ymm_f, ymm_src, act_gate_); act<ymm_t>(ymm_f, ymm_f, act_gate_);
vmulps(ymm_f, ymm_f, ymm_i); // ct
vmulps(ymm_f, ymm_f, ymm_ct_1);
vaddps(ymm_f, ymm_f, ymm_c); vaddps(ymm_f, ymm_f, ymm_c);
} }
/* H_t = act_cell(C_t) * ogated */ /* H_t = act_cell(C_t) * act_gate(o) */
// act_cell(C_t)
ymm_t ymm_ct = compute_c1h1_ ? ymm_c : ymm_f; ymm_t ymm_ct = compute_c1h1_ ? ymm_c : ymm_f;
ymm_t ymm_o = compute_c1h1_ ? ymm_f : ymm_c;
ymm_t ymm_tmp = ymm_i; ymm_t ymm_tmp = ymm_i;
vmovups(ptr[reg_ptr_ct + offset], ymm_ct); // save ct
vmovups(ymm_src, ptr[reg_ptr_gates + offset + 3 * d]);
if (use_peephole_) {
ymm_t ymm_wp = ymm_t(2);
vmovups(ymm_wp, ptr[reg_ptr_wp + offset + d * 2]);
vmulps(ymm_wp, ymm_ct, ymm_wp);
vaddps(ymm_src, ymm_src, ymm_wp);
}
act<ymm_t>(ymm_tmp, ymm_ct, act_cell_); act<ymm_t>(ymm_tmp, ymm_ct, act_cell_);
act<ymm_t>(ymm_o, ymm_src, act_gate_); // act_gate(o) or act_gate(ct * wp2 + o)
vmulps(ymm_o, ymm_tmp, ymm_o); if (use_peephole_) {
vmovups(ptr[reg_ptr_ht + offset], ymm_o); // save ht vmulps(ymm_wp2, ymm_ct, ymm_wp2);
vaddps(ymm_o, ymm_o, ymm_wp2);
}
act<ymm_t>(ymm_o, ymm_o, act_gate_);
// ht
vmulps(ymm_o, ymm_o, ymm_tmp);
// save ct and ht
vmovups(ptr[reg_ptr_ct + offset], ymm_ct);
vmovups(ptr[reg_ptr_ht + offset], ymm_o);
offset += sizeof(float) * YMM_FLOAT_BLOCK; offset += sizeof(float) * YMM_FLOAT_BLOCK;
} }
...@@ -293,13 +268,61 @@ bool GRUJitCode::init(int d) { return MayIUse(avx) && d % 8 == 0; } ...@@ -293,13 +268,61 @@ bool GRUJitCode::init(int d) { return MayIUse(avx) && d % 8 == 0; }
void GRUJitCode::generate() { void GRUJitCode::generate() {
reg64_t reg_ptr_gates = rax; reg64_t reg_ptr_gates = rax;
reg64_t reg_ptr_ct_1 = r9; reg64_t reg_ptr_ht_1 = r9;
reg64_t reg_ptr_ct = r10; reg64_t reg_ptr_ht = r10;
reg64_t reg_ptr_ht = r11; mov(reg_ptr_gates, ptr[param1 + offsetof(gru_t, gates)]);
mov(reg_ptr_gates, ptr[param1 + offsetof(lstm_t, gates)]); mov(reg_ptr_ht_1, ptr[param1 + offsetof(gru_t, ht_1)]);
mov(reg_ptr_ct_1, ptr[param1 + offsetof(lstm_t, ct_1)]); mov(reg_ptr_ht, ptr[param1 + offsetof(gru_t, ht)]);
mov(reg_ptr_ct, ptr[param1 + offsetof(lstm_t, ct)]); ymm_t ymm_one = ymm_t(0);
mov(reg_ptr_ht, ptr[param1 + offsetof(lstm_t, ht)]);
if (id_ == 2) {
reg64_t reg_ptr_tmp = r11;
mov(reg_ptr_tmp, reinterpret_cast<size_t>(exp_float_consts));
vmovaps(ymm_one, ptr[reg_ptr_tmp + OFFSET_EXP_ONE]);
}
int offset = 0;
int d = num_ * sizeof(float);
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
ymm_t ymm_u = ymm_t(1);
ymm_t ymm_r = ymm_t(2);
ymm_t ymm_s = ymm_t(3);
ymm_t ymm_ht_1 = ymm_t(4);
// W: {W_update, W_reset; W_state}
if (id_ == 0 || id_ == 2) {
vmovups(ymm_u, ptr[reg_ptr_gates + offset]);
vmovups(ymm_s, ptr[reg_ptr_gates + offset + 2 * d]);
}
if (id_ == 1) {
vmovups(ymm_r, ptr[reg_ptr_gates + offset + d]);
}
if (id_ == 1 || id_ == 2) {
vmovups(ymm_ht_1, ptr[reg_ptr_ht_1 + offset]);
}
if (id_ == 0) {
// ht = act_gate(u) * act_cand(s)
act<ymm_t>(ymm_u, ymm_u, act_gate_);
act<ymm_t>(ymm_s, ymm_s, act_cand_);
vmulps(ymm_s, ymm_s, ymm_u);
vmovups(ptr[reg_ptr_ht + offset], ymm_s);
} else if (id_ == 1) {
// ht = act_gate(r) * ht_1
act<ymm_t>(ymm_r, ymm_r, act_gate_);
vmulps(ymm_r, ymm_r, ymm_ht_1);
vmovups(ptr[reg_ptr_ht + offset], ymm_r);
} else if (id_ == 2) {
// ht = act_gate(u) * act_cand(s) + (1-act_gate(u)) * ht_1
ymm_t ymm_one_inner = ymm_t(ymm_one.getIdx());
act<ymm_t>(ymm_u, ymm_u, act_gate_);
act<ymm_t>(ymm_s, ymm_s, act_cand_);
vmulps(ymm_s, ymm_s, ymm_u);
vsubps(ymm_u, ymm_one_inner, ymm_u);
vmulps(ymm_u, ymm_ht_1, ymm_u);
vaddps(ymm_u, ymm_s, ymm_u);
vmovups(ptr[reg_ptr_ht + offset], ymm_u);
}
offset += sizeof(float) * YMM_FLOAT_BLOCK;
}
ret(); ret();
} }
......
...@@ -169,31 +169,34 @@ class VActJitCode : public JitCode { ...@@ -169,31 +169,34 @@ class VActJitCode : public JitCode {
protected: protected:
// compute relu with ymm, xmm // compute relu with ymm, xmm
template <typename JMM> template <typename JMM>
void relu_jmm(JMM& dst, JMM& src, JMM& zero) { // NOLINT void relu_jmm(JMM& dst, JMM& src, int zero_idx = 15) { // NOLINT
JMM zero = JMM(zero_idx);
vxorps(zero, zero, zero);
vmaxps(dst, src, zero); vmaxps(dst, src, zero);
} }
// compute exp with ymm, xmm // compute exp with ymm, xmm
template <typename JMM> template <typename JMM>
void exp_jmm(JMM& dst, JMM& src, int fx_idx = 2, int fy_idx = 3, // NOLINT void exp_jmm(JMM& dst, JMM& src, int src_idx = 11, int fx_idx = 12, // NOLINT
int mask_idx = 4, int tmp_idx = 5) { int fy_idx = 13, int mask_idx = 14, int tmp_idx = 15) {
using namespace platform::jit; // NOLINT using namespace platform::jit; // NOLINT
assert(src.getIdx() != dst.getIdx()); // TODO(TJ): use enfore
// check all idx can not equal // check all idx can not equal
JMM jmm_src = JMM(src_idx);
JMM jmm_fx = JMM(fx_idx); JMM jmm_fx = JMM(fx_idx);
JMM jmm_fy = JMM(fy_idx); JMM jmm_fy = JMM(fy_idx);
JMM jmm_mask = JMM(mask_idx); JMM jmm_mask = JMM(mask_idx);
JMM jmm_tmp = JMM(tmp_idx); JMM jmm_tmp = JMM(tmp_idx);
reg64_t reg_ptr_global = rax; reg64_t reg_ptr_global = rax;
push(reg_ptr_global); push(reg_ptr_global);
vmovaps(jmm_src, src);
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts)); mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]);
vminps(src, src, jmm_tmp); vminps(jmm_src, jmm_src, jmm_tmp);
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOW]); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOW]);
vmaxps(src, src, jmm_tmp); vmaxps(jmm_src, jmm_src, jmm_tmp);
// express exp(x) as exp(g + n*log(2)) // express exp(x) as exp(g + n*log(2))
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOG2EF]); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOG2EF]);
vmulps(jmm_fx, src, jmm_tmp); vmulps(jmm_fx, jmm_src, jmm_tmp);
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_0P5]); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_0P5]);
vaddps(jmm_fx, jmm_fx, jmm_tmp); vaddps(jmm_fx, jmm_fx, jmm_tmp);
vroundps(jmm_fy, jmm_fx, 0x01); vroundps(jmm_fy, jmm_fx, 0x01);
...@@ -207,21 +210,21 @@ class VActJitCode : public JitCode { ...@@ -207,21 +210,21 @@ class VActJitCode : public JitCode {
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C2]); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C2]);
JMM ymm_z = JMM(jmm_mask.getIdx()); JMM ymm_z = JMM(jmm_mask.getIdx());
vmulps(ymm_z, jmm_fx, jmm_tmp); vmulps(ymm_z, jmm_fx, jmm_tmp);
vsubps(src, src, jmm_fy); vsubps(jmm_src, jmm_src, jmm_fy);
vsubps(src, src, ymm_z); vsubps(jmm_src, jmm_src, ymm_z);
vmulps(ymm_z, src, src); vmulps(ymm_z, jmm_src, jmm_src);
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P0]); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P0]);
vmulps(dst, src, jmm_tmp); vmulps(dst, jmm_src, jmm_tmp);
for (size_t i = OFFSET_EXP_P1; i < OFFSET_EXP_P5; for (size_t i = OFFSET_EXP_P1; i < OFFSET_EXP_P5;
i += (YMM_FLOAT_BLOCK * sizeof(float))) { i += (YMM_FLOAT_BLOCK * sizeof(float))) {
vmovaps(jmm_tmp, ptr[reg_ptr_global + i]); // P1~P4 vmovaps(jmm_tmp, ptr[reg_ptr_global + i]); // P1~P4
vaddps(dst, dst, jmm_tmp); vaddps(dst, dst, jmm_tmp);
vmulps(dst, dst, src); vmulps(dst, dst, jmm_src);
} }
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P5]); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P5]);
vaddps(dst, dst, jmm_tmp); vaddps(dst, dst, jmm_tmp);
vmulps(dst, dst, ymm_z); vmulps(dst, dst, ymm_z);
vaddps(dst, dst, src); vaddps(dst, dst, jmm_src);
vmovaps(jmm_tmp, ptr[reg_ptr_global]); vmovaps(jmm_tmp, ptr[reg_ptr_global]);
vaddps(dst, dst, jmm_tmp); vaddps(dst, dst, jmm_tmp);
// build 2^n // build 2^n
...@@ -258,20 +261,23 @@ class VActJitCode : public JitCode { ...@@ -258,20 +261,23 @@ class VActJitCode : public JitCode {
// compute sigmoid with ymm, xmm // compute sigmoid with ymm, xmm
template <typename JMM> template <typename JMM>
void sigmoid_jmm(JMM& dst, JMM& src, int fx_idx = 2, // NOLINT void sigmoid_jmm(JMM& dst, JMM& src, int src_idx = 11, // NOLINT
int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5) { int fx_idx = 12, int fy_idx = 13, int mask_idx = 14,
int tmp_idx = 15) {
// y = 1 / (1 + e^-x) // y = 1 / (1 + e^-x)
JMM jmm_tmp = JMM(tmp_idx); JMM jmm_tmp = JMM(tmp_idx);
JMM jmm_src = JMM(src_idx);
reg64_t reg_ptr_global = rax; reg64_t reg_ptr_global = rax;
push(reg_ptr_global); push(reg_ptr_global);
vmovaps(jmm_src, src);
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts)); mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MAX]); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MAX]);
vminps(src, src, jmm_tmp); vminps(jmm_src, jmm_src, jmm_tmp);
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MIN]); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MIN]);
vmaxps(src, src, jmm_tmp); vmaxps(jmm_src, jmm_src, jmm_tmp);
vxorps(jmm_tmp, jmm_tmp, jmm_tmp); vxorps(jmm_tmp, jmm_tmp, jmm_tmp);
vsubps(src, jmm_tmp, src); vsubps(jmm_src, jmm_tmp, jmm_src);
exp_jmm<JMM>(dst, src, fx_idx, fy_idx, mask_idx, tmp_idx); exp_jmm<JMM>(dst, jmm_src, src_idx, fx_idx, fy_idx, mask_idx, tmp_idx);
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
vaddps(dst, dst, jmm_tmp); vaddps(dst, dst, jmm_tmp);
vdivps(dst, jmm_tmp, dst); vdivps(dst, jmm_tmp, dst);
...@@ -280,19 +286,22 @@ class VActJitCode : public JitCode { ...@@ -280,19 +286,22 @@ class VActJitCode : public JitCode {
// compute tanh with ymm, xmm // compute tanh with ymm, xmm
template <typename JMM> template <typename JMM>
void tanh_jmm(JMM& dst, JMM& src, int fx_idx = 2, int fy_idx = 3, // NOLINT void tanh_jmm(JMM& dst, JMM& src, int src_idx = 11, // NOLINT
int mask_idx = 4, int tmp_idx = 5) { int fx_idx = 12, int fy_idx = 13, int mask_idx = 14,
int tmp_idx = 15) {
// y = 2 / (1 + e^(-2x)) - 1 // y = 2 / (1 + e^(-2x)) - 1
JMM jmm_src = JMM(src_idx);
JMM jmm_tmp = JMM(tmp_idx); JMM jmm_tmp = JMM(tmp_idx);
JMM jmm_zero = JMM(mask_idx); JMM jmm_zero = JMM(mask_idx);
reg64_t reg_ptr_global = rax; reg64_t reg_ptr_global = rax;
push(reg_ptr_global); push(reg_ptr_global);
vmovaps(jmm_src, src);
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts)); mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]);
vxorps(jmm_zero, jmm_zero, jmm_zero); vxorps(jmm_zero, jmm_zero, jmm_zero);
vsubps(jmm_tmp, jmm_zero, jmm_tmp); vsubps(jmm_tmp, jmm_zero, jmm_tmp);
vmulps(src, src, jmm_tmp); vmulps(jmm_src, jmm_src, jmm_tmp);
exp_jmm<JMM>(dst, src, fx_idx, fy_idx, mask_idx, tmp_idx); exp_jmm<JMM>(dst, jmm_src, src_idx, fx_idx, fy_idx, mask_idx, tmp_idx);
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
vaddps(dst, dst, jmm_tmp); vaddps(dst, dst, jmm_tmp);
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]);
...@@ -304,23 +313,19 @@ class VActJitCode : public JitCode { ...@@ -304,23 +313,19 @@ class VActJitCode : public JitCode {
template <typename JMM> template <typename JMM>
void act(JMM& dst, JMM& src, operand_type type) { // NOLINT void act(JMM& dst, JMM& src, operand_type type) { // NOLINT
// use 15 // use 11~15
JMM zero = JMM(15);
if (type_ == operand_type::relu) {
vxorps(zero, zero, zero);
}
switch (type) { switch (type) {
case operand_type::relu: case operand_type::relu:
relu_jmm<JMM>(dst, src, zero); relu_jmm<JMM>(dst, src, 15);
break; break;
case operand_type::exp: case operand_type::exp:
exp_jmm<JMM>(dst, src, 2, 3, 4, 5); exp_jmm<JMM>(dst, src, 11, 12, 13, 14, 15);
break; break;
case operand_type::sigmoid: case operand_type::sigmoid:
sigmoid_jmm<JMM>(dst, src, 2, 3, 4, 5); sigmoid_jmm<JMM>(dst, src, 11, 12, 13, 14, 15);
break; break;
case operand_type::tanh: case operand_type::tanh:
tanh_jmm<JMM>(dst, src, 2, 3, 4, 5); tanh_jmm<JMM>(dst, src, 11, 12, 13, 14, 15);
break; break;
case operand_type::identity: case operand_type::identity:
break; break;
...@@ -414,15 +419,6 @@ class LSTMJitCode : public VActJitCode { ...@@ -414,15 +419,6 @@ class LSTMJitCode : public VActJitCode {
operand_type act_cand_; operand_type act_cand_;
operand_type act_cell_; operand_type act_cell_;
reg64_t param1{abi_param1}; reg64_t param1{abi_param1};
xmm_t xmm_src = xmm_t(0);
xmm_t xmm_c = xmm_t(1);
xmm_t xmm_i = xmm_t(6);
xmm_t xmm_f = xmm_t(7);
ymm_t ymm_src = ymm_t(0);
ymm_t ymm_c = ymm_t(1); // 2~5 for act
ymm_t ymm_i = ymm_t(6);
ymm_t ymm_f = ymm_t(7);
}; };
class GRUJitCode : public VActJitCode { class GRUJitCode : public VActJitCode {
...@@ -492,16 +488,6 @@ class GRUJitCode : public VActJitCode { ...@@ -492,16 +488,6 @@ class GRUJitCode : public VActJitCode {
operand_type act_gate_; operand_type act_gate_;
operand_type act_cand_; operand_type act_cand_;
reg64_t param1{abi_param1}; reg64_t param1{abi_param1};
xmm_t xmm_src = xmm_t(0);
xmm_t xmm_c = xmm_t(1);
xmm_t xmm_i = xmm_t(6);
xmm_t xmm_f = xmm_t(7);
ymm_t ymm_src = ymm_t(0);
ymm_t ymm_c = ymm_t(1);
ymm_t ymm_i = ymm_t(6);
ymm_t ymm_f = ymm_t(7);
}; };
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
......
...@@ -206,7 +206,7 @@ void GRUHtPart1(gru_t* step, const gru_attr_t* attr) { ...@@ -206,7 +206,7 @@ void GRUHtPart1(gru_t* step, const gru_attr_t* attr) {
T* ht = reinterpret_cast<T*>(step->ht); T* ht = reinterpret_cast<T*>(step->ht);
const T* ht_1 = reinterpret_cast<const T*>(step->ht_1); const T* ht_1 = reinterpret_cast<const T*>(step->ht_1);
auto act_gate = getActFunc<T>(attr->act_gate); auto act_gate = getActFunc<T>(attr->act_gate);
act_gate(gates, gates, attr->d * 2); act_gate(gates + attr->d, gates + attr->d, attr->d);
VMul(ht_1, gates + attr->d, ht, attr->d); VMul(ht_1, gates + attr->d, ht, attr->d);
} }
...@@ -215,9 +215,11 @@ void GRUHtPart2(gru_t* step, const gru_attr_t* attr) { ...@@ -215,9 +215,11 @@ void GRUHtPart2(gru_t* step, const gru_attr_t* attr) {
T* gates = reinterpret_cast<T*>(step->gates); T* gates = reinterpret_cast<T*>(step->gates);
T* ht = reinterpret_cast<T*>(step->ht); T* ht = reinterpret_cast<T*>(step->ht);
const T* ht_1 = reinterpret_cast<const T*>(step->ht_1); const T* ht_1 = reinterpret_cast<const T*>(step->ht_1);
auto act_gate = getActFunc<T>(attr->act_gate);
auto act_cand = getActFunc<T>(attr->act_cand); auto act_cand = getActFunc<T>(attr->act_cand);
int d = attr->d; int d = attr->d;
T* y = gates + d * 2; T* y = gates + d * 2;
act_gate(gates, gates, d);
act_cand(y, y, d); act_cand(y, y, d);
// out = zt*ht~ + (1-zt)*ht_1 // out = zt*ht~ + (1-zt)*ht_1
for (int i = 0; i < d; ++i) { for (int i = 0; i < d; ++i) {
......
...@@ -177,7 +177,7 @@ class GRUKernelImpl : public GRUKernel<T> { ...@@ -177,7 +177,7 @@ class GRUKernelImpl : public GRUKernel<T> {
explicit GRUKernelImpl(const gru_attr_t& attr) : GRUKernel<T>() { explicit GRUKernelImpl(const gru_attr_t& attr) : GRUKernel<T>() {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
if (useJIT(attr.d)) { if (useJIT(attr.d)) {
size_t sz = 96 + attr.d / YMM_FLOAT_BLOCK * 84 * 8; // should change size_t sz = 96 + attr.d / YMM_FLOAT_BLOCK * 96 * 2 * 8;
jitcode0_.reset(new gen::GRUJitCode(0, attr, sz > 4096 ? sz : 4096)); jitcode0_.reset(new gen::GRUJitCode(0, attr, sz > 4096 ? sz : 4096));
this->ComputeH1 = this->ComputeH1 =
jitcode0_->getCode<void (*)(gru_t*, const gru_attr_t*)>(); jitcode0_->getCode<void (*)(gru_t*, const gru_attr_t*)>();
...@@ -188,7 +188,7 @@ class GRUKernelImpl : public GRUKernel<T> { ...@@ -188,7 +188,7 @@ class GRUKernelImpl : public GRUKernel<T> {
jitcode2_.reset(new gen::GRUJitCode(2, attr, sz > 4096 ? sz : 4096)); jitcode2_.reset(new gen::GRUJitCode(2, attr, sz > 4096 ? sz : 4096));
this->ComputeHtPart2 = this->ComputeHtPart2 =
jitcode1_->getCode<void (*)(gru_t*, const gru_attr_t*)>(); jitcode2_->getCode<void (*)(gru_t*, const gru_attr_t*)>();
return; return;
} }
#endif #endif
...@@ -207,7 +207,7 @@ class GRUKernelImpl : public GRUKernel<T> { ...@@ -207,7 +207,7 @@ class GRUKernelImpl : public GRUKernel<T> {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
template <> template <>
bool GRUKernelImpl<float>::useJIT(int d) { bool GRUKernelImpl<float>::useJIT(int d) {
return false; // jitcode not ready yet return gen::GRUJitCode::init(d);
} }
#endif #endif
......
...@@ -714,6 +714,8 @@ TEST(JitKernel, pool) { ...@@ -714,6 +714,8 @@ TEST(JitKernel, pool) {
std::string act_gate = "sigmoid", act_cand = "tanh", act_cell = "tanh"; std::string act_gate = "sigmoid", act_cand = "tanh", act_cell = "tanh";
jit::lstm_attr_t attr(frame_size, act_gate, act_cand, act_cell, false); jit::lstm_attr_t attr(frame_size, act_gate, act_cand, act_cell, false);
// empty call it to avoid unknown flag 'use_pinned_memory' on Mac
paddle::platform::jit::MayIUse(paddle::platform::jit::avx);
const auto& plstm1 = const auto& plstm1 =
jit::KernelPool::Instance() jit::KernelPool::Instance()
.template Get<jit::LSTMKernel<float>, const jit::lstm_attr_t&>(attr); .template Get<jit::LSTMKernel<float>, const jit::lstm_attr_t&>(attr);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册