diff --git a/paddle/fluid/lite/api/apis_test.cc b/paddle/fluid/lite/api/apis_test.cc index 4d99f238dd6b6af6597b2a5f0b41ac7d4580da79..7dd6a1193754437a32957f081b3be3fd5c1fc403 100644 --- a/paddle/fluid/lite/api/apis_test.cc +++ b/paddle/fluid/lite/api/apis_test.cc @@ -82,7 +82,7 @@ TEST(CXXApi_LightApi, save_and_load_model) { ASSERT_TRUE(TensorCompareWith(*cxx_out, *light_out)); std::vector tensors_with_order({ - "a", "fc_0.w_0", "fc_0.tmp_0", "scale_0.tmp_0", + "a", "fc_0.w_0", "scale_0.tmp_0", }); for (const auto& tensor_name : tensors_with_order) { diff --git a/paddle/fluid/lite/api/cxx_api_bin.cc b/paddle/fluid/lite/api/cxx_api_bin.cc index b064c663d11fb9c56c2d6c38501adceb0bf76814..af6f8e44d69b2d228fc32b9ae3d926fe6ecca69f 100644 --- a/paddle/fluid/lite/api/cxx_api_bin.cc +++ b/paddle/fluid/lite/api/cxx_api_bin.cc @@ -13,7 +13,7 @@ // limitations under the License. #include "paddle/fluid/lite/api/cxx_api.h" -#include +#include // NOLINT #include "paddle/fluid/lite/core/mir/use_passes.h" #include "paddle/fluid/lite/core/op_registry.h" diff --git a/paddle/fluid/lite/api/lite_api_test_helper.cc b/paddle/fluid/lite/api/lite_api_test_helper.cc index 490a64bb512bdf31359b6204399b1e1767bb4f17..b82541723308f4748e28c64affa6899bf2d9b727 100644 --- a/paddle/fluid/lite/api/lite_api_test_helper.cc +++ b/paddle/fluid/lite/api/lite_api_test_helper.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/lite/api/lite_api_test_helper.h" +#include DEFINE_string(model_dir, "", ""); DEFINE_string(optimized_model, "", ""); diff --git a/paddle/fluid/lite/core/CMakeLists.txt b/paddle/fluid/lite/core/CMakeLists.txt index 665d7555e3757188f8a7b76496fa85cb20192670..1e95668cddc722e32ea784fe2331380ea3a3940e 100644 --- a/paddle/fluid/lite/core/CMakeLists.txt +++ b/paddle/fluid/lite/core/CMakeLists.txt @@ -59,4 +59,3 @@ lite_cc_test(test_type_system SRCS type_system_test.cc DEPS type_system utils_li lite_cc_test(test_types_lite SRCS types_test.cc DEPS types_lite) lite_cc_test(test_memory_lite SRCS memory_test.cc DEPS memory_lite) lite_cc_test(test_context_lite SRCS context_test.cc DEPS context_lite X86_DEPS operator) - diff --git a/paddle/fluid/lite/core/mir/CMakeLists.txt b/paddle/fluid/lite/core/mir/CMakeLists.txt index 0d7fcf8b3b2843c0d36be24288743a86b8c7ea24..e67ade8cbef5c574ce911bee403a152a23aa045e 100644 --- a/paddle/fluid/lite/core/mir/CMakeLists.txt +++ b/paddle/fluid/lite/core/mir/CMakeLists.txt @@ -5,21 +5,25 @@ cc_library(mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph mir cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager) add_subdirectory(fusion) +add_subdirectory(elimination) + cc_library(mir_passes - SRCS fc_fuse_pass.cc - conv_elementwise_add_activation_fuse_pass.cc - elementwise_add_activation_fuse_pass.cc - conv_bn_fuse_pass.cc - static_kernel_pick_pass.cc - variable_place_inference_pass.cc - type_target_transform_pass.cc - io_copy_kernel_pick_pass.cc - graph_visualize_pass.cc - generate_program_pass.cc - argument_type_display_pass.cc - demo_pass.cc - runtime_context_assign_pass.cc - DEPS mir_pass types_lite context_lite ${mir_fusers}) + SRCS + fusion/fc_fuse_pass.cc + fusion/conv_elementwise_add_activation_fuse_pass.cc + fusion/conv_bn_fuse_pass.cc + fusion/elementwise_add_activation_fuse_pass.cc + elimination/identity_scale_eliminate_pass.cc + static_kernel_pick_pass.cc + variable_place_inference_pass.cc + type_target_transform_pass.cc + io_copy_kernel_pick_pass.cc + graph_visualize_pass.cc + generate_program_pass.cc + argument_type_display_pass.cc + demo_pass.cc + runtime_context_assign_pass.cc + DEPS mir_pass types_lite context_lite ${mir_fusers}) #cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS #mir_ssa_graph scope_lite op_lite @@ -73,7 +77,7 @@ message(STATUS "----> Ops lite: ${ops_lite}") message(STATUS "----> Host kernels: ${host_kernels}") message(STATUS "----> X86 kernels: ${x86_kernels}") -lite_cc_test(test_lite_fc_fuse SRCS fc_fuse_pass_test.cc +lite_cc_test(test_lite_fc_fuse SRCS fusion/fc_fuse_pass_test.cc DEPS cxx_api_lite mir_passes ${ops_lite} ${host_kernels} ${x86_kernels} ${arm_kernels} ARGS --model_dir=${LITE_MODEL_DIR}/lite_fc_model @@ -83,11 +87,11 @@ lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "lite_fc_model.tar.gz add_dependencies(test_lite_fc_fuse extern_lite_download_lite_fc_model_tar_gz) -lite_cc_test(test_lite_conv_elementwise_add_activation_fuse - SRCS conv_elementwise_add_activation_fuse_pass_test.cc +lite_cc_test(test_lite_conv_elementwise_add_activation_fuse + SRCS fusion/conv_elementwise_add_activation_fuse_pass_test.cc DEPS cxx_api_lite mir_passes ${ops_lite} ${host_kernels} ${x86_kernels}) -lite_cc_test(test_lite_elementwise_add_activation_fuse - SRCS elementwise_add_activation_fuse_pass_test.cc +lite_cc_test(test_lite_elementwise_add_activation_fuse + SRCS fusion/elementwise_add_activation_fuse_pass_test.cc DEPS cxx_api_lite mir_passes ${ops_lite} ${host_kernels} ${x86_kernels}) diff --git a/paddle/fluid/lite/core/mir/elimination/CMakeLists.txt b/paddle/fluid/lite/core/mir/elimination/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..9fda8ec29a4da3a3a9b443448f10e27b93ce61e8 --- /dev/null +++ b/paddle/fluid/lite/core/mir/elimination/CMakeLists.txt @@ -0,0 +1,7 @@ +if (NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) + lite_cc_test(test_identity_scale_eliminate_pass_lite + SRCS identity_scale_eliminate_pass_test.cc + DEPS mir_passes program_lite proto_desc cpp_op_desc_lite + ${ops_lite} + ) +endif() diff --git a/paddle/fluid/lite/core/mir/elimination/identity_scale_eliminate_pass.cc b/paddle/fluid/lite/core/mir/elimination/identity_scale_eliminate_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..6f8aeb65c0592a184ee436f8cefeb9d241a6943f --- /dev/null +++ b/paddle/fluid/lite/core/mir/elimination/identity_scale_eliminate_pass.cc @@ -0,0 +1,72 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/pass.h" +#include "paddle/fluid/lite/core/mir/pass_registry.h" +#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h" + +namespace paddle { +namespace lite { +namespace mir { + +namespace { + +class Eliminator : public FuseBase { + public: + void BuildPattern() override { + auto* pre_op = OpNode("preop"); // the previous op's output need update + // TODO(Superjomn) check has only one output + auto* x = VarNode("x")->assert_is_op_input("scale", "X"); + auto* scale_op = OpNode("scale", "scale") + ->assert_op_attr("scale", 1.) + ->assert_op_attr("bias", 0.); + auto* out = VarNode("out")->assert_is_op_output("scale", "Out"); + + *pre_op >> *x >> *scale_op >> *out; + + // The pre_op will be eliminated, and a new output-updated op will insert. + x->AsIntermediate(); // x is pre_op's output, need to update + } + + private: + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override { + auto& pre_op = matched.at("preop")->AsStmt(); + auto op_info = *pre_op.op_info(); + + op_info.UpdateAllOutputs(matched.at("x")->AsArg().name, + matched.at("out")->AsArg().name); + pre_op.ResetOp(op_info, graph->valid_places()); + + GraphSafeRemoveNodes(graph, {matched.at("scale")}); + + IR_NODE_LINK_TO(matched.at("preop"), matched.at("out")); + } +}; + +} // namespace + +class IdentityScaleEliminatePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override { + Eliminator eliminator; + eliminator(graph.get()); + } +}; + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(identity_scale_eliminate_pass, + paddle::lite::mir::IdentityScaleEliminatePass); diff --git a/paddle/fluid/lite/core/mir/elimination/identity_scale_eliminate_pass_test.cc b/paddle/fluid/lite/core/mir/elimination/identity_scale_eliminate_pass_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..89db35fe0e8b943b7691f51ed4febacee83ebd41 --- /dev/null +++ b/paddle/fluid/lite/core/mir/elimination/identity_scale_eliminate_pass_test.cc @@ -0,0 +1,93 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h" +#include "paddle/fluid/lite/core/mir/pass_registry.h" +#include "paddle/fluid/lite/core/mir/ssa_graph.h" + +namespace paddle { +namespace lite { +namespace mir { + +std::unique_ptr BuildGraph(framework::ProgramDesc* program_desc, + const std::shared_ptr& scope, + const std::vector& valid_places) { + // Op list: + // (x)->feed -> (feed) -> scale -> (scale_out) -> fetch->(fetch) + // After pass + // (x)->feed->(scale_out)->fetch->(fetch) + auto* main_block = program_desc->MutableBlock(0); + auto* feed_op = main_block->AppendOp(); + auto* scale_op = main_block->AppendOp(); + auto* fetch_op = main_block->AppendOp(); + main_block->Var("x"); + main_block->Var("feed"); + main_block->Var("scale_out"); + main_block->Var("fetch_out"); + + scope->Var("x")->GetMutable(); + scope->Var("feed")->GetMutable(); + scope->Var("scale_out")->GetMutable(); + scope->Var("fetch_out")->GetMutable(); + + feed_op->SetType("feed"); + feed_op->SetInput("X", {"x"}); + feed_op->SetAttr("col", 1); + feed_op->SetOutput("Out", {"feed"}); + + scale_op->SetType("scale"); + scale_op->SetInput("X", {"feed"}); + scale_op->SetOutput("Out", {"scale_out"}); + scale_op->SetAttr("scale", 1.f); + scale_op->SetAttr("bias", 0.f); + scale_op->SetAttr("bias_after_scale", true); + + fetch_op->SetType("fetch"); + fetch_op->SetInput("X", {"scale_out"}); + fetch_op->SetOutput("Out", {"fetch"}); + fetch_op->SetAttr("col", 1); + + program_desc->Flush(); + + lite::Program program(*program_desc->Proto(), scope, valid_places); + auto graph = std::unique_ptr(new SSAGraph()); + graph->Build(program, valid_places); + + LOG(INFO) << Visualize(graph.get()); + + return graph; +} + +TEST(identity_test, test) { + framework::ProgramDesc program_desc; + std::vector places{{TARGET(kHost), PRECISION(kFloat)}}; + auto scope = std::make_shared(); + auto graph = BuildGraph(&program_desc, scope, places); + const int num_nodes = graph->nodes().size(); + auto pass = PassManager::Global().LookUp("identity_scale_eliminate_pass"); + ASSERT_TRUE(pass); + pass->Apply(graph); + ASSERT_EQ(graph->nodes().size(), num_nodes - 2UL); +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +USE_LITE_OP(feed) +USE_LITE_OP(fetch) +USE_LITE_OP(scale) +USE_MIR_PASS(identity_scale_eliminate_pass) diff --git a/paddle/fluid/lite/core/mir/fusion/CMakeLists.txt b/paddle/fluid/lite/core/mir/fusion/CMakeLists.txt index 9139293c8aa59d5664e29afba97c02226f9338bf..db092e17679fb2f7ed33cda7d4e92b99b5039776 100644 --- a/paddle/fluid/lite/core/mir/fusion/CMakeLists.txt +++ b/paddle/fluid/lite/core/mir/fusion/CMakeLists.txt @@ -11,7 +11,7 @@ cc_library(fuse_elementwise_add_activation SRCS elementwise_add_activation_fuser.cc DEPS pattern_matcher_high_api) -set(mir_fusers +set(mir_fusers fuse_fc fuse_conv_elementwise_add_activation fuse_conv_bn diff --git a/paddle/fluid/lite/core/mir/conv_bn_fuse_pass.cc b/paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass.cc similarity index 94% rename from paddle/fluid/lite/core/mir/conv_bn_fuse_pass.cc rename to paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass.cc index 562ec7f45073a13f37c7f44ebcae0fb13fbb8b42..1e7d7bc5774c7902f2aea80678b05886f9482415 100644 --- a/paddle/fluid/lite/core/mir/conv_bn_fuse_pass.cc +++ b/paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h" +#include "paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass.h" #include #include #include "paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.h" diff --git a/paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h b/paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass.h similarity index 100% rename from paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h rename to paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass.h diff --git a/paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass_test.cc b/paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass_test.cc index 79436a9fa3d71111a5e805a804a77b9bda137134..3a8573b4f8c13684c2077164805966b3887a7f4f 100644 --- a/paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass_test.cc +++ b/paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h" +#include "paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass.h" #include #include #include diff --git a/paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.cc b/paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.cc index d29f078513e2113db12c67be4d694a6dc8de99f9..0a73d1e39d99ba0d3b4e4790bb689ed403e63f5e 100644 --- a/paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.cc +++ b/paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.cc @@ -70,7 +70,7 @@ void ConvBNFuser::BuildPattern() { void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { auto op_desc = GenOpDesc(matched); auto eltwise_op = LiteOpRegistry::Global().Create("elementwise_add"); - auto conv = matched.at("conv2d")->stmt()->op; + auto conv = matched.at("conv2d")->stmt()->op(); auto* scope = conv->scope(); auto& valid_places = conv->valid_places(); diff --git a/paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.cc b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuse_pass.cc similarity index 94% rename from paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.cc rename to paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuse_pass.cc index 27f6413c47b514d3203c5879d7ee7b9697d8cf5a..f4eb5a00ad24900fc97abbc6ce4e890d288e0872 100644 --- a/paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.cc +++ b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuse_pass.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.h" +#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuse_pass.h" #include #include #include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.h" diff --git a/paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.h b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuse_pass.h similarity index 100% rename from paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.h rename to paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuse_pass.h diff --git a/paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass_test.cc b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuse_pass_test.cc similarity index 98% rename from paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass_test.cc rename to paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuse_pass_test.cc index a67e577505f3ee1e099a5a3be3801116210c197d..e7751d801eaa1239a70e3fb0d128165029a31669 100644 --- a/paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass_test.cc +++ b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuse_pass_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.h" +#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuse_pass.h" #include #include #include diff --git a/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.cc b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.cc index b26b758fb2318b7c9a645503687f994b73009310..a085b139c86725360b4939c979cec685bf11879b 100644 --- a/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.cc +++ b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.cc @@ -65,7 +65,7 @@ void ConvElementwiseAddActivationFuser::InsertNewNode( SSAGraph* graph, const key2nodes_t& matched) { auto op_desc = GenOpDesc(matched); auto conv_op = LiteOpRegistry::Global().Create(conv_type_); - auto conv_old = matched.at("conv2d")->stmt()->op; + auto conv_old = matched.at("conv2d")->stmt()->op(); auto* scope = conv_old->scope(); auto& valid_places = conv_old->valid_places(); conv_op->Attach(op_desc, scope); diff --git a/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass.cc b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..4ace19f304bf1f935c82d138e3980e85e417d6f8 --- /dev/null +++ b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass.cc @@ -0,0 +1,39 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "conv_elementwise_add_relu_fuse_pass.h" +#include +#include +#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h" +#include "paddle/fluid/lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +void ConvElementwiseAddReLUFusePass::Apply( + const std::unique_ptr& graph) { + fusion::ConvElementwiseAddReLUFuser fuser("conv2d"); + fuser(graph.get()); + + fusion::ConvElementwiseAddReLUFuser depthwise_fuser("depthwise_conv2d"); + depthwise_fuser(graph.get()); +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(lite_conv_elementwise_add_act_fuse_pass, + paddle::lite::mir::ConvElementwiseAddReLUFusePass); diff --git a/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass.h b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..4276f1ffc8c258b0b4266abd950fa1ccf541c4a7 --- /dev/null +++ b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass.h @@ -0,0 +1,32 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/lite/core/mir/pass.h" + +namespace paddle { +namespace lite { +namespace mir { + +class ConvElementwiseAddReLUFusePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass_test.cc b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..00c9eaf8c07ce4f853ee51c39c752c51bf0c6ccd --- /dev/null +++ b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass_test.cc @@ -0,0 +1,153 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "conv_elementwise_add_relu_fuse_pass.h" +#include +#include +#include +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/lite/api/cxx_api.h" +#include "paddle/fluid/lite/core/compatible_tensor.h" +#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h" +#include "paddle/fluid/lite/core/mir/passes.h" +#include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/lite/core/program.h" + +DEFINE_string(model_dir, "", ""); +DEFINE_string(optimized_model, "", ""); + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +std::unique_ptr BuildGraph(framework::ProgramDesc* program_desc, + const std::shared_ptr& scope, + const std::vector& valid_places) { + auto* main_block = program_desc->MutableBlock(0); + + auto* conv2d_1 = main_block->AppendOp(); + auto* conv2d_2 = main_block->AppendOp(); + auto* add_1 = main_block->AppendOp(); + auto* relu_1 = main_block->AppendOp(); + auto* add_2 = main_block->AppendOp(); + auto* relu_2 = main_block->AppendOp(); + + main_block->Var("input_1"); + main_block->Var("input_2"); + main_block->Var("filter_1"); + main_block->Var("filter_2"); + main_block->Var("conv2d_1_out"); + main_block->Var("conv2d_2_out"); + main_block->Var("bias_1"); + main_block->Var("add_1_out"); + main_block->Var("add_2_out"); + main_block->Var("relu_1_out"); + main_block->Var("out"); + + scope->Var("input_1")->GetMutable(); + scope->Var("input_2")->GetMutable(); + scope->Var("filter_1")->GetMutable(); + scope->Var("filter_2")->GetMutable(); + scope->Var("conv2d_1_out")->GetMutable(); + scope->Var("conv2d_2_out")->GetMutable(); + scope->Var("bias_1")->GetMutable(); + scope->Var("add_1_out")->GetMutable(); + scope->Var("add_2_out")->GetMutable(); + scope->Var("relu_1_out")->GetMutable(); + scope->Var("out")->GetMutable(); + + conv2d_1->SetType("conv2d"); + conv2d_1->SetInput("Input", {"input_1"}); + conv2d_1->SetInput("Filter", {"filter_1"}); + conv2d_1->SetOutput("Output", {"conv2d_1_out"}); + conv2d_1->SetAttr("strides", std::vector({1, 1})); + conv2d_1->SetAttr("paddings", std::vector({0, 0})); + conv2d_1->SetAttr("groups", 1); + conv2d_1->SetAttr("dilations", std::vector({1, 1})); + conv2d_1->SetAttr("fuse_relu", false); + + add_1->SetType("elementwise_add"); + add_1->SetInput("X", {"conv2d_1_out"}); + add_1->SetInput("Y", {"bias_1"}); + add_1->SetOutput("Out", {"add_1_out"}); + add_1->SetAttr("axis", 1); + + relu_1->SetType("relu"); + relu_1->SetInput("X", {"add_1_out"}); + relu_1->SetOutput("Out", {"relu_1_out"}); + + conv2d_2->SetType("conv2d"); + conv2d_2->SetInput("Input", {"input_2"}); + conv2d_2->SetInput("Filter", {"filter_2"}); + conv2d_2->SetOutput("Output", {"conv2d_2_out"}); + conv2d_2->SetAttr("strides", std::vector({1, 1})); + conv2d_2->SetAttr("paddings", std::vector({0, 0})); + conv2d_2->SetAttr("groups", 1); + conv2d_2->SetAttr("dilations", std::vector({1, 1})); + conv2d_2->SetAttr("fuse_relu", false); + + add_2->SetType("elementwise_add"); + add_2->SetInput("X", {"conv2d_2_out"}); + add_2->SetInput("Y", {"relu_1_out"}); + add_2->SetOutput("Out", {"add_2_out"}); + add_2->SetAttr("axis", 1); + + relu_2->SetType("relu"); + relu_2->SetInput("X", {"add_2_out"}); + relu_2->SetOutput("Out", {"out"}); + + program_desc->Flush(); + + lite::Program program(*program_desc->Proto(), scope, valid_places); + auto graph = std::unique_ptr(new SSAGraph()); + graph->Build(program, valid_places); + + return graph; +} + +TEST(conv_elementwise_add_relu_fuse_pass, graph_test) { + framework::ProgramDesc program_desc; + std::vector places{{TARGET(kHost), PRECISION(kFloat)}}; + auto scope = std::make_shared(); + auto graph = BuildGraph(&program_desc, scope, places); + + Visualize(graph.get()); + ASSERT_EQ(graph->nodes().size(), 11UL /*vars*/ + 6UL /*ops*/); + Visualize(graph.get()); +} + +TEST(conv_elementwise_add_relu_fuse_pass, fuse_test_op) { + framework::ProgramDesc program_desc; + std::vector places{{TARGET(kHost), PRECISION(kFloat)}}; + auto scope = std::make_shared(); + auto graph = BuildGraph(&program_desc, scope, places); + Visualize(graph.get()); + const int num_nodes = graph->nodes().size(); + auto* fuser = new ConvElementwiseAddReLUFusePass; + fuser->Apply(graph); + Visualize(graph.get()); + ASSERT_EQ(graph->nodes().size(), num_nodes - 5UL * 2 /*nodes removed */ + + 1UL * 2 /* fused fc node*/); +} + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle + +USE_LITE_OP(elementwise_add); +USE_LITE_OP(conv2d); +USE_LITE_OP(depthwise_conv2d); +USE_LITE_OP(relu); diff --git a/paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.cc b/paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.cc similarity index 93% rename from paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.cc rename to paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.cc index 9ce455dcdafb0d2e8f040bc3244495b2968eebd0..20d1eaa82a8cdaafb21252a60b7977e0b4bae1cd 100644 --- a/paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.cc +++ b/paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.h" +#include "paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.h" #include #include #include "paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.h" diff --git a/paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.h b/paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.h similarity index 100% rename from paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.h rename to paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.h diff --git a/paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass_test.cc b/paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuse_pass_test.cc similarity index 97% rename from paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass_test.cc rename to paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuse_pass_test.cc index 7f64eead9ea82457f504be9955f42ededa3650f4..4b7742e059d8206d880859968c7f11a17ca213b7 100644 --- a/paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass_test.cc +++ b/paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuse_pass_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.h" +#include "paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.h" #include #include #include diff --git a/paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.cc b/paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.cc index 83b916eea3e47947083d4a41406d2ebd6918dfd2..cafbc42d85b1f8159b0d5b010847348d1150777d 100644 --- a/paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.cc +++ b/paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.cc @@ -54,7 +54,7 @@ void ElementwiseAddActivationFuser::InsertNewNode(SSAGraph* graph, auto op_desc = GenOpDesc(matched); auto op = LiteOpRegistry::Global().Create("fusion_elementwise_add_activation"); - auto old_op = matched.at("add")->stmt()->op; + auto old_op = matched.at("add")->stmt()->op(); auto* scope = old_op->scope(); auto& valid_places = old_op->valid_places(); op->Attach(op_desc, scope); diff --git a/paddle/fluid/lite/core/mir/fc_fuse_pass.cc b/paddle/fluid/lite/core/mir/fusion/fc_fuse_pass.cc similarity index 94% rename from paddle/fluid/lite/core/mir/fc_fuse_pass.cc rename to paddle/fluid/lite/core/mir/fusion/fc_fuse_pass.cc index 008f05ce5cbd5f6f14d67e79f732e51ab2aa3ddd..f50db9c17b3dd81ad37558996f58164c057abd97 100644 --- a/paddle/fluid/lite/core/mir/fc_fuse_pass.cc +++ b/paddle/fluid/lite/core/mir/fusion/fc_fuse_pass.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/lite/core/mir/fc_fuse_pass.h" +#include "paddle/fluid/lite/core/mir/fusion/fc_fuse_pass.h" #include #include #include "paddle/fluid/lite/core/mir/fusion/fc_fuser.h" diff --git a/paddle/fluid/lite/core/mir/fc_fuse_pass.h b/paddle/fluid/lite/core/mir/fusion/fc_fuse_pass.h similarity index 100% rename from paddle/fluid/lite/core/mir/fc_fuse_pass.h rename to paddle/fluid/lite/core/mir/fusion/fc_fuse_pass.h diff --git a/paddle/fluid/lite/core/mir/fc_fuse_pass_test.cc b/paddle/fluid/lite/core/mir/fusion/fc_fuse_pass_test.cc similarity index 98% rename from paddle/fluid/lite/core/mir/fc_fuse_pass_test.cc rename to paddle/fluid/lite/core/mir/fusion/fc_fuse_pass_test.cc index e2f7dd1a87d2ef576d175857ae880c5828b61a79..b64a436f925d291929079703c8687930b97a8a13 100644 --- a/paddle/fluid/lite/core/mir/fc_fuse_pass_test.cc +++ b/paddle/fluid/lite/core/mir/fusion/fc_fuse_pass_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/lite/core/mir/fc_fuse_pass.h" +#include "fc_fuse_pass.h" #include #include #include diff --git a/paddle/fluid/lite/core/mir/fusion/fc_fuser.cc b/paddle/fluid/lite/core/mir/fusion/fc_fuser.cc index a8b6336595c0fe63d64d75d6434fcfd559c185c9..bb350c731c657c54d071dd490bbba953fbbe83fd 100644 --- a/paddle/fluid/lite/core/mir/fusion/fc_fuser.cc +++ b/paddle/fluid/lite/core/mir/fusion/fc_fuser.cc @@ -46,7 +46,7 @@ void FcFuser::BuildPattern() { void FcFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { auto op_desc = GenOpDesc(matched); auto fc_op = LiteOpRegistry::Global().Create("fc"); - auto mul = matched.at("mul")->stmt()->op; + auto mul = matched.at("mul")->stmt()->op(); auto* scope = mul->scope(); auto& valid_places = mul->valid_places(); fc_op->Attach(op_desc, scope); diff --git a/paddle/fluid/lite/core/mir/generate_program_pass.cc b/paddle/fluid/lite/core/mir/generate_program_pass.cc index 75ff159015d6a090b0b0b926328e30ac4ec087a9..97586d7484204a4eccee9385c05aafbc11460f62 100644 --- a/paddle/fluid/lite/core/mir/generate_program_pass.cc +++ b/paddle/fluid/lite/core/mir/generate_program_pass.cc @@ -29,7 +29,7 @@ void GenerateProgramPass::Apply(const std::unique_ptr& graph) { if (item->IsStmt()) { auto& stmt = item->AsStmt(); VLOG(4) << stmt; - insts_.emplace_back(stmt.op, std::move(stmt.valid_kernels.front())); + insts_.emplace_back(stmt.op(), std::move(stmt.kernels().front())); } } } diff --git a/paddle/fluid/lite/core/mir/graph_visualize_pass.cc b/paddle/fluid/lite/core/mir/graph_visualize_pass.cc index 6a13bafd67ca710691f1a20a62ea411c90064e85..90a99b5deb199cc69e6732eddc60b77964b92d03 100644 --- a/paddle/fluid/lite/core/mir/graph_visualize_pass.cc +++ b/paddle/fluid/lite/core/mir/graph_visualize_pass.cc @@ -39,7 +39,7 @@ std::string Visualize(mir::SSAGraph* graph) { if (node.IsArg()) { key = node.AsArg().name; } else { - key = node.AsStmt().op_type + std::to_string(id++); + key = node.AsStmt().op_type() + std::to_string(id++); } if (node.IsStmt()) { diff --git a/paddle/fluid/lite/core/mir/io_copy_kernel_pick_pass.cc b/paddle/fluid/lite/core/mir/io_copy_kernel_pick_pass.cc index ebf9e5a57bfb9395cbd661c4e69ec2980eebbd17..9f38ce01ba12627b8eae3ff51d5de620c971fc46 100644 --- a/paddle/fluid/lite/core/mir/io_copy_kernel_pick_pass.cc +++ b/paddle/fluid/lite/core/mir/io_copy_kernel_pick_pass.cc @@ -25,11 +25,11 @@ class IoCopyKernelPickPass : public StmtPass { for (auto& node : graph->mutable_nodes()) { if (!node.IsStmt()) continue; auto& inst = node.AsStmt(); - if (inst.op_type != "io_copy") continue; + if (inst.op_type() != "io_copy") continue; LOG(INFO) << "....> picking a IO COPY kernel"; - auto& kernels = node.AsStmt().valid_kernels; + auto& kernels = node.AsStmt().kernels(); CHECK(!kernels.empty()) << "No valid kernels found for IoCopy Op"; const auto* inty = node.inlinks.front()->AsArg().type; const auto* outy = node.outlinks.front()->AsArg().type; diff --git a/paddle/fluid/lite/core/mir/node.cc b/paddle/fluid/lite/core/mir/node.cc index 711ff508f23c7d5218a7d788e90b3fe58f154018..814df2b61a268f6ef71989f831edd53c9bfdb41d 100644 --- a/paddle/fluid/lite/core/mir/node.cc +++ b/paddle/fluid/lite/core/mir/node.cc @@ -13,3 +13,62 @@ // limitations under the License. #include "paddle/fluid/lite/core/mir/node.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { + +const OpInfo *mir::Node::Stmt::op_info() const { + CHECK(op_); + return op_->op_info(); +} + +Place mir::Node::Stmt::place() const { + CHECK(!valid_kernels_.empty()); + return valid_kernels_.front()->place(); +} + +KernelBase &mir::Node::Stmt::picked_kernel() { + CHECK(!valid_kernels_.empty()) << "no kernel for " << op_type(); + return *valid_kernels_.front(); +} + +OpInfo *mir::Node::Stmt::mutable_op_info() { + CHECK(op_); + return op_->mutable_op_info(); +} + +void mir::Node::Stmt::ResetOp(const cpp::OpDesc &op_desc, + const std::vector &valid_places, + lite::Scope *scope) { + CHECK((op_ && op_->scope()) || scope) << "Either scope should be set"; + lite::Scope *the_scope = scope ? scope : op_->scope(); + op_->Attach(op_desc, the_scope); + // Recreate the kernels with the latest OpInfo. + valid_kernels_.clear(); + + if (!op_ || op_->op_info()->Type() != op_desc.Type()) { + op_ = LiteOpRegistry::Global().Create(op_desc.Type()); + CHECK(op_) << "No op found for " << op_desc.Type(); + } + valid_kernels_ = op_->CreateKernels(valid_places); +} + +std::ostream &mir::operator<<(std::ostream &os, const mir::Node::Stmt &other) { + os << "Statement " << other.op_type() << " " << other.place(); + return os; +} + +mir::Node::Arg &mir::Node::AsArg(const std::string &name, int id) { + auto &x = AsArg(); + x.name = name; + x.id = id; + return x; +} +mir::Node::Arg &mir::Node::AsArg(const std::string &name) { + auto &x = AsArg(); + x.name = name; + return x; +} +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/node.h b/paddle/fluid/lite/core/mir/node.h index a5fd90dac482d434afb624216aad875e12350c36..08b7a963e797b31ab015dab4761fb5f41d855faa 100644 --- a/paddle/fluid/lite/core/mir/node.h +++ b/paddle/fluid/lite/core/mir/node.h @@ -41,32 +41,40 @@ class Node { kUnk, }; - struct Stmt { - std::string op_type; + class Stmt { // The kernel instances this Statement contains. - std::vector> valid_kernels; + std::vector> valid_kernels_; // TODO(Superjomn) make this a shared_ptr for resource safety. - std::shared_ptr op; // we hold op to run InferShape + std::shared_ptr op_; // we hold op to run InferShape - const OpInfo* op_info() { - CHECK(op); - return op->op_info(); - } + public: + // Refresh the operator and kernels with the latest OpInfo. + void ResetOp(const cpp::OpDesc& op_desc, + const std::vector& valid_places, + lite::Scope* scope = nullptr); - Place place() const { - CHECK(!valid_kernels.empty()); - return valid_kernels.front()->place(); - } + std::string op_type() const { return op_info()->Type(); } + const OpInfo* op_info() const; + OpInfo* mutable_op_info(); - KernelBase& picked_kernel() { - CHECK(!valid_kernels.empty()) << "no kernel for " << op_type; - return *valid_kernels.front(); + void SetKernels(std::vector>&& kernels) { + valid_kernels_ = std::move(kernels); } - - friend std::ostream& operator<<(std::ostream& os, const Stmt& other) { - os << "Statement " << other.op_type << " " << other.place(); - return os; + std::vector>& kernels() { + return valid_kernels_; } + + void SetOp(const std::shared_ptr& op) { op_ = op; } + const std::shared_ptr op() const { return op_; } + + Place place() const; + + KernelBase& picked_kernel(); + + friend std::ostream& operator<<(std::ostream& os, const Stmt& other); + + // Description. + std::string desc; }; struct Arg { @@ -78,26 +86,16 @@ class Node { bool is_weight{false}; }; - Arg& AsArg(const std::string& name, int id) { - auto& x = AsArg(); - x.name = name; - x.id = id; - return x; - } + Arg& AsArg(const std::string& name, int id); - Arg& AsArg(const std::string& name) { - auto& x = AsArg(); - x.name = name; - return x; - } + Arg& AsArg(const std::string& name); Stmt& AsStmt(const std::string& op_type, std::vector>&& kernels, const std::shared_ptr& op) { auto& x = AsStmt(); - x.op_type = op_type; - x.op = op; - x.valid_kernels = std::move(kernels); + x.SetOp(op); + x.SetKernels(std::move(kernels)); return x; } @@ -142,7 +140,7 @@ class Node { } if (other.IsStmt()) { auto& arg = other.AsStmt(); - os << "Statement " << arg.op_type; + os << "Statement " << arg.op_type(); } return os; } diff --git a/paddle/fluid/lite/core/mir/pattern_matcher.h b/paddle/fluid/lite/core/mir/pattern_matcher.h index ff9fbce35ddf3f601a441bb6105dc658505cbe0e..76ed5f1dd0fdbe337bd2baf63ec9664773b00f8b 100644 --- a/paddle/fluid/lite/core/mir/pattern_matcher.h +++ b/paddle/fluid/lite/core/mir/pattern_matcher.h @@ -139,14 +139,13 @@ struct PMNode { template PMNode* assert_op_attr(const std::string& attr_name, const T& attr) { - asserts_.emplace_back([=](Node* x) { + asserts_.push_back([=](const Node* x) { if (x && x->IsStmt()) { auto* op_info = x->stmt()->op_info(); return op_info->HasAttr(attr_name) && op_info->GetAttr(attr_name) == attr; - } else { - return false; } + return false; }); return this; } diff --git a/paddle/fluid/lite/core/mir/pattern_matcher_high_api.cc b/paddle/fluid/lite/core/mir/pattern_matcher_high_api.cc index 57bba3aad140b4c8f8e1a2c6db27792773c018cd..9f0b2e1f3225d708f0e71c255bad2eec71628f76 100644 --- a/paddle/fluid/lite/core/mir/pattern_matcher_high_api.cc +++ b/paddle/fluid/lite/core/mir/pattern_matcher_high_api.cc @@ -41,6 +41,7 @@ void FuseBase::DeleteInterNodes(SSAGraph *graph) { } } + LOG(INFO) << "keys: " << key2nodes_.size(); std::unordered_set nodes2rm; for (auto &matched : key2nodes_) { for (const auto &key : keys) { diff --git a/paddle/fluid/lite/core/mir/pattern_matcher_high_api.h b/paddle/fluid/lite/core/mir/pattern_matcher_high_api.h index b3a23c654bdb36974fd1a0419c199ba04a1d66bf..7c3f890383d75ff364db5f9018827d2ddd5e9507 100644 --- a/paddle/fluid/lite/core/mir/pattern_matcher_high_api.h +++ b/paddle/fluid/lite/core/mir/pattern_matcher_high_api.h @@ -49,7 +49,13 @@ class FuseBase { virtual void BuildPattern() = 0; // Generate an operator desc with a matched subgraph. - virtual cpp::OpDesc GenOpDesc(const key2nodes_t& matched) = 0; + virtual cpp::OpDesc GenOpDesc(const key2nodes_t& matched) { + return cpp::OpDesc(); + } + + PMNode* OpNode(const std::string& key) { + return GetOrCreateNode(key)->assert_is_op(); + } PMNode* OpNode(const std::string& key, const std::string& op_type); diff --git a/paddle/fluid/lite/core/mir/pattern_matcher_high_api_test.cc b/paddle/fluid/lite/core/mir/pattern_matcher_high_api_test.cc index 7a46bb9a93d95b9379c961d8044fbdfcd04e7ab4..d0844b0b7ef2fa805e042bdf9b66cd478a0de5d0 100644 --- a/paddle/fluid/lite/core/mir/pattern_matcher_high_api_test.cc +++ b/paddle/fluid/lite/core/mir/pattern_matcher_high_api_test.cc @@ -52,7 +52,7 @@ class FcFuser : public FuseBase { void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override { auto op_desc = GenOpDesc(matched); auto fc_op = LiteOpRegistry::Global().Create("fc"); - auto mul = matched.at("mul")->stmt()->op; + auto mul = matched.at("mul")->stmt()->op(); auto* scope = mul->scope(); auto& valid_places = mul->valid_places(); fc_op->Attach(op_desc, scope); @@ -90,7 +90,7 @@ std::unique_ptr BuildGraph(framework::ProgramDesc* program_desc, main_block->Var("w"); main_block->Var("out"); - scope->Var("w")->GetMutable(); + scope->Var("x")->GetMutable(); scope->Var("b")->GetMutable(); scope->Var("mul_out")->GetMutable(); scope->Var("w")->GetMutable(); diff --git a/paddle/fluid/lite/core/mir/pattern_matcher_test.cc b/paddle/fluid/lite/core/mir/pattern_matcher_test.cc index 3b082060fe21731000394f6941e0803af7da74d6..8f2ca38f1cc13e80219aad33fd9c5e03cba52283 100644 --- a/paddle/fluid/lite/core/mir/pattern_matcher_test.cc +++ b/paddle/fluid/lite/core/mir/pattern_matcher_test.cc @@ -23,19 +23,19 @@ namespace mir { void BuildGraph(SSAGraph* g) { g->mutable_nodes().emplace_back(); Node& o1 = g->mutable_nodes().back(); - o1.AsStmt().op_type = "op1"; + o1.AsStmt().desc = "op1"; g->mutable_nodes().emplace_back(); Node& o2 = g->mutable_nodes().back(); - o2.AsStmt().op_type = "op2"; + o2.AsStmt().desc = "op2"; g->mutable_nodes().emplace_back(); Node& o3 = g->mutable_nodes().back(); - o3.AsStmt().op_type = "op3"; + o3.AsStmt().desc = "op3"; g->mutable_nodes().emplace_back(); Node& o4 = g->mutable_nodes().back(); - o4.AsStmt().op_type = "op4"; + o4.AsStmt().desc = "op4"; g->mutable_nodes().emplace_back(); Node& o5 = g->mutable_nodes().back(); - o5.AsStmt().op_type = "op5"; + o5.AsStmt().desc = "op5"; g->mutable_nodes().emplace_back(); Node& v1 = g->mutable_nodes().back(); v1.AsArg("var1"); @@ -108,11 +108,11 @@ TEST(PatternMatcher, MarkPMNodesInGraph) { // v2 -> o3(a node named o3) auto* o2 = x.pattern_.NewNode([](const Node* node) { // The teller can be any condition, such as op type, or variable's shape. - return node && node->IsStmt() && node->stmt()->op_type == "op2"; + return node && node->IsStmt() && node->stmt()->desc == "op2"; }); auto* o3 = x.pattern_.NewNode([](const Node* node) { // The teller can be any condition, such as op type, or variable's shape. - return node && node->IsStmt() && node->stmt()->op_type == "op3"; + return node && node->IsStmt() && node->stmt()->desc == "op3"; }); auto* v2 = x.pattern_.NewNode([](const Node* node) { // The teller can be any condition, such as op type, or variable's shape. @@ -153,8 +153,8 @@ TEST(PatternMatcher, MultiSubgraph) { // op -> var auto* any_op = x.mutable_pattern()->NewNode( [](const Node* node) { - return node->IsStmt() && (node->stmt()->op_type == "op2" || - node->stmt()->op_type == "op3"); + return node->IsStmt() && + (node->stmt()->desc == "op2" || node->stmt()->desc == "op3"); }, "OP0"); auto* any_var = @@ -170,9 +170,9 @@ TEST(PatternMatcher, MultiSubgraph) { int count = 0; PatternMatcher::handle_t handle = [&](const PatternMatcher::subgraph_t& s, SSAGraph* g) { - LOG(INFO) << "Detect " << s.at(any_op)->stmt()->op_type << " -> " + LOG(INFO) << "Detect " << s.at(any_op)->stmt()->desc << " -> " << s.at(any_var)->arg()->name << " -> " - << s.at(any_op1)->stmt()->op_type; + << s.at(any_op1)->stmt()->desc; count++; }; @@ -197,12 +197,12 @@ TEST(PatternMatcher, IntermediateCheck) { PatternMatcher matcher; auto* op2 = matcher.mutable_pattern()->NewNode( [](const Node* x) { - return x && x->IsStmt() && x->stmt()->op_type == "op2"; + return x && x->IsStmt() && x->stmt()->desc == "op2"; }, "op2"); auto* op3 = matcher.mutable_pattern()->NewNode( [](const Node* x) { - return x && x->IsStmt() && x->stmt()->op_type == "op3"; + return x && x->IsStmt() && x->stmt()->desc == "op3"; }, "op3"); auto* v2 = matcher.mutable_pattern() diff --git a/paddle/fluid/lite/core/mir/ssa_graph.h b/paddle/fluid/lite/core/mir/ssa_graph.h index 7c0e6cef498c5c555c1cee6ab334e6be556a9897..0a6f4022dd90f45013ae52f795ecd1f1591c0f7a 100644 --- a/paddle/fluid/lite/core/mir/ssa_graph.h +++ b/paddle/fluid/lite/core/mir/ssa_graph.h @@ -65,6 +65,10 @@ class SSAGraph : GraphBase { Node *GraphCreateInstructNode(const std::shared_ptr &op, const std::vector &valid_places); + // Device related attributes + const std::vector &valid_places() const { return valid_places_; } + void SetValidPlaces(const std::vector &x) { valid_places_ = x; } + private: mir::Node *Argument(const std::string &name); // Check the bidirectional connection. @@ -89,6 +93,7 @@ class SSAGraph : GraphBase { private: std::list node_storage_; std::map arguments_; + std::vector valid_places_; }; // Remove the link between a -> b. diff --git a/paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc b/paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc index 9d48c123a0c8e322f3ec6eb2b9788b9f115e9247..93ee96bbf0a0d1be8a4be7b4ce5f8c9e9b616498 100644 --- a/paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc +++ b/paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc @@ -37,9 +37,9 @@ void StaticKernelPickPass::Apply(const std::unique_ptr& graph) { if (!node.IsStmt()) continue; auto& instruct = node.AsStmt(); std::vector>> scored; - CHECK(!instruct.valid_kernels.empty()) << "No kernels found for " - << instruct.op_type; - for (auto&& kernel : instruct.valid_kernels) { + CHECK(!instruct.kernels().empty()) << "No kernels found for " + << instruct.op_type(); + for (auto&& kernel : instruct.kernels()) { size_t score = KernelGrade(*kernel); scored.emplace_back(score, std::move(kernel)); } @@ -49,9 +49,9 @@ void StaticKernelPickPass::Apply(const std::unique_ptr& graph) { // Move kernel back // Just keep a single best kernel. // TODO(Superjomn) reconsider this. - instruct.valid_kernels.clear(); - instruct.valid_kernels.emplace_back(std::move(scored.front().second)); - VLOG(2) << "pick " << instruct.valid_kernels.front()->name(); + instruct.kernels().clear(); + instruct.kernels().emplace_back(std::move(scored.front().second)); + VLOG(2) << "pick " << instruct.kernels().front()->name(); } } diff --git a/paddle/fluid/lite/core/mir/type_target_transform_pass.cc b/paddle/fluid/lite/core/mir/type_target_transform_pass.cc index 12dd2dcff0607bea46f41e7f5698ad2fb7e12404..951e3423e56f07d45fde38484769a5de5c67f2cc 100644 --- a/paddle/fluid/lite/core/mir/type_target_transform_pass.cc +++ b/paddle/fluid/lite/core/mir/type_target_transform_pass.cc @@ -62,7 +62,7 @@ void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph, Node* inst_node, CHECK(in->AsArg().type); if (!TargetCompatibleTo(*in->AsArg().type, *decl_arg_type)) { LOG(INFO) << "found Target unmatched tensor: " << in->AsArg().name - << " for kernel " << inst.op->DebugString() << " " + << " for kernel " << inst.op()->DebugString() << " " << *in->AsArg().type << " -> " << *decl_arg_type; // Add an IoCopy instruction to make the input compatible with other dist. AddIoCopyInst(*in->AsArg().type, *decl_arg_type, in, graph, inst_node, @@ -89,7 +89,7 @@ void TypeTargetTransformPass::AddIoCopyInst( CHECK(io_copy_op) << "create op [" << io_copy_op << "] failed"; // CHECK(io_copy_op); // Create the new var manually. - inst_node->AsStmt().op->scope()->Var(io_copy_output_name); + inst_node->AsStmt().op()->scope()->Var(io_copy_output_name); // Create IoCopy Instruction. cpp::OpDesc op_desc; @@ -97,7 +97,7 @@ void TypeTargetTransformPass::AddIoCopyInst( op_desc.SetInput("Input", {in->AsArg().name}); op_desc.SetOutput("Out", {io_copy_output_name}); - io_copy_op->Attach(op_desc, inst_node->AsStmt().op->scope()); + io_copy_op->Attach(op_desc, inst_node->AsStmt().op()->scope()); auto kernels = io_copy_op->CreateKernels(valid_places); io_copy_inst->AsStmt("io_copy", std::move(kernels), io_copy_op); @@ -113,19 +113,19 @@ void TypeTargetTransformPass::AddIoCopyInst( DirectedLink(io_copy_output_arg, inst_node); // reset opdesc and update kernel information - UpdateInputTo(inst_node->AsStmt().op->mutable_op_info(), in->AsArg().name, + UpdateInputTo(inst_node->AsStmt().op()->mutable_op_info(), in->AsArg().name, io_copy_output_name); - inst_node->AsStmt().op->Attach(*inst_node->AsStmt().op->op_info(), - inst_node->AsStmt().op->scope()); + inst_node->AsStmt().ResetOp(*inst_node->AsStmt().op_info(), + graph->valid_places()); std::string tmp; if (inst_node->AsStmt().op_info()->GetInputArgname("a", &tmp)) { CHECK(false) << "get old a " << tmp; } - for (auto& kernel : inst_node->AsStmt().valid_kernels) { - inst_node->AsStmt().op->AttachKernel(kernel.get()); + for (auto& kernel : inst_node->AsStmt().kernels()) { + inst_node->AsStmt().op()->AttachKernel(kernel.get()); } graph->CheckValid(); diff --git a/paddle/fluid/lite/core/mir/use_passes.h b/paddle/fluid/lite/core/mir/use_passes.h index f2c81ee57e364ee233b128fcc460111348dc5acb..5203ad3f141b4580aab8eaea4170d19831049e07 100644 --- a/paddle/fluid/lite/core/mir/use_passes.h +++ b/paddle/fluid/lite/core/mir/use_passes.h @@ -24,9 +24,11 @@ USE_MIR_PASS(generate_program_pass); USE_MIR_PASS(io_copy_kernel_pick_pass); USE_MIR_PASS(argument_type_display_pass); #endif + USE_MIR_PASS(runtime_context_assign_pass); USE_MIR_PASS(lite_conv_bn_fuse_pass); USE_MIR_PASS(graph_visualze); USE_MIR_PASS(lite_fc_fuse_pass); +USE_MIR_PASS(identity_scale_eliminate_pass); USE_MIR_PASS(lite_conv_elementwise_add_activation_fuse_pass); USE_MIR_PASS(lite_elementwise_add_activation_fuse_pass); diff --git a/paddle/fluid/lite/core/mir/variable_place_inference_pass.h b/paddle/fluid/lite/core/mir/variable_place_inference_pass.h index 2128c6d2014bf8879743ebf7190b3a95a3bc4186..0a5b3c341ab7e6661903bb189bf4ee8452ccec32 100644 --- a/paddle/fluid/lite/core/mir/variable_place_inference_pass.h +++ b/paddle/fluid/lite/core/mir/variable_place_inference_pass.h @@ -39,7 +39,7 @@ class VariablePlaceInferencePass : public DebugPass { for (const auto& v : graph->inputs()) { // the feed op might in the inputs if (v->IsStmt()) { - LOG(INFO) << "found kernel in inputs " << v->AsStmt().op_type; + LOG(INFO) << "found kernel in inputs " << v->AsStmt().op_type(); continue; } } @@ -59,10 +59,10 @@ class VariablePlaceInferencePass : public DebugPass { for (auto& x : graph->StmtTopologicalOrder()) { auto& inst = x->AsStmt(); // The IoCopyOp is a tool operator, it won't support the type inference. - if (inst.op_type == "io_copy") continue; + if (inst.op_type() == "io_copy") continue; // LOG(INFO) << "- inferencing type " << // deal with inputs - VLOG(4) << "inferencing op " << inst.op_type; + VLOG(4) << "Infering op " << inst.op_info()->Repr(); // TODO(zhaolong): Add check if the node's name in op's arguments. auto get_argname = [&]( @@ -90,12 +90,14 @@ class VariablePlaceInferencePass : public DebugPass { } } + VLOG(3) << "inst " << inst.op_info()->Repr(); for (auto* x_out : x->outlinks) { std::string node_name = x_out->AsArg().name; std::string arg_name = get_argname(node_name, inst.op_info()->outputs()); CHECK(arg_name.size() > 0) << "can not found op arguments for node " - << node_name; + << node_name << " in Inst " + << inst.op_type(); VLOG(3) << "-- output arg_name " << arg_name; auto type = inst.picked_kernel().GetOutputDeclType(arg_name); if (!x_out->AsArg().type) { diff --git a/paddle/fluid/lite/core/op_lite.cc b/paddle/fluid/lite/core/op_lite.cc index 484d22abf52dda9832b524146114e2b2e093bb99..31c339a5e63a2c49134d43b8357bae519bf3a29f 100644 --- a/paddle/fluid/lite/core/op_lite.cc +++ b/paddle/fluid/lite/core/op_lite.cc @@ -61,7 +61,6 @@ std::vector> OpLite::CreateKernels( targets.insert(place.target); } - // CHECK(!kernels.empty()) << "No kernel found for Op " << op_type_; VLOG(2) << "op " << op_type_ << " get " << kernels.size() << " kernels"; return kernels; } @@ -83,7 +82,7 @@ bool OpLite::Attach(const cpp::OpDesc &opdesc, lite::Scope *scope) { scope_ = scope; op_info_.reset( new OpInfo(opdesc)); // Force clean the out-of-date infomation. - return AttachImpl(opdesc, scope); + return AttachImpl(*op_info(), scope); } const Tensor *OpLite::GetTensor(lite::Scope *scope, diff --git a/paddle/fluid/lite/core/op_lite.h b/paddle/fluid/lite/core/op_lite.h index 38cce73d29133b947b49a7e13e4c44f6a37f2455..cd7d9ef84494f2d07859d7008187119ff75eefb1 100644 --- a/paddle/fluid/lite/core/op_lite.h +++ b/paddle/fluid/lite/core/op_lite.h @@ -197,6 +197,22 @@ class OpInfo : public cpp::OpDesc { } return false; } + + void UpdateAllInputs(const std::string &from, const std::string &to) { + for (auto &item : inputs_) { + for (auto &var : item.second) { + if (var == from) var = to; + } + } + } + + void UpdateAllOutputs(const std::string &from, const std::string &to) { + for (auto &item : outputs_) { + for (auto &var : item.second) { + if (var == from) var = to; + } + } + } }; } // namespace lite diff --git a/paddle/fluid/lite/core/optimizer.h b/paddle/fluid/lite/core/optimizer.h index c2c1121f53e100ffc747579d6ad826459b47c169..b936a139cbcede98cdf79ca744abab04f87d93f4 100644 --- a/paddle/fluid/lite/core/optimizer.h +++ b/paddle/fluid/lite/core/optimizer.h @@ -43,6 +43,8 @@ class Optimizer { CHECK(!graph_) << "duplicate optimize found"; graph_.reset(new mir::SSAGraph); graph_->Build(program, valid_places); + graph_->SetValidPlaces(valid_places); + SpecifyKernelPickTactic(kernel_pick_factor); InitTargetTypeTransformPass(); @@ -51,6 +53,7 @@ class Optimizer { "lite_conv_bn_fuse_pass", // "lite_conv_elementwise_add_activation_fuse_pass", // "lite_fc_fuse_pass", // + "identity_scale_eliminate_pass", // #ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK "lite_elementwise_add_activation_fuse_pass", // #endif diff --git a/paddle/fluid/lite/core/program.h b/paddle/fluid/lite/core/program.h index 2f3e078462a7a5ff61217aa6c10b6e3973a29143..46da1815f197a2107a2ab3c3d844f1c4d87b44f2 100644 --- a/paddle/fluid/lite/core/program.h +++ b/paddle/fluid/lite/core/program.h @@ -140,7 +140,7 @@ class RuntimeProgram { void Run() { for (auto& inst : instructions_) { - VLOG(4) << ">> Running kernel: " << inst; + VLOG(3) << ">> Running kernel: " << inst.op()->op_info()->Repr(); inst.Run(); } } diff --git a/paddle/fluid/lite/core/tensor.h b/paddle/fluid/lite/core/tensor.h index 1d61f72063b8f6e40975e10ae6907c8264d4c117..2c001c84e4c98f68ebc90729ba8bfbc4acdde6d3 100644 --- a/paddle/fluid/lite/core/tensor.h +++ b/paddle/fluid/lite/core/tensor.h @@ -191,7 +191,6 @@ class TensorBase { template bool TensorCompareWith(const TensorT &a, const TensorT &b) { if (a.dims() != b.dims()) return false; - LOG(INFO) << "data_size: " << a.data_size(); if (memcmp(a.raw_data(), b.raw_data(), a.data_size()) != 0) return false; return true; } diff --git a/paddle/fluid/lite/kernels/x86/elementwise_compute_test.cc b/paddle/fluid/lite/kernels/x86/elementwise_compute_test.cc index abb28e2bb5868e6188c13c6ae145de74881801ae..8bade95f58ca386adb2b9a94da888a58f15158ac 100644 --- a/paddle/fluid/lite/kernels/x86/elementwise_compute_test.cc +++ b/paddle/fluid/lite/kernels/x86/elementwise_compute_test.cc @@ -15,6 +15,8 @@ #include "paddle/fluid/lite/kernels/x86/elementwise_compute.h" #include #include +#include +#include #include #include "paddle/fluid/lite/core/op_registry.h" diff --git a/paddle/fluid/lite/kernels/x86/mul_compute.h b/paddle/fluid/lite/kernels/x86/mul_compute.h index 96f90842f69f12a1c7baee9f66f055bb21d73126..0f95fea934a26fee17dd52bf9746b96828af1948 100644 --- a/paddle/fluid/lite/kernels/x86/mul_compute.h +++ b/paddle/fluid/lite/kernels/x86/mul_compute.h @@ -40,12 +40,20 @@ class MulCompute : public KernelLite { auto* x = ¶m.x->raw_tensor(); auto* y = ¶m.y->raw_tensor(); - const Tensor x_matrix = x->dims().size() > 2 ? framework::ReshapeToMatrix( - *x, param.x_num_col_dims) - : *x; - const Tensor y_matrix = y->dims().size() > 2 ? framework::ReshapeToMatrix( - *y, param.y_num_col_dims) - : *y; + Tensor x_matrix, y_matrix; + + if (x->dims().size() > 2) { + x_matrix = framework::ReshapeToMatrix(*x, param.x_num_col_dims); + } else { + x_matrix = *x; + } + + if (y->dims().size() > 2) { + y_matrix = framework::ReshapeToMatrix(*y, param.y_num_col_dims); + + } else { + y_matrix = *y; + } auto* z = ¶m.output->raw_tensor(); auto z_dim = z->dims(); @@ -75,12 +83,22 @@ class MulGradCompute : public KernelLite { auto* x = ¶m.x->raw_tensor(); auto* y = ¶m.y->raw_tensor(); - auto x_matrix = x->dims().size() > 2 - ? framework::ReshapeToMatrix(*x, param.x_num_col_dims) - : static_cast(*x); - auto y_matrix = y->dims().size() > 2 - ? framework::ReshapeToMatrix(*y, param.y_num_col_dims) - : static_cast(*y); + + Tensor x_matrix, y_matrix; + + if (x->dims().size() > 2) { + x_matrix = framework::ReshapeToMatrix(*x, param.x_num_col_dims); + } else { + x_matrix = *x; + } + + if (y->dims().size() > 2) { + y_matrix = framework::ReshapeToMatrix(*y, param.y_num_col_dims); + + } else { + y_matrix = *y; + } + auto* dout = ¶m.output_grad->raw_tensor(); Tensor dout_mat; diff --git a/paddle/fluid/lite/kernels/x86/mul_compute_test.cc b/paddle/fluid/lite/kernels/x86/mul_compute_test.cc index 50854d29d0902baf28770a5320daee92408732c2..c551754328ebe005aeaadf06846d82b48da511e6 100644 --- a/paddle/fluid/lite/kernels/x86/mul_compute_test.cc +++ b/paddle/fluid/lite/kernels/x86/mul_compute_test.cc @@ -15,6 +15,8 @@ #include "paddle/fluid/lite/kernels/x86/mul_compute.h" #include #include +#include +#include #include #include "paddle/fluid/lite/core/op_registry.h" diff --git a/paddle/fluid/lite/model_parser/CMakeLists.txt b/paddle/fluid/lite/model_parser/CMakeLists.txt index c539e409a655d73136b3c5c5ebc84ce1ecc697bd..2690fa0206b1f60b506f7d7b6a76d7abff359fec 100644 --- a/paddle/fluid/lite/model_parser/CMakeLists.txt +++ b/paddle/fluid/lite/model_parser/CMakeLists.txt @@ -11,7 +11,8 @@ if(NOT LITE_ON_MOBILE) endif() -cc_library(compatible_pb_lite SRCS compatible_pb.cc DEPS op_desc_lite framework_proto_lite var_desc_lite) +cc_library(compatible_pb_lite SRCS compatible_pb.cc + DEPS op_desc_lite framework_proto_lite var_desc_lite cpp_op_desc_lite) lite_cc_library(model_parser_lite SRCS model_parser.cc DEPS variable_lite scope_lite ${tensor_lite} scope_lite diff --git a/paddle/fluid/lite/model_parser/cpp/CMakeLists.txt b/paddle/fluid/lite/model_parser/cpp/CMakeLists.txt index e6e2fc77f00c691176aa5c20c455964bd9bd5e66..71073179991294aadef40d5df6d23662ec41fcfe 100644 --- a/paddle/fluid/lite/model_parser/cpp/CMakeLists.txt +++ b/paddle/fluid/lite/model_parser/cpp/CMakeLists.txt @@ -1,2 +1 @@ cc_library(cpp_op_desc_lite SRCS op_desc.cc DEPS any_lite) - diff --git a/paddle/fluid/lite/model_parser/desc_apis.h b/paddle/fluid/lite/model_parser/desc_apis.h index d28f82a0e730850f6a05b1a1bc749e856fe7afd9..5981b873f7ce9c878f5d2e79b2d4b547f8b00c80 100644 --- a/paddle/fluid/lite/model_parser/desc_apis.h +++ b/paddle/fluid/lite/model_parser/desc_apis.h @@ -14,6 +14,7 @@ #pragma once #include +#include #include #include @@ -79,6 +80,27 @@ class OpDescAPI { /// Get an attribute. template T GetAttr(const std::string& name) const; + + std::string Repr() const { + std::stringstream ss; + ss << Type(); + ss << "("; + for (auto& arg : InputArgumentNames()) { + ss << arg << ":"; + for (auto val : Input(arg)) { + ss << val << " "; + } + } + ss << ") -> ("; + for (auto& arg : OutputArgumentNames()) { + ss << arg << ":"; + for (auto val : Output(arg)) { + ss << val << " "; + } + } + ss << ")"; + return ss.str(); + } }; } // namespace lite diff --git a/paddle/fluid/lite/model_parser/pb/op_desc.h b/paddle/fluid/lite/model_parser/pb/op_desc.h index e8772e162a5e7229d57afe18c435a6fa635a87ec..b64ba5452d63a3cb6e4670880a6aed9ac603ac94 100644 --- a/paddle/fluid/lite/model_parser/pb/op_desc.h +++ b/paddle/fluid/lite/model_parser/pb/op_desc.h @@ -141,6 +141,8 @@ class OpDesc : public OpDescAPI { template T GetAttr(const std::string &name) const; + std::string DebugString() const { return desc_.DebugString(); } + private: std::vector GetArguments( const google::protobuf::RepeatedPtrField diff --git a/paddle/fluid/lite/operators/mul_op.h b/paddle/fluid/lite/operators/mul_op.h index 7aa1581bb2adb6214877b33382d09f32ca5e225c..a01427b1f4c87f0d29d073879c799720ddd987d7 100644 --- a/paddle/fluid/lite/operators/mul_op.h +++ b/paddle/fluid/lite/operators/mul_op.h @@ -38,15 +38,19 @@ class MulOpLite : public OpLite { void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } // TODO(Superjomn) replace framework::OpDesc with a lite one. bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { + CHECK(!op_desc.Input("X").empty()); + CHECK(!op_desc.Input("Y").empty()); + CHECK(!op_desc.Output("Out").empty()); + auto input = op_desc.Input("X").front(); auto W = op_desc.Input("Y").front(); auto out = op_desc.Output("Out").front(); auto *var = scope->FindVar(input); CHECK(var); - param_.x = var->GetMutable(); + param_.x = &var->Get(); var = scope->FindVar(W); CHECK(var) << "no var called " << W; - param_.y = var->GetMutable(); + param_.y = &var->Get(); var = scope->FindVar(out); CHECK(var) << "no var called " << out; param_.output = var->GetMutable(); diff --git a/paddle/fluid/lite/operators/op_params.h b/paddle/fluid/lite/operators/op_params.h index 0cc1e6b78e99902724857af7b13cf2fd84500243..b50e14a485526369777cbf3b44fd6e6f21e4ae33 100644 --- a/paddle/fluid/lite/operators/op_params.h +++ b/paddle/fluid/lite/operators/op_params.h @@ -67,8 +67,8 @@ struct ReluParam { // For Mul Op struct MulParam { - lite::Tensor* x{}; - lite::Tensor* y{}; + const lite::Tensor* x{}; + const lite::Tensor* y{}; lite::Tensor* output{}; int x_num_col_dims{1}; diff --git a/paddle/fluid/lite/utils/varient.h b/paddle/fluid/lite/utils/varient.h index 2d2a3061108978364cfebfd1c2b4389e008c5115..52bbcffcef980a13591d13d2bfc2bbc17069aaed 100644 --- a/paddle/fluid/lite/utils/varient.h +++ b/paddle/fluid/lite/utils/varient.h @@ -20,6 +20,7 @@ #include #include #include "paddle/fluid/lite/utils/cp_logging.h" +#include "paddle/fluid/lite/utils/string.h" // This is an equivalent implementation of boost::any. We implement this to // avoid including the whole boost library and keep the inference library small. @@ -116,9 +117,9 @@ struct variant { if (type_id == typeid(T).hash_code()) return *reinterpret_cast(&data); else - throw std::invalid_argument("unmatched type"); - // LOG(FATAL) << "unmatched type get, should be " << type_id << " but get " - // << typeid(T).name(); + throw std::invalid_argument( + string_format("unmatched type, store as %d, but want to get %s", + type_id, typeid(T).name())); return *reinterpret_cast(&data); }