Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
900c789a
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
900c789a
编写于
12月 10, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
use jitcode and use vmul
上级
53709e7e
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
137 addition
and
76 deletion
+137
-76
paddle/fluid/operators/jit/gen/blas.cc
paddle/fluid/operators/jit/gen/blas.cc
+17
-9
paddle/fluid/operators/jit/gen/blas.h
paddle/fluid/operators/jit/gen/blas.h
+12
-10
paddle/fluid/operators/jit/gen/jitcode.cc
paddle/fluid/operators/jit/gen/jitcode.cc
+1
-18
paddle/fluid/operators/jit/gen/jitcode.h
paddle/fluid/operators/jit/gen/jitcode.h
+4
-6
paddle/fluid/operators/jit/gen_base.cc
paddle/fluid/operators/jit/gen_base.cc
+5
-0
paddle/fluid/operators/jit/gen_base.h
paddle/fluid/operators/jit/gen_base.h
+28
-23
paddle/fluid/operators/jit/kernel_pool.cc
paddle/fluid/operators/jit/kernel_pool.cc
+5
-0
paddle/fluid/operators/jit/kernel_pool.h
paddle/fluid/operators/jit/kernel_pool.h
+40
-9
paddle/fluid/operators/jit/registry.h
paddle/fluid/operators/jit/registry.h
+24
-1
paddle/fluid/operators/jit/test.cc
paddle/fluid/operators/jit/test.cc
+1
-0
未找到文件。
paddle/fluid/operators/jit/gen/blas.cc
浏览文件 @
900c789a
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#include "paddle/fluid/operators/jit/gen/blas.h"
#include "paddle/fluid/operators/jit/gen/blas.h"
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -103,17 +104,24 @@ void VXXJitCode::genCode() {
...
@@ -103,17 +104,24 @@ void VXXJitCode::genCode() {
ret
();
ret
();
}
}
}
// namespace gen
class
VMulCreator
:
public
JitCodeCreator
<
int
>
{
public:
template
<
>
bool
UseMe
(
const
int
&
attr
)
const
override
{
std
::
unique_ptr
<
GenBase
>
CreateJitCode
<
KernelType
::
vmul
,
float
,
int
>
(
int
attr
)
{
return
platform
::
MayIUse
(
platform
::
avx
);
if
(
UseJitCode
<
KernelType
::
vmul
,
float
,
int
>
(
attr
))
{
return
make_unique
<
gen
::
VMulJitCode
>
(
attr
,
CodeSize
<
KernelType
::
vmul
,
float
,
int
>
(
attr
));
}
}
return
nullptr
;
size_t
CodeSize
(
const
int
&
d
)
const
override
{
}
return
96
+
d
/
YMM_FLOAT_BLOCK
*
4
*
8
;
}
std
::
unique_ptr
<
GenBase
>
CreateJitCode
(
const
int
&
attr
)
const
override
{
return
make_unique
<
VMulJitCode
>
(
attr
,
CodeSize
(
attr
));
}
};
}
// namespace gen
}
// namespace jit
}
// namespace jit
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
namespace
gen
=
paddle
::
operators
::
jit
::
gen
;
REGISTER_JITKERNEL_GEN
(
vmul
,
gen
::
VMulCreator
);
paddle/fluid/operators/jit/gen/blas.h
浏览文件 @
900c789a
...
@@ -25,7 +25,18 @@ namespace gen {
...
@@ -25,7 +25,18 @@ namespace gen {
// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu)
// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu)
class
VXXJitCode
:
public
JitCode
{
class
VXXJitCode
:
public
JitCode
{
public:
public:
const
char
*
name
()
const
override
{
explicit
VXXJitCode
(
int
d
,
operand_type
type
,
int
scalar_index
,
bool
with_relu
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
:
JitCode
(
code_size
,
code_ptr
),
num_
(
d
),
type_
(
type
),
scalar_index_
(
scalar_index
),
with_relu_
(
with_relu
)
{
this
->
genCode
();
}
virtual
const
char
*
name
()
const
{
std
::
string
base
=
"VXXJitCode"
;
std
::
string
base
=
"VXXJitCode"
;
if
(
scalar_index_
==
1
)
{
if
(
scalar_index_
==
1
)
{
base
+=
"_Scalar"
;
base
+=
"_Scalar"
;
...
@@ -45,15 +56,6 @@ class VXXJitCode : public JitCode {
...
@@ -45,15 +56,6 @@ class VXXJitCode : public JitCode {
base
+=
(
with_relu_
?
"_Relu"
:
""
);
base
+=
(
with_relu_
?
"_Relu"
:
""
);
return
base
.
c_str
();
return
base
.
c_str
();
}
}
explicit
VXXJitCode
(
int
d
,
operand_type
type
,
int
scalar_index
,
bool
with_relu
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
:
JitCode
(
code_size
,
code_ptr
),
num_
(
d
),
type_
(
type
),
scalar_index_
(
scalar_index
),
with_relu_
(
with_relu
)
{}
// static bool init(int d, int scalar_index = 0);
void
genCode
()
override
;
void
genCode
()
override
;
private:
private:
...
...
paddle/fluid/operators/jit/gen/jitcode.cc
浏览文件 @
900c789a
...
@@ -16,23 +16,6 @@
...
@@ -16,23 +16,6 @@
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
namespace
jit
{
namespace
jit
{}
// namespace jit
template
<
>
size_t
GetKey
<
int
>
(
int
d
)
{
return
d
;
}
// template <>
// std::shared_ptr<const GenBase> CreateJitCode<KernelType::vmul, int>(int attr)
// {
// if (UseJitCode<KernelType::vmul, int>(attr)) {
// return std::make_shared<gen::VMulJitCode<int>>(attr,
// CodeSize<KernelType::vmul, int>(attr)));
// }
// return nullptr;
// }
}
// namespace jit
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/fluid/operators/jit/gen/jitcode.h
浏览文件 @
900c789a
...
@@ -70,9 +70,10 @@ typedef enum {
...
@@ -70,9 +70,10 @@ typedef enum {
class
JitCode
:
public
GenBase
,
public
Xbyak
::
CodeGenerator
{
class
JitCode
:
public
GenBase
,
public
Xbyak
::
CodeGenerator
{
public:
public:
explicit
JitCode
(
size_t
code_size
,
void
*
code_ptr
=
nullptr
)
explicit
JitCode
(
size_t
code_size
,
void
*
code_ptr
=
nullptr
)
:
Xbyak
::
CodeGenerator
(
code_size
,
code_ptr
)
{
:
Xbyak
::
CodeGenerator
(
code_size
,
code_ptr
)
{}
this
->
genCode
();
}
virtual
const
char
*
name
()
const
=
0
;
virtual
void
genCode
()
=
0
;
size_t
getSize
()
const
override
{
return
CodeGenerator
::
getSize
();
}
size_t
getSize
()
const
override
{
return
CodeGenerator
::
getSize
();
}
const
unsigned
char
*
getCodeInternal
()
override
{
const
unsigned
char
*
getCodeInternal
()
override
{
...
@@ -80,9 +81,6 @@ class JitCode : public GenBase, public Xbyak::CodeGenerator {
...
@@ -80,9 +81,6 @@ class JitCode : public GenBase, public Xbyak::CodeGenerator {
return
code
;
return
code
;
}
}
virtual
const
char
*
name
()
const
=
0
;
virtual
void
genCode
()
=
0
;
protected:
protected:
Xbyak
::
Reg64
param1
{
abi_param1
};
Xbyak
::
Reg64
param1
{
abi_param1
};
const
int
EVEX_max_8b_offt
=
0x200
;
const
int
EVEX_max_8b_offt
=
0x200
;
...
...
paddle/fluid/operators/jit/gen_base.cc
浏览文件 @
900c789a
...
@@ -23,6 +23,11 @@ namespace paddle {
...
@@ -23,6 +23,11 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
namespace
jit
{
namespace
jit
{
template
<
>
size_t
JitCodeKey
<
int
>
(
int
d
)
{
return
d
;
}
// refer do not need useme, it would be the last one.
// refer do not need useme, it would be the last one.
void
GenBase
::
dumpCode
(
const
unsigned
char
*
code
)
const
{
void
GenBase
::
dumpCode
(
const
unsigned
char
*
code
)
const
{
if
(
code
)
{
if
(
code
)
{
...
...
paddle/fluid/operators/jit/gen_base.h
浏览文件 @
900c789a
...
@@ -15,9 +15,8 @@
...
@@ -15,9 +15,8 @@
#pragma once
#pragma once
#include <gflags/gflags.h>
#include <gflags/gflags.h>
#include <memory> // for
shared
_ptr
#include <memory> // for
unique
_ptr
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/platform/macros.h"
DECLARE_bool
(
dump_jitcode
);
DECLARE_bool
(
dump_jitcode
);
...
@@ -25,29 +24,12 @@ namespace paddle {
...
@@ -25,29 +24,12 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
namespace
jit
{
namespace
jit
{
// TODO(TJ): make these functions as virtual of a class
// Every JitCode should estimate the code size itself
template
<
KernelType
KT
,
typename
T
,
typename
Attr
>
size_t
CodeSize
(
Attr
attr
)
{
return
4096
;
}
// Every JitCode should have a condition when to use this JitCode
template
<
KernelType
KT
,
typename
T
,
typename
Attr
>
bool
UseJitCode
(
Attr
attr
)
{
return
false
;
}
// Every JitCode should have a method to get the key from attribution
template
<
typename
Attr
>
size_t
GetKey
(
Attr
attr
);
class
GenBase
:
public
Kernel
{
class
GenBase
:
public
Kernel
{
public:
public:
virtual
~
GenBase
()
=
default
;
virtual
const
char
*
name
()
const
=
0
;
virtual
const
char
*
name
()
const
=
0
;
virtual
const
unsigned
char
*
getCodeInternal
()
=
0
;
virtual
size_t
getSize
()
const
=
0
;
virtual
size_t
getSize
()
const
=
0
;
virtual
const
unsigned
char
*
getCodeInternal
()
=
0
;
template
<
typename
FUNC
>
template
<
typename
FUNC
>
const
FUNC
getCode
()
{
const
FUNC
getCode
()
{
const
unsigned
char
*
code
=
this
->
getCodeInternal
();
const
unsigned
char
*
code
=
this
->
getCodeInternal
();
...
@@ -61,8 +43,31 @@ class GenBase : public Kernel {
...
@@ -61,8 +43,31 @@ class GenBase : public Kernel {
void
dumpCode
(
const
unsigned
char
*
code
)
const
;
void
dumpCode
(
const
unsigned
char
*
code
)
const
;
};
};
template
<
KernelType
KT
,
typename
T
,
typename
Attr
>
// Every JitCode should have a method to get the key from attribution
std
::
unique_ptr
<
GenBase
>
CreateJitCode
(
Attr
attr
);
template
<
typename
Attr
>
size_t
JitCodeKey
(
Attr
attr
);
// Creator is used to creat the jitcode and save in pool.
// Every JitCode should have one creator.
class
GenCreator
{
public:
virtual
~
GenCreator
()
=
default
;
};
template
<
typename
Attr
>
class
JitCodeCreator
:
public
GenCreator
{
public:
virtual
~
JitCodeCreator
()
=
default
;
// condition when this jit code can be used.
virtual
bool
UseMe
(
const
Attr
&
attr
)
const
=
0
;
// estimate this code size
virtual
size_t
CodeSize
(
const
Attr
&
attr
)
const
=
0
;
// create this code
virtual
std
::
unique_ptr
<
GenBase
>
CreateJitCode
(
const
Attr
&
attr
)
const
=
0
;
};
}
// namespace jit
}
// namespace jit
}
// namespace operators
}
// namespace operators
...
...
paddle/fluid/operators/jit/kernel_pool.cc
浏览文件 @
900c789a
...
@@ -21,6 +21,11 @@ namespace paddle {
...
@@ -21,6 +21,11 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
namespace
jit
{
namespace
jit
{
JitCodeCreatorPool
&
JitCodeCreatorPool
::
Instance
()
{
static
JitCodeCreatorPool
g_creator_pool
;
return
g_creator_pool
;
}
KernelPool
&
KernelPool
::
Instance
()
{
KernelPool
&
KernelPool
::
Instance
()
{
static
KernelPool
g_kernel_pool
;
static
KernelPool
g_kernel_pool
;
return
g_kernel_pool
;
return
g_kernel_pool
;
...
...
paddle/fluid/operators/jit/kernel_pool.h
浏览文件 @
900c789a
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
#pragma once
#pragma once
#include <memory> // for
shared
_ptr
#include <memory> // for
unique
_ptr
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include <vector>
...
@@ -52,6 +52,28 @@ class JitCodePool {
...
@@ -52,6 +52,28 @@ class JitCodePool {
DISABLE_COPY_AND_ASSIGN
(
JitCodePool
);
DISABLE_COPY_AND_ASSIGN
(
JitCodePool
);
};
};
class
JitCodeCreatorPool
{
typedef
std
::
unique_ptr
<
const
GenCreator
>
GenCreatorPtr
;
typedef
std
::
unordered_map
<
KernelKey
,
std
::
vector
<
GenCreatorPtr
>
,
KernelKey
::
Hash
>
GenCreatorPtrMap
;
public:
JitCodeCreatorPool
()
=
default
;
static
JitCodeCreatorPool
&
Instance
();
GenCreatorPtrMap
&
AllCreators
()
{
return
creators_
;
}
void
Insert
(
const
KernelKey
&
key
,
GenCreatorPtr
value
)
{
if
(
creators_
.
find
(
key
)
==
creators_
.
end
())
{
creators_
.
emplace
(
key
,
std
::
vector
<
GenCreatorPtr
>
());
}
creators_
.
at
(
key
).
emplace_back
(
std
::
move
(
value
));
}
private:
GenCreatorPtrMap
creators_
;
DISABLE_COPY_AND_ASSIGN
(
JitCodeCreatorPool
);
};
typedef
std
::
unique_ptr
<
const
Kernel
>
KernelPtr
;
typedef
std
::
unique_ptr
<
const
Kernel
>
KernelPtr
;
typedef
std
::
unordered_map
<
KernelKey
,
std
::
vector
<
KernelPtr
>
,
KernelKey
::
Hash
>
typedef
std
::
unordered_map
<
KernelKey
,
std
::
vector
<
KernelPtr
>
,
KernelKey
::
Hash
>
KernelMap
;
KernelMap
;
...
@@ -113,24 +135,33 @@ inline Func GetRefer() {
...
@@ -113,24 +135,33 @@ inline Func GetRefer() {
template
<
KernelType
KT
,
typename
T
,
typename
Func
,
typename
Attr
,
template
<
KernelType
KT
,
typename
T
,
typename
Func
,
typename
Attr
,
typename
PlaceType
=
platform
::
CPUPlace
>
typename
PlaceType
=
platform
::
CPUPlace
>
const
Func
Get
(
Attr
attr
)
{
const
Func
Get
(
Attr
attr
)
{
size_t
key
=
Get
Key
<
Attr
>
(
attr
);
size_t
key
=
JitCode
Key
<
Attr
>
(
attr
);
auto
&
codes
=
JitCodePool
<
KT
>
().
Instance
();
auto
&
codes
=
JitCodePool
<
KT
>
().
Instance
();
if
(
codes
.
Has
(
key
))
{
if
(
codes
.
Has
(
key
))
{
return
codes
.
AllKernels
().
at
(
key
)
->
template
getCode
<
Func
>();
return
codes
.
AllKernels
().
at
(
key
)
->
template
getCode
<
Func
>();
}
}
KernelKey
kkey
(
KT
,
PlaceType
());
if
(
std
::
is_same
<
PlaceType
,
platform
::
CPUPlace
>::
value
)
{
if
(
std
::
is_same
<
PlaceType
,
platform
::
CPUPlace
>::
value
)
{
auto
p
=
CreateJitCode
<
KT
,
T
,
Attr
>
(
attr
);
// pool: (KernelKey(type, place), vector<GenCreatorPtr>)
if
(
p
)
{
auto
&
creator_map
=
JitCodeCreatorPool
().
Instance
().
AllCreators
();
auto
f
=
p
->
template
getCode
<
Func
>();
auto
iter
=
creator_map
.
find
(
kkey
);
codes
.
Insert
(
key
,
std
::
move
(
p
));
auto
&
creators
=
iter
->
second
;
return
f
;
for
(
auto
&
cur
:
creators
)
{
auto
i
=
dynamic_cast
<
const
JitCodeCreator
<
Attr
>*>
(
cur
.
get
());
if
(
i
&&
i
->
UseMe
(
attr
))
{
auto
p
=
i
->
CreateJitCode
(
attr
);
if
(
p
)
{
auto
f
=
p
->
template
getCode
<
Func
>();
codes
.
Insert
(
key
,
std
::
move
(
p
));
return
f
;
}
}
}
}
}
}
// pool: (KernelKey(type, place), vector<Kernel>)
// pool: (KernelKey(type, place), vector<Kernel
Ptr
>)
auto
&
pool
=
KernelPool
().
Instance
().
AllKernels
();
auto
&
pool
=
KernelPool
().
Instance
().
AllKernels
();
KernelKey
kkey
(
KT
,
PlaceType
());
auto
iter
=
pool
.
find
(
kkey
);
auto
iter
=
pool
.
find
(
kkey
);
if
(
iter
!=
pool
.
end
())
{
if
(
iter
!=
pool
.
end
())
{
auto
&
impls
=
iter
->
second
;
auto
&
impls
=
iter
->
second
;
...
...
paddle/fluid/operators/jit/registry.h
浏览文件 @
900c789a
...
@@ -116,7 +116,30 @@ class JitKernelRegistrar {
...
@@ -116,7 +116,30 @@ class JitKernelRegistrar {
#define REGISTER_GPUKERNEL_MORE(kernel_type, impl_type, ...) \
#define REGISTER_GPUKERNEL_MORE(kernel_type, impl_type, ...) \
REGISTER_KERNEL_MORE(kernel_type, impl_type, GPUPlace, __VA_ARGS__)
REGISTER_KERNEL_MORE(kernel_type, impl_type, GPUPlace, __VA_ARGS__)
// REGISTER_JITKERNEL_JITCODE(vmul, JitKernelCode<vmul, int>);
#define REGISTER_JITKERNEL_GEN(kernel_type, ...) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_gen_##kernel_type##_CPUPlace_, \
"REGISTER_JITKERNEL_GEN must be called in global namespace"); \
extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
static int __assert_gen_##kernel_type##_has_refer_ UNUSED = \
TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
static ::paddle::operators::jit::JitKernelRegistrar< \
::paddle::operators::jit::JitCodeCreatorPool, \
::paddle::platform::CPUPlace, __VA_ARGS__> \
__jit_kernel_registrar_gen_##kernel_type##_CPUPlace_( \
::paddle::operators::jit::KernelType::kernel_type); \
int TouchJitKernelReg_gen_##kernel_type##_CPUPlace_() { \
__jit_kernel_registrar_gen_##kernel_type##_CPUPlace_.Touch(); \
return 0; \
}
#define USE_JITKERNEL_GEN(kernel_type) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_gen_##kernel_type##_CPUPlace_, \
"USE_JITKERNEL_GEN must be called in global namespace"); \
extern int TouchJitKernelReg_gen_##kernel_type##_CPUPlace_(); \
static int use_jitkernel_gen_##kernel_type##_CPUPlace_ UNUSED = \
TouchJitKernelReg_gen_##kernel_type##_CPUPlace_()
#define USE_JITKERNEL_REFER(kernel_type) \
#define USE_JITKERNEL_REFER(kernel_type) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
...
...
paddle/fluid/operators/jit/test.cc
浏览文件 @
900c789a
...
@@ -61,6 +61,7 @@ void ExpectEQ(const T* target, const T* refer, int n) {
...
@@ -61,6 +61,7 @@ void ExpectEQ(const T* target, const T* refer, int n) {
// TODO(TJ): remove me
// TODO(TJ): remove me
USE_JITKERNEL_MORE
(
vmul
,
mkl
);
USE_JITKERNEL_MORE
(
vmul
,
mkl
);
USE_JITKERNEL_REFER
(
vmul
);
USE_JITKERNEL_REFER
(
vmul
);
USE_JITKERNEL_GEN
(
vmul
);
TEST
(
JitKernel
,
vmul
)
{
TEST
(
JitKernel
,
vmul
)
{
using
T
=
float
;
using
T
=
float
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录