jit_code.h 4.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

T
tensor-tang 已提交
17
#include <string>
18 19 20 21 22 23 24 25 26 27 28 29 30 31
#include "paddle/fluid/operators/math/jit_gen.h"
namespace paddle {
namespace operators {
namespace math {
namespace jitkernel {
namespace gen {

using reg64_t = const Xbyak::Reg64;
using reg32_t = const Xbyak::Reg32;
using xmm_t = const Xbyak::Xmm;
using ymm_t = const Xbyak::Ymm;
using zmm_t = const Xbyak::Zmm;
using Label = Xbyak::Label;

32 33 34 35 36 37 38 39 40 41
typedef enum {
  mul = 0,
  add,
  sub,
  relu,
  exp,
  sigmoid,
  tanh,
  identity
} operand_type;
42

T
tensor-tang 已提交
43
// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu)
T
tensor-tang 已提交
44
class VXXJitCode : public JitCode {
T
tensor-tang 已提交
45
 public:
T
tensor-tang 已提交
46
  const char* name() const override {
T
tensor-tang 已提交
47
    std::string base = "VXXJitCode";
T
tensor-tang 已提交
48 49 50 51 52
    if (scalar_index_ == 1) {
      base += "_Scalar";
    } else {
      base += "_Vec";
    }
T
tensor-tang 已提交
53 54 55 56 57
    if (type_ == operand_type::mul) {
      base += "_Mul";
    } else if (type_ == operand_type::add) {
      base += "_Add";
    }
T
tensor-tang 已提交
58 59 60 61 62
    if (scalar_index_ == 2) {
      base += "_Scalar";
    } else {
      base += "_Vec";
    }
T
tensor-tang 已提交
63
    base += (with_relu_ ? "_Relu" : "");
T
tensor-tang 已提交
64 65
    return base.c_str();
  }
T
tensor-tang 已提交
66 67 68
  explicit VXXJitCode(int d, operand_type type, int scalar_index,
                      bool with_relu, size_t code_size = 256 * 1024,
                      void* code_ptr = nullptr)
T
tensor-tang 已提交
69 70 71
      : JitCode(code_size, code_ptr),
        num_(d),
        type_(type),
T
tensor-tang 已提交
72
        scalar_index_(scalar_index),
T
tensor-tang 已提交
73
        with_relu_(with_relu) {}
T
tensor-tang 已提交
74
  static bool init(int d, int scalar_index = 0);
T
tensor-tang 已提交
75 76 77 78
  void generate() override;

 private:
  int num_;
T
tensor-tang 已提交
79
  operand_type type_;
T
tensor-tang 已提交
80
  int scalar_index_;
T
tensor-tang 已提交
81
  bool with_relu_;
T
tensor-tang 已提交
82 83 84 85 86 87
  reg64_t param1{abi_param1};
  reg64_t param2{abi_param2};
  reg64_t param3{abi_param3};

  xmm_t xmm_src1 = xmm_t(0);
  xmm_t xmm_src2 = xmm_t(1);
T
tensor-tang 已提交
88 89
  xmm_t xmm_dst = xmm_t(2);
  xmm_t xmm_zero = xmm_t(3);
T
tensor-tang 已提交
90 91 92

  ymm_t ymm_src1 = ymm_t(0);
  ymm_t ymm_src2 = ymm_t(1);
T
tensor-tang 已提交
93 94
  ymm_t ymm_dst = ymm_t(2);
  ymm_t ymm_zero = ymm_t(3);
T
tensor-tang 已提交
95 96
};

97
class VActJitCode : public JitCode {
T
tensor-tang 已提交
98
 public:
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
  const char* name() const override {
    std::string base = "VActJitCode";
    switch (type_) {
      case operand_type::relu:
        base += "_Relu";
        break;
      case operand_type::exp:
        base += "_Exp";
        break;
      case operand_type::sigmoid:
        base += "_Sigmoid";
        break;
      case operand_type::tanh:
        base += "_Tanh";
        break;
      case operand_type::identity:
        base += "_Identity";
        break;
      default:
        break;
    }
    return base.c_str();
  }
T
tensor-tang 已提交
122

123
  explicit VActJitCode(int d, operand_type type, size_t code_size = 256 * 1024,
T
tensor-tang 已提交
124
                       void* code_ptr = nullptr)
125 126
      : JitCode(code_size, code_ptr), num_(d), type_(type) {}
  static bool init(int d, operand_type type);
T
tensor-tang 已提交
127 128
  void generate() override;

T
tensor-tang 已提交
129
 protected:
130 131 132
  // compute relu with ymm
  void relu_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src,
                const Xbyak::Ymm& zero);
T
tensor-tang 已提交
133

134 135 136
  // compute exp with ymm
  void exp_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2,
               int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5);
T
tensor-tang 已提交
137 138

  // compute sigmoid with ymm
139 140
  void sigmoid_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2,
                   int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5);
T
tensor-tang 已提交
141

142 143 144
  // compute tanh with ymm
  void tanh_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2,
                int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5);
T
tensor-tang 已提交
145

146
 protected:
T
tensor-tang 已提交
147
  int num_;
148
  operand_type type_;
T
tensor-tang 已提交
149 150
  reg64_t param1{abi_param1};
  reg64_t param2{abi_param2};
151 152

  xmm_t xmm_src = xmm_t(0);
T
tensor-tang 已提交
153
  ymm_t ymm_src = ymm_t(0);
154 155

  xmm_t xmm_dst = xmm_t(1);
T
tensor-tang 已提交
156 157 158
  ymm_t ymm_dst = ymm_t(1);
};

159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
#ifdef PADDLE_WITH_MKLDNN
struct EltwiseMulnChw16cNC : public Xbyak::CodeGenerator {
  explicit EltwiseMulnChw16cNC(size_t code_size = 256 * 1024)
      : Xbyak::CodeGenerator(code_size) {
    // RDI is ptr x_input
    // RSI is ptr y_input
    // RDX is ptr output
    // RCX is height
    // r8 is width

    push(rbx);

    xor_(rax, rax);
    xor_(r10, r10);
    vmovups(zmm3, ptr[rsi]);

    L("h_loop");
    xor_(rbx, rbx);
    L("w_loop");
    vmovups(zmm2, ptr[rdi + rax]);
    vmulps(zmm1, zmm2, zmm3);
    vmovups(ptr[rdx + rax], zmm1);
    add(rax, 64);
    inc(rbx);
    cmp(r8, rbx);
    jnz("w_loop");
    inc(r10);
    cmp(r10, rcx);
    jnz("h_loop");

    pop(rbx);
    ret();
  }
};
#endif

195 196 197 198 199
}  // namespace gen
}  // namespace jitkernel
}  // namespace math
}  // namespace operators
}  // namespace paddle