avx_mathfun.h 25.7 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
//  Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//    http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
   AVX implementation of sin, cos, sincos, exp and log

   Based on "sse_mathfun.h", by Julien Pommier
   http://gruntthepeon.free.fr/ssemath/

   Copyright (C) 2012 Giovanni Garberoglio
   Interdisciplinary Laboratory for Computational Science (LISC)
   Fondazione Bruno Kessler and University of Trento
   via Sommarive, 18
   I-38123 Trento (Italy)

  This software is provided 'as-is', without any express or implied
  warranty.  In no event will the authors be held liable for any damages
  arising from the use of this software.

  Permission is granted to anyone to use this software for any purpose,
  including commercial applications, and to alter it and redistribute it
  freely, subject to the following restrictions:

  1. The origin of this software must not be misrepresented; you must not
     claim that you wrote the original software. If you use this software
     in a product, an acknowledgment in the product documentation would be
     appreciated but is not required.
  2. Altered source versions must be plainly marked as such, and must not be
     misrepresented as being the original software.
  3. This notice may not be removed or altered from any source distribution.

  (this is the zlib license)
*/
44
#pragma once
45
#include "lite/backends/x86/cpu_info.h"
Y
Yan Chunwei 已提交
46

47 48
namespace paddle {
namespace lite {
Y
Yan Chunwei 已提交
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
/* __m128 is ugly to write */
typedef __m256 v8sf;   // vector of 8 float (avx)
typedef __m256i v8si;  // vector of 8 int   (avx)
typedef __m128i v4si;  // vector of 8 int   (avx)

#define _PI32AVX_CONST(Name, Val)                                 \
  static const ALIGN32_BEG int _pi32avx_##Name[4] ALIGN32_END = { \
      Val, Val, Val, Val}

_PI32AVX_CONST(1, 1);
_PI32AVX_CONST(inv1, ~1);
_PI32AVX_CONST(2, 2);
_PI32AVX_CONST(4, 4);

/* declare some AVX constants -- why can't I figure a better way to do that? */
#define _PS256_CONST(Name, Val)                                   \
  static const ALIGN32_BEG float _ps256_##Name[8] ALIGN32_END = { \
      Val, Val, Val, Val, Val, Val, Val, Val}
#define _PI32_CONST256(Name, Val)                                  \
  static const ALIGN32_BEG int _pi32_256_##Name[8] ALIGN32_END = { \
      Val, Val, Val, Val, Val, Val, Val, Val}
#define _PS256_CONST_TYPE(Name, Type, Val)                       \
  static const ALIGN32_BEG Type _ps256_##Name[8] ALIGN32_END = { \
      Val, Val, Val, Val, Val, Val, Val, Val}

_PS256_CONST(1, 1.0f);
_PS256_CONST(0p5, 0.5f);
/* the smallest non denormalized float number */
_PS256_CONST_TYPE(min_norm_pos, int, 0x00800000);
_PS256_CONST_TYPE(mant_mask, int, 0x7f800000);
_PS256_CONST_TYPE(inv_mant_mask, int, ~0x7f800000);

_PS256_CONST_TYPE(sign_mask, int, (int)0x80000000);
_PS256_CONST_TYPE(inv_sign_mask, int, ~0x80000000);

_PI32_CONST256(0, 0);
_PI32_CONST256(1, 1);
_PI32_CONST256(inv1, ~1);
_PI32_CONST256(2, 2);
_PI32_CONST256(4, 4);
_PI32_CONST256(0x7f, 0x7f);

_PS256_CONST(cephes_SQRTHF, 0.707106781186547524);
_PS256_CONST(cephes_log_p0, 7.0376836292E-2);
_PS256_CONST(cephes_log_p1, -1.1514610310E-1);
_PS256_CONST(cephes_log_p2, 1.1676998740E-1);
_PS256_CONST(cephes_log_p3, -1.2420140846E-1);
_PS256_CONST(cephes_log_p4, +1.4249322787E-1);
_PS256_CONST(cephes_log_p5, -1.6668057665E-1);
_PS256_CONST(cephes_log_p6, +2.0000714765E-1);
_PS256_CONST(cephes_log_p7, -2.4999993993E-1);
_PS256_CONST(cephes_log_p8, +3.3333331174E-1);
_PS256_CONST(cephes_log_q1, -2.12194440e-4);
_PS256_CONST(cephes_log_q2, 0.693359375);

#ifndef __AVX2__

typedef union imm_xmm_union {
  v8si imm;
  v4si xmm[2];
} imm_xmm_union;

#define COPY_IMM_TO_XMM(imm_, xmm0_, xmm1_)  \
  {                                          \
    imm_xmm_union ALIGN32_BEG u ALIGN32_END; \
    u.imm = imm_;                            \
    xmm0_ = u.xmm[0];                        \
    xmm1_ = u.xmm[1];                        \
  }

#define COPY_XMM_TO_IMM(xmm0_, xmm1_, imm_)  \
  {                                          \
    imm_xmm_union ALIGN32_BEG u ALIGN32_END; \
    u.xmm[0] = xmm0_;                        \
    u.xmm[1] = xmm1_;                        \
    imm_ = u.imm;                            \
  }

#define AVX2_BITOP_USING_SSE2(fn)                        \
  static inline v8si avx2_mm256_##fn(v8si x, int a) {    \
    /* use SSE2 instruction to perform the bitop AVX2 */ \
    v4si x1, x2;                                         \
    v8si ret;                                            \
    COPY_IMM_TO_XMM(x, x1, x2);                          \
    x1 = _mm_##fn(x1, a);                                \
    x2 = _mm_##fn(x2, a);                                \
    COPY_XMM_TO_IMM(x1, x2, ret);                        \
    return (ret);                                        \
  }

