jit_kernel_exp.cc 25.4 KB
Newer Older
T
tensor-tang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/math/jit_kernel.h"
T
tensor-tang 已提交
16
#include <cmath>  // for exp
T
tensor-tang 已提交
17 18
#include <string>
#include "paddle/fluid/operators/math/jit_kernel_macro.h"
T
tensor-tang 已提交
19 20 21 22 23

#ifdef PADDLE_WITH_XBYAK
#include "paddle/fluid/operators/math/jit_code.h"
#endif

T
tensor-tang 已提交
24 25 26 27
#ifdef PADDLE_WITH_MKLML
#include "paddle/fluid/platform/dynload/mklml.h"
#endif

28 29 30 31
#ifdef __AVX__
#include <immintrin.h>
#endif

T
tensor-tang 已提交
32 33 34 35 36 37
namespace paddle {
namespace operators {
namespace math {
namespace jitkernel {
namespace jit = platform::jit;

T
tensor-tang 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
// TODO(TJ): move refer codes to one file
template <typename T>
void VExpRefer(const T* x, T* y, int n) {
  for (int i = 0; i < n; ++i) {
    y[i] = std::exp(x[i]);
  }
}

#ifdef PADDLE_WITH_MKLML
template <typename T>
void VExpMKL(const T* x, T* y, int n);

template <>
void VExpMKL<float>(const float* x, float* y, int n) {
  platform::dynload::vsExp(n, x, y);
}

template <>
void VExpMKL<double>(const double* x, double* y, int n) {
  platform::dynload::vdExp(n, x, y);
}
#endif

T
tensor-tang 已提交
61
/* VExp JitKernel */
T
tensor-tang 已提交
62
template <typename T>
T
tensor-tang 已提交
63 64
class VExpKernelImpl : public VExpKernel<T> {
 public:
T
tensor-tang 已提交
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
  JITKERNEL_DECLARE_STATIC_FUNC;
  explicit VExpKernelImpl(int d) : VExpKernel<T>() {
    this->num_ = d;  // TODO(TJ): remove me when ComputeDeprecated done
#ifdef PADDLE_WITH_XBYAK
    if (useJIT(d)) {
      size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;  // should change
      jitcode_.reset(new gen::VExpJitCode(d, sz > 4096 ? sz : 4096));
      this->Compute = jitcode_->getCode<void (*)(const T*, T*, int)>();
      return;
    }
#endif
#ifdef PADDLE_WITH_MKLML
    if (useMKL(d)) {
      this->Compute = VExpMKL<T>;
      return;
T
tensor-tang 已提交
80
    }
T
tensor-tang 已提交
81 82 83 84 85
#endif
    this->Compute = VExpRefer<T>;
  }
  void ComputeDeprecated(const T* x, T* y) const override {
    VExpRefer(x, y, this->num_);
T
tensor-tang 已提交
86
  }
T
tensor-tang 已提交
87 88 89 90 91
#ifdef PADDLE_WITH_XBYAK

