diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index f91bbdf0a29c80cbd5d14b517e80ef9f4e75843f..2814020b94f2ce2e0487a09886d8d99656706c06 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1508,6 +1508,27 @@ PDNode *patterns::OpRequant::operator()() { return requant_out; } +PDNode *patterns::RequantOp::operator()() { + auto requant_in = pattern->NewNode(requant_in_repr()) + ->assert_is_op_input("requantize", "Input"); + auto requant_op = + pattern->NewNode(requant_op_repr())->assert_is_op("requantize"); + auto requant_out = pattern->NewNode(requant_out_repr()) + ->AsOutput() + ->assert_is_op_output("requantize", "Output"); + auto any_op = pattern->NewNode(any_op_repr()) + ->assert_is_op() + ->assert_more([&](Node *node) { + return (node->Op()->HasAttr("Scale_in") || + node->Op()->HasAttr("Scale_x") || + node->Op()->HasAttr("Scale_y")); + }); + + requant_op->LinksFrom({requant_in}).LinksTo({requant_out}); + any_op->LinksFrom({requant_out}); + return any_op; +} + PDNode *patterns::ConvDequant::operator()() { // Create Operators auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d"); diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 4d7a4e283d353b508b42f5c03f42cd8f842ae40a..58248aec709d5e24d804f6bc19aa36ca392aad7f 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -913,6 +913,22 @@ struct OpRequant : public PatternBase { PATTERN_DECL_NODE(requant_out); }; +// Requant + Op +// named nodes: +// requant_in, requant_op, +// requant_out, any_op +struct RequantOp : public PatternBase { + RequantOp(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "requant_op") {} + + PDNode* operator()(); + + PATTERN_DECL_NODE(any_op); + PATTERN_DECL_NODE(requant_in); + PATTERN_DECL_NODE(requant_op); + PATTERN_DECL_NODE(requant_out); +}; + // Conv + Dequant // named nodes: // conv_op, conv_out diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc index 92396306934a682f4316e6fb6ed4ca4ba07ae0e7..ed57e329443ea6dcf75d70a82361955928dbe559 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc @@ -152,7 +152,7 @@ void CPUQuantizeSquashPass::OpRequantSquash(Graph* graph) const { PADDLE_ENFORCE_NE( any_op_output_name.empty(), true, platform::errors::NotFound("Operator before requantize operator " - "should has requantize input as output")); + "should have requantize input as output")); float requant_scale_out = boost::get(requant_op->Op()->GetAttr("Scale_out")); @@ -170,6 +170,59 @@ void CPUQuantizeSquashPass::OpRequantSquash(Graph* graph) const { found_requant_squash_count); } +// requant-op squash if op has Scale_in, Scale_x, Scale_y attr +// conv2d, fc, matmul +void CPUQuantizeSquashPass::RequantOpSquash(Graph* graph) const { + GraphPatternDetector gpd; + patterns::RequantOp requant_op_pattern{gpd.mutable_pattern(), "requant_op"}; + requant_op_pattern(); + + int found_requant_squash_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "squash requantize-op ops pair"; + + GET_IR_NODE_FROM_SUBGRAPH(requant_in, requant_in, requant_op_pattern); + GET_IR_NODE_FROM_SUBGRAPH(requant_op, requant_op, requant_op_pattern); + GET_IR_NODE_FROM_SUBGRAPH(requant_out, requant_out, requant_op_pattern); + GET_IR_NODE_FROM_SUBGRAPH(any_op, any_op, requant_op_pattern); + + if (requant_out->outputs.size() == 1) { + std::string any_op_input_name; + for (auto name : any_op->Op()->InputNames()) + for (auto input_name : any_op->Op()->Input(name)) + if (input_name == requant_out->Name()) any_op_input_name = name; + + PADDLE_ENFORCE_NE( + any_op_input_name.empty(), true, + platform::errors::NotFound("The operator after requantize operator " + "should have requantize output as input")); + float requant_scale_in = + boost::get(requant_op->Op()->GetAttr("Scale_in")); + + auto scale_name = "Scale_in"; + if (any_op->Op()->Type() == "matmul") + scale_name = any_op_input_name == "X" ? "Scale_x" : "Scale_y"; + + PADDLE_ENFORCE_EQ(requant_op->Op()->GetAttrIfExists("Scale_out"), + any_op->Op()->GetAttrIfExists(scale_name), + platform::errors::InvalidArgument( + "The operator after requantize should have input " + "scale equal to requantize output scale")); + any_op->Op()->SetAttr(scale_name, requant_scale_in); + any_op->Op()->SetInput(any_op_input_name, + std::vector({requant_in->Name()})); + IR_NODE_LINK_TO(requant_in, any_op); + GraphSafeRemoveNodes(graph, {requant_op, requant_out}); + found_requant_squash_count++; + } + }; + gpd(graph, handler); + AddStatis(found_requant_squash_count); + PrettyLogDetail("--- squashed %d requantize ops", + found_requant_squash_count); +} + void CPUQuantizeSquashPass::ConvDequantSquash(Graph* graph) const { GraphPatternDetector gpd; patterns::ConvDequant conv_dequant_pattern{gpd.mutable_pattern(), @@ -379,6 +432,7 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const { FindNodesToKeep(graph, &nodes_keep_counter); DequantQuantSquash(graph, &nodes_keep_counter); OpRequantSquash(graph); + RequantOpSquash(graph); ConvDequantSquash(graph); FcDequantSquash(graph); MultipleQuantizeSquash(graph); diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h index be770fb0b2b569adefe6f3062dc7a907eaeac716..8ce5c1858559c6083df17d66630931818858d6f8 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h @@ -55,6 +55,11 @@ class CPUQuantizeSquashPass : public FusePassBase { */ void OpRequantSquash(Graph* graph) const; + /* + * Squash requantize op if the next operator's input scale can be updated + */ + void RequantOpSquash(Graph* graph) const; + /* * Squash conv2d with dequant when dequant is the only op after conv2d */ diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc index 163497d80c42c16d91b0100b036051047dc51306..d10494aabfbc49bf16b92d719b7e58e58b4fcd5f 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc @@ -24,13 +24,14 @@ namespace ir { void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, const std::vector& inputs, const std::vector& outputs, bool use_mkldnn, - float scale = 0, float bias = 0.0) { + const std::vector scale = {}, float bias = 0.0) { auto* op = prog->MutableBlock(0)->AppendOp(); op->SetType(type); op->SetAttr("use_mkldnn", use_mkldnn); op->SetAttr("name", name); if (type == "conv2d") { - op->SetAttr("Scale_out", scale); + if (scale.size() > 0) op->SetAttr("Scale_in", scale[0]); + if (scale.size() > 1) op->SetAttr("Scale_out", scale[1]); op->SetInput("Input", {inputs[0]}); if (inputs.size() > 1) op->SetInput("Filter", {inputs[1]}); if (inputs.size() > 2) op->SetInput("Bias", {inputs[2]}); @@ -38,15 +39,16 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, } else if (type == "quantize") { op->SetInput("Input", {inputs[0]}); op->SetOutput("Output", {outputs[0]}); - op->SetAttr("Scale", scale); + op->SetAttr("Scale", scale[0]); } else if (type == "dequantize") { op->SetInput("Input", {inputs[0]}); op->SetOutput("Output", {outputs[0]}); - op->SetAttr("Scale", scale); + op->SetAttr("Scale", scale[0]); } else if (type == "requantize") { op->SetInput("Input", {inputs[0]}); op->SetOutput("Output", {outputs[0]}); - op->SetAttr("Scale_out", scale); + op->SetAttr("Scale_in", scale[0]); + op->SetAttr("Scale_out", scale[1]); } else if (type == "concat") { op->SetInput("X", inputs); op->SetOutput("Out", outputs); @@ -59,17 +61,19 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, inputs.size())); op->SetInput("W", {inputs[1]}); op->SetOutput("Out", outputs); - op->SetAttr("Scale_out", scale); + if (scale.size() > 0) op->SetAttr("Scale_in", scale[0]); + if (scale.size() > 1) op->SetAttr("Scale_out", scale[1]); } else if (type == "scale") { op->SetInput("X", {inputs[0]}); op->SetOutput("Out", {outputs[0]}); - op->SetAttr("scale", scale); + op->SetAttr("scale", scale[0]); op->SetAttr("bias", bias); } else if (type == "matmul") { op->SetInput("X", {inputs[0]}); op->SetInput("Y", {inputs[1]}); op->SetOutput("Out", {outputs[0]}); - op->SetAttr("Scale_out", scale); + if (scale.size() > 0) op->SetAttr("Scale_x", scale[0]); + if (scale.size() > 1) op->SetAttr("Scale_out", scale[1]); } } @@ -78,7 +82,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, // e->Quant(scale2)->f // (f,w2,b2)->Conv2->i ProgramDesc BuildConvRequantProgramDesc(bool use_mkldnn, float scale_out, - float scale1, float scale2) { + float scale_in) { ProgramDesc prog; for (auto& v : std::initializer_list( {"a", "w1", "b1", "d", "e", "f", "w2", "b2", "i"})) { @@ -89,22 +93,22 @@ ProgramDesc BuildConvRequantProgramDesc(bool use_mkldnn, float scale_out, } SetOp(&prog, "conv2d", "Conv1", {"a", "w1", "b1"}, {"d"}, use_mkldnn, - scale_out); - SetOp(&prog, "dequantize", "Dequant", {"d"}, {"e"}, use_mkldnn, scale1); - SetOp(&prog, "quantize", "Quant", {"e"}, {"f"}, use_mkldnn, scale2); + {1.23f, scale_out}); + SetOp(&prog, "dequantize", "Dequant", {"d"}, {"e"}, use_mkldnn, {scale_out}); + SetOp(&prog, "quantize", "Quant", {"e"}, {"f"}, use_mkldnn, {scale_in}); SetOp(&prog, "conv2d", "Conv2", {"f", "w2", "b2"}, {"i"}, use_mkldnn, - scale_out); + {scale_in, 2.34f}); return prog; } static const std::initializer_list variable_names{ - "a", "b", "c", "d", "e", "f", "g", "h", "x", "y", "w1"}; + "a", "b", "c", "d", "e", "f", "g", "h", "i", "x", "y", "w1", "w2"}; -// a->Conv1->b +// a->Conv1(scale1)->b // b->Dequant(scale1)->c -// c->Quant1(scale2)->d and d->Conv2->e +// c->Quant1(scale2)->d and d->(scale2)Conv2->e // c->Conv3->f -// c->Quant2(scale3)->g and g->Conv4->h +// c->Quant2(scale3)->g and g->Concat->h ProgramDesc BuildConvMultiOutputProgramDesc(bool use_mkldnn, float scale_out, float scale1, float scale2, float scale3) { @@ -113,16 +117,17 @@ ProgramDesc BuildConvMultiOutputProgramDesc(bool use_mkldnn, float scale_out, prog.MutableBlock(0)->Var(v); } - SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, scale_out); - SetOp(&prog, "dequantize", "Dequant", {"b"}, {"c"}, use_mkldnn, scale1); + SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, {1.23f, scale1}); + SetOp(&prog, "dequantize", "Dequant", {"b"}, {"c"}, use_mkldnn, {scale1}); - SetOp(&prog, "quantize", "Quant1", {"c"}, {"d"}, use_mkldnn, scale2); - SetOp(&prog, "conv2d", "Conv2", {"d"}, {"e"}, use_mkldnn, scale_out); + SetOp(&prog, "quantize", "Quant1", {"c"}, {"d"}, use_mkldnn, {scale2}); + SetOp(&prog, "conv2d", "Conv2", {"d"}, {"e"}, use_mkldnn, + {scale2, scale_out}); - SetOp(&prog, "conv2d", "Conv3", {"c"}, {"f"}, use_mkldnn, scale_out); + SetOp(&prog, "conv2d", "Conv3", {"c"}, {"f"}, use_mkldnn); - SetOp(&prog, "quantize", "Quant2", {"c"}, {"g"}, use_mkldnn, scale3); - SetOp(&prog, "conv2d", "Conv4", {"g"}, {"h"}, use_mkldnn, scale_out); + SetOp(&prog, "quantize", "Quant2", {"c"}, {"g"}, use_mkldnn, {scale3}); + SetOp(&prog, "concat", "Concat", {"g"}, {"h"}, use_mkldnn); return prog; } @@ -141,16 +146,17 @@ ProgramDesc BuildOpRequantProgramDesc(bool use_mkldnn, float conv_scale, prog.MutableBlock(0)->Var(v); } - SetOp(&prog, "conv2d", "Conv", {"a"}, {"b"}, use_mkldnn, conv_scale); + SetOp(&prog, "conv2d", "Conv", {"a"}, {"b"}, use_mkldnn, {1.23f, conv_scale}); SetOp(&prog, "requantize", "Requant1", {"b"}, {"c"}, use_mkldnn, - requant_scale1); - SetOp(&prog, "fc", "Fc", {"d", "w1"}, {"e"}, use_mkldnn, fc_scale); + {conv_scale, requant_scale1}); + SetOp(&prog, "fc", "Fc", {"d", "w1"}, {"e"}, use_mkldnn, {1.23f, fc_scale}); SetOp(&prog, "requantize", "Requant2", {"e"}, {"f"}, use_mkldnn, - requant_scale2); - SetOp(&prog, "matmul", "Matmul", {"x", "y"}, {"g"}, use_mkldnn, matmul_scale); + {fc_scale, requant_scale2}); + SetOp(&prog, "matmul", "Matmul", {"x", "y"}, {"g"}, use_mkldnn, + {1.23f, matmul_scale}); SetOp(&prog, "requantize", "Requant3", {"g"}, {"h"}, use_mkldnn, - requant_scale3); - SetOp(&prog, "concat", "Concat", {"c", "f", "h"}, {"g"}, use_mkldnn); + {matmul_scale, requant_scale3}); + SetOp(&prog, "concat", "Concat", {"c", "f", "h"}, {"g"}, {use_mkldnn}); return prog; } @@ -158,7 +164,8 @@ ProgramDesc BuildOpRequantProgramDesc(bool use_mkldnn, float conv_scale, // a->Concat->b // b->Dequant(scale1)->c // c->Quant(scale2)->d -// d->Conv->e +// d->Conv1->e +// d->Conv2->f ProgramDesc BuildConcatDequantQuantProgramDesc(bool use_mkldnn, float scale_out, float scale1, float scale2) { ProgramDesc prog; @@ -167,9 +174,12 @@ ProgramDesc BuildConcatDequantQuantProgramDesc(bool use_mkldnn, float scale_out, } SetOp(&prog, "concat", "Concat", {"a"}, {"b"}, use_mkldnn); - SetOp(&prog, "dequantize", "Dequant", {"b"}, {"c"}, use_mkldnn, scale1); - SetOp(&prog, "quantize", "Quant", {"c"}, {"d"}, use_mkldnn, scale2); - SetOp(&prog, "conv2d", "Conv2", {"d"}, {"e"}, use_mkldnn, scale_out); + SetOp(&prog, "dequantize", "Dequant", {"b"}, {"c"}, use_mkldnn, {scale1}); + SetOp(&prog, "quantize", "Quant", {"c"}, {"d"}, use_mkldnn, {scale2}); + SetOp(&prog, "conv2d", "Conv1", {"d"}, {"e"}, use_mkldnn, + {scale2, scale_out}); + SetOp(&prog, "conv2d", "Conv2", {"d"}, {"f"}, use_mkldnn, + {scale2, scale_out}); return prog; } @@ -182,9 +192,11 @@ ProgramDesc BuildConvMultiRequantProgramDesc(bool use_mkldnn, float scale_out, for (auto& v : variable_names) { prog.MutableBlock(0)->Var(v); } - SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, scale_out); - SetOp(&prog, "requantize", "Requant1", {"b"}, {"c"}, use_mkldnn, scale1); - SetOp(&prog, "requantize", "Requant2", {"b"}, {"d"}, use_mkldnn, scale2); + SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, {1.23f, scale_out}); + SetOp(&prog, "requantize", "Requant1", {"b"}, {"c"}, use_mkldnn, + {scale_out, scale1}); + SetOp(&prog, "requantize", "Requant2", {"b"}, {"d"}, use_mkldnn, + {scale_out, scale2}); return prog; } @@ -197,8 +209,8 @@ ProgramDesc BuildConvDequantConcatProgramDesc(bool use_mkldnn, float scale_out, for (auto& v : variable_names) { prog.MutableBlock(0)->Var(v); } - SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, scale_out); - SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, use_mkldnn, scale); + SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, {1.23f, scale_out}); + SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, use_mkldnn, {scale}); SetOp(&prog, "concat", "Concat1", {"c"}, {"d"}, use_mkldnn); return prog; } @@ -212,24 +224,24 @@ ProgramDesc BuildFcDequantConcatProgramDesc(bool use_mkldnn, float scale_out, for (auto& v : variable_names) { prog.MutableBlock(0)->Var(v); } - SetOp(&prog, "fc", "Fc1", {"a", "w1"}, {"b"}, use_mkldnn, scale_out); - SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, use_mkldnn, scale); + SetOp(&prog, "fc", "Fc1", {"a", "w1"}, {"b"}, use_mkldnn, {1.23f, scale_out}); + SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, use_mkldnn, {scale}); SetOp(&prog, "concat", "Concat1", {"c"}, {"d"}, use_mkldnn); return prog; } // a->fc->b // b->Dequant1->c -// b->concat->d +// b->fc->d ProgramDesc BuildFcDequantFcProgramDesc(bool use_mkldnn, float scale_out, float scale) { ProgramDesc prog; for (auto& v : variable_names) { prog.MutableBlock(0)->Var(v); } - SetOp(&prog, "fc", "Fc1", {"a", "w1"}, {"b"}, use_mkldnn, scale_out); - SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, use_mkldnn, scale); - SetOp(&prog, "concat", "Concat1", {"b"}, {"d"}, use_mkldnn); + SetOp(&prog, "fc", "Fc1", {"a", "w1"}, {"b"}, use_mkldnn, {1.23f, scale_out}); + SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, use_mkldnn, {scale}); + SetOp(&prog, "fc", "Fc2", {"b", "w2"}, {"d"}, use_mkldnn, {scale_out, 2.34f}); return prog; } @@ -242,18 +254,16 @@ ProgramDesc BuildConvDequantConvProgramDesc(bool use_mkldnn, float scale_out, for (auto& v : variable_names) { prog.MutableBlock(0)->Var(v); } - SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, scale_out); - SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, use_mkldnn, scale); + SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, {1.23f, scale_out}); + SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, use_mkldnn, {scale}); SetOp(&prog, "conv2d", "Conv2", {"b"}, {"d"}, use_mkldnn); return prog; } // a->concat->b -// b->Quant1(Scale1)->c -// b->Quant2(Scale2)->d +// b->Quant1(Scale1)->c->fc->f +// b->Quant2(Scale2)->d->fc->g // b->concat->e -// c->fc->f -// d->fc->g ProgramDesc BuildMultipleQuantizeProgramDesc(bool use_mkldnn, float first_scale, float second_scale) { ProgramDesc prog; @@ -261,11 +271,15 @@ ProgramDesc BuildMultipleQuantizeProgramDesc(bool use_mkldnn, float first_scale, prog.MutableBlock(0)->Var(v); } SetOp(&prog, "concat", "Concat1", {"a"}, {"b"}, use_mkldnn); - SetOp(&prog, "quantize", "Quantize1", {"b"}, {"c"}, use_mkldnn, first_scale); - SetOp(&prog, "quantize", "Quantize2", {"b"}, {"d"}, use_mkldnn, second_scale); + SetOp(&prog, "quantize", "Quantize1", {"b"}, {"c"}, use_mkldnn, + {first_scale}); + SetOp(&prog, "quantize", "Quantize2", {"b"}, {"d"}, use_mkldnn, + {second_scale}); SetOp(&prog, "concat", "Concat2", {"b"}, {"e"}, use_mkldnn); - SetOp(&prog, "fc", "Fc1", {"c", "w1"}, {"f"}, use_mkldnn, first_scale); - SetOp(&prog, "fc", "Fc2", {"d", "w2"}, {"g"}, use_mkldnn, second_scale); + SetOp(&prog, "fc", "Fc1", {"c", "w1"}, {"f"}, use_mkldnn, + {first_scale, 1.23f}); + SetOp(&prog, "fc", "Fc2", {"d", "w2"}, {"g"}, use_mkldnn, + {second_scale, 2.34f}); return prog; } @@ -279,8 +293,8 @@ ProgramDesc BuildDequantScaleProgramDesc(bool use_mkldnn, float dequant_scale, prog.MutableBlock(0)->Var(v); } SetOp(&prog, "dequantize", "Dequant", {"a"}, {"b"}, use_mkldnn, - dequant_scale); - SetOp(&prog, "scale", "Scale", {"b"}, {"c"}, use_mkldnn, scale_scale, bias); + {dequant_scale}); + SetOp(&prog, "scale", "Scale", {"b"}, {"c"}, use_mkldnn, {scale_scale}, bias); return prog; } @@ -295,7 +309,34 @@ ProgramDesc BuildMatmulDequantProgramDesc(bool use_mkldnn, } SetOp(&prog, "matmul", "Matmul", {"x", "y"}, {"b"}, use_mkldnn); SetOp(&prog, "dequantize", "Dequant", {"b"}, {"c"}, use_mkldnn, - dequant_scale); + {dequant_scale}); + + return prog; +} + +// a->Requant1->x->Matmul->b +// c->Requant2->d->Fc->e +// f->Requant3->g->Conv->h +// {b,e,h}->Concat->i +ProgramDesc BuildRequantOpProgramDesc(bool use_mkldnn, float requant_scale_in, + float op_scale_in, float op_scale_out) { + ProgramDesc prog; + for (auto& v : variable_names) { + prog.MutableBlock(0)->Var(v); + } + SetOp(&prog, "requantize", "Requant1", {"a"}, {"x"}, use_mkldnn, + {requant_scale_in, op_scale_in}); + SetOp(&prog, "matmul", "Matmul", {"x", "y"}, {"b"}, use_mkldnn, + {op_scale_in, op_scale_out}); + SetOp(&prog, "requantize", "Requant2", {"c"}, {"d"}, use_mkldnn, + {requant_scale_in, op_scale_in}); + SetOp(&prog, "fc", "Fc", {"d", "w1"}, {"e"}, use_mkldnn, + {op_scale_in, op_scale_out}); + SetOp(&prog, "requantize", "Requant3", {"f"}, {"g"}, use_mkldnn, + {requant_scale_in, op_scale_in}); + SetOp(&prog, "conv2d", "Conv", {"g"}, {"h"}, use_mkldnn, + {op_scale_in, op_scale_out}); + SetOp(&prog, "concat", "Concat", {"b", "e", "h"}, {"i"}, {use_mkldnn}); return prog; } @@ -390,35 +431,31 @@ void IsForceFp32OutputTest(const ProgramDesc& prog, std::string op_type, // From Conv1->d->Dequant->e->Quant->f->Conv2 // To Conv1->d->Conv2 TEST(CpuQuantizeSquashPass, equal_scales) { - auto scale_out = 1.0f; - auto scale = 1.2345f; + auto scale_out = 1.234f; + auto scale = 2.345f; auto use_mkldnn = true; // Remove 4 nodes: Dequant, Quant, e, f auto remove_nodes = 4; - CountNodeTest( - BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale, scale), - remove_nodes); + CountNodeTest(BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale), + remove_nodes); } // From Conv1->d->Dequant->e->Quant->f->Conv2 // First change to Conv1->d->Requant->f->Conv2 // Then Conv1->f->Conv2 TEST(CpuQuantizeSquashPass, unequal_scales) { - auto scale_out = 1.0f; - auto scale1 = 1.2345f; - auto scale2 = 21.0f; + auto scale_out = 1.230f; + auto scale_in = 2.34f; auto use_mkldnn = true; // Remove 4 nodes: Dequant, Quant, e, d auto remove_nodes = 4; - CountNodeTest( - BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale1, scale2), - remove_nodes); + CountNodeTest(BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale_in), + remove_nodes); - EqualScaleTest( - BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale1, scale2), - "Conv1", "Scale_out", scale2); + EqualScaleTest(BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale_in), + "Conv1", "Scale_out", scale_in); } // a->Conv->b->Requant->c @@ -635,6 +672,20 @@ TEST(CpuQuantizeSquashPass, matmul_with_dequant) { IsForceFp32OutputTest( BuildMatmulDequantProgramDesc(use_mkldnn, dequant_scale), "matmul", true); } + +TEST(CpuQuantizeSquashPass, requantize_with_matmul_fc_conv) { + auto use_mkldnn = true; + auto requant_scale_in = 1.2f, op_scale_in = 2.3f, op_scale_out = 3.4f; + // remove: 3 requant ops + 3 requant outs + auto remove_nodes = 6; + auto program_desc = BuildRequantOpProgramDesc(use_mkldnn, requant_scale_in, + op_scale_in, op_scale_out); + CountNodeTest(program_desc, remove_nodes); + EqualScaleTest(program_desc, "Matmul", "Scale_x", requant_scale_in); + EqualScaleTest(program_desc, "Fc", "Scale_in", requant_scale_in); + EqualScaleTest(program_desc, "Conv", "Scale_in", requant_scale_in); +} + } // namespace ir } // namespace framework } // namespace paddle