未验证 提交 72b65d6b 编写于 作者: H heliqi 提交者: GitHub

clone ort_predictor reuse session (#44703)

上级 bd813d35
......@@ -720,10 +720,6 @@ void Tensor::SetOrtBinding(const std::shared_ptr<Ort::IoBinding> binding) {
binding_ = binding;
}
void Tensor::SetOrtBuffer(const std::shared_ptr<std::vector<int8_t>> buffer) {
buffer_ = buffer;
}
template <typename T>
void Tensor::ORTCopyToCpu(T *data) const {
auto binding = binding_.lock();
......
......@@ -86,9 +86,7 @@ bool CheckConvertToONNX(const AnalysisConfig &config) {
}
}
bool ONNXRuntimePredictor::Init() {
VLOG(3) << "ONNXRuntime Predictor::init()";
bool ONNXRuntimePredictor::InitBinding() {
// Now ONNXRuntime only support CPU
const char *device_name = config_.use_gpu() ? "Cuda" : "Cpu";
if (config_.use_gpu()) {
......@@ -98,6 +96,53 @@ bool ONNXRuntimePredictor::Init() {
}
scope_.reset(new paddle::framework::Scope());
binding_ = std::make_shared<Ort::IoBinding>(*session_);
Ort::MemoryInfo memory_info(
device_name, OrtDeviceAllocator, place_.GetDeviceId(), OrtMemTypeDefault);
Ort::Allocator allocator(*session_, memory_info);
size_t n_inputs = session_->GetInputCount();
framework::proto::VarType::Type proto_type =
framework::proto::VarType::LOD_TENSOR;
for (size_t i = 0; i < n_inputs; ++i) {
auto input_name = session_->GetInputName(i, allocator);
auto type_info = session_->GetInputTypeInfo(i);
std::vector<int64_t> shape =
type_info.GetTensorTypeAndShapeInfo().GetShape();
ONNXTensorElementDataType data_type =
type_info.GetTensorTypeAndShapeInfo().GetElementType();
input_desc_.emplace_back(ONNXDesc{input_name, shape, data_type});
auto *ptr = scope_->Var(input_name);
framework::InitializeVariable(ptr, proto_type);
allocator.Free(input_name);
}
size_t n_outputs = session_->GetOutputCount();
for (size_t i = 0; i < n_outputs; ++i) {
auto output_name = session_->GetOutputName(i, allocator);
auto type_info = session_->GetOutputTypeInfo(i);
std::vector<int64_t> shape =
type_info.GetTensorTypeAndShapeInfo().GetShape();
ONNXTensorElementDataType data_type =
type_info.GetTensorTypeAndShapeInfo().GetElementType();
output_desc_.emplace_back(ONNXDesc{output_name, shape, data_type});
Ort::MemoryInfo out_memory_info(device_name,
OrtDeviceAllocator,
place_.GetDeviceId(),
OrtMemTypeDefault);
binding_->BindOutput(output_name, out_memory_info);
allocator.Free(output_name);
}
return true;
}
bool ONNXRuntimePredictor::Init() {
VLOG(3) << "ONNXRuntime Predictor::init()";
char *onnx_proto = nullptr;
int out_size;
if (config_.model_from_memory()) {
......@@ -139,49 +184,10 @@ bool ONNXRuntimePredictor::Init() {
"will be "
"generated.";
}
session_ = {env_, onnx_proto, static_cast<size_t>(out_size), session_options};
binding_ = std::make_shared<Ort::IoBinding>(session_);
Ort::MemoryInfo memory_info(
device_name, OrtDeviceAllocator, place_.GetDeviceId(), OrtMemTypeDefault);
Ort::Allocator allocator(session_, memory_info);
size_t n_inputs = session_.GetInputCount();
framework::proto::VarType::Type proto_type =
framework::proto::VarType::LOD_TENSOR;
for (size_t i = 0; i < n_inputs; ++i) {
auto input_name = session_.GetInputName(i, allocator);
auto type_info = session_.GetInputTypeInfo(i);
std::vector<int64_t> shape =
type_info.GetTensorTypeAndShapeInfo().GetShape();
ONNXTensorElementDataType data_type =
type_info.GetTensorTypeAndShapeInfo().GetElementType();
input_desc_.emplace_back(ONNXDesc{input_name, shape, data_type});
auto *ptr = scope_->Var(input_name);
framework::InitializeVariable(ptr, proto_type);
session_ = std::make_shared<Ort::Session>(
*env_, onnx_proto, static_cast<size_t>(out_size), session_options);
InitBinding();
allocator.Free(input_name);
}
size_t n_outputs = session_.GetOutputCount();
for (size_t i = 0; i < n_outputs; ++i) {
auto output_name = session_.GetOutputName(i, allocator);
auto type_info = session_.GetOutputTypeInfo(i);
std::vector<int64_t> shape =
type_info.GetTensorTypeAndShapeInfo().GetShape();
ONNXTensorElementDataType data_type =
type_info.GetTensorTypeAndShapeInfo().GetElementType();
output_desc_.emplace_back(ONNXDesc{output_name, shape, data_type});
Ort::MemoryInfo out_memory_info(device_name,
OrtDeviceAllocator,
place_.GetDeviceId(),
OrtMemTypeDefault);
binding_->BindOutput(output_name, out_memory_info);
allocator.Free(output_name);
}
delete onnx_proto;
onnx_proto = nullptr;
return true;
......@@ -343,7 +349,7 @@ bool ONNXRuntimePredictor::ZeroCopyRun() {
OrtMemTypeDefault);
binding_->BindOutput(output.name.c_str(), out_memory_info);
}
session_.Run({}, *(binding_.get()));
session_->Run({}, *(binding_.get()));
} catch (const std::exception &e) {
LOG(ERROR) << e.what();
return false;
......@@ -354,8 +360,8 @@ bool ONNXRuntimePredictor::ZeroCopyRun() {
std::unique_ptr<PaddlePredictor> ONNXRuntimePredictor::Clone(void *stream) {
std::lock_guard<std::mutex> lk(clone_mutex_);
auto *x = new ONNXRuntimePredictor(config_);
x->Init();
auto *x = new ONNXRuntimePredictor(config_, env_, session_);
x->InitBinding();
return std::unique_ptr<PaddlePredictor>(x);
}
......
......@@ -92,7 +92,22 @@ class ONNXRuntimePredictor : public PaddlePredictor {
/// \param[in] AnalysisConfig config
///
explicit ONNXRuntimePredictor(const AnalysisConfig &config)
: env_(ORT_LOGGING_LEVEL_WARNING, "onnx"), config_(config) {
: env_(std::make_shared<Ort::Env>(ORT_LOGGING_LEVEL_WARNING,
"paddle-ort")),
session_(nullptr),
binding_(nullptr),
config_(config) {
predictor_id_ = inference::GetUniqueId();
}
///
/// \brief Clone a ONNXRuntime Predictor object
///
/// \param[in] AnalysisConfig config
///
explicit ONNXRuntimePredictor(const AnalysisConfig &config,
std::shared_ptr<Ort::Env> env,
std::shared_ptr<Ort::Session> session)
: env_(env), session_(session), binding_(nullptr), config_(config) {
predictor_id_ = inference::GetUniqueId();
}
///
......@@ -100,6 +115,13 @@ class ONNXRuntimePredictor : public PaddlePredictor {
///
~ONNXRuntimePredictor();
///
/// \brief Initialize ORT Binding
///
/// \return Whether the init function executed successfully
///
bool InitBinding();
///
/// \brief Initialize predictor
///
......@@ -203,8 +225,8 @@ class ONNXRuntimePredictor : public PaddlePredictor {
private:
// ONNXRuntime
Ort::Env env_;
Ort::Session session_{nullptr};
std::shared_ptr<Ort::Env> env_;
std::shared_ptr<Ort::Session> session_{nullptr};
std::shared_ptr<Ort::IoBinding> binding_;
AnalysisConfig config_;
......@@ -212,7 +234,6 @@ class ONNXRuntimePredictor : public PaddlePredictor {
platform::Place place_;
std::vector<ONNXDesc> input_desc_;
std::vector<ONNXDesc> output_desc_;
std::map<std::string, std::shared_ptr<std::vector<int8_t>>> input_buffers_;
int predictor_id_;
// Some more detailed tests, they are made the friends of the predictor, so that
......
......@@ -191,7 +191,6 @@ class PD_INFER_DECL Tensor {
#ifdef PADDLE_WITH_ONNXRUNTIME
bool is_ort_tensor_{false};
std::vector<int64_t> shape_;
std::weak_ptr<std::vector<int8_t>> buffer_;
std::weak_ptr<Ort::IoBinding> binding_;
int idx_{-1};
......@@ -199,8 +198,6 @@ class PD_INFER_DECL Tensor {
void SetOrtBinding(const std::shared_ptr<Ort::IoBinding> binding);
void SetOrtBuffer(const std::shared_ptr<std::vector<int8_t>> buffer);
template <typename T>
void ORTCopyFromCpu(const T* data);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册