Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
55e44761
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
55e44761
编写于
9月 29, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine code and init vsigmoid
上级
2d0ff6a3
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
178 addition
and
118 deletion
+178
-118
paddle/fluid/operators/math/jit_kernel.cc
paddle/fluid/operators/math/jit_kernel.cc
+3
-3
paddle/fluid/operators/math/jit_kernel.h
paddle/fluid/operators/math/jit_kernel.h
+20
-8
paddle/fluid/operators/math/jit_kernel_blas.cc
paddle/fluid/operators/math/jit_kernel_blas.cc
+58
-58
paddle/fluid/operators/math/jit_kernel_exp.cc
paddle/fluid/operators/math/jit_kernel_exp.cc
+41
-10
paddle/fluid/operators/math/jit_kernel_macro.h
paddle/fluid/operators/math/jit_kernel_macro.h
+51
-34
paddle/fluid/operators/math/jit_kernel_test.cc
paddle/fluid/operators/math/jit_kernel_test.cc
+5
-5
未找到文件。
paddle/fluid/operators/math/jit_kernel.cc
浏览文件 @
55e44761
...
...
@@ -28,7 +28,7 @@ KernelPool& KernelPool::Instance() {
return
g_jit_kernels
;
}
const
std
::
shared_ptr
<
Kernel
>
KernelPool
::
Get
(
const
std
::
string
&
key
)
const
{
std
::
shared_ptr
<
const
Kernel
>
KernelPool
::
Get
(
const
std
::
string
&
key
)
const
{
if
(
kers_
.
find
(
key
)
==
kers_
.
end
())
{
return
nullptr
;
}
...
...
@@ -36,7 +36,7 @@ const std::shared_ptr<Kernel> KernelPool::Get(const std::string& key) const {
}
template
<
>
const
std
::
shared_ptr
<
LSTMKernel
<
float
>>
std
::
shared_ptr
<
const
LSTMKernel
<
float
>>
KernelPool
::
Get
<
LSTMKernel
<
float
>
,
int
,
const
std
::
string
&
,
const
std
::
string
&
,
const
std
::
string
&>
(
int
d
,
const
std
::
string
&
act_gate
,
const
std
::
string
&
act_cand
,
...
...
@@ -49,7 +49,7 @@ KernelPool::Get<LSTMKernel<float>, int, const std::string&, const std::string&,
kers_
.
insert
({
key
,
std
::
dynamic_pointer_cast
<
Kernel
>
(
p
)});
return
p
;
}
return
std
::
dynamic_pointer_cast
<
LSTMKernel
<
float
>>
(
kers_
.
at
(
key
));
return
std
::
dynamic_pointer_cast
<
const
LSTMKernel
<
float
>>
(
kers_
.
at
(
key
));
}
}
// namespace jitkernel
...
...
paddle/fluid/operators/math/jit_kernel.h
浏览文件 @
55e44761
...
...
@@ -52,13 +52,13 @@ class KernelPool {
static
KernelPool
&
Instance
();
template
<
typename
Ker
,
typename
...
ARGS
>
const
std
::
shared_ptr
<
Ker
>
Get
(
ARGS
...
args
);
std
::
shared_ptr
<
const
Ker
>
Get
(
ARGS
...
args
);
const
std
::
shared_ptr
<
Kernel
>
Get
(
const
std
::
string
&
key
)
const
;
std
::
shared_ptr
<
const
Kernel
>
Get
(
const
std
::
string
&
key
)
const
;
private:
KernelPool
()
=
default
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Kernel
>>
kers_
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
const
Kernel
>>
kers_
;
DISABLE_COPY_AND_ASSIGN
(
KernelPool
);
};
...
...
@@ -66,26 +66,38 @@ class KernelPool {
template
<
typename
T
>
class
VMulKernel
:
public
Kernel
{
public:
virtual
void
Compute
(
const
int
n
,
const
T
*
x
,
const
T
*
y
,
T
*
z
)
=
0
;
virtual
void
Compute
(
const
int
n
,
const
T
*
x
,
const
T
*
y
,
T
*
z
)
const
=
0
;
};
template
<
typename
T
>
class
VAddKernel
:
public
Kernel
{
public:
virtual
void
Compute
(
const
int
n
,
const
T
*
x
,
const
T
*
y
,
T
*
z
)
=
0
;
virtual
void
Compute
(
const
int
n
,
const
T
*
x
,
const
T
*
y
,
T
*
z
)
const
=
0
;
};
template
<
typename
T
>
class
VScalKernel
:
public
Kernel
{
public:
virtual
void
Compute
(
const
int
n
,
const
T
a
,
const
T
*
x
,
T
*
y
)
=
0
;
virtual
void
Compute
(
const
int
n
,
const
T
a
,
T
*
x
)
=
0
;
virtual
void
Compute
(
const
int
n
,
const
T
a
,
const
T
*
x
,
T
*
y
)
const
=
0
;
virtual
void
Compute
(
const
int
n
,
const
T
a
,
T
*
x
)
const
=
0
;
};
template
<
typename
T
>
class
VExpKernel
:
public
Kernel
{
public:
virtual
void
Compute
(
const
int
n
,
const
T
*
x
,
T
*
y
)
=
0
;
virtual
void
Compute
(
const
int
n
,
const
T
*
x
,
T
*
y
)
const
=
0
;
};
template
<
typename
T
>
class
VSigmoidKernel
:
public
Kernel
{
public:
virtual
void
Compute
(
const
int
n
,
const
T
*
x
,
T
*
y
)
const
=
0
;
};
template
<
typename
T
>
class
VTanhKernel
:
public
Kernel
{
public:
virtual
void
Compute
(
const
int
n
,
const
T
*
x
,
T
*
y
)
const
=
0
;
};
template
<
typename
T
>
...
...
paddle/fluid/operators/math/jit_kernel_blas.cc
浏览文件 @
55e44761
...
...
@@ -34,7 +34,7 @@ namespace jit = platform::jit;
template
<
typename
T
,
platform
::
jit
::
cpu_isa_t
isa
,
jit_block
>
class
VMulKernelImpl
:
public
VMulKernel
<
T
>
{
public:
void
Compute
(
const
int
n
,
const
T
*
x
,
const
T
*
y
,
T
*
z
)
override
{
void
Compute
(
const
int
n
,
const
T
*
x
,
const
T
*
y
,
T
*
z
)
const
override
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
z
[
i
]
=
x
[
i
]
*
y
[
i
];
}
...
...
@@ -42,33 +42,33 @@ class VMulKernelImpl : public VMulKernel<T> {
};
#ifdef PADDLE_WITH_MKLML
#define MKL_FLOAT(isa, block)
\
template <>
\
void VMulKernelImpl<float, isa, block>::Compute(
const int n, const float* x,
\
const float* y, float* z) {
\
platform::dynload::vsMul(n, x, y, z);
\
#define MKL_FLOAT(isa, block) \
template <> \
void VMulKernelImpl<float, isa, block>::Compute(
\
const int n, const float* x, const float* y, float* z) const {
\
platform::dynload::vsMul(n, x, y, z); \
}
#define MKL_DOUBLE(isa, block) \
template <> \
void VMulKernelImpl<double, isa, block>::Compute( \
const int n, const double* x, const double* y, double* z) { \
platform::dynload::vdMul(n, x, y, z); \
#define MKL_DOUBLE(isa, block)
\
template <>
\
void VMulKernelImpl<double, isa, block>::Compute(
\
const int n, const double* x, const double* y, double* z)
const
{ \
platform::dynload::vdMul(n, x, y, z);
\
}
FOR_EACH_ISA
(
MKL_FLOAT
,
kGT16
);
FOR_EACH_ISA_BLOCK
(
MKL_DOUBLE
);
#endif
#define INTRI8_FLOAT(isa)
\
template <>
\
void VMulKernelImpl<float, isa, kEQ8>::Compute(
const int n, const float* x,
\
const float* y, float* z) {
\
__m256 tmpx, tmpy;
\
tmpx = _mm256_loadu_ps(x);
\
tmpy = _mm256_loadu_ps(y);
\
tmpx = _mm256_mul_ps(tmpx, tmpy);
\
_mm256_storeu_ps(z, tmpx);
\
#define INTRI8_FLOAT(isa) \
template <> \
void VMulKernelImpl<float, isa, kEQ8>::Compute(
\
const int n, const float* x, const float* y, float* z) const {
\
__m256 tmpx, tmpy; \
tmpx = _mm256_loadu_ps(x); \
tmpy = _mm256_loadu_ps(y); \
tmpx = _mm256_mul_ps(tmpx, tmpy); \
_mm256_storeu_ps(z, tmpx); \
}
// avx > for > mkl
...
...
@@ -90,7 +90,7 @@ INTRI8_FLOAT(jit::avx512f);
template
<
typename
T
,
platform
::
jit
::
cpu_isa_t
isa
,
jit_block
>
class
VAddKernelImpl
:
public
VAddKernel
<
T
>
{
public:
void
Compute
(
const
int
n
,
const
T
*
x
,
const
T
*
y
,
T
*
z
)
override
{
void
Compute
(
const
int
n
,
const
T
*
x
,
const
T
*
y
,
T
*
z
)
const
override
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
z
[
i
]
=
x
[
i
]
+
y
[
i
];
}
...
...
@@ -98,33 +98,33 @@ class VAddKernelImpl : public VAddKernel<T> {
};
#ifdef PADDLE_WITH_MKLML
#define MKL_FLOAT(isa, block)
\
template <>
\
void VAddKernelImpl<float, isa, block>::Compute(
const int n, const float* x,
\
const float* y, float* z) {
\
platform::dynload::vsAdd(n, x, y, z);
\
#define MKL_FLOAT(isa, block) \
template <> \
void VAddKernelImpl<float, isa, block>::Compute(
\
const int n, const float* x, const float* y, float* z) const {
\
platform::dynload::vsAdd(n, x, y, z); \
}
#define MKL_DOUBLE(isa, block) \
template <> \
void VAddKernelImpl<double, isa, block>::Compute( \
const int n, const double* x, const double* y, double* z) { \
platform::dynload::vdAdd(n, x, y, z); \
#define MKL_DOUBLE(isa, block)
\
template <>
\
void VAddKernelImpl<double, isa, block>::Compute(
\
const int n, const double* x, const double* y, double* z)
const
{ \
platform::dynload::vdAdd(n, x, y, z);
\
}
FOR_EACH_ISA
(
MKL_FLOAT
,
kGT16
);
FOR_EACH_ISA_BLOCK
(
MKL_DOUBLE
);
#endif
#define INTRI8_FLOAT(isa)
\
template <>
\
void VAddKernelImpl<float, isa, kEQ8>::Compute(
const int n, const float* x,
\
const float* y, float* z) {
\
__m256 tmpx, tmpy;
\
tmpx = _mm256_loadu_ps(x);
\
tmpy = _mm256_loadu_ps(y);
\
tmpx = _mm256_add_ps(tmpx, tmpy);
\
_mm256_storeu_ps(z, tmpx);
\
#define INTRI8_FLOAT(isa) \
template <> \
void VAddKernelImpl<float, isa, kEQ8>::Compute(
\
const int n, const float* x, const float* y, float* z) const {
\
__m256 tmpx, tmpy; \
tmpx = _mm256_loadu_ps(x); \
tmpy = _mm256_loadu_ps(y); \
tmpx = _mm256_add_ps(tmpx, tmpy); \
_mm256_storeu_ps(z, tmpx); \
}
#ifdef __AVX__
INTRI8_FLOAT
(
jit
::
avx
);
...
...
@@ -145,12 +145,12 @@ INTRI8_FLOAT(jit::avx512f);
template
<
typename
T
,
platform
::
jit
::
cpu_isa_t
isa
,
jit_block
>
class
VScalKernelImpl
:
public
VScalKernel
<
T
>
{
public:
void
Compute
(
const
int
n
,
const
T
a
,
const
T
*
x
,
T
*
y
)
override
{
void
Compute
(
const
int
n
,
const
T
a
,
const
T
*
x
,
T
*
y
)
const
override
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
a
*
x
[
i
];
}
}
void
Compute
(
const
int
n
,
const
T
a
,
T
*
x
)
override
{
void
Compute
(
const
int
n
,
const
T
a
,
T
*
x
)
const
override
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
x
[
i
]
=
a
*
x
[
i
];
}
...
...
@@ -161,35 +161,35 @@ class VScalKernelImpl : public VScalKernel<T> {
#define MKL_FLOAT(isa, block) \
template <> \
void VScalKernelImpl<float, isa, block>::Compute(const int n, const float a, \
float* x)
{
\
float* x)
const {
\
platform::dynload::cblas_sscal(n, a, x, 1); \
}
#define MKL_DOUBLE(isa, block) \
template <> \
void VScalKernelImpl<double, isa, block>::Compute( \
const int n, const double a, double* x)
{
\
platform::dynload::cblas_dscal(n, a, x, 1); \
#define MKL_DOUBLE(isa, block)
\
template <>
\
void VScalKernelImpl<double, isa, block>::Compute(
\
const int n, const double a, double* x)
const {
\
platform::dynload::cblas_dscal(n, a, x, 1);
\
}
FOR_EACH_ISA
(
MKL_FLOAT
,
kGT16
);
FOR_EACH_ISA_BLOCK
(
MKL_DOUBLE
);
#endif
#define INTRI8_FLOAT(isa)
\
template <>
\
void VScalKernelImpl<float, isa, kEQ8>::Compute(
const int n, const float a,
\
const float* x, float* y)
{ \
__m256 tmp;
\
__m256 scalar = _mm256_set1_ps(a);
\
tmp = _mm256_loadu_ps(x);
\
tmp = _mm256_mul_ps(tmp, scalar);
\
_mm256_storeu_ps(y, tmp);
\
#define INTRI8_FLOAT(isa) \
template <> \
void VScalKernelImpl<float, isa, kEQ8>::Compute(
\
const int n, const float a, const float* x, float* y) const
{ \
__m256 tmp; \
__m256 scalar = _mm256_set1_ps(a); \
tmp = _mm256_loadu_ps(x); \
tmp = _mm256_mul_ps(tmp, scalar); \
_mm256_storeu_ps(y, tmp); \
}
#define INTRI8_INPLACE_FLOAT(isa) \
template <> \
void VScalKernelImpl<float, isa, kEQ8>::Compute(const int n, const float a, \
float* x)
{
\
float* x)
const {
\
__m256 tmp; \
__m256 scalar = _mm256_set1_ps(a); \
tmp = _mm256_loadu_ps(x); \
...
...
paddle/fluid/operators/math/jit_kernel_exp.cc
浏览文件 @
55e44761
...
...
@@ -34,14 +34,13 @@ __m256 Exp(__m256 a);
#endif
namespace
jitkernel
{
namespace
jit
=
platform
::
jit
;
/* VExp JitKernel */
template
<
typename
T
,
jit
::
cpu_isa_t
isa
,
jit_block
>
class
VExpKernelImpl
:
public
VExpKernel
<
T
>
{
public:
void
Compute
(
const
int
n
,
const
T
*
x
,
T
*
y
)
override
{
void
Compute
(
const
int
n
,
const
T
*
x
,
T
*
y
)
const
override
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
std
::
exp
(
x
[
i
]);
}
...
...
@@ -52,15 +51,15 @@ class VExpKernelImpl : public VExpKernel<T> {
#define MKL_FLOAT(isa, block) \
template <> \
void VExpKernelImpl<float, isa, block>::Compute(const int n, const float* x, \
float* y)
{
\
float* y)
const {
\
platform::dynload::vsExp(n, x, y); \
}
#define MKL_DOUBLE(isa, block) \
template <> \
void VExpKernelImpl<double, isa, block>::Compute( \
const int n, const double* x, double* y)
{
\
platform::dynload::vdExp(n, x, y); \
#define MKL_DOUBLE(isa, block)
\
template <>
\
void VExpKernelImpl<double, isa, block>::Compute(
\
const int n, const double* x, double* y)
const {
\
platform::dynload::vdExp(n, x, y);
\
}
FOR_EACH_ISA
(
MKL_FLOAT
,
kLT8
);
FOR_EACH_ISA
(
MKL_FLOAT
,
kGT8LT16
);
...
...
@@ -71,7 +70,7 @@ FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
#define INTRI8_FLOAT(isa) \
template <> \
void VExpKernelImpl<float, isa, kEQ8>::Compute(const int n, const float* x, \
float* y)
{
\
float* y)
const {
\
__m256 tmp = _mm256_loadu_ps(x); \
_mm256_storeu_ps(y, detail::Exp(tmp)); \
}
...
...
@@ -79,7 +78,7 @@ FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
#define INTRI16_FLOAT(isa) \
template <> \
void VExpKernelImpl<float, isa, kEQ16>::Compute(const int n, const float* x, \
float* y)
{
\
float* y)
const {
\
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
tmp0 = detail::Exp(tmp0); \
...
...
@@ -109,6 +108,38 @@ INTRI16_FLOAT(jit::avx512f);
REGISTER_JITKERNEL
(
vexp
,
VExpKernel
);
/* VSigmoid JitKernel */
template
<
typename
T
,
jit
::
cpu_isa_t
isa
,
jit_block
>
class
VSigmoidKernelImpl
:
public
VSigmoidKernel
<
T
>
{
public:
explicit
VSigmoidKernelImpl
(
int
d
)
:
VSigmoidKernel
<
T
>
()
{
vexp_
=
KernelPool
::
Instance
().
template
Get
<
VExpKernel
<
T
>
>
(
d
);
}
void
Compute
(
const
int
n
,
const
T
*
x
,
T
*
y
)
const
override
{
const
T
min
=
SIGMOID_THRESHOLD_MIN
;
const
T
max
=
SIGMOID_THRESHOLD_MAX
;
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
(
x
[
i
]
<
min
)
?
min
:
((
x
[
i
]
>
max
)
?
max
:
x
[
i
]);
y
[
i
]
=
static_cast
<
T
>
(
0
)
-
y
[
i
];
}
vexp_
->
Compute
(
n
,
y
,
y
);
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
static_cast
<
T
>
(
1
)
/
(
static_cast
<
T
>
(
1
)
+
y
[
i
]);
}
}
private:
std
::
shared_ptr
<
const
VExpKernel
<
T
>>
vexp_
;
};
#define JITKERNEL_NEW_ACT_IMPL(ker, dtype, isa, k) \
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype, isa, k>>(d))
REGISTER_JITKERNEL_ARGS
(
vsigmoid
,
VSigmoidKernel
,
JITKERNEL_DECLARE
,
JITKERNEL_KEY
,
JITKERNEL_NEW_ACT_IMPL
);
#undef JITKERNEL_NEW_ACT_IMPL
}
// namespace jitkernel
}
// namespace math
}
// namespace operators
...
...
paddle/fluid/operators/math/jit_kernel_macro.h
浏览文件 @
55e44761
...
...
@@ -23,51 +23,68 @@ namespace jitkernel {
namespace
jit
=
platform
::
jit
;
#define NEW_JITKERNEL_IMPL(src, t, isa, k) \
p = std::dynamic_pointer_cast<src<t>>( \
std::make_shared<src##Impl<t, isa, k>>())
#define SEARCH_BLOCK(src, t, isa) \
#define SEARCH_BLOCK(macro_, ker, dtype, isa) \
if (d < AVX_FLOAT_BLOCK) { \
NEW_JITKERNEL_IMPL(src, t, isa, kLT8);
\
macro_(ker, dtype, isa, kLT8);
\
} else if (d == AVX_FLOAT_BLOCK) { \
NEW_JITKERNEL_IMPL(src, t, isa, kEQ8);
\
macro_(ker, dtype, isa, kEQ8);
\
} else if (d > AVX_FLOAT_BLOCK && d < AVX512_FLOAT_BLOCK) { \
NEW_JITKERNEL_IMPL(src, t, isa, kGT8LT16);
\
macro_(ker, dtype, isa, kGT8LT16);
\
} else if (d == AVX512_FLOAT_BLOCK) { \
NEW_JITKERNEL_IMPL(src, t, isa, kEQ16);
\
macro_(ker, dtype, isa, kEQ16);
\
} else { \
NEW_JITKERNEL_IMPL(src, t, isa, kGT16);
\
macro_(ker, dtype, isa, kGT16);
\
}
#define SEARCH_ISA_BLOCK(
src, t
) \
if (jit::MayIUse(jit::avx512f)) { \
SEARCH_BLOCK(
src, t
, jit::avx512f); \
} else if (jit::MayIUse(jit::avx2)) { \
SEARCH_BLOCK(
src, t
, jit::avx2); \
} else if (jit::MayIUse(jit::avx)) { \
SEARCH_BLOCK(
src, t
, jit::avx); \
} else { \
SEARCH_BLOCK(
src, t
, jit::isa_any); \
#define SEARCH_ISA_BLOCK(
macro_, ker, dtype
) \
if (jit::MayIUse(jit::avx512f)) {
\
SEARCH_BLOCK(
macro_, ker, dtype
, jit::avx512f); \
} else if (jit::MayIUse(jit::avx2)) {
\
SEARCH_BLOCK(
macro_, ker, dtype
, jit::avx2); \
} else if (jit::MayIUse(jit::avx)) {
\
SEARCH_BLOCK(
macro_, ker, dtype
, jit::avx); \
} else {
\
SEARCH_BLOCK(
macro_, ker, dtype
, jit::isa_any); \
}
#define JITKERNEL_WITH_DTYPE(ker_key, ker_class, ker_dtype, dtype_key) \
template <> \
const std::shared_ptr<ker_class<ker_dtype>> \
KernelPool::Get<ker_class<ker_dtype>>(int d) { \
std::string key = #ker_key #dtype_key + std::to_string(d); \
if (kers_.find(key) == kers_.end()) { \
std::shared_ptr<ker_class<ker_dtype>> p; \
SEARCH_ISA_BLOCK(ker_class, ker_dtype); \
kers_.insert({key, std::dynamic_pointer_cast<Kernel>(p)}); \
return p; \
} \
return std::dynamic_pointer_cast<ker_class<ker_dtype>>(kers_.at(key)); \
#define JITKERNEL_DECLARE(ker_class, ker_dtype) \
template <> \
std::shared_ptr<const ker_class<ker_dtype>> \
KernelPool::Get<ker_class<ker_dtype>, int>(int d)
#define JITKERNEL_KEY(ker_key, dtype_key) \
#ker_key #dtype_key + std::to_string(d)
#define JITKERNEL_NEW_IMPL(ker, dtype, isa, k) \
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype, isa, k>>())
#define JITKERNEL_WITH_DTYPE(ker_key, ker_class, ker_dtype, dtype_key, \
marco_declare, macro_key, macro_impl) \
marco_declare(ker_class, ker_dtype) { \
std::string key = macro_key(ker_key, dtype_key); \
if (kers_.find(key) == kers_.end()) { \
std::shared_ptr<ker_class<ker_dtype>> p; \
SEARCH_ISA_BLOCK(macro_impl, ker_class, ker_dtype); \
kers_.insert({key, std::dynamic_pointer_cast<Kernel>(p)}); \
return p; \
} \
return std::dynamic_pointer_cast<const ker_class<ker_dtype>>( \
kers_.at(key)); \
}
#define REGISTER_JITKERNEL(ker_key, ker_class) \
JITKERNEL_WITH_DTYPE(ker_key, ker_class, float, f); \
JITKERNEL_WITH_DTYPE(ker_key, ker_class, double, d)
#define REGISTER_JITKERNEL(ker_key, ker_class) \
JITKERNEL_WITH_DTYPE(ker_key, ker_class, float, f, JITKERNEL_DECLARE, \
JITKERNEL_KEY, JITKERNEL_NEW_IMPL); \
JITKERNEL_WITH_DTYPE(ker_key, ker_class, double, d, JITKERNEL_DECLARE, \
JITKERNEL_KEY, JITKERNEL_NEW_IMPL)
#define REGISTER_JITKERNEL_ARGS(ker_key, ker_class, marco_declare, macro_key, \
macro_impl) \
JITKERNEL_WITH_DTYPE(ker_key, ker_class, float, f, marco_declare, macro_key, \
macro_impl); \
JITKERNEL_WITH_DTYPE(ker_key, ker_class, double, d, marco_declare, \
macro_key, macro_impl)
#define FOR_EACH_ISA(macro_, block) \
macro_(jit::avx512f, block); \
...
...
paddle/fluid/operators/math/jit_kernel_test.cc
浏览文件 @
55e44761
...
...
@@ -388,16 +388,16 @@ TEST(JitKernel, pool) {
const
auto
&
pvmul_f
=
jit
::
KernelPool
::
Instance
().
template
Get
<
jit
::
VMulKernel
<
float
>
>
(
4
);
EXPECT_TRUE
(
std
::
dynamic_pointer_cast
<
jit
::
Kernel
>
(
plstm2
)
!=
std
::
dynamic_pointer_cast
<
jit
::
Kernel
>
(
pvmul_f
));
EXPECT_TRUE
(
std
::
dynamic_pointer_cast
<
const
jit
::
Kernel
>
(
plstm2
)
!=
std
::
dynamic_pointer_cast
<
const
jit
::
Kernel
>
(
pvmul_f
));
const
auto
&
pvmul_d
=
jit
::
KernelPool
::
Instance
().
template
Get
<
jit
::
VMulKernel
<
double
>
>
(
4
);
EXPECT_TRUE
(
std
::
dynamic_pointer_cast
<
jit
::
Kernel
>
(
pvmul_f
)
!=
std
::
dynamic_pointer_cast
<
jit
::
Kernel
>
(
pvmul_d
));
EXPECT_TRUE
(
std
::
dynamic_pointer_cast
<
const
jit
::
Kernel
>
(
pvmul_f
)
!=
std
::
dynamic_pointer_cast
<
const
jit
::
Kernel
>
(
pvmul_d
));
const
auto
&
pvmul_from_key
=
jit
::
KernelPool
::
Instance
().
Get
(
"vmulf4"
);
EXPECT_
TRUE
(
pvmul_f
==
pvmul_from_key
);
EXPECT_
EQ
(
pvmul_f
,
pvmul_from_key
);
const
auto
&
pvmul_from_key2
=
jit
::
KernelPool
::
Instance
().
Get
(
"vmulf5"
);
EXPECT_TRUE
(
pvmul_from_key2
==
nullptr
);
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录