Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
45bfa70c
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
45bfa70c
编写于
12月 03, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
complete vmul jit kernel
上级
77236e33
变更
12
显示空白变更内容
内联
并排
Showing
12 changed file
with
273 addition
and
126 deletion
+273
-126
paddle/fluid/operators/jitkernels/CMakeLists.txt
paddle/fluid/operators/jitkernels/CMakeLists.txt
+7
-5
paddle/fluid/operators/jitkernels/README.md
paddle/fluid/operators/jitkernels/README.md
+3
-0
paddle/fluid/operators/jitkernels/jitcode/jitcode.cc
paddle/fluid/operators/jitkernels/jitcode/jitcode.cc
+23
-0
paddle/fluid/operators/jitkernels/jitcode/jitcode.h
paddle/fluid/operators/jitkernels/jitcode/jitcode.h
+4
-3
paddle/fluid/operators/jitkernels/jitcode_base.h
paddle/fluid/operators/jitkernels/jitcode_base.h
+4
-5
paddle/fluid/operators/jitkernels/kernel_base.h
paddle/fluid/operators/jitkernels/kernel_base.h
+9
-4
paddle/fluid/operators/jitkernels/kernels.cc
paddle/fluid/operators/jitkernels/kernels.cc
+5
-2
paddle/fluid/operators/jitkernels/kernels.h
paddle/fluid/operators/jitkernels/kernels.h
+65
-45
paddle/fluid/operators/jitkernels/refer/refer.cc
paddle/fluid/operators/jitkernels/refer/refer.cc
+2
-1
paddle/fluid/operators/jitkernels/refer/refer.h
paddle/fluid/operators/jitkernels/refer/refer.h
+8
-0
paddle/fluid/operators/jitkernels/registry.h
paddle/fluid/operators/jitkernels/registry.h
+68
-58
paddle/fluid/operators/jitkernels/test.cc
paddle/fluid/operators/jitkernels/test.cc
+75
-3
未找到文件。
paddle/fluid/operators/jitkernels/CMakeLists.txt
浏览文件 @
45bfa70c
# set(use_jit_file ${PADDLE_BINARY_DIR}/paddle/fluid/operators/jit/kernels.h)
# file(WRITE ${pass_file} "// Generated by the paddle/fluid/framework/ir/CMakeLists.txt. DO NOT EDIT!\n\n")
# file(APPEND ${pass_file} "\#pragma once\n")
# file(APPEND ${pass_file} "\#include \"paddle/fluid/framework/ir/pass.h\"\n")
set
(
JIT_KERNEL_DEPS cpu_info cblas gflags enforce place
)
cc_library
(
jit_kernel_base SRCS kernels.cc DEPS
${
JIT_KERNEL_DEPS
}
)
add_subdirectory
(
more
)
add_subdirectory
(
refer
)
add_subdirectory
(
more
)
if
(
WITH_XBYAK
)
add_subdirectory
(
jitcode
)
endif
()
# Debug
message
(
STATUS
"--------
${
JIT_KERNEL_DEPS
}
"
)
cc_library
(
jit_kernel SRCS kernels.cc DEPS
${
JIT_KERNEL_DEPS
}
)
cc_test
(
jit_kernel_test SRCS test.cc DEPS jit_kernel
)
paddle/fluid/operators/jitkernels/README.md
浏览文件 @
45bfa70c
TBD
# Use me
Add USE_JIT_KERNEL(yourname) to CMakefile.
paddle/fluid/operators/jitkernels/jitcode/jitcode.cc
浏览文件 @
45bfa70c
...
...
@@ -13,3 +13,26 @@
* limitations under the License. */
#include "paddle/fluid/operators/jitkernels/jitcode/jitcode.h"
namespace
paddle
{
namespace
operators
{
namespace
jitkernels
{
template
<
>
size_t
GetKey
<
int
>
(
int
d
)
{
return
d
;
}
// template <>
// std::shared_ptr<const JitBase> CreateJitCode<KernelType::vmul, int>(int attr)
// {
// if (UseJitCode<KernelType::vmul, int>(attr)) {
// return std::make_shared<jitcode::VMulJitCode<int>>(attr,
// CodeSize<KernelType::vmul, int>(attr)));
// }
// return nullptr;
// }
}
// namespace jitkernels
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jitkernels/jitcode/jitcode.h
浏览文件 @
45bfa70c
...
...
@@ -15,6 +15,7 @@
#pragma once
#include <type_traits>
#include "paddle/fluid/operators/jitkernels/jitcode_base.h"
#include "paddle/fluid/operators/jitkernels/kernels.h"
#define XBYAK_USE_MMAP_ALLOCATOR
...
...
@@ -31,10 +32,10 @@ constexpr Xbyak::Operand::Code abi_param1(Xbyak::Operand::RDI),
abi_param2
(
Xbyak
::
Operand
::
RSI
),
abi_param3
(
Xbyak
::
Operand
::
RDX
),
abi_param4
(
Xbyak
::
Operand
::
RCX
),
abi_not_param1
(
Xbyak
::
Operand
::
RCX
);
template
<
KernelType
KT
,
typename
Attr
>
class
JitCode
:
public
JitBase
,
public
Xbyak
::
CodeGenerator
{
template
<
typename
Attr
>
class
VMul
JitCode
:
public
JitBase
,
public
Xbyak
::
CodeGenerator
{
public:
JitCode
(
Attr
attr
,
size_t
code_size
,
void
*
code_ptr
=
nullptr
)
VMul
JitCode
(
Attr
attr
,
size_t
code_size
,
void
*
code_ptr
=
nullptr
)
:
Xbyak
::
CodeGenerator
(
code_size
,
code_ptr
)
{
this
->
genCode
();
}
...
...
paddle/fluid/operators/jitkernels/jitcode_base.h
浏览文件 @
45bfa70c
...
...
@@ -15,6 +15,7 @@
#pragma once
#include <gflags/gflags.h>
#include <memory> // for shared_ptr
#include "paddle/fluid/operators/jitkernels/kernel_base.h"
#include "paddle/fluid/platform/macros.h"
...
...
@@ -42,11 +43,6 @@ bool UseJitCode(Attr attr) {
template
<
typename
Attr
>
size_t
GetKey
(
Attr
attr
);
template
<
>
size_t
GetKey
<
int
>
(
int
d
)
{
return
d
;
}
class
JitBase
{
public:
JitBase
()
=
default
;
...
...
@@ -68,6 +64,9 @@ class JitBase {
void
dumpCode
(
const
unsigned
char
*
code
);
};
template
<
KernelType
KT
,
typename
Attr
>
std
::
shared_ptr
<
const
JitBase
>
CreateJitCode
(
Attr
attr
);
}
// namespace jitkernels
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jitkernels/kernel_base.h
浏览文件 @
45bfa70c
...
...
@@ -25,6 +25,7 @@ typedef enum { vmul = 0, vadd = 1, vsub, vexp } KernelType;
class
Kernel
{
public:
Kernel
()
=
default
;
virtual
~
Kernel
()
=
default
;
DISABLE_COPY_AND_ASSIGN
(
Kernel
);
};
...
...
@@ -32,16 +33,20 @@ template <typename T, typename Func, typename Attr> // TODO(TJ): use tuple
class
KernelImpl
:
public
Kernel
{
public:
using
ELEMENT_TYPE
=
T
;
// TODO(TJ): remove me?
KernelImpl
()
=
default
;
virtual
~
KernelImpl
()
=
default
;
virtual
Func
GetFunc
()
{
return
func
;
}
virtual
Func
GetFunc
()
const
{
return
func
;
}
virtual
bool
UseMe
(
Attr
attr
)
const
=
0
;
protected:
Func
func
{
nullptr
};
};
template
<
typename
T
,
typename
Func
,
typename
Attr
>
// TODO(TJ): use tuple
class
ReferKernel
:
public
KernelImpl
<
T
,
Func
,
Attr
>
{
public:
// Refer code can always be used
bool
UseMe
(
Attr
attr
)
const
override
{
return
true
;
}
};
}
// namespace jitkernels
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jitkernels/kernels.cc
浏览文件 @
45bfa70c
...
...
@@ -21,13 +21,16 @@ namespace paddle {
namespace
operators
{
namespace
jitkernels
{
// refer do not need useme, it would be the last one.
KernelPool
&
KernelPool
::
Instance
()
{
static
KernelPool
g_kernel_pool
;
return
g_kernel_pool
;
}
ReferKernelPool
&
ReferKernelPool
::
Instance
()
{
static
ReferKernelPool
g_refer_kernel_pool
;
return
g_refer_kernel_pool
;
}
}
// namespace jitkernels
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jitkernels/kernels.h
浏览文件 @
45bfa70c
...
...
@@ -18,22 +18,21 @@
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/operators/jitkernels/jitcode_base.h"
#include "paddle/fluid/operators/jitkernels/kernel_base.h"
#include "paddle/fluid/operators/jitkernels/kernel_key.h"
#ifdef PADDLE_WITH_XBYAK
#include "paddle/fluid/operators/jitkernels/jitcode/jitcode.h"
#endif
#include "paddle/fluid/platform/place.h"
namespace
paddle
{
namespace
operators
{
namespace
jitkernels
{
// TODO(TJ): rename file to kernel_pool
template
<
KernelType
KT
>
class
JitCodePool
{
public:
JitCodePool
()
=
default
;
static
JitCodePool
&
Instance
()
{
static
thread_local
JitCodePool
<
KT
>
g_jit_codes
;
return
g_jit_codes
;
...
...
@@ -51,13 +50,11 @@ class JitCodePool {
}
private:
JitCodePool
()
=
default
;
std
::
unordered_map
<
size_t
,
std
::
shared_ptr
<
const
JitBase
>>
codes_
;
DISABLE_COPY_AND_ASSIGN
(
JitCodePool
);
};
// std::tuple<T, Func, Attr>
//
TODO(TJ):
std::tuple<T, Func, Attr>
template
<
typename
T
,
typename
Func
,
typename
Attr
>
struct
KernelAttr
{
typedef
T
data_type
;
...
...
@@ -65,76 +62,99 @@ struct KernelAttr {
typedef
Attr
attr_type
;
};
typedef
std
::
unique_ptr
<
const
Kernel
>
KernelPtr
;
typedef
std
::
unordered_map
<
KernelKey
,
std
::
vector
<
KernelPtr
>
,
KernelKey
::
Hash
>
KernelMap
;
class
KernelPool
{
public:
static
KernelPool
&
Instance
();
typedef
std
::
unique_ptr
<
const
Kernel
>
KernelPtr
;
typedef
std
::
unordered_map
<
KernelKey
,
std
::
vector
<
KernelPtr
>
,
KernelKey
::
Hash
>
KernelMap
;
KernelPool
()
=
default
;
KernelMap
&
AllKernels
()
{
return
pool_
;
}
void
Insert
(
const
KernelKey
&
key
,
KernelPtr
value
)
{
if
(
pool_
.
find
(
key
)
==
pool_
.
end
())
{
pool_
.
emplace
(
key
,
std
::
vector
<
KernelPtr
>
());
}
pool_
.
at
(
key
).
emplace_back
(
std
::
move
(
value
));
}
KernelPool
()
=
default
;
private:
KernelMap
pool_
;
DISABLE_COPY_AND_ASSIGN
(
KernelPool
);
};
// TODO(TJ): create_jitcode;
// Every kernel should have refer code and it should be used in unit tests,
// so refer kernels should have it's independent kernel pool
class
ReferKernelPool
{
public:
static
ReferKernelPool
&
Instance
();
ReferKernelPool
()
=
default
;
KernelMap
&
AllKernels
()
{
return
pool_
;
}
void
Insert
(
const
KernelKey
&
key
,
KernelPtr
value
)
{
if
(
pool_
.
find
(
key
)
==
pool_
.
end
())
{
pool_
.
emplace
(
key
,
std
::
vector
<
KernelPtr
>
());
}
pool_
.
at
(
key
).
emplace_back
(
std
::
move
(
value
));
}
private:
KernelMap
pool_
;
DISABLE_COPY_AND_ASSIGN
(
ReferKernelPool
);
};
// Refer code do not related with attr, and always on CPUPlace
template
<
KernelType
KT
,
typename
T
,
typename
Func
,
typename
Attr
>
inline
Func
GetRefer
()
{
auto
&
ref_pool
=
ReferKernelPool
().
Instance
().
AllKernels
();
KernelKey
kkey
(
KT
,
platform
::
CPUPlace
());
auto
ref_iter
=
ref_pool
.
find
(
kkey
);
PADDLE_ENFORCE
(
ref_iter
!=
ref_pool
.
end
(),
"Every Kernel should have reference function."
);
auto
&
ref_impls
=
ref_iter
->
second
;
for
(
auto
&
impl
:
ref_impls
)
{
auto
i
=
dynamic_cast
<
const
ReferKernel
<
T
,
Func
,
Attr
>*>
(
impl
.
get
());
if
(
i
)
{
return
i
->
GetFunc
();
}
}
return
nullptr
;
}
// TODO(TJ): make tuple? named KernelAttr
template
<
KernelType
KT
,
typename
T
,
typename
Func
,
typename
Attr
,
typename
PlaceType
=
platform
::
CPUPlace
>
Func
Get
(
Attr
attr
)
{
size_t
key
=
GetKey
<
Attr
>
(
attr
);
auto
jitcode
=
JitCodePool
<
KT
>
().
Instance
().
Get
(
key
);
if
(
jitcode
)
{
return
jitcode
->
template
getCode
<
Func
>();
// size_t key = GetKey<Attr>(attr);
// auto jitcode = JitCodePool<KT>().Instance().Get(key);
// if (jitcode) {
// return jitcode->template getCode<Func>();
// }
if
(
std
::
is_same
<
PlaceType
,
platform
::
CPUPlace
>::
value
&&
std
::
is_same
<
T
,
float
>::
value
)
{
// TODO(TJ): float move to create
// auto p = CreateJitCode<KT, Attr>(attr);
// if (p) {
// JitCodePool<KT>().Instance().Insert(key, p);
// return p->template getCode<Func>();
// }
}
#ifdef PADDLE_WITH_XBYAK
// // jitcode::JitCode is under protection of PADDLE_WITH_XBYAK
// if (std::is_same<PlaceType, platform::CPUPlace>::value) {
// if (UseJitCode<KT, T, Attr>(attr)) {
// std::shared_ptr<JitBase> p(std::make_shared<jitcode::JitCode<KT, Attr>>(
// attr, CodeSize<KT, Attr>(attr)));
// JitCodePool<KT>().Instance().Insert(key, p);
// return p->getCode<Func>();
// }
// }
#endif
// (KernelKey(type, place), vector<Kernel>)
// pool: (KernelKey(type, place), vector<Kernel>)
auto
&
pool
=
KernelPool
().
Instance
().
AllKernels
();
KernelKey
kkey
(
KT
,
PlaceType
());
auto
iter
=
pool
.
find
(
kkey
);
if
(
iter
!=
pool
.
end
())
{
auto
impls
=
iter
->
second
;
for
(
auto
impl
:
impls
)
{
auto
i
=
std
::
dynamic_pointer_cast
<
KernelImpl
<
T
,
Func
,
Attr
>
>
(
impl
.
get
());
auto
&
impls
=
iter
->
second
;
for
(
auto
&
impl
:
impls
)
{
auto
i
=
dynamic_cast
<
const
KernelImpl
<
T
,
Func
,
Attr
>*
>
(
impl
.
get
());
if
(
i
&&
i
->
UseMe
(
attr
))
{
return
i
->
GetFunc
();
}
}
}
// The last implementation should be reference function on CPU
// Every kernel should have refer code.
// because of test refer should have it's own pool
// PADDLE_ENFORCE_GT(list.size(), 1) << "Should have refer implemtation";
// const auto& refer = KernelRefer<KT, T>().AllKernels();
// return refer.Get<Func>();
return
nullptr
;
// The last implementation should be reference function on CPUPlace.
return
GetRefer
<
KT
,
T
,
Func
,
Attr
>
();
}
}
// namespace jitkernels
...
...
paddle/fluid/operators/jitkernels/refer/refer.cc
浏览文件 @
45bfa70c
...
...
@@ -17,4 +17,5 @@
namespace
refer
=
paddle
::
operators
::
jitkernels
::
refer
;
// REGISTER_JITKERNEL_REFER(vmul, refer::VMul<float>, refer::VMul<double>);
REGISTER_JITKERNEL_REFER
(
vmul
,
refer
::
VMulKernel
<
float
>
,
refer
::
VMulKernel
<
double
>
);
paddle/fluid/operators/jitkernels/refer/refer.h
浏览文件 @
45bfa70c
...
...
@@ -13,6 +13,7 @@
* limitations under the License. */
#pragma once
#include "paddle/fluid/operators/jitkernels/kernel_base.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
...
...
@@ -27,6 +28,13 @@ void VMul(const T* x, const T* y, T* z, int n) {
}
}
template
<
typename
T
>
class
VMulKernel
:
public
ReferKernel
<
T
,
void
(
*
)(
const
T
*
,
const
T
*
,
T
*
,
int
),
int
>
{
public:
VMulKernel
()
{
this
->
func
=
VMul
<
T
>
;
}
};
}
// namespace refer
}
// namespace jitkernels
}
// namespace operators
...
...
paddle/fluid/operators/jitkernels/registry.h
浏览文件 @
45bfa70c
...
...
@@ -20,6 +20,7 @@
#include "paddle/fluid/operators/jitkernels/kernel_base.h"
#include "paddle/fluid/operators/jitkernels/kernels.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/variant.h" // for UNUSED
namespace
paddle
{
namespace
operators
{
...
...
@@ -32,37 +33,40 @@ inline std::unique_ptr<T> make_unique(Args&&... args) {
return
std
::
unique_ptr
<
T
>
(
new
T
(
std
::
forward
<
Args
>
(
args
)...));
}
template
<
typename
PlaceType
,
bool
IsEnd
,
size_t
I
,
typename
...
KernelImpls
>
template
<
typename
Pool
,
typename
PlaceType
,
bool
IsEnd
,
size_t
I
,
typename
...
KernelImpls
>
struct
JitKernelRegistrarFunctor
;
template
<
typename
PlaceType
,
size_t
I
,
typename
...
KernelImpls
>
struct
JitKernelRegistrarFunctor
<
PlaceType
,
true
,
I
,
KernelImpls
...
>
{
template
<
typename
P
ool
,
typename
P
laceType
,
size_t
I
,
typename
...
KernelImpls
>
struct
JitKernelRegistrarFunctor
<
P
ool
,
P
laceType
,
true
,
I
,
KernelImpls
...
>
{
void
operator
()(
KernelType
kt
)
const
{}
};
template
<
typename
PlaceType
,
size_t
I
,
typename
...
KernelImpls
>
struct
JitKernelRegistrarFunctor
<
PlaceType
,
false
,
I
,
KernelImpls
...
>
{
template
<
typename
P
ool
,
typename
P
laceType
,
size_t
I
,
typename
...
KernelImpls
>
struct
JitKernelRegistrarFunctor
<
P
ool
,
P
laceType
,
false
,
I
,
KernelImpls
...
>
{
using
KERNEL_IMPL_TYPE
=
typename
std
::
tuple_element
<
I
,
std
::
tuple
<
KernelImpls
...
>>::
type
;
void
operator
()(
KernelType
kt
)
const
{
KernelKey
kkey
(
kt
,
PlaceType
());
KernelPool
().
Instance
().
Insert
(
kkey
,
std
::
move
(
make_unique
<
const
KERNEL_IMPL_TYPE
>
()));
Pool
().
Instance
().
Insert
(
kkey
,
std
::
move
(
make_unique
<
const
KERNEL_IMPL_TYPE
>
()));
constexpr
auto
size
=
std
::
tuple_size
<
std
::
tuple
<
KernelImpls
...
>>::
value
;
JitKernelRegistrarFunctor
<
PlaceType
,
I
+
1
==
size
,
I
+
1
,
KernelImpls
...
>
JitKernelRegistrarFunctor
<
Pool
,
PlaceType
,
I
+
1
==
size
,
I
+
1
,
KernelImpls
...
>
func
;
func
(
kt
);
}
};
template
<
typename
PlaceType
,
typename
...
KernelImpls
>
template
<
typename
P
ool
,
typename
P
laceType
,
typename
...
KernelImpls
>
class
JitKernelRegistrar
{
public:
explicit
JitKernelRegistrar
(
KernelType
kt
)
{
JitKernelRegistrarFunctor
<
PlaceType
,
false
,
0
,
KernelImpls
...
>
func
;
JitKernelRegistrarFunctor
<
P
ool
,
P
laceType
,
false
,
0
,
KernelImpls
...
>
func
;
func
(
kt
);
}
void
Touch
()
{}
};
#define STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE(uniq_name, msg) \
...
...
@@ -71,17 +75,40 @@ class JitKernelRegistrar {
__test_global_namespace_##uniq_name##__>::value, \
msg)
// Refer always on CPUPlace
#define REGISTER_JITKERNEL_REFER(kernel_type, ...) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_##kernel_type##_refer_CPUPlace, \
"REGISTER_KERNEL_REFER must be called in global namespace"); \
static ::paddle::operators::jitkernels::JitKernelRegistrar< \
::paddle::operators::jitkernels::ReferKernelPool, \
::paddle::platform::CPUPlace, __VA_ARGS__> \
__jit_kernel_registrar_##kernel_type##_refer_CPUPlace_( \
::paddle::operators::jitkernels::KernelType::kernel_type); \
int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_() { \
__jit_kernel_registrar_##kernel_type##_refer_CPUPlace_.Touch(); \
return 0; \
}
// kernel_type: should be in paddle::operators::jitkernels::KernelType
// place_type: should be one of CPUPlace and GPUPlace in paddle::platform
#define REGISTER_KERNEL_MORE(kernel_type, impl_type, place_type, ...) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_##kernel_type##_##impl_type##_##place_type, \
"REGISTER_KERNEL_MORE must be called in global namespace"); \
extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
static int __assert_##kernel_type##_##impl_type##_##place_type##_has_refer_ \
UNUSED = TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
static ::paddle::operators::jitkernels::JitKernelRegistrar< \
::paddle::operators::jitkernels::KernelPool, \
::paddle::platform::place_type, __VA_ARGS__> \
__jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##__( \
::paddle::operators::jitkernels::KernelType::kernel_type)
// TODO(TJ): Add Touch and use me
__jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##_( \
::paddle::operators::jitkernels::KernelType::kernel_type); \
int TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_() { \
__jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##_ \
.Touch(); \
return 0; \
}
#define REGISTER_JITKERNEL_MORE(kernel_type, impl_type, ...) \
REGISTER_KERNEL_MORE(kernel_type, impl_type, CPUPlace, __VA_ARGS__)
...
...
@@ -89,45 +116,28 @@ class JitKernelRegistrar {
#define REGISTER_GPUKERNEL_MORE(kernel_type, impl_type, ...) \
REGISTER_KERNEL_MORE(kernel_type, impl_type, GPUPlace, __VA_ARGS__)
/*
REGISTER_JITKERNEL_JITCODE(vmul, JitKernelCode<vmul, int>);
// refer must be only one and at least one
REGISTER_JITKERNEL_REFER(vmul, VMul); // Refer need support dtype
// REGISTER_JITKERNEL_JITCODE(vmul, JitKernelCode<vmul, int>);
// you can register more implementations and the condition when use it
REGISTER_JITKERNEL_MORE(vmul, mkl::VMUL<float>, UseMe<float>, mkl::VMUL<double>,
UseMe<double>)
#define STATIC_ASSERT_PASS_GLOBAL_NAMESPACE(uniq_name, msg) \
struct __test_global_namespace_##uniq_name##__ {}; \
static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \
__test_global_namespace_##uniq_name##__>::value, \
msg)
#define USE_JITKERNEL_REFER(kernel_type) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_##kernel_type##_refer_CPUPlace_, \
"USE_JITKERNEL_REFER must be called in global namespace"); \
extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
static int use_jitkernel_##kernel_type##_refer_CPUPlace_ UNUSED = \
TouchJitKernelReg_##kernel_type##_refer_CPUPlace_()
// Register a new pass that can be applied on the IR.
#define REGISTER_PASS(pass_type, pass_class) \
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
__reg_pass__##pass_type, \
"REGISTER_PASS must be called in global namespace"); \
static ::paddle::framework::ir::PassRegistrar<pass_class> \
__pass_registrar_##pass_type##__(#pass_type); \
int TouchPassRegistrar_##pass_type() { \
__pass_registrar_##pass_type##__.Touch(); \
return 0; \
} \
static ::paddle::framework::ir::PassRegistrar<pass_class>& \
__pass_tmp_registrar_##pass_type##__ UNUSED = \
__pass_registrar_##pass_type##__
#define USE_PASS(pass_type) \
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
__use_pass_itself_##pass_type, \
"USE_PASS must be called in global namespace"); \
extern int TouchPassRegistrar_##pass_type(); \
static int use_pass_itself_##pass_type##_ UNUSED = \
TouchPassRegistrar_##pass_type()
*/
#define USE_KERNEL_MORE(kernel_type, impl_type, place_type) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_##kernel_type##_##impl_type##_##place_type##_, \
"USE_JITKERNEL_MORE must be called in global namespace"); \
extern int \
TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_(); \
static int use_jitkernel_##kernel_type##_##impl_type##_##place_type##_ \
UNUSED = \
TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_()
#define USE_JITKERNEL_MORE(kernel_type, impl_type) \
USE_KERNEL_MORE(kernel_type, impl_type, CPUPlace)
}
// namespace jitkernels
}
// namespace operators
...
...
paddle/fluid/operators/jitkernels/test.cc
浏览文件 @
45bfa70c
...
...
@@ -19,8 +19,11 @@
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/operators/math/jit_kernel.h"
#include "paddle/fluid/operators/math/jit_kernel_refer.h"
#include "paddle/fluid/operators/jitkernels/kernels.h"
// TODO(TJ): remove me
#include "paddle/fluid/operators/jitkernels/registry.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/port.h"
constexpr
int
repeat
=
20000
;
...
...
@@ -31,6 +34,75 @@ inline double GetCurrentUS() {
return
1e+6
*
time
.
tv_sec
+
time
.
tv_usec
;
}
TEST
(
JitKernel
,
vmul
)
{}
template
<
typename
T
>
void
RandomVec
(
const
int
n
,
T
*
a
,
const
T
lower
=
static_cast
<
T
>
(
-
20.
f
),
const
T
upper
=
static_cast
<
T
>
(
20.
f
))
{
static
unsigned
int
seed
=
100
;
std
::
mt19937
rng
(
seed
++
);
std
::
uniform_real_distribution
<
double
>
uniform_dist
(
0
,
1
);
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
a
[
i
]
=
static_cast
<
T
>
(
uniform_dist
(
rng
)
*
(
upper
-
lower
)
+
lower
);
}
}
template
<
typename
T
>
void
ExpectEQ
(
const
T
*
target
,
const
T
*
refer
,
int
n
)
{
if
(
std
::
is_floating_point
<
T
>::
value
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
EXPECT_NEAR
(
target
[
i
],
refer
[
i
],
1e-3
);
}
}
else
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
EXPECT_EQ
(
target
[
i
],
refer
[
i
]);
}
}
}
// TODO(TJ): remove me
USE_JITKERNEL_MORE
(
vmul
,
mkl
);
USE_JITKERNEL_REFER
(
vmul
);
TEST
(
JitKernel
,
vmul
)
{
using
T
=
float
;
using
PlaceType
=
paddle
::
platform
::
CPUPlace
;
namespace
jit
=
paddle
::
operators
::
jitkernels
;
// TODO(TJ): test more vector size
for
(
int
d
=
1
;
d
<
30
;
++
d
)
{
auto
ref
=
jit
::
GetRefer
<
jit
::
vmul
,
T
,
void
(
*
)(
const
T
*
,
const
T
*
,
T
*
,
int
),
int
>
();
auto
tgt
=
jit
::
Get
<
jit
::
vmul
,
T
,
void
(
*
)(
const
T
*
,
const
T
*
,
T
*
,
int
),
int
,
PlaceType
>
(
d
);
EXPECT_TRUE
(
ref
!=
nullptr
);
EXPECT_TRUE
(
tgt
!=
nullptr
);
std
::
vector
<
T
>
x
(
d
),
y
(
d
);
std
::
vector
<
T
>
zref
(
d
),
ztgt
(
d
);
RandomVec
<
T
>
(
d
,
x
.
data
());
RandomVec
<
T
>
(
d
,
y
.
data
());
const
float
*
x_data
=
x
.
data
();
const
float
*
y_data
=
y
.
data
();
float
*
ztgt_data
=
ztgt
.
data
();
float
*
zref_data
=
zref
.
data
();
tgt
(
x_data
,
y_data
,
ztgt_data
,
d
);
ref
(
x_data
,
y_data
,
zref_data
,
d
);
ExpectEQ
<
T
>
(
ztgt_data
,
zref_data
,
d
);
// test inplace x
std
::
copy
(
x
.
begin
(),
x
.
end
(),
zref
.
begin
());
std
::
copy
(
x
.
begin
(),
x
.
end
(),
ztgt
.
begin
());
tgt
(
ztgt_data
,
y_data
,
ztgt_data
,
d
);
ref
(
zref_data
,
y_data
,
zref_data
,
d
);
ExpectEQ
<
T
>
(
ztgt_data
,
zref_data
,
d
);
// test inplace y
std
::
copy
(
y
.
begin
(),
y
.
end
(),
zref
.
begin
());
std
::
copy
(
y
.
begin
(),
y
.
end
(),
ztgt
.
begin
());
tgt
(
x_data
,
ztgt_data
,
ztgt_data
,
d
);
ref
(
x_data
,
zref_data
,
zref_data
,
d
);
ExpectEQ
<
T
>
(
ztgt_data
,
zref_data
,
d
);
}
}
TEST
(
JitKernel
,
pool
)
{}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录