vbroadcast.cc 2.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License. */

#include "paddle/fluid/operators/jit/gen/vbroadcast.h"
#include <memory>
#include <vector>
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"

namespace paddle {
namespace operators {
namespace jit {
namespace gen {

void VBroadcastJitCode::genCode() {
  preCode();
  constexpr int block = YMM_FLOAT_BLOCK;
  constexpr int max_num_regs = 16;
  const int num_block = w_ / block;
  const int num_groups = num_block / max_num_regs;
  const size_t block_size = sizeof(float) * block;
  std::vector<int> groups(num_groups, max_num_regs);
  int rest_num_regs = num_block % max_num_regs;
  if (rest_num_regs > 0) {
    groups.push_back(rest_num_regs);
  }

  // protect param_h
  const size_t width_in_byte = sizeof(float) * w_;
  mov(reg_height, param_h);
  int acc_num_regs = 0;
  for (int num_regs : groups) {
    mov(reg_ptr_src_i, param_src);
    add(reg_ptr_src_i, acc_num_regs * block_size);
    size_t w_offset = 0;
    for (int reg_i = 0; reg_i < num_regs; ++reg_i) {
      vmovups(ymm_t(reg_i), ptr[reg_ptr_src_i + w_offset]);
      w_offset += block_size;
    }

    Label l_next_h;
    xor_(reg_h_i, reg_h_i);
    mov(reg_ptr_dst_i, param_dst);
    add(reg_ptr_dst_i, acc_num_regs * block_size);
    L(l_next_h);
    {
      w_offset = 0;
      for (int reg_i = 0; reg_i < num_regs; ++reg_i) {
        vmovups(ptr[reg_ptr_dst_i + w_offset], ymm_t(reg_i));
        w_offset += block_size;
      }
      add(reg_ptr_dst_i, width_in_byte);
      inc(reg_h_i);
      cmp(reg_h_i, reg_height);
      jl(l_next_h, T_NEAR);
    }  // end of l_next_h
    acc_num_regs += num_regs;
  }  // end of groups
  postCode();
}

class VBroadcastCreator : public JitCodeCreator<int64_t> {
 public:
  bool UseMe(const int64_t& w) const override {
    return platform::MayIUse(platform::avx) && w % YMM_FLOAT_BLOCK == 0;
  }
  size_t CodeSize(const int64_t& w) const override {
    return 96 + (w / YMM_FLOAT_BLOCK) * 16 * 8;
  }
  std::unique_ptr<GenBase> CreateJitCode(const int64_t& w) const override {
    PADDLE_ENFORCE_GT(w, 0);
    return make_unique<VBroadcastJitCode>(w, CodeSize(w));
  }
};

}  // namespace gen
}  // namespace jit
}  // namespace operators
}  // namespace paddle

namespace gen = paddle::operators::jit::gen;

REGISTER_JITKERNEL_GEN(kVBroadcast, gen::VBroadcastCreator);