 private:
  std::unique_ptr<gen::VExpJitCode> jitcode_{nullptr};
#endif
T
tensor-tang 已提交
92 93
};

T
tensor-tang 已提交
94 95 96 97 98 99 100
#ifdef PADDLE_WITH_XBYAK
template <>
bool VExpKernelImpl<float>::useJIT(int d) {
  return gen::VExpJitCode::init(d);
}
#endif

T
tensor-tang 已提交
101
#ifdef PADDLE_WITH_MKLML
T
tensor-tang 已提交
102 103 104 105
template <>
bool VExpKernelImpl<float>::useMKL(int d) {
  return d > 512;
}
T
tensor-tang 已提交
106

T
tensor-tang 已提交
107 108 109 110
template <>
bool VExpKernelImpl<double>::useMKL(int d) {
  return true;
}
T
tensor-tang 已提交
111 112
#endif

T
tensor-tang 已提交
113
REGISTER_JITKERNEL(vexp, VExpKernel);
114

T
tensor-tang 已提交
115
namespace detail {
116 117 118 119

#define ALIGN32 __attribute__((aligned(32)))

#define _PS256_CONST(Name, Val)                                      \
120
  static const float _ps256_##Name[8] ALIGN32 = {Val, Val, Val, Val, \
121 122 123
                                                 Val, Val, Val, Val}

#define _PI256_CONST(Name, Val)                                    \
124
  static const int _pi256_##Name[8] ALIGN32 = {Val, Val, Val, Val, \
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
                                               Val, Val, Val, Val}

_PI256_CONST(0x7f, 0x7f);
_PS256_CONST(one, 1.f);
_PS256_CONST(0p5, 0.5f);
_PS256_CONST(exp_hi, 88.3762626647949f);
_PS256_CONST(exp_lo, -88.3762626647949f);
_PS256_CONST(cephes_LOG2EF, 1.44269504088896341);
_PS256_CONST(cephes_exp_C1, 0.693359375);
_PS256_CONST(cephes_exp_C2, -2.12194440e-4);
_PS256_CONST(cephes_exp_p0, 1.9875691500E-4);
_PS256_CONST(cephes_exp_p1, 1.3981999507E-3);
_PS256_CONST(cephes_exp_p2, 8.3334519073E-3);
_PS256_CONST(cephes_exp_p3, 4.1665795894E-2);
_PS256_CONST(cephes_exp_p4, 1.6666665459E-1);
_PS256_CONST(cephes_exp_p5, 5.0000001201E-1);

typedef union imm_xmm_union {
  __m256i imm;
  __m128i xmm[2];
} imm_xmm_union;

#define COPY_IMM_TO_XMM(imm_, xmm0_, xmm1_) \
  {                                         \
149
    imm_xmm_union u ALIGN32;                \
150 151 152 153 154 155 156
    u.imm = imm_;                           \
    xmm0_ = u.xmm[0];                       \
    xmm1_ = u.xmm[1];                       \
  }

#define COPY_XMM_TO_IMM(xmm0_, xmm1_, imm_) \
  {                                         \
157
    imm_xmm_union u ALIGN32;                \
158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191
    u.xmm[0] = xmm0_;                       \
    u.xmm[1] = xmm1_;                       \
    imm_ = u.imm;                           \
  }

#define AVX2_BITOP_USING_SSE2(fn)                           \
  static inline __m256i avx2_mm256_##fn(__m256i x, int y) { \
    /* use SSE2 to perform the bitop AVX2 */                \
    __m128i x1, x2;                                         \
    __m256i ret;                                            \
    COPY_IMM_TO_XMM(x, x1, x2);                             \
    x1 = _mm_##fn(x1, y);                                   \
    x2 = _mm_##fn(x2, y);                                   \
    COPY_XMM_TO_IMM(x1, x2, ret);                           \
    return ret;                                             \
  }

#define AVX2_INTOP_USING_SSE2(fn)                                    \
  static inline __m256i avx2_mm256_add_epi32(__m256i x, __m256i y) { \
    /* use SSE2 to perform the AVX2 integer operation */             \
    __m128i x1, x2;                                                  \
    __m128i y1, y2;                                                  \
    __m256i ret;                                                     \
    COPY_IMM_TO_XMM(x, x1, x2);                                      \
    COPY_IMM_TO_XMM(y, y1, y2);                                      \
    x1 = _mm_##fn(x1, y1);                                           \
    x2 = _mm_##fn(x2, y2);                                           \
    COPY_XMM_TO_IMM(x1, x2, ret);                                    \
    return ret;                                                      \
  }

AVX2_BITOP_USING_SSE2(slli_epi32);
AVX2_INTOP_USING_SSE2(add_epi32);

T
tensor-tang 已提交
192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235
#define AVXEXP_BASE                                                            \
  __m256 tmp = _mm256_setzero_ps(), fx;                                        \
  __m256 one = *reinterpret_cast<const __m256*>(_ps256_one);                   \
  __m256i imm0;                                                                \
  x = _mm256_min_ps(x, *reinterpret_cast<const __m256*>(_ps256_exp_hi));       \
  x = _mm256_max_ps(x, *reinterpret_cast<const __m256*>(_ps256_exp_lo));       \
  /* express exp(x) as exp(g + n*log(2)) */                                    \
  fx = _mm256_mul_ps(x,                                                        \
                     *reinterpret_cast<const __m256*>(_ps256_cephes_LOG2EF));  \
  fx = _mm256_add_ps(fx, *reinterpret_cast<const __m256*>(_ps256_0p5));        \
  tmp = _mm256_floor_ps(fx);                                                   \
  /* if greater, substract 1 */                                                \
  __m256 mask = _mm256_cmp_ps(tmp, fx, _CMP_GT_OS);                            \
  mask = _mm256_and_ps(mask, one);                                             \
  fx = _mm256_sub_ps(tmp, mask);                                               \
  tmp = _mm256_mul_ps(fx,                                                      \
                      *reinterpret_cast<const __m256*>(_ps256_cephes_exp_C1)); \
  __m256 z = _mm256_mul_ps(                                                    \
      fx, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_C2));             \
  x = _mm256_sub_ps(x, tmp);                                                   \
  x = _mm256_sub_ps(x, z);                                                     \
  z = _mm256_mul_ps(x, x);                                                     \
  __m256 y = *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p0);           \
  y = _mm256_mul_ps(y, x);                                                     \
  y = _mm256_add_ps(y,                                                         \
                    *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p1));   \
  y = _mm256_mul_ps(y, x);                                                     \
  y = _mm256_add_ps(y,                                                         \
                    *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p2));   \
  y = _mm256_mul_ps(y, x);                                                     \
  y = _mm256_add_ps(y,                                                         \
                    *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p3));   \
  y = _mm256_mul_ps(y, x);                                                     \
  y = _mm256_add_ps(y,                                                         \
                    *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p4));   \
  y = _mm256_mul_ps(y, x);                                                     \
  y = _mm256_add_ps(y,                                                         \
                    *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p5));   \
  y = _mm256_mul_ps(y, z);                                                     \
  y = _mm256_add_ps(y, x);                                                     \
  y = _mm256_add_ps(y, one);                                                   \
  /* build 2^n */                                                              \
  imm0 = _mm256_cvttps_epi32(fx)

