diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index ba1d7379c56d953a0f37d03deed6c47e46cbf129..a26732926c2c6e376079c610078c80c6a8afd452 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -70,6 +70,7 @@ pass_library(sync_batch_norm_pass base) pass_library(runtime_context_cache_pass base) pass_library(quant_conv2d_dequant_fuse_pass inference) pass_library(fillconstant_elementwisemul_fuse inference) +pass_library(shuffle_channel_detect_pass inference) if(ANAKIN_FOUND) pass_library(simplify_anakin_priorbox_detection_out_pass inference) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 77f50e914b668ebfeb2fcaf5de8f91a74f0c0d3b..0dcf064902d1c1c6cb034421cedea0387b6e0505 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1706,6 +1706,37 @@ void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input, } } +void patterns::ShuffleChannelPattern::operator()(PDNode *reshape1_in) { + auto reshape1_op = + pattern->NewNode(reshape1_op_repr())->assert_is_op("reshape2"); + + auto reshape1_out = pattern->NewNode(reshape1_out_repr()) + ->assert_is_op_output("reshape2", "Out") + ->assert_is_op_input("transpose2") + ->AsIntermediate(); + + auto transpose_op = + pattern->NewNode(transpose_op_repr())->assert_is_op("transpose2"); + + auto transpose_out = pattern->NewNode(transpose_out_repr()) + ->assert_is_op_output("transpose2", "Out") + ->assert_is_op_input("reshape2") + ->AsIntermediate(); + + auto reshape2_op = + pattern->NewNode(reshape2_op_repr())->assert_is_op("reshape2"); + auto reshape2_out = pattern->NewNode(reshape2_out_repr()) + ->assert_is_op_output("reshape2", "Out") + ->AsOutput(); + + reshape1_op->LinksFrom({reshape1_in}); + reshape1_out->LinksFrom({reshape1_op}); + transpose_op->LinksFrom({reshape1_out}); + transpose_out->LinksFrom({transpose_op}); + reshape2_op->LinksFrom({transpose_out}); + reshape2_out->LinksFrom({reshape2_op}); +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 525987e0072cb05ad3df4d09a17ac172e48ce133..907371b56b06dcd66297adedea6c17b61d9b5e38 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -892,6 +892,21 @@ struct QuantDequantOpFuse : public PatternBase { } }; +struct ShuffleChannelPattern : public PatternBase { + ShuffleChannelPattern(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "shufflechannel_pattern") {} + + void operator()(PDNode* reshape1_in); + + PATTERN_DECL_NODE(reshape1_op); + PATTERN_DECL_NODE(reshape1_out); + + PATTERN_DECL_NODE(transpose_op); + PATTERN_DECL_NODE(transpose_out); + PATTERN_DECL_NODE(reshape2_op); + PATTERN_DECL_NODE(reshape2_out); +}; + } // namespace patterns // Link two ir::Nodes from each other. diff --git a/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc b/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..e55783637a6e08578ef7717ba9768f7eece7ca8f --- /dev/null +++ b/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc @@ -0,0 +1,93 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/fluid/framework/ir/graph_viz_pass.h" +#include "paddle/fluid/framework/ir/shuffle_channel_detect_pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern); +#define GET_NODES \ + GET_IR_NODE(reshape1_op); \ + GET_IR_NODE(reshape1_out); \ + GET_IR_NODE(transpose_op); \ + GET_IR_NODE(transpose_out); \ + GET_IR_NODE(reshape2_op); \ + GET_IR_NODE(reshape2_out); + +void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const { + const std::string pattern_name = "shufflechannel_pattern"; + FusePassBase::Init(pattern_name, graph); + + GraphPatternDetector gpd; + auto* x = gpd.mutable_pattern() + ->NewNode("x") + ->assert_is_op_input("reshape2", "X") + ->AsInput(); + + patterns::ShuffleChannelPattern pattern(gpd.mutable_pattern(), pattern_name); + pattern(x); + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_NODES; + + PADDLE_ENFORCE(subgraph.count(x)); + auto* input_node = subgraph.at(x); + auto reshape1_desc = reshape1_op->Op(); + auto reshape2_desc = reshape2_op->Op(); + std::string input_name = input_node->Name(); + std::string output_name = reshape2_out->Name(); + + auto reshape1_shape = + boost::get>(reshape1_desc->GetAttr("shape")); + auto reshape2_shape = + boost::get>(reshape2_desc->GetAttr("shape")); + + int i_c = reshape1_shape[2]; + int o_c = reshape2_shape[1]; + int group = o_c / i_c; + + framework::OpDesc new_op_desc; + new_op_desc.SetType("shuffle_channel"); + new_op_desc.SetInput("X", {input_name}); + new_op_desc.SetOutput("Out", {output_name}); + + new_op_desc.SetAttr("group", group); + new_op_desc.Flush(); + + // Create a new node for the fused op. + auto* new_op = graph->CreateOpNode(&new_op_desc); + + IR_NODE_LINK_TO(input_node, new_op); + IR_NODE_LINK_TO(new_op, reshape2_out); + + // Delete the unneeded nodes. + GraphSafeRemoveNodes(graph, {reshape1_op, reshape1_out, transpose_op, + transpose_out, reshape2_op}); + }; + + gpd(graph, handler); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(shuffle_channel_detect_pass, + paddle::framework::ir::ShuffleChannelDetectPass); diff --git a/paddle/fluid/framework/ir/shuffle_channel_detect_pass.h b/paddle/fluid/framework/ir/shuffle_channel_detect_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..008f8013efd28b3cdc5a846662653e07e45e3985 --- /dev/null +++ b/paddle/fluid/framework/ir/shuffle_channel_detect_pass.h @@ -0,0 +1,34 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { + +class ShuffleChannelDetectPass : public FusePassBase { + public: + virtual ~ShuffleChannelDetectPass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/anakin/convert/CMakeLists.txt b/paddle/fluid/inference/anakin/convert/CMakeLists.txt index 6546d3b855fbc1a3243b56f3ee2f8f21625c2a93..5d85525a652a6016694e012853c95aca086b3fd9 100644 --- a/paddle/fluid/inference/anakin/convert/CMakeLists.txt +++ b/paddle/fluid/inference/anakin/convert/CMakeLists.txt @@ -2,8 +2,8 @@ cc_library(anakin_op_converter SRCS fc.cc conv2d.cc conv2d_fusion.cc elementwise.cc activation.cc pool2d.cc concat.cc split.cc relu.cc softmax.cc batch_norm.cc reshape.cc flatten.cc transpose.cc density_prior_box.cc detection_out.cc scale.cc dropout.cc im2sequence.cc sum.cc affine_channel.cc -roi_align.cc helper.cc DEPS anakin_engine framework_proto scope op_registry -gtest) +roi_align.cc shuffle_channel.cc helper.cc DEPS anakin_engine framework_proto +scope op_registry gtest) cc_test(test_anakin_fc SRCS test_fc_op.cc DEPS anakin_op_converter mul_op SERIAL) cc_test(test_anakin_conv2d SRCS test_conv2d_op.cc DEPS anakin_op_converter conv_op im2col vol2col depthwise_conv SERIAL) diff --git a/paddle/fluid/inference/anakin/convert/shuffle_channel.cc b/paddle/fluid/inference/anakin/convert/shuffle_channel.cc new file mode 100644 index 0000000000000000000000000000000000000000..fdd2e3182e34992205d7707b83efbc3c6421076c --- /dev/null +++ b/paddle/fluid/inference/anakin/convert/shuffle_channel.cc @@ -0,0 +1,47 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/anakin/convert/shuffle_channel.h" +#include +#include +#include + +using anakin::PTuple; + +namespace paddle { +namespace inference { +namespace anakin { + +template +void ShuffleChannelOpConverter::operator()( + const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc, + const framework::Scope &scope, bool test_mode) { + framework::OpDesc op_desc(op, nullptr); + PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1); + PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1); + + auto input = op_desc.Input("X").front(); + auto output = op_desc.Output("Out").front(); + auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front(); + this->engine_->AddOp(op_name, "ShuffleChannel", {input}, {output}); + + auto group = boost::get(op_desc.GetAttr("group")); + this->engine_->AddOpAttr(op_name, "group", group); +} + +} // namespace anakin +} // namespace inference +} // namespace paddle + +REGISTER_ANAKIN_OP_CONVERTER(shuffle_channel, ShuffleChannelOpConverter); diff --git a/paddle/fluid/inference/anakin/convert/shuffle_channel.h b/paddle/fluid/inference/anakin/convert/shuffle_channel.h new file mode 100644 index 0000000000000000000000000000000000000000..457a14865a91bd6cfa763513f01cda72e34186e8 --- /dev/null +++ b/paddle/fluid/inference/anakin/convert/shuffle_channel.h @@ -0,0 +1,38 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/inference/anakin/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace anakin { + +template +class ShuffleChannelOpConverter + : public AnakinOpConverter { + public: + ShuffleChannelOpConverter() = default; + + virtual void operator()(const framework::proto::OpDesc &op, + const framework::BlockDesc &block_desc, + const framework::Scope &scope, + bool test_mode) override; + virtual ~ShuffleChannelOpConverter() {} +}; + +} // namespace anakin +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/anakin/op_teller.cc b/paddle/fluid/inference/anakin/op_teller.cc index 6cad00f8ecfe872924ed3804847cb22b8932b91d..67b771226c4999a361a818e32e8caedd81723c03 100644 --- a/paddle/fluid/inference/anakin/op_teller.cc +++ b/paddle/fluid/inference/anakin/op_teller.cc @@ -48,6 +48,7 @@ struct SimpleOpTypeSetTeller : public Teller { teller_set.insert("affine_channel"); teller_set.insert("relu6"); teller_set.insert("swish"); + teller_set.insert("shuffle_channel"); } bool operator()(const std::string& op_type, diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 677f5bf130f124db90ac99155d2ec336604ec17e..a9d8113a7721cfce123e618538f79ac75b9637fe 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -896,4 +896,5 @@ USE_ANAKIN_CONVERTER(leaky_relu); USE_ANAKIN_CONVERTER(affine_channel); USE_ANAKIN_CONVERTER(relu6); USE_ANAKIN_CONVERTER(swish); +USE_ANAKIN_CONVERTER(shuffle_channel); #endif diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index fea291c5528a11fd18b1069a5d57e456c8cc84fc..48a140aa1135b6bf45f229c30b4028e99ed9dd96 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -79,6 +79,7 @@ const std::vector kAnakinSubgraphPasses({ "fc_fuse_pass", // "conv_elementwise_add_fuse_pass", // "fc_gru_fuse_pass", // + "shuffle_channel_detect_pass", // "anakin_subgraph_pass", // "fc_gru_fuse_pass", // });