Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
3d928d4f
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3d928d4f
编写于
9月 28, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine and seepdup
上级
77fc42d2
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
103 addition
and
108 deletion
+103
-108
paddle/fluid/operators/math/jit_kernel.cc
paddle/fluid/operators/math/jit_kernel.cc
+0
-23
paddle/fluid/operators/math/jit_kernel.h
paddle/fluid/operators/math/jit_kernel.h
+3
-5
paddle/fluid/operators/math/jit_kernel_blas.cc
paddle/fluid/operators/math/jit_kernel_blas.cc
+100
-80
未找到文件。
paddle/fluid/operators/math/jit_kernel.cc
浏览文件 @
3d928d4f
...
...
@@ -35,29 +35,6 @@ const std::shared_ptr<Kernel> KernelPool::Get(const std::string& key) const {
return
kers_
.
at
(
key
);
}
#define DEFINE_WITH_DTYPE(ker_key, ker_class, ker_dtype, dtype_key) \
template <> \
const std::shared_ptr<ker_class<ker_dtype>> \
KernelPool::Get<ker_class<ker_dtype>>(int d) { \
std::string key = #ker_key #dtype_key + std::to_string(d); \
if (kers_.find(key) == kers_.end()) { \
auto p = std::make_shared<ker_class<ker_dtype>>(d); \
kers_.insert({key, std::dynamic_pointer_cast<Kernel>(p)}); \
return p; \
} \
return std::dynamic_pointer_cast<ker_class<ker_dtype>>(kers_.at(key)); \
}
#define REGISTER_BLAS_JITKERNEL(ker_key, ker_class) \
DEFINE_WITH_DTYPE(ker_key, ker_class, float, f); \
DEFINE_WITH_DTYPE(ker_key, ker_class, double, d)
REGISTER_BLAS_JITKERNEL
(
vmul
,
VMulKernel
);
REGISTER_BLAS_JITKERNEL
(
vadd
,
VAddKernel
);
#undef REGISTER_BLAS_JITKERNEL
#undef DEFINE_WITH_DTYPE
template
<
>
const
std
::
shared_ptr
<
LSTMKernel
<
float
>>
KernelPool
::
Get
<
LSTMKernel
<
float
>
,
int
,
const
std
::
string
&
,
const
std
::
string
&
,
...
...
paddle/fluid/operators/math/jit_kernel.h
浏览文件 @
3d928d4f
...
...
@@ -40,7 +40,7 @@ typedef enum { kLT8, kEQ8, kGT8LT16, kEQ16, kGT16 } jit_block;
class
Kernel
{
public:
Kernel
()
{}
Kernel
()
=
default
;
virtual
~
Kernel
()
=
default
;
private:
...
...
@@ -66,15 +66,13 @@ class KernelPool {
template
<
typename
T
>
class
VMulKernel
:
public
Kernel
{
public:
explicit
VMulKernel
(
int
n
);
void
(
*
Compute
)(
const
int
n
,
const
T
*
,
const
T
*
,
T
*
);
virtual
void
Compute
(
const
int
n
,
const
T
*
x
,
const
T
*
y
,
T
*
z
)
=
0
;
};
template
<
typename
T
>
class
VAddKernel
:
public
Kernel
{
public:
explicit
VAddKernel
(
int
n
);
void
(
*
Compute
)(
const
int
n
,
const
T
*
,
const
T
*
,
T
*
);
virtual
void
Compute
(
const
int
n
,
const
T
*
x
,
const
T
*
y
,
T
*
z
)
=
0
;
};
template
<
typename
T
>
...
...
paddle/fluid/operators/math/jit_kernel_blas.cc
浏览文件 @
3d928d4f
...
...
@@ -29,17 +29,21 @@ namespace jitkernel {
namespace
jit
=
platform
::
jit
;
#define NEW_IMPL(src, t, isa, k) \
p = std::dynamic_pointer_cast<src<t>>( \
std::make_shared<src##Impl<t, isa, k>>())
#define SEARCH_BLOCK(src, t, isa) \
if (d < AVX_FLOAT_BLOCK) { \
Compute = src<t, isa, kLT8>
; \
NEW_IMPL(src, t, isa, kLT8)
; \
} else if (d == AVX_FLOAT_BLOCK) { \
Compute = src<t, isa, kEQ8>
; \
NEW_IMPL(src, t, isa, kEQ8)
; \
} else if (d > AVX_FLOAT_BLOCK && d < AVX512_FLOAT_BLOCK) { \
Compute = src<t, isa, kGT8LT16>
; \
NEW_IMPL(src, t, isa, kGT8LT16)
; \
} else if (d == AVX512_FLOAT_BLOCK) { \
Compute = src<t, isa, kEQ16>
; \
NEW_IMPL(src, t, isa, kEQ16)
; \
} else { \
Compute = src<t, isa, kGT16>
; \
NEW_IMPL(src, t, isa, kGT16)
; \
}
#define SEARCH_ISA_BLOCK(src, t) \
...
...
@@ -53,6 +57,24 @@ namespace jit = platform::jit;
SEARCH_BLOCK(src, t, jit::isa_any); \
}
#define DEFINE_WITH_DTYPE(ker_key, ker_class, ker_dtype, dtype_key) \
template <> \
const std::shared_ptr<ker_class<ker_dtype>> \
KernelPool::Get<ker_class<ker_dtype>>(int d) { \
std::string key = #ker_key #dtype_key + std::to_string(d); \
if (kers_.find(key) == kers_.end()) { \
std::shared_ptr<ker_class<ker_dtype>> p; \
SEARCH_ISA_BLOCK(ker_class, ker_dtype); \
kers_.insert({key, std::dynamic_pointer_cast<Kernel>(p)}); \
return p; \
} \
return std::dynamic_pointer_cast<ker_class<ker_dtype>>(kers_.at(key)); \
}
#define REGISTER_BLAS_JITKERNEL(ker_key, ker_class) \
DEFINE_WITH_DTYPE(ker_key, ker_class, float, f); \
DEFINE_WITH_DTYPE(ker_key, ker_class, double, d)
// do not include lt8, eq8, eq16
#define FOR_EACH_COMMON_BLOCK(macro_, isa) \
macro_(isa, kGT8LT16) macro_(isa, kGT16)
...
...
@@ -73,132 +95,130 @@ namespace jit = platform::jit;
FOR_EACH_ALL_BLOCK(macro_, jit::avx) \
FOR_EACH_ALL_BLOCK(macro_, jit::isa_any)
#define BIND_KERNEL_WITH_DTYPE(ker_class, ker_func, ker_dtype) \
template <> \
ker_class<ker_dtype>::ker_class(int d) { \
SEARCH_ISA_BLOCK(ker_func, ker_dtype); \
}
#define BIND_KERNEL(ker_class, ker_func) \
BIND_KERNEL_WITH_DTYPE(ker_class, ker_func, float); \
BIND_KERNEL_WITH_DTYPE(ker_class, ker_func, double)
/* VMUL JitKernel */
template
<
typename
T
,
platform
::
jit
::
cpu_isa_t
isa
,
jit_block
>
static
void
VMulCompute
(
const
int
n
,
const
T
*
x
,
const
T
*
y
,
T
*
z
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
z
[
i
]
=
x
[
i
]
*
y
[
i
];
class
VMulKernelImpl
:
public
VMulKernel
<
T
>
{
public:
void
Compute
(
const
int
n
,
const
T
*
x
,
const
T
*
y
,
T
*
z
)
override
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
z
[
i
]
=
x
[
i
]
*
y
[
i
];
}
}
}
}
;
#ifdef PADDLE_WITH_MKLML
#define VMUL_MKL_FLOAT(isa, block) \
template <> \
void VMul
Compute<float, isa, block>
(const int n, const float* x, \
const float* y, float* z) { \
platform::dynload::vsMul(n, x, y, z); \
#define VMUL_MKL_FLOAT(isa, block)
\
template <>
\
void VMul
KernelImpl<float, isa, block>::Compute
(const int n, const float* x, \
const float* y, float* z) { \
platform::dynload::vsMul(n, x, y, z);
\
}
#define VMUL_MKL_DOUBLE(isa, block)
\
template <>
\
void VMul
Compute<double, isa, block>(const int n, const double* x,
\
const double* y, double* z) { \
platform::dynload::vdMul(n, x, y, z);
\
#define VMUL_MKL_DOUBLE(isa, block) \
template <> \
void VMul
KernelImpl<double, isa, block>::Compute(
\
const int n, const double* x,
const double* y, double* z) { \
platform::dynload::vdMul(n, x, y, z); \
}
FOR_EACH_ISA_COMMON_BLOCK
(
VMUL_MKL_FLOAT
)
FOR_EACH_ISA_ALL_BLOCK
(
VMUL_MKL_DOUBLE
)
FOR_EACH_ISA_COMMON_BLOCK
(
VMUL_MKL_FLOAT
)
;
FOR_EACH_ISA_ALL_BLOCK
(
VMUL_MKL_DOUBLE
)
;
#endif
/// eq8
#define VMUL_INTRI8_FLOAT(isa) \
template <> \
void VMulCompute<float, isa, kEQ8>(const int n, const float* x, \
const float* y, float* z) { \
__m256 tmpx, tmpy; \
tmpx = _mm256_loadu_ps(x); \
tmpy = _mm256_loadu_ps(y); \
tmpx = _mm256_mul_ps(tmpx, tmpy); \
_mm256_storeu_ps(z, tmpx); \
#define VMUL_INTRI8_FLOAT(isa) \
template <> \
void VMulKernelImpl<float, isa, kEQ8>::Compute(const int n, const float* x, \
const float* y, float* z) { \
__m256 tmpx, tmpy; \
tmpx = _mm256_loadu_ps(x); \
tmpy = _mm256_loadu_ps(y); \
tmpx = _mm256_mul_ps(tmpx, tmpy); \
_mm256_storeu_ps(z, tmpx); \
}
// avx > for > mkl
#ifdef __AVX__
VMUL_INTRI8_FLOAT
(
jit
::
avx
);
#endif
// avx2 > for > mkl
#ifdef __AVX2__
VMUL_INTRI8_FLOAT
(
jit
::
avx2
)
VMUL_INTRI8_FLOAT
(
jit
::
avx2
);
#endif
#ifdef __AVX512F__
VMUL_INTRI8_FLOAT
(
jit
::
avx512f
);
#endif
// TODO(TJ): test and complete avx512
// TODO(TJ): eq16 test and complete avx512
#undef VMUL_INTRI8_FLOAT
#undef VMUL_MKL_FLOAT
#undef VMUL_MKL_DOUBLE
/* VADD */
/* VADD
JitKernel
*/
template
<
typename
T
,
platform
::
jit
::
cpu_isa_t
isa
,
jit_block
>
static
void
VAddCompute
(
const
int
n
,
const
T
*
x
,
const
T
*
y
,
T
*
z
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
z
[
i
]
=
x
[
i
]
+
y
[
i
];
class
VAddKernelImpl
:
public
VAddKernel
<
T
>
{
public:
void
Compute
(
const
int
n
,
const
T
*
x
,
const
T
*
y
,
T
*
z
)
override
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
z
[
i
]
=
x
[
i
]
+
y
[
i
];
}
}
}
}
;
#ifdef PADDLE_WITH_MKLML
#define VADD_MKL_FLOAT(isa, block) \
template <> \
void VAdd
Compute<float, isa, block>
(const int n, const float* x, \
const float* y, float* z) { \
platform::dynload::vsAdd(n, x, y, z); \
#define VADD_MKL_FLOAT(isa, block)
\
template <>
\
void VAdd
KernelImpl<float, isa, block>::Compute
(const int n, const float* x, \
const float* y, float* z) { \
platform::dynload::vsAdd(n, x, y, z);
\
}
#define VADD_MKL_DOUBLE(isa, block)
\
template <>
\
void VAdd
Compute<double, isa, block>(const int n, const double* x,
\
const double* y, double* z) { \
platform::dynload::vdAdd(n, x, y, z);
\
#define VADD_MKL_DOUBLE(isa, block) \
template <> \
void VAdd
KernelImpl<double, isa, block>::Compute(
\
const int n, const double* x,
const double* y, double* z) { \
platform::dynload::vdAdd(n, x, y, z); \
}
FOR_EACH_ISA_COMMON_BLOCK
(
VADD_MKL_FLOAT
)
FOR_EACH_ISA_ALL_BLOCK
(
VADD_MKL_DOUBLE
)
FOR_EACH_ISA_COMMON_BLOCK
(
VADD_MKL_FLOAT
)
;
FOR_EACH_ISA_ALL_BLOCK
(
VADD_MKL_DOUBLE
)
;
#endif
/// eq8
#define VADD_INTRI8_FLOAT(isa) \
template <> \
void VAddCompute<float, isa, kEQ8>(const int n, const float* x, \
const float* y, float* z) { \
__m256 tmpx, tmpy; \
tmpx = _mm256_loadu_ps(x); \
tmpy = _mm256_loadu_ps(y); \
tmpx = _mm256_add_ps(tmpx, tmpy); \
_mm256_storeu_ps(z, tmpx); \
#define VADD_INTRI8_FLOAT(isa) \
template <> \
void VAddKernelImpl<float, isa, kEQ8>::Compute(const int n, const float* x, \
const float* y, float* z) { \
__m256 tmpx, tmpy; \
tmpx = _mm256_loadu_ps(x); \
tmpy = _mm256_loadu_ps(y); \
tmpx = _mm256_add_ps(tmpx, tmpy); \
_mm256_storeu_ps(z, tmpx); \
}
#ifdef __AVX__
VADD_INTRI8_FLOAT
(
jit
::
avx
)
VADD_INTRI8_FLOAT
(
jit
::
avx
)
;
#endif
#ifdef __AVX2__
VADD_INTRI8_FLOAT
(
jit
::
avx2
)
VADD_INTRI8_FLOAT
(
jit
::
avx2
);
#endif
#ifdef __AVX512F__
VADD_INTRI8_FLOAT
(
jit
::
avx512f
);
#endif
// TODO(TJ): test and complete avx512
// TODO(TJ):
eq16
test and complete avx512
#undef VADD_INTRI8_FLOAT
#undef VADD_MKL_FLOAT
#undef VADD_MKL_DOUBLE
BIND_KERNEL
(
VMulKernel
,
VMulCompute
);
BIND_KERNEL
(
VAddKernel
,
VAddCompute
);
REGISTER_BLAS_JITKERNEL
(
vmul
,
VMulKernel
);
REGISTER_BLAS_JITKERNEL
(
vadd
,
VAddKernel
);
#undef BIND_KERNEL
#undef BIND_KERNEL_WITH_DTYPE
#undef FOR_EACH_ISA_ALL_BLOCK
#undef FOR_EACH_ALL_BLOCK
#undef FOR_EACH_ISA_COMMON_BLOCK
#undef FOR_EACH_COMMON_BLOCK
#undef REGISTER_BLAS_JITKERNEL
#undef DEFINE_WITH_DTYPE
#undef SEARCH_ISA_BLOCK
#undef SEARCH_BLOCK
#undef NEW_IMPL
}
// namespace jitkernel
}
// namespace math
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录