139
// #warning "Using SSE2 to perform AVX2 bitshift ops"
Y
Yan Chunwei 已提交
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
AVX2_BITOP_USING_SSE2(slli_epi32)
AVX2_BITOP_USING_SSE2(srli_epi32)

#define AVX2_INTOP_USING_SSE2(fn)                                     \
  static inline v8si avx2_mm256_##fn(v8si x, v8si y) {                \
    /* use SSE2 instructions to perform the AVX2 integer operation */ \
    v4si x1, x2;                                                      \
    v4si y1, y2;                                                      \
    v8si ret;                                                         \
    COPY_IMM_TO_XMM(x, x1, x2);                                       \
    COPY_IMM_TO_XMM(y, y1, y2);                                       \
    x1 = _mm_##fn(x1, y1);                                            \
    x2 = _mm_##fn(x2, y2);                                            \
    COPY_XMM_TO_IMM(x1, x2, ret);                                     \
    return (ret);                                                     \
  }

157
// #warning "Using SSE2 to perform AVX2 integer ops"
Y
Yan Chunwei 已提交
158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
AVX2_INTOP_USING_SSE2(and_si128)
AVX2_INTOP_USING_SSE2(andnot_si128)
AVX2_INTOP_USING_SSE2(cmpeq_epi32)
AVX2_INTOP_USING_SSE2(sub_epi32)
AVX2_INTOP_USING_SSE2(add_epi32)
#define avx2_mm256_and_si256 avx2_mm256_and_si128
#define avx2_mm256_andnot_si256 avx2_mm256_andnot_si128
#else
#define avx2_mm256_slli_epi32 _mm256_slli_epi32
#define avx2_mm256_srli_epi32 _mm256_srli_epi32
#define avx2_mm256_and_si256 _mm256_and_si256
#define avx2_mm256_andnot_si256 _mm256_andnot_si256
#define avx2_mm256_cmpeq_epi32 _mm256_cmpeq_epi32
#define avx2_mm256_sub_epi32 _mm256_sub_epi32
#define avx2_mm256_add_epi32 _mm256_add_epi32
#endif /* __AVX2__ */

/* natural logarithm computed for 8 simultaneous float
   return NaN for x <= 0
*/
v8sf log256_ps(v8sf x) {
  v8si imm0;
180
  v8sf one = *(v8sf *)_ps256_1;  // NOLINT
Y
Yan Chunwei 已提交
181 182 183 184

  // v8sf invalid_mask = _mm256_cmple_ps(x, _mm256_setzero_ps());
  v8sf invalid_mask = _mm256_cmp_ps(x, _mm256_setzero_ps(), _CMP_LE_OS);

185 186
  x = _mm256_max_ps(x, *(v8sf *)_ps256_min_norm_pos);  // NOLINT
  /* cut off denormalized stuff */                     // NOLINT
Y
Yan Chunwei 已提交
187 188 189 190 191

  // can be done with AVX2
  imm0 = avx2_mm256_srli_epi32(_mm256_castps_si256(x), 23);

  /* keep only the fractional part */
192 193
  x = _mm256_and_ps(x, *(v8sf *)_ps256_inv_mant_mask);  // NOLINT
  x = _mm256_or_ps(x, *(v8sf *)_ps256_0p5);             // NOLINT
Y
Yan Chunwei 已提交
194 195

  // this is again another AVX2 instruction
196
  imm0 = avx2_mm256_sub_epi32(imm0, *(v8si *)_pi32_256_0x7f);  // NOLINT
Y
Yan Chunwei 已提交
197 198 199 200 201 202 203 204 205 206 207
  v8sf e = _mm256_cvtepi32_ps(imm0);

  e = _mm256_add_ps(e, one);

  /* part2:
     if( x < SQRTHF ) {
       e -= 1;
       x = x + x - 1.0;
     } else { x = x - 1.0; }
  */
  // v8sf mask = _mm256_cmplt_ps(x, *(v8sf*)_ps256_cephes_SQRTHF);
208 209
  v8sf mask =
      _mm256_cmp_ps(x, *(v8sf *)_ps256_cephes_SQRTHF, _CMP_LT_OS);  // NOLINT
Y
Yan Chunwei 已提交
210 211 212 213 214 215 216
  v8sf tmp = _mm256_and_ps(x, mask);
  x = _mm256_sub_ps(x, one);
  e = _mm256_sub_ps(e, _mm256_and_ps(one, mask));
  x = _mm256_add_ps(x, tmp);

  v8sf z = _mm256_mul_ps(x, x);

217
  v8sf y = *(v8sf *)_ps256_cephes_log_p0;  // NOLINT
Y
Yan Chunwei 已提交
218
  y = _mm256_mul_ps(y, x);
219
  y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_log_p1);  // NOLINT
Y
Yan Chunwei 已提交
220
  y = _mm256_mul_ps(y, x);
221
  y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_log_p2);  // NOLINT
Y
Yan Chunwei 已提交
222
  y = _mm256_mul_ps(y, x);
223
  y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_log_p3);  // NOLINT
Y
Yan Chunwei 已提交
224
  y = _mm256_mul_ps(y, x);
225
  y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_log_p4);  // NOLINT
Y
Yan Chunwei 已提交
226
  y = _mm256_mul_ps(y, x);
