提交 2e3ec66b 编写于 作者: J joanna.wozna.intel 提交者: Tao Luo

Add conv dequant squash for int8 (#18905)

上级 482ce818
...@@ -1267,6 +1267,24 @@ PDNode *patterns::ConvRequant::operator()() { ...@@ -1267,6 +1267,24 @@ PDNode *patterns::ConvRequant::operator()() {
return requant_out; return requant_out;
} }
PDNode *patterns::ConvDequant::operator()() {
// Create Operators
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
auto dequant_op =
pattern->NewNode(dequant_op_repr())->assert_is_op("dequantize");
auto conv_out = pattern->NewNode(conv_out_repr())
->assert_is_op_output("conv2d", "Output");
auto dequant_out = pattern->NewNode(dequant_out_repr())
->AsOutput()
->assert_is_op_output("dequantize", "Output");
conv_op->LinksTo({conv_out});
dequant_op->LinksFrom({conv_out}).LinksTo({dequant_out});
return dequant_out;
}
PDNode *patterns::PriorBox::operator()() { PDNode *patterns::PriorBox::operator()() {
auto prior_box_op = auto prior_box_op =
pattern->NewNode(prior_box_op_repr())->assert_is_op("prior_box"); pattern->NewNode(prior_box_op_repr())->assert_is_op("prior_box");
......
...@@ -793,6 +793,23 @@ struct ConvRequant : public PatternBase { ...@@ -793,6 +793,23 @@ struct ConvRequant : public PatternBase {
PATTERN_DECL_NODE(requant_out); PATTERN_DECL_NODE(requant_out);
}; };
// Conv + Dequant
// named nodes:
// conv_op, conv_out
// dequant_op, dequant_out
struct ConvDequant : public PatternBase {
ConvDequant(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_dequant") {}
PDNode* operator()();
PATTERN_DECL_NODE(conv_op);
PATTERN_DECL_NODE(conv_out);
PATTERN_DECL_NODE(dequant_op);
PATTERN_DECL_NODE(dequant_out);
};
// PriorBox operator // PriorBox operator
// operator: prior_box_op // operator: prior_box_op
// inputs: prior_box_input, prior_box_image // inputs: prior_box_input, prior_box_image
......
...@@ -160,6 +160,38 @@ void CPUQuantizeSquashPass::ConvRequantSquash(Graph* graph) const { ...@@ -160,6 +160,38 @@ void CPUQuantizeSquashPass::ConvRequantSquash(Graph* graph) const {
found_requant_squash_count); found_requant_squash_count);
} }
void CPUQuantizeSquashPass::ConvDequantSquash(Graph* graph) const {
GraphPatternDetector gpd;
patterns::ConvDequant conv_dequant_pattern{gpd.mutable_pattern(),
"conv_dequant"};
conv_dequant_pattern();
int found_conv_dequant_squash_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "squash conv-dequant ops pair";
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_dequant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_dequant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(dequant_op, dequant_op, conv_dequant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, conv_dequant_pattern);
// if conv2d has one output
if (conv_out->outputs.size() == 1) {
conv_op->Op()->SetAttr("force_fp32_output", true);
conv_op->Op()->SetOutput("Output",
std::vector<std::string>({dequant_out->Name()}));
IR_NODE_LINK_TO(conv_op, dequant_out);
GraphSafeRemoveNodes(graph, {conv_out, dequant_op});
found_conv_dequant_squash_count++;
}
};
gpd(graph, handler);
AddStatis(found_conv_dequant_squash_count);
PrettyLogDetail("--- squashed %d dequant with convs",
found_conv_dequant_squash_count);
}
void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const { void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE(graph);
FusePassBase::Init("cpu_quantize_squash_pass", graph); FusePassBase::Init("cpu_quantize_squash_pass", graph);
...@@ -168,6 +200,7 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const { ...@@ -168,6 +200,7 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
FindNodesToKeep(graph, &nodes_keep_counter); FindNodesToKeep(graph, &nodes_keep_counter);
DequantQuantSquash(graph, &nodes_keep_counter); DequantQuantSquash(graph, &nodes_keep_counter);
ConvRequantSquash(graph); ConvRequantSquash(graph);
ConvDequantSquash(graph);
} }
} // namespace ir } // namespace ir
......
...@@ -55,6 +55,11 @@ class CPUQuantizeSquashPass : public FusePassBase { ...@@ -55,6 +55,11 @@ class CPUQuantizeSquashPass : public FusePassBase {
*/ */
void ConvRequantSquash(Graph* graph) const; void ConvRequantSquash(Graph* graph) const;
/*
* Squash conv2d with dequant when dequant is the only op after conv2d
*/
void ConvDequantSquash(Graph* graph) const;
const std::string name_scope_{"squash"}; const std::string name_scope_{"squash"};
}; };
......
...@@ -161,6 +161,36 @@ ProgramDesc BuildConvMultiRequantProgramDesc(bool use_mkldnn, float scale_out, ...@@ -161,6 +161,36 @@ ProgramDesc BuildConvMultiRequantProgramDesc(bool use_mkldnn, float scale_out,
return prog; return prog;
} }
// a->Conv1->b
// b->Dequant1(Scale1)->c
// c->Concat
ProgramDesc BuildConvDequantConcatProgramDesc(bool use_mkldnn, float scale_out,
float scale) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, scale_out);
SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, use_mkldnn, scale);
SetOp(&prog, "concat", "Concat1", {"c"}, {"d"}, use_mkldnn);
return prog;
}
// a->Conv1->b
// b->Dequant1(Scale1)->c
// b->Conv2->d
ProgramDesc BuildConvDequantConvProgramDesc(bool use_mkldnn, float scale_out,
float scale) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, scale_out);
SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, use_mkldnn, scale);
SetOp(&prog, "conv2d", "Conv2", {"b"}, {"d"}, use_mkldnn);
return prog;
}
void InitTensorHolder(Scope* scope, const paddle::platform::Place& place, void InitTensorHolder(Scope* scope, const paddle::platform::Place& place,
const char* var_name) { const char* var_name) {
auto x = scope->Var(var_name); auto x = scope->Var(var_name);
...@@ -217,6 +247,7 @@ void EqualScaleOutTest(const ProgramDesc& prog, const std::string& name, ...@@ -217,6 +247,7 @@ void EqualScaleOutTest(const ProgramDesc& prog, const std::string& name,
void CheckRequantScalesTest(const ProgramDesc& prog, float scale_in, void CheckRequantScalesTest(const ProgramDesc& prog, float scale_in,
float scale_out) { float scale_out) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog)); std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
PrepareGraph(&graph, prog); PrepareGraph(&graph, prog);
RegisterPass(&graph); RegisterPass(&graph);
...@@ -238,6 +269,7 @@ TEST(CpuQuantizeSquashPass, equal_scales) { ...@@ -238,6 +269,7 @@ TEST(CpuQuantizeSquashPass, equal_scales) {
auto use_mkldnn = true; auto use_mkldnn = true;
// Remove 4 nodes: Dequant, Quant, e, f // Remove 4 nodes: Dequant, Quant, e, f
auto remove_nodes = 4; auto remove_nodes = 4;
CountNodeTest( CountNodeTest(
BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale, scale), BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale, scale),
remove_nodes); remove_nodes);
...@@ -253,6 +285,7 @@ TEST(CpuQuantizeSquashPass, unequal_scales) { ...@@ -253,6 +285,7 @@ TEST(CpuQuantizeSquashPass, unequal_scales) {
auto use_mkldnn = true; auto use_mkldnn = true;
// Remove 4 nodes: Dequant, Quant, e, d // Remove 4 nodes: Dequant, Quant, e, d
auto remove_nodes = 4; auto remove_nodes = 4;
CountNodeTest( CountNodeTest(
BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale1, scale2), BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale1, scale2),
remove_nodes); remove_nodes);
...@@ -280,6 +313,7 @@ TEST(CpuQuantizeSquashPass, branch_to_equal_unequal_and_fp32) { ...@@ -280,6 +313,7 @@ TEST(CpuQuantizeSquashPass, branch_to_equal_unequal_and_fp32) {
// Remove 3 nodes: Quant1, c, Quant2, // Remove 3 nodes: Quant1, c, Quant2,
// Insert 1 node: Requant // Insert 1 node: Requant
auto remove_nodes = 2; auto remove_nodes = 2;
CountNodeTest(BuildConvMultiOutputProgramDesc(use_mkldnn, scale_out, scale, CountNodeTest(BuildConvMultiOutputProgramDesc(use_mkldnn, scale_out, scale,
scale, scale2), scale, scale2),
remove_nodes); remove_nodes);
...@@ -322,6 +356,7 @@ TEST(CpuQuantizeSquashPass, ...@@ -322,6 +356,7 @@ TEST(CpuQuantizeSquashPass,
// Remove 3 nodes: Dequant1, c, Quant // Remove 3 nodes: Dequant1, c, Quant
// Insert 1 node: Requant // Insert 1 node: Requant
auto remove_nodes = 2; auto remove_nodes = 2;
CountNodeTest( CountNodeTest(
BuildConcatDequantQuantProgramDesc(use_mkldnn, scale_out, scale, scale2), BuildConcatDequantQuantProgramDesc(use_mkldnn, scale_out, scale, scale2),
remove_nodes); remove_nodes);
...@@ -345,6 +380,27 @@ TEST(CpuQuantizeSquashPass, more_than_one_conv_out_outputs) { ...@@ -345,6 +380,27 @@ TEST(CpuQuantizeSquashPass, more_than_one_conv_out_outputs) {
remove_nodes); remove_nodes);
} }
// a->Conv1->c->Concat
TEST(CpuQuantizeSquashPass, conv_dequant_only_one_output) {
auto scale_out = 1.0f;
auto scale = 1.2345f;
auto use_mkldnn = true;
// remove 2 nodes: Dequant1, c
auto remove_nodes = 2;
CountNodeTest(BuildConvDequantConcatProgramDesc(use_mkldnn, scale_out, scale),
remove_nodes);
}
TEST(CpuQuantizeSquashPass, conv_dequant_more_than_one_op_after_conv) {
auto scale_out = 1.0f;
auto scale = 1.2345f;
auto use_mkldnn = true;
// nothing change
auto remove_nodes = 0;
CountNodeTest(BuildConvDequantConvProgramDesc(use_mkldnn, scale_out, scale),
remove_nodes);
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册