Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
61fdc38e
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
61fdc38e
编写于
11月 05, 2018
作者:
T
tensor-tang
提交者:
GitHub
11月 05, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #14206 from tensor-tang/fea/jit/gen
Fea/jit/gen
上级
d4c771c6
85bcb286
变更
12
显示空白变更内容
内联
并排
Showing
12 changed file
with
481 addition
and
120 deletion
+481
-120
paddle/fluid/operators/math/CMakeLists.txt
paddle/fluid/operators/math/CMakeLists.txt
+2
-2
paddle/fluid/operators/math/jit_code.cc
paddle/fluid/operators/math/jit_code.cc
+53
-0
paddle/fluid/operators/math/jit_code.h
paddle/fluid/operators/math/jit_code.h
+63
-0
paddle/fluid/operators/math/jit_gen.cc
paddle/fluid/operators/math/jit_gen.cc
+90
-0
paddle/fluid/operators/math/jit_gen.h
paddle/fluid/operators/math/jit_gen.h
+80
-0
paddle/fluid/operators/math/jit_kernel.h
paddle/fluid/operators/math/jit_kernel.h
+2
-1
paddle/fluid/operators/math/jit_kernel_blas.cc
paddle/fluid/operators/math/jit_kernel_blas.cc
+67
-54
paddle/fluid/operators/math/jit_kernel_crf_decode.cc
paddle/fluid/operators/math/jit_kernel_crf_decode.cc
+1
-1
paddle/fluid/operators/math/jit_kernel_exp.cc
paddle/fluid/operators/math/jit_kernel_exp.cc
+3
-3
paddle/fluid/operators/math/jit_kernel_macro.h
paddle/fluid/operators/math/jit_kernel_macro.h
+93
-32
paddle/fluid/operators/math/jit_kernel_rnn.cc
paddle/fluid/operators/math/jit_kernel_rnn.cc
+20
-20
paddle/fluid/operators/math/jit_kernel_test.cc
paddle/fluid/operators/math/jit_kernel_test.cc
+7
-7
未找到文件。
paddle/fluid/operators/math/CMakeLists.txt
浏览文件 @
61fdc38e
...
...
@@ -76,6 +76,6 @@ endif()
cc_test
(
concat_test SRCS concat_test.cc DEPS concat_and_split
)
cc_test
(
cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info
)
cc_library
(
jit_kernel
SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc jit_kernel_crf_decode.cc
DEPS cpu_info cblas
)
SRCS jit_kernel.cc jit_
gen.cc jit_code.cc jit_
kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc jit_kernel_crf_decode.cc
DEPS cpu_info cblas
gflags enforce
)
cc_test
(
jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel
)
paddle/fluid/operators/math/jit_code.cc
0 → 100644
浏览文件 @
61fdc38e
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/jit_code.h"
#include "paddle/fluid/operators/math/jit_kernel.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
jitkernel
{
namespace
gen
{
using
namespace
platform
::
jit
;
// NOLINT
bool
VMulJitCode
::
init
(
int
d
)
{
// TODO(TJ): maybe one AVX is enough, AVX above would slow down freq
// try more with avx2 or avx512
if
(
MayIUse
(
avx
)
||
MayIUse
(
avx2
))
{
return
d
%
AVX_FLOAT_BLOCK
==
0
;
}
else
{
return
false
;
}
}
void
VMulJitCode
::
generate
()
{
// do not need push stack, and do not need save avx512reg if do not use avx512
int
stride
=
sizeof
(
float
)
*
AVX_FLOAT_BLOCK
;
for
(
int
i
=
0
;
i
<
num_
/
AVX_FLOAT_BLOCK
;
++
i
)
{
vmovups
(
ymm_src1
,
ptr
[
param1
+
i
*
stride
]);
vmovups
(
ymm_src2
,
ptr
[
param2
+
i
*
stride
]);
vmulps
(
ymm_dst
,
ymm_src1
,
ymm_src2
);
vmovups
(
ptr
[
param3
+
stride
*
i
],
ymm_dst
);
}
ret
();
}
}
// namespace gen
}
// namespace jitkernel
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/jit_code.h
0 → 100644
浏览文件 @
61fdc38e
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/operators/math/jit_gen.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
jitkernel
{
namespace
gen
{
using
reg64_t
=
const
Xbyak
::
Reg64
;
using
reg32_t
=
const
Xbyak
::
Reg32
;
using
xmm_t
=
const
Xbyak
::
Xmm
;
using
ymm_t
=
const
Xbyak
::
Ymm
;
using
zmm_t
=
const
Xbyak
::
Zmm
;
using
Label
=
Xbyak
::
Label
;
class
VMulJitCode
:
public
JitCode
{
public:
DECLARE_JIT_CODE
(
VMulJitCode
);
explicit
VMulJitCode
(
int
d
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
:
JitCode
(
code_size
,
code_ptr
),
num_
(
d
)
{}
static
bool
init
(
int
d
);
void
generate
()
override
;
private:
int
num_
;
reg64_t
param1
{
abi_param1
};
reg64_t
param2
{
abi_param2
};
reg64_t
param3
{
abi_param3
};
xmm_t
xmm_src1
=
xmm_t
(
0
);
ymm_t
ymm_src1
=
ymm_t
(
0
);
zmm_t
zmm_src1
=
zmm_t
(
0
);
xmm_t
xmm_src2
=
xmm_t
(
1
);
ymm_t
ymm_src2
=
ymm_t
(
1
);
zmm_t
zmm_src2
=
zmm_t
(
1
);
xmm_t
xmm_dst
=
xmm_t
(
2
);
ymm_t
ymm_dst
=
ymm_t
(
2
);
zmm_t
zmm_dst
=
zmm_t
(
2
);
};
}
// namespace gen
}
// namespace jitkernel
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/jit_gen.cc
0 → 100644
浏览文件 @
61fdc38e
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/jit_gen.h"
#include <fstream>
#include <iostream>
#include <sstream>
#include "paddle/fluid/platform/cpu_info.h"
DEFINE_bool
(
dump_jitcode
,
false
,
"Whether to dump the jitcode to file"
);
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
jitkernel
{
namespace
gen
{
constexpr
Xbyak
::
Operand
::
Code
g_abi_regs
[]
=
{
Xbyak
::
Operand
::
RBX
,
Xbyak
::
Operand
::
RBP
,
Xbyak
::
Operand
::
R12
,
Xbyak
::
Operand
::
R13
,
Xbyak
::
Operand
::
R14
,
Xbyak
::
Operand
::
R15
};
constexpr
int
num_g_abi_regs
=
sizeof
(
g_abi_regs
)
/
sizeof
(
g_abi_regs
[
0
]);
void
JitCode
::
preCode
()
{
for
(
int
i
=
0
;
i
<
num_g_abi_regs
;
++
i
)
{
push
(
Xbyak
::
Reg64
(
g_abi_regs
[
i
]));
}
if
(
platform
::
jit
::
MayIUse
(
platform
::
jit
::
avx512f
))
{
mov
(
reg_EVEX_max_8b_offt
,
2
*
EVEX_max_8b_offt
);
}
}
void
JitCode
::
postCode
()
{
for
(
int
i
=
0
;
i
<
num_g_abi_regs
;
++
i
)
{
pop
(
Xbyak
::
Reg64
(
g_abi_regs
[
num_g_abi_regs
-
1
-
i
]));
}
ret
();
}
void
JitCode
::
dumpCode
(
const
Xbyak
::
uint8
*
code
)
const
{
if
(
code
)
{
static
int
counter
=
0
;
std
::
ostringstream
filename
;
filename
<<
"paddle_jitcode_"
<<
name
()
<<
"."
<<
counter
<<
".bin"
;
counter
++
;
std
::
ofstream
fout
(
filename
.
str
(),
std
::
ios
::
out
);
if
(
fout
.
is_open
())
{
fout
.
write
(
reinterpret_cast
<
const
char
*>
(
code
),
getSize
());
fout
.
close
();
}
}
}
Xbyak
::
Address
JitCode
::
EVEX_compress_addr
(
Xbyak
::
Reg64
base
,
int
offt
,
bool
bcast
)
{
int
scale
=
0
;
if
(
EVEX_max_8b_offt
<=
offt
&&
offt
<
3
*
EVEX_max_8b_offt
)
{
offt
=
offt
-
2
*
EVEX_max_8b_offt
;
scale
=
1
;
}
else
if
(
3
*
EVEX_max_8b_offt
<=
offt
&&
offt
<
5
*
EVEX_max_8b_offt
)
{
offt
=
offt
-
4
*
EVEX_max_8b_offt
;
scale
=
2
;
}
auto
re
=
Xbyak
::
RegExp
()
+
base
+
offt
;
if
(
scale
)
{
re
=
re
+
reg_EVEX_max_8b_offt
*
scale
;
}
if
(
bcast
)
{
return
zword_b
[
re
];
}
else
{
return
zword
[
re
];
}
}
}
// namespace gen
}
// namespace jitkernel
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/jit_gen.h
0 → 100644
浏览文件 @
61fdc38e
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <gflags/gflags.h>
#include <type_traits>
#include "paddle/fluid/platform/macros.h"
#define XBYAK_USE_MMAP_ALLOCATOR
#include "xbyak/xbyak.h"
#include "xbyak/xbyak_util.h"
DECLARE_bool
(
dump_jitcode
);
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
jitkernel
{
namespace
gen
{
#define DECLARE_JIT_CODE(codename) \
const char *name() const override { return #codename; }
// Application Binary Interface
constexpr
Xbyak
::
Operand
::
Code
abi_param1
(
Xbyak
::
Operand
::
RDI
),
abi_param2
(
Xbyak
::
Operand
::
RSI
),
abi_param3
(
Xbyak
::
Operand
::
RDX
),
abi_param4
(
Xbyak
::
Operand
::
RCX
),
abi_not_param1
(
Xbyak
::
Operand
::
RCX
);
class
JitCode
:
public
Xbyak
::
CodeGenerator
{
public:
explicit
JitCode
(
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
:
Xbyak
::
CodeGenerator
(
code_size
,
code_ptr
)
{}
virtual
~
JitCode
()
{}
virtual
const
char
*
name
()
const
=
0
;
virtual
void
generate
()
=
0
;
template
<
typename
FUNC
>
const
FUNC
getCode
()
{
this
->
generate
();
const
Xbyak
::
uint8
*
code
=
CodeGenerator
::
getCode
();
if
(
FLAGS_dump_jitcode
)
{
this
->
dumpCode
(
code
);
}
return
reinterpret_cast
<
const
FUNC
>
(
code
);
}
DISABLE_COPY_AND_ASSIGN
(
JitCode
);
protected:
Xbyak
::
Reg64
param1
{
abi_param1
};
const
int
EVEX_max_8b_offt
=
0x200
;
const
Xbyak
::
Reg64
reg_EVEX_max_8b_offt
=
rbp
;
void
preCode
();
void
postCode
();
void
dumpCode
(
const
Xbyak
::
uint8
*
code
)
const
;
void
L
(
const
char
*
label
)
{
Xbyak
::
CodeGenerator
::
L
(
label
);
}
void
L
(
const
Xbyak
::
Label
&
label
)
{
Xbyak
::
CodeGenerator
::
L
(
label
);
}
// Enhanced vector extension
Xbyak
::
Address
EVEX_compress_addr
(
Xbyak
::
Reg64
base
,
int
offt
,
bool
bcast
=
false
);
};
}
// namespace gen
}
// namespace jitkernel
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/jit_kernel.h
浏览文件 @
61fdc38e
...
...
@@ -39,6 +39,7 @@ class Kernel {
public:
Kernel
()
=
default
;
virtual
~
Kernel
()
=
default
;
// TODO(TJ): below members should be deprecated.
int
num_
{
0
};
int
end_
{
0
};
int
rest_
{
0
};
...
...
@@ -64,7 +65,7 @@ class KernelPool {
template
<
typename
T
>
class
VMulKernel
:
public
Kernel
{
public:
v
irtual
void
Compute
(
const
T
*
x
,
const
T
*
y
,
T
*
z
)
const
=
0
;
v
oid
(
*
Compute
)(
const
T
*
,
const
T
*
,
T
*
,
int
)
;
};
template
<
typename
T
>
...
...
paddle/fluid/operators/math/jit_kernel_blas.cc
浏览文件 @
61fdc38e
...
...
@@ -14,7 +14,10 @@ limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h"
#include <string>
#include "paddle/fluid/operators/math/jit_code.h"
#include "paddle/fluid/operators/math/jit_kernel_macro.h"
#include "paddle/fluid/platform/enforce.h"
#ifdef PADDLE_WITH_MKLML
#include "paddle/fluid/platform/dynload/mklml.h"
#endif
...
...
@@ -27,65 +30,76 @@ namespace paddle {
namespace
operators
{
namespace
math
{
namespace
jitkernel
{
namespace
jit
=
platform
::
jit
;
template
<
typename
T
>
void
VMulRefer
(
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int
n
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
z
[
i
]
=
x
[
i
]
*
y
[
i
];
}
}
#ifdef PADDLE_WITH_MKLML
template
<
typename
T
>
void
VMulMKL
(
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int
n
);
template
<
>
void
VMulMKL
<
float
>
(
const
float
*
x
,
const
float
*
y
,
float
*
z
,
int
n
)
{
platform
::
dynload
::
vsMul
(
n
,
x
,
y
,
z
);
}
template
<
>
void
VMulMKL
<
double
>
(
const
double
*
x
,
const
double
*
y
,
double
*
z
,
int
n
)
{
platform
::
dynload
::
vdMul
(
n
,
x
,
y
,
z
);
}
#endif
/* VMUL JitKernel */
template
<
typename
T
,
platform
::
jit
::
cpu_isa_t
isa
,
jit_block
>
template
<
typename
T
>
class
VMulKernelImpl
:
public
VMulKernel
<
T
>
{
public:
explicit
VMulKernelImpl
(
int
d
)
:
VMulKernel
<
T
>
()
{
this
->
num_
=
d
;
}
void
Compute
(
const
T
*
x
,
const
T
*
y
,
T
*
z
)
const
override
{
for
(
int
i
=
0
;
i
<
this
->
num_
;
++
i
)
{
z
[
i
]
=
x
[
i
]
*
y
[
i
];
}
static
inline
std
::
string
name
(
int
d
)
{
PADDLE_THROW
(
"DType should be either float or double"
);
}
};
static
inline
bool
useJIT
(
int
d
)
{
return
false
;
}
static
inline
bool
useMKL
(
int
d
)
{
return
false
;
}
explicit
VMulKernelImpl
(
int
d
)
:
VMulKernel
<
T
>
()
{
if
(
useJIT
(
d
))
{
constexpr
size_t
sz
=
256
*
1024
;
// TODO(TJ): should be related with d
jitcode_
.
reset
(
new
gen
::
VMulJitCode
(
d
,
sz
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
const
T
*
,
T
*
,
int
)
>
();
return
;
}
#ifdef PADDLE_WITH_MKLML
#define MKL_FLOAT(isa, block) \
template <> \
void VMulKernelImpl<float, isa, block>::Compute( \
const float* x, const float* y, float* z) const { \
platform::dynload::vsMul(this->num_, x, y, z); \
if
(
useMKL
(
d
))
{
this
->
Compute
=
VMulMKL
<
T
>
;
return
;
}
#define MKL_DOUBLE(isa, block) \
template <> \
void VMulKernelImpl<double, isa, block>::Compute( \
const double* x, const double* y, double* z) const { \
platform::dynload::vdMul(this->num_, x, y, z); \
#endif
this
->
Compute
=
VMulRefer
<
T
>
;
}
FOR_EACH_ISA
(
MKL_FLOAT
,
kGT16
);
FOR_EACH_ISA_BLOCK
(
MKL_DOUBLE
)
;
#endif
private:
std
::
unique_ptr
<
gen
::
VMulJitCode
>
jitcode_
{
nullptr
}
;
};
#define INTRI8_FLOAT(isa) \
template <> \
void VMulKernelImpl<float, isa, kEQ8>::Compute( \
const float* x, const float* y, float* z) const { \
__m256 tmpx, tmpy; \
tmpx = _mm256_loadu_ps(x); \
tmpy = _mm256_loadu_ps(y); \
tmpx = _mm256_mul_ps(tmpx, tmpy); \
_mm256_storeu_ps(z, tmpx); \
}
template
<
>
bool
VMulKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
VMulJitCode
::
init
(
d
);
}
// avx > for > mkl
#ifdef __AVX__
INTRI8_FLOAT
(
jit
::
avx
);
#endif
#ifdef __AVX2__
INTRI8_FLOAT
(
jit
::
avx2
);
#endif
#ifdef __AVX512F__
INTRI8_FLOAT
(
jit
::
avx512f
);
#endif
// TODO(TJ): eq16 test and complete avx512
#undef INTRI8_FLOAT
#undef MKL_FLOAT
#undef MKL_DOUBLE
template
<
>
bool
VMulKernelImpl
<
float
>::
useMKL
(
int
d
)
{
return
jit
::
MayIUse
(
jit
::
avx512f
)
&&
d
>
512
;
}
template
<
>
bool
VMulKernelImpl
<
double
>::
useMKL
(
int
d
)
{
return
true
;
}
REGISTER_JITKERNEL
(
vmul
,
VMulKernel
);
/* VADD JitKernel */
template
<
typename
T
,
platform
::
jit
::
cpu_isa_t
isa
,
jit_block
>
...
...
@@ -465,13 +479,12 @@ INTRI_COMMON_FLOAT(jit::avx512f, kGT16);
#undef INTRI16_FLOAT
#undef INTRI_COMMON_FLOAT
REGISTER_JITKERNEL
(
vmul
,
VMulKernel
);
REGISTER_JITKERNEL
(
vadd
,
VAddKernel
);
REGISTER_JITKERNEL
(
vscal
,
VScalKernel
);
REGISTER_JITKERNEL
(
vaddb
,
VAddBiasKernel
);
REGISTER_JITKERNEL
(
vrelu
,
VReluKernel
);
REGISTER_JITKERNEL
(
vaddrelu
,
VAddReluKernel
);
REGISTER_JITKERNEL
(
videntity
,
VIdentityKernel
);
REGISTER_JITKERNEL_DEPRECATED
(
vadd
,
VAddKernel
);
REGISTER_JITKERNEL_DEPRECATED
(
vscal
,
VScalKernel
);
REGISTER_JITKERNEL_DEPRECATED
(
vaddb
,
VAddBiasKernel
);
REGISTER_JITKERNEL_DEPRECATED
(
vrelu
,
VReluKernel
);
REGISTER_JITKERNEL_DEPRECATED
(
vaddrelu
,
VAddReluKernel
);
REGISTER_JITKERNEL_DEPRECATED
(
videntity
,
VIdentityKernel
);
}
// namespace jitkernel
}
// namespace math
...
...
paddle/fluid/operators/math/jit_kernel_crf_decode.cc
浏览文件 @
61fdc38e
...
...
@@ -288,7 +288,7 @@ INTRIAVX512_FLOAT(kGT16);
#undef INIT_ALPHA
#undef UPDATE_ALPHA
REGISTER_JITKERNEL
(
crf_decode
,
CRFDecodeKernel
);
REGISTER_JITKERNEL
_DEPRECATED
(
crf_decode
,
CRFDecodeKernel
);
}
// namespace jitkernel
}
// namespace math
...
...
paddle/fluid/operators/math/jit_kernel_exp.cc
浏览文件 @
61fdc38e
...
...
@@ -250,7 +250,7 @@ INTRI16_FLOAT(jit::avx512f, detail::ExpAVX2);
#undef MKL_FLOAT
#undef MKL_DOUBLE
REGISTER_JITKERNEL
(
vexp
,
VExpKernel
);
REGISTER_JITKERNEL
_DEPRECATED
(
vexp
,
VExpKernel
);
/* VSigmoid JitKernel */
template
<
typename
T
,
jit
::
cpu_isa_t
isa
,
jit_block
>
...
...
@@ -396,7 +396,7 @@ INTRI16_FLOAT(jit::avx512f, detail::ExpAVX2);
#undef INTRI_GT16_FLOAT
#undef INTRI_VSIGMOID
REGISTER_JITKERNEL
(
vsigmoid
,
VSigmoidKernel
);
REGISTER_JITKERNEL
_DEPRECATED
(
vsigmoid
,
VSigmoidKernel
);
/* VTanh JitKernel */
template
<
typename
T
,
jit
::
cpu_isa_t
isa
,
jit_block
>
...
...
@@ -531,7 +531,7 @@ INTRI16_FLOAT(jit::avx512f, detail::ExpAVX2);
#undef INTRI_GT16_FLOAT
#undef INTRI_VTANH
REGISTER_JITKERNEL
(
vtanh
,
VTanhKernel
);
REGISTER_JITKERNEL
_DEPRECATED
(
vtanh
,
VTanhKernel
);
#undef JITKERNEL_NEW_ACT_IMPL
...
...
paddle/fluid/operators/math/jit_kernel_macro.h
浏览文件 @
61fdc38e
...
...
@@ -21,8 +21,71 @@ namespace operators {
namespace
math
{
namespace
jitkernel
{
namespace
jit
=
platform
::
jit
;
#define JITKERNEL_DEFINE_NAME(ker_key, ker_class) \
template <> \
std::string ker_class##Impl<float>::name(int d) { \
std::string key(#ker_key "f"); \
if (useJIT(d)) { \
/* only jit code need record d*/
\
return key + "jit" + std::to_string(d); \
} else if (useMKL(d)) { \
return key + "mkl"; \
} else { \
return key + "any"; \
} \
} \
template <> \
std::string ker_class##Impl<double>::name(int d) { \
std::string key(#ker_key "d"); \
/* jit code do not support double yet*/
\
if (useMKL(d)) { \
return key + "mkl"; \
} else { \
return key + "any"; \
} \
}
#define JITKERNEL_DECLARE(ker_class, ker_dtype) \
template <> \
std::shared_ptr<const ker_class<ker_dtype>> \
KernelPool::Get<ker_class<ker_dtype>, int>(int d)
#define JITKERNEL_FIND_KEY(ker_class, ker_dtype) \
std::string key = ker_class##Impl<ker_dtype>::name(d)
#define JITKERNEL_IMPL(ker_class, ker_dtype) \
p = std::dynamic_pointer_cast<ker_class<ker_dtype>>( \
std::make_shared<ker_class##Impl<ker_dtype>>(d))
#define REGISTER_JITKERNEL_WITH_DTYPE(ker_class, ker_dtype, marco_declare, \
macro_find_key, macro_impl) \
marco_declare(ker_class, ker_dtype) { \
macro_find_key(ker_class, ker_dtype); \
if (kers_.find(key) == kers_.end()) { \
std::shared_ptr<ker_class<ker_dtype>> p; \
macro_impl(ker_class, ker_dtype); \
kers_.insert({key, std::dynamic_pointer_cast<Kernel>(p)}); \
return p; \
} \
return std::dynamic_pointer_cast<const ker_class<ker_dtype>>( \
kers_.at(key)); \
}
#define REGISTER_JITKERNEL_ARGS(ker_key, ker_class, marco_define_name, \
marco_declare, macro_find_key, macro_impl) \
marco_define_name(ker_key, ker_class); \
REGISTER_JITKERNEL_WITH_DTYPE(ker_class, float, JITKERNEL_DECLARE, \
JITKERNEL_FIND_KEY, JITKERNEL_IMPL); \
REGISTER_JITKERNEL_WITH_DTYPE(ker_class, double, JITKERNEL_DECLARE, \
JITKERNEL_FIND_KEY, JITKERNEL_IMPL)
#define REGISTER_JITKERNEL(ker_key, ker_class) \
REGISTER_JITKERNEL_ARGS(ker_key, ker_class, JITKERNEL_DEFINE_NAME, \
JITKERNEL_DECLARE, JITKERNEL_FIND_KEY, \
JITKERNEL_IMPL)
namespace
jit
=
platform
::
jit
;
// TODO(TJ): below defines are deprecated, would be remove recently
#define SEARCH_BLOCK(macro_, ker, dtype, isa) \
if (d < AVX_FLOAT_BLOCK) { \
macro_(ker, dtype, isa, kLT8); \
...
...
@@ -47,20 +110,16 @@ namespace jit = platform::jit;
SEARCH_BLOCK(macro_, ker, dtype, jit::isa_any); \
}
#define JITKERNEL_DECLARE(ker_class, ker_dtype) \
template <> \
std::shared_ptr<const ker_class<ker_dtype>> \
KernelPool::Get<ker_class<ker_dtype>, int>(int d)
#define JITKERNEL_KEY(ker_key, dtype_key) \
#ker_key #dtype_key + std::to_string(d)
#define JITKERNEL_NEW_IMPL(ker, dtype, isa, k) \
#define JITKERNEL_NEW_IMPL
_DEPRECATED
(ker, dtype, isa, k) \
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype, isa, k>>(d))
#define JITKERNEL_WITH_DTYPE(ker_key, ker_class, ker_dtype, dtype_key, \
marco_declare, macro_key, macro_impl) \
#define JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, ker_dtype, \
dtype_key, marco_declare, macro_key, \
macro_impl) \
marco_declare(ker_class, ker_dtype) { \
std::string key = macro_key(ker_key, dtype_key); \
if (kers_.find(key) == kers_.end()) { \
...
...
@@ -73,18 +132,20 @@ namespace jit = platform::jit;
kers_.at(key)); \
}
#define REGISTER_JITKERNEL(ker_key, ker_class) \
JITKERNEL_WITH_DTYPE(ker_key, ker_class, float, f, JITKERNEL_DECLARE, \
JITKERNEL_KEY, JITKERNEL_NEW_IMPL); \
JITKERNEL_WITH_DTYPE(ker_key, ker_class, double, d, JITKERNEL_DECLARE, \
JITKERNEL_KEY, JITKERNEL_NEW_IMPL)
#define REGISTER_JITKERNEL_ARGS(ker_key, ker_class, marco_declare, macro_key, \
macro_impl) \
JITKERNEL_WITH_DTYPE(ker_key, ker_class, float, f, marco_declare, macro_key, \
macro_impl); \
JITKERNEL_WITH_DTYPE(ker_key, ker_class, double, d, marco_declare, \
macro_key, macro_impl)
#define REGISTER_JITKERNEL_DEPRECATED(ker_key, ker_class) \
JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, float, f, \
JITKERNEL_DECLARE, JITKERNEL_KEY, \
JITKERNEL_NEW_IMPL_DEPRECATED); \
JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, double, d, \
JITKERNEL_DECLARE, JITKERNEL_KEY, \
JITKERNEL_NEW_IMPL_DEPRECATED)
#define REGISTER_JITKERNEL_ARGS_DEPRECATED(ker_key, ker_class, marco_declare, \
macro_key, macro_impl) \
JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, float, f, marco_declare, \
macro_key, macro_impl); \
JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, double, d, \
marco_declare, macro_key, macro_impl)
#define FOR_EACH_ISA(macro_, block) \
macro_(jit::avx512f, block); \
...
...
paddle/fluid/operators/math/jit_kernel_rnn.cc
浏览文件 @
61fdc38e
...
...
@@ -179,23 +179,23 @@ class LSTMKernelImpl : public LSTMKernel<T> {
/* C_t = C_t-1 * fgated + cand_gated * igated */
act_cand_d_
->
Compute
(
gates
,
gates
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
gates
+
d_
);
vmul_d_
->
Compute
(
ct_1
,
gates
+
d2_
,
gates
+
d2_
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
gates
+
d_
,
d_
);
vmul_d_
->
Compute
(
ct_1
,
gates
+
d2_
,
gates
+
d2_
,
d_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d2_
,
ct
);
/* H_t = act_cell(C_t) * ogated */
act_cell_d_
->
Compute
(
ct
,
gates
+
d2_
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
,
d_
);
}
void
ComputeC1H1
(
T
*
gates
,
T
*
ct
,
T
*
ht
,
const
T
*
wp_data
)
const
override
{
/* C_t = igated * cgated*/
act_gate_d_
->
Compute
(
gates
+
d_
,
gates
+
d_
);
act_cand_d_
->
Compute
(
gates
,
gates
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
ct
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
ct
,
d_
);
/* H_t = act_cell(C_t) * ogated */
act_gate_d_
->
Compute
(
gates
+
d3_
,
gates
+
d3_
);
act_cell_d_
->
Compute
(
ct
,
gates
+
d2_
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
,
d_
);
}
private:
...
...
@@ -289,36 +289,36 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
void
ComputeCtHt
(
T
*
gates
,
const
T
*
ct_1
,
T
*
ct
,
T
*
ht
,
const
T
*
wp_data
,
T
*
checked
)
const
override
{
/* get fgated and igated*/
vmul_d_
->
Compute
(
wp_data
,
ct_1
,
checked
);
vmul_d_
->
Compute
(
wp_data
+
d_
,
ct_1
,
checked
+
d_
);
vmul_d_
->
Compute
(
wp_data
,
ct_1
,
checked
,
d_
);
vmul_d_
->
Compute
(
wp_data
+
d_
,
ct_1
,
checked
+
d_
,
d_
);
vadd_d2_
->
Compute
(
checked
,
gates
+
d_
,
gates
+
d_
);
act_gate_d2_
->
Compute
(
gates
+
d_
,
gates
+
d_
);
/* C_t = C_t-1 * fgated + cand_gated * igated*/
act_cand_d_
->
Compute
(
gates
,
gates
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
gates
+
d_
);
vmul_d_
->
Compute
(
ct_1
,
gates
+
d2_
,
gates
+
d2_
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
gates
+
d_
,
d_
);
vmul_d_
->
Compute
(
ct_1
,
gates
+
d2_
,
gates
+
d2_
,
d_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d2_
,
ct
);
/* get ogated*/
vmul_d_
->
Compute
(
wp_data
+
d2_
,
ct
,
gates
+
d_
);
vmul_d_
->
Compute
(
wp_data
+
d2_
,
ct
,
gates
+
d_
,
d_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d3_
,
gates
+
d3_
);
act_gate_d_
->
Compute
(
gates
+
d3_
,
gates
+
d3_
);
/* H_t = act_cell(C_t) * ogated */
act_cell_d_
->
Compute
(
ct
,
gates
+
d2_
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
,
d_
);
}
void
ComputeC1H1
(
T
*
gates
,
T
*
ct
,
T
*
ht
,
const
T
*
wp_data
)
const
override
{
/* C_t = igated * cgated*/
act_gate_d_
->
Compute
(
gates
+
d_
,
gates
+
d_
);
act_cand_d_
->
Compute
(
gates
,
gates
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
ct
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
ct
,
d_
);
/* get outgated, put W_oc * C_t on igated */
vmul_d_
->
Compute
(
wp_data
+
d2_
,
ct
,
gates
+
d_
);
vmul_d_
->
Compute
(
wp_data
+
d2_
,
ct
,
gates
+
d_
,
d_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d3_
,
gates
+
d3_
);
/* H_t = act_cell(C_t) * ogated */
act_gate_d_
->
Compute
(
gates
+
d3_
,
gates
+
d3_
);
act_cell_d_
->
Compute
(
ct
,
gates
+
d2_
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
,
d_
);
}
private:
...
...
@@ -352,7 +352,7 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
act_cell, d)); \
}
REGISTER_JITKERNEL_ARGS
(
lstm
,
LSTMKernel
,
JITKERNEL_DECLARE_LSTM
,
REGISTER_JITKERNEL_ARGS
_DEPRECATED
(
lstm
,
LSTMKernel
,
JITKERNEL_DECLARE_LSTM
,
JITKERNEL_KEY_LSTM
,
JITKERNEL_NEW_LSTM_IMPL
);
#undef INTRI8_FLOAT
...
...
@@ -378,13 +378,13 @@ class GRUKernelImpl : public GRUKernel<T> {
void
ComputeH1
(
T
*
gates
,
T
*
ht
)
const
override
{
act_gate_d_
->
Compute
(
gates
,
gates
);
act_state_d_
->
Compute
(
gates
+
d2_
,
gates
+
d2_
);
vmul_d_
->
Compute
(
gates
,
gates
+
d2_
,
ht
);
vmul_d_
->
Compute
(
gates
,
gates
+
d2_
,
ht
,
d_
);
}
void
ComputeHtPart1
(
T
*
gates
,
const
T
*
ht_1
,
T
*
ht
)
const
override
{
// W: {W_update, W_reset; W_state}
act_gate_d2_
->
Compute
(
gates
,
gates
);
vmul_d_
->
Compute
(
ht_1
,
gates
+
d_
,
ht
);
vmul_d_
->
Compute
(
ht_1
,
gates
+
d_
,
ht
,
d_
);
}
void
ComputeHtPart2
(
T
*
gates
,
const
T
*
ht_1
,
T
*
ht
)
const
override
{
...
...
@@ -472,7 +472,7 @@ INTRI8_FLOAT(jit::avx512f);
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype, isa, k>>(act_gate, act_state, d));
REGISTER_JITKERNEL_ARGS
(
gru
,
GRUKernel
,
JITKERNEL_DECLARE_GRU
,
REGISTER_JITKERNEL_ARGS
_DEPRECATED
(
gru
,
GRUKernel
,
JITKERNEL_DECLARE_GRU
,
JITKERNEL_KEY_GRU
,
JITKERNEL_NEW_GRU_IMPL
);
#undef INTRI8_FLOAT
...
...
paddle/fluid/operators/math/jit_kernel_test.cc
浏览文件 @
61fdc38e
...
...
@@ -369,12 +369,12 @@ void lstm_ctht_better(
int
d2
=
d
*
2
;
vsigmoid_3d
->
Compute
(
gates
+
d
,
gates
+
d
);
vtanh_d
->
Compute
(
gates
,
gates
);
vmul_d
->
Compute
(
gates
,
gates
+
d
,
gates
+
d
);
vmul_d
->
Compute
(
ct_1
,
gates
+
d2
,
gates
+
d2
);
vmul_d
->
Compute
(
gates
,
gates
+
d
,
gates
+
d
,
d
);
vmul_d
->
Compute
(
ct_1
,
gates
+
d2
,
gates
+
d2
,
d
);
vadd_d
->
Compute
(
gates
+
d
,
gates
+
d2
,
ct
);
/* H_t = act_cell(C_t) * ogated */
vtanh_d
->
Compute
(
ct
,
gates
+
d2
);
vmul_d
->
Compute
(
gates
+
d2
,
gates
+
d
*
3
,
ht
);
vmul_d
->
Compute
(
gates
+
d2
,
gates
+
d
*
3
,
ht
,
d
);
}
TEST
(
JitKernel
,
lstm
)
{
...
...
@@ -578,7 +578,7 @@ void vmul_mkl(const int n, const float* x, const float* y, float* z) {
TEST
(
JitKernel
,
vmul
)
{
namespace
jit
=
paddle
::
operators
::
math
::
jitkernel
;
for
(
int
d
:
{
7
,
8
,
15
,
16
,
30
,
256
,
512
})
{
for
(
int
d
:
{
7
,
8
,
15
,
16
,
30
,
256
,
512
,
1000
,
1024
})
{
std
::
vector
<
float
>
x
(
d
),
y
(
d
);
std
::
vector
<
float
>
zref
(
d
),
ztgt
(
d
);
RandomVec
<
float
>
(
d
,
x
.
data
());
...
...
@@ -616,7 +616,7 @@ TEST(JitKernel, vmul) {
auto
ttgts
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
ker
->
Compute
(
x_data
,
y_data
,
ztgt_data
);
ker
->
Compute
(
x_data
,
y_data
,
ztgt_data
,
d
);
}
auto
ttgte
=
GetCurrentUS
();
...
...
@@ -800,8 +800,8 @@ TEST(JitKernel, pool) {
EXPECT_TRUE
(
std
::
dynamic_pointer_cast
<
const
jit
::
Kernel
>
(
pvmul_f
)
!=
std
::
dynamic_pointer_cast
<
const
jit
::
Kernel
>
(
pvmul_d
));
const
auto
&
pvmul_from_key
=
jit
::
KernelPool
::
Instance
().
Get
(
"vmulf
4
"
);
const
auto
&
pvmul_from_key
=
jit
::
KernelPool
::
Instance
().
Get
(
"vmulf
any
"
);
EXPECT_EQ
(
pvmul_f
,
pvmul_from_key
);
const
auto
&
pvmul_from_key2
=
jit
::
KernelPool
::
Instance
().
Get
(
"vmulf
5
"
);
const
auto
&
pvmul_from_key2
=
jit
::
KernelPool
::
Instance
().
Get
(
"vmulf
jit
"
);
EXPECT_TRUE
(
pvmul_from_key2
==
nullptr
);
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录