From ad86739879a9d58c5532e91a76afeef6c35a2d14 Mon Sep 17 00:00:00 2001 From: Wilber Date: Thu, 31 Oct 2019 23:45:20 +0800 Subject: [PATCH] fix repeated_fc_fuse_pass and jit::matmul bug test=develop test=release/1.6 (#20948) - fix jit::matmul bug input x, shape(m, k), weight, shape(k, n) --- paddle/fluid/operators/jit/gen/matmul.cc | 8 +++++++- .../tests/unittests/test_fusion_repeated_fc_relu_op.py | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/jit/gen/matmul.cc b/paddle/fluid/operators/jit/gen/matmul.cc index d9955c8cc6..9e9ee8df55 100644 --- a/paddle/fluid/operators/jit/gen/matmul.cc +++ b/paddle/fluid/operators/jit/gen/matmul.cc @@ -40,7 +40,12 @@ void MatMulJitCode::genCode() { size_t wgt_offset = 0; for (size_t g = 0; g < groups.size(); ++g) { size_t x_offset = 0; + size_t wgt_offset_tmp = 0; + for (int i = 0; i < g; ++i) { + wgt_offset_tmp += groups[i] * block_len; + } for (int k = 0; k < k_; ++k) { + wgt_offset = wgt_offset_tmp; vbroadcastss(zmm_t(x_reg_idx), ptr[param_x + x_offset]); // clean if (k == 0) { @@ -49,7 +54,8 @@ void MatMulJitCode::genCode() { } } for (int i = 0; i < groups[g]; ++i) { - vmovups(zmm_t(w_reg_idx), ptr[reg_ptr_wgt + wgt_offset]); + vmovups(zmm_t(w_reg_idx), + ptr[reg_ptr_wgt + wgt_offset + k * n_ * sizeof(float)]); vfmadd231ps(zmm_t(i), zmm_t(w_reg_idx), zmm_t(x_reg_idx)); wgt_offset += block_len; } diff --git a/python/paddle/fluid/tests/unittests/test_fusion_repeated_fc_relu_op.py b/python/paddle/fluid/tests/unittests/test_fusion_repeated_fc_relu_op.py index d21368fbf8..aa24408034 100644 --- a/python/paddle/fluid/tests/unittests/test_fusion_repeated_fc_relu_op.py +++ b/python/paddle/fluid/tests/unittests/test_fusion_repeated_fc_relu_op.py @@ -78,7 +78,7 @@ class TestFusionRepeatedFCReluOp(OpTest): class TestFusionRepeatedFCReluOpBS1(TestFusionRepeatedFCReluOp): def set_conf(self): self.bs = 1 - self.oc = [4, 2, 7, 5] + self.oc = [4, 2, 7, 5, 512, 1024] if __name__ == '__main__': -- GitLab