Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
f65ddff8
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f65ddff8
编写于
11月 15, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
unify act jitcode of relu, exp, sigmoid and tanh
上级
6a159071
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
153 addition
and
159 deletion
+153
-159
paddle/fluid/operators/math/jit_code.cc
paddle/fluid/operators/math/jit_code.cc
+83
-80
paddle/fluid/operators/math/jit_code.h
paddle/fluid/operators/math/jit_code.h
+54
-67
paddle/fluid/operators/math/jit_kernel_blas.cc
paddle/fluid/operators/math/jit_kernel_blas.cc
+4
-3
paddle/fluid/operators/math/jit_kernel_exp.cc
paddle/fluid/operators/math/jit_kernel_exp.cc
+12
-9
未找到文件。
paddle/fluid/operators/math/jit_code.cc
浏览文件 @
f65ddff8
...
...
@@ -118,40 +118,6 @@ void VXXJitCode::generate() {
ret
();
}
bool
ReluJitCode
::
init
(
int
d
)
{
return
MayIUse
(
avx
);
}
void
ReluJitCode
::
generate
()
{
int
offset
=
0
;
vxorps
(
ymm_zero
,
ymm_zero
,
ymm_zero
);
for
(
int
i
=
0
;
i
<
num_
/
AVX_FLOAT_BLOCK
;
++
i
)
{
vmovups
(
ymm_src
,
ptr
[
param1
+
offset
]);
vmaxps
(
ymm_dst
,
ymm_zero
,
ymm_src
);
vmovups
(
ptr
[
param2
+
offset
],
ymm_dst
);
offset
+=
sizeof
(
float
)
*
AVX_FLOAT_BLOCK
;
}
int
rest
=
num_
%
AVX_FLOAT_BLOCK
;
if
(
rest
>=
4
)
{
vmovups
(
xmm_src
,
ptr
[
param1
+
offset
]);
vmaxps
(
xmm_dst
,
xmm_zero
,
xmm_src
);
vmovups
(
ptr
[
param2
+
offset
],
xmm_dst
);
offset
+=
sizeof
(
float
)
*
4
;
rest
-=
4
;
}
if
(
rest
>=
2
)
{
vmovups
(
xmm_src
,
ptr
[
param1
+
offset
]);
vmaxps
(
xmm_dst
,
xmm_zero
,
xmm_src
);
vmovq
(
ptr
[
param2
+
offset
],
xmm_dst
);
offset
+=
sizeof
(
float
)
*
2
;
rest
-=
2
;
}
if
(
rest
>
0
)
{
vmovups
(
xmm_src
,
ptr
[
param1
+
offset
]);
vmaxps
(
xmm_dst
,
xmm_zero
,
xmm_src
);
vmovss
(
ptr
[
param2
+
offset
],
xmm_dst
);
}
ret
();
}
#define ALIGN32 __attribute__((aligned(32)))
#define EXP_HIG 88.3762626647949f
#define EXP_LOW -88.3762626647949f
...
...
@@ -207,18 +173,28 @@ static const float exp_float_consts[] ALIGN32 = {
static
const
int
exp_int_0x7f
[]
ALIGN32
=
{
REPEAT_8TIMES
(
0x7f
)};
static
int
g_tmp_mem
[
16
]
ALIGN32
=
{
0
};
bool
VExpJitCode
::
init
(
int
d
)
{
return
MayIUse
(
avx
)
&&
d
==
8
;
// only 8 yet
bool
VActJitCode
::
init
(
int
d
,
operand_type
type
)
{
bool
ok
=
MayIUse
(
avx
);
if
(
type
==
operand_type
::
relu
)
{
return
ok
;
}
else
{
return
ok
&&
d
==
8
;
// only 8 yet
}
}
void
VExpJitCode
::
exp_ymm
(
ymm_t
&
ymm_src
,
ymm_t
&
ymm_dst
)
{
// use reg rax and ymm 2~5
reg64_t
reg_ptr_global
=
rax
;
ymm_t
ymm_fx
=
ymm_t
(
2
);
ymm_t
ymm_fy
=
ymm_t
(
3
);
ymm_t
ymm_mask
=
ymm_t
(
4
);
ymm_t
ymm_tmp
=
ymm_t
(
5
);
void
VActJitCode
::
relu_ymm
(
ymm_t
&
ymm_dst
,
ymm_t
&
ymm_src
,
ymm_t
&
ymm_zero
)
{
vmaxps
(
ymm_dst
,
ymm_zero
,
ymm_src
);
}
void
VActJitCode
::
exp_ymm
(
ymm_t
&
ymm_dst
,
ymm_t
&
ymm_src
,
int
fx_idx
,
int
fy_idx
,
int
mask_idx
,
int
tmp_idx
)
{
assert
(
ymm_src
.
getIdx
()
!=
ymm_dst
.
getIdx
());
// TODO(TJ): use enfore
// check all idx can not equal
ymm_t
ymm_fx
=
ymm_t
(
fx_idx
);
ymm_t
ymm_fy
=
ymm_t
(
fy_idx
);
ymm_t
ymm_mask
=
ymm_t
(
mask_idx
);
ymm_t
ymm_tmp
=
ymm_t
(
tmp_idx
);
reg64_t
reg_ptr_global
=
rax
;
push
(
reg_ptr_global
);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_HIG
]);
...
...
@@ -291,22 +267,11 @@ void VExpJitCode::exp_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) {
pop
(
reg_ptr_global
);
}
void
VExpJitCode
::
generate
()
{
int
offset
=
0
;
vmovups
(
ymm_src
,
ptr
[
param1
+
offset
]);
exp_ymm
(
ymm_src
,
ymm_dst
);
vmovups
(
ptr
[
param2
+
offset
],
ymm_dst
);
ret
();
}
bool
VSigmoidJitCode
::
init
(
int
d
)
{
return
MayIUse
(
avx
)
&&
d
==
8
;
// only 8 yet
}
void
VSigmoidJitCode
::
sigmoid_ymm
(
ymm_t
&
ymm_src
,
ymm_t
&
ymm_dst
)
{
// use ymm2
void
VActJitCode
::
sigmoid_ymm
(
ymm_t
&
ymm_dst
,
ymm_t
&
ymm_src
,
int
fx_idx
,
int
fy_idx
,
int
mask_idx
,
int
tmp_idx
)
{
// y = 1 / (1 + e^-x)
ymm_t
ymm_tmp
=
ymm_t
(
tmp_idx
);
reg64_t
reg_ptr_global
=
rax
;
ymm_t
ymm_tmp
=
ymm_t
(
2
);
push
(
reg_ptr_global
);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_SIGMOID_MAX
]);
...
...
@@ -315,38 +280,26 @@ void VSigmoidJitCode::sigmoid_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) {
vmaxps
(
ymm_src
,
ymm_src
,
ymm_tmp
);
vxorps
(
ymm_tmp
,
ymm_tmp
,
ymm_tmp
);
vsubps
(
ymm_src
,
ymm_tmp
,
ymm_src
);
exp_ymm
(
ymm_
src
,
ymm_dst
);
exp_ymm
(
ymm_
dst
,
ymm_src
,
fx_idx
,
fy_idx
,
mask_idx
,
tmp_idx
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_ONE
]);
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
vdivps
(
ymm_dst
,
ymm_tmp
,
ymm_dst
);
pop
(
reg_ptr_global
);
}
void
VSigmoidJitCode
::
generate
()
{
int
offset
=
0
;
vmovups
(
ymm_src
,
ptr
[
param1
+
offset
]);
sigmoid_ymm
(
ymm_src
,
ymm_dst
);
vmovups
(
ptr
[
param2
+
offset
],
ymm_dst
);
ret
();
}
bool
VTanhJitCode
::
init
(
int
d
)
{
return
MayIUse
(
avx
)
&&
d
==
8
;
// only 8 yet
}
void
VTanhJitCode
::
vtanh_ymm
(
ymm_t
&
ymm_src
,
ymm_t
&
ymm_dst
)
{
void
VActJitCode
::
tanh_ymm
(
ymm_t
&
ymm_dst
,
ymm_t
&
ymm_src
,
int
fx_idx
,
int
fy_idx
,
int
mask_idx
,
int
tmp_idx
)
{
// y = 2 / (1 + e^(-2x)) - 1
// use ymm2, ymm3
ymm_t
ymm_tmp
=
ymm_t
(
tmp_idx
);
ymm_t
ymm_zero
=
ymm_t
(
mask_idx
);
reg64_t
reg_ptr_global
=
rax
;
ymm_t
ymm_tmp
=
ymm_t
(
2
);
ymm_t
ymm_zero
=
ymm_t
(
3
);
push
(
reg_ptr_global
);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_TWO
]);
vxorps
(
ymm_zero
,
ymm_zero
,
ymm_zero
);
vsubps
(
ymm_tmp
,
ymm_zero
,
ymm_tmp
);
vmulps
(
ymm_src
,
ymm_src
,
ymm_tmp
);
exp_ymm
(
ymm_
src
,
ymm_dst
);
exp_ymm
(
ymm_
dst
,
ymm_src
,
fx_idx
,
fy_idx
,
mask_idx
,
tmp_idx
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_ONE
]);
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_TWO
]);
...
...
@@ -356,11 +309,61 @@ void VTanhJitCode::vtanh_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) {
pop
(
reg_ptr_global
);
}
void
VTanhJitCode
::
generate
()
{
void
VActJitCode
::
generate
()
{
xmm_t
xmm_zero
=
xmm_t
(
2
);
ymm_t
ymm_zero
=
ymm_t
(
2
);
if
(
type_
==
operand_type
::
relu
)
{
vxorps
(
ymm_zero
,
ymm_zero
,
ymm_zero
);
}
int
offset
=
0
;
for
(
int
i
=
0
;
i
<
num_
/
AVX_FLOAT_BLOCK
;
++
i
)
{
vmovups
(
ymm_src
,
ptr
[
param1
+
offset
]);
vtanh_ymm
(
ymm_src
,
ymm_dst
);
switch
(
type_
)
{
case
operand_type
::
relu
:
relu_ymm
(
ymm_dst
,
ymm_src
,
ymm_zero
);
break
;
case
operand_type
::
exp
:
exp_ymm
(
ymm_dst
,
ymm_src
,
2
,
3
,
4
,
5
);
break
;
case
operand_type
::
sigmoid
:
sigmoid_ymm
(
ymm_dst
,
ymm_src
,
2
,
3
,
4
,
5
);
break
;
case
operand_type
::
tanh
:
tanh_ymm
(
ymm_dst
,
ymm_src
,
2
,
3
,
4
,
5
);
break
;
case
operand_type
::
identity
:
break
;
default:
break
;
}
vmovups
(
ptr
[
param2
+
offset
],
ymm_dst
);
offset
+=
sizeof
(
float
)
*
AVX_FLOAT_BLOCK
;
}
if
(
type_
!=
operand_type
::
relu
)
{
// TODO(TJ): remove me
ret
();
return
;
}
int
rest
=
num_
%
AVX_FLOAT_BLOCK
;
if
(
rest
>=
4
)
{
vmovups
(
xmm_src
,
ptr
[
param1
+
offset
]);
vmaxps
(
xmm_dst
,
xmm_zero
,
xmm_src
);
vmovups
(
ptr
[
param2
+
offset
],
xmm_dst
);
offset
+=
sizeof
(
float
)
*
4
;
rest
-=
4
;
}
if
(
rest
>=
2
)
{
vmovups
(
xmm_src
,
ptr
[
param1
+
offset
]);
vmaxps
(
xmm_dst
,
xmm_zero
,
xmm_src
);
vmovq
(
ptr
[
param2
+
offset
],
xmm_dst
);
offset
+=
sizeof
(
float
)
*
2
;
rest
-=
2
;
}
if
(
rest
>
0
)
{
vmovups
(
xmm_src
,
ptr
[
param1
+
offset
]);
vmaxps
(
xmm_dst
,
xmm_zero
,
xmm_src
);
vmovss
(
ptr
[
param2
+
offset
],
xmm_dst
);
}
ret
();
}
...
...
paddle/fluid/operators/math/jit_code.h
浏览文件 @
f65ddff8
...
...
@@ -29,7 +29,16 @@ using ymm_t = const Xbyak::Ymm;
using
zmm_t
=
const
Xbyak
::
Zmm
;
using
Label
=
Xbyak
::
Label
;
typedef
enum
{
mul
=
0
,
add
}
operand_type
;
typedef
enum
{
mul
=
0
,
add
,
sub
,
relu
,
exp
,
sigmoid
,
tanh
,
identity
}
operand_type
;
// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu)
class
VXXJitCode
:
public
JitCode
{
...
...
@@ -85,87 +94,65 @@ class VXXJitCode : public JitCode {
ymm_t
ymm_zero
=
ymm_t
(
3
);
};
class
Relu
JitCode
:
public
JitCode
{
class
VAct
JitCode
:
public
JitCode
{
public:
DECLARE_JIT_CODE
(
ReluJitCode
);
explicit
ReluJitCode
(
int
d
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
:
JitCode
(
code_size
,
code_ptr
),
num_
(
d
)
{}
static
bool
init
(
int
d
);
void
generate
()
override
;
private:
int
num_
;
reg64_t
param1
{
abi_param1
};
reg64_t
param2
{
abi_param2
};
xmm_t
xmm_zero
=
xmm_t
(
0
);
xmm_t
xmm_src
=
xmm_t
(
1
);
xmm_t
xmm_dst
=
xmm_t
(
1
);
ymm_t
ymm_zero
=
ymm_t
(
0
);
ymm_t
ymm_src
=
ymm_t
(
1
);
ymm_t
ymm_dst
=
ymm_t
(
1
);
};
const
char
*
name
()
const
override
{
std
::
string
base
=
"VActJitCode"
;
switch
(
type_
)
{
case
operand_type
::
relu
:
base
+=
"_Relu"
;
break
;
case
operand_type
::
exp
:
base
+=
"_Exp"
;
break
;
case
operand_type
::
sigmoid
:
base
+=
"_Sigmoid"
;
break
;
case
operand_type
::
tanh
:
base
+=
"_Tanh"
;
break
;
case
operand_type
::
identity
:
base
+=
"_Identity"
;
break
;
default:
break
;
}
return
base
.
c_str
();
}
class
VExpJitCode
:
public
JitCode
{
public:
DECLARE_JIT_CODE
(
VExpJitCode
);
explicit
VExpJitCode
(
int
d
,
size_t
code_size
=
256
*
1024
,
explicit
VActJitCode
(
int
d
,
operand_type
type
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
:
JitCode
(
code_size
,
code_ptr
),
num_
(
d
)
{}
static
bool
init
(
int
d
);
:
JitCode
(
code_size
,
code_ptr
),
num_
(
d
)
,
type_
(
type
)
{}
static
bool
init
(
int
d
,
operand_type
type
);
void
generate
()
override
;
protected:
// compute exp with ymm
void
exp_ymm
(
const
Xbyak
::
Ymm
&
src
,
const
Xbyak
::
Ymm
&
dst
);
// compute relu with ymm
void
relu_ymm
(
const
Xbyak
::
Ymm
&
dst
,
const
Xbyak
::
Ymm
&
src
,
const
Xbyak
::
Ymm
&
zero
);
private:
int
num_
;
reg64_t
param1
{
abi_param1
};
reg64_t
param2
{
abi_param2
};
ymm_t
ymm_src
=
ymm_t
(
0
);
ymm_t
ymm_dst
=
ymm_t
(
1
);
};
class
VSigmoidJitCode
:
public
VExpJitCode
{
public:
DECLARE_JIT_CODE
(
VSigmoidJitCode
);
explicit
VSigmoidJitCode
(
int
d
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
:
VExpJitCode
(
d
,
code_size
,
code_ptr
),
num_
(
d
)
{}
static
bool
init
(
int
d
);
void
generate
()
override
;
// compute exp with ymm
void
exp_ymm
(
const
Xbyak
::
Ymm
&
dst
,
const
Xbyak
::
Ymm
&
src
,
int
fx_idx
=
2
,
int
fy_idx
=
3
,
int
mask_idx
=
4
,
int
tmp_idx
=
5
);
// compute sigmoid with ymm
void
sigmoid_ymm
(
const
Xbyak
::
Ymm
&
src
,
const
Xbyak
::
Ymm
&
dst
);
private:
int
num_
;
reg64_t
param1
{
abi_param1
};
reg64_t
param2
{
abi_param2
};
ymm_t
ymm_src
=
ymm_t
(
0
);
ymm_t
ymm_dst
=
ymm_t
(
1
);
};
class
VTanhJitCode
:
public
VExpJitCode
{
public:
DECLARE_JIT_CODE
(
VTanhJitCode
);
explicit
VTanhJitCode
(
int
d
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
:
VExpJitCode
(
d
,
code_size
,
code_ptr
),
num_
(
d
)
{}
static
bool
init
(
int
d
);
void
generate
()
override
;
void
sigmoid_ymm
(
const
Xbyak
::
Ymm
&
dst
,
const
Xbyak
::
Ymm
&
src
,
int
fx_idx
=
2
,
int
fy_idx
=
3
,
int
mask_idx
=
4
,
int
tmp_idx
=
5
);
// compute sigmoid with ymm
void
vtanh_ymm
(
const
Xbyak
::
Ymm
&
src
,
const
Xbyak
::
Ymm
&
dst
);
// compute tanh with ymm
void
tanh_ymm
(
const
Xbyak
::
Ymm
&
dst
,
const
Xbyak
::
Ymm
&
src
,
int
fx_idx
=
2
,
int
fy_idx
=
3
,
int
mask_idx
=
4
,
int
tmp_idx
=
5
);
pr
ivate
:
pr
otected
:
int
num_
;
operand_type
type_
;
reg64_t
param1
{
abi_param1
};
reg64_t
param2
{
abi_param2
};
xmm_t
xmm_src
=
xmm_t
(
0
);
ymm_t
ymm_src
=
ymm_t
(
0
);
xmm_t
xmm_dst
=
xmm_t
(
1
);
ymm_t
ymm_dst
=
ymm_t
(
1
);
};
...
...
paddle/fluid/operators/math/jit_kernel_blas.cc
浏览文件 @
f65ddff8
...
...
@@ -352,7 +352,8 @@ class VReluKernelImpl : public VReluKernel<T> {
size_t
sz
=
96
/* init size */
+
d
/
AVX_FLOAT_BLOCK
*
4
/* instructions */
*
8
/* average bytes for each instruction */
;
jitcode_
.
reset
(
new
gen
::
ReluJitCode
(
d
,
sz
>
4096
?
sz
:
4096
));
jitcode_
.
reset
(
new
gen
::
VActJitCode
(
d
,
gen
::
operand_type
::
relu
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
T
*
,
int
)
>
();
return
;
}
...
...
@@ -366,14 +367,14 @@ class VReluKernelImpl : public VReluKernel<T> {
#ifdef PADDLE_WITH_XBYAK
private:
std
::
unique_ptr
<
gen
::
Relu
JitCode
>
jitcode_
{
nullptr
};
std
::
unique_ptr
<
gen
::
VAct
JitCode
>
jitcode_
{
nullptr
};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template
<
>
bool
VReluKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
ReluJitCode
::
init
(
d
);
return
gen
::
VActJitCode
::
init
(
d
,
gen
::
operand_type
::
relu
);
}
#endif
...
...
paddle/fluid/operators/math/jit_kernel_exp.cc
浏览文件 @
f65ddff8
...
...
@@ -116,7 +116,8 @@ class VExpKernelImpl : public VExpKernel<T> {
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
AVX_FLOAT_BLOCK
*
4
*
8
;
// should change
jitcode_
.
reset
(
new
gen
::
VExpJitCode
(
d
,
sz
>
4096
?
sz
:
4096
));
jitcode_
.
reset
(
new
gen
::
VActJitCode
(
d
,
gen
::
operand_type
::
exp
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
T
*
,
int
)
>
();
return
;
}
...
...
@@ -135,14 +136,14 @@ class VExpKernelImpl : public VExpKernel<T> {
#ifdef PADDLE_WITH_XBYAK
private:
std
::
unique_ptr
<
gen
::
V
Exp
JitCode
>
jitcode_
{
nullptr
};
std
::
unique_ptr
<
gen
::
V
Act
JitCode
>
jitcode_
{
nullptr
};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template
<
>
bool
VExpKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
V
ExpJitCode
::
init
(
d
);
return
gen
::
V
ActJitCode
::
init
(
d
,
gen
::
operand_type
::
exp
);
}
#endif
...
...
@@ -169,7 +170,8 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
AVX_FLOAT_BLOCK
*
4
*
8
;
// should change
jitcode_
.
reset
(
new
gen
::
VSigmoidJitCode
(
d
,
sz
>
4096
?
sz
:
4096
));
jitcode_
.
reset
(
new
gen
::
VActJitCode
(
d
,
gen
::
operand_type
::
sigmoid
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
T
*
,
int
)
>
();
return
;
}
...
...
@@ -190,14 +192,14 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
#ifdef PADDLE_WITH_XBYAK
private:
std
::
unique_ptr
<
gen
::
V
Sigmoid
JitCode
>
jitcode_
{
nullptr
};
std
::
unique_ptr
<
gen
::
V
Act
JitCode
>
jitcode_
{
nullptr
};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template
<
>
bool
VSigmoidKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
V
SigmoidJitCode
::
init
(
d
);
return
gen
::
V
ActJitCode
::
init
(
d
,
gen
::
operand_type
::
sigmoi
d
);
}
#endif
...
...
@@ -223,7 +225,8 @@ class VTanhKernelImpl : public VTanhKernel<T> {
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
AVX_FLOAT_BLOCK
*
4
*
8
;
// should change
jitcode_
.
reset
(
new
gen
::
VTanhJitCode
(
d
,
sz
>
4096
?
sz
:
4096
));
jitcode_
.
reset
(
new
gen
::
VActJitCode
(
d
,
gen
::
operand_type
::
tanh
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
T
*
,
int
)
>
();
return
;
}
...
...
@@ -244,14 +247,14 @@ class VTanhKernelImpl : public VTanhKernel<T> {
#ifdef PADDLE_WITH_XBYAK
private:
std
::
unique_ptr
<
gen
::
V
Tanh
JitCode
>
jitcode_
{
nullptr
};
std
::
unique_ptr
<
gen
::
V
Act
JitCode
>
jitcode_
{
nullptr
};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template
<
>
bool
VTanhKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
V
TanhJitCode
::
init
(
d
);
return
gen
::
V
ActJitCode
::
init
(
d
,
gen
::
operand_type
::
tanh
);
}
#endif
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录