Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
03e11f3f
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
03e11f3f
编写于
11月 08, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add vscal jitcode
上级
5b7a9dd7
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
150 addition
and
85 deletion
+150
-85
paddle/fluid/operators/math/jit_code.cc
paddle/fluid/operators/math/jit_code.cc
+35
-0
paddle/fluid/operators/math/jit_code.h
paddle/fluid/operators/math/jit_code.h
+28
-2
paddle/fluid/operators/math/jit_kernel.h
paddle/fluid/operators/math/jit_kernel.h
+1
-2
paddle/fluid/operators/math/jit_kernel_blas.cc
paddle/fluid/operators/math/jit_kernel_blas.cc
+72
-71
paddle/fluid/operators/math/jit_kernel_exp.cc
paddle/fluid/operators/math/jit_kernel_exp.cc
+9
-6
paddle/fluid/operators/math/jit_kernel_test.cc
paddle/fluid/operators/math/jit_kernel_test.cc
+5
-4
未找到文件。
paddle/fluid/operators/math/jit_code.cc
浏览文件 @
03e11f3f
...
@@ -96,6 +96,41 @@ void VVVJitCode::generate() {
...
@@ -96,6 +96,41 @@ void VVVJitCode::generate() {
}
}
ret
();
ret
();
}
}
bool
VScalJitCode
::
init
(
int
d
)
{
return
MayIUse
(
avx
);
}
void
VScalJitCode
::
generate
()
{
int
offset
=
0
;
vbroadcastss
(
ymm_src1
,
ptr
[
param1
]);
for
(
int
i
=
0
;
i
<
num_
/
AVX_FLOAT_BLOCK
;
++
i
)
{
vmovups
(
ymm_src2
,
ptr
[
param2
+
offset
]);
vmulps
(
ymm_dst
,
ymm_src1
,
ymm_src2
);
vmovups
(
ptr
[
param3
+
offset
],
ymm_dst
);
offset
+=
sizeof
(
float
)
*
AVX_FLOAT_BLOCK
;
}
int
rest
=
num_
%
AVX_FLOAT_BLOCK
;
if
(
rest
>=
4
)
{
vmovups
(
xmm_src2
,
ptr
[
param2
+
offset
]);
vmulps
(
xmm_dst
,
xmm_src1
,
xmm_src2
);
vmovups
(
ptr
[
param3
+
offset
],
xmm_dst
);
offset
+=
sizeof
(
float
)
*
4
;
rest
-=
4
;
}
if
(
rest
>=
2
)
{
vmovq
(
xmm_src2
,
ptr
[
param2
+
offset
]);
vmulps
(
xmm_dst
,
xmm_src1
,
xmm_src2
);
vmovq
(
ptr
[
param3
+
offset
],
xmm_dst
);
offset
+=
sizeof
(
float
)
*
2
;
rest
-=
2
;
}
if
(
rest
>
0
)
{
vmovss
(
xmm_src2
,
ptr
[
param2
+
offset
]);
vmulss
(
xmm_dst
,
xmm_src1
,
xmm_src2
);
vmovss
(
ptr
[
param3
+
offset
],
xmm_dst
);
}
ret
();
}
}
// namespace gen
}
// namespace gen
}
// namespace jitkernel
}
// namespace jitkernel
}
// namespace math
}
// namespace math
...
...
paddle/fluid/operators/math/jit_code.h
浏览文件 @
03e11f3f
...
@@ -29,9 +29,9 @@ using ymm_t = const Xbyak::Ymm;
...
@@ -29,9 +29,9 @@ using ymm_t = const Xbyak::Ymm;
using
zmm_t
=
const
Xbyak
::
Zmm
;
using
zmm_t
=
const
Xbyak
::
Zmm
;
using
Label
=
Xbyak
::
Label
;
using
Label
=
Xbyak
::
Label
;
// function: vec = Operand(vec, vec) (maybe with relu)
typedef
enum
{
mul
=
0
,
add
}
operand_type
;
typedef
enum
{
mul
=
0
,
add
}
operand_type
;
// function: vec = Operand(vec, vec) (maybe with relu)
class
VVVJitCode
:
public
JitCode
{
class
VVVJitCode
:
public
JitCode
{
public:
public:
const
char
*
name
()
const
override
{
const
char
*
name
()
const
override
{
...
@@ -41,7 +41,7 @@ class VVVJitCode : public JitCode {
...
@@ -41,7 +41,7 @@ class VVVJitCode : public JitCode {
}
else
if
(
type_
==
operand_type
::
add
)
{
}
else
if
(
type_
==
operand_type
::
add
)
{
base
+=
"_Add"
;
base
+=
"_Add"
;
}
}
base
+=
(
with_relu_
?
"_
r
elu"
:
""
);
base
+=
(
with_relu_
?
"_
R
elu"
:
""
);
return
base
.
c_str
();
return
base
.
c_str
();
}
}
explicit
VVVJitCode
(
int
d
,
operand_type
type
,
bool
with_relu
,
explicit
VVVJitCode
(
int
d
,
operand_type
type
,
bool
with_relu
,
...
@@ -72,6 +72,32 @@ class VVVJitCode : public JitCode {
...
@@ -72,6 +72,32 @@ class VVVJitCode : public JitCode {
ymm_t
ymm_zero
=
ymm_t
(
2
);
ymm_t
ymm_zero
=
ymm_t
(
2
);
};
};
class
VScalJitCode
:
public
JitCode
{
public:
DECLARE_JIT_CODE
(
VScalJitCode
);
explicit
VScalJitCode
(
int
d
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
:
JitCode
(
code_size
,
code_ptr
),
num_
(
d
)
{}
static
bool
init
(
int
d
);
void
generate
()
override
;
private:
int
num_
;
reg64_t
param1
{
abi_param1
};
reg64_t
param2
{
abi_param2
};
reg64_t
param3
{
abi_param3
};
xmm_t
xmm_src1
=
xmm_t
(
0
);
xmm_t
xmm_src2
=
xmm_t
(
1
);
xmm_t
xmm_dst
=
xmm_t
(
1
);
xmm_t
xmm_zero
=
xmm_t
(
2
);
ymm_t
ymm_src1
=
ymm_t
(
0
);
ymm_t
ymm_src2
=
ymm_t
(
1
);
ymm_t
ymm_dst
=
ymm_t
(
1
);
ymm_t
ymm_zero
=
ymm_t
(
2
);
};
}
// namespace gen
}
// namespace gen
}
// namespace jitkernel
}
// namespace jitkernel
}
// namespace math
}
// namespace math
...
...
paddle/fluid/operators/math/jit_kernel.h
浏览文件 @
03e11f3f
...
@@ -83,8 +83,7 @@ class VAddReluKernel : public Kernel {
...
@@ -83,8 +83,7 @@ class VAddReluKernel : public Kernel {
template
<
typename
T
>
template
<
typename
T
>
class
VScalKernel
:
public
Kernel
{
class
VScalKernel
:
public
Kernel
{
public:
public:
virtual
void
Compute
(
const
T
a
,
const
T
*
x
,
T
*
y
)
const
=
0
;
void
(
*
Compute
)(
const
T
*
,
const
T
*
,
T
*
,
int
);
virtual
void
Compute
(
const
T
a
,
T
*
x
)
const
=
0
;
};
};
template
<
typename
T
>
template
<
typename
T
>
...
...
paddle/fluid/operators/math/jit_kernel_blas.cc
浏览文件 @
03e11f3f
...
@@ -57,6 +57,13 @@ void VAddReluRefer(const T* x, const T* y, T* z, int n) {
...
@@ -57,6 +57,13 @@ void VAddReluRefer(const T* x, const T* y, T* z, int n) {
}
}
}
}
template
<
typename
T
>
void
VScalRefer
(
const
T
*
a
,
const
T
*
x
,
T
*
y
,
int
n
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
a
[
0
]
*
x
[
i
];
}
}
#ifdef PADDLE_WITH_MKLML
#ifdef PADDLE_WITH_MKLML
template
<
typename
T
>
template
<
typename
T
>
void
VMulMKL
(
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int
n
);
void
VMulMKL
(
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int
n
);
...
@@ -83,6 +90,28 @@ template <>
...
@@ -83,6 +90,28 @@ template <>
void
VAddMKL
<
double
>
(
const
double
*
x
,
const
double
*
y
,
double
*
z
,
int
n
)
{
void
VAddMKL
<
double
>
(
const
double
*
x
,
const
double
*
y
,
double
*
z
,
int
n
)
{
platform
::
dynload
::
vdAdd
(
n
,
x
,
y
,
z
);
platform
::
dynload
::
vdAdd
(
n
,
x
,
y
,
z
);
}
}
template
<
typename
T
>
void
VScalMKL
(
const
T
*
a
,
const
T
*
x
,
T
*
y
,
int
n
);
template
<
>
void
VScalMKL
<
float
>
(
const
float
*
a
,
const
float
*
x
,
float
*
y
,
int
n
)
{
if
(
x
==
y
)
{
platform
::
dynload
::
cblas_sscal
(
n
,
*
a
,
y
,
1
);
}
else
{
VScalRefer
<
float
>
(
a
,
x
,
y
,
n
);
}
}
template
<
>
void
VScalMKL
<
double
>
(
const
double
*
a
,
const
double
*
x
,
double
*
y
,
int
n
)
{
if
(
x
==
y
)
{
platform
::
dynload
::
cblas_dscal
(
n
,
*
a
,
y
,
1
);
}
else
{
VScalRefer
<
double
>
(
a
,
x
,
y
,
n
);
}
}
#endif
#endif
#define DECLARE_STATIC_FUNC \
#define DECLARE_STATIC_FUNC \
...
@@ -226,87 +255,60 @@ bool VAddReluKernelImpl<float>::useJIT(int d) {
...
@@ -226,87 +255,60 @@ bool VAddReluKernelImpl<float>::useJIT(int d) {
}
}
#endif
#endif
#undef DECLARE_STATIC_FUNC
/* VScal JitKernel */
template
<
typename
T
>
REGISTER_JITKERNEL
(
vmul
,
VMulKernel
);
REGISTER_JITKERNEL
(
vadd
,
VAddKernel
);
REGISTER_JITKERNEL
(
vaddrelu
,
VAddReluKernel
);
/* VSCAL JitKernel */
template
<
typename
T
,
platform
::
jit
::
cpu_isa_t
isa
,
jit_block
>
class
VScalKernelImpl
:
public
VScalKernel
<
T
>
{
class
VScalKernelImpl
:
public
VScalKernel
<
T
>
{
public:
public:
explicit
VScalKernelImpl
(
int
d
)
:
VScalKernel
<
T
>
()
{
this
->
num_
=
d
;
}
DECLARE_STATIC_FUNC
;
void
Compute
(
const
T
a
,
const
T
*
x
,
T
*
y
)
const
override
{
explicit
VScalKernelImpl
(
int
d
)
:
VScalKernel
<
T
>
()
{
for
(
int
i
=
0
;
i
<
this
->
num_
;
++
i
)
{
#ifdef PADDLE_WITH_XBYAK
y
[
i
]
=
a
*
x
[
i
];
if
(
useJIT
(
d
))
{
}
size_t
sz
=
96
+
d
/
AVX_FLOAT_BLOCK
*
4
*
8
;
}
jitcode_
.
reset
(
new
gen
::
VScalJitCode
(
d
,
sz
>
4096
?
sz
:
4096
));
void
Compute
(
const
T
a
,
T
*
x
)
const
override
{
this
->
Compute
=
for
(
int
i
=
0
;
i
<
this
->
num_
;
++
i
)
{
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
const
T
*
,
T
*
,
int
)
>
();
x
[
i
]
=
a
*
x
[
i
];
return
;
}
}
}
};
#endif
#ifdef PADDLE_WITH_MKLML
#ifdef PADDLE_WITH_MKLML
#define MKL_FLOAT(isa, block) \
if
(
useMKL
(
d
))
{
template <> \
this
->
Compute
=
VScalMKL
<
T
>
;
void VScalKernelImpl<float, isa, block>::Compute(const float a, float* x) \
return
;
const { \
platform::dynload::cblas_sscal(this->num_, a, x, 1); \
}
#define MKL_DOUBLE(isa, block) \
template <> \
void VScalKernelImpl<double, isa, block>::Compute(const double a, double* x) \
const { \
platform::dynload::cblas_dscal(this->num_, a, x, 1); \
}
}
FOR_EACH_ISA
(
MKL_FLOAT
,
kGT16
);
FOR_EACH_ISA_BLOCK
(
MKL_DOUBLE
);
#endif
#endif
this
->
Compute
=
VScalRefer
<
T
>
;
#define INTRI8_FLOAT(isa) \
template <> \
void VScalKernelImpl<float, isa, kEQ8>::Compute( \
const float a, const float* x, float* y) const { \
__m256 tmp; \
__m256 scalar = _mm256_set1_ps(a); \
tmp = _mm256_loadu_ps(x); \
tmp = _mm256_mul_ps(tmp, scalar); \
_mm256_storeu_ps(y, tmp); \
}
#define INTRI8_INPLACE_FLOAT(isa) \
template <> \
void VScalKernelImpl<float, isa, kEQ8>::Compute(const float a, float* x) \
const { \
__m256 tmp; \
__m256 scalar = _mm256_set1_ps(a); \
tmp = _mm256_loadu_ps(x); \
tmp = _mm256_mul_ps(tmp, scalar); \
_mm256_storeu_ps(x, tmp); \
}
}
#ifdef PADDLE_WITH_XBYAK
#ifdef __AVX__
private:
INTRI8_FLOAT
(
jit
::
avx
);
std
::
unique_ptr
<
gen
::
VScalJitCode
>
jitcode_
{
nullptr
};
INTRI8_INPLACE_FLOAT
(
jit
::
avx
);
#endif
#endif
#ifdef __AVX2__
};
INTRI8_FLOAT
(
jit
::
avx2
);
INTRI8_INPLACE_FLOAT
(
jit
::
avx2
);
#ifdef PADDLE_WITH_XBYAK
template
<
>
bool
VScalKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
VScalJitCode
::
init
(
d
);
}
#endif
#endif
#ifdef __AVX512F__
INTRI8_FLOAT
(
jit
::
avx512f
);
#ifdef PADDLE_WITH_MKLML
INTRI8_INPLACE_FLOAT
(
jit
::
avx512f
);
template
<
>
bool
VScalKernelImpl
<
float
>::
useMKL
(
int
d
)
{
return
d
>
512
;
}
template
<
>
bool
VScalKernelImpl
<
double
>::
useMKL
(
int
d
)
{
return
true
;
}
#endif
#endif
// TODO(TJ): eq16 test and complete avx512
#undef INTRI8_FLOAT
#undef DECLARE_STATIC_FUNC
#undef INTRI8_INPLACE_FLOAT
#undef MKL_FLOAT
REGISTER_JITKERNEL
(
vmul
,
VMulKernel
);
#undef MKL_DOUBLE
REGISTER_JITKERNEL
(
vadd
,
VAddKernel
);
REGISTER_JITKERNEL
(
vscal
,
VScalKernel
);
REGISTER_JITKERNEL
(
vaddrelu
,
VAddReluKernel
);
/* VAddBias JitKernel */
/* VAddBias JitKernel */
template
<
typename
T
,
platform
::
jit
::
cpu_isa_t
isa
,
jit_block
>
template
<
typename
T
,
platform
::
jit
::
cpu_isa_t
isa
,
jit_block
>
...
@@ -467,7 +469,6 @@ class VIdentityKernelImpl : public VIdentityKernel<T> {
...
@@ -467,7 +469,6 @@ class VIdentityKernelImpl : public VIdentityKernel<T> {
void
Compute
(
const
T
*
x
,
T
*
y
)
const
override
{}
void
Compute
(
const
T
*
x
,
T
*
y
)
const
override
{}
};
};
REGISTER_JITKERNEL_DEPRECATED
(
vscal
,
VScalKernel
);
REGISTER_JITKERNEL_DEPRECATED
(
vaddb
,
VAddBiasKernel
);
REGISTER_JITKERNEL_DEPRECATED
(
vaddb
,
VAddBiasKernel
);
REGISTER_JITKERNEL_DEPRECATED
(
vrelu
,
VReluKernel
);
REGISTER_JITKERNEL_DEPRECATED
(
vrelu
,
VReluKernel
);
REGISTER_JITKERNEL_DEPRECATED
(
videntity
,
VIdentityKernel
);
REGISTER_JITKERNEL_DEPRECATED
(
videntity
,
VIdentityKernel
);
...
...
paddle/fluid/operators/math/jit_kernel_exp.cc
浏览文件 @
03e11f3f
...
@@ -409,9 +409,10 @@ class VTanhKernelImpl : public VTanhKernel<T> {
...
@@ -409,9 +409,10 @@ class VTanhKernelImpl : public VTanhKernel<T> {
vaddbias_
=
KernelPool
::
Instance
().
template
Get
<
VAddBiasKernel
<
T
>
>
(
d
);
vaddbias_
=
KernelPool
::
Instance
().
template
Get
<
VAddBiasKernel
<
T
>
>
(
d
);
}
}
void
Compute
(
const
T
*
x
,
T
*
y
)
const
override
{
void
Compute
(
const
T
*
x
,
T
*
y
)
const
override
{
vscal_
->
Compute
(
static_cast
<
T
>
(
2
),
x
,
y
);
const
T
a
=
static_cast
<
T
>
(
2
);
vscal_
->
Compute
(
&
a
,
x
,
y
,
this
->
num_
);
vsigmoid_
->
Compute
(
y
,
y
);
vsigmoid_
->
Compute
(
y
,
y
);
vscal_
->
Compute
(
static_cast
<
T
>
(
2
),
y
);
vscal_
->
Compute
(
&
a
,
y
,
y
,
this
->
num_
);
vaddbias_
->
Compute
(
static_cast
<
T
>
(
-
1
),
y
,
y
);
vaddbias_
->
Compute
(
static_cast
<
T
>
(
-
1
),
y
,
y
);
}
}
...
@@ -472,9 +473,10 @@ class VTanhKernelImpl : public VTanhKernel<T> {
...
@@ -472,9 +473,10 @@ class VTanhKernelImpl : public VTanhKernel<T> {
_mm256_storeu_ps(y, tmp); \
_mm256_storeu_ps(y, tmp); \
x += AVX_FLOAT_BLOCK; \
x += AVX_FLOAT_BLOCK; \
y += AVX_FLOAT_BLOCK; \
y += AVX_FLOAT_BLOCK; \
vscal_->Compute(2.f, x, y); \
const float a = 2.f; \
vscal_->Compute(&a, x, y, this->num_); \
vsigmoid_->Compute(y, y); \
vsigmoid_->Compute(y, y); \
vscal_->Compute(
2.f, y);
\
vscal_->Compute(
&a, y, y, this->num_);
\
vaddbias_->Compute(-1.f, y, y); \
vaddbias_->Compute(-1.f, y, y); \
}
}
...
@@ -502,9 +504,10 @@ class VTanhKernelImpl : public VTanhKernel<T> {
...
@@ -502,9 +504,10 @@ class VTanhKernelImpl : public VTanhKernel<T> {
} \
} \
x += this->end_; \
x += this->end_; \
y += this->end_; \
y += this->end_; \
vscal_->Compute(2.f, x, y); \
const float a = 2.f; \
vscal_->Compute(&a, x, y, this->num_); \
vsigmoid_->Compute(y, y); \
vsigmoid_->Compute(y, y); \
vscal_->Compute(
2.f, y);
\
vscal_->Compute(
&a, y, y, this->num_);
\
vaddbias_->Compute(-1.f, y, y); \
vaddbias_->Compute(-1.f, y, y); \
}
}
...
...
paddle/fluid/operators/math/jit_kernel_test.cc
浏览文件 @
03e11f3f
...
@@ -281,9 +281,10 @@ void vtanh_better(
...
@@ -281,9 +281,10 @@ void vtanh_better(
const
paddle
::
operators
::
math
::
jitkernel
::
VAddBiasKernel
<
float
>>&
const
paddle
::
operators
::
math
::
jitkernel
::
VAddBiasKernel
<
float
>>&
vaddbias
,
vaddbias
,
const
int
n
,
const
float
*
x
,
float
*
y
)
{
const
int
n
,
const
float
*
x
,
float
*
y
)
{
vscal
->
Compute
(
2.
f
,
x
,
y
);
const
float
tmp1
=
2.
f
;
vscal
->
Compute
(
&
tmp1
,
x
,
y
,
n
);
vsigmoid
->
Compute
(
y
,
y
);
vsigmoid
->
Compute
(
y
,
y
);
vscal
->
Compute
(
2.
f
,
y
);
vscal
->
Compute
(
&
tmp1
,
y
,
y
,
n
);
vaddbias
->
Compute
(
-
1.
f
,
y
,
y
);
vaddbias
->
Compute
(
-
1.
f
,
y
,
y
);
}
}
...
@@ -531,12 +532,12 @@ TEST(JitKernel, vscal) {
...
@@ -531,12 +532,12 @@ TEST(JitKernel, vscal) {
auto
ttgts
=
GetCurrentUS
();
auto
ttgts
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
ker
->
Compute
(
a
,
x_data
,
ztgt_data
);
ker
->
Compute
(
&
a
,
x_data
,
ztgt_data
,
d
);
}
}
auto
ttgte
=
GetCurrentUS
();
auto
ttgte
=
GetCurrentUS
();
auto
ttgts1
=
GetCurrentUS
();
auto
ttgts1
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
ker
->
Compute
(
a
,
y_data
);
ker
->
Compute
(
&
a
,
y_data
,
y_data
,
d
);
}
}
auto
ttgte1
=
GetCurrentUS
();
auto
ttgte1
=
GetCurrentUS
();
VLOG
(
3
)
<<
"Vec size "
<<
d
<<
": refer takes: "
<<
(
trefe
-
trefs
)
/
repeat
VLOG
(
3
)
<<
"Vec size "
<<
d
<<
": refer takes: "
<<
(
trefe
-
trefs
)
/
repeat
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录