Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
ae179269
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ae179269
编写于
12月 14, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
enable jitkernel mkl vmul, vadd and vscal
上级
77907a35
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
90 addition
and
63 deletion
+90
-63
paddle/fluid/operators/jit/README.md
paddle/fluid/operators/jit/README.md
+2
-0
paddle/fluid/operators/jit/more/mkl/CMakeLists.txt
paddle/fluid/operators/jit/more/mkl/CMakeLists.txt
+2
-0
paddle/fluid/operators/jit/more/mkl/mkl.cc
paddle/fluid/operators/jit/more/mkl/mkl.cc
+66
-2
paddle/fluid/operators/jit/more/mkl/mkl.h
paddle/fluid/operators/jit/more/mkl/mkl.h
+20
-11
paddle/fluid/operators/math/jit_kernel_blas.cc
paddle/fluid/operators/math/jit_kernel_blas.cc
+0
-50
未找到文件。
paddle/fluid/operators/jit/README.md
浏览文件 @
ae179269
...
...
@@ -45,6 +45,8 @@ PaddlePaddle/Paddle/paddle/fluid/
-
在
`KernelType`
中添加
`your_key`
.
-
实现Reference 的逻辑,每个jitkernel的Reference 实现是必须的。不要依赖任何第三方库。并在
`refer/CmakeLists.txt`
中
`USE_JITKERNEL_REFER(your_key)`
.
-
(optional) 实现更多的算法在
`more`
目录下,可以依赖mkl,openblas,或者mkldnn等第三方库。
-
(optional) 实现基于Xbyak的生成code,在
`gen`
目下。
-
必要时可以添加新的
`KernelTuples`
,可以参考
`XYZNTuples`
,新加的Attr类型需要特例化
`JitCodeKey`
方法。
-
添加unit test,需要测试float和double
-
添加benchmark确保get得到的速度是最快。
paddle/fluid/operators/jit/more/mkl/CMakeLists.txt
浏览文件 @
ae179269
...
...
@@ -4,3 +4,5 @@ set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} dynload_mklml jit_kernel_mkl PARENT_SCOPE
# use mkl kernels by name and type
USE_JITKERNEL_MORE
(
vmul, mkl
)
USE_JITKERNEL_MORE
(
vadd, mkl
)
USE_JITKERNEL_MORE
(
vscal, mkl
)
paddle/fluid/operators/jit/more/mkl/mkl.cc
浏览文件 @
ae179269
...
...
@@ -13,7 +13,9 @@
* limitations under the License. */
#include "paddle/fluid/operators/jit/more/mkl/mkl.h"
#include "paddle/fluid/operators/jit/refer/refer.h"
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/dynload/mklml.h"
namespace
paddle
{
...
...
@@ -32,6 +34,61 @@ void VMul<double>(const double* x, const double* y, double* z, int n) {
platform
::
dynload
::
vdMul
(
n
,
x
,
y
,
z
);
}
template
<
>
void
VAdd
<
float
>
(
const
float
*
x
,
const
float
*
y
,
float
*
z
,
int
n
)
{
platform
::
dynload
::
vsAdd
(
n
,
x
,
y
,
z
);
}
template
<
>
void
VAdd
<
double
>
(
const
double
*
x
,
const
double
*
y
,
double
*
z
,
int
n
)
{
platform
::
dynload
::
vdAdd
(
n
,
x
,
y
,
z
);
}
template
<
>
void
VScal
<
float
>
(
const
float
*
a
,
const
float
*
x
,
float
*
y
,
int
n
)
{
if
(
x
==
y
)
{
platform
::
dynload
::
cblas_sscal
(
n
,
*
a
,
y
,
1
);
}
else
{
refer
::
VScal
<
float
>
(
a
,
x
,
y
,
n
);
}
}
template
<
>
void
VScal
<
double
>
(
const
double
*
a
,
const
double
*
x
,
double
*
y
,
int
n
)
{
if
(
x
==
y
)
{
platform
::
dynload
::
cblas_dscal
(
n
,
*
a
,
y
,
1
);
}
else
{
refer
::
VScal
<
double
>
(
a
,
x
,
y
,
n
);
}
}
// TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512
template
<
>
bool
VMulKernel
<
float
>::
UseMe
(
int
d
)
const
{
return
platform
::
MayIUse
(
platform
::
avx512f
)
&&
d
>
512
;
}
template
<
>
bool
VAddKernel
<
float
>::
UseMe
(
int
d
)
const
{
return
platform
::
MayIUse
(
platform
::
avx512f
)
&&
d
>
512
;
}
template
<
>
bool
VScalKernel
<
float
>::
UseMe
(
int
d
)
const
{
return
platform
::
MayIUse
(
platform
::
avx512f
)
&&
d
>
512
;
}
#define AWALYS_USE_ME_WITH_DOUBLE(func) \
template <> \
bool func##Kernel<double>::UseMe(int d) const { \
return true; \
}
AWALYS_USE_ME_WITH_DOUBLE
(
VMul
);
AWALYS_USE_ME_WITH_DOUBLE
(
VAdd
);
AWALYS_USE_ME_WITH_DOUBLE
(
VScal
);
#undef AWALYS_USE_ME_WITH_DOUBLE
}
// namespace mkl
}
// namespace more
}
// namespace jit
...
...
@@ -40,5 +97,12 @@ void VMul<double>(const double* x, const double* y, double* z, int n) {
namespace
mkl
=
paddle
::
operators
::
jit
::
more
::
mkl
;
REGISTER_JITKERNEL_MORE
(
vmul
,
mkl
,
mkl
::
VMulKernel
<
float
>
,
mkl
::
VMulKernel
<
double
>
);
#define REGISTER_MKL_KERNEL(key, func) \
REGISTER_JITKERNEL_MORE(key, mkl, mkl::func##Kernel<float>, \
mkl::func##Kernel<double>)
REGISTER_MKL_KERNEL
(
vmul
,
VMul
);
REGISTER_MKL_KERNEL
(
vadd
,
VAdd
);
REGISTER_MKL_KERNEL
(
vscal
,
VScal
);
#undef REGISTER_MKL_KERNEL
paddle/fluid/operators/jit/more/mkl/mkl.h
浏览文件 @
ae179269
...
...
@@ -16,7 +16,6 @@
#include <type_traits>
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -28,17 +27,27 @@ template <typename T>
void
VMul
(
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int
n
);
template
<
typename
T
>
class
VMulKernel
:
public
KernelImpl
<
XYZNTuples
<
T
>>
{
public:
VMulKernel
()
{
this
->
func
=
VMul
<
T
>
;
}
bool
UseMe
(
int
d
)
const
override
{
if
(
std
::
is_same
<
T
,
float
>::
value
)
{
return
platform
::
MayIUse
(
platform
::
avx512f
)
&&
d
>
512
;
}
else
{
return
true
;
}
void
VAdd
(
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int
n
);
template
<
typename
T
>
void
VScal
(
const
T
*
a
,
const
T
*
x
,
T
*
y
,
int
n
);
#define DECLARE_MKL_KERNEL(name, tuples) \
template <typename T> \
class name##Kernel : public KernelImpl<tuples<T>> { \
public: \
name##Kernel() { this->func = name<T>; } \
bool UseMe(typename tuples<T>::attr_type) const override; \
}
};
// XYZN
DECLARE_MKL_KERNEL
(
VMul
,
XYZNTuples
);
DECLARE_MKL_KERNEL
(
VAdd
,
XYZNTuples
);
// AXYN
DECLARE_MKL_KERNEL
(
VScal
,
AXYNTuples
);
#undef DECLARE_MKL_KERNEL
}
// namespace mkl
}
// namespace more
...
...
paddle/fluid/operators/math/jit_kernel_blas.cc
浏览文件 @
ae179269
...
...
@@ -31,56 +31,6 @@ namespace operators {
namespace
math
{
namespace
jitkernel
{
#ifdef PADDLE_WITH_MKLML
template
<
typename
T
>
void
VMulMKL
(
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int
n
);
template
<
>
void
VMulMKL
<
float
>
(
const
float
*
x
,
const
float
*
y
,
float
*
z
,
int
n
)
{
platform
::
dynload
::
vsMul
(
n
,
x
,
y
,
z
);
}
template
<
>
void
VMulMKL
<
double
>
(
const
double
*
x
,
const
double
*
y
,
double
*
z
,
int
n
)
{
platform
::
dynload
::
vdMul
(
n
,
x
,
y
,
z
);
}
template
<
typename
T
>
void
VAddMKL
(
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int
n
);
template
<
>
void
VAddMKL
<
float
>
(
const
float
*
x
,
const
float
*
y
,
float
*
z
,
int
n
)
{
platform
::
dynload
::
vsAdd
(
n
,
x
,
y
,
z
);
}
template
<
>
void
VAddMKL
<
double
>
(
const
double
*
x
,
const
double
*
y
,
double
*
z
,
int
n
)
{
platform
::
dynload
::
vdAdd
(
n
,
x
,
y
,
z
);
}
template
<
typename
T
>
void
VScalMKL
(
const
T
*
a
,
const
T
*
x
,
T
*
y
,
int
n
);
template
<
>
void
VScalMKL
<
float
>
(
const
float
*
a
,
const
float
*
x
,
float
*
y
,
int
n
)
{
if
(
x
==
y
)
{
platform
::
dynload
::
cblas_sscal
(
n
,
*
a
,
y
,
1
);
}
else
{
refer
::
VScal
<
float
>
(
a
,
x
,
y
,
n
);
}
}
template
<
>
void
VScalMKL
<
double
>
(
const
double
*
a
,
const
double
*
x
,
double
*
y
,
int
n
)
{
if
(
x
==
y
)
{
platform
::
dynload
::
cblas_dscal
(
n
,
*
a
,
y
,
1
);
}
else
{
refer
::
VScal
<
double
>
(
a
,
x
,
y
,
n
);
}
}
#endif
/* VMUL JitKernel */
template
<
typename
T
>
class
VMulKernelImpl
:
public
VMulKernel
<
T
>
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录