Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
4fe5c8aa
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4fe5c8aa
编写于
6月 20, 2019
作者:
N
nhzlx
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'incubate/lite' of
http://10.87.145.36/inference/paddlelite
into xzl/incubate/lite
fix comments
上级
8ead391d
6f705068
变更
65
显示空白变更内容
内联
并排
Showing
65 changed file
with
859 addition
and
224 deletion
+859
-224
CMakeLists.txt
CMakeLists.txt
+2
-2
paddle/fluid/lite/api/apis_test.cc
paddle/fluid/lite/api/apis_test.cc
+1
-1
paddle/fluid/lite/api/cxx_api_bin.cc
paddle/fluid/lite/api/cxx_api_bin.cc
+1
-1
paddle/fluid/lite/api/lite_api_test_helper.cc
paddle/fluid/lite/api/lite_api_test_helper.cc
+1
-0
paddle/fluid/lite/core/CMakeLists.txt
paddle/fluid/lite/core/CMakeLists.txt
+0
-1
paddle/fluid/lite/core/mir/CMakeLists.txt
paddle/fluid/lite/core/mir/CMakeLists.txt
+24
-20
paddle/fluid/lite/core/mir/elimination/CMakeLists.txt
paddle/fluid/lite/core/mir/elimination/CMakeLists.txt
+7
-0
paddle/fluid/lite/core/mir/elimination/identity_scale_eliminate_pass.cc
...ite/core/mir/elimination/identity_scale_eliminate_pass.cc
+72
-0
paddle/fluid/lite/core/mir/elimination/identity_scale_eliminate_pass_test.cc
...ore/mir/elimination/identity_scale_eliminate_pass_test.cc
+93
-0
paddle/fluid/lite/core/mir/fusion/CMakeLists.txt
paddle/fluid/lite/core/mir/fusion/CMakeLists.txt
+0
-1
paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass.cc
paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass.cc
+1
-1
paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass.h
paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass.h
+0
-0
paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass_test.cc
paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass_test.cc
+1
-1
paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.cc
paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.cc
+1
-1
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuse_pass.cc
...e/mir/fusion/conv_elementwise_add_activation_fuse_pass.cc
+1
-1
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuse_pass.h
...re/mir/fusion/conv_elementwise_add_activation_fuse_pass.h
+0
-0
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuse_pass_test.cc
.../fusion/conv_elementwise_add_activation_fuse_pass_test.cc
+1
-1
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.cc
.../core/mir/fusion/conv_elementwise_add_activation_fuser.cc
+1
-1
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass.cc
...te/core/mir/fusion/conv_elementwise_add_relu_fuse_pass.cc
+39
-0
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass.h
...ite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass.h
+32
-0
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass_test.cc
...re/mir/fusion/conv_elementwise_add_relu_fuse_pass_test.cc
+153
-0
paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.cc
...e/core/mir/fusion/elementwise_add_activation_fuse_pass.cc
+1
-1
paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.h
...te/core/mir/fusion/elementwise_add_activation_fuse_pass.h
+0
-0
paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuse_pass_test.cc
...e/mir/fusion/elementwise_add_activation_fuse_pass_test.cc
+1
-1
paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.cc
.../lite/core/mir/fusion/elementwise_add_activation_fuser.cc
+1
-1
paddle/fluid/lite/core/mir/fusion/fc_fuse_pass.cc
paddle/fluid/lite/core/mir/fusion/fc_fuse_pass.cc
+1
-1
paddle/fluid/lite/core/mir/fusion/fc_fuse_pass.h
paddle/fluid/lite/core/mir/fusion/fc_fuse_pass.h
+0
-0
paddle/fluid/lite/core/mir/fusion/fc_fuse_pass_test.cc
paddle/fluid/lite/core/mir/fusion/fc_fuse_pass_test.cc
+1
-1
paddle/fluid/lite/core/mir/fusion/fc_fuser.cc
paddle/fluid/lite/core/mir/fusion/fc_fuser.cc
+1
-1
paddle/fluid/lite/core/mir/fusion/quant_dequant_fuse_pass.cc
paddle/fluid/lite/core/mir/fusion/quant_dequant_fuse_pass.cc
+1
-1
paddle/fluid/lite/core/mir/fusion/quant_dequant_fuse_pass.h
paddle/fluid/lite/core/mir/fusion/quant_dequant_fuse_pass.h
+0
-0
paddle/fluid/lite/core/mir/fusion/quant_dequant_op_fuser.cc
paddle/fluid/lite/core/mir/fusion/quant_dequant_op_fuser.cc
+2
-2
paddle/fluid/lite/core/mir/generate_program_pass.cc
paddle/fluid/lite/core/mir/generate_program_pass.cc
+1
-1
paddle/fluid/lite/core/mir/graph_visualize_pass.cc
paddle/fluid/lite/core/mir/graph_visualize_pass.cc
+1
-1
paddle/fluid/lite/core/mir/io_copy_kernel_pick_pass.cc
paddle/fluid/lite/core/mir/io_copy_kernel_pick_pass.cc
+2
-2
paddle/fluid/lite/core/mir/node.cc
paddle/fluid/lite/core/mir/node.cc
+59
-0
paddle/fluid/lite/core/mir/node.h
paddle/fluid/lite/core/mir/node.h
+32
-34
paddle/fluid/lite/core/mir/pattern_matcher.h
paddle/fluid/lite/core/mir/pattern_matcher.h
+2
-3
paddle/fluid/lite/core/mir/pattern_matcher_high_api.cc
paddle/fluid/lite/core/mir/pattern_matcher_high_api.cc
+1
-0
paddle/fluid/lite/core/mir/pattern_matcher_high_api.h
paddle/fluid/lite/core/mir/pattern_matcher_high_api.h
+7
-1
paddle/fluid/lite/core/mir/pattern_matcher_high_api_test.cc
paddle/fluid/lite/core/mir/pattern_matcher_high_api_test.cc
+2
-2
paddle/fluid/lite/core/mir/pattern_matcher_test.cc
paddle/fluid/lite/core/mir/pattern_matcher_test.cc
+13
-13
paddle/fluid/lite/core/mir/ssa_graph.h
paddle/fluid/lite/core/mir/ssa_graph.h
+5
-0
paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc
paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc
+6
-6
paddle/fluid/lite/core/mir/type_target_transform_pass.cc
paddle/fluid/lite/core/mir/type_target_transform_pass.cc
+8
-8
paddle/fluid/lite/core/mir/use_passes.h
paddle/fluid/lite/core/mir/use_passes.h
+3
-1
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
+6
-4
paddle/fluid/lite/core/op_lite.cc
paddle/fluid/lite/core/op_lite.cc
+1
-2
paddle/fluid/lite/core/op_lite.h
paddle/fluid/lite/core/op_lite.h
+16
-0
paddle/fluid/lite/core/optimizer.h
paddle/fluid/lite/core/optimizer.h
+4
-0
paddle/fluid/lite/core/program.h
paddle/fluid/lite/core/program.h
+1
-1
paddle/fluid/lite/core/tensor.h
paddle/fluid/lite/core/tensor.h
+0
-1
paddle/fluid/lite/kernels/arm/elementwise_compute_test.cc
paddle/fluid/lite/kernels/arm/elementwise_compute_test.cc
+29
-0
paddle/fluid/lite/kernels/arm/softmax_compute_test.cc
paddle/fluid/lite/kernels/arm/softmax_compute_test.cc
+8
-1
paddle/fluid/lite/kernels/x86/elementwise_compute_test.cc
paddle/fluid/lite/kernels/x86/elementwise_compute_test.cc
+2
-0
paddle/fluid/lite/kernels/x86/mul_compute.h
paddle/fluid/lite/kernels/x86/mul_compute.h
+30
-12
paddle/fluid/lite/kernels/x86/mul_compute_test.cc
paddle/fluid/lite/kernels/x86/mul_compute_test.cc
+2
-0
paddle/fluid/lite/model_parser/CMakeLists.txt
paddle/fluid/lite/model_parser/CMakeLists.txt
+2
-1
paddle/fluid/lite/model_parser/cpp/CMakeLists.txt
paddle/fluid/lite/model_parser/cpp/CMakeLists.txt
+0
-1
paddle/fluid/lite/model_parser/desc_apis.h
paddle/fluid/lite/model_parser/desc_apis.h
+22
-0
paddle/fluid/lite/model_parser/pb/op_desc.h
paddle/fluid/lite/model_parser/pb/op_desc.h
+2
-0
paddle/fluid/lite/operators/mul_op.h
paddle/fluid/lite/operators/mul_op.h
+6
-2
paddle/fluid/lite/operators/op_params.h
paddle/fluid/lite/operators/op_params.h
+2
-2
paddle/fluid/lite/tools/build.sh
paddle/fluid/lite/tools/build.sh
+140
-81
paddle/fluid/lite/utils/varient.h
paddle/fluid/lite/utils/varient.h
+4
-3
未找到文件。
CMakeLists.txt
浏览文件 @
4fe5c8aa
...
@@ -44,9 +44,9 @@ if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
...
@@ -44,9 +44,9 @@ if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
# check arch abi
# check arch abi
if
(
NOT DEFINED ARM_TARGET_LANG
)
if
(
NOT DEFINED ARM_TARGET_LANG
)
set
(
ARM_TARGET_LANG
"
clang
"
CACHE STRING
"Choose ARM Target Language"
)
set
(
ARM_TARGET_LANG
"
gcc
"
CACHE STRING
"Choose ARM Target Language"
)
endif
()
endif
()
set
(
ARM_TARGET_LANG_LIST
"gcc"
"clang"
)
set
(
ARM_TARGET_LANG_LIST
"gcc"
"clang"
""
)
set_property
(
CACHE ARM_TARGET_LANG PROPERTY STRINGS
${
ARM_TARGET_LANG_LIST
}
)
set_property
(
CACHE ARM_TARGET_LANG PROPERTY STRINGS
${
ARM_TARGET_LANG_LIST
}
)
if
(
NOT ARM_TARGET_LANG IN_LIST ARM_TARGET_LANG_LIST
)
if
(
NOT ARM_TARGET_LANG IN_LIST ARM_TARGET_LANG_LIST
)
message
(
FATAL_ERROR
"ARM_TARGET_LANG must be in one of
${
ARM_TARGET_LANG_LIST
}
"
)
message
(
FATAL_ERROR
"ARM_TARGET_LANG must be in one of
${
ARM_TARGET_LANG_LIST
}
"
)
...
...
paddle/fluid/lite/api/apis_test.cc
浏览文件 @
4fe5c8aa
...
@@ -82,7 +82,7 @@ TEST(CXXApi_LightApi, save_and_load_model) {
...
@@ -82,7 +82,7 @@ TEST(CXXApi_LightApi, save_and_load_model) {
ASSERT_TRUE
(
TensorCompareWith
(
*
cxx_out
,
*
light_out
));
ASSERT_TRUE
(
TensorCompareWith
(
*
cxx_out
,
*
light_out
));
std
::
vector
<
std
::
string
>
tensors_with_order
({
std
::
vector
<
std
::
string
>
tensors_with_order
({
"a"
,
"fc_0.w_0"
,
"
fc_0.tmp_0"
,
"
scale_0.tmp_0"
,
"a"
,
"fc_0.w_0"
,
"scale_0.tmp_0"
,
});
});
for
(
const
auto
&
tensor_name
:
tensors_with_order
)
{
for
(
const
auto
&
tensor_name
:
tensors_with_order
)
{
...
...
paddle/fluid/lite/api/cxx_api_bin.cc
浏览文件 @
4fe5c8aa
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/lite/api/cxx_api.h"
#include "paddle/fluid/lite/api/cxx_api.h"
#include <chrono>
#include <chrono>
// NOLINT
#include "paddle/fluid/lite/core/mir/use_passes.h"
#include "paddle/fluid/lite/core/mir/use_passes.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/op_registry.h"
...
...
paddle/fluid/lite/api/lite_api_test_helper.cc
浏览文件 @
4fe5c8aa
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/lite/api/lite_api_test_helper.h"
#include "paddle/fluid/lite/api/lite_api_test_helper.h"
#include <vector>
DEFINE_string
(
model_dir
,
""
,
""
);
DEFINE_string
(
model_dir
,
""
,
""
);
DEFINE_string
(
optimized_model
,
""
,
""
);
DEFINE_string
(
optimized_model
,
""
,
""
);
...
...
paddle/fluid/lite/core/CMakeLists.txt
浏览文件 @
4fe5c8aa
...
@@ -59,4 +59,3 @@ lite_cc_test(test_type_system SRCS type_system_test.cc DEPS type_system utils_li
...
@@ -59,4 +59,3 @@ lite_cc_test(test_type_system SRCS type_system_test.cc DEPS type_system utils_li
lite_cc_test
(
test_types_lite SRCS types_test.cc DEPS types_lite
)
lite_cc_test
(
test_types_lite SRCS types_test.cc DEPS types_lite
)
lite_cc_test
(
test_memory_lite SRCS memory_test.cc DEPS memory_lite
)
lite_cc_test
(
test_memory_lite SRCS memory_test.cc DEPS memory_lite
)
lite_cc_test
(
test_context_lite SRCS context_test.cc DEPS context_lite X86_DEPS operator
)
lite_cc_test
(
test_context_lite SRCS context_test.cc DEPS context_lite X86_DEPS operator
)
paddle/fluid/lite/core/mir/CMakeLists.txt
浏览文件 @
4fe5c8aa
...
@@ -5,12 +5,16 @@ cc_library(mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph mir
...
@@ -5,12 +5,16 @@ cc_library(mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph mir
cc_library
(
mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager
)
cc_library
(
mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager
)
add_subdirectory
(
fusion
)
add_subdirectory
(
fusion
)
add_subdirectory
(
elimination
)
cc_library
(
mir_passes
cc_library
(
mir_passes
SRCS fc_fuse_pass.cc
SRCS
conv_elementwise_add_activation_fuse_pass.cc
fusion/fc_fuse_pass.cc
elementwise_add_activation_fuse_pass.cc
fusion/conv_elementwise_add_activation_fuse_pass.cc
conv_bn_fuse_pass.cc
fusion/conv_bn_fuse_pass.cc
quant_dequant_fuse_pass.cc
fusion/elementwise_add_activation_fuse_pass.cc
fusion/quant_dequant_fuse_pass.cc
elimination/identity_scale_eliminate_pass.cc
static_kernel_pick_pass.cc
static_kernel_pick_pass.cc
variable_place_inference_pass.cc
variable_place_inference_pass.cc
type_target_transform_pass.cc
type_target_transform_pass.cc
...
@@ -74,7 +78,7 @@ message(STATUS "----> Ops lite: ${ops_lite}")
...
@@ -74,7 +78,7 @@ message(STATUS "----> Ops lite: ${ops_lite}")
message
(
STATUS
"----> Host kernels:
${
host_kernels
}
"
)
message
(
STATUS
"----> Host kernels:
${
host_kernels
}
"
)
message
(
STATUS
"----> X86 kernels:
${
x86_kernels
}
"
)
message
(
STATUS
"----> X86 kernels:
${
x86_kernels
}
"
)
lite_cc_test
(
test_lite_fc_fuse SRCS fc_fuse_pass_test.cc
lite_cc_test
(
test_lite_fc_fuse SRCS f
usion/f
c_fuse_pass_test.cc
DEPS cxx_api_lite mir_passes
DEPS cxx_api_lite mir_passes
${
ops_lite
}
${
host_kernels
}
${
x86_kernels
}
${
arm_kernels
}
${
ops_lite
}
${
host_kernels
}
${
x86_kernels
}
${
arm_kernels
}
ARGS --model_dir=
${
LITE_MODEL_DIR
}
/lite_fc_model
ARGS --model_dir=
${
LITE_MODEL_DIR
}
/lite_fc_model
...
@@ -85,10 +89,10 @@ add_dependencies(test_lite_fc_fuse extern_lite_download_lite_fc_model_tar_gz)
...
@@ -85,10 +89,10 @@ add_dependencies(test_lite_fc_fuse extern_lite_download_lite_fc_model_tar_gz)
lite_cc_test
(
test_lite_conv_elementwise_add_activation_fuse
lite_cc_test
(
test_lite_conv_elementwise_add_activation_fuse
SRCS conv_elementwise_add_activation_fuse_pass_test.cc
SRCS
fusion/
conv_elementwise_add_activation_fuse_pass_test.cc
DEPS cxx_api_lite mir_passes
DEPS cxx_api_lite mir_passes
${
ops_lite
}
${
host_kernels
}
${
x86_kernels
}
)
${
ops_lite
}
${
host_kernels
}
${
x86_kernels
}
)
lite_cc_test
(
test_lite_elementwise_add_activation_fuse
lite_cc_test
(
test_lite_elementwise_add_activation_fuse
SRCS elementwise_add_activation_fuse_pass_test.cc
SRCS
fusion/
elementwise_add_activation_fuse_pass_test.cc
DEPS cxx_api_lite mir_passes
DEPS cxx_api_lite mir_passes
${
ops_lite
}
${
host_kernels
}
${
x86_kernels
}
)
${
ops_lite
}
${
host_kernels
}
${
x86_kernels
}
)
paddle/fluid/lite/core/mir/elimination/CMakeLists.txt
0 → 100644
浏览文件 @
4fe5c8aa
if
(
NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
)
lite_cc_test
(
test_identity_scale_eliminate_pass_lite
SRCS identity_scale_eliminate_pass_test.cc
DEPS mir_passes program_lite proto_desc cpp_op_desc_lite
${
ops_lite
}
)
endif
()
paddle/fluid/lite/core/mir/elimination/identity_scale_eliminate_pass.cc
0 → 100644
浏览文件 @
4fe5c8aa
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
namespace
{
class
Eliminator
:
public
FuseBase
{
public:
void
BuildPattern
()
override
{
auto
*
pre_op
=
OpNode
(
"preop"
);
// the previous op's output need update
// TODO(Superjomn) check has only one output
auto
*
x
=
VarNode
(
"x"
)
->
assert_is_op_input
(
"scale"
,
"X"
);
auto
*
scale_op
=
OpNode
(
"scale"
,
"scale"
)
->
assert_op_attr
<
float
>
(
"scale"
,
1.
)
->
assert_op_attr
<
float
>
(
"bias"
,
0.
);
auto
*
out
=
VarNode
(
"out"
)
->
assert_is_op_output
(
"scale"
,
"Out"
);
*
pre_op
>>
*
x
>>
*
scale_op
>>
*
out
;
// The pre_op will be eliminated, and a new output-updated op will insert.
x
->
AsIntermediate
();
// x is pre_op's output, need to update
}
private:
void
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
override
{
auto
&
pre_op
=
matched
.
at
(
"preop"
)
->
AsStmt
();
auto
op_info
=
*
pre_op
.
op_info
();
op_info
.
UpdateAllOutputs
(
matched
.
at
(
"x"
)
->
AsArg
().
name
,
matched
.
at
(
"out"
)
->
AsArg
().
name
);
pre_op
.
ResetOp
(
op_info
,
graph
->
valid_places
());
GraphSafeRemoveNodes
(
graph
,
{
matched
.
at
(
"scale"
)});
IR_NODE_LINK_TO
(
matched
.
at
(
"preop"
),
matched
.
at
(
"out"
));
}
};
}
// namespace
class
IdentityScaleEliminatePass
:
public
ProgramPass
{
public:
void
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
override
{
Eliminator
eliminator
;
eliminator
(
graph
.
get
());
}
};
}
// namespace mir
}
// namespace lite
}
// namespace paddle
REGISTER_MIR_PASS
(
identity_scale_eliminate_pass
,
paddle
::
lite
::
mir
::
IdentityScaleEliminatePass
);
paddle/fluid/lite/core/mir/elimination/identity_scale_eliminate_pass_test.cc
0 → 100644
浏览文件 @
4fe5c8aa
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
#include "paddle/fluid/lite/core/mir/ssa_graph.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
std
::
unique_ptr
<
SSAGraph
>
BuildGraph
(
framework
::
ProgramDesc
*
program_desc
,
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
std
::
vector
<
Place
>&
valid_places
)
{
// Op list:
// (x)->feed -> (feed) -> scale -> (scale_out) -> fetch->(fetch)
// After pass
// (x)->feed->(scale_out)->fetch->(fetch)
auto
*
main_block
=
program_desc
->
MutableBlock
(
0
);
auto
*
feed_op
=
main_block
->
AppendOp
();
auto
*
scale_op
=
main_block
->
AppendOp
();
auto
*
fetch_op
=
main_block
->
AppendOp
();
main_block
->
Var
(
"x"
);
main_block
->
Var
(
"feed"
);
main_block
->
Var
(
"scale_out"
);
main_block
->
Var
(
"fetch_out"
);
scope
->
Var
(
"x"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"feed"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"scale_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"fetch_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
feed_op
->
SetType
(
"feed"
);
feed_op
->
SetInput
(
"X"
,
{
"x"
});
feed_op
->
SetAttr
(
"col"
,
1
);
feed_op
->
SetOutput
(
"Out"
,
{
"feed"
});
scale_op
->
SetType
(
"scale"
);
scale_op
->
SetInput
(
"X"
,
{
"feed"
});
scale_op
->
SetOutput
(
"Out"
,
{
"scale_out"
});
scale_op
->
SetAttr
(
"scale"
,
1.
f
);
scale_op
->
SetAttr
(
"bias"
,
0.
f
);
scale_op
->
SetAttr
(
"bias_after_scale"
,
true
);
fetch_op
->
SetType
(
"fetch"
);
fetch_op
->
SetInput
(
"X"
,
{
"scale_out"
});
fetch_op
->
SetOutput
(
"Out"
,
{
"fetch"
});
fetch_op
->
SetAttr
(
"col"
,
1
);
program_desc
->
Flush
();
lite
::
Program
program
(
*
program_desc
->
Proto
(),
scope
,
valid_places
);
auto
graph
=
std
::
unique_ptr
<
SSAGraph
>
(
new
SSAGraph
());
graph
->
Build
(
program
,
valid_places
);
LOG
(
INFO
)
<<
Visualize
(
graph
.
get
());
return
graph
;
}
TEST
(
identity_test
,
test
)
{
framework
::
ProgramDesc
program_desc
;
std
::
vector
<
Place
>
places
{{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}};
auto
scope
=
std
::
make_shared
<
Scope
>
();
auto
graph
=
BuildGraph
(
&
program_desc
,
scope
,
places
);
const
int
num_nodes
=
graph
->
nodes
().
size
();
auto
pass
=
PassManager
::
Global
().
LookUp
(
"identity_scale_eliminate_pass"
);
ASSERT_TRUE
(
pass
);
pass
->
Apply
(
graph
);
ASSERT_EQ
(
graph
->
nodes
().
size
(),
num_nodes
-
2UL
);
}
}
// namespace mir
}
// namespace lite
}
// namespace paddle
USE_LITE_OP
(
feed
)
USE_LITE_OP
(
fetch
)
USE_LITE_OP
(
scale
)
USE_MIR_PASS
(
identity_scale_eliminate_pass
)
paddle/fluid/lite/core/mir/fusion/CMakeLists.txt
浏览文件 @
4fe5c8aa
...
@@ -10,7 +10,6 @@ cc_library(fuse_conv_bn
...
@@ -10,7 +10,6 @@ cc_library(fuse_conv_bn
cc_library
(
fuse_elementwise_add_activation
cc_library
(
fuse_elementwise_add_activation
SRCS elementwise_add_activation_fuser.cc
SRCS elementwise_add_activation_fuser.cc
DEPS pattern_matcher_high_api
)
DEPS pattern_matcher_high_api
)
cc_library
(
fuse_quant_dequant
cc_library
(
fuse_quant_dequant
SRCS quant_dequant_op_fuser.cc
SRCS quant_dequant_op_fuser.cc
DEPS pattern_matcher_high_api
)
DEPS pattern_matcher_high_api
)
...
...
paddle/fluid/lite/core/mir/conv_bn_fuse_pass.cc
→
paddle/fluid/lite/core/mir/
fusion/
conv_bn_fuse_pass.cc
浏览文件 @
4fe5c8aa
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/
fusion/
conv_bn_fuse_pass.h"
#include <memory>
#include <memory>
#include <vector>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.h"
#include "paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.h"
...
...
paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h
→
paddle/fluid/lite/core/mir/
fusion/
conv_bn_fuse_pass.h
浏览文件 @
4fe5c8aa
文件已移动
paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass_test.cc
浏览文件 @
4fe5c8aa
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/
fusion/
conv_bn_fuse_pass.h"
#include <gflags/gflags.h>
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include <vector>
#include <vector>
...
...
paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.cc
浏览文件 @
4fe5c8aa
...
@@ -70,7 +70,7 @@ void ConvBNFuser::BuildPattern() {
...
@@ -70,7 +70,7 @@ void ConvBNFuser::BuildPattern() {
void
ConvBNFuser
::
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
{
void
ConvBNFuser
::
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
{
auto
op_desc
=
GenOpDesc
(
matched
);
auto
op_desc
=
GenOpDesc
(
matched
);
auto
eltwise_op
=
LiteOpRegistry
::
Global
().
Create
(
"elementwise_add"
);
auto
eltwise_op
=
LiteOpRegistry
::
Global
().
Create
(
"elementwise_add"
);
auto
conv
=
matched
.
at
(
"conv2d"
)
->
stmt
()
->
op
;
auto
conv
=
matched
.
at
(
"conv2d"
)
->
stmt
()
->
op
()
;
auto
*
scope
=
conv
->
scope
();
auto
*
scope
=
conv
->
scope
();
auto
&
valid_places
=
conv
->
valid_places
();
auto
&
valid_places
=
conv
->
valid_places
();
...
...
paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.cc
→
paddle/fluid/lite/core/mir/
fusion/
conv_elementwise_add_activation_fuse_pass.cc
浏览文件 @
4fe5c8aa
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/
fusion/
conv_elementwise_add_activation_fuse_pass.h"
#include <memory>
#include <memory>
#include <vector>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.h"
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.h"
...
...
paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.h
→
paddle/fluid/lite/core/mir/
fusion/
conv_elementwise_add_activation_fuse_pass.h
浏览文件 @
4fe5c8aa
文件已移动
paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass_test.cc
→
paddle/fluid/lite/core/mir/
fusion/
conv_elementwise_add_activation_fuse_pass_test.cc
浏览文件 @
4fe5c8aa
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/
fusion/
conv_elementwise_add_activation_fuse_pass.h"
#include <gflags/gflags.h>
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include <vector>
#include <vector>
...
...
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.cc
浏览文件 @
4fe5c8aa
...
@@ -65,7 +65,7 @@ void ConvElementwiseAddActivationFuser::InsertNewNode(
...
@@ -65,7 +65,7 @@ void ConvElementwiseAddActivationFuser::InsertNewNode(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
{
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
{
auto
op_desc
=
GenOpDesc
(
matched
);
auto
op_desc
=
GenOpDesc
(
matched
);
auto
conv_op
=
LiteOpRegistry
::
Global
().
Create
(
conv_type_
);
auto
conv_op
=
LiteOpRegistry
::
Global
().
Create
(
conv_type_
);
auto
conv_old
=
matched
.
at
(
"conv2d"
)
->
stmt
()
->
op
;
auto
conv_old
=
matched
.
at
(
"conv2d"
)
->
stmt
()
->
op
()
;
auto
*
scope
=
conv_old
->
scope
();
auto
*
scope
=
conv_old
->
scope
();
auto
&
valid_places
=
conv_old
->
valid_places
();
auto
&
valid_places
=
conv_old
->
valid_places
();
conv_op
->
Attach
(
op_desc
,
scope
);
conv_op
->
Attach
(
op_desc
,
scope
);
...
...
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass.cc
0 → 100644
浏览文件 @
4fe5c8aa
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
void
ConvElementwiseAddReLUFusePass
::
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
{
fusion
::
ConvElementwiseAddReLUFuser
fuser
(
"conv2d"
);
fuser
(
graph
.
get
());
fusion
::
ConvElementwiseAddReLUFuser
depthwise_fuser
(
"depthwise_conv2d"
);
depthwise_fuser
(
graph
.
get
());
}
}
// namespace mir
}
// namespace lite
}
// namespace paddle
REGISTER_MIR_PASS
(
lite_conv_elementwise_add_act_fuse_pass
,
paddle
::
lite
::
mir
::
ConvElementwiseAddReLUFusePass
);
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass.h
0 → 100644
浏览文件 @
4fe5c8aa
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pass.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
class
ConvElementwiseAddReLUFusePass
:
public
ProgramPass
{
public:
void
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
override
;
};
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass_test.cc
0 → 100644
浏览文件 @
4fe5c8aa
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/api/cxx_api.h"
#include "paddle/fluid/lite/core/compatible_tensor.h"
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/program.h"
DEFINE_string
(
model_dir
,
""
,
""
);
DEFINE_string
(
optimized_model
,
""
,
""
);
namespace
paddle
{
namespace
lite
{
namespace
mir
{
namespace
fusion
{
std
::
unique_ptr
<
SSAGraph
>
BuildGraph
(
framework
::
ProgramDesc
*
program_desc
,
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
std
::
vector
<
Place
>&
valid_places
)
{
auto
*
main_block
=
program_desc
->
MutableBlock
(
0
);
auto
*
conv2d_1
=
main_block
->
AppendOp
();
auto
*
conv2d_2
=
main_block
->
AppendOp
();
auto
*
add_1
=
main_block
->
AppendOp
();
auto
*
relu_1
=
main_block
->
AppendOp
();
auto
*
add_2
=
main_block
->
AppendOp
();
auto
*
relu_2
=
main_block
->
AppendOp
();
main_block
->
Var
(
"input_1"
);
main_block
->
Var
(
"input_2"
);
main_block
->
Var
(
"filter_1"
);
main_block
->
Var
(
"filter_2"
);
main_block
->
Var
(
"conv2d_1_out"
);
main_block
->
Var
(
"conv2d_2_out"
);
main_block
->
Var
(
"bias_1"
);
main_block
->
Var
(
"add_1_out"
);
main_block
->
Var
(
"add_2_out"
);
main_block
->
Var
(
"relu_1_out"
);
main_block
->
Var
(
"out"
);
scope
->
Var
(
"input_1"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"input_2"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"filter_1"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"filter_2"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"conv2d_1_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"conv2d_2_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"bias_1"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"add_1_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"add_2_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"relu_1_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"out"
)
->
GetMutable
<
lite
::
Tensor
>
();
conv2d_1
->
SetType
(
"conv2d"
);
conv2d_1
->
SetInput
(
"Input"
,
{
"input_1"
});
conv2d_1
->
SetInput
(
"Filter"
,
{
"filter_1"
});
conv2d_1
->
SetOutput
(
"Output"
,
{
"conv2d_1_out"
});
conv2d_1
->
SetAttr
(
"strides"
,
std
::
vector
<
int
>
({
1
,
1
}));
conv2d_1
->
SetAttr
(
"paddings"
,
std
::
vector
<
int
>
({
0
,
0
}));
conv2d_1
->
SetAttr
(
"groups"
,
1
);
conv2d_1
->
SetAttr
(
"dilations"
,
std
::
vector
<
int
>
({
1
,
1
}));
conv2d_1
->
SetAttr
(
"fuse_relu"
,
false
);
add_1
->
SetType
(
"elementwise_add"
);
add_1
->
SetInput
(
"X"
,
{
"conv2d_1_out"
});
add_1
->
SetInput
(
"Y"
,
{
"bias_1"
});
add_1
->
SetOutput
(
"Out"
,
{
"add_1_out"
});
add_1
->
SetAttr
(
"axis"
,
1
);
relu_1
->
SetType
(
"relu"
);
relu_1
->
SetInput
(
"X"
,
{
"add_1_out"
});
relu_1
->
SetOutput
(
"Out"
,
{
"relu_1_out"
});
conv2d_2
->
SetType
(
"conv2d"
);
conv2d_2
->
SetInput
(
"Input"
,
{
"input_2"
});
conv2d_2
->
SetInput
(
"Filter"
,
{
"filter_2"
});
conv2d_2
->
SetOutput
(
"Output"
,
{
"conv2d_2_out"
});
conv2d_2
->
SetAttr
(
"strides"
,
std
::
vector
<
int
>
({
1
,
1
}));
conv2d_2
->
SetAttr
(
"paddings"
,
std
::
vector
<
int
>
({
0
,
0
}));
conv2d_2
->
SetAttr
(
"groups"
,
1
);
conv2d_2
->
SetAttr
(
"dilations"
,
std
::
vector
<
int
>
({
1
,
1
}));
conv2d_2
->
SetAttr
(
"fuse_relu"
,
false
);
add_2
->
SetType
(
"elementwise_add"
);
add_2
->
SetInput
(
"X"
,
{
"conv2d_2_out"
});
add_2
->
SetInput
(
"Y"
,
{
"relu_1_out"
});
add_2
->
SetOutput
(
"Out"
,
{
"add_2_out"
});
add_2
->
SetAttr
(
"axis"
,
1
);
relu_2
->
SetType
(
"relu"
);
relu_2
->
SetInput
(
"X"
,
{
"add_2_out"
});
relu_2
->
SetOutput
(
"Out"
,
{
"out"
});
program_desc
->
Flush
();
lite
::
Program
program
(
*
program_desc
->
Proto
(),
scope
,
valid_places
);
auto
graph
=
std
::
unique_ptr
<
SSAGraph
>
(
new
SSAGraph
());
graph
->
Build
(
program
,
valid_places
);
return
graph
;
}
TEST
(
conv_elementwise_add_relu_fuse_pass
,
graph_test
)
{
framework
::
ProgramDesc
program_desc
;
std
::
vector
<
Place
>
places
{{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}};
auto
scope
=
std
::
make_shared
<
Scope
>
();
auto
graph
=
BuildGraph
(
&
program_desc
,
scope
,
places
);
Visualize
(
graph
.
get
());
ASSERT_EQ
(
graph
->
nodes
().
size
(),
11UL
/*vars*/
+
6UL
/*ops*/
);
Visualize
(
graph
.
get
());
}
TEST
(
conv_elementwise_add_relu_fuse_pass
,
fuse_test_op
)
{
framework
::
ProgramDesc
program_desc
;
std
::
vector
<
Place
>
places
{{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}};
auto
scope
=
std
::
make_shared
<
Scope
>
();
auto
graph
=
BuildGraph
(
&
program_desc
,
scope
,
places
);
Visualize
(
graph
.
get
());
const
int
num_nodes
=
graph
->
nodes
().
size
();
auto
*
fuser
=
new
ConvElementwiseAddReLUFusePass
;
fuser
->
Apply
(
graph
);
Visualize
(
graph
.
get
());
ASSERT_EQ
(
graph
->
nodes
().
size
(),
num_nodes
-
5UL
*
2
/*nodes removed */
+
1UL
*
2
/* fused fc node*/
);
}
}
// namespace fusion
}
// namespace mir
}
// namespace lite
}
// namespace paddle
USE_LITE_OP
(
elementwise_add
);
USE_LITE_OP
(
conv2d
);
USE_LITE_OP
(
depthwise_conv2d
);
USE_LITE_OP
(
relu
);
paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.cc
→
paddle/fluid/lite/core/mir/
fusion/
elementwise_add_activation_fuse_pass.cc
浏览文件 @
4fe5c8aa
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/
fusion/
elementwise_add_activation_fuse_pass.h"
#include <memory>
#include <memory>
#include <vector>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.h"
#include "paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.h"
...
...
paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.h
→
paddle/fluid/lite/core/mir/
fusion/
elementwise_add_activation_fuse_pass.h
浏览文件 @
4fe5c8aa
文件已移动
paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass_test.cc
→
paddle/fluid/lite/core/mir/
fusion/
elementwise_add_activation_fuse_pass_test.cc
浏览文件 @
4fe5c8aa
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/
fusion/
elementwise_add_activation_fuse_pass.h"
#include <gflags/gflags.h>
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include <vector>
#include <vector>
...
...
paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.cc
浏览文件 @
4fe5c8aa
...
@@ -54,7 +54,7 @@ void ElementwiseAddActivationFuser::InsertNewNode(SSAGraph* graph,
...
@@ -54,7 +54,7 @@ void ElementwiseAddActivationFuser::InsertNewNode(SSAGraph* graph,
auto
op_desc
=
GenOpDesc
(
matched
);
auto
op_desc
=
GenOpDesc
(
matched
);
auto
op
=
auto
op
=
LiteOpRegistry
::
Global
().
Create
(
"fusion_elementwise_add_activation"
);
LiteOpRegistry
::
Global
().
Create
(
"fusion_elementwise_add_activation"
);
auto
old_op
=
matched
.
at
(
"add"
)
->
stmt
()
->
op
;
auto
old_op
=
matched
.
at
(
"add"
)
->
stmt
()
->
op
()
;
auto
*
scope
=
old_op
->
scope
();
auto
*
scope
=
old_op
->
scope
();
auto
&
valid_places
=
old_op
->
valid_places
();
auto
&
valid_places
=
old_op
->
valid_places
();
op
->
Attach
(
op_desc
,
scope
);
op
->
Attach
(
op_desc
,
scope
);
...
...
paddle/fluid/lite/core/mir/fc_fuse_pass.cc
→
paddle/fluid/lite/core/mir/f
usion/f
c_fuse_pass.cc
浏览文件 @
4fe5c8aa
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fc_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/f
usion/f
c_fuse_pass.h"
#include <memory>
#include <memory>
#include <vector>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/fc_fuser.h"
#include "paddle/fluid/lite/core/mir/fusion/fc_fuser.h"
...
...
paddle/fluid/lite/core/mir/fc_fuse_pass.h
→
paddle/fluid/lite/core/mir/f
usion/f
c_fuse_pass.h
浏览文件 @
4fe5c8aa
文件已移动
paddle/fluid/lite/core/mir/fc_fuse_pass_test.cc
→
paddle/fluid/lite/core/mir/f
usion/f
c_fuse_pass_test.cc
浏览文件 @
4fe5c8aa
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fc_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/f
usion/f
c_fuse_pass.h"
#include <gflags/gflags.h>
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include <vector>
#include <vector>
...
...
paddle/fluid/lite/core/mir/fusion/fc_fuser.cc
浏览文件 @
4fe5c8aa
...
@@ -46,7 +46,7 @@ void FcFuser::BuildPattern() {
...
@@ -46,7 +46,7 @@ void FcFuser::BuildPattern() {
void
FcFuser
::
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
{
void
FcFuser
::
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
{
auto
op_desc
=
GenOpDesc
(
matched
);
auto
op_desc
=
GenOpDesc
(
matched
);
auto
fc_op
=
LiteOpRegistry
::
Global
().
Create
(
"fc"
);
auto
fc_op
=
LiteOpRegistry
::
Global
().
Create
(
"fc"
);
auto
mul
=
matched
.
at
(
"mul"
)
->
stmt
()
->
op
;
auto
mul
=
matched
.
at
(
"mul"
)
->
stmt
()
->
op
()
;
auto
*
scope
=
mul
->
scope
();
auto
*
scope
=
mul
->
scope
();
auto
&
valid_places
=
mul
->
valid_places
();
auto
&
valid_places
=
mul
->
valid_places
();
fc_op
->
Attach
(
op_desc
,
scope
);
fc_op
->
Attach
(
op_desc
,
scope
);
...
...
paddle/fluid/lite/core/mir/quant_dequant_fuse_pass.cc
→
paddle/fluid/lite/core/mir/
fusion/
quant_dequant_fuse_pass.cc
浏览文件 @
4fe5c8aa
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/lite/core/mir/quant_dequant_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/
fusion/
quant_dequant_fuse_pass.h"
#include <memory>
#include <memory>
#include <vector>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/quant_dequant_op_fuser.h"
#include "paddle/fluid/lite/core/mir/fusion/quant_dequant_op_fuser.h"
...
...
paddle/fluid/lite/core/mir/quant_dequant_fuse_pass.h
→
paddle/fluid/lite/core/mir/
fusion/
quant_dequant_fuse_pass.h
浏览文件 @
4fe5c8aa
文件已移动
paddle/fluid/lite/core/mir/fusion/quant_dequant_op_fuser.cc
浏览文件 @
4fe5c8aa
...
@@ -115,8 +115,8 @@ void QuantDequantOpFuser::InsertNewNode(SSAGraph* graph,
...
@@ -115,8 +115,8 @@ void QuantDequantOpFuser::InsertNewNode(SSAGraph* graph,
nodes
.
push_back
(
matched
.
at
(
"dequant_op_out"
+
std
::
to_string
(
i
)));
nodes
.
push_back
(
matched
.
at
(
"dequant_op_out"
+
std
::
to_string
(
i
)));
}
}
int
bit_length
=
quant_op
->
stmt
()
->
op_info
()
->
GetAttr
<
int
>
(
"bit_length"
);
int
bit_length
=
quant_op
->
stmt
()
->
op_info
()
->
GetAttr
<
int
>
(
"bit_length"
);
auto
*
scope
=
quant_op
->
stmt
()
->
op
->
scope
();
auto
*
scope
=
quant_op
->
stmt
()
->
op
()
->
scope
();
auto
&
valid_places
=
quant_op
->
stmt
()
->
op
->
valid_places
();
auto
&
valid_places
=
quant_op
->
stmt
()
->
op
()
->
valid_places
();
int
range
=
((
1
<<
(
bit_length
-
1
))
-
1
);
int
range
=
((
1
<<
(
bit_length
-
1
))
-
1
);
auto
input_scale_t
=
scope
->
FindVar
(
quant_op_in_scale
->
arg
()
->
name
)
auto
input_scale_t
=
scope
->
FindVar
(
quant_op_in_scale
->
arg
()
->
name
)
->
GetMutable
<
lite
::
Tensor
>
();
->
GetMutable
<
lite
::
Tensor
>
();
...
...
paddle/fluid/lite/core/mir/generate_program_pass.cc
浏览文件 @
4fe5c8aa
...
@@ -29,7 +29,7 @@ void GenerateProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
...
@@ -29,7 +29,7 @@ void GenerateProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
if
(
item
->
IsStmt
())
{
if
(
item
->
IsStmt
())
{
auto
&
stmt
=
item
->
AsStmt
();
auto
&
stmt
=
item
->
AsStmt
();
VLOG
(
4
)
<<
stmt
;
VLOG
(
4
)
<<
stmt
;
insts_
.
emplace_back
(
stmt
.
op
,
std
::
move
(
stmt
.
valid_kernels
.
front
()));
insts_
.
emplace_back
(
stmt
.
op
(),
std
::
move
(
stmt
.
kernels
()
.
front
()));
}
}
}
}
}
}
...
...
paddle/fluid/lite/core/mir/graph_visualize_pass.cc
浏览文件 @
4fe5c8aa
...
@@ -39,7 +39,7 @@ std::string Visualize(mir::SSAGraph* graph) {
...
@@ -39,7 +39,7 @@ std::string Visualize(mir::SSAGraph* graph) {
if
(
node
.
IsArg
())
{
if
(
node
.
IsArg
())
{
key
=
node
.
AsArg
().
name
;
key
=
node
.
AsArg
().
name
;
}
else
{
}
else
{
key
=
node
.
AsStmt
().
op_type
+
std
::
to_string
(
id
++
);
key
=
node
.
AsStmt
().
op_type
()
+
std
::
to_string
(
id
++
);
}
}
if
(
node
.
IsStmt
())
{
if
(
node
.
IsStmt
())
{
...
...
paddle/fluid/lite/core/mir/io_copy_kernel_pick_pass.cc
浏览文件 @
4fe5c8aa
...
@@ -25,11 +25,11 @@ class IoCopyKernelPickPass : public StmtPass {
...
@@ -25,11 +25,11 @@ class IoCopyKernelPickPass : public StmtPass {
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
if
(
!
node
.
IsStmt
())
continue
;
if
(
!
node
.
IsStmt
())
continue
;
auto
&
inst
=
node
.
AsStmt
();
auto
&
inst
=
node
.
AsStmt
();
if
(
inst
.
op_type
!=
"io_copy"
)
continue
;
if
(
inst
.
op_type
()
!=
"io_copy"
)
continue
;
LOG
(
INFO
)
<<
"....> picking a IO COPY kernel"
;
LOG
(
INFO
)
<<
"....> picking a IO COPY kernel"
;
auto
&
kernels
=
node
.
AsStmt
().
valid_kernels
;
auto
&
kernels
=
node
.
AsStmt
().
kernels
()
;
CHECK
(
!
kernels
.
empty
())
<<
"No valid kernels found for IoCopy Op"
;
CHECK
(
!
kernels
.
empty
())
<<
"No valid kernels found for IoCopy Op"
;
const
auto
*
inty
=
node
.
inlinks
.
front
()
->
AsArg
().
type
;
const
auto
*
inty
=
node
.
inlinks
.
front
()
->
AsArg
().
type
;
const
auto
*
outy
=
node
.
outlinks
.
front
()
->
AsArg
().
type
;
const
auto
*
outy
=
node
.
outlinks
.
front
()
->
AsArg
().
type
;
...
...
paddle/fluid/lite/core/mir/node.cc
浏览文件 @
4fe5c8aa
...
@@ -13,3 +13,62 @@
...
@@ -13,3 +13,62 @@
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/lite/core/mir/node.h"
#include "paddle/fluid/lite/core/mir/node.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace
paddle
{
namespace
lite
{
const
OpInfo
*
mir
::
Node
::
Stmt
::
op_info
()
const
{
CHECK
(
op_
);
return
op_
->
op_info
();
}
Place
mir
::
Node
::
Stmt
::
place
()
const
{
CHECK
(
!
valid_kernels_
.
empty
());
return
valid_kernels_
.
front
()
->
place
();
}
KernelBase
&
mir
::
Node
::
Stmt
::
picked_kernel
()
{
CHECK
(
!
valid_kernels_
.
empty
())
<<
"no kernel for "
<<
op_type
();
return
*
valid_kernels_
.
front
();
}
OpInfo
*
mir
::
Node
::
Stmt
::
mutable_op_info
()
{
CHECK
(
op_
);
return
op_
->
mutable_op_info
();
}
void
mir
::
Node
::
Stmt
::
ResetOp
(
const
cpp
::
OpDesc
&
op_desc
,
const
std
::
vector
<
Place
>
&
valid_places
,
lite
::
Scope
*
scope
)
{
CHECK
((
op_
&&
op_
->
scope
())
||
scope
)
<<
"Either scope should be set"
;
lite
::
Scope
*
the_scope
=
scope
?
scope
:
op_
->
scope
();
op_
->
Attach
(
op_desc
,
the_scope
);
// Recreate the kernels with the latest OpInfo.
valid_kernels_
.
clear
();
if
(
!
op_
||
op_
->
op_info
()
->
Type
()
!=
op_desc
.
Type
())
{
op_
=
LiteOpRegistry
::
Global
().
Create
(
op_desc
.
Type
());
CHECK
(
op_
)
<<
"No op found for "
<<
op_desc
.
Type
();
}
valid_kernels_
=
op_
->
CreateKernels
(
valid_places
);
}
std
::
ostream
&
mir
::
operator
<<
(
std
::
ostream
&
os
,
const
mir
::
Node
::
Stmt
&
other
)
{
os
<<
"Statement "
<<
other
.
op_type
()
<<
" "
<<
other
.
place
();
return
os
;
}
mir
::
Node
::
Arg
&
mir
::
Node
::
AsArg
(
const
std
::
string
&
name
,
int
id
)
{
auto
&
x
=
AsArg
();
x
.
name
=
name
;
x
.
id
=
id
;
return
x
;
}
mir
::
Node
::
Arg
&
mir
::
Node
::
AsArg
(
const
std
::
string
&
name
)
{
auto
&
x
=
AsArg
();
x
.
name
=
name
;
return
x
;
}
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/node.h
浏览文件 @
4fe5c8aa
...
@@ -41,32 +41,40 @@ class Node {
...
@@ -41,32 +41,40 @@ class Node {
kUnk
,
kUnk
,
};
};
struct
Stmt
{
class
Stmt
{
std
::
string
op_type
;
// The kernel instances this Statement contains.
// The kernel instances this Statement contains.
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>
valid_kernels
;
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>
valid_kernels
_
;
// TODO(Superjomn) make this a shared_ptr for resource safety.
// TODO(Superjomn) make this a shared_ptr for resource safety.
std
::
shared_ptr
<
OpLite
>
op
;
// we hold op to run InferShape
std
::
shared_ptr
<
OpLite
>
op
_
;
// we hold op to run InferShape
const
OpInfo
*
op_info
()
{
public:
CHECK
(
op
);
// Refresh the operator and kernels with the latest OpInfo.
return
op
->
op_info
();
void
ResetOp
(
const
cpp
::
OpDesc
&
op_desc
,
}
const
std
::
vector
<
Place
>&
valid_places
,
lite
::
Scope
*
scope
=
nullptr
);
Place
place
()
const
{
std
::
string
op_type
()
const
{
return
op_info
()
->
Type
();
}
CHECK
(
!
valid_kernels
.
empty
());
const
OpInfo
*
op_info
()
const
;
return
valid_kernels
.
front
()
->
place
();
OpInfo
*
mutable_op_info
();
}
KernelBase
&
picked_kernel
()
{
void
SetKernels
(
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>&&
kernels
)
{
CHECK
(
!
valid_kernels
.
empty
())
<<
"no kernel for "
<<
op_type
;
valid_kernels_
=
std
::
move
(
kernels
);
return
*
valid_kernels
.
front
();
}
}
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>&
kernels
()
{
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Stmt
&
other
)
{
return
valid_kernels_
;
os
<<
"Statement "
<<
other
.
op_type
<<
" "
<<
other
.
place
();
return
os
;
}
}
void
SetOp
(
const
std
::
shared_ptr
<
OpLite
>&
op
)
{
op_
=
op
;
}
const
std
::
shared_ptr
<
OpLite
>
op
()
const
{
return
op_
;
}
Place
place
()
const
;
KernelBase
&
picked_kernel
();
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Stmt
&
other
);
// Description.
std
::
string
desc
;
};
};
struct
Arg
{
struct
Arg
{
...
@@ -78,26 +86,16 @@ class Node {
...
@@ -78,26 +86,16 @@ class Node {
bool
is_weight
{
false
};
bool
is_weight
{
false
};
};
};
Arg
&
AsArg
(
const
std
::
string
&
name
,
int
id
)
{
Arg
&
AsArg
(
const
std
::
string
&
name
,
int
id
);
auto
&
x
=
AsArg
();
x
.
name
=
name
;
x
.
id
=
id
;
return
x
;
}
Arg
&
AsArg
(
const
std
::
string
&
name
)
{
Arg
&
AsArg
(
const
std
::
string
&
name
);
auto
&
x
=
AsArg
();
x
.
name
=
name
;
return
x
;
}
Stmt
&
AsStmt
(
const
std
::
string
&
op_type
,
Stmt
&
AsStmt
(
const
std
::
string
&
op_type
,
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>&&
kernels
,
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>&&
kernels
,
const
std
::
shared_ptr
<
OpLite
>&
op
)
{
const
std
::
shared_ptr
<
OpLite
>&
op
)
{
auto
&
x
=
AsStmt
();
auto
&
x
=
AsStmt
();
x
.
op_type
=
op_type
;
x
.
SetOp
(
op
);
x
.
op
=
op
;
x
.
SetKernels
(
std
::
move
(
kernels
));
x
.
valid_kernels
=
std
::
move
(
kernels
);
return
x
;
return
x
;
}
}
...
@@ -142,7 +140,7 @@ class Node {
...
@@ -142,7 +140,7 @@ class Node {
}
}
if
(
other
.
IsStmt
())
{
if
(
other
.
IsStmt
())
{
auto
&
arg
=
other
.
AsStmt
();
auto
&
arg
=
other
.
AsStmt
();
os
<<
"Statement "
<<
arg
.
op_type
;
os
<<
"Statement "
<<
arg
.
op_type
()
;
}
}
return
os
;
return
os
;
}
}
...
...
paddle/fluid/lite/core/mir/pattern_matcher.h
浏览文件 @
4fe5c8aa
...
@@ -139,14 +139,13 @@ struct PMNode {
...
@@ -139,14 +139,13 @@ struct PMNode {
template
<
typename
T
>
template
<
typename
T
>
PMNode
*
assert_op_attr
(
const
std
::
string
&
attr_name
,
const
T
&
attr
)
{
PMNode
*
assert_op_attr
(
const
std
::
string
&
attr_name
,
const
T
&
attr
)
{
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
asserts_
.
push_back
([
=
](
const
Node
*
x
)
{
if
(
x
&&
x
->
IsStmt
())
{
if
(
x
&&
x
->
IsStmt
())
{
auto
*
op_info
=
x
->
stmt
()
->
op_info
();
auto
*
op_info
=
x
->
stmt
()
->
op_info
();
return
op_info
->
HasAttr
(
attr_name
)
&&
return
op_info
->
HasAttr
(
attr_name
)
&&
op_info
->
GetAttr
<
T
>
(
attr_name
)
==
attr
;
op_info
->
GetAttr
<
T
>
(
attr_name
)
==
attr
;
}
else
{
return
false
;
}
}
return
false
;
});
});
return
this
;
return
this
;
}
}
...
...
paddle/fluid/lite/core/mir/pattern_matcher_high_api.cc
浏览文件 @
4fe5c8aa
...
@@ -41,6 +41,7 @@ void FuseBase::DeleteInterNodes(SSAGraph *graph) {
...
@@ -41,6 +41,7 @@ void FuseBase::DeleteInterNodes(SSAGraph *graph) {
}
}
}
}
LOG
(
INFO
)
<<
"keys: "
<<
key2nodes_
.
size
();
std
::
unordered_set
<
const
Node
*>
nodes2rm
;
std
::
unordered_set
<
const
Node
*>
nodes2rm
;
for
(
auto
&
matched
:
key2nodes_
)
{
for
(
auto
&
matched
:
key2nodes_
)
{
for
(
const
auto
&
key
:
keys
)
{
for
(
const
auto
&
key
:
keys
)
{
...
...
paddle/fluid/lite/core/mir/pattern_matcher_high_api.h
浏览文件 @
4fe5c8aa
...
@@ -49,7 +49,13 @@ class FuseBase {
...
@@ -49,7 +49,13 @@ class FuseBase {
virtual
void
BuildPattern
()
=
0
;
virtual
void
BuildPattern
()
=
0
;
// Generate an operator desc with a matched subgraph.
// Generate an operator desc with a matched subgraph.
virtual
cpp
::
OpDesc
GenOpDesc
(
const
key2nodes_t
&
matched
)
=
0
;
virtual
cpp
::
OpDesc
GenOpDesc
(
const
key2nodes_t
&
matched
)
{
return
cpp
::
OpDesc
();
}
PMNode
*
OpNode
(
const
std
::
string
&
key
)
{
return
GetOrCreateNode
(
key
)
->
assert_is_op
();
}
PMNode
*
OpNode
(
const
std
::
string
&
key
,
const
std
::
string
&
op_type
);
PMNode
*
OpNode
(
const
std
::
string
&
key
,
const
std
::
string
&
op_type
);
...
...
paddle/fluid/lite/core/mir/pattern_matcher_high_api_test.cc
浏览文件 @
4fe5c8aa
...
@@ -52,7 +52,7 @@ class FcFuser : public FuseBase {
...
@@ -52,7 +52,7 @@ class FcFuser : public FuseBase {
void
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
override
{
void
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
override
{
auto
op_desc
=
GenOpDesc
(
matched
);
auto
op_desc
=
GenOpDesc
(
matched
);
auto
fc_op
=
LiteOpRegistry
::
Global
().
Create
(
"fc"
);
auto
fc_op
=
LiteOpRegistry
::
Global
().
Create
(
"fc"
);
auto
mul
=
matched
.
at
(
"mul"
)
->
stmt
()
->
op
;
auto
mul
=
matched
.
at
(
"mul"
)
->
stmt
()
->
op
()
;
auto
*
scope
=
mul
->
scope
();
auto
*
scope
=
mul
->
scope
();
auto
&
valid_places
=
mul
->
valid_places
();
auto
&
valid_places
=
mul
->
valid_places
();
fc_op
->
Attach
(
op_desc
,
scope
);
fc_op
->
Attach
(
op_desc
,
scope
);
...
@@ -90,7 +90,7 @@ std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc,
...
@@ -90,7 +90,7 @@ std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc,
main_block
->
Var
(
"w"
);
main_block
->
Var
(
"w"
);
main_block
->
Var
(
"out"
);
main_block
->
Var
(
"out"
);
scope
->
Var
(
"
w
"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"
x
"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"b"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"b"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"mul_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"mul_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"w"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"w"
)
->
GetMutable
<
lite
::
Tensor
>
();
...
...
paddle/fluid/lite/core/mir/pattern_matcher_test.cc
浏览文件 @
4fe5c8aa
...
@@ -23,19 +23,19 @@ namespace mir {
...
@@ -23,19 +23,19 @@ namespace mir {
void
BuildGraph
(
SSAGraph
*
g
)
{
void
BuildGraph
(
SSAGraph
*
g
)
{
g
->
mutable_nodes
().
emplace_back
();
g
->
mutable_nodes
().
emplace_back
();
Node
&
o1
=
g
->
mutable_nodes
().
back
();
Node
&
o1
=
g
->
mutable_nodes
().
back
();
o1
.
AsStmt
().
op_type
=
"op1"
;
o1
.
AsStmt
().
desc
=
"op1"
;
g
->
mutable_nodes
().
emplace_back
();
g
->
mutable_nodes
().
emplace_back
();
Node
&
o2
=
g
->
mutable_nodes
().
back
();
Node
&
o2
=
g
->
mutable_nodes
().
back
();
o2
.
AsStmt
().
op_type
=
"op2"
;
o2
.
AsStmt
().
desc
=
"op2"
;
g
->
mutable_nodes
().
emplace_back
();
g
->
mutable_nodes
().
emplace_back
();
Node
&
o3
=
g
->
mutable_nodes
().
back
();
Node
&
o3
=
g
->
mutable_nodes
().
back
();
o3
.
AsStmt
().
op_type
=
"op3"
;
o3
.
AsStmt
().
desc
=
"op3"
;
g
->
mutable_nodes
().
emplace_back
();
g
->
mutable_nodes
().
emplace_back
();
Node
&
o4
=
g
->
mutable_nodes
().
back
();
Node
&
o4
=
g
->
mutable_nodes
().
back
();
o4
.
AsStmt
().
op_type
=
"op4"
;
o4
.
AsStmt
().
desc
=
"op4"
;
g
->
mutable_nodes
().
emplace_back
();
g
->
mutable_nodes
().
emplace_back
();
Node
&
o5
=
g
->
mutable_nodes
().
back
();
Node
&
o5
=
g
->
mutable_nodes
().
back
();
o5
.
AsStmt
().
op_type
=
"op5"
;
o5
.
AsStmt
().
desc
=
"op5"
;
g
->
mutable_nodes
().
emplace_back
();
g
->
mutable_nodes
().
emplace_back
();
Node
&
v1
=
g
->
mutable_nodes
().
back
();
Node
&
v1
=
g
->
mutable_nodes
().
back
();
v1
.
AsArg
(
"var1"
);
v1
.
AsArg
(
"var1"
);
...
@@ -108,11 +108,11 @@ TEST(PatternMatcher, MarkPMNodesInGraph) {
...
@@ -108,11 +108,11 @@ TEST(PatternMatcher, MarkPMNodesInGraph) {
// v2 -> o3(a node named o3)
// v2 -> o3(a node named o3)
auto
*
o2
=
x
.
pattern_
.
NewNode
([](
const
Node
*
node
)
{
auto
*
o2
=
x
.
pattern_
.
NewNode
([](
const
Node
*
node
)
{
// The teller can be any condition, such as op type, or variable's shape.
// The teller can be any condition, such as op type, or variable's shape.
return
node
&&
node
->
IsStmt
()
&&
node
->
stmt
()
->
op_type
==
"op2"
;
return
node
&&
node
->
IsStmt
()
&&
node
->
stmt
()
->
desc
==
"op2"
;
});
});
auto
*
o3
=
x
.
pattern_
.
NewNode
([](
const
Node
*
node
)
{
auto
*
o3
=
x
.
pattern_
.
NewNode
([](
const
Node
*
node
)
{
// The teller can be any condition, such as op type, or variable's shape.
// The teller can be any condition, such as op type, or variable's shape.
return
node
&&
node
->
IsStmt
()
&&
node
->
stmt
()
->
op_type
==
"op3"
;
return
node
&&
node
->
IsStmt
()
&&
node
->
stmt
()
->
desc
==
"op3"
;
});
});
auto
*
v2
=
x
.
pattern_
.
NewNode
([](
const
Node
*
node
)
{
auto
*
v2
=
x
.
pattern_
.
NewNode
([](
const
Node
*
node
)
{
// The teller can be any condition, such as op type, or variable's shape.
// The teller can be any condition, such as op type, or variable's shape.
...
@@ -153,8 +153,8 @@ TEST(PatternMatcher, MultiSubgraph) {
...
@@ -153,8 +153,8 @@ TEST(PatternMatcher, MultiSubgraph) {
// op -> var
// op -> var
auto
*
any_op
=
x
.
mutable_pattern
()
->
NewNode
(
auto
*
any_op
=
x
.
mutable_pattern
()
->
NewNode
(
[](
const
Node
*
node
)
{
[](
const
Node
*
node
)
{
return
node
->
IsStmt
()
&&
(
node
->
stmt
()
->
op_type
==
"op2"
||
return
node
->
IsStmt
()
&&
node
->
stmt
()
->
op_type
==
"op3"
);
(
node
->
stmt
()
->
desc
==
"op2"
||
node
->
stmt
()
->
desc
==
"op3"
);
},
},
"OP0"
);
"OP0"
);
auto
*
any_var
=
auto
*
any_var
=
...
@@ -170,9 +170,9 @@ TEST(PatternMatcher, MultiSubgraph) {
...
@@ -170,9 +170,9 @@ TEST(PatternMatcher, MultiSubgraph) {
int
count
=
0
;
int
count
=
0
;
PatternMatcher
::
handle_t
handle
=
[
&
](
const
PatternMatcher
::
subgraph_t
&
s
,
PatternMatcher
::
handle_t
handle
=
[
&
](
const
PatternMatcher
::
subgraph_t
&
s
,
SSAGraph
*
g
)
{
SSAGraph
*
g
)
{
LOG
(
INFO
)
<<
"Detect "
<<
s
.
at
(
any_op
)
->
stmt
()
->
op_type
<<
" -> "
LOG
(
INFO
)
<<
"Detect "
<<
s
.
at
(
any_op
)
->
stmt
()
->
desc
<<
" -> "
<<
s
.
at
(
any_var
)
->
arg
()
->
name
<<
" -> "
<<
s
.
at
(
any_var
)
->
arg
()
->
name
<<
" -> "
<<
s
.
at
(
any_op1
)
->
stmt
()
->
op_type
;
<<
s
.
at
(
any_op1
)
->
stmt
()
->
desc
;
count
++
;
count
++
;
};
};
...
@@ -197,12 +197,12 @@ TEST(PatternMatcher, IntermediateCheck) {
...
@@ -197,12 +197,12 @@ TEST(PatternMatcher, IntermediateCheck) {
PatternMatcher
matcher
;
PatternMatcher
matcher
;
auto
*
op2
=
matcher
.
mutable_pattern
()
->
NewNode
(
auto
*
op2
=
matcher
.
mutable_pattern
()
->
NewNode
(
[](
const
Node
*
x
)
{
[](
const
Node
*
x
)
{
return
x
&&
x
->
IsStmt
()
&&
x
->
stmt
()
->
op_type
==
"op2"
;
return
x
&&
x
->
IsStmt
()
&&
x
->
stmt
()
->
desc
==
"op2"
;
},
},
"op2"
);
"op2"
);
auto
*
op3
=
matcher
.
mutable_pattern
()
->
NewNode
(
auto
*
op3
=
matcher
.
mutable_pattern
()
->
NewNode
(
[](
const
Node
*
x
)
{
[](
const
Node
*
x
)
{
return
x
&&
x
->
IsStmt
()
&&
x
->
stmt
()
->
op_type
==
"op3"
;
return
x
&&
x
->
IsStmt
()
&&
x
->
stmt
()
->
desc
==
"op3"
;
},
},
"op3"
);
"op3"
);
auto
*
v2
=
matcher
.
mutable_pattern
()
auto
*
v2
=
matcher
.
mutable_pattern
()
...
...
paddle/fluid/lite/core/mir/ssa_graph.h
浏览文件 @
4fe5c8aa
...
@@ -65,6 +65,10 @@ class SSAGraph : GraphBase {
...
@@ -65,6 +65,10 @@ class SSAGraph : GraphBase {
Node
*
GraphCreateInstructNode
(
const
std
::
shared_ptr
<
OpLite
>
&
op
,
Node
*
GraphCreateInstructNode
(
const
std
::
shared_ptr
<
OpLite
>
&
op
,
const
std
::
vector
<
Place
>
&
valid_places
);
const
std
::
vector
<
Place
>
&
valid_places
);
// Device related attributes
const
std
::
vector
<
Place
>
&
valid_places
()
const
{
return
valid_places_
;
}
void
SetValidPlaces
(
const
std
::
vector
<
Place
>
&
x
)
{
valid_places_
=
x
;
}
private:
private:
mir
::
Node
*
Argument
(
const
std
::
string
&
name
);
mir
::
Node
*
Argument
(
const
std
::
string
&
name
);
// Check the bidirectional connection.
// Check the bidirectional connection.
...
@@ -89,6 +93,7 @@ class SSAGraph : GraphBase {
...
@@ -89,6 +93,7 @@ class SSAGraph : GraphBase {
private:
private:
std
::
list
<
mir
::
Node
>
node_storage_
;
std
::
list
<
mir
::
Node
>
node_storage_
;
std
::
map
<
std
::
string
,
mir
::
Node
*>
arguments_
;
std
::
map
<
std
::
string
,
mir
::
Node
*>
arguments_
;
std
::
vector
<
Place
>
valid_places_
;
};
};
// Remove the link between a -> b.
// Remove the link between a -> b.
...
...
paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc
浏览文件 @
4fe5c8aa
...
@@ -37,9 +37,9 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
...
@@ -37,9 +37,9 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
if
(
!
node
.
IsStmt
())
continue
;
if
(
!
node
.
IsStmt
())
continue
;
auto
&
instruct
=
node
.
AsStmt
();
auto
&
instruct
=
node
.
AsStmt
();
std
::
vector
<
std
::
pair
<
size_t
,
std
::
unique_ptr
<
KernelBase
>>>
scored
;
std
::
vector
<
std
::
pair
<
size_t
,
std
::
unique_ptr
<
KernelBase
>>>
scored
;
CHECK
(
!
instruct
.
valid_kernels
.
empty
())
<<
"No kernels found for "
CHECK
(
!
instruct
.
kernels
()
.
empty
())
<<
"No kernels found for "
<<
instruct
.
op_type
;
<<
instruct
.
op_type
()
;
for
(
auto
&&
kernel
:
instruct
.
valid_kernels
)
{
for
(
auto
&&
kernel
:
instruct
.
kernels
()
)
{
size_t
score
=
KernelGrade
(
*
kernel
);
size_t
score
=
KernelGrade
(
*
kernel
);
scored
.
emplace_back
(
score
,
std
::
move
(
kernel
));
scored
.
emplace_back
(
score
,
std
::
move
(
kernel
));
}
}
...
@@ -49,9 +49,9 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
...
@@ -49,9 +49,9 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// Move kernel back
// Move kernel back
// Just keep a single best kernel.
// Just keep a single best kernel.
// TODO(Superjomn) reconsider this.
// TODO(Superjomn) reconsider this.
instruct
.
valid_kernels
.
clear
();
instruct
.
kernels
()
.
clear
();
instruct
.
valid_kernels
.
emplace_back
(
std
::
move
(
scored
.
front
().
second
));
instruct
.
kernels
()
.
emplace_back
(
std
::
move
(
scored
.
front
().
second
));
VLOG
(
2
)
<<
"pick "
<<
instruct
.
valid_kernels
.
front
()
->
name
();
VLOG
(
2
)
<<
"pick "
<<
instruct
.
kernels
()
.
front
()
->
name
();
}
}
}
}
...
...
paddle/fluid/lite/core/mir/type_target_transform_pass.cc
浏览文件 @
4fe5c8aa
...
@@ -62,7 +62,7 @@ void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph, Node* inst_node,
...
@@ -62,7 +62,7 @@ void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph, Node* inst_node,
CHECK
(
in
->
AsArg
().
type
);
CHECK
(
in
->
AsArg
().
type
);
if
(
!
TargetCompatibleTo
(
*
in
->
AsArg
().
type
,
*
decl_arg_type
))
{
if
(
!
TargetCompatibleTo
(
*
in
->
AsArg
().
type
,
*
decl_arg_type
))
{
LOG
(
INFO
)
<<
"found Target unmatched tensor: "
<<
in
->
AsArg
().
name
LOG
(
INFO
)
<<
"found Target unmatched tensor: "
<<
in
->
AsArg
().
name
<<
" for kernel "
<<
inst
.
op
->
DebugString
()
<<
" "
<<
" for kernel "
<<
inst
.
op
()
->
DebugString
()
<<
" "
<<
*
in
->
AsArg
().
type
<<
" -> "
<<
*
decl_arg_type
;
<<
*
in
->
AsArg
().
type
<<
" -> "
<<
*
decl_arg_type
;
// Add an IoCopy instruction to make the input compatible with other dist.
// Add an IoCopy instruction to make the input compatible with other dist.
AddIoCopyInst
(
*
in
->
AsArg
().
type
,
*
decl_arg_type
,
in
,
graph
,
inst_node
,
AddIoCopyInst
(
*
in
->
AsArg
().
type
,
*
decl_arg_type
,
in
,
graph
,
inst_node
,
...
@@ -89,7 +89,7 @@ void TypeTargetTransformPass::AddIoCopyInst(
...
@@ -89,7 +89,7 @@ void TypeTargetTransformPass::AddIoCopyInst(
CHECK
(
io_copy_op
)
<<
"create op ["
<<
io_copy_op
<<
"] failed"
;
CHECK
(
io_copy_op
)
<<
"create op ["
<<
io_copy_op
<<
"] failed"
;
// CHECK(io_copy_op);
// CHECK(io_copy_op);
// Create the new var manually.
// Create the new var manually.
inst_node
->
AsStmt
().
op
->
scope
()
->
Var
(
io_copy_output_name
);
inst_node
->
AsStmt
().
op
()
->
scope
()
->
Var
(
io_copy_output_name
);
// Create IoCopy Instruction.
// Create IoCopy Instruction.
cpp
::
OpDesc
op_desc
;
cpp
::
OpDesc
op_desc
;
...
@@ -97,7 +97,7 @@ void TypeTargetTransformPass::AddIoCopyInst(
...
@@ -97,7 +97,7 @@ void TypeTargetTransformPass::AddIoCopyInst(
op_desc
.
SetInput
(
"Input"
,
{
in
->
AsArg
().
name
});
op_desc
.
SetInput
(
"Input"
,
{
in
->
AsArg
().
name
});
op_desc
.
SetOutput
(
"Out"
,
{
io_copy_output_name
});
op_desc
.
SetOutput
(
"Out"
,
{
io_copy_output_name
});
io_copy_op
->
Attach
(
op_desc
,
inst_node
->
AsStmt
().
op
->
scope
());
io_copy_op
->
Attach
(
op_desc
,
inst_node
->
AsStmt
().
op
()
->
scope
());
auto
kernels
=
io_copy_op
->
CreateKernels
(
valid_places
);
auto
kernels
=
io_copy_op
->
CreateKernels
(
valid_places
);
io_copy_inst
->
AsStmt
(
"io_copy"
,
std
::
move
(
kernels
),
io_copy_op
);
io_copy_inst
->
AsStmt
(
"io_copy"
,
std
::
move
(
kernels
),
io_copy_op
);
...
@@ -113,19 +113,19 @@ void TypeTargetTransformPass::AddIoCopyInst(
...
@@ -113,19 +113,19 @@ void TypeTargetTransformPass::AddIoCopyInst(
DirectedLink
(
io_copy_output_arg
,
inst_node
);
DirectedLink
(
io_copy_output_arg
,
inst_node
);
// reset opdesc and update kernel information
// reset opdesc and update kernel information
UpdateInputTo
(
inst_node
->
AsStmt
().
op
->
mutable_op_info
(),
in
->
AsArg
().
name
,
UpdateInputTo
(
inst_node
->
AsStmt
().
op
()
->
mutable_op_info
(),
in
->
AsArg
().
name
,
io_copy_output_name
);
io_copy_output_name
);
inst_node
->
AsStmt
().
op
->
Attach
(
*
inst_node
->
AsStmt
().
op
->
op_info
(),
inst_node
->
AsStmt
().
ResetOp
(
*
inst_node
->
AsStmt
().
op_info
(),
inst_node
->
AsStmt
().
op
->
scope
());
graph
->
valid_places
());
std
::
string
tmp
;
std
::
string
tmp
;
if
(
inst_node
->
AsStmt
().
op_info
()
->
GetInputArgname
(
"a"
,
&
tmp
))
{
if
(
inst_node
->
AsStmt
().
op_info
()
->
GetInputArgname
(
"a"
,
&
tmp
))
{
CHECK
(
false
)
<<
"get old a "
<<
tmp
;
CHECK
(
false
)
<<
"get old a "
<<
tmp
;
}
}
for
(
auto
&
kernel
:
inst_node
->
AsStmt
().
valid_kernels
)
{
for
(
auto
&
kernel
:
inst_node
->
AsStmt
().
kernels
()
)
{
inst_node
->
AsStmt
().
op
->
AttachKernel
(
kernel
.
get
());
inst_node
->
AsStmt
().
op
()
->
AttachKernel
(
kernel
.
get
());
}
}
graph
->
CheckValid
();
graph
->
CheckValid
();
...
...
paddle/fluid/lite/core/mir/use_passes.h
浏览文件 @
4fe5c8aa
...
@@ -23,9 +23,11 @@ USE_MIR_PASS(generate_program_pass);
...
@@ -23,9 +23,11 @@ USE_MIR_PASS(generate_program_pass);
USE_MIR_PASS
(
io_copy_kernel_pick_pass
);
USE_MIR_PASS
(
io_copy_kernel_pick_pass
);
USE_MIR_PASS
(
argument_type_display_pass
);
USE_MIR_PASS
(
argument_type_display_pass
);
USE_MIR_PASS
(
runtime_context_assign_pass
);
USE_MIR_PASS
(
runtime_context_assign_pass
);
USE_MIR_PASS
(
lite_conv_bn_fuse_pass
);
USE_MIR_PASS
(
graph_visualze
);
USE_MIR_PASS
(
graph_visualze
);
USE_MIR_PASS
(
lite_conv_bn_fuse_pass
);
USE_MIR_PASS
(
lite_fc_fuse_pass
);
USE_MIR_PASS
(
lite_fc_fuse_pass
);
USE_MIR_PASS
(
identity_scale_eliminate_pass
);
USE_MIR_PASS
(
lite_conv_elementwise_add_activation_fuse_pass
);
USE_MIR_PASS
(
lite_conv_elementwise_add_activation_fuse_pass
);
USE_MIR_PASS
(
lite_elementwise_add_activation_fuse_pass
);
USE_MIR_PASS
(
lite_elementwise_add_activation_fuse_pass
);
USE_MIR_PASS
(
lite_quant_dequant_fuse_pass
);
USE_MIR_PASS
(
lite_quant_dequant_fuse_pass
);
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
浏览文件 @
4fe5c8aa
...
@@ -39,7 +39,7 @@ class VariablePlaceInferencePass : public DebugPass {
...
@@ -39,7 +39,7 @@ class VariablePlaceInferencePass : public DebugPass {
for
(
const
auto
&
v
:
graph
->
inputs
())
{
for
(
const
auto
&
v
:
graph
->
inputs
())
{
// the feed op might in the inputs
// the feed op might in the inputs
if
(
v
->
IsStmt
())
{
if
(
v
->
IsStmt
())
{
LOG
(
INFO
)
<<
"found kernel in inputs "
<<
v
->
AsStmt
().
op_type
;
LOG
(
INFO
)
<<
"found kernel in inputs "
<<
v
->
AsStmt
().
op_type
()
;
continue
;
continue
;
}
}
}
}
...
@@ -59,10 +59,10 @@ class VariablePlaceInferencePass : public DebugPass {
...
@@ -59,10 +59,10 @@ class VariablePlaceInferencePass : public DebugPass {
for
(
auto
&
x
:
graph
->
StmtTopologicalOrder
())
{
for
(
auto
&
x
:
graph
->
StmtTopologicalOrder
())
{
auto
&
inst
=
x
->
AsStmt
();
auto
&
inst
=
x
->
AsStmt
();
// The IoCopyOp is a tool operator, it won't support the type inference.
// The IoCopyOp is a tool operator, it won't support the type inference.
if
(
inst
.
op_type
==
"io_copy"
)
continue
;
if
(
inst
.
op_type
()
==
"io_copy"
)
continue
;
// LOG(INFO) << "- inferencing type " <<
// LOG(INFO) << "- inferencing type " <<
// deal with inputs
// deal with inputs
VLOG
(
4
)
<<
"
inferencing op "
<<
inst
.
op_type
;
VLOG
(
4
)
<<
"
Infering op "
<<
inst
.
op_info
()
->
Repr
()
;
// TODO(zhaolong): Add check if the node's name in op's arguments.
// TODO(zhaolong): Add check if the node's name in op's arguments.
auto
get_argname
=
[
&
](
auto
get_argname
=
[
&
](
...
@@ -90,12 +90,14 @@ class VariablePlaceInferencePass : public DebugPass {
...
@@ -90,12 +90,14 @@ class VariablePlaceInferencePass : public DebugPass {
}
}
}
}
VLOG
(
3
)
<<
"inst "
<<
inst
.
op_info
()
->
Repr
();
for
(
auto
*
x_out
:
x
->
outlinks
)
{
for
(
auto
*
x_out
:
x
->
outlinks
)
{
std
::
string
node_name
=
x_out
->
AsArg
().
name
;
std
::
string
node_name
=
x_out
->
AsArg
().
name
;
std
::
string
arg_name
=
std
::
string
arg_name
=
get_argname
(
node_name
,
inst
.
op_info
()
->
outputs
());
get_argname
(
node_name
,
inst
.
op_info
()
->
outputs
());
CHECK
(
arg_name
.
size
()
>
0
)
<<
"can not found op arguments for node "
CHECK
(
arg_name
.
size
()
>
0
)
<<
"can not found op arguments for node "
<<
node_name
;
<<
node_name
<<
" in Inst "
<<
inst
.
op_type
();
VLOG
(
3
)
<<
"-- output arg_name "
<<
arg_name
;
VLOG
(
3
)
<<
"-- output arg_name "
<<
arg_name
;
auto
type
=
inst
.
picked_kernel
().
GetOutputDeclType
(
arg_name
);
auto
type
=
inst
.
picked_kernel
().
GetOutputDeclType
(
arg_name
);
if
(
!
x_out
->
AsArg
().
type
)
{
if
(
!
x_out
->
AsArg
().
type
)
{
...
...
paddle/fluid/lite/core/op_lite.cc
浏览文件 @
4fe5c8aa
...
@@ -61,7 +61,6 @@ std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels(
...
@@ -61,7 +61,6 @@ std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels(
targets
.
insert
(
place
.
target
);
targets
.
insert
(
place
.
target
);
}
}
// CHECK(!kernels.empty()) << "No kernel found for Op " << op_type_;
VLOG
(
2
)
<<
"op "
<<
op_type_
<<
" get "
<<
kernels
.
size
()
<<
" kernels"
;
VLOG
(
2
)
<<
"op "
<<
op_type_
<<
" get "
<<
kernels
.
size
()
<<
" kernels"
;
return
kernels
;
return
kernels
;
}
}
...
@@ -83,7 +82,7 @@ bool OpLite::Attach(const cpp::OpDesc &opdesc, lite::Scope *scope) {
...
@@ -83,7 +82,7 @@ bool OpLite::Attach(const cpp::OpDesc &opdesc, lite::Scope *scope) {
scope_
=
scope
;
scope_
=
scope
;
op_info_
.
reset
(
op_info_
.
reset
(
new
OpInfo
(
opdesc
));
// Force clean the out-of-date infomation.
new
OpInfo
(
opdesc
));
// Force clean the out-of-date infomation.
return
AttachImpl
(
opdesc
,
scope
);
return
AttachImpl
(
*
op_info
()
,
scope
);
}
}
const
Tensor
*
OpLite
::
GetTensor
(
lite
::
Scope
*
scope
,
const
Tensor
*
OpLite
::
GetTensor
(
lite
::
Scope
*
scope
,
...
...
paddle/fluid/lite/core/op_lite.h
浏览文件 @
4fe5c8aa
...
@@ -197,6 +197,22 @@ class OpInfo : public cpp::OpDesc {
...
@@ -197,6 +197,22 @@ class OpInfo : public cpp::OpDesc {
}
}
return
false
;
return
false
;
}
}
void
UpdateAllInputs
(
const
std
::
string
&
from
,
const
std
::
string
&
to
)
{
for
(
auto
&
item
:
inputs_
)
{
for
(
auto
&
var
:
item
.
second
)
{
if
(
var
==
from
)
var
=
to
;
}
}
}
void
UpdateAllOutputs
(
const
std
::
string
&
from
,
const
std
::
string
&
to
)
{
for
(
auto
&
item
:
outputs_
)
{
for
(
auto
&
var
:
item
.
second
)
{
if
(
var
==
from
)
var
=
to
;
}
}
}
};
};
}
// namespace lite
}
// namespace lite
...
...
paddle/fluid/lite/core/optimizer.h
浏览文件 @
4fe5c8aa
...
@@ -43,6 +43,8 @@ class Optimizer {
...
@@ -43,6 +43,8 @@ class Optimizer {
CHECK
(
!
graph_
)
<<
"duplicate optimize found"
;
CHECK
(
!
graph_
)
<<
"duplicate optimize found"
;
graph_
.
reset
(
new
mir
::
SSAGraph
);
graph_
.
reset
(
new
mir
::
SSAGraph
);
graph_
->
Build
(
program
,
valid_places
);
graph_
->
Build
(
program
,
valid_places
);
graph_
->
SetValidPlaces
(
valid_places
);
SpecifyKernelPickTactic
(
kernel_pick_factor
);
SpecifyKernelPickTactic
(
kernel_pick_factor
);
InitTargetTypeTransformPass
();
InitTargetTypeTransformPass
();
...
@@ -51,6 +53,8 @@ class Optimizer {
...
@@ -51,6 +53,8 @@ class Optimizer {
"lite_quant_dequant_fuse_pass"
,
//
"lite_quant_dequant_fuse_pass"
,
//
"lite_conv_bn_fuse_pass"
,
//
"lite_conv_bn_fuse_pass"
,
//
"lite_conv_elementwise_add_activation_fuse_pass"
,
//
"lite_conv_elementwise_add_activation_fuse_pass"
,
//
"lite_fc_fuse_pass"
,
//
"identity_scale_eliminate_pass"
,
//
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
"lite_elementwise_add_activation_fuse_pass"
,
//
"lite_elementwise_add_activation_fuse_pass"
,
//
#endif
#endif
...
...
paddle/fluid/lite/core/program.h
浏览文件 @
4fe5c8aa
...
@@ -140,7 +140,7 @@ class RuntimeProgram {
...
@@ -140,7 +140,7 @@ class RuntimeProgram {
void
Run
()
{
void
Run
()
{
for
(
auto
&
inst
:
instructions_
)
{
for
(
auto
&
inst
:
instructions_
)
{
VLOG
(
4
)
<<
">> Running kernel: "
<<
inst
;
VLOG
(
3
)
<<
">> Running kernel: "
<<
inst
.
op
()
->
op_info
()
->
Repr
()
;
inst
.
Run
();
inst
.
Run
();
}
}
}
}
...
...
paddle/fluid/lite/core/tensor.h
浏览文件 @
4fe5c8aa
...
@@ -191,7 +191,6 @@ class TensorBase {
...
@@ -191,7 +191,6 @@ class TensorBase {
template
<
typename
TensorT
>
template
<
typename
TensorT
>
bool
TensorCompareWith
(
const
TensorT
&
a
,
const
TensorT
&
b
)
{
bool
TensorCompareWith
(
const
TensorT
&
a
,
const
TensorT
&
b
)
{
if
(
a
.
dims
()
!=
b
.
dims
())
return
false
;
if
(
a
.
dims
()
!=
b
.
dims
())
return
false
;
LOG
(
INFO
)
<<
"data_size: "
<<
a
.
data_size
();
if
(
memcmp
(
a
.
raw_data
(),
b
.
raw_data
(),
a
.
data_size
())
!=
0
)
return
false
;
if
(
memcmp
(
a
.
raw_data
(),
b
.
raw_data
(),
a
.
data_size
())
!=
0
)
return
false
;
return
true
;
return
true
;
}
}
...
...
paddle/fluid/lite/kernels/arm/elementwise_compute_test.cc
浏览文件 @
4fe5c8aa
...
@@ -117,6 +117,19 @@ TEST(elementwise_add, compute) {
...
@@ -117,6 +117,19 @@ TEST(elementwise_add, compute) {
operators
::
ElementwiseParam
param
;
operators
::
ElementwiseParam
param
;
lite
::
Tensor
x
,
y
,
output
,
output_ref
;
lite
::
Tensor
x
,
y
,
output
,
output_ref
;
#if 1
for
(
auto
n
:
{
1
,
3
,
4
})
{
for
(
auto
c
:
{
1
,
3
,
4
})
{
for
(
auto
h
:
{
1
,
3
,
4
})
{
for
(
auto
w
:
{
1
,
3
,
4
})
{
for
(
auto
axis
:
{
-
1
,
0
,
1
,
3
})
{
for
(
auto
yd
:
{
std
::
vector
<
int64_t
>
({
n
}),
std
::
vector
<
int64_t
>
({
c
}),
std
::
vector
<
int64_t
>
({
h
}),
std
::
vector
<
int64_t
>
({
w
}),
std
::
vector
<
int64_t
>
({
n
,
c
}),
std
::
vector
<
int64_t
>
({
c
,
h
}),
std
::
vector
<
int64_t
>
({
c
,
h
,
w
}),
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
})})
{
#else
for
(
auto
n
:
{
1
,
3
,
4
,
11
})
{
for
(
auto
n
:
{
1
,
3
,
4
,
11
})
{
for
(
auto
c
:
{
1
,
3
,
4
,
11
})
{
for
(
auto
c
:
{
1
,
3
,
4
,
11
})
{
for
(
auto
h
:
{
1
,
3
,
4
,
11
})
{
for
(
auto
h
:
{
1
,
3
,
4
,
11
})
{
...
@@ -129,6 +142,7 @@ TEST(elementwise_add, compute) {
...
@@ -129,6 +142,7 @@ TEST(elementwise_add, compute) {
std
::
vector
<
int64_t
>
({
h
,
w
}),
std
::
vector
<
int64_t
>
({
n
,
c
,
h
}),
std
::
vector
<
int64_t
>
({
h
,
w
}),
std
::
vector
<
int64_t
>
({
n
,
c
,
h
}),
std
::
vector
<
int64_t
>
({
c
,
h
,
w
}),
std
::
vector
<
int64_t
>
({
c
,
h
,
w
}),
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
})})
{
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
})})
{
#endif
auto
x_dim
=
DDim
(
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
}));
auto
x_dim
=
DDim
(
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
}));
auto
y_dim
=
DDim
(
yd
);
auto
y_dim
=
DDim
(
yd
);
int
axis_t
=
axis
<
0
?
x_dim
.
size
()
-
y_dim
.
size
()
:
axis
;
int
axis_t
=
axis
<
0
?
x_dim
.
size
()
-
y_dim
.
size
()
:
axis
;
...
@@ -192,6 +206,20 @@ TEST(fusion_elementwise_add_activation_arm, compute) {
...
@@ -192,6 +206,20 @@ TEST(fusion_elementwise_add_activation_arm, compute) {
operators
::
FusionElementwiseActivationParam
param
;
operators
::
FusionElementwiseActivationParam
param
;
lite
::
Tensor
x
,
y
,
output
,
output_ref
;
lite
::
Tensor
x
,
y
,
output
,
output_ref
;
#if 1
for
(
auto
act_type
:
{
"relu"
})
{
for
(
auto
n
:
{
1
,
3
,
4
})
{
for
(
auto
c
:
{
1
,
3
,
4
})
{
for
(
auto
h
:
{
1
,
3
,
4
})
{
for
(
auto
w
:
{
1
,
3
,
4
})
{
for
(
auto
axis
:
{
-
1
,
0
,
1
,
3
})
{
for
(
auto
yd
:
{
std
::
vector
<
int64_t
>
({
n
}),
std
::
vector
<
int64_t
>
({
c
}),
std
::
vector
<
int64_t
>
({
h
}),
std
::
vector
<
int64_t
>
({
w
}),
std
::
vector
<
int64_t
>
({
n
,
c
}),
std
::
vector
<
int64_t
>
({
h
,
w
}),
std
::
vector
<
int64_t
>
({
n
,
c
,
h
}),
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
})})
{
#else
for
(
auto
act_type
:
{
"relu"
})
{
for
(
auto
act_type
:
{
"relu"
})
{
for
(
auto
n
:
{
1
,
3
,
4
,
11
})
{
for
(
auto
n
:
{
1
,
3
,
4
,
11
})
{
for
(
auto
c
:
{
1
,
3
,
4
,
11
})
{
for
(
auto
c
:
{
1
,
3
,
4
,
11
})
{
...
@@ -206,6 +234,7 @@ TEST(fusion_elementwise_add_activation_arm, compute) {
...
@@ -206,6 +234,7 @@ TEST(fusion_elementwise_add_activation_arm, compute) {
std
::
vector
<
int64_t
>
({
n
,
c
,
h
}),
std
::
vector
<
int64_t
>
({
n
,
c
,
h
}),
std
::
vector
<
int64_t
>
({
c
,
h
,
w
}),
std
::
vector
<
int64_t
>
({
c
,
h
,
w
}),
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
})})
{
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
})})
{
#endif
auto
x_dim
=
DDim
(
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
}));
auto
x_dim
=
DDim
(
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
}));
auto
y_dim
=
DDim
(
yd
);
auto
y_dim
=
DDim
(
yd
);
int
axis_t
=
axis
<
0
?
x_dim
.
size
()
-
y_dim
.
size
()
:
axis
;
int
axis_t
=
axis
<
0
?
x_dim
.
size
()
-
y_dim
.
size
()
:
axis
;
...
...
paddle/fluid/lite/kernels/arm/softmax_compute_test.cc
浏览文件 @
4fe5c8aa
...
@@ -80,12 +80,19 @@ TEST(softmax_arm, compute) {
...
@@ -80,12 +80,19 @@ TEST(softmax_arm, compute) {
lite
::
Tensor
x
;
lite
::
Tensor
x
;
lite
::
Tensor
output
;
lite
::
Tensor
output
;
lite
::
Tensor
output_ref
;
lite
::
Tensor
output_ref
;
#if 1
for
(
auto
n
:
{
1
,
3
})
{
for
(
auto
c
:
{
1
,
4
})
{
for
(
auto
h
:
{
5
,
1
})
{
for
(
auto
w
:
{
1
,
6
})
{
for
(
auto
axis
:
{
-
2
,
-
1
,
0
,
1
,
2
})
{
#else
for
(
auto
n
:
{
1
,
3
,
4
,
11
})
{
for
(
auto
n
:
{
1
,
3
,
4
,
11
})
{
for
(
auto
c
:
{
1
,
3
,
11
,
4
})
{
for
(
auto
c
:
{
1
,
3
,
11
,
4
})
{
for
(
auto
h
:
{
3
,
1
,
11
,
4
})
{
for
(
auto
h
:
{
3
,
1
,
11
,
4
})
{
for
(
auto
w
:
{
1
,
3
,
4
,
12
})
{
for
(
auto
w
:
{
1
,
3
,
4
,
12
})
{
for
(
auto
axis
:
{
-
4
,
-
3
,
-
2
,
-
1
,
0
,
1
,
2
,
3
})
{
for
(
auto
axis
:
{
-
4
,
-
3
,
-
2
,
-
1
,
0
,
1
,
2
,
3
})
{
#endif
x
.
Resize
(
DDim
(
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
})));
x
.
Resize
(
DDim
(
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
})));
output
.
Resize
(
DDim
(
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
})));
output
.
Resize
(
DDim
(
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
})));
output_ref
.
Resize
(
DDim
(
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
})));
output_ref
.
Resize
(
DDim
(
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
})));
...
...
paddle/fluid/lite/kernels/x86/elementwise_compute_test.cc
浏览文件 @
4fe5c8aa
...
@@ -15,6 +15,8 @@
...
@@ -15,6 +15,8 @@
#include "paddle/fluid/lite/kernels/x86/elementwise_compute.h"
#include "paddle/fluid/lite/kernels/x86/elementwise_compute.h"
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include <iostream>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/op_registry.h"
...
...
paddle/fluid/lite/kernels/x86/mul_compute.h
浏览文件 @
4fe5c8aa
...
@@ -40,12 +40,20 @@ class MulCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
...
@@ -40,12 +40,20 @@ class MulCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
auto
*
x
=
&
param
.
x
->
raw_tensor
();
auto
*
x
=
&
param
.
x
->
raw_tensor
();
auto
*
y
=
&
param
.
y
->
raw_tensor
();
auto
*
y
=
&
param
.
y
->
raw_tensor
();
const
Tensor
x_matrix
=
x
->
dims
().
size
()
>
2
?
framework
::
ReshapeToMatrix
(
Tensor
x_matrix
,
y_matrix
;
*
x
,
param
.
x_num_col_dims
)
:
*
x
;
if
(
x
->
dims
().
size
()
>
2
)
{
const
Tensor
y_matrix
=
y
->
dims
().
size
()
>
2
?
framework
::
ReshapeToMatrix
(
x_matrix
=
framework
::
ReshapeToMatrix
(
*
x
,
param
.
x_num_col_dims
);
*
y
,
param
.
y_num_col_dims
)
}
else
{
:
*
y
;
x_matrix
=
*
x
;
}
if
(
y
->
dims
().
size
()
>
2
)
{
y_matrix
=
framework
::
ReshapeToMatrix
(
*
y
,
param
.
y_num_col_dims
);
}
else
{
y_matrix
=
*
y
;
}
auto
*
z
=
&
param
.
output
->
raw_tensor
();
auto
*
z
=
&
param
.
output
->
raw_tensor
();
auto
z_dim
=
z
->
dims
();
auto
z_dim
=
z
->
dims
();
...
@@ -75,12 +83,22 @@ class MulGradCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
...
@@ -75,12 +83,22 @@ class MulGradCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
auto
*
x
=
&
param
.
x
->
raw_tensor
();
auto
*
x
=
&
param
.
x
->
raw_tensor
();
auto
*
y
=
&
param
.
y
->
raw_tensor
();
auto
*
y
=
&
param
.
y
->
raw_tensor
();
auto
x_matrix
=
x
->
dims
().
size
()
>
2
?
framework
::
ReshapeToMatrix
(
*
x
,
param
.
x_num_col_dims
)
Tensor
x_matrix
,
y_matrix
;
:
static_cast
<
const
Tensor
&>
(
*
x
);
auto
y_matrix
=
y
->
dims
().
size
()
>
2
if
(
x
->
dims
().
size
()
>
2
)
{
?
framework
::
ReshapeToMatrix
(
*
y
,
param
.
y_num_col_dims
)
x_matrix
=
framework
::
ReshapeToMatrix
(
*
x
,
param
.
x_num_col_dims
);
:
static_cast
<
const
Tensor
&>
(
*
y
);
}
else
{
x_matrix
=
*
x
;
}
if
(
y
->
dims
().
size
()
>
2
)
{
y_matrix
=
framework
::
ReshapeToMatrix
(
*
y
,
param
.
y_num_col_dims
);
}
else
{
y_matrix
=
*
y
;
}
auto
*
dout
=
&
param
.
output_grad
->
raw_tensor
();
auto
*
dout
=
&
param
.
output_grad
->
raw_tensor
();
Tensor
dout_mat
;
Tensor
dout_mat
;
...
...
paddle/fluid/lite/kernels/x86/mul_compute_test.cc
浏览文件 @
4fe5c8aa
...
@@ -15,6 +15,8 @@
...
@@ -15,6 +15,8 @@
#include "paddle/fluid/lite/kernels/x86/mul_compute.h"
#include "paddle/fluid/lite/kernels/x86/mul_compute.h"
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include <iostream>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/op_registry.h"
...
...
paddle/fluid/lite/model_parser/CMakeLists.txt
浏览文件 @
4fe5c8aa
...
@@ -11,7 +11,8 @@ if(NOT LITE_ON_MOBILE)
...
@@ -11,7 +11,8 @@ if(NOT LITE_ON_MOBILE)
endif
()
endif
()
cc_library
(
compatible_pb_lite SRCS compatible_pb.cc DEPS op_desc_lite framework_proto_lite var_desc_lite
)
cc_library
(
compatible_pb_lite SRCS compatible_pb.cc
DEPS op_desc_lite framework_proto_lite var_desc_lite cpp_op_desc_lite
)
lite_cc_library
(
model_parser_lite SRCS model_parser.cc DEPS
lite_cc_library
(
model_parser_lite SRCS model_parser.cc DEPS
variable_lite scope_lite
${
tensor_lite
}
scope_lite
variable_lite scope_lite
${
tensor_lite
}
scope_lite
...
...
paddle/fluid/lite/model_parser/cpp/CMakeLists.txt
浏览文件 @
4fe5c8aa
cc_library
(
cpp_op_desc_lite SRCS op_desc.cc DEPS any_lite
)
cc_library
(
cpp_op_desc_lite SRCS op_desc.cc DEPS any_lite
)
paddle/fluid/lite/model_parser/desc_apis.h
浏览文件 @
4fe5c8aa
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#pragma once
#pragma once
#include <map>
#include <map>
#include <sstream>
#include <string>
#include <string>
#include <vector>
#include <vector>
...
@@ -79,6 +80,27 @@ class OpDescAPI {
...
@@ -79,6 +80,27 @@ class OpDescAPI {
/// Get an attribute.
/// Get an attribute.
template
<
typename
T
>
template
<
typename
T
>
T
GetAttr
(
const
std
::
string
&
name
)
const
;
T
GetAttr
(
const
std
::
string
&
name
)
const
;
std
::
string
Repr
()
const
{
std
::
stringstream
ss
;
ss
<<
Type
();
ss
<<
"("
;
for
(
auto
&
arg
:
InputArgumentNames
())
{
ss
<<
arg
<<
":"
;
for
(
auto
val
:
Input
(
arg
))
{
ss
<<
val
<<
" "
;
}
}
ss
<<
") -> ("
;
for
(
auto
&
arg
:
OutputArgumentNames
())
{
ss
<<
arg
<<
":"
;
for
(
auto
val
:
Output
(
arg
))
{
ss
<<
val
<<
" "
;
}
}
ss
<<
")"
;
return
ss
.
str
();
}
};
};
}
// namespace lite
}
// namespace lite
...
...
paddle/fluid/lite/model_parser/pb/op_desc.h
浏览文件 @
4fe5c8aa
...
@@ -141,6 +141,8 @@ class OpDesc : public OpDescAPI {
...
@@ -141,6 +141,8 @@ class OpDesc : public OpDescAPI {
template
<
typename
T
>
template
<
typename
T
>
T
GetAttr
(
const
std
::
string
&
name
)
const
;
T
GetAttr
(
const
std
::
string
&
name
)
const
;
std
::
string
DebugString
()
const
{
return
desc_
.
DebugString
();
}
private:
private:
std
::
vector
<
std
::
string
>
GetArguments
(
std
::
vector
<
std
::
string
>
GetArguments
(
const
google
::
protobuf
::
RepeatedPtrField
<
framework
::
proto
::
OpDesc_Var
>
const
google
::
protobuf
::
RepeatedPtrField
<
framework
::
proto
::
OpDesc_Var
>
...
...
paddle/fluid/lite/operators/mul_op.h
浏览文件 @
4fe5c8aa
...
@@ -38,15 +38,19 @@ class MulOpLite : public OpLite {
...
@@ -38,15 +38,19 @@ class MulOpLite : public OpLite {
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
// TODO(Superjomn) replace framework::OpDesc with a lite one.
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
CHECK
(
!
op_desc
.
Input
(
"X"
).
empty
());
CHECK
(
!
op_desc
.
Input
(
"Y"
).
empty
());
CHECK
(
!
op_desc
.
Output
(
"Out"
).
empty
());
auto
input
=
op_desc
.
Input
(
"X"
).
front
();
auto
input
=
op_desc
.
Input
(
"X"
).
front
();
auto
W
=
op_desc
.
Input
(
"Y"
).
front
();
auto
W
=
op_desc
.
Input
(
"Y"
).
front
();
auto
out
=
op_desc
.
Output
(
"Out"
).
front
();
auto
out
=
op_desc
.
Output
(
"Out"
).
front
();
auto
*
var
=
scope
->
FindVar
(
input
);
auto
*
var
=
scope
->
FindVar
(
input
);
CHECK
(
var
);
CHECK
(
var
);
param_
.
x
=
var
->
GetMutable
<
Tensor
>
();
param_
.
x
=
&
var
->
Get
<
Tensor
>
();
var
=
scope
->
FindVar
(
W
);
var
=
scope
->
FindVar
(
W
);
CHECK
(
var
)
<<
"no var called "
<<
W
;
CHECK
(
var
)
<<
"no var called "
<<
W
;
param_
.
y
=
var
->
GetMutable
<
Tensor
>
();
param_
.
y
=
&
var
->
Get
<
Tensor
>
();
var
=
scope
->
FindVar
(
out
);
var
=
scope
->
FindVar
(
out
);
CHECK
(
var
)
<<
"no var called "
<<
out
;
CHECK
(
var
)
<<
"no var called "
<<
out
;
param_
.
output
=
var
->
GetMutable
<
Tensor
>
();
param_
.
output
=
var
->
GetMutable
<
Tensor
>
();
...
...
paddle/fluid/lite/operators/op_params.h
浏览文件 @
4fe5c8aa
...
@@ -67,8 +67,8 @@ struct ReluParam {
...
@@ -67,8 +67,8 @@ struct ReluParam {
// For Mul Op
// For Mul Op
struct
MulParam
{
struct
MulParam
{
lite
::
Tensor
*
x
{};
const
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
y
{};
const
lite
::
Tensor
*
y
{};
lite
::
Tensor
*
output
{};
lite
::
Tensor
*
output
{};
int
x_num_col_dims
{
1
};
int
x_num_col_dims
{
1
};
...
...
paddle/fluid/lite/tools/build.sh
浏览文件 @
4fe5c8aa
...
@@ -54,22 +54,6 @@ function check_style {
...
@@ -54,22 +54,6 @@ function check_style {
fi
fi
}
}
function
cmake_arm
{
# $1: ARM_TARGET_OS in "android" , "armlinux"
# $2: ARM_TARGET_ARCH_ABI in "armv8", "armv7" ,"armv7hf"
# $3: ARM_TARGET_LANG in "gcc" "clang"
cmake ..
\
-DWITH_GPU
=
OFF
\
-DWITH_MKL
=
OFF
\
-DWITH_LITE
=
ON
\
-DLITE_WITH_CUDA
=
OFF
\
-DLITE_WITH_X86
=
OFF
\
-DLITE_WITH_ARM
=
ON
\
-DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK
=
ON
\
-DWITH_TESTING
=
ON
\
-DARM_TARGET_OS
=
$1
-DARM_TARGET_ARCH_ABI
=
$2
-DARM_TARGET_LANG
=
$3
}
function
build_single
{
function
build_single
{
#make $1 -j$(expr $(nproc) - 2)
#make $1 -j$(expr $(nproc) - 2)
make
$1
-j
$NUM_CORES_FOR_COMPILE
make
$1
-j
$NUM_CORES_FOR_COMPILE
...
@@ -153,33 +137,53 @@ function test_arm_model {
...
@@ -153,33 +137,53 @@ function test_arm_model {
adb
-s
emulator-
${
port
}
shell
chmod
+x
"
${
adb_work_dir
}
/
${
test_name
}
"
adb
-s
emulator-
${
port
}
shell
chmod
+x
"
${
adb_work_dir
}
/
${
test_name
}
"
local
adb_model_path
=
"./
${
adb_work_dir
}
/
`
basename
${
model_dir
}
`
"
local
adb_model_path
=
"./
${
adb_work_dir
}
/
`
basename
${
model_dir
}
`
"
adb
-s
emulator-
${
port
}
shell
"./
${
adb_work_dir
}
/
${
test_name
}
--eval_model_dir=
$adb_model_path
"
adb
-s
emulator-
${
port
}
shell
"./
${
adb_work_dir
}
/
${
test_name
}
--eval_model_dir=
$adb_model_path
"
}
}
# Build the code and run lite arm tests. This is executed in the CI system.
function
cmake_arm
{
function
build_test_arm
{
# $1: ARM_TARGET_OS in "android" , "armlinux"
# 1. Build goes first
# $2: ARM_TARGET_ARCH_ABI in "armv8", "armv7" ,"armv7hf"
# $3: ARM_TARGET_LANG in "gcc" "clang"
cmake ..
\
-DWITH_GPU
=
OFF
\
-DWITH_MKL
=
OFF
\
-DWITH_LITE
=
ON
\
-DLITE_WITH_CUDA
=
OFF
\
-DLITE_WITH_X86
=
OFF
\
-DLITE_WITH_ARM
=
ON
\
-DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK
=
ON
\
-DWITH_TESTING
=
ON
\
-DARM_TARGET_OS
=
$1
-DARM_TARGET_ARCH_ABI
=
$2
-DARM_TARGET_LANG
=
$3
}
# $1: ARM_TARGET_OS in "android" , "armlinux"
# $2: ARM_TARGET_ARCH_ABI in "armv8", "armv7" ,"armv7hf"
# $3: ARM_TARGET_LANG in "gcc" "clang"
function
build_arm
{
os
=
$1
abi
=
$2
lang
=
$3
cur_dir
=
$(
pwd
)
cur_dir
=
$(
pwd
)
for
lang
in
"gcc"
"clang"
;
do
if
[[
${
os
}
==
"armlinux"
]]
;
then
for
os
in
"android"
"armlinux"
;
do
# TODO(hongming): enable compile armv7 and armv7hf on armlinux, and clang compile
if
[[
${
os
}
==
"armlinux"
&&
${
lang
}
==
"clang"
]]
;
then
if
[[
${
lang
}
==
"clang"
]]
;
then
continue
echo
"clang is not enabled on armlinux yet"
return
0
fi
fi
for
abi
in
"armv8"
"armv7"
"armv7hf"
;
do
# TODO(hongming): enable compile armv7 and armv7hf on armlinux
if
[[
${
abi
}
==
"armv7hf"
]]
;
then
if
[[
${
abi
}
==
"armv7hf"
]]
;
then
echo
"armv7hf is not supported on both android and
armlinux yet"
echo
"armv7hf is not supported on
armlinux yet"
continue
return
0
fi
fi
if
[[
${
abi
}
==
"armv7"
]]
;
then
# TODO(hongming): enable armv7 on armlinux
if
[[
${
os
}
==
"armlinux"
&&
${
abi
}
==
"armv7"
]]
;
then
echo
"armv7 is not supported on armlinux yet"
echo
"armv7 is not supported on armlinux yet"
continue
return
0
fi
fi
fi
if
[[
${
os
}
==
"android"
&&
${
abi
}
==
"armv7hf"
]]
;
then
if
[[
${
os
}
==
"android"
&&
${
abi
}
==
"armv7hf"
]]
;
then
echo
"android do not need armv7hf"
echo
"android do not need armv7hf"
continue
return
0
fi
fi
build_dir
=
$cur_dir
/build.lite.
${
os
}
.
${
abi
}
.
${
lang
}
build_dir
=
$cur_dir
/build.lite.
${
os
}
.
${
abi
}
.
${
lang
}
...
@@ -188,11 +192,47 @@ function build_test_arm {
...
@@ -188,11 +192,47 @@ function build_test_arm {
cmake_arm
${
os
}
${
abi
}
${
lang
}
cmake_arm
${
os
}
${
abi
}
${
lang
}
build
$TESTS_FILE
build
$TESTS_FILE
}
# $1: ARM_TARGET_OS in "android" , "armlinux"
# $2: ARM_TARGET_ARCH_ABI in "armv8", "armv7" ,"armv7hf"
# $3: ARM_TARGET_LANG in "gcc" "clang"
# $4: android test port
# Note: test must be in build dir
function
test_arm
{
os
=
$1
abi
=
$2
lang
=
$3
port
=
$4
if
[[
${
os
}
==
"armlinux"
]]
;
then
# TODO(hongming): enable test armlinux on armv8, armv7 and armv7hf
echo
"Skip test arm linux yet. armlinux must in another docker"
return
0
fi
if
[[
${
os
}
==
"android"
&&
${
abi
}
==
"armv7hf"
]]
;
then
echo
"android do not need armv7hf"
return
0
fi
# TODO(yuanshuai): enable armv7 on android
if
[[
${
abi
}
==
"armv7"
]]
;
then
echo
"skip android v7 test yet"
return
0
fi
echo
"test file:
${
TESTS_FILE
}
"
for
_test
in
$(
cat
$TESTS_FILE
)
;
do
test_arm_android
$_test
$port
done
done
done
# TODO(sangoly): refine this
done
test_arm_model
"test_cxx_api_lite"
$port
"./third_party/install/mobilenet_v2_relu"
}
# 2. Then test
# Build the code and run lite arm tests. This is executed in the CI system.
function
build_test_arm
{
########################################################################
# job 1-4 must be in one runner
port_armv8
=
5554
port_armv8
=
5554
port_armv7
=
5556
port_armv7
=
5556
...
@@ -206,39 +246,46 @@ function build_test_arm {
...
@@ -206,39 +246,46 @@ function build_test_arm {
echo
-ne
'\n'
|
${
ANDROID_HOME
}
/emulator/emulator
-avd
paddle-armv7
-noaudio
-no-window
-gpu
off
-verbose
-port
${
port_armv7
}
&
echo
-ne
'\n'
|
${
ANDROID_HOME
}
/emulator/emulator
-avd
paddle-armv7
-noaudio
-no-window
-gpu
off
-verbose
-port
${
port_armv7
}
&
sleep
1m
sleep
1m
# now can only test android.
# job 1
for
lang
in
"gcc"
"clang"
;
do
build_arm
"android"
"armv8"
"gcc"
for
abi
in
"armv8"
"armv7"
;
do
test_arm
"android"
"armv8"
"gcc"
${
port_armv8
}
# TODO(yuanshuai): enable armv7 on android
cd
-
if
[[
${
abi
}
==
"armv7"
]]
;
then
continue
fi
build_dir
=
$cur_dir
/build.lite.android.
${
abi
}
.
${
lang
}
# job 2
cd
$build_dir
build_arm
"android"
"armv8"
"clang"
test_arm
"android"
"armv8"
"clang"
${
port_armv8
}
cd
-
local
port
=
# job 3
if
[[
${
abi
}
==
"armv7"
]]
;
then
build_arm
"android"
"armv7"
"gcc"
port
=
${
port_armv7
}
test_arm
"android"
"armv7"
"gcc"
${
port_armv7
}
fi
cd
-
if
[[
${
abi
}
==
"armv8"
]]
;
then
# job 4
port
=
${
port_armv8
}
build_arm
"android"
"armv7"
"clang"
fi
test_arm
"android"
"armv7"
"clang"
${
port_armv7
}
echo
"test file:
${
TESTS_FILE
}
"
cd
-
for
_test
in
$(
cat
$TESTS_FILE
)
;
do
test_arm_android
$_test
$port
done
# TODO(sangoly): refine this
test_arm_model
"test_cxx_api_lite"
$port
"./third_party/install/mobilenet_v2_relu"
done
done
# armlinux need in another docker
# TODO(hongming): enable test armlinux on armv8, armv7 and armv7hf
adb devices |
grep
emulator |
cut
-f1
|
while
read
line
;
do
adb
-s
$line
emu
kill
;
done
adb devices |
grep
emulator |
cut
-f1
|
while
read
line
;
do
adb
-s
$line
emu
kill
;
done
echo
"Done"
echo
"Done"
########################################################################
# job 5
build_arm
"armlinux"
"armv8"
test_arm
"armlinux"
"armv8"
cd
-
# job 6
build_arm
"armlinux"
"armv7"
test_arm
"armlinux"
"armv7"
cd
-
# job 7
build_arm
"armlinux"
"armv7hf"
test_arm
"armlinux"
"armv7hf"
cd
-
echo
"Done"
}
}
############################# MAIN #################################
############################# MAIN #################################
...
@@ -279,6 +326,10 @@ function main {
...
@@ -279,6 +326,10 @@ function main {
ARM_ABI
=
"
${
i
#*=
}
"
ARM_ABI
=
"
${
i
#*=
}
"
shift
shift
;;
;;
--arm_lang
=
*
)
ARM_LANG
=
"
${
i
#*=
}
"
shift
;;
--arm_port
=
*
)
--arm_port
=
*
)
ARM_PORT
=
"
${
i
#*=
}
"
ARM_PORT
=
"
${
i
#*=
}
"
shift
shift
...
@@ -301,13 +352,21 @@ function main {
...
@@ -301,13 +352,21 @@ function main {
shift
shift
;;
;;
cmake_arm
)
cmake_arm
)
cmake_arm
$ARM_OS
$ARM_ABI
cmake_arm
$ARM_OS
$ARM_ABI
$ARM_LANG
shift
;;
build_arm
)
build_arm
$ARM_OS
$ARM_ABI
$ARM_LANG
shift
shift
;;
;;
test_server
)
test_server
)
test_lite
$TESTS_FILE
test_lite
$TESTS_FILE
shift
shift
;;
;;
test_arm
)
build_arm
$ARM_OS
$ARM_ABI
$ARM_LANG
$ARM_PORT
shift
;;
test_arm_android
)
test_arm_android
)
test_arm_android
$TEST_NAME
$ARM_PORT
test_arm_android
$TEST_NAME
$ARM_PORT
shift
shift
...
...
paddle/fluid/lite/utils/varient.h
浏览文件 @
4fe5c8aa
...
@@ -20,6 +20,7 @@
...
@@ -20,6 +20,7 @@
#include <typeinfo>
#include <typeinfo>
#include <utility>
#include <utility>
#include "paddle/fluid/lite/utils/cp_logging.h"
#include "paddle/fluid/lite/utils/cp_logging.h"
#include "paddle/fluid/lite/utils/string.h"
// This is an equivalent implementation of boost::any. We implement this to
// This is an equivalent implementation of boost::any. We implement this to
// avoid including the whole boost library and keep the inference library small.
// avoid including the whole boost library and keep the inference library small.
...
@@ -116,9 +117,9 @@ struct variant {
...
@@ -116,9 +117,9 @@ struct variant {
if
(
type_id
==
typeid
(
T
).
hash_code
())
if
(
type_id
==
typeid
(
T
).
hash_code
())
return
*
reinterpret_cast
<
const
T
*>
(
&
data
);
return
*
reinterpret_cast
<
const
T
*>
(
&
data
);
else
else
throw
std
::
invalid_argument
(
"unmatched type"
);
throw
std
::
invalid_argument
(
// LOG(FATAL) << "unmatched type get, should be " << type_id << " but get "
string_format
(
"unmatched type, store as %d, but want to get %s"
,
// << typeid(T).name(
);
type_id
,
typeid
(
T
).
name
())
);
return
*
reinterpret_cast
<
const
T
*>
(
&
data
);
return
*
reinterpret_cast
<
const
T
*>
(
&
data
);
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录