Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
极客李华
Paddle
提交
22125eba
P
Paddle
项目概览
极客李华
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
22125eba
编写于
11月 09, 2018
作者:
T
tensor-tang
提交者:
GitHub
11月 09, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #14321 from tensor-tang/fea/jit/vscal
Fea jitcode vscal vaddbias
上级
f1046d7e
5e64244f
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
193 addition
and
161 deletion
+193
-161
paddle/fluid/operators/math/jit_code.cc
paddle/fluid/operators/math/jit_code.cc
+33
-11
paddle/fluid/operators/math/jit_code.h
paddle/fluid/operators/math/jit_code.h
+24
-11
paddle/fluid/operators/math/jit_kernel.h
paddle/fluid/operators/math/jit_kernel.h
+4
-3
paddle/fluid/operators/math/jit_kernel_blas.cc
paddle/fluid/operators/math/jit_kernel_blas.cc
+113
-121
paddle/fluid/operators/math/jit_kernel_exp.cc
paddle/fluid/operators/math/jit_kernel_exp.cc
+12
-9
paddle/fluid/operators/math/jit_kernel_test.cc
paddle/fluid/operators/math/jit_kernel_test.cc
+7
-6
未找到文件。
paddle/fluid/operators/math/jit_code.cc
浏览文件 @
22125eba
...
...
@@ -24,21 +24,30 @@ namespace gen {
using
namespace
platform
::
jit
;
// NOLINT
bool
V
VVJitCode
::
init
(
int
d
)
{
bool
V
XXJitCode
::
init
(
int
d
,
int
scalar_index
)
{
// It's not necessary to use avx512 since it would slow down the frequency
// and this kernel is not compute bound.
return
MayIUse
(
avx
);
return
MayIUse
(
avx
)
&&
scalar_index
>=
0
&&
scalar_index
<=
2
;
}
void
V
VV
JitCode
::
generate
()
{
void
V
XX
JitCode
::
generate
()
{
// do not need push stack, and do not need save avx512reg if do not use avx512
int
offset
=
0
;
if
(
with_relu_
)
{
vxorps
(
ymm_zero
,
ymm_zero
,
ymm_zero
);
}
if
(
scalar_index_
==
1
)
{
vbroadcastss
(
ymm_src1
,
ptr
[
param1
]);
}
else
if
(
scalar_index_
==
2
)
{
vbroadcastss
(
ymm_src2
,
ptr
[
param2
]);
}
for
(
int
i
=
0
;
i
<
num_
/
AVX_FLOAT_BLOCK
;
++
i
)
{
vmovups
(
ymm_src1
,
ptr
[
param1
+
offset
]);
vmovups
(
ymm_src2
,
ptr
[
param2
+
offset
]);
if
(
scalar_index_
!=
1
)
{
vmovups
(
ymm_src1
,
ptr
[
param1
+
offset
]);
}
if
(
scalar_index_
!=
2
)
{
vmovups
(
ymm_src2
,
ptr
[
param2
+
offset
]);
}
if
(
type_
==
operand_type
::
mul
)
{
vmulps
(
ymm_dst
,
ymm_src1
,
ymm_src2
);
}
else
if
(
type_
==
operand_type
::
add
)
{
...
...
@@ -52,8 +61,12 @@ void VVVJitCode::generate() {
}
int
rest
=
num_
%
AVX_FLOAT_BLOCK
;
if
(
rest
>=
4
)
{
vmovups
(
xmm_src1
,
ptr
[
param1
+
offset
]);
vmovups
(
xmm_src2
,
ptr
[
param2
+
offset
]);
if
(
scalar_index_
!=
1
)
{
vmovups
(
xmm_src1
,
ptr
[
param1
+
offset
]);
}
if
(
scalar_index_
!=
2
)
{
vmovups
(
xmm_src2
,
ptr
[
param2
+
offset
]);
}
if
(
type_
==
operand_type
::
mul
)
{
vmulps
(
xmm_dst
,
xmm_src1
,
xmm_src2
);
}
else
if
(
type_
==
operand_type
::
add
)
{
...
...
@@ -67,8 +80,12 @@ void VVVJitCode::generate() {
rest
-=
4
;
}
if
(
rest
>=
2
)
{
vmovq
(
xmm_src1
,
ptr
[
param1
+
offset
]);
vmovq
(
xmm_src2
,
ptr
[
param2
+
offset
]);
if
(
scalar_index_
!=
1
)
{
vmovups
(
xmm_src1
,
ptr
[
param1
+
offset
]);
}
if
(
scalar_index_
!=
2
)
{
vmovups
(
xmm_src2
,
ptr
[
param2
+
offset
]);
}
if
(
type_
==
operand_type
::
mul
)
{
vmulps
(
xmm_dst
,
xmm_src1
,
xmm_src2
);
}
else
if
(
type_
==
operand_type
::
add
)
{
...
...
@@ -82,8 +99,12 @@ void VVVJitCode::generate() {
rest
-=
2
;
}
if
(
rest
>
0
)
{
vmovss
(
xmm_src1
,
ptr
[
param1
+
offset
]);
vmovss
(
xmm_src2
,
ptr
[
param2
+
offset
]);
if
(
scalar_index_
!=
1
)
{
vmovups
(
xmm_src1
,
ptr
[
param1
+
offset
]);
}
if
(
scalar_index_
!=
2
)
{
vmovups
(
xmm_src2
,
ptr
[
param2
+
offset
]);
}
if
(
type_
==
operand_type
::
mul
)
{
vmulss
(
xmm_dst
,
xmm_src1
,
xmm_src2
);
}
else
if
(
type_
==
operand_type
::
add
)
{
...
...
@@ -96,6 +117,7 @@ void VVVJitCode::generate() {
}
ret
();
}
}
// namespace gen
}
// namespace jitkernel
}
// namespace math
...
...
paddle/fluid/operators/math/jit_code.h
浏览文件 @
22125eba
...
...
@@ -29,33 +29,46 @@ using ymm_t = const Xbyak::Ymm;
using
zmm_t
=
const
Xbyak
::
Zmm
;
using
Label
=
Xbyak
::
Label
;
// function: vec = Operand(vec, vec) (maybe with relu)
typedef
enum
{
mul
=
0
,
add
}
operand_type
;
class
VVVJitCode
:
public
JitCode
{
// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu)
class
VXXJitCode
:
public
JitCode
{
public:
const
char
*
name
()
const
override
{
std
::
string
base
=
"VVVJitCode"
;
std
::
string
base
=
"VXXJitCode"
;
if
(
scalar_index_
==
1
)
{
base
+=
"_Scalar"
;
}
else
{
base
+=
"_Vec"
;
}
if
(
type_
==
operand_type
::
mul
)
{
base
+=
"_Mul"
;
}
else
if
(
type_
==
operand_type
::
add
)
{
base
+=
"_Add"
;
}
base
+=
(
with_relu_
?
"_relu"
:
""
);
if
(
scalar_index_
==
2
)
{
base
+=
"_Scalar"
;
}
else
{
base
+=
"_Vec"
;
}
base
+=
(
with_relu_
?
"_Relu"
:
""
);
return
base
.
c_str
();
}
explicit
VVVJitCode
(
int
d
,
operand_type
type
,
bool
with_relu
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
explicit
VXXJitCode
(
int
d
,
operand_type
type
,
int
scalar_index
,
bool
with_relu
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
:
JitCode
(
code_size
,
code_ptr
),
num_
(
d
),
type_
(
type
),
scalar_index_
(
scalar_index
),
with_relu_
(
with_relu
)
{}
static
bool
init
(
int
d
);
static
bool
init
(
int
d
,
int
scalar_index
=
0
);
void
generate
()
override
;
private:
int
num_
;
operand_type
type_
;
int
scalar_index_
;
bool
with_relu_
;
reg64_t
param1
{
abi_param1
};
reg64_t
param2
{
abi_param2
};
...
...
@@ -63,13 +76,13 @@ class VVVJitCode : public JitCode {
xmm_t
xmm_src1
=
xmm_t
(
0
);
xmm_t
xmm_src2
=
xmm_t
(
1
);
xmm_t
xmm_dst
=
xmm_t
(
1
);
xmm_t
xmm_zero
=
xmm_t
(
2
);
xmm_t
xmm_dst
=
xmm_t
(
2
);
xmm_t
xmm_zero
=
xmm_t
(
3
);
ymm_t
ymm_src1
=
ymm_t
(
0
);
ymm_t
ymm_src2
=
ymm_t
(
1
);
ymm_t
ymm_dst
=
ymm_t
(
1
);
ymm_t
ymm_zero
=
ymm_t
(
2
);
ymm_t
ymm_dst
=
ymm_t
(
2
);
ymm_t
ymm_zero
=
ymm_t
(
3
);
};
}
// namespace gen
...
...
paddle/fluid/operators/math/jit_kernel.h
浏览文件 @
22125eba
...
...
@@ -83,14 +83,15 @@ class VAddReluKernel : public Kernel {
template
<
typename
T
>
class
VScalKernel
:
public
Kernel
{
public:
virtual
void
Compute
(
const
T
a
,
const
T
*
x
,
T
*
y
)
const
=
0
;
v
irtual
void
Compute
(
const
T
a
,
T
*
x
)
const
=
0
;
// y = a.*x
v
oid
(
*
Compute
)(
const
T
*
,
const
T
*
,
T
*
,
int
)
;
};
template
<
typename
T
>
class
VAddBiasKernel
:
public
Kernel
{
public:
virtual
void
Compute
(
const
T
a
,
const
T
*
x
,
T
*
y
)
const
=
0
;
// y = a.+x
void
(
*
Compute
)(
const
T
*
,
const
T
*
,
T
*
,
int
);
};
template
<
typename
T
>
...
...
paddle/fluid/operators/math/jit_kernel_blas.cc
浏览文件 @
22125eba
...
...
@@ -57,6 +57,20 @@ void VAddReluRefer(const T* x, const T* y, T* z, int n) {
}
}
template
<
typename
T
>
void
VScalRefer
(
const
T
*
a
,
const
T
*
x
,
T
*
y
,
int
n
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
a
[
0
]
*
x
[
i
];
}
}
template
<
typename
T
>
void
VAddBiasRefer
(
const
T
*
a
,
const
T
*
x
,
T
*
y
,
int
n
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
a
[
0
]
+
x
[
i
];
}
}
#ifdef PADDLE_WITH_MKLML
template
<
typename
T
>
void
VMulMKL
(
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int
n
);
...
...
@@ -83,6 +97,28 @@ template <>
void
VAddMKL
<
double
>
(
const
double
*
x
,
const
double
*
y
,
double
*
z
,
int
n
)
{
platform
::
dynload
::
vdAdd
(
n
,
x
,
y
,
z
);
}
template
<
typename
T
>
void
VScalMKL
(
const
T
*
a
,
const
T
*
x
,
T
*
y
,
int
n
);
template
<
>
void
VScalMKL
<
float
>
(
const
float
*
a
,
const
float
*
x
,
float
*
y
,
int
n
)
{
if
(
x
==
y
)
{
platform
::
dynload
::
cblas_sscal
(
n
,
*
a
,
y
,
1
);
}
else
{
VScalRefer
<
float
>
(
a
,
x
,
y
,
n
);
}
}
template
<
>
void
VScalMKL
<
double
>
(
const
double
*
a
,
const
double
*
x
,
double
*
y
,
int
n
)
{
if
(
x
==
y
)
{
platform
::
dynload
::
cblas_dscal
(
n
,
*
a
,
y
,
1
);
}
else
{
VScalRefer
<
double
>
(
a
,
x
,
y
,
n
);
}
}
#endif
#define DECLARE_STATIC_FUNC \
...
...
@@ -102,7 +138,7 @@ class VMulKernelImpl : public VMulKernel<T> {
if
(
useJIT
(
d
))
{
// roughly estimate the size of code
size_t
sz
=
96
+
d
/
AVX_FLOAT_BLOCK
*
4
*
8
;
jitcode_
.
reset
(
new
gen
::
V
VVJitCode
(
d
,
gen
::
operand_type
::
mul
,
false
,
jitcode_
.
reset
(
new
gen
::
V
XXJitCode
(
d
,
gen
::
operand_type
::
mul
,
0
,
false
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
const
T
*
,
T
*
,
int
)
>
();
...
...
@@ -121,14 +157,14 @@ class VMulKernelImpl : public VMulKernel<T> {
#ifdef PADDLE_WITH_XBYAK
private:
std
::
unique_ptr
<
gen
::
V
VV
JitCode
>
jitcode_
{
nullptr
};
std
::
unique_ptr
<
gen
::
V
XX
JitCode
>
jitcode_
{
nullptr
};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template
<
>
bool
VMulKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
V
VV
JitCode
::
init
(
d
);
return
gen
::
V
XX
JitCode
::
init
(
d
);
}
#endif
...
...
@@ -153,7 +189,7 @@ class VAddKernelImpl : public VAddKernel<T> {
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
AVX_FLOAT_BLOCK
*
4
*
8
;
jitcode_
.
reset
(
new
gen
::
V
VVJitCode
(
d
,
gen
::
operand_type
::
add
,
false
,
jitcode_
.
reset
(
new
gen
::
V
XXJitCode
(
d
,
gen
::
operand_type
::
add
,
0
,
false
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
const
T
*
,
T
*
,
int
)
>
();
...
...
@@ -171,14 +207,14 @@ class VAddKernelImpl : public VAddKernel<T> {
#ifdef PADDLE_WITH_XBYAK
private:
std
::
unique_ptr
<
gen
::
V
VV
JitCode
>
jitcode_
{
nullptr
};
std
::
unique_ptr
<
gen
::
V
XX
JitCode
>
jitcode_
{
nullptr
};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template
<
>
bool
VAddKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
V
VV
JitCode
::
init
(
d
);
return
gen
::
V
XX
JitCode
::
init
(
d
);
}
#endif
...
...
@@ -203,7 +239,7 @@ class VAddReluKernelImpl : public VAddReluKernel<T> {
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
AVX_FLOAT_BLOCK
*
4
*
8
;
jitcode_
.
reset
(
new
gen
::
V
VVJitCode
(
d
,
gen
::
operand_type
::
add
,
true
,
jitcode_
.
reset
(
new
gen
::
V
XXJitCode
(
d
,
gen
::
operand_type
::
add
,
0
,
true
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
const
T
*
,
T
*
,
int
)
>
();
...
...
@@ -215,148 +251,106 @@ class VAddReluKernelImpl : public VAddReluKernel<T> {
#ifdef PADDLE_WITH_XBYAK
private:
std
::
unique_ptr
<
gen
::
V
VV
JitCode
>
jitcode_
{
nullptr
};
std
::
unique_ptr
<
gen
::
V
XX
JitCode
>
jitcode_
{
nullptr
};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template
<
>
bool
VAddReluKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
V
VV
JitCode
::
init
(
d
);
return
gen
::
V
XX
JitCode
::
init
(
d
);
}
#endif
#undef DECLARE_STATIC_FUNC
REGISTER_JITKERNEL
(
vmul
,
VMulKernel
);
REGISTER_JITKERNEL
(
vadd
,
VAddKernel
);
REGISTER_JITKERNEL
(
vaddrelu
,
VAddReluKernel
);
/* VSCAL JitKernel */
template
<
typename
T
,
platform
::
jit
::
cpu_isa_t
isa
,
jit_block
>
/* VScal JitKernel */
template
<
typename
T
>
class
VScalKernelImpl
:
public
VScalKernel
<
T
>
{
public:
explicit
VScalKernelImpl
(
int
d
)
:
VScalKernel
<
T
>
()
{
this
->
num_
=
d
;
}
void
Compute
(
const
T
a
,
const
T
*
x
,
T
*
y
)
const
override
{
for
(
int
i
=
0
;
i
<
this
->
num_
;
++
i
)
{
y
[
i
]
=
a
*
x
[
i
];
}
}
void
Compute
(
const
T
a
,
T
*
x
)
const
override
{
for
(
int
i
=
0
;
i
<
this
->
num_
;
++
i
)
{
x
[
i
]
=
a
*
x
[
i
];
DECLARE_STATIC_FUNC
;
explicit
VScalKernelImpl
(
int
d
)
:
VScalKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
AVX_FLOAT_BLOCK
*
4
*
8
;
jitcode_
.
reset
(
new
gen
::
VXXJitCode
(
d
,
gen
::
operand_type
::
mul
,
1
,
false
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
const
T
*
,
T
*
,
int
)
>
();
return
;
}
}
};
#endif
#ifdef PADDLE_WITH_MKLML
#define MKL_FLOAT(isa, block) \
template <> \
void VScalKernelImpl<float, isa, block>::Compute(const float a, float* x) \
const { \
platform::dynload::cblas_sscal(this->num_, a, x, 1); \
}
#define MKL_DOUBLE(isa, block) \
template <> \
void VScalKernelImpl<double, isa, block>::Compute(const double a, double* x) \
const { \
platform::dynload::cblas_dscal(this->num_, a, x, 1); \
}
FOR_EACH_ISA
(
MKL_FLOAT
,
kGT16
);
FOR_EACH_ISA_BLOCK
(
MKL_DOUBLE
);
if
(
useMKL
(
d
))
{
this
->
Compute
=
VScalMKL
<
T
>
;
return
;
}
#endif
#define INTRI8_FLOAT(isa) \
template <> \
void VScalKernelImpl<float, isa, kEQ8>::Compute( \
const float a, const float* x, float* y) const { \
__m256 tmp; \
__m256 scalar = _mm256_set1_ps(a); \
tmp = _mm256_loadu_ps(x); \
tmp = _mm256_mul_ps(tmp, scalar); \
_mm256_storeu_ps(y, tmp); \
}
#define INTRI8_INPLACE_FLOAT(isa) \
template <> \
void VScalKernelImpl<float, isa, kEQ8>::Compute(const float a, float* x) \
const { \
__m256 tmp; \
__m256 scalar = _mm256_set1_ps(a); \
tmp = _mm256_loadu_ps(x); \
tmp = _mm256_mul_ps(tmp, scalar); \
_mm256_storeu_ps(x, tmp); \
this
->
Compute
=
VScalRefer
<
T
>
;
}
#ifdef PADDLE_WITH_XBYAK
#ifdef __AVX__
INTRI8_FLOAT
(
jit
::
avx
);
INTRI8_INPLACE_FLOAT
(
jit
::
avx
);
#endif
#ifdef __AVX2__
INTRI8_FLOAT
(
jit
::
avx2
);
INTRI8_INPLACE_FLOAT
(
jit
::
avx2
);
private:
std
::
unique_ptr
<
gen
::
VXXJitCode
>
jitcode_
{
nullptr
};
#endif
#ifdef __AVX512F__
INTRI8_FLOAT
(
jit
::
avx512f
);
INTRI8_INPLACE_FLOAT
(
jit
::
avx512f
);
};
#ifdef PADDLE_WITH_XBYAK
template
<
>
bool
VScalKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
VXXJitCode
::
init
(
d
,
1
);
}
#endif
// TODO(TJ): eq16 test and complete avx512
#undef INTRI8_FLOAT
#undef INTRI8_INPLACE_FLOAT
#undef MKL_FLOAT
#undef MKL_DOUBLE
#ifdef PADDLE_WITH_MKLML
template
<
>
bool
VScalKernelImpl
<
float
>::
useMKL
(
int
d
)
{
return
d
>
512
;
}
template
<
>
bool
VScalKernelImpl
<
double
>::
useMKL
(
int
d
)
{
return
true
;
}
#endif
/* VAddBias JitKernel */
template
<
typename
T
,
platform
::
jit
::
cpu_isa_t
isa
,
jit_block
>
template
<
typename
T
>
class
VAddBiasKernelImpl
:
public
VAddBiasKernel
<
T
>
{
public:
explicit
VAddBiasKernelImpl
(
int
d
)
:
VAddBiasKernel
<
T
>
()
{
this
->
num_
=
d
;
}
void
Compute
(
const
T
a
,
const
T
*
x
,
T
*
y
)
const
override
{
for
(
int
i
=
0
;
i
<
this
->
num_
;
++
i
)
{
y
[
i
]
=
x
[
i
]
+
a
;
DECLARE_STATIC_FUNC
;
explicit
VAddBiasKernelImpl
(
int
d
)
:
VAddBiasKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
AVX_FLOAT_BLOCK
*
4
*
8
;
jitcode_
.
reset
(
new
gen
::
VXXJitCode
(
d
,
gen
::
operand_type
::
add
,
1
,
false
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
const
T
*
,
T
*
,
int
)
>
();
return
;
}
}
};
#define INTRI8_FLOAT(isa) \
template <> \
void VAddBiasKernelImpl<float, isa, kEQ8>::Compute( \
const float a, const float* x, float* y) const { \
__m256 tmp = _mm256_loadu_ps(x); \
tmp = _mm256_add_ps(tmp, _mm256_set1_ps(a)); \
_mm256_storeu_ps(y, tmp); \
}
#endif
#define INTRI16_FLOAT(isa) \
template <> \
void VAddBiasKernelImpl<float, isa, kEQ16>::Compute( \
const float a, const float* x, float* y) const { \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
tmp0 = _mm256_add_ps(tmp0, _mm256_set1_ps(a)); \
tmp1 = _mm256_add_ps(tmp1, _mm256_set1_ps(a)); \
_mm256_storeu_ps(y, tmp0); \
_mm256_storeu_ps(y + 8, tmp1); \
this
->
Compute
=
VAddBiasRefer
<
T
>
;
}
#ifdef PADDLE_WITH_XBYAK
#ifdef __AVX__
INTRI8_FLOAT
(
jit
::
avx
);
INTRI16_FLOAT
(
jit
::
avx
);
#endif
#ifdef __AVX2__
INTRI8_FLOAT
(
jit
::
avx2
);
INTRI16_FLOAT
(
jit
::
avx2
);
private:
std
::
unique_ptr
<
gen
::
VXXJitCode
>
jitcode_
{
nullptr
};
#endif
#ifdef __AVX512F__
INTRI8_FLOAT
(
jit
::
avx512f
);
INTRI16_FLOAT
(
jit
::
avx512f
);
};
#ifdef PADDLE_WITH_XBYAK
template
<
>
bool
VAddBiasKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
VXXJitCode
::
init
(
d
,
1
);
}
#endif
// TODO(TJ): eq16 test and complete avx512
#undef INTRI8_FLOAT
#undef INTRI16_FLOAT
#undef DECLARE_STATIC_FUNC
REGISTER_JITKERNEL
(
vmul
,
VMulKernel
);
REGISTER_JITKERNEL
(
vadd
,
VAddKernel
);
REGISTER_JITKERNEL
(
vaddrelu
,
VAddReluKernel
);
REGISTER_JITKERNEL
(
vscal
,
VScalKernel
);
REGISTER_JITKERNEL
(
vaddbias
,
VAddBiasKernel
);
/* VRelu JitKernel */
template
<
typename
T
,
platform
::
jit
::
cpu_isa_t
isa
,
jit_block
>
...
...
@@ -467,8 +461,6 @@ class VIdentityKernelImpl : public VIdentityKernel<T> {
void
Compute
(
const
T
*
x
,
T
*
y
)
const
override
{}
};
REGISTER_JITKERNEL_DEPRECATED
(
vscal
,
VScalKernel
);
REGISTER_JITKERNEL_DEPRECATED
(
vaddb
,
VAddBiasKernel
);
REGISTER_JITKERNEL_DEPRECATED
(
vrelu
,
VReluKernel
);
REGISTER_JITKERNEL_DEPRECATED
(
videntity
,
VIdentityKernel
);
...
...
paddle/fluid/operators/math/jit_kernel_exp.cc
浏览文件 @
22125eba
...
...
@@ -409,10 +409,11 @@ class VTanhKernelImpl : public VTanhKernel<T> {
vaddbias_
=
KernelPool
::
Instance
().
template
Get
<
VAddBiasKernel
<
T
>
>
(
d
);
}
void
Compute
(
const
T
*
x
,
T
*
y
)
const
override
{
vscal_
->
Compute
(
static_cast
<
T
>
(
2
),
x
,
y
);
const
T
a
=
static_cast
<
T
>
(
2
),
b
=
static_cast
<
T
>
(
-
1
);
vscal_
->
Compute
(
&
a
,
x
,
y
,
this
->
num_
);
vsigmoid_
->
Compute
(
y
,
y
);
vscal_
->
Compute
(
static_cast
<
T
>
(
2
),
y
);
vaddbias_
->
Compute
(
static_cast
<
T
>
(
-
1
),
y
,
y
);
vscal_
->
Compute
(
&
a
,
y
,
y
,
this
->
num_
);
vaddbias_
->
Compute
(
&
b
,
y
,
y
,
this
->
num_
);
}
private:
...
...
@@ -472,10 +473,11 @@ class VTanhKernelImpl : public VTanhKernel<T> {
_mm256_storeu_ps(y, tmp); \
x += AVX_FLOAT_BLOCK; \
y += AVX_FLOAT_BLOCK; \
vscal_->Compute(2.f, x, y); \
const float a = 2.f, b = -1.f; \
vscal_->Compute(&a, x, y, this->num_); \
vsigmoid_->Compute(y, y); \
vscal_->Compute(
2.f, y);
\
vaddbias_->Compute(
-1.f, y, y);
\
vscal_->Compute(
&a, y, y, this->num_);
\
vaddbias_->Compute(
&b, y, y, this->num_);
\
}
#define INTRI_GT16_FLOAT(isa, expisa) \
...
...
@@ -502,10 +504,11 @@ class VTanhKernelImpl : public VTanhKernel<T> {
} \
x += this->end_; \
y += this->end_; \
vscal_->Compute(2.f, x, y); \
const float a = 2.f, b = -1.f; \
vscal_->Compute(&a, x, y, this->num_); \
vsigmoid_->Compute(y, y); \
vscal_->Compute(
2.f, y);
\
vaddbias_->Compute(
-1.f, y, y);
\
vscal_->Compute(
&a, y, y, this->num_);
\
vaddbias_->Compute(
&b, y, y, this->num_);
\
}
#ifdef __AVX__
...
...
paddle/fluid/operators/math/jit_kernel_test.cc
浏览文件 @
22125eba
...
...
@@ -128,7 +128,7 @@ TEST(JitKernel, vaddbias) {
auto
trefe
=
GetCurrentUS
();
auto
ttgts
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
ker
->
Compute
(
a
,
x_data
,
ztgt_data
);
ker
->
Compute
(
&
a
,
x_data
,
ztgt_data
,
d
);
}
auto
ttgte
=
GetCurrentUS
();
...
...
@@ -281,10 +281,11 @@ void vtanh_better(
const
paddle
::
operators
::
math
::
jitkernel
::
VAddBiasKernel
<
float
>>&
vaddbias
,
const
int
n
,
const
float
*
x
,
float
*
y
)
{
vscal
->
Compute
(
2.
f
,
x
,
y
);
const
float
a
=
2.
f
,
b
=
-
1.
f
;
vscal
->
Compute
(
&
a
,
x
,
y
,
n
);
vsigmoid
->
Compute
(
y
,
y
);
vscal
->
Compute
(
2.
f
,
y
);
vaddbias
->
Compute
(
-
1.
f
,
y
,
y
);
vscal
->
Compute
(
&
a
,
y
,
y
,
n
);
vaddbias
->
Compute
(
&
b
,
y
,
y
,
n
);
}
TEST
(
JitKernel
,
vtanh
)
{
...
...
@@ -531,12 +532,12 @@ TEST(JitKernel, vscal) {
auto
ttgts
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
ker
->
Compute
(
a
,
x_data
,
ztgt_data
);
ker
->
Compute
(
&
a
,
x_data
,
ztgt_data
,
d
);
}
auto
ttgte
=
GetCurrentUS
();
auto
ttgts1
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
ker
->
Compute
(
a
,
y_data
);
ker
->
Compute
(
&
a
,
y_data
,
y_data
,
d
);
}
auto
ttgte1
=
GetCurrentUS
();
VLOG
(
3
)
<<
"Vec size "
<<
d
<<
": refer takes: "
<<
(
trefe
-
trefs
)
/
repeat
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录