Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
28eb7d84
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
28eb7d84
编写于
12月 11, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
test all impls and all inplace cases
上级
d4cab7d9
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
121 addition
and
49 deletion
+121
-49
paddle/fluid/operators/jit/helper.h
paddle/fluid/operators/jit/helper.h
+33
-20
paddle/fluid/operators/jit/test.cc
paddle/fluid/operators/jit/test.cc
+88
-29
未找到文件。
paddle/fluid/operators/jit/helper.h
浏览文件 @
28eb7d84
...
@@ -28,33 +28,16 @@ namespace paddle {
...
@@ -28,33 +28,16 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
namespace
jit
{
namespace
jit
{
// Refer code do not related with attr, and always on CPUPlace
template
<
KernelType
KT
,
typename
T
,
typename
Func
,
typename
Attr
>
inline
Func
GetRefer
()
{
auto
&
ref_pool
=
ReferKernelPool
().
Instance
().
AllKernels
();
KernelKey
kkey
(
KT
,
platform
::
CPUPlace
());
auto
ref_iter
=
ref_pool
.
find
(
kkey
);
PADDLE_ENFORCE
(
ref_iter
!=
ref_pool
.
end
(),
"Every Kernel should have reference function."
);
auto
&
ref_impls
=
ref_iter
->
second
;
for
(
auto
&
impl
:
ref_impls
)
{
auto
i
=
dynamic_cast
<
const
ReferKernel
<
T
,
Func
,
Attr
>*>
(
impl
.
get
());
if
(
i
)
{
return
i
->
GetFunc
();
}
}
return
nullptr
;
}
template
<
KernelType
KT
,
typename
T
,
typename
Func
,
typename
Attr
,
template
<
KernelType
KT
,
typename
T
,
typename
Func
,
typename
Attr
,
typename
PlaceType
=
platform
::
CPUPlace
>
typename
PlaceType
>
const
Func
Get
(
Attr
attr
)
{
inline
const
Func
GetJitCode
(
Attr
attr
)
{
size_t
key
=
JitCodeKey
<
Attr
>
(
attr
);
size_t
key
=
JitCodeKey
<
Attr
>
(
attr
);
auto
&
codes
=
JitCodePool
<
KT
>
().
Instance
();
auto
&
codes
=
JitCodePool
<
KT
>
().
Instance
();
if
(
codes
.
Has
(
key
))
{
if
(
codes
.
Has
(
key
))
{
return
codes
.
AllKernels
().
at
(
key
)
->
template
getCode
<
Func
>();
return
codes
.
AllKernels
().
at
(
key
)
->
template
getCode
<
Func
>();
}
}
// creator is not related with attr, so can use KernelKey as key
KernelKey
kkey
(
KT
,
PlaceType
());
KernelKey
kkey
(
KT
,
PlaceType
());
if
(
std
::
is_same
<
PlaceType
,
platform
::
CPUPlace
>::
value
)
{
if
(
std
::
is_same
<
PlaceType
,
platform
::
CPUPlace
>::
value
)
{
// pool: (KernelKey(type, place), vector<GenCreatorPtr>)
// pool: (KernelKey(type, place), vector<GenCreatorPtr>)
...
@@ -73,8 +56,38 @@ const Func Get(Attr attr) {
...
@@ -73,8 +56,38 @@ const Func Get(Attr attr) {
}
}
}
}
}
}
return
nullptr
;
}
// Refer code do not related with attr, which is just for cast
// Refer is always on CPUPlace
template
<
KernelType
KT
,
typename
T
,
typename
Func
,
typename
Attr
>
inline
Func
GetRefer
()
{
auto
&
ref_pool
=
ReferKernelPool
().
Instance
().
AllKernels
();
KernelKey
kkey
(
KT
,
platform
::
CPUPlace
());
auto
ref_iter
=
ref_pool
.
find
(
kkey
);
PADDLE_ENFORCE
(
ref_iter
!=
ref_pool
.
end
(),
"Every Kernel should have reference function."
);
auto
&
ref_impls
=
ref_iter
->
second
;
for
(
auto
&
impl
:
ref_impls
)
{
auto
i
=
dynamic_cast
<
const
ReferKernel
<
T
,
Func
,
Attr
>*>
(
impl
.
get
());
if
(
i
)
{
return
i
->
GetFunc
();
}
}
return
nullptr
;
}
template
<
KernelType
KT
,
typename
T
,
typename
Func
,
typename
Attr
,
typename
PlaceType
=
platform
::
CPUPlace
>
const
Func
Get
(
Attr
attr
)
{
auto
jitfunc
=
GetJitCode
<
KT
,
T
,
Func
,
Attr
,
PlaceType
>
(
attr
);
if
(
jitfunc
)
{
return
jitfunc
;
}
// pool: (KernelKey(type, place), vector<KernelPtr>)
// pool: (KernelKey(type, place), vector<KernelPtr>)
KernelKey
kkey
(
KT
,
PlaceType
());
auto
&
pool
=
KernelPool
().
Instance
().
AllKernels
();
auto
&
pool
=
KernelPool
().
Instance
().
AllKernels
();
auto
iter
=
pool
.
find
(
kkey
);
auto
iter
=
pool
.
find
(
kkey
);
if
(
iter
!=
pool
.
end
())
{
if
(
iter
!=
pool
.
end
())
{
...
...
paddle/fluid/operators/jit/test.cc
浏览文件 @
28eb7d84
...
@@ -55,46 +55,105 @@ void ExpectEQ(const T* target, const T* refer, int n) {
...
@@ -55,46 +55,105 @@ void ExpectEQ(const T* target, const T* refer, int n) {
}
}
}
}
TEST
(
JitKernel
,
vmul
)
{
std
::
vector
<
int
>
TestSizes
()
{
using
T
=
float
;
std
::
vector
<
int
>
s
;
using
PlaceType
=
paddle
::
platform
::
CPUPlace
;
for
(
int
i
=
1
;
i
<
30
;
++
i
)
{
s
.
push_back
(
i
);
}
// test some large size
s
.
push_back
(
100
);
s
.
push_back
(
1000
);
return
s
;
}
namespace
jit
=
paddle
::
operators
::
jit
;
template
<
typename
T
,
typename
Func
>
// TODO(TJ): test more vector size
void
TestTartgetFunc
(
const
Func
tgt
,
const
std
::
vector
<
T
>&
x
,
for
(
int
d
=
1
;
d
<
30
;
++
d
)
{
const
std
::
vector
<
T
>&
y
,
const
std
::
vector
<
T
>&
zref
)
{
auto
ref
=
jit
::
GetRefer
<
jit
::
vmul
,
T
,
jit
::
VMulTuples
<
T
>::
func_type
,
jit
::
VMulTuples
<
T
>::
attr_type
>
();
auto
tgt
=
jit
::
Get
<
jit
::
vmul
,
T
,
jit
::
VMulTuples
<
T
>::
func_type
,
jit
::
VMulTuples
<
T
>::
attr_type
,
PlaceType
>
(
d
);
EXPECT_TRUE
(
ref
!=
nullptr
);
EXPECT_TRUE
(
tgt
!=
nullptr
);
EXPECT_TRUE
(
tgt
!=
nullptr
);
EXPECT_EQ
(
zref
.
size
(),
x
.
size
());
EXPECT_EQ
(
zref
.
size
(),
y
.
size
());
const
T
*
x_data
=
x
.
data
();
const
T
*
y_data
=
y
.
data
();
const
T
*
zref_data
=
zref
.
data
();
const
int
d
=
zref
.
size
();
std
::
vector
<
T
>
x
(
d
),
y
(
d
);
std
::
vector
<
T
>
ztgt
(
d
);
std
::
vector
<
T
>
zref
(
d
),
ztgt
(
d
);
T
*
ztgt_data
=
ztgt
.
data
();
RandomVec
<
T
>
(
d
,
x
.
data
());
// test normal
RandomVec
<
T
>
(
d
,
y
.
data
());
const
float
*
x_data
=
x
.
data
();
const
float
*
y_data
=
y
.
data
();
float
*
ztgt_data
=
ztgt
.
data
();
float
*
zref_data
=
zref
.
data
();
tgt
(
x_data
,
y_data
,
ztgt_data
,
d
);
tgt
(
x_data
,
y_data
,
ztgt_data
,
d
);
ref
(
x_data
,
y_data
,
zref_data
,
d
);
ExpectEQ
<
T
>
(
ztgt_data
,
zref_data
,
d
);
ExpectEQ
<
T
>
(
ztgt_data
,
zref_data
,
d
);
// test inplace x
// test inplace x
std
::
copy
(
x
.
begin
(),
x
.
end
(),
zref
.
begin
());
std
::
copy
(
x
.
begin
(),
x
.
end
(),
ztgt
.
begin
());
std
::
copy
(
x
.
begin
(),
x
.
end
(),
ztgt
.
begin
());
tgt
(
ztgt_data
,
y_data
,
ztgt_data
,
d
);
tgt
(
ztgt_data
,
y_data
,
ztgt_data
,
d
);
ref
(
zref_data
,
y_data
,
zref_data
,
d
);
ExpectEQ
<
T
>
(
ztgt_data
,
zref_data
,
d
);
ExpectEQ
<
T
>
(
ztgt_data
,
zref_data
,
d
);
// test inplace y
// test inplace y
std
::
copy
(
y
.
begin
(),
y
.
end
(),
zref
.
begin
());
std
::
copy
(
y
.
begin
(),
y
.
end
(),
ztgt
.
begin
());
std
::
copy
(
y
.
begin
(),
y
.
end
(),
ztgt
.
begin
());
tgt
(
x_data
,
ztgt_data
,
ztgt_data
,
d
);
tgt
(
x_data
,
ztgt_data
,
ztgt_data
,
d
);
ref
(
x_data
,
zref_data
,
zref_data
,
d
);
ExpectEQ
<
T
>
(
ztgt_data
,
zref_data
,
d
);
ExpectEQ
<
T
>
(
ztgt_data
,
zref_data
,
d
);
}
TEST
(
JitKernel
,
vmul
)
{
using
T
=
float
;
using
PlaceType
=
paddle
::
platform
::
CPUPlace
;
namespace
jit
=
paddle
::
operators
::
jit
;
const
auto
KT
=
jit
::
vmul
;
for
(
int
d
:
TestSizes
())
{
auto
ref
=
jit
::
GetRefer
<
KT
,
T
,
jit
::
VMulTuples
<
T
>::
func_type
,
jit
::
VMulTuples
<
T
>::
attr_type
>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
std
::
vector
<
T
>
x
(
d
),
y
(
d
),
zref
(
d
);
RandomVec
<
T
>
(
d
,
x
.
data
());
RandomVec
<
T
>
(
d
,
y
.
data
());
std
::
vector
<
T
>
xinp
(
d
),
yinp
(
d
);
// inplace test
std
::
copy
(
x
.
begin
(),
x
.
end
(),
xinp
.
begin
());
std
::
copy
(
y
.
begin
(),
y
.
end
(),
yinp
.
begin
());
const
T
*
x_data
=
x
.
data
();
const
T
*
y_data
=
y
.
data
();
T
*
zref_data
=
zref
.
data
();
T
*
xinp_data
=
xinp
.
data
();
T
*
yinp_data
=
yinp
.
data
();
// test refer code inplace
ref
(
x_data
,
y_data
,
zref_data
,
d
);
ref
(
x_data
,
yinp_data
,
yinp_data
,
d
);
ref
(
xinp_data
,
y_data
,
xinp_data
,
d
);
ExpectEQ
<
T
>
(
xinp_data
,
zref_data
,
d
);
ExpectEQ
<
T
>
(
yinp_data
,
zref_data
,
d
);
// test jitcode
auto
jitcode
=
jit
::
GetJitCode
<
KT
,
T
,
jit
::
VMulTuples
<
T
>::
func_type
,
jit
::
VMulTuples
<
T
>::
attr_type
,
PlaceType
>
(
d
);
if
(
jitcode
)
{
VLOG
(
10
)
<<
"Test jitcode, size: "
<<
d
;
TestTartgetFunc
<
T
,
jit
::
VMulTuples
<
T
>::
func_type
>
(
jitcode
,
x
,
y
,
zref
);
}
// test all impls in more
jit
::
KernelKey
kkey
(
KT
,
PlaceType
());
auto
&
pool
=
jit
::
KernelPool
().
Instance
().
AllKernels
();
auto
iter
=
pool
.
find
(
kkey
);
if
(
iter
!=
pool
.
end
())
{
auto
&
impls
=
iter
->
second
;
for
(
auto
&
impl
:
impls
)
{
auto
i
=
dynamic_cast
<
const
jit
::
KernelImpl
<
T
,
jit
::
VMulTuples
<
T
>::
func_type
,
jit
::
VMulTuples
<
T
>::
attr_type
>*>
(
impl
.
get
());
if
(
i
&&
i
->
UseMe
(
d
))
{
auto
more
=
i
->
GetFunc
();
VLOG
(
10
)
<<
"Test More Kernel, size: "
<<
d
;
TestTartgetFunc
<
T
,
jit
::
VMulTuples
<
T
>::
func_type
>
(
more
,
x
,
y
,
zref
);
}
}
}
// Test result from Get function
VLOG
(
10
)
<<
"Test Get function, size: "
<<
d
;
auto
tgt
=
jit
::
Get
<
KT
,
T
,
jit
::
VMulTuples
<
T
>::
func_type
,
jit
::
VMulTuples
<
T
>::
attr_type
,
PlaceType
>
(
d
);
TestTartgetFunc
<
T
,
jit
::
VMulTuples
<
T
>::
func_type
>
(
tgt
,
x
,
y
,
zref
);
}
}
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录