Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
e9216e82
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e9216e82
编写于
12月 12, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add refer vscal, vaddbias and test and benchmark
上级
a3703888
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
216 addition
and
28 deletion
+216
-28
paddle/fluid/operators/jit/README.md
paddle/fluid/operators/jit/README.md
+5
-3
paddle/fluid/operators/jit/benchmark.cc
paddle/fluid/operators/jit/benchmark.cc
+82
-7
paddle/fluid/operators/jit/helper.cc
paddle/fluid/operators/jit/helper.cc
+4
-0
paddle/fluid/operators/jit/kernel_base.h
paddle/fluid/operators/jit/kernel_base.h
+12
-1
paddle/fluid/operators/jit/refer/CMakeLists.txt
paddle/fluid/operators/jit/refer/CMakeLists.txt
+5
-0
paddle/fluid/operators/jit/refer/refer.cc
paddle/fluid/operators/jit/refer/refer.cc
+3
-0
paddle/fluid/operators/jit/refer/refer.h
paddle/fluid/operators/jit/refer/refer.h
+12
-0
paddle/fluid/operators/jit/test.cc
paddle/fluid/operators/jit/test.cc
+93
-10
paddle/fluid/operators/math/jit_kernel_refer.h
paddle/fluid/operators/math/jit_kernel_refer.h
+0
-7
未找到文件。
paddle/fluid/operators/jit/README.md
浏览文件 @
e9216e82
...
@@ -37,10 +37,12 @@ PaddlePaddle/Paddle/paddle/fluid/
...
@@ -37,10 +37,12 @@ PaddlePaddle/Paddle/paddle/fluid/
## 测试
## 测试
-
逻辑测试
-
逻辑测试
所有实现都要与refer的code对比,需要满足精度要求
所有实现都要与refer的code对比,需要满足精度要求
, 包括float和double的数据类型
-
性能测试
-
性能测试
所有实现的性能对比,并且与最终的
`jit::Get`
方法对比,该方法拿到的性能需要是最好的。
# 如何添加新的算子
# 如何添加新的算子
-
在
`KernelType`
中添加
`your_key`
-
在
`KernelType`
中添加
`your_key`
.
-
实现Reference 的逻辑,每个jitkernel的Reference 实现是必须的。不要依赖任何第三方库。并在
`refer/CmakeLists.txt`
中
`USE_JITKERNEL_REFER(your_key)`
-
实现Reference 的逻辑,每个jitkernel的Reference 实现是必须的。不要依赖任何第三方库。并在
`refer/CmakeLists.txt`
中
`USE_JITKERNEL_REFER(your_key)`
.
-
必要时可以添加新的
`KernelTuples`
,可以参考
`XYZNTuples`
.
paddle/fluid/operators/jit/benchmark.cc
浏览文件 @
e9216e82
...
@@ -53,7 +53,7 @@ std::vector<int> TestSizes() {
...
@@ -53,7 +53,7 @@ std::vector<int> TestSizes() {
// return this function avg time
// return this function avg time
template
<
typename
T
,
typename
KernelTuples
>
template
<
typename
T
,
typename
KernelTuples
>
double
Bench
Tartget
Func
(
const
typename
KernelTuples
::
func_type
tgt
,
double
Bench
XYZN
Func
(
const
typename
KernelTuples
::
func_type
tgt
,
const
std
::
vector
<
T
>&
x
,
const
std
::
vector
<
T
>&
y
,
const
std
::
vector
<
T
>&
x
,
const
std
::
vector
<
T
>&
y
,
std
::
vector
<
T
>&
z
)
{
// NOLINT
std
::
vector
<
T
>&
z
)
{
// NOLINT
const
T
*
x_data
=
x
.
data
();
const
T
*
x_data
=
x
.
data
();
...
@@ -83,14 +83,14 @@ void BenchXYZNKernel() {
...
@@ -83,14 +83,14 @@ void BenchXYZNKernel() {
// refer
// refer
auto
refer
=
jit
::
GetRefer
<
KT
,
jit
::
XYZNTuples
<
T
>>
();
auto
refer
=
jit
::
GetRefer
<
KT
,
jit
::
XYZNTuples
<
T
>>
();
if
(
refer
)
{
if
(
refer
)
{
auto
res
=
Bench
Tartget
Func
<
T
,
jit
::
XYZNTuples
<
T
>>
(
refer
,
x
,
y
,
z
);
auto
res
=
Bench
XYZN
Func
<
T
,
jit
::
XYZNTuples
<
T
>>
(
refer
,
x
,
y
,
z
);
infos
.
push_back
(
std
::
make_pair
(
"Refer"
,
res
));
infos
.
push_back
(
std
::
make_pair
(
"Refer"
,
res
));
}
}
// test jitcode
// test jitcode
auto
jitcode
=
jit
::
GetJitCode
<
KT
,
jit
::
XYZNTuples
<
T
>
,
PlaceType
>
(
d
);
auto
jitcode
=
jit
::
GetJitCode
<
KT
,
jit
::
XYZNTuples
<
T
>
,
PlaceType
>
(
d
);
if
(
jitcode
)
{
if
(
jitcode
)
{
auto
res
=
Bench
Tartget
Func
<
T
,
jit
::
XYZNTuples
<
T
>>
(
jitcode
,
x
,
y
,
z
);
auto
res
=
Bench
XYZN
Func
<
T
,
jit
::
XYZNTuples
<
T
>>
(
jitcode
,
x
,
y
,
z
);
infos
.
push_back
(
std
::
make_pair
(
"JitCode"
,
res
));
infos
.
push_back
(
std
::
make_pair
(
"JitCode"
,
res
));
}
}
...
@@ -105,7 +105,7 @@ void BenchXYZNKernel() {
...
@@ -105,7 +105,7 @@ void BenchXYZNKernel() {
impl
.
get
());
impl
.
get
());
if
(
i
&&
i
->
UseMe
(
d
))
{
if
(
i
&&
i
->
UseMe
(
d
))
{
auto
more
=
i
->
GetFunc
();
auto
more
=
i
->
GetFunc
();
auto
res
=
Bench
Tartget
Func
<
T
,
jit
::
XYZNTuples
<
T
>>
(
more
,
x
,
y
,
z
);
auto
res
=
Bench
XYZN
Func
<
T
,
jit
::
XYZNTuples
<
T
>>
(
more
,
x
,
y
,
z
);
infos
.
push_back
(
std
::
make_pair
(
"More"
,
res
));
infos
.
push_back
(
std
::
make_pair
(
"More"
,
res
));
}
}
}
}
...
@@ -116,7 +116,7 @@ void BenchXYZNKernel() {
...
@@ -116,7 +116,7 @@ void BenchXYZNKernel() {
if
(
!
tgt
)
{
if
(
!
tgt
)
{
LOG
(
ERROR
)
<<
"Target can not be empty!"
;
LOG
(
ERROR
)
<<
"Target can not be empty!"
;
}
}
auto
res
=
Bench
Tartget
Func
<
T
,
jit
::
XYZNTuples
<
T
>>
(
tgt
,
x
,
y
,
z
);
auto
res
=
Bench
XYZN
Func
<
T
,
jit
::
XYZNTuples
<
T
>>
(
tgt
,
x
,
y
,
z
);
infos
.
push_back
(
std
::
make_pair
(
"Target"
,
res
));
infos
.
push_back
(
std
::
make_pair
(
"Target"
,
res
));
// print
// print
...
@@ -129,6 +129,78 @@ void BenchXYZNKernel() {
...
@@ -129,6 +129,78 @@ void BenchXYZNKernel() {
}
}
}
}
// return this function avg time
template
<
typename
T
,
typename
KernelTuples
>
double
BenchAXYNFunc
(
const
typename
KernelTuples
::
func_type
tgt
,
const
T
a
,
const
std
::
vector
<
T
>&
x
,
std
::
vector
<
T
>&
y
)
{
// NOLINT
const
T
*
x_data
=
x
.
data
();
T
*
y_data
=
y
.
data
();
const
int
d
=
y
.
size
();
for
(
int
i
=
0
;
i
<
FLAGS_burning
;
++
i
)
{
tgt
(
&
a
,
x_data
,
y_data
,
d
);
}
auto
start
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
FLAGS_repeat
;
++
i
)
{
tgt
(
&
a
,
x_data
,
y_data
,
d
);
}
auto
end
=
GetCurrentUS
();
return
(
end
-
start
)
/
FLAGS_repeat
;
}
template
<
paddle
::
operators
::
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
BenchAXYNKernel
()
{
namespace
jit
=
paddle
::
operators
::
jit
;
for
(
int
d
:
TestSizes
())
{
std
::
vector
<
std
::
pair
<
std
::
string
,
double
>>
infos
;
const
T
a
=
static_cast
<
T
>
(
3
);
std
::
vector
<
T
>
x
(
d
),
y
(
d
);
RandomVec
<
T
>
(
d
,
x
.
data
());
// test refer
auto
refer
=
jit
::
GetRefer
<
KT
,
jit
::
AXYNTuples
<
T
>>
();
if
(
refer
)
{
auto
res
=
BenchAXYNFunc
<
T
,
jit
::
AXYNTuples
<
T
>>
(
refer
,
a
,
x
,
y
);
infos
.
push_back
(
std
::
make_pair
(
"Refer"
,
res
));
}
// test jitcode
auto
jitcode
=
jit
::
GetJitCode
<
KT
,
jit
::
AXYNTuples
<
T
>
,
PlaceType
>
(
d
);
if
(
jitcode
)
{
auto
res
=
BenchAXYNFunc
<
T
,
jit
::
AXYNTuples
<
T
>>
(
jitcode
,
a
,
x
,
y
);
infos
.
push_back
(
std
::
make_pair
(
"JitCode"
,
res
));
}
// test all impls in more
jit
::
KernelKey
kkey
(
KT
,
PlaceType
());
auto
&
pool
=
jit
::
KernelPool
().
Instance
().
AllKernels
();
auto
iter
=
pool
.
find
(
kkey
);
if
(
iter
!=
pool
.
end
())
{
auto
&
impls
=
iter
->
second
;
for
(
auto
&
impl
:
impls
)
{
auto
i
=
dynamic_cast
<
const
jit
::
KernelImpl
<
jit
::
AXYNTuples
<
T
>>*>
(
impl
.
get
());
if
(
i
&&
i
->
UseMe
(
d
))
{
auto
more
=
i
->
GetFunc
();
auto
res
=
BenchAXYNFunc
<
T
,
jit
::
AXYNTuples
<
T
>>
(
more
,
a
,
x
,
y
);
infos
.
push_back
(
std
::
make_pair
(
"More"
,
res
));
}
}
}
// Test result from Get function
auto
tgt
=
jit
::
Get
<
KT
,
jit
::
AXYNTuples
<
T
>
,
PlaceType
>
(
d
);
if
(
!
tgt
)
{
LOG
(
ERROR
)
<<
"Target can not be empty!"
;
}
auto
res
=
BenchAXYNFunc
<
T
,
jit
::
AXYNTuples
<
T
>>
(
tgt
,
a
,
x
,
y
);
infos
.
push_back
(
std
::
make_pair
(
"Target"
,
res
));
// print
std
::
ostringstream
loginfos
;
loginfos
<<
"Kernel Type: "
<<
jit
::
to_string
(
KT
)
<<
", size "
<<
d
<<
": "
;
for
(
auto
pair
:
infos
)
{
loginfos
<<
pair
.
first
<<
" takes "
<<
pair
.
second
<<
" us; "
;
}
LOG
(
INFO
)
<<
loginfos
.
str
();
}
}
// Benchmark all jit kernels including jitcode, mkl and refer.
// Benchmark all jit kernels including jitcode, mkl and refer.
// To use this tool, run command: ./benchmark [options...]
// To use this tool, run command: ./benchmark [options...]
// Options:
// Options:
...
@@ -147,4 +219,7 @@ int main(int argc, char* argv[]) {
...
@@ -147,4 +219,7 @@ int main(int argc, char* argv[]) {
BenchXYZNKernel
<
jit
::
vadd
,
T
,
PlaceType
>
();
BenchXYZNKernel
<
jit
::
vadd
,
T
,
PlaceType
>
();
BenchXYZNKernel
<
jit
::
vaddrelu
,
T
,
PlaceType
>
();
BenchXYZNKernel
<
jit
::
vaddrelu
,
T
,
PlaceType
>
();
BenchXYZNKernel
<
jit
::
vsub
,
T
,
PlaceType
>
();
BenchXYZNKernel
<
jit
::
vsub
,
T
,
PlaceType
>
();
BenchAXYNKernel
<
jit
::
vscal
,
T
,
PlaceType
>
();
BenchAXYNKernel
<
jit
::
vaddbias
,
T
,
PlaceType
>
();
}
}
paddle/fluid/operators/jit/helper.cc
浏览文件 @
e9216e82
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
* limitations under the License. */
* limitations under the License. */
#include "paddle/fluid/operators/jit/helper.h"
#include "paddle/fluid/operators/jit/helper.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -32,7 +33,10 @@ const char* to_string(KernelType kt) {
...
@@ -32,7 +33,10 @@ const char* to_string(KernelType kt) {
return
"vscal"
;
return
"vscal"
;
case
vexp
:
case
vexp
:
return
"vexp"
;
return
"vexp"
;
case
vaddbias
:
return
"vaddbias"
;
default:
default:
PADDLE_THROW
(
"Not support type: %d"
,
kt
);
return
"NOT JITKernel"
;
return
"NOT JITKernel"
;
}
}
return
nullptr
;
return
nullptr
;
...
...
paddle/fluid/operators/jit/kernel_base.h
浏览文件 @
e9216e82
...
@@ -19,7 +19,15 @@ namespace paddle {
...
@@ -19,7 +19,15 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
namespace
jit
{
namespace
jit
{
typedef
enum
{
vmul
=
0
,
vadd
=
1
,
vaddrelu
,
vsub
,
vscal
,
vexp
}
KernelType
;
typedef
enum
{
vmul
=
0
,
vadd
=
1
,
vaddrelu
,
vsub
,
vscal
,
vaddbias
,
vexp
}
KernelType
;
template
<
typename
T
>
template
<
typename
T
>
struct
XYZNTuples
{
struct
XYZNTuples
{
...
@@ -28,6 +36,9 @@ struct XYZNTuples {
...
@@ -28,6 +36,9 @@ struct XYZNTuples {
typedef
void
(
*
func_type
)(
const
T
*
,
const
T
*
,
T
*
,
int
);
typedef
void
(
*
func_type
)(
const
T
*
,
const
T
*
,
T
*
,
int
);
};
};
template
<
typename
T
>
struct
AXYNTuples
:
public
XYZNTuples
<
T
>
{};
// Just for adding to kernel pool without template
// Just for adding to kernel pool without template
class
Kernel
{
class
Kernel
{
public:
public:
...
...
paddle/fluid/operators/jit/refer/CMakeLists.txt
浏览文件 @
e9216e82
...
@@ -8,3 +8,8 @@ endfunction()
...
@@ -8,3 +8,8 @@ endfunction()
# use refer kernel by name
# use refer kernel by name
USE_JITKERNEL_REFER
(
vmul
)
USE_JITKERNEL_REFER
(
vmul
)
USE_JITKERNEL_REFER
(
vadd
)
USE_JITKERNEL_REFER
(
vaddrelu
)
USE_JITKERNEL_REFER
(
vsub
)
USE_JITKERNEL_REFER
(
vscal
)
USE_JITKERNEL_REFER
(
vaddbias
)
paddle/fluid/operators/jit/refer/refer.cc
浏览文件 @
e9216e82
...
@@ -26,4 +26,7 @@ REGISTER_REFER_KERNEL(vadd, VAdd);
...
@@ -26,4 +26,7 @@ REGISTER_REFER_KERNEL(vadd, VAdd);
REGISTER_REFER_KERNEL
(
vaddrelu
,
VAddRelu
);
REGISTER_REFER_KERNEL
(
vaddrelu
,
VAddRelu
);
REGISTER_REFER_KERNEL
(
vsub
,
VSub
);
REGISTER_REFER_KERNEL
(
vsub
,
VSub
);
REGISTER_REFER_KERNEL
(
vscal
,
VScal
);
REGISTER_REFER_KERNEL
(
vaddbias
,
VAddBias
);
#undef REGISTER_REFER_KERNEL
#undef REGISTER_REFER_KERNEL
paddle/fluid/operators/jit/refer/refer.h
浏览文件 @
e9216e82
...
@@ -59,6 +59,13 @@ void VScal(const T* a, const T* x, T* y, int n) {
...
@@ -59,6 +59,13 @@ void VScal(const T* a, const T* x, T* y, int n) {
}
}
}
}
template
<
typename
T
>
void
VAddBias
(
const
T
*
a
,
const
T
*
x
,
T
*
y
,
int
n
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
a
[
0
]
+
x
[
i
];
}
}
#define DECLARE_REFER_KERNEL(name, tuples) \
#define DECLARE_REFER_KERNEL(name, tuples) \
template <typename T> \
template <typename T> \
class name##Kernel : public ReferKernel<tuples<T>> { \
class name##Kernel : public ReferKernel<tuples<T>> { \
...
@@ -66,11 +73,16 @@ void VScal(const T* a, const T* x, T* y, int n) {
...
@@ -66,11 +73,16 @@ void VScal(const T* a, const T* x, T* y, int n) {
name##Kernel() { this->func = name<T>; } \
name##Kernel() { this->func = name<T>; } \
}
}
// const T* x, const T* y, T* z, int n
DECLARE_REFER_KERNEL
(
VMul
,
XYZNTuples
);
DECLARE_REFER_KERNEL
(
VMul
,
XYZNTuples
);
DECLARE_REFER_KERNEL
(
VAdd
,
XYZNTuples
);
DECLARE_REFER_KERNEL
(
VAdd
,
XYZNTuples
);
DECLARE_REFER_KERNEL
(
VAddRelu
,
XYZNTuples
);
DECLARE_REFER_KERNEL
(
VAddRelu
,
XYZNTuples
);
DECLARE_REFER_KERNEL
(
VSub
,
XYZNTuples
);
DECLARE_REFER_KERNEL
(
VSub
,
XYZNTuples
);
// const T* a, const T* x, T* y, int n
DECLARE_REFER_KERNEL
(
VScal
,
AXYNTuples
);
DECLARE_REFER_KERNEL
(
VAddBias
,
AXYNTuples
);
#undef DECLARE_REFER_KERNEL
#undef DECLARE_REFER_KERNEL
}
// namespace refer
}
// namespace refer
...
...
paddle/fluid/operators/jit/test.cc
浏览文件 @
e9216e82
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
* See the License for the specific language governing permissions and
* See the License for the specific language governing permissions and
* limitations under the License. */
* limitations under the License. */
#include <cstring> // for memcpy
#include <random>
#include <random>
#include <string>
#include <string>
#include <vector>
#include <vector>
...
@@ -59,7 +58,7 @@ std::vector<int> TestSizes() {
...
@@ -59,7 +58,7 @@ std::vector<int> TestSizes() {
}
}
template
<
typename
T
,
typename
KernelTuples
>
template
<
typename
T
,
typename
KernelTuples
>
void
Test
Tartget
Func
(
const
typename
KernelTuples
::
func_type
tgt
,
void
Test
XYZN
Func
(
const
typename
KernelTuples
::
func_type
tgt
,
const
std
::
vector
<
T
>&
x
,
const
std
::
vector
<
T
>&
y
,
const
std
::
vector
<
T
>&
x
,
const
std
::
vector
<
T
>&
y
,
const
std
::
vector
<
T
>&
zref
)
{
const
std
::
vector
<
T
>&
zref
)
{
EXPECT_TRUE
(
tgt
!=
nullptr
);
EXPECT_TRUE
(
tgt
!=
nullptr
);
...
@@ -88,9 +87,8 @@ void TestTartgetFunc(const typename KernelTuples::func_type tgt,
...
@@ -88,9 +87,8 @@ void TestTartgetFunc(const typename KernelTuples::func_type tgt,
template
<
paddle
::
operators
::
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
template
<
paddle
::
operators
::
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
TestXYZNKernel
()
{
void
TestXYZNKernel
()
{
namespace
jit
=
paddle
::
operators
::
jit
;
namespace
jit
=
paddle
::
operators
::
jit
;
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
for
(
int
d
:
TestSizes
())
{
for
(
int
d
:
TestSizes
())
{
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
)
<<
", size: "
<<
d
;
auto
ref
=
jit
::
GetRefer
<
KT
,
jit
::
XYZNTuples
<
T
>>
();
auto
ref
=
jit
::
GetRefer
<
KT
,
jit
::
XYZNTuples
<
T
>>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
EXPECT_TRUE
(
ref
!=
nullptr
);
...
@@ -119,7 +117,7 @@ void TestXYZNKernel() {
...
@@ -119,7 +117,7 @@ void TestXYZNKernel() {
auto
jitcode
=
jit
::
GetJitCode
<
KT
,
jit
::
XYZNTuples
<
T
>
,
PlaceType
>
(
d
);
auto
jitcode
=
jit
::
GetJitCode
<
KT
,
jit
::
XYZNTuples
<
T
>
,
PlaceType
>
(
d
);
if
(
jitcode
)
{
if
(
jitcode
)
{
VLOG
(
10
)
<<
"Test Jitcode Kernel, size: "
<<
d
;
VLOG
(
10
)
<<
"Test Jitcode Kernel, size: "
<<
d
;
Test
Tartget
Func
<
T
,
jit
::
XYZNTuples
<
T
>>
(
jitcode
,
x
,
y
,
zref
);
Test
XYZN
Func
<
T
,
jit
::
XYZNTuples
<
T
>>
(
jitcode
,
x
,
y
,
zref
);
}
}
// test all impls in more
// test all impls in more
...
@@ -134,14 +132,14 @@ void TestXYZNKernel() {
...
@@ -134,14 +132,14 @@ void TestXYZNKernel() {
if
(
i
&&
i
->
UseMe
(
d
))
{
if
(
i
&&
i
->
UseMe
(
d
))
{
auto
more
=
i
->
GetFunc
();
auto
more
=
i
->
GetFunc
();
VLOG
(
10
)
<<
"Test More Kernel, size: "
<<
d
;
VLOG
(
10
)
<<
"Test More Kernel, size: "
<<
d
;
Test
Tartget
Func
<
T
,
jit
::
XYZNTuples
<
T
>>
(
more
,
x
,
y
,
zref
);
Test
XYZN
Func
<
T
,
jit
::
XYZNTuples
<
T
>>
(
more
,
x
,
y
,
zref
);
}
}
}
}
}
}
// Test result from Get function
// Test result from Get function
VLOG
(
10
)
<<
"Test Get function, size: "
<<
d
;
VLOG
(
10
)
<<
"Test Get function, size: "
<<
d
;
auto
tgt
=
jit
::
Get
<
KT
,
jit
::
XYZNTuples
<
T
>
,
PlaceType
>
(
d
);
auto
tgt
=
jit
::
Get
<
KT
,
jit
::
XYZNTuples
<
T
>
,
PlaceType
>
(
d
);
Test
Tartget
Func
<
T
,
jit
::
XYZNTuples
<
T
>>
(
tgt
,
x
,
y
,
zref
);
Test
XYZN
Func
<
T
,
jit
::
XYZNTuples
<
T
>>
(
tgt
,
x
,
y
,
zref
);
}
}
}
}
...
@@ -169,4 +167,89 @@ TEST(JITKernel, vsub) {
...
@@ -169,4 +167,89 @@ TEST(JITKernel, vsub) {
TestXYZNKernel
<
jit
::
vsub
,
double
,
paddle
::
platform
::
CPUPlace
>
();
TestXYZNKernel
<
jit
::
vsub
,
double
,
paddle
::
platform
::
CPUPlace
>
();
}
}
TEST
(
JITKernel
,
pool
)
{}
template
<
typename
T
,
typename
KernelTuples
>
void
TestAXYNFunc
(
const
typename
KernelTuples
::
func_type
tgt
,
const
T
a
,
const
std
::
vector
<
T
>&
x
,
const
std
::
vector
<
T
>&
yref
)
{
EXPECT_TRUE
(
tgt
!=
nullptr
);
EXPECT_EQ
(
yref
.
size
(),
x
.
size
());
const
T
*
x_data
=
x
.
data
();
const
T
*
yref_data
=
yref
.
data
();
const
int
d
=
yref
.
size
();
std
::
vector
<
T
>
ytgt
(
d
);
T
*
ytgt_data
=
ytgt
.
data
();
// test normal
tgt
(
&
a
,
x_data
,
ytgt_data
,
d
);
ExpectEQ
<
T
>
(
ytgt_data
,
yref_data
,
d
);
// test inplace x
std
::
copy
(
x
.
begin
(),
x
.
end
(),
ytgt
.
begin
());
tgt
(
&
a
,
ytgt_data
,
ytgt_data
,
d
);
ExpectEQ
<
T
>
(
ytgt_data
,
yref_data
,
d
);
}
template
<
paddle
::
operators
::
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
TestAXYNKernel
()
{
namespace
jit
=
paddle
::
operators
::
jit
;
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
for
(
int
d
:
TestSizes
())
{
auto
ref
=
jit
::
GetRefer
<
KT
,
jit
::
AXYNTuples
<
T
>>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
const
T
a
=
static_cast
<
T
>
(
3
);
std
::
vector
<
T
>
x
(
d
),
yref
(
d
);
std
::
vector
<
T
>
xinp
(
d
);
// inplace test
RandomVec
<
T
>
(
d
,
x
.
data
());
std
::
copy
(
x
.
begin
(),
x
.
end
(),
xinp
.
begin
());
const
T
*
x_data
=
x
.
data
();
T
*
yref_data
=
yref
.
data
();
T
*
xinp_data
=
xinp
.
data
();
// test refer code inplace
ref
(
&
a
,
x_data
,
yref_data
,
d
);
ref
(
&
a
,
xinp_data
,
xinp_data
,
d
);
ExpectEQ
<
T
>
(
xinp_data
,
yref_data
,
d
);
// test jitcode
auto
jitcode
=
jit
::
GetJitCode
<
KT
,
jit
::
AXYNTuples
<
T
>
,
PlaceType
>
(
d
);
if
(
jitcode
)
{
VLOG
(
10
)
<<
"Test Jitcode Kernel, size: "
<<
d
;
TestAXYNFunc
<
T
,
jit
::
AXYNTuples
<
T
>>
(
jitcode
,
a
,
x
,
yref
);
}
// test all impls in more
jit
::
KernelKey
kkey
(
KT
,
PlaceType
());
auto
&
pool
=
jit
::
KernelPool
().
Instance
().
AllKernels
();
auto
iter
=
pool
.
find
(
kkey
);
if
(
iter
!=
pool
.
end
())
{
auto
&
impls
=
iter
->
second
;
for
(
auto
&
impl
:
impls
)
{
auto
i
=
dynamic_cast
<
const
jit
::
KernelImpl
<
jit
::
AXYNTuples
<
T
>>*>
(
impl
.
get
());
if
(
i
&&
i
->
UseMe
(
d
))
{
auto
more
=
i
->
GetFunc
();
VLOG
(
10
)
<<
"Test More Kernel, size: "
<<
d
;
TestAXYNFunc
<
T
,
jit
::
AXYNTuples
<
T
>>
(
more
,
a
,
x
,
yref
);
}
}
}
// Test result from Get function
VLOG
(
10
)
<<
"Test Get function, size: "
<<
d
;
auto
tgt
=
jit
::
Get
<
KT
,
jit
::
AXYNTuples
<
T
>
,
PlaceType
>
(
d
);
TestAXYNFunc
<
T
,
jit
::
AXYNTuples
<
T
>>
(
tgt
,
a
,
x
,
yref
);
}
}
TEST
(
JITKernel
,
vscal
)
{
namespace
jit
=
paddle
::
operators
::
jit
;
TestAXYNKernel
<
jit
::
vscal
,
float
,
paddle
::
platform
::
CPUPlace
>
();
TestAXYNKernel
<
jit
::
vscal
,
double
,
paddle
::
platform
::
CPUPlace
>
();
}
TEST
(
JITKernel
,
vaddbias
)
{
namespace
jit
=
paddle
::
operators
::
jit
;
TestAXYNKernel
<
jit
::
vaddbias
,
float
,
paddle
::
platform
::
CPUPlace
>
();
TestAXYNKernel
<
jit
::
vaddbias
,
double
,
paddle
::
platform
::
CPUPlace
>
();
}
TEST
(
JITKernel
,
pool
)
{
// TODO(TJ): add some test
}
paddle/fluid/operators/math/jit_kernel_refer.h
浏览文件 @
e9216e82
...
@@ -24,13 +24,6 @@ namespace math {
...
@@ -24,13 +24,6 @@ namespace math {
namespace
jitkernel
{
namespace
jitkernel
{
namespace
refer
{
namespace
refer
{
template
<
typename
T
>
void
VAddBias
(
const
T
*
a
,
const
T
*
x
,
T
*
y
,
int
n
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
a
[
0
]
+
x
[
i
];
}
}
template
<
typename
T
>
template
<
typename
T
>
void
VRelu
(
const
T
*
x
,
T
*
y
,
int
n
)
{
void
VRelu
(
const
T
*
x
,
T
*
y
,
int
n
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录