Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
bf9302f9
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
bf9302f9
编写于
12月 13, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add lstm, peephole refer and test
上级
bf951fa7
变更
13
显示空白变更内容
内联
并排
Showing
13 changed file
with
346 addition
and
138 deletion
+346
-138
paddle/fluid/operators/jit/gen_base.cc
paddle/fluid/operators/jit/gen_base.cc
+0
-5
paddle/fluid/operators/jit/gen_base.h
paddle/fluid/operators/jit/gen_base.h
+0
-4
paddle/fluid/operators/jit/helper.cc
paddle/fluid/operators/jit/helper.cc
+20
-0
paddle/fluid/operators/jit/helper.h
paddle/fluid/operators/jit/helper.h
+2
-2
paddle/fluid/operators/jit/kernel_base.h
paddle/fluid/operators/jit/kernel_base.h
+51
-3
paddle/fluid/operators/jit/kernel_key.cc
paddle/fluid/operators/jit/kernel_key.cc
+38
-0
paddle/fluid/operators/jit/kernel_key.h
paddle/fluid/operators/jit/kernel_key.h
+4
-0
paddle/fluid/operators/jit/refer/CMakeLists.txt
paddle/fluid/operators/jit/refer/CMakeLists.txt
+2
-0
paddle/fluid/operators/jit/refer/refer.cc
paddle/fluid/operators/jit/refer/refer.cc
+3
-0
paddle/fluid/operators/jit/refer/refer.h
paddle/fluid/operators/jit/refer/refer.h
+89
-0
paddle/fluid/operators/jit/test.cc
paddle/fluid/operators/jit/test.cc
+137
-0
paddle/fluid/operators/math/jit_kernel_impl.h
paddle/fluid/operators/math/jit_kernel_impl.h
+0
-39
paddle/fluid/operators/math/jit_kernel_refer.h
paddle/fluid/operators/math/jit_kernel_refer.h
+0
-85
未找到文件。
paddle/fluid/operators/jit/gen_base.cc
浏览文件 @
bf9302f9
...
...
@@ -23,11 +23,6 @@ namespace paddle {
namespace
operators
{
namespace
jit
{
template
<
>
size_t
JitCodeKey
<
int
>
(
int
d
)
{
return
d
;
}
// refer do not need useme, it would be the last one.
void
GenBase
::
dumpCode
(
const
unsigned
char
*
code
)
const
{
if
(
code
)
{
...
...
paddle/fluid/operators/jit/gen_base.h
浏览文件 @
bf9302f9
...
...
@@ -43,10 +43,6 @@ class GenBase : public Kernel {
void
dumpCode
(
const
unsigned
char
*
code
)
const
;
};
// Every JitCode should have a method to get the key from attribution
template
<
typename
Attr
>
size_t
JitCodeKey
(
Attr
attr
);
// Creator is used to creat the jitcode and save in pool.
// Every JitCode should have one creator.
class
GenCreator
{
...
...
paddle/fluid/operators/jit/helper.cc
浏览文件 @
bf9302f9
...
...
@@ -13,6 +13,7 @@
* limitations under the License. */
#include "paddle/fluid/operators/jit/helper.h"
#include <algorithm> // tolower
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
...
...
@@ -36,6 +37,8 @@ const char* to_string(KernelType kt) {
ONE_CASE
(
vexp
);
ONE_CASE
(
vsigmoid
);
ONE_CASE
(
vtanh
);
ONE_CASE
(
lstmctht
);
ONE_CASE
(
lstmc1h1
);
default:
PADDLE_THROW
(
"Not support type: %d"
,
kt
);
return
"NOT JITKernel"
;
...
...
@@ -44,6 +47,23 @@ const char* to_string(KernelType kt) {
}
#undef ONE_CASE
KernelType
to_kerneltype
(
const
std
::
string
&
act
)
{
std
::
string
lower
=
act
;
std
::
transform
(
lower
.
begin
(),
lower
.
end
(),
lower
.
begin
(),
::
tolower
);
if
(
lower
==
"relu"
||
lower
==
"vrelu"
)
{
return
vrelu
;
}
else
if
(
lower
==
"identity"
||
lower
==
"videntity"
||
lower
==
""
)
{
return
videntity
;
}
else
if
(
lower
==
"exp"
||
lower
==
"vexp"
)
{
return
vexp
;
}
else
if
(
lower
==
"sigmoid"
||
lower
==
"vsigmoid"
)
{
return
vsigmoid
;
}
else
if
(
lower
==
"tanh"
||
lower
==
"vtanh"
)
{
return
vtanh
;
}
return
non_kernel
;
}
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit/helper.h
浏览文件 @
bf9302f9
...
...
@@ -14,9 +14,7 @@
#pragma once
#include <memory> // for unique_ptr
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/operators/jit/gen_base.h"
#include "paddle/fluid/operators/jit/kernel_base.h"
...
...
@@ -124,6 +122,8 @@ typename KernelTuples::func_type Get(typename KernelTuples::attr_type attr) {
const
char
*
to_string
(
KernelType
kt
);
KernelType
to_kerneltype
(
const
std
::
string
&
act
);
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit/kernel_base.h
浏览文件 @
bf9302f9
...
...
@@ -20,8 +20,9 @@ namespace operators {
namespace
jit
{
typedef
enum
{
vmul
=
0
,
vadd
=
1
,
non_kernel
=
0
,
vmul
=
1
,
vadd
=
2
,
vaddrelu
,
vsub
,
vscal
,
...
...
@@ -30,7 +31,9 @@ typedef enum {
videntity
,
vexp
,
vsigmoid
,
vtanh
vtanh
,
lstmctht
,
lstmc1h1
}
KernelType
;
template
<
typename
T
>
...
...
@@ -50,6 +53,51 @@ struct XYNTuples {
typedef
void
(
*
func_type
)(
const
T
*
,
T
*
,
int
);
};
typedef
struct
{
void
*
gates
;
// gates: x_ch, x_ih, x_fh, x_oh
const
void
*
ct_1
;
void
*
ct
;
void
*
ht
;
/* weight_peephole and checked data are only used in peephole*/
const
void
*
wp
{
nullptr
};
// W_ic, W_fc, W_oc
void
*
checked
{
nullptr
};
// size: 2 * d
}
lstm_t
;
typedef
struct
{
void
*
gates
;
// gates: {x_update, x_reset; x_state}
const
void
*
ht_1
;
void
*
ht
;
}
gru_t
;
struct
rnn_attr_s
{
int
d
;
KernelType
act_gate
,
act_cand
;
rnn_attr_s
()
=
default
;
rnn_attr_s
(
int
_d
,
KernelType
_act_gate
,
KernelType
_act_cand
)
:
d
(
_d
),
act_gate
(
_act_gate
),
act_cand
(
_act_cand
)
{}
};
struct
lstm_attr_s
:
public
rnn_attr_s
{
bool
use_peephole
;
KernelType
act_cell
;
lstm_attr_s
()
=
default
;
lstm_attr_s
(
int
_d
,
KernelType
_act_gate
,
KernelType
_act_cand
,
KernelType
_act_cell
,
bool
_use_peephole
=
false
)
:
rnn_attr_s
(
_d
,
_act_gate
,
_act_cand
),
use_peephole
(
_use_peephole
),
act_cell
(
_act_cell
)
{}
};
typedef
struct
rnn_attr_s
gru_attr_t
;
typedef
struct
lstm_attr_s
lstm_attr_t
;
template
<
typename
T
>
struct
LSTMTuples
{
typedef
T
data_type
;
typedef
lstm_attr_t
attr_type
;
typedef
void
(
*
func_type
)(
lstm_t
*
,
const
lstm_attr_t
*
);
};
// Just for adding to kernel pool without template
class
Kernel
{
public:
...
...
paddle/fluid/operators/jit/kernel_key.cc
0 → 100644
浏览文件 @
bf9302f9
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/kernel_key.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
{
template
<
>
size_t
JitCodeKey
<
int
>
(
const
int
&
d
)
{
return
d
;
}
template
<
>
size_t
JitCodeKey
<
lstm_attr_t
>
(
const
lstm_attr_t
&
attr
)
{
constexpr
int
act_type_shift
=
3
;
// suppot 2^3 act types
size_t
key
=
attr
.
d
;
int
gate_key
=
static_cast
<
int
>
(
attr
.
act_gate
)
<<
1
;
int
cand_key
=
static_cast
<
int
>
(
attr
.
act_cand
)
<<
(
1
+
act_type_shift
);
int
cell_key
=
static_cast
<
int
>
(
attr
.
act_cell
)
<<
(
1
+
act_type_shift
*
2
);
return
(
key
<<
(
1
+
act_type_shift
*
3
))
+
gate_key
+
cand_key
+
cell_key
+
attr
.
use_peephole
;
}
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit/kernel_key.h
浏览文件 @
bf9302f9
...
...
@@ -44,6 +44,10 @@ struct KernelKey {
bool
operator
!=
(
const
KernelKey
&
o
)
const
{
return
!
(
*
this
==
o
);
}
};
// Every JitCode should have a method to get the key from attribution
template
<
typename
Attr
>
size_t
JitCodeKey
(
const
Attr
&
attr
);
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit/refer/CMakeLists.txt
浏览文件 @
bf9302f9
...
...
@@ -18,3 +18,5 @@ USE_JITKERNEL_REFER(videntity)
USE_JITKERNEL_REFER
(
vexp
)
USE_JITKERNEL_REFER
(
vsigmoid
)
USE_JITKERNEL_REFER
(
vtanh
)
USE_JITKERNEL_REFER
(
lstmctht
)
USE_JITKERNEL_REFER
(
lstmc1h1
)
paddle/fluid/operators/jit/refer/refer.cc
浏览文件 @
bf9302f9
...
...
@@ -35,4 +35,7 @@ REGISTER_REFER_KERNEL(vexp, VExp);
REGISTER_REFER_KERNEL
(
vsigmoid
,
VSigmoid
);
REGISTER_REFER_KERNEL
(
vtanh
,
VTanh
);
REGISTER_REFER_KERNEL
(
lstmctht
,
LSTMCtHt
);
REGISTER_REFER_KERNEL
(
lstmc1h1
,
LSTMC1H1
);
#undef REGISTER_REFER_KERNEL
paddle/fluid/operators/jit/refer/refer.h
浏览文件 @
bf9302f9
...
...
@@ -110,6 +110,91 @@ void VTanh(const T* x, T* y, int n) {
}
}
template
<
typename
T
>
void
(
*
getActFunc
(
KernelType
type
))(
const
T
*
,
T
*
,
int
)
{
// NOLINT
if
(
type
==
vsigmoid
)
{
return
VSigmoid
<
T
>
;
}
else
if
(
type
==
vrelu
)
{
return
VRelu
<
T
>
;
}
else
if
(
type
==
vtanh
)
{
return
VTanh
<
T
>
;
}
else
if
(
type
==
videntity
)
{
return
VIdentity
<
T
>
;
}
PADDLE_THROW
(
"Not support type: %s"
,
type
);
return
nullptr
;
}
// compute ct and ht
template
<
typename
T
>
void
LSTMCtHt
(
lstm_t
*
step
,
const
lstm_attr_t
*
attr
)
{
T
*
gates
=
reinterpret_cast
<
T
*>
(
step
->
gates
);
const
T
*
ct_1
=
reinterpret_cast
<
const
T
*>
(
step
->
ct_1
);
T
*
ct
=
reinterpret_cast
<
T
*>
(
step
->
ct
);
T
*
ht
=
reinterpret_cast
<
T
*>
(
step
->
ht
);
const
T
*
wp
=
reinterpret_cast
<
const
T
*>
(
step
->
wp
);
T
*
checked
=
reinterpret_cast
<
T
*>
(
step
->
checked
);
auto
act_gate
=
getActFunc
<
T
>
(
attr
->
act_gate
);
auto
act_cand
=
getActFunc
<
T
>
(
attr
->
act_cand
);
auto
act_cell
=
getActFunc
<
T
>
(
attr
->
act_cell
);
int
d
=
attr
->
d
;
int
d2
=
d
*
2
;
int
d3
=
d
*
3
;
// gates: W_ch, W_ih, W_fh, W_oh
if
(
attr
->
use_peephole
)
{
VMul
(
wp
,
ct_1
,
checked
,
d
);
VMul
(
wp
+
d
,
ct_1
,
checked
+
d
,
d
);
VAdd
(
checked
,
gates
+
d
,
gates
+
d
,
d2
);
act_gate
(
gates
+
d
,
gates
+
d
,
d2
);
}
else
{
act_gate
(
gates
+
d
,
gates
+
d
,
d3
);
}
// C_t = C_t-1 * fgated + cand_gated * igated
act_cand
(
gates
,
gates
,
d
);
VMul
(
gates
,
gates
+
d
,
gates
+
d
,
d
);
VMul
(
ct_1
,
gates
+
d2
,
gates
+
d2
,
d
);
VAdd
(
gates
+
d
,
gates
+
d2
,
ct
,
d
);
if
(
attr
->
use_peephole
)
{
// get ogated
VMul
(
wp
+
d2
,
ct
,
gates
+
d
,
d
);
VAdd
(
gates
+
d
,
gates
+
d3
,
gates
+
d3
,
d
);
act_gate
(
gates
+
d3
,
gates
+
d3
,
d
);
}
// H_t = act_cell(C_t) * ogated
act_cell
(
ct
,
gates
+
d2
,
d
);
VMul
(
gates
+
d2
,
gates
+
d3
,
ht
,
d
);
}
// compute c1 and h1 without c0 or h0
template
<
typename
T
>
void
LSTMC1H1
(
lstm_t
*
step
,
const
lstm_attr_t
*
attr
)
{
T
*
gates
=
reinterpret_cast
<
T
*>
(
step
->
gates
);
T
*
ct
=
reinterpret_cast
<
T
*>
(
step
->
ct
);
T
*
ht
=
reinterpret_cast
<
T
*>
(
step
->
ht
);
auto
act_gate
=
getActFunc
<
T
>
(
attr
->
act_gate
);
auto
act_cand
=
getActFunc
<
T
>
(
attr
->
act_cand
);
auto
act_cell
=
getActFunc
<
T
>
(
attr
->
act_cell
);
int
d
=
attr
->
d
;
int
d2
=
d
*
2
;
int
d3
=
d
*
3
;
/* C_t = igated * cgated*/
act_gate
(
gates
+
d
,
gates
+
d
,
d
);
act_cand
(
gates
,
gates
,
d
);
VMul
(
gates
,
gates
+
d
,
ct
,
d
);
if
(
attr
->
use_peephole
)
{
// get outgated, put W_oc * C_t on igated
const
T
*
wp
=
reinterpret_cast
<
const
T
*>
(
step
->
wp
);
VMul
(
wp
+
d2
,
ct
,
gates
+
d
,
d
);
VAdd
(
gates
+
d
,
gates
+
d3
,
gates
+
d3
,
d
);
}
/* H_t = act_cell(C_t) * ogated */
act_gate
(
gates
+
d3
,
gates
+
d3
,
d
);
act_cell
(
ct
,
gates
+
d2
,
d
);
VMul
(
gates
+
d2
,
gates
+
d3
,
ht
,
d
);
}
#define DECLARE_REFER_KERNEL(name, tuples) \
template <typename T> \
class name##Kernel : public ReferKernel<tuples<T>> { \
...
...
@@ -134,6 +219,10 @@ DECLARE_REFER_KERNEL(VExp, XYNTuples);
DECLARE_REFER_KERNEL
(
VSigmoid
,
XYNTuples
);
DECLARE_REFER_KERNEL
(
VTanh
,
XYNTuples
);
// lstm_t* , const lstm_attr_t*
DECLARE_REFER_KERNEL
(
LSTMCtHt
,
LSTMTuples
);
DECLARE_REFER_KERNEL
(
LSTMC1H1
,
LSTMTuples
);
#undef DECLARE_REFER_KERNEL
}
// namespace refer
...
...
paddle/fluid/operators/jit/test.cc
浏览文件 @
bf9302f9
...
...
@@ -350,6 +350,143 @@ TEST(JITKernel, vtanh) {
TestXYNKernel
<
jit
::
vtanh
,
double
,
paddle
::
platform
::
CPUPlace
>
();
}
template
<
typename
T
,
typename
KernelTuples
>
void
TestLSTMFunc
(
const
typename
KernelTuples
::
func_type
tgt
,
const
std
::
vector
<
T
>&
xsrc
,
const
std
::
vector
<
T
>&
wp
,
const
std
::
vector
<
T
>&
ct_1
,
const
std
::
vector
<
T
>&
ct_ref
,
const
std
::
vector
<
T
>&
ht_ref
,
const
paddle
::
operators
::
jit
::
lstm_attr_t
&
attr
)
{
EXPECT_TRUE
(
tgt
!=
nullptr
);
EXPECT_EQ
(
ct_ref
.
size
(),
ht_ref
.
size
());
EXPECT_EQ
(
ct_1
.
size
(),
ht_ref
.
size
());
EXPECT_EQ
(
xsrc
.
size
(),
4
*
ht_ref
.
size
());
EXPECT_EQ
(
wp
.
size
(),
3
*
ht_ref
.
size
());
// x could be changed after compute, so copy to save src
int
d
=
ht_ref
.
size
();
std
::
vector
<
T
>
x
(
xsrc
.
size
()),
ct
(
ct_ref
.
size
()),
ht
(
ht_ref
.
size
());
std
::
vector
<
T
>
checked
(
2
*
d
);
std
::
copy
(
xsrc
.
begin
(),
xsrc
.
end
(),
x
.
begin
());
const
T
*
ct_1_data
=
ct_1
.
data
();
const
T
*
wp_data
=
wp
.
data
();
const
T
*
ct_ref_data
=
ct_ref
.
data
();
const
T
*
ht_ref_data
=
ht_ref
.
data
();
T
*
x_data
=
x
.
data
();
T
*
ct_data
=
ct
.
data
();
T
*
ht_data
=
ht
.
data
();
T
*
checked_data
=
checked
.
data
();
paddle
::
operators
::
jit
::
lstm_t
step
;
step
.
gates
=
x_data
;
step
.
ct_1
=
ct_1_data
;
step
.
ct
=
ct_data
;
step
.
ht
=
ht_data
;
if
(
attr
.
use_peephole
)
{
step
.
wp
=
wp_data
;
step
.
checked
=
checked_data
;
}
tgt
(
&
step
,
&
attr
);
ExpectEQ
<
T
>
(
ct_data
,
ct_ref_data
,
d
);
ExpectEQ
<
T
>
(
ht_data
,
ht_ref_data
,
d
);
}
template
<
paddle
::
operators
::
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
TestLSTMKernel
()
{
namespace
jit
=
paddle
::
operators
::
jit
;
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
std
::
vector
<
std
::
string
>
all_acts
=
{
"sigmoid"
,
"tanh"
,
"relu"
,
"identity"
};
for
(
int
d
:
TestSizes
())
{
for
(
bool
use_peephole
:
{
true
,
false
})
{
for
(
auto
&
act_gate
:
all_acts
)
{
for
(
auto
&
act_cand
:
all_acts
)
{
for
(
auto
&
act_cell
:
all_acts
)
{
std
::
string
info
=
act_gate
+
act_cand
+
act_cell
+
(
use_peephole
?
"peephole_"
:
""
)
+
"size_"
+
std
::
to_string
(
d
);
const
jit
::
lstm_attr_t
attr
(
d
,
jit
::
to_kerneltype
(
act_gate
),
jit
::
to_kerneltype
(
act_cand
),
jit
::
to_kerneltype
(
act_cell
),
use_peephole
);
auto
ref
=
jit
::
GetRefer
<
KT
,
jit
::
LSTMTuples
<
T
>>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
std
::
vector
<
T
>
xsrc
(
4
*
d
),
wp
(
3
*
d
),
ct_1
(
d
);
std
::
vector
<
T
>
ct_ref
(
d
),
ht_ref
(
d
),
checked
(
2
*
d
);
RandomVec
<
T
>
(
4
*
d
,
xsrc
.
data
(),
-
2.
f
,
2.
f
);
RandomVec
<
T
>
(
3
*
d
,
wp
.
data
(),
-
2.
f
,
2.
f
);
RandomVec
<
T
>
(
d
,
ct_1
.
data
(),
-
2.
f
,
2.
f
);
// x could be changed after compute, so copy to save src
std
::
vector
<
T
>
x
(
xsrc
.
size
());
std
::
copy
(
xsrc
.
begin
(),
xsrc
.
end
(),
x
.
begin
());
const
T
*
ct_1_data
=
ct_1
.
data
();
const
T
*
wp_data
=
wp
.
data
();
T
*
x_data
=
x
.
data
();
T
*
checked_data
=
checked
.
data
();
T
*
ct_ref_data
=
ct_ref
.
data
();
T
*
ht_ref_data
=
ht_ref
.
data
();
jit
::
lstm_t
step
;
step
.
gates
=
x_data
;
step
.
ct_1
=
ct_1_data
;
step
.
ct
=
ct_ref_data
;
step
.
ht
=
ht_ref_data
;
if
(
use_peephole
)
{
step
.
wp
=
wp_data
;
step
.
checked
=
checked_data
;
}
ref
(
&
step
,
&
attr
);
// test jitcode
auto
jitcode
=
jit
::
GetJitCode
<
KT
,
jit
::
LSTMTuples
<
T
>
,
PlaceType
>
(
attr
);
if
(
jitcode
)
{
VLOG
(
10
)
<<
"Test Jitcode Kernel "
<<
info
;
TestLSTMFunc
<
T
,
jit
::
LSTMTuples
<
T
>>
(
jitcode
,
xsrc
,
wp
,
ct_1
,
ct_ref
,
ht_ref
,
attr
);
}
// test all impls in more
jit
::
KernelKey
kkey
(
KT
,
PlaceType
());
auto
&
pool
=
jit
::
KernelPool
().
Instance
().
AllKernels
();
auto
iter
=
pool
.
find
(
kkey
);
if
(
iter
!=
pool
.
end
())
{
auto
&
impls
=
iter
->
second
;
for
(
auto
&
impl
:
impls
)
{
auto
i
=
dynamic_cast
<
const
jit
::
KernelImpl
<
jit
::
LSTMTuples
<
T
>>*>
(
impl
.
get
());
if
(
i
&&
i
->
UseMe
(
attr
))
{
auto
more
=
i
->
GetFunc
();
VLOG
(
10
)
<<
"Test More Kernel "
<<
info
;
TestLSTMFunc
<
T
,
jit
::
LSTMTuples
<
T
>>
(
more
,
xsrc
,
wp
,
ct_1
,
ct_ref
,
ht_ref
,
attr
);
}
}
}
// Test result from Get function
auto
tgt
=
jit
::
Get
<
KT
,
jit
::
LSTMTuples
<
T
>
,
PlaceType
>
(
attr
);
TestLSTMFunc
<
T
,
jit
::
LSTMTuples
<
T
>>
(
tgt
,
xsrc
,
wp
,
ct_1
,
ct_ref
,
ht_ref
,
attr
);
}
}
}
}
}
}
TEST
(
JITKernel
,
lstmctht
)
{
namespace
jit
=
paddle
::
operators
::
jit
;
TestLSTMKernel
<
jit
::
lstmctht
,
float
,
paddle
::
platform
::
CPUPlace
>
();
TestLSTMKernel
<
jit
::
lstmctht
,
double
,
paddle
::
platform
::
CPUPlace
>
();
}
TEST
(
JITKernel
,
lstmc1h1
)
{
namespace
jit
=
paddle
::
operators
::
jit
;
TestLSTMKernel
<
jit
::
lstmc1h1
,
float
,
paddle
::
platform
::
CPUPlace
>
();
TestLSTMKernel
<
jit
::
lstmc1h1
,
double
,
paddle
::
platform
::
CPUPlace
>
();
}
// TODO(TJ): refine the tests template
TEST
(
JITKernel
,
pool
)
{
// TODO(TJ): add some test
}
paddle/fluid/operators/math/jit_kernel_impl.h
浏览文件 @
bf9302f9
...
...
@@ -28,45 +28,6 @@ namespace jitkernel {
#define YMM_FLOAT_BLOCK 8
#define ZMM_FLOAT_BLOCK 16
typedef
struct
{
void
*
gates
;
// gates: W_ch, W_ih, W_fh, W_oh
const
void
*
ct_1
;
void
*
ct
;
void
*
ht
;
/* weight_peephole and checked data are only used in peephole*/
const
void
*
wp
{
nullptr
};
void
*
checked
{
nullptr
};
}
lstm_t
;
typedef
struct
{
void
*
gates
;
// gates: {W_update, W_reset; W_state}
const
void
*
ht_1
;
void
*
ht
;
}
gru_t
;
struct
rnn_attr_s
{
int
d
;
std
::
string
act_gate
,
act_cand
;
rnn_attr_s
()
=
default
;
rnn_attr_s
(
int
_d
,
const
std
::
string
&
_act_gate
,
const
std
::
string
&
_act_cand
)
:
d
(
_d
),
act_gate
(
_act_gate
),
act_cand
(
_act_cand
)
{}
};
struct
lstm_attr_s
:
public
rnn_attr_s
{
bool
use_peephole
;
std
::
string
act_cell
;
lstm_attr_s
()
=
default
;
lstm_attr_s
(
int
_d
,
const
std
::
string
&
_act_gate
,
const
std
::
string
&
_act_cand
,
const
std
::
string
&
_act_cell
,
bool
_use_peephole
=
false
)
:
rnn_attr_s
(
_d
,
_act_gate
,
_act_cand
),
use_peephole
(
_use_peephole
),
act_cell
(
_act_cell
)
{}
};
typedef
struct
rnn_attr_s
gru_attr_t
;
typedef
struct
lstm_attr_s
lstm_attr_t
;
}
// namespace jitkernel
}
// namespace math
}
// namespace operators
...
...
paddle/fluid/operators/math/jit_kernel_refer.h
浏览文件 @
bf9302f9
...
...
@@ -24,91 +24,6 @@ namespace math {
namespace
jitkernel
{
namespace
refer
{
template
<
typename
T
>
void
(
*
getActFunc
(
const
std
::
string
&
type
))(
const
T
*
,
T
*
,
int
)
{
// NOLINT
if
(
type
==
"sigmoid"
)
{
return
VSigmoid
<
T
>
;
}
else
if
(
type
==
"relu"
)
{
return
VRelu
<
T
>
;
}
else
if
(
type
==
"tanh"
)
{
return
VTanh
<
T
>
;
}
else
if
(
type
==
"identity"
||
type
==
""
)
{
return
VIdentity
<
T
>
;
}
PADDLE_THROW
(
"Not support type: %s"
,
type
);
return
nullptr
;
}
// compute ct and ht
template
<
typename
T
>
void
LSTMCtHt
(
lstm_t
*
step
,
const
lstm_attr_t
*
attr
)
{
T
*
gates
=
reinterpret_cast
<
T
*>
(
step
->
gates
);
const
T
*
ct_1
=
reinterpret_cast
<
const
T
*>
(
step
->
ct_1
);
T
*
ct
=
reinterpret_cast
<
T
*>
(
step
->
ct
);
T
*
ht
=
reinterpret_cast
<
T
*>
(
step
->
ht
);
const
T
*
wp
=
reinterpret_cast
<
const
T
*>
(
step
->
wp
);
T
*
checked
=
reinterpret_cast
<
T
*>
(
step
->
checked
);
auto
act_gate
=
getActFunc
<
T
>
(
attr
->
act_gate
);
auto
act_cand
=
getActFunc
<
T
>
(
attr
->
act_cand
);
auto
act_cell
=
getActFunc
<
T
>
(
attr
->
act_cell
);
int
d
=
attr
->
d
;
int
d2
=
d
*
2
;
int
d3
=
d
*
3
;
// gates: W_ch, W_ih, W_fh, W_oh
if
(
attr
->
use_peephole
)
{
VMul
(
wp
,
ct_1
,
checked
,
d
);
VMul
(
wp
+
d
,
ct_1
,
checked
+
d
,
d
);
VAdd
(
checked
,
gates
+
d
,
gates
+
d
,
d2
);
act_gate
(
gates
+
d
,
gates
+
d
,
d2
);
}
else
{
act_gate
(
gates
+
d
,
gates
+
d
,
d3
);
}
// C_t = C_t-1 * fgated + cand_gated * igated
act_cand
(
gates
,
gates
,
d
);
VMul
(
gates
,
gates
+
d
,
gates
+
d
,
d
);
VMul
(
ct_1
,
gates
+
d2
,
gates
+
d2
,
d
);
VAdd
(
gates
+
d
,
gates
+
d2
,
ct
,
d
);
if
(
attr
->
use_peephole
)
{
// get ogated
VMul
(
wp
+
d2
,
ct
,
gates
+
d
,
d
);
VAdd
(
gates
+
d
,
gates
+
d3
,
gates
+
d3
,
d
);
act_gate
(
gates
+
d3
,
gates
+
d3
,
d
);
}
// H_t = act_cell(C_t) * ogated
act_cell
(
ct
,
gates
+
d2
,
d
);
VMul
(
gates
+
d2
,
gates
+
d3
,
ht
,
d
);
}
// compute c1 and h1 without c0 or h0
template
<
typename
T
>
void
LSTMC1H1
(
lstm_t
*
step
,
const
lstm_attr_t
*
attr
)
{
T
*
gates
=
reinterpret_cast
<
T
*>
(
step
->
gates
);
T
*
ct
=
reinterpret_cast
<
T
*>
(
step
->
ct
);
T
*
ht
=
reinterpret_cast
<
T
*>
(
step
->
ht
);
auto
act_gate
=
getActFunc
<
T
>
(
attr
->
act_gate
);
auto
act_cand
=
getActFunc
<
T
>
(
attr
->
act_cand
);
auto
act_cell
=
getActFunc
<
T
>
(
attr
->
act_cell
);
int
d
=
attr
->
d
;
int
d2
=
d
*
2
;
int
d3
=
d
*
3
;
/* C_t = igated * cgated*/
act_gate
(
gates
+
d
,
gates
+
d
,
d
);
act_cand
(
gates
,
gates
,
d
);
VMul
(
gates
,
gates
+
d
,
ct
,
d
);
if
(
attr
->
use_peephole
)
{
// get outgated, put W_oc * C_t on igated
const
T
*
wp
=
reinterpret_cast
<
const
T
*>
(
step
->
wp
);
VMul
(
wp
+
d2
,
ct
,
gates
+
d
,
d
);
VAdd
(
gates
+
d
,
gates
+
d3
,
gates
+
d3
,
d
);
}
/* H_t = act_cell(C_t) * ogated */
act_gate
(
gates
+
d3
,
gates
+
d3
,
d
);
act_cell
(
ct
,
gates
+
d2
,
d
);
VMul
(
gates
+
d2
,
gates
+
d3
,
ht
,
d
);
}
// compute h1 without h0
template
<
typename
T
>
void
GRUH1
(
gru_t
*
step
,
const
gru_attr_t
*
attr
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录