Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
505e450d
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
505e450d
编写于
6月 15, 2019
作者:
X
xingzhaolong
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'xzl/incubate/lite' into 'incubate/lite'
merge github code to gitlab. See merge request inference/paddlelite!13
上级
7c02e682
83bd58d4
变更
44
隐藏空白更改
内联
并排
Showing
44 changed file
with
1382 addition
and
176 deletion
+1382
-176
paddle/fluid/lite/CMakeLists.txt
paddle/fluid/lite/CMakeLists.txt
+3
-2
paddle/fluid/lite/api/CMakeLists.txt
paddle/fluid/lite/api/CMakeLists.txt
+2
-5
paddle/fluid/lite/api/cxx_api_bin.cc
paddle/fluid/lite/api/cxx_api_bin.cc
+20
-5
paddle/fluid/lite/core/CMakeLists.txt
paddle/fluid/lite/core/CMakeLists.txt
+3
-1
paddle/fluid/lite/core/mir/CMakeLists.txt
paddle/fluid/lite/core/mir/CMakeLists.txt
+35
-9
paddle/fluid/lite/core/mir/conv_bn_fuse_pass.cc
paddle/fluid/lite/core/mir/conv_bn_fuse_pass.cc
+37
-0
paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h
paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h
+32
-0
paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.cc
...luid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.cc
+39
-0
paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h
...fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h
+32
-0
paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass_test.cc
...lite/core/mir/conv_elementwise_add_relu_fuse_pass_test.cc
+153
-0
paddle/fluid/lite/core/mir/fc_fuse_pass.cc
paddle/fluid/lite/core/mir/fc_fuse_pass.cc
+34
-0
paddle/fluid/lite/core/mir/fc_fuse_pass.h
paddle/fluid/lite/core/mir/fc_fuse_pass.h
+32
-0
paddle/fluid/lite/core/mir/fc_fuse_pass_test.cc
paddle/fluid/lite/core/mir/fc_fuse_pass_test.cc
+112
-0
paddle/fluid/lite/core/mir/fusion/CMakeLists.txt
paddle/fluid/lite/core/mir/fusion/CMakeLists.txt
+22
-0
paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass_test.cc
paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass_test.cc
+140
-0
paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.cc
paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.cc
+128
-0
paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.h
paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.h
+57
-0
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.cc
...d/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.cc
+108
-0
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h
...id/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h
+41
-0
paddle/fluid/lite/core/mir/fusion/fc_fuser.cc
paddle/fluid/lite/core/mir/fusion/fc_fuser.cc
+78
-0
paddle/fluid/lite/core/mir/fusion/fc_fuser.h
paddle/fluid/lite/core/mir/fusion/fc_fuser.h
+38
-0
paddle/fluid/lite/core/mir/generate_program_pass.cc
paddle/fluid/lite/core/mir/generate_program_pass.cc
+1
-1
paddle/fluid/lite/core/mir/node.h
paddle/fluid/lite/core/mir/node.h
+8
-0
paddle/fluid/lite/core/mir/passes.h
paddle/fluid/lite/core/mir/passes.h
+4
-0
paddle/fluid/lite/core/mir/pattern_matcher.cc
paddle/fluid/lite/core/mir/pattern_matcher.cc
+71
-1
paddle/fluid/lite/core/mir/pattern_matcher.h
paddle/fluid/lite/core/mir/pattern_matcher.h
+17
-1
paddle/fluid/lite/core/mir/pattern_matcher_high_api.cc
paddle/fluid/lite/core/mir/pattern_matcher_high_api.cc
+2
-5
paddle/fluid/lite/core/mir/pattern_matcher_high_api.h
paddle/fluid/lite/core/mir/pattern_matcher_high_api.h
+0
-1
paddle/fluid/lite/core/mir/pattern_matcher_high_api_test.cc
paddle/fluid/lite/core/mir/pattern_matcher_high_api_test.cc
+10
-14
paddle/fluid/lite/core/mir/ssa_graph.cc
paddle/fluid/lite/core/mir/ssa_graph.cc
+33
-42
paddle/fluid/lite/core/mir/ssa_graph.h
paddle/fluid/lite/core/mir/ssa_graph.h
+1
-5
paddle/fluid/lite/core/mir/type_target_transform_pass.cc
paddle/fluid/lite/core/mir/type_target_transform_pass.cc
+10
-8
paddle/fluid/lite/core/mir/type_target_transform_pass.h
paddle/fluid/lite/core/mir/type_target_transform_pass.h
+1
-1
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
+34
-27
paddle/fluid/lite/core/op_lite.h
paddle/fluid/lite/core/op_lite.h
+1
-1
paddle/fluid/lite/core/optimizer.h
paddle/fluid/lite/core/optimizer.h
+3
-0
paddle/fluid/lite/core/profile/basic_profiler.h
paddle/fluid/lite/core/profile/basic_profiler.h
+1
-1
paddle/fluid/lite/core/program.h
paddle/fluid/lite/core/program.h
+1
-1
paddle/fluid/lite/kernels/x86/conv_compute.cc
paddle/fluid/lite/kernels/x86/conv_compute.cc
+2
-3
paddle/fluid/lite/kernels/x86/fc_compute.cc
paddle/fluid/lite/kernels/x86/fc_compute.cc
+17
-24
paddle/fluid/lite/kernels/x86/relu_compute.cc
paddle/fluid/lite/kernels/x86/relu_compute.cc
+1
-1
paddle/fluid/lite/model_parser/model_parser.cc
paddle/fluid/lite/model_parser/model_parser.cc
+2
-2
paddle/fluid/lite/operators/CMakeLists.txt
paddle/fluid/lite/operators/CMakeLists.txt
+2
-3
paddle/fluid/lite/operators/conv_op.h
paddle/fluid/lite/operators/conv_op.h
+14
-12
未找到文件。
paddle/fluid/lite/CMakeLists.txt
浏览文件 @
505e450d
...
...
@@ -10,6 +10,7 @@ message(STATUS "LITE_WITH_ARM:\t${LITE_WITH_ARM}")
message
(
STATUS
"LITE_WITH_PROFILE:
\t
${
LITE_WITH_PROFILE
}
"
)
set
(
LITE_MODEL_DIR
"
${
THIRD_PARTY_PATH
}
/install"
)
set
(
LITE_URL
"http://paddle-inference-dist.bj.bcebos.com"
CACHE STRING
"inference download url"
)
function
(
lite_download_and_uncompress INSTALL_DIR URL FILENAME
)
message
(
STATUS
"Download inference test stuff from
${
URL
}
/
${
FILENAME
}
"
)
...
...
@@ -161,13 +162,13 @@ function(lite_cc_test TARGET)
file
(
APPEND
${
offline_test_registry_file
}
"
${
TARGET
}
\n
"
)
endfunction
()
add_subdirectory
(
operators
)
add_subdirectory
(
kernels
)
add_subdirectory
(
core
)
add_subdirectory
(
x86
)
add_subdirectory
(
arm
)
add_subdirectory
(
host
)
add_subdirectory
(
cuda
)
add_subdirectory
(
operators
)
add_subdirectory
(
kernels
)
add_subdirectory
(
model_parser
)
add_subdirectory
(
utils
)
add_subdirectory
(
api
)
...
...
paddle/fluid/lite/api/CMakeLists.txt
浏览文件 @
505e450d
...
...
@@ -5,7 +5,7 @@ if(LITE_WITH_CUDA)
nv_test
(
test_cxx_api_lite_cuda SRCS cxx_api_test.cc DEPS cxx_api_lite_cuda
)
endif
()
cc_library
(
cxx_api_lite SRCS cxx_api.cc DEPS
${
cxx_api_lite_deps
}
${
ops_lite
}
)
cc_library
(
cxx_api_lite SRCS cxx_api.cc DEPS
${
cxx_api_lite_deps
}
${
ops_lite
}
program_lite
)
set
(
light_api_deps
scope_lite target_wrapper_host model_parser_lite
)
...
...
@@ -21,15 +21,13 @@ message(STATUS "get Host kernels ${host_kernels}")
message
(
STATUS
"get ARM kernels
${
arm_kernels
}
"
)
include
(
ExternalProject
)
set
(
LITE_URL
"http://paddle-inference-dist.bj.bcebos.com"
CACHE STRING
"inference download url"
)
set
(
LITE_DEMO_INSTALL_DIR
"
${
THIRD_PARTY_PATH
}
/inference_demo"
CACHE STRING
"A path setting inference demo download directories."
)
if
((
NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
)
AND WITH_TESTING
)
lite_cc_test
(
test_cxx_api_lite SRCS cxx_api_test.cc
DEPS cxx_api_lite m
odel_parser_lite target_wrapper_host
DEPS cxx_api_lite m
ir_passes
${
ops_lite
}
${
host_kernels
}
${
x86_kernels
}
PROFILE_DEPS basic_profiler_lite
ARGS --model_dir=
${
LITE_MODEL_DIR
}
/lite_naive_model
--optimized_model=
${
LITE_MODEL_DIR
}
/lite_naive_model_opt SERIAL
)
...
...
@@ -45,7 +43,6 @@ endif()
# lite_cc_test(test_light_api SRCS light_api_test.cc DEPS light_api_lite ARGS --optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL)
# endif()
lite_cc_binary
(
cxx_api_lite_bin SRCS cxx_api_bin.cc
DEPS
cxx_api_lite
...
...
paddle/fluid/lite/api/cxx_api_bin.cc
浏览文件 @
505e450d
...
...
@@ -13,13 +13,22 @@
// limitations under the License.
#include "paddle/fluid/lite/api/cxx_api.h"
#include <chrono>
#include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace
paddle
{
namespace
lite
{
void
Run
(
const
char
*
model_dir
)
{
using
Time
=
decltype
(
std
::
chrono
::
high_resolution_clock
::
now
());
Time
time
()
{
return
std
::
chrono
::
high_resolution_clock
::
now
();
}
double
time_diff
(
Time
t1
,
Time
t2
)
{
typedef
std
::
chrono
::
microseconds
ms
;
auto
diff
=
t2
-
t1
;
ms
counter
=
std
::
chrono
::
duration_cast
<
ms
>
(
diff
);
return
counter
.
count
()
/
1000.0
;
}
void
Run
(
const
char
*
model_dir
,
int
repeat
)
{
#ifdef LITE_WITH_ARM
DeviceInfo
::
Init
();
#endif
...
...
@@ -34,10 +43,16 @@ void Run(const char* model_dir) {
input_tensor
->
Resize
(
DDim
(
std
::
vector
<
DDim
::
value_type
>
({
1
,
3
,
224
,
224
})));
auto
*
data
=
input_tensor
->
mutable_data
<
float
>
();
for
(
int
i
=
0
;
i
<
input_tensor
->
dims
().
production
();
i
++
)
{
data
[
i
]
=
i
;
data
[
i
]
=
1
;
}
predictor
.
Run
();
for
(
int
i
=
0
;
i
<
10
;
i
++
)
predictor
.
Run
();
auto
time1
=
time
();
for
(
int
i
=
0
;
i
<
repeat
;
i
++
)
predictor
.
Run
();
auto
time2
=
time
();
std
::
cout
<<
" predict cost: "
<<
time_diff
(
time1
,
time2
)
/
repeat
<<
"ms"
<<
std
::
endl
;
auto
*
out
=
predictor
.
GetOutput
(
0
);
LOG
(
INFO
)
<<
out
<<
" memory size "
<<
out
->
data_size
();
...
...
@@ -52,7 +67,7 @@ void Run(const char* model_dir) {
int
main
(
int
argc
,
char
**
argv
)
{
CHECK_EQ
(
argc
,
2
)
<<
"usage: ./cmd <model_dir>"
;
paddle
::
lite
::
Run
(
argv
[
1
]);
paddle
::
lite
::
Run
(
argv
[
1
]
,
1
);
return
0
;
}
...
...
paddle/fluid/lite/core/CMakeLists.txt
浏览文件 @
505e450d
...
...
@@ -30,7 +30,9 @@ cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite target_wrapp
cc_library
(
types_lite SRCS types.cc
)
cc_library
(
type_system SRCS type_system.cc DEPS
${
tensor_lite
}
target_wrapper_lite
)
lite_cc_library
(
program_lite SRCS program.cc DEPS op_lite kernel_lite compatible_pb_lite model_parser_lite HVY_DEPS framework_proto
lite_cc_library
(
program_lite SRCS program.cc
DEPS op_lite kernel_lite compatible_pb_lite model_parser_lite
HVY_DEPS framework_proto
PROFILE_DEPS basic_profiler_lite
)
cc_library
(
optimizer_lite SRCS optimizer.cc DEPS mir_pass_manager model_parser_lite program_lite
)
...
...
paddle/fluid/lite/core/mir/CMakeLists.txt
浏览文件 @
505e450d
...
...
@@ -3,8 +3,13 @@ cc_library(mir_ssa_graph SRCS ssa_graph.cc DEPS mir_node program_lite)
cc_library
(
mir_pass SRCS pass.cc DEPS mir_ssa_graph
)
cc_library
(
mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph mir_passes
)
cc_library
(
mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager
)
add_subdirectory
(
fusion
)
cc_library
(
mir_passes
SRCS static_kernel_pick_pass.cc
SRCS fc_fuse_pass.cc
conv_elementwise_add_relu_fuse_pass.cc
conv_bn_fuse_pass.cc
static_kernel_pick_pass.cc
variable_place_inference_pass.cc
type_target_transform_pass.cc
io_copy_kernel_pick_pass.cc
...
...
@@ -13,13 +18,8 @@ cc_library(mir_passes
argument_type_display_pass.cc
demo_pass.cc
runtime_context_assign_pass.cc
DEPS mir_pass types_lite context_lite
)
DEPS mir_pass types_lite context_lite
${
mir_fusers
}
)
# for mobile, unnecessary to compile the following testings.
if
(
LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
)
return
()
endif
()
cc_test
(
test_mir_pass_manager SRCS pass_manager_test.cc DEPS mir_pass_manager mir_passes
)
#cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS
#mir_ssa_graph scope_lite op_lite
#fc_op_lite
...
...
@@ -52,11 +52,37 @@ lite_cc_test(test_pattern_matcher_lite SRCS pattern_matcher_test.cc DEPS pattern
lite_cc_library
(
pattern_matcher_high_api SRCS pattern_matcher_high_api.cc DEPS pattern_matcher_lite
)
# TODO(wz) replace framework/proto to lite proto.
# for mobile, unnecessary to compile the following testings.
if
(
LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
)
return
()
endif
()
cc_test
(
test_mir_pass_manager SRCS pass_manager_test.cc DEPS mir_pass_manager mir_passes
)
# TODO(wz) replace framework/proto to lite proto.
if
(
NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
)
# it depends on the fluid/framework/proto, that is too heavy for mobile execution.
lite_cc_test
(
test_pattern_matcher_high_api SRCS pattern_matcher_high_api_test.cc DEPS
pattern_matcher_high_api proto_desc mir_pass_manager fc_op_lite mul_op_lite elementwise_ops_lite
mir_passes compatible_pb_lite program_lite
${
ops_lite
}
)
endif
()
message
(
STATUS
"----> Ops lite:
${
ops_lite
}
"
)
message
(
STATUS
"----> Host kernels:
${
host_kernels
}
"
)
message
(
STATUS
"----> X86 kernels:
${
x86_kernels
}
"
)
lite_cc_test
(
test_lite_fc_fuse SRCS fc_fuse_pass_test.cc
DEPS cxx_api_lite mir_passes
${
ops_lite
}
${
host_kernels
}
${
x86_kernels
}
${
arm_kernels
}
ARGS --model_dir=
${
LITE_MODEL_DIR
}
/lite_fc_model
--optimized_model=
${
LITE_MODEL_DIR
}
/lite_fc_model_opt SERIAL
)
lite_download_and_uncompress
(
${
LITE_MODEL_DIR
}
${
LITE_URL
}
"lite_fc_model.tar.gz"
)
add_dependencies
(
test_lite_fc_fuse extern_lite_download_lite_fc_model_tar_gz
)
lite_cc_test
(
test_lite_conv_elementwise_add_relu_fuse
SRCS conv_elementwise_add_relu_fuse_pass_test.cc
DEPS cxx_api_lite mir_passes
${
ops_lite
}
${
host_kernels
}
${
x86_kernels
}
)
paddle/fluid/lite/core/mir/conv_bn_fuse_pass.cc
0 → 100644
浏览文件 @
505e450d
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
void
ConvBNFusePass
::
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
{
fusion
::
ConvBNFuser
fuser
(
"conv2d"
);
fuser
(
graph
.
get
());
fusion
::
ConvBNFuser
fuser2
(
"depthwise_conv2d"
);
fuser2
(
graph
.
get
());
}
}
// namespace mir
}
// namespace lite
}
// namespace paddle
REGISTER_MIR_PASS
(
lite_conv_bn_fuse_pass
,
paddle
::
lite
::
mir
::
ConvBNFusePass
);
paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h
0 → 100644
浏览文件 @
505e450d
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pass.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
class
ConvBNFusePass
:
public
ProgramPass
{
public:
void
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
override
;
};
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.cc
0 → 100644
浏览文件 @
505e450d
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
void
ConvElementwiseAddReLUFusePass
::
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
{
fusion
::
ConvElementwiseAddReLUFuser
fuser
(
"conv2d"
);
fuser
(
graph
.
get
());
fusion
::
ConvElementwiseAddReLUFuser
depthwise_fuser
(
"depthwise_conv2d"
);
depthwise_fuser
(
graph
.
get
());
}
}
// namespace mir
}
// namespace lite
}
// namespace paddle
REGISTER_MIR_PASS
(
lite_conv_elementwise_add_act_fuse_pass
,
paddle
::
lite
::
mir
::
ConvElementwiseAddReLUFusePass
);
paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h
0 → 100644
浏览文件 @
505e450d
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pass.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
class
ConvElementwiseAddReLUFusePass
:
public
ProgramPass
{
public:
void
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
override
;
};
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass_test.cc
0 → 100644
浏览文件 @
505e450d
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/api/cxx_api.h"
#include "paddle/fluid/lite/core/compatible_tensor.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/program.h"
DEFINE_string
(
model_dir
,
""
,
""
);
DEFINE_string
(
optimized_model
,
""
,
""
);
namespace
paddle
{
namespace
lite
{
namespace
mir
{
namespace
fusion
{
std
::
unique_ptr
<
SSAGraph
>
BuildGraph
(
framework
::
ProgramDesc
*
program_desc
,
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
std
::
vector
<
Place
>&
valid_places
)
{
auto
*
main_block
=
program_desc
->
MutableBlock
(
0
);
auto
*
conv2d_1
=
main_block
->
AppendOp
();
auto
*
conv2d_2
=
main_block
->
AppendOp
();
auto
*
add_1
=
main_block
->
AppendOp
();
auto
*
relu_1
=
main_block
->
AppendOp
();
auto
*
add_2
=
main_block
->
AppendOp
();
auto
*
relu_2
=
main_block
->
AppendOp
();
main_block
->
Var
(
"input_1"
);
main_block
->
Var
(
"input_2"
);
main_block
->
Var
(
"filter_1"
);
main_block
->
Var
(
"filter_2"
);
main_block
->
Var
(
"conv2d_1_out"
);
main_block
->
Var
(
"conv2d_2_out"
);
main_block
->
Var
(
"bias_1"
);
main_block
->
Var
(
"add_1_out"
);
main_block
->
Var
(
"add_2_out"
);
main_block
->
Var
(
"relu_1_out"
);
main_block
->
Var
(
"out"
);
scope
->
Var
(
"input_1"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"input_2"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"filter_1"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"filter_2"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"conv2d_1_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"conv2d_2_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"bias_1"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"add_1_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"add_2_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"relu_1_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"out"
)
->
GetMutable
<
lite
::
Tensor
>
();
conv2d_1
->
SetType
(
"conv2d"
);
conv2d_1
->
SetInput
(
"Input"
,
{
"input_1"
});
conv2d_1
->
SetInput
(
"Filter"
,
{
"filter_1"
});
conv2d_1
->
SetOutput
(
"Output"
,
{
"conv2d_1_out"
});
conv2d_1
->
SetAttr
(
"strides"
,
std
::
vector
<
int
>
({
1
,
1
}));
conv2d_1
->
SetAttr
(
"paddings"
,
std
::
vector
<
int
>
({
0
,
0
}));
conv2d_1
->
SetAttr
(
"groups"
,
1
);
conv2d_1
->
SetAttr
(
"dilations"
,
std
::
vector
<
int
>
({
1
,
1
}));
conv2d_1
->
SetAttr
(
"fuse_relu"
,
false
);
add_1
->
SetType
(
"elementwise_add"
);
add_1
->
SetInput
(
"X"
,
{
"conv2d_1_out"
});
add_1
->
SetInput
(
"Y"
,
{
"bias_1"
});
add_1
->
SetOutput
(
"Out"
,
{
"add_1_out"
});
add_1
->
SetAttr
(
"axis"
,
1
);
relu_1
->
SetType
(
"relu"
);
relu_1
->
SetInput
(
"X"
,
{
"add_1_out"
});
relu_1
->
SetOutput
(
"Out"
,
{
"relu_1_out"
});
conv2d_2
->
SetType
(
"conv2d"
);
conv2d_2
->
SetInput
(
"Input"
,
{
"input_2"
});
conv2d_2
->
SetInput
(
"Filter"
,
{
"filter_2"
});
conv2d_2
->
SetOutput
(
"Output"
,
{
"conv2d_2_out"
});
conv2d_2
->
SetAttr
(
"strides"
,
std
::
vector
<
int
>
({
1
,
1
}));
conv2d_2
->
SetAttr
(
"paddings"
,
std
::
vector
<
int
>
({
0
,
0
}));
conv2d_2
->
SetAttr
(
"groups"
,
1
);
conv2d_2
->
SetAttr
(
"dilations"
,
std
::
vector
<
int
>
({
1
,
1
}));
conv2d_2
->
SetAttr
(
"fuse_relu"
,
false
);
add_2
->
SetType
(
"elementwise_add"
);
add_2
->
SetInput
(
"X"
,
{
"conv2d_2_out"
});
add_2
->
SetInput
(
"Y"
,
{
"relu_1_out"
});
add_2
->
SetOutput
(
"Out"
,
{
"add_2_out"
});
add_2
->
SetAttr
(
"axis"
,
1
);
relu_2
->
SetType
(
"relu"
);
relu_2
->
SetInput
(
"X"
,
{
"add_2_out"
});
relu_2
->
SetOutput
(
"Out"
,
{
"out"
});
program_desc
->
Flush
();
lite
::
Program
program
(
*
program_desc
->
Proto
(),
scope
,
valid_places
);
auto
graph
=
std
::
unique_ptr
<
SSAGraph
>
(
new
SSAGraph
());
graph
->
Build
(
program
,
valid_places
);
return
graph
;
}
TEST
(
conv_elementwise_add_relu_fuse_pass
,
graph_test
)
{
framework
::
ProgramDesc
program_desc
;
std
::
vector
<
Place
>
places
{{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}};
auto
scope
=
std
::
make_shared
<
Scope
>
();
auto
graph
=
BuildGraph
(
&
program_desc
,
scope
,
places
);
Visualize
(
graph
.
get
());
ASSERT_EQ
(
graph
->
nodes
().
size
(),
11UL
/*vars*/
+
6UL
/*ops*/
);
Visualize
(
graph
.
get
());
}
TEST
(
conv_elementwise_add_relu_fuse_pass
,
fuse_test_op
)
{
framework
::
ProgramDesc
program_desc
;
std
::
vector
<
Place
>
places
{{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}};
auto
scope
=
std
::
make_shared
<
Scope
>
();
auto
graph
=
BuildGraph
(
&
program_desc
,
scope
,
places
);
Visualize
(
graph
.
get
());
const
int
num_nodes
=
graph
->
nodes
().
size
();
auto
*
fuser
=
new
ConvElementwiseAddReLUFusePass
;
fuser
->
Apply
(
graph
);
Visualize
(
graph
.
get
());
ASSERT_EQ
(
graph
->
nodes
().
size
(),
num_nodes
-
5UL
*
2
/*nodes removed */
+
1UL
*
2
/* fused fc node*/
);
}
}
// namespace fusion
}
// namespace mir
}
// namespace lite
}
// namespace paddle
USE_LITE_OP
(
elementwise_add
);
USE_LITE_OP
(
conv2d
);
USE_LITE_OP
(
depthwise_conv2d
);
USE_LITE_OP
(
relu
);
paddle/fluid/lite/core/mir/fc_fuse_pass.cc
0 → 100644
浏览文件 @
505e450d
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fc_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/fc_fuser.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
void
FcFusePass
::
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
{
fusion
::
FcFuser
fuser
;
fuser
(
graph
.
get
());
}
}
// namespace mir
}
// namespace lite
}
// namespace paddle
REGISTER_MIR_PASS
(
lite_fc_fuse_pass
,
paddle
::
lite
::
mir
::
FcFusePass
);
paddle/fluid/lite/core/mir/fc_fuse_pass.h
0 → 100644
浏览文件 @
505e450d
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pass.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
class
FcFusePass
:
public
ProgramPass
{
public:
void
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
override
;
};
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/fc_fuse_pass_test.cc
0 → 100644
浏览文件 @
505e450d
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fc_fuse_pass.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/lite/api/cxx_api.h"
#include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/op_registry.h"
DEFINE_string
(
model_dir
,
""
,
""
);
DEFINE_string
(
optimized_model
,
""
,
""
);
namespace
paddle
{
namespace
lite
{
namespace
mir
{
TEST
(
fc_fuse_pass
,
fuse_test
)
{
lite
::
ExecutorLite
predictor
;
#ifndef LITE_WITH_CUDA
std
::
vector
<
Place
>
valid_places
({
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)},
Place
{
TARGET
(
kX86
),
PRECISION
(
kFloat
)}});
#else
std
::
vector
<
Place
>
valid_places
({
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
)},
Place
{
TARGET
(
kCUDA
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
)},
Place
{
TARGET
(
kCUDA
),
PRECISION
(
kAny
),
DATALAYOUT
(
kNCHW
)},
Place
{
TARGET
(
kHost
),
PRECISION
(
kAny
),
DATALAYOUT
(
kNCHW
)},
Place
{
TARGET
(
kCUDA
),
PRECISION
(
kAny
),
DATALAYOUT
(
kAny
)},
Place
{
TARGET
(
kHost
),
PRECISION
(
kAny
),
DATALAYOUT
(
kAny
)},
});
#endif
predictor
.
Build
(
FLAGS_model_dir
,
Place
{
TARGET
(
kX86
),
PRECISION
(
kFloat
)},
// origin cuda
valid_places
);
auto
*
input_tensor
=
predictor
.
GetInput
(
0
);
input_tensor
->
Resize
(
DDim
(
std
::
vector
<
DDim
::
value_type
>
({
100
,
100
})));
auto
*
data
=
input_tensor
->
mutable_data
<
float
>
();
for
(
int
i
=
0
;
i
<
100
*
100
;
i
++
)
{
data
[
i
]
=
i
;
}
predictor
.
Run
();
auto
*
out
=
predictor
.
GetOutput
(
0
);
LOG
(
INFO
)
<<
out
<<
" memory size "
<<
out
->
data_size
();
LOG
(
INFO
)
<<
"out "
<<
out
->
data
<
float
>
()[
0
];
LOG
(
INFO
)
<<
"out "
<<
out
->
data
<
float
>
()[
1
];
LOG
(
INFO
)
<<
"dims "
<<
out
->
dims
();
EXPECT_NEAR
(
out
->
data
<
float
>
()[
0
],
38.120617
f
,
1e-5
);
EXPECT_NEAR
(
out
->
data
<
float
>
()[
1
],
10.109812
f
,
1e-5
);
CHECK_EQ
(
out
->
dims
()[
0
],
100
);
CHECK_EQ
(
out
->
dims
()[
1
],
500
);
}
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
TEST
(
fc_fuse_pass
,
save_model_test
)
{
lite
::
ExecutorLite
predictor
;
std
::
vector
<
Place
>
valid_places
({
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)},
Place
{
TARGET
(
kX86
),
PRECISION
(
kFloat
)}});
predictor
.
Build
(
FLAGS_model_dir
,
Place
{
TARGET
(
kX86
),
PRECISION
(
kFloat
)},
valid_places
);
LOG
(
INFO
)
<<
"Save optimized model to "
<<
FLAGS_optimized_model
;
predictor
.
SaveModel
(
FLAGS_optimized_model
);
}
#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
}
// namespace mir
}
// namespace lite
}
// namespace paddle
USE_LITE_OP
(
mul
);
USE_LITE_OP
(
elementwise_add
);
USE_LITE_OP
(
elementwise_sub
);
USE_LITE_OP
(
fc
);
USE_LITE_OP
(
feed
);
USE_LITE_OP
(
fetch
);
USE_LITE_OP
(
io_copy
);
USE_LITE_OP
(
softmax
);
USE_LITE_OP
(
scale
);
USE_LITE_KERNEL
(
feed
,
kHost
,
kAny
,
kAny
,
def
);
USE_LITE_KERNEL
(
fetch
,
kHost
,
kAny
,
kAny
,
def
);
#ifdef LITE_WITH_X86
USE_LITE_KERNEL
(
mul
,
kX86
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
fc
,
kX86
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
elementwise_sub
,
kX86
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
elementwise_add
,
kX86
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
softmax
,
kX86
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
scale
,
kX86
,
kFloat
,
kNCHW
,
def
);
#endif
#ifdef LITE_WITH_CUDA
USE_LITE_KERNEL
(
mul
,
kCUDA
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
io_copy
,
kCUDA
,
kAny
,
kAny
,
host_to_device
);
USE_LITE_KERNEL
(
io_copy
,
kCUDA
,
kAny
,
kAny
,
device_to_host
);
#endif
paddle/fluid/lite/core/mir/fusion/CMakeLists.txt
0 → 100644
浏览文件 @
505e450d
cc_library
(
fuse_fc
SRCS fc_fuser.cc
DEPS pattern_matcher_high_api
)
cc_library
(
fuse_conv_elementwise_add_relu
SRCS conv_elementwise_add_relu_fuser.cc
DEPS pattern_matcher_high_api
)
cc_library
(
fuse_conv_bn
SRCS conv_bn_fuser.cc
DEPS pattern_matcher_high_api
)
set
(
mir_fusers
fuse_fc
fuse_conv_elementwise_add_relu
fuse_conv_bn
CACHE INTERNAL
"fusers"
)
if
(
LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
)
return
()
endif
()
lite_cc_test
(
test_lite_conv_bn_fuse SRCS conv_bn_fuse_pass_test.cc
DEPS elementwise_ops_lite batch_norm_op_lite conv_op_lite proto_desc compatible_pb_lite program_lite mir_pass mir_pass_manager pattern_matcher_high_api
)
paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass_test.cc
0 → 100644
浏览文件 @
505e450d
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/core/compatible_tensor.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/program.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
namespace
fusion
{
std
::
unique_ptr
<
SSAGraph
>
BuildGraph
(
framework
::
ProgramDesc
*
program_desc
,
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
std
::
vector
<
Place
>&
valid_places
)
{
auto
*
main_block
=
program_desc
->
MutableBlock
(
0
);
auto
*
conv_op
=
main_block
->
AppendOp
();
auto
*
bn_op
=
main_block
->
AppendOp
();
main_block
->
Var
(
"conv_i"
);
main_block
->
Var
(
"conv_param"
);
main_block
->
Var
(
"conv_out"
);
main_block
->
Var
(
"bn_scale"
);
main_block
->
Var
(
"bn_bias"
);
main_block
->
Var
(
"bn_mean"
);
main_block
->
Var
(
"bn_var"
);
main_block
->
Var
(
"bn_out"
);
main_block
->
Var
(
"bn_mean_out"
);
main_block
->
Var
(
"bn_var_out"
);
main_block
->
Var
(
"bn_saved_mean"
);
main_block
->
Var
(
"bn_saved_var"
);
scope
->
Var
(
"conv_i"
)
->
GetMutable
<
lite
::
Tensor
>
();
auto
conv_param_t
=
scope
->
Var
(
"conv_param"
)
->
GetMutable
<
lite
::
Tensor
>
();
std
::
vector
<
int64_t
>
conv_param_shape
=
{
3
,
1
,
2
,
2
};
conv_param_t
->
Resize
(
lite
::
DDim
(
conv_param_shape
));
conv_param_t
->
mutable_data
<
float
>
();
scope
->
Var
(
"conv_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
auto
bn_scale_t
=
scope
->
Var
(
"bn_scale"
)
->
GetMutable
<
lite
::
Tensor
>
();
std
::
vector
<
int64_t
>
bn_scale_shape
=
{
3
};
bn_scale_t
->
Resize
(
lite
::
DDim
(
bn_scale_shape
));
bn_scale_t
->
mutable_data
<
float
>
();
auto
bn_bias_t
=
scope
->
Var
(
"bn_bias"
)
->
GetMutable
<
lite
::
Tensor
>
();
std
::
vector
<
int64_t
>
bn_bias_shape
=
{
3
};
bn_bias_t
->
Resize
(
lite
::
DDim
(
bn_bias_shape
));
bn_bias_t
->
mutable_data
<
float
>
();
auto
bn_mean_t
=
scope
->
Var
(
"bn_mean"
)
->
GetMutable
<
lite
::
Tensor
>
();
bn_mean_t
->
Resize
(
lite
::
DDim
(
bn_bias_shape
));
bn_mean_t
->
mutable_data
<
float
>
();
auto
bn_var_t
=
scope
->
Var
(
"bn_var"
)
->
GetMutable
<
lite
::
Tensor
>
();
bn_var_t
->
Resize
(
lite
::
DDim
(
bn_bias_shape
));
bn_var_t
->
mutable_data
<
float
>
();
scope
->
Var
(
"bn_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"bn_mean_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"bn_var_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"bn_saved_mean"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"bn_saved_var"
)
->
GetMutable
<
lite
::
Tensor
>
();
conv_op
->
SetType
(
"conv2d"
);
conv_op
->
SetInput
(
"Input"
,
{
"conv_i"
});
conv_op
->
SetInput
(
"Filter"
,
{
"conv_param"
});
conv_op
->
SetOutput
(
"Output"
,
{
"conv_out"
});
const
std
::
vector
<
int
>
strides
({
1
,
1
});
const
std
::
vector
<
int
>
paddings
({
1
,
1
});
const
std
::
vector
<
int
>
dilations
({
1
,
1
});
const
int
groups
=
1
;
conv_op
->
SetAttr
(
"strides"
,
strides
);
conv_op
->
SetAttr
(
"paddings"
,
paddings
);
conv_op
->
SetAttr
(
"dilations"
,
dilations
);
conv_op
->
SetAttr
(
"groups"
,
groups
);
conv_op
->
SetAttr
(
"fuse_relu"
,
false
);
bn_op
->
SetType
(
"batch_norm"
);
bn_op
->
SetInput
(
"X"
,
{
"conv_out"
});
bn_op
->
SetInput
(
"Bias"
,
{
"bn_bias"
});
bn_op
->
SetInput
(
"Mean"
,
{
"bn_mean"
});
bn_op
->
SetInput
(
"Scale"
,
{
"bn_scale"
});
bn_op
->
SetInput
(
"Variance"
,
{
"bn_var"
});
bn_op
->
SetOutput
(
"Y"
,
{
"bn_out"
});
bn_op
->
SetOutput
(
"MeanOut"
,
{
"bn_mean_out"
});
bn_op
->
SetOutput
(
"VarianceOut"
,
{
"bn_var_out"
});
bn_op
->
SetOutput
(
"SavedMean"
,
{
"bn_saved_mean"
});
bn_op
->
SetOutput
(
"SavedVariance"
,
{
"bn_saved_var"
});
float
eps
=
1e-5
;
bn_op
->
SetAttr
(
"epsilon"
,
eps
);
bn_op
->
SetAttr
(
"is_test"
,
static_cast
<
int
>
(
1
));
bn_op
->
SetAttr
(
"use_global_stats"
,
false
);
bn_op
->
SetAttr
(
"momentum"
,
0.9
f
);
bn_op
->
SetAttr
(
"data_layout"
,
std
::
string
(
"NCHW"
));
program_desc
->
Flush
();
lite
::
Program
program
(
*
program_desc
->
Proto
(),
scope
,
valid_places
);
auto
graph
=
std
::
unique_ptr
<
SSAGraph
>
(
new
SSAGraph
());
graph
->
Build
(
program
,
valid_places
);
return
graph
;
}
TEST
(
pattern_matcher2
,
test
)
{
framework
::
ProgramDesc
program_desc
;
std
::
vector
<
Place
>
places
{{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}};
auto
scope
=
std
::
make_shared
<
Scope
>
();
auto
graph
=
BuildGraph
(
&
program_desc
,
scope
,
places
);
const
int
num_nodes
=
graph
->
nodes
().
size
();
auto
*
fuser
=
new
ConvBNFusePass
;
fuser
->
Apply
(
graph
);
ASSERT_EQ
(
graph
->
nodes
().
size
(),
num_nodes
-
8UL
/*nodes removed */
+
1UL
/* eltwise_add node*/
);
}
}
// namespace fusion
}
// namespace mir
}
// namespace lite
}
// namespace paddle
USE_LITE_OP
(
conv2d
);
USE_LITE_OP
(
batch_norm
);
USE_LITE_OP
(
elementwise_add
);
paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.cc
0 → 100644
浏览文件 @
505e450d
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.h"
#include <memory>
#include <vector>
namespace
paddle
{
namespace
lite
{
namespace
mir
{
namespace
fusion
{
void
ConvBNFuser
::
BuildPattern
()
{
auto
*
conv_input
=
VarNode
(
"conv_input"
)
->
assert_is_op_input
(
conv_type_
,
"Input"
)
->
AsInput
();
auto
*
conv_weight
=
VarNode
(
"conv_weight"
)
->
assert_is_op_input
(
conv_type_
,
"Filter"
)
->
AsInput
();
auto
*
conv
=
OpNode
(
"conv2d"
,
conv_type_
)
->
assert_is_op
(
conv_type_
);
auto
*
conv_out
=
VarNode
(
"conv_out"
)
->
assert_is_op_output
(
conv_type_
,
"Output"
)
->
assert_is_op_input
(
"batch_norm"
,
"X"
);
auto
*
bn_scale
=
VarNode
(
"bn_scale"
)
->
assert_is_op_input
(
"batch_norm"
,
"Scale"
)
->
AsIntermediate
();
auto
*
bn_bias
=
VarNode
(
"bn_bias"
)
->
assert_is_op_input
(
"batch_norm"
,
"Bias"
)
->
AsInput
();
auto
*
bn_mean
=
VarNode
(
"bn_mean"
)
->
assert_is_op_input
(
"batch_norm"
,
"Mean"
)
->
AsIntermediate
();
auto
*
bn_var
=
VarNode
(
"bn_variance"
)
->
assert_is_op_input
(
"batch_norm"
,
"Variance"
)
->
AsIntermediate
();
auto
*
bn
=
OpNode
(
"bn"
,
"batch_norm"
)
->
assert_is_op
(
"batch_norm"
)
->
AsIntermediate
();
auto
*
bn_out
=
VarNode
(
"bn_out"
)
->
assert_is_op_output
(
"batch_norm"
,
"Y"
)
->
AsOutput
();
auto
*
bn_mean_out
=
VarNode
(
"bn_mean_out"
)
->
assert_is_op_output
(
"batch_norm"
,
"MeanOut"
)
->
AsIntermediate
();
auto
*
bn_var_out
=
VarNode
(
"bn_var_out"
)
->
assert_is_op_output
(
"batch_norm"
,
"VarianceOut"
)
->
AsIntermediate
();
auto
*
bn_saved_mean
=
VarNode
(
"bn_saved_mean"
)
->
assert_is_op_output
(
"batch_norm"
,
"SavedMean"
)
->
AsIntermediate
();
auto
*
bn_saved_var
=
VarNode
(
"bn_saved_var"
)
->
assert_is_op_output
(
"batch_norm"
,
"SavedVariance"
)
->
AsIntermediate
();
conv
->
LinksFrom
({
conv_input
,
conv_weight
}).
LinksTo
({
conv_out
});
bn
->
LinksFrom
({
conv_out
,
bn_scale
,
bn_bias
,
bn_mean
,
bn_var
})
.
LinksTo
({
bn_out
,
bn_mean_out
,
bn_saved_mean
,
bn_saved_var
,
bn_var_out
});
}
void
ConvBNFuser
::
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
{
auto
op_desc
=
GenOpDesc
(
matched
);
auto
eltwise_op
=
LiteOpRegistry
::
Global
().
Create
(
"elementwise_add"
);
auto
conv
=
matched
.
at
(
"conv2d"
)
->
stmt
()
->
op
;
auto
*
scope
=
conv
->
scope
();
auto
&
valid_places
=
conv
->
valid_places
();
auto
conv_weight_t
=
scope
->
FindVar
(
matched
.
at
(
"conv_weight"
)
->
arg
()
->
name
)
->
GetMutable
<
lite
::
Tensor
>
();
auto
conv_weight_d
=
conv_weight_t
->
mutable_data
<
float
>
();
auto
conv_weight_dims
=
conv_weight_t
->
dims
();
size_t
weight_num
=
conv_weight_t
->
data_size
();
auto
bn_scale_t
=
scope
->
FindVar
(
matched
.
at
(
"bn_scale"
)
->
arg
()
->
name
)
->
GetMutable
<
lite
::
Tensor
>
();
size_t
bias_size
=
bn_scale_t
->
data_size
();
auto
bn_scale_d
=
bn_scale_t
->
mutable_data
<
float
>
();
CHECK
(
bias_size
==
conv_weight_dims
[
0
])
<<
"The BN bias's size should be equal to the size of the first "
<<
"dim size of the conv weights"
;
auto
bn_mean_t
=
scope
->
FindVar
(
matched
.
at
(
"bn_mean"
)
->
arg
()
->
name
)
->
GetMutable
<
lite
::
Tensor
>
();
auto
bn_mean_d
=
bn_mean_t
->
mutable_data
<
float
>
();
auto
bn_var_t
=
scope
->
FindVar
(
matched
.
at
(
"bn_variance"
)
->
arg
()
->
name
)
->
GetMutable
<
lite
::
Tensor
>
();
auto
bn_var_d
=
bn_var_t
->
mutable_data
<
float
>
();
auto
bn_bias_t
=
scope
->
FindVar
(
matched
.
at
(
"bn_bias"
)
->
arg
()
->
name
)
->
GetMutable
<
lite
::
Tensor
>
();
auto
bn_bias_d
=
bn_bias_t
->
mutable_data
<
float
>
();
auto
eps
=
matched
.
at
(
"bn"
)
->
stmt
()
->
op_info
()
->
GetAttr
<
float
>
(
"epsilon"
);
ComputeFusedWeight
(
bn_scale_d
,
bn_mean_d
,
bn_var_d
,
bn_bias_d
,
conv_weight_d
,
eps
,
bias_size
,
weight_num
/
bias_size
);
eltwise_op
->
Attach
(
op_desc
,
scope
);
auto
*
new_op_node
=
graph
->
GraphCreateInstructNode
(
eltwise_op
,
valid_places
);
IR_NODE_LINK_TO
(
matched
.
at
(
"conv_out"
),
new_op_node
);
IR_NODE_LINK_TO
(
matched
.
at
(
"bn_bias"
),
new_op_node
);
IR_NODE_LINK_TO
(
new_op_node
,
matched
.
at
(
"bn_out"
));
}
cpp
::
OpDesc
ConvBNFuser
::
GenOpDesc
(
const
key2nodes_t
&
matched
)
{
cpp
::
OpDesc
op_desc
;
op_desc
.
SetType
(
"elementwise_add"
);
op_desc
.
SetInput
(
"X"
,
{
matched
.
at
(
"conv_out"
)
->
arg
()
->
name
});
op_desc
.
SetInput
(
"Y"
,
{
matched
.
at
(
"bn_bias"
)
->
arg
()
->
name
});
op_desc
.
SetOutput
(
"Out"
,
{
matched
.
at
(
"bn_out"
)
->
arg
()
->
name
});
op_desc
.
SetAttr
(
"axis"
,
1
);
return
op_desc
;
}
}
// namespace fusion
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.h
0 → 100644
浏览文件 @
505e450d
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
namespace
fusion
{
class
ConvBNFuser
:
public
FuseBase
{
public:
explicit
ConvBNFuser
(
const
std
::
string
&
conv_type
)
:
conv_type_
(
conv_type
)
{}
void
BuildPattern
()
override
;
void
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
override
;
private:
cpp
::
OpDesc
GenOpDesc
(
const
key2nodes_t
&
matched
)
override
;
void
ComputeFusedWeight
(
float
*
scale_d
,
float
*
mean_d
,
float
*
var_d
,
float
*
bias_d
,
float
*
conv_weight_d
,
float
eps
,
int
h
,
int
w
)
{
for
(
int
i
=
0
;
i
<
h
;
i
++
)
{
var_d
[
i
]
=
scale_d
[
i
]
/
std
::
sqrt
(
var_d
[
i
]
+
eps
);
}
for
(
int
i
=
0
;
i
<
h
;
i
++
)
{
bias_d
[
i
]
+=
(
-
mean_d
[
i
])
*
var_d
[
i
];
}
for
(
int
i
=
0
;
i
<
h
;
i
++
)
{
for
(
int
j
=
0
;
j
<
w
;
j
++
)
{
conv_weight_d
[
i
*
w
+
j
]
*=
var_d
[
i
];
}
}
}
private:
std
::
string
conv_type_
{
"conv2d"
};
};
}
// namespace fusion
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.cc
0 → 100644
浏览文件 @
505e450d
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h"
#include <memory>
#include <vector>
namespace
paddle
{
namespace
lite
{
namespace
mir
{
namespace
fusion
{
void
ConvElementwiseAddReLUFuser
::
BuildPattern
()
{
// create input nodes.
auto
*
input
=
VarNode
(
"input"
)
->
assert_is_op_input
(
conv_type_
,
"Input"
)
->
AsInput
();
auto
*
filter
=
VarNode
(
"filter"
)
->
assert_is_op_input
(
conv_type_
,
"Filter"
)
->
AsInput
();
auto
*
bias
=
VarNode
(
"bias"
)
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
)
->
AsInput
();
// create op nodes
auto
*
conv2d
=
OpNode
(
"conv2d"
,
conv_type_
)
->
assert_is_op
(
conv_type_
)
->
AsIntermediate
();
auto
*
add
=
OpNode
(
"add"
,
"elementwise_add"
)
->
assert_is_op
(
"elementwise_add"
)
->
AsIntermediate
();
auto
*
relu
=
OpNode
(
"relu"
,
"relu"
)
->
assert_is_op
(
"relu"
)
->
AsIntermediate
();
// create intermediate nodes
auto
*
conv2d_out
=
VarNode
(
"conv2d_out"
)
->
assert_is_op_output
(
conv_type_
,
"Output"
)
->
assert_is_op_input
(
"elementwise_add"
,
"X"
)
->
AsIntermediate
();
auto
*
add_out
=
VarNode
(
"add_out"
)
->
assert_is_op_output
(
"elementwise_add"
,
"Out"
)
->
assert_is_op_input
(
"relu"
,
"X"
)
->
AsIntermediate
();
// create output node
auto
*
out
=
VarNode
(
"output"
)
->
assert_is_op_output
(
"relu"
,
"Out"
)
->
AsOutput
();
// create topology.
std
::
vector
<
PMNode
*>
conv2d_inputs
{
filter
,
input
};
std
::
vector
<
PMNode
*>
add_inputs
{
conv2d_out
,
bias
};
conv2d_inputs
>>
*
conv2d
>>
*
conv2d_out
;
add_inputs
>>
*
add
>>
*
add_out
;
*
add_out
>>
*
relu
>>
*
out
;
}
void
ConvElementwiseAddReLUFuser
::
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
{
auto
op_desc
=
GenOpDesc
(
matched
);
auto
conv_op
=
LiteOpRegistry
::
Global
().
Create
(
conv_type_
);
auto
conv_old
=
matched
.
at
(
"conv2d"
)
->
stmt
()
->
op
;
auto
*
scope
=
conv_old
->
scope
();
auto
&
valid_places
=
conv_old
->
valid_places
();
conv_op
->
Attach
(
op_desc
,
scope
);
auto
*
new_op_node
=
graph
->
GraphCreateInstructNode
(
conv_op
,
valid_places
);
IR_NODE_LINK_TO
(
matched
.
at
(
"input"
),
new_op_node
);
IR_NODE_LINK_TO
(
matched
.
at
(
"filter"
),
new_op_node
);
IR_NODE_LINK_TO
(
matched
.
at
(
"bias"
),
new_op_node
);
IR_NODE_LINK_TO
(
new_op_node
,
matched
.
at
(
"output"
));
}
cpp
::
OpDesc
ConvElementwiseAddReLUFuser
::
GenOpDesc
(
const
key2nodes_t
&
matched
)
{
auto
*
desc
=
matched
.
at
(
"conv2d"
)
->
stmt
()
->
op_info
();
cpp
::
OpDesc
op_desc
;
op_desc
.
SetType
(
conv_type_
);
op_desc
.
SetInput
(
"Input"
,
{
matched
.
at
(
"input"
)
->
arg
()
->
name
});
op_desc
.
SetInput
(
"Filter"
,
{
matched
.
at
(
"filter"
)
->
arg
()
->
name
});
op_desc
.
SetInput
(
"Bias"
,
{
matched
.
at
(
"bias"
)
->
arg
()
->
name
});
op_desc
.
SetOutput
(
"Output"
,
{
matched
.
at
(
"output"
)
->
arg
()
->
name
});
// Other inputs. See operators/conv_op.h
std
::
vector
<
std
::
string
>
input_arg_names
=
desc
->
InputArgumentNames
();
if
(
std
::
find
(
input_arg_names
.
begin
(),
input_arg_names
.
end
(),
"ResidualData"
)
!=
input_arg_names
.
end
())
{
op_desc
.
SetInput
(
"ResidualData"
,
desc
->
Input
(
"ResidualData"
));
}
// Only consider strides, padding, groups, dilations, fuse_relu for now
op_desc
.
SetAttr
(
"strides"
,
desc
->
GetAttr
<
std
::
vector
<
int
>>
(
"strides"
));
op_desc
.
SetAttr
(
"paddings"
,
desc
->
GetAttr
<
std
::
vector
<
int
>>
(
"paddings"
));
op_desc
.
SetAttr
(
"groups"
,
desc
->
GetAttr
<
int
>
(
"groups"
));
op_desc
.
SetAttr
(
"dilations"
,
desc
->
GetAttr
<
std
::
vector
<
int
>>
(
"dilations"
));
op_desc
.
SetAttr
(
"fuse_relu"
,
true
);
return
op_desc
;
}
}
// namespace fusion
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h
0 → 100644
浏览文件 @
505e450d
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
namespace
fusion
{
class
ConvElementwiseAddReLUFuser
:
public
FuseBase
{
public:
explicit
ConvElementwiseAddReLUFuser
(
const
std
::
string
&
conv_type
)
:
conv_type_
(
conv_type
)
{}
void
BuildPattern
()
override
;
void
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
override
;
private:
cpp
::
OpDesc
GenOpDesc
(
const
key2nodes_t
&
matched
)
override
;
std
::
string
conv_type_
;
};
}
// namespace fusion
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/fusion/fc_fuser.cc
0 → 100644
浏览文件 @
505e450d
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fusion/fc_fuser.h"
#include <memory>
#include <vector>
namespace
paddle
{
namespace
lite
{
namespace
mir
{
namespace
fusion
{
void
FcFuser
::
BuildPattern
()
{
// create nodes.
auto
*
x
=
VarNode
(
"x"
)
->
assert_is_op_input
(
"mul"
,
"X"
);
auto
*
W
=
VarNode
(
"W"
)
->
assert_is_op_input
(
"mul"
,
"Y"
);
auto
*
b
=
VarNode
(
"b"
);
auto
*
mul
=
OpNode
(
"mul"
,
"mul"
);
auto
*
mul_out
=
VarNode
(
"mul_out"
);
auto
*
add
=
OpNode
(
"add"
,
"elementwise_add"
);
auto
*
Out
=
VarNode
(
"Out"
);
// create topology.
std
::
vector
<
PMNode
*>
mul_inputs
{
W
,
x
};
std
::
vector
<
PMNode
*>
add_inputs
{
mul_out
,
b
};
mul_inputs
>>
*
mul
>>
*
mul_out
;
add_inputs
>>
*
add
>>
*
Out
;
// Some op specialities.
mul_out
->
AsIntermediate
();
mul
->
AsIntermediate
();
add
->
AsIntermediate
();
}
void
FcFuser
::
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
{
auto
op_desc
=
GenOpDesc
(
matched
);
auto
fc_op
=
LiteOpRegistry
::
Global
().
Create
(
"fc"
);
auto
mul
=
matched
.
at
(
"mul"
)
->
stmt
()
->
op
;
auto
*
scope
=
mul
->
scope
();
auto
&
valid_places
=
mul
->
valid_places
();
fc_op
->
Attach
(
op_desc
,
scope
);
auto
*
new_op_node
=
graph
->
GraphCreateInstructNode
(
fc_op
,
valid_places
);
IR_NODE_LINK_TO
(
matched
.
at
(
"W"
),
new_op_node
);
IR_NODE_LINK_TO
(
matched
.
at
(
"x"
),
new_op_node
);
IR_NODE_LINK_TO
(
matched
.
at
(
"b"
),
new_op_node
);
IR_NODE_LINK_TO
(
new_op_node
,
matched
.
at
(
"Out"
));
}
cpp
::
OpDesc
FcFuser
::
GenOpDesc
(
const
key2nodes_t
&
matched
)
{
cpp
::
OpDesc
op_desc
;
op_desc
.
SetType
(
"fc"
);
op_desc
.
SetInput
(
"Input"
,
{
matched
.
at
(
"x"
)
->
arg
()
->
name
});
op_desc
.
SetInput
(
"W"
,
{
matched
.
at
(
"W"
)
->
arg
()
->
name
});
op_desc
.
SetInput
(
"Bias"
,
{
matched
.
at
(
"b"
)
->
arg
()
->
name
});
op_desc
.
SetOutput
(
"Out"
,
{
matched
.
at
(
"Out"
)
->
arg
()
->
name
});
op_desc
.
SetAttr
(
"in_num_col_dims"
,
matched
.
at
(
"mul"
)
->
stmt
()
->
op_info
()
->
GetAttr
<
int
>
(
"x_num_col_dims"
));
return
op_desc
;
}
}
// namespace fusion
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/fusion/fc_fuser.h
0 → 100644
浏览文件 @
505e450d
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
namespace
fusion
{
class
FcFuser
:
public
FuseBase
{
public:
void
BuildPattern
()
override
;
void
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
override
;
private:
cpp
::
OpDesc
GenOpDesc
(
const
key2nodes_t
&
matched
)
override
;
};
}
// namespace fusion
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/generate_program_pass.cc
浏览文件 @
505e450d
...
...
@@ -28,7 +28,7 @@ void GenerateProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
for
(
auto
&
item
:
graph
->
StmtTopologicalOrder
())
{
if
(
item
->
IsStmt
())
{
auto
&
stmt
=
item
->
AsStmt
();
LOG
(
INFO
)
<<
stmt
;
VLOG
(
4
)
<<
stmt
;
insts_
.
emplace_back
(
stmt
.
op
,
std
::
move
(
stmt
.
valid_kernels
.
front
()));
}
}
...
...
paddle/fluid/lite/core/mir/node.h
浏览文件 @
505e450d
...
...
@@ -71,12 +71,20 @@ class Node {
struct
Arg
{
std
::
string
name
;
int
id
{
0
};
const
Type
*
type
{};
// Weight is a special kind of argument, it is marked as weight explicitly
// so that some weight related optimization can take place.
bool
is_weight
{
false
};
};
Arg
&
AsArg
(
const
std
::
string
&
name
,
int
id
)
{
auto
&
x
=
AsArg
();
x
.
name
=
name
;
x
.
id
=
id
;
return
x
;
}
Arg
&
AsArg
(
const
std
::
string
&
name
)
{
auto
&
x
=
AsArg
();
x
.
name
=
name
;
...
...
paddle/fluid/lite/core/mir/passes.h
浏览文件 @
505e450d
...
...
@@ -31,3 +31,7 @@ USE_MIR_PASS(io_copy_kernel_pick_pass);
USE_MIR_PASS
(
argument_type_display_pass
);
#endif
USE_MIR_PASS
(
runtime_context_assign_pass
);
USE_MIR_PASS
(
lite_conv_bn_fuse_pass
);
USE_MIR_PASS
(
graph_visualze
);
USE_MIR_PASS
(
lite_fc_fuse_pass
);
USE_MIR_PASS
(
lite_conv_elementwise_add_act_fuse_pass
);
paddle/fluid/lite/core/mir/pattern_matcher.cc
浏览文件 @
505e450d
...
...
@@ -45,10 +45,11 @@ PMNode &PMNode::operator>>(std::vector<PMNode *> &nodes) {
return
*
this
;
}
void
operator
>>
(
std
::
vector
<
PMNode
*>
&
others
,
PMNode
&
me
)
{
PMNode
&
operator
>>
(
std
::
vector
<
PMNode
*>
&
others
,
PMNode
&
me
)
{
for
(
auto
*
o
:
others
)
{
*
o
>>
me
;
}
return
me
;
}
PMNode
*
PMPattern
::
NewNode
(
const
std
::
string
&
name
)
{
...
...
@@ -406,6 +407,67 @@ PMNode *PMNode::assert_is_op_output(const std::string &op_type) {
return
this
;
}
bool
IsNthOutput
(
const
Node
*
var
,
const
Node
*
op
,
const
std
::
string
&
argument
,
size_t
nth
)
{
CHECK
(
var
->
IsArg
());
CHECK
(
op
->
IsStmt
());
auto
op_info
=
op
->
stmt
()
->
op_info
();
if
(
op_info
->
Output
(
argument
).
size
()
<=
nth
)
return
false
;
return
var
->
arg
()
->
name
==
op_info
->
Output
(
argument
)[
nth
];
}
bool
IsNthInput
(
const
Node
*
var
,
const
Node
*
op
,
const
std
::
string
&
argument
,
size_t
nth
)
{
CHECK
(
var
->
IsArg
());
CHECK
(
op
->
IsStmt
());
auto
op_info
=
op
->
stmt
()
->
op_info
();
if
(
op_info
->
Input
(
argument
).
size
()
<=
nth
)
return
false
;
return
var
->
arg
()
->
name
==
op_info
->
Input
(
argument
)[
nth
];
}
PMNode
*
PMNode
::
assert_is_op_input
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
)
{
assert_is_var
();
assert_is_op_nth_input
(
op_type
,
argument
,
0
);
return
this
;
}
PMNode
*
PMNode
::
assert_is_op_nth_input
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
,
int
nth
)
{
assert_is_var
();
assert_is_op_input
(
op_type
);
asserts_
.
emplace_back
([
=
](
const
Node
*
x
)
{
for
(
auto
*
op
:
x
->
outlinks
)
{
if
(
op
&&
op
->
IsStmt
()
&&
op
->
stmt
()
->
op_info
()
->
Type
()
==
op_type
&&
IsNthInput
(
x
,
op
,
argument
,
nth
))
return
true
;
}
return
false
;
});
return
this
;
}
PMNode
*
PMNode
::
assert_is_op_output
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
)
{
assert_is_var
();
assert_is_op_nth_output
(
op_type
,
argument
,
0
);
return
this
;
}
PMNode
*
PMNode
::
assert_is_op_nth_output
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
,
int
nth
)
{
assert_is_var
();
asserts_
.
emplace_back
([
=
](
const
Node
*
x
)
{
for
(
auto
*
op
:
x
->
inlinks
)
{
if
(
op
&&
op
->
IsStmt
()
&&
op
->
stmt
()
->
op_info
()
->
Type
()
==
op_type
&&
IsNthOutput
(
x
,
op
,
argument
,
nth
))
return
true
;
}
return
false
;
});
return
this
;
}
PMNode
*
PMNode
::
assert_is_op_input
(
const
std
::
string
&
op_type
)
{
assert_is_var
();
asserts_
.
emplace_back
([
=
](
const
Node
*
x
)
{
...
...
@@ -422,6 +484,14 @@ PMNode *PMNode::assert_is_op_input(const std::string &op_type) {
return
this
;
}
bool
HasInput
(
const
Node
&
op
,
const
std
::
string
&
argument
)
{
CHECK
(
op
.
IsStmt
());
auto
const
&
names
=
op
.
stmt
()
->
op_info
()
->
input_argnames
();
if
(
std
::
find
(
names
.
begin
(),
names
.
end
(),
argument
)
==
names
.
end
())
return
false
;
return
true
;
}
void
GraphSafeRemoveNodes
(
SSAGraph
*
graph
,
const
std
::
unordered_set
<
const
Node
*>
&
nodes
)
{
for
(
auto
*
node
:
nodes
)
{
...
...
paddle/fluid/lite/core/mir/pattern_matcher.h
浏览文件 @
505e450d
...
...
@@ -62,7 +62,7 @@ struct PMNode {
PMNode
&
operator
>>
(
PMNode
&
right
);
// Link many nodes to this node.
friend
void
operator
>>
(
std
::
vector
<
PMNode
*>&
others
,
PMNode
&
me
);
friend
PMNode
&
operator
>>
(
std
::
vector
<
PMNode
*>&
others
,
PMNode
&
me
);
// Link this to many other nodes.
PMNode
&
operator
>>
(
std
::
vector
<
PMNode
*>&
nodes
);
...
...
@@ -127,6 +127,15 @@ struct PMNode {
PMNode
*
assert_is_persistable_var
();
PMNode
*
assert_is_op_output
(
const
std
::
string
&
op_type
);
PMNode
*
assert_is_op_input
(
const
std
::
string
&
op_type
);
PMNode
*
assert_is_op_input
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
);
PMNode
*
assert_is_op_output
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
);
PMNode
*
assert_is_op_nth_input
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
,
int
nth
);
PMNode
*
assert_is_op_nth_output
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
,
int
nth
);
template
<
typename
T
>
PMNode
*
assert_op_attr
(
const
std
::
string
&
attr_name
,
const
T
&
attr
)
{
...
...
@@ -297,6 +306,13 @@ class PatternMatcher {
std
::
unordered_map
<
const
PMNode
*
,
std
::
unordered_set
<
Node
*>>
pmnodes2nodes_
;
};
// Check whether a var node is a op node's nth input.
bool
IsNthInput
(
const
Node
&
var
,
const
Node
&
op
,
const
std
::
string
&
argument
,
int
nth
);
// Check whether the op node has input of given name.
bool
HasInput
(
const
Node
&
op
,
const
std
::
string
&
argument
);
// Graph safely remove some nodes, will automatically clean up the edges.
void
GraphSafeRemoveNodes
(
SSAGraph
*
graph
,
const
std
::
unordered_set
<
const
Node
*>&
nodes
);
...
...
paddle/fluid/lite/core/mir/pattern_matcher_high_api.cc
浏览文件 @
505e450d
...
...
@@ -20,7 +20,7 @@ namespace lite {
namespace
mir
{
void
FuseBase
::
PerformPatternMatcher
(
SSAGraph
*
graph
)
{
LOG
(
INFO
)
<<
"
\n
"
<<
matcher_
.
pattern
().
DotString
();
VLOG
(
4
)
<<
"
\n
"
<<
matcher_
.
pattern
().
DotString
();
// Get subgraphs and record the mir::Node pointers for each PMNode.
auto
handler
=
[
&
](
const
PatternMatcher
::
subgraph_t
&
subgraph
,
SSAGraph
*
g
)
{
// get all the reigistered nodes.
...
...
@@ -41,17 +41,14 @@ void FuseBase::DeleteInterNodes(SSAGraph *graph) {
}
}
LOG
(
INFO
)
<<
"keys.size "
<<
keys
.
size
();
std
::
unordered_set
<
const
Node
*>
nodes2rm
;
for
(
auto
&
matched
:
key2nodes_
)
{
LOG
(
INFO
)
<<
"get matched "
<<
matched
.
size
();
for
(
const
auto
&
key
:
keys
)
{
nodes2rm
.
insert
(
matched
.
at
(
key
));
}
}
LOG
(
INFO
)
<<
"clean nodes "
<<
nodes2rm
.
size
();
VLOG
(
3
)
<<
"clean nodes "
<<
nodes2rm
.
size
();
GraphSafeRemoveNodes
(
graph
,
nodes2rm
);
}
...
...
paddle/fluid/lite/core/mir/pattern_matcher_high_api.h
浏览文件 @
505e450d
...
...
@@ -64,7 +64,6 @@ class FuseBase {
// Delete nodes that are marked as Intermediate
void
DeleteInterNodes
(
SSAGraph
*
graph
);
private:
PMNode
*
GetOrCreateNode
(
const
std
::
string
&
key
);
protected:
...
...
paddle/fluid/lite/core/mir/pattern_matcher_high_api_test.cc
浏览文件 @
505e450d
...
...
@@ -29,8 +29,8 @@ class FcFuser : public FuseBase {
public:
void
BuildPattern
()
override
{
// create nodes.
auto
*
x
=
VarNode
(
"x"
);
auto
*
W
=
VarNode
(
"W"
);
auto
*
x
=
VarNode
(
"x"
)
->
assert_is_op_input
(
"mul"
,
"X"
)
;
auto
*
W
=
VarNode
(
"W"
)
->
assert_is_op_input
(
"mul"
,
"Y"
)
;
auto
*
b
=
VarNode
(
"b"
);
auto
*
mul
=
OpNode
(
"mul"
,
"mul"
);
auto
*
mul_out
=
VarNode
(
"mul_out"
);
...
...
@@ -38,12 +38,10 @@ class FcFuser : public FuseBase {
auto
*
Out
=
VarNode
(
"Out"
);
// create topology.
// std::vector<PMNode*>({W, x}) >> *mul >> *mul_out;
// std::vector<PMNode*>({mul_out, b}) >> *add >> *Out;
*
W
>>
*
mul
;
*
x
>>
*
mul
>>
*
mul_out
;
*
b
>>
*
add
;
*
mul_out
>>
*
add
>>
*
Out
;
std
::
vector
<
PMNode
*>
mul_inputs
{
W
,
x
};
std
::
vector
<
PMNode
*>
add_inputs
{
mul_out
,
b
};
mul_inputs
>>
*
mul
>>
*
mul_out
;
add_inputs
>>
*
add
>>
*
Out
;
// Some op specialities.
mul_out
->
AsIntermediate
();
...
...
@@ -91,14 +89,12 @@ std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc,
main_block
->
Var
(
"mul_out"
);
main_block
->
Var
(
"w"
);
main_block
->
Var
(
"out"
);
main_block
->
Var
(
"out1"
);
scope
->
Var
(
"w"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"b"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"mul_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"w"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"out1"
)
->
GetMutable
<
lite
::
Tensor
>
();
mul
->
SetInput
(
"X"
,
{
"x"
});
mul
->
SetInput
(
"Y"
,
{
"w"
});
...
...
@@ -122,18 +118,17 @@ std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc,
return
graph
;
}
TEST
(
pattern_matcher
2
,
graph_test
)
{
TEST
(
pattern_matcher
_high_api
,
graph_test
)
{
framework
::
ProgramDesc
program_desc
;
std
::
vector
<
Place
>
places
{{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}};
auto
scope
=
std
::
make_shared
<
Scope
>
();
auto
graph
=
BuildGraph
(
&
program_desc
,
scope
,
places
);
ASSERT_EQ
(
graph
->
nodes
().
size
(),
8UL
/*real nodes*/
+
2UL
/*feed op + fetch op*/
);
ASSERT_EQ
(
graph
->
nodes
().
size
(),
7UL
/*real nodes*/
);
Visualize
(
graph
.
get
());
}
TEST
(
pattern_matcher
2
,
test
)
{
TEST
(
pattern_matcher
_high_api
,
fuse_
test
)
{
framework
::
ProgramDesc
program_desc
;
std
::
vector
<
Place
>
places
{{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}};
auto
scope
=
std
::
make_shared
<
Scope
>
();
...
...
@@ -143,6 +138,7 @@ TEST(pattern_matcher2, test) {
fuser
(
graph
.
get
());
ASSERT_EQ
(
graph
->
nodes
().
size
(),
num_nodes
-
3UL
/*nodes removed */
+
1UL
/* fused fc node*/
);
Visualize
(
graph
.
get
());
}
}
// namespace mir
...
...
paddle/fluid/lite/core/mir/ssa_graph.cc
浏览文件 @
505e450d
...
...
@@ -16,6 +16,7 @@
#include <algorithm>
#include <memory>
#include <set>
#include <unordered_map>
#include <utility>
namespace
paddle
{
...
...
@@ -93,31 +94,6 @@ std::vector<mir::Node *> SSAGraph::StmtTopologicalOrder() {
return
res
;
}
void
SSAGraph
::
GraphCreateTmpVarNodes
(
const
Program
&
program
)
{
for
(
const
auto
&
name
:
program
.
tmp_vars
())
{
CHECK
(
!
arguments_
.
count
(
name
))
<<
"duplicate creating temp variable: "
<<
name
;
VLOG
(
5
)
<<
"create arg node "
<<
name
;
node_storage_
.
emplace_back
();
auto
&
new_node
=
node_storage_
.
back
();
new_node
.
AsArg
(
name
);
arguments_
[
name
]
=
&
new_node
;
}
}
void
SSAGraph
::
GraphCreateWeightVarNodes
(
const
Program
&
program
)
{
// create weight nodes.
for
(
const
auto
&
name
:
program
.
weights
())
{
CHECK
(
!
arguments_
.
count
(
name
))
<<
"duplicate creating weight variable: "
<<
name
;
VLOG
(
5
)
<<
"create arg node "
<<
name
;
node_storage_
.
emplace_back
();
auto
&
new_node
=
node_storage_
.
back
();
new_node
.
AsArg
(
name
);
arguments_
[
name
]
=
&
new_node
;
}
}
Node
*
SSAGraph
::
GraphCreateInstructNode
(
const
std
::
shared_ptr
<
OpLite
>
&
op
,
const
std
::
vector
<
Place
>
&
valid_places
)
{
node_storage_
.
emplace_back
();
...
...
@@ -135,29 +111,45 @@ Node *SSAGraph::GraphCreateInstructNode(
void
SSAGraph
::
Build
(
const
Program
&
program
,
const
std
::
vector
<
Place
>
&
valid_places
)
{
CHECK
(
node_storage_
.
empty
());
GraphCreateTmpVarNodes
(
program
);
GraphCreateWeightVarNodes
(
program
);
CHECK
(
CheckNodesRoleSet
());
auto
weights_name
=
program
.
weights
();
auto
is_weights
=
[
&
](
const
std
::
string
&
name
)
->
bool
{
auto
it
=
std
::
find
(
weights_name
.
begin
(),
weights_name
.
end
(),
name
);
if
(
it
==
weights_name
.
end
())
return
false
;
return
true
;
};
std
::
unordered_map
<
std
::
string
,
mir
::
Node
*>
arg_update_node_map_
;
for
(
auto
&
op
:
program
.
ops
())
{
auto
*
op_node
=
GraphCreateInstructNode
(
op
,
valid_places
);
for
(
const
std
::
string
&
name
:
op
->
op_info
()
->
input_names
())
{
auto
*
arg
=
Argument
(
name
);
CHECK
(
arg
->
IsRoleSet
());
DirectedLink
(
arg
,
op_node
);
mir
::
Node
*
arg_node
=
nullptr
;
if
(
arg_update_node_map_
.
count
(
name
))
{
arg_node
=
arg_update_node_map_
.
at
(
name
);
}
else
{
node_storage_
.
emplace_back
();
arg_node
=
&
node_storage_
.
back
();
arg_node
->
AsArg
(
name
,
node_storage_
.
size
()
-
1
);
arg_update_node_map_
[
name
]
=
arg_node
;
}
if
(
is_weights
(
name
))
arg_node
->
AsArg
().
is_weight
=
true
;
CHECK
(
arg_node
->
IsRoleSet
());
DirectedLink
(
arg_node
,
op_node
);
}
for
(
const
std
::
string
&
name
:
op
->
op_info
()
->
output_names
())
{
if
(
!
arguments_
.
count
(
name
))
{
NewArgumentNode
(
name
);
}
auto
*
arg
=
arguments_
.
at
(
name
);
CHECK
(
arg
->
IsRoleSet
());
DirectedLink
(
op_node
,
arg
);
node_storage_
.
emplace_back
();
auto
*
arg_node
=
&
node_storage_
.
back
();
arg_node
->
AsArg
(
name
,
node_storage_
.
size
()
-
1
);
arg_update_node_map_
[
name
]
=
arg_node
;
if
(
is_weights
(
name
))
arg_node
->
AsArg
().
is_weight
=
true
;
CHECK
(
arg_node
->
IsRoleSet
());
DirectedLink
(
op_node
,
arg_node
);
}
CHECK
(
CheckLinksRoleSet
());
}
MarkArgumentWeights
(
program
);
CHECK
(
CheckNodesRoleSet
()
);
CheckValid
();
}
...
...
@@ -227,10 +219,9 @@ bool SSAGraph::CheckLinksRoleSet() {
Node
*
SSAGraph
::
NewArgumentNode
(
const
std
::
string
&
name
)
{
node_storage_
.
emplace_back
();
CHECK
(
!
arguments_
.
count
(
name
))
<<
"duplicate argument called "
<<
name
;
arguments_
[
name
]
=
&
node_storage_
.
back
();
node_storage_
.
back
().
AsArg
(
name
);
return
&
node_storage_
.
back
();
auto
&
arg_node
=
node_storage_
.
back
();
arg_node
.
AsArg
(
name
,
node_storage_
.
size
()
-
1
);
return
&
arg_node
;
}
Node
*
SSAGraph
::
NewInstructNode
()
{
...
...
paddle/fluid/lite/core/mir/ssa_graph.h
浏览文件 @
505e450d
...
...
@@ -40,8 +40,6 @@ class SSAGraph : GraphBase {
void
Build
(
const
Program
&
program
,
const
std
::
vector
<
Place
>
&
valid_places
);
void
RemoveNode
(
const
mir
::
Node
*
node
);
mir
::
Node
*
Argument
(
const
std
::
string
&
name
);
std
::
vector
<
mir
::
Node
*>
StmtTopologicalOrder
();
// The inputs of the graph.
...
...
@@ -68,9 +66,7 @@ class SSAGraph : GraphBase {
const
std
::
vector
<
Place
>
&
valid_places
);
private:
void
GraphCreateTmpVarNodes
(
const
Program
&
program
);
void
GraphCreateWeightVarNodes
(
const
Program
&
program
);
mir
::
Node
*
Argument
(
const
std
::
string
&
name
);
// Check the bidirectional connection.
bool
CheckBidirectionalConnection
();
bool
CheckNodesRoleSet
();
...
...
paddle/fluid/lite/core/mir/type_target_transform_pass.cc
浏览文件 @
505e450d
...
...
@@ -65,20 +65,22 @@ void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph, Node* inst_node,
<<
" for kernel "
<<
inst
.
op
->
DebugString
()
<<
" "
<<
*
in
->
AsArg
().
type
<<
" -> "
<<
*
decl_arg_type
;
// Add an IoCopy instruction to make the input compatible with other dist.
AddIoCopyInst
(
*
in
->
AsArg
().
type
,
*
decl_arg_type
,
in
->
AsArg
().
name
,
graph
,
inst_node
,
valid_places_
);
AddIoCopyInst
(
*
in
->
AsArg
().
type
,
*
decl_arg_type
,
in
,
graph
,
inst_node
,
valid_places_
);
}
}
void
TypeTargetTransformPass
::
AddIoCopyInst
(
const
Type
&
from
,
const
Type
&
to
,
const
std
::
string
&
var
,
SSAGraph
*
graph
,
const
Type
&
from
,
const
Type
&
to
,
Node
*
in
,
SSAGraph
*
graph
,
Node
*
inst_node
,
const
std
::
vector
<
Place
>&
valid_places
)
{
CHECK
(
!
valid_places
.
empty
())
<<
"valid_place should be set"
;
// var -> new_transform_op -> new_var -> inst
// So there will be a new Argument node and a new IoCopy Statement Node.
CHECK
(
in
->
IsArg
());
auto
node_id
=
[
&
]
{
return
graph
->
nodes
().
size
();
};
auto
io_copy_output_name
=
var
+
"/trans/"
+
std
::
to_string
(
node_id
());
auto
io_copy_output_name
=
in
->
AsArg
().
name
+
"/trans/"
+
std
::
to_string
(
node_id
());
auto
*
io_copy_output_arg
=
graph
->
NewArgumentNode
(
io_copy_output_name
);
auto
*
io_copy_inst
=
graph
->
NewInstructNode
();
...
...
@@ -92,7 +94,7 @@ void TypeTargetTransformPass::AddIoCopyInst(
// Create IoCopy Instruction.
cpp
::
OpDesc
op_desc
;
op_desc
.
SetType
(
"io_copy"
);
op_desc
.
SetInput
(
"Input"
,
{
var
});
op_desc
.
SetInput
(
"Input"
,
{
in
->
AsArg
().
name
});
op_desc
.
SetOutput
(
"Out"
,
{
io_copy_output_name
});
io_copy_op
->
Attach
(
op_desc
,
inst_node
->
AsStmt
().
op
->
scope
());
...
...
@@ -100,18 +102,18 @@ void TypeTargetTransformPass::AddIoCopyInst(
io_copy_inst
->
AsStmt
(
"io_copy"
,
std
::
move
(
kernels
),
io_copy_op
);
// Remove the old link
RemoveDirectedLink
(
graph
->
Argument
(
var
)
,
inst_node
);
RemoveDirectedLink
(
in
,
inst_node
);
// Update the original instruction OpDesc.
// Update its input to the io_copy_output_name
// Add new link, var -> new_inst, new_inst->newarg, newarg->inst
DirectedLink
(
graph
->
Argument
(
var
)
,
io_copy_inst
);
DirectedLink
(
in
,
io_copy_inst
);
DirectedLink
(
io_copy_inst
,
io_copy_output_arg
);
DirectedLink
(
io_copy_output_arg
,
inst_node
);
// reset opdesc and update kernel information
UpdateInputTo
(
inst_node
->
AsStmt
().
op
->
mutable_op_info
(),
var
,
UpdateInputTo
(
inst_node
->
AsStmt
().
op
->
mutable_op_info
(),
in
->
AsArg
().
name
,
io_copy_output_name
);
inst_node
->
AsStmt
().
op
->
Attach
(
*
inst_node
->
AsStmt
().
op
->
op_info
(),
...
...
paddle/fluid/lite/core/mir/type_target_transform_pass.h
浏览文件 @
505e450d
...
...
@@ -45,7 +45,7 @@ class TypeTargetTransformPass : public ProgramPass {
void
ComplementInputs
(
SSAGraph
*
graph
,
Node
*
inst_node
,
Node
*
in
);
void
AddIoCopyInst
(
const
Type
&
from
,
const
Type
&
to
,
const
std
::
string
&
var
,
void
AddIoCopyInst
(
const
Type
&
from
,
const
Type
&
to
,
Node
*
in
,
SSAGraph
*
graph
,
Node
*
inst_node
,
const
std
::
vector
<
Place
>&
valid_places
);
...
...
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
浏览文件 @
505e450d
...
...
@@ -13,7 +13,10 @@
// limitations under the License.
#pragma once
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/target_wrapper.h"
...
...
@@ -60,40 +63,44 @@ class VariablePlaceInferencePass : public DebugPass {
// LOG(INFO) << "- inferencing type " <<
// deal with inputs
VLOG
(
4
)
<<
"inferencing op "
<<
inst
.
op_type
;
for
(
auto
&
arg_name
:
inst
.
op_info
()
->
input_argnames
())
{
// TODO(zhaolong): Add check if the node's name in op's arguments.
auto
get_argname
=
[
&
](
const
std
::
string
&
node_name
,
const
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>&
argname_map
)
->
std
::
string
{
for
(
auto
&
ele
:
argname_map
)
{
auto
it
=
std
::
find
(
ele
.
second
.
begin
(),
ele
.
second
.
end
(),
node_name
);
if
(
it
!=
ele
.
second
.
end
())
return
ele
.
first
;
}
return
""
;
};
for
(
auto
*
x_in
:
x
->
inlinks
)
{
std
::
string
node_name
=
x_in
->
AsArg
().
name
;
std
::
string
arg_name
=
get_argname
(
node_name
,
inst
.
op_info
()
->
inputs
());
CHECK
(
arg_name
.
size
()
>
0
)
<<
"can not found op arguments for node "
<<
node_name
;
VLOG
(
3
)
<<
"-- input arg_name "
<<
arg_name
;
// check if inputs's place is set, if not set, update them with the
// kernel's declaration.
auto
type
=
inst
.
picked_kernel
().
GetInputDeclType
(
arg_name
);
auto
arg_names
=
inst
.
op_info
()
->
inputs
().
at
(
arg_name
);
for
(
auto
&
arg_name
:
arg_names
)
{
VLOG
(
3
)
<<
"--- var "
<<
arg_name
;
auto
*
node
=
graph
->
RetrieveArgument
(
arg_name
);
CHECK
(
node
)
<<
"argument "
<<
arg_name
<<
" not exists in the graph"
;
auto
&
arg_node
=
node
->
AsArg
();
if
(
!
arg_node
.
type
)
{
VLOG
(
4
)
<<
"set type "
<<
*
type
<<
" "
<<
node
;
arg_node
.
type
=
type
;
}
if
(
!
x_in
->
AsArg
().
type
)
{
VLOG
(
4
)
<<
"set type "
<<
*
type
<<
" "
<<
x_in
;
x_in
->
AsArg
().
type
=
type
;
}
}
for
(
auto
&
arg_name
:
inst
.
op_info
()
->
output_argnames
())
{
for
(
auto
*
x_out
:
x
->
outlinks
)
{
std
::
string
node_name
=
x_out
->
AsArg
().
name
;
std
::
string
arg_name
=
get_argname
(
node_name
,
inst
.
op_info
()
->
outputs
());
CHECK
(
arg_name
.
size
()
>
0
)
<<
"can not found op arguments for node "
<<
node_name
;
VLOG
(
3
)
<<
"-- output arg_name "
<<
arg_name
;
auto
type
=
inst
.
picked_kernel
().
GetOutputDeclType
(
arg_name
);
auto
arg_names
=
inst
.
op_info
()
->
outputs
().
at
(
arg_name
);
// check if outputs's place is set, if not set, update them with the
// kernel's declaration.
for
(
auto
&
arg_name
:
arg_names
)
{
VLOG
(
3
)
<<
"--- var "
<<
arg_name
;
auto
*
node
=
graph
->
RetrieveArgument
(
arg_name
);
CHECK
(
node
)
<<
"argument "
<<
arg_name
<<
" not exists in the graph"
;
auto
&
arg_node
=
node
->
AsArg
();
if
(
!
arg_node
.
type
)
{
node
->
AsArg
().
type
=
type
;
VLOG
(
3
)
<<
"set type "
<<
*
type
;
}
if
(
!
x_out
->
AsArg
().
type
)
{
VLOG
(
4
)
<<
"set type "
<<
*
type
<<
" "
<<
x_out
;
x_out
->
AsArg
().
type
=
type
;
}
}
}
...
...
paddle/fluid/lite/core/op_lite.h
浏览文件 @
505e450d
...
...
@@ -59,7 +59,7 @@ class OpLite : public Registry {
}
void
SetValidPlaces
(
const
std
::
vector
<
Place
>
&
places
)
{
LOG
(
INFO
)
<<
"valid places "
<<
valid_places_
.
size
();
VLOG
(
3
)
<<
"valid places "
<<
valid_places_
.
size
();
valid_places_
=
places
;
}
const
std
::
vector
<
Place
>
&
valid_places
()
const
{
return
valid_places_
;
}
...
...
paddle/fluid/lite/core/optimizer.h
浏览文件 @
505e450d
...
...
@@ -48,6 +48,9 @@ class Optimizer {
if
(
passes
.
empty
())
{
RunPasses
(
std
::
vector
<
std
::
string
>
{{
"lite_conv_bn_fuse_pass"
,
//
"lite_conv_elementwise_add_act_fuse_pass"
,
//
"lite_fc_fuse_pass"
,
//
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
"static_kernel_pick_pass"
,
//
"variable_place_inference_pass"
,
//
...
...
paddle/fluid/lite/core/profile/basic_profiler.h
浏览文件 @
505e450d
...
...
@@ -152,8 +152,8 @@ class BasicProfiler {
}
record_t
*
mutable_record
(
int
id
)
{
CHECK_LT
(
id
,
records_
.
size
());
CHECK_GE
(
id
,
0
);
CHECK_LT
(
static_cast
<
size_t
>
(
id
),
records_
.
size
());
return
&
records_
[
id
];
}
...
...
paddle/fluid/lite/core/program.h
浏览文件 @
505e450d
...
...
@@ -140,7 +140,7 @@ class RuntimeProgram {
void
Run
()
{
for
(
auto
&
inst
:
instructions_
)
{
LOG
(
INFO
)
<<
">> Running kernel: "
<<
inst
;
VLOG
(
4
)
<<
">> Running kernel: "
<<
inst
;
inst
.
Run
();
}
}
...
...
paddle/fluid/lite/kernels/x86/conv_compute.cc
浏览文件 @
505e450d
...
...
@@ -74,6 +74,7 @@ class Conv2dCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
lite
::
Tensor
col_matrix
;
if
(
is_expand
)
{
col
.
Resize
(
col_shape
);
col
.
mutable_data
<
T
>
();
col_matrix
.
ShareDataWith
(
col
);
col_matrix
.
Resize
(
col_matrix_shape
);
}
...
...
@@ -104,7 +105,7 @@ class Conv2dCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
param
.
x
->
raw_tensor
().
Slice
(
i
,
i
+
1
).
Resize
(
input_shape
.
data
()));
lite
::
Tensor
out_batch
;
out_batch
.
ShareDataWith
(
param
.
output
->
raw_tensor
().
Slice
(
i
,
i
+
1
).
Resize
(
input
_shape
.
data
()));
output_matrix
_shape
.
data
()));
for
(
int
g
=
0
;
g
<
param
.
groups
;
g
++
)
{
lite
::
Tensor
in_slice
;
...
...
@@ -155,7 +156,6 @@ REGISTER_LITE_KERNEL(conv2d, kX86, kFloat, kNCHW,
.
BindInput
(
"Input"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kX86
))})
.
BindInput
(
"Filter"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kX86
))})
.
BindInput
(
"Bias"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kX86
))})
.
BindInput
(
"Input"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kX86
))})
.
BindOutput
(
"Output"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kX86
))})
.
Finalize
();
...
...
@@ -164,6 +164,5 @@ REGISTER_LITE_KERNEL(depthwise_conv2d, kX86, kFloat, kNCHW,
.
BindInput
(
"Input"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kX86
))})
.
BindInput
(
"Filter"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kX86
))})
.
BindInput
(
"Bias"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kX86
))})
.
BindInput
(
"Input"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kX86
))})
.
BindOutput
(
"Output"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kX86
))})
.
Finalize
();
paddle/fluid/lite/kernels/x86/fc_compute.cc
浏览文件 @
505e450d
...
...
@@ -27,8 +27,8 @@ namespace kernels {
namespace
x86
{
template
<
typename
T
>
void
fc_compute_eigen
(
const
T
*
x
,
int
x_
w
,
int
x_h
,
//
const
T
*
w
,
int
w_
w
,
int
w_h
,
//
void
fc_compute_eigen
(
const
T
*
x
,
int
x_
h
,
int
x_w
,
//
const
T
*
w
,
int
w_
h
,
int
w_w
,
//
const
T
*
b
,
//
T
*
out
)
{
using
matrix_t
=
...
...
@@ -36,38 +36,31 @@ void fc_compute_eigen(const T* x, int x_w, int x_h, //
Eigen
::
Map
<
const
matrix_t
>
X
(
x
,
x_h
,
x_w
);
Eigen
::
Map
<
const
matrix_t
>
W
(
w
,
w_h
,
w_w
);
Eigen
::
Map
<
matrix_t
>
Out
(
out
,
x_h
,
w_
h
);
Eigen
::
Map
<
matrix_t
>
Out
(
out
,
x_h
,
w_
w
);
Out
=
X
*
W
.
transpose
()
;
Out
=
X
*
W
;
if
(
b
)
{
Eigen
::
Map
<
const
Eigen
::
Matrix
<
T
,
Eigen
::
Dynamic
,
1
>>
B
(
b
,
w_
h
);
Eigen
::
Map
<
const
Eigen
::
Matrix
<
T
,
Eigen
::
Dynamic
,
1
>>
B
(
b
,
w_
w
);
Out
=
Out
.
array
().
rowwise
()
+
B
.
transpose
().
array
();
}
}
template
<
typename
T
>
__attribute__
((
optimize
(
"unroll-loops"
)))
//
T
dot
(
const
T
*
x
,
const
T
*
y
,
int
dim
)
{
T
out
{};
for
(
int
i
=
0
;
i
<
dim
;
i
++
)
{
out
+=
x
[
i
]
*
y
[
i
];
}
return
out
;
}
template
<
typename
T
>
void
fc_compute_naive
(
const
T
*
x
,
int
x_w
,
int
x_h
,
//
const
T
*
w
,
int
w_w
,
int
w_h
,
//
void
fc_compute_naive
(
const
T
*
x
,
int
x_h
,
int
x_w
,
//
const
T
*
w
,
int
w_h
,
int
w_w
,
//
const
T
*
b
,
//
T
*
out
)
{
CHECK_EQ
(
x_w
,
w_
w
);
CHECK_EQ
(
x_w
,
w_
h
);
// out shape: (x_h, w_w)
memset
(
out
,
0
,
x_h
*
w_h
*
sizeof
(
T
));
for
(
int
r
=
0
;
r
<
x_h
;
r
++
)
{
for
(
int
c
=
0
;
c
<
w_h
;
c
++
)
{
out
[
r
*
w_h
+
c
]
=
dot
(
&
x
[
r
*
x_w
],
&
w
[
c
*
w_w
],
w_w
)
+
b
[
c
];
memset
(
out
,
0
,
x_h
*
w_w
*
sizeof
(
T
));
for
(
int
i
=
0
;
i
<
x_h
;
i
++
)
{
for
(
int
j
=
0
;
j
<
w_w
;
j
++
)
{
T
tmp
=
static_cast
<
T
>
(
0
);
for
(
int
k
=
0
;
k
<
x_w
;
k
++
)
{
tmp
+=
x
[
i
*
x_w
+
k
]
*
w
[
k
*
w_w
+
j
];
}
out
[
i
*
w_w
+
j
]
=
tmp
+
b
[
j
];
}
}
}
...
...
@@ -89,8 +82,8 @@ class FcCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
.
Slice
(
param
.
in_num_col_dims
,
param
.
input
->
dims
().
size
())
.
production
(),
param
.
w
->
data
<
T
>
(),
// w
param
.
w
->
dims
()[
1
],
// w_w
param
.
w
->
dims
()[
0
],
// w_h
param
.
w
->
dims
()[
1
],
// w_w
param
.
bias
->
data
<
T
>
(),
// b
param
.
output
->
mutable_data
<
T
>
());
}
...
...
paddle/fluid/lite/kernels/x86/relu_compute.cc
浏览文件 @
505e450d
...
...
@@ -51,6 +51,6 @@ class ReluCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
REGISTER_LITE_KERNEL
(
relu
,
kX86
,
kFloat
,
kNCHW
,
paddle
::
lite
::
kernels
::
x86
::
ReluCompute
<
float
>
,
def
)
.
BindInput
(
"
Input
"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kX86
))})
.
BindInput
(
"
X
"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kX86
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kX86
))})
.
Finalize
();
paddle/fluid/lite/model_parser/model_parser.cc
浏览文件 @
505e450d
...
...
@@ -91,7 +91,7 @@ void LoadLoDTensor(std::istream &is, Variable *var) {
auto
*
tensor
=
var
->
GetMutable
<
lite
::
Tensor
>
();
uint32_t
version
{};
is
.
read
(
reinterpret_cast
<
char
*>
(
&
version
),
sizeof
(
version
));
LOG
(
INFO
)
<<
"model version "
<<
version
;
VLOG
(
3
)
<<
"model version "
<<
version
;
// Load LoD information
uint64_t
lod_level
{};
...
...
@@ -154,7 +154,7 @@ void LoadModel(const std::string &model_dir, Scope *scope,
continue
;
std
::
string
file_path
=
model_dir
+
"/"
+
var
.
name
();
LOG
(
INFO
)
<<
"reading weight "
<<
var
.
name
();
VLOG
(
4
)
<<
"reading weight "
<<
var
.
name
();
std
::
ifstream
file
(
file_path
);
switch
(
var
.
type
().
type
())
{
...
...
paddle/fluid/lite/operators/CMakeLists.txt
浏览文件 @
505e450d
...
...
@@ -20,7 +20,7 @@ cc_library(fill_constant_op_lite SRCS fill_constant_op.cc DEPS ${op_DEPS})
cc_library
(
op_params_lite SRCS op_params.cc DEPS
${
tensor_lite
}
any_lite framework_proto_lite
)
cc_library
(
dropout_op_lite SRCS dropout_op.cc DEPS
${
op_DEPS
}
)
cc_library
(
concat_op_lite SRCS concat_op.cc DEPS
${
op_DEPS
}
)
cc_library
(
split_op_lite SRCS split_op.cc DEPS
${
op_DEPS
}
)
#
cc_library(split_op_lite SRCS split_op.cc DEPS ${op_DEPS})
set
(
ops_lite
conv_op_lite
...
...
@@ -41,7 +41,7 @@ set(ops_lite
activation_ops_lite
dropout_op_lite
concat_op_lite
split_op_lite
#
split_op_lite
PARENT_SCOPE
)
lite_cc_test
(
test_fc_op_lite SRCS fc_op_test.cc
...
...
@@ -56,4 +56,3 @@ lite_cc_test(test_softmax_op_lite SRCS softmax_op_test.cc DEPS softmax_op_lite m
lite_cc_test
(
test_reshape_op_lite SRCS reshape_op_test.cc DEPS reshape_op_lite memory_lite
)
lite_cc_test
(
test_batch_norm_op_lite SRCS batch_norm_op_test.cc DEPS batch_norm_op_lite memory_lite
)
lite_cc_test
(
test_concat_op_lite SRCS concat_op_test.cc DEPS concat_op_lite memory_lite
)
paddle/fluid/lite/operators/conv_op.h
浏览文件 @
505e450d
...
...
@@ -30,25 +30,27 @@ class ConvOpLite : public OpLite {
public:
ConvOpLite
()
{}
explicit
ConvOpLite
(
const
std
::
string
&
type
)
:
OpLite
(
type
)
{}
explicit
ConvOpLite
(
const
std
::
string
&
type
)
:
OpLite
(
type
)
{}
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
auto
input
=
op_desc
.
Input
(
"Input"
).
front
();
auto
filter
=
op_desc
.
Input
(
"Filter"
).
front
();
auto
output
=
op_desc
.
Output
(
"Output"
).
front
();
param_
.
x
=
scope
->
FindVar
(
input
)
->
GetMutable
<
lite
::
Tensor
>
();
param_
.
filter
=
scope
->
FindVar
(
filter
)
->
GetMutable
<
lite
::
Tensor
>
();
CHECK
(
scope
->
FindVar
(
output
));
param_
.
output
=
scope
->
FindVar
(
output
)
->
GetMutable
<
lite
::
Tensor
>
();
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
auto
X
=
op_desc
.
Input
(
"Input"
).
front
();
auto
Filter
=
op_desc
.
Input
(
"Filter"
).
front
();
auto
Out
=
op_desc
.
Output
(
"Output"
).
front
();
param_
.
x
=
scope
->
FindVar
(
X
)
->
GetMutable
<
lite
::
Tensor
>
();
param_
.
filter
=
scope
->
FindVar
(
Filter
)
->
GetMutable
<
lite
::
Tensor
>
();
param_
.
output
=
scope
->
FindVar
(
Out
)
->
GetMutable
<
lite
::
Tensor
>
();
param_
.
strides
=
op_desc
.
GetAttr
<
std
::
vector
<
int
>>
(
"strides"
);
param_
.
paddings
=
op_desc
.
GetAttr
<
std
::
vector
<
int
>>
(
"paddings"
);
param_
.
groups
=
op_desc
.
GetAttr
<
int
>
(
"groups"
);
param_
.
dilations
=
op_desc
.
GetAttr
<
std
::
vector
<
int
>>
(
"dilations"
);
// optional params
std
::
vector
<
std
::
string
>
input_arg_names
=
op_desc
.
InputArgumentNames
();
if
(
std
::
find
(
input_arg_names
.
begin
(),
input_arg_names
.
end
(),
"Bias"
)
!=
...
...
@@ -58,7 +60,7 @@ class ConvOpLite : public OpLite {
auto
bias_var
=
scope
->
FindVar
(
bias_arguments
.
front
());
if
(
bias_var
!=
nullptr
)
{
param_
.
bias
=
const_cast
<
lite
::
Tensor
*>
(
&
(
bias_var
->
Get
<
lite
::
Tensor
>
()));
const_cast
<
lite
::
Tensor
*>
(
&
(
bias_var
->
Get
<
lite
::
Tensor
>
()));
}
}
}
...
...
@@ -68,7 +70,7 @@ class ConvOpLite : public OpLite {
if
(
res_data_arguments
.
size
()
>
0
)
{
auto
residual_data_var
=
scope
->
FindVar
(
res_data_arguments
.
front
());
if
(
residual_data_var
!=
nullptr
)
{
param_
.
residualData
=
const_cast
<
lite
::
Tensor
*>
(
param_
.
residualData
=
const_cast
<
lite
::
Tensor
*>
(
&
(
residual_data_var
->
Get
<
lite
::
Tensor
>
()));
}
}
...
...
@@ -77,7 +79,7 @@ class ConvOpLite : public OpLite {
return
true
;
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"conv2d"
;
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录