avx_mathfun.h 24.4 KB
Newer Older
1
//  Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2 3 4 5 6 7 8 9 10 11 12 13
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//    http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
Z
zhangjinchao01 已提交
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
/*
   AVX implementation of sin, cos, sincos, exp and log

   Based on "sse_mathfun.h", by Julien Pommier
   http://gruntthepeon.free.fr/ssemath/

   Copyright (C) 2012 Giovanni Garberoglio
   Interdisciplinary Laboratory for Computational Science (LISC)
   Fondazione Bruno Kessler and University of Trento
   via Sommarive, 18
   I-38123 Trento (Italy)

  This software is provided 'as-is', without any express or implied
  warranty.  In no event will the authors be held liable for any damages
  arising from the use of this software.

  Permission is granted to anyone to use this software for any purpose,
  including commercial applications, and to alter it and redistribute it
  freely, subject to the following restrictions:

  1. The origin of this software must not be misrepresented; you must not
     claim that you wrote the original software. If you use this software
     in a product, an acknowledgment in the product documentation would be
     appreciated but is not required.
  2. Altered source versions must be plainly marked as such, and must not be
     misrepresented as being the original software.
  3. This notice may not be removed or altered from any source distribution.

  (this is the zlib license)
*/

#include <immintrin.h>

/* yes I know, the top of this file is quite ugly */
48 49
#define ALIGN32_BEG
#define ALIGN32_END __attribute__((aligned(32)))
Z
zhangjinchao01 已提交
50 51

/* __m128 is ugly to write */
52 53 54
typedef __m256 v8sf;   // vector of 8 float (avx)
typedef __m256i v8si;  // vector of 8 int   (avx)
typedef __m128i v4si;  // vector of 8 int   (avx)
Z
zhangjinchao01 已提交
55

56 57 58
#define _PI32AVX_CONST(Name, Val)                                 \
  static const ALIGN32_BEG int _pi32avx_##Name[4] ALIGN32_END = { \
      Val, Val, Val, Val}
Z
zhangjinchao01 已提交
59 60 61 62 63 64 65

_PI32AVX_CONST(1, 1);
_PI32AVX_CONST(inv1, ~1);
_PI32AVX_CONST(2, 2);
_PI32AVX_CONST(4, 4);

/* declare some AVX constants -- why can't I figure a better way to do that? */
66 67 68 69 70 71 72 73 74 75 76
#define _PS256_CONST(Name, Val)                                   \
  static const ALIGN32_BEG float _ps256_##Name[8] ALIGN32_END = { \
      Val, Val, Val, Val, Val, Val, Val, Val}
#define _PI32_CONST256(Name, Val)                                  \
  static const ALIGN32_BEG int _pi32_256_##Name[8] ALIGN32_END = { \
      Val, Val, Val, Val, Val, Val, Val, Val}
#define _PS256_CONST_TYPE(Name, Type, Val)                       \
  static const ALIGN32_BEG Type _ps256_##Name[8] ALIGN32_END = { \
      Val, Val, Val, Val, Val, Val, Val, Val}

_PS256_CONST(1, 1.0f);
Z
zhangjinchao01 已提交
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
_PS256_CONST(0p5, 0.5f);
/* the smallest non denormalized float number */
_PS256_CONST_TYPE(min_norm_pos, int, 0x00800000);
_PS256_CONST_TYPE(mant_mask, int, 0x7f800000);
_PS256_CONST_TYPE(inv_mant_mask, int, ~0x7f800000);

_PS256_CONST_TYPE(sign_mask, int, (int)0x80000000);
_PS256_CONST_TYPE(inv_sign_mask, int, ~0x80000000);

_PI32_CONST256(0, 0);
_PI32_CONST256(1, 1);
_PI32_CONST256(inv1, ~1);
_PI32_CONST256(2, 2);
_PI32_CONST256(4, 4);
_PI32_CONST256(0x7f, 0x7f);

_PS256_CONST(cephes_SQRTHF, 0.707106781186547524);
_PS256_CONST(cephes_log_p0, 7.0376836292E-2);
95
_PS256_CONST(cephes_log_p1, -1.1514610310E-1);
Z
zhangjinchao01 已提交
96
_PS256_CONST(cephes_log_p2, 1.1676998740E-1);
97 98 99 100 101 102
_PS256_CONST(cephes_log_p3, -1.2420140846E-1);
_PS256_CONST(cephes_log_p4, +1.4249322787E-1);
_PS256_CONST(cephes_log_p5, -1.6668057665E-1);
_PS256_CONST(cephes_log_p6, +2.0000714765E-1);
_PS256_CONST(cephes_log_p7, -2.4999993993E-1);
_PS256_CONST(cephes_log_p8, +3.3333331174E-1);
Z
zhangjinchao01 已提交
103 104 105 106 107 108 109 110 111 112
_PS256_CONST(cephes_log_q1, -2.12194440e-4);
_PS256_CONST(cephes_log_q2, 0.693359375);

#ifndef __AVX2__

typedef union imm_xmm_union {
  v8si imm;
  v4si xmm[2];
} imm_xmm_union;

113 114
#define COPY_IMM_TO_XMM(imm_, xmm0_, xmm1_)       \
  {                                               \
Z
zhangjinchao01 已提交
115
    imm_xmm_union u __attribute__((aligned(32))); \
116 117 118
    u.imm = imm_;                                 \
    xmm0_ = u.xmm[0];                             \
    xmm1_ = u.xmm[1];                             \
Z
zhangjinchao01 已提交
119 120
  }

