diff --git a/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc b/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..e55783637a6e08578ef7717ba9768f7eece7ca8f --- /dev/null +++ b/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc @@ -0,0 +1,93 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/fluid/framework/ir/graph_viz_pass.h" +#include "paddle/fluid/framework/ir/shuffle_channel_detect_pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern); +#define GET_NODES \ + GET_IR_NODE(reshape1_op); \ + GET_IR_NODE(reshape1_out); \ + GET_IR_NODE(transpose_op); \ + GET_IR_NODE(transpose_out); \ + GET_IR_NODE(reshape2_op); \ + GET_IR_NODE(reshape2_out); + +void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const { + const std::string pattern_name = "shufflechannel_pattern"; + FusePassBase::Init(pattern_name, graph); + + GraphPatternDetector gpd; + auto* x = gpd.mutable_pattern() + ->NewNode("x") + ->assert_is_op_input("reshape2", "X") + ->AsInput(); + + patterns::ShuffleChannelPattern pattern(gpd.mutable_pattern(), pattern_name); + pattern(x); + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_NODES; + + PADDLE_ENFORCE(subgraph.count(x)); + auto* input_node = subgraph.at(x); + auto reshape1_desc = reshape1_op->Op(); + auto reshape2_desc = reshape2_op->Op(); + std::string input_name = input_node->Name(); + std::string output_name = reshape2_out->Name(); + + auto reshape1_shape = + boost::get>(reshape1_desc->GetAttr("shape")); + auto reshape2_shape = + boost::get>(reshape2_desc->GetAttr("shape")); + + int i_c = reshape1_shape[2]; + int o_c = reshape2_shape[1]; + int group = o_c / i_c; + + framework::OpDesc new_op_desc; + new_op_desc.SetType("shuffle_channel"); + new_op_desc.SetInput("X", {input_name}); + new_op_desc.SetOutput("Out", {output_name}); + + new_op_desc.SetAttr("group", group); + new_op_desc.Flush(); + + // Create a new node for the fused op. + auto* new_op = graph->CreateOpNode(&new_op_desc); + + IR_NODE_LINK_TO(input_node, new_op); + IR_NODE_LINK_TO(new_op, reshape2_out); + + // Delete the unneeded nodes. + GraphSafeRemoveNodes(graph, {reshape1_op, reshape1_out, transpose_op, + transpose_out, reshape2_op}); + }; + + gpd(graph, handler); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(shuffle_channel_detect_pass, + paddle::framework::ir::ShuffleChannelDetectPass); diff --git a/paddle/fluid/framework/ir/shuffle_channel_detect_pass.h b/paddle/fluid/framework/ir/shuffle_channel_detect_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..008f8013efd28b3cdc5a846662653e07e45e3985 --- /dev/null +++ b/paddle/fluid/framework/ir/shuffle_channel_detect_pass.h @@ -0,0 +1,34 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { + +class ShuffleChannelDetectPass : public FusePassBase { + public: + virtual ~ShuffleChannelDetectPass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/anakin/convert/CMakeLists.txt b/paddle/fluid/inference/anakin/convert/CMakeLists.txt index 6546d3b855fbc1a3243b56f3ee2f8f21625c2a93..5d85525a652a6016694e012853c95aca086b3fd9 100644 --- a/paddle/fluid/inference/anakin/convert/CMakeLists.txt +++ b/paddle/fluid/inference/anakin/convert/CMakeLists.txt @@ -2,8 +2,8 @@ cc_library(anakin_op_converter SRCS fc.cc conv2d.cc conv2d_fusion.cc elementwise.cc activation.cc pool2d.cc concat.cc split.cc relu.cc softmax.cc batch_norm.cc reshape.cc flatten.cc transpose.cc density_prior_box.cc detection_out.cc scale.cc dropout.cc im2sequence.cc sum.cc affine_channel.cc -roi_align.cc helper.cc DEPS anakin_engine framework_proto scope op_registry -gtest) +roi_align.cc shuffle_channel.cc helper.cc DEPS anakin_engine framework_proto +scope op_registry gtest) cc_test(test_anakin_fc SRCS test_fc_op.cc DEPS anakin_op_converter mul_op SERIAL) cc_test(test_anakin_conv2d SRCS test_conv2d_op.cc DEPS anakin_op_converter conv_op im2col vol2col depthwise_conv SERIAL) diff --git a/paddle/fluid/inference/anakin/convert/shuffle_channel.cc b/paddle/fluid/inference/anakin/convert/shuffle_channel.cc new file mode 100644 index 0000000000000000000000000000000000000000..fdd2e3182e34992205d7707b83efbc3c6421076c --- /dev/null +++ b/paddle/fluid/inference/anakin/convert/shuffle_channel.cc @@ -0,0 +1,47 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/anakin/convert/shuffle_channel.h" +#include +#include +#include + +using anakin::PTuple; + +namespace paddle { +namespace inference { +namespace anakin { + +template +void ShuffleChannelOpConverter::operator()( + const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc, + const framework::Scope &scope, bool test_mode) { + framework::OpDesc op_desc(op, nullptr); + PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1); + PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1); + + auto input = op_desc.Input("X").front(); + auto output = op_desc.Output("Out").front(); + auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front(); + this->engine_->AddOp(op_name, "ShuffleChannel", {input}, {output}); + + auto group = boost::get(op_desc.GetAttr("group")); + this->engine_->AddOpAttr(op_name, "group", group); +} + +} // namespace anakin +} // namespace inference +} // namespace paddle + +REGISTER_ANAKIN_OP_CONVERTER(shuffle_channel, ShuffleChannelOpConverter); diff --git a/paddle/fluid/inference/anakin/convert/shuffle_channel.h b/paddle/fluid/inference/anakin/convert/shuffle_channel.h new file mode 100644 index 0000000000000000000000000000000000000000..457a14865a91bd6cfa763513f01cda72e34186e8 --- /dev/null +++ b/paddle/fluid/inference/anakin/convert/shuffle_channel.h @@ -0,0 +1,38 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/inference/anakin/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace anakin { + +template +class ShuffleChannelOpConverter + : public AnakinOpConverter { + public: + ShuffleChannelOpConverter() = default; + + virtual void operator()(const framework::proto::OpDesc &op, + const framework::BlockDesc &block_desc, + const framework::Scope &scope, + bool test_mode) override; + virtual ~ShuffleChannelOpConverter() {} +}; + +} // namespace anakin +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/anakin/op_teller.cc b/paddle/fluid/inference/anakin/op_teller.cc index 6cad00f8ecfe872924ed3804847cb22b8932b91d..67b771226c4999a361a818e32e8caedd81723c03 100644 --- a/paddle/fluid/inference/anakin/op_teller.cc +++ b/paddle/fluid/inference/anakin/op_teller.cc @@ -48,6 +48,7 @@ struct SimpleOpTypeSetTeller : public Teller { teller_set.insert("affine_channel"); teller_set.insert("relu6"); teller_set.insert("swish"); + teller_set.insert("shuffle_channel"); } bool operator()(const std::string& op_type, diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 677f5bf130f124db90ac99155d2ec336604ec17e..a9d8113a7721cfce123e618538f79ac75b9637fe 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -896,4 +896,5 @@ USE_ANAKIN_CONVERTER(leaky_relu); USE_ANAKIN_CONVERTER(affine_channel); USE_ANAKIN_CONVERTER(relu6); USE_ANAKIN_CONVERTER(swish); +USE_ANAKIN_CONVERTER(shuffle_channel); #endif diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index ab347b85885fe3bc6a46c60e572e3a03185b5f44..48a140aa1135b6bf45f229c30b4028e99ed9dd96 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -79,11 +79,8 @@ const std::vector kAnakinSubgraphPasses({ "fc_fuse_pass", // "conv_elementwise_add_fuse_pass", // "fc_gru_fuse_pass", // - "graph_viz_pass", // "shuffle_channel_detect_pass", // - "graph_viz_pass", // "anakin_subgraph_pass", // - "graph_viz_pass", // "fc_gru_fuse_pass", // });