Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
72b65d6b
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
72b65d6b
编写于
7月 28, 2022
作者:
H
heliqi
提交者:
GitHub
7月 28, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
clone ort_predictor reuse session (#44703)
上级
bd813d35
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
79 addition
and
59 deletion
+79
-59
paddle/fluid/inference/api/details/zero_copy_tensor.cc
paddle/fluid/inference/api/details/zero_copy_tensor.cc
+0
-4
paddle/fluid/inference/api/onnxruntime_predictor.cc
paddle/fluid/inference/api/onnxruntime_predictor.cc
+54
-48
paddle/fluid/inference/api/onnxruntime_predictor.h
paddle/fluid/inference/api/onnxruntime_predictor.h
+25
-4
paddle/fluid/inference/api/paddle_tensor.h
paddle/fluid/inference/api/paddle_tensor.h
+0
-3
未找到文件。
paddle/fluid/inference/api/details/zero_copy_tensor.cc
浏览文件 @
72b65d6b
...
...
@@ -720,10 +720,6 @@ void Tensor::SetOrtBinding(const std::shared_ptr<Ort::IoBinding> binding) {
binding_
=
binding
;
}
void
Tensor
::
SetOrtBuffer
(
const
std
::
shared_ptr
<
std
::
vector
<
int8_t
>>
buffer
)
{
buffer_
=
buffer
;
}
template
<
typename
T
>
void
Tensor
::
ORTCopyToCpu
(
T
*
data
)
const
{
auto
binding
=
binding_
.
lock
();
...
...
paddle/fluid/inference/api/onnxruntime_predictor.cc
浏览文件 @
72b65d6b
...
...
@@ -86,9 +86,7 @@ bool CheckConvertToONNX(const AnalysisConfig &config) {
}
}
bool
ONNXRuntimePredictor
::
Init
()
{
VLOG
(
3
)
<<
"ONNXRuntime Predictor::init()"
;
bool
ONNXRuntimePredictor
::
InitBinding
()
{
// Now ONNXRuntime only support CPU
const
char
*
device_name
=
config_
.
use_gpu
()
?
"Cuda"
:
"Cpu"
;
if
(
config_
.
use_gpu
())
{
...
...
@@ -98,6 +96,53 @@ bool ONNXRuntimePredictor::Init() {
}
scope_
.
reset
(
new
paddle
::
framework
::
Scope
());
binding_
=
std
::
make_shared
<
Ort
::
IoBinding
>
(
*
session_
);
Ort
::
MemoryInfo
memory_info
(
device_name
,
OrtDeviceAllocator
,
place_
.
GetDeviceId
(),
OrtMemTypeDefault
);
Ort
::
Allocator
allocator
(
*
session_
,
memory_info
);
size_t
n_inputs
=
session_
->
GetInputCount
();
framework
::
proto
::
VarType
::
Type
proto_type
=
framework
::
proto
::
VarType
::
LOD_TENSOR
;
for
(
size_t
i
=
0
;
i
<
n_inputs
;
++
i
)
{
auto
input_name
=
session_
->
GetInputName
(
i
,
allocator
);
auto
type_info
=
session_
->
GetInputTypeInfo
(
i
);
std
::
vector
<
int64_t
>
shape
=
type_info
.
GetTensorTypeAndShapeInfo
().
GetShape
();
ONNXTensorElementDataType
data_type
=
type_info
.
GetTensorTypeAndShapeInfo
().
GetElementType
();
input_desc_
.
emplace_back
(
ONNXDesc
{
input_name
,
shape
,
data_type
});
auto
*
ptr
=
scope_
->
Var
(
input_name
);
framework
::
InitializeVariable
(
ptr
,
proto_type
);
allocator
.
Free
(
input_name
);
}
size_t
n_outputs
=
session_
->
GetOutputCount
();
for
(
size_t
i
=
0
;
i
<
n_outputs
;
++
i
)
{
auto
output_name
=
session_
->
GetOutputName
(
i
,
allocator
);
auto
type_info
=
session_
->
GetOutputTypeInfo
(
i
);
std
::
vector
<
int64_t
>
shape
=
type_info
.
GetTensorTypeAndShapeInfo
().
GetShape
();
ONNXTensorElementDataType
data_type
=
type_info
.
GetTensorTypeAndShapeInfo
().
GetElementType
();
output_desc_
.
emplace_back
(
ONNXDesc
{
output_name
,
shape
,
data_type
});
Ort
::
MemoryInfo
out_memory_info
(
device_name
,
OrtDeviceAllocator
,
place_
.
GetDeviceId
(),
OrtMemTypeDefault
);
binding_
->
BindOutput
(
output_name
,
out_memory_info
);
allocator
.
Free
(
output_name
);
}
return
true
;
}
bool
ONNXRuntimePredictor
::
Init
()
{
VLOG
(
3
)
<<
"ONNXRuntime Predictor::init()"
;
char
*
onnx_proto
=
nullptr
;
int
out_size
;
if
(
config_
.
model_from_memory
())
{
...
...
@@ -139,49 +184,10 @@ bool ONNXRuntimePredictor::Init() {
"will be "
"generated."
;
}
session_
=
{
env_
,
onnx_proto
,
static_cast
<
size_t
>
(
out_size
),
session_options
};
binding_
=
std
::
make_shared
<
Ort
::
IoBinding
>
(
session_
);
Ort
::
MemoryInfo
memory_info
(
device_name
,
OrtDeviceAllocator
,
place_
.
GetDeviceId
(),
OrtMemTypeDefault
);
Ort
::
Allocator
allocator
(
session_
,
memory_info
);
size_t
n_inputs
=
session_
.
GetInputCount
();
framework
::
proto
::
VarType
::
Type
proto_type
=
framework
::
proto
::
VarType
::
LOD_TENSOR
;
for
(
size_t
i
=
0
;
i
<
n_inputs
;
++
i
)
{
auto
input_name
=
session_
.
GetInputName
(
i
,
allocator
);
auto
type_info
=
session_
.
GetInputTypeInfo
(
i
);
std
::
vector
<
int64_t
>
shape
=
type_info
.
GetTensorTypeAndShapeInfo
().
GetShape
();
ONNXTensorElementDataType
data_type
=
type_info
.
GetTensorTypeAndShapeInfo
().
GetElementType
();
input_desc_
.
emplace_back
(
ONNXDesc
{
input_name
,
shape
,
data_type
});
auto
*
ptr
=
scope_
->
Var
(
input_name
);
framework
::
InitializeVariable
(
ptr
,
proto_type
);
session_
=
std
::
make_shared
<
Ort
::
Session
>
(
*
env_
,
onnx_proto
,
static_cast
<
size_t
>
(
out_size
),
session_options
);
InitBinding
();
allocator
.
Free
(
input_name
);
}
size_t
n_outputs
=
session_
.
GetOutputCount
();
for
(
size_t
i
=
0
;
i
<
n_outputs
;
++
i
)
{
auto
output_name
=
session_
.
GetOutputName
(
i
,
allocator
);
auto
type_info
=
session_
.
GetOutputTypeInfo
(
i
);
std
::
vector
<
int64_t
>
shape
=
type_info
.
GetTensorTypeAndShapeInfo
().
GetShape
();
ONNXTensorElementDataType
data_type
=
type_info
.
GetTensorTypeAndShapeInfo
().
GetElementType
();
output_desc_
.
emplace_back
(
ONNXDesc
{
output_name
,
shape
,
data_type
});
Ort
::
MemoryInfo
out_memory_info
(
device_name
,
OrtDeviceAllocator
,
place_
.
GetDeviceId
(),
OrtMemTypeDefault
);
binding_
->
BindOutput
(
output_name
,
out_memory_info
);
allocator
.
Free
(
output_name
);
}
delete
onnx_proto
;
onnx_proto
=
nullptr
;
return
true
;
...
...
@@ -343,7 +349,7 @@ bool ONNXRuntimePredictor::ZeroCopyRun() {
OrtMemTypeDefault
);
binding_
->
BindOutput
(
output
.
name
.
c_str
(),
out_memory_info
);
}
session_
.
Run
({},
*
(
binding_
.
get
()));
session_
->
Run
({},
*
(
binding_
.
get
()));
}
catch
(
const
std
::
exception
&
e
)
{
LOG
(
ERROR
)
<<
e
.
what
();
return
false
;
...
...
@@ -354,8 +360,8 @@ bool ONNXRuntimePredictor::ZeroCopyRun() {
std
::
unique_ptr
<
PaddlePredictor
>
ONNXRuntimePredictor
::
Clone
(
void
*
stream
)
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
clone_mutex_
);
auto
*
x
=
new
ONNXRuntimePredictor
(
config_
);
x
->
Init
();
auto
*
x
=
new
ONNXRuntimePredictor
(
config_
,
env_
,
session_
);
x
->
Init
Binding
();
return
std
::
unique_ptr
<
PaddlePredictor
>
(
x
);
}
...
...
paddle/fluid/inference/api/onnxruntime_predictor.h
浏览文件 @
72b65d6b
...
...
@@ -92,7 +92,22 @@ class ONNXRuntimePredictor : public PaddlePredictor {
/// \param[in] AnalysisConfig config
///
explicit
ONNXRuntimePredictor
(
const
AnalysisConfig
&
config
)
:
env_
(
ORT_LOGGING_LEVEL_WARNING
,
"onnx"
),
config_
(
config
)
{
:
env_
(
std
::
make_shared
<
Ort
::
Env
>
(
ORT_LOGGING_LEVEL_WARNING
,
"paddle-ort"
)),
session_
(
nullptr
),
binding_
(
nullptr
),
config_
(
config
)
{
predictor_id_
=
inference
::
GetUniqueId
();
}
///
/// \brief Clone a ONNXRuntime Predictor object
///
/// \param[in] AnalysisConfig config
///
explicit
ONNXRuntimePredictor
(
const
AnalysisConfig
&
config
,
std
::
shared_ptr
<
Ort
::
Env
>
env
,
std
::
shared_ptr
<
Ort
::
Session
>
session
)
:
env_
(
env
),
session_
(
session
),
binding_
(
nullptr
),
config_
(
config
)
{
predictor_id_
=
inference
::
GetUniqueId
();
}
///
...
...
@@ -100,6 +115,13 @@ class ONNXRuntimePredictor : public PaddlePredictor {
///
~
ONNXRuntimePredictor
();
///
/// \brief Initialize ORT Binding
///
/// \return Whether the init function executed successfully
///
bool
InitBinding
();
///
/// \brief Initialize predictor
///
...
...
@@ -203,8 +225,8 @@ class ONNXRuntimePredictor : public PaddlePredictor {
private:
// ONNXRuntime
Ort
::
Env
env_
;
Ort
::
Session
session_
{
nullptr
};
std
::
shared_ptr
<
Ort
::
Env
>
env_
;
std
::
shared_ptr
<
Ort
::
Session
>
session_
{
nullptr
};
std
::
shared_ptr
<
Ort
::
IoBinding
>
binding_
;
AnalysisConfig
config_
;
...
...
@@ -212,7 +234,6 @@ class ONNXRuntimePredictor : public PaddlePredictor {
platform
::
Place
place_
;
std
::
vector
<
ONNXDesc
>
input_desc_
;
std
::
vector
<
ONNXDesc
>
output_desc_
;
std
::
map
<
std
::
string
,
std
::
shared_ptr
<
std
::
vector
<
int8_t
>>>
input_buffers_
;
int
predictor_id_
;
// Some more detailed tests, they are made the friends of the predictor, so that
...
...
paddle/fluid/inference/api/paddle_tensor.h
浏览文件 @
72b65d6b
...
...
@@ -191,7 +191,6 @@ class PD_INFER_DECL Tensor {
#ifdef PADDLE_WITH_ONNXRUNTIME
bool
is_ort_tensor_
{
false
};
std
::
vector
<
int64_t
>
shape_
;
std
::
weak_ptr
<
std
::
vector
<
int8_t
>>
buffer_
;
std
::
weak_ptr
<
Ort
::
IoBinding
>
binding_
;
int
idx_
{
-
1
};
...
...
@@ -199,8 +198,6 @@ class PD_INFER_DECL Tensor {
void
SetOrtBinding
(
const
std
::
shared_ptr
<
Ort
::
IoBinding
>
binding
);
void
SetOrtBuffer
(
const
std
::
shared_ptr
<
std
::
vector
<
int8_t
>>
buffer
);
template
<
typename
T
>
void
ORTCopyFromCpu
(
const
T
*
data
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录