未验证 提交 66b5348e 编写于 作者: S Sylwester Fraczek 提交者: GitHub

[Bug fix] prevent squashing pair u8 dequantize -> s8 quantize (#39346)

* prevent squashing pair u8 dequantize -> s8 quantize

* add relu op to check for uint8

* fix ptq fc attr name fuse_activation->activation_type

* fix

* add unit test

* remove unused variable

* test fix unsuccessful

* fix test and logic

* multiline comment

* remove cout

* Revert "fix ptq fc attr name fuse_activation->activation_type"

This reverts commit ffd023353a5e9b0fd15e81b9e9f9fe1794035017.

* fix ptq fc attr name fuse_activation->activation_type
上级 463e31f4
......@@ -192,7 +192,6 @@ void MainTest(const ProgramDesc& prog,
int original_nodes_num, current_nodes_num;
PreparePass(&graph, prog, variable_names, &original_nodes_num,
&current_nodes_num, var_without_scale, var_signed);
std::unordered_map<std::string, int> actual_operators;
for (auto* node : graph->Nodes()) {
if (node->IsOp()) {
auto* op = node->Op();
......
......@@ -104,6 +104,34 @@ void CPUQuantizeSquashPass::FindNodesToKeep(
AddStatis(found_count);
}
bool CPUQuantizeSquashPass::IsDequantizeInputUint8(
const Node* dequant_in) const {
PADDLE_ENFORCE_EQ(
dequant_in->inputs.size(), 1,
platform::errors::InvalidArgument(
"Dequantize (id: %f) should have only one input.", dequant_in->id()));
if (dequant_in->inputs[0]->IsOp()) {
auto prev_op = dequant_in->inputs[0]->Op();
std::string act_name;
if (prev_op->Type() == "relu") {
return true;
} else {
if (prev_op->Type() == "conv2d") {
act_name = "fuse_activation";
} else if (prev_op->Type() == "fc") {
act_name = "activation_type";
}
if (!act_name.empty()) {
auto act = prev_op->GetAttrIfExists<std::string>(act_name);
if (act == "relu" || act == "relu6") {
return true;
}
}
}
}
return false;
}
void CPUQuantizeSquashPass::DequantQuantSquash(
Graph* graph,
std::unordered_map<const Node*, int>* nodes_keep_counter) const {
......@@ -123,6 +151,12 @@ void CPUQuantizeSquashPass::DequantQuantSquash(
GET_IR_NODE_FROM_SUBGRAPH(quant_out, quant_out, squash_pattern);
GET_IR_NODE_FROM_SUBGRAPH(next_op, next_op, squash_pattern);
// Don't squash if e.g. just one concat input is unsigned
if (IsDequantizeInputUint8(dequant_in) &&
!quant_op->Op()->GetAttrIfExists<bool>("is_negative_input")) {
return;
}
auto* next_op_desc = next_op->Op();
float dequant_scale =
BOOST_GET_CONST(float, dequant_op->Op()->GetAttr("Scale"));
......
......@@ -43,6 +43,11 @@ class CPUQuantizeSquashPass : public FusePassBase {
Graph* graph,
std::unordered_map<const Node*, int>* nodes_keep_counter) const;
/*
* Check if input to dequantize is uint8
*/
bool IsDequantizeInputUint8(const Node* dequant_in) const;
/*
* Squash dequantize-quantize ops pairs into requantize or nothing
*/
......
......@@ -26,12 +26,26 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
const std::vector<std::string>& outputs, bool use_mkldnn,
const std::vector<float> scale = {}, float bias = 0.0,
const std::string& mkldnn_data_type = "float32",
bool bias_after_scale = false, int groups = 1) {
bool bias_after_scale = false, int groups = 1,
bool is_negative_input = true) {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
op->SetAttr("use_mkldnn", use_mkldnn);
op->SetAttr("name", name);
if (type == "conv2d") {
if (type != "dropout" && type != "quantize" && type != "dequantize") {
op->SetAttr("mkldnn_data_type", mkldnn_data_type);
}
if (type == "pool2d") {
op->SetInput("X", {inputs[0]});
op->SetOutput("Out", {outputs[0]});
if (scale.size() > 0) op->SetAttr("Scale_in", scale[0]);
if (scale.size() > 1) op->SetAttr("Scale_out", scale[1]);
} else if (type == "relu") {
op->SetInput("X", {inputs[0]});
op->SetOutput("Out", {outputs[0]});
if (scale.size() > 0) op->SetAttr("Scale_in", scale[0]);
if (scale.size() > 1) op->SetAttr("Scale_out", scale[1]);
} else if (type == "conv2d") {
if (scale.size() > 0) op->SetAttr("Scale_in", scale[0]);
if (scale.size() > 1) op->SetAttr("Scale_out", scale[1]);
op->SetInput("Input", {inputs[0]});
......@@ -48,11 +62,11 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op->SetAttr("padding_algorithm", std::string("EXPLICIT"));
op->SetAttr("data_format", std::string("NCHW"));
op->SetAttr("force_fp32_output", false);
op->SetAttr("mkldnn_data_type", mkldnn_data_type);
} else if (type == "quantize") {
op->SetInput("Input", {inputs[0]});
op->SetOutput("Output", {outputs[0]});
op->SetAttr("Scale", scale[0]);
op->SetAttr("is_negative_input", is_negative_input);
} else if (type == "dequantize") {
op->SetInput("Input", {inputs[0]});
op->SetOutput("Output", {outputs[0]});
......@@ -121,7 +135,8 @@ ProgramDesc BuildConvRequantProgramDesc(bool use_mkldnn, float scale_out,
}
static const std::initializer_list<std::string> variable_names{
"a", "b", "c", "d", "e", "f", "g", "h", "i", "x", "y", "w1", "w2"};
"a", "b", "c", "d", "e", "f", "g", "h",
"i", "j", "k", "l", "x", "y", "w1", "w2"};
// a->Conv1(scale1)->b
// b->Dequant(scale1)->c
......@@ -219,6 +234,35 @@ ProgramDesc BuildConvMultiRequantProgramDesc(bool use_mkldnn, float scale_out,
return prog;
}
/* a->pool2d->b->Dequant->c(s8)->Quant->d-\
* e->relu->f->Dequant->g(u8)->Quant->h--Concat1->x
* i->pool2d->j->Dequant->k(s8)->Quant->l-/
*/
ProgramDesc BuildConvS8U8S8ConcatProgramDesc(float scale_out, float scale) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "pool2d", "Pool2d1", {"a"}, {"b"}, true, {scale, scale_out});
SetOp(&prog, "relu", "Relu1", {"e"}, {"f"}, true, {scale, scale_out});
SetOp(&prog, "pool2d", "Pool2d2", {"i"}, {"j"}, true, {scale, scale_out});
SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, true,
{scale, scale_out});
SetOp(&prog, "dequantize", "Dequant2", {"f"}, {"g"}, true,
{scale, scale_out});
SetOp(&prog, "dequantize", "Dequant3", {"j"}, {"k"}, true,
{scale, scale_out});
SetOp(&prog, "quantize", "Quant1", {"c"}, {"d"}, true, {scale, scale_out});
SetOp(&prog, "quantize", "Quant2", {"g"}, {"h"}, true, {scale, scale_out},
0.0, "float32", false, 1, false);
SetOp(&prog, "quantize", "Quant3", {"k"}, {"l"}, true, {scale, scale_out});
SetOp(&prog, "concat", "Concat1", {"d", "h", "l"}, {"x"}, true);
return prog;
}
// a->Conv1->b
// b->Dequant1(Scale1)->c
// c->Concat
......@@ -426,6 +470,31 @@ void CountNodeTest(const ProgramDesc& prog, int removed_nodes_num) {
EXPECT_EQ(original_nodes_num - removed_nodes_num, current_nodes_num);
}
void CheckNodesTest(const ProgramDesc& prog,
std::unordered_map<std::string, int> expected_operators,
const int removed_nodes_num) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
PrepareGraph(&graph, prog);
int original_nodes_num = graph->Nodes().size();
RegisterPass(&graph);
int current_nodes_num = graph->Nodes().size();
EXPECT_EQ(original_nodes_num - removed_nodes_num, current_nodes_num);
for (auto* node : graph->Nodes()) {
if (node->IsOp()) {
auto* op = node->Op();
if (expected_operators.count(op->Type()) > 0) {
expected_operators[op->Type()]--;
}
}
}
for (auto const& pair : expected_operators) {
EXPECT_EQ(pair.second, 0) << " " << pair.first;
}
}
// check op->scale_out
void EqualScaleTest(const ProgramDesc& prog, const std::string& op_name,
const std::string& scale_name, float scale) {
......@@ -764,6 +833,18 @@ TEST(CpuQuantizeSquashPass, quant_bf16_conv2d) {
remove_nodes);
}
TEST(CpuQuantizeSquashPass, dont_squash_u8_dequant_s8_quant_input_to_concat) {
// removed 2 x 4 (dequantize_op, dequantize_out, quantize, quantize_out)
auto remove_nodes = 8;
std::unordered_map<std::string, int> expected_operators = {{"concat", 1},
{"quantize", 1},
{"dequantize", 1},
{"relu", 1},
{"pool2d", 2}};
CheckNodesTest(BuildConvS8U8S8ConcatProgramDesc(1.2f, 1.2f),
expected_operators, remove_nodes);
}
} // namespace ir
} // namespace framework
} // namespace paddle
......
......@@ -116,11 +116,15 @@ void AnalysisPredictor::MkldnnQuantizer::CalculateScalesForOpOutputs(
// force unsigned type if already know it
bool is_unsigned = false;
bool compute_scale = true;
if (op->Type() == "conv2d" || op->Type() == "fc") {
if (op->Type() == "conv2d") {
// output of conv2d with relu must be unsigned
std::string fuse_activation =
op->GetAttrIfExists<std::string>("fuse_activation");
is_unsigned = (fuse_activation == "relu" || fuse_activation == "relu6");
} else if (op->Type() == "fc") {
std::string activation_type =
op->GetAttrIfExists<std::string>("activation_type");
is_unsigned = (activation_type == "relu" || activation_type == "relu6");
} else if (op->Type() == "relu") {
is_unsigned = true;
} else if (op->Type() == "transpose2" || op->Type() == "reshape2" ||
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册