diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 0ca78c679aecaa396b59c7d50471baee239ba622..ab69170322ce3ec4eaa8e46b53e490b634df64b7 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -86,6 +86,7 @@ pass_library(quant_conv2d_dequant_fuse_pass inference) pass_library(shuffle_channel_detect_pass inference) pass_library(delete_quant_dequant_op_pass inference) pass_library(delete_quant_dequant_filter_op_pass inference) +pass_library(delete_dropout_op_pass inference) pass_library(simplify_with_basic_ops_pass base) pass_library(fc_elementwise_layernorm_fuse_pass base) pass_library(skip_layernorm_fuse_pass base) diff --git a/paddle/fluid/framework/ir/delete_dropout_op_pass.cc b/paddle/fluid/framework/ir/delete_dropout_op_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..09962239a01b1839bea93846ca3ffe9ded3cca4e --- /dev/null +++ b/paddle/fluid/framework/ir/delete_dropout_op_pass.cc @@ -0,0 +1,96 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +#include "paddle/fluid/framework/ir/delete_dropout_op_pass.h" + +namespace paddle { +namespace framework { +class LoDTensor; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { + +#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern); +#define GET_NODES \ + GET_IR_NODE(any_op_out); \ + GET_IR_NODE(dropout_op); \ + GET_IR_NODE(dropout_op_out); \ + GET_IR_NODE(dropout_op_outmask); \ + GET_IR_NODE(any_op2); + +void DeleteDropoutOpPass::ApplyImpl(ir::Graph* graph) const { + const std::string pattern_name = "delete_dropout_op_pattern"; + FusePassBase::Init(pattern_name, graph); + + GraphPatternDetector gpd; + + patterns::DeleteDropoutOpPattern pattern(gpd.mutable_pattern(), pattern_name); + pattern(); + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_NODES; + IR_NODE_LINK_TO(any_op_out, any_op2); + std::string any_op_out_name = any_op_out->Var()->Name(); + std::string dropout_op_out_name = dropout_op_out->Var()->Name(); + + auto* any_op2_desc = any_op2->Op(); + auto var_map = any_op2_desc->Inputs(); + std::string arg_name = ""; + for (auto& name_m : var_map) { + if (std::find(name_m.second.begin(), name_m.second.end(), + dropout_op_out_name) != name_m.second.end()) { + arg_name = name_m.first; + } + } + if (arg_name.size() == 0) { + LOG(INFO) << "Delete dropout op pass: can not find the input " + << dropout_op_out_name; + return; + } + + // modify the any_op2's inputs + for (auto& name_m : var_map) { + if (std::find(name_m.second.begin(), name_m.second.end(), + dropout_op_out_name) != name_m.second.end()) { + std::vector new_inputs; + for (auto& i_n : name_m.second) { + if (i_n != dropout_op_out_name) { + new_inputs.push_back(i_n); + } + } + new_inputs.push_back(any_op_out_name); + any_op2_desc->SetInput(name_m.first, new_inputs); + any_op2_desc->Flush(); + } + } + any_op2_desc->Flush(); + // Delete the unneeded nodes. + GraphSafeRemoveNodes(graph, + {dropout_op, dropout_op_out, dropout_op_outmask}); + }; + + gpd(graph, handler); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(delete_dropout_op_pass, + paddle::framework::ir::DeleteDropoutOpPass); diff --git a/paddle/fluid/framework/ir/delete_dropout_op_pass.h b/paddle/fluid/framework/ir/delete_dropout_op_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..c49abf3c871ced474bc47e28ec32d29bc9ccf750 --- /dev/null +++ b/paddle/fluid/framework/ir/delete_dropout_op_pass.h @@ -0,0 +1,37 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { + +class Graph; + +class DeleteDropoutOpPass : public FusePassBase { + public: + virtual ~DeleteDropoutOpPass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index d74e8e5f65cd2020433e9658ee9520d51c13387a..064da3d941602ee0e4f868fb0dbda305102da32b 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2439,6 +2439,29 @@ PDNode *patterns::TransposeFlattenConcat::operator()( return concat_out; } +void patterns::DeleteDropoutOpPattern::operator()() { + auto any_op_out = pattern->NewNode(any_op_out_repr()) + ->assert_is_op_input("dropout", "X") + ->AsInput(); + + auto dropout_op = + pattern->NewNode(dropout_op_repr())->assert_is_op("dropout"); + + auto dropout_op_out = pattern->NewNode(dropout_op_out_repr()) + ->assert_is_op_output("dropout", "Out") + ->AsIntermediate(); + + auto dropout_op_outmask = pattern->NewNode(dropout_op_outmask_repr()) + ->assert_is_op_output("dropout", "Mask") + ->AsOutput(); + auto any_op2 = pattern->NewNode(any_op2_repr())->assert_is_op()->AsOutput(); + + dropout_op->LinksFrom({any_op_out}); + dropout_op_out->LinksFrom({dropout_op}); + dropout_op_outmask->LinksFrom({dropout_op}); + any_op2->LinksFrom({dropout_op_out}); +} + void patterns::DeleteQuantOpFuse::operator()(PDNode *input_act_node, const std::string &quant_type) { auto *input_scale_node = pattern->NewNode(GetNodeName("input_scale_node")) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index cfac01ec9dedc83af4bfdce30678f933d9a8e921..13f65859954d58ce446ab3b9de488833f6220dee 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1464,6 +1464,19 @@ struct ShuffleChannelPattern : public PatternBase { PATTERN_DECL_NODE(reshape2_out); }; +struct DeleteDropoutOpPattern : public PatternBase { + DeleteDropoutOpPattern(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "delete_dropout_op_pattern") {} + + void operator()(); + + PATTERN_DECL_NODE(any_op_out); + PATTERN_DECL_NODE(dropout_op); + PATTERN_DECL_NODE(dropout_op_out); + PATTERN_DECL_NODE(dropout_op_outmask); + PATTERN_DECL_NODE(any_op2); +}; + struct DeleteQuantDequantOpPattern : public PatternBase { DeleteQuantDequantOpPattern(PDPattern* pattern, const std::string& name_scope) : PatternBase(pattern, name_scope, "delete_quantdequant_op_pattern") {} diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 2b7333edae0dae1f0313bf71fc824c922e20b84d..b2e3de63691c555b24eb6f1e1fb9ffcc35d400f9 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -112,6 +112,7 @@ const std::vector kTRTSubgraphPasses({ const std::vector kDlnneSubgraphPasses({ "is_test_pass", // + "delete_dropout_op_pass" // "simplify_with_basic_ops_pass", // "conv_bn_fuse_pass", // "depthwise_conv_bn_fuse_pass", //