From 43d97500ba67931169fb7ead5477e008525123fd Mon Sep 17 00:00:00 2001 From: Chunwei Date: Wed, 19 Jun 2019 09:44:29 +0000 Subject: [PATCH] add scale eliminate pass --- paddle/fluid/lite/api/apis_test.cc | 2 +- paddle/fluid/lite/api/cxx_api_bin.cc | 2 +- paddle/fluid/lite/api/lite_api_test_helper.cc | 1 + paddle/fluid/lite/core/CMakeLists.txt | 1 - paddle/fluid/lite/core/mir/CMakeLists.txt | 42 ++--- .../lite/core/mir/elimination/CMakeLists.txt | 7 + .../identity_scale_eliminate_pass.cc | 72 +++++++++ .../identity_scale_eliminate_pass_test.cc | 93 +++++++++++ .../fluid/lite/core/mir/fusion/CMakeLists.txt | 2 +- .../mir/{ => fusion}/conv_bn_fuse_pass.cc | 2 +- .../core/mir/{ => fusion}/conv_bn_fuse_pass.h | 0 .../core/mir/fusion/conv_bn_fuse_pass_test.cc | 2 +- .../lite/core/mir/fusion/conv_bn_fuser.cc | 2 +- ...nv_elementwise_add_activation_fuse_pass.cc | 2 +- ...onv_elementwise_add_activation_fuse_pass.h | 0 ...ementwise_add_activation_fuse_pass_test.cc | 2 +- .../conv_elementwise_add_activation_fuser.cc | 2 +- .../conv_elementwise_add_relu_fuse_pass.cc | 39 +++++ .../conv_elementwise_add_relu_fuse_pass.h | 32 ++++ ...onv_elementwise_add_relu_fuse_pass_test.cc | 153 ++++++++++++++++++ .../elementwise_add_activation_fuse_pass.cc | 2 +- .../elementwise_add_activation_fuse_pass.h | 0 ...ementwise_add_activation_fuse_pass_test.cc | 2 +- .../elementwise_add_activation_fuser.cc | 2 +- .../core/mir/{ => fusion}/fc_fuse_pass.cc | 2 +- .../lite/core/mir/{ => fusion}/fc_fuse_pass.h | 0 .../mir/{ => fusion}/fc_fuse_pass_test.cc | 2 +- paddle/fluid/lite/core/mir/fusion/fc_fuser.cc | 2 +- .../lite/core/mir/generate_program_pass.cc | 2 +- .../lite/core/mir/graph_visualize_pass.cc | 2 +- .../lite/core/mir/io_copy_kernel_pick_pass.cc | 4 +- paddle/fluid/lite/core/mir/node.cc | 59 +++++++ paddle/fluid/lite/core/mir/node.h | 66 ++++---- paddle/fluid/lite/core/mir/pattern_matcher.h | 5 +- .../lite/core/mir/pattern_matcher_high_api.cc | 1 + .../lite/core/mir/pattern_matcher_high_api.h | 8 +- .../core/mir/pattern_matcher_high_api_test.cc | 4 +- .../lite/core/mir/pattern_matcher_test.cc | 26 +-- paddle/fluid/lite/core/mir/ssa_graph.h | 5 + .../lite/core/mir/static_kernel_pick_pass.cc | 12 +- .../core/mir/type_target_transform_pass.cc | 16 +- paddle/fluid/lite/core/mir/use_passes.h | 2 + .../core/mir/variable_place_inference_pass.h | 10 +- paddle/fluid/lite/core/op_lite.cc | 3 +- paddle/fluid/lite/core/op_lite.h | 16 ++ paddle/fluid/lite/core/optimizer.h | 3 + paddle/fluid/lite/core/program.h | 2 +- paddle/fluid/lite/core/tensor.h | 1 - .../kernels/x86/elementwise_compute_test.cc | 2 + paddle/fluid/lite/kernels/x86/mul_compute.h | 42 +++-- .../lite/kernels/x86/mul_compute_test.cc | 2 + paddle/fluid/lite/model_parser/CMakeLists.txt | 3 +- .../lite/model_parser/cpp/CMakeLists.txt | 1 - paddle/fluid/lite/model_parser/desc_apis.h | 22 +++ paddle/fluid/lite/model_parser/pb/op_desc.h | 2 + paddle/fluid/lite/operators/mul_op.h | 8 +- paddle/fluid/lite/operators/op_params.h | 4 +- paddle/fluid/lite/utils/varient.h | 7 +- 58 files changed, 675 insertions(+), 135 deletions(-) create mode 100644 paddle/fluid/lite/core/mir/elimination/CMakeLists.txt create mode 100644 paddle/fluid/lite/core/mir/elimination/identity_scale_eliminate_pass.cc create mode 100644 paddle/fluid/lite/core/mir/elimination/identity_scale_eliminate_pass_test.cc rename paddle/fluid/lite/core/mir/{ => fusion}/conv_bn_fuse_pass.cc (94%) rename paddle/fluid/lite/core/mir/{ => fusion}/conv_bn_fuse_pass.h (100%) rename paddle/fluid/lite/core/mir/{ => fusion}/conv_elementwise_add_activation_fuse_pass.cc (94%) rename paddle/fluid/lite/core/mir/{ => fusion}/conv_elementwise_add_activation_fuse_pass.h (100%) rename paddle/fluid/lite/core/mir/{ => fusion}/conv_elementwise_add_activation_fuse_pass_test.cc (98%) create mode 100644 paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass.cc create mode 100644 paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass.h create mode 100644 paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass_test.cc rename paddle/fluid/lite/core/mir/{ => fusion}/elementwise_add_activation_fuse_pass.cc (93%) rename paddle/fluid/lite/core/mir/{ => fusion}/elementwise_add_activation_fuse_pass.h (100%) rename paddle/fluid/lite/core/mir/{ => fusion}/elementwise_add_activation_fuse_pass_test.cc (97%) rename paddle/fluid/lite/core/mir/{ => fusion}/fc_fuse_pass.cc (94%) rename paddle/fluid/lite/core/mir/{ => fusion}/fc_fuse_pass.h (100%) rename paddle/fluid/lite/core/mir/{ => fusion}/fc_fuse_pass_test.cc (98%) diff --git a/paddle/fluid/lite/api/apis_test.cc b/paddle/fluid/lite/api/apis_test.cc index 4d99f238dd6..7dd6a119375 100644 --- a/paddle/fluid/lite/api/apis_test.cc +++ b/paddle/fluid/lite/api/apis_test.cc @@ -82,7 +82,7 @@ TEST(CXXApi_LightApi, save_and_load_model) { ASSERT_TRUE(TensorCompareWith(*cxx_out, *light_out)); std::vector tensors_with_order({ - "a", "fc_0.w_0", "fc_0.tmp_0", "scale_0.tmp_0", + "a", "fc_0.w_0", "scale_0.tmp_0", }); for (const auto& tensor_name : tensors_with_order) { diff --git a/paddle/fluid/lite/api/cxx_api_bin.cc b/paddle/fluid/lite/api/cxx_api_bin.cc index b064c663d11..af6f8e44d69 100644 --- a/paddle/fluid/lite/api/cxx_api_bin.cc +++ b/paddle/fluid/lite/api/cxx_api_bin.cc @@ -13,7 +13,7 @@ // limitations under the License. #include "paddle/fluid/lite/api/cxx_api.h" -#include +#include // NOLINT #include "paddle/fluid/lite/core/mir/use_passes.h" #include "paddle/fluid/lite/core/op_registry.h" diff --git a/paddle/fluid/lite/api/lite_api_test_helper.cc b/paddle/fluid/lite/api/lite_api_test_helper.cc index 490a64bb512..b8254172330 100644 --- a/paddle/fluid/lite/api/lite_api_test_helper.cc +++ b/paddle/fluid/lite/api/lite_api_test_helper.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/lite/api/lite_api_test_helper.h" +#include DEFINE_string(model_dir, "", ""); DEFINE_string(optimized_model, "", ""); diff --git a/paddle/fluid/lite/core/CMakeLists.txt b/paddle/fluid/lite/core/CMakeLists.txt index 665d7555e37..1e95668cddc 100644 --- a/paddle/fluid/lite/core/CMakeLists.txt +++ b/paddle/fluid/lite/core/CMakeLists.txt @@ -59,4 +59,3 @@ lite_cc_test(test_type_system SRCS type_system_test.cc DEPS type_system utils_li lite_cc_test(test_types_lite SRCS types_test.cc DEPS types_lite) lite_cc_test(test_memory_lite SRCS memory_test.cc DEPS memory_lite) lite_cc_test(test_context_lite SRCS context_test.cc DEPS context_lite X86_DEPS operator) - diff --git a/paddle/fluid/lite/core/mir/CMakeLists.txt b/paddle/fluid/lite/core/mir/CMakeLists.txt index 0d7fcf8b3b2..e67ade8cbef 100644 --- a/paddle/fluid/lite/core/mir/CMakeLists.txt +++ b/paddle/fluid/lite/core/mir/CMakeLists.txt @@ -5,21 +5,25 @@ cc_library(mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph mir cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager) add_subdirectory(fusion) +add_subdirectory(elimination) + cc_library(mir_passes - SRCS fc_fuse_pass.cc - conv_elementwise_add_activation_fuse_pass.cc - elementwise_add_activation_fuse_pass.cc - conv_bn_fuse_pass.cc - static_kernel_pick_pass.cc - variable_place_inference_pass.cc - type_target_transform_pass.cc - io_copy_kernel_pick_pass.cc - graph_visualize_pass.cc - generate_program_pass.cc - argument_type_display_pass.cc - demo_pass.cc - runtime_context_assign_pass.cc - DEPS mir_pass types_lite context_lite ${mir_fusers}) + SRCS + fusion/fc_fuse_pass.cc + fusion/conv_elementwise_add_activation_fuse_pass.cc + fusion/conv_bn_fuse_pass.cc + fusion/elementwise_add_activation_fuse_pass.cc + elimination/identity_scale_eliminate_pass.cc + static_kernel_pick_pass.cc + variable_place_inference_pass.cc + type_target_transform_pass.cc + io_copy_kernel_pick_pass.cc + graph_visualize_pass.cc + generate_program_pass.cc + argument_type_display_pass.cc + demo_pass.cc + runtime_context_assign_pass.cc + DEPS mir_pass types_lite context_lite ${mir_fusers}) #cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS #mir_ssa_graph scope_lite op_lite @@ -73,7 +77,7 @@ message(STATUS "----> Ops lite: ${ops_lite}") message(STATUS "----> Host kernels: ${host_kernels}") message(STATUS "----> X86 kernels: ${x86_kernels}") -lite_cc_test(test_lite_fc_fuse SRCS fc_fuse_pass_test.cc +lite_cc_test(test_lite_fc_fuse SRCS fusion/fc_fuse_pass_test.cc DEPS cxx_api_lite mir_passes ${ops_lite} ${host_kernels} ${x86_kernels} ${arm_kernels} ARGS --model_dir=${LITE_MODEL_DIR}/lite_fc_model @@ -83,11 +87,11 @@ lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "lite_fc_model.tar.gz add_dependencies(test_lite_fc_fuse extern_lite_download_lite_fc_model_tar_gz) -lite_cc_test(test_lite_conv_elementwise_add_activation_fuse - SRCS conv_elementwise_add_activation_fuse_pass_test.cc +lite_cc_test(test_lite_conv_elementwise_add_activation_fuse + SRCS fusion/conv_elementwise_add_activation_fuse_pass_test.cc DEPS cxx_api_lite mir_passes ${ops_lite} ${host_kernels} ${x86_kernels}) -lite_cc_test(test_lite_elementwise_add_activation_fuse - SRCS elementwise_add_activation_fuse_pass_test.cc +lite_cc_test(test_lite_elementwise_add_activation_fuse + SRCS fusion/elementwise_add_activation_fuse_pass_test.cc DEPS cxx_api_lite mir_passes ${ops_lite} ${host_kernels} ${x86_kernels}) diff --git a/paddle/fluid/lite/core/mir/elimination/CMakeLists.txt b/paddle/fluid/lite/core/mir/elimination/CMakeLists.txt new file mode 100644 index 00000000000..9fda8ec29a4 --- /dev/null +++ b/paddle/fluid/lite/core/mir/elimination/CMakeLists.txt @@ -0,0 +1,7 @@ +if (NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) + lite_cc_test(test_identity_scale_eliminate_pass_lite + SRCS identity_scale_eliminate_pass_test.cc + DEPS mir_passes program_lite proto_desc cpp_op_desc_lite + ${ops_lite} + ) +endif() diff --git a/paddle/fluid/lite/core/mir/elimination/identity_scale_eliminate_pass.cc b/paddle/fluid/lite/core/mir/elimination/identity_scale_eliminate_pass.cc new file mode 100644 index 00000000000..6f8aeb65c05 --- /dev/null +++ b/paddle/fluid/lite/core/mir/elimination/identity_scale_eliminate_pass.cc @@ -0,0 +1,72 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/pass.h" +#include "paddle/fluid/lite/core/mir/pass_registry.h" +#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h" + +namespace paddle { +namespace lite { +namespace mir { + +namespace { + +class Eliminator : public FuseBase { + public: + void BuildPattern() override { + auto* pre_op = OpNode("preop"); // the previous op's output need update + // TODO(Superjomn) check has only one output + auto* x = VarNode("x")->assert_is_op_input("scale", "X"); + auto* scale_op = OpNode("scale", "scale") + ->assert_op_attr("scale", 1.) + ->assert_op_attr("bias", 0.); + auto* out = VarNode("out")->assert_is_op_output("scale", "Out"); + + *pre_op >> *x >> *scale_op >> *out; + + // The pre_op will be eliminated, and a new output-updated op will insert. + x->AsIntermediate(); // x is pre_op's output, need to update + } + + private: + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override { + auto& pre_op = matched.at("preop")->AsStmt(); + auto op_info = *pre_op.op_info(); + + op_info.UpdateAllOutputs(matched.at("x")->AsArg().name, + matched.at("out")->AsArg().name); + pre_op.ResetOp(op_info, graph->valid_places()); + + GraphSafeRemoveNodes(graph, {matched.at("scale")}); + + IR_NODE_LINK_TO(matched.at("preop"), matched.at("out")); + } +}; + +} // namespace + +class IdentityScaleEliminatePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override { + Eliminator eliminator; + eliminator(graph.get()); + } +}; + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(identity_scale_eliminate_pass, + paddle::lite::mir::IdentityScaleEliminatePass); diff --git a/paddle/fluid/lite/core/mir/elimination/identity_scale_eliminate_pass_test.cc b/paddle/fluid/lite/core/mir/elimination/identity_scale_eliminate_pass_test.cc new file mode 100644 index 00000000000..89db35fe0e8 --- /dev/null +++ b/paddle/fluid/lite/core/mir/elimination/identity_scale_eliminate_pass_test.cc @@ -0,0 +1,93 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h" +#include "paddle/fluid/lite/core/mir/pass_registry.h" +#include "paddle/fluid/lite/core/mir/ssa_graph.h" + +namespace paddle { +namespace lite { +namespace mir { + +std::unique_ptr BuildGraph(framework::ProgramDesc* program_desc, + const std::shared_ptr& scope, + const std::vector& valid_places) { + // Op list: + // (x)->feed -> (feed) -> scale -> (scale_out) -> fetch->(fetch) + // After pass + // (x)->feed->(scale_out)->fetch->(fetch) + auto* main_block = program_desc->MutableBlock(0); + auto* feed_op = main_block->AppendOp(); + auto* scale_op = main_block->AppendOp(); + auto* fetch_op = main_block->AppendOp(); + main_block->Var("x"); + main_block->Var("feed"); + main_block->Var("scale_out"); + main_block->Var("fetch_out"); + + scope->Var("x")->GetMutable(); + scope->Var("feed")->GetMutable(); + scope->Var("scale_out")->GetMutable(); + scope->Var("fetch_out")->GetMutable(); + + feed_op->SetType("feed"); + feed_op->SetInput("X", {"x"}); + feed_op->SetAttr("col", 1); + feed_op->SetOutput("Out", {"feed"}); + + scale_op->SetType("scale"); + scale_op->SetInput("X", {"feed"}); + scale_op->SetOutput("Out", {"scale_out"}); + scale_op->SetAttr("scale", 1.f); + scale_op->SetAttr("bias", 0.f); + scale_op->SetAttr("bias_after_scale", true); + + fetch_op->SetType("fetch"); + fetch_op->SetInput("X", {"scale_out"}); + fetch_op->SetOutput("Out", {"fetch"}); + fetch_op->SetAttr("col", 1); + + program_desc->Flush(); + + lite::Program program(*program_desc->Proto(), scope, valid_places); + auto graph = std::unique_ptr(new SSAGraph()); + graph->Build(program, valid_places); + + LOG(INFO) << Visualize(graph.get()); + + return graph; +} + +TEST(identity_test, test) { + framework::ProgramDesc program_desc; + std::vector places{{TARGET(kHost), PRECISION(kFloat)}}; + auto scope = std::make_shared(); + auto graph = BuildGraph(&program_desc, scope, places); + const int num_nodes = graph->nodes().size(); + auto pass = PassManager::Global().LookUp("identity_scale_eliminate_pass"); + ASSERT_TRUE(pass); + pass->Apply(graph); + ASSERT_EQ(graph->nodes().size(), num_nodes - 2UL); +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +USE_LITE_OP(feed) +USE_LITE_OP(fetch) +USE_LITE_OP(scale) +USE_MIR_PASS(identity_scale_eliminate_pass) diff --git a/paddle/fluid/lite/core/mir/fusion/CMakeLists.txt b/paddle/fluid/lite/core/mir/fusion/CMakeLists.txt index 9139293c8aa..db092e17679 100644 --- a/paddle/fluid/lite/core/mir/fusion/CMakeLists.txt +++ b/paddle/fluid/lite/core/mir/fusion/CMakeLists.txt @@ -11,7 +11,7 @@ cc_library(fuse_elementwise_add_activation SRCS elementwise_add_activation_fuser.cc DEPS pattern_matcher_high_api) -set(mir_fusers +set(mir_fusers fuse_fc fuse_conv_elementwise_add_activation fuse_conv_bn diff --git a/paddle/fluid/lite/core/mir/conv_bn_fuse_pass.cc b/paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass.cc similarity index 94% rename from paddle/fluid/lite/core/mir/conv_bn_fuse_pass.cc rename to paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass.cc index 562ec7f4507..1e7d7bc5774 100644 --- a/paddle/fluid/lite/core/mir/conv_bn_fuse_pass.cc +++ b/paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h" +#include "paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass.h" #include #include #include "paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.h" diff --git a/paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h b/paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass.h similarity index 100% rename from paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h rename to paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass.h diff --git a/paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass_test.cc b/paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass_test.cc index 79436a9fa3d..3a8573b4f8c 100644 --- a/paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass_test.cc +++ b/paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h" +#include "paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass.h" #include #include #include diff --git a/paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.cc b/paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.cc index d29f078513e..0a73d1e39d9 100644 --- a/paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.cc +++ b/paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.cc @@ -70,7 +70,7 @@ void ConvBNFuser::BuildPattern() { void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { auto op_desc = GenOpDesc(matched); auto eltwise_op = LiteOpRegistry::Global().Create("elementwise_add"); - auto conv = matched.at("conv2d")->stmt()->op; + auto conv = matched.at("conv2d")->stmt()->op(); auto* scope = conv->scope(); auto& valid_places = conv->valid_places(); diff --git a/paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.cc b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuse_pass.cc similarity index 94% rename from paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.cc rename to paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuse_pass.cc index 27f6413c47b..f4eb5a00ad2 100644 --- a/paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.cc +++ b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuse_pass.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.h" +#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuse_pass.h" #include #include #include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.h" diff --git a/paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.h b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuse_pass.h similarity index 100% rename from paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.h rename to paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuse_pass.h diff --git a/paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass_test.cc b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuse_pass_test.cc similarity index 98% rename from paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass_test.cc rename to paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuse_pass_test.cc index a67e577505f..e7751d801ea 100644 --- a/paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass_test.cc +++ b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuse_pass_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.h" +#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuse_pass.h" #include #include #include diff --git a/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.cc b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.cc index b26b758fb23..a085b139c86 100644 --- a/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.cc +++ b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.cc @@ -65,7 +65,7 @@ void ConvElementwiseAddActivationFuser::InsertNewNode( SSAGraph* graph, const key2nodes_t& matched) { auto op_desc = GenOpDesc(matched); auto conv_op = LiteOpRegistry::Global().Create(conv_type_); - auto conv_old = matched.at("conv2d")->stmt()->op; + auto conv_old = matched.at("conv2d")->stmt()->op(); auto* scope = conv_old->scope(); auto& valid_places = conv_old->valid_places(); conv_op->Attach(op_desc, scope); diff --git a/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass.cc b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass.cc new file mode 100644 index 00000000000..4ace19f304b --- /dev/null +++ b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass.cc @@ -0,0 +1,39 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "conv_elementwise_add_relu_fuse_pass.h" +#include +#include +#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h" +#include "paddle/fluid/lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +void ConvElementwiseAddReLUFusePass::Apply( + const std::unique_ptr& graph) { + fusion::ConvElementwiseAddReLUFuser fuser("conv2d"); + fuser(graph.get()); + + fusion::ConvElementwiseAddReLUFuser depthwise_fuser("depthwise_conv2d"); + depthwise_fuser(graph.get()); +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(lite_conv_elementwise_add_act_fuse_pass, + paddle::lite::mir::ConvElementwiseAddReLUFusePass); diff --git a/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass.h b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass.h new file mode 100644 index 00000000000..4276f1ffc8c --- /dev/null +++ b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass.h @@ -0,0 +1,32 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/lite/core/mir/pass.h" + +namespace paddle { +namespace lite { +namespace mir { + +class ConvElementwiseAddReLUFusePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass_test.cc b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass_test.cc new file mode 100644 index 00000000000..00c9eaf8c07 --- /dev/null +++ b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass_test.cc @@ -0,0 +1,153 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "conv_elementwise_add_relu_fuse_pass.h" +#include +#include +#include +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/lite/api/cxx_api.h" +#include "paddle/fluid/lite/core/compatible_tensor.h" +#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h" +#include "paddle/fluid/lite/core/mir/passes.h" +#include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/lite/core/program.h" + +DEFINE_string(model_dir, "", ""); +DEFINE_string(optimized_model, "", ""); + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +std::unique_ptr BuildGraph(framework::ProgramDesc* program_desc, + const std::shared_ptr& scope, + const std::vector& valid_places) { + auto* main_block = program_desc->MutableBlock(0); + + auto* conv2d_1 = main_block->AppendOp(); + auto* conv2d_2 = main_block->AppendOp(); + auto* add_1 = main_block->AppendOp(); + auto* relu_1 = main_block->AppendOp(); + auto* add_2 = main_block->AppendOp(); + auto* relu_2 = main_block->AppendOp(); + + main_block->Var("input_1"); + main_block->Var("input_2"); + main_block->Var("filter_1"); + main_block->Var("filter_2"); + main_block->Var("conv2d_1_out"); + main_block->Var("conv2d_2_out"); + main_block->Var("bias_1"); + main_block->Var("add_1_out"); + main_block->Var("add_2_out"); + main_block->Var("relu_1_out"); + main_block->Var("out"); + + scope->Var("input_1")->GetMutable(); + scope->Var("input_2")->GetMutable(); + scope->Var("filter_1")->GetMutable(); + scope->Var("filter_2")->GetMutable(); + scope->Var("conv2d_1_out")->GetMutable(); + scope->Var("conv2d_2_out")->GetMutable(); + scope->Var("bias_1")->GetMutable(); + scope->Var("add_1_out")->GetMutable(); + scope->Var("add_2_out")->GetMutable(); + scope->Var("relu_1_out")->GetMutable(); + scope->Var("out")->GetMutable(); + + conv2d_1->SetType("conv2d"); + conv2d_1->SetInput("Input", {"input_1"}); + conv2d_1->SetInput("Filter", {"filter_1"}); + conv2d_1->SetOutput("Output", {"conv2d_1_out"}); + conv2d_1->SetAttr("strides", std::vector({1, 1})); + conv2d_1->SetAttr("paddings", std::vector({0, 0})); + conv2d_1->SetAttr("groups", 1); + conv2d_1->SetAttr("dilations", std::vector({1, 1})); + conv2d_1->SetAttr("fuse_relu", false); + + add_1->SetType("elementwise_add"); + add_1->SetInput("X", {"conv2d_1_out"}); + add_1->SetInput("Y", {"bias_1"}); + add_1->SetOutput("Out", {"add_1_out"}); + add_1->SetAttr("axis", 1); + + relu_1->SetType("relu"); + relu_1->SetInput("X", {"add_1_out"}); + relu_1->SetOutput("Out", {"relu_1_out"}); + + conv2d_2->SetType("conv2d"); + conv2d_2->SetInput("Input", {"input_2"}); + conv2d_2->SetInput("Filter", {"filter_2"}); + conv2d_2->SetOutput("Output", {"conv2d_2_out"}); + conv2d_2->SetAttr("strides", std::vector({1, 1})); + conv2d_2->SetAttr("paddings", std::vector({0, 0})); + conv2d_2->SetAttr("groups", 1); + conv2d_2->SetAttr("dilations", std::vector({1, 1})); + conv2d_2->SetAttr("fuse_relu", false); + + add_2->SetType("elementwise_add"); + add_2->SetInput("X", {"conv2d_2_out"}); + add_2->SetInput("Y", {"relu_1_out"}); + add_2->SetOutput("Out", {"add_2_out"}); + add_2->SetAttr("axis", 1); + + relu_2->SetType("relu"); + relu_2->SetInput("X", {"add_2_out"}); + relu_2->SetOutput("Out", {"out"}); + + program_desc->Flush(); + + lite::Program program(*program_desc->Proto(), scope, valid_places); + auto graph = std::unique_ptr(new SSAGraph()); + graph->Build(program, valid_places); + + return graph; +} + +TEST(conv_elementwise_add_relu_fuse_pass, graph_test) { + framework::ProgramDesc program_desc; + std::vector places{{TARGET(kHost), PRECISION(kFloat)}}; + auto scope = std::make_shared(); + auto graph = BuildGraph(&program_desc, scope, places); + + Visualize(graph.get()); + ASSERT_EQ(graph->nodes().size(), 11UL /*vars*/ + 6UL /*ops*/); + Visualize(graph.get()); +} + +TEST(conv_elementwise_add_relu_fuse_pass, fuse_test_op) { + framework::ProgramDesc program_desc; + std::vector places{{TARGET(kHost), PRECISION(kFloat)}}; + auto scope = std::make_shared(); + auto graph = BuildGraph(&program_desc, scope, places); + Visualize(graph.get()); + const int num_nodes = graph->nodes().size(); + auto* fuser = new ConvElementwiseAddReLUFusePass; + fuser->Apply(graph); + Visualize(graph.get()); + ASSERT_EQ(graph->nodes().size(), num_nodes - 5UL * 2 /*nodes removed */ + + 1UL * 2 /* fused fc node*/); +} + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle + +USE_LITE_OP(elementwise_add); +USE_LITE_OP(conv2d); +USE_LITE_OP(depthwise_conv2d); +USE_LITE_OP(relu); diff --git a/paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.cc b/paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.cc similarity index 93% rename from paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.cc rename to paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.cc index 9ce455dcdaf..20d1eaa82a8 100644 --- a/paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.cc +++ b/paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.h" +#include "paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.h" #include #include #include "paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.h" diff --git a/paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.h b/paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.h similarity index 100% rename from paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.h rename to paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.h diff --git a/paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass_test.cc b/paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuse_pass_test.cc similarity index 97% rename from paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass_test.cc rename to paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuse_pass_test.cc index 7f64eead9ea..4b7742e059d 100644 --- a/paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass_test.cc +++ b/paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuse_pass_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.h" +#include "paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.h" #include #include #include diff --git a/paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.cc b/paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.cc index 83b916eea3e..cafbc42d85b 100644 --- a/paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.cc +++ b/paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.cc @@ -54,7 +54,7 @@ void ElementwiseAddActivationFuser::InsertNewNode(SSAGraph* graph, auto op_desc = GenOpDesc(matched); auto op = LiteOpRegistry::Global().Create("fusion_elementwise_add_activation"); - auto old_op = matched.at("add")->stmt()->op; + auto old_op = matched.at("add")->stmt()->op(); auto* scope = old_op->scope(); auto& valid_places = old_op->valid_places(); op->Attach(op_desc, scope); diff --git a/paddle/fluid/lite/core/mir/fc_fuse_pass.cc b/paddle/fluid/lite/core/mir/fusion/fc_fuse_pass.cc similarity index 94% rename from paddle/fluid/lite/core/mir/fc_fuse_pass.cc rename to paddle/fluid/lite/core/mir/fusion/fc_fuse_pass.cc index 008f05ce5cb..f50db9c17b3 100644 --- a/paddle/fluid/lite/core/mir/fc_fuse_pass.cc +++ b/paddle/fluid/lite/core/mir/fusion/fc_fuse_pass.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/lite/core/mir/fc_fuse_pass.h" +#include "paddle/fluid/lite/core/mir/fusion/fc_fuse_pass.h" #include #include #include "paddle/fluid/lite/core/mir/fusion/fc_fuser.h" diff --git a/paddle/fluid/lite/core/mir/fc_fuse_pass.h b/paddle/fluid/lite/core/mir/fusion/fc_fuse_pass.h similarity index 100% rename from paddle/fluid/lite/core/mir/fc_fuse_pass.h rename to paddle/fluid/lite/core/mir/fusion/fc_fuse_pass.h diff --git a/paddle/fluid/lite/core/mir/fc_fuse_pass_test.cc b/paddle/fluid/lite/core/mir/fusion/fc_fuse_pass_test.cc similarity index 98% rename from paddle/fluid/lite/core/mir/fc_fuse_pass_test.cc rename to paddle/fluid/lite/core/mir/fusion/fc_fuse_pass_test.cc index e2f7dd1a87d..b64a436f925 100644 --- a/paddle/fluid/lite/core/mir/fc_fuse_pass_test.cc +++ b/paddle/fluid/lite/core/mir/fusion/fc_fuse_pass_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/lite/core/mir/fc_fuse_pass.h" +#include "fc_fuse_pass.h" #include #include #include diff --git a/paddle/fluid/lite/core/mir/fusion/fc_fuser.cc b/paddle/fluid/lite/core/mir/fusion/fc_fuser.cc index a8b6336595c..bb350c731c6 100644 --- a/paddle/fluid/lite/core/mir/fusion/fc_fuser.cc +++ b/paddle/fluid/lite/core/mir/fusion/fc_fuser.cc @@ -46,7 +46,7 @@ void FcFuser::BuildPattern() { void FcFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { auto op_desc = GenOpDesc(matched); auto fc_op = LiteOpRegistry::Global().Create("fc"); - auto mul = matched.at("mul")->stmt()->op; + auto mul = matched.at("mul")->stmt()->op(); auto* scope = mul->scope(); auto& valid_places = mul->valid_places(); fc_op->Attach(op_desc, scope); diff --git a/paddle/fluid/lite/core/mir/generate_program_pass.cc b/paddle/fluid/lite/core/mir/generate_program_pass.cc index 75ff159015d..97586d74842 100644 --- a/paddle/fluid/lite/core/mir/generate_program_pass.cc +++ b/paddle/fluid/lite/core/mir/generate_program_pass.cc @@ -29,7 +29,7 @@ void GenerateProgramPass::Apply(const std::unique_ptr& graph) { if (item->IsStmt()) { auto& stmt = item->AsStmt(); VLOG(4) << stmt; - insts_.emplace_back(stmt.op, std::move(stmt.valid_kernels.front())); + insts_.emplace_back(stmt.op(), std::move(stmt.kernels().front())); } } } diff --git a/paddle/fluid/lite/core/mir/graph_visualize_pass.cc b/paddle/fluid/lite/core/mir/graph_visualize_pass.cc index 6a13bafd67c..90a99b5deb1 100644 --- a/paddle/fluid/lite/core/mir/graph_visualize_pass.cc +++ b/paddle/fluid/lite/core/mir/graph_visualize_pass.cc @@ -39,7 +39,7 @@ std::string Visualize(mir::SSAGraph* graph) { if (node.IsArg()) { key = node.AsArg().name; } else { - key = node.AsStmt().op_type + std::to_string(id++); + key = node.AsStmt().op_type() + std::to_string(id++); } if (node.IsStmt()) { diff --git a/paddle/fluid/lite/core/mir/io_copy_kernel_pick_pass.cc b/paddle/fluid/lite/core/mir/io_copy_kernel_pick_pass.cc index ebf9e5a57bf..9f38ce01ba1 100644 --- a/paddle/fluid/lite/core/mir/io_copy_kernel_pick_pass.cc +++ b/paddle/fluid/lite/core/mir/io_copy_kernel_pick_pass.cc @@ -25,11 +25,11 @@ class IoCopyKernelPickPass : public StmtPass { for (auto& node : graph->mutable_nodes()) { if (!node.IsStmt()) continue; auto& inst = node.AsStmt(); - if (inst.op_type != "io_copy") continue; + if (inst.op_type() != "io_copy") continue; LOG(INFO) << "....> picking a IO COPY kernel"; - auto& kernels = node.AsStmt().valid_kernels; + auto& kernels = node.AsStmt().kernels(); CHECK(!kernels.empty()) << "No valid kernels found for IoCopy Op"; const auto* inty = node.inlinks.front()->AsArg().type; const auto* outy = node.outlinks.front()->AsArg().type; diff --git a/paddle/fluid/lite/core/mir/node.cc b/paddle/fluid/lite/core/mir/node.cc index 711ff508f23..814df2b61a2 100644 --- a/paddle/fluid/lite/core/mir/node.cc +++ b/paddle/fluid/lite/core/mir/node.cc @@ -13,3 +13,62 @@ // limitations under the License. #include "paddle/fluid/lite/core/mir/node.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { + +const OpInfo *mir::Node::Stmt::op_info() const { + CHECK(op_); + return op_->op_info(); +} + +Place mir::Node::Stmt::place() const { + CHECK(!valid_kernels_.empty()); + return valid_kernels_.front()->place(); +} + +KernelBase &mir::Node::Stmt::picked_kernel() { + CHECK(!valid_kernels_.empty()) << "no kernel for " << op_type(); + return *valid_kernels_.front(); +} + +OpInfo *mir::Node::Stmt::mutable_op_info() { + CHECK(op_); + return op_->mutable_op_info(); +} + +void mir::Node::Stmt::ResetOp(const cpp::OpDesc &op_desc, + const std::vector &valid_places, + lite::Scope *scope) { + CHECK((op_ && op_->scope()) || scope) << "Either scope should be set"; + lite::Scope *the_scope = scope ? scope : op_->scope(); + op_->Attach(op_desc, the_scope); + // Recreate the kernels with the latest OpInfo. + valid_kernels_.clear(); + + if (!op_ || op_->op_info()->Type() != op_desc.Type()) { + op_ = LiteOpRegistry::Global().Create(op_desc.Type()); + CHECK(op_) << "No op found for " << op_desc.Type(); + } + valid_kernels_ = op_->CreateKernels(valid_places); +} + +std::ostream &mir::operator<<(std::ostream &os, const mir::Node::Stmt &other) { + os << "Statement " << other.op_type() << " " << other.place(); + return os; +} + +mir::Node::Arg &mir::Node::AsArg(const std::string &name, int id) { + auto &x = AsArg(); + x.name = name; + x.id = id; + return x; +} +mir::Node::Arg &mir::Node::AsArg(const std::string &name) { + auto &x = AsArg(); + x.name = name; + return x; +} +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/node.h b/paddle/fluid/lite/core/mir/node.h index a5fd90dac48..08b7a963e79 100644 --- a/paddle/fluid/lite/core/mir/node.h +++ b/paddle/fluid/lite/core/mir/node.h @@ -41,32 +41,40 @@ class Node { kUnk, }; - struct Stmt { - std::string op_type; + class Stmt { // The kernel instances this Statement contains. - std::vector> valid_kernels; + std::vector> valid_kernels_; // TODO(Superjomn) make this a shared_ptr for resource safety. - std::shared_ptr op; // we hold op to run InferShape + std::shared_ptr op_; // we hold op to run InferShape - const OpInfo* op_info() { - CHECK(op); - return op->op_info(); - } + public: + // Refresh the operator and kernels with the latest OpInfo. + void ResetOp(const cpp::OpDesc& op_desc, + const std::vector& valid_places, + lite::Scope* scope = nullptr); - Place place() const { - CHECK(!valid_kernels.empty()); - return valid_kernels.front()->place(); - } + std::string op_type() const { return op_info()->Type(); } + const OpInfo* op_info() const; + OpInfo* mutable_op_info(); - KernelBase& picked_kernel() { - CHECK(!valid_kernels.empty()) << "no kernel for " << op_type; - return *valid_kernels.front(); + void SetKernels(std::vector>&& kernels) { + valid_kernels_ = std::move(kernels); } - - friend std::ostream& operator<<(std::ostream& os, const Stmt& other) { - os << "Statement " << other.op_type << " " << other.place(); - return os; + std::vector>& kernels() { + return valid_kernels_; } + + void SetOp(const std::shared_ptr& op) { op_ = op; } + const std::shared_ptr op() const { return op_; } + + Place place() const; + + KernelBase& picked_kernel(); + + friend std::ostream& operator<<(std::ostream& os, const Stmt& other); + + // Description. + std::string desc; }; struct Arg { @@ -78,26 +86,16 @@ class Node { bool is_weight{false}; }; - Arg& AsArg(const std::string& name, int id) { - auto& x = AsArg(); - x.name = name; - x.id = id; - return x; - } + Arg& AsArg(const std::string& name, int id); - Arg& AsArg(const std::string& name) { - auto& x = AsArg(); - x.name = name; - return x; - } + Arg& AsArg(const std::string& name); Stmt& AsStmt(const std::string& op_type, std::vector>&& kernels, const std::shared_ptr& op) { auto& x = AsStmt(); - x.op_type = op_type; - x.op = op; - x.valid_kernels = std::move(kernels); + x.SetOp(op); + x.SetKernels(std::move(kernels)); return x; } @@ -142,7 +140,7 @@ class Node { } if (other.IsStmt()) { auto& arg = other.AsStmt(); - os << "Statement " << arg.op_type; + os << "Statement " << arg.op_type(); } return os; } diff --git a/paddle/fluid/lite/core/mir/pattern_matcher.h b/paddle/fluid/lite/core/mir/pattern_matcher.h index ff9fbce35dd..76ed5f1dd0f 100644 --- a/paddle/fluid/lite/core/mir/pattern_matcher.h +++ b/paddle/fluid/lite/core/mir/pattern_matcher.h @@ -139,14 +139,13 @@ struct PMNode { template PMNode* assert_op_attr(const std::string& attr_name, const T& attr) { - asserts_.emplace_back([=](Node* x) { + asserts_.push_back([=](const Node* x) { if (x && x->IsStmt()) { auto* op_info = x->stmt()->op_info(); return op_info->HasAttr(attr_name) && op_info->GetAttr(attr_name) == attr; - } else { - return false; } + return false; }); return this; } diff --git a/paddle/fluid/lite/core/mir/pattern_matcher_high_api.cc b/paddle/fluid/lite/core/mir/pattern_matcher_high_api.cc index 57bba3aad14..9f0b2e1f322 100644 --- a/paddle/fluid/lite/core/mir/pattern_matcher_high_api.cc +++ b/paddle/fluid/lite/core/mir/pattern_matcher_high_api.cc @@ -41,6 +41,7 @@ void FuseBase::DeleteInterNodes(SSAGraph *graph) { } } + LOG(INFO) << "keys: " << key2nodes_.size(); std::unordered_set nodes2rm; for (auto &matched : key2nodes_) { for (const auto &key : keys) { diff --git a/paddle/fluid/lite/core/mir/pattern_matcher_high_api.h b/paddle/fluid/lite/core/mir/pattern_matcher_high_api.h index b3a23c654bd..7c3f890383d 100644 --- a/paddle/fluid/lite/core/mir/pattern_matcher_high_api.h +++ b/paddle/fluid/lite/core/mir/pattern_matcher_high_api.h @@ -49,7 +49,13 @@ class FuseBase { virtual void BuildPattern() = 0; // Generate an operator desc with a matched subgraph. - virtual cpp::OpDesc GenOpDesc(const key2nodes_t& matched) = 0; + virtual cpp::OpDesc GenOpDesc(const key2nodes_t& matched) { + return cpp::OpDesc(); + } + + PMNode* OpNode(const std::string& key) { + return GetOrCreateNode(key)->assert_is_op(); + } PMNode* OpNode(const std::string& key, const std::string& op_type); diff --git a/paddle/fluid/lite/core/mir/pattern_matcher_high_api_test.cc b/paddle/fluid/lite/core/mir/pattern_matcher_high_api_test.cc index 7a46bb9a93d..d0844b0b7ef 100644 --- a/paddle/fluid/lite/core/mir/pattern_matcher_high_api_test.cc +++ b/paddle/fluid/lite/core/mir/pattern_matcher_high_api_test.cc @@ -52,7 +52,7 @@ class FcFuser : public FuseBase { void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override { auto op_desc = GenOpDesc(matched); auto fc_op = LiteOpRegistry::Global().Create("fc"); - auto mul = matched.at("mul")->stmt()->op; + auto mul = matched.at("mul")->stmt()->op(); auto* scope = mul->scope(); auto& valid_places = mul->valid_places(); fc_op->Attach(op_desc, scope); @@ -90,7 +90,7 @@ std::unique_ptr BuildGraph(framework::ProgramDesc* program_desc, main_block->Var("w"); main_block->Var("out"); - scope->Var("w")->GetMutable(); + scope->Var("x")->GetMutable(); scope->Var("b")->GetMutable(); scope->Var("mul_out")->GetMutable(); scope->Var("w")->GetMutable(); diff --git a/paddle/fluid/lite/core/mir/pattern_matcher_test.cc b/paddle/fluid/lite/core/mir/pattern_matcher_test.cc index 3b082060fe2..8f2ca38f1cc 100644 --- a/paddle/fluid/lite/core/mir/pattern_matcher_test.cc +++ b/paddle/fluid/lite/core/mir/pattern_matcher_test.cc @@ -23,19 +23,19 @@ namespace mir { void BuildGraph(SSAGraph* g) { g->mutable_nodes().emplace_back(); Node& o1 = g->mutable_nodes().back(); - o1.AsStmt().op_type = "op1"; + o1.AsStmt().desc = "op1"; g->mutable_nodes().emplace_back(); Node& o2 = g->mutable_nodes().back(); - o2.AsStmt().op_type = "op2"; + o2.AsStmt().desc = "op2"; g->mutable_nodes().emplace_back(); Node& o3 = g->mutable_nodes().back(); - o3.AsStmt().op_type = "op3"; + o3.AsStmt().desc = "op3"; g->mutable_nodes().emplace_back(); Node& o4 = g->mutable_nodes().back(); - o4.AsStmt().op_type = "op4"; + o4.AsStmt().desc = "op4"; g->mutable_nodes().emplace_back(); Node& o5 = g->mutable_nodes().back(); - o5.AsStmt().op_type = "op5"; + o5.AsStmt().desc = "op5"; g->mutable_nodes().emplace_back(); Node& v1 = g->mutable_nodes().back(); v1.AsArg("var1"); @@ -108,11 +108,11 @@ TEST(PatternMatcher, MarkPMNodesInGraph) { // v2 -> o3(a node named o3) auto* o2 = x.pattern_.NewNode([](const Node* node) { // The teller can be any condition, such as op type, or variable's shape. - return node && node->IsStmt() && node->stmt()->op_type == "op2"; + return node && node->IsStmt() && node->stmt()->desc == "op2"; }); auto* o3 = x.pattern_.NewNode([](const Node* node) { // The teller can be any condition, such as op type, or variable's shape. - return node && node->IsStmt() && node->stmt()->op_type == "op3"; + return node && node->IsStmt() && node->stmt()->desc == "op3"; }); auto* v2 = x.pattern_.NewNode([](const Node* node) { // The teller can be any condition, such as op type, or variable's shape. @@ -153,8 +153,8 @@ TEST(PatternMatcher, MultiSubgraph) { // op -> var auto* any_op = x.mutable_pattern()->NewNode( [](const Node* node) { - return node->IsStmt() && (node->stmt()->op_type == "op2" || - node->stmt()->op_type == "op3"); + return node->IsStmt() && + (node->stmt()->desc == "op2" || node->stmt()->desc == "op3"); }, "OP0"); auto* any_var = @@ -170,9 +170,9 @@ TEST(PatternMatcher, MultiSubgraph) { int count = 0; PatternMatcher::handle_t handle = [&](const PatternMatcher::subgraph_t& s, SSAGraph* g) { - LOG(INFO) << "Detect " << s.at(any_op)->stmt()->op_type << " -> " + LOG(INFO) << "Detect " << s.at(any_op)->stmt()->desc << " -> " << s.at(any_var)->arg()->name << " -> " - << s.at(any_op1)->stmt()->op_type; + << s.at(any_op1)->stmt()->desc; count++; }; @@ -197,12 +197,12 @@ TEST(PatternMatcher, IntermediateCheck) { PatternMatcher matcher; auto* op2 = matcher.mutable_pattern()->NewNode( [](const Node* x) { - return x && x->IsStmt() && x->stmt()->op_type == "op2"; + return x && x->IsStmt() && x->stmt()->desc == "op2"; }, "op2"); auto* op3 = matcher.mutable_pattern()->NewNode( [](const Node* x) { - return x && x->IsStmt() && x->stmt()->op_type == "op3"; + return x && x->IsStmt() && x->stmt()->desc == "op3"; }, "op3"); auto* v2 = matcher.mutable_pattern() diff --git a/paddle/fluid/lite/core/mir/ssa_graph.h b/paddle/fluid/lite/core/mir/ssa_graph.h index 7c0e6cef498..0a6f4022dd9 100644 --- a/paddle/fluid/lite/core/mir/ssa_graph.h +++ b/paddle/fluid/lite/core/mir/ssa_graph.h @@ -65,6 +65,10 @@ class SSAGraph : GraphBase { Node *GraphCreateInstructNode(const std::shared_ptr &op, const std::vector &valid_places); + // Device related attributes + const std::vector &valid_places() const { return valid_places_; } + void SetValidPlaces(const std::vector &x) { valid_places_ = x; } + private: mir::Node *Argument(const std::string &name); // Check the bidirectional connection. @@ -89,6 +93,7 @@ class SSAGraph : GraphBase { private: std::list node_storage_; std::map arguments_; + std::vector valid_places_; }; // Remove the link between a -> b. diff --git a/paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc b/paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc index 9d48c123a0c..93ee96bbf0a 100644 --- a/paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc +++ b/paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc @@ -37,9 +37,9 @@ void StaticKernelPickPass::Apply(const std::unique_ptr& graph) { if (!node.IsStmt()) continue; auto& instruct = node.AsStmt(); std::vector>> scored; - CHECK(!instruct.valid_kernels.empty()) << "No kernels found for " - << instruct.op_type; - for (auto&& kernel : instruct.valid_kernels) { + CHECK(!instruct.kernels().empty()) << "No kernels found for " + << instruct.op_type(); + for (auto&& kernel : instruct.kernels()) { size_t score = KernelGrade(*kernel); scored.emplace_back(score, std::move(kernel)); } @@ -49,9 +49,9 @@ void StaticKernelPickPass::Apply(const std::unique_ptr& graph) { // Move kernel back // Just keep a single best kernel. // TODO(Superjomn) reconsider this. - instruct.valid_kernels.clear(); - instruct.valid_kernels.emplace_back(std::move(scored.front().second)); - VLOG(2) << "pick " << instruct.valid_kernels.front()->name(); + instruct.kernels().clear(); + instruct.kernels().emplace_back(std::move(scored.front().second)); + VLOG(2) << "pick " << instruct.kernels().front()->name(); } } diff --git a/paddle/fluid/lite/core/mir/type_target_transform_pass.cc b/paddle/fluid/lite/core/mir/type_target_transform_pass.cc index 12dd2dcff06..951e3423e56 100644 --- a/paddle/fluid/lite/core/mir/type_target_transform_pass.cc +++ b/paddle/fluid/lite/core/mir/type_target_transform_pass.cc @@ -62,7 +62,7 @@ void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph, Node* inst_node, CHECK(in->AsArg().type); if (!TargetCompatibleTo(*in->AsArg().type, *decl_arg_type)) { LOG(INFO) << "found Target unmatched tensor: " << in->AsArg().name - << " for kernel " << inst.op->DebugString() << " " + << " for kernel " << inst.op()->DebugString() << " " << *in->AsArg().type << " -> " << *decl_arg_type; // Add an IoCopy instruction to make the input compatible with other dist. AddIoCopyInst(*in->AsArg().type, *decl_arg_type, in, graph, inst_node, @@ -89,7 +89,7 @@ void TypeTargetTransformPass::AddIoCopyInst( CHECK(io_copy_op) << "create op [" << io_copy_op << "] failed"; // CHECK(io_copy_op); // Create the new var manually. - inst_node->AsStmt().op->scope()->Var(io_copy_output_name); + inst_node->AsStmt().op()->scope()->Var(io_copy_output_name); // Create IoCopy Instruction. cpp::OpDesc op_desc; @@ -97,7 +97,7 @@ void TypeTargetTransformPass::AddIoCopyInst( op_desc.SetInput("Input", {in->AsArg().name}); op_desc.SetOutput("Out", {io_copy_output_name}); - io_copy_op->Attach(op_desc, inst_node->AsStmt().op->scope()); + io_copy_op->Attach(op_desc, inst_node->AsStmt().op()->scope()); auto kernels = io_copy_op->CreateKernels(valid_places); io_copy_inst->AsStmt("io_copy", std::move(kernels), io_copy_op); @@ -113,19 +113,19 @@ void TypeTargetTransformPass::AddIoCopyInst( DirectedLink(io_copy_output_arg, inst_node); // reset opdesc and update kernel information - UpdateInputTo(inst_node->AsStmt().op->mutable_op_info(), in->AsArg().name, + UpdateInputTo(inst_node->AsStmt().op()->mutable_op_info(), in->AsArg().name, io_copy_output_name); - inst_node->AsStmt().op->Attach(*inst_node->AsStmt().op->op_info(), - inst_node->AsStmt().op->scope()); + inst_node->AsStmt().ResetOp(*inst_node->AsStmt().op_info(), + graph->valid_places()); std::string tmp; if (inst_node->AsStmt().op_info()->GetInputArgname("a", &tmp)) { CHECK(false) << "get old a " << tmp; } - for (auto& kernel : inst_node->AsStmt().valid_kernels) { - inst_node->AsStmt().op->AttachKernel(kernel.get()); + for (auto& kernel : inst_node->AsStmt().kernels()) { + inst_node->AsStmt().op()->AttachKernel(kernel.get()); } graph->CheckValid(); diff --git a/paddle/fluid/lite/core/mir/use_passes.h b/paddle/fluid/lite/core/mir/use_passes.h index f2c81ee57e3..5203ad3f141 100644 --- a/paddle/fluid/lite/core/mir/use_passes.h +++ b/paddle/fluid/lite/core/mir/use_passes.h @@ -24,9 +24,11 @@ USE_MIR_PASS(generate_program_pass); USE_MIR_PASS(io_copy_kernel_pick_pass); USE_MIR_PASS(argument_type_display_pass); #endif + USE_MIR_PASS(runtime_context_assign_pass); USE_MIR_PASS(lite_conv_bn_fuse_pass); USE_MIR_PASS(graph_visualze); USE_MIR_PASS(lite_fc_fuse_pass); +USE_MIR_PASS(identity_scale_eliminate_pass); USE_MIR_PASS(lite_conv_elementwise_add_activation_fuse_pass); USE_MIR_PASS(lite_elementwise_add_activation_fuse_pass); diff --git a/paddle/fluid/lite/core/mir/variable_place_inference_pass.h b/paddle/fluid/lite/core/mir/variable_place_inference_pass.h index 2128c6d2014..0a5b3c341ab 100644 --- a/paddle/fluid/lite/core/mir/variable_place_inference_pass.h +++ b/paddle/fluid/lite/core/mir/variable_place_inference_pass.h @@ -39,7 +39,7 @@ class VariablePlaceInferencePass : public DebugPass { for (const auto& v : graph->inputs()) { // the feed op might in the inputs if (v->IsStmt()) { - LOG(INFO) << "found kernel in inputs " << v->AsStmt().op_type; + LOG(INFO) << "found kernel in inputs " << v->AsStmt().op_type(); continue; } } @@ -59,10 +59,10 @@ class VariablePlaceInferencePass : public DebugPass { for (auto& x : graph->StmtTopologicalOrder()) { auto& inst = x->AsStmt(); // The IoCopyOp is a tool operator, it won't support the type inference. - if (inst.op_type == "io_copy") continue; + if (inst.op_type() == "io_copy") continue; // LOG(INFO) << "- inferencing type " << // deal with inputs - VLOG(4) << "inferencing op " << inst.op_type; + VLOG(4) << "Infering op " << inst.op_info()->Repr(); // TODO(zhaolong): Add check if the node's name in op's arguments. auto get_argname = [&]( @@ -90,12 +90,14 @@ class VariablePlaceInferencePass : public DebugPass { } } + VLOG(3) << "inst " << inst.op_info()->Repr(); for (auto* x_out : x->outlinks) { std::string node_name = x_out->AsArg().name; std::string arg_name = get_argname(node_name, inst.op_info()->outputs()); CHECK(arg_name.size() > 0) << "can not found op arguments for node " - << node_name; + << node_name << " in Inst " + << inst.op_type(); VLOG(3) << "-- output arg_name " << arg_name; auto type = inst.picked_kernel().GetOutputDeclType(arg_name); if (!x_out->AsArg().type) { diff --git a/paddle/fluid/lite/core/op_lite.cc b/paddle/fluid/lite/core/op_lite.cc index 484d22abf52..31c339a5e63 100644 --- a/paddle/fluid/lite/core/op_lite.cc +++ b/paddle/fluid/lite/core/op_lite.cc @@ -61,7 +61,6 @@ std::vector> OpLite::CreateKernels( targets.insert(place.target); } - // CHECK(!kernels.empty()) << "No kernel found for Op " << op_type_; VLOG(2) << "op " << op_type_ << " get " << kernels.size() << " kernels"; return kernels; } @@ -83,7 +82,7 @@ bool OpLite::Attach(const cpp::OpDesc &opdesc, lite::Scope *scope) { scope_ = scope; op_info_.reset( new OpInfo(opdesc)); // Force clean the out-of-date infomation. - return AttachImpl(opdesc, scope); + return AttachImpl(*op_info(), scope); } const Tensor *OpLite::GetTensor(lite::Scope *scope, diff --git a/paddle/fluid/lite/core/op_lite.h b/paddle/fluid/lite/core/op_lite.h index 38cce73d291..cd7d9ef8449 100644 --- a/paddle/fluid/lite/core/op_lite.h +++ b/paddle/fluid/lite/core/op_lite.h @@ -197,6 +197,22 @@ class OpInfo : public cpp::OpDesc { } return false; } + + void UpdateAllInputs(const std::string &from, const std::string &to) { + for (auto &item : inputs_) { + for (auto &var : item.second) { + if (var == from) var = to; + } + } + } + + void UpdateAllOutputs(const std::string &from, const std::string &to) { + for (auto &item : outputs_) { + for (auto &var : item.second) { + if (var == from) var = to; + } + } + } }; } // namespace lite diff --git a/paddle/fluid/lite/core/optimizer.h b/paddle/fluid/lite/core/optimizer.h index c2c1121f53e..b936a139cbc 100644 --- a/paddle/fluid/lite/core/optimizer.h +++ b/paddle/fluid/lite/core/optimizer.h @@ -43,6 +43,8 @@ class Optimizer { CHECK(!graph_) << "duplicate optimize found"; graph_.reset(new mir::SSAGraph); graph_->Build(program, valid_places); + graph_->SetValidPlaces(valid_places); + SpecifyKernelPickTactic(kernel_pick_factor); InitTargetTypeTransformPass(); @@ -51,6 +53,7 @@ class Optimizer { "lite_conv_bn_fuse_pass", // "lite_conv_elementwise_add_activation_fuse_pass", // "lite_fc_fuse_pass", // + "identity_scale_eliminate_pass", // #ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK "lite_elementwise_add_activation_fuse_pass", // #endif diff --git a/paddle/fluid/lite/core/program.h b/paddle/fluid/lite/core/program.h index 2f3e078462a..46da1815f19 100644 --- a/paddle/fluid/lite/core/program.h +++ b/paddle/fluid/lite/core/program.h @@ -140,7 +140,7 @@ class RuntimeProgram { void Run() { for (auto& inst : instructions_) { - VLOG(4) << ">> Running kernel: " << inst; + VLOG(3) << ">> Running kernel: " << inst.op()->op_info()->Repr(); inst.Run(); } } diff --git a/paddle/fluid/lite/core/tensor.h b/paddle/fluid/lite/core/tensor.h index 1d61f72063b..2c001c84e4c 100644 --- a/paddle/fluid/lite/core/tensor.h +++ b/paddle/fluid/lite/core/tensor.h @@ -191,7 +191,6 @@ class TensorBase { template bool TensorCompareWith(const TensorT &a, const TensorT &b) { if (a.dims() != b.dims()) return false; - LOG(INFO) << "data_size: " << a.data_size(); if (memcmp(a.raw_data(), b.raw_data(), a.data_size()) != 0) return false; return true; } diff --git a/paddle/fluid/lite/kernels/x86/elementwise_compute_test.cc b/paddle/fluid/lite/kernels/x86/elementwise_compute_test.cc index abb28e2bb58..8bade95f58c 100644 --- a/paddle/fluid/lite/kernels/x86/elementwise_compute_test.cc +++ b/paddle/fluid/lite/kernels/x86/elementwise_compute_test.cc @@ -15,6 +15,8 @@ #include "paddle/fluid/lite/kernels/x86/elementwise_compute.h" #include #include +#include +#include #include #include "paddle/fluid/lite/core/op_registry.h" diff --git a/paddle/fluid/lite/kernels/x86/mul_compute.h b/paddle/fluid/lite/kernels/x86/mul_compute.h index 96f90842f69..0f95fea934a 100644 --- a/paddle/fluid/lite/kernels/x86/mul_compute.h +++ b/paddle/fluid/lite/kernels/x86/mul_compute.h @@ -40,12 +40,20 @@ class MulCompute : public KernelLite { auto* x = ¶m.x->raw_tensor(); auto* y = ¶m.y->raw_tensor(); - const Tensor x_matrix = x->dims().size() > 2 ? framework::ReshapeToMatrix( - *x, param.x_num_col_dims) - : *x; - const Tensor y_matrix = y->dims().size() > 2 ? framework::ReshapeToMatrix( - *y, param.y_num_col_dims) - : *y; + Tensor x_matrix, y_matrix; + + if (x->dims().size() > 2) { + x_matrix = framework::ReshapeToMatrix(*x, param.x_num_col_dims); + } else { + x_matrix = *x; + } + + if (y->dims().size() > 2) { + y_matrix = framework::ReshapeToMatrix(*y, param.y_num_col_dims); + + } else { + y_matrix = *y; + } auto* z = ¶m.output->raw_tensor(); auto z_dim = z->dims(); @@ -75,12 +83,22 @@ class MulGradCompute : public KernelLite { auto* x = ¶m.x->raw_tensor(); auto* y = ¶m.y->raw_tensor(); - auto x_matrix = x->dims().size() > 2 - ? framework::ReshapeToMatrix(*x, param.x_num_col_dims) - : static_cast(*x); - auto y_matrix = y->dims().size() > 2 - ? framework::ReshapeToMatrix(*y, param.y_num_col_dims) - : static_cast(*y); + + Tensor x_matrix, y_matrix; + + if (x->dims().size() > 2) { + x_matrix = framework::ReshapeToMatrix(*x, param.x_num_col_dims); + } else { + x_matrix = *x; + } + + if (y->dims().size() > 2) { + y_matrix = framework::ReshapeToMatrix(*y, param.y_num_col_dims); + + } else { + y_matrix = *y; + } + auto* dout = ¶m.output_grad->raw_tensor(); Tensor dout_mat; diff --git a/paddle/fluid/lite/kernels/x86/mul_compute_test.cc b/paddle/fluid/lite/kernels/x86/mul_compute_test.cc index 50854d29d09..c551754328e 100644 --- a/paddle/fluid/lite/kernels/x86/mul_compute_test.cc +++ b/paddle/fluid/lite/kernels/x86/mul_compute_test.cc @@ -15,6 +15,8 @@ #include "paddle/fluid/lite/kernels/x86/mul_compute.h" #include #include +#include +#include #include #include "paddle/fluid/lite/core/op_registry.h" diff --git a/paddle/fluid/lite/model_parser/CMakeLists.txt b/paddle/fluid/lite/model_parser/CMakeLists.txt index c539e409a65..2690fa0206b 100644 --- a/paddle/fluid/lite/model_parser/CMakeLists.txt +++ b/paddle/fluid/lite/model_parser/CMakeLists.txt @@ -11,7 +11,8 @@ if(NOT LITE_ON_MOBILE) endif() -cc_library(compatible_pb_lite SRCS compatible_pb.cc DEPS op_desc_lite framework_proto_lite var_desc_lite) +cc_library(compatible_pb_lite SRCS compatible_pb.cc + DEPS op_desc_lite framework_proto_lite var_desc_lite cpp_op_desc_lite) lite_cc_library(model_parser_lite SRCS model_parser.cc DEPS variable_lite scope_lite ${tensor_lite} scope_lite diff --git a/paddle/fluid/lite/model_parser/cpp/CMakeLists.txt b/paddle/fluid/lite/model_parser/cpp/CMakeLists.txt index e6e2fc77f00..71073179991 100644 --- a/paddle/fluid/lite/model_parser/cpp/CMakeLists.txt +++ b/paddle/fluid/lite/model_parser/cpp/CMakeLists.txt @@ -1,2 +1 @@ cc_library(cpp_op_desc_lite SRCS op_desc.cc DEPS any_lite) - diff --git a/paddle/fluid/lite/model_parser/desc_apis.h b/paddle/fluid/lite/model_parser/desc_apis.h index d28f82a0e73..5981b873f7c 100644 --- a/paddle/fluid/lite/model_parser/desc_apis.h +++ b/paddle/fluid/lite/model_parser/desc_apis.h @@ -14,6 +14,7 @@ #pragma once #include +#include #include #include @@ -79,6 +80,27 @@ class OpDescAPI { /// Get an attribute. template T GetAttr(const std::string& name) const; + + std::string Repr() const { + std::stringstream ss; + ss << Type(); + ss << "("; + for (auto& arg : InputArgumentNames()) { + ss << arg << ":"; + for (auto val : Input(arg)) { + ss << val << " "; + } + } + ss << ") -> ("; + for (auto& arg : OutputArgumentNames()) { + ss << arg << ":"; + for (auto val : Output(arg)) { + ss << val << " "; + } + } + ss << ")"; + return ss.str(); + } }; } // namespace lite diff --git a/paddle/fluid/lite/model_parser/pb/op_desc.h b/paddle/fluid/lite/model_parser/pb/op_desc.h index e8772e162a5..b64ba5452d6 100644 --- a/paddle/fluid/lite/model_parser/pb/op_desc.h +++ b/paddle/fluid/lite/model_parser/pb/op_desc.h @@ -141,6 +141,8 @@ class OpDesc : public OpDescAPI { template T GetAttr(const std::string &name) const; + std::string DebugString() const { return desc_.DebugString(); } + private: std::vector GetArguments( const google::protobuf::RepeatedPtrField diff --git a/paddle/fluid/lite/operators/mul_op.h b/paddle/fluid/lite/operators/mul_op.h index 7aa1581bb2a..a01427b1f4c 100644 --- a/paddle/fluid/lite/operators/mul_op.h +++ b/paddle/fluid/lite/operators/mul_op.h @@ -38,15 +38,19 @@ class MulOpLite : public OpLite { void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } // TODO(Superjomn) replace framework::OpDesc with a lite one. bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { + CHECK(!op_desc.Input("X").empty()); + CHECK(!op_desc.Input("Y").empty()); + CHECK(!op_desc.Output("Out").empty()); + auto input = op_desc.Input("X").front(); auto W = op_desc.Input("Y").front(); auto out = op_desc.Output("Out").front(); auto *var = scope->FindVar(input); CHECK(var); - param_.x = var->GetMutable(); + param_.x = &var->Get(); var = scope->FindVar(W); CHECK(var) << "no var called " << W; - param_.y = var->GetMutable(); + param_.y = &var->Get(); var = scope->FindVar(out); CHECK(var) << "no var called " << out; param_.output = var->GetMutable(); diff --git a/paddle/fluid/lite/operators/op_params.h b/paddle/fluid/lite/operators/op_params.h index 0cc1e6b78e9..b50e14a4855 100644 --- a/paddle/fluid/lite/operators/op_params.h +++ b/paddle/fluid/lite/operators/op_params.h @@ -67,8 +67,8 @@ struct ReluParam { // For Mul Op struct MulParam { - lite::Tensor* x{}; - lite::Tensor* y{}; + const lite::Tensor* x{}; + const lite::Tensor* y{}; lite::Tensor* output{}; int x_num_col_dims{1}; diff --git a/paddle/fluid/lite/utils/varient.h b/paddle/fluid/lite/utils/varient.h index 2d2a3061108..52bbcffcef9 100644 --- a/paddle/fluid/lite/utils/varient.h +++ b/paddle/fluid/lite/utils/varient.h @@ -20,6 +20,7 @@ #include #include #include "paddle/fluid/lite/utils/cp_logging.h" +#include "paddle/fluid/lite/utils/string.h" // This is an equivalent implementation of boost::any. We implement this to // avoid including the whole boost library and keep the inference library small. @@ -116,9 +117,9 @@ struct variant { if (type_id == typeid(T).hash_code()) return *reinterpret_cast(&data); else - throw std::invalid_argument("unmatched type"); - // LOG(FATAL) << "unmatched type get, should be " << type_id << " but get " - // << typeid(T).name(); + throw std::invalid_argument( + string_format("unmatched type, store as %d, but want to get %s", + type_id, typeid(T).name())); return *reinterpret_cast(&data); } -- GitLab