Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
facfecbd
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
facfecbd
编写于
12月 20, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
follow comment: reuse time function and change to upper case
test=develop
上级
f5532877
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
71 addition
and
77 deletion
+71
-77
paddle/fluid/operators/jit/README.md
paddle/fluid/operators/jit/README.md
+1
-1
paddle/fluid/operators/jit/benchmark.cc
paddle/fluid/operators/jit/benchmark.cc
+6
-12
paddle/fluid/operators/jit/gen/act.h
paddle/fluid/operators/jit/gen/act.h
+23
-23
paddle/fluid/operators/jit/gen/blas.cc
paddle/fluid/operators/jit/gen/blas.cc
+4
-4
paddle/fluid/operators/jit/gen/blas.h
paddle/fluid/operators/jit/gen/blas.h
+9
-9
paddle/fluid/operators/jit/gen/gru.h
paddle/fluid/operators/jit/gen/gru.h
+10
-10
paddle/fluid/operators/jit/gen/jitcode.h
paddle/fluid/operators/jit/gen/jitcode.h
+8
-8
paddle/fluid/operators/jit/gen/lstm.h
paddle/fluid/operators/jit/gen/lstm.h
+10
-10
未找到文件。
paddle/fluid/operators/jit/README.md
浏览文件 @
facfecbd
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
结合函数模板和JIT生成需要的kernel函数。
结合函数模板和JIT生成需要的kernel函数。
这里的kernel是比Operator中kernel更小级别的算子单元,更侧重的是在不同硬件上的性能。可以有多重第三方库的实现,每种实现有自己的
`UseMe`
函数负责什么条件下可以被调用。
这里的kernel是比Operator中kernel更小级别的算子单元,更侧重的是在不同硬件上的性能。可以有多重第三方库的实现,每种实现有自己的
`UseMe`
函数负责什么条件下可以被调用。
这里实现的函数可以非常细粒度的函数方法,比如Vector
mul
, 也可以是一个复杂的逻辑比如LSTM等。复杂的逻辑也可以由自己的底层函数拼接而成。
这里实现的函数可以非常细粒度的函数方法,比如Vector
MUL
, 也可以是一个复杂的逻辑比如LSTM等。复杂的逻辑也可以由自己的底层函数拼接而成。
目前仅支持CPU上的高性能计算。
目前仅支持CPU上的高性能计算。
## 目录结构
## 目录结构
...
...
paddle/fluid/operators/jit/benchmark.cc
浏览文件 @
facfecbd
...
@@ -19,6 +19,7 @@
...
@@ -19,6 +19,7 @@
#include "gflags/gflags.h"
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/platform/device_tracer.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/port.h"
#include "paddle/fluid/platform/port.h"
...
@@ -26,17 +27,10 @@ DEFINE_int32(burning, 10, "Burning times.");
...
@@ -26,17 +27,10 @@ DEFINE_int32(burning, 10, "Burning times.");
DEFINE_int32
(
repeat
,
3000
,
"Repeat times."
);
DEFINE_int32
(
repeat
,
3000
,
"Repeat times."
);
DEFINE_int32
(
max_size
,
1000
,
"The Max size would be tested."
);
DEFINE_int32
(
max_size
,
1000
,
"The Max size would be tested."
);
inline
double
GetCurrentUS
()
{
struct
timeval
time
;
gettimeofday
(
&
time
,
NULL
);
return
1e+6
*
time
.
tv_sec
+
time
.
tv_usec
;
}
template
<
typename
T
>
template
<
typename
T
>
void
RandomVec
(
const
int
n
,
T
*
a
,
const
T
lower
=
static_cast
<
T
>
(
-
20.
f
),
void
RandomVec
(
const
int
n
,
T
*
a
,
const
T
lower
=
static_cast
<
T
>
(
-
20.
f
),
const
T
upper
=
static_cast
<
T
>
(
20.
f
))
{
const
T
upper
=
static_cast
<
T
>
(
20.
f
),
unsigned
int
seed
=
100
)
{
static
unsigned
int
seed
=
100
;
std
::
mt19937
rng
(
seed
);
std
::
mt19937
rng
(
seed
++
);
std
::
uniform_real_distribution
<
double
>
uniform_dist
(
0
,
1
);
std
::
uniform_real_distribution
<
double
>
uniform_dist
(
0
,
1
);
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
a
[
i
]
=
static_cast
<
T
>
(
uniform_dist
(
rng
)
*
(
upper
-
lower
)
+
lower
);
a
[
i
]
=
static_cast
<
T
>
(
uniform_dist
(
rng
)
*
(
upper
-
lower
)
+
lower
);
...
@@ -58,12 +52,12 @@ struct BenchFunc {
...
@@ -58,12 +52,12 @@ struct BenchFunc {
for
(
int
i
=
0
;
i
<
FLAGS_burning
;
++
i
)
{
for
(
int
i
=
0
;
i
<
FLAGS_burning
;
++
i
)
{
tgt
(
args
...);
tgt
(
args
...);
}
}
auto
start
=
GetCurrentUS
()
;
auto
start
=
paddle
::
platform
::
PosixInNsec
()
/
1e-3
;
for
(
int
i
=
0
;
i
<
FLAGS_repeat
;
++
i
)
{
for
(
int
i
=
0
;
i
<
FLAGS_repeat
;
++
i
)
{
tgt
(
args
...);
tgt
(
args
...);
}
}
auto
end
=
GetCurrentUS
()
;
auto
end
=
paddle
::
platform
::
PosixInNsec
()
/
1e-3
;
return
(
end
-
start
)
/
FLAGS_repeat
;
return
static_cast
<
double
>
(
end
-
start
)
/
FLAGS_repeat
;
}
}
};
};
...
...
paddle/fluid/operators/jit/gen/act.h
浏览文件 @
facfecbd
...
@@ -67,7 +67,7 @@ class VActFunc : public JitCode {
...
@@ -67,7 +67,7 @@ class VActFunc : public JitCode {
virtual
void
genCode
()
=
0
;
virtual
void
genCode
()
=
0
;
protected:
protected:
// compute
relu
with ymm, xmm
// compute
RELU
with ymm, xmm
template
<
typename
JMM
>
template
<
typename
JMM
>
void
relu_jmm
(
JMM
&
dst
,
JMM
&
src
,
int
zero_idx
=
15
)
{
// NOLINT
void
relu_jmm
(
JMM
&
dst
,
JMM
&
src
,
int
zero_idx
=
15
)
{
// NOLINT
JMM
zero
=
JMM
(
zero_idx
);
JMM
zero
=
JMM
(
zero_idx
);
...
@@ -75,7 +75,7 @@ class VActFunc : public JitCode {
...
@@ -75,7 +75,7 @@ class VActFunc : public JitCode {
vmaxps
(
dst
,
src
,
zero
);
vmaxps
(
dst
,
src
,
zero
);
}
}
// compute
exp
with ymm, xmm
// compute
EXP
with ymm, xmm
template
<
typename
JMM
>
template
<
typename
JMM
>
void
exp_jmm
(
JMM
&
dst
,
JMM
&
src
,
int
src_idx
=
11
,
int
fx_idx
=
12
,
// NOLINT
void
exp_jmm
(
JMM
&
dst
,
JMM
&
src
,
int
src_idx
=
11
,
int
fx_idx
=
12
,
// NOLINT
int
fy_idx
=
13
,
int
mask_idx
=
14
,
int
tmp_idx
=
15
)
{
int
fy_idx
=
13
,
int
mask_idx
=
14
,
int
tmp_idx
=
15
)
{
...
@@ -159,7 +159,7 @@ class VActFunc : public JitCode {
...
@@ -159,7 +159,7 @@ class VActFunc : public JitCode {
pop
(
reg_ptr_global
);
pop
(
reg_ptr_global
);
}
}
// compute
sigmoid
with ymm, xmm
// compute
SIGMOID
with ymm, xmm
template
<
typename
JMM
>
template
<
typename
JMM
>
void
sigmoid_jmm
(
JMM
&
dst
,
JMM
&
src
,
int
src_idx
=
11
,
// NOLINT
void
sigmoid_jmm
(
JMM
&
dst
,
JMM
&
src
,
int
src_idx
=
11
,
// NOLINT
int
fx_idx
=
12
,
int
fy_idx
=
13
,
int
mask_idx
=
14
,
int
fx_idx
=
12
,
int
fy_idx
=
13
,
int
mask_idx
=
14
,
...
@@ -184,7 +184,7 @@ class VActFunc : public JitCode {
...
@@ -184,7 +184,7 @@ class VActFunc : public JitCode {
pop
(
reg_ptr_global
);
pop
(
reg_ptr_global
);
}
}
// compute
tanh
with ymm, xmm
// compute
TANH
with ymm, xmm
template
<
typename
JMM
>
template
<
typename
JMM
>
void
tanh_jmm
(
JMM
&
dst
,
JMM
&
src
,
int
src_idx
=
11
,
// NOLINT
void
tanh_jmm
(
JMM
&
dst
,
JMM
&
src
,
int
src_idx
=
11
,
// NOLINT
int
fx_idx
=
12
,
int
fy_idx
=
13
,
int
mask_idx
=
14
,
int
fx_idx
=
12
,
int
fy_idx
=
13
,
int
mask_idx
=
14
,
...
@@ -211,7 +211,7 @@ class VActFunc : public JitCode {
...
@@ -211,7 +211,7 @@ class VActFunc : public JitCode {
pop
(
reg_ptr_global
);
pop
(
reg_ptr_global
);
}
}
// compute
identity
with ymm, xmm
// compute
IDENTITY
with ymm, xmm
template
<
typename
JMM
>
template
<
typename
JMM
>
void
identity_jmm
(
JMM
&
dst
,
JMM
&
src
,
int
zero_idx
)
{
// NOLINT
void
identity_jmm
(
JMM
&
dst
,
JMM
&
src
,
int
zero_idx
)
{
// NOLINT
JMM
zero
=
JMM
(
zero_idx
);
JMM
zero
=
JMM
(
zero_idx
);
...
@@ -225,19 +225,19 @@ class VActFunc : public JitCode {
...
@@ -225,19 +225,19 @@ class VActFunc : public JitCode {
void
act
(
JMM
&
dst
,
JMM
&
src
,
operand_type
type
)
{
// NOLINT
void
act
(
JMM
&
dst
,
JMM
&
src
,
operand_type
type
)
{
// NOLINT
// use 11~15
// use 11~15
switch
(
type
)
{
switch
(
type
)
{
case
operand_type
::
relu
:
case
operand_type
::
RELU
:
relu_jmm
<
JMM
>
(
dst
,
src
,
15
);
relu_jmm
<
JMM
>
(
dst
,
src
,
15
);
break
;
break
;
case
operand_type
::
exp
:
case
operand_type
::
EXP
:
exp_jmm
<
JMM
>
(
dst
,
src
,
11
,
12
,
13
,
14
,
15
);
exp_jmm
<
JMM
>
(
dst
,
src
,
11
,
12
,
13
,
14
,
15
);
break
;
break
;
case
operand_type
::
sigmoid
:
case
operand_type
::
SIGMOID
:
sigmoid_jmm
<
JMM
>
(
dst
,
src
,
11
,
12
,
13
,
14
,
15
);
sigmoid_jmm
<
JMM
>
(
dst
,
src
,
11
,
12
,
13
,
14
,
15
);
break
;
break
;
case
operand_type
::
tanh
:
case
operand_type
::
TANH
:
tanh_jmm
<
JMM
>
(
dst
,
src
,
11
,
12
,
13
,
14
,
15
);
tanh_jmm
<
JMM
>
(
dst
,
src
,
11
,
12
,
13
,
14
,
15
);
break
;
break
;
case
operand_type
::
identity
:
case
operand_type
::
IDENTITY
:
identity_jmm
<
JMM
>
(
dst
,
src
,
15
);
identity_jmm
<
JMM
>
(
dst
,
src
,
15
);
break
;
break
;
default:
default:
...
@@ -252,9 +252,9 @@ class VActJitCode : public VActFunc {
...
@@ -252,9 +252,9 @@ class VActJitCode : public VActFunc {
explicit
VActJitCode
(
int
d
,
operand_type
type
,
size_t
code_size
,
explicit
VActJitCode
(
int
d
,
operand_type
type
,
size_t
code_size
,
void
*
code_ptr
=
nullptr
)
void
*
code_ptr
=
nullptr
)
:
VActFunc
(
code_size
,
code_ptr
),
num_
(
d
),
type_
(
type
)
{
:
VActFunc
(
code_size
,
code_ptr
),
num_
(
d
),
type_
(
type
)
{
if
(
!
(
type_
==
operand_type
::
relu
||
type_
==
operand_type
::
exp
||
if
(
!
(
type_
==
operand_type
::
RELU
||
type_
==
operand_type
::
EXP
||
type_
==
operand_type
::
sigmoid
||
type_
==
operand_type
::
tanh
||
type_
==
operand_type
::
SIGMOID
||
type_
==
operand_type
::
TANH
||
type_
==
operand_type
::
identity
))
{
type_
==
operand_type
::
IDENTITY
))
{
LOG
(
FATAL
)
<<
"Do not support this operand type: "
<<
type_
;
LOG
(
FATAL
)
<<
"Do not support this operand type: "
<<
type_
;
}
}
this
->
genCode
();
this
->
genCode
();
...
@@ -263,19 +263,19 @@ class VActJitCode : public VActFunc {
...
@@ -263,19 +263,19 @@ class VActJitCode : public VActFunc {
const
char
*
name
()
const
override
{
const
char
*
name
()
const
override
{
std
::
string
base
=
"VActJitCode"
;
std
::
string
base
=
"VActJitCode"
;
switch
(
type_
)
{
switch
(
type_
)
{
case
operand_type
::
relu
:
case
operand_type
::
RELU
:
base
+=
"_Relu"
;
base
+=
"_Relu"
;
break
;
break
;
case
operand_type
::
exp
:
case
operand_type
::
EXP
:
base
+=
"_Exp"
;
base
+=
"_Exp"
;
break
;
break
;
case
operand_type
::
sigmoid
:
case
operand_type
::
SIGMOID
:
base
+=
"_Sigmoid"
;
base
+=
"_Sigmoid"
;
break
;
break
;
case
operand_type
::
tanh
:
case
operand_type
::
TANH
:
base
+=
"_Tanh"
;
base
+=
"_Tanh"
;
break
;
break
;
case
operand_type
::
identity
:
case
operand_type
::
IDENTITY
:
base
+=
"_Identity"
;
base
+=
"_Identity"
;
break
;
break
;
default:
default:
...
@@ -305,11 +305,11 @@ class VActJitCode : public VActFunc {
...
@@ -305,11 +305,11 @@ class VActJitCode : public VActFunc {
: VActJitCode(d, op_type, code_size, code_ptr) {} \
: VActJitCode(d, op_type, code_size, code_ptr) {} \
};
};
DECLARE_ACT_JITCODE
(
VRelu
,
operand_type
::
relu
);
DECLARE_ACT_JITCODE
(
VRelu
,
operand_type
::
RELU
);
DECLARE_ACT_JITCODE
(
VIdentity
,
operand_type
::
identity
);
DECLARE_ACT_JITCODE
(
VIdentity
,
operand_type
::
IDENTITY
);
DECLARE_ACT_JITCODE
(
VExp
,
operand_type
::
exp
);
DECLARE_ACT_JITCODE
(
VExp
,
operand_type
::
EXP
);
DECLARE_ACT_JITCODE
(
VSigmoid
,
operand_type
::
sigmoid
);
DECLARE_ACT_JITCODE
(
VSigmoid
,
operand_type
::
SIGMOID
);
DECLARE_ACT_JITCODE
(
VTanh
,
operand_type
::
tanh
);
DECLARE_ACT_JITCODE
(
VTanh
,
operand_type
::
TANH
);
#undef DECLARE_ACT_JITCODE
#undef DECLARE_ACT_JITCODE
...
...
paddle/fluid/operators/jit/gen/blas.cc
浏览文件 @
facfecbd
...
@@ -39,9 +39,9 @@ void VXXJitCode::genCode() {
...
@@ -39,9 +39,9 @@ void VXXJitCode::genCode() {
if
(
scalar_index_
!=
2
)
{
if
(
scalar_index_
!=
2
)
{
vmovups
(
ymm_src2
,
ptr
[
param2
+
offset
]);
vmovups
(
ymm_src2
,
ptr
[
param2
+
offset
]);
}
}
if
(
type_
==
operand_type
::
mul
)
{
if
(
type_
==
operand_type
::
MUL
)
{
vmulps
(
ymm_dst
,
ymm_src1
,
ymm_src2
);
vmulps
(
ymm_dst
,
ymm_src1
,
ymm_src2
);
}
else
if
(
type_
==
operand_type
::
add
)
{
}
else
if
(
type_
==
operand_type
::
ADD
)
{
vaddps
(
ymm_dst
,
ymm_src1
,
ymm_src2
);
vaddps
(
ymm_dst
,
ymm_src1
,
ymm_src2
);
}
}
if
(
with_relu_
)
{
if
(
with_relu_
)
{
...
@@ -79,10 +79,10 @@ void VXXJitCode::genCode() {
...
@@ -79,10 +79,10 @@ void VXXJitCode::genCode() {
}
}
}
}
switch
(
type_
)
{
switch
(
type_
)
{
case
operand_type
::
mul
:
case
operand_type
::
MUL
:
vmulps
(
xmm_dst
,
xmm_src1
,
xmm_src2
);
vmulps
(
xmm_dst
,
xmm_src1
,
xmm_src2
);
break
;
break
;
case
operand_type
::
add
:
case
operand_type
::
ADD
:
vaddps
(
xmm_dst
,
xmm_src1
,
xmm_src2
);
vaddps
(
xmm_dst
,
xmm_src1
,
xmm_src2
);
break
;
break
;
default:
default:
...
...
paddle/fluid/operators/jit/gen/blas.h
浏览文件 @
facfecbd
...
@@ -34,7 +34,7 @@ class VXXJitCode : public JitCode {
...
@@ -34,7 +34,7 @@ class VXXJitCode : public JitCode {
type_
(
type
),
type_
(
type
),
scalar_index_
(
scalar_index
),
scalar_index_
(
scalar_index
),
with_relu_
(
with_relu
)
{
with_relu_
(
with_relu
)
{
if
(
!
(
type_
==
operand_type
::
mul
||
type_
==
operand_type
::
add
))
{
if
(
!
(
type_
==
operand_type
::
MUL
||
type_
==
operand_type
::
ADD
))
{
LOG
(
FATAL
)
<<
"Do not support this operand type: "
<<
type_
;
LOG
(
FATAL
)
<<
"Do not support this operand type: "
<<
type_
;
}
}
this
->
genCode
();
this
->
genCode
();
...
@@ -47,9 +47,9 @@ class VXXJitCode : public JitCode {
...
@@ -47,9 +47,9 @@ class VXXJitCode : public JitCode {
}
else
{
}
else
{
base
+=
"_Vec"
;
base
+=
"_Vec"
;
}
}
if
(
type_
==
operand_type
::
mul
)
{
if
(
type_
==
operand_type
::
MUL
)
{
base
+=
"_Mul"
;
base
+=
"_Mul"
;
}
else
if
(
type_
==
operand_type
::
add
)
{
}
else
if
(
type_
==
operand_type
::
ADD
)
{
base
+=
"_Add"
;
base
+=
"_Add"
;
}
}
if
(
scalar_index_
==
2
)
{
if
(
scalar_index_
==
2
)
{
...
@@ -90,12 +90,12 @@ class VXXJitCode : public JitCode {
...
@@ -90,12 +90,12 @@ class VXXJitCode : public JitCode {
} \
} \
};
};
DECLARE_BLAS_JITCODE
(
VMul
,
operand_type
::
mul
,
0
,
false
);
DECLARE_BLAS_JITCODE
(
VMul
,
operand_type
::
MUL
,
0
,
false
);
DECLARE_BLAS_JITCODE
(
VAdd
,
operand_type
::
add
,
0
,
false
);
DECLARE_BLAS_JITCODE
(
VAdd
,
operand_type
::
ADD
,
0
,
false
);
DECLARE_BLAS_JITCODE
(
VSub
,
operand_type
::
sub
,
0
,
false
);
DECLARE_BLAS_JITCODE
(
VSub
,
operand_type
::
SUB
,
0
,
false
);
DECLARE_BLAS_JITCODE
(
VAddRelu
,
operand_type
::
add
,
0
,
true
);
DECLARE_BLAS_JITCODE
(
VAddRelu
,
operand_type
::
ADD
,
0
,
true
);
DECLARE_BLAS_JITCODE
(
VScal
,
operand_type
::
mul
,
1
,
false
);
DECLARE_BLAS_JITCODE
(
VScal
,
operand_type
::
MUL
,
1
,
false
);
DECLARE_BLAS_JITCODE
(
VAddBias
,
operand_type
::
add
,
1
,
false
);
DECLARE_BLAS_JITCODE
(
VAddBias
,
operand_type
::
ADD
,
1
,
false
);
#undef DECLARE_BLAS_JITCODE
#undef DECLARE_BLAS_JITCODE
...
...
paddle/fluid/operators/jit/gen/gru.h
浏览文件 @
facfecbd
...
@@ -31,17 +31,17 @@ class GRUJitCode : public VActFunc {
...
@@ -31,17 +31,17 @@ class GRUJitCode : public VActFunc {
:
VActFunc
(
code_size
,
code_ptr
),
id_
(
id
),
num_
(
attr
.
d
)
{
:
VActFunc
(
code_size
,
code_ptr
),
id_
(
id
),
num_
(
attr
.
d
)
{
auto
typeExchange
=
[](
KernelType
type
)
->
gen
::
operand_type
{
auto
typeExchange
=
[](
KernelType
type
)
->
gen
::
operand_type
{
if
(
type
==
KernelType
::
vsigmoid
)
{
if
(
type
==
KernelType
::
vsigmoid
)
{
return
operand_type
::
sigmoid
;
return
operand_type
::
SIGMOID
;
}
else
if
(
type
==
KernelType
::
vrelu
)
{
}
else
if
(
type
==
KernelType
::
vrelu
)
{
return
operand_type
::
relu
;
return
operand_type
::
RELU
;
}
else
if
(
type
==
KernelType
::
vtanh
)
{
}
else
if
(
type
==
KernelType
::
vtanh
)
{
return
operand_type
::
tanh
;
return
operand_type
::
TANH
;
}
else
if
(
type
==
KernelType
::
videntity
)
{
}
else
if
(
type
==
KernelType
::
videntity
)
{
return
operand_type
::
identity
;
return
operand_type
::
IDENTITY
;
}
else
{
}
else
{
LOG
(
FATAL
)
<<
"Do not support this jit::KernelType: "
<<
type
;
LOG
(
FATAL
)
<<
"Do not support this jit::KernelType: "
<<
type
;
}
}
return
operand_type
::
identity
;
return
operand_type
::
IDENTITY
;
};
};
act_gate_
=
typeExchange
(
attr
.
act_gate
);
act_gate_
=
typeExchange
(
attr
.
act_gate
);
act_cand_
=
typeExchange
(
attr
.
act_cand
);
act_cand_
=
typeExchange
(
attr
.
act_cand
);
...
@@ -60,19 +60,19 @@ class GRUJitCode : public VActFunc {
...
@@ -60,19 +60,19 @@ class GRUJitCode : public VActFunc {
}
}
auto
AddTypeStr
=
[
&
](
operand_type
type
)
{
auto
AddTypeStr
=
[
&
](
operand_type
type
)
{
switch
(
type
)
{
switch
(
type
)
{
case
operand_type
::
relu
:
case
operand_type
::
RELU
:
base
+=
"_Relu"
;
base
+=
"_Relu"
;
break
;
break
;
case
operand_type
::
exp
:
case
operand_type
::
EXP
:
base
+=
"_Exp"
;
base
+=
"_Exp"
;
break
;
break
;
case
operand_type
::
sigmoid
:
case
operand_type
::
SIGMOID
:
base
+=
"_Sigmoid"
;
base
+=
"_Sigmoid"
;
break
;
break
;
case
operand_type
::
tanh
:
case
operand_type
::
TANH
:
base
+=
"_Tanh"
;
base
+=
"_Tanh"
;
break
;
break
;
case
operand_type
::
identity
:
case
operand_type
::
IDENTITY
:
base
+=
"_Identity"
;
base
+=
"_Identity"
;
break
;
break
;
default:
default:
...
...
paddle/fluid/operators/jit/gen/jitcode.h
浏览文件 @
facfecbd
...
@@ -46,14 +46,14 @@ using zmm_t = const Xbyak::Zmm;
...
@@ -46,14 +46,14 @@ using zmm_t = const Xbyak::Zmm;
using
Label
=
Xbyak
::
Label
;
using
Label
=
Xbyak
::
Label
;
typedef
enum
{
typedef
enum
{
mul
=
0
,
MUL
=
0
,
add
,
ADD
,
sub
,
SUB
,
relu
,
RELU
,
exp
,
EXP
,
sigmoid
,
SIGMOID
,
tanh
,
TANH
,
identity
IDENTITY
}
operand_type
;
}
operand_type
;
#define DECLARE_JIT_CODE(codename) \
#define DECLARE_JIT_CODE(codename) \
...
...
paddle/fluid/operators/jit/gen/lstm.h
浏览文件 @
facfecbd
...
@@ -34,17 +34,17 @@ class LSTMJitCode : public VActFunc {
...
@@ -34,17 +34,17 @@ class LSTMJitCode : public VActFunc {
use_peephole_
(
attr
.
use_peephole
)
{
use_peephole_
(
attr
.
use_peephole
)
{
auto
typeExchange
=
[](
KernelType
type
)
->
gen
::
operand_type
{
auto
typeExchange
=
[](
KernelType
type
)
->
gen
::
operand_type
{
if
(
type
==
KernelType
::
vsigmoid
)
{
if
(
type
==
KernelType
::
vsigmoid
)
{
return
operand_type
::
sigmoid
;
return
operand_type
::
SIGMOID
;
}
else
if
(
type
==
KernelType
::
vrelu
)
{
}
else
if
(
type
==
KernelType
::
vrelu
)
{
return
operand_type
::
relu
;
return
operand_type
::
RELU
;
}
else
if
(
type
==
KernelType
::
vtanh
)
{
}
else
if
(
type
==
KernelType
::
vtanh
)
{
return
operand_type
::
tanh
;
return
operand_type
::
TANH
;
}
else
if
(
type
==
KernelType
::
videntity
)
{
}
else
if
(
type
==
KernelType
::
videntity
)
{
return
operand_type
::
identity
;
return
operand_type
::
IDENTITY
;
}
else
{
}
else
{
LOG
(
FATAL
)
<<
"Do not support this jit::KernelType: "
<<
type
;
LOG
(
FATAL
)
<<
"Do not support this jit::KernelType: "
<<
type
;
}
}
return
operand_type
::
identity
;
return
operand_type
::
IDENTITY
;
};
};
act_gate_
=
typeExchange
(
attr
.
act_gate
);
act_gate_
=
typeExchange
(
attr
.
act_gate
);
act_cand_
=
typeExchange
(
attr
.
act_cand
);
act_cand_
=
typeExchange
(
attr
.
act_cand
);
...
@@ -63,19 +63,19 @@ class LSTMJitCode : public VActFunc {
...
@@ -63,19 +63,19 @@ class LSTMJitCode : public VActFunc {
}
}
auto
AddTypeStr
=
[
&
](
operand_type
type
)
{
auto
AddTypeStr
=
[
&
](
operand_type
type
)
{
switch
(
type
)
{
switch
(
type
)
{
case
operand_type
::
relu
:
case
operand_type
::
RELU
:
base
+=
"_Relu"
;
base
+=
"_Relu"
;
break
;
break
;
case
operand_type
::
exp
:
case
operand_type
::
EXP
:
base
+=
"_Exp"
;
base
+=
"_Exp"
;
break
;
break
;
case
operand_type
::
sigmoid
:
case
operand_type
::
SIGMOID
:
base
+=
"_Sigmoid"
;
base
+=
"_Sigmoid"
;
break
;
break
;
case
operand_type
::
tanh
:
case
operand_type
::
TANH
:
base
+=
"_Tanh"
;
base
+=
"_Tanh"
;
break
;
break
;
case
operand_type
::
identity
:
case
operand_type
::
IDENTITY
:
base
+=
"_Identity"
;
base
+=
"_Identity"
;
break
;
break
;
default:
default:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录