未验证 提交 e4d475ea 编写于 作者: J joanna.wozna.intel 提交者: GitHub

[Bug fix] Fixed handling of one of the cases in the quantization process (#39342)

* Fix quantization next op findings

* Corrections according to the review
上级 f884edb9
......@@ -1592,11 +1592,8 @@ PDNode *patterns::Transpose::operator()() {
->AsOutput()
->assert_is_op_output("transpose2", "Out");
auto next_op = pattern->NewNode(next_op_repr())->assert_is_op();
prev_op->LinksTo({transpose_in});
transpose_op->LinksFrom({transpose_in}).LinksTo({transpose_out});
next_op->LinksFrom({transpose_out});
return transpose_out;
}
......@@ -1613,11 +1610,8 @@ PDNode *patterns::Reshape::operator()() {
->AsOutput()
->assert_is_op_output("reshape2", "Out");
auto next_op = pattern->NewNode(next_op_repr())->assert_is_op();
prev_op->LinksTo({reshape_in});
reshape_op->LinksFrom({reshape_in}).LinksTo({reshape_out});
next_op->LinksFrom({reshape_out});
return reshape_out;
}
......@@ -1633,11 +1627,8 @@ PDNode *patterns::Slice::operator()() {
->AsOutput()
->assert_is_op_output("slice", "Out");
auto next_op = pattern->NewNode(next_op_repr())->assert_is_op();
prev_op->LinksTo({slice_in});
slice_op->LinksFrom({slice_in}).LinksTo({slice_out});
next_op->LinksFrom({slice_out});
return slice_out;
}
......@@ -1658,12 +1649,9 @@ PDNode *patterns::NearestInterp::operator()() {
->assert_is_ops_output({"nearest_interp", "nearest_interp_v2"},
"Out");
auto next_op = pattern->NewNode(next_op_repr())->assert_is_op();
prev_op->LinksTo({nearest_interp_in});
nearest_interp_op->LinksFrom({nearest_interp_in})
.LinksTo({nearest_interp_out});
next_op->LinksFrom({nearest_interp_out});
return nearest_interp_out;
}
......
......@@ -963,7 +963,6 @@ struct Transpose : public PatternBase {
PATTERN_DECL_NODE(transpose_in);
PATTERN_DECL_NODE(transpose_op);
PATTERN_DECL_NODE(transpose_out);
PATTERN_DECL_NODE(next_op);
};
// Reshape op
......@@ -978,7 +977,6 @@ struct Reshape : public PatternBase {
PATTERN_DECL_NODE(reshape_in);
PATTERN_DECL_NODE(reshape_op);
PATTERN_DECL_NODE(reshape_out);
PATTERN_DECL_NODE(next_op);
};
// Slice op
// Forward pass for slice.
......@@ -992,7 +990,6 @@ struct Slice : public PatternBase {
PATTERN_DECL_NODE(slice_in);
PATTERN_DECL_NODE(slice_op);
PATTERN_DECL_NODE(slice_out);
PATTERN_DECL_NODE(next_op);
};
// Nearest Interp op
......@@ -1007,7 +1004,6 @@ struct NearestInterp : public PatternBase {
PATTERN_DECL_NODE(nearest_interp_in);
PATTERN_DECL_NODE(nearest_interp_op);
PATTERN_DECL_NODE(nearest_interp_out);
PATTERN_DECL_NODE(next_op);
};
// Matmul op
......
......@@ -274,8 +274,12 @@ bool CPUQuantizePass::IsOpDequantized(const Node* node) const {
}
bool CPUQuantizePass::IsOpQuantized(const Node* node) const {
return node->Op()->Type() == "quantize" ||
platform::HasOpINT8DataType(node->Op());
// return true only if all of outputs are ops and their are either quantize or
// have int8 data type
return all_of(node->outputs.begin(), node->outputs.end(), [](Node* output) {
return (output->IsOp() && (output->Op()->Type() == "quantize" ||
platform::HasOpINT8DataType(output->Op())));
});
}
void CPUQuantizePass::QuantizeConv(Graph* graph,
......@@ -314,7 +318,7 @@ void CPUQuantizePass::QuantizeConv(Graph* graph,
conv_pattern);
if (!AreScalesPresentForNodes(
{conv_input, conv_filter, conv_residual_data})) {
LogCannotQuantizeOp(conv_op);
LogCannotQuantizeOp(conv_op, "No scale available for the operator");
return;
}
......@@ -326,7 +330,7 @@ void CPUQuantizePass::QuantizeConv(Graph* graph,
residual_scale, is_residual_unsigned, "Scale_in_eltwise");
} else {
if (!AreScalesPresentForNodes({conv_input, conv_filter})) {
LogCannotQuantizeOp(conv_op);
LogCannotQuantizeOp(conv_op, "No scale available for the operator");
return;
}
}
......@@ -401,6 +405,7 @@ void CPUQuantizePass::QuantizeFc(Graph* graph) const {
return;
}
if (!fc->Op()->GetAttrIfExists<bool>("use_mkldnn")) {
LogCannotQuantizeOp(fc, "use_mkldnn attribute set to false");
return;
}
......@@ -409,7 +414,7 @@ void CPUQuantizePass::QuantizeFc(Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(output, output, fc_pattern);
if (!AreScalesPresentForNodes({input, weights})) {
LogCannotQuantizeOp(fc);
LogCannotQuantizeOp(fc, "No scale available for the operator");
return;
}
......@@ -471,7 +476,7 @@ void CPUQuantizePass::QuantizePool(Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(pool_output, pool_output, pool_pattern);
if (!AreScalesPresentForNodes({pool_input, pool_output})) {
LogCannotQuantizeOp(pool_op);
LogCannotQuantizeOp(pool_op, "No scale available for the operator");
return;
}
......@@ -514,7 +519,7 @@ void CPUQuantizePass::QuantizeConcat(Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(concat_out, concat_out, concat_pattern);
if (!AreScalesPresentForNodes({concat_out})) {
LogCannotQuantizeOp(concat_op);
LogCannotQuantizeOp(concat_op, "No scale available for the operator");
return;
}
......@@ -560,7 +565,7 @@ void CPUQuantizePass::QuantizePriorBox(Graph* graph) const {
prior_box_pattern);
if (!AreScalesPresentForNodes({prior_box_input})) {
LogCannotQuantizeOp(prior_box_op);
LogCannotQuantizeOp(prior_box_op, "No scale available for the operator");
return;
}
......@@ -598,17 +603,18 @@ void CPUQuantizePass::QuantizeTranspose(Graph* graph) const {
return;
}
GET_IR_NODE_FROM_SUBGRAPH(prev_op, prev_op, transpose_pattern);
GET_IR_NODE_FROM_SUBGRAPH(next_op, next_op, transpose_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose_in, transpose_in, transpose_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose_out, transpose_out, transpose_pattern);
// skip if prev op and next op is not quantized
if (!(IsOpDequantized(prev_op)) && !(IsOpQuantized(next_op))) {
if (!(IsOpDequantized(prev_op)) && !(IsOpQuantized(transpose_out))) {
LogCannotQuantizeOp(transpose_op,
"No other quantizable operators nearby");
return;
}
GET_IR_NODE_FROM_SUBGRAPH(transpose_in, transpose_in, transpose_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose_out, transpose_out, transpose_pattern);
if (!AreScalesPresentForNodes({transpose_in, transpose_out})) {
LogCannotQuantizeOp(transpose_op);
LogCannotQuantizeOp(transpose_op, "No scale available for the operator");
return;
}
......@@ -651,18 +657,17 @@ void CPUQuantizePass::QuantizeReshape(Graph* graph) const {
return;
}
GET_IR_NODE_FROM_SUBGRAPH(prev_op, prev_op, reshape_pattern);
GET_IR_NODE_FROM_SUBGRAPH(next_op, next_op, reshape_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape_in, reshape_in, reshape_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape_out, reshape_out, reshape_pattern);
// skip if prev op and next op is not quantized
if (!(IsOpDequantized(prev_op)) && !(IsOpQuantized(next_op))) {
// skip if prev op is not quantized
if (!(IsOpDequantized(prev_op)) && !(IsOpQuantized(reshape_out))) {
LogCannotQuantizeOp(reshape_op, "No other quantizable operators nearby");
return;
}
GET_IR_NODE_FROM_SUBGRAPH(reshape_in, reshape_in, reshape_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape_out, reshape_out, reshape_pattern);
if (!AreScalesPresentForNodes({reshape_in, reshape_out})) {
LogCannotQuantizeOp(reshape_op);
LogCannotQuantizeOp(reshape_op, "No scale available for the operator");
return;
}
......@@ -703,17 +708,17 @@ void CPUQuantizePass::QuantizeSlice(Graph* graph) const {
return;
}
GET_IR_NODE_FROM_SUBGRAPH(prev_op, prev_op, slice_pattern);
GET_IR_NODE_FROM_SUBGRAPH(next_op, next_op, slice_pattern);
GET_IR_NODE_FROM_SUBGRAPH(slice_in, slice_in, slice_pattern);
GET_IR_NODE_FROM_SUBGRAPH(slice_out, slice_out, slice_pattern);
// skip if prev op and next op is not quantized
if (!IsOpDequantized(prev_op) && !IsOpQuantized(next_op)) {
if (!IsOpDequantized(prev_op) && !IsOpQuantized(slice_out)) {
LogCannotQuantizeOp(slice_op, "No other quantizable operators nearby");
return;
}
GET_IR_NODE_FROM_SUBGRAPH(slice_in, slice_in, slice_pattern);
GET_IR_NODE_FROM_SUBGRAPH(slice_out, slice_out, slice_pattern);
if (!AreScalesPresentForNodes({slice_out})) {
LogCannotQuantizeOp(slice_op);
LogCannotQuantizeOp(slice_op, "No scale available for the operator");
return;
}
......@@ -758,6 +763,7 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const {
// skip if prev ops are not quantized
if (!IsOpDequantized(prev_op_x) || !IsOpDequantized(prev_op_y)) {
LogCannotQuantizeOp(matmul_op, "No other quantizable operators nearby");
return;
}
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, matmul_pattern);
......@@ -765,7 +771,7 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, matmul_pattern);
if (!AreScalesPresentForNodes({matmul_in_x, matmul_in_y})) {
LogCannotQuantizeOp(matmul_op);
LogCannotQuantizeOp(matmul_op, "No scale available for the operator");
return;
}
......@@ -832,7 +838,8 @@ void CPUQuantizePass::QuantizeElementwiseAdd(Graph* graph) const {
if (!AreScalesPresentForNodes(
{elementwise_add_x, elementwise_add_y, elementwise_add_out})) {
LogCannotQuantizeOp(elementwise_add_op);
LogCannotQuantizeOp(elementwise_add_op,
"No scale available for the operator");
return;
}
......@@ -893,7 +900,7 @@ void CPUQuantizePass::QuantizeFusionGru(Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(out, out, pattern);
if (!AreScalesPresentForNodes({x, weight_x})) {
LogCannotQuantizeOp(op);
LogCannotQuantizeOp(op, "No scale available for the operator");
return;
}
......@@ -950,7 +957,7 @@ void CPUQuantizePass::QuantizeMultiGru(Graph* graph) const {
auto wx_names = gru->Op()->Input("WeightX");
if (!AreScalesPresentForNodes({x}) ||
!AreScalesPresentForVarNames(wx_names)) {
LogCannotQuantizeOp(gru);
LogCannotQuantizeOp(gru, "No scale available for the operator");
return;
}
......@@ -1029,7 +1036,7 @@ void CPUQuantizePass::QuantizeFusionLSTM(Graph* graph) const {
// Starting from here there maybe issues
if (!AreScalesPresentForNodes({x, weight_x})) {
LogCannotQuantizeOp(op);
LogCannotQuantizeOp(op, "No scale available for the operator");
return;
}
......@@ -1081,23 +1088,21 @@ void CPUQuantizePass::QuantizeNearestInterp(Graph* graph) const {
return;
}
GET_IR_NODE_FROM_SUBGRAPH(prev_op, prev_op, nearest_interp_pattern);
GET_IR_NODE_FROM_SUBGRAPH(next_op, next_op, nearest_interp_pattern);
GET_IR_NODE_FROM_SUBGRAPH(nearest_interp_in, nearest_interp_in,
nearest_interp_pattern);
GET_IR_NODE_FROM_SUBGRAPH(nearest_interp_out, nearest_interp_out,
nearest_interp_pattern);
// skip if prev op and next op is not quantized
if (!(IsOpDequantized(prev_op)) && !(IsOpQuantized(next_op))) {
if (!(IsOpDequantized(prev_op)) && !(IsOpQuantized(nearest_interp_out))) {
LogCannotQuantizeOp(nearest_interp_op,
"There are no other quantized operators nearby, so "
"quantization is not recommended.");
"No other quantizable operators nearby");
return;
}
GET_IR_NODE_FROM_SUBGRAPH(nearest_interp_in, nearest_interp_in,
nearest_interp_pattern);
GET_IR_NODE_FROM_SUBGRAPH(nearest_interp_out, nearest_interp_out,
nearest_interp_pattern);
if (!AreScalesPresentForNodes({nearest_interp_in, nearest_interp_out})) {
LogCannotQuantizeOp(nearest_interp_op);
LogCannotQuantizeOp(nearest_interp_op,
"No scale available for the operator");
return;
}
......
......@@ -36,7 +36,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op->SetType(type);
op->SetAttr("use_mkldnn", use_mkldnn);
op->SetAttr("name", name);
if (type != "dropout" || type != "quantize" || type != "dequantize") {
if (type != "dropout" && type != "quantize" && type != "dequantize") {
op->SetAttr("mkldnn_data_type", mkldnn_data_type);
}
......@@ -371,7 +371,7 @@ TEST(CpuQuantizePass, fusion_lstm) {
}
static const std::initializer_list<std::string> variable_names_immutable_ops = {
"a", "w1", "b", "c", "d"};
"a", "w1", "b", "c", "d", "e", "f", "g"};
// a->Dequantize->b
// b->Tested Op->c
......@@ -417,36 +417,88 @@ void TestImmutableOpBetweenNonQuantizedOp(const std::string tested_op) {
SCALE * S8_MAX);
}
// a->Dropout1->b
// b->TestedOp1(won't be quantized)->c
// c->Dropout2->d
// c->TestedOp2(will be quantized)->e
// e->Pool2d1(will be quantized)->f
// e->Pool2d2(will be quantized)->g
void TestImmutableOpWithManyOutputs(const std::string tested_op) {
ProgramDesc prog;
for (auto& v : variable_names_immutable_ops) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "dropout", "Dropout1", {"a"}, {"b"}, true, "float32");
SetOp(&prog, tested_op, std::string(tested_op + "1"), {"b"}, {"c"}, true,
"int8");
SetOp(&prog, "dropout", "Dropout2", {"c"}, {"d"}, true, "float32");
SetOp(&prog, tested_op, std::string(tested_op + "2"), {"c"}, {"e"}, true,
"int8");
SetOp(&prog, "pool2d", "Pool2d1", {"e"}, {"f"}, true, "int8");
SetOp(&prog, "pool2d", "Pool2d2", {"e"}, {"g"}, true, "int8");
// 3 Quant + 3 IN + 3 DeQuant + 3 OUT
int added_nodes = 12;
std::unordered_map<std::string, int> expected_operators = {{tested_op, 2},
{"dropout", 2},
{"pool2d", 2},
{"quantize", 3},
{"dequantize", 3}};
MainTest(prog, variable_names_immutable_ops, expected_operators, added_nodes,
SCALE * S8_MAX);
}
TEST(CpuQuantizePass, reshape2) { TestImmutableOp("reshape2"); }
TEST(CpuQuantizePass, reshape2BetweenNonQuantizedOp) {
TestImmutableOpBetweenNonQuantizedOp("reshape2");
}
TEST(CpuQuantizePass, reshape2WithManyOutputs) {
TestImmutableOpWithManyOutputs("reshape2");
}
TEST(CpuQuantizePass, transpose2) { TestImmutableOp("transpose2"); }
TEST(CpuQuantizePass, transpose2BetweenNonQuantizedOp) {
TestImmutableOpBetweenNonQuantizedOp("transpose2");
}
TEST(CpuQuantizePass, transpose2WithManyOutputs) {
TestImmutableOpWithManyOutputs("transpose2");
}
TEST(CpuQuantizePass, slice) { TestImmutableOp("slice"); }
TEST(CpuQuantizePass, sliceBetweenNonQuantizedOp) {
TestImmutableOpBetweenNonQuantizedOp("slice");
}
TEST(CpuQuantizePass, sliceWithManyOutputs) {
TestImmutableOpWithManyOutputs("slice");
}
TEST(CpuQuantizePass, nearestInterp) { TestImmutableOp("nearest_interp"); }
TEST(CpuQuantizePass, nearestInterpBetweenNonQuantizedOp) {
TestImmutableOpBetweenNonQuantizedOp("nearest_interp");
}
TEST(CpuQuantizePass, nearestInterpWithManyOutputs) {
TestImmutableOpWithManyOutputs("nearest_interp");
}
TEST(CpuQuantizePass, nearestInterpV2) { TestImmutableOp("nearest_interp_v2"); }
TEST(CpuQuantizePass, nearestInterpV2BetweenNonQuantizedOp) {
TestImmutableOpBetweenNonQuantizedOp("nearest_interp_v2");
}
TEST(CpuQuantizePass, nearestInterpV2WithManyOutputs) {
TestImmutableOpWithManyOutputs("nearest_interp_v2");
}
static const std::initializer_list<std::string> variable_names_matmul = {
"a", "b", "c", "d", "e", "f"};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册