vbroadcast.cc 2.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License. */

#include "paddle/fluid/operators/jit/gen/vbroadcast.h"
#include <memory>
#include <vector>
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"

namespace paddle {
namespace operators {
namespace jit {
namespace gen {

void VBroadcastJitCode::genCode() {
  preCode();
  constexpr int block = YMM_FLOAT_BLOCK;
  constexpr int max_num_regs = 16;
  const int num_block = w_ / block;
  const int num_groups = num_block / max_num_regs;
  const size_t block_size = sizeof(float) * block;
  std::vector<int> groups(num_groups, max_num_regs);
  int rest_num_regs = num_block % max_num_regs;
  if (rest_num_regs > 0) {
    groups.push_back(rest_num_regs);
  }

  // protect param_h
  mov(reg_height, param_h);
T
tensor-tang 已提交
41 42 43 44 45
  Label l_next_h;
  xor_(reg_h_i, reg_h_i);
  mov(reg_ptr_dst_i, param_dst);
  L(l_next_h);
  {
46
    mov(reg_ptr_src_i, param_src);
T
tensor-tang 已提交
47 48 49 50 51 52 53
    for (int num_regs : groups) {
      size_t w_offset = 0;
      for (int reg_i = 0; reg_i < num_regs; ++reg_i) {
        vmovups(ymm_t(reg_i), ptr[reg_ptr_src_i + w_offset]);
        w_offset += block_size;
      }
      add(reg_ptr_src_i, num_regs * block_size);
54 55 56 57 58 59

      w_offset = 0;
      for (int reg_i = 0; reg_i < num_regs; ++reg_i) {
        vmovups(ptr[reg_ptr_dst_i + w_offset], ymm_t(reg_i));
        w_offset += block_size;
      }
T
tensor-tang 已提交
60 61 62 63 64 65 66
      add(reg_ptr_dst_i, num_regs * block_size);
    }  // end of groups
    inc(reg_h_i);
    cmp(reg_h_i, reg_height);
    jl(l_next_h, T_NEAR);
  }  // end of l_next_h

67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
  postCode();
}

class VBroadcastCreator : public JitCodeCreator<int64_t> {
 public:
  bool UseMe(const int64_t& w) const override {
    return platform::MayIUse(platform::avx) && w % YMM_FLOAT_BLOCK == 0;
  }
  size_t CodeSize(const int64_t& w) const override {
    return 96 + (w / YMM_FLOAT_BLOCK) * 16 * 8;
  }
  std::unique_ptr<GenBase> CreateJitCode(const int64_t& w) const override {
    PADDLE_ENFORCE_GT(w, 0);
    return make_unique<VBroadcastJitCode>(w, CodeSize(w));
  }
};

}  // namespace gen
}  // namespace jit
}  // namespace operators
}  // namespace paddle

namespace gen = paddle::operators::jit::gen;

REGISTER_JITKERNEL_GEN(kVBroadcast, gen::VBroadcastCreator);