121 122 123 124 125 126 127
#define COPY_XMM_TO_IMM(xmm0_, xmm1_, imm_)       \
  {                                               \
    imm_xmm_union u __attribute__((aligned(32))); \
    u.xmm[0] = xmm0_;                             \
    u.xmm[1] = xmm1_;                             \
    imm_ = u.imm;                                 \
  }
Z
zhangjinchao01 已提交
128

129 130 131 132 133 134 135 136 137 138 139
#define AVX2_BITOP_USING_SSE2(fn)                        \
  static inline v8si avx2_mm256_##fn(v8si x, int a) {    \
    /* use SSE2 instruction to perform the bitop AVX2 */ \
    v4si x1, x2;                                         \
    v8si ret;                                            \
    COPY_IMM_TO_XMM(x, x1, x2);                          \
    x1 = _mm_##fn(x1, a);                                \
    x2 = _mm_##fn(x2, a);                                \
    COPY_XMM_TO_IMM(x1, x2, ret);                        \
    return (ret);                                        \
  }
Z
zhangjinchao01 已提交
140 141 142 143 144

//#warning "Using SSE2 to perform AVX2 bitshift ops"
AVX2_BITOP_USING_SSE2(slli_epi32)
AVX2_BITOP_USING_SSE2(srli_epi32)

145 146 147 148 149 150 151 152 153 154 155 156 157
#define AVX2_INTOP_USING_SSE2(fn)                                     \
  static inline v8si avx2_mm256_##fn(v8si x, v8si y) {                \
    /* use SSE2 instructions to perform the AVX2 integer operation */ \
    v4si x1, x2;                                                      \
    v4si y1, y2;                                                      \
    v8si ret;                                                         \
    COPY_IMM_TO_XMM(x, x1, x2);                                       \
    COPY_IMM_TO_XMM(y, y1, y2);                                       \
    x1 = _mm_##fn(x1, y1);                                            \
    x2 = _mm_##fn(x2, y2);                                            \
    COPY_XMM_TO_IMM(x1, x2, ret);                                     \
    return (ret);                                                     \
  }
Z
zhangjinchao01 已提交
158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176

//#warning "Using SSE2 to perform AVX2 integer ops"
AVX2_INTOP_USING_SSE2(and_si128)
AVX2_INTOP_USING_SSE2(andnot_si128)
AVX2_INTOP_USING_SSE2(cmpeq_epi32)
AVX2_INTOP_USING_SSE2(sub_epi32)
AVX2_INTOP_USING_SSE2(add_epi32)
#define avx2_mm256_and_si256 avx2_mm256_and_si128
#define avx2_mm256_andnot_si256 avx2_mm256_andnot_si128
#else
#define avx2_mm256_slli_epi32 _mm256_slli_epi32
#define avx2_mm256_srli_epi32 _mm256_srli_epi32
#define avx2_mm256_and_si256 _mm256_and_si256
#define avx2_mm256_andnot_si256 _mm256_andnot_si256
#define avx2_mm256_cmpeq_epi32 _mm256_cmpeq_epi32
#define avx2_mm256_sub_epi32 _mm256_sub_epi32
#define avx2_mm256_add_epi32 _mm256_add_epi32
#endif /* __AVX2__ */

177
/* natural logarithm computed for 8 simultaneous float
Z
zhangjinchao01 已提交
178 179 180 181
   return NaN for x <= 0
*/
v8sf log256_ps(v8sf x) {
  v8si imm0;
182
  v8sf one = *(v8sf *)_ps256_1;
Z
zhangjinchao01 已提交
183

184
  // v8sf invalid_mask = _mm256_cmple_ps(x, _mm256_setzero_ps());
Z
zhangjinchao01 已提交
185 186
  v8sf invalid_mask = _mm256_cmp_ps(x, _mm256_setzero_ps(), _CMP_LE_OS);

187 188
  x = _mm256_max_ps(
      x, *(v8sf *)_ps256_min_norm_pos); /* cut off denormalized stuff */
Z
zhangjinchao01 已提交
189 190 191 192 193

  // can be done with AVX2
  imm0 = avx2_mm256_srli_epi32(_mm256_castps_si256(x), 23);

  /* keep only the fractional part */
194 195
  x = _mm256_and_ps(x, *(v8sf *)_ps256_inv_mant_mask);
  x = _mm256_or_ps(x, *(v8sf *)_ps256_0p5);
Z
zhangjinchao01 已提交
196 197

  // this is again another AVX2 instruction
198
  imm0 = avx2_mm256_sub_epi32(imm0, *(v8si *)_pi32_256_0x7f);
Z
zhangjinchao01 已提交
199 200 201 202
  v8sf e = _mm256_cvtepi32_ps(imm0);

  e = _mm256_add_ps(e, one);

203
  /* part2:
Z
zhangjinchao01 已提交
204 205 206 207 208
     if( x < SQRTHF ) {
       e -= 1;
       x = x + x - 1.0;
     } else { x = x - 1.0; }
  */
209 210
  // v8sf mask = _mm256_cmplt_ps(x, *(v8sf*)_ps256_cephes_SQRTHF);
  v8sf mask = _mm256_cmp_ps(x, *(v8sf *)_ps256_cephes_SQRTHF, _CMP_LT_OS);
Z
zhangjinchao01 已提交
211 212 213 214 215
  v8sf tmp = _mm256_and_ps(x, mask);
  x = _mm256_sub_ps(x, one);
  e = _mm256_sub_ps(e, _mm256_and_ps(one, mask));
  x = _mm256_add_ps(x, tmp);

216
  v8sf z = _mm256_mul_ps(x, x);
Z
zhangjinchao01 已提交
217

218
  v8sf y = *(v8sf *)_ps256_cephes_log_p0;
Z
zhangjinchao01 已提交
219
  y = _mm256_mul_ps(y, x);
220
  y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_log_p1);
