Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
4fe5c8aa
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4fe5c8aa
编写于
6月 20, 2019
作者:
N
nhzlx
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'incubate/lite' of
http://10.87.145.36/inference/paddlelite
into xzl/incubate/lite
fix comments
上级
8ead391d
6f705068
变更
65
显示空白变更内容
内联
并排
Showing
65 changed file
with
859 addition
and
224 deletion
+859
-224
CMakeLists.txt
CMakeLists.txt
+2
-2
paddle/fluid/lite/api/apis_test.cc
paddle/fluid/lite/api/apis_test.cc
+1
-1
paddle/fluid/lite/api/cxx_api_bin.cc
paddle/fluid/lite/api/cxx_api_bin.cc
+1
-1
paddle/fluid/lite/api/lite_api_test_helper.cc
paddle/fluid/lite/api/lite_api_test_helper.cc
+1
-0
paddle/fluid/lite/core/CMakeLists.txt
paddle/fluid/lite/core/CMakeLists.txt
+0
-1
paddle/fluid/lite/core/mir/CMakeLists.txt
paddle/fluid/lite/core/mir/CMakeLists.txt
+24
-20
paddle/fluid/lite/core/mir/elimination/CMakeLists.txt
paddle/fluid/lite/core/mir/elimination/CMakeLists.txt
+7
-0
paddle/fluid/lite/core/mir/elimination/identity_scale_eliminate_pass.cc
...ite/core/mir/elimination/identity_scale_eliminate_pass.cc
+72
-0
paddle/fluid/lite/core/mir/elimination/identity_scale_eliminate_pass_test.cc
...ore/mir/elimination/identity_scale_eliminate_pass_test.cc
+93
-0
paddle/fluid/lite/core/mir/fusion/CMakeLists.txt
paddle/fluid/lite/core/mir/fusion/CMakeLists.txt
+0
-1
paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass.cc
paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass.cc
+1
-1
paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass.h
paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass.h
+0
-0
paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass_test.cc
paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass_test.cc
+1
-1
paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.cc
paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.cc
+1
-1
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuse_pass.cc
...e/mir/fusion/conv_elementwise_add_activation_fuse_pass.cc
+1
-1
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuse_pass.h
...re/mir/fusion/conv_elementwise_add_activation_fuse_pass.h
+0
-0
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuse_pass_test.cc
.../fusion/conv_elementwise_add_activation_fuse_pass_test.cc
+1
-1
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.cc
.../core/mir/fusion/conv_elementwise_add_activation_fuser.cc
+1
-1
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass.cc
...te/core/mir/fusion/conv_elementwise_add_relu_fuse_pass.cc
+39
-0
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass.h
...ite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass.h
+32
-0
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass_test.cc
...re/mir/fusion/conv_elementwise_add_relu_fuse_pass_test.cc
+153
-0
paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.cc
...e/core/mir/fusion/elementwise_add_activation_fuse_pass.cc
+1
-1
paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.h
...te/core/mir/fusion/elementwise_add_activation_fuse_pass.h
+0
-0
paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuse_pass_test.cc
...e/mir/fusion/elementwise_add_activation_fuse_pass_test.cc
+1
-1
paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.cc
.../lite/core/mir/fusion/elementwise_add_activation_fuser.cc
+1
-1
paddle/fluid/lite/core/mir/fusion/fc_fuse_pass.cc
paddle/fluid/lite/core/mir/fusion/fc_fuse_pass.cc
+1
-1
paddle/fluid/lite/core/mir/fusion/fc_fuse_pass.h
paddle/fluid/lite/core/mir/fusion/fc_fuse_pass.h
+0
-0
paddle/fluid/lite/core/mir/fusion/fc_fuse_pass_test.cc
paddle/fluid/lite/core/mir/fusion/fc_fuse_pass_test.cc
+1
-1
paddle/fluid/lite/core/mir/fusion/fc_fuser.cc
paddle/fluid/lite/core/mir/fusion/fc_fuser.cc
+1
-1
paddle/fluid/lite/core/mir/fusion/quant_dequant_fuse_pass.cc
paddle/fluid/lite/core/mir/fusion/quant_dequant_fuse_pass.cc
+1
-1
paddle/fluid/lite/core/mir/fusion/quant_dequant_fuse_pass.h
paddle/fluid/lite/core/mir/fusion/quant_dequant_fuse_pass.h
+0
-0
paddle/fluid/lite/core/mir/fusion/quant_dequant_op_fuser.cc
paddle/fluid/lite/core/mir/fusion/quant_dequant_op_fuser.cc
+2
-2
paddle/fluid/lite/core/mir/generate_program_pass.cc
paddle/fluid/lite/core/mir/generate_program_pass.cc
+1
-1
paddle/fluid/lite/core/mir/graph_visualize_pass.cc
paddle/fluid/lite/core/mir/graph_visualize_pass.cc
+1
-1
paddle/fluid/lite/core/mir/io_copy_kernel_pick_pass.cc
paddle/fluid/lite/core/mir/io_copy_kernel_pick_pass.cc
+2
-2
paddle/fluid/lite/core/mir/node.cc
paddle/fluid/lite/core/mir/node.cc
+59
-0
paddle/fluid/lite/core/mir/node.h
paddle/fluid/lite/core/mir/node.h
+32
-34
paddle/fluid/lite/core/mir/pattern_matcher.h
paddle/fluid/lite/core/mir/pattern_matcher.h
+2
-3
paddle/fluid/lite/core/mir/pattern_matcher_high_api.cc
paddle/fluid/lite/core/mir/pattern_matcher_high_api.cc
+1
-0
paddle/fluid/lite/core/mir/pattern_matcher_high_api.h
paddle/fluid/lite/core/mir/pattern_matcher_high_api.h
+7
-1
paddle/fluid/lite/core/mir/pattern_matcher_high_api_test.cc
paddle/fluid/lite/core/mir/pattern_matcher_high_api_test.cc
+2
-2
paddle/fluid/lite/core/mir/pattern_matcher_test.cc
paddle/fluid/lite/core/mir/pattern_matcher_test.cc
+13
-13
paddle/fluid/lite/core/mir/ssa_graph.h
paddle/fluid/lite/core/mir/ssa_graph.h
+5
-0
paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc
paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc
+6
-6
paddle/fluid/lite/core/mir/type_target_transform_pass.cc
paddle/fluid/lite/core/mir/type_target_transform_pass.cc
+8
-8
paddle/fluid/lite/core/mir/use_passes.h
paddle/fluid/lite/core/mir/use_passes.h
+3
-1
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
+6
-4
paddle/fluid/lite/core/op_lite.cc
paddle/fluid/lite/core/op_lite.cc
+1
-2
paddle/fluid/lite/core/op_lite.h
paddle/fluid/lite/core/op_lite.h
+16
-0
paddle/fluid/lite/core/optimizer.h
paddle/fluid/lite/core/optimizer.h
+4
-0
paddle/fluid/lite/core/program.h
paddle/fluid/lite/core/program.h
+1
-1
paddle/fluid/lite/core/tensor.h
paddle/fluid/lite/core/tensor.h
+0
-1
paddle/fluid/lite/kernels/arm/elementwise_compute_test.cc
paddle/fluid/lite/kernels/arm/elementwise_compute_test.cc
+29
-0
paddle/fluid/lite/kernels/arm/softmax_compute_test.cc
paddle/fluid/lite/kernels/arm/softmax_compute_test.cc
+8
-1
paddle/fluid/lite/kernels/x86/elementwise_compute_test.cc
paddle/fluid/lite/kernels/x86/elementwise_compute_test.cc
+2
-0
paddle/fluid/lite/kernels/x86/mul_compute.h
paddle/fluid/lite/kernels/x86/mul_compute.h
+30
-12
paddle/fluid/lite/kernels/x86/mul_compute_test.cc
paddle/fluid/lite/kernels/x86/mul_compute_test.cc
+2
-0
paddle/fluid/lite/model_parser/CMakeLists.txt
paddle/fluid/lite/model_parser/CMakeLists.txt
+2
-1
paddle/fluid/lite/model_parser/cpp/CMakeLists.txt
paddle/fluid/lite/model_parser/cpp/CMakeLists.txt
+0
-1
paddle/fluid/lite/model_parser/desc_apis.h
paddle/fluid/lite/model_parser/desc_apis.h
+22
-0
paddle/fluid/lite/model_parser/pb/op_desc.h
paddle/fluid/lite/model_parser/pb/op_desc.h
+2
-0
paddle/fluid/lite/operators/mul_op.h
paddle/fluid/lite/operators/mul_op.h
+6
-2
paddle/fluid/lite/operators/op_params.h
paddle/fluid/lite/operators/op_params.h
+2
-2
paddle/fluid/lite/tools/build.sh
paddle/fluid/lite/tools/build.sh
+140
-81
paddle/fluid/lite/utils/varient.h
paddle/fluid/lite/utils/varient.h
+4
-3
未找到文件。
CMakeLists.txt
浏览文件 @
4fe5c8aa
...
@@ -44,9 +44,9 @@ if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
...
@@ -44,9 +44,9 @@ if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
# check arch abi
# check arch abi
if
(
NOT DEFINED ARM_TARGET_LANG
)
if
(
NOT DEFINED ARM_TARGET_LANG
)
set
(
ARM_TARGET_LANG
"
clang
"
CACHE STRING
"Choose ARM Target Language"
)
set
(
ARM_TARGET_LANG
"
gcc
"
CACHE STRING
"Choose ARM Target Language"
)
endif
()
endif
()
set
(
ARM_TARGET_LANG_LIST
"gcc"
"clang"
)
set
(
ARM_TARGET_LANG_LIST
"gcc"
"clang"
""
)
set_property
(
CACHE ARM_TARGET_LANG PROPERTY STRINGS
${
ARM_TARGET_LANG_LIST
}
)
set_property
(
CACHE ARM_TARGET_LANG PROPERTY STRINGS
${
ARM_TARGET_LANG_LIST
}
)
if
(
NOT ARM_TARGET_LANG IN_LIST ARM_TARGET_LANG_LIST
)
if
(
NOT ARM_TARGET_LANG IN_LIST ARM_TARGET_LANG_LIST
)
message
(
FATAL_ERROR
"ARM_TARGET_LANG must be in one of
${
ARM_TARGET_LANG_LIST
}
"
)
message
(
FATAL_ERROR
"ARM_TARGET_LANG must be in one of
${
ARM_TARGET_LANG_LIST
}
"
)
...
...
paddle/fluid/lite/api/apis_test.cc
浏览文件 @
4fe5c8aa
...
@@ -82,7 +82,7 @@ TEST(CXXApi_LightApi, save_and_load_model) {
...
@@ -82,7 +82,7 @@ TEST(CXXApi_LightApi, save_and_load_model) {
ASSERT_TRUE
(
TensorCompareWith
(
*
cxx_out
,
*
light_out
));
ASSERT_TRUE
(
TensorCompareWith
(
*
cxx_out
,
*
light_out
));
std
::
vector
<
std
::
string
>
tensors_with_order
({
std
::
vector
<
std
::
string
>
tensors_with_order
({
"a"
,
"fc_0.w_0"
,
"
fc_0.tmp_0"
,
"
scale_0.tmp_0"
,
"a"
,
"fc_0.w_0"
,
"scale_0.tmp_0"
,
});
});
for
(
const
auto
&
tensor_name
:
tensors_with_order
)
{
for
(
const
auto
&
tensor_name
:
tensors_with_order
)
{
...
...
paddle/fluid/lite/api/cxx_api_bin.cc
浏览文件 @
4fe5c8aa
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/lite/api/cxx_api.h"
#include "paddle/fluid/lite/api/cxx_api.h"
#include <chrono>
#include <chrono>
// NOLINT
#include "paddle/fluid/lite/core/mir/use_passes.h"
#include "paddle/fluid/lite/core/mir/use_passes.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/op_registry.h"
...
...
paddle/fluid/lite/api/lite_api_test_helper.cc
浏览文件 @
4fe5c8aa
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/lite/api/lite_api_test_helper.h"
#include "paddle/fluid/lite/api/lite_api_test_helper.h"
#include <vector>
DEFINE_string
(
model_dir
,
""
,
""
);
DEFINE_string
(
model_dir
,
""
,
""
);
DEFINE_string
(
optimized_model
,
""
,
""
);
DEFINE_string
(
optimized_model
,
""
,
""
);
...
...
paddle/fluid/lite/core/CMakeLists.txt
浏览文件 @
4fe5c8aa
...
@@ -59,4 +59,3 @@ lite_cc_test(test_type_system SRCS type_system_test.cc DEPS type_system utils_li
...
@@ -59,4 +59,3 @@ lite_cc_test(test_type_system SRCS type_system_test.cc DEPS type_system utils_li
lite_cc_test
(
test_types_lite SRCS types_test.cc DEPS types_lite
)
lite_cc_test
(
test_types_lite SRCS types_test.cc DEPS types_lite
)
lite_cc_test
(
test_memory_lite SRCS memory_test.cc DEPS memory_lite
)
lite_cc_test
(
test_memory_lite SRCS memory_test.cc DEPS memory_lite
)
lite_cc_test
(
test_context_lite SRCS context_test.cc DEPS context_lite X86_DEPS operator
)
lite_cc_test
(
test_context_lite SRCS context_test.cc DEPS context_lite X86_DEPS operator
)
paddle/fluid/lite/core/mir/CMakeLists.txt
浏览文件 @
4fe5c8aa
...
@@ -5,12 +5,16 @@ cc_library(mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph mir
...
@@ -5,12 +5,16 @@ cc_library(mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph mir
cc_library
(
mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager
)
cc_library
(
mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager
)
add_subdirectory
(
fusion
)
add_subdirectory
(
fusion
)
add_subdirectory
(
elimination
)
cc_library
(
mir_passes
cc_library
(
mir_passes
SRCS fc_fuse_pass.cc
SRCS
conv_elementwise_add_activation_fuse_pass.cc
fusion/fc_fuse_pass.cc
elementwise_add_activation_fuse_pass.cc
fusion/conv_elementwise_add_activation_fuse_pass.cc
conv_bn_fuse_pass.cc
fusion/conv_bn_fuse_pass.cc
quant_dequant_fuse_pass.cc
fusion/elementwise_add_activation_fuse_pass.cc
fusion/quant_dequant_fuse_pass.cc
elimination/identity_scale_eliminate_pass.cc
static_kernel_pick_pass.cc
static_kernel_pick_pass.cc
variable_place_inference_pass.cc
variable_place_inference_pass.cc
type_target_transform_pass.cc
type_target_transform_pass.cc
...
@@ -74,7 +78,7 @@ message(STATUS "----> Ops lite: ${ops_lite}")
...
@@ -74,7 +78,7 @@ message(STATUS "----> Ops lite: ${ops_lite}")
message
(
STATUS
"----> Host kernels:
${
host_kernels
}
"
)
message
(
STATUS
"----> Host kernels:
${
host_kernels
}
"
)
message
(
STATUS
"----> X86 kernels:
${
x86_kernels
}
"
)
message
(
STATUS
"----> X86 kernels:
${
x86_kernels
}
"
)
lite_cc_test
(
test_lite_fc_fuse SRCS fc_fuse_pass_test.cc
lite_cc_test
(
test_lite_fc_fuse SRCS f
usion/f
c_fuse_pass_test.cc
DEPS cxx_api_lite mir_passes
DEPS cxx_api_lite mir_passes
${
ops_lite
}
${
host_kernels
}
${
x86_kernels
}
${
arm_kernels
}
${
ops_lite
}
${
host_kernels
}
${
x86_kernels
}
${
arm_kernels
}
ARGS --model_dir=
${
LITE_MODEL_DIR
}
/lite_fc_model
ARGS --model_dir=
${
LITE_MODEL_DIR
}
/lite_fc_model
...
@@ -85,10 +89,10 @@ add_dependencies(test_lite_fc_fuse extern_lite_download_lite_fc_model_tar_gz)
...
@@ -85,10 +89,10 @@ add_dependencies(test_lite_fc_fuse extern_lite_download_lite_fc_model_tar_gz)
lite_cc_test
(
test_lite_conv_elementwise_add_activation_fuse
lite_cc_test
(
test_lite_conv_elementwise_add_activation_fuse
SRCS conv_elementwise_add_activation_fuse_pass_test.cc
SRCS
fusion/
conv_elementwise_add_activation_fuse_pass_test.cc
DEPS cxx_api_lite mir_passes
DEPS cxx_api_lite mir_passes
${
ops_lite
}
${
host_kernels
}
${
x86_kernels
}
)
${
ops_lite
}
${
host_kernels
}
${
x86_kernels
}
)
lite_cc_test
(
test_lite_elementwise_add_activation_fuse
lite_cc_test
(
test_lite_elementwise_add_activation_fuse
SRCS elementwise_add_activation_fuse_pass_test.cc
SRCS
fusion/
elementwise_add_activation_fuse_pass_test.cc
DEPS cxx_api_lite mir_passes
DEPS cxx_api_lite mir_passes
${
ops_lite
}
${
host_kernels
}
${
x86_kernels
}
)
${
ops_lite
}
${
host_kernels
}
${
x86_kernels
}
)
paddle/fluid/lite/core/mir/elimination/CMakeLists.txt
0 → 100644
浏览文件 @
4fe5c8aa
if
(
NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
)
lite_cc_test
(
test_identity_scale_eliminate_pass_lite
SRCS identity_scale_eliminate_pass_test.cc
DEPS mir_passes program_lite proto_desc cpp_op_desc_lite
${
ops_lite
}
)
endif
()
paddle/fluid/lite/core/mir/elimination/identity_scale_eliminate_pass.cc
0 → 100644
浏览文件 @
4fe5c8aa
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
namespace
{
class
Eliminator
:
public
FuseBase
{
public:
void
BuildPattern
()
override
{
auto
*
pre_op
=
OpNode
(
"preop"
);
// the previous op's output need update
// TODO(Superjomn) check has only one output
auto
*
x
=
VarNode
(
"x"
)
->
assert_is_op_input
(
"scale"
,
"X"
);
auto
*
scale_op
=
OpNode
(
"scale"
,
"scale"
)
->
assert_op_attr
<
float
>
(
"scale"
,
1.
)
->
assert_op_attr
<
float
>
(
"bias"
,
0.
);
auto
*
out
=
VarNode
(
"out"
)
->
assert_is_op_output
(
"scale"
,
"Out"
);
*
pre_op
>>
*
x
>>
*
scale_op
>>
*
out
;
// The pre_op will be eliminated, and a new output-updated op will insert.
x
->
AsIntermediate
();
// x is pre_op's output, need to update
}
private:
void
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
override
{
auto
&
pre_op
=
matched
.
at
(
"preop"
)
->
AsStmt
();
auto
op_info
=
*
pre_op
.
op_info
();
op_info
.
UpdateAllOutputs
(
matched
.
at
(
"x"
)
->
AsArg
().
name
,
matched
.
at
(
"out"
)
->
AsArg
().
name
);
pre_op
.
ResetOp
(
op_info
,
graph
->
valid_places
());
GraphSafeRemoveNodes
(
graph
,
{
matched
.
at
(
"scale"
)});
IR_NODE_LINK_TO
(
matched
.
at
(
"preop"
),
matched
.
at
(
"out"
));
}
};
}
// namespace
class
IdentityScaleEliminatePass
:
public
ProgramPass
{
public:
void
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
override
{
Eliminator
eliminator
;
eliminator
(
graph
.
get
());
}
};
}
// namespace mir
}
// namespace lite
}
// namespace paddle
REGISTER_MIR_PASS
(
identity_scale_eliminate_pass
,
paddle
::
lite
::
mir
::
IdentityScaleEliminatePass
);
paddle/fluid/lite/core/mir/elimination/identity_scale_eliminate_pass_test.cc
0 → 100644
浏览文件 @
4fe5c8aa
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
#include "paddle/fluid/lite/core/mir/ssa_graph.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
std
::
unique_ptr
<
SSAGraph
>
BuildGraph
(
framework
::
ProgramDesc
*
program_desc
,
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
std
::
vector
<
Place
>&
valid_places
)
{
// Op list:
// (x)->feed -> (feed) -> scale -> (scale_out) -> fetch->(fetch)
// After pass
// (x)->feed->(scale_out)->fetch->(fetch)
auto
*
main_block
=
program_desc
->
MutableBlock
(
0
);
auto
*
feed_op
=
main_block
->
AppendOp
();
auto
*
scale_op
=
main_block
->
AppendOp
();
auto
*
fetch_op
=
main_block
->
AppendOp
();
main_block
->
Var
(
"x"
);
main_block
->
Var
(
"feed"
);
main_block
->
Var
(
"scale_out"
);
main_block
->
Var
(
"fetch_out"
);
scope
->
Var
(
"x"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"feed"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"scale_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"fetch_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
feed_op
->
SetType
(
"feed"
);
feed_op
->
SetInput
(
"X"
,
{
"x"
});
feed_op
->
SetAttr
(
"col"
,
1
);
feed_op
->
SetOutput
(
"Out"
,
{
"feed"
});
scale_op
->
SetType
(
"scale"
);
scale_op
->
SetInput
(
"X"
,
{
"feed"
});
scale_op
->
SetOutput
(
"Out"
,
{
"scale_out"
});
scale_op
->
SetAttr
(
"scale"
,
1.
f
);
scale_op
->
SetAttr
(
"bias"
,
0.
f
);
scale_op
->
SetAttr
(
"bias_after_scale"
,
true
);
fetch_op
->
SetType
(
"fetch"
);
fetch_op
->
SetInput
(
"X"
,
{
"scale_out"
});
fetch_op
->
SetOutput
(
"Out"
,
{
"fetch"
});
fetch_op
->
SetAttr
(
"col"
,
1
);
program_desc
->
Flush
();
lite
::
Program
program
(
*
program_desc
->
Proto
(),
scope
,
valid_places
);
auto
graph
=
std
::
unique_ptr
<
SSAGraph
>
(
new
SSAGraph
());
graph
->
Build
(
program
,
valid_places
);
LOG
(
INFO
)
<<
Visualize
(
graph
.
get
());
return
graph
;
}
TEST
(
identity_test
,
test
)
{
framework
::
ProgramDesc
program_desc
;
std
::
vector
<
Place
>
places
{{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}};
auto
scope
=
std
::
make_shared
<
Scope
>
();
auto
graph
=
BuildGraph
(
&
program_desc
,
scope
,
places
);
const
int
num_nodes
=
graph
->
nodes
().
size
();
auto
pass
=
PassManager
::
Global
().
LookUp
(
"identity_scale_eliminate_pass"
);
ASSERT_TRUE
(
pass
);
pass
->
Apply
(
graph
);
ASSERT_EQ
(
graph
->
nodes
().
size
(),
num_nodes
-
2UL
);
}
}
// namespace mir
}
// namespace lite
}
// namespace paddle
USE_LITE_OP
(
feed
)
USE_LITE_OP
(
fetch
)
USE_LITE_OP
(
scale
)
USE_MIR_PASS
(
identity_scale_eliminate_pass
)
paddle/fluid/lite/core/mir/fusion/CMakeLists.txt
浏览文件 @
4fe5c8aa
...
@@ -10,7 +10,6 @@ cc_library(fuse_conv_bn
...
@@ -10,7 +10,6 @@ cc_library(fuse_conv_bn
cc_library
(
fuse_elementwise_add_activation
cc_library
(
fuse_elementwise_add_activation
SRCS elementwise_add_activation_fuser.cc
SRCS elementwise_add_activation_fuser.cc
DEPS pattern_matcher_high_api
)
DEPS pattern_matcher_high_api
)
cc_library
(
fuse_quant_dequant
cc_library
(
fuse_quant_dequant
SRCS quant_dequant_op_fuser.cc
SRCS quant_dequant_op_fuser.cc
DEPS pattern_matcher_high_api
)
DEPS pattern_matcher_high_api
)
...
...
paddle/fluid/lite/core/mir/conv_bn_fuse_pass.cc
→
paddle/fluid/lite/core/mir/
fusion/
conv_bn_fuse_pass.cc
浏览文件 @
4fe5c8aa
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/
fusion/
conv_bn_fuse_pass.h"
#include <memory>
#include <memory>
#include <vector>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.h"
#include "paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.h"
...
...
paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h
→
paddle/fluid/lite/core/mir/
fusion/
conv_bn_fuse_pass.h
浏览文件 @
4fe5c8aa
文件已移动
paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass_test.cc
浏览文件 @
4fe5c8aa
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/
fusion/
conv_bn_fuse_pass.h"
#include <gflags/gflags.h>
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include <vector>
#include <vector>
...
...
paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.cc
浏览文件 @
4fe5c8aa
...
@@ -70,7 +70,7 @@ void ConvBNFuser::BuildPattern() {
...
@@ -70,7 +70,7 @@ void ConvBNFuser::BuildPattern() {
void
ConvBNFuser
::
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
{
void
ConvBNFuser
::
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
{
auto
op_desc
=
GenOpDesc
(
matched
);
auto
op_desc
=
GenOpDesc
(
matched
);
auto
eltwise_op
=
LiteOpRegistry
::
Global
().
Create
(
"elementwise_add"
);
auto
eltwise_op
=
LiteOpRegistry
::
Global
().
Create
(
"elementwise_add"
);
auto
conv
=
matched
.
at
(
"conv2d"
)
->
stmt
()
->
op
;
auto
conv
=
matched
.
at
(
"conv2d"
)
->
stmt
()
->
op
()
;
auto
*
scope
=
conv
->
scope
();
auto
*
scope
=
conv
->
scope
();
auto
&
valid_places
=
conv
->
valid_places
();
auto
&
valid_places
=
conv
->
valid_places
();
...
...
paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.cc
→
paddle/fluid/lite/core/mir/
fusion/
conv_elementwise_add_activation_fuse_pass.cc
浏览文件 @
4fe5c8aa
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/
fusion/
conv_elementwise_add_activation_fuse_pass.h"
#include <memory>
#include <memory>
#include <vector>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.h"
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.h"
...
...
paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.h
→
paddle/fluid/lite/core/mir/
fusion/
conv_elementwise_add_activation_fuse_pass.h
浏览文件 @
4fe5c8aa
文件已移动
paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass_test.cc
→
paddle/fluid/lite/core/mir/
fusion/
conv_elementwise_add_activation_fuse_pass_test.cc
浏览文件 @
4fe5c8aa
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/
fusion/
conv_elementwise_add_activation_fuse_pass.h"
#include <gflags/gflags.h>
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include <vector>
#include <vector>
...
...
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.cc
浏览文件 @
4fe5c8aa
...
@@ -65,7 +65,7 @@ void ConvElementwiseAddActivationFuser::InsertNewNode(
...
@@ -65,7 +65,7 @@ void ConvElementwiseAddActivationFuser::InsertNewNode(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
{
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
{
auto
op_desc
=
GenOpDesc
(
matched
);
auto
op_desc
=
GenOpDesc
(
matched
);
auto
conv_op
=
LiteOpRegistry
::
Global
().
Create
(
conv_type_
);
auto
conv_op
=
LiteOpRegistry
::
Global
().
Create
(
conv_type_
);
auto
conv_old
=
matched
.
at
(
"conv2d"
)
->
stmt
()
->
op
;
auto
conv_old
=
matched
.
at
(
"conv2d"
)
->
stmt
()
->
op
()
;
auto
*
scope
=
conv_old
->
scope
();
auto
*
scope
=
conv_old
->
scope
();
auto
&
valid_places
=
conv_old
->
valid_places
();
auto
&
valid_places
=
conv_old
->
valid_places
();
conv_op
->
Attach
(
op_desc
,
scope
);
conv_op
->
Attach
(
op_desc
,
scope
);
...
...
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass.cc
0 → 100644
浏览文件 @
4fe5c8aa
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
void
ConvElementwiseAddReLUFusePass
::
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
{
fusion
::
ConvElementwiseAddReLUFuser
fuser
(
"conv2d"
);
fuser
(
graph
.
get
());
fusion
::
ConvElementwiseAddReLUFuser
depthwise_fuser
(
"depthwise_conv2d"
);
depthwise_fuser
(
graph
.
get
());
}
}
// namespace mir
}
// namespace lite
}
// namespace paddle
REGISTER_MIR_PASS
(
lite_conv_elementwise_add_act_fuse_pass
,
paddle
::
lite
::
mir
::
ConvElementwiseAddReLUFusePass
);
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass.h
0 → 100644
浏览文件 @
4fe5c8aa
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pass.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
class
ConvElementwiseAddReLUFusePass
:
public
ProgramPass
{
public:
void
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
override
;
};
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass_test.cc
0 → 100644
浏览文件 @
4fe5c8aa
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/api/cxx_api.h"
#include "paddle/fluid/lite/core/compatible_tensor.h"
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/program.h"
DEFINE_string
(
model_dir
,
""
,
""
);
DEFINE_string
(
optimized_model
,
""
,
""
);
namespace
paddle
{
namespace
lite
{
namespace
mir
{
namespace
fusion
{
std
::
unique_ptr
<
SSAGraph
>
BuildGraph
(
framework
::
ProgramDesc
*
program_desc
,
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
std
::
vector
<
Place
>&
valid_places
)
{
auto
*
main_block
=
program_desc
->
MutableBlock
(
0
);
auto
*
conv2d_1
=
main_block
->
AppendOp
();
auto
*
conv2d_2
=
main_block
->
AppendOp
();
auto
*
add_1
=
main_block
->
AppendOp
();
auto
*
relu_1
=
main_block
->
AppendOp
();
auto
*
add_2
=
main_block
->
AppendOp
();
auto
*
relu_2
=
main_block
->
AppendOp
();
main_block
->
Var
(
"input_1"
);
main_block
->
Var
(
"input_2"
);
main_block
->
Var
(
"filter_1"
);
main_block
->
Var
(
"filter_2"
);
main_block
->
Var
(
"conv2d_1_out"
);
main_block
->
Var
(
"conv2d_2_out"
);
main_block
->
Var
(
"bias_1"
);
main_block
->
Var
(
"add_1_out"
);
main_block
->
Var
(
"add_2_out"
);
main_block
->
Var
(
"relu_1_out"
);
main_block
->
Var
(
"out"
);
scope
->
Var
(
"input_1"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"input_2"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"filter_1"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"filter_2"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"conv2d_1_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"conv2d_2_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"bias_1"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"add_1_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"add_2_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"relu_1_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"out"
)
->
GetMutable
<
lite
::
Tensor
>
();
conv2d_1
->
SetType
(
"conv2d"
);
conv2d_1
->
SetInput
(
"Input"
,
{
"input_1"
});
conv2d_1
->
SetInput
(
"Filter"
,
{
"filter_1"
});
conv2d_1
->
SetOutput
(
"Output"
,
{
"conv2d_1_out"
});
conv2d_1
->
SetAttr
(
"strides"
,
std
::
vector
<
int
>
({
1
,
1
}));
conv2d_1
->
SetAttr
(
"paddings"
,
std
::
vector
<
int
>
({
0
,
0
}));
conv2d_1
->
SetAttr
(
"groups"
,
1
);
conv2d_1
->
SetAttr
(
"dilations"
,
std
::
vector
<
int
>
({
1
,
1
}));
conv2d_1
->
SetAttr
(
"fuse_relu"
,
false
);
add_1
->
SetType
(
"elementwise_add"
);
add_1
->
SetInput
(
"X"
,
{
"conv2d_1_out"
});
add_1
->
SetInput
(
"Y"
,
{
"bias_1"
});
add_1
->
SetOutput
(
"Out"
,
{
"add_1_out"
});
add_1
->
SetAttr
(
"axis"
,
1
);
relu_1
->
SetType
(
"relu"
);
relu_1
->
SetInput
(
"X"
,
{
"add_1_out"
});
relu_1
->
SetOutput
(
"Out"
,
{
"relu_1_out"
});
conv2d_2
->
SetType
(
"conv2d"
);
conv2d_2
->
SetInput
(
"Input"
,
{
"input_2"
});
conv2d_2
->
SetInput
(
"Filter"
,
{
"filter_2"
});
conv2d_2
->
SetOutput
(
"Output"
,
{
"conv2d_2_out"
});
conv2d_2
->
SetAttr
(
"strides"
,
std
::
vector
<
int
>
({
1
,
1
}));
conv2d_2
->
SetAttr
(
"paddings"
,
std
::
vector
<
int
>
({
0
,
0
}));
conv2d_2
->
SetAttr
(
"groups"
,
1
);
conv2d_2
->
SetAttr
(
"dilations"
,
std
::
vector
<
int
>
({
1
,
1
}));
conv2d_2
->
SetAttr
(
"fuse_relu"
,
false
);
add_2
->
SetType
(
"elementwise_add"
);
add_2
->
SetInput
(
"X"
,
{
"conv2d_2_out"
});
add_2
->
SetInput
(
"Y"
,
{
"relu_1_out"
});
add_2
->
SetOutput
(
"Out"
,
{
"add_2_out"
});
add_2
->
SetAttr
(
"axis"
,
1
);
relu_2
->
SetType
(
"relu"
);
relu_2
->
SetInput
(
"X"
,
{
"add_2_out"
});
relu_2
->
SetOutput
(
"Out"
,
{
"out"
});
program_desc
->
Flush
();
lite
::
Program
program
(
*
program_desc
->
Proto
(),
scope
,
valid_places
);
auto
graph
=
std
::
unique_ptr
<
SSAGraph
>
(
new
SSAGraph
());
graph
->
Build
(
program
,
valid_places
);
return
graph
;
}
TEST
(
conv_elementwise_add_relu_fuse_pass
,
graph_test
)
{
framework
::
ProgramDesc
program_desc
;
std
::
vector
<
Place
>
places
{{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}};
auto
scope
=
std
::
make_shared
<
Scope
>
();
auto
graph
=
BuildGraph
(
&
program_desc
,
scope
,
places
);
Visualize
(
graph
.
get
());
ASSERT_EQ
(
graph
->
nodes
().
size
(),
11UL
/*vars*/
+
6UL
/*ops*/
);
Visualize
(
graph
.
get
());
}
TEST
(
conv_elementwise_add_relu_fuse_pass
,
fuse_test_op
)
{
framework
::
ProgramDesc
program_desc
;
std
::
vector
<
Place
>
places
{{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}};
auto
scope
=
std
::
make_shared
<
Scope
>
();
auto
graph
=
BuildGraph
(
&
program_desc
,
scope
,
places
);
Visualize
(
graph
.
get
());
const
int
num_nodes
=
graph
->
nodes
().
size
();
auto
*
fuser
=
new
ConvElementwiseAddReLUFusePass
;
fuser
->
Apply
(
graph
);
Visualize
(
graph
.
get
());
ASSERT_EQ
(
graph
->
nodes
().
size
(),
num_nodes
-
5UL
*
2
/*nodes removed */
+
1UL
*
2
/* fused fc node*/
);
}
}
// namespace fusion
}
// namespace mir
}
// namespace lite
}
// namespace paddle
USE_LITE_OP
(
elementwise_add
);
USE_LITE_OP
(
conv2d
);
USE_LITE_OP
(
depthwise_conv2d
);
USE_LITE_OP
(
relu
);
paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.cc
→
paddle/fluid/lite/core/mir/
fusion/
elementwise_add_activation_fuse_pass.cc
浏览文件 @
4fe5c8aa
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/
fusion/
elementwise_add_activation_fuse_pass.h"
#include <memory>
#include <memory>
#include <vector>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.h"
#include "paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.h"
...
...
paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.h
→
paddle/fluid/lite/core/mir/
fusion/
elementwise_add_activation_fuse_pass.h
浏览文件 @
4fe5c8aa
文件已移动
paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass_test.cc
→
paddle/fluid/lite/core/mir/
fusion/
elementwise_add_activation_fuse_pass_test.cc
浏览文件 @
4fe5c8aa
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/
fusion/
elementwise_add_activation_fuse_pass.h"
#include <gflags/gflags.h>
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include <vector>
#include <vector>
...
...
paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.cc
浏览文件 @
4fe5c8aa
...
@@ -54,7 +54,7 @@ void ElementwiseAddActivationFuser::InsertNewNode(SSAGraph* graph,
...
@@ -54,7 +54,7 @@ void ElementwiseAddActivationFuser::InsertNewNode(SSAGraph* graph,
auto
op_desc
=
GenOpDesc
(
matched
);
auto
op_desc
=
GenOpDesc
(
matched
);
auto
op
=
auto
op
=
LiteOpRegistry
::
Global
().
Create
(
"fusion_elementwise_add_activation"
);
LiteOpRegistry
::
Global
().
Create
(
"fusion_elementwise_add_activation"
);
auto
old_op
=
matched
.
at
(
"add"
)
->
stmt
()
->
op
;
auto
old_op
=
matched
.
at
(
"add"
)
->
stmt
()
->
op
()
;
auto
*
scope
=
old_op
->
scope
();
auto
*
scope
=
old_op
->
scope
();
auto
&
valid_places
=
old_op
->
valid_places
();
auto
&
valid_places
=
old_op
->
valid_places
();
op
->
Attach
(
op_desc
,
scope
);
op
->
Attach
(
op_desc
,
scope
);
...
...
paddle/fluid/lite/core/mir/fc_fuse_pass.cc
→
paddle/fluid/lite/core/mir/f
usion/f
c_fuse_pass.cc
浏览文件 @
4fe5c8aa
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fc_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/f
usion/f
c_fuse_pass.h"
#include <memory>
#include <memory>
#include <vector>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/fc_fuser.h"
#include "paddle/fluid/lite/core/mir/fusion/fc_fuser.h"
...
...
paddle/fluid/lite/core/mir/fc_fuse_pass.h
→
paddle/fluid/lite/core/mir/f
usion/f
c_fuse_pass.h
浏览文件 @
4fe5c8aa
文件已移动
paddle/fluid/lite/core/mir/fc_fuse_pass_test.cc
→
paddle/fluid/lite/core/mir/f
usion/f
c_fuse_pass_test.cc
浏览文件 @
4fe5c8aa
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fc_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/f
usion/f
c_fuse_pass.h"
#include <gflags/gflags.h>
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include <vector>
#include <vector>
...
...
paddle/fluid/lite/core/mir/fusion/fc_fuser.cc
浏览文件 @
4fe5c8aa
...
@@ -46,7 +46,7 @@ void FcFuser::BuildPattern() {
...
@@ -46,7 +46,7 @@ void FcFuser::BuildPattern() {
void
FcFuser
::
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
{
void
FcFuser
::
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
{
auto
op_desc
=
GenOpDesc
(
matched
);
auto
op_desc
=
GenOpDesc
(
matched
);
auto
fc_op
=
LiteOpRegistry
::
Global
().
Create
(
"fc"
);
auto
fc_op
=
LiteOpRegistry
::
Global
().
Create
(
"fc"
);
auto
mul
=
matched
.
at
(
"mul"
)
->
stmt
()
->
op
;
auto
mul
=
matched
.
at
(
"mul"
)
->
stmt
()
->
op
()
;
auto
*
scope
=
mul
->
scope
();
auto
*
scope
=
mul
->
scope
();
auto
&
valid_places
=
mul
->
valid_places
();
auto
&
valid_places
=
mul
->
valid_places
();
fc_op
->
Attach
(
op_desc
,
scope
);
fc_op
->
Attach
(
op_desc
,
scope
);
...
...
paddle/fluid/lite/core/mir/quant_dequant_fuse_pass.cc
→
paddle/fluid/lite/core/mir/
fusion/
quant_dequant_fuse_pass.cc
浏览文件 @
4fe5c8aa
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/lite/core/mir/quant_dequant_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/
fusion/
quant_dequant_fuse_pass.h"
#include <memory>
#include <memory>
#include <vector>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/quant_dequant_op_fuser.h"
#include "paddle/fluid/lite/core/mir/fusion/quant_dequant_op_fuser.h"
...
...
paddle/fluid/lite/core/mir/quant_dequant_fuse_pass.h
→
paddle/fluid/lite/core/mir/
fusion/
quant_dequant_fuse_pass.h
浏览文件 @
4fe5c8aa
文件已移动
paddle/fluid/lite/core/mir/fusion/quant_dequant_op_fuser.cc
浏览文件 @
4fe5c8aa
...
@@ -115,8 +115,8 @@ void QuantDequantOpFuser::InsertNewNode(SSAGraph* graph,
...
@@ -115,8 +115,8 @@ void QuantDequantOpFuser::InsertNewNode(SSAGraph* graph,
nodes
.
push_back
(
matched
.
at
(
"dequant_op_out"
+
std
::
to_string
(
i
)));
nodes
.
push_back
(
matched
.
at
(
"dequant_op_out"
+
std
::
to_string
(
i
)));
}
}
int
bit_length
=
quant_op
->
stmt
()
->
op_info
()
->
GetAttr
<
int
>
(
"bit_length"
);
int
bit_length
=
quant_op
->
stmt
()
->
op_info
()
->
GetAttr
<
int
>
(
"bit_length"
);
auto
*
scope
=
quant_op
->
stmt
()
->
op
->
scope
();
auto
*
scope
=
quant_op
->
stmt
()
->
op
()
->
scope
();
auto
&
valid_places
=
quant_op
->
stmt
()
->
op
->
valid_places
();
auto
&
valid_places
=
quant_op
->
stmt
()
->
op
()
->
valid_places
();
int
range
=
((
1
<<
(
bit_length
-
1
))
-
1
);
int
range
=
((
1
<<
(
bit_length
-
1
))
-
1
);
auto
input_scale_t
=
scope
->
FindVar
(
quant_op_in_scale
->
arg
()
->
name
)
auto
input_scale_t
=
scope
->
FindVar
(
quant_op_in_scale
->
arg
()
->
name
)
->
GetMutable
<
lite
::
Tensor
>
();
->
GetMutable
<
lite
::
Tensor
>
();
...
...
paddle/fluid/lite/core/mir/generate_program_pass.cc
浏览文件 @
4fe5c8aa
...
@@ -29,7 +29,7 @@ void GenerateProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
...
@@ -29,7 +29,7 @@ void GenerateProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
if
(
item
->
IsStmt
())
{
if
(
item
->
IsStmt
())
{
auto
&
stmt
=
item
->
AsStmt
();
auto
&
stmt
=
item
->
AsStmt
();
VLOG
(
4
)
<<
stmt
;
VLOG
(
4
)
<<
stmt
;
insts_
.
emplace_back
(
stmt
.
op
,
std
::
move
(
stmt
.
valid_kernels
.
front
()));
insts_
.
emplace_back
(
stmt
.
op
(),
std
::
move
(
stmt
.
kernels
()
.
front
()));
}
}
}
}
}
}
...
...
paddle/fluid/lite/core/mir/graph_visualize_pass.cc
浏览文件 @
4fe5c8aa
...
@@ -39,7 +39,7 @@ std::string Visualize(mir::SSAGraph* graph) {
...
@@ -39,7 +39,7 @@ std::string Visualize(mir::SSAGraph* graph) {
if
(
node
.
IsArg
())
{
if
(
node
.
IsArg
())
{
key
=
node
.
AsArg
().
name
;
key
=
node
.
AsArg
().
name
;
}
else
{
}
else
{
key
=
node
.
AsStmt
().
op_type
+
std
::
to_string
(
id
++
);
key
=
node
.
AsStmt
().
op_type
()
+
std
::
to_string
(
id
++
);
}
}
if
(
node
.
IsStmt
())
{
if
(
node
.
IsStmt
())
{
...
...
paddle/fluid/lite/core/mir/io_copy_kernel_pick_pass.cc
浏览文件 @
4fe5c8aa
...
@@ -25,11 +25,11 @@ class IoCopyKernelPickPass : public StmtPass {
...
@@ -25,11 +25,11 @@ class IoCopyKernelPickPass : public StmtPass {
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
if
(
!
node
.
IsStmt
())
continue
;
if
(
!
node
.
IsStmt
())
continue
;
auto
&
inst
=
node
.
AsStmt
();
auto
&
inst
=
node
.
AsStmt
();
if
(
inst
.
op_type
!=
"io_copy"
)
continue
;
if
(
inst
.
op_type
()
!=
"io_copy"
)
continue
;
LOG
(
INFO
)
<<
"....> picking a IO COPY kernel"
;
LOG
(
INFO
)
<<
"....> picking a IO COPY kernel"
;
auto
&
kernels
=
node
.
AsStmt
().
valid_kernels
;
auto
&
kernels
=
node
.
AsStmt
().
kernels
()
;
CHECK
(
!
kernels
.
empty
())
<<
"No valid kernels found for IoCopy Op"
;
CHECK
(
!
kernels
.
empty
())
<<
"No valid kernels found for IoCopy Op"
;
const
auto
*
inty
=
node
.
inlinks
.
front
()
->
AsArg
().
type
;
const
auto
*
inty
=
node
.
inlinks
.
front
()
->
AsArg
().
type
;
const
auto
*
outy
=
node
.
outlinks
.
front
()
->
AsArg
().
type
;
const
auto
*
outy
=
node
.
outlinks
.
front
()
->
AsArg
().
type
;
...
...
paddle/fluid/lite/core/mir/node.cc
浏览文件 @
4fe5c8aa
...
@@ -13,3 +13,62 @@
...
@@ -13,3 +13,62 @@
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/lite/core/mir/node.h"
#include "paddle/fluid/lite/core/mir/node.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace
paddle
{
namespace
lite
{
const
OpInfo
*
mir
::
Node
::
Stmt
::
op_info
()
const
{
CHECK
(
op_
);
return
op_
->
op_info
();
}
Place
mir
::
Node
::
Stmt
::
place
()
const
{
CHECK
(
!
valid_kernels_
.
empty
());
return
valid_kernels_
.
front
()
->
place
();
}
KernelBase
&
mir
::
Node
::
Stmt
::
picked_kernel
()
{
CHECK
(
!
valid_kernels_
.
empty
())
<<
"no kernel for "
<<
op_type
();
return
*
valid_kernels_
.
front
();
}
OpInfo
*
mir
::
Node
::
Stmt
::
mutable_op_info
()
{
CHECK
(
op_
);
return
op_
->
mutable_op_info
();
}
void
mir
::
Node
::
Stmt
::
ResetOp
(
const
cpp
::
OpDesc
&
op_desc
,
const
std
::
vector
<
Place
>
&
valid_places
,
lite
::
Scope
*
scope
)
{
CHECK
((
op_
&&
op_
->
scope
())
||
scope
)
<<
"Either scope should be set"
;
lite
::
Scope
*
the_scope
=
scope
?
scope
:
op_
->
scope
();
op_
->
Attach
(
op_desc
,
the_scope
);
// Recreate the kernels with the latest OpInfo.
valid_kernels_
.
clear
();
if
(
!
op_
||
op_
->
op_info
()
->
Type
()
!=
op_desc
.
Type
())
{
op_
=
LiteOpRegistry
::
Global
().
Create
(
op_desc
.
Type
());
CHECK
(
op_
)
<<
"No op found for "
<<
op_desc
.
Type
();
}
valid_kernels_
=
op_
->
CreateKernels
(
valid_places
);
}
std
::
ostream
&
mir
::
operator
<<
(
std
::
ostream
&
os
,
const
mir
::
Node
::
Stmt
&
other
)
{
os
<<
"Statement "
<<
other
.
op_type
()
<<
" "
<<
other
.
place
();
return
os
;
}
mir
::
Node
::
Arg
&
mir
::
Node
::
AsArg
(
const
std
::
string
&
name
,
int
id
)
{
auto
&
x
=
AsArg
();
x
.
name
=
name
;
x
.
id
=
id
;
return
x
;
}
mir
::
Node
::
Arg
&
mir
::
Node
::
AsArg
(
const
std
::
string
&
name
)
{
auto
&
x
=
AsArg
();
x
.
name
=
name
;
return
x
;
}
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/node.h
浏览文件 @
4fe5c8aa
...
@@ -41,32 +41,40 @@ class Node {
...
@@ -41,32 +41,40 @@ class Node {
kUnk
,
kUnk
,
};
};
struct
Stmt
{
class
Stmt
{
std
::
string
op_type
;
// The kernel instances this Statement contains.
// The kernel instances this Statement contains.
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>
valid_kernels
;
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>
valid_kernels
_
;
// TODO(Superjomn) make this a shared_ptr for resource safety.
// TODO(Superjomn) make this a shared_ptr for resource safety.
std
::
shared_ptr
<
OpLite
>
op
;
// we hold op to run InferShape
std
::
shared_ptr
<
OpLite
>
op
_
;
// we hold op to run InferShape
const
OpInfo
*
op_info
()
{
public:
CHECK
(
op
);
// Refresh the operator and kernels with the latest OpInfo.
return
op
->
op_info
();
void
ResetOp
(
const
cpp
::
OpDesc
&
op_desc
,
}
const
std
::
vector
<
Place
>&
valid_places
,
lite
::
Scope
*
scope
=
nullptr
);
Place
place
()
const
{
std
::
string
op_type
()
const
{
return
op_info
()
->
Type
();
}
CHECK
(
!
valid_kernels
.
empty
());
const
OpInfo
*
op_info
()
const
;
return
valid_kernels
.
front
()
->
place
();
OpInfo
*
mutable_op_info
();
}
KernelBase
&
picked_kernel
()
{
void
SetKernels
(
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>&&
kernels
)
{
CHECK
(
!
valid_kernels
.
empty
())
<<
"no kernel for "
<<
op_type
;
valid_kernels_
=
std
::
move
(
kernels
);
return
*
valid_kernels
.
front
();
}
}
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>&
kernels
()
{
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Stmt
&
other
)
{
return
valid_kernels_
;
os
<<
"Statement "
<<
other
.
op_type
<<
" "
<<
other
.
place
();
return
os
;
}
}
void
SetOp
(
const
std
::
shared_ptr
<
OpLite
>&
op
)
{
op_
=
op
;
}
const
std
::
shared_ptr
<
OpLite
>
op
()
const
{
return
op_
;
}
Place
place
()
const
;
KernelBase
&
picked_kernel
();
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Stmt
&
other
);
// Description.
std
::
string
desc
;
};
};
struct
Arg
{
struct
Arg
{
...
@@ -78,26 +86,16 @@ class Node {
...
@@ -78,26 +86,16 @@ class Node {
bool
is_weight
{
false
};
bool
is_weight
{
false
};
};
};
Arg
&
AsArg
(
const
std
::
string
&
name
,
int
id
)
{
Arg
&
AsArg
(
const
std
::
string
&
name
,
int
id
);
auto
&
x
=
AsArg
();
x
.
name
=
name
;
x
.
id
=
id
;
return
x
;
}
Arg
&
AsArg
(
const
std
::
string
&
name
)
{
Arg
&
AsArg
(
const
std
::
string
&
name
);
auto
&
x
=
AsArg
();
x
.
name
=
name
;
return
x
;
}
Stmt
&
AsStmt
(
const
std
::
string
&
op_type
,
Stmt
&
AsStmt
(
const
std
::
string
&
op_type
,
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>&&
kernels
,
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>&&
kernels
,
const
std
::
shared_ptr
<
OpLite
>&
op
)
{
const
std
::
shared_ptr
<
OpLite
>&
op
)
{
auto
&
x
=
AsStmt
();
auto
&
x
=
AsStmt
();
x
.
op_type
=
op_type
;
x
.
SetOp
(
op
);
x
.
op
=
op
;
x
.
SetKernels
(
std
::
move
(
kernels
));
x
.
valid_kernels
=
std
::
move
(
kernels
);
return
x
;
return
x
;
}
}
...
@@ -142,7 +140,7 @@ class Node {
...
@@ -142,7 +140,7 @@ class Node {
}
}
if
(
other
.
IsStmt
())
{
if
(
other
.
IsStmt
())
{
auto
&
arg
=
other
.
AsStmt
();
auto
&
arg
=
other
.
AsStmt
();
os
<<
"Statement "
<<
arg
.
op_type
;
os
<<
"Statement "
<<
arg
.
op_type
()
;
}
}
return
os
;
return
os
;
}
}
...
...
paddle/fluid/lite/core/mir/pattern_matcher.h
浏览文件 @
4fe5c8aa
...
@@ -139,14 +139,13 @@ struct PMNode {
...
@@ -139,14 +139,13 @@ struct PMNode {
template
<
typename
T
>
template
<
typename
T
>
PMNode
*
assert_op_attr
(
const
std
::
string
&
attr_name
,
const
T
&
attr
)
{
PMNode
*
assert_op_attr
(
const
std
::
string
&
attr_name
,
const
T
&
attr
)
{
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
asserts_
.
push_back
([
=
](
const
Node
*
x
)
{
if
(
x
&&
x
->
IsStmt
())
{
if
(
x
&&
x
->
IsStmt
())
{
auto
*
op_info
=
x
->
stmt
()
->
op_info
();
auto
*
op_info
=
x
->
stmt
()
->
op_info
();
return
op_info
->
HasAttr
(
attr_name
)
&&
return
op_info
->
HasAttr
(
attr_name
)
&&
op_info
->
GetAttr
<
T
>
(
attr_name
)
==
attr
;
op_info
->
GetAttr
<
T
>
(
attr_name
)
==
attr
;
}
else
{
return
false
;
}
}
return
false
;
});
});
return
this
;
return
this
;
}
}
...
...
paddle/fluid/lite/core/mir/pattern_matcher_high_api.cc
浏览文件 @
4fe5c8aa
...
@@ -41,6 +41,7 @@ void FuseBase::DeleteInterNodes(SSAGraph *graph) {
...
@@ -41,6 +41,7 @@ void FuseBase::DeleteInterNodes(SSAGraph *graph) {
}
}
}
}
LOG
(
INFO
)
<<
"keys: "
<<
key2nodes_
.
size
();
std
::
unordered_set
<
const
Node
*>
nodes2rm
;
std
::
unordered_set
<
const
Node
*>
nodes2rm
;
for
(
auto
&
matched
:
key2nodes_
)
{
for
(
auto
&
matched
:
key2nodes_
)
{
for
(
const
auto
&
key
:
keys
)
{
for
(
const
auto
&
key
:
keys
)
{
...
...
paddle/fluid/lite/core/mir/pattern_matcher_high_api.h
浏览文件 @
4fe5c8aa
...
@@ -49,7 +49,13 @@ class FuseBase {
...
@@ -49,7 +49,13 @@ class FuseBase {
virtual
void
BuildPattern
()
=
0
;
virtual
void
BuildPattern
()
=
0
;
// Generate an operator desc with a matched subgraph.
// Generate an operator desc with a matched subgraph.
virtual
cpp
::
OpDesc
GenOpDesc
(
const
key2nodes_t
&
matched
)
=
0
;
virtual
cpp
::
OpDesc
GenOpDesc
(
const
key2nodes_t
&
matched
)
{
return
cpp
::
OpDesc
();
}
PMNode
*
OpNode
(
const
std
::
string
&
key
)
{
return
GetOrCreateNode
(
key
)
->
assert_is_op
();
}
PMNode
*
OpNode
(
const
std
::
string
&
key
,
const
std
::
string
&
op_type
);
PMNode
*
OpNode
(
const
std
::
string
&
key
,
const
std
::
string
&
op_type
);
...
...
paddle/fluid/lite/core/mir/pattern_matcher_high_api_test.cc
浏览文件 @
4fe5c8aa
...
@@ -52,7 +52,7 @@ class FcFuser : public FuseBase {
...
@@ -52,7 +52,7 @@ class FcFuser : public FuseBase {
void
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
override
{
void
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
override
{
auto
op_desc
=
GenOpDesc
(
matched
);
auto
op_desc
=
GenOpDesc
(
matched
);
auto
fc_op
=
LiteOpRegistry
::
Global
().
Create
(
"fc"
);
auto
fc_op
=
LiteOpRegistry
::
Global
().
Create
(
"fc"
);
auto
mul
=
matched
.
at
(
"mul"
)
->
stmt
()
->
op
;
auto
mul
=
matched
.
at
(
"mul"
)
->
stmt
()
->
op
()
;
auto
*
scope
=
mul
->
scope
();
auto
*
scope
=
mul
->
scope
();
auto
&
valid_places
=
mul
->
valid_places
();
auto
&
valid_places
=
mul
->
valid_places
();
fc_op
->
Attach
(
op_desc
,
scope
);
fc_op
->
Attach
(
op_desc
,
scope
);
...
@@ -90,7 +90,7 @@ std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc,
...
@@ -90,7 +90,7 @@ std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc,
main_block
->
Var
(
"w"
);
main_block
->
Var
(
"w"
);
main_block
->
Var
(
"out"
);
main_block
->
Var
(
"out"
);
scope
->
Var
(
"
w
"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"
x
"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"b"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"b"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"mul_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"mul_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"w"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"w"
)
->
GetMutable
<
lite
::
Tensor
>
();
...
...
paddle/fluid/lite/core/mir/pattern_matcher_test.cc
浏览文件 @
4fe5c8aa
...
@@ -23,19 +23,19 @@ namespace mir {
...
@@ -23,19 +23,19 @@ namespace mir {
void
BuildGraph
(
SSAGraph
*
g
)
{
void
BuildGraph
(
SSAGraph
*
g
)
{
g
->
mutable_nodes
().
emplace_back
();
g
->
mutable_nodes
().
emplace_back
();
Node
&
o1
=
g
->
mutable_nodes
().
back
();
Node
&
o1
=
g
->
mutable_nodes
().
back
();
o1
.
AsStmt
().
op_type
=
"op1"
;
o1
.
AsStmt
().
desc
=
"op1"
;
g
->
mutable_nodes
().
emplace_back
();
g
->
mutable_nodes
().
emplace_back
();
Node
&
o2
=
g
->
mutable_nodes
().
back
();
Node
&
o2
=
g
->
mutable_nodes
().
back
();
o2
.
AsStmt
().
op_type
=
"op2"
;
o2
.
AsStmt
().
desc
=
"op2"
;
g
->
mutable_nodes
().
emplace_back
();
g
->
mutable_nodes
().
emplace_back
();
Node
&
o3
=
g
->
mutable_nodes
().
back
();
Node
&
o3
=
g
->
mutable_nodes
().
back
();
o3
.
AsStmt
().
op_type
=
"op3"
;
o3
.
AsStmt
().
desc
=
"op3"
;
g
->
mutable_nodes
().
emplace_back
();
g
->
mutable_nodes
().
emplace_back
();
Node
&
o4
=
g
->
mutable_nodes
().
back
();
Node
&
o4
=
g
->
mutable_nodes
().
back
();
o4
.
AsStmt
().
op_type
=
"op4"
;
o4
.
AsStmt
().
desc
=
"op4"
;
g
->
mutable_nodes
().
emplace_back
();
g
->
mutable_nodes
().
emplace_back
();
Node
&
o5
=
g
->
mutable_nodes
().
back
();
Node
&
o5
=
g
->
mutable_nodes
().
back
();
o5
.
AsStmt
().
op_type
=
"op5"
;
o5
.
AsStmt
().
desc
=
"op5"
;
g
->
mutable_nodes
().
emplace_back
();
g
->
mutable_nodes
().
emplace_back
();
Node
&
v1
=
g
->
mutable_nodes
().
back
();
Node
&
v1
=
g
->
mutable_nodes
().
back
();
v1
.
AsArg
(
"var1"
);
v1
.
AsArg
(
"var1"
);
...
@@ -108,11 +108,11 @@ TEST(PatternMatcher, MarkPMNodesInGraph) {
...
@@ -108,11 +108,11 @@ TEST(PatternMatcher, MarkPMNodesInGraph) {
// v2 -> o3(a node named o3)
// v2 -> o3(a node named o3)
auto
*
o2
=
x
.
pattern_
.
NewNode
([](
const
Node
*
node
)
{
auto
*
o2
=
x
.
pattern_
.
NewNode
([](
const
Node
*
node
)
{
// The teller can be any condition, such as op type, or variable's shape.
// The teller can be any condition, such as op type, or variable's shape.
return
node
&&
node
->
IsStmt
()
&&
node
->
stmt
()
->
op_type
==
"op2"
;
return
node
&&
node
->
IsStmt
()
&&
node
->
stmt
()
->
desc
==
"op2"
;
});
});
auto
*
o3
=
x
.
pattern_
.
NewNode
([](
const
Node
*
node
)
{
auto
*
o3
=
x
.
pattern_
.
NewNode
([](
const
Node
*
node
)
{
// The teller can be any condition, such as op type, or variable's shape.
// The teller can be any condition, such as op type, or variable's shape.
return
node
&&
node
->
IsStmt
()
&&
node
->
stmt
()
->
op_type
==
"op3"
;
return
node
&&
node
->
IsStmt
()
&&
node
->
stmt
()
->
desc
==
"op3"
;
});
});
auto
*
v2
=
x
.
pattern_
.
NewNode
([](
const
Node
*
node
)
{
auto
*
v2
=
x
.
pattern_
.
NewNode
([](
const
Node
*
node
)
{
// The teller can be any condition, such as op type, or variable's shape.
// The teller can be any condition, such as op type, or variable's shape.
...
@@ -153,8 +153,8 @@ TEST(PatternMatcher, MultiSubgraph) {
...
@@ -153,8 +153,8 @@ TEST(PatternMatcher, MultiSubgraph) {
// op -> var
// op -> var
auto
*
any_op
=
x
.
mutable_pattern
()
->
NewNode
(
auto
*
any_op
=
x
.
mutable_pattern
()
->
NewNode
(
[](
const
Node
*
node
)
{
[](
const
Node
*
node
)
{
return
node
->
IsStmt
()
&&
(
node
->
stmt
()
->
op_type
==
"op2"
||
return
node
->
IsStmt
()
&&
node
->
stmt
()
->
op_type
==
"op3"
);
(
node
->
stmt
()
->
desc
==
"op2"
||
node
->
stmt
()
->
desc
==
"op3"
);
},
},
"OP0"
);
"OP0"
);
auto
*
any_var
=
auto
*
any_var
=
...
@@ -170,9 +170,9 @@ TEST(PatternMatcher, MultiSubgraph) {
...
@@ -170,9 +170,9 @@ TEST(PatternMatcher, MultiSubgraph) {
int
count
=
0
;
int
count
=
0
;
PatternMatcher
::
handle_t
handle
=
[
&
](
const
PatternMatcher
::
subgraph_t
&
s
,
PatternMatcher
::
handle_t
handle
=
[
&
](
const
PatternMatcher
::
subgraph_t
&
s
,
SSAGraph
*
g
)
{
SSAGraph
*
g
)
{
LOG
(
INFO
)
<<
"Detect "
<<
s
.
at
(
any_op
)
->
stmt
()
->
op_type
<<
" -> "
LOG
(
INFO
)
<<
"Detect "
<<
s
.
at
(
any_op
)
->
stmt
()
->
desc
<<
" -> "
<<
s
.
at
(
any_var
)
->
arg
()
->
name
<<
" -> "
<<
s
.
at
(
any_var
)
->
arg
()
->
name
<<
" -> "
<<
s
.
at
(
any_op1
)
->
stmt
()
->
op_type
;
<<
s
.
at
(
any_op1
)
->
stmt
()
->
desc
;
count
++
;
count
++
;
};
};
...
@@ -197,12 +197,12 @@ TEST(PatternMatcher, IntermediateCheck) {
...
@@ -197,12 +197,12 @@ TEST(PatternMatcher, IntermediateCheck) {
PatternMatcher
matcher
;
PatternMatcher
matcher
;
auto
*
op2
=
matcher
.
mutable_pattern
()
->
NewNode
(
auto
*
op2
=
matcher
.
mutable_pattern
()
->
NewNode
(
[](
const
Node
*
x
)
{
[](
const
Node
*
x
)
{
return
x
&&
x
->
IsStmt
()
&&
x
->
stmt
()
->
op_type
==
"op2"
;
return
x
&&
x
->
IsStmt
()
&&
x
->
stmt
()
->
desc
==
"op2"
;
},
},
"op2"
);
"op2"
);
auto
*
op3
=
matcher
.
mutable_pattern
()
->
NewNode
(
auto
*
op3
=
matcher
.
mutable_pattern
()
->
NewNode
(
[](
const
Node
*
x
)
{
[](
const
Node
*
x
)
{
return
x
&&
x
->
IsStmt
()
&&
x
->
stmt
()
->
op_type
==
"op3"
;
return
x
&&
x
->
IsStmt
()
&&
x
->
stmt
()
->
desc
==
"op3"
;
},
},
"op3"
);
"op3"
);
auto
*
v2
=
matcher
.
mutable_pattern
()
auto
*
v2
=
matcher
.
mutable_pattern
()
...
...
paddle/fluid/lite/core/mir/ssa_graph.h
浏览文件 @
4fe5c8aa
...
@@ -65,6 +65,10 @@ class SSAGraph : GraphBase {
...
@@ -65,6 +65,10 @@ class SSAGraph : GraphBase {
Node
*
GraphCreateInstructNode
(
const
std
::
shared_ptr
<
OpLite
>
&
op
,
Node
*
GraphCreateInstructNode
(
const
std
::
shared_ptr
<
OpLite
>
&
op
,
const
std
::
vector
<
Place
>
&
valid_places
);
const
std
::
vector
<
Place
>
&
valid_places
);
// Device related attributes
const
std
::
vector
<
Place
>
&
valid_places
()
const
{
return
valid_places_
;
}
void
SetValidPlaces
(
const
std
::
vector
<
Place
>
&
x
)
{
valid_places_
=
x
;
}
private:
private:
mir
::
Node
*
Argument
(
const
std
::
string
&
name
);
mir
::
Node
*
Argument
(
const
std
::
string
&
name
);
// Check the bidirectional connection.
// Check the bidirectional connection.
...
@@ -89,6 +93,7 @@ class SSAGraph : GraphBase {
...
@@ -89,6 +93,7 @@ class SSAGraph : GraphBase {
private:
private:
std
::
list
<
mir
::
Node
>
node_storage_
;
std
::
list
<
mir
::
Node
>
node_storage_
;
std
::
map
<
std
::
string
,
mir
::
Node
*>
arguments_
;
std
::
map
<
std
::
string
,
mir
::
Node
*>
arguments_
;
std
::
vector
<
Place
>
valid_places_
;
};
};
// Remove the link between a -> b.
// Remove the link between a -> b.
...
...
paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc
浏览文件 @
4fe5c8aa
...
@@ -37,9 +37,9 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
...
@@ -37,9 +37,9 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
if
(
!
node
.
IsStmt
())
continue
;
if
(
!
node
.
IsStmt
())
continue
;
auto
&
instruct
=
node
.
AsStmt
();
auto
&
instruct
=
node
.
AsStmt
();
std
::
vector
<
std
::
pair
<
size_t
,
std
::
unique_ptr
<
KernelBase
>>>
scored
;
std
::
vector
<
std
::
pair
<
size_t
,
std
::
unique_ptr
<
KernelBase
>>>
scored
;
CHECK
(
!
instruct
.
valid_kernels
.
empty
())
<<
"No kernels found for "
CHECK
(
!
instruct
.
kernels
()
.
empty
())
<<
"No kernels found for "
<<
instruct
.
op_type
;
<<
instruct
.
op_type
()
;
for
(
auto
&&
kernel
:
instruct
.
valid_kernels
)
{
for
(
auto
&&
kernel
:
instruct
.
kernels
()
)
{
size_t
score
=
KernelGrade
(
*
kernel
);
size_t
score
=
KernelGrade
(
*
kernel
);
scored
.
emplace_back
(
score
,
std
::
move
(
kernel
));
scored
.
emplace_back
(
score
,
std
::
move
(
kernel
));
}
}
...
@@ -49,9 +49,9 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
...
@@ -49,9 +49,9 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// Move kernel back
// Move kernel back
// Just keep a single best kernel.
// Just keep a single best kernel.
// TODO(Superjomn) reconsider this.
// TODO(Superjomn) reconsider this.
instruct
.
valid_kernels
.
clear
();
instruct
.
kernels
()
.
clear
();
instruct
.
valid_kernels
.
emplace_back
(
std
::
move
(
scored
.
front
().
second
));
instruct
.
kernels
()
.
emplace_back
(
std
::
move
(
scored
.
front
().
second
));
VLOG
(
2
)
<<
"pick "
<<
instruct
.
valid_kernels
.
front
()
->
name
();
VLOG
(
2
)
<<
"pick "
<<
instruct
.
kernels
()
.
front
()
->
name
();
}
}
}
}
...
...
paddle/fluid/lite/core/mir/type_target_transform_pass.cc
浏览文件 @
4fe5c8aa
...
@@ -62,7 +62,7 @@ void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph, Node* inst_node,
...
@@ -62,7 +62,7 @@ void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph, Node* inst_node,
CHECK
(
in
->
AsArg
().
type
);
CHECK
(
in
->
AsArg
().
type
);
if
(
!
TargetCompatibleTo
(
*
in
->
AsArg
().
type
,
*
decl_arg_type
))
{
if
(
!
TargetCompatibleTo
(
*
in
->
AsArg
().
type
,
*
decl_arg_type
))
{
LOG
(
INFO
)
<<
"found Target unmatched tensor: "
<<
in
->
AsArg
().
name
LOG
(
INFO
)
<<
"found Target unmatched tensor: "
<<
in
->
AsArg
().
name
<<
" for kernel "
<<
inst
.
op
->
DebugString
()
<<
" "
<<
" for kernel "
<<
inst
.
op
()
->
DebugString
()
<<
" "
<<
*
in
->
AsArg
().
type
<<
" -> "
<<
*
decl_arg_type
;
<<
*
in
->
AsArg
().
type
<<
" -> "
<<
*
decl_arg_type
;
// Add an IoCopy instruction to make the input compatible with other dist.
// Add an IoCopy instruction to make the input compatible with other dist.
AddIoCopyInst
(
*
in
->
AsArg
().
type
,
*
decl_arg_type
,
in
,
graph
,
inst_node
,
AddIoCopyInst
(
*
in
->
AsArg
().
type
,
*
decl_arg_type
,
in
,
graph
,
inst_node
,
...
@@ -89,7 +89,7 @@ void TypeTargetTransformPass::AddIoCopyInst(
...
@@ -89,7 +89,7 @@ void TypeTargetTransformPass::AddIoCopyInst(
CHECK
(
io_copy_op
)
<<
"create op ["
<<
io_copy_op
<<
"] failed"
;
CHECK
(
io_copy_op
)
<<
"create op ["
<<
io_copy_op
<<
"] failed"
;
// CHECK(io_copy_op);
// CHECK(io_copy_op);
// Create the new var manually.
// Create the new var manually.
inst_node
->
AsStmt
().
op
->
scope
()
->
Var
(
io_copy_output_name
);
inst_node
->
AsStmt
().
op
()
->
scope
()
->
Var
(
io_copy_output_name
);
// Create IoCopy Instruction.
// Create IoCopy Instruction.
cpp
::
OpDesc
op_desc
;
cpp
::
OpDesc
op_desc
;
...
@@ -97,7 +97,7 @@ void TypeTargetTransformPass::AddIoCopyInst(
...
@@ -97,7 +97,7 @@ void TypeTargetTransformPass::AddIoCopyInst(
op_desc
.
SetInput
(
"Input"
,
{
in
->
AsArg
().
name
});
op_desc
.
SetInput
(
"Input"
,
{
in
->
AsArg
().
name
});
op_desc
.
SetOutput
(
"Out"
,
{
io_copy_output_name
});
op_desc
.
SetOutput
(
"Out"
,
{
io_copy_output_name
});
io_copy_op
->
Attach
(
op_desc
,
inst_node
->
AsStmt
().
op
->
scope
());
io_copy_op
->
Attach
(
op_desc
,
inst_node
->
AsStmt
().
op
()
->
scope
());
auto
kernels
=
io_copy_op
->
CreateKernels
(
valid_places
);
auto
kernels
=
io_copy_op
->
CreateKernels
(
valid_places
);
io_copy_inst
->
AsStmt
(
"io_copy"
,
std
::
move
(
kernels
),
io_copy_op
);
io_copy_inst
->
AsStmt
(
"io_copy"
,
std
::
move
(
kernels
),
io_copy_op
);
...
@@ -113,19 +113,19 @@ void TypeTargetTransformPass::AddIoCopyInst(
...
@@ -113,19 +113,19 @@ void TypeTargetTransformPass::AddIoCopyInst(
DirectedLink
(
io_copy_output_arg
,
inst_node
);
DirectedLink
(
io_copy_output_arg
,
inst_node
);
// reset opdesc and update kernel information
// reset opdesc and update kernel information
UpdateInputTo
(
inst_node
->
AsStmt
().
op
->
mutable_op_info
(),
in
->
AsArg
().
name
,
UpdateInputTo
(
inst_node
->
AsStmt
().
op
()
->
mutable_op_info
(),
in
->
AsArg
().
name
,
io_copy_output_name
);
io_copy_output_name
);
inst_node
->
AsStmt
().
op
->
Attach
(
*
inst_node
->
AsStmt
().
op
->
op_info
(),
inst_node
->
AsStmt
().
ResetOp
(
*
inst_node
->
AsStmt
().
op_info
(),
inst_node
->
AsStmt
().
op
->
scope
());
graph
->
valid_places
());
std
::
string
tmp
;
std
::
string
tmp
;
if
(
inst_node
->
AsStmt
().
op_info
()
->
GetInputArgname
(
"a"
,
&
tmp
))
{
if
(
inst_node
->
AsStmt
().
op_info
()
->
GetInputArgname
(
"a"
,
&
tmp
))
{
CHECK
(
false
)
<<
"get old a "
<<
tmp
;
CHECK
(
false
)
<<
"get old a "
<<
tmp
;
}
}
for
(
auto
&
kernel
:
inst_node
->
AsStmt
().
valid_kernels
)
{
for
(
auto
&
kernel
:
inst_node
->
AsStmt
().
kernels
()
)
{
inst_node
->
AsStmt
().
op
->
AttachKernel
(
kernel
.
get
());
inst_node
->
AsStmt
().
op
()
->
AttachKernel
(
kernel
.
get
());
}
}
graph
->
CheckValid
();
graph
->
CheckValid
();
...
...
paddle/fluid/lite/core/mir/use_passes.h
浏览文件 @
4fe5c8aa
...
@@ -23,9 +23,11 @@ USE_MIR_PASS(generate_program_pass);
...
@@ -23,9 +23,11 @@ USE_MIR_PASS(generate_program_pass);
USE_MIR_PASS
(
io_copy_kernel_pick_pass
);
USE_MIR_PASS
(
io_copy_kernel_pick_pass
);
USE_MIR_PASS
(
argument_type_display_pass
);
USE_MIR_PASS
(
argument_type_display_pass
);
USE_MIR_PASS
(
runtime_context_assign_pass
);
USE_MIR_PASS
(
runtime_context_assign_pass
);
USE_MIR_PASS
(
lite_conv_bn_fuse_pass
);
USE_MIR_PASS
(
graph_visualze
);
USE_MIR_PASS
(
graph_visualze
);
USE_MIR_PASS
(
lite_conv_bn_fuse_pass
);
USE_MIR_PASS
(
lite_fc_fuse_pass
);
USE_MIR_PASS
(
lite_fc_fuse_pass
);
USE_MIR_PASS
(
identity_scale_eliminate_pass
);
USE_MIR_PASS
(
lite_conv_elementwise_add_activation_fuse_pass
);
USE_MIR_PASS
(
lite_conv_elementwise_add_activation_fuse_pass
);
USE_MIR_PASS
(
lite_elementwise_add_activation_fuse_pass
);
USE_MIR_PASS
(
lite_elementwise_add_activation_fuse_pass
);
USE_MIR_PASS
(
lite_quant_dequant_fuse_pass
);
USE_MIR_PASS
(
lite_quant_dequant_fuse_pass
);
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
浏览文件 @
4fe5c8aa
...
@@ -39,7 +39,7 @@ class VariablePlaceInferencePass : public DebugPass {
...
@@ -39,7 +39,7 @@ class VariablePlaceInferencePass : public DebugPass {
for
(
const
auto
&
v
:
graph
->
inputs
())
{
for
(
const
auto
&
v
:
graph
->
inputs
())
{
// the feed op might in the inputs
// the feed op might in the inputs
if
(
v
->
IsStmt
())
{
if
(
v
->
IsStmt
())
{
LOG
(
INFO
)
<<
"found kernel in inputs "
<<
v
->
AsStmt
().
op_type
;
LOG
(
INFO
)
<<
"found kernel in inputs "
<<
v
->
AsStmt
().
op_type
()
;
continue
;
continue
;
}
}
}
}
...
@@ -59,10 +59,10 @@ class VariablePlaceInferencePass : public DebugPass {
...
@@ -59,10 +59,10 @@ class VariablePlaceInferencePass : public DebugPass {
for
(
auto
&
x
:
graph
->
StmtTopologicalOrder
())
{
for
(
auto
&
x
:
graph
->
StmtTopologicalOrder
())
{
auto
&
inst
=
x
->
AsStmt
();
auto
&
inst
=
x
->
AsStmt
();
// The IoCopyOp is a tool operator, it won't support the type inference.
// The IoCopyOp is a tool operator, it won't support the type inference.
if
(
inst
.
op_type
==
"io_copy"
)
continue
;
if
(
inst
.
op_type
()
==
"io_copy"
)
continue
;
// LOG(INFO) << "- inferencing type " <<
// LOG(INFO) << "- inferencing type " <<
// deal with inputs
// deal with inputs
VLOG
(
4
)
<<
"
inferencing op "
<<
inst
.
op_type
;
VLOG
(
4
)
<<
"
Infering op "
<<
inst
.
op_info
()
->
Repr
()
;
// TODO(zhaolong): Add check if the node's name in op's arguments.
// TODO(zhaolong): Add check if the node's name in op's arguments.
auto
get_argname
=
[
&
](
auto
get_argname
=
[
&
](
...
@@ -90,12 +90,14 @@ class VariablePlaceInferencePass : public DebugPass {
...
@@ -90,12 +90,14 @@ class VariablePlaceInferencePass : public DebugPass {
}
}
}
}
VLOG
(
3
)
<<
"inst "
<<
inst
.
op_info
()
->
Repr
();
for
(
auto
*
x_out
:
x
->
outlinks
)
{
for
(
auto
*
x_out
:
x
->
outlinks
)
{
std
::
string
node_name
=
x_out
->
AsArg
().
name
;
std
::
string
node_name
=
x_out
->
AsArg
().
name
;
std
::
string
arg_name
=
std
::
string
arg_name
=
get_argname
(
node_name
,
inst
.
op_info
()
->
outputs
());
get_argname
(
node_name
,
inst
.
op_info
()
->
outputs
());
CHECK
(
arg_name
.
size
()
>
0
)
<<
"can not found op arguments for node "
CHECK
(
arg_name
.
size
()
>
0
)
<<
"can not found op arguments for node "
<<
node_name
;
<<
node_name
<<
" in Inst "
<<
inst
.
op_type
();
VLOG
(
3
)
<<
"-- output arg_name "
<<
arg_name
;
VLOG
(
3
)
<<
"-- output arg_name "
<<
arg_name
;
auto
type
=
inst
.
picked_kernel
().
GetOutputDeclType
(
arg_name
);
auto
type
=
inst
.
picked_kernel
().
GetOutputDeclType
(
arg_name
);
if
(
!
x_out
->
AsArg
().
type
)
{
if
(
!
x_out
->
AsArg
().
type
)
{
...
...
paddle/fluid/lite/core/op_lite.cc
浏览文件 @
4fe5c8aa
...
@@ -61,7 +61,6 @@ std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels(
...
@@ -61,7 +61,6 @@ std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels(
targets
.
insert
(
place
.
target
);
targets
.
insert
(
place
.
target
);
}
}
// CHECK(!kernels.empty()) << "No kernel found for Op " << op_type_;
VLOG
(
2
)
<<
"op "
<<
op_type_
<<
" get "
<<
kernels
.
size
()
<<
" kernels"
;
VLOG
(
2
)
<<
"op "
<<
op_type_
<<
" get "
<<
kernels
.
size
()
<<
" kernels"
;
return
kernels
;
return
kernels
;
}
}
...
@@ -83,7 +82,7 @@ bool OpLite::Attach(const cpp::OpDesc &opdesc, lite::Scope *scope) {
...
@@ -83,7 +82,7 @@ bool OpLite::Attach(const cpp::OpDesc &opdesc, lite::Scope *scope) {
scope_
=
scope
;
scope_
=
scope
;
op_info_
.
reset
(
op_info_
.
reset
(
new
OpInfo
(
opdesc
));
// Force clean the out-of-date infomation.
new
OpInfo
(
opdesc
));
// Force clean the out-of-date infomation.
return
AttachImpl
(
opdesc
,
scope
);
return
AttachImpl
(
*
op_info
()
,
scope
);
}
}
const
Tensor
*
OpLite
::
GetTensor
(
lite
::
Scope
*
scope
,
const
Tensor
*
OpLite
::
GetTensor
(
lite
::
Scope
*
scope
,
...
...
paddle/fluid/lite/core/op_lite.h
浏览文件 @
4fe5c8aa
...
@@ -197,6 +197,22 @@ class OpInfo : public cpp::OpDesc {
...
@@ -197,6 +197,22 @@ class OpInfo : public cpp::OpDesc {
}
}
return
false
;
return
false
;
}
}
void
UpdateAllInputs
(
const
std
::
string
&
from
,
const
std
::
string
&
to
)
{
for
(
auto
&
item
:
inputs_
)
{
for
(
auto
&
var
:
item
.
second
)
{
if
(
var
==
from
)
var
=
to
;
}
}
}
void
UpdateAllOutputs
(
const
std
::
string
&
from
,
const
std
::
string
&
to
)
{
for
(
auto
&
item
:
outputs_
)
{
for
(
auto
&
var
:
item
.
second
)
{
if
(
var
==
from
)
var
=
to
;
}
}
}
};
};
}
// namespace lite
}
// namespace lite
...
...
paddle/fluid/lite/core/optimizer.h
浏览文件 @
4fe5c8aa
...
@@ -43,6 +43,8 @@ class Optimizer {
...
@@ -43,6 +43,8 @@ class Optimizer {
CHECK
(
!
graph_
)
<<
"duplicate optimize found"
;
CHECK
(
!
graph_
)
<<
"duplicate optimize found"
;
graph_
.
reset
(
new
mir
::
SSAGraph
);
graph_
.
reset
(
new
mir
::
SSAGraph
);
graph_
->
Build
(
program
,
valid_places
);
graph_
->
Build
(
program
,
valid_places
);
graph_
->
SetValidPlaces
(
valid_places
);
SpecifyKernelPickTactic
(
kernel_pick_factor
);
SpecifyKernelPickTactic
(
kernel_pick_factor
);
InitTargetTypeTransformPass
();
InitTargetTypeTransformPass
();
...
@@ -51,6 +53,8 @@ class Optimizer {
...
@@ -51,6 +53,8 @@ class Optimizer {
"lite_quant_dequant_fuse_pass"
,
//
"lite_quant_dequant_fuse_pass"
,
//
"lite_conv_bn_fuse_pass"
,
//
"lite_conv_bn_fuse_pass"
,
//
"lite_conv_elementwise_add_activation_fuse_pass"
,
//
"lite_conv_elementwise_add_activation_fuse_pass"
,
//
"lite_fc_fuse_pass"
,
//
"identity_scale_eliminate_pass"
,
//
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
"lite_elementwise_add_activation_fuse_pass"
,
//
"lite_elementwise_add_activation_fuse_pass"
,
//
#endif
#endif
...
...
paddle/fluid/lite/core/program.h
浏览文件 @
4fe5c8aa
...
@@ -140,7 +140,7 @@ class RuntimeProgram {
...
@@ -140,7 +140,7 @@ class RuntimeProgram {
void
Run
()
{
void
Run
()
{
for
(
auto
&
inst
:
instructions_
)
{
for
(
auto
&
inst
:
instructions_
)
{
VLOG
(
4
)
<<
">> Running kernel: "
<<
inst
;
VLOG
(
3
)
<<
">> Running kernel: "
<<
inst
.
op
()
->
op_info
()
->
Repr
()
;
inst
.
Run
();
inst
.
Run
();
}
}
}
}
...
...
paddle/fluid/lite/core/tensor.h
浏览文件 @
4fe5c8aa
...
@@ -191,7 +191,6 @@ class TensorBase {
...
@@ -191,7 +191,6 @@ class TensorBase {
template
<
typename
TensorT
>
template
<
typename
TensorT
>
bool
TensorCompareWith
(
const
TensorT
&
a
,
const
TensorT
&
b
)
{
bool
TensorCompareWith
(
const
TensorT
&
a
,
const
TensorT
&
b
)
{
if
(
a
.
dims
()
!=
b
.
dims
())
return
false
;
if
(
a
.
dims
()
!=
b
.
dims
())
return
false
;
LOG
(
INFO
)
<<
"data_size: "
<<
a
.
data_size
();
if
(
memcmp
(
a
.
raw_data
(),
b
.
raw_data
(),
a
.
data_size
())
!=
0
)
return
false
;
if
(
memcmp
(
a
.
raw_data
(),
b
.
raw_data
(),
a
.
data_size
())
!=
0
)
return
false
;
return
true
;
return
true
;
}
}
...
...
paddle/fluid/lite/kernels/arm/elementwise_compute_test.cc
浏览文件 @
4fe5c8aa
...
@@ -117,6 +117,19 @@ TEST(elementwise_add, compute) {
...
@@ -117,6 +117,19 @@ TEST(elementwise_add, compute) {
operators
::
ElementwiseParam
param
;
operators
::
ElementwiseParam
param
;
lite
::
Tensor
x
,
y
,
output
,
output_ref
;
lite
::
Tensor
x
,
y
,
output
,
output_ref
;
#if 1
for
(
auto
n
:
{
1
,
3
,
4
})
{
for
(
auto
c
:
{
1
,
3
,
4
})
{
for
(
auto
h
:
{
1
,
3
,
4
})
{
for
(
auto
w
:
{
1
,
3
,
4
})
{
for
(
auto
axis
:
{
-
1
,
0
,
1
,
3
})
{
for
(
auto
yd
:
{
std
::
vector
<
int64_t
>
({
n
}),
std
::
vector
<
int64_t
>
({
c
}),
std
::
vector
<
int64_t
>
({
h
}),
std
::
vector
<
int64_t
>
({
w
}),
std
::
vector
<
int64_t
>
({
n
,
c
}),
std
::
vector
<
int64_t
>
({
c
,
h
}),
std
::
vector
<
int64_t
>
({
c
,
h
,
w
}),
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
})})
{
#else
for
(
auto
n
:
{
1
,
3
,
4
,
11
})
{
for
(
auto
n
:
{
1
,
3
,
4
,
11
})
{
for
(
auto
c
:
{
1
,
3
,
4
,
11
})
{
for
(
auto
c
:
{
1
,
3
,
4
,
11
})
{
for
(
auto
h
:
{
1
,
3
,
4
,
11
})
{
for
(
auto
h
:
{
1
,
3
,
4
,
11
})
{
...
@@ -129,6 +142,7 @@ TEST(elementwise_add, compute) {
...
@@ -129,6 +142,7 @@ TEST(elementwise_add, compute) {
std
::
vector
<
int64_t
>
({
h
,
w
}),
std
::
vector
<
int64_t
>
({
n
,
c
,
h
}),
std
::
vector
<
int64_t
>
({
h
,
w
}),
std
::
vector
<
int64_t
>
({
n
,
c
,
h
}),
std
::
vector
<
int64_t
>
({
c
,
h
,
w
}),
std
::
vector
<
int64_t
>
({
c
,
h
,
w
}),
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
})})
{
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
})})
{
#endif
auto
x_dim
=
DDim
(
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
}));
auto
x_dim
=
DDim
(
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
}));
auto
y_dim
=
DDim
(
yd
);
auto
y_dim
=
DDim
(
yd
);
int
axis_t
=
axis
<
0
?
x_dim
.
size
()
-
y_dim
.
size
()
:
axis
;
int
axis_t
=
axis
<
0
?
x_dim
.
size
()
-
y_dim
.
size
()
:
axis
;
...
@@ -192,6 +206,20 @@ TEST(fusion_elementwise_add_activation_arm, compute) {
...
@@ -192,6 +206,20 @@ TEST(fusion_elementwise_add_activation_arm, compute) {
operators
::
FusionElementwiseActivationParam
param
;
operators
::
FusionElementwiseActivationParam
param
;
lite
::
Tensor
x
,
y
,
output
,
output_ref
;
lite
::
Tensor
x
,
y
,
output
,
output_ref
;
#if 1
for
(
auto
act_type
:
{
"relu"
})
{
for
(
auto
n
:
{
1
,
3
,
4
})
{
for
(
auto
c
:
{
1
,
3
,
4
})
{
for
(
auto
h
:
{
1
,
3
,
4
})
{
for
(
auto
w
:
{
1
,
3
,
4
})
{
for
(
auto
axis
:
{
-
1
,
0
,
1
,
3
})
{
for
(
auto
yd
:
{
std
::
vector
<
int64_t
>
({
n
}),
std
::
vector
<
int64_t
>
({
c
}),
std
::
vector
<
int64_t
>
({
h
}),
std
::
vector
<
int64_t
>
({
w
}),
std
::
vector
<
int64_t
>
({
n
,
c
}),
std
::
vector
<
int64_t
>
({
h
,
w
}),
std
::
vector
<
int64_t
>
({
n
,
c
,
h
}),
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
})})
{
#else
for
(
auto
act_type
:
{
"relu"
})
{
for
(
auto
act_type
:
{
"relu"
})
{
for
(
auto
n
:
{
1
,
3
,
4
,
11
})
{
for
(
auto
n
:
{
1
,
3
,
4
,
11
})
{
for
(
auto
c
:
{
1
,
3
,
4
,
11
})
{
for
(
auto
c
:
{
1
,
3
,
4
,
11
})
{
...
@@ -206,6 +234,7 @@ TEST(fusion_elementwise_add_activation_arm, compute) {
...
@@ -206,6 +234,7 @@ TEST(fusion_elementwise_add_activation_arm, compute) {
std
::
vector
<
int64_t
>
({
n
,
c
,
h
}),
std
::
vector
<
int64_t
>
({
n
,
c
,
h
}),
std
::
vector
<
int64_t
>
({
c
,
h
,
w
}),
std
::
vector
<
int64_t
>
({
c
,
h
,
w
}),
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
})})
{
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
})})
{
#endif
auto
x_dim
=
DDim
(
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
}));
auto
x_dim
=
DDim
(
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
}));
auto
y_dim
=
DDim
(
yd
);
auto
y_dim
=
DDim
(
yd
);
int
axis_t
=
axis
<
0
?
x_dim
.
size
()
-
y_dim
.
size
()
:
axis
;
int
axis_t
=
axis
<
0
?
x_dim
.
size
()
-
y_dim
.
size
()
:
axis
;
...
...
paddle/fluid/lite/kernels/arm/softmax_compute_test.cc
浏览文件 @
4fe5c8aa
...
@@ -80,12 +80,19 @@ TEST(softmax_arm, compute) {
...
@@ -80,12 +80,19 @@ TEST(softmax_arm, compute) {
lite
::
Tensor
x
;
lite
::
Tensor
x
;
lite
::
Tensor
output
;
lite
::
Tensor
output
;
lite
::
Tensor
output_ref
;
lite
::
Tensor
output_ref
;
#if 1
for
(
auto
n
:
{
1
,
3
})
{
for
(
auto
c
:
{
1
,
4
})
{
for
(
auto
h
:
{
5
,
1
})
{
for
(
auto
w
:
{
1
,
6
})
{
for
(
auto
axis
:
{
-
2
,
-
1
,
0
,
1
,
2
})
{
#else
for
(
auto
n
:
{
1
,
3
,
4
,
11
})
{
for
(
auto
n
:
{
1
,
3
,
4
,
11
})
{
for
(
auto
c
:
{
1
,
3
,
11
,
4
})
{
for
(
auto
c
:
{
1
,
3
,
11
,
4
})
{
for
(
auto
h
:
{
3
,
1
,
11
,
4
})
{
for
(
auto
h
:
{
3
,
1
,
11
,
4
})
{
for
(
auto
w
:
{
1
,
3
,
4
,
12
})
{
for
(
auto
w
:
{
1
,
3
,
4
,
12
})
{
for
(
auto
axis
:
{
-
4
,
-
3
,
-
2
,
-
1
,
0
,
1
,
2
,
3
})
{
for
(
auto
axis
:
{
-
4
,
-
3
,
-
2
,
-
1
,
0
,
1
,
2
,
3
})
{
#endif
x
.
Resize
(
DDim
(
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
})));
x
.
Resize
(
DDim
(
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
})));
output
.
Resize
(
DDim
(
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
})));
output
.
Resize
(
DDim
(
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
})));
output_ref
.
Resize
(
DDim
(
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
})));
output_ref
.
Resize
(
DDim
(
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
})));
...
...
paddle/fluid/lite/kernels/x86/elementwise_compute_test.cc
浏览文件 @
4fe5c8aa
...
@@ -15,6 +15,8 @@
...
@@ -15,6 +15,8 @@
#include "paddle/fluid/lite/kernels/x86/elementwise_compute.h"
#include "paddle/fluid/lite/kernels/x86/elementwise_compute.h"
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include <iostream>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/op_registry.h"
...
...
paddle/fluid/lite/kernels/x86/mul_compute.h
浏览文件 @
4fe5c8aa
...
@@ -40,12 +40,20 @@ class MulCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
...
@@ -40,12 +40,20 @@ class MulCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
auto
*
x
=
&
param
.
x
->
raw_tensor
();
auto
*
x
=
&
param
.
x
->
raw_tensor
();
auto
*
y
=
&
param
.
y
->
raw_tensor
();
auto
*
y
=
&
param
.
y
->
raw_tensor
();
const
Tensor
x_matrix
=
x
->
dims
().
size
()
>
2
?
framework
::
ReshapeToMatrix
(
Tensor
x_matrix
,
y_matrix
;
*
x
,
param
.
x_num_col_dims
)
:
*
x
;
if
(
x
->
dims
().
size
()
>
2
)
{
const
Tensor
y_matrix
=
y
->
dims
().
size
()
>
2
?
framework
::
ReshapeToMatrix
(
x_matrix
=
framework
::
ReshapeToMatrix
(
*
x
,
param
.
x_num_col_dims
);
*
y
,
param
.
y_num_col_dims
)
}
else
{
:
*
y
;
x_matrix
=
*
x
;
}
if
(
y
->
dims
().
size
()
>
2
)
{
y_matrix
=
framework
::
ReshapeToMatrix
(
*
y
,
param
.
y_num_col_dims
);
}
else
{
y_matrix
=
*
y
;
}
auto
*
z
=
&
param
.
output
->
raw_tensor
();
auto
*
z
=
&
param
.
output
->
raw_tensor
();
auto
z_dim
=
z
->
dims
();
auto
z_dim
=
z
->
dims
();
...
@@ -75,12 +83,22 @@ class MulGradCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
...
@@ -75,12 +83,22 @@ class MulGradCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
auto
*
x
=
&
param
.
x
->
raw_tensor
();
auto
*
x
=
&
param
.
x
->
raw_tensor
();
auto
*
y
=
&
param
.
y
->
raw_tensor
();
auto
*
y
=
&
param
.
y
->
raw_tensor
();
auto
x_matrix
=
x
->
dims
().
size
()
>
2
?
framework
::
ReshapeToMatrix
(
*
x
,
param
.
x_num_col_dims
)
Tensor
x_matrix
,
y_matrix
;
:
static_cast
<
const
Tensor
&>
(
*
x
);
auto
y_matrix
=
y
->
dims
().
size
()
>
2
if
(
x
->
dims
().
size
()
>
2
)
{
?
framework
::
ReshapeToMatrix
(
*
y
,
param
.
y_num_col_dims
)
x_matrix
=
framework
::
ReshapeToMatrix
(
*
x
,
param
.
x_num_col_dims
);
:
static_cast
<
const
Tensor
&>
(
*
y
);
}
else
{
x_matrix
=
*
x
;
}
if
(
y
->
dims
().
size
()
>
2
)
{
y_matrix
=
framework
::
ReshapeToMatrix
(
*
y
,
param
.
y_num_col_dims
);
}
else
{
y_matrix
=
*
y
;
}
auto
*
dout
=
&
param
.
output_grad
->
raw_tensor
();
auto
*
dout
=
&
param
.
output_grad
->
raw_tensor
();
Tensor
dout_mat
;
Tensor
dout_mat
;
...
...
paddle/fluid/lite/kernels/x86/mul_compute_test.cc
浏览文件 @
4fe5c8aa
...
@@ -15,6 +15,8 @@
...
@@ -15,6 +15,8 @@
#include "paddle/fluid/lite/kernels/x86/mul_compute.h"
#include "paddle/fluid/lite/kernels/x86/mul_compute.h"
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include <iostream>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/op_registry.h"
...
...
paddle/fluid/lite/model_parser/CMakeLists.txt
浏览文件 @
4fe5c8aa
...
@@ -11,7 +11,8 @@ if(NOT LITE_ON_MOBILE)
...
@@ -11,7 +11,8 @@ if(NOT LITE_ON_MOBILE)
endif
()
endif
()
cc_library
(
compatible_pb_lite SRCS compatible_pb.cc DEPS op_desc_lite framework_proto_lite var_desc_lite
)
cc_library
(
compatible_pb_lite SRCS compatible_pb.cc
DEPS op_desc_lite framework_proto_lite var_desc_lite cpp_op_desc_lite
)
lite_cc_library
(
model_parser_lite SRCS model_parser.cc DEPS
lite_cc_library
(
model_parser_lite SRCS model_parser.cc DEPS
variable_lite scope_lite
${
tensor_lite
}
scope_lite
variable_lite scope_lite
${
tensor_lite
}
scope_lite
...
...
paddle/fluid/lite/model_parser/cpp/CMakeLists.txt
浏览文件 @
4fe5c8aa
cc_library
(
cpp_op_desc_lite SRCS op_desc.cc DEPS any_lite
)
cc_library
(
cpp_op_desc_lite SRCS op_desc.cc DEPS any_lite
)
paddle/fluid/lite/model_parser/desc_apis.h
浏览文件 @
4fe5c8aa
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#pragma once
#pragma once
#include <map>
#include <map>
#include <sstream>
#include <string>
#include <string>
#include <vector>
#include <vector>
...
@@ -79,6 +80,27 @@ class OpDescAPI {
...
@@ -79,6 +80,27 @@ class OpDescAPI {
/// Get an attribute.
/// Get an attribute.
template
<
typename
T
>
template
<
typename
T
>
T
GetAttr
(
const
std
::
string
&
name
)
const
;
T
GetAttr
(
const
std
::
string
&
name
)
const
;
std
::
string
Repr
()
const
{
std
::
stringstream
ss
;
ss
<<
Type
();
ss
<<
"("
;
for
(
auto
&
arg
:
InputArgumentNames
())
{
ss
<<
arg
<<
":"
;
for
(
auto
val
:
Input
(
arg
))
{
ss
<<
val
<<
" "
;
}
}
ss
<<
") -> ("
;
for
(
auto
&
arg
:
OutputArgumentNames
())
{
ss
<<
arg
<<
":"
;
for
(
auto
val
:
Output
(
arg
))
{
ss
<<
val
<<
" "
;
}
}
ss
<<
")"
;
return
ss
.
str
();
}
};
};
}
// namespace lite
}
// namespace lite
...
...
paddle/fluid/lite/model_parser/pb/op_desc.h
浏览文件 @
4fe5c8aa
...
@@ -141,6 +141,8 @@ class OpDesc : public OpDescAPI {
...
@@ -141,6 +141,8 @@ class OpDesc : public OpDescAPI {
template
<
typename
T
>
template
<
typename
T
>
T
GetAttr
(
const
std
::
string
&
name
)
const
;
T
GetAttr
(
const
std
::
string
&
name
)
const
;
std
::
string
DebugString
()
const
{
return
desc_
.
DebugString
();
}
private:
private:
std
::
vector
<
std
::
string
>
GetArguments
(
std
::
vector
<
std
::
string
>
GetArguments
(
const
google
::
protobuf
::
RepeatedPtrField
<
framework
::
proto
::
OpDesc_Var
>
const
google
::
protobuf
::
RepeatedPtrField
<
framework
::
proto
::
OpDesc_Var
>
...
...
paddle/fluid/lite/operators/mul_op.h
浏览文件 @
4fe5c8aa
...
@@ -38,15 +38,19 @@ class MulOpLite : public OpLite {
...
@@ -38,15 +38,19 @@ class MulOpLite : public OpLite {
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
// TODO(Superjomn) replace framework::OpDesc with a lite one.
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
CHECK
(
!
op_desc
.
Input
(
"X"
).
empty
());
CHECK
(
!
op_desc
.
Input
(
"Y"
).
empty
());
CHECK
(
!
op_desc
.
Output
(
"Out"
).
empty
());
auto
input
=
op_desc
.
Input
(
"X"
).
front
();
auto
input
=
op_desc
.
Input
(
"X"
).
front
();
auto
W
=
op_desc
.
Input
(
"Y"
).
front
();
auto
W
=
op_desc
.
Input
(
"Y"
).
front
();
auto
out
=
op_desc
.
Output
(
"Out"
).
front
();
auto
out
=
op_desc
.
Output
(
"Out"
).
front
();
auto
*
var
=
scope
->
FindVar
(
input
);
auto
*
var
=
scope
->
FindVar
(
input
);
CHECK
(
var
);
CHECK
(
var
);
param_
.
x
=
var
->
GetMutable
<
Tensor
>
();
param_
.
x
=
&
var
->
Get
<
Tensor
>
();
var
=
scope
->
FindVar
(
W
);
var
=
scope
->
FindVar
(
W
);
CHECK
(
var
)
<<
"no var called "
<<
W
;
CHECK
(
var
)
<<
"no var called "
<<
W
;
param_
.
y
=
var
->
GetMutable
<
Tensor
>
();
param_
.
y
=
&
var
->
Get
<
Tensor
>
();
var
=
scope
->
FindVar
(
out
);
var
=
scope
->
FindVar
(
out
);
CHECK
(
var
)
<<
"no var called "
<<
out
;
CHECK
(
var
)
<<
"no var called "
<<
out
;
param_
.
output
=
var
->
GetMutable
<
Tensor
>
();
param_
.
output
=
var
->
GetMutable
<
Tensor
>
();
...
...
paddle/fluid/lite/operators/op_params.h
浏览文件 @
4fe5c8aa
...
@@ -67,8 +67,8 @@ struct ReluParam {
...
@@ -67,8 +67,8 @@ struct ReluParam {
// For Mul Op
// For Mul Op
struct
MulParam
{
struct
MulParam
{
lite
::
Tensor
*
x
{};
const
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
y
{};
const
lite
::
Tensor
*
y
{};
lite
::
Tensor
*
output
{};
lite
::
Tensor
*
output
{};
int
x_num_col_dims
{
1
};
int
x_num_col_dims
{
1
};
...
...
paddle/fluid/lite/tools/build.sh
浏览文件 @
4fe5c8aa
...
@@ -54,22 +54,6 @@ function check_style {
...
@@ -54,22 +54,6 @@ function check_style {
fi
fi
}
}
function
cmake_arm
{
# $1: ARM_TARGET_OS in "android" , "armlinux"
# $2: ARM_TARGET_ARCH_ABI in "armv8", "armv7" ,"armv7hf"
# $3: ARM_TARGET_LANG in "gcc" "clang"
cmake ..
\
-DWITH_GPU
=
OFF
\
-DWITH_MKL
=
OFF
\
-DWITH_LITE
=
ON
\
-DLITE_WITH_CUDA
=
OFF
\
-DLITE_WITH_X86
=
OFF
\
-DLITE_WITH_ARM
=
ON
\
-DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK
=
ON
\
-DWITH_TESTING
=
ON
\
-DARM_TARGET_OS
=
$1
-DARM_TARGET_ARCH_ABI
=
$2
-DARM_TARGET_LANG
=
$3
}
function
build_single
{
function
build_single
{
#make $1 -j$(expr $(nproc) - 2)
#make $1 -j$(expr $(nproc) - 2)
make
$1
-j
$NUM_CORES_FOR_COMPILE
make
$1
-j
$NUM_CORES_FOR_COMPILE
...
@@ -153,33 +137,53 @@ function test_arm_model {
...
@@ -153,33 +137,53 @@ function test_arm_model {
adb
-s
emulator-
${
port
}
shell
chmod
+x
"
${
adb_work_dir
}
/
${
test_name
}
"
adb
-s
emulator-
${
port
}
shell
chmod
+x
"
${
adb_work_dir
}
/
${
test_name
}
"
local
adb_model_path
=
"./
${
adb_work_dir
}
/
`
basename
${
model_dir
}
`
"
local
adb_model_path
=
"./
${
adb_work_dir
}
/
`
basename
${
model_dir
}
`
"
adb
-s
emulator-
${
port
}
shell
"./
${
adb_work_dir
}
/
${
test_name
}
--eval_model_dir=
$adb_model_path
"
adb
-s
emulator-
${
port
}
shell
"./
${
adb_work_dir
}
/
${
test_name
}
--eval_model_dir=
$adb_model_path
"
}
}
# Build the code and run lite arm tests. This is executed in the CI system.
function
cmake_arm
{
function
build_test_arm
{
# $1: ARM_TARGET_OS in "android" , "armlinux"
# 1. Build goes first
# $2: ARM_TARGET_ARCH_ABI in "armv8", "armv7" ,"armv7hf"
# $3: ARM_TARGET_LANG in "gcc" "clang"
cmake ..
\
-DWITH_GPU
=
OFF
\
-DWITH_MKL
=
OFF
\
-DWITH_LITE
=
ON
\
-DLITE_WITH_CUDA
=
OFF
\
-DLITE_WITH_X86
=
OFF
\
-DLITE_WITH_ARM
=
ON
\
-DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK
=
ON
\
-DWITH_TESTING
=
ON
\
-DARM_TARGET_OS
=
$1
-DARM_TARGET_ARCH_ABI
=
$2
-DARM_TARGET_LANG
=
$3
}
# $1: ARM_TARGET_OS in "android" , "armlinux"
# $2: ARM_TARGET_ARCH_ABI in "armv8", "armv7" ,"armv7hf"
# $3: ARM_TARGET_LANG in "gcc" "clang"
function
build_arm
{
os
=
$1
abi
=
$2
lang
=
$3
cur_dir
=
$(
pwd
)
cur_dir
=
$(
pwd
)
for
lang
in
"gcc"
"clang"
;
do
if
[[
${
os
}
==
"armlinux"
]]
;
then
for
os
in
"android"
"armlinux"
;
do
# TODO(hongming): enable compile armv7 and armv7hf on armlinux, and clang compile
if
[[
${
os
}
==
"armlinux"
&&
${
lang
}
==
"clang"
]]
;
then
if
[[
${
lang
}
==
"clang"
]]
;
then
continue
echo
"clang is not enabled on armlinux yet"
return
0
fi
fi
for
abi
in
"armv8"
"armv7"
"armv7hf"
;
do
# TODO(hongming): enable compile armv7 and armv7hf on armlinux
if
[[
${
abi
}
==
"armv7hf"
]]
;
then
if
[[
${
abi
}
==
"armv7hf"
]]
;
then
echo
"armv7hf is not supported on both android and
armlinux yet"
echo
"armv7hf is not supported on
armlinux yet"
continue
return
0
fi
fi
if
[[
${
abi
}
==
"armv7"
]]
;
then
# TODO(hongming): enable armv7 on armlinux
if
[[
${
os
}
==
"armlinux"
&&
${
abi
}
==
"armv7"
]]
;
then
echo
"armv7 is not supported on armlinux yet"
echo
"armv7 is not supported on armlinux yet"
continue
return
0
fi
fi
fi
if
[[
${
os
}
==
"android"
&&
${
abi
}
==
"armv7hf"
]]
;
then
if
[[
${
os
}
==
"android"
&&
${
abi
}
==
"armv7hf"
]]
;
then
echo
"android do not need armv7hf"
echo
"android do not need armv7hf"
continue
return
0
fi
fi
build_dir
=
$cur_dir
/build.lite.
${
os
}
.
${
abi
}
.
${
lang
}
build_dir
=
$cur_dir
/build.lite.
${
os
}
.
${
abi
}
.
${
lang
}
...
@@ -188,11 +192,47 @@ function build_test_arm {
...
@@ -188,11 +192,47 @@ function build_test_arm {
cmake_arm
${
os
}
${
abi
}
${
lang
}
cmake_arm
${
os
}
${
abi
}
${
lang
}
build
$TESTS_FILE
build
$TESTS_FILE
}
# $1: ARM_TARGET_OS in "android" , "armlinux"
# $2: ARM_TARGET_ARCH_ABI in "armv8", "armv7" ,"armv7hf"
# $3: ARM_TARGET_LANG in "gcc" "clang"
# $4: android test port
# Note: test must be in build dir
function
test_arm
{
os
=
$1
abi
=
$2
lang
=
$3
port
=
$4
if
[[
${
os
}
==
"armlinux"
]]
;
then
# TODO(hongming): enable test armlinux on armv8, armv7 and armv7hf
echo
"Skip test arm linux yet. armlinux must in another docker"
return
0
fi
if
[[
${
os
}
==
"android"
&&
${
abi
}
==
"armv7hf"
]]
;
then
echo
"android do not need armv7hf"
return
0
fi
# TODO(yuanshuai): enable armv7 on android
if
[[
${
abi
}
==
"armv7"
]]
;
then
echo
"skip android v7 test yet"
return
0
fi
echo
"test file:
${
TESTS_FILE
}
"
for
_test
in
$(
cat
$TESTS_FILE
)
;
do
test_arm_android
$_test
$port
done
done
done
# TODO(sangoly): refine this
done
test_arm_model
"test_cxx_api_lite"
$port
"./third_party/install/mobilenet_v2_relu"
}
# 2. Then test
# Build the code and run lite arm tests. This is executed in the CI system.
function
build_test_arm
{
########################################################################
# job 1-4 must be in one runner
port_armv8
=
5554
port_armv8
=
5554
port_armv7
=
5556
port_armv7
=
5556
...
@@ -206,39 +246,46 @@ function build_test_arm {
...
@@ -206,39 +246,46 @@ function build_test_arm {
echo
-ne
'\n'
|
${
ANDROID_HOME
}
/emulator/emulator
-avd
paddle-armv7
-noaudio
-no-window
-gpu
off
-verbose
-port
${
port_armv7
}
&
echo
-ne
'\n'
|
${
ANDROID_HOME
}
/emulator/emulator
-avd
paddle-armv7
-noaudio
-no-window
-gpu
off
-verbose
-port
${
port_armv7
}
&
sleep
1m
sleep
1m
# now can only test android.
# job 1
for
lang
in
"gcc"
"clang"
;
do
build_arm
"android"
"armv8"
"gcc"
for
abi
in
"armv8"
"armv7"
;
do
test_arm
"android"
"armv8"
"gcc"
${
port_armv8
}
# TODO(yuanshuai): enable armv7 on android
cd
-
if
[[
${
abi
}
==
"armv7"
]]
;
then
continue
fi
build_dir
=
$cur_dir
/build.lite.android.
${
abi
}
.
${
lang
}
# job 2
cd
$build_dir
build_arm
"android"
"armv8"
"clang"
test_arm
"android"
"armv8"
"clang"
${
port_armv8
}
cd
-
local
port
=
# job 3
if
[[
${
abi
}
==
"armv7"
]]
;
then
build_arm
"android"
"armv7"
"gcc"
port
=
${
port_armv7
}
test_arm
"android"
"armv7"
"gcc"
${
port_armv7
}
fi
cd
-
if
[[
${
abi
}
==
"armv8"
]]
;
then
# job 4
port
=
${
port_armv8
}
build_arm
"android"
"armv7"
"clang"
fi
test_arm
"android"
"armv7"
"clang"
${
port_armv7
}
echo
"test file:
${
TESTS_FILE
}
"
cd
-
for
_test
in
$(
cat
$TESTS_FILE
)
;
do
test_arm_android
$_test
$port
done
# TODO(sangoly): refine this
test_arm_model
"test_cxx_api_lite"
$port
"./third_party/install/mobilenet_v2_relu"
done
done
# armlinux need in another docker
# TODO(hongming): enable test armlinux on armv8, armv7 and armv7hf
adb devices |
grep
emulator |
cut
-f1
|
while
read
line
;
do
adb
-s
$line
emu
kill
;
done
adb devices |
grep
emulator |
cut
-f1
|
while
read
line
;
do
adb
-s
$line
emu
kill
;
done
echo
"Done"
echo
"Done"
########################################################################
# job 5
build_arm
"armlinux"
"armv8"
test_arm
"armlinux"
"armv8"
cd
-
# job 6
build_arm
"armlinux"
"armv7"
test_arm
"armlinux"
"armv7"
cd
-
# job 7
build_arm
"armlinux"
"armv7hf"
test_arm
"armlinux"
"armv7hf"
cd
-
echo
"Done"
}
}
############################# MAIN #################################
############################# MAIN #################################
...
@@ -279,6 +326,10 @@ function main {
...
@@ -279,6 +326,10 @@ function main {
ARM_ABI
=
"
${
i
#*=
}
"
ARM_ABI
=
"
${
i
#*=
}
"
shift
shift
;;
;;
--arm_lang
=
*
)
ARM_LANG
=
"
${
i
#*=
}
"
shift
;;
--arm_port
=
*
)
--arm_port
=
*
)
ARM_PORT
=
"
${
i
#*=
}
"
ARM_PORT
=
"
${
i
#*=
}
"
shift
shift
...
@@ -301,13 +352,21 @@ function main {
...
@@ -301,13 +352,21 @@ function main {
shift
shift
;;
;;
cmake_arm
)
cmake_arm
)
cmake_arm
$ARM_OS
$ARM_ABI
cmake_arm
$ARM_OS
$ARM_ABI
$ARM_LANG
shift
;;
build_arm
)
build_arm
$ARM_OS
$ARM_ABI
$ARM_LANG
shift
shift
;;
;;
test_server
)
test_server
)
test_lite
$TESTS_FILE
test_lite
$TESTS_FILE
shift
shift
;;
;;
test_arm
)
build_arm
$ARM_OS
$ARM_ABI
$ARM_LANG
$ARM_PORT
shift
;;
test_arm_android
)
test_arm_android
)
test_arm_android
$TEST_NAME
$ARM_PORT
test_arm_android
$TEST_NAME
$ARM_PORT
shift
shift
...
...
paddle/fluid/lite/utils/varient.h
浏览文件 @
4fe5c8aa
...
@@ -20,6 +20,7 @@
...
@@ -20,6 +20,7 @@
#include <typeinfo>
#include <typeinfo>
#include <utility>
#include <utility>
#include "paddle/fluid/lite/utils/cp_logging.h"
#include "paddle/fluid/lite/utils/cp_logging.h"
#include "paddle/fluid/lite/utils/string.h"
// This is an equivalent implementation of boost::any. We implement this to
// This is an equivalent implementation of boost::any. We implement this to
// avoid including the whole boost library and keep the inference library small.
// avoid including the whole boost library and keep the inference library small.
...
@@ -116,9 +117,9 @@ struct variant {
...
@@ -116,9 +117,9 @@ struct variant {
if
(
type_id
==
typeid
(
T
).
hash_code
())
if
(
type_id
==
typeid
(
T
).
hash_code
())
return
*
reinterpret_cast
<
const
T
*>
(
&
data
);
return
*
reinterpret_cast
<
const
T
*>
(
&
data
);
else
else
throw
std
::
invalid_argument
(
"unmatched type"
);
throw
std
::
invalid_argument
(
// LOG(FATAL) << "unmatched type get, should be " << type_id << " but get "
string_format
(
"unmatched type, store as %d, but want to get %s"
,
// << typeid(T).name(
);
type_id
,
typeid
(
T
).
name
())
);
return
*
reinterpret_cast
<
const
T
*>
(
&
data
);
return
*
reinterpret_cast
<
const
T
*>
(
&
data
);
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录