236
__m256 ExpAVX(__m256 x) {
T
tensor-tang 已提交
237
  AVXEXP_BASE;
238 239 240 241 242 243 244 245 246 247 248
  // two AVX2 instructions using SSE2
  imm0 = avx2_mm256_add_epi32(imm0,
                              *reinterpret_cast<const __m256i*>(_pi256_0x7f));
  imm0 = avx2_mm256_slli_epi32(imm0, 23);
  __m256 pow2n = _mm256_castsi256_ps(imm0);
  y = _mm256_mul_ps(y, pow2n);
  return y;
}

#ifdef __AVX2__
__m256 ExpAVX2(__m256 x) {
T
tensor-tang 已提交
249
  AVXEXP_BASE;
250 251 252 253 254 255 256 257 258 259 260
  // two AVX2 instructions
  imm0 = _mm256_add_epi32(imm0, *reinterpret_cast<const __m256i*>(_pi256_0x7f));
  imm0 = _mm256_slli_epi32(imm0, 23);
  __m256 pow2n = _mm256_castsi256_ps(imm0);
  y = _mm256_mul_ps(y, pow2n);
  return y;
}
#endif

}  // namespace detail

T
tensor-tang 已提交
261 262 263 264 265
/* VSigmoid JitKernel */
template <typename T, jit::cpu_isa_t isa, jit_block>
class VSigmoidKernelImpl : public VSigmoidKernel<T> {
 public:
  explicit VSigmoidKernelImpl(int d) : VSigmoidKernel<T>() {
T
tensor-tang 已提交
266
    this->num_ = d;
T
tensor-tang 已提交
267 268
    vexp_ = KernelPool::Instance().template Get<VExpKernel<T>>(d);
  }
T
tensor-tang 已提交
269
  void ComputeDeprecated(const T* x, T* y) const override {
T
tensor-tang 已提交
270 271
    const T min = SIGMOID_THRESHOLD_MIN;
    const T max = SIGMOID_THRESHOLD_MAX;
T
tensor-tang 已提交
272
    for (int i = 0; i < this->num_; ++i) {
T
tensor-tang 已提交
273 274 275
      y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
      y[i] = static_cast<T>(0) - y[i];
    }
T
tensor-tang 已提交
276
    vexp_->ComputeDeprecated(y, y);
T
tensor-tang 已提交
277
    for (int i = 0; i < this->num_; ++i) {
T
tensor-tang 已提交
278 279 280 281 282 283 284 285
      y[i] = static_cast<T>(1) / (static_cast<T>(1) + y[i]);
    }
  }

