diff --git a/paddle/fluid/lite/core/mir/CMakeLists.txt b/paddle/fluid/lite/core/mir/CMakeLists.txt index 8e824727468f0878f99e7deb4038bee686581ad8..ba6c363b86287b4ccdaaf1faccfa12031d338a57 100644 --- a/paddle/fluid/lite/core/mir/CMakeLists.txt +++ b/paddle/fluid/lite/core/mir/CMakeLists.txt @@ -6,6 +6,7 @@ cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager) add_subdirectory(fusion) cc_library(mir_passes SRCS fc_fuse_pass.cc + conv_elementwise_add_relu_fuse_pass.cc static_kernel_pick_pass.cc variable_place_inference_pass.cc type_target_transform_pass.cc @@ -15,7 +16,7 @@ cc_library(mir_passes argument_type_display_pass.cc demo_pass.cc runtime_context_assign_pass.cc - DEPS mir_pass types_lite context_lite mir_fusers) + DEPS mir_pass types_lite context_lite ${mir_fusers}) # for mobile, unnecessary to compile the following testings. if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) @@ -74,3 +75,8 @@ lite_cc_test(test_lite_fc_fuse SRCS fc_fuse_pass_test.cc lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "lite_fc_model.tar.gz") add_dependencies(test_lite_fc_fuse extern_lite_download_lite_fc_model_tar_gz) + +lite_cc_test(test_lite_conv_elementwise_add_relu_fuse + SRCS conv_elementwise_add_relu_fuse_pass_test.cc + DEPS cxx_api_lite mir_passes + ${ops_lite} ${host_kernels} ${x86_kernels}) diff --git a/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.cc b/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..065e8ceca3f5bf146f22a66d2eb4350c655c79a2 --- /dev/null +++ b/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.cc @@ -0,0 +1,36 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h" +#include +#include +#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h" +#include "paddle/fluid/lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +void ConvElementwiseAddReLUFusePass::Apply( + const std::unique_ptr& graph) { + fusion::ConvElementwiseAddReLUFuser fuser; + fuser(graph.get()); +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(lite_conv_elementwise_add_act_fuse_pass, + paddle::lite::mir::ConvElementwiseAddReLUFusePass); diff --git a/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h b/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..4276f1ffc8c258b0b4266abd950fa1ccf541c4a7 --- /dev/null +++ b/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h @@ -0,0 +1,32 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/lite/core/mir/pass.h" + +namespace paddle { +namespace lite { +namespace mir { + +class ConvElementwiseAddReLUFusePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass_test.cc b/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5ada0a2c60dabf36c5d1081dcb5023baac3b173a --- /dev/null +++ b/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass_test.cc @@ -0,0 +1,153 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h" +#include +#include +#include +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/lite/api/cxx_api.h" +#include "paddle/fluid/lite/core/compatible_tensor.h" +#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h" +#include "paddle/fluid/lite/core/mir/passes.h" +#include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/lite/core/program.h" + +DEFINE_string(model_dir, "", ""); +DEFINE_string(optimized_model, "", ""); + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +std::unique_ptr BuildGraph(framework::ProgramDesc* program_desc, + const std::shared_ptr& scope, + const std::vector& valid_places) { + auto* main_block = program_desc->MutableBlock(0); + + auto* conv2d_1 = main_block->AppendOp(); + auto* conv2d_2 = main_block->AppendOp(); + auto* add_1 = main_block->AppendOp(); + auto* add_2 = main_block->AppendOp(); + auto* relu_1 = main_block->AppendOp(); + auto* relu_2 = main_block->AppendOp(); + + main_block->Var("input_1"); + main_block->Var("input_2"); + main_block->Var("filter_1"); + main_block->Var("filter_2"); + main_block->Var("conv2d_1_out"); + main_block->Var("conv2d_2_out"); + main_block->Var("bias_1"); + main_block->Var("add_1_out"); + main_block->Var("add_2_out"); + main_block->Var("relu_1_out"); + main_block->Var("out"); + + scope->Var("input_1")->GetMutable(); + scope->Var("input_2")->GetMutable(); + scope->Var("filter_1")->GetMutable(); + scope->Var("filter_2")->GetMutable(); + scope->Var("conv2d_1_out")->GetMutable(); + scope->Var("conv2d_2_out")->GetMutable(); + scope->Var("bias_1")->GetMutable(); + scope->Var("add_1_out")->GetMutable(); + scope->Var("add_2_out")->GetMutable(); + scope->Var("relu_1_out")->GetMutable(); + scope->Var("out")->GetMutable(); + + conv2d_1->SetType("conv2d"); + conv2d_1->SetInput("Input", {"input_1"}); + conv2d_1->SetInput("Filter", {"filter_1"}); + conv2d_1->SetOutput("Output", {"conv2d_1_out"}); + conv2d_1->SetAttr("strides", std::vector({1, 1})); + conv2d_1->SetAttr("paddings", std::vector({0, 0})); + conv2d_1->SetAttr("groups", 1); + conv2d_1->SetAttr("dilations", std::vector({1, 1})); + conv2d_1->SetAttr("fuse_relu", false); + + add_1->SetType("elementwise_add"); + add_1->SetInput("X", {"conv2d_1_out"}); + add_1->SetInput("Y", {"bias_1"}); + add_1->SetOutput("Out", {"add_1_out"}); + add_1->SetAttr("axis", 1); + + relu_1->SetType("relu"); + relu_1->SetInput("Input", {"add_1_out"}); + relu_1->SetOutput("Out", {"relu_1_out"}); + + conv2d_2->SetType("conv2d"); + conv2d_2->SetInput("Input", {"input_2"}); + conv2d_2->SetInput("Filter", {"filter_2"}); + conv2d_2->SetOutput("Output", {"conv2d_2_out"}); + conv2d_2->SetAttr("strides", std::vector({1, 1})); + conv2d_2->SetAttr("paddings", std::vector({0, 0})); + conv2d_2->SetAttr("groups", 1); + conv2d_2->SetAttr("dilations", std::vector({1, 1})); + conv2d_2->SetAttr("fuse_relu", false); + + add_2->SetType("elementwise_add"); + add_2->SetInput("X", {"conv2d_2_out"}); + add_2->SetInput("Y", {"relu_1_out"}); + add_2->SetOutput("Out", {"add_2_out"}); + add_2->SetAttr("axis", 1); + + relu_2->SetType("relu"); + relu_2->SetInput("Input", {"add_2_out"}); + relu_2->SetOutput("Out", {"out"}); + + program_desc->Flush(); + + lite::Program program(*program_desc->Proto(), scope, valid_places); + auto graph = std::unique_ptr(new SSAGraph()); + graph->Build(program, valid_places); + + return graph; +} + +TEST(conv_elementwise_add_relu_fuse_pass, graph_test) { + framework::ProgramDesc program_desc; + std::vector places{{TARGET(kHost), PRECISION(kFloat)}}; + auto scope = std::make_shared(); + auto graph = BuildGraph(&program_desc, scope, places); + + ASSERT_EQ(graph->nodes().size(), + 11UL /*vars*/ + 6UL /*ops*/ + 2UL /*feed op + fetch op*/); + Visualize(graph.get()); +} + +TEST(conv_elementwise_add_relu_fuse_pass, fuse_test_op) { + framework::ProgramDesc program_desc; + std::vector places{{TARGET(kHost), PRECISION(kFloat)}}; + auto scope = std::make_shared(); + auto graph = BuildGraph(&program_desc, scope, places); + Visualize(graph.get()); + const int num_nodes = graph->nodes().size(); + auto* fuser = new ConvElementwiseAddReLUFusePass; + fuser->Apply(graph); + Visualize(graph.get()); + ASSERT_EQ(graph->nodes().size(), num_nodes - 5UL * 2 /*nodes removed */ + + 1UL * 2 /* fused fc node*/); +} + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle + +USE_LITE_OP(elementwise_add); +USE_LITE_OP(conv2d); +USE_LITE_OP(depthwise_conv2d); +USE_LITE_OP(relu); diff --git a/paddle/fluid/lite/core/mir/fusion/CMakeLists.txt b/paddle/fluid/lite/core/mir/fusion/CMakeLists.txt index 6a0626b9eb56c0dbb9406ba8a2faa66d898ddd73..e0816c1be56665602bff7dce685f21084d82a0a1 100644 --- a/paddle/fluid/lite/core/mir/fusion/CMakeLists.txt +++ b/paddle/fluid/lite/core/mir/fusion/CMakeLists.txt @@ -1,3 +1,11 @@ -cc_library(mir_fusers +cc_library(fuse_fc SRCS fc_fuser.cc DEPS pattern_matcher_high_api) +cc_library(conv_elementwise_add_relu + SRCS conv_elementwise_add_relu_fuser.cc + DEPS pattern_matcher_high_api) + +set(mir_fusers + fuse_fc + conv_elementwise_add_relu + CACHE INTERNAL "fusers") diff --git a/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.cc b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.cc new file mode 100644 index 0000000000000000000000000000000000000000..c1322386348581476df3db33df9836f1abe3d661 --- /dev/null +++ b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.cc @@ -0,0 +1,104 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h" +#include +#include + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +void ConvElementwiseAddReLUFuser::BuildPattern() { + // create input nodes. + auto* input = VarNode("input"); + auto* filter = VarNode("filter"); + auto* bias = VarNode("bias"); + + // create op nodes + auto* conv2d = OpNode("conv2d", "conv2d"); + auto* add = OpNode("add", "elementwise_add"); + auto* relu = OpNode("relu", "relu"); + + // create intermediate nodes + auto* conv2d_out = VarNode("conv2d_out"); + auto* add_out = VarNode("add_out"); + + // create output node + auto* out = VarNode("output"); + + // create topology. + std::vector conv2d_inputs{filter, input}; + std::vector add_inputs{conv2d_out, bias}; + conv2d_inputs >> *conv2d >> *conv2d_out; + add_inputs >> *add >> *add_out; + *add_out >> *relu >> *out; + + // Some op specialities. + conv2d_out->AsIntermediate(); + add_out->AsIntermediate(); + conv2d->AsIntermediate(); + add->AsIntermediate(); + relu->AsIntermediate(); +} + +void ConvElementwiseAddReLUFuser::InsertNewNode(SSAGraph* graph, + const key2nodes_t& matched) { + auto op_desc = GenOpDesc(matched); + auto conv_op = LiteOpRegistry::Global().Create("conv2d"); + auto conv_old = matched.at("conv2d")->stmt()->op; + auto* scope = conv_old->scope(); + auto& valid_places = conv_old->valid_places(); + conv_op->Attach(op_desc, scope); + + auto* new_op_node = graph->GraphCreateInstructNode(conv_op, valid_places); + + IR_NODE_LINK_TO(matched.at("input"), new_op_node); + IR_NODE_LINK_TO(matched.at("filter"), new_op_node); + IR_NODE_LINK_TO(matched.at("bias"), new_op_node); + IR_NODE_LINK_TO(new_op_node, matched.at("output")); +} + +cpp::OpDesc ConvElementwiseAddReLUFuser::GenOpDesc(const key2nodes_t& matched) { + auto* desc = matched.at("conv2d")->stmt()->op_info(); + + cpp::OpDesc op_desc; + op_desc.SetType("conv2d"); + op_desc.SetInput("Input", {matched.at("input")->arg()->name}); + op_desc.SetInput("Filter", {matched.at("filter")->arg()->name}); + op_desc.SetInput("Bias", {matched.at("bias")->arg()->name}); + op_desc.SetOutput("Output", {matched.at("output")->arg()->name}); + // Other inputs. See operators/conv_op.h + std::vector input_arg_names = desc->InputArgumentNames(); + for (auto name : input_arg_names) LOG(INFO) << name; + + if (std::find(input_arg_names.begin(), input_arg_names.end(), + "ResidualData") != input_arg_names.end()) { + op_desc.SetInput("ResidualData", desc->Input("ResidualData")); + } + + // Only consider strides, padding, groups, dilations, fuse_relu for now + op_desc.SetAttr("strides", desc->GetAttr>("strides")); + op_desc.SetAttr("paddings", desc->GetAttr>("paddings")); + op_desc.SetAttr("groups", desc->GetAttr("groups")); + op_desc.SetAttr("dilations", desc->GetAttr>("dilations")); + op_desc.SetAttr("fuse_relu", true); + return op_desc; +} + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h new file mode 100644 index 0000000000000000000000000000000000000000..5ba0ee268415c93b65e32735c8ba7aabff1ea7c4 --- /dev/null +++ b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h @@ -0,0 +1,38 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +class ConvElementwiseAddReLUFuser : public FuseBase { + public: + void BuildPattern() override; + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; + + private: + cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; +}; + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/passes.h b/paddle/fluid/lite/core/mir/passes.h index 217e4d5dbdfd48efa3249ddec83257bfd8f1c8a5..e0110a8e3b27fd11c03fa15732b634e3d7978682 100644 --- a/paddle/fluid/lite/core/mir/passes.h +++ b/paddle/fluid/lite/core/mir/passes.h @@ -23,6 +23,7 @@ namespace mir {} // namespace mir USE_MIR_PASS(demo); USE_MIR_PASS(lite_fc_fuse_pass); +USE_MIR_PASS(lite_conv_elementwise_add_act_fuse_pass); USE_MIR_PASS(static_kernel_pick_pass); USE_MIR_PASS(variable_place_inference_pass); USE_MIR_PASS(type_target_transform_pass); diff --git a/paddle/fluid/lite/operators/conv_op.h b/paddle/fluid/lite/operators/conv_op.h index 3ab30eb787bd9574a10cc9198f4c08b744eb0c27..e5ad7fe67f9561fd9fda6f1b99c3a00828008422 100644 --- a/paddle/fluid/lite/operators/conv_op.h +++ b/paddle/fluid/lite/operators/conv_op.h @@ -64,17 +64,31 @@ class ConvOpLite : public OpLite { bool AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) override { auto X = op_desc.Input("Input").front(); auto Filter = op_desc.Input("Filter").front(); - auto Bias = op_desc.Input("Bias").front(); - // auto ResidualData = op_desc.Input("ResidualData"); auto Out = op_desc.Output("Output").front(); param_.x = scope->FindVar(X)->GetMutable(); param_.filter = scope->FindVar(Filter)->GetMutable(); - param_.bias = scope->FindVar(Bias)->GetMutable(); - // param_.residualData = - // scope->FindVar(ResidualData)->GetMutable(); param_.output = scope->FindVar(Out)->GetMutable(); + std::vector input_arg_names = op_desc.InputArgumentNames(); + if (std::find(input_arg_names.begin(), input_arg_names.end(), "Bias") != + input_arg_names.end()) { + auto bias_var = scope->FindVar(op_desc.Input("Bias").front()); + if (bias_var != nullptr) { + param_.bias = + const_cast(&(bias_var->Get())); + } + } + if (std::find(input_arg_names.begin(), input_arg_names.end(), + "ResidualData") != input_arg_names.end()) { + auto residual_data_var = + scope->FindVar(op_desc.Input("ResidualData").front()); + if (residual_data_var != nullptr) { + param_.residualData = const_cast( + &(residual_data_var->Get())); + } + } + param_.strides = op_desc.GetAttr>("strides"); param_.paddings = op_desc.GetAttr>("paddings"); param_.groups = op_desc.GetAttr("groups"); diff --git a/paddle/fluid/lite/operators/relu_op.cc b/paddle/fluid/lite/operators/relu_op.cc index b073e2db43a4891defeb95750424941969323ba0..a588b1c8cbf101c2f795253960dc4a47d79a7e60 100644 --- a/paddle/fluid/lite/operators/relu_op.cc +++ b/paddle/fluid/lite/operators/relu_op.cc @@ -37,7 +37,6 @@ bool ReluOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { scope->FindVar(opdesc.Output("Out").front())->GetMutable(); CHECK(param_.input); CHECK(param_.output); - kernel_->SetParam(param_); return true; }