Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
4fe5c8aa
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4fe5c8aa
编写于
6月 20, 2019
作者:
N
nhzlx
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'incubate/lite' of
http://10.87.145.36/inference/paddlelite
into xzl/incubate/lite
fix comments
上级
8ead391d
6f705068
变更
65
显示空白变更内容
内联
并排
Showing
65 changed file
with
859 addition
and
224 deletion
+859
-224
CMakeLists.txt
CMakeLists.txt
+2
-2
paddle/fluid/lite/api/apis_test.cc
paddle/fluid/lite/api/apis_test.cc
+1
-1
paddle/fluid/lite/api/cxx_api_bin.cc
paddle/fluid/lite/api/cxx_api_bin.cc
+1
-1
paddle/fluid/lite/api/lite_api_test_helper.cc
paddle/fluid/lite/api/lite_api_test_helper.cc
+1
-0
paddle/fluid/lite/core/CMakeLists.txt
paddle/fluid/lite/core/CMakeLists.txt
+0
-1
paddle/fluid/lite/core/mir/CMakeLists.txt
paddle/fluid/lite/core/mir/CMakeLists.txt
+24
-20
paddle/fluid/lite/core/mir/elimination/CMakeLists.txt
paddle/fluid/lite/core/mir/elimination/CMakeLists.txt
+7
-0
paddle/fluid/lite/core/mir/elimination/identity_scale_eliminate_pass.cc
...ite/core/mir/elimination/identity_scale_eliminate_pass.cc
+72
-0
paddle/fluid/lite/core/mir/elimination/identity_scale_eliminate_pass_test.cc
...ore/mir/elimination/identity_scale_eliminate_pass_test.cc
+93
-0
paddle/fluid/lite/core/mir/fusion/CMakeLists.txt
paddle/fluid/lite/core/mir/fusion/CMakeLists.txt
+0
-1
paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass.cc
paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass.cc
+1
-1
paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass.h
paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass.h
+0
-0
paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass_test.cc
paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass_test.cc
+1
-1
paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.cc
paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.cc
+1
-1
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuse_pass.cc
...e/mir/fusion/conv_elementwise_add_activation_fuse_pass.cc
+1
-1
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuse_pass.h
...re/mir/fusion/conv_elementwise_add_activation_fuse_pass.h
+0
-0
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuse_pass_test.cc
.../fusion/conv_elementwise_add_activation_fuse_pass_test.cc
+1
-1
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.cc
.../core/mir/fusion/conv_elementwise_add_activation_fuser.cc
+1
-1
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass.cc
...te/core/mir/fusion/conv_elementwise_add_relu_fuse_pass.cc
+39
-0
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass.h
...ite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass.h
+32
-0
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass_test.cc
...re/mir/fusion/conv_elementwise_add_relu_fuse_pass_test.cc
+153
-0
paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.cc
...e/core/mir/fusion/elementwise_add_activation_fuse_pass.cc
+1
-1
paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.h
...te/core/mir/fusion/elementwise_add_activation_fuse_pass.h
+0
-0
paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuse_pass_test.cc
...e/mir/fusion/elementwise_add_activation_fuse_pass_test.cc
+1
-1
paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.cc
.../lite/core/mir/fusion/elementwise_add_activation_fuser.cc
+1
-1
paddle/fluid/lite/core/mir/fusion/fc_fuse_pass.cc
paddle/fluid/lite/core/mir/fusion/fc_fuse_pass.cc
+1
-1
paddle/fluid/lite/core/mir/fusion/fc_fuse_pass.h
paddle/fluid/lite/core/mir/fusion/fc_fuse_pass.h
+0
-0
paddle/fluid/lite/core/mir/fusion/fc_fuse_pass_test.cc
paddle/fluid/lite/core/mir/fusion/fc_fuse_pass_test.cc
+1
-1
paddle/fluid/lite/core/mir/fusion/fc_fuser.cc
paddle/fluid/lite/core/mir/fusion/fc_fuser.cc
+1
-1
paddle/fluid/lite/core/mir/fusion/quant_dequant_fuse_pass.cc
paddle/fluid/lite/core/mir/fusion/quant_dequant_fuse_pass.cc
+1
-1
paddle/fluid/lite/core/mir/fusion/quant_dequant_fuse_pass.h
paddle/fluid/lite/core/mir/fusion/quant_dequant_fuse_pass.h
+0
-0
paddle/fluid/lite/core/mir/fusion/quant_dequant_op_fuser.cc
paddle/fluid/lite/core/mir/fusion/quant_dequant_op_fuser.cc
+2
-2
paddle/fluid/lite/core/mir/generate_program_pass.cc
paddle/fluid/lite/core/mir/generate_program_pass.cc
+1
-1
paddle/fluid/lite/core/mir/graph_visualize_pass.cc
paddle/fluid/lite/core/mir/graph_visualize_pass.cc
+1
-1
paddle/fluid/lite/core/mir/io_copy_kernel_pick_pass.cc
paddle/fluid/lite/core/mir/io_copy_kernel_pick_pass.cc
+2
-2
paddle/fluid/lite/core/mir/node.cc
paddle/fluid/lite/core/mir/node.cc
+59
-0
paddle/fluid/lite/core/mir/node.h
paddle/fluid/lite/core/mir/node.h
+32
-34
paddle/fluid/lite/core/mir/pattern_matcher.h
paddle/fluid/lite/core/mir/pattern_matcher.h
+2
-3
paddle/fluid/lite/core/mir/pattern_matcher_high_api.cc
paddle/fluid/lite/core/mir/pattern_matcher_high_api.cc
+1
-0
paddle/fluid/lite/core/mir/pattern_matcher_high_api.h
paddle/fluid/lite/core/mir/pattern_matcher_high_api.h
+7
-1
paddle/fluid/lite/core/mir/pattern_matcher_high_api_test.cc
paddle/fluid/lite/core/mir/pattern_matcher_high_api_test.cc
+2
-2
paddle/fluid/lite/core/mir/pattern_matcher_test.cc
paddle/fluid/lite/core/mir/pattern_matcher_test.cc
+13
-13
paddle/fluid/lite/core/mir/ssa_graph.h
paddle/fluid/lite/core/mir/ssa_graph.h
+5
-0
paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc
paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc
+6
-6
paddle/fluid/lite/core/mir/type_target_transform_pass.cc
paddle/fluid/lite/core/mir/type_target_transform_pass.cc
+8
-8
paddle/fluid/lite/core/mir/use_passes.h
paddle/fluid/lite/core/mir/use_passes.h
+3
-1
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
+6
-4
paddle/fluid/lite/core/op_lite.cc
paddle/fluid/lite/core/op_lite.cc
+1
-2
paddle/fluid/lite/core/op_lite.h
paddle/fluid/lite/core/op_lite.h
+16
-0
paddle/fluid/lite/core/optimizer.h
paddle/fluid/lite/core/optimizer.h
+4
-0
paddle/fluid/lite/core/program.h
paddle/fluid/lite/core/program.h
+1
-1
paddle/fluid/lite/core/tensor.h
paddle/fluid/lite/core/tensor.h
+0
-1
paddle/fluid/lite/kernels/arm/elementwise_compute_test.cc
paddle/fluid/lite/kernels/arm/elementwise_compute_test.cc
+29
-0
paddle/fluid/lite/kernels/arm/softmax_compute_test.cc
paddle/fluid/lite/kernels/arm/softmax_compute_test.cc
+8
-1
paddle/fluid/lite/kernels/x86/elementwise_compute_test.cc
paddle/fluid/lite/kernels/x86/elementwise_compute_test.cc
+2
-0
paddle/fluid/lite/kernels/x86/mul_compute.h
paddle/fluid/lite/kernels/x86/mul_compute.h
+30
-12
paddle/fluid/lite/kernels/x86/mul_compute_test.cc
paddle/fluid/lite/kernels/x86/mul_compute_test.cc
+2
-0
paddle/fluid/lite/model_parser/CMakeLists.txt
paddle/fluid/lite/model_parser/CMakeLists.txt
+2
-1
paddle/fluid/lite/model_parser/cpp/CMakeLists.txt
paddle/fluid/lite/model_parser/cpp/CMakeLists.txt
+0
-1
paddle/fluid/lite/model_parser/desc_apis.h
paddle/fluid/lite/model_parser/desc_apis.h
+22
-0
paddle/fluid/lite/model_parser/pb/op_desc.h
paddle/fluid/lite/model_parser/pb/op_desc.h
+2
-0
paddle/fluid/lite/operators/mul_op.h
paddle/fluid/lite/operators/mul_op.h
+6
-2
paddle/fluid/lite/operators/op_params.h
paddle/fluid/lite/operators/op_params.h
+2
-2
paddle/fluid/lite/tools/build.sh
paddle/fluid/lite/tools/build.sh
+140
-81
paddle/fluid/lite/utils/varient.h
paddle/fluid/lite/utils/varient.h
+4
-3
未找到文件。
CMakeLists.txt
浏览文件 @
4fe5c8aa
...
...
@@ -44,9 +44,9 @@ if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
# check arch abi
if
(
NOT DEFINED ARM_TARGET_LANG
)
set
(
ARM_TARGET_LANG
"
clang
"
CACHE STRING
"Choose ARM Target Language"
)
set
(
ARM_TARGET_LANG
"
gcc
"
CACHE STRING
"Choose ARM Target Language"
)
endif
()
set
(
ARM_TARGET_LANG_LIST
"gcc"
"clang"
)
set
(
ARM_TARGET_LANG_LIST
"gcc"
"clang"
""
)
set_property
(
CACHE ARM_TARGET_LANG PROPERTY STRINGS
${
ARM_TARGET_LANG_LIST
}
)
if
(
NOT ARM_TARGET_LANG IN_LIST ARM_TARGET_LANG_LIST
)
message
(
FATAL_ERROR
"ARM_TARGET_LANG must be in one of
${
ARM_TARGET_LANG_LIST
}
"
)
...
...
paddle/fluid/lite/api/apis_test.cc
浏览文件 @
4fe5c8aa
...
...
@@ -82,7 +82,7 @@ TEST(CXXApi_LightApi, save_and_load_model) {
ASSERT_TRUE
(
TensorCompareWith
(
*
cxx_out
,
*
light_out
));
std
::
vector
<
std
::
string
>
tensors_with_order
({
"a"
,
"fc_0.w_0"
,
"
fc_0.tmp_0"
,
"
scale_0.tmp_0"
,
"a"
,
"fc_0.w_0"
,
"scale_0.tmp_0"
,
});
for
(
const
auto
&
tensor_name
:
tensors_with_order
)
{
...
...
paddle/fluid/lite/api/cxx_api_bin.cc
浏览文件 @
4fe5c8aa
...
...
@@ -13,7 +13,7 @@
// limitations under the License.
#include "paddle/fluid/lite/api/cxx_api.h"
#include <chrono>
#include <chrono>
// NOLINT
#include "paddle/fluid/lite/core/mir/use_passes.h"
#include "paddle/fluid/lite/core/op_registry.h"
...
...
paddle/fluid/lite/api/lite_api_test_helper.cc
浏览文件 @
4fe5c8aa
...
...
@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/lite/api/lite_api_test_helper.h"
#include <vector>
DEFINE_string
(
model_dir
,
""
,
""
);
DEFINE_string
(
optimized_model
,
""
,
""
);
...
...
paddle/fluid/lite/core/CMakeLists.txt
浏览文件 @
4fe5c8aa
...
...
@@ -59,4 +59,3 @@ lite_cc_test(test_type_system SRCS type_system_test.cc DEPS type_system utils_li
lite_cc_test
(
test_types_lite SRCS types_test.cc DEPS types_lite
)
lite_cc_test
(
test_memory_lite SRCS memory_test.cc DEPS memory_lite
)
lite_cc_test
(
test_context_lite SRCS context_test.cc DEPS context_lite X86_DEPS operator
)
paddle/fluid/lite/core/mir/CMakeLists.txt
浏览文件 @
4fe5c8aa
...
...
@@ -5,12 +5,16 @@ cc_library(mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph mir
cc_library
(
mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager
)
add_subdirectory
(
fusion
)
add_subdirectory
(
elimination
)
cc_library
(
mir_passes
SRCS fc_fuse_pass.cc
conv_elementwise_add_activation_fuse_pass.cc
elementwise_add_activation_fuse_pass.cc
conv_bn_fuse_pass.cc
quant_dequant_fuse_pass.cc
SRCS
fusion/fc_fuse_pass.cc
fusion/conv_elementwise_add_activation_fuse_pass.cc
fusion/conv_bn_fuse_pass.cc
fusion/elementwise_add_activation_fuse_pass.cc
fusion/quant_dequant_fuse_pass.cc
elimination/identity_scale_eliminate_pass.cc
static_kernel_pick_pass.cc
variable_place_inference_pass.cc
type_target_transform_pass.cc
...
...
@@ -74,7 +78,7 @@ message(STATUS "----> Ops lite: ${ops_lite}")
message
(
STATUS
"----> Host kernels:
${
host_kernels
}
"
)
message
(
STATUS
"----> X86 kernels:
${
x86_kernels
}
"
)
lite_cc_test
(
test_lite_fc_fuse SRCS fc_fuse_pass_test.cc
lite_cc_test
(
test_lite_fc_fuse SRCS f
usion/f
c_fuse_pass_test.cc
DEPS cxx_api_lite mir_passes
${
ops_lite
}
${
host_kernels
}
${
x86_kernels
}
${
arm_kernels
}
ARGS --model_dir=
${
LITE_MODEL_DIR
}
/lite_fc_model
...
...
@@ -85,10 +89,10 @@ add_dependencies(test_lite_fc_fuse extern_lite_download_lite_fc_model_tar_gz)
lite_cc_test
(
test_lite_conv_elementwise_add_activation_fuse
SRCS conv_elementwise_add_activation_fuse_pass_test.cc
SRCS
fusion/
conv_elementwise_add_activation_fuse_pass_test.cc
DEPS cxx_api_lite mir_passes
${
ops_lite
}
${
host_kernels
}
${
x86_kernels
}
)
lite_cc_test
(
test_lite_elementwise_add_activation_fuse
SRCS elementwise_add_activation_fuse_pass_test.cc
SRCS
fusion/
elementwise_add_activation_fuse_pass_test.cc
DEPS cxx_api_lite mir_passes
${
ops_lite
}
${
host_kernels
}
${
x86_kernels
}
)
paddle/fluid/lite/core/mir/elimination/CMakeLists.txt
0 → 100644
浏览文件 @
4fe5c8aa
if
(
NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
)
lite_cc_test
(
test_identity_scale_eliminate_pass_lite
SRCS identity_scale_eliminate_pass_test.cc
DEPS mir_passes program_lite proto_desc cpp_op_desc_lite
${
ops_lite
}
)
endif
()
paddle/fluid/lite/core/mir/elimination/identity_scale_eliminate_pass.cc
0 → 100644
浏览文件 @
4fe5c8aa
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
namespace
{
class
Eliminator
:
public
FuseBase
{
public:
void
BuildPattern
()
override
{
auto
*
pre_op
=
OpNode
(
"preop"
);
// the previous op's output need update
// TODO(Superjomn) check has only one output
auto
*
x
=
VarNode
(
"x"
)
->
assert_is_op_input
(
"scale"
,
"X"
);
auto
*
scale_op
=
OpNode
(
"scale"
,
"scale"
)
->
assert_op_attr
<
float
>
(
"scale"
,
1.
)
->
assert_op_attr
<
float
>
(
"bias"
,
0.
);
auto
*
out
=
VarNode
(
"out"
)
->
assert_is_op_output
(
"scale"
,
"Out"
);
*
pre_op
>>
*
x
>>
*
scale_op
>>
*
out
;
// The pre_op will be eliminated, and a new output-updated op will insert.
x
->
AsIntermediate
();
// x is pre_op's output, need to update
}
private:
void
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
override
{
auto
&
pre_op
=
matched
.
at
(
"preop"
)
->
AsStmt
();
auto
op_info
=
*
pre_op
.
op_info
();
op_info
.
UpdateAllOutputs
(
matched
.
at
(
"x"
)
->
AsArg
().
name
,
matched
.
at
(
"out"
)
->
AsArg
().
name
);
pre_op
.
ResetOp
(
op_info
,
graph
->
valid_places
());
GraphSafeRemoveNodes
(
graph
,
{
matched
.
at
(
"scale"
)});
IR_NODE_LINK_TO
(
matched
.
at
(
"preop"
),
matched
.
at
(
"out"
));
}
};
}
// namespace
class
IdentityScaleEliminatePass
:
public
ProgramPass
{
public:
void
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
override
{
Eliminator
eliminator
;
eliminator
(
graph
.
get
());
}
};
}
// namespace mir
}
// namespace lite
}
// namespace paddle
REGISTER_MIR_PASS
(
identity_scale_eliminate_pass
,
paddle
::
lite
::
mir
::
IdentityScaleEliminatePass
);
paddle/fluid/lite/core/mir/elimination/identity_scale_eliminate_pass_test.cc
0 → 100644
浏览文件 @
4fe5c8aa
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
#include "paddle/fluid/lite/core/mir/ssa_graph.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
std
::
unique_ptr
<
SSAGraph
>
BuildGraph
(
framework
::
ProgramDesc
*
program_desc
,
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
std
::
vector
<
Place
>&
valid_places
)
{
// Op list:
// (x)->feed -> (feed) -> scale -> (scale_out) -> fetch->(fetch)
// After pass
// (x)->feed->(scale_out)->fetch->(fetch)
auto
*
main_block
=
program_desc
->
MutableBlock
(
0
);
auto
*
feed_op
=
main_block
->
AppendOp
();
auto
*
scale_op
=
main_block
->
AppendOp
();
auto
*
fetch_op
=
main_block
->
AppendOp
();
main_block
->
Var
(
"x"
);
main_block
->
Var
(
"feed"
);
main_block
->
Var
(
"scale_out"
);
main_block
->
Var
(
"fetch_out"
);
scope
->
Var
(
"x"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"feed"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"scale_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"fetch_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
feed_op
->
SetType
(
"feed"
);
feed_op
->
SetInput
(
"X"
,
{
"x"
});
feed_op
->
SetAttr
(
"col"
,
1
);
feed_op
->
SetOutput
(
"Out"
,
{
"feed"
});
scale_op
->
SetType
(
"scale"
);
scale_op
->
SetInput
(
"X"
,
{
"feed"
});
scale_op
->
SetOutput
(
"Out"
,
{
"scale_out"
});
scale_op
->
SetAttr
(
"scale"
,
1.
f
);
scale_op
->
SetAttr
(
"bias"
,
0.
f
);
scale_op
->
SetAttr
(
"bias_after_scale"
,
true
);
fetch_op
->
SetType
(
"fetch"
);
fetch_op
->
SetInput
(
"X"
,
{
"scale_out"
});
fetch_op
->
SetOutput
(
"Out"
,
{
"fetch"
});
fetch_op
->
SetAttr
(
"col"
,
1
);
program_desc
->
Flush
();
lite
::
Program
program
(
*
program_desc
->
Proto
(),
scope
,
valid_places
);
auto
graph
=
std
::
unique_ptr
<
SSAGraph
>
(
new
SSAGraph
());
graph
->
Build
(
program
,
valid_places
);
LOG
(
INFO
)
<<
Visualize
(
graph
.
get
());
return
graph
;
}
TEST
(
identity_test
,
test
)
{
framework
::
ProgramDesc
program_desc
;
std
::
vector
<
Place
>
places
{{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}};
auto
scope
=
std
::
make_shared
<
Scope
>
();
auto
graph
=
BuildGraph
(
&
program_desc
,
scope
,
places
);
const
int
num_nodes
=
graph
->
nodes
().
size
();
auto
pass
=
PassManager
::
Global
().
LookUp
(
"identity_scale_eliminate_pass"
);
ASSERT_TRUE
(
pass
);
pass
->
Apply
(
graph
);
ASSERT_EQ
(
graph
->
nodes
().
size
(),
num_nodes
-
2UL
);
}
}
// namespace mir
}
// namespace lite
}
// namespace paddle
USE_LITE_OP
(
feed
)
USE_LITE_OP
(
fetch
)
USE_LITE_OP
(
scale
)
USE_MIR_PASS
(
identity_scale_eliminate_pass
)
paddle/fluid/lite/core/mir/fusion/CMakeLists.txt
浏览文件 @
4fe5c8aa
...
...
@@ -10,7 +10,6 @@ cc_library(fuse_conv_bn
cc_library
(
fuse_elementwise_add_activation
SRCS elementwise_add_activation_fuser.cc
DEPS pattern_matcher_high_api
)
cc_library
(
fuse_quant_dequant
SRCS quant_dequant_op_fuser.cc
DEPS pattern_matcher_high_api
)
...
...
paddle/fluid/lite/core/mir/conv_bn_fuse_pass.cc
→
paddle/fluid/lite/core/mir/
fusion/
conv_bn_fuse_pass.cc
浏览文件 @
4fe5c8aa
...
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/
fusion/
conv_bn_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.h"
...
...
paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h
→
paddle/fluid/lite/core/mir/
fusion/
conv_bn_fuse_pass.h
浏览文件 @
4fe5c8aa
文件已移动
paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass_test.cc
浏览文件 @
4fe5c8aa
...
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/
fusion/
conv_bn_fuse_pass.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
...
...
paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.cc
浏览文件 @
4fe5c8aa
...
...
@@ -70,7 +70,7 @@ void ConvBNFuser::BuildPattern() {
void
ConvBNFuser
::
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
{
auto
op_desc
=
GenOpDesc
(
matched
);
auto
eltwise_op
=
LiteOpRegistry
::
Global
().
Create
(
"elementwise_add"
);
auto
conv
=
matched
.
at
(
"conv2d"
)
->
stmt
()
->
op
;
auto
conv
=
matched
.
at
(
"conv2d"
)
->
stmt
()
->
op
()
;
auto
*
scope
=
conv
->
scope
();
auto
&
valid_places
=
conv
->
valid_places
();
...
...
paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.cc
→
paddle/fluid/lite/core/mir/
fusion/
conv_elementwise_add_activation_fuse_pass.cc
浏览文件 @
4fe5c8aa
...
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/
fusion/
conv_elementwise_add_activation_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.h"
...
...
paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.h
→
paddle/fluid/lite/core/mir/
fusion/
conv_elementwise_add_activation_fuse_pass.h
浏览文件 @
4fe5c8aa
文件已移动
paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass_test.cc
→
paddle/fluid/lite/core/mir/
fusion/
conv_elementwise_add_activation_fuse_pass_test.cc
浏览文件 @
4fe5c8aa
...
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/
fusion/
conv_elementwise_add_activation_fuse_pass.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
...
...
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.cc
浏览文件 @
4fe5c8aa
...
...
@@ -65,7 +65,7 @@ void ConvElementwiseAddActivationFuser::InsertNewNode(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
{
auto
op_desc
=
GenOpDesc
(
matched
);
auto
conv_op
=
LiteOpRegistry
::
Global
().
Create
(
conv_type_
);
auto
conv_old
=
matched
.
at
(
"conv2d"
)
->
stmt
()
->
op
;
auto
conv_old
=
matched
.
at
(
"conv2d"
)
->
stmt
()
->
op
()
;
auto
*
scope
=
conv_old
->
scope
();
auto
&
valid_places
=
conv_old
->
valid_places
();
conv_op
->
Attach
(
op_desc
,
scope
);
...
...
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass.cc
0 → 100644
浏览文件 @
4fe5c8aa
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
void
ConvElementwiseAddReLUFusePass
::
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
{
fusion
::
ConvElementwiseAddReLUFuser
fuser
(
"conv2d"
);
fuser
(
graph
.
get
());
fusion
::
ConvElementwiseAddReLUFuser
depthwise_fuser
(
"depthwise_conv2d"
);
depthwise_fuser
(
graph
.
get
());
}
}
// namespace mir
}
// namespace lite
}
// namespace paddle
REGISTER_MIR_PASS
(
lite_conv_elementwise_add_act_fuse_pass
,
paddle
::
lite
::
mir
::
ConvElementwiseAddReLUFusePass
);
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass.h
0 → 100644
浏览文件 @
4fe5c8aa
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pass.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
class
ConvElementwiseAddReLUFusePass
:
public
ProgramPass
{
public:
void
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
override
;
};
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass_test.cc
0 → 100644
浏览文件 @
4fe5c8aa
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/api/cxx_api.h"
#include "paddle/fluid/lite/core/compatible_tensor.h"
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/program.h"
DEFINE_string
(
model_dir
,
""
,
""
);
DEFINE_string
(
optimized_model
,
""
,
""
);
namespace
paddle
{
namespace
lite
{
namespace
mir
{
namespace
fusion
{
std
::
unique_ptr
<
SSAGraph
>
BuildGraph
(
framework
::
ProgramDesc
*
program_desc
,
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
std
::
vector
<
Place
>&
valid_places
)
{
auto
*
main_block
=
program_desc
->
MutableBlock
(
0
);
auto
*
conv2d_1
=
main_block
->
AppendOp
();
auto
*
conv2d_2
=
main_block
->
AppendOp
();
auto
*
add_1
=
main_block
->
AppendOp
();
auto
*
relu_1
=
main_block
->
AppendOp
();
auto
*
add_2
=
main_block
->
AppendOp
();
auto
*
relu_2
=
main_block
->
AppendOp
();
main_block
->
Var
(
"input_1"
);
main_block
->
Var
(
"input_2"
);
main_block
->
Var
(
"filter_1"
);
main_block
->
Var
(
"filter_2"
);
main_block
->
Var
(
"conv2d_1_out"
);
main_block
->
Var
(
"conv2d_2_out"
);
main_block
->
Var
(
"bias_1"
);
main_block
->
Var
(
"add_1_out"
);
main_block
->
Var
(
"add_2_out"
);
main_block
->
Var
(
"relu_1_out"
);
main_block
->
Var
(
"out"
);
scope
->
Var
(
"input_1"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"input_2"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"filter_1"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"filter_2"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"conv2d_1_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"conv2d_2_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"bias_1"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"add_1_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"add_2_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"relu_1_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"out"
)
->
GetMutable
<
lite
::
Tensor
>
();
conv2d_1
->
SetType
(
"conv2d"
);
conv2d_1
->
SetInput
(
"Input"
,
{
"input_1"
});
conv2d_1
->
SetInput
(
"Filter"
,
{
"filter_1"
});
conv2d_1
->
SetOutput
(
"Output"
,
{
"conv2d_1_out"
});
conv2d_1
->
SetAttr
(
"strides"
,
std
::
vector
<
int
>
({
1
,
1
}));
conv2d_1
->
SetAttr
(
"paddings"
,
std
::
vector
<
int
>
({
0
,
0
}));
conv2d_1
->
SetAttr
(
"groups"
,
1
);
conv2d_1
->
SetAttr
(
"dilations"
,
std
::
vector
<
int
>
({
1
,
1
}));
conv2d_1
->
SetAttr
(
"fuse_relu"
,
false
);
add_1
->
SetType
(
"elementwise_add"
);
add_1
->
SetInput
(
"X"
,
{
"conv2d_1_out"
});
add_1
->
SetInput
(
"Y"
,
{
"bias_1"
});
add_1
->
SetOutput
(
"Out"
,
{
"add_1_out"
});
add_1
->
SetAttr
(
"axis"
,
1
);
relu_1
->
SetType
(
"relu"
);
relu_1
->
SetInput
(
"X"
,
{
"add_1_out"
});
relu_1
->
SetOutput
(
"Out"
,
{
"relu_1_out"
});
conv2d_2
->
SetType
(
"conv2d"
);
conv2d_2
->
SetInput
(
"Input"
,
{
"input_2"
});
conv2d_2
->
SetInput
(
"Filter"
,
{
"filter_2"
});
conv2d_2
->
SetOutput
(
"Output"
,
{
"conv2d_2_out"
});
conv2d_2
->
SetAttr
(
"strides"
,
std
::
vector
<
int
>
({
1
,
1
}));
conv2d_2
->
SetAttr
(
"paddings"
,
std
::
vector
<
int
>
({
0
,
0
}));
conv2d_2
->
SetAttr
(
"groups"
,
1
);
conv2d_2
->
SetAttr
(
"dilations"
,
std
::
vector
<
int
>
({
1
,
1
}));
conv2d_2
->
SetAttr
(
"fuse_relu"
,
false
);
add_2
->
SetType
(
"elementwise_add"
);
add_2
->
SetInput
(
"X"
,
{
"conv2d_2_out"
});
add_2
->
SetInput
(
"Y"
,
{
"relu_1_out"
});
add_2
->
SetOutput
(
"Out"
,
{
"add_2_out"
});
add_2
->
SetAttr
(
"axis"
,
1
);
relu_2
->
SetType
(
"relu"
);
relu_2
->
SetInput
(
"X"
,
{
"add_2_out"
});
relu_2
->
SetOutput
(
"Out"
,
{
"out"
});
program_desc
->
Flush
();
lite
::
Program
program
(
*
program_desc
->
Proto
(),
scope
,
valid_places
);
auto
graph
=
std
::
unique_ptr
<
SSAGraph
>
(
new
SSAGraph
());
graph
->
Build
(
program
,
valid_places
);
return
graph
;
}
TEST
(
conv_elementwise_add_relu_fuse_pass
,
graph_test
)
{
framework
::
ProgramDesc
program_desc
;
std
::
vector
<
Place
>
places
{{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}};
auto
scope
=
std
::
make_shared
<
Scope
>
();
auto
graph
=
BuildGraph
(
&
program_desc
,
scope
,
places
);
Visualize
(
graph
.
get
());
ASSERT_EQ
(
graph
->
nodes
().
size
(),
11UL
/*vars*/
+
6UL
/*ops*/
);
Visualize
(
graph
.
get
());
}
TEST
(
conv_elementwise_add_relu_fuse_pass
,
fuse_test_op
)
{
framework
::
ProgramDesc
program_desc
;
std
::
vector
<
Place
>
places
{{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}};
auto
scope
=
std
::
make_shared
<
Scope
>
();
auto
graph
=
BuildGraph
(
&
program_desc
,
scope
,
places
);
Visualize
(
graph
.
get
());
const
int
num_nodes
=
graph
->
nodes
().
size
();
auto
*
fuser
=
new
ConvElementwiseAddReLUFusePass
;
fuser
->
Apply
(
graph
);
Visualize
(
graph
.
get
());
ASSERT_EQ
(
graph
->
nodes
().
size
(),
num_nodes
-
5UL
*
2
/*nodes removed */
+
1UL
*
2
/* fused fc node*/
);
}
}
// namespace fusion
}
// namespace mir
}
// namespace lite
}
// namespace paddle
USE_LITE_OP
(
elementwise_add
);
USE_LITE_OP
(
conv2d
);
USE_LITE_OP
(
depthwise_conv2d
);
USE_LITE_OP
(
relu
);
paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.cc
→
paddle/fluid/lite/core/mir/
fusion/
elementwise_add_activation_fuse_pass.cc
浏览文件 @
4fe5c8aa
...
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/
fusion/
elementwise_add_activation_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.h"
...
...
paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.h
→
paddle/fluid/lite/core/mir/
fusion/
elementwise_add_activation_fuse_pass.h
浏览文件 @
4fe5c8aa
文件已移动
paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass_test.cc
→
paddle/fluid/lite/core/mir/
fusion/
elementwise_add_activation_fuse_pass_test.cc
浏览文件 @
4fe5c8aa
...
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/
fusion/
elementwise_add_activation_fuse_pass.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
...
...
paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.cc
浏览文件 @
4fe5c8aa
...
...
@@ -54,7 +54,7 @@ void ElementwiseAddActivationFuser::InsertNewNode(SSAGraph* graph,
auto
op_desc
=
GenOpDesc
(
matched
);
auto
op
=
LiteOpRegistry
::
Global
().
Create
(
"fusion_elementwise_add_activation"
);
auto
old_op
=
matched
.
at
(
"add"
)
->
stmt
()
->
op
;
auto
old_op
=
matched
.
at
(
"add"
)
->
stmt
()
->
op
()
;
auto
*
scope
=
old_op
->
scope
();
auto
&
valid_places
=
old_op
->
valid_places
();
op
->
Attach
(
op_desc
,
scope
);
...
...
paddle/fluid/lite/core/mir/fc_fuse_pass.cc
→
paddle/fluid/lite/core/mir/f
usion/f
c_fuse_pass.cc
浏览文件 @
4fe5c8aa
...
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fc_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/f
usion/f
c_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/fc_fuser.h"
...
...
paddle/fluid/lite/core/mir/fc_fuse_pass.h
→
paddle/fluid/lite/core/mir/f
usion/f
c_fuse_pass.h
浏览文件 @
4fe5c8aa
文件已移动
paddle/fluid/lite/core/mir/fc_fuse_pass_test.cc
→
paddle/fluid/lite/core/mir/f
usion/f
c_fuse_pass_test.cc
浏览文件 @
4fe5c8aa
...
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fc_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/f
usion/f
c_fuse_pass.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
...
...
paddle/fluid/lite/core/mir/fusion/fc_fuser.cc
浏览文件 @
4fe5c8aa
...
...
@@ -46,7 +46,7 @@ void FcFuser::BuildPattern() {
void
FcFuser
::
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
{
auto
op_desc
=
GenOpDesc
(
matched
);
auto
fc_op
=
LiteOpRegistry
::
Global
().
Create
(
"fc"
);
auto
mul
=
matched
.
at
(
"mul"
)
->
stmt
()
->
op
;
auto
mul
=
matched
.
at
(
"mul"
)
->
stmt
()
->
op
()
;
auto
*
scope
=
mul
->
scope
();
auto
&
valid_places
=
mul
->
valid_places
();
fc_op
->
Attach
(
op_desc
,
scope
);
...
...
paddle/fluid/lite/core/mir/quant_dequant_fuse_pass.cc
→
paddle/fluid/lite/core/mir/
fusion/
quant_dequant_fuse_pass.cc
浏览文件 @
4fe5c8aa
...
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/quant_dequant_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/
fusion/
quant_dequant_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/quant_dequant_op_fuser.h"
...
...
paddle/fluid/lite/core/mir/quant_dequant_fuse_pass.h
→
paddle/fluid/lite/core/mir/
fusion/
quant_dequant_fuse_pass.h
浏览文件 @
4fe5c8aa
文件已移动
paddle/fluid/lite/core/mir/fusion/quant_dequant_op_fuser.cc
浏览文件 @
4fe5c8aa
...
...
@@ -115,8 +115,8 @@ void QuantDequantOpFuser::InsertNewNode(SSAGraph* graph,
nodes
.
push_back
(
matched
.
at
(
"dequant_op_out"
+
std
::
to_string
(
i
)));
}
int
bit_length
=
quant_op
->
stmt
()
->
op_info
()
->
GetAttr
<
int
>
(
"bit_length"
);
auto
*
scope
=
quant_op
->
stmt
()
->
op
->
scope
();
auto
&
valid_places
=
quant_op
->
stmt
()
->
op
->
valid_places
();
auto
*
scope
=
quant_op
->
stmt
()
->
op
()
->
scope
();
auto
&
valid_places
=
quant_op
->
stmt
()
->
op
()
->
valid_places
();
int
range
=
((
1
<<
(
bit_length
-
1
))
-
1
);
auto
input_scale_t
=
scope
->
FindVar
(
quant_op_in_scale
->
arg
()
->
name
)
->
GetMutable
<
lite
::
Tensor
>
();
...
...
paddle/fluid/lite/core/mir/generate_program_pass.cc
浏览文件 @
4fe5c8aa
...
...
@@ -29,7 +29,7 @@ void GenerateProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
if
(
item
->
IsStmt
())
{
auto
&
stmt
=
item
->
AsStmt
();
VLOG
(
4
)
<<
stmt
;
insts_
.
emplace_back
(
stmt
.
op
,
std
::
move
(
stmt
.
valid_kernels
.
front
()));
insts_
.
emplace_back
(
stmt
.
op
(),
std
::
move
(
stmt
.
kernels
()
.
front
()));
}
}
}
...
...
paddle/fluid/lite/core/mir/graph_visualize_pass.cc
浏览文件 @
4fe5c8aa
...
...
@@ -39,7 +39,7 @@ std::string Visualize(mir::SSAGraph* graph) {
if
(
node
.
IsArg
())
{
key
=
node
.
AsArg
().
name
;
}
else
{
key
=
node
.
AsStmt
().
op_type
+
std
::
to_string
(
id
++
);
key
=
node
.
AsStmt
().
op_type
()
+
std
::
to_string
(
id
++
);
}
if
(
node
.
IsStmt
())
{
...
...
paddle/fluid/lite/core/mir/io_copy_kernel_pick_pass.cc
浏览文件 @
4fe5c8aa
...
...
@@ -25,11 +25,11 @@ class IoCopyKernelPickPass : public StmtPass {
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
if
(
!
node
.
IsStmt
())
continue
;
auto
&
inst
=
node
.
AsStmt
();
if
(
inst
.
op_type
!=
"io_copy"
)
continue
;
if
(
inst
.
op_type
()
!=
"io_copy"
)
continue
;
LOG
(
INFO
)
<<
"....> picking a IO COPY kernel"
;
auto
&
kernels
=
node
.
AsStmt
().
valid_kernels
;
auto
&
kernels
=
node
.
AsStmt
().
kernels
()
;
CHECK
(
!
kernels
.
empty
())
<<
"No valid kernels found for IoCopy Op"
;
const
auto
*
inty
=
node
.
inlinks
.
front
()
->
AsArg
().
type
;
const
auto
*
outy
=
node
.
outlinks
.
front
()
->
AsArg
().
type
;
...
...
paddle/fluid/lite/core/mir/node.cc
浏览文件 @
4fe5c8aa
...
...
@@ -13,3 +13,62 @@
// limitations under the License.
#include "paddle/fluid/lite/core/mir/node.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace
paddle
{
namespace
lite
{
const
OpInfo
*
mir
::
Node
::
Stmt
::
op_info
()
const
{
CHECK
(
op_
);
return
op_
->
op_info
();
}
Place
mir
::
Node
::
Stmt
::
place
()
const
{
CHECK
(
!
valid_kernels_
.
empty
());
return
valid_kernels_
.
front
()
->
place
();
}
KernelBase
&
mir
::
Node
::
Stmt
::
picked_kernel
()
{
CHECK
(
!
valid_kernels_
.
empty
())
<<
"no kernel for "
<<
op_type
();
return
*
valid_kernels_
.
front
();
}
OpInfo
*
mir
::
Node
::
Stmt
::
mutable_op_info
()
{
CHECK
(
op_
);
return
op_
->
mutable_op_info
();
}
void
mir
::
Node
::
Stmt
::
ResetOp
(
const
cpp
::
OpDesc
&
op_desc
,
const
std
::
vector
<
Place
>
&
valid_places
,
lite
::
Scope
*
scope
)
{
CHECK
((
op_
&&
op_
->
scope
())
||
scope
)
<<
"Either scope should be set"
;
lite
::
Scope
*
the_scope
=
scope
?
scope
:
op_
->
scope
();
op_
->
Attach
(
op_desc
,
the_scope
);
// Recreate the kernels with the latest OpInfo.
valid_kernels_
.
clear
();
if
(
!
op_
||
op_
->
op_info
()
->
Type
()
!=
op_desc
.
Type
())
{
op_
=
LiteOpRegistry
::
Global
().
Create
(
op_desc
.
Type
());
CHECK
(
op_
)
<<
"No op found for "
<<
op_desc
.
Type
();
}
valid_kernels_
=
op_
->
CreateKernels
(
valid_places
);
}
std
::
ostream
&
mir
::
operator
<<
(
std
::
ostream
&
os
,
const
mir
::
Node
::
Stmt
&
other
)
{
os
<<
"Statement "
<<
other
.
op_type
()
<<
" "
<<
other
.
place
();
return
os
;
}
mir
::
Node
::
Arg
&
mir
::
Node
::
AsArg
(
const
std
::
string
&
name
,
int
id
)
{
auto
&
x
=
AsArg
();
x
.
name
=
name
;
x
.
id
=
id
;
return
x
;
}
mir
::
Node
::
Arg
&
mir
::
Node
::
AsArg
(
const
std
::
string
&
name
)
{
auto
&
x
=
AsArg
();
x
.
name
=
name
;
return
x
;
}
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/node.h
浏览文件 @
4fe5c8aa
...
...
@@ -41,32 +41,40 @@ class Node {
kUnk
,
};
struct
Stmt
{
std
::
string
op_type
;
class
Stmt
{
// The kernel instances this Statement contains.
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>
valid_kernels
;
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>
valid_kernels
_
;
// TODO(Superjomn) make this a shared_ptr for resource safety.
std
::
shared_ptr
<
OpLite
>
op
;
// we hold op to run InferShape
std
::
shared_ptr
<
OpLite
>
op
_
;
// we hold op to run InferShape
const
OpInfo
*
op_info
()
{
CHECK
(
op
);
return
op
->
op_info
();
}
public:
// Refresh the operator and kernels with the latest OpInfo.
void
ResetOp
(
const
cpp
::
OpDesc
&
op_desc
,
const
std
::
vector
<
Place
>&
valid_places
,
lite
::
Scope
*
scope
=
nullptr
);
Place
place
()
const
{
CHECK
(
!
valid_kernels
.
empty
());
return
valid_kernels
.
front
()
->
place
();
}
std
::
string
op_type
()
const
{
return
op_info
()
->
Type
();
}
const
OpInfo
*
op_info
()
const
;
OpInfo
*
mutable_op_info
();
KernelBase
&
picked_kernel
()
{
CHECK
(
!
valid_kernels
.
empty
())
<<
"no kernel for "
<<
op_type
;
return
*
valid_kernels
.
front
();
void
SetKernels
(
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>&&
kernels
)
{
valid_kernels_
=
std
::
move
(
kernels
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Stmt
&
other
)
{
os
<<
"Statement "
<<
other
.
op_type
<<
" "
<<
other
.
place
();
return
os
;
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>&
kernels
()
{
return
valid_kernels_
;
}
void
SetOp
(
const
std
::
shared_ptr
<
OpLite
>&
op
)
{
op_
=
op
;
}
const
std
::
shared_ptr
<
OpLite
>
op
()
const
{
return
op_
;
}
Place
place
()
const
;
KernelBase
&
picked_kernel
();
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Stmt
&
other
);
// Description.
std
::
string
desc
;
};
struct
Arg
{
...
...
@@ -78,26 +86,16 @@ class Node {
bool
is_weight
{
false
};
};
Arg
&
AsArg
(
const
std
::
string
&
name
,
int
id
)
{
auto
&
x
=
AsArg
();
x
.
name
=
name
;
x
.
id
=
id
;
return
x
;
}
Arg
&
AsArg
(
const
std
::
string
&
name
,
int
id
);
Arg
&
AsArg
(
const
std
::
string
&
name
)
{
auto
&
x
=
AsArg
();
x
.
name
=
name
;
return
x
;
}
Arg
&
AsArg
(
const
std
::
string
&
name
);
Stmt
&
AsStmt
(
const
std
::
string
&
op_type
,
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>&&
kernels
,
const
std
::
shared_ptr
<
OpLite
>&
op
)
{
auto
&
x
=
AsStmt
();
x
.
op_type
=
op_type
;
x
.
op
=
op
;
x
.
valid_kernels
=
std
::
move
(
kernels
);
x
.
SetOp
(
op
);
x
.
SetKernels
(
std
::
move
(
kernels
));
return
x
;
}
...
...
@@ -142,7 +140,7 @@ class Node {
}
if
(
other
.
IsStmt
())
{
auto
&
arg
=
other
.
AsStmt
();
os
<<
"Statement "
<<
arg
.
op_type
;
os
<<
"Statement "
<<
arg
.
op_type
()
;
}
return
os
;
}
...
...
paddle/fluid/lite/core/mir/pattern_matcher.h
浏览文件 @
4fe5c8aa
...
...
@@ -139,14 +139,13 @@ struct PMNode {
template
<
typename
T
>
PMNode
*
assert_op_attr
(
const
std
::
string
&
attr_name
,
const
T
&
attr
)
{
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
asserts_
.
push_back
([
=
](
const
Node
*
x
)
{
if
(
x
&&
x
->
IsStmt
())
{
auto
*
op_info
=
x
->
stmt
()
->
op_info
();
return
op_info
->
HasAttr
(
attr_name
)
&&
op_info
->
GetAttr
<
T
>
(
attr_name
)
==
attr
;
}
else
{
return
false
;
}
return
false
;
});
return
this
;
}
...
...
paddle/fluid/lite/core/mir/pattern_matcher_high_api.cc
浏览文件 @
4fe5c8aa
...
...
@@ -41,6 +41,7 @@ void FuseBase::DeleteInterNodes(SSAGraph *graph) {
}
}
LOG
(
INFO
)
<<
"keys: "
<<
key2nodes_
.
size
();
std
::
unordered_set
<
const
Node
*>
nodes2rm
;
for
(
auto
&
matched
:
key2nodes_
)
{
for
(
const
auto
&
key
:
keys
)
{
...
...
paddle/fluid/lite/core/mir/pattern_matcher_high_api.h
浏览文件 @
4fe5c8aa
...
...
@@ -49,7 +49,13 @@ class FuseBase {
virtual
void
BuildPattern
()
=
0
;
// Generate an operator desc with a matched subgraph.
virtual
cpp
::
OpDesc
GenOpDesc
(
const
key2nodes_t
&
matched
)
=
0
;
virtual
cpp
::
OpDesc
GenOpDesc
(
const
key2nodes_t
&
matched
)
{
return
cpp
::
OpDesc
();
}
PMNode
*
OpNode
(
const
std
::
string
&
key
)
{
return
GetOrCreateNode
(
key
)
->
assert_is_op
();
}
PMNode
*
OpNode
(
const
std
::
string
&
key
,
const
std
::
string
&
op_type
);
...
...
paddle/fluid/lite/core/mir/pattern_matcher_high_api_test.cc
浏览文件 @
4fe5c8aa
...
...
@@ -52,7 +52,7 @@ class FcFuser : public FuseBase {
void
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
override
{
auto
op_desc
=
GenOpDesc
(
matched
);
auto
fc_op
=
LiteOpRegistry
::
Global
().
Create
(
"fc"
);
auto
mul
=
matched
.
at
(
"mul"
)
->
stmt
()
->
op
;
auto
mul
=
matched
.
at
(
"mul"
)
->
stmt
()
->
op
()
;
auto
*
scope
=
mul
->
scope
();
auto
&
valid_places
=
mul
->
valid_places
();
fc_op
->
Attach
(
op_desc
,
scope
);
...
...
@@ -90,7 +90,7 @@ std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc,
main_block
->
Var
(
"w"
);
main_block
->
Var
(
"out"
);
scope
->
Var
(
"
w
"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"
x
"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"b"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"mul_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"w"
)
->
GetMutable
<
lite
::
Tensor
>
();
...
...
paddle/fluid/lite/core/mir/pattern_matcher_test.cc
浏览文件 @
4fe5c8aa
...
...
@@ -23,19 +23,19 @@ namespace mir {
void
BuildGraph
(
SSAGraph
*
g
)
{
g
->
mutable_nodes
().
emplace_back
();
Node
&
o1
=
g
->
mutable_nodes
().
back
();
o1
.
AsStmt
().
op_type
=
"op1"
;
o1
.
AsStmt
().
desc
=
"op1"
;
g
->
mutable_nodes
().
emplace_back
();
Node
&
o2
=
g
->
mutable_nodes
().
back
();
o2
.
AsStmt
().
op_type
=
"op2"
;
o2
.
AsStmt
().
desc
=
"op2"
;
g
->
mutable_nodes
().
emplace_back
();
Node
&
o3
=
g
->
mutable_nodes
().
back
();
o3
.
AsStmt
().
op_type
=
"op3"
;
o3
.
AsStmt
().
desc
=
"op3"
;
g
->
mutable_nodes
().
emplace_back
();
Node
&
o4
=
g
->
mutable_nodes
().
back
();
o4
.
AsStmt
().
op_type
=
"op4"
;
o4
.
AsStmt
().
desc
=
"op4"
;
g
->
mutable_nodes
().
emplace_back
();
Node
&
o5
=
g
->
mutable_nodes
().
back
();
o5
.
AsStmt
().
op_type
=
"op5"
;
o5
.
AsStmt
().
desc
=
"op5"
;
g
->
mutable_nodes
().
emplace_back
();
Node
&
v1
=
g
->
mutable_nodes
().
back
();
v1
.
AsArg
(
"var1"
);
...
...
@@ -108,11 +108,11 @@ TEST(PatternMatcher, MarkPMNodesInGraph) {
// v2 -> o3(a node named o3)
auto
*
o2
=
x
.
pattern_
.
NewNode
([](
const
Node
*
node
)
{
// The teller can be any condition, such as op type, or variable's shape.
return
node
&&
node
->
IsStmt
()
&&
node
->
stmt
()
->
op_type
==
"op2"
;
return
node
&&
node
->
IsStmt
()
&&
node
->
stmt
()
->
desc
==
"op2"
;
});
auto
*
o3
=
x
.
pattern_
.
NewNode
([](
const
Node
*
node
)
{
// The teller can be any condition, such as op type, or variable's shape.
return
node
&&
node
->
IsStmt
()
&&
node
->
stmt
()
->
op_type
==
"op3"
;
return
node
&&
node
->
IsStmt
()
&&
node
->
stmt
()
->
desc
==
"op3"
;
});
auto
*
v2
=
x
.
pattern_
.
NewNode
([](
const
Node
*
node
)
{
// The teller can be any condition, such as op type, or variable's shape.
...
...
@@ -153,8 +153,8 @@ TEST(PatternMatcher, MultiSubgraph) {
// op -> var
auto
*
any_op
=
x
.
mutable_pattern
()
->
NewNode
(
[](
const
Node
*
node
)
{
return
node
->
IsStmt
()
&&
(
node
->
stmt
()
->
op_type
==
"op2"
||
node
->
stmt
()
->
op_type
==
"op3"
);
return
node
->
IsStmt
()
&&
(
node
->
stmt
()
->
desc
==
"op2"
||
node
->
stmt
()
->
desc
==
"op3"
);
},
"OP0"
);
auto
*
any_var
=
...
...
@@ -170,9 +170,9 @@ TEST(PatternMatcher, MultiSubgraph) {
int
count
=
0
;
PatternMatcher
::
handle_t
handle
=
[
&
](
const
PatternMatcher
::
subgraph_t
&
s
,
SSAGraph
*
g
)
{
LOG
(
INFO
)
<<
"Detect "
<<
s
.
at
(
any_op
)
->
stmt
()
->
op_type
<<
" -> "
LOG
(
INFO
)
<<
"Detect "
<<
s
.
at
(
any_op
)
->
stmt
()
->
desc
<<
" -> "
<<
s
.
at
(
any_var
)
->
arg
()
->
name
<<
" -> "
<<
s
.
at
(
any_op1
)
->
stmt
()
->
op_type
;
<<
s
.
at
(
any_op1
)
->
stmt
()
->
desc
;
count
++
;
};
...
...
@@ -197,12 +197,12 @@ TEST(PatternMatcher, IntermediateCheck) {
PatternMatcher
matcher
;
auto
*
op2
=
matcher
.
mutable_pattern
()
->
NewNode
(
[](
const
Node
*
x
)
{
return
x
&&
x
->
IsStmt
()
&&
x
->
stmt
()
->
op_type
==
"op2"
;
return
x
&&
x
->
IsStmt
()
&&
x
->
stmt
()
->
desc
==
"op2"
;
},
"op2"
);
auto
*
op3
=
matcher
.
mutable_pattern
()
->
NewNode
(
[](
const
Node
*
x
)
{
return
x
&&
x
->
IsStmt
()
&&
x
->
stmt
()
->
op_type
==
"op3"
;
return
x
&&
x
->
IsStmt
()
&&
x
->
stmt
()
->
desc
==
"op3"
;
},
"op3"
);
auto
*
v2
=
matcher
.
mutable_pattern
()
...
...
paddle/fluid/lite/core/mir/ssa_graph.h
浏览文件 @
4fe5c8aa
...
...
@@ -65,6 +65,10 @@ class SSAGraph : GraphBase {
Node
*
GraphCreateInstructNode
(
const
std
::
shared_ptr
<
OpLite
>
&
op
,
const
std
::
vector
<
Place
>
&
valid_places
);
// Device related attributes
const
std
::
vector
<
Place
>
&
valid_places
()
const
{
return
valid_places_
;
}
void
SetValidPlaces
(
const
std
::
vector
<
Place
>
&
x
)
{
valid_places_
=
x
;
}
private:
mir
::
Node
*
Argument
(
const
std
::
string
&
name
);
// Check the bidirectional connection.
...
...
@@ -89,6 +93,7 @@ class SSAGraph : GraphBase {
private:
std
::
list
<
mir
::
Node
>
node_storage_
;
std
::
map
<
std
::
string
,
mir
::
Node
*>
arguments_
;
std
::
vector
<
Place
>
valid_places_
;
};
// Remove the link between a -> b.
...
...
paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc
浏览文件 @
4fe5c8aa
...
...
@@ -37,9 +37,9 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
if
(
!
node
.
IsStmt
())
continue
;
auto
&
instruct
=
node
.
AsStmt
();
std
::
vector
<
std
::
pair
<
size_t
,
std
::
unique_ptr
<
KernelBase
>>>
scored
;
CHECK
(
!
instruct
.
valid_kernels
.
empty
())
<<
"No kernels found for "
<<
instruct
.
op_type
;
for
(
auto
&&
kernel
:
instruct
.
valid_kernels
)
{
CHECK
(
!
instruct
.
kernels
()
.
empty
())
<<
"No kernels found for "
<<
instruct
.
op_type
()
;
for
(
auto
&&
kernel
:
instruct
.
kernels
()
)
{
size_t
score
=
KernelGrade
(
*
kernel
);
scored
.
emplace_back
(
score
,
std
::
move
(
kernel
));
}
...
...
@@ -49,9 +49,9 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// Move kernel back
// Just keep a single best kernel.
// TODO(Superjomn) reconsider this.
instruct
.
valid_kernels
.
clear
();
instruct
.
valid_kernels
.
emplace_back
(
std
::
move
(
scored
.
front
().
second
));
VLOG
(
2
)
<<
"pick "
<<
instruct
.
valid_kernels
.
front
()
->
name
();
instruct
.
kernels
()
.
clear
();
instruct
.
kernels
()
.
emplace_back
(
std
::
move
(
scored
.
front
().
second
));
VLOG
(
2
)
<<
"pick "
<<
instruct
.
kernels
()
.
front
()
->
name
();
}
}
...
...
paddle/fluid/lite/core/mir/type_target_transform_pass.cc
浏览文件 @
4fe5c8aa
...
...
@@ -62,7 +62,7 @@ void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph, Node* inst_node,
CHECK
(
in
->
AsArg
().
type
);
if
(
!
TargetCompatibleTo
(
*
in
->
AsArg
().
type
,
*
decl_arg_type
))
{
LOG
(
INFO
)
<<
"found Target unmatched tensor: "
<<
in
->
AsArg
().
name
<<
" for kernel "
<<
inst
.
op
->
DebugString
()
<<
" "
<<
" for kernel "
<<
inst
.
op
()
->
DebugString
()
<<
" "
<<
*
in
->
AsArg
().
type
<<
" -> "
<<
*
decl_arg_type
;
// Add an IoCopy instruction to make the input compatible with other dist.
AddIoCopyInst
(
*
in
->
AsArg
().
type
,
*
decl_arg_type
,
in
,
graph
,
inst_node
,
...
...
@@ -89,7 +89,7 @@ void TypeTargetTransformPass::AddIoCopyInst(
CHECK
(
io_copy_op
)
<<
"create op ["
<<
io_copy_op
<<
"] failed"
;
// CHECK(io_copy_op);
// Create the new var manually.
inst_node
->
AsStmt
().
op
->
scope
()
->
Var
(
io_copy_output_name
);
inst_node
->
AsStmt
().
op
()
->
scope
()
->
Var
(
io_copy_output_name
);
// Create IoCopy Instruction.
cpp
::
OpDesc
op_desc
;
...
...
@@ -97,7 +97,7 @@ void TypeTargetTransformPass::AddIoCopyInst(
op_desc
.
SetInput
(
"Input"
,
{
in
->
AsArg
().
name
});
op_desc
.
SetOutput
(
"Out"
,
{
io_copy_output_name
});
io_copy_op
->
Attach
(
op_desc
,
inst_node
->
AsStmt
().
op
->
scope
());
io_copy_op
->
Attach
(
op_desc
,
inst_node
->
AsStmt
().
op
()
->
scope
());
auto
kernels
=
io_copy_op
->
CreateKernels
(
valid_places
);
io_copy_inst
->
AsStmt
(
"io_copy"
,
std
::
move
(
kernels
),
io_copy_op
);
...
...
@@ -113,19 +113,19 @@ void TypeTargetTransformPass::AddIoCopyInst(
DirectedLink
(
io_copy_output_arg
,
inst_node
);
// reset opdesc and update kernel information
UpdateInputTo
(
inst_node
->
AsStmt
().
op
->
mutable_op_info
(),
in
->
AsArg
().
name
,
UpdateInputTo
(
inst_node
->
AsStmt
().
op
()
->
mutable_op_info
(),
in
->
AsArg
().
name
,
io_copy_output_name
);
inst_node
->
AsStmt
().
op
->
Attach
(
*
inst_node
->
AsStmt
().
op
->
op_info
(),
inst_node
->
AsStmt
().
op
->
scope
());
inst_node
->
AsStmt
().
ResetOp
(
*
inst_node
->
AsStmt
().
op_info
(),
graph
->
valid_places
());
std
::
string
tmp
;
if
(
inst_node
->
AsStmt
().
op_info
()
->
GetInputArgname
(
"a"
,
&
tmp
))
{
CHECK
(
false
)
<<
"get old a "
<<
tmp
;
}
for
(
auto
&
kernel
:
inst_node
->
AsStmt
().
valid_kernels
)
{
inst_node
->
AsStmt
().
op
->
AttachKernel
(
kernel
.
get
());
for
(
auto
&
kernel
:
inst_node
->
AsStmt
().
kernels
()
)
{
inst_node
->
AsStmt
().
op
()
->
AttachKernel
(
kernel
.
get
());
}
graph
->
CheckValid
();
...
...
paddle/fluid/lite/core/mir/use_passes.h
浏览文件 @
4fe5c8aa
...
...
@@ -23,9 +23,11 @@ USE_MIR_PASS(generate_program_pass);
USE_MIR_PASS
(
io_copy_kernel_pick_pass
);
USE_MIR_PASS
(
argument_type_display_pass
);
USE_MIR_PASS
(
runtime_context_assign_pass
);
USE_MIR_PASS
(
lite_conv_bn_fuse_pass
);
USE_MIR_PASS
(
graph_visualze
);
USE_MIR_PASS
(
lite_conv_bn_fuse_pass
);
USE_MIR_PASS
(
lite_fc_fuse_pass
);
USE_MIR_PASS
(
identity_scale_eliminate_pass
);
USE_MIR_PASS
(
lite_conv_elementwise_add_activation_fuse_pass
);
USE_MIR_PASS
(
lite_elementwise_add_activation_fuse_pass
);
USE_MIR_PASS
(
lite_quant_dequant_fuse_pass
);
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
浏览文件 @
4fe5c8aa
...
...
@@ -39,7 +39,7 @@ class VariablePlaceInferencePass : public DebugPass {
for
(
const
auto
&
v
:
graph
->
inputs
())
{
// the feed op might in the inputs
if
(
v
->
IsStmt
())
{
LOG
(
INFO
)
<<
"found kernel in inputs "
<<
v
->
AsStmt
().
op_type
;
LOG
(
INFO
)
<<
"found kernel in inputs "
<<
v
->
AsStmt
().
op_type
()
;
continue
;
}
}
...
...
@@ -59,10 +59,10 @@ class VariablePlaceInferencePass : public DebugPass {
for
(
auto
&
x
:
graph
->
StmtTopologicalOrder
())
{
auto
&
inst
=
x
->
AsStmt
();
// The IoCopyOp is a tool operator, it won't support the type inference.
if
(
inst
.
op_type
==
"io_copy"
)
continue
;
if
(
inst
.
op_type
()
==
"io_copy"
)
continue
;
// LOG(INFO) << "- inferencing type " <<
// deal with inputs
VLOG
(
4
)
<<
"
inferencing op "
<<
inst
.
op_type
;
VLOG
(
4
)
<<
"
Infering op "
<<
inst
.
op_info
()
->
Repr
()
;
// TODO(zhaolong): Add check if the node's name in op's arguments.
auto
get_argname
=
[
&
](
...
...
@@ -90,12 +90,14 @@ class VariablePlaceInferencePass : public DebugPass {
}
}
VLOG
(
3
)
<<
"inst "
<<
inst
.
op_info
()
->
Repr
();
for
(
auto
*
x_out
:
x
->
outlinks
)
{
std
::
string
node_name
=
x_out
->
AsArg
().
name
;
std
::
string
arg_name
=
get_argname
(
node_name
,
inst
.
op_info
()
->
outputs
());
CHECK
(
arg_name
.
size
()
>
0
)
<<
"can not found op arguments for node "
<<
node_name
;
<<
node_name
<<
" in Inst "
<<
inst
.
op_type
();
VLOG
(
3
)
<<
"-- output arg_name "
<<
arg_name
;
auto
type
=
inst
.
picked_kernel
().
GetOutputDeclType
(
arg_name
);
if
(
!
x_out
->
AsArg
().
type
)
{
...
...
paddle/fluid/lite/core/op_lite.cc
浏览文件 @
4fe5c8aa
...
...
@@ -61,7 +61,6 @@ std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels(
targets
.
insert
(
place
.
target
);
}
// CHECK(!kernels.empty()) << "No kernel found for Op " << op_type_;
VLOG
(
2
)
<<
"op "
<<
op_type_
<<
" get "
<<
kernels
.
size
()
<<
" kernels"
;
return
kernels
;
}
...
...
@@ -83,7 +82,7 @@ bool OpLite::Attach(const cpp::OpDesc &opdesc, lite::Scope *scope) {
scope_
=
scope
;
op_info_
.
reset
(
new
OpInfo
(
opdesc
));
// Force clean the out-of-date infomation.
return
AttachImpl
(
opdesc
,
scope
);
return
AttachImpl
(
*
op_info
()
,
scope
);
}
const
Tensor
*
OpLite
::
GetTensor
(
lite
::
Scope
*
scope
,
...
...
paddle/fluid/lite/core/op_lite.h
浏览文件 @
4fe5c8aa
...
...
@@ -197,6 +197,22 @@ class OpInfo : public cpp::OpDesc {
}
return
false
;
}
void
UpdateAllInputs
(
const
std
::
string
&
from
,
const
std
::
string
&
to
)
{
for
(
auto
&
item
:
inputs_
)
{
for
(
auto
&
var
:
item
.
second
)
{
if
(
var
==
from
)
var
=
to
;
}
}
}
void
UpdateAllOutputs
(
const
std
::
string
&
from
,
const
std
::
string
&
to
)
{
for
(
auto
&
item
:
outputs_
)
{
for
(
auto
&
var
:
item
.
second
)
{
if
(
var
==
from
)
var
=
to
;
}
}
}
};
}
// namespace lite
...
...
paddle/fluid/lite/core/optimizer.h
浏览文件 @
4fe5c8aa
...
...
@@ -43,6 +43,8 @@ class Optimizer {
CHECK
(
!
graph_
)
<<
"duplicate optimize found"
;
graph_
.
reset
(
new
mir
::
SSAGraph
);
graph_
->
Build
(
program
,
valid_places
);
graph_
->
SetValidPlaces
(
valid_places
);
SpecifyKernelPickTactic
(
kernel_pick_factor
);
InitTargetTypeTransformPass
();
...
...
@@ -51,6 +53,8 @@ class Optimizer {
"lite_quant_dequant_fuse_pass"
,
//
"lite_conv_bn_fuse_pass"
,
//
"lite_conv_elementwise_add_activation_fuse_pass"
,
//
"lite_fc_fuse_pass"
,
//
"identity_scale_eliminate_pass"
,
//
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
"lite_elementwise_add_activation_fuse_pass"
,
//
#endif
...
...
paddle/fluid/lite/core/program.h
浏览文件 @
4fe5c8aa
...
...
@@ -140,7 +140,7 @@ class RuntimeProgram {
void
Run
()
{
for
(
auto
&
inst
:
instructions_
)
{
VLOG
(
4
)
<<
">> Running kernel: "
<<
inst
;
VLOG
(
3
)
<<
">> Running kernel: "
<<
inst
.
op
()
->
op_info
()
->
Repr
()
;
inst
.
Run
();
}
}
...
...
paddle/fluid/lite/core/tensor.h
浏览文件 @
4fe5c8aa
...
...
@@ -191,7 +191,6 @@ class TensorBase {
template
<
typename
TensorT
>
bool
TensorCompareWith
(
const
TensorT
&
a
,
const
TensorT
&
b
)
{
if
(
a
.
dims
()
!=
b
.
dims
())
return
false
;
LOG
(
INFO
)
<<
"data_size: "
<<
a
.
data_size
();
if
(
memcmp
(
a
.
raw_data
(),
b
.
raw_data
(),
a
.
data_size
())
!=
0
)
return
false
;
return
true
;
}
...
...
paddle/fluid/lite/kernels/arm/elementwise_compute_test.cc
浏览文件 @
4fe5c8aa
...
...
@@ -117,6 +117,19 @@ TEST(elementwise_add, compute) {
operators
::
ElementwiseParam
param
;
lite
::
Tensor
x
,
y
,
output
,
output_ref
;
#if 1
for
(
auto
n
:
{
1
,
3
,
4
})
{
for
(
auto
c
:
{
1
,
3
,
4
})
{
for
(
auto
h
:
{
1
,
3
,
4
})
{
for
(
auto
w
:
{
1
,
3
,
4
})
{
for
(
auto
axis
:
{
-
1
,
0
,
1
,
3
})
{
for
(
auto
yd
:
{
std
::
vector
<
int64_t
>
({
n
}),
std
::
vector
<
int64_t
>
({
c
}),
std
::
vector
<
int64_t
>
({
h
}),
std
::
vector
<
int64_t
>
({
w
}),
std
::
vector
<
int64_t
>
({
n
,
c
}),
std
::
vector
<
int64_t
>
({
c
,
h
}),
std
::
vector
<
int64_t
>
({
c
,
h
,
w
}),
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
})})
{
#else
for
(
auto
n
:
{
1
,
3
,
4
,
11
})
{
for
(
auto
c
:
{
1
,
3
,
4
,
11
})
{
for
(
auto
h
:
{
1
,
3
,
4
,
11
})
{
...
...
@@ -129,6 +142,7 @@ TEST(elementwise_add, compute) {
std
::
vector
<
int64_t
>
({
h
,
w
}),
std
::
vector
<
int64_t
>
({
n
,
c
,
h
}),
std
::
vector
<
int64_t
>
({
c
,
h
,
w
}),
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
})})
{
#endif
auto
x_dim
=
DDim
(
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
}));
auto
y_dim
=
DDim
(
yd
);
int
axis_t
=
axis
<
0
?
x_dim
.
size
()
-
y_dim
.
size
()
:
axis
;
...
...
@@ -192,6 +206,20 @@ TEST(fusion_elementwise_add_activation_arm, compute) {
operators
::
FusionElementwiseActivationParam
param
;
lite
::
Tensor
x
,
y
,
output
,
output_ref
;
#if 1
for
(
auto
act_type
:
{
"relu"
})
{
for
(
auto
n
:
{
1
,
3
,
4
})
{
for
(
auto
c
:
{
1
,
3
,
4
})
{
for
(
auto
h
:
{
1
,
3
,
4
})
{
for
(
auto
w
:
{
1
,
3
,
4
})
{
for
(
auto
axis
:
{
-
1
,
0
,
1
,
3
})
{
for
(
auto
yd
:
{
std
::
vector
<
int64_t
>
({
n
}),
std
::
vector
<
int64_t
>
({
c
}),
std
::
vector
<
int64_t
>
({
h
}),
std
::
vector
<
int64_t
>
({
w
}),
std
::
vector
<
int64_t
>
({
n
,
c
}),
std
::
vector
<
int64_t
>
({
h
,
w
}),
std
::
vector
<
int64_t
>
({
n
,
c
,
h
}),
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
})})
{
#else
for
(
auto
act_type
:
{
"relu"
})
{
for
(
auto
n
:
{
1
,
3
,
4
,
11
})
{
for
(
auto
c
:
{
1
,
3
,
4
,
11
})
{
...
...
@@ -206,6 +234,7 @@ TEST(fusion_elementwise_add_activation_arm, compute) {
std
::
vector
<
int64_t
>
({
n
,
c
,
h
}),
std
::
vector
<
int64_t
>
({
c
,
h
,
w
}),
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
})})
{
#endif
auto
x_dim
=
DDim
(
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
}));
auto
y_dim
=
DDim
(
yd
);
int
axis_t
=
axis
<
0
?
x_dim
.
size
()
-
y_dim
.
size
()
:
axis
;
...
...
paddle/fluid/lite/kernels/arm/softmax_compute_test.cc
浏览文件 @
4fe5c8aa
...
...
@@ -80,12 +80,19 @@ TEST(softmax_arm, compute) {
lite
::
Tensor
x
;
lite
::
Tensor
output
;
lite
::
Tensor
output_ref
;
#if 1
for
(
auto
n
:
{
1
,
3
})
{
for
(
auto
c
:
{
1
,
4
})
{
for
(
auto
h
:
{
5
,
1
})
{
for
(
auto
w
:
{
1
,
6
})
{
for
(
auto
axis
:
{
-
2
,
-
1
,
0
,
1
,
2
})
{
#else
for
(
auto
n
:
{
1
,
3
,
4
,
11
})
{
for
(
auto
c
:
{
1
,
3
,
11
,
4
})
{
for
(
auto
h
:
{
3
,
1
,
11
,
4
})
{
for
(
auto
w
:
{
1
,
3
,
4
,
12
})
{
for
(
auto
axis
:
{
-
4
,
-
3
,
-
2
,
-
1
,
0
,
1
,
2
,
3
})
{
#endif
x
.
Resize
(
DDim
(
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
})));
output
.
Resize
(
DDim
(
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
})));
output_ref
.
Resize
(
DDim
(
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
})));
...
...
paddle/fluid/lite/kernels/x86/elementwise_compute_test.cc
浏览文件 @
4fe5c8aa
...
...
@@ -15,6 +15,8 @@
#include "paddle/fluid/lite/kernels/x86/elementwise_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
...
...
paddle/fluid/lite/kernels/x86/mul_compute.h
浏览文件 @
4fe5c8aa
...
...
@@ -40,12 +40,20 @@ class MulCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
auto
*
x
=
&
param
.
x
->
raw_tensor
();
auto
*
y
=
&
param
.
y
->
raw_tensor
();
const
Tensor
x_matrix
=
x
->
dims
().
size
()
>
2
?
framework
::
ReshapeToMatrix
(
*
x
,
param
.
x_num_col_dims
)
:
*
x
;
const
Tensor
y_matrix
=
y
->
dims
().
size
()
>
2
?
framework
::
ReshapeToMatrix
(
*
y
,
param
.
y_num_col_dims
)
:
*
y
;
Tensor
x_matrix
,
y_matrix
;
if
(
x
->
dims
().
size
()
>
2
)
{
x_matrix
=
framework
::
ReshapeToMatrix
(
*
x
,
param
.
x_num_col_dims
);
}
else
{
x_matrix
=
*
x
;
}
if
(
y
->
dims
().
size
()
>
2
)
{
y_matrix
=
framework
::
ReshapeToMatrix
(
*
y
,
param
.
y_num_col_dims
);
}
else
{
y_matrix
=
*
y
;
}
auto
*
z
=
&
param
.
output
->
raw_tensor
();
auto
z_dim
=
z
->
dims
();
...
...
@@ -75,12 +83,22 @@ class MulGradCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
auto
*
x
=
&
param
.
x
->
raw_tensor
();
auto
*
y
=
&
param
.
y
->
raw_tensor
();
auto
x_matrix
=
x
->
dims
().
size
()
>
2
?
framework
::
ReshapeToMatrix
(
*
x
,
param
.
x_num_col_dims
)
:
static_cast
<
const
Tensor
&>
(
*
x
);
auto
y_matrix
=
y
->
dims
().
size
()
>
2
?
framework
::
ReshapeToMatrix
(
*
y
,
param
.
y_num_col_dims
)
:
static_cast
<
const
Tensor
&>
(
*
y
);
Tensor
x_matrix
,
y_matrix
;
if
(
x
->
dims
().
size
()
>
2
)
{
x_matrix
=
framework
::
ReshapeToMatrix
(
*
x
,
param
.
x_num_col_dims
);
}
else
{
x_matrix
=
*
x
;
}
if
(
y
->
dims
().
size
()
>
2
)
{
y_matrix
=
framework
::
ReshapeToMatrix
(
*
y
,
param
.
y_num_col_dims
);
}
else
{
y_matrix
=
*
y
;
}
auto
*
dout
=
&
param
.
output_grad
->
raw_tensor
();
Tensor
dout_mat
;
...
...
paddle/fluid/lite/kernels/x86/mul_compute_test.cc
浏览文件 @
4fe5c8aa
...
...
@@ -15,6 +15,8 @@
#include "paddle/fluid/lite/kernels/x86/mul_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
...
...
paddle/fluid/lite/model_parser/CMakeLists.txt
浏览文件 @
4fe5c8aa
...
...
@@ -11,7 +11,8 @@ if(NOT LITE_ON_MOBILE)
endif
()
cc_library
(
compatible_pb_lite SRCS compatible_pb.cc DEPS op_desc_lite framework_proto_lite var_desc_lite
)
cc_library
(
compatible_pb_lite SRCS compatible_pb.cc
DEPS op_desc_lite framework_proto_lite var_desc_lite cpp_op_desc_lite
)
lite_cc_library
(
model_parser_lite SRCS model_parser.cc DEPS
variable_lite scope_lite
${
tensor_lite
}
scope_lite
...
...
paddle/fluid/lite/model_parser/cpp/CMakeLists.txt
浏览文件 @
4fe5c8aa
cc_library
(
cpp_op_desc_lite SRCS op_desc.cc DEPS any_lite
)
paddle/fluid/lite/model_parser/desc_apis.h
浏览文件 @
4fe5c8aa
...
...
@@ -14,6 +14,7 @@
#pragma once
#include <map>
#include <sstream>
#include <string>
#include <vector>
...
...
@@ -79,6 +80,27 @@ class OpDescAPI {
/// Get an attribute.
template
<
typename
T
>
T
GetAttr
(
const
std
::
string
&
name
)
const
;
std
::
string
Repr
()
const
{
std
::
stringstream
ss
;
ss
<<
Type
();
ss
<<
"("
;
for
(
auto
&
arg
:
InputArgumentNames
())
{
ss
<<
arg
<<
":"
;
for
(
auto
val
:
Input
(
arg
))
{
ss
<<
val
<<
" "
;
}
}
ss
<<
") -> ("
;
for
(
auto
&
arg
:
OutputArgumentNames
())
{
ss
<<
arg
<<
":"
;
for
(
auto
val
:
Output
(
arg
))
{
ss
<<
val
<<
" "
;
}
}
ss
<<
")"
;
return
ss
.
str
();
}
};
}
// namespace lite
...
...
paddle/fluid/lite/model_parser/pb/op_desc.h
浏览文件 @
4fe5c8aa
...
...
@@ -141,6 +141,8 @@ class OpDesc : public OpDescAPI {
template
<
typename
T
>
T
GetAttr
(
const
std
::
string
&
name
)
const
;
std
::
string
DebugString
()
const
{
return
desc_
.
DebugString
();
}
private:
std
::
vector
<
std
::
string
>
GetArguments
(
const
google
::
protobuf
::
RepeatedPtrField
<
framework
::
proto
::
OpDesc_Var
>
...
...
paddle/fluid/lite/operators/mul_op.h
浏览文件 @
4fe5c8aa
...
...
@@ -38,15 +38,19 @@ class MulOpLite : public OpLite {
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
CHECK
(
!
op_desc
.
Input
(
"X"
).
empty
());
CHECK
(
!
op_desc
.
Input
(
"Y"
).
empty
());
CHECK
(
!
op_desc
.
Output
(
"Out"
).
empty
());
auto
input
=
op_desc
.
Input
(
"X"
).
front
();
auto
W
=
op_desc
.
Input
(
"Y"
).
front
();
auto
out
=
op_desc
.
Output
(
"Out"
).
front
();
auto
*
var
=
scope
->
FindVar
(
input
);
CHECK
(
var
);
param_
.
x
=
var
->
GetMutable
<
Tensor
>
();
param_
.
x
=
&
var
->
Get
<
Tensor
>
();
var
=
scope
->
FindVar
(
W
);
CHECK
(
var
)
<<
"no var called "
<<
W
;
param_
.
y
=
var
->
GetMutable
<
Tensor
>
();
param_
.
y
=
&
var
->
Get
<
Tensor
>
();
var
=
scope
->
FindVar
(
out
);
CHECK
(
var
)
<<
"no var called "
<<
out
;
param_
.
output
=
var
->
GetMutable
<
Tensor
>
();
...
...
paddle/fluid/lite/operators/op_params.h
浏览文件 @
4fe5c8aa
...
...
@@ -67,8 +67,8 @@ struct ReluParam {
// For Mul Op
struct
MulParam
{
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
y
{};
const
lite
::
Tensor
*
x
{};
const
lite
::
Tensor
*
y
{};
lite
::
Tensor
*
output
{};
int
x_num_col_dims
{
1
};
...
...
paddle/fluid/lite/tools/build.sh
浏览文件 @
4fe5c8aa
...
...
@@ -54,22 +54,6 @@ function check_style {
fi
}
function
cmake_arm
{
# $1: ARM_TARGET_OS in "android" , "armlinux"
# $2: ARM_TARGET_ARCH_ABI in "armv8", "armv7" ,"armv7hf"
# $3: ARM_TARGET_LANG in "gcc" "clang"
cmake ..
\
-DWITH_GPU
=
OFF
\
-DWITH_MKL
=
OFF
\
-DWITH_LITE
=
ON
\
-DLITE_WITH_CUDA
=
OFF
\
-DLITE_WITH_X86
=
OFF
\
-DLITE_WITH_ARM
=
ON
\
-DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK
=
ON
\
-DWITH_TESTING
=
ON
\
-DARM_TARGET_OS
=
$1
-DARM_TARGET_ARCH_ABI
=
$2
-DARM_TARGET_LANG
=
$3
}
function
build_single
{
#make $1 -j$(expr $(nproc) - 2)
make
$1
-j
$NUM_CORES_FOR_COMPILE
...
...
@@ -153,33 +137,53 @@ function test_arm_model {
adb
-s
emulator-
${
port
}
shell
chmod
+x
"
${
adb_work_dir
}
/
${
test_name
}
"
local
adb_model_path
=
"./
${
adb_work_dir
}
/
`
basename
${
model_dir
}
`
"
adb
-s
emulator-
${
port
}
shell
"./
${
adb_work_dir
}
/
${
test_name
}
--eval_model_dir=
$adb_model_path
"
}
# Build the code and run lite arm tests. This is executed in the CI system.
function
build_test_arm
{
# 1. Build goes first
function
cmake_arm
{
# $1: ARM_TARGET_OS in "android" , "armlinux"
# $2: ARM_TARGET_ARCH_ABI in "armv8", "armv7" ,"armv7hf"
# $3: ARM_TARGET_LANG in "gcc" "clang"
cmake ..
\
-DWITH_GPU
=
OFF
\
-DWITH_MKL
=
OFF
\
-DWITH_LITE
=
ON
\
-DLITE_WITH_CUDA
=
OFF
\
-DLITE_WITH_X86
=
OFF
\
-DLITE_WITH_ARM
=
ON
\
-DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK
=
ON
\
-DWITH_TESTING
=
ON
\
-DARM_TARGET_OS
=
$1
-DARM_TARGET_ARCH_ABI
=
$2
-DARM_TARGET_LANG
=
$3
}
# $1: ARM_TARGET_OS in "android" , "armlinux"
# $2: ARM_TARGET_ARCH_ABI in "armv8", "armv7" ,"armv7hf"
# $3: ARM_TARGET_LANG in "gcc" "clang"
function
build_arm
{
os
=
$1
abi
=
$2
lang
=
$3
cur_dir
=
$(
pwd
)
for
lang
in
"gcc"
"clang"
;
do
for
os
in
"android"
"armlinux"
;
do
if
[[
${
os
}
==
"armlinux"
&&
${
lang
}
==
"clang"
]]
;
then
continue
if
[[
${
os
}
==
"armlinux"
]]
;
then
# TODO(hongming): enable compile armv7 and armv7hf on armlinux, and clang compile
if
[[
${
lang
}
==
"clang"
]]
;
then
echo
"clang is not enabled on armlinux yet"
return
0
fi
for
abi
in
"armv8"
"armv7"
"armv7hf"
;
do
# TODO(hongming): enable compile armv7 and armv7hf on armlinux
if
[[
${
abi
}
==
"armv7hf"
]]
;
then
echo
"armv7hf is not supported on both android and
armlinux yet"
continue
echo
"armv7hf is not supported on
armlinux yet"
return
0
fi
# TODO(hongming): enable armv7 on armlinux
if
[[
${
os
}
==
"armlinux"
&&
${
abi
}
==
"armv7"
]]
;
then
if
[[
${
abi
}
==
"armv7"
]]
;
then
echo
"armv7 is not supported on armlinux yet"
continue
return
0
fi
fi
if
[[
${
os
}
==
"android"
&&
${
abi
}
==
"armv7hf"
]]
;
then
echo
"android do not need armv7hf"
continue
return
0
fi
build_dir
=
$cur_dir
/build.lite.
${
os
}
.
${
abi
}
.
${
lang
}
...
...
@@ -188,11 +192,47 @@ function build_test_arm {
cmake_arm
${
os
}
${
abi
}
${
lang
}
build
$TESTS_FILE
}
# $1: ARM_TARGET_OS in "android" , "armlinux"
# $2: ARM_TARGET_ARCH_ABI in "armv8", "armv7" ,"armv7hf"
# $3: ARM_TARGET_LANG in "gcc" "clang"
# $4: android test port
# Note: test must be in build dir
function
test_arm
{
os
=
$1
abi
=
$2
lang
=
$3
port
=
$4
if
[[
${
os
}
==
"armlinux"
]]
;
then
# TODO(hongming): enable test armlinux on armv8, armv7 and armv7hf
echo
"Skip test arm linux yet. armlinux must in another docker"
return
0
fi
if
[[
${
os
}
==
"android"
&&
${
abi
}
==
"armv7hf"
]]
;
then
echo
"android do not need armv7hf"
return
0
fi
# TODO(yuanshuai): enable armv7 on android
if
[[
${
abi
}
==
"armv7"
]]
;
then
echo
"skip android v7 test yet"
return
0
fi
echo
"test file:
${
TESTS_FILE
}
"
for
_test
in
$(
cat
$TESTS_FILE
)
;
do
test_arm_android
$_test
$port
done
done
done
# TODO(sangoly): refine this
test_arm_model
"test_cxx_api_lite"
$port
"./third_party/install/mobilenet_v2_relu"
}
# 2. Then test
# Build the code and run lite arm tests. This is executed in the CI system.
function
build_test_arm
{
########################################################################
# job 1-4 must be in one runner
port_armv8
=
5554
port_armv7
=
5556
...
...
@@ -206,39 +246,46 @@ function build_test_arm {
echo
-ne
'\n'
|
${
ANDROID_HOME
}
/emulator/emulator
-avd
paddle-armv7
-noaudio
-no-window
-gpu
off
-verbose
-port
${
port_armv7
}
&
sleep
1m
# now can only test android.
for
lang
in
"gcc"
"clang"
;
do
for
abi
in
"armv8"
"armv7"
;
do
# TODO(yuanshuai): enable armv7 on android
if
[[
${
abi
}
==
"armv7"
]]
;
then
continue
fi
# job 1
build_arm
"android"
"armv8"
"gcc"
test_arm
"android"
"armv8"
"gcc"
${
port_armv8
}
cd
-
build_dir
=
$cur_dir
/build.lite.android.
${
abi
}
.
${
lang
}
cd
$build_dir
# job 2
build_arm
"android"
"armv8"
"clang"
test_arm
"android"
"armv8"
"clang"
${
port_armv8
}
cd
-
local
port
=
if
[[
${
abi
}
==
"armv7"
]]
;
then
port
=
${
port_armv7
}
fi
# job 3
build_arm
"android"
"armv7"
"gcc"
test_arm
"android"
"armv7"
"gcc"
${
port_armv7
}
cd
-
if
[[
${
abi
}
==
"armv8"
]]
;
then
port
=
${
port_armv8
}
fi
echo
"test file:
${
TESTS_FILE
}
"
for
_test
in
$(
cat
$TESTS_FILE
)
;
do
test_arm_android
$_test
$port
done
# TODO(sangoly): refine this
test_arm_model
"test_cxx_api_lite"
$port
"./third_party/install/mobilenet_v2_relu"
done
done
# armlinux need in another docker
# TODO(hongming): enable test armlinux on armv8, armv7 and armv7hf
# job 4
build_arm
"android"
"armv7"
"clang"
test_arm
"android"
"armv7"
"clang"
${
port_armv7
}
cd
-
adb devices |
grep
emulator |
cut
-f1
|
while
read
line
;
do
adb
-s
$line
emu
kill
;
done
echo
"Done"
########################################################################
# job 5
build_arm
"armlinux"
"armv8"
test_arm
"armlinux"
"armv8"
cd
-
# job 6
build_arm
"armlinux"
"armv7"
test_arm
"armlinux"
"armv7"
cd
-
# job 7
build_arm
"armlinux"
"armv7hf"
test_arm
"armlinux"
"armv7hf"
cd
-
echo
"Done"
}
############################# MAIN #################################
...
...
@@ -279,6 +326,10 @@ function main {
ARM_ABI
=
"
${
i
#*=
}
"
shift
;;
--arm_lang
=
*
)
ARM_LANG
=
"
${
i
#*=
}
"
shift
;;
--arm_port
=
*
)
ARM_PORT
=
"
${
i
#*=
}
"
shift
...
...
@@ -301,13 +352,21 @@ function main {
shift
;;
cmake_arm
)
cmake_arm
$ARM_OS
$ARM_ABI
cmake_arm
$ARM_OS
$ARM_ABI
$ARM_LANG
shift
;;
build_arm
)
build_arm
$ARM_OS
$ARM_ABI
$ARM_LANG
shift
;;
test_server
)
test_lite
$TESTS_FILE
shift
;;
test_arm
)
build_arm
$ARM_OS
$ARM_ABI
$ARM_LANG
$ARM_PORT
shift
;;
test_arm_android
)
test_arm_android
$TEST_NAME
$ARM_PORT
shift
...
...
paddle/fluid/lite/utils/varient.h
浏览文件 @
4fe5c8aa
...
...
@@ -20,6 +20,7 @@
#include <typeinfo>
#include <utility>
#include "paddle/fluid/lite/utils/cp_logging.h"
#include "paddle/fluid/lite/utils/string.h"
// This is an equivalent implementation of boost::any. We implement this to
// avoid including the whole boost library and keep the inference library small.
...
...
@@ -116,9 +117,9 @@ struct variant {
if
(
type_id
==
typeid
(
T
).
hash_code
())
return
*
reinterpret_cast
<
const
T
*>
(
&
data
);
else
throw
std
::
invalid_argument
(
"unmatched type"
);
// LOG(FATAL) << "unmatched type get, should be " << type_id << " but get "
// << typeid(T).name(
);
throw
std
::
invalid_argument
(
string_format
(
"unmatched type, store as %d, but want to get %s"
,
type_id
,
typeid
(
T
).
name
())
);
return
*
reinterpret_cast
<
const
T
*>
(
&
data
);
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录