matmul.cc 5.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License. */

#include "paddle/fluid/operators/jit/gen/matmul.h"
W
wanghuancoder 已提交
16

17
#include <stddef.h>  // offsetof
18
#include <memory>
W
wanghuancoder 已提交
19

20 21 22 23 24 25 26 27 28 29 30 31
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"

namespace paddle {
namespace operators {
namespace jit {
namespace gen {

void MatMulJitCode::genCode() {
  preCode();
  int block, rest;
  const auto groups = packed_groups(n_, k_, &block, &rest);
G
GaoWei8 已提交
32 33 34 35 36
  PADDLE_ENFORCE_GT(
      groups.front(), 0,
      platform::errors::InvalidArgument("The number of rest registers should "
                                        "be larger than 0. But it is %d.",
                                        groups.front()));
37 38 39 40 41 42 43 44 45 46 47

  const int block_len = sizeof(float) * block;
  const int x_reg_idx = (block == ZMM_FLOAT_BLOCK ? 32 : 16) - 1;
  const int w_reg_idx = x_reg_idx - 1;
  // from packed mov(reg_ptr_wgt, ptr[param_attr + offsetof(matmul_attr_t,
  // packed_weight)]);
  mov(reg_ptr_wgt, param_y);
  size_t z_offset = 0;
  size_t wgt_offset = 0;
  for (size_t g = 0; g < groups.size(); ++g) {
    size_t x_offset = 0;
W
Wilber 已提交
48
    size_t wgt_offset_tmp = 0;
49
    for (size_t i = 0; i < g; ++i) {
W
Wilber 已提交
50 51
      wgt_offset_tmp += groups[i] * block_len;
    }
52
    for (int k = 0; k < k_; ++k) {
W
Wilber 已提交
53
      wgt_offset = wgt_offset_tmp;
54 55 56 57 58 59 60 61
      vbroadcastss(zmm_t(x_reg_idx), ptr[param_x + x_offset]);
      // clean
      if (k == 0) {
        for (int i = 0; i < groups[g]; ++i) {
          vxorps(zmm_t(i), zmm_t(i), zmm_t(i));
        }
      }
      for (int i = 0; i < groups[g]; ++i) {
W
Wilber 已提交
62 63
        vmovups(zmm_t(w_reg_idx),
                ptr[reg_ptr_wgt + wgt_offset + k * n_ * sizeof(float)]);
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
        vfmadd231ps(zmm_t(i), zmm_t(w_reg_idx), zmm_t(x_reg_idx));
        wgt_offset += block_len;
      }
      // last one, save
      if (k == k_ - 1) {
        for (int i = 0; i < groups[g]; ++i) {
          // only rest save should be careful
          if (rest != 0 && g == groups.size() - 1 && i == groups[g] - 1) {
            break;
          }
          vmovups(ptr[param_z + z_offset + i * block_len], zmm_t(i));
        }
      }
      x_offset += sizeof(float);
    }
    z_offset += block_len * groups[g];
  }

  if (rest != 0) {
    // below should refine with mask
    int reg_idx = groups.back() - 1;
    z_offset = (n_ - rest) * sizeof(float);
    int inner_block = 8;
    while (rest > 0) {
      if (rest >= 8) {
        inner_block = 8;
        vmovups(ptr[param_z + z_offset], ymm_t(reg_idx));
        // shift zmm of inner_block, change reg_idx if update
      } else if (rest >= 4) {
        inner_block = 4;
        vmovups(ptr[param_z + z_offset], xmm_t(reg_idx));
      } else if (rest >= 2) {
        inner_block = 2;
        vmovq(ptr[param_z + z_offset], xmm_t(reg_idx));
      } else {
        inner_block = 1;
        vmovss(ptr[param_z + z_offset], xmm_t(reg_idx));
      }
      z_offset += inner_block * sizeof(float);
      rest -= inner_block;
    }
  }

  postCode();
}

class MatMulCreator : public JitCodeCreator<matmul_attr_t> {
 public:
112
  bool CanBeUsed(const matmul_attr_t& attr) const override {
113 114 115 116 117 118 119 120 121 122 123 124
    return attr.m == 1 && platform::MayIUse(platform::avx512f) &&
           attr.n % ZMM_FLOAT_BLOCK == 0 && attr.k < 512;
  }
  size_t CodeSize(const matmul_attr_t& attr) const override {
    int block = YMM_FLOAT_BLOCK;
    if (platform::MayIUse(platform::avx512f)) {
      block = ZMM_FLOAT_BLOCK;
    }
    return 96 + 4 * attr.k * (attr.n / block + 1) * 8;
  }
  std::unique_ptr<GenBase> CreateJitCode(
      const matmul_attr_t& attr) const override {
G
GaoWei8 已提交
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
    PADDLE_ENFORCE_GT(
        attr.m, 0, platform::errors::InvalidArgument(
                       "The attribute m (first matrix's row) of MatMul should "
                       "be larger than 0. But it is %d.",
                       attr.m));
    PADDLE_ENFORCE_GT(
        attr.n, 0, platform::errors::InvalidArgument(
                       "The attribute n (first matrix's col) of MatMul should "
                       "be larger than 0. But it is %d.",
                       attr.n));
    PADDLE_ENFORCE_GT(
        attr.k, 0, platform::errors::InvalidArgument(
                       "The attribute k (second matrix's col) of MatMul should "
                       "be larger than 0. But it is %d.",
                       attr.k));
140 141 142 143 144 145 146 147 148 149 150 151
    return make_unique<MatMulJitCode>(attr, CodeSize(attr));
  }
};

}  // namespace gen
}  // namespace jit
}  // namespace operators
}  // namespace paddle

namespace gen = paddle::operators::jit::gen;

REGISTER_JITKERNEL_GEN(kMatMul, gen::MatMulCreator);