diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 0b4b18c94b439c5cc17fe423fb414fee1569df74..919364541e4eee27a5970da12ffd818124699d50 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1522,6 +1522,25 @@ PDNode *patterns::FcDequant::operator()() { return dequant_out; } +PDNode *patterns::DequantScale::operator()() { + // Create Operators + auto dequant_op = + pattern->NewNode(dequant_op_repr())->assert_is_op("dequantize"); + auto scale_op = pattern->NewNode(scale_op_repr())->assert_is_op("scale"); + + auto dequant_out = pattern->NewNode(dequant_out_repr()) + ->AsOutput() + ->assert_is_op_output("dequantize", "Output"); + auto scale_out = pattern->NewNode(scale_out_repr()) + ->AsOutput() + ->assert_is_op_output("scale", "Out"); + + dequant_op->LinksTo({dequant_out}); + scale_op->LinksFrom({dequant_out}).LinksTo({scale_out}); + + return scale_out; +} + PDNode *patterns::PriorBox::operator()() { auto prior_box_op = pattern->NewNode(prior_box_op_repr())->assert_is_op("prior_box"); diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index db58c9e8fdd0f95c3c56876acbfba4d3dd85f495..dcdf4318c883851ec97208cabb8d5e9a6af8a611 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -929,6 +929,20 @@ struct FcDequant : public PatternBase { PATTERN_DECL_NODE(dequant_out); }; +// Dequantize + Scale +struct DequantScale : public PatternBase { + DequantScale(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "dequant_scale") {} + + PDNode* operator()(); + + PATTERN_DECL_NODE(dequant_op); + PATTERN_DECL_NODE(dequant_out); + + PATTERN_DECL_NODE(scale_op); + PATTERN_DECL_NODE(scale_out); +}; + // PriorBox operator // operator: prior_box_op // inputs: prior_box_input, prior_box_image diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc index eff9b294f7064ebad7e18fc5f24b791574d62d0c..66556c7cc86640827fa72c72d7b754e1d80b8e27 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc @@ -284,6 +284,49 @@ void CPUQuantizeSquashPass::MultipleQuantizeSquash(Graph* graph) const { PrettyLogDetail("--- squashed %d quantize op", removed_quantize); } +// squash scale with dequant +void CPUQuantizeSquashPass::DequantScaleSquash(Graph* graph) const { + GraphPatternDetector gpd; + patterns::DequantScale dequant_scale_pattern{gpd.mutable_pattern(), + "dequant_scale"}; + dequant_scale_pattern(); + + int found_dequant_scale_squash_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "squash dequant-scale ops pair"; + + GET_IR_NODE_FROM_SUBGRAPH(dequant_op, dequant_op, dequant_scale_pattern); + GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, dequant_scale_pattern); + GET_IR_NODE_FROM_SUBGRAPH(scale_op, scale_op, dequant_scale_pattern); + GET_IR_NODE_FROM_SUBGRAPH(scale_out, scale_out, dequant_scale_pattern); + + if (dequant_out->outputs.size() == 1 && + scale_op->Op()->GetAttrIfExists("bias") == 0.0) { + auto dequant_scale = dequant_op->Op()->GetAttrIfExists("Scale"); + auto scale_scale = scale_op->Op()->GetAttrIfExists("scale"); + + PADDLE_ENFORCE_GT(dequant_scale, 0.0f, + platform::errors::InvalidArgument( + "Dequantize scale should have positive value")); + PADDLE_ENFORCE_GT(scale_scale, 0.0f, + platform::errors::InvalidArgument( + "Scale of scale op should have positive value")); + + dequant_op->Op()->SetAttr("Scale", dequant_scale / scale_scale); + dequant_op->Op()->SetOutput( + "Output", std::vector({scale_out->Name()})); + IR_NODE_LINK_TO(dequant_op, scale_out); + GraphSafeRemoveNodes(graph, {dequant_out, scale_op}); + found_dequant_scale_squash_count++; + } + }; + gpd(graph, handler); + AddStatis(found_dequant_scale_squash_count); + PrettyLogDetail("--- squashed %d scale with dequant", + found_dequant_scale_squash_count); +} + void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE_NOT_NULL( graph, @@ -298,6 +341,7 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const { ConvDequantSquash(graph); FcDequantSquash(graph); MultipleQuantizeSquash(graph); + DequantScaleSquash(graph); } } // namespace ir diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h index af8a66c929bfdc0c5c410a6f16ca973904e2f326..41c5323ba5ce074ccf229878d5eaafb757340b46 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h @@ -70,6 +70,11 @@ class CPUQuantizeSquashPass : public FusePassBase { */ void MultipleQuantizeSquash(Graph* graph) const; + /* + * Squash scale if dequantize is before scale + */ + void DequantScaleSquash(Graph* graph) const; + const std::string name_scope_{"squash"}; }; diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc index 5a364aab1c59169b95756c55bae94ba4e91fc86c..1ce7fc9a72b64beb81c1103abdb1b7ff158480c8 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc @@ -24,7 +24,7 @@ namespace ir { void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, const std::vector& inputs, const std::vector& outputs, bool use_mkldnn, - float scale = 0) { + float scale = 0, float bias = 0.0) { auto* op = prog->MutableBlock(0)->AppendOp(); op->SetType(type); op->SetAttr("use_mkldnn", use_mkldnn); @@ -59,6 +59,11 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, inputs.size())); op->SetInput("W", {inputs[1]}); op->SetOutput("Out", outputs); + } else if (type == "scale") { + op->SetInput("X", {inputs[0]}); + op->SetOutput("Out", {outputs[0]}); + op->SetAttr("scale", scale); + op->SetAttr("bias", bias); } } @@ -252,6 +257,21 @@ ProgramDesc BuildMultipleQuantizeProgramDesc(bool use_mkldnn, float first_scale, return prog; } +// a->Dequant->b +// b->Scale->c +ProgramDesc BuildDequantScaleProgramDesc(bool use_mkldnn, float dequant_scale, + float scale_scale, float bias) { + ProgramDesc prog; + for (auto& v : variable_names) { + prog.MutableBlock(0)->Var(v); + } + SetOp(&prog, "dequantize", "Dequant", {"a"}, {"b"}, use_mkldnn, + dequant_scale); + SetOp(&prog, "scale", "Scale", {"b"}, {"c"}, use_mkldnn, scale_scale, bias); + + return prog; +} + void InitTensorHolder(Scope* scope, const paddle::platform::Place& place, const char* var_name) { auto x = scope->Var(var_name); @@ -289,17 +309,17 @@ void CountNodeTest(const ProgramDesc& prog, int removed_nodes_num) { } // check op->scale_out -void EqualScaleOutTest(const ProgramDesc& prog, const std::string& name, - float scale) { +void EqualScaleTest(const ProgramDesc& prog, const std::string& op_name, + const std::string& scale_name, float scale) { std::unique_ptr graph(new ir::Graph(prog)); PrepareGraph(&graph, prog); RegisterPass(&graph); for (auto* node : graph->Nodes()) { if (node->IsOp() && - boost::get(node->Op()->GetAttr("name")) == name) { - float scale_out = boost::get(node->Op()->GetAttr("Scale_out")); - EXPECT_EQ(scale_out, scale); + boost::get(node->Op()->GetAttr("name")) == op_name) { + float op_scale = boost::get(node->Op()->GetAttr(scale_name)); + EXPECT_EQ(op_scale, scale); } } } @@ -368,9 +388,9 @@ TEST(CpuQuantizeSquashPass, unequal_scales) { BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale1, scale2), remove_nodes); - EqualScaleOutTest( + EqualScaleTest( BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale1, scale2), - "Conv1", scale2); + "Conv1", "Scale_out", scale2); } // a->Conv1->b->Requant->c @@ -388,12 +408,12 @@ TEST(CpuQuantizeSquashPass, equal_scales_squash_requantize) { remove_nodes); // check equal scale conv->scale_out and requant->scale_out - EqualScaleOutTest( + EqualScaleTest( BuildConvsRequantConcatProgramDesc(use_mkldnn, scale_out, scale, scale), - "Conv1", scale); - EqualScaleOutTest( + "Conv1", "Scale_out", scale); + EqualScaleTest( BuildConvsRequantConcatProgramDesc(use_mkldnn, scale_out, scale, scale), - "Conv2", scale); + "Conv2", "Scale_out", scale); } // from @@ -544,6 +564,37 @@ TEST(CpuQuantizeSquashPass, quatize_with_different_scale) { remove_nodes); } +// if scale has no bias +TEST(CpuQuantizeSquashPass, dequantize_scale_with_no_bias) { + auto dequant_scale = 1.2345f; + auto scale_scale = 1.5432f; + auto bias = 0.0f; + auto use_mkldnn = true; + // remove: dequant out, scale op + auto remove_nodes = 2; + CountNodeTest(BuildDequantScaleProgramDesc(use_mkldnn, dequant_scale, + scale_scale, bias), + remove_nodes); + EqualScaleTest(BuildDequantScaleProgramDesc(use_mkldnn, dequant_scale, + scale_scale, bias), + "Dequant", "Scale", dequant_scale / scale_scale); +} + +// if scale has bias +TEST(CpuQuantizeSquashPass, dequantize_scale_with_bias) { + auto dequant_scale = 1.2345f; + auto scale_scale = 1.5432f; + auto bias = 1.0f; + auto use_mkldnn = true; + // nothing change + auto remove_nodes = 0; + CountNodeTest(BuildDequantScaleProgramDesc(use_mkldnn, dequant_scale, + scale_scale, bias), + remove_nodes); + EqualScaleTest(BuildDequantScaleProgramDesc(use_mkldnn, dequant_scale, + scale_scale, bias), + "Dequant", "Scale", dequant_scale); +} } // namespace ir } // namespace framework } // namespace paddle