未验证 提交 3cb5623d 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Add matmul dequant squash (#23505)

test=develop
上级 3d5d2170
......@@ -1562,6 +1562,23 @@ PDNode *patterns::DequantScale::operator()() {
return scale_out;
}
PDNode *patterns::MatmulDequant::operator()() {
auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op("matmul");
auto dequant_op =
pattern->NewNode(dequant_op_repr())->assert_is_op("dequantize");
auto matmul_out = pattern->NewNode(matmul_out_repr())
->AsOutput()
->assert_is_op_output("matmul", "Out");
auto dequant_out = pattern->NewNode(dequant_out_repr())
->AsOutput()
->assert_is_op_output("dequantize", "Output");
matmul_op->LinksTo({matmul_out});
dequant_op->LinksFrom({matmul_out}).LinksTo({dequant_out});
return dequant_out;
}
PDNode *patterns::PriorBox::operator()() {
auto prior_box_op =
pattern->NewNode(prior_box_op_repr())->assert_is_op("prior_box");
......
......@@ -959,6 +959,20 @@ struct DequantScale : public PatternBase {
PATTERN_DECL_NODE(scale_out);
};
// Matmul + Dequantize
struct MatmulDequant : public PatternBase {
MatmulDequant(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "matmul_dequant") {}
PDNode* operator()();
PATTERN_DECL_NODE(matmul_op);
PATTERN_DECL_NODE(matmul_out);
PATTERN_DECL_NODE(dequant_op);
PATTERN_DECL_NODE(dequant_out);
};
// PriorBox operator
// operator: prior_box_op
// inputs: prior_box_input, prior_box_image
......
......@@ -327,6 +327,38 @@ void CPUQuantizeSquashPass::DequantScaleSquash(Graph* graph) const {
found_dequant_scale_squash_count);
}
// squash dequant with dequant
void CPUQuantizeSquashPass::MatmulDequantSquash(Graph* graph) const {
GraphPatternDetector gpd;
patterns::MatmulDequant matmul_dequant_pattern{gpd.mutable_pattern(),
"matmul_dequant"};
matmul_dequant_pattern();
int found_matmul_dequant_squash_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "squash matmul-dequant ops pair";
GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, matmul_dequant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, matmul_dequant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(dequant_op, dequant_op, matmul_dequant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, matmul_dequant_pattern);
if (matmul_out->outputs.size() == 1) {
matmul_op->Op()->SetAttr("force_fp32_output", true);
matmul_op->Op()->SetOutput(
"Out", std::vector<std::string>({dequant_out->Name()}));
IR_NODE_LINK_TO(matmul_op, dequant_out);
GraphSafeRemoveNodes(graph, {matmul_out, dequant_op});
found_matmul_dequant_squash_count++;
}
};
gpd(graph, handler);
AddStatis(found_matmul_dequant_squash_count);
PrettyLogDetail("--- squashed %d dequant with matmul",
found_matmul_dequant_squash_count);
}
void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph,
......@@ -342,6 +374,7 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
FcDequantSquash(graph);
MultipleQuantizeSquash(graph);
DequantScaleSquash(graph);
MatmulDequantSquash(graph);
}
} // namespace ir
......
......@@ -75,6 +75,11 @@ class CPUQuantizeSquashPass : public FusePassBase {
*/
void DequantScaleSquash(Graph* graph) const;
/*
* Squash dequantize if it is after matmul
*/
void MatmulDequantSquash(Graph* graph) const;
const std::string name_scope_{"squash"};
};
......
......@@ -64,6 +64,10 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op->SetOutput("Out", {outputs[0]});
op->SetAttr("scale", scale);
op->SetAttr("bias", bias);
} else if (type == "matmul") {
op->SetInput("X", {inputs[0]});
op->SetInput("Y", {inputs[1]});
op->SetOutput("Out", {outputs[0]});
}
}
......@@ -92,7 +96,7 @@ ProgramDesc BuildConvRequantProgramDesc(bool use_mkldnn, float scale_out,
}
static const std::initializer_list<std::string> variable_names{
"a", "b", "c", "d", "e", "f", "g", "h"};
"a", "b", "c", "d", "e", "f", "g", "h", "x", "y"};
// a->Conv1->b
// b->Dequant(scale1)->c
......@@ -272,6 +276,21 @@ ProgramDesc BuildDequantScaleProgramDesc(bool use_mkldnn, float dequant_scale,
return prog;
}
// {x,y}->Matmul->b
// b->Dequant->c
ProgramDesc BuildMatmulDequantProgramDesc(bool use_mkldnn,
float dequant_scale) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "matmul", "Matmul", {"x", "y"}, {"b"}, use_mkldnn);
SetOp(&prog, "dequantize", "Dequant", {"b"}, {"c"}, use_mkldnn,
dequant_scale);
return prog;
}
void InitTensorHolder(Scope* scope, const paddle::platform::Place& place,
const char* var_name) {
auto x = scope->Var(var_name);
......@@ -595,6 +614,17 @@ TEST(CpuQuantizeSquashPass, dequantize_scale_with_bias) {
scale_scale, bias),
"Dequant", "Scale", dequant_scale);
}
TEST(CpuQuantizeSquashPass, matmul_with_dequant) {
auto dequant_scale = 1.2345f;
auto use_mkldnn = true;
// remove: matmul_out, dequant_op
auto remove_nodes = 2;
CountNodeTest(BuildMatmulDequantProgramDesc(use_mkldnn, dequant_scale),
remove_nodes);
IsForceFp32OutputTest(
BuildMatmulDequantProgramDesc(use_mkldnn, dequant_scale), "matmul", true);
}
} // namespace ir
} // namespace framework
} // namespace paddle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册