未验证 提交 78d4f0cc 编写于 作者: W Wojciech Uss 提交者: GitHub

add option to exclude ops by id from quantization (#24689)

上级 62b4ff7d
......@@ -25,6 +25,9 @@ namespace paddle {
namespace framework {
namespace ir {
using EigenVectorArrayMap = Eigen::Map<Eigen::Array<double, Eigen::Dynamic, 1>>;
using string::PrettyLogDetail;
namespace {
void UnlinkNodes(ir::Node* a, ir::Node* b) {
......@@ -34,13 +37,24 @@ void UnlinkNodes(ir::Node* a, ir::Node* b) {
b->inputs.end());
}
void LogCannotQuantizeOp(Node* op) {
std::stringstream msg_ss;
msg_ss << "Cannot quantize operator " << op->Name()
<< " (type: " << op->Op()->Type() << ", id: " << op->id() << ").";
PrettyLogDetail(msg_ss.str().c_str());
}
void LogScaleIsMissingForVar(Node* var) {
std::stringstream msg_ss;
msg_ss << "Quantization scale for the variable " << var->Name()
<< " is missing.";
PrettyLogDetail(msg_ss.str().c_str());
}
} // namespace
enum { U8_MAX = 255, S8_MAX = 127 };
using EigenVectorArrayMap = Eigen::Map<Eigen::Array<double, Eigen::Dynamic, 1>>;
using string::PrettyLogDetail;
void CPUQuantizePass::QuantizeInput(Graph* g, Node* op, Node* input,
std::string input_name, double scale_to_one,
bool is_unsigned,
......@@ -177,17 +191,8 @@ bool CPUQuantizePass::AreScalesPresentForNodes(
for (auto node : nodes) {
if (scales.count(node->Name()) == 0) {
present = false;
std::stringstream msg_ss;
msg_ss << "Quantization scale for the variable " << node->Name()
<< " is missing.";
PrettyLogDetail(msg_ss.str().c_str());
}
LogScaleIsMissingForVar(node);
}
if (!present) {
std::stringstream msg_ss;
msg_ss << "Cannot quantize operator " << op_node->Name()
<< " (type: " << op_node->Op()->Type() << ").";
PrettyLogDetail(msg_ss.str().c_str());
}
return present;
}
......@@ -243,9 +248,11 @@ void CPUQuantizePass::QuantizeConv(Graph* graph,
if (with_residual_data) {
GET_IR_NODE_FROM_SUBGRAPH(conv_residual_data, conv_residual_data,
conv_pattern);
if (!AreScalesPresentForNodes(conv_op, {conv_input, conv_filter,
conv_residual_data, conv_output}))
if (!AreScalesPresentForNodes(
conv_op, {conv_input, conv_filter, conv_residual_data})) {
LogCannotQuantizeOp(conv_op);
return;
}
bool is_residual_unsigned{false};
auto residual_scale =
......@@ -254,10 +261,11 @@ void CPUQuantizePass::QuantizeConv(Graph* graph,
QuantizeInput(g, conv_op, conv_residual_data, "ResidualData",
residual_scale, is_residual_unsigned, "Scale_in_eltwise");
} else {
if (!AreScalesPresentForNodes(conv_op,
{conv_input, conv_filter, conv_output}))
if (!AreScalesPresentForNodes(conv_op, {conv_input, conv_filter})) {
LogCannotQuantizeOp(conv_op);
return;
}
}
bool is_input_unsigned{false};
auto input_scale = GetScaleValueForNode(conv_input, &is_input_unsigned);
......@@ -274,10 +282,16 @@ void CPUQuantizePass::QuantizeConv(Graph* graph,
conv_op->Op()->SetAttr("Scale_weights", filter_scale);
// if quantization scale is missing for output tensor, return fp32 data
if (AreScalesPresentForNodes(conv_op, {conv_output})) {
bool is_output_unsigned{false};
auto output_scale = GetScaleValueForNode(conv_output, &is_output_unsigned);
auto output_scale =
GetScaleValueForNode(conv_output, &is_output_unsigned);
DequantizeOutput(g, conv_op, conv_output, "Output", output_scale,
is_output_unsigned, "Scale_out");
} else {
conv_op->Op()->SetAttr("force_fp32_output", true);
}
// change threshold in bounded ReLu
if (conv_op->Op()->GetAttrIfExists<std::string>("fuse_activation") ==
......@@ -327,7 +341,10 @@ void CPUQuantizePass::QuantizeFc(Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(input, input, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(output, output, fc_pattern);
if (!AreScalesPresentForNodes(fc, {input, weights, output})) return;
if (!AreScalesPresentForNodes(fc, {input, weights})) {
LogCannotQuantizeOp(fc);
return;
}
bool is_input_unsigned{false};
auto input_scale = GetScaleValueForNode(input, &is_input_unsigned);
......@@ -344,10 +361,15 @@ void CPUQuantizePass::QuantizeFc(Graph* graph) const {
fc->Op()->SetAttr("Scale_weights", filter_scale);
// if quantization scale is missing for output tensor, return fp32 data
if (AreScalesPresentForNodes(fc, {output})) {
bool is_output_unsigned{false};
auto output_scale = GetScaleValueForNode(output, &is_output_unsigned);
DequantizeOutput(g, fc, output, "Out", output_scale, is_output_unsigned,
"Scale_out");
} else {
fc->Op()->SetAttr("force_fp32_output", true);
}
++quantize_fc_count;
};
......@@ -379,7 +401,10 @@ void CPUQuantizePass::QuantizePool(Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(pool_input, pool_input, pool_pattern);
GET_IR_NODE_FROM_SUBGRAPH(pool_output, pool_output, pool_pattern);
if (!AreScalesPresentForNodes(pool_op, {pool_input, pool_output})) return;
if (!AreScalesPresentForNodes(pool_op, {pool_input, pool_output})) {
LogCannotQuantizeOp(pool_op);
return;
}
bool is_input_unsigned{false};
auto input_scale = GetScaleValueForNode(pool_input, &is_input_unsigned);
......@@ -417,7 +442,10 @@ void CPUQuantizePass::QuantizeConcat(Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(concat_out, concat_out, concat_pattern);
if (!AreScalesPresentForNodes(concat_op, {concat_out})) return;
if (!AreScalesPresentForNodes(concat_op, {concat_out})) {
LogCannotQuantizeOp(concat_op);
return;
}
// if all inputs were unsigned, then the output was set to unsigned
// during the scale calculation step
......@@ -458,7 +486,10 @@ void CPUQuantizePass::QuantizePriorBox(Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(prior_box_input, prior_box_input,
prior_box_pattern);
if (!AreScalesPresentForNodes(prior_box_op, {prior_box_input})) return;
if (!AreScalesPresentForNodes(prior_box_op, {prior_box_input})) {
LogCannotQuantizeOp(prior_box_op);
return;
}
bool is_input_unsigned{false};
auto input_scale =
......@@ -503,8 +534,11 @@ void CPUQuantizePass::QuantizeTranspose(Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(transpose_in, transpose_in, transpose_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose_out, transpose_out, transpose_pattern);
if (!AreScalesPresentForNodes(transpose_op, {transpose_in, transpose_out}))
if (!AreScalesPresentForNodes(transpose_op,
{transpose_in, transpose_out})) {
LogCannotQuantizeOp(transpose_op);
return;
}
bool is_input_unsigned{false};
auto input_scale = GetScaleValueForNode(transpose_in, &is_input_unsigned);
......@@ -555,8 +589,10 @@ void CPUQuantizePass::QuantizeReshape(Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(reshape_in, reshape_in, reshape_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape_out, reshape_out, reshape_pattern);
if (!AreScalesPresentForNodes(reshape_op, {reshape_in, reshape_out}))
if (!AreScalesPresentForNodes(reshape_op, {reshape_in, reshape_out})) {
LogCannotQuantizeOp(reshape_op);
return;
}
bool is_input_unsigned{false};
auto input_scale = GetScaleValueForNode(reshape_in, &is_input_unsigned);
......@@ -605,9 +641,10 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, matmul_pattern);
if (!AreScalesPresentForNodes(matmul_op,
{matmul_in_x, matmul_in_y, matmul_out}))
if (!AreScalesPresentForNodes(matmul_op, {matmul_in_x, matmul_in_y})) {
LogCannotQuantizeOp(matmul_op);
return;
}
bool is_x_unsigned{false}, is_y_unsigned{false};
auto input_x_scale = GetScaleValueForNode(matmul_in_x, &is_x_unsigned);
......@@ -621,10 +658,15 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const {
QuantizeInput(g, matmul_op, matmul_in_y, "Y", input_y_scale, is_y_unsigned,
"Scale_y");
// if quantization scale is missing for output tensor, return fp32 data
if (AreScalesPresentForNodes(matmul_op, {matmul_out})) {
bool is_output_unsigned{false};
auto output_scale = GetScaleValueForNode(matmul_out, &is_output_unsigned);
DequantizeOutput(g, matmul_op, matmul_out, "Out", output_scale,
is_output_unsigned, "Scale_out");
} else {
matmul_op->Op()->SetAttr("force_fp32_output", true);
}
++quantize_matmul_count;
};
......
......@@ -37,6 +37,7 @@ class Qat2Int8MkldnnPass(object):
def __init__(self,
_ops_to_quantize,
_op_ids_to_skip=None,
_scope=None,
_place=None,
_core=None,
......@@ -54,6 +55,8 @@ class Qat2Int8MkldnnPass(object):
'fake_dequantize_max_abs', 'fake_channel_wise_dequantize_max_abs'
]
self._ops_to_quantize = _ops_to_quantize
self._op_ids_to_skip = _op_ids_to_skip if _op_ids_to_skip != None else set(
[-1])
self._scale_immutable_ops = [
'transpose2', 'reshape2', 'pool2d', 'scale'
]
......@@ -61,6 +64,7 @@ class Qat2Int8MkldnnPass(object):
self._pool_ops = ['pool2d']
self._mul_ops = ['mul']
self._fc_ops = ['fc']
self._relu_ops = ['relu', 'relu6']
self._matmul_ops = ['matmul']
self._weight_scales = {}
# Collect the Input and Output sclaes from Fake QAT models
......@@ -81,7 +85,6 @@ class Qat2Int8MkldnnPass(object):
graph = self._compute_weight_scales(graph)
graph = self._update_relu_output_scales(graph)
graph = self._propagate_scales(graph)
graph = self._set_dummy_out_scales(graph)
graph = self._quantize_fp32_graph(graph)
graph = self._optimize_int8_graph(graph)
graph = self._cleanup(graph)
......@@ -91,6 +94,9 @@ class Qat2Int8MkldnnPass(object):
assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.'
graph = self._gather_weight_scales_from_fake(graph)
graph = self._remove_fake_ops(graph)
graph = self._dequantize_weights(graph)
graph = self._optimize_fp32_graph(graph)
graph = self._cleanup(graph)
return graph
......@@ -100,12 +106,23 @@ class Qat2Int8MkldnnPass(object):
tensor.set(scale, core.CPUPlace())
return tensor
def _is_conv_quantized(self):
return any(op_type in self._ops_to_quantize
for op_type in self._conv_ops)
def _is_quantizing_all_ops(self):
return len(self._ops_to_quantize) == 0
def _is_any_of_op_types_in_graph(self, op_types, graph):
return any(op.name() in op_types for op in graph.all_op_nodes())
def _is_fc_quantized(self):
return 'fc' in self._ops_to_quantize
def _is_any_of_op_types_quantized(self, op_types, graph):
return self._is_any_of_op_types_in_graph(
op_types, graph) and (self._is_quantizing_all_ops() or
any(op_type in self._ops_to_quantize
for op_type in op_types))
def _is_conv_quantized(self, graph):
return self._is_any_of_op_types_quantized(self._conv_ops, graph)
def _is_fc_quantized(self, graph):
return self._is_any_of_op_types_quantized(self._fc_ops, graph)
def _gather_input_scales_from_fake(self, graph):
def _add_scale_for_vars(var_names, use_unsigned_int, lod_tensor):
......@@ -209,32 +226,6 @@ class Qat2Int8MkldnnPass(object):
return graph
def _set_dummy_out_scales(self, graph):
'''
For the output tensors of fc, conv2d and matmul ops that do not have an assigned scale,
assign a dummy scale (same scale as input), so that the quantize pass
won't fail. In the end these scales aren't used, since the ops that
have an unassigend output scale will have a force_fp32_output attr
set to True.
'''
def _set_scale(op, op_types, input_names, output_name):
scales = self._var_quant_scales
should_set = op.name() in op_types \
and op.output(output_name)[0] not in scales \
and all(op.input(input_name)[0] in scales for input_name in input_names)
if should_set:
output_var_name = op.output(output_name)[0]
input_var_name = op.input(input_names[0])[0]
scales[output_var_name] = scales[input_var_name]
for op in graph.all_op_nodes():
_set_scale(op, self._conv_ops, ["Input"], "Output")
_set_scale(op, self._fc_ops, ["Input"], "Out")
_set_scale(op, self._matmul_ops, ["X", "Y"], "Out")
return graph
def _load_param(self, scope, param_name):
return np.array(scope.find_var(param_name).get_tensor())
......@@ -353,7 +344,7 @@ class Qat2Int8MkldnnPass(object):
graph = self._apply_pass(graph, 'conv_relu6_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'fc_fuse_pass',
['use_gpu', 'use_fc_padding'], [False, False])
if self._is_fc_quantized:
if self._is_fc_quantized(graph):
graph = self._apply_pass(graph, 'fc_mkldnn_pass')
graph = self._apply_pass(graph, 'matmul_transpose_reshape_fuse_pass')
return graph
......@@ -435,15 +426,14 @@ class Qat2Int8MkldnnPass(object):
return graph
def _find_avg_pooling_ids(self, graph):
ids = []
for op in graph.all_op_nodes():
if op.name() in self._pool_ops:
if op.op().attr("pooling_type") == "avg":
ids.append(op.id())
return set(ids) if len(ids) else set([-1])
self._op_ids_to_skip.add(op.id())
return self._op_ids_to_skip
def _update_relu_output_scales(self, graph):
def _update_scale(graph, ops, op_out_name, predicate):
def _set_unsigned_scale(graph, ops, op_out_name, predicate):
'''
Sets the type of an output scale of a passed op type(s) to 'unsigned int8' if the
predicate applied on op passes. Typically, the predicate checks if op's
......@@ -458,20 +448,21 @@ class Qat2Int8MkldnnPass(object):
self._var_quant_scales[out_name] = (True, tensor)
return graph
if self._is_conv_quantized():
conv_predicate = lambda op: op.attr("fuse_activation") == 'relu' and \
conv_predicate = lambda op: op.attr("fuse_activation") in self._relu_ops and \
op.attr("fuse_residual_connection") == False
graph = _update_scale(graph, self._conv_ops, "Output",
graph = _set_unsigned_scale(graph, self._conv_ops, "Output",
conv_predicate)
if self._is_fc_quantized():
fc_predicate = lambda op: op.attr("activation_type") == 'relu'
graph = _update_scale(graph, self._fc_ops, "Out", fc_predicate)
fc_predicate = lambda op: op.attr("activation_type") in self._relu_ops
graph = _set_unsigned_scale(graph, self._fc_ops, "Out", fc_predicate)
graph = _set_unsigned_scale(graph, self._relu_ops, 'Out',
lambda op: True)
return graph
def _get_data_layout(self):
return 'NHWC' if self._is_conv_quantized() else 'NCHW'
def _get_data_layout(self, graph):
return 'NHWC' if self._is_conv_quantized(graph) else 'NCHW'
def _quantize_fp32_graph(self, graph):
ir_pass = self._core.get_pass('cpu_quantize_placement_pass')
......@@ -488,6 +479,6 @@ class Qat2Int8MkldnnPass(object):
'reshape_transpose_matmul_mkldnn_fuse_pass')
graph = self._apply_pass(
graph, 'cpu_quantize_pass', ['quant_var_scales', 'data_layout'],
[self._var_quant_scales, self._get_data_layout()])
[self._var_quant_scales, self._get_data_layout(graph)])
graph = self._apply_pass(graph, 'cpu_quantize_squash_pass')
return graph
......@@ -237,11 +237,13 @@ You can use the `qat2_int8_image_classification_comparison.py` script to reprodu
* `--fp32_model` - a path to an FP32 model whose accuracy will be measured and compared to the accuracy of the INT8 model.
* `--infer_data` - a path to the validation dataset.
The following option is also accepted:
The following options are also accepted:
* `--ops_to_quantize` - a comma-separated list of operator types to quantize. If the option is not used, an attempt to quantize all quantizable operators will be made, and in that case only quantizable operators which have quantization scales provided in the QAT model will be quantized. When deciding which operators to put on the list, the following have to be considered:
* Only operators which support quantization will be taken into account.
* All the quantizable operators from the list, which are present in the model, must have quantization scales provided in the model. Otherwise, quantization of the operator will be skipped with a message saying which variable is missing a quantization scale.
* Sometimes it may be suboptimal to quantize all quantizable operators in the model (cf. *Notes* in the **Gathering scales** section above). To find the optimal configuration for this option, user can run benchmark a few times with different lists of quantized operators present in the model and compare the results. For Image Classification models mentioned above the list usually comprises of `conv2d` and `pool2d` operators.
* `--op_ids_to_skip` - a comma-separated list of operator ids to skip in quantization. To get an id of a particular operator run the script with the `--debug` option first (see below for the description of the option), and having opened the generated file `qat_int8_cpu_quantize_placement_pass.dot` find the id number written in parentheses next to the name of the operator.
* `--debug` - add this option to generate a series of `*.dot` files containing the model graphs after each step of the transformation. For a description of the DOT format see [DOT]( https://graphviz.gitlab.io/_pages/doc/info/lang.html). The files will be saved in the current location. To open the `*.dot` files use any of the Graphviz tools available on your system (e.g. `xdot` tool on Linux or `dot` tool on Windows, for documentation see [Graphviz](http://www.graphviz.org/documentation/)).
```bash
cd /PATH/TO/PADDLE
......
......@@ -41,10 +41,6 @@ def parse_args():
default=0,
help='Number of the first minibatches to skip in performance statistics.'
)
parser.add_argument(
'--debug',
action='store_true',
help='If used, the graph of QAT model is drawn.')
parser.add_argument(
'--qat_model', type=str, default='', help='A path to a QAT model.')
parser.add_argument(
......@@ -67,6 +63,15 @@ def parse_args():
default='',
help='A comma separated list of operators to quantize. Only quantizable operators are taken into account. If the option is not used, an attempt to quantize all quantizable operators will be made.'
)
parser.add_argument(
'--op_ids_to_skip',
type=str,
default='',
help='A comma separated list of operator ids to skip in quantization.')
parser.add_argument(
'--debug',
action='store_true',
help='If used, the graph of QAT model is drawn.')
test_args, args = parser.parse_known_args(namespace=unittest)
return test_args, sys.argv[:1] + args
......@@ -181,6 +186,7 @@ class Qat2Int8ImageClassificationComparisonTest(unittest.TestCase):
if (transform_to_int8):
transform_to_mkldnn_int8_pass = Qat2Int8MkldnnPass(
self._quantized_ops,
_op_ids_to_skip=self._op_ids_to_skip,
_scope=inference_scope,
_place=place,
_core=core,
......@@ -306,18 +312,29 @@ class Qat2Int8ImageClassificationComparisonTest(unittest.TestCase):
skip_batch_num = test_case_args.skip_batch_num
acc_diff_threshold = test_case_args.acc_diff_threshold
self._debug = test_case_args.debug
self._quantized_ops = set()
if len(test_case_args.ops_to_quantize) > 0:
self._quantized_ops = set(test_case_args.ops_to_quantize.split(','))
self._quantized_ops = set(
op.strip() for op in test_case_args.ops_to_quantize.split(','))
self._op_ids_to_skip = set([-1])
if len(test_case_args.op_ids_to_skip) > 0:
self._op_ids_to_skip = set(
map(int, test_case_args.op_ids_to_skip.split(',')))
_logger.info('FP32 & QAT INT8 prediction run.')
_logger.info('QAT model: {0}'.format(qat_model_path))
_logger.info('FP32 model: {0}'.format(fp32_model_path))
_logger.info('Dataset: {0}'.format(data_path))
_logger.info('Batch size: {0}'.format(batch_size))
_logger.info('Batch number: {0}'.format(batch_num))
_logger.info('Accuracy drop threshold: {0}.'.format(acc_diff_threshold))
_logger.info('Quantized ops: {0}.'.format(self._quantized_ops))
_logger.info('QAT model: {}'.format(qat_model_path))
_logger.info('FP32 model: {}'.format(fp32_model_path))
_logger.info('Dataset: {}'.format(data_path))
_logger.info('Batch size: {}'.format(batch_size))
_logger.info('Batch number: {}'.format(batch_num))
_logger.info('Accuracy drop threshold: {}.'.format(acc_diff_threshold))
_logger.info('Quantized ops: {}.'.format(','.join(
self._quantized_ops) if self._quantized_ops else 'all quantizable'))
_logger.info('Op ids to skip quantization: {}.'.format(','.join(
map(str, self._op_ids_to_skip)) if test_case_args.op_ids_to_skip
else 'none'))
_logger.info('--- FP32 prediction start ---')
val_reader = paddle.batch(
......
......@@ -41,10 +41,6 @@ def parse_args():
default=0,
help='Number of the first minibatches to skip in performance statistics.'
)
parser.add_argument(
'--debug',
action='store_true',
help='If used, the graph of QAT model is drawn.')
parser.add_argument(
'--qat_model', type=str, default='', help='A path to a QAT model.')
parser.add_argument(
......@@ -73,6 +69,15 @@ def parse_args():
default='',
help='A comma separated list of operators to quantize. Only quantizable operators are taken into account. If the option is not used, an attempt to quantize all quantizable operators will be made.'
)
parser.add_argument(
'--op_ids_to_skip',
type=str,
default='',
help='A comma separated list of operator ids to skip in quantization.')
parser.add_argument(
'--debug',
action='store_true',
help='If used, the graph of QAT model is drawn.')
test_args, args = parser.parse_known_args(namespace=unittest)
......@@ -157,6 +162,7 @@ class QatInt8NLPComparisonTest(unittest.TestCase):
if (transform_to_int8):
transform_to_mkldnn_int8_pass = Qat2Int8MkldnnPass(
self._quantized_ops,
_op_ids_to_skip=self._op_ids_to_skip,
_scope=inference_scope,
_place=place,
_core=core,
......@@ -253,19 +259,30 @@ class QatInt8NLPComparisonTest(unittest.TestCase):
skip_batch_num = test_case_args.skip_batch_num
acc_diff_threshold = test_case_args.acc_diff_threshold
self._debug = test_case_args.debug
self._quantized_ops = set()
if len(test_case_args.ops_to_quantize) > 0:
self._quantized_ops = set(test_case_args.ops_to_quantize.split(','))
if test_case_args.ops_to_quantize:
self._quantized_ops = set(
op.strip() for op in test_case_args.ops_to_quantize.split(','))
self._op_ids_to_skip = set([-1])
if test_case_args.op_ids_to_skip:
self._op_ids_to_skip = set(
map(int, test_case_args.op_ids_to_skip.split(',')))
_logger.info('FP32 & QAT INT8 prediction run.')
_logger.info('QAT model: {0}'.format(qat_model_path))
_logger.info('FP32 model: {0}'.format(fp32_model_path))
_logger.info('Dataset: {0}'.format(data_path))
_logger.info('Labels: {0}'.format(labels_path))
_logger.info('Batch size: {0}'.format(batch_size))
_logger.info('Batch number: {0}'.format(batch_num))
_logger.info('Accuracy drop threshold: {0}.'.format(acc_diff_threshold))
_logger.info('Quantized ops: {0}.'.format(self._quantized_ops))
_logger.info('QAT model: {}'.format(qat_model_path))
_logger.info('FP32 model: {}'.format(fp32_model_path))
_logger.info('Dataset: {}'.format(data_path))
_logger.info('Labels: {}'.format(labels_path))
_logger.info('Batch size: {}'.format(batch_size))
_logger.info('Batch number: {}'.format(batch_num))
_logger.info('Accuracy drop threshold: {}.'.format(acc_diff_threshold))
_logger.info('Quantized ops: {}.'.format(','.join(
self._quantized_ops) if self._quantized_ops else 'all quantizable'))
_logger.info('Op ids to skip quantization: {}.'.format(','.join(
map(str, self._op_ids_to_skip)) if test_case_args.op_ids_to_skip
else 'none'))
_logger.info('--- FP32 prediction start ---')
val_reader = paddle.batch(
......
......@@ -48,6 +48,15 @@ def parse_args():
default='',
help='A comma separated list of operators to quantize. Only quantizable operators are taken into account. If the option is not used, an attempt to quantize all quantizable operators will be made.'
)
parser.add_argument(
'--op_ids_to_skip',
type=str,
default='',
help='A comma separated list of operator ids to skip in quantization.')
parser.add_argument(
'--debug',
action='store_true',
help='If used, the graph of QAT model is drawn.')
test_args, args = parser.parse_known_args(namespace=unittest)
return test_args, sys.argv[:1] + args
......@@ -70,8 +79,20 @@ def transform_and_save_model(original_path, save_path, save_type):
if len(test_args.ops_to_quantize) > 0:
ops_to_quantize = set(test_args.ops_to_quantize.split(','))
op_ids_to_skip = set([-1])
if len(test_args.op_ids_to_skip) > 0:
op_ids_to_skip = set(map(int, test_args.op_ids_to_skip.split(',')))
graph = IrGraph(core.Graph(inference_program.desc), for_test=True)
if (test_args.debug):
graph.draw('.', 'qat_orig', graph.all_op_nodes())
transform_to_mkldnn_int8_pass = Qat2Int8MkldnnPass(
ops_to_quantize, _scope=inference_scope, _place=place, _core=core)
ops_to_quantize,
_op_ids_to_skip=op_ids_to_skip,
_scope=inference_scope,
_place=place,
_core=core,
_debug=test_args.debug)
graph = IrGraph(core.Graph(inference_program.desc), for_test=True)
if save_type == 'FP32':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册