Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
72b65d6b
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
72b65d6b
编写于
7月 28, 2022
作者:
H
heliqi
提交者:
GitHub
7月 28, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
clone ort_predictor reuse session (#44703)
上级
bd813d35
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
79 addition
and
59 deletion
+79
-59
paddle/fluid/inference/api/details/zero_copy_tensor.cc
paddle/fluid/inference/api/details/zero_copy_tensor.cc
+0
-4
paddle/fluid/inference/api/onnxruntime_predictor.cc
paddle/fluid/inference/api/onnxruntime_predictor.cc
+54
-48
paddle/fluid/inference/api/onnxruntime_predictor.h
paddle/fluid/inference/api/onnxruntime_predictor.h
+25
-4
paddle/fluid/inference/api/paddle_tensor.h
paddle/fluid/inference/api/paddle_tensor.h
+0
-3
未找到文件。
paddle/fluid/inference/api/details/zero_copy_tensor.cc
浏览文件 @
72b65d6b
...
...
@@ -720,10 +720,6 @@ void Tensor::SetOrtBinding(const std::shared_ptr<Ort::IoBinding> binding) {
binding_
=
binding
;
}
void
Tensor
::
SetOrtBuffer
(
const
std
::
shared_ptr
<
std
::
vector
<
int8_t
>>
buffer
)
{
buffer_
=
buffer
;
}
template
<
typename
T
>
void
Tensor
::
ORTCopyToCpu
(
T
*
data
)
const
{
auto
binding
=
binding_
.
lock
();
...
...
paddle/fluid/inference/api/onnxruntime_predictor.cc
浏览文件 @
72b65d6b
...
...
@@ -86,9 +86,7 @@ bool CheckConvertToONNX(const AnalysisConfig &config) {
}
}
bool
ONNXRuntimePredictor
::
Init
()
{
VLOG
(
3
)
<<
"ONNXRuntime Predictor::init()"
;
bool
ONNXRuntimePredictor
::
InitBinding
()
{
// Now ONNXRuntime only support CPU
const
char
*
device_name
=
config_
.
use_gpu
()
?
"Cuda"
:
"Cpu"
;
if
(
config_
.
use_gpu
())
{
...
...
@@ -98,6 +96,53 @@ bool ONNXRuntimePredictor::Init() {
}
scope_
.
reset
(
new
paddle
::
framework
::
Scope
());
binding_
=
std
::
make_shared
<
Ort
::
IoBinding
>
(
*
session_
);
Ort
::
MemoryInfo
memory_info
(
device_name
,
OrtDeviceAllocator
,
place_
.
GetDeviceId
(),
OrtMemTypeDefault
);
Ort
::
Allocator
allocator
(
*
session_
,
memory_info
);
size_t
n_inputs
=
session_
->
GetInputCount
();
framework
::
proto
::
VarType
::
Type
proto_type
=
framework
::
proto
::
VarType
::
LOD_TENSOR
;
for
(
size_t
i
=
0
;
i
<
n_inputs
;
++
i
)
{
auto
input_name
=
session_
->
GetInputName
(
i
,
allocator
);
auto
type_info
=
session_
->
GetInputTypeInfo
(
i
);
std
::
vector
<
int64_t
>
shape
=
type_info
.
GetTensorTypeAndShapeInfo
().
GetShape
();
ONNXTensorElementDataType
data_type
=
type_info
.
GetTensorTypeAndShapeInfo
().
GetElementType
();
input_desc_
.
emplace_back
(
ONNXDesc
{
input_name
,
shape
,
data_type
});
auto
*
ptr
=
scope_
->
Var
(
input_name
);
framework
::
InitializeVariable
(
ptr
,
proto_type
);
allocator
.
Free
(
input_name
);
}
size_t
n_outputs
=
session_
->
GetOutputCount
();
for
(
size_t
i
=
0
;
i
<
n_outputs
;
++
i
)
{
auto
output_name
=
session_
->
GetOutputName
(
i
,
allocator
);
auto
type_info
=
session_
->
GetOutputTypeInfo
(
i
);
std
::
vector
<
int64_t
>
shape
=
type_info
.
GetTensorTypeAndShapeInfo
().
GetShape
();
ONNXTensorElementDataType
data_type
=
type_info
.
GetTensorTypeAndShapeInfo
().
GetElementType
();
output_desc_
.
emplace_back
(
ONNXDesc
{
output_name
,
shape
,
data_type
});
Ort
::
MemoryInfo
out_memory_info
(
device_name
,
OrtDeviceAllocator
,
place_
.
GetDeviceId
(),
OrtMemTypeDefault
);
binding_
->
BindOutput
(
output_name
,
out_memory_info
);
allocator
.
Free
(
output_name
);
}
return
true
;
}
bool
ONNXRuntimePredictor
::
Init
()
{
VLOG
(
3
)
<<
"ONNXRuntime Predictor::init()"
;
char
*
onnx_proto
=
nullptr
;
int
out_size
;
if
(
config_
.
model_from_memory
())
{
...
...
@@ -139,49 +184,10 @@ bool ONNXRuntimePredictor::Init() {
"will be "
"generated."
;
}
session_
=
{
env_
,
onnx_proto
,
static_cast
<
size_t
>
(
out_size
),
session_options
};
binding_
=
std
::
make_shared
<
Ort
::
IoBinding
>
(
session_
);
Ort
::
MemoryInfo
memory_info
(
device_name
,
OrtDeviceAllocator
,
place_
.
GetDeviceId
(),
OrtMemTypeDefault
);
Ort
::
Allocator
allocator
(
session_
,
memory_info
);
size_t
n_inputs
=
session_
.
GetInputCount
();
framework
::
proto
::
VarType
::
Type
proto_type
=
framework
::
proto
::
VarType
::
LOD_TENSOR
;
for
(
size_t
i
=
0
;
i
<
n_inputs
;
++
i
)
{
auto
input_name
=
session_
.
GetInputName
(
i
,
allocator
);
auto
type_info
=
session_
.
GetInputTypeInfo
(
i
);
std
::
vector
<
int64_t
>
shape
=
type_info
.
GetTensorTypeAndShapeInfo
().
GetShape
();
ONNXTensorElementDataType
data_type
=
type_info
.
GetTensorTypeAndShapeInfo
().
GetElementType
();
input_desc_
.
emplace_back
(
ONNXDesc
{
input_name
,
shape
,
data_type
});
auto
*
ptr
=
scope_
->
Var
(
input_name
);
framework
::
InitializeVariable
(
ptr
,
proto_type
);
session_
=
std
::
make_shared
<
Ort
::
Session
>
(
*
env_
,
onnx_proto
,
static_cast
<
size_t
>
(
out_size
),
session_options
);
InitBinding
();
allocator
.
Free
(
input_name
);
}
size_t
n_outputs
=
session_
.
GetOutputCount
();
for
(
size_t
i
=
0
;
i
<
n_outputs
;
++
i
)
{
auto
output_name
=
session_
.
GetOutputName
(
i
,
allocator
);
auto
type_info
=
session_
.
GetOutputTypeInfo
(
i
);
std
::
vector
<
int64_t
>
shape
=
type_info
.
GetTensorTypeAndShapeInfo
().
GetShape
();
ONNXTensorElementDataType
data_type
=
type_info
.
GetTensorTypeAndShapeInfo
().
GetElementType
();
output_desc_
.
emplace_back
(
ONNXDesc
{
output_name
,
shape
,
data_type
});
Ort
::
MemoryInfo
out_memory_info
(
device_name
,
OrtDeviceAllocator
,
place_
.
GetDeviceId
(),
OrtMemTypeDefault
);
binding_
->
BindOutput
(
output_name
,
out_memory_info
);
allocator
.
Free
(
output_name
);
}
delete
onnx_proto
;
onnx_proto
=
nullptr
;
return
true
;
...
...
@@ -343,7 +349,7 @@ bool ONNXRuntimePredictor::ZeroCopyRun() {
OrtMemTypeDefault
);
binding_
->
BindOutput
(
output
.
name
.
c_str
(),
out_memory_info
);
}
session_
.
Run
({},
*
(
binding_
.
get
()));
session_
->
Run
({},
*
(
binding_
.
get
()));
}
catch
(
const
std
::
exception
&
e
)
{
LOG
(
ERROR
)
<<
e
.
what
();
return
false
;
...
...
@@ -354,8 +360,8 @@ bool ONNXRuntimePredictor::ZeroCopyRun() {
std
::
unique_ptr
<
PaddlePredictor
>
ONNXRuntimePredictor
::
Clone
(
void
*
stream
)
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
clone_mutex_
);
auto
*
x
=
new
ONNXRuntimePredictor
(
config_
);
x
->
Init
();
auto
*
x
=
new
ONNXRuntimePredictor
(
config_
,
env_
,
session_
);
x
->
Init
Binding
();
return
std
::
unique_ptr
<
PaddlePredictor
>
(
x
);
}
...
...
paddle/fluid/inference/api/onnxruntime_predictor.h
浏览文件 @
72b65d6b
...
...
@@ -92,7 +92,22 @@ class ONNXRuntimePredictor : public PaddlePredictor {
/// \param[in] AnalysisConfig config
///
explicit
ONNXRuntimePredictor
(
const
AnalysisConfig
&
config
)
:
env_
(
ORT_LOGGING_LEVEL_WARNING
,
"onnx"
),
config_
(
config
)
{
:
env_
(
std
::
make_shared
<
Ort
::
Env
>
(
ORT_LOGGING_LEVEL_WARNING
,
"paddle-ort"
)),
session_
(
nullptr
),
binding_
(
nullptr
),
config_
(
config
)
{
predictor_id_
=
inference
::
GetUniqueId
();
}
///
/// \brief Clone a ONNXRuntime Predictor object
///
/// \param[in] AnalysisConfig config
///
explicit
ONNXRuntimePredictor
(
const
AnalysisConfig
&
config
,
std
::
shared_ptr
<
Ort
::
Env
>
env
,
std
::
shared_ptr
<
Ort
::
Session
>
session
)
:
env_
(
env
),
session_
(
session
),
binding_
(
nullptr
),
config_
(
config
)
{
predictor_id_
=
inference
::
GetUniqueId
();
}
///
...
...
@@ -100,6 +115,13 @@ class ONNXRuntimePredictor : public PaddlePredictor {
///
~
ONNXRuntimePredictor
();
///
/// \brief Initialize ORT Binding
///
/// \return Whether the init function executed successfully
///
bool
InitBinding
();
///
/// \brief Initialize predictor
///
...
...
@@ -203,8 +225,8 @@ class ONNXRuntimePredictor : public PaddlePredictor {
private:
// ONNXRuntime
Ort
::
Env
env_
;
Ort
::
Session
session_
{
nullptr
};
std
::
shared_ptr
<
Ort
::
Env
>
env_
;
std
::
shared_ptr
<
Ort
::
Session
>
session_
{
nullptr
};
std
::
shared_ptr
<
Ort
::
IoBinding
>
binding_
;
AnalysisConfig
config_
;
...
...
@@ -212,7 +234,6 @@ class ONNXRuntimePredictor : public PaddlePredictor {
platform
::
Place
place_
;
std
::
vector
<
ONNXDesc
>
input_desc_
;
std
::
vector
<
ONNXDesc
>
output_desc_
;
std
::
map
<
std
::
string
,
std
::
shared_ptr
<
std
::
vector
<
int8_t
>>>
input_buffers_
;
int
predictor_id_
;
// Some more detailed tests, they are made the friends of the predictor, so that
...
...
paddle/fluid/inference/api/paddle_tensor.h
浏览文件 @
72b65d6b
...
...
@@ -191,7 +191,6 @@ class PD_INFER_DECL Tensor {
#ifdef PADDLE_WITH_ONNXRUNTIME
bool
is_ort_tensor_
{
false
};
std
::
vector
<
int64_t
>
shape_
;
std
::
weak_ptr
<
std
::
vector
<
int8_t
>>
buffer_
;
std
::
weak_ptr
<
Ort
::
IoBinding
>
binding_
;
int
idx_
{
-
1
};
...
...
@@ -199,8 +198,6 @@ class PD_INFER_DECL Tensor {
void
SetOrtBinding
(
const
std
::
shared_ptr
<
Ort
::
IoBinding
>
binding
);
void
SetOrtBuffer
(
const
std
::
shared_ptr
<
std
::
vector
<
int8_t
>>
buffer
);
template
<
typename
T
>
void
ORTCopyFromCpu
(
const
T
*
data
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录