227
  y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_log_p5);  // NOLINT
Y
Yan Chunwei 已提交
228
  y = _mm256_mul_ps(y, x);
229
  y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_log_p6);  // NOLINT
Y
Yan Chunwei 已提交
230
  y = _mm256_mul_ps(y, x);
231
  y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_log_p7);  // NOLINT
Y
Yan Chunwei 已提交
232
  y = _mm256_mul_ps(y, x);
233
  y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_log_p8);  // NOLINT
Y
Yan Chunwei 已提交
234 235 236 237
  y = _mm256_mul_ps(y, x);

  y = _mm256_mul_ps(y, z);

238
  tmp = _mm256_mul_ps(e, *(v8sf *)_ps256_cephes_log_q1);  // NOLINT
Y
Yan Chunwei 已提交
239 240
  y = _mm256_add_ps(y, tmp);

241
  tmp = _mm256_mul_ps(z, *(v8sf *)_ps256_0p5);  // NOLINT
Y
Yan Chunwei 已提交
242 243
  y = _mm256_sub_ps(y, tmp);

244
  tmp = _mm256_mul_ps(e, *(v8sf *)_ps256_cephes_log_q2);  // NOLINT
Y
Yan Chunwei 已提交
245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267
  x = _mm256_add_ps(x, y);
  x = _mm256_add_ps(x, tmp);
  x = _mm256_or_ps(x, invalid_mask);  // negative arg will be NAN
  return x;
}

_PS256_CONST(exp_hi, 88.3762626647949f);
_PS256_CONST(exp_lo, -88.3762626647949f);

_PS256_CONST(cephes_LOG2EF, 1.44269504088896341);
_PS256_CONST(cephes_exp_C1, 0.693359375);
_PS256_CONST(cephes_exp_C2, -2.12194440e-4);

_PS256_CONST(cephes_exp_p0, 1.9875691500E-4);
_PS256_CONST(cephes_exp_p1, 1.3981999507E-3);
_PS256_CONST(cephes_exp_p2, 8.3334519073E-3);
_PS256_CONST(cephes_exp_p3, 4.1665795894E-2);
_PS256_CONST(cephes_exp_p4, 1.6666665459E-1);
_PS256_CONST(cephes_exp_p5, 5.0000001201E-1);

v8sf exp256_ps(v8sf x) {
  v8sf tmp = _mm256_setzero_ps(), fx;
  v8si imm0;
268
  v8sf one = *(v8sf *)_ps256_1;  // NOLINT
Y
Yan Chunwei 已提交
269

270 271
  x = _mm256_min_ps(x, *(v8sf *)_ps256_exp_hi);  // NOLINT
  x = _mm256_max_ps(x, *(v8sf *)_ps256_exp_lo);  // NOLINT
Y
Yan Chunwei 已提交
272 273

  /* express exp(x) as exp(g + n*log(2)) */
274 275
  fx = _mm256_mul_ps(x, *(v8sf *)_ps256_cephes_LOG2EF);  // NOLINT
  fx = _mm256_add_ps(fx, *(v8sf *)_ps256_0p5);           // NOLINT
Y
Yan Chunwei 已提交
276 277 278 279 280 281 282 283 284 285 286 287 288

  /* how to perform a floorf with SSE: just below */
  // imm0 = _mm256_cvttps_epi32(fx);
  // tmp  = _mm256_cvtepi32_ps(imm0);

  tmp = _mm256_floor_ps(fx);

  /* if greater, substract 1 */
  // v8sf mask = _mm256_cmpgt_ps(tmp, fx);
  v8sf mask = _mm256_cmp_ps(tmp, fx, _CMP_GT_OS);
  mask = _mm256_and_ps(mask, one);
  fx = _mm256_sub_ps(tmp, mask);

289 290
  tmp = _mm256_mul_ps(fx, *(v8sf *)_ps256_cephes_exp_C1);     // NOLINT
  v8sf z = _mm256_mul_ps(fx, *(v8sf *)_ps256_cephes_exp_C2);  // NOLINT
Y
Yan Chunwei 已提交
291 292 293 294 295
  x = _mm256_sub_ps(x, tmp);
  x = _mm256_sub_ps(x, z);

  z = _mm256_mul_ps(x, x);

296
  v8sf y = *(v8sf *)_ps256_cephes_exp_p0;  // NOLINT
Y
Yan Chunwei 已提交
297
  y = _mm256_mul_ps(y, x);
298
  y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_exp_p1);  // NOLINT
Y
Yan Chunwei 已提交
299
  y = _mm256_mul_ps(y, x);
300
  y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_exp_p2);  // NOLINT
Y
Yan Chunwei 已提交
301
  y = _mm256_mul_ps(y, x);
302
  y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_exp_p3);  // NOLINT
Y
Yan Chunwei 已提交
303
  y = _mm256_mul_ps(y, x);
304
  y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_exp_p4);  // NOLINT
Y
Yan Chunwei 已提交
305
  y = _mm256_mul_ps(y, x);
306
  y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_exp_p5);  // NOLINT
Y
Yan Chunwei 已提交
307 308 309 310 311 312 313
  y = _mm256_mul_ps(y, z);
  y = _mm256_add_ps(y, x);
  y = _mm256_add_ps(y, one);

  /* build 2^n */
  imm0 = _mm256_cvttps_epi32(fx);
  // another two AVX2 instructions
314
  imm0 = avx2_mm256_add_epi32(imm0, *(v8si *)_pi32_256_0x7f);  // NOLINT