 private:
  std::shared_ptr<const VExpKernel<T>> vexp_;
};

286
#define INTRI_SIGMOID(tmp, min, max, expisa)      \
287 288 289
  tmp = _mm256_max_ps(tmp, min);                  \
  tmp = _mm256_min_ps(tmp, max);                  \
  tmp = _mm256_sub_ps(_mm256_set1_ps(0.0f), tmp); \
290
  tmp = expisa(tmp);                              \
291 292 293
  tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \
  tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp)

T
tensor-tang 已提交
294 295 296 297 298 299 300 301 302 303
#define INTRI8_FLOAT(isa, expisa)                               \
  template <>                                                   \
  void VSigmoidKernelImpl<float, isa, kEQ8>::ComputeDeprecated( \
      const float* x, float* y) const {                         \
    /* TODO(TJ): try to use static const*/                      \
    __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX);         \
    __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN);         \
    __m256 tmp = _mm256_loadu_ps(x);                            \
    INTRI_SIGMOID(tmp, min, max, expisa);                       \
    _mm256_storeu_ps(y, tmp);                                   \
304 305
  }

T
tensor-tang 已提交
306 307 308 309 310 311 312 313 314 315 316 317
#define INTRI16_FLOAT(isa, expisa)                               \
  template <>                                                    \
  void VSigmoidKernelImpl<float, isa, kEQ16>::ComputeDeprecated( \
      const float* x, float* y) const {                          \
    __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX);          \
    __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN);          \
    __m256 tmp0 = _mm256_loadu_ps(x);                            \
    __m256 tmp1 = _mm256_loadu_ps(x + 8);                        \
    INTRI_SIGMOID(tmp0, min, max, expisa);                       \
    INTRI_SIGMOID(tmp1, min, max, expisa);                       \
    _mm256_storeu_ps(y, tmp0);                                   \
    _mm256_storeu_ps(y + 8, tmp1);                               \
318 319
  }

320
#define INTRI_GT8LT16_FLOAT(isa, expisa)                                     \
T
tensor-tang 已提交
321 322 323 324 325 326 327 328 329 330
  template <>                                                                \
  VSigmoidKernelImpl<float, isa, kGT8LT16>::VSigmoidKernelImpl(int d)        \
      : VSigmoidKernel<float>() {                                            \
    this->num_ = d;                                                          \
    this->end_ = AVX_FLOAT_BLOCK;                                            \
    this->rest_ = d - this->end_;                                            \
    vexp_ =                                                                  \
        KernelPool::Instance().template Get<VExpKernel<float>>(this->rest_); \
  }                                                                          \
  template <>                                                                \
T
tensor-tang 已提交
331 332
  void VSigmoidKernelImpl<float, isa, kGT8LT16>::ComputeDeprecated(          \
      const float* x, float* y) const {                                      \
T
tensor-tang 已提交
333 334 335
    __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX);                      \
    __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN);                      \
    __m256 tmp = _mm256_loadu_ps(x);                                         \
336
    INTRI_SIGMOID(tmp, min, max, expisa);                                    \
T
tensor-tang 已提交
337 338 339 340 341 342 343
    _mm256_storeu_ps(y, tmp);                                                \
    const float min_ = SIGMOID_THRESHOLD_MIN;                                \
    const float max_ = SIGMOID_THRESHOLD_MAX;                                \
    for (int i = this->end_; i < this->num_; ++i) {                          \
      y[i] = (x[i] < min_) ? min_ : ((x[i] > max_) ? max_ : x[i]);           \
      y[i] = 0.f - y[i];                                                     \
    }                                                                        \
T
tensor-tang 已提交
344
    vexp_->ComputeDeprecated(y + this->end_, y + this->end_);                \
T
tensor-tang 已提交
345 346 347
    for (int i = this->end_; i < this->num_; ++i) {                          \
      y[i] = 1.f / (1.f + y[i]);                                             \
    }                                                                        \
