Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
7aa3aff3
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7aa3aff3
编写于
11月 20, 2018
作者:
T
tensor-tang
提交者:
GitHub
11月 20, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #14465 from tensor-tang/fea/jit/exp
jitcode act support all size
上级
1b894e49
a19b3225
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
295 addition
and
286 deletion
+295
-286
paddle/fluid/operators/math/jit_code.cc
paddle/fluid/operators/math/jit_code.cc
+107
-268
paddle/fluid/operators/math/jit_code.h
paddle/fluid/operators/math/jit_code.h
+178
-12
paddle/fluid/operators/math/jit_kernel.h
paddle/fluid/operators/math/jit_kernel.h
+1
-0
paddle/fluid/operators/math/jit_kernel_test.cc
paddle/fluid/operators/math/jit_kernel_test.cc
+9
-6
未找到文件。
paddle/fluid/operators/math/jit_code.cc
浏览文件 @
7aa3aff3
...
...
@@ -13,8 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/jit_code.h"
#include "paddle/fluid/operators/math/jit_kernel.h"
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/operators/math/jit_kernel.h" // TODO(TJ): remove me
namespace
paddle
{
namespace
operators
{
...
...
@@ -60,257 +59,83 @@ void VXXJitCode::generate() {
offset
+=
sizeof
(
float
)
*
YMM_FLOAT_BLOCK
;
}
int
rest
=
num_
%
YMM_FLOAT_BLOCK
;
if
(
rest
>=
4
)
{
if
(
scalar_index_
!=
1
)
{
vmovups
(
xmm_src1
,
ptr
[
param1
+
offset
]);
}
if
(
scalar_index_
!=
2
)
{
vmovups
(
xmm_src2
,
ptr
[
param2
+
offset
]);
}
if
(
type_
==
operand_type
::
mul
)
{
vmulps
(
xmm_dst
,
xmm_src1
,
xmm_src2
);
}
else
if
(
type_
==
operand_type
::
add
)
{
vaddps
(
xmm_dst
,
xmm_src1
,
xmm_src2
);
}
if
(
with_relu_
)
{
vmaxps
(
xmm_dst
,
xmm_zero
,
xmm_dst
);
}
vmovups
(
ptr
[
param3
+
offset
],
xmm_dst
);
offset
+=
sizeof
(
float
)
*
4
;
rest
-=
4
;
}
if
(
rest
>=
2
)
{
if
(
scalar_index_
!=
1
)
{
vmovups
(
xmm_src1
,
ptr
[
param1
+
offset
]);
while
(
rest
>
0
)
{
int
block
=
XMM_FLOAT_BLOCK
;
if
(
rest
>=
4
)
{
block
=
4
;
if
(
scalar_index_
!=
1
)
{
vmovups
(
xmm_src1
,
ptr
[
param1
+
offset
]);
}
if
(
scalar_index_
!=
2
)
{
vmovups
(
xmm_src2
,
ptr
[
param2
+
offset
]);
}
}
else
if
(
rest
>=
2
)
{
block
=
2
;
if
(
scalar_index_
!=
1
)
{
vmovq
(
xmm_src1
,
ptr
[
param1
+
offset
]);
}
if
(
scalar_index_
!=
2
)
{
vmovq
(
xmm_src2
,
ptr
[
param2
+
offset
]);
}
}
else
{
block
=
1
;
if
(
scalar_index_
!=
1
)
{
vmovss
(
xmm_src1
,
ptr
[
param1
+
offset
]);
}
if
(
scalar_index_
!=
2
)
{
vmovss
(
xmm_src2
,
ptr
[
param2
+
offset
]);
}
}
if
(
scalar_index_
!=
2
)
{
vmovups
(
xmm_src2
,
ptr
[
param2
+
offset
]);
}
if
(
type_
==
operand_type
::
mul
)
{
vmulps
(
xmm_dst
,
xmm_src1
,
xmm_src2
);
}
else
if
(
type_
==
operand_type
::
add
)
{
vaddps
(
xmm_dst
,
xmm_src1
,
xmm_src2
);
switch
(
type_
)
{
case
operand_type
::
mul
:
vmulps
(
xmm_dst
,
xmm_src1
,
xmm_src2
);
break
;
case
operand_type
::
add
:
vaddps
(
xmm_dst
,
xmm_src1
,
xmm_src2
);
break
;
default:
break
;
}
if
(
with_relu_
)
{
vmaxps
(
xmm_dst
,
xmm_zero
,
xmm_dst
);
}
vmovq
(
ptr
[
param3
+
offset
],
xmm_dst
);
offset
+=
sizeof
(
float
)
*
2
;
rest
-=
2
;
}
if
(
rest
>
0
)
{
if
(
scalar_index_
!=
1
)
{
vmovups
(
xmm_src1
,
ptr
[
param1
+
offset
]);
}
if
(
scalar_index_
!=
2
)
{
vmovups
(
xmm_src2
,
ptr
[
param2
+
offset
]);
}
if
(
type_
==
operand_type
::
mul
)
{
vmulss
(
xmm_dst
,
xmm_src1
,
xmm_src2
);
}
else
if
(
type_
==
operand_type
::
add
)
{
vaddss
(
xmm_dst
,
xmm_src1
,
xmm_src2
);
if
(
rest
>=
4
)
{
vmovups
(
ptr
[
param3
+
offset
],
xmm_dst
);
}
else
if
(
rest
>=
2
)
{
vmovq
(
ptr
[
param3
+
offset
],
xmm_dst
);
}
else
{
vmovss
(
ptr
[
param3
+
offset
],
xmm_dst
);
}
if
(
with_relu_
)
{
vmaxps
(
xmm_dst
,
xmm_zero
,
xmm_dst
);
}
vmovss
(
ptr
[
param3
+
offset
],
xmm_dst
);
offset
+=
sizeof
(
float
)
*
block
;
rest
-=
block
;
}
ret
();
}
#define ALIGN32 __attribute__((aligned(32)))
#define EXP_HIG 88.3762626647949f
#define EXP_LOW -88.3762626647949f
#define CEPHES_LOG2EF 1.44269504088896341
#define CEPHES_EXP_C1 0.693359375
#define CEPHES_EXP_C2 -2.12194440e-4
#define CEPHES_EXP_P0 1.9875691500E-4
#define CEPHES_EXP_P1 1.3981999507E-3
#define CEPHES_EXP_P2 8.3334519073E-3
#define CEPHES_EXP_P3 4.1665795894E-2
#define CEPHES_EXP_P4 1.6666665459E-1
#define CEPHES_EXP_P5 5.0000001201E-1
#define REPEAT_8TIMES(val) val, val, val, val, val, val, val, val
#define OFFSET_EXP_ONE 0 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_TWO 1 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_0P5 2 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_HIG 3 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_LOW 4 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_LOG2EF 5 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_C1 6 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_C2 7 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P0 8 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P1 9 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P2 10 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P3 11 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P4 12 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P5 13 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_MAX_INPUT 14 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MAX 15 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MIN 16 * YMM_FLOAT_BLOCK * sizeof(float)
static
const
float
exp_float_consts
[]
ALIGN32
=
{
REPEAT_8TIMES
(
1.
f
),
REPEAT_8TIMES
(
2.
f
),
REPEAT_8TIMES
(
0.5
f
),
REPEAT_8TIMES
(
EXP_HIG
),
REPEAT_8TIMES
(
EXP_LOW
),
REPEAT_8TIMES
(
CEPHES_LOG2EF
),
REPEAT_8TIMES
(
CEPHES_EXP_C1
),
REPEAT_8TIMES
(
CEPHES_EXP_C2
),
REPEAT_8TIMES
(
CEPHES_EXP_P0
),
REPEAT_8TIMES
(
CEPHES_EXP_P1
),
REPEAT_8TIMES
(
CEPHES_EXP_P2
),
REPEAT_8TIMES
(
CEPHES_EXP_P3
),
REPEAT_8TIMES
(
CEPHES_EXP_P4
),
REPEAT_8TIMES
(
CEPHES_EXP_P5
),
REPEAT_8TIMES
(
EXP_MAX_INPUT
),
REPEAT_8TIMES
(
SIGMOID_THRESHOLD_MAX
),
REPEAT_8TIMES
(
SIGMOID_THRESHOLD_MIN
)};
static
const
int
exp_int_0x7f
[]
ALIGN32
=
{
REPEAT_8TIMES
(
0x7f
)};
static
int
g_tmp_mem
[
16
]
ALIGN32
=
{
0
};
const
float
exp_float_consts
[]
ALIGN32
=
{
REPEAT_8TIMES
(
1.
f
),
REPEAT_8TIMES
(
2.
f
),
REPEAT_8TIMES
(
0.5
f
),
REPEAT_8TIMES
(
EXP_HIG
),
REPEAT_8TIMES
(
EXP_LOW
),
REPEAT_8TIMES
(
CEPHES_LOG2EF
),
REPEAT_8TIMES
(
CEPHES_EXP_C1
),
REPEAT_8TIMES
(
CEPHES_EXP_C2
),
REPEAT_8TIMES
(
CEPHES_EXP_P0
),
REPEAT_8TIMES
(
CEPHES_EXP_P1
),
REPEAT_8TIMES
(
CEPHES_EXP_P2
),
REPEAT_8TIMES
(
CEPHES_EXP_P3
),
REPEAT_8TIMES
(
CEPHES_EXP_P4
),
REPEAT_8TIMES
(
CEPHES_EXP_P5
),
REPEAT_8TIMES
(
EXP_MAX_INPUT
),
REPEAT_8TIMES
(
SIGMOID_THRESHOLD_MAX
),
REPEAT_8TIMES
(
SIGMOID_THRESHOLD_MIN
)};
const
int
exp_int_0x7f
[]
ALIGN32
=
{
REPEAT_8TIMES
(
0x7f
)};
int
g_tmp_mem
[
16
]
ALIGN32
=
{
0
};
bool
VActJitCode
::
init
(
int
d
,
operand_type
type
)
{
bool
ok
=
MayIUse
(
avx
);
if
(
type
==
operand_type
::
relu
)
{
return
ok
;
}
else
if
(
type
==
operand_type
::
exp
)
{
// exp is slower than mkl when d >= 256
return
ok
&&
d
%
8
==
0
&&
d
<
256
;
}
else
{
// TODO(TJ): support more
return
ok
&&
d
%
8
==
0
;
}
}
void
VActJitCode
::
relu_ymm
(
ymm_t
&
ymm_dst
,
ymm_t
&
ymm_src
,
ymm_t
&
ymm_zero
)
{
vmaxps
(
ymm_dst
,
ymm_zero
,
ymm_src
);
}
void
VActJitCode
::
exp_ymm
(
ymm_t
&
ymm_dst
,
ymm_t
&
ymm_src
,
int
fx_idx
,
int
fy_idx
,
int
mask_idx
,
int
tmp_idx
)
{
assert
(
ymm_src
.
getIdx
()
!=
ymm_dst
.
getIdx
());
// TODO(TJ): use enfore
// check all idx can not equal
ymm_t
ymm_fx
=
ymm_t
(
fx_idx
);
ymm_t
ymm_fy
=
ymm_t
(
fy_idx
);
ymm_t
ymm_mask
=
ymm_t
(
mask_idx
);
ymm_t
ymm_tmp
=
ymm_t
(
tmp_idx
);
reg64_t
reg_ptr_global
=
rax
;
push
(
reg_ptr_global
);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_HIG
]);
vminps
(
ymm_src
,
ymm_src
,
ymm_tmp
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_LOW
]);
vmaxps
(
ymm_src
,
ymm_src
,
ymm_tmp
);
// express exp(x) as exp(g + n*log(2))
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_LOG2EF
]);
vmulps
(
ymm_fx
,
ymm_src
,
ymm_tmp
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_0P5
]);
vaddps
(
ymm_fx
,
ymm_fx
,
ymm_tmp
);
vroundps
(
ymm_fy
,
ymm_fx
,
0x01
);
// if greater, substract 1
vcmpgtps
(
ymm_mask
,
ymm_fy
,
ymm_fx
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
]);
vandps
(
ymm_mask
,
ymm_mask
,
ymm_tmp
);
vsubps
(
ymm_fx
,
ymm_fy
,
ymm_mask
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_C1
]);
vmulps
(
ymm_fy
,
ymm_fx
,
ymm_tmp
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_C2
]);
ymm_t
ymm_z
=
ymm_t
(
ymm_mask
.
getIdx
());
vmulps
(
ymm_z
,
ymm_fx
,
ymm_tmp
);
vsubps
(
ymm_src
,
ymm_src
,
ymm_fy
);
vsubps
(
ymm_src
,
ymm_src
,
ymm_z
);
vmulps
(
ymm_z
,
ymm_src
,
ymm_src
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_P0
]);
vmulps
(
ymm_dst
,
ymm_src
,
ymm_tmp
);
for
(
size_t
i
=
OFFSET_EXP_P1
;
i
<
OFFSET_EXP_P5
;
i
+=
(
YMM_FLOAT_BLOCK
*
sizeof
(
float
)))
{
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
i
]);
// P1~P4
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
vmulps
(
ymm_dst
,
ymm_dst
,
ymm_src
);
}
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_P5
]);
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
vmulps
(
ymm_dst
,
ymm_dst
,
ymm_z
);
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_src
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
]);
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
// build 2^n
ymm_t
ymm_int
=
ymm_fx
;
vcvttps2dq
(
ymm_int
,
ymm_fx
);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_int_0x7f
));
vmovdqa
(
ymm_tmp
,
ptr
[
reg_ptr_global
]);
if
(
MayIUse
(
avx2
))
{
vpaddd
(
ymm_int
,
ymm_int
,
ymm_tmp
);
vpslld
(
ymm_int
,
ymm_int
,
23
);
}
else
if
(
MayIUse
(
avx
))
{
xmm_t
xtmp1
=
xmm_t
(
ymm_int
.
getIdx
());
xmm_t
xtmp2
=
xmm_t
(
ymm_tmp
.
getIdx
());
reg64_t
reg_ptr_tmp
=
reg_ptr_global
;
mov
(
reg_ptr_tmp
,
reinterpret_cast
<
size_t
>
(
g_tmp_mem
));
vmovdqa
(
ptr
[
reg_ptr_tmp
],
ymm_int
);
vmovdqa
(
ptr
[
reg_ptr_tmp
+
YMM_FLOAT_BLOCK
*
sizeof
(
float
)],
ymm_tmp
);
vpaddd
(
xtmp1
,
xtmp1
,
xtmp2
);
vpslld
(
xtmp1
,
xtmp1
,
23
);
vmovdqa
(
ptr
[
reg_ptr_tmp
],
xtmp1
);
// next 128bits
vmovdqa
(
xtmp1
,
ptr
[
reg_ptr_tmp
+
4
/*xmm float block*/
*
sizeof
(
float
)]);
vmovdqa
(
xtmp2
,
ptr
[
reg_ptr_tmp
+
(
YMM_FLOAT_BLOCK
+
4
/*xmm float block*/
)
*
sizeof
(
float
)]);
vpaddd
(
xtmp1
,
xtmp1
,
xtmp2
);
vpslld
(
xtmp1
,
xtmp1
,
23
);
vmovdqa
(
ptr
[
reg_ptr_tmp
+
4
/*xmm float block*/
*
sizeof
(
float
)],
xtmp1
);
// load out
vmovdqa
(
ymm_int
,
ptr
[
reg_ptr_tmp
]);
}
vmulps
(
ymm_dst
,
ymm_dst
,
ymm_int
);
pop
(
reg_ptr_global
);
}
void
VActJitCode
::
sigmoid_ymm
(
ymm_t
&
ymm_dst
,
ymm_t
&
ymm_src
,
int
fx_idx
,
int
fy_idx
,
int
mask_idx
,
int
tmp_idx
)
{
// y = 1 / (1 + e^-x)
ymm_t
ymm_tmp
=
ymm_t
(
tmp_idx
);
reg64_t
reg_ptr_global
=
rax
;
push
(
reg_ptr_global
);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_SIGMOID_MAX
]);
vminps
(
ymm_src
,
ymm_src
,
ymm_tmp
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_SIGMOID_MIN
]);
vmaxps
(
ymm_src
,
ymm_src
,
ymm_tmp
);
vxorps
(
ymm_tmp
,
ymm_tmp
,
ymm_tmp
);
vsubps
(
ymm_src
,
ymm_tmp
,
ymm_src
);
exp_ymm
(
ymm_dst
,
ymm_src
,
fx_idx
,
fy_idx
,
mask_idx
,
tmp_idx
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_ONE
]);
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
vdivps
(
ymm_dst
,
ymm_tmp
,
ymm_dst
);
pop
(
reg_ptr_global
);
}
void
VActJitCode
::
tanh_ymm
(
ymm_t
&
ymm_dst
,
ymm_t
&
ymm_src
,
int
fx_idx
,
int
fy_idx
,
int
mask_idx
,
int
tmp_idx
)
{
// y = 2 / (1 + e^(-2x)) - 1
ymm_t
ymm_tmp
=
ymm_t
(
tmp_idx
);
ymm_t
ymm_zero
=
ymm_t
(
mask_idx
);
reg64_t
reg_ptr_global
=
rax
;
push
(
reg_ptr_global
);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_TWO
]);
vxorps
(
ymm_zero
,
ymm_zero
,
ymm_zero
);
vsubps
(
ymm_tmp
,
ymm_zero
,
ymm_tmp
);
vmulps
(
ymm_src
,
ymm_src
,
ymm_tmp
);
exp_ymm
(
ymm_dst
,
ymm_src
,
fx_idx
,
fy_idx
,
mask_idx
,
tmp_idx
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_ONE
]);
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_TWO
]);
vdivps
(
ymm_dst
,
ymm_tmp
,
ymm_dst
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_ONE
]);
vsubps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
pop
(
reg_ptr_global
);
// TODO(TJ): implement avx512, avx_exp is slower than mkl when d >= 256
return
MayIUse
(
avx
);
}
void
VActJitCode
::
generate
()
{
...
...
@@ -324,16 +149,16 @@ void VActJitCode::generate() {
vmovups
(
ymm_src
,
ptr
[
param1
+
offset
]);
switch
(
type_
)
{
case
operand_type
::
relu
:
relu_
ymm
(
ymm_dst
,
ymm_src
,
ymm_zero
);
relu_
jmm
<
ymm_t
>
(
ymm_dst
,
ymm_src
,
ymm_zero
);
break
;
case
operand_type
::
exp
:
exp_
ymm
(
ymm_dst
,
ymm_src
,
2
,
3
,
4
,
5
);
exp_
jmm
<
ymm_t
>
(
ymm_dst
,
ymm_src
,
2
,
3
,
4
,
5
);
break
;
case
operand_type
::
sigmoid
:
sigmoid_
ymm
(
ymm_dst
,
ymm_src
,
2
,
3
,
4
,
5
);
sigmoid_
jmm
<
ymm_t
>
(
ymm_dst
,
ymm_src
,
2
,
3
,
4
,
5
);
break
;
case
operand_type
::
tanh
:
tanh_
ymm
(
ymm_dst
,
ymm_src
,
2
,
3
,
4
,
5
);
tanh_
jmm
<
ymm_t
>
(
ymm_dst
,
ymm_src
,
2
,
3
,
4
,
5
);
break
;
case
operand_type
::
identity
:
break
;
...
...
@@ -343,30 +168,44 @@ void VActJitCode::generate() {
vmovups
(
ptr
[
param2
+
offset
],
ymm_dst
);
offset
+=
sizeof
(
float
)
*
YMM_FLOAT_BLOCK
;
}
if
(
type_
!=
operand_type
::
relu
)
{
// TODO(TJ): remove me
ret
();
return
;
}
int
rest
=
num_
%
YMM_FLOAT_BLOCK
;
if
(
rest
>=
4
)
{
vmovups
(
xmm_src
,
ptr
[
param1
+
offset
]);
vmaxps
(
xmm_dst
,
xmm_zero
,
xmm_src
);
vmovups
(
ptr
[
param2
+
offset
],
xmm_dst
);
offset
+=
sizeof
(
float
)
*
4
;
rest
-=
4
;
}
if
(
rest
>=
2
)
{
vmovups
(
xmm_src
,
ptr
[
param1
+
offset
]);
vmaxps
(
xmm_dst
,
xmm_zero
,
xmm_src
);
vmovq
(
ptr
[
param2
+
offset
],
xmm_dst
);
offset
+=
sizeof
(
float
)
*
2
;
rest
-=
2
;
}
if
(
rest
>
0
)
{
vmovups
(
xmm_src
,
ptr
[
param1
+
offset
]);
vmaxps
(
xmm_dst
,
xmm_zero
,
xmm_src
);
vmovss
(
ptr
[
param2
+
offset
],
xmm_dst
);
while
(
rest
>
0
)
{
int
block
=
XMM_FLOAT_BLOCK
;
if
(
rest
>=
4
)
{
block
=
4
;
vmovups
(
xmm_src
,
ptr
[
param1
+
offset
]);
}
else
if
(
rest
>=
2
)
{
block
=
2
;
vmovq
(
xmm_src
,
ptr
[
param1
+
offset
]);
}
else
{
block
=
1
;
vmovss
(
xmm_src
,
ptr
[
param1
+
offset
]);
}
switch
(
type_
)
{
case
operand_type
::
relu
:
relu_jmm
<
xmm_t
>
(
xmm_dst
,
xmm_src
,
xmm_zero
);
break
;
case
operand_type
::
exp
:
exp_jmm
<
xmm_t
>
(
xmm_dst
,
xmm_src
,
2
,
3
,
4
,
5
);
break
;
case
operand_type
::
sigmoid
:
sigmoid_jmm
<
xmm_t
>
(
xmm_dst
,
xmm_src
,
2
,
3
,
4
,
5
);
break
;
case
operand_type
::
tanh
:
tanh_jmm
<
xmm_t
>
(
xmm_dst
,
xmm_src
,
2
,
3
,
4
,
5
);
break
;
default:
break
;
}
if
(
rest
>=
4
)
{
vmovups
(
ptr
[
param2
+
offset
],
xmm_dst
);
}
else
if
(
rest
>=
2
)
{
vmovq
(
ptr
[
param2
+
offset
],
xmm_dst
);
}
else
{
vmovss
(
ptr
[
param2
+
offset
],
xmm_dst
);
}
offset
+=
sizeof
(
float
)
*
block
;
rest
-=
block
;
}
ret
();
}
...
...
paddle/fluid/operators/math/jit_code.h
浏览文件 @
7aa3aff3
...
...
@@ -16,6 +16,8 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/operators/math/jit_gen.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
...
...
@@ -40,6 +42,51 @@ typedef enum {
identity
}
operand_type
;
extern
const
float
exp_float_consts
[];
extern
const
int
exp_int_0x7f
[];
extern
int
g_tmp_mem
[];
// TODO(TJ): move these to some proper place
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
#define XMM_FLOAT_BLOCK 4
#define YMM_FLOAT_BLOCK 8
#define ZMM_FLOAT_BLOCK 16
#define ALIGN32 __attribute__((aligned(32)))
#define EXP_HIG 88.3762626647949f
#define EXP_LOW -88.3762626647949f
#define CEPHES_LOG2EF 1.44269504088896341
#define CEPHES_EXP_C1 0.693359375
#define CEPHES_EXP_C2 -2.12194440e-4
#define CEPHES_EXP_P0 1.9875691500E-4
#define CEPHES_EXP_P1 1.3981999507E-3
#define CEPHES_EXP_P2 8.3334519073E-3
#define CEPHES_EXP_P3 4.1665795894E-2
#define CEPHES_EXP_P4 1.6666665459E-1
#define CEPHES_EXP_P5 5.0000001201E-1
#define REPEAT_8TIMES(val) val, val, val, val, val, val, val, val
#define OFFSET_EXP_ONE 0 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_TWO 1 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_0P5 2 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_HIG 3 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_LOW 4 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_LOG2EF 5 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_C1 6 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_C2 7 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P0 8 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P1 9 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P2 10 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P3 11 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P4 12 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P5 13 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_MAX_INPUT 14 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MAX 15 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MIN 16 * YMM_FLOAT_BLOCK * sizeof(float)
// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu)
class
VXXJitCode
:
public
JitCode
{
public:
...
...
@@ -127,21 +174,140 @@ class VActJitCode : public JitCode {
void
generate
()
override
;
protected:
// compute relu with ymm
void
relu_ymm
(
const
Xbyak
::
Ymm
&
dst
,
const
Xbyak
::
Ymm
&
src
,
const
Xbyak
::
Ymm
&
zero
);
// compute relu with ymm, xmm
template
<
typename
JMM
>
void
relu_jmm
(
JMM
&
dst
,
JMM
&
src
,
JMM
&
zero
)
{
// NOLINT
vmaxps
(
dst
,
src
,
zero
);
}
// compute exp with ymm
void
exp_ymm
(
const
Xbyak
::
Ymm
&
dst
,
const
Xbyak
::
Ymm
&
src
,
int
fx_idx
=
2
,
int
fy_idx
=
3
,
int
mask_idx
=
4
,
int
tmp_idx
=
5
);
// compute exp with ymm, xmm
template
<
typename
JMM
>
void
exp_jmm
(
JMM
&
dst
,
JMM
&
src
,
int
fx_idx
=
2
,
int
fy_idx
=
3
,
// NOLINT
int
mask_idx
=
4
,
int
tmp_idx
=
5
)
{
using
namespace
platform
::
jit
;
// NOLINT
assert
(
src
.
getIdx
()
!=
dst
.
getIdx
());
// TODO(TJ): use enfore
// check all idx can not equal
JMM
jmm_fx
=
JMM
(
fx_idx
);
JMM
jmm_fy
=
JMM
(
fy_idx
);
JMM
jmm_mask
=
JMM
(
mask_idx
);
JMM
jmm_tmp
=
JMM
(
tmp_idx
);
reg64_t
reg_ptr_global
=
rax
;
push
(
reg_ptr_global
);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_HIG
]);
vminps
(
src
,
src
,
jmm_tmp
);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_LOW
]);
vmaxps
(
src
,
src
,
jmm_tmp
);
// express exp(x) as exp(g + n*log(2))
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_LOG2EF
]);
vmulps
(
jmm_fx
,
src
,
jmm_tmp
);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_0P5
]);
vaddps
(
jmm_fx
,
jmm_fx
,
jmm_tmp
);
vroundps
(
jmm_fy
,
jmm_fx
,
0x01
);
// if greater, substract 1
vcmpgtps
(
jmm_mask
,
jmm_fy
,
jmm_fx
);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
]);
vandps
(
jmm_mask
,
jmm_mask
,
jmm_tmp
);
vsubps
(
jmm_fx
,
jmm_fy
,
jmm_mask
);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_C1
]);
vmulps
(
jmm_fy
,
jmm_fx
,
jmm_tmp
);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_C2
]);
JMM
ymm_z
=
JMM
(
jmm_mask
.
getIdx
());
vmulps
(
ymm_z
,
jmm_fx
,
jmm_tmp
);
vsubps
(
src
,
src
,
jmm_fy
);
vsubps
(
src
,
src
,
ymm_z
);
vmulps
(
ymm_z
,
src
,
src
);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_P0
]);
vmulps
(
dst
,
src
,
jmm_tmp
);
for
(
size_t
i
=
OFFSET_EXP_P1
;
i
<
OFFSET_EXP_P5
;
i
+=
(
YMM_FLOAT_BLOCK
*
sizeof
(
float
)))
{
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
i
]);
// P1~P4
vaddps
(
dst
,
dst
,
jmm_tmp
);
vmulps
(
dst
,
dst
,
src
);
}
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_P5
]);
vaddps
(
dst
,
dst
,
jmm_tmp
);
vmulps
(
dst
,
dst
,
ymm_z
);
vaddps
(
dst
,
dst
,
src
);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
]);
vaddps
(
dst
,
dst
,
jmm_tmp
);
// build 2^n
JMM
ymm_int
=
jmm_fx
;
vcvttps2dq
(
ymm_int
,
jmm_fx
);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_int_0x7f
));
vmovdqa
(
jmm_tmp
,
ptr
[
reg_ptr_global
]);
if
(
MayIUse
(
avx2
)
||
std
::
is_same
<
JMM
,
xmm_t
>::
value
)
{
vpaddd
(
ymm_int
,
ymm_int
,
jmm_tmp
);
vpslld
(
ymm_int
,
ymm_int
,
23
);
}
else
if
(
MayIUse
(
avx
))
{
xmm_t
xtmp1
=
xmm_t
(
ymm_int
.
getIdx
());
xmm_t
xtmp2
=
xmm_t
(
jmm_tmp
.
getIdx
());
reg64_t
reg_ptr_tmp
=
reg_ptr_global
;
mov
(
reg_ptr_tmp
,
reinterpret_cast
<
size_t
>
(
g_tmp_mem
));
vmovdqa
(
ptr
[
reg_ptr_tmp
],
ymm_int
);
vmovdqa
(
ptr
[
reg_ptr_tmp
+
YMM_FLOAT_BLOCK
*
sizeof
(
float
)],
jmm_tmp
);
vpaddd
(
xtmp1
,
xtmp1
,
xtmp2
);
vpslld
(
xtmp1
,
xtmp1
,
23
);
vmovdqa
(
ptr
[
reg_ptr_tmp
],
xtmp1
);
// next 128bits
vmovdqa
(
xtmp1
,
ptr
[
reg_ptr_tmp
+
XMM_FLOAT_BLOCK
*
sizeof
(
float
)]);
vmovdqa
(
xtmp2
,
ptr
[
reg_ptr_tmp
+
(
YMM_FLOAT_BLOCK
+
XMM_FLOAT_BLOCK
)
*
sizeof
(
float
)]);
vpaddd
(
xtmp1
,
xtmp1
,
xtmp2
);
vpslld
(
xtmp1
,
xtmp1
,
23
);
vmovdqa
(
ptr
[
reg_ptr_tmp
+
XMM_FLOAT_BLOCK
*
sizeof
(
float
)],
xtmp1
);
// load out
vmovdqa
(
ymm_int
,
ptr
[
reg_ptr_tmp
]);
}
vmulps
(
dst
,
dst
,
ymm_int
);
pop
(
reg_ptr_global
);
}
// compute sigmoid with ymm
void
sigmoid_ymm
(
const
Xbyak
::
Ymm
&
dst
,
const
Xbyak
::
Ymm
&
src
,
int
fx_idx
=
2
,
int
fy_idx
=
3
,
int
mask_idx
=
4
,
int
tmp_idx
=
5
);
// compute sigmoid with ymm, xmm
template
<
typename
JMM
>
void
sigmoid_jmm
(
JMM
&
dst
,
JMM
&
src
,
int
fx_idx
=
2
,
// NOLINT
int
fy_idx
=
3
,
int
mask_idx
=
4
,
int
tmp_idx
=
5
)
{
// y = 1 / (1 + e^-x)
JMM
jmm_tmp
=
JMM
(
tmp_idx
);
reg64_t
reg_ptr_global
=
rax
;
push
(
reg_ptr_global
);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_SIGMOID_MAX
]);
vminps
(
src
,
src
,
jmm_tmp
);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_SIGMOID_MIN
]);
vmaxps
(
src
,
src
,
jmm_tmp
);
vxorps
(
jmm_tmp
,
jmm_tmp
,
jmm_tmp
);
vsubps
(
src
,
jmm_tmp
,
src
);
exp_jmm
<
JMM
>
(
dst
,
src
,
fx_idx
,
fy_idx
,
mask_idx
,
tmp_idx
);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_ONE
]);
vaddps
(
dst
,
dst
,
jmm_tmp
);
vdivps
(
dst
,
jmm_tmp
,
dst
);
pop
(
reg_ptr_global
);
}
// compute tanh with ymm
void
tanh_ymm
(
const
Xbyak
::
Ymm
&
dst
,
const
Xbyak
::
Ymm
&
src
,
int
fx_idx
=
2
,
int
fy_idx
=
3
,
int
mask_idx
=
4
,
int
tmp_idx
=
5
);
// compute tanh with ymm, xmm
template
<
typename
JMM
>
void
tanh_jmm
(
JMM
&
dst
,
JMM
&
src
,
int
fx_idx
=
2
,
int
fy_idx
=
3
,
// NOLINT
int
mask_idx
=
4
,
int
tmp_idx
=
5
)
{
// y = 2 / (1 + e^(-2x)) - 1
JMM
jmm_tmp
=
JMM
(
tmp_idx
);
JMM
jmm_zero
=
JMM
(
mask_idx
);
reg64_t
reg_ptr_global
=
rax
;
push
(
reg_ptr_global
);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_TWO
]);
vxorps
(
jmm_zero
,
jmm_zero
,
jmm_zero
);
vsubps
(
jmm_tmp
,
jmm_zero
,
jmm_tmp
);
vmulps
(
src
,
src
,
jmm_tmp
);
exp_jmm
<
JMM
>
(
dst
,
src
,
fx_idx
,
fy_idx
,
mask_idx
,
tmp_idx
);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_ONE
]);
vaddps
(
dst
,
dst
,
jmm_tmp
);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_TWO
]);
vdivps
(
dst
,
jmm_tmp
,
dst
);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_ONE
]);
vsubps
(
dst
,
dst
,
jmm_tmp
);
pop
(
reg_ptr_global
);
}
protected:
int
num_
;
...
...
paddle/fluid/operators/math/jit_kernel.h
浏览文件 @
7aa3aff3
...
...
@@ -26,6 +26,7 @@ namespace operators {
namespace
math
{
namespace
jitkernel
{
// TODO(TJ): move these to some proper place
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
...
...
paddle/fluid/operators/math/jit_kernel_test.cc
浏览文件 @
7aa3aff3
...
...
@@ -33,6 +33,9 @@ limitations under the License. */
constexpr
int
repeat
=
20000
;
// TODO(TJ): benchmark and test should be seperated,
// benchmark should verify more sizes
inline
double
GetCurrentUS
()
{
struct
timeval
time
;
gettimeofday
(
&
time
,
NULL
);
...
...
@@ -66,7 +69,7 @@ void vrelu_intri8(const int n, const float* x, float* y) {
TEST
(
JitKernel
,
vrelu
)
{
namespace
jit
=
paddle
::
operators
::
math
::
jitkernel
;
for
(
int
d
:
{
7
,
8
,
15
,
16
,
30
,
256
,
512
})
{
for
(
int
d
:
{
3
,
7
,
8
,
15
,
16
,
30
,
256
,
512
})
{
std
::
vector
<
float
>
x
(
d
);
std
::
vector
<
float
>
zref
(
d
),
ztgt
(
d
);
RandomVec
<
float
>
(
d
,
x
.
data
(),
-
10.
f
,
1.
f
);
...
...
@@ -156,7 +159,7 @@ void vexp_mkl(const int n, const float* x, float* y) {
TEST
(
JitKernel
,
vexp
)
{
namespace
jit
=
paddle
::
operators
::
math
::
jitkernel
;
for
(
int
d
:
{
7
,
8
,
15
,
16
,
30
,
128
,
256
})
{
for
(
int
d
:
{
1
,
3
,
4
,
6
,
7
,
8
,
12
,
15
,
16
,
20
,
30
,
128
,
256
})
{
std
::
vector
<
float
>
x
(
d
);
std
::
vector
<
float
>
zref
(
d
),
ztgt
(
d
);
RandomVec
<
float
>
(
d
,
x
.
data
(),
-
2.
f
,
2.
f
);
...
...
@@ -231,7 +234,7 @@ void vsigmoid_better(
TEST
(
JitKernel
,
vsigmoid
)
{
namespace
jit
=
paddle
::
operators
::
math
::
jitkernel
;
for
(
int
d
:
{
7
,
8
,
15
,
16
,
30
,
32
,
64
,
100
,
128
,
256
})
{
for
(
int
d
:
{
1
,
3
,
4
,
6
,
7
,
8
,
15
,
16
,
30
,
32
,
64
,
100
,
128
,
256
})
{
std
::
vector
<
float
>
x
(
d
);
std
::
vector
<
float
>
zref
(
d
),
ztgt
(
d
);
RandomVec
<
float
>
(
d
,
x
.
data
(),
-
2.
f
,
2.
f
);
...
...
@@ -295,7 +298,7 @@ void vtanh_better(
TEST
(
JitKernel
,
vtanh
)
{
namespace
jit
=
paddle
::
operators
::
math
::
jitkernel
;
for
(
int
d
:
{
7
,
8
,
15
,
16
,
30
,
32
,
64
,
100
,
128
,
256
})
{
for
(
int
d
:
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
15
,
16
,
30
,
32
,
64
,
100
,
128
,
256
})
{
std
::
vector
<
float
>
x
(
d
);
std
::
vector
<
float
>
zref
(
d
),
ztgt
(
d
);
RandomVec
<
float
>
(
d
,
x
.
data
(),
-
2.
f
,
2.
f
);
...
...
@@ -386,7 +389,7 @@ void lstm_ctht_better(
TEST
(
JitKernel
,
lstm
)
{
namespace
jit
=
paddle
::
operators
::
math
::
jitkernel
;
for
(
int
d
:
{
7
,
8
,
15
,
16
,
30
,
32
,
64
,
100
})
{
for
(
int
d
:
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
15
,
16
,
30
,
32
,
64
,
100
})
{
int
d4
=
d
*
4
;
int
d3
=
d
*
3
;
std
::
vector
<
float
>
x
(
d4
),
xref
(
d4
);
...
...
@@ -759,7 +762,7 @@ TEST(JitKernel, vaddrelu) {
float
*
zref_data
=
zref
.
data
();
auto
trefs
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
vadd_ref
(
d
,
x_data
,
y_data
,
zref_data
);
vadd
relu
_ref
(
d
,
x_data
,
y_data
,
zref_data
);
}
auto
trefe
=
GetCurrentUS
();
auto
tmkls
=
GetCurrentUS
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录