提交 492a00f5 编写于 作者: J joanna.wozna.intel 提交者: Tao Luo

Add conv reqantize squash (#18754)

* Add requantize squash

test=develop

* Add more precise tests
test=develop

* REname and REfactor tester

test=develop
上级 c548e370
...@@ -1275,6 +1275,23 @@ PDNode *patterns::ConvConcatReLU::operator()() { ...@@ -1275,6 +1275,23 @@ PDNode *patterns::ConvConcatReLU::operator()() {
return relu_out; return relu_out;
} }
PDNode *patterns::ConvRequant::operator()() {
// Create Operators
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
auto requant_op =
pattern->NewNode(requant_op_repr())->assert_is_op("requantize");
auto conv_out = pattern->NewNode(conv_out_repr())
->assert_is_op_output("conv2d", "Output");
auto requant_out = pattern->NewNode(requant_out_repr())
->AsOutput()
->assert_is_op_output("requantize", "Output");
conv_op->LinksTo({conv_out});
requant_op->LinksFrom({conv_out}).LinksTo({requant_out});
return requant_out;
}
PDNode *patterns::PriorBox::operator()() { PDNode *patterns::PriorBox::operator()() {
auto prior_box_op = auto prior_box_op =
pattern->NewNode(prior_box_op_repr())->assert_is_op("prior_box"); pattern->NewNode(prior_box_op_repr())->assert_is_op("prior_box");
......
...@@ -796,6 +796,23 @@ struct ConvConcatReLU : public PatternBase { ...@@ -796,6 +796,23 @@ struct ConvConcatReLU : public PatternBase {
PATTERN_DECL_NODE(relu_out); PATTERN_DECL_NODE(relu_out);
}; };
// Conv + Requant
// named nodes:
// conv_op, conv_out
// requant_op, requant_out
struct ConvRequant : public PatternBase {
ConvRequant(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_requant") {}
PDNode* operator()();
PATTERN_DECL_NODE(conv_op);
PATTERN_DECL_NODE(conv_out);
PATTERN_DECL_NODE(requant_op);
PATTERN_DECL_NODE(requant_out);
};
// PriorBox operator // PriorBox operator
// operator: prior_box_op // operator: prior_box_op
// inputs: prior_box_input, prior_box_image // inputs: prior_box_input, prior_box_image
......
...@@ -49,14 +49,14 @@ void CPUQuantizeSquashPass::FindNodesToKeep( ...@@ -49,14 +49,14 @@ void CPUQuantizeSquashPass::FindNodesToKeep(
AddStatis(found_count); AddStatis(found_count);
} }
void CPUQuantizeSquashPass::Squash( void CPUQuantizeSquashPass::DequantQuantSquash(
Graph* graph, Graph* graph,
std::unordered_map<const Node*, int>* nodes_keep_counter) const { std::unordered_map<const Node*, int>* nodes_keep_counter) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
patterns::DequantQuantAny squash_pattern{gpd.mutable_pattern(), "squash"}; patterns::DequantQuantAny squash_pattern{gpd.mutable_pattern(), "squash"};
squash_pattern(); squash_pattern();
int found_squash_count = 0; int found_dequant_quant_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
VLOG(4) << "squash requantize-quantize ops pair"; VLOG(4) << "squash requantize-quantize ops pair";
...@@ -96,7 +96,7 @@ void CPUQuantizeSquashPass::Squash( ...@@ -96,7 +96,7 @@ void CPUQuantizeSquashPass::Squash(
IR_NODE_LINK_TO(dequant_in, next_op); IR_NODE_LINK_TO(dequant_in, next_op);
found_squash_count++; found_dequant_quant_count++;
} else { } else {
// squash dequantize-quantize to requantize op // squash dequantize-quantize to requantize op
OpDesc desc; OpDesc desc;
...@@ -116,13 +116,48 @@ void CPUQuantizeSquashPass::Squash( ...@@ -116,13 +116,48 @@ void CPUQuantizeSquashPass::Squash(
IR_NODE_LINK_TO(dequant_in, requant_op); IR_NODE_LINK_TO(dequant_in, requant_op);
IR_NODE_LINK_TO(requant_op, quant_out); IR_NODE_LINK_TO(requant_op, quant_out);
found_squash_count++; found_dequant_quant_count++;
} }
}; };
gpd(graph, handler); gpd(graph, handler);
AddStatis(found_squash_count); AddStatis(found_dequant_quant_count);
PrettyLogDetail("--- squashed %d dequantize-quantize pairs", PrettyLogDetail("--- squashed %d dequantize-quantize pairs",
found_squash_count); found_dequant_quant_count);
}
void CPUQuantizeSquashPass::ConvRequantSquash(Graph* graph) const {
GraphPatternDetector gpd;
patterns::ConvRequant conv_requant_pattern{gpd.mutable_pattern(),
"conv_requant"};
conv_requant_pattern();
int found_requant_squash_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "squash conv-requantize ops pair";
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_requant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_requant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(requant_op, requant_op, conv_requant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(requant_out, requant_out, conv_requant_pattern);
// if conv2d has one output squash
if (conv_out->outputs.size() == 1) {
float requant_scale_out =
boost::get<float>(requant_op->Op()->GetAttr("Scale_out"));
conv_op->Op()->SetAttr("Scale_out", requant_scale_out);
conv_op->Op()->SetOutput("Output",
std::vector<std::string>({requant_out->Name()}));
IR_NODE_LINK_TO(conv_op, requant_out);
GraphSafeRemoveNodes(graph, {conv_out, requant_op});
found_requant_squash_count++;
}
};
gpd(graph, handler);
AddStatis(found_requant_squash_count);
PrettyLogDetail("--- squashed %d requantize with convs",
found_requant_squash_count);
} }
void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const { void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
...@@ -131,7 +166,8 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const { ...@@ -131,7 +166,8 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
std::unordered_map<const Node*, int> nodes_keep_counter; std::unordered_map<const Node*, int> nodes_keep_counter;
FindNodesToKeep(graph, &nodes_keep_counter); FindNodesToKeep(graph, &nodes_keep_counter);
Squash(graph, &nodes_keep_counter); DequantQuantSquash(graph, &nodes_keep_counter);
ConvRequantSquash(graph);
} }
} // namespace ir } // namespace ir
......
...@@ -46,8 +46,14 @@ class CPUQuantizeSquashPass : public FusePassBase { ...@@ -46,8 +46,14 @@ class CPUQuantizeSquashPass : public FusePassBase {
/* /*
* Squash dequantize-quantize ops pairs into requantize or nothing * Squash dequantize-quantize ops pairs into requantize or nothing
*/ */
void Squash(Graph* graph, void DequantQuantSquash(
std::unordered_map<const Node*, int>* nodes_keep_counter) const; Graph* graph,
std::unordered_map<const Node*, int>* nodes_keep_counter) const;
/*
* Squash requantize op into conv with scale_out like requantize scale_out
*/
void ConvRequantSquash(Graph* graph) const;
const std::string name_scope_{"squash"}; const std::string name_scope_{"squash"};
}; };
......
...@@ -30,6 +30,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, ...@@ -30,6 +30,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op->SetAttr("use_mkldnn", use_mkldnn); op->SetAttr("use_mkldnn", use_mkldnn);
op->SetAttr("name", name); op->SetAttr("name", name);
if (type == "conv2d") { if (type == "conv2d") {
op->SetAttr("Scale_out", scale);
op->SetInput("Input", {inputs[0]}); op->SetInput("Input", {inputs[0]});
if (inputs.size() > 1) op->SetInput("Filter", {inputs[1]}); if (inputs.size() > 1) op->SetInput("Filter", {inputs[1]});
if (inputs.size() > 2) op->SetInput("Bias", {inputs[2]}); if (inputs.size() > 2) op->SetInput("Bias", {inputs[2]});
...@@ -42,14 +43,22 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, ...@@ -42,14 +43,22 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op->SetInput("Input", {inputs[0]}); op->SetInput("Input", {inputs[0]});
op->SetOutput("Output", {outputs[0]}); op->SetOutput("Output", {outputs[0]});
op->SetAttr("Scale", scale); op->SetAttr("Scale", scale);
} else if (type == "requantize") {
op->SetInput("Input", {inputs[0]});
op->SetOutput("Output", {outputs[0]});
op->SetAttr("Scale_out", scale);
} else if (type == "concat") {
op->SetInput("X", inputs);
op->SetOutput("Out", outputs);
} }
} }
// (a,w1,b1)->Conv1->d // (a,w1,b1)->Conv1->d
// d->Dequant->e // d->Dequant(scale1)->e
// e->Quant->f // e->Quant(scale2)->f
// (f,w2,b2)->Conv2->i // (f,w2,b2)->Conv2->i
ProgramDesc BuildProgramDesc(bool use_mkldnn, float scale1, float scale2) { ProgramDesc BuildConvRequantProgramDesc(bool use_mkldnn, float scale_out,
float scale1, float scale2) {
ProgramDesc prog; ProgramDesc prog;
for (auto& v : std::initializer_list<std::string>( for (auto& v : std::initializer_list<std::string>(
{"a", "w1", "b1", "d", "e", "f", "w2", "b2", "i"})) { {"a", "w1", "b1", "d", "e", "f", "w2", "b2", "i"})) {
...@@ -59,42 +68,96 @@ ProgramDesc BuildProgramDesc(bool use_mkldnn, float scale1, float scale2) { ...@@ -59,42 +68,96 @@ ProgramDesc BuildProgramDesc(bool use_mkldnn, float scale1, float scale2) {
} }
} }
SetOp(&prog, "conv2d", "Conv1", {"a", "w1", "b1"}, {"d"}, use_mkldnn); SetOp(&prog, "conv2d", "Conv1", {"a", "w1", "b1"}, {"d"}, use_mkldnn,
scale_out);
SetOp(&prog, "dequantize", "Dequant", {"d"}, {"e"}, use_mkldnn, scale1); SetOp(&prog, "dequantize", "Dequant", {"d"}, {"e"}, use_mkldnn, scale1);
SetOp(&prog, "quantize", "Quant", {"e"}, {"f"}, use_mkldnn, scale2); SetOp(&prog, "quantize", "Quant", {"e"}, {"f"}, use_mkldnn, scale2);
SetOp(&prog, "conv2d", "Conv2", {"f", "w2", "b2"}, {"i"}, use_mkldnn); SetOp(&prog, "conv2d", "Conv2", {"f", "w2", "b2"}, {"i"}, use_mkldnn,
scale_out);
return prog; return prog;
} }
static const std::initializer_list<std::string> variable_names{ static const std::initializer_list<std::string> variable_names{
"a", "b", "c", "d", "e", "f", "g", "h"}; "a", "b", "c", "d", "e", "f", "g", "h"};
// a->Conv1->b // a->Conv1->b
// b->Dequant->c // b->Dequant(scale1)->c
// // c->Quant1(scale2)->d and d->Conv2->e
// c->Quant1->d and d->Conv2->e
//
// c->Conv3->f // c->Conv3->f
// // c->Quant2(scale3)->g and g->Conv4->h
// c->Quant2->g and g->Conv4->h ProgramDesc BuildConvMultiOutputProgramDesc(bool use_mkldnn, float scale_out,
// float scale1, float scale2,
ProgramDesc BuildProgramDesc2(bool use_mkldnn, float scale1, float scale2, float scale3) {
float scale3) {
ProgramDesc prog; ProgramDesc prog;
for (auto& v : variable_names) { for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v); prog.MutableBlock(0)->Var(v);
} }
SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn); SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, scale_out);
SetOp(&prog, "dequantize", "Dequant", {"b"}, {"c"}, use_mkldnn, scale1); SetOp(&prog, "dequantize", "Dequant", {"b"}, {"c"}, use_mkldnn, scale1);
SetOp(&prog, "quantize", "Quant1", {"c"}, {"d"}, use_mkldnn, scale2); SetOp(&prog, "quantize", "Quant1", {"c"}, {"d"}, use_mkldnn, scale2);
SetOp(&prog, "conv2d", "Conv2", {"d"}, {"e"}, use_mkldnn); SetOp(&prog, "conv2d", "Conv2", {"d"}, {"e"}, use_mkldnn, scale_out);
SetOp(&prog, "conv2d", "Conv3", {"c"}, {"f"}, use_mkldnn); SetOp(&prog, "conv2d", "Conv3", {"c"}, {"f"}, use_mkldnn, scale_out);
SetOp(&prog, "quantize", "Quant2", {"c"}, {"g"}, use_mkldnn, scale3); SetOp(&prog, "quantize", "Quant2", {"c"}, {"g"}, use_mkldnn, scale3);
SetOp(&prog, "conv2d", "Conv4", {"g"}, {"h"}, use_mkldnn); SetOp(&prog, "conv2d", "Conv4", {"g"}, {"h"}, use_mkldnn, scale_out);
return prog;
}
// a->Conv1->b->Requant(scale1)->c
// d->Conv2->e->Requant(scale2)->f
// {c,f}->Concat
ProgramDesc BuildConvsRequantConcatProgramDesc(bool use_mkldnn, float scale_out,
float scale1, float scale2) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, scale_out);
SetOp(&prog, "requantize", "Requant1", {"b"}, {"c"}, use_mkldnn, scale1);
SetOp(&prog, "conv2d", "Conv2", {"d"}, {"e"}, use_mkldnn, scale_out);
SetOp(&prog, "requantize", "Requant2", {"e"}, {"f"}, use_mkldnn, scale2);
SetOp(&prog, "concat", "Concat", {"c"}, {"f"}, use_mkldnn);
return prog;
}
// a->Concat->b
// b->Dequant(scale1)->c
// c->Quant(scale2)->d
// d->Conv->e
ProgramDesc BuildConcatDequantQuantProgramDesc(bool use_mkldnn, float scale_out,
float scale1, float scale2) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "concat", "Concat", {"a"}, {"b"}, use_mkldnn);
SetOp(&prog, "dequantize", "Dequant", {"b"}, {"c"}, use_mkldnn, scale1);
SetOp(&prog, "quantize", "Quant", {"c"}, {"d"}, use_mkldnn, scale2);
SetOp(&prog, "conv2d", "Conv2", {"d"}, {"e"}, use_mkldnn, scale_out);
return prog;
}
// a->Conv1->b
// b->Requant1(Scale1)->c
// b->Requant2(Scale2)->d
ProgramDesc BuildConvMultiRequantProgramDesc(bool use_mkldnn, float scale_out,
float scale1, float scale2) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, scale_out);
SetOp(&prog, "requantize", "Requant1", {"b"}, {"c"}, use_mkldnn, scale1);
SetOp(&prog, "requantize", "Requant2", {"b"}, {"d"}, use_mkldnn, scale2);
return prog; return prog;
} }
...@@ -105,10 +168,7 @@ void InitTensorHolder(Scope* scope, const paddle::platform::Place& place, ...@@ -105,10 +168,7 @@ void InitTensorHolder(Scope* scope, const paddle::platform::Place& place,
tensor->mutable_data(place, proto::VarType::FP32, 1); tensor->mutable_data(place, proto::VarType::FP32, 1);
} }
void MainTest(const ProgramDesc& prog, int removed_nodes_num) { void PrepareGraph(std::unique_ptr<ir::Graph>* graph, const ProgramDesc& prog) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
// Init scope, as it is used in pass
auto place = paddle::platform::CPUPlace(); auto place = paddle::platform::CPUPlace();
NaiveExecutor exe{place}; NaiveExecutor exe{place};
Scope scope; Scope scope;
...@@ -117,58 +177,172 @@ void MainTest(const ProgramDesc& prog, int removed_nodes_num) { ...@@ -117,58 +177,172 @@ void MainTest(const ProgramDesc& prog, int removed_nodes_num) {
for (auto& v : variable_names) { for (auto& v : variable_names) {
InitTensorHolder(&scope, place, v.c_str()); InitTensorHolder(&scope, place, v.c_str());
} }
(*graph)->SetNotOwned(kParamScopeAttr, &scope);
}
graph->SetNotOwned(kParamScopeAttr, &scope); void RegisterPass(std::unique_ptr<ir::Graph>* graph) {
auto pass = PassRegistry::Instance().Get("cpu_quantize_squash_pass"); auto pass = PassRegistry::Instance().Get("cpu_quantize_squash_pass");
graph->reset(pass->Apply(graph->release()));
}
int original_nodes_num = graph->Nodes().size(); // check number of nodes
void CountNodeTest(const ProgramDesc& prog, int removed_nodes_num) {
graph.reset(pass->Apply(graph.release())); std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
PrepareGraph(&graph, prog);
int original_nodes_num = graph->Nodes().size();
RegisterPass(&graph);
int current_nodes_num = graph->Nodes().size(); int current_nodes_num = graph->Nodes().size();
EXPECT_EQ(original_nodes_num - removed_nodes_num, current_nodes_num); EXPECT_EQ(original_nodes_num - removed_nodes_num, current_nodes_num);
} }
// check op->scale_out
void EqualScaleOutTest(const ProgramDesc& prog, const std::string& name,
float scale) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
PrepareGraph(&graph, prog);
RegisterPass(&graph);
for (auto* node : graph->Nodes()) {
if (node->IsOp() &&
boost::get<std::string>(node->Op()->GetAttr("name")) == name) {
float scale_out = boost::get<float>(node->Op()->GetAttr("Scale_out"));
EXPECT_EQ(scale_out, scale);
}
}
}
// check requant_op scales
void CheckRequantScalesTest(const ProgramDesc& prog, float scale_in,
float scale_out) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
PrepareGraph(&graph, prog);
RegisterPass(&graph);
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == "requantize") {
float op_scale_in = boost::get<float>(node->Op()->GetAttr("Scale_in"));
EXPECT_EQ(op_scale_in, scale_in);
float op_scale_out = boost::get<float>(node->Op()->GetAttr("Scale_out"));
EXPECT_EQ(op_scale_out, scale_out);
}
}
}
// From Conv1->d->Dequant->e->Quant->f->Conv2
// To Conv1->d->Conv2
TEST(CpuQuantizeSquashPass, equal_scales) { TEST(CpuQuantizeSquashPass, equal_scales) {
auto scale_out = 1.0f;
auto scale = 1.2345f; auto scale = 1.2345f;
auto use_mkldnn = true; auto use_mkldnn = true;
// Remove 4 nodes: Dequant, Quant, e, f // Remove 4 nodes: Dequant, Quant, e, f
auto remove_nodes = 4; auto remove_nodes = 4;
MainTest(BuildProgramDesc(use_mkldnn, scale, scale), remove_nodes); CountNodeTest(
BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale, scale),
use_mkldnn = !use_mkldnn; remove_nodes);
MainTest(BuildProgramDesc(use_mkldnn, scale, scale), remove_nodes);
} }
TEST(CpuQuantizeSquashPass, inequal_scales) { // From Conv1->d->Dequant->e->Quant->f->Conv2
// First change to Conv1->d->Requant->f->Conv2
// Then Conv1->f->Conv2
TEST(CpuQuantizeSquashPass, unequal_scales) {
auto scale_out = 1.0f;
auto scale1 = 1.2345f; auto scale1 = 1.2345f;
auto scale2 = 21.0f; auto scale2 = 21.0f;
auto use_mkldnn = true; auto use_mkldnn = true;
// Remove 3 nodes: Dequant, Quant, e // Remove 4 nodes: Dequant, Quant, e, d
// Insert 1 node: requantize auto remove_nodes = 4;
CountNodeTest(
BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale1, scale2),
remove_nodes);
EqualScaleOutTest(
BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale1, scale2),
"Conv1", scale2);
}
// from
// a->Conv1->b->Dequant(Scale1)->c
// c->Quant1(Scale1)->d and d->Conv2->e
// c->Quant2(Scale2)->g and g->Conv4->h
// c->Conv3->f
// to
// a->Conv1->b
// b->Conv2->e
// b->Requant(Scale_in = Scale1; Scale_out = Scale2)->g->Conv4->h
// b->Dequant(Scale1)->c->Conv3->f
TEST(CpuQuantizeSquashPass, branch_to_equal_unequal_and_fp32) {
auto scale_out = 1.0f;
auto scale = 1.2345f;
auto scale2 = 21.0f;
auto use_mkldnn = true;
// Remove 3 nodes: Quant1, c, Quant2,
// Insert 1 node: Requant
auto remove_nodes = 2; auto remove_nodes = 2;
MainTest(BuildProgramDesc(use_mkldnn, scale1, scale2), remove_nodes); CountNodeTest(BuildConvMultiOutputProgramDesc(use_mkldnn, scale_out, scale,
scale, scale2),
remove_nodes);
CheckRequantScalesTest(BuildConvMultiOutputProgramDesc(use_mkldnn, scale_out,
scale, scale, scale2),
scale, scale2);
}
use_mkldnn = !use_mkldnn; // a->Conv1->b->Requant->c
MainTest(BuildProgramDesc(use_mkldnn, scale1, scale2), remove_nodes); // d->Conv2->e->Requant->f
// {c,f}->Concat
TEST(CpuQuantizeSquashPass, equal_scales_squash_requantize) {
// Delete both requantize op
auto scale_out = 1.0f;
auto scale = 1.2345f;
auto use_mkldnn = true;
// Remove 4 nodes: b, Requant1, e, Requant2
auto remove_nodes = 4;
CountNodeTest(
BuildConvsRequantConcatProgramDesc(use_mkldnn, scale_out, scale, scale),
remove_nodes);
// check equal scale conv->scale_out and requant->scale_out
EqualScaleOutTest(
BuildConvsRequantConcatProgramDesc(use_mkldnn, scale_out, scale, scale),
"Conv1", scale);
EqualScaleOutTest(
BuildConvsRequantConcatProgramDesc(use_mkldnn, scale_out, scale, scale),
"Conv2", scale);
} }
TEST(CpuQuantizeSquashPass, branch_to_equal_inequal_and_fp32) { // a->Concat->b->Dequant->c->Quant->d->Conv->e
// Delete both quantize ops, // to a->Concat->b->Requant->d->Conv->e
// bypass dequantize in both branches, TEST(CpuQuantizeSquashPass,
// insert requantize on one branch unequal_scales_squash_dequantize_quantize_into_requantize) {
auto scale_out = 1.0f;
auto scale = 1.2345f; auto scale = 1.2345f;
auto scale2 = 21.0f; auto scale2 = 21.0f;
auto use_mkldnn = true; auto use_mkldnn = true;
// Remove 3 nodes: Quant1, Quant2, g // Remove 3 nodes: Dequant1, c, Quant
// Insert 1 node: requantize // Insert 1 node: Requant
auto remove_nodes = 2; auto remove_nodes = 2;
MainTest(BuildProgramDesc2(use_mkldnn, scale, scale, scale2), remove_nodes); CountNodeTest(
BuildConcatDequantQuantProgramDesc(use_mkldnn, scale_out, scale, scale2),
remove_nodes);
CheckRequantScalesTest(
BuildConcatDequantQuantProgramDesc(use_mkldnn, scale_out, scale, scale2),
scale, scale2);
}
use_mkldnn = !use_mkldnn; // a->Conv1->b
MainTest(BuildProgramDesc2(use_mkldnn, scale, scale, scale2), remove_nodes); // b->Requant1(Scale1)->c
// b->Requant2(Scale2)->d
TEST(CpuQuantizeSquashPass, more_than_one_conv_out_outputs) {
auto scale_out = 1.0f;
auto scale = 1.2345f;
auto scale2 = 21.0f;
auto use_mkldnn = true;
// nothing change
auto remove_nodes = 0;
CountNodeTest(
BuildConvMultiRequantProgramDesc(use_mkldnn, scale_out, scale, scale2),
remove_nodes);
} }
} // namespace ir } // namespace ir
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册