未验证 提交 356f5ee2 编写于 作者: J joanna.wozna.intel 提交者: GitHub

[Refactoring] Unify op-dequant squashes (#24277)

上级 ac9a7eee
......@@ -1529,39 +1529,24 @@ PDNode *patterns::RequantOp::operator()() {
return any_op;
}
PDNode *patterns::ConvDequant::operator()() {
// Create Operators
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
auto dequant_op =
pattern->NewNode(dequant_op_repr())->assert_is_op("dequantize");
auto conv_out = pattern->NewNode(conv_out_repr())
->assert_is_op_output("conv2d", "Output");
auto dequant_out = pattern->NewNode(dequant_out_repr())
->AsOutput()
->assert_is_op_output("dequantize", "Output");
conv_op->LinksTo({conv_out});
dequant_op->LinksFrom({conv_out}).LinksTo({dequant_out});
return dequant_out;
}
PDNode *patterns::FcDequant::operator()() {
// Create Operators
auto fc_op = pattern->NewNode(fc_op_repr())->assert_is_op("fc");
PDNode *patterns::OpDequant::operator()() {
auto any_op = pattern->NewNode(any_op_repr())
->assert_is_op()
->assert_more([&](Node *node) {
return (node->Op()->Type() == "matmul" ||
node->Op()->Type() == "conv2d" ||
node->Op()->Type() == "fc");
});
auto dequant_in = pattern->NewNode(dequant_in_repr())
->assert_is_op_input("dequantize", "Input");
auto dequant_op =
pattern->NewNode(dequant_op_repr())->assert_is_op("dequantize");
auto fc_out =
pattern->NewNode(fc_out_repr())->assert_is_op_output("fc", "Out");
auto dequant_out = pattern->NewNode(dequant_out_repr())
->AsOutput()
->assert_is_op_output("dequantize", "Output");
fc_op->LinksTo({fc_out});
dequant_op->LinksFrom({fc_out}).LinksTo({dequant_out});
any_op->LinksTo({dequant_in});
dequant_op->LinksFrom({dequant_in}).LinksTo({dequant_out});
return dequant_out;
}
......@@ -1584,23 +1569,6 @@ PDNode *patterns::DequantScale::operator()() {
return scale_out;
}
PDNode *patterns::MatmulDequant::operator()() {
auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op("matmul");
auto dequant_op =
pattern->NewNode(dequant_op_repr())->assert_is_op("dequantize");
auto matmul_out = pattern->NewNode(matmul_out_repr())
->AsOutput()
->assert_is_op_output("matmul", "Out");
auto dequant_out = pattern->NewNode(dequant_out_repr())
->AsOutput()
->assert_is_op_output("dequantize", "Output");
matmul_op->LinksTo({matmul_out});
dequant_op->LinksFrom({matmul_out}).LinksTo({dequant_out});
return dequant_out;
}
PDNode *patterns::ScaleMatmul::operator()() {
auto scale_in = pattern->NewNode(scale_in_repr())
->AsInput()
......
......@@ -929,33 +929,18 @@ struct RequantOp : public PatternBase {
PATTERN_DECL_NODE(requant_out);
};
// Conv + Dequant
// Op + Dequant
// named nodes:
// conv_op, conv_out
// any_op, dequant_in
// dequant_op, dequant_out
struct ConvDequant : public PatternBase {
ConvDequant(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_dequant") {}
PDNode* operator()();
PATTERN_DECL_NODE(conv_op);
PATTERN_DECL_NODE(conv_out);
PATTERN_DECL_NODE(dequant_op);
PATTERN_DECL_NODE(dequant_out);
};
// Fc + Dequant
struct FcDequant : public PatternBase {
FcDequant(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "fc_dequant") {}
struct OpDequant : public PatternBase {
OpDequant(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "op_dequant") {}
PDNode* operator()();
PATTERN_DECL_NODE(fc_op);
PATTERN_DECL_NODE(fc_out);
PATTERN_DECL_NODE(any_op);
PATTERN_DECL_NODE(dequant_in);
PATTERN_DECL_NODE(dequant_op);
PATTERN_DECL_NODE(dequant_out);
};
......@@ -974,20 +959,6 @@ struct DequantScale : public PatternBase {
PATTERN_DECL_NODE(scale_out);
};
// Matmul + Dequantize
struct MatmulDequant : public PatternBase {
MatmulDequant(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "matmul_dequant") {}
PDNode* operator()();
PATTERN_DECL_NODE(matmul_op);
PATTERN_DECL_NODE(matmul_out);
PATTERN_DECL_NODE(dequant_op);
PATTERN_DECL_NODE(dequant_out);
};
// Scale + Matmul
struct ScaleMatmul : public PatternBase {
ScaleMatmul(PDPattern* pattern, const std::string& name_scope)
......
......@@ -223,71 +223,44 @@ void CPUQuantizeSquashPass::RequantOpSquash(Graph* graph) const {
found_requant_squash_count);
}
void CPUQuantizeSquashPass::ConvDequantSquash(Graph* graph) const {
GraphPatternDetector gpd;
patterns::ConvDequant conv_dequant_pattern{gpd.mutable_pattern(),
"conv_dequant"};
conv_dequant_pattern();
int found_conv_dequant_squash_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "squash conv-dequant ops pair";
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_dequant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_dequant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(dequant_op, dequant_op, conv_dequant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, conv_dequant_pattern);
// if conv2d has one output
// and there is no fuse residual connection
// because residual fusion does not support force output with fp32
if (conv_out->outputs.size() == 1 &&
!(conv_op->Op()->GetAttrIfExists<bool>("fuse_residual_connection"))) {
conv_op->Op()->SetAttr("force_fp32_output", true);
conv_op->Op()->SetOutput("Output",
std::vector<std::string>({dequant_out->Name()}));
IR_NODE_LINK_TO(conv_op, dequant_out);
GraphSafeRemoveNodes(graph, {conv_out, dequant_op});
found_conv_dequant_squash_count++;
}
};
gpd(graph, handler);
AddStatis(found_conv_dequant_squash_count);
PrettyLogDetail("--- squashed %d dequant with convs",
found_conv_dequant_squash_count);
}
// squash fc with dequant
void CPUQuantizeSquashPass::FcDequantSquash(Graph* graph) const {
// squash dequant with previous op if that op has force_fp32_output attr
// conv2d, fc, matmul
void CPUQuantizeSquashPass::OpDequantSquash(Graph* graph) const {
GraphPatternDetector gpd;
patterns::FcDequant fc_dequant_pattern{gpd.mutable_pattern(), "fc_dequant"};
fc_dequant_pattern();
patterns::OpDequant op_dequant_pattern{gpd.mutable_pattern(), "op_dequant"};
op_dequant_pattern();
int found_fc_dequant_squash_count = 0;
int found_op_dequant_squash_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "squash fc-dequant ops pair";
GET_IR_NODE_FROM_SUBGRAPH(fc_op, fc_op, fc_dequant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fc_out, fc_out, fc_dequant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(dequant_op, dequant_op, fc_dequant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, fc_dequant_pattern);
// if fc has force_fp32_output attribute
if (fc_out->outputs.size() == 1) {
fc_op->Op()->SetAttr("force_fp32_output", true);
fc_op->Op()->SetOutput("Out",
std::vector<std::string>({dequant_out->Name()}));
IR_NODE_LINK_TO(fc_op, dequant_out);
GraphSafeRemoveNodes(graph, {fc_out, dequant_op});
found_fc_dequant_squash_count++;
VLOG(4) << "squash op-dequant ops pair";
GET_IR_NODE_FROM_SUBGRAPH(any_op, any_op, op_dequant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(dequant_in, dequant_in, op_dequant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(dequant_op, dequant_op, op_dequant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, op_dequant_pattern);
if (dequant_in->outputs.size() == 1) {
auto output_name = "Out";
if (any_op->Op()->Type() == "conv2d") {
// do not squash if fuse residual connection is true
// because residual fusion does not support force output with fp32
if (any_op->Op()->GetAttrIfExists<bool>("fuse_residual_connection"))
return;
output_name = "Output";
}
any_op->Op()->SetAttr("force_fp32_output", true);
any_op->Op()->SetOutput(output_name,
std::vector<std::string>({dequant_out->Name()}));
IR_NODE_LINK_TO(any_op, dequant_out);
GraphSafeRemoveNodes(graph, {dequant_in, dequant_op});
found_op_dequant_squash_count++;
}
};
gpd(graph, handler);
AddStatis(found_fc_dequant_squash_count);
PrettyLogDetail("--- squashed %d dequant with fcs",
found_fc_dequant_squash_count);
AddStatis(found_op_dequant_squash_count);
PrettyLogDetail("--- squashed %d dequant with ops",
found_op_dequant_squash_count);
}
void CPUQuantizeSquashPass::MultipleQuantizeSquash(Graph* graph) const {
......@@ -389,38 +362,6 @@ void CPUQuantizeSquashPass::DequantScaleSquash(Graph* graph) const {
found_dequant_scale_squash_count);
}
// squash dequant with dequant
void CPUQuantizeSquashPass::MatmulDequantSquash(Graph* graph) const {
GraphPatternDetector gpd;
patterns::MatmulDequant matmul_dequant_pattern{gpd.mutable_pattern(),
"matmul_dequant"};
matmul_dequant_pattern();
int found_matmul_dequant_squash_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "squash matmul-dequant ops pair";
GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, matmul_dequant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, matmul_dequant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(dequant_op, dequant_op, matmul_dequant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, matmul_dequant_pattern);
if (matmul_out->outputs.size() == 1) {
matmul_op->Op()->SetAttr("force_fp32_output", true);
matmul_op->Op()->SetOutput(
"Out", std::vector<std::string>({dequant_out->Name()}));
IR_NODE_LINK_TO(matmul_op, dequant_out);
GraphSafeRemoveNodes(graph, {matmul_out, dequant_op});
found_matmul_dequant_squash_count++;
}
};
gpd(graph, handler);
AddStatis(found_matmul_dequant_squash_count);
PrettyLogDetail("--- squashed %d dequant with matmul",
found_matmul_dequant_squash_count);
}
void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph,
......@@ -433,11 +374,9 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
DequantQuantSquash(graph, &nodes_keep_counter);
OpRequantSquash(graph);
RequantOpSquash(graph);
ConvDequantSquash(graph);
FcDequantSquash(graph);
OpDequantSquash(graph);
MultipleQuantizeSquash(graph);
DequantScaleSquash(graph);
MatmulDequantSquash(graph);
}
} // namespace ir
......
......@@ -61,29 +61,19 @@ class CPUQuantizeSquashPass : public FusePassBase {
void RequantOpSquash(Graph* graph) const;
/*
* Squash conv2d with dequant when dequant is the only op after conv2d
*/
void ConvDequantSquash(Graph* graph) const;
/*
* Squash fc with dequant when dequant is the next op after fc
*/
void FcDequantSquash(Graph* graph) const;
* Squash dequant if the previous operator has force_fp32_output attribute
*/
void OpDequantSquash(Graph* graph) const;
/*
* Squash quantize if several quatize ops have the same scale
*/
* Squash quantize if several quatize ops have the same scale
*/
void MultipleQuantizeSquash(Graph* graph) const;
/*
* Squash scale if dequantize is before scale
*/
void DequantScaleSquash(Graph* graph) const;
/*
* Squash dequantize if it is after matmul
* Squash scale if dequantize is before scale
*/
void MatmulDequantSquash(Graph* graph) const;
void DequantScaleSquash(Graph* graph) const;
const std::string name_scope_{"squash"};
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册