提交 5b2e98aa 编写于 作者: J joanna.wozna.intel 提交者: Tao Luo

Add multiple quantize operators fuse (#22062)

上级 96980c22
......@@ -1659,6 +1659,21 @@ PDNode *patterns::DequantAny::operator()() {
return dequant_out;
}
PDNode *patterns::MultipleQuantize::operator()() {
auto *prev_out = pattern->NewNode(prev_out_repr())->AsOutput();
// find nodes that are inputs to quantize operators
prev_out->assert_more([&](Node *node) {
int counter = std::count_if(
node->outputs.begin(), node->outputs.end(), [&](Node const *iter) {
return iter && iter->IsOp() && iter->Op()->Type() == "quantize";
});
return (counter > 1);
});
return prev_out;
}
// a -> transpose_op(1) -> transpose_out_a -> flatten_op(1) -> flatten_out_a
// b -> transpose_op(2) -> transpose_out_b -> flatten_op(2) -> flatten_out_b
// ...
......
......@@ -1004,6 +1004,16 @@ struct DequantAny : public PatternBase {
PATTERN_DECL_NODE(next_op);
};
// anyOp + more then one quantize op
// This pattern is used for squashing multiple quantize with the same scale.
struct MultipleQuantize : public PatternBase {
MultipleQuantize(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "multiple_quantize") {}
PDNode* operator()();
PATTERN_DECL_NODE(prev_out);
};
struct TransposeFlattenConcat : public PatternBase {
TransposeFlattenConcat(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "transpose_flatten_concat") {}
......
......@@ -228,6 +228,62 @@ void CPUQuantizeSquashPass::FcDequantSquash(Graph* graph) const {
found_fc_dequant_squash_count);
}
void CPUQuantizeSquashPass::MultipleQuantizeSquash(Graph* graph) const {
GraphPatternDetector gpd;
patterns::MultipleQuantize multiple_quantize_pattern{gpd.mutable_pattern(),
"multiple_quantize"};
multiple_quantize_pattern();
int found_multiple_quantize_squash_count = 0;
int removed_quantize = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "fuse multiple quantize ops";
GET_IR_NODE_FROM_SUBGRAPH(prev_out, prev_out, multiple_quantize_pattern);
auto* first_quant_op = *(std::find_if(
prev_out->outputs.begin(), prev_out->outputs.end(), [&](Node* node) {
return (node->IsOp() && node->Op()->Type() == "quantize");
}));
auto* first_quant_out = first_quant_op->outputs[0];
float scale = first_quant_op->Op()->GetAttrIfExists<float>("Scale");
PADDLE_ENFORCE_NE(scale, 0, platform::errors::InvalidArgument(
"Quantize scale should not be equal 0"));
for (int iter = prev_out->outputs.size() - 1; iter >= 0; iter--) {
auto quant_op = prev_out->outputs[iter];
if (quant_op->IsOp() && quant_op->Op()->Type() == "quantize" &&
quant_op->id() != first_quant_op->id() &&
quant_op->Op()->GetAttrIfExists<float>("Scale") == scale) {
auto quant_out = quant_op->outputs[0];
auto last_op = quant_out->outputs[0];
std::string last_op_input_name;
for (auto name : last_op->Op()->InputNames())
for (auto input_name : last_op->Op()->Input(name))
if (input_name == quant_out->Name()) last_op_input_name = name;
PADDLE_ENFORCE_NE(
last_op_input_name.empty(), true,
platform::errors::NotFound("Operator after quantize operator "
"should has quantize output as input"));
last_op->Op()->SetInput(
last_op_input_name,
std::vector<std::string>({first_quant_out->Name()}));
IR_NODE_LINK_TO(first_quant_out, last_op);
GraphSafeRemoveNodes(graph, {quant_op, quant_out});
removed_quantize++;
}
}
found_multiple_quantize_squash_count++;
};
gpd(graph, handler);
AddStatis(found_multiple_quantize_squash_count);
PrettyLogDetail("--- squashed %d quantize op", removed_quantize);
}
void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph,
......@@ -240,6 +296,7 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
DequantQuantSquash(graph, &nodes_keep_counter);
ConvDequantSquash(graph);
FcDequantSquash(graph);
MultipleQuantizeSquash(graph);
}
} // namespace ir
......
......@@ -65,6 +65,11 @@ class CPUQuantizeSquashPass : public FusePassBase {
*/
void FcDequantSquash(Graph* graph) const;
/*
* Squash quantize if several quatize ops have the same scale
*/
void MultipleQuantizeSquash(Graph* graph) const;
const std::string name_scope_{"squash"};
};
......
......@@ -230,6 +230,28 @@ ProgramDesc BuildConvDequantConvProgramDesc(bool use_mkldnn, float scale_out,
return prog;
}
// a->concat->b
// b->Quant1(Scale1)->c
// b->Quant2(Scale2)->d
// b->concat->e
// c->fc->f
// d->fc->g
ProgramDesc BuildMultipleQuantizeProgramDesc(bool use_mkldnn, float first_scale,
float second_scale) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "concat", "Concat1", {"a"}, {"b"}, use_mkldnn);
SetOp(&prog, "quantize", "Quantize1", {"b"}, {"c"}, use_mkldnn, first_scale);
SetOp(&prog, "quantize", "Quantize2", {"b"}, {"d"}, use_mkldnn, second_scale);
SetOp(&prog, "concat", "Concat2", {"b"}, {"e"}, use_mkldnn);
SetOp(&prog, "fc", "Fc1", {"c", "w1"}, {"f"}, use_mkldnn, first_scale);
SetOp(&prog, "fc", "Fc2", {"d", "w2"}, {"g"}, use_mkldnn, second_scale);
return prog;
}
void InitTensorHolder(Scope* scope, const paddle::platform::Place& place,
const char* var_name) {
auto x = scope->Var(var_name);
......@@ -467,6 +489,34 @@ TEST(CpuQuantizeSquashPass, fc_dequant_more_than_one_op_after_dequant) {
BuildFcDequantFcProgramDesc(use_mkldnn, scale_out, scale), "fc", false);
}
// a->Concat1->b
// b->Concat2
// b->Quatize1(Scale)->c
// c->Fc1
// c->Fc2
TEST(CpuQuantizeSquashPass, quatize_with_same_scale) {
auto first_scale = 1.2345f;
auto second_scale = 1.2345f;
auto use_mkldnn = true;
// remove nodes: Quantize2 + d
auto remove_nodes = 1 + 1;
CountNodeTest(
BuildMultipleQuantizeProgramDesc(use_mkldnn, first_scale, second_scale),
remove_nodes);
}
// if scales are not the same, do not fuse
TEST(CpuQuantizeSquashPass, quatize_with_different_scale) {
auto first_scale = 1.2345f;
auto second_scale = 1.5432f;
auto use_mkldnn = true;
// nothing change
auto remove_nodes = 0;
CountNodeTest(
BuildMultipleQuantizeProgramDesc(use_mkldnn, first_scale, second_scale),
remove_nodes);
}
} // namespace ir
} // namespace framework
} // namespace paddle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册