未验证 提交 46bc06b5 编写于 作者: H Hui Zhang 提交者: GitHub

add del dropout op pass to jit pe enigne (#45439)

* add del dropout op pass to jit pe enigne

* add delete dropout test
上级 c2942144
......@@ -385,7 +385,8 @@ set(IR_PASS_DEPS
runtime_context_cache_pass
graph_to_program_pass
fix_op_run_order_pass
fuse_gemm_epilogue_pass)
fuse_gemm_epilogue_pass
delete_dropout_op_pass)
if(WITH_CINN)
set(IR_PASS_DEPS ${IR_PASS_DEPS} build_cinn_pass)
......
......@@ -169,6 +169,11 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
}
void AppendOpFusePasses() {
// 1. infernce pass if enabled.
AppendPassWithCheck(strategy_.inference_ && strategy_.del_dropout_,
"delete_dropout_op_x_pass");
// 2. trainning pass
AppendPassWithCheck(strategy_.fuse_relu_depthwise_conv_,
"fuse_relu_depthwise_conv_pass");
AppendPassWithCheck(strategy_.fuse_bn_act_ops_, "fuse_bn_act_pass");
......@@ -509,6 +514,7 @@ USE_PASS(fuse_momentum_op_pass);
USE_PASS(fuse_all_reduce_op_pass);
USE_PASS(runtime_context_cache_pass);
USE_PASS(add_reader_dependency_pass);
USE_PASS(delete_dropout_op_x_pass);
#ifdef PADDLE_WITH_CINN
USE_PASS(build_cinn_pass);
#endif
......
......@@ -147,6 +147,10 @@ struct BuildStrategy {
bool allow_cuda_graph_capture_{false};
// Inference pass
bool inference_{false}; // switch for infernce pass
bool del_dropout_{false};
// FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode,
// num_trainers is 1, so the current fields of build_strategy doesn't tell if
// it's distributed model.
......
......@@ -369,6 +369,10 @@ cc_test(
test_generate_pass_cc
SRCS generate_pass_tester.cc
DEPS generate_pass pass_desc_proto)
cc_test(
test_delete_dropout_pass_cc
SRCS delete_dropout_op_pass_test.cc
DEPS delete_dropout_op_pass)
if(WITH_GPU OR WITH_ROCM)
cc_test(
test_embedding_eltwise_layernorm_fuse_pass
......
......@@ -15,6 +15,8 @@
#include <string>
#include "paddle/fluid/framework/op_version_registry.h"
namespace phi {
class DenseTensor;
} // namespace phi
......@@ -47,6 +49,7 @@ void DeleteDropoutOpPass::ApplyImpl(ir::Graph* graph) const {
std::string any_op_out_name = any_op_out->Var()->Name();
std::string dropout_op_out_name = dropout_op_out->Var()->Name();
// any_op2
auto* any_op2_desc = any_op2->Op();
auto var_map = any_op2_desc->Inputs();
std::string arg_name = "";
......@@ -80,6 +83,7 @@ void DeleteDropoutOpPass::ApplyImpl(ir::Graph* graph) const {
}
}
any_op2_desc->Flush();
// Delete the unneeded nodes.
GraphSafeRemoveNodes(graph,
{dropout_op, dropout_op_out, dropout_op_outmask});
......@@ -88,9 +92,197 @@ void DeleteDropoutOpPass::ApplyImpl(ir::Graph* graph) const {
gpd(graph, handler);
}
DeleteDropoutOpXPass::DeleteDropoutOpXPass() {
AddOpCompat(OpCompat("scale"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("scale")
.IsNumGE(0.f)
.IsNumLE(1.f)
.End()
.AddAttr("bias")
.IsNumEQ(0.f)
.End()
.AddAttr("bias_after_scale")
.IsNumEQ(true)
.End();
}
void DeleteDropoutOpXPass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "delte dropout op.";
std::unordered_set<const Node*> del_node_set;
for (Node* n : graph->Nodes()) {
if (n->IsOp() && n->Op()) {
if (n->Op()->Type() == "dropout") {
DelDropout(graph, n, &del_node_set);
}
}
}
GraphSafeRemoveNodes(graph, del_node_set);
}
bool DeleteDropoutOpXPass::DelDropout(
Graph* graph,
Node* n,
std::unordered_set<const Node*>* del_node_set) const {
OpDesc* dropout_op_desc = n->Op();
Node* dropout_x = GetInputVar(n, dropout_op_desc->Input("X")[0]);
Node* dropout_out = GetOutputVar(n, dropout_op_desc->Output("Out")[0]);
bool upscale_in_train = false;
// Once the dropout_implementation's AttrType is BOOLEAN, but now is STRING.
if (dropout_op_desc->HasAttr("dropout_implementation")) {
if (dropout_op_desc->GetAttrType("dropout_implementation") ==
proto::AttrType::BOOLEAN) {
upscale_in_train = PADDLE_GET_CONST(
bool, dropout_op_desc->GetAttr("dropout_implementation"));
} else if (dropout_op_desc->GetAttrType("dropout_implementation") ==
proto::AttrType::STRING) {
upscale_in_train =
PADDLE_GET_CONST(std::string,
dropout_op_desc->GetAttr(
"dropout_implementation")) == "upscale_in_train";
}
}
VLOG(3) << "upscale_in_train: " << upscale_in_train;
if (upscale_in_train) {
// delete dropout
// dropout_op can be deleted.
// dropout_x -> dropout_op -> dropout_out -> next_op -> next_out
// |
// \|/
// dropout_x -> next_op -> next_out
// Check whether dropout_x is some next_op's output
bool dropout_x_is_reused_as_output = false;
for (auto* next_op : dropout_out->outputs) {
for (auto* next_out : next_op->outputs) {
if (next_out == dropout_x ||
next_out->Var()->Name() == dropout_x->Var()->Name()) {
dropout_x_is_reused_as_output = true;
break;
}
}
if (dropout_x_is_reused_as_output) {
break;
}
}
if (dropout_x_is_reused_as_output) {
VarDesc new_var_desc(*dropout_x->Var());
new_var_desc.SetName("delete_dropout_x_pass_" + dropout_x->Name());
auto* new_var_node = graph->CreateVarNode(&new_var_desc);
for (auto* out_op : dropout_x->outputs) {
if (out_op != n) {
ReplaceInputVar(out_op, dropout_x, new_var_node);
}
}
for (auto* in_op : dropout_x->inputs) {
ReplaceOutputVar(in_op, dropout_x, new_var_node);
}
dropout_x = new_var_node;
}
for (auto* next_op : dropout_out->outputs) {
ReplaceInputVar(next_op, dropout_out, dropout_x);
}
del_node_set->insert(dropout_out);
} else {
// keep dropout
// Use a scale_op replaces the dropout_op
// dropout_x -> dropout_op -> dropout_out -> next_op -> next_out
// |
// \|/
// dropout_x -> scale_op -> dropout_out -> next_op -> next_out
float scale = 1.0f - PADDLE_GET_CONST(
float, dropout_op_desc->GetAttr("dropout_prob"));
framework::OpDesc new_op_desc(dropout_op_desc->Block());
new_op_desc.SetType("scale");
new_op_desc.SetInput("X", {dropout_x->Name()});
new_op_desc.SetOutput("Out", {dropout_out->Name()});
new_op_desc.SetAttr("scale", scale);
new_op_desc.SetAttr("bias", static_cast<float>(0));
new_op_desc.SetAttr("bias_after_scale", true);
if (!IsCompat(new_op_desc)) {
LOG(WARNING) << "Basic ops pass in scale op compat failed.";
return false;
}
auto* scale_op_node = graph->CreateOpNode(&new_op_desc);
IR_NODE_LINK_TO(dropout_x, scale_op_node);
IR_NODE_LINK_TO(scale_op_node, dropout_out);
}
del_node_set->insert(n);
return true;
}
Node* DeleteDropoutOpXPass::GetInputVar(Node* n,
const std::string& name) const {
for (auto* in : n->inputs) {
if (in->Name() == name) {
return in;
}
}
return nullptr;
}
Node* DeleteDropoutOpXPass::GetOutputVar(Node* n,
const std::string& name) const {
for (auto* out : n->outputs) {
if (out->Name() == name) {
return out;
}
}
return nullptr;
}
void DeleteDropoutOpXPass::ReplaceInputVar(Node* op,
Node* old_var,
Node* new_var) const {
if (op->IsOp() && op->Op()) {
new_var->outputs.push_back(op);
for (size_t i = 0; i < op->inputs.size(); ++i) {
if (op->inputs[i] == old_var) {
op->inputs[i] = new_var;
op->Op()->RenameInput(old_var->Name(), new_var->Name());
}
}
}
}
void DeleteDropoutOpXPass::ReplaceOutputVar(Node* op,
Node* old_var,
Node* new_var) const {
if (op->IsOp() && op->Op()) {
new_var->inputs.push_back(op);
for (size_t i = 0; i < op->outputs.size(); ++i) {
if (op->outputs[i] == old_var) {
op->outputs[i] = new_var;
op->Op()->RenameOutput(old_var->Name(), new_var->Name());
}
}
}
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(delete_dropout_op_pass,
paddle::framework::ir::DeleteDropoutOpPass);
REGISTER_PASS(delete_dropout_op_x_pass,
paddle::framework::ir::DeleteDropoutOpXPass);
REGISTER_PASS_CAPABILITY(delete_dropout_op_x_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"scale", 0));
......@@ -13,10 +13,13 @@
// limitations under the License.
#pragma once
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/op_compat_sensible_pass.h"
namespace paddle {
namespace framework {
......@@ -32,6 +35,24 @@ class DeleteDropoutOpPass : public FusePassBase {
void ApplyImpl(ir::Graph* graph) const override;
};
class DeleteDropoutOpXPass : public OpCompatSensiblePass {
public:
DeleteDropoutOpXPass();
virtual ~DeleteDropoutOpXPass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
bool DelDropout(Graph* graph,
Node* n,
std::unordered_set<const Node*>* del_node_set) const;
Node* GetInputVar(Node* n, const std::string& name) const;
Node* GetOutputVar(Node* n, const std::string& name) const;
void ReplaceInputVar(Node* op, Node* old_var, Node* new_var) const;
void ReplaceOutputVar(Node* op, Node* old_var, Node* new_var) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/delete_dropout_op_pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle {
namespace framework {
namespace ir {
TEST(DeleteDropoutOpsPass, dropout) {
for (std::string dropout_implementation :
{"downgrade_in_infer", "upscale_in_train"}) {
for (auto inplace : {false, true}) {
if (dropout_implementation == "downgrade_in_infer" && inplace == true) {
continue;
}
LOG(INFO) << "dropout_implementation: " << dropout_implementation
<< ", inplace: " << inplace;
Layers layers;
// (x, y) -> mul -> tmp_0
// (tmp_0) -> dropout -> (tmp_1)
// (tmp_1, z) -> elementwise_add -> (tmp_2)
// or
// (tmp_1, z) -> elementwise_add -> (tmp_0)
auto* x = layers.data("x");
auto* y = layers.data("y");
auto* z = layers.data("z");
auto* mul_out = layers.mul(x, y);
auto* dropout_out = layers.dropout(mul_out, 0.5f, dropout_implementation);
if (inplace) {
layers.elementwise_add(dropout_out, z, mul_out);
} else {
layers.elementwise_add(dropout_out, z);
}
std::unique_ptr<Graph> graph(new Graph(layers.main_program()));
auto pass = PassRegistry::Instance().Get("delete_dropout_op_x_pass");
int num_dropout_nodes_before = GetNumOpNodes(graph, "dropout");
int num_scale_nodes_before = GetNumOpNodes(graph, "scale");
VLOG(3) << DebugString(graph);
graph.reset(pass->Apply(graph.release()));
int num_dropout_nodes_after = GetNumOpNodes(graph, "dropout");
int num_scale_nodes_after = GetNumOpNodes(graph, "scale");
VLOG(3) << DebugString(graph);
PADDLE_ENFORCE_EQ(
num_dropout_nodes_after,
0,
platform::errors::InvalidArgument("num_dropout_nodes_after = %d.",
num_dropout_nodes_after));
if (dropout_implementation == "downgrade_in_infer") {
PADDLE_ENFORCE_EQ(
num_dropout_nodes_before,
num_scale_nodes_after - num_scale_nodes_before,
platform::errors::InvalidArgument(
"num_dropout_nodes_before = %d, num_scale_nodes_after = %d, "
"num_scale_nodes_before = %d.",
num_dropout_nodes_before,
num_scale_nodes_after,
num_scale_nodes_before));
} else {
PADDLE_ENFORCE_EQ(
num_scale_nodes_after - num_scale_nodes_before,
0,
platform::errors::InvalidArgument(
"num_scale_nodes_after = %d, num_scale_nodes_before = %d.",
num_scale_nodes_after,
num_scale_nodes_before));
}
}
}
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(delete_dropout_op_x_pass);
......@@ -74,6 +74,8 @@ PEEngine::PEEngine(const std::shared_ptr<FunctionInfo> &info,
void PEEngine::CreateGraphAndPE() {
framework::details::BuildStrategy build_strategy;
build_strategy.inference_ = true;
build_strategy.del_dropout_ = true;
auto execution_strategy = GetExecutionStrategy(place_);
auto &program_desc = info_->ProgramDesc();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册