未验证 提交 db052009 编写于 作者: W Wojciech Uss 提交者: GitHub

Enabled quantize all and skip missing in QAT (#24281)

* Enabled quantize all and skip missing in QAT
上级 5f65d9d5
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h" #include "paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h"
#include <limits> #include <limits>
#include <sstream>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
...@@ -169,13 +170,31 @@ void CPUQuantizePass::DequantizeOutput(Graph* g, Node* op, Node* output, ...@@ -169,13 +170,31 @@ void CPUQuantizePass::DequantizeOutput(Graph* g, Node* op, Node* output,
if (!scale_attr_name.empty()) op->Op()->SetAttr(scale_attr_name, scale); if (!scale_attr_name.empty()) op->Op()->SetAttr(scale_attr_name, scale);
} }
bool CPUQuantizePass::AreScalesPresentForNodes(
const Node* op_node, std::initializer_list<Node*> nodes) const {
auto& scales = Get<VarQuantScale>("quant_var_scales");
bool present = true;
for (auto node : nodes) {
if (scales.count(node->Name()) == 0) {
present = false;
std::stringstream msg_ss;
msg_ss << "Quantization scale for the variable " << node->Name()
<< " is missing.";
PrettyLogDetail(msg_ss.str().c_str());
}
}
if (!present) {
std::stringstream msg_ss;
msg_ss << "Cannot quantize operator " << op_node->Name()
<< " (type: " << op_node->Op()->Type() << ").";
PrettyLogDetail(msg_ss.str().c_str());
}
return present;
}
std::pair<bool, LoDTensor> CPUQuantizePass::GetScaleDataForNode( std::pair<bool, LoDTensor> CPUQuantizePass::GetScaleDataForNode(
const Node* node) const { const Node* node) const {
auto& scales = Get<VarQuantScale>("quant_var_scales"); auto& scales = Get<VarQuantScale>("quant_var_scales");
PADDLE_ENFORCE_EQ(
scales.count(node->Name()), 1,
platform::errors::InvalidArgument(
"Quantization scale for the variable %s is missing.", node->Name()));
return scales[node->Name()]; return scales[node->Name()];
} }
...@@ -221,6 +240,25 @@ void CPUQuantizePass::QuantizeConv(Graph* graph, ...@@ -221,6 +240,25 @@ void CPUQuantizePass::QuantizeConv(Graph* graph,
GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);
if (with_residual_data) {
GET_IR_NODE_FROM_SUBGRAPH(conv_residual_data, conv_residual_data,
conv_pattern);
if (!AreScalesPresentForNodes(conv_op, {conv_input, conv_filter,
conv_residual_data, conv_output}))
return;
bool is_residual_unsigned{false};
auto residual_scale =
GetScaleValueForNode(conv_residual_data, &is_residual_unsigned);
QuantizeInput(g, conv_op, conv_residual_data, "ResidualData",
residual_scale, is_residual_unsigned, "Scale_in_eltwise");
} else {
if (!AreScalesPresentForNodes(conv_op,
{conv_input, conv_filter, conv_output}))
return;
}
bool is_input_unsigned{false}; bool is_input_unsigned{false};
auto input_scale = GetScaleValueForNode(conv_input, &is_input_unsigned); auto input_scale = GetScaleValueForNode(conv_input, &is_input_unsigned);
QuantizeInput(g, conv_op, conv_input, "Input", input_scale, QuantizeInput(g, conv_op, conv_input, "Input", input_scale,
...@@ -236,17 +274,6 @@ void CPUQuantizePass::QuantizeConv(Graph* graph, ...@@ -236,17 +274,6 @@ void CPUQuantizePass::QuantizeConv(Graph* graph,
conv_op->Op()->SetAttr("Scale_weights", filter_scale); conv_op->Op()->SetAttr("Scale_weights", filter_scale);
if (with_residual_data) {
GET_IR_NODE_FROM_SUBGRAPH(conv_residual_data, conv_residual_data,
conv_pattern);
bool is_residual_unsigned{false};
auto residual_scale =
GetScaleValueForNode(conv_residual_data, &is_residual_unsigned);
QuantizeInput(g, conv_op, conv_residual_data, "ResidualData",
residual_scale, is_residual_unsigned, "Scale_in_eltwise");
}
bool is_output_unsigned{false}; bool is_output_unsigned{false};
auto output_scale = GetScaleValueForNode(conv_output, &is_output_unsigned); auto output_scale = GetScaleValueForNode(conv_output, &is_output_unsigned);
DequantizeOutput(g, conv_op, conv_output, "Output", output_scale, DequantizeOutput(g, conv_op, conv_output, "Output", output_scale,
...@@ -298,6 +325,8 @@ void CPUQuantizePass::QuantizeFc(Graph* graph) const { ...@@ -298,6 +325,8 @@ void CPUQuantizePass::QuantizeFc(Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(input, input, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(input, input, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(output, output, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(output, output, fc_pattern);
if (!AreScalesPresentForNodes(fc, {input, weights, output})) return;
bool is_input_unsigned{false}; bool is_input_unsigned{false};
auto input_scale = GetScaleValueForNode(input, &is_input_unsigned); auto input_scale = GetScaleValueForNode(input, &is_input_unsigned);
QuantizeInput(g, fc, input, "Input", input_scale, is_input_unsigned, QuantizeInput(g, fc, input, "Input", input_scale, is_input_unsigned,
...@@ -348,6 +377,8 @@ void CPUQuantizePass::QuantizePool(Graph* graph) const { ...@@ -348,6 +377,8 @@ void CPUQuantizePass::QuantizePool(Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(pool_input, pool_input, pool_pattern); GET_IR_NODE_FROM_SUBGRAPH(pool_input, pool_input, pool_pattern);
GET_IR_NODE_FROM_SUBGRAPH(pool_output, pool_output, pool_pattern); GET_IR_NODE_FROM_SUBGRAPH(pool_output, pool_output, pool_pattern);
if (!AreScalesPresentForNodes(pool_op, {pool_input, pool_output})) return;
bool is_input_unsigned{false}; bool is_input_unsigned{false};
auto input_scale = GetScaleValueForNode(pool_input, &is_input_unsigned); auto input_scale = GetScaleValueForNode(pool_input, &is_input_unsigned);
QuantizeInput(g, pool_op, pool_input, "X", input_scale, is_input_unsigned); QuantizeInput(g, pool_op, pool_input, "X", input_scale, is_input_unsigned);
...@@ -384,6 +415,8 @@ void CPUQuantizePass::QuantizeConcat(Graph* graph) const { ...@@ -384,6 +415,8 @@ void CPUQuantizePass::QuantizeConcat(Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(concat_out, concat_out, concat_pattern); GET_IR_NODE_FROM_SUBGRAPH(concat_out, concat_out, concat_pattern);
if (!AreScalesPresentForNodes(concat_op, {concat_out})) return;
// if all inputs were unsigned, then the output was set to unsigned // if all inputs were unsigned, then the output was set to unsigned
// during the scale calculation step // during the scale calculation step
bool are_all_inputs_unsigned{false}; bool are_all_inputs_unsigned{false};
...@@ -423,6 +456,8 @@ void CPUQuantizePass::QuantizePriorBox(Graph* graph) const { ...@@ -423,6 +456,8 @@ void CPUQuantizePass::QuantizePriorBox(Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(prior_box_input, prior_box_input, GET_IR_NODE_FROM_SUBGRAPH(prior_box_input, prior_box_input,
prior_box_pattern); prior_box_pattern);
if (!AreScalesPresentForNodes(prior_box_op, {prior_box_input})) return;
bool is_input_unsigned{false}; bool is_input_unsigned{false};
auto input_scale = auto input_scale =
GetScaleValueForNode(prior_box_input, &is_input_unsigned); GetScaleValueForNode(prior_box_input, &is_input_unsigned);
...@@ -466,6 +501,9 @@ void CPUQuantizePass::QuantizeTranspose(Graph* graph) const { ...@@ -466,6 +501,9 @@ void CPUQuantizePass::QuantizeTranspose(Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(transpose_in, transpose_in, transpose_pattern); GET_IR_NODE_FROM_SUBGRAPH(transpose_in, transpose_in, transpose_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose_out, transpose_out, transpose_pattern); GET_IR_NODE_FROM_SUBGRAPH(transpose_out, transpose_out, transpose_pattern);
if (!AreScalesPresentForNodes(transpose_op, {transpose_in, transpose_out}))
return;
bool is_input_unsigned{false}; bool is_input_unsigned{false};
auto input_scale = GetScaleValueForNode(transpose_in, &is_input_unsigned); auto input_scale = GetScaleValueForNode(transpose_in, &is_input_unsigned);
QuantizeInput(g, transpose_op, transpose_in, "X", input_scale, QuantizeInput(g, transpose_op, transpose_in, "X", input_scale,
...@@ -515,6 +553,9 @@ void CPUQuantizePass::QuantizeReshape(Graph* graph) const { ...@@ -515,6 +553,9 @@ void CPUQuantizePass::QuantizeReshape(Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(reshape_in, reshape_in, reshape_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape_in, reshape_in, reshape_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape_out, reshape_out, reshape_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape_out, reshape_out, reshape_pattern);
if (!AreScalesPresentForNodes(reshape_op, {reshape_in, reshape_out}))
return;
bool is_input_unsigned{false}; bool is_input_unsigned{false};
auto input_scale = GetScaleValueForNode(reshape_in, &is_input_unsigned); auto input_scale = GetScaleValueForNode(reshape_in, &is_input_unsigned);
QuantizeInput(g, reshape_op, reshape_in, "X", input_scale, QuantizeInput(g, reshape_op, reshape_in, "X", input_scale,
...@@ -562,6 +603,10 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const { ...@@ -562,6 +603,10 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, matmul_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, matmul_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, matmul_pattern);
if (!AreScalesPresentForNodes(matmul_op,
{matmul_in_x, matmul_in_y, matmul_out}))
return;
bool is_x_unsigned{false}, is_y_unsigned{false}; bool is_x_unsigned{false}, is_y_unsigned{false};
auto input_x_scale = GetScaleValueForNode(matmul_in_x, &is_x_unsigned); auto input_x_scale = GetScaleValueForNode(matmul_in_x, &is_x_unsigned);
auto input_y_scale = GetScaleValueForNode(matmul_in_y, &is_y_unsigned); auto input_y_scale = GetScaleValueForNode(matmul_in_y, &is_y_unsigned);
......
...@@ -74,6 +74,8 @@ class CPUQuantizePass : public FusePassBase { ...@@ -74,6 +74,8 @@ class CPUQuantizePass : public FusePassBase {
bool is_unsigned, bool is_unsigned,
std::string scale_attr_name = "") const; std::string scale_attr_name = "") const;
bool AreScalesPresentForNodes(const Node* op_node,
std::initializer_list<Node*> nodes) const;
std::pair<bool, LoDTensor> GetScaleDataForNode(const Node* node) const; std::pair<bool, LoDTensor> GetScaleDataForNode(const Node* node) const;
LoDTensor GetScaleTensorForNode(const Node* node) const; LoDTensor GetScaleTensorForNode(const Node* node) const;
double GetScaleValueForNode(const Node* node, double GetScaleValueForNode(const Node* node,
......
...@@ -486,41 +486,6 @@ TEST(CpuQuantizePass, reshapeBetweenNonQuantizedOp) { ...@@ -486,41 +486,6 @@ TEST(CpuQuantizePass, reshapeBetweenNonQuantizedOp) {
added_nodes_count, 2.0f * 127); added_nodes_count, 2.0f * 127);
} }
void MainTestCheckScales(
const ProgramDesc& prog,
const std::initializer_list<std::string> variable_names,
const std::string& var_without_scale) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
std::stringstream error_msg_ss;
error_msg_ss << "Quantization scale for the variable " << var_without_scale
<< " is missing.";
bool caught_exception = false;
try {
int original_nodes_num, current_nodes_num;
PreparePass(&graph, prog, variable_names, &original_nodes_num,
&current_nodes_num, var_without_scale);
} catch (paddle::platform::EnforceNotMet& error) {
caught_exception = true;
std::string ex_msg = error.what();
EXPECT_NE(ex_msg.find(error_msg_ss.str()), std::string::npos);
}
EXPECT_TRUE(caught_exception);
}
// (a, w)->Conv->o
ProgramDesc BuildProgramDescCheckScalesConv() {
ProgramDesc prog;
SetOp(&prog, "conv2d", "Conv", {"a", "w"}, {"o"}, true, true);
return prog;
}
// Check if an exception with a proper message is thrown when quantization scale
// is missing for a variable
TEST(CPUQuantizePass, check_scales) {
const std::initializer_list<std::string> var_names = {"a", "w", "o"};
MainTestCheckScales(BuildProgramDescCheckScalesConv(), var_names, "a");
}
static const std::initializer_list<std::string> variable_names_matmul = { static const std::initializer_list<std::string> variable_names_matmul = {
"a", "b", "c", "d", "e", "f"}; "a", "b", "c", "d", "e", "f"};
......
...@@ -36,7 +36,7 @@ class Qat2Int8MkldnnPass(object): ...@@ -36,7 +36,7 @@ class Qat2Int8MkldnnPass(object):
""" """
def __init__(self, def __init__(self,
_quantized_ops, _ops_to_quantize,
_scope=None, _scope=None,
_place=None, _place=None,
_core=None, _core=None,
...@@ -53,7 +53,7 @@ class Qat2Int8MkldnnPass(object): ...@@ -53,7 +53,7 @@ class Qat2Int8MkldnnPass(object):
self._fake_dequantize_types = [ self._fake_dequantize_types = [
'fake_dequantize_max_abs', 'fake_channel_wise_dequantize_max_abs' 'fake_dequantize_max_abs', 'fake_channel_wise_dequantize_max_abs'
] ]
self._quantized_ops = _quantized_ops self._ops_to_quantize = _ops_to_quantize
self._scale_immutable_ops = [ self._scale_immutable_ops = [
'transpose2', 'reshape2', 'pool2d', 'scale' 'transpose2', 'reshape2', 'pool2d', 'scale'
] ]
...@@ -101,10 +101,11 @@ class Qat2Int8MkldnnPass(object): ...@@ -101,10 +101,11 @@ class Qat2Int8MkldnnPass(object):
return tensor return tensor
def _is_conv_quantized(self): def _is_conv_quantized(self):
return any(op_type in self._quantized_ops for op_type in self._conv_ops) return any(op_type in self._ops_to_quantize
for op_type in self._conv_ops)
def _is_fc_quantized(self): def _is_fc_quantized(self):
return 'fc' in self._quantized_ops return 'fc' in self._ops_to_quantize
def _gather_input_scales_from_fake(self, graph): def _gather_input_scales_from_fake(self, graph):
def _add_scale_for_vars(var_names, use_unsigned_int, lod_tensor): def _add_scale_for_vars(var_names, use_unsigned_int, lod_tensor):
...@@ -238,27 +239,13 @@ class Qat2Int8MkldnnPass(object): ...@@ -238,27 +239,13 @@ class Qat2Int8MkldnnPass(object):
return np.array(scope.find_var(param_name).get_tensor()) return np.array(scope.find_var(param_name).get_tensor())
def _remove_fake_ops(self, graph): def _remove_fake_ops(self, graph):
'''
When FC isn't quantized:
Remove fake (de)quantize ops that do not surround mul.
When FC is quantized:
Remove all fake (de)quantize ops.
'''
is_fc_quantized = self._is_fc_quantized()
for op in graph.all_op_nodes(): for op in graph.all_op_nodes():
if op.name() in self._fake_quantize_types: if op.name() in self._fake_quantize_types:
op_out = graph._find_node_by_name(op.outputs, self._remove_fake_quantize(graph, op)
op.output("Out")[0])
next_op = op_out.outputs[0]
if next_op.name() not in self._mul_ops or is_fc_quantized:
self._remove_fake_quantize(graph, op)
for op in graph.all_op_nodes(): for op in graph.all_op_nodes():
if op.name() in self._fake_dequantize_types: if op.name() in self._fake_dequantize_types:
op_in = graph._find_node_by_name(op.inputs, op.input("X")[0]) self._remove_fake_dequantize(graph, op)
prev_op = op_in.inputs[0]
if prev_op.name() not in self._mul_ops or is_fc_quantized:
self._remove_fake_dequantize(graph, op)
return graph return graph
...@@ -305,7 +292,7 @@ class Qat2Int8MkldnnPass(object): ...@@ -305,7 +292,7 @@ class Qat2Int8MkldnnPass(object):
for op in graph.all_op_nodes(): for op in graph.all_op_nodes():
if op.name() in self._conv_ops: if op.name() in self._conv_ops:
self._dequantize_op_weights(graph, op, "Filter", "Output") self._dequantize_op_weights(graph, op, "Filter", "Output")
elif self._is_fc_quantized() and op.name() in self._mul_ops: elif op.name() in self._mul_ops:
self._dequantize_op_weights(graph, op, "Y", "Out") self._dequantize_op_weights(graph, op, "Y", "Out")
return graph return graph
...@@ -357,19 +344,16 @@ class Qat2Int8MkldnnPass(object): ...@@ -357,19 +344,16 @@ class Qat2Int8MkldnnPass(object):
graph = self._remove_ctrl_vars(graph) graph = self._remove_ctrl_vars(graph)
graph = self._apply_pass(graph, 'mkldnn_placement_pass', graph = self._apply_pass(graph, 'mkldnn_placement_pass',
['mkldnn_enabled_op_types'], [set()]) ['mkldnn_enabled_op_types'], [set()])
if self._is_conv_quantized(): graph = self._apply_pass(graph, 'depthwise_conv_mkldnn_pass')
graph = self._apply_pass(graph, 'depthwise_conv_mkldnn_pass') graph = self._apply_pass(graph, 'conv_bn_fuse_pass')
graph = self._apply_pass(graph, 'conv_bn_fuse_pass') graph = self._apply_pass(graph, 'conv_eltwiseadd_bn_fuse_pass')
graph = self._apply_pass(graph, 'conv_eltwiseadd_bn_fuse_pass') graph = self._apply_pass(graph, 'conv_bias_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_bias_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'conv_elementwise_add_mkldnn_fuse_pass')
graph = self._apply_pass(graph, graph = self._apply_pass(graph, 'conv_relu_mkldnn_fuse_pass')
'conv_elementwise_add_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'conv_relu6_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_relu_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'fc_fuse_pass',
graph = self._apply_pass(graph, 'conv_relu6_mkldnn_fuse_pass') ['use_gpu', 'use_fc_padding'], [False, False])
if self._is_fc_quantized(): if self._is_fc_quantized:
graph = self._apply_pass(graph, 'fc_fuse_pass',
['use_gpu', 'use_fc_padding'],
[False, False])
graph = self._apply_pass(graph, 'fc_mkldnn_pass') graph = self._apply_pass(graph, 'fc_mkldnn_pass')
graph = self._apply_pass(graph, 'matmul_transpose_reshape_fuse_pass') graph = self._apply_pass(graph, 'matmul_transpose_reshape_fuse_pass')
return graph return graph
...@@ -492,7 +476,7 @@ class Qat2Int8MkldnnPass(object): ...@@ -492,7 +476,7 @@ class Qat2Int8MkldnnPass(object):
def _quantize_fp32_graph(self, graph): def _quantize_fp32_graph(self, graph):
ir_pass = self._core.get_pass('cpu_quantize_placement_pass') ir_pass = self._core.get_pass('cpu_quantize_placement_pass')
cpp_graph = graph.graph cpp_graph = graph.graph
ir_pass.set('quantize_enabled_op_types', self._quantized_ops) ir_pass.set('quantize_enabled_op_types', self._ops_to_quantize)
ir_pass.set('quantize_excluded_op_ids', ir_pass.set('quantize_excluded_op_ids',
self._find_avg_pooling_ids(graph)) self._find_avg_pooling_ids(graph))
ir_pass.apply(cpp_graph) ir_pass.apply(cpp_graph)
......
...@@ -57,7 +57,7 @@ endfunction() ...@@ -57,7 +57,7 @@ endfunction()
# set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 25 # set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 25
function(inference_qat2_int8_image_classification_test target qat_model_dir fp32_model_dir dataset_path quantized_ops) function(inference_qat2_int8_image_classification_test target qat_model_dir fp32_model_dir dataset_path ops_to_quantize)
py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/qat2_int8_image_classification_comparison.py" py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/qat2_int8_image_classification_comparison.py"
ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
...@@ -68,11 +68,11 @@ function(inference_qat2_int8_image_classification_test target qat_model_dir fp32 ...@@ -68,11 +68,11 @@ function(inference_qat2_int8_image_classification_test target qat_model_dir fp32
--batch_size 10 --batch_size 10
--batch_num 2 --batch_num 2
--acc_diff_threshold 0.1 --acc_diff_threshold 0.1
--quantized_ops ${quantized_ops}) --ops_to_quantize ${ops_to_quantize})
endfunction() endfunction()
# set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 20 # set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 20
function(inference_qat2_int8_nlp_test target qat_model_dir fp32_model_dir dataset_path labels_path quantized_ops) function(inference_qat2_int8_nlp_test target qat_model_dir fp32_model_dir dataset_path labels_path)
py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/qat2_int8_nlp_comparison.py" py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/qat2_int8_nlp_comparison.py"
ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
...@@ -83,8 +83,7 @@ function(inference_qat2_int8_nlp_test target qat_model_dir fp32_model_dir datase ...@@ -83,8 +83,7 @@ function(inference_qat2_int8_nlp_test target qat_model_dir fp32_model_dir datase
--labels ${labels_path} --labels ${labels_path}
--batch_size 10 --batch_size 10
--batch_num 2 --batch_num 2
--acc_diff_threshold 0.1 --acc_diff_threshold 0.1)
--quantized_ops ${quantized_ops})
endfunction() endfunction()
function(download_qat_data install_dir data_file) function(download_qat_data install_dir data_file)
...@@ -99,12 +98,19 @@ function(download_qat_model install_dir data_file) ...@@ -99,12 +98,19 @@ function(download_qat_model install_dir data_file)
endif() endif()
endfunction() endfunction()
function(save_qat_model_test target qat_model_dir fp32_model_save_path int8_model_save_path quantized_ops) function(save_qat_ic_model_test target qat_model_dir fp32_model_save_path int8_model_save_path ops_to_quantize)
py_test(${target} SRCS ${CMAKE_CURRENT_SOURCE_DIR}/save_qat_model.py py_test(${target} SRCS ${CMAKE_CURRENT_SOURCE_DIR}/save_qat_model.py
ARGS --qat_model_path ${qat_model_dir} ARGS --qat_model_path ${qat_model_dir}
--fp32_model_save_path ${fp32_model_save_path} --fp32_model_save_path ${fp32_model_save_path}
--int8_model_save_path ${int8_model_save_path} --int8_model_save_path ${int8_model_save_path}
--quantized_ops ${quantized_ops}) --ops_to_quantize ${ops_to_quantize})
endfunction()
function(save_qat_nlp_model_test target qat_model_dir fp32_model_save_path int8_model_save_path)
py_test(${target} SRCS ${CMAKE_CURRENT_SOURCE_DIR}/save_qat_model.py
ARGS --qat_model_path ${qat_model_dir}
--fp32_model_save_path ${fp32_model_save_path}
--int8_model_save_path ${int8_model_save_path})
endfunction() endfunction()
if(WIN32) if(WIN32)
...@@ -213,7 +219,7 @@ if(LINUX AND WITH_MKLDNN) ...@@ -213,7 +219,7 @@ if(LINUX AND WITH_MKLDNN)
### QATv2 for image classification ### QATv2 for image classification
set(QAT2_IC_QUANTIZED_OPS "conv2d,pool2d") set(QAT2_IC_OPS_TO_QUANTIZE "conv2d,pool2d")
# QAT2 ResNet50 with input/output scales in `fake_quantize_moving_average_abs_max` operators, # QAT2 ResNet50 with input/output scales in `fake_quantize_moving_average_abs_max` operators,
# with weight scales in `fake_dequantize_max_abs` operators # with weight scales in `fake_dequantize_max_abs` operators
...@@ -221,33 +227,31 @@ if(LINUX AND WITH_MKLDNN) ...@@ -221,33 +227,31 @@ if(LINUX AND WITH_MKLDNN)
set(FP32_RESNET50_MODEL_DIR "${INT8_INSTALL_DIR}/resnet50") set(FP32_RESNET50_MODEL_DIR "${INT8_INSTALL_DIR}/resnet50")
set(QAT2_RESNET50_MODEL_ARCHIVE "ResNet50_qat_perf.tar.gz") set(QAT2_RESNET50_MODEL_ARCHIVE "ResNet50_qat_perf.tar.gz")
download_qat_model(${QAT2_RESNET50_MODEL_DIR} ${QAT2_RESNET50_MODEL_ARCHIVE}) download_qat_model(${QAT2_RESNET50_MODEL_DIR} ${QAT2_RESNET50_MODEL_ARCHIVE})
inference_qat2_int8_image_classification_test(test_qat2_int8_resnet50_mkldnn ${QAT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QAT2_IC_QUANTIZED_OPS}) inference_qat2_int8_image_classification_test(test_qat2_int8_resnet50_mkldnn ${QAT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QAT2_IC_OPS_TO_QUANTIZE})
# QAT2 ResNet50 with input/output scales in `fake_quantize_range_abs_max` operators and the `out_threshold` attributes, # QAT2 ResNet50 with input/output scales in `fake_quantize_range_abs_max` operators and the `out_threshold` attributes,
# with weight scales in `fake_dequantize_max_abs` operators # with weight scales in `fake_dequantize_max_abs` operators
set(QAT2_RESNET50_RANGE_MODEL_DIR "${QAT_INSTALL_DIR}/ResNet50_qat_range") set(QAT2_RESNET50_RANGE_MODEL_DIR "${QAT_INSTALL_DIR}/ResNet50_qat_range")
set(QAT2_RESNET50_RANGE_MODEL_ARCHIVE "ResNet50_qat_range.tar.gz") set(QAT2_RESNET50_RANGE_MODEL_ARCHIVE "ResNet50_qat_range.tar.gz")
download_qat_model(${QAT2_RESNET50_RANGE_MODEL_DIR} ${QAT2_RESNET50_RANGE_MODEL_ARCHIVE}) download_qat_model(${QAT2_RESNET50_RANGE_MODEL_DIR} ${QAT2_RESNET50_RANGE_MODEL_ARCHIVE})
inference_qat2_int8_image_classification_test(test_qat2_int8_resnet50_range_mkldnn ${QAT2_RESNET50_RANGE_MODEL_DIR}/ResNet50_qat_range ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QAT2_IC_QUANTIZED_OPS}) inference_qat2_int8_image_classification_test(test_qat2_int8_resnet50_range_mkldnn ${QAT2_RESNET50_RANGE_MODEL_DIR}/ResNet50_qat_range ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QAT2_IC_OPS_TO_QUANTIZE})
# QAT2 ResNet50 with input/output scales in `fake_quantize_range_abs_max` operators and the `out_threshold` attributes, # QAT2 ResNet50 with input/output scales in `fake_quantize_range_abs_max` operators and the `out_threshold` attributes,
# with weight scales in `fake_channel_wise_dequantize_max_abs` operators # with weight scales in `fake_channel_wise_dequantize_max_abs` operators
set(QAT2_RESNET50_CHANNELWISE_MODEL_DIR "${QAT_INSTALL_DIR}/ResNet50_qat_channelwise") set(QAT2_RESNET50_CHANNELWISE_MODEL_DIR "${QAT_INSTALL_DIR}/ResNet50_qat_channelwise")
set(QAT2_RESNET50_CHANNELWISE_MODEL_ARCHIVE "ResNet50_qat_channelwise.tar.gz") set(QAT2_RESNET50_CHANNELWISE_MODEL_ARCHIVE "ResNet50_qat_channelwise.tar.gz")
download_qat_model(${QAT2_RESNET50_CHANNELWISE_MODEL_DIR} ${QAT2_RESNET50_CHANNELWISE_MODEL_ARCHIVE}) download_qat_model(${QAT2_RESNET50_CHANNELWISE_MODEL_DIR} ${QAT2_RESNET50_CHANNELWISE_MODEL_ARCHIVE})
inference_qat2_int8_image_classification_test(test_qat2_int8_resnet50_channelwise_mkldnn ${QAT2_RESNET50_CHANNELWISE_MODEL_DIR}/ResNet50_qat_channelwise ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QAT2_IC_QUANTIZED_OPS}) inference_qat2_int8_image_classification_test(test_qat2_int8_resnet50_channelwise_mkldnn ${QAT2_RESNET50_CHANNELWISE_MODEL_DIR}/ResNet50_qat_channelwise ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QAT2_IC_OPS_TO_QUANTIZE})
# QAT2 MobileNetV1 # QAT2 MobileNetV1
set(QAT2_MOBILENETV1_MODEL_DIR "${QAT_INSTALL_DIR}/MobileNet_qat_perf") set(QAT2_MOBILENETV1_MODEL_DIR "${QAT_INSTALL_DIR}/MobileNet_qat_perf")
set(FP32_MOBILENETV1_MODEL_DIR "${INT8_INSTALL_DIR}/mobilenetv1") set(FP32_MOBILENETV1_MODEL_DIR "${INT8_INSTALL_DIR}/mobilenetv1")
set(QAT2_MOBILENETV1_MODEL_ARCHIVE "MobileNet_qat_perf.tar.gz") set(QAT2_MOBILENETV1_MODEL_ARCHIVE "MobileNet_qat_perf.tar.gz")
download_qat_model(${QAT2_MOBILENETV1_MODEL_DIR} ${QAT2_MOBILENETV1_MODEL_ARCHIVE}) download_qat_model(${QAT2_MOBILENETV1_MODEL_DIR} ${QAT2_MOBILENETV1_MODEL_ARCHIVE})
inference_qat2_int8_image_classification_test(test_qat2_int8_mobilenetv1_mkldnn ${QAT2_MOBILENETV1_MODEL_DIR}/MobileNet_qat_perf/float ${FP32_MOBILENETV1_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QAT2_IC_QUANTIZED_OPS}) inference_qat2_int8_image_classification_test(test_qat2_int8_mobilenetv1_mkldnn ${QAT2_MOBILENETV1_MODEL_DIR}/MobileNet_qat_perf/float ${FP32_MOBILENETV1_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QAT2_IC_OPS_TO_QUANTIZE})
### QATv2 for NLP ### QATv2 for NLP
set(QAT2_NLP_QUANTIZED_OPS "fc,reshape2,transpose2,matmul")
set(NLP_DATA_ARCHIVE "Ernie_dataset.tar.gz") set(NLP_DATA_ARCHIVE "Ernie_dataset.tar.gz")
set(NLP_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/Ernie_dataset") set(NLP_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/Ernie_dataset")
set(NLP_DATA_PATH "${NLP_DATA_DIR}/Ernie_dataset/1.8w.bs1") set(NLP_DATA_PATH "${NLP_DATA_DIR}/Ernie_dataset/1.8w.bs1")
...@@ -261,17 +265,17 @@ if(LINUX AND WITH_MKLDNN) ...@@ -261,17 +265,17 @@ if(LINUX AND WITH_MKLDNN)
set(FP32_ERNIE_MODEL_ARCHIVE "ernie_fp32_model.tar.gz") set(FP32_ERNIE_MODEL_ARCHIVE "ernie_fp32_model.tar.gz")
set(FP32_ERNIE_MODEL_DIR "${QAT_INSTALL_DIR}/Ernie_float") set(FP32_ERNIE_MODEL_DIR "${QAT_INSTALL_DIR}/Ernie_float")
download_qat_fp32_model(${FP32_ERNIE_MODEL_DIR} ${FP32_ERNIE_MODEL_ARCHIVE}) download_qat_fp32_model(${FP32_ERNIE_MODEL_DIR} ${FP32_ERNIE_MODEL_ARCHIVE})
inference_qat2_int8_nlp_test(test_qat2_int8_ernie_mkldnn ${QAT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${FP32_ERNIE_MODEL_DIR}/ernie_fp32_model ${NLP_DATA_PATH} ${NLP_LABLES_PATH} ${QAT2_NLP_QUANTIZED_OPS}) inference_qat2_int8_nlp_test(test_qat2_int8_ernie_mkldnn ${QAT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${FP32_ERNIE_MODEL_DIR}/ernie_fp32_model ${NLP_DATA_PATH} ${NLP_LABLES_PATH})
### Save QAT2 FP32 model or QAT2 INT8 model ### Save QAT2 FP32 model or QAT2 INT8 model
set(QAT2_INT8_RESNET50_SAVE_PATH "${QAT_INSTALL_DIR}/ResNet50_qat2_int8") set(QAT2_INT8_RESNET50_SAVE_PATH "${QAT_INSTALL_DIR}/ResNet50_qat2_int8")
set(QAT2_FP32_RESNET50_SAVE_PATH "${QAT_INSTALL_DIR}/ResNet50_qat2_fp32") set(QAT2_FP32_RESNET50_SAVE_PATH "${QAT_INSTALL_DIR}/ResNet50_qat2_fp32")
save_qat_model_test(save_qat2_model_resnet50 ${QAT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${QAT2_FP32_RESNET50_SAVE_PATH} ${QAT2_INT8_RESNET50_SAVE_PATH} ${QAT2_IC_QUANTIZED_OPS}) save_qat_ic_model_test(save_qat2_model_resnet50 ${QAT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${QAT2_FP32_RESNET50_SAVE_PATH} ${QAT2_INT8_RESNET50_SAVE_PATH} ${QAT2_IC_OPS_TO_QUANTIZE})
set(QAT2_INT8_ERNIE_SAVE_PATH "${QAT_INSTALL_DIR}/Ernie_qat2_int8") set(QAT2_INT8_ERNIE_SAVE_PATH "${QAT_INSTALL_DIR}/Ernie_qat2_int8")
set(QAT2_FP32_ERNIE_SAVE_PATH "${QAT_INSTALL_DIR}/Ernie_qat2_fp32") set(QAT2_FP32_ERNIE_SAVE_PATH "${QAT_INSTALL_DIR}/Ernie_qat2_fp32")
save_qat_model_test(save_qat2_model_ernie ${QAT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${QAT2_FP32_ERNIE_SAVE_PATH} ${QAT2_INT8_ERNIE_SAVE_PATH} ${QAT2_NLP_QUANTIZED_OPS}) save_qat_nlp_model_test(save_qat2_model_ernie ${QAT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${QAT2_FP32_ERNIE_SAVE_PATH} ${QAT2_INT8_ERNIE_SAVE_PATH})
endif() endif()
......
...@@ -270,15 +270,17 @@ You can use the `qat2_int8_image_classification_comparison.py` script to reprodu ...@@ -270,15 +270,17 @@ You can use the `qat2_int8_image_classification_comparison.py` script to reprodu
* `--qat_model` - a path to a QAT model that will be transformed into INT8 model. * `--qat_model` - a path to a QAT model that will be transformed into INT8 model.
* `--fp32_model` - a path to an FP32 model whose accuracy will be measured and compared to the accuracy of the INT8 model. * `--fp32_model` - a path to an FP32 model whose accuracy will be measured and compared to the accuracy of the INT8 model.
* `--quantized_ops` - a comma-separated list of names of operators to be quantized. When deciding which operators to put on the list, the following have to be considered:
* Only operators which support quantization will be taken into account.
* All the quantizable operators from the list, which are present in the model, must have quantization scales provided in the model. Otherwise, the quantization procedure will fail with a message saying which variable is missing a quantization scale.
* Sometimes it may be suboptimal to quantize all quantizable operators in the model (cf. *Notes* in the **Gathering scales** section above). To find the optimal configuration for this option, user can run benchmark a few times with different lists of quantized operators present in the model and compare the results. For Image Classification models mentioned above the list comprises of `conv2d` and `pool2d` operators.
* `--infer_data` - a path to the validation dataset. * `--infer_data` - a path to the validation dataset.
The following option is also accepted:
* `--ops_to_quantize` - a comma-separated list of operator types to quantize. If the option is not used, an attempt to quantize all quantizable operators will be made, and in that case only quantizable operators which have quantization scales provided in the QAT model will be quantized. When deciding which operators to put on the list, the following have to be considered:
* Only operators which support quantization will be taken into account.
* All the quantizable operators from the list, which are present in the model, must have quantization scales provided in the model. Otherwise, quantization of the operator will be skipped with a message saying which variable is missing a quantization scale.
* Sometimes it may be suboptimal to quantize all quantizable operators in the model (cf. *Notes* in the **Gathering scales** section above). To find the optimal configuration for this option, user can run benchmark a few times with different lists of quantized operators present in the model and compare the results. For Image Classification models mentioned above the list usually comprises of `conv2d` and `pool2d` operators.
```bash ```bash
cd /PATH/TO/PADDLE cd /PATH/TO/PADDLE
OMP_NUM_THREADS=28 FLAGS_use_mkldnn=true python python/paddle/fluid/contrib/slim/tests/qat2_int8_image_classification_comparison.py --qat_model=/PATH/TO/DOWNLOADED/QAT/MODEL --fp32_model=/PATH/TO/DOWNLOADED/FP32/MODEL --infer_data=$HOME/.cache/paddle/dataset/int8/download/int8_full_val.bin --batch_size=50 --batch_num=1000 --acc_diff_threshold=0.01 --quantized_ops="conv2d,pool2d" OMP_NUM_THREADS=28 FLAGS_use_mkldnn=true python python/paddle/fluid/contrib/slim/tests/qat2_int8_image_classification_comparison.py --qat_model=/PATH/TO/DOWNLOADED/QAT/MODEL --fp32_model=/PATH/TO/DOWNLOADED/FP32/MODEL --infer_data=$HOME/.cache/paddle/dataset/int8/download/int8_full_val.bin --batch_size=50 --batch_num=1000 --acc_diff_threshold=0.01 --ops_to_quantize="conv2d,pool2d"
``` ```
> Notes: Due to a large amount of images in the `int8_full_val.bin` dataset (50 000), the accuracy benchmark may last long. To accelerate accuracy measuring, it is recommended to set `OMP_NUM_THREADS` to the maximum number of physical cores available on the server. > Notes: Due to a large amount of images in the `int8_full_val.bin` dataset (50 000), the accuracy benchmark may last long. To accelerate accuracy measuring, it is recommended to set `OMP_NUM_THREADS` to the maximum number of physical cores available on the server.
...@@ -287,11 +289,11 @@ OMP_NUM_THREADS=28 FLAGS_use_mkldnn=true python python/paddle/fluid/contrib/slim ...@@ -287,11 +289,11 @@ OMP_NUM_THREADS=28 FLAGS_use_mkldnn=true python python/paddle/fluid/contrib/slim
To reproduce the performance results, the environment variable `OMP_NUM_THREADS=1` and `--batch_size=1` option should be set. To reproduce the performance results, the environment variable `OMP_NUM_THREADS=1` and `--batch_size=1` option should be set.
1. Transform the QAT model into INT8 model by applying the `Qat2Int8MkldnnPass` pass and save the result. You can use the script `save_qat_model.py` for this purpose. It also requires the option `--quantized_ops` with a list of operators to be quantized. 1. Transform the QAT model into INT8 model by applying the `Qat2Int8MkldnnPass` pass and save the result. You can use the script `save_qat_model.py` for this purpose. It also accepts the option `--ops_to_quantize` with a list of operators to quantize.
```bash ```bash
cd /PATH/TO/PADDLE/build cd /PATH/TO/PADDLE/build
python ../python/paddle/fluid/contrib/slim/tests/save_qat_model.py --qat_model_path=/PATH/TO/DOWNLOADED/QAT/MODEL --int8_model_save_path=/PATH/TO/SAVE/QAT/INT8/MODEL --quantized_ops="conv2d,pool2d" python ../python/paddle/fluid/contrib/slim/tests/save_qat_model.py --qat_model_path=/PATH/TO/DOWNLOADED/QAT/MODEL --int8_model_save_path=/PATH/TO/SAVE/QAT/INT8/MODEL --ops_to_quantize="conv2d,pool2d"
``` ```
2. Run the C-API test for performance benchmark. 2. Run the C-API test for performance benchmark.
......
...@@ -62,10 +62,11 @@ def parse_args(): ...@@ -62,10 +62,11 @@ def parse_args():
default=0.01, default=0.01,
help='Accepted accuracy difference threshold.') help='Accepted accuracy difference threshold.')
parser.add_argument( parser.add_argument(
'--quantized_ops', '--ops_to_quantize',
type=str, type=str,
default='', default='',
help='A comma separated list of quantized operators.') help='A comma separated list of operators to quantize. Only quantizable operators are taken into account. If the option is not used, an attempt to quantize all quantizable operators will be made.'
)
test_args, args = parser.parse_known_args(namespace=unittest) test_args, args = parser.parse_known_args(namespace=unittest)
return test_args, sys.argv[:1] + args return test_args, sys.argv[:1] + args
...@@ -305,7 +306,9 @@ class Qat2Int8ImageClassificationComparisonTest(unittest.TestCase): ...@@ -305,7 +306,9 @@ class Qat2Int8ImageClassificationComparisonTest(unittest.TestCase):
skip_batch_num = test_case_args.skip_batch_num skip_batch_num = test_case_args.skip_batch_num
acc_diff_threshold = test_case_args.acc_diff_threshold acc_diff_threshold = test_case_args.acc_diff_threshold
self._debug = test_case_args.debug self._debug = test_case_args.debug
self._quantized_ops = set(test_case_args.quantized_ops.split(',')) self._quantized_ops = set()
if len(test_case_args.ops_to_quantize) > 0:
self._quantized_ops = set(test_case_args.ops_to_quantize.split(','))
_logger.info('FP32 & QAT INT8 prediction run.') _logger.info('FP32 & QAT INT8 prediction run.')
_logger.info('QAT model: {0}'.format(qat_model_path)) _logger.info('QAT model: {0}'.format(qat_model_path))
......
...@@ -68,10 +68,11 @@ def parse_args(): ...@@ -68,10 +68,11 @@ def parse_args():
default=0.01, default=0.01,
help='Accepted accuracy difference threshold.') help='Accepted accuracy difference threshold.')
parser.add_argument( parser.add_argument(
'--quantized_ops', '--ops_to_quantize',
type=str, type=str,
default='', default='',
help='A comma separated list of quantized operators.') help='A comma separated list of operators to quantize. Only quantizable operators are taken into account. If the option is not used, an attempt to quantize all quantizable operators will be made.'
)
test_args, args = parser.parse_known_args(namespace=unittest) test_args, args = parser.parse_known_args(namespace=unittest)
...@@ -252,7 +253,9 @@ class QatInt8NLPComparisonTest(unittest.TestCase): ...@@ -252,7 +253,9 @@ class QatInt8NLPComparisonTest(unittest.TestCase):
skip_batch_num = test_case_args.skip_batch_num skip_batch_num = test_case_args.skip_batch_num
acc_diff_threshold = test_case_args.acc_diff_threshold acc_diff_threshold = test_case_args.acc_diff_threshold
self._debug = test_case_args.debug self._debug = test_case_args.debug
self._quantized_ops = set(test_case_args.quantized_ops.split(',')) self._quantized_ops = set()
if len(test_case_args.ops_to_quantize) > 0:
self._quantized_ops = set(test_case_args.ops_to_quantize.split(','))
_logger.info('FP32 & QAT INT8 prediction run.') _logger.info('FP32 & QAT INT8 prediction run.')
_logger.info('QAT model: {0}'.format(qat_model_path)) _logger.info('QAT model: {0}'.format(qat_model_path))
......
...@@ -43,10 +43,11 @@ def parse_args(): ...@@ -43,10 +43,11 @@ def parse_args():
default='', default='',
help='Saved optimized and quantized INT8 model') help='Saved optimized and quantized INT8 model')
parser.add_argument( parser.add_argument(
'--quantized_ops', '--ops_to_quantize',
type=str, type=str,
default='', default='',
help='A comma separated list of quantized operators.') help='A comma separated list of operators to quantize. Only quantizable operators are taken into account. If the option is not used, an attempt to quantize all quantizable operators will be made.'
)
test_args, args = parser.parse_known_args(namespace=unittest) test_args, args = parser.parse_known_args(namespace=unittest)
return test_args, sys.argv[:1] + args return test_args, sys.argv[:1] + args
...@@ -65,9 +66,12 @@ def transform_and_save_model(original_path, save_path, save_type): ...@@ -65,9 +66,12 @@ def transform_and_save_model(original_path, save_path, save_type):
fetch_targets] = fluid.io.load_inference_model(original_path, exe, fetch_targets] = fluid.io.load_inference_model(original_path, exe,
'model', 'params') 'model', 'params')
quantized_ops = set(test_args.quantized_ops.split(',')) ops_to_quantize = set()
if len(test_args.ops_to_quantize) > 0:
ops_to_quantize = set(test_args.ops_to_quantize.split(','))
transform_to_mkldnn_int8_pass = Qat2Int8MkldnnPass( transform_to_mkldnn_int8_pass = Qat2Int8MkldnnPass(
quantized_ops, _scope=inference_scope, _place=place, _core=core) ops_to_quantize, _scope=inference_scope, _place=place, _core=core)
graph = IrGraph(core.Graph(inference_program.desc), for_test=True) graph = IrGraph(core.Graph(inference_program.desc), for_test=True)
if save_type == 'FP32': if save_type == 'FP32':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册