Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
c187a7c6
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c187a7c6
编写于
12月 19, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add more impls of lstm and gru and fix build on win
test=develop
上级
83d075aa
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
167 addition
and
25 deletion
+167
-25
paddle/fluid/operators/jit/kernel_base.h
paddle/fluid/operators/jit/kernel_base.h
+2
-2
paddle/fluid/operators/jit/more/mix/CMakeLists.txt
paddle/fluid/operators/jit/more/mix/CMakeLists.txt
+5
-0
paddle/fluid/operators/jit/more/mix/mix.cc
paddle/fluid/operators/jit/more/mix/mix.cc
+145
-18
paddle/fluid/operators/jit/more/mix/mix.h
paddle/fluid/operators/jit/more/mix/mix.h
+15
-5
未找到文件。
paddle/fluid/operators/jit/kernel_base.h
浏览文件 @
c187a7c6
...
...
@@ -147,12 +147,12 @@ template <typename KernelTuples>
class
KernelImpl
:
public
Kernel
{
// TODO(TJ): rename KernelImpl to KernelMore which seems only used in more
// and add name interface for more implements easy for debug
public:
using
T
=
typename
KernelTuples
::
data_type
;
using
Func
=
typename
KernelTuples
::
func_type
;
using
Attr
=
typename
KernelTuples
::
attr_type
;
public:
virtual
Func
GetFunc
()
const
{
return
func
;
}
// TODO(TJ): const &attr
virtual
bool
UseMe
(
Attr
attr
)
const
=
0
;
protected:
...
...
paddle/fluid/operators/jit/more/mix/CMakeLists.txt
浏览文件 @
c187a7c6
...
...
@@ -7,3 +7,8 @@ set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} jit_kernel_mix PARENT_SCOPE)
USE_JITKERNEL_MORE
(
vsigmoid, mix
)
USE_JITKERNEL_MORE
(
vtanh, mix
)
USE_JITKERNEL_MORE
(
lstmctht, mix
)
USE_JITKERNEL_MORE
(
lstmc1h1, mix
)
USE_JITKERNEL_MORE
(
gruh1, mix
)
USE_JITKERNEL_MORE
(
gruhtpart1, mix
)
USE_JITKERNEL_MORE
(
gruhtpart2, mix
)
paddle/fluid/operators/jit/more/mix/mix.cc
浏览文件 @
c187a7c6
...
...
@@ -23,7 +23,6 @@ namespace jit {
namespace
more
{
namespace
mix
{
template
<
typename
T
>
void
VSigmoid
(
const
T
*
x
,
T
*
y
,
int
n
)
{
const
float
min
=
SIGMOID_THRESHOLD_MIN
;
const
float
max
=
SIGMOID_THRESHOLD_MAX
;
...
...
@@ -38,7 +37,6 @@ void VSigmoid(const T* x, T* y, int n) {
}
}
template
<
typename
T
>
void
VTanh
(
const
T
*
x
,
T
*
y
,
int
n
)
{
const
T
a
=
2
,
b
=
-
1
;
auto
compute_scal
=
Get
<
vscal
,
AXYNTuples
<
T
>
,
platform
::
CPUPlace
>
(
n
);
...
...
@@ -50,26 +48,151 @@ void VTanh(const T* x, T* y, int n) {
compute_addbias
(
&
b
,
y
,
y
,
n
);
}
template
<
>
bool
VSigmoidKernel
<
float
>::
UseMe
(
int
d
)
const
{
return
true
;
void
(
*
getActFunc
(
KernelType
type
,
int
d
))(
const
T
*
,
T
*
,
int
)
{
// NOLINT
if
(
type
==
vsigmoid
)
{
return
Get
<
vsigmoid
,
XYNTuples
<
T
>
,
platform
::
CPUPlace
>
(
d
);
}
else
if
(
type
==
vrelu
)
{
return
Get
<
vrelu
,
XYNTuples
<
T
>
,
platform
::
CPUPlace
>
(
d
);
}
else
if
(
type
==
vtanh
)
{
return
Get
<
vtanh
,
XYNTuples
<
T
>
,
platform
::
CPUPlace
>
(
d
);
}
else
if
(
type
==
videntity
)
{
return
Get
<
videntity
,
XYNTuples
<
T
>
,
platform
::
CPUPlace
>
(
d
);
}
PADDLE_THROW
(
"Not support type: %s"
,
type
);
return
nullptr
;
}
template
<
>
bool
VTanhKernel
<
float
>::
UseMe
(
int
d
)
const
{
return
true
;
void
LSTMCtHt
(
lstm_t
*
step
,
const
lstm_attr_t
*
attr
)
{
T
*
gates
=
reinterpret_cast
<
T
*>
(
step
->
gates
);
const
T
*
ct_1
=
reinterpret_cast
<
const
T
*>
(
step
->
ct_1
);
T
*
ct
=
reinterpret_cast
<
T
*>
(
step
->
ct
);
T
*
ht
=
reinterpret_cast
<
T
*>
(
step
->
ht
);
const
T
*
wp
=
reinterpret_cast
<
const
T
*>
(
step
->
wp
);
T
*
checked
=
reinterpret_cast
<
T
*>
(
step
->
checked
);
const
int
d
=
attr
->
d
;
const
int
d2
=
d
*
2
;
const
int
d3
=
d
*
3
;
auto
vmul_d
=
Get
<
vmul
,
XYZNTuples
<
T
>
,
platform
::
CPUPlace
>
(
d
);
auto
vadd_d
=
Get
<
vadd
,
XYZNTuples
<
T
>
,
platform
::
CPUPlace
>
(
d
);
auto
vadd_d2
=
Get
<
vadd
,
XYZNTuples
<
T
>
,
platform
::
CPUPlace
>
(
d2
);
auto
act_gate_d
=
getActFunc
(
attr
->
act_gate
,
d
);
auto
act_gate_d2
=
getActFunc
(
attr
->
act_gate
,
d2
);
auto
act_gate_d3
=
getActFunc
(
attr
->
act_gate
,
d2
);
auto
act_cand_d
=
getActFunc
(
attr
->
act_cand
,
d
);
auto
act_cell_d
=
getActFunc
(
attr
->
act_cell
,
d
);
if
(
attr
->
use_peephole
)
{
vmul_d
(
wp
,
ct_1
,
checked
,
d
);
vmul_d
(
wp
+
d
,
ct_1
,
checked
+
d
,
d
);
vadd_d2
(
checked
,
gates
+
d
,
gates
+
d
,
d2
);
act_gate_d2
(
gates
+
d
,
gates
+
d
,
d2
);
}
else
{
act_gate_d3
(
gates
+
d
,
gates
+
d
,
d3
);
}
// C_t = C_t-1 * fgated + cand_gated * igated
act_cand_d
(
gates
,
gates
,
d
);
vmul_d
(
gates
,
gates
+
d
,
gates
+
d
,
d
);
vmul_d
(
ct_1
,
gates
+
d2
,
gates
+
d2
,
d
);
vadd_d
(
gates
+
d
,
gates
+
d2
,
ct
,
d
);
if
(
attr
->
use_peephole
)
{
// get ogated
vmul_d
(
wp
+
d2
,
ct
,
gates
+
d
,
d
);
vadd_d
(
gates
+
d
,
gates
+
d3
,
gates
+
d3
,
d
);
act_gate_d
(
gates
+
d3
,
gates
+
d3
,
d
);
}
// H_t = act_cell(C_t) * ogated
act_cell_d
(
ct
,
gates
+
d2
,
d
);
vmul_d
(
gates
+
d2
,
gates
+
d3
,
ht
,
d
);
}
#define AWALYS_USE_ME_WITH_DOUBLE(func) \
template <> \
bool func##Kernel<double>::UseMe(int d) const { \
return true; \
void
LSTMC1H1
(
lstm_t
*
step
,
const
lstm_attr_t
*
attr
)
{
T
*
gates
=
reinterpret_cast
<
T
*>
(
step
->
gates
);
T
*
ct
=
reinterpret_cast
<
T
*>
(
step
->
ct
);
T
*
ht
=
reinterpret_cast
<
T
*>
(
step
->
ht
);
int
d
=
attr
->
d
;
int
d2
=
d
*
2
;
int
d3
=
d
*
3
;
auto
vmul_d
=
Get
<
vmul
,
XYZNTuples
<
T
>
,
platform
::
CPUPlace
>
(
d
);
auto
vadd_d
=
Get
<
vadd
,
XYZNTuples
<
T
>
,
platform
::
CPUPlace
>
(
d
);
auto
act_gate_d
=
getActFunc
(
attr
->
act_gate
,
d
);
auto
act_cand_d
=
getActFunc
(
attr
->
act_cand
,
d
);
auto
act_cell_d
=
getActFunc
(
attr
->
act_cell
,
d
);
/* C_t = igated * cgated*/
act_gate_d
(
gates
+
d
,
gates
+
d
,
d
);
act_cand_d
(
gates
,
gates
,
d
);
vmul_d
(
gates
,
gates
+
d
,
ct
,
d
);
if
(
attr
->
use_peephole
)
{
// get outgated, put W_oc * C_t on igated
const
T
*
wp
=
reinterpret_cast
<
const
T
*>
(
step
->
wp
);
vmul_d
(
wp
+
d2
,
ct
,
gates
+
d
,
d
);
vadd_d
(
gates
+
d
,
gates
+
d3
,
gates
+
d3
,
d
);
}
/* H_t = act_cell(C_t) * ogated */
act_gate_d
(
gates
+
d3
,
gates
+
d3
,
d
);
act_cell_d
(
ct
,
gates
+
d2
,
d
);
vmul_d
(
gates
+
d2
,
gates
+
d3
,
ht
,
d
);
}
// compute h1 without h0
void
GRUH1
(
gru_t
*
step
,
const
gru_attr_t
*
attr
)
{
T
*
gates
=
reinterpret_cast
<
T
*>
(
step
->
gates
);
T
*
ht
=
reinterpret_cast
<
T
*>
(
step
->
ht
);
int
d
=
attr
->
d
;
int
d2
=
d
*
2
;
auto
act_gate
=
getActFunc
(
attr
->
act_gate
,
d
);
auto
act_cand
=
getActFunc
(
attr
->
act_cand
,
d
);
auto
vmul_d
=
Get
<
vmul
,
XYZNTuples
<
T
>
,
platform
::
CPUPlace
>
(
d
);
act_gate
(
gates
,
gates
,
d
);
act_cand
(
gates
+
d2
,
gates
+
d2
,
d
);
vmul_d
(
gates
,
gates
+
d2
,
ht
,
d
);
}
// compute the first part of GRU: ht = act_gate(r) * ht_1
void
GRUHtPart1
(
gru_t
*
step
,
const
gru_attr_t
*
attr
)
{
// W: {W_update, W_reset; W_state}
T
*
gates
=
reinterpret_cast
<
T
*>
(
step
->
gates
);
T
*
ht
=
reinterpret_cast
<
T
*>
(
step
->
ht
);
const
T
*
ht_1
=
reinterpret_cast
<
const
T
*>
(
step
->
ht_1
);
auto
act_gate
=
getActFunc
(
attr
->
act_gate
,
attr
->
d
);
auto
vmul_d
=
Get
<
vmul
,
XYZNTuples
<
T
>
,
platform
::
CPUPlace
>
(
attr
->
d
);
act_gate
(
gates
+
attr
->
d
,
gates
+
attr
->
d
,
attr
->
d
);
vmul_d
(
ht_1
,
gates
+
attr
->
d
,
ht
,
attr
->
d
);
}
// compute the second part of GRU:
// ht = act_gate(u) * act_cand(s) + (1-act_gate(u)) * ht_1
void
GRUHtPart2
(
gru_t
*
step
,
const
gru_attr_t
*
attr
)
{
T
*
gates
=
reinterpret_cast
<
T
*>
(
step
->
gates
);
T
*
ht
=
reinterpret_cast
<
T
*>
(
step
->
ht
);
const
T
*
ht_1
=
reinterpret_cast
<
const
T
*>
(
step
->
ht_1
);
int
d
=
attr
->
d
;
auto
act_gate
=
getActFunc
(
attr
->
act_gate
,
d
);
auto
act_cand
=
getActFunc
(
attr
->
act_cand
,
d
);
T
*
y
=
gates
+
d
*
2
;
act_gate
(
gates
,
gates
,
d
);
act_cand
(
y
,
y
,
d
);
// out = zt*ht~ + (1-zt)*ht_1
for
(
int
i
=
0
;
i
<
d
;
++
i
)
{
ht
[
i
]
=
gates
[
i
]
*
y
[
i
]
+
(
static_cast
<
T
>
(
1
)
-
gates
[
i
])
*
ht_1
[
i
];
}
}
// TODO(TJ): tuning me
bool
VSigmoidKernel
::
UseMe
(
int
d
)
const
{
return
true
;
}
bool
VTanhKernel
::
UseMe
(
int
d
)
const
{
return
true
;
}
bool
LSTMCtHtKernel
::
UseMe
(
lstm_attr_t
attr
)
const
{
return
true
;
}
bool
LSTMC1H1Kernel
::
UseMe
(
lstm_attr_t
attr
)
const
{
return
true
;
}
bool
GRUH1Kernel
::
UseMe
(
gru_attr_t
attr
)
const
{
return
true
;
}
AWALYS_USE_ME_WITH_DOUBLE
(
VSigmoid
);
AWALYS_USE_ME_WITH_DOUBLE
(
VTanh
);
bool
GRUHtPart1Kernel
::
UseMe
(
gru_attr_t
attr
)
const
{
return
true
;
}
#undef AWALYS_USE_ME_WITH_DOUBLE
bool
GRUHtPart2Kernel
::
UseMe
(
gru_attr_t
attr
)
const
{
return
true
;
}
}
// namespace mix
}
// namespace more
...
...
@@ -79,11 +202,15 @@ AWALYS_USE_ME_WITH_DOUBLE(VTanh);
namespace
mix
=
paddle
::
operators
::
jit
::
more
::
mix
;
#define REGISTER_MORE_KERNEL(key, func) \
REGISTER_JITKERNEL_MORE(key, mix, mix::func##Kernel<float>, \
mix::func##Kernel<double>)
#define REGISTER_MORE_KERNEL(key, func) \
REGISTER_JITKERNEL_MORE(key, mix, mix::func##Kernel)
REGISTER_MORE_KERNEL
(
vsigmoid
,
VSigmoid
);
REGISTER_MORE_KERNEL
(
vtanh
,
VTanh
);
REGISTER_MORE_KERNEL
(
lstmctht
,
LSTMCtHt
);
REGISTER_MORE_KERNEL
(
lstmc1h1
,
LSTMC1H1
);
REGISTER_MORE_KERNEL
(
gruh1
,
GRUH1
);
REGISTER_MORE_KERNEL
(
gruhtpart1
,
GRUHtPart1
);
REGISTER_MORE_KERNEL
(
gruhtpart2
,
GRUHtPart2
);
#undef REGISTER_MORE_KERNEL
paddle/fluid/operators/jit/more/mix/mix.h
浏览文件 @
c187a7c6
...
...
@@ -22,18 +22,21 @@ namespace operators {
namespace
jit
{
namespace
more
{
namespace
mix
{
using
T
=
float
;
template
<
typename
T
>
void
VSigmoid
(
const
T
*
x
,
T
*
y
,
int
n
);
template
<
typename
T
>
void
VTanh
(
const
T
*
x
,
T
*
y
,
int
n
);
void
LSTMCtHt
(
lstm_t
*
step
,
const
lstm_attr_t
*
attr
);
void
LSTMC1H1
(
lstm_t
*
step
,
const
lstm_attr_t
*
attr
);
void
GRUH1
(
gru_t
*
step
,
const
gru_attr_t
*
attr
);
void
GRUHtPart1
(
gru_t
*
step
,
const
gru_attr_t
*
attr
);
void
GRUHtPart2
(
gru_t
*
step
,
const
gru_attr_t
*
attr
);
#define DECLARE_MORE_KERNEL(name, tuples) \
template <typename T> \
class name##Kernel : public KernelImpl<tuples<T>> { \
public: \
name##Kernel() { this->func = name
<T>; }
\
name##Kernel() { this->func = name
; }
\
bool UseMe(typename tuples<T>::attr_type) const override; \
}
...
...
@@ -41,6 +44,13 @@ void VTanh(const T* x, T* y, int n);
DECLARE_MORE_KERNEL
(
VSigmoid
,
XYNTuples
);
DECLARE_MORE_KERNEL
(
VTanh
,
XYNTuples
);
DECLARE_MORE_KERNEL
(
LSTMCtHt
,
LSTMTuples
);
DECLARE_MORE_KERNEL
(
LSTMC1H1
,
LSTMTuples
);
DECLARE_MORE_KERNEL
(
GRUH1
,
GRUTuples
);
DECLARE_MORE_KERNEL
(
GRUHtPart1
,
GRUTuples
);
DECLARE_MORE_KERNEL
(
GRUHtPart2
,
GRUTuples
);
#undef DECLARE_MORE_KERNEL
}
// namespace mix
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录