348 349
  }

350
#define INTRI_GT16_FLOAT(isa, expisa)                                        \
T
tensor-tang 已提交
351 352 353 354 355 356 357 358 359 360
  template <>                                                                \
  VSigmoidKernelImpl<float, isa, kGT16>::VSigmoidKernelImpl(int d)           \
      : VSigmoidKernel<float>() {                                            \
    this->num_ = d;                                                          \
    this->rest_ = d % AVX_FLOAT_BLOCK;                                       \
    this->end_ = d - this->rest_;                                            \
    vexp_ =                                                                  \
        KernelPool::Instance().template Get<VExpKernel<float>>(this->rest_); \
  }                                                                          \
  template <>                                                                \
T
tensor-tang 已提交
361 362
  void VSigmoidKernelImpl<float, isa, kGT16>::ComputeDeprecated(             \
      const float* x, float* y) const {                                      \
T
tensor-tang 已提交
363 364 365 366
    __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX);                      \
    __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN);                      \
    for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) {                  \
      __m256 tmp = _mm256_loadu_ps(x + i);                                   \
367
      INTRI_SIGMOID(tmp, min, max, expisa);                                  \
T
tensor-tang 已提交
368 369 370 371 372 373 374 375
      _mm256_storeu_ps(y + i, tmp);                                          \
    }                                                                        \
    const float min_ = SIGMOID_THRESHOLD_MIN;                                \
    const float max_ = SIGMOID_THRESHOLD_MAX;                                \
    for (int i = this->end_; i < this->num_; ++i) {                          \
      y[i] = (x[i] < min_) ? min_ : ((x[i] > max_) ? max_ : x[i]);           \
      y[i] = 0.f - y[i];                                                     \
    }                                                                        \
T
tensor-tang 已提交
376
    vexp_->ComputeDeprecated(y + this->end_, y + this->end_);                \
T
tensor-tang 已提交
377 378 379
    for (int i = this->end_; i < this->num_; ++i) {                          \
      y[i] = 1.f / (1.f + y[i]);                                             \
    }                                                                        \
380 381 382
  }

#ifdef __AVX__
383 384 385 386
INTRI8_FLOAT(jit::avx, detail::ExpAVX);
INTRI16_FLOAT(jit::avx, detail::ExpAVX);
INTRI_GT8LT16_FLOAT(jit::avx, detail::ExpAVX);
INTRI_GT16_FLOAT(jit::avx, detail::ExpAVX);
387 388
#endif
#ifdef __AVX2__
389 390 391
INTRI8_FLOAT(jit::avx2, detail::ExpAVX2);
INTRI16_FLOAT(jit::avx2, detail::ExpAVX2);
// maybe use avx at gt8lt16 and gt16
392 393
#endif
#ifdef __AVX512F__
394 395 396
INTRI8_FLOAT(jit::avx512f, detail::ExpAVX2);
INTRI16_FLOAT(jit::avx512f, detail::ExpAVX2);
// maybe use avx2 at gt8lt16 and gt16
397 398 399 400 401 402
#endif

#undef INTRI8_FLOAT
#undef INTRI16_FLOAT
#undef INTRI_GT8LT16_FLOAT
#undef INTRI_GT16_FLOAT
T
tensor-tang 已提交
403
#undef INTRI_VSIGMOID
404

T
tensor-tang 已提交
405
REGISTER_JITKERNEL_DEPRECATED(vsigmoid, VSigmoidKernel);
T
tensor-tang 已提交
406

T
tensor-tang 已提交
407 408 409 410 411
/* VTanh JitKernel */
template <typename T, jit::cpu_isa_t isa, jit_block>
class VTanhKernelImpl : public VTanhKernel<T> {
 public:
  explicit VTanhKernelImpl(int d) : VTanhKernel<T>() {
T
tensor-tang 已提交
412
    this->num_ = d;
T
tensor-tang 已提交
413 414 415 416
    vscal_ = KernelPool::Instance().template Get<VScalKernel<T>>(d);
    vsigmoid_ = KernelPool::Instance().template Get<VSigmoidKernel<T>>(d);
    vaddbias_ = KernelPool::Instance().template Get<VAddBiasKernel<T>>(d);
  }
T
tensor-tang 已提交
417
  void ComputeDeprecated(const T* x, T* y) const override {
T
tensor-tang 已提交
418
    const T a = static_cast<T>(2), b = static_cast<T>(-1);
T
tensor-tang 已提交
419
    vscal_->Compute(&a, x, y, this->num_);
T
tensor-tang 已提交
420
    vsigmoid_->ComputeDeprecated(y, y);
T
tensor-tang 已提交
421
    vscal_->Compute(&a, y, y, this->num_);
T
tensor-tang 已提交
422
    vaddbias_->Compute(&b, y, y, this->num_);
T
tensor-tang 已提交
423 424 425 426 427 428 429 430
  }

