Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
bc0df6a9
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
bc0df6a9
编写于
12月 12, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
make typename tuples
上级
194ce2e9
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
35 addition
and
40 deletion
+35
-40
paddle/fluid/operators/jit/benchmark.cc
paddle/fluid/operators/jit/benchmark.cc
+5
-10
paddle/fluid/operators/jit/helper.h
paddle/fluid/operators/jit/helper.h
+13
-11
paddle/fluid/operators/jit/kernel_base.h
paddle/fluid/operators/jit/kernel_base.h
+10
-5
paddle/fluid/operators/jit/more/mkl/mkl.h
paddle/fluid/operators/jit/more/mkl/mkl.h
+1
-2
paddle/fluid/operators/jit/refer/refer.h
paddle/fluid/operators/jit/refer/refer.h
+1
-2
paddle/fluid/operators/jit/test.cc
paddle/fluid/operators/jit/test.cc
+5
-10
未找到文件。
paddle/fluid/operators/jit/benchmark.cc
浏览文件 @
bc0df6a9
...
@@ -94,8 +94,7 @@ int main(int argc, char* argv[]) {
...
@@ -94,8 +94,7 @@ int main(int argc, char* argv[]) {
RandomVec
<
T
>
(
d
,
x
.
data
());
RandomVec
<
T
>
(
d
,
x
.
data
());
RandomVec
<
T
>
(
d
,
y
.
data
());
RandomVec
<
T
>
(
d
,
y
.
data
());
// refer
// refer
auto
refer
=
jit
::
GetRefer
<
KT
,
T
,
jit
::
VMulTuples
<
T
>::
func_type
,
auto
refer
=
jit
::
GetRefer
<
KT
,
jit
::
VMulTuples
<
T
>>
();
jit
::
VMulTuples
<
T
>::
attr_type
>
();
if
(
refer
)
{
if
(
refer
)
{
auto
res
=
auto
res
=
BenchTartgetFunc
<
T
,
jit
::
VMulTuples
<
T
>::
func_type
>
(
refer
,
x
,
y
,
z
);
BenchTartgetFunc
<
T
,
jit
::
VMulTuples
<
T
>::
func_type
>
(
refer
,
x
,
y
,
z
);
...
@@ -103,8 +102,7 @@ int main(int argc, char* argv[]) {
...
@@ -103,8 +102,7 @@ int main(int argc, char* argv[]) {
}
}
// test jitcode
// test jitcode
auto
jitcode
=
jit
::
GetJitCode
<
KT
,
T
,
jit
::
VMulTuples
<
T
>::
func_type
,
auto
jitcode
=
jit
::
GetJitCode
<
KT
,
jit
::
VMulTuples
<
T
>
,
PlaceType
>
(
d
);
jit
::
VMulTuples
<
T
>::
attr_type
,
PlaceType
>
(
d
);
if
(
jitcode
)
{
if
(
jitcode
)
{
auto
res
=
auto
res
=
BenchTartgetFunc
<
T
,
jit
::
VMulTuples
<
T
>::
func_type
>
(
jitcode
,
x
,
y
,
z
);
BenchTartgetFunc
<
T
,
jit
::
VMulTuples
<
T
>::
func_type
>
(
jitcode
,
x
,
y
,
z
);
...
@@ -118,10 +116,8 @@ int main(int argc, char* argv[]) {
...
@@ -118,10 +116,8 @@ int main(int argc, char* argv[]) {
if
(
iter
!=
pool
.
end
())
{
if
(
iter
!=
pool
.
end
())
{
auto
&
impls
=
iter
->
second
;
auto
&
impls
=
iter
->
second
;
for
(
auto
&
impl
:
impls
)
{
for
(
auto
&
impl
:
impls
)
{
auto
i
=
auto
i
=
dynamic_cast
<
const
jit
::
KernelImpl
<
jit
::
VMulTuples
<
T
>>*>
(
dynamic_cast
<
const
jit
::
KernelImpl
<
T
,
jit
::
VMulTuples
<
T
>::
func_type
,
impl
.
get
());
jit
::
VMulTuples
<
T
>::
attr_type
>*>
(
impl
.
get
());
if
(
i
&&
i
->
UseMe
(
d
))
{
if
(
i
&&
i
->
UseMe
(
d
))
{
auto
more
=
i
->
GetFunc
();
auto
more
=
i
->
GetFunc
();
auto
res
=
auto
res
=
...
@@ -132,8 +128,7 @@ int main(int argc, char* argv[]) {
...
@@ -132,8 +128,7 @@ int main(int argc, char* argv[]) {
}
}
// Test result from Get function
// Test result from Get function
auto
tgt
=
jit
::
Get
<
KT
,
T
,
jit
::
VMulTuples
<
T
>::
func_type
,
auto
tgt
=
jit
::
Get
<
KT
,
jit
::
VMulTuples
<
T
>
,
PlaceType
>
(
d
);
jit
::
VMulTuples
<
T
>::
attr_type
,
PlaceType
>
(
d
);
if
(
!
tgt
)
{
if
(
!
tgt
)
{
LOG
(
ERROR
)
<<
"Target can not be empty!"
;
LOG
(
ERROR
)
<<
"Target can not be empty!"
;
}
}
...
...
paddle/fluid/operators/jit/helper.h
浏览文件 @
bc0df6a9
...
@@ -32,9 +32,11 @@ namespace jit {
...
@@ -32,9 +32,11 @@ namespace jit {
#define SIGMOID_THRESHOLD_MAX 13.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
#define EXP_MAX_INPUT 40.0
template
<
KernelType
KT
,
typename
T
,
typename
Func
,
typename
Attr
,
template
<
KernelType
KT
,
typename
KernelTuples
,
typename
PlaceType
>
typename
PlaceType
>
inline
typename
KernelTuples
::
func_type
GetJitCode
(
inline
Func
GetJitCode
(
Attr
attr
)
{
typename
KernelTuples
::
attr_type
attr
)
{
using
Func
=
typename
KernelTuples
::
func_type
;
using
Attr
=
typename
KernelTuples
::
attr_type
;
size_t
key
=
JitCodeKey
<
Attr
>
(
attr
);
size_t
key
=
JitCodeKey
<
Attr
>
(
attr
);
auto
&
codes
=
JitCodePool
<
KT
>
().
Instance
();
auto
&
codes
=
JitCodePool
<
KT
>
().
Instance
();
if
(
codes
.
Has
(
key
))
{
if
(
codes
.
Has
(
key
))
{
...
@@ -65,8 +67,8 @@ inline Func GetJitCode(Attr attr) {
...
@@ -65,8 +67,8 @@ inline Func GetJitCode(Attr attr) {
// Refer code do not related with attr, which is just for cast
// Refer code do not related with attr, which is just for cast
// Refer is always on CPUPlace
// Refer is always on CPUPlace
template
<
KernelType
KT
,
typename
T
,
typename
Func
,
typename
Attr
>
template
<
KernelType
KT
,
typename
KernelTuples
>
inline
Func
GetRefer
()
{
inline
typename
KernelTuples
::
func_type
GetRefer
()
{
auto
&
ref_pool
=
ReferKernelPool
().
Instance
().
AllKernels
();
auto
&
ref_pool
=
ReferKernelPool
().
Instance
().
AllKernels
();
KernelKey
kkey
(
KT
,
platform
::
CPUPlace
());
KernelKey
kkey
(
KT
,
platform
::
CPUPlace
());
auto
ref_iter
=
ref_pool
.
find
(
kkey
);
auto
ref_iter
=
ref_pool
.
find
(
kkey
);
...
@@ -74,7 +76,7 @@ inline Func GetRefer() {
...
@@ -74,7 +76,7 @@ inline Func GetRefer() {
"Every Kernel should have reference function."
);
"Every Kernel should have reference function."
);
auto
&
ref_impls
=
ref_iter
->
second
;
auto
&
ref_impls
=
ref_iter
->
second
;
for
(
auto
&
impl
:
ref_impls
)
{
for
(
auto
&
impl
:
ref_impls
)
{
auto
i
=
dynamic_cast
<
const
ReferKernel
<
T
,
Func
,
Attr
>*>
(
impl
.
get
());
auto
i
=
dynamic_cast
<
const
ReferKernel
<
KernelTuples
>*>
(
impl
.
get
());
if
(
i
)
{
if
(
i
)
{
return
i
->
GetFunc
();
return
i
->
GetFunc
();
}
}
...
@@ -82,10 +84,10 @@ inline Func GetRefer() {
...
@@ -82,10 +84,10 @@ inline Func GetRefer() {
return
nullptr
;
return
nullptr
;
}
}
template
<
KernelType
KT
,
typename
T
,
typename
Func
,
typename
Attr
,
template
<
KernelType
KT
,
typename
KernelTuples
,
typename
PlaceType
=
platform
::
CPUPlace
>
typename
PlaceType
=
platform
::
CPUPlace
>
Func
Get
(
Attr
attr
)
{
typename
KernelTuples
::
func_type
Get
(
typename
KernelTuples
::
attr_type
attr
)
{
auto
jitfunc
=
GetJitCode
<
KT
,
T
,
Func
,
Attr
,
PlaceType
>
(
attr
);
auto
jitfunc
=
GetJitCode
<
KT
,
KernelTuples
,
PlaceType
>
(
attr
);
if
(
jitfunc
)
{
if
(
jitfunc
)
{
return
jitfunc
;
return
jitfunc
;
}
}
...
@@ -97,7 +99,7 @@ Func Get(Attr attr) {
...
@@ -97,7 +99,7 @@ Func Get(Attr attr) {
if
(
iter
!=
pool
.
end
())
{
if
(
iter
!=
pool
.
end
())
{
auto
&
impls
=
iter
->
second
;
auto
&
impls
=
iter
->
second
;
for
(
auto
&
impl
:
impls
)
{
for
(
auto
&
impl
:
impls
)
{
auto
i
=
dynamic_cast
<
const
KernelImpl
<
T
,
Func
,
Attr
>*>
(
impl
.
get
());
auto
i
=
dynamic_cast
<
const
KernelImpl
<
KernelTuples
>*>
(
impl
.
get
());
if
(
i
&&
i
->
UseMe
(
attr
))
{
if
(
i
&&
i
->
UseMe
(
attr
))
{
return
i
->
GetFunc
();
return
i
->
GetFunc
();
}
}
...
@@ -105,7 +107,7 @@ Func Get(Attr attr) {
...
@@ -105,7 +107,7 @@ Func Get(Attr attr) {
}
}
// The last implementation should be reference function on CPUPlace.
// The last implementation should be reference function on CPUPlace.
return
GetRefer
<
KT
,
T
,
Func
,
Attr
>
();
return
GetRefer
<
KT
,
KernelTuples
>
();
}
}
}
// namespace jit
}
// namespace jit
...
...
paddle/fluid/operators/jit/kernel_base.h
浏览文件 @
bc0df6a9
...
@@ -36,10 +36,13 @@ class Kernel {
...
@@ -36,10 +36,13 @@ class Kernel {
DISABLE_COPY_AND_ASSIGN
(
Kernel
);
DISABLE_COPY_AND_ASSIGN
(
Kernel
);
};
};
template
<
typename
T
,
typename
Func
,
typename
Attr
>
template
<
typename
KernelTuples
>
class
KernelImpl
:
public
Kernel
{
class
KernelImpl
:
public
Kernel
{
using
T
=
typename
KernelTuples
::
data_type
;
using
Func
=
typename
KernelTuples
::
func_type
;
using
Attr
=
typename
KernelTuples
::
attr_type
;
public:
public:
using
ELEMENT_TYPE
=
T
;
virtual
Func
GetFunc
()
const
{
return
func
;
}
virtual
Func
GetFunc
()
const
{
return
func
;
}
virtual
bool
UseMe
(
Attr
attr
)
const
=
0
;
virtual
bool
UseMe
(
Attr
attr
)
const
=
0
;
...
@@ -47,11 +50,13 @@ class KernelImpl : public Kernel {
...
@@ -47,11 +50,13 @@ class KernelImpl : public Kernel {
Func
func
{
nullptr
};
Func
func
{
nullptr
};
};
};
template
<
typename
T
,
typename
Func
,
typename
Attr
>
template
<
typename
KernelTuples
>
class
ReferKernel
:
public
KernelImpl
<
T
,
Func
,
Attr
>
{
class
ReferKernel
:
public
KernelImpl
<
KernelTuples
>
{
public:
public:
// Refer code can always be used
// Refer code can always be used
bool
UseMe
(
Attr
attr
)
const
override
{
return
true
;
}
bool
UseMe
(
typename
KernelTuples
::
attr_type
attr
)
const
override
{
return
true
;
}
};
};
}
// namespace jit
}
// namespace jit
...
...
paddle/fluid/operators/jit/more/mkl/mkl.h
浏览文件 @
bc0df6a9
...
@@ -28,8 +28,7 @@ template <typename T>
...
@@ -28,8 +28,7 @@ template <typename T>
void
VMul
(
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int
n
);
void
VMul
(
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int
n
);
template
<
typename
T
>
template
<
typename
T
>
class
VMulKernel
:
public
KernelImpl
<
T
,
typename
VMulTuples
<
T
>::
func_type
,
class
VMulKernel
:
public
KernelImpl
<
VMulTuples
<
T
>>
{
typename
VMulTuples
<
T
>::
attr_type
>
{
public:
public:
VMulKernel
()
{
this
->
func
=
VMul
<
T
>
;
}
VMulKernel
()
{
this
->
func
=
VMul
<
T
>
;
}
bool
UseMe
(
int
d
)
const
override
{
bool
UseMe
(
int
d
)
const
override
{
...
...
paddle/fluid/operators/jit/refer/refer.h
浏览文件 @
bc0df6a9
...
@@ -29,8 +29,7 @@ void VMul(const T* x, const T* y, T* z, int n) {
...
@@ -29,8 +29,7 @@ void VMul(const T* x, const T* y, T* z, int n) {
}
}
template
<
typename
T
>
template
<
typename
T
>
class
VMulKernel
:
public
ReferKernel
<
T
,
typename
VMulTuples
<
T
>::
func_type
,
class
VMulKernel
:
public
ReferKernel
<
VMulTuples
<
T
>>
{
typename
VMulTuples
<
T
>::
attr_type
>
{
public:
public:
VMulKernel
()
{
this
->
func
=
VMul
<
T
>
;
}
VMulKernel
()
{
this
->
func
=
VMul
<
T
>
;
}
};
};
...
...
paddle/fluid/operators/jit/test.cc
浏览文件 @
bc0df6a9
...
@@ -89,8 +89,7 @@ TEST(JitKernel, vmul) {
...
@@ -89,8 +89,7 @@ TEST(JitKernel, vmul) {
namespace
jit
=
paddle
::
operators
::
jit
;
namespace
jit
=
paddle
::
operators
::
jit
;
const
auto
KT
=
jit
::
vmul
;
const
auto
KT
=
jit
::
vmul
;
for
(
int
d
:
TestSizes
())
{
for
(
int
d
:
TestSizes
())
{
auto
ref
=
jit
::
GetRefer
<
KT
,
T
,
jit
::
VMulTuples
<
T
>::
func_type
,
auto
ref
=
jit
::
GetRefer
<
KT
,
jit
::
VMulTuples
<
T
>>
();
jit
::
VMulTuples
<
T
>::
attr_type
>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
EXPECT_TRUE
(
ref
!=
nullptr
);
std
::
vector
<
T
>
x
(
d
),
y
(
d
),
zref
(
d
);
std
::
vector
<
T
>
x
(
d
),
y
(
d
),
zref
(
d
);
...
@@ -115,8 +114,7 @@ TEST(JitKernel, vmul) {
...
@@ -115,8 +114,7 @@ TEST(JitKernel, vmul) {
ExpectEQ
<
T
>
(
yinp_data
,
zref_data
,
d
);
ExpectEQ
<
T
>
(
yinp_data
,
zref_data
,
d
);
// test jitcode
// test jitcode
auto
jitcode
=
jit
::
GetJitCode
<
KT
,
T
,
jit
::
VMulTuples
<
T
>::
func_type
,
auto
jitcode
=
jit
::
GetJitCode
<
KT
,
jit
::
VMulTuples
<
T
>
,
PlaceType
>
(
d
);
jit
::
VMulTuples
<
T
>::
attr_type
,
PlaceType
>
(
d
);
if
(
jitcode
)
{
if
(
jitcode
)
{
VLOG
(
10
)
<<
"Test jitcode, size: "
<<
d
;
VLOG
(
10
)
<<
"Test jitcode, size: "
<<
d
;
TestTartgetFunc
<
T
,
jit
::
VMulTuples
<
T
>::
func_type
>
(
jitcode
,
x
,
y
,
zref
);
TestTartgetFunc
<
T
,
jit
::
VMulTuples
<
T
>::
func_type
>
(
jitcode
,
x
,
y
,
zref
);
...
@@ -129,10 +127,8 @@ TEST(JitKernel, vmul) {
...
@@ -129,10 +127,8 @@ TEST(JitKernel, vmul) {
if
(
iter
!=
pool
.
end
())
{
if
(
iter
!=
pool
.
end
())
{
auto
&
impls
=
iter
->
second
;
auto
&
impls
=
iter
->
second
;
for
(
auto
&
impl
:
impls
)
{
for
(
auto
&
impl
:
impls
)
{
auto
i
=
auto
i
=
dynamic_cast
<
const
jit
::
KernelImpl
<
jit
::
VMulTuples
<
T
>>*>
(
dynamic_cast
<
const
jit
::
KernelImpl
<
T
,
jit
::
VMulTuples
<
T
>::
func_type
,
impl
.
get
());
jit
::
VMulTuples
<
T
>::
attr_type
>*>
(
impl
.
get
());
if
(
i
&&
i
->
UseMe
(
d
))
{
if
(
i
&&
i
->
UseMe
(
d
))
{
auto
more
=
i
->
GetFunc
();
auto
more
=
i
->
GetFunc
();
VLOG
(
10
)
<<
"Test More Kernel, size: "
<<
d
;
VLOG
(
10
)
<<
"Test More Kernel, size: "
<<
d
;
...
@@ -142,8 +138,7 @@ TEST(JitKernel, vmul) {
...
@@ -142,8 +138,7 @@ TEST(JitKernel, vmul) {
}
}
// Test result from Get function
// Test result from Get function
VLOG
(
10
)
<<
"Test Get function, size: "
<<
d
;
VLOG
(
10
)
<<
"Test Get function, size: "
<<
d
;
auto
tgt
=
jit
::
Get
<
KT
,
T
,
jit
::
VMulTuples
<
T
>::
func_type
,
auto
tgt
=
jit
::
Get
<
KT
,
jit
::
VMulTuples
<
T
>
,
PlaceType
>
(
d
);
jit
::
VMulTuples
<
T
>::
attr_type
,
PlaceType
>
(
d
);
TestTartgetFunc
<
T
,
jit
::
VMulTuples
<
T
>::
func_type
>
(
tgt
,
x
,
y
,
zref
);
TestTartgetFunc
<
T
,
jit
::
VMulTuples
<
T
>::
func_type
>
(
tgt
,
x
,
y
,
zref
);
}
}
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录