未验证 提交 5ee099ca 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Op-requant squash (#23665)

* Op-requant squash

test=develop

* Add matmul to op-requant test

test=develop
上级 9bc1e0a1
......@@ -1490,20 +1490,22 @@ PDNode *patterns::ConvConcatReLU::operator()() {
return relu_out;
}
PDNode *patterns::ConvRequant::operator()() {
// Create Operators
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
PDNode *patterns::OpRequant::operator()() {
auto any_op = pattern->NewNode(any_op_repr())
->assert_is_op()
->assert_more([&](Node *node) {
return node->Op()->HasAttr("Scale_out") ? true : false;
});
auto requant_in = pattern->NewNode(requant_in_repr())
->assert_is_op_input("requantize", "Input");
auto requant_op =
pattern->NewNode(requant_op_repr())->assert_is_op("requantize");
auto conv_out = pattern->NewNode(conv_out_repr())
->assert_is_op_output("conv2d", "Output");
auto requant_out = pattern->NewNode(requant_out_repr())
->AsOutput()
->assert_is_op_output("requantize", "Output");
conv_op->LinksTo({conv_out});
requant_op->LinksFrom({conv_out}).LinksTo({requant_out});
any_op->LinksTo({requant_in});
requant_op->LinksFrom({requant_in}).LinksTo({requant_out});
return requant_out;
}
......
......@@ -897,19 +897,18 @@ struct ConvConcatReLU : public PatternBase {
PATTERN_DECL_NODE(relu_out);
};
// Conv + Requant
// Op + Requant
// named nodes:
// conv_op, conv_out
// any_op, any_out
// requant_op, requant_out
struct ConvRequant : public PatternBase {
ConvRequant(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_requant") {}
struct OpRequant : public PatternBase {
OpRequant(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "op_requant") {}
PDNode* operator()();
PATTERN_DECL_NODE(conv_op);
PATTERN_DECL_NODE(conv_out);
PATTERN_DECL_NODE(any_op);
PATTERN_DECL_NODE(requant_in);
PATTERN_DECL_NODE(requant_op);
PATTERN_DECL_NODE(requant_out);
};
......
......@@ -126,38 +126,47 @@ void CPUQuantizeSquashPass::DequantQuantSquash(
found_dequant_quant_count);
}
void CPUQuantizeSquashPass::ConvRequantSquash(Graph* graph) const {
// op+requant squash if op has Scale_out attr
// conv2d and fc
void CPUQuantizeSquashPass::OpRequantSquash(Graph* graph) const {
GraphPatternDetector gpd;
patterns::ConvRequant conv_requant_pattern{gpd.mutable_pattern(),
"conv_requant"};
conv_requant_pattern();
patterns::OpRequant op_requant_pattern{gpd.mutable_pattern(), "op_requant"};
op_requant_pattern();
int found_requant_squash_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "squash conv-requantize ops pair";
VLOG(4) << "squash op-requantize ops pair";
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_requant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_requant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(requant_op, requant_op, conv_requant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(requant_out, requant_out, conv_requant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(any_op, any_op, op_requant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(requant_in, requant_in, op_requant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(requant_op, requant_op, op_requant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(requant_out, requant_out, op_requant_pattern);
if (requant_in->outputs.size() == 1) {
std::string any_op_output_name;
for (auto name : any_op->Op()->OutputNames())
for (auto output_name : any_op->Op()->Output(name))
if (output_name == requant_in->Name()) any_op_output_name = name;
PADDLE_ENFORCE_NE(
any_op_output_name.empty(), true,
platform::errors::NotFound("Operator before requantize operator "
"should has requantize input as output"));
// if conv2d has one output squash
if (conv_out->outputs.size() == 1) {
float requant_scale_out =
boost::get<float>(requant_op->Op()->GetAttr("Scale_out"));
conv_op->Op()->SetAttr("Scale_out", requant_scale_out);
conv_op->Op()->SetOutput("Output",
std::vector<std::string>({requant_out->Name()}));
IR_NODE_LINK_TO(conv_op, requant_out);
GraphSafeRemoveNodes(graph, {conv_out, requant_op});
any_op->Op()->SetAttr("Scale_out", requant_scale_out);
any_op->Op()->SetOutput(any_op_output_name,
std::vector<std::string>({requant_out->Name()}));
IR_NODE_LINK_TO(any_op, requant_out);
GraphSafeRemoveNodes(graph, {requant_in, requant_op});
found_requant_squash_count++;
}
};
gpd(graph, handler);
AddStatis(found_requant_squash_count);
PrettyLogDetail("--- squashed %d requantize with convs",
PrettyLogDetail("--- squashed %d requantize ops",
found_requant_squash_count);
}
......@@ -369,7 +378,7 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
std::unordered_map<const Node*, int> nodes_keep_counter;
FindNodesToKeep(graph, &nodes_keep_counter);
DequantQuantSquash(graph, &nodes_keep_counter);
ConvRequantSquash(graph);
OpRequantSquash(graph);
ConvDequantSquash(graph);
FcDequantSquash(graph);
MultipleQuantizeSquash(graph);
......
......@@ -53,7 +53,7 @@ class CPUQuantizeSquashPass : public FusePassBase {
/*
* Squash requantize op into conv with scale_out like requantize scale_out
*/
void ConvRequantSquash(Graph* graph) const;
void OpRequantSquash(Graph* graph) const;
/*
* Squash conv2d with dequant when dequant is the only op after conv2d
......
......@@ -59,6 +59,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
inputs.size()));
op->SetInput("W", {inputs[1]});
op->SetOutput("Out", outputs);
op->SetAttr("Scale_out", scale);
} else if (type == "scale") {
op->SetInput("X", {inputs[0]});
op->SetOutput("Out", {outputs[0]});
......@@ -68,6 +69,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op->SetInput("X", {inputs[0]});
op->SetInput("Y", {inputs[1]});
op->SetOutput("Out", {outputs[0]});
op->SetAttr("Scale_out", scale);
}
}
......@@ -96,7 +98,7 @@ ProgramDesc BuildConvRequantProgramDesc(bool use_mkldnn, float scale_out,
}
static const std::initializer_list<std::string> variable_names{
"a", "b", "c", "d", "e", "f", "g", "h", "x", "y"};
"a", "b", "c", "d", "e", "f", "g", "h", "x", "y", "w1"};
// a->Conv1->b
// b->Dequant(scale1)->c
......@@ -125,23 +127,30 @@ ProgramDesc BuildConvMultiOutputProgramDesc(bool use_mkldnn, float scale_out,
return prog;
}
// a->Conv1->b->Requant(scale1)->c
// d->Conv2->e->Requant(scale2)->f
// {c,f}->Concat
ProgramDesc BuildConvsRequantConcatProgramDesc(bool use_mkldnn, float scale_out,
float scale1, float scale2) {
// a->Conv->b->Requant(scale1)->c
// d->Fc->e->Requant(scale2)->f
// {x,y}->Matmul->g->Requant(scale3)->h
// {c,f,h}->Concat
ProgramDesc BuildOpRequantProgramDesc(bool use_mkldnn, float conv_scale,
float fc_scale, float matmul_scale,
float requant_scale1,
float requant_scale2,
float requant_scale3) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, scale_out);
SetOp(&prog, "requantize", "Requant1", {"b"}, {"c"}, use_mkldnn, scale1);
SetOp(&prog, "conv2d", "Conv2", {"d"}, {"e"}, use_mkldnn, scale_out);
SetOp(&prog, "requantize", "Requant2", {"e"}, {"f"}, use_mkldnn, scale2);
SetOp(&prog, "concat", "Concat", {"c"}, {"f"}, use_mkldnn);
SetOp(&prog, "conv2d", "Conv", {"a"}, {"b"}, use_mkldnn, conv_scale);
SetOp(&prog, "requantize", "Requant1", {"b"}, {"c"}, use_mkldnn,
requant_scale1);
SetOp(&prog, "fc", "Fc", {"d", "w1"}, {"e"}, use_mkldnn, fc_scale);
SetOp(&prog, "requantize", "Requant2", {"e"}, {"f"}, use_mkldnn,
requant_scale2);
SetOp(&prog, "matmul", "Matmul", {"x", "y"}, {"g"}, use_mkldnn, matmul_scale);
SetOp(&prog, "requantize", "Requant3", {"g"}, {"h"}, use_mkldnn,
requant_scale3);
SetOp(&prog, "concat", "Concat", {"c", "f", "h"}, {"g"}, use_mkldnn);
return prog;
}
......@@ -412,27 +421,28 @@ TEST(CpuQuantizeSquashPass, unequal_scales) {
"Conv1", "Scale_out", scale2);
}
// a->Conv1->b->Requant->c
// d->Conv2->e->Requant->f
// {c,f}->Concat
TEST(CpuQuantizeSquashPass, equal_scales_squash_requantize) {
// Delete both requantize op
auto scale_out = 1.0f;
auto scale = 1.2345f;
// a->Conv->b->Requant->c
// d->Fc->e->Requant->f
// {x,y}->Matmul->g->Requant->h
// {c,f,h}->Concat
TEST(CpuQuantizeSquashPass, op_requantize_squash) {
// Delete all requantize op
auto conv_scale = 0.234f;
auto fc_scale = 1.234f;
auto matmul_scale = 2.234f;
auto requant_scale1 = 2.234f;
auto requant_scale2 = 3.234f;
auto requant_scale3 = 4.234f;
auto use_mkldnn = true;
// Remove 4 nodes: b, Requant1, e, Requant2
auto remove_nodes = 4;
CountNodeTest(
BuildConvsRequantConcatProgramDesc(use_mkldnn, scale_out, scale, scale),
remove_nodes);
// check equal scale conv->scale_out and requant->scale_out
EqualScaleTest(
BuildConvsRequantConcatProgramDesc(use_mkldnn, scale_out, scale, scale),
"Conv1", "Scale_out", scale);
EqualScaleTest(
BuildConvsRequantConcatProgramDesc(use_mkldnn, scale_out, scale, scale),
"Conv2", "Scale_out", scale);
// Remove 4 nodes: b, Requant1, e, Requant2, g, Requant3
auto remove_nodes = 6;
auto program_desc =
BuildOpRequantProgramDesc(use_mkldnn, conv_scale, fc_scale, matmul_scale,
requant_scale1, requant_scale2, requant_scale3);
CountNodeTest(program_desc, remove_nodes);
EqualScaleTest(program_desc, "Conv", "Scale_out", requant_scale1);
EqualScaleTest(program_desc, "Fc", "Scale_out", requant_scale2);
EqualScaleTest(program_desc, "Matmul", "Scale_out", requant_scale3);
}
// from
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册