 private:
  std::shared_ptr<const VScalKernel<T>> vscal_;
  std::shared_ptr<const VSigmoidKernel<T>> vsigmoid_;
  std::shared_ptr<const VAddBiasKernel<T>> vaddbias_;
};

431
#define INTRI_VTANH(tmp, expisa)                           \
T
tensor-tang 已提交
432 433
  tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), tmp);         \
  tmp = _mm256_min_ps(tmp, _mm256_set1_ps(EXP_MAX_INPUT)); \
434
  tmp = expisa(tmp);                                       \
T
tensor-tang 已提交
435 436 437 438
  tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp);          \
  tmp = _mm256_div_ps(_mm256_set1_ps(2.0f), tmp);          \
  tmp = _mm256_sub_ps(tmp, _mm256_set1_ps(1.0f))

T
tensor-tang 已提交
439 440 441 442 443 444 445
#define INTRI8_FLOAT(isa, expisa)                                             \
  template <>                                                                 \
  void VTanhKernelImpl<float, isa, kEQ8>::ComputeDeprecated(const float* x,   \
                                                            float* y) const { \
    __m256 tmp = _mm256_loadu_ps(x);                                          \
    INTRI_VTANH(tmp, expisa);                                                 \
    _mm256_storeu_ps(y, tmp);                                                 \
T
tensor-tang 已提交
446 447
  }

T
tensor-tang 已提交
448 449 450 451 452 453 454 455 456 457
#define INTRI16_FLOAT(isa, expisa)                                             \
  template <>                                                                  \
  void VTanhKernelImpl<float, isa, kEQ16>::ComputeDeprecated(const float* x,   \
                                                             float* y) const { \
    __m256 tmp0 = _mm256_loadu_ps(x);                                          \
    __m256 tmp1 = _mm256_loadu_ps(x + 8);                                      \
    INTRI_VTANH(tmp0, expisa);                                                 \
    INTRI_VTANH(tmp1, expisa);                                                 \
    _mm256_storeu_ps(y, tmp0);                                                 \
    _mm256_storeu_ps(y + 8, tmp1);                                             \
T
tensor-tang 已提交
458 459
  }

460
#define INTRI_GT8LT16_FLOAT(isa, expisa)                                      \
T
tensor-tang 已提交
461 462 463 464 465 466 467 468 469 470 471 472 473 474
  template <>                                                                 \
  VTanhKernelImpl<float, isa, kGT8LT16>::VTanhKernelImpl(int d)               \
      : VTanhKernel<float>() {                                                \
    this->num_ = d;                                                           \
    this->end_ = AVX_FLOAT_BLOCK;                                             \
    this->rest_ = d - this->end_;                                             \
    vscal_ =                                                                  \
        KernelPool::Instance().template Get<VScalKernel<float>>(this->rest_); \
    vsigmoid_ = KernelPool::Instance().template Get<VSigmoidKernel<float>>(   \
        this->rest_);                                                         \
    vaddbias_ = KernelPool::Instance().template Get<VAddBiasKernel<float>>(   \
        this->rest_);                                                         \
  }                                                                           \
  template <>                                                                 \
