Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
fa78fc60
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
fa78fc60
编写于
12月 18, 2018
作者:
W
wopeizl
提交者:
GitHub
12月 18, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #14907 from wopeizl/windows/avx
add avx support for windows
上级
addded48
41456e17
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
777 addition
and
43 deletion
+777
-43
CMakeLists.txt
CMakeLists.txt
+0
-2
paddle/fluid/operators/math/cpu_vec.h
paddle/fluid/operators/math/cpu_vec.h
+0
-3
paddle/fluid/operators/math/detail/activation_functions.h
paddle/fluid/operators/math/detail/activation_functions.h
+1
-5
paddle/fluid/operators/math/detail/avx_functions.cc
paddle/fluid/operators/math/detail/avx_functions.cc
+1
-3
paddle/fluid/operators/math/detail/avx_mathfun.h
paddle/fluid/operators/math/detail/avx_mathfun.h
+731
-0
paddle/fluid/operators/math/jit_code.cc
paddle/fluid/operators/math/jit_code.cc
+20
-19
paddle/fluid/operators/math/jit_code.h
paddle/fluid/operators/math/jit_code.h
+0
-1
paddle/fluid/operators/math/jit_kernel_crf_decode.cc
paddle/fluid/operators/math/jit_kernel_crf_decode.cc
+2
-5
paddle/fluid/operators/math/jit_kernel_layer_norm.cc
paddle/fluid/operators/math/jit_kernel_layer_norm.cc
+2
-5
paddle/fluid/platform/cpu_info.h
paddle/fluid/platform/cpu_info.h
+20
-0
未找到文件。
CMakeLists.txt
浏览文件 @
fa78fc60
...
...
@@ -131,8 +131,6 @@ if (APPLE OR WIN32)
endif
()
if
(
WIN32
)
set
(
WITH_AVX OFF CACHE STRING
"Disable AVX when compiling for Windows"
FORCE
)
set
(
WITH_DSO OFF CACHE STRING
"Disable DSO when compiling for Windows"
FORCE
)
set
(
WITH_MKL OFF CACHE STRING
...
...
paddle/fluid/operators/math/cpu_vec.h
浏览文件 @
fa78fc60
...
...
@@ -18,9 +18,6 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/enforce.h"
#ifdef __AVX__
#include <immintrin.h>
#endif
#ifdef PADDLE_WITH_MKLML
#include "paddle/fluid/platform/dynload/mklml.h"
...
...
paddle/fluid/operators/math/detail/activation_functions.h
浏览文件 @
fa78fc60
...
...
@@ -15,14 +15,10 @@ limitations under the License. */
#pragma once
#include <math.h>
#include <string>
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/hostdevice.h"
#ifdef __AVX__
#include <immintrin.h>
#endif
namespace
paddle
{
namespace
operators
{
namespace
math
{
...
...
paddle/fluid/operators/math/detail/avx_functions.cc
浏览文件 @
fa78fc60
...
...
@@ -14,10 +14,8 @@ limitations under the License. */
#ifdef __AVX__
#include <immintrin.h>
#include "paddle/fluid/operators/math/detail/activation_functions.h"
// TODO(qingqing) refine this dependence
#include "paddle/legacy/cuda/src/avx_mathfun.h"
#include "paddle/fluid/operators/math/detail/avx_mathfun.h"
namespace
paddle
{
namespace
operators
{
...
...
paddle/fluid/operators/math/detail/avx_mathfun.h
0 → 100644
浏览文件 @
fa78fc60
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
AVX implementation of sin, cos, sincos, exp and log
Based on "sse_mathfun.h", by Julien Pommier
http://gruntthepeon.free.fr/ssemath/
Copyright (C) 2012 Giovanni Garberoglio
Interdisciplinary Laboratory for Computational Science (LISC)
Fondazione Bruno Kessler and University of Trento
via Sommarive, 18
I-38123 Trento (Italy)
This software is provided 'as-is', without any express or implied
warranty. In no event will the authors be held liable for any damages
arising from the use of this software.
Permission is granted to anyone to use this software for any purpose,
including commercial applications, and to alter it and redistribute it
freely, subject to the following restrictions:
1. The origin of this software must not be misrepresented; you must not
claim that you wrote the original software. If you use this software
in a product, an acknowledgment in the product documentation would be
appreciated but is not required.
2. Altered source versions must be plainly marked as such, and must not be
misrepresented as being the original software.
3. This notice may not be removed or altered from any source distribution.
(this is the zlib license)
*/
#include "paddle/fluid/platform/cpu_info.h"
/* __m128 is ugly to write */
typedef
__m256
v8sf
;
// vector of 8 float (avx)
typedef
__m256i
v8si
;
// vector of 8 int (avx)
typedef
__m128i
v4si
;
// vector of 8 int (avx)
#define _PI32AVX_CONST(Name, Val) \
static const ALIGN32_BEG int _pi32avx_##Name[4] ALIGN32_END = {Val, Val, \
Val, Val}
_PI32AVX_CONST
(
1
,
1
);
_PI32AVX_CONST
(
inv1
,
~
1
);
_PI32AVX_CONST
(
2
,
2
);
_PI32AVX_CONST
(
4
,
4
);
/* declare some AVX constants -- why can't I figure a better way to do that? */
#define _PS256_CONST(Name, Val) \
static const ALIGN32_BEG float _ps256_##Name[8] ALIGN32_END = { \
Val, Val, Val, Val, Val, Val, Val, Val}
#define _PI32_CONST256(Name, Val) \
static const ALIGN32_BEG int _pi32_256_##Name[8] ALIGN32_END = { \
Val, Val, Val, Val, Val, Val, Val, Val}
#define _PS256_CONST_TYPE(Name, Type, Val) \
static const ALIGN32_BEG Type _ps256_##Name[8] ALIGN32_END = { \
Val, Val, Val, Val, Val, Val, Val, Val}
_PS256_CONST
(
1
,
1
.
0
f
);
_PS256_CONST
(
0
p5
,
0
.
5
f
);
/* the smallest non denormalized float number */
_PS256_CONST_TYPE
(
min_norm_pos
,
int
,
0x00800000
);
_PS256_CONST_TYPE
(
mant_mask
,
int
,
0x7f800000
);
_PS256_CONST_TYPE
(
inv_mant_mask
,
int
,
~
0x7f800000
);
_PS256_CONST_TYPE
(
sign_mask
,
int
,
(
int
)
0x80000000
);
_PS256_CONST_TYPE
(
inv_sign_mask
,
int
,
~
0x80000000
);
_PI32_CONST256
(
0
,
0
);
_PI32_CONST256
(
1
,
1
);
_PI32_CONST256
(
inv1
,
~
1
);
_PI32_CONST256
(
2
,
2
);
_PI32_CONST256
(
4
,
4
);
_PI32_CONST256
(
0x7f
,
0x7f
);
_PS256_CONST
(
cephes_SQRTHF
,
0
.
707106781186547524
);
_PS256_CONST
(
cephes_log_p0
,
7.0376836292E-2
);
_PS256_CONST
(
cephes_log_p1
,
-
1.1514610310E-1
);
_PS256_CONST
(
cephes_log_p2
,
1.1676998740E-1
);
_PS256_CONST
(
cephes_log_p3
,
-
1.2420140846E-1
);
_PS256_CONST
(
cephes_log_p4
,
+
1.4249322787E-1
);
_PS256_CONST
(
cephes_log_p5
,
-
1.6668057665E-1
);
_PS256_CONST
(
cephes_log_p6
,
+
2.0000714765E-1
);
_PS256_CONST
(
cephes_log_p7
,
-
2.4999993993E-1
);
_PS256_CONST
(
cephes_log_p8
,
+
3.3333331174E-1
);
_PS256_CONST
(
cephes_log_q1
,
-
2.12194440e-4
);
_PS256_CONST
(
cephes_log_q2
,
0
.
693359375
);
#ifndef __AVX2__
typedef
union
imm_xmm_union
{
v8si
imm
;
v4si
xmm
[
2
];
}
imm_xmm_union
;
#define COPY_IMM_TO_XMM(imm_, xmm0_, xmm1_) \
{ \
imm_xmm_union ALIGN32_BEG u ALIGN32_END; \
u.imm = imm_; \
xmm0_ = u.xmm[0]; \
xmm1_ = u.xmm[1]; \
}
#define COPY_XMM_TO_IMM(xmm0_, xmm1_, imm_) \
{ \
imm_xmm_union ALIGN32_BEG u ALIGN32_END; \
u.xmm[0] = xmm0_; \
u.xmm[1] = xmm1_; \
imm_ = u.imm; \
}
#define AVX2_BITOP_USING_SSE2(fn) \
static inline v8si avx2_mm256_##fn(v8si x, int a) { \
/* use SSE2 instruction to perform the bitop AVX2 */
\
v4si x1, x2; \
v8si ret; \
COPY_IMM_TO_XMM(x, x1, x2); \
x1 = _mm_##fn(x1, a); \
x2 = _mm_##fn(x2, a); \
COPY_XMM_TO_IMM(x1, x2, ret); \
return (ret); \
}
//#warning "Using SSE2 to perform AVX2 bitshift ops"
AVX2_BITOP_USING_SSE2
(
slli_epi32
)
AVX2_BITOP_USING_SSE2
(
srli_epi32
)
#define AVX2_INTOP_USING_SSE2(fn) \
static inline v8si avx2_mm256_##fn(v8si x, v8si y) { \
/* use SSE2 instructions to perform the AVX2 integer operation */
\
v4si x1, x2; \
v4si y1, y2; \
v8si ret; \
COPY_IMM_TO_XMM(x, x1, x2); \
COPY_IMM_TO_XMM(y, y1, y2); \
x1 = _mm_##fn(x1, y1); \
x2 = _mm_##fn(x2, y2); \
COPY_XMM_TO_IMM(x1, x2, ret); \
return (ret); \
}
//#warning "Using SSE2 to perform AVX2 integer ops"
AVX2_INTOP_USING_SSE2
(
and_si128
)
AVX2_INTOP_USING_SSE2
(
andnot_si128
)
AVX2_INTOP_USING_SSE2
(
cmpeq_epi32
)
AVX2_INTOP_USING_SSE2
(
sub_epi32
)
AVX2_INTOP_USING_SSE2
(
add_epi32
)
#define avx2_mm256_and_si256 avx2_mm256_and_si128
#define avx2_mm256_andnot_si256 avx2_mm256_andnot_si128
#else
#define avx2_mm256_slli_epi32 _mm256_slli_epi32
#define avx2_mm256_srli_epi32 _mm256_srli_epi32
#define avx2_mm256_and_si256 _mm256_and_si256
#define avx2_mm256_andnot_si256 _mm256_andnot_si256
#define avx2_mm256_cmpeq_epi32 _mm256_cmpeq_epi32
#define avx2_mm256_sub_epi32 _mm256_sub_epi32
#define avx2_mm256_add_epi32 _mm256_add_epi32
#endif
/* __AVX2__ */
/* natural logarithm computed for 8 simultaneous float
return NaN for x <= 0
*/
v8sf
log256_ps
(
v8sf
x
)
{
v8si
imm0
;
v8sf
one
=
*
(
v8sf
*
)
_ps256_1
;
// v8sf invalid_mask = _mm256_cmple_ps(x, _mm256_setzero_ps());
v8sf
invalid_mask
=
_mm256_cmp_ps
(
x
,
_mm256_setzero_ps
(),
_CMP_LE_OS
);
x
=
_mm256_max_ps
(
x
,
*
(
v8sf
*
)
_ps256_min_norm_pos
);
/* cut off denormalized stuff */
// can be done with AVX2
imm0
=
avx2_mm256_srli_epi32
(
_mm256_castps_si256
(
x
),
23
);
/* keep only the fractional part */
x
=
_mm256_and_ps
(
x
,
*
(
v8sf
*
)
_ps256_inv_mant_mask
);
x
=
_mm256_or_ps
(
x
,
*
(
v8sf
*
)
_ps256_0p5
);
// this is again another AVX2 instruction
imm0
=
avx2_mm256_sub_epi32
(
imm0
,
*
(
v8si
*
)
_pi32_256_0x7f
);
v8sf
e
=
_mm256_cvtepi32_ps
(
imm0
);
e
=
_mm256_add_ps
(
e
,
one
);
/* part2:
if( x < SQRTHF ) {
e -= 1;
x = x + x - 1.0;
} else { x = x - 1.0; }
*/
// v8sf mask = _mm256_cmplt_ps(x, *(v8sf*)_ps256_cephes_SQRTHF);
v8sf
mask
=
_mm256_cmp_ps
(
x
,
*
(
v8sf
*
)
_ps256_cephes_SQRTHF
,
_CMP_LT_OS
);
v8sf
tmp
=
_mm256_and_ps
(
x
,
mask
);
x
=
_mm256_sub_ps
(
x
,
one
);
e
=
_mm256_sub_ps
(
e
,
_mm256_and_ps
(
one
,
mask
));
x
=
_mm256_add_ps
(
x
,
tmp
);
v8sf
z
=
_mm256_mul_ps
(
x
,
x
);
v8sf
y
=
*
(
v8sf
*
)
_ps256_cephes_log_p0
;
y
=
_mm256_mul_ps
(
y
,
x
);
y
=
_mm256_add_ps
(
y
,
*
(
v8sf
*
)
_ps256_cephes_log_p1
);
y
=
_mm256_mul_ps
(
y
,
x
);
y
=
_mm256_add_ps
(
y
,
*
(
v8sf
*
)
_ps256_cephes_log_p2
);
y
=
_mm256_mul_ps
(
y
,
x
);
y
=
_mm256_add_ps
(
y
,
*
(
v8sf
*
)
_ps256_cephes_log_p3
);
y
=
_mm256_mul_ps
(
y
,
x
);
y
=
_mm256_add_ps
(
y
,
*
(
v8sf
*
)
_ps256_cephes_log_p4
);
y
=
_mm256_mul_ps
(
y
,
x
);
y
=
_mm256_add_ps
(
y
,
*
(
v8sf
*
)
_ps256_cephes_log_p5
);
y
=
_mm256_mul_ps
(
y
,
x
);
y
=
_mm256_add_ps
(
y
,
*
(
v8sf
*
)
_ps256_cephes_log_p6
);
y
=
_mm256_mul_ps
(
y
,
x
);
y
=
_mm256_add_ps
(
y
,
*
(
v8sf
*
)
_ps256_cephes_log_p7
);
y
=
_mm256_mul_ps
(
y
,
x
);
y
=
_mm256_add_ps
(
y
,
*
(
v8sf
*
)
_ps256_cephes_log_p8
);
y
=
_mm256_mul_ps
(
y
,
x
);
y
=
_mm256_mul_ps
(
y
,
z
);
tmp
=
_mm256_mul_ps
(
e
,
*
(
v8sf
*
)
_ps256_cephes_log_q1
);
y
=
_mm256_add_ps
(
y
,
tmp
);
tmp
=
_mm256_mul_ps
(
z
,
*
(
v8sf
*
)
_ps256_0p5
);
y
=
_mm256_sub_ps
(
y
,
tmp
);
tmp
=
_mm256_mul_ps
(
e
,
*
(
v8sf
*
)
_ps256_cephes_log_q2
);
x
=
_mm256_add_ps
(
x
,
y
);
x
=
_mm256_add_ps
(
x
,
tmp
);
x
=
_mm256_or_ps
(
x
,
invalid_mask
);
// negative arg will be NAN
return
x
;
}
_PS256_CONST
(
exp_hi
,
88
.
3762626647949
f
);
_PS256_CONST
(
exp_lo
,
-
88
.
3762626647949
f
);
_PS256_CONST
(
cephes_LOG2EF
,
1
.
44269504088896341
);
_PS256_CONST
(
cephes_exp_C1
,
0
.
693359375
);
_PS256_CONST
(
cephes_exp_C2
,
-
2.12194440e-4
);
_PS256_CONST
(
cephes_exp_p0
,
1.9875691500E-4
);
_PS256_CONST
(
cephes_exp_p1
,
1.3981999507E-3
);
_PS256_CONST
(
cephes_exp_p2
,
8.3334519073E-3
);
_PS256_CONST
(
cephes_exp_p3
,
4.1665795894E-2
);
_PS256_CONST
(
cephes_exp_p4
,
1.6666665459E-1
);
_PS256_CONST
(
cephes_exp_p5
,
5.0000001201E-1
);
v8sf
exp256_ps
(
v8sf
x
)
{
v8sf
tmp
=
_mm256_setzero_ps
(),
fx
;
v8si
imm0
;
v8sf
one
=
*
(
v8sf
*
)
_ps256_1
;
x
=
_mm256_min_ps
(
x
,
*
(
v8sf
*
)
_ps256_exp_hi
);
x
=
_mm256_max_ps
(
x
,
*
(
v8sf
*
)
_ps256_exp_lo
);
/* express exp(x) as exp(g + n*log(2)) */
fx
=
_mm256_mul_ps
(
x
,
*
(
v8sf
*
)
_ps256_cephes_LOG2EF
);
fx
=
_mm256_add_ps
(
fx
,
*
(
v8sf
*
)
_ps256_0p5
);
/* how to perform a floorf with SSE: just below */
// imm0 = _mm256_cvttps_epi32(fx);
// tmp = _mm256_cvtepi32_ps(imm0);
tmp
=
_mm256_floor_ps
(
fx
);
/* if greater, substract 1 */
// v8sf mask = _mm256_cmpgt_ps(tmp, fx);
v8sf
mask
=
_mm256_cmp_ps
(
tmp
,
fx
,
_CMP_GT_OS
);
mask
=
_mm256_and_ps
(
mask
,
one
);
fx
=
_mm256_sub_ps
(
tmp
,
mask
);
tmp
=
_mm256_mul_ps
(
fx
,
*
(
v8sf
*
)
_ps256_cephes_exp_C1
);
v8sf
z
=
_mm256_mul_ps
(
fx
,
*
(
v8sf
*
)
_ps256_cephes_exp_C2
);
x
=
_mm256_sub_ps
(
x
,
tmp
);
x
=
_mm256_sub_ps
(
x
,
z
);
z
=
_mm256_mul_ps
(
x
,
x
);
v8sf
y
=
*
(
v8sf
*
)
_ps256_cephes_exp_p0
;
y
=
_mm256_mul_ps
(
y
,
x
);
y
=
_mm256_add_ps
(
y
,
*
(
v8sf
*
)
_ps256_cephes_exp_p1
);
y
=
_mm256_mul_ps
(
y
,
x
);
y
=
_mm256_add_ps
(
y
,
*
(
v8sf
*
)
_ps256_cephes_exp_p2
);
y
=
_mm256_mul_ps
(
y
,
x
);
y
=
_mm256_add_ps
(
y
,
*
(
v8sf
*
)
_ps256_cephes_exp_p3
);
y
=
_mm256_mul_ps
(
y
,
x
);
y
=
_mm256_add_ps
(
y
,
*
(
v8sf
*
)
_ps256_cephes_exp_p4
);
y
=
_mm256_mul_ps
(
y
,
x
);
y
=
_mm256_add_ps
(
y
,
*
(
v8sf
*
)
_ps256_cephes_exp_p5
);
y
=
_mm256_mul_ps
(
y
,
z
);
y
=
_mm256_add_ps
(
y
,
x
);
y
=
_mm256_add_ps
(
y
,
one
);
/* build 2^n */
imm0
=
_mm256_cvttps_epi32
(
fx
);
// another two AVX2 instructions
imm0
=
avx2_mm256_add_epi32
(
imm0
,
*
(
v8si
*
)
_pi32_256_0x7f
);
imm0
=
avx2_mm256_slli_epi32
(
imm0
,
23
);
v8sf
pow2n
=
_mm256_castsi256_ps
(
imm0
);
y
=
_mm256_mul_ps
(
y
,
pow2n
);
return
y
;
}
_PS256_CONST
(
minus_cephes_DP1
,
-
0
.
78515625
);
_PS256_CONST
(
minus_cephes_DP2
,
-
2.4187564849853515625e-4
);
_PS256_CONST
(
minus_cephes_DP3
,
-
3.77489497744594108e-8
);
_PS256_CONST
(
sincof_p0
,
-
1.9515295891E-4
);
_PS256_CONST
(
sincof_p1
,
8.3321608736E-3
);
_PS256_CONST
(
sincof_p2
,
-
1.6666654611E-1
);
_PS256_CONST
(
coscof_p0
,
2.443315711809948E-005
);
_PS256_CONST
(
coscof_p1
,
-
1.388731625493765E-003
);
_PS256_CONST
(
coscof_p2
,
4.166664568298827E-002
);
_PS256_CONST
(
cephes_FOPI
,
1
.
27323954473516
);
// 4 / M_PI
/* evaluation of 8 sines at onces using AVX intrisics
The code is the exact rewriting of the cephes sinf function.
Precision is excellent as long as x < 8192 (I did not bother to
take into account the special handling they have for greater values
-- it does not return garbage for arguments over 8192, though, but
the extra precision is missing).
Note that it is such that sinf((float)M_PI) = 8.74e-8, which is the
surprising but correct result.
*/
v8sf
sin256_ps
(
v8sf
x
)
{
// any x
v8sf
xmm1
,
xmm2
=
_mm256_setzero_ps
(),
xmm3
,
sign_bit
,
y
;
v8si
imm0
,
imm2
;
#ifndef __AVX2__
v4si
imm0_1
,
imm0_2
;
v4si
imm2_1
,
imm2_2
;
#endif
sign_bit
=
x
;
/* take the absolute value */
x
=
_mm256_and_ps
(
x
,
*
(
v8sf
*
)
_ps256_inv_sign_mask
);
/* extract the sign bit (upper one) */
sign_bit
=
_mm256_and_ps
(
sign_bit
,
*
(
v8sf
*
)
_ps256_sign_mask
);
/* scale by 4/Pi */
y
=
_mm256_mul_ps
(
x
,
*
(
v8sf
*
)
_ps256_cephes_FOPI
);
/*
Here we start a series of integer operations, which are in the
realm of AVX2.
If we don't have AVX, let's perform them using SSE2 directives
*/
#ifdef __AVX2__
/* store the integer part of y in mm0 */
imm2
=
_mm256_cvttps_epi32
(
y
);
/* j=(j+1) & (~1) (see the cephes sources) */
// another two AVX2 instruction
imm2
=
avx2_mm256_add_epi32
(
imm2
,
*
(
v8si
*
)
_pi32_256_1
);
imm2
=
avx2_mm256_and_si256
(
imm2
,
*
(
v8si
*
)
_pi32_256_inv1
);
y
=
_mm256_cvtepi32_ps
(
imm2
);
/* get the swap sign flag */
imm0
=
avx2_mm256_and_si256
(
imm2
,
*
(
v8si
*
)
_pi32_256_4
);
imm0
=
avx2_mm256_slli_epi32
(
imm0
,
29
);
/* get the polynom selection mask
there is one polynom for 0 <= x <= Pi/4
and another one for Pi/4<x<=Pi/2
Both branches will be computed.
*/
imm2
=
avx2_mm256_and_si256
(
imm2
,
*
(
v8si
*
)
_pi32_256_2
);
imm2
=
avx2_mm256_cmpeq_epi32
(
imm2
,
*
(
v8si
*
)
_pi32_256_0
);
#else
/* we use SSE2 routines to perform the integer ops */
COPY_IMM_TO_XMM
(
_mm256_cvttps_epi32
(
y
),
imm2_1
,
imm2_2
);
imm2_1
=
_mm_add_epi32
(
imm2_1
,
*
(
v4si
*
)
_pi32avx_1
);
imm2_2
=
_mm_add_epi32
(
imm2_2
,
*
(
v4si
*
)
_pi32avx_1
);
imm2_1
=
_mm_and_si128
(
imm2_1
,
*
(
v4si
*
)
_pi32avx_inv1
);
imm2_2
=
_mm_and_si128
(
imm2_2
,
*
(
v4si
*
)
_pi32avx_inv1
);
COPY_XMM_TO_IMM
(
imm2_1
,
imm2_2
,
imm2
);
y
=
_mm256_cvtepi32_ps
(
imm2
);
imm0_1
=
_mm_and_si128
(
imm2_1
,
*
(
v4si
*
)
_pi32avx_4
);
imm0_2
=
_mm_and_si128
(
imm2_2
,
*
(
v4si
*
)
_pi32avx_4
);
imm0_1
=
_mm_slli_epi32
(
imm0_1
,
29
);
imm0_2
=
_mm_slli_epi32
(
imm0_2
,
29
);
COPY_XMM_TO_IMM
(
imm0_1
,
imm0_2
,
imm0
);
imm2_1
=
_mm_and_si128
(
imm2_1
,
*
(
v4si
*
)
_pi32avx_2
);
imm2_2
=
_mm_and_si128
(
imm2_2
,
*
(
v4si
*
)
_pi32avx_2
);
imm2_1
=
_mm_cmpeq_epi32
(
imm2_1
,
_mm_setzero_si128
());
imm2_2
=
_mm_cmpeq_epi32
(
imm2_2
,
_mm_setzero_si128
());
COPY_XMM_TO_IMM
(
imm2_1
,
imm2_2
,
imm2
);
#endif
v8sf
swap_sign_bit
=
_mm256_castsi256_ps
(
imm0
);
v8sf
poly_mask
=
_mm256_castsi256_ps
(
imm2
);
sign_bit
=
_mm256_xor_ps
(
sign_bit
,
swap_sign_bit
);
/* The magic pass: "Extended precision modular arithmetic"
x = ((x - y * DP1) - y * DP2) - y * DP3; */
xmm1
=
*
(
v8sf
*
)
_ps256_minus_cephes_DP1
;
xmm2
=
*
(
v8sf
*
)
_ps256_minus_cephes_DP2
;
xmm3
=
*
(
v8sf
*
)
_ps256_minus_cephes_DP3
;
xmm1
=
_mm256_mul_ps
(
y
,
xmm1
);
xmm2
=
_mm256_mul_ps
(
y
,
xmm2
);
xmm3
=
_mm256_mul_ps
(
y
,
xmm3
);
x
=
_mm256_add_ps
(
x
,
xmm1
);
x
=
_mm256_add_ps
(
x
,
xmm2
);
x
=
_mm256_add_ps
(
x
,
xmm3
);
/* Evaluate the first polynom (0 <= x <= Pi/4) */
y
=
*
(
v8sf
*
)
_ps256_coscof_p0
;
v8sf
z
=
_mm256_mul_ps
(
x
,
x
);
y
=
_mm256_mul_ps
(
y
,
z
);
y
=
_mm256_add_ps
(
y
,
*
(
v8sf
*
)
_ps256_coscof_p1
);
y
=
_mm256_mul_ps
(
y
,
z
);
y
=
_mm256_add_ps
(
y
,
*
(
v8sf
*
)
_ps256_coscof_p2
);
y
=
_mm256_mul_ps
(
y
,
z
);
y
=
_mm256_mul_ps
(
y
,
z
);
v8sf
tmp
=
_mm256_mul_ps
(
z
,
*
(
v8sf
*
)
_ps256_0p5
);
y
=
_mm256_sub_ps
(
y
,
tmp
);
y
=
_mm256_add_ps
(
y
,
*
(
v8sf
*
)
_ps256_1
);
/* Evaluate the second polynom (Pi/4 <= x <= 0) */
v8sf
y2
=
*
(
v8sf
*
)
_ps256_sincof_p0
;
y2
=
_mm256_mul_ps
(
y2
,
z
);
y2
=
_mm256_add_ps
(
y2
,
*
(
v8sf
*
)
_ps256_sincof_p1
);
y2
=
_mm256_mul_ps
(
y2
,
z
);
y2
=
_mm256_add_ps
(
y2
,
*
(
v8sf
*
)
_ps256_sincof_p2
);
y2
=
_mm256_mul_ps
(
y2
,
z
);
y2
=
_mm256_mul_ps
(
y2
,
x
);
y2
=
_mm256_add_ps
(
y2
,
x
);
/* select the correct result from the two polynoms */
xmm3
=
poly_mask
;
y2
=
_mm256_and_ps
(
xmm3
,
y2
);
//, xmm3);
y
=
_mm256_andnot_ps
(
xmm3
,
y
);
y
=
_mm256_add_ps
(
y
,
y2
);
/* update the sign */
y
=
_mm256_xor_ps
(
y
,
sign_bit
);
return
y
;
}
/* almost the same as sin_ps */
v8sf
cos256_ps
(
v8sf
x
)
{
// any x
v8sf
xmm1
,
xmm2
=
_mm256_setzero_ps
(),
xmm3
,
y
;
v8si
imm0
,
imm2
;
#ifndef __AVX2__
v4si
imm0_1
,
imm0_2
;
v4si
imm2_1
,
imm2_2
;
#endif
/* take the absolute value */
x
=
_mm256_and_ps
(
x
,
*
(
v8sf
*
)
_ps256_inv_sign_mask
);
/* scale by 4/Pi */
y
=
_mm256_mul_ps
(
x
,
*
(
v8sf
*
)
_ps256_cephes_FOPI
);
#ifdef __AVX2__
/* store the integer part of y in mm0 */
imm2
=
_mm256_cvttps_epi32
(
y
);
/* j=(j+1) & (~1) (see the cephes sources) */
imm2
=
avx2_mm256_add_epi32
(
imm2
,
*
(
v8si
*
)
_pi32_256_1
);
imm2
=
avx2_mm256_and_si256
(
imm2
,
*
(
v8si
*
)
_pi32_256_inv1
);
y
=
_mm256_cvtepi32_ps
(
imm2
);
imm2
=
avx2_mm256_sub_epi32
(
imm2
,
*
(
v8si
*
)
_pi32_256_2
);
/* get the swap sign flag */
imm0
=
avx2_mm256_andnot_si256
(
imm2
,
*
(
v8si
*
)
_pi32_256_4
);
imm0
=
avx2_mm256_slli_epi32
(
imm0
,
29
);
/* get the polynom selection mask */
imm2
=
avx2_mm256_and_si256
(
imm2
,
*
(
v8si
*
)
_pi32_256_2
);
imm2
=
avx2_mm256_cmpeq_epi32
(
imm2
,
*
(
v8si
*
)
_pi32_256_0
);
#else
/* we use SSE2 routines to perform the integer ops */
COPY_IMM_TO_XMM
(
_mm256_cvttps_epi32
(
y
),
imm2_1
,
imm2_2
);
imm2_1
=
_mm_add_epi32
(
imm2_1
,
*
(
v4si
*
)
_pi32avx_1
);
imm2_2
=
_mm_add_epi32
(
imm2_2
,
*
(
v4si
*
)
_pi32avx_1
);
imm2_1
=
_mm_and_si128
(
imm2_1
,
*
(
v4si
*
)
_pi32avx_inv1
);
imm2_2
=
_mm_and_si128
(
imm2_2
,
*
(
v4si
*
)
_pi32avx_inv1
);
COPY_XMM_TO_IMM
(
imm2_1
,
imm2_2
,
imm2
);
y
=
_mm256_cvtepi32_ps
(
imm2
);
imm2_1
=
_mm_sub_epi32
(
imm2_1
,
*
(
v4si
*
)
_pi32avx_2
);
imm2_2
=
_mm_sub_epi32
(
imm2_2
,
*
(
v4si
*
)
_pi32avx_2
);
imm0_1
=
_mm_andnot_si128
(
imm2_1
,
*
(
v4si
*
)
_pi32avx_4
);
imm0_2
=
_mm_andnot_si128
(
imm2_2
,
*
(
v4si
*
)
_pi32avx_4
);
imm0_1
=
_mm_slli_epi32
(
imm0_1
,
29
);
imm0_2
=
_mm_slli_epi32
(
imm0_2
,
29
);
COPY_XMM_TO_IMM
(
imm0_1
,
imm0_2
,
imm0
);
imm2_1
=
_mm_and_si128
(
imm2_1
,
*
(
v4si
*
)
_pi32avx_2
);
imm2_2
=
_mm_and_si128
(
imm2_2
,
*
(
v4si
*
)
_pi32avx_2
);
imm2_1
=
_mm_cmpeq_epi32
(
imm2_1
,
_mm_setzero_si128
());
imm2_2
=
_mm_cmpeq_epi32
(
imm2_2
,
_mm_setzero_si128
());
COPY_XMM_TO_IMM
(
imm2_1
,
imm2_2
,
imm2
);
#endif
v8sf
sign_bit
=
_mm256_castsi256_ps
(
imm0
);
v8sf
poly_mask
=
_mm256_castsi256_ps
(
imm2
);
/* The magic pass: "Extended precision modular arithmetic"
x = ((x - y * DP1) - y * DP2) - y * DP3; */
xmm1
=
*
(
v8sf
*
)
_ps256_minus_cephes_DP1
;
xmm2
=
*
(
v8sf
*
)
_ps256_minus_cephes_DP2
;
xmm3
=
*
(
v8sf
*
)
_ps256_minus_cephes_DP3
;
xmm1
=
_mm256_mul_ps
(
y
,
xmm1
);
xmm2
=
_mm256_mul_ps
(
y
,
xmm2
);
xmm3
=
_mm256_mul_ps
(
y
,
xmm3
);
x
=
_mm256_add_ps
(
x
,
xmm1
);
x
=
_mm256_add_ps
(
x
,
xmm2
);
x
=
_mm256_add_ps
(
x
,
xmm3
);
/* Evaluate the first polynom (0 <= x <= Pi/4) */
y
=
*
(
v8sf
*
)
_ps256_coscof_p0
;
v8sf
z
=
_mm256_mul_ps
(
x
,
x
);
y
=
_mm256_mul_ps
(
y
,
z
);
y
=
_mm256_add_ps
(
y
,
*
(
v8sf
*
)
_ps256_coscof_p1
);
y
=
_mm256_mul_ps
(
y
,
z
);
y
=
_mm256_add_ps
(
y
,
*
(
v8sf
*
)
_ps256_coscof_p2
);
y
=
_mm256_mul_ps
(
y
,
z
);
y
=
_mm256_mul_ps
(
y
,
z
);
v8sf
tmp
=
_mm256_mul_ps
(
z
,
*
(
v8sf
*
)
_ps256_0p5
);
y
=
_mm256_sub_ps
(
y
,
tmp
);
y
=
_mm256_add_ps
(
y
,
*
(
v8sf
*
)
_ps256_1
);
/* Evaluate the second polynom (Pi/4 <= x <= 0) */
v8sf
y2
=
*
(
v8sf
*
)
_ps256_sincof_p0
;
y2
=
_mm256_mul_ps
(
y2
,
z
);
y2
=
_mm256_add_ps
(
y2
,
*
(
v8sf
*
)
_ps256_sincof_p1
);
y2
=
_mm256_mul_ps
(
y2
,
z
);
y2
=
_mm256_add_ps
(
y2
,
*
(
v8sf
*
)
_ps256_sincof_p2
);
y2
=
_mm256_mul_ps
(
y2
,
z
);
y2
=
_mm256_mul_ps
(
y2
,
x
);
y2
=
_mm256_add_ps
(
y2
,
x
);
/* select the correct result from the two polynoms */
xmm3
=
poly_mask
;
y2
=
_mm256_and_ps
(
xmm3
,
y2
);
//, xmm3);
y
=
_mm256_andnot_ps
(
xmm3
,
y
);
y
=
_mm256_add_ps
(
y
,
y2
);
/* update the sign */
y
=
_mm256_xor_ps
(
y
,
sign_bit
);
return
y
;
}
/* since sin256_ps and cos256_ps are almost identical, sincos256_ps could
replace both of them..
it is almost as fast, and gives you a free cosine with your sine */
void
sincos256_ps
(
v8sf
x
,
v8sf
*
s
,
v8sf
*
c
)
{
v8sf
xmm1
,
xmm2
,
xmm3
=
_mm256_setzero_ps
(),
sign_bit_sin
,
y
;
v8si
imm0
,
imm2
,
imm4
;
#ifndef __AVX2__
v4si
imm0_1
,
imm0_2
;
v4si
imm2_1
,
imm2_2
;
v4si
imm4_1
,
imm4_2
;
#endif
sign_bit_sin
=
x
;
/* take the absolute value */
x
=
_mm256_and_ps
(
x
,
*
(
v8sf
*
)
_ps256_inv_sign_mask
);
/* extract the sign bit (upper one) */
sign_bit_sin
=
_mm256_and_ps
(
sign_bit_sin
,
*
(
v8sf
*
)
_ps256_sign_mask
);
/* scale by 4/Pi */
y
=
_mm256_mul_ps
(
x
,
*
(
v8sf
*
)
_ps256_cephes_FOPI
);
#ifdef __AVX2__
/* store the integer part of y in imm2 */
imm2
=
_mm256_cvttps_epi32
(
y
);
/* j=(j+1) & (~1) (see the cephes sources) */
imm2
=
avx2_mm256_add_epi32
(
imm2
,
*
(
v8si
*
)
_pi32_256_1
);
imm2
=
avx2_mm256_and_si256
(
imm2
,
*
(
v8si
*
)
_pi32_256_inv1
);
y
=
_mm256_cvtepi32_ps
(
imm2
);
imm4
=
imm2
;
/* get the swap sign flag for the sine */
imm0
=
avx2_mm256_and_si256
(
imm2
,
*
(
v8si
*
)
_pi32_256_4
);
imm0
=
avx2_mm256_slli_epi32
(
imm0
,
29
);
// v8sf swap_sign_bit_sin = _mm256_castsi256_ps(imm0);
/* get the polynom selection mask for the sine*/
imm2
=
avx2_mm256_and_si256
(
imm2
,
*
(
v8si
*
)
_pi32_256_2
);
imm2
=
avx2_mm256_cmpeq_epi32
(
imm2
,
*
(
v8si
*
)
_pi32_256_0
);
// v8sf poly_mask = _mm256_castsi256_ps(imm2);
#else
/* we use SSE2 routines to perform the integer ops */
COPY_IMM_TO_XMM
(
_mm256_cvttps_epi32
(
y
),
imm2_1
,
imm2_2
);
imm2_1
=
_mm_add_epi32
(
imm2_1
,
*
(
v4si
*
)
_pi32avx_1
);
imm2_2
=
_mm_add_epi32
(
imm2_2
,
*
(
v4si
*
)
_pi32avx_1
);
imm2_1
=
_mm_and_si128
(
imm2_1
,
*
(
v4si
*
)
_pi32avx_inv1
);
imm2_2
=
_mm_and_si128
(
imm2_2
,
*
(
v4si
*
)
_pi32avx_inv1
);
COPY_XMM_TO_IMM
(
imm2_1
,
imm2_2
,
imm2
);
y
=
_mm256_cvtepi32_ps
(
imm2
);
imm4_1
=
imm2_1
;
imm4_2
=
imm2_2
;
imm0_1
=
_mm_and_si128
(
imm2_1
,
*
(
v4si
*
)
_pi32avx_4
);
imm0_2
=
_mm_and_si128
(
imm2_2
,
*
(
v4si
*
)
_pi32avx_4
);
imm0_1
=
_mm_slli_epi32
(
imm0_1
,
29
);
imm0_2
=
_mm_slli_epi32
(
imm0_2
,
29
);
COPY_XMM_TO_IMM
(
imm0_1
,
imm0_2
,
imm0
);
imm2_1
=
_mm_and_si128
(
imm2_1
,
*
(
v4si
*
)
_pi32avx_2
);
imm2_2
=
_mm_and_si128
(
imm2_2
,
*
(
v4si
*
)
_pi32avx_2
);
imm2_1
=
_mm_cmpeq_epi32
(
imm2_1
,
_mm_setzero_si128
());
imm2_2
=
_mm_cmpeq_epi32
(
imm2_2
,
_mm_setzero_si128
());
COPY_XMM_TO_IMM
(
imm2_1
,
imm2_2
,
imm2
);
#endif
v8sf
swap_sign_bit_sin
=
_mm256_castsi256_ps
(
imm0
);
v8sf
poly_mask
=
_mm256_castsi256_ps
(
imm2
);
/* The magic pass: "Extended precision modular arithmetic"
x = ((x - y * DP1) - y * DP2) - y * DP3; */
xmm1
=
*
(
v8sf
*
)
_ps256_minus_cephes_DP1
;
xmm2
=
*
(
v8sf
*
)
_ps256_minus_cephes_DP2
;
xmm3
=
*
(
v8sf
*
)
_ps256_minus_cephes_DP3
;
xmm1
=
_mm256_mul_ps
(
y
,
xmm1
);
xmm2
=
_mm256_mul_ps
(
y
,
xmm2
);
xmm3
=
_mm256_mul_ps
(
y
,
xmm3
);
x
=
_mm256_add_ps
(
x
,
xmm1
);
x
=
_mm256_add_ps
(
x
,
xmm2
);
x
=
_mm256_add_ps
(
x
,
xmm3
);
#ifdef __AVX2__
imm4
=
avx2_mm256_sub_epi32
(
imm4
,
*
(
v8si
*
)
_pi32_256_2
);
imm4
=
avx2_mm256_andnot_si256
(
imm4
,
*
(
v8si
*
)
_pi32_256_4
);
imm4
=
avx2_mm256_slli_epi32
(
imm4
,
29
);
#else
imm4_1
=
_mm_sub_epi32
(
imm4_1
,
*
(
v4si
*
)
_pi32avx_2
);
imm4_2
=
_mm_sub_epi32
(
imm4_2
,
*
(
v4si
*
)
_pi32avx_2
);
imm4_1
=
_mm_andnot_si128
(
imm4_1
,
*
(
v4si
*
)
_pi32avx_4
);
imm4_2
=
_mm_andnot_si128
(
imm4_2
,
*
(
v4si
*
)
_pi32avx_4
);
imm4_1
=
_mm_slli_epi32
(
imm4_1
,
29
);
imm4_2
=
_mm_slli_epi32
(
imm4_2
,
29
);
COPY_XMM_TO_IMM
(
imm4_1
,
imm4_2
,
imm4
);
#endif
v8sf
sign_bit_cos
=
_mm256_castsi256_ps
(
imm4
);
sign_bit_sin
=
_mm256_xor_ps
(
sign_bit_sin
,
swap_sign_bit_sin
);
/* Evaluate the first polynom (0 <= x <= Pi/4) */
v8sf
z
=
_mm256_mul_ps
(
x
,
x
);
y
=
*
(
v8sf
*
)
_ps256_coscof_p0
;
y
=
_mm256_mul_ps
(
y
,
z
);
y
=
_mm256_add_ps
(
y
,
*
(
v8sf
*
)
_ps256_coscof_p1
);
y
=
_mm256_mul_ps
(
y
,
z
);
y
=
_mm256_add_ps
(
y
,
*
(
v8sf
*
)
_ps256_coscof_p2
);
y
=
_mm256_mul_ps
(
y
,
z
);
y
=
_mm256_mul_ps
(
y
,
z
);
v8sf
tmp
=
_mm256_mul_ps
(
z
,
*
(
v8sf
*
)
_ps256_0p5
);
y
=
_mm256_sub_ps
(
y
,
tmp
);
y
=
_mm256_add_ps
(
y
,
*
(
v8sf
*
)
_ps256_1
);
/* Evaluate the second polynom (Pi/4 <= x <= 0) */
v8sf
y2
=
*
(
v8sf
*
)
_ps256_sincof_p0
;
y2
=
_mm256_mul_ps
(
y2
,
z
);
y2
=
_mm256_add_ps
(
y2
,
*
(
v8sf
*
)
_ps256_sincof_p1
);
y2
=
_mm256_mul_ps
(
y2
,
z
);
y2
=
_mm256_add_ps
(
y2
,
*
(
v8sf
*
)
_ps256_sincof_p2
);
y2
=
_mm256_mul_ps
(
y2
,
z
);
y2
=
_mm256_mul_ps
(
y2
,
x
);
y2
=
_mm256_add_ps
(
y2
,
x
);
/* select the correct result from the two polynoms */
xmm3
=
poly_mask
;
v8sf
ysin2
=
_mm256_and_ps
(
xmm3
,
y2
);
v8sf
ysin1
=
_mm256_andnot_ps
(
xmm3
,
y
);
y2
=
_mm256_sub_ps
(
y2
,
ysin2
);
y
=
_mm256_sub_ps
(
y
,
ysin1
);
xmm1
=
_mm256_add_ps
(
ysin1
,
ysin2
);
xmm2
=
_mm256_add_ps
(
y
,
y2
);
/* update the sign */
*
s
=
_mm256_xor_ps
(
xmm1
,
sign_bit_sin
);
*
c
=
_mm256_xor_ps
(
xmm2
,
sign_bit_cos
);
}
paddle/fluid/operators/math/jit_code.cc
浏览文件 @
fa78fc60
...
...
@@ -113,26 +113,27 @@ void VXXJitCode::generate() {
ret
();
}
const
float
exp_float_consts
[]
ALIGN32
=
{
REPEAT_8TIMES
(
1.
f
),
REPEAT_8TIMES
(
2.
f
),
REPEAT_8TIMES
(
0.5
f
),
REPEAT_8TIMES
(
EXP_HIG
),
REPEAT_8TIMES
(
EXP_LOW
),
REPEAT_8TIMES
(
CEPHES_LOG2EF
),
REPEAT_8TIMES
(
CEPHES_EXP_C1
),
REPEAT_8TIMES
(
CEPHES_EXP_C2
),
REPEAT_8TIMES
(
CEPHES_EXP_P0
),
REPEAT_8TIMES
(
CEPHES_EXP_P1
),
REPEAT_8TIMES
(
CEPHES_EXP_P2
),
REPEAT_8TIMES
(
CEPHES_EXP_P3
),
REPEAT_8TIMES
(
CEPHES_EXP_P4
),
REPEAT_8TIMES
(
CEPHES_EXP_P5
),
REPEAT_8TIMES
(
EXP_MAX_INPUT
),
REPEAT_8TIMES
(
SIGMOID_THRESHOLD_MAX
),
REPEAT_8TIMES
(
SIGMOID_THRESHOLD_MIN
)};
const
float
ALIGN32_BEG
exp_float_consts
[]
ALIGN32_END
=
{
REPEAT_8TIMES
(
1.
f
),
REPEAT_8TIMES
(
2.
f
),
REPEAT_8TIMES
(
0.5
f
),
REPEAT_8TIMES
(
EXP_HIG
),
REPEAT_8TIMES
(
EXP_LOW
),
REPEAT_8TIMES
(
CEPHES_LOG2EF
),
REPEAT_8TIMES
(
CEPHES_EXP_C1
),
REPEAT_8TIMES
(
CEPHES_EXP_C2
),
REPEAT_8TIMES
(
CEPHES_EXP_P0
),
REPEAT_8TIMES
(
CEPHES_EXP_P1
),
REPEAT_8TIMES
(
CEPHES_EXP_P2
),
REPEAT_8TIMES
(
CEPHES_EXP_P3
),
REPEAT_8TIMES
(
CEPHES_EXP_P4
),
REPEAT_8TIMES
(
CEPHES_EXP_P5
),
REPEAT_8TIMES
(
EXP_MAX_INPUT
),
REPEAT_8TIMES
(
SIGMOID_THRESHOLD_MAX
),
REPEAT_8TIMES
(
SIGMOID_THRESHOLD_MIN
)};
const
int
exp_int_0x7f
[]
ALIGN32
=
{
REPEAT_8TIMES
(
0x7f
)};
int
g_tmp_mem
[
16
]
ALIGN32
=
{
0
};
const
int
ALIGN32_BEG
exp_int_0x7f
[]
ALIGN32_END
=
{
REPEAT_8TIMES
(
0x7f
)};
int
ALIGN32_BEG
g_tmp_mem
[
16
]
ALIGN32_END
=
{
0
};
bool
VActJitCode
::
init
(
int
d
,
operand_type
type
)
{
// TODO(TJ): implement avx512, avx_exp is slower than mkl when d >= 256
...
...
paddle/fluid/operators/math/jit_code.h
浏览文件 @
fa78fc60
...
...
@@ -47,7 +47,6 @@ extern const float exp_float_consts[];
extern
const
int
exp_int_0x7f
[];
extern
int
g_tmp_mem
[];
#define ALIGN32 __attribute__((aligned(32)))
#define EXP_HIG 88.3762626647949f
#define EXP_LOW -88.3762626647949f
#define CEPHES_LOG2EF 1.44269504088896341
...
...
paddle/fluid/operators/math/jit_kernel_crf_decode.cc
浏览文件 @
fa78fc60
...
...
@@ -16,9 +16,6 @@ limitations under the License. */
#include <limits>
#include <string>
#include "paddle/fluid/operators/math/jit_kernel_macro.h"
#ifdef __AVX__
#include <immintrin.h>
#endif
namespace
paddle
{
namespace
operators
{
...
...
@@ -133,8 +130,8 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
/* AVX instructions.*/
\
__m128i lo_max_j = _mm256_extractf128_si256(max_j, 0); \
__m128i hi_max_j = _mm256_extractf128_si256(max_j, 1); \
__m128i lo_mask = _mm256_extractf128_si256(
(__m256i)mask, 0);
\
__m128i hi_mask = _mm256_extractf128_si256(
(__m256i)mask, 1);
\
__m128i lo_mask = _mm256_extractf128_si256(
*(__m256i*)&mask, 0);
\
__m128i hi_mask = _mm256_extractf128_si256(
*(__m256i*)&mask, 1);
\
lo_max_j = _mm_andnot_si128(lo_mask, lo_max_j); \
hi_max_j = _mm_andnot_si128(hi_mask, hi_max_j); \
lo_mask = _mm_and_si128(lo_mask, _mm_set1_epi32(i)); \
...
...
paddle/fluid/operators/math/jit_kernel_layer_norm.cc
浏览文件 @
fa78fc60
...
...
@@ -13,9 +13,6 @@ limitations under the License. */
#include <limits>
#include <string>
#include "paddle/fluid/operators/math/jit_kernel_macro.h"
#ifdef __AVX__
#include <immintrin.h>
#endif
namespace
paddle
{
namespace
operators
{
...
...
@@ -121,7 +118,7 @@ class LayerNormKernelImpl : public LayerNormKernel<T> {
if (rest_ != 0) { \
j = offset + this->num_ - block; \
tmp = _mm256_loadu_ps((const float*)x + j); \
tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp,
(__m256)mask_vec);
\
tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp,
*(__m256*)&mask_vec);
\
sum = _mm256_add_ps(sum, tmp); \
} \
hi = _mm256_extractf128_ps(sum, 1); \
...
...
@@ -145,7 +142,7 @@ class LayerNormKernelImpl : public LayerNormKernel<T> {
j = offset + this->num_ - block; \
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); \
tmp = _mm256_mul_ps(tmp, tmp); \
tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp,
(__m256)mask_vec);
\
tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp,
*(__m256*)&mask_vec);
\
sum = _mm256_add_ps(sum, tmp); \
} \
hi = _mm256_extractf128_ps(sum, 1); \
...
...
paddle/fluid/platform/cpu_info.h
浏览文件 @
fa78fc60
...
...
@@ -16,6 +16,26 @@ limitations under the License. */
#include <stddef.h>
#ifdef _WIN32
#if defined(__AVX2__)
#include <immintrin.h> //avx2
#elif defined(__AVX__)
#include <intrin.h> //avx
#endif // AVX
#else // WIN32
#ifdef __AVX__
#include <immintrin.h>
#endif
#endif // WIN32
#if defined(_WIN32)
#define ALIGN32_BEG __declspec(align(32))
#define ALIGN32_END
#else
#define ALIGN32_BEG
#define ALIGN32_END __attribute__((aligned(32)))
#endif // _WIN32
namespace
paddle
{
namespace
platform
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录