提交 cb506579 编写于 作者: S Shang Zhizhou 提交者: GitHub

Nne integration (#32604) (#32658)

* Add dlnne engine runtime

* Remove <const_cast> and remove unrelated modify with dlnne, +clang-format

* Add copyright message

* Add some paddlepaddle_pass to support more networks

* Add delete dropout_op pass
Co-authored-by: Ndenglin-github <82362191+denglin-github@users.noreply.github.com>
上级 e7c81600
......@@ -86,6 +86,7 @@ pass_library(quant_conv2d_dequant_fuse_pass inference)
pass_library(shuffle_channel_detect_pass inference)
pass_library(delete_quant_dequant_op_pass inference)
pass_library(delete_quant_dequant_filter_op_pass inference)
pass_library(delete_dropout_op_pass inference)
pass_library(simplify_with_basic_ops_pass base)
pass_library(fc_elementwise_layernorm_fuse_pass base)
pass_library(skip_layernorm_fuse_pass base)
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/fluid/framework/ir/delete_dropout_op_pass.h"
namespace paddle {
namespace framework {
class LoDTensor;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(any_op_out); \
GET_IR_NODE(dropout_op); \
GET_IR_NODE(dropout_op_out); \
GET_IR_NODE(dropout_op_outmask); \
GET_IR_NODE(any_op2);
void DeleteDropoutOpPass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "delete_dropout_op_pattern";
FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd;
patterns::DeleteDropoutOpPattern pattern(gpd.mutable_pattern(), pattern_name);
pattern();
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_NODES;
IR_NODE_LINK_TO(any_op_out, any_op2);
std::string any_op_out_name = any_op_out->Var()->Name();
std::string dropout_op_out_name = dropout_op_out->Var()->Name();
auto* any_op2_desc = any_op2->Op();
auto var_map = any_op2_desc->Inputs();
std::string arg_name = "";
for (auto& name_m : var_map) {
if (std::find(name_m.second.begin(), name_m.second.end(),
dropout_op_out_name) != name_m.second.end()) {
arg_name = name_m.first;
}
}
if (arg_name.size() == 0) {
LOG(INFO) << "Delete dropout op pass: can not find the input "
<< dropout_op_out_name;
return;
}
// modify the any_op2's inputs
for (auto& name_m : var_map) {
if (std::find(name_m.second.begin(), name_m.second.end(),
dropout_op_out_name) != name_m.second.end()) {
std::vector<std::string> new_inputs;
for (auto& i_n : name_m.second) {
if (i_n != dropout_op_out_name) {
new_inputs.push_back(i_n);
}
}
new_inputs.push_back(any_op_out_name);
any_op2_desc->SetInput(name_m.first, new_inputs);
any_op2_desc->Flush();
}
}
any_op2_desc->Flush();
// Delete the unneeded nodes.
GraphSafeRemoveNodes(graph,
{dropout_op, dropout_op_out, dropout_op_outmask});
};
gpd(graph, handler);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(delete_dropout_op_pass,
paddle::framework::ir::DeleteDropoutOpPass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
class Graph;
class DeleteDropoutOpPass : public FusePassBase {
public:
virtual ~DeleteDropoutOpPass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -2439,6 +2439,29 @@ PDNode *patterns::TransposeFlattenConcat::operator()(
return concat_out;
}
void patterns::DeleteDropoutOpPattern::operator()() {
auto any_op_out = pattern->NewNode(any_op_out_repr())
->assert_is_op_input("dropout", "X")
->AsInput();
auto dropout_op =
pattern->NewNode(dropout_op_repr())->assert_is_op("dropout");
auto dropout_op_out = pattern->NewNode(dropout_op_out_repr())
->assert_is_op_output("dropout", "Out")
->AsIntermediate();
auto dropout_op_outmask = pattern->NewNode(dropout_op_outmask_repr())
->assert_is_op_output("dropout", "Mask")
->AsOutput();
auto any_op2 = pattern->NewNode(any_op2_repr())->assert_is_op()->AsOutput();
dropout_op->LinksFrom({any_op_out});
dropout_op_out->LinksFrom({dropout_op});
dropout_op_outmask->LinksFrom({dropout_op});
any_op2->LinksFrom({dropout_op_out});
}
void patterns::DeleteQuantOpFuse::operator()(PDNode *input_act_node,
const std::string &quant_type) {
auto *input_scale_node = pattern->NewNode(GetNodeName("input_scale_node"))
......
......@@ -1464,6 +1464,19 @@ struct ShuffleChannelPattern : public PatternBase {
PATTERN_DECL_NODE(reshape2_out);
};
struct DeleteDropoutOpPattern : public PatternBase {
DeleteDropoutOpPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "delete_dropout_op_pattern") {}
void operator()();
PATTERN_DECL_NODE(any_op_out);
PATTERN_DECL_NODE(dropout_op);
PATTERN_DECL_NODE(dropout_op_out);
PATTERN_DECL_NODE(dropout_op_outmask);
PATTERN_DECL_NODE(any_op2);
};
struct DeleteQuantDequantOpPattern : public PatternBase {
DeleteQuantDequantOpPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "delete_quantdequant_op_pattern") {}
......
......@@ -112,6 +112,7 @@ const std::vector<std::string> kTRTSubgraphPasses({
const std::vector<std::string> kDlnneSubgraphPasses({
"is_test_pass", //
"delete_dropout_op_pass" //
"simplify_with_basic_ops_pass", //
"conv_bn_fuse_pass", //
"depthwise_conv_bn_fuse_pass", //
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册