未验证 提交 17f2c089 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Add dequant-scale squash (#22409)

* Add dequant scale squash

test=develop

* Correct dequant-scale squash test

test=develop
上级 9c4deedb
...@@ -1522,6 +1522,25 @@ PDNode *patterns::FcDequant::operator()() { ...@@ -1522,6 +1522,25 @@ PDNode *patterns::FcDequant::operator()() {
return dequant_out; return dequant_out;
} }
PDNode *patterns::DequantScale::operator()() {
// Create Operators
auto dequant_op =
pattern->NewNode(dequant_op_repr())->assert_is_op("dequantize");
auto scale_op = pattern->NewNode(scale_op_repr())->assert_is_op("scale");
auto dequant_out = pattern->NewNode(dequant_out_repr())
->AsOutput()
->assert_is_op_output("dequantize", "Output");
auto scale_out = pattern->NewNode(scale_out_repr())
->AsOutput()
->assert_is_op_output("scale", "Out");
dequant_op->LinksTo({dequant_out});
scale_op->LinksFrom({dequant_out}).LinksTo({scale_out});
return scale_out;
}
PDNode *patterns::PriorBox::operator()() { PDNode *patterns::PriorBox::operator()() {
auto prior_box_op = auto prior_box_op =
pattern->NewNode(prior_box_op_repr())->assert_is_op("prior_box"); pattern->NewNode(prior_box_op_repr())->assert_is_op("prior_box");
......
...@@ -929,6 +929,20 @@ struct FcDequant : public PatternBase { ...@@ -929,6 +929,20 @@ struct FcDequant : public PatternBase {
PATTERN_DECL_NODE(dequant_out); PATTERN_DECL_NODE(dequant_out);
}; };
// Dequantize + Scale
struct DequantScale : public PatternBase {
DequantScale(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "dequant_scale") {}
PDNode* operator()();
PATTERN_DECL_NODE(dequant_op);
PATTERN_DECL_NODE(dequant_out);
PATTERN_DECL_NODE(scale_op);
PATTERN_DECL_NODE(scale_out);
};
// PriorBox operator // PriorBox operator
// operator: prior_box_op // operator: prior_box_op
// inputs: prior_box_input, prior_box_image // inputs: prior_box_input, prior_box_image
......
...@@ -284,6 +284,49 @@ void CPUQuantizeSquashPass::MultipleQuantizeSquash(Graph* graph) const { ...@@ -284,6 +284,49 @@ void CPUQuantizeSquashPass::MultipleQuantizeSquash(Graph* graph) const {
PrettyLogDetail("--- squashed %d quantize op", removed_quantize); PrettyLogDetail("--- squashed %d quantize op", removed_quantize);
} }
// squash scale with dequant
void CPUQuantizeSquashPass::DequantScaleSquash(Graph* graph) const {
GraphPatternDetector gpd;
patterns::DequantScale dequant_scale_pattern{gpd.mutable_pattern(),
"dequant_scale"};
dequant_scale_pattern();
int found_dequant_scale_squash_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "squash dequant-scale ops pair";
GET_IR_NODE_FROM_SUBGRAPH(dequant_op, dequant_op, dequant_scale_pattern);
GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, dequant_scale_pattern);
GET_IR_NODE_FROM_SUBGRAPH(scale_op, scale_op, dequant_scale_pattern);
GET_IR_NODE_FROM_SUBGRAPH(scale_out, scale_out, dequant_scale_pattern);
if (dequant_out->outputs.size() == 1 &&
scale_op->Op()->GetAttrIfExists<float>("bias") == 0.0) {
auto dequant_scale = dequant_op->Op()->GetAttrIfExists<float>("Scale");
auto scale_scale = scale_op->Op()->GetAttrIfExists<float>("scale");
PADDLE_ENFORCE_GT(dequant_scale, 0.0f,
platform::errors::InvalidArgument(
"Dequantize scale should have positive value"));
PADDLE_ENFORCE_GT(scale_scale, 0.0f,
platform::errors::InvalidArgument(
"Scale of scale op should have positive value"));
dequant_op->Op()->SetAttr("Scale", dequant_scale / scale_scale);
dequant_op->Op()->SetOutput(
"Output", std::vector<std::string>({scale_out->Name()}));
IR_NODE_LINK_TO(dequant_op, scale_out);
GraphSafeRemoveNodes(graph, {dequant_out, scale_op});
found_dequant_scale_squash_count++;
}
};
gpd(graph, handler);
AddStatis(found_dequant_scale_squash_count);
PrettyLogDetail("--- squashed %d scale with dequant",
found_dequant_scale_squash_count);
}
void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const { void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, graph,
...@@ -298,6 +341,7 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const { ...@@ -298,6 +341,7 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
ConvDequantSquash(graph); ConvDequantSquash(graph);
FcDequantSquash(graph); FcDequantSquash(graph);
MultipleQuantizeSquash(graph); MultipleQuantizeSquash(graph);
DequantScaleSquash(graph);
} }
} // namespace ir } // namespace ir
......
...@@ -70,6 +70,11 @@ class CPUQuantizeSquashPass : public FusePassBase { ...@@ -70,6 +70,11 @@ class CPUQuantizeSquashPass : public FusePassBase {
*/ */
void MultipleQuantizeSquash(Graph* graph) const; void MultipleQuantizeSquash(Graph* graph) const;
/*
* Squash scale if dequantize is before scale
*/
void DequantScaleSquash(Graph* graph) const;
const std::string name_scope_{"squash"}; const std::string name_scope_{"squash"};
}; };
......
...@@ -24,7 +24,7 @@ namespace ir { ...@@ -24,7 +24,7 @@ namespace ir {
void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
const std::vector<std::string>& inputs, const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs, bool use_mkldnn, const std::vector<std::string>& outputs, bool use_mkldnn,
float scale = 0) { float scale = 0, float bias = 0.0) {
auto* op = prog->MutableBlock(0)->AppendOp(); auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type); op->SetType(type);
op->SetAttr("use_mkldnn", use_mkldnn); op->SetAttr("use_mkldnn", use_mkldnn);
...@@ -59,6 +59,11 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, ...@@ -59,6 +59,11 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
inputs.size())); inputs.size()));
op->SetInput("W", {inputs[1]}); op->SetInput("W", {inputs[1]});
op->SetOutput("Out", outputs); op->SetOutput("Out", outputs);
} else if (type == "scale") {
op->SetInput("X", {inputs[0]});
op->SetOutput("Out", {outputs[0]});
op->SetAttr("scale", scale);
op->SetAttr("bias", bias);
} }
} }
...@@ -252,6 +257,21 @@ ProgramDesc BuildMultipleQuantizeProgramDesc(bool use_mkldnn, float first_scale, ...@@ -252,6 +257,21 @@ ProgramDesc BuildMultipleQuantizeProgramDesc(bool use_mkldnn, float first_scale,
return prog; return prog;
} }
// a->Dequant->b
// b->Scale->c
ProgramDesc BuildDequantScaleProgramDesc(bool use_mkldnn, float dequant_scale,
float scale_scale, float bias) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "dequantize", "Dequant", {"a"}, {"b"}, use_mkldnn,
dequant_scale);
SetOp(&prog, "scale", "Scale", {"b"}, {"c"}, use_mkldnn, scale_scale, bias);
return prog;
}
void InitTensorHolder(Scope* scope, const paddle::platform::Place& place, void InitTensorHolder(Scope* scope, const paddle::platform::Place& place,
const char* var_name) { const char* var_name) {
auto x = scope->Var(var_name); auto x = scope->Var(var_name);
...@@ -289,17 +309,17 @@ void CountNodeTest(const ProgramDesc& prog, int removed_nodes_num) { ...@@ -289,17 +309,17 @@ void CountNodeTest(const ProgramDesc& prog, int removed_nodes_num) {
} }
// check op->scale_out // check op->scale_out
void EqualScaleOutTest(const ProgramDesc& prog, const std::string& name, void EqualScaleTest(const ProgramDesc& prog, const std::string& op_name,
float scale) { const std::string& scale_name, float scale) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog)); std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
PrepareGraph(&graph, prog); PrepareGraph(&graph, prog);
RegisterPass(&graph); RegisterPass(&graph);
for (auto* node : graph->Nodes()) { for (auto* node : graph->Nodes()) {
if (node->IsOp() && if (node->IsOp() &&
boost::get<std::string>(node->Op()->GetAttr("name")) == name) { boost::get<std::string>(node->Op()->GetAttr("name")) == op_name) {
float scale_out = boost::get<float>(node->Op()->GetAttr("Scale_out")); float op_scale = boost::get<float>(node->Op()->GetAttr(scale_name));
EXPECT_EQ(scale_out, scale); EXPECT_EQ(op_scale, scale);
} }
} }
} }
...@@ -368,9 +388,9 @@ TEST(CpuQuantizeSquashPass, unequal_scales) { ...@@ -368,9 +388,9 @@ TEST(CpuQuantizeSquashPass, unequal_scales) {
BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale1, scale2), BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale1, scale2),
remove_nodes); remove_nodes);
EqualScaleOutTest( EqualScaleTest(
BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale1, scale2), BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale1, scale2),
"Conv1", scale2); "Conv1", "Scale_out", scale2);
} }
// a->Conv1->b->Requant->c // a->Conv1->b->Requant->c
...@@ -388,12 +408,12 @@ TEST(CpuQuantizeSquashPass, equal_scales_squash_requantize) { ...@@ -388,12 +408,12 @@ TEST(CpuQuantizeSquashPass, equal_scales_squash_requantize) {
remove_nodes); remove_nodes);
// check equal scale conv->scale_out and requant->scale_out // check equal scale conv->scale_out and requant->scale_out
EqualScaleOutTest( EqualScaleTest(
BuildConvsRequantConcatProgramDesc(use_mkldnn, scale_out, scale, scale), BuildConvsRequantConcatProgramDesc(use_mkldnn, scale_out, scale, scale),
"Conv1", scale); "Conv1", "Scale_out", scale);
EqualScaleOutTest( EqualScaleTest(
BuildConvsRequantConcatProgramDesc(use_mkldnn, scale_out, scale, scale), BuildConvsRequantConcatProgramDesc(use_mkldnn, scale_out, scale, scale),
"Conv2", scale); "Conv2", "Scale_out", scale);
} }
// from // from
...@@ -544,6 +564,37 @@ TEST(CpuQuantizeSquashPass, quatize_with_different_scale) { ...@@ -544,6 +564,37 @@ TEST(CpuQuantizeSquashPass, quatize_with_different_scale) {
remove_nodes); remove_nodes);
} }
// if scale has no bias
TEST(CpuQuantizeSquashPass, dequantize_scale_with_no_bias) {
auto dequant_scale = 1.2345f;
auto scale_scale = 1.5432f;
auto bias = 0.0f;
auto use_mkldnn = true;
// remove: dequant out, scale op
auto remove_nodes = 2;
CountNodeTest(BuildDequantScaleProgramDesc(use_mkldnn, dequant_scale,
scale_scale, bias),
remove_nodes);
EqualScaleTest(BuildDequantScaleProgramDesc(use_mkldnn, dequant_scale,
scale_scale, bias),
"Dequant", "Scale", dequant_scale / scale_scale);
}
// if scale has bias
TEST(CpuQuantizeSquashPass, dequantize_scale_with_bias) {
auto dequant_scale = 1.2345f;
auto scale_scale = 1.5432f;
auto bias = 1.0f;
auto use_mkldnn = true;
// nothing change
auto remove_nodes = 0;
CountNodeTest(BuildDequantScaleProgramDesc(use_mkldnn, dequant_scale,
scale_scale, bias),
remove_nodes);
EqualScaleTest(BuildDequantScaleProgramDesc(use_mkldnn, dequant_scale,
scale_scale, bias),
"Dequant", "Scale", dequant_scale);
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部