未验证 提交 356f5ee2 编写于 作者: J joanna.wozna.intel 提交者: GitHub

[Refactoring] Unify op-dequant squashes (#24277)

上级 ac9a7eee
...@@ -1529,39 +1529,24 @@ PDNode *patterns::RequantOp::operator()() { ...@@ -1529,39 +1529,24 @@ PDNode *patterns::RequantOp::operator()() {
return any_op; return any_op;
} }
PDNode *patterns::ConvDequant::operator()() { PDNode *patterns::OpDequant::operator()() {
// Create Operators auto any_op = pattern->NewNode(any_op_repr())
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d"); ->assert_is_op()
auto dequant_op = ->assert_more([&](Node *node) {
pattern->NewNode(dequant_op_repr())->assert_is_op("dequantize"); return (node->Op()->Type() == "matmul" ||
node->Op()->Type() == "conv2d" ||
auto conv_out = pattern->NewNode(conv_out_repr()) node->Op()->Type() == "fc");
->assert_is_op_output("conv2d", "Output"); });
auto dequant_out = pattern->NewNode(dequant_out_repr()) auto dequant_in = pattern->NewNode(dequant_in_repr())
->AsOutput() ->assert_is_op_input("dequantize", "Input");
->assert_is_op_output("dequantize", "Output");
conv_op->LinksTo({conv_out});
dequant_op->LinksFrom({conv_out}).LinksTo({dequant_out});
return dequant_out;
}
PDNode *patterns::FcDequant::operator()() {
// Create Operators
auto fc_op = pattern->NewNode(fc_op_repr())->assert_is_op("fc");
auto dequant_op = auto dequant_op =
pattern->NewNode(dequant_op_repr())->assert_is_op("dequantize"); pattern->NewNode(dequant_op_repr())->assert_is_op("dequantize");
auto fc_out =
pattern->NewNode(fc_out_repr())->assert_is_op_output("fc", "Out");
auto dequant_out = pattern->NewNode(dequant_out_repr()) auto dequant_out = pattern->NewNode(dequant_out_repr())
->AsOutput() ->AsOutput()
->assert_is_op_output("dequantize", "Output"); ->assert_is_op_output("dequantize", "Output");
fc_op->LinksTo({fc_out}); any_op->LinksTo({dequant_in});
dequant_op->LinksFrom({fc_out}).LinksTo({dequant_out}); dequant_op->LinksFrom({dequant_in}).LinksTo({dequant_out});
return dequant_out; return dequant_out;
} }
...@@ -1584,23 +1569,6 @@ PDNode *patterns::DequantScale::operator()() { ...@@ -1584,23 +1569,6 @@ PDNode *patterns::DequantScale::operator()() {
return scale_out; return scale_out;
} }
PDNode *patterns::MatmulDequant::operator()() {
auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op("matmul");
auto dequant_op =
pattern->NewNode(dequant_op_repr())->assert_is_op("dequantize");
auto matmul_out = pattern->NewNode(matmul_out_repr())
->AsOutput()
->assert_is_op_output("matmul", "Out");
auto dequant_out = pattern->NewNode(dequant_out_repr())
->AsOutput()
->assert_is_op_output("dequantize", "Output");
matmul_op->LinksTo({matmul_out});
dequant_op->LinksFrom({matmul_out}).LinksTo({dequant_out});
return dequant_out;
}
PDNode *patterns::ScaleMatmul::operator()() { PDNode *patterns::ScaleMatmul::operator()() {
auto scale_in = pattern->NewNode(scale_in_repr()) auto scale_in = pattern->NewNode(scale_in_repr())
->AsInput() ->AsInput()
......
...@@ -929,33 +929,18 @@ struct RequantOp : public PatternBase { ...@@ -929,33 +929,18 @@ struct RequantOp : public PatternBase {
PATTERN_DECL_NODE(requant_out); PATTERN_DECL_NODE(requant_out);
}; };
// Conv + Dequant // Op + Dequant
// named nodes: // named nodes:
// conv_op, conv_out // any_op, dequant_in
// dequant_op, dequant_out // dequant_op, dequant_out
struct ConvDequant : public PatternBase { struct OpDequant : public PatternBase {
ConvDequant(PDPattern* pattern, const std::string& name_scope) OpDequant(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_dequant") {} : PatternBase(pattern, name_scope, "op_dequant") {}
PDNode* operator()();
PATTERN_DECL_NODE(conv_op);
PATTERN_DECL_NODE(conv_out);
PATTERN_DECL_NODE(dequant_op);
PATTERN_DECL_NODE(dequant_out);
};
// Fc + Dequant
struct FcDequant : public PatternBase {
FcDequant(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "fc_dequant") {}
PDNode* operator()(); PDNode* operator()();
PATTERN_DECL_NODE(fc_op); PATTERN_DECL_NODE(any_op);
PATTERN_DECL_NODE(fc_out); PATTERN_DECL_NODE(dequant_in);
PATTERN_DECL_NODE(dequant_op); PATTERN_DECL_NODE(dequant_op);
PATTERN_DECL_NODE(dequant_out); PATTERN_DECL_NODE(dequant_out);
}; };
...@@ -974,20 +959,6 @@ struct DequantScale : public PatternBase { ...@@ -974,20 +959,6 @@ struct DequantScale : public PatternBase {
PATTERN_DECL_NODE(scale_out); PATTERN_DECL_NODE(scale_out);
}; };
// Matmul + Dequantize
struct MatmulDequant : public PatternBase {
MatmulDequant(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "matmul_dequant") {}
PDNode* operator()();
PATTERN_DECL_NODE(matmul_op);
PATTERN_DECL_NODE(matmul_out);
PATTERN_DECL_NODE(dequant_op);
PATTERN_DECL_NODE(dequant_out);
};
// Scale + Matmul // Scale + Matmul
struct ScaleMatmul : public PatternBase { struct ScaleMatmul : public PatternBase {
ScaleMatmul(PDPattern* pattern, const std::string& name_scope) ScaleMatmul(PDPattern* pattern, const std::string& name_scope)
......
...@@ -223,71 +223,44 @@ void CPUQuantizeSquashPass::RequantOpSquash(Graph* graph) const { ...@@ -223,71 +223,44 @@ void CPUQuantizeSquashPass::RequantOpSquash(Graph* graph) const {
found_requant_squash_count); found_requant_squash_count);
} }
void CPUQuantizeSquashPass::ConvDequantSquash(Graph* graph) const { // squash dequant with previous op if that op has force_fp32_output attr
GraphPatternDetector gpd; // conv2d, fc, matmul
patterns::ConvDequant conv_dequant_pattern{gpd.mutable_pattern(), void CPUQuantizeSquashPass::OpDequantSquash(Graph* graph) const {
"conv_dequant"};
conv_dequant_pattern();
int found_conv_dequant_squash_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "squash conv-dequant ops pair";
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_dequant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_dequant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(dequant_op, dequant_op, conv_dequant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, conv_dequant_pattern);
// if conv2d has one output
// and there is no fuse residual connection
// because residual fusion does not support force output with fp32
if (conv_out->outputs.size() == 1 &&
!(conv_op->Op()->GetAttrIfExists<bool>("fuse_residual_connection"))) {
conv_op->Op()->SetAttr("force_fp32_output", true);
conv_op->Op()->SetOutput("Output",
std::vector<std::string>({dequant_out->Name()}));
IR_NODE_LINK_TO(conv_op, dequant_out);
GraphSafeRemoveNodes(graph, {conv_out, dequant_op});
found_conv_dequant_squash_count++;
}
};
gpd(graph, handler);
AddStatis(found_conv_dequant_squash_count);
PrettyLogDetail("--- squashed %d dequant with convs",
found_conv_dequant_squash_count);
}
// squash fc with dequant
void CPUQuantizeSquashPass::FcDequantSquash(Graph* graph) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
patterns::FcDequant fc_dequant_pattern{gpd.mutable_pattern(), "fc_dequant"}; patterns::OpDequant op_dequant_pattern{gpd.mutable_pattern(), "op_dequant"};
fc_dequant_pattern(); op_dequant_pattern();
int found_fc_dequant_squash_count = 0; int found_op_dequant_squash_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
VLOG(4) << "squash fc-dequant ops pair"; VLOG(4) << "squash op-dequant ops pair";
GET_IR_NODE_FROM_SUBGRAPH(fc_op, fc_op, fc_dequant_pattern); GET_IR_NODE_FROM_SUBGRAPH(any_op, any_op, op_dequant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fc_out, fc_out, fc_dequant_pattern); GET_IR_NODE_FROM_SUBGRAPH(dequant_in, dequant_in, op_dequant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(dequant_op, dequant_op, fc_dequant_pattern); GET_IR_NODE_FROM_SUBGRAPH(dequant_op, dequant_op, op_dequant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, fc_dequant_pattern); GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, op_dequant_pattern);
// if fc has force_fp32_output attribute if (dequant_in->outputs.size() == 1) {
if (fc_out->outputs.size() == 1) { auto output_name = "Out";
fc_op->Op()->SetAttr("force_fp32_output", true); if (any_op->Op()->Type() == "conv2d") {
fc_op->Op()->SetOutput("Out", // do not squash if fuse residual connection is true
std::vector<std::string>({dequant_out->Name()})); // because residual fusion does not support force output with fp32
IR_NODE_LINK_TO(fc_op, dequant_out); if (any_op->Op()->GetAttrIfExists<bool>("fuse_residual_connection"))
GraphSafeRemoveNodes(graph, {fc_out, dequant_op}); return;
found_fc_dequant_squash_count++; output_name = "Output";
}
any_op->Op()->SetAttr("force_fp32_output", true);
any_op->Op()->SetOutput(output_name,
std::vector<std::string>({dequant_out->Name()}));
IR_NODE_LINK_TO(any_op, dequant_out);
GraphSafeRemoveNodes(graph, {dequant_in, dequant_op});
found_op_dequant_squash_count++;
} }
}; };
gpd(graph, handler); gpd(graph, handler);
AddStatis(found_fc_dequant_squash_count); AddStatis(found_op_dequant_squash_count);
PrettyLogDetail("--- squashed %d dequant with fcs", PrettyLogDetail("--- squashed %d dequant with ops",
found_fc_dequant_squash_count); found_op_dequant_squash_count);
} }
void CPUQuantizeSquashPass::MultipleQuantizeSquash(Graph* graph) const { void CPUQuantizeSquashPass::MultipleQuantizeSquash(Graph* graph) const {
...@@ -389,38 +362,6 @@ void CPUQuantizeSquashPass::DequantScaleSquash(Graph* graph) const { ...@@ -389,38 +362,6 @@ void CPUQuantizeSquashPass::DequantScaleSquash(Graph* graph) const {
found_dequant_scale_squash_count); found_dequant_scale_squash_count);
} }
// squash dequant with dequant
void CPUQuantizeSquashPass::MatmulDequantSquash(Graph* graph) const {
GraphPatternDetector gpd;
patterns::MatmulDequant matmul_dequant_pattern{gpd.mutable_pattern(),
"matmul_dequant"};
matmul_dequant_pattern();
int found_matmul_dequant_squash_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "squash matmul-dequant ops pair";
GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, matmul_dequant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, matmul_dequant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(dequant_op, dequant_op, matmul_dequant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, matmul_dequant_pattern);
if (matmul_out->outputs.size() == 1) {
matmul_op->Op()->SetAttr("force_fp32_output", true);
matmul_op->Op()->SetOutput(
"Out", std::vector<std::string>({dequant_out->Name()}));
IR_NODE_LINK_TO(matmul_op, dequant_out);
GraphSafeRemoveNodes(graph, {matmul_out, dequant_op});
found_matmul_dequant_squash_count++;
}
};
gpd(graph, handler);
AddStatis(found_matmul_dequant_squash_count);
PrettyLogDetail("--- squashed %d dequant with matmul",
found_matmul_dequant_squash_count);
}
void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const { void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, graph,
...@@ -433,11 +374,9 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const { ...@@ -433,11 +374,9 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
DequantQuantSquash(graph, &nodes_keep_counter); DequantQuantSquash(graph, &nodes_keep_counter);
OpRequantSquash(graph); OpRequantSquash(graph);
RequantOpSquash(graph); RequantOpSquash(graph);
ConvDequantSquash(graph); OpDequantSquash(graph);
FcDequantSquash(graph);
MultipleQuantizeSquash(graph); MultipleQuantizeSquash(graph);
DequantScaleSquash(graph); DequantScaleSquash(graph);
MatmulDequantSquash(graph);
} }
} // namespace ir } // namespace ir
......
...@@ -61,29 +61,19 @@ class CPUQuantizeSquashPass : public FusePassBase { ...@@ -61,29 +61,19 @@ class CPUQuantizeSquashPass : public FusePassBase {
void RequantOpSquash(Graph* graph) const; void RequantOpSquash(Graph* graph) const;
/* /*
* Squash conv2d with dequant when dequant is the only op after conv2d * Squash dequant if the previous operator has force_fp32_output attribute
*/ */
void ConvDequantSquash(Graph* graph) const; void OpDequantSquash(Graph* graph) const;
/*
* Squash fc with dequant when dequant is the next op after fc
*/
void FcDequantSquash(Graph* graph) const;
/* /*
* Squash quantize if several quatize ops have the same scale * Squash quantize if several quatize ops have the same scale
*/ */
void MultipleQuantizeSquash(Graph* graph) const; void MultipleQuantizeSquash(Graph* graph) const;
/* /*
* Squash scale if dequantize is before scale * Squash scale if dequantize is before scale
*/
void DequantScaleSquash(Graph* graph) const;
/*
* Squash dequantize if it is after matmul
*/ */
void MatmulDequantSquash(Graph* graph) const; void DequantScaleSquash(Graph* graph) const;
const std::string name_scope_{"squash"}; const std::string name_scope_{"squash"};
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册