From fcec365d29a69e81e95578e3720faaabccafbae7 Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Fri, 30 Aug 2019 10:02:10 +0800 Subject: [PATCH] Add a pass to replace dropout_op with scale_op when is_test is true (#19297) * Add simplify_with_basic_ops_pass to replace dropout_op with scale_op when is_test is true. test=develop * Delete dropout_op directly when upscale_in_train is true. test=develop * Improve the debug string, adding the print of op_desc information. * Fix the case when dropout's input x is reused as the next op's output. * Add the pass to inference. test=develop * Change the log level. test=develop * Add unittest for inplace case. * Add comment to explain the pass. * Apply the pass for CPU inference. test=develop * Fix the typo. test=develop * Add the check of AttrType. test=develop --- paddle/fluid/framework/ir/CMakeLists.txt | 2 + .../fluid/framework/ir/pass_tester_helper.h | 206 ++++++++++++++++++ .../ir/simplify_with_basic_ops_pass.cc | 202 +++++++++++++++++ .../ir/simplify_with_basic_ops_pass.h | 42 ++++ .../ir/simplify_with_basic_ops_pass_tester.cc | 78 +++++++ .../inference/api/paddle_pass_builder.cc | 7 +- 6 files changed, 535 insertions(+), 2 deletions(-) create mode 100644 paddle/fluid/framework/ir/pass_tester_helper.h create mode 100644 paddle/fluid/framework/ir/simplify_with_basic_ops_pass.cc create mode 100644 paddle/fluid/framework/ir/simplify_with_basic_ops_pass.h create mode 100644 paddle/fluid/framework/ir/simplify_with_basic_ops_pass_tester.cc diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 5c5eac54ce..48e95bc090 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -76,6 +76,7 @@ pass_library(quant_conv2d_dequant_fuse_pass inference) pass_library(fillconstant_elementwisemul_fuse inference) pass_library(shuffle_channel_detect_pass inference) pass_library(delete_quant_dequant_op_pass inference) +pass_library(simplify_with_basic_ops_pass base) if(ANAKIN_SUBGRAPH) pass_library(simplify_anakin_priorbox_detection_out_pass inference) @@ -119,6 +120,7 @@ cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framewor cc_test(test_seqpool_concat_fuse_pass SRCS seqpool_concat_fuse_pass_tester.cc DEPS seqpool_concat_fuse_pass framework_proto) cc_test(test_seqpool_cvm_concat_fuse_pass SRCS seqpool_cvm_concat_fuse_pass_tester.cc DEPS seqpool_cvm_concat_fuse_pass framework_proto) cc_test(test_is_test_pass SRCS is_test_pass_tester.cc DEPS is_test_pass) +cc_test(test_simplify_with_basic_ops_pass SRCS simplify_with_basic_ops_pass_tester.cc DEPS simplify_with_basic_ops_pass) if(NOT WIN32) cc_test(test_sync_batch_norm_pass SRCS sync_batch_norm_pass_tester.cc DEPS sync_batch_norm_pass) endif() diff --git a/paddle/fluid/framework/ir/pass_tester_helper.h b/paddle/fluid/framework/ir/pass_tester_helper.h new file mode 100644 index 0000000000..5564c01429 --- /dev/null +++ b/paddle/fluid/framework/ir/pass_tester_helper.h @@ -0,0 +1,206 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include "paddle/fluid/framework/op_proto_maker.h" + +namespace paddle { +namespace framework { +namespace ir { + +struct Layers { + public: + const ProgramDesc& main_program() { return program_; } + + VarDesc* data(std::string name) { return lod_tensor(name); } + + VarDesc* mul(VarDesc* x, VarDesc* y, VarDesc* out = nullptr) { + return binary_op("mul", x, y, out); + } + + VarDesc* elementwise_add(VarDesc* x, VarDesc* y, VarDesc* out = nullptr) { + return binary_op("elementwise_add", x, y, out); + } + + VarDesc* dropout(VarDesc* x, float dropout_prob, + std::string dropout_implementation) { + VarDesc* out = lod_tensor(unique_name()); + OpDesc* op = program_.MutableBlock(0)->AppendOp(); + op->SetType("dropout"); + op->SetInput("X", {x->Name()}); + op->SetOutput("Out", {out->Name()}); + op->SetAttr("is_test", true); + op->SetAttr("dropout_prob", dropout_prob); + op->SetAttr("dropout_implementation", dropout_implementation); + op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), + static_cast(OpRole::kForward)); + return out; + } + + private: + VarDesc* lod_tensor(std::string name) { + auto* var = program_.MutableBlock(0)->Var(name); + var->SetType(proto::VarType::LOD_TENSOR); + return var; + } + + VarDesc* binary_op(std::string type, VarDesc* x, VarDesc* y, + VarDesc* out = nullptr) { + if (!out) { + out = lod_tensor(unique_name()); + } + OpDesc* op = program_.MutableBlock(0)->AppendOp(); + op->SetType(type); + op->SetInput("X", {x->Name()}); + op->SetInput("Y", {y->Name()}); + op->SetOutput("Out", {out->Name()}); + op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), + static_cast(OpRole::kForward)); + return out; + } + + std::string unique_name() { return "tmp_" + std::to_string(idx_++); } + + private: + ProgramDesc program_; + int idx_{0}; +}; + +static std::string DebugString(OpDesc* op) { + std::ostringstream os; + os << "Op(" << op->Type() << "), inputs:{"; + bool is_first = true; + for (auto& name : op->InputNames()) { + if (!is_first) { + os << ", "; + } + os << name << "["; + bool is_first_var_name = true; + for (auto& var_name : op->Input(name)) { + if (!is_first_var_name) { + os << ", "; + } + os << var_name; + is_first_var_name = false; + } + os << "]"; + is_first = false; + } + + os << "}, outputs:{"; + is_first = true; + for (auto& name : op->OutputNames()) { + if (!is_first) { + os << ", "; + } + os << name << "["; + bool is_first_var_name = true; + for (auto& var_name : op->Output(name)) { + if (!is_first_var_name) { + os << ", "; + } + os << var_name; + is_first_var_name = false; + } + os << "]"; + is_first = false; + } + os << "}"; + return os.str(); +} + +static std::string DebugString(Node* node) { + std::ostringstream os; + if (node->IsOp() && node->Op()) { + OpDesc* op = node->Op(); + os << "Node(" << DebugString(op) << "), inputs:{"; + bool is_first = true; + for (auto* in : node->inputs) { + if (!is_first) { + os << ", "; + } + os << in->Name(); + is_first = false; + } + os << "}, outputs:{"; + is_first = true; + for (auto* out : node->outputs) { + if (!is_first) { + os << ", "; + } + os << out->Name(); + is_first = false; + } + os << "}."; + } else if (node->IsVar() && node->Var()) { + os << "Node(" << node->Name() << "), inputs:{"; + bool is_first = true; + for (auto* in : node->inputs) { + if (!is_first) { + os << ", "; + } + if (in->IsOp() && in->Op()) { + os << in->Op()->Type(); + } + is_first = false; + } + os << "}, outputs:{"; + is_first = true; + for (auto* out : node->outputs) { + if (!is_first) { + os << ", "; + } + if (out->IsOp() && out->Op()) { + os << out->Op()->Type(); + } + is_first = false; + } + os << "}"; + } + return os.str(); +} + +static std::string DebugString(const std::unique_ptr& graph) { + std::ostringstream os; + os << "Graph: {\n"; + for (auto* node : graph->Nodes()) { + if (node->IsOp() && node->Op()) { + os << " "; + } else if (node->IsVar() && node->Var()) { + os << " "; + } + os << DebugString(node) << "\n"; + } + os << "}\n"; + return os.str(); +} + +static int GetNumOpNodes(const std::unique_ptr& graph, + std::string op_type) { + int num_nodes = 0; + for (auto* node : graph->Nodes()) { + if (node->IsOp() && node->Op() && node->Op()->Type() == op_type) { + num_nodes++; + } + } + return num_nodes; +} + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/simplify_with_basic_ops_pass.cc b/paddle/fluid/framework/ir/simplify_with_basic_ops_pass.cc new file mode 100644 index 0000000000..61784f8c66 --- /dev/null +++ b/paddle/fluid/framework/ir/simplify_with_basic_ops_pass.cc @@ -0,0 +1,202 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/simplify_with_basic_ops_pass.h" + +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass_tester_helper.h" + +namespace paddle { +namespace framework { +namespace ir { + +/* + * This pass is to simplify the Grpah, it may contains: + * - replace comlicated op with basic op + * - remove some unnecessary op + * + * In the current implementation, it supports: + * - remove dropout_op (upscale_in_train) or + * replace dropout_op with scale_op (downgrade_in_infer) when is_test is true + */ +void SimplifyWithBasicOpsPass::ApplyImpl(Graph* graph) const { + VLOG(3) << "Simplify the Graph with basic ops."; + std::unordered_set del_node_set; + for (Node* n : graph->Nodes()) { + if (n->IsOp() && n->Op()) { + if (n->Op()->Type() == "dropout") { + SimplifyDropout(graph, n, &del_node_set); + } + } + } + + GraphSafeRemoveNodes(graph, del_node_set); +} + +bool SimplifyWithBasicOpsPass::SimplifyDropout( + Graph* graph, Node* n, + std::unordered_set* del_node_set) const { + OpDesc* dropout_op_desc = n->Op(); + bool is_test = false; + // In the model used in test_analyzer_bert, the is_test's AttrType of + // dropout_op is INT. + if (dropout_op_desc->HasAttr("is_test")) { + if (dropout_op_desc->GetAttrType("is_test") == proto::AttrType::BOOLEAN) { + is_test = boost::get(dropout_op_desc->GetAttr("is_test")); + } else if (dropout_op_desc->GetAttrType("is_test") == + proto::AttrType::INT) { + is_test = boost::get(dropout_op_desc->GetAttr("is_test")) == 0 + ? false + : true; + } + } + + if (!is_test) { + return false; + } + + Node* dropout_x = GetInputVar(n, dropout_op_desc->Input("X")[0]); + Node* dropout_out = GetOutputVar(n, dropout_op_desc->Output("Out")[0]); + + bool upscale_in_train = false; + // Once the dropout_implementation's AttrType is BOOLEAN, but now is STRING. + if (dropout_op_desc->HasAttr("dropout_implementation")) { + if (dropout_op_desc->GetAttrType("dropout_implementation") == + proto::AttrType::BOOLEAN) { + upscale_in_train = + boost::get(dropout_op_desc->GetAttr("dropout_implementation")); + } else if (dropout_op_desc->GetAttrType("dropout_implementation") == + proto::AttrType::STRING) { + upscale_in_train = boost::get(dropout_op_desc->GetAttr( + "dropout_implementation")) == "upscale_in_train"; + } + } + + if (upscale_in_train) { + // dropout_op can be deleted. + // dropout_x -> dropout_op -> dropout_out -> next_op -> next_out + // | + // \|/ + // dropout_x -> next_op -> next_out + // Check whether dropout_x is some next_op's output + bool dropout_x_is_reused_as_output = false; + for (auto* next_op : dropout_out->outputs) { + for (auto* next_out : next_op->outputs) { + if (next_out == dropout_x || + next_out->Var()->Name() == dropout_x->Var()->Name()) { + dropout_x_is_reused_as_output = true; + break; + } + } + if (dropout_x_is_reused_as_output) { + break; + } + } + if (dropout_x_is_reused_as_output) { + VarDesc new_var_desc(*dropout_x->Var()); + new_var_desc.SetName("simplify_with_basic_ops_" + dropout_x->Name()); + auto* new_var_node = graph->CreateVarNode(&new_var_desc); + for (auto* out_op : dropout_x->outputs) { + if (out_op != n) { + ReplaceInputVar(out_op, dropout_x, new_var_node); + } + } + for (auto* in_op : dropout_x->inputs) { + ReplaceOutputVar(in_op, dropout_x, new_var_node); + } + dropout_x = new_var_node; + } + for (auto* next_op : dropout_out->outputs) { + ReplaceInputVar(next_op, dropout_out, dropout_x); + } + + del_node_set->insert(dropout_out); + } else { + // Use a scale_op replaces the dropout_op + // dropout_x -> dropout_op -> dropout_out -> next_op -> next_out + // | + // \|/ + // dropout_x -> scale_op -> dropout_out -> next_op -> next_out + float scale = + 1.0f - boost::get(dropout_op_desc->GetAttr("dropout_prob")); + + framework::OpDesc new_op_desc; + new_op_desc.SetType("scale"); + new_op_desc.SetInput("X", {dropout_x->Name()}); + new_op_desc.SetOutput("Out", {dropout_out->Name()}); + new_op_desc.SetAttr("scale", scale); + new_op_desc.SetAttr("bias", static_cast(0)); + new_op_desc.SetAttr("bias_after_scale", true); + + auto* scale_op_node = graph->CreateOpNode(&new_op_desc); + IR_NODE_LINK_TO(dropout_x, scale_op_node); + IR_NODE_LINK_TO(scale_op_node, dropout_out); + } + + del_node_set->insert(n); + return true; +} + +Node* SimplifyWithBasicOpsPass::GetInputVar(Node* n, + const std::string& name) const { + for (auto* in : n->inputs) { + if (in->Name() == name) { + return in; + } + } + return nullptr; +} + +Node* SimplifyWithBasicOpsPass::GetOutputVar(Node* n, + const std::string& name) const { + for (auto* out : n->outputs) { + if (out->Name() == name) { + return out; + } + } + return nullptr; +} + +void SimplifyWithBasicOpsPass::ReplaceInputVar(Node* op, Node* old_var, + Node* new_var) const { + if (op->IsOp() && op->Op()) { + new_var->outputs.push_back(op); + for (size_t i = 0; i < op->inputs.size(); ++i) { + if (op->inputs[i] == old_var) { + op->inputs[i] = new_var; + op->Op()->RenameInput(old_var->Name(), new_var->Name()); + } + } + } +} + +void SimplifyWithBasicOpsPass::ReplaceOutputVar(Node* op, Node* old_var, + Node* new_var) const { + if (op->IsOp() && op->Op()) { + new_var->inputs.push_back(op); + for (size_t i = 0; i < op->outputs.size(); ++i) { + if (op->outputs[i] == old_var) { + op->outputs[i] = new_var; + op->Op()->RenameOutput(old_var->Name(), new_var->Name()); + } + } + } +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(simplify_with_basic_ops_pass, + paddle::framework::ir::SimplifyWithBasicOpsPass); diff --git a/paddle/fluid/framework/ir/simplify_with_basic_ops_pass.h b/paddle/fluid/framework/ir/simplify_with_basic_ops_pass.h new file mode 100644 index 0000000000..f518562246 --- /dev/null +++ b/paddle/fluid/framework/ir/simplify_with_basic_ops_pass.h @@ -0,0 +1,42 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +class SimplifyWithBasicOpsPass : public Pass { + protected: + void ApplyImpl(Graph* graph) const override; + + private: + bool SimplifyDropout(Graph* graph, Node* n, + std::unordered_set* del_node_set) const; + + Node* GetInputVar(Node* n, const std::string& name) const; + Node* GetOutputVar(Node* n, const std::string& name) const; + + void ReplaceInputVar(Node* op, Node* old_var, Node* new_var) const; + void ReplaceOutputVar(Node* op, Node* old_var, Node* new_var) const; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/simplify_with_basic_ops_pass_tester.cc b/paddle/fluid/framework/ir/simplify_with_basic_ops_pass_tester.cc new file mode 100644 index 0000000000..5c23d67cf1 --- /dev/null +++ b/paddle/fluid/framework/ir/simplify_with_basic_ops_pass_tester.cc @@ -0,0 +1,78 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/simplify_with_basic_ops_pass.h" + +#include +#include "paddle/fluid/framework/ir/pass_tester_helper.h" +#include "paddle/fluid/framework/op_proto_maker.h" + +namespace paddle { +namespace framework { +namespace ir { + +TEST(SimplifyWithBasicOpsPass, dropout) { + for (std::string dropout_implementation : + {"downgrade_in_infer", "upscale_in_train"}) { + for (auto inplace : {false, true}) { + if (dropout_implementation == "downgrade_in_infer" && inplace == true) { + continue; + } + + LOG(INFO) << "dropout_implementation: " << dropout_implementation + << ", inplace: " << inplace; + Layers layers; + // (x, y) -> mul -> tmp_0 + // (tmp_0) -> dropout -> (tmp_1) + // (tmp_1, z) -> elementwise_add -> (tmp_2) + // or + // (tmp_1, z) -> elementwise_add -> (tmp_0) + auto* x = layers.data("x"); + auto* y = layers.data("y"); + auto* z = layers.data("z"); + auto* mul_out = layers.mul(x, y); + auto* dropout_out = layers.dropout(mul_out, 0.5f, dropout_implementation); + if (inplace) { + layers.elementwise_add(dropout_out, z, mul_out); + } else { + layers.elementwise_add(dropout_out, z); + } + + std::unique_ptr graph(new Graph(layers.main_program())); + auto pass = PassRegistry::Instance().Get("simplify_with_basic_ops_pass"); + int num_dropout_nodes_before = GetNumOpNodes(graph, "dropout"); + int num_scale_nodes_before = GetNumOpNodes(graph, "scale"); + VLOG(3) << DebugString(graph); + + graph.reset(pass->Apply(graph.release())); + int num_dropout_nodes_after = GetNumOpNodes(graph, "dropout"); + int num_scale_nodes_after = GetNumOpNodes(graph, "scale"); + VLOG(3) << DebugString(graph); + + PADDLE_ENFORCE_EQ(num_dropout_nodes_after, 0UL); + if (dropout_implementation == "downgrade_in_infer") { + PADDLE_ENFORCE_EQ(num_dropout_nodes_before, + num_scale_nodes_after - num_scale_nodes_before); + } else { + PADDLE_ENFORCE_EQ(num_scale_nodes_after - num_scale_nodes_before, 0UL); + } + } + } +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(simplify_with_basic_ops_pass); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index f48c280087..6deeb6fc89 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -104,7 +104,9 @@ const std::vector kAnakinSubgraphPasses({ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { passes_.assign({ // "identity_scale_op_clean_pass", // - "conv_affine_channel_fuse_pass", // + "is_test_pass", // + "simplify_with_basic_ops_pass", // + "conv_affine_channel_fuse_pass", // "conv_eltwiseadd_affine_channel_fuse_pass", // "conv_bn_fuse_pass", // "conv_eltwiseadd_bn_fuse_pass", // @@ -138,7 +140,8 @@ void GpuPassStrategy::EnableNgraph() { CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) { // NOTE the large fusions should be located in the front, so that they will // not be damaged by smaller ones. - passes_.assign({"attention_lstm_fuse_pass", // + passes_.assign({"simplify_with_basic_ops_pass", // + "attention_lstm_fuse_pass", // "seqconv_eltadd_relu_fuse_pass", // // "seqpool_concat_fuse_pass", // "seqpool_cvm_concat_fuse_pass", // -- GitLab