From 72b65d6b2a978b6ff2e898c286b42ce88a2b3ce6 Mon Sep 17 00:00:00 2001 From: heliqi <1101791222@qq.com> Date: Thu, 28 Jul 2022 08:23:05 -0500 Subject: [PATCH] clone ort_predictor reuse session (#44703) --- .../inference/api/details/zero_copy_tensor.cc | 4 - .../inference/api/onnxruntime_predictor.cc | 102 +++++++++--------- .../inference/api/onnxruntime_predictor.h | 29 ++++- paddle/fluid/inference/api/paddle_tensor.h | 3 - 4 files changed, 79 insertions(+), 59 deletions(-) diff --git a/paddle/fluid/inference/api/details/zero_copy_tensor.cc b/paddle/fluid/inference/api/details/zero_copy_tensor.cc index 7bb384b273..81c34ae29c 100644 --- a/paddle/fluid/inference/api/details/zero_copy_tensor.cc +++ b/paddle/fluid/inference/api/details/zero_copy_tensor.cc @@ -720,10 +720,6 @@ void Tensor::SetOrtBinding(const std::shared_ptr binding) { binding_ = binding; } -void Tensor::SetOrtBuffer(const std::shared_ptr> buffer) { - buffer_ = buffer; -} - template void Tensor::ORTCopyToCpu(T *data) const { auto binding = binding_.lock(); diff --git a/paddle/fluid/inference/api/onnxruntime_predictor.cc b/paddle/fluid/inference/api/onnxruntime_predictor.cc index 5313db6442..340d61b480 100644 --- a/paddle/fluid/inference/api/onnxruntime_predictor.cc +++ b/paddle/fluid/inference/api/onnxruntime_predictor.cc @@ -86,9 +86,7 @@ bool CheckConvertToONNX(const AnalysisConfig &config) { } } -bool ONNXRuntimePredictor::Init() { - VLOG(3) << "ONNXRuntime Predictor::init()"; - +bool ONNXRuntimePredictor::InitBinding() { // Now ONNXRuntime only support CPU const char *device_name = config_.use_gpu() ? "Cuda" : "Cpu"; if (config_.use_gpu()) { @@ -98,6 +96,53 @@ bool ONNXRuntimePredictor::Init() { } scope_.reset(new paddle::framework::Scope()); + binding_ = std::make_shared(*session_); + Ort::MemoryInfo memory_info( + device_name, OrtDeviceAllocator, place_.GetDeviceId(), OrtMemTypeDefault); + Ort::Allocator allocator(*session_, memory_info); + + size_t n_inputs = session_->GetInputCount(); + framework::proto::VarType::Type proto_type = + framework::proto::VarType::LOD_TENSOR; + for (size_t i = 0; i < n_inputs; ++i) { + auto input_name = session_->GetInputName(i, allocator); + auto type_info = session_->GetInputTypeInfo(i); + std::vector shape = + type_info.GetTensorTypeAndShapeInfo().GetShape(); + ONNXTensorElementDataType data_type = + type_info.GetTensorTypeAndShapeInfo().GetElementType(); + input_desc_.emplace_back(ONNXDesc{input_name, shape, data_type}); + + auto *ptr = scope_->Var(input_name); + framework::InitializeVariable(ptr, proto_type); + + allocator.Free(input_name); + } + + size_t n_outputs = session_->GetOutputCount(); + for (size_t i = 0; i < n_outputs; ++i) { + auto output_name = session_->GetOutputName(i, allocator); + auto type_info = session_->GetOutputTypeInfo(i); + std::vector shape = + type_info.GetTensorTypeAndShapeInfo().GetShape(); + ONNXTensorElementDataType data_type = + type_info.GetTensorTypeAndShapeInfo().GetElementType(); + output_desc_.emplace_back(ONNXDesc{output_name, shape, data_type}); + + Ort::MemoryInfo out_memory_info(device_name, + OrtDeviceAllocator, + place_.GetDeviceId(), + OrtMemTypeDefault); + binding_->BindOutput(output_name, out_memory_info); + + allocator.Free(output_name); + } + return true; +} + +bool ONNXRuntimePredictor::Init() { + VLOG(3) << "ONNXRuntime Predictor::init()"; + char *onnx_proto = nullptr; int out_size; if (config_.model_from_memory()) { @@ -139,49 +184,10 @@ bool ONNXRuntimePredictor::Init() { "will be " "generated."; } - session_ = {env_, onnx_proto, static_cast(out_size), session_options}; - binding_ = std::make_shared(session_); - - Ort::MemoryInfo memory_info( - device_name, OrtDeviceAllocator, place_.GetDeviceId(), OrtMemTypeDefault); - Ort::Allocator allocator(session_, memory_info); - - size_t n_inputs = session_.GetInputCount(); - framework::proto::VarType::Type proto_type = - framework::proto::VarType::LOD_TENSOR; - for (size_t i = 0; i < n_inputs; ++i) { - auto input_name = session_.GetInputName(i, allocator); - auto type_info = session_.GetInputTypeInfo(i); - std::vector shape = - type_info.GetTensorTypeAndShapeInfo().GetShape(); - ONNXTensorElementDataType data_type = - type_info.GetTensorTypeAndShapeInfo().GetElementType(); - input_desc_.emplace_back(ONNXDesc{input_name, shape, data_type}); - - auto *ptr = scope_->Var(input_name); - framework::InitializeVariable(ptr, proto_type); + session_ = std::make_shared( + *env_, onnx_proto, static_cast(out_size), session_options); + InitBinding(); - allocator.Free(input_name); - } - - size_t n_outputs = session_.GetOutputCount(); - for (size_t i = 0; i < n_outputs; ++i) { - auto output_name = session_.GetOutputName(i, allocator); - auto type_info = session_.GetOutputTypeInfo(i); - std::vector shape = - type_info.GetTensorTypeAndShapeInfo().GetShape(); - ONNXTensorElementDataType data_type = - type_info.GetTensorTypeAndShapeInfo().GetElementType(); - output_desc_.emplace_back(ONNXDesc{output_name, shape, data_type}); - - Ort::MemoryInfo out_memory_info(device_name, - OrtDeviceAllocator, - place_.GetDeviceId(), - OrtMemTypeDefault); - binding_->BindOutput(output_name, out_memory_info); - - allocator.Free(output_name); - } delete onnx_proto; onnx_proto = nullptr; return true; @@ -343,7 +349,7 @@ bool ONNXRuntimePredictor::ZeroCopyRun() { OrtMemTypeDefault); binding_->BindOutput(output.name.c_str(), out_memory_info); } - session_.Run({}, *(binding_.get())); + session_->Run({}, *(binding_.get())); } catch (const std::exception &e) { LOG(ERROR) << e.what(); return false; @@ -354,8 +360,8 @@ bool ONNXRuntimePredictor::ZeroCopyRun() { std::unique_ptr ONNXRuntimePredictor::Clone(void *stream) { std::lock_guard lk(clone_mutex_); - auto *x = new ONNXRuntimePredictor(config_); - x->Init(); + auto *x = new ONNXRuntimePredictor(config_, env_, session_); + x->InitBinding(); return std::unique_ptr(x); } diff --git a/paddle/fluid/inference/api/onnxruntime_predictor.h b/paddle/fluid/inference/api/onnxruntime_predictor.h index b8f0ad0a52..971632c4b3 100644 --- a/paddle/fluid/inference/api/onnxruntime_predictor.h +++ b/paddle/fluid/inference/api/onnxruntime_predictor.h @@ -92,7 +92,22 @@ class ONNXRuntimePredictor : public PaddlePredictor { /// \param[in] AnalysisConfig config /// explicit ONNXRuntimePredictor(const AnalysisConfig &config) - : env_(ORT_LOGGING_LEVEL_WARNING, "onnx"), config_(config) { + : env_(std::make_shared(ORT_LOGGING_LEVEL_WARNING, + "paddle-ort")), + session_(nullptr), + binding_(nullptr), + config_(config) { + predictor_id_ = inference::GetUniqueId(); + } + /// + /// \brief Clone a ONNXRuntime Predictor object + /// + /// \param[in] AnalysisConfig config + /// + explicit ONNXRuntimePredictor(const AnalysisConfig &config, + std::shared_ptr env, + std::shared_ptr session) + : env_(env), session_(session), binding_(nullptr), config_(config) { predictor_id_ = inference::GetUniqueId(); } /// @@ -100,6 +115,13 @@ class ONNXRuntimePredictor : public PaddlePredictor { /// ~ONNXRuntimePredictor(); + /// + /// \brief Initialize ORT Binding + /// + /// \return Whether the init function executed successfully + /// + bool InitBinding(); + /// /// \brief Initialize predictor /// @@ -203,8 +225,8 @@ class ONNXRuntimePredictor : public PaddlePredictor { private: // ONNXRuntime - Ort::Env env_; - Ort::Session session_{nullptr}; + std::shared_ptr env_; + std::shared_ptr session_{nullptr}; std::shared_ptr binding_; AnalysisConfig config_; @@ -212,7 +234,6 @@ class ONNXRuntimePredictor : public PaddlePredictor { platform::Place place_; std::vector input_desc_; std::vector output_desc_; - std::map>> input_buffers_; int predictor_id_; // Some more detailed tests, they are made the friends of the predictor, so that diff --git a/paddle/fluid/inference/api/paddle_tensor.h b/paddle/fluid/inference/api/paddle_tensor.h index c0396713bb..d96148abd3 100644 --- a/paddle/fluid/inference/api/paddle_tensor.h +++ b/paddle/fluid/inference/api/paddle_tensor.h @@ -191,7 +191,6 @@ class PD_INFER_DECL Tensor { #ifdef PADDLE_WITH_ONNXRUNTIME bool is_ort_tensor_{false}; std::vector shape_; - std::weak_ptr> buffer_; std::weak_ptr binding_; int idx_{-1}; @@ -199,8 +198,6 @@ class PD_INFER_DECL Tensor { void SetOrtBinding(const std::shared_ptr binding); - void SetOrtBuffer(const std::shared_ptr> buffer); - template void ORTCopyFromCpu(const T* data); -- GitLab