Y
Yan Chunwei 已提交
315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354
  imm0 = avx2_mm256_slli_epi32(imm0, 23);
  v8sf pow2n = _mm256_castsi256_ps(imm0);
  y = _mm256_mul_ps(y, pow2n);
  return y;
}

_PS256_CONST(minus_cephes_DP1, -0.78515625);
_PS256_CONST(minus_cephes_DP2, -2.4187564849853515625e-4);
_PS256_CONST(minus_cephes_DP3, -3.77489497744594108e-8);
_PS256_CONST(sincof_p0, -1.9515295891E-4);
_PS256_CONST(sincof_p1, 8.3321608736E-3);
_PS256_CONST(sincof_p2, -1.6666654611E-1);
_PS256_CONST(coscof_p0, 2.443315711809948E-005);
_PS256_CONST(coscof_p1, -1.388731625493765E-003);
_PS256_CONST(coscof_p2, 4.166664568298827E-002);
_PS256_CONST(cephes_FOPI, 1.27323954473516);  // 4 / M_PI

/* evaluation of 8 sines at onces using AVX intrisics

   The code is the exact rewriting of the cephes sinf function.
   Precision is excellent as long as x < 8192 (I did not bother to
   take into account the special handling they have for greater values
   -- it does not return garbage for arguments over 8192, though, but
   the extra precision is missing).

   Note that it is such that sinf((float)M_PI) = 8.74e-8, which is the
   surprising but correct result.

*/
v8sf sin256_ps(v8sf x) {  // any x
  v8sf xmm1, xmm2 = _mm256_setzero_ps(), xmm3, sign_bit, y;
  v8si imm0, imm2;

#ifndef __AVX2__
  v4si imm0_1, imm0_2;
  v4si imm2_1, imm2_2;
#endif

  sign_bit = x;
  /* take the absolute value */
355
  x = _mm256_and_ps(x, *(v8sf *)_ps256_inv_sign_mask);  // NOLINT
Y
Yan Chunwei 已提交
356
  /* extract the sign bit (upper one) */
357
  sign_bit = _mm256_and_ps(sign_bit, *(v8sf *)_ps256_sign_mask);  // NOLINT
Y
Yan Chunwei 已提交
358 359

  /* scale by 4/Pi */
360
  y = _mm256_mul_ps(x, *(v8sf *)_ps256_cephes_FOPI);  // NOLINT
Y
Yan Chunwei 已提交
361 362 363 364 365 366 367 368 369 370 371 372

/*
  Here we start a series of integer operations, which are in the
  realm of AVX2.
  If we don't have AVX, let's perform them using SSE2 directives
*/

#ifdef __AVX2__
  /* store the integer part of y in mm0 */
  imm2 = _mm256_cvttps_epi32(y);
  /* j=(j+1) & (~1) (see the cephes sources) */
  // another two AVX2 instruction
373 374
  imm2 = avx2_mm256_add_epi32(imm2, *(v8si *)_pi32_256_1);     // NOLINT
  imm2 = avx2_mm256_and_si256(imm2, *(v8si *)_pi32_256_inv1);  // NOLINT
Y
Yan Chunwei 已提交
375 376 377
  y = _mm256_cvtepi32_ps(imm2);

  /* get the swap sign flag */
378
  imm0 = avx2_mm256_and_si256(imm2, *(v8si *)_pi32_256_4);  // NOLINT
Y
Yan Chunwei 已提交
379 380 381 382 383 384 385
  imm0 = avx2_mm256_slli_epi32(imm0, 29);
  /* get the polynom selection mask
     there is one polynom for 0 <= x <= Pi/4
     and another one for Pi/4<x<=Pi/2

     Both branches will be computed.
  */
386 387
  imm2 = avx2_mm256_and_si256(imm2, *(v8si *)_pi32_256_2);    // NOLINT
  imm2 = avx2_mm256_cmpeq_epi32(imm2, *(v8si *)_pi32_256_0);  // NOLINT
Y
Yan Chunwei 已提交
388 389 390 391
#else
  /* we use SSE2 routines to perform the integer ops */
  COPY_IMM_TO_XMM(_mm256_cvttps_epi32(y), imm2_1, imm2_2);

392 393
  imm2_1 = _mm_add_epi32(imm2_1, *(v4si *)_pi32avx_1);  // NOLINT
  imm2_2 = _mm_add_epi32(imm2_2, *(v4si *)_pi32avx_1);  // NOLINT
Y
Yan Chunwei 已提交
394

395 396
  imm2_1 = _mm_and_si128(imm2_1, *(v4si *)_pi32avx_inv1);  // NOLINT
  imm2_2 = _mm_and_si128(imm2_2, *(v4si *)_pi32avx_inv1);  // NOLINT
Y
Yan Chunwei 已提交
397 398 399 400

  COPY_XMM_TO_IMM(imm2_1, imm2_2, imm2);
  y = _mm256_cvtepi32_ps(imm2);

401 402
  imm0_1 = _mm_and_si128(imm2_1, *(v4si *)_pi32avx_4);  // NOLINT
  imm0_2 = _mm_and_si128(imm2_2, *(v4si *)_pi32avx_4);  // NOLINT
Y
Yan Chunwei 已提交
403 404 405 406 407 408

  imm0_1 = _mm_slli_epi32(imm0_1, 29);
  imm0_2 = _mm_slli_epi32(imm0_2, 29);

  COPY_XMM_TO_IMM(imm0_1, imm0_2, imm0);

409 410
  imm2_1 = _mm_and_si128(imm2_1, *(v4si *)_pi32avx_2);  // NOLINT
  imm2_2 = _mm_and_si128(imm2_2, *(v4si *)_pi32avx_2);  // NOLINT
Y
Yan Chunwei 已提交
411 412 413 414 415 416 417 418 419 420 421 422 423

  imm2_1 = _mm_cmpeq_epi32(imm2_1, _mm_setzero_si128());
  imm2_2 = _mm_cmpeq_epi32(imm2_2, _mm_setzero_si128());

  COPY_XMM_TO_IMM(imm2_1, imm2_2, imm2);
#endif

  v8sf swap_sign_bit = _mm256_castsi256_ps(imm0);
  v8sf poly_mask = _mm256_castsi256_ps(imm2);
  sign_bit = _mm256_xor_ps(sign_bit, swap_sign_bit);

  /* The magic pass: "Extended precision modular arithmetic"
     x = ((x - y * DP1) - y * DP2) - y * DP3; */
424 425 426
  xmm1 = *(v8sf *)_ps256_minus_cephes_DP1;  // NOLINT
  xmm2 = *(v8sf *)_ps256_minus_cephes_DP2;  // NOLINT
  xmm3 = *(v8sf *)_ps256_minus_cephes_DP3;  // NOLINT
Y
Yan Chunwei 已提交
427 428 429 430 431 432 433 434
  xmm1 = _mm256_mul_ps(y, xmm1);
  xmm2 = _mm256_mul_ps(y, xmm2);
  xmm3 = _mm256_mul_ps(y, xmm3);
  x = _mm256_add_ps(x, xmm1);
  x = _mm256_add_ps(x, xmm2);
  x = _mm256_add_ps(x, xmm3);

  /* Evaluate the first polynom  (0 <= x <= Pi/4) */
435
  y = *(v8sf *)_ps256_coscof_p0;  // NOLINT
Y
Yan Chunwei 已提交
436 437 438
  v8sf z = _mm256_mul_ps(x, x);

  y = _mm256_mul_ps(y, z);
439
  y = _mm256_add_ps(y, *(v8sf *)_ps256_coscof_p1);  // NOLINT
Y
Yan Chunwei 已提交
440
  y = _mm256_mul_ps(y, z);
441
  y = _mm256_add_ps(y, *(v8sf *)_ps256_coscof_p2);  // NOLINT
Y
Yan Chunwei 已提交
442 443
  y = _mm256_mul_ps(y, z);
  y = _mm256_mul_ps(y, z);
444
  v8sf tmp = _mm256_mul_ps(z, *(v8sf *)_ps256_0p5);  // NOLINT
Y
Yan Chunwei 已提交
445
  y = _mm256_sub_ps(y, tmp);
446
  y = _mm256_add_ps(y, *(v8sf *)_ps256_1);  // NOLINT
Y
Yan Chunwei 已提交
447 448 449

  /* Evaluate the second polynom  (Pi/4 <= x <= 0) */

450
  v8sf y2 = *(v8sf *)_ps256_sincof_p0;  // NOLINT
Y
Yan Chunwei 已提交
451
  y2 = _mm256_mul_ps(y2, z);
452
  y2 = _mm256_add_ps(y2, *(v8sf *)_ps256_sincof_p1);  // NOLINT
Y
Yan Chunwei 已提交
453
  y2 = _mm256_mul_ps(y2, z);
454
  y2 = _mm256_add_ps(y2, *(v8sf *)_ps256_sincof_p2);  // NOLINT
Y
Yan Chunwei 已提交
455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480
  y2 = _mm256_mul_ps(y2, z);
  y2 = _mm256_mul_ps(y2, x);
  y2 = _mm256_add_ps(y2, x);

  /* select the correct result from the two polynoms */
  xmm3 = poly_mask;
  y2 = _mm256_and_ps(xmm3, y2);  //, xmm3);
  y = _mm256_andnot_ps(xmm3, y);
  y = _mm256_add_ps(y, y2);
  /* update the sign */
  y = _mm256_xor_ps(y, sign_bit);

  return y;
}

