Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
30e47bce
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
30e47bce
编写于
11月 27, 2018
作者:
Q
Qiyang Min
提交者:
GitHub
11月 27, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' into revert_vlog
上级
be04d99f
3ae6692a
变更
15
展开全部
隐藏空白更改
内联
并排
Showing
15 changed file
with
1100 addition
and
948 deletion
+1100
-948
Dockerfile
Dockerfile
+2
-0
paddle/fluid/operators/fused/fusion_gru_op.cc
paddle/fluid/operators/fused/fusion_gru_op.cc
+41
-26
paddle/fluid/operators/fused/fusion_lstm_op.cc
paddle/fluid/operators/fused/fusion_lstm_op.cc
+46
-27
paddle/fluid/operators/math/jit_code.cc
paddle/fluid/operators/math/jit_code.cc
+155
-39
paddle/fluid/operators/math/jit_code.h
paddle/fluid/operators/math/jit_code.h
+200
-32
paddle/fluid/operators/math/jit_kernel.h
paddle/fluid/operators/math/jit_kernel.h
+7
-19
paddle/fluid/operators/math/jit_kernel_blas.cc
paddle/fluid/operators/math/jit_kernel_blas.cc
+10
-55
paddle/fluid/operators/math/jit_kernel_exp.cc
paddle/fluid/operators/math/jit_kernel_exp.cc
+4
-188
paddle/fluid/operators/math/jit_kernel_impl.h
paddle/fluid/operators/math/jit_kernel_impl.h
+73
-0
paddle/fluid/operators/math/jit_kernel_macro.h
paddle/fluid/operators/math/jit_kernel_macro.h
+4
-4
paddle/fluid/operators/math/jit_kernel_refer.h
paddle/fluid/operators/math/jit_kernel_refer.h
+238
-0
paddle/fluid/operators/math/jit_kernel_rnn.cc
paddle/fluid/operators/math/jit_kernel_rnn.cc
+184
-406
paddle/fluid/operators/math/jit_kernel_test.cc
paddle/fluid/operators/math/jit_kernel_test.cc
+81
-146
paddle/scripts/paddle_build.sh
paddle/scripts/paddle_build.sh
+49
-0
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+6
-6
未找到文件。
Dockerfile
浏览文件 @
30e47bce
...
...
@@ -43,6 +43,8 @@ RUN wget -q https://www.python.org/ftp/python/3.7.0/Python-3.7.0.tgz && \
CFLAGS
=
"-Wformat"
./configure
--prefix
=
/usr/local/
--enable-shared
>
/dev/null
&&
\
make
-j8
>
/dev/null
&&
make altinstall
>
/dev/null
RUN
rm
-r
/root/python_build
RUN
apt-get update
&&
\
apt-get
install
-y
--allow-downgrades
patchelf
\
python3 python3-dev python3-pip
\
...
...
paddle/fluid/operators/fused/fusion_gru_op.cc
浏览文件 @
30e47bce
...
...
@@ -183,24 +183,27 @@ class FusionGRUKernel : public framework::OpKernel<T> {
const int total_T = x_dims[0]; \
const int D3 = wh_dims[1]
#define INIT_OTHER_DEFINES \
auto* h0 = ctx.Input<Tensor>("H0"); \
auto* wx = ctx.Input<Tensor>("WeightX"); \
auto* bias = ctx.Input<Tensor>("Bias"); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
bool is_reverse = ctx.Attr<bool>("is_reverse"); \
const int M = x_dims[1]; \
const int D = wh_dims[0]; \
const int D2 = D * 2; \
const auto& ker = math::jitkernel::KernelPool::Instance() \
.template Get<math::jitkernel::GRUKernel<T>, \
const std::string&, const std::string&>( \
ctx.Attr<std::string>("gate_activation"), \
ctx.Attr<std::string>("activation"), D); \
const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \
auto place = ctx.GetPlace(); \
#define INIT_OTHER_DEFINES \
auto* h0 = ctx.Input<Tensor>("H0"); \
auto* wx = ctx.Input<Tensor>("WeightX"); \
auto* bias = ctx.Input<Tensor>("Bias"); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
bool is_reverse = ctx.Attr<bool>("is_reverse"); \
const int M = x_dims[1]; \
const int D = wh_dims[0]; \
const int D2 = D * 2; \
const math::jitkernel::gru_attr_t attr( \
D, ctx.Attr<std::string>("gate_activation"), \
ctx.Attr<std::string>("activation")); \
math::jitkernel::gru_t one_step; \
const auto& ker = \
math::jitkernel::KernelPool::Instance() \
.template Get<math::jitkernel::GRUKernel<T>, \
const math::jitkernel::gru_attr_t&>(attr); \
const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \
auto place = ctx.GetPlace(); \
T* xx_data = xx->mutable_data<T>(place)
void
SeqCompute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
...
...
@@ -237,7 +240,9 @@ class FusionGRUKernel : public framework::OpKernel<T> {
if
(
h0_data
)
{
prev_hidden_data
=
h0_data
+
bid
*
D
;
}
else
{
ker
->
ComputeH1
(
xx_data
,
hidden_out_data
);
one_step
.
gates
=
xx_data
;
one_step
.
ht
=
hidden_out_data
;
ker
->
ComputeH1
(
&
one_step
,
&
attr
);
prev_hidden_data
=
hidden_out_data
;
tstart
=
1
;
move_step
();
...
...
@@ -247,12 +252,15 @@ class FusionGRUKernel : public framework::OpKernel<T> {
blas
.
GEMM
(
CblasNoTrans
,
CblasNoTrans
,
1
,
D2
,
D
,
static_cast
<
T
>
(
1
),
prev_hidden_data
,
D
,
wh_data
,
D2
,
static_cast
<
T
>
(
1
),
xx_data
,
D3
);
ker
->
ComputeHtPart1
(
xx_data
,
prev_hidden_data
,
hidden_out_data
);
one_step
.
gates
=
xx_data
;
one_step
.
ht_1
=
prev_hidden_data
;
one_step
.
ht
=
hidden_out_data
;
ker
->
ComputeHtPart1
(
&
one_step
,
&
attr
);
// gemm rt * Ws
blas
.
GEMM
(
CblasNoTrans
,
CblasNoTrans
,
1
,
D
,
D
,
static_cast
<
T
>
(
1
),
hidden_out_data
,
D
,
wh_state_data
,
D
,
static_cast
<
T
>
(
1
),
xx_data
+
D2
,
D3
);
ker
->
ComputeHtPart2
(
xx_data
,
prev_hidden_data
,
hidden_out_data
);
ker
->
ComputeHtPart2
(
&
one_step
,
&
attr
);
// save prev
prev_hidden_data
=
hidden_out_data
;
move_step
();
...
...
@@ -314,7 +322,9 @@ class FusionGRUKernel : public framework::OpKernel<T> {
T
*
cur_out_data
=
batched_out_data
;
// W: {W_update, W_reset; W_state}
for
(
int
i
=
0
;
i
<
max_bs
;
++
i
)
{
ker
->
ComputeH1
(
cur_in_data
,
cur_out_data
);
one_step
.
gates
=
cur_in_data
;
one_step
.
ht
=
cur_out_data
;
ker
->
ComputeH1
(
&
one_step
,
&
attr
);
// add offset
cur_in_data
+=
D3
;
cur_out_data
+=
D
;
...
...
@@ -339,8 +349,11 @@ class FusionGRUKernel : public framework::OpKernel<T> {
T
*
cur_out_data
=
batched_out_data
;
T
*
cur_prev_hidden_data
=
prev_hidden_data
;
for
(
int
i
=
0
;
i
<
cur_bs
;
++
i
)
{
ker
->
ComputeHtPart1
(
cur_batched_data
,
cur_prev_hidden_data
,
cur_out_data
);
one_step
.
gates
=
cur_batched_data
;
one_step
.
ht_1
=
cur_prev_hidden_data
;
one_step
.
ht
=
cur_out_data
;
ker
->
ComputeHtPart1
(
&
one_step
,
&
attr
);
cur_batched_data
+=
D3
;
cur_prev_hidden_data
+=
D
;
cur_out_data
+=
D
;
...
...
@@ -354,8 +367,10 @@ class FusionGRUKernel : public framework::OpKernel<T> {
cur_prev_hidden_data
=
prev_hidden_data
;
for
(
int
i
=
0
;
i
<
cur_bs
;
++
i
)
{
ker
->
ComputeHtPart2
(
cur_batched_data
,
cur_prev_hidden_data
,
cur_out_data
);
one_step
.
gates
=
cur_batched_data
;
one_step
.
ht_1
=
cur_prev_hidden_data
;
one_step
.
ht
=
cur_out_data
;
ker
->
ComputeHtPart2
(
&
one_step
,
&
attr
);
cur_batched_data
+=
D3
;
cur_prev_hidden_data
+=
D
;
cur_out_data
+=
D
;
...
...
paddle/fluid/operators/fused/fusion_lstm_op.cc
浏览文件 @
30e47bce
...
...
@@ -236,27 +236,31 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
const int D = wh_dims[0]; \
const int D4 = wh_dims[1]
#define INIT_OTHER_DEFINES \
const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \
/* diagonal weight*/
\
const T* wp_data = bias->data<T>() + D4; \
/* for peephole only*/
\
T* checked_cell_data = nullptr; \
auto place = ctx.GetPlace(); \
if (use_peepholes) { \
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/
\
auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \
checked_cell_data = checked_cell->mutable_data<T>(place); \
} \
const auto& ker = \
math::jitkernel::KernelPool::Instance() \
.template Get<math::jitkernel::LSTMKernel<T>, const std::string&, \
const std::string&, const std::string&>( \
ctx.Attr<std::string>("gate_activation"), \
ctx.Attr<std::string>("candidate_activation"), \
ctx.Attr<std::string>("cell_activation"), D, use_peepholes)
#define INIT_OTHER_DEFINES \
const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \
/* diagonal weight*/
\
const T* wp_data = bias->data<T>() + D4; \
/* for peephole only*/
\
T* checked_cell_data = nullptr; \
auto place = ctx.GetPlace(); \
if (use_peepholes) { \
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/
\
auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \
checked_cell_data = checked_cell->mutable_data<T>(place); \
} \
const math::jitkernel::lstm_attr_t attr( \
D, ctx.Attr<std::string>("gate_activation"), \
ctx.Attr<std::string>("candidate_activation"), \
ctx.Attr<std::string>("cell_activation"), use_peepholes); \
math::jitkernel::lstm_t one_step; \
one_step.wp = wp_data; \
one_step.checked = checked_cell_data; \
const auto& ker = \
math::jitkernel::KernelPool::Instance() \
.template Get<math::jitkernel::LSTMKernel<T>, \
const math::jitkernel::lstm_attr_t&>(attr)
// Wh GEMM
#define GEMM_WH_ADDON(bs, prev, out) \
...
...
@@ -299,7 +303,10 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
prev_h_data
=
h0_data
+
bid
*
D
;
prev_c_data
=
c0_data
+
bid
*
D
;
}
else
{
ker
->
ComputeC1H1
(
xx_data
,
c_out_data
,
h_out_data
,
wp_data
);
one_step
.
gates
=
xx_data
;
one_step
.
ct
=
c_out_data
;
one_step
.
ht
=
h_out_data
;
ker
->
ComputeC1H1
(
&
one_step
,
&
attr
);
tstart
=
1
;
// move one step
prev_h_data
=
h_out_data
;
...
...
@@ -310,8 +317,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
}
for
(
int
step
=
tstart
;
step
<
seq_len
;
++
step
)
{
GEMM_WH_ADDON
(
1
,
prev_h_data
,
xx_data
);
ker
->
ComputeCtHt
(
xx_data
,
prev_c_data
,
c_out_data
,
h_out_data
,
wp_data
,
checked_cell_data
);
one_step
.
gates
=
xx_data
;
one_step
.
ct_1
=
prev_c_data
;
one_step
.
ct
=
c_out_data
;
one_step
.
ht
=
h_out_data
;
ker
->
ComputeCtHt
(
&
one_step
,
&
attr
);
// move one step
prev_h_data
=
h_out_data
;
prev_c_data
=
c_out_data
;
...
...
@@ -388,7 +399,11 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
T
*
cur_h_out_data
=
batched_h_out_data
;
T
*
cur_c_out_data
=
batched_c_out_data
;
for
(
int
i
=
0
;
i
<
max_bs
;
++
i
)
{
ker
->
ComputeC1H1
(
cur_in_data
,
cur_c_out_data
,
cur_h_out_data
,
wp_data
);
one_step
.
gates
=
cur_in_data
;
one_step
.
ct
=
cur_c_out_data
;
one_step
.
ht
=
cur_h_out_data
;
ker
->
ComputeC1H1
(
&
one_step
,
&
attr
);
cur_in_data
+=
D4
;
cur_c_out_data
+=
D
;
cur_h_out_data
+=
D
;
...
...
@@ -413,8 +428,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
T
*
cur_c_out_data
=
batched_c_out_data
;
T
*
cur_h_out_data
=
batched_h_out_data
;
for
(
int
i
=
0
;
i
<
cur_bs
;
++
i
)
{
ker
->
ComputeCtHt
(
cur_in_data
,
cur_prev_c_data
,
cur_c_out_data
,
cur_h_out_data
,
wp_data
,
checked_cell_data
);
one_step
.
gates
=
cur_in_data
;
one_step
.
ct_1
=
cur_prev_c_data
;
one_step
.
ct
=
cur_c_out_data
;
one_step
.
ht
=
cur_h_out_data
;
ker
->
ComputeCtHt
(
&
one_step
,
&
attr
);
// move one batch
cur_in_data
+=
D4
;
cur_prev_c_data
+=
D
;
...
...
paddle/fluid/operators/math/jit_code.cc
浏览文件 @
30e47bce
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/jit_code.h"
#include <stddef.h> // offsetof
#include "paddle/fluid/operators/math/jit_kernel.h" // TODO(TJ): remove me
namespace
paddle
{
...
...
@@ -139,32 +140,10 @@ bool VActJitCode::init(int d, operand_type type) {
}
void
VActJitCode
::
generate
()
{
xmm_t
xmm_zero
=
xmm_t
(
2
);
ymm_t
ymm_zero
=
ymm_t
(
2
);
if
(
type_
==
operand_type
::
relu
)
{
vxorps
(
ymm_zero
,
ymm_zero
,
ymm_zero
);
}
int
offset
=
0
;
for
(
int
i
=
0
;
i
<
num_
/
YMM_FLOAT_BLOCK
;
++
i
)
{
vmovups
(
ymm_src
,
ptr
[
param1
+
offset
]);
switch
(
type_
)
{
case
operand_type
::
relu
:
relu_jmm
<
ymm_t
>
(
ymm_dst
,
ymm_src
,
ymm_zero
);
break
;
case
operand_type
::
exp
:
exp_jmm
<
ymm_t
>
(
ymm_dst
,
ymm_src
,
2
,
3
,
4
,
5
);
break
;
case
operand_type
::
sigmoid
:
sigmoid_jmm
<
ymm_t
>
(
ymm_dst
,
ymm_src
,
2
,
3
,
4
,
5
);
break
;
case
operand_type
::
tanh
:
tanh_jmm
<
ymm_t
>
(
ymm_dst
,
ymm_src
,
2
,
3
,
4
,
5
);
break
;
case
operand_type
::
identity
:
break
;
default:
break
;
}
act
<
ymm_t
>
(
ymm_dst
,
ymm_src
,
type_
);
vmovups
(
ptr
[
param2
+
offset
],
ymm_dst
);
offset
+=
sizeof
(
float
)
*
YMM_FLOAT_BLOCK
;
}
...
...
@@ -181,22 +160,7 @@ void VActJitCode::generate() {
block
=
1
;
vmovss
(
xmm_src
,
ptr
[
param1
+
offset
]);
}
switch
(
type_
)
{
case
operand_type
::
relu
:
relu_jmm
<
xmm_t
>
(
xmm_dst
,
xmm_src
,
xmm_zero
);
break
;
case
operand_type
::
exp
:
exp_jmm
<
xmm_t
>
(
xmm_dst
,
xmm_src
,
2
,
3
,
4
,
5
);
break
;
case
operand_type
::
sigmoid
:
sigmoid_jmm
<
xmm_t
>
(
xmm_dst
,
xmm_src
,
2
,
3
,
4
,
5
);
break
;
case
operand_type
::
tanh
:
tanh_jmm
<
xmm_t
>
(
xmm_dst
,
xmm_src
,
2
,
3
,
4
,
5
);
break
;
default:
break
;
}
act
<
xmm_t
>
(
xmm_dst
,
xmm_src
,
type_
);
if
(
rest
>=
4
)
{
vmovups
(
ptr
[
param2
+
offset
],
xmm_dst
);
}
else
if
(
rest
>=
2
)
{
...
...
@@ -210,6 +174,158 @@ void VActJitCode::generate() {
ret
();
}
bool
LSTMJitCode
::
init
(
int
d
)
{
return
MayIUse
(
avx
)
&&
d
%
8
==
0
;
}
void
LSTMJitCode
::
generate
()
{
if
(
use_peephole_
)
{
preCode
();
}
reg64_t
reg_ptr_gates
=
rax
;
reg64_t
reg_ptr_ct_1
=
r9
;
reg64_t
reg_ptr_ct
=
r10
;
reg64_t
reg_ptr_ht
=
r11
;
reg64_t
reg_ptr_wp
=
r12
;
mov
(
reg_ptr_gates
,
ptr
[
param1
+
offsetof
(
lstm_t
,
gates
)]);
mov
(
reg_ptr_ct_1
,
ptr
[
param1
+
offsetof
(
lstm_t
,
ct_1
)]);
mov
(
reg_ptr_ct
,
ptr
[
param1
+
offsetof
(
lstm_t
,
ct
)]);
mov
(
reg_ptr_ht
,
ptr
[
param1
+
offsetof
(
lstm_t
,
ht
)]);
if
(
use_peephole_
)
{
mov
(
reg_ptr_wp
,
ptr
[
param1
+
offsetof
(
lstm_t
,
wp
)]);
}
int
offset
=
0
;
int
d
=
num_
*
sizeof
(
float
);
for
(
int
i
=
0
;
i
<
num_
/
YMM_FLOAT_BLOCK
;
++
i
)
{
/* gates: W_ch, W_ih, W_fh, W_oh */
ymm_t
ymm_c
=
ymm_t
(
0
);
ymm_t
ymm_i
=
ymm_t
(
1
);
ymm_t
ymm_f
=
ymm_t
(
2
);
ymm_t
ymm_o
=
ymm_t
(
3
);
ymm_t
ymm_ct_1
=
ymm_t
(
4
);
ymm_t
ymm_wp0
=
ymm_t
(
5
);
ymm_t
ymm_wp1
=
ymm_t
(
6
);
ymm_t
ymm_wp2
=
ymm_t
(
7
);
vmovups
(
ymm_c
,
ptr
[
reg_ptr_gates
+
offset
]);
vmovups
(
ymm_i
,
ptr
[
reg_ptr_gates
+
offset
+
d
]);
vmovups
(
ymm_f
,
ptr
[
reg_ptr_gates
+
offset
+
2
*
d
]);
vmovups
(
ymm_o
,
ptr
[
reg_ptr_gates
+
offset
+
3
*
d
]);
if
(
!
compute_c1h1_
)
{
vmovups
(
ymm_ct_1
,
ptr
[
reg_ptr_ct_1
+
offset
]);
}
if
(
use_peephole_
)
{
vmovups
(
ymm_wp0
,
ptr
[
reg_ptr_wp
+
offset
]);
vmovups
(
ymm_wp1
,
ptr
[
reg_ptr_wp
+
offset
+
d
]);
vmovups
(
ymm_wp2
,
ptr
[
reg_ptr_wp
+
offset
+
2
*
d
]);
}
/* C_t = act_cand(c) * act_gate(i) + C_t-1 * act_gate(f) */
// act_cand(c)
act
<
ymm_t
>
(
ymm_c
,
ymm_c
,
act_cand_
);
// act_gate(i) or act_gate(ct_1 * wp0 + i)
if
(
!
compute_c1h1_
&&
use_peephole_
)
{
vmulps
(
ymm_wp0
,
ymm_ct_1
,
ymm_wp0
);
vaddps
(
ymm_i
,
ymm_i
,
ymm_wp0
);
}
act
<
ymm_t
>
(
ymm_i
,
ymm_i
,
act_gate_
);
vmulps
(
ymm_c
,
ymm_c
,
ymm_i
);
if
(
!
compute_c1h1_
)
{
// act_gate(f) or act_gate(ct_1 * wp1 + f)
if
(
use_peephole_
)
{
vmulps
(
ymm_wp1
,
ymm_ct_1
,
ymm_wp1
);
vaddps
(
ymm_f
,
ymm_f
,
ymm_wp1
);
}
act
<
ymm_t
>
(
ymm_f
,
ymm_f
,
act_gate_
);
// ct
vmulps
(
ymm_f
,
ymm_f
,
ymm_ct_1
);
vaddps
(
ymm_f
,
ymm_f
,
ymm_c
);
}
/* H_t = act_cell(C_t) * act_gate(o) */
// act_cell(C_t)
ymm_t
ymm_ct
=
compute_c1h1_
?
ymm_c
:
ymm_f
;
ymm_t
ymm_tmp
=
ymm_i
;
act
<
ymm_t
>
(
ymm_tmp
,
ymm_ct
,
act_cell_
);
// act_gate(o) or act_gate(ct * wp2 + o)
if
(
use_peephole_
)
{
vmulps
(
ymm_wp2
,
ymm_ct
,
ymm_wp2
);
vaddps
(
ymm_o
,
ymm_o
,
ymm_wp2
);
}
act
<
ymm_t
>
(
ymm_o
,
ymm_o
,
act_gate_
);
// ht
vmulps
(
ymm_o
,
ymm_o
,
ymm_tmp
);
// save ct and ht
vmovups
(
ptr
[
reg_ptr_ct
+
offset
],
ymm_ct
);
vmovups
(
ptr
[
reg_ptr_ht
+
offset
],
ymm_o
);
offset
+=
sizeof
(
float
)
*
YMM_FLOAT_BLOCK
;
}
if
(
use_peephole_
)
{
postCode
();
}
else
{
ret
();
}
}
bool
GRUJitCode
::
init
(
int
d
)
{
return
MayIUse
(
avx
)
&&
d
%
8
==
0
;
}
void
GRUJitCode
::
generate
()
{
reg64_t
reg_ptr_gates
=
rax
;
reg64_t
reg_ptr_ht_1
=
r9
;
reg64_t
reg_ptr_ht
=
r10
;
mov
(
reg_ptr_gates
,
ptr
[
param1
+
offsetof
(
gru_t
,
gates
)]);
mov
(
reg_ptr_ht_1
,
ptr
[
param1
+
offsetof
(
gru_t
,
ht_1
)]);
mov
(
reg_ptr_ht
,
ptr
[
param1
+
offsetof
(
gru_t
,
ht
)]);
ymm_t
ymm_one
=
ymm_t
(
0
);
if
(
id_
==
2
)
{
reg64_t
reg_ptr_tmp
=
r11
;
mov
(
reg_ptr_tmp
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovaps
(
ymm_one
,
ptr
[
reg_ptr_tmp
+
OFFSET_EXP_ONE
]);
}
int
offset
=
0
;
int
d
=
num_
*
sizeof
(
float
);
for
(
int
i
=
0
;
i
<
num_
/
YMM_FLOAT_BLOCK
;
++
i
)
{
ymm_t
ymm_u
=
ymm_t
(
1
);
ymm_t
ymm_r
=
ymm_t
(
2
);
ymm_t
ymm_s
=
ymm_t
(
3
);
ymm_t
ymm_ht_1
=
ymm_t
(
4
);
// W: {W_update, W_reset; W_state}
if
(
id_
==
0
||
id_
==
2
)
{
vmovups
(
ymm_u
,
ptr
[
reg_ptr_gates
+
offset
]);
vmovups
(
ymm_s
,
ptr
[
reg_ptr_gates
+
offset
+
2
*
d
]);
}
if
(
id_
==
1
)
{
vmovups
(
ymm_r
,
ptr
[
reg_ptr_gates
+
offset
+
d
]);
}
if
(
id_
==
1
||
id_
==
2
)
{
vmovups
(
ymm_ht_1
,
ptr
[
reg_ptr_ht_1
+
offset
]);
}
if
(
id_
==
0
)
{
// ht = act_gate(u) * act_cand(s)
act
<
ymm_t
>
(
ymm_u
,
ymm_u
,
act_gate_
);
act
<
ymm_t
>
(
ymm_s
,
ymm_s
,
act_cand_
);
vmulps
(
ymm_s
,
ymm_s
,
ymm_u
);
vmovups
(
ptr
[
reg_ptr_ht
+
offset
],
ymm_s
);
}
else
if
(
id_
==
1
)
{
// ht = act_gate(r) * ht_1
act
<
ymm_t
>
(
ymm_r
,
ymm_r
,
act_gate_
);
vmulps
(
ymm_r
,
ymm_r
,
ymm_ht_1
);
vmovups
(
ptr
[
reg_ptr_ht
+
offset
],
ymm_r
);
}
else
if
(
id_
==
2
)
{
// ht = act_gate(u) * act_cand(s) + (1-act_gate(u)) * ht_1
ymm_t
ymm_one_inner
=
ymm_t
(
ymm_one
.
getIdx
());
act
<
ymm_t
>
(
ymm_u
,
ymm_u
,
act_gate_
);
act
<
ymm_t
>
(
ymm_s
,
ymm_s
,
act_cand_
);
vmulps
(
ymm_s
,
ymm_s
,
ymm_u
);
vsubps
(
ymm_u
,
ymm_one_inner
,
ymm_u
);
vmulps
(
ymm_u
,
ymm_ht_1
,
ymm_u
);
vaddps
(
ymm_u
,
ymm_s
,
ymm_u
);
vmovups
(
ptr
[
reg_ptr_ht
+
offset
],
ymm_u
);
}
offset
+=
sizeof
(
float
)
*
YMM_FLOAT_BLOCK
;
}
ret
();
}
}
// namespace gen
}
// namespace jitkernel
}
// namespace math
...
...
paddle/fluid/operators/math/jit_code.h
浏览文件 @
30e47bce
...
...
@@ -16,6 +16,7 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/operators/math/jit_gen.h"
#include "paddle/fluid/operators/math/jit_kernel_impl.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace
paddle
{
...
...
@@ -46,14 +47,6 @@ extern const float exp_float_consts[];
extern
const
int
exp_int_0x7f
[];
extern
int
g_tmp_mem
[];
// TODO(TJ): move these to some proper place
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
#define XMM_FLOAT_BLOCK 4
#define YMM_FLOAT_BLOCK 8
#define ZMM_FLOAT_BLOCK 16
#define ALIGN32 __attribute__((aligned(32)))
#define EXP_HIG 88.3762626647949f
#define EXP_LOW -88.3762626647949f
...
...
@@ -176,31 +169,34 @@ class VActJitCode : public JitCode {
protected:
// compute relu with ymm, xmm
template
<
typename
JMM
>
void
relu_jmm
(
JMM
&
dst
,
JMM
&
src
,
JMM
&
zero
)
{
// NOLINT
void
relu_jmm
(
JMM
&
dst
,
JMM
&
src
,
int
zero_idx
=
15
)
{
// NOLINT
JMM
zero
=
JMM
(
zero_idx
);
vxorps
(
zero
,
zero
,
zero
);
vmaxps
(
dst
,
src
,
zero
);
}
// compute exp with ymm, xmm
template
<
typename
JMM
>
void
exp_jmm
(
JMM
&
dst
,
JMM
&
src
,
int
fx_idx
=
2
,
int
fy_idx
=
3
,
// NOLINT
int
mask_idx
=
4
,
int
tmp_idx
=
5
)
{
using
namespace
platform
::
jit
;
// NOLINT
assert
(
src
.
getIdx
()
!=
dst
.
getIdx
());
// TODO(TJ): use enfore
void
exp_jmm
(
JMM
&
dst
,
JMM
&
src
,
int
src_idx
=
11
,
int
fx_idx
=
12
,
// NOLINT
int
fy_idx
=
13
,
int
mask_idx
=
14
,
int
tmp_idx
=
15
)
{
using
namespace
platform
::
jit
;
// NOLINT
// check all idx can not equal
JMM
jmm_src
=
JMM
(
src_idx
);
JMM
jmm_fx
=
JMM
(
fx_idx
);
JMM
jmm_fy
=
JMM
(
fy_idx
);
JMM
jmm_mask
=
JMM
(
mask_idx
);
JMM
jmm_tmp
=
JMM
(
tmp_idx
);
reg64_t
reg_ptr_global
=
rax
;
push
(
reg_ptr_global
);
vmovaps
(
jmm_src
,
src
);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_HIG
]);
vminps
(
src
,
src
,
jmm_tmp
);
vminps
(
jmm_src
,
jmm_
src
,
jmm_tmp
);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_LOW
]);
vmaxps
(
src
,
src
,
jmm_tmp
);
vmaxps
(
jmm_src
,
jmm_
src
,
jmm_tmp
);
// express exp(x) as exp(g + n*log(2))
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_LOG2EF
]);
vmulps
(
jmm_fx
,
src
,
jmm_tmp
);
vmulps
(
jmm_fx
,
jmm_
src
,
jmm_tmp
);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_0P5
]);
vaddps
(
jmm_fx
,
jmm_fx
,
jmm_tmp
);
vroundps
(
jmm_fy
,
jmm_fx
,
0x01
);
...
...
@@ -214,21 +210,21 @@ class VActJitCode : public JitCode {
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_C2
]);
JMM
ymm_z
=
JMM
(
jmm_mask
.
getIdx
());
vmulps
(
ymm_z
,
jmm_fx
,
jmm_tmp
);
vsubps
(
src
,
src
,
jmm_fy
);
vsubps
(
src
,
src
,
ymm_z
);
vmulps
(
ymm_z
,
src
,
src
);
vsubps
(
jmm_src
,
jmm_
src
,
jmm_fy
);
vsubps
(
jmm_src
,
jmm_
src
,
ymm_z
);
vmulps
(
ymm_z
,
jmm_src
,
jmm_
src
);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_P0
]);
vmulps
(
dst
,
src
,
jmm_tmp
);
vmulps
(
dst
,
jmm_
src
,
jmm_tmp
);
for
(
size_t
i
=
OFFSET_EXP_P1
;
i
<
OFFSET_EXP_P5
;
i
+=
(
YMM_FLOAT_BLOCK
*
sizeof
(
float
)))
{
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
i
]);
// P1~P4
vaddps
(
dst
,
dst
,
jmm_tmp
);
vmulps
(
dst
,
dst
,
src
);
vmulps
(
dst
,
dst
,
jmm_
src
);
}
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_P5
]);
vaddps
(
dst
,
dst
,
jmm_tmp
);
vmulps
(
dst
,
dst
,
ymm_z
);
vaddps
(
dst
,
dst
,
src
);
vaddps
(
dst
,
dst
,
jmm_
src
);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
]);
vaddps
(
dst
,
dst
,
jmm_tmp
);
// build 2^n
...
...
@@ -265,20 +261,23 @@ class VActJitCode : public JitCode {
// compute sigmoid with ymm, xmm
template
<
typename
JMM
>
void
sigmoid_jmm
(
JMM
&
dst
,
JMM
&
src
,
int
fx_idx
=
2
,
// NOLINT
int
fy_idx
=
3
,
int
mask_idx
=
4
,
int
tmp_idx
=
5
)
{
void
sigmoid_jmm
(
JMM
&
dst
,
JMM
&
src
,
int
src_idx
=
11
,
// NOLINT
int
fx_idx
=
12
,
int
fy_idx
=
13
,
int
mask_idx
=
14
,
int
tmp_idx
=
15
)
{
// y = 1 / (1 + e^-x)
JMM
jmm_tmp
=
JMM
(
tmp_idx
);
JMM
jmm_src
=
JMM
(
src_idx
);
reg64_t
reg_ptr_global
=
rax
;
push
(
reg_ptr_global
);
vmovaps
(
jmm_src
,
src
);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_SIGMOID_MAX
]);
vminps
(
src
,
src
,
jmm_tmp
);
vminps
(
jmm_src
,
jmm_
src
,
jmm_tmp
);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_SIGMOID_MIN
]);
vmaxps
(
src
,
src
,
jmm_tmp
);
vmaxps
(
jmm_src
,
jmm_
src
,
jmm_tmp
);
vxorps
(
jmm_tmp
,
jmm_tmp
,
jmm_tmp
);
vsubps
(
src
,
jmm_tmp
,
src
);
exp_jmm
<
JMM
>
(
dst
,
src
,
fx_idx
,
fy_idx
,
mask_idx
,
tmp_idx
);
vsubps
(
jmm_src
,
jmm_tmp
,
jmm_
src
);
exp_jmm
<
JMM
>
(
dst
,
jmm_src
,
src_idx
,
fx_idx
,
fy_idx
,
mask_idx
,
tmp_idx
);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_ONE
]);
vaddps
(
dst
,
dst
,
jmm_tmp
);
vdivps
(
dst
,
jmm_tmp
,
dst
);
...
...
@@ -287,19 +286,22 @@ class VActJitCode : public JitCode {
// compute tanh with ymm, xmm
template
<
typename
JMM
>
void
tanh_jmm
(
JMM
&
dst
,
JMM
&
src
,
int
fx_idx
=
2
,
int
fy_idx
=
3
,
// NOLINT
int
mask_idx
=
4
,
int
tmp_idx
=
5
)
{
void
tanh_jmm
(
JMM
&
dst
,
JMM
&
src
,
int
src_idx
=
11
,
// NOLINT
int
fx_idx
=
12
,
int
fy_idx
=
13
,
int
mask_idx
=
14
,
int
tmp_idx
=
15
)
{
// y = 2 / (1 + e^(-2x)) - 1
JMM
jmm_src
=
JMM
(
src_idx
);
JMM
jmm_tmp
=
JMM
(
tmp_idx
);
JMM
jmm_zero
=
JMM
(
mask_idx
);
reg64_t
reg_ptr_global
=
rax
;
push
(
reg_ptr_global
);
vmovaps
(
jmm_src
,
src
);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_TWO
]);
vxorps
(
jmm_zero
,
jmm_zero
,
jmm_zero
);
vsubps
(
jmm_tmp
,
jmm_zero
,
jmm_tmp
);
vmulps
(
src
,
src
,
jmm_tmp
);
exp_jmm
<
JMM
>
(
dst
,
src
,
fx_idx
,
fy_idx
,
mask_idx
,
tmp_idx
);
vmulps
(
jmm_src
,
jmm_
src
,
jmm_tmp
);
exp_jmm
<
JMM
>
(
dst
,
jmm_src
,
src_idx
,
fx_idx
,
fy_idx
,
mask_idx
,
tmp_idx
);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_ONE
]);
vaddps
(
dst
,
dst
,
jmm_tmp
);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_TWO
]);
...
...
@@ -309,6 +311,30 @@ class VActJitCode : public JitCode {
pop
(
reg_ptr_global
);
}
template
<
typename
JMM
>
void
act
(
JMM
&
dst
,
JMM
&
src
,
operand_type
type
)
{
// NOLINT
// use 11~15
switch
(
type
)
{
case
operand_type
::
relu
:
relu_jmm
<
JMM
>
(
dst
,
src
,
15
);
break
;
case
operand_type
::
exp
:
exp_jmm
<
JMM
>
(
dst
,
src
,
11
,
12
,
13
,
14
,
15
);
break
;
case
operand_type
::
sigmoid
:
sigmoid_jmm
<
JMM
>
(
dst
,
src
,
11
,
12
,
13
,
14
,
15
);
break
;
case
operand_type
::
tanh
:
tanh_jmm
<
JMM
>
(
dst
,
src
,
11
,
12
,
13
,
14
,
15
);
break
;
case
operand_type
::
identity
:
break
;
default:
// throw error
break
;
}
}
protected:
int
num_
;
operand_type
type_
;
...
...
@@ -322,6 +348,148 @@ class VActJitCode : public JitCode {
ymm_t
ymm_dst
=
ymm_t
(
1
);
};
class
LSTMJitCode
:
public
VActJitCode
{
public:
const
char
*
name
()
const
override
{
std
::
string
base
=
"LSTMJitCode"
;
if
(
use_peephole_
)
{
base
+=
"_Peephole"
;
}
if
(
compute_c1h1_
)
{
base
+=
"_C1H1"
;
}
auto
AddTypeStr
=
[
&
](
operand_type
type
)
{
switch
(
type
)
{
case
operand_type
::
relu
:
base
+=
"_Relu"
;
break
;
case
operand_type
::
exp
:
base
+=
"_Exp"
;
break
;
case
operand_type
::
sigmoid
:
base
+=
"_Sigmoid"
;
break
;
case
operand_type
::
tanh
:
base
+=
"_Tanh"
;
break
;
case
operand_type
::
identity
:
base
+=
"_Identity"
;
break
;
default:
break
;
}
};
AddTypeStr
(
act_gate_
);
AddTypeStr
(
act_cand_
);
AddTypeStr
(
act_cell_
);
return
base
.
c_str
();
}
explicit
LSTMJitCode
(
bool
compute_c1h1
,
const
lstm_attr_t
&
attr
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
:
VActJitCode
(
attr
.
d
,
operand_type
::
sigmoid
/* this is bugy*/
,
code_size
,
code_ptr
),
compute_c1h1_
(
compute_c1h1
)
{
auto
typeExchange
=
[](
const
std
::
string
&
type
)
->
gen
::
operand_type
{
if
(
type
==
"sigmoid"
)
{
return
operand_type
::
sigmoid
;
}
else
if
(
type
==
"relu"
)
{
return
operand_type
::
relu
;
}
else
if
(
type
==
"tanh"
)
{
return
operand_type
::
tanh
;
}
else
if
(
type
==
"identity"
||
type
==
""
)
{
return
operand_type
::
identity
;
}
// else throw error
return
operand_type
::
identity
;
};
num_
=
attr
.
d
;
use_peephole_
=
attr
.
use_peephole
;
act_gate_
=
typeExchange
(
attr
.
act_gate
);
act_cand_
=
typeExchange
(
attr
.
act_cand
);
act_cell_
=
typeExchange
(
attr
.
act_cell
);
}
static
bool
init
(
int
d
);
void
generate
()
override
;
protected:
int
num_
;
bool
compute_c1h1_
;
bool
use_peephole_
;
operand_type
act_gate_
;
operand_type
act_cand_
;
operand_type
act_cell_
;
reg64_t
param1
{
abi_param1
};
};
class
GRUJitCode
:
public
VActJitCode
{
public:
const
char
*
name
()
const
override
{
std
::
string
base
=
"GRUJitCode"
;
if
(
id_
==
0
)
{
base
+=
"_H1"
;
}
else
if
(
id_
==
1
)
{
base
+=
"_HtPart1"
;
}
else
if
(
id_
==
2
)
{
base
+=
"_HtPart2"
;
}
auto
AddTypeStr
=
[
&
](
operand_type
type
)
{
switch
(
type
)
{
case
operand_type
::
relu
:
base
+=
"_Relu"
;
break
;
case
operand_type
::
exp
:
base
+=
"_Exp"
;
break
;
case
operand_type
::
sigmoid
:
base
+=
"_Sigmoid"
;
break
;
case
operand_type
::
tanh
:
base
+=
"_Tanh"
;
break
;
case
operand_type
::
identity
:
base
+=
"_Identity"
;
break
;
default:
break
;
}
};
AddTypeStr
(
act_gate_
);
AddTypeStr
(
act_cand_
);
return
base
.
c_str
();
}
explicit
GRUJitCode
(
int
id
,
const
gru_attr_t
&
attr
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
:
VActJitCode
(
attr
.
d
,
operand_type
::
sigmoid
/* this is bugy*/
,
code_size
,
code_ptr
),
id_
(
id
)
{
auto
typeExchange
=
[](
const
std
::
string
&
type
)
->
gen
::
operand_type
{
if
(
type
==
"sigmoid"
)
{
return
operand_type
::
sigmoid
;
}
else
if
(
type
==
"relu"
)
{
return
operand_type
::
relu
;
}
else
if
(
type
==
"tanh"
)
{
return
operand_type
::
tanh
;
}
else
if
(
type
==
"identity"
||
type
==
""
)
{
return
operand_type
::
identity
;
}
// else throw error
return
operand_type
::
identity
;
};
num_
=
attr
.
d
;
act_gate_
=
typeExchange
(
attr
.
act_gate
);
act_cand_
=
typeExchange
(
attr
.
act_cand
);
}
static
bool
init
(
int
d
);
void
generate
()
override
;
protected:
int
id_
;
int
num_
;
operand_type
act_gate_
;
operand_type
act_cand_
;
reg64_t
param1
{
abi_param1
};
};
#ifdef PADDLE_WITH_MKLDNN
struct
EltwiseMulnChw16cNC
:
public
Xbyak
::
CodeGenerator
{
explicit
EltwiseMulnChw16cNC
(
size_t
code_size
=
256
*
1024
)
...
...
paddle/fluid/operators/math/jit_kernel.h
浏览文件 @
30e47bce
...
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include <memory> // for shared_ptr
#include <string>
#include <unordered_map>
#include "paddle/fluid/operators/math/jit_kernel_impl.h"
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/macros.h"
...
...
@@ -26,14 +27,7 @@ namespace operators {
namespace
math
{
namespace
jitkernel
{
// TODO(TJ): move these to some proper place
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
#define XMM_FLOAT_BLOCK 4
#define YMM_FLOAT_BLOCK 8
#define ZMM_FLOAT_BLOCK 16
// TODO(TJ): remove me
typedef
enum
{
kLT8
,
kEQ8
,
kGT8LT16
,
kEQ16
,
kGT16
}
jit_block
;
class
Kernel
{
...
...
@@ -128,24 +122,18 @@ class VTanhKernel : public VActKernel<T> {};
template
<
typename
T
>
class
LSTMKernel
:
public
Kernel
{
public:
virtual
void
ComputeCtHt
(
T
*
gates
,
const
T
*
ct_1
,
T
*
ct
,
T
*
ht
,
/* below only used in peephole*/
const
T
*
wp_data
=
nullptr
,
T
*
checked
=
nullptr
)
const
=
0
;
// compute c1 and h1 without c0 or h0
virtual
void
ComputeC1H1
(
T
*
gates
,
T
*
ct
,
T
*
ht
,
/* below only used in peephole*/
const
T
*
wp_data
=
nullptr
)
const
=
0
;
void
(
*
ComputeC1H1
)(
lstm_t
*
,
const
lstm_attr_t
*
);
void
(
*
ComputeCtHt
)(
lstm_t
*
,
const
lstm_attr_t
*
);
};
template
<
typename
T
>
class
GRUKernel
:
public
Kernel
{
public:
// compute h1 without h0
v
irtual
void
ComputeH1
(
T
*
gates
,
T
*
ht
)
const
=
0
;
v
irtual
void
ComputeHtPart1
(
T
*
gates
,
const
T
*
ht_1
,
T
*
ht
)
const
=
0
;
v
irtual
void
ComputeHtPart2
(
T
*
gates
,
const
T
*
ht_1
,
T
*
ht
)
const
=
0
;
v
oid
(
*
ComputeH1
)(
gru_t
*
,
const
gru_attr_t
*
)
;
v
oid
(
*
ComputeHtPart1
)(
gru_t
*
,
const
gru_attr_t
*
)
;
v
oid
(
*
ComputeHtPart2
)(
gru_t
*
,
const
gru_attr_t
*
)
;
};
template
<
typename
T
>
...
...
paddle/fluid/operators/math/jit_kernel_blas.cc
浏览文件 @
30e47bce
...
...
@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h"
#include <string>
#include "paddle/fluid/operators/math/jit_kernel_macro.h"
#include "paddle/fluid/operators/math/jit_kernel_refer.h"
#include "paddle/fluid/platform/enforce.h"
#ifdef PADDLE_WITH_XBYAK
...
...
@@ -31,49 +32,6 @@ namespace math {
namespace
jitkernel
{
namespace
jit
=
platform
::
jit
;
template
<
typename
T
>
void
VMulRefer
(
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int
n
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
z
[
i
]
=
x
[
i
]
*
y
[
i
];
}
}
template
<
typename
T
>
void
VAddRefer
(
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int
n
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
z
[
i
]
=
x
[
i
]
+
y
[
i
];
}
}
template
<
typename
T
>
void
VAddReluRefer
(
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int
n
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
z
[
i
]
=
x
[
i
]
+
y
[
i
];
z
[
i
]
=
z
[
i
]
>
0
?
z
[
i
]
:
0
;
}
}
template
<
typename
T
>
void
VScalRefer
(
const
T
*
a
,
const
T
*
x
,
T
*
y
,
int
n
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
a
[
0
]
*
x
[
i
];
}
}
template
<
typename
T
>
void
VAddBiasRefer
(
const
T
*
a
,
const
T
*
x
,
T
*
y
,
int
n
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
a
[
0
]
+
x
[
i
];
}
}
template
<
typename
T
>
void
VReluRefer
(
const
T
*
x
,
T
*
y
,
int
n
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
x
[
i
]
>
0
?
x
[
i
]
:
0
;
}
}
#ifdef PADDLE_WITH_MKLML
template
<
typename
T
>
void
VMulMKL
(
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int
n
);
...
...
@@ -109,7 +67,7 @@ void VScalMKL<float>(const float* a, const float* x, float* y, int n) {
if
(
x
==
y
)
{
platform
::
dynload
::
cblas_sscal
(
n
,
*
a
,
y
,
1
);
}
else
{
VScalRefer
<
float
>
(
a
,
x
,
y
,
n
);
refer
::
VScal
<
float
>
(
a
,
x
,
y
,
n
);
}
}
...
...
@@ -118,7 +76,7 @@ void VScalMKL<double>(const double* a, const double* x, double* y, int n) {
if
(
x
==
y
)
{
platform
::
dynload
::
cblas_dscal
(
n
,
*
a
,
y
,
1
);
}
else
{
VScalRefer
<
double
>
(
a
,
x
,
y
,
n
);
refer
::
VScal
<
double
>
(
a
,
x
,
y
,
n
);
}
}
...
...
@@ -147,7 +105,7 @@ class VMulKernelImpl : public VMulKernel<T> {
return
;
}
#endif
this
->
Compute
=
VMulRefer
<
T
>
;
this
->
Compute
=
refer
::
VMul
<
T
>
;
}
#ifdef PADDLE_WITH_XBYAK
...
...
@@ -198,7 +156,7 @@ class VAddKernelImpl : public VAddKernel<T> {
return
;
}
#endif
this
->
Compute
=
VAddRefer
<
T
>
;
this
->
Compute
=
refer
::
VAdd
<
T
>
;
}
#ifdef PADDLE_WITH_XBYAK
...
...
@@ -280,7 +238,7 @@ class VAddReluKernelImpl : public VAddReluKernel<T> {
return
;
}
#endif
this
->
Compute
=
VAddReluRefer
<
T
>
;
this
->
Compute
=
refer
::
VAddRelu
<
T
>
;
}
#ifdef PADDLE_WITH_XBYAK
...
...
@@ -318,7 +276,7 @@ class VScalKernelImpl : public VScalKernel<T> {
return
;
}
#endif
this
->
Compute
=
VScalRefer
<
T
>
;
this
->
Compute
=
refer
::
VScal
<
T
>
;
}
#ifdef PADDLE_WITH_XBYAK
...
...
@@ -362,7 +320,7 @@ class VAddBiasKernelImpl : public VAddBiasKernel<T> {
}
#endif
this
->
Compute
=
VAddBiasRefer
<
T
>
;
this
->
Compute
=
refer
::
VAddBias
<
T
>
;
}
#ifdef PADDLE_WITH_XBYAK
...
...
@@ -396,7 +354,7 @@ class VReluKernelImpl : public VReluKernel<T> {
}
#endif
this
->
Compute
=
VReluRefer
<
T
>
;
this
->
Compute
=
refer
::
VRelu
<
T
>
;
}
#ifdef PADDLE_WITH_XBYAK
...
...
@@ -412,16 +370,13 @@ bool VReluKernelImpl<float>::useJIT(int d) {
}
#endif
template
<
typename
T
>
inline
void
VIdentityRefer
(
const
T
*
x
,
T
*
y
,
int
n
)
{}
/* An empty JitKernel */
template
<
typename
T
>
class
VIdentityKernelImpl
:
public
VIdentityKernel
<
T
>
{
public:
JITKERNEL_DECLARE_STATIC_FUNC
;
explicit
VIdentityKernelImpl
(
int
d
)
:
VIdentityKernel
<
T
>
()
{
this
->
Compute
=
VIdentityRefer
<
T
>
;
this
->
Compute
=
refer
::
VIdentity
<
T
>
;
}
};
...
...
paddle/fluid/operators/math/jit_kernel_exp.cc
浏览文件 @
30e47bce
...
...
@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h"
#include <cmath> // for exp
#include <string>
#include "paddle/fluid/operators/math/jit_kernel_macro.h"
#include "paddle/fluid/operators/math/jit_kernel_refer.h"
#ifdef PADDLE_WITH_XBYAK
#include "paddle/fluid/operators/math/jit_code.h"
...
...
@@ -25,48 +25,12 @@ limitations under the License. */
#include "paddle/fluid/platform/dynload/mklml.h"
#endif
#ifdef __AVX__
#include <immintrin.h>
#endif
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
jitkernel
{
namespace
jit
=
platform
::
jit
;
// TODO(TJ): move refer codes to one file
// Refer code only focus on correctness
template
<
typename
T
>
void
VExpRefer
(
const
T
*
x
,
T
*
y
,
int
n
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
std
::
exp
(
x
[
i
]);
}
}
template
<
typename
T
>
void
VSigmoidRefer
(
const
T
*
x
,
T
*
y
,
int
n
)
{
// y = 1 / (1 + e^-x)
const
T
min
=
SIGMOID_THRESHOLD_MIN
;
const
T
max
=
SIGMOID_THRESHOLD_MAX
;
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
T
tmp
=
(
x
[
i
]
<
min
)
?
min
:
((
x
[
i
]
>
max
)
?
max
:
x
[
i
]);
y
[
i
]
=
static_cast
<
T
>
(
1
)
/
(
static_cast
<
T
>
(
1
)
+
std
::
exp
(
-
tmp
));
}
}
template
<
typename
T
>
void
VTanhRefer
(
const
T
*
x
,
T
*
y
,
int
n
)
{
// y = 2 * sigmoid(2x) - 1
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
static_cast
<
T
>
(
2
)
*
x
[
i
];
}
VSigmoidRefer
(
y
,
y
,
n
);
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
static_cast
<
T
>
(
2
)
*
y
[
i
]
-
static_cast
<
T
>
(
1
);
}
}
#ifdef PADDLE_WITH_MKLML
// try to use MKL to speedup
template
<
typename
T
>
...
...
@@ -129,7 +93,7 @@ class VExpKernelImpl : public VExpKernel<T> {
return
;
}
#endif
this
->
Compute
=
VExpRefer
<
T
>
;
this
->
Compute
=
refer
::
VExp
<
T
>
;
}
#ifdef PADDLE_WITH_XBYAK
...
...
@@ -182,7 +146,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
return
;
}
#endif
this
->
Compute
=
VSigmoidRefer
<
T
>
;
this
->
Compute
=
refer
::
VSigmoid
<
T
>
;
}
#ifdef PADDLE_WITH_XBYAK
...
...
@@ -234,7 +198,7 @@ class VTanhKernelImpl : public VTanhKernel<T> {
return
;
}
#endif
this
->
Compute
=
VTanhRefer
<
T
>
;
this
->
Compute
=
refer
::
VTanh
<
T
>
;
}
#ifdef PADDLE_WITH_XBYAK
...
...
@@ -267,154 +231,6 @@ REGISTER_JITKERNEL(vexp, VExpKernel);
REGISTER_JITKERNEL
(
vsigmoid
,
VSigmoidKernel
);
REGISTER_JITKERNEL
(
vtanh
,
VTanhKernel
);
namespace
detail
{
#ifdef __AVX__
#define ALIGN32 __attribute__((aligned(32)))
#define _PS256_CONST(Name, Val) \
static const float _ps256_##Name[8] ALIGN32 = {Val, Val, Val, Val, \
Val, Val, Val, Val}
#define _PI256_CONST(Name, Val) \
static const int _pi256_##Name[8] ALIGN32 = {Val, Val, Val, Val, \
Val, Val, Val, Val}
_PI256_CONST
(
0x7f
,
0x7f
);
_PS256_CONST
(
one
,
1.
f
);
_PS256_CONST
(
0
p5
,
0.5
f
);
_PS256_CONST
(
exp_hi
,
88.3762626647949
f
);
_PS256_CONST
(
exp_lo
,
-
88.3762626647949
f
);
_PS256_CONST
(
cephes_LOG2EF
,
1.44269504088896341
);
_PS256_CONST
(
cephes_exp_C1
,
0.693359375
);
_PS256_CONST
(
cephes_exp_C2
,
-
2.12194440e-4
);
_PS256_CONST
(
cephes_exp_p0
,
1.9875691500E-4
);
_PS256_CONST
(
cephes_exp_p1
,
1.3981999507E-3
);
_PS256_CONST
(
cephes_exp_p2
,
8.3334519073E-3
);
_PS256_CONST
(
cephes_exp_p3
,
4.1665795894E-2
);
_PS256_CONST
(
cephes_exp_p4
,
1.6666665459E-1
);
_PS256_CONST
(
cephes_exp_p5
,
5.0000001201E-1
);
typedef
union
imm_xmm_union
{
__m256i
imm
;
__m128i
xmm
[
2
];
}
imm_xmm_union
;
#define COPY_IMM_TO_XMM(imm_, xmm0_, xmm1_) \
{ \
imm_xmm_union u ALIGN32; \
u.imm = imm_; \
xmm0_ = u.xmm[0]; \
xmm1_ = u.xmm[1]; \
}
#define COPY_XMM_TO_IMM(xmm0_, xmm1_, imm_) \
{ \
imm_xmm_union u ALIGN32; \
u.xmm[0] = xmm0_; \
u.xmm[1] = xmm1_; \
imm_ = u.imm; \
}
#define AVX2_BITOP_USING_SSE2(fn) \
static inline __m256i avx2_mm256_##fn(__m256i x, int y) { \
/* use SSE2 to perform the bitop AVX2 */
\
__m128i x1, x2; \
__m256i ret; \
COPY_IMM_TO_XMM(x, x1, x2); \
x1 = _mm_##fn(x1, y); \
x2 = _mm_##fn(x2, y); \
COPY_XMM_TO_IMM(x1, x2, ret); \
return ret; \
}
#define AVX2_INTOP_USING_SSE2(fn) \
static inline __m256i avx2_mm256_add_epi32(__m256i x, __m256i y) { \
/* use SSE2 to perform the AVX2 integer operation */
\
__m128i x1, x2; \
__m128i y1, y2; \
__m256i ret; \
COPY_IMM_TO_XMM(x, x1, x2); \
COPY_IMM_TO_XMM(y, y1, y2); \
x1 = _mm_##fn(x1, y1); \
x2 = _mm_##fn(x2, y2); \
COPY_XMM_TO_IMM(x1, x2, ret); \
return ret; \
}
AVX2_BITOP_USING_SSE2
(
slli_epi32
);
AVX2_INTOP_USING_SSE2
(
add_epi32
);
#define AVXEXP_BASE \
__m256 tmp = _mm256_setzero_ps(), fx; \
__m256 one = *reinterpret_cast<const __m256*>(_ps256_one); \
__m256i imm0; \
x = _mm256_min_ps(x, *reinterpret_cast<const __m256*>(_ps256_exp_hi)); \
x = _mm256_max_ps(x, *reinterpret_cast<const __m256*>(_ps256_exp_lo)); \
/* express exp(x) as exp(g + n*log(2)) */
\
fx = _mm256_mul_ps(x, \
*reinterpret_cast<const __m256*>(_ps256_cephes_LOG2EF)); \
fx = _mm256_add_ps(fx, *reinterpret_cast<const __m256*>(_ps256_0p5)); \
tmp = _mm256_floor_ps(fx); \
/* if greater, substract 1 */
\
__m256 mask = _mm256_cmp_ps(tmp, fx, _CMP_GT_OS); \
mask = _mm256_and_ps(mask, one); \
fx = _mm256_sub_ps(tmp, mask); \
tmp = _mm256_mul_ps(fx, \
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_C1)); \
__m256 z = _mm256_mul_ps( \
fx, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_C2)); \
x = _mm256_sub_ps(x, tmp); \
x = _mm256_sub_ps(x, z); \
z = _mm256_mul_ps(x, x); \
__m256 y = *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p0); \
y = _mm256_mul_ps(y, x); \
y = _mm256_add_ps(y, \
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p1)); \
y = _mm256_mul_ps(y, x); \
y = _mm256_add_ps(y, \
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p2)); \
y = _mm256_mul_ps(y, x); \
y = _mm256_add_ps(y, \
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p3)); \
y = _mm256_mul_ps(y, x); \
y = _mm256_add_ps(y, \
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p4)); \
y = _mm256_mul_ps(y, x); \
y = _mm256_add_ps(y, \
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p5)); \
y = _mm256_mul_ps(y, z); \
y = _mm256_add_ps(y, x); \
y = _mm256_add_ps(y, one); \
/* build 2^n */
\
imm0 = _mm256_cvttps_epi32(fx)
__m256
ExpAVX
(
__m256
x
)
{
AVXEXP_BASE
;
// two AVX2 instructions using SSE2
imm0
=
avx2_mm256_add_epi32
(
imm0
,
*
reinterpret_cast
<
const
__m256i
*>
(
_pi256_0x7f
));
imm0
=
avx2_mm256_slli_epi32
(
imm0
,
23
);
__m256
pow2n
=
_mm256_castsi256_ps
(
imm0
);
y
=
_mm256_mul_ps
(
y
,
pow2n
);
return
y
;
}
#endif
#ifdef __AVX2__
__m256
ExpAVX2
(
__m256
x
)
{
AVXEXP_BASE
;
// two AVX2 instructions
imm0
=
_mm256_add_epi32
(
imm0
,
*
reinterpret_cast
<
const
__m256i
*>
(
_pi256_0x7f
));
imm0
=
_mm256_slli_epi32
(
imm0
,
23
);
__m256
pow2n
=
_mm256_castsi256_ps
(
imm0
);
y
=
_mm256_mul_ps
(
y
,
pow2n
);
return
y
;
}
#endif
}
// namespace detail
}
// namespace jitkernel
}
// namespace math
}
// namespace operators
...
...
paddle/fluid/operators/math/jit_kernel_impl.h
0 → 100644
浏览文件 @
30e47bce
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <type_traits>
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
jitkernel
{
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
#define XMM_FLOAT_BLOCK 4
#define YMM_FLOAT_BLOCK 8
#define ZMM_FLOAT_BLOCK 16
typedef
struct
{
void
*
gates
;
// gates: W_ch, W_ih, W_fh, W_oh
const
void
*
ct_1
;
void
*
ct
;
void
*
ht
;
/* weight_peephole and checked data are only used in peephole*/
const
void
*
wp
{
nullptr
};
void
*
checked
{
nullptr
};
}
lstm_t
;
typedef
struct
{
void
*
gates
;
// gates: {W_update, W_reset; W_state}
const
void
*
ht_1
;
void
*
ht
;
}
gru_t
;
struct
rnn_attr_s
{
int
d
;
std
::
string
act_gate
,
act_cand
;
rnn_attr_s
()
=
default
;
rnn_attr_s
(
int
_d
,
const
std
::
string
&
_act_gate
,
const
std
::
string
&
_act_cand
)
:
d
(
_d
),
act_gate
(
_act_gate
),
act_cand
(
_act_cand
)
{}
};
struct
lstm_attr_s
:
public
rnn_attr_s
{
bool
use_peephole
;
std
::
string
act_cell
;
lstm_attr_s
()
=
default
;
lstm_attr_s
(
int
_d
,
const
std
::
string
&
_act_gate
,
const
std
::
string
&
_act_cand
,
const
std
::
string
&
_act_cell
,
bool
_use_peephole
=
false
)
:
rnn_attr_s
(
_d
,
_act_gate
,
_act_cand
),
use_peephole
(
_use_peephole
),
act_cell
(
_act_cell
)
{}
};
typedef
struct
rnn_attr_s
gru_attr_t
;
typedef
struct
lstm_attr_s
lstm_attr_t
;
}
// namespace jitkernel
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/jit_kernel_macro.h
浏览文件 @
30e47bce
...
...
@@ -82,10 +82,10 @@ namespace jitkernel {
#define REGISTER_JITKERNEL_ARGS(ker_key, ker_class, marco_define_name, \
marco_declare, macro_find_key, macro_impl) \
marco_define_name(ker_key, ker_class); \
REGISTER_JITKERNEL_WITH_DTYPE(ker_class, float,
JITKERNEL_DECLARE,
\
JITKERNEL_FIND_KEY, JITKERNEL_IMPL);
\
REGISTER_JITKERNEL_WITH_DTYPE(ker_class, double,
JITKERNEL_DECLARE,
\
JITKERNEL_FIND_KEY, JITKERNEL_IMPL
)
REGISTER_JITKERNEL_WITH_DTYPE(ker_class, float,
marco_declare,
\
macro_find_key, macro_impl);
\
REGISTER_JITKERNEL_WITH_DTYPE(ker_class, double,
marco_declare,
\
macro_find_key, macro_impl
)
#define REGISTER_JITKERNEL(ker_key, ker_class) \
REGISTER_JITKERNEL_ARGS(ker_key, ker_class, JITKERNEL_DEFINE_NAME, \
...
...
paddle/fluid/operators/math/jit_kernel_refer.h
0 → 100644
浏览文件 @
30e47bce
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cmath>
#include <string>
#include "paddle/fluid/operators/math/jit_kernel_impl.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
jitkernel
{
namespace
refer
{
/* Refer code only focus on correctness */
template
<
typename
T
>
void
VMul
(
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int
n
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
z
[
i
]
=
x
[
i
]
*
y
[
i
];
}
}
template
<
typename
T
>
void
VAdd
(
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int
n
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
z
[
i
]
=
x
[
i
]
+
y
[
i
];
}
}
template
<
typename
T
>
void
VAddRelu
(
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int
n
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
z
[
i
]
=
x
[
i
]
+
y
[
i
];
z
[
i
]
=
z
[
i
]
>
0
?
z
[
i
]
:
0
;
}
}
template
<
typename
T
>
void
VScal
(
const
T
*
a
,
const
T
*
x
,
T
*
y
,
int
n
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
a
[
0
]
*
x
[
i
];
}
}
template
<
typename
T
>
void
VAddBias
(
const
T
*
a
,
const
T
*
x
,
T
*
y
,
int
n
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
a
[
0
]
+
x
[
i
];
}
}
template
<
typename
T
>
void
VRelu
(
const
T
*
x
,
T
*
y
,
int
n
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
x
[
i
]
>
0
?
x
[
i
]
:
0
;
}
}
template
<
typename
T
>
inline
void
VIdentity
(
const
T
*
x
,
T
*
y
,
int
n
)
{}
template
<
typename
T
>
void
VExp
(
const
T
*
x
,
T
*
y
,
int
n
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
std
::
exp
(
x
[
i
]);
}
}
template
<
typename
T
>
void
VSigmoid
(
const
T
*
x
,
T
*
y
,
int
n
)
{
// y = 1 / (1 + e^-x)
const
T
min
=
SIGMOID_THRESHOLD_MIN
;
const
T
max
=
SIGMOID_THRESHOLD_MAX
;
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
T
tmp
=
(
x
[
i
]
<
min
)
?
min
:
((
x
[
i
]
>
max
)
?
max
:
x
[
i
]);
y
[
i
]
=
static_cast
<
T
>
(
1
)
/
(
static_cast
<
T
>
(
1
)
+
std
::
exp
(
-
tmp
));
}
}
template
<
typename
T
>
void
VTanh
(
const
T
*
x
,
T
*
y
,
int
n
)
{
// y = 2 * sigmoid(2x) - 1
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
static_cast
<
T
>
(
2
)
*
x
[
i
];
}
VSigmoid
(
y
,
y
,
n
);
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
static_cast
<
T
>
(
2
)
*
y
[
i
]
-
static_cast
<
T
>
(
1
);
}
}
template
<
typename
T
>
void
(
*
getActFunc
(
const
std
::
string
&
type
))(
const
T
*
,
T
*
,
int
)
{
// NOLINT
if
(
type
==
"sigmoid"
)
{
return
VSigmoid
<
T
>
;
}
else
if
(
type
==
"relu"
)
{
return
VRelu
<
T
>
;
}
else
if
(
type
==
"tanh"
)
{
return
VTanh
<
T
>
;
}
else
if
(
type
==
"identity"
||
type
==
""
)
{
return
VIdentity
<
T
>
;
}
PADDLE_THROW
(
"Not support type: %s"
,
type
);
return
nullptr
;
}
// compute ct and ht
template
<
typename
T
>
void
LSTMCtHt
(
lstm_t
*
step
,
const
lstm_attr_t
*
attr
)
{
T
*
gates
=
reinterpret_cast
<
T
*>
(
step
->
gates
);
const
T
*
ct_1
=
reinterpret_cast
<
const
T
*>
(
step
->
ct_1
);
T
*
ct
=
reinterpret_cast
<
T
*>
(
step
->
ct
);
T
*
ht
=
reinterpret_cast
<
T
*>
(
step
->
ht
);
const
T
*
wp
=
reinterpret_cast
<
const
T
*>
(
step
->
wp
);
T
*
checked
=
reinterpret_cast
<
T
*>
(
step
->
checked
);
auto
act_gate
=
getActFunc
<
T
>
(
attr
->
act_gate
);
auto
act_cand
=
getActFunc
<
T
>
(
attr
->
act_cand
);
auto
act_cell
=
getActFunc
<
T
>
(
attr
->
act_cell
);
int
d
=
attr
->
d
;
int
d2
=
d
*
2
;
int
d3
=
d
*
3
;
// gates: W_ch, W_ih, W_fh, W_oh
if
(
attr
->
use_peephole
)
{
VMul
(
wp
,
ct_1
,
checked
,
d
);
VMul
(
wp
+
d
,
ct_1
,
checked
+
d
,
d
);
VAdd
(
checked
,
gates
+
d
,
gates
+
d
,
d2
);
act_gate
(
gates
+
d
,
gates
+
d
,
d2
);
}
else
{
act_gate
(
gates
+
d
,
gates
+
d
,
d3
);
}
// C_t = C_t-1 * fgated + cand_gated * igated
act_cand
(
gates
,
gates
,
d
);
VMul
(
gates
,
gates
+
d
,
gates
+
d
,
d
);
VMul
(
ct_1
,
gates
+
d2
,
gates
+
d2
,
d
);
VAdd
(
gates
+
d
,
gates
+
d2
,
ct
,
d
);
if
(
attr
->
use_peephole
)
{
// get ogated
VMul
(
wp
+
d2
,
ct
,
gates
+
d
,
d
);
VAdd
(
gates
+
d
,
gates
+
d3
,
gates
+
d3
,
d
);
act_gate
(
gates
+
d3
,
gates
+
d3
,
d
);
}
// H_t = act_cell(C_t) * ogated
act_cell
(
ct
,
gates
+
d2
,
d
);
VMul
(
gates
+
d2
,
gates
+
d3
,
ht
,
d
);
}
// compute c1 and h1 without c0 or h0
template
<
typename
T
>
void
LSTMC1H1
(
lstm_t
*
step
,
const
lstm_attr_t
*
attr
)
{
T
*
gates
=
reinterpret_cast
<
T
*>
(
step
->
gates
);
T
*
ct
=
reinterpret_cast
<
T
*>
(
step
->
ct
);
T
*
ht
=
reinterpret_cast
<
T
*>
(
step
->
ht
);
auto
act_gate
=
getActFunc
<
T
>
(
attr
->
act_gate
);
auto
act_cand
=
getActFunc
<
T
>
(
attr
->
act_cand
);
auto
act_cell
=
getActFunc
<
T
>
(
attr
->
act_cell
);
int
d
=
attr
->
d
;
int
d2
=
d
*
2
;
int
d3
=
d
*
3
;
/* C_t = igated * cgated*/
act_gate
(
gates
+
d
,
gates
+
d
,
d
);
act_cand
(
gates
,
gates
,
d
);
VMul
(
gates
,
gates
+
d
,
ct
,
d
);
if
(
attr
->
use_peephole
)
{
// get outgated, put W_oc * C_t on igated
const
T
*
wp
=
reinterpret_cast
<
const
T
*>
(
step
->
wp
);
VMul
(
wp
+
d2
,
ct
,
gates
+
d
,
d
);
VAdd
(
gates
+
d
,
gates
+
d3
,
gates
+
d3
,
d
);
}
/* H_t = act_cell(C_t) * ogated */
act_gate
(
gates
+
d3
,
gates
+
d3
,
d
);
act_cell
(
ct
,
gates
+
d2
,
d
);
VMul
(
gates
+
d2
,
gates
+
d3
,
ht
,
d
);
}
// compute h1 without h0
template
<
typename
T
>
void
GRUH1
(
gru_t
*
step
,
const
gru_attr_t
*
attr
)
{
T
*
gates
=
reinterpret_cast
<
T
*>
(
step
->
gates
);
T
*
ht
=
reinterpret_cast
<
T
*>
(
step
->
ht
);
auto
act_gate
=
getActFunc
<
T
>
(
attr
->
act_gate
);
auto
act_cand
=
getActFunc
<
T
>
(
attr
->
act_cand
);
int
d
=
attr
->
d
;
int
d2
=
d
*
2
;
act_gate
(
gates
,
gates
,
d
);
act_cand
(
gates
+
d2
,
gates
+
d2
,
d
);
VMul
(
gates
,
gates
+
d2
,
ht
,
d
);
}
// compute the first part of GRU: ht = act_gate(r) * ht_1
template
<
typename
T
>
void
GRUHtPart1
(
gru_t
*
step
,
const
gru_attr_t
*
attr
)
{
// W: {W_update, W_reset; W_state}
T
*
gates
=
reinterpret_cast
<
T
*>
(
step
->
gates
);
T
*
ht
=
reinterpret_cast
<
T
*>
(
step
->
ht
);
const
T
*
ht_1
=
reinterpret_cast
<
const
T
*>
(
step
->
ht_1
);
auto
act_gate
=
getActFunc
<
T
>
(
attr
->
act_gate
);
act_gate
(
gates
+
attr
->
d
,
gates
+
attr
->
d
,
attr
->
d
);
VMul
(
ht_1
,
gates
+
attr
->
d
,
ht
,
attr
->
d
);
}
// compute the second part of GRU:
// ht = act_gate(u) * act_cand(s) + (1-act_gate(u)) * ht_1
template
<
typename
T
>
void
GRUHtPart2
(
gru_t
*
step
,
const
gru_attr_t
*
attr
)
{
T
*
gates
=
reinterpret_cast
<
T
*>
(
step
->
gates
);
T
*
ht
=
reinterpret_cast
<
T
*>
(
step
->
ht
);
const
T
*
ht_1
=
reinterpret_cast
<
const
T
*>
(
step
->
ht_1
);
auto
act_gate
=
getActFunc
<
T
>
(
attr
->
act_gate
);
auto
act_cand
=
getActFunc
<
T
>
(
attr
->
act_cand
);
int
d
=
attr
->
d
;
T
*
y
=
gates
+
d
*
2
;
act_gate
(
gates
,
gates
,
d
);
act_cand
(
y
,
y
,
d
);
// out = zt*ht~ + (1-zt)*ht_1
for
(
int
i
=
0
;
i
<
d
;
++
i
)
{
ht
[
i
]
=
gates
[
i
]
*
y
[
i
]
+
(
static_cast
<
T
>
(
1
)
-
gates
[
i
])
*
ht_1
[
i
];
}
}
}
// namespace refer
}
// namespace jitkernel
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/jit_kernel_rnn.cc
浏览文件 @
30e47bce
此差异已折叠。
点击以展开。
paddle/fluid/operators/math/jit_kernel_test.cc
浏览文件 @
30e47bce
此差异已折叠。
点击以展开。
paddle/scripts/paddle_build.sh
浏览文件 @
30e47bce
...
...
@@ -671,6 +671,55 @@ EOF
${
DOCKERFILE_CUBLAS_DSO
}
${
DOCKERFILE_GPU_ENV
}
ENV NCCL_LAUNCH_MODE PARALLEL
EOF
elif
[
"
$1
"
==
"cp36-cp36m"
]
;
then
cat
>>
${
PADDLE_ROOT
}
/build/Dockerfile
<<
EOF
ADD python/dist/*.whl /
# run paddle version to install python packages first
RUN apt-get update &&
${
NCCL_DEPS
}
RUN apt-get install -y make build-essential libssl-dev zlib1g-dev libbz2-dev
\
libreadline-dev libsqlite3-dev wget curl llvm libncurses5-dev libncursesw5-dev
\
xz-utils tk-dev libffi-dev liblzma-dev
RUN mkdir -p /root/python_build/ && wget -q https://www.sqlite.org/2018/sqlite-autoconf-3250300.tar.gz &&
\
tar -zxf sqlite-autoconf-3250300.tar.gz && cd sqlite-autoconf-3250300 &&
\
./configure -prefix=/usr/local && make -j8 && make install && cd ../ && rm sqlite-autoconf-3250300.tar.gz &&
\
wget -q https://www.python.org/ftp/python/3.6.0/Python-3.6.0.tgz &&
\
tar -xzf Python-3.6.0.tgz && cd Python-3.6.0 &&
\
CFLAGS="-Wformat" ./configure --prefix=/usr/local/ --enable-shared > /dev/null &&
\
make -j8 > /dev/null && make altinstall > /dev/null
RUN apt-get install -y libgtk2.0-dev dmidecode python3-tk &&
\
pip3.6 install opencv-python && pip3.6 install /*.whl; apt-get install -f -y &&
\
apt-get clean -y &&
\
rm -f /*.whl &&
\
${
PADDLE_VERSION
}
&&
\
ldconfig
${
DOCKERFILE_CUDNN_DSO
}
${
DOCKERFILE_CUBLAS_DSO
}
${
DOCKERFILE_GPU_ENV
}
ENV NCCL_LAUNCH_MODE PARALLEL
EOF
elif
[
"
$1
"
==
"cp37-cp37m"
]
;
then
cat
>>
${
PADDLE_ROOT
}
/build/Dockerfile
<<
EOF
ADD python/dist/*.whl /
# run paddle version to install python packages first
RUN apt-get update &&
${
NCCL_DEPS
}
RUN apt-get install -y make build-essential libssl-dev zlib1g-dev libbz2-dev
\
libreadline-dev libsqlite3-dev wget curl llvm libncurses5-dev libncursesw5-dev
\
xz-utils tk-dev libffi-dev liblzma-dev
RUN wget -q https://www.python.org/ftp/python/3.7.0/Python-3.7.0.tgz &&
\
tar -xzf Python-3.7.0.tgz && cd Python-3.7.0 &&
\
CFLAGS="-Wformat" ./configure --prefix=/usr/local/ --enable-shared > /dev/null &&
\
make -j8 > /dev/null && make altinstall > /dev/null
RUN apt-get install -y libgtk2.0-dev dmidecode python3-tk &&
\
pip3.7 install opencv-python && pip3.7 install /*.whl; apt-get install -f -y &&
\
apt-get clean -y &&
\
rm -f /*.whl &&
\
${
PADDLE_VERSION
}
&&
\
ldconfig
${
DOCKERFILE_CUDNN_DSO
}
${
DOCKERFILE_CUBLAS_DSO
}
${
DOCKERFILE_GPU_ENV
}
ENV NCCL_LAUNCH_MODE PARALLEL
EOF
else
cat
>>
${
PADDLE_ROOT
}
/build/Dockerfile
<<
EOF
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
30e47bce
...
...
@@ -93,13 +93,13 @@ if(WITH_DISTRIBUTE)
if
(
NOT APPLE
)
set_tests_properties
(
test_dist_mnist PROPERTIES TIMEOUT 200
)
set_tests_properties
(
test_dist_word2vec PROPERTIES TIMEOUT 200
)
py_test_modules
(
test_dist_se_resnext MODULES test_dist_se_resnext
)
set_tests_properties
(
test_dist_se_resnext PROPERTIES TIMEOUT 1000
)
# FIXME(typhoonzero): add this back
#py_test_modules(test_dist_transformer MODULES test_dist_transformer)
#set_tests_properties(test_dist_transformer PROPERTIES TIMEOUT 1000)
# FIXME(typhoonzero): add these tests back
# py_test_modules(test_dist_se_resnext MODULES test_dist_se_resnext
)
# set_tests_properties(test_dist_se_resnext PROPERTIES TIMEOUT 1000)
#
py_test_modules(test_dist_transformer MODULES test_dist_transformer)
#
set_tests_properties(test_dist_transformer PROPERTIES TIMEOUT 1000)
# TODO(typhoonzero): make dist test parallel when fix port management issue
set_tests_properties
(
test_dist_mnist test_dist_word2vec test_dist_
se_resnext test_dist_
ctr test_dist_simnet_bow test_dist_save_load test_dist_text_classification test_dist_mnist_batch_merge PROPERTIES RUN_SERIAL TRUE
)
set_tests_properties
(
test_dist_mnist test_dist_word2vec test_dist_ctr test_dist_simnet_bow test_dist_save_load test_dist_text_classification test_dist_mnist_batch_merge PROPERTIES RUN_SERIAL TRUE
)
endif
(
NOT APPLE
)
py_test_modules
(
test_dist_transpiler MODULES test_dist_transpiler
)
endif
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录