未验证 提交 b4897600 编写于 作者: W Wilber 提交者: GitHub

fix jit_matmul bug test=develop (#20886)

* fix jit_matmul bug 

* update jit matmul and add test
上级 3255fe69
...@@ -40,7 +40,12 @@ void MatMulJitCode::genCode() { ...@@ -40,7 +40,12 @@ void MatMulJitCode::genCode() {
size_t wgt_offset = 0; size_t wgt_offset = 0;
for (size_t g = 0; g < groups.size(); ++g) { for (size_t g = 0; g < groups.size(); ++g) {
size_t x_offset = 0; size_t x_offset = 0;
size_t wgt_offset_tmp = 0;
for (int i = 0; i < g; ++i) {
wgt_offset_tmp += groups[i] * block_len;
}
for (int k = 0; k < k_; ++k) { for (int k = 0; k < k_; ++k) {
wgt_offset = wgt_offset_tmp;
vbroadcastss(zmm_t(x_reg_idx), ptr[param_x + x_offset]); vbroadcastss(zmm_t(x_reg_idx), ptr[param_x + x_offset]);
// clean // clean
if (k == 0) { if (k == 0) {
...@@ -49,7 +54,8 @@ void MatMulJitCode::genCode() { ...@@ -49,7 +54,8 @@ void MatMulJitCode::genCode() {
} }
} }
for (int i = 0; i < groups[g]; ++i) { for (int i = 0; i < groups[g]; ++i) {
vmovups(zmm_t(w_reg_idx), ptr[reg_ptr_wgt + wgt_offset]); vmovups(zmm_t(w_reg_idx),
ptr[reg_ptr_wgt + wgt_offset + k * n_ * sizeof(float)]);
vfmadd231ps(zmm_t(i), zmm_t(w_reg_idx), zmm_t(x_reg_idx)); vfmadd231ps(zmm_t(i), zmm_t(w_reg_idx), zmm_t(x_reg_idx));
wgt_offset += block_len; wgt_offset += block_len;
} }
......
...@@ -78,7 +78,7 @@ class TestFusionRepeatedFCReluOp(OpTest): ...@@ -78,7 +78,7 @@ class TestFusionRepeatedFCReluOp(OpTest):
class TestFusionRepeatedFCReluOpBS1(TestFusionRepeatedFCReluOp): class TestFusionRepeatedFCReluOpBS1(TestFusionRepeatedFCReluOp):
def set_conf(self): def set_conf(self):
self.bs = 1 self.bs = 1
self.oc = [4, 2, 7, 5] self.oc = [4, 2, 7, 5, 512, 1024]
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册