Z
zhangjinchao01 已提交
221
  y = _mm256_mul_ps(y, x);
222
  y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_log_p2);
Z
zhangjinchao01 已提交
223
  y = _mm256_mul_ps(y, x);
224
  y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_log_p3);
Z
zhangjinchao01 已提交
225
  y = _mm256_mul_ps(y, x);
226
  y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_log_p4);
Z
zhangjinchao01 已提交
227
  y = _mm256_mul_ps(y, x);
228
  y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_log_p5);
Z
zhangjinchao01 已提交
229
  y = _mm256_mul_ps(y, x);
230
  y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_log_p6);
Z
zhangjinchao01 已提交
231
  y = _mm256_mul_ps(y, x);
232
  y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_log_p7);
Z
zhangjinchao01 已提交
233
  y = _mm256_mul_ps(y, x);
234
  y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_log_p8);
Z
zhangjinchao01 已提交
235 236 237 238
  y = _mm256_mul_ps(y, x);

  y = _mm256_mul_ps(y, z);

239 240
  tmp = _mm256_mul_ps(e, *(v8sf *)_ps256_cephes_log_q1);
  y = _mm256_add_ps(y, tmp);
Z
zhangjinchao01 已提交
241

242
  tmp = _mm256_mul_ps(z, *(v8sf *)_ps256_0p5);
Z
zhangjinchao01 已提交
243 244
  y = _mm256_sub_ps(y, tmp);

245
  tmp = _mm256_mul_ps(e, *(v8sf *)_ps256_cephes_log_q2);
Z
zhangjinchao01 已提交
246 247
  x = _mm256_add_ps(x, y);
  x = _mm256_add_ps(x, tmp);
248
  x = _mm256_or_ps(x, invalid_mask);  // negative arg will be NAN
Z
zhangjinchao01 已提交
249 250 251
  return x;
}

252 253
_PS256_CONST(exp_hi, 88.3762626647949f);
_PS256_CONST(exp_lo, -88.3762626647949f);
Z
zhangjinchao01 已提交
254 255 256 257 258 259 260 261 262 263 264 265 266 267 268

_PS256_CONST(cephes_LOG2EF, 1.44269504088896341);
_PS256_CONST(cephes_exp_C1, 0.693359375);
_PS256_CONST(cephes_exp_C2, -2.12194440e-4);

_PS256_CONST(cephes_exp_p0, 1.9875691500E-4);
_PS256_CONST(cephes_exp_p1, 1.3981999507E-3);
_PS256_CONST(cephes_exp_p2, 8.3334519073E-3);
_PS256_CONST(cephes_exp_p3, 4.1665795894E-2);
_PS256_CONST(cephes_exp_p4, 1.6666665459E-1);
_PS256_CONST(cephes_exp_p5, 5.0000001201E-1);

v8sf exp256_ps(v8sf x) {
  v8sf tmp = _mm256_setzero_ps(), fx;
  v8si imm0;
269
  v8sf one = *(v8sf *)_ps256_1;
Z
zhangjinchao01 已提交
270

271 272
  x = _mm256_min_ps(x, *(v8sf *)_ps256_exp_hi);
  x = _mm256_max_ps(x, *(v8sf *)_ps256_exp_lo);
Z
zhangjinchao01 已提交
273 274

  /* express exp(x) as exp(g + n*log(2)) */
275 276
  fx = _mm256_mul_ps(x, *(v8sf *)_ps256_cephes_LOG2EF);
  fx = _mm256_add_ps(fx, *(v8sf *)_ps256_0p5);
Z
zhangjinchao01 已提交
277 278

  /* how to perform a floorf with SSE: just below */
279 280 281
  // imm0 = _mm256_cvttps_epi32(fx);
  // tmp  = _mm256_cvtepi32_ps(imm0);

Z
zhangjinchao01 已提交
282 283 284
  tmp = _mm256_floor_ps(fx);

  /* if greater, substract 1 */
285 286
  // v8sf mask = _mm256_cmpgt_ps(tmp, fx);
  v8sf mask = _mm256_cmp_ps(tmp, fx, _CMP_GT_OS);
Z
zhangjinchao01 已提交
287 288 289
  mask = _mm256_and_ps(mask, one);
  fx = _mm256_sub_ps(tmp, mask);

290 291
  tmp = _mm256_mul_ps(fx, *(v8sf *)_ps256_cephes_exp_C1);
  v8sf z = _mm256_mul_ps(fx, *(v8sf *)_ps256_cephes_exp_C2);
Z
zhangjinchao01 已提交
292 293 294
  x = _mm256_sub_ps(x, tmp);
  x = _mm256_sub_ps(x, z);

295 296 297
  z = _mm256_mul_ps(x, x);

  v8sf y = *(v8sf *)_ps256_cephes_exp_p0;
Z
zhangjinchao01 已提交
298
  y = _mm256_mul_ps(y, x);
299
  y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_exp_p1);
Z
zhangjinchao01 已提交
300
  y = _mm256_mul_ps(y, x);
301
  y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_exp_p2);
Z
zhangjinchao01 已提交
302
  y = _mm256_mul_ps(y, x);
303
  y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_exp_p3);
Z
zhangjinchao01 已提交
304
  y = _mm256_mul_ps(y, x);
305
  y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_exp_p4);
Z
zhangjinchao01 已提交
306
  y = _mm256_mul_ps(y, x);
307
  y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_exp_p5);
