From cb5065792831331d0c9fd9e79853e0ee917ebe9a Mon Sep 17 00:00:00 2001 From: Shang Zhizhou Date: Fri, 30 Apr 2021 09:21:29 +0800 Subject: [PATCH] Nne integration (#32604) (#32658) * Add dlnne engine runtime * Remove and remove unrelated modify with dlnne, +clang-format * Add copyright message * Add some paddlepaddle_pass to support more networks * Add delete dropout_op pass Co-authored-by: denglin-github <82362191+denglin-github@users.noreply.github.com> --- paddle/fluid/framework/ir/CMakeLists.txt | 1 + .../framework/ir/delete_dropout_op_pass.cc | 96 +++++++++++++++++++ .../framework/ir/delete_dropout_op_pass.h | 37 +++++++ .../framework/ir/graph_pattern_detector.cc | 23 +++++ .../framework/ir/graph_pattern_detector.h | 13 +++ .../inference/api/paddle_pass_builder.cc | 1 + 6 files changed, 171 insertions(+) create mode 100644 paddle/fluid/framework/ir/delete_dropout_op_pass.cc create mode 100644 paddle/fluid/framework/ir/delete_dropout_op_pass.h diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 0ca78c679a..ab69170322 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -86,6 +86,7 @@ pass_library(quant_conv2d_dequant_fuse_pass inference) pass_library(shuffle_channel_detect_pass inference) pass_library(delete_quant_dequant_op_pass inference) pass_library(delete_quant_dequant_filter_op_pass inference) +pass_library(delete_dropout_op_pass inference) pass_library(simplify_with_basic_ops_pass base) pass_library(fc_elementwise_layernorm_fuse_pass base) pass_library(skip_layernorm_fuse_pass base) diff --git a/paddle/fluid/framework/ir/delete_dropout_op_pass.cc b/paddle/fluid/framework/ir/delete_dropout_op_pass.cc new file mode 100644 index 0000000000..09962239a0 --- /dev/null +++ b/paddle/fluid/framework/ir/delete_dropout_op_pass.cc @@ -0,0 +1,96 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +#include "paddle/fluid/framework/ir/delete_dropout_op_pass.h" + +namespace paddle { +namespace framework { +class LoDTensor; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { + +#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern); +#define GET_NODES \ + GET_IR_NODE(any_op_out); \ + GET_IR_NODE(dropout_op); \ + GET_IR_NODE(dropout_op_out); \ + GET_IR_NODE(dropout_op_outmask); \ + GET_IR_NODE(any_op2); + +void DeleteDropoutOpPass::ApplyImpl(ir::Graph* graph) const { + const std::string pattern_name = "delete_dropout_op_pattern"; + FusePassBase::Init(pattern_name, graph); + + GraphPatternDetector gpd; + + patterns::DeleteDropoutOpPattern pattern(gpd.mutable_pattern(), pattern_name); + pattern(); + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_NODES; + IR_NODE_LINK_TO(any_op_out, any_op2); + std::string any_op_out_name = any_op_out->Var()->Name(); + std::string dropout_op_out_name = dropout_op_out->Var()->Name(); + + auto* any_op2_desc = any_op2->Op(); + auto var_map = any_op2_desc->Inputs(); + std::string arg_name = ""; + for (auto& name_m : var_map) { + if (std::find(name_m.second.begin(), name_m.second.end(), + dropout_op_out_name) != name_m.second.end()) { + arg_name = name_m.first; + } + } + if (arg_name.size() == 0) { + LOG(INFO) << "Delete dropout op pass: can not find the input " + << dropout_op_out_name; + return; + } + + // modify the any_op2's inputs + for (auto& name_m : var_map) { + if (std::find(name_m.second.begin(), name_m.second.end(), + dropout_op_out_name) != name_m.second.end()) { + std::vector new_inputs; + for (auto& i_n : name_m.second) { + if (i_n != dropout_op_out_name) { + new_inputs.push_back(i_n); + } + } + new_inputs.push_back(any_op_out_name); + any_op2_desc->SetInput(name_m.first, new_inputs); + any_op2_desc->Flush(); + } + } + any_op2_desc->Flush(); + // Delete the unneeded nodes. + GraphSafeRemoveNodes(graph, + {dropout_op, dropout_op_out, dropout_op_outmask}); + }; + + gpd(graph, handler); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(delete_dropout_op_pass, + paddle::framework::ir::DeleteDropoutOpPass); diff --git a/paddle/fluid/framework/ir/delete_dropout_op_pass.h b/paddle/fluid/framework/ir/delete_dropout_op_pass.h new file mode 100644 index 0000000000..c49abf3c87 --- /dev/null +++ b/paddle/fluid/framework/ir/delete_dropout_op_pass.h @@ -0,0 +1,37 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { + +class Graph; + +class DeleteDropoutOpPass : public FusePassBase { + public: + virtual ~DeleteDropoutOpPass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index d74e8e5f65..064da3d941 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2439,6 +2439,29 @@ PDNode *patterns::TransposeFlattenConcat::operator()( return concat_out; } +void patterns::DeleteDropoutOpPattern::operator()() { + auto any_op_out = pattern->NewNode(any_op_out_repr()) + ->assert_is_op_input("dropout", "X") + ->AsInput(); + + auto dropout_op = + pattern->NewNode(dropout_op_repr())->assert_is_op("dropout"); + + auto dropout_op_out = pattern->NewNode(dropout_op_out_repr()) + ->assert_is_op_output("dropout", "Out") + ->AsIntermediate(); + + auto dropout_op_outmask = pattern->NewNode(dropout_op_outmask_repr()) + ->assert_is_op_output("dropout", "Mask") + ->AsOutput(); + auto any_op2 = pattern->NewNode(any_op2_repr())->assert_is_op()->AsOutput(); + + dropout_op->LinksFrom({any_op_out}); + dropout_op_out->LinksFrom({dropout_op}); + dropout_op_outmask->LinksFrom({dropout_op}); + any_op2->LinksFrom({dropout_op_out}); +} + void patterns::DeleteQuantOpFuse::operator()(PDNode *input_act_node, const std::string &quant_type) { auto *input_scale_node = pattern->NewNode(GetNodeName("input_scale_node")) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index cfac01ec9d..13f6585995 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1464,6 +1464,19 @@ struct ShuffleChannelPattern : public PatternBase { PATTERN_DECL_NODE(reshape2_out); }; +struct DeleteDropoutOpPattern : public PatternBase { + DeleteDropoutOpPattern(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "delete_dropout_op_pattern") {} + + void operator()(); + + PATTERN_DECL_NODE(any_op_out); + PATTERN_DECL_NODE(dropout_op); + PATTERN_DECL_NODE(dropout_op_out); + PATTERN_DECL_NODE(dropout_op_outmask); + PATTERN_DECL_NODE(any_op2); +}; + struct DeleteQuantDequantOpPattern : public PatternBase { DeleteQuantDequantOpPattern(PDPattern* pattern, const std::string& name_scope) : PatternBase(pattern, name_scope, "delete_quantdequant_op_pattern") {} diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 2b7333edae..b2e3de6369 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -112,6 +112,7 @@ const std::vector kTRTSubgraphPasses({ const std::vector kDlnneSubgraphPasses({ "is_test_pass", // + "delete_dropout_op_pass" // "simplify_with_basic_ops_pass", // "conv_bn_fuse_pass", // "depthwise_conv_bn_fuse_pass", // -- GitLab