Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
87e6149c
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
87e6149c
编写于
5月 04, 2022
作者:
H
heliqi
提交者:
GitHub
5月 04, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix paddle-ort python bug (#42464) (#42470)
* fix paddle-ort python bug * fix paddle-ort python bug
上级
544352de
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
34 addition
and
2 deletion
+34
-2
paddle/fluid/inference/api/details/zero_copy_tensor.cc
paddle/fluid/inference/api/details/zero_copy_tensor.cc
+33
-2
paddle/fluid/inference/api/paddle_tensor.h
paddle/fluid/inference/api/paddle_tensor.h
+1
-0
未找到文件。
paddle/fluid/inference/api/details/zero_copy_tensor.cc
浏览文件 @
87e6149c
...
...
@@ -674,8 +674,39 @@ void Tensor::ORTCopyFromCpu(const T *data) {
OrtMemTypeDefault
);
size_t
size
=
std
::
accumulate
(
begin
(
shape_
),
end
(
shape_
),
1UL
,
std
::
multiplies
<
size_t
>
());
auto
ort_value
=
GetOrtVaule
(
memory_info
,
const_cast
<
T
*>
(
data
),
size
,
shape_
.
data
(),
shape_
.
size
());
size_t
buffer_size
=
size
*
sizeof
(
T
);
if
(
buffer_size
>
buffer_
.
size
())
{
buffer_
.
resize
(
buffer_size
);
}
std
::
memcpy
(
static_cast
<
void
*>
(
buffer_
.
data
()),
data
,
buffer_size
);
auto
onnx_dtype
=
ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED
;
if
(
std
::
is_same
<
T
,
float
>::
value
)
{
onnx_dtype
=
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT
;
}
else
if
(
std
::
is_same
<
T
,
double
>::
value
)
{
onnx_dtype
=
ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE
;
}
else
if
(
std
::
is_same
<
T
,
int64_t
>::
value
)
{
onnx_dtype
=
ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64
;
}
else
if
(
std
::
is_same
<
T
,
int32_t
>::
value
)
{
onnx_dtype
=
ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32
;
}
else
if
(
std
::
is_same
<
T
,
uint8_t
>::
value
)
{
onnx_dtype
=
ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8
;
}
else
if
(
std
::
is_same
<
T
,
int8_t
>::
value
)
{
onnx_dtype
=
ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8
;
}
else
if
(
std
::
is_same
<
T
,
float16
>::
value
)
{
onnx_dtype
=
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16
;
}
if
(
onnx_dtype
==
ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED
)
{
PADDLE_THROW
(
paddle
::
platform
::
errors
::
InvalidArgument
(
"Found undefined data type for onnxruntime, only supports "
"float16/float32/float64/int8/uint8/int32/int64."
));
}
auto
ort_value
=
Ort
::
Value
::
CreateTensor
(
memory_info
,
buffer_
.
data
(),
buffer_size
,
shape_
.
data
(),
shape_
.
size
(),
onnx_dtype
);
binding
->
BindInput
(
name_
.
c_str
(),
ort_value
);
}
...
...
paddle/fluid/inference/api/paddle_tensor.h
浏览文件 @
87e6149c
...
...
@@ -183,6 +183,7 @@ class PD_INFER_DECL Tensor {
#ifdef PADDLE_WITH_ONNXRUNTIME
bool
is_ort_tensor_
{
false
};
std
::
vector
<
int64_t
>
shape_
;
std
::
vector
<
int8_t
>
buffer_
;
std
::
weak_ptr
<
Ort
::
IoBinding
>
binding_
;
int
idx_
{
-
1
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录