Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
45bfa70c
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
45bfa70c
编写于
12月 03, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
complete vmul jit kernel
上级
77236e33
变更
12
显示空白变更内容
内联
并排
Showing
12 changed file
with
273 addition
and
126 deletion
+273
-126
paddle/fluid/operators/jitkernels/CMakeLists.txt
paddle/fluid/operators/jitkernels/CMakeLists.txt
+7
-5
paddle/fluid/operators/jitkernels/README.md
paddle/fluid/operators/jitkernels/README.md
+3
-0
paddle/fluid/operators/jitkernels/jitcode/jitcode.cc
paddle/fluid/operators/jitkernels/jitcode/jitcode.cc
+23
-0
paddle/fluid/operators/jitkernels/jitcode/jitcode.h
paddle/fluid/operators/jitkernels/jitcode/jitcode.h
+4
-3
paddle/fluid/operators/jitkernels/jitcode_base.h
paddle/fluid/operators/jitkernels/jitcode_base.h
+4
-5
paddle/fluid/operators/jitkernels/kernel_base.h
paddle/fluid/operators/jitkernels/kernel_base.h
+9
-4
paddle/fluid/operators/jitkernels/kernels.cc
paddle/fluid/operators/jitkernels/kernels.cc
+5
-2
paddle/fluid/operators/jitkernels/kernels.h
paddle/fluid/operators/jitkernels/kernels.h
+65
-45
paddle/fluid/operators/jitkernels/refer/refer.cc
paddle/fluid/operators/jitkernels/refer/refer.cc
+2
-1
paddle/fluid/operators/jitkernels/refer/refer.h
paddle/fluid/operators/jitkernels/refer/refer.h
+8
-0
paddle/fluid/operators/jitkernels/registry.h
paddle/fluid/operators/jitkernels/registry.h
+68
-58
paddle/fluid/operators/jitkernels/test.cc
paddle/fluid/operators/jitkernels/test.cc
+75
-3
未找到文件。
paddle/fluid/operators/jitkernels/CMakeLists.txt
浏览文件 @
45bfa70c
# set(use_jit_file ${PADDLE_BINARY_DIR}/paddle/fluid/operators/jit/kernels.h)
# file(WRITE ${pass_file} "// Generated by the paddle/fluid/framework/ir/CMakeLists.txt. DO NOT EDIT!\n\n")
# file(APPEND ${pass_file} "\#pragma once\n")
# file(APPEND ${pass_file} "\#include \"paddle/fluid/framework/ir/pass.h\"\n")
set
(
JIT_KERNEL_DEPS cpu_info cblas gflags enforce place
)
set
(
JIT_KERNEL_DEPS cpu_info cblas gflags enforce place
)
cc_library
(
jit_kernel_base SRCS kernels.cc DEPS
${
JIT_KERNEL_DEPS
}
)
cc_library
(
jit_kernel_base SRCS kernels.cc DEPS
${
JIT_KERNEL_DEPS
}
)
add_subdirectory
(
more
)
add_subdirectory
(
refer
)
add_subdirectory
(
refer
)
add_subdirectory
(
more
)
if
(
WITH_XBYAK
)
if
(
WITH_XBYAK
)
add_subdirectory
(
jitcode
)
add_subdirectory
(
jitcode
)
endif
()
endif
()
# Debug
message
(
STATUS
"--------
${
JIT_KERNEL_DEPS
}
"
)
cc_library
(
jit_kernel SRCS kernels.cc DEPS
${
JIT_KERNEL_DEPS
}
)
cc_library
(
jit_kernel SRCS kernels.cc DEPS
${
JIT_KERNEL_DEPS
}
)
cc_test
(
jit_kernel_test SRCS test.cc DEPS jit_kernel
)
cc_test
(
jit_kernel_test SRCS test.cc DEPS jit_kernel
)
paddle/fluid/operators/jitkernels/README.md
浏览文件 @
45bfa70c
TBD
TBD
# Use me
Add USE_JIT_KERNEL(yourname) to CMakefile.
paddle/fluid/operators/jitkernels/jitcode/jitcode.cc
浏览文件 @
45bfa70c
...
@@ -13,3 +13,26 @@
...
@@ -13,3 +13,26 @@
* limitations under the License. */
* limitations under the License. */
#include "paddle/fluid/operators/jitkernels/jitcode/jitcode.h"
#include "paddle/fluid/operators/jitkernels/jitcode/jitcode.h"
namespace
paddle
{
namespace
operators
{
namespace
jitkernels
{
template
<
>
size_t
GetKey
<
int
>
(
int
d
)
{
return
d
;
}
// template <>
// std::shared_ptr<const JitBase> CreateJitCode<KernelType::vmul, int>(int attr)
// {
// if (UseJitCode<KernelType::vmul, int>(attr)) {
// return std::make_shared<jitcode::VMulJitCode<int>>(attr,
// CodeSize<KernelType::vmul, int>(attr)));
// }
// return nullptr;
// }
}
// namespace jitkernels
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jitkernels/jitcode/jitcode.h
浏览文件 @
45bfa70c
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#pragma once
#pragma once
#include <type_traits>
#include <type_traits>
#include "paddle/fluid/operators/jitkernels/jitcode_base.h"
#include "paddle/fluid/operators/jitkernels/kernels.h"
#include "paddle/fluid/operators/jitkernels/kernels.h"
#define XBYAK_USE_MMAP_ALLOCATOR
#define XBYAK_USE_MMAP_ALLOCATOR
...
@@ -31,10 +32,10 @@ constexpr Xbyak::Operand::Code abi_param1(Xbyak::Operand::RDI),
...
@@ -31,10 +32,10 @@ constexpr Xbyak::Operand::Code abi_param1(Xbyak::Operand::RDI),
abi_param2
(
Xbyak
::
Operand
::
RSI
),
abi_param3
(
Xbyak
::
Operand
::
RDX
),
abi_param2
(
Xbyak
::
Operand
::
RSI
),
abi_param3
(
Xbyak
::
Operand
::
RDX
),
abi_param4
(
Xbyak
::
Operand
::
RCX
),
abi_not_param1
(
Xbyak
::
Operand
::
RCX
);
abi_param4
(
Xbyak
::
Operand
::
RCX
),
abi_not_param1
(
Xbyak
::
Operand
::
RCX
);
template
<
KernelType
KT
,
typename
Attr
>
template
<
typename
Attr
>
class
JitCode
:
public
JitBase
,
public
Xbyak
::
CodeGenerator
{
class
VMul
JitCode
:
public
JitBase
,
public
Xbyak
::
CodeGenerator
{
public:
public:
JitCode
(
Attr
attr
,
size_t
code_size
,
void
*
code_ptr
=
nullptr
)
VMul
JitCode
(
Attr
attr
,
size_t
code_size
,
void
*
code_ptr
=
nullptr
)
:
Xbyak
::
CodeGenerator
(
code_size
,
code_ptr
)
{
:
Xbyak
::
CodeGenerator
(
code_size
,
code_ptr
)
{
this
->
genCode
();
this
->
genCode
();
}
}
...
...
paddle/fluid/operators/jitkernels/jitcode_base.h
浏览文件 @
45bfa70c
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#pragma once
#pragma once
#include <gflags/gflags.h>
#include <gflags/gflags.h>
#include <memory> // for shared_ptr
#include "paddle/fluid/operators/jitkernels/kernel_base.h"
#include "paddle/fluid/operators/jitkernels/kernel_base.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/macros.h"
...
@@ -42,11 +43,6 @@ bool UseJitCode(Attr attr) {
...
@@ -42,11 +43,6 @@ bool UseJitCode(Attr attr) {
template
<
typename
Attr
>
template
<
typename
Attr
>
size_t
GetKey
(
Attr
attr
);
size_t
GetKey
(
Attr
attr
);
template
<
>
size_t
GetKey
<
int
>
(
int
d
)
{
return
d
;
}
class
JitBase
{
class
JitBase
{
public:
public:
JitBase
()
=
default
;
JitBase
()
=
default
;
...
@@ -68,6 +64,9 @@ class JitBase {
...
@@ -68,6 +64,9 @@ class JitBase {
void
dumpCode
(
const
unsigned
char
*
code
);
void
dumpCode
(
const
unsigned
char
*
code
);
};
};
template
<
KernelType
KT
,
typename
Attr
>
std
::
shared_ptr
<
const
JitBase
>
CreateJitCode
(
Attr
attr
);
}
// namespace jitkernels
}
// namespace jitkernels
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/fluid/operators/jitkernels/kernel_base.h
浏览文件 @
45bfa70c
...
@@ -25,6 +25,7 @@ typedef enum { vmul = 0, vadd = 1, vsub, vexp } KernelType;
...
@@ -25,6 +25,7 @@ typedef enum { vmul = 0, vadd = 1, vsub, vexp } KernelType;
class
Kernel
{
class
Kernel
{
public:
public:
Kernel
()
=
default
;
Kernel
()
=
default
;
virtual
~
Kernel
()
=
default
;
DISABLE_COPY_AND_ASSIGN
(
Kernel
);
DISABLE_COPY_AND_ASSIGN
(
Kernel
);
};
};
...
@@ -32,16 +33,20 @@ template <typename T, typename Func, typename Attr> // TODO(TJ): use tuple
...
@@ -32,16 +33,20 @@ template <typename T, typename Func, typename Attr> // TODO(TJ): use tuple
class
KernelImpl
:
public
Kernel
{
class
KernelImpl
:
public
Kernel
{
public:
public:
using
ELEMENT_TYPE
=
T
;
// TODO(TJ): remove me?
using
ELEMENT_TYPE
=
T
;
// TODO(TJ): remove me?
KernelImpl
()
=
default
;
virtual
Func
GetFunc
()
const
{
return
func
;
}
virtual
~
KernelImpl
()
=
default
;
virtual
Func
GetFunc
()
{
return
func
;
}
virtual
bool
UseMe
(
Attr
attr
)
const
=
0
;
virtual
bool
UseMe
(
Attr
attr
)
const
=
0
;
protected:
protected:
Func
func
{
nullptr
};
Func
func
{
nullptr
};
};
};
template
<
typename
T
,
typename
Func
,
typename
Attr
>
// TODO(TJ): use tuple
class
ReferKernel
:
public
KernelImpl
<
T
,
Func
,
Attr
>
{
public:
// Refer code can always be used
bool
UseMe
(
Attr
attr
)
const
override
{
return
true
;
}
};
}
// namespace jitkernels
}
// namespace jitkernels
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/fluid/operators/jitkernels/kernels.cc
浏览文件 @
45bfa70c
...
@@ -21,13 +21,16 @@ namespace paddle {
...
@@ -21,13 +21,16 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
namespace
jitkernels
{
namespace
jitkernels
{
// refer do not need useme, it would be the last one.
KernelPool
&
KernelPool
::
Instance
()
{
KernelPool
&
KernelPool
::
Instance
()
{
static
KernelPool
g_kernel_pool
;
static
KernelPool
g_kernel_pool
;
return
g_kernel_pool
;
return
g_kernel_pool
;
}
}
ReferKernelPool
&
ReferKernelPool
::
Instance
()
{
static
ReferKernelPool
g_refer_kernel_pool
;
return
g_refer_kernel_pool
;
}
}
// namespace jitkernels
}
// namespace jitkernels
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/fluid/operators/jitkernels/kernels.h
浏览文件 @
45bfa70c
...
@@ -18,22 +18,21 @@
...
@@ -18,22 +18,21 @@
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include <vector>
#include "paddle/fluid/operators/jitkernels/jitcode_base.h"
#include "paddle/fluid/operators/jitkernels/jitcode_base.h"
#include "paddle/fluid/operators/jitkernels/kernel_base.h"
#include "paddle/fluid/operators/jitkernels/kernel_base.h"
#include "paddle/fluid/operators/jitkernels/kernel_key.h"
#include "paddle/fluid/operators/jitkernels/kernel_key.h"
#include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_XBYAK
#include "paddle/fluid/operators/jitkernels/jitcode/jitcode.h"
#endif
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
namespace
jitkernels
{
namespace
jitkernels
{
// TODO(TJ): rename file to kernel_pool
template
<
KernelType
KT
>
template
<
KernelType
KT
>
class
JitCodePool
{
class
JitCodePool
{
public:
public:
JitCodePool
()
=
default
;
static
JitCodePool
&
Instance
()
{
static
JitCodePool
&
Instance
()
{
static
thread_local
JitCodePool
<
KT
>
g_jit_codes
;
static
thread_local
JitCodePool
<
KT
>
g_jit_codes
;
return
g_jit_codes
;
return
g_jit_codes
;
...
@@ -51,13 +50,11 @@ class JitCodePool {
...
@@ -51,13 +50,11 @@ class JitCodePool {
}
}
private:
private:
JitCodePool
()
=
default
;
std
::
unordered_map
<
size_t
,
std
::
shared_ptr
<
const
JitBase
>>
codes_
;
std
::
unordered_map
<
size_t
,
std
::
shared_ptr
<
const
JitBase
>>
codes_
;
DISABLE_COPY_AND_ASSIGN
(
JitCodePool
);
DISABLE_COPY_AND_ASSIGN
(
JitCodePool
);
};
};
// std::tuple<T, Func, Attr>
//
TODO(TJ):
std::tuple<T, Func, Attr>
template
<
typename
T
,
typename
Func
,
typename
Attr
>
template
<
typename
T
,
typename
Func
,
typename
Attr
>
struct
KernelAttr
{
struct
KernelAttr
{
typedef
T
data_type
;
typedef
T
data_type
;
...
@@ -65,76 +62,99 @@ struct KernelAttr {
...
@@ -65,76 +62,99 @@ struct KernelAttr {
typedef
Attr
attr_type
;
typedef
Attr
attr_type
;
};
};
typedef
std
::
unique_ptr
<
const
Kernel
>
KernelPtr
;
typedef
std
::
unordered_map
<
KernelKey
,
std
::
vector
<
KernelPtr
>
,
KernelKey
::
Hash
>
KernelMap
;
class
KernelPool
{
class
KernelPool
{
public:
public:
static
KernelPool
&
Instance
();
static
KernelPool
&
Instance
();
KernelPool
()
=
default
;
typedef
std
::
unique_ptr
<
const
Kernel
>
KernelPtr
;
typedef
std
::
unordered_map
<
KernelKey
,
std
::
vector
<
KernelPtr
>
,
KernelKey
::
Hash
>
KernelMap
;
KernelMap
&
AllKernels
()
{
return
pool_
;
}
KernelMap
&
AllKernels
()
{
return
pool_
;
}
void
Insert
(
const
KernelKey
&
key
,
KernelPtr
value
)
{
void
Insert
(
const
KernelKey
&
key
,
KernelPtr
value
)
{
if
(
pool_
.
find
(
key
)
==
pool_
.
end
())
{
if
(
pool_
.
find
(
key
)
==
pool_
.
end
())
{
pool_
.
emplace
(
key
,
std
::
vector
<
KernelPtr
>
());
pool_
.
emplace
(
key
,
std
::
vector
<
KernelPtr
>
());
}
}
pool_
.
at
(
key
).
emplace_back
(
std
::
move
(
value
));
pool_
.
at
(
key
).
emplace_back
(
std
::
move
(
value
));
}
}
KernelPool
()
=
default
;
private:
private:
KernelMap
pool_
;
KernelMap
pool_
;
DISABLE_COPY_AND_ASSIGN
(
KernelPool
);
DISABLE_COPY_AND_ASSIGN
(
KernelPool
);
};
};
// TODO(TJ): create_jitcode;
// Every kernel should have refer code and it should be used in unit tests,
// so refer kernels should have it's independent kernel pool
class
ReferKernelPool
{
public:
static
ReferKernelPool
&
Instance
();
ReferKernelPool
()
=
default
;
KernelMap
&
AllKernels
()
{
return
pool_
;
}
void
Insert
(
const
KernelKey
&
key
,
KernelPtr
value
)
{
if
(
pool_
.
find
(
key
)
==
pool_
.
end
())
{
pool_
.
emplace
(
key
,
std
::
vector
<
KernelPtr
>
());
}
pool_
.
at
(
key
).
emplace_back
(
std
::
move
(
value
));
}
private:
KernelMap
pool_
;
DISABLE_COPY_AND_ASSIGN
(
ReferKernelPool
);
};
// Refer code do not related with attr, and always on CPUPlace
template
<
KernelType
KT
,
typename
T
,
typename
Func
,
typename
Attr
>
inline
Func
GetRefer
()
{
auto
&
ref_pool
=
ReferKernelPool
().
Instance
().
AllKernels
();
KernelKey
kkey
(
KT
,
platform
::
CPUPlace
());
auto
ref_iter
=
ref_pool
.
find
(
kkey
);
PADDLE_ENFORCE
(
ref_iter
!=
ref_pool
.
end
(),
"Every Kernel should have reference function."
);
auto
&
ref_impls
=
ref_iter
->
second
;
for
(
auto
&
impl
:
ref_impls
)
{
auto
i
=
dynamic_cast
<
const
ReferKernel
<
T
,
Func
,
Attr
>*>
(
impl
.
get
());
if
(
i
)
{
return
i
->
GetFunc
();
}
}
return
nullptr
;
}
// TODO(TJ): make tuple? named KernelAttr
// TODO(TJ): make tuple? named KernelAttr
template
<
KernelType
KT
,
typename
T
,
typename
Func
,
typename
Attr
,
template
<
KernelType
KT
,
typename
T
,
typename
Func
,
typename
Attr
,
typename
PlaceType
=
platform
::
CPUPlace
>
typename
PlaceType
=
platform
::
CPUPlace
>
Func
Get
(
Attr
attr
)
{
Func
Get
(
Attr
attr
)
{
size_t
key
=
GetKey
<
Attr
>
(
attr
);
// size_t key = GetKey<Attr>(attr);
auto
jitcode
=
JitCodePool
<
KT
>
().
Instance
().
Get
(
key
);
// auto jitcode = JitCodePool<KT>().Instance().Get(key);
if
(
jitcode
)
{
// if (jitcode) {
return
jitcode
->
template
getCode
<
Func
>();
// return jitcode->template getCode<Func>();
// }
if
(
std
::
is_same
<
PlaceType
,
platform
::
CPUPlace
>::
value
&&
std
::
is_same
<
T
,
float
>::
value
)
{
// TODO(TJ): float move to create
// auto p = CreateJitCode<KT, Attr>(attr);
// if (p) {
// JitCodePool<KT>().Instance().Insert(key, p);
// return p->template getCode<Func>();
// }
}
}
#ifdef PADDLE_WITH_XBYAK
// pool: (KernelKey(type, place), vector<Kernel>)
// // jitcode::JitCode is under protection of PADDLE_WITH_XBYAK
// if (std::is_same<PlaceType, platform::CPUPlace>::value) {
// if (UseJitCode<KT, T, Attr>(attr)) {
// std::shared_ptr<JitBase> p(std::make_shared<jitcode::JitCode<KT, Attr>>(
// attr, CodeSize<KT, Attr>(attr)));
// JitCodePool<KT>().Instance().Insert(key, p);
// return p->getCode<Func>();
// }
// }
#endif
// (KernelKey(type, place), vector<Kernel>)
auto
&
pool
=
KernelPool
().
Instance
().
AllKernels
();
auto
&
pool
=
KernelPool
().
Instance
().
AllKernels
();
KernelKey
kkey
(
KT
,
PlaceType
());
KernelKey
kkey
(
KT
,
PlaceType
());
auto
iter
=
pool
.
find
(
kkey
);
auto
iter
=
pool
.
find
(
kkey
);
if
(
iter
!=
pool
.
end
())
{
if
(
iter
!=
pool
.
end
())
{
auto
impls
=
iter
->
second
;
auto
&
impls
=
iter
->
second
;
for
(
auto
impl
:
impls
)
{
for
(
auto
&
impl
:
impls
)
{
auto
i
=
std
::
dynamic_pointer_cast
<
KernelImpl
<
T
,
Func
,
Attr
>
>
(
impl
.
get
());
auto
i
=
dynamic_cast
<
const
KernelImpl
<
T
,
Func
,
Attr
>*
>
(
impl
.
get
());
if
(
i
&&
i
->
UseMe
(
attr
))
{
if
(
i
&&
i
->
UseMe
(
attr
))
{
return
i
->
GetFunc
();
return
i
->
GetFunc
();
}
}
}
}
}
}
// The last implementation should be reference function on CPU
// The last implementation should be reference function on CPUPlace.
// Every kernel should have refer code.
return
GetRefer
<
KT
,
T
,
Func
,
Attr
>
();
// because of test refer should have it's own pool
// PADDLE_ENFORCE_GT(list.size(), 1) << "Should have refer implemtation";
// const auto& refer = KernelRefer<KT, T>().AllKernels();
// return refer.Get<Func>();
return
nullptr
;
}
}
}
// namespace jitkernels
}
// namespace jitkernels
...
...
paddle/fluid/operators/jitkernels/refer/refer.cc
浏览文件 @
45bfa70c
...
@@ -17,4 +17,5 @@
...
@@ -17,4 +17,5 @@
namespace
refer
=
paddle
::
operators
::
jitkernels
::
refer
;
namespace
refer
=
paddle
::
operators
::
jitkernels
::
refer
;
// REGISTER_JITKERNEL_REFER(vmul, refer::VMul<float>, refer::VMul<double>);
REGISTER_JITKERNEL_REFER
(
vmul
,
refer
::
VMulKernel
<
float
>
,
refer
::
VMulKernel
<
double
>
);
paddle/fluid/operators/jitkernels/refer/refer.h
浏览文件 @
45bfa70c
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
* limitations under the License. */
* limitations under the License. */
#pragma once
#pragma once
#include "paddle/fluid/operators/jitkernels/kernel_base.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -27,6 +28,13 @@ void VMul(const T* x, const T* y, T* z, int n) {
...
@@ -27,6 +28,13 @@ void VMul(const T* x, const T* y, T* z, int n) {
}
}
}
}
template
<
typename
T
>
class
VMulKernel
:
public
ReferKernel
<
T
,
void
(
*
)(
const
T
*
,
const
T
*
,
T
*
,
int
),
int
>
{
public:
VMulKernel
()
{
this
->
func
=
VMul
<
T
>
;
}
};
}
// namespace refer
}
// namespace refer
}
// namespace jitkernels
}
// namespace jitkernels
}
// namespace operators
}
// namespace operators
...
...
paddle/fluid/operators/jitkernels/registry.h
浏览文件 @
45bfa70c
...
@@ -20,6 +20,7 @@
...
@@ -20,6 +20,7 @@
#include "paddle/fluid/operators/jitkernels/kernel_base.h"
#include "paddle/fluid/operators/jitkernels/kernel_base.h"
#include "paddle/fluid/operators/jitkernels/kernels.h"
#include "paddle/fluid/operators/jitkernels/kernels.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/variant.h" // for UNUSED
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -32,37 +33,40 @@ inline std::unique_ptr<T> make_unique(Args&&... args) {
...
@@ -32,37 +33,40 @@ inline std::unique_ptr<T> make_unique(Args&&... args) {
return
std
::
unique_ptr
<
T
>
(
new
T
(
std
::
forward
<
Args
>
(
args
)...));
return
std
::
unique_ptr
<
T
>
(
new
T
(
std
::
forward
<
Args
>
(
args
)...));
}
}
template
<
typename
PlaceType
,
bool
IsEnd
,
size_t
I
,
typename
...
KernelImpls
>
template
<
typename
Pool
,
typename
PlaceType
,
bool
IsEnd
,
size_t
I
,
typename
...
KernelImpls
>
struct
JitKernelRegistrarFunctor
;
struct
JitKernelRegistrarFunctor
;
template
<
typename
PlaceType
,
size_t
I
,
typename
...
KernelImpls
>
template
<
typename
P
ool
,
typename
P
laceType
,
size_t
I
,
typename
...
KernelImpls
>
struct
JitKernelRegistrarFunctor
<
PlaceType
,
true
,
I
,
KernelImpls
...
>
{
struct
JitKernelRegistrarFunctor
<
P
ool
,
P
laceType
,
true
,
I
,
KernelImpls
...
>
{
void
operator
()(
KernelType
kt
)
const
{}
void
operator
()(
KernelType
kt
)
const
{}
};
};
template
<
typename
PlaceType
,
size_t
I
,
typename
...
KernelImpls
>
template
<
typename
P
ool
,
typename
P
laceType
,
size_t
I
,
typename
...
KernelImpls
>
struct
JitKernelRegistrarFunctor
<
PlaceType
,
false
,
I
,
KernelImpls
...
>
{
struct
JitKernelRegistrarFunctor
<
P
ool
,
P
laceType
,
false
,
I
,
KernelImpls
...
>
{
using
KERNEL_IMPL_TYPE
=
using
KERNEL_IMPL_TYPE
=
typename
std
::
tuple_element
<
I
,
std
::
tuple
<
KernelImpls
...
>>::
type
;
typename
std
::
tuple_element
<
I
,
std
::
tuple
<
KernelImpls
...
>>::
type
;
void
operator
()(
KernelType
kt
)
const
{
void
operator
()(
KernelType
kt
)
const
{
KernelKey
kkey
(
kt
,
PlaceType
());
KernelKey
kkey
(
kt
,
PlaceType
());
KernelPool
().
Instance
().
Insert
(
Pool
().
Instance
().
Insert
(
kkey
,
kkey
,
std
::
move
(
make_unique
<
const
KERNEL_IMPL_TYPE
>
()));
std
::
move
(
make_unique
<
const
KERNEL_IMPL_TYPE
>
()));
constexpr
auto
size
=
std
::
tuple_size
<
std
::
tuple
<
KernelImpls
...
>>::
value
;
constexpr
auto
size
=
std
::
tuple_size
<
std
::
tuple
<
KernelImpls
...
>>::
value
;
JitKernelRegistrarFunctor
<
PlaceType
,
I
+
1
==
size
,
I
+
1
,
KernelImpls
...
>
JitKernelRegistrarFunctor
<
Pool
,
PlaceType
,
I
+
1
==
size
,
I
+
1
,
KernelImpls
...
>
func
;
func
;
func
(
kt
);
func
(
kt
);
}
}
};
};
template
<
typename
PlaceType
,
typename
...
KernelImpls
>
template
<
typename
P
ool
,
typename
P
laceType
,
typename
...
KernelImpls
>
class
JitKernelRegistrar
{
class
JitKernelRegistrar
{
public:
public:
explicit
JitKernelRegistrar
(
KernelType
kt
)
{
explicit
JitKernelRegistrar
(
KernelType
kt
)
{
JitKernelRegistrarFunctor
<
PlaceType
,
false
,
0
,
KernelImpls
...
>
func
;
JitKernelRegistrarFunctor
<
P
ool
,
P
laceType
,
false
,
0
,
KernelImpls
...
>
func
;
func
(
kt
);
func
(
kt
);
}
}
void
Touch
()
{}
};
};
#define STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE(uniq_name, msg) \
#define STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE(uniq_name, msg) \
...
@@ -71,17 +75,40 @@ class JitKernelRegistrar {
...
@@ -71,17 +75,40 @@ class JitKernelRegistrar {
__test_global_namespace_##uniq_name##__>::value, \
__test_global_namespace_##uniq_name##__>::value, \
msg)
msg)
// Refer always on CPUPlace
#define REGISTER_JITKERNEL_REFER(kernel_type, ...) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_##kernel_type##_refer_CPUPlace, \
"REGISTER_KERNEL_REFER must be called in global namespace"); \
static ::paddle::operators::jitkernels::JitKernelRegistrar< \
::paddle::operators::jitkernels::ReferKernelPool, \
::paddle::platform::CPUPlace, __VA_ARGS__> \
__jit_kernel_registrar_##kernel_type##_refer_CPUPlace_( \
::paddle::operators::jitkernels::KernelType::kernel_type); \
int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_() { \
__jit_kernel_registrar_##kernel_type##_refer_CPUPlace_.Touch(); \
return 0; \
}
// kernel_type: should be in paddle::operators::jitkernels::KernelType
// kernel_type: should be in paddle::operators::jitkernels::KernelType
// place_type: should be one of CPUPlace and GPUPlace in paddle::platform
// place_type: should be one of CPUPlace and GPUPlace in paddle::platform
#define REGISTER_KERNEL_MORE(kernel_type, impl_type, place_type, ...) \
#define REGISTER_KERNEL_MORE(kernel_type, impl_type, place_type, ...) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_##kernel_type##_##impl_type##_##place_type, \
__reg_jitkernel_##kernel_type##_##impl_type##_##place_type, \
"REGISTER_KERNEL_MORE must be called in global namespace"); \
"REGISTER_KERNEL_MORE must be called in global namespace"); \
extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
static int __assert_##kernel_type##_##impl_type##_##place_type##_has_refer_ \
UNUSED = TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
static ::paddle::operators::jitkernels::JitKernelRegistrar< \
static ::paddle::operators::jitkernels::JitKernelRegistrar< \
::paddle::operators::jitkernels::KernelPool, \
::paddle::platform::place_type, __VA_ARGS__> \
::paddle::platform::place_type, __VA_ARGS__> \
__jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##__( \
__jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##_( \
::paddle::operators::jitkernels::KernelType::kernel_type)
::paddle::operators::jitkernels::KernelType::kernel_type); \
// TODO(TJ): Add Touch and use me
int TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_() { \
__jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##_ \
.Touch(); \
return 0; \
}
#define REGISTER_JITKERNEL_MORE(kernel_type, impl_type, ...) \
#define REGISTER_JITKERNEL_MORE(kernel_type, impl_type, ...) \
REGISTER_KERNEL_MORE(kernel_type, impl_type, CPUPlace, __VA_ARGS__)
REGISTER_KERNEL_MORE(kernel_type, impl_type, CPUPlace, __VA_ARGS__)
...
@@ -89,45 +116,28 @@ class JitKernelRegistrar {
...
@@ -89,45 +116,28 @@ class JitKernelRegistrar {
#define REGISTER_GPUKERNEL_MORE(kernel_type, impl_type, ...) \
#define REGISTER_GPUKERNEL_MORE(kernel_type, impl_type, ...) \
REGISTER_KERNEL_MORE(kernel_type, impl_type, GPUPlace, __VA_ARGS__)
REGISTER_KERNEL_MORE(kernel_type, impl_type, GPUPlace, __VA_ARGS__)
/*
// REGISTER_JITKERNEL_JITCODE(vmul, JitKernelCode<vmul, int>);
REGISTER_JITKERNEL_JITCODE(vmul, JitKernelCode<vmul, int>);
// refer must be only one and at least one
REGISTER_JITKERNEL_REFER(vmul, VMul); // Refer need support dtype
// you can register more implementations and the condition when use it
#define USE_JITKERNEL_REFER(kernel_type) \
REGISTER_JITKERNEL_MORE(vmul, mkl::VMUL<float>, UseMe<float>, mkl::VMUL<double>,
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
UseMe<double>)
__reg_jitkernel_##kernel_type##_refer_CPUPlace_, \
"USE_JITKERNEL_REFER must be called in global namespace"); \
#define STATIC_ASSERT_PASS_GLOBAL_NAMESPACE(uniq_name, msg) \
extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
struct __test_global_namespace_##uniq_name##__ {}; \
static int use_jitkernel_##kernel_type##_refer_CPUPlace_ UNUSED = \
static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \
TouchJitKernelReg_##kernel_type##_refer_CPUPlace_()
__test_global_namespace_##uniq_name##__>::value, \
msg)
// Register a new pass that can be applied on the IR.
#define USE_KERNEL_MORE(kernel_type, impl_type, place_type) \
#define REGISTER_PASS(pass_type, pass_class) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
__reg_jitkernel_##kernel_type##_##impl_type##_##place_type##_, \
__reg_pass__##pass_type, \
"USE_JITKERNEL_MORE must be called in global namespace"); \
"REGISTER_PASS must be called in global namespace"); \
extern int \
static ::paddle::framework::ir::PassRegistrar<pass_class> \
TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_(); \
__pass_registrar_##pass_type##__(#pass_type); \
static int use_jitkernel_##kernel_type##_##impl_type##_##place_type##_ \
int TouchPassRegistrar_##pass_type() { \
UNUSED = \
__pass_registrar_##pass_type##__.Touch(); \
TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_()
return 0; \
} \
#define USE_JITKERNEL_MORE(kernel_type, impl_type) \
static ::paddle::framework::ir::PassRegistrar<pass_class>& \
USE_KERNEL_MORE(kernel_type, impl_type, CPUPlace)
__pass_tmp_registrar_##pass_type##__ UNUSED = \
__pass_registrar_##pass_type##__
#define USE_PASS(pass_type) \
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
__use_pass_itself_##pass_type, \
"USE_PASS must be called in global namespace"); \
extern int TouchPassRegistrar_##pass_type(); \
static int use_pass_itself_##pass_type##_ UNUSED = \
TouchPassRegistrar_##pass_type()
*/
}
// namespace jitkernels
}
// namespace jitkernels
}
// namespace operators
}
// namespace operators
...
...
paddle/fluid/operators/jitkernels/test.cc
浏览文件 @
45bfa70c
...
@@ -19,8 +19,11 @@
...
@@ -19,8 +19,11 @@
#include "gflags/gflags.h"
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "gtest/gtest.h"
#include "paddle/fluid/operators/math/jit_kernel.h"
#include "paddle/fluid/operators/jitkernels/kernels.h"
#include "paddle/fluid/operators/math/jit_kernel_refer.h"
// TODO(TJ): remove me
#include "paddle/fluid/operators/jitkernels/registry.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/port.h"
#include "paddle/fluid/platform/port.h"
constexpr
int
repeat
=
20000
;
constexpr
int
repeat
=
20000
;
...
@@ -31,6 +34,75 @@ inline double GetCurrentUS() {
...
@@ -31,6 +34,75 @@ inline double GetCurrentUS() {
return
1e+6
*
time
.
tv_sec
+
time
.
tv_usec
;
return
1e+6
*
time
.
tv_sec
+
time
.
tv_usec
;
}
}
TEST
(
JitKernel
,
vmul
)
{}
template
<
typename
T
>
void
RandomVec
(
const
int
n
,
T
*
a
,
const
T
lower
=
static_cast
<
T
>
(
-
20.
f
),
const
T
upper
=
static_cast
<
T
>
(
20.
f
))
{
static
unsigned
int
seed
=
100
;
std
::
mt19937
rng
(
seed
++
);
std
::
uniform_real_distribution
<
double
>
uniform_dist
(
0
,
1
);
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
a
[
i
]
=
static_cast
<
T
>
(
uniform_dist
(
rng
)
*
(
upper
-
lower
)
+
lower
);
}
}
template
<
typename
T
>
void
ExpectEQ
(
const
T
*
target
,
const
T
*
refer
,
int
n
)
{
if
(
std
::
is_floating_point
<
T
>::
value
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
EXPECT_NEAR
(
target
[
i
],
refer
[
i
],
1e-3
);
}
}
else
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
EXPECT_EQ
(
target
[
i
],
refer
[
i
]);
}
}
}
// TODO(TJ): remove me
USE_JITKERNEL_MORE
(
vmul
,
mkl
);
USE_JITKERNEL_REFER
(
vmul
);
TEST
(
JitKernel
,
vmul
)
{
using
T
=
float
;
using
PlaceType
=
paddle
::
platform
::
CPUPlace
;
namespace
jit
=
paddle
::
operators
::
jitkernels
;
// TODO(TJ): test more vector size
for
(
int
d
=
1
;
d
<
30
;
++
d
)
{
auto
ref
=
jit
::
GetRefer
<
jit
::
vmul
,
T
,
void
(
*
)(
const
T
*
,
const
T
*
,
T
*
,
int
),
int
>
();
auto
tgt
=
jit
::
Get
<
jit
::
vmul
,
T
,
void
(
*
)(
const
T
*
,
const
T
*
,
T
*
,
int
),
int
,
PlaceType
>
(
d
);
EXPECT_TRUE
(
ref
!=
nullptr
);
EXPECT_TRUE
(
tgt
!=
nullptr
);
std
::
vector
<
T
>
x
(
d
),
y
(
d
);
std
::
vector
<
T
>
zref
(
d
),
ztgt
(
d
);
RandomVec
<
T
>
(
d
,
x
.
data
());
RandomVec
<
T
>
(
d
,
y
.
data
());
const
float
*
x_data
=
x
.
data
();
const
float
*
y_data
=
y
.
data
();
float
*
ztgt_data
=
ztgt
.
data
();
float
*
zref_data
=
zref
.
data
();
tgt
(
x_data
,
y_data
,
ztgt_data
,
d
);
ref
(
x_data
,
y_data
,
zref_data
,
d
);
ExpectEQ
<
T
>
(
ztgt_data
,
zref_data
,
d
);
// test inplace x
std
::
copy
(
x
.
begin
(),
x
.
end
(),
zref
.
begin
());
std
::
copy
(
x
.
begin
(),
x
.
end
(),
ztgt
.
begin
());
tgt
(
ztgt_data
,
y_data
,
ztgt_data
,
d
);
ref
(
zref_data
,
y_data
,
zref_data
,
d
);
ExpectEQ
<
T
>
(
ztgt_data
,
zref_data
,
d
);
// test inplace y
std
::
copy
(
y
.
begin
(),
y
.
end
(),
zref
.
begin
());
std
::
copy
(
y
.
begin
(),
y
.
end
(),
ztgt
.
begin
());
tgt
(
x_data
,
ztgt_data
,
ztgt_data
,
d
);
ref
(
x_data
,
zref_data
,
zref_data
,
d
);
ExpectEQ
<
T
>
(
ztgt_data
,
zref_data
,
d
);
}
}
TEST
(
JitKernel
,
pool
)
{}
TEST
(
JitKernel
,
pool
)
{}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录