未验证 提交 9b41e455 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #15222 from luotao1/native_config

fix analyzer_test runs error in native_config
......@@ -62,7 +62,7 @@ std::ostream &operator<<(std::ostream &os,
const contrib::AnalysisConfig &config) {
os << GenSpaces(num_spaces) << "contrib::AnalysisConfig {\n";
num_spaces++;
os << *reinterpret_cast<const NativeConfig *>(&config);
os << config.ToNativeConfig();
if (!config.model_from_memory()) {
os << GenSpaces(num_spaces) << "prog_file: " << config.prog_file() << "\n";
os << GenSpaces(num_spaces) << "param_file: " << config.params_file()
......
......@@ -54,11 +54,13 @@ namespace paddle {
namespace inference {
void PrintConfig(const PaddlePredictor::Config *config, bool use_analysis) {
const auto *analysis_config =
reinterpret_cast<const contrib::AnalysisConfig *>(config);
if (use_analysis) {
LOG(INFO) << *reinterpret_cast<const contrib::AnalysisConfig *>(config);
LOG(INFO) << *analysis_config;
return;
}
LOG(INFO) << *reinterpret_cast<const NativeConfig *>(config);
LOG(INFO) << analysis_config->ToNativeConfig();
}
void CompareResult(const std::vector<PaddleTensor> &outputs,
......@@ -96,12 +98,13 @@ void CompareResult(const std::vector<PaddleTensor> &outputs,
std::unique_ptr<PaddlePredictor> CreateTestPredictor(
const PaddlePredictor::Config *config, bool use_analysis = true) {
const auto *analysis_config =
reinterpret_cast<const contrib::AnalysisConfig *>(config);
if (use_analysis) {
return CreatePaddlePredictor<contrib::AnalysisConfig>(
*(reinterpret_cast<const contrib::AnalysisConfig *>(config)));
return CreatePaddlePredictor<contrib::AnalysisConfig>(*analysis_config);
}
return CreatePaddlePredictor<NativeConfig>(
*(reinterpret_cast<const NativeConfig *>(config)));
auto native_config = analysis_config->ToNativeConfig();
return CreatePaddlePredictor<NativeConfig>(native_config);
}
size_t GetSize(const PaddleTensor &out) { return VecReduceToInt(out.shape); }
......@@ -328,10 +331,7 @@ void CompareNativeAndAnalysis(
const std::vector<std::vector<PaddleTensor>> &inputs) {
PrintConfig(config, true);
std::vector<PaddleTensor> native_outputs, analysis_outputs;
const auto *analysis_config =
reinterpret_cast<const contrib::AnalysisConfig *>(config);
auto native_config = analysis_config->ToNativeConfig();
TestOneThreadPrediction(&native_config, inputs, &native_outputs, false);
TestOneThreadPrediction(config, inputs, &native_outputs, false);
TestOneThreadPrediction(config, inputs, &analysis_outputs, true);
CompareResult(analysis_outputs, native_outputs);
}
......
......@@ -99,24 +99,12 @@ void compare(std::string model_dir, bool use_tensorrt) {
SetFakeImageInput(&inputs_all, model_dir, false, "__model__", "");
}
std::vector<PaddleTensor> native_outputs;
NativeConfig native_config;
SetConfig<NativeConfig>(&native_config, model_dir, true, false,
FLAGS_batch_size);
TestOneThreadPrediction(
reinterpret_cast<PaddlePredictor::Config*>(&native_config), inputs_all,
&native_outputs, false);
std::vector<PaddleTensor> analysis_outputs;
contrib::AnalysisConfig analysis_config;
analysis_config.EnableUseGpu(50, 0);
SetConfig<contrib::AnalysisConfig>(&analysis_config, model_dir, true,
use_tensorrt, FLAGS_batch_size);
TestOneThreadPrediction(
reinterpret_cast<PaddlePredictor::Config*>(&analysis_config), inputs_all,
&analysis_outputs, true);
CompareResult(native_outputs, analysis_outputs);
CompareNativeAndAnalysis(
reinterpret_cast<const PaddlePredictor::Config*>(&analysis_config),
inputs_all);
}
TEST(TensorRT_mobilenet, compare) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册