Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
bb09e310
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
bb09e310
编写于
11月 06, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add vadd jitcode
test=develop
上级
d55481cf
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
135 addition
and
65 deletion
+135
-65
paddle/fluid/operators/math/jit_code.cc
paddle/fluid/operators/math/jit_code.cc
+36
-0
paddle/fluid/operators/math/jit_code.h
paddle/fluid/operators/math/jit_code.h
+24
-0
paddle/fluid/operators/math/jit_kernel.h
paddle/fluid/operators/math/jit_kernel.h
+1
-1
paddle/fluid/operators/math/jit_kernel_blas.cc
paddle/fluid/operators/math/jit_kernel_blas.cc
+64
-54
paddle/fluid/operators/math/jit_kernel_rnn.cc
paddle/fluid/operators/math/jit_kernel_rnn.cc
+5
-5
paddle/fluid/operators/math/jit_kernel_test.cc
paddle/fluid/operators/math/jit_kernel_test.cc
+5
-5
未找到文件。
paddle/fluid/operators/math/jit_code.cc
浏览文件 @
bb09e310
...
...
@@ -66,6 +66,42 @@ void VMulJitCode::generate() {
ret
();
}
bool
VAddJitCode
::
init
(
int
d
)
{
return
MayIUse
(
avx
);
}
void
VAddJitCode
::
generate
()
{
int
offset
=
0
;
for
(
int
i
=
0
;
i
<
num_
/
AVX_FLOAT_BLOCK
;
++
i
)
{
vmovups
(
ymm_src1
,
ptr
[
param1
+
offset
]);
vmovups
(
ymm_src2
,
ptr
[
param2
+
offset
]);
vaddps
(
ymm_dst
,
ymm_src1
,
ymm_src2
);
vmovups
(
ptr
[
param3
+
offset
],
ymm_dst
);
offset
+=
sizeof
(
float
)
*
AVX_FLOAT_BLOCK
;
}
int
rest
=
num_
%
AVX_FLOAT_BLOCK
;
if
(
rest
>=
4
)
{
vmovups
(
xmm_src1
,
ptr
[
param1
+
offset
]);
vmovups
(
xmm_src2
,
ptr
[
param2
+
offset
]);
vaddps
(
xmm_dst
,
xmm_src1
,
xmm_src2
);
vmovups
(
ptr
[
param3
+
offset
],
xmm_dst
);
offset
+=
sizeof
(
float
)
*
4
;
rest
-=
4
;
}
if
(
rest
>=
2
)
{
vmovq
(
xmm_src1
,
ptr
[
param1
+
offset
]);
vmovq
(
xmm_src2
,
ptr
[
param2
+
offset
]);
vaddps
(
xmm_dst
,
xmm_src1
,
xmm_src2
);
vmovq
(
ptr
[
param3
+
offset
],
xmm_dst
);
offset
+=
sizeof
(
float
)
*
2
;
rest
-=
2
;
}
if
(
rest
>
0
)
{
vmovss
(
xmm_src1
,
ptr
[
param1
+
offset
]);
vmovss
(
xmm_src2
,
ptr
[
param2
+
offset
]);
vaddss
(
xmm_dst
,
xmm_src1
,
xmm_src2
);
vmovss
(
ptr
[
param3
+
offset
],
xmm_dst
);
}
ret
();
}
}
// namespace gen
}
// namespace jitkernel
}
// namespace math
...
...
paddle/fluid/operators/math/jit_code.h
浏览文件 @
bb09e310
...
...
@@ -53,6 +53,30 @@ class VMulJitCode : public JitCode {
ymm_t
ymm_dst
=
ymm_t
(
2
);
};
class
VAddJitCode
:
public
JitCode
{
public:
DECLARE_JIT_CODE
(
VAddJitCode
);
explicit
VAddJitCode
(
int
d
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
:
JitCode
(
code_size
,
code_ptr
),
num_
(
d
)
{}
static
bool
init
(
int
d
);
void
generate
()
override
;
private:
int
num_
;
reg64_t
param1
{
abi_param1
};
reg64_t
param2
{
abi_param2
};
reg64_t
param3
{
abi_param3
};
xmm_t
xmm_src1
=
xmm_t
(
0
);
xmm_t
xmm_src2
=
xmm_t
(
1
);
xmm_t
xmm_dst
=
xmm_t
(
2
);
ymm_t
ymm_src1
=
ymm_t
(
0
);
ymm_t
ymm_src2
=
ymm_t
(
1
);
ymm_t
ymm_dst
=
ymm_t
(
2
);
};
}
// namespace gen
}
// namespace jitkernel
}
// namespace math
...
...
paddle/fluid/operators/math/jit_kernel.h
浏览文件 @
bb09e310
...
...
@@ -71,7 +71,7 @@ class VMulKernel : public Kernel {
template
<
typename
T
>
class
VAddKernel
:
public
Kernel
{
public:
v
irtual
void
Compute
(
const
T
*
x
,
const
T
*
y
,
T
*
z
)
const
=
0
;
v
oid
(
*
Compute
)(
const
T
*
,
const
T
*
,
T
*
,
int
)
;
};
template
<
typename
T
>
...
...
paddle/fluid/operators/math/jit_kernel_blas.cc
浏览文件 @
bb09e310
...
...
@@ -39,6 +39,13 @@ void VMulRefer(const T* x, const T* y, T* z, int n) {
}
}
template
<
typename
T
>
void
VAddRefer
(
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int
n
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
z
[
i
]
=
x
[
i
]
+
y
[
i
];
}
}
#ifdef PADDLE_WITH_MKLML
template
<
typename
T
>
void
VMulMKL
(
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int
n
);
...
...
@@ -47,22 +54,38 @@ template <>
void
VMulMKL
<
float
>
(
const
float
*
x
,
const
float
*
y
,
float
*
z
,
int
n
)
{
platform
::
dynload
::
vsMul
(
n
,
x
,
y
,
z
);
}
template
<
>
void
VMulMKL
<
double
>
(
const
double
*
x
,
const
double
*
y
,
double
*
z
,
int
n
)
{
platform
::
dynload
::
vdMul
(
n
,
x
,
y
,
z
);
}
template
<
typename
T
>
void
VAddMKL
(
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int
n
);
template
<
>
void
VAddMKL
<
float
>
(
const
float
*
x
,
const
float
*
y
,
float
*
z
,
int
n
)
{
platform
::
dynload
::
vsAdd
(
n
,
x
,
y
,
z
);
}
template
<
>
void
VAddMKL
<
double
>
(
const
double
*
x
,
const
double
*
y
,
double
*
z
,
int
n
)
{
platform
::
dynload
::
vdAdd
(
n
,
x
,
y
,
z
);
}
#endif
#define DECLARE_STATIC_FUNC \
static inline std::string name(int d) { \
PADDLE_THROW("DType should be either float or double"); \
} \
static inline bool useJIT(int d) { return false; } \
static inline bool useMKL(int d) { return false; }
/* VMUL JitKernel */
template
<
typename
T
>
class
VMulKernelImpl
:
public
VMulKernel
<
T
>
{
public:
static
inline
std
::
string
name
(
int
d
)
{
PADDLE_THROW
(
"DType should be either float or double"
);
}
static
inline
bool
useJIT
(
int
d
)
{
return
false
;
}
static
inline
bool
useMKL
(
int
d
)
{
return
false
;
}
DECLARE_STATIC_FUNC
;
explicit
VMulKernelImpl
(
int
d
)
:
VMulKernel
<
T
>
()
{
if
(
useJIT
(
d
))
{
// roughly estimate the size of code
...
...
@@ -100,63 +123,51 @@ bool VMulKernelImpl<double>::useMKL(int d) {
return
true
;
}
REGISTER_JITKERNEL
(
vmul
,
VMulKernel
);
/* VADD JitKernel */
template
<
typename
T
,
platform
::
jit
::
cpu_isa_t
isa
,
jit_block
>
/* VAdd JitKernel */
template
<
typename
T
>
class
VAddKernelImpl
:
public
VAddKernel
<
T
>
{
public:
explicit
VAddKernelImpl
(
int
d
)
:
VAddKernel
<
T
>
()
{
this
->
num_
=
d
;
}
void
Compute
(
const
T
*
x
,
const
T
*
y
,
T
*
z
)
const
override
{
for
(
int
i
=
0
;
i
<
this
->
num_
;
++
i
)
{
z
[
i
]
=
x
[
i
]
+
y
[
i
];
DECLARE_STATIC_FUNC
;
explicit
VAddKernelImpl
(
int
d
)
:
VAddKernel
<
T
>
()
{
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
AVX_FLOAT_BLOCK
*
4
*
8
;
jitcode_
.
reset
(
new
gen
::
VAddJitCode
(
d
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
const
T
*
,
T
*
,
int
)
>
();
return
;
}
#ifdef PADDLE_WITH_MKLML
if
(
useMKL
(
d
))
{
this
->
Compute
=
VAddMKL
<
T
>
;
return
;
}
#endif
this
->
Compute
=
VAddRefer
<
T
>
;
}
private:
std
::
unique_ptr
<
gen
::
VAddJitCode
>
jitcode_
{
nullptr
};
};
#ifdef PADDLE_WITH_MKLML
#define MKL_FLOAT(isa, block) \
template <> \
void VAddKernelImpl<float, isa, block>::Compute( \
const float* x, const float* y, float* z) const { \
platform::dynload::vsAdd(this->num_, x, y, z); \
}
template
<
>
bool
VAddKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
VAddJitCode
::
init
(
d
);
}
#define MKL_DOUBLE(isa, block) \
template <> \
void VAddKernelImpl<double, isa, block>::Compute( \
const double* x, const double* y, double* z) const { \
platform::dynload::vdAdd(this->num_, x, y, z); \
}
template
<
>
bool
VAddKernelImpl
<
float
>::
useMKL
(
int
d
)
{
return
d
>
512
;
}
FOR_EACH_ISA
(
MKL_FLOAT
,
kGT16
);
FOR_EACH_ISA_BLOCK
(
MKL_DOUBLE
);
#endif
template
<
>
bool
VAddKernelImpl
<
double
>::
useMKL
(
int
d
)
{
return
true
;
}
#define INTRI8_FLOAT(isa) \
template <> \
void VAddKernelImpl<float, isa, kEQ8>::Compute( \
const float* x, const float* y, float* z) const { \
__m256 tmpx, tmpy; \
tmpx = _mm256_loadu_ps(x); \
tmpy = _mm256_loadu_ps(y); \
tmpx = _mm256_add_ps(tmpx, tmpy); \
_mm256_storeu_ps(z, tmpx); \
}
#ifdef __AVX__
INTRI8_FLOAT
(
jit
::
avx
);
#endif
#ifdef __AVX2__
INTRI8_FLOAT
(
jit
::
avx2
);
#endif
#ifdef __AVX512F__
INTRI8_FLOAT
(
jit
::
avx512f
);
#endif
// TODO(TJ): eq16 test and complete avx512
#undef DECLARE_STATIC_FUNC
#undef INTRI8_FLOAT
#undef MKL_FLOAT
#undef MKL_DOUBLE
REGISTER_JITKERNEL
(
vmul
,
VMulKernel
);
REGISTER_JITKERNEL
(
vadd
,
VAddKernel
);
/* VSCAL JitKernel */
template
<
typename
T
,
platform
::
jit
::
cpu_isa_t
isa
,
jit_block
>
...
...
@@ -480,7 +491,6 @@ INTRI_COMMON_FLOAT(jit::avx512f, kGT16);
#undef INTRI16_FLOAT
#undef INTRI_COMMON_FLOAT
REGISTER_JITKERNEL_DEPRECATED
(
vadd
,
VAddKernel
);
REGISTER_JITKERNEL_DEPRECATED
(
vscal
,
VScalKernel
);
REGISTER_JITKERNEL_DEPRECATED
(
vaddb
,
VAddBiasKernel
);
REGISTER_JITKERNEL_DEPRECATED
(
vrelu
,
VReluKernel
);
...
...
paddle/fluid/operators/math/jit_kernel_rnn.cc
浏览文件 @
bb09e310
...
...
@@ -181,7 +181,7 @@ class LSTMKernelImpl : public LSTMKernel<T> {
act_cand_d_
->
Compute
(
gates
,
gates
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
gates
+
d_
,
d_
);
vmul_d_
->
Compute
(
ct_1
,
gates
+
d2_
,
gates
+
d2_
,
d_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d2_
,
ct
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d2_
,
ct
,
d_
);
/* H_t = act_cell(C_t) * ogated */
act_cell_d_
->
Compute
(
ct
,
gates
+
d2_
);
...
...
@@ -291,16 +291,16 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
/* get fgated and igated*/
vmul_d_
->
Compute
(
wp_data
,
ct_1
,
checked
,
d_
);
vmul_d_
->
Compute
(
wp_data
+
d_
,
ct_1
,
checked
+
d_
,
d_
);
vadd_d2_
->
Compute
(
checked
,
gates
+
d_
,
gates
+
d_
);
vadd_d2_
->
Compute
(
checked
,
gates
+
d_
,
gates
+
d_
,
d2_
);
act_gate_d2_
->
Compute
(
gates
+
d_
,
gates
+
d_
);
/* C_t = C_t-1 * fgated + cand_gated * igated*/
act_cand_d_
->
Compute
(
gates
,
gates
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
gates
+
d_
,
d_
);
vmul_d_
->
Compute
(
ct_1
,
gates
+
d2_
,
gates
+
d2_
,
d_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d2_
,
ct
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d2_
,
ct
,
d_
);
/* get ogated*/
vmul_d_
->
Compute
(
wp_data
+
d2_
,
ct
,
gates
+
d_
,
d_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d3_
,
gates
+
d3_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d3_
,
gates
+
d3_
,
d_
);
act_gate_d_
->
Compute
(
gates
+
d3_
,
gates
+
d3_
);
/* H_t = act_cell(C_t) * ogated */
act_cell_d_
->
Compute
(
ct
,
gates
+
d2_
);
...
...
@@ -314,7 +314,7 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
ct
,
d_
);
/* get outgated, put W_oc * C_t on igated */
vmul_d_
->
Compute
(
wp_data
+
d2_
,
ct
,
gates
+
d_
,
d_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d3_
,
gates
+
d3_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d3_
,
gates
+
d3_
,
d_
);
/* H_t = act_cell(C_t) * ogated */
act_gate_d_
->
Compute
(
gates
+
d3_
,
gates
+
d3_
);
act_cell_d_
->
Compute
(
ct
,
gates
+
d2_
);
...
...
paddle/fluid/operators/math/jit_kernel_test.cc
浏览文件 @
bb09e310
...
...
@@ -371,7 +371,7 @@ void lstm_ctht_better(
vtanh_d
->
Compute
(
gates
,
gates
);
vmul_d
->
Compute
(
gates
,
gates
+
d
,
gates
+
d
,
d
);
vmul_d
->
Compute
(
ct_1
,
gates
+
d2
,
gates
+
d2
,
d
);
vadd_d
->
Compute
(
gates
+
d
,
gates
+
d2
,
ct
);
vadd_d
->
Compute
(
gates
+
d
,
gates
+
d2
,
ct
,
d
);
/* H_t = act_cell(C_t) * ogated */
vtanh_d
->
Compute
(
ct
,
gates
+
d2
);
vmul_d
->
Compute
(
gates
+
d2
,
gates
+
d
*
3
,
ht
,
d
);
...
...
@@ -695,7 +695,7 @@ TEST(JitKernel, vadd) {
auto
ttgts
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
ker
->
Compute
(
x_data
,
y_data
,
ztgt_data
);
ker
->
Compute
(
x_data
,
y_data
,
ztgt_data
,
d
);
}
auto
ttgte
=
GetCurrentUS
();
...
...
@@ -723,8 +723,8 @@ void vaddrelu_better(
const
paddle
::
operators
::
math
::
jitkernel
::
VAddKernel
<
float
>>&
vadd
,
const
std
::
shared_ptr
<
const
paddle
::
operators
::
math
::
jitkernel
::
VReluKernel
<
float
>>&
vrelu
,
const
float
*
x
,
const
float
*
y
,
float
*
z
)
{
vadd
->
Compute
(
x
,
y
,
z
);
const
float
*
x
,
const
float
*
y
,
float
*
z
,
int
d
)
{
vadd
->
Compute
(
x
,
y
,
z
,
d
);
vrelu
->
Compute
(
z
,
z
);
}
...
...
@@ -752,7 +752,7 @@ TEST(JitKernel, vaddrelu) {
auto
trefe
=
GetCurrentUS
();
auto
tmkls
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
vaddrelu_better
(
vadd
,
vrelu
,
x_data
,
y_data
,
zref_data
);
vaddrelu_better
(
vadd
,
vrelu
,
x_data
,
y_data
,
zref_data
,
d
);
}
auto
tmkle
=
GetCurrentUS
();
auto
ttgts
=
GetCurrentUS
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录