未验证 提交 b43b46e6 编写于 作者: J joanna.wozna.intel 提交者: GitHub

[INT8] Add requant-op squash (#24143)

上级 e8869a90
......@@ -1508,6 +1508,27 @@ PDNode *patterns::OpRequant::operator()() {
return requant_out;
}
PDNode *patterns::RequantOp::operator()() {
auto requant_in = pattern->NewNode(requant_in_repr())
->assert_is_op_input("requantize", "Input");
auto requant_op =
pattern->NewNode(requant_op_repr())->assert_is_op("requantize");
auto requant_out = pattern->NewNode(requant_out_repr())
->AsOutput()
->assert_is_op_output("requantize", "Output");
auto any_op = pattern->NewNode(any_op_repr())
->assert_is_op()
->assert_more([&](Node *node) {
return (node->Op()->HasAttr("Scale_in") ||
node->Op()->HasAttr("Scale_x") ||
node->Op()->HasAttr("Scale_y"));
});
requant_op->LinksFrom({requant_in}).LinksTo({requant_out});
any_op->LinksFrom({requant_out});
return any_op;
}
PDNode *patterns::ConvDequant::operator()() {
// Create Operators
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
......
......@@ -913,6 +913,22 @@ struct OpRequant : public PatternBase {
PATTERN_DECL_NODE(requant_out);
};
// Requant + Op
// named nodes:
// requant_in, requant_op,
// requant_out, any_op
struct RequantOp : public PatternBase {
RequantOp(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "requant_op") {}
PDNode* operator()();
PATTERN_DECL_NODE(any_op);
PATTERN_DECL_NODE(requant_in);
PATTERN_DECL_NODE(requant_op);
PATTERN_DECL_NODE(requant_out);
};
// Conv + Dequant
// named nodes:
// conv_op, conv_out
......
......@@ -152,7 +152,7 @@ void CPUQuantizeSquashPass::OpRequantSquash(Graph* graph) const {
PADDLE_ENFORCE_NE(
any_op_output_name.empty(), true,
platform::errors::NotFound("Operator before requantize operator "
"should has requantize input as output"));
"should have requantize input as output"));
float requant_scale_out =
boost::get<float>(requant_op->Op()->GetAttr("Scale_out"));
......@@ -170,6 +170,59 @@ void CPUQuantizeSquashPass::OpRequantSquash(Graph* graph) const {
found_requant_squash_count);
}
// requant-op squash if op has Scale_in, Scale_x, Scale_y attr
// conv2d, fc, matmul
void CPUQuantizeSquashPass::RequantOpSquash(Graph* graph) const {
GraphPatternDetector gpd;
patterns::RequantOp requant_op_pattern{gpd.mutable_pattern(), "requant_op"};
requant_op_pattern();
int found_requant_squash_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "squash requantize-op ops pair";
GET_IR_NODE_FROM_SUBGRAPH(requant_in, requant_in, requant_op_pattern);
GET_IR_NODE_FROM_SUBGRAPH(requant_op, requant_op, requant_op_pattern);
GET_IR_NODE_FROM_SUBGRAPH(requant_out, requant_out, requant_op_pattern);
GET_IR_NODE_FROM_SUBGRAPH(any_op, any_op, requant_op_pattern);
if (requant_out->outputs.size() == 1) {
std::string any_op_input_name;
for (auto name : any_op->Op()->InputNames())
for (auto input_name : any_op->Op()->Input(name))
if (input_name == requant_out->Name()) any_op_input_name = name;
PADDLE_ENFORCE_NE(
any_op_input_name.empty(), true,
platform::errors::NotFound("The operator after requantize operator "
"should have requantize output as input"));
float requant_scale_in =
boost::get<float>(requant_op->Op()->GetAttr("Scale_in"));
auto scale_name = "Scale_in";
if (any_op->Op()->Type() == "matmul")
scale_name = any_op_input_name == "X" ? "Scale_x" : "Scale_y";
PADDLE_ENFORCE_EQ(requant_op->Op()->GetAttrIfExists<float>("Scale_out"),
any_op->Op()->GetAttrIfExists<float>(scale_name),
platform::errors::InvalidArgument(
"The operator after requantize should have input "
"scale equal to requantize output scale"));
any_op->Op()->SetAttr(scale_name, requant_scale_in);
any_op->Op()->SetInput(any_op_input_name,
std::vector<std::string>({requant_in->Name()}));
IR_NODE_LINK_TO(requant_in, any_op);
GraphSafeRemoveNodes(graph, {requant_op, requant_out});
found_requant_squash_count++;
}
};
gpd(graph, handler);
AddStatis(found_requant_squash_count);
PrettyLogDetail("--- squashed %d requantize ops",
found_requant_squash_count);
}
void CPUQuantizeSquashPass::ConvDequantSquash(Graph* graph) const {
GraphPatternDetector gpd;
patterns::ConvDequant conv_dequant_pattern{gpd.mutable_pattern(),
......@@ -379,6 +432,7 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
FindNodesToKeep(graph, &nodes_keep_counter);
DequantQuantSquash(graph, &nodes_keep_counter);
OpRequantSquash(graph);
RequantOpSquash(graph);
ConvDequantSquash(graph);
FcDequantSquash(graph);
MultipleQuantizeSquash(graph);
......
......@@ -55,6 +55,11 @@ class CPUQuantizeSquashPass : public FusePassBase {
*/
void OpRequantSquash(Graph* graph) const;
/*
* Squash requantize op if the next operator's input scale can be updated
*/
void RequantOpSquash(Graph* graph) const;
/*
* Squash conv2d with dequant when dequant is the only op after conv2d
*/
......
......@@ -24,13 +24,14 @@ namespace ir {
void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs, bool use_mkldnn,
float scale = 0, float bias = 0.0) {
const std::vector<float> scale = {}, float bias = 0.0) {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
op->SetAttr("use_mkldnn", use_mkldnn);
op->SetAttr("name", name);
if (type == "conv2d") {
op->SetAttr("Scale_out", scale);
if (scale.size() > 0) op->SetAttr("Scale_in", scale[0]);
if (scale.size() > 1) op->SetAttr("Scale_out", scale[1]);
op->SetInput("Input", {inputs[0]});
if (inputs.size() > 1) op->SetInput("Filter", {inputs[1]});
if (inputs.size() > 2) op->SetInput("Bias", {inputs[2]});
......@@ -38,15 +39,16 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
} else if (type == "quantize") {
op->SetInput("Input", {inputs[0]});
op->SetOutput("Output", {outputs[0]});
op->SetAttr("Scale", scale);
op->SetAttr("Scale", scale[0]);
} else if (type == "dequantize") {
op->SetInput("Input", {inputs[0]});
op->SetOutput("Output", {outputs[0]});
op->SetAttr("Scale", scale);
op->SetAttr("Scale", scale[0]);
} else if (type == "requantize") {
op->SetInput("Input", {inputs[0]});
op->SetOutput("Output", {outputs[0]});
op->SetAttr("Scale_out", scale);
op->SetAttr("Scale_in", scale[0]);
op->SetAttr("Scale_out", scale[1]);
} else if (type == "concat") {
op->SetInput("X", inputs);
op->SetOutput("Out", outputs);
......@@ -59,17 +61,19 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
inputs.size()));
op->SetInput("W", {inputs[1]});
op->SetOutput("Out", outputs);
op->SetAttr("Scale_out", scale);
if (scale.size() > 0) op->SetAttr("Scale_in", scale[0]);
if (scale.size() > 1) op->SetAttr("Scale_out", scale[1]);
} else if (type == "scale") {
op->SetInput("X", {inputs[0]});
op->SetOutput("Out", {outputs[0]});
op->SetAttr("scale", scale);
op->SetAttr("scale", scale[0]);
op->SetAttr("bias", bias);
} else if (type == "matmul") {
op->SetInput("X", {inputs[0]});
op->SetInput("Y", {inputs[1]});
op->SetOutput("Out", {outputs[0]});
op->SetAttr("Scale_out", scale);
if (scale.size() > 0) op->SetAttr("Scale_x", scale[0]);
if (scale.size() > 1) op->SetAttr("Scale_out", scale[1]);
}
}
......@@ -78,7 +82,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
// e->Quant(scale2)->f
// (f,w2,b2)->Conv2->i
ProgramDesc BuildConvRequantProgramDesc(bool use_mkldnn, float scale_out,
float scale1, float scale2) {
float scale_in) {
ProgramDesc prog;
for (auto& v : std::initializer_list<std::string>(
{"a", "w1", "b1", "d", "e", "f", "w2", "b2", "i"})) {
......@@ -89,22 +93,22 @@ ProgramDesc BuildConvRequantProgramDesc(bool use_mkldnn, float scale_out,
}
SetOp(&prog, "conv2d", "Conv1", {"a", "w1", "b1"}, {"d"}, use_mkldnn,
scale_out);
SetOp(&prog, "dequantize", "Dequant", {"d"}, {"e"}, use_mkldnn, scale1);
SetOp(&prog, "quantize", "Quant", {"e"}, {"f"}, use_mkldnn, scale2);
{1.23f, scale_out});
SetOp(&prog, "dequantize", "Dequant", {"d"}, {"e"}, use_mkldnn, {scale_out});
SetOp(&prog, "quantize", "Quant", {"e"}, {"f"}, use_mkldnn, {scale_in});
SetOp(&prog, "conv2d", "Conv2", {"f", "w2", "b2"}, {"i"}, use_mkldnn,
scale_out);
{scale_in, 2.34f});
return prog;
}
static const std::initializer_list<std::string> variable_names{
"a", "b", "c", "d", "e", "f", "g", "h", "x", "y", "w1"};
"a", "b", "c", "d", "e", "f", "g", "h", "i", "x", "y", "w1", "w2"};
// a->Conv1->b
// a->Conv1(scale1)->b
// b->Dequant(scale1)->c
// c->Quant1(scale2)->d and d->Conv2->e
// c->Quant1(scale2)->d and d->(scale2)Conv2->e
// c->Conv3->f
// c->Quant2(scale3)->g and g->Conv4->h
// c->Quant2(scale3)->g and g->Concat->h
ProgramDesc BuildConvMultiOutputProgramDesc(bool use_mkldnn, float scale_out,
float scale1, float scale2,
float scale3) {
......@@ -113,16 +117,17 @@ ProgramDesc BuildConvMultiOutputProgramDesc(bool use_mkldnn, float scale_out,
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, scale_out);
SetOp(&prog, "dequantize", "Dequant", {"b"}, {"c"}, use_mkldnn, scale1);
SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, {1.23f, scale1});
SetOp(&prog, "dequantize", "Dequant", {"b"}, {"c"}, use_mkldnn, {scale1});
SetOp(&prog, "quantize", "Quant1", {"c"}, {"d"}, use_mkldnn, scale2);
SetOp(&prog, "conv2d", "Conv2", {"d"}, {"e"}, use_mkldnn, scale_out);
SetOp(&prog, "quantize", "Quant1", {"c"}, {"d"}, use_mkldnn, {scale2});
SetOp(&prog, "conv2d", "Conv2", {"d"}, {"e"}, use_mkldnn,
{scale2, scale_out});
SetOp(&prog, "conv2d", "Conv3", {"c"}, {"f"}, use_mkldnn, scale_out);
SetOp(&prog, "conv2d", "Conv3", {"c"}, {"f"}, use_mkldnn);
SetOp(&prog, "quantize", "Quant2", {"c"}, {"g"}, use_mkldnn, scale3);
SetOp(&prog, "conv2d", "Conv4", {"g"}, {"h"}, use_mkldnn, scale_out);
SetOp(&prog, "quantize", "Quant2", {"c"}, {"g"}, use_mkldnn, {scale3});
SetOp(&prog, "concat", "Concat", {"g"}, {"h"}, use_mkldnn);
return prog;
}
......@@ -141,16 +146,17 @@ ProgramDesc BuildOpRequantProgramDesc(bool use_mkldnn, float conv_scale,
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "conv2d", "Conv", {"a"}, {"b"}, use_mkldnn, conv_scale);
SetOp(&prog, "conv2d", "Conv", {"a"}, {"b"}, use_mkldnn, {1.23f, conv_scale});
SetOp(&prog, "requantize", "Requant1", {"b"}, {"c"}, use_mkldnn,
requant_scale1);
SetOp(&prog, "fc", "Fc", {"d", "w1"}, {"e"}, use_mkldnn, fc_scale);
{conv_scale, requant_scale1});
SetOp(&prog, "fc", "Fc", {"d", "w1"}, {"e"}, use_mkldnn, {1.23f, fc_scale});
SetOp(&prog, "requantize", "Requant2", {"e"}, {"f"}, use_mkldnn,
requant_scale2);
SetOp(&prog, "matmul", "Matmul", {"x", "y"}, {"g"}, use_mkldnn, matmul_scale);
{fc_scale, requant_scale2});
SetOp(&prog, "matmul", "Matmul", {"x", "y"}, {"g"}, use_mkldnn,
{1.23f, matmul_scale});
SetOp(&prog, "requantize", "Requant3", {"g"}, {"h"}, use_mkldnn,
requant_scale3);
SetOp(&prog, "concat", "Concat", {"c", "f", "h"}, {"g"}, use_mkldnn);
{matmul_scale, requant_scale3});
SetOp(&prog, "concat", "Concat", {"c", "f", "h"}, {"g"}, {use_mkldnn});
return prog;
}
......@@ -158,7 +164,8 @@ ProgramDesc BuildOpRequantProgramDesc(bool use_mkldnn, float conv_scale,
// a->Concat->b
// b->Dequant(scale1)->c
// c->Quant(scale2)->d
// d->Conv->e
// d->Conv1->e
// d->Conv2->f
ProgramDesc BuildConcatDequantQuantProgramDesc(bool use_mkldnn, float scale_out,
float scale1, float scale2) {
ProgramDesc prog;
......@@ -167,9 +174,12 @@ ProgramDesc BuildConcatDequantQuantProgramDesc(bool use_mkldnn, float scale_out,
}
SetOp(&prog, "concat", "Concat", {"a"}, {"b"}, use_mkldnn);
SetOp(&prog, "dequantize", "Dequant", {"b"}, {"c"}, use_mkldnn, scale1);
SetOp(&prog, "quantize", "Quant", {"c"}, {"d"}, use_mkldnn, scale2);
SetOp(&prog, "conv2d", "Conv2", {"d"}, {"e"}, use_mkldnn, scale_out);
SetOp(&prog, "dequantize", "Dequant", {"b"}, {"c"}, use_mkldnn, {scale1});
SetOp(&prog, "quantize", "Quant", {"c"}, {"d"}, use_mkldnn, {scale2});
SetOp(&prog, "conv2d", "Conv1", {"d"}, {"e"}, use_mkldnn,
{scale2, scale_out});
SetOp(&prog, "conv2d", "Conv2", {"d"}, {"f"}, use_mkldnn,
{scale2, scale_out});
return prog;
}
......@@ -182,9 +192,11 @@ ProgramDesc BuildConvMultiRequantProgramDesc(bool use_mkldnn, float scale_out,
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, scale_out);
SetOp(&prog, "requantize", "Requant1", {"b"}, {"c"}, use_mkldnn, scale1);
SetOp(&prog, "requantize", "Requant2", {"b"}, {"d"}, use_mkldnn, scale2);
SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, {1.23f, scale_out});
SetOp(&prog, "requantize", "Requant1", {"b"}, {"c"}, use_mkldnn,
{scale_out, scale1});
SetOp(&prog, "requantize", "Requant2", {"b"}, {"d"}, use_mkldnn,
{scale_out, scale2});
return prog;
}
......@@ -197,8 +209,8 @@ ProgramDesc BuildConvDequantConcatProgramDesc(bool use_mkldnn, float scale_out,
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, scale_out);
SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, use_mkldnn, scale);
SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, {1.23f, scale_out});
SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, use_mkldnn, {scale});
SetOp(&prog, "concat", "Concat1", {"c"}, {"d"}, use_mkldnn);
return prog;
}
......@@ -212,24 +224,24 @@ ProgramDesc BuildFcDequantConcatProgramDesc(bool use_mkldnn, float scale_out,
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "fc", "Fc1", {"a", "w1"}, {"b"}, use_mkldnn, scale_out);
SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, use_mkldnn, scale);
SetOp(&prog, "fc", "Fc1", {"a", "w1"}, {"b"}, use_mkldnn, {1.23f, scale_out});
SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, use_mkldnn, {scale});
SetOp(&prog, "concat", "Concat1", {"c"}, {"d"}, use_mkldnn);
return prog;
}
// a->fc->b
// b->Dequant1->c
// b->concat->d
// b->fc->d
ProgramDesc BuildFcDequantFcProgramDesc(bool use_mkldnn, float scale_out,
float scale) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "fc", "Fc1", {"a", "w1"}, {"b"}, use_mkldnn, scale_out);
SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, use_mkldnn, scale);
SetOp(&prog, "concat", "Concat1", {"b"}, {"d"}, use_mkldnn);
SetOp(&prog, "fc", "Fc1", {"a", "w1"}, {"b"}, use_mkldnn, {1.23f, scale_out});
SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, use_mkldnn, {scale});
SetOp(&prog, "fc", "Fc2", {"b", "w2"}, {"d"}, use_mkldnn, {scale_out, 2.34f});
return prog;
}
......@@ -242,18 +254,16 @@ ProgramDesc BuildConvDequantConvProgramDesc(bool use_mkldnn, float scale_out,
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, scale_out);
SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, use_mkldnn, scale);
SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, {1.23f, scale_out});
SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, use_mkldnn, {scale});
SetOp(&prog, "conv2d", "Conv2", {"b"}, {"d"}, use_mkldnn);
return prog;
}
// a->concat->b
// b->Quant1(Scale1)->c
// b->Quant2(Scale2)->d
// b->Quant1(Scale1)->c->fc->f
// b->Quant2(Scale2)->d->fc->g
// b->concat->e
// c->fc->f
// d->fc->g
ProgramDesc BuildMultipleQuantizeProgramDesc(bool use_mkldnn, float first_scale,
float second_scale) {
ProgramDesc prog;
......@@ -261,11 +271,15 @@ ProgramDesc BuildMultipleQuantizeProgramDesc(bool use_mkldnn, float first_scale,
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "concat", "Concat1", {"a"}, {"b"}, use_mkldnn);
SetOp(&prog, "quantize", "Quantize1", {"b"}, {"c"}, use_mkldnn, first_scale);
SetOp(&prog, "quantize", "Quantize2", {"b"}, {"d"}, use_mkldnn, second_scale);
SetOp(&prog, "quantize", "Quantize1", {"b"}, {"c"}, use_mkldnn,
{first_scale});
SetOp(&prog, "quantize", "Quantize2", {"b"}, {"d"}, use_mkldnn,
{second_scale});
SetOp(&prog, "concat", "Concat2", {"b"}, {"e"}, use_mkldnn);
SetOp(&prog, "fc", "Fc1", {"c", "w1"}, {"f"}, use_mkldnn, first_scale);
SetOp(&prog, "fc", "Fc2", {"d", "w2"}, {"g"}, use_mkldnn, second_scale);
SetOp(&prog, "fc", "Fc1", {"c", "w1"}, {"f"}, use_mkldnn,
{first_scale, 1.23f});
SetOp(&prog, "fc", "Fc2", {"d", "w2"}, {"g"}, use_mkldnn,
{second_scale, 2.34f});
return prog;
}
......@@ -279,8 +293,8 @@ ProgramDesc BuildDequantScaleProgramDesc(bool use_mkldnn, float dequant_scale,
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "dequantize", "Dequant", {"a"}, {"b"}, use_mkldnn,
dequant_scale);
SetOp(&prog, "scale", "Scale", {"b"}, {"c"}, use_mkldnn, scale_scale, bias);
{dequant_scale});
SetOp(&prog, "scale", "Scale", {"b"}, {"c"}, use_mkldnn, {scale_scale}, bias);
return prog;
}
......@@ -295,7 +309,34 @@ ProgramDesc BuildMatmulDequantProgramDesc(bool use_mkldnn,
}
SetOp(&prog, "matmul", "Matmul", {"x", "y"}, {"b"}, use_mkldnn);
SetOp(&prog, "dequantize", "Dequant", {"b"}, {"c"}, use_mkldnn,
dequant_scale);
{dequant_scale});
return prog;
}
// a->Requant1->x->Matmul->b
// c->Requant2->d->Fc->e
// f->Requant3->g->Conv->h
// {b,e,h}->Concat->i
ProgramDesc BuildRequantOpProgramDesc(bool use_mkldnn, float requant_scale_in,
float op_scale_in, float op_scale_out) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "requantize", "Requant1", {"a"}, {"x"}, use_mkldnn,
{requant_scale_in, op_scale_in});
SetOp(&prog, "matmul", "Matmul", {"x", "y"}, {"b"}, use_mkldnn,
{op_scale_in, op_scale_out});
SetOp(&prog, "requantize", "Requant2", {"c"}, {"d"}, use_mkldnn,
{requant_scale_in, op_scale_in});
SetOp(&prog, "fc", "Fc", {"d", "w1"}, {"e"}, use_mkldnn,
{op_scale_in, op_scale_out});
SetOp(&prog, "requantize", "Requant3", {"f"}, {"g"}, use_mkldnn,
{requant_scale_in, op_scale_in});
SetOp(&prog, "conv2d", "Conv", {"g"}, {"h"}, use_mkldnn,
{op_scale_in, op_scale_out});
SetOp(&prog, "concat", "Concat", {"b", "e", "h"}, {"i"}, {use_mkldnn});
return prog;
}
......@@ -390,14 +431,13 @@ void IsForceFp32OutputTest(const ProgramDesc& prog, std::string op_type,
// From Conv1->d->Dequant->e->Quant->f->Conv2
// To Conv1->d->Conv2
TEST(CpuQuantizeSquashPass, equal_scales) {
auto scale_out = 1.0f;
auto scale = 1.2345f;
auto scale_out = 1.234f;
auto scale = 2.345f;
auto use_mkldnn = true;
// Remove 4 nodes: Dequant, Quant, e, f
auto remove_nodes = 4;
CountNodeTest(
BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale, scale),
CountNodeTest(BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale),
remove_nodes);
}
......@@ -405,20 +445,17 @@ TEST(CpuQuantizeSquashPass, equal_scales) {
// First change to Conv1->d->Requant->f->Conv2
// Then Conv1->f->Conv2
TEST(CpuQuantizeSquashPass, unequal_scales) {
auto scale_out = 1.0f;
auto scale1 = 1.2345f;
auto scale2 = 21.0f;
auto scale_out = 1.230f;
auto scale_in = 2.34f;
auto use_mkldnn = true;
// Remove 4 nodes: Dequant, Quant, e, d
auto remove_nodes = 4;
CountNodeTest(
BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale1, scale2),
CountNodeTest(BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale_in),
remove_nodes);
EqualScaleTest(
BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale1, scale2),
"Conv1", "Scale_out", scale2);
EqualScaleTest(BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale_in),
"Conv1", "Scale_out", scale_in);
}
// a->Conv->b->Requant->c
......@@ -635,6 +672,20 @@ TEST(CpuQuantizeSquashPass, matmul_with_dequant) {
IsForceFp32OutputTest(
BuildMatmulDequantProgramDesc(use_mkldnn, dequant_scale), "matmul", true);
}
TEST(CpuQuantizeSquashPass, requantize_with_matmul_fc_conv) {
auto use_mkldnn = true;
auto requant_scale_in = 1.2f, op_scale_in = 2.3f, op_scale_out = 3.4f;
// remove: 3 requant ops + 3 requant outs
auto remove_nodes = 6;
auto program_desc = BuildRequantOpProgramDesc(use_mkldnn, requant_scale_in,
op_scale_in, op_scale_out);
CountNodeTest(program_desc, remove_nodes);
EqualScaleTest(program_desc, "Matmul", "Scale_x", requant_scale_in);
EqualScaleTest(program_desc, "Fc", "Scale_in", requant_scale_in);
EqualScaleTest(program_desc, "Conv", "Scale_in", requant_scale_in);
}
} // namespace ir
} // namespace framework
} // namespace paddle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册