Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
55e44761
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
55e44761
编写于
9月 29, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine code and init vsigmoid
上级
2d0ff6a3
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
178 addition
and
118 deletion
+178
-118
paddle/fluid/operators/math/jit_kernel.cc
paddle/fluid/operators/math/jit_kernel.cc
+3
-3
paddle/fluid/operators/math/jit_kernel.h
paddle/fluid/operators/math/jit_kernel.h
+20
-8
paddle/fluid/operators/math/jit_kernel_blas.cc
paddle/fluid/operators/math/jit_kernel_blas.cc
+58
-58
paddle/fluid/operators/math/jit_kernel_exp.cc
paddle/fluid/operators/math/jit_kernel_exp.cc
+41
-10
paddle/fluid/operators/math/jit_kernel_macro.h
paddle/fluid/operators/math/jit_kernel_macro.h
+51
-34
paddle/fluid/operators/math/jit_kernel_test.cc
paddle/fluid/operators/math/jit_kernel_test.cc
+5
-5
未找到文件。
paddle/fluid/operators/math/jit_kernel.cc
浏览文件 @
55e44761
...
...
@@ -28,7 +28,7 @@ KernelPool& KernelPool::Instance() {
return
g_jit_kernels
;
}
const
std
::
shared_ptr
<
Kernel
>
KernelPool
::
Get
(
const
std
::
string
&
key
)
const
{
std
::
shared_ptr
<
const
Kernel
>
KernelPool
::
Get
(
const
std
::
string
&
key
)
const
{
if
(
kers_
.
find
(
key
)
==
kers_
.
end
())
{
return
nullptr
;
}
...
...
@@ -36,7 +36,7 @@ const std::shared_ptr<Kernel> KernelPool::Get(const std::string& key) const {
}
template
<
>
const
std
::
shared_ptr
<
LSTMKernel
<
float
>>
std
::
shared_ptr
<
const
LSTMKernel
<
float
>>
KernelPool
::
Get
<
LSTMKernel
<
float
>
,
int
,
const
std
::
string
&
,
const
std
::
string
&
,
const
std
::
string
&>
(
int
d
,
const
std
::
string
&
act_gate
,
const
std
::
string
&
act_cand
,
...
...
@@ -49,7 +49,7 @@ KernelPool::Get<LSTMKernel<float>, int, const std::string&, const std::string&,
kers_
.
insert
({
key
,
std
::
dynamic_pointer_cast
<
Kernel
>
(
p
)});
return
p
;
}
return
std
::
dynamic_pointer_cast
<
LSTMKernel
<
float
>>
(
kers_
.
at
(
key
));
return
std
::
dynamic_pointer_cast
<
const
LSTMKernel
<
float
>>
(
kers_
.
at
(
key
));
}
}
// namespace jitkernel
...
...
paddle/fluid/operators/math/jit_kernel.h
浏览文件 @
55e44761
...
...
@@ -52,13 +52,13 @@ class KernelPool {
static
KernelPool
&
Instance
();
template
<
typename
Ker
,
typename
...
ARGS
>
const
std
::
shared_ptr
<
Ker
>
Get
(
ARGS
...
args
);
std
::
shared_ptr
<
const
Ker
>
Get
(
ARGS
...
args
);
const
std
::
shared_ptr
<
Kernel
>
Get
(
const
std
::
string
&
key
)
const
;
std
::
shared_ptr
<
const
Kernel
>
Get
(
const
std
::
string
&
key
)
const
;
private:
KernelPool
()
=
default
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Kernel
>>
kers_
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
const
Kernel
>>
kers_
;
DISABLE_COPY_AND_ASSIGN
(
KernelPool
);
};
...
...
@@ -66,26 +66,38 @@ class KernelPool {
template
<
typename
T
>
class
VMulKernel
:
public
Kernel
{
public:
virtual
void
Compute
(
const
int
n
,
const
T
*
x
,
const
T
*
y
,
T
*
z
)
=
0
;
virtual
void
Compute
(
const
int
n
,
const
T
*
x
,
const
T
*
y
,
T
*
z
)
const
=
0
;
};
template
<
typename
T
>
class
VAddKernel
:
public
Kernel
{
public:
virtual
void
Compute
(
const
int
n
,
const
T
*
x
,
const
T
*
y
,
T
*
z
)
=
0
;
virtual
void
Compute
(
const
int
n
,
const
T
*
x
,
const
T
*
y
,
T
*
z
)
const
=
0
;
};
template
<
typename
T
>
class
VScalKernel
:
public
Kernel
{
public:
virtual
void
Compute
(
const
int
n
,
const
T
a
,
const
T
*
x
,
T
*
y
)
=
0
;
virtual
void
Compute
(
const
int
n
,
const
T
a
,
T
*
x
)
=
0
;
virtual
void
Compute
(
const
int
n
,
const
T
a
,
const
T
*
x
,
T
*
y
)
const
=
0
;
virtual
void
Compute
(
const
int
n
,
const
T
a
,
T
*
x
)
const
=
0
;
};
template
<
typename
T
>
class
VExpKernel
:
public
Kernel
{
public:
virtual
void
Compute
(
const
int
n
,
const
T
*
x
,
T
*
y
)
=
0
;
virtual
void
Compute
(
const
int
n
,
const
T
*
x
,
T
*
y
)
const
=
0
;
};
template
<
typename
T
>
class
VSigmoidKernel
:
public
Kernel
{
public:
virtual
void
Compute
(
const
int
n
,
const
T
*
x
,
T
*
y
)
const
=
0
;
};
template
<
typename
T
>
class
VTanhKernel
:
public
Kernel
{
public:
virtual
void
Compute
(
const
int
n
,
const
T
*
x
,
T
*
y
)
const
=
0
;
};
template
<
typename
T
>
...
...
paddle/fluid/operators/math/jit_kernel_blas.cc
浏览文件 @
55e44761
...
...
@@ -34,7 +34,7 @@ namespace jit = platform::jit;
template
<
typename
T
,
platform
::
jit
::
cpu_isa_t
isa
,
jit_block
>
class
VMulKernelImpl
:
public
VMulKernel
<
T
>
{
public:
void
Compute
(
const
int
n
,
const
T
*
x
,
const
T
*
y
,
T
*
z
)
override
{
void
Compute
(
const
int
n
,
const
T
*
x
,
const
T
*
y
,
T
*
z
)
const
override
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
z
[
i
]
=
x
[
i
]
*
y
[
i
];
}
...
...
@@ -42,33 +42,33 @@ class VMulKernelImpl : public VMulKernel<T> {
};
#ifdef PADDLE_WITH_MKLML
#define MKL_FLOAT(isa, block)
\
template <>
\
void VMulKernelImpl<float, isa, block>::Compute(
const int n, const float* x,
\
const float* y, float* z) {
\
platform::dynload::vsMul(n, x, y, z);
\
#define MKL_FLOAT(isa, block) \
template <> \
void VMulKernelImpl<float, isa, block>::Compute(
\
const int n, const float* x, const float* y, float* z) const {
\
platform::dynload::vsMul(n, x, y, z); \
}
#define MKL_DOUBLE(isa, block) \
template <> \
void VMulKernelImpl<double, isa, block>::Compute( \
const int n, const double* x, const double* y, double* z) { \
platform::dynload::vdMul(n, x, y, z); \
#define MKL_DOUBLE(isa, block)
\
template <>
\
void VMulKernelImpl<double, isa, block>::Compute(
\
const int n, const double* x, const double* y, double* z)
const
{ \
platform::dynload::vdMul(n, x, y, z);
\
}
FOR_EACH_ISA
(
MKL_FLOAT
,
kGT16
);
FOR_EACH_ISA_BLOCK
(
MKL_DOUBLE
);
#endif
#define INTRI8_FLOAT(isa)
\
template <>
\
void VMulKernelImpl<float, isa, kEQ8>::Compute(
const int n, const float* x,
\
const float* y, float* z) {
\
__m256 tmpx, tmpy;
\
tmpx = _mm256_loadu_ps(x);
\
tmpy = _mm256_loadu_ps(y);
\
tmpx = _mm256_mul_ps(tmpx, tmpy);
\
_mm256_storeu_ps(z, tmpx);
\
#define INTRI8_FLOAT(isa) \
template <> \
void VMulKernelImpl<float, isa, kEQ8>::Compute(
\
const int n, const float* x, const float* y, float* z) const {
\
__m256 tmpx, tmpy; \
tmpx = _mm256_loadu_ps(x); \
tmpy = _mm256_loadu_ps(y); \
tmpx = _mm256_mul_ps(tmpx, tmpy); \
_mm256_storeu_ps(z, tmpx); \
}
// avx > for > mkl
...
...
@@ -90,7 +90,7 @@ INTRI8_FLOAT(jit::avx512f);
template
<
typename
T
,
platform
::
jit
::
cpu_isa_t
isa
,
jit_block
>
class
VAddKernelImpl
:
public
VAddKernel
<
T
>
{
public:
void
Compute
(
const
int
n
,
const
T
*
x
,
const
T
*
y
,
T
*
z
)
override
{
void
Compute
(
const
int
n
,
const
T
*
x
,
const
T
*
y
,
T
*
z
)
const
override
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
z
[
i
]
=
x
[
i
]
+
y
[
i
];
}
...
...
@@ -98,33 +98,33 @@ class VAddKernelImpl : public VAddKernel<T> {
};
#ifdef PADDLE_WITH_MKLML
#define MKL_FLOAT(isa, block)
\
template <>
\
void VAddKernelImpl<float, isa, block>::Compute(
const int n, const float* x,
\
const float* y, float* z) {
\
platform::dynload::vsAdd(n, x, y, z);
\
#define MKL_FLOAT(isa, block) \
template <> \
void VAddKernelImpl<float, isa, block>::Compute(
\
const int n, const float* x, const float* y, float* z) const {
\
platform::dynload::vsAdd(n, x, y, z); \
}
#define MKL_DOUBLE(isa, block) \
template <> \
void VAddKernelImpl<double, isa, block>::Compute( \
const int n, const double* x, const double* y, double* z) { \
platform::dynload::vdAdd(n, x, y, z); \
#define MKL_DOUBLE(isa, block)
\
template <>
\
void VAddKernelImpl<double, isa, block>::Compute(
\
const int n, const double* x, const double* y, double* z)
const
{ \
platform::dynload::vdAdd(n, x, y, z);
\
}
FOR_EACH_ISA
(
MKL_FLOAT
,
kGT16
);
FOR_EACH_ISA_BLOCK
(
MKL_DOUBLE
);
#endif
#define INTRI8_FLOAT(isa)
\
template <>
\
void VAddKernelImpl<float, isa, kEQ8>::Compute(
const int n, const float* x,
\
const float* y, float* z) {
\
__m256 tmpx, tmpy;
\
tmpx = _mm256_loadu_ps(x);
\
tmpy = _mm256_loadu_ps(y);
\
tmpx = _mm256_add_ps(tmpx, tmpy);
\
_mm256_storeu_ps(z, tmpx);
\
#define INTRI8_FLOAT(isa) \
template <> \
void VAddKernelImpl<float, isa, kEQ8>::Compute(
\
const int n, const float* x, const float* y, float* z) const {
\
__m256 tmpx, tmpy; \
tmpx = _mm256_loadu_ps(x); \
tmpy = _mm256_loadu_ps(y); \
tmpx = _mm256_add_ps(tmpx, tmpy); \
_mm256_storeu_ps(z, tmpx); \
}
#ifdef __AVX__
INTRI8_FLOAT
(
jit
::
avx
);
...
...
@@ -145,12 +145,12 @@ INTRI8_FLOAT(jit::avx512f);
template
<
typename
T
,
platform
::
jit
::
cpu_isa_t
isa
,
jit_block
>
class
VScalKernelImpl
:
public
VScalKernel
<
T
>
{
public:
void
Compute
(
const
int
n
,
const
T
a
,
const
T
*
x
,
T
*
y
)
override
{
void
Compute
(
const
int
n
,
const
T
a
,
const
T
*
x
,
T
*
y
)
const
override
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
a
*
x
[
i
];
}
}
void
Compute
(
const
int
n
,
const
T
a
,
T
*
x
)
override
{
void
Compute
(
const
int
n
,
const
T
a
,
T
*
x
)
const
override
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
x
[
i
]
=
a
*
x
[
i
];
}
...
...
@@ -161,35 +161,35 @@ class VScalKernelImpl : public VScalKernel<T> {
#define MKL_FLOAT(isa, block) \
template <> \
void VScalKernelImpl<float, isa, block>::Compute(const int n, const float a, \
float* x)
{
\
float* x)
const {
\
platform::dynload::cblas_sscal(n, a, x, 1); \
}
#define MKL_DOUBLE(isa, block) \
template <> \
void VScalKernelImpl<double, isa, block>::Compute( \
const int n, const double a, double* x)
{
\
platform::dynload::cblas_dscal(n, a, x, 1); \
#define MKL_DOUBLE(isa, block)
\
template <>
\
void VScalKernelImpl<double, isa, block>::Compute(
\
const int n, const double a, double* x)
const {
\
platform::dynload::cblas_dscal(n, a, x, 1);
\
}
FOR_EACH_ISA
(
MKL_FLOAT
,
kGT16
);
FOR_EACH_ISA_BLOCK
(
MKL_DOUBLE
);
#endif
#define INTRI8_FLOAT(isa)
\
template <>
\
void VScalKernelImpl<float, isa, kEQ8>::Compute(
const int n, const float a,
\
const float* x, float* y)
{ \
__m256 tmp;
\
__m256 scalar = _mm256_set1_ps(a);
\
tmp = _mm256_loadu_ps(x);
\
tmp = _mm256_mul_ps(tmp, scalar);
\
_mm256_storeu_ps(y, tmp);
\
#define INTRI8_FLOAT(isa) \
template <> \
void VScalKernelImpl<float, isa, kEQ8>::Compute(
\
const int n, const float a, const float* x, float* y) const
{ \
__m256 tmp; \
__m256 scalar = _mm256_set1_ps(a); \
tmp = _mm256_loadu_ps(x); \
tmp = _mm256_mul_ps(tmp, scalar); \
_mm256_storeu_ps(y, tmp); \
}
#define INTRI8_INPLACE_FLOAT(isa) \
template <> \
void VScalKernelImpl<float, isa, kEQ8>::Compute(const int n, const float a, \
float* x)
{
\
float* x)
const {
\
__m256 tmp; \
__m256 scalar = _mm256_set1_ps(a); \
tmp = _mm256_loadu_ps(x); \
...
...
paddle/fluid/operators/math/jit_kernel_exp.cc
浏览文件 @
55e44761
...
...
@@ -34,14 +34,13 @@ __m256 Exp(__m256 a);
#endif
namespace
jitkernel
{
namespace
jit
=
platform
::
jit
;
/* VExp JitKernel */
template
<
typename
T
,
jit
::
cpu_isa_t
isa
,
jit_block
>
class
VExpKernelImpl
:
public
VExpKernel
<
T
>
{
public:
void
Compute
(
const
int
n
,
const
T
*
x
,
T
*
y
)
override
{
void
Compute
(
const
int
n
,
const
T
*
x
,
T
*
y
)
const
override
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
std
::
exp
(
x
[
i
]);
}
...
...
@@ -52,15 +51,15 @@ class VExpKernelImpl : public VExpKernel<T> {
#define MKL_FLOAT(isa, block) \
template <> \
void VExpKernelImpl<float, isa, block>::Compute(const int n, const float* x, \
float* y)
{
\
float* y)
const {
\
platform::dynload::vsExp(n, x, y); \
}
#define MKL_DOUBLE(isa, block) \
template <> \
void VExpKernelImpl<double, isa, block>::Compute( \
const int n, const double* x, double* y)
{
\
platform::dynload::vdExp(n, x, y); \
#define MKL_DOUBLE(isa, block)
\
template <>
\
void VExpKernelImpl<double, isa, block>::Compute(
\
const int n, const double* x, double* y)
const {
\
platform::dynload::vdExp(n, x, y);
\
}
FOR_EACH_ISA
(
MKL_FLOAT
,
kLT8
);
FOR_EACH_ISA
(
MKL_FLOAT
,
kGT8LT16
);
...
...
@@ -71,7 +70,7 @@ FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
#define INTRI8_FLOAT(isa) \
template <> \
void VExpKernelImpl<float, isa, kEQ8>::Compute(const int n, const float* x, \
float* y)
{
\
float* y)
const {
\
__m256 tmp = _mm256_loadu_ps(x); \
_mm256_storeu_ps(y, detail::Exp(tmp)); \
}
...
...
@@ -79,7 +78,7 @@ FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
#define INTRI16_FLOAT(isa) \
template <> \
void VExpKernelImpl<float, isa, kEQ16>::Compute(const int n, const float* x, \
float* y)
{
\
float* y)
const {
\
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
tmp0 = detail::Exp(tmp0); \
...
...
@@ -109,6 +108,38 @@ INTRI16_FLOAT(jit::avx512f);
REGISTER_JITKERNEL
(
vexp
,
VExpKernel
);
/* VSigmoid JitKernel */
template
<
typename
T
,
jit
::
cpu_isa_t
isa
,
jit_block
>
class
VSigmoidKernelImpl
:
public
VSigmoidKernel
<
T
>
{
public:
explicit
VSigmoidKernelImpl
(
int
d
)
:
VSigmoidKernel
<
T
>
()
{
vexp_
=
KernelPool
::
Instance
().
template
Get
<
VExpKernel
<
T
>
>
(
d
);
}
void
Compute
(
const
int
n
,
const
T
*
x
,
T
*
y
)
const
override
{
const
T
min
=
SIGMOID_THRESHOLD_MIN
;
const
T
max
=
SIGMOID_THRESHOLD_MAX
;
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
(
x
[
i
]
<
min
)
?
min
:
((
x
[
i
]
>
max
)
?
max
:
x
[
i
]);
y
[
i
]
=
static_cast
<
T
>
(
0
)
-
y
[
i
];
}
vexp_
->
Compute
(
n
,
y
,
y
);
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
static_cast
<
T
>
(
1
)
/
(
static_cast
<
T
>
(
1
)
+
y
[
i
]);
}
}
private:
std
::
shared_ptr
<
const
VExpKernel
<
T
>>
vexp_
;
};
#define JITKERNEL_NEW_ACT_IMPL(ker, dtype, isa, k) \
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype, isa, k>>(d))
REGISTER_JITKERNEL_ARGS
(
vsigmoid
,
VSigmoidKernel
,
JITKERNEL_DECLARE
,
JITKERNEL_KEY
,
JITKERNEL_NEW_ACT_IMPL
);
#undef JITKERNEL_NEW_ACT_IMPL
}
// namespace jitkernel
}
// namespace math
}
// namespace operators
...
...
paddle/fluid/operators/math/jit_kernel_macro.h
浏览文件 @
55e44761
...
...
@@ -23,51 +23,68 @@ namespace jitkernel {
namespace
jit
=
platform
::
jit
;
#define NEW_JITKERNEL_IMPL(src, t, isa, k) \
p = std::dynamic_pointer_cast<src<t>>( \
std::make_shared<src##Impl<t, isa, k>>())
#define SEARCH_BLOCK(src, t, isa) \
#define SEARCH_BLOCK(macro_, ker, dtype, isa) \
if (d < AVX_FLOAT_BLOCK) { \
NEW_JITKERNEL_IMPL(src, t, isa, kLT8);
\
macro_(ker, dtype, isa, kLT8);
\
} else if (d == AVX_FLOAT_BLOCK) { \
NEW_JITKERNEL_IMPL(src, t, isa, kEQ8);
\
macro_(ker, dtype, isa, kEQ8);
\
} else if (d > AVX_FLOAT_BLOCK && d < AVX512_FLOAT_BLOCK) { \
NEW_JITKERNEL_IMPL(src, t, isa, kGT8LT16);
\
macro_(ker, dtype, isa, kGT8LT16);
\
} else if (d == AVX512_FLOAT_BLOCK) { \
NEW_JITKERNEL_IMPL(src, t, isa, kEQ16);
\
macro_(ker, dtype, isa, kEQ16);
\
} else { \
NEW_JITKERNEL_IMPL(src, t, isa, kGT16);
\
macro_(ker, dtype, isa, kGT16);
\
}
#define SEARCH_ISA_BLOCK(
src, t
) \
if (jit::MayIUse(jit::avx512f)) { \
SEARCH_BLOCK(
src, t
, jit::avx512f); \
} else if (jit::MayIUse(jit::avx2)) { \
SEARCH_BLOCK(
src, t
, jit::avx2); \
} else if (jit::MayIUse(jit::avx)) { \
SEARCH_BLOCK(
src, t
, jit::avx); \
} else { \
SEARCH_BLOCK(
src, t
, jit::isa_any); \
#define SEARCH_ISA_BLOCK(
macro_, ker, dtype
) \
if (jit::MayIUse(jit::avx512f)) {
\
SEARCH_BLOCK(
macro_, ker, dtype
, jit::avx512f); \
} else if (jit::MayIUse(jit::avx2)) {
\
SEARCH_BLOCK(
macro_, ker, dtype
, jit::avx2); \
} else if (jit::MayIUse(jit::avx)) {
\
SEARCH_BLOCK(
macro_, ker, dtype
, jit::avx); \
} else {
\
SEARCH_BLOCK(
macro_, ker, dtype
, jit::isa_any); \
}
#define JITKERNEL_WITH_DTYPE(ker_key, ker_class, ker_dtype, dtype_key) \
template <> \
const std::shared_ptr<ker_class<ker_dtype>> \
KernelPool::Get<ker_class<ker_dtype>>(int d) { \
std::string key = #ker_key #dtype_key + std::to_string(d); \
if (kers_.find(key) == kers_.end()) { \
std::shared_ptr<ker_class<ker_dtype>> p; \
SEARCH_ISA_BLOCK(ker_class, ker_dtype); \
kers_.insert({key, std::dynamic_pointer_cast<Kernel>(p)}); \
return p; \
} \
return std::dynamic_pointer_cast<ker_class<ker_dtype>>(kers_.at(key)); \
#define JITKERNEL_DECLARE(ker_class, ker_dtype) \
template <> \
std::shared_ptr<const ker_class<ker_dtype>> \
KernelPool::Get<ker_class<ker_dtype>, int>(int d)
#define JITKERNEL_KEY(ker_key, dtype_key) \
#ker_key #dtype_key + std::to_string(d)
#define JITKERNEL_NEW_IMPL(ker, dtype, isa, k) \
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype, isa, k>>())
#define JITKERNEL_WITH_DTYPE(ker_key, ker_class, ker_dtype, dtype_key, \
marco_declare, macro_key, macro_impl) \
marco_declare(ker_class, ker_dtype) { \
std::string key = macro_key(ker_key, dtype_key); \
if (kers_.find(key) == kers_.end()) { \
std::shared_ptr<ker_class<ker_dtype>> p; \
SEARCH_ISA_BLOCK(macro_impl, ker_class, ker_dtype); \
kers_.insert({key, std::dynamic_pointer_cast<Kernel>(p)}); \
return p; \
} \
return std::dynamic_pointer_cast<const ker_class<ker_dtype>>( \
kers_.at(key)); \
}
#define REGISTER_JITKERNEL(ker_key, ker_class) \
JITKERNEL_WITH_DTYPE(ker_key, ker_class, float, f); \
JITKERNEL_WITH_DTYPE(ker_key, ker_class, double, d)
#define REGISTER_JITKERNEL(ker_key, ker_class) \
JITKERNEL_WITH_DTYPE(ker_key, ker_class, float, f, JITKERNEL_DECLARE, \
JITKERNEL_KEY, JITKERNEL_NEW_IMPL); \
JITKERNEL_WITH_DTYPE(ker_key, ker_class, double, d, JITKERNEL_DECLARE, \
JITKERNEL_KEY, JITKERNEL_NEW_IMPL)
#define REGISTER_JITKERNEL_ARGS(ker_key, ker_class, marco_declare, macro_key, \
macro_impl) \
JITKERNEL_WITH_DTYPE(ker_key, ker_class, float, f, marco_declare, macro_key, \
macro_impl); \
JITKERNEL_WITH_DTYPE(ker_key, ker_class, double, d, marco_declare, \
macro_key, macro_impl)
#define FOR_EACH_ISA(macro_, block) \
macro_(jit::avx512f, block); \
...
...
paddle/fluid/operators/math/jit_kernel_test.cc
浏览文件 @
55e44761
...
...
@@ -388,16 +388,16 @@ TEST(JitKernel, pool) {
const
auto
&
pvmul_f
=
jit
::
KernelPool
::
Instance
().
template
Get
<
jit
::
VMulKernel
<
float
>
>
(
4
);
EXPECT_TRUE
(
std
::
dynamic_pointer_cast
<
jit
::
Kernel
>
(
plstm2
)
!=
std
::
dynamic_pointer_cast
<
jit
::
Kernel
>
(
pvmul_f
));
EXPECT_TRUE
(
std
::
dynamic_pointer_cast
<
const
jit
::
Kernel
>
(
plstm2
)
!=
std
::
dynamic_pointer_cast
<
const
jit
::
Kernel
>
(
pvmul_f
));
const
auto
&
pvmul_d
=
jit
::
KernelPool
::
Instance
().
template
Get
<
jit
::
VMulKernel
<
double
>
>
(
4
);
EXPECT_TRUE
(
std
::
dynamic_pointer_cast
<
jit
::
Kernel
>
(
pvmul_f
)
!=
std
::
dynamic_pointer_cast
<
jit
::
Kernel
>
(
pvmul_d
));
EXPECT_TRUE
(
std
::
dynamic_pointer_cast
<
const
jit
::
Kernel
>
(
pvmul_f
)
!=
std
::
dynamic_pointer_cast
<
const
jit
::
Kernel
>
(
pvmul_d
));
const
auto
&
pvmul_from_key
=
jit
::
KernelPool
::
Instance
().
Get
(
"vmulf4"
);
EXPECT_
TRUE
(
pvmul_f
==
pvmul_from_key
);
EXPECT_
EQ
(
pvmul_f
,
pvmul_from_key
);
const
auto
&
pvmul_from_key2
=
jit
::
KernelPool
::
Instance
().
Get
(
"vmulf5"
);
EXPECT_TRUE
(
pvmul_from_key2
==
nullptr
);
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录