From d3a96632fad36b2dae2c37cdcaa317cce4f1819e Mon Sep 17 00:00:00 2001 From: lidanqing Date: Mon, 16 Dec 2019 02:55:47 +0100 Subject: [PATCH] Add fc-dequantize squash in cpu_quantize_squash_pass for ernie model (#21714) * fc-dequantize squash test=develop * change according to reviews test=develop * change PADDLE_ENFORCE test=develop * add second test when fc-dequant do not fuse test=develop * change all related PADDLE_ENFORCE test=develop --- .../framework/ir/graph_pattern_detector.cc | 63 ++++++++++--- .../framework/ir/graph_pattern_detector.h | 31 +++++-- .../ir/mkldnn/cpu_quantize_squash_pass.cc | 43 ++++++++- .../ir/mkldnn/cpu_quantize_squash_pass.h | 5 + .../mkldnn/cpu_quantize_squash_pass_tester.cc | 93 +++++++++++++++++++ 5 files changed, 213 insertions(+), 22 deletions(-) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index fa4ee48bf6f..ab7e89946b5 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -41,9 +41,10 @@ size_t PDPattern::id_ = 0UL; PDNode *PDPattern::NewNode(const std::string &name) { if (!name.empty()) { - PADDLE_ENFORCE_EQ(node_map_.count(name), 0UL, - "PDNode's name should be unique, get duplicate [%s]", - name); + PADDLE_ENFORCE_EQ( + node_map_.count(name), 0UL, + platform::errors::PreconditionNotMet( + "PDNode's name should be unique, get duplicate [%s]", name)); } nodes_.emplace_back(new PDNode(this, name)); @@ -54,9 +55,10 @@ PDNode *PDPattern::NewNode(const std::string &name) { PDNode *PDPattern::NewNode(PDNode::teller_t &&teller, const std::string &name) { if (!name.empty()) { - PADDLE_ENFORCE_EQ(node_map_.count(name), 0UL, - "PDNode's name should be unique, get duplicate [%s]", - name); + PADDLE_ENFORCE_EQ( + node_map_.count(name), 0UL, + platform::errors::PreconditionNotMet( + "PDNode's name should be unique, get duplicate [%s]", name)); } nodes_.emplace_back(new PDNode(std::move(teller), this, name)); @@ -75,8 +77,10 @@ PDNode *PDPattern::RetrieveNode(const std::string &id) const { } void PDPattern::AddEdge(PDNode *a, PDNode *b) { - PADDLE_ENFORCE(a); - PADDLE_ENFORCE(b); + PADDLE_ENFORCE_NOT_NULL( + a, platform::errors::NotFound("PDNode %s is not found.", a->name())); + PADDLE_ENFORCE_NOT_NULL( + b, platform::errors::NotFound("PDNode %s is not found.", b->name())); PADDLE_ENFORCE_NE(a, b, platform::errors::PermissionDenied( "Cannot connect the same node in the graph.")); edges_.emplace_back(a, b); @@ -610,15 +614,24 @@ bool VarLinksToOp(Node *node, const std::string &op_type) { } bool IsNthInput(Node *var, Node *op, const std::string &argument, size_t nth) { - PADDLE_ENFORCE(var->IsVar()); - PADDLE_ENFORCE(op->IsOp()); + PADDLE_ENFORCE_EQ( + var->IsVar(), true, + platform::errors::InvalidArgument( + "First parameter of function IsNthInput must be Node::Var")); + PADDLE_ENFORCE_EQ( + op->IsOp(), true, + platform::errors::InvalidArgument( + "Second parameter of function IsNthInput must be Node::Op")); if (!HasInput(op, argument) || op->Op()->Input(argument).size() <= nth) return false; return var->Name() == op->Op()->Input(argument)[nth]; } bool HasInput(Node *op, const std::string &argument) { - PADDLE_ENFORCE(op->IsOp()); + PADDLE_ENFORCE_EQ( + op->IsOp(), true, + platform::errors::InvalidArgument( + "First parameter of function HasInput must be Node::Op")); auto const &names = op->Op()->InputNames(); if (std::find(names.begin(), names.end(), argument) == names.end()) return false; @@ -626,8 +639,14 @@ bool HasInput(Node *op, const std::string &argument) { } bool IsNthOutput(Node *var, Node *op, const std::string &argument, size_t nth) { - PADDLE_ENFORCE(var->IsVar()); - PADDLE_ENFORCE(op->IsOp()); + PADDLE_ENFORCE_EQ( + var->IsVar(), true, + platform::errors::InvalidArgument( + "First parameter of function IsNthOutput must be Node::Var")); + PADDLE_ENFORCE_EQ( + op->IsOp(), true, + platform::errors::InvalidArgument( + "Second parameter of function IsNthOutput must be Node::Op")); if (op->Op()->Output(argument).size() <= nth) return false; return var->Name() == op->Op()->Output(argument)[nth]; } @@ -1344,6 +1363,24 @@ PDNode *patterns::ConvDequant::operator()() { return dequant_out; } +PDNode *patterns::FcDequant::operator()() { + // Create Operators + auto fc_op = pattern->NewNode(fc_op_repr())->assert_is_op("fc"); + auto dequant_op = + pattern->NewNode(dequant_op_repr())->assert_is_op("dequantize"); + + auto fc_out = + pattern->NewNode(fc_out_repr())->assert_is_op_output("fc", "Out"); + auto dequant_out = pattern->NewNode(dequant_out_repr()) + ->AsOutput() + ->assert_is_op_output("dequantize", "Output"); + + fc_op->LinksTo({fc_out}); + dequant_op->LinksFrom({fc_out}).LinksTo({dequant_out}); + + return dequant_out; +} + PDNode *patterns::PriorBox::operator()() { auto prior_box_op = pattern->NewNode(prior_box_op_repr())->assert_is_op("prior_box"); diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 3b266215da3..06d72dda8cc 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -153,7 +153,9 @@ struct PDNode { pattern_(pattern), name_(name), type_(type) { - PADDLE_ENFORCE(teller_ != nullptr, "invalid teller functer is set."); + PADDLE_ENFORCE_NOT_NULL( + teller_, + platform::errors::NotFound("invalid teller is set, teller is null")); } PDNode(PDNode&& other) = default; @@ -370,11 +372,14 @@ static std::string UniqueKey(const std::string& repr) { // var: variable. // arg: the argument declared by PATTERN_DECL_NODE in a pattern definition. // pat: the pattern object. -#define GET_IR_NODE_FROM_SUBGRAPH(var, arg, pat) \ - PADDLE_ENFORCE(subgraph.count(pat.arg##_n()), \ - "Node not found for PDNode %s", pat.arg##_repr()); \ - Node* var = subgraph.at(pat.arg##_n()); \ - PADDLE_ENFORCE(var, "node %s not exists in the sub-graph", #arg) +#define GET_IR_NODE_FROM_SUBGRAPH(var, arg, pat) \ + PADDLE_ENFORCE_NE(subgraph.count(pat.arg##_n()), 0UL, \ + platform::errors::NotFound("Node not found for PDNode %s", \ + pat.arg##_repr())); \ + Node* var = subgraph.at(pat.arg##_n()); \ + PADDLE_ENFORCE_NOT_NULL( \ + var, platform::errors::NotFound("node %s not exists in the sub-graph", \ + #arg)); // The base class of all the patterns. struct PatternBase { @@ -844,6 +849,20 @@ struct ConvDequant : public PatternBase { PATTERN_DECL_NODE(dequant_out); }; +// Fc + Dequant +struct FcDequant : public PatternBase { + FcDequant(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "fc_dequant") {} + + PDNode* operator()(); + + PATTERN_DECL_NODE(fc_op); + PATTERN_DECL_NODE(fc_out); + + PATTERN_DECL_NODE(dequant_op); + PATTERN_DECL_NODE(dequant_out); +}; + // PriorBox operator // operator: prior_box_op // inputs: prior_box_input, prior_box_image diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc index 1c09dc669da..2d98758985b 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc @@ -71,8 +71,9 @@ void CPUQuantizeSquashPass::DequantQuantSquash( auto* next_op_desc = next_op->Op(); float dequant_scale = boost::get(dequant_op->Op()->GetAttr("Scale")); float quant_scale = boost::get(quant_op->Op()->GetAttr("Scale")); - PADDLE_ENFORCE(nodes_keep_counter->find(dequant_out) != - nodes_keep_counter->end()); + PADDLE_ENFORCE_NE( + nodes_keep_counter->find(dequant_out), nodes_keep_counter->end(), + platform::errors::NotFound("The dequant output node is not found")); // check if dequantize op should be kept or removed, decrease the counter bool keep_dequant = (*nodes_keep_counter)[dequant_out]-- > 1; @@ -195,14 +196,50 @@ void CPUQuantizeSquashPass::ConvDequantSquash(Graph* graph) const { found_conv_dequant_squash_count); } +// squash fc with dequant +void CPUQuantizeSquashPass::FcDequantSquash(Graph* graph) const { + GraphPatternDetector gpd; + patterns::FcDequant fc_dequant_pattern{gpd.mutable_pattern(), "fc_dequant"}; + fc_dequant_pattern(); + + int found_fc_dequant_squash_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "squash fc-dequant ops pair"; + + GET_IR_NODE_FROM_SUBGRAPH(fc_op, fc_op, fc_dequant_pattern); + GET_IR_NODE_FROM_SUBGRAPH(fc_out, fc_out, fc_dequant_pattern); + GET_IR_NODE_FROM_SUBGRAPH(dequant_op, dequant_op, fc_dequant_pattern); + GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, fc_dequant_pattern); + + // if fc has force_fp32_output attribute + if (fc_out->outputs.size() == 1) { + fc_op->Op()->SetAttr("force_fp32_output", true); + fc_op->Op()->SetOutput("Out", + std::vector({dequant_out->Name()})); + IR_NODE_LINK_TO(fc_op, dequant_out); + GraphSafeRemoveNodes(graph, {fc_out, dequant_op}); + found_fc_dequant_squash_count++; + } + }; + gpd(graph, handler); + AddStatis(found_fc_dequant_squash_count); + PrettyLogDetail("--- squashed %d dequant with fcs", + found_fc_dequant_squash_count); +} + void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const { - PADDLE_ENFORCE(graph); + PADDLE_ENFORCE_NOT_NULL( + graph, + platform::errors::NotFound( + "The graph in function CPUQuantizeSquashPass::ApplyImpl is null")); FusePassBase::Init("cpu_quantize_squash_pass", graph); std::unordered_map nodes_keep_counter; FindNodesToKeep(graph, &nodes_keep_counter); DequantQuantSquash(graph, &nodes_keep_counter); ConvDequantSquash(graph); + FcDequantSquash(graph); } } // namespace ir diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h index 7e9e92e3dac..b02e057948d 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h @@ -60,6 +60,11 @@ class CPUQuantizeSquashPass : public FusePassBase { */ void ConvDequantSquash(Graph* graph) const; + /* + * Squash fc with dequant when dequant is the next op after fc + */ + void FcDequantSquash(Graph* graph) const; + const std::string name_scope_{"squash"}; }; diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc index 0dfef76f8a0..9f5597c6dc8 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc @@ -50,6 +50,15 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, } else if (type == "concat") { op->SetInput("X", inputs); op->SetOutput("Out", outputs); + } else if (type == "fc") { + op->SetInput("Input", {inputs[0]}); + PADDLE_ENFORCE_EQ(inputs.size(), 2UL, + platform::errors::InvalidArgument( + "The fc inputs should contain input and weights, but " + "now the size of inputs is %d", + inputs.size())); + op->SetInput("W", {inputs[1]}); + op->SetOutput("Out", outputs); } } @@ -176,6 +185,36 @@ ProgramDesc BuildConvDequantConcatProgramDesc(bool use_mkldnn, float scale_out, return prog; } +// a->fc->b +// b->Dequant1->c +// c->Concat1->d +ProgramDesc BuildFcDequantConcatProgramDesc(bool use_mkldnn, float scale_out, + float scale) { + ProgramDesc prog; + for (auto& v : variable_names) { + prog.MutableBlock(0)->Var(v); + } + SetOp(&prog, "fc", "Fc1", {"a", "w1"}, {"b"}, use_mkldnn, scale_out); + SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, use_mkldnn, scale); + SetOp(&prog, "concat", "Concat1", {"c"}, {"d"}, use_mkldnn); + return prog; +} + +// a->fc->b +// b->Dequant1->c +// b->concat->d +ProgramDesc BuildFcDequantFcProgramDesc(bool use_mkldnn, float scale_out, + float scale) { + ProgramDesc prog; + for (auto& v : variable_names) { + prog.MutableBlock(0)->Var(v); + } + SetOp(&prog, "fc", "Fc1", {"a", "w1"}, {"b"}, use_mkldnn, scale_out); + SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, use_mkldnn, scale); + SetOp(&prog, "concat", "Concat1", {"b"}, {"d"}, use_mkldnn); + return prog; +} + // a->Conv1->b // b->Dequant1(Scale1)->c // b->Conv2->d @@ -261,6 +300,23 @@ void CheckRequantScalesTest(const ProgramDesc& prog, float scale_in, } } +// check requant_op scales +void IsForceFp32OutputTest(const ProgramDesc& prog, std::string op_type, + bool target_is_force_fp32_output) { + std::unique_ptr graph(new ir::Graph(prog)); + + PrepareGraph(&graph, prog); + RegisterPass(&graph); + + for (auto* node : graph->Nodes()) { + if (node->IsOp() && node->Op()->Type() == op_type) { + bool is_force_fp32_output = + node->Op()->GetAttrIfExists("force_fp32_output"); + EXPECT_EQ(is_force_fp32_output, target_is_force_fp32_output); + } + } +} + // From Conv1->d->Dequant->e->Quant->f->Conv2 // To Conv1->d->Conv2 TEST(CpuQuantizeSquashPass, equal_scales) { @@ -362,8 +418,12 @@ TEST(CpuQuantizeSquashPass, conv_dequant_only_one_output) { auto remove_nodes = 2; CountNodeTest(BuildConvDequantConcatProgramDesc(use_mkldnn, scale_out, scale), remove_nodes); + IsForceFp32OutputTest( + BuildConvDequantConcatProgramDesc(use_mkldnn, scale_out, scale), "conv2d", + true); } +// If there are more than one op after conv->dequantize, do not fuse TEST(CpuQuantizeSquashPass, conv_dequant_more_than_one_op_after_conv) { auto scale_out = 1.0f; auto scale = 1.2345f; @@ -372,6 +432,39 @@ TEST(CpuQuantizeSquashPass, conv_dequant_more_than_one_op_after_conv) { auto remove_nodes = 0; CountNodeTest(BuildConvDequantConvProgramDesc(use_mkldnn, scale_out, scale), remove_nodes); + IsForceFp32OutputTest( + BuildConvDequantConvProgramDesc(use_mkldnn, scale_out, scale), "conv2d", + false); +} + +// from +// a->fc->b->Dequant1->c->Concat1->d +// to +// a->fc->c->Concat->d +TEST(CpuQuantizeSquashPass, fc_dequant_only_one_output) { + auto scale_out = 1.0f; + auto scale = 1.2345f; + auto use_mkldnn = true; + // remove 2 nodes: b, Dequant1 + auto remove_nodes = 2; + CountNodeTest(BuildFcDequantConcatProgramDesc(use_mkldnn, scale_out, scale), + remove_nodes); + IsForceFp32OutputTest( + BuildFcDequantConcatProgramDesc(use_mkldnn, scale_out, scale), "fc", + true); +} + +// If there are more than one op after fc->dequantize, do not fuse +TEST(CpuQuantizeSquashPass, fc_dequant_more_than_one_op_after_dequant) { + auto scale_out = 1.0f; + auto scale = 1.2345f; + auto use_mkldnn = true; + // nothing change + auto remove_nodes = 0; + CountNodeTest(BuildFcDequantFcProgramDesc(use_mkldnn, scale_out, scale), + remove_nodes); + IsForceFp32OutputTest( + BuildFcDequantFcProgramDesc(use_mkldnn, scale_out, scale), "fc", false); } } // namespace ir -- GitLab