T
tensor-tang 已提交
475 476
  void VTanhKernelImpl<float, isa, kGT8LT16>::ComputeDeprecated(              \
      const float* x, float* y) const {                                       \
T
tensor-tang 已提交
477
    __m256 tmp = _mm256_loadu_ps(x);                                          \
478
    INTRI_VTANH(tmp, expisa);                                                 \
T
tensor-tang 已提交
479 480 481
    _mm256_storeu_ps(y, tmp);                                                 \
    x += AVX_FLOAT_BLOCK;                                                     \
    y += AVX_FLOAT_BLOCK;                                                     \
T
tensor-tang 已提交
482
    const float a = 2.f, b = -1.f;                                            \
T
tensor-tang 已提交
483
    vscal_->Compute(&a, x, y, this->num_);                                    \
T
tensor-tang 已提交
484
    vsigmoid_->ComputeDeprecated(y, y);                                       \
T
tensor-tang 已提交
485
    vscal_->Compute(&a, y, y, this->num_);                                    \
T
tensor-tang 已提交
486
    vaddbias_->Compute(&b, y, y, this->num_);                                 \
T
tensor-tang 已提交
487 488
  }

T
tensor-tang 已提交
489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517
#define INTRI_GT16_FLOAT(isa, expisa)                                          \
  template <>                                                                  \
  VTanhKernelImpl<float, isa, kGT16>::VTanhKernelImpl(int d)                   \
      : VTanhKernel<float>() {                                                 \
    this->num_ = d;                                                            \
    this->rest_ = d % AVX_FLOAT_BLOCK;                                         \
    this->end_ = d - this->rest_;                                              \
    vscal_ =                                                                   \
        KernelPool::Instance().template Get<VScalKernel<float>>(this->rest_);  \
    vsigmoid_ = KernelPool::Instance().template Get<VSigmoidKernel<float>>(    \
        this->rest_);                                                          \
    vaddbias_ = KernelPool::Instance().template Get<VAddBiasKernel<float>>(    \
        this->rest_);                                                          \
  }                                                                            \
  template <>                                                                  \
  void VTanhKernelImpl<float, isa, kGT16>::ComputeDeprecated(const float* x,   \
                                                             float* y) const { \
    for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) {                    \
      __m256 tmp = _mm256_loadu_ps(x + i);                                     \
      INTRI_VTANH(tmp, expisa);                                                \
      _mm256_storeu_ps(y + i, tmp);                                            \
    }                                                                          \
    x += this->end_;                                                           \
    y += this->end_;                                                           \
    const float a = 2.f, b = -1.f;                                             \
    vscal_->Compute(&a, x, y, this->num_);                                     \
    vsigmoid_->ComputeDeprecated(y, y);                                        \
    vscal_->Compute(&a, y, y, this->num_);                                     \
    vaddbias_->Compute(&b, y, y, this->num_);                                  \
T
tensor-tang 已提交
518 519 520
  }

#ifdef __AVX__
521 522 523 524
INTRI8_FLOAT(jit::avx, detail::ExpAVX);
INTRI16_FLOAT(jit::avx, detail::ExpAVX);
INTRI_GT8LT16_FLOAT(jit::avx, detail::ExpAVX);
INTRI_GT16_FLOAT(jit::avx, detail::ExpAVX);
525
#endif
T
tensor-tang 已提交
526
#ifdef __AVX2__
527 528
INTRI8_FLOAT(jit::avx2, detail::ExpAVX2);
INTRI16_FLOAT(jit::avx2, detail::ExpAVX2);
T
tensor-tang 已提交
529 530 531
// maybe use avx at gt8lt16 and gt16
#endif
#ifdef __AVX512F__
532 533
INTRI8_FLOAT(jit::avx512f, detail::ExpAVX2);
INTRI16_FLOAT(jit::avx512f, detail::ExpAVX2);
T
tensor-tang 已提交
534 535 536 537 538 539 540 541 542
// maybe use avx at gt8lt16 and gt16
#endif

#undef INTRI8_FLOAT
#undef INTRI16_FLOAT
#undef INTRI_GT8LT16_FLOAT
#undef INTRI_GT16_FLOAT
#undef INTRI_VTANH

T
tensor-tang 已提交
543
REGISTER_JITKERNEL_DEPRECATED(vtanh, VTanhKernel);
T
tensor-tang 已提交
544

T
tensor-tang 已提交
545
#undef JITKERNEL_NEW_ACT_IMPL
546

T
tensor-tang 已提交
547 548 549 550
}  // namespace jitkernel
}  // namespace math
}  // namespace operators
}  // namespace paddle