Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
0043c42b
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0043c42b
编写于
11月 12, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add vrelu jitcode
test=develop
上级
9a6e2392
变更
7
展开全部
隐藏空白更改
内联
并排
Showing
7 changed file
with
245 addition
and
245 deletion
+245
-245
paddle/fluid/operators/math/jit_code.cc
paddle/fluid/operators/math/jit_code.cc
+33
-0
paddle/fluid/operators/math/jit_code.h
paddle/fluid/operators/math/jit_code.h
+23
-0
paddle/fluid/operators/math/jit_kernel.h
paddle/fluid/operators/math/jit_kernel.h
+7
-6
paddle/fluid/operators/math/jit_kernel_blas.cc
paddle/fluid/operators/math/jit_kernel_blas.cc
+42
-99
paddle/fluid/operators/math/jit_kernel_exp.cc
paddle/fluid/operators/math/jit_kernel_exp.cc
+108
-108
paddle/fluid/operators/math/jit_kernel_rnn.cc
paddle/fluid/operators/math/jit_kernel_rnn.cc
+19
-19
paddle/fluid/operators/math/jit_kernel_test.cc
paddle/fluid/operators/math/jit_kernel_test.cc
+13
-13
未找到文件。
paddle/fluid/operators/math/jit_code.cc
浏览文件 @
0043c42b
...
@@ -118,6 +118,39 @@ void VXXJitCode::generate() {
...
@@ -118,6 +118,39 @@ void VXXJitCode::generate() {
ret
();
ret
();
}
}
bool
ReluJitCode
::
init
(
int
d
)
{
return
MayIUse
(
avx
);
}
void
ReluJitCode
::
generate
()
{
int
offset
=
0
;
vxorps
(
ymm_zero
,
ymm_zero
,
ymm_zero
);
for
(
int
i
=
0
;
i
<
num_
/
AVX_FLOAT_BLOCK
;
++
i
)
{
vmovups
(
ymm_src
,
ptr
[
param1
+
offset
]);
vmaxps
(
ymm_dst
,
ymm_zero
,
ymm_src
);
vmovups
(
ptr
[
param2
+
offset
],
ymm_dst
);
offset
+=
sizeof
(
float
)
*
AVX_FLOAT_BLOCK
;
}
int
rest
=
num_
%
AVX_FLOAT_BLOCK
;
if
(
rest
>=
4
)
{
vmovups
(
xmm_src
,
ptr
[
param1
+
offset
]);
vmaxps
(
xmm_dst
,
xmm_zero
,
xmm_src
);
vmovups
(
ptr
[
param2
+
offset
],
xmm_dst
);
offset
+=
sizeof
(
float
)
*
4
;
rest
-=
4
;
}
if
(
rest
>=
2
)
{
vmovups
(
xmm_src
,
ptr
[
param1
+
offset
]);
vmaxps
(
xmm_dst
,
xmm_zero
,
xmm_src
);
vmovq
(
ptr
[
param2
+
offset
],
xmm_dst
);
offset
+=
sizeof
(
float
)
*
2
;
rest
-=
2
;
}
if
(
rest
>
0
)
{
vmovups
(
xmm_src
,
ptr
[
param1
+
offset
]);
vmaxps
(
xmm_dst
,
xmm_zero
,
xmm_src
);
vmovss
(
ptr
[
param2
+
offset
],
xmm_dst
);
}
ret
();
}
}
// namespace gen
}
// namespace gen
}
// namespace jitkernel
}
// namespace jitkernel
}
// namespace math
}
// namespace math
...
...
paddle/fluid/operators/math/jit_code.h
浏览文件 @
0043c42b
...
@@ -85,6 +85,29 @@ class VXXJitCode : public JitCode {
...
@@ -85,6 +85,29 @@ class VXXJitCode : public JitCode {
ymm_t
ymm_zero
=
ymm_t
(
3
);
ymm_t
ymm_zero
=
ymm_t
(
3
);
};
};
class
ReluJitCode
:
public
JitCode
{
public:
DECLARE_JIT_CODE
(
ReluJitCode
);
explicit
ReluJitCode
(
int
d
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
:
JitCode
(
code_size
,
code_ptr
),
num_
(
d
)
{}
static
bool
init
(
int
d
);
void
generate
()
override
;
private:
int
num_
;
reg64_t
param1
{
abi_param1
};
reg64_t
param2
{
abi_param2
};
xmm_t
xmm_zero
=
xmm_t
(
0
);
xmm_t
xmm_src
=
xmm_t
(
1
);
xmm_t
xmm_dst
=
xmm_t
(
1
);
ymm_t
ymm_zero
=
ymm_t
(
0
);
ymm_t
ymm_src
=
ymm_t
(
1
);
ymm_t
ymm_dst
=
ymm_t
(
1
);
};
}
// namespace gen
}
// namespace gen
}
// namespace jitkernel
}
// namespace jitkernel
}
// namespace math
}
// namespace math
...
...
paddle/fluid/operators/math/jit_kernel.h
浏览文件 @
0043c42b
...
@@ -97,37 +97,38 @@ class VAddBiasKernel : public Kernel {
...
@@ -97,37 +97,38 @@ class VAddBiasKernel : public Kernel {
template
<
typename
T
>
template
<
typename
T
>
class
VActKernel
:
public
Kernel
{
class
VActKernel
:
public
Kernel
{
public:
public:
virtual
void
Compute
(
const
T
*
x
,
T
*
y
)
const
=
0
;
virtual
void
Compute
Deprecated
(
const
T
*
x
,
T
*
y
)
const
=
0
;
};
};
template
<
typename
T
>
template
<
typename
T
>
class
VReluKernel
:
public
VActKernel
<
T
>
{
class
VReluKernel
:
public
VActKernel
<
T
>
{
public:
public:
virtual
void
Compute
(
const
T
*
x
,
T
*
y
)
const
=
0
;
virtual
void
ComputeDeprecated
(
const
T
*
x
,
T
*
y
)
const
=
0
;
void
(
*
Compute
)(
const
T
*
,
T
*
,
int
);
};
};
template
<
typename
T
>
template
<
typename
T
>
class
VIdentityKernel
:
public
VActKernel
<
T
>
{
class
VIdentityKernel
:
public
VActKernel
<
T
>
{
public:
public:
virtual
void
Compute
(
const
T
*
x
,
T
*
y
)
const
=
0
;
virtual
void
Compute
Deprecated
(
const
T
*
x
,
T
*
y
)
const
=
0
;
};
};
template
<
typename
T
>
template
<
typename
T
>
class
VExpKernel
:
public
VActKernel
<
T
>
{
class
VExpKernel
:
public
VActKernel
<
T
>
{
public:
public:
virtual
void
Compute
(
const
T
*
x
,
T
*
y
)
const
=
0
;
virtual
void
Compute
Deprecated
(
const
T
*
x
,
T
*
y
)
const
=
0
;
};
};
template
<
typename
T
>
template
<
typename
T
>
class
VSigmoidKernel
:
public
VActKernel
<
T
>
{
class
VSigmoidKernel
:
public
VActKernel
<
T
>
{
public:
public:
virtual
void
Compute
(
const
T
*
x
,
T
*
y
)
const
=
0
;
virtual
void
Compute
Deprecated
(
const
T
*
x
,
T
*
y
)
const
=
0
;
};
};
template
<
typename
T
>
template
<
typename
T
>
class
VTanhKernel
:
public
VActKernel
<
T
>
{
class
VTanhKernel
:
public
VActKernel
<
T
>
{
public:
public:
virtual
void
Compute
(
const
T
*
x
,
T
*
y
)
const
=
0
;
virtual
void
Compute
Deprecated
(
const
T
*
x
,
T
*
y
)
const
=
0
;
};
};
template
<
typename
T
>
template
<
typename
T
>
...
...
paddle/fluid/operators/math/jit_kernel_blas.cc
浏览文件 @
0043c42b
...
@@ -71,6 +71,13 @@ void VAddBiasRefer(const T* a, const T* x, T* y, int n) {
...
@@ -71,6 +71,13 @@ void VAddBiasRefer(const T* a, const T* x, T* y, int n) {
}
}
}
}
template
<
typename
T
>
void
VReluRefer
(
const
T
*
x
,
T
*
y
,
int
n
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
x
[
i
]
>
0
?
x
[
i
]
:
0
;
}
}
#ifdef PADDLE_WITH_MKLML
#ifdef PADDLE_WITH_MKLML
template
<
typename
T
>
template
<
typename
T
>
void
VMulMKL
(
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int
n
);
void
VMulMKL
(
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int
n
);
...
@@ -344,124 +351,60 @@ bool VAddBiasKernelImpl<float>::useJIT(int d) {
...
@@ -344,124 +351,60 @@ bool VAddBiasKernelImpl<float>::useJIT(int d) {
}
}
#endif
#endif
#undef DECLARE_STATIC_FUNC
REGISTER_JITKERNEL
(
vmul
,
VMulKernel
);
REGISTER_JITKERNEL
(
vadd
,
VAddKernel
);
REGISTER_JITKERNEL
(
vaddrelu
,
VAddReluKernel
);
REGISTER_JITKERNEL
(
vscal
,
VScalKernel
);
REGISTER_JITKERNEL
(
vaddbias
,
VAddBiasKernel
);
/* VRelu JitKernel */
/* VRelu JitKernel */
template
<
typename
T
,
platform
::
jit
::
cpu_isa_t
isa
,
jit_block
>
template
<
typename
T
>
class
VReluKernelImpl
:
public
VReluKernel
<
T
>
{
class
VReluKernelImpl
:
public
VReluKernel
<
T
>
{
public:
public:
explicit
VReluKernelImpl
(
int
d
)
:
VReluKernel
<
T
>
()
{
this
->
num_
=
d
;
}
DECLARE_STATIC_FUNC
;
void
Compute
(
const
T
*
x
,
T
*
y
)
const
override
{
explicit
VReluKernelImpl
(
int
d
)
:
VReluKernel
<
T
>
()
{
for
(
int
i
=
0
;
i
<
this
->
num_
;
++
i
)
{
this
->
num_
=
d
;
// TODO(TJ): remove me when ComputeDeprecated done
y
[
i
]
=
x
[
i
]
>
0
?
x
[
i
]
:
0
;
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
/*init*/
+
d
/
AVX_FLOAT_BLOCK
*
4
/* instructions*/
*
8
/*everage byte for each instruction*/
;
jitcode_
.
reset
(
new
gen
::
ReluJitCode
(
d
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
T
*
,
int
)
>
();
return
;
}
}
}
#endif
};
#define INTRI8_FLOAT(isa) \
template <> \
void VReluKernelImpl<float, isa, kEQ8>::Compute(const float* x, float* y) \
const { \
__m256 tmp = _mm256_loadu_ps(x); \
tmp = _mm256_max_ps(tmp, _mm256_setzero_ps()); \
_mm256_storeu_ps(y, tmp); \
}
#define INTRI16_FLOAT(isa) \
template <> \
void VReluKernelImpl<float, isa, kEQ16>::Compute(const float* x, float* y) \
const { \
__m256 zeros = _mm256_setzero_ps(); \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
tmp0 = _mm256_max_ps(tmp0, zeros); \
tmp1 = _mm256_max_ps(tmp1, zeros); \
_mm256_storeu_ps(y, tmp0); \
_mm256_storeu_ps(y + 8, tmp1); \
}
#define INTRI_GT8LT16_FLOAT(isa) \
this
->
Compute
=
VReluRefer
<
T
>
;
template <> \
VReluKernelImpl<float, isa, kGT8LT16>::VReluKernelImpl(int d) \
: VReluKernel<float>() { \
this->num_ = d; \
this->end_ = AVX_FLOAT_BLOCK; \
this->rest_ = d - AVX_FLOAT_BLOCK; \
} \
template <> \
void VReluKernelImpl<float, isa, kGT8LT16>::Compute(const float* x, \
float* y) const { \
__m256 zeros = _mm256_setzero_ps(); \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + this->rest_); \
tmp0 = _mm256_max_ps(tmp0, zeros); \
tmp1 = _mm256_max_ps(tmp1, zeros); \
_mm256_storeu_ps(y, tmp0); \
_mm256_storeu_ps(y + this->rest_, tmp1); \
}
}
void
ComputeDeprecated
(
const
T
*
x
,
T
*
y
)
const
override
{
#define INTRI_GT16_FLOAT(isa) \
VReluRefer
(
x
,
y
,
this
->
num_
);
template <> \
VReluKernelImpl<float, isa, kGT16>::VReluKernelImpl(int d) \
: VReluKernel<float>() { \
this->num_ = d; \
this->end_ = d - d % AVX_FLOAT_BLOCK; \
this->rest_ = d - AVX_FLOAT_BLOCK; \
} \
template <> \
void VReluKernelImpl<float, isa, kGT16>::Compute(const float* x, float* y) \
const { \
__m256 zeros = _mm256_setzero_ps(); \
for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \
__m256 tmp = _mm256_loadu_ps(x + i); \
tmp = _mm256_max_ps(tmp, zeros); \
_mm256_storeu_ps(y + i, tmp); \
} \
__m256 tmp = _mm256_loadu_ps(x + this->rest_); \
tmp = _mm256_max_ps(tmp, zeros); \
_mm256_storeu_ps(y + this->rest_, tmp); \
}
}
#ifdef PADDLE_WITH_XBYAK
#ifdef __AVX__
private:
INTRI8_FLOAT
(
jit
::
avx
);
std
::
unique_ptr
<
gen
::
ReluJitCode
>
jitcode_
{
nullptr
};
INTRI16_FLOAT
(
jit
::
avx
);
INTRI_GT8LT16_FLOAT
(
jit
::
avx
);
INTRI_GT16_FLOAT
(
jit
::
avx
);
#endif
#ifdef __AVX2__
INTRI8_FLOAT
(
jit
::
avx2
);
INTRI16_FLOAT
(
jit
::
avx2
);
INTRI_GT8LT16_FLOAT
(
jit
::
avx2
);
INTRI_GT16_FLOAT
(
jit
::
avx2
);
#endif
#endif
#ifdef __AVX512F__
};
// TODO(TJ): refine avx512
INTRI8_FLOAT
(
jit
::
avx512f
);
#ifdef PADDLE_WITH_XBYAK
INTRI16_FLOAT
(
jit
::
avx512f
);
template
<
>
INTRI_GT8LT16_FLOAT
(
jit
::
avx512f
);
bool
VReluKernelImpl
<
float
>::
useJIT
(
int
d
)
{
INTRI_GT16_FLOAT
(
jit
::
avx512f
);
return
gen
::
ReluJitCode
::
init
(
d
);
}
#endif
#endif
#undef INTRI8_FLOAT
#undef DECLARE_STATIC_FUNC
#undef INTRI16_FLOAT
#undef INTRI_GT8LT16_FLOAT
REGISTER_JITKERNEL
(
vmul
,
VMulKernel
);
#undef INTRI_GT16_FLOAT
REGISTER_JITKERNEL
(
vadd
,
VAddKernel
);
REGISTER_JITKERNEL
(
vaddrelu
,
VAddReluKernel
);
REGISTER_JITKERNEL
(
vscal
,
VScalKernel
);
REGISTER_JITKERNEL
(
vaddbias
,
VAddBiasKernel
);
REGISTER_JITKERNEL
(
vrelu
,
VReluKernel
);
/* An empty JitKernel */
/* An empty JitKernel */
template
<
typename
T
,
platform
::
jit
::
cpu_isa_t
isa
,
jit_block
>
template
<
typename
T
,
platform
::
jit
::
cpu_isa_t
isa
,
jit_block
>
class
VIdentityKernelImpl
:
public
VIdentityKernel
<
T
>
{
class
VIdentityKernelImpl
:
public
VIdentityKernel
<
T
>
{
public:
public:
explicit
VIdentityKernelImpl
(
int
d
)
:
VIdentityKernel
<
T
>
()
{
this
->
num_
=
d
;
}
explicit
VIdentityKernelImpl
(
int
d
)
:
VIdentityKernel
<
T
>
()
{
this
->
num_
=
d
;
}
void
Compute
(
const
T
*
x
,
T
*
y
)
const
override
{}
void
Compute
Deprecated
(
const
T
*
x
,
T
*
y
)
const
override
{}
};
};
REGISTER_JITKERNEL_DEPRECATED
(
vrelu
,
VReluKernel
);
REGISTER_JITKERNEL_DEPRECATED
(
videntity
,
VIdentityKernel
);
REGISTER_JITKERNEL_DEPRECATED
(
videntity
,
VIdentityKernel
);
}
// namespace jitkernel
}
// namespace jitkernel
...
...
paddle/fluid/operators/math/jit_kernel_exp.cc
浏览文件 @
0043c42b
此差异已折叠。
点击以展开。
paddle/fluid/operators/math/jit_kernel_rnn.cc
浏览文件 @
0043c42b
...
@@ -175,26 +175,26 @@ class LSTMKernelImpl : public LSTMKernel<T> {
...
@@ -175,26 +175,26 @@ class LSTMKernelImpl : public LSTMKernel<T> {
void
ComputeCtHt
(
T
*
gates
,
const
T
*
ct_1
,
T
*
ct
,
T
*
ht
,
const
T
*
wp_data
,
void
ComputeCtHt
(
T
*
gates
,
const
T
*
ct_1
,
T
*
ct
,
T
*
ht
,
const
T
*
wp_data
,
T
*
checked
)
const
override
{
T
*
checked
)
const
override
{
// gates: W_ch, W_ih, W_fh, W_oh
// gates: W_ch, W_ih, W_fh, W_oh
act_gate_d3_
->
Compute
(
gates
+
d_
,
gates
+
d_
);
act_gate_d3_
->
Compute
Deprecated
(
gates
+
d_
,
gates
+
d_
);
/* C_t = C_t-1 * fgated + cand_gated * igated */
/* C_t = C_t-1 * fgated + cand_gated * igated */
act_cand_d_
->
Compute
(
gates
,
gates
);
act_cand_d_
->
Compute
Deprecated
(
gates
,
gates
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
gates
+
d_
,
d_
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
gates
+
d_
,
d_
);
vmul_d_
->
Compute
(
ct_1
,
gates
+
d2_
,
gates
+
d2_
,
d_
);
vmul_d_
->
Compute
(
ct_1
,
gates
+
d2_
,
gates
+
d2_
,
d_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d2_
,
ct
,
d_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d2_
,
ct
,
d_
);
/* H_t = act_cell(C_t) * ogated */
/* H_t = act_cell(C_t) * ogated */
act_cell_d_
->
Compute
(
ct
,
gates
+
d2_
);
act_cell_d_
->
Compute
Deprecated
(
ct
,
gates
+
d2_
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
,
d_
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
,
d_
);
}
}
void
ComputeC1H1
(
T
*
gates
,
T
*
ct
,
T
*
ht
,
const
T
*
wp_data
)
const
override
{
void
ComputeC1H1
(
T
*
gates
,
T
*
ct
,
T
*
ht
,
const
T
*
wp_data
)
const
override
{
/* C_t = igated * cgated*/
/* C_t = igated * cgated*/
act_gate_d_
->
Compute
(
gates
+
d_
,
gates
+
d_
);
act_gate_d_
->
Compute
Deprecated
(
gates
+
d_
,
gates
+
d_
);
act_cand_d_
->
Compute
(
gates
,
gates
);
act_cand_d_
->
Compute
Deprecated
(
gates
,
gates
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
ct
,
d_
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
ct
,
d_
);
/* H_t = act_cell(C_t) * ogated */
/* H_t = act_cell(C_t) * ogated */
act_gate_d_
->
Compute
(
gates
+
d3_
,
gates
+
d3_
);
act_gate_d_
->
Compute
Deprecated
(
gates
+
d3_
,
gates
+
d3_
);
act_cell_d_
->
Compute
(
ct
,
gates
+
d2_
);
act_cell_d_
->
Compute
Deprecated
(
ct
,
gates
+
d2_
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
,
d_
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
,
d_
);
}
}
...
@@ -292,32 +292,32 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
...
@@ -292,32 +292,32 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
vmul_d_
->
Compute
(
wp_data
,
ct_1
,
checked
,
d_
);
vmul_d_
->
Compute
(
wp_data
,
ct_1
,
checked
,
d_
);
vmul_d_
->
Compute
(
wp_data
+
d_
,
ct_1
,
checked
+
d_
,
d_
);
vmul_d_
->
Compute
(
wp_data
+
d_
,
ct_1
,
checked
+
d_
,
d_
);
vadd_d2_
->
Compute
(
checked
,
gates
+
d_
,
gates
+
d_
,
d2_
);
vadd_d2_
->
Compute
(
checked
,
gates
+
d_
,
gates
+
d_
,
d2_
);
act_gate_d2_
->
Compute
(
gates
+
d_
,
gates
+
d_
);
act_gate_d2_
->
Compute
Deprecated
(
gates
+
d_
,
gates
+
d_
);
/* C_t = C_t-1 * fgated + cand_gated * igated*/
/* C_t = C_t-1 * fgated + cand_gated * igated*/
act_cand_d_
->
Compute
(
gates
,
gates
);
act_cand_d_
->
Compute
Deprecated
(
gates
,
gates
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
gates
+
d_
,
d_
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
gates
+
d_
,
d_
);
vmul_d_
->
Compute
(
ct_1
,
gates
+
d2_
,
gates
+
d2_
,
d_
);
vmul_d_
->
Compute
(
ct_1
,
gates
+
d2_
,
gates
+
d2_
,
d_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d2_
,
ct
,
d_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d2_
,
ct
,
d_
);
/* get ogated*/
/* get ogated*/
vmul_d_
->
Compute
(
wp_data
+
d2_
,
ct
,
gates
+
d_
,
d_
);
vmul_d_
->
Compute
(
wp_data
+
d2_
,
ct
,
gates
+
d_
,
d_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d3_
,
gates
+
d3_
,
d_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d3_
,
gates
+
d3_
,
d_
);
act_gate_d_
->
Compute
(
gates
+
d3_
,
gates
+
d3_
);
act_gate_d_
->
Compute
Deprecated
(
gates
+
d3_
,
gates
+
d3_
);
/* H_t = act_cell(C_t) * ogated */
/* H_t = act_cell(C_t) * ogated */
act_cell_d_
->
Compute
(
ct
,
gates
+
d2_
);
act_cell_d_
->
Compute
Deprecated
(
ct
,
gates
+
d2_
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
,
d_
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
,
d_
);
}
}
void
ComputeC1H1
(
T
*
gates
,
T
*
ct
,
T
*
ht
,
const
T
*
wp_data
)
const
override
{
void
ComputeC1H1
(
T
*
gates
,
T
*
ct
,
T
*
ht
,
const
T
*
wp_data
)
const
override
{
/* C_t = igated * cgated*/
/* C_t = igated * cgated*/
act_gate_d_
->
Compute
(
gates
+
d_
,
gates
+
d_
);
act_gate_d_
->
Compute
Deprecated
(
gates
+
d_
,
gates
+
d_
);
act_cand_d_
->
Compute
(
gates
,
gates
);
act_cand_d_
->
Compute
Deprecated
(
gates
,
gates
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
ct
,
d_
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
ct
,
d_
);
/* get outgated, put W_oc * C_t on igated */
/* get outgated, put W_oc * C_t on igated */
vmul_d_
->
Compute
(
wp_data
+
d2_
,
ct
,
gates
+
d_
,
d_
);
vmul_d_
->
Compute
(
wp_data
+
d2_
,
ct
,
gates
+
d_
,
d_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d3_
,
gates
+
d3_
,
d_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d3_
,
gates
+
d3_
,
d_
);
/* H_t = act_cell(C_t) * ogated */
/* H_t = act_cell(C_t) * ogated */
act_gate_d_
->
Compute
(
gates
+
d3_
,
gates
+
d3_
);
act_gate_d_
->
Compute
Deprecated
(
gates
+
d3_
,
gates
+
d3_
);
act_cell_d_
->
Compute
(
ct
,
gates
+
d2_
);
act_cell_d_
->
Compute
Deprecated
(
ct
,
gates
+
d2_
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
,
d_
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
,
d_
);
}
}
...
@@ -376,20 +376,20 @@ class GRUKernelImpl : public GRUKernel<T> {
...
@@ -376,20 +376,20 @@ class GRUKernelImpl : public GRUKernel<T> {
}
}
void
ComputeH1
(
T
*
gates
,
T
*
ht
)
const
override
{
void
ComputeH1
(
T
*
gates
,
T
*
ht
)
const
override
{
act_gate_d_
->
Compute
(
gates
,
gates
);
act_gate_d_
->
Compute
Deprecated
(
gates
,
gates
);
act_state_d_
->
Compute
(
gates
+
d2_
,
gates
+
d2_
);
act_state_d_
->
Compute
Deprecated
(
gates
+
d2_
,
gates
+
d2_
);
vmul_d_
->
Compute
(
gates
,
gates
+
d2_
,
ht
,
d_
);
vmul_d_
->
Compute
(
gates
,
gates
+
d2_
,
ht
,
d_
);
}
}
void
ComputeHtPart1
(
T
*
gates
,
const
T
*
ht_1
,
T
*
ht
)
const
override
{
void
ComputeHtPart1
(
T
*
gates
,
const
T
*
ht_1
,
T
*
ht
)
const
override
{
// W: {W_update, W_reset; W_state}
// W: {W_update, W_reset; W_state}
act_gate_d2_
->
Compute
(
gates
,
gates
);
act_gate_d2_
->
Compute
Deprecated
(
gates
,
gates
);
vmul_d_
->
Compute
(
ht_1
,
gates
+
d_
,
ht
,
d_
);
vmul_d_
->
Compute
(
ht_1
,
gates
+
d_
,
ht
,
d_
);
}
}
void
ComputeHtPart2
(
T
*
gates
,
const
T
*
ht_1
,
T
*
ht
)
const
override
{
void
ComputeHtPart2
(
T
*
gates
,
const
T
*
ht_1
,
T
*
ht
)
const
override
{
T
*
y
=
gates
+
d2_
;
T
*
y
=
gates
+
d2_
;
act_state_d_
->
Compute
(
y
,
y
);
act_state_d_
->
Compute
Deprecated
(
y
,
y
);
// out = zt*ht~ + (1-zt)*ht_1
// out = zt*ht~ + (1-zt)*ht_1
for
(
int
i
=
0
;
i
<
d_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
d_
;
++
i
)
{
ht
[
i
]
=
gates
[
i
]
*
y
[
i
]
+
(
static_cast
<
T
>
(
1
)
-
gates
[
i
])
*
ht_1
[
i
];
ht
[
i
]
=
gates
[
i
]
*
y
[
i
]
+
(
static_cast
<
T
>
(
1
)
-
gates
[
i
])
*
ht_1
[
i
];
...
...
paddle/fluid/operators/math/jit_kernel_test.cc
浏览文件 @
0043c42b
...
@@ -92,7 +92,7 @@ TEST(JitKernel, vrelu) {
...
@@ -92,7 +92,7 @@ TEST(JitKernel, vrelu) {
#endif
#endif
auto
ttgts
=
GetCurrentUS
();
auto
ttgts
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
ker
->
Compute
(
x_data
,
ztgt_data
);
ker
->
Compute
(
x_data
,
ztgt_data
,
d
);
}
}
auto
ttgte
=
GetCurrentUS
();
auto
ttgte
=
GetCurrentUS
();
VLOG
(
30
)
<<
"Vec size "
<<
d
VLOG
(
30
)
<<
"Vec size "
<<
d
...
@@ -181,7 +181,7 @@ TEST(JitKernel, vexp) {
...
@@ -181,7 +181,7 @@ TEST(JitKernel, vexp) {
auto
ttgts
=
GetCurrentUS
();
auto
ttgts
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
ker
->
Compute
(
x_data
,
ztgt_data
);
ker
->
Compute
Deprecated
(
x_data
,
ztgt_data
);
}
}
auto
ttgte
=
GetCurrentUS
();
auto
ttgte
=
GetCurrentUS
();
...
@@ -222,7 +222,7 @@ void vsigmoid_better(
...
@@ -222,7 +222,7 @@ void vsigmoid_better(
y
[
i
]
=
(
x
[
i
]
<
min
)
?
min
:
((
x
[
i
]
>
max
)
?
max
:
x
[
i
]);
y
[
i
]
=
(
x
[
i
]
<
min
)
?
min
:
((
x
[
i
]
>
max
)
?
max
:
x
[
i
]);
y
[
i
]
=
0.
f
-
y
[
i
];
y
[
i
]
=
0.
f
-
y
[
i
];
}
}
vexp
->
Compute
(
y
,
y
);
vexp
->
Compute
Deprecated
(
y
,
y
);
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
1.
f
/
(
1.
f
+
y
[
i
]);
y
[
i
]
=
1.
f
/
(
1.
f
+
y
[
i
]);
}
}
...
@@ -253,7 +253,7 @@ TEST(JitKernel, vsigmoid) {
...
@@ -253,7 +253,7 @@ TEST(JitKernel, vsigmoid) {
auto
trefe
=
GetCurrentUS
();
auto
trefe
=
GetCurrentUS
();
auto
ttgts
=
GetCurrentUS
();
auto
ttgts
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
ker
->
Compute
(
x_data
,
ztgt_data
);
ker
->
Compute
Deprecated
(
x_data
,
ztgt_data
);
}
}
auto
ttgte
=
GetCurrentUS
();
auto
ttgte
=
GetCurrentUS
();
...
@@ -287,7 +287,7 @@ void vtanh_better(
...
@@ -287,7 +287,7 @@ void vtanh_better(
const
int
n
,
const
float
*
x
,
float
*
y
)
{
const
int
n
,
const
float
*
x
,
float
*
y
)
{
const
float
a
=
2.
f
,
b
=
-
1.
f
;
const
float
a
=
2.
f
,
b
=
-
1.
f
;
vscal
->
Compute
(
&
a
,
x
,
y
,
n
);
vscal
->
Compute
(
&
a
,
x
,
y
,
n
);
vsigmoid
->
Compute
(
y
,
y
);
vsigmoid
->
Compute
Deprecated
(
y
,
y
);
vscal
->
Compute
(
&
a
,
y
,
y
,
n
);
vscal
->
Compute
(
&
a
,
y
,
y
,
n
);
vaddbias
->
Compute
(
&
b
,
y
,
y
,
n
);
vaddbias
->
Compute
(
&
b
,
y
,
y
,
n
);
}
}
...
@@ -321,7 +321,7 @@ TEST(JitKernel, vtanh) {
...
@@ -321,7 +321,7 @@ TEST(JitKernel, vtanh) {
auto
trefe
=
GetCurrentUS
();
auto
trefe
=
GetCurrentUS
();
auto
ttgts
=
GetCurrentUS
();
auto
ttgts
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
ker
->
Compute
(
x_data
,
ztgt_data
);
ker
->
Compute
Deprecated
(
x_data
,
ztgt_data
);
}
}
auto
ttgte
=
GetCurrentUS
();
auto
ttgte
=
GetCurrentUS
();
...
@@ -344,8 +344,8 @@ void lstm_ctht_ref(
...
@@ -344,8 +344,8 @@ void lstm_ctht_ref(
const
std
::
shared_ptr
<
const
std
::
shared_ptr
<
const
paddle
::
operators
::
math
::
jitkernel
::
VExpKernel
<
float
>>&
vexp_1
,
const
paddle
::
operators
::
math
::
jitkernel
::
VExpKernel
<
float
>>&
vexp_1
,
const
int
d
,
float
*
gates
,
const
float
*
ct_1
,
float
*
ct
,
float
*
ht
)
{
const
int
d
,
float
*
gates
,
const
float
*
ct_1
,
float
*
ct
,
float
*
ht
)
{
vsigmoid_3d
->
Compute
(
gates
+
d
,
gates
+
d
);
vsigmoid_3d
->
Compute
Deprecated
(
gates
+
d
,
gates
+
d
);
vtanh_d
->
Compute
(
gates
,
gates
);
vtanh_d
->
Compute
Deprecated
(
gates
,
gates
);
const
float
*
i
=
gates
+
d
,
*
f
=
gates
+
d
*
2
,
*
o
=
gates
+
d
*
3
;
const
float
*
i
=
gates
+
d
,
*
f
=
gates
+
d
*
2
,
*
o
=
gates
+
d
*
3
;
const
float
min
=
SIGMOID_THRESHOLD_MIN
;
const
float
min
=
SIGMOID_THRESHOLD_MIN
;
const
float
max
=
SIGMOID_THRESHOLD_MAX
;
const
float
max
=
SIGMOID_THRESHOLD_MAX
;
...
@@ -355,7 +355,7 @@ void lstm_ctht_ref(
...
@@ -355,7 +355,7 @@ void lstm_ctht_ref(
// H_t = act_cell(C_t) * ogated
// H_t = act_cell(C_t) * ogated
float
tmp
=
ct
[
k
]
*
2
;
float
tmp
=
ct
[
k
]
*
2
;
tmp
=
0.
f
-
((
tmp
<
min
)
?
min
:
((
tmp
>
max
)
?
max
:
tmp
));
tmp
=
0.
f
-
((
tmp
<
min
)
?
min
:
((
tmp
>
max
)
?
max
:
tmp
));
vexp_1
->
Compute
(
&
tmp
,
&
tmp
);
vexp_1
->
Compute
Deprecated
(
&
tmp
,
&
tmp
);
tmp
=
2.
f
/
(
1.
f
+
tmp
)
-
1.
f
;
tmp
=
2.
f
/
(
1.
f
+
tmp
)
-
1.
f
;
ht
[
k
]
=
tmp
*
o
[
k
];
ht
[
k
]
=
tmp
*
o
[
k
];
}
}
...
@@ -373,13 +373,13 @@ void lstm_ctht_better(
...
@@ -373,13 +373,13 @@ void lstm_ctht_better(
const
paddle
::
operators
::
math
::
jitkernel
::
VAddKernel
<
float
>>&
vadd_d
,
const
paddle
::
operators
::
math
::
jitkernel
::
VAddKernel
<
float
>>&
vadd_d
,
const
int
d
,
float
*
gates
,
const
float
*
ct_1
,
float
*
ct
,
float
*
ht
)
{
const
int
d
,
float
*
gates
,
const
float
*
ct_1
,
float
*
ct
,
float
*
ht
)
{
int
d2
=
d
*
2
;
int
d2
=
d
*
2
;
vsigmoid_3d
->
Compute
(
gates
+
d
,
gates
+
d
);
vsigmoid_3d
->
Compute
Deprecated
(
gates
+
d
,
gates
+
d
);
vtanh_d
->
Compute
(
gates
,
gates
);
vtanh_d
->
Compute
Deprecated
(
gates
,
gates
);
vmul_d
->
Compute
(
gates
,
gates
+
d
,
gates
+
d
,
d
);
vmul_d
->
Compute
(
gates
,
gates
+
d
,
gates
+
d
,
d
);
vmul_d
->
Compute
(
ct_1
,
gates
+
d2
,
gates
+
d2
,
d
);
vmul_d
->
Compute
(
ct_1
,
gates
+
d2
,
gates
+
d2
,
d
);
vadd_d
->
Compute
(
gates
+
d
,
gates
+
d2
,
ct
,
d
);
vadd_d
->
Compute
(
gates
+
d
,
gates
+
d2
,
ct
,
d
);
/* H_t = act_cell(C_t) * ogated */
/* H_t = act_cell(C_t) * ogated */
vtanh_d
->
Compute
(
ct
,
gates
+
d2
);
vtanh_d
->
Compute
Deprecated
(
ct
,
gates
+
d2
);
vmul_d
->
Compute
(
gates
+
d2
,
gates
+
d
*
3
,
ht
,
d
);
vmul_d
->
Compute
(
gates
+
d2
,
gates
+
d
*
3
,
ht
,
d
);
}
}
...
@@ -736,7 +736,7 @@ void vaddrelu_better(
...
@@ -736,7 +736,7 @@ void vaddrelu_better(
const
paddle
::
operators
::
math
::
jitkernel
::
VReluKernel
<
float
>>&
vrelu
,
const
paddle
::
operators
::
math
::
jitkernel
::
VReluKernel
<
float
>>&
vrelu
,
const
float
*
x
,
const
float
*
y
,
float
*
z
,
int
d
)
{
const
float
*
x
,
const
float
*
y
,
float
*
z
,
int
d
)
{
vadd
->
Compute
(
x
,
y
,
z
,
d
);
vadd
->
Compute
(
x
,
y
,
z
,
d
);
vrelu
->
Compute
(
z
,
z
);
vrelu
->
Compute
Deprecated
(
z
,
z
);
}
}
TEST
(
JitKernel
,
vaddrelu
)
{
TEST
(
JitKernel
,
vaddrelu
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录