Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
a7fc3d42
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a7fc3d42
编写于
1月 15, 2019
作者:
T
tensor-tang
提交者:
GitHub
1月 15, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #15304 from tensor-tang/fuse/second_order_mul_sub
Fuse/second order mul sub and fuse repeated fc relu
上级
a152a5c7
1a95cd22
变更
31
显示空白变更内容
内联
并排
Showing
31 changed file
with
1584 addition
and
19 deletion
+1584
-19
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+2
-0
paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc
paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc
+386
-0
paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.h
paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.h
+41
-0
paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc
paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc
+2
-1
paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc
paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc
+379
-0
paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.h
paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.h
+41
-0
paddle/fluid/inference/api/paddle_pass_builder.h
paddle/fluid/inference/api/paddle_pass_builder.h
+2
-0
paddle/fluid/inference/tests/api/CMakeLists.txt
paddle/fluid/inference/tests/api/CMakeLists.txt
+7
-6
paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc
...le/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc
+15
-7
paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc
paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc
+149
-0
paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.h
paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.h
+41
-0
paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc
paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc
+137
-0
paddle/fluid/operators/fused/fusion_squared_mat_sub_op.h
paddle/fluid/operators/fused/fusion_squared_mat_sub_op.h
+42
-0
paddle/fluid/operators/jit/benchmark.cc
paddle/fluid/operators/jit/benchmark.cc
+22
-0
paddle/fluid/operators/jit/gen/CMakeLists.txt
paddle/fluid/operators/jit/gen/CMakeLists.txt
+2
-1
paddle/fluid/operators/jit/gen/act.cc
paddle/fluid/operators/jit/gen/act.cc
+6
-0
paddle/fluid/operators/jit/gen/act.h
paddle/fluid/operators/jit/gen/act.h
+14
-1
paddle/fluid/operators/jit/gen/blas.cc
paddle/fluid/operators/jit/gen/blas.cc
+6
-2
paddle/fluid/operators/jit/gen/blas.h
paddle/fluid/operators/jit/gen/blas.h
+4
-1
paddle/fluid/operators/jit/gen/jitcode.h
paddle/fluid/operators/jit/gen/jitcode.h
+1
-0
paddle/fluid/operators/jit/helper.cc
paddle/fluid/operators/jit/helper.cc
+2
-0
paddle/fluid/operators/jit/kernel_base.h
paddle/fluid/operators/jit/kernel_base.h
+9
-0
paddle/fluid/operators/jit/more/mkl/CMakeLists.txt
paddle/fluid/operators/jit/more/mkl/CMakeLists.txt
+2
-0
paddle/fluid/operators/jit/more/mkl/mkl.cc
paddle/fluid/operators/jit/more/mkl/mkl.cc
+38
-0
paddle/fluid/operators/jit/more/mkl/mkl.h
paddle/fluid/operators/jit/more/mkl/mkl.h
+10
-0
paddle/fluid/operators/jit/refer/CMakeLists.txt
paddle/fluid/operators/jit/refer/CMakeLists.txt
+2
-0
paddle/fluid/operators/jit/refer/refer.cc
paddle/fluid/operators/jit/refer/refer.cc
+3
-0
paddle/fluid/operators/jit/refer/refer.h
paddle/fluid/operators/jit/refer/refer.h
+27
-0
paddle/fluid/operators/jit/test.cc
paddle/fluid/operators/jit/test.cc
+54
-0
python/paddle/fluid/tests/unittests/test_fusion_repeated_fc_relu_op.py
.../fluid/tests/unittests/test_fusion_repeated_fc_relu_op.py
+85
-0
python/paddle/fluid/tests/unittests/test_fusion_squared_mat_sub_op.py
...e/fluid/tests/unittests/test_fusion_squared_mat_sub_op.py
+53
-0
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
a7fc3d42
...
@@ -43,6 +43,8 @@ pass_library(multi_batch_merge_pass base)
...
@@ -43,6 +43,8 @@ pass_library(multi_batch_merge_pass base)
pass_library
(
conv_bn_fuse_pass inference
)
pass_library
(
conv_bn_fuse_pass inference
)
pass_library
(
seqconv_eltadd_relu_fuse_pass inference
)
pass_library
(
seqconv_eltadd_relu_fuse_pass inference
)
pass_library
(
seqpool_concat_fuse_pass inference
)
pass_library
(
seqpool_concat_fuse_pass inference
)
pass_library
(
repeated_fc_relu_fuse_pass inference
)
pass_library
(
squared_mat_sub_fuse_pass inference
)
pass_library
(
is_test_pass base
)
pass_library
(
is_test_pass base
)
pass_library
(
conv_elementwise_add_act_fuse_pass inference
)
pass_library
(
conv_elementwise_add_act_fuse_pass inference
)
pass_library
(
conv_elementwise_add2_act_fuse_pass inference
)
pass_library
(
conv_elementwise_add2_act_fuse_pass inference
)
...
...
paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc
0 → 100644
浏览文件 @
a7fc3d42
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.h"
#include <algorithm> // for max
#include <string>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#define MAX_NUM_FC 10
namespace
paddle
{
namespace
framework
{
namespace
ir
{
PDNode
*
BuildRepeatedFCReluPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
int
num_fc
)
{
auto
var_next_is_fc_act
=
[
=
](
Node
*
x
,
const
std
::
string
&
act_type
=
"relu"
,
bool
check_in_has_only_one_out
=
true
,
int
fc_idx
=
0
)
->
bool
{
bool
next_is_fc
=
x
&&
x
->
IsVar
()
&&
VarLinksToOp
(
x
,
"fc"
);
if
(
check_in_has_only_one_out
)
{
next_is_fc
=
next_is_fc
&&
x
->
outputs
.
size
()
==
1
;
}
if
(
!
next_is_fc
)
{
return
false
;
}
auto
*
fc_op
=
x
->
outputs
[
fc_idx
];
bool
next_is_act
=
fc_op
&&
fc_op
->
IsOp
()
&&
fc_op
->
outputs
.
size
()
==
1
&&
fc_op
->
outputs
[
0
]
&&
fc_op
->
outputs
[
0
]
->
IsVar
()
&&
VarLinksToOp
(
fc_op
->
outputs
[
0
],
act_type
)
&&
fc_op
->
outputs
[
0
]
->
outputs
.
size
()
==
1
;
if
(
!
next_is_act
)
{
return
false
;
}
auto
*
act_op
=
fc_op
->
outputs
[
0
]
->
outputs
[
0
];
return
act_op
&&
act_op
->
IsOp
()
&&
act_op
->
outputs
.
size
()
==
1
;
};
auto
find_fc_idx
=
[
=
](
Node
*
x
,
const
std
::
string
&
act_type
=
"relu"
)
->
int
{
bool
next_is_fc
=
x
&&
x
->
IsVar
()
&&
VarLinksToOp
(
x
,
"fc"
);
if
(
!
next_is_fc
)
{
return
0
;
}
for
(
size_t
k
=
0
;
k
<
x
->
outputs
.
size
();
++
k
)
{
auto
*
fc_op
=
x
->
outputs
[
k
];
bool
next_is_act
=
fc_op
&&
fc_op
->
IsOp
()
&&
fc_op
->
outputs
.
size
()
==
1
&&
fc_op
->
outputs
[
0
]
&&
fc_op
->
outputs
[
0
]
->
IsVar
()
&&
VarLinksToOp
(
fc_op
->
outputs
[
0
],
act_type
)
&&
fc_op
->
outputs
[
0
]
->
outputs
.
size
()
==
1
;
if
(
!
next_is_act
)
{
continue
;
}
auto
*
act_op
=
fc_op
->
outputs
[
0
]
->
outputs
[
0
];
if
(
act_op
&&
act_op
->
IsOp
()
&&
act_op
->
outputs
.
size
()
==
1
)
{
return
k
;
}
}
return
0
;
};
auto
next_var_of_part
=
[
=
](
Node
*
x
,
int
fc_idx
=
0
)
->
Node
*
{
return
x
->
outputs
[
fc_idx
]
->
outputs
[
0
]
->
outputs
[
0
]
->
outputs
[
0
];
};
auto
var_next_is_fc_act_repeated_n_times
=
[
=
](
Node
*
x
,
int
repeated_times
,
const
std
::
string
&
act_type
=
"relu"
,
bool
check_in_has_only_one_out
=
true
)
->
bool
{
for
(
int
i
=
0
;
i
<
repeated_times
;
++
i
)
{
if
(
!
var_next_is_fc_act
(
x
,
act_type
,
i
==
0
&&
check_in_has_only_one_out
))
{
return
false
;
}
x
=
next_var_of_part
(
x
);
}
return
true
;
};
auto
var_before_is_fc_act
=
[
=
](
Node
*
x
,
const
std
::
string
&
act_type
=
"relu"
,
bool
at_top
=
false
)
->
bool
{
bool
before_is_act
=
x
&&
x
->
IsVar
()
&&
x
->
inputs
.
size
()
==
1
&&
VarLinksFromOp
(
x
,
"relu"
);
if
(
!
before_is_act
)
{
return
false
;
}
auto
*
relu_op
=
x
->
inputs
[
0
];
bool
before_is_fc
=
relu_op
->
IsOp
()
&&
relu_op
->
inputs
.
size
()
==
1
&&
relu_op
->
inputs
[
0
]
->
IsVar
()
&&
VarLinksFromOp
(
relu_op
->
inputs
[
0
],
"fc"
)
&&
relu_op
->
inputs
[
0
]
->
inputs
.
size
()
==
1
;
if
(
!
before_is_fc
)
{
return
false
;
}
auto
*
fc_op
=
relu_op
->
inputs
[
0
]
->
inputs
[
0
];
bool
is_fc
=
fc_op
->
IsOp
()
&&
fc_op
->
inputs
.
size
()
==
3
;
if
(
!
is_fc
)
{
return
false
;
}
for
(
auto
*
fc_i
:
fc_op
->
inputs
)
{
if
(
!
fc_i
->
inputs
.
empty
())
{
if
(
at_top
)
{
return
true
;
}
else
{
return
VarLinksFromOp
(
fc_i
,
"relu"
);
}
}
}
return
false
;
};
auto
before_var_of_part
=
[
=
](
Node
*
x
)
->
Node
*
{
auto
*
fc_op
=
x
->
inputs
[
0
]
->
inputs
[
0
];
for
(
auto
*
fc_i
:
fc_op
->
inputs
)
{
if
(
!
fc_i
->
inputs
.
empty
())
{
return
fc_i
->
inputs
[
0
];
}
}
return
nullptr
;
};
auto
var_before_is_fc_act_repeated_n_times
=
[
=
](
Node
*
x
,
int
repeated_times
,
const
std
::
string
&
act_type
=
"relu"
)
->
bool
{
for
(
int
i
=
0
;
i
<
repeated_times
;
++
i
)
{
if
(
!
var_before_is_fc_act
(
x
,
act_type
,
i
==
repeated_times
-
1
))
{
return
false
;
}
x
=
before_var_of_part
(
x
);
}
return
true
;
};
std
::
vector
<
PDNode
*>
fc_input_var
(
num_fc
);
std
::
vector
<
PDNode
*>
fc_output_var
(
num_fc
);
std
::
vector
<
PDNode
*>
fc_weight_var
(
num_fc
);
std
::
vector
<
PDNode
*>
fc_bias_var
(
num_fc
);
std
::
vector
<
PDNode
*>
fc_ops
(
num_fc
);
std
::
vector
<
PDNode
*>
relu_ops
(
num_fc
);
for
(
int
i
=
0
;
i
<
num_fc
;
++
i
)
{
fc_input_var
[
i
]
=
pattern
->
NewNode
(
[
=
](
Node
*
x
)
{
if
(
i
==
0
&&
x
->
outputs
.
size
()
>
0
)
{
bool
ok
=
x
->
inputs
.
size
()
>
0
;
if
(
!
ok
)
{
return
false
;
}
int
idx
=
find_fc_idx
(
x
);
if
(
idx
==
0
)
{
return
var_next_is_fc_act_repeated_n_times
(
x
,
num_fc
-
i
,
"relu"
);
}
else
{
x
=
next_var_of_part
(
x
,
idx
);
return
var_next_is_fc_act_repeated_n_times
(
x
,
std
::
max
(
1
,
num_fc
-
i
-
1
),
"relu"
);
}
}
else
{
return
var_next_is_fc_act_repeated_n_times
(
x
,
num_fc
-
i
,
"relu"
)
&&
x
->
inputs
.
size
()
>
0
&&
var_before_is_fc_act_repeated_n_times
(
x
,
i
,
"relu"
);
}
},
name_scope
+
"/fc_in_"
+
std
::
to_string
(
i
));
fc_weight_var
[
i
]
=
pattern
->
NewNode
(
[
=
](
Node
*
x
)
{
return
var_next_is_fc_act_repeated_n_times
(
x
,
num_fc
-
i
,
"relu"
)
&&
x
->
inputs
.
empty
()
&&
var_before_is_fc_act_repeated_n_times
(
x
->
outputs
[
0
]
->
inputs
[
0
],
i
,
"relu"
)
&&
x
->
Name
()
==
x
->
outputs
[
0
]
->
Op
()
->
Input
(
"W"
)[
0
];
},
name_scope
+
"/fc_weight_"
+
std
::
to_string
(
i
));
fc_bias_var
[
i
]
=
pattern
->
NewNode
(
[
=
](
Node
*
x
)
{
return
var_next_is_fc_act_repeated_n_times
(
x
,
num_fc
-
i
,
"relu"
)
&&
x
->
inputs
.
empty
()
&&
var_before_is_fc_act_repeated_n_times
(
x
->
outputs
[
0
]
->
inputs
[
0
],
i
,
"relu"
)
&&
x
->
Name
()
==
x
->
outputs
[
0
]
->
Op
()
->
Input
(
"Bias"
)[
0
];
},
name_scope
+
"/fc_bias_"
+
std
::
to_string
(
i
));
fc_output_var
[
i
]
=
pattern
->
NewNode
(
[
=
](
Node
*
x
)
{
bool
basic
=
x
&&
x
->
IsVar
()
&&
VarLinksFromOp
(
x
,
"fc"
)
&&
VarLinksToOp
(
x
,
"relu"
)
&&
x
->
inputs
.
size
()
==
1
&&
x
->
inputs
[
0
]
->
inputs
.
size
()
==
3
;
if
(
!
basic
)
{
return
false
;
}
x
=
x
->
inputs
[
0
]
->
inputs
[
0
];
if
(
i
==
0
&&
x
->
outputs
.
size
()
>
0
)
{
bool
ok
=
x
->
inputs
.
size
()
>
0
;
if
(
!
ok
)
{
return
false
;
}
int
idx
=
find_fc_idx
(
x
);
if
(
idx
==
0
)
{
return
var_next_is_fc_act_repeated_n_times
(
x
,
num_fc
-
i
,
"relu"
);
}
else
{
x
=
next_var_of_part
(
x
,
idx
);
return
var_next_is_fc_act_repeated_n_times
(
x
,
std
::
max
(
1
,
num_fc
-
i
-
1
),
"relu"
);
}
}
else
{
return
var_next_is_fc_act_repeated_n_times
(
x
,
num_fc
-
i
,
"relu"
)
&&
x
->
inputs
.
size
()
>
0
&&
var_before_is_fc_act_repeated_n_times
(
x
,
i
,
"relu"
);
}
},
name_scope
+
"/fc_out_"
+
std
::
to_string
(
i
));
fc_ops
[
i
]
=
pattern
->
NewNode
(
[
=
](
Node
*
x
)
{
bool
basic
=
x
&&
x
->
IsOp
()
&&
x
->
Op
()
->
Type
()
==
"fc"
&&
x
->
inputs
.
size
()
==
3
&&
x
->
outputs
.
size
()
==
1
;
if
(
!
basic
)
{
return
false
;
}
auto
*
fc_out_var
=
x
->
outputs
[
0
];
return
fc_out_var
&&
fc_out_var
->
IsVar
()
&&
fc_out_var
->
outputs
.
size
()
==
1
&&
VarLinksToOp
(
fc_out_var
,
"relu"
)
&&
fc_out_var
->
outputs
[
0
]
->
outputs
.
size
()
==
1
&&
var_next_is_fc_act_repeated_n_times
(
fc_out_var
->
outputs
[
0
]
->
outputs
[
0
],
num_fc
-
i
-
1
,
"relu"
)
&&
var_before_is_fc_act_repeated_n_times
(
fc_out_var
->
outputs
[
0
]
->
outputs
[
0
],
i
+
1
,
"relu"
);
},
name_scope
+
"/fc_op_"
+
std
::
to_string
(
i
));
relu_ops
[
i
]
=
pattern
->
NewNode
(
[
=
](
Node
*
x
)
{
return
x
&&
x
->
IsOp
()
&&
x
->
Op
()
->
Type
()
==
"relu"
&&
x
->
inputs
.
size
()
==
1
&&
x
->
outputs
.
size
()
==
1
&&
x
->
inputs
[
0
]
->
IsVar
()
&&
VarLinksFromOp
(
x
->
inputs
[
0
],
"fc"
)
&&
x
->
outputs
[
0
]
->
IsVar
()
&&
var_next_is_fc_act_repeated_n_times
(
x
->
outputs
[
0
],
num_fc
-
i
-
1
,
"relu"
)
&&
var_before_is_fc_act_repeated_n_times
(
x
->
outputs
[
0
],
i
+
1
,
"relu"
);
},
name_scope
+
"/act_op_"
+
std
::
to_string
(
i
));
fc_ops
[
i
]
->
LinksFrom
({
fc_input_var
[
i
],
fc_weight_var
[
i
],
fc_bias_var
[
i
]})
.
LinksTo
({
fc_output_var
[
i
]});
relu_ops
[
i
]
->
LinksFrom
({
fc_output_var
[
i
]});
}
auto
*
last_out_var
=
pattern
->
NewNode
(
[
=
](
Node
*
x
)
{
return
var_before_is_fc_act_repeated_n_times
(
x
,
num_fc
,
"relu"
);
},
name_scope
+
"/act_out"
);
for
(
int
i
=
0
;
i
<
num_fc
-
1
;
++
i
)
{
relu_ops
[
i
]
->
LinksTo
({
fc_input_var
[
i
+
1
]});
}
relu_ops
[
num_fc
-
1
]
->
LinksTo
({
last_out_var
});
return
last_out_var
;
}
static
int
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
int
num_fc
)
{
GraphPatternDetector
gpd
;
auto
*
pattern
=
gpd
.
mutable_pattern
();
BuildRepeatedFCReluPattern
(
pattern
,
name_scope
,
num_fc
);
auto
retrieve_node
=
[](
const
std
::
string
&
name
,
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
const
PDPattern
&
pat
)
->
Node
*
{
PADDLE_ENFORCE
(
subgraph
.
count
(
pat
.
RetrieveNode
(
name
)),
"pattern has no Node called %s"
,
name
.
c_str
());
Node
*
p
=
subgraph
.
at
(
pat
.
RetrieveNode
(
name
));
PADDLE_ENFORCE_NOT_NULL
(
p
,
"subgraph has no node %s"
,
name
.
c_str
());
return
p
;
};
int
fusion_count
{
0
};
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
LOG
(
INFO
)
<<
"handle Repeated FC Act fuse"
;
std
::
vector
<
Node
*>
weights_vars
(
num_fc
);
std
::
vector
<
Node
*>
bias_vars
(
num_fc
);
std
::
vector
<
Node
*>
relu_vars
(
num_fc
-
1
);
std
::
vector
<
std
::
string
>
weight_names
(
num_fc
);
std
::
vector
<
std
::
string
>
bias_names
(
num_fc
);
std
::
vector
<
std
::
string
>
relu_names
(
num_fc
-
1
);
auto
&
fused_pattern
=
gpd
.
pattern
();
for
(
int
i
=
0
;
i
<
num_fc
;
++
i
)
{
if
(
i
>=
1
)
{
relu_vars
[
i
-
1
]
=
retrieve_node
(
name_scope
+
"/fc_in_"
+
std
::
to_string
(
i
),
subgraph
,
fused_pattern
);
relu_names
[
i
-
1
]
=
relu_vars
[
i
-
1
]
->
Name
();
}
weights_vars
[
i
]
=
retrieve_node
(
name_scope
+
"/fc_weight_"
+
std
::
to_string
(
i
),
subgraph
,
fused_pattern
);
weight_names
[
i
]
=
weights_vars
[
i
]
->
Name
();
bias_vars
[
i
]
=
retrieve_node
(
name_scope
+
"/fc_bias_"
+
std
::
to_string
(
i
),
subgraph
,
fused_pattern
);
bias_names
[
i
]
=
bias_vars
[
i
]
->
Name
();
}
auto
*
input_var
=
retrieve_node
(
name_scope
+
"/fc_in_0"
,
subgraph
,
fused_pattern
);
auto
*
last_out_var
=
retrieve_node
(
name_scope
+
"/act_out"
,
subgraph
,
fused_pattern
);
// Create New OpDesc
OpDesc
op_desc
;
op_desc
.
SetType
(
"fusion_repeated_fc_relu"
);
op_desc
.
SetInput
(
"X"
,
{
input_var
->
Name
()});
op_desc
.
SetInput
(
"W"
,
weight_names
);
op_desc
.
SetInput
(
"Bias"
,
bias_names
);
op_desc
.
SetOutput
(
"ReluOut"
,
relu_names
);
op_desc
.
SetOutput
(
"Out"
,
{
last_out_var
->
Name
()});
auto
*
op
=
graph
->
CreateOpNode
(
&
op_desc
);
IR_NODE_LINK_TO
(
input_var
,
op
);
for
(
size_t
i
=
0
;
i
<
weights_vars
.
size
();
++
i
)
{
IR_NODE_LINK_TO
(
weights_vars
[
i
],
op
);
IR_NODE_LINK_TO
(
bias_vars
[
i
],
op
);
}
for
(
size_t
i
=
0
;
i
<
relu_vars
.
size
();
++
i
)
{
IR_NODE_LINK_TO
(
op
,
relu_vars
[
i
]);
}
IR_NODE_LINK_TO
(
op
,
last_out_var
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
;
for
(
auto
&
item
:
subgraph
)
{
marked_nodes
.
insert
(
item
.
second
);
}
for
(
size_t
i
=
0
;
i
<
weights_vars
.
size
();
++
i
)
{
marked_nodes
.
erase
(
weights_vars
[
i
]);
marked_nodes
.
erase
(
bias_vars
[
i
]);
}
for
(
size_t
i
=
0
;
i
<
relu_vars
.
size
();
++
i
)
{
marked_nodes
.
erase
(
relu_vars
[
i
]);
}
marked_nodes
.
erase
(
input_var
);
marked_nodes
.
erase
(
last_out_var
);
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
++
fusion_count
;
};
gpd
(
graph
,
handler
);
return
fusion_count
;
}
std
::
unique_ptr
<
ir
::
Graph
>
RepeatedFCReluFusePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
FusePassBase
::
Init
(
name_scope_
,
graph
.
get
());
int
fusion_count
=
0
;
for
(
int
i
=
MAX_NUM_FC
;
i
>
1
;
--
i
)
{
fusion_count
+=
BuildFusion
(
graph
.
get
(),
name_scope_
+
"/"
+
std
::
to_string
(
i
),
i
);
}
AddStatis
(
fusion_count
);
return
graph
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
repeated_fc_relu_fuse_pass
,
paddle
::
framework
::
ir
::
RepeatedFCReluFusePass
);
paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.h
0 → 100644
浏览文件 @
a7fc3d42
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
/**
* Fuse Repeated FC Relu
*/
class
RepeatedFCReluFusePass
:
public
FusePassBase
{
public:
virtual
~
RepeatedFCReluFusePass
()
{}
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
;
const
std
::
string
name_scope_
{
"repeated_fc_relu_fuse"
};
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc
浏览文件 @
a7fc3d42
...
@@ -129,7 +129,8 @@ PDNode* BuildSeqPoolConcatPattern(PDPattern* pattern,
...
@@ -129,7 +129,8 @@ PDNode* BuildSeqPoolConcatPattern(PDPattern* pattern,
return
concat_out_var
;
return
concat_out_var
;
}
}
int
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
int
num_inputs
)
{
static
int
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
int
num_inputs
)
{
GraphPatternDetector
gpd
;
GraphPatternDetector
gpd
;
auto
*
pattern
=
gpd
.
mutable_pattern
();
auto
*
pattern
=
gpd
.
mutable_pattern
();
BuildSeqPoolConcatPattern
(
pattern
,
name_scope
,
num_inputs
);
BuildSeqPoolConcatPattern
(
pattern
,
name_scope
,
num_inputs
);
...
...
paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc
0 → 100644
浏览文件 @
a7fc3d42
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
PDNode
*
BuildSquaredMatSubPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
{
auto
var_is_op_input
=
[
=
](
Node
*
x
,
const
std
::
string
&
op_type
,
const
std
::
string
&
arg_name
=
""
)
->
bool
{
if
(
!
(
x
&&
x
->
IsVar
()))
{
return
false
;
}
for
(
auto
*
op
:
x
->
outputs
)
{
if
(
op
&&
op
->
IsOp
()
&&
op
->
Op
()
&&
op
->
Op
()
->
Type
()
==
op_type
)
{
if
(
arg_name
.
empty
())
{
return
true
;
}
for
(
auto
&
name
:
op
->
Op
()
->
Input
(
arg_name
))
{
if
(
name
==
x
->
Name
())
{
return
true
;
}
}
}
}
return
false
;
};
auto
var_is_op_only_output
=
[](
Node
*
x
,
const
std
::
string
&
op_type
)
->
bool
{
return
x
&&
x
->
IsVar
()
&&
x
->
inputs
.
size
()
==
1
&&
x
->
inputs
[
0
]
&&
x
->
inputs
[
0
]
->
IsOp
()
&&
x
->
inputs
[
0
]
->
Op
()
->
Type
()
==
op_type
&&
x
->
inputs
[
0
]
->
outputs
.
size
()
==
1
;
};
auto
next_op
=
[
=
](
Node
*
x
,
const
std
::
string
&
op_type
)
->
Node
*
{
if
(
!
(
x
&&
x
->
IsVar
()))
{
return
nullptr
;
}
for
(
auto
*
op
:
x
->
outputs
)
{
if
(
op
&&
op
->
IsOp
()
&&
op
->
Op
()
&&
op
->
Op
()
->
Type
()
==
op_type
)
{
return
op
;
}
}
return
nullptr
;
};
auto
get_op_input_var
=
[
=
](
Node
*
x
,
const
std
::
string
&
arg_name
)
->
Node
*
{
if
(
!
(
x
&&
x
->
IsOp
()))
{
return
nullptr
;
}
for
(
auto
*
var
:
x
->
inputs
)
{
for
(
auto
name
:
x
->
Op
()
->
Input
(
arg_name
))
{
if
(
var
->
Name
()
==
name
)
{
return
var
;
}
}
}
return
nullptr
;
};
auto
is_fusion_input_var
=
[
=
](
Node
*
x
,
const
std
::
string
&
arg_name
)
{
bool
basic
=
var_is_op_input
(
x
,
"matmul"
,
arg_name
)
&&
var_is_op_input
(
x
,
"square"
,
"X"
);
if
(
!
basic
)
{
return
false
;
}
auto
*
squared_x_op
=
next_op
(
x
,
"square"
);
if
(
!
(
squared_x_op
&&
squared_x_op
->
outputs
.
size
()
==
1
))
{
return
false
;
}
auto
*
squared_x
=
squared_x_op
->
outputs
[
0
];
bool
next_is_matmul_from_arg
=
var_is_op_input
(
squared_x
,
"matmul"
,
arg_name
)
&&
squared_x
->
outputs
.
size
()
==
1
&&
squared_x
->
outputs
[
0
]
->
outputs
.
size
()
==
1
;
if
(
!
next_is_matmul_from_arg
)
{
return
false
;
}
auto
*
sub_y_in
=
squared_x
->
outputs
[
0
]
->
outputs
[
0
];
return
var_is_op_input
(
sub_y_in
,
"elementwise_sub"
,
"Y"
)
&&
sub_y_in
->
outputs
[
0
]
->
outputs
.
size
()
==
1
&&
var_is_op_input
(
sub_y_in
->
outputs
[
0
]
->
outputs
[
0
],
"elementwise_mul"
);
};
auto
is_fusion_first_mul_out
=
[
=
](
Node
*
x
)
->
bool
{
bool
input_is_matmul_op
=
x
&&
x
->
inputs
.
size
()
==
1
&&
x
->
inputs
[
0
]
->
IsOp
()
&&
x
->
inputs
[
0
]
->
Op
()
->
Type
()
==
"matmul"
;
if
(
!
input_is_matmul_op
)
{
return
false
;
}
auto
*
mat_x
=
get_op_input_var
(
x
->
inputs
[
0
],
"X"
);
auto
*
mat_y
=
get_op_input_var
(
x
->
inputs
[
0
],
"Y"
);
bool
input_mul_is_valid
=
mat_x
&&
is_fusion_input_var
(
mat_x
,
"X"
)
&&
mat_y
&&
is_fusion_input_var
(
mat_y
,
"Y"
);
if
(
!
input_mul_is_valid
)
{
return
false
;
}
bool
next_is_square
=
var_is_op_input
(
x
,
"square"
,
"X"
)
&&
x
->
outputs
.
size
()
==
1
&&
x
->
outputs
[
0
]
->
outputs
.
size
()
==
1
;
if
(
!
next_is_square
)
{
return
false
;
}
auto
*
sub_x_in
=
x
->
outputs
[
0
]
->
outputs
[
0
];
return
var_is_op_input
(
sub_x_in
,
"elementwise_sub"
,
"X"
)
&&
sub_x_in
->
outputs
[
0
]
->
outputs
.
size
()
==
1
&&
var_is_op_input
(
sub_x_in
->
outputs
[
0
]
->
outputs
[
0
],
"elementwise_mul"
);
};
auto
*
x
=
pattern
->
NewNode
(
[
=
](
Node
*
x
)
{
return
is_fusion_input_var
(
x
,
"X"
);
},
name_scope
+
"/x"
);
auto
*
y
=
pattern
->
NewNode
(
[
=
](
Node
*
x
)
{
return
is_fusion_input_var
(
x
,
"Y"
);
},
name_scope
+
"/y"
);
auto
*
square_x_op
=
pattern
->
NewNode
(
[
=
](
Node
*
x
)
{
return
x
&&
x
->
IsOp
()
&&
x
->
Op
()
->
Type
()
==
"square"
&&
is_fusion_input_var
(
x
->
inputs
[
0
],
"X"
);
},
name_scope
+
"/squared_x_op"
);
auto
*
square_y_op
=
pattern
->
NewNode
(
[
=
](
Node
*
x
)
{
return
x
&&
x
->
IsOp
()
&&
x
->
Op
()
->
Type
()
==
"square"
&&
is_fusion_input_var
(
x
->
inputs
[
0
],
"Y"
);
},
name_scope
+
"/squared_y_op"
);
auto
*
squared_x
=
pattern
->
NewNode
(
[
=
](
Node
*
x
)
{
return
x
&&
x
->
inputs
.
size
()
==
1
&&
x
->
inputs
[
0
]
->
inputs
.
size
()
==
1
&&
is_fusion_input_var
(
x
->
inputs
[
0
]
->
inputs
[
0
],
"X"
);
},
name_scope
+
"/squared_x"
);
auto
*
squared_y
=
pattern
->
NewNode
(
[
=
](
Node
*
x
)
{
return
x
&&
x
->
inputs
.
size
()
==
1
&&
x
->
inputs
[
0
]
->
inputs
.
size
()
==
1
&&
is_fusion_input_var
(
x
->
inputs
[
0
]
->
inputs
[
0
],
"Y"
);
},
name_scope
+
"/squared_y"
);
auto
*
matmuled_xy
=
pattern
->
NewNode
([
=
](
Node
*
x
)
{
return
is_fusion_first_mul_out
(
x
);
},
name_scope
+
"/matmuled_xy"
);
auto
*
matmul_xy_op
=
pattern
->
NewNode
(
[
=
](
Node
*
x
)
{
return
x
&&
x
->
IsOp
()
&&
x
->
Op
()
->
Type
()
==
"matmul"
&&
is_fusion_first_mul_out
(
x
->
outputs
[
0
]);
},
name_scope
+
"/matmul_xy_op"
);
auto
*
square_matmuled_xy_op
=
pattern
->
NewNode
(
[
=
](
Node
*
x
)
{
return
x
&&
x
->
IsOp
()
&&
x
->
Op
()
->
Type
()
==
"square"
&&
is_fusion_first_mul_out
(
x
->
inputs
[
0
]);
},
name_scope
+
"/square_matmuled_xy_op"
);
auto
*
squared_xmuly
=
pattern
->
NewNode
(
[
=
](
Node
*
x
)
{
return
x
&&
x
->
IsVar
()
&&
x
->
inputs
.
size
()
==
1
&&
x
->
inputs
[
0
]
->
IsOp
()
&&
x
->
inputs
[
0
]
->
Op
()
->
Type
()
==
"square"
&&
is_fusion_first_mul_out
(
x
->
inputs
[
0
]
->
inputs
[
0
]);
},
name_scope
+
"/squared_xmuly"
);
auto
is_fusion_mat_squared_x_y_op_out
=
[
=
](
Node
*
x
)
->
bool
{
bool
basic
=
x
&&
x
->
IsVar
()
&&
x
->
inputs
.
size
()
==
1
&&
x
->
inputs
[
0
]
->
IsOp
()
&&
x
->
inputs
[
0
]
->
Op
()
->
Type
()
==
"matmul"
;
if
(
!
basic
)
{
return
false
;
}
auto
*
sqx
=
get_op_input_var
(
x
->
inputs
[
0
],
"X"
);
auto
*
sqy
=
get_op_input_var
(
x
->
inputs
[
0
],
"Y"
);
return
var_is_op_only_output
(
sqx
,
"square"
)
&&
var_is_op_only_output
(
sqy
,
"square"
)
&&
sqx
->
inputs
[
0
]
&&
sqx
->
inputs
[
0
]
->
inputs
.
size
()
==
1
&&
is_fusion_input_var
(
sqx
->
inputs
[
0
]
->
inputs
[
0
],
"X"
)
&&
sqy
->
inputs
[
0
]
&&
sqy
->
inputs
[
0
]
->
inputs
.
size
()
==
1
&&
is_fusion_input_var
(
sqy
->
inputs
[
0
]
->
inputs
[
0
],
"Y"
);
};
auto
*
matmul_squared_x_y_op
=
pattern
->
NewNode
(
[
=
](
Node
*
x
)
{
return
x
&&
x
->
IsOp
()
&&
x
->
Op
()
->
Type
()
==
"matmul"
&&
is_fusion_mat_squared_x_y_op_out
(
x
->
outputs
[
0
]);
},
name_scope
+
"/matmul_squared_x_y_op"
);
auto
*
mat_squared_x_y_op_out
=
pattern
->
NewNode
(
[
=
](
Node
*
x
)
{
return
is_fusion_mat_squared_x_y_op_out
(
x
);
},
name_scope
+
"/mat_squared_x_y_op_out"
);
auto
is_fusion_sub_op
=
[
=
](
Node
*
x
)
->
bool
{
bool
is_sub_op
=
x
&&
x
->
IsOp
()
&&
x
->
Op
()
->
Type
()
==
"elementwise_sub"
;
if
(
!
is_sub_op
)
{
return
false
;
}
auto
*
matmul_sqx_sqy_var
=
get_op_input_var
(
x
,
"Y"
);
return
is_fusion_mat_squared_x_y_op_out
(
matmul_sqx_sqy_var
);
};
auto
*
sub_op
=
pattern
->
NewNode
([
=
](
Node
*
x
)
{
return
is_fusion_sub_op
(
x
);
},
name_scope
+
"/sub_op"
);
auto
*
sub_op_out
=
pattern
->
NewNode
(
[
=
](
Node
*
x
)
{
return
x
&&
x
->
IsVar
()
&&
x
->
inputs
.
size
()
==
1
&&
is_fusion_sub_op
(
x
->
inputs
[
0
]);
},
name_scope
+
"/sub_op_out"
);
auto
is_fusion_element_op
=
[
=
](
Node
*
x
)
->
bool
{
bool
is_elemul_op
=
x
&&
x
->
IsOp
()
&&
x
->
Op
()
->
Type
()
==
"elementwise_mul"
;
if
(
!
is_elemul_op
)
{
return
false
;
}
for
(
auto
*
in
:
x
->
inputs
)
{
if
(
in
&&
in
->
inputs
[
0
]
&&
is_fusion_sub_op
(
in
->
inputs
[
0
]))
{
return
true
;
}
}
return
false
;
};
auto
*
elementmul_op
=
pattern
->
NewNode
([
=
](
Node
*
x
)
{
return
is_fusion_element_op
(
x
);
},
name_scope
+
"/elementmul_op"
);
auto
*
constant_op
=
pattern
->
NewNode
(
[
=
](
Node
*
x
)
{
return
x
&&
x
->
IsOp
()
&&
x
->
Op
()
->
Type
()
==
"fill_constant"
&&
x
->
outputs
.
size
()
==
1
&&
is_fusion_element_op
(
x
->
outputs
[
0
]
->
outputs
[
0
]);
},
name_scope
+
"/fill_constant_op"
);
auto
*
constant_op_out
=
pattern
->
NewNode
(
[
=
](
Node
*
x
)
{
return
x
&&
x
->
IsVar
()
&&
var_is_op_input
(
x
,
"elementwise_mul"
)
&&
x
->
inputs
[
0
]
&&
x
->
inputs
[
0
]
->
IsOp
()
&&
x
->
inputs
[
0
]
->
Op
()
->
Type
()
==
"fill_constant"
&&
x
->
outputs
[
0
]
&&
is_fusion_element_op
(
x
->
outputs
[
0
]);
},
name_scope
+
"/constant_op_out"
);
auto
*
last_out_var
=
pattern
->
NewNode
(
[
=
](
Node
*
x
)
{
return
var_is_op_only_output
(
x
,
"elementwise_mul"
)
&&
is_fusion_element_op
(
x
->
inputs
[
0
]);
},
name_scope
+
"/out"
);
square_x_op
->
LinksFrom
({
x
}).
LinksTo
({
squared_x
});
square_y_op
->
LinksFrom
({
y
}).
LinksTo
({
squared_y
});
matmul_xy_op
->
LinksFrom
({
x
,
y
}).
LinksTo
({
matmuled_xy
});
matmul_squared_x_y_op
->
LinksFrom
({
squared_x
,
squared_y
})
.
LinksTo
({
mat_squared_x_y_op_out
});
square_matmuled_xy_op
->
LinksFrom
({
matmuled_xy
}).
LinksTo
({
squared_xmuly
});
sub_op
->
LinksFrom
({
squared_xmuly
,
mat_squared_x_y_op_out
})
.
LinksTo
({
sub_op_out
});
constant_op
->
LinksFrom
({}).
LinksTo
({
constant_op_out
});
elementmul_op
->
LinksFrom
({
constant_op_out
,
sub_op_out
})
.
LinksTo
({
last_out_var
});
return
last_out_var
;
}
static
int
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
)
{
GraphPatternDetector
gpd
;
auto
*
pattern
=
gpd
.
mutable_pattern
();
BuildSquaredMatSubPattern
(
pattern
,
name_scope
);
auto
retrieve_node
=
[](
const
std
::
string
&
name
,
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
const
PDPattern
&
pat
)
->
Node
*
{
PADDLE_ENFORCE
(
subgraph
.
count
(
pat
.
RetrieveNode
(
name
)),
"pattern has no Node called %s"
,
name
.
c_str
());
Node
*
p
=
subgraph
.
at
(
pat
.
RetrieveNode
(
name
));
PADDLE_ENFORCE_NOT_NULL
(
p
,
"subgraph has no node %s"
,
name
.
c_str
());
return
p
;
};
int
fusion_count
{
0
};
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
LOG
(
INFO
)
<<
"handle sqaure mat sub fuse"
;
auto
&
fused_pattern
=
gpd
.
pattern
();
auto
*
matx
=
retrieve_node
(
name_scope
+
"/x"
,
subgraph
,
fused_pattern
);
auto
*
maty
=
retrieve_node
(
name_scope
+
"/y"
,
subgraph
,
fused_pattern
);
auto
*
squaredx
=
retrieve_node
(
name_scope
+
"/squared_x"
,
subgraph
,
fused_pattern
);
auto
*
squaredy
=
retrieve_node
(
name_scope
+
"/squared_y"
,
subgraph
,
fused_pattern
);
auto
*
squaredxy
=
retrieve_node
(
name_scope
+
"/squared_xmuly"
,
subgraph
,
fused_pattern
);
auto
*
last_out_var
=
retrieve_node
(
name_scope
+
"/out"
,
subgraph
,
fused_pattern
);
auto
*
fill_constant_op
=
retrieve_node
(
name_scope
+
"/fill_constant_op"
,
subgraph
,
fused_pattern
);
// Create New OpDesc
OpDesc
op_desc
;
op_desc
.
SetType
(
"fusion_squared_mat_sub"
);
op_desc
.
SetInput
(
"X"
,
{
matx
->
Name
()});
op_desc
.
SetInput
(
"Y"
,
{
maty
->
Name
()});
op_desc
.
SetOutput
(
"SquaredX"
,
{
squaredx
->
Name
()});
op_desc
.
SetOutput
(
"SquaredY"
,
{
squaredy
->
Name
()});
op_desc
.
SetOutput
(
"SquaredXY"
,
{
squaredxy
->
Name
()});
op_desc
.
SetOutput
(
"Out"
,
{
last_out_var
->
Name
()});
op_desc
.
SetAttr
(
"scalar"
,
fill_constant_op
->
Op
()
->
GetAttr
(
"value"
));
auto
*
op
=
graph
->
CreateOpNode
(
&
op_desc
);
IR_NODE_LINK_TO
(
matx
,
op
);
IR_NODE_LINK_TO
(
maty
,
op
);
IR_NODE_LINK_TO
(
op
,
squaredx
);
IR_NODE_LINK_TO
(
op
,
squaredy
);
IR_NODE_LINK_TO
(
op
,
squaredxy
);
IR_NODE_LINK_TO
(
op
,
last_out_var
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
;
for
(
auto
&
item
:
subgraph
)
{
marked_nodes
.
insert
(
item
.
second
);
}
marked_nodes
.
erase
(
matx
);
marked_nodes
.
erase
(
maty
);
marked_nodes
.
erase
(
squaredx
);
marked_nodes
.
erase
(
squaredy
);
marked_nodes
.
erase
(
squaredxy
);
marked_nodes
.
erase
(
last_out_var
);
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
++
fusion_count
;
};
gpd
(
graph
,
handler
);
return
fusion_count
;
}
std
::
unique_ptr
<
ir
::
Graph
>
SquaredMatSubFusePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
FusePassBase
::
Init
(
name_scope_
,
graph
.
get
());
int
fusion_count
=
BuildFusion
(
graph
.
get
(),
name_scope_
);
AddStatis
(
fusion_count
);
return
graph
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
squared_mat_sub_fuse_pass
,
paddle
::
framework
::
ir
::
SquaredMatSubFusePass
);
paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.h
0 → 100644
浏览文件 @
a7fc3d42
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
/**
* Fuse ( (A.^2 * B.^2) - (A * B).^2 ) .* scalar
*/
class
SquaredMatSubFusePass
:
public
FusePassBase
{
public:
virtual
~
SquaredMatSubFusePass
()
{}
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
;
const
std
::
string
name_scope_
{
"squared_mat_sub_fuse"
};
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/inference/api/paddle_pass_builder.h
浏览文件 @
a7fc3d42
...
@@ -98,6 +98,8 @@ class CpuPassStrategy : public PassStrategy {
...
@@ -98,6 +98,8 @@ class CpuPassStrategy : public PassStrategy {
"mul_gru_fuse_pass"
,
//
"mul_gru_fuse_pass"
,
//
"seq_concat_fc_fuse_pass"
,
//
"seq_concat_fc_fuse_pass"
,
//
"fc_fuse_pass"
,
//
"fc_fuse_pass"
,
//
"repeated_fc_relu_fuse_pass"
,
//
"squared_mat_sub_fuse_pass"
,
//
"conv_bn_fuse_pass"
,
//
"conv_bn_fuse_pass"
,
//
"conv_eltwiseadd_bn_fuse_pass"
,
//
"conv_eltwiseadd_bn_fuse_pass"
,
//
"is_test_pass"
,
//
"is_test_pass"
,
//
...
...
paddle/fluid/inference/tests/api/CMakeLists.txt
浏览文件 @
a7fc3d42
...
@@ -37,15 +37,21 @@ function(inference_analysis_api_test_with_refer_result target install_dir filena
...
@@ -37,15 +37,21 @@ function(inference_analysis_api_test_with_refer_result target install_dir filena
--refer_result=
${
install_dir
}
/result.txt
)
--refer_result=
${
install_dir
}
/result.txt
)
endfunction
()
endfunction
()
# RNN1
if
(
NOT APPLE AND WITH_MKLML
)
if
(
NOT APPLE AND WITH_MKLML
)
# RNN1
set
(
RNN1_INSTALL_DIR
"
${
INFERENCE_DEMO_INSTALL_DIR
}
/rnn1"
)
set
(
RNN1_INSTALL_DIR
"
${
INFERENCE_DEMO_INSTALL_DIR
}
/rnn1"
)
download_model_and_data
(
${
RNN1_INSTALL_DIR
}
"rnn1%2Fmodel.tar.gz"
"rnn1%2Fdata.txt.tar.gz"
)
download_model_and_data
(
${
RNN1_INSTALL_DIR
}
"rnn1%2Fmodel.tar.gz"
"rnn1%2Fdata.txt.tar.gz"
)
inference_analysis_api_test
(
test_analyzer_rnn1
${
RNN1_INSTALL_DIR
}
analyzer_rnn1_tester.cc SERIAL
)
inference_analysis_api_test
(
test_analyzer_rnn1
${
RNN1_INSTALL_DIR
}
analyzer_rnn1_tester.cc SERIAL
)
# seq_pool1
set
(
SEQ_POOL1_INSTALL_DIR
"
${
INFERENCE_DEMO_INSTALL_DIR
}
/seq_pool"
)
download_model_and_data
(
${
SEQ_POOL1_INSTALL_DIR
}
"seq_pool1_model_.tar.gz"
"seq_pool1_data.txt.tar.gz"
)
inference_analysis_api_test
(
test_analyzer_seq_pool1
${
SEQ_POOL1_INSTALL_DIR
}
analyzer_seq_pool1_tester.cc SERIAL
)
else
()
else
()
# TODO: fix this test on MACOS and OPENBLAS, the reason is that
# TODO: fix this test on MACOS and OPENBLAS, the reason is that
# fusion_seqexpand_concat_fc_op is not supported on MACOS and OPENBLAS
# fusion_seqexpand_concat_fc_op is not supported on MACOS and OPENBLAS
message
(
WARNING
"These tests has been disabled in OSX or WITH_MKL=OFF before being fixed:
\n
test_analyzer_rnn1"
)
message
(
WARNING
"These tests has been disabled in OSX or WITH_MKL=OFF before being fixed:
\n
test_analyzer_rnn1"
)
message
(
WARNING
"These tests has been disabled in OSX or WITH_MKL=OFF before being fixed:
\n
test_analyzer_seq_pool1"
)
endif
()
endif
()
# RNN2
# RNN2
...
@@ -90,11 +96,6 @@ set(SEQ_CONV1_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/seq_conv1")
...
@@ -90,11 +96,6 @@ set(SEQ_CONV1_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/seq_conv1")
download_model_and_data
(
${
SEQ_CONV1_INSTALL_DIR
}
"seq_conv1_model.tar.gz"
"seq_conv1_data.txt.tar.gz"
)
download_model_and_data
(
${
SEQ_CONV1_INSTALL_DIR
}
"seq_conv1_model.tar.gz"
"seq_conv1_data.txt.tar.gz"
)
inference_analysis_api_test
(
test_analyzer_seq_conv1
${
SEQ_CONV1_INSTALL_DIR
}
analyzer_seq_conv1_tester.cc
)
inference_analysis_api_test
(
test_analyzer_seq_conv1
${
SEQ_CONV1_INSTALL_DIR
}
analyzer_seq_conv1_tester.cc
)
# seq_pool1
set
(
SEQ_POOL1_INSTALL_DIR
"
${
INFERENCE_DEMO_INSTALL_DIR
}
/seq_pool"
)
download_model_and_data
(
${
SEQ_POOL1_INSTALL_DIR
}
"seq_pool1_model_.tar.gz"
"seq_pool1_data.txt.tar.gz"
)
inference_analysis_api_test
(
test_analyzer_seq_pool1
${
SEQ_POOL1_INSTALL_DIR
}
analyzer_seq_pool1_tester.cc
)
# ocr
# ocr
set
(
OCR_INSTALL_DIR
"
${
INFERENCE_DEMO_INSTALL_DIR
}
/ocr"
)
set
(
OCR_INSTALL_DIR
"
${
INFERENCE_DEMO_INSTALL_DIR
}
/ocr"
)
if
(
NOT EXISTS
${
OCR_INSTALL_DIR
}
)
if
(
NOT EXISTS
${
OCR_INSTALL_DIR
}
)
...
...
paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc
浏览文件 @
a7fc3d42
...
@@ -21,6 +21,12 @@ namespace paddle {
...
@@ -21,6 +21,12 @@ namespace paddle {
namespace
inference
{
namespace
inference
{
namespace
analysis
{
namespace
analysis
{
// diff: similarity_norm.tmp_0, for speed: fc_4.tmp_1
static
const
char
out_var_name
[]
=
"reduce_sum_0.tmp_0"
;
// for diff: 154, for speed 111
constexpr
int
num_slots
=
154
;
struct
OneSlotInBatch
{
struct
OneSlotInBatch
{
std
::
string
name
;
std
::
string
name
;
std
::
vector
<
std
::
vector
<
float
>>
data
;
std
::
vector
<
std
::
vector
<
float
>>
data
;
...
@@ -41,7 +47,6 @@ struct DataRecord {
...
@@ -41,7 +47,6 @@ struct DataRecord {
void
Load
(
const
std
::
string
&
path
)
{
void
Load
(
const
std
::
string
&
path
)
{
std
::
ifstream
file
(
path
);
std
::
ifstream
file
(
path
);
constexpr
int
num_slots
=
154
;
std
::
string
line
;
std
::
string
line
;
int
num_lines
=
0
;
int
num_lines
=
0
;
while
(
std
::
getline
(
file
,
line
))
{
while
(
std
::
getline
(
file
,
line
))
{
...
@@ -187,11 +192,15 @@ void analysis_fuse_statis(bool use_zerocopy) {
...
@@ -187,11 +192,15 @@ void analysis_fuse_statis(bool use_zerocopy) {
auto
predictor
=
CreatePaddlePredictor
<
AnalysisConfig
>
(
cfg
);
auto
predictor
=
CreatePaddlePredictor
<
AnalysisConfig
>
(
cfg
);
auto
fuse_statis
=
GetFuseStatis
(
predictor
.
get
(),
&
num_ops
);
auto
fuse_statis
=
GetFuseStatis
(
predictor
.
get
(),
&
num_ops
);
ASSERT_TRUE
(
fuse_statis
.
count
(
"fc_fuse"
));
ASSERT_TRUE
(
fuse_statis
.
count
(
"fc_fuse"
));
ASSERT_EQ
(
fuse_statis
.
at
(
"fc_fuse"
),
10
);
ASSERT_TRUE
(
fuse_statis
.
count
(
"seqpool_concat_fuse"
));
ASSERT_TRUE
(
fuse_statis
.
count
(
"seqpool_concat_fuse"
));
ASSERT_TRUE
(
fuse_statis
.
count
(
"squared_mat_sub_fuse"
));
ASSERT_TRUE
(
fuse_statis
.
count
(
"repeated_fc_relu_fuse"
));
ASSERT_EQ
(
fuse_statis
.
at
(
"fc_fuse"
),
10
);
EXPECT_EQ
(
fuse_statis
.
at
(
"seqpool_concat_fuse"
),
2
);
EXPECT_EQ
(
fuse_statis
.
at
(
"seqpool_concat_fuse"
),
2
);
EXPECT_EQ
(
fuse_statis
.
at
(
"squared_mat_sub_fuse"
),
2
);
EXPECT_EQ
(
fuse_statis
.
at
(
"repeated_fc_relu_fuse"
),
2
);
LOG
(
INFO
)
<<
"num_ops: "
<<
num_ops
;
LOG
(
INFO
)
<<
"num_ops: "
<<
num_ops
;
EXPECT_EQ
(
num_ops
,
1
95
);
EXPECT_EQ
(
num_ops
,
1
71
);
}
}
// Check the fuse status
// Check the fuse status
...
@@ -214,9 +223,6 @@ void PrepareZeroCopyInputs(
...
@@ -214,9 +223,6 @@ void PrepareZeroCopyInputs(
}
}
}
}
// diff: similarity_norm.tmp_0, // speed: fc_4.tmp_1
static
const
char
out_var_name
[]
=
"reduce_sum_0.tmp_0"
;
// return the output values
// return the output values
std
::
vector
<
float
>
zerocopy_profile
(
int
repeat_times
)
{
std
::
vector
<
float
>
zerocopy_profile
(
int
repeat_times
)
{
AnalysisConfig
config
;
AnalysisConfig
config
;
...
@@ -322,7 +328,9 @@ TEST(Analyzer_seq_pool1, zerocopy_compare_native) {
...
@@ -322,7 +328,9 @@ TEST(Analyzer_seq_pool1, zerocopy_compare_native) {
native_outputs
.
front
().
data
.
length
());
native_outputs
.
front
().
data
.
length
());
auto
*
native_data
=
static_cast
<
float
*>
(
native_outputs
.
front
().
data
.
data
());
auto
*
native_data
=
static_cast
<
float
*>
(
native_outputs
.
front
().
data
.
data
());
for
(
size_t
i
=
0
;
i
<
zerocopy_output
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
zerocopy_output
.
size
();
++
i
)
{
EXPECT_NEAR
(
zerocopy_output
[
i
],
native_data
[
i
],
1e-3
);
EXPECT_LT
(
std
::
fabs
((
zerocopy_output
[
i
]
-
native_data
[
i
])
/
zerocopy_output
[
i
]),
1e-3
);
}
}
}
}
...
...
paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc
0 → 100644
浏览文件 @
a7fc3d42
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.h"
#include <string>
#include <vector>
#include "paddle/fluid/operators/jit/kernels.h"
namespace
paddle
{
namespace
operators
{
void
FusionRepeatedFCReluOp
::
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of FusionRepeatedFCReluOp should not be null."
);
auto
sz
=
ctx
->
Inputs
(
"W"
).
size
();
PADDLE_ENFORCE_GT
(
sz
,
1UL
,
"Inputs(W) of FusionRepeatedFCReluOp should larger than 1."
);
PADDLE_ENFORCE_EQ
(
ctx
->
Inputs
(
"Bias"
).
size
(),
sz
,
"Size of inputs(Bias) of FusionRepeatedFCReluOp should be "
"equal to inputs size."
);
PADDLE_ENFORCE_EQ
(
ctx
->
Outputs
(
"ReluOut"
).
size
(),
sz
-
1
,
"Size of output(ReluOut) of FusionRepeatedFCReluOp should "
"be equal to inputs size -1."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of FusionRepeatedFCReluOp should not be null."
);
auto
i_dims
=
ctx
->
GetInputDim
(
"X"
);
PADDLE_ENFORCE_EQ
(
i_dims
.
size
(),
2UL
,
"Input shape size should be 2"
);
auto
w_dims
=
ctx
->
GetInputsDim
(
"W"
);
auto
b_dims
=
ctx
->
GetInputsDim
(
"Bias"
);
PADDLE_ENFORCE_EQ
(
w_dims
.
size
(),
b_dims
.
size
(),
"Shape size of weight and bias should be equal"
);
PADDLE_ENFORCE_EQ
(
w_dims
.
size
(),
sz
,
"Shape size of weight and bias should be equal"
);
PADDLE_ENFORCE_EQ
(
i_dims
[
1
],
w_dims
[
0
][
0
],
"inpute width should be equal with weight height"
);
for
(
size_t
i
=
1
;
i
<
sz
;
++
i
)
{
PADDLE_ENFORCE_EQ
(
w_dims
[
i
].
size
(),
2UL
,
"Every weight shape size should be 2."
);
PADDLE_ENFORCE_EQ
(
framework
::
product
(
b_dims
[
i
]),
w_dims
[
i
][
1
],
"The length of Bias must be equal with w_dims[1]."
);
}
ctx
->
SetOutputDim
(
"Out"
,
{
i_dims
[
0
],
w_dims
[
sz
-
1
][
1
]});
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
framework
::
OpKernelType
FusionRepeatedFCReluOp
::
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
return
framework
::
OpKernelType
(
framework
::
GetDataTypeOfVar
(
ctx
.
InputVar
(
"X"
)),
ctx
.
GetPlace
());
}
void
FusionRepeatedFCReluOpMaker
::
Make
()
{
AddInput
(
"X"
,
"(LoDTensor) Input tensors of this operator."
);
AddInput
(
"W"
,
"(Tensor) The weight tensors of this operator."
).
AsDuplicable
();
AddInput
(
"Bias"
,
"(Tensor) The bias tensors of this operator."
)
.
AsDuplicable
();
AddOutput
(
"ReluOut"
,
"(Tensor) The output tensor of each relu operator."
)
.
AsDuplicable
()
.
AsIntermediate
();
AddOutput
(
"Out"
,
"(LoDTensor) Output tensor of this operator."
);
AddComment
(
R"DOC(
Fusion Repeated FC with Relu Operator.
)DOC"
);
}
template
<
typename
T
>
static
void
fc_relu
(
const
T
*
x
,
const
T
*
w
,
const
T
*
b
,
T
*
y
,
int
m
,
int
n
,
int
k
)
{
auto
matmul
=
jit
::
Get
<
jit
::
kMatMul
,
jit
::
MatMulTuples
<
T
>
,
platform
::
CPUPlace
>
(
k
);
auto
addbias_relu
=
jit
::
Get
<
jit
::
kVAddRelu
,
jit
::
XYZNTuples
<
T
>
,
platform
::
CPUPlace
>
(
n
);
matmul
(
x
,
w
,
y
,
m
,
n
,
k
);
T
*
dst
=
y
;
for
(
int
i
=
0
;
i
<
m
;
++
i
)
{
addbias_relu
(
b
,
dst
,
dst
,
n
);
dst
+=
n
;
}
}
template
<
typename
T
>
class
FusionRepeatedFCReluKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
in
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
weights
=
ctx
.
MultiInput
<
Tensor
>
(
"W"
);
auto
biases
=
ctx
.
MultiInput
<
Tensor
>
(
"Bias"
);
auto
relus
=
ctx
.
MultiOutput
<
Tensor
>
(
"ReluOut"
);
auto
*
out
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
auto
place
=
ctx
.
GetPlace
();
int
weight_sz
=
static_cast
<
int
>
(
weights
.
size
());
auto
i_dims
=
in
->
dims
();
auto
w_dims
=
weights
[
0
]
->
dims
();
int
m
=
i_dims
[
0
];
int
n
=
w_dims
[
1
];
int
k
=
w_dims
[
0
];
relus
[
0
]
->
Resize
({
m
,
n
});
fc_relu
(
in
->
data
<
T
>
(),
weights
[
0
]
->
data
<
T
>
(),
biases
[
0
]
->
data
<
T
>
(),
relus
[
0
]
->
mutable_data
<
T
>
(
place
),
m
,
n
,
k
);
for
(
int
i
=
1
;
i
<
weight_sz
-
1
;
++
i
)
{
auto
i_dims
=
relus
[
i
-
1
]
->
dims
();
auto
w_dims
=
weights
[
i
]
->
dims
();
int
m
=
i_dims
[
0
];
int
n
=
w_dims
[
1
];
int
k
=
w_dims
[
0
];
relus
[
i
]
->
Resize
({
m
,
n
});
fc_relu
(
relus
[
i
-
1
]
->
data
<
T
>
(),
weights
[
i
]
->
data
<
T
>
(),
biases
[
i
]
->
data
<
T
>
(),
relus
[
i
]
->
mutable_data
<
T
>
(
place
),
m
,
n
,
k
);
}
auto
i_dims_last
=
relus
[
weight_sz
-
2
]
->
dims
();
auto
w_dims_last
=
weights
[
weight_sz
-
1
]
->
dims
();
m
=
i_dims_last
[
0
];
n
=
w_dims_last
[
1
];
k
=
w_dims_last
[
0
];
fc_relu
(
relus
[
weight_sz
-
2
]
->
data
<
T
>
(),
weights
[
weight_sz
-
1
]
->
data
<
T
>
(),
biases
[
weight_sz
-
1
]
->
data
<
T
>
(),
out
->
mutable_data
<
T
>
(
place
),
m
,
n
,
k
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
fusion_repeated_fc_relu
,
ops
::
FusionRepeatedFCReluOp
,
ops
::
FusionRepeatedFCReluOpMaker
,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
REGISTER_OP_CPU_KERNEL
(
fusion_repeated_fc_relu
,
ops
::
FusionRepeatedFCReluKernel
<
float
>
,
ops
::
FusionRepeatedFCReluKernel
<
double
>
);
paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.h
0 → 100644
浏览文件 @
a7fc3d42
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
using
LoDTensor
=
framework
::
LoDTensor
;
using
Tensor
=
framework
::
Tensor
;
class
FusionRepeatedFCReluOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
;
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
;
};
class
FusionRepeatedFCReluOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
;
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc
0 → 100644
浏览文件 @
a7fc3d42
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/fused/fusion_squared_mat_sub_op.h"
#include <string>
#include <vector>
#include "paddle/fluid/operators/jit/kernels.h"
namespace
paddle
{
namespace
operators
{
void
FusionSquaredMatSubOp
::
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of FusionSquaredMatSubOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) of FusionSquaredMatSubOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"SquaredX"
),
"Output(SquaredX) of FusionSquaredMatSubOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"SquaredY"
),
"Output(SquaredY) of FusionSquaredMatSubOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"SquaredXY"
),
"Output(SquaredXY) of FusionSquaredMatSubOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of FusionSquaredMatSubOp should not be null."
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
y_dims
=
ctx
->
GetInputDim
(
"Y"
);
PADDLE_ENFORCE_EQ
(
x_dims
.
size
(),
y_dims
.
size
(),
"Input tensors dims size should be equal."
);
PADDLE_ENFORCE_EQ
(
x_dims
.
size
(),
2UL
,
"Input tensors should be a Matrix."
);
PADDLE_ENFORCE_EQ
(
x_dims
[
1
],
y_dims
[
0
],
"Inputs Matrix should be multiply."
);
ctx
->
SetOutputDim
(
"SquaredX"
,
x_dims
);
ctx
->
SetOutputDim
(
"SquaredY"
,
y_dims
);
ctx
->
SetOutputDim
(
"SquaredXY"
,
{
x_dims
[
0
],
y_dims
[
1
]});
ctx
->
SetOutputDim
(
"Out"
,
{
x_dims
[
0
],
y_dims
[
1
]});
}
framework
::
OpKernelType
FusionSquaredMatSubOp
::
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
return
framework
::
OpKernelType
(
framework
::
GetDataTypeOfVar
(
ctx
.
InputVar
(
"X"
)),
ctx
.
GetPlace
());
}
void
FusionSquaredMatSubOpMaker
::
Make
()
{
AddInput
(
"X"
,
"(Tensor) Input Mat A of this operator."
);
AddInput
(
"Y"
,
"(Tensor) Input Mat B of this operator."
);
AddOutput
(
"SquaredX"
,
"(Tensor) Squared X."
).
AsIntermediate
();
AddOutput
(
"SquaredY"
,
"(Tensor) Squared Y."
).
AsIntermediate
();
AddOutput
(
"SquaredXY"
,
"(Tensor) Squared X*Y."
).
AsIntermediate
();
AddOutput
(
"Out"
,
"(Tensor) Output tensor of concat operator."
);
AddAttr
<
float
>
(
"scalar"
,
"The scalar on output matrix."
).
SetDefault
(
1.
f
);
AddComment
(
R"DOC(
Fusion Squared Matrix and substrct operator.
( (X * Y).^2 - (X.^2 * Y.^2) ) .* scalar
)DOC"
);
}
template
<
typename
T
>
class
FusionSquaredMatSubKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
y
=
ctx
.
Input
<
Tensor
>
(
"Y"
);
auto
*
squared_x
=
ctx
.
Output
<
Tensor
>
(
"SquaredX"
);
auto
*
squared_y
=
ctx
.
Output
<
Tensor
>
(
"SquaredY"
);
auto
*
squared_xy
=
ctx
.
Output
<
Tensor
>
(
"SquaredXY"
);
auto
*
out
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
auto
place
=
ctx
.
GetPlace
();
T
scalar
=
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"scalar"
));
auto
x_dims
=
x
->
dims
();
auto
y_dims
=
y
->
dims
();
int
m
=
x_dims
[
0
];
int
k
=
x_dims
[
1
];
int
n
=
y_dims
[
1
];
int
o_numel
=
m
*
n
;
auto
vsquare_x
=
jit
::
Get
<
jit
::
kVSquare
,
jit
::
XYNTuples
<
T
>
,
platform
::
CPUPlace
>
(
m
*
k
);
auto
vsquare_y
=
jit
::
Get
<
jit
::
kVSquare
,
jit
::
XYNTuples
<
T
>
,
platform
::
CPUPlace
>
(
k
*
n
);
auto
vsquare_xy
=
jit
::
Get
<
jit
::
kVSquare
,
jit
::
XYNTuples
<
T
>
,
platform
::
CPUPlace
>
(
o_numel
);
auto
vsub
=
jit
::
Get
<
jit
::
kVSub
,
jit
::
XYZNTuples
<
T
>
,
platform
::
CPUPlace
>
(
o_numel
);
auto
vscal
=
jit
::
Get
<
jit
::
kVScal
,
jit
::
AXYNTuples
<
T
>
,
platform
::
CPUPlace
>
(
o_numel
);
auto
matmul
=
jit
::
Get
<
jit
::
kMatMul
,
jit
::
MatMulTuples
<
T
>
,
platform
::
CPUPlace
>
(
k
);
const
T
*
x_data
=
x
->
data
<
T
>
();
const
T
*
y_data
=
y
->
data
<
T
>
();
T
*
squared_x_data
=
squared_x
->
mutable_data
<
T
>
(
place
);
T
*
squared_y_data
=
squared_y
->
mutable_data
<
T
>
(
place
);
T
*
squared_xy_data
=
squared_xy
->
mutable_data
<
T
>
(
place
);
T
*
o_data
=
out
->
mutable_data
<
T
>
(
place
);
matmul
(
x_data
,
y_data
,
squared_xy_data
,
m
,
n
,
k
);
vsquare_xy
(
squared_xy_data
,
squared_xy_data
,
o_numel
);
vsquare_x
(
x_data
,
squared_x_data
,
m
*
k
);
vsquare_y
(
y_data
,
squared_y_data
,
k
*
n
);
matmul
(
squared_x_data
,
squared_y_data
,
o_data
,
m
,
n
,
k
);
vsub
(
squared_xy_data
,
o_data
,
o_data
,
o_numel
);
vscal
(
&
scalar
,
o_data
,
o_data
,
o_numel
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
fusion_squared_mat_sub
,
ops
::
FusionSquaredMatSubOp
,
ops
::
FusionSquaredMatSubOpMaker
,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
REGISTER_OP_CPU_KERNEL
(
fusion_squared_mat_sub
,
ops
::
FusionSquaredMatSubKernel
<
float
>
,
ops
::
FusionSquaredMatSubKernel
<
double
>
);
paddle/fluid/operators/fused/fusion_squared_mat_sub_op.h
0 → 100644
浏览文件 @
a7fc3d42
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
using
LoDTensor
=
framework
::
LoDTensor
;
using
Tensor
=
framework
::
Tensor
;
// ( (A.^2 * B.^2) - (A * B).^2 ) .* scalar
class
FusionSquaredMatSubOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
;
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
;
};
class
FusionSquaredMatSubOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
;
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit/benchmark.cc
浏览文件 @
a7fc3d42
...
@@ -210,6 +210,24 @@ void BenchSeqPoolKernel() {
...
@@ -210,6 +210,24 @@ void BenchSeqPoolKernel() {
}
}
}
}
template
<
paddle
::
operators
::
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
BenchMatMulKernel
()
{
for
(
int
m
:
{
1
,
2
,
3
,
4
})
{
for
(
int
n
:
TestSizes
())
{
for
(
int
k
:
TestSizes
())
{
std
::
vector
<
T
>
a
(
m
*
k
),
b
(
k
*
n
),
c
(
m
*
n
);
RandomVec
<
T
>
(
m
*
k
,
a
.
data
(),
-
2.
f
,
2.
f
);
RandomVec
<
T
>
(
k
*
n
,
b
.
data
(),
-
2.
f
,
2.
f
);
const
T
*
a_data
=
a
.
data
();
const
T
*
b_data
=
b
.
data
();
T
*
c_data
=
c
.
data
();
BenchAllImpls
<
KT
,
jit
::
MatMulTuples
<
T
>
,
PlaceType
>
(
k
,
a_data
,
b_data
,
c_data
,
m
,
n
,
k
);
}
}
}
}
// Benchmark all jit kernels including jitcode, mkl and refer.
// Benchmark all jit kernels including jitcode, mkl and refer.
// To use this tool, run command: ./benchmark [options...]
// To use this tool, run command: ./benchmark [options...]
// Options:
// Options:
...
@@ -236,6 +254,7 @@ int main(int argc, char* argv[]) {
...
@@ -236,6 +254,7 @@ int main(int argc, char* argv[]) {
// xyn
// xyn
BenchXYNKernel
<
jit
::
kVRelu
,
T
,
PlaceType
>
();
BenchXYNKernel
<
jit
::
kVRelu
,
T
,
PlaceType
>
();
BenchXYNKernel
<
jit
::
kVIdentity
,
T
,
PlaceType
>
();
BenchXYNKernel
<
jit
::
kVIdentity
,
T
,
PlaceType
>
();
BenchXYNKernel
<
jit
::
kVSquare
,
T
,
PlaceType
>
();
BenchXYNKernel
<
jit
::
kVExp
,
T
,
PlaceType
>
();
BenchXYNKernel
<
jit
::
kVExp
,
T
,
PlaceType
>
();
BenchXYNKernel
<
jit
::
kVSigmoid
,
T
,
PlaceType
>
();
BenchXYNKernel
<
jit
::
kVSigmoid
,
T
,
PlaceType
>
();
BenchXYNKernel
<
jit
::
kVTanh
,
T
,
PlaceType
>
();
BenchXYNKernel
<
jit
::
kVTanh
,
T
,
PlaceType
>
();
...
@@ -251,4 +270,7 @@ int main(int argc, char* argv[]) {
...
@@ -251,4 +270,7 @@ int main(int argc, char* argv[]) {
// seq pool function
// seq pool function
BenchSeqPoolKernel
<
jit
::
kSeqPool
,
T
,
PlaceType
>
();
BenchSeqPoolKernel
<
jit
::
kSeqPool
,
T
,
PlaceType
>
();
// matmul
BenchMatMulKernel
<
jit
::
kMatMul
,
T
,
PlaceType
>
();
}
}
paddle/fluid/operators/jit/gen/CMakeLists.txt
浏览文件 @
a7fc3d42
...
@@ -11,11 +11,12 @@ endfunction()
...
@@ -11,11 +11,12 @@ endfunction()
# use gen jitcode kernel by name
# use gen jitcode kernel by name
USE_JITKERNEL_GEN
(
kVMul
)
USE_JITKERNEL_GEN
(
kVMul
)
USE_JITKERNEL_GEN
(
kVAdd
)
USE_JITKERNEL_GEN
(
kVAdd
)
#USE_JITKERNEL_GEN(kVSub) # TODO(TJ): enable me
USE_JITKERNEL_GEN
(
kVSub
)
USE_JITKERNEL_GEN
(
kVAddRelu
)
USE_JITKERNEL_GEN
(
kVAddRelu
)
USE_JITKERNEL_GEN
(
kVScal
)
USE_JITKERNEL_GEN
(
kVScal
)
USE_JITKERNEL_GEN
(
kVAddBias
)
USE_JITKERNEL_GEN
(
kVAddBias
)
USE_JITKERNEL_GEN
(
kVRelu
)
USE_JITKERNEL_GEN
(
kVRelu
)
USE_JITKERNEL_GEN
(
kVSquare
)
USE_JITKERNEL_GEN
(
kVIdentity
)
USE_JITKERNEL_GEN
(
kVIdentity
)
USE_JITKERNEL_GEN
(
kVExp
)
USE_JITKERNEL_GEN
(
kVExp
)
USE_JITKERNEL_GEN
(
kVSigmoid
)
USE_JITKERNEL_GEN
(
kVSigmoid
)
...
...
paddle/fluid/operators/jit/gen/act.cc
浏览文件 @
a7fc3d42
...
@@ -91,6 +91,7 @@ void VActJitCode::genCode() {
...
@@ -91,6 +91,7 @@ void VActJitCode::genCode() {
}
}
DECLARE_ACT_CREATOR
(
VRelu
);
DECLARE_ACT_CREATOR
(
VRelu
);
DECLARE_ACT_CREATOR
(
VSquare
);
DECLARE_ACT_CREATOR
(
VIdentity
);
DECLARE_ACT_CREATOR
(
VIdentity
);
DECLARE_ACT_CREATOR
(
VExp
);
DECLARE_ACT_CREATOR
(
VExp
);
DECLARE_ACT_CREATOR
(
VSigmoid
);
DECLARE_ACT_CREATOR
(
VSigmoid
);
...
@@ -103,6 +104,10 @@ size_t VReluCreator::CodeSize(const int& d) const {
...
@@ -103,6 +104,10 @@ size_t VReluCreator::CodeSize(const int& d) const {
8
/* average bytes for each instruction */
;
8
/* average bytes for each instruction */
;
}
}
size_t
VSquareCreator
::
CodeSize
(
const
int
&
d
)
const
{
return
96
+
(
d
/
YMM_FLOAT_BLOCK
+
3
)
*
4
*
8
;
}
size_t
VIdentityCreator
::
CodeSize
(
const
int
&
d
)
const
{
size_t
VIdentityCreator
::
CodeSize
(
const
int
&
d
)
const
{
return
96
+
(
d
/
YMM_FLOAT_BLOCK
+
3
)
*
4
*
8
;
return
96
+
(
d
/
YMM_FLOAT_BLOCK
+
3
)
*
4
*
8
;
}
}
...
@@ -129,6 +134,7 @@ size_t VTanhCreator::CodeSize(const int& d) const {
...
@@ -129,6 +134,7 @@ size_t VTanhCreator::CodeSize(const int& d) const {
namespace
gen
=
paddle
::
operators
::
jit
::
gen
;
namespace
gen
=
paddle
::
operators
::
jit
::
gen
;
REGISTER_JITKERNEL_GEN
(
kVRelu
,
gen
::
VReluCreator
);
REGISTER_JITKERNEL_GEN
(
kVRelu
,
gen
::
VReluCreator
);
REGISTER_JITKERNEL_GEN
(
kVSquare
,
gen
::
VSquareCreator
);
REGISTER_JITKERNEL_GEN
(
kVIdentity
,
gen
::
VIdentityCreator
);
REGISTER_JITKERNEL_GEN
(
kVIdentity
,
gen
::
VIdentityCreator
);
REGISTER_JITKERNEL_GEN
(
kVExp
,
gen
::
VExpCreator
);
REGISTER_JITKERNEL_GEN
(
kVExp
,
gen
::
VExpCreator
);
REGISTER_JITKERNEL_GEN
(
kVSigmoid
,
gen
::
VSigmoidCreator
);
REGISTER_JITKERNEL_GEN
(
kVSigmoid
,
gen
::
VSigmoidCreator
);
...
...
paddle/fluid/operators/jit/gen/act.h
浏览文件 @
a7fc3d42
...
@@ -75,6 +75,12 @@ class VActFunc : public JitCode {
...
@@ -75,6 +75,12 @@ class VActFunc : public JitCode {
vmaxps
(
dst
,
src
,
zero
);
vmaxps
(
dst
,
src
,
zero
);
}
}
// compute SQUARE with ymm, xmm
template
<
typename
JMM
>
void
square_jmm
(
JMM
&
dst
,
JMM
&
src
)
{
// NOLINT
vmulps
(
dst
,
src
,
src
);
}
// compute EXP with ymm, xmm
// compute EXP with ymm, xmm
template
<
typename
JMM
>
template
<
typename
JMM
>
void
exp_jmm
(
JMM
&
dst
,
JMM
&
src
,
int
src_idx
=
11
,
int
fx_idx
=
12
,
// NOLINT
void
exp_jmm
(
JMM
&
dst
,
JMM
&
src
,
int
src_idx
=
11
,
int
fx_idx
=
12
,
// NOLINT
...
@@ -228,6 +234,9 @@ class VActFunc : public JitCode {
...
@@ -228,6 +234,9 @@ class VActFunc : public JitCode {
case
operand_type
::
RELU
:
case
operand_type
::
RELU
:
relu_jmm
<
JMM
>
(
dst
,
src
,
15
);
relu_jmm
<
JMM
>
(
dst
,
src
,
15
);
break
;
break
;
case
operand_type
::
SQUARE
:
square_jmm
<
JMM
>
(
dst
,
src
);
break
;
case
operand_type
::
EXP
:
case
operand_type
::
EXP
:
exp_jmm
<
JMM
>
(
dst
,
src
,
11
,
12
,
13
,
14
,
15
);
exp_jmm
<
JMM
>
(
dst
,
src
,
11
,
12
,
13
,
14
,
15
);
break
;
break
;
...
@@ -254,7 +263,7 @@ class VActJitCode : public VActFunc {
...
@@ -254,7 +263,7 @@ class VActJitCode : public VActFunc {
:
VActFunc
(
code_size
,
code_ptr
),
num_
(
d
),
type_
(
type
)
{
:
VActFunc
(
code_size
,
code_ptr
),
num_
(
d
),
type_
(
type
)
{
if
(
!
(
type_
==
operand_type
::
RELU
||
type_
==
operand_type
::
EXP
||
if
(
!
(
type_
==
operand_type
::
RELU
||
type_
==
operand_type
::
EXP
||
type_
==
operand_type
::
SIGMOID
||
type_
==
operand_type
::
TANH
||
type_
==
operand_type
::
SIGMOID
||
type_
==
operand_type
::
TANH
||
type_
==
operand_type
::
IDENTITY
))
{
type_
==
operand_type
::
IDENTITY
||
type_
==
operand_type
::
SQUARE
))
{
LOG
(
FATAL
)
<<
"Do not support this operand type: "
<<
type_
;
LOG
(
FATAL
)
<<
"Do not support this operand type: "
<<
type_
;
}
}
this
->
genCode
();
this
->
genCode
();
...
@@ -266,6 +275,9 @@ class VActJitCode : public VActFunc {
...
@@ -266,6 +275,9 @@ class VActJitCode : public VActFunc {
case
operand_type
::
RELU
:
case
operand_type
::
RELU
:
base
+=
"_Relu"
;
base
+=
"_Relu"
;
break
;
break
;
case
operand_type
::
SQUARE
:
base
+=
"_Square"
;
break
;
case
operand_type
::
EXP
:
case
operand_type
::
EXP
:
base
+=
"_Exp"
;
base
+=
"_Exp"
;
break
;
break
;
...
@@ -306,6 +318,7 @@ class VActJitCode : public VActFunc {
...
@@ -306,6 +318,7 @@ class VActJitCode : public VActFunc {
};
};
DECLARE_ACT_JITCODE
(
VRelu
,
operand_type
::
RELU
);
DECLARE_ACT_JITCODE
(
VRelu
,
operand_type
::
RELU
);
DECLARE_ACT_JITCODE
(
VSquare
,
operand_type
::
SQUARE
);
DECLARE_ACT_JITCODE
(
VIdentity
,
operand_type
::
IDENTITY
);
DECLARE_ACT_JITCODE
(
VIdentity
,
operand_type
::
IDENTITY
);
DECLARE_ACT_JITCODE
(
VExp
,
operand_type
::
EXP
);
DECLARE_ACT_JITCODE
(
VExp
,
operand_type
::
EXP
);
DECLARE_ACT_JITCODE
(
VSigmoid
,
operand_type
::
SIGMOID
);
DECLARE_ACT_JITCODE
(
VSigmoid
,
operand_type
::
SIGMOID
);
...
...
paddle/fluid/operators/jit/gen/blas.cc
浏览文件 @
a7fc3d42
...
@@ -43,6 +43,8 @@ void VXXJitCode::genCode() {
...
@@ -43,6 +43,8 @@ void VXXJitCode::genCode() {
vmulps
(
ymm_dst
,
ymm_src1
,
ymm_src2
);
vmulps
(
ymm_dst
,
ymm_src1
,
ymm_src2
);
}
else
if
(
type_
==
operand_type
::
ADD
)
{
}
else
if
(
type_
==
operand_type
::
ADD
)
{
vaddps
(
ymm_dst
,
ymm_src1
,
ymm_src2
);
vaddps
(
ymm_dst
,
ymm_src1
,
ymm_src2
);
}
else
if
(
type_
==
operand_type
::
SUB
)
{
vsubps
(
ymm_dst
,
ymm_src1
,
ymm_src2
);
}
}
if
(
with_relu_
)
{
if
(
with_relu_
)
{
vmaxps
(
ymm_dst
,
ymm_zero
,
ymm_dst
);
vmaxps
(
ymm_dst
,
ymm_zero
,
ymm_dst
);
...
@@ -85,6 +87,9 @@ void VXXJitCode::genCode() {
...
@@ -85,6 +87,9 @@ void VXXJitCode::genCode() {
case
operand_type
::
ADD
:
case
operand_type
::
ADD
:
vaddps
(
xmm_dst
,
xmm_src1
,
xmm_src2
);
vaddps
(
xmm_dst
,
xmm_src1
,
xmm_src2
);
break
;
break
;
case
operand_type
::
SUB
:
vsubps
(
xmm_dst
,
xmm_src1
,
xmm_src2
);
break
;
default:
default:
break
;
break
;
}
}
...
@@ -178,8 +183,7 @@ namespace gen = paddle::operators::jit::gen;
...
@@ -178,8 +183,7 @@ namespace gen = paddle::operators::jit::gen;
REGISTER_JITKERNEL_GEN
(
kVMul
,
gen
::
VMulCreator
);
REGISTER_JITKERNEL_GEN
(
kVMul
,
gen
::
VMulCreator
);
REGISTER_JITKERNEL_GEN
(
kVAdd
,
gen
::
VAddCreator
);
REGISTER_JITKERNEL_GEN
(
kVAdd
,
gen
::
VAddCreator
);
// TODO(TJ): enable sub
REGISTER_JITKERNEL_GEN
(
kVSub
,
gen
::
VSubCreator
);
// REGISTER_JITKERNEL_GEN(kVSub, gen::VSubCreator);
REGISTER_JITKERNEL_GEN
(
kVAddRelu
,
gen
::
VAddReluCreator
);
REGISTER_JITKERNEL_GEN
(
kVAddRelu
,
gen
::
VAddReluCreator
);
REGISTER_JITKERNEL_GEN
(
kVScal
,
gen
::
VScalCreator
);
REGISTER_JITKERNEL_GEN
(
kVScal
,
gen
::
VScalCreator
);
REGISTER_JITKERNEL_GEN
(
kVAddBias
,
gen
::
VAddBiasCreator
);
REGISTER_JITKERNEL_GEN
(
kVAddBias
,
gen
::
VAddBiasCreator
);
...
...
paddle/fluid/operators/jit/gen/blas.h
浏览文件 @
a7fc3d42
...
@@ -34,7 +34,8 @@ class VXXJitCode : public JitCode {
...
@@ -34,7 +34,8 @@ class VXXJitCode : public JitCode {
type_
(
type
),
type_
(
type
),
scalar_index_
(
scalar_index
),
scalar_index_
(
scalar_index
),
with_relu_
(
with_relu
)
{
with_relu_
(
with_relu
)
{
if
(
!
(
type_
==
operand_type
::
MUL
||
type_
==
operand_type
::
ADD
))
{
if
(
!
(
type_
==
operand_type
::
MUL
||
type_
==
operand_type
::
ADD
||
type_
==
operand_type
::
SUB
))
{
LOG
(
FATAL
)
<<
"Do not support this operand type: "
<<
type_
;
LOG
(
FATAL
)
<<
"Do not support this operand type: "
<<
type_
;
}
}
this
->
genCode
();
this
->
genCode
();
...
@@ -51,6 +52,8 @@ class VXXJitCode : public JitCode {
...
@@ -51,6 +52,8 @@ class VXXJitCode : public JitCode {
base
+=
"_Mul"
;
base
+=
"_Mul"
;
}
else
if
(
type_
==
operand_type
::
ADD
)
{
}
else
if
(
type_
==
operand_type
::
ADD
)
{
base
+=
"_Add"
;
base
+=
"_Add"
;
}
else
if
(
type_
==
operand_type
::
SUB
)
{
base
+=
"_SUB"
;
}
}
if
(
scalar_index_
==
2
)
{
if
(
scalar_index_
==
2
)
{
base
+=
"_Scalar"
;
base
+=
"_Scalar"
;
...
...
paddle/fluid/operators/jit/gen/jitcode.h
浏览文件 @
a7fc3d42
...
@@ -51,6 +51,7 @@ typedef enum {
...
@@ -51,6 +51,7 @@ typedef enum {
SUB
,
SUB
,
RELU
,
RELU
,
EXP
,
EXP
,
SQUARE
,
SIGMOID
,
SIGMOID
,
TANH
,
TANH
,
IDENTITY
IDENTITY
...
...
paddle/fluid/operators/jit/helper.cc
浏览文件 @
a7fc3d42
...
@@ -36,6 +36,7 @@ const char* to_string(KernelType kt) {
...
@@ -36,6 +36,7 @@ const char* to_string(KernelType kt) {
ONE_CASE
(
kVRelu
);
ONE_CASE
(
kVRelu
);
ONE_CASE
(
kVIdentity
);
ONE_CASE
(
kVIdentity
);
ONE_CASE
(
kVExp
);
ONE_CASE
(
kVExp
);
ONE_CASE
(
kVSquare
);
ONE_CASE
(
kVSigmoid
);
ONE_CASE
(
kVSigmoid
);
ONE_CASE
(
kVTanh
);
ONE_CASE
(
kVTanh
);
ONE_CASE
(
kLSTMCtHt
);
ONE_CASE
(
kLSTMCtHt
);
...
@@ -47,6 +48,7 @@ const char* to_string(KernelType kt) {
...
@@ -47,6 +48,7 @@ const char* to_string(KernelType kt) {
ONE_CASE
(
kLayerNorm
);
ONE_CASE
(
kLayerNorm
);
ONE_CASE
(
kNCHW16CMulNC
);
ONE_CASE
(
kNCHW16CMulNC
);
ONE_CASE
(
kSeqPool
);
ONE_CASE
(
kSeqPool
);
ONE_CASE
(
kMatMul
);
default:
default:
PADDLE_THROW
(
"Not support type: %d, or forget to add it."
,
kt
);
PADDLE_THROW
(
"Not support type: %d, or forget to add it."
,
kt
);
return
"NOT JITKernel"
;
return
"NOT JITKernel"
;
...
...
paddle/fluid/operators/jit/kernel_base.h
浏览文件 @
a7fc3d42
...
@@ -30,6 +30,7 @@ typedef enum {
...
@@ -30,6 +30,7 @@ typedef enum {
kVAddBias
,
kVAddBias
,
kVRelu
,
kVRelu
,
kVIdentity
,
kVIdentity
,
kVSquare
,
kVExp
,
kVExp
,
kVSigmoid
,
kVSigmoid
,
kVTanh
,
kVTanh
,
...
@@ -42,6 +43,7 @@ typedef enum {
...
@@ -42,6 +43,7 @@ typedef enum {
kLayerNorm
,
kLayerNorm
,
kNCHW16CMulNC
,
kNCHW16CMulNC
,
kSeqPool
,
kSeqPool
,
kMatMul
,
}
KernelType
;
}
KernelType
;
typedef
enum
{
typedef
enum
{
...
@@ -135,6 +137,13 @@ struct SeqPoolTuples {
...
@@ -135,6 +137,13 @@ struct SeqPoolTuples {
typedef
void
(
*
func_type
)(
const
T
*
,
T
*
,
const
seq_pool_attr_t
*
);
typedef
void
(
*
func_type
)(
const
T
*
,
T
*
,
const
seq_pool_attr_t
*
);
};
};
template
<
typename
T
>
struct
MatMulTuples
{
typedef
T
data_type
;
typedef
int
attr_type
;
typedef
void
(
*
func_type
)(
const
T
*
,
const
T
*
,
T
*
,
int
,
int
,
int
);
};
template
<
typename
T
>
template
<
typename
T
>
struct
CRFDecodingTuples
{
struct
CRFDecodingTuples
{
typedef
T
data_type
;
typedef
T
data_type
;
...
...
paddle/fluid/operators/jit/more/mkl/CMakeLists.txt
浏览文件 @
a7fc3d42
...
@@ -3,10 +3,12 @@ cc_library(jit_kernel_mkl SRCS mkl.cc DEPS jit_kernel_base dynload_mklml)
...
@@ -3,10 +3,12 @@ cc_library(jit_kernel_mkl SRCS mkl.cc DEPS jit_kernel_base dynload_mklml)
set
(
JIT_KERNEL_DEPS
${
JIT_KERNEL_DEPS
}
dynload_mklml jit_kernel_mkl PARENT_SCOPE
)
set
(
JIT_KERNEL_DEPS
${
JIT_KERNEL_DEPS
}
dynload_mklml jit_kernel_mkl PARENT_SCOPE
)
# use mkl kernels by name and type
# use mkl kernels by name and type
USE_JITKERNEL_MORE
(
kMatMul, mkl
)
USE_JITKERNEL_MORE
(
kVMul, mkl
)
USE_JITKERNEL_MORE
(
kVMul, mkl
)
USE_JITKERNEL_MORE
(
kVAdd, mkl
)
USE_JITKERNEL_MORE
(
kVAdd, mkl
)
USE_JITKERNEL_MORE
(
kVScal, mkl
)
USE_JITKERNEL_MORE
(
kVScal, mkl
)
USE_JITKERNEL_MORE
(
kVExp, mkl
)
USE_JITKERNEL_MORE
(
kVExp, mkl
)
USE_JITKERNEL_MORE
(
kVSquare, mkl
)
USE_JITKERNEL_MORE
(
kVSigmoid, mkl
)
USE_JITKERNEL_MORE
(
kVSigmoid, mkl
)
USE_JITKERNEL_MORE
(
kVTanh, mkl
)
USE_JITKERNEL_MORE
(
kVTanh, mkl
)
USE_JITKERNEL_MORE
(
kSeqPool, mkl
)
USE_JITKERNEL_MORE
(
kSeqPool, mkl
)
paddle/fluid/operators/jit/more/mkl/mkl.cc
浏览文件 @
a7fc3d42
...
@@ -24,6 +24,20 @@ namespace jit {
...
@@ -24,6 +24,20 @@ namespace jit {
namespace
more
{
namespace
more
{
namespace
mkl
{
namespace
mkl
{
template
<
>
void
MatMul
<
float
>
(
const
float
*
a
,
const
float
*
b
,
float
*
c
,
int
m
,
int
n
,
int
k
)
{
platform
::
dynload
::
cblas_sgemm
(
CblasRowMajor
,
CblasNoTrans
,
CblasNoTrans
,
m
,
n
,
k
,
1.
f
,
a
,
k
,
b
,
n
,
0.
f
,
c
,
n
);
}
template
<
>
void
MatMul
<
double
>
(
const
double
*
a
,
const
double
*
b
,
double
*
c
,
int
m
,
int
n
,
int
k
)
{
platform
::
dynload
::
cblas_dgemm
(
CblasRowMajor
,
CblasNoTrans
,
CblasNoTrans
,
m
,
n
,
k
,
1.0
,
a
,
k
,
b
,
n
,
0.0
,
c
,
n
);
}
template
<
>
template
<
>
void
VMul
<
float
>
(
const
float
*
x
,
const
float
*
y
,
float
*
z
,
int
n
)
{
void
VMul
<
float
>
(
const
float
*
x
,
const
float
*
y
,
float
*
z
,
int
n
)
{
platform
::
dynload
::
vsMul
(
n
,
x
,
y
,
z
);
platform
::
dynload
::
vsMul
(
n
,
x
,
y
,
z
);
...
@@ -72,6 +86,16 @@ void VExp<double>(const double* x, double* y, int n) {
...
@@ -72,6 +86,16 @@ void VExp<double>(const double* x, double* y, int n) {
platform
::
dynload
::
vdExp
(
n
,
x
,
y
);
platform
::
dynload
::
vdExp
(
n
,
x
,
y
);
}
}
template
<
>
void
VSquare
<
float
>
(
const
float
*
x
,
float
*
y
,
int
n
)
{
platform
::
dynload
::
vsSqr
(
n
,
x
,
y
);
}
template
<
>
void
VSquare
<
double
>
(
const
double
*
x
,
double
*
y
,
int
n
)
{
platform
::
dynload
::
vdSqr
(
n
,
x
,
y
);
}
template
<
>
template
<
>
void
VCopy
<
float
>
(
const
float
*
x
,
float
*
y
,
int
n
)
{
void
VCopy
<
float
>
(
const
float
*
x
,
float
*
y
,
int
n
)
{
platform
::
dynload
::
cblas_scopy
(
n
,
x
,
1
,
y
,
1
);
platform
::
dynload
::
cblas_scopy
(
n
,
x
,
1
,
y
,
1
);
...
@@ -93,6 +117,11 @@ void VAXPY<double>(double a, const double* x, double* y, int n) {
...
@@ -93,6 +117,11 @@ void VAXPY<double>(double a, const double* x, double* y, int n) {
}
}
// TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512
// TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512
template
<
>
bool
MatMulKernel
<
float
>::
UseMe
(
const
int
&
d
)
const
{
return
platform
::
MayIUse
(
platform
::
avx
);
}
template
<
>
template
<
>
bool
VMulKernel
<
float
>::
UseMe
(
const
int
&
d
)
const
{
bool
VMulKernel
<
float
>::
UseMe
(
const
int
&
d
)
const
{
return
platform
::
MayIUse
(
platform
::
avx512f
)
&&
d
>
512
;
return
platform
::
MayIUse
(
platform
::
avx512f
)
&&
d
>
512
;
...
@@ -113,6 +142,11 @@ bool VExpKernel<float>::UseMe(const int& d) const {
...
@@ -113,6 +142,11 @@ bool VExpKernel<float>::UseMe(const int& d) const {
return
d
>
7
;
return
d
>
7
;
}
}
template
<
>
bool
VSquareKernel
<
float
>::
UseMe
(
const
int
&
d
)
const
{
return
d
>
7
;
}
template
<
>
template
<
>
bool
VSigmoidKernel
<
float
>::
UseMe
(
const
int
&
d
)
const
{
bool
VSigmoidKernel
<
float
>::
UseMe
(
const
int
&
d
)
const
{
return
d
>
7
;
return
d
>
7
;
...
@@ -139,12 +173,14 @@ bool SeqPoolKernel<double>::UseMe(const seq_pool_attr_t& attr) const {
...
@@ -139,12 +173,14 @@ bool SeqPoolKernel<double>::UseMe(const seq_pool_attr_t& attr) const {
return true; \
return true; \
}
}
AWALYS_USE_ME_WITH_DOUBLE
(
MatMul
);
AWALYS_USE_ME_WITH_DOUBLE
(
VMul
);
AWALYS_USE_ME_WITH_DOUBLE
(
VMul
);
AWALYS_USE_ME_WITH_DOUBLE
(
VAdd
);
AWALYS_USE_ME_WITH_DOUBLE
(
VAdd
);
AWALYS_USE_ME_WITH_DOUBLE
(
VScal
);
AWALYS_USE_ME_WITH_DOUBLE
(
VScal
);
AWALYS_USE_ME_WITH_DOUBLE
(
VExp
);
AWALYS_USE_ME_WITH_DOUBLE
(
VExp
);
AWALYS_USE_ME_WITH_DOUBLE
(
VSigmoid
);
AWALYS_USE_ME_WITH_DOUBLE
(
VSigmoid
);
AWALYS_USE_ME_WITH_DOUBLE
(
VTanh
);
AWALYS_USE_ME_WITH_DOUBLE
(
VTanh
);
AWALYS_USE_ME_WITH_DOUBLE
(
VSquare
);
#undef AWALYS_USE_ME_WITH_DOUBLE
#undef AWALYS_USE_ME_WITH_DOUBLE
}
// namespace mkl
}
// namespace mkl
...
@@ -159,10 +195,12 @@ namespace mkl = paddle::operators::jit::more::mkl;
...
@@ -159,10 +195,12 @@ namespace mkl = paddle::operators::jit::more::mkl;
REGISTER_JITKERNEL_MORE(key, mkl, mkl::func##Kernel<float>, \
REGISTER_JITKERNEL_MORE(key, mkl, mkl::func##Kernel<float>, \
mkl::func##Kernel<double>)
mkl::func##Kernel<double>)
REGISTER_MKL_KERNEL
(
kMatMul
,
MatMul
);
REGISTER_MKL_KERNEL
(
kVMul
,
VMul
);
REGISTER_MKL_KERNEL
(
kVMul
,
VMul
);
REGISTER_MKL_KERNEL
(
kVAdd
,
VAdd
);
REGISTER_MKL_KERNEL
(
kVAdd
,
VAdd
);
REGISTER_MKL_KERNEL
(
kVScal
,
VScal
);
REGISTER_MKL_KERNEL
(
kVScal
,
VScal
);
REGISTER_MKL_KERNEL
(
kVExp
,
VExp
);
REGISTER_MKL_KERNEL
(
kVExp
,
VExp
);
REGISTER_MKL_KERNEL
(
kVSquare
,
VSquare
);
REGISTER_MKL_KERNEL
(
kVSigmoid
,
VSigmoid
);
REGISTER_MKL_KERNEL
(
kVSigmoid
,
VSigmoid
);
REGISTER_MKL_KERNEL
(
kVTanh
,
VTanh
);
REGISTER_MKL_KERNEL
(
kVTanh
,
VTanh
);
REGISTER_MKL_KERNEL
(
kSeqPool
,
SeqPool
);
REGISTER_MKL_KERNEL
(
kSeqPool
,
SeqPool
);
...
...
paddle/fluid/operators/jit/more/mkl/mkl.h
浏览文件 @
a7fc3d42
...
@@ -24,6 +24,9 @@ namespace jit {
...
@@ -24,6 +24,9 @@ namespace jit {
namespace
more
{
namespace
more
{
namespace
mkl
{
namespace
mkl
{
template
<
typename
T
>
void
MatMul
(
const
T
*
a
,
const
T
*
b
,
T
*
c
,
int
m
,
int
n
,
int
k
);
template
<
typename
T
>
template
<
typename
T
>
void
VMul
(
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int
n
);
void
VMul
(
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int
n
);
...
@@ -36,6 +39,9 @@ void VScal(const T* a, const T* x, T* y, int n);
...
@@ -36,6 +39,9 @@ void VScal(const T* a, const T* x, T* y, int n);
template
<
typename
T
>
template
<
typename
T
>
void
VExp
(
const
T
*
x
,
T
*
y
,
int
n
);
void
VExp
(
const
T
*
x
,
T
*
y
,
int
n
);
template
<
typename
T
>
void
VSquare
(
const
T
*
x
,
T
*
y
,
int
n
);
template
<
typename
T
>
template
<
typename
T
>
void
VCopy
(
const
T
*
x
,
T
*
y
,
int
n
);
void
VCopy
(
const
T
*
x
,
T
*
y
,
int
n
);
...
@@ -93,6 +99,9 @@ void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) {
...
@@ -93,6 +99,9 @@ void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) {
const char* ImplType() const override { return "MKL"; } \
const char* ImplType() const override { return "MKL"; } \
}
}
// ABCMNK
DECLARE_MKL_KERNEL
(
MatMul
,
MatMulTuples
);
// XYZN
// XYZN
DECLARE_MKL_KERNEL
(
VMul
,
XYZNTuples
);
DECLARE_MKL_KERNEL
(
VMul
,
XYZNTuples
);
DECLARE_MKL_KERNEL
(
VAdd
,
XYZNTuples
);
DECLARE_MKL_KERNEL
(
VAdd
,
XYZNTuples
);
...
@@ -104,6 +113,7 @@ DECLARE_MKL_KERNEL(VScal, AXYNTuples);
...
@@ -104,6 +113,7 @@ DECLARE_MKL_KERNEL(VScal, AXYNTuples);
DECLARE_MKL_KERNEL
(
VExp
,
XYNTuples
);
DECLARE_MKL_KERNEL
(
VExp
,
XYNTuples
);
DECLARE_MKL_KERNEL
(
VSigmoid
,
XYNTuples
);
DECLARE_MKL_KERNEL
(
VSigmoid
,
XYNTuples
);
DECLARE_MKL_KERNEL
(
VTanh
,
XYNTuples
);
DECLARE_MKL_KERNEL
(
VTanh
,
XYNTuples
);
DECLARE_MKL_KERNEL
(
VSquare
,
XYNTuples
);
DECLARE_MKL_KERNEL
(
SeqPool
,
SeqPoolTuples
);
DECLARE_MKL_KERNEL
(
SeqPool
,
SeqPoolTuples
);
...
...
paddle/fluid/operators/jit/refer/CMakeLists.txt
浏览文件 @
a7fc3d42
...
@@ -27,3 +27,5 @@ USE_JITKERNEL_REFER(kCRFDecoding)
...
@@ -27,3 +27,5 @@ USE_JITKERNEL_REFER(kCRFDecoding)
USE_JITKERNEL_REFER
(
kLayerNorm
)
USE_JITKERNEL_REFER
(
kLayerNorm
)
USE_JITKERNEL_REFER
(
kNCHW16CMulNC
)
USE_JITKERNEL_REFER
(
kNCHW16CMulNC
)
USE_JITKERNEL_REFER
(
kSeqPool
)
USE_JITKERNEL_REFER
(
kSeqPool
)
USE_JITKERNEL_REFER
(
kMatMul
)
USE_JITKERNEL_REFER
(
kVSquare
)
paddle/fluid/operators/jit/refer/refer.cc
浏览文件 @
a7fc3d42
...
@@ -31,6 +31,7 @@ REGISTER_REFER_KERNEL(kVAddBias, VAddBias);
...
@@ -31,6 +31,7 @@ REGISTER_REFER_KERNEL(kVAddBias, VAddBias);
REGISTER_REFER_KERNEL
(
kVRelu
,
VRelu
);
REGISTER_REFER_KERNEL
(
kVRelu
,
VRelu
);
REGISTER_REFER_KERNEL
(
kVIdentity
,
VIdentity
);
REGISTER_REFER_KERNEL
(
kVIdentity
,
VIdentity
);
REGISTER_REFER_KERNEL
(
kVSquare
,
VSquare
);
REGISTER_REFER_KERNEL
(
kVExp
,
VExp
);
REGISTER_REFER_KERNEL
(
kVExp
,
VExp
);
REGISTER_REFER_KERNEL
(
kVSigmoid
,
VSigmoid
);
REGISTER_REFER_KERNEL
(
kVSigmoid
,
VSigmoid
);
REGISTER_REFER_KERNEL
(
kVTanh
,
VTanh
);
REGISTER_REFER_KERNEL
(
kVTanh
,
VTanh
);
...
@@ -49,4 +50,6 @@ REGISTER_REFER_KERNEL(kNCHW16CMulNC, NCHW16CMulNC);
...
@@ -49,4 +50,6 @@ REGISTER_REFER_KERNEL(kNCHW16CMulNC, NCHW16CMulNC);
REGISTER_REFER_KERNEL
(
kSeqPool
,
SeqPool
);
REGISTER_REFER_KERNEL
(
kSeqPool
,
SeqPool
);
REGISTER_REFER_KERNEL
(
kMatMul
,
MatMul
);
#undef REGISTER_REFER_KERNEL
#undef REGISTER_REFER_KERNEL
paddle/fluid/operators/jit/refer/refer.h
浏览文件 @
a7fc3d42
...
@@ -83,6 +83,13 @@ inline void VIdentity(const T* x, T* y, int n) {
...
@@ -83,6 +83,13 @@ inline void VIdentity(const T* x, T* y, int n) {
}
}
}
}
template
<
typename
T
>
inline
void
VSquare
(
const
T
*
x
,
T
*
y
,
int
n
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
x
[
i
]
*
x
[
i
];
}
}
template
<
typename
T
>
template
<
typename
T
>
void
VExp
(
const
T
*
x
,
T
*
y
,
int
n
)
{
void
VExp
(
const
T
*
x
,
T
*
y
,
int
n
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
...
@@ -354,6 +361,23 @@ void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) {
...
@@ -354,6 +361,23 @@ void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) {
}
}
}
}
// A(M,K) * B(K,N) = C(M,N)
template
<
typename
T
>
void
MatMul
(
const
T
*
A
,
const
T
*
B
,
T
*
C
,
int
M
,
int
N
,
int
K
)
{
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
const
T
*
pa
=
A
+
m
*
K
;
T
*
pc
=
C
+
m
*
N
;
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
const
T
*
pb
=
B
+
n
;
T
sum
=
static_cast
<
T
>
(
0
);
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
sum
+=
(
pa
[
k
]
*
pb
[
k
*
N
]);
}
*
(
pc
+
n
)
=
sum
;
}
}
}
#define DECLARE_REFER_KERNEL(name, tuples) \
#define DECLARE_REFER_KERNEL(name, tuples) \
template <typename T> \
template <typename T> \
class name##Kernel : public ReferKernel<tuples<T>> { \
class name##Kernel : public ReferKernel<tuples<T>> { \
...
@@ -377,6 +401,7 @@ DECLARE_REFER_KERNEL(VIdentity, XYNTuples);
...
@@ -377,6 +401,7 @@ DECLARE_REFER_KERNEL(VIdentity, XYNTuples);
DECLARE_REFER_KERNEL
(
VExp
,
XYNTuples
);
DECLARE_REFER_KERNEL
(
VExp
,
XYNTuples
);
DECLARE_REFER_KERNEL
(
VSigmoid
,
XYNTuples
);
DECLARE_REFER_KERNEL
(
VSigmoid
,
XYNTuples
);
DECLARE_REFER_KERNEL
(
VTanh
,
XYNTuples
);
DECLARE_REFER_KERNEL
(
VTanh
,
XYNTuples
);
DECLARE_REFER_KERNEL
(
VSquare
,
XYNTuples
);
// lstm_t*, const lstm_attr_t*
// lstm_t*, const lstm_attr_t*
DECLARE_REFER_KERNEL
(
LSTMCtHt
,
LSTMTuples
);
DECLARE_REFER_KERNEL
(
LSTMCtHt
,
LSTMTuples
);
...
@@ -394,6 +419,8 @@ DECLARE_REFER_KERNEL(NCHW16CMulNC, NCHW16CMulNCTuples);
...
@@ -394,6 +419,8 @@ DECLARE_REFER_KERNEL(NCHW16CMulNC, NCHW16CMulNCTuples);
DECLARE_REFER_KERNEL
(
SeqPool
,
SeqPoolTuples
);
DECLARE_REFER_KERNEL
(
SeqPool
,
SeqPoolTuples
);
DECLARE_REFER_KERNEL
(
MatMul
,
MatMulTuples
);
#undef DECLARE_REFER_KERNEL
#undef DECLARE_REFER_KERNEL
}
// namespace refer
}
// namespace refer
...
...
paddle/fluid/operators/jit/test.cc
浏览文件 @
a7fc3d42
...
@@ -229,6 +229,26 @@ struct TestFuncWithRefer<jit::SeqPoolTuples<T>, std::vector<T>,
...
@@ -229,6 +229,26 @@ struct TestFuncWithRefer<jit::SeqPoolTuples<T>, std::vector<T>,
}
}
};
};
template
<
typename
T
>
struct
TestFuncWithRefer
<
jit
::
MatMulTuples
<
T
>
,
std
::
vector
<
T
>
,
std
::
vector
<
T
>
,
std
::
vector
<
T
>
,
int
,
int
,
int
>
{
void
operator
()(
const
typename
jit
::
MatMulTuples
<
T
>::
func_type
tgt
,
const
std
::
vector
<
T
>&
a
,
const
std
::
vector
<
T
>&
b
,
const
std
::
vector
<
T
>&
cref
,
int
m
,
int
n
,
int
k
)
{
EXPECT_TRUE
(
tgt
!=
nullptr
);
EXPECT_EQ
(
a
.
size
(),
static_cast
<
size_t
>
(
m
*
k
));
EXPECT_EQ
(
b
.
size
(),
static_cast
<
size_t
>
(
k
*
n
));
EXPECT_EQ
(
cref
.
size
(),
static_cast
<
size_t
>
(
m
*
n
));
std
::
vector
<
T
>
c
(
cref
.
size
());
const
T
*
a_data
=
a
.
data
();
const
T
*
b_data
=
b
.
data
();
const
T
*
cref_data
=
cref
.
data
();
T
*
c_data
=
c
.
data
();
tgt
(
a_data
,
b_data
,
c_data
,
m
,
n
,
k
);
ExpectEQ
<
T
>
(
c_data
,
cref_data
,
m
*
n
);
}
};
template
<
paddle
::
operators
::
jit
::
KernelType
KT
,
typename
KernelTuples
,
template
<
paddle
::
operators
::
jit
::
KernelType
KT
,
typename
KernelTuples
,
typename
PlaceType
,
typename
...
Args
>
typename
PlaceType
,
typename
...
Args
>
void
TestAllImpls
(
const
typename
KernelTuples
::
attr_type
&
attr
,
Args
...
args
)
{
void
TestAllImpls
(
const
typename
KernelTuples
::
attr_type
&
attr
,
Args
...
args
)
{
...
@@ -458,6 +478,28 @@ void TestSeqPoolKernel() {
...
@@ -458,6 +478,28 @@ void TestSeqPoolKernel() {
}
}
}
}
template
<
paddle
::
operators
::
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
TestMatMulKernel
()
{
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
for
(
int
m
:
{
1
,
2
,
3
,
4
})
{
for
(
int
n
:
{
1
,
2
,
3
,
4
})
{
for
(
int
k
:
TestSizes
())
{
auto
ref
=
jit
::
GetRefer
<
KT
,
jit
::
MatMulTuples
<
T
>>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
std
::
vector
<
T
>
a
(
m
*
k
),
b
(
k
*
n
),
c
(
m
*
n
);
RandomVec
<
T
>
(
m
*
k
,
a
.
data
(),
-
0.2
f
,
0.2
f
);
RandomVec
<
T
>
(
k
*
n
,
b
.
data
(),
-
0.2
f
,
0.2
f
);
const
T
*
a_data
=
a
.
data
();
const
T
*
b_data
=
b
.
data
();
T
*
c_data
=
c
.
data
();
ref
(
a_data
,
b_data
,
c_data
,
m
,
n
,
k
);
TestAllImpls
<
KT
,
jit
::
MatMulTuples
<
T
>
,
PlaceType
,
std
::
vector
<
T
>
,
std
::
vector
<
T
>
,
std
::
vector
<
T
>>
(
k
,
a
,
b
,
c
,
m
,
n
,
k
);
}
}
}
}
template
<
paddle
::
operators
::
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
template
<
paddle
::
operators
::
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
TestNCHW16CMulNCKernel
()
{
void
TestNCHW16CMulNCKernel
()
{
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
...
@@ -562,6 +604,12 @@ TEST(JITKernel, kVIdentity) {
...
@@ -562,6 +604,12 @@ TEST(JITKernel, kVIdentity) {
TestXYNKernel
<
jit
::
kVIdentity
,
double
,
paddle
::
platform
::
CPUPlace
>
();
TestXYNKernel
<
jit
::
kVIdentity
,
double
,
paddle
::
platform
::
CPUPlace
>
();
}
}
TEST
(
JITKernel
,
kVSquare
)
{
namespace
jit
=
paddle
::
operators
::
jit
;
TestXYNKernel
<
jit
::
kVSquare
,
float
,
paddle
::
platform
::
CPUPlace
>
();
TestXYNKernel
<
jit
::
kVSquare
,
double
,
paddle
::
platform
::
CPUPlace
>
();
}
TEST
(
JITKernel
,
kVExp
)
{
TEST
(
JITKernel
,
kVExp
)
{
namespace
jit
=
paddle
::
operators
::
jit
;
namespace
jit
=
paddle
::
operators
::
jit
;
TestXYNKernel
<
jit
::
kVExp
,
float
,
paddle
::
platform
::
CPUPlace
>
();
TestXYNKernel
<
jit
::
kVExp
,
float
,
paddle
::
platform
::
CPUPlace
>
();
...
@@ -618,6 +666,12 @@ TEST(JITKernel, kSeqPool) {
...
@@ -618,6 +666,12 @@ TEST(JITKernel, kSeqPool) {
TestSeqPoolKernel
<
jit
::
kSeqPool
,
double
,
paddle
::
platform
::
CPUPlace
>
();
TestSeqPoolKernel
<
jit
::
kSeqPool
,
double
,
paddle
::
platform
::
CPUPlace
>
();
}
}
TEST
(
JITKernel
,
kMatMul
)
{
namespace
jit
=
paddle
::
operators
::
jit
;
TestMatMulKernel
<
jit
::
kMatMul
,
float
,
paddle
::
platform
::
CPUPlace
>
();
TestMatMulKernel
<
jit
::
kMatMul
,
double
,
paddle
::
platform
::
CPUPlace
>
();
}
TEST
(
JITKernel
,
kNCHW16CMulNC
)
{
TEST
(
JITKernel
,
kNCHW16CMulNC
)
{
namespace
jit
=
paddle
::
operators
::
jit
;
namespace
jit
=
paddle
::
operators
::
jit
;
TestNCHW16CMulNCKernel
<
jit
::
kNCHW16CMulNC
,
float
,
TestNCHW16CMulNCKernel
<
jit
::
kNCHW16CMulNC
,
float
,
...
...
python/paddle/fluid/tests/unittests/test_fusion_repeated_fc_relu_op.py
0 → 100644
浏览文件 @
a7fc3d42
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
from
test_fc_op
import
fc_refer
,
MatrixGenerate
class
TestFusionRepeatedFCReluOp
(
OpTest
):
def
setUp
(
self
):
self
.
bs
=
3
self
.
ic
=
9
self
.
oc
=
[
2
,
4
,
3
]
assert
len
(
self
.
oc
)
>
1
,
'Should larger than 1'
self
.
set_conf
()
self
.
op_type
=
'fusion_repeated_fc_relu'
sz
=
len
(
self
.
oc
)
ics
=
[
self
.
ic
]
+
self
.
oc
[
0
:
sz
-
1
]
assert
len
(
ics
)
==
len
(
self
.
oc
)
weights
=
[]
biases
=
[]
outs
=
[]
i
=
0
matrix
=
MatrixGenerate
(
self
.
bs
,
ics
[
i
],
self
.
oc
[
i
],
1
,
1
)
inp
=
np
.
reshape
(
matrix
.
input
,
[
self
.
bs
,
ics
[
i
]])
weights
.
append
((
'W_{0}'
.
format
(
i
),
np
.
reshape
(
matrix
.
weights
,
[
ics
[
i
],
self
.
oc
[
i
]])))
biases
.
append
((
'B_{0}'
.
format
(
i
),
matrix
.
bias
))
outs
.
append
(
np
.
reshape
(
np
.
maximum
(
fc_refer
(
matrix
,
True
),
0
),
[
self
.
bs
,
self
.
oc
[
i
]]))
for
i
in
range
(
sz
-
1
):
matrix
=
MatrixGenerate
(
self
.
bs
,
ics
[
i
+
1
],
self
.
oc
[
i
+
1
],
1
,
1
)
matrix
.
input
=
np
.
reshape
(
outs
[
i
],
[
self
.
bs
,
ics
[
i
+
1
],
1
,
1
])
out
=
fc_refer
(
matrix
,
True
)
weights
.
append
(
(
'W_{0}'
.
format
(
i
+
1
),
np
.
reshape
(
matrix
.
weights
,
[
ics
[
i
+
1
],
self
.
oc
[
i
+
1
]])))
biases
.
append
((
'B_{0}'
.
format
(
i
+
1
),
matrix
.
bias
))
outs
.
append
(
np
.
reshape
(
np
.
maximum
(
out
,
0
),
[
self
.
bs
,
self
.
oc
[
i
+
1
]]))
relu_outs
=
[]
for
i
in
range
(
sz
-
1
):
relu_outs
.
append
((
'ReluOut_{0}'
.
format
(
i
),
outs
[
i
]))
self
.
inputs
=
{
'X'
:
inp
,
'W'
:
weights
,
'Bias'
:
biases
,
}
self
.
outputs
=
{
'Out'
:
outs
[
-
1
],
'ReluOut'
:
relu_outs
}
def
test_check_output
(
self
):
self
.
check_output
()
def
set_conf
(
self
):
pass
class
TestFusionRepeatedFCReluOpBS1
(
TestFusionRepeatedFCReluOp
):
def
set_conf
(
self
):
self
.
bs
=
1
self
.
oc
=
[
4
,
2
,
7
,
5
]
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_fusion_squared_mat_sub_op.py
0 → 100644
浏览文件 @
a7fc3d42
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
class
TestFusionSquaredMatSubOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
'fusion_squared_mat_sub'
self
.
m
=
11
self
.
n
=
12
self
.
k
=
4
self
.
scalar
=
0.5
self
.
set_conf
()
matx
=
np
.
random
.
random
((
self
.
m
,
self
.
k
)).
astype
(
"float32"
)
maty
=
np
.
random
.
random
((
self
.
k
,
self
.
n
)).
astype
(
"float32"
)
self
.
inputs
=
{
'X'
:
matx
,
'Y'
:
maty
}
self
.
outputs
=
{
'Out'
:
(
np
.
dot
(
matx
,
maty
)
**
2
-
np
.
dot
(
matx
**
2
,
maty
**
2
))
*
self
.
scalar
}
self
.
attrs
=
{
'scalar'
:
self
.
scalar
,
}
def
set_conf
(
self
):
pass
def
test_check_output
(
self
):
self
.
check_output
()
class
TestFusionSquaredMatSubOpCase1
(
TestFusionSquaredMatSubOp
):
def
set_conf
(
self
):
self
.
scalar
=
-
0.3
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录