Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
45bfa70c
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
45bfa70c
编写于
12月 03, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
complete vmul jit kernel
上级
77236e33
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
273 addition
and
126 deletion
+273
-126
paddle/fluid/operators/jitkernels/CMakeLists.txt
paddle/fluid/operators/jitkernels/CMakeLists.txt
+7
-5
paddle/fluid/operators/jitkernels/README.md
paddle/fluid/operators/jitkernels/README.md
+3
-0
paddle/fluid/operators/jitkernels/jitcode/jitcode.cc
paddle/fluid/operators/jitkernels/jitcode/jitcode.cc
+23
-0
paddle/fluid/operators/jitkernels/jitcode/jitcode.h
paddle/fluid/operators/jitkernels/jitcode/jitcode.h
+4
-3
paddle/fluid/operators/jitkernels/jitcode_base.h
paddle/fluid/operators/jitkernels/jitcode_base.h
+4
-5
paddle/fluid/operators/jitkernels/kernel_base.h
paddle/fluid/operators/jitkernels/kernel_base.h
+9
-4
paddle/fluid/operators/jitkernels/kernels.cc
paddle/fluid/operators/jitkernels/kernels.cc
+5
-2
paddle/fluid/operators/jitkernels/kernels.h
paddle/fluid/operators/jitkernels/kernels.h
+65
-45
paddle/fluid/operators/jitkernels/refer/refer.cc
paddle/fluid/operators/jitkernels/refer/refer.cc
+2
-1
paddle/fluid/operators/jitkernels/refer/refer.h
paddle/fluid/operators/jitkernels/refer/refer.h
+8
-0
paddle/fluid/operators/jitkernels/registry.h
paddle/fluid/operators/jitkernels/registry.h
+68
-58
paddle/fluid/operators/jitkernels/test.cc
paddle/fluid/operators/jitkernels/test.cc
+75
-3
未找到文件。
paddle/fluid/operators/jitkernels/CMakeLists.txt
浏览文件 @
45bfa70c
# set(use_jit_file ${PADDLE_BINARY_DIR}/paddle/fluid/operators/jit/kernels.h)
# file(WRITE ${pass_file} "// Generated by the paddle/fluid/framework/ir/CMakeLists.txt. DO NOT EDIT!\n\n")
# file(APPEND ${pass_file} "\#pragma once\n")
# file(APPEND ${pass_file} "\#include \"paddle/fluid/framework/ir/pass.h\"\n")
set
(
JIT_KERNEL_DEPS cpu_info cblas gflags enforce place
)
set
(
JIT_KERNEL_DEPS cpu_info cblas gflags enforce place
)
cc_library
(
jit_kernel_base SRCS kernels.cc DEPS
${
JIT_KERNEL_DEPS
}
)
cc_library
(
jit_kernel_base SRCS kernels.cc DEPS
${
JIT_KERNEL_DEPS
}
)
add_subdirectory
(
more
)
add_subdirectory
(
refer
)
add_subdirectory
(
refer
)
add_subdirectory
(
more
)
if
(
WITH_XBYAK
)
if
(
WITH_XBYAK
)
add_subdirectory
(
jitcode
)
add_subdirectory
(
jitcode
)
endif
()
endif
()
# Debug
message
(
STATUS
"--------
${
JIT_KERNEL_DEPS
}
"
)
cc_library
(
jit_kernel SRCS kernels.cc DEPS
${
JIT_KERNEL_DEPS
}
)
cc_library
(
jit_kernel SRCS kernels.cc DEPS
${
JIT_KERNEL_DEPS
}
)
cc_test
(
jit_kernel_test SRCS test.cc DEPS jit_kernel
)
cc_test
(
jit_kernel_test SRCS test.cc DEPS jit_kernel
)
paddle/fluid/operators/jitkernels/README.md
浏览文件 @
45bfa70c
TBD
TBD
# Use me
Add USE_JIT_KERNEL(yourname) to CMakefile.
paddle/fluid/operators/jitkernels/jitcode/jitcode.cc
浏览文件 @
45bfa70c
...
@@ -13,3 +13,26 @@
...
@@ -13,3 +13,26 @@
* limitations under the License. */
* limitations under the License. */
#include "paddle/fluid/operators/jitkernels/jitcode/jitcode.h"
#include "paddle/fluid/operators/jitkernels/jitcode/jitcode.h"
namespace
paddle
{
namespace
operators
{
namespace
jitkernels
{
template
<
>
size_t
GetKey
<
int
>
(
int
d
)
{
return
d
;
}
// template <>
// std::shared_ptr<const JitBase> CreateJitCode<KernelType::vmul, int>(int attr)
// {
// if (UseJitCode<KernelType::vmul, int>(attr)) {
// return std::make_shared<jitcode::VMulJitCode<int>>(attr,
// CodeSize<KernelType::vmul, int>(attr)));
// }
// return nullptr;
// }
}
// namespace jitkernels
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jitkernels/jitcode/jitcode.h
浏览文件 @
45bfa70c
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#pragma once
#pragma once
#include <type_traits>
#include <type_traits>
#include "paddle/fluid/operators/jitkernels/jitcode_base.h"
#include "paddle/fluid/operators/jitkernels/kernels.h"
#include "paddle/fluid/operators/jitkernels/kernels.h"
#define XBYAK_USE_MMAP_ALLOCATOR
#define XBYAK_USE_MMAP_ALLOCATOR
...
@@ -31,10 +32,10 @@ constexpr Xbyak::Operand::Code abi_param1(Xbyak::Operand::RDI),
...
@@ -31,10 +32,10 @@ constexpr Xbyak::Operand::Code abi_param1(Xbyak::Operand::RDI),
abi_param2
(
Xbyak
::
Operand
::
RSI
),
abi_param3
(
Xbyak
::
Operand
::
RDX
),
abi_param2
(
Xbyak
::
Operand
::
RSI
),
abi_param3
(
Xbyak
::
Operand
::
RDX
),
abi_param4
(
Xbyak
::
Operand
::
RCX
),
abi_not_param1
(
Xbyak
::
Operand
::
RCX
);
abi_param4
(
Xbyak
::
Operand
::
RCX
),
abi_not_param1
(
Xbyak
::
Operand
::
RCX
);
template
<
KernelType
KT
,
typename
Attr
>
template
<
typename
Attr
>
class
JitCode
:
public
JitBase
,
public
Xbyak
::
CodeGenerator
{
class
VMul
JitCode
:
public
JitBase
,
public
Xbyak
::
CodeGenerator
{
public:
public:
JitCode
(
Attr
attr
,
size_t
code_size
,
void
*
code_ptr
=
nullptr
)
VMul
JitCode
(
Attr
attr
,
size_t
code_size
,
void
*
code_ptr
=
nullptr
)
:
Xbyak
::
CodeGenerator
(
code_size
,
code_ptr
)
{
:
Xbyak
::
CodeGenerator
(
code_size
,
code_ptr
)
{
this
->
genCode
();
this
->
genCode
();
}
}
...
...
paddle/fluid/operators/jitkernels/jitcode_base.h
浏览文件 @
45bfa70c
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#pragma once
#pragma once
#include <gflags/gflags.h>
#include <gflags/gflags.h>
#include <memory> // for shared_ptr
#include "paddle/fluid/operators/jitkernels/kernel_base.h"
#include "paddle/fluid/operators/jitkernels/kernel_base.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/macros.h"
...
@@ -42,11 +43,6 @@ bool UseJitCode(Attr attr) {
...
@@ -42,11 +43,6 @@ bool UseJitCode(Attr attr) {
template
<
typename
Attr
>
template
<
typename
Attr
>
size_t
GetKey
(
Attr
attr
);
size_t
GetKey
(
Attr
attr
);
template
<
>
size_t
GetKey
<
int
>
(
int
d
)
{
return
d
;
}
class
JitBase
{
class
JitBase
{
public:
public:
JitBase
()
=
default
;
JitBase
()
=
default
;
...
@@ -68,6 +64,9 @@ class JitBase {
...
@@ -68,6 +64,9 @@ class JitBase {
void
dumpCode
(
const
unsigned
char
*
code
);
void
dumpCode
(
const
unsigned
char
*
code
);
};
};
template
<
KernelType
KT
,
typename
Attr
>
std
::
shared_ptr
<
const
JitBase
>
CreateJitCode
(
Attr
attr
);
}
// namespace jitkernels
}
// namespace jitkernels
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/fluid/operators/jitkernels/kernel_base.h
浏览文件 @
45bfa70c
...
@@ -25,6 +25,7 @@ typedef enum { vmul = 0, vadd = 1, vsub, vexp } KernelType;
...
@@ -25,6 +25,7 @@ typedef enum { vmul = 0, vadd = 1, vsub, vexp } KernelType;
class
Kernel
{
class
Kernel
{
public:
public:
Kernel
()
=
default
;
Kernel
()
=
default
;
virtual
~
Kernel
()
=
default
;
DISABLE_COPY_AND_ASSIGN
(
Kernel
);
DISABLE_COPY_AND_ASSIGN
(
Kernel
);
};
};
...
@@ -32,16 +33,20 @@ template <typename T, typename Func, typename Attr> // TODO(TJ): use tuple
...
@@ -32,16 +33,20 @@ template <typename T, typename Func, typename Attr> // TODO(TJ): use tuple
class
KernelImpl
:
public
Kernel
{
class
KernelImpl
:
public
Kernel
{
public:
public:
using
ELEMENT_TYPE
=
T
;
// TODO(TJ): remove me?
using
ELEMENT_TYPE
=
T
;
// TODO(TJ): remove me?
KernelImpl
()
=
default
;
virtual
Func
GetFunc
()
const
{
return
func
;
}
virtual
~
KernelImpl
()
=
default
;
virtual
Func
GetFunc
()
{
return
func
;
}
virtual
bool
UseMe
(
Attr
attr
)
const
=
0
;
virtual
bool
UseMe
(
Attr
attr
)
const
=
0
;
protected:
protected:
Func
func
{
nullptr
};
Func
func
{
nullptr
};
};
};
template
<
typename
T
,
typename
Func
,
typename
Attr
>
// TODO(TJ): use tuple
class
ReferKernel
:
public
KernelImpl
<
T
,
Func
,
Attr
>
{
public:
// Refer code can always be used
bool
UseMe
(
Attr
attr
)
const
override
{
return
true
;
}
};
}
// namespace jitkernels
}
// namespace jitkernels
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/fluid/operators/jitkernels/kernels.cc
浏览文件 @
45bfa70c
...
@@ -21,13 +21,16 @@ namespace paddle {
...
@@ -21,13 +21,16 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
namespace
jitkernels
{
namespace
jitkernels
{
// refer do not need useme, it would be the last one.
KernelPool
&
KernelPool
::
Instance
()
{
KernelPool
&
KernelPool
::
Instance
()
{
static
KernelPool
g_kernel_pool
;
static
KernelPool
g_kernel_pool
;
return
g_kernel_pool
;
return
g_kernel_pool
;
}
}
ReferKernelPool
&
ReferKernelPool
::
Instance
()
{
static
ReferKernelPool
g_refer_kernel_pool
;
return
g_refer_kernel_pool
;
}
}
// namespace jitkernels
}
// namespace jitkernels
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/fluid/operators/jitkernels/kernels.h
浏览文件 @
45bfa70c
...
@@ -18,22 +18,21 @@
...
@@ -18,22 +18,21 @@
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include <vector>
#include "paddle/fluid/operators/jitkernels/jitcode_base.h"
#include "paddle/fluid/operators/jitkernels/jitcode_base.h"
#include "paddle/fluid/operators/jitkernels/kernel_base.h"
#include "paddle/fluid/operators/jitkernels/kernel_base.h"
#include "paddle/fluid/operators/jitkernels/kernel_key.h"
#include "paddle/fluid/operators/jitkernels/kernel_key.h"
#include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_XBYAK
#include "paddle/fluid/operators/jitkernels/jitcode/jitcode.h"
#endif
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
namespace
jitkernels
{
namespace
jitkernels
{
// TODO(TJ): rename file to kernel_pool
template
<
KernelType
KT
>
template
<
KernelType
KT
>
class
JitCodePool
{
class
JitCodePool
{
public:
public:
JitCodePool
()
=
default
;
static
JitCodePool
&
Instance
()
{
static
JitCodePool
&
Instance
()
{
static
thread_local
JitCodePool
<
KT
>
g_jit_codes
;
static
thread_local
JitCodePool
<
KT
>
g_jit_codes
;
return
g_jit_codes
;
return
g_jit_codes
;
...
@@ -51,13 +50,11 @@ class JitCodePool {
...
@@ -51,13 +50,11 @@ class JitCodePool {
}
}
private:
private:
JitCodePool
()
=
default
;
std
::
unordered_map
<
size_t
,
std
::
shared_ptr
<
const
JitBase
>>
codes_
;
std
::
unordered_map
<
size_t
,
std
::
shared_ptr
<
const
JitBase
>>
codes_
;
DISABLE_COPY_AND_ASSIGN
(
JitCodePool
);
DISABLE_COPY_AND_ASSIGN
(
JitCodePool
);
};
};
// std::tuple<T, Func, Attr>
//
TODO(TJ):
std::tuple<T, Func, Attr>
template
<
typename
T
,
typename
Func
,
typename
Attr
>
template
<
typename
T
,
typename
Func
,
typename
Attr
>
struct
KernelAttr
{
struct
KernelAttr
{
typedef
T
data_type
;
typedef
T
data_type
;
...
@@ -65,76 +62,99 @@ struct KernelAttr {
...
@@ -65,76 +62,99 @@ struct KernelAttr {
typedef
Attr
attr_type
;
typedef
Attr
attr_type
;
};
};
typedef
std
::
unique_ptr
<
const
Kernel
>
KernelPtr
;
typedef
std
::
unordered_map
<
KernelKey
,
std
::
vector
<
KernelPtr
>
,
KernelKey
::
Hash
>
KernelMap
;
class
KernelPool
{
class
KernelPool
{
public:
public:
static
KernelPool
&
Instance
();
static
KernelPool
&
Instance
();
KernelPool
()
=
default
;
typedef
std
::
unique_ptr
<
const
Kernel
>
KernelPtr
;
typedef
std
::
unordered_map
<
KernelKey
,
std
::
vector
<
KernelPtr
>
,
KernelKey
::
Hash
>
KernelMap
;
KernelMap
&
AllKernels
()
{
return
pool_
;
}
KernelMap
&
AllKernels
()
{
return
pool_
;
}
void
Insert
(
const
KernelKey
&
key
,
KernelPtr
value
)
{
void
Insert
(
const
KernelKey
&
key
,
KernelPtr
value
)
{
if
(
pool_
.
find
(
key
)
==
pool_
.
end
())
{
if
(
pool_
.
find
(
key
)
==
pool_
.
end
())
{
pool_
.
emplace
(
key
,
std
::
vector
<
KernelPtr
>
());
pool_
.
emplace
(
key
,
std
::
vector
<
KernelPtr
>
());
}
}
pool_
.
at
(
key
).
emplace_back
(
std
::
move
(
value
));
pool_
.
at
(
key
).
emplace_back
(
std
::
move
(
value
));
}
}
KernelPool
()
=
default
;
private:
private:
KernelMap
pool_
;
KernelMap
pool_
;
DISABLE_COPY_AND_ASSIGN
(
KernelPool
);
DISABLE_COPY_AND_ASSIGN
(
KernelPool
);
};
};
// TODO(TJ): create_jitcode;
// Every kernel should have refer code and it should be used in unit tests,
// so refer kernels should have it's independent kernel pool
class
ReferKernelPool
{
public:
static
ReferKernelPool
&
Instance
();
ReferKernelPool
()
=
default
;
KernelMap
&
AllKernels
()
{
return
pool_
;
}
void
Insert
(
const
KernelKey
&
key
,
KernelPtr
value
)
{
if
(
pool_
.
find
(
key
)
==
pool_
.
end
())
{
pool_
.
emplace
(
key
,
std
::
vector
<
KernelPtr
>
());
}
pool_
.
at
(
key
).
emplace_back
(
std
::
move
(
value
));
}
private:
KernelMap
pool_
;
DISABLE_COPY_AND_ASSIGN
(
ReferKernelPool
);
};
// Refer code do not related with attr, and always on CPUPlace
template
<
KernelType
KT
,
typename
T
,
typename
Func
,
typename
Attr
>
inline
Func
GetRefer
()
{
auto
&
ref_pool
=
ReferKernelPool
().
Instance
().
AllKernels
();
KernelKey
kkey
(
KT
,
platform
::
CPUPlace
());
auto
ref_iter
=
ref_pool
.
find
(
kkey
);
PADDLE_ENFORCE
(
ref_iter
!=
ref_pool
.
end
(),
"Every Kernel should have reference function."
);
auto
&
ref_impls
=
ref_iter
->
second
;
for
(
auto
&
impl
:
ref_impls
)
{
auto
i
=
dynamic_cast
<
const
ReferKernel
<
T
,
Func
,
Attr
>*>
(
impl
.
get
());
if
(
i
)
{
return
i
->
GetFunc
();
}
}
return
nullptr
;
}
// TODO(TJ): make tuple? named KernelAttr
// TODO(TJ): make tuple? named KernelAttr
template
<
KernelType
KT
,
typename
T
,
typename
Func
,
typename
Attr
,
template
<
KernelType
KT
,
typename
T
,
typename
Func
,
typename
Attr
,
typename
PlaceType
=
platform
::
CPUPlace
>
typename
PlaceType
=
platform
::
CPUPlace
>
Func
Get
(
Attr
attr
)
{
Func
Get
(
Attr
attr
)
{
size_t
key
=
GetKey
<
Attr
>
(
attr
);
// size_t key = GetKey<Attr>(attr);
auto
jitcode
=
JitCodePool
<
KT
>
().
Instance
().
Get
(
key
);
// auto jitcode = JitCodePool<KT>().Instance().Get(key);
if
(
jitcode
)
{
// if (jitcode) {
return
jitcode
->
template
getCode
<
Func
>();
// return jitcode->template getCode<Func>();
// }
if
(
std
::
is_same
<
PlaceType
,
platform
::
CPUPlace
>::
value
&&
std
::
is_same
<
T
,
float
>::
value
)
{
// TODO(TJ): float move to create
// auto p = CreateJitCode<KT, Attr>(attr);
// if (p) {
// JitCodePool<KT>().Instance().Insert(key, p);
// return p->template getCode<Func>();
// }
}
}
#ifdef PADDLE_WITH_XBYAK
// pool: (KernelKey(type, place), vector<Kernel>)
// // jitcode::JitCode is under protection of PADDLE_WITH_XBYAK
// if (std::is_same<PlaceType, platform::CPUPlace>::value) {
// if (UseJitCode<KT, T, Attr>(attr)) {
// std::shared_ptr<JitBase> p(std::make_shared<jitcode::JitCode<KT, Attr>>(
// attr, CodeSize<KT, Attr>(attr)));
// JitCodePool<KT>().Instance().Insert(key, p);
// return p->getCode<Func>();
// }
// }
#endif
// (KernelKey(type, place), vector<Kernel>)
auto
&
pool
=
KernelPool
().
Instance
().
AllKernels
();
auto
&
pool
=
KernelPool
().
Instance
().
AllKernels
();
KernelKey
kkey
(
KT
,
PlaceType
());
KernelKey
kkey
(
KT
,
PlaceType
());
auto
iter
=
pool
.
find
(
kkey
);
auto
iter
=
pool
.
find
(
kkey
);
if
(
iter
!=
pool
.
end
())
{
if
(
iter
!=
pool
.
end
())
{
auto
impls
=
iter
->
second
;
auto
&
impls
=
iter
->
second
;
for
(
auto
impl
:
impls
)
{
for
(
auto
&
impl
:
impls
)
{
auto
i
=
std
::
dynamic_pointer_cast
<
KernelImpl
<
T
,
Func
,
Attr
>
>
(
impl
.
get
());
auto
i
=
dynamic_cast
<
const
KernelImpl
<
T
,
Func
,
Attr
>*
>
(
impl
.
get
());
if
(
i
&&
i
->
UseMe
(
attr
))
{
if
(
i
&&
i
->
UseMe
(
attr
))
{
return
i
->
GetFunc
();
return
i
->
GetFunc
();
}
}
}
}
}
}
// The last implementation should be reference function on CPU
// The last implementation should be reference function on CPUPlace.
// Every kernel should have refer code.
return
GetRefer
<
KT
,
T
,
Func
,
Attr
>
();
// because of test refer should have it's own pool
// PADDLE_ENFORCE_GT(list.size(), 1) << "Should have refer implemtation";
// const auto& refer = KernelRefer<KT, T>().AllKernels();
// return refer.Get<Func>();
return
nullptr
;
}
}
}
// namespace jitkernels
}
// namespace jitkernels
...
...
paddle/fluid/operators/jitkernels/refer/refer.cc
浏览文件 @
45bfa70c
...
@@ -17,4 +17,5 @@
...
@@ -17,4 +17,5 @@
namespace
refer
=
paddle
::
operators
::
jitkernels
::
refer
;
namespace
refer
=
paddle
::
operators
::
jitkernels
::
refer
;
// REGISTER_JITKERNEL_REFER(vmul, refer::VMul<float>, refer::VMul<double>);
REGISTER_JITKERNEL_REFER
(
vmul
,
refer
::
VMulKernel
<
float
>
,
refer
::
VMulKernel
<
double
>
);
paddle/fluid/operators/jitkernels/refer/refer.h
浏览文件 @
45bfa70c
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
* limitations under the License. */
* limitations under the License. */
#pragma once
#pragma once
#include "paddle/fluid/operators/jitkernels/kernel_base.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -27,6 +28,13 @@ void VMul(const T* x, const T* y, T* z, int n) {
...
@@ -27,6 +28,13 @@ void VMul(const T* x, const T* y, T* z, int n) {
}
}
}
}
template
<
typename
T
>
class
VMulKernel
:
public
ReferKernel
<
T
,
void
(
*
)(
const
T
*
,
const
T
*
,
T
*
,
int
),
int
>
{
public:
VMulKernel
()
{
this
->
func
=
VMul
<
T
>
;
}
};
}
// namespace refer
}
// namespace refer
}
// namespace jitkernels
}
// namespace jitkernels
}
// namespace operators
}
// namespace operators
...
...
paddle/fluid/operators/jitkernels/registry.h
浏览文件 @
45bfa70c
...
@@ -20,6 +20,7 @@
...
@@ -20,6 +20,7 @@
#include "paddle/fluid/operators/jitkernels/kernel_base.h"
#include "paddle/fluid/operators/jitkernels/kernel_base.h"
#include "paddle/fluid/operators/jitkernels/kernels.h"
#include "paddle/fluid/operators/jitkernels/kernels.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/variant.h" // for UNUSED
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -32,37 +33,40 @@ inline std::unique_ptr<T> make_unique(Args&&... args) {
...
@@ -32,37 +33,40 @@ inline std::unique_ptr<T> make_unique(Args&&... args) {
return
std
::
unique_ptr
<
T
>
(
new
T
(
std
::
forward
<
Args
>
(
args
)...));
return
std
::
unique_ptr
<
T
>
(
new
T
(
std
::
forward
<
Args
>
(
args
)...));
}
}
template
<
typename
PlaceType
,
bool
IsEnd
,
size_t
I
,
typename
...
KernelImpls
>
template
<
typename
Pool
,
typename
PlaceType
,
bool
IsEnd
,
size_t
I
,
typename
...
KernelImpls
>
struct
JitKernelRegistrarFunctor
;
struct
JitKernelRegistrarFunctor
;
template
<
typename
PlaceType
,
size_t
I
,
typename
...
KernelImpls
>
template
<
typename
P
ool
,
typename
P
laceType
,
size_t
I
,
typename
...
KernelImpls
>
struct
JitKernelRegistrarFunctor
<
PlaceType
,
true
,
I
,
KernelImpls
...
>
{
struct
JitKernelRegistrarFunctor
<
P
ool
,
P
laceType
,
true
,
I
,
KernelImpls
...
>
{
void
operator
()(
KernelType
kt
)
const
{}
void
operator
()(
KernelType
kt
)
const
{}
};
};
template
<
typename
PlaceType
,
size_t
I
,
typename
...
KernelImpls
>
template
<
typename
P
ool
,
typename
P
laceType
,
size_t
I
,
typename
...
KernelImpls
>
struct
JitKernelRegistrarFunctor
<
PlaceType
,
false
,
I
,
KernelImpls
...
>
{
struct
JitKernelRegistrarFunctor
<
P
ool
,
P
laceType
,
false
,
I
,
KernelImpls
...
>
{
using
KERNEL_IMPL_TYPE
=
using
KERNEL_IMPL_TYPE
=
typename
std
::
tuple_element
<
I
,
std
::
tuple
<
KernelImpls
...
>>::
type
;
typename
std
::
tuple_element
<
I
,
std
::
tuple
<
KernelImpls
...
>>::
type
;
void
operator
()(
KernelType
kt
)
const
{
void
operator
()(
KernelType
kt
)
const
{
KernelKey
kkey
(
kt
,
PlaceType
());
KernelKey
kkey
(
kt
,
PlaceType
());
KernelPool
().
Instance
().
Insert
(
Pool
().
Instance
().
Insert
(
kkey
,
kkey
,
std
::
move
(
make_unique
<
const
KERNEL_IMPL_TYPE
>
()));
std
::
move
(
make_unique
<
const
KERNEL_IMPL_TYPE
>
()));
constexpr
auto
size
=
std
::
tuple_size
<
std
::
tuple
<
KernelImpls
...
>>::
value
;
constexpr
auto
size
=
std
::
tuple_size
<
std
::
tuple
<
KernelImpls
...
>>::
value
;
JitKernelRegistrarFunctor
<
PlaceType
,
I
+
1
==
size
,
I
+
1
,
KernelImpls
...
>
JitKernelRegistrarFunctor
<
Pool
,
PlaceType
,
I
+
1
==
size
,
I
+
1
,
KernelImpls
...
>
func
;
func
;
func
(
kt
);
func
(
kt
);
}
}
};
};
template
<
typename
PlaceType
,
typename
...
KernelImpls
>
template
<
typename
P
ool
,
typename
P
laceType
,
typename
...
KernelImpls
>
class
JitKernelRegistrar
{
class
JitKernelRegistrar
{
public:
public:
explicit
JitKernelRegistrar
(
KernelType
kt
)
{
explicit
JitKernelRegistrar
(
KernelType
kt
)
{
JitKernelRegistrarFunctor
<
PlaceType
,
false
,
0
,
KernelImpls
...
>
func
;
JitKernelRegistrarFunctor
<
P
ool
,
P
laceType
,
false
,
0
,
KernelImpls
...
>
func
;
func
(
kt
);
func
(
kt
);
}
}
void
Touch
()
{}
};
};
#define STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE(uniq_name, msg) \
#define STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE(uniq_name, msg) \
...
@@ -71,17 +75,40 @@ class JitKernelRegistrar {
...
@@ -71,17 +75,40 @@ class JitKernelRegistrar {
__test_global_namespace_##uniq_name##__>::value, \
__test_global_namespace_##uniq_name##__>::value, \
msg)
msg)
// Refer always on CPUPlace
#define REGISTER_JITKERNEL_REFER(kernel_type, ...) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_##kernel_type##_refer_CPUPlace, \
"REGISTER_KERNEL_REFER must be called in global namespace"); \
static ::paddle::operators::jitkernels::JitKernelRegistrar< \
::paddle::operators::jitkernels::ReferKernelPool, \
::paddle::platform::CPUPlace, __VA_ARGS__> \
__jit_kernel_registrar_##kernel_type##_refer_CPUPlace_( \
::paddle::operators::jitkernels::KernelType::kernel_type); \
int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_() { \
__jit_kernel_registrar_##kernel_type##_refer_CPUPlace_.Touch(); \
return 0; \
}
// kernel_type: should be in paddle::operators::jitkernels::KernelType
// kernel_type: should be in paddle::operators::jitkernels::KernelType
// place_type: should be one of CPUPlace and GPUPlace in paddle::platform
// place_type: should be one of CPUPlace and GPUPlace in paddle::platform
#define REGISTER_KERNEL_MORE(kernel_type, impl_type, place_type, ...) \
#define REGISTER_KERNEL_MORE(kernel_type, impl_type, place_type, ...) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_##kernel_type##_##impl_type##_##place_type, \
__reg_jitkernel_##kernel_type##_##impl_type##_##place_type, \
"REGISTER_KERNEL_MORE must be called in global namespace"); \
"REGISTER_KERNEL_MORE must be called in global namespace"); \
static ::paddle::operators::jitkernels::JitKernelRegistrar< \
extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
::paddle::platform::place_type, __VA_ARGS__> \
static int __assert_##kernel_type##_##impl_type##_##place_type##_has_refer_ \
__jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##__( \
UNUSED = TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
::paddle::operators::jitkernels::KernelType::kernel_type)
static ::paddle::operators::jitkernels::JitKernelRegistrar< \
// TODO(TJ): Add Touch and use me
::paddle::operators::jitkernels::KernelPool, \
::paddle::platform::place_type, __VA_ARGS__> \
__jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##_( \
::paddle::operators::jitkernels::KernelType::kernel_type); \
int TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_() { \
__jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##_ \
.Touch(); \
return 0; \
}
#define REGISTER_JITKERNEL_MORE(kernel_type, impl_type, ...) \
#define REGISTER_JITKERNEL_MORE(kernel_type, impl_type, ...) \
REGISTER_KERNEL_MORE(kernel_type, impl_type, CPUPlace, __VA_ARGS__)
REGISTER_KERNEL_MORE(kernel_type, impl_type, CPUPlace, __VA_ARGS__)
...
@@ -89,45 +116,28 @@ class JitKernelRegistrar {
...
@@ -89,45 +116,28 @@ class JitKernelRegistrar {
#define REGISTER_GPUKERNEL_MORE(kernel_type, impl_type, ...) \
#define REGISTER_GPUKERNEL_MORE(kernel_type, impl_type, ...) \
REGISTER_KERNEL_MORE(kernel_type, impl_type, GPUPlace, __VA_ARGS__)
REGISTER_KERNEL_MORE(kernel_type, impl_type, GPUPlace, __VA_ARGS__)
/*
// REGISTER_JITKERNEL_JITCODE(vmul, JitKernelCode<vmul, int>);
REGISTER_JITKERNEL_JITCODE(vmul, JitKernelCode<vmul, int>);
#define USE_JITKERNEL_REFER(kernel_type) \
// refer must be only one and at least one
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
REGISTER_JITKERNEL_REFER(vmul, VMul); // Refer need support dtype
__reg_jitkernel_##kernel_type##_refer_CPUPlace_, \
"USE_JITKERNEL_REFER must be called in global namespace"); \
// you can register more implementations and the condition when use it
extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
REGISTER_JITKERNEL_MORE(vmul, mkl::VMUL<float>, UseMe<float>, mkl::VMUL<double>,
static int use_jitkernel_##kernel_type##_refer_CPUPlace_ UNUSED = \
UseMe<double>)
TouchJitKernelReg_##kernel_type##_refer_CPUPlace_()
#define STATIC_ASSERT_PASS_GLOBAL_NAMESPACE(uniq_name, msg) \
#define USE_KERNEL_MORE(kernel_type, impl_type, place_type) \
struct __test_global_namespace_##uniq_name##__ {}; \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \
__reg_jitkernel_##kernel_type##_##impl_type##_##place_type##_, \
__test_global_namespace_##uniq_name##__>::value, \
"USE_JITKERNEL_MORE must be called in global namespace"); \
msg)
extern int \
TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_(); \
// Register a new pass that can be applied on the IR.
static int use_jitkernel_##kernel_type##_##impl_type##_##place_type##_ \
#define REGISTER_PASS(pass_type, pass_class) \
UNUSED = \
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_()
__reg_pass__##pass_type, \
"REGISTER_PASS must be called in global namespace"); \
#define USE_JITKERNEL_MORE(kernel_type, impl_type) \
static ::paddle::framework::ir::PassRegistrar<pass_class> \
USE_KERNEL_MORE(kernel_type, impl_type, CPUPlace)
__pass_registrar_##pass_type##__(#pass_type); \
int TouchPassRegistrar_##pass_type() { \
__pass_registrar_##pass_type##__.Touch(); \
return 0; \
} \
static ::paddle::framework::ir::PassRegistrar<pass_class>& \
__pass_tmp_registrar_##pass_type##__ UNUSED = \
__pass_registrar_##pass_type##__
#define USE_PASS(pass_type) \
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
__use_pass_itself_##pass_type, \
"USE_PASS must be called in global namespace"); \
extern int TouchPassRegistrar_##pass_type(); \
static int use_pass_itself_##pass_type##_ UNUSED = \
TouchPassRegistrar_##pass_type()
*/
}
// namespace jitkernels
}
// namespace jitkernels
}
// namespace operators
}
// namespace operators
...
...
paddle/fluid/operators/jitkernels/test.cc
浏览文件 @
45bfa70c
...
@@ -19,8 +19,11 @@
...
@@ -19,8 +19,11 @@
#include "gflags/gflags.h"
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "gtest/gtest.h"
#include "paddle/fluid/operators/math/jit_kernel.h"
#include "paddle/fluid/operators/jitkernels/kernels.h"
#include "paddle/fluid/operators/math/jit_kernel_refer.h"
// TODO(TJ): remove me
#include "paddle/fluid/operators/jitkernels/registry.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/port.h"
#include "paddle/fluid/platform/port.h"
constexpr
int
repeat
=
20000
;
constexpr
int
repeat
=
20000
;
...
@@ -31,6 +34,75 @@ inline double GetCurrentUS() {
...
@@ -31,6 +34,75 @@ inline double GetCurrentUS() {
return
1e+6
*
time
.
tv_sec
+
time
.
tv_usec
;
return
1e+6
*
time
.
tv_sec
+
time
.
tv_usec
;
}
}
TEST
(
JitKernel
,
vmul
)
{}
template
<
typename
T
>
void
RandomVec
(
const
int
n
,
T
*
a
,
const
T
lower
=
static_cast
<
T
>
(
-
20.
f
),
const
T
upper
=
static_cast
<
T
>
(
20.
f
))
{
static
unsigned
int
seed
=
100
;
std
::
mt19937
rng
(
seed
++
);
std
::
uniform_real_distribution
<
double
>
uniform_dist
(
0
,
1
);
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
a
[
i
]
=
static_cast
<
T
>
(
uniform_dist
(
rng
)
*
(
upper
-
lower
)
+
lower
);
}
}
template
<
typename
T
>
void
ExpectEQ
(
const
T
*
target
,
const
T
*
refer
,
int
n
)
{
if
(
std
::
is_floating_point
<
T
>::
value
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
EXPECT_NEAR
(
target
[
i
],
refer
[
i
],
1e-3
);
}
}
else
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
EXPECT_EQ
(
target
[
i
],
refer
[
i
]);
}
}
}
// TODO(TJ): remove me
USE_JITKERNEL_MORE
(
vmul
,
mkl
);
USE_JITKERNEL_REFER
(
vmul
);
TEST
(
JitKernel
,
vmul
)
{
using
T
=
float
;
using
PlaceType
=
paddle
::
platform
::
CPUPlace
;
namespace
jit
=
paddle
::
operators
::
jitkernels
;
// TODO(TJ): test more vector size
for
(
int
d
=
1
;
d
<
30
;
++
d
)
{
auto
ref
=
jit
::
GetRefer
<
jit
::
vmul
,
T
,
void
(
*
)(
const
T
*
,
const
T
*
,
T
*
,
int
),
int
>
();
auto
tgt
=
jit
::
Get
<
jit
::
vmul
,
T
,
void
(
*
)(
const
T
*
,
const
T
*
,
T
*
,
int
),
int
,
PlaceType
>
(
d
);
EXPECT_TRUE
(
ref
!=
nullptr
);
EXPECT_TRUE
(
tgt
!=
nullptr
);
std
::
vector
<
T
>
x
(
d
),
y
(
d
);
std
::
vector
<
T
>
zref
(
d
),
ztgt
(
d
);
RandomVec
<
T
>
(
d
,
x
.
data
());
RandomVec
<
T
>
(
d
,
y
.
data
());
const
float
*
x_data
=
x
.
data
();
const
float
*
y_data
=
y
.
data
();
float
*
ztgt_data
=
ztgt
.
data
();
float
*
zref_data
=
zref
.
data
();
tgt
(
x_data
,
y_data
,
ztgt_data
,
d
);
ref
(
x_data
,
y_data
,
zref_data
,
d
);
ExpectEQ
<
T
>
(
ztgt_data
,
zref_data
,
d
);
// test inplace x
std
::
copy
(
x
.
begin
(),
x
.
end
(),
zref
.
begin
());
std
::
copy
(
x
.
begin
(),
x
.
end
(),
ztgt
.
begin
());
tgt
(
ztgt_data
,
y_data
,
ztgt_data
,
d
);
ref
(
zref_data
,
y_data
,
zref_data
,
d
);
ExpectEQ
<
T
>
(
ztgt_data
,
zref_data
,
d
);
// test inplace y
std
::
copy
(
y
.
begin
(),
y
.
end
(),
zref
.
begin
());
std
::
copy
(
y
.
begin
(),
y
.
end
(),
ztgt
.
begin
());
tgt
(
x_data
,
ztgt_data
,
ztgt_data
,
d
);
ref
(
x_data
,
zref_data
,
zref_data
,
d
);
ExpectEQ
<
T
>
(
ztgt_data
,
zref_data
,
d
);
}
}
TEST
(
JitKernel
,
pool
)
{}
TEST
(
JitKernel
,
pool
)
{}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录