Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
b6bf8994
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b6bf8994
编写于
6月 23, 2022
作者:
C
ccrrong
提交者:
GitHub
6月 23, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add cast trt converter (#43447)
* add cast trt converter
上级
8902a414
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
331 addition
and
67 deletion
+331
-67
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+94
-49
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
+1
-0
paddle/fluid/inference/tensorrt/convert/cast_op.cc
paddle/fluid/inference/tensorrt/convert/cast_op.cc
+69
-0
paddle/fluid/inference/tensorrt/op_teller.cc
paddle/fluid/inference/tensorrt/op_teller.cc
+47
-18
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_cast.py
...uid/tests/unittests/ir/inference/test_trt_convert_cast.py
+120
-0
未找到文件。
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
b6bf8994
...
@@ -104,7 +104,8 @@ bool IsPersistable(const framework::VarDesc *var) {
...
@@ -104,7 +104,8 @@ bool IsPersistable(const framework::VarDesc *var) {
}
}
}
// namespace
}
// namespace
bool
PaddleTensorToLoDTensor
(
const
PaddleTensor
&
pt
,
framework
::
LoDTensor
*
t
,
bool
PaddleTensorToLoDTensor
(
const
PaddleTensor
&
pt
,
framework
::
LoDTensor
*
t
,
const
platform
::
Place
&
place
)
{
const
platform
::
Place
&
place
)
{
framework
::
DDim
ddim
=
phi
::
make_ddim
(
pt
.
shape
);
framework
::
DDim
ddim
=
phi
::
make_ddim
(
pt
.
shape
);
void
*
input_ptr
;
void
*
input_ptr
;
...
@@ -132,18 +133,19 @@ bool PaddleTensorToLoDTensor(const PaddleTensor &pt, framework::LoDTensor *t,
...
@@ -132,18 +133,19 @@ bool PaddleTensorToLoDTensor(const PaddleTensor &pt, framework::LoDTensor *t,
if
(
platform
::
is_cpu_place
(
place
))
{
if
(
platform
::
is_cpu_place
(
place
))
{
// TODO(panyx0718): Init LoDTensor from existing memcpy to save a copy.
// TODO(panyx0718): Init LoDTensor from existing memcpy to save a copy.
std
::
memcpy
(
static_cast
<
void
*>
(
input_ptr
),
pt
.
data
.
data
(),
std
::
memcpy
(
pt
.
data
.
length
());
static_cast
<
void
*>
(
input_ptr
),
pt
.
data
.
data
(),
pt
.
data
.
length
());
}
else
if
(
platform
::
is_ipu_place
(
place
))
{
}
else
if
(
platform
::
is_ipu_place
(
place
))
{
#ifdef PADDLE_WITH_IPU
#ifdef PADDLE_WITH_IPU
std
::
memcpy
(
static_cast
<
void
*>
(
input_ptr
),
pt
.
data
.
data
(),
std
::
memcpy
(
pt
.
data
.
length
());
static_cast
<
void
*>
(
input_ptr
),
pt
.
data
.
data
(),
pt
.
data
.
length
());
#else
#else
PADDLE_THROW
(
paddle
::
platform
::
errors
::
Fatal
(
PADDLE_THROW
(
paddle
::
platform
::
errors
::
Fatal
(
"Not compile with WITH_IPU, should not reach here."
));
"Not compile with WITH_IPU, should not reach here."
));
#endif
#endif
}
else
if
(
platform
::
is_gpu_place
(
place
))
{
}
else
if
(
platform
::
is_gpu_place
(
place
))
{
PADDLE_ENFORCE_EQ
(
platform
::
is_xpu_place
(
place
),
false
,
PADDLE_ENFORCE_EQ
(
platform
::
is_xpu_place
(
place
),
false
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"Only one choice can be made between CPU and XPU."
));
"Only one choice can be made between CPU and XPU."
));
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
...
@@ -151,8 +153,11 @@ bool PaddleTensorToLoDTensor(const PaddleTensor &pt, framework::LoDTensor *t,
...
@@ -151,8 +153,11 @@ bool PaddleTensorToLoDTensor(const PaddleTensor &pt, framework::LoDTensor *t,
auto
*
dev_ctx
=
auto
*
dev_ctx
=
static_cast
<
const
platform
::
CUDADeviceContext
*>
(
pool
.
Get
(
place
));
static_cast
<
const
platform
::
CUDADeviceContext
*>
(
pool
.
Get
(
place
));
auto
dst_gpu_place
=
place
;
auto
dst_gpu_place
=
place
;
memory
::
Copy
(
dst_gpu_place
,
static_cast
<
void
*>
(
input_ptr
),
memory
::
Copy
(
dst_gpu_place
,
platform
::
CPUPlace
(),
pt
.
data
.
data
(),
pt
.
data
.
length
(),
static_cast
<
void
*>
(
input_ptr
),
platform
::
CPUPlace
(),
pt
.
data
.
data
(),
pt
.
data
.
length
(),
dev_ctx
->
stream
());
dev_ctx
->
stream
());
#else
#else
PADDLE_THROW
(
paddle
::
platform
::
errors
::
Fatal
(
PADDLE_THROW
(
paddle
::
platform
::
errors
::
Fatal
(
...
@@ -161,8 +166,11 @@ bool PaddleTensorToLoDTensor(const PaddleTensor &pt, framework::LoDTensor *t,
...
@@ -161,8 +166,11 @@ bool PaddleTensorToLoDTensor(const PaddleTensor &pt, framework::LoDTensor *t,
}
else
if
(
platform
::
is_xpu_place
(
place
))
{
}
else
if
(
platform
::
is_xpu_place
(
place
))
{
#ifdef PADDLE_WITH_XPU
#ifdef PADDLE_WITH_XPU
auto
dst_xpu_place
=
place
;
auto
dst_xpu_place
=
place
;
memory
::
Copy
(
dst_xpu_place
,
static_cast
<
void
*>
(
input_ptr
),
memory
::
Copy
(
dst_xpu_place
,
platform
::
CPUPlace
(),
pt
.
data
.
data
(),
pt
.
data
.
length
());
static_cast
<
void
*>
(
input_ptr
),
platform
::
CPUPlace
(),
pt
.
data
.
data
(),
pt
.
data
.
length
());
#else
#else
PADDLE_THROW
(
paddle
::
platform
::
errors
::
Fatal
(
PADDLE_THROW
(
paddle
::
platform
::
errors
::
Fatal
(
"Not compile with XPU, should not reach here."
));
"Not compile with XPU, should not reach here."
));
...
@@ -245,7 +253,8 @@ bool AnalysisPredictor::Init(
...
@@ -245,7 +253,8 @@ bool AnalysisPredictor::Init(
void
AnalysisPredictor
::
InitPlace
()
{
void
AnalysisPredictor
::
InitPlace
()
{
if
(
config_
.
use_gpu
())
{
if
(
config_
.
use_gpu
())
{
PADDLE_ENFORCE_EQ
(
config_
.
use_xpu
(),
false
,
PADDLE_ENFORCE_EQ
(
config_
.
use_xpu
(),
false
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"Only one choice can be made between CPU and XPU."
));
"Only one choice can be made between CPU and XPU."
));
place_
=
paddle
::
platform
::
CUDAPlace
(
config_
.
gpu_device_id
());
place_
=
paddle
::
platform
::
CUDAPlace
(
config_
.
gpu_device_id
());
...
@@ -502,7 +511,8 @@ static bool IsPrepareDataOptTargetOp(framework::OpDesc *op) {
...
@@ -502,7 +511,8 @@ static bool IsPrepareDataOptTargetOp(framework::OpDesc *op) {
}
}
static
void
DisablePrepareDataOpt
(
static
void
DisablePrepareDataOpt
(
std
::
shared_ptr
<
framework
::
ProgramDesc
>
inference_program
,
int
block
,
std
::
shared_ptr
<
framework
::
ProgramDesc
>
inference_program
,
int
block
,
bool
pre_disable_opt
)
{
bool
pre_disable_opt
)
{
bool
disable_opt
=
false
;
bool
disable_opt
=
false
;
auto
&
infer_block
=
inference_program
->
Block
(
block
);
auto
&
infer_block
=
inference_program
->
Block
(
block
);
...
@@ -512,8 +522,8 @@ static void DisablePrepareDataOpt(
...
@@ -512,8 +522,8 @@ static void DisablePrepareDataOpt(
}
}
if
(
op
->
HasAttr
(
"sub_block"
))
{
if
(
op
->
HasAttr
(
"sub_block"
))
{
int
blockID
=
op
->
GetBlockAttrId
(
"sub_block"
);
int
blockID
=
op
->
GetBlockAttrId
(
"sub_block"
);
DisablePrepareDataOpt
(
inference_program
,
blockID
,
DisablePrepareDataOpt
(
disable_opt
||
pre_disable_opt
);
inference_program
,
blockID
,
disable_opt
||
pre_disable_opt
);
}
}
// disable prepare data if unfriendly op is found
// disable prepare data if unfriendly op is found
if
(
!
disable_opt
)
{
if
(
!
disable_opt
)
{
...
@@ -531,8 +541,8 @@ bool AnalysisPredictor::PrepareExecutor() {
...
@@ -531,8 +541,8 @@ bool AnalysisPredictor::PrepareExecutor() {
#endif
#endif
DisablePrepareDataOpt
(
inference_program_
,
0
,
false
);
DisablePrepareDataOpt
(
inference_program_
,
0
,
false
);
executor_
->
Prepare
(
sub_scope_
,
*
inference_program_
,
0
,
executor_
->
Prepare
(
config_
.
use_feed_fetch_ops_
);
sub_scope_
,
*
inference_program_
,
0
,
config_
.
use_feed_fetch_ops_
);
PADDLE_ENFORCE_NOT_NULL
(
sub_scope_
,
PADDLE_ENFORCE_NOT_NULL
(
sub_scope_
,
platform
::
errors
::
PreconditionNotMet
(
platform
::
errors
::
PreconditionNotMet
(
...
@@ -578,8 +588,13 @@ bool AnalysisPredictor::PrepareFleetExecutor() {
...
@@ -578,8 +588,13 @@ bool AnalysisPredictor::PrepareFleetExecutor() {
feed_fetch_vars
.
emplace_back
(
pair
.
second
);
feed_fetch_vars
.
emplace_back
(
pair
.
second
);
}
}
fleet_exe_
->
Init
(
config_
.
dist_config
().
carrier_id
(),
fleet_exe_
->
Init
(
config_
.
dist_config
().
carrier_id
(),
*
(
inference_program_
.
get
()),
scope_
.
get
(),
place_
,
1
,
*
(
inference_program_
.
get
()),
{
task_node_
.
get
()},
id_to_rank
,
feed_fetch_vars
);
scope_
.
get
(),
place_
,
1
,
{
task_node_
.
get
()},
id_to_rank
,
feed_fetch_vars
);
return
true
;
return
true
;
}
}
...
@@ -616,8 +631,12 @@ bool AnalysisPredictor::CommInit() {
...
@@ -616,8 +631,12 @@ bool AnalysisPredictor::CommInit() {
peer_endpoints
.
emplace_back
(
peer_endpoints
.
emplace_back
(
config_
.
dist_config
().
trainer_endpoints
()[
rank
]);
config_
.
dist_config
().
trainer_endpoints
()[
rank
]);
}
}
InsertCommOp
(
var_name_base
+
std
::
to_string
(
order
),
ranks_in_group
,
InsertCommOp
(
var_name_base
+
std
::
to_string
(
order
),
rank_in_group
,
peer_endpoints
,
comm_init_block
,
ring_id
);
ranks_in_group
,
rank_in_group
,
peer_endpoints
,
comm_init_block
,
ring_id
);
order
+=
1
;
order
+=
1
;
}
}
framework
::
NaiveExecutor
e
(
place_
);
framework
::
NaiveExecutor
e
(
place_
);
...
@@ -629,8 +648,11 @@ bool AnalysisPredictor::CommInit() {
...
@@ -629,8 +648,11 @@ bool AnalysisPredictor::CommInit() {
}
}
void
AnalysisPredictor
::
InsertCommOp
(
void
AnalysisPredictor
::
InsertCommOp
(
std
::
string
tmp_var_name
,
int
nranks
,
int
rank
,
std
::
string
tmp_var_name
,
const
std
::
vector
<
std
::
string
>
&
peer_endpoints
,
framework
::
BlockDesc
*
block
,
int
nranks
,
int
rank
,
const
std
::
vector
<
std
::
string
>
&
peer_endpoints
,
framework
::
BlockDesc
*
block
,
int
ring_id
)
{
int
ring_id
)
{
/*
/*
* tmp_var_name: the var name for var comm_id
* tmp_var_name: the var name for var comm_id
...
@@ -687,7 +709,8 @@ bool AnalysisPredictor::LoadConverterConfig(
...
@@ -687,7 +709,8 @@ bool AnalysisPredictor::LoadConverterConfig(
<<
config_
.
dist_config
().
comm_init_config
()
<<
"
\n
"
;
<<
config_
.
dist_config
().
comm_init_config
()
<<
"
\n
"
;
std
::
ifstream
fin
(
config_
.
dist_config
().
comm_init_config
(),
std
::
ios
::
in
);
std
::
ifstream
fin
(
config_
.
dist_config
().
comm_init_config
(),
std
::
ios
::
in
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
static_cast
<
bool
>
(
fin
.
is_open
()),
true
,
static_cast
<
bool
>
(
fin
.
is_open
()),
true
,
platform
::
errors
::
NotFound
(
platform
::
errors
::
NotFound
(
"Cannot open file %s, please confirm whether the file is normal."
,
"Cannot open file %s, please confirm whether the file is normal."
,
config_
.
dist_config
().
comm_init_config
()));
config_
.
dist_config
().
comm_init_config
()));
...
@@ -831,8 +854,9 @@ bool AnalysisPredictor::Run(const std::vector<PaddleTensor> &inputs,
...
@@ -831,8 +854,9 @@ bool AnalysisPredictor::Run(const std::vector<PaddleTensor> &inputs,
timer
.
tic
();
timer
.
tic
();
// set feed variable
// set feed variable
framework
::
Scope
*
scope
=
sub_scope_
?
sub_scope_
:
scope_
.
get
();
framework
::
Scope
*
scope
=
sub_scope_
?
sub_scope_
:
scope_
.
get
();
PADDLE_ENFORCE_NOT_NULL
(
scope
,
platform
::
errors
::
PreconditionNotMet
(
PADDLE_ENFORCE_NOT_NULL
(
"The scope should not be nullptr."
));
scope
,
platform
::
errors
::
PreconditionNotMet
(
"The scope should not be nullptr."
));
if
(
!
SetFeed
(
inputs
,
scope
))
{
if
(
!
SetFeed
(
inputs
,
scope
))
{
LOG
(
ERROR
)
<<
"fail to set feed"
;
LOG
(
ERROR
)
<<
"fail to set feed"
;
return
false
;
return
false
;
...
@@ -935,9 +959,11 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs,
...
@@ -935,9 +959,11 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs,
for
(
size_t
i
=
0
;
i
<
fetches_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
fetches_
.
size
();
++
i
)
{
int
idx
=
BOOST_GET_CONST
(
int
,
fetches_
[
i
]
->
GetAttr
(
"col"
));
int
idx
=
BOOST_GET_CONST
(
int
,
fetches_
[
i
]
->
GetAttr
(
"col"
));
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
static_cast
<
size_t
>
(
idx
),
i
,
static_cast
<
size_t
>
(
idx
),
i
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"Fetch op's col attr(%d) should be equal to the index(%d)"
,
idx
,
"Fetch op's col attr(%d) should be equal to the index(%d)"
,
idx
,
i
));
i
));
framework
::
FetchType
&
fetch_var
=
framework
::
FetchType
&
fetch_var
=
framework
::
GetFetchVariable
(
*
scope
,
"fetch"
,
idx
);
framework
::
GetFetchVariable
(
*
scope
,
"fetch"
,
idx
);
...
@@ -978,7 +1004,8 @@ void AnalysisPredictor::PrepareArgument() {
...
@@ -978,7 +1004,8 @@ void AnalysisPredictor::PrepareArgument() {
if
(
!
config_
.
model_dir
().
empty
())
{
if
(
!
config_
.
model_dir
().
empty
())
{
argument_
.
SetModelDir
(
config_
.
model_dir
());
argument_
.
SetModelDir
(
config_
.
model_dir
());
}
else
{
}
else
{
PADDLE_ENFORCE_EQ
(
config_
.
prog_file
().
empty
(),
false
,
PADDLE_ENFORCE_EQ
(
config_
.
prog_file
().
empty
(),
false
,
platform
::
errors
::
PreconditionNotMet
(
platform
::
errors
::
PreconditionNotMet
(
"Either model_dir or prog_file should be set."
));
"Either model_dir or prog_file should be set."
));
std
::
string
dir
=
inference
::
analysis
::
GetDirRoot
(
config_
.
prog_file
());
std
::
string
dir
=
inference
::
analysis
::
GetDirRoot
(
config_
.
prog_file
());
...
@@ -1123,7 +1150,8 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
...
@@ -1123,7 +1150,8 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
Analyzer
().
Run
(
&
argument_
);
Analyzer
().
Run
(
&
argument_
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
argument_
.
scope_valid
(),
true
,
argument_
.
scope_valid
(),
true
,
platform
::
errors
::
InvalidArgument
(
"The argument scope should be valid."
));
platform
::
errors
::
InvalidArgument
(
"The argument scope should be valid."
));
VLOG
(
5
)
<<
"to prepare executor"
;
VLOG
(
5
)
<<
"to prepare executor"
;
ARGUMENT_CHECK_FIELD
((
&
argument_
),
ir_analyzed_program
);
ARGUMENT_CHECK_FIELD
((
&
argument_
),
ir_analyzed_program
);
...
@@ -1173,7 +1201,8 @@ CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kAnalysis>(
...
@@ -1173,7 +1201,8 @@ CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kAnalysis>(
}
}
VLOG
(
3
)
<<
"create AnalysisConfig"
;
VLOG
(
3
)
<<
"create AnalysisConfig"
;
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
config
.
is_valid
(),
true
,
config
.
is_valid
(),
true
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"Note: Each config can only be used for one predictor."
));
"Note: Each config can only be used for one predictor."
));
...
@@ -1190,11 +1219,13 @@ CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kAnalysis>(
...
@@ -1190,11 +1219,13 @@ CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kAnalysis>(
std
::
call_once
(
gflags_initialized
,
[
&
]()
{
std
::
call_once
(
gflags_initialized
,
[
&
]()
{
std
::
vector
<
std
::
string
>
gflags
;
std
::
vector
<
std
::
string
>
gflags
;
PADDLE_ENFORCE_GE
(
PADDLE_ENFORCE_GE
(
config
.
memory_pool_init_size_mb
(),
0.
f
,
config
.
memory_pool_init_size_mb
(),
0.
f
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"The size of memory pool should be greater than 0."
));
"The size of memory pool should be greater than 0."
));
PADDLE_ENFORCE_GE
(
PADDLE_ENFORCE_GE
(
config
.
gpu_device_id
(),
0
,
config
.
gpu_device_id
(),
0
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"Invalid device id (%d). The device id should be greater than 0."
,
"Invalid device id (%d). The device id should be greater than 0."
,
config
.
gpu_device_id
()));
config
.
gpu_device_id
()));
...
@@ -1303,8 +1334,9 @@ void AnalysisPredictor::PrepareFeedFetch() {
...
@@ -1303,8 +1334,9 @@ void AnalysisPredictor::PrepareFeedFetch() {
}
}
void
AnalysisPredictor
::
CreateFeedFetchVar
(
framework
::
Scope
*
scope
)
{
void
AnalysisPredictor
::
CreateFeedFetchVar
(
framework
::
Scope
*
scope
)
{
PADDLE_ENFORCE_NOT_NULL
(
scope
,
platform
::
errors
::
InvalidArgument
(
PADDLE_ENFORCE_NOT_NULL
(
"The scope should not be nullptr."
));
scope
,
platform
::
errors
::
InvalidArgument
(
"The scope should not be nullptr."
));
auto
*
var
=
scope
->
Var
(
"feed"
);
auto
*
var
=
scope
->
Var
(
"feed"
);
var
->
GetMutable
<
framework
::
FeedList
>
();
var
->
GetMutable
<
framework
::
FeedList
>
();
var
=
scope
->
Var
(
"fetch"
);
var
=
scope
->
Var
(
"fetch"
);
...
@@ -1325,8 +1357,9 @@ AnalysisPredictor::GetInputTensorShape() {
...
@@ -1325,8 +1357,9 @@ AnalysisPredictor::GetInputTensorShape() {
std
::
vector
<
std
::
string
>
names
=
GetInputNames
();
std
::
vector
<
std
::
string
>
names
=
GetInputNames
();
for
(
std
::
string
name
:
names
)
{
for
(
std
::
string
name
:
names
)
{
auto
*
var
=
inference_program_
->
Block
(
0
).
FindVar
(
name
);
auto
*
var
=
inference_program_
->
Block
(
0
).
FindVar
(
name
);
PADDLE_ENFORCE_NOT_NULL
(
var
,
platform
::
errors
::
PreconditionNotMet
(
PADDLE_ENFORCE_NOT_NULL
(
"Input %s does not exist."
,
name
));
var
,
platform
::
errors
::
PreconditionNotMet
(
"Input %s does not exist."
,
name
));
input_shapes
[
name
]
=
var
->
GetShape
();
input_shapes
[
name
]
=
var
->
GetShape
();
}
}
return
input_shapes
;
return
input_shapes
;
...
@@ -1565,7 +1598,8 @@ void AnalysisPredictor::StatisticShapeRangeInfo() {
...
@@ -1565,7 +1598,8 @@ void AnalysisPredictor::StatisticShapeRangeInfo() {
std
::
vector
<
std
::
pair
<
int32_t
,
int32_t
>>
counter
;
std
::
vector
<
std
::
pair
<
int32_t
,
int32_t
>>
counter
;
for
(
auto
&
it
:
m
)
counter
.
push_back
(
it
);
for
(
auto
&
it
:
m
)
counter
.
push_back
(
it
);
std
::
sort
(
std
::
sort
(
counter
.
begin
(),
counter
.
end
(),
counter
.
begin
(),
counter
.
end
(),
[](
std
::
pair
<
int32_t
,
int32_t
>
&
a
,
std
::
pair
<
int32_t
,
int32_t
>
&
b
)
{
[](
std
::
pair
<
int32_t
,
int32_t
>
&
a
,
std
::
pair
<
int32_t
,
int32_t
>
&
b
)
{
return
a
.
second
>
b
.
second
;
return
a
.
second
>
b
.
second
;
});
});
...
@@ -1587,8 +1621,8 @@ void AnalysisPredictor::StatisticShapeRangeInfo() {
...
@@ -1587,8 +1621,8 @@ void AnalysisPredictor::StatisticShapeRangeInfo() {
opt_shapes
[
name
]
=
opt_shape
;
opt_shapes
[
name
]
=
opt_shape
;
}
}
inference
::
SerializeShapeRangeInfo
(
config_
.
shape_range_info_path
(),
inference
::
SerializeShapeRangeInfo
(
min_shapes
,
max_shapes
,
opt_shapes
);
config_
.
shape_range_info_path
(),
min_shapes
,
max_shapes
,
opt_shapes
);
}
}
bool
AnalysisPredictor
::
LoadProgramDesc
()
{
bool
AnalysisPredictor
::
LoadProgramDesc
()
{
...
@@ -1608,7 +1642,8 @@ bool AnalysisPredictor::LoadProgramDesc() {
...
@@ -1608,7 +1642,8 @@ bool AnalysisPredictor::LoadProgramDesc() {
return
false
;
return
false
;
}
}
LOG
(
ERROR
)
<<
string
::
Sprintf
(
LOG
(
ERROR
)
<<
string
::
Sprintf
(
"not valid model path '%s' or program path '%s'."
,
config_
.
model_dir
(),
"not valid model path '%s' or program path '%s'."
,
config_
.
model_dir
(),
config_
.
params_file
());
config_
.
params_file
());
return
false
;
return
false
;
}
}
...
@@ -1620,7 +1655,8 @@ bool AnalysisPredictor::LoadProgramDesc() {
...
@@ -1620,7 +1655,8 @@ bool AnalysisPredictor::LoadProgramDesc() {
// Read binary
// Read binary
std
::
ifstream
fin
(
filename
,
std
::
ios
::
in
|
std
::
ios
::
binary
);
std
::
ifstream
fin
(
filename
,
std
::
ios
::
in
|
std
::
ios
::
binary
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
static_cast
<
bool
>
(
fin
.
is_open
()),
true
,
static_cast
<
bool
>
(
fin
.
is_open
()),
true
,
platform
::
errors
::
NotFound
(
platform
::
errors
::
NotFound
(
"Cannot open file %s, please confirm whether the file is normal."
,
"Cannot open file %s, please confirm whether the file is normal."
,
filename
));
filename
));
...
@@ -1722,7 +1758,8 @@ void AnalysisPredictor::ClearIntermediateTensor() {
...
@@ -1722,7 +1758,8 @@ void AnalysisPredictor::ClearIntermediateTensor() {
#if PADDLE_WITH_TENSORRT
#if PADDLE_WITH_TENSORRT
bool
AnalysisPredictor
::
SaveTrtCalibToDisk
()
{
bool
AnalysisPredictor
::
SaveTrtCalibToDisk
()
{
PADDLE_ENFORCE_EQ
(
config_
.
tensorrt_engine_enabled
(),
true
,
PADDLE_ENFORCE_EQ
(
config_
.
tensorrt_engine_enabled
(),
true
,
platform
::
errors
::
PreconditionNotMet
(
platform
::
errors
::
PreconditionNotMet
(
"This func can be invoked only in trt mode"
));
"This func can be invoked only in trt mode"
));
auto
&
block
=
inference_program_
->
Block
(
0
);
auto
&
block
=
inference_program_
->
Block
(
0
);
...
@@ -1963,6 +2000,7 @@ USE_TRT_CONVERTER(c_allreduce_sum)
...
@@ -1963,6 +2000,7 @@ USE_TRT_CONVERTER(c_allreduce_sum)
USE_TRT_CONVERTER
(
roll
)
USE_TRT_CONVERTER
(
roll
)
USE_TRT_CONVERTER
(
strided_slice
)
USE_TRT_CONVERTER
(
strided_slice
)
USE_TRT_CONVERTER
(
transformer_input_convert
)
USE_TRT_CONVERTER
(
transformer_input_convert
)
USE_TRT_CONVERTER
(
cast
)
USE_TRT_CONVERTER
(
recover_padding
)
USE_TRT_CONVERTER
(
recover_padding
)
USE_TRT_CONVERTER
(
remove_padding
)
USE_TRT_CONVERTER
(
remove_padding
)
USE_TRT_CONVERTER
(
top_k
)
USE_TRT_CONVERTER
(
top_k
)
...
@@ -1990,8 +2028,10 @@ Predictor::Predictor(const Config &config) {
...
@@ -1990,8 +2028,10 @@ Predictor::Predictor(const Config &config) {
<<
"Paddle2ONNX do't support convert the Model, fall back to using "
<<
"Paddle2ONNX do't support convert the Model, fall back to using "
"Paddle Inference."
;
"Paddle Inference."
;
}
else
{
}
else
{
predictor_
=
paddle
::
CreatePaddlePredictor
<
predictor_
=
Config
,
paddle
::
PaddleEngineKind
::
kONNXRuntime
>
(
config
);
paddle
::
CreatePaddlePredictor
<
Config
,
paddle
::
PaddleEngineKind
::
kONNXRuntime
>
(
config
);
return
;
return
;
}
}
#else
#else
...
@@ -2001,8 +2041,10 @@ Predictor::Predictor(const Config &config) {
...
@@ -2001,8 +2041,10 @@ Predictor::Predictor(const Config &config) {
"fall back to using Paddle Inference."
;
"fall back to using Paddle Inference."
;
#endif
#endif
}
}
predictor_
=
paddle
::
CreatePaddlePredictor
<
predictor_
=
Config
,
paddle
::
PaddleEngineKind
::
kAnalysis
>
(
config
);
paddle
::
CreatePaddlePredictor
<
Config
,
paddle
::
PaddleEngineKind
::
kAnalysis
>
(
config
);
}
}
std
::
vector
<
std
::
string
>
Predictor
::
GetInputNames
()
{
std
::
vector
<
std
::
string
>
Predictor
::
GetInputNames
()
{
...
@@ -2086,7 +2128,8 @@ std::shared_ptr<Predictor> CreatePredictor(const Config &config) { // NOLINT
...
@@ -2086,7 +2128,8 @@ std::shared_ptr<Predictor> CreatePredictor(const Config &config) { // NOLINT
namespace
services
{
namespace
services
{
PredictorPool
::
PredictorPool
(
const
Config
&
config
,
size_t
size
)
{
PredictorPool
::
PredictorPool
(
const
Config
&
config
,
size_t
size
)
{
PADDLE_ENFORCE_GE
(
PADDLE_ENFORCE_GE
(
size
,
1UL
,
size
,
1UL
,
paddle
::
platform
::
errors
::
InvalidArgument
(
paddle
::
platform
::
errors
::
InvalidArgument
(
"The predictor pool size should be greater than 1, but it's (%d)"
,
"The predictor pool size should be greater than 1, but it's (%d)"
,
size
));
size
));
...
@@ -2105,9 +2148,11 @@ PredictorPool::PredictorPool(const Config &config, size_t size) {
...
@@ -2105,9 +2148,11 @@ PredictorPool::PredictorPool(const Config &config, size_t size) {
Predictor
*
PredictorPool
::
Retrive
(
size_t
idx
)
{
Predictor
*
PredictorPool
::
Retrive
(
size_t
idx
)
{
PADDLE_ENFORCE_LT
(
PADDLE_ENFORCE_LT
(
idx
,
preds_
.
size
()
+
1
,
idx
,
preds_
.
size
()
+
1
,
paddle
::
platform
::
errors
::
InvalidArgument
(
paddle
::
platform
::
errors
::
InvalidArgument
(
"There are (%d) predictors in the pool, but the idx is (%d)"
,
idx
,
"There are (%d) predictors in the pool, but the idx is (%d)"
,
idx
,
preds_
.
size
()
+
1
));
preds_
.
size
()
+
1
));
if
(
idx
==
0
)
{
if
(
idx
==
0
)
{
return
main_pred_
.
get
();
return
main_pred_
.
get
();
...
...
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
浏览文件 @
b6bf8994
...
@@ -60,6 +60,7 @@ list(
...
@@ -60,6 +60,7 @@ list(
preln_skip_layernorm.cc
preln_skip_layernorm.cc
roll_op.cc
roll_op.cc
transformer_input_convert_op.cc
transformer_input_convert_op.cc
cast_op.cc
remove_padding_op.cc
remove_padding_op.cc
recover_padding_op.cc
recover_padding_op.cc
preln_residual_bias.cc
preln_residual_bias.cc
...
...
paddle/fluid/inference/tensorrt/convert/cast_op.cc
0 → 100644
浏览文件 @
b6bf8994
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace
paddle
{
namespace
framework
{
class
Scope
;
namespace
proto
{
class
OpDesc
;
}
// namespace proto
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
class
CastOpConverter
:
public
OpConverter
{
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
VLOG
(
3
)
<<
"convert a cast op to tensorrt"
;
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
auto
*
input
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"X"
)[
0
]);
auto
out_dtype
=
BOOST_GET_CONST
(
int
,
op_desc
.
GetAttr
(
"out_dtype"
));
auto
*
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Identity
,
*
input
);
switch
(
out_dtype
)
{
case
2
:
// INT32 = 2
layer
->
getOutput
(
0
)
->
setType
(
nvinfer1
::
DataType
::
kINT32
);
break
;
case
4
:
// FP16 = 4
layer
->
getOutput
(
0
)
->
setType
(
nvinfer1
::
DataType
::
kHALF
);
break
;
case
5
:
// FP32 = 5
layer
->
getOutput
(
0
)
->
setType
(
nvinfer1
::
DataType
::
kFLOAT
);
break
;
default:
LOG
(
ERROR
)
<<
"Unable to convert a fluid data type("
<<
out_dtype
<<
") to a nvinfer DataType"
;
break
;
}
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
RreplenishLayerAndOutput
(
layer
,
"cast"
,
{
output_name
},
test_mode
);
}
};
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
REGISTER_TRT_OP_CONVERTER
(
cast
,
CastOpConverter
);
paddle/fluid/inference/tensorrt/op_teller.cc
浏览文件 @
b6bf8994
...
@@ -55,7 +55,8 @@ struct SimpleOpTypeSetTeller : public Teller {
...
@@ -55,7 +55,8 @@ struct SimpleOpTypeSetTeller : public Teller {
#endif
#endif
}
}
bool
operator
()(
const
std
::
string
&
op_type
,
const
framework
::
OpDesc
&
desc
,
bool
operator
()(
const
std
::
string
&
op_type
,
const
framework
::
OpDesc
&
desc
,
bool
use_no_calib_int8
)
override
{
bool
use_no_calib_int8
)
override
{
if
(
use_no_calib_int8
)
{
if
(
use_no_calib_int8
)
{
return
int8_teller_set
.
count
(
op_type
);
return
int8_teller_set
.
count
(
op_type
);
...
@@ -162,6 +163,7 @@ struct SimpleOpTypeSetTeller : public Teller {
...
@@ -162,6 +163,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"c_allreduce_max"
,
"c_allreduce_max"
,
"c_allreduce_prod"
,
"c_allreduce_prod"
,
"roll"
,
"roll"
,
"cast"
,
"preln_skip_layernorm"
,
"preln_skip_layernorm"
,
"transformer_input_convert"
,
"transformer_input_convert"
,
"recover_padding"
,
"recover_padding"
,
...
@@ -265,6 +267,7 @@ struct SimpleOpTypeSetTeller : public Teller {
...
@@ -265,6 +267,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"c_allreduce_max"
,
"c_allreduce_max"
,
"c_allreduce_prod"
,
"c_allreduce_prod"
,
"roll"
,
"roll"
,
"cast"
,
"multiclass_nms3"
,
"multiclass_nms3"
,
"transformer_input_convert"
,
"transformer_input_convert"
,
"recover_padding"
,
"recover_padding"
,
...
@@ -273,7 +276,8 @@ struct SimpleOpTypeSetTeller : public Teller {
...
@@ -273,7 +276,8 @@ struct SimpleOpTypeSetTeller : public Teller {
"unsqueeze2"
};
"unsqueeze2"
};
};
};
bool
OpTeller
::
Tell
(
const
framework
::
ir
::
Node
*
node
,
bool
use_no_calib_int8
,
bool
OpTeller
::
Tell
(
const
framework
::
ir
::
Node
*
node
,
bool
use_no_calib_int8
,
bool
with_dynamic_shape
)
{
bool
with_dynamic_shape
)
{
const
std
::
string
op_type
=
node
->
Op
()
->
Type
();
const
std
::
string
op_type
=
node
->
Op
()
->
Type
();
const
framework
::
OpDesc
desc
=
*
node
->
Op
();
const
framework
::
OpDesc
desc
=
*
node
->
Op
();
...
@@ -818,8 +822,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
...
@@ -818,8 +822,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
}
}
if
(
op_type
==
"nearest_interp"
)
{
if
(
op_type
==
"nearest_interp"
)
{
std
::
vector
<
std
::
string
>
attrs
{
"interp_method"
,
"align_corners"
,
"scale"
,
std
::
vector
<
std
::
string
>
attrs
{
"out_h"
,
"out_w"
};
"interp_method"
,
"align_corners"
,
"scale"
,
"out_h"
,
"out_w"
};
for
(
auto
const
attr
:
attrs
)
{
for
(
auto
const
attr
:
attrs
)
{
if
(
!
desc
.
HasAttr
(
attr
))
return
false
;
if
(
!
desc
.
HasAttr
(
attr
))
return
false
;
}
}
...
@@ -859,9 +863,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
...
@@ -859,9 +863,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
}
}
if
(
op_type
==
"nearest_interp_v2"
)
{
if
(
op_type
==
"nearest_interp_v2"
)
{
std
::
vector
<
std
::
string
>
attrs
{
"data_layout"
,
"interp_method"
,
std
::
vector
<
std
::
string
>
attrs
{
"data_layout"
,
"align_corners"
,
"scale"
,
"interp_method"
,
"out_h"
,
"out_w"
};
"align_corners"
,
"scale"
,
"out_h"
,
"out_w"
};
for
(
auto
const
attr
:
attrs
)
{
for
(
auto
const
attr
:
attrs
)
{
if
(
!
desc
.
HasAttr
(
attr
))
return
false
;
if
(
!
desc
.
HasAttr
(
attr
))
return
false
;
}
}
...
@@ -887,9 +894,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
...
@@ -887,9 +894,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
}
}
if
(
op_type
==
"bilinear_interp_v2"
)
{
if
(
op_type
==
"bilinear_interp_v2"
)
{
std
::
vector
<
std
::
string
>
attrs
{
"data_layout"
,
"interp_method"
,
std
::
vector
<
std
::
string
>
attrs
{
"data_layout"
,
"align_corners"
,
"scale"
,
"interp_method"
,
"out_h"
,
"out_w"
};
"align_corners"
,
"scale"
,
"out_h"
,
"out_w"
};
for
(
auto
const
attr
:
attrs
)
{
for
(
auto
const
attr
:
attrs
)
{
if
(
!
desc
.
HasAttr
(
attr
))
{
if
(
!
desc
.
HasAttr
(
attr
))
{
VLOG
(
3
)
<<
"The op_type "
<<
op_type
<<
" doesn't have the attr "
VLOG
(
3
)
<<
"The op_type "
<<
op_type
<<
" doesn't have the attr "
...
@@ -1032,8 +1042,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
...
@@ -1032,8 +1042,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
}
}
if
(
op_type
==
"batch_norm"
)
{
if
(
op_type
==
"batch_norm"
)
{
const
std
::
vector
<
std
::
string
>
bn_inputs
=
{
"X"
,
"Bias"
,
"Mean"
,
"Scale"
,
const
std
::
vector
<
std
::
string
>
bn_inputs
=
{
"Variance"
};
"X"
,
"Bias"
,
"Mean"
,
"Scale"
,
"Variance"
};
for
(
unsigned
int
i
=
0
;
i
<
bn_inputs
.
size
();
i
++
)
{
for
(
unsigned
int
i
=
0
;
i
<
bn_inputs
.
size
();
i
++
)
{
if
(
desc
.
Input
(
bn_inputs
[
i
]).
size
()
!=
1
)
{
if
(
desc
.
Input
(
bn_inputs
[
i
]).
size
()
!=
1
)
{
VLOG
(
3
)
<<
"Invalid "
<<
bn_inputs
[
i
]
VLOG
(
3
)
<<
"Invalid "
<<
bn_inputs
[
i
]
...
@@ -1585,8 +1595,10 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
...
@@ -1585,8 +1595,10 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
"the roi_align will change the batch size."
;
"the roi_align will change the batch size."
;
return
false
;
return
false
;
}
}
std
::
vector
<
std
::
string
>
attrs
{
"pooled_height"
,
"pooled_width"
,
std
::
vector
<
std
::
string
>
attrs
{
"pooled_height"
,
"spatial_scale"
,
"sampling_ratio"
,
"pooled_width"
,
"spatial_scale"
,
"sampling_ratio"
,
"aligned"
};
"aligned"
};
for
(
auto
const
attr
:
attrs
)
{
for
(
auto
const
attr
:
attrs
)
{
if
(
!
desc
.
HasAttr
(
attr
))
return
false
;
if
(
!
desc
.
HasAttr
(
attr
))
return
false
;
...
@@ -1771,10 +1783,10 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
...
@@ -1771,10 +1783,10 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
auto
x_var_name
=
desc
.
Input
(
"X"
)[
0
];
auto
x_var_name
=
desc
.
Input
(
"X"
)[
0
];
auto
*
x_var_desc
=
block
->
FindVar
(
x_var_name
);
auto
*
x_var_desc
=
block
->
FindVar
(
x_var_name
);
const
auto
x_shape
=
x_var_desc
->
GetShape
();
const
auto
x_shape
=
x_var_desc
->
GetShape
();
int
input_num
=
std
::
accumulate
(
x_shape
.
begin
()
+
1
,
x_shape
.
end
(),
1
,
int
input_num
=
std
::
accumulate
(
std
::
multiplies
<
int
>
());
x_shape
.
begin
()
+
1
,
x_shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
int
shape_num
=
std
::
accumulate
(
shape
.
begin
()
+
1
,
shape
.
end
(),
1
,
int
shape_num
=
std
::
accumulate
(
std
::
multiplies
<
int
>
());
shape
.
begin
()
+
1
,
shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
if
(
input_num
==
shape_num
)
{
if
(
input_num
==
shape_num
)
{
return
true
;
return
true
;
}
}
...
@@ -1960,6 +1972,23 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
...
@@ -1960,6 +1972,23 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
}
}
}
}
if
(
op_type
==
"cast"
)
{
int
in_dtype
=
BOOST_GET_CONST
(
int
,
desc
.
GetAttr
(
"in_dtype"
));
int
out_dtype
=
BOOST_GET_CONST
(
int
,
desc
.
GetAttr
(
"out_dtype"
));
if
((
in_dtype
==
4
||
in_dtype
==
5
)
&&
out_dtype
==
4
)
{
VLOG
(
3
)
<<
"unsupport data type conversion"
;
return
false
;
}
if
(
!
((
in_dtype
==
5
||
in_dtype
==
4
||
in_dtype
==
2
||
in_dtype
==
0
)
&&
(
out_dtype
==
5
||
out_dtype
==
4
||
out_dtype
==
2
)))
{
VLOG
(
3
)
<<
"only valid conversions are: "
"(kFLOAT | kHALF | kINT32 | kBOOL) -> (kFLOAT | kHALF | kINT32)"
;
return
false
;
}
}
if
(
op_type
==
"top_k_v2"
||
op_type
==
"top_k"
)
{
if
(
op_type
==
"top_k_v2"
||
op_type
==
"top_k"
)
{
auto
*
block
=
desc
.
Block
();
auto
*
block
=
desc
.
Block
();
auto
x_var_name
=
desc
.
Input
(
"X"
)[
0
];
auto
x_var_name
=
desc
.
Input
(
"X"
)[
0
];
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_cast.py
0 → 100644
浏览文件 @
b6bf8994
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
trt_layer_auto_scan_test
import
TrtLayerAutoScanTest
,
SkipReasons
from
program_config
import
TensorConfig
,
ProgramConfig
import
unittest
import
numpy
as
np
import
paddle.inference
as
paddle_infer
from
functools
import
partial
from
typing
import
Optional
,
List
,
Callable
,
Dict
,
Any
,
Set
class
TrtConvertCastTest
(
TrtLayerAutoScanTest
):
def
is_program_valid
(
self
,
program_config
:
ProgramConfig
)
->
bool
:
attrs
=
[
program_config
.
ops
[
i
].
attrs
for
i
in
range
(
len
(
program_config
.
ops
))
]
if
attrs
[
0
][
'in_dtype'
]
in
[
4
,
5
]
and
attrs
[
0
][
'out_dtype'
]
==
4
:
return
False
if
attrs
[
0
][
'in_dtype'
]
not
in
[
0
,
2
,
4
,
5
]
or
attrs
[
0
][
'out_dtype'
]
not
in
[
2
,
4
,
5
]:
return
False
return
True
def
sample_program_configs
(
self
):
def
generate_input
(
type
):
if
type
==
0
:
return
np
.
ones
([
1
,
3
,
64
,
64
]).
astype
(
np
.
bool
)
elif
type
==
2
:
return
np
.
ones
([
1
,
3
,
64
,
64
]).
astype
(
np
.
int32
)
elif
type
==
4
:
return
np
.
ones
([
1
,
3
,
64
,
64
]).
astype
(
np
.
float16
)
else
:
return
np
.
ones
([
1
,
3
,
64
,
64
]).
astype
(
np
.
float32
)
for
in_dtype
in
[
0
,
2
,
4
,
5
,
6
]:
for
out_dtype
in
[
0
,
2
,
4
,
5
,
6
]:
dics
=
[{
"in_dtype"
:
in_dtype
,
"out_dtype"
:
out_dtype
}]
ops_config
=
[{
"op_type"
:
"cast"
,
"op_inputs"
:
{
"X"
:
[
"input_data"
]
},
"op_outputs"
:
{
"Out"
:
[
"cast_output_data"
]
},
"op_attrs"
:
dics
[
0
]
}]
ops
=
self
.
generate_op_config
(
ops_config
)
program_config
=
ProgramConfig
(
ops
=
ops
,
weights
=
{},
inputs
=
{
"input_data"
:
TensorConfig
(
data_gen
=
partial
(
generate_input
,
in_dtype
))
},
outputs
=
[
"cast_output_data"
])
yield
program_config
def
sample_predictor_configs
(
self
,
program_config
)
->
(
paddle_infer
.
Config
,
List
[
int
],
float
):
def
generate_dynamic_shape
(
attrs
):
self
.
dynamic_shape
.
min_input_shape
=
{
"input_data"
:
[
1
,
3
,
64
,
64
]}
self
.
dynamic_shape
.
max_input_shape
=
{
"input_data"
:
[
4
,
3
,
64
,
64
]}
self
.
dynamic_shape
.
opt_input_shape
=
{
"input_data"
:
[
1
,
3
,
64
,
64
]}
def
clear_dynamic_shape
():
self
.
dynamic_shape
.
min_input_shape
=
{}
self
.
dynamic_shape
.
max_input_shape
=
{}
self
.
dynamic_shape
.
opt_input_shape
=
{}
def
generate_trt_nodes_num
(
attrs
,
dynamic_shape
):
return
1
,
2
attrs
=
[
program_config
.
ops
[
i
].
attrs
for
i
in
range
(
len
(
program_config
.
ops
))
]
# for static_shape
clear_dynamic_shape
()
self
.
trt_param
.
precision
=
paddle_infer
.
PrecisionType
.
Float32
yield
self
.
create_inference_config
(),
generate_trt_nodes_num
(
attrs
,
False
),
1e-5
self
.
trt_param
.
precision
=
paddle_infer
.
PrecisionType
.
Half
yield
self
.
create_inference_config
(),
generate_trt_nodes_num
(
attrs
,
False
),
1e-2
# for dynamic_shape
generate_dynamic_shape
(
attrs
)
self
.
trt_param
.
precision
=
paddle_infer
.
PrecisionType
.
Float32
yield
self
.
create_inference_config
(),
generate_trt_nodes_num
(
attrs
,
True
),
1e-5
self
.
trt_param
.
precision
=
paddle_infer
.
PrecisionType
.
Half
yield
self
.
create_inference_config
(),
generate_trt_nodes_num
(
attrs
,
True
),
1e-2
def
test
(
self
):
self
.
run_test
()
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录