提交 492a00f5 编写于 作者: J joanna.wozna.intel 提交者: Tao Luo

Add conv reqantize squash (#18754)

* Add requantize squash

test=develop

* Add more precise tests
test=develop

* REname and REfactor tester

test=develop
上级 c548e370
......@@ -1275,6 +1275,23 @@ PDNode *patterns::ConvConcatReLU::operator()() {
return relu_out;
}
PDNode *patterns::ConvRequant::operator()() {
// Create Operators
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
auto requant_op =
pattern->NewNode(requant_op_repr())->assert_is_op("requantize");
auto conv_out = pattern->NewNode(conv_out_repr())
->assert_is_op_output("conv2d", "Output");
auto requant_out = pattern->NewNode(requant_out_repr())
->AsOutput()
->assert_is_op_output("requantize", "Output");
conv_op->LinksTo({conv_out});
requant_op->LinksFrom({conv_out}).LinksTo({requant_out});
return requant_out;
}
PDNode *patterns::PriorBox::operator()() {
auto prior_box_op =
pattern->NewNode(prior_box_op_repr())->assert_is_op("prior_box");
......
......@@ -796,6 +796,23 @@ struct ConvConcatReLU : public PatternBase {
PATTERN_DECL_NODE(relu_out);
};
// Conv + Requant
// named nodes:
// conv_op, conv_out
// requant_op, requant_out
struct ConvRequant : public PatternBase {
ConvRequant(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_requant") {}
PDNode* operator()();
PATTERN_DECL_NODE(conv_op);
PATTERN_DECL_NODE(conv_out);
PATTERN_DECL_NODE(requant_op);
PATTERN_DECL_NODE(requant_out);
};
// PriorBox operator
// operator: prior_box_op
// inputs: prior_box_input, prior_box_image
......
......@@ -49,14 +49,14 @@ void CPUQuantizeSquashPass::FindNodesToKeep(
AddStatis(found_count);
}
void CPUQuantizeSquashPass::Squash(
void CPUQuantizeSquashPass::DequantQuantSquash(
Graph* graph,
std::unordered_map<const Node*, int>* nodes_keep_counter) const {
GraphPatternDetector gpd;
patterns::DequantQuantAny squash_pattern{gpd.mutable_pattern(), "squash"};
squash_pattern();
int found_squash_count = 0;
int found_dequant_quant_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "squash requantize-quantize ops pair";
......@@ -96,7 +96,7 @@ void CPUQuantizeSquashPass::Squash(
IR_NODE_LINK_TO(dequant_in, next_op);
found_squash_count++;
found_dequant_quant_count++;
} else {
// squash dequantize-quantize to requantize op
OpDesc desc;
......@@ -116,13 +116,48 @@ void CPUQuantizeSquashPass::Squash(
IR_NODE_LINK_TO(dequant_in, requant_op);
IR_NODE_LINK_TO(requant_op, quant_out);
found_squash_count++;
found_dequant_quant_count++;
}
};
gpd(graph, handler);
AddStatis(found_squash_count);
AddStatis(found_dequant_quant_count);
PrettyLogDetail("--- squashed %d dequantize-quantize pairs",
found_squash_count);
found_dequant_quant_count);
}
void CPUQuantizeSquashPass::ConvRequantSquash(Graph* graph) const {
GraphPatternDetector gpd;
patterns::ConvRequant conv_requant_pattern{gpd.mutable_pattern(),
"conv_requant"};
conv_requant_pattern();
int found_requant_squash_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "squash conv-requantize ops pair";
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_requant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_requant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(requant_op, requant_op, conv_requant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(requant_out, requant_out, conv_requant_pattern);
// if conv2d has one output squash
if (conv_out->outputs.size() == 1) {
float requant_scale_out =
boost::get<float>(requant_op->Op()->GetAttr("Scale_out"));
conv_op->Op()->SetAttr("Scale_out", requant_scale_out);
conv_op->Op()->SetOutput("Output",
std::vector<std::string>({requant_out->Name()}));
IR_NODE_LINK_TO(conv_op, requant_out);
GraphSafeRemoveNodes(graph, {conv_out, requant_op});
found_requant_squash_count++;
}
};
gpd(graph, handler);
AddStatis(found_requant_squash_count);
PrettyLogDetail("--- squashed %d requantize with convs",
found_requant_squash_count);
}
void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
......@@ -131,7 +166,8 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
std::unordered_map<const Node*, int> nodes_keep_counter;
FindNodesToKeep(graph, &nodes_keep_counter);
Squash(graph, &nodes_keep_counter);
DequantQuantSquash(graph, &nodes_keep_counter);
ConvRequantSquash(graph);
}
} // namespace ir
......
......@@ -46,8 +46,14 @@ class CPUQuantizeSquashPass : public FusePassBase {
/*
* Squash dequantize-quantize ops pairs into requantize or nothing
*/
void Squash(Graph* graph,
std::unordered_map<const Node*, int>* nodes_keep_counter) const;
void DequantQuantSquash(
Graph* graph,
std::unordered_map<const Node*, int>* nodes_keep_counter) const;
/*
* Squash requantize op into conv with scale_out like requantize scale_out
*/
void ConvRequantSquash(Graph* graph) const;
const std::string name_scope_{"squash"};
};
......
......@@ -30,6 +30,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op->SetAttr("use_mkldnn", use_mkldnn);
op->SetAttr("name", name);
if (type == "conv2d") {
op->SetAttr("Scale_out", scale);
op->SetInput("Input", {inputs[0]});
if (inputs.size() > 1) op->SetInput("Filter", {inputs[1]});
if (inputs.size() > 2) op->SetInput("Bias", {inputs[2]});
......@@ -42,14 +43,22 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op->SetInput("Input", {inputs[0]});
op->SetOutput("Output", {outputs[0]});
op->SetAttr("Scale", scale);
} else if (type == "requantize") {
op->SetInput("Input", {inputs[0]});
op->SetOutput("Output", {outputs[0]});
op->SetAttr("Scale_out", scale);
} else if (type == "concat") {
op->SetInput("X", inputs);
op->SetOutput("Out", outputs);
}
}
// (a,w1,b1)->Conv1->d
// d->Dequant->e
// e->Quant->f
// d->Dequant(scale1)->e
// e->Quant(scale2)->f
// (f,w2,b2)->Conv2->i
ProgramDesc BuildProgramDesc(bool use_mkldnn, float scale1, float scale2) {
ProgramDesc BuildConvRequantProgramDesc(bool use_mkldnn, float scale_out,
float scale1, float scale2) {
ProgramDesc prog;
for (auto& v : std::initializer_list<std::string>(
{"a", "w1", "b1", "d", "e", "f", "w2", "b2", "i"})) {
......@@ -59,42 +68,96 @@ ProgramDesc BuildProgramDesc(bool use_mkldnn, float scale1, float scale2) {
}
}
SetOp(&prog, "conv2d", "Conv1", {"a", "w1", "b1"}, {"d"}, use_mkldnn);
SetOp(&prog, "conv2d", "Conv1", {"a", "w1", "b1"}, {"d"}, use_mkldnn,
scale_out);
SetOp(&prog, "dequantize", "Dequant", {"d"}, {"e"}, use_mkldnn, scale1);
SetOp(&prog, "quantize", "Quant", {"e"}, {"f"}, use_mkldnn, scale2);
SetOp(&prog, "conv2d", "Conv2", {"f", "w2", "b2"}, {"i"}, use_mkldnn);
SetOp(&prog, "conv2d", "Conv2", {"f", "w2", "b2"}, {"i"}, use_mkldnn,
scale_out);
return prog;
}
static const std::initializer_list<std::string> variable_names{
"a", "b", "c", "d", "e", "f", "g", "h"};
// a->Conv1->b
// b->Dequant->c
//
// c->Quant1->d and d->Conv2->e
//
// b->Dequant(scale1)->c
// c->Quant1(scale2)->d and d->Conv2->e
// c->Conv3->f
//
// c->Quant2->g and g->Conv4->h
//
ProgramDesc BuildProgramDesc2(bool use_mkldnn, float scale1, float scale2,
float scale3) {
// c->Quant2(scale3)->g and g->Conv4->h
ProgramDesc BuildConvMultiOutputProgramDesc(bool use_mkldnn, float scale_out,
float scale1, float scale2,
float scale3) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn);
SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, scale_out);
SetOp(&prog, "dequantize", "Dequant", {"b"}, {"c"}, use_mkldnn, scale1);
SetOp(&prog, "quantize", "Quant1", {"c"}, {"d"}, use_mkldnn, scale2);
SetOp(&prog, "conv2d", "Conv2", {"d"}, {"e"}, use_mkldnn);
SetOp(&prog, "conv2d", "Conv2", {"d"}, {"e"}, use_mkldnn, scale_out);
SetOp(&prog, "conv2d", "Conv3", {"c"}, {"f"}, use_mkldnn);
SetOp(&prog, "conv2d", "Conv3", {"c"}, {"f"}, use_mkldnn, scale_out);
SetOp(&prog, "quantize", "Quant2", {"c"}, {"g"}, use_mkldnn, scale3);
SetOp(&prog, "conv2d", "Conv4", {"g"}, {"h"}, use_mkldnn);
SetOp(&prog, "conv2d", "Conv4", {"g"}, {"h"}, use_mkldnn, scale_out);
return prog;
}
// a->Conv1->b->Requant(scale1)->c
// d->Conv2->e->Requant(scale2)->f
// {c,f}->Concat
ProgramDesc BuildConvsRequantConcatProgramDesc(bool use_mkldnn, float scale_out,
float scale1, float scale2) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, scale_out);
SetOp(&prog, "requantize", "Requant1", {"b"}, {"c"}, use_mkldnn, scale1);
SetOp(&prog, "conv2d", "Conv2", {"d"}, {"e"}, use_mkldnn, scale_out);
SetOp(&prog, "requantize", "Requant2", {"e"}, {"f"}, use_mkldnn, scale2);
SetOp(&prog, "concat", "Concat", {"c"}, {"f"}, use_mkldnn);
return prog;
}
// a->Concat->b
// b->Dequant(scale1)->c
// c->Quant(scale2)->d
// d->Conv->e
ProgramDesc BuildConcatDequantQuantProgramDesc(bool use_mkldnn, float scale_out,
float scale1, float scale2) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "concat", "Concat", {"a"}, {"b"}, use_mkldnn);
SetOp(&prog, "dequantize", "Dequant", {"b"}, {"c"}, use_mkldnn, scale1);
SetOp(&prog, "quantize", "Quant", {"c"}, {"d"}, use_mkldnn, scale2);
SetOp(&prog, "conv2d", "Conv2", {"d"}, {"e"}, use_mkldnn, scale_out);
return prog;
}
// a->Conv1->b
// b->Requant1(Scale1)->c
// b->Requant2(Scale2)->d
ProgramDesc BuildConvMultiRequantProgramDesc(bool use_mkldnn, float scale_out,
float scale1, float scale2) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, scale_out);
SetOp(&prog, "requantize", "Requant1", {"b"}, {"c"}, use_mkldnn, scale1);
SetOp(&prog, "requantize", "Requant2", {"b"}, {"d"}, use_mkldnn, scale2);
return prog;
}
......@@ -105,10 +168,7 @@ void InitTensorHolder(Scope* scope, const paddle::platform::Place& place,
tensor->mutable_data(place, proto::VarType::FP32, 1);
}
void MainTest(const ProgramDesc& prog, int removed_nodes_num) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
// Init scope, as it is used in pass
void PrepareGraph(std::unique_ptr<ir::Graph>* graph, const ProgramDesc& prog) {
auto place = paddle::platform::CPUPlace();
NaiveExecutor exe{place};
Scope scope;
......@@ -117,58 +177,172 @@ void MainTest(const ProgramDesc& prog, int removed_nodes_num) {
for (auto& v : variable_names) {
InitTensorHolder(&scope, place, v.c_str());
}
(*graph)->SetNotOwned(kParamScopeAttr, &scope);
}
graph->SetNotOwned(kParamScopeAttr, &scope);
void RegisterPass(std::unique_ptr<ir::Graph>* graph) {
auto pass = PassRegistry::Instance().Get("cpu_quantize_squash_pass");
graph->reset(pass->Apply(graph->release()));
}
int original_nodes_num = graph->Nodes().size();
graph.reset(pass->Apply(graph.release()));
// check number of nodes
void CountNodeTest(const ProgramDesc& prog, int removed_nodes_num) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
PrepareGraph(&graph, prog);
int original_nodes_num = graph->Nodes().size();
RegisterPass(&graph);
int current_nodes_num = graph->Nodes().size();
EXPECT_EQ(original_nodes_num - removed_nodes_num, current_nodes_num);
}
// check op->scale_out
void EqualScaleOutTest(const ProgramDesc& prog, const std::string& name,
float scale) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
PrepareGraph(&graph, prog);
RegisterPass(&graph);
for (auto* node : graph->Nodes()) {
if (node->IsOp() &&
boost::get<std::string>(node->Op()->GetAttr("name")) == name) {
float scale_out = boost::get<float>(node->Op()->GetAttr("Scale_out"));
EXPECT_EQ(scale_out, scale);
}
}
}
// check requant_op scales
void CheckRequantScalesTest(const ProgramDesc& prog, float scale_in,
float scale_out) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
PrepareGraph(&graph, prog);
RegisterPass(&graph);
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == "requantize") {
float op_scale_in = boost::get<float>(node->Op()->GetAttr("Scale_in"));
EXPECT_EQ(op_scale_in, scale_in);
float op_scale_out = boost::get<float>(node->Op()->GetAttr("Scale_out"));
EXPECT_EQ(op_scale_out, scale_out);
}
}
}
// From Conv1->d->Dequant->e->Quant->f->Conv2
// To Conv1->d->Conv2
TEST(CpuQuantizeSquashPass, equal_scales) {
auto scale_out = 1.0f;
auto scale = 1.2345f;
auto use_mkldnn = true;
// Remove 4 nodes: Dequant, Quant, e, f
auto remove_nodes = 4;
MainTest(BuildProgramDesc(use_mkldnn, scale, scale), remove_nodes);
use_mkldnn = !use_mkldnn;
MainTest(BuildProgramDesc(use_mkldnn, scale, scale), remove_nodes);
CountNodeTest(
BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale, scale),
remove_nodes);
}
TEST(CpuQuantizeSquashPass, inequal_scales) {
// From Conv1->d->Dequant->e->Quant->f->Conv2
// First change to Conv1->d->Requant->f->Conv2
// Then Conv1->f->Conv2
TEST(CpuQuantizeSquashPass, unequal_scales) {
auto scale_out = 1.0f;
auto scale1 = 1.2345f;
auto scale2 = 21.0f;
auto use_mkldnn = true;
// Remove 3 nodes: Dequant, Quant, e
// Insert 1 node: requantize
// Remove 4 nodes: Dequant, Quant, e, d
auto remove_nodes = 4;
CountNodeTest(
BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale1, scale2),
remove_nodes);
EqualScaleOutTest(
BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale1, scale2),
"Conv1", scale2);
}
// from
// a->Conv1->b->Dequant(Scale1)->c
// c->Quant1(Scale1)->d and d->Conv2->e
// c->Quant2(Scale2)->g and g->Conv4->h
// c->Conv3->f
// to
// a->Conv1->b
// b->Conv2->e
// b->Requant(Scale_in = Scale1; Scale_out = Scale2)->g->Conv4->h
// b->Dequant(Scale1)->c->Conv3->f
TEST(CpuQuantizeSquashPass, branch_to_equal_unequal_and_fp32) {
auto scale_out = 1.0f;
auto scale = 1.2345f;
auto scale2 = 21.0f;
auto use_mkldnn = true;
// Remove 3 nodes: Quant1, c, Quant2,
// Insert 1 node: Requant
auto remove_nodes = 2;
MainTest(BuildProgramDesc(use_mkldnn, scale1, scale2), remove_nodes);
CountNodeTest(BuildConvMultiOutputProgramDesc(use_mkldnn, scale_out, scale,
scale, scale2),
remove_nodes);
CheckRequantScalesTest(BuildConvMultiOutputProgramDesc(use_mkldnn, scale_out,
scale, scale, scale2),
scale, scale2);
}
use_mkldnn = !use_mkldnn;
MainTest(BuildProgramDesc(use_mkldnn, scale1, scale2), remove_nodes);
// a->Conv1->b->Requant->c
// d->Conv2->e->Requant->f
// {c,f}->Concat
TEST(CpuQuantizeSquashPass, equal_scales_squash_requantize) {
// Delete both requantize op
auto scale_out = 1.0f;
auto scale = 1.2345f;
auto use_mkldnn = true;
// Remove 4 nodes: b, Requant1, e, Requant2
auto remove_nodes = 4;
CountNodeTest(
BuildConvsRequantConcatProgramDesc(use_mkldnn, scale_out, scale, scale),
remove_nodes);
// check equal scale conv->scale_out and requant->scale_out
EqualScaleOutTest(
BuildConvsRequantConcatProgramDesc(use_mkldnn, scale_out, scale, scale),
"Conv1", scale);
EqualScaleOutTest(
BuildConvsRequantConcatProgramDesc(use_mkldnn, scale_out, scale, scale),
"Conv2", scale);
}
TEST(CpuQuantizeSquashPass, branch_to_equal_inequal_and_fp32) {
// Delete both quantize ops,
// bypass dequantize in both branches,
// insert requantize on one branch
// a->Concat->b->Dequant->c->Quant->d->Conv->e
// to a->Concat->b->Requant->d->Conv->e
TEST(CpuQuantizeSquashPass,
unequal_scales_squash_dequantize_quantize_into_requantize) {
auto scale_out = 1.0f;
auto scale = 1.2345f;
auto scale2 = 21.0f;
auto use_mkldnn = true;
// Remove 3 nodes: Quant1, Quant2, g
// Insert 1 node: requantize
// Remove 3 nodes: Dequant1, c, Quant
// Insert 1 node: Requant
auto remove_nodes = 2;
MainTest(BuildProgramDesc2(use_mkldnn, scale, scale, scale2), remove_nodes);
CountNodeTest(
BuildConcatDequantQuantProgramDesc(use_mkldnn, scale_out, scale, scale2),
remove_nodes);
CheckRequantScalesTest(
BuildConcatDequantQuantProgramDesc(use_mkldnn, scale_out, scale, scale2),
scale, scale2);
}
use_mkldnn = !use_mkldnn;
MainTest(BuildProgramDesc2(use_mkldnn, scale, scale, scale2), remove_nodes);
// a->Conv1->b
// b->Requant1(Scale1)->c
// b->Requant2(Scale2)->d
TEST(CpuQuantizeSquashPass, more_than_one_conv_out_outputs) {
auto scale_out = 1.0f;
auto scale = 1.2345f;
auto scale2 = 21.0f;
auto use_mkldnn = true;
// nothing change
auto remove_nodes = 0;
CountNodeTest(
BuildConvMultiRequantProgramDesc(use_mkldnn, scale_out, scale, scale2),
remove_nodes);
}
} // namespace ir
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册