Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
83d075aa
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
83d075aa
编写于
12月 19, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix lstm and gru jitcode
test=develop
上级
20392be0
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
60 addition
and
52 deletion
+60
-52
paddle/fluid/operators/jit/gen/act.h
paddle/fluid/operators/jit/gen/act.h
+45
-37
paddle/fluid/operators/jit/gen/gru.h
paddle/fluid/operators/jit/gen/gru.h
+2
-5
paddle/fluid/operators/jit/gen/jitcode.h
paddle/fluid/operators/jit/gen/jitcode.h
+3
-1
paddle/fluid/operators/jit/gen/lstm.h
paddle/fluid/operators/jit/gen/lstm.h
+5
-6
paddle/fluid/operators/jit/kernel_base.h
paddle/fluid/operators/jit/kernel_base.h
+5
-3
未找到文件。
paddle/fluid/operators/jit/gen/act.h
浏览文件 @
83d075aa
...
@@ -59,43 +59,12 @@ extern int g_tmp_mem[];
...
@@ -59,43 +59,12 @@ extern int g_tmp_mem[];
#define OFFSET_SIGMOID_MAX 15 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MAX 15 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MIN 16 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MIN 16 * YMM_FLOAT_BLOCK * sizeof(float)
class
VAct
JitCode
:
public
JitCode
{
class
VAct
Func
:
public
JitCode
{
public:
public:
explicit
VActJitCode
(
int
d
,
operand_type
type
,
size_t
code_size
,
explicit
VActFunc
(
size_t
code_size
,
void
*
code_ptr
)
void
*
code_ptr
=
nullptr
)
:
JitCode
(
code_size
,
code_ptr
)
{}
:
JitCode
(
code_size
,
code_ptr
),
num_
(
d
),
type_
(
type
)
{
virtual
const
char
*
name
()
const
=
0
;
if
(
!
(
type_
==
operand_type
::
relu
||
type_
==
operand_type
::
exp
||
virtual
void
genCode
()
=
0
;
type_
==
operand_type
::
sigmoid
||
type_
==
operand_type
::
tanh
||
type_
==
operand_type
::
identity
))
{
LOG
(
FATAL
)
<<
"Do not support this operand type: "
<<
type_
;
}
this
->
genCode
();
}
const
char
*
name
()
const
override
{
std
::
string
base
=
"VActJitCode"
;
switch
(
type_
)
{
case
operand_type
::
relu
:
base
+=
"_Relu"
;
break
;
case
operand_type
::
exp
:
base
+=
"_Exp"
;
break
;
case
operand_type
::
sigmoid
:
base
+=
"_Sigmoid"
;
break
;
case
operand_type
::
tanh
:
base
+=
"_Tanh"
;
break
;
case
operand_type
::
identity
:
base
+=
"_Identity"
;
break
;
default:
break
;
}
return
base
.
c_str
();
}
void
genCode
()
override
;
protected:
protected:
// compute relu with ymm, xmm
// compute relu with ymm, xmm
...
@@ -272,10 +241,49 @@ class VActJitCode : public JitCode {
...
@@ -272,10 +241,49 @@ class VActJitCode : public JitCode {
identity_jmm
<
JMM
>
(
dst
,
src
,
15
);
identity_jmm
<
JMM
>
(
dst
,
src
,
15
);
break
;
break
;
default:
default:
LOG
(
FATAL
)
<<
"Do not support this operand type: "
<<
type
;
break
;
}
}
};
class
VActJitCode
:
public
VActFunc
{
public:
explicit
VActJitCode
(
int
d
,
operand_type
type
,
size_t
code_size
,
void
*
code_ptr
=
nullptr
)
:
VActFunc
(
code_size
,
code_ptr
),
num_
(
d
),
type_
(
type
)
{
if
(
!
(
type_
==
operand_type
::
relu
||
type_
==
operand_type
::
exp
||
type_
==
operand_type
::
sigmoid
||
type_
==
operand_type
::
tanh
||
type_
==
operand_type
::
identity
))
{
LOG
(
FATAL
)
<<
"Do not support this operand type: "
<<
type_
;
LOG
(
FATAL
)
<<
"Do not support this operand type: "
<<
type_
;
}
this
->
genCode
();
}
const
char
*
name
()
const
override
{
std
::
string
base
=
"VActJitCode"
;
switch
(
type_
)
{
case
operand_type
::
relu
:
base
+=
"_Relu"
;
break
;
case
operand_type
::
exp
:
base
+=
"_Exp"
;
break
;
case
operand_type
::
sigmoid
:
base
+=
"_Sigmoid"
;
break
;
case
operand_type
::
tanh
:
base
+=
"_Tanh"
;
break
;
case
operand_type
::
identity
:
base
+=
"_Identity"
;
break
;
default:
break
;
break
;
}
}
return
base
.
c_str
();
}
}
void
genCode
()
override
;
protected:
protected:
int
num_
;
int
num_
;
...
...
paddle/fluid/operators/jit/gen/gru.h
浏览文件 @
83d075aa
...
@@ -24,13 +24,11 @@ namespace operators {
...
@@ -24,13 +24,11 @@ namespace operators {
namespace
jit
{
namespace
jit
{
namespace
gen
{
namespace
gen
{
class
GRUJitCode
:
public
VAct
JitCode
{
class
GRUJitCode
:
public
VAct
Func
{
public:
public:
explicit
GRUJitCode
(
int
id
,
const
gru_attr_t
&
attr
,
size_t
code_size
,
explicit
GRUJitCode
(
int
id
,
const
gru_attr_t
&
attr
,
size_t
code_size
,
void
*
code_ptr
=
nullptr
)
void
*
code_ptr
=
nullptr
)
:
VActJitCode
(
attr
.
d
,
operand_type
::
sigmoid
/* this is bugy*/
,
code_size
,
:
VActFunc
(
code_size
,
code_ptr
),
id_
(
id
),
num_
(
attr
.
d
)
{
code_ptr
),
id_
(
id
)
{
auto
typeExchange
=
[](
KernelType
type
)
->
gen
::
operand_type
{
auto
typeExchange
=
[](
KernelType
type
)
->
gen
::
operand_type
{
if
(
type
==
KernelType
::
vsigmoid
)
{
if
(
type
==
KernelType
::
vsigmoid
)
{
return
operand_type
::
sigmoid
;
return
operand_type
::
sigmoid
;
...
@@ -45,7 +43,6 @@ class GRUJitCode : public VActJitCode {
...
@@ -45,7 +43,6 @@ class GRUJitCode : public VActJitCode {
}
}
return
operand_type
::
identity
;
return
operand_type
::
identity
;
};
};
num_
=
attr
.
d
;
act_gate_
=
typeExchange
(
attr
.
act_gate
);
act_gate_
=
typeExchange
(
attr
.
act_gate
);
act_cand_
=
typeExchange
(
attr
.
act_cand
);
act_cand_
=
typeExchange
(
attr
.
act_cand
);
...
...
paddle/fluid/operators/jit/gen/jitcode.h
浏览文件 @
83d075aa
...
@@ -62,7 +62,9 @@ typedef enum {
...
@@ -62,7 +62,9 @@ typedef enum {
class
JitCode
:
public
GenBase
,
public
Xbyak
::
CodeGenerator
{
class
JitCode
:
public
GenBase
,
public
Xbyak
::
CodeGenerator
{
public:
public:
explicit
JitCode
(
size_t
code_size
,
void
*
code_ptr
=
nullptr
)
explicit
JitCode
(
size_t
code_size
,
void
*
code_ptr
=
nullptr
)
:
Xbyak
::
CodeGenerator
((
code_size
<
4096
?
4096
:
code_size
),
code_ptr
)
{}
:
Xbyak
::
CodeGenerator
(
(
code_size
%
4096
!=
0
?
(
code_size
/
4096
+
1
)
*
4096
:
code_size
),
code_ptr
)
{}
virtual
const
char
*
name
()
const
=
0
;
virtual
const
char
*
name
()
const
=
0
;
virtual
void
genCode
()
=
0
;
virtual
void
genCode
()
=
0
;
...
...
paddle/fluid/operators/jit/gen/lstm.h
浏览文件 @
83d075aa
...
@@ -24,13 +24,14 @@ namespace operators {
...
@@ -24,13 +24,14 @@ namespace operators {
namespace
jit
{
namespace
jit
{
namespace
gen
{
namespace
gen
{
class
LSTMJitCode
:
public
VAct
JitCode
{
class
LSTMJitCode
:
public
VAct
Func
{
public:
public:
explicit
LSTMJitCode
(
bool
compute_c1h1
,
const
lstm_attr_t
&
attr
,
explicit
LSTMJitCode
(
bool
compute_c1h1
,
const
lstm_attr_t
&
attr
,
size_t
code_size
,
void
*
code_ptr
=
nullptr
)
size_t
code_size
,
void
*
code_ptr
=
nullptr
)
:
VActJitCode
(
attr
.
d
,
operand_type
::
sigmoid
/* this is bugy*/
,
code_size
,
:
VActFunc
(
code_size
,
code_ptr
),
code_ptr
),
num_
(
attr
.
d
),
compute_c1h1_
(
compute_c1h1
)
{
compute_c1h1_
(
compute_c1h1
),
use_peephole_
(
attr
.
use_peephole
)
{
auto
typeExchange
=
[](
KernelType
type
)
->
gen
::
operand_type
{
auto
typeExchange
=
[](
KernelType
type
)
->
gen
::
operand_type
{
if
(
type
==
KernelType
::
vsigmoid
)
{
if
(
type
==
KernelType
::
vsigmoid
)
{
return
operand_type
::
sigmoid
;
return
operand_type
::
sigmoid
;
...
@@ -45,8 +46,6 @@ class LSTMJitCode : public VActJitCode {
...
@@ -45,8 +46,6 @@ class LSTMJitCode : public VActJitCode {
}
}
return
operand_type
::
identity
;
return
operand_type
::
identity
;
};
};
num_
=
attr
.
d
;
use_peephole_
=
attr
.
use_peephole
;
act_gate_
=
typeExchange
(
attr
.
act_gate
);
act_gate_
=
typeExchange
(
attr
.
act_gate
);
act_cand_
=
typeExchange
(
attr
.
act_cand
);
act_cand_
=
typeExchange
(
attr
.
act_cand
);
act_cell_
=
typeExchange
(
attr
.
act_cell
);
act_cell_
=
typeExchange
(
attr
.
act_cell
);
...
...
paddle/fluid/operators/jit/kernel_base.h
浏览文件 @
83d075aa
...
@@ -80,7 +80,7 @@ struct rnn_attr_s {
...
@@ -80,7 +80,7 @@ struct rnn_attr_s {
int
d
;
int
d
;
KernelType
act_gate
,
act_cand
;
KernelType
act_gate
,
act_cand
;
rnn_attr_s
()
=
default
;
rnn_attr_s
()
=
default
;
rnn_attr_s
(
int
_d
,
KernelType
_act_gate
,
KernelType
_act_cand
)
explicit
rnn_attr_s
(
int
_d
,
KernelType
_act_gate
,
KernelType
_act_cand
)
:
d
(
_d
),
act_gate
(
_act_gate
),
act_cand
(
_act_cand
)
{}
:
d
(
_d
),
act_gate
(
_act_gate
),
act_cand
(
_act_cand
)
{}
};
};
...
@@ -88,7 +88,7 @@ struct lstm_attr_s : public rnn_attr_s {
...
@@ -88,7 +88,7 @@ struct lstm_attr_s : public rnn_attr_s {
bool
use_peephole
;
bool
use_peephole
;
KernelType
act_cell
;
KernelType
act_cell
;
lstm_attr_s
()
=
default
;
lstm_attr_s
()
=
default
;
lstm_attr_s
(
int
_d
,
KernelType
_act_gate
,
KernelType
_act_cand
,
explicit
lstm_attr_s
(
int
_d
,
KernelType
_act_gate
,
KernelType
_act_cand
,
KernelType
_act_cell
,
bool
_use_peephole
=
false
)
KernelType
_act_cell
,
bool
_use_peephole
=
false
)
:
rnn_attr_s
(
_d
,
_act_gate
,
_act_cand
),
:
rnn_attr_s
(
_d
,
_act_gate
,
_act_cand
),
use_peephole
(
_use_peephole
),
use_peephole
(
_use_peephole
),
...
@@ -145,6 +145,8 @@ class Kernel {
...
@@ -145,6 +145,8 @@ class Kernel {
template
<
typename
KernelTuples
>
template
<
typename
KernelTuples
>
class
KernelImpl
:
public
Kernel
{
class
KernelImpl
:
public
Kernel
{
// TODO(TJ): rename KernelImpl to KernelMore which seems only used in more
// and add name interface for more implements easy for debug
using
T
=
typename
KernelTuples
::
data_type
;
using
T
=
typename
KernelTuples
::
data_type
;
using
Func
=
typename
KernelTuples
::
func_type
;
using
Func
=
typename
KernelTuples
::
func_type
;
using
Attr
=
typename
KernelTuples
::
attr_type
;
using
Attr
=
typename
KernelTuples
::
attr_type
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录