Z
zhangjinchao01 已提交
308 309 310 311 312 313 314
  y = _mm256_mul_ps(y, z);
  y = _mm256_add_ps(y, x);
  y = _mm256_add_ps(y, one);

  /* build 2^n */
  imm0 = _mm256_cvttps_epi32(fx);
  // another two AVX2 instructions
315
  imm0 = avx2_mm256_add_epi32(imm0, *(v8si *)_pi32_256_0x7f);
Z
zhangjinchao01 已提交
316 317 318 319 320 321 322 323 324 325
  imm0 = avx2_mm256_slli_epi32(imm0, 23);
  v8sf pow2n = _mm256_castsi256_ps(imm0);
  y = _mm256_mul_ps(y, pow2n);
  return y;
}

_PS256_CONST(minus_cephes_DP1, -0.78515625);
_PS256_CONST(minus_cephes_DP2, -2.4187564849853515625e-4);
_PS256_CONST(minus_cephes_DP3, -3.77489497744594108e-8);
_PS256_CONST(sincof_p0, -1.9515295891E-4);
326
_PS256_CONST(sincof_p1, 8.3321608736E-3);
Z
zhangjinchao01 已提交
327
_PS256_CONST(sincof_p2, -1.6666654611E-1);
328
_PS256_CONST(coscof_p0, 2.443315711809948E-005);
Z
zhangjinchao01 已提交
329
_PS256_CONST(coscof_p1, -1.388731625493765E-003);
330 331
_PS256_CONST(coscof_p2, 4.166664568298827E-002);
_PS256_CONST(cephes_FOPI, 1.27323954473516);  // 4 / M_PI
Z
zhangjinchao01 已提交
332 333 334 335 336 337 338 339 340 341 342 343 344

