未验证 提交 fcec365d 编写于 作者: Y Yiqun Liu 提交者: GitHub

Add a pass to replace dropout_op with scale_op when is_test is true (#19297)

* Add simplify_with_basic_ops_pass to replace dropout_op with scale_op when is_test is true.
test=develop

* Delete dropout_op directly when upscale_in_train is true.
test=develop

* Improve the debug string, adding the print of op_desc information.

* Fix the case when dropout's input x is reused as the next op's output.

* Add the pass to inference.
test=develop

* Change the log level.
test=develop

* Add unittest for inplace case.

* Add comment to explain the pass.

* Apply the pass for CPU inference.
test=develop

* Fix the typo.
test=develop

* Add the check of AttrType.
test=develop
上级 e1695388
...@@ -76,6 +76,7 @@ pass_library(quant_conv2d_dequant_fuse_pass inference) ...@@ -76,6 +76,7 @@ pass_library(quant_conv2d_dequant_fuse_pass inference)
pass_library(fillconstant_elementwisemul_fuse inference) pass_library(fillconstant_elementwisemul_fuse inference)
pass_library(shuffle_channel_detect_pass inference) pass_library(shuffle_channel_detect_pass inference)
pass_library(delete_quant_dequant_op_pass inference) pass_library(delete_quant_dequant_op_pass inference)
pass_library(simplify_with_basic_ops_pass base)
if(ANAKIN_SUBGRAPH) if(ANAKIN_SUBGRAPH)
pass_library(simplify_anakin_priorbox_detection_out_pass inference) pass_library(simplify_anakin_priorbox_detection_out_pass inference)
...@@ -119,6 +120,7 @@ cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framewor ...@@ -119,6 +120,7 @@ cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framewor
cc_test(test_seqpool_concat_fuse_pass SRCS seqpool_concat_fuse_pass_tester.cc DEPS seqpool_concat_fuse_pass framework_proto) cc_test(test_seqpool_concat_fuse_pass SRCS seqpool_concat_fuse_pass_tester.cc DEPS seqpool_concat_fuse_pass framework_proto)
cc_test(test_seqpool_cvm_concat_fuse_pass SRCS seqpool_cvm_concat_fuse_pass_tester.cc DEPS seqpool_cvm_concat_fuse_pass framework_proto) cc_test(test_seqpool_cvm_concat_fuse_pass SRCS seqpool_cvm_concat_fuse_pass_tester.cc DEPS seqpool_cvm_concat_fuse_pass framework_proto)
cc_test(test_is_test_pass SRCS is_test_pass_tester.cc DEPS is_test_pass) cc_test(test_is_test_pass SRCS is_test_pass_tester.cc DEPS is_test_pass)
cc_test(test_simplify_with_basic_ops_pass SRCS simplify_with_basic_ops_pass_tester.cc DEPS simplify_with_basic_ops_pass)
if(NOT WIN32) if(NOT WIN32)
cc_test(test_sync_batch_norm_pass SRCS sync_batch_norm_pass_tester.cc DEPS sync_batch_norm_pass) cc_test(test_sync_batch_norm_pass SRCS sync_batch_norm_pass_tester.cc DEPS sync_batch_norm_pass)
endif() endif()
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <sstream>
#include <string>
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle {
namespace framework {
namespace ir {
struct Layers {
public:
const ProgramDesc& main_program() { return program_; }
VarDesc* data(std::string name) { return lod_tensor(name); }
VarDesc* mul(VarDesc* x, VarDesc* y, VarDesc* out = nullptr) {
return binary_op("mul", x, y, out);
}
VarDesc* elementwise_add(VarDesc* x, VarDesc* y, VarDesc* out = nullptr) {
return binary_op("elementwise_add", x, y, out);
}
VarDesc* dropout(VarDesc* x, float dropout_prob,
std::string dropout_implementation) {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("dropout");
op->SetInput("X", {x->Name()});
op->SetOutput("Out", {out->Name()});
op->SetAttr("is_test", true);
op->SetAttr("dropout_prob", dropout_prob);
op->SetAttr("dropout_implementation", dropout_implementation);
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
return out;
}
private:
VarDesc* lod_tensor(std::string name) {
auto* var = program_.MutableBlock(0)->Var(name);
var->SetType(proto::VarType::LOD_TENSOR);
return var;
}
VarDesc* binary_op(std::string type, VarDesc* x, VarDesc* y,
VarDesc* out = nullptr) {
if (!out) {
out = lod_tensor(unique_name());
}
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType(type);
op->SetInput("X", {x->Name()});
op->SetInput("Y", {y->Name()});
op->SetOutput("Out", {out->Name()});
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
return out;
}
std::string unique_name() { return "tmp_" + std::to_string(idx_++); }
private:
ProgramDesc program_;
int idx_{0};
};
static std::string DebugString(OpDesc* op) {
std::ostringstream os;
os << "Op(" << op->Type() << "), inputs:{";
bool is_first = true;
for (auto& name : op->InputNames()) {
if (!is_first) {
os << ", ";
}
os << name << "[";
bool is_first_var_name = true;
for (auto& var_name : op->Input(name)) {
if (!is_first_var_name) {
os << ", ";
}
os << var_name;
is_first_var_name = false;
}
os << "]";
is_first = false;
}
os << "}, outputs:{";
is_first = true;
for (auto& name : op->OutputNames()) {
if (!is_first) {
os << ", ";
}
os << name << "[";
bool is_first_var_name = true;
for (auto& var_name : op->Output(name)) {
if (!is_first_var_name) {
os << ", ";
}
os << var_name;
is_first_var_name = false;
}
os << "]";
is_first = false;
}
os << "}";
return os.str();
}
static std::string DebugString(Node* node) {
std::ostringstream os;
if (node->IsOp() && node->Op()) {
OpDesc* op = node->Op();
os << "Node(" << DebugString(op) << "), inputs:{";
bool is_first = true;
for (auto* in : node->inputs) {
if (!is_first) {
os << ", ";
}
os << in->Name();
is_first = false;
}
os << "}, outputs:{";
is_first = true;
for (auto* out : node->outputs) {
if (!is_first) {
os << ", ";
}
os << out->Name();
is_first = false;
}
os << "}.";
} else if (node->IsVar() && node->Var()) {
os << "Node(" << node->Name() << "), inputs:{";
bool is_first = true;
for (auto* in : node->inputs) {
if (!is_first) {
os << ", ";
}
if (in->IsOp() && in->Op()) {
os << in->Op()->Type();
}
is_first = false;
}
os << "}, outputs:{";
is_first = true;
for (auto* out : node->outputs) {
if (!is_first) {
os << ", ";
}
if (out->IsOp() && out->Op()) {
os << out->Op()->Type();
}
is_first = false;
}
os << "}";
}
return os.str();
}
static std::string DebugString(const std::unique_ptr<Graph>& graph) {
std::ostringstream os;
os << "Graph: {\n";
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()) {
os << " ";
} else if (node->IsVar() && node->Var()) {
os << " ";
}
os << DebugString(node) << "\n";
}
os << "}\n";
return os.str();
}
static int GetNumOpNodes(const std::unique_ptr<Graph>& graph,
std::string op_type) {
int num_nodes = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op() && node->Op()->Type() == op_type) {
num_nodes++;
}
}
return num_nodes;
}
} // namespace ir
} // namespace framework
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/simplify_with_basic_ops_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* This pass is to simplify the Grpah, it may contains:
* - replace comlicated op with basic op
* - remove some unnecessary op
*
* In the current implementation, it supports:
* - remove dropout_op (upscale_in_train) or
* replace dropout_op with scale_op (downgrade_in_infer) when is_test is true
*/
void SimplifyWithBasicOpsPass::ApplyImpl(Graph* graph) const {
VLOG(3) << "Simplify the Graph with basic ops.";
std::unordered_set<const Node*> del_node_set;
for (Node* n : graph->Nodes()) {
if (n->IsOp() && n->Op()) {
if (n->Op()->Type() == "dropout") {
SimplifyDropout(graph, n, &del_node_set);
}
}
}
GraphSafeRemoveNodes(graph, del_node_set);
}
bool SimplifyWithBasicOpsPass::SimplifyDropout(
Graph* graph, Node* n,
std::unordered_set<const Node*>* del_node_set) const {
OpDesc* dropout_op_desc = n->Op();
bool is_test = false;
// In the model used in test_analyzer_bert, the is_test's AttrType of
// dropout_op is INT.
if (dropout_op_desc->HasAttr("is_test")) {
if (dropout_op_desc->GetAttrType("is_test") == proto::AttrType::BOOLEAN) {
is_test = boost::get<bool>(dropout_op_desc->GetAttr("is_test"));
} else if (dropout_op_desc->GetAttrType("is_test") ==
proto::AttrType::INT) {
is_test = boost::get<int>(dropout_op_desc->GetAttr("is_test")) == 0
? false
: true;
}
}
if (!is_test) {
return false;
}
Node* dropout_x = GetInputVar(n, dropout_op_desc->Input("X")[0]);
Node* dropout_out = GetOutputVar(n, dropout_op_desc->Output("Out")[0]);
bool upscale_in_train = false;
// Once the dropout_implementation's AttrType is BOOLEAN, but now is STRING.
if (dropout_op_desc->HasAttr("dropout_implementation")) {
if (dropout_op_desc->GetAttrType("dropout_implementation") ==
proto::AttrType::BOOLEAN) {
upscale_in_train =
boost::get<bool>(dropout_op_desc->GetAttr("dropout_implementation"));
} else if (dropout_op_desc->GetAttrType("dropout_implementation") ==
proto::AttrType::STRING) {
upscale_in_train = boost::get<std::string>(dropout_op_desc->GetAttr(
"dropout_implementation")) == "upscale_in_train";
}
}
if (upscale_in_train) {
// dropout_op can be deleted.
// dropout_x -> dropout_op -> dropout_out -> next_op -> next_out
// |
// \|/
// dropout_x -> next_op -> next_out
// Check whether dropout_x is some next_op's output
bool dropout_x_is_reused_as_output = false;
for (auto* next_op : dropout_out->outputs) {
for (auto* next_out : next_op->outputs) {
if (next_out == dropout_x ||
next_out->Var()->Name() == dropout_x->Var()->Name()) {
dropout_x_is_reused_as_output = true;
break;
}
}
if (dropout_x_is_reused_as_output) {
break;
}
}
if (dropout_x_is_reused_as_output) {
VarDesc new_var_desc(*dropout_x->Var());
new_var_desc.SetName("simplify_with_basic_ops_" + dropout_x->Name());
auto* new_var_node = graph->CreateVarNode(&new_var_desc);
for (auto* out_op : dropout_x->outputs) {
if (out_op != n) {
ReplaceInputVar(out_op, dropout_x, new_var_node);
}
}
for (auto* in_op : dropout_x->inputs) {
ReplaceOutputVar(in_op, dropout_x, new_var_node);
}
dropout_x = new_var_node;
}
for (auto* next_op : dropout_out->outputs) {
ReplaceInputVar(next_op, dropout_out, dropout_x);
}
del_node_set->insert(dropout_out);
} else {
// Use a scale_op replaces the dropout_op
// dropout_x -> dropout_op -> dropout_out -> next_op -> next_out
// |
// \|/
// dropout_x -> scale_op -> dropout_out -> next_op -> next_out
float scale =
1.0f - boost::get<float>(dropout_op_desc->GetAttr("dropout_prob"));
framework::OpDesc new_op_desc;
new_op_desc.SetType("scale");
new_op_desc.SetInput("X", {dropout_x->Name()});
new_op_desc.SetOutput("Out", {dropout_out->Name()});
new_op_desc.SetAttr("scale", scale);
new_op_desc.SetAttr("bias", static_cast<float>(0));
new_op_desc.SetAttr("bias_after_scale", true);
auto* scale_op_node = graph->CreateOpNode(&new_op_desc);
IR_NODE_LINK_TO(dropout_x, scale_op_node);
IR_NODE_LINK_TO(scale_op_node, dropout_out);
}
del_node_set->insert(n);
return true;
}
Node* SimplifyWithBasicOpsPass::GetInputVar(Node* n,
const std::string& name) const {
for (auto* in : n->inputs) {
if (in->Name() == name) {
return in;
}
}
return nullptr;
}
Node* SimplifyWithBasicOpsPass::GetOutputVar(Node* n,
const std::string& name) const {
for (auto* out : n->outputs) {
if (out->Name() == name) {
return out;
}
}
return nullptr;
}
void SimplifyWithBasicOpsPass::ReplaceInputVar(Node* op, Node* old_var,
Node* new_var) const {
if (op->IsOp() && op->Op()) {
new_var->outputs.push_back(op);
for (size_t i = 0; i < op->inputs.size(); ++i) {
if (op->inputs[i] == old_var) {
op->inputs[i] = new_var;
op->Op()->RenameInput(old_var->Name(), new_var->Name());
}
}
}
}
void SimplifyWithBasicOpsPass::ReplaceOutputVar(Node* op, Node* old_var,
Node* new_var) const {
if (op->IsOp() && op->Op()) {
new_var->inputs.push_back(op);
for (size_t i = 0; i < op->outputs.size(); ++i) {
if (op->outputs[i] == old_var) {
op->outputs[i] = new_var;
op->Op()->RenameOutput(old_var->Name(), new_var->Name());
}
}
}
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(simplify_with_basic_ops_pass,
paddle::framework::ir::SimplifyWithBasicOpsPass);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <unordered_set>
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class SimplifyWithBasicOpsPass : public Pass {
protected:
void ApplyImpl(Graph* graph) const override;
private:
bool SimplifyDropout(Graph* graph, Node* n,
std::unordered_set<const Node*>* del_node_set) const;
Node* GetInputVar(Node* n, const std::string& name) const;
Node* GetOutputVar(Node* n, const std::string& name) const;
void ReplaceInputVar(Node* op, Node* old_var, Node* new_var) const;
void ReplaceOutputVar(Node* op, Node* old_var, Node* new_var) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/simplify_with_basic_ops_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle {
namespace framework {
namespace ir {
TEST(SimplifyWithBasicOpsPass, dropout) {
for (std::string dropout_implementation :
{"downgrade_in_infer", "upscale_in_train"}) {
for (auto inplace : {false, true}) {
if (dropout_implementation == "downgrade_in_infer" && inplace == true) {
continue;
}
LOG(INFO) << "dropout_implementation: " << dropout_implementation
<< ", inplace: " << inplace;
Layers layers;
// (x, y) -> mul -> tmp_0
// (tmp_0) -> dropout -> (tmp_1)
// (tmp_1, z) -> elementwise_add -> (tmp_2)
// or
// (tmp_1, z) -> elementwise_add -> (tmp_0)
auto* x = layers.data("x");
auto* y = layers.data("y");
auto* z = layers.data("z");
auto* mul_out = layers.mul(x, y);
auto* dropout_out = layers.dropout(mul_out, 0.5f, dropout_implementation);
if (inplace) {
layers.elementwise_add(dropout_out, z, mul_out);
} else {
layers.elementwise_add(dropout_out, z);
}
std::unique_ptr<Graph> graph(new Graph(layers.main_program()));
auto pass = PassRegistry::Instance().Get("simplify_with_basic_ops_pass");
int num_dropout_nodes_before = GetNumOpNodes(graph, "dropout");
int num_scale_nodes_before = GetNumOpNodes(graph, "scale");
VLOG(3) << DebugString(graph);
graph.reset(pass->Apply(graph.release()));
int num_dropout_nodes_after = GetNumOpNodes(graph, "dropout");
int num_scale_nodes_after = GetNumOpNodes(graph, "scale");
VLOG(3) << DebugString(graph);
PADDLE_ENFORCE_EQ(num_dropout_nodes_after, 0UL);
if (dropout_implementation == "downgrade_in_infer") {
PADDLE_ENFORCE_EQ(num_dropout_nodes_before,
num_scale_nodes_after - num_scale_nodes_before);
} else {
PADDLE_ENFORCE_EQ(num_scale_nodes_after - num_scale_nodes_before, 0UL);
}
}
}
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(simplify_with_basic_ops_pass);
...@@ -104,6 +104,8 @@ const std::vector<std::string> kAnakinSubgraphPasses({ ...@@ -104,6 +104,8 @@ const std::vector<std::string> kAnakinSubgraphPasses({
GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
passes_.assign({ passes_.assign({
// "identity_scale_op_clean_pass", // // "identity_scale_op_clean_pass", //
"is_test_pass", //
"simplify_with_basic_ops_pass", //
"conv_affine_channel_fuse_pass", // "conv_affine_channel_fuse_pass", //
"conv_eltwiseadd_affine_channel_fuse_pass", // "conv_eltwiseadd_affine_channel_fuse_pass", //
"conv_bn_fuse_pass", // "conv_bn_fuse_pass", //
...@@ -138,7 +140,8 @@ void GpuPassStrategy::EnableNgraph() { ...@@ -138,7 +140,8 @@ void GpuPassStrategy::EnableNgraph() {
CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) { CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
// NOTE the large fusions should be located in the front, so that they will // NOTE the large fusions should be located in the front, so that they will
// not be damaged by smaller ones. // not be damaged by smaller ones.
passes_.assign({"attention_lstm_fuse_pass", // passes_.assign({"simplify_with_basic_ops_pass", //
"attention_lstm_fuse_pass", //
"seqconv_eltadd_relu_fuse_pass", // "seqconv_eltadd_relu_fuse_pass", //
// "seqpool_concat_fuse_pass", // // "seqpool_concat_fuse_pass", //
"seqpool_cvm_concat_fuse_pass", // "seqpool_cvm_concat_fuse_pass", //
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册