未验证 提交 c797e64d 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Add pool avg to quantization and concat scales correction (#44186)

上级 015532b4
...@@ -390,6 +390,13 @@ std::unordered_set<std::string> ComputePropagateScalesMkldnnPass::UpdateScales( ...@@ -390,6 +390,13 @@ std::unordered_set<std::string> ComputePropagateScalesMkldnnPass::UpdateScales(
} else if (out_iter != var_quant_scales->end()) { } else if (out_iter != var_quant_scales->end()) {
(*var_quant_scales)[input_name] = out_iter->second; (*var_quant_scales)[input_name] = out_iter->second;
} }
} else if (op_name == "concat") {
auto out_iter = var_quant_scales->find(op_node->Op()->Output("Out")[0]);
if (out_iter != var_quant_scales->end()) {
std::vector<std::string> input_names = op_node->Op()->Input("X");
for (auto input_name : input_names)
(*var_quant_scales)[input_name] = out_iter->second;
}
} else if (op_name == "scale") { } else if (op_name == "scale") {
const std::string output_name = op_node->Op()->Output("Out")[0]; const std::string output_name = op_node->Op()->Output("Out")[0];
auto out_iter = var_quant_scales->find(output_name); auto out_iter = var_quant_scales->find(output_name);
......
...@@ -55,23 +55,6 @@ void QuantDequantMkldnnPass::MarkSkipQuantizedOps( ...@@ -55,23 +55,6 @@ void QuantDequantMkldnnPass::MarkSkipQuantizedOps(
} }
} }
void QuantDequantMkldnnPass::MarkSkipQuantizedPool2d(ir::Graph* graph) const {
VLOG(3) << "mark avg pool2d as skip quantized op";
for (auto* op_node :
ir::TopologyVarientSort(*graph, static_cast<ir::SortKind>(0))) {
if (!op_node->IsOp()) continue;
if (op_node->Name() == "pool2d") {
auto* op_desc = op_node->Op();
auto pool_type =
BOOST_GET_CONST(std::string, op_desc->GetAttr("pooling_type"));
if (pool_type == "avg") {
op_node->Op()->SetAttr("skip_quant", 1);
}
}
}
}
void QuantDequantMkldnnPass::CollectInfoFromFake( void QuantDequantMkldnnPass::CollectInfoFromFake(
ir::Graph* graph, ir::Graph* graph,
Scope* scope, Scope* scope,
...@@ -548,7 +531,6 @@ void QuantDequantMkldnnPass::ApplyImpl(ir::Graph* graph) const { ...@@ -548,7 +531,6 @@ void QuantDequantMkldnnPass::ApplyImpl(ir::Graph* graph) const {
auto* scope = param_scope(); auto* scope = param_scope();
MarkSkipQuantizedOps(graph, skip_ops); MarkSkipQuantizedOps(graph, skip_ops);
MarkSkipQuantizedPool2d(graph);
CollectInfoFromFake(graph, scope, fake_dequantize_types, &weight_thresholds); CollectInfoFromFake(graph, scope, fake_dequantize_types, &weight_thresholds);
CollectInputScalesFromFake( CollectInputScalesFromFake(
graph, scope, fake_quantize_types, &var_quant_scales); graph, scope, fake_quantize_types, &var_quant_scales);
......
...@@ -264,6 +264,14 @@ class Quant2Int8MkldnnPass(object): ...@@ -264,6 +264,14 @@ class Quant2Int8MkldnnPass(object):
elif output_name in self._var_quant_scales: elif output_name in self._var_quant_scales:
self._var_quant_scales[ self._var_quant_scales[
input_name] = self._var_quant_scales[output_name] input_name] = self._var_quant_scales[output_name]
elif op.name() == 'concat':
output_name = op.output("Out")[0]
if output_name in self._var_quant_scales:
input_names = op.input("X")
for input_name in input_names:
self._var_quant_scales[
input_name] = self._var_quant_scales[
output_name]
elif op.name() in self._scale_ops: elif op.name() in self._scale_ops:
input_name = op.input("X")[0] input_name = op.input("X")[0]
output_name = op.output("Out")[0] output_name = op.output("Out")[0]
...@@ -595,13 +603,6 @@ class Quant2Int8MkldnnPass(object): ...@@ -595,13 +603,6 @@ class Quant2Int8MkldnnPass(object):
_compute_lstm_weight_scales("WeightX", "WeightH") _compute_lstm_weight_scales("WeightX", "WeightH")
return graph return graph
def _find_avg_pooling_ids(self, graph):
for op in graph.all_op_nodes():
if op.name() in self._pool_ops:
if op.op().attr("pooling_type") == "avg":
self._op_ids_to_skip.add(op.id())
return self._op_ids_to_skip
def _update_relu_output_scales(self, graph): def _update_relu_output_scales(self, graph):
def _set_unsigned_scale(graph, ops, op_out_name, predicate): def _set_unsigned_scale(graph, ops, op_out_name, predicate):
...@@ -651,11 +652,9 @@ class Quant2Int8MkldnnPass(object): ...@@ -651,11 +652,9 @@ class Quant2Int8MkldnnPass(object):
'reshape_transpose_matmul_mkldnn_fuse_pass') 'reshape_transpose_matmul_mkldnn_fuse_pass')
graph = self._apply_pass( graph = self._apply_pass(
graph, 'reshape_transpose_matmul_v2_mkldnn_fuse_pass') graph, 'reshape_transpose_matmul_v2_mkldnn_fuse_pass')
graph = self._apply_pass( graph = self._apply_pass(graph, 'cpu_quantize_placement_pass',
graph, 'cpu_quantize_placement_pass', ['quantize_enabled_op_types'],
['quantize_enabled_op_types', 'quantize_excluded_op_ids'], [self._ops_to_quantize])
[self._ops_to_quantize,
self._find_avg_pooling_ids(graph)])
graph = self._apply_pass( graph = self._apply_pass(
graph, 'cpu_quantize_pass', ['quant_var_scales', 'data_layout'], graph, 'cpu_quantize_pass', ['quant_var_scales', 'data_layout'],
[self._var_quant_scales, [self._var_quant_scales,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册