/* evaluation of 8 sines at onces using AVX intrisics

   The code is the exact rewriting of the cephes sinf function.
   Precision is excellent as long as x < 8192 (I did not bother to
   take into account the special handling they have for greater values
   -- it does not return garbage for arguments over 8192, though, but
   the extra precision is missing).

   Note that it is such that sinf((float)M_PI) = 8.74e-8, which is the
   surprising but correct result.

*/
345
v8sf sin256_ps(v8sf x) {  // any x
Z
zhangjinchao01 已提交
346 347 348 349 350 351 352 353 354 355
  v8sf xmm1, xmm2 = _mm256_setzero_ps(), xmm3, sign_bit, y;
  v8si imm0, imm2;

#ifndef __AVX2__
  v4si imm0_1, imm0_2;
  v4si imm2_1, imm2_2;
#endif

  sign_bit = x;
  /* take the absolute value */
356
  x = _mm256_and_ps(x, *(v8sf *)_ps256_inv_sign_mask);
Z
zhangjinchao01 已提交
357
  /* extract the sign bit (upper one) */
358 359
  sign_bit = _mm256_and_ps(sign_bit, *(v8sf *)_ps256_sign_mask);

Z
zhangjinchao01 已提交
360
  /* scale by 4/Pi */
361
  y = _mm256_mul_ps(x, *(v8sf *)_ps256_cephes_FOPI);
Z
zhangjinchao01 已提交
362

363 364 365 366 367
/*
  Here we start a series of integer operations, which are in the
  realm of AVX2.
  If we don't have AVX, let's perform them using SSE2 directives
*/
Z
zhangjinchao01 已提交
368 369 370 371 372 373

#ifdef __AVX2__
  /* store the integer part of y in mm0 */
  imm2 = _mm256_cvttps_epi32(y);
  /* j=(j+1) & (~1) (see the cephes sources) */
  // another two AVX2 instruction
374 375
  imm2 = avx2_mm256_add_epi32(imm2, *(v8si *)_pi32_256_1);
  imm2 = avx2_mm256_and_si256(imm2, *(v8si *)_pi32_256_inv1);
Z
zhangjinchao01 已提交
376 377 378
  y = _mm256_cvtepi32_ps(imm2);

  /* get the swap sign flag */
379
  imm0 = avx2_mm256_and_si256(imm2, *(v8si *)_pi32_256_4);
Z
zhangjinchao01 已提交
380
  imm0 = avx2_mm256_slli_epi32(imm0, 29);
381
  /* get the polynom selection mask
Z
zhangjinchao01 已提交
382 383 384 385 386
     there is one polynom for 0 <= x <= Pi/4
     and another one for Pi/4<x<=Pi/2

     Both branches will be computed.
  */
387 388
  imm2 = avx2_mm256_and_si256(imm2, *(v8si *)_pi32_256_2);
  imm2 = avx2_mm256_cmpeq_epi32(imm2, *(v8si *)_pi32_256_0);
Z
zhangjinchao01 已提交
389 390
#else
  /* we use SSE2 routines to perform the integer ops */
391
  COPY_IMM_TO_XMM(_mm256_cvttps_epi32(y), imm2_1, imm2_2);
Z
zhangjinchao01 已提交
392

393 394
  imm2_1 = _mm_add_epi32(imm2_1, *(v4si *)_pi32avx_1);
  imm2_2 = _mm_add_epi32(imm2_2, *(v4si *)_pi32avx_1);
Z
zhangjinchao01 已提交
395

396 397
  imm2_1 = _mm_and_si128(imm2_1, *(v4si *)_pi32avx_inv1);
  imm2_2 = _mm_and_si128(imm2_2, *(v4si *)_pi32avx_inv1);
Z
zhangjinchao01 已提交
398

399
  COPY_XMM_TO_IMM(imm2_1, imm2_2, imm2);
Z
zhangjinchao01 已提交
400 401
  y = _mm256_cvtepi32_ps(imm2);

402 403
  imm0_1 = _mm_and_si128(imm2_1, *(v4si *)_pi32avx_4);
  imm0_2 = _mm_and_si128(imm2_2, *(v4si *)_pi32avx_4);
Z
zhangjinchao01 已提交
404 405 406 407 408 409

  imm0_1 = _mm_slli_epi32(imm0_1, 29);
  imm0_2 = _mm_slli_epi32(imm0_2, 29);

  COPY_XMM_TO_IMM(imm0_1, imm0_2, imm0);

410 411
  imm2_1 = _mm_and_si128(imm2_1, *(v4si *)_pi32avx_2);
  imm2_2 = _mm_and_si128(imm2_2, *(v4si *)_pi32avx_2);
Z
zhangjinchao01 已提交
412 413 414 415 416 417

  imm2_1 = _mm_cmpeq_epi32(imm2_1, _mm_setzero_si128());
  imm2_2 = _mm_cmpeq_epi32(imm2_2, _mm_setzero_si128());

  COPY_XMM_TO_IMM(imm2_1, imm2_2, imm2);
#endif
418

Z
zhangjinchao01 已提交
419 420 421 422
  v8sf swap_sign_bit = _mm256_castsi256_ps(imm0);
  v8sf poly_mask = _mm256_castsi256_ps(imm2);
  sign_bit = _mm256_xor_ps(sign_bit, swap_sign_bit);

423
  /* The magic pass: "Extended precision modular arithmetic"
Z
zhangjinchao01 已提交
424
     x = ((x - y * DP1) - y * DP2) - y * DP3; */
425 426 427
  xmm1 = *(v8sf *)_ps256_minus_cephes_DP1;
  xmm2 = *(v8sf *)_ps256_minus_cephes_DP2;
  xmm3 = *(v8sf *)_ps256_minus_cephes_DP3;
Z
zhangjinchao01 已提交
428 429 430 431 432 433 434 435
  xmm1 = _mm256_mul_ps(y, xmm1);
  xmm2 = _mm256_mul_ps(y, xmm2);
  xmm3 = _mm256_mul_ps(y, xmm3);
  x = _mm256_add_ps(x, xmm1);
  x = _mm256_add_ps(x, xmm2);
  x = _mm256_add_ps(x, xmm3);

  /* Evaluate the first polynom  (0 <= x <= Pi/4) */
436 437
  y = *(v8sf *)_ps256_coscof_p0;
  v8sf z = _mm256_mul_ps(x, x);
Z
zhangjinchao01 已提交
438 439

  y = _mm256_mul_ps(y, z);
440
  y = _mm256_add_ps(y, *(v8sf *)_ps256_coscof_p1);
Z
zhangjinchao01 已提交
441
  y = _mm256_mul_ps(y, z);
442
  y = _mm256_add_ps(y, *(v8sf *)_ps256_coscof_p2);
Z
zhangjinchao01 已提交
443 444
  y = _mm256_mul_ps(y, z);
  y = _mm256_mul_ps(y, z);
445
  v8sf tmp = _mm256_mul_ps(z, *(v8sf *)_ps256_0p5);
Z
zhangjinchao01 已提交
446
  y = _mm256_sub_ps(y, tmp);
447 448
  y = _mm256_add_ps(y, *(v8sf *)_ps256_1);

Z
zhangjinchao01 已提交
449 450
  /* Evaluate the second polynom  (Pi/4 <= x <= 0) */

451
  v8sf y2 = *(v8sf *)_ps256_sincof_p0;
Z
zhangjinchao01 已提交
452
  y2 = _mm256_mul_ps(y2, z);
453
  y2 = _mm256_add_ps(y2, *(v8sf *)_ps256_sincof_p1);
Z
zhangjinchao01 已提交
454
  y2 = _mm256_mul_ps(y2, z);
455
  y2 = _mm256_add_ps(y2, *(v8sf *)_ps256_sincof_p2);
Z
zhangjinchao01 已提交
456 457 458 459
  y2 = _mm256_mul_ps(y2, z);
  y2 = _mm256_mul_ps(y2, x);
  y2 = _mm256_add_ps(y2, x);

460
  /* select the correct result from the two polynoms */
Z
zhangjinchao01 已提交
461
  xmm3 = poly_mask;
462
  y2 = _mm256_and_ps(xmm3, y2);  //, xmm3);
Z
zhangjinchao01 已提交
463
  y = _mm256_andnot_ps(xmm3, y);
464
  y = _mm256_add_ps(y, y2);
Z
zhangjinchao01 已提交
465 466 467 468 469 470 471
  /* update the sign */
  y = _mm256_xor_ps(y, sign_bit);

  return y;
}

