diff --git a/paddle/fluid/lite/core/mir/CMakeLists.txt b/paddle/fluid/lite/core/mir/CMakeLists.txt index fe7defcf73e6bea6819c62ae36c87b59eb4f09b2..0d7fcf8b3b2843c0d36be24288743a86b8c7ea24 100644 --- a/paddle/fluid/lite/core/mir/CMakeLists.txt +++ b/paddle/fluid/lite/core/mir/CMakeLists.txt @@ -7,7 +7,8 @@ cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager) add_subdirectory(fusion) cc_library(mir_passes SRCS fc_fuse_pass.cc - conv_elementwise_add_relu_fuse_pass.cc + conv_elementwise_add_activation_fuse_pass.cc + elementwise_add_activation_fuse_pass.cc conv_bn_fuse_pass.cc static_kernel_pick_pass.cc variable_place_inference_pass.cc @@ -82,7 +83,11 @@ lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "lite_fc_model.tar.gz add_dependencies(test_lite_fc_fuse extern_lite_download_lite_fc_model_tar_gz) -lite_cc_test(test_lite_conv_elementwise_add_relu_fuse - SRCS conv_elementwise_add_relu_fuse_pass_test.cc +lite_cc_test(test_lite_conv_elementwise_add_activation_fuse + SRCS conv_elementwise_add_activation_fuse_pass_test.cc + DEPS cxx_api_lite mir_passes + ${ops_lite} ${host_kernels} ${x86_kernels}) +lite_cc_test(test_lite_elementwise_add_activation_fuse + SRCS elementwise_add_activation_fuse_pass_test.cc DEPS cxx_api_lite mir_passes ${ops_lite} ${host_kernels} ${x86_kernels}) diff --git a/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.cc b/paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.cc similarity index 66% rename from paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.cc rename to paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.cc index 3110c7aa6d408d2520d982ec76a77baea7babdbc..27f6413c47b514d3203c5879d7ee7b9697d8cf5a 100644 --- a/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.cc +++ b/paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.cc @@ -12,22 +12,23 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h" +#include "paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.h" #include #include -#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h" +#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.h" #include "paddle/fluid/lite/core/mir/pass_registry.h" namespace paddle { namespace lite { namespace mir { -void ConvElementwiseAddReLUFusePass::Apply( +void ConvElementwiseAddActivationFusePass::Apply( const std::unique_ptr& graph) { - fusion::ConvElementwiseAddReLUFuser fuser("conv2d"); + fusion::ConvElementwiseAddActivationFuser fuser("conv2d", "relu"); fuser(graph.get()); - fusion::ConvElementwiseAddReLUFuser depthwise_fuser("depthwise_conv2d"); + fusion::ConvElementwiseAddActivationFuser depthwise_fuser("depthwise_conv2d", + "relu"); depthwise_fuser(graph.get()); } @@ -35,5 +36,5 @@ void ConvElementwiseAddReLUFusePass::Apply( } // namespace lite } // namespace paddle -REGISTER_MIR_PASS(lite_conv_elementwise_add_act_fuse_pass, - paddle::lite::mir::ConvElementwiseAddReLUFusePass); +REGISTER_MIR_PASS(lite_conv_elementwise_add_activation_fuse_pass, + paddle::lite::mir::ConvElementwiseAddActivationFusePass); diff --git a/paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.h b/paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..a5a619f4d0d06da52661282e68f6a3c34c987bc9 --- /dev/null +++ b/paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.h @@ -0,0 +1,32 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/lite/core/mir/pass.h" + +namespace paddle { +namespace lite { +namespace mir { + +class ConvElementwiseAddActivationFusePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass_test.cc b/paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass_test.cc similarity index 95% rename from paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass_test.cc rename to paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass_test.cc index 30991313ad3ed9ef39c3fb8183f4cfc43c9c49b9..5a5fdc134b810b67df33b2d385d982a306a0dddc 100644 --- a/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass_test.cc +++ b/paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h" +#include "paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.h" #include #include #include @@ -135,11 +135,11 @@ TEST(conv_elementwise_add_relu_fuse_pass, fuse_test_op) { auto graph = BuildGraph(&program_desc, scope, places); Visualize(graph.get()); const int num_nodes = graph->nodes().size(); - auto* fuser = new ConvElementwiseAddReLUFusePass; + auto* fuser = new ConvElementwiseAddActivationFusePass; fuser->Apply(graph); Visualize(graph.get()); - ASSERT_EQ(graph->nodes().size(), num_nodes - 5UL * 2 /*nodes removed */ + - 1UL * 2 /* fused fc node*/); + ASSERT_EQ(graph->nodes().size(), + num_nodes - 5UL * 2 /*nodes removed */ + 1UL * 2 /* fused nodes*/); } } // namespace fusion diff --git a/paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.cc b/paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..9ce455dcdafb0d2e8f040bc3244495b2968eebd0 --- /dev/null +++ b/paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.cc @@ -0,0 +1,36 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.h" +#include +#include +#include "paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.h" +#include "paddle/fluid/lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +void ElementwiseAddActivationFusePass::Apply( + const std::unique_ptr& graph) { + fusion::ElementwiseAddActivationFuser fuser("relu"); + fuser(graph.get()); +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(lite_elementwise_add_activation_fuse_pass, + paddle::lite::mir::ElementwiseAddActivationFusePass); diff --git a/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h b/paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.h similarity index 93% rename from paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h rename to paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.h index 4276f1ffc8c258b0b4266abd950fa1ccf541c4a7..213c3f68f6008bfc9c522b3896a678a137e92201 100644 --- a/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h +++ b/paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.h @@ -22,7 +22,7 @@ namespace paddle { namespace lite { namespace mir { -class ConvElementwiseAddReLUFusePass : public ProgramPass { +class ElementwiseAddActivationFusePass : public ProgramPass { public: void Apply(const std::unique_ptr& graph) override; }; diff --git a/paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass_test.cc b/paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f70da091bda94c5e88c0022ad9aa97828b9fb947 --- /dev/null +++ b/paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass_test.cc @@ -0,0 +1,117 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.h" +#include +#include +#include +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/lite/api/cxx_api.h" +#include "paddle/fluid/lite/core/compatible_tensor.h" +#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h" +#include "paddle/fluid/lite/core/mir/passes.h" +#include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/lite/core/program.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +std::unique_ptr BuildGraph(framework::ProgramDesc* program_desc, + const std::shared_ptr& scope, + const std::vector& valid_places) { + auto* main_block = program_desc->MutableBlock(0); + + auto* add_1 = main_block->AppendOp(); + auto* add_2 = main_block->AppendOp(); + auto* relu_1 = main_block->AppendOp(); + auto* relu_2 = main_block->AppendOp(); + + main_block->Var("x_1"); + main_block->Var("y_1"); + main_block->Var("add_out_1"); + main_block->Var("relu_out_1"); + main_block->Var("y_2"); + main_block->Var("add_out_2"); + main_block->Var("out"); + + scope->Var("x_1")->GetMutable(); + scope->Var("y_1")->GetMutable(); + scope->Var("add_out_1")->GetMutable(); + scope->Var("relu_out_1")->GetMutable(); + scope->Var("y_2")->GetMutable(); + scope->Var("add_out_2")->GetMutable(); + scope->Var("out")->GetMutable(); + + add_1->SetType("elementwise_add"); + add_1->SetInput("X", {"x_1"}); + add_1->SetInput("Y", {"y_1"}); + add_1->SetOutput("Out", {"add_out_1"}); + add_1->SetAttr("axis", 1); + + relu_1->SetType("relu"); + relu_1->SetInput("X", {"add_out_1"}); + relu_1->SetOutput("Out", {"relu_out_1"}); + + add_2->SetType("elementwise_add"); + add_2->SetInput("X", {"relu_out_1"}); + add_2->SetInput("Y", {"y_2"}); + add_2->SetOutput("Out", {"add_out_2"}); + add_2->SetAttr("axis", 1); + + relu_2->SetType("relu"); + relu_2->SetInput("X", {"add_out_2"}); + relu_2->SetOutput("Out", {"out"}); + + program_desc->Flush(); + + lite::Program program(*program_desc->Proto(), scope, valid_places); + auto graph = std::unique_ptr(new SSAGraph()); + graph->Build(program, valid_places); + + return graph; +} + +TEST(elementwise_add_activation_fuse_pass, graph_test) { + framework::ProgramDesc program_desc; + std::vector places{{TARGET(kHost), PRECISION(kFloat)}}; + auto scope = std::make_shared(); + auto graph = BuildGraph(&program_desc, scope, places); + ASSERT_EQ(graph->nodes().size(), + 7UL /*vars*/ + 4UL /*ops*/ + 1UL /* SSAGraph tmp node*/); +} + +TEST(elementwise_add_activation_fuse_pass, fuse_test_op) { + framework::ProgramDesc program_desc; + std::vector places{{TARGET(kHost), PRECISION(kFloat)}}; + auto scope = std::make_shared(); + auto graph = BuildGraph(&program_desc, scope, places); + Visualize(graph.get()); + const int num_nodes = graph->nodes().size(); + auto* fuser = new ElementwiseAddActivationFusePass; + fuser->Apply(graph); + Visualize(graph.get()); + ASSERT_EQ(graph->nodes().size(), + num_nodes - 3UL * 2 /*nodes removed */ + 1UL * 2 /* fused nodes*/); +} + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle + +USE_LITE_OP(elementwise_add); +USE_LITE_OP(fusion_elementwise_add_activation); +USE_LITE_OP(relu); diff --git a/paddle/fluid/lite/core/mir/fusion/CMakeLists.txt b/paddle/fluid/lite/core/mir/fusion/CMakeLists.txt index fbc7ffe730bca1e2d1c5c9fa48e81bc3b98de45c..9139293c8aa59d5664e29afba97c02226f9338bf 100644 --- a/paddle/fluid/lite/core/mir/fusion/CMakeLists.txt +++ b/paddle/fluid/lite/core/mir/fusion/CMakeLists.txt @@ -1,17 +1,21 @@ cc_library(fuse_fc SRCS fc_fuser.cc DEPS pattern_matcher_high_api) -cc_library(fuse_conv_elementwise_add_relu - SRCS conv_elementwise_add_relu_fuser.cc +cc_library(fuse_conv_elementwise_add_activation + SRCS conv_elementwise_add_activation_fuser.cc DEPS pattern_matcher_high_api) cc_library(fuse_conv_bn SRCS conv_bn_fuser.cc DEPS pattern_matcher_high_api) +cc_library(fuse_elementwise_add_activation + SRCS elementwise_add_activation_fuser.cc + DEPS pattern_matcher_high_api) set(mir_fusers fuse_fc - fuse_conv_elementwise_add_relu + fuse_conv_elementwise_add_activation fuse_conv_bn + fuse_elementwise_add_activation CACHE INTERNAL "fusers") if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) diff --git a/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.cc b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.cc similarity index 86% rename from paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.cc rename to paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.cc index 421c920e6214756a823622b4f24dfb651d63951b..b26b758fb2318b7c9a645503687f994b73009310 100644 --- a/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.cc +++ b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h" +#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.h" #include #include @@ -21,7 +21,7 @@ namespace lite { namespace mir { namespace fusion { -void ConvElementwiseAddReLUFuser::BuildPattern() { +void ConvElementwiseAddActivationFuser::BuildPattern() { // create input nodes. auto* input = VarNode("input")->assert_is_op_input(conv_type_, "Input")->AsInput(); @@ -36,7 +36,8 @@ void ConvElementwiseAddReLUFuser::BuildPattern() { auto* add = OpNode("add", "elementwise_add") ->assert_is_op("elementwise_add") ->AsIntermediate(); - auto* relu = OpNode("relu", "relu")->assert_is_op("relu")->AsIntermediate(); + auto* act = + OpNode("act", act_type_)->assert_is_op(act_type_)->AsIntermediate(); // create intermediate nodes auto* conv2d_out = VarNode("conv2d_out") @@ -45,22 +46,23 @@ void ConvElementwiseAddReLUFuser::BuildPattern() { ->AsIntermediate(); auto* add_out = VarNode("add_out") ->assert_is_op_output("elementwise_add", "Out") - ->assert_is_op_input("relu", "X") + ->assert_is_op_input(act_type_, "X") ->AsIntermediate(); // create output node - auto* out = VarNode("output")->assert_is_op_output("relu", "Out")->AsOutput(); + auto* out = + VarNode("output")->assert_is_op_output(act_type_, "Out")->AsOutput(); // create topology. std::vector conv2d_inputs{filter, input}; std::vector add_inputs{conv2d_out, bias}; conv2d_inputs >> *conv2d >> *conv2d_out; add_inputs >> *add >> *add_out; - *add_out >> *relu >> *out; + *add_out >> *act >> *out; } -void ConvElementwiseAddReLUFuser::InsertNewNode(SSAGraph* graph, - const key2nodes_t& matched) { +void ConvElementwiseAddActivationFuser::InsertNewNode( + SSAGraph* graph, const key2nodes_t& matched) { auto op_desc = GenOpDesc(matched); auto conv_op = LiteOpRegistry::Global().Create(conv_type_); auto conv_old = matched.at("conv2d")->stmt()->op; @@ -76,7 +78,8 @@ void ConvElementwiseAddReLUFuser::InsertNewNode(SSAGraph* graph, IR_NODE_LINK_TO(new_op_node, matched.at("output")); } -cpp::OpDesc ConvElementwiseAddReLUFuser::GenOpDesc(const key2nodes_t& matched) { +cpp::OpDesc ConvElementwiseAddActivationFuser::GenOpDesc( + const key2nodes_t& matched) { auto* desc = matched.at("conv2d")->stmt()->op_info(); cpp::OpDesc op_desc; @@ -98,6 +101,7 @@ cpp::OpDesc ConvElementwiseAddReLUFuser::GenOpDesc(const key2nodes_t& matched) { op_desc.SetAttr("paddings", desc->GetAttr>("paddings")); op_desc.SetAttr("groups", desc->GetAttr("groups")); op_desc.SetAttr("dilations", desc->GetAttr>("dilations")); + // TODO(sangoly): support other activation types op_desc.SetAttr("fuse_relu", true); return op_desc; } diff --git a/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.h b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.h new file mode 100644 index 0000000000000000000000000000000000000000..14a33613fdffce8c2d9d4044a11b5de4b5652da3 --- /dev/null +++ b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.h @@ -0,0 +1,47 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +class ConvElementwiseAddActivationFuser : public FuseBase { + public: + explicit ConvElementwiseAddActivationFuser(const std::string& conv_type, + const std::string& act_type) { + CHECK(act_type == "relu") << "Only relu activation be supported now"; + conv_type_ = conv_type; + act_type_ = act_type; + } + + void BuildPattern() override; + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; + + private: + cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; + std::string conv_type_; + std::string act_type_; +}; + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.cc b/paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.cc new file mode 100644 index 0000000000000000000000000000000000000000..83b916eea3e47947083d4a41406d2ebd6918dfd2 --- /dev/null +++ b/paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.cc @@ -0,0 +1,87 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.h" +#include +#include + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +void ElementwiseAddActivationFuser::BuildPattern() { + // create input nodes. + auto* x = VarNode("x")->assert_is_op_input("elementwise_add", "X")->AsInput(); + auto* y = VarNode("y")->assert_is_op_input("elementwise_add", "Y")->AsInput(); + + // create op nodes + auto* add = OpNode("add", "elementwise_add") + ->assert_is_op("elementwise_add") + ->AsIntermediate(); + auto* act = + OpNode("act", act_type_)->assert_is_op(act_type_)->AsIntermediate(); + + // create intermediate nodes + auto* add_out = VarNode("add_out") + ->assert_is_op_output("elementwise_add", "Out") + ->assert_is_op_input(act_type_, "X") + ->AsIntermediate(); + + // create output node + auto* out = + VarNode("output")->assert_is_op_output(act_type_, "Out")->AsOutput(); + + // create topology. + std::vector add_inputs{x, y}; + add_inputs >> *add >> *add_out; + *add_out >> *act >> *out; +} + +void ElementwiseAddActivationFuser::InsertNewNode(SSAGraph* graph, + const key2nodes_t& matched) { + auto op_desc = GenOpDesc(matched); + auto op = + LiteOpRegistry::Global().Create("fusion_elementwise_add_activation"); + auto old_op = matched.at("add")->stmt()->op; + auto* scope = old_op->scope(); + auto& valid_places = old_op->valid_places(); + op->Attach(op_desc, scope); + + auto* new_op_node = graph->GraphCreateInstructNode(op, valid_places); + + IR_NODE_LINK_TO(matched.at("x"), new_op_node); + IR_NODE_LINK_TO(matched.at("y"), new_op_node); + IR_NODE_LINK_TO(new_op_node, matched.at("output")); +} + +cpp::OpDesc ElementwiseAddActivationFuser::GenOpDesc( + const key2nodes_t& matched) { + auto* desc = matched.at("add")->stmt()->op_info(); + + cpp::OpDesc op_desc; + op_desc.SetType("fusion_elementwise_add_activation"); + op_desc.SetInput("X", {matched.at("x")->arg()->name}); + op_desc.SetInput("Y", {matched.at("y")->arg()->name}); + op_desc.SetOutput("Out", {matched.at("output")->arg()->name}); + + op_desc.SetAttr("axis", desc->GetAttr("axis")); + op_desc.SetAttr("act_type", act_type_); + return op_desc; +} + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h b/paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.h similarity index 85% rename from paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h rename to paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.h index 3e21368234f36a5afafb08958930943599955090..bcd7b4cbcda84538f01cc4e418ce201500edbb26 100644 --- a/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h +++ b/paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.h @@ -23,16 +23,16 @@ namespace lite { namespace mir { namespace fusion { -class ConvElementwiseAddReLUFuser : public FuseBase { +class ElementwiseAddActivationFuser : public FuseBase { public: - explicit ConvElementwiseAddReLUFuser(const std::string& conv_type) - : conv_type_(conv_type) {} + explicit ElementwiseAddActivationFuser(const std::string& act_type) + : act_type_(act_type) {} void BuildPattern() override; void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; private: cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; - std::string conv_type_; + std::string act_type_; }; } // namespace fusion diff --git a/paddle/fluid/lite/core/mir/passes.h b/paddle/fluid/lite/core/mir/passes.h index a6abb16e3eaabe6a0f12b75248f3db1f7cfeeb81..cea15e12f64e94f4b4acfdd91e73a2abf4f05dee 100644 --- a/paddle/fluid/lite/core/mir/passes.h +++ b/paddle/fluid/lite/core/mir/passes.h @@ -34,4 +34,5 @@ USE_MIR_PASS(runtime_context_assign_pass); USE_MIR_PASS(lite_conv_bn_fuse_pass); USE_MIR_PASS(graph_visualze); USE_MIR_PASS(lite_fc_fuse_pass); -USE_MIR_PASS(lite_conv_elementwise_add_act_fuse_pass); +USE_MIR_PASS(lite_conv_elementwise_add_activation_fuse_pass); +USE_MIR_PASS(lite_elementwise_add_activation_fuse_pass); diff --git a/paddle/fluid/lite/core/optimizer.h b/paddle/fluid/lite/core/optimizer.h index a3e0641b1c7a44809e2a8fdc1b34a49772f71085..c2c1121f53e100ffc747579d6ad826459b47c169 100644 --- a/paddle/fluid/lite/core/optimizer.h +++ b/paddle/fluid/lite/core/optimizer.h @@ -48,9 +48,12 @@ class Optimizer { if (passes.empty()) { RunPasses(std::vector{{ - "lite_conv_bn_fuse_pass", // - "lite_conv_elementwise_add_act_fuse_pass", // - "lite_fc_fuse_pass", // + "lite_conv_bn_fuse_pass", // + "lite_conv_elementwise_add_activation_fuse_pass", // + "lite_fc_fuse_pass", // +#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK + "lite_elementwise_add_activation_fuse_pass", // +#endif #ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK "static_kernel_pick_pass", // "variable_place_inference_pass", // diff --git a/paddle/fluid/lite/operators/CMakeLists.txt b/paddle/fluid/lite/operators/CMakeLists.txt index 1a7ac5bd5704e040498eb1dfb2f51f5429b9d7c0..ba23181387bc58464eae205877a8d9ccf6959146 100644 --- a/paddle/fluid/lite/operators/CMakeLists.txt +++ b/paddle/fluid/lite/operators/CMakeLists.txt @@ -14,6 +14,7 @@ cc_library(fetch_op_lite SRCS fetch_op.cc DEPS ${op_DEPS}) cc_library(io_copy_op_lite SRCS io_copy_op.cc DEPS ${op_DEPS}) cc_library(activation_ops_lite SRCS activation_ops.cc DEPS ${op_DEPS}) cc_library(elementwise_ops_lite SRCS elementwise_ops.cc DEPS ${op_DEPS}) +cc_library(fusion_elementwise_activation_ops_lite SRCS fusion_elementwise_activation_ops.cc DEPS elementwise_ops_lite ${op_DEPS}) cc_library(mean_op_lite SRCS mean_op.cc DEPS ${op_DEPS}) cc_library(fill_constant_op_lite SRCS fill_constant_op.cc DEPS ${op_DEPS}) #cc_library(sgd_op_lite SRCS sgd_op.cc DEPS ${op_DEPS}) @@ -36,6 +37,7 @@ set(ops_lite fetch_op_lite io_copy_op_lite elementwise_ops_lite + fusion_elementwise_activation_ops_lite mean_op_lite fill_constant_op_lite activation_ops_lite @@ -56,3 +58,6 @@ lite_cc_test(test_softmax_op_lite SRCS softmax_op_test.cc DEPS softmax_op_lite m lite_cc_test(test_reshape_op_lite SRCS reshape_op_test.cc DEPS reshape_op_lite memory_lite) lite_cc_test(test_batch_norm_op_lite SRCS batch_norm_op_test.cc DEPS batch_norm_op_lite memory_lite) lite_cc_test(test_concat_op_lite SRCS concat_op_test.cc DEPS concat_op_lite memory_lite) +lite_cc_test(test_fusion_elementwise_activation_ops_lite + SRCS fusion_elementwise_activation_ops_test.cc + DEPS fusion_elementwise_activation_ops_lite memory_lite) diff --git a/paddle/fluid/lite/operators/elementwise_ops.cc b/paddle/fluid/lite/operators/elementwise_ops.cc index b400b1ab26c137fbbee830e1992706e586ae152e..2c6d4e709082b11ab643d6d8b8571efcba4e5f7b 100644 --- a/paddle/fluid/lite/operators/elementwise_ops.cc +++ b/paddle/fluid/lite/operators/elementwise_ops.cc @@ -12,92 +12,67 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/operators/elementwise_ops.h" #include "paddle/fluid/lite/core/op_registry.h" namespace paddle { namespace lite { namespace operators { -class ElementwiseOp : public OpLite { - public: - explicit ElementwiseOp(const std::string& type) : OpLite(type) {} - - bool CheckShape() const override { - CHECK_OR_FALSE(param_.X); - CHECK_OR_FALSE(param_.Y); - CHECK_OR_FALSE(param_.Out); - return true; - } - - bool InferShape() const override { - CHECK_OR_FALSE(param_.X->dims().size() >= param_.Y->dims().size()); - param_.Out->Resize(param_.X->dims()); - return true; - } - - bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override { - auto X_name = opdesc.Input("X").front(); - auto Y_name = opdesc.Input("Y").front(); - auto Out_name = opdesc.Output("Out").front(); - - param_.X = GetVar(scope, X_name); - param_.Y = GetVar(scope, Y_name); - param_.Out = GetMutableVar(scope, Out_name); - param_.axis = opdesc.GetAttr("axis"); - return true; - } - - void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } - - std::string DebugString() const override { return "elementwise_op"; } - - private: - mutable operators::ElementwiseParam param_; -}; +bool ElementwiseOp::CheckShape() const { + CHECK_OR_FALSE(param_.X); + CHECK_OR_FALSE(param_.Y); + CHECK_OR_FALSE(param_.Out); + return true; +} + +bool ElementwiseOp::InferShape() const { + CHECK_OR_FALSE(param_.X->dims().size() >= param_.Y->dims().size()); + param_.Out->Resize(param_.X->dims()); + return true; +} + +bool ElementwiseOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) { + auto X_name = opdesc.Input("X").front(); + auto Y_name = opdesc.Input("Y").front(); + auto Out_name = opdesc.Output("Out").front(); + + param_.X = GetVar(scope, X_name); + param_.Y = GetVar(scope, Y_name); + param_.Out = GetMutableVar(scope, Out_name); + param_.axis = opdesc.GetAttr("axis"); + return true; +} #ifdef LITE_WITH_X86 -class ElementwiseGradExplicitOp : public OpLite { - public: - explicit ElementwiseGradExplicitOp(const std::string& type) : OpLite(type) {} - - bool CheckShape() const override { - CHECK_OR_FALSE(param_.Y); - CHECK_OR_FALSE(param_.X_grad); - CHECK_OR_FALSE(param_.Y_grad); - CHECK_OR_FALSE(param_.Out_grad); - return true; - } - - bool InferShape() const override { - param_.X_grad->Resize(param_.Out_grad->dims()); - param_.Y_grad->Resize(param_.Y->dims()); - return true; - } - - bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override { - CHECK_EQ(opdesc.InputArgumentNames().size(), 1UL); - auto Out_name = opdesc.Input(framework::GradVarName("Out")).front(); - auto X_name = opdesc.Output(framework::GradVarName("X")).front(); - auto Y_name = opdesc.Output(framework::GradVarName("Y")).front(); - - param_.Out_grad = GetVar(scope, Out_name); - param_.X_grad = GetMutableVar(scope, X_name); - param_.Y_grad = GetMutableVar(scope, Y_name); - param_.axis = opdesc.GetAttr("axis"); - - return true; - } - - void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } - - std::string DebugString() const override { - return "elementwise_grad_explicit_op"; - } - - private: - mutable operators::ElementwiseGradParam param_; -}; +bool ElementwiseGradExplicitOp::CheckShape() const { + CHECK_OR_FALSE(param_.Y); + CHECK_OR_FALSE(param_.X_grad); + CHECK_OR_FALSE(param_.Y_grad); + CHECK_OR_FALSE(param_.Out_grad); + return true; +} + +bool ElementwiseGradExplicitOp::InferShape() const { + param_.X_grad->Resize(param_.Out_grad->dims()); + param_.Y_grad->Resize(param_.Y->dims()); + return true; +} + +bool ElementwiseGradExplicitOp::AttachImpl(const cpp::OpDesc& opdesc, + lite::Scope* scope) { + CHECK_EQ(opdesc.InputArgumentNames().size(), 1UL); + auto Out_name = opdesc.Input(framework::GradVarName("Out")).front(); + auto X_name = opdesc.Output(framework::GradVarName("X")).front(); + auto Y_name = opdesc.Output(framework::GradVarName("Y")).front(); + + param_.Out_grad = GetVar(scope, Out_name); + param_.X_grad = GetMutableVar(scope, X_name); + param_.Y_grad = GetMutableVar(scope, Y_name); + param_.axis = opdesc.GetAttr("axis"); + + return true; +} #endif } // namespace operators diff --git a/paddle/fluid/lite/operators/elementwise_ops.h b/paddle/fluid/lite/operators/elementwise_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..8e427f708fcab5a74052a5ea13776709d7f4f72e --- /dev/null +++ b/paddle/fluid/lite/operators/elementwise_ops.h @@ -0,0 +1,65 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "paddle/fluid/lite/core/op_lite.h" + +namespace paddle { +namespace lite { +namespace operators { + +class ElementwiseOp : public OpLite { + public: + explicit ElementwiseOp(const std::string& op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; + + void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "elementwise_op"; } + + private: + mutable operators::ElementwiseParam param_; +}; + +#ifdef LITE_WITH_X86 +class ElementwiseGradExplicitOp : public OpLite { + public: + explicit ElementwiseGradExplicitOp(const std::string& type) : OpLite(type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; + + void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { + return "elementwise_grad_explicit_op"; + } + + private: + mutable operators::ElementwiseGradParam param_; +}; +#endif + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/operators/fusion_elementwise_activation_ops.cc b/paddle/fluid/lite/operators/fusion_elementwise_activation_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..c7c57810fe6f6b4c1ed04883ec736eca6abc297d --- /dev/null +++ b/paddle/fluid/lite/operators/fusion_elementwise_activation_ops.cc @@ -0,0 +1,57 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/operators/fusion_elementwise_activation_ops.h" +#include +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool FusionElementwiseActivationOp::AttachImpl(const cpp::OpDesc& opdesc, + lite::Scope* scope) { + ElementwiseOp::AttachImpl(opdesc, scope); + param_.act_type = opdesc.GetAttr("act_type"); + // TODO(sangoly): support more activation types. + CHECK(param_.act_type == "relu") << "Only relu activation be supported now"; + + return true; +} + +#ifdef LITE_WITH_X86 +bool FusionElementwiseActivationGradExplicitOp::AttachImpl( + const cpp::OpDesc& opdesc, lite::Scope* scope) { + ElementwiseGradExplicitOp::AttachImpl(opdesc, scope); + param_.act_type = opdesc.GetAttr("act_type"); + // TODO(sangoly): support more activation types. + CHECK(param_.act_type == "relu") << "Only relu activation be supported now"; + + return true; +} +#endif + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(fusion_elementwise_sub_activation, + paddle::lite::operators::FusionElementwiseActivationOp); +#ifdef LITE_WITH_X86 +REGISTER_LITE_OP( + fusion_elementwise_sub_activation_grad, + paddle::lite::operators::FusionElementwiseActivationGradExplicitOp); +#endif +REGISTER_LITE_OP(fusion_elementwise_add_activation, + paddle::lite::operators::FusionElementwiseActivationOp); diff --git a/paddle/fluid/lite/operators/fusion_elementwise_activation_ops.h b/paddle/fluid/lite/operators/fusion_elementwise_activation_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..78ec419925f3d23d5eac0a9a62d82588e52e0d2c --- /dev/null +++ b/paddle/fluid/lite/operators/fusion_elementwise_activation_ops.h @@ -0,0 +1,60 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/lite/operators/elementwise_ops.h" + +namespace paddle { +namespace lite { +namespace operators { + +class FusionElementwiseActivationOp : public ElementwiseOp { + public: + explicit FusionElementwiseActivationOp(const std::string& type) + : ElementwiseOp(type) {} + + bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; + + std::string DebugString() const override { + return "fusion_elementwise_activation_op"; + } + + private: + mutable operators::FusionElementwiseActivationParam param_; +}; + +#ifdef LITE_WITH_X86 +class FusionElementwiseActivationGradExplicitOp + : public ElementwiseGradExplicitOp { + public: + explicit FusionElementwiseActivationGradExplicitOp(const std::string& type) + : ElementwiseGradExplicitOp(type) {} + + bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; + + std::string DebugString() const override { + return "fusion_elementwise_activation_grad_explicit_op"; + } + + private: + mutable operators::FusionElementwiseActivationGradParam param_; +}; +#endif + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/operators/fusion_elementwise_activation_ops_test.cc b/paddle/fluid/lite/operators/fusion_elementwise_activation_ops_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..07566e25fc1133bc09c62a97d2cfcb4c823164a0 --- /dev/null +++ b/paddle/fluid/lite/operators/fusion_elementwise_activation_ops_test.cc @@ -0,0 +1,63 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/operators/fusion_elementwise_activation_ops.h" +#include +#include +#include +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +TEST(fusion_elementwise_activation_op_lite, test) { + // prepare variables + lite::Scope scope; + auto* x = scope.Var("x")->GetMutable(); + auto* y = scope.Var("y")->GetMutable(); + auto* out = scope.Var("out")->GetMutable(); + x->Resize(lite::DDim(std::vector({10, 20}))); + y->Resize(lite::DDim(std::vector({10, 20}))); + out->Resize(lite::DDim(std::vector{10, 20})); + + // set data + for (int i = 0; i < 10 * 20; i++) { + x->mutable_data()[i] = i; + } + for (int i = 0; i < 10 * 20; i++) { + y->mutable_data()[i] = i; + } + for (int i = 0; i < 10 * 20; i++) { + out->mutable_data()[i] = 0.; + } + + // prepare op desc + cpp::OpDesc desc; + desc.SetType("fusion_elementwise_add_activation"); + desc.SetInput("X", {"x"}); + desc.SetInput("Y", {"y"}); + desc.SetOutput("Out", {"out"}); + desc.SetAttr("axis", static_cast(1)); + desc.SetAttr("act_type", std::string("relu")); + + FusionElementwiseActivationOp fuse_op("fusion_elementwise_add_activation"); + + fuse_op.SetValidPlaces({Place{TARGET(kX86), PRECISION(kFloat)}}); + fuse_op.Attach(desc, &scope); +} + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/operators/op_params.h b/paddle/fluid/lite/operators/op_params.h index 91a6067959854f608e31a6151a4e63e26df7eb64..25d27f03664558654987601adb525e063c668b7c 100644 --- a/paddle/fluid/lite/operators/op_params.h +++ b/paddle/fluid/lite/operators/op_params.h @@ -219,6 +219,14 @@ struct ElementwiseGradParam { int axis{-1}; // for broadcasting. }; +struct FusionElementwiseActivationParam : public ElementwiseParam { + std::string act_type; +}; + +struct FusionElementwiseActivationGradParam : public ElementwiseGradParam { + std::string act_type; +}; + /// ----------------------- activation operators ---------------------- struct ActivationParam { const lite::Tensor* X{};