From 492a00f53ea0b265bcac29a21c8c33d7284dd289 Mon Sep 17 00:00:00 2001 From: "joanna.wozna.intel" Date: Tue, 13 Aug 2019 09:58:44 +0200 Subject: [PATCH] Add conv reqantize squash (#18754) * Add requantize squash test=develop * Add more precise tests test=develop * REname and REfactor tester test=develop --- .../framework/ir/graph_pattern_detector.cc | 17 ++ .../framework/ir/graph_pattern_detector.h | 17 ++ .../ir/mkldnn/cpu_quantize_squash_pass.cc | 50 +++- .../ir/mkldnn/cpu_quantize_squash_pass.h | 10 +- .../mkldnn/cpu_quantize_squash_pass_tester.cc | 266 +++++++++++++++--- 5 files changed, 305 insertions(+), 55 deletions(-) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index c54e805e26..2670c12911 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1275,6 +1275,23 @@ PDNode *patterns::ConvConcatReLU::operator()() { return relu_out; } +PDNode *patterns::ConvRequant::operator()() { + // Create Operators + auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d"); + auto requant_op = + pattern->NewNode(requant_op_repr())->assert_is_op("requantize"); + auto conv_out = pattern->NewNode(conv_out_repr()) + ->assert_is_op_output("conv2d", "Output"); + auto requant_out = pattern->NewNode(requant_out_repr()) + ->AsOutput() + ->assert_is_op_output("requantize", "Output"); + + conv_op->LinksTo({conv_out}); + requant_op->LinksFrom({conv_out}).LinksTo({requant_out}); + + return requant_out; +} + PDNode *patterns::PriorBox::operator()() { auto prior_box_op = pattern->NewNode(prior_box_op_repr())->assert_is_op("prior_box"); diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index c53e4e5e25..d2ad12fca0 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -796,6 +796,23 @@ struct ConvConcatReLU : public PatternBase { PATTERN_DECL_NODE(relu_out); }; +// Conv + Requant +// named nodes: +// conv_op, conv_out +// requant_op, requant_out +struct ConvRequant : public PatternBase { + ConvRequant(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "conv_requant") {} + + PDNode* operator()(); + + PATTERN_DECL_NODE(conv_op); + PATTERN_DECL_NODE(conv_out); + + PATTERN_DECL_NODE(requant_op); + PATTERN_DECL_NODE(requant_out); +}; + // PriorBox operator // operator: prior_box_op // inputs: prior_box_input, prior_box_image diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc index 2270e2b5cc..6277df1432 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc @@ -49,14 +49,14 @@ void CPUQuantizeSquashPass::FindNodesToKeep( AddStatis(found_count); } -void CPUQuantizeSquashPass::Squash( +void CPUQuantizeSquashPass::DequantQuantSquash( Graph* graph, std::unordered_map* nodes_keep_counter) const { GraphPatternDetector gpd; patterns::DequantQuantAny squash_pattern{gpd.mutable_pattern(), "squash"}; squash_pattern(); - int found_squash_count = 0; + int found_dequant_quant_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { VLOG(4) << "squash requantize-quantize ops pair"; @@ -96,7 +96,7 @@ void CPUQuantizeSquashPass::Squash( IR_NODE_LINK_TO(dequant_in, next_op); - found_squash_count++; + found_dequant_quant_count++; } else { // squash dequantize-quantize to requantize op OpDesc desc; @@ -116,13 +116,48 @@ void CPUQuantizeSquashPass::Squash( IR_NODE_LINK_TO(dequant_in, requant_op); IR_NODE_LINK_TO(requant_op, quant_out); - found_squash_count++; + found_dequant_quant_count++; } }; gpd(graph, handler); - AddStatis(found_squash_count); + AddStatis(found_dequant_quant_count); PrettyLogDetail("--- squashed %d dequantize-quantize pairs", - found_squash_count); + found_dequant_quant_count); +} + +void CPUQuantizeSquashPass::ConvRequantSquash(Graph* graph) const { + GraphPatternDetector gpd; + patterns::ConvRequant conv_requant_pattern{gpd.mutable_pattern(), + "conv_requant"}; + conv_requant_pattern(); + + int found_requant_squash_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "squash conv-requantize ops pair"; + + GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_requant_pattern); + GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_requant_pattern); + GET_IR_NODE_FROM_SUBGRAPH(requant_op, requant_op, conv_requant_pattern); + GET_IR_NODE_FROM_SUBGRAPH(requant_out, requant_out, conv_requant_pattern); + + // if conv2d has one output squash + if (conv_out->outputs.size() == 1) { + float requant_scale_out = + boost::get(requant_op->Op()->GetAttr("Scale_out")); + conv_op->Op()->SetAttr("Scale_out", requant_scale_out); + conv_op->Op()->SetOutput("Output", + std::vector({requant_out->Name()})); + IR_NODE_LINK_TO(conv_op, requant_out); + GraphSafeRemoveNodes(graph, {conv_out, requant_op}); + + found_requant_squash_count++; + } + }; + gpd(graph, handler); + AddStatis(found_requant_squash_count); + PrettyLogDetail("--- squashed %d requantize with convs", + found_requant_squash_count); } void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const { @@ -131,7 +166,8 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const { std::unordered_map nodes_keep_counter; FindNodesToKeep(graph, &nodes_keep_counter); - Squash(graph, &nodes_keep_counter); + DequantQuantSquash(graph, &nodes_keep_counter); + ConvRequantSquash(graph); } } // namespace ir diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h index e873994c57..52acdb0390 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h @@ -46,8 +46,14 @@ class CPUQuantizeSquashPass : public FusePassBase { /* * Squash dequantize-quantize ops pairs into requantize or nothing */ - void Squash(Graph* graph, - std::unordered_map* nodes_keep_counter) const; + void DequantQuantSquash( + Graph* graph, + std::unordered_map* nodes_keep_counter) const; + + /* + * Squash requantize op into conv with scale_out like requantize scale_out + */ + void ConvRequantSquash(Graph* graph) const; const std::string name_scope_{"squash"}; }; diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc index 057a790ccb..2d76400971 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc @@ -30,6 +30,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, op->SetAttr("use_mkldnn", use_mkldnn); op->SetAttr("name", name); if (type == "conv2d") { + op->SetAttr("Scale_out", scale); op->SetInput("Input", {inputs[0]}); if (inputs.size() > 1) op->SetInput("Filter", {inputs[1]}); if (inputs.size() > 2) op->SetInput("Bias", {inputs[2]}); @@ -42,14 +43,22 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, op->SetInput("Input", {inputs[0]}); op->SetOutput("Output", {outputs[0]}); op->SetAttr("Scale", scale); + } else if (type == "requantize") { + op->SetInput("Input", {inputs[0]}); + op->SetOutput("Output", {outputs[0]}); + op->SetAttr("Scale_out", scale); + } else if (type == "concat") { + op->SetInput("X", inputs); + op->SetOutput("Out", outputs); } } // (a,w1,b1)->Conv1->d -// d->Dequant->e -// e->Quant->f +// d->Dequant(scale1)->e +// e->Quant(scale2)->f // (f,w2,b2)->Conv2->i -ProgramDesc BuildProgramDesc(bool use_mkldnn, float scale1, float scale2) { +ProgramDesc BuildConvRequantProgramDesc(bool use_mkldnn, float scale_out, + float scale1, float scale2) { ProgramDesc prog; for (auto& v : std::initializer_list( {"a", "w1", "b1", "d", "e", "f", "w2", "b2", "i"})) { @@ -59,42 +68,96 @@ ProgramDesc BuildProgramDesc(bool use_mkldnn, float scale1, float scale2) { } } - SetOp(&prog, "conv2d", "Conv1", {"a", "w1", "b1"}, {"d"}, use_mkldnn); + SetOp(&prog, "conv2d", "Conv1", {"a", "w1", "b1"}, {"d"}, use_mkldnn, + scale_out); SetOp(&prog, "dequantize", "Dequant", {"d"}, {"e"}, use_mkldnn, scale1); SetOp(&prog, "quantize", "Quant", {"e"}, {"f"}, use_mkldnn, scale2); - SetOp(&prog, "conv2d", "Conv2", {"f", "w2", "b2"}, {"i"}, use_mkldnn); + SetOp(&prog, "conv2d", "Conv2", {"f", "w2", "b2"}, {"i"}, use_mkldnn, + scale_out); return prog; } static const std::initializer_list variable_names{ "a", "b", "c", "d", "e", "f", "g", "h"}; + // a->Conv1->b -// b->Dequant->c -// -// c->Quant1->d and d->Conv2->e -// +// b->Dequant(scale1)->c +// c->Quant1(scale2)->d and d->Conv2->e // c->Conv3->f -// -// c->Quant2->g and g->Conv4->h -// -ProgramDesc BuildProgramDesc2(bool use_mkldnn, float scale1, float scale2, - float scale3) { +// c->Quant2(scale3)->g and g->Conv4->h +ProgramDesc BuildConvMultiOutputProgramDesc(bool use_mkldnn, float scale_out, + float scale1, float scale2, + float scale3) { ProgramDesc prog; for (auto& v : variable_names) { prog.MutableBlock(0)->Var(v); } - SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn); + SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, scale_out); SetOp(&prog, "dequantize", "Dequant", {"b"}, {"c"}, use_mkldnn, scale1); SetOp(&prog, "quantize", "Quant1", {"c"}, {"d"}, use_mkldnn, scale2); - SetOp(&prog, "conv2d", "Conv2", {"d"}, {"e"}, use_mkldnn); + SetOp(&prog, "conv2d", "Conv2", {"d"}, {"e"}, use_mkldnn, scale_out); - SetOp(&prog, "conv2d", "Conv3", {"c"}, {"f"}, use_mkldnn); + SetOp(&prog, "conv2d", "Conv3", {"c"}, {"f"}, use_mkldnn, scale_out); SetOp(&prog, "quantize", "Quant2", {"c"}, {"g"}, use_mkldnn, scale3); - SetOp(&prog, "conv2d", "Conv4", {"g"}, {"h"}, use_mkldnn); + SetOp(&prog, "conv2d", "Conv4", {"g"}, {"h"}, use_mkldnn, scale_out); + + return prog; +} + +// a->Conv1->b->Requant(scale1)->c +// d->Conv2->e->Requant(scale2)->f +// {c,f}->Concat +ProgramDesc BuildConvsRequantConcatProgramDesc(bool use_mkldnn, float scale_out, + float scale1, float scale2) { + ProgramDesc prog; + for (auto& v : variable_names) { + prog.MutableBlock(0)->Var(v); + } + + SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, scale_out); + SetOp(&prog, "requantize", "Requant1", {"b"}, {"c"}, use_mkldnn, scale1); + + SetOp(&prog, "conv2d", "Conv2", {"d"}, {"e"}, use_mkldnn, scale_out); + SetOp(&prog, "requantize", "Requant2", {"e"}, {"f"}, use_mkldnn, scale2); + + SetOp(&prog, "concat", "Concat", {"c"}, {"f"}, use_mkldnn); + + return prog; +} + +// a->Concat->b +// b->Dequant(scale1)->c +// c->Quant(scale2)->d +// d->Conv->e +ProgramDesc BuildConcatDequantQuantProgramDesc(bool use_mkldnn, float scale_out, + float scale1, float scale2) { + ProgramDesc prog; + for (auto& v : variable_names) { + prog.MutableBlock(0)->Var(v); + } + SetOp(&prog, "concat", "Concat", {"a"}, {"b"}, use_mkldnn); + SetOp(&prog, "dequantize", "Dequant", {"b"}, {"c"}, use_mkldnn, scale1); + SetOp(&prog, "quantize", "Quant", {"c"}, {"d"}, use_mkldnn, scale2); + SetOp(&prog, "conv2d", "Conv2", {"d"}, {"e"}, use_mkldnn, scale_out); + return prog; +} + +// a->Conv1->b +// b->Requant1(Scale1)->c +// b->Requant2(Scale2)->d +ProgramDesc BuildConvMultiRequantProgramDesc(bool use_mkldnn, float scale_out, + float scale1, float scale2) { + ProgramDesc prog; + for (auto& v : variable_names) { + prog.MutableBlock(0)->Var(v); + } + SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, scale_out); + SetOp(&prog, "requantize", "Requant1", {"b"}, {"c"}, use_mkldnn, scale1); + SetOp(&prog, "requantize", "Requant2", {"b"}, {"d"}, use_mkldnn, scale2); return prog; } @@ -105,10 +168,7 @@ void InitTensorHolder(Scope* scope, const paddle::platform::Place& place, tensor->mutable_data(place, proto::VarType::FP32, 1); } -void MainTest(const ProgramDesc& prog, int removed_nodes_num) { - std::unique_ptr graph(new ir::Graph(prog)); - - // Init scope, as it is used in pass +void PrepareGraph(std::unique_ptr* graph, const ProgramDesc& prog) { auto place = paddle::platform::CPUPlace(); NaiveExecutor exe{place}; Scope scope; @@ -117,58 +177,172 @@ void MainTest(const ProgramDesc& prog, int removed_nodes_num) { for (auto& v : variable_names) { InitTensorHolder(&scope, place, v.c_str()); } + (*graph)->SetNotOwned(kParamScopeAttr, &scope); +} - graph->SetNotOwned(kParamScopeAttr, &scope); - +void RegisterPass(std::unique_ptr* graph) { auto pass = PassRegistry::Instance().Get("cpu_quantize_squash_pass"); + graph->reset(pass->Apply(graph->release())); +} - int original_nodes_num = graph->Nodes().size(); - - graph.reset(pass->Apply(graph.release())); +// check number of nodes +void CountNodeTest(const ProgramDesc& prog, int removed_nodes_num) { + std::unique_ptr graph(new ir::Graph(prog)); + PrepareGraph(&graph, prog); + int original_nodes_num = graph->Nodes().size(); + RegisterPass(&graph); int current_nodes_num = graph->Nodes().size(); EXPECT_EQ(original_nodes_num - removed_nodes_num, current_nodes_num); } +// check op->scale_out +void EqualScaleOutTest(const ProgramDesc& prog, const std::string& name, + float scale) { + std::unique_ptr graph(new ir::Graph(prog)); + PrepareGraph(&graph, prog); + RegisterPass(&graph); + + for (auto* node : graph->Nodes()) { + if (node->IsOp() && + boost::get(node->Op()->GetAttr("name")) == name) { + float scale_out = boost::get(node->Op()->GetAttr("Scale_out")); + EXPECT_EQ(scale_out, scale); + } + } +} + +// check requant_op scales +void CheckRequantScalesTest(const ProgramDesc& prog, float scale_in, + float scale_out) { + std::unique_ptr graph(new ir::Graph(prog)); + PrepareGraph(&graph, prog); + RegisterPass(&graph); + + for (auto* node : graph->Nodes()) { + if (node->IsOp() && node->Op()->Type() == "requantize") { + float op_scale_in = boost::get(node->Op()->GetAttr("Scale_in")); + EXPECT_EQ(op_scale_in, scale_in); + float op_scale_out = boost::get(node->Op()->GetAttr("Scale_out")); + EXPECT_EQ(op_scale_out, scale_out); + } + } +} + +// From Conv1->d->Dequant->e->Quant->f->Conv2 +// To Conv1->d->Conv2 TEST(CpuQuantizeSquashPass, equal_scales) { + auto scale_out = 1.0f; auto scale = 1.2345f; auto use_mkldnn = true; // Remove 4 nodes: Dequant, Quant, e, f auto remove_nodes = 4; - MainTest(BuildProgramDesc(use_mkldnn, scale, scale), remove_nodes); - - use_mkldnn = !use_mkldnn; - MainTest(BuildProgramDesc(use_mkldnn, scale, scale), remove_nodes); + CountNodeTest( + BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale, scale), + remove_nodes); } -TEST(CpuQuantizeSquashPass, inequal_scales) { +// From Conv1->d->Dequant->e->Quant->f->Conv2 +// First change to Conv1->d->Requant->f->Conv2 +// Then Conv1->f->Conv2 +TEST(CpuQuantizeSquashPass, unequal_scales) { + auto scale_out = 1.0f; auto scale1 = 1.2345f; auto scale2 = 21.0f; auto use_mkldnn = true; - // Remove 3 nodes: Dequant, Quant, e - // Insert 1 node: requantize + // Remove 4 nodes: Dequant, Quant, e, d + auto remove_nodes = 4; + CountNodeTest( + BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale1, scale2), + remove_nodes); + + EqualScaleOutTest( + BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale1, scale2), + "Conv1", scale2); +} + +// from +// a->Conv1->b->Dequant(Scale1)->c +// c->Quant1(Scale1)->d and d->Conv2->e +// c->Quant2(Scale2)->g and g->Conv4->h +// c->Conv3->f +// to +// a->Conv1->b +// b->Conv2->e +// b->Requant(Scale_in = Scale1; Scale_out = Scale2)->g->Conv4->h +// b->Dequant(Scale1)->c->Conv3->f +TEST(CpuQuantizeSquashPass, branch_to_equal_unequal_and_fp32) { + auto scale_out = 1.0f; + auto scale = 1.2345f; + auto scale2 = 21.0f; + auto use_mkldnn = true; + // Remove 3 nodes: Quant1, c, Quant2, + // Insert 1 node: Requant auto remove_nodes = 2; - MainTest(BuildProgramDesc(use_mkldnn, scale1, scale2), remove_nodes); + CountNodeTest(BuildConvMultiOutputProgramDesc(use_mkldnn, scale_out, scale, + scale, scale2), + remove_nodes); + CheckRequantScalesTest(BuildConvMultiOutputProgramDesc(use_mkldnn, scale_out, + scale, scale, scale2), + scale, scale2); +} - use_mkldnn = !use_mkldnn; - MainTest(BuildProgramDesc(use_mkldnn, scale1, scale2), remove_nodes); +// a->Conv1->b->Requant->c +// d->Conv2->e->Requant->f +// {c,f}->Concat +TEST(CpuQuantizeSquashPass, equal_scales_squash_requantize) { + // Delete both requantize op + auto scale_out = 1.0f; + auto scale = 1.2345f; + auto use_mkldnn = true; + // Remove 4 nodes: b, Requant1, e, Requant2 + auto remove_nodes = 4; + CountNodeTest( + BuildConvsRequantConcatProgramDesc(use_mkldnn, scale_out, scale, scale), + remove_nodes); + + // check equal scale conv->scale_out and requant->scale_out + EqualScaleOutTest( + BuildConvsRequantConcatProgramDesc(use_mkldnn, scale_out, scale, scale), + "Conv1", scale); + EqualScaleOutTest( + BuildConvsRequantConcatProgramDesc(use_mkldnn, scale_out, scale, scale), + "Conv2", scale); } -TEST(CpuQuantizeSquashPass, branch_to_equal_inequal_and_fp32) { - // Delete both quantize ops, - // bypass dequantize in both branches, - // insert requantize on one branch +// a->Concat->b->Dequant->c->Quant->d->Conv->e +// to a->Concat->b->Requant->d->Conv->e +TEST(CpuQuantizeSquashPass, + unequal_scales_squash_dequantize_quantize_into_requantize) { + auto scale_out = 1.0f; auto scale = 1.2345f; auto scale2 = 21.0f; auto use_mkldnn = true; - // Remove 3 nodes: Quant1, Quant2, g - // Insert 1 node: requantize + // Remove 3 nodes: Dequant1, c, Quant + // Insert 1 node: Requant auto remove_nodes = 2; - MainTest(BuildProgramDesc2(use_mkldnn, scale, scale, scale2), remove_nodes); + CountNodeTest( + BuildConcatDequantQuantProgramDesc(use_mkldnn, scale_out, scale, scale2), + remove_nodes); + CheckRequantScalesTest( + BuildConcatDequantQuantProgramDesc(use_mkldnn, scale_out, scale, scale2), + scale, scale2); +} - use_mkldnn = !use_mkldnn; - MainTest(BuildProgramDesc2(use_mkldnn, scale, scale, scale2), remove_nodes); +// a->Conv1->b +// b->Requant1(Scale1)->c +// b->Requant2(Scale2)->d +TEST(CpuQuantizeSquashPass, more_than_one_conv_out_outputs) { + auto scale_out = 1.0f; + auto scale = 1.2345f; + auto scale2 = 21.0f; + auto use_mkldnn = true; + // nothing change + auto remove_nodes = 0; + CountNodeTest( + BuildConvMultiRequantProgramDesc(use_mkldnn, scale_out, scale, scale2), + remove_nodes); } } // namespace ir -- GitLab