/* almost the same as sin_ps */
472
v8sf cos256_ps(v8sf x) {  // any x
Z
zhangjinchao01 已提交
473 474 475 476 477 478 479 480 481
  v8sf xmm1, xmm2 = _mm256_setzero_ps(), xmm3, y;
  v8si imm0, imm2;

#ifndef __AVX2__
  v4si imm0_1, imm0_2;
  v4si imm2_1, imm2_2;
#endif

  /* take the absolute value */
482 483
  x = _mm256_and_ps(x, *(v8sf *)_ps256_inv_sign_mask);

Z
zhangjinchao01 已提交
484
  /* scale by 4/Pi */
485 486
  y = _mm256_mul_ps(x, *(v8sf *)_ps256_cephes_FOPI);

Z
zhangjinchao01 已提交
487 488 489 490
#ifdef __AVX2__
  /* store the integer part of y in mm0 */
  imm2 = _mm256_cvttps_epi32(y);
  /* j=(j+1) & (~1) (see the cephes sources) */
491 492
  imm2 = avx2_mm256_add_epi32(imm2, *(v8si *)_pi32_256_1);
  imm2 = avx2_mm256_and_si256(imm2, *(v8si *)_pi32_256_inv1);
Z
zhangjinchao01 已提交
493
  y = _mm256_cvtepi32_ps(imm2);
494 495
  imm2 = avx2_mm256_sub_epi32(imm2, *(v8si *)_pi32_256_2);

Z
zhangjinchao01 已提交
496
  /* get the swap sign flag */
497
  imm0 = avx2_mm256_andnot_si256(imm2, *(v8si *)_pi32_256_4);
Z
zhangjinchao01 已提交
498 499
  imm0 = avx2_mm256_slli_epi32(imm0, 29);
  /* get the polynom selection mask */
500 501
  imm2 = avx2_mm256_and_si256(imm2, *(v8si *)_pi32_256_2);
  imm2 = avx2_mm256_cmpeq_epi32(imm2, *(v8si *)_pi32_256_0);
Z
zhangjinchao01 已提交
502 503 504
#else

  /* we use SSE2 routines to perform the integer ops */
505
  COPY_IMM_TO_XMM(_mm256_cvttps_epi32(y), imm2_1, imm2_2);
Z
zhangjinchao01 已提交
506

507 508
  imm2_1 = _mm_add_epi32(imm2_1, *(v4si *)_pi32avx_1);
  imm2_2 = _mm_add_epi32(imm2_2, *(v4si *)_pi32avx_1);
Z
zhangjinchao01 已提交
509

510 511
  imm2_1 = _mm_and_si128(imm2_1, *(v4si *)_pi32avx_inv1);
  imm2_2 = _mm_and_si128(imm2_2, *(v4si *)_pi32avx_inv1);
Z
zhangjinchao01 已提交
512

513
  COPY_XMM_TO_IMM(imm2_1, imm2_2, imm2);
Z
zhangjinchao01 已提交
514 515
  y = _mm256_cvtepi32_ps(imm2);

516 517
  imm2_1 = _mm_sub_epi32(imm2_1, *(v4si *)_pi32avx_2);
  imm2_2 = _mm_sub_epi32(imm2_2, *(v4si *)_pi32avx_2);
Z
zhangjinchao01 已提交
518

519 520
  imm0_1 = _mm_andnot_si128(imm2_1, *(v4si *)_pi32avx_4);
  imm0_2 = _mm_andnot_si128(imm2_2, *(v4si *)_pi32avx_4);
Z
zhangjinchao01 已提交
521 522 523 524 525 526

  imm0_1 = _mm_slli_epi32(imm0_1, 29);
  imm0_2 = _mm_slli_epi32(imm0_2, 29);

  COPY_XMM_TO_IMM(imm0_1, imm0_2, imm0);

527 528
  imm2_1 = _mm_and_si128(imm2_1, *(v4si *)_pi32avx_2);
  imm2_2 = _mm_and_si128(imm2_2, *(v4si *)_pi32avx_2);
Z
zhangjinchao01 已提交
529 530 531 532 533 534 535 536 537 538

  imm2_1 = _mm_cmpeq_epi32(imm2_1, _mm_setzero_si128());
  imm2_2 = _mm_cmpeq_epi32(imm2_2, _mm_setzero_si128());

  COPY_XMM_TO_IMM(imm2_1, imm2_2, imm2);
#endif

  v8sf sign_bit = _mm256_castsi256_ps(imm0);
  v8sf poly_mask = _mm256_castsi256_ps(imm2);

539
  /* The magic pass: "Extended precision modular arithmetic"
Z
zhangjinchao01 已提交
540
     x = ((x - y * DP1) - y * DP2) - y * DP3; */
541 542 543
  xmm1 = *(v8sf *)_ps256_minus_cephes_DP1;
  xmm2 = *(v8sf *)_ps256_minus_cephes_DP2;
  xmm3 = *(v8sf *)_ps256_minus_cephes_DP3;
Z
zhangjinchao01 已提交
544 545 546 547 548 549
  xmm1 = _mm256_mul_ps(y, xmm1);
  xmm2 = _mm256_mul_ps(y, xmm2);
  xmm3 = _mm256_mul_ps(y, xmm3);
  x = _mm256_add_ps(x, xmm1);
  x = _mm256_add_ps(x, xmm2);
  x = _mm256_add_ps(x, xmm3);
550

Z
zhangjinchao01 已提交
551
  /* Evaluate the first polynom  (0 <= x <= Pi/4) */
552 553
  y = *(v8sf *)_ps256_coscof_p0;
  v8sf z = _mm256_mul_ps(x, x);
Z
zhangjinchao01 已提交
554 555

  y = _mm256_mul_ps(y, z);
556
  y = _mm256_add_ps(y, *(v8sf *)_ps256_coscof_p1);
Z
zhangjinchao01 已提交
557
  y = _mm256_mul_ps(y, z);
558
  y = _mm256_add_ps(y, *(v8sf *)_ps256_coscof_p2);
Z
zhangjinchao01 已提交
559 560
  y = _mm256_mul_ps(y, z);
  y = _mm256_mul_ps(y, z);
561
  v8sf tmp = _mm256_mul_ps(z, *(v8sf *)_ps256_0p5);
Z
zhangjinchao01 已提交
562
  y = _mm256_sub_ps(y, tmp);
563 564
  y = _mm256_add_ps(y, *(v8sf *)_ps256_1);

Z
zhangjinchao01 已提交
565 566
  /* Evaluate the second polynom  (Pi/4 <= x <= 0) */

567
  v8sf y2 = *(v8sf *)_ps256_sincof_p0;
Z
zhangjinchao01 已提交
568
  y2 = _mm256_mul_ps(y2, z);
569
  y2 = _mm256_add_ps(y2, *(v8sf *)_ps256_sincof_p1);
Z
zhangjinchao01 已提交
570
  y2 = _mm256_mul_ps(y2, z);
571
  y2 = _mm256_add_ps(y2, *(v8sf *)_ps256_sincof_p2);
Z
zhangjinchao01 已提交
572 573 574 575
  y2 = _mm256_mul_ps(y2, z);
  y2 = _mm256_mul_ps(y2, x);
  y2 = _mm256_add_ps(y2, x);

576
  /* select the correct result from the two polynoms */
Z
zhangjinchao01 已提交
577
  xmm3 = poly_mask;
578
  y2 = _mm256_and_ps(xmm3, y2);  //, xmm3);
Z
zhangjinchao01 已提交
579
  y = _mm256_andnot_ps(xmm3, y);
580
  y = _mm256_add_ps(y, y2);
Z
zhangjinchao01 已提交
581 582 583 584 585 586
  /* update the sign */
  y = _mm256_xor_ps(y, sign_bit);

  return y;
}

