未验证 提交 5ee099ca 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Op-requant squash (#23665)

* Op-requant squash

test=develop

* Add matmul to op-requant test

test=develop
上级 9bc1e0a1
...@@ -1490,20 +1490,22 @@ PDNode *patterns::ConvConcatReLU::operator()() { ...@@ -1490,20 +1490,22 @@ PDNode *patterns::ConvConcatReLU::operator()() {
return relu_out; return relu_out;
} }
PDNode *patterns::ConvRequant::operator()() { PDNode *patterns::OpRequant::operator()() {
// Create Operators auto any_op = pattern->NewNode(any_op_repr())
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d"); ->assert_is_op()
->assert_more([&](Node *node) {
return node->Op()->HasAttr("Scale_out") ? true : false;
});
auto requant_in = pattern->NewNode(requant_in_repr())
->assert_is_op_input("requantize", "Input");
auto requant_op = auto requant_op =
pattern->NewNode(requant_op_repr())->assert_is_op("requantize"); pattern->NewNode(requant_op_repr())->assert_is_op("requantize");
auto conv_out = pattern->NewNode(conv_out_repr())
->assert_is_op_output("conv2d", "Output");
auto requant_out = pattern->NewNode(requant_out_repr()) auto requant_out = pattern->NewNode(requant_out_repr())
->AsOutput() ->AsOutput()
->assert_is_op_output("requantize", "Output"); ->assert_is_op_output("requantize", "Output");
conv_op->LinksTo({conv_out}); any_op->LinksTo({requant_in});
requant_op->LinksFrom({conv_out}).LinksTo({requant_out}); requant_op->LinksFrom({requant_in}).LinksTo({requant_out});
return requant_out; return requant_out;
} }
......
...@@ -897,19 +897,18 @@ struct ConvConcatReLU : public PatternBase { ...@@ -897,19 +897,18 @@ struct ConvConcatReLU : public PatternBase {
PATTERN_DECL_NODE(relu_out); PATTERN_DECL_NODE(relu_out);
}; };
// Conv + Requant // Op + Requant
// named nodes: // named nodes:
// conv_op, conv_out // any_op, any_out
// requant_op, requant_out // requant_op, requant_out
struct ConvRequant : public PatternBase { struct OpRequant : public PatternBase {
ConvRequant(PDPattern* pattern, const std::string& name_scope) OpRequant(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_requant") {} : PatternBase(pattern, name_scope, "op_requant") {}
PDNode* operator()(); PDNode* operator()();
PATTERN_DECL_NODE(conv_op); PATTERN_DECL_NODE(any_op);
PATTERN_DECL_NODE(conv_out); PATTERN_DECL_NODE(requant_in);
PATTERN_DECL_NODE(requant_op); PATTERN_DECL_NODE(requant_op);
PATTERN_DECL_NODE(requant_out); PATTERN_DECL_NODE(requant_out);
}; };
......
...@@ -126,38 +126,47 @@ void CPUQuantizeSquashPass::DequantQuantSquash( ...@@ -126,38 +126,47 @@ void CPUQuantizeSquashPass::DequantQuantSquash(
found_dequant_quant_count); found_dequant_quant_count);
} }
void CPUQuantizeSquashPass::ConvRequantSquash(Graph* graph) const { // op+requant squash if op has Scale_out attr
// conv2d and fc
void CPUQuantizeSquashPass::OpRequantSquash(Graph* graph) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
patterns::ConvRequant conv_requant_pattern{gpd.mutable_pattern(), patterns::OpRequant op_requant_pattern{gpd.mutable_pattern(), "op_requant"};
"conv_requant"}; op_requant_pattern();
conv_requant_pattern();
int found_requant_squash_count = 0; int found_requant_squash_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
VLOG(4) << "squash conv-requantize ops pair"; VLOG(4) << "squash op-requantize ops pair";
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_requant_pattern); GET_IR_NODE_FROM_SUBGRAPH(any_op, any_op, op_requant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_requant_pattern); GET_IR_NODE_FROM_SUBGRAPH(requant_in, requant_in, op_requant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(requant_op, requant_op, conv_requant_pattern); GET_IR_NODE_FROM_SUBGRAPH(requant_op, requant_op, op_requant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(requant_out, requant_out, conv_requant_pattern); GET_IR_NODE_FROM_SUBGRAPH(requant_out, requant_out, op_requant_pattern);
if (requant_in->outputs.size() == 1) {
std::string any_op_output_name;
for (auto name : any_op->Op()->OutputNames())
for (auto output_name : any_op->Op()->Output(name))
if (output_name == requant_in->Name()) any_op_output_name = name;
PADDLE_ENFORCE_NE(
any_op_output_name.empty(), true,
platform::errors::NotFound("Operator before requantize operator "
"should has requantize input as output"));
// if conv2d has one output squash
if (conv_out->outputs.size() == 1) {
float requant_scale_out = float requant_scale_out =
boost::get<float>(requant_op->Op()->GetAttr("Scale_out")); boost::get<float>(requant_op->Op()->GetAttr("Scale_out"));
conv_op->Op()->SetAttr("Scale_out", requant_scale_out); any_op->Op()->SetAttr("Scale_out", requant_scale_out);
conv_op->Op()->SetOutput("Output", any_op->Op()->SetOutput(any_op_output_name,
std::vector<std::string>({requant_out->Name()})); std::vector<std::string>({requant_out->Name()}));
IR_NODE_LINK_TO(conv_op, requant_out); IR_NODE_LINK_TO(any_op, requant_out);
GraphSafeRemoveNodes(graph, {conv_out, requant_op}); GraphSafeRemoveNodes(graph, {requant_in, requant_op});
found_requant_squash_count++; found_requant_squash_count++;
} }
}; };
gpd(graph, handler); gpd(graph, handler);
AddStatis(found_requant_squash_count); AddStatis(found_requant_squash_count);
PrettyLogDetail("--- squashed %d requantize with convs", PrettyLogDetail("--- squashed %d requantize ops",
found_requant_squash_count); found_requant_squash_count);
} }
...@@ -369,7 +378,7 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const { ...@@ -369,7 +378,7 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
std::unordered_map<const Node*, int> nodes_keep_counter; std::unordered_map<const Node*, int> nodes_keep_counter;
FindNodesToKeep(graph, &nodes_keep_counter); FindNodesToKeep(graph, &nodes_keep_counter);
DequantQuantSquash(graph, &nodes_keep_counter); DequantQuantSquash(graph, &nodes_keep_counter);
ConvRequantSquash(graph); OpRequantSquash(graph);
ConvDequantSquash(graph); ConvDequantSquash(graph);
FcDequantSquash(graph); FcDequantSquash(graph);
MultipleQuantizeSquash(graph); MultipleQuantizeSquash(graph);
......
...@@ -53,7 +53,7 @@ class CPUQuantizeSquashPass : public FusePassBase { ...@@ -53,7 +53,7 @@ class CPUQuantizeSquashPass : public FusePassBase {
/* /*
* Squash requantize op into conv with scale_out like requantize scale_out * Squash requantize op into conv with scale_out like requantize scale_out
*/ */
void ConvRequantSquash(Graph* graph) const; void OpRequantSquash(Graph* graph) const;
/* /*
* Squash conv2d with dequant when dequant is the only op after conv2d * Squash conv2d with dequant when dequant is the only op after conv2d
......
...@@ -59,6 +59,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, ...@@ -59,6 +59,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
inputs.size())); inputs.size()));
op->SetInput("W", {inputs[1]}); op->SetInput("W", {inputs[1]});
op->SetOutput("Out", outputs); op->SetOutput("Out", outputs);
op->SetAttr("Scale_out", scale);
} else if (type == "scale") { } else if (type == "scale") {
op->SetInput("X", {inputs[0]}); op->SetInput("X", {inputs[0]});
op->SetOutput("Out", {outputs[0]}); op->SetOutput("Out", {outputs[0]});
...@@ -68,6 +69,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, ...@@ -68,6 +69,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op->SetInput("X", {inputs[0]}); op->SetInput("X", {inputs[0]});
op->SetInput("Y", {inputs[1]}); op->SetInput("Y", {inputs[1]});
op->SetOutput("Out", {outputs[0]}); op->SetOutput("Out", {outputs[0]});
op->SetAttr("Scale_out", scale);
} }
} }
...@@ -96,7 +98,7 @@ ProgramDesc BuildConvRequantProgramDesc(bool use_mkldnn, float scale_out, ...@@ -96,7 +98,7 @@ ProgramDesc BuildConvRequantProgramDesc(bool use_mkldnn, float scale_out,
} }
static const std::initializer_list<std::string> variable_names{ static const std::initializer_list<std::string> variable_names{
"a", "b", "c", "d", "e", "f", "g", "h", "x", "y"}; "a", "b", "c", "d", "e", "f", "g", "h", "x", "y", "w1"};
// a->Conv1->b // a->Conv1->b
// b->Dequant(scale1)->c // b->Dequant(scale1)->c
...@@ -125,23 +127,30 @@ ProgramDesc BuildConvMultiOutputProgramDesc(bool use_mkldnn, float scale_out, ...@@ -125,23 +127,30 @@ ProgramDesc BuildConvMultiOutputProgramDesc(bool use_mkldnn, float scale_out,
return prog; return prog;
} }
// a->Conv1->b->Requant(scale1)->c // a->Conv->b->Requant(scale1)->c
// d->Conv2->e->Requant(scale2)->f // d->Fc->e->Requant(scale2)->f
// {c,f}->Concat // {x,y}->Matmul->g->Requant(scale3)->h
ProgramDesc BuildConvsRequantConcatProgramDesc(bool use_mkldnn, float scale_out, // {c,f,h}->Concat
float scale1, float scale2) { ProgramDesc BuildOpRequantProgramDesc(bool use_mkldnn, float conv_scale,
float fc_scale, float matmul_scale,
float requant_scale1,
float requant_scale2,
float requant_scale3) {
ProgramDesc prog; ProgramDesc prog;
for (auto& v : variable_names) { for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v); prog.MutableBlock(0)->Var(v);
} }
SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, scale_out); SetOp(&prog, "conv2d", "Conv", {"a"}, {"b"}, use_mkldnn, conv_scale);
SetOp(&prog, "requantize", "Requant1", {"b"}, {"c"}, use_mkldnn, scale1); SetOp(&prog, "requantize", "Requant1", {"b"}, {"c"}, use_mkldnn,
requant_scale1);
SetOp(&prog, "conv2d", "Conv2", {"d"}, {"e"}, use_mkldnn, scale_out); SetOp(&prog, "fc", "Fc", {"d", "w1"}, {"e"}, use_mkldnn, fc_scale);
SetOp(&prog, "requantize", "Requant2", {"e"}, {"f"}, use_mkldnn, scale2); SetOp(&prog, "requantize", "Requant2", {"e"}, {"f"}, use_mkldnn,
requant_scale2);
SetOp(&prog, "concat", "Concat", {"c"}, {"f"}, use_mkldnn); SetOp(&prog, "matmul", "Matmul", {"x", "y"}, {"g"}, use_mkldnn, matmul_scale);
SetOp(&prog, "requantize", "Requant3", {"g"}, {"h"}, use_mkldnn,
requant_scale3);
SetOp(&prog, "concat", "Concat", {"c", "f", "h"}, {"g"}, use_mkldnn);
return prog; return prog;
} }
...@@ -412,27 +421,28 @@ TEST(CpuQuantizeSquashPass, unequal_scales) { ...@@ -412,27 +421,28 @@ TEST(CpuQuantizeSquashPass, unequal_scales) {
"Conv1", "Scale_out", scale2); "Conv1", "Scale_out", scale2);
} }
// a->Conv1->b->Requant->c // a->Conv->b->Requant->c
// d->Conv2->e->Requant->f // d->Fc->e->Requant->f
// {c,f}->Concat // {x,y}->Matmul->g->Requant->h
TEST(CpuQuantizeSquashPass, equal_scales_squash_requantize) { // {c,f,h}->Concat
// Delete both requantize op TEST(CpuQuantizeSquashPass, op_requantize_squash) {
auto scale_out = 1.0f; // Delete all requantize op
auto scale = 1.2345f; auto conv_scale = 0.234f;
auto fc_scale = 1.234f;
auto matmul_scale = 2.234f;
auto requant_scale1 = 2.234f;
auto requant_scale2 = 3.234f;
auto requant_scale3 = 4.234f;
auto use_mkldnn = true; auto use_mkldnn = true;
// Remove 4 nodes: b, Requant1, e, Requant2 // Remove 4 nodes: b, Requant1, e, Requant2, g, Requant3
auto remove_nodes = 4; auto remove_nodes = 6;
CountNodeTest( auto program_desc =
BuildConvsRequantConcatProgramDesc(use_mkldnn, scale_out, scale, scale), BuildOpRequantProgramDesc(use_mkldnn, conv_scale, fc_scale, matmul_scale,
remove_nodes); requant_scale1, requant_scale2, requant_scale3);
CountNodeTest(program_desc, remove_nodes);
// check equal scale conv->scale_out and requant->scale_out EqualScaleTest(program_desc, "Conv", "Scale_out", requant_scale1);
EqualScaleTest( EqualScaleTest(program_desc, "Fc", "Scale_out", requant_scale2);
BuildConvsRequantConcatProgramDesc(use_mkldnn, scale_out, scale, scale), EqualScaleTest(program_desc, "Matmul", "Scale_out", requant_scale3);
"Conv1", "Scale_out", scale);
EqualScaleTest(
BuildConvsRequantConcatProgramDesc(use_mkldnn, scale_out, scale, scale),
"Conv2", "Scale_out", scale);
} }
// from // from
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册