Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
55e44761
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
55e44761
编写于
9月 29, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine code and init vsigmoid
上级
2d0ff6a3
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
178 addition
and
118 deletion
+178
-118
paddle/fluid/operators/math/jit_kernel.cc
paddle/fluid/operators/math/jit_kernel.cc
+3
-3
paddle/fluid/operators/math/jit_kernel.h
paddle/fluid/operators/math/jit_kernel.h
+20
-8
paddle/fluid/operators/math/jit_kernel_blas.cc
paddle/fluid/operators/math/jit_kernel_blas.cc
+58
-58
paddle/fluid/operators/math/jit_kernel_exp.cc
paddle/fluid/operators/math/jit_kernel_exp.cc
+41
-10
paddle/fluid/operators/math/jit_kernel_macro.h
paddle/fluid/operators/math/jit_kernel_macro.h
+51
-34
paddle/fluid/operators/math/jit_kernel_test.cc
paddle/fluid/operators/math/jit_kernel_test.cc
+5
-5
未找到文件。
paddle/fluid/operators/math/jit_kernel.cc
浏览文件 @
55e44761
...
@@ -28,7 +28,7 @@ KernelPool& KernelPool::Instance() {
...
@@ -28,7 +28,7 @@ KernelPool& KernelPool::Instance() {
return
g_jit_kernels
;
return
g_jit_kernels
;
}
}
const
std
::
shared_ptr
<
Kernel
>
KernelPool
::
Get
(
const
std
::
string
&
key
)
const
{
std
::
shared_ptr
<
const
Kernel
>
KernelPool
::
Get
(
const
std
::
string
&
key
)
const
{
if
(
kers_
.
find
(
key
)
==
kers_
.
end
())
{
if
(
kers_
.
find
(
key
)
==
kers_
.
end
())
{
return
nullptr
;
return
nullptr
;
}
}
...
@@ -36,7 +36,7 @@ const std::shared_ptr<Kernel> KernelPool::Get(const std::string& key) const {
...
@@ -36,7 +36,7 @@ const std::shared_ptr<Kernel> KernelPool::Get(const std::string& key) const {
}
}
template
<
>
template
<
>
const
std
::
shared_ptr
<
LSTMKernel
<
float
>>
std
::
shared_ptr
<
const
LSTMKernel
<
float
>>
KernelPool
::
Get
<
LSTMKernel
<
float
>
,
int
,
const
std
::
string
&
,
const
std
::
string
&
,
KernelPool
::
Get
<
LSTMKernel
<
float
>
,
int
,
const
std
::
string
&
,
const
std
::
string
&
,
const
std
::
string
&>
(
int
d
,
const
std
::
string
&
act_gate
,
const
std
::
string
&>
(
int
d
,
const
std
::
string
&
act_gate
,
const
std
::
string
&
act_cand
,
const
std
::
string
&
act_cand
,
...
@@ -49,7 +49,7 @@ KernelPool::Get<LSTMKernel<float>, int, const std::string&, const std::string&,
...
@@ -49,7 +49,7 @@ KernelPool::Get<LSTMKernel<float>, int, const std::string&, const std::string&,
kers_
.
insert
({
key
,
std
::
dynamic_pointer_cast
<
Kernel
>
(
p
)});
kers_
.
insert
({
key
,
std
::
dynamic_pointer_cast
<
Kernel
>
(
p
)});
return
p
;
return
p
;
}
}
return
std
::
dynamic_pointer_cast
<
LSTMKernel
<
float
>>
(
kers_
.
at
(
key
));
return
std
::
dynamic_pointer_cast
<
const
LSTMKernel
<
float
>>
(
kers_
.
at
(
key
));
}
}
}
// namespace jitkernel
}
// namespace jitkernel
...
...
paddle/fluid/operators/math/jit_kernel.h
浏览文件 @
55e44761
...
@@ -52,13 +52,13 @@ class KernelPool {
...
@@ -52,13 +52,13 @@ class KernelPool {
static
KernelPool
&
Instance
();
static
KernelPool
&
Instance
();
template
<
typename
Ker
,
typename
...
ARGS
>
template
<
typename
Ker
,
typename
...
ARGS
>
const
std
::
shared_ptr
<
Ker
>
Get
(
ARGS
...
args
);
std
::
shared_ptr
<
const
Ker
>
Get
(
ARGS
...
args
);
const
std
::
shared_ptr
<
Kernel
>
Get
(
const
std
::
string
&
key
)
const
;
std
::
shared_ptr
<
const
Kernel
>
Get
(
const
std
::
string
&
key
)
const
;
private:
private:
KernelPool
()
=
default
;
KernelPool
()
=
default
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Kernel
>>
kers_
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
const
Kernel
>>
kers_
;
DISABLE_COPY_AND_ASSIGN
(
KernelPool
);
DISABLE_COPY_AND_ASSIGN
(
KernelPool
);
};
};
...
@@ -66,26 +66,38 @@ class KernelPool {
...
@@ -66,26 +66,38 @@ class KernelPool {
template
<
typename
T
>
template
<
typename
T
>
class
VMulKernel
:
public
Kernel
{
class
VMulKernel
:
public
Kernel
{
public:
public:
virtual
void
Compute
(
const
int
n
,
const
T
*
x
,
const
T
*
y
,
T
*
z
)
=
0
;
virtual
void
Compute
(
const
int
n
,
const
T
*
x
,
const
T
*
y
,
T
*
z
)
const
=
0
;
};
};
template
<
typename
T
>
template
<
typename
T
>
class
VAddKernel
:
public
Kernel
{
class
VAddKernel
:
public
Kernel
{
public:
public:
virtual
void
Compute
(
const
int
n
,
const
T
*
x
,
const
T
*
y
,
T
*
z
)
=
0
;
virtual
void
Compute
(
const
int
n
,
const
T
*
x
,
const
T
*
y
,
T
*
z
)
const
=
0
;
};
};
template
<
typename
T
>
template
<
typename
T
>
class
VScalKernel
:
public
Kernel
{
class
VScalKernel
:
public
Kernel
{
public:
public:
virtual
void
Compute
(
const
int
n
,
const
T
a
,
const
T
*
x
,
T
*
y
)
=
0
;
virtual
void
Compute
(
const
int
n
,
const
T
a
,
const
T
*
x
,
T
*
y
)
const
=
0
;
virtual
void
Compute
(
const
int
n
,
const
T
a
,
T
*
x
)
=
0
;
virtual
void
Compute
(
const
int
n
,
const
T
a
,
T
*
x
)
const
=
0
;
};
};
template
<
typename
T
>
template
<
typename
T
>
class
VExpKernel
:
public
Kernel
{
class
VExpKernel
:
public
Kernel
{
public:
public:
virtual
void
Compute
(
const
int
n
,
const
T
*
x
,
T
*
y
)
=
0
;
virtual
void
Compute
(
const
int
n
,
const
T
*
x
,
T
*
y
)
const
=
0
;
};
template
<
typename
T
>
class
VSigmoidKernel
:
public
Kernel
{
public:
virtual
void
Compute
(
const
int
n
,
const
T
*
x
,
T
*
y
)
const
=
0
;
};
template
<
typename
T
>
class
VTanhKernel
:
public
Kernel
{
public:
virtual
void
Compute
(
const
int
n
,
const
T
*
x
,
T
*
y
)
const
=
0
;
};
};
template
<
typename
T
>
template
<
typename
T
>
...
...
paddle/fluid/operators/math/jit_kernel_blas.cc
浏览文件 @
55e44761
...
@@ -34,7 +34,7 @@ namespace jit = platform::jit;
...
@@ -34,7 +34,7 @@ namespace jit = platform::jit;
template
<
typename
T
,
platform
::
jit
::
cpu_isa_t
isa
,
jit_block
>
template
<
typename
T
,
platform
::
jit
::
cpu_isa_t
isa
,
jit_block
>
class
VMulKernelImpl
:
public
VMulKernel
<
T
>
{
class
VMulKernelImpl
:
public
VMulKernel
<
T
>
{
public:
public:
void
Compute
(
const
int
n
,
const
T
*
x
,
const
T
*
y
,
T
*
z
)
override
{
void
Compute
(
const
int
n
,
const
T
*
x
,
const
T
*
y
,
T
*
z
)
const
override
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
z
[
i
]
=
x
[
i
]
*
y
[
i
];
z
[
i
]
=
x
[
i
]
*
y
[
i
];
}
}
...
@@ -44,15 +44,15 @@ class VMulKernelImpl : public VMulKernel<T> {
...
@@ -44,15 +44,15 @@ class VMulKernelImpl : public VMulKernel<T> {
#ifdef PADDLE_WITH_MKLML
#ifdef PADDLE_WITH_MKLML
#define MKL_FLOAT(isa, block) \
#define MKL_FLOAT(isa, block) \
template <> \
template <> \
void VMulKernelImpl<float, isa, block>::Compute(
const int n, const float* x,
\
void VMulKernelImpl<float, isa, block>::Compute(
\
const float* y, float* z) {
\
const int n, const float* x, const float* y, float* z) const {
\
platform::dynload::vsMul(n, x, y, z); \
platform::dynload::vsMul(n, x, y, z); \
}
}
#define MKL_DOUBLE(isa, block) \
#define MKL_DOUBLE(isa, block) \
template <> \
template <> \
void VMulKernelImpl<double, isa, block>::Compute( \
void VMulKernelImpl<double, isa, block>::Compute( \
const int n, const double* x, const double* y, double* z) { \
const int n, const double* x, const double* y, double* z)
const
{ \
platform::dynload::vdMul(n, x, y, z); \
platform::dynload::vdMul(n, x, y, z); \
}
}
...
@@ -62,8 +62,8 @@ FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
...
@@ -62,8 +62,8 @@ FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
#define INTRI8_FLOAT(isa) \
#define INTRI8_FLOAT(isa) \
template <> \
template <> \
void VMulKernelImpl<float, isa, kEQ8>::Compute(
const int n, const float* x,
\
void VMulKernelImpl<float, isa, kEQ8>::Compute(
\
const float* y, float* z) {
\
const int n, const float* x, const float* y, float* z) const {
\
__m256 tmpx, tmpy; \
__m256 tmpx, tmpy; \
tmpx = _mm256_loadu_ps(x); \
tmpx = _mm256_loadu_ps(x); \
tmpy = _mm256_loadu_ps(y); \
tmpy = _mm256_loadu_ps(y); \
...
@@ -90,7 +90,7 @@ INTRI8_FLOAT(jit::avx512f);
...
@@ -90,7 +90,7 @@ INTRI8_FLOAT(jit::avx512f);
template
<
typename
T
,
platform
::
jit
::
cpu_isa_t
isa
,
jit_block
>
template
<
typename
T
,
platform
::
jit
::
cpu_isa_t
isa
,
jit_block
>
class
VAddKernelImpl
:
public
VAddKernel
<
T
>
{
class
VAddKernelImpl
:
public
VAddKernel
<
T
>
{
public:
public:
void
Compute
(
const
int
n
,
const
T
*
x
,
const
T
*
y
,
T
*
z
)
override
{
void
Compute
(
const
int
n
,
const
T
*
x
,
const
T
*
y
,
T
*
z
)
const
override
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
z
[
i
]
=
x
[
i
]
+
y
[
i
];
z
[
i
]
=
x
[
i
]
+
y
[
i
];
}
}
...
@@ -100,15 +100,15 @@ class VAddKernelImpl : public VAddKernel<T> {
...
@@ -100,15 +100,15 @@ class VAddKernelImpl : public VAddKernel<T> {
#ifdef PADDLE_WITH_MKLML
#ifdef PADDLE_WITH_MKLML
#define MKL_FLOAT(isa, block) \
#define MKL_FLOAT(isa, block) \
template <> \
template <> \
void VAddKernelImpl<float, isa, block>::Compute(
const int n, const float* x,
\
void VAddKernelImpl<float, isa, block>::Compute(
\
const float* y, float* z) {
\
const int n, const float* x, const float* y, float* z) const {
\
platform::dynload::vsAdd(n, x, y, z); \
platform::dynload::vsAdd(n, x, y, z); \
}
}
#define MKL_DOUBLE(isa, block) \
#define MKL_DOUBLE(isa, block) \
template <> \
template <> \
void VAddKernelImpl<double, isa, block>::Compute( \
void VAddKernelImpl<double, isa, block>::Compute( \
const int n, const double* x, const double* y, double* z) { \
const int n, const double* x, const double* y, double* z)
const
{ \
platform::dynload::vdAdd(n, x, y, z); \
platform::dynload::vdAdd(n, x, y, z); \
}
}
...
@@ -118,8 +118,8 @@ FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
...
@@ -118,8 +118,8 @@ FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
#define INTRI8_FLOAT(isa) \
#define INTRI8_FLOAT(isa) \
template <> \
template <> \
void VAddKernelImpl<float, isa, kEQ8>::Compute(
const int n, const float* x,
\
void VAddKernelImpl<float, isa, kEQ8>::Compute(
\
const float* y, float* z) {
\
const int n, const float* x, const float* y, float* z) const {
\
__m256 tmpx, tmpy; \
__m256 tmpx, tmpy; \
tmpx = _mm256_loadu_ps(x); \
tmpx = _mm256_loadu_ps(x); \
tmpy = _mm256_loadu_ps(y); \
tmpy = _mm256_loadu_ps(y); \
...
@@ -145,12 +145,12 @@ INTRI8_FLOAT(jit::avx512f);
...
@@ -145,12 +145,12 @@ INTRI8_FLOAT(jit::avx512f);
template
<
typename
T
,
platform
::
jit
::
cpu_isa_t
isa
,
jit_block
>
template
<
typename
T
,
platform
::
jit
::
cpu_isa_t
isa
,
jit_block
>
class
VScalKernelImpl
:
public
VScalKernel
<
T
>
{
class
VScalKernelImpl
:
public
VScalKernel
<
T
>
{
public:
public:
void
Compute
(
const
int
n
,
const
T
a
,
const
T
*
x
,
T
*
y
)
override
{
void
Compute
(
const
int
n
,
const
T
a
,
const
T
*
x
,
T
*
y
)
const
override
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
a
*
x
[
i
];
y
[
i
]
=
a
*
x
[
i
];
}
}
}
}
void
Compute
(
const
int
n
,
const
T
a
,
T
*
x
)
override
{
void
Compute
(
const
int
n
,
const
T
a
,
T
*
x
)
const
override
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
x
[
i
]
=
a
*
x
[
i
];
x
[
i
]
=
a
*
x
[
i
];
}
}
...
@@ -161,14 +161,14 @@ class VScalKernelImpl : public VScalKernel<T> {
...
@@ -161,14 +161,14 @@ class VScalKernelImpl : public VScalKernel<T> {
#define MKL_FLOAT(isa, block) \
#define MKL_FLOAT(isa, block) \
template <> \
template <> \
void VScalKernelImpl<float, isa, block>::Compute(const int n, const float a, \
void VScalKernelImpl<float, isa, block>::Compute(const int n, const float a, \
float* x)
{
\
float* x)
const {
\
platform::dynload::cblas_sscal(n, a, x, 1); \
platform::dynload::cblas_sscal(n, a, x, 1); \
}
}
#define MKL_DOUBLE(isa, block) \
#define MKL_DOUBLE(isa, block) \
template <> \
template <> \
void VScalKernelImpl<double, isa, block>::Compute( \
void VScalKernelImpl<double, isa, block>::Compute( \
const int n, const double a, double* x)
{
\
const int n, const double a, double* x)
const {
\
platform::dynload::cblas_dscal(n, a, x, 1); \
platform::dynload::cblas_dscal(n, a, x, 1); \
}
}
...
@@ -178,8 +178,8 @@ FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
...
@@ -178,8 +178,8 @@ FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
#define INTRI8_FLOAT(isa) \
#define INTRI8_FLOAT(isa) \
template <> \
template <> \
void VScalKernelImpl<float, isa, kEQ8>::Compute(
const int n, const float a,
\
void VScalKernelImpl<float, isa, kEQ8>::Compute(
\
const float* x, float* y)
{ \
const int n, const float a, const float* x, float* y) const
{ \
__m256 tmp; \
__m256 tmp; \
__m256 scalar = _mm256_set1_ps(a); \
__m256 scalar = _mm256_set1_ps(a); \
tmp = _mm256_loadu_ps(x); \
tmp = _mm256_loadu_ps(x); \
...
@@ -189,7 +189,7 @@ FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
...
@@ -189,7 +189,7 @@ FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
#define INTRI8_INPLACE_FLOAT(isa) \
#define INTRI8_INPLACE_FLOAT(isa) \
template <> \
template <> \
void VScalKernelImpl<float, isa, kEQ8>::Compute(const int n, const float a, \
void VScalKernelImpl<float, isa, kEQ8>::Compute(const int n, const float a, \
float* x)
{
\
float* x)
const {
\
__m256 tmp; \
__m256 tmp; \
__m256 scalar = _mm256_set1_ps(a); \
__m256 scalar = _mm256_set1_ps(a); \
tmp = _mm256_loadu_ps(x); \
tmp = _mm256_loadu_ps(x); \
...
...
paddle/fluid/operators/math/jit_kernel_exp.cc
浏览文件 @
55e44761
...
@@ -34,14 +34,13 @@ __m256 Exp(__m256 a);
...
@@ -34,14 +34,13 @@ __m256 Exp(__m256 a);
#endif
#endif
namespace
jitkernel
{
namespace
jitkernel
{
namespace
jit
=
platform
::
jit
;
namespace
jit
=
platform
::
jit
;
/* VExp JitKernel */
/* VExp JitKernel */
template
<
typename
T
,
jit
::
cpu_isa_t
isa
,
jit_block
>
template
<
typename
T
,
jit
::
cpu_isa_t
isa
,
jit_block
>
class
VExpKernelImpl
:
public
VExpKernel
<
T
>
{
class
VExpKernelImpl
:
public
VExpKernel
<
T
>
{
public:
public:
void
Compute
(
const
int
n
,
const
T
*
x
,
T
*
y
)
override
{
void
Compute
(
const
int
n
,
const
T
*
x
,
T
*
y
)
const
override
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
std
::
exp
(
x
[
i
]);
y
[
i
]
=
std
::
exp
(
x
[
i
]);
}
}
...
@@ -52,14 +51,14 @@ class VExpKernelImpl : public VExpKernel<T> {
...
@@ -52,14 +51,14 @@ class VExpKernelImpl : public VExpKernel<T> {
#define MKL_FLOAT(isa, block) \
#define MKL_FLOAT(isa, block) \
template <> \
template <> \
void VExpKernelImpl<float, isa, block>::Compute(const int n, const float* x, \
void VExpKernelImpl<float, isa, block>::Compute(const int n, const float* x, \
float* y)
{
\
float* y)
const {
\
platform::dynload::vsExp(n, x, y); \
platform::dynload::vsExp(n, x, y); \
}
}
#define MKL_DOUBLE(isa, block) \
#define MKL_DOUBLE(isa, block) \
template <> \
template <> \
void VExpKernelImpl<double, isa, block>::Compute( \
void VExpKernelImpl<double, isa, block>::Compute( \
const int n, const double* x, double* y)
{
\
const int n, const double* x, double* y)
const {
\
platform::dynload::vdExp(n, x, y); \
platform::dynload::vdExp(n, x, y); \
}
}
FOR_EACH_ISA
(
MKL_FLOAT
,
kLT8
);
FOR_EACH_ISA
(
MKL_FLOAT
,
kLT8
);
...
@@ -71,7 +70,7 @@ FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
...
@@ -71,7 +70,7 @@ FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
#define INTRI8_FLOAT(isa) \
#define INTRI8_FLOAT(isa) \
template <> \
template <> \
void VExpKernelImpl<float, isa, kEQ8>::Compute(const int n, const float* x, \
void VExpKernelImpl<float, isa, kEQ8>::Compute(const int n, const float* x, \
float* y)
{
\
float* y)
const {
\
__m256 tmp = _mm256_loadu_ps(x); \
__m256 tmp = _mm256_loadu_ps(x); \
_mm256_storeu_ps(y, detail::Exp(tmp)); \
_mm256_storeu_ps(y, detail::Exp(tmp)); \
}
}
...
@@ -79,7 +78,7 @@ FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
...
@@ -79,7 +78,7 @@ FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
#define INTRI16_FLOAT(isa) \
#define INTRI16_FLOAT(isa) \
template <> \
template <> \
void VExpKernelImpl<float, isa, kEQ16>::Compute(const int n, const float* x, \
void VExpKernelImpl<float, isa, kEQ16>::Compute(const int n, const float* x, \
float* y)
{
\
float* y)
const {
\
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
tmp0 = detail::Exp(tmp0); \
tmp0 = detail::Exp(tmp0); \
...
@@ -109,6 +108,38 @@ INTRI16_FLOAT(jit::avx512f);
...
@@ -109,6 +108,38 @@ INTRI16_FLOAT(jit::avx512f);
REGISTER_JITKERNEL
(
vexp
,
VExpKernel
);
REGISTER_JITKERNEL
(
vexp
,
VExpKernel
);
/* VSigmoid JitKernel */
template
<
typename
T
,
jit
::
cpu_isa_t
isa
,
jit_block
>
class
VSigmoidKernelImpl
:
public
VSigmoidKernel
<
T
>
{
public:
explicit
VSigmoidKernelImpl
(
int
d
)
:
VSigmoidKernel
<
T
>
()
{
vexp_
=
KernelPool
::
Instance
().
template
Get
<
VExpKernel
<
T
>
>
(
d
);
}
void
Compute
(
const
int
n
,
const
T
*
x
,
T
*
y
)
const
override
{
const
T
min
=
SIGMOID_THRESHOLD_MIN
;
const
T
max
=
SIGMOID_THRESHOLD_MAX
;
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
(
x
[
i
]
<
min
)
?
min
:
((
x
[
i
]
>
max
)
?
max
:
x
[
i
]);
y
[
i
]
=
static_cast
<
T
>
(
0
)
-
y
[
i
];
}
vexp_
->
Compute
(
n
,
y
,
y
);
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
static_cast
<
T
>
(
1
)
/
(
static_cast
<
T
>
(
1
)
+
y
[
i
]);
}
}
private:
std
::
shared_ptr
<
const
VExpKernel
<
T
>>
vexp_
;
};
#define JITKERNEL_NEW_ACT_IMPL(ker, dtype, isa, k) \
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype, isa, k>>(d))
REGISTER_JITKERNEL_ARGS
(
vsigmoid
,
VSigmoidKernel
,
JITKERNEL_DECLARE
,
JITKERNEL_KEY
,
JITKERNEL_NEW_ACT_IMPL
);
#undef JITKERNEL_NEW_ACT_IMPL
}
// namespace jitkernel
}
// namespace jitkernel
}
// namespace math
}
// namespace math
}
// namespace operators
}
// namespace operators
...
...
paddle/fluid/operators/math/jit_kernel_macro.h
浏览文件 @
55e44761
...
@@ -23,51 +23,68 @@ namespace jitkernel {
...
@@ -23,51 +23,68 @@ namespace jitkernel {
namespace
jit
=
platform
::
jit
;
namespace
jit
=
platform
::
jit
;
#define NEW_JITKERNEL_IMPL(src, t, isa, k) \
#define SEARCH_BLOCK(macro_, ker, dtype, isa) \
p = std::dynamic_pointer_cast<src<t>>( \
std::make_shared<src##Impl<t, isa, k>>())
#define SEARCH_BLOCK(src, t, isa) \
if (d < AVX_FLOAT_BLOCK) { \
if (d < AVX_FLOAT_BLOCK) { \
NEW_JITKERNEL_IMPL(src, t, isa, kLT8);
\
macro_(ker, dtype, isa, kLT8);
\
} else if (d == AVX_FLOAT_BLOCK) { \
} else if (d == AVX_FLOAT_BLOCK) { \
NEW_JITKERNEL_IMPL(src, t, isa, kEQ8);
\
macro_(ker, dtype, isa, kEQ8);
\
} else if (d > AVX_FLOAT_BLOCK && d < AVX512_FLOAT_BLOCK) { \
} else if (d > AVX_FLOAT_BLOCK && d < AVX512_FLOAT_BLOCK) { \
NEW_JITKERNEL_IMPL(src, t, isa, kGT8LT16);
\
macro_(ker, dtype, isa, kGT8LT16);
\
} else if (d == AVX512_FLOAT_BLOCK) { \
} else if (d == AVX512_FLOAT_BLOCK) { \
NEW_JITKERNEL_IMPL(src, t, isa, kEQ16);
\
macro_(ker, dtype, isa, kEQ16);
\
} else { \
} else { \
NEW_JITKERNEL_IMPL(src, t, isa, kGT16);
\
macro_(ker, dtype, isa, kGT16);
\
}
}
#define SEARCH_ISA_BLOCK(
src, t
) \
#define SEARCH_ISA_BLOCK(
macro_, ker, dtype
) \
if (jit::MayIUse(jit::avx512f)) { \
if (jit::MayIUse(jit::avx512f)) { \
SEARCH_BLOCK(
src, t
, jit::avx512f); \
SEARCH_BLOCK(
macro_, ker, dtype
, jit::avx512f); \
} else if (jit::MayIUse(jit::avx2)) { \
} else if (jit::MayIUse(jit::avx2)) { \
SEARCH_BLOCK(
src, t
, jit::avx2); \
SEARCH_BLOCK(
macro_, ker, dtype
, jit::avx2); \
} else if (jit::MayIUse(jit::avx)) { \
} else if (jit::MayIUse(jit::avx)) { \
SEARCH_BLOCK(
src, t
, jit::avx); \
SEARCH_BLOCK(
macro_, ker, dtype
, jit::avx); \
} else { \
} else { \
SEARCH_BLOCK(
src, t
, jit::isa_any); \
SEARCH_BLOCK(
macro_, ker, dtype
, jit::isa_any); \
}
}
#define JITKERNEL_
WITH_DTYPE(ker_key, ker_class, ker_dtype, dtype_key)
\
#define JITKERNEL_
DECLARE(ker_class, ker_dtype)
\
template <> \
template <> \
const std::shared_ptr<ker_class<ker_dtype>> \
std::shared_ptr<const ker_class<ker_dtype>> \
KernelPool::Get<ker_class<ker_dtype>>(int d) { \
KernelPool::Get<ker_class<ker_dtype>, int>(int d)
std::string key = #ker_key #dtype_key + std::to_string(d); \
#define JITKERNEL_KEY(ker_key, dtype_key) \
#ker_key #dtype_key + std::to_string(d)
#define JITKERNEL_NEW_IMPL(ker, dtype, isa, k) \
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype, isa, k>>())
#define JITKERNEL_WITH_DTYPE(ker_key, ker_class, ker_dtype, dtype_key, \
marco_declare, macro_key, macro_impl) \
marco_declare(ker_class, ker_dtype) { \
std::string key = macro_key(ker_key, dtype_key); \
if (kers_.find(key) == kers_.end()) { \
if (kers_.find(key) == kers_.end()) { \
std::shared_ptr<ker_class<ker_dtype>> p; \
std::shared_ptr<ker_class<ker_dtype>> p; \
SEARCH_ISA_BLOCK(
ker_class, ker_dtype);
\
SEARCH_ISA_BLOCK(
macro_impl, ker_class, ker_dtype);
\
kers_.insert({key, std::dynamic_pointer_cast<Kernel>(p)}); \
kers_.insert({key, std::dynamic_pointer_cast<Kernel>(p)}); \
return p; \
return p; \
} \
} \
return std::dynamic_pointer_cast<ker_class<ker_dtype>>(kers_.at(key)); \
return std::dynamic_pointer_cast<const ker_class<ker_dtype>>( \
kers_.at(key)); \
}
}
#define REGISTER_JITKERNEL(ker_key, ker_class) \
#define REGISTER_JITKERNEL(ker_key, ker_class) \
JITKERNEL_WITH_DTYPE(ker_key, ker_class, float, f); \
JITKERNEL_WITH_DTYPE(ker_key, ker_class, float, f, JITKERNEL_DECLARE, \
JITKERNEL_WITH_DTYPE(ker_key, ker_class, double, d)
JITKERNEL_KEY, JITKERNEL_NEW_IMPL); \
JITKERNEL_WITH_DTYPE(ker_key, ker_class, double, d, JITKERNEL_DECLARE, \
JITKERNEL_KEY, JITKERNEL_NEW_IMPL)
#define REGISTER_JITKERNEL_ARGS(ker_key, ker_class, marco_declare, macro_key, \
macro_impl) \
JITKERNEL_WITH_DTYPE(ker_key, ker_class, float, f, marco_declare, macro_key, \
macro_impl); \
JITKERNEL_WITH_DTYPE(ker_key, ker_class, double, d, marco_declare, \
macro_key, macro_impl)
#define FOR_EACH_ISA(macro_, block) \
#define FOR_EACH_ISA(macro_, block) \
macro_(jit::avx512f, block); \
macro_(jit::avx512f, block); \
...
...
paddle/fluid/operators/math/jit_kernel_test.cc
浏览文件 @
55e44761
...
@@ -388,16 +388,16 @@ TEST(JitKernel, pool) {
...
@@ -388,16 +388,16 @@ TEST(JitKernel, pool) {
const
auto
&
pvmul_f
=
const
auto
&
pvmul_f
=
jit
::
KernelPool
::
Instance
().
template
Get
<
jit
::
VMulKernel
<
float
>
>
(
4
);
jit
::
KernelPool
::
Instance
().
template
Get
<
jit
::
VMulKernel
<
float
>
>
(
4
);
EXPECT_TRUE
(
std
::
dynamic_pointer_cast
<
jit
::
Kernel
>
(
plstm2
)
!=
EXPECT_TRUE
(
std
::
dynamic_pointer_cast
<
const
jit
::
Kernel
>
(
plstm2
)
!=
std
::
dynamic_pointer_cast
<
jit
::
Kernel
>
(
pvmul_f
));
std
::
dynamic_pointer_cast
<
const
jit
::
Kernel
>
(
pvmul_f
));
const
auto
&
pvmul_d
=
const
auto
&
pvmul_d
=
jit
::
KernelPool
::
Instance
().
template
Get
<
jit
::
VMulKernel
<
double
>
>
(
4
);
jit
::
KernelPool
::
Instance
().
template
Get
<
jit
::
VMulKernel
<
double
>
>
(
4
);
EXPECT_TRUE
(
std
::
dynamic_pointer_cast
<
jit
::
Kernel
>
(
pvmul_f
)
!=
EXPECT_TRUE
(
std
::
dynamic_pointer_cast
<
const
jit
::
Kernel
>
(
pvmul_f
)
!=
std
::
dynamic_pointer_cast
<
jit
::
Kernel
>
(
pvmul_d
));
std
::
dynamic_pointer_cast
<
const
jit
::
Kernel
>
(
pvmul_d
));
const
auto
&
pvmul_from_key
=
jit
::
KernelPool
::
Instance
().
Get
(
"vmulf4"
);
const
auto
&
pvmul_from_key
=
jit
::
KernelPool
::
Instance
().
Get
(
"vmulf4"
);
EXPECT_
TRUE
(
pvmul_f
==
pvmul_from_key
);
EXPECT_
EQ
(
pvmul_f
,
pvmul_from_key
);
const
auto
&
pvmul_from_key2
=
jit
::
KernelPool
::
Instance
().
Get
(
"vmulf5"
);
const
auto
&
pvmul_from_key2
=
jit
::
KernelPool
::
Instance
().
Get
(
"vmulf5"
);
EXPECT_TRUE
(
pvmul_from_key2
==
nullptr
);
EXPECT_TRUE
(
pvmul_from_key2
==
nullptr
);
}
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录