未验证 提交 b43b46e6 编写于 作者: J joanna.wozna.intel 提交者: GitHub

[INT8] Add requant-op squash (#24143)

上级 e8869a90
...@@ -1508,6 +1508,27 @@ PDNode *patterns::OpRequant::operator()() { ...@@ -1508,6 +1508,27 @@ PDNode *patterns::OpRequant::operator()() {
return requant_out; return requant_out;
} }
PDNode *patterns::RequantOp::operator()() {
auto requant_in = pattern->NewNode(requant_in_repr())
->assert_is_op_input("requantize", "Input");
auto requant_op =
pattern->NewNode(requant_op_repr())->assert_is_op("requantize");
auto requant_out = pattern->NewNode(requant_out_repr())
->AsOutput()
->assert_is_op_output("requantize", "Output");
auto any_op = pattern->NewNode(any_op_repr())
->assert_is_op()
->assert_more([&](Node *node) {
return (node->Op()->HasAttr("Scale_in") ||
node->Op()->HasAttr("Scale_x") ||
node->Op()->HasAttr("Scale_y"));
});
requant_op->LinksFrom({requant_in}).LinksTo({requant_out});
any_op->LinksFrom({requant_out});
return any_op;
}
PDNode *patterns::ConvDequant::operator()() { PDNode *patterns::ConvDequant::operator()() {
// Create Operators // Create Operators
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d"); auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
......
...@@ -913,6 +913,22 @@ struct OpRequant : public PatternBase { ...@@ -913,6 +913,22 @@ struct OpRequant : public PatternBase {
PATTERN_DECL_NODE(requant_out); PATTERN_DECL_NODE(requant_out);
}; };
// Requant + Op
// named nodes:
// requant_in, requant_op,
// requant_out, any_op
struct RequantOp : public PatternBase {
RequantOp(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "requant_op") {}
PDNode* operator()();
PATTERN_DECL_NODE(any_op);
PATTERN_DECL_NODE(requant_in);
PATTERN_DECL_NODE(requant_op);
PATTERN_DECL_NODE(requant_out);
};
// Conv + Dequant // Conv + Dequant
// named nodes: // named nodes:
// conv_op, conv_out // conv_op, conv_out
......
...@@ -152,7 +152,7 @@ void CPUQuantizeSquashPass::OpRequantSquash(Graph* graph) const { ...@@ -152,7 +152,7 @@ void CPUQuantizeSquashPass::OpRequantSquash(Graph* graph) const {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
any_op_output_name.empty(), true, any_op_output_name.empty(), true,
platform::errors::NotFound("Operator before requantize operator " platform::errors::NotFound("Operator before requantize operator "
"should has requantize input as output")); "should have requantize input as output"));
float requant_scale_out = float requant_scale_out =
boost::get<float>(requant_op->Op()->GetAttr("Scale_out")); boost::get<float>(requant_op->Op()->GetAttr("Scale_out"));
...@@ -170,6 +170,59 @@ void CPUQuantizeSquashPass::OpRequantSquash(Graph* graph) const { ...@@ -170,6 +170,59 @@ void CPUQuantizeSquashPass::OpRequantSquash(Graph* graph) const {
found_requant_squash_count); found_requant_squash_count);
} }
// requant-op squash if op has Scale_in, Scale_x, Scale_y attr
// conv2d, fc, matmul
void CPUQuantizeSquashPass::RequantOpSquash(Graph* graph) const {
GraphPatternDetector gpd;
patterns::RequantOp requant_op_pattern{gpd.mutable_pattern(), "requant_op"};
requant_op_pattern();
int found_requant_squash_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "squash requantize-op ops pair";
GET_IR_NODE_FROM_SUBGRAPH(requant_in, requant_in, requant_op_pattern);
GET_IR_NODE_FROM_SUBGRAPH(requant_op, requant_op, requant_op_pattern);
GET_IR_NODE_FROM_SUBGRAPH(requant_out, requant_out, requant_op_pattern);
GET_IR_NODE_FROM_SUBGRAPH(any_op, any_op, requant_op_pattern);
if (requant_out->outputs.size() == 1) {
std::string any_op_input_name;
for (auto name : any_op->Op()->InputNames())
for (auto input_name : any_op->Op()->Input(name))
if (input_name == requant_out->Name()) any_op_input_name = name;
PADDLE_ENFORCE_NE(
any_op_input_name.empty(), true,
platform::errors::NotFound("The operator after requantize operator "
"should have requantize output as input"));
float requant_scale_in =
boost::get<float>(requant_op->Op()->GetAttr("Scale_in"));
auto scale_name = "Scale_in";
if (any_op->Op()->Type() == "matmul")
scale_name = any_op_input_name == "X" ? "Scale_x" : "Scale_y";
PADDLE_ENFORCE_EQ(requant_op->Op()->GetAttrIfExists<float>("Scale_out"),
any_op->Op()->GetAttrIfExists<float>(scale_name),
platform::errors::InvalidArgument(
"The operator after requantize should have input "
"scale equal to requantize output scale"));
any_op->Op()->SetAttr(scale_name, requant_scale_in);
any_op->Op()->SetInput(any_op_input_name,
std::vector<std::string>({requant_in->Name()}));
IR_NODE_LINK_TO(requant_in, any_op);
GraphSafeRemoveNodes(graph, {requant_op, requant_out});
found_requant_squash_count++;
}
};
gpd(graph, handler);
AddStatis(found_requant_squash_count);
PrettyLogDetail("--- squashed %d requantize ops",
found_requant_squash_count);
}
void CPUQuantizeSquashPass::ConvDequantSquash(Graph* graph) const { void CPUQuantizeSquashPass::ConvDequantSquash(Graph* graph) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
patterns::ConvDequant conv_dequant_pattern{gpd.mutable_pattern(), patterns::ConvDequant conv_dequant_pattern{gpd.mutable_pattern(),
...@@ -379,6 +432,7 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const { ...@@ -379,6 +432,7 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
FindNodesToKeep(graph, &nodes_keep_counter); FindNodesToKeep(graph, &nodes_keep_counter);
DequantQuantSquash(graph, &nodes_keep_counter); DequantQuantSquash(graph, &nodes_keep_counter);
OpRequantSquash(graph); OpRequantSquash(graph);
RequantOpSquash(graph);
ConvDequantSquash(graph); ConvDequantSquash(graph);
FcDequantSquash(graph); FcDequantSquash(graph);
MultipleQuantizeSquash(graph); MultipleQuantizeSquash(graph);
......
...@@ -55,6 +55,11 @@ class CPUQuantizeSquashPass : public FusePassBase { ...@@ -55,6 +55,11 @@ class CPUQuantizeSquashPass : public FusePassBase {
*/ */
void OpRequantSquash(Graph* graph) const; void OpRequantSquash(Graph* graph) const;
/*
* Squash requantize op if the next operator's input scale can be updated
*/
void RequantOpSquash(Graph* graph) const;
/* /*
* Squash conv2d with dequant when dequant is the only op after conv2d * Squash conv2d with dequant when dequant is the only op after conv2d
*/ */
......
...@@ -24,13 +24,14 @@ namespace ir { ...@@ -24,13 +24,14 @@ namespace ir {
void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
const std::vector<std::string>& inputs, const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs, bool use_mkldnn, const std::vector<std::string>& outputs, bool use_mkldnn,
float scale = 0, float bias = 0.0) { const std::vector<float> scale = {}, float bias = 0.0) {
auto* op = prog->MutableBlock(0)->AppendOp(); auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type); op->SetType(type);
op->SetAttr("use_mkldnn", use_mkldnn); op->SetAttr("use_mkldnn", use_mkldnn);
op->SetAttr("name", name); op->SetAttr("name", name);
if (type == "conv2d") { if (type == "conv2d") {
op->SetAttr("Scale_out", scale); if (scale.size() > 0) op->SetAttr("Scale_in", scale[0]);
if (scale.size() > 1) op->SetAttr("Scale_out", scale[1]);
op->SetInput("Input", {inputs[0]}); op->SetInput("Input", {inputs[0]});
if (inputs.size() > 1) op->SetInput("Filter", {inputs[1]}); if (inputs.size() > 1) op->SetInput("Filter", {inputs[1]});
if (inputs.size() > 2) op->SetInput("Bias", {inputs[2]}); if (inputs.size() > 2) op->SetInput("Bias", {inputs[2]});
...@@ -38,15 +39,16 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, ...@@ -38,15 +39,16 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
} else if (type == "quantize") { } else if (type == "quantize") {
op->SetInput("Input", {inputs[0]}); op->SetInput("Input", {inputs[0]});
op->SetOutput("Output", {outputs[0]}); op->SetOutput("Output", {outputs[0]});
op->SetAttr("Scale", scale); op->SetAttr("Scale", scale[0]);
} else if (type == "dequantize") { } else if (type == "dequantize") {
op->SetInput("Input", {inputs[0]}); op->SetInput("Input", {inputs[0]});
op->SetOutput("Output", {outputs[0]}); op->SetOutput("Output", {outputs[0]});
op->SetAttr("Scale", scale); op->SetAttr("Scale", scale[0]);
} else if (type == "requantize") { } else if (type == "requantize") {
op->SetInput("Input", {inputs[0]}); op->SetInput("Input", {inputs[0]});
op->SetOutput("Output", {outputs[0]}); op->SetOutput("Output", {outputs[0]});
op->SetAttr("Scale_out", scale); op->SetAttr("Scale_in", scale[0]);
op->SetAttr("Scale_out", scale[1]);
} else if (type == "concat") { } else if (type == "concat") {
op->SetInput("X", inputs); op->SetInput("X", inputs);
op->SetOutput("Out", outputs); op->SetOutput("Out", outputs);
...@@ -59,17 +61,19 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, ...@@ -59,17 +61,19 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
inputs.size())); inputs.size()));
op->SetInput("W", {inputs[1]}); op->SetInput("W", {inputs[1]});
op->SetOutput("Out", outputs); op->SetOutput("Out", outputs);
op->SetAttr("Scale_out", scale); if (scale.size() > 0) op->SetAttr("Scale_in", scale[0]);
if (scale.size() > 1) op->SetAttr("Scale_out", scale[1]);
} else if (type == "scale") { } else if (type == "scale") {
op->SetInput("X", {inputs[0]}); op->SetInput("X", {inputs[0]});
op->SetOutput("Out", {outputs[0]}); op->SetOutput("Out", {outputs[0]});
op->SetAttr("scale", scale); op->SetAttr("scale", scale[0]);
op->SetAttr("bias", bias); op->SetAttr("bias", bias);
} else if (type == "matmul") { } else if (type == "matmul") {
op->SetInput("X", {inputs[0]}); op->SetInput("X", {inputs[0]});
op->SetInput("Y", {inputs[1]}); op->SetInput("Y", {inputs[1]});
op->SetOutput("Out", {outputs[0]}); op->SetOutput("Out", {outputs[0]});
op->SetAttr("Scale_out", scale); if (scale.size() > 0) op->SetAttr("Scale_x", scale[0]);
if (scale.size() > 1) op->SetAttr("Scale_out", scale[1]);
} }
} }
...@@ -78,7 +82,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, ...@@ -78,7 +82,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
// e->Quant(scale2)->f // e->Quant(scale2)->f
// (f,w2,b2)->Conv2->i // (f,w2,b2)->Conv2->i
ProgramDesc BuildConvRequantProgramDesc(bool use_mkldnn, float scale_out, ProgramDesc BuildConvRequantProgramDesc(bool use_mkldnn, float scale_out,
float scale1, float scale2) { float scale_in) {
ProgramDesc prog; ProgramDesc prog;
for (auto& v : std::initializer_list<std::string>( for (auto& v : std::initializer_list<std::string>(
{"a", "w1", "b1", "d", "e", "f", "w2", "b2", "i"})) { {"a", "w1", "b1", "d", "e", "f", "w2", "b2", "i"})) {
...@@ -89,22 +93,22 @@ ProgramDesc BuildConvRequantProgramDesc(bool use_mkldnn, float scale_out, ...@@ -89,22 +93,22 @@ ProgramDesc BuildConvRequantProgramDesc(bool use_mkldnn, float scale_out,
} }
SetOp(&prog, "conv2d", "Conv1", {"a", "w1", "b1"}, {"d"}, use_mkldnn, SetOp(&prog, "conv2d", "Conv1", {"a", "w1", "b1"}, {"d"}, use_mkldnn,
scale_out); {1.23f, scale_out});
SetOp(&prog, "dequantize", "Dequant", {"d"}, {"e"}, use_mkldnn, scale1); SetOp(&prog, "dequantize", "Dequant", {"d"}, {"e"}, use_mkldnn, {scale_out});
SetOp(&prog, "quantize", "Quant", {"e"}, {"f"}, use_mkldnn, scale2); SetOp(&prog, "quantize", "Quant", {"e"}, {"f"}, use_mkldnn, {scale_in});
SetOp(&prog, "conv2d", "Conv2", {"f", "w2", "b2"}, {"i"}, use_mkldnn, SetOp(&prog, "conv2d", "Conv2", {"f", "w2", "b2"}, {"i"}, use_mkldnn,
scale_out); {scale_in, 2.34f});
return prog; return prog;
} }
static const std::initializer_list<std::string> variable_names{ static const std::initializer_list<std::string> variable_names{
"a", "b", "c", "d", "e", "f", "g", "h", "x", "y", "w1"}; "a", "b", "c", "d", "e", "f", "g", "h", "i", "x", "y", "w1", "w2"};
// a->Conv1->b // a->Conv1(scale1)->b
// b->Dequant(scale1)->c // b->Dequant(scale1)->c
// c->Quant1(scale2)->d and d->Conv2->e // c->Quant1(scale2)->d and d->(scale2)Conv2->e
// c->Conv3->f // c->Conv3->f
// c->Quant2(scale3)->g and g->Conv4->h // c->Quant2(scale3)->g and g->Concat->h
ProgramDesc BuildConvMultiOutputProgramDesc(bool use_mkldnn, float scale_out, ProgramDesc BuildConvMultiOutputProgramDesc(bool use_mkldnn, float scale_out,
float scale1, float scale2, float scale1, float scale2,
float scale3) { float scale3) {
...@@ -113,16 +117,17 @@ ProgramDesc BuildConvMultiOutputProgramDesc(bool use_mkldnn, float scale_out, ...@@ -113,16 +117,17 @@ ProgramDesc BuildConvMultiOutputProgramDesc(bool use_mkldnn, float scale_out,
prog.MutableBlock(0)->Var(v); prog.MutableBlock(0)->Var(v);
} }
SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, scale_out); SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, {1.23f, scale1});
SetOp(&prog, "dequantize", "Dequant", {"b"}, {"c"}, use_mkldnn, scale1); SetOp(&prog, "dequantize", "Dequant", {"b"}, {"c"}, use_mkldnn, {scale1});
SetOp(&prog, "quantize", "Quant1", {"c"}, {"d"}, use_mkldnn, scale2); SetOp(&prog, "quantize", "Quant1", {"c"}, {"d"}, use_mkldnn, {scale2});
SetOp(&prog, "conv2d", "Conv2", {"d"}, {"e"}, use_mkldnn, scale_out); SetOp(&prog, "conv2d", "Conv2", {"d"}, {"e"}, use_mkldnn,
{scale2, scale_out});
SetOp(&prog, "conv2d", "Conv3", {"c"}, {"f"}, use_mkldnn, scale_out); SetOp(&prog, "conv2d", "Conv3", {"c"}, {"f"}, use_mkldnn);
SetOp(&prog, "quantize", "Quant2", {"c"}, {"g"}, use_mkldnn, scale3); SetOp(&prog, "quantize", "Quant2", {"c"}, {"g"}, use_mkldnn, {scale3});
SetOp(&prog, "conv2d", "Conv4", {"g"}, {"h"}, use_mkldnn, scale_out); SetOp(&prog, "concat", "Concat", {"g"}, {"h"}, use_mkldnn);
return prog; return prog;
} }
...@@ -141,16 +146,17 @@ ProgramDesc BuildOpRequantProgramDesc(bool use_mkldnn, float conv_scale, ...@@ -141,16 +146,17 @@ ProgramDesc BuildOpRequantProgramDesc(bool use_mkldnn, float conv_scale,
prog.MutableBlock(0)->Var(v); prog.MutableBlock(0)->Var(v);
} }
SetOp(&prog, "conv2d", "Conv", {"a"}, {"b"}, use_mkldnn, conv_scale); SetOp(&prog, "conv2d", "Conv", {"a"}, {"b"}, use_mkldnn, {1.23f, conv_scale});
SetOp(&prog, "requantize", "Requant1", {"b"}, {"c"}, use_mkldnn, SetOp(&prog, "requantize", "Requant1", {"b"}, {"c"}, use_mkldnn,
requant_scale1); {conv_scale, requant_scale1});
SetOp(&prog, "fc", "Fc", {"d", "w1"}, {"e"}, use_mkldnn, fc_scale); SetOp(&prog, "fc", "Fc", {"d", "w1"}, {"e"}, use_mkldnn, {1.23f, fc_scale});
SetOp(&prog, "requantize", "Requant2", {"e"}, {"f"}, use_mkldnn, SetOp(&prog, "requantize", "Requant2", {"e"}, {"f"}, use_mkldnn,
requant_scale2); {fc_scale, requant_scale2});
SetOp(&prog, "matmul", "Matmul", {"x", "y"}, {"g"}, use_mkldnn, matmul_scale); SetOp(&prog, "matmul", "Matmul", {"x", "y"}, {"g"}, use_mkldnn,
{1.23f, matmul_scale});
SetOp(&prog, "requantize", "Requant3", {"g"}, {"h"}, use_mkldnn, SetOp(&prog, "requantize", "Requant3", {"g"}, {"h"}, use_mkldnn,
requant_scale3); {matmul_scale, requant_scale3});
SetOp(&prog, "concat", "Concat", {"c", "f", "h"}, {"g"}, use_mkldnn); SetOp(&prog, "concat", "Concat", {"c", "f", "h"}, {"g"}, {use_mkldnn});
return prog; return prog;
} }
...@@ -158,7 +164,8 @@ ProgramDesc BuildOpRequantProgramDesc(bool use_mkldnn, float conv_scale, ...@@ -158,7 +164,8 @@ ProgramDesc BuildOpRequantProgramDesc(bool use_mkldnn, float conv_scale,
// a->Concat->b // a->Concat->b
// b->Dequant(scale1)->c // b->Dequant(scale1)->c
// c->Quant(scale2)->d // c->Quant(scale2)->d
// d->Conv->e // d->Conv1->e
// d->Conv2->f
ProgramDesc BuildConcatDequantQuantProgramDesc(bool use_mkldnn, float scale_out, ProgramDesc BuildConcatDequantQuantProgramDesc(bool use_mkldnn, float scale_out,
float scale1, float scale2) { float scale1, float scale2) {
ProgramDesc prog; ProgramDesc prog;
...@@ -167,9 +174,12 @@ ProgramDesc BuildConcatDequantQuantProgramDesc(bool use_mkldnn, float scale_out, ...@@ -167,9 +174,12 @@ ProgramDesc BuildConcatDequantQuantProgramDesc(bool use_mkldnn, float scale_out,
} }
SetOp(&prog, "concat", "Concat", {"a"}, {"b"}, use_mkldnn); SetOp(&prog, "concat", "Concat", {"a"}, {"b"}, use_mkldnn);
SetOp(&prog, "dequantize", "Dequant", {"b"}, {"c"}, use_mkldnn, scale1); SetOp(&prog, "dequantize", "Dequant", {"b"}, {"c"}, use_mkldnn, {scale1});
SetOp(&prog, "quantize", "Quant", {"c"}, {"d"}, use_mkldnn, scale2); SetOp(&prog, "quantize", "Quant", {"c"}, {"d"}, use_mkldnn, {scale2});
SetOp(&prog, "conv2d", "Conv2", {"d"}, {"e"}, use_mkldnn, scale_out); SetOp(&prog, "conv2d", "Conv1", {"d"}, {"e"}, use_mkldnn,
{scale2, scale_out});
SetOp(&prog, "conv2d", "Conv2", {"d"}, {"f"}, use_mkldnn,
{scale2, scale_out});
return prog; return prog;
} }
...@@ -182,9 +192,11 @@ ProgramDesc BuildConvMultiRequantProgramDesc(bool use_mkldnn, float scale_out, ...@@ -182,9 +192,11 @@ ProgramDesc BuildConvMultiRequantProgramDesc(bool use_mkldnn, float scale_out,
for (auto& v : variable_names) { for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v); prog.MutableBlock(0)->Var(v);
} }
SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, scale_out); SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, {1.23f, scale_out});
SetOp(&prog, "requantize", "Requant1", {"b"}, {"c"}, use_mkldnn, scale1); SetOp(&prog, "requantize", "Requant1", {"b"}, {"c"}, use_mkldnn,
SetOp(&prog, "requantize", "Requant2", {"b"}, {"d"}, use_mkldnn, scale2); {scale_out, scale1});
SetOp(&prog, "requantize", "Requant2", {"b"}, {"d"}, use_mkldnn,
{scale_out, scale2});
return prog; return prog;
} }
...@@ -197,8 +209,8 @@ ProgramDesc BuildConvDequantConcatProgramDesc(bool use_mkldnn, float scale_out, ...@@ -197,8 +209,8 @@ ProgramDesc BuildConvDequantConcatProgramDesc(bool use_mkldnn, float scale_out,
for (auto& v : variable_names) { for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v); prog.MutableBlock(0)->Var(v);
} }
SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, scale_out); SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, {1.23f, scale_out});
SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, use_mkldnn, scale); SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, use_mkldnn, {scale});
SetOp(&prog, "concat", "Concat1", {"c"}, {"d"}, use_mkldnn); SetOp(&prog, "concat", "Concat1", {"c"}, {"d"}, use_mkldnn);
return prog; return prog;
} }
...@@ -212,24 +224,24 @@ ProgramDesc BuildFcDequantConcatProgramDesc(bool use_mkldnn, float scale_out, ...@@ -212,24 +224,24 @@ ProgramDesc BuildFcDequantConcatProgramDesc(bool use_mkldnn, float scale_out,
for (auto& v : variable_names) { for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v); prog.MutableBlock(0)->Var(v);
} }
SetOp(&prog, "fc", "Fc1", {"a", "w1"}, {"b"}, use_mkldnn, scale_out); SetOp(&prog, "fc", "Fc1", {"a", "w1"}, {"b"}, use_mkldnn, {1.23f, scale_out});
SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, use_mkldnn, scale); SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, use_mkldnn, {scale});
SetOp(&prog, "concat", "Concat1", {"c"}, {"d"}, use_mkldnn); SetOp(&prog, "concat", "Concat1", {"c"}, {"d"}, use_mkldnn);
return prog; return prog;
} }
// a->fc->b // a->fc->b
// b->Dequant1->c // b->Dequant1->c
// b->concat->d // b->fc->d
ProgramDesc BuildFcDequantFcProgramDesc(bool use_mkldnn, float scale_out, ProgramDesc BuildFcDequantFcProgramDesc(bool use_mkldnn, float scale_out,
float scale) { float scale) {
ProgramDesc prog; ProgramDesc prog;
for (auto& v : variable_names) { for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v); prog.MutableBlock(0)->Var(v);
} }
SetOp(&prog, "fc", "Fc1", {"a", "w1"}, {"b"}, use_mkldnn, scale_out); SetOp(&prog, "fc", "Fc1", {"a", "w1"}, {"b"}, use_mkldnn, {1.23f, scale_out});
SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, use_mkldnn, scale); SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, use_mkldnn, {scale});
SetOp(&prog, "concat", "Concat1", {"b"}, {"d"}, use_mkldnn); SetOp(&prog, "fc", "Fc2", {"b", "w2"}, {"d"}, use_mkldnn, {scale_out, 2.34f});
return prog; return prog;
} }
...@@ -242,18 +254,16 @@ ProgramDesc BuildConvDequantConvProgramDesc(bool use_mkldnn, float scale_out, ...@@ -242,18 +254,16 @@ ProgramDesc BuildConvDequantConvProgramDesc(bool use_mkldnn, float scale_out,
for (auto& v : variable_names) { for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v); prog.MutableBlock(0)->Var(v);
} }
SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, scale_out); SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, {1.23f, scale_out});
SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, use_mkldnn, scale); SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, use_mkldnn, {scale});
SetOp(&prog, "conv2d", "Conv2", {"b"}, {"d"}, use_mkldnn); SetOp(&prog, "conv2d", "Conv2", {"b"}, {"d"}, use_mkldnn);
return prog; return prog;
} }
// a->concat->b // a->concat->b
// b->Quant1(Scale1)->c // b->Quant1(Scale1)->c->fc->f
// b->Quant2(Scale2)->d // b->Quant2(Scale2)->d->fc->g
// b->concat->e // b->concat->e
// c->fc->f
// d->fc->g
ProgramDesc BuildMultipleQuantizeProgramDesc(bool use_mkldnn, float first_scale, ProgramDesc BuildMultipleQuantizeProgramDesc(bool use_mkldnn, float first_scale,
float second_scale) { float second_scale) {
ProgramDesc prog; ProgramDesc prog;
...@@ -261,11 +271,15 @@ ProgramDesc BuildMultipleQuantizeProgramDesc(bool use_mkldnn, float first_scale, ...@@ -261,11 +271,15 @@ ProgramDesc BuildMultipleQuantizeProgramDesc(bool use_mkldnn, float first_scale,
prog.MutableBlock(0)->Var(v); prog.MutableBlock(0)->Var(v);
} }
SetOp(&prog, "concat", "Concat1", {"a"}, {"b"}, use_mkldnn); SetOp(&prog, "concat", "Concat1", {"a"}, {"b"}, use_mkldnn);
SetOp(&prog, "quantize", "Quantize1", {"b"}, {"c"}, use_mkldnn, first_scale); SetOp(&prog, "quantize", "Quantize1", {"b"}, {"c"}, use_mkldnn,
SetOp(&prog, "quantize", "Quantize2", {"b"}, {"d"}, use_mkldnn, second_scale); {first_scale});
SetOp(&prog, "quantize", "Quantize2", {"b"}, {"d"}, use_mkldnn,
{second_scale});
SetOp(&prog, "concat", "Concat2", {"b"}, {"e"}, use_mkldnn); SetOp(&prog, "concat", "Concat2", {"b"}, {"e"}, use_mkldnn);
SetOp(&prog, "fc", "Fc1", {"c", "w1"}, {"f"}, use_mkldnn, first_scale); SetOp(&prog, "fc", "Fc1", {"c", "w1"}, {"f"}, use_mkldnn,
SetOp(&prog, "fc", "Fc2", {"d", "w2"}, {"g"}, use_mkldnn, second_scale); {first_scale, 1.23f});
SetOp(&prog, "fc", "Fc2", {"d", "w2"}, {"g"}, use_mkldnn,
{second_scale, 2.34f});
return prog; return prog;
} }
...@@ -279,8 +293,8 @@ ProgramDesc BuildDequantScaleProgramDesc(bool use_mkldnn, float dequant_scale, ...@@ -279,8 +293,8 @@ ProgramDesc BuildDequantScaleProgramDesc(bool use_mkldnn, float dequant_scale,
prog.MutableBlock(0)->Var(v); prog.MutableBlock(0)->Var(v);
} }
SetOp(&prog, "dequantize", "Dequant", {"a"}, {"b"}, use_mkldnn, SetOp(&prog, "dequantize", "Dequant", {"a"}, {"b"}, use_mkldnn,
dequant_scale); {dequant_scale});
SetOp(&prog, "scale", "Scale", {"b"}, {"c"}, use_mkldnn, scale_scale, bias); SetOp(&prog, "scale", "Scale", {"b"}, {"c"}, use_mkldnn, {scale_scale}, bias);
return prog; return prog;
} }
...@@ -295,7 +309,34 @@ ProgramDesc BuildMatmulDequantProgramDesc(bool use_mkldnn, ...@@ -295,7 +309,34 @@ ProgramDesc BuildMatmulDequantProgramDesc(bool use_mkldnn,
} }
SetOp(&prog, "matmul", "Matmul", {"x", "y"}, {"b"}, use_mkldnn); SetOp(&prog, "matmul", "Matmul", {"x", "y"}, {"b"}, use_mkldnn);
SetOp(&prog, "dequantize", "Dequant", {"b"}, {"c"}, use_mkldnn, SetOp(&prog, "dequantize", "Dequant", {"b"}, {"c"}, use_mkldnn,
dequant_scale); {dequant_scale});
return prog;
}
// a->Requant1->x->Matmul->b
// c->Requant2->d->Fc->e
// f->Requant3->g->Conv->h
// {b,e,h}->Concat->i
ProgramDesc BuildRequantOpProgramDesc(bool use_mkldnn, float requant_scale_in,
float op_scale_in, float op_scale_out) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "requantize", "Requant1", {"a"}, {"x"}, use_mkldnn,
{requant_scale_in, op_scale_in});
SetOp(&prog, "matmul", "Matmul", {"x", "y"}, {"b"}, use_mkldnn,
{op_scale_in, op_scale_out});
SetOp(&prog, "requantize", "Requant2", {"c"}, {"d"}, use_mkldnn,
{requant_scale_in, op_scale_in});
SetOp(&prog, "fc", "Fc", {"d", "w1"}, {"e"}, use_mkldnn,
{op_scale_in, op_scale_out});
SetOp(&prog, "requantize", "Requant3", {"f"}, {"g"}, use_mkldnn,
{requant_scale_in, op_scale_in});
SetOp(&prog, "conv2d", "Conv", {"g"}, {"h"}, use_mkldnn,
{op_scale_in, op_scale_out});
SetOp(&prog, "concat", "Concat", {"b", "e", "h"}, {"i"}, {use_mkldnn});
return prog; return prog;
} }
...@@ -390,14 +431,13 @@ void IsForceFp32OutputTest(const ProgramDesc& prog, std::string op_type, ...@@ -390,14 +431,13 @@ void IsForceFp32OutputTest(const ProgramDesc& prog, std::string op_type,
// From Conv1->d->Dequant->e->Quant->f->Conv2 // From Conv1->d->Dequant->e->Quant->f->Conv2
// To Conv1->d->Conv2 // To Conv1->d->Conv2
TEST(CpuQuantizeSquashPass, equal_scales) { TEST(CpuQuantizeSquashPass, equal_scales) {
auto scale_out = 1.0f; auto scale_out = 1.234f;
auto scale = 1.2345f; auto scale = 2.345f;
auto use_mkldnn = true; auto use_mkldnn = true;
// Remove 4 nodes: Dequant, Quant, e, f // Remove 4 nodes: Dequant, Quant, e, f
auto remove_nodes = 4; auto remove_nodes = 4;
CountNodeTest( CountNodeTest(BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale),
BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale, scale),
remove_nodes); remove_nodes);
} }
...@@ -405,20 +445,17 @@ TEST(CpuQuantizeSquashPass, equal_scales) { ...@@ -405,20 +445,17 @@ TEST(CpuQuantizeSquashPass, equal_scales) {
// First change to Conv1->d->Requant->f->Conv2 // First change to Conv1->d->Requant->f->Conv2
// Then Conv1->f->Conv2 // Then Conv1->f->Conv2
TEST(CpuQuantizeSquashPass, unequal_scales) { TEST(CpuQuantizeSquashPass, unequal_scales) {
auto scale_out = 1.0f; auto scale_out = 1.230f;
auto scale1 = 1.2345f; auto scale_in = 2.34f;
auto scale2 = 21.0f;
auto use_mkldnn = true; auto use_mkldnn = true;
// Remove 4 nodes: Dequant, Quant, e, d // Remove 4 nodes: Dequant, Quant, e, d
auto remove_nodes = 4; auto remove_nodes = 4;
CountNodeTest( CountNodeTest(BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale_in),
BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale1, scale2),
remove_nodes); remove_nodes);
EqualScaleTest( EqualScaleTest(BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale_in),
BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale1, scale2), "Conv1", "Scale_out", scale_in);
"Conv1", "Scale_out", scale2);
} }
// a->Conv->b->Requant->c // a->Conv->b->Requant->c
...@@ -635,6 +672,20 @@ TEST(CpuQuantizeSquashPass, matmul_with_dequant) { ...@@ -635,6 +672,20 @@ TEST(CpuQuantizeSquashPass, matmul_with_dequant) {
IsForceFp32OutputTest( IsForceFp32OutputTest(
BuildMatmulDequantProgramDesc(use_mkldnn, dequant_scale), "matmul", true); BuildMatmulDequantProgramDesc(use_mkldnn, dequant_scale), "matmul", true);
} }
TEST(CpuQuantizeSquashPass, requantize_with_matmul_fc_conv) {
auto use_mkldnn = true;
auto requant_scale_in = 1.2f, op_scale_in = 2.3f, op_scale_out = 3.4f;
// remove: 3 requant ops + 3 requant outs
auto remove_nodes = 6;
auto program_desc = BuildRequantOpProgramDesc(use_mkldnn, requant_scale_in,
op_scale_in, op_scale_out);
CountNodeTest(program_desc, remove_nodes);
EqualScaleTest(program_desc, "Matmul", "Scale_x", requant_scale_in);
EqualScaleTest(program_desc, "Fc", "Scale_in", requant_scale_in);
EqualScaleTest(program_desc, "Conv", "Scale_in", requant_scale_in);
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册