Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
505e450d
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
505e450d
编写于
6月 15, 2019
作者:
X
xingzhaolong
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'xzl/incubate/lite' into 'incubate/lite'
merge github code to gitlab. See merge request inference/paddlelite!13
上级
7c02e682
83bd58d4
变更
44
隐藏空白更改
内联
并排
Showing
44 changed file
with
1382 addition
and
176 deletion
+1382
-176
paddle/fluid/lite/CMakeLists.txt
paddle/fluid/lite/CMakeLists.txt
+3
-2
paddle/fluid/lite/api/CMakeLists.txt
paddle/fluid/lite/api/CMakeLists.txt
+2
-5
paddle/fluid/lite/api/cxx_api_bin.cc
paddle/fluid/lite/api/cxx_api_bin.cc
+20
-5
paddle/fluid/lite/core/CMakeLists.txt
paddle/fluid/lite/core/CMakeLists.txt
+3
-1
paddle/fluid/lite/core/mir/CMakeLists.txt
paddle/fluid/lite/core/mir/CMakeLists.txt
+35
-9
paddle/fluid/lite/core/mir/conv_bn_fuse_pass.cc
paddle/fluid/lite/core/mir/conv_bn_fuse_pass.cc
+37
-0
paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h
paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h
+32
-0
paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.cc
...luid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.cc
+39
-0
paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h
...fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h
+32
-0
paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass_test.cc
...lite/core/mir/conv_elementwise_add_relu_fuse_pass_test.cc
+153
-0
paddle/fluid/lite/core/mir/fc_fuse_pass.cc
paddle/fluid/lite/core/mir/fc_fuse_pass.cc
+34
-0
paddle/fluid/lite/core/mir/fc_fuse_pass.h
paddle/fluid/lite/core/mir/fc_fuse_pass.h
+32
-0
paddle/fluid/lite/core/mir/fc_fuse_pass_test.cc
paddle/fluid/lite/core/mir/fc_fuse_pass_test.cc
+112
-0
paddle/fluid/lite/core/mir/fusion/CMakeLists.txt
paddle/fluid/lite/core/mir/fusion/CMakeLists.txt
+22
-0
paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass_test.cc
paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass_test.cc
+140
-0
paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.cc
paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.cc
+128
-0
paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.h
paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.h
+57
-0
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.cc
...d/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.cc
+108
-0
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h
...id/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h
+41
-0
paddle/fluid/lite/core/mir/fusion/fc_fuser.cc
paddle/fluid/lite/core/mir/fusion/fc_fuser.cc
+78
-0
paddle/fluid/lite/core/mir/fusion/fc_fuser.h
paddle/fluid/lite/core/mir/fusion/fc_fuser.h
+38
-0
paddle/fluid/lite/core/mir/generate_program_pass.cc
paddle/fluid/lite/core/mir/generate_program_pass.cc
+1
-1
paddle/fluid/lite/core/mir/node.h
paddle/fluid/lite/core/mir/node.h
+8
-0
paddle/fluid/lite/core/mir/passes.h
paddle/fluid/lite/core/mir/passes.h
+4
-0
paddle/fluid/lite/core/mir/pattern_matcher.cc
paddle/fluid/lite/core/mir/pattern_matcher.cc
+71
-1
paddle/fluid/lite/core/mir/pattern_matcher.h
paddle/fluid/lite/core/mir/pattern_matcher.h
+17
-1
paddle/fluid/lite/core/mir/pattern_matcher_high_api.cc
paddle/fluid/lite/core/mir/pattern_matcher_high_api.cc
+2
-5
paddle/fluid/lite/core/mir/pattern_matcher_high_api.h
paddle/fluid/lite/core/mir/pattern_matcher_high_api.h
+0
-1
paddle/fluid/lite/core/mir/pattern_matcher_high_api_test.cc
paddle/fluid/lite/core/mir/pattern_matcher_high_api_test.cc
+10
-14
paddle/fluid/lite/core/mir/ssa_graph.cc
paddle/fluid/lite/core/mir/ssa_graph.cc
+33
-42
paddle/fluid/lite/core/mir/ssa_graph.h
paddle/fluid/lite/core/mir/ssa_graph.h
+1
-5
paddle/fluid/lite/core/mir/type_target_transform_pass.cc
paddle/fluid/lite/core/mir/type_target_transform_pass.cc
+10
-8
paddle/fluid/lite/core/mir/type_target_transform_pass.h
paddle/fluid/lite/core/mir/type_target_transform_pass.h
+1
-1
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
+34
-27
paddle/fluid/lite/core/op_lite.h
paddle/fluid/lite/core/op_lite.h
+1
-1
paddle/fluid/lite/core/optimizer.h
paddle/fluid/lite/core/optimizer.h
+3
-0
paddle/fluid/lite/core/profile/basic_profiler.h
paddle/fluid/lite/core/profile/basic_profiler.h
+1
-1
paddle/fluid/lite/core/program.h
paddle/fluid/lite/core/program.h
+1
-1
paddle/fluid/lite/kernels/x86/conv_compute.cc
paddle/fluid/lite/kernels/x86/conv_compute.cc
+2
-3
paddle/fluid/lite/kernels/x86/fc_compute.cc
paddle/fluid/lite/kernels/x86/fc_compute.cc
+17
-24
paddle/fluid/lite/kernels/x86/relu_compute.cc
paddle/fluid/lite/kernels/x86/relu_compute.cc
+1
-1
paddle/fluid/lite/model_parser/model_parser.cc
paddle/fluid/lite/model_parser/model_parser.cc
+2
-2
paddle/fluid/lite/operators/CMakeLists.txt
paddle/fluid/lite/operators/CMakeLists.txt
+2
-3
paddle/fluid/lite/operators/conv_op.h
paddle/fluid/lite/operators/conv_op.h
+14
-12
未找到文件。
paddle/fluid/lite/CMakeLists.txt
浏览文件 @
505e450d
...
...
@@ -10,6 +10,7 @@ message(STATUS "LITE_WITH_ARM:\t${LITE_WITH_ARM}")
message
(
STATUS
"LITE_WITH_PROFILE:
\t
${
LITE_WITH_PROFILE
}
"
)
set
(
LITE_MODEL_DIR
"
${
THIRD_PARTY_PATH
}
/install"
)
set
(
LITE_URL
"http://paddle-inference-dist.bj.bcebos.com"
CACHE STRING
"inference download url"
)
function
(
lite_download_and_uncompress INSTALL_DIR URL FILENAME
)
message
(
STATUS
"Download inference test stuff from
${
URL
}
/
${
FILENAME
}
"
)
...
...
@@ -161,13 +162,13 @@ function(lite_cc_test TARGET)
file
(
APPEND
${
offline_test_registry_file
}
"
${
TARGET
}
\n
"
)
endfunction
()
add_subdirectory
(
operators
)
add_subdirectory
(
kernels
)
add_subdirectory
(
core
)
add_subdirectory
(
x86
)
add_subdirectory
(
arm
)
add_subdirectory
(
host
)
add_subdirectory
(
cuda
)
add_subdirectory
(
operators
)
add_subdirectory
(
kernels
)
add_subdirectory
(
model_parser
)
add_subdirectory
(
utils
)
add_subdirectory
(
api
)
...
...
paddle/fluid/lite/api/CMakeLists.txt
浏览文件 @
505e450d
...
...
@@ -5,7 +5,7 @@ if(LITE_WITH_CUDA)
nv_test
(
test_cxx_api_lite_cuda SRCS cxx_api_test.cc DEPS cxx_api_lite_cuda
)
endif
()
cc_library
(
cxx_api_lite SRCS cxx_api.cc DEPS
${
cxx_api_lite_deps
}
${
ops_lite
}
)
cc_library
(
cxx_api_lite SRCS cxx_api.cc DEPS
${
cxx_api_lite_deps
}
${
ops_lite
}
program_lite
)
set
(
light_api_deps
scope_lite target_wrapper_host model_parser_lite
)
...
...
@@ -21,15 +21,13 @@ message(STATUS "get Host kernels ${host_kernels}")
message
(
STATUS
"get ARM kernels
${
arm_kernels
}
"
)
include
(
ExternalProject
)
set
(
LITE_URL
"http://paddle-inference-dist.bj.bcebos.com"
CACHE STRING
"inference download url"
)
set
(
LITE_DEMO_INSTALL_DIR
"
${
THIRD_PARTY_PATH
}
/inference_demo"
CACHE STRING
"A path setting inference demo download directories."
)
if
((
NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
)
AND WITH_TESTING
)
lite_cc_test
(
test_cxx_api_lite SRCS cxx_api_test.cc
DEPS cxx_api_lite m
odel_parser_lite target_wrapper_host
DEPS cxx_api_lite m
ir_passes
${
ops_lite
}
${
host_kernels
}
${
x86_kernels
}
PROFILE_DEPS basic_profiler_lite
ARGS --model_dir=
${
LITE_MODEL_DIR
}
/lite_naive_model
--optimized_model=
${
LITE_MODEL_DIR
}
/lite_naive_model_opt SERIAL
)
...
...
@@ -45,7 +43,6 @@ endif()
# lite_cc_test(test_light_api SRCS light_api_test.cc DEPS light_api_lite ARGS --optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL)
# endif()
lite_cc_binary
(
cxx_api_lite_bin SRCS cxx_api_bin.cc
DEPS
cxx_api_lite
...
...
paddle/fluid/lite/api/cxx_api_bin.cc
浏览文件 @
505e450d
...
...
@@ -13,13 +13,22 @@
// limitations under the License.
#include "paddle/fluid/lite/api/cxx_api.h"
#include <chrono>
#include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace
paddle
{
namespace
lite
{
void
Run
(
const
char
*
model_dir
)
{
using
Time
=
decltype
(
std
::
chrono
::
high_resolution_clock
::
now
());
Time
time
()
{
return
std
::
chrono
::
high_resolution_clock
::
now
();
}
double
time_diff
(
Time
t1
,
Time
t2
)
{
typedef
std
::
chrono
::
microseconds
ms
;
auto
diff
=
t2
-
t1
;
ms
counter
=
std
::
chrono
::
duration_cast
<
ms
>
(
diff
);
return
counter
.
count
()
/
1000.0
;
}
void
Run
(
const
char
*
model_dir
,
int
repeat
)
{
#ifdef LITE_WITH_ARM
DeviceInfo
::
Init
();
#endif
...
...
@@ -34,10 +43,16 @@ void Run(const char* model_dir) {
input_tensor
->
Resize
(
DDim
(
std
::
vector
<
DDim
::
value_type
>
({
1
,
3
,
224
,
224
})));
auto
*
data
=
input_tensor
->
mutable_data
<
float
>
();
for
(
int
i
=
0
;
i
<
input_tensor
->
dims
().
production
();
i
++
)
{
data
[
i
]
=
i
;
data
[
i
]
=
1
;
}
predictor
.
Run
();
for
(
int
i
=
0
;
i
<
10
;
i
++
)
predictor
.
Run
();
auto
time1
=
time
();
for
(
int
i
=
0
;
i
<
repeat
;
i
++
)
predictor
.
Run
();
auto
time2
=
time
();
std
::
cout
<<
" predict cost: "
<<
time_diff
(
time1
,
time2
)
/
repeat
<<
"ms"
<<
std
::
endl
;
auto
*
out
=
predictor
.
GetOutput
(
0
);
LOG
(
INFO
)
<<
out
<<
" memory size "
<<
out
->
data_size
();
...
...
@@ -52,7 +67,7 @@ void Run(const char* model_dir) {
int
main
(
int
argc
,
char
**
argv
)
{
CHECK_EQ
(
argc
,
2
)
<<
"usage: ./cmd <model_dir>"
;
paddle
::
lite
::
Run
(
argv
[
1
]);
paddle
::
lite
::
Run
(
argv
[
1
]
,
1
);
return
0
;
}
...
...
paddle/fluid/lite/core/CMakeLists.txt
浏览文件 @
505e450d
...
...
@@ -30,7 +30,9 @@ cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite target_wrapp
cc_library
(
types_lite SRCS types.cc
)
cc_library
(
type_system SRCS type_system.cc DEPS
${
tensor_lite
}
target_wrapper_lite
)
lite_cc_library
(
program_lite SRCS program.cc DEPS op_lite kernel_lite compatible_pb_lite model_parser_lite HVY_DEPS framework_proto
lite_cc_library
(
program_lite SRCS program.cc
DEPS op_lite kernel_lite compatible_pb_lite model_parser_lite
HVY_DEPS framework_proto
PROFILE_DEPS basic_profiler_lite
)
cc_library
(
optimizer_lite SRCS optimizer.cc DEPS mir_pass_manager model_parser_lite program_lite
)
...
...
paddle/fluid/lite/core/mir/CMakeLists.txt
浏览文件 @
505e450d
...
...
@@ -3,8 +3,13 @@ cc_library(mir_ssa_graph SRCS ssa_graph.cc DEPS mir_node program_lite)
cc_library
(
mir_pass SRCS pass.cc DEPS mir_ssa_graph
)
cc_library
(
mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph mir_passes
)
cc_library
(
mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager
)
add_subdirectory
(
fusion
)
cc_library
(
mir_passes
SRCS static_kernel_pick_pass.cc
SRCS fc_fuse_pass.cc
conv_elementwise_add_relu_fuse_pass.cc
conv_bn_fuse_pass.cc
static_kernel_pick_pass.cc
variable_place_inference_pass.cc
type_target_transform_pass.cc
io_copy_kernel_pick_pass.cc
...
...
@@ -13,13 +18,8 @@ cc_library(mir_passes
argument_type_display_pass.cc
demo_pass.cc
runtime_context_assign_pass.cc
DEPS mir_pass types_lite context_lite
)
DEPS mir_pass types_lite context_lite
${
mir_fusers
}
)
# for mobile, unnecessary to compile the following testings.
if
(
LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
)
return
()
endif
()
cc_test
(
test_mir_pass_manager SRCS pass_manager_test.cc DEPS mir_pass_manager mir_passes
)
#cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS
#mir_ssa_graph scope_lite op_lite
#fc_op_lite
...
...
@@ -52,11 +52,37 @@ lite_cc_test(test_pattern_matcher_lite SRCS pattern_matcher_test.cc DEPS pattern
lite_cc_library
(
pattern_matcher_high_api SRCS pattern_matcher_high_api.cc DEPS pattern_matcher_lite
)
# TODO(wz) replace framework/proto to lite proto.
# for mobile, unnecessary to compile the following testings.
if
(
LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
)
return
()
endif
()
cc_test
(
test_mir_pass_manager SRCS pass_manager_test.cc DEPS mir_pass_manager mir_passes
)
# TODO(wz) replace framework/proto to lite proto.
if
(
NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
)
# it depends on the fluid/framework/proto, that is too heavy for mobile execution.
lite_cc_test
(
test_pattern_matcher_high_api SRCS pattern_matcher_high_api_test.cc DEPS
pattern_matcher_high_api proto_desc mir_pass_manager fc_op_lite mul_op_lite elementwise_ops_lite
mir_passes compatible_pb_lite program_lite
${
ops_lite
}
)
endif
()
message
(
STATUS
"----> Ops lite:
${
ops_lite
}
"
)
message
(
STATUS
"----> Host kernels:
${
host_kernels
}
"
)
message
(
STATUS
"----> X86 kernels:
${
x86_kernels
}
"
)
lite_cc_test
(
test_lite_fc_fuse SRCS fc_fuse_pass_test.cc
DEPS cxx_api_lite mir_passes
${
ops_lite
}
${
host_kernels
}
${
x86_kernels
}
${
arm_kernels
}
ARGS --model_dir=
${
LITE_MODEL_DIR
}
/lite_fc_model
--optimized_model=
${
LITE_MODEL_DIR
}
/lite_fc_model_opt SERIAL
)
lite_download_and_uncompress
(
${
LITE_MODEL_DIR
}
${
LITE_URL
}
"lite_fc_model.tar.gz"
)
add_dependencies
(
test_lite_fc_fuse extern_lite_download_lite_fc_model_tar_gz
)
lite_cc_test
(
test_lite_conv_elementwise_add_relu_fuse
SRCS conv_elementwise_add_relu_fuse_pass_test.cc
DEPS cxx_api_lite mir_passes
${
ops_lite
}
${
host_kernels
}
${
x86_kernels
}
)
paddle/fluid/lite/core/mir/conv_bn_fuse_pass.cc
0 → 100644
浏览文件 @
505e450d
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
void
ConvBNFusePass
::
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
{
fusion
::
ConvBNFuser
fuser
(
"conv2d"
);
fuser
(
graph
.
get
());
fusion
::
ConvBNFuser
fuser2
(
"depthwise_conv2d"
);
fuser2
(
graph
.
get
());
}
}
// namespace mir
}
// namespace lite
}
// namespace paddle
REGISTER_MIR_PASS
(
lite_conv_bn_fuse_pass
,
paddle
::
lite
::
mir
::
ConvBNFusePass
);
paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h
0 → 100644
浏览文件 @
505e450d
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pass.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
class
ConvBNFusePass
:
public
ProgramPass
{
public:
void
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
override
;
};
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.cc
0 → 100644
浏览文件 @
505e450d
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
void
ConvElementwiseAddReLUFusePass
::
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
{
fusion
::
ConvElementwiseAddReLUFuser
fuser
(
"conv2d"
);
fuser
(
graph
.
get
());
fusion
::
ConvElementwiseAddReLUFuser
depthwise_fuser
(
"depthwise_conv2d"
);
depthwise_fuser
(
graph
.
get
());
}
}
// namespace mir
}
// namespace lite
}
// namespace paddle
REGISTER_MIR_PASS
(
lite_conv_elementwise_add_act_fuse_pass
,
paddle
::
lite
::
mir
::
ConvElementwiseAddReLUFusePass
);
paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h
0 → 100644
浏览文件 @
505e450d
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pass.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
class
ConvElementwiseAddReLUFusePass
:
public
ProgramPass
{
public:
void
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
override
;
};
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass_test.cc
0 → 100644
浏览文件 @
505e450d
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/api/cxx_api.h"
#include "paddle/fluid/lite/core/compatible_tensor.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/program.h"
DEFINE_string
(
model_dir
,
""
,
""
);
DEFINE_string
(
optimized_model
,
""
,
""
);
namespace
paddle
{
namespace
lite
{
namespace
mir
{
namespace
fusion
{
std
::
unique_ptr
<
SSAGraph
>
BuildGraph
(
framework
::
ProgramDesc
*
program_desc
,
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
std
::
vector
<
Place
>&
valid_places
)
{
auto
*
main_block
=
program_desc
->
MutableBlock
(
0
);
auto
*
conv2d_1
=
main_block
->
AppendOp
();
auto
*
conv2d_2
=
main_block
->
AppendOp
();
auto
*
add_1
=
main_block
->
AppendOp
();
auto
*
relu_1
=
main_block
->
AppendOp
();
auto
*
add_2
=
main_block
->
AppendOp
();
auto
*
relu_2
=
main_block
->
AppendOp
();
main_block
->
Var
(
"input_1"
);
main_block
->
Var
(
"input_2"
);
main_block
->
Var
(
"filter_1"
);
main_block
->
Var
(
"filter_2"
);
main_block
->
Var
(
"conv2d_1_out"
);
main_block
->
Var
(
"conv2d_2_out"
);
main_block
->
Var
(
"bias_1"
);
main_block
->
Var
(
"add_1_out"
);
main_block
->
Var
(
"add_2_out"
);
main_block
->
Var
(
"relu_1_out"
);
main_block
->
Var
(
"out"
);
scope
->
Var
(
"input_1"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"input_2"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"filter_1"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"filter_2"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"conv2d_1_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"conv2d_2_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"bias_1"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"add_1_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"add_2_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"relu_1_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"out"
)
->
GetMutable
<
lite
::
Tensor
>
();
conv2d_1
->
SetType
(
"conv2d"
);
conv2d_1
->
SetInput
(
"Input"
,
{
"input_1"
});
conv2d_1
->
SetInput
(
"Filter"
,
{
"filter_1"
});
conv2d_1
->
SetOutput
(
"Output"
,
{
"conv2d_1_out"
});
conv2d_1
->
SetAttr
(
"strides"
,
std
::
vector
<
int
>
({
1
,
1
}));
conv2d_1
->
SetAttr
(
"paddings"
,
std
::
vector
<
int
>
({
0
,
0
}));
conv2d_1
->
SetAttr
(
"groups"
,
1
);
conv2d_1
->
SetAttr
(
"dilations"
,
std
::
vector
<
int
>
({
1
,
1
}));
conv2d_1
->
SetAttr
(
"fuse_relu"
,
false
);
add_1
->
SetType
(
"elementwise_add"
);
add_1
->
SetInput
(
"X"
,
{
"conv2d_1_out"
});
add_1
->
SetInput
(
"Y"
,
{
"bias_1"
});
add_1
->
SetOutput
(
"Out"
,
{
"add_1_out"
});
add_1
->
SetAttr
(
"axis"
,
1
);
relu_1
->
SetType
(
"relu"
);
relu_1
->
SetInput
(
"X"
,
{
"add_1_out"
});
relu_1
->
SetOutput
(
"Out"
,
{
"relu_1_out"
});
conv2d_2
->
SetType
(
"conv2d"
);
conv2d_2
->
SetInput
(
"Input"
,
{
"input_2"
});
conv2d_2
->
SetInput
(
"Filter"
,
{
"filter_2"
});
conv2d_2
->
SetOutput
(
"Output"
,
{
"conv2d_2_out"
});
conv2d_2
->
SetAttr
(
"strides"
,
std
::
vector
<
int
>
({
1
,
1
}));
conv2d_2
->
SetAttr
(
"paddings"
,
std
::
vector
<
int
>
({
0
,
0
}));
conv2d_2
->
SetAttr
(
"groups"
,
1
);
conv2d_2
->
SetAttr
(
"dilations"
,
std
::
vector
<
int
>
({
1
,
1
}));
conv2d_2
->
SetAttr
(
"fuse_relu"
,
false
);
add_2
->
SetType
(
"elementwise_add"
);
add_2
->
SetInput
(
"X"
,
{
"conv2d_2_out"
});
add_2
->
SetInput
(
"Y"
,
{
"relu_1_out"
});
add_2
->
SetOutput
(
"Out"
,
{
"add_2_out"
});
add_2
->
SetAttr
(
"axis"
,
1
);
relu_2
->
SetType
(
"relu"
);
relu_2
->
SetInput
(
"X"
,
{
"add_2_out"
});
relu_2
->
SetOutput
(
"Out"
,
{
"out"
});
program_desc
->
Flush
();
lite
::
Program
program
(
*
program_desc
->
Proto
(),
scope
,
valid_places
);
auto
graph
=
std
::
unique_ptr
<
SSAGraph
>
(
new
SSAGraph
());
graph
->
Build
(
program
,
valid_places
);
return
graph
;
}
TEST
(
conv_elementwise_add_relu_fuse_pass
,
graph_test
)
{
framework
::
ProgramDesc
program_desc
;
std
::
vector
<
Place
>
places
{{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}};
auto
scope
=
std
::
make_shared
<
Scope
>
();
auto
graph
=
BuildGraph
(
&
program_desc
,
scope
,
places
);
Visualize
(
graph
.
get
());
ASSERT_EQ
(
graph
->
nodes
().
size
(),
11UL
/*vars*/
+
6UL
/*ops*/
);
Visualize
(
graph
.
get
());
}
TEST
(
conv_elementwise_add_relu_fuse_pass
,
fuse_test_op
)
{
framework
::
ProgramDesc
program_desc
;
std
::
vector
<
Place
>
places
{{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}};
auto
scope
=
std
::
make_shared
<
Scope
>
();
auto
graph
=
BuildGraph
(
&
program_desc
,
scope
,
places
);
Visualize
(
graph
.
get
());
const
int
num_nodes
=
graph
->
nodes
().
size
();
auto
*
fuser
=
new
ConvElementwiseAddReLUFusePass
;
fuser
->
Apply
(
graph
);
Visualize
(
graph
.
get
());
ASSERT_EQ
(
graph
->
nodes
().
size
(),
num_nodes
-
5UL
*
2
/*nodes removed */
+
1UL
*
2
/* fused fc node*/
);
}
}
// namespace fusion
}
// namespace mir
}
// namespace lite
}
// namespace paddle
USE_LITE_OP
(
elementwise_add
);
USE_LITE_OP
(
conv2d
);
USE_LITE_OP
(
depthwise_conv2d
);
USE_LITE_OP
(
relu
);
paddle/fluid/lite/core/mir/fc_fuse_pass.cc
0 → 100644
浏览文件 @
505e450d
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fc_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/fc_fuser.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
void
FcFusePass
::
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
{
fusion
::
FcFuser
fuser
;
fuser
(
graph
.
get
());
}
}
// namespace mir
}
// namespace lite
}
// namespace paddle
REGISTER_MIR_PASS
(
lite_fc_fuse_pass
,
paddle
::
lite
::
mir
::
FcFusePass
);
paddle/fluid/lite/core/mir/fc_fuse_pass.h
0 → 100644
浏览文件 @
505e450d
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pass.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
class
FcFusePass
:
public
ProgramPass
{
public:
void
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
override
;
};
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/fc_fuse_pass_test.cc
0 → 100644
浏览文件 @
505e450d
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fc_fuse_pass.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/lite/api/cxx_api.h"
#include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/op_registry.h"
DEFINE_string
(
model_dir
,
""
,
""
);
DEFINE_string
(
optimized_model
,
""
,
""
);
namespace
paddle
{
namespace
lite
{
namespace
mir
{
TEST
(
fc_fuse_pass
,
fuse_test
)
{
lite
::
ExecutorLite
predictor
;
#ifndef LITE_WITH_CUDA
std
::
vector
<
Place
>
valid_places
({
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)},
Place
{
TARGET
(
kX86
),
PRECISION
(
kFloat
)}});
#else
std
::
vector
<
Place
>
valid_places
({
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
)},
Place
{
TARGET
(
kCUDA
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
)},
Place
{
TARGET
(
kCUDA
),
PRECISION
(
kAny
),
DATALAYOUT
(
kNCHW
)},
Place
{
TARGET
(
kHost
),
PRECISION
(
kAny
),
DATALAYOUT
(
kNCHW
)},
Place
{
TARGET
(
kCUDA
),
PRECISION
(
kAny
),
DATALAYOUT
(
kAny
)},
Place
{
TARGET
(
kHost
),
PRECISION
(
kAny
),
DATALAYOUT
(
kAny
)},
});
#endif
predictor
.
Build
(
FLAGS_model_dir
,
Place
{
TARGET
(
kX86
),
PRECISION
(
kFloat
)},
// origin cuda
valid_places
);
auto
*
input_tensor
=
predictor
.
GetInput
(
0
);
input_tensor
->
Resize
(
DDim
(
std
::
vector
<
DDim
::
value_type
>
({
100
,
100
})));
auto
*
data
=
input_tensor
->
mutable_data
<
float
>
();
for
(
int
i
=
0
;
i
<
100
*
100
;
i
++
)
{
data
[
i
]
=
i
;
}
predictor
.
Run
();
auto
*
out
=
predictor
.
GetOutput
(
0
);
LOG
(
INFO
)
<<
out
<<
" memory size "
<<
out
->
data_size
();
LOG
(
INFO
)
<<
"out "
<<
out
->
data
<
float
>
()[
0
];
LOG
(
INFO
)
<<
"out "
<<
out
->
data
<
float
>
()[
1
];
LOG
(
INFO
)
<<
"dims "
<<
out
->
dims
();
EXPECT_NEAR
(
out
->
data
<
float
>
()[
0
],
38.120617
f
,
1e-5
);
EXPECT_NEAR
(
out
->
data
<
float
>
()[
1
],
10.109812
f
,
1e-5
);
CHECK_EQ
(
out
->
dims
()[
0
],
100
);
CHECK_EQ
(
out
->
dims
()[
1
],
500
);
}
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
TEST
(
fc_fuse_pass
,
save_model_test
)
{
lite
::
ExecutorLite
predictor
;
std
::
vector
<
Place
>
valid_places
({
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)},
Place
{
TARGET
(
kX86
),
PRECISION
(
kFloat
)}});
predictor
.
Build
(
FLAGS_model_dir
,
Place
{
TARGET
(
kX86
),
PRECISION
(
kFloat
)},
valid_places
);
LOG
(
INFO
)
<<
"Save optimized model to "
<<
FLAGS_optimized_model
;
predictor
.
SaveModel
(
FLAGS_optimized_model
);
}
#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
}
// namespace mir
}
// namespace lite
}
// namespace paddle
USE_LITE_OP
(
mul
);
USE_LITE_OP
(
elementwise_add
);
USE_LITE_OP
(
elementwise_sub
);
USE_LITE_OP
(
fc
);
USE_LITE_OP
(
feed
);
USE_LITE_OP
(
fetch
);
USE_LITE_OP
(
io_copy
);
USE_LITE_OP
(
softmax
);
USE_LITE_OP
(
scale
);
USE_LITE_KERNEL
(
feed
,
kHost
,
kAny
,
kAny
,
def
);
USE_LITE_KERNEL
(
fetch
,
kHost
,
kAny
,
kAny
,
def
);
#ifdef LITE_WITH_X86
USE_LITE_KERNEL
(
mul
,
kX86
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
fc
,
kX86
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
elementwise_sub
,
kX86
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
elementwise_add
,
kX86
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
softmax
,
kX86
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
scale
,
kX86
,
kFloat
,
kNCHW
,
def
);
#endif
#ifdef LITE_WITH_CUDA
USE_LITE_KERNEL
(
mul
,
kCUDA
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
io_copy
,
kCUDA
,
kAny
,
kAny
,
host_to_device
);
USE_LITE_KERNEL
(
io_copy
,
kCUDA
,
kAny
,
kAny
,
device_to_host
);
#endif
paddle/fluid/lite/core/mir/fusion/CMakeLists.txt
0 → 100644
浏览文件 @
505e450d
cc_library
(
fuse_fc
SRCS fc_fuser.cc
DEPS pattern_matcher_high_api
)
cc_library
(
fuse_conv_elementwise_add_relu
SRCS conv_elementwise_add_relu_fuser.cc
DEPS pattern_matcher_high_api
)
cc_library
(
fuse_conv_bn
SRCS conv_bn_fuser.cc
DEPS pattern_matcher_high_api
)
set
(
mir_fusers
fuse_fc
fuse_conv_elementwise_add_relu
fuse_conv_bn
CACHE INTERNAL
"fusers"
)
if
(
LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
)
return
()
endif
()
lite_cc_test
(
test_lite_conv_bn_fuse SRCS conv_bn_fuse_pass_test.cc
DEPS elementwise_ops_lite batch_norm_op_lite conv_op_lite proto_desc compatible_pb_lite program_lite mir_pass mir_pass_manager pattern_matcher_high_api
)
paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass_test.cc
0 → 100644
浏览文件 @
505e450d
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/core/compatible_tensor.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/program.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
namespace
fusion
{
std
::
unique_ptr
<
SSAGraph
>
BuildGraph
(
framework
::
ProgramDesc
*
program_desc
,
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
std
::
vector
<
Place
>&
valid_places
)
{
auto
*
main_block
=
program_desc
->
MutableBlock
(
0
);
auto
*
conv_op
=
main_block
->
AppendOp
();
auto
*
bn_op
=
main_block
->
AppendOp
();
main_block
->
Var
(
"conv_i"
);
main_block
->
Var
(
"conv_param"
);
main_block
->
Var
(
"conv_out"
);
main_block
->
Var
(
"bn_scale"
);
main_block
->
Var
(
"bn_bias"
);
main_block
->
Var
(
"bn_mean"
);
main_block
->
Var
(
"bn_var"
);
main_block
->
Var
(
"bn_out"
);
main_block
->
Var
(
"bn_mean_out"
);
main_block
->
Var
(
"bn_var_out"
);
main_block
->
Var
(
"bn_saved_mean"
);
main_block
->
Var
(
"bn_saved_var"
);
scope
->
Var
(
"conv_i"
)
->
GetMutable
<
lite
::
Tensor
>
();
auto
conv_param_t
=
scope
->
Var
(
"conv_param"
)
->
GetMutable
<
lite
::
Tensor
>
();
std
::
vector
<
int64_t
>
conv_param_shape
=
{
3
,
1
,
2
,
2
};
conv_param_t
->
Resize
(
lite
::
DDim
(
conv_param_shape
));
conv_param_t
->
mutable_data
<
float
>
();
scope
->
Var
(
"conv_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
auto
bn_scale_t
=
scope
->
Var
(
"bn_scale"
)
->
GetMutable
<
lite
::
Tensor
>
();
std
::
vector
<
int64_t
>
bn_scale_shape
=
{
3
};
bn_scale_t
->
Resize
(
lite
::
DDim
(
bn_scale_shape
));
bn_scale_t
->
mutable_data
<
float
>
();
auto
bn_bias_t
=
scope
->
Var
(
"bn_bias"
)
->
GetMutable
<
lite
::
Tensor
>
();
std
::
vector
<
int64_t
>
bn_bias_shape
=
{
3
};
bn_bias_t
->
Resize
(
lite
::
DDim
(
bn_bias_shape
));
bn_bias_t
->
mutable_data
<
float
>
();
auto
bn_mean_t
=
scope
->
Var
(
"bn_mean"
)
->
GetMutable
<
lite
::
Tensor
>
();
bn_mean_t
->
Resize
(
lite
::
DDim
(
bn_bias_shape
));
bn_mean_t
->
mutable_data
<
float
>
();
auto
bn_var_t
=
scope
->
Var
(
"bn_var"
)
->
GetMutable
<
lite
::
Tensor
>
();
bn_var_t
->
Resize
(
lite
::
DDim
(
bn_bias_shape
));
bn_var_t
->
mutable_data
<
float
>
();
scope
->
Var
(
"bn_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"bn_mean_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"bn_var_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"bn_saved_mean"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"bn_saved_var"
)
->
GetMutable
<
lite
::
Tensor
>
();
conv_op
->
SetType
(
"conv2d"
);
conv_op
->
SetInput
(
"Input"
,
{
"conv_i"
});
conv_op
->
SetInput
(
"Filter"
,
{
"conv_param"
});
conv_op
->
SetOutput
(
"Output"
,
{
"conv_out"
});
const
std
::
vector
<
int
>
strides
({
1
,
1
});
const
std
::
vector
<
int
>
paddings
({
1
,
1
});
const
std
::
vector
<
int
>
dilations
({
1
,
1
});
const
int
groups
=
1
;
conv_op
->
SetAttr
(
"strides"
,
strides
);
conv_op
->
SetAttr
(
"paddings"
,
paddings
);
conv_op
->
SetAttr
(
"dilations"
,
dilations
);
conv_op
->
SetAttr
(
"groups"
,
groups
);
conv_op
->
SetAttr
(
"fuse_relu"
,
false
);
bn_op
->
SetType
(
"batch_norm"
);
bn_op
->
SetInput
(
"X"
,
{
"conv_out"
});
bn_op
->
SetInput
(
"Bias"
,
{
"bn_bias"
});
bn_op
->
SetInput
(
"Mean"
,
{
"bn_mean"
});
bn_op
->
SetInput
(
"Scale"
,
{
"bn_scale"
});
bn_op
->
SetInput
(
"Variance"
,
{
"bn_var"
});
bn_op
->
SetOutput
(
"Y"
,
{
"bn_out"
});
bn_op
->
SetOutput
(
"MeanOut"
,
{
"bn_mean_out"
});
bn_op
->
SetOutput
(
"VarianceOut"
,
{
"bn_var_out"
});
bn_op
->
SetOutput
(
"SavedMean"
,
{
"bn_saved_mean"
});
bn_op
->
SetOutput
(
"SavedVariance"
,
{
"bn_saved_var"
});
float
eps
=
1e-5
;
bn_op
->
SetAttr
(
"epsilon"
,
eps
);
bn_op
->
SetAttr
(
"is_test"
,
static_cast
<
int
>
(
1
));
bn_op
->
SetAttr
(
"use_global_stats"
,
false
);
bn_op
->
SetAttr
(
"momentum"
,
0.9
f
);
bn_op
->
SetAttr
(
"data_layout"
,
std
::
string
(
"NCHW"
));
program_desc
->
Flush
();
lite
::
Program
program
(
*
program_desc
->
Proto
(),
scope
,
valid_places
);
auto
graph
=
std
::
unique_ptr
<
SSAGraph
>
(
new
SSAGraph
());
graph
->
Build
(
program
,
valid_places
);
return
graph
;
}
TEST
(
pattern_matcher2
,
test
)
{
framework
::
ProgramDesc
program_desc
;
std
::
vector
<
Place
>
places
{{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}};
auto
scope
=
std
::
make_shared
<
Scope
>
();
auto
graph
=
BuildGraph
(
&
program_desc
,
scope
,
places
);
const
int
num_nodes
=
graph
->
nodes
().
size
();
auto
*
fuser
=
new
ConvBNFusePass
;
fuser
->
Apply
(
graph
);
ASSERT_EQ
(
graph
->
nodes
().
size
(),
num_nodes
-
8UL
/*nodes removed */
+
1UL
/* eltwise_add node*/
);
}
}
// namespace fusion
}
// namespace mir
}
// namespace lite
}
// namespace paddle
USE_LITE_OP
(
conv2d
);
USE_LITE_OP
(
batch_norm
);
USE_LITE_OP
(
elementwise_add
);
paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.cc
0 → 100644
浏览文件 @
505e450d
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.h"
#include <memory>
#include <vector>
namespace
paddle
{
namespace
lite
{
namespace
mir
{
namespace
fusion
{
void
ConvBNFuser
::
BuildPattern
()
{
auto
*
conv_input
=
VarNode
(
"conv_input"
)
->
assert_is_op_input
(
conv_type_
,
"Input"
)
->
AsInput
();
auto
*
conv_weight
=
VarNode
(
"conv_weight"
)
->
assert_is_op_input
(
conv_type_
,
"Filter"
)
->
AsInput
();
auto
*
conv
=
OpNode
(
"conv2d"
,
conv_type_
)
->
assert_is_op
(
conv_type_
);
auto
*
conv_out
=
VarNode
(
"conv_out"
)
->
assert_is_op_output
(
conv_type_
,
"Output"
)
->
assert_is_op_input
(
"batch_norm"
,
"X"
);
auto
*
bn_scale
=
VarNode
(
"bn_scale"
)
->
assert_is_op_input
(
"batch_norm"
,
"Scale"
)
->
AsIntermediate
();
auto
*
bn_bias
=
VarNode
(
"bn_bias"
)
->
assert_is_op_input
(
"batch_norm"
,
"Bias"
)
->
AsInput
();
auto
*
bn_mean
=
VarNode
(
"bn_mean"
)
->
assert_is_op_input
(
"batch_norm"
,
"Mean"
)
->
AsIntermediate
();
auto
*
bn_var
=
VarNode
(
"bn_variance"
)
->
assert_is_op_input
(
"batch_norm"
,
"Variance"
)
->
AsIntermediate
();
auto
*
bn
=
OpNode
(
"bn"
,
"batch_norm"
)
->
assert_is_op
(
"batch_norm"
)
->
AsIntermediate
();
auto
*
bn_out
=
VarNode
(
"bn_out"
)
->
assert_is_op_output
(
"batch_norm"
,
"Y"
)
->
AsOutput
();
auto
*
bn_mean_out
=
VarNode
(
"bn_mean_out"
)
->
assert_is_op_output
(
"batch_norm"
,
"MeanOut"
)
->
AsIntermediate
();
auto
*
bn_var_out
=
VarNode
(
"bn_var_out"
)
->
assert_is_op_output
(
"batch_norm"
,
"VarianceOut"
)
->
AsIntermediate
();
auto
*
bn_saved_mean
=
VarNode
(
"bn_saved_mean"
)
->
assert_is_op_output
(
"batch_norm"
,
"SavedMean"
)
->
AsIntermediate
();
auto
*
bn_saved_var
=
VarNode
(
"bn_saved_var"
)
->
assert_is_op_output
(
"batch_norm"
,
"SavedVariance"
)
->
AsIntermediate
();
conv
->
LinksFrom
({
conv_input
,
conv_weight
}).
LinksTo
({
conv_out
});
bn
->
LinksFrom
({
conv_out
,
bn_scale
,
bn_bias
,
bn_mean
,
bn_var
})
.
LinksTo
({
bn_out
,
bn_mean_out
,
bn_saved_mean
,
bn_saved_var
,
bn_var_out
});
}
void
ConvBNFuser
::
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
{
auto
op_desc
=
GenOpDesc
(
matched
);
auto
eltwise_op
=
LiteOpRegistry
::
Global
().
Create
(
"elementwise_add"
);
auto
conv
=
matched
.
at
(
"conv2d"
)
->
stmt
()
->
op
;
auto
*
scope
=
conv
->
scope
();
auto
&
valid_places
=
conv
->
valid_places
();
auto
conv_weight_t
=
scope
->
FindVar
(
matched
.
at
(
"conv_weight"
)
->
arg
()
->
name
)
->
GetMutable
<
lite
::
Tensor
>
();
auto
conv_weight_d
=
conv_weight_t
->
mutable_data
<
float
>
();
auto
conv_weight_dims
=
conv_weight_t
->
dims
();
size_t
weight_num
=
conv_weight_t
->
data_size
();
auto
bn_scale_t
=
scope
->
FindVar
(
matched
.
at
(
"bn_scale"
)
->
arg
()
->
name
)
->
GetMutable
<
lite
::
Tensor
>
();
size_t
bias_size
=
bn_scale_t
->
data_size
();
auto
bn_scale_d
=
bn_scale_t
->
mutable_data
<
float
>
();
CHECK
(
bias_size
==
conv_weight_dims
[
0
])
<<
"The BN bias's size should be equal to the size of the first "
<<
"dim size of the conv weights"
;
auto
bn_mean_t
=
scope
->
FindVar
(
matched
.
at
(
"bn_mean"
)
->
arg
()
->
name
)
->
GetMutable
<
lite
::
Tensor
>
();
auto
bn_mean_d
=
bn_mean_t
->
mutable_data
<
float
>
();
auto
bn_var_t
=
scope
->
FindVar
(
matched
.
at
(
"bn_variance"
)
->
arg
()
->
name
)
->
GetMutable
<
lite
::
Tensor
>
();
auto
bn_var_d
=
bn_var_t
->
mutable_data
<
float
>
();
auto
bn_bias_t
=
scope
->
FindVar
(
matched
.
at
(
"bn_bias"
)
->
arg
()
->
name
)
->
GetMutable
<
lite
::
Tensor
>
();
auto
bn_bias_d
=
bn_bias_t
->
mutable_data
<
float
>
();
auto
eps
=
matched
.
at
(
"bn"
)
->
stmt
()
->
op_info
()
->
GetAttr
<
float
>
(
"epsilon"
);
ComputeFusedWeight
(
bn_scale_d
,
bn_mean_d
,
bn_var_d
,
bn_bias_d
,
conv_weight_d
,
eps
,
bias_size
,
weight_num
/
bias_size
);
eltwise_op
->
Attach
(
op_desc
,
scope
);
auto
*
new_op_node
=
graph
->
GraphCreateInstructNode
(
eltwise_op
,
valid_places
);
IR_NODE_LINK_TO
(
matched
.
at
(
"conv_out"
),
new_op_node
);
IR_NODE_LINK_TO
(
matched
.
at
(
"bn_bias"
),
new_op_node
);
IR_NODE_LINK_TO
(
new_op_node
,
matched
.
at
(
"bn_out"
));
}
cpp
::
OpDesc
ConvBNFuser
::
GenOpDesc
(
const
key2nodes_t
&
matched
)
{
cpp
::
OpDesc
op_desc
;
op_desc
.
SetType
(
"elementwise_add"
);
op_desc
.
SetInput
(
"X"
,
{
matched
.
at
(
"conv_out"
)
->
arg
()
->
name
});
op_desc
.
SetInput
(
"Y"
,
{
matched
.
at
(
"bn_bias"
)
->
arg
()
->
name
});
op_desc
.
SetOutput
(
"Out"
,
{
matched
.
at
(
"bn_out"
)
->
arg
()
->
name
});
op_desc
.
SetAttr
(
"axis"
,
1
);
return
op_desc
;
}
}
// namespace fusion
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.h
0 → 100644
浏览文件 @
505e450d
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
namespace
fusion
{
class
ConvBNFuser
:
public
FuseBase
{
public:
explicit
ConvBNFuser
(
const
std
::
string
&
conv_type
)
:
conv_type_
(
conv_type
)
{}
void
BuildPattern
()
override
;
void
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
override
;
private:
cpp
::
OpDesc
GenOpDesc
(
const
key2nodes_t
&
matched
)
override
;
void
ComputeFusedWeight
(
float
*
scale_d
,
float
*
mean_d
,
float
*
var_d
,
float
*
bias_d
,
float
*
conv_weight_d
,
float
eps
,
int
h
,
int
w
)
{
for
(
int
i
=
0
;
i
<
h
;
i
++
)
{
var_d
[
i
]
=
scale_d
[
i
]
/
std
::
sqrt
(
var_d
[
i
]
+
eps
);
}
for
(
int
i
=
0
;
i
<
h
;
i
++
)
{
bias_d
[
i
]
+=
(
-
mean_d
[
i
])
*
var_d
[
i
];
}
for
(
int
i
=
0
;
i
<
h
;
i
++
)
{
for
(
int
j
=
0
;
j
<
w
;
j
++
)
{
conv_weight_d
[
i
*
w
+
j
]
*=
var_d
[
i
];
}
}
}
private:
std
::
string
conv_type_
{
"conv2d"
};
};
}
// namespace fusion
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.cc
0 → 100644
浏览文件 @
505e450d
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h"
#include <memory>
#include <vector>
namespace
paddle
{
namespace
lite
{
namespace
mir
{
namespace
fusion
{
void
ConvElementwiseAddReLUFuser
::
BuildPattern
()
{
// create input nodes.
auto
*
input
=
VarNode
(
"input"
)
->
assert_is_op_input
(
conv_type_
,
"Input"
)
->
AsInput
();
auto
*
filter
=
VarNode
(
"filter"
)
->
assert_is_op_input
(
conv_type_
,
"Filter"
)
->
AsInput
();
auto
*
bias
=
VarNode
(
"bias"
)
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
)
->
AsInput
();
// create op nodes
auto
*
conv2d
=
OpNode
(
"conv2d"
,
conv_type_
)
->
assert_is_op
(
conv_type_
)
->
AsIntermediate
();
auto
*
add
=
OpNode
(
"add"
,
"elementwise_add"
)
->
assert_is_op
(
"elementwise_add"
)
->
AsIntermediate
();
auto
*
relu
=
OpNode
(
"relu"
,
"relu"
)
->
assert_is_op
(
"relu"
)
->
AsIntermediate
();
// create intermediate nodes
auto
*
conv2d_out
=
VarNode
(
"conv2d_out"
)
->
assert_is_op_output
(
conv_type_
,
"Output"
)
->
assert_is_op_input
(
"elementwise_add"
,
"X"
)
->
AsIntermediate
();
auto
*
add_out
=
VarNode
(
"add_out"
)
->
assert_is_op_output
(
"elementwise_add"
,
"Out"
)
->
assert_is_op_input
(
"relu"
,
"X"
)
->
AsIntermediate
();
// create output node
auto
*
out
=
VarNode
(
"output"
)
->
assert_is_op_output
(
"relu"
,
"Out"
)
->
AsOutput
();
// create topology.
std
::
vector
<
PMNode
*>
conv2d_inputs
{
filter
,
input
};
std
::
vector
<
PMNode
*>
add_inputs
{
conv2d_out
,
bias
};
conv2d_inputs
>>
*
conv2d
>>
*
conv2d_out
;
add_inputs
>>
*
add
>>
*
add_out
;
*
add_out
>>
*
relu
>>
*
out
;
}
void
ConvElementwiseAddReLUFuser
::
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
{
auto
op_desc
=
GenOpDesc
(
matched
);
auto
conv_op
=
LiteOpRegistry
::
Global
().
Create
(
conv_type_
);
auto
conv_old
=
matched
.
at
(
"conv2d"
)
->
stmt
()
->
op
;
auto
*
scope
=
conv_old
->
scope
();
auto
&
valid_places
=
conv_old
->
valid_places
();
conv_op
->
Attach
(
op_desc
,
scope
);
auto
*
new_op_node
=
graph
->
GraphCreateInstructNode
(
conv_op
,
valid_places
);
IR_NODE_LINK_TO
(
matched
.
at
(
"input"
),
new_op_node
);
IR_NODE_LINK_TO
(
matched
.
at
(
"filter"
),
new_op_node
);
IR_NODE_LINK_TO
(
matched
.
at
(
"bias"
),
new_op_node
);
IR_NODE_LINK_TO
(
new_op_node
,
matched
.
at
(
"output"
));
}
cpp
::
OpDesc
ConvElementwiseAddReLUFuser
::
GenOpDesc
(
const
key2nodes_t
&
matched
)
{
auto
*
desc
=
matched
.
at
(
"conv2d"
)
->
stmt
()
->
op_info
();
cpp
::
OpDesc
op_desc
;
op_desc
.
SetType
(
conv_type_
);
op_desc
.
SetInput
(
"Input"
,
{
matched
.
at
(
"input"
)
->
arg
()
->
name
});
op_desc
.
SetInput
(
"Filter"
,
{
matched
.
at
(
"filter"
)
->
arg
()
->
name
});
op_desc
.
SetInput
(
"Bias"
,
{
matched
.
at
(
"bias"
)
->
arg
()
->
name
});
op_desc
.
SetOutput
(
"Output"
,
{
matched
.
at
(
"output"
)
->
arg
()
->
name
});
// Other inputs. See operators/conv_op.h
std
::
vector
<
std
::
string
>
input_arg_names
=
desc
->
InputArgumentNames
();
if
(
std
::
find
(
input_arg_names
.
begin
(),
input_arg_names
.
end
(),
"ResidualData"
)
!=
input_arg_names
.
end
())
{
op_desc
.
SetInput
(
"ResidualData"
,
desc
->
Input
(
"ResidualData"
));
}
// Only consider strides, padding, groups, dilations, fuse_relu for now
op_desc
.
SetAttr
(
"strides"
,
desc
->
GetAttr
<
std
::
vector
<
int
>>
(
"strides"
));
op_desc
.
SetAttr
(
"paddings"
,
desc
->
GetAttr
<
std
::
vector
<
int
>>
(
"paddings"
));
op_desc
.
SetAttr
(
"groups"
,
desc
->
GetAttr
<
int
>
(
"groups"
));
op_desc
.
SetAttr
(
"dilations"
,
desc
->
GetAttr
<
std
::
vector
<
int
>>
(
"dilations"
));
op_desc
.
SetAttr
(
"fuse_relu"
,
true
);
return
op_desc
;
}
}
// namespace fusion
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h
0 → 100644
浏览文件 @
505e450d
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
namespace
fusion
{
class
ConvElementwiseAddReLUFuser
:
public
FuseBase
{
public:
explicit
ConvElementwiseAddReLUFuser
(
const
std
::
string
&
conv_type
)
:
conv_type_
(
conv_type
)
{}
void
BuildPattern
()
override
;
void
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
override
;
private:
cpp
::
OpDesc
GenOpDesc
(
const
key2nodes_t
&
matched
)
override
;
std
::
string
conv_type_
;
};
}
// namespace fusion
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/fusion/fc_fuser.cc
0 → 100644
浏览文件 @
505e450d
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fusion/fc_fuser.h"
#include <memory>
#include <vector>
namespace
paddle
{
namespace
lite
{
namespace
mir
{
namespace
fusion
{
void
FcFuser
::
BuildPattern
()
{
// create nodes.
auto
*
x
=
VarNode
(
"x"
)
->
assert_is_op_input
(
"mul"
,
"X"
);
auto
*
W
=
VarNode
(
"W"
)
->
assert_is_op_input
(
"mul"
,
"Y"
);
auto
*
b
=
VarNode
(
"b"
);
auto
*
mul
=
OpNode
(
"mul"
,
"mul"
);
auto
*
mul_out
=
VarNode
(
"mul_out"
);
auto
*
add
=
OpNode
(
"add"
,
"elementwise_add"
);
auto
*
Out
=
VarNode
(
"Out"
);
// create topology.
std
::
vector
<
PMNode
*>
mul_inputs
{
W
,
x
};
std
::
vector
<
PMNode
*>
add_inputs
{
mul_out
,
b
};
mul_inputs
>>
*
mul
>>
*
mul_out
;
add_inputs
>>
*
add
>>
*
Out
;
// Some op specialities.
mul_out
->
AsIntermediate
();
mul
->
AsIntermediate
();
add
->
AsIntermediate
();
}
void
FcFuser
::
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
{
auto
op_desc
=
GenOpDesc
(
matched
);
auto
fc_op
=
LiteOpRegistry
::
Global
().
Create
(
"fc"
);
auto
mul
=
matched
.
at
(
"mul"
)
->
stmt
()
->
op
;
auto
*
scope
=
mul
->
scope
();
auto
&
valid_places
=
mul
->
valid_places
();
fc_op
->
Attach
(
op_desc
,
scope
);
auto
*
new_op_node
=
graph
->
GraphCreateInstructNode
(
fc_op
,
valid_places
);
IR_NODE_LINK_TO
(
matched
.
at
(
"W"
),
new_op_node
);
IR_NODE_LINK_TO
(
matched
.
at
(
"x"
),
new_op_node
);
IR_NODE_LINK_TO
(
matched
.
at
(
"b"
),
new_op_node
);
IR_NODE_LINK_TO
(
new_op_node
,
matched
.
at
(
"Out"
));
}
cpp
::
OpDesc
FcFuser
::
GenOpDesc
(
const
key2nodes_t
&
matched
)
{
cpp
::
OpDesc
op_desc
;
op_desc
.
SetType
(
"fc"
);
op_desc
.
SetInput
(
"Input"
,
{
matched
.
at
(
"x"
)
->
arg
()
->
name
});
op_desc
.
SetInput
(
"W"
,
{
matched
.
at
(
"W"
)
->
arg
()
->
name
});
op_desc
.
SetInput
(
"Bias"
,
{
matched
.
at
(
"b"
)
->
arg
()
->
name
});
op_desc
.
SetOutput
(
"Out"
,
{
matched
.
at
(
"Out"
)
->
arg
()
->
name
});
op_desc
.
SetAttr
(
"in_num_col_dims"
,
matched
.
at
(
"mul"
)
->
stmt
()
->
op_info
()
->
GetAttr
<
int
>
(
"x_num_col_dims"
));
return
op_desc
;
}
}
// namespace fusion
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/fusion/fc_fuser.h
0 → 100644
浏览文件 @
505e450d
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
namespace
fusion
{
class
FcFuser
:
public
FuseBase
{
public:
void
BuildPattern
()
override
;
void
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
override
;
private:
cpp
::
OpDesc
GenOpDesc
(
const
key2nodes_t
&
matched
)
override
;
};
}
// namespace fusion
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/generate_program_pass.cc
浏览文件 @
505e450d
...
...
@@ -28,7 +28,7 @@ void GenerateProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
for
(
auto
&
item
:
graph
->
StmtTopologicalOrder
())
{
if
(
item
->
IsStmt
())
{
auto
&
stmt
=
item
->
AsStmt
();
LOG
(
INFO
)
<<
stmt
;
VLOG
(
4
)
<<
stmt
;
insts_
.
emplace_back
(
stmt
.
op
,
std
::
move
(
stmt
.
valid_kernels
.
front
()));
}
}
...
...
paddle/fluid/lite/core/mir/node.h
浏览文件 @
505e450d
...
...
@@ -71,12 +71,20 @@ class Node {
struct
Arg
{
std
::
string
name
;
int
id
{
0
};
const
Type
*
type
{};
// Weight is a special kind of argument, it is marked as weight explicitly
// so that some weight related optimization can take place.
bool
is_weight
{
false
};
};
Arg
&
AsArg
(
const
std
::
string
&
name
,
int
id
)
{
auto
&
x
=
AsArg
();
x
.
name
=
name
;
x
.
id
=
id
;
return
x
;
}
Arg
&
AsArg
(
const
std
::
string
&
name
)
{
auto
&
x
=
AsArg
();
x
.
name
=
name
;
...
...
paddle/fluid/lite/core/mir/passes.h
浏览文件 @
505e450d
...
...
@@ -31,3 +31,7 @@ USE_MIR_PASS(io_copy_kernel_pick_pass);
USE_MIR_PASS
(
argument_type_display_pass
);
#endif
USE_MIR_PASS
(
runtime_context_assign_pass
);
USE_MIR_PASS
(
lite_conv_bn_fuse_pass
);
USE_MIR_PASS
(
graph_visualze
);
USE_MIR_PASS
(
lite_fc_fuse_pass
);
USE_MIR_PASS
(
lite_conv_elementwise_add_act_fuse_pass
);
paddle/fluid/lite/core/mir/pattern_matcher.cc
浏览文件 @
505e450d
...
...
@@ -45,10 +45,11 @@ PMNode &PMNode::operator>>(std::vector<PMNode *> &nodes) {
return
*
this
;
}
void
operator
>>
(
std
::
vector
<
PMNode
*>
&
others
,
PMNode
&
me
)
{
PMNode
&
operator
>>
(
std
::
vector
<
PMNode
*>
&
others
,
PMNode
&
me
)
{
for
(
auto
*
o
:
others
)
{
*
o
>>
me
;
}
return
me
;
}
PMNode
*
PMPattern
::
NewNode
(
const
std
::
string
&
name
)
{
...
...
@@ -406,6 +407,67 @@ PMNode *PMNode::assert_is_op_output(const std::string &op_type) {
return
this
;
}
bool
IsNthOutput
(
const
Node
*
var
,
const
Node
*
op
,
const
std
::
string
&
argument
,
size_t
nth
)
{
CHECK
(
var
->
IsArg
());
CHECK
(
op
->
IsStmt
());
auto
op_info
=
op
->
stmt
()
->
op_info
();
if
(
op_info
->
Output
(
argument
).
size
()
<=
nth
)
return
false
;
return
var
->
arg
()
->
name
==
op_info
->
Output
(
argument
)[
nth
];
}
bool
IsNthInput
(
const
Node
*
var
,
const
Node
*
op
,
const
std
::
string
&
argument
,
size_t
nth
)
{
CHECK
(
var
->
IsArg
());
CHECK
(
op
->
IsStmt
());
auto
op_info
=
op
->
stmt
()
->
op_info
();
if
(
op_info
->
Input
(
argument
).
size
()
<=
nth
)
return
false
;
return
var
->
arg
()
->
name
==
op_info
->
Input
(
argument
)[
nth
];
}
PMNode
*
PMNode
::
assert_is_op_input
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
)
{
assert_is_var
();
assert_is_op_nth_input
(
op_type
,
argument
,
0
);
return
this
;
}
PMNode
*
PMNode
::
assert_is_op_nth_input
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
,
int
nth
)
{
assert_is_var
();
assert_is_op_input
(
op_type
);
asserts_
.
emplace_back
([
=
](
const
Node
*
x
)
{
for
(
auto
*
op
:
x
->
outlinks
)
{
if
(
op
&&
op
->
IsStmt
()
&&
op
->
stmt
()
->
op_info
()
->
Type
()
==
op_type
&&
IsNthInput
(
x
,
op
,
argument
,
nth
))
return
true
;
}
return
false
;
});
return
this
;
}
PMNode
*
PMNode
::
assert_is_op_output
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
)
{
assert_is_var
();
assert_is_op_nth_output
(
op_type
,
argument
,
0
);
return
this
;
}
PMNode
*
PMNode
::
assert_is_op_nth_output
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
,
int
nth
)
{
assert_is_var
();
asserts_
.
emplace_back
([
=
](
const
Node
*
x
)
{
for
(
auto
*
op
:
x
->
inlinks
)
{
if
(
op
&&
op
->
IsStmt
()
&&
op
->
stmt
()
->
op_info
()
->
Type
()
==
op_type
&&
IsNthOutput
(
x
,
op
,
argument
,
nth
))
return
true
;
}
return
false
;
});
return
this
;
}
PMNode
*
PMNode
::
assert_is_op_input
(
const
std
::
string
&
op_type
)
{
assert_is_var
();
asserts_
.
emplace_back
([
=
](
const
Node
*
x
)
{
...
...
@@ -422,6 +484,14 @@ PMNode *PMNode::assert_is_op_input(const std::string &op_type) {
return
this
;
}
bool
HasInput
(
const
Node
&
op
,
const
std
::
string
&
argument
)
{
CHECK
(
op
.
IsStmt
());
auto
const
&
names
=
op
.
stmt
()
->
op_info
()
->
input_argnames
();
if
(
std
::
find
(
names
.
begin
(),
names
.
end
(),
argument
)
==
names
.
end
())
return
false
;
return
true
;
}
void
GraphSafeRemoveNodes
(
SSAGraph
*
graph
,
const
std
::
unordered_set
<
const
Node
*>
&
nodes
)
{
for
(
auto
*
node
:
nodes
)
{
...
...
paddle/fluid/lite/core/mir/pattern_matcher.h
浏览文件 @
505e450d
...
...
@@ -62,7 +62,7 @@ struct PMNode {
PMNode
&
operator
>>
(
PMNode
&
right
);
// Link many nodes to this node.
friend
void
operator
>>
(
std
::
vector
<
PMNode
*>&
others
,
PMNode
&
me
);
friend
PMNode
&
operator
>>
(
std
::
vector
<
PMNode
*>&
others
,
PMNode
&
me
);
// Link this to many other nodes.
PMNode
&
operator
>>
(
std
::
vector
<
PMNode
*>&
nodes
);
...
...
@@ -127,6 +127,15 @@ struct PMNode {
PMNode
*
assert_is_persistable_var
();
PMNode
*
assert_is_op_output
(
const
std
::
string
&
op_type
);
PMNode
*
assert_is_op_input
(
const
std
::
string
&
op_type
);
PMNode
*
assert_is_op_input
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
);
PMNode
*
assert_is_op_output
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
);
PMNode
*
assert_is_op_nth_input
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
,
int
nth
);
PMNode
*
assert_is_op_nth_output
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
,
int
nth
);
template
<
typename
T
>
PMNode
*
assert_op_attr
(
const
std
::
string
&
attr_name
,
const
T
&
attr
)
{
...
...
@@ -297,6 +306,13 @@ class PatternMatcher {
std
::
unordered_map
<
const
PMNode
*
,
std
::
unordered_set
<
Node
*>>
pmnodes2nodes_
;
};
// Check whether a var node is a op node's nth input.
bool
IsNthInput
(
const
Node
&
var
,
const
Node
&
op
,
const
std
::
string
&
argument
,
int
nth
);
// Check whether the op node has input of given name.
bool
HasInput
(
const
Node
&
op
,
const
std
::
string
&
argument
);
// Graph safely remove some nodes, will automatically clean up the edges.
void
GraphSafeRemoveNodes
(
SSAGraph
*
graph
,
const
std
::
unordered_set
<
const
Node
*>&
nodes
);
...
...
paddle/fluid/lite/core/mir/pattern_matcher_high_api.cc
浏览文件 @
505e450d
...
...
@@ -20,7 +20,7 @@ namespace lite {
namespace
mir
{
void
FuseBase
::
PerformPatternMatcher
(
SSAGraph
*
graph
)
{
LOG
(
INFO
)
<<
"
\n
"
<<
matcher_
.
pattern
().
DotString
();
VLOG
(
4
)
<<
"
\n
"
<<
matcher_
.
pattern
().
DotString
();
// Get subgraphs and record the mir::Node pointers for each PMNode.
auto
handler
=
[
&
](
const
PatternMatcher
::
subgraph_t
&
subgraph
,
SSAGraph
*
g
)
{
// get all the reigistered nodes.
...
...
@@ -41,17 +41,14 @@ void FuseBase::DeleteInterNodes(SSAGraph *graph) {
}
}
LOG
(
INFO
)
<<
"keys.size "
<<
keys
.
size
();
std
::
unordered_set
<
const
Node
*>
nodes2rm
;
for
(
auto
&
matched
:
key2nodes_
)
{
LOG
(
INFO
)
<<
"get matched "
<<
matched
.
size
();
for
(
const
auto
&
key
:
keys
)
{
nodes2rm
.
insert
(
matched
.
at
(
key
));
}
}
LOG
(
INFO
)
<<
"clean nodes "
<<
nodes2rm
.
size
();
VLOG
(
3
)
<<
"clean nodes "
<<
nodes2rm
.
size
();
GraphSafeRemoveNodes
(
graph
,
nodes2rm
);
}
...
...
paddle/fluid/lite/core/mir/pattern_matcher_high_api.h
浏览文件 @
505e450d
...
...
@@ -64,7 +64,6 @@ class FuseBase {
// Delete nodes that are marked as Intermediate
void
DeleteInterNodes
(
SSAGraph
*
graph
);
private:
PMNode
*
GetOrCreateNode
(
const
std
::
string
&
key
);
protected:
...
...
paddle/fluid/lite/core/mir/pattern_matcher_high_api_test.cc
浏览文件 @
505e450d
...
...
@@ -29,8 +29,8 @@ class FcFuser : public FuseBase {
public:
void
BuildPattern
()
override
{
// create nodes.
auto
*
x
=
VarNode
(
"x"
);
auto
*
W
=
VarNode
(
"W"
);
auto
*
x
=
VarNode
(
"x"
)
->
assert_is_op_input
(
"mul"
,
"X"
)
;
auto
*
W
=
VarNode
(
"W"
)
->
assert_is_op_input
(
"mul"
,
"Y"
)
;
auto
*
b
=
VarNode
(
"b"
);
auto
*
mul
=
OpNode
(
"mul"
,
"mul"
);
auto
*
mul_out
=
VarNode
(
"mul_out"
);
...
...
@@ -38,12 +38,10 @@ class FcFuser : public FuseBase {
auto
*
Out
=
VarNode
(
"Out"
);
// create topology.
// std::vector<PMNode*>({W, x}) >> *mul >> *mul_out;
// std::vector<PMNode*>({mul_out, b}) >> *add >> *Out;
*
W
>>
*
mul
;
*
x
>>
*
mul
>>
*
mul_out
;
*
b
>>
*
add
;
*
mul_out
>>
*
add
>>
*
Out
;
std
::
vector
<
PMNode
*>
mul_inputs
{
W
,
x
};
std
::
vector
<
PMNode
*>
add_inputs
{
mul_out
,
b
};
mul_inputs
>>
*
mul
>>
*
mul_out
;
add_inputs
>>
*
add
>>
*
Out
;
// Some op specialities.
mul_out
->
AsIntermediate
();
...
...
@@ -91,14 +89,12 @@ std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc,
main_block
->
Var
(
"mul_out"
);
main_block
->
Var
(
"w"
);
main_block
->
Var
(
"out"
);
main_block
->
Var
(
"out1"
);
scope
->
Var
(
"w"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"b"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"mul_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"w"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"out1"
)
->
GetMutable
<
lite
::
Tensor
>
();
mul
->
SetInput
(
"X"
,
{
"x"
});
mul
->
SetInput
(
"Y"
,
{
"w"
});
...
...
@@ -122,18 +118,17 @@ std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc,
return
graph
;
}
TEST
(
pattern_matcher
2
,
graph_test
)
{
TEST
(
pattern_matcher
_high_api
,
graph_test
)
{
framework
::
ProgramDesc
program_desc
;
std
::
vector
<
Place
>
places
{{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}};
auto
scope
=
std
::
make_shared
<
Scope
>
();
auto
graph
=
BuildGraph
(
&
program_desc
,
scope
,
places
);
ASSERT_EQ
(
graph
->
nodes
().
size
(),
8UL
/*real nodes*/
+
2UL
/*feed op + fetch op*/
);
ASSERT_EQ
(
graph
->
nodes
().
size
(),
7UL
/*real nodes*/
);
Visualize
(
graph
.
get
());
}
TEST
(
pattern_matcher
2
,
test
)
{
TEST
(
pattern_matcher
_high_api
,
fuse_
test
)
{
framework
::
ProgramDesc
program_desc
;
std
::
vector
<
Place
>
places
{{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}};
auto
scope
=
std
::
make_shared
<
Scope
>
();
...
...
@@ -143,6 +138,7 @@ TEST(pattern_matcher2, test) {
fuser
(
graph
.
get
());
ASSERT_EQ
(
graph
->
nodes
().
size
(),
num_nodes
-
3UL
/*nodes removed */
+
1UL
/* fused fc node*/
);
Visualize
(
graph
.
get
());
}
}
// namespace mir
...
...
paddle/fluid/lite/core/mir/ssa_graph.cc
浏览文件 @
505e450d
...
...
@@ -16,6 +16,7 @@
#include <algorithm>
#include <memory>
#include <set>
#include <unordered_map>
#include <utility>
namespace
paddle
{
...
...
@@ -93,31 +94,6 @@ std::vector<mir::Node *> SSAGraph::StmtTopologicalOrder() {
return
res
;
}
void
SSAGraph
::
GraphCreateTmpVarNodes
(
const
Program
&
program
)
{
for
(
const
auto
&
name
:
program
.
tmp_vars
())
{
CHECK
(
!
arguments_
.
count
(
name
))
<<
"duplicate creating temp variable: "
<<
name
;
VLOG
(
5
)
<<
"create arg node "
<<
name
;
node_storage_
.
emplace_back
();
auto
&
new_node
=
node_storage_
.
back
();
new_node
.
AsArg
(
name
);
arguments_
[
name
]
=
&
new_node
;
}
}
void
SSAGraph
::
GraphCreateWeightVarNodes
(
const
Program
&
program
)
{
// create weight nodes.
for
(
const
auto
&
name
:
program
.
weights
())
{
CHECK
(
!
arguments_
.
count
(
name
))
<<
"duplicate creating weight variable: "
<<
name
;
VLOG
(
5
)
<<
"create arg node "
<<
name
;
node_storage_
.
emplace_back
();
auto
&
new_node
=
node_storage_
.
back
();
new_node
.
AsArg
(
name
);
arguments_
[
name
]
=
&
new_node
;
}
}
Node
*
SSAGraph
::
GraphCreateInstructNode
(
const
std
::
shared_ptr
<
OpLite
>
&
op
,
const
std
::
vector
<
Place
>
&
valid_places
)
{
node_storage_
.
emplace_back
();
...
...
@@ -135,29 +111,45 @@ Node *SSAGraph::GraphCreateInstructNode(
void
SSAGraph
::
Build
(
const
Program
&
program
,
const
std
::
vector
<
Place
>
&
valid_places
)
{
CHECK
(
node_storage_
.
empty
());
GraphCreateTmpVarNodes
(
program
);
GraphCreateWeightVarNodes
(
program
);
CHECK
(
CheckNodesRoleSet
());
auto
weights_name
=
program
.
weights
();
auto
is_weights
=
[
&
](
const
std
::
string
&
name
)
->
bool
{
auto
it
=
std
::
find
(
weights_name
.
begin
(),
weights_name
.
end
(),
name
);
if
(
it
==
weights_name
.
end
())
return
false
;
return
true
;
};
std
::
unordered_map
<
std
::
string
,
mir
::
Node
*>
arg_update_node_map_
;
for
(
auto
&
op
:
program
.
ops
())
{
auto
*
op_node
=
GraphCreateInstructNode
(
op
,
valid_places
);
for
(
const
std
::
string
&
name
:
op
->
op_info
()
->
input_names
())
{
auto
*
arg
=
Argument
(
name
);
CHECK
(
arg
->
IsRoleSet
());
DirectedLink
(
arg
,
op_node
);
mir
::
Node
*
arg_node
=
nullptr
;
if
(
arg_update_node_map_
.
count
(
name
))
{
arg_node
=
arg_update_node_map_
.
at
(
name
);
}
else
{
node_storage_
.
emplace_back
();
arg_node
=
&
node_storage_
.
back
();
arg_node
->
AsArg
(
name
,
node_storage_
.
size
()
-
1
);
arg_update_node_map_
[
name
]
=
arg_node
;
}
if
(
is_weights
(
name
))
arg_node
->
AsArg
().
is_weight
=
true
;
CHECK
(
arg_node
->
IsRoleSet
());
DirectedLink
(
arg_node
,
op_node
);
}
for
(
const
std
::
string
&
name
:
op
->
op_info
()
->
output_names
())
{
if
(
!
arguments_
.
count
(
name
))
{
NewArgumentNode
(
name
);
}
auto
*
arg
=
arguments_
.
at
(
name
);
CHECK
(
arg
->
IsRoleSet
());
DirectedLink
(
op_node
,
arg
);
node_storage_
.
emplace_back
();
auto
*
arg_node
=
&
node_storage_
.
back
();
arg_node
->
AsArg
(
name
,
node_storage_
.
size
()
-
1
);
arg_update_node_map_
[
name
]
=
arg_node
;
if
(
is_weights
(
name
))
arg_node
->
AsArg
().
is_weight
=
true
;
CHECK
(
arg_node
->
IsRoleSet
());
DirectedLink
(
op_node
,
arg_node
);
}
CHECK
(
CheckLinksRoleSet
());
}
MarkArgumentWeights
(
program
);
CHECK
(
CheckNodesRoleSet
()
);
CheckValid
();
}
...
...
@@ -227,10 +219,9 @@ bool SSAGraph::CheckLinksRoleSet() {
Node
*
SSAGraph
::
NewArgumentNode
(
const
std
::
string
&
name
)
{
node_storage_
.
emplace_back
();
CHECK
(
!
arguments_
.
count
(
name
))
<<
"duplicate argument called "
<<
name
;
arguments_
[
name
]
=
&
node_storage_
.
back
();
node_storage_
.
back
().
AsArg
(
name
);
return
&
node_storage_
.
back
();
auto
&
arg_node
=
node_storage_
.
back
();
arg_node
.
AsArg
(
name
,
node_storage_
.
size
()
-
1
);
return
&
arg_node
;
}
Node
*
SSAGraph
::
NewInstructNode
()
{
...
...
paddle/fluid/lite/core/mir/ssa_graph.h
浏览文件 @
505e450d
...
...
@@ -40,8 +40,6 @@ class SSAGraph : GraphBase {
void
Build
(
const
Program
&
program
,
const
std
::
vector
<
Place
>
&
valid_places
);
void
RemoveNode
(
const
mir
::
Node
*
node
);
mir
::
Node
*
Argument
(
const
std
::
string
&
name
);
std
::
vector
<
mir
::
Node
*>
StmtTopologicalOrder
();
// The inputs of the graph.
...
...
@@ -68,9 +66,7 @@ class SSAGraph : GraphBase {
const
std
::
vector
<
Place
>
&
valid_places
);
private:
void
GraphCreateTmpVarNodes
(
const
Program
&
program
);
void
GraphCreateWeightVarNodes
(
const
Program
&
program
);
mir
::
Node
*
Argument
(
const
std
::
string
&
name
);
// Check the bidirectional connection.
bool
CheckBidirectionalConnection
();
bool
CheckNodesRoleSet
();
...
...
paddle/fluid/lite/core/mir/type_target_transform_pass.cc
浏览文件 @
505e450d
...
...
@@ -65,20 +65,22 @@ void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph, Node* inst_node,
<<
" for kernel "
<<
inst
.
op
->
DebugString
()
<<
" "
<<
*
in
->
AsArg
().
type
<<
" -> "
<<
*
decl_arg_type
;
// Add an IoCopy instruction to make the input compatible with other dist.
AddIoCopyInst
(
*
in
->
AsArg
().
type
,
*
decl_arg_type
,
in
->
AsArg
().
name
,
graph
,
inst_node
,
valid_places_
);
AddIoCopyInst
(
*
in
->
AsArg
().
type
,
*
decl_arg_type
,
in
,
graph
,
inst_node
,
valid_places_
);
}
}
void
TypeTargetTransformPass
::
AddIoCopyInst
(
const
Type
&
from
,
const
Type
&
to
,
const
std
::
string
&
var
,
SSAGraph
*
graph
,
const
Type
&
from
,
const
Type
&
to
,
Node
*
in
,
SSAGraph
*
graph
,
Node
*
inst_node
,
const
std
::
vector
<
Place
>&
valid_places
)
{
CHECK
(
!
valid_places
.
empty
())
<<
"valid_place should be set"
;
// var -> new_transform_op -> new_var -> inst
// So there will be a new Argument node and a new IoCopy Statement Node.
CHECK
(
in
->
IsArg
());
auto
node_id
=
[
&
]
{
return
graph
->
nodes
().
size
();
};
auto
io_copy_output_name
=
var
+
"/trans/"
+
std
::
to_string
(
node_id
());
auto
io_copy_output_name
=
in
->
AsArg
().
name
+
"/trans/"
+
std
::
to_string
(
node_id
());
auto
*
io_copy_output_arg
=
graph
->
NewArgumentNode
(
io_copy_output_name
);
auto
*
io_copy_inst
=
graph
->
NewInstructNode
();
...
...
@@ -92,7 +94,7 @@ void TypeTargetTransformPass::AddIoCopyInst(
// Create IoCopy Instruction.
cpp
::
OpDesc
op_desc
;
op_desc
.
SetType
(
"io_copy"
);
op_desc
.
SetInput
(
"Input"
,
{
var
});
op_desc
.
SetInput
(
"Input"
,
{
in
->
AsArg
().
name
});
op_desc
.
SetOutput
(
"Out"
,
{
io_copy_output_name
});
io_copy_op
->
Attach
(
op_desc
,
inst_node
->
AsStmt
().
op
->
scope
());
...
...
@@ -100,18 +102,18 @@ void TypeTargetTransformPass::AddIoCopyInst(
io_copy_inst
->
AsStmt
(
"io_copy"
,
std
::
move
(
kernels
),
io_copy_op
);
// Remove the old link
RemoveDirectedLink
(
graph
->
Argument
(
var
)
,
inst_node
);
RemoveDirectedLink
(
in
,
inst_node
);
// Update the original instruction OpDesc.
// Update its input to the io_copy_output_name
// Add new link, var -> new_inst, new_inst->newarg, newarg->inst
DirectedLink
(
graph
->
Argument
(
var
)
,
io_copy_inst
);
DirectedLink
(
in
,
io_copy_inst
);
DirectedLink
(
io_copy_inst
,
io_copy_output_arg
);
DirectedLink
(
io_copy_output_arg
,
inst_node
);
// reset opdesc and update kernel information
UpdateInputTo
(
inst_node
->
AsStmt
().
op
->
mutable_op_info
(),
var
,
UpdateInputTo
(
inst_node
->
AsStmt
().
op
->
mutable_op_info
(),
in
->
AsArg
().
name
,
io_copy_output_name
);
inst_node
->
AsStmt
().
op
->
Attach
(
*
inst_node
->
AsStmt
().
op
->
op_info
(),
...
...
paddle/fluid/lite/core/mir/type_target_transform_pass.h
浏览文件 @
505e450d
...
...
@@ -45,7 +45,7 @@ class TypeTargetTransformPass : public ProgramPass {
void
ComplementInputs
(
SSAGraph
*
graph
,
Node
*
inst_node
,
Node
*
in
);
void
AddIoCopyInst
(
const
Type
&
from
,
const
Type
&
to
,
const
std
::
string
&
var
,
void
AddIoCopyInst
(
const
Type
&
from
,
const
Type
&
to
,
Node
*
in
,
SSAGraph
*
graph
,
Node
*
inst_node
,
const
std
::
vector
<
Place
>&
valid_places
);
...
...
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
浏览文件 @
505e450d
...
...
@@ -13,7 +13,10 @@
// limitations under the License.
#pragma once
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/target_wrapper.h"
...
...
@@ -60,40 +63,44 @@ class VariablePlaceInferencePass : public DebugPass {
// LOG(INFO) << "- inferencing type " <<
// deal with inputs
VLOG
(
4
)
<<
"inferencing op "
<<
inst
.
op_type
;
for
(
auto
&
arg_name
:
inst
.
op_info
()
->
input_argnames
())
{
// TODO(zhaolong): Add check if the node's name in op's arguments.
auto
get_argname
=
[
&
](
const
std
::
string
&
node_name
,
const
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>&
argname_map
)
->
std
::
string
{
for
(
auto
&
ele
:
argname_map
)
{
auto
it
=
std
::
find
(
ele
.
second
.
begin
(),
ele
.
second
.
end
(),
node_name
);
if
(
it
!=
ele
.
second
.
end
())
return
ele
.
first
;
}
return
""
;
};
for
(
auto
*
x_in
:
x
->
inlinks
)
{
std
::
string
node_name
=
x_in
->
AsArg
().
name
;
std
::
string
arg_name
=
get_argname
(
node_name
,
inst
.
op_info
()
->
inputs
());
CHECK
(
arg_name
.
size
()
>
0
)
<<
"can not found op arguments for node "
<<
node_name
;
VLOG
(
3
)
<<
"-- input arg_name "
<<
arg_name
;
// check if inputs's place is set, if not set, update them with the
// kernel's declaration.
auto
type
=
inst
.
picked_kernel
().
GetInputDeclType
(
arg_name
);
auto
arg_names
=
inst
.
op_info
()
->
inputs
().
at
(
arg_name
);
for
(
auto
&
arg_name
:
arg_names
)
{
VLOG
(
3
)
<<
"--- var "
<<
arg_name
;
auto
*
node
=
graph
->
RetrieveArgument
(
arg_name
);
CHECK
(
node
)
<<
"argument "
<<
arg_name
<<
" not exists in the graph"
;
auto
&
arg_node
=
node
->
AsArg
();
if
(
!
arg_node
.
type
)
{
VLOG
(
4
)
<<
"set type "
<<
*
type
<<
" "
<<
node
;
arg_node
.
type
=
type
;
}
if
(
!
x_in
->
AsArg
().
type
)
{
VLOG
(
4
)
<<
"set type "
<<
*
type
<<
" "
<<
x_in
;
x_in
->
AsArg
().
type
=
type
;
}
}
for
(
auto
&
arg_name
:
inst
.
op_info
()
->
output_argnames
())
{
for
(
auto
*
x_out
:
x
->
outlinks
)
{
std
::
string
node_name
=
x_out
->
AsArg
().
name
;
std
::
string
arg_name
=
get_argname
(
node_name
,
inst
.
op_info
()
->
outputs
());
CHECK
(
arg_name
.
size
()
>
0
)
<<
"can not found op arguments for node "
<<
node_name
;
VLOG
(
3
)
<<
"-- output arg_name "
<<
arg_name
;
auto
type
=
inst
.
picked_kernel
().
GetOutputDeclType
(
arg_name
);
auto
arg_names
=
inst
.
op_info
()
->
outputs
().
at
(
arg_name
);
// check if outputs's place is set, if not set, update them with the
// kernel's declaration.
for
(
auto
&
arg_name
:
arg_names
)
{
VLOG
(
3
)
<<
"--- var "
<<
arg_name
;
auto
*
node
=
graph
->
RetrieveArgument
(
arg_name
);
CHECK
(
node
)
<<
"argument "
<<
arg_name
<<
" not exists in the graph"
;
auto
&
arg_node
=
node
->
AsArg
();
if
(
!
arg_node
.
type
)
{
node
->
AsArg
().
type
=
type
;
VLOG
(
3
)
<<
"set type "
<<
*
type
;
}
if
(
!
x_out
->
AsArg
().
type
)
{
VLOG
(
4
)
<<
"set type "
<<
*
type
<<
" "
<<
x_out
;
x_out
->
AsArg
().
type
=
type
;
}
}
}
...
...
paddle/fluid/lite/core/op_lite.h
浏览文件 @
505e450d
...
...
@@ -59,7 +59,7 @@ class OpLite : public Registry {
}
void
SetValidPlaces
(
const
std
::
vector
<
Place
>
&
places
)
{
LOG
(
INFO
)
<<
"valid places "
<<
valid_places_
.
size
();
VLOG
(
3
)
<<
"valid places "
<<
valid_places_
.
size
();
valid_places_
=
places
;
}
const
std
::
vector
<
Place
>
&
valid_places
()
const
{
return
valid_places_
;
}
...
...
paddle/fluid/lite/core/optimizer.h
浏览文件 @
505e450d
...
...
@@ -48,6 +48,9 @@ class Optimizer {
if
(
passes
.
empty
())
{
RunPasses
(
std
::
vector
<
std
::
string
>
{{
"lite_conv_bn_fuse_pass"
,
//
"lite_conv_elementwise_add_act_fuse_pass"
,
//
"lite_fc_fuse_pass"
,
//
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
"static_kernel_pick_pass"
,
//
"variable_place_inference_pass"
,
//
...
...
paddle/fluid/lite/core/profile/basic_profiler.h
浏览文件 @
505e450d
...
...
@@ -152,8 +152,8 @@ class BasicProfiler {
}
record_t
*
mutable_record
(
int
id
)
{
CHECK_LT
(
id
,
records_
.
size
());
CHECK_GE
(
id
,
0
);
CHECK_LT
(
static_cast
<
size_t
>
(
id
),
records_
.
size
());
return
&
records_
[
id
];
}
...
...
paddle/fluid/lite/core/program.h
浏览文件 @
505e450d
...
...
@@ -140,7 +140,7 @@ class RuntimeProgram {
void
Run
()
{
for
(
auto
&
inst
:
instructions_
)
{
LOG
(
INFO
)
<<
">> Running kernel: "
<<
inst
;
VLOG
(
4
)
<<
">> Running kernel: "
<<
inst
;
inst
.
Run
();
}
}
...
...
paddle/fluid/lite/kernels/x86/conv_compute.cc
浏览文件 @
505e450d
...
...
@@ -74,6 +74,7 @@ class Conv2dCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
lite
::
Tensor
col_matrix
;
if
(
is_expand
)
{
col
.
Resize
(
col_shape
);
col
.
mutable_data
<
T
>
();
col_matrix
.
ShareDataWith
(
col
);
col_matrix
.
Resize
(
col_matrix_shape
);
}
...
...
@@ -104,7 +105,7 @@ class Conv2dCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
param
.
x
->
raw_tensor
().
Slice
(
i
,
i
+
1
).
Resize
(
input_shape
.
data
()));
lite
::
Tensor
out_batch
;
out_batch
.
ShareDataWith
(
param
.
output
->
raw_tensor
().
Slice
(
i
,
i
+
1
).
Resize
(
input
_shape
.
data
()));
output_matrix
_shape
.
data
()));
for
(
int
g
=
0
;
g
<
param
.
groups
;
g
++
)
{
lite
::
Tensor
in_slice
;
...
...
@@ -155,7 +156,6 @@ REGISTER_LITE_KERNEL(conv2d, kX86, kFloat, kNCHW,
.
BindInput
(
"Input"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kX86
))})
.
BindInput
(
"Filter"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kX86
))})
.
BindInput
(
"Bias"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kX86
))})
.
BindInput
(
"Input"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kX86
))})
.
BindOutput
(
"Output"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kX86
))})
.
Finalize
();
...
...
@@ -164,6 +164,5 @@ REGISTER_LITE_KERNEL(depthwise_conv2d, kX86, kFloat, kNCHW,
.
BindInput
(
"Input"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kX86
))})
.
BindInput
(
"Filter"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kX86
))})
.
BindInput
(
"Bias"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kX86
))})
.
BindInput
(
"Input"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kX86
))})
.
BindOutput
(
"Output"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kX86
))})
.
Finalize
();
paddle/fluid/lite/kernels/x86/fc_compute.cc
浏览文件 @
505e450d
...
...
@@ -27,8 +27,8 @@ namespace kernels {
namespace
x86
{
template
<
typename
T
>
void
fc_compute_eigen
(
const
T
*
x
,
int
x_
w
,
int
x_h
,
//
const
T
*
w
,
int
w_
w
,
int
w_h
,
//
void
fc_compute_eigen
(
const
T
*
x
,
int
x_
h
,
int
x_w
,
//
const
T
*
w
,
int
w_
h
,
int
w_w
,
//
const
T
*
b
,
//
T
*
out
)
{
using
matrix_t
=
...
...
@@ -36,38 +36,31 @@ void fc_compute_eigen(const T* x, int x_w, int x_h, //
Eigen
::
Map
<
const
matrix_t
>
X
(
x
,
x_h
,
x_w
);
Eigen
::
Map
<
const
matrix_t
>
W
(
w
,
w_h
,
w_w
);
Eigen
::
Map
<
matrix_t
>
Out
(
out
,
x_h
,
w_
h
);
Eigen
::
Map
<
matrix_t
>
Out
(
out
,
x_h
,
w_
w
);
Out
=
X
*
W
.
transpose
()
;
Out
=
X
*
W
;
if
(
b
)
{
Eigen
::
Map
<
const
Eigen
::
Matrix
<
T
,
Eigen
::
Dynamic
,
1
>>
B
(
b
,
w_
h
);
Eigen
::
Map
<
const
Eigen
::
Matrix
<
T
,
Eigen
::
Dynamic
,
1
>>
B
(
b
,
w_
w
);
Out
=
Out
.
array
().
rowwise
()
+
B
.
transpose
().
array
();
}
}
template
<
typename
T
>
__attribute__
((
optimize
(
"unroll-loops"
)))
//
T
dot
(
const
T
*
x
,
const
T
*
y
,
int
dim
)
{
T
out
{};
for
(
int
i
=
0
;
i
<
dim
;
i
++
)
{
out
+=
x
[
i
]
*
y
[
i
];
}
return
out
;
}
template
<
typename
T
>
void
fc_compute_naive
(
const
T
*
x
,
int
x_w
,
int
x_h
,
//
const
T
*
w
,
int
w_w
,
int
w_h
,
//
void
fc_compute_naive
(
const
T
*
x
,
int
x_h
,
int
x_w
,
//
const
T
*
w
,
int
w_h
,
int
w_w
,
//
const
T
*
b
,
//
T
*
out
)
{
CHECK_EQ
(
x_w
,
w_
w
);
CHECK_EQ
(
x_w
,
w_
h
);
// out shape: (x_h, w_w)
memset
(
out
,
0
,
x_h
*
w_h
*
sizeof
(
T
));
for
(
int
r
=
0
;
r
<
x_h
;
r
++
)
{
for
(
int
c
=
0
;
c
<
w_h
;
c
++
)
{
out
[
r
*
w_h
+
c
]
=
dot
(
&
x
[
r
*
x_w
],
&
w
[
c
*
w_w
],
w_w
)
+
b
[
c
];
memset
(
out
,
0
,
x_h
*
w_w
*
sizeof
(
T
));
for
(
int
i
=
0
;
i
<
x_h
;
i
++
)
{
for
(
int
j
=
0
;
j
<
w_w
;
j
++
)
{
T
tmp
=
static_cast
<
T
>
(
0
);
for
(
int
k
=
0
;
k
<
x_w
;
k
++
)
{
tmp
+=
x
[
i
*
x_w
+
k
]
*
w
[
k
*
w_w
+
j
];
}
out
[
i
*
w_w
+
j
]
=
tmp
+
b
[
j
];
}
}
}
...
...
@@ -89,8 +82,8 @@ class FcCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
.
Slice
(
param
.
in_num_col_dims
,
param
.
input
->
dims
().
size
())
.
production
(),
param
.
w
->
data
<
T
>
(),
// w
param
.
w
->
dims
()[
1
],
// w_w
param
.
w
->
dims
()[
0
],
// w_h
param
.
w
->
dims
()[
1
],
// w_w
param
.
bias
->
data
<
T
>
(),
// b
param
.
output
->
mutable_data
<
T
>
());
}
...
...
paddle/fluid/lite/kernels/x86/relu_compute.cc
浏览文件 @
505e450d
...
...
@@ -51,6 +51,6 @@ class ReluCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
REGISTER_LITE_KERNEL
(
relu
,
kX86
,
kFloat
,
kNCHW
,
paddle
::
lite
::
kernels
::
x86
::
ReluCompute
<
float
>
,
def
)
.
BindInput
(
"
Input
"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kX86
))})
.
BindInput
(
"
X
"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kX86
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kX86
))})
.
Finalize
();
paddle/fluid/lite/model_parser/model_parser.cc
浏览文件 @
505e450d
...
...
@@ -91,7 +91,7 @@ void LoadLoDTensor(std::istream &is, Variable *var) {
auto
*
tensor
=
var
->
GetMutable
<
lite
::
Tensor
>
();
uint32_t
version
{};
is
.
read
(
reinterpret_cast
<
char
*>
(
&
version
),
sizeof
(
version
));
LOG
(
INFO
)
<<
"model version "
<<
version
;
VLOG
(
3
)
<<
"model version "
<<
version
;
// Load LoD information
uint64_t
lod_level
{};
...
...
@@ -154,7 +154,7 @@ void LoadModel(const std::string &model_dir, Scope *scope,
continue
;
std
::
string
file_path
=
model_dir
+
"/"
+
var
.
name
();
LOG
(
INFO
)
<<
"reading weight "
<<
var
.
name
();
VLOG
(
4
)
<<
"reading weight "
<<
var
.
name
();
std
::
ifstream
file
(
file_path
);
switch
(
var
.
type
().
type
())
{
...
...
paddle/fluid/lite/operators/CMakeLists.txt
浏览文件 @
505e450d
...
...
@@ -20,7 +20,7 @@ cc_library(fill_constant_op_lite SRCS fill_constant_op.cc DEPS ${op_DEPS})
cc_library
(
op_params_lite SRCS op_params.cc DEPS
${
tensor_lite
}
any_lite framework_proto_lite
)
cc_library
(
dropout_op_lite SRCS dropout_op.cc DEPS
${
op_DEPS
}
)
cc_library
(
concat_op_lite SRCS concat_op.cc DEPS
${
op_DEPS
}
)
cc_library
(
split_op_lite SRCS split_op.cc DEPS
${
op_DEPS
}
)
#
cc_library(split_op_lite SRCS split_op.cc DEPS ${op_DEPS})
set
(
ops_lite
conv_op_lite
...
...
@@ -41,7 +41,7 @@ set(ops_lite
activation_ops_lite
dropout_op_lite
concat_op_lite
split_op_lite
#
split_op_lite
PARENT_SCOPE
)
lite_cc_test
(
test_fc_op_lite SRCS fc_op_test.cc
...
...
@@ -56,4 +56,3 @@ lite_cc_test(test_softmax_op_lite SRCS softmax_op_test.cc DEPS softmax_op_lite m
lite_cc_test
(
test_reshape_op_lite SRCS reshape_op_test.cc DEPS reshape_op_lite memory_lite
)
lite_cc_test
(
test_batch_norm_op_lite SRCS batch_norm_op_test.cc DEPS batch_norm_op_lite memory_lite
)
lite_cc_test
(
test_concat_op_lite SRCS concat_op_test.cc DEPS concat_op_lite memory_lite
)
paddle/fluid/lite/operators/conv_op.h
浏览文件 @
505e450d
...
...
@@ -30,25 +30,27 @@ class ConvOpLite : public OpLite {
public:
ConvOpLite
()
{}
explicit
ConvOpLite
(
const
std
::
string
&
type
)
:
OpLite
(
type
)
{}
explicit
ConvOpLite
(
const
std
::
string
&
type
)
:
OpLite
(
type
)
{}
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
auto
input
=
op_desc
.
Input
(
"Input"
).
front
();
auto
filter
=
op_desc
.
Input
(
"Filter"
).
front
();
auto
output
=
op_desc
.
Output
(
"Output"
).
front
();
param_
.
x
=
scope
->
FindVar
(
input
)
->
GetMutable
<
lite
::
Tensor
>
();
param_
.
filter
=
scope
->
FindVar
(
filter
)
->
GetMutable
<
lite
::
Tensor
>
();
CHECK
(
scope
->
FindVar
(
output
));
param_
.
output
=
scope
->
FindVar
(
output
)
->
GetMutable
<
lite
::
Tensor
>
();
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
auto
X
=
op_desc
.
Input
(
"Input"
).
front
();
auto
Filter
=
op_desc
.
Input
(
"Filter"
).
front
();
auto
Out
=
op_desc
.
Output
(
"Output"
).
front
();
param_
.
x
=
scope
->
FindVar
(
X
)
->
GetMutable
<
lite
::
Tensor
>
();
param_
.
filter
=
scope
->
FindVar
(
Filter
)
->
GetMutable
<
lite
::
Tensor
>
();
param_
.
output
=
scope
->
FindVar
(
Out
)
->
GetMutable
<
lite
::
Tensor
>
();
param_
.
strides
=
op_desc
.
GetAttr
<
std
::
vector
<
int
>>
(
"strides"
);
param_
.
paddings
=
op_desc
.
GetAttr
<
std
::
vector
<
int
>>
(
"paddings"
);
param_
.
groups
=
op_desc
.
GetAttr
<
int
>
(
"groups"
);
param_
.
dilations
=
op_desc
.
GetAttr
<
std
::
vector
<
int
>>
(
"dilations"
);
// optional params
std
::
vector
<
std
::
string
>
input_arg_names
=
op_desc
.
InputArgumentNames
();
if
(
std
::
find
(
input_arg_names
.
begin
(),
input_arg_names
.
end
(),
"Bias"
)
!=
...
...
@@ -58,7 +60,7 @@ class ConvOpLite : public OpLite {
auto
bias_var
=
scope
->
FindVar
(
bias_arguments
.
front
());
if
(
bias_var
!=
nullptr
)
{
param_
.
bias
=
const_cast
<
lite
::
Tensor
*>
(
&
(
bias_var
->
Get
<
lite
::
Tensor
>
()));
const_cast
<
lite
::
Tensor
*>
(
&
(
bias_var
->
Get
<
lite
::
Tensor
>
()));
}
}
}
...
...
@@ -68,7 +70,7 @@ class ConvOpLite : public OpLite {
if
(
res_data_arguments
.
size
()
>
0
)
{
auto
residual_data_var
=
scope
->
FindVar
(
res_data_arguments
.
front
());
if
(
residual_data_var
!=
nullptr
)
{
param_
.
residualData
=
const_cast
<
lite
::
Tensor
*>
(
param_
.
residualData
=
const_cast
<
lite
::
Tensor
*>
(
&
(
residual_data_var
->
Get
<
lite
::
Tensor
>
()));
}
}
...
...
@@ -77,7 +79,7 @@ class ConvOpLite : public OpLite {
return
true
;
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"conv2d"
;
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录