Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
f65ddff8
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f65ddff8
编写于
11月 15, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
unify act jitcode of relu, exp, sigmoid and tanh
上级
6a159071
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
153 addition
and
159 deletion
+153
-159
paddle/fluid/operators/math/jit_code.cc
paddle/fluid/operators/math/jit_code.cc
+83
-80
paddle/fluid/operators/math/jit_code.h
paddle/fluid/operators/math/jit_code.h
+54
-67
paddle/fluid/operators/math/jit_kernel_blas.cc
paddle/fluid/operators/math/jit_kernel_blas.cc
+4
-3
paddle/fluid/operators/math/jit_kernel_exp.cc
paddle/fluid/operators/math/jit_kernel_exp.cc
+12
-9
未找到文件。
paddle/fluid/operators/math/jit_code.cc
浏览文件 @
f65ddff8
...
@@ -118,40 +118,6 @@ void VXXJitCode::generate() {
...
@@ -118,40 +118,6 @@ void VXXJitCode::generate() {
ret
();
ret
();
}
}
bool
ReluJitCode
::
init
(
int
d
)
{
return
MayIUse
(
avx
);
}
void
ReluJitCode
::
generate
()
{
int
offset
=
0
;
vxorps
(
ymm_zero
,
ymm_zero
,
ymm_zero
);
for
(
int
i
=
0
;
i
<
num_
/
AVX_FLOAT_BLOCK
;
++
i
)
{
vmovups
(
ymm_src
,
ptr
[
param1
+
offset
]);
vmaxps
(
ymm_dst
,
ymm_zero
,
ymm_src
);
vmovups
(
ptr
[
param2
+
offset
],
ymm_dst
);
offset
+=
sizeof
(
float
)
*
AVX_FLOAT_BLOCK
;
}
int
rest
=
num_
%
AVX_FLOAT_BLOCK
;
if
(
rest
>=
4
)
{
vmovups
(
xmm_src
,
ptr
[
param1
+
offset
]);
vmaxps
(
xmm_dst
,
xmm_zero
,
xmm_src
);
vmovups
(
ptr
[
param2
+
offset
],
xmm_dst
);
offset
+=
sizeof
(
float
)
*
4
;
rest
-=
4
;
}
if
(
rest
>=
2
)
{
vmovups
(
xmm_src
,
ptr
[
param1
+
offset
]);
vmaxps
(
xmm_dst
,
xmm_zero
,
xmm_src
);
vmovq
(
ptr
[
param2
+
offset
],
xmm_dst
);
offset
+=
sizeof
(
float
)
*
2
;
rest
-=
2
;
}
if
(
rest
>
0
)
{
vmovups
(
xmm_src
,
ptr
[
param1
+
offset
]);
vmaxps
(
xmm_dst
,
xmm_zero
,
xmm_src
);
vmovss
(
ptr
[
param2
+
offset
],
xmm_dst
);
}
ret
();
}
#define ALIGN32 __attribute__((aligned(32)))
#define ALIGN32 __attribute__((aligned(32)))
#define EXP_HIG 88.3762626647949f
#define EXP_HIG 88.3762626647949f
#define EXP_LOW -88.3762626647949f
#define EXP_LOW -88.3762626647949f
...
@@ -207,18 +173,28 @@ static const float exp_float_consts[] ALIGN32 = {
...
@@ -207,18 +173,28 @@ static const float exp_float_consts[] ALIGN32 = {
static
const
int
exp_int_0x7f
[]
ALIGN32
=
{
REPEAT_8TIMES
(
0x7f
)};
static
const
int
exp_int_0x7f
[]
ALIGN32
=
{
REPEAT_8TIMES
(
0x7f
)};
static
int
g_tmp_mem
[
16
]
ALIGN32
=
{
0
};
static
int
g_tmp_mem
[
16
]
ALIGN32
=
{
0
};
bool
VExpJitCode
::
init
(
int
d
)
{
bool
VActJitCode
::
init
(
int
d
,
operand_type
type
)
{
return
MayIUse
(
avx
)
&&
d
==
8
;
// only 8 yet
bool
ok
=
MayIUse
(
avx
);
if
(
type
==
operand_type
::
relu
)
{
return
ok
;
}
else
{
return
ok
&&
d
==
8
;
// only 8 yet
}
}
}
void
VExpJitCode
::
exp_ymm
(
ymm_t
&
ymm_src
,
ymm_t
&
ymm_dst
)
{
void
VActJitCode
::
relu_ymm
(
ymm_t
&
ymm_dst
,
ymm_t
&
ymm_src
,
ymm_t
&
ymm_zero
)
{
// use reg rax and ymm 2~5
vmaxps
(
ymm_dst
,
ymm_zero
,
ymm_src
);
reg64_t
reg_ptr_global
=
rax
;
}
ymm_t
ymm_fx
=
ymm_t
(
2
);
ymm_t
ymm_fy
=
ymm_t
(
3
);
void
VActJitCode
::
exp_ymm
(
ymm_t
&
ymm_dst
,
ymm_t
&
ymm_src
,
int
fx_idx
,
ymm_t
ymm_mask
=
ymm_t
(
4
);
int
fy_idx
,
int
mask_idx
,
int
tmp_idx
)
{
ymm_t
ymm_tmp
=
ymm_t
(
5
);
assert
(
ymm_src
.
getIdx
()
!=
ymm_dst
.
getIdx
());
// TODO(TJ): use enfore
assert
(
ymm_src
.
getIdx
()
!=
ymm_dst
.
getIdx
());
// TODO(TJ): use enfore
// check all idx can not equal
ymm_t
ymm_fx
=
ymm_t
(
fx_idx
);
ymm_t
ymm_fy
=
ymm_t
(
fy_idx
);
ymm_t
ymm_mask
=
ymm_t
(
mask_idx
);
ymm_t
ymm_tmp
=
ymm_t
(
tmp_idx
);
reg64_t
reg_ptr_global
=
rax
;
push
(
reg_ptr_global
);
push
(
reg_ptr_global
);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_HIG
]);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_HIG
]);
...
@@ -291,22 +267,11 @@ void VExpJitCode::exp_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) {
...
@@ -291,22 +267,11 @@ void VExpJitCode::exp_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) {
pop
(
reg_ptr_global
);
pop
(
reg_ptr_global
);
}
}
void
VExpJitCode
::
generate
()
{
void
VActJitCode
::
sigmoid_ymm
(
ymm_t
&
ymm_dst
,
ymm_t
&
ymm_src
,
int
fx_idx
,
int
offset
=
0
;
int
fy_idx
,
int
mask_idx
,
int
tmp_idx
)
{
vmovups
(
ymm_src
,
ptr
[
param1
+
offset
]);
// y = 1 / (1 + e^-x)
exp_ymm
(
ymm_src
,
ymm_dst
);
ymm_t
ymm_tmp
=
ymm_t
(
tmp_idx
);
vmovups
(
ptr
[
param2
+
offset
],
ymm_dst
);
ret
();
}
bool
VSigmoidJitCode
::
init
(
int
d
)
{
return
MayIUse
(
avx
)
&&
d
==
8
;
// only 8 yet
}
void
VSigmoidJitCode
::
sigmoid_ymm
(
ymm_t
&
ymm_src
,
ymm_t
&
ymm_dst
)
{
// use ymm2
reg64_t
reg_ptr_global
=
rax
;
reg64_t
reg_ptr_global
=
rax
;
ymm_t
ymm_tmp
=
ymm_t
(
2
);
push
(
reg_ptr_global
);
push
(
reg_ptr_global
);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_SIGMOID_MAX
]);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_SIGMOID_MAX
]);
...
@@ -315,38 +280,26 @@ void VSigmoidJitCode::sigmoid_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) {
...
@@ -315,38 +280,26 @@ void VSigmoidJitCode::sigmoid_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) {
vmaxps
(
ymm_src
,
ymm_src
,
ymm_tmp
);
vmaxps
(
ymm_src
,
ymm_src
,
ymm_tmp
);
vxorps
(
ymm_tmp
,
ymm_tmp
,
ymm_tmp
);
vxorps
(
ymm_tmp
,
ymm_tmp
,
ymm_tmp
);
vsubps
(
ymm_src
,
ymm_tmp
,
ymm_src
);
vsubps
(
ymm_src
,
ymm_tmp
,
ymm_src
);
exp_ymm
(
ymm_
src
,
ymm_dst
);
exp_ymm
(
ymm_
dst
,
ymm_src
,
fx_idx
,
fy_idx
,
mask_idx
,
tmp_idx
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_ONE
]);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_ONE
]);
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
vdivps
(
ymm_dst
,
ymm_tmp
,
ymm_dst
);
vdivps
(
ymm_dst
,
ymm_tmp
,
ymm_dst
);
pop
(
reg_ptr_global
);
pop
(
reg_ptr_global
);
}
}
void
VSigmoidJitCode
::
generate
()
{
void
VActJitCode
::
tanh_ymm
(
ymm_t
&
ymm_dst
,
ymm_t
&
ymm_src
,
int
fx_idx
,
int
offset
=
0
;
int
fy_idx
,
int
mask_idx
,
int
tmp_idx
)
{
vmovups
(
ymm_src
,
ptr
[
param1
+
offset
]);
sigmoid_ymm
(
ymm_src
,
ymm_dst
);
vmovups
(
ptr
[
param2
+
offset
],
ymm_dst
);
ret
();
}
bool
VTanhJitCode
::
init
(
int
d
)
{
return
MayIUse
(
avx
)
&&
d
==
8
;
// only 8 yet
}
void
VTanhJitCode
::
vtanh_ymm
(
ymm_t
&
ymm_src
,
ymm_t
&
ymm_dst
)
{
// y = 2 / (1 + e^(-2x)) - 1
// y = 2 / (1 + e^(-2x)) - 1
// use ymm2, ymm3
ymm_t
ymm_tmp
=
ymm_t
(
tmp_idx
);
ymm_t
ymm_zero
=
ymm_t
(
mask_idx
);
reg64_t
reg_ptr_global
=
rax
;
reg64_t
reg_ptr_global
=
rax
;
ymm_t
ymm_tmp
=
ymm_t
(
2
);
ymm_t
ymm_zero
=
ymm_t
(
3
);
push
(
reg_ptr_global
);
push
(
reg_ptr_global
);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_TWO
]);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_TWO
]);
vxorps
(
ymm_zero
,
ymm_zero
,
ymm_zero
);
vxorps
(
ymm_zero
,
ymm_zero
,
ymm_zero
);
vsubps
(
ymm_tmp
,
ymm_zero
,
ymm_tmp
);
vsubps
(
ymm_tmp
,
ymm_zero
,
ymm_tmp
);
vmulps
(
ymm_src
,
ymm_src
,
ymm_tmp
);
vmulps
(
ymm_src
,
ymm_src
,
ymm_tmp
);
exp_ymm
(
ymm_
src
,
ymm_dst
);
exp_ymm
(
ymm_
dst
,
ymm_src
,
fx_idx
,
fy_idx
,
mask_idx
,
tmp_idx
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_ONE
]);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_ONE
]);
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_TWO
]);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_TWO
]);
...
@@ -356,11 +309,61 @@ void VTanhJitCode::vtanh_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) {
...
@@ -356,11 +309,61 @@ void VTanhJitCode::vtanh_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) {
pop
(
reg_ptr_global
);
pop
(
reg_ptr_global
);
}
}
void
VTanhJitCode
::
generate
()
{
void
VActJitCode
::
generate
()
{
xmm_t
xmm_zero
=
xmm_t
(
2
);
ymm_t
ymm_zero
=
ymm_t
(
2
);
if
(
type_
==
operand_type
::
relu
)
{
vxorps
(
ymm_zero
,
ymm_zero
,
ymm_zero
);
}
int
offset
=
0
;
int
offset
=
0
;
for
(
int
i
=
0
;
i
<
num_
/
AVX_FLOAT_BLOCK
;
++
i
)
{
vmovups
(
ymm_src
,
ptr
[
param1
+
offset
]);
vmovups
(
ymm_src
,
ptr
[
param1
+
offset
]);
vtanh_ymm
(
ymm_src
,
ymm_dst
);
switch
(
type_
)
{
case
operand_type
::
relu
:
relu_ymm
(
ymm_dst
,
ymm_src
,
ymm_zero
);
break
;
case
operand_type
::
exp
:
exp_ymm
(
ymm_dst
,
ymm_src
,
2
,
3
,
4
,
5
);
break
;
case
operand_type
::
sigmoid
:
sigmoid_ymm
(
ymm_dst
,
ymm_src
,
2
,
3
,
4
,
5
);
break
;
case
operand_type
::
tanh
:
tanh_ymm
(
ymm_dst
,
ymm_src
,
2
,
3
,
4
,
5
);
break
;
case
operand_type
::
identity
:
break
;
default:
break
;
}
vmovups
(
ptr
[
param2
+
offset
],
ymm_dst
);
vmovups
(
ptr
[
param2
+
offset
],
ymm_dst
);
offset
+=
sizeof
(
float
)
*
AVX_FLOAT_BLOCK
;
}
if
(
type_
!=
operand_type
::
relu
)
{
// TODO(TJ): remove me
ret
();
return
;
}
int
rest
=
num_
%
AVX_FLOAT_BLOCK
;
if
(
rest
>=
4
)
{
vmovups
(
xmm_src
,
ptr
[
param1
+
offset
]);
vmaxps
(
xmm_dst
,
xmm_zero
,
xmm_src
);
vmovups
(
ptr
[
param2
+
offset
],
xmm_dst
);
offset
+=
sizeof
(
float
)
*
4
;
rest
-=
4
;
}
if
(
rest
>=
2
)
{
vmovups
(
xmm_src
,
ptr
[
param1
+
offset
]);
vmaxps
(
xmm_dst
,
xmm_zero
,
xmm_src
);
vmovq
(
ptr
[
param2
+
offset
],
xmm_dst
);
offset
+=
sizeof
(
float
)
*
2
;
rest
-=
2
;
}
if
(
rest
>
0
)
{
vmovups
(
xmm_src
,
ptr
[
param1
+
offset
]);
vmaxps
(
xmm_dst
,
xmm_zero
,
xmm_src
);
vmovss
(
ptr
[
param2
+
offset
],
xmm_dst
);
}
ret
();
ret
();
}
}
...
...
paddle/fluid/operators/math/jit_code.h
浏览文件 @
f65ddff8
...
@@ -29,7 +29,16 @@ using ymm_t = const Xbyak::Ymm;
...
@@ -29,7 +29,16 @@ using ymm_t = const Xbyak::Ymm;
using
zmm_t
=
const
Xbyak
::
Zmm
;
using
zmm_t
=
const
Xbyak
::
Zmm
;
using
Label
=
Xbyak
::
Label
;
using
Label
=
Xbyak
::
Label
;
typedef
enum
{
mul
=
0
,
add
}
operand_type
;
typedef
enum
{
mul
=
0
,
add
,
sub
,
relu
,
exp
,
sigmoid
,
tanh
,
identity
}
operand_type
;
// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu)
// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu)
class
VXXJitCode
:
public
JitCode
{
class
VXXJitCode
:
public
JitCode
{
...
@@ -85,87 +94,65 @@ class VXXJitCode : public JitCode {
...
@@ -85,87 +94,65 @@ class VXXJitCode : public JitCode {
ymm_t
ymm_zero
=
ymm_t
(
3
);
ymm_t
ymm_zero
=
ymm_t
(
3
);
};
};
class
Relu
JitCode
:
public
JitCode
{
class
VAct
JitCode
:
public
JitCode
{
public:
public:
DECLARE_JIT_CODE
(
ReluJitCode
);
const
char
*
name
()
const
override
{
explicit
ReluJitCode
(
int
d
,
size_t
code_size
=
256
*
1024
,
std
::
string
base
=
"VActJitCode"
;
void
*
code_ptr
=
nullptr
)
switch
(
type_
)
{
:
JitCode
(
code_size
,
code_ptr
),
num_
(
d
)
{}
case
operand_type
::
relu
:
static
bool
init
(
int
d
);
base
+=
"_Relu"
;
void
generate
()
override
;
break
;
case
operand_type
::
exp
:
private:
base
+=
"_Exp"
;
int
num_
;
break
;
reg64_t
param1
{
abi_param1
};
case
operand_type
::
sigmoid
:
reg64_t
param2
{
abi_param2
};
base
+=
"_Sigmoid"
;
break
;
xmm_t
xmm_zero
=
xmm_t
(
0
);
case
operand_type
::
tanh
:
xmm_t
xmm_src
=
xmm_t
(
1
);
base
+=
"_Tanh"
;
xmm_t
xmm_dst
=
xmm_t
(
1
);
break
;
case
operand_type
::
identity
:
ymm_t
ymm_zero
=
ymm_t
(
0
);
base
+=
"_Identity"
;
ymm_t
ymm_src
=
ymm_t
(
1
);
break
;
ymm_t
ymm_dst
=
ymm_t
(
1
);
default:
};
break
;
}
return
base
.
c_str
();
}
class
VExpJitCode
:
public
JitCode
{
explicit
VActJitCode
(
int
d
,
operand_type
type
,
size_t
code_size
=
256
*
1024
,
public:
DECLARE_JIT_CODE
(
VExpJitCode
);
explicit
VExpJitCode
(
int
d
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
void
*
code_ptr
=
nullptr
)
:
JitCode
(
code_size
,
code_ptr
),
num_
(
d
)
{}
:
JitCode
(
code_size
,
code_ptr
),
num_
(
d
)
,
type_
(
type
)
{}
static
bool
init
(
int
d
);
static
bool
init
(
int
d
,
operand_type
type
);
void
generate
()
override
;
void
generate
()
override
;
protected:
protected:
// compute exp with ymm
// compute relu with ymm
void
exp_ymm
(
const
Xbyak
::
Ymm
&
src
,
const
Xbyak
::
Ymm
&
dst
);
void
relu_ymm
(
const
Xbyak
::
Ymm
&
dst
,
const
Xbyak
::
Ymm
&
src
,
const
Xbyak
::
Ymm
&
zero
);
private:
// compute exp with ymm
int
num_
;
void
exp_ymm
(
const
Xbyak
::
Ymm
&
dst
,
const
Xbyak
::
Ymm
&
src
,
int
fx_idx
=
2
,
reg64_t
param1
{
abi_param1
};
int
fy_idx
=
3
,
int
mask_idx
=
4
,
int
tmp_idx
=
5
);
reg64_t
param2
{
abi_param2
};
ymm_t
ymm_src
=
ymm_t
(
0
);
ymm_t
ymm_dst
=
ymm_t
(
1
);
};
class
VSigmoidJitCode
:
public
VExpJitCode
{
public:
DECLARE_JIT_CODE
(
VSigmoidJitCode
);
explicit
VSigmoidJitCode
(
int
d
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
:
VExpJitCode
(
d
,
code_size
,
code_ptr
),
num_
(
d
)
{}
static
bool
init
(
int
d
);
void
generate
()
override
;
// compute sigmoid with ymm
// compute sigmoid with ymm
void
sigmoid_ymm
(
const
Xbyak
::
Ymm
&
src
,
const
Xbyak
::
Ymm
&
dst
);
void
sigmoid_ymm
(
const
Xbyak
::
Ymm
&
dst
,
const
Xbyak
::
Ymm
&
src
,
int
fx_idx
=
2
,
int
fy_idx
=
3
,
int
mask_idx
=
4
,
int
tmp_idx
=
5
);
private:
int
num_
;
reg64_t
param1
{
abi_param1
};
reg64_t
param2
{
abi_param2
};
ymm_t
ymm_src
=
ymm_t
(
0
);
ymm_t
ymm_dst
=
ymm_t
(
1
);
};
class
VTanhJitCode
:
public
VExpJitCode
{
public:
DECLARE_JIT_CODE
(
VTanhJitCode
);
explicit
VTanhJitCode
(
int
d
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
:
VExpJitCode
(
d
,
code_size
,
code_ptr
),
num_
(
d
)
{}
static
bool
init
(
int
d
);
void
generate
()
override
;
// compute sigmoid with ymm
// compute tanh with ymm
void
vtanh_ymm
(
const
Xbyak
::
Ymm
&
src
,
const
Xbyak
::
Ymm
&
dst
);
void
tanh_ymm
(
const
Xbyak
::
Ymm
&
dst
,
const
Xbyak
::
Ymm
&
src
,
int
fx_idx
=
2
,
int
fy_idx
=
3
,
int
mask_idx
=
4
,
int
tmp_idx
=
5
);
pr
ivate
:
pr
otected
:
int
num_
;
int
num_
;
operand_type
type_
;
reg64_t
param1
{
abi_param1
};
reg64_t
param1
{
abi_param1
};
reg64_t
param2
{
abi_param2
};
reg64_t
param2
{
abi_param2
};
xmm_t
xmm_src
=
xmm_t
(
0
);
ymm_t
ymm_src
=
ymm_t
(
0
);
ymm_t
ymm_src
=
ymm_t
(
0
);
xmm_t
xmm_dst
=
xmm_t
(
1
);
ymm_t
ymm_dst
=
ymm_t
(
1
);
ymm_t
ymm_dst
=
ymm_t
(
1
);
};
};
...
...
paddle/fluid/operators/math/jit_kernel_blas.cc
浏览文件 @
f65ddff8
...
@@ -352,7 +352,8 @@ class VReluKernelImpl : public VReluKernel<T> {
...
@@ -352,7 +352,8 @@ class VReluKernelImpl : public VReluKernel<T> {
size_t
sz
=
96
/* init size */
+
size_t
sz
=
96
/* init size */
+
d
/
AVX_FLOAT_BLOCK
*
4
/* instructions */
*
d
/
AVX_FLOAT_BLOCK
*
4
/* instructions */
*
8
/* average bytes for each instruction */
;
8
/* average bytes for each instruction */
;
jitcode_
.
reset
(
new
gen
::
ReluJitCode
(
d
,
sz
>
4096
?
sz
:
4096
));
jitcode_
.
reset
(
new
gen
::
VActJitCode
(
d
,
gen
::
operand_type
::
relu
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
T
*
,
int
)
>
();
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
T
*
,
int
)
>
();
return
;
return
;
}
}
...
@@ -366,14 +367,14 @@ class VReluKernelImpl : public VReluKernel<T> {
...
@@ -366,14 +367,14 @@ class VReluKernelImpl : public VReluKernel<T> {
#ifdef PADDLE_WITH_XBYAK
#ifdef PADDLE_WITH_XBYAK
private:
private:
std
::
unique_ptr
<
gen
::
Relu
JitCode
>
jitcode_
{
nullptr
};
std
::
unique_ptr
<
gen
::
VAct
JitCode
>
jitcode_
{
nullptr
};
#endif
#endif
};
};
#ifdef PADDLE_WITH_XBYAK
#ifdef PADDLE_WITH_XBYAK
template
<
>
template
<
>
bool
VReluKernelImpl
<
float
>::
useJIT
(
int
d
)
{
bool
VReluKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
ReluJitCode
::
init
(
d
);
return
gen
::
VActJitCode
::
init
(
d
,
gen
::
operand_type
::
relu
);
}
}
#endif
#endif
...
...
paddle/fluid/operators/math/jit_kernel_exp.cc
浏览文件 @
f65ddff8
...
@@ -116,7 +116,8 @@ class VExpKernelImpl : public VExpKernel<T> {
...
@@ -116,7 +116,8 @@ class VExpKernelImpl : public VExpKernel<T> {
#ifdef PADDLE_WITH_XBYAK
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
AVX_FLOAT_BLOCK
*
4
*
8
;
// should change
size_t
sz
=
96
+
d
/
AVX_FLOAT_BLOCK
*
4
*
8
;
// should change
jitcode_
.
reset
(
new
gen
::
VExpJitCode
(
d
,
sz
>
4096
?
sz
:
4096
));
jitcode_
.
reset
(
new
gen
::
VActJitCode
(
d
,
gen
::
operand_type
::
exp
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
T
*
,
int
)
>
();
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
T
*
,
int
)
>
();
return
;
return
;
}
}
...
@@ -135,14 +136,14 @@ class VExpKernelImpl : public VExpKernel<T> {
...
@@ -135,14 +136,14 @@ class VExpKernelImpl : public VExpKernel<T> {
#ifdef PADDLE_WITH_XBYAK
#ifdef PADDLE_WITH_XBYAK
private:
private:
std
::
unique_ptr
<
gen
::
V
Exp
JitCode
>
jitcode_
{
nullptr
};
std
::
unique_ptr
<
gen
::
V
Act
JitCode
>
jitcode_
{
nullptr
};
#endif
#endif
};
};
#ifdef PADDLE_WITH_XBYAK
#ifdef PADDLE_WITH_XBYAK
template
<
>
template
<
>
bool
VExpKernelImpl
<
float
>::
useJIT
(
int
d
)
{
bool
VExpKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
V
ExpJitCode
::
init
(
d
);
return
gen
::
V
ActJitCode
::
init
(
d
,
gen
::
operand_type
::
exp
);
}
}
#endif
#endif
...
@@ -169,7 +170,8 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
...
@@ -169,7 +170,8 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
#ifdef PADDLE_WITH_XBYAK
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
AVX_FLOAT_BLOCK
*
4
*
8
;
// should change
size_t
sz
=
96
+
d
/
AVX_FLOAT_BLOCK
*
4
*
8
;
// should change
jitcode_
.
reset
(
new
gen
::
VSigmoidJitCode
(
d
,
sz
>
4096
?
sz
:
4096
));
jitcode_
.
reset
(
new
gen
::
VActJitCode
(
d
,
gen
::
operand_type
::
sigmoid
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
T
*
,
int
)
>
();
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
T
*
,
int
)
>
();
return
;
return
;
}
}
...
@@ -190,14 +192,14 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
...
@@ -190,14 +192,14 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
#ifdef PADDLE_WITH_XBYAK
#ifdef PADDLE_WITH_XBYAK
private:
private:
std
::
unique_ptr
<
gen
::
V
Sigmoid
JitCode
>
jitcode_
{
nullptr
};
std
::
unique_ptr
<
gen
::
V
Act
JitCode
>
jitcode_
{
nullptr
};
#endif
#endif
};
};
#ifdef PADDLE_WITH_XBYAK
#ifdef PADDLE_WITH_XBYAK
template
<
>
template
<
>
bool
VSigmoidKernelImpl
<
float
>::
useJIT
(
int
d
)
{
bool
VSigmoidKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
V
SigmoidJitCode
::
init
(
d
);
return
gen
::
V
ActJitCode
::
init
(
d
,
gen
::
operand_type
::
sigmoi
d
);
}
}
#endif
#endif
...
@@ -223,7 +225,8 @@ class VTanhKernelImpl : public VTanhKernel<T> {
...
@@ -223,7 +225,8 @@ class VTanhKernelImpl : public VTanhKernel<T> {
#ifdef PADDLE_WITH_XBYAK
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
AVX_FLOAT_BLOCK
*
4
*
8
;
// should change
size_t
sz
=
96
+
d
/
AVX_FLOAT_BLOCK
*
4
*
8
;
// should change
jitcode_
.
reset
(
new
gen
::
VTanhJitCode
(
d
,
sz
>
4096
?
sz
:
4096
));
jitcode_
.
reset
(
new
gen
::
VActJitCode
(
d
,
gen
::
operand_type
::
tanh
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
T
*
,
int
)
>
();
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
T
*
,
int
)
>
();
return
;
return
;
}
}
...
@@ -244,14 +247,14 @@ class VTanhKernelImpl : public VTanhKernel<T> {
...
@@ -244,14 +247,14 @@ class VTanhKernelImpl : public VTanhKernel<T> {
#ifdef PADDLE_WITH_XBYAK
#ifdef PADDLE_WITH_XBYAK
private:
private:
std
::
unique_ptr
<
gen
::
V
Tanh
JitCode
>
jitcode_
{
nullptr
};
std
::
unique_ptr
<
gen
::
V
Act
JitCode
>
jitcode_
{
nullptr
};
#endif
#endif
};
};
#ifdef PADDLE_WITH_XBYAK
#ifdef PADDLE_WITH_XBYAK
template
<
>
template
<
>
bool
VTanhKernelImpl
<
float
>::
useJIT
(
int
d
)
{
bool
VTanhKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
V
TanhJitCode
::
init
(
d
);
return
gen
::
V
ActJitCode
::
init
(
d
,
gen
::
operand_type
::
tanh
);
}
}
#endif
#endif
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录