提交 36b60e24 编写于 作者: W Wojciech Uss 提交者: Tao Luo

Enable MKL-DNN for slim FP32 vs. INT8 tests (#18214)

* Enable MKL-DNN for slim FP32 vs. INT8 tests

test=develop

* added test for MobileNetV1 with MKL-DNN

test=develop
上级 976cf460
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
function(inference_analysis_python_api_int8_test target model_dir data_dir filename)
function(_inference_analysis_python_api_int8_test target model_dir data_dir filename use_mkldnn)
py_test(${target} SRCS ${filename}
ENVS CPU_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
FLAGS_use_mkldnn=${use_mkldnn}
ARGS --infer_model ${model_dir}/model
--infer_data ${data_dir}/data.bin
--int8_model_save_path int8_models/${target}
......@@ -11,6 +12,14 @@ function(inference_analysis_python_api_int8_test target model_dir data_dir filen
--batch_size 50)
endfunction()
function(inference_analysis_python_api_int8_test target model_dir data_dir filename)
_inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_dir} ${filename} False)
endfunction()
function(inference_analysis_python_api_int8_test_mkldnn target model_dir data_dir filename)
_inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_dir} ${filename} True)
endfunction()
function(inference_qat_int8_test target model_dir data_dir test_script use_mkldnn)
py_test(${target} SRCS ${test_script}
ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
......@@ -44,6 +53,7 @@ if(LINUX AND WITH_MKLDNN)
# mobilenet int8
set(INT8_MOBILENET_MODEL_DIR "${INT8_DATA_DIR}/mobilenetv1")
inference_analysis_python_api_int8_test(test_slim_int8_mobilenet ${INT8_MOBILENET_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE})
inference_analysis_python_api_int8_test_mkldnn(test_slim_int8_mobilenet_mkldnn ${INT8_MOBILENET_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE})
# temporarily adding WITH_SLIM_MKLDNN_FULL_TEST FLAG for QA testing the following UTs locally,
# since the following UTs cost too much time on CI test.
......
......@@ -128,3 +128,4 @@ python ./test_mkldnn_int8_quantization_strategy.py --infer_model /PATH/TO/DOWNLO
Notes:
* The above commands will cost maybe several hours in the prediction stage (include int8 prediction and fp32 prediction) since there have 50000 pictures need to be predicted in `int8_full_val.bin`
* Running the above command with environment variable `FLAGS_use_mkldnn=true` will make the FP32 part of the test running using MKL-DNN (the INT8 part uses MKL-DNN either way).
......@@ -23,6 +23,8 @@ import six
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph
from paddle.fluid import core
from paddle.fluid.contrib.slim.core import Compressor
from paddle.fluid.log_helper import get_logger
......@@ -112,6 +114,41 @@ class TestMKLDNNPostTrainingQuantStrategy(unittest.TestCase):
return new_config_path
def _transform_depthwise_conv(self, graph):
'''
Transform depthwise_conv2d into conv2d, with MKL-DNN only
'''
ops = graph.all_op_nodes()
for op_node in ops:
name = op_node.name()
if name in ['depthwise_conv2d']:
input_var_node = graph._find_node_by_name(
op_node.inputs, op_node.input("Input")[0])
weight_var_node = graph._find_node_by_name(
op_node.inputs, op_node.input("Filter")[0])
output_var_node = graph._find_node_by_name(
graph.all_var_nodes(), op_node.output("Output")[0])
attrs = {
name: op_node.op().attr(name)
for name in op_node.op().attr_names()
}
conv_op_node = graph.create_op_node(
op_type='conv2d',
attrs=attrs,
inputs={
'Input': input_var_node,
'Filter': weight_var_node
},
outputs={'Output': output_var_node})
graph.link_to(input_var_node, conv_op_node)
graph.link_to(weight_var_node, conv_op_node)
graph.link_to(conv_op_node, output_var_node)
graph.safe_remove_nodes(op_node)
return graph
def _predict(self, test_reader=None, model_path=None):
place = fluid.CPUPlace()
exe = fluid.Executor(place)
......@@ -125,6 +162,13 @@ class TestMKLDNNPostTrainingQuantStrategy(unittest.TestCase):
fetch_targets] = fluid.io.load_inference_model(
model_path, exe, 'model', 'params')
use_mkldnn = bool(os.getenv("FLAGS_use_mkldnn", False))
if (use_mkldnn):
graph = IrGraph(
core.Graph(inference_program.desc), for_test=True)
graph = self._transform_depthwise_conv(graph)
inference_program = graph.to_program()
dshape = [3, 224, 224]
top1 = 0.0
top5 = 0.0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册