diff --git a/lite/backends/x86/jit/gen/matmul.cc b/lite/backends/x86/jit/gen/matmul.cc index 103b9101bab4a90f651b1af2fbf229933905990e..2c75f6dd5dc4bbf12513d10ef0a4e02e709135fd 100644 --- a/lite/backends/x86/jit/gen/matmul.cc +++ b/lite/backends/x86/jit/gen/matmul.cc @@ -39,7 +39,12 @@ void MatMulJitCode::genCode() { size_t wgt_offset = 0; for (size_t g = 0; g < groups.size(); ++g) { size_t x_offset = 0; + size_t wgt_offset_tmp = 0; + for (int i = 0; i < g; ++i) { + wgt_offset_tmp += groups[i] * block_len; + } for (int k = 0; k < k_; ++k) { + wgt_offset = wgt_offset_tmp; vbroadcastss(zmm_t(x_reg_idx), ptr[param_x + x_offset]); // clean if (k == 0) { @@ -48,7 +53,8 @@ void MatMulJitCode::genCode() { } } for (int i = 0; i < groups[g]; ++i) { - vmovups(zmm_t(w_reg_idx), ptr[reg_ptr_wgt + wgt_offset]); + vmovups(zmm_t(w_reg_idx), + ptr[reg_ptr_wgt + wgt_offset + k * n_ * sizeof(float)]); vfmadd231ps(zmm_t(i), zmm_t(w_reg_idx), zmm_t(x_reg_idx)); wgt_offset += block_len; }