未验证 提交 46bc06b5 编写于 作者: H Hui Zhang 提交者: GitHub

add del dropout op pass to jit pe enigne (#45439)

* add del dropout op pass to jit pe enigne

* add delete dropout test
上级 c2942144
...@@ -385,7 +385,8 @@ set(IR_PASS_DEPS ...@@ -385,7 +385,8 @@ set(IR_PASS_DEPS
runtime_context_cache_pass runtime_context_cache_pass
graph_to_program_pass graph_to_program_pass
fix_op_run_order_pass fix_op_run_order_pass
fuse_gemm_epilogue_pass) fuse_gemm_epilogue_pass
delete_dropout_op_pass)
if(WITH_CINN) if(WITH_CINN)
set(IR_PASS_DEPS ${IR_PASS_DEPS} build_cinn_pass) set(IR_PASS_DEPS ${IR_PASS_DEPS} build_cinn_pass)
......
...@@ -169,6 +169,11 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -169,6 +169,11 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
} }
void AppendOpFusePasses() { void AppendOpFusePasses() {
// 1. infernce pass if enabled.
AppendPassWithCheck(strategy_.inference_ && strategy_.del_dropout_,
"delete_dropout_op_x_pass");
// 2. trainning pass
AppendPassWithCheck(strategy_.fuse_relu_depthwise_conv_, AppendPassWithCheck(strategy_.fuse_relu_depthwise_conv_,
"fuse_relu_depthwise_conv_pass"); "fuse_relu_depthwise_conv_pass");
AppendPassWithCheck(strategy_.fuse_bn_act_ops_, "fuse_bn_act_pass"); AppendPassWithCheck(strategy_.fuse_bn_act_ops_, "fuse_bn_act_pass");
...@@ -509,6 +514,7 @@ USE_PASS(fuse_momentum_op_pass); ...@@ -509,6 +514,7 @@ USE_PASS(fuse_momentum_op_pass);
USE_PASS(fuse_all_reduce_op_pass); USE_PASS(fuse_all_reduce_op_pass);
USE_PASS(runtime_context_cache_pass); USE_PASS(runtime_context_cache_pass);
USE_PASS(add_reader_dependency_pass); USE_PASS(add_reader_dependency_pass);
USE_PASS(delete_dropout_op_x_pass);
#ifdef PADDLE_WITH_CINN #ifdef PADDLE_WITH_CINN
USE_PASS(build_cinn_pass); USE_PASS(build_cinn_pass);
#endif #endif
......
...@@ -147,6 +147,10 @@ struct BuildStrategy { ...@@ -147,6 +147,10 @@ struct BuildStrategy {
bool allow_cuda_graph_capture_{false}; bool allow_cuda_graph_capture_{false};
// Inference pass
bool inference_{false}; // switch for infernce pass
bool del_dropout_{false};
// FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode, // FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode,
// num_trainers is 1, so the current fields of build_strategy doesn't tell if // num_trainers is 1, so the current fields of build_strategy doesn't tell if
// it's distributed model. // it's distributed model.
......
...@@ -369,6 +369,10 @@ cc_test( ...@@ -369,6 +369,10 @@ cc_test(
test_generate_pass_cc test_generate_pass_cc
SRCS generate_pass_tester.cc SRCS generate_pass_tester.cc
DEPS generate_pass pass_desc_proto) DEPS generate_pass pass_desc_proto)
cc_test(
test_delete_dropout_pass_cc
SRCS delete_dropout_op_pass_test.cc
DEPS delete_dropout_op_pass)
if(WITH_GPU OR WITH_ROCM) if(WITH_GPU OR WITH_ROCM)
cc_test( cc_test(
test_embedding_eltwise_layernorm_fuse_pass test_embedding_eltwise_layernorm_fuse_pass
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
#include <string> #include <string>
#include "paddle/fluid/framework/op_version_registry.h"
namespace phi { namespace phi {
class DenseTensor; class DenseTensor;
} // namespace phi } // namespace phi
...@@ -47,6 +49,7 @@ void DeleteDropoutOpPass::ApplyImpl(ir::Graph* graph) const { ...@@ -47,6 +49,7 @@ void DeleteDropoutOpPass::ApplyImpl(ir::Graph* graph) const {
std::string any_op_out_name = any_op_out->Var()->Name(); std::string any_op_out_name = any_op_out->Var()->Name();
std::string dropout_op_out_name = dropout_op_out->Var()->Name(); std::string dropout_op_out_name = dropout_op_out->Var()->Name();
// any_op2
auto* any_op2_desc = any_op2->Op(); auto* any_op2_desc = any_op2->Op();
auto var_map = any_op2_desc->Inputs(); auto var_map = any_op2_desc->Inputs();
std::string arg_name = ""; std::string arg_name = "";
...@@ -80,6 +83,7 @@ void DeleteDropoutOpPass::ApplyImpl(ir::Graph* graph) const { ...@@ -80,6 +83,7 @@ void DeleteDropoutOpPass::ApplyImpl(ir::Graph* graph) const {
} }
} }
any_op2_desc->Flush(); any_op2_desc->Flush();
// Delete the unneeded nodes. // Delete the unneeded nodes.
GraphSafeRemoveNodes(graph, GraphSafeRemoveNodes(graph,
{dropout_op, dropout_op_out, dropout_op_outmask}); {dropout_op, dropout_op_out, dropout_op_outmask});
...@@ -88,9 +92,197 @@ void DeleteDropoutOpPass::ApplyImpl(ir::Graph* graph) const { ...@@ -88,9 +92,197 @@ void DeleteDropoutOpPass::ApplyImpl(ir::Graph* graph) const {
gpd(graph, handler); gpd(graph, handler);
} }
DeleteDropoutOpXPass::DeleteDropoutOpXPass() {
AddOpCompat(OpCompat("scale"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("scale")
.IsNumGE(0.f)
.IsNumLE(1.f)
.End()
.AddAttr("bias")
.IsNumEQ(0.f)
.End()
.AddAttr("bias_after_scale")
.IsNumEQ(true)
.End();
}
void DeleteDropoutOpXPass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "delte dropout op.";
std::unordered_set<const Node*> del_node_set;
for (Node* n : graph->Nodes()) {
if (n->IsOp() && n->Op()) {
if (n->Op()->Type() == "dropout") {
DelDropout(graph, n, &del_node_set);
}
}
}
GraphSafeRemoveNodes(graph, del_node_set);
}
bool DeleteDropoutOpXPass::DelDropout(
Graph* graph,
Node* n,
std::unordered_set<const Node*>* del_node_set) const {
OpDesc* dropout_op_desc = n->Op();
Node* dropout_x = GetInputVar(n, dropout_op_desc->Input("X")[0]);
Node* dropout_out = GetOutputVar(n, dropout_op_desc->Output("Out")[0]);
bool upscale_in_train = false;
// Once the dropout_implementation's AttrType is BOOLEAN, but now is STRING.
if (dropout_op_desc->HasAttr("dropout_implementation")) {
if (dropout_op_desc->GetAttrType("dropout_implementation") ==
proto::AttrType::BOOLEAN) {
upscale_in_train = PADDLE_GET_CONST(
bool, dropout_op_desc->GetAttr("dropout_implementation"));
} else if (dropout_op_desc->GetAttrType("dropout_implementation") ==
proto::AttrType::STRING) {
upscale_in_train =
PADDLE_GET_CONST(std::string,
dropout_op_desc->GetAttr(
"dropout_implementation")) == "upscale_in_train";
}
}
VLOG(3) << "upscale_in_train: " << upscale_in_train;
if (upscale_in_train) {
// delete dropout
// dropout_op can be deleted.
// dropout_x -> dropout_op -> dropout_out -> next_op -> next_out
// |
// \|/
// dropout_x -> next_op -> next_out
// Check whether dropout_x is some next_op's output
bool dropout_x_is_reused_as_output = false;
for (auto* next_op : dropout_out->outputs) {
for (auto* next_out : next_op->outputs) {
if (next_out == dropout_x ||
next_out->Var()->Name() == dropout_x->Var()->Name()) {
dropout_x_is_reused_as_output = true;
break;
}
}
if (dropout_x_is_reused_as_output) {
break;
}
}
if (dropout_x_is_reused_as_output) {
VarDesc new_var_desc(*dropout_x->Var());
new_var_desc.SetName("delete_dropout_x_pass_" + dropout_x->Name());
auto* new_var_node = graph->CreateVarNode(&new_var_desc);
for (auto* out_op : dropout_x->outputs) {
if (out_op != n) {
ReplaceInputVar(out_op, dropout_x, new_var_node);
}
}
for (auto* in_op : dropout_x->inputs) {
ReplaceOutputVar(in_op, dropout_x, new_var_node);
}
dropout_x = new_var_node;
}
for (auto* next_op : dropout_out->outputs) {
ReplaceInputVar(next_op, dropout_out, dropout_x);
}
del_node_set->insert(dropout_out);
} else {
// keep dropout
// Use a scale_op replaces the dropout_op
// dropout_x -> dropout_op -> dropout_out -> next_op -> next_out
// |
// \|/
// dropout_x -> scale_op -> dropout_out -> next_op -> next_out
float scale = 1.0f - PADDLE_GET_CONST(
float, dropout_op_desc->GetAttr("dropout_prob"));
framework::OpDesc new_op_desc(dropout_op_desc->Block());
new_op_desc.SetType("scale");
new_op_desc.SetInput("X", {dropout_x->Name()});
new_op_desc.SetOutput("Out", {dropout_out->Name()});
new_op_desc.SetAttr("scale", scale);
new_op_desc.SetAttr("bias", static_cast<float>(0));
new_op_desc.SetAttr("bias_after_scale", true);
if (!IsCompat(new_op_desc)) {
LOG(WARNING) << "Basic ops pass in scale op compat failed.";
return false;
}
auto* scale_op_node = graph->CreateOpNode(&new_op_desc);
IR_NODE_LINK_TO(dropout_x, scale_op_node);
IR_NODE_LINK_TO(scale_op_node, dropout_out);
}
del_node_set->insert(n);
return true;
}
Node* DeleteDropoutOpXPass::GetInputVar(Node* n,
const std::string& name) const {
for (auto* in : n->inputs) {
if (in->Name() == name) {
return in;
}
}
return nullptr;
}
Node* DeleteDropoutOpXPass::GetOutputVar(Node* n,
const std::string& name) const {
for (auto* out : n->outputs) {
if (out->Name() == name) {
return out;
}
}
return nullptr;
}
void DeleteDropoutOpXPass::ReplaceInputVar(Node* op,
Node* old_var,
Node* new_var) const {
if (op->IsOp() && op->Op()) {
new_var->outputs.push_back(op);
for (size_t i = 0; i < op->inputs.size(); ++i) {
if (op->inputs[i] == old_var) {
op->inputs[i] = new_var;
op->Op()->RenameInput(old_var->Name(), new_var->Name());
}
}
}
}
void DeleteDropoutOpXPass::ReplaceOutputVar(Node* op,
Node* old_var,
Node* new_var) const {
if (op->IsOp() && op->Op()) {
new_var->inputs.push_back(op);
for (size_t i = 0; i < op->outputs.size(); ++i) {
if (op->outputs[i] == old_var) {
op->outputs[i] = new_var;
op->Op()->RenameOutput(old_var->Name(), new_var->Name());
}
}
}
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(delete_dropout_op_pass, REGISTER_PASS(delete_dropout_op_pass,
paddle::framework::ir::DeleteDropoutOpPass); paddle::framework::ir::DeleteDropoutOpPass);
REGISTER_PASS(delete_dropout_op_x_pass,
paddle::framework::ir::DeleteDropoutOpXPass);
REGISTER_PASS_CAPABILITY(delete_dropout_op_x_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"scale", 0));
...@@ -13,10 +13,13 @@ ...@@ -13,10 +13,13 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <string>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/op_compat_sensible_pass.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -32,6 +35,24 @@ class DeleteDropoutOpPass : public FusePassBase { ...@@ -32,6 +35,24 @@ class DeleteDropoutOpPass : public FusePassBase {
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(ir::Graph* graph) const override;
}; };
class DeleteDropoutOpXPass : public OpCompatSensiblePass {
public:
DeleteDropoutOpXPass();
virtual ~DeleteDropoutOpXPass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
bool DelDropout(Graph* graph,
Node* n,
std::unordered_set<const Node*>* del_node_set) const;
Node* GetInputVar(Node* n, const std::string& name) const;
Node* GetOutputVar(Node* n, const std::string& name) const;
void ReplaceInputVar(Node* op, Node* old_var, Node* new_var) const;
void ReplaceOutputVar(Node* op, Node* old_var, Node* new_var) const;
};
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/delete_dropout_op_pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle {
namespace framework {
namespace ir {
TEST(DeleteDropoutOpsPass, dropout) {
for (std::string dropout_implementation :
{"downgrade_in_infer", "upscale_in_train"}) {
for (auto inplace : {false, true}) {
if (dropout_implementation == "downgrade_in_infer" && inplace == true) {
continue;
}
LOG(INFO) << "dropout_implementation: " << dropout_implementation
<< ", inplace: " << inplace;
Layers layers;
// (x, y) -> mul -> tmp_0
// (tmp_0) -> dropout -> (tmp_1)
// (tmp_1, z) -> elementwise_add -> (tmp_2)
// or
// (tmp_1, z) -> elementwise_add -> (tmp_0)
auto* x = layers.data("x");
auto* y = layers.data("y");
auto* z = layers.data("z");
auto* mul_out = layers.mul(x, y);
auto* dropout_out = layers.dropout(mul_out, 0.5f, dropout_implementation);
if (inplace) {
layers.elementwise_add(dropout_out, z, mul_out);
} else {
layers.elementwise_add(dropout_out, z);
}
std::unique_ptr<Graph> graph(new Graph(layers.main_program()));
auto pass = PassRegistry::Instance().Get("delete_dropout_op_x_pass");
int num_dropout_nodes_before = GetNumOpNodes(graph, "dropout");
int num_scale_nodes_before = GetNumOpNodes(graph, "scale");
VLOG(3) << DebugString(graph);
graph.reset(pass->Apply(graph.release()));
int num_dropout_nodes_after = GetNumOpNodes(graph, "dropout");
int num_scale_nodes_after = GetNumOpNodes(graph, "scale");
VLOG(3) << DebugString(graph);
PADDLE_ENFORCE_EQ(
num_dropout_nodes_after,
0,
platform::errors::InvalidArgument("num_dropout_nodes_after = %d.",
num_dropout_nodes_after));
if (dropout_implementation == "downgrade_in_infer") {
PADDLE_ENFORCE_EQ(
num_dropout_nodes_before,
num_scale_nodes_after - num_scale_nodes_before,
platform::errors::InvalidArgument(
"num_dropout_nodes_before = %d, num_scale_nodes_after = %d, "
"num_scale_nodes_before = %d.",
num_dropout_nodes_before,
num_scale_nodes_after,
num_scale_nodes_before));
} else {
PADDLE_ENFORCE_EQ(
num_scale_nodes_after - num_scale_nodes_before,
0,
platform::errors::InvalidArgument(
"num_scale_nodes_after = %d, num_scale_nodes_before = %d.",
num_scale_nodes_after,
num_scale_nodes_before));
}
}
}
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(delete_dropout_op_x_pass);
...@@ -74,6 +74,8 @@ PEEngine::PEEngine(const std::shared_ptr<FunctionInfo> &info, ...@@ -74,6 +74,8 @@ PEEngine::PEEngine(const std::shared_ptr<FunctionInfo> &info,
void PEEngine::CreateGraphAndPE() { void PEEngine::CreateGraphAndPE() {
framework::details::BuildStrategy build_strategy; framework::details::BuildStrategy build_strategy;
build_strategy.inference_ = true;
build_strategy.del_dropout_ = true;
auto execution_strategy = GetExecutionStrategy(place_); auto execution_strategy = GetExecutionStrategy(place_);
auto &program_desc = info_->ProgramDesc(); auto &program_desc = info_->ProgramDesc();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册