/* almost the same as sin_ps */
v8sf cos256_ps(v8sf x) {  // any x
  v8sf xmm1, xmm2 = _mm256_setzero_ps(), xmm3, y;
  v8si imm0, imm2;

#ifndef __AVX2__
  v4si imm0_1, imm0_2;
  v4si imm2_1, imm2_2;
#endif

  /* take the absolute value */
481
  x = _mm256_and_ps(x, *(v8sf *)_ps256_inv_sign_mask);  // NOLINT
Y
Yan Chunwei 已提交
482 483

  /* scale by 4/Pi */
484
  y = _mm256_mul_ps(x, *(v8sf *)_ps256_cephes_FOPI);  // NOLINT
Y
Yan Chunwei 已提交
485 486 487 488 489

#ifdef __AVX2__
  /* store the integer part of y in mm0 */
  imm2 = _mm256_cvttps_epi32(y);
  /* j=(j+1) & (~1) (see the cephes sources) */
490 491
  imm2 = avx2_mm256_add_epi32(imm2, *(v8si *)_pi32_256_1);     // NOLINT
  imm2 = avx2_mm256_and_si256(imm2, *(v8si *)_pi32_256_inv1);  // NOLINT
Y
Yan Chunwei 已提交
492
  y = _mm256_cvtepi32_ps(imm2);
493
  imm2 = avx2_mm256_sub_epi32(imm2, *(v8si *)_pi32_256_2);  // NOLINT
Y
Yan Chunwei 已提交
494 495

  /* get the swap sign flag */
496
  imm0 = avx2_mm256_andnot_si256(imm2, *(v8si *)_pi32_256_4);  // NOLINT
Y
Yan Chunwei 已提交
497 498
  imm0 = avx2_mm256_slli_epi32(imm0, 29);
  /* get the polynom selection mask */
499 500
  imm2 = avx2_mm256_and_si256(imm2, *(v8si *)_pi32_256_2);    // NOLINT
  imm2 = avx2_mm256_cmpeq_epi32(imm2, *(v8si *)_pi32_256_0);  // NOLINT
Y
Yan Chunwei 已提交
501 502 503 504 505
#else

  /* we use SSE2 routines to perform the integer ops */
  COPY_IMM_TO_XMM(_mm256_cvttps_epi32(y), imm2_1, imm2_2);

506 507
  imm2_1 = _mm_add_epi32(imm2_1, *(v4si *)_pi32avx_1);  // NOLINT
  imm2_2 = _mm_add_epi32(imm2_2, *(v4si *)_pi32avx_1);  // NOLINT
Y
Yan Chunwei 已提交
508

509 510
  imm2_1 = _mm_and_si128(imm2_1, *(v4si *)_pi32avx_inv1);  // NOLINT
  imm2_2 = _mm_and_si128(imm2_2, *(v4si *)_pi32avx_inv1);  // NOLINT
Y
Yan Chunwei 已提交
511 512 513 514

  COPY_XMM_TO_IMM(imm2_1, imm2_2, imm2);
  y = _mm256_cvtepi32_ps(imm2);

515 516
  imm2_1 = _mm_sub_epi32(imm2_1, *(v4si *)_pi32avx_2);  // NOLINT
  imm2_2 = _mm_sub_epi32(imm2_2, *(v4si *)_pi32avx_2);  // NOLINT
Y
Yan Chunwei 已提交
517

518 519
  imm0_1 = _mm_andnot_si128(imm2_1, *(v4si *)_pi32avx_4);  // NOLINT
  imm0_2 = _mm_andnot_si128(imm2_2, *(v4si *)_pi32avx_4);  // NOLINT
Y
Yan Chunwei 已提交
520 521 522 523 524 525

  imm0_1 = _mm_slli_epi32(imm0_1, 29);
  imm0_2 = _mm_slli_epi32(imm0_2, 29);

  COPY_XMM_TO_IMM(imm0_1, imm0_2, imm0);

526 527
  imm2_1 = _mm_and_si128(imm2_1, *(v4si *)_pi32avx_2);  // NOLINT
  imm2_2 = _mm_and_si128(imm2_2, *(v4si *)_pi32avx_2);  // NOLINT
Y
Yan Chunwei 已提交
528 529 530 531 532 533 534 535 536 537 538 539

  imm2_1 = _mm_cmpeq_epi32(imm2_1, _mm_setzero_si128());
  imm2_2 = _mm_cmpeq_epi32(imm2_2, _mm_setzero_si128());

  COPY_XMM_TO_IMM(imm2_1, imm2_2, imm2);
#endif

  v8sf sign_bit = _mm256_castsi256_ps(imm0);
  v8sf poly_mask = _mm256_castsi256_ps(imm2);

  /* The magic pass: "Extended precision modular arithmetic"
     x = ((x - y * DP1) - y * DP2) - y * DP3; */
540 541 542
  xmm1 = *(v8sf *)_ps256_minus_cephes_DP1;  // NOLINT
  xmm2 = *(v8sf *)_ps256_minus_cephes_DP2;  // NOLINT
  xmm3 = *(v8sf *)_ps256_minus_cephes_DP3;  // NOLINT
Y
Yan Chunwei 已提交
543 544 545 546 547 548 549 550
  xmm1 = _mm256_mul_ps(y, xmm1);
  xmm2 = _mm256_mul_ps(y, xmm2);
  xmm3 = _mm256_mul_ps(y, xmm3);
  x = _mm256_add_ps(x, xmm1);
  x = _mm256_add_ps(x, xmm2);
  x = _mm256_add_ps(x, xmm3);

  /* Evaluate the first polynom  (0 <= x <= Pi/4) */
551
  y = *(v8sf *)_ps256_coscof_p0;  // NOLINT
Y
Yan Chunwei 已提交
552 553 554
  v8sf z = _mm256_mul_ps(x, x);

  y = _mm256_mul_ps(y, z);
555
  y = _mm256_add_ps(y, *(v8sf *)_ps256_coscof_p1);  // NOLINT
Y
Yan Chunwei 已提交
556
  y = _mm256_mul_ps(y, z);
557
  y = _mm256_add_ps(y, *(v8sf *)_ps256_coscof_p2);  // NOLINT
Y
Yan Chunwei 已提交
558 559
  y = _mm256_mul_ps(y, z);
  y = _mm256_mul_ps(y, z);
560
  v8sf tmp = _mm256_mul_ps(z, *(v8sf *)_ps256_0p5);  // NOLINT
Y
Yan Chunwei 已提交
561
  y = _mm256_sub_ps(y, tmp);
562
  y = _mm256_add_ps(y, *(v8sf *)_ps256_1);  // NOLINT
Y
Yan Chunwei 已提交
563 564 565

  /* Evaluate the second polynom  (Pi/4 <= x <= 0) */

566
  v8sf y2 = *(v8sf *)_ps256_sincof_p0;  // NOLINT
Y
Yan Chunwei 已提交
567
  y2 = _mm256_mul_ps(y2, z);
568
  y2 = _mm256_add_ps(y2, *(v8sf *)_ps256_sincof_p1);  // NOLINT
Y
Yan Chunwei 已提交
569
  y2 = _mm256_mul_ps(y2, z);
570
  y2 = _mm256_add_ps(y2, *(v8sf *)_ps256_sincof_p2);  // NOLINT
Y
Yan Chunwei 已提交
571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600
  y2 = _mm256_mul_ps(y2, z);
  y2 = _mm256_mul_ps(y2, x);
  y2 = _mm256_add_ps(y2, x);

  /* select the correct result from the two polynoms */
  xmm3 = poly_mask;
  y2 = _mm256_and_ps(xmm3, y2);  //, xmm3);
  y = _mm256_andnot_ps(xmm3, y);
  y = _mm256_add_ps(y, y2);
  /* update the sign */
  y = _mm256_xor_ps(y, sign_bit);

  return y;
}

