From f12b5260369aed9b458bf74f716d9525cba9c601 Mon Sep 17 00:00:00 2001 From: heliqi <1101791222@qq.com> Date: Wed, 30 Mar 2022 11:05:36 +0800 Subject: [PATCH] Optimize the onnxruntime code (#41044) --- paddle/fluid/inference/api/details/zero_copy_tensor.cc | 3 ++- paddle/fluid/inference/api/onnxruntime_predictor.cc | 2 +- paddle/fluid/inference/api/onnxruntime_predictor_tester.cc | 4 ---- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/inference/api/details/zero_copy_tensor.cc b/paddle/fluid/inference/api/details/zero_copy_tensor.cc index 66dec0157d..77ab6bd590 100644 --- a/paddle/fluid/inference/api/details/zero_copy_tensor.cc +++ b/paddle/fluid/inference/api/details/zero_copy_tensor.cc @@ -23,7 +23,8 @@ #include "paddle/fluid/platform/float16.h" #include "paddle/phi/core/allocator.h" #ifdef PADDLE_WITH_ONNXRUNTIME -#include "paddle/fluid/inference/api/onnxruntime_predictor.h" +#include "onnxruntime_c_api.h" // NOLINT +#include "onnxruntime_cxx_api.h" // NOLINT #endif namespace paddle_infer { diff --git a/paddle/fluid/inference/api/onnxruntime_predictor.cc b/paddle/fluid/inference/api/onnxruntime_predictor.cc index bd9de252a0..eb561667fe 100644 --- a/paddle/fluid/inference/api/onnxruntime_predictor.cc +++ b/paddle/fluid/inference/api/onnxruntime_predictor.cc @@ -81,7 +81,7 @@ bool CheckConvertToONNX(const AnalysisConfig &config) { bool ONNXRuntimePredictor::Init() { VLOG(3) << "ONNXRuntime Predictor::init()"; - // Now ONNXRuntime only suuport CPU + // Now ONNXRuntime only support CPU const char *device_name = config_.use_gpu() ? "Cuda" : "Cpu"; if (config_.use_gpu()) { place_ = paddle::platform::CUDAPlace(config_.gpu_device_id()); diff --git a/paddle/fluid/inference/api/onnxruntime_predictor_tester.cc b/paddle/fluid/inference/api/onnxruntime_predictor_tester.cc index 2be2de9c60..4a702edacc 100644 --- a/paddle/fluid/inference/api/onnxruntime_predictor_tester.cc +++ b/paddle/fluid/inference/api/onnxruntime_predictor_tester.cc @@ -49,10 +49,6 @@ TEST(ONNXRuntimePredictor, onnxruntime_on) { ASSERT_TRUE(predictor); ASSERT_TRUE(!predictor->Clone()); - ASSERT_TRUE(predictor->scope_); - ASSERT_TRUE(predictor->sub_scope_); - ASSERT_EQ(predictor->scope_->parent(), nullptr); - ASSERT_EQ(predictor->sub_scope_->parent(), predictor->scope_.get()); // Dummy Input Data std::vector input_shape = {-1, 3, 224, 224}; std::vector input_data(1 * 3 * 224 * 224, 1.0); -- GitLab