From 46bc06b5378462c5d14c490f776726269c97dc07 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Wed, 31 Aug 2022 16:51:26 +0800 Subject: [PATCH] add del dropout op pass to jit pe enigne (#45439) * add del dropout op pass to jit pe enigne * add delete dropout test --- paddle/fluid/framework/details/CMakeLists.txt | 3 +- .../fluid/framework/details/build_strategy.cc | 6 + .../fluid/framework/details/build_strategy.h | 4 + paddle/fluid/framework/ir/CMakeLists.txt | 4 + .../framework/ir/delete_dropout_op_pass.cc | 192 ++++++++++++++++++ .../framework/ir/delete_dropout_op_pass.h | 21 ++ .../ir/delete_dropout_op_pass_test.cc | 96 +++++++++ paddle/fluid/jit/engine/pe_engine.cc | 2 + 8 files changed, 327 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/framework/ir/delete_dropout_op_pass_test.cc diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 8908cfe1d25..21518c4f831 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -385,7 +385,8 @@ set(IR_PASS_DEPS runtime_context_cache_pass graph_to_program_pass fix_op_run_order_pass - fuse_gemm_epilogue_pass) + fuse_gemm_epilogue_pass + delete_dropout_op_pass) if(WITH_CINN) set(IR_PASS_DEPS ${IR_PASS_DEPS} build_cinn_pass) diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index 631194d5f23..33ec9b4f47a 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -169,6 +169,11 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { } void AppendOpFusePasses() { + // 1. infernce pass if enabled. + AppendPassWithCheck(strategy_.inference_ && strategy_.del_dropout_, + "delete_dropout_op_x_pass"); + + // 2. trainning pass AppendPassWithCheck(strategy_.fuse_relu_depthwise_conv_, "fuse_relu_depthwise_conv_pass"); AppendPassWithCheck(strategy_.fuse_bn_act_ops_, "fuse_bn_act_pass"); @@ -509,6 +514,7 @@ USE_PASS(fuse_momentum_op_pass); USE_PASS(fuse_all_reduce_op_pass); USE_PASS(runtime_context_cache_pass); USE_PASS(add_reader_dependency_pass); +USE_PASS(delete_dropout_op_x_pass); #ifdef PADDLE_WITH_CINN USE_PASS(build_cinn_pass); #endif diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index 1e27e381500..0ef89ae1ecc 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -147,6 +147,10 @@ struct BuildStrategy { bool allow_cuda_graph_capture_{false}; + // Inference pass + bool inference_{false}; // switch for infernce pass + bool del_dropout_{false}; + // FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode, // num_trainers is 1, so the current fields of build_strategy doesn't tell if // it's distributed model. diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 33437bcd836..5c9841aef17 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -369,6 +369,10 @@ cc_test( test_generate_pass_cc SRCS generate_pass_tester.cc DEPS generate_pass pass_desc_proto) +cc_test( + test_delete_dropout_pass_cc + SRCS delete_dropout_op_pass_test.cc + DEPS delete_dropout_op_pass) if(WITH_GPU OR WITH_ROCM) cc_test( test_embedding_eltwise_layernorm_fuse_pass diff --git a/paddle/fluid/framework/ir/delete_dropout_op_pass.cc b/paddle/fluid/framework/ir/delete_dropout_op_pass.cc index 3f543f26b6b..7066ee5f6c0 100644 --- a/paddle/fluid/framework/ir/delete_dropout_op_pass.cc +++ b/paddle/fluid/framework/ir/delete_dropout_op_pass.cc @@ -15,6 +15,8 @@ #include +#include "paddle/fluid/framework/op_version_registry.h" + namespace phi { class DenseTensor; } // namespace phi @@ -47,6 +49,7 @@ void DeleteDropoutOpPass::ApplyImpl(ir::Graph* graph) const { std::string any_op_out_name = any_op_out->Var()->Name(); std::string dropout_op_out_name = dropout_op_out->Var()->Name(); + // any_op2 auto* any_op2_desc = any_op2->Op(); auto var_map = any_op2_desc->Inputs(); std::string arg_name = ""; @@ -80,6 +83,7 @@ void DeleteDropoutOpPass::ApplyImpl(ir::Graph* graph) const { } } any_op2_desc->Flush(); + // Delete the unneeded nodes. GraphSafeRemoveNodes(graph, {dropout_op, dropout_op_out, dropout_op_outmask}); @@ -88,9 +92,197 @@ void DeleteDropoutOpPass::ApplyImpl(ir::Graph* graph) const { gpd(graph, handler); } +DeleteDropoutOpXPass::DeleteDropoutOpXPass() { + AddOpCompat(OpCompat("scale")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("scale") + .IsNumGE(0.f) + .IsNumLE(1.f) + .End() + .AddAttr("bias") + .IsNumEQ(0.f) + .End() + .AddAttr("bias_after_scale") + .IsNumEQ(true) + .End(); +} + +void DeleteDropoutOpXPass::ApplyImpl(ir::Graph* graph) const { + VLOG(3) << "delte dropout op."; + std::unordered_set del_node_set; + for (Node* n : graph->Nodes()) { + if (n->IsOp() && n->Op()) { + if (n->Op()->Type() == "dropout") { + DelDropout(graph, n, &del_node_set); + } + } + } + + GraphSafeRemoveNodes(graph, del_node_set); +} + +bool DeleteDropoutOpXPass::DelDropout( + Graph* graph, + Node* n, + std::unordered_set* del_node_set) const { + OpDesc* dropout_op_desc = n->Op(); + + Node* dropout_x = GetInputVar(n, dropout_op_desc->Input("X")[0]); + Node* dropout_out = GetOutputVar(n, dropout_op_desc->Output("Out")[0]); + + bool upscale_in_train = false; + // Once the dropout_implementation's AttrType is BOOLEAN, but now is STRING. + if (dropout_op_desc->HasAttr("dropout_implementation")) { + if (dropout_op_desc->GetAttrType("dropout_implementation") == + proto::AttrType::BOOLEAN) { + upscale_in_train = PADDLE_GET_CONST( + bool, dropout_op_desc->GetAttr("dropout_implementation")); + } else if (dropout_op_desc->GetAttrType("dropout_implementation") == + proto::AttrType::STRING) { + upscale_in_train = + PADDLE_GET_CONST(std::string, + dropout_op_desc->GetAttr( + "dropout_implementation")) == "upscale_in_train"; + } + } + + VLOG(3) << "upscale_in_train: " << upscale_in_train; + + if (upscale_in_train) { + // delete dropout + // dropout_op can be deleted. + // dropout_x -> dropout_op -> dropout_out -> next_op -> next_out + // | + // \|/ + // dropout_x -> next_op -> next_out + // Check whether dropout_x is some next_op's output + bool dropout_x_is_reused_as_output = false; + for (auto* next_op : dropout_out->outputs) { + for (auto* next_out : next_op->outputs) { + if (next_out == dropout_x || + next_out->Var()->Name() == dropout_x->Var()->Name()) { + dropout_x_is_reused_as_output = true; + break; + } + } + if (dropout_x_is_reused_as_output) { + break; + } + } + if (dropout_x_is_reused_as_output) { + VarDesc new_var_desc(*dropout_x->Var()); + new_var_desc.SetName("delete_dropout_x_pass_" + dropout_x->Name()); + auto* new_var_node = graph->CreateVarNode(&new_var_desc); + for (auto* out_op : dropout_x->outputs) { + if (out_op != n) { + ReplaceInputVar(out_op, dropout_x, new_var_node); + } + } + for (auto* in_op : dropout_x->inputs) { + ReplaceOutputVar(in_op, dropout_x, new_var_node); + } + dropout_x = new_var_node; + } + for (auto* next_op : dropout_out->outputs) { + ReplaceInputVar(next_op, dropout_out, dropout_x); + } + + del_node_set->insert(dropout_out); + } else { + // keep dropout + // Use a scale_op replaces the dropout_op + // dropout_x -> dropout_op -> dropout_out -> next_op -> next_out + // | + // \|/ + // dropout_x -> scale_op -> dropout_out -> next_op -> next_out + float scale = 1.0f - PADDLE_GET_CONST( + float, dropout_op_desc->GetAttr("dropout_prob")); + + framework::OpDesc new_op_desc(dropout_op_desc->Block()); + new_op_desc.SetType("scale"); + new_op_desc.SetInput("X", {dropout_x->Name()}); + new_op_desc.SetOutput("Out", {dropout_out->Name()}); + new_op_desc.SetAttr("scale", scale); + new_op_desc.SetAttr("bias", static_cast(0)); + new_op_desc.SetAttr("bias_after_scale", true); + + if (!IsCompat(new_op_desc)) { + LOG(WARNING) << "Basic ops pass in scale op compat failed."; + return false; + } + + auto* scale_op_node = graph->CreateOpNode(&new_op_desc); + IR_NODE_LINK_TO(dropout_x, scale_op_node); + IR_NODE_LINK_TO(scale_op_node, dropout_out); + } + + del_node_set->insert(n); + return true; +} + +Node* DeleteDropoutOpXPass::GetInputVar(Node* n, + const std::string& name) const { + for (auto* in : n->inputs) { + if (in->Name() == name) { + return in; + } + } + return nullptr; +} + +Node* DeleteDropoutOpXPass::GetOutputVar(Node* n, + const std::string& name) const { + for (auto* out : n->outputs) { + if (out->Name() == name) { + return out; + } + } + return nullptr; +} + +void DeleteDropoutOpXPass::ReplaceInputVar(Node* op, + Node* old_var, + Node* new_var) const { + if (op->IsOp() && op->Op()) { + new_var->outputs.push_back(op); + for (size_t i = 0; i < op->inputs.size(); ++i) { + if (op->inputs[i] == old_var) { + op->inputs[i] = new_var; + op->Op()->RenameInput(old_var->Name(), new_var->Name()); + } + } + } +} + +void DeleteDropoutOpXPass::ReplaceOutputVar(Node* op, + Node* old_var, + Node* new_var) const { + if (op->IsOp() && op->Op()) { + new_var->inputs.push_back(op); + for (size_t i = 0; i < op->outputs.size(); ++i) { + if (op->outputs[i] == old_var) { + op->outputs[i] = new_var; + op->Op()->RenameOutput(old_var->Name(), new_var->Name()); + } + } + } +} + } // namespace ir } // namespace framework } // namespace paddle REGISTER_PASS(delete_dropout_op_pass, paddle::framework::ir::DeleteDropoutOpPass); + +REGISTER_PASS(delete_dropout_op_x_pass, + paddle::framework::ir::DeleteDropoutOpXPass); +REGISTER_PASS_CAPABILITY(delete_dropout_op_x_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().EQ( + "scale", 0)); diff --git a/paddle/fluid/framework/ir/delete_dropout_op_pass.h b/paddle/fluid/framework/ir/delete_dropout_op_pass.h index c49abf3c871..829dbdf5824 100644 --- a/paddle/fluid/framework/ir/delete_dropout_op_pass.h +++ b/paddle/fluid/framework/ir/delete_dropout_op_pass.h @@ -13,10 +13,13 @@ // limitations under the License. #pragma once +#include +#include #include #include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/op_compat_sensible_pass.h" namespace paddle { namespace framework { @@ -32,6 +35,24 @@ class DeleteDropoutOpPass : public FusePassBase { void ApplyImpl(ir::Graph* graph) const override; }; +class DeleteDropoutOpXPass : public OpCompatSensiblePass { + public: + DeleteDropoutOpXPass(); + virtual ~DeleteDropoutOpXPass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + bool DelDropout(Graph* graph, + Node* n, + std::unordered_set* del_node_set) const; + Node* GetInputVar(Node* n, const std::string& name) const; + Node* GetOutputVar(Node* n, const std::string& name) const; + void ReplaceInputVar(Node* op, Node* old_var, Node* new_var) const; + void ReplaceOutputVar(Node* op, Node* old_var, Node* new_var) const; +}; + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/delete_dropout_op_pass_test.cc b/paddle/fluid/framework/ir/delete_dropout_op_pass_test.cc new file mode 100644 index 00000000000..d8cc2210645 --- /dev/null +++ b/paddle/fluid/framework/ir/delete_dropout_op_pass_test.cc @@ -0,0 +1,96 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include "paddle/fluid/framework/ir/delete_dropout_op_pass.h" +#include "paddle/fluid/framework/ir/pass_tester_helper.h" + +namespace paddle { +namespace framework { +namespace ir { + +TEST(DeleteDropoutOpsPass, dropout) { + for (std::string dropout_implementation : + {"downgrade_in_infer", "upscale_in_train"}) { + for (auto inplace : {false, true}) { + if (dropout_implementation == "downgrade_in_infer" && inplace == true) { + continue; + } + + LOG(INFO) << "dropout_implementation: " << dropout_implementation + << ", inplace: " << inplace; + Layers layers; + // (x, y) -> mul -> tmp_0 + // (tmp_0) -> dropout -> (tmp_1) + // (tmp_1, z) -> elementwise_add -> (tmp_2) + // or + // (tmp_1, z) -> elementwise_add -> (tmp_0) + auto* x = layers.data("x"); + auto* y = layers.data("y"); + auto* z = layers.data("z"); + auto* mul_out = layers.mul(x, y); + auto* dropout_out = layers.dropout(mul_out, 0.5f, dropout_implementation); + if (inplace) { + layers.elementwise_add(dropout_out, z, mul_out); + } else { + layers.elementwise_add(dropout_out, z); + } + + std::unique_ptr graph(new Graph(layers.main_program())); + auto pass = PassRegistry::Instance().Get("delete_dropout_op_x_pass"); + int num_dropout_nodes_before = GetNumOpNodes(graph, "dropout"); + int num_scale_nodes_before = GetNumOpNodes(graph, "scale"); + VLOG(3) << DebugString(graph); + + graph.reset(pass->Apply(graph.release())); + int num_dropout_nodes_after = GetNumOpNodes(graph, "dropout"); + int num_scale_nodes_after = GetNumOpNodes(graph, "scale"); + + VLOG(3) << DebugString(graph); + + PADDLE_ENFORCE_EQ( + num_dropout_nodes_after, + 0, + platform::errors::InvalidArgument("num_dropout_nodes_after = %d.", + num_dropout_nodes_after)); + + if (dropout_implementation == "downgrade_in_infer") { + PADDLE_ENFORCE_EQ( + num_dropout_nodes_before, + num_scale_nodes_after - num_scale_nodes_before, + platform::errors::InvalidArgument( + "num_dropout_nodes_before = %d, num_scale_nodes_after = %d, " + "num_scale_nodes_before = %d.", + num_dropout_nodes_before, + num_scale_nodes_after, + num_scale_nodes_before)); + } else { + PADDLE_ENFORCE_EQ( + num_scale_nodes_after - num_scale_nodes_before, + 0, + platform::errors::InvalidArgument( + "num_scale_nodes_after = %d, num_scale_nodes_before = %d.", + num_scale_nodes_after, + num_scale_nodes_before)); + } + } + } +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(delete_dropout_op_x_pass); diff --git a/paddle/fluid/jit/engine/pe_engine.cc b/paddle/fluid/jit/engine/pe_engine.cc index a2c6d69d16e..35d7f87df74 100644 --- a/paddle/fluid/jit/engine/pe_engine.cc +++ b/paddle/fluid/jit/engine/pe_engine.cc @@ -74,6 +74,8 @@ PEEngine::PEEngine(const std::shared_ptr &info, void PEEngine::CreateGraphAndPE() { framework::details::BuildStrategy build_strategy; + build_strategy.inference_ = true; + build_strategy.del_dropout_ = true; auto execution_strategy = GetExecutionStrategy(place_); auto &program_desc = info_->ProgramDesc(); -- GitLab