matmul.cc 4.9 KB
Newer Older
1
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License. */

15
#include "paddle/phi/kernels/funcs/jit/gen/matmul.h"
W
wanghuancoder 已提交
16

17
#include <cstddef>  // offsetof
W
wanghuancoder 已提交
18

19
#include "paddle/phi/backends/cpu/cpu_info.h"
20
#include "paddle/phi/kernels/funcs/jit/registry.h"
21

22
namespace phi {
23 24 25 26 27 28 29
namespace jit {
namespace gen {

void MatMulJitCode::genCode() {
  preCode();
  int block, rest;
  const auto groups = packed_groups(n_, k_, &block, &rest);
G
GaoWei8 已提交
30
  PADDLE_ENFORCE_GT(
31 32
      groups.front(),
      0,
33 34 35
      phi::errors::InvalidArgument("The number of rest registers should "
                                   "be larger than 0. But it is %d.",
                                   groups.front()));
36 37 38 39 40 41 42 43 44 45 46

  const int block_len = sizeof(float) * block;
  const int x_reg_idx = (block == ZMM_FLOAT_BLOCK ? 32 : 16) - 1;
  const int w_reg_idx = x_reg_idx - 1;
  // from packed mov(reg_ptr_wgt, ptr[param_attr + offsetof(matmul_attr_t,
  // packed_weight)]);
  mov(reg_ptr_wgt, param_y);
  size_t z_offset = 0;
  size_t wgt_offset = 0;
  for (size_t g = 0; g < groups.size(); ++g) {
    size_t x_offset = 0;
W
Wilber 已提交
47
    size_t wgt_offset_tmp = 0;
48
    for (size_t i = 0; i < g; ++i) {
W
Wilber 已提交
49 50
      wgt_offset_tmp += groups[i] * block_len;
    }
51
    for (int k = 0; k < k_; ++k) {
W
Wilber 已提交
52
      wgt_offset = wgt_offset_tmp;
53 54 55 56 57 58 59 60
      vbroadcastss(zmm_t(x_reg_idx), ptr[param_x + x_offset]);
      // clean
      if (k == 0) {
        for (int i = 0; i < groups[g]; ++i) {
          vxorps(zmm_t(i), zmm_t(i), zmm_t(i));
        }
      }
      for (int i = 0; i < groups[g]; ++i) {
W
Wilber 已提交
61 62
        vmovups(zmm_t(w_reg_idx),
                ptr[reg_ptr_wgt + wgt_offset + k * n_ * sizeof(float)]);
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
        vfmadd231ps(zmm_t(i), zmm_t(w_reg_idx), zmm_t(x_reg_idx));
        wgt_offset += block_len;
      }
      // last one, save
      if (k == k_ - 1) {
        for (int i = 0; i < groups[g]; ++i) {
          // only rest save should be careful
          if (rest != 0 && g == groups.size() - 1 && i == groups[g] - 1) {
            break;
          }
          vmovups(ptr[param_z + z_offset + i * block_len], zmm_t(i));
        }
      }
      x_offset += sizeof(float);
    }
    z_offset += block_len * groups[g];
  }

  if (rest != 0) {
    // below should refine with mask
    int reg_idx = groups.back() - 1;
    z_offset = (n_ - rest) * sizeof(float);
    int inner_block = 8;
    while (rest > 0) {
      if (rest >= 8) {
        inner_block = 8;
        vmovups(ptr[param_z + z_offset], ymm_t(reg_idx));
        // shift zmm of inner_block, change reg_idx if update
      } else if (rest >= 4) {
        inner_block = 4;
        vmovups(ptr[param_z + z_offset], xmm_t(reg_idx));
      } else if (rest >= 2) {
        inner_block = 2;
        vmovq(ptr[param_z + z_offset], xmm_t(reg_idx));
      } else {
        inner_block = 1;
        vmovss(ptr[param_z + z_offset], xmm_t(reg_idx));
      }
      z_offset += inner_block * sizeof(float);
      rest -= inner_block;
    }
  }

  postCode();
}

class MatMulCreator : public JitCodeCreator<matmul_attr_t> {
 public:
111
  bool CanBeUsed(const matmul_attr_t& attr) const override {
112 113
    return attr.m == 1 &&
           phi::backends::cpu::MayIUse(phi::backends::cpu::avx512f) &&
114 115 116 117
           attr.n % ZMM_FLOAT_BLOCK == 0 && attr.k < 512;
  }
  size_t CodeSize(const matmul_attr_t& attr) const override {
    int block = YMM_FLOAT_BLOCK;
118
    if (phi::backends::cpu::MayIUse(phi::backends::cpu::avx512f)) {
119 120 121 122 123 124
      block = ZMM_FLOAT_BLOCK;
    }
    return 96 + 4 * attr.k * (attr.n / block + 1) * 8;
  }
  std::unique_ptr<GenBase> CreateJitCode(
      const matmul_attr_t& attr) const override {
G
GaoWei8 已提交
125
    PADDLE_ENFORCE_GT(
126 127
        attr.m,
        0,
128
        phi::errors::InvalidArgument(
129 130 131
            "The attribute m (first matrix's row) of MatMul should "
            "be larger than 0. But it is %d.",
            attr.m));
G
GaoWei8 已提交
132
    PADDLE_ENFORCE_GT(
133 134
        attr.n,
        0,
135
        phi::errors::InvalidArgument(
136 137 138
            "The attribute n (first matrix's col) of MatMul should "
            "be larger than 0. But it is %d.",
            attr.n));
G
GaoWei8 已提交
139
    PADDLE_ENFORCE_GT(
140 141
        attr.k,
        0,
142
        phi::errors::InvalidArgument(
143 144 145
            "The attribute k (second matrix's col) of MatMul should "
            "be larger than 0. But it is %d.",
            attr.k));
146 147 148 149 150 151
    return make_unique<MatMulJitCode>(attr, CodeSize(attr));
  }
};

}  // namespace gen
}  // namespace jit
152
}  // namespace phi
153

154
namespace gen = phi::jit::gen;
155 156

REGISTER_JITKERNEL_GEN(kMatMul, gen::MatMulCreator);