未验证 提交 eae8f4e3 编写于 作者: I iducn 提交者: GitHub

[Inference] [unittest] Inference unit tests rely on dynamic libraries (#24743) (#26008)

Co-authored-by: NWilber <jiweibo@baidu.com>
上级 ac347fce
...@@ -382,8 +382,7 @@ function(cc_test_run TARGET_NAME) ...@@ -382,8 +382,7 @@ function(cc_test_run TARGET_NAME)
set(multiValueArgs COMMAND ARGS) set(multiValueArgs COMMAND ARGS)
cmake_parse_arguments(cc_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(cc_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
add_test(NAME ${TARGET_NAME} add_test(NAME ${TARGET_NAME}
COMMAND ${cc_test_COMMAND} COMMAND ${cc_test_COMMAND} ${cc_test_ARGS}
ARGS ${cc_test_ARGS}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true)
......
...@@ -63,10 +63,6 @@ if(WITH_TESTING) ...@@ -63,10 +63,6 @@ if(WITH_TESTING)
endif() endif()
endif() endif()
if(NOT ON_INFER)
return()
endif()
set(SHARED_INFERENCE_SRCS set(SHARED_INFERENCE_SRCS
io.cc io.cc
${CMAKE_CURRENT_SOURCE_DIR}/../framework/data_feed.cc ${CMAKE_CURRENT_SOURCE_DIR}/../framework/data_feed.cc
......
...@@ -45,10 +45,21 @@ cc_library(analysis_predictor SRCS analysis_predictor.cc ${mkldnn_quantizer_src} ...@@ -45,10 +45,21 @@ cc_library(analysis_predictor SRCS analysis_predictor.cc ${mkldnn_quantizer_src}
cc_test(test_paddle_inference_api SRCS api_tester.cc DEPS paddle_inference_api) cc_test(test_paddle_inference_api SRCS api_tester.cc DEPS paddle_inference_api)
if(WITH_TESTING) if(WITH_TESTING)
if (NOT APPLE AND NOT WIN32)
inference_base_test(test_api_impl SRCS api_impl_tester.cc DEPS paddle_fluid_shared
ARGS --word2vec_dirname=${WORD2VEC_MODEL_DIR} --book_dirname=${PYTHON_TESTS_DIR}/book)
else()
inference_base_test(test_api_impl SRCS api_impl_tester.cc DEPS ${inference_deps} inference_base_test(test_api_impl SRCS api_impl_tester.cc DEPS ${inference_deps}
ARGS --word2vec_dirname=${WORD2VEC_MODEL_DIR} --book_dirname=${PYTHON_TESTS_DIR}/book) ARGS --word2vec_dirname=${WORD2VEC_MODEL_DIR} --book_dirname=${PYTHON_TESTS_DIR}/book)
endif()
set_tests_properties(test_api_impl PROPERTIES DEPENDS test_image_classification) set_tests_properties(test_api_impl PROPERTIES DEPENDS test_image_classification)
set_tests_properties(test_api_impl PROPERTIES LABELS "RUN_TYPE=DIST") set_tests_properties(test_api_impl PROPERTIES LABELS "RUN_TYPE=DIST")
endif() endif()
cc_test(test_analysis_predictor SRCS analysis_predictor_tester.cc DEPS analysis_predictor benchmark ${inference_deps}
if (NOT APPLE AND NOT WIN32)
cc_test(test_analysis_predictor SRCS analysis_predictor_tester.cc DEPS paddle_fluid_shared
ARGS --dirname=${WORD2VEC_MODEL_DIR}) ARGS --dirname=${WORD2VEC_MODEL_DIR})
else()
cc_test(test_analysis_predictor SRCS analysis_predictor_tester.cc DEPS analysis_predictor benchmark ${inference_deps}
ARGS --dirname=${WORD2VEC_MODEL_DIR})
endif()
set(INFERENCE_EXTRA_DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis_predictor benchmark) if (NOT APPLE AND NOT WIN32)
set(INFERENCE_EXTRA_DEPS paddle_fluid_shared)
else()
set(INFERENCE_EXTRA_DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis_predictor benchmark)
endif()
if(WITH_GPU AND TENSORRT_FOUND) if(WITH_GPU AND TENSORRT_FOUND)
set(INFERENCE_EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} analysis ${analysis_deps} ir_pass_manager analysis_predictor) set(INFERENCE_EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} analysis ${analysis_deps})
endif() endif()
function(download_data install_dir data_file) function(download_data install_dir data_file)
...@@ -33,13 +37,13 @@ endfunction() ...@@ -33,13 +37,13 @@ endfunction()
function(inference_analysis_api_test target install_dir filename) function(inference_analysis_api_test target install_dir filename)
inference_analysis_test(${target} SRCS ${filename} inference_analysis_test(${target} SRCS ${filename}
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} benchmark EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${install_dir}/model --infer_data=${install_dir}/data.txt --refer_result=${install_dir}/result.txt) ARGS --infer_model=${install_dir}/model --infer_data=${install_dir}/data.txt --refer_result=${install_dir}/result.txt)
endfunction() endfunction()
function(inference_analysis_api_test_build TARGET_NAME filename) function(inference_analysis_api_test_build TARGET_NAME filename)
inference_analysis_test_build(${TARGET_NAME} SRCS ${filename} inference_analysis_test_build(${TARGET_NAME} SRCS ${filename}
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} benchmark) EXTRA_DEPS ${INFERENCE_EXTRA_DEPS})
endfunction() endfunction()
function(inference_analysis_api_int8_test_run TARGET_NAME test_binary model_dir data_path) function(inference_analysis_api_int8_test_run TARGET_NAME test_binary model_dir data_path)
...@@ -49,7 +53,7 @@ function(inference_analysis_api_int8_test_run TARGET_NAME test_binary model_dir ...@@ -49,7 +53,7 @@ function(inference_analysis_api_int8_test_run TARGET_NAME test_binary model_dir
--infer_data=${data_path} --infer_data=${data_path}
--warmup_batch_size=${WARMUP_BATCH_SIZE} --warmup_batch_size=${WARMUP_BATCH_SIZE}
--batch_size=50 --batch_size=50
--paddle_num_threads=${CPU_NUM_THREADS_ON_CI} --cpu_num_threads=${CPU_NUM_THREADS_ON_CI}
--iterations=2) --iterations=2)
endfunction() endfunction()
...@@ -65,7 +69,7 @@ function(inference_analysis_api_object_dection_int8_test_run TARGET_NAME test_bi ...@@ -65,7 +69,7 @@ function(inference_analysis_api_object_dection_int8_test_run TARGET_NAME test_bi
--infer_data=${data_path} --infer_data=${data_path}
--warmup_batch_size=10 --warmup_batch_size=10
--batch_size=300 --batch_size=300
--paddle_num_threads=${CPU_NUM_THREADS_ON_CI} --cpu_num_threads=${CPU_NUM_THREADS_ON_CI}
--iterations=1) --iterations=1)
endfunction() endfunction()
...@@ -88,7 +92,7 @@ function(inference_analysis_api_qat_test_run TARGET_NAME test_binary fp32_model_ ...@@ -88,7 +92,7 @@ function(inference_analysis_api_qat_test_run TARGET_NAME test_binary fp32_model_
--int8_model=${int8_model_dir} --int8_model=${int8_model_dir}
--infer_data=${data_path} --infer_data=${data_path}
--batch_size=50 --batch_size=50
--paddle_num_threads=${CPU_NUM_THREADS_ON_CI} --cpu_num_threads=${CPU_NUM_THREADS_ON_CI}
--with_accuracy_layer=false --with_accuracy_layer=false
--iterations=2) --iterations=2)
endfunction() endfunction()
...@@ -167,7 +171,7 @@ set(ERNIE_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/Ernie_Large") ...@@ -167,7 +171,7 @@ set(ERNIE_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/Ernie_Large")
download_model_and_data(${ERNIE_INSTALL_DIR} "Ernie_large_model.tar.gz" "Ernie_large_data.txt.tar.gz" "Ernie_large_result.txt.tar.gz") download_model_and_data(${ERNIE_INSTALL_DIR} "Ernie_large_model.tar.gz" "Ernie_large_data.txt.tar.gz" "Ernie_large_result.txt.tar.gz")
download_result(${ERNIE_INSTALL_DIR} "Ernie_large_result.txt.tar.gz") download_result(${ERNIE_INSTALL_DIR} "Ernie_large_result.txt.tar.gz")
inference_analysis_test(test_analyzer_ernie_large SRCS analyzer_ernie_tester.cc inference_analysis_test(test_analyzer_ernie_large SRCS analyzer_ernie_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} benchmark EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${ERNIE_INSTALL_DIR}/model --infer_data=${ERNIE_INSTALL_DIR}/data.txt --refer_result=${ERNIE_INSTALL_DIR}/result.txt --ernie_large=true) ARGS --infer_model=${ERNIE_INSTALL_DIR}/model --infer_data=${ERNIE_INSTALL_DIR}/data.txt --refer_result=${ERNIE_INSTALL_DIR}/result.txt --ernie_large=true)
# text_classification # text_classification
...@@ -186,7 +190,7 @@ download_model_and_data(${TRANSFORMER_INSTALL_DIR} "temp%2Ftransformer_model.tar ...@@ -186,7 +190,7 @@ download_model_and_data(${TRANSFORMER_INSTALL_DIR} "temp%2Ftransformer_model.tar
inference_analysis_test(test_analyzer_transformer SRCS analyzer_transformer_tester.cc inference_analysis_test(test_analyzer_transformer SRCS analyzer_transformer_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TRANSFORMER_INSTALL_DIR}/model --infer_data=${TRANSFORMER_INSTALL_DIR}/data.txt --batch_size=8 ARGS --infer_model=${TRANSFORMER_INSTALL_DIR}/model --infer_data=${TRANSFORMER_INSTALL_DIR}/data.txt --batch_size=8
--paddle_num_threads=${CPU_NUM_THREADS_ON_CI}) --cpu_num_threads=${CPU_NUM_THREADS_ON_CI})
# ocr # ocr
set(OCR_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/ocr") set(OCR_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/ocr")
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <fstream>
#include <iostream>
#include "paddle/fluid/inference/tests/api/tester_helper.h"
DEFINE_string(infer_shape, "", "data shape file");
DEFINE_int32(sample, 20, "number of sample");
namespace paddle {
namespace inference {
namespace analysis {
struct Record {
std::vector<float> data;
std::vector<int32_t> shape;
};
Record ProcessALine(const std::string &line, const std::string &shape_line) {
VLOG(3) << "process a line";
std::vector<std::string> columns;
Record record;
std::vector<std::string> data_strs;
split(line, ' ', &data_strs);
for (auto &d : data_strs) {
record.data.push_back(std::stof(d));
}
std::vector<std::string> shape_strs;
split(shape_line, ' ', &shape_strs);
for (auto &s : shape_strs) {
record.shape.push_back(std::stoi(s));
}
return record;
}
void SetConfig(AnalysisConfig *cfg) {
cfg->SetModel(FLAGS_infer_model + "/model", FLAGS_infer_model + "/params");
cfg->DisableGpu();
cfg->SwitchIrDebug();
cfg->SwitchSpecifyInputNames(false);
cfg->SetCpuMathLibraryNumThreads(FLAGS_cpu_num_threads);
}
void SetInput(std::vector<std::vector<PaddleTensor>> *inputs,
const std::string &line, const std::string &shape_line) {
auto record = ProcessALine(line, shape_line);
PaddleTensor input;
input.shape = record.shape;
input.dtype = PaddleDType::FLOAT32;
size_t input_size = record.data.size() * sizeof(float);
input.data.Resize(input_size);
memcpy(input.data.data(), record.data.data(), input_size);
std::vector<PaddleTensor> input_slots;
input_slots.assign({input});
(*inputs).emplace_back(input_slots);
}
void profile(int cache_capacity = 1) {
AnalysisConfig cfg;
SetConfig(&cfg);
cfg.EnableMKLDNN();
cfg.SetMkldnnCacheCapacity(cache_capacity);
std::vector<std::vector<PaddleTensor>> outputs;
std::vector<std::vector<PaddleTensor>> input_slots_all;
Timer run_timer;
double elapsed_time = 0;
int num_times = FLAGS_repeat;
int sample = FLAGS_sample;
auto predictor = CreatePaddlePredictor<AnalysisConfig>(cfg);
outputs.resize(sample);
std::vector<std::thread> threads;
std::ifstream file(FLAGS_infer_data);
std::ifstream infer_file(FLAGS_infer_shape);
std::string line;
std::string shape_line;
for (int i = 0; i < sample; i++) {
threads.emplace_back([&, i]() {
std::getline(file, line);
std::getline(infer_file, shape_line);
SetInput(&input_slots_all, line, shape_line);
run_timer.tic();
predictor->Run(input_slots_all[0], &outputs[0], FLAGS_batch_size);
elapsed_time += run_timer.toc();
});
threads[0].join();
threads.clear();
std::vector<std::vector<PaddleTensor>>().swap(input_slots_all);
}
file.close();
infer_file.close();
auto batch_latency = elapsed_time / (sample * num_times);
PrintTime(FLAGS_batch_size, num_times, FLAGS_num_threads, 0, batch_latency,
sample, VarType::FP32);
}
#ifdef PADDLE_WITH_MKLDNN
TEST(Analyzer_detect, profile_mkldnn) {
profile(5 /* cache_capacity */);
profile(10 /* cache_capacity */);
}
#endif
} // namespace analysis
} // namespace inference
} // namespace paddle
...@@ -143,7 +143,7 @@ void SetConfig(AnalysisConfig *cfg, bool use_mkldnn = false, ...@@ -143,7 +143,7 @@ void SetConfig(AnalysisConfig *cfg, bool use_mkldnn = false,
} }
cfg->SwitchSpecifyInputNames(); cfg->SwitchSpecifyInputNames();
cfg->SwitchIrOptim(); cfg->SwitchIrOptim();
cfg->SetCpuMathLibraryNumThreads(FLAGS_paddle_num_threads); cfg->SetCpuMathLibraryNumThreads(FLAGS_cpu_num_threads);
} }
void profile(bool use_mkldnn = false, bool use_gpu = false) { void profile(bool use_mkldnn = false, bool use_gpu = false) {
......
...@@ -27,7 +27,7 @@ void SetConfig(AnalysisConfig *cfg) { ...@@ -27,7 +27,7 @@ void SetConfig(AnalysisConfig *cfg) {
cfg->DisableGpu(); cfg->DisableGpu();
cfg->SwitchIrOptim(); cfg->SwitchIrOptim();
cfg->SwitchSpecifyInputNames(); cfg->SwitchSpecifyInputNames();
cfg->SetCpuMathLibraryNumThreads(FLAGS_paddle_num_threads); cfg->SetCpuMathLibraryNumThreads(FLAGS_cpu_num_threads);
} }
void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) { void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
...@@ -40,7 +40,7 @@ void SetOptimConfig(AnalysisConfig *cfg) { ...@@ -40,7 +40,7 @@ void SetOptimConfig(AnalysisConfig *cfg) {
cfg->DisableGpu(); cfg->DisableGpu();
cfg->SwitchIrOptim(); cfg->SwitchIrOptim();
cfg->SwitchSpecifyInputNames(); cfg->SwitchSpecifyInputNames();
cfg->SetCpuMathLibraryNumThreads(FLAGS_paddle_num_threads); cfg->SetCpuMathLibraryNumThreads(FLAGS_cpu_num_threads);
} }
// Easy for profiling independently. // Easy for profiling independently.
......
...@@ -26,7 +26,7 @@ void SetConfig(AnalysisConfig *cfg) { ...@@ -26,7 +26,7 @@ void SetConfig(AnalysisConfig *cfg) {
cfg->DisableGpu(); cfg->DisableGpu();
cfg->SwitchIrOptim(); cfg->SwitchIrOptim();
cfg->SwitchSpecifyInputNames(); cfg->SwitchSpecifyInputNames();
cfg->SetCpuMathLibraryNumThreads(FLAGS_paddle_num_threads); cfg->SetCpuMathLibraryNumThreads(FLAGS_cpu_num_threads);
cfg->EnableMKLDNN(); cfg->EnableMKLDNN();
} }
......
...@@ -27,7 +27,7 @@ void SetConfig(AnalysisConfig *cfg) { ...@@ -27,7 +27,7 @@ void SetConfig(AnalysisConfig *cfg) {
cfg->DisableGpu(); cfg->DisableGpu();
cfg->SwitchIrOptim(true); cfg->SwitchIrOptim(true);
cfg->SwitchSpecifyInputNames(false); cfg->SwitchSpecifyInputNames(false);
cfg->SetCpuMathLibraryNumThreads(FLAGS_paddle_num_threads); cfg->SetCpuMathLibraryNumThreads(FLAGS_cpu_num_threads);
cfg->EnableMKLDNN(); cfg->EnableMKLDNN();
} }
......
...@@ -107,7 +107,7 @@ void SetConfig(AnalysisConfig *cfg) { ...@@ -107,7 +107,7 @@ void SetConfig(AnalysisConfig *cfg) {
cfg->DisableGpu(); cfg->DisableGpu();
cfg->SwitchSpecifyInputNames(); cfg->SwitchSpecifyInputNames();
cfg->SwitchIrOptim(); cfg->SwitchIrOptim();
cfg->SetCpuMathLibraryNumThreads(FLAGS_paddle_num_threads); cfg->SetCpuMathLibraryNumThreads(FLAGS_cpu_num_threads);
if (FLAGS_zero_copy) { if (FLAGS_zero_copy) {
cfg->SwitchUseFeedFetchOps(false); cfg->SwitchUseFeedFetchOps(false);
} }
......
...@@ -26,7 +26,7 @@ void SetConfig(AnalysisConfig *cfg, std::string model_path) { ...@@ -26,7 +26,7 @@ void SetConfig(AnalysisConfig *cfg, std::string model_path) {
cfg->DisableGpu(); cfg->DisableGpu();
cfg->SwitchIrOptim(false); cfg->SwitchIrOptim(false);
cfg->SwitchSpecifyInputNames(); cfg->SwitchSpecifyInputNames();
cfg->SetCpuMathLibraryNumThreads(FLAGS_paddle_num_threads); cfg->SetCpuMathLibraryNumThreads(FLAGS_cpu_num_threads);
cfg->EnableMKLDNN(); cfg->EnableMKLDNN();
} }
......
...@@ -143,7 +143,7 @@ void SetConfig(AnalysisConfig *cfg, bool use_mkldnn = false) { ...@@ -143,7 +143,7 @@ void SetConfig(AnalysisConfig *cfg, bool use_mkldnn = false) {
cfg->DisableGpu(); cfg->DisableGpu();
cfg->SwitchSpecifyInputNames(); cfg->SwitchSpecifyInputNames();
cfg->SwitchIrDebug(); cfg->SwitchIrDebug();
cfg->SetCpuMathLibraryNumThreads(FLAGS_paddle_num_threads); cfg->SetCpuMathLibraryNumThreads(FLAGS_cpu_num_threads);
if (FLAGS_zero_copy) { if (FLAGS_zero_copy) {
cfg->SwitchUseFeedFetchOps(false); cfg->SwitchUseFeedFetchOps(false);
} }
......
...@@ -165,7 +165,7 @@ void SetConfig(AnalysisConfig *cfg) { ...@@ -165,7 +165,7 @@ void SetConfig(AnalysisConfig *cfg) {
cfg->DisableGpu(); cfg->DisableGpu();
cfg->SwitchSpecifyInputNames(); cfg->SwitchSpecifyInputNames();
cfg->SwitchIrOptim(); cfg->SwitchIrOptim();
cfg->SetCpuMathLibraryNumThreads(FLAGS_paddle_num_threads); cfg->SetCpuMathLibraryNumThreads(FLAGS_cpu_num_threads);
} }
void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) { void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
......
...@@ -66,8 +66,8 @@ DEFINE_bool(warmup, false, ...@@ -66,8 +66,8 @@ DEFINE_bool(warmup, false,
"Use warmup to calculate elapsed_time more accurately. " "Use warmup to calculate elapsed_time more accurately. "
"To reduce CI time, it sets false in default."); "To reduce CI time, it sets false in default.");
DECLARE_bool(profile); DEFINE_bool(enable_profile, false, "Turn on profiler for fluid");
DECLARE_int32(paddle_num_threads); DEFINE_int32(cpu_num_threads, 1, "Number of threads for each paddle instance.");
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -355,7 +355,7 @@ void PredictionWarmUp(PaddlePredictor *predictor, ...@@ -355,7 +355,7 @@ void PredictionWarmUp(PaddlePredictor *predictor,
predictor->ZeroCopyRun(); predictor->ZeroCopyRun();
} }
PrintTime(batch_size, 1, num_threads, tid, warmup_timer.toc(), 1, data_type); PrintTime(batch_size, 1, num_threads, tid, warmup_timer.toc(), 1, data_type);
if (FLAGS_profile) { if (FLAGS_enable_profile) {
paddle::platform::ResetProfiler(); paddle::platform::ResetProfiler();
} }
} }
......
...@@ -38,6 +38,16 @@ DEFINE_int32(multiple_of_cupti_buffer_size, 1, ...@@ -38,6 +38,16 @@ DEFINE_int32(multiple_of_cupti_buffer_size, 1,
"Multiple of the CUPTI device buffer size. If the timestamps have " "Multiple of the CUPTI device buffer size. If the timestamps have "
"been dropped when you are profiling, try increasing this value."); "been dropped when you are profiling, try increasing this value.");
namespace paddle {
namespace platform {
void ParseCommandLineFlags(int argc, char **argv, bool remove) {
google::ParseCommandLineFlags(&argc, &argv, remove);
}
} // namespace platform
} // namespace paddle
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -19,6 +19,14 @@ limitations under the License. */ ...@@ -19,6 +19,14 @@ limitations under the License. */
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "glog/logging.h" #include "glog/logging.h"
namespace paddle {
namespace platform {
void ParseCommandLineFlags(int argc, char **argv, bool remove);
} // namespace platform
} // namespace paddle
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -23,10 +23,41 @@ limitations under the License. */ ...@@ -23,10 +23,41 @@ limitations under the License. */
int main(int argc, char** argv) { int main(int argc, char** argv) {
paddle::memory::allocation::UseAllocatorStrategyGFlag(); paddle::memory::allocation::UseAllocatorStrategyGFlag();
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
std::vector<char*> new_argv; // Because the dynamic library libpaddle_fluid.so clips the symbol table, the
std::string gflags_env; // external program cannot recognize the flag inside the so, and the flag
// defined by the external program cannot be accessed inside the so.
// Therefore, the ParseCommandLine function needs to be called separately
// inside and outside.
std::vector<char*> external_argv;
std::vector<char*> internal_argv;
// ParseNewCommandLineFlags in gflags.cc starts processing
// commandline strings from idx 1.
// The reason is, it assumes that the first one (idx 0) is
// the filename of executable file.
external_argv.push_back(argv[0]);
internal_argv.push_back(argv[0]);
std::vector<google::CommandLineFlagInfo> all_flags;
std::vector<std::string> external_flags_name;
google::GetAllFlags(&all_flags);
for (size_t i = 0; i < all_flags.size(); ++i) {
external_flags_name.push_back(all_flags[i].name);
}
for (int i = 0; i < argc; ++i) { for (int i = 0; i < argc; ++i) {
new_argv.push_back(argv[i]); bool flag = true;
std::string tmp(argv[i]);
for (size_t j = 0; j < external_flags_name.size(); ++j) {
if (tmp.find(external_flags_name[j]) != std::string::npos) {
external_argv.push_back(argv[i]);
flag = false;
break;
}
}
if (flag) {
internal_argv.push_back(argv[i]);
}
} }
std::vector<std::string> envs; std::vector<std::string> envs;
...@@ -70,7 +101,7 @@ int main(int argc, char** argv) { ...@@ -70,7 +101,7 @@ int main(int argc, char** argv) {
} }
env_string = env_string.substr(0, env_string.length() - 1); env_string = env_string.substr(0, env_string.length() - 1);
env_str = strdup(env_string.c_str()); env_str = strdup(env_string.c_str());
new_argv.push_back(env_str); internal_argv.push_back(env_str);
VLOG(1) << "gtest env_string:" << env_string; VLOG(1) << "gtest env_string:" << env_string;
} }
...@@ -82,13 +113,17 @@ int main(int argc, char** argv) { ...@@ -82,13 +113,17 @@ int main(int argc, char** argv) {
} }
undefok_string = undefok_string.substr(0, undefok_string.length() - 1); undefok_string = undefok_string.substr(0, undefok_string.length() - 1);
undefok_str = strdup(undefok_string.c_str()); undefok_str = strdup(undefok_string.c_str());
new_argv.push_back(undefok_str); internal_argv.push_back(undefok_str);
VLOG(1) << "gtest undefok_string:" << undefok_string; VLOG(1) << "gtest undefok_string:" << undefok_string;
} }
int new_argc = static_cast<int>(new_argv.size()); int new_argc = static_cast<int>(external_argv.size());
char** new_argv_address = new_argv.data(); char** external_argv_address = external_argv.data();
google::ParseCommandLineFlags(&new_argc, &new_argv_address, false); google::ParseCommandLineFlags(&new_argc, &external_argv_address, false);
int internal_argc = internal_argv.size();
char** arr = internal_argv.data();
paddle::platform::ParseCommandLineFlags(internal_argc, arr, true);
paddle::framework::InitDevices(true); paddle::framework::InitDevices(true);
int ret = RUN_ALL_TESTS(); int ret = RUN_ALL_TESTS();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册