Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
dc13f7c5
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2305
Star
20932
Fork
5423
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
dc13f7c5
编写于
1月 03, 2023
作者:
Y
Yuanle Liu
提交者:
GitHub
1月 03, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Paddle Inference] enhance paddle_infer::Tensor data type (#49388)
上级
72597c3e
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
82 addition
and
16 deletion
+82
-16
paddle/fluid/inference/api/analysis_predictor.h
paddle/fluid/inference/api/analysis_predictor.h
+1
-1
paddle/fluid/inference/api/details/zero_copy_tensor.cc
paddle/fluid/inference/api/details/zero_copy_tensor.cc
+30
-0
paddle/fluid/inference/api/paddle_analysis_config.h
paddle/fluid/inference/api/paddle_analysis_config.h
+1
-1
paddle/fluid/inference/api/paddle_tensor.h
paddle/fluid/inference/api/paddle_tensor.h
+3
-2
paddle/fluid/pybind/inference_api.cc
paddle/fluid/pybind/inference_api.cc
+47
-12
未找到文件。
paddle/fluid/inference/api/analysis_predictor.h
浏览文件 @
dc13f7c5
...
...
@@ -109,7 +109,7 @@ class AnalysisPredictor : public PaddlePredictor {
// negative sharing_identifier directly. In the future, this may affect
// the meaning of negative predictor id.
predictor_id_
=
-
trt_identifier
;
LOG
(
WARNING
)
LOG
_FIRST_N
(
WARNING
,
1
)
<<
"Since the engine context memory of multiple predictors "
"is enabled in Paddle-TRT, we set the id of current predictor to "
"negative sharing_identifier you specified."
;
...
...
paddle/fluid/inference/api/details/zero_copy_tensor.cc
浏览文件 @
dc13f7c5
...
...
@@ -176,6 +176,8 @@ DataType Tensor::type() const {
return
DataType
::
UINT8
;
}
else
if
(
type
==
paddle
::
framework
::
proto
::
VarType
::
INT8
)
{
return
DataType
::
INT8
;
}
else
if
(
type
==
paddle
::
framework
::
proto
::
VarType
::
BOOL
)
{
return
DataType
::
BOOL
;
}
return
DataType
::
FLOAT32
;
}
...
...
@@ -279,6 +281,11 @@ void Tensor::CopyFromCpu(const T *data) {
template
<
typename
T
>
struct
DataTypeInfo
;
template
<
>
struct
DataTypeInfo
<
bool
>
{
paddle
::
experimental
::
DataType
TYPE
=
paddle
::
experimental
::
DataType
::
BOOL
;
};
template
<
>
struct
DataTypeInfo
<
float
>
{
paddle
::
experimental
::
DataType
TYPE
=
paddle
::
experimental
::
DataType
::
FLOAT32
;
...
...
@@ -513,6 +520,7 @@ template PD_INFER_DECL void Tensor::CopyFromCpu<int32_t>(const int32_t *data);
template
PD_INFER_DECL
void
Tensor
::
CopyFromCpu
<
uint8_t
>(
const
uint8_t
*
data
);
template
PD_INFER_DECL
void
Tensor
::
CopyFromCpu
<
int8_t
>(
const
int8_t
*
data
);
template
PD_INFER_DECL
void
Tensor
::
CopyFromCpu
<
float16
>(
const
float16
*
data
);
template
PD_INFER_DECL
void
Tensor
::
CopyFromCpu
<
bool
>(
const
bool
*
data
);
template
PD_INFER_DECL
void
Tensor
::
ShareExternalData
<
float
>(
const
float
*
data
,
...
...
@@ -544,6 +552,11 @@ template PD_INFER_DECL void Tensor::ShareExternalData<float16>(
const
std
::
vector
<
int
>
&
shape
,
PlaceType
place
,
DataLayout
layout
);
template
PD_INFER_DECL
void
Tensor
::
ShareExternalData
<
bool
>(
const
bool
*
data
,
const
std
::
vector
<
int
>
&
shape
,
PlaceType
place
,
DataLayout
layout
);
template
PD_INFER_DECL
void
Tensor
::
CopyToCpu
<
float
>(
float
*
data
)
const
;
template
PD_INFER_DECL
void
Tensor
::
CopyToCpu
<
int64_t
>(
int64_t
*
data
)
const
;
...
...
@@ -551,6 +564,7 @@ template PD_INFER_DECL void Tensor::CopyToCpu<int32_t>(int32_t *data) const;
template
PD_INFER_DECL
void
Tensor
::
CopyToCpu
<
uint8_t
>(
uint8_t
*
data
)
const
;
template
PD_INFER_DECL
void
Tensor
::
CopyToCpu
<
int8_t
>(
int8_t
*
data
)
const
;
template
PD_INFER_DECL
void
Tensor
::
CopyToCpu
<
float16
>(
float16
*
data
)
const
;
template
PD_INFER_DECL
void
Tensor
::
CopyToCpu
<
bool
>(
bool
*
data
)
const
;
template
PD_INFER_DECL
void
Tensor
::
CopyToCpuImpl
<
float
>(
float
*
data
,
void
*
exec_stream
,
...
...
@@ -566,6 +580,10 @@ template PD_INFER_DECL void Tensor::CopyToCpuImpl<int8_t>(
int8_t
*
data
,
void
*
exec_stream
,
CallbackFunc
cb
,
void
*
cb_params
)
const
;
template
PD_INFER_DECL
void
Tensor
::
CopyToCpuImpl
<
float16
>(
float16
*
data
,
void
*
exec_stream
,
CallbackFunc
cb
,
void
*
cb_params
)
const
;
template
PD_INFER_DECL
void
Tensor
::
CopyToCpuImpl
<
bool
>(
bool
*
data
,
void
*
exec_stream
,
CallbackFunc
cb
,
void
*
cb_params
)
const
;
template
PD_INFER_DECL
void
Tensor
::
CopyToCpuAsync
<
float
>(
float
*
data
,
void
*
exec_stream
)
const
;
...
...
@@ -579,6 +597,8 @@ template PD_INFER_DECL void Tensor::CopyToCpuAsync<int8_t>(
int8_t
*
data
,
void
*
exec_stream
)
const
;
template
PD_INFER_DECL
void
Tensor
::
CopyToCpuAsync
<
float16
>(
float16
*
data
,
void
*
exec_stream
)
const
;
template
PD_INFER_DECL
void
Tensor
::
CopyToCpuAsync
<
bool
>(
bool
*
data
,
void
*
exec_stream
)
const
;
template
PD_INFER_DECL
void
Tensor
::
CopyToCpuAsync
<
float
>(
float
*
data
,
CallbackFunc
cb
,
void
*
cb_params
)
const
;
...
...
@@ -592,6 +612,9 @@ template PD_INFER_DECL void Tensor::CopyToCpuAsync<int8_t>(
int8_t
*
data
,
CallbackFunc
cb
,
void
*
cb_params
)
const
;
template
PD_INFER_DECL
void
Tensor
::
CopyToCpuAsync
<
float16
>(
float16
*
data
,
CallbackFunc
cb
,
void
*
cb_params
)
const
;
template
PD_INFER_DECL
void
Tensor
::
CopyToCpuAsync
<
bool
>(
bool
*
data
,
CallbackFunc
cb
,
void
*
cb_params
)
const
;
template
PD_INFER_DECL
float
*
Tensor
::
data
<
float
>(
PlaceType
*
place
,
int
*
size
)
const
;
...
...
@@ -605,6 +628,8 @@ template PD_INFER_DECL int8_t *Tensor::data<int8_t>(PlaceType *place,
int
*
size
)
const
;
template
PD_INFER_DECL
float16
*
Tensor
::
data
<
float16
>(
PlaceType
*
place
,
int
*
size
)
const
;
template
PD_INFER_DECL
bool
*
Tensor
::
data
<
bool
>(
PlaceType
*
place
,
int
*
size
)
const
;
template
PD_INFER_DECL
float
*
Tensor
::
mutable_data
<
float
>(
PlaceType
place
);
template
PD_INFER_DECL
int64_t
*
Tensor
::
mutable_data
<
int64_t
>(
PlaceType
place
);
...
...
@@ -612,6 +637,7 @@ template PD_INFER_DECL int32_t *Tensor::mutable_data<int32_t>(PlaceType place);
template
PD_INFER_DECL
uint8_t
*
Tensor
::
mutable_data
<
uint8_t
>(
PlaceType
place
);
template
PD_INFER_DECL
int8_t
*
Tensor
::
mutable_data
<
int8_t
>(
PlaceType
place
);
template
PD_INFER_DECL
float16
*
Tensor
::
mutable_data
<
float16
>(
PlaceType
place
);
template
PD_INFER_DECL
bool
*
Tensor
::
mutable_data
<
bool
>(
PlaceType
place
);
Tensor
::
Tensor
(
void
*
scope
,
const
void
*
device_contexts
)
:
scope_
{
scope
},
device_contexs_
(
device_contexts
)
{}
...
...
@@ -895,6 +921,8 @@ template void InternalUtils::CopyFromCpuWithIoStream<int8_t>(
paddle_infer
::
Tensor
*
t
,
const
int8_t
*
data
,
cudaStream_t
stream
);
template
void
InternalUtils
::
CopyFromCpuWithIoStream
<
float16
>(
paddle_infer
::
Tensor
*
t
,
const
float16
*
data
,
cudaStream_t
stream
);
template
void
InternalUtils
::
CopyFromCpuWithIoStream
<
bool
>(
paddle_infer
::
Tensor
*
t
,
const
bool
*
data
,
cudaStream_t
stream
);
template
void
InternalUtils
::
CopyToCpuWithIoStream
<
float
>(
paddle_infer
::
Tensor
*
t
,
float
*
data
,
cudaStream_t
stream
);
...
...
@@ -908,6 +936,8 @@ template void InternalUtils::CopyToCpuWithIoStream<int8_t>(
paddle_infer
::
Tensor
*
t
,
int8_t
*
data
,
cudaStream_t
stream
);
template
void
InternalUtils
::
CopyToCpuWithIoStream
<
float16
>(
paddle_infer
::
Tensor
*
t
,
float16
*
data
,
cudaStream_t
stream
);
template
void
InternalUtils
::
CopyToCpuWithIoStream
<
bool
>(
paddle_infer
::
Tensor
*
t
,
bool
*
data
,
cudaStream_t
stream
);
}
// namespace experimental
...
...
paddle/fluid/inference/api/paddle_analysis_config.h
浏览文件 @
dc13f7c5
...
...
@@ -161,7 +161,7 @@ struct PD_INFER_DECL AnalysisConfig {
explicit
AnalysisConfig
(
const
std
::
string
&
prog_file
,
const
std
::
string
&
params_file
);
///
/// \brief Precision of inference
in TensorRT
.
/// \brief Precision of inference.
///
enum
class
Precision
{
kFloat32
=
0
,
///< fp32
...
...
paddle/fluid/inference/api/paddle_tensor.h
浏览文件 @
dc13f7c5
...
...
@@ -52,13 +52,14 @@ class InternalUtils;
/// \brief Paddle data type.
enum
DataType
{
FLOAT32
,
INT64
,
INT32
,
UINT8
,
INT8
,
FLOAT32
,
FLOAT16
,
// TODO(Superjomn) support more data types if needed.
BOOL
,
// TODO(Inference): support more data types if needed.
};
enum
class
PlaceType
{
kUNK
=
-
1
,
kCPU
,
kGPU
,
kXPU
,
kNPU
,
kIPU
,
kCUSTOM
};
...
...
paddle/fluid/pybind/inference_api.cc
浏览文件 @
dc13f7c5
...
...
@@ -175,16 +175,22 @@ py::dtype PaddleDTypeToNumpyDType(PaddleDType dtype) {
case
PaddleDType
::
FLOAT32
:
dt
=
py
::
dtype
::
of
<
float
>
();
break
;
case
PaddleDType
::
FLOAT16
:
dt
=
py
::
dtype
::
of
<
paddle_infer
::
float16
>
();
break
;
case
PaddleDType
::
UINT8
:
dt
=
py
::
dtype
::
of
<
uint8_t
>
();
break
;
case
PaddleDType
::
FLOAT16
:
dt
=
py
::
dtype
::
of
<
paddle_infer
::
float16
>
();
case
PaddleDType
::
INT8
:
dt
=
py
::
dtype
::
of
<
int8_t
>
();
break
;
case
PaddleDType
::
BOOL
:
dt
=
py
::
dtype
::
of
<
bool
>
();
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupported data type. Now only supports INT32, INT64,
UINT8 and
"
"FLOAT
32
."
));
"Unsupported data type. Now only supports INT32, INT64,
FLOAT32,
"
"FLOAT
16, INT8, UINT8 and BOOL
."
));
}
return
dt
;
...
...
@@ -282,10 +288,22 @@ size_t PaddleGetDTypeSize(PaddleDType dt) {
case
PaddleDType
::
FLOAT32
:
size
=
sizeof
(
float
);
break
;
case
PaddleDType
::
FLOAT16
:
size
=
sizeof
(
paddle_infer
::
float16
);
break
;
case
PaddleDType
::
INT8
:
size
=
sizeof
(
int8_t
);
break
;
case
PaddleDType
::
UINT8
:
size
=
sizeof
(
uint8_t
);
break
;
case
PaddleDType
::
BOOL
:
size
=
sizeof
(
bool
);
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupported data t
ype. Now only supports INT32, INT64 and
"
"FLOAT
32
."
));
"Unsupported data t
ype. Now only supports INT32, INT64, FLOAT32,
"
"FLOAT
16, INT8, UINT8 and BOOL
."
));
}
return
size
;
}
...
...
@@ -316,10 +334,13 @@ py::array ZeroCopyTensorToNumpy(ZeroCopyTensor &tensor) { // NOLINT
case
PaddleDType
::
INT8
:
tensor
.
copy_to_cpu
<
int8_t
>
(
static_cast
<
int8_t
*>
(
array
.
mutable_data
()));
break
;
case
PaddleDType
::
BOOL
:
tensor
.
copy_to_cpu
<
bool
>
(
static_cast
<
bool
*>
(
array
.
mutable_data
()));
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupported data type. Now only supports INT32, INT64,
UINT8 and
"
"FLOAT
32
."
));
"Unsupported data type. Now only supports INT32, INT64,
FLOAT32,
"
"FLOAT
16, INT8, UINT8 and BOOL
."
));
}
return
array
;
}
...
...
@@ -350,10 +371,13 @@ py::array PaddleInferTensorToNumpy(paddle_infer::Tensor &tensor) { // NOLINT
case
PaddleDType
::
INT8
:
tensor
.
CopyToCpu
(
static_cast
<
int8_t
*>
(
array
.
mutable_data
()));
break
;
case
PaddleDType
::
BOOL
:
tensor
.
CopyToCpu
(
static_cast
<
bool
*>
(
array
.
mutable_data
()));
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupported data t
ype. Now only supports INT32, INT64 and
"
"FLOAT
32
."
));
"Unsupported data t
ype. Now only supports INT32, INT64, FLOAT32,
"
"FLOAT
16, INT8, UINT8 and BOOL
."
));
}
return
array
;
}
...
...
@@ -433,8 +457,12 @@ namespace {
void
BindPaddleDType
(
py
::
module
*
m
)
{
py
::
enum_
<
PaddleDType
>
(
*
m
,
"PaddleDType"
)
.
value
(
"FLOAT32"
,
PaddleDType
::
FLOAT32
)
.
value
(
"FLOAT16"
,
PaddleDType
::
FLOAT16
)
.
value
(
"INT64"
,
PaddleDType
::
INT64
)
.
value
(
"INT32"
,
PaddleDType
::
INT32
);
.
value
(
"INT32"
,
PaddleDType
::
INT32
)
.
value
(
"UINT8"
,
PaddleDType
::
UINT8
)
.
value
(
"INT8"
,
PaddleDType
::
INT8
)
.
value
(
"BOOL"
,
PaddleDType
::
BOOL
);
}
void
BindPaddleDataLayout
(
py
::
module
*
m
)
{
...
...
@@ -538,7 +566,8 @@ void BindPaddlePlace(py::module *m) {
.
value
(
"CPU"
,
PaddlePlace
::
kCPU
)
.
value
(
"GPU"
,
PaddlePlace
::
kGPU
)
.
value
(
"XPU"
,
PaddlePlace
::
kXPU
)
.
value
(
"NPU"
,
PaddlePlace
::
kNPU
);
.
value
(
"NPU"
,
PaddlePlace
::
kNPU
)
.
value
(
"CUSTOM"
,
PaddlePlace
::
kCUSTOM
);
}
void
BindPaddlePredictor
(
py
::
module
*
m
)
{
...
...
@@ -990,10 +1019,13 @@ void BindZeroCopyTensor(py::module *m) {
.
def
(
"reshape"
,
py
::
overload_cast
<
const
std
::
size_t
&>
(
&
paddle_infer
::
Tensor
::
ReshapeStrings
))
.
def
(
"copy_from_cpu"
,
&
ZeroCopyTensorCreate
<
int8_t
>
)
.
def
(
"copy_from_cpu"
,
&
ZeroCopyTensorCreate
<
uint8_t
>
)
.
def
(
"copy_from_cpu"
,
&
ZeroCopyTensorCreate
<
int32_t
>
)
.
def
(
"copy_from_cpu"
,
&
ZeroCopyTensorCreate
<
int64_t
>
)
.
def
(
"copy_from_cpu"
,
&
ZeroCopyTensorCreate
<
float
>
)
.
def
(
"copy_from_cpu"
,
&
ZeroCopyTensorCreate
<
paddle_infer
::
float16
>
)
.
def
(
"copy_from_cpu"
,
&
ZeroCopyTensorCreate
<
bool
>
)
.
def
(
"copy_from_cpu"
,
&
ZeroCopyStringTensorCreate
)
.
def
(
"copy_to_cpu"
,
&
ZeroCopyTensorToNumpy
)
.
def
(
"shape"
,
&
ZeroCopyTensor
::
shape
)
...
...
@@ -1010,11 +1042,14 @@ void BindPaddleInferTensor(py::module *m) {
.
def
(
"reshape"
,
py
::
overload_cast
<
const
std
::
size_t
&>
(
&
paddle_infer
::
Tensor
::
ReshapeStrings
))
.
def
(
"copy_from_cpu_bind"
,
&
PaddleInferTensorCreate
<
int8_t
>
)
.
def
(
"copy_from_cpu_bind"
,
&
PaddleInferTensorCreate
<
uint8_t
>
)
.
def
(
"copy_from_cpu_bind"
,
&
PaddleInferTensorCreate
<
int32_t
>
)
.
def
(
"copy_from_cpu_bind"
,
&
PaddleInferTensorCreate
<
int64_t
>
)
.
def
(
"copy_from_cpu_bind"
,
&
PaddleInferTensorCreate
<
float
>
)
.
def
(
"copy_from_cpu_bind"
,
&
PaddleInferTensorCreate
<
paddle_infer
::
float16
>
)
.
def
(
"copy_from_cpu_bind"
,
&
PaddleInferTensorCreate
<
bool
>
)
.
def
(
"copy_from_cpu_bind"
,
&
PaddleInferStringTensorCreate
)
.
def
(
"share_external_data_bind"
,
&
PaddleInferShareExternalData
)
.
def
(
"copy_to_cpu"
,
&
PaddleInferTensorToNumpy
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录