提交 8462e2b8 编写于 作者: M Michał Gallus 提交者: Tao Luo

Disable MKLDNN FC in Resnet50 test (#18030)

上级 78e93286
......@@ -385,7 +385,7 @@ function(cc_test TARGET_NAME)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_limit_of_tmp_allocation=4294967296) # 4G
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cudnn_deterministic=true ${MKL_DEBUG_FLAG})
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cudnn_deterministic=true)
# No unit test should exceed 10 minutes.
set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 600)
endif()
......
......@@ -33,14 +33,12 @@ function(inference_analysis_api_int8_test target model_dir data_dir filename)
--paddle_num_threads=${CPU_NUM_THREADS_ON_CI}
--iterations=2)
endfunction()
function(inference_analysis_api_test_with_fake_data target install_dir filename model_name mkl_debug)
if(mkl_debug)
set(MKL_DEBUG_FLAG MKL_DEBUG_CPU_TYPE=7)
endif()
function(inference_analysis_api_test_with_fake_data target install_dir filename model_name disable_fc)
download_model(${install_dir} ${model_name})
inference_analysis_test(${target} SRCS ${filename}
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${install_dir}/model)
ARGS --infer_model=${install_dir}/model
--disable_mkldnn_fc=${disable_fc})
endfunction()
function(inference_analysis_api_test_with_refer_result target install_dir filename)
......
......@@ -16,6 +16,8 @@ limitations under the License. */
#include <iostream>
#include "paddle/fluid/inference/tests/api/tester_helper.h"
DEFINE_bool(disable_mkldnn_fc, false, "Disable usage of MKL-DNN's FC op");
namespace paddle {
namespace inference {
namespace analysis {
......@@ -48,7 +50,8 @@ void profile(bool use_mkldnn = false) {
if (use_mkldnn) {
cfg.EnableMKLDNN();
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
if (!FLAGS_disable_mkldnn_fc)
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
}
std::vector<std::vector<PaddleTensor>> outputs;
......@@ -80,7 +83,8 @@ void compare(bool use_mkldnn = false) {
SetConfig(&cfg);
if (use_mkldnn) {
cfg.EnableMKLDNN();
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
if (!FLAGS_disable_mkldnn_fc)
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
}
std::vector<std::vector<PaddleTensor>> input_slots_all;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册