未验证 提交 753964a2 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Correct MultipleQuantizeSquash (#40717)

* Correct MultipleQuantizeSquash

* Correct logging
上级 99541895
...@@ -39,12 +39,13 @@ void UnlinkNodes(ir::Node* a, ir::Node* b) { ...@@ -39,12 +39,13 @@ void UnlinkNodes(ir::Node* a, ir::Node* b) {
b->inputs.end()); b->inputs.end());
} }
void LogCannotQuantizeOp(Node* op, const char* details = nullptr) { void MarkAndLogCannotQuantizeOp(Node* op, const char* details = nullptr) {
std::stringstream msg_ss; std::stringstream msg_ss;
msg_ss << "Cannot quantize operator " << op->Name() msg_ss << "Cannot quantize operator " << op->Name()
<< " (type: " << op->Op()->Type() << ", id: " << op->id() << ")."; << " (type: " << op->Op()->Type() << ", id: " << op->id() << ").";
if (details) msg_ss << " " << details; if (details) msg_ss << " " << details;
PrettyLogDetail(msg_ss.str().c_str()); VLOG(2) << msg_ss.str().c_str();
op->Op()->SetAttr("mkldnn_data_type", std::string("float32"));
} }
void LogScaleIsMissingForVarName(const std::string& name) { void LogScaleIsMissingForVarName(const std::string& name) {
...@@ -56,12 +57,19 @@ void LogScaleIsMissingForVarNode(Node* node) { ...@@ -56,12 +57,19 @@ void LogScaleIsMissingForVarNode(Node* node) {
} }
void LogQuantizationDisabled(Node* op) { void LogQuantizationDisabled(Node* op) {
std::stringstream msg_ss; VLOG(2) << "Quantization skipped for operator " << op->Name()
VLOG(4) << "Qantization skipped for operator " << op->Name()
<< " (type: " << op->Op()->Type() << ", id: " << op->id() << " (type: " << op->Op()->Type() << ", id: " << op->id()
<< "). Attribute mkldnn_data_type != \"int8\"."; << "). Attribute mkldnn_data_type != \"int8\".";
} }
void LogQuantizedOpsCounter(const std::string& type, const int counter,
const char* details = nullptr) {
std::stringstream msg_ss;
msg_ss << "--- quantized " << counter << " " << type << " ops";
if (details) msg_ss << " " << details;
PrettyLogDetail(msg_ss.str().c_str());
}
} // namespace } // namespace
enum { U8_MAX = 255, S8_MAX = 127 }; enum { U8_MAX = 255, S8_MAX = 127 };
...@@ -307,9 +315,10 @@ void CPUQuantizePass::QuantizeConv(Graph* graph, ...@@ -307,9 +315,10 @@ void CPUQuantizePass::QuantizeConv(Graph* graph,
auto has_output_scale = AreScalesPresentForNodes({conv_output}); auto has_output_scale = AreScalesPresentForNodes({conv_output});
if (with_residual_data && !has_output_scale) { if (with_residual_data && !has_output_scale) {
LogCannotQuantizeOp(conv_op, MarkAndLogCannotQuantizeOp(
"Conv op with ResidualData input cannot be quantized " conv_op,
"without output scale."); "Conv op with ResidualData input cannot be quantized "
"without output scale.");
return; return;
} }
...@@ -318,7 +327,8 @@ void CPUQuantizePass::QuantizeConv(Graph* graph, ...@@ -318,7 +327,8 @@ void CPUQuantizePass::QuantizeConv(Graph* graph,
conv_pattern); conv_pattern);
if (!AreScalesPresentForNodes( if (!AreScalesPresentForNodes(
{conv_input, conv_filter, conv_residual_data})) { {conv_input, conv_filter, conv_residual_data})) {
LogCannotQuantizeOp(conv_op, "No scale available for the operator"); MarkAndLogCannotQuantizeOp(conv_op,
"No scale available for the operator");
return; return;
} }
...@@ -330,7 +340,8 @@ void CPUQuantizePass::QuantizeConv(Graph* graph, ...@@ -330,7 +340,8 @@ void CPUQuantizePass::QuantizeConv(Graph* graph,
residual_scale, is_residual_unsigned, "Scale_in_eltwise"); residual_scale, is_residual_unsigned, "Scale_in_eltwise");
} else { } else {
if (!AreScalesPresentForNodes({conv_input, conv_filter})) { if (!AreScalesPresentForNodes({conv_input, conv_filter})) {
LogCannotQuantizeOp(conv_op, "No scale available for the operator"); MarkAndLogCannotQuantizeOp(conv_op,
"No scale available for the operator");
return; return;
} }
} }
...@@ -377,10 +388,9 @@ void CPUQuantizePass::QuantizeConv(Graph* graph, ...@@ -377,10 +388,9 @@ void CPUQuantizePass::QuantizeConv(Graph* graph,
gpd(graph, handler); gpd(graph, handler);
AddStatis(quantize_conv_count); AddStatis(quantize_conv_count);
std::stringstream msg_ss; LogQuantizedOpsCounter(
msg_ss << "--- quantized " << quantize_conv_count << " conv2d ops"; "conv2d", quantize_conv_count,
if (with_residual_data) msg_ss << " with residual connection"; ((with_residual_data) ? "with residual connection" : ""));
PrettyLogDetail(msg_ss.str().c_str());
} }
void CPUQuantizePass::QuantizeFc(Graph* graph) const { void CPUQuantizePass::QuantizeFc(Graph* graph) const {
...@@ -405,7 +415,7 @@ void CPUQuantizePass::QuantizeFc(Graph* graph) const { ...@@ -405,7 +415,7 @@ void CPUQuantizePass::QuantizeFc(Graph* graph) const {
return; return;
} }
if (!fc->Op()->GetAttrIfExists<bool>("use_mkldnn")) { if (!fc->Op()->GetAttrIfExists<bool>("use_mkldnn")) {
LogCannotQuantizeOp(fc, "use_mkldnn attribute set to false"); MarkAndLogCannotQuantizeOp(fc, "use_mkldnn attribute set to false");
return; return;
} }
...@@ -414,7 +424,7 @@ void CPUQuantizePass::QuantizeFc(Graph* graph) const { ...@@ -414,7 +424,7 @@ void CPUQuantizePass::QuantizeFc(Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(output, output, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(output, output, fc_pattern);
if (!AreScalesPresentForNodes({input, weights})) { if (!AreScalesPresentForNodes({input, weights})) {
LogCannotQuantizeOp(fc, "No scale available for the operator"); MarkAndLogCannotQuantizeOp(fc, "No scale available for the operator");
return; return;
} }
...@@ -448,10 +458,7 @@ void CPUQuantizePass::QuantizeFc(Graph* graph) const { ...@@ -448,10 +458,7 @@ void CPUQuantizePass::QuantizeFc(Graph* graph) const {
gpd(graph, handler); gpd(graph, handler);
AddStatis(quantize_fc_count); AddStatis(quantize_fc_count);
LogQuantizedOpsCounter("fc", quantize_fc_count);
std::stringstream msg_ss;
msg_ss << "--- quantized " << quantize_fc_count << " fc ops";
PrettyLogDetail(msg_ss.str().c_str());
} }
void CPUQuantizePass::QuantizePool(Graph* graph) const { void CPUQuantizePass::QuantizePool(Graph* graph) const {
...@@ -476,7 +483,8 @@ void CPUQuantizePass::QuantizePool(Graph* graph) const { ...@@ -476,7 +483,8 @@ void CPUQuantizePass::QuantizePool(Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(pool_output, pool_output, pool_pattern); GET_IR_NODE_FROM_SUBGRAPH(pool_output, pool_output, pool_pattern);
if (!AreScalesPresentForNodes({pool_input, pool_output})) { if (!AreScalesPresentForNodes({pool_input, pool_output})) {
LogCannotQuantizeOp(pool_op, "No scale available for the operator"); MarkAndLogCannotQuantizeOp(pool_op,
"No scale available for the operator");
return; return;
} }
...@@ -494,8 +502,7 @@ void CPUQuantizePass::QuantizePool(Graph* graph) const { ...@@ -494,8 +502,7 @@ void CPUQuantizePass::QuantizePool(Graph* graph) const {
gpd(graph, handler); gpd(graph, handler);
AddStatis(quantize_pool_count); AddStatis(quantize_pool_count);
LogQuantizedOpsCounter("pool2d", quantize_pool_count);
PrettyLogDetail("--- quantized %d pool2d ops", quantize_pool_count);
} }
void CPUQuantizePass::QuantizeConcat(Graph* graph) const { void CPUQuantizePass::QuantizeConcat(Graph* graph) const {
...@@ -519,7 +526,8 @@ void CPUQuantizePass::QuantizeConcat(Graph* graph) const { ...@@ -519,7 +526,8 @@ void CPUQuantizePass::QuantizeConcat(Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(concat_out, concat_out, concat_pattern); GET_IR_NODE_FROM_SUBGRAPH(concat_out, concat_out, concat_pattern);
if (!AreScalesPresentForNodes({concat_out})) { if (!AreScalesPresentForNodes({concat_out})) {
LogCannotQuantizeOp(concat_op, "No scale available for the operator"); MarkAndLogCannotQuantizeOp(concat_op,
"No scale available for the operator");
return; return;
} }
...@@ -539,8 +547,7 @@ void CPUQuantizePass::QuantizeConcat(Graph* graph) const { ...@@ -539,8 +547,7 @@ void CPUQuantizePass::QuantizeConcat(Graph* graph) const {
gpd(graph, handler); gpd(graph, handler);
AddStatis(quantize_concat_count); AddStatis(quantize_concat_count);
LogQuantizedOpsCounter("concat", quantize_concat_count);
PrettyLogDetail("--- quantized %d concat ops", quantize_concat_count);
} }
void CPUQuantizePass::QuantizePriorBox(Graph* graph) const { void CPUQuantizePass::QuantizePriorBox(Graph* graph) const {
...@@ -565,7 +572,8 @@ void CPUQuantizePass::QuantizePriorBox(Graph* graph) const { ...@@ -565,7 +572,8 @@ void CPUQuantizePass::QuantizePriorBox(Graph* graph) const {
prior_box_pattern); prior_box_pattern);
if (!AreScalesPresentForNodes({prior_box_input})) { if (!AreScalesPresentForNodes({prior_box_input})) {
LogCannotQuantizeOp(prior_box_op, "No scale available for the operator"); MarkAndLogCannotQuantizeOp(prior_box_op,
"No scale available for the operator");
return; return;
} }
...@@ -580,9 +588,7 @@ void CPUQuantizePass::QuantizePriorBox(Graph* graph) const { ...@@ -580,9 +588,7 @@ void CPUQuantizePass::QuantizePriorBox(Graph* graph) const {
gpd(graph, handler); gpd(graph, handler);
AddStatis(quantize_prior_box_count); AddStatis(quantize_prior_box_count);
LogQuantizedOpsCounter("prior_box", quantize_prior_box_count);
PrettyLogDetail("--- quantized %d prior_box ops",
quantize_prior_box_count);
} }
void CPUQuantizePass::QuantizeTranspose(Graph* graph) const { void CPUQuantizePass::QuantizeTranspose(Graph* graph) const {
...@@ -608,13 +614,14 @@ void CPUQuantizePass::QuantizeTranspose(Graph* graph) const { ...@@ -608,13 +614,14 @@ void CPUQuantizePass::QuantizeTranspose(Graph* graph) const {
// skip if prev op and next op is not quantized // skip if prev op and next op is not quantized
if (!(IsOpDequantized(prev_op)) && !(IsOpQuantized(transpose_out))) { if (!(IsOpDequantized(prev_op)) && !(IsOpQuantized(transpose_out))) {
LogCannotQuantizeOp(transpose_op, MarkAndLogCannotQuantizeOp(transpose_op,
"No other quantizable operators nearby"); "No other quantizable operators nearby");
return; return;
} }
if (!AreScalesPresentForNodes({transpose_in, transpose_out})) { if (!AreScalesPresentForNodes({transpose_in, transpose_out})) {
LogCannotQuantizeOp(transpose_op, "No scale available for the operator"); MarkAndLogCannotQuantizeOp(transpose_op,
"No scale available for the operator");
return; return;
} }
...@@ -634,9 +641,7 @@ void CPUQuantizePass::QuantizeTranspose(Graph* graph) const { ...@@ -634,9 +641,7 @@ void CPUQuantizePass::QuantizeTranspose(Graph* graph) const {
gpd(graph, handler); gpd(graph, handler);
AddStatis(quantize_transpose_count); AddStatis(quantize_transpose_count);
LogQuantizedOpsCounter("transpose2", quantize_transpose_count);
PrettyLogDetail("--- quantized %d transpose ops",
quantize_transpose_count);
} }
void CPUQuantizePass::QuantizeReshape(Graph* graph) const { void CPUQuantizePass::QuantizeReshape(Graph* graph) const {
...@@ -662,12 +667,14 @@ void CPUQuantizePass::QuantizeReshape(Graph* graph) const { ...@@ -662,12 +667,14 @@ void CPUQuantizePass::QuantizeReshape(Graph* graph) const {
// skip if prev op is not quantized // skip if prev op is not quantized
if (!(IsOpDequantized(prev_op)) && !(IsOpQuantized(reshape_out))) { if (!(IsOpDequantized(prev_op)) && !(IsOpQuantized(reshape_out))) {
LogCannotQuantizeOp(reshape_op, "No other quantizable operators nearby"); MarkAndLogCannotQuantizeOp(reshape_op,
"No other quantizable operators nearby");
return; return;
} }
if (!AreScalesPresentForNodes({reshape_in, reshape_out})) { if (!AreScalesPresentForNodes({reshape_in, reshape_out})) {
LogCannotQuantizeOp(reshape_op, "No scale available for the operator"); MarkAndLogCannotQuantizeOp(reshape_op,
"No scale available for the operator");
return; return;
} }
...@@ -686,8 +693,7 @@ void CPUQuantizePass::QuantizeReshape(Graph* graph) const { ...@@ -686,8 +693,7 @@ void CPUQuantizePass::QuantizeReshape(Graph* graph) const {
gpd(graph, handler); gpd(graph, handler);
AddStatis(quantize_reshape_count); AddStatis(quantize_reshape_count);
LogQuantizedOpsCounter("reshape2", quantize_reshape_count);
PrettyLogDetail("--- quantized %d reshape ops", quantize_reshape_count);
} }
void CPUQuantizePass::QuantizeSlice(Graph* graph) const { void CPUQuantizePass::QuantizeSlice(Graph* graph) const {
...@@ -713,12 +719,14 @@ void CPUQuantizePass::QuantizeSlice(Graph* graph) const { ...@@ -713,12 +719,14 @@ void CPUQuantizePass::QuantizeSlice(Graph* graph) const {
// skip if prev op and next op is not quantized // skip if prev op and next op is not quantized
if (!IsOpDequantized(prev_op) && !IsOpQuantized(slice_out)) { if (!IsOpDequantized(prev_op) && !IsOpQuantized(slice_out)) {
LogCannotQuantizeOp(slice_op, "No other quantizable operators nearby"); MarkAndLogCannotQuantizeOp(slice_op,
"No other quantizable operators nearby");
return; return;
} }
if (!AreScalesPresentForNodes({slice_out})) { if (!AreScalesPresentForNodes({slice_out})) {
LogCannotQuantizeOp(slice_op, "No scale available for the operator"); MarkAndLogCannotQuantizeOp(slice_op,
"No scale available for the operator");
return; return;
} }
...@@ -737,8 +745,7 @@ void CPUQuantizePass::QuantizeSlice(Graph* graph) const { ...@@ -737,8 +745,7 @@ void CPUQuantizePass::QuantizeSlice(Graph* graph) const {
gpd(graph, handler); gpd(graph, handler);
AddStatis(quantize_slice_count); AddStatis(quantize_slice_count);
LogQuantizedOpsCounter("slice", quantize_slice_count);
PrettyLogDetail("--- quantized %d slice ops", quantize_slice_count);
} }
void CPUQuantizePass::QuantizeMatmul(Graph* graph) const { void CPUQuantizePass::QuantizeMatmul(Graph* graph) const {
...@@ -763,7 +770,8 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const { ...@@ -763,7 +770,8 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const {
// skip if prev ops are not quantized // skip if prev ops are not quantized
if (!IsOpDequantized(prev_op_x) || !IsOpDequantized(prev_op_y)) { if (!IsOpDequantized(prev_op_x) || !IsOpDequantized(prev_op_y)) {
LogCannotQuantizeOp(matmul_op, "No other quantizable operators nearby"); MarkAndLogCannotQuantizeOp(matmul_op,
"No other quantizable operators nearby");
return; return;
} }
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, matmul_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, matmul_pattern);
...@@ -771,7 +779,8 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const { ...@@ -771,7 +779,8 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, matmul_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, matmul_pattern);
if (!AreScalesPresentForNodes({matmul_in_x, matmul_in_y})) { if (!AreScalesPresentForNodes({matmul_in_x, matmul_in_y})) {
LogCannotQuantizeOp(matmul_op, "No scale available for the operator"); MarkAndLogCannotQuantizeOp(matmul_op,
"No scale available for the operator");
return; return;
} }
...@@ -803,8 +812,7 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const { ...@@ -803,8 +812,7 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const {
}; };
gpd(graph, handler); gpd(graph, handler);
AddStatis(quantize_matmul_count); AddStatis(quantize_matmul_count);
LogQuantizedOpsCounter("matmul", quantize_matmul_count);
PrettyLogDetail("--- quantized %d matmul ops", quantize_matmul_count);
} }
void CPUQuantizePass::QuantizeElementwise( void CPUQuantizePass::QuantizeElementwise(
...@@ -840,8 +848,8 @@ void CPUQuantizePass::QuantizeElementwise( ...@@ -840,8 +848,8 @@ void CPUQuantizePass::QuantizeElementwise(
if (!AreScalesPresentForNodes( if (!AreScalesPresentForNodes(
{elementwise_x, elementwise_y, elementwise_out})) { {elementwise_x, elementwise_y, elementwise_out})) {
LogCannotQuantizeOp(elementwise_op, MarkAndLogCannotQuantizeOp(elementwise_op,
"No scale available for the operator"); "No scale available for the operator");
return; return;
} }
...@@ -851,8 +859,8 @@ void CPUQuantizePass::QuantizeElementwise( ...@@ -851,8 +859,8 @@ void CPUQuantizePass::QuantizeElementwise(
// TODO(sfraczek): add support for different signness // TODO(sfraczek): add support for different signness
if (is_x_unsigned != is_y_unsigned) { if (is_x_unsigned != is_y_unsigned) {
LogCannotQuantizeOp(elementwise_op, MarkAndLogCannotQuantizeOp(
"Elementwise inputs must be of the same type."); elementwise_op, "Elementwise inputs must be of the same type.");
return; return;
} }
...@@ -872,9 +880,7 @@ void CPUQuantizePass::QuantizeElementwise( ...@@ -872,9 +880,7 @@ void CPUQuantizePass::QuantizeElementwise(
}; };
gpd(graph, handler); gpd(graph, handler);
AddStatis(quantize_elementwise_count); AddStatis(quantize_elementwise_count);
LogQuantizedOpsCounter(elementwise_type, quantize_elementwise_count);
PrettyLogDetail("--- quantized %d %s ops", quantize_elementwise_count,
elementwise_type);
} }
void CPUQuantizePass::QuantizeFusionGru(Graph* graph) const { void CPUQuantizePass::QuantizeFusionGru(Graph* graph) const {
...@@ -900,7 +906,7 @@ void CPUQuantizePass::QuantizeFusionGru(Graph* graph) const { ...@@ -900,7 +906,7 @@ void CPUQuantizePass::QuantizeFusionGru(Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(out, out, pattern); GET_IR_NODE_FROM_SUBGRAPH(out, out, pattern);
if (!AreScalesPresentForNodes({x, weight_x})) { if (!AreScalesPresentForNodes({x, weight_x})) {
LogCannotQuantizeOp(op, "No scale available for the operator"); MarkAndLogCannotQuantizeOp(op, "No scale available for the operator");
return; return;
} }
...@@ -929,8 +935,7 @@ void CPUQuantizePass::QuantizeFusionGru(Graph* graph) const { ...@@ -929,8 +935,7 @@ void CPUQuantizePass::QuantizeFusionGru(Graph* graph) const {
}; };
gpd(graph, handler); gpd(graph, handler);
AddStatis(quantize_count); AddStatis(quantize_count);
LogQuantizedOpsCounter("fusion_gru", quantize_count);
PrettyLogDetail("--- quantized %d fusion_gru ops", quantize_count);
} }
void CPUQuantizePass::QuantizeMultiGru(Graph* graph) const { void CPUQuantizePass::QuantizeMultiGru(Graph* graph) const {
...@@ -957,7 +962,7 @@ void CPUQuantizePass::QuantizeMultiGru(Graph* graph) const { ...@@ -957,7 +962,7 @@ void CPUQuantizePass::QuantizeMultiGru(Graph* graph) const {
auto wx_names = gru->Op()->Input("WeightX"); auto wx_names = gru->Op()->Input("WeightX");
if (!AreScalesPresentForNodes({x}) || if (!AreScalesPresentForNodes({x}) ||
!AreScalesPresentForVarNames(wx_names)) { !AreScalesPresentForVarNames(wx_names)) {
LogCannotQuantizeOp(gru, "No scale available for the operator"); MarkAndLogCannotQuantizeOp(gru, "No scale available for the operator");
return; return;
} }
...@@ -1007,8 +1012,7 @@ void CPUQuantizePass::QuantizeMultiGru(Graph* graph) const { ...@@ -1007,8 +1012,7 @@ void CPUQuantizePass::QuantizeMultiGru(Graph* graph) const {
}; };
gpd(graph, handler); gpd(graph, handler);
AddStatis(quantize_count); AddStatis(quantize_count);
LogQuantizedOpsCounter("multi_gru", quantize_count);
PrettyLogDetail("--- quantized %d multi_gru ops", quantize_count);
} }
void CPUQuantizePass::QuantizeFusionLSTM(Graph* graph) const { void CPUQuantizePass::QuantizeFusionLSTM(Graph* graph) const {
...@@ -1036,7 +1040,7 @@ void CPUQuantizePass::QuantizeFusionLSTM(Graph* graph) const { ...@@ -1036,7 +1040,7 @@ void CPUQuantizePass::QuantizeFusionLSTM(Graph* graph) const {
// Starting from here there maybe issues // Starting from here there maybe issues
if (!AreScalesPresentForNodes({x, weight_x})) { if (!AreScalesPresentForNodes({x, weight_x})) {
LogCannotQuantizeOp(op, "No scale available for the operator"); MarkAndLogCannotQuantizeOp(op, "No scale available for the operator");
return; return;
} }
...@@ -1065,8 +1069,7 @@ void CPUQuantizePass::QuantizeFusionLSTM(Graph* graph) const { ...@@ -1065,8 +1069,7 @@ void CPUQuantizePass::QuantizeFusionLSTM(Graph* graph) const {
}; };
gpd(graph, handler); gpd(graph, handler);
AddStatis(quantize_count); AddStatis(quantize_count);
LogQuantizedOpsCounter("fusion_lstm", quantize_count);
PrettyLogDetail("--- quantized %d fusion_lstm ops", quantize_count);
} }
void CPUQuantizePass::QuantizeNearestInterp(Graph* graph) const { void CPUQuantizePass::QuantizeNearestInterp(Graph* graph) const {
...@@ -1095,14 +1098,14 @@ void CPUQuantizePass::QuantizeNearestInterp(Graph* graph) const { ...@@ -1095,14 +1098,14 @@ void CPUQuantizePass::QuantizeNearestInterp(Graph* graph) const {
// skip if prev op and next op is not quantized // skip if prev op and next op is not quantized
if (!(IsOpDequantized(prev_op)) && !(IsOpQuantized(nearest_interp_out))) { if (!(IsOpDequantized(prev_op)) && !(IsOpQuantized(nearest_interp_out))) {
LogCannotQuantizeOp(nearest_interp_op, MarkAndLogCannotQuantizeOp(nearest_interp_op,
"No other quantizable operators nearby"); "No other quantizable operators nearby");
return; return;
} }
if (!AreScalesPresentForNodes({nearest_interp_in, nearest_interp_out})) { if (!AreScalesPresentForNodes({nearest_interp_in, nearest_interp_out})) {
LogCannotQuantizeOp(nearest_interp_op, MarkAndLogCannotQuantizeOp(nearest_interp_op,
"No scale available for the operator"); "No scale available for the operator");
return; return;
} }
...@@ -1123,9 +1126,7 @@ void CPUQuantizePass::QuantizeNearestInterp(Graph* graph) const { ...@@ -1123,9 +1126,7 @@ void CPUQuantizePass::QuantizeNearestInterp(Graph* graph) const {
gpd(graph, handler); gpd(graph, handler);
AddStatis(quantize_nearest_interp_count); AddStatis(quantize_nearest_interp_count);
LogQuantizedOpsCounter("nearest_interp", quantize_nearest_interp_count);
PrettyLogDetail("--- quantized %d nearest_interp ops",
quantize_nearest_interp_count);
} }
void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const { void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
......
...@@ -434,9 +434,17 @@ void CPUQuantizeSquashPass::MultipleQuantizeSquash(Graph* graph) const { ...@@ -434,9 +434,17 @@ void CPUQuantizeSquashPass::MultipleQuantizeSquash(Graph* graph) const {
platform::errors::NotFound("Operator after quantize operator(%s) " platform::errors::NotFound("Operator after quantize operator(%s) "
"should has quantize output as input.", "should has quantize output as input.",
quant_out->Name())); quant_out->Name()));
last_op->Op()->SetInput(
last_op_input_name, // update the next operator input,
std::vector<std::string>({first_quant_out->Name()})); // by replacing quant_out with first_quant_out
auto last_op_names = last_op->Op()->Input(last_op_input_name);
last_op_names.erase(std::remove(last_op_names.begin(),
last_op_names.end(), quant_out->Name()),
last_op_names.end());
last_op_names.push_back(first_quant_out->Name());
last_op->Op()->SetInput(last_op_input_name,
std::vector<std::string>(last_op_names));
IR_NODE_LINK_TO(first_quant_out, last_op); IR_NODE_LINK_TO(first_quant_out, last_op);
GraphSafeRemoveNodes(graph, {quant_op, quant_out}); GraphSafeRemoveNodes(graph, {quant_op, quant_out});
removed_quantize++; removed_quantize++;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册