Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
9b6a0296
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9b6a0296
编写于
4月 02, 2019
作者:
W
Wojciech Uss
提交者:
Tao Luo
4月 02, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix dataset reading and add support for full dataset (#16559)
上级
220190d5
变更
19
显示空白变更内容
内联
并排
Showing
19 changed file
with
193 addition
and
131 deletion
+193
-131
paddle/fluid/inference/api/helper.h
paddle/fluid/inference/api/helper.h
+11
-10
paddle/fluid/inference/tests/api/CMakeLists.txt
paddle/fluid/inference/tests/api/CMakeLists.txt
+9
-5
paddle/fluid/inference/tests/api/analyzer_bert_tester.cc
paddle/fluid/inference/tests/api/analyzer_bert_tester.cc
+1
-1
paddle/fluid/inference/tests/api/analyzer_dam_tester.cc
paddle/fluid/inference/tests/api/analyzer_dam_tester.cc
+5
-3
paddle/fluid/inference/tests/api/analyzer_int8_image_classification_tester.cc
...ce/tests/api/analyzer_int8_image_classification_tester.cc
+29
-24
paddle/fluid/inference/tests/api/analyzer_lac_tester.cc
paddle/fluid/inference/tests/api/analyzer_lac_tester.cc
+6
-4
paddle/fluid/inference/tests/api/analyzer_mm_dnn_tester.cc
paddle/fluid/inference/tests/api/analyzer_mm_dnn_tester.cc
+4
-3
paddle/fluid/inference/tests/api/analyzer_ner_tester.cc
paddle/fluid/inference/tests/api/analyzer_ner_tester.cc
+6
-4
paddle/fluid/inference/tests/api/analyzer_pyramid_dnn_tester.cc
.../fluid/inference/tests/api/analyzer_pyramid_dnn_tester.cc
+6
-4
paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc
paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc
+1
-1
paddle/fluid/inference/tests/api/analyzer_rnn1_tester.cc
paddle/fluid/inference/tests/api/analyzer_rnn1_tester.cc
+2
-2
paddle/fluid/inference/tests/api/analyzer_rnn2_tester.cc
paddle/fluid/inference/tests/api/analyzer_rnn2_tester.cc
+5
-3
paddle/fluid/inference/tests/api/analyzer_seq_conv1_tester.cc
...le/fluid/inference/tests/api/analyzer_seq_conv1_tester.cc
+6
-4
paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc
...le/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc
+1
-1
paddle/fluid/inference/tests/api/analyzer_text_classification_tester.cc
...nference/tests/api/analyzer_text_classification_tester.cc
+4
-3
paddle/fluid/inference/tests/api/analyzer_transformer_tester.cc
.../fluid/inference/tests/api/analyzer_transformer_tester.cc
+1
-1
paddle/fluid/inference/tests/api/analyzer_vis_tester.cc
paddle/fluid/inference/tests/api/analyzer_vis_tester.cc
+3
-2
paddle/fluid/inference/tests/api/tester_helper.h
paddle/fluid/inference/tests/api/tester_helper.h
+92
-55
paddle/fluid/inference/tests/api/trt_models_tester.cc
paddle/fluid/inference/tests/api/trt_models_tester.cc
+1
-1
未找到文件。
paddle/fluid/inference/api/helper.h
浏览文件 @
9b6a0296
...
@@ -27,6 +27,7 @@
...
@@ -27,6 +27,7 @@
#include <string>
#include <string>
#include <vector>
#include <vector>
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/port.h"
#include "paddle/fluid/platform/port.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/printf.h"
...
@@ -266,17 +267,17 @@ static std::string DescribeZeroCopyTensor(const ZeroCopyTensor &tensor) {
...
@@ -266,17 +267,17 @@ static std::string DescribeZeroCopyTensor(const ZeroCopyTensor &tensor) {
}
}
static
void
PrintTime
(
int
batch_size
,
int
repeat
,
int
num_threads
,
int
tid
,
static
void
PrintTime
(
int
batch_size
,
int
repeat
,
int
num_threads
,
int
tid
,
double
latency
,
int
epoch
=
1
)
{
double
batch_
latency
,
int
epoch
=
1
)
{
LOG
(
INFO
)
<<
"====== batch_size: "
<<
batch_size
<<
", repeat: "
<<
repeat
PADDLE_ENFORCE
(
batch_size
>
0
,
"Non-positive batch size."
);
<<
", threads: "
<<
num_threads
<<
", thread id: "
<<
tid
double
sample_latency
=
batch_latency
/
batch_size
;
<<
", latency: "
<<
latency
<<
"ms, fps: "
<<
1
/
(
latency
/
1000.
f
)
LOG
(
INFO
)
<<
"====== threads: "
<<
num_threads
<<
", thread id: "
<<
tid
<<
" ======"
;
<<
" ======"
;
if
(
epoch
>
1
)
{
LOG
(
INFO
)
<<
"====== batch_size: "
<<
batch_size
<<
", iterations: "
<<
epoch
int
samples
=
batch_size
*
epoch
;
<<
", repetitions: "
<<
repeat
<<
" ======"
;
LOG
(
INFO
)
<<
"====== sample number: "
<<
samples
LOG
(
INFO
)
<<
"====== batch latency: "
<<
batch_latency
<<
", average latency of each sample: "
<<
latency
/
samples
<<
"ms, number of samples: "
<<
batch_size
*
epoch
<<
"ms ======"
;
<<
", sample latency: "
<<
sample_latency
}
<<
"ms, fps: "
<<
1000.
f
/
sample_latency
<<
" ======"
;
}
}
static
bool
IsFileExists
(
const
std
::
string
&
path
)
{
static
bool
IsFileExists
(
const
std
::
string
&
path
)
{
...
...
paddle/fluid/inference/tests/api/CMakeLists.txt
浏览文件 @
9b6a0296
...
@@ -26,7 +26,11 @@ endfunction()
...
@@ -26,7 +26,11 @@ endfunction()
function
(
inference_analysis_api_int8_test target model_dir data_dir filename
)
function
(
inference_analysis_api_int8_test target model_dir data_dir filename
)
inference_analysis_test
(
${
target
}
SRCS
${
filename
}
inference_analysis_test
(
${
target
}
SRCS
${
filename
}
EXTRA_DEPS
${
INFERENCE_EXTRA_DEPS
}
benchmark
EXTRA_DEPS
${
INFERENCE_EXTRA_DEPS
}
benchmark
ARGS --infer_model=
${
model_dir
}
/model --infer_data=
${
data_dir
}
/data.bin --batch_size=100
)
ARGS --infer_model=
${
model_dir
}
/model
--infer_data=
${
data_dir
}
/data.bin
--warmup_batch_size=100
--batch_size=50
--iterations=2
)
endfunction
()
endfunction
()
function
(
inference_analysis_api_test_with_fake_data target install_dir filename model_name
)
function
(
inference_analysis_api_test_with_fake_data target install_dir filename model_name
)
...
@@ -146,22 +150,22 @@ inference_analysis_api_test_with_fake_data(test_analyzer_mobilenet_depthwise_con
...
@@ -146,22 +150,22 @@ inference_analysis_api_test_with_fake_data(test_analyzer_mobilenet_depthwise_con
# int8 image classification tests
# int8 image classification tests
if
(
WITH_MKLDNN
)
if
(
WITH_MKLDNN
)
set
(
INT8_DATA_DIR
"
${
INFERENCE_DEMO_INSTALL_DIR
}
/int8"
)
set
(
INT8_DATA_DIR
"
${
INFERENCE_DEMO_INSTALL_DIR
}
/int8
v2
"
)
if
(
NOT EXISTS
${
INT8_DATA_DIR
}
)
if
(
NOT EXISTS
${
INT8_DATA_DIR
}
)
inference_download_and_uncompress
(
${
INT8_DATA_DIR
}
${
INFERENCE_URL
}
"/int8"
"imagenet_val_100
.tar.gz"
)
inference_download_and_uncompress
(
${
INT8_DATA_DIR
}
"
${
INFERENCE_URL
}
/int8"
"imagenet_val_100_tail
.tar.gz"
)
endif
()
endif
()
#resnet50 int8
#resnet50 int8
set
(
INT8_RESNET50_MODEL_DIR
"
${
INT8_DATA_DIR
}
/resnet50"
)
set
(
INT8_RESNET50_MODEL_DIR
"
${
INT8_DATA_DIR
}
/resnet50"
)
if
(
NOT EXISTS
${
INT8_RESNET50_MODEL_DIR
}
)
if
(
NOT EXISTS
${
INT8_RESNET50_MODEL_DIR
}
)
inference_download_and_uncompress
(
${
INT8_RESNET50_MODEL_DIR
}
${
INFERENCE_URL
}
"
/int8"
"resnet50_int8_model.tar.gz"
)
inference_download_and_uncompress
(
${
INT8_RESNET50_MODEL_DIR
}
"
${
INFERENCE_URL
}
/int8"
"resnet50_int8_model.tar.gz"
)
endif
()
endif
()
inference_analysis_api_int8_test
(
test_analyzer_int8_resnet50
${
INT8_RESNET50_MODEL_DIR
}
${
INT8_DATA_DIR
}
analyzer_int8_image_classification_tester.cc SERIAL
)
inference_analysis_api_int8_test
(
test_analyzer_int8_resnet50
${
INT8_RESNET50_MODEL_DIR
}
${
INT8_DATA_DIR
}
analyzer_int8_image_classification_tester.cc SERIAL
)
#mobilenet int8
#mobilenet int8
set
(
INT8_MOBILENET_MODEL_DIR
"
${
INT8_DATA_DIR
}
/mobilenet"
)
set
(
INT8_MOBILENET_MODEL_DIR
"
${
INT8_DATA_DIR
}
/mobilenet"
)
if
(
NOT EXISTS
${
INT8_MOBILENET_MODEL_DIR
}
)
if
(
NOT EXISTS
${
INT8_MOBILENET_MODEL_DIR
}
)
inference_download_and_uncompress
(
${
INT8_MOBILENET_MODEL_DIR
}
${
INFERENCE_URL
}
"
/int8"
"mobilenetv1_int8_model.tar.gz"
)
inference_download_and_uncompress
(
${
INT8_MOBILENET_MODEL_DIR
}
"
${
INFERENCE_URL
}
/int8"
"mobilenetv1_int8_model.tar.gz"
)
endif
()
endif
()
inference_analysis_api_int8_test
(
test_analyzer_int8_mobilenet
${
INT8_MOBILENET_MODEL_DIR
}
${
INT8_DATA_DIR
}
analyzer_int8_image_classification_tester.cc SERIAL
)
inference_analysis_api_int8_test
(
test_analyzer_int8_mobilenet
${
INT8_MOBILENET_MODEL_DIR
}
${
INT8_DATA_DIR
}
analyzer_int8_image_classification_tester.cc SERIAL
)
endif
()
endif
()
...
...
paddle/fluid/inference/tests/api/analyzer_bert_tester.cc
浏览文件 @
9b6a0296
...
@@ -154,7 +154,7 @@ void profile(bool use_mkldnn = false) {
...
@@ -154,7 +154,7 @@ void profile(bool use_mkldnn = false) {
config
.
EnableMKLDNN
();
config
.
EnableMKLDNN
();
}
}
std
::
vector
<
PaddleTensor
>
outputs
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>
>
outputs
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
inputs
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
inputs
;
LoadInputData
(
&
inputs
);
LoadInputData
(
&
inputs
);
TestPrediction
(
reinterpret_cast
<
const
PaddlePredictor
::
Config
*>
(
&
config
),
TestPrediction
(
reinterpret_cast
<
const
PaddlePredictor
::
Config
*>
(
&
config
),
...
...
paddle/fluid/inference/tests/api/analyzer_dam_tester.cc
浏览文件 @
9b6a0296
...
@@ -197,7 +197,7 @@ void profile(bool use_mkldnn = false) {
...
@@ -197,7 +197,7 @@ void profile(bool use_mkldnn = false) {
cfg
.
SetMKLDNNOp
(
op_list
);
cfg
.
SetMKLDNNOp
(
op_list
);
}
}
std
::
vector
<
PaddleTensor
>
outputs
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>
>
outputs
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
SetInput
(
&
input_slots_all
);
SetInput
(
&
input_slots_all
);
...
@@ -206,9 +206,11 @@ void profile(bool use_mkldnn = false) {
...
@@ -206,9 +206,11 @@ void profile(bool use_mkldnn = false) {
if
(
FLAGS_num_threads
==
1
&&
!
FLAGS_test_all_data
)
{
if
(
FLAGS_num_threads
==
1
&&
!
FLAGS_test_all_data
)
{
PADDLE_ENFORCE_GT
(
outputs
.
size
(),
0
);
PADDLE_ENFORCE_GT
(
outputs
.
size
(),
0
);
size_t
size
=
GetSize
(
outputs
[
0
]);
auto
output
=
outputs
.
back
();
PADDLE_ENFORCE_GT
(
output
.
size
(),
0
);
size_t
size
=
GetSize
(
output
[
0
]);
PADDLE_ENFORCE_GT
(
size
,
0
);
PADDLE_ENFORCE_GT
(
size
,
0
);
float
*
result
=
static_cast
<
float
*>
(
output
s
[
0
].
data
.
data
());
float
*
result
=
static_cast
<
float
*>
(
output
[
0
].
data
.
data
());
for
(
size_t
i
=
0
;
i
<
size
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
size
;
i
++
)
{
EXPECT_NEAR
(
result
[
i
],
result_data
[
i
],
1e-3
);
EXPECT_NEAR
(
result
[
i
],
result_data
[
i
],
1e-3
);
}
}
...
...
paddle/fluid/inference/tests/api/analyzer_int8_image_classification_tester.cc
浏览文件 @
9b6a0296
...
@@ -17,8 +17,6 @@ limitations under the License. */
...
@@ -17,8 +17,6 @@ limitations under the License. */
#include "paddle/fluid/inference/api/paddle_analysis_config.h"
#include "paddle/fluid/inference/api/paddle_analysis_config.h"
#include "paddle/fluid/inference/tests/api/tester_helper.h"
#include "paddle/fluid/inference/tests/api/tester_helper.h"
DEFINE_int32
(
iterations
,
0
,
"Number of iterations"
);
namespace
paddle
{
namespace
paddle
{
namespace
inference
{
namespace
inference
{
namespace
analysis
{
namespace
analysis
{
...
@@ -30,8 +28,13 @@ void SetConfig(AnalysisConfig *cfg) {
...
@@ -30,8 +28,13 @@ void SetConfig(AnalysisConfig *cfg) {
cfg
->
SwitchIrOptim
();
cfg
->
SwitchIrOptim
();
cfg
->
SwitchSpecifyInputNames
(
false
);
cfg
->
SwitchSpecifyInputNames
(
false
);
cfg
->
SetCpuMathLibraryNumThreads
(
FLAGS_paddle_num_threads
);
cfg
->
SetCpuMathLibraryNumThreads
(
FLAGS_paddle_num_threads
);
cfg
->
EnableMKLDNN
();
cfg
->
EnableMKLDNN
();
cfg
->
pass_builder
()
->
SetPasses
(
{
"infer_clean_graph_pass"
,
"mkldnn_placement_pass"
,
"depthwise_conv_mkldnn_pass"
,
"conv_bn_fuse_pass"
,
"conv_eltwiseadd_bn_fuse_pass"
,
"conv_bias_mkldnn_fuse_pass"
,
"conv_elementwise_add_mkldnn_fuse_pass"
,
"conv_relu_mkldnn_fuse_pass"
,
"fc_fuse_pass"
,
"is_test_pass"
});
}
}
template
<
typename
T
>
template
<
typename
T
>
...
@@ -40,8 +43,8 @@ class TensorReader {
...
@@ -40,8 +43,8 @@ class TensorReader {
TensorReader
(
std
::
ifstream
&
file
,
size_t
beginning_offset
,
TensorReader
(
std
::
ifstream
&
file
,
size_t
beginning_offset
,
std
::
vector
<
int
>
shape
,
std
::
string
name
)
std
::
vector
<
int
>
shape
,
std
::
string
name
)
:
file_
(
file
),
position
(
beginning_offset
),
shape_
(
shape
),
name_
(
name
)
{
:
file_
(
file
),
position
(
beginning_offset
),
shape_
(
shape
),
name_
(
name
)
{
numel
=
numel
=
std
::
accumulate
(
shape_
.
begin
(),
shape_
.
end
(),
size_t
{
1
},
std
::
accumulate
(
shape_
.
begin
(),
shape_
.
end
(),
1
,
std
::
multiplies
<
T
>
());
std
::
multiplies
<
size_t
>
());
}
}
PaddleTensor
NextBatch
()
{
PaddleTensor
NextBatch
()
{
...
@@ -71,10 +74,14 @@ class TensorReader {
...
@@ -71,10 +74,14 @@ class TensorReader {
};
};
std
::
shared_ptr
<
std
::
vector
<
PaddleTensor
>>
GetWarmupData
(
std
::
shared_ptr
<
std
::
vector
<
PaddleTensor
>>
GetWarmupData
(
const
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
&
test_data
,
int
num_images
)
{
const
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
&
test_data
,
int
num_images
=
FLAGS_warmup_batch_size
)
{
int
test_data_batch_size
=
test_data
[
0
][
0
].
shape
[
0
];
int
test_data_batch_size
=
test_data
[
0
][
0
].
shape
[
0
];
CHECK_LE
(
static_cast
<
size_t
>
(
num_images
),
auto
iterations_max
=
test_data
.
size
();
test_data
.
size
()
*
test_data_batch_size
);
PADDLE_ENFORCE
(
static_cast
<
size_t
>
(
num_images
)
<=
iterations_max
*
test_data_batch_size
,
"The requested quantization warmup data size "
+
std
::
to_string
(
num_images
)
+
" is bigger than all test data size."
);
PaddleTensor
images
;
PaddleTensor
images
;
images
.
name
=
"input"
;
images
.
name
=
"input"
;
...
@@ -120,20 +127,17 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs,
...
@@ -120,20 +127,17 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs,
std
::
vector
<
int
>
image_batch_shape
{
batch_size
,
3
,
224
,
224
};
std
::
vector
<
int
>
image_batch_shape
{
batch_size
,
3
,
224
,
224
};
std
::
vector
<
int
>
label_batch_shape
{
batch_size
,
1
};
std
::
vector
<
int
>
label_batch_shape
{
batch_size
,
1
};
auto
images_offset_in_file
=
static_cast
<
size_t
>
(
file
.
tellg
());
auto
labels_offset_in_file
=
auto
labels_offset_in_file
=
static_cast
<
size_t
>
(
file
.
tellg
())
+
images_offset_in_file
+
sizeof
(
float
)
*
total_images
*
3
*
224
*
224
;
sizeof
(
float
)
*
total_images
*
std
::
accumulate
(
image_batch_shape
.
begin
()
+
1
,
image_batch_shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
TensorReader
<
float
>
image_reader
(
file
,
0
,
image_batch_shape
,
"input"
);
TensorReader
<
float
>
image_reader
(
file
,
images_offset_in_file
,
image_batch_shape
,
"input"
);
TensorReader
<
int64_t
>
label_reader
(
file
,
labels_offset_in_file
,
TensorReader
<
int64_t
>
label_reader
(
file
,
labels_offset_in_file
,
label_batch_shape
,
"label"
);
label_batch_shape
,
"label"
);
auto
iterations
=
total_images
/
batch_size
;
auto
iterations_max
=
total_images
/
batch_size
;
if
(
FLAGS_iterations
>
0
&&
FLAGS_iterations
<
iterations
)
for
(
auto
i
=
0
;
i
<
iterations_max
;
i
++
)
{
iterations
=
FLAGS_iterations
;
for
(
auto
i
=
0
;
i
<
iterations
;
i
++
)
{
auto
images
=
image_reader
.
NextBatch
();
auto
images
=
image_reader
.
NextBatch
();
auto
labels
=
label_reader
.
NextBatch
();
auto
labels
=
label_reader
.
NextBatch
();
inputs
->
emplace_back
(
inputs
->
emplace_back
(
...
@@ -148,20 +152,21 @@ TEST(Analyzer_int8_resnet50, quantization) {
...
@@ -148,20 +152,21 @@ TEST(Analyzer_int8_resnet50, quantization) {
AnalysisConfig
q_cfg
;
AnalysisConfig
q_cfg
;
SetConfig
(
&
q_cfg
);
SetConfig
(
&
q_cfg
);
// read data from file and prepare batches with test data
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
SetInput
(
&
input_slots_all
,
100
);
SetInput
(
&
input_slots_all
);
// prepare warmup batch from input data read earlier
// warmup batch size can be different than batch size
std
::
shared_ptr
<
std
::
vector
<
PaddleTensor
>>
warmup_data
=
std
::
shared_ptr
<
std
::
vector
<
PaddleTensor
>>
warmup_data
=
GetWarmupData
(
input_slots_all
,
100
);
GetWarmupData
(
input_slots_all
);
// configure quantizer
q_cfg
.
EnableMkldnnQuantizer
();
q_cfg
.
EnableMkldnnQuantizer
();
q_cfg
.
mkldnn_quantizer_config
()
->
SetWarmupData
(
warmup_data
);
q_cfg
.
mkldnn_quantizer_config
()
->
SetWarmupData
(
warmup_data
);
q_cfg
.
mkldnn_quantizer_config
()
->
SetWarmupBatchSize
(
100
);
q_cfg
.
mkldnn_quantizer_config
()
->
SetWarmupBatchSize
(
FLAGS_warmup_batch_size
);
CompareQuantizedAndAnalysis
(
CompareQuantizedAndAnalysis
(
&
cfg
,
&
q_cfg
,
input_slots_all
);
reinterpret_cast
<
const
PaddlePredictor
::
Config
*>
(
&
cfg
),
reinterpret_cast
<
const
PaddlePredictor
::
Config
*>
(
&
q_cfg
),
input_slots_all
);
}
}
}
// namespace analysis
}
// namespace analysis
...
...
paddle/fluid/inference/tests/api/analyzer_lac_tester.cc
浏览文件 @
9b6a0296
...
@@ -124,7 +124,7 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
...
@@ -124,7 +124,7 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
TEST
(
Analyzer_LAC
,
profile
)
{
TEST
(
Analyzer_LAC
,
profile
)
{
AnalysisConfig
cfg
;
AnalysisConfig
cfg
;
SetConfig
(
&
cfg
);
SetConfig
(
&
cfg
);
std
::
vector
<
PaddleTensor
>
outputs
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>
>
outputs
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
SetInput
(
&
input_slots_all
);
SetInput
(
&
input_slots_all
);
...
@@ -137,11 +137,13 @@ TEST(Analyzer_LAC, profile) {
...
@@ -137,11 +137,13 @@ TEST(Analyzer_LAC, profile) {
24
,
25
,
25
,
25
,
38
,
30
,
31
,
14
,
15
,
44
,
24
,
25
,
25
,
25
,
25
,
25
,
24
,
25
,
25
,
25
,
38
,
30
,
31
,
14
,
15
,
44
,
24
,
25
,
25
,
25
,
25
,
25
,
44
,
24
,
25
,
25
,
25
,
36
,
42
,
43
,
44
,
14
,
15
,
44
,
14
,
15
,
44
,
14
,
44
,
24
,
25
,
25
,
25
,
36
,
42
,
43
,
44
,
14
,
15
,
44
,
14
,
15
,
44
,
14
,
15
,
44
,
38
,
39
,
14
,
15
,
44
,
22
,
23
,
23
,
23
,
23
,
23
,
23
,
23
};
15
,
44
,
38
,
39
,
14
,
15
,
44
,
22
,
23
,
23
,
23
,
23
,
23
,
23
,
23
};
PADDLE_ENFORCE_EQ
(
outputs
.
size
(),
1UL
);
PADDLE_ENFORCE_GT
(
outputs
.
size
(),
0
);
size_t
size
=
GetSize
(
outputs
[
0
]);
auto
output
=
outputs
.
back
();
PADDLE_ENFORCE_EQ
(
output
.
size
(),
1UL
);
size_t
size
=
GetSize
(
output
[
0
]);
size_t
batch1_size
=
sizeof
(
lac_ref_data
)
/
sizeof
(
int64_t
);
size_t
batch1_size
=
sizeof
(
lac_ref_data
)
/
sizeof
(
int64_t
);
PADDLE_ENFORCE_GE
(
size
,
batch1_size
);
PADDLE_ENFORCE_GE
(
size
,
batch1_size
);
int64_t
*
pdata
=
static_cast
<
int64_t
*>
(
output
s
[
0
].
data
.
data
());
int64_t
*
pdata
=
static_cast
<
int64_t
*>
(
output
[
0
].
data
.
data
());
for
(
size_t
i
=
0
;
i
<
batch1_size
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
batch1_size
;
++
i
)
{
EXPECT_EQ
(
pdata
[
i
],
lac_ref_data
[
i
]);
EXPECT_EQ
(
pdata
[
i
],
lac_ref_data
[
i
]);
}
}
...
...
paddle/fluid/inference/tests/api/analyzer_mm_dnn_tester.cc
浏览文件 @
9b6a0296
...
@@ -96,7 +96,7 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
...
@@ -96,7 +96,7 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
void
profile
(
bool
use_mkldnn
=
false
)
{
void
profile
(
bool
use_mkldnn
=
false
)
{
AnalysisConfig
cfg
;
AnalysisConfig
cfg
;
SetConfig
(
&
cfg
);
SetConfig
(
&
cfg
);
std
::
vector
<
PaddleTensor
>
outputs
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>
>
outputs
;
if
(
use_mkldnn
)
{
if
(
use_mkldnn
)
{
cfg
.
EnableMKLDNN
();
cfg
.
EnableMKLDNN
();
...
@@ -108,8 +108,9 @@ void profile(bool use_mkldnn = false) {
...
@@ -108,8 +108,9 @@ void profile(bool use_mkldnn = false) {
input_slots_all
,
&
outputs
,
FLAGS_num_threads
);
input_slots_all
,
&
outputs
,
FLAGS_num_threads
);
if
(
FLAGS_num_threads
==
1
&&
!
FLAGS_test_all_data
)
{
if
(
FLAGS_num_threads
==
1
&&
!
FLAGS_test_all_data
)
{
PADDLE_ENFORCE_EQ
(
outputs
.
size
(),
2UL
);
PADDLE_ENFORCE_GT
(
outputs
.
size
(),
0
);
for
(
auto
&
output
:
outputs
)
{
PADDLE_ENFORCE_EQ
(
outputs
.
back
().
size
(),
2UL
);
for
(
auto
&
output
:
outputs
.
back
())
{
size_t
size
=
GetSize
(
output
);
size_t
size
=
GetSize
(
output
);
PADDLE_ENFORCE_GT
(
size
,
0
);
PADDLE_ENFORCE_GT
(
size
,
0
);
float
*
result
=
static_cast
<
float
*>
(
output
.
data
.
data
());
float
*
result
=
static_cast
<
float
*>
(
output
.
data
.
data
());
...
...
paddle/fluid/inference/tests/api/analyzer_ner_tester.cc
浏览文件 @
9b6a0296
...
@@ -106,7 +106,7 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
...
@@ -106,7 +106,7 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
void
profile
(
bool
memory_load
=
false
)
{
void
profile
(
bool
memory_load
=
false
)
{
AnalysisConfig
cfg
;
AnalysisConfig
cfg
;
SetConfig
(
&
cfg
,
memory_load
);
SetConfig
(
&
cfg
,
memory_load
);
std
::
vector
<
PaddleTensor
>
outputs
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>
>
outputs
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
SetInput
(
&
input_slots_all
);
SetInput
(
&
input_slots_all
);
...
@@ -117,10 +117,12 @@ void profile(bool memory_load = false) {
...
@@ -117,10 +117,12 @@ void profile(bool memory_load = false) {
// the first inference result
// the first inference result
const
int
chinese_ner_result_data
[]
=
{
30
,
45
,
41
,
48
,
17
,
26
,
const
int
chinese_ner_result_data
[]
=
{
30
,
45
,
41
,
48
,
17
,
26
,
48
,
39
,
38
,
16
,
25
};
48
,
39
,
38
,
16
,
25
};
PADDLE_ENFORCE_EQ
(
outputs
.
size
(),
1UL
);
PADDLE_ENFORCE_GT
(
outputs
.
size
(),
0
);
size_t
size
=
GetSize
(
outputs
[
0
]);
auto
output
=
outputs
.
back
();
PADDLE_ENFORCE_EQ
(
output
.
size
(),
1UL
);
size_t
size
=
GetSize
(
output
[
0
]);
PADDLE_ENFORCE_GT
(
size
,
0
);
PADDLE_ENFORCE_GT
(
size
,
0
);
int64_t
*
result
=
static_cast
<
int64_t
*>
(
output
s
[
0
].
data
.
data
());
int64_t
*
result
=
static_cast
<
int64_t
*>
(
output
[
0
].
data
.
data
());
for
(
size_t
i
=
0
;
i
<
std
::
min
(
11UL
,
size
);
i
++
)
{
for
(
size_t
i
=
0
;
i
<
std
::
min
(
11UL
,
size
);
i
++
)
{
EXPECT_EQ
(
result
[
i
],
chinese_ner_result_data
[
i
]);
EXPECT_EQ
(
result
[
i
],
chinese_ner_result_data
[
i
]);
}
}
...
...
paddle/fluid/inference/tests/api/analyzer_pyramid_dnn_tester.cc
浏览文件 @
9b6a0296
...
@@ -127,7 +127,7 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
...
@@ -127,7 +127,7 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
TEST
(
Analyzer_Pyramid_DNN
,
profile
)
{
TEST
(
Analyzer_Pyramid_DNN
,
profile
)
{
AnalysisConfig
cfg
;
AnalysisConfig
cfg
;
SetConfig
(
&
cfg
);
SetConfig
(
&
cfg
);
std
::
vector
<
PaddleTensor
>
outputs
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>
>
outputs
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
SetInput
(
&
input_slots_all
);
SetInput
(
&
input_slots_all
);
...
@@ -135,10 +135,12 @@ TEST(Analyzer_Pyramid_DNN, profile) {
...
@@ -135,10 +135,12 @@ TEST(Analyzer_Pyramid_DNN, profile) {
input_slots_all
,
&
outputs
,
FLAGS_num_threads
);
input_slots_all
,
&
outputs
,
FLAGS_num_threads
);
if
(
FLAGS_num_threads
==
1
&&
!
FLAGS_test_all_data
&&
!
FLAGS_zero_copy
)
{
if
(
FLAGS_num_threads
==
1
&&
!
FLAGS_test_all_data
&&
!
FLAGS_zero_copy
)
{
PADDLE_ENFORCE_EQ
(
outputs
.
size
(),
1UL
);
PADDLE_ENFORCE_GT
(
outputs
.
size
(),
0
);
size_t
size
=
GetSize
(
outputs
[
0
]);
auto
output
=
outputs
.
back
();
PADDLE_ENFORCE_EQ
(
output
.
size
(),
1UL
);
size_t
size
=
GetSize
(
output
[
0
]);
PADDLE_ENFORCE_GT
(
size
,
0
);
PADDLE_ENFORCE_GT
(
size
,
0
);
float
*
result
=
static_cast
<
float
*>
(
output
s
[
0
].
data
.
data
());
float
*
result
=
static_cast
<
float
*>
(
output
[
0
].
data
.
data
());
// output is probability, which is in (0, 1).
// output is probability, which is in (0, 1).
for
(
size_t
i
=
0
;
i
<
size
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
size
;
i
++
)
{
EXPECT_GT
(
result
[
i
],
0
);
EXPECT_GT
(
result
[
i
],
0
);
...
...
paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc
浏览文件 @
9b6a0296
...
@@ -40,7 +40,7 @@ void profile(bool use_mkldnn = false) {
...
@@ -40,7 +40,7 @@ void profile(bool use_mkldnn = false) {
if
(
use_mkldnn
)
{
if
(
use_mkldnn
)
{
cfg
.
EnableMKLDNN
();
cfg
.
EnableMKLDNN
();
}
}
std
::
vector
<
PaddleTensor
>
outputs
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>
>
outputs
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
SetInput
(
&
input_slots_all
);
SetInput
(
&
input_slots_all
);
...
...
paddle/fluid/inference/tests/api/analyzer_rnn1_tester.cc
浏览文件 @
9b6a0296
...
@@ -229,7 +229,7 @@ TEST(Analyzer_rnn1, profile) {
...
@@ -229,7 +229,7 @@ TEST(Analyzer_rnn1, profile) {
SetConfig
(
&
cfg
);
SetConfig
(
&
cfg
);
cfg
.
DisableGpu
();
cfg
.
DisableGpu
();
cfg
.
SwitchIrDebug
();
cfg
.
SwitchIrDebug
();
std
::
vector
<
PaddleTensor
>
outputs
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>
>
outputs
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
SetInput
(
&
input_slots_all
);
SetInput
(
&
input_slots_all
);
...
@@ -280,7 +280,7 @@ TEST(Analyzer_rnn1, compare_determine) {
...
@@ -280,7 +280,7 @@ TEST(Analyzer_rnn1, compare_determine) {
TEST
(
Analyzer_rnn1
,
multi_thread
)
{
TEST
(
Analyzer_rnn1
,
multi_thread
)
{
AnalysisConfig
cfg
;
AnalysisConfig
cfg
;
SetConfig
(
&
cfg
);
SetConfig
(
&
cfg
);
std
::
vector
<
PaddleTensor
>
outputs
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>
>
outputs
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
SetInput
(
&
input_slots_all
);
SetInput
(
&
input_slots_all
);
...
...
paddle/fluid/inference/tests/api/analyzer_rnn2_tester.cc
浏览文件 @
9b6a0296
...
@@ -126,7 +126,7 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
...
@@ -126,7 +126,7 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
TEST
(
Analyzer_rnn2
,
profile
)
{
TEST
(
Analyzer_rnn2
,
profile
)
{
AnalysisConfig
cfg
;
AnalysisConfig
cfg
;
SetConfig
(
&
cfg
);
SetConfig
(
&
cfg
);
std
::
vector
<
PaddleTensor
>
outputs
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>
>
outputs
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
SetInput
(
&
input_slots_all
);
SetInput
(
&
input_slots_all
);
...
@@ -136,9 +136,11 @@ TEST(Analyzer_rnn2, profile) {
...
@@ -136,9 +136,11 @@ TEST(Analyzer_rnn2, profile) {
if
(
FLAGS_num_threads
==
1
&&
!
FLAGS_test_all_data
)
{
if
(
FLAGS_num_threads
==
1
&&
!
FLAGS_test_all_data
)
{
// the first inference result
// the first inference result
PADDLE_ENFORCE_GT
(
outputs
.
size
(),
0
);
PADDLE_ENFORCE_GT
(
outputs
.
size
(),
0
);
size_t
size
=
GetSize
(
outputs
[
0
]);
auto
output
=
outputs
.
back
();
PADDLE_ENFORCE_GT
(
output
.
size
(),
0
);
size_t
size
=
GetSize
(
output
[
0
]);
PADDLE_ENFORCE_GT
(
size
,
0
);
PADDLE_ENFORCE_GT
(
size
,
0
);
float
*
result
=
static_cast
<
float
*>
(
output
s
[
0
].
data
.
data
());
float
*
result
=
static_cast
<
float
*>
(
output
[
0
].
data
.
data
());
for
(
size_t
i
=
0
;
i
<
size
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
size
;
i
++
)
{
EXPECT_NEAR
(
result
[
i
],
result_data
[
i
],
1e-3
);
EXPECT_NEAR
(
result
[
i
],
result_data
[
i
],
1e-3
);
}
}
...
...
paddle/fluid/inference/tests/api/analyzer_seq_conv1_tester.cc
浏览文件 @
9b6a0296
...
@@ -110,7 +110,7 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
...
@@ -110,7 +110,7 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
TEST
(
Analyzer_seq_conv1
,
profile
)
{
TEST
(
Analyzer_seq_conv1
,
profile
)
{
AnalysisConfig
cfg
;
AnalysisConfig
cfg
;
SetConfig
(
&
cfg
);
SetConfig
(
&
cfg
);
std
::
vector
<
PaddleTensor
>
outputs
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>
>
outputs
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
SetInput
(
&
input_slots_all
);
SetInput
(
&
input_slots_all
);
...
@@ -119,10 +119,12 @@ TEST(Analyzer_seq_conv1, profile) {
...
@@ -119,10 +119,12 @@ TEST(Analyzer_seq_conv1, profile) {
if
(
FLAGS_num_threads
==
1
&&
!
FLAGS_test_all_data
)
{
if
(
FLAGS_num_threads
==
1
&&
!
FLAGS_test_all_data
)
{
// the first inference result
// the first inference result
PADDLE_ENFORCE_EQ
(
outputs
.
size
(),
1UL
);
PADDLE_ENFORCE_GT
(
outputs
.
size
(),
0
);
size_t
size
=
GetSize
(
outputs
[
0
]);
auto
output
=
outputs
.
back
();
PADDLE_ENFORCE_EQ
(
output
.
size
(),
1UL
);
size_t
size
=
GetSize
(
output
[
0
]);
PADDLE_ENFORCE_GT
(
size
,
0
);
PADDLE_ENFORCE_GT
(
size
,
0
);
float
*
result
=
static_cast
<
float
*>
(
output
s
[
0
].
data
.
data
());
float
*
result
=
static_cast
<
float
*>
(
output
[
0
].
data
.
data
());
// output is probability, which is in (0, 1).
// output is probability, which is in (0, 1).
for
(
size_t
i
=
0
;
i
<
size
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
size
;
i
++
)
{
EXPECT_GT
(
result
[
i
],
0
);
EXPECT_GT
(
result
[
i
],
0
);
...
...
paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc
浏览文件 @
9b6a0296
...
@@ -156,7 +156,7 @@ void profile(bool use_mkldnn = false) {
...
@@ -156,7 +156,7 @@ void profile(bool use_mkldnn = false) {
AnalysisConfig
cfg
;
AnalysisConfig
cfg
;
SetConfig
(
&
cfg
,
use_mkldnn
);
SetConfig
(
&
cfg
,
use_mkldnn
);
std
::
vector
<
PaddleTensor
>
outputs
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>
>
outputs
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
SetInput
(
&
input_slots_all
);
SetInput
(
&
input_slots_all
);
TestPrediction
(
reinterpret_cast
<
const
PaddlePredictor
::
Config
*>
(
&
cfg
),
TestPrediction
(
reinterpret_cast
<
const
PaddlePredictor
::
Config
*>
(
&
cfg
),
...
...
paddle/fluid/inference/tests/api/analyzer_text_classification_tester.cc
浏览文件 @
9b6a0296
...
@@ -70,7 +70,7 @@ TEST(Analyzer_Text_Classification, profile) {
...
@@ -70,7 +70,7 @@ TEST(Analyzer_Text_Classification, profile) {
AnalysisConfig
cfg
;
AnalysisConfig
cfg
;
SetConfig
(
&
cfg
);
SetConfig
(
&
cfg
);
cfg
.
SwitchIrDebug
();
cfg
.
SwitchIrDebug
();
std
::
vector
<
PaddleTensor
>
outputs
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>
>
outputs
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
SetInput
(
&
input_slots_all
);
SetInput
(
&
input_slots_all
);
...
@@ -79,8 +79,9 @@ TEST(Analyzer_Text_Classification, profile) {
...
@@ -79,8 +79,9 @@ TEST(Analyzer_Text_Classification, profile) {
if
(
FLAGS_num_threads
==
1
)
{
if
(
FLAGS_num_threads
==
1
)
{
// Get output
// Get output
LOG
(
INFO
)
<<
"get outputs "
<<
outputs
.
size
();
PADDLE_ENFORCE_GT
(
outputs
.
size
(),
0
);
for
(
auto
&
output
:
outputs
)
{
LOG
(
INFO
)
<<
"get outputs "
<<
outputs
.
back
().
size
();
for
(
auto
&
output
:
outputs
.
back
())
{
LOG
(
INFO
)
<<
"output.shape: "
<<
to_string
(
output
.
shape
);
LOG
(
INFO
)
<<
"output.shape: "
<<
to_string
(
output
.
shape
);
// no lod ?
// no lod ?
CHECK_EQ
(
output
.
lod
.
size
(),
0UL
);
CHECK_EQ
(
output
.
lod
.
size
(),
0UL
);
...
...
paddle/fluid/inference/tests/api/analyzer_transformer_tester.cc
浏览文件 @
9b6a0296
...
@@ -186,7 +186,7 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
...
@@ -186,7 +186,7 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
void
profile
(
bool
use_mkldnn
=
false
)
{
void
profile
(
bool
use_mkldnn
=
false
)
{
AnalysisConfig
cfg
;
AnalysisConfig
cfg
;
SetConfig
(
&
cfg
);
SetConfig
(
&
cfg
);
std
::
vector
<
PaddleTensor
>
outputs
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>
>
outputs
;
if
(
use_mkldnn
)
{
if
(
use_mkldnn
)
{
cfg
.
EnableMKLDNN
();
cfg
.
EnableMKLDNN
();
}
}
...
...
paddle/fluid/inference/tests/api/analyzer_vis_tester.cc
浏览文件 @
9b6a0296
...
@@ -87,7 +87,7 @@ void profile(bool use_mkldnn = false) {
...
@@ -87,7 +87,7 @@ void profile(bool use_mkldnn = false) {
cfg
.
EnableMKLDNN
();
cfg
.
EnableMKLDNN
();
}
}
// cfg.pass_builder()->TurnOnDebug();
// cfg.pass_builder()->TurnOnDebug();
std
::
vector
<
PaddleTensor
>
outputs
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>
>
outputs
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
SetInput
(
&
input_slots_all
);
SetInput
(
&
input_slots_all
);
...
@@ -100,7 +100,8 @@ void profile(bool use_mkldnn = false) {
...
@@ -100,7 +100,8 @@ void profile(bool use_mkldnn = false) {
auto
refer
=
ProcessALine
(
line
);
auto
refer
=
ProcessALine
(
line
);
file
.
close
();
file
.
close
();
auto
&
output
=
outputs
.
front
();
PADDLE_ENFORCE_GT
(
outputs
.
size
(),
0
);
auto
&
output
=
outputs
.
back
().
front
();
size_t
numel
=
output
.
data
.
length
()
/
PaddleDtypeSize
(
output
.
dtype
);
size_t
numel
=
output
.
data
.
length
()
/
PaddleDtypeSize
(
output
.
dtype
);
CHECK_EQ
(
numel
,
refer
.
data
.
size
());
CHECK_EQ
(
numel
,
refer
.
data
.
size
());
for
(
size_t
i
=
0
;
i
<
numel
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
numel
;
++
i
)
{
...
...
paddle/fluid/inference/tests/api/tester_helper.h
浏览文件 @
9b6a0296
...
@@ -41,7 +41,10 @@ DEFINE_string(model_name, "", "model name");
...
@@ -41,7 +41,10 @@ DEFINE_string(model_name, "", "model name");
DEFINE_string
(
infer_model
,
""
,
"model path"
);
DEFINE_string
(
infer_model
,
""
,
"model path"
);
DEFINE_string
(
infer_data
,
""
,
"data file"
);
DEFINE_string
(
infer_data
,
""
,
"data file"
);
DEFINE_string
(
refer_result
,
""
,
"reference result for comparison"
);
DEFINE_string
(
refer_result
,
""
,
"reference result for comparison"
);
DEFINE_int32
(
batch_size
,
1
,
"batch size."
);
DEFINE_int32
(
batch_size
,
1
,
"batch size"
);
DEFINE_int32
(
warmup_batch_size
,
100
,
"batch size for quantization warmup"
);
// setting iterations to 0 means processing the whole dataset
DEFINE_int32
(
iterations
,
0
,
"number of batches to process"
);
DEFINE_int32
(
repeat
,
1
,
"Running the inference program repeat times."
);
DEFINE_int32
(
repeat
,
1
,
"Running the inference program repeat times."
);
DEFINE_bool
(
test_all_data
,
false
,
"Test the all dataset in data file."
);
DEFINE_bool
(
test_all_data
,
false
,
"Test the all dataset in data file."
);
DEFINE_int32
(
num_threads
,
1
,
"Running the inference program in multi-threads."
);
DEFINE_int32
(
num_threads
,
1
,
"Running the inference program in multi-threads."
);
...
@@ -239,7 +242,7 @@ void SetFakeImageInput(std::vector<std::vector<PaddleTensor>> *inputs,
...
@@ -239,7 +242,7 @@ void SetFakeImageInput(std::vector<std::vector<PaddleTensor>> *inputs,
}
}
input
.
shape
=
shape
;
input
.
shape
=
shape
;
input
.
dtype
=
PaddleDType
::
FLOAT32
;
input
.
dtype
=
PaddleDType
::
FLOAT32
;
size_t
len
=
std
::
accumulate
(
shape
.
begin
(),
shape
.
end
(),
1
,
size_t
len
=
std
::
accumulate
(
shape
.
begin
(),
shape
.
end
(),
size_t
{
1
}
,
[](
int
a
,
int
b
)
{
return
a
*
b
;
});
[](
int
a
,
int
b
)
{
return
a
*
b
;
});
input
.
data
.
Resize
(
len
*
sizeof
(
float
));
input
.
data
.
Resize
(
len
*
sizeof
(
float
));
input
.
lod
.
assign
({{
0
,
static_cast
<
size_t
>
(
FLAGS_batch_size
)}});
input
.
lod
.
assign
({{
0
,
static_cast
<
size_t
>
(
FLAGS_batch_size
)}});
...
@@ -286,17 +289,18 @@ void ConvertPaddleTensorToZeroCopyTensor(
...
@@ -286,17 +289,18 @@ void ConvertPaddleTensorToZeroCopyTensor(
void
PredictionWarmUp
(
PaddlePredictor
*
predictor
,
void
PredictionWarmUp
(
PaddlePredictor
*
predictor
,
const
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
&
inputs
,
const
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
&
inputs
,
std
::
vector
<
PaddleTensor
>
*
outputs
,
int
num_thread
s
,
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
*
output
s
,
int
tid
)
{
int
num_threads
,
int
tid
)
{
int
batch_size
=
FLAGS_batch_size
;
int
batch_size
=
FLAGS_batch_size
;
LOG
(
INFO
)
<<
"Running thread "
<<
tid
<<
", warm up run..."
;
LOG
(
INFO
)
<<
"Running thread "
<<
tid
<<
", warm up run..."
;
if
(
FLAGS_zero_copy
)
{
if
(
FLAGS_zero_copy
)
{
ConvertPaddleTensorToZeroCopyTensor
(
predictor
,
inputs
[
0
]);
ConvertPaddleTensorToZeroCopyTensor
(
predictor
,
inputs
[
0
]);
}
}
outputs
->
resize
(
1
);
Timer
warmup_timer
;
Timer
warmup_timer
;
warmup_timer
.
tic
();
warmup_timer
.
tic
();
if
(
!
FLAGS_zero_copy
)
{
if
(
!
FLAGS_zero_copy
)
{
predictor
->
Run
(
inputs
[
0
],
outputs
,
batch_size
);
predictor
->
Run
(
inputs
[
0
],
&
(
*
outputs
)[
0
]
,
batch_size
);
}
else
{
}
else
{
predictor
->
ZeroCopyRun
();
predictor
->
ZeroCopyRun
();
}
}
...
@@ -308,11 +312,16 @@ void PredictionWarmUp(PaddlePredictor *predictor,
...
@@ -308,11 +312,16 @@ void PredictionWarmUp(PaddlePredictor *predictor,
void
PredictionRun
(
PaddlePredictor
*
predictor
,
void
PredictionRun
(
PaddlePredictor
*
predictor
,
const
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
&
inputs
,
const
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
&
inputs
,
std
::
vector
<
PaddleTensor
>
*
outputs
,
int
num_threads
,
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
*
outputs
,
int
tid
)
{
int
num_threads
,
int
tid
)
{
int
batch_size
=
FLAGS_batch_size
;
int
num_times
=
FLAGS_repeat
;
int
num_times
=
FLAGS_repeat
;
LOG
(
INFO
)
<<
"Thread "
<<
tid
<<
" run "
<<
num_times
<<
" times..."
;
int
iterations
=
inputs
.
size
();
// process the whole dataset ...
if
(
FLAGS_iterations
>
0
&&
FLAGS_iterations
<
inputs
.
size
())
iterations
=
FLAGS_iterations
;
// ... unless the number of iterations is set
outputs
->
resize
(
iterations
);
LOG
(
INFO
)
<<
"Thread "
<<
tid
<<
", number of threads "
<<
num_threads
<<
", run "
<<
num_times
<<
" times..."
;
Timer
run_timer
;
Timer
run_timer
;
double
elapsed_time
=
0
;
double
elapsed_time
=
0
;
#ifdef WITH_GPERFTOOLS
#ifdef WITH_GPERFTOOLS
...
@@ -320,14 +329,14 @@ void PredictionRun(PaddlePredictor *predictor,
...
@@ -320,14 +329,14 @@ void PredictionRun(PaddlePredictor *predictor,
#endif
#endif
if
(
!
FLAGS_zero_copy
)
{
if
(
!
FLAGS_zero_copy
)
{
run_timer
.
tic
();
run_timer
.
tic
();
for
(
size_t
i
=
0
;
i
<
i
nputs
.
size
()
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
i
terations
;
i
++
)
{
for
(
int
j
=
0
;
j
<
num_times
;
j
++
)
{
for
(
int
j
=
0
;
j
<
num_times
;
j
++
)
{
predictor
->
Run
(
inputs
[
i
],
outputs
,
batch_size
);
predictor
->
Run
(
inputs
[
i
],
&
(
*
outputs
)[
i
],
FLAGS_
batch_size
);
}
}
}
}
elapsed_time
=
run_timer
.
toc
();
elapsed_time
=
run_timer
.
toc
();
}
else
{
}
else
{
for
(
size_t
i
=
0
;
i
<
i
nputs
.
size
()
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
i
terations
;
i
++
)
{
ConvertPaddleTensorToZeroCopyTensor
(
predictor
,
inputs
[
i
]);
ConvertPaddleTensorToZeroCopyTensor
(
predictor
,
inputs
[
i
]);
run_timer
.
tic
();
run_timer
.
tic
();
for
(
int
j
=
0
;
j
<
num_times
;
j
++
)
{
for
(
int
j
=
0
;
j
<
num_times
;
j
++
)
{
...
@@ -340,13 +349,14 @@ void PredictionRun(PaddlePredictor *predictor,
...
@@ -340,13 +349,14 @@ void PredictionRun(PaddlePredictor *predictor,
ProfilerStop
();
ProfilerStop
();
#endif
#endif
PrintTime
(
batch_size
,
num_times
,
num_threads
,
tid
,
elapsed_time
/
num_times
,
auto
batch_latency
=
elapsed_time
/
(
iterations
*
num_times
);
inputs
.
size
());
PrintTime
(
FLAGS_batch_size
,
num_times
,
num_threads
,
tid
,
batch_latency
,
iterations
);
if
(
FLAGS_record_benchmark
)
{
if
(
FLAGS_record_benchmark
)
{
Benchmark
benchmark
;
Benchmark
benchmark
;
benchmark
.
SetName
(
FLAGS_model_name
);
benchmark
.
SetName
(
FLAGS_model_name
);
benchmark
.
SetBatchSize
(
batch_size
);
benchmark
.
SetBatchSize
(
FLAGS_
batch_size
);
benchmark
.
SetLatency
(
elapsed_time
/
num_times
);
benchmark
.
SetLatency
(
batch_latency
);
benchmark
.
PersistToFile
(
"benchmark_record.txt"
);
benchmark
.
PersistToFile
(
"benchmark_record.txt"
);
}
}
}
}
...
@@ -354,16 +364,17 @@ void PredictionRun(PaddlePredictor *predictor,
...
@@ -354,16 +364,17 @@ void PredictionRun(PaddlePredictor *predictor,
void
TestOneThreadPrediction
(
void
TestOneThreadPrediction
(
const
PaddlePredictor
::
Config
*
config
,
const
PaddlePredictor
::
Config
*
config
,
const
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
&
inputs
,
const
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
&
inputs
,
std
::
vector
<
PaddleTensor
>
*
outputs
,
bool
use_analysis
=
true
)
{
std
::
vector
<
std
::
vector
<
PaddleTensor
>
>
*
outputs
,
bool
use_analysis
=
true
)
{
auto
predictor
=
CreateTestPredictor
(
config
,
use_analysis
);
auto
predictor
=
CreateTestPredictor
(
config
,
use_analysis
);
PredictionWarmUp
(
predictor
.
get
(),
inputs
,
outputs
,
1
,
0
);
PredictionWarmUp
(
predictor
.
get
(),
inputs
,
outputs
,
FLAGS_paddle_num_threads
,
PredictionRun
(
predictor
.
get
(),
inputs
,
outputs
,
1
,
0
);
0
);
PredictionRun
(
predictor
.
get
(),
inputs
,
outputs
,
FLAGS_paddle_num_threads
,
0
);
}
}
void
TestMultiThreadPrediction
(
void
TestMultiThreadPrediction
(
const
PaddlePredictor
::
Config
*
config
,
const
PaddlePredictor
::
Config
*
config
,
const
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
&
inputs
,
const
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
&
inputs
,
std
::
vector
<
PaddleTensor
>
*
outputs
,
int
num_threads
,
std
::
vector
<
std
::
vector
<
PaddleTensor
>
>
*
outputs
,
int
num_threads
,
bool
use_analysis
=
true
)
{
bool
use_analysis
=
true
)
{
std
::
vector
<
std
::
thread
>
threads
;
std
::
vector
<
std
::
thread
>
threads
;
std
::
vector
<
std
::
unique_ptr
<
PaddlePredictor
>>
predictors
;
std
::
vector
<
std
::
unique_ptr
<
PaddlePredictor
>>
predictors
;
...
@@ -376,7 +387,7 @@ void TestMultiThreadPrediction(
...
@@ -376,7 +387,7 @@ void TestMultiThreadPrediction(
threads
.
emplace_back
([
&
,
tid
]()
{
threads
.
emplace_back
([
&
,
tid
]()
{
// Each thread should have local inputs and outputs.
// Each thread should have local inputs and outputs.
// The inputs of each thread are all the same.
// The inputs of each thread are all the same.
std
::
vector
<
PaddleTensor
>
outputs_tid
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>
>
outputs_tid
;
auto
&
predictor
=
predictors
[
tid
];
auto
&
predictor
=
predictors
[
tid
];
#ifdef PADDLE_WITH_MKLDNN
#ifdef PADDLE_WITH_MKLDNN
if
(
use_analysis
)
{
if
(
use_analysis
)
{
...
@@ -384,8 +395,8 @@ void TestMultiThreadPrediction(
...
@@ -384,8 +395,8 @@ void TestMultiThreadPrediction(
->
SetMkldnnThreadID
(
static_cast
<
int
>
(
tid
)
+
1
);
->
SetMkldnnThreadID
(
static_cast
<
int
>
(
tid
)
+
1
);
}
}
#endif
#endif
PredictionWarmUp
(
predictor
.
get
(),
inputs
,
outputs
,
num_threads
,
tid
);
PredictionWarmUp
(
predictor
.
get
(),
inputs
,
&
outputs_tid
,
num_threads
,
tid
);
PredictionRun
(
predictor
.
get
(),
inputs
,
outputs
,
num_threads
,
tid
);
PredictionRun
(
predictor
.
get
(),
inputs
,
&
outputs_tid
,
num_threads
,
tid
);
});
});
}
}
for
(
int
i
=
0
;
i
<
num_threads
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_threads
;
++
i
)
{
...
@@ -395,8 +406,8 @@ void TestMultiThreadPrediction(
...
@@ -395,8 +406,8 @@ void TestMultiThreadPrediction(
void
TestPrediction
(
const
PaddlePredictor
::
Config
*
config
,
void
TestPrediction
(
const
PaddlePredictor
::
Config
*
config
,
const
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
&
inputs
,
const
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
&
inputs
,
std
::
vector
<
PaddleTensor
>
*
outputs
,
int
num_thread
s
,
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
*
output
s
,
bool
use_analysis
=
FLAGS_use_analysis
)
{
int
num_threads
,
bool
use_analysis
=
FLAGS_use_analysis
)
{
PrintConfig
(
config
,
use_analysis
);
PrintConfig
(
config
,
use_analysis
);
if
(
num_threads
==
1
)
{
if
(
num_threads
==
1
)
{
TestOneThreadPrediction
(
config
,
inputs
,
outputs
,
use_analysis
);
TestOneThreadPrediction
(
config
,
inputs
,
outputs
,
use_analysis
);
...
@@ -406,30 +417,41 @@ void TestPrediction(const PaddlePredictor::Config *config,
...
@@ -406,30 +417,41 @@ void TestPrediction(const PaddlePredictor::Config *config,
}
}
}
}
void
CompareTopAccuracy
(
const
std
::
vector
<
PaddleTensor
>
&
output_slots1
,
void
CompareTopAccuracy
(
const
std
::
vector
<
PaddleTensor
>
&
output_slots2
)
{
const
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
&
output_slots_quant
,
// first output: avg_cost
const
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
&
output_slots_ref
)
{
if
(
output_slots
1
.
size
()
==
0
||
output_slots2
.
size
()
==
0
)
if
(
output_slots
_quant
.
size
()
==
0
||
output_slots_ref
.
size
()
==
0
)
throw
std
::
invalid_argument
(
throw
std
::
invalid_argument
(
"CompareTopAccuracy: output_slots vector is empty."
);
"CompareTopAccuracy: output_slots vector is empty."
);
PADDLE_ENFORCE
(
output_slots1
.
size
()
>=
2UL
);
PADDLE_ENFORCE
(
output_slots2
.
size
()
>=
2UL
);
float
total_accs1_quant
{
0
};
float
total_accs1_ref
{
0
};
for
(
size_t
i
=
0
;
i
<
output_slots_quant
.
size
();
++
i
)
{
PADDLE_ENFORCE
(
output_slots_quant
[
i
].
size
()
>=
2UL
);
PADDLE_ENFORCE
(
output_slots_ref
[
i
].
size
()
>=
2UL
);
// second output: acc_top1
// second output: acc_top1
if
(
output_slots1
[
1
].
lod
.
size
()
>
0
||
output_slots2
[
1
].
lod
.
size
()
>
0
)
if
(
output_slots_quant
[
i
][
1
].
lod
.
size
()
>
0
||
output_slots_ref
[
i
][
1
].
lod
.
size
()
>
0
)
throw
std
::
invalid_argument
(
throw
std
::
invalid_argument
(
"CompareTopAccuracy: top1 accuracy output has nonempty LoD."
);
"CompareTopAccuracy: top1 accuracy output has nonempty LoD."
);
if
(
output_slots1
[
1
].
dtype
!=
paddle
::
PaddleDType
::
FLOAT32
||
if
(
output_slots_quant
[
i
]
[
1
].
dtype
!=
paddle
::
PaddleDType
::
FLOAT32
||
output_slots2
[
1
].
dtype
!=
paddle
::
PaddleDType
::
FLOAT32
)
output_slots_ref
[
i
]
[
1
].
dtype
!=
paddle
::
PaddleDType
::
FLOAT32
)
throw
std
::
invalid_argument
(
throw
std
::
invalid_argument
(
"CompareTopAccuracy: top1 accuracy output is of a wrong type."
);
"CompareTopAccuracy: top1 accuracy output is of a wrong type."
);
float
*
top1_quantized
=
static_cast
<
float
*>
(
output_slots1
[
1
].
data
.
data
());
total_accs1_quant
+=
float
*
top1_reference
=
static_cast
<
float
*>
(
output_slots2
[
1
].
data
.
data
());
*
static_cast
<
float
*>
(
output_slots_quant
[
i
][
1
].
data
.
data
());
LOG
(
INFO
)
<<
"top1 INT8 accuracy: "
<<
*
top1_quantized
;
total_accs1_ref
+=
LOG
(
INFO
)
<<
"top1 FP32 accuracy: "
<<
*
top1_reference
;
*
static_cast
<
float
*>
(
output_slots_ref
[
i
][
1
].
data
.
data
());
}
float
avg_acc1_quant
=
total_accs1_quant
/
output_slots_quant
.
size
();
float
avg_acc1_ref
=
total_accs1_ref
/
output_slots_ref
.
size
();
LOG
(
INFO
)
<<
"Avg top1 INT8 accuracy: "
<<
std
::
fixed
<<
std
::
setw
(
6
)
<<
std
::
setprecision
(
4
)
<<
avg_acc1_quant
;
LOG
(
INFO
)
<<
"Avg top1 FP32 accuracy: "
<<
std
::
fixed
<<
std
::
setw
(
6
)
<<
std
::
setprecision
(
4
)
<<
avg_acc1_ref
;
LOG
(
INFO
)
<<
"Accepted accuracy drop threshold: "
<<
FLAGS_quantized_accuracy
;
LOG
(
INFO
)
<<
"Accepted accuracy drop threshold: "
<<
FLAGS_quantized_accuracy
;
CHECK_LE
(
std
::
abs
(
*
top1_quantized
-
*
top1_reference
),
CHECK_LE
(
std
::
abs
(
avg_acc1_quant
-
avg_acc1_ref
),
FLAGS_quantized_accuracy
);
FLAGS_quantized_accuracy
);
}
}
void
CompareDeterministic
(
void
CompareDeterministic
(
...
@@ -455,20 +477,35 @@ void CompareNativeAndAnalysis(
...
@@ -455,20 +477,35 @@ void CompareNativeAndAnalysis(
const
PaddlePredictor
::
Config
*
config
,
const
PaddlePredictor
::
Config
*
config
,
const
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
&
inputs
)
{
const
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
&
inputs
)
{
PrintConfig
(
config
,
true
);
PrintConfig
(
config
,
true
);
std
::
vector
<
PaddleTensor
>
native_outputs
,
analysis_outputs
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>
>
native_outputs
,
analysis_outputs
;
TestOneThreadPrediction
(
config
,
inputs
,
&
native_outputs
,
false
);
TestOneThreadPrediction
(
config
,
inputs
,
&
native_outputs
,
false
);
TestOneThreadPrediction
(
config
,
inputs
,
&
analysis_outputs
,
true
);
TestOneThreadPrediction
(
config
,
inputs
,
&
analysis_outputs
,
true
);
CompareResult
(
analysis_outputs
,
native_outputs
);
PADDLE_ENFORCE
(
native_outputs
.
size
()
>
0
,
"Native output is empty."
);
PADDLE_ENFORCE
(
analysis_outputs
.
size
()
>
0
,
"Analysis output is empty."
);
CompareResult
(
analysis_outputs
.
back
(),
native_outputs
.
back
());
}
}
void
CompareQuantizedAndAnalysis
(
void
CompareQuantizedAndAnalysis
(
const
PaddlePredictor
::
Config
*
config
,
const
AnalysisConfig
*
config
,
const
AnalysisConfig
*
qconfig
,
const
PaddlePredictor
::
Config
*
qconfig
,
const
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
&
inputs
)
{
const
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
&
inputs
)
{
PrintConfig
(
config
,
true
);
PADDLE_ENFORCE_EQ
(
inputs
[
0
][
0
].
shape
[
0
],
FLAGS_batch_size
,
std
::
vector
<
PaddleTensor
>
analysis_outputs
,
quantized_outputs
;
"Input data has to be packed batch by batch."
);
TestOneThreadPrediction
(
config
,
inputs
,
&
analysis_outputs
,
true
);
LOG
(
INFO
)
<<
"FP32 & INT8 prediction run: batch_size "
<<
FLAGS_batch_size
TestOneThreadPrediction
(
qconfig
,
inputs
,
&
quantized_outputs
,
true
);
<<
", warmup batch size "
<<
FLAGS_warmup_batch_size
<<
"."
;
LOG
(
INFO
)
<<
"--- FP32 prediction start ---"
;
auto
*
cfg
=
reinterpret_cast
<
const
PaddlePredictor
::
Config
*>
(
config
);
PrintConfig
(
cfg
,
true
);
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
analysis_outputs
;
TestOneThreadPrediction
(
cfg
,
inputs
,
&
analysis_outputs
,
true
);
LOG
(
INFO
)
<<
"--- INT8 prediction start ---"
;
auto
*
qcfg
=
reinterpret_cast
<
const
PaddlePredictor
::
Config
*>
(
qconfig
);
PrintConfig
(
qcfg
,
true
);
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
quantized_outputs
;
TestOneThreadPrediction
(
qcfg
,
inputs
,
&
quantized_outputs
,
true
);
LOG
(
INFO
)
<<
"--- comparing outputs --- "
;
CompareTopAccuracy
(
quantized_outputs
,
analysis_outputs
);
CompareTopAccuracy
(
quantized_outputs
,
analysis_outputs
);
}
}
...
@@ -578,9 +615,9 @@ static bool CompareTensorData(const framework::LoDTensor &a,
...
@@ -578,9 +615,9 @@ static bool CompareTensorData(const framework::LoDTensor &a,
const
framework
::
LoDTensor
&
b
)
{
const
framework
::
LoDTensor
&
b
)
{
auto
a_shape
=
framework
::
vectorize
(
a
.
dims
());
auto
a_shape
=
framework
::
vectorize
(
a
.
dims
());
auto
b_shape
=
framework
::
vectorize
(
b
.
dims
());
auto
b_shape
=
framework
::
vectorize
(
b
.
dims
());
size_t
a_size
=
std
::
accumulate
(
a_shape
.
begin
(),
a_shape
.
end
(),
1
,
size_t
a_size
=
std
::
accumulate
(
a_shape
.
begin
(),
a_shape
.
end
(),
size_t
{
1
}
,
[](
int
a
,
int
b
)
{
return
a
*
b
;
});
[](
int
a
,
int
b
)
{
return
a
*
b
;
});
size_t
b_size
=
std
::
accumulate
(
b_shape
.
begin
(),
b_shape
.
end
(),
1
,
size_t
b_size
=
std
::
accumulate
(
b_shape
.
begin
(),
b_shape
.
end
(),
size_t
{
1
}
,
[](
int
a
,
int
b
)
{
return
a
*
b
;
});
[](
int
a
,
int
b
)
{
return
a
*
b
;
});
if
(
a_size
!=
b_size
)
{
if
(
a_size
!=
b_size
)
{
LOG
(
ERROR
)
<<
string
::
Sprintf
(
"tensor data size not match, %d != %d"
,
LOG
(
ERROR
)
<<
string
::
Sprintf
(
"tensor data size not match, %d != %d"
,
...
...
paddle/fluid/inference/tests/api/trt_models_tester.cc
浏览文件 @
9b6a0296
...
@@ -74,7 +74,7 @@ void profile(std::string model_dir, bool use_analysis, bool use_tensorrt) {
...
@@ -74,7 +74,7 @@ void profile(std::string model_dir, bool use_analysis, bool use_tensorrt) {
SetFakeImageInput
(
&
inputs_all
,
model_dir
,
false
,
"__model__"
,
""
);
SetFakeImageInput
(
&
inputs_all
,
model_dir
,
false
,
"__model__"
,
""
);
}
}
std
::
vector
<
PaddleTensor
>
outputs
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>
>
outputs
;
if
(
use_analysis
||
use_tensorrt
)
{
if
(
use_analysis
||
use_tensorrt
)
{
AnalysisConfig
config
;
AnalysisConfig
config
;
config
.
EnableUseGpu
(
100
,
0
);
config
.
EnableUseGpu
(
100
,
0
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录