From 6a7f83d45df2ff22c49867837c97f0773421ee0c Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Fri, 23 Nov 2018 04:11:28 +0000 Subject: [PATCH] enable gru jitcode and refine act and lstm jitcode test=develop --- paddle/fluid/operators/math/jit_code.cc | 183 ++++++++++-------- paddle/fluid/operators/math/jit_code.h | 90 ++++----- .../fluid/operators/math/jit_kernel_refer.h | 4 +- paddle/fluid/operators/math/jit_kernel_rnn.cc | 6 +- .../fluid/operators/math/jit_kernel_test.cc | 2 + 5 files changed, 149 insertions(+), 136 deletions(-) diff --git a/paddle/fluid/operators/math/jit_code.cc b/paddle/fluid/operators/math/jit_code.cc index 95247ce309..52cbdf685d 100644 --- a/paddle/fluid/operators/math/jit_code.cc +++ b/paddle/fluid/operators/math/jit_code.cc @@ -140,32 +140,10 @@ bool VActJitCode::init(int d, operand_type type) { } void VActJitCode::generate() { - xmm_t xmm_zero = xmm_t(2); - ymm_t ymm_zero = ymm_t(2); - if (type_ == operand_type::relu) { - vxorps(ymm_zero, ymm_zero, ymm_zero); - } int offset = 0; for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) { vmovups(ymm_src, ptr[param1 + offset]); - switch (type_) { - case operand_type::relu: - relu_jmm(ymm_dst, ymm_src, ymm_zero); - break; - case operand_type::exp: - exp_jmm(ymm_dst, ymm_src, 2, 3, 4, 5); - break; - case operand_type::sigmoid: - sigmoid_jmm(ymm_dst, ymm_src, 2, 3, 4, 5); - break; - case operand_type::tanh: - tanh_jmm(ymm_dst, ymm_src, 2, 3, 4, 5); - break; - case operand_type::identity: - break; - default: - break; - } + act(ymm_dst, ymm_src, type_); vmovups(ptr[param2 + offset], ymm_dst); offset += sizeof(float) * YMM_FLOAT_BLOCK; } @@ -182,22 +160,7 @@ void VActJitCode::generate() { block = 1; vmovss(xmm_src, ptr[param1 + offset]); } - switch (type_) { - case operand_type::relu: - relu_jmm(xmm_dst, xmm_src, xmm_zero); - break; - case operand_type::exp: - exp_jmm(xmm_dst, xmm_src, 2, 3, 4, 5); - break; - case operand_type::sigmoid: - sigmoid_jmm(xmm_dst, xmm_src, 2, 3, 4, 5); - break; - case operand_type::tanh: - tanh_jmm(xmm_dst, xmm_src, 2, 3, 4, 5); - break; - default: - break; - } + act(xmm_dst, xmm_src, type_); if (rest >= 4) { vmovups(ptr[param2 + offset], xmm_dst); } else if (rest >= 2) { @@ -233,52 +196,64 @@ void LSTMJitCode::generate() { int offset = 0; int d = num_ * sizeof(float); for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) { - /* C_t = C_t-1 * fgated + cand_gated * igated*/ - // c - vmovups(ymm_src, ptr[reg_ptr_gates + offset]); - act(ymm_c, ymm_src, act_cand_); - // i - vmovups(ymm_src, ptr[reg_ptr_gates + offset + d]); - if (!compute_c1h1_ && use_peephole_) { - ymm_t ymm_wp = ymm_t(2); - ymm_t ymm_ct_1 = ymm_t(3); - vmovups(ymm_wp, ptr[reg_ptr_wp + offset]); + /* gates: W_ch, W_ih, W_fh, W_oh */ + ymm_t ymm_c = ymm_t(0); + ymm_t ymm_i = ymm_t(1); + ymm_t ymm_f = ymm_t(2); + ymm_t ymm_o = ymm_t(3); + ymm_t ymm_ct_1 = ymm_t(4); + ymm_t ymm_wp0 = ymm_t(5); + ymm_t ymm_wp1 = ymm_t(6); + ymm_t ymm_wp2 = ymm_t(7); + vmovups(ymm_c, ptr[reg_ptr_gates + offset]); + vmovups(ymm_i, ptr[reg_ptr_gates + offset + d]); + vmovups(ymm_f, ptr[reg_ptr_gates + offset + 2 * d]); + vmovups(ymm_o, ptr[reg_ptr_gates + offset + 3 * d]); + if (!compute_c1h1_) { vmovups(ymm_ct_1, ptr[reg_ptr_ct_1 + offset]); - vmulps(ymm_wp, ymm_ct_1, ymm_wp); - vaddps(ymm_src, ymm_src, ymm_wp); } - act(ymm_i, ymm_src, act_gate_); + if (use_peephole_) { + vmovups(ymm_wp0, ptr[reg_ptr_wp + offset]); + vmovups(ymm_wp1, ptr[reg_ptr_wp + offset + d]); + vmovups(ymm_wp2, ptr[reg_ptr_wp + offset + 2 * d]); + } + /* C_t = act_cand(c) * act_gate(i) + C_t-1 * act_gate(f) */ + // act_cand(c) + act(ymm_c, ymm_c, act_cand_); + // act_gate(i) or act_gate(ct_1 * wp0 + i) + if (!compute_c1h1_ && use_peephole_) { + vmulps(ymm_wp0, ymm_ct_1, ymm_wp0); + vaddps(ymm_i, ymm_i, ymm_wp0); + } + act(ymm_i, ymm_i, act_gate_); vmulps(ymm_c, ymm_c, ymm_i); if (!compute_c1h1_) { - // f - vmovups(ymm_src, ptr[reg_ptr_gates + offset + 2 * d]); - vmovups(ymm_i, ptr[reg_ptr_ct_1 + offset]); + // act_gate(f) or act_gate(ct_1 * wp1 + f) if (use_peephole_) { - ymm_t ymm_wp = ymm_t(3); - vmovups(ymm_wp, ptr[reg_ptr_wp + offset + d]); - vmulps(ymm_wp, ymm_i, ymm_wp); - vaddps(ymm_src, ymm_src, ymm_wp); + vmulps(ymm_wp1, ymm_ct_1, ymm_wp1); + vaddps(ymm_f, ymm_f, ymm_wp1); } - act(ymm_f, ymm_src, act_gate_); - vmulps(ymm_f, ymm_f, ymm_i); + act(ymm_f, ymm_f, act_gate_); + // ct + vmulps(ymm_f, ymm_f, ymm_ct_1); vaddps(ymm_f, ymm_f, ymm_c); } - /* H_t = act_cell(C_t) * ogated */ + /* H_t = act_cell(C_t) * act_gate(o) */ + // act_cell(C_t) ymm_t ymm_ct = compute_c1h1_ ? ymm_c : ymm_f; - ymm_t ymm_o = compute_c1h1_ ? ymm_f : ymm_c; ymm_t ymm_tmp = ymm_i; - vmovups(ptr[reg_ptr_ct + offset], ymm_ct); // save ct - vmovups(ymm_src, ptr[reg_ptr_gates + offset + 3 * d]); + act(ymm_tmp, ymm_ct, act_cell_); + // act_gate(o) or act_gate(ct * wp2 + o) if (use_peephole_) { - ymm_t ymm_wp = ymm_t(2); - vmovups(ymm_wp, ptr[reg_ptr_wp + offset + d * 2]); - vmulps(ymm_wp, ymm_ct, ymm_wp); - vaddps(ymm_src, ymm_src, ymm_wp); + vmulps(ymm_wp2, ymm_ct, ymm_wp2); + vaddps(ymm_o, ymm_o, ymm_wp2); } - act(ymm_tmp, ymm_ct, act_cell_); - act(ymm_o, ymm_src, act_gate_); - vmulps(ymm_o, ymm_tmp, ymm_o); - vmovups(ptr[reg_ptr_ht + offset], ymm_o); // save ht + act(ymm_o, ymm_o, act_gate_); + // ht + vmulps(ymm_o, ymm_o, ymm_tmp); + // save ct and ht + vmovups(ptr[reg_ptr_ct + offset], ymm_ct); + vmovups(ptr[reg_ptr_ht + offset], ymm_o); offset += sizeof(float) * YMM_FLOAT_BLOCK; } @@ -293,13 +268,61 @@ bool GRUJitCode::init(int d) { return MayIUse(avx) && d % 8 == 0; } void GRUJitCode::generate() { reg64_t reg_ptr_gates = rax; - reg64_t reg_ptr_ct_1 = r9; - reg64_t reg_ptr_ct = r10; - reg64_t reg_ptr_ht = r11; - mov(reg_ptr_gates, ptr[param1 + offsetof(lstm_t, gates)]); - mov(reg_ptr_ct_1, ptr[param1 + offsetof(lstm_t, ct_1)]); - mov(reg_ptr_ct, ptr[param1 + offsetof(lstm_t, ct)]); - mov(reg_ptr_ht, ptr[param1 + offsetof(lstm_t, ht)]); + reg64_t reg_ptr_ht_1 = r9; + reg64_t reg_ptr_ht = r10; + mov(reg_ptr_gates, ptr[param1 + offsetof(gru_t, gates)]); + mov(reg_ptr_ht_1, ptr[param1 + offsetof(gru_t, ht_1)]); + mov(reg_ptr_ht, ptr[param1 + offsetof(gru_t, ht)]); + ymm_t ymm_one = ymm_t(0); + + if (id_ == 2) { + reg64_t reg_ptr_tmp = r11; + mov(reg_ptr_tmp, reinterpret_cast(exp_float_consts)); + vmovaps(ymm_one, ptr[reg_ptr_tmp + OFFSET_EXP_ONE]); + } + int offset = 0; + int d = num_ * sizeof(float); + for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) { + ymm_t ymm_u = ymm_t(1); + ymm_t ymm_r = ymm_t(2); + ymm_t ymm_s = ymm_t(3); + ymm_t ymm_ht_1 = ymm_t(4); + // W: {W_update, W_reset; W_state} + if (id_ == 0 || id_ == 2) { + vmovups(ymm_u, ptr[reg_ptr_gates + offset]); + vmovups(ymm_s, ptr[reg_ptr_gates + offset + 2 * d]); + } + if (id_ == 1) { + vmovups(ymm_r, ptr[reg_ptr_gates + offset + d]); + } + if (id_ == 1 || id_ == 2) { + vmovups(ymm_ht_1, ptr[reg_ptr_ht_1 + offset]); + } + + if (id_ == 0) { + // ht = act_gate(u) * act_cand(s) + act(ymm_u, ymm_u, act_gate_); + act(ymm_s, ymm_s, act_cand_); + vmulps(ymm_s, ymm_s, ymm_u); + vmovups(ptr[reg_ptr_ht + offset], ymm_s); + } else if (id_ == 1) { + // ht = act_gate(r) * ht_1 + act(ymm_r, ymm_r, act_gate_); + vmulps(ymm_r, ymm_r, ymm_ht_1); + vmovups(ptr[reg_ptr_ht + offset], ymm_r); + } else if (id_ == 2) { + // ht = act_gate(u) * act_cand(s) + (1-act_gate(u)) * ht_1 + ymm_t ymm_one_inner = ymm_t(ymm_one.getIdx()); + act(ymm_u, ymm_u, act_gate_); + act(ymm_s, ymm_s, act_cand_); + vmulps(ymm_s, ymm_s, ymm_u); + vsubps(ymm_u, ymm_one_inner, ymm_u); + vmulps(ymm_u, ymm_ht_1, ymm_u); + vaddps(ymm_u, ymm_s, ymm_u); + vmovups(ptr[reg_ptr_ht + offset], ymm_u); + } + offset += sizeof(float) * YMM_FLOAT_BLOCK; + } ret(); } diff --git a/paddle/fluid/operators/math/jit_code.h b/paddle/fluid/operators/math/jit_code.h index 403cea3991..a921462129 100644 --- a/paddle/fluid/operators/math/jit_code.h +++ b/paddle/fluid/operators/math/jit_code.h @@ -169,31 +169,34 @@ class VActJitCode : public JitCode { protected: // compute relu with ymm, xmm template - void relu_jmm(JMM& dst, JMM& src, JMM& zero) { // NOLINT + void relu_jmm(JMM& dst, JMM& src, int zero_idx = 15) { // NOLINT + JMM zero = JMM(zero_idx); + vxorps(zero, zero, zero); vmaxps(dst, src, zero); } // compute exp with ymm, xmm template - void exp_jmm(JMM& dst, JMM& src, int fx_idx = 2, int fy_idx = 3, // NOLINT - int mask_idx = 4, int tmp_idx = 5) { - using namespace platform::jit; // NOLINT - assert(src.getIdx() != dst.getIdx()); // TODO(TJ): use enfore + void exp_jmm(JMM& dst, JMM& src, int src_idx = 11, int fx_idx = 12, // NOLINT + int fy_idx = 13, int mask_idx = 14, int tmp_idx = 15) { + using namespace platform::jit; // NOLINT // check all idx can not equal + JMM jmm_src = JMM(src_idx); JMM jmm_fx = JMM(fx_idx); JMM jmm_fy = JMM(fy_idx); JMM jmm_mask = JMM(mask_idx); JMM jmm_tmp = JMM(tmp_idx); reg64_t reg_ptr_global = rax; push(reg_ptr_global); + vmovaps(jmm_src, src); mov(reg_ptr_global, reinterpret_cast(exp_float_consts)); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]); - vminps(src, src, jmm_tmp); + vminps(jmm_src, jmm_src, jmm_tmp); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOW]); - vmaxps(src, src, jmm_tmp); + vmaxps(jmm_src, jmm_src, jmm_tmp); // express exp(x) as exp(g + n*log(2)) vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOG2EF]); - vmulps(jmm_fx, src, jmm_tmp); + vmulps(jmm_fx, jmm_src, jmm_tmp); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_0P5]); vaddps(jmm_fx, jmm_fx, jmm_tmp); vroundps(jmm_fy, jmm_fx, 0x01); @@ -207,21 +210,21 @@ class VActJitCode : public JitCode { vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C2]); JMM ymm_z = JMM(jmm_mask.getIdx()); vmulps(ymm_z, jmm_fx, jmm_tmp); - vsubps(src, src, jmm_fy); - vsubps(src, src, ymm_z); - vmulps(ymm_z, src, src); + vsubps(jmm_src, jmm_src, jmm_fy); + vsubps(jmm_src, jmm_src, ymm_z); + vmulps(ymm_z, jmm_src, jmm_src); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P0]); - vmulps(dst, src, jmm_tmp); + vmulps(dst, jmm_src, jmm_tmp); for (size_t i = OFFSET_EXP_P1; i < OFFSET_EXP_P5; i += (YMM_FLOAT_BLOCK * sizeof(float))) { vmovaps(jmm_tmp, ptr[reg_ptr_global + i]); // P1~P4 vaddps(dst, dst, jmm_tmp); - vmulps(dst, dst, src); + vmulps(dst, dst, jmm_src); } vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P5]); vaddps(dst, dst, jmm_tmp); vmulps(dst, dst, ymm_z); - vaddps(dst, dst, src); + vaddps(dst, dst, jmm_src); vmovaps(jmm_tmp, ptr[reg_ptr_global]); vaddps(dst, dst, jmm_tmp); // build 2^n @@ -258,20 +261,23 @@ class VActJitCode : public JitCode { // compute sigmoid with ymm, xmm template - void sigmoid_jmm(JMM& dst, JMM& src, int fx_idx = 2, // NOLINT - int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5) { + void sigmoid_jmm(JMM& dst, JMM& src, int src_idx = 11, // NOLINT + int fx_idx = 12, int fy_idx = 13, int mask_idx = 14, + int tmp_idx = 15) { // y = 1 / (1 + e^-x) JMM jmm_tmp = JMM(tmp_idx); + JMM jmm_src = JMM(src_idx); reg64_t reg_ptr_global = rax; push(reg_ptr_global); + vmovaps(jmm_src, src); mov(reg_ptr_global, reinterpret_cast(exp_float_consts)); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MAX]); - vminps(src, src, jmm_tmp); + vminps(jmm_src, jmm_src, jmm_tmp); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MIN]); - vmaxps(src, src, jmm_tmp); + vmaxps(jmm_src, jmm_src, jmm_tmp); vxorps(jmm_tmp, jmm_tmp, jmm_tmp); - vsubps(src, jmm_tmp, src); - exp_jmm(dst, src, fx_idx, fy_idx, mask_idx, tmp_idx); + vsubps(jmm_src, jmm_tmp, jmm_src); + exp_jmm(dst, jmm_src, src_idx, fx_idx, fy_idx, mask_idx, tmp_idx); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]); vaddps(dst, dst, jmm_tmp); vdivps(dst, jmm_tmp, dst); @@ -280,19 +286,22 @@ class VActJitCode : public JitCode { // compute tanh with ymm, xmm template - void tanh_jmm(JMM& dst, JMM& src, int fx_idx = 2, int fy_idx = 3, // NOLINT - int mask_idx = 4, int tmp_idx = 5) { + void tanh_jmm(JMM& dst, JMM& src, int src_idx = 11, // NOLINT + int fx_idx = 12, int fy_idx = 13, int mask_idx = 14, + int tmp_idx = 15) { // y = 2 / (1 + e^(-2x)) - 1 + JMM jmm_src = JMM(src_idx); JMM jmm_tmp = JMM(tmp_idx); JMM jmm_zero = JMM(mask_idx); reg64_t reg_ptr_global = rax; push(reg_ptr_global); + vmovaps(jmm_src, src); mov(reg_ptr_global, reinterpret_cast(exp_float_consts)); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]); vxorps(jmm_zero, jmm_zero, jmm_zero); vsubps(jmm_tmp, jmm_zero, jmm_tmp); - vmulps(src, src, jmm_tmp); - exp_jmm(dst, src, fx_idx, fy_idx, mask_idx, tmp_idx); + vmulps(jmm_src, jmm_src, jmm_tmp); + exp_jmm(dst, jmm_src, src_idx, fx_idx, fy_idx, mask_idx, tmp_idx); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]); vaddps(dst, dst, jmm_tmp); vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]); @@ -304,23 +313,19 @@ class VActJitCode : public JitCode { template void act(JMM& dst, JMM& src, operand_type type) { // NOLINT - // use 15 - JMM zero = JMM(15); - if (type_ == operand_type::relu) { - vxorps(zero, zero, zero); - } + // use 11~15 switch (type) { case operand_type::relu: - relu_jmm(dst, src, zero); + relu_jmm(dst, src, 15); break; case operand_type::exp: - exp_jmm(dst, src, 2, 3, 4, 5); + exp_jmm(dst, src, 11, 12, 13, 14, 15); break; case operand_type::sigmoid: - sigmoid_jmm(dst, src, 2, 3, 4, 5); + sigmoid_jmm(dst, src, 11, 12, 13, 14, 15); break; case operand_type::tanh: - tanh_jmm(dst, src, 2, 3, 4, 5); + tanh_jmm(dst, src, 11, 12, 13, 14, 15); break; case operand_type::identity: break; @@ -414,15 +419,6 @@ class LSTMJitCode : public VActJitCode { operand_type act_cand_; operand_type act_cell_; reg64_t param1{abi_param1}; - xmm_t xmm_src = xmm_t(0); - xmm_t xmm_c = xmm_t(1); - xmm_t xmm_i = xmm_t(6); - xmm_t xmm_f = xmm_t(7); - - ymm_t ymm_src = ymm_t(0); - ymm_t ymm_c = ymm_t(1); // 2~5 for act - ymm_t ymm_i = ymm_t(6); - ymm_t ymm_f = ymm_t(7); }; class GRUJitCode : public VActJitCode { @@ -492,16 +488,6 @@ class GRUJitCode : public VActJitCode { operand_type act_gate_; operand_type act_cand_; reg64_t param1{abi_param1}; - - xmm_t xmm_src = xmm_t(0); - xmm_t xmm_c = xmm_t(1); - xmm_t xmm_i = xmm_t(6); - xmm_t xmm_f = xmm_t(7); - - ymm_t ymm_src = ymm_t(0); - ymm_t ymm_c = ymm_t(1); - ymm_t ymm_i = ymm_t(6); - ymm_t ymm_f = ymm_t(7); }; #ifdef PADDLE_WITH_MKLDNN diff --git a/paddle/fluid/operators/math/jit_kernel_refer.h b/paddle/fluid/operators/math/jit_kernel_refer.h index 2e1a7f22db..bcb6615df8 100644 --- a/paddle/fluid/operators/math/jit_kernel_refer.h +++ b/paddle/fluid/operators/math/jit_kernel_refer.h @@ -206,7 +206,7 @@ void GRUHtPart1(gru_t* step, const gru_attr_t* attr) { T* ht = reinterpret_cast(step->ht); const T* ht_1 = reinterpret_cast(step->ht_1); auto act_gate = getActFunc(attr->act_gate); - act_gate(gates, gates, attr->d * 2); + act_gate(gates + attr->d, gates + attr->d, attr->d); VMul(ht_1, gates + attr->d, ht, attr->d); } @@ -215,9 +215,11 @@ void GRUHtPart2(gru_t* step, const gru_attr_t* attr) { T* gates = reinterpret_cast(step->gates); T* ht = reinterpret_cast(step->ht); const T* ht_1 = reinterpret_cast(step->ht_1); + auto act_gate = getActFunc(attr->act_gate); auto act_cand = getActFunc(attr->act_cand); int d = attr->d; T* y = gates + d * 2; + act_gate(gates, gates, d); act_cand(y, y, d); // out = zt*ht~ + (1-zt)*ht_1 for (int i = 0; i < d; ++i) { diff --git a/paddle/fluid/operators/math/jit_kernel_rnn.cc b/paddle/fluid/operators/math/jit_kernel_rnn.cc index 85ea95cfcc..2db3274a45 100644 --- a/paddle/fluid/operators/math/jit_kernel_rnn.cc +++ b/paddle/fluid/operators/math/jit_kernel_rnn.cc @@ -177,7 +177,7 @@ class GRUKernelImpl : public GRUKernel { explicit GRUKernelImpl(const gru_attr_t& attr) : GRUKernel() { #ifdef PADDLE_WITH_XBYAK if (useJIT(attr.d)) { - size_t sz = 96 + attr.d / YMM_FLOAT_BLOCK * 84 * 8; // should change + size_t sz = 96 + attr.d / YMM_FLOAT_BLOCK * 96 * 2 * 8; jitcode0_.reset(new gen::GRUJitCode(0, attr, sz > 4096 ? sz : 4096)); this->ComputeH1 = jitcode0_->getCode(); @@ -188,7 +188,7 @@ class GRUKernelImpl : public GRUKernel { jitcode2_.reset(new gen::GRUJitCode(2, attr, sz > 4096 ? sz : 4096)); this->ComputeHtPart2 = - jitcode1_->getCode(); + jitcode2_->getCode(); return; } #endif @@ -207,7 +207,7 @@ class GRUKernelImpl : public GRUKernel { #ifdef PADDLE_WITH_XBYAK template <> bool GRUKernelImpl::useJIT(int d) { - return false; // jitcode not ready yet + return gen::GRUJitCode::init(d); } #endif diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc index 1cbe1b5d95..cc8a5d4d86 100644 --- a/paddle/fluid/operators/math/jit_kernel_test.cc +++ b/paddle/fluid/operators/math/jit_kernel_test.cc @@ -714,6 +714,8 @@ TEST(JitKernel, pool) { std::string act_gate = "sigmoid", act_cand = "tanh", act_cell = "tanh"; jit::lstm_attr_t attr(frame_size, act_gate, act_cand, act_cell, false); + // empty call it to avoid unknown flag 'use_pinned_memory' on Mac + paddle::platform::jit::MayIUse(paddle::platform::jit::avx); const auto& plstm1 = jit::KernelPool::Instance() .template Get, const jit::lstm_attr_t&>(attr); -- GitLab