Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
159be8cc
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
159be8cc
编写于
10月 23, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
optimize fusion gru kernel at size 8
上级
83dc6898
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
123 addition
and
55 deletion
+123
-55
paddle/fluid/operators/math/jit_kernel_rnn.cc
paddle/fluid/operators/math/jit_kernel_rnn.cc
+117
-55
python/paddle/fluid/tests/unittests/test_fusion_gru_op.py
python/paddle/fluid/tests/unittests/test_fusion_gru_op.py
+6
-0
未找到文件。
paddle/fluid/operators/math/jit_kernel_rnn.cc
浏览文件 @
159be8cc
...
@@ -136,6 +136,21 @@ static std::shared_ptr<const VActKernel<T>> GetActKernel(
...
@@ -136,6 +136,21 @@ static std::shared_ptr<const VActKernel<T>> GetActKernel(
return
nullptr
;
return
nullptr
;
}
}
template
<
jit
::
cpu_isa_t
isa
>
static
std
::
unique_ptr
<
AVXAct
>
GetAVXAct
(
const
std
::
string
&
type
)
{
if
(
type
==
"sigmoid"
)
{
return
std
::
unique_ptr
<
AVXAct
>
(
new
AVXActImpl
<
kSigmoid
,
isa
>
());
}
else
if
(
type
==
"relu"
)
{
return
std
::
unique_ptr
<
AVXAct
>
(
new
AVXActImpl
<
kRelu
,
isa
>
());
}
else
if
(
type
==
"tanh"
)
{
return
std
::
unique_ptr
<
AVXAct
>
(
new
AVXActImpl
<
kTanh
,
isa
>
());
}
else
if
(
type
==
"identity"
||
type
==
""
)
{
return
std
::
unique_ptr
<
AVXAct
>
(
new
AVXActImpl
<
kIdentity
,
isa
>
());
}
PADDLE_THROW
(
"Not support type: %s"
,
type
);
return
nullptr
;
}
/* LSTM JitKernel */
/* LSTM JitKernel */
template
<
typename
T
,
jit
::
cpu_isa_t
isa
,
jit_block
>
template
<
typename
T
,
jit
::
cpu_isa_t
isa
,
jit_block
>
class
LSTMKernelImpl
:
public
LSTMKernel
<
T
>
{
class
LSTMKernelImpl
:
public
LSTMKernel
<
T
>
{
...
@@ -192,61 +207,49 @@ class LSTMKernelImpl : public LSTMKernel<T> {
...
@@ -192,61 +207,49 @@ class LSTMKernelImpl : public LSTMKernel<T> {
#endif
#endif
};
};
#define INTRI8_FLOAT(isa) \
#define INTRI8_FLOAT(isa) \
template <> \
template <> \
LSTMKernelImpl<float, isa, kEQ8>::LSTMKernelImpl( \
LSTMKernelImpl<float, isa, kEQ8>::LSTMKernelImpl( \
const std::string& act_gate, const std::string& act_cand, \
const std::string& act_gate, const std::string& act_cand, \
const std::string& act_cell, int d) \
const std::string& act_cell, int d) \
: LSTMKernel<float>() { \
: LSTMKernel<float>() { \
auto GetAVXAct = [&](const std::string& type) -> std::unique_ptr<AVXAct> { \
avx_act_gate_ = GetAVXAct<isa>(act_gate); \
if (type == "sigmoid") { \
avx_act_cand_ = GetAVXAct<isa>(act_cand); \
return std::unique_ptr<AVXAct>(new AVXActImpl<kSigmoid, isa>()); \
avx_act_cell_ = GetAVXAct<isa>(act_cell); \
} else if (type == "relu") { \
} \
return std::unique_ptr<AVXAct>(new AVXActImpl<kRelu, isa>()); \
template <> \
} else if (type == "tanh") { \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \
return std::unique_ptr<AVXAct>(new AVXActImpl<kTanh, isa>()); \
float* gates, const float* ct_1, float* ct, float* ht, \
} else if (type == "identity" || type == "") { \
const float* wp_data, float* checked) const { \
return std::unique_ptr<AVXAct>(new AVXActImpl<kIdentity, isa>()); \
/* gates: W_ch, W_ih, W_fh, W_oh */
\
} \
__m256 c, i, f, o; \
PADDLE_THROW("Not support type: %s", type); \
c = _mm256_loadu_ps(gates); \
}; \
i = _mm256_loadu_ps(gates + 8); \
avx_act_gate_ = GetAVXAct(act_gate); \
f = _mm256_loadu_ps(gates + 16); \
avx_act_cand_ = GetAVXAct(act_cand); \
o = _mm256_loadu_ps(gates + 24); \
avx_act_cell_ = GetAVXAct(act_cell); \
/* C_t = C_t-1 * fgated + cand_gated * igated*/
\
} \
c = _mm256_mul_ps(avx_act_cand_->Compute(c), avx_act_gate_->Compute(i)); \
template <> \
i = _mm256_loadu_ps(ct_1); \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \
f = _mm256_mul_ps(i, avx_act_gate_->Compute(f)); \
float* gates, const float* ct_1, float* ct, float* ht, \
f = _mm256_add_ps(c, f); \
const float* wp_data, float* checked) const { \
_mm256_storeu_ps(ct, f); \
/* gates: W_ch, W_ih, W_fh, W_oh */
\
/* H_t = act_cell(C_t) * ogated */
\
__m256 c, i, f, o; \
o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \
c = _mm256_loadu_ps(gates); \
_mm256_storeu_ps(ht, o); \
i = _mm256_loadu_ps(gates + 8); \
} \
f = _mm256_loadu_ps(gates + 16); \
template <> \
o = _mm256_loadu_ps(gates + 24); \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeC1H1( \
/* C_t = C_t-1 * fgated + cand_gated * igated*/
\
float* gates, float* ct, float* ht, const float* wp_data) const { \
c = _mm256_mul_ps(avx_act_cand_->Compute(c), avx_act_gate_->Compute(i)); \
__m256 c, i, o; \
i = _mm256_loadu_ps(ct_1); \
c = _mm256_loadu_ps(gates); \
f = _mm256_mul_ps(i, avx_act_gate_->Compute(f)); \
i = _mm256_loadu_ps(gates + 8); \
f = _mm256_add_ps(c, f); \
o = _mm256_loadu_ps(gates + 24); \
_mm256_storeu_ps(ct, f); \
/* C_t = igated * cgated*/
\
/* H_t = act_cell(C_t) * ogated */
\
c = _mm256_mul_ps(avx_act_gate_->Compute(i), avx_act_cand_->Compute(c)); \
o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \
_mm256_storeu_ps(ct, c); \
_mm256_storeu_ps(ht, o); \
/* H_t = act_cell(C_t) * ogated */
\
} \
o = _mm256_mul_ps(avx_act_cell_->Compute(c), avx_act_gate_->Compute(o)); \
template <> \
_mm256_storeu_ps(ht, o); \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeC1H1( \
float* gates, float* ct, float* ht, const float* wp_data) const { \
__m256 c, i, o; \
c = _mm256_loadu_ps(gates); \
i = _mm256_loadu_ps(gates + 8); \
o = _mm256_loadu_ps(gates + 24); \
/* C_t = igated * cgated*/
\
c = _mm256_mul_ps(avx_act_gate_->Compute(i), avx_act_cand_->Compute(c)); \
_mm256_storeu_ps(ct, c); \
/* H_t = act_cell(C_t) * ogated */
\
o = _mm256_mul_ps(avx_act_cell_->Compute(c), avx_act_gate_->Compute(o)); \
_mm256_storeu_ps(ht, o); \
}
}
// TODO(TJ): optimize keq16
// TODO(TJ): optimize keq16
...
@@ -375,6 +378,7 @@ class GRUKernelImpl : public GRUKernel<T> {
...
@@ -375,6 +378,7 @@ class GRUKernelImpl : public GRUKernel<T> {
act_state_d_
->
Compute
(
gates
+
d2_
,
gates
+
d2_
);
act_state_d_
->
Compute
(
gates
+
d2_
,
gates
+
d2_
);
vmul_d_
->
Compute
(
gates
,
gates
+
d2_
,
ht
);
vmul_d_
->
Compute
(
gates
,
gates
+
d2_
,
ht
);
}
}
void
ComputeHtPart1
(
T
*
gates
,
const
T
*
ht_1
,
T
*
ht
)
const
override
{
void
ComputeHtPart1
(
T
*
gates
,
const
T
*
ht_1
,
T
*
ht
)
const
override
{
// W: {W_update, W_reset; W_state}
// W: {W_update, W_reset; W_state}
act_gate_d2_
->
Compute
(
gates
,
gates
);
act_gate_d2_
->
Compute
(
gates
,
gates
);
...
@@ -394,8 +398,65 @@ class GRUKernelImpl : public GRUKernel<T> {
...
@@ -394,8 +398,65 @@ class GRUKernelImpl : public GRUKernel<T> {
int
d_
,
d2_
;
int
d_
,
d2_
;
std
::
shared_ptr
<
const
VActKernel
<
T
>>
act_gate_d2_
,
act_gate_d_
,
act_state_d_
;
std
::
shared_ptr
<
const
VActKernel
<
T
>>
act_gate_d2_
,
act_gate_d_
,
act_state_d_
;
std
::
shared_ptr
<
const
VMulKernel
<
T
>>
vmul_d_
;
std
::
shared_ptr
<
const
VMulKernel
<
T
>>
vmul_d_
;
#ifdef __AVX__
std
::
unique_ptr
<
const
AVXAct
>
avx_act_gate_
,
avx_act_state_
;
#endif
};
};
#define INTRI8_FLOAT(isa) \
template <> \
GRUKernelImpl<float, isa, kEQ8>::GRUKernelImpl( \
const std::string& act_gate, const std::string& act_state, int d) \
: GRUKernel<float>() { \
avx_act_gate_ = GetAVXAct<isa>(act_gate); \
avx_act_state_ = GetAVXAct<isa>(act_state); \
} \
template <> \
void GRUKernelImpl<float, isa, kEQ8>::ComputeH1(float* gates, float* ht) \
const { \
__m256 u, s; \
/* W: {W_update, W_reset; W_state} */
\
u = _mm256_loadu_ps(gates); \
s = _mm256_loadu_ps(gates + 16); \
s = _mm256_mul_ps(avx_act_gate_->Compute(u), avx_act_state_->Compute(s)); \
_mm256_storeu_ps(ht, s); \
} \
template <> \
void GRUKernelImpl<float, isa, kEQ8>::ComputeHtPart1( \
float* gates, const float* ht_1, float* ht) const { \
/* not exactly equal the any implementation */
\
__m256 r, ht0; \
r = _mm256_loadu_ps(gates + 8); \
ht0 = _mm256_loadu_ps(ht_1); \
r = _mm256_mul_ps(avx_act_gate_->Compute(r), ht0); \
_mm256_storeu_ps(ht, r); \
} \
template <> \
void GRUKernelImpl<float, isa, kEQ8>::ComputeHtPart2( \
float* gates, const float* ht_1, float* ht) const { \
/* not exactly equal the any implementation */
\
__m256 u, s, ht0; \
u = _mm256_loadu_ps(gates); \
s = _mm256_loadu_ps(gates + 16); \
ht0 = _mm256_loadu_ps(ht_1); \
u = avx_act_gate_->Compute(u); \
s = _mm256_mul_ps(u, avx_act_state_->Compute(s)); \
u = _mm256_sub_ps(_mm256_set1_ps(1.f), u); \
u = _mm256_mul_ps(u, ht0); \
u = _mm256_add_ps(s, u); \
_mm256_storeu_ps(ht, u); \
}
#ifdef __AVX__
INTRI8_FLOAT
(
jit
::
avx
);
#endif
#ifdef __AVX2__
INTRI8_FLOAT
(
jit
::
avx2
);
#endif
#ifdef __AVX512F__
INTRI8_FLOAT
(
jit
::
avx512f
);
#endif
#define JITKERNEL_DECLARE_GRU(ker_class, ker_dtype) \
#define JITKERNEL_DECLARE_GRU(ker_class, ker_dtype) \
template <> \
template <> \
std::shared_ptr<const GRUKernel<ker_dtype>> KernelPool::Get< \
std::shared_ptr<const GRUKernel<ker_dtype>> KernelPool::Get< \
...
@@ -412,6 +473,7 @@ class GRUKernelImpl : public GRUKernel<T> {
...
@@ -412,6 +473,7 @@ class GRUKernelImpl : public GRUKernel<T> {
REGISTER_JITKERNEL_ARGS
(
gru
,
GRUKernel
,
JITKERNEL_DECLARE_GRU
,
REGISTER_JITKERNEL_ARGS
(
gru
,
GRUKernel
,
JITKERNEL_DECLARE_GRU
,
JITKERNEL_KEY_GRU
,
JITKERNEL_NEW_GRU_IMPL
);
JITKERNEL_KEY_GRU
,
JITKERNEL_NEW_GRU_IMPL
);
#undef INTRI8_FLOAT
#undef JITKERNEL_NEW_GRU_IMPL
#undef JITKERNEL_NEW_GRU_IMPL
#undef JITKERNEL_KEY_GRU
#undef JITKERNEL_KEY_GRU
#undef JITKERNEL_DECLARE_GRU
#undef JITKERNEL_DECLARE_GRU
...
...
python/paddle/fluid/tests/unittests/test_fusion_gru_op.py
浏览文件 @
159be8cc
...
@@ -125,6 +125,12 @@ class TestFusionGRUOpMD2(TestFusionGRUOp):
...
@@ -125,6 +125,12 @@ class TestFusionGRUOpMD2(TestFusionGRUOp):
self
.
D
=
8
self
.
D
=
8
class
TestFusionGRUOpMD3
(
TestFusionGRUOp
):
def
set_confs
(
self
):
self
.
M
=
17
self
.
D
=
15
class
TestFusionGRUOpBS1
(
TestFusionGRUOp
):
class
TestFusionGRUOpBS1
(
TestFusionGRUOp
):
def
set_confs
(
self
):
def
set_confs
(
self
):
self
.
lod
=
[[
3
]]
self
.
lod
=
[[
3
]]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录