/* since sin256_ps and cos256_ps are almost identical, sincos256_ps could
   replace both of them..
   it is almost as fast, and gives you a free cosine with your sine */
void sincos256_ps(v8sf x, v8sf *s, v8sf *c) {
  v8sf xmm1, xmm2, xmm3 = _mm256_setzero_ps(), sign_bit_sin, y;
  v8si imm0, imm2, imm4;

#ifndef __AVX2__
  v4si imm0_1, imm0_2;
  v4si imm2_1, imm2_2;
  v4si imm4_1, imm4_2;
#endif

  sign_bit_sin = x;
  /* take the absolute value */
601
  x = _mm256_and_ps(x, *(v8sf *)_ps256_inv_sign_mask);  // NOLINT
Y
Yan Chunwei 已提交
602
  /* extract the sign bit (upper one) */
603 604
  sign_bit_sin =
      _mm256_and_ps(sign_bit_sin, *(v8sf *)_ps256_sign_mask);  // NOLINT
Y
Yan Chunwei 已提交
605 606

  /* scale by 4/Pi */
607
  y = _mm256_mul_ps(x, *(v8sf *)_ps256_cephes_FOPI);  // NOLINT
Y
Yan Chunwei 已提交
608 609 610 611 612 613

#ifdef __AVX2__
  /* store the integer part of y in imm2 */
  imm2 = _mm256_cvttps_epi32(y);

  /* j=(j+1) & (~1) (see the cephes sources) */
614 615
  imm2 = avx2_mm256_add_epi32(imm2, *(v8si *)_pi32_256_1);     // NOLINT
  imm2 = avx2_mm256_and_si256(imm2, *(v8si *)_pi32_256_inv1);  // NOLINT
Y
Yan Chunwei 已提交
616 617 618 619 620

  y = _mm256_cvtepi32_ps(imm2);
  imm4 = imm2;

  /* get the swap sign flag for the sine */
621
  imm0 = avx2_mm256_and_si256(imm2, *(v8si *)_pi32_256_4);  // NOLINT
Y
Yan Chunwei 已提交
622 623 624 625
  imm0 = avx2_mm256_slli_epi32(imm0, 29);
  // v8sf swap_sign_bit_sin = _mm256_castsi256_ps(imm0);

  /* get the polynom selection mask for the sine*/
626 627
  imm2 = avx2_mm256_and_si256(imm2, *(v8si *)_pi32_256_2);    // NOLINT
  imm2 = avx2_mm256_cmpeq_epi32(imm2, *(v8si *)_pi32_256_0);  // NOLINT
Y
Yan Chunwei 已提交
628 629 630 631 632
// v8sf poly_mask = _mm256_castsi256_ps(imm2);
#else
  /* we use SSE2 routines to perform the integer ops */
  COPY_IMM_TO_XMM(_mm256_cvttps_epi32(y), imm2_1, imm2_2);

633 634
  imm2_1 = _mm_add_epi32(imm2_1, *(v4si *)_pi32avx_1);  // NOLINT
  imm2_2 = _mm_add_epi32(imm2_2, *(v4si *)_pi32avx_1);  // NOLINT
Y
Yan Chunwei 已提交
635

636 637
  imm2_1 = _mm_and_si128(imm2_1, *(v4si *)_pi32avx_inv1);  // NOLINT
  imm2_2 = _mm_and_si128(imm2_2, *(v4si *)_pi32avx_inv1);  // NOLINT
Y
Yan Chunwei 已提交
638 639 640 641 642 643 644

  COPY_XMM_TO_IMM(imm2_1, imm2_2, imm2);
  y = _mm256_cvtepi32_ps(imm2);

  imm4_1 = imm2_1;
  imm4_2 = imm2_2;

645 646
  imm0_1 = _mm_and_si128(imm2_1, *(v4si *)_pi32avx_4);  // NOLINT
  imm0_2 = _mm_and_si128(imm2_2, *(v4si *)_pi32avx_4);  // NOLINT
Y
Yan Chunwei 已提交
647 648 649 650 651 652

  imm0_1 = _mm_slli_epi32(imm0_1, 29);
  imm0_2 = _mm_slli_epi32(imm0_2, 29);

  COPY_XMM_TO_IMM(imm0_1, imm0_2, imm0);

653 654
  imm2_1 = _mm_and_si128(imm2_1, *(v4si *)_pi32avx_2);  // NOLINT
  imm2_2 = _mm_and_si128(imm2_2, *(v4si *)_pi32avx_2);  // NOLINT
Y
Yan Chunwei 已提交
655 656 657 658 659 660 661 662 663 664 665

  imm2_1 = _mm_cmpeq_epi32(imm2_1, _mm_setzero_si128());
  imm2_2 = _mm_cmpeq_epi32(imm2_2, _mm_setzero_si128());

  COPY_XMM_TO_IMM(imm2_1, imm2_2, imm2);
#endif
  v8sf swap_sign_bit_sin = _mm256_castsi256_ps(imm0);
  v8sf poly_mask = _mm256_castsi256_ps(imm2);

  /* The magic pass: "Extended precision modular arithmetic"
     x = ((x - y * DP1) - y * DP2) - y * DP3; */
666 667 668
  xmm1 = *(v8sf *)_ps256_minus_cephes_DP1;  // NOLINT
  xmm2 = *(v8sf *)_ps256_minus_cephes_DP2;  // NOLINT
  xmm3 = *(v8sf *)_ps256_minus_cephes_DP3;  // NOLINT
Y
Yan Chunwei 已提交
669 670 671 672 673 674 675 676
  xmm1 = _mm256_mul_ps(y, xmm1);
  xmm2 = _mm256_mul_ps(y, xmm2);
  xmm3 = _mm256_mul_ps(y, xmm3);
  x = _mm256_add_ps(x, xmm1);
  x = _mm256_add_ps(x, xmm2);
  x = _mm256_add_ps(x, xmm3);

#ifdef __AVX2__
677 678
  imm4 = avx2_mm256_sub_epi32(imm4, *(v8si *)_pi32_256_2);     // NOLINT
  imm4 = avx2_mm256_andnot_si256(imm4, *(v8si *)_pi32_256_4);  // NOLINT
Y
Yan Chunwei 已提交
679 680
  imm4 = avx2_mm256_slli_epi32(imm4, 29);
#else
681 682
  imm4_1 = _mm_sub_epi32(imm4_1, *(v4si *)_pi32avx_2);  // NOLINT
  imm4_2 = _mm_sub_epi32(imm4_2, *(v4si *)_pi32avx_2);  // NOLINT
Y
Yan Chunwei 已提交
683

684 685
  imm4_1 = _mm_andnot_si128(imm4_1, *(v4si *)_pi32avx_4);  // NOLINT
  imm4_2 = _mm_andnot_si128(imm4_2, *(v4si *)_pi32avx_4);  // NOLINT
Y
Yan Chunwei 已提交
686 687 688 689 690 691 692 693 694 695 696 697 698

  imm4_1 = _mm_slli_epi32(imm4_1, 29);
  imm4_2 = _mm_slli_epi32(imm4_2, 29);

  COPY_XMM_TO_IMM(imm4_1, imm4_2, imm4);
#endif

  v8sf sign_bit_cos = _mm256_castsi256_ps(imm4);

  sign_bit_sin = _mm256_xor_ps(sign_bit_sin, swap_sign_bit_sin);

  /* Evaluate the first polynom  (0 <= x <= Pi/4) */
  v8sf z = _mm256_mul_ps(x, x);
699
  y = *(v8sf *)_ps256_coscof_p0;  // NOLINT
Y
Yan Chunwei 已提交
700 701

  y = _mm256_mul_ps(y, z);
702
  y = _mm256_add_ps(y, *(v8sf *)_ps256_coscof_p1);  // NOLINT
Y
Yan Chunwei 已提交
703
  y = _mm256_mul_ps(y, z);
704
  y = _mm256_add_ps(y, *(v8sf *)_ps256_coscof_p2);  // NOLINT
Y
Yan Chunwei 已提交
705 706
  y = _mm256_mul_ps(y, z);
  y = _mm256_mul_ps(y, z);
707
  v8sf tmp = _mm256_mul_ps(z, *(v8sf *)_ps256_0p5);  // NOLINT
Y
Yan Chunwei 已提交
708
  y = _mm256_sub_ps(y, tmp);
709
  y = _mm256_add_ps(y, *(v8sf *)_ps256_1);  // NOLINT
Y
Yan Chunwei 已提交
710 711 712

  /* Evaluate the second polynom  (Pi/4 <= x <= 0) */

713
  v8sf y2 = *(v8sf *)_ps256_sincof_p0;  // NOLINT
Y
Yan Chunwei 已提交
714
  y2 = _mm256_mul_ps(y2, z);
715
  y2 = _mm256_add_ps(y2, *(v8sf *)_ps256_sincof_p1);  // NOLINT
Y
Yan Chunwei 已提交
716
  y2 = _mm256_mul_ps(y2, z);
717
  y2 = _mm256_add_ps(y2, *(v8sf *)_ps256_sincof_p2);  // NOLINT
Y
Yan Chunwei 已提交
718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735
  y2 = _mm256_mul_ps(y2, z);
  y2 = _mm256_mul_ps(y2, x);
  y2 = _mm256_add_ps(y2, x);

  /* select the correct result from the two polynoms */
  xmm3 = poly_mask;
  v8sf ysin2 = _mm256_and_ps(xmm3, y2);
  v8sf ysin1 = _mm256_andnot_ps(xmm3, y);
  y2 = _mm256_sub_ps(y2, ysin2);
  y = _mm256_sub_ps(y, ysin1);

  xmm1 = _mm256_add_ps(ysin1, ysin2);
  xmm2 = _mm256_add_ps(y, y2);

  /* update the sign */
  *s = _mm256_xor_ps(xmm1, sign_bit_sin);
  *c = _mm256_xor_ps(xmm2, sign_bit_cos);
}
736 737 738

}  // namespace lite
}  // namespace paddle