587 588
/* since sin256_ps and cos256_ps are almost identical, sincos256_ps could
   replace both of them..
Z
zhangjinchao01 已提交
589 590 591 592 593 594 595 596 597 598 599 600 601
   it is almost as fast, and gives you a free cosine with your sine */
void sincos256_ps(v8sf x, v8sf *s, v8sf *c) {
  v8sf xmm1, xmm2, xmm3 = _mm256_setzero_ps(), sign_bit_sin, y;
  v8si imm0, imm2, imm4;

#ifndef __AVX2__
  v4si imm0_1, imm0_2;
  v4si imm2_1, imm2_2;
  v4si imm4_1, imm4_2;
#endif

  sign_bit_sin = x;
  /* take the absolute value */
602
  x = _mm256_and_ps(x, *(v8sf *)_ps256_inv_sign_mask);
Z
zhangjinchao01 已提交
603
  /* extract the sign bit (upper one) */
604 605
  sign_bit_sin = _mm256_and_ps(sign_bit_sin, *(v8sf *)_ps256_sign_mask);

Z
zhangjinchao01 已提交
606
  /* scale by 4/Pi */
607
  y = _mm256_mul_ps(x, *(v8sf *)_ps256_cephes_FOPI);
Z
zhangjinchao01 已提交
608

609
#ifdef __AVX2__
Z
zhangjinchao01 已提交
610 611 612 613
  /* store the integer part of y in imm2 */
  imm2 = _mm256_cvttps_epi32(y);

  /* j=(j+1) & (~1) (see the cephes sources) */
614 615
  imm2 = avx2_mm256_add_epi32(imm2, *(v8si *)_pi32_256_1);
  imm2 = avx2_mm256_and_si256(imm2, *(v8si *)_pi32_256_inv1);
Z
zhangjinchao01 已提交
616 617 618 619 620

  y = _mm256_cvtepi32_ps(imm2);
  imm4 = imm2;

  /* get the swap sign flag for the sine */
621
  imm0 = avx2_mm256_and_si256(imm2, *(v8si *)_pi32_256_4);
Z
zhangjinchao01 已提交
622
  imm0 = avx2_mm256_slli_epi32(imm0, 29);
623
  // v8sf swap_sign_bit_sin = _mm256_castsi256_ps(imm0);
Z
zhangjinchao01 已提交
624 625

  /* get the polynom selection mask for the sine*/
626 627 628
  imm2 = avx2_mm256_and_si256(imm2, *(v8si *)_pi32_256_2);
  imm2 = avx2_mm256_cmpeq_epi32(imm2, *(v8si *)_pi32_256_0);
// v8sf poly_mask = _mm256_castsi256_ps(imm2);
Z
zhangjinchao01 已提交
629 630
#else
  /* we use SSE2 routines to perform the integer ops */
631 632 633 634
  COPY_IMM_TO_XMM(_mm256_cvttps_epi32(y), imm2_1, imm2_2);

  imm2_1 = _mm_add_epi32(imm2_1, *(v4si *)_pi32avx_1);
  imm2_2 = _mm_add_epi32(imm2_2, *(v4si *)_pi32avx_1);
Z
zhangjinchao01 已提交
635

636 637
  imm2_1 = _mm_and_si128(imm2_1, *(v4si *)_pi32avx_inv1);
  imm2_2 = _mm_and_si128(imm2_2, *(v4si *)_pi32avx_inv1);
Z
zhangjinchao01 已提交
638

639
  COPY_XMM_TO_IMM(imm2_1, imm2_2, imm2);
Z
zhangjinchao01 已提交
640 641 642 643 644
  y = _mm256_cvtepi32_ps(imm2);

  imm4_1 = imm2_1;
  imm4_2 = imm2_2;

645 646 647
  imm0_1 = _mm_and_si128(imm2_1, *(v4si *)_pi32avx_4);
  imm0_2 = _mm_and_si128(imm2_2, *(v4si *)_pi32avx_4);

Z
zhangjinchao01 已提交
648 649 650 651 652
  imm0_1 = _mm_slli_epi32(imm0_1, 29);
  imm0_2 = _mm_slli_epi32(imm0_2, 29);

  COPY_XMM_TO_IMM(imm0_1, imm0_2, imm0);

653 654
  imm2_1 = _mm_and_si128(imm2_1, *(v4si *)_pi32avx_2);
  imm2_2 = _mm_and_si128(imm2_2, *(v4si *)_pi32avx_2);
Z
zhangjinchao01 已提交
655 656 657 658 659 660 661 662 663

  imm2_1 = _mm_cmpeq_epi32(imm2_1, _mm_setzero_si128());
  imm2_2 = _mm_cmpeq_epi32(imm2_2, _mm_setzero_si128());

  COPY_XMM_TO_IMM(imm2_1, imm2_2, imm2);
#endif
  v8sf swap_sign_bit_sin = _mm256_castsi256_ps(imm0);
  v8sf poly_mask = _mm256_castsi256_ps(imm2);

664
  /* The magic pass: "Extended precision modular arithmetic"
Z
zhangjinchao01 已提交
665
     x = ((x - y * DP1) - y * DP2) - y * DP3; */
666 667 668
  xmm1 = *(v8sf *)_ps256_minus_cephes_DP1;
  xmm2 = *(v8sf *)_ps256_minus_cephes_DP2;
  xmm3 = *(v8sf *)_ps256_minus_cephes_DP3;
Z
zhangjinchao01 已提交
669 670 671 672 673 674 675 676
  xmm1 = _mm256_mul_ps(y, xmm1);
  xmm2 = _mm256_mul_ps(y, xmm2);
  xmm3 = _mm256_mul_ps(y, xmm3);
  x = _mm256_add_ps(x, xmm1);
  x = _mm256_add_ps(x, xmm2);
  x = _mm256_add_ps(x, xmm3);

#ifdef __AVX2__
677 678
  imm4 = avx2_mm256_sub_epi32(imm4, *(v8si *)_pi32_256_2);
  imm4 = avx2_mm256_andnot_si256(imm4, *(v8si *)_pi32_256_4);
Z
zhangjinchao01 已提交
679 680
  imm4 = avx2_mm256_slli_epi32(imm4, 29);
#else
681 682 683 684 685
  imm4_1 = _mm_sub_epi32(imm4_1, *(v4si *)_pi32avx_2);
  imm4_2 = _mm_sub_epi32(imm4_2, *(v4si *)_pi32avx_2);

  imm4_1 = _mm_andnot_si128(imm4_1, *(v4si *)_pi32avx_4);
  imm4_2 = _mm_andnot_si128(imm4_2, *(v4si *)_pi32avx_4);
Z
zhangjinchao01 已提交
686 687 688 689 690 691 692 693 694 695

  imm4_1 = _mm_slli_epi32(imm4_1, 29);
  imm4_2 = _mm_slli_epi32(imm4_2, 29);

  COPY_XMM_TO_IMM(imm4_1, imm4_2, imm4);
#endif

  v8sf sign_bit_cos = _mm256_castsi256_ps(imm4);

  sign_bit_sin = _mm256_xor_ps(sign_bit_sin, swap_sign_bit_sin);
696

Z
zhangjinchao01 已提交
697
  /* Evaluate the first polynom  (0 <= x <= Pi/4) */
698 699
  v8sf z = _mm256_mul_ps(x, x);
  y = *(v8sf *)_ps256_coscof_p0;
Z
zhangjinchao01 已提交
700 701

  y = _mm256_mul_ps(y, z);
702
  y = _mm256_add_ps(y, *(v8sf *)_ps256_coscof_p1);
Z
zhangjinchao01 已提交
703
  y = _mm256_mul_ps(y, z);
704
  y = _mm256_add_ps(y, *(v8sf *)_ps256_coscof_p2);
Z
zhangjinchao01 已提交
705 706
  y = _mm256_mul_ps(y, z);
  y = _mm256_mul_ps(y, z);
707
  v8sf tmp = _mm256_mul_ps(z, *(v8sf *)_ps256_0p5);
Z
zhangjinchao01 已提交
708
  y = _mm256_sub_ps(y, tmp);
709 710
  y = _mm256_add_ps(y, *(v8sf *)_ps256_1);

Z
zhangjinchao01 已提交
711 712
  /* Evaluate the second polynom  (Pi/4 <= x <= 0) */

713
  v8sf y2 = *(v8sf *)_ps256_sincof_p0;
Z
zhangjinchao01 已提交
714
  y2 = _mm256_mul_ps(y2, z);
715
  y2 = _mm256_add_ps(y2, *(v8sf *)_ps256_sincof_p1);
Z
zhangjinchao01 已提交
716
  y2 = _mm256_mul_ps(y2, z);
717
  y2 = _mm256_add_ps(y2, *(v8sf *)_ps256_sincof_p2);
Z
zhangjinchao01 已提交
718 719 720 721
  y2 = _mm256_mul_ps(y2, z);
  y2 = _mm256_mul_ps(y2, x);
  y2 = _mm256_add_ps(y2, x);

722
  /* select the correct result from the two polynoms */
Z
zhangjinchao01 已提交
723 724 725
  xmm3 = poly_mask;
  v8sf ysin2 = _mm256_and_ps(xmm3, y2);
  v8sf ysin1 = _mm256_andnot_ps(xmm3, y);
726
  y2 = _mm256_sub_ps(y2, ysin2);
Z
zhangjinchao01 已提交
727 728
  y = _mm256_sub_ps(y, ysin1);

729 730 731
  xmm1 = _mm256_add_ps(ysin1, ysin2);
  xmm2 = _mm256_add_ps(y, y2);

Z
zhangjinchao01 已提交
732 733 734 735
  /* update the sign */
  *s = _mm256_xor_ps(xmm1, sign_bit_sin);
  *c = _mm256_xor_ps(xmm2, sign_bit_cos);
}