未验证 提交 72b65d6b 编写于 作者: H heliqi 提交者: GitHub

clone ort_predictor reuse session (#44703)

上级 bd813d35
...@@ -720,10 +720,6 @@ void Tensor::SetOrtBinding(const std::shared_ptr<Ort::IoBinding> binding) { ...@@ -720,10 +720,6 @@ void Tensor::SetOrtBinding(const std::shared_ptr<Ort::IoBinding> binding) {
binding_ = binding; binding_ = binding;
} }
void Tensor::SetOrtBuffer(const std::shared_ptr<std::vector<int8_t>> buffer) {
buffer_ = buffer;
}
template <typename T> template <typename T>
void Tensor::ORTCopyToCpu(T *data) const { void Tensor::ORTCopyToCpu(T *data) const {
auto binding = binding_.lock(); auto binding = binding_.lock();
......
...@@ -86,9 +86,7 @@ bool CheckConvertToONNX(const AnalysisConfig &config) { ...@@ -86,9 +86,7 @@ bool CheckConvertToONNX(const AnalysisConfig &config) {
} }
} }
bool ONNXRuntimePredictor::Init() { bool ONNXRuntimePredictor::InitBinding() {
VLOG(3) << "ONNXRuntime Predictor::init()";
// Now ONNXRuntime only support CPU // Now ONNXRuntime only support CPU
const char *device_name = config_.use_gpu() ? "Cuda" : "Cpu"; const char *device_name = config_.use_gpu() ? "Cuda" : "Cpu";
if (config_.use_gpu()) { if (config_.use_gpu()) {
...@@ -98,6 +96,53 @@ bool ONNXRuntimePredictor::Init() { ...@@ -98,6 +96,53 @@ bool ONNXRuntimePredictor::Init() {
} }
scope_.reset(new paddle::framework::Scope()); scope_.reset(new paddle::framework::Scope());
binding_ = std::make_shared<Ort::IoBinding>(*session_);
Ort::MemoryInfo memory_info(
device_name, OrtDeviceAllocator, place_.GetDeviceId(), OrtMemTypeDefault);
Ort::Allocator allocator(*session_, memory_info);
size_t n_inputs = session_->GetInputCount();
framework::proto::VarType::Type proto_type =
framework::proto::VarType::LOD_TENSOR;
for (size_t i = 0; i < n_inputs; ++i) {
auto input_name = session_->GetInputName(i, allocator);
auto type_info = session_->GetInputTypeInfo(i);
std::vector<int64_t> shape =
type_info.GetTensorTypeAndShapeInfo().GetShape();
ONNXTensorElementDataType data_type =
type_info.GetTensorTypeAndShapeInfo().GetElementType();
input_desc_.emplace_back(ONNXDesc{input_name, shape, data_type});
auto *ptr = scope_->Var(input_name);
framework::InitializeVariable(ptr, proto_type);
allocator.Free(input_name);
}
size_t n_outputs = session_->GetOutputCount();
for (size_t i = 0; i < n_outputs; ++i) {
auto output_name = session_->GetOutputName(i, allocator);
auto type_info = session_->GetOutputTypeInfo(i);
std::vector<int64_t> shape =
type_info.GetTensorTypeAndShapeInfo().GetShape();
ONNXTensorElementDataType data_type =
type_info.GetTensorTypeAndShapeInfo().GetElementType();
output_desc_.emplace_back(ONNXDesc{output_name, shape, data_type});
Ort::MemoryInfo out_memory_info(device_name,
OrtDeviceAllocator,
place_.GetDeviceId(),
OrtMemTypeDefault);
binding_->BindOutput(output_name, out_memory_info);
allocator.Free(output_name);
}
return true;
}
bool ONNXRuntimePredictor::Init() {
VLOG(3) << "ONNXRuntime Predictor::init()";
char *onnx_proto = nullptr; char *onnx_proto = nullptr;
int out_size; int out_size;
if (config_.model_from_memory()) { if (config_.model_from_memory()) {
...@@ -139,49 +184,10 @@ bool ONNXRuntimePredictor::Init() { ...@@ -139,49 +184,10 @@ bool ONNXRuntimePredictor::Init() {
"will be " "will be "
"generated."; "generated.";
} }
session_ = {env_, onnx_proto, static_cast<size_t>(out_size), session_options}; session_ = std::make_shared<Ort::Session>(
binding_ = std::make_shared<Ort::IoBinding>(session_); *env_, onnx_proto, static_cast<size_t>(out_size), session_options);
InitBinding();
Ort::MemoryInfo memory_info(
device_name, OrtDeviceAllocator, place_.GetDeviceId(), OrtMemTypeDefault);
Ort::Allocator allocator(session_, memory_info);
size_t n_inputs = session_.GetInputCount();
framework::proto::VarType::Type proto_type =
framework::proto::VarType::LOD_TENSOR;
for (size_t i = 0; i < n_inputs; ++i) {
auto input_name = session_.GetInputName(i, allocator);
auto type_info = session_.GetInputTypeInfo(i);
std::vector<int64_t> shape =
type_info.GetTensorTypeAndShapeInfo().GetShape();
ONNXTensorElementDataType data_type =
type_info.GetTensorTypeAndShapeInfo().GetElementType();
input_desc_.emplace_back(ONNXDesc{input_name, shape, data_type});
auto *ptr = scope_->Var(input_name);
framework::InitializeVariable(ptr, proto_type);
allocator.Free(input_name);
}
size_t n_outputs = session_.GetOutputCount();
for (size_t i = 0; i < n_outputs; ++i) {
auto output_name = session_.GetOutputName(i, allocator);
auto type_info = session_.GetOutputTypeInfo(i);
std::vector<int64_t> shape =
type_info.GetTensorTypeAndShapeInfo().GetShape();
ONNXTensorElementDataType data_type =
type_info.GetTensorTypeAndShapeInfo().GetElementType();
output_desc_.emplace_back(ONNXDesc{output_name, shape, data_type});
Ort::MemoryInfo out_memory_info(device_name,
OrtDeviceAllocator,
place_.GetDeviceId(),
OrtMemTypeDefault);
binding_->BindOutput(output_name, out_memory_info);
allocator.Free(output_name);
}
delete onnx_proto; delete onnx_proto;
onnx_proto = nullptr; onnx_proto = nullptr;
return true; return true;
...@@ -343,7 +349,7 @@ bool ONNXRuntimePredictor::ZeroCopyRun() { ...@@ -343,7 +349,7 @@ bool ONNXRuntimePredictor::ZeroCopyRun() {
OrtMemTypeDefault); OrtMemTypeDefault);
binding_->BindOutput(output.name.c_str(), out_memory_info); binding_->BindOutput(output.name.c_str(), out_memory_info);
} }
session_.Run({}, *(binding_.get())); session_->Run({}, *(binding_.get()));
} catch (const std::exception &e) { } catch (const std::exception &e) {
LOG(ERROR) << e.what(); LOG(ERROR) << e.what();
return false; return false;
...@@ -354,8 +360,8 @@ bool ONNXRuntimePredictor::ZeroCopyRun() { ...@@ -354,8 +360,8 @@ bool ONNXRuntimePredictor::ZeroCopyRun() {
std::unique_ptr<PaddlePredictor> ONNXRuntimePredictor::Clone(void *stream) { std::unique_ptr<PaddlePredictor> ONNXRuntimePredictor::Clone(void *stream) {
std::lock_guard<std::mutex> lk(clone_mutex_); std::lock_guard<std::mutex> lk(clone_mutex_);
auto *x = new ONNXRuntimePredictor(config_); auto *x = new ONNXRuntimePredictor(config_, env_, session_);
x->Init(); x->InitBinding();
return std::unique_ptr<PaddlePredictor>(x); return std::unique_ptr<PaddlePredictor>(x);
} }
......
...@@ -92,7 +92,22 @@ class ONNXRuntimePredictor : public PaddlePredictor { ...@@ -92,7 +92,22 @@ class ONNXRuntimePredictor : public PaddlePredictor {
/// \param[in] AnalysisConfig config /// \param[in] AnalysisConfig config
/// ///
explicit ONNXRuntimePredictor(const AnalysisConfig &config) explicit ONNXRuntimePredictor(const AnalysisConfig &config)
: env_(ORT_LOGGING_LEVEL_WARNING, "onnx"), config_(config) { : env_(std::make_shared<Ort::Env>(ORT_LOGGING_LEVEL_WARNING,
"paddle-ort")),
session_(nullptr),
binding_(nullptr),
config_(config) {
predictor_id_ = inference::GetUniqueId();
}
///
/// \brief Clone a ONNXRuntime Predictor object
///
/// \param[in] AnalysisConfig config
///
explicit ONNXRuntimePredictor(const AnalysisConfig &config,
std::shared_ptr<Ort::Env> env,
std::shared_ptr<Ort::Session> session)
: env_(env), session_(session), binding_(nullptr), config_(config) {
predictor_id_ = inference::GetUniqueId(); predictor_id_ = inference::GetUniqueId();
} }
/// ///
...@@ -100,6 +115,13 @@ class ONNXRuntimePredictor : public PaddlePredictor { ...@@ -100,6 +115,13 @@ class ONNXRuntimePredictor : public PaddlePredictor {
/// ///
~ONNXRuntimePredictor(); ~ONNXRuntimePredictor();
///
/// \brief Initialize ORT Binding
///
/// \return Whether the init function executed successfully
///
bool InitBinding();
/// ///
/// \brief Initialize predictor /// \brief Initialize predictor
/// ///
...@@ -203,8 +225,8 @@ class ONNXRuntimePredictor : public PaddlePredictor { ...@@ -203,8 +225,8 @@ class ONNXRuntimePredictor : public PaddlePredictor {
private: private:
// ONNXRuntime // ONNXRuntime
Ort::Env env_; std::shared_ptr<Ort::Env> env_;
Ort::Session session_{nullptr}; std::shared_ptr<Ort::Session> session_{nullptr};
std::shared_ptr<Ort::IoBinding> binding_; std::shared_ptr<Ort::IoBinding> binding_;
AnalysisConfig config_; AnalysisConfig config_;
...@@ -212,7 +234,6 @@ class ONNXRuntimePredictor : public PaddlePredictor { ...@@ -212,7 +234,6 @@ class ONNXRuntimePredictor : public PaddlePredictor {
platform::Place place_; platform::Place place_;
std::vector<ONNXDesc> input_desc_; std::vector<ONNXDesc> input_desc_;
std::vector<ONNXDesc> output_desc_; std::vector<ONNXDesc> output_desc_;
std::map<std::string, std::shared_ptr<std::vector<int8_t>>> input_buffers_;
int predictor_id_; int predictor_id_;
// Some more detailed tests, they are made the friends of the predictor, so that // Some more detailed tests, they are made the friends of the predictor, so that
......
...@@ -191,7 +191,6 @@ class PD_INFER_DECL Tensor { ...@@ -191,7 +191,6 @@ class PD_INFER_DECL Tensor {
#ifdef PADDLE_WITH_ONNXRUNTIME #ifdef PADDLE_WITH_ONNXRUNTIME
bool is_ort_tensor_{false}; bool is_ort_tensor_{false};
std::vector<int64_t> shape_; std::vector<int64_t> shape_;
std::weak_ptr<std::vector<int8_t>> buffer_;
std::weak_ptr<Ort::IoBinding> binding_; std::weak_ptr<Ort::IoBinding> binding_;
int idx_{-1}; int idx_{-1};
...@@ -199,8 +198,6 @@ class PD_INFER_DECL Tensor { ...@@ -199,8 +198,6 @@ class PD_INFER_DECL Tensor {
void SetOrtBinding(const std::shared_ptr<Ort::IoBinding> binding); void SetOrtBinding(const std::shared_ptr<Ort::IoBinding> binding);
void SetOrtBuffer(const std::shared_ptr<std::vector<int8_t>> buffer);
template <typename T> template <typename T>
void ORTCopyFromCpu(const T* data); void ORTCopyFromCpu(const T* data);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册