Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
084893a9
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
084893a9
编写于
9月 27, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add vadd kernel
上级
eeff268a
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
148 addition
and
44 deletion
+148
-44
paddle/fluid/operators/math/jit_kernel.cc
paddle/fluid/operators/math/jit_kernel.cc
+27
-19
paddle/fluid/operators/math/jit_kernel.h
paddle/fluid/operators/math/jit_kernel.h
+9
-0
paddle/fluid/operators/math/jit_kernel_blas.cc
paddle/fluid/operators/math/jit_kernel_blas.cc
+98
-16
paddle/fluid/operators/math/jit_kernel_test.cc
paddle/fluid/operators/math/jit_kernel_test.cc
+14
-9
未找到文件。
paddle/fluid/operators/math/jit_kernel.cc
浏览文件 @
084893a9
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h"
#include <iostream>
#include <string>
namespace
paddle
{
...
...
@@ -27,29 +28,35 @@ KernelPool& KernelPool::Instance() {
return
g_jit_kernels
;
}
template
<
>
const
std
::
shared_ptr
<
VMulKernel
<
float
>>
KernelPool
::
Get
<
VMulKernel
<
float
>>
(
int
d
)
{
std
::
string
key
=
"f"
+
std
::
to_string
(
d
);
const
std
::
shared_ptr
<
Kernel
>
KernelPool
::
Get
(
const
std
::
string
&
key
)
const
{
if
(
kers_
.
find
(
key
)
==
kers_
.
end
())
{
auto
p
=
std
::
make_shared
<
VMulKernel
<
float
>>
(
d
);
kers_
.
insert
({
key
,
std
::
dynamic_pointer_cast
<
Kernel
>
(
p
)});
return
p
;
return
nullptr
;
}
return
std
::
dynamic_pointer_cast
<
VMulKernel
<
float
>>
(
kers_
.
at
(
key
)
);
return
kers_
.
at
(
key
);
}
template
<
>
const
std
::
shared_ptr
<
VMulKernel
<
double
>>
KernelPool
::
Get
<
VMulKernel
<
double
>>
(
int
d
)
{
std
::
string
key
=
"d"
+
std
::
to_string
(
d
);
if
(
kers_
.
find
(
key
)
==
kers_
.
end
())
{
auto
p
=
std
::
make_shared
<
VMulKernel
<
double
>>
(
d
);
kers_
.
insert
({
key
,
std
::
dynamic_pointer_cast
<
Kernel
>
(
p
)});
return
p
;
#define DEFINE_WITH_DTYPE(ker_key, ker_class, ker_dtype, dtype_key) \
template <> \
const std::shared_ptr<ker_class<ker_dtype>> \
KernelPool::Get<ker_class<ker_dtype>>(int d) { \
std::string key = #ker_key #dtype_key + std::to_string(d); \
if (kers_.find(key) == kers_.end()) { \
auto p = std::make_shared<ker_class<ker_dtype>>(d); \
kers_.insert({key, std::dynamic_pointer_cast<Kernel>(p)}); \
return p; \
} \
return std::dynamic_pointer_cast<ker_class<ker_dtype>>(kers_.at(key)); \
}
return
std
::
dynamic_pointer_cast
<
VMulKernel
<
double
>>
(
kers_
.
at
(
key
));
}
#define REGISTER_BLAS_JITKERNEL(ker_key, ker_class) \
DEFINE_WITH_DTYPE(ker_key, ker_class, float, f); \
DEFINE_WITH_DTYPE(ker_key, ker_class, double, d)
REGISTER_BLAS_JITKERNEL
(
vmul
,
VMulKernel
);
REGISTER_BLAS_JITKERNEL
(
vadd
,
VAddKernel
);
#undef REGISTER_BLAS_JITKERNEL
#undef DEFINE_WITH_DTYPE
template
<
>
const
std
::
shared_ptr
<
LSTMKernel
<
float
>>
...
...
@@ -57,7 +64,8 @@ KernelPool::Get<LSTMKernel<float>, int, const std::string&, const std::string&,
const
std
::
string
&>
(
int
d
,
const
std
::
string
&
act_gate
,
const
std
::
string
&
act_cand
,
const
std
::
string
&
act_cell
)
{
std
::
string
key
=
"f"
+
std
::
to_string
(
d
)
+
act_gate
+
act_cand
+
act_cell
;
std
::
string
key
=
"lstmf"
+
std
::
to_string
(
d
)
+
act_gate
+
act_cand
+
act_cell
;
if
(
kers_
.
find
(
key
)
==
kers_
.
end
())
{
auto
p
=
std
::
make_shared
<
LSTMKernel
<
float
>>
(
d
,
act_gate
,
act_cand
,
act_cell
);
...
...
paddle/fluid/operators/math/jit_kernel.h
浏览文件 @
084893a9
...
...
@@ -54,6 +54,8 @@ class KernelPool {
template
<
typename
Ker
,
typename
...
ARGS
>
const
std
::
shared_ptr
<
Ker
>
Get
(
ARGS
...
args
);
const
std
::
shared_ptr
<
Kernel
>
Get
(
const
std
::
string
&
key
)
const
;
private:
KernelPool
()
=
default
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Kernel
>>
kers_
;
...
...
@@ -68,6 +70,13 @@ class VMulKernel : public Kernel {
void
(
*
Compute
)(
const
int
n
,
const
T
*
,
const
T
*
,
T
*
);
};
template
<
typename
T
>
class
VAddKernel
:
public
Kernel
{
public:
explicit
VAddKernel
(
int
n
);
void
(
*
Compute
)(
const
int
n
,
const
T
*
,
const
T
*
,
T
*
);
};
template
<
typename
T
>
class
LSTMKernel
:
public
Kernel
{
public:
...
...
paddle/fluid/operators/math/jit_kernel_blas.cc
浏览文件 @
084893a9
...
...
@@ -74,15 +74,22 @@ namespace jit = platform::jit;
FOR_EACH_ALL_BLOCK(macro_, jit::avx) \
FOR_EACH_ALL_BLOCK(macro_, jit::any)
/* VMUL JitKernel */
#define VMUL_ANY
\
for (int i = 0; i < n; ++i) {
\
z[i] = x[i] * y[i];
\
#define BIND_KERNEL_WITH_DTYPE(ker_class, ker_func, ker_dtype) \
template <>
\
ker_class<ker_dtype>::ker_class(int d) {
\
SEARCH_ISA_BLOCK(ker_func, ker_dtype);
\
}
#define BIND_KERNEL(ker_class, ker_func) \
BIND_KERNEL_WITH_DTYPE(ker_class, ker_func, float); \
BIND_KERNEL_WITH_DTYPE(ker_class, ker_func, double)
/* VMUL JitKernel */
template
<
typename
T
,
platform
::
jit
::
cpu_isa_t
isa
,
jit_block
>
static
void
VMulCompute
(
const
int
n
,
const
T
*
x
,
const
T
*
y
,
T
*
z
)
{
VMUL_ANY
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
z
[
i
]
=
x
[
i
]
*
y
[
i
];
}
}
#ifdef PADDLE_USE_MKLML
...
...
@@ -107,6 +114,8 @@ FOR_EACH_ISA_ALL_BLOCK(VMUL_MKL_DOUBLE)
/// lt8
#ifdef PADDLE_USE_MKLML
VMUL_MKL_FLOAT
(
jit
::
avx
,
kLT8
)
VMUL_MKL_FLOAT
(
jit
::
avx2
,
kLT8
)
VMUL_MKL_FLOAT
(
jit
::
avx512f
,
kLT8
)
#endif
/// eq8
...
...
@@ -143,20 +152,93 @@ VMUL_MKL_FLOAT(jit::avx2, kEQ16)
VMUL_MKL_FLOAT
(
jit
::
avx512f
,
kEQ16
)
#endif
#define USE_VMUL_KERNEL(T, func) \
template <> \
VMulKernel<T>::VMulKernel(int d) { \
SEARCH_ISA_BLOCK(func, T); \
}
USE_VMUL_KERNEL
(
float
,
VMulCompute
);
USE_VMUL_KERNEL
(
double
,
VMulCompute
);
#undef VMUL_ANY
#undef VMUL_INTRI8_FLOAT
#undef VMUL_MKL_FLOAT
#undef VMUL_MKL_DOUBLE
#undef USE_VMUL_KERNEL
/* VADD */
template
<
typename
T
,
platform
::
jit
::
cpu_isa_t
isa
,
jit_block
>
static
void
VAddCompute
(
const
int
n
,
const
T
*
x
,
const
T
*
y
,
T
*
z
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
z
[
i
]
=
x
[
i
]
+
y
[
i
];
}
}
#ifdef PADDLE_USE_MKLML
#define VADD_MKL_FLOAT(isa, block) \
template <> \
void VAddCompute<float, isa, block>(const int n, const float* x, \
const float* y, float* z) { \
platform::dynload::vsAdd(n, x, y, z); \
}
#define VADD_MKL_DOUBLE(isa, block) \
template <> \
void VAddCompute<double, isa, block>(const int n, const double* x, \
const double* y, float* z) { \
platform::dynload::vdAdd(n, x, y, z); \
}
FOR_EACH_ISA_COMMON_BLOCK
(
VADD_MKL_FLOAT
)
FOR_EACH_ISA_ALL_BLOCK
(
VADD_MKL_DOUBLE
)
#endif
/// lt8
#ifdef PADDLE_USE_MKLML
VADD_MKL_FLOAT
(
jit
::
avx
,
kLT8
)
VADD_MKL_FLOAT
(
jit
::
avx2
,
kLT8
)
VADD_MKL_FLOAT
(
jit
::
avx512f
,
kLT8
)
#endif
/// eq8
#define VADD_INTRI8_FLOAT(isa) \
template <> \
void VAddCompute<float, isa, kEQ8>(const int n, const float* x, \
const float* y, float* z) { \
__m256 tmpx, tmpy; \
tmpx = _mm256_loadu_ps(x); \
tmpy = _mm256_loadu_ps(y); \
tmpx = _mm256_add_ps(tmpx, tmpy); \
_mm256_storeu_ps(z, tmpx); \
}
// mkl > avx > for, ">" means better
#ifdef PADDLE_USE_MKLML
VADD_MKL_FLOAT
(
jit
::
avx
,
kEQ8
)
#elif defined __AVX__
VADD_INTRI8_FLOAT
(
jit
::
avx
)
#endif
// avx2 > mkl > for
#ifdef __AVX2__
VADD_INTRI8_FLOAT
(
jit
::
avx2
)
#elif defined PADDLE_USE_MKLML
VADD_MKL_FLOAT
(
jit
::
avx2
,
kEQ8
)
#endif
// TODO(TJ): test and complete avx512
/// eq16
#ifdef PADDLE_USE_MKLML
// TODO(TJ): test and complete me
VADD_MKL_FLOAT
(
jit
::
avx
,
kEQ16
)
VADD_MKL_FLOAT
(
jit
::
avx2
,
kEQ16
)
VADD_MKL_FLOAT
(
jit
::
avx512f
,
kEQ16
)
#endif
#undef VADD_INTRI8_FLOAT
#undef VADD_MKL_FLOAT
#undef VADD_MKL_DOUBLE
BIND_KERNEL
(
VMulKernel
,
VMulCompute
);
BIND_KERNEL
(
VAddKernel
,
VAddCompute
);
#undef BIND_KERNEL
#undef BIND_KERNEL_WITH_DTYPE
#undef FOR_EACH_ISA_ALL_BLOCK
#undef FOR_EACH_ALL_BLOCK
#undef FOR_EACH_ISA_COMMON_BLOCK
#undef FOR_EACH_COMMON_BLOCK
#undef SEARCH_ISA_BLOCK
#undef SEARCH_BLOCK
}
// namespace jitkernel
}
// namespace math
...
...
paddle/fluid/operators/math/jit_kernel_test.cc
浏览文件 @
084893a9
...
...
@@ -23,25 +23,30 @@ TEST(JitKernel, pool) {
namespace
jit
=
paddle
::
operators
::
math
::
jitkernel
;
const
int
frame_size
=
4
;
std
::
string
act_gate
=
"sigmoid"
,
act_cand
=
"tanh"
,
act_cell
=
"tanh"
;
const
auto
&
p1
=
const
auto
&
p
lstm
1
=
jit
::
KernelPool
::
Instance
()
.
template
Get
<
jit
::
LSTMKernel
<
float
>,
int
,
const
std
::
string
&
,
const
std
::
string
&
,
const
std
::
string
&>
(
frame_size
,
act_gate
,
act_cand
,
act_cell
);
const
auto
&
p2
=
const
auto
&
p
lstm
2
=
jit
::
KernelPool
::
Instance
()
.
template
Get
<
jit
::
LSTMKernel
<
float
>,
int
,
const
std
::
string
&
,
const
std
::
string
&
,
const
std
::
string
&>
(
frame_size
,
act_gate
,
act_cand
,
act_cell
);
EXPECT_EQ
(
p
1
,
p
2
);
EXPECT_EQ
(
p
lstm1
,
plstm
2
);
const
auto
&
p
3
=
const
auto
&
p
vmul_f
=
jit
::
KernelPool
::
Instance
().
template
Get
<
jit
::
VMulKernel
<
float
>
>
(
4
);
EXPECT_TRUE
(
std
::
dynamic_pointer_cast
<
jit
::
Kernel
>
(
p2
)
!=
std
::
dynamic_pointer_cast
<
jit
::
Kernel
>
(
p
3
));
EXPECT_TRUE
(
std
::
dynamic_pointer_cast
<
jit
::
Kernel
>
(
p
lstm
2
)
!=
std
::
dynamic_pointer_cast
<
jit
::
Kernel
>
(
p
vmul_f
));
const
auto
&
p
4
=
const
auto
&
p
vmul_d
=
jit
::
KernelPool
::
Instance
().
template
Get
<
jit
::
VMulKernel
<
double
>
>
(
4
);
EXPECT_TRUE
(
std
::
dynamic_pointer_cast
<
jit
::
Kernel
>
(
p3
)
!=
std
::
dynamic_pointer_cast
<
jit
::
Kernel
>
(
p4
));
EXPECT_TRUE
(
std
::
dynamic_pointer_cast
<
jit
::
Kernel
>
(
pvmul_f
)
!=
std
::
dynamic_pointer_cast
<
jit
::
Kernel
>
(
pvmul_d
));
const
auto
&
pvmul_from_key
=
jit
::
KernelPool
::
Instance
().
Get
(
"vmulf4"
);
EXPECT_TRUE
(
pvmul_f
==
pvmul_from_key
);
const
auto
&
pvmul_from_key2
=
jit
::
KernelPool
::
Instance
().
Get
(
"vmulf5"
);
EXPECT_TRUE
(
pvmul_from_key2
==
nullptr
);
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录