未验证 提交 4c0f0f6f 编写于 作者: L lidanqing 提交者: GitHub

mkldnn quant aware demo and document (#198)

* Add mkldnn quantization demo and document
上级 676ab9b2
CMAKE_MINIMUM_REQUIRED(VERSION 3.2)
project(mkldnn_quantaware_demo CXX C)
set(DEMO_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR})
set(DEMO_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})
option(USE_GPU "Compile the inference code with the support CUDA GPU" OFF)
option(USE_PROFILER "Whether enable Paddle's profiler." OFF)
set(USE_SHARED OFF)
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
if(NOT PADDLE_ROOT)
set(PADDLE_ROOT ${DEMO_SOURCE_DIR}/fluid_inference)
endif()
find_package(Fluid)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O3")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -std=c++11")
if(USE_PROFILER)
find_package(Gperftools REQUIRED)
include_directories(${GPERFTOOLS_INCLUDE_DIR})
add_definitions(-DWITH_GPERFTOOLS)
endif()
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
if(PADDLE_FOUND)
add_executable(inference sample_tester.cc)
target_link_libraries(inference
${PADDLE_LIBRARIES}
${PADDLE_THIRD_PARTY_LIBRARIES}
rt dl pthread)
if (mklml_FOUND)
target_link_libraries(inference "-L${THIRD_PARTY_ROOT}/install/mklml/lib -liomp5 -Wl,--as-needed")
endif()
else()
message(FATAL_ERROR "Cannot find PaddlePaddle Fluid under ${PADDLE_ROOT}")
endif()
# 图像分类INT8模型在CPU优化部署和预测
## 概述
本文主要介绍在CPU上转化、部署和执行PaddleSlim产出的量化模型的流程。在Intel(R) Xeon(R) Gold 6271机器上,量化后的INT8模型为优化后FP32模型的3-4倍,而精度仅有极小下降。
流程步骤如下:
- 产出量化模型:使用PaddleSlim训练产出量化模型,注意模型的weights的值应该在INT8范围内,但是类型仍为float型。
- CPU转换量化模型:在CPU上使用DNNL转化量化模型为真正的INT8模型
- CPU部署预测:在CPU上部署demo应用并预测
## 1. 准备
#### 安装构建PaddleSlim
PaddleSlim 安装请参考[官方安装文档](https://paddlepaddle.github.io/PaddleSlim/install.html)安装
```
git clone https://github.com/PaddlePaddle/PaddleSlim.git
cd PaddleSlim
python setup.py install
```
#### 在代码中使用
在用户自己的测试样例中,按以下方式导入Paddle和PaddleSlim:
```
import paddle
import paddle.fluid as fluid
import paddleslim as slim
import numpy as np
```
## 2. 用PaddleSlim产出量化模型
使用PaddleSlim产出量化训练模型或者离线量化模型。
#### 2.1 量化训练
量化训练流程可以参考 [分类模型的离线量化流程](https://paddlepaddle.github.io/PaddleSlim/tutorials/quant_aware_demo/)
**注意量化训练过程中config参数:**
- **quantize_op_types:** 目前CPU上支持量化 `depthwise_conv2d`, `mul`, `conv2d`, `matmul`, `transpose2`, `reshape2`, `pool2d`, `scale`。但是训练阶段插入fake quantize/dequantize op时,只需在前四种op前后插入fake quantize/dequantize ops,因为后面四种op `matmul`, `transpose2`, `reshape2`, `pool2d`的输入输出scale不变,将从前后方op的输入输出scales获得scales,所以`quantize_op_types` 参数只需要 `depthwise_conv2d`, `mul`, `conv2d`, `matmul` 即可。
- **其他参数:** 请参考 [PaddleSlim quant_aware API](https://paddlepaddle.github.io/PaddleSlim/api/quantization_api/#quant_aware)
#### 2.2 离线量化
离线量化模型产出可以参考[分类模型的离线量化流程](https://paddlepaddle.github.io/PaddleSlim/tutorials/quant_post_demo/#_1)
## 3. 转化产出的量化模型为DNNL优化后的INT8模型
为了部署在CPU上,我们将保存的quant模型,通过一个转化脚本,移除fake quantize/dequantize op,fuse一些op,并且完全转化成 INT8 模型。需要使用Paddle所在目录运行下面的脚本,脚本在官网的位置为[save_qat_model.py](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/fluid/contrib/slim/tests/save_qat_model.py)。复制脚本到demo所在目录下(`/PATH_TO_PaddleSlim/demo/mkldnn_quant/quant_aware/`)并执行如下命令:
```
python save_qat_model.py --qat_model_path=/PATH/TO/SAVE/FLOAT32/QAT/MODEL --int8_model_save_path=/PATH/TO/SAVE/INT8/MODEL -ops_to_quantize="conv2d,pool2d"
```
**参数说明:**
- **qat_model_path:** 为输入参数,必填。为量化训练产出的quant模型。
- **int8_model_save_path:** 将quant模型经过DNNL优化量化后保存的最终INT8模型路径。注意:qat_model_path必须传入量化训练后的含有fake quant/dequant ops的quant模型
- **ops_to_quantize:** 必填,不可以不设置。表示最终INT8模型中使用量化op的列表。图像分类模型请设置`--ops_to_quantize=“conv2d, pool2d"`。自然语言处理模型,如Ernie模型,请设置`--ops_to_quantize="fc,reshape2,transpose2,matmul"`。用户必须手动设置,因为不是量化所有可量化的op就能达到最优速度。
注意:
- 目前支持DNNL量化op列表是`conv2d`, `depthwise_conv2d`, `mul`, `fc`, `matmul`, `pool2d`, `reshape2`, `transpose2`, `concat`,只能从这个列表中选择。
- 量化所有可量化的Op不一定性能最优,所以用户要手动输入。比如,如果一个op是单个的INT8 op, 不可以与之前的和之后的op融合,那么为了量化这个op,需要先做quantize,然后运行INT8 op, 再dequantize, 这样可能导致最终性能不如保持该op为fp32 op。由于用户模型未知,这里不给出默认设置。图像分类和NLP任务的设置建议已给出。
- 一个有效找到最优配置的方法是,用户观察这个模型一共用到了哪些可量化的op,选出不同的`ops_to_quantize`组合,多运行几次。
## 4. 预测
### 4.1 数据预处理转化
在精度和性能预测中,需要先对数据进行二进制转化。运行脚本如下可转化完整ILSVRC2012 val数据集。使用`--local`可以转化用户自己的数据。在Paddle所在目录运行下面的脚本。脚本在官网位置为[full_ILSVRC2012_val_preprocess.py](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/inference/tests/api/full_ILSVRC2012_val_preprocess.py)
```
python Paddle/paddle/fluid/inference/tests/api/full_ILSVRC2012_val_preprocess.py --local --data_dir=/PATH/TO/USER/DATASET/ --output_file=/PATH/TO/SAVE/BINARY/FILE
```
可选参数:
- 不设置任何参数。脚本将下载 ILSVRC2012_img_val数据集,并转化为二进制文件。
- **local:** 设置便为true,表示用户将提供自己的数据
- **data_dir:** 用户自己的数据目录
- **label_list:** 图片路径-图片类别列表文件,类似于`val_list.txt`
- **output_file:** 生成的binary文件路径。
- **data_dim:** 预处理图片的长和宽。默认值 224。
用户自己的数据集目录结构应该如下
```
imagenet_user
├── val
│   ├── ILSVRC2012_val_00000001.jpg
│   ├── ILSVRC2012_val_00000002.jpg
| |── ...
└── val_list.txt
```
其中,val_list.txt 内容应该如下:
```
val/ILSVRC2012_val_00000001.jpg 0
val/ILSVRC2012_val_00000002.jpg 0
```
注意:
- 为什么将数据集转化为二进制文件?因为paddle中的数据预处理(resize, crop等)都使用pythong.Image模块进行,训练出的模型也是基于Python预处理的图片,但是我们发现Python测试性能开销很大,导致预测性能下降。为了获得良好性能,在量化模型预测阶段,我们决定使用C++测试,而C++只支持Open-CV等库,Paddle不建议使用外部库,因此我们使用Python将图片预处理然后放入二进制文件,再在C++测试中读出。用户根据自己的需要,可以更改C++测试以使用open-cv库直接读数据并预处理,精度不会有太大下降。我们还提供了python测试`sample_tester.py`作为参考,与C++测试`sample_tester.cc`相比,用户可以看到Python测试更大的性能开销。
### 4.2 部署预测
#### 部署前提
- 只有使用AVX512系列CPU服务器才能获得性能提升。用户可以通过在命令行红输入`lscpu`查看本机支持指令。
- 在支持`avx512_vnni`的CPU服务器上,INT8精度最高,性能提升最快。
#### 准备预测推理库
用户可以从源码编译Paddle推理库,也可以直接下载推理库。
- 用户可以从Paddle源码编译Paddle推理库,参考[从源码编译](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/advanced_guide/inference_deployment/inference/build_and_install_lib_cn.html#id12),使用release/2.0以上版本。
- 用户也可以从Paddle官网下载发布的[预测库](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/advanced_guide/inference_deployment/inference/build_and_install_lib_cn.html)。请选择`ubuntu14.04_cpu_avx_mkl` 最新发布版或者develop版。
你可以将准备好的预测库解压并重命名为fluid_inference,放在当前目录下(`/PATH_TO_PaddleSlim/demo/mkldnn_quant/quant_aware/`)。或者在cmake时通过设置PADDLE_ROOT来指定Paddle预测库的位置。
#### 编译应用
样例所在目录为PaddleSlim下`demo/mkldnn_quant/quant_aware/`,样例`sample_tester.cc`和编译所需`cmake`文件夹都在这个目录下。
```
cd /PATH/TO/PaddleSlim
cd demo/mkldnn_quant/quant_aware
mkdir build
cd build
make -j
```
如果你从官网下载解压了[预测库](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/advanced_guide/inference_deployment/inference/build_and_install_lib_cn.html)到当前目录下,这里`-DPADDLE_ROOT`可以不设置,因为`DPADDLE_ROOT`默认位置`demo/mkldnn_quant/quant_aware/fluid_inference`
#### 运行测试
```
# Bind threads to cores
export KMP_AFFINITY=granularity=fine,compact,1,0
export KMP_BLOCKTIME=1
# Turbo Boost could be set to OFF using the command
echo 1 | sudo tee /sys/devices/system/cpu/intel_pstate/no_turbo
# In the file run.sh, set `MODEL_DIR` to `/PATH/TO/FLOAT32/MODEL`或者`/PATH/TO/SAVE/INT8/MODEL`
# In the file run.sh, set `DATA_FILE` to `/PATH/TO/SAVE/BINARY/FILE`
# For 1 thread performance:
./run.sh
# For 20 thread performance:
./run.sh -1 20
```
运行时需要配置以下参数:
- **infer_model:** 模型所在目录,注意模型参数当前必须是分开保存成多个文件的。可以设置为`PATH/TO/SAVE/INT8/MODEL`, `PATH/TO/SAVE/FLOAT32/MODEL`。无默认值。
- **infer_data:** 测试数据文件所在路径。注意需要是经`full_ILSVRC2012_val_preprocess`转化后的binary文件。
- **batch_size:** 预测batch size大小。默认值为50。
- **iterations:** 预测多少batches。默认为0,表示预测infer_data中所有batches (image numbers/batch size)
- **num_threads:** 预测使用CPU 线程数,默认为单核一个线程。
- **with_accuracy_layer:** 由于这个测试是Image Classification通用的测试,既可以测试float32模型也可以INT8模型,模型可以包含或者不包含label层,设置此参数更改。
- **optimize_fp32_model** 是否优化测试FP32模型。样例可以测试保存的INT8模型,也可以优化(fuses等)并测试优化后的FP32模型。默认为False,表示测试转化好的INT8模型,此处无需优化。
- **use_profile:** 由Paddle预测库中提供,设置用来进行性能分析。默认值为false。
你可以直接修改`/PATH_TO_PaddleSlim/demo/mkldnn_quant/quant_aware/`目录下的`run.sh`中的MODEL_DIR和DATA_DIR,即可执行`./run.sh`进行CPU预测。
### 4.3 用户编写自己的测试:
如果用户编写自己的测试:
1. 测试INT8模型
如果用户测试转化好的INT8模型,使用 paddle::NativeConfig 即可测试。在demo中,设置`optimize_fp32_model`为false。
2. 测试FP32模型
如果用户要测试PF32模型,可以使用AnalysisConfig对原始FP32模型先优化(fuses等)再测试。AnalysisConfig配置设置如下:
```
static void SetConfig(paddle::AnalysisConfig *cfg) {
cfg->SetModel(FLAGS_infer_model); // 必须。表示需要测试的模型
cfg->DisableGpu(); // 必须。部署在CPU上预测,必须Disablegpu
cfg->EnableMKLDNN(); //必须。表示使用MKLDNN算子,将比 native 快
cfg->SwitchIrOptim(); // 如果传入FP32原始,这个配置设置为true将优化加速模型(如进行fuses等)
cfg->SetCpuMathLibraryNumThreads(FLAGS_num_threads); //默认设置为1。表示多线程运行
if(FLAGS_use_profile){
cfg->EnableProfile(); // 可选。如果设置use_profile,运行结束将展现各个算子所占用时间
}
}
```
在我们提供的样例中,只要设置`optimize_fp32_model`为true,`infer_model`传入原始FP32模型,AnalysisConfig的上述设置将被执行,传入的FP32模型将被DNNL优化加速(包括fuses等)。
如果infer_model传入INT8模型,则optimize_fp32_model将不起作用,因为INT8模型已经被优化量化。
如果infer_model传入PaddleSlim产出的模型,optimize_fp32_model也不起作用,因为quant模型包含fake quantize/dequantize ops,无法fuse,无法优化。
## 5. 精度和性能数据
INT8模型精度和性能结果参考[CPU部署预测INT8模型的精度和性能](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/docs/zh_cn/tutorials/image_classification_mkldnn_quant_aware_tutorial.md)
## FAQ
- 自然语言处理模型在CPU上的部署和预测参考样例[ERNIE 模型 QAT INT8 精度与性能复现](https://github.com/PaddlePaddle/benchmark/tree/master/Inference/c%2B%2B/ernie/mkldnn)
- 具体DNNL优化原理可以查看[SLIM QAT for INT8 DNNL](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/fluid/contrib/slim/tests/QAT_mkldnn_int8_readme.md)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
set(PADDLE_FOUND OFF)
if(NOT PADDLE_ROOT)
set(PADDLE_ROOT $ENV{PADDLE_ROOT} CACHE PATH "Paddle Path")
endif()
if(NOT PADDLE_ROOT)
message(FATAL_ERROR "Set PADDLE_ROOT as your root directory installed PaddlePaddle")
endif()
set(THIRD_PARTY_ROOT ${PADDLE_ROOT}/third_party)
if(USE_GPU)
set(CUDA_ROOT $ENV{CUDA_ROOT} CACHE PATH "CUDA root Path")
set(CUDNN_ROOT $ENV{CUDNN_ROOT} CACHE PATH "CUDNN root Path")
endif()
# Support directory orgnizations
find_path(PADDLE_INC_DIR NAMES paddle_inference_api.h PATHS ${PADDLE_ROOT}/paddle/include)
if(PADDLE_INC_DIR)
set(LIB_PATH "paddle/lib")
else()
find_path(PADDLE_INC_DIR NAMES paddle/fluid/inference/paddle_inference_api.h PATHS ${PADDLE_ROOT})
if(PADDLE_INC_DIR)
include_directories(${PADDLE_ROOT}/paddle/fluid/inference)
endif()
set(LIB_PATH "paddle/fluid/inference")
endif()
include_directories(${PADDLE_INC_DIR})
find_library(PADDLE_FLUID_SHARED_LIB NAMES "libpaddle_fluid.so" PATHS
${PADDLE_ROOT}/${LIB_PATH})
find_library(PADDLE_FLUID_STATIC_LIB NAMES "libpaddle_fluid.a" PATHS
${PADDLE_ROOT}/${LIB_PATH})
if(USE_SHARED AND PADDLE_INC_DIR AND PADDLE_FLUID_SHARED_LIB)
set(PADDLE_FOUND ON)
add_library(paddle_fluid_shared SHARED IMPORTED)
set_target_properties(paddle_fluid_shared PROPERTIES IMPORTED_LOCATION
${PADDLE_FLUID_SHARED_LIB})
set(PADDLE_LIBRARIES paddle_fluid_shared)
message(STATUS "Found PaddlePaddle Fluid (include: ${PADDLE_INC_DIR}; "
"library: ${PADDLE_FLUID_SHARED_LIB}")
elseif(PADDLE_INC_DIR AND PADDLE_FLUID_STATIC_LIB)
set(PADDLE_FOUND ON)
add_library(paddle_fluid_static STATIC IMPORTED)
set_target_properties(paddle_fluid_static PROPERTIES IMPORTED_LOCATION
${PADDLE_FLUID_STATIC_LIB})
set(PADDLE_LIBRARIES paddle_fluid_static)
message(STATUS "Found PaddlePaddle Fluid (include: ${PADDLE_INC_DIR}; "
"library: ${PADDLE_FLUID_STATIC_LIB}")
else()
set(PADDLE_FOUND OFF)
message(WARNING "Cannot find PaddlePaddle Fluid under ${PADDLE_ROOT}")
return()
endif()
# including directory of third_party libraries
set(PADDLE_THIRD_PARTY_INC_DIRS)
function(third_party_include TARGET_NAME HEADER_NAME TARGET_DIRNAME)
find_path(PADDLE_${TARGET_NAME}_INC_DIR NAMES ${HEADER_NAME} PATHS
${TARGET_DIRNAME}
NO_DEFAULT_PATH)
if(PADDLE_${TARGET_NAME}_INC_DIR)
message(STATUS "Found PaddlePaddle third_party including directory: " ${PADDLE_${TARGET_NAME}_INC_DIR})
set(PADDLE_THIRD_PARTY_INC_DIRS ${PADDLE_THIRD_PARTY_INC_DIRS} ${PADDLE_${TARGET_NAME}_INC_DIR} PARENT_SCOPE)
endif()
endfunction()
third_party_include(glog glog/logging.h ${THIRD_PARTY_ROOT}/install/glog/include)
third_party_include(protobuf google/protobuf/message.h ${THIRD_PARTY_ROOT}/install/protobuf/include)
third_party_include(gflags gflags/gflags.h ${THIRD_PARTY_ROOT}/install/gflags/include)
third_party_include(eigen unsupported/Eigen/CXX11/Tensor ${THIRD_PARTY_ROOT}/eigen3)
third_party_include(boost boost/config.hpp ${THIRD_PARTY_ROOT}/boost)
if(USE_GPU)
third_party_include(cuda cuda.h ${CUDA_ROOT}/include)
third_party_include(cudnn cudnn.h ${CUDNN_ROOT}/include)
endif()
message(STATUS "PaddlePaddle need to include these third party directories: ${PADDLE_THIRD_PARTY_INC_DIRS}")
include_directories(${PADDLE_THIRD_PARTY_INC_DIRS})
set(PADDLE_THIRD_PARTY_LIBRARIES)
function(third_party_library TARGET_NAME TARGET_DIRNAME)
set(library_names ${ARGN})
set(local_third_party_libraries)
foreach(lib ${library_names})
string(REGEX REPLACE "^lib" "" lib_noprefix ${lib})
if(${lib} MATCHES "${CMAKE_STATIC_LIBRARY_SUFFIX}$")
set(libtype STATIC)
string(REGEX REPLACE "${CMAKE_STATIC_LIBRARY_SUFFIX}$" "" libname ${lib_noprefix})
elseif(${lib} MATCHES "${CMAKE_SHARED_LIBRARY_SUFFIX}(\\.[0-9]+)?$")
set(libtype SHARED)
string(REGEX REPLACE "${CMAKE_SHARED_LIBRARY_SUFFIX}(\\.[0-9]+)?$" "" libname ${lib_noprefix})
else()
message(FATAL_ERROR "Unknown library type: ${lib}")
endif()
#message(STATUS "libname: ${libname}")
find_library(${libname}_LIBRARY NAMES "${lib}" PATHS
${TARGET_DIRNAME}
NO_DEFAULT_PATH)
if(${libname}_LIBRARY)
set(${TARGET_NAME}_FOUND ON PARENT_SCOPE)
add_library(${libname} ${libtype} IMPORTED)
set_target_properties(${libname} PROPERTIES IMPORTED_LOCATION ${${libname}_LIBRARY})
set(local_third_party_libraries ${local_third_party_libraries} ${libname})
message(STATUS "Found PaddlePaddle third_party library: " ${${libname}_LIBRARY})
else()
set(${TARGET_NAME}_FOUND OFF PARENT_SCOPE)
message(WARNING "Cannot find ${lib} under ${THIRD_PARTY_ROOT}")
endif()
endforeach()
set(PADDLE_THIRD_PARTY_LIBRARIES ${PADDLE_THIRD_PARTY_LIBRARIES} ${local_third_party_libraries} PARENT_SCOPE)
endfunction()
third_party_library(mklml ${THIRD_PARTY_ROOT}/install/mklml/lib libiomp5.so libmklml_intel.so)
third_party_library(mkldnn ${THIRD_PARTY_ROOT}/install/mkldnn/lib libmkldnn.so)
if(NOT mkldnn_FOUND)
third_party_library(mkldnn ${THIRD_PARTY_ROOT}/install/mkldnn/lib libmkldnn.so.0)
endif()
if(NOT USE_SHARED)
third_party_library(glog ${THIRD_PARTY_ROOT}/install/glog/lib libglog.a)
third_party_library(protobuf ${THIRD_PARTY_ROOT}/install/protobuf/lib libprotobuf.a)
third_party_library(gflags ${THIRD_PARTY_ROOT}/install/gflags/lib libgflags.a)
if(NOT mklml_FOUND)
third_party_library(openblas ${THIRD_PARTY_ROOT}/install/openblas/lib libopenblas.a)
endif()
third_party_library(zlib ${THIRD_PARTY_ROOT}/install/zlib/lib libz.a)
third_party_library(snappystream ${THIRD_PARTY_ROOT}/install/snappystream/lib libsnappystream.a)
third_party_library(snappy ${THIRD_PARTY_ROOT}/install/snappy/lib libsnappy.a)
third_party_library(xxhash ${THIRD_PARTY_ROOT}/install/xxhash/lib libxxhash.a)
if(USE_GPU)
third_party_library(cudart ${CUDA_ROOT}/lib64 libcudart.so)
endif()
endif()
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
# Tries to find Gperftools.
#
# Usage of this module as follows:
#
# find_package(Gperftools)
#
# Variables used by this module, they can change the default behaviour and need
# to be set before calling find_package:
#
# Gperftools_ROOT_DIR Set this variable to the root installation of
# Gperftools if the module has problems finding
# the proper installation path.
#
# Variables defined by this module:
#
# GPERFTOOLS_FOUND System has Gperftools libs/headers
# GPERFTOOLS_LIBRARIES The Gperftools libraries (tcmalloc & profiler)
# GPERFTOOLS_INCLUDE_DIR The location of Gperftools headers
find_library(GPERFTOOLS_TCMALLOC
NAMES tcmalloc
HINTS ${Gperftools_ROOT_DIR}/lib)
find_library(GPERFTOOLS_PROFILER
NAMES profiler
HINTS ${Gperftools_ROOT_DIR}/lib)
find_library(GPERFTOOLS_TCMALLOC_AND_PROFILER
NAMES tcmalloc_and_profiler
HINTS ${Gperftools_ROOT_DIR}/lib)
find_path(GPERFTOOLS_INCLUDE_DIR
NAMES gperftools/heap-profiler.h
HINTS ${Gperftools_ROOT_DIR}/include)
set(GPERFTOOLS_LIBRARIES ${GPERFTOOLS_TCMALLOC_AND_PROFILER})
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(
Gperftools
DEFAULT_MSG
GPERFTOOLS_LIBRARIES
GPERFTOOLS_INCLUDE_DIR)
mark_as_advanced(
Gperftools_ROOT_DIR
GPERFTOOLS_TCMALLOC
GPERFTOOLS_PROFILER
GPERFTOOLS_TCMALLOC_AND_PROFILER
GPERFTOOLS_LIBRARIES
GPERFTOOLS_INCLUDE_DIR)
# create IMPORTED targets
if (Gperftools_FOUND AND NOT TARGET gperftools::tcmalloc)
add_library(gperftools::tcmalloc UNKNOWN IMPORTED)
set_target_properties(gperftools::tcmalloc PROPERTIES
IMPORTED_LOCATION ${GPERFTOOLS_TCMALLOC}
INTERFACE_INCLUDE_DIRECTORIES "${GPERFTOOLS_INCLUDE_DIR}")
add_library(gperftools::profiler UNKNOWN IMPORTED)
set_target_properties(gperftools::profiler PROPERTIES
IMPORTED_LOCATION ${GPERFTOOLS_PROFILER}
INTERFACE_INCLUDE_DIRECTORIES "${GPERFTOOLS_INCLUDE_DIR}")
endif()
#!/bin/bash
MODEL_DIR=$HOME/repo/Paddle/resnet50_quant_int8
DATA_FILE=$HOME/.cache/paddle/dataset/int8/download/int8_full_val.bin
num_threads=10
with_accuracy_layer=false
use_profile=true
ITERATIONS=0
./build/inference --logtostderr=1 \
--infer_model=${MODEL_DIR} \
--infer_data=${DATA_FILE} \
--batch_size=1 \
--num_threads=${num_threads} \
--iterations=${ITERATIONS} \
--with_accuracy_layer=${with_accuracy_layer} \
--use_profile=${use_profile} \
--optimize_fp32_model=false
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <paddle_inference_api.h>
#include <algorithm>
#include <chrono>
#include <fstream>
#include <iomanip>
#include <iostream>
#include <numeric>
#include <sstream>
#include <string>
#include <vector>
#ifdef WITH_GPERFTOOLS
#include <gperftools/profiler.h>
#include <paddle/fluid/platform/profiler.h>
#endif
DEFINE_string(infer_model, "", "path to the model");
DEFINE_string(infer_data, "", "path to the input data");
DEFINE_int32(batch_size, 50, "inference batch size");
DEFINE_int32(iterations,
0,
"number of batches to process. 0 means testing whole dataset");
DEFINE_int32(num_threads, 1, "num of threads to run in parallel");
DEFINE_bool(with_accuracy_layer,
true,
"Set with_accuracy_layer to true if provided model has accuracy layer and requires label input");
DEFINE_bool(use_profile, false, "Set use_profile to true to get profile information");
DEFINE_bool(optimize_fp32_model, false, "If optimize_fp32_model is set to true, fp32 model will be optimized");
struct Timer {
std::chrono::high_resolution_clock::time_point start;
std::chrono::high_resolution_clock::time_point startu;
void tic() { start = std::chrono::high_resolution_clock::now(); }
double toc() {
startu = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> time_span =
std::chrono::duration_cast<std::chrono::duration<double>>(startu -
start);
double used_time_ms = static_cast<double>(time_span.count()) * 1000.0;
return used_time_ms;
}
};
template <typename T>
constexpr paddle::PaddleDType GetPaddleDType();
template <>
constexpr paddle::PaddleDType GetPaddleDType<int64_t>() {
return paddle::PaddleDType::INT64;
}
template <>
constexpr paddle::PaddleDType GetPaddleDType<float>() {
return paddle::PaddleDType::FLOAT32;
}
template <typename T>
class TensorReader {
public:
TensorReader(std::ifstream &file,
size_t beginning_offset,
std::vector<int> shape,
std::string name)
: file_(file), position_(beginning_offset), shape_(shape), name_(name) {
numel_ = std::accumulate(
shape_.begin(), shape_.end(), size_t{1}, std::multiplies<size_t>());
}
paddle::PaddleTensor NextBatch() {
paddle::PaddleTensor tensor;
tensor.name = name_;
tensor.shape = shape_;
tensor.dtype = GetPaddleDType<T>();
tensor.data.Resize(numel_ * sizeof(T));
file_.seekg(position_);
file_.read(static_cast<char *>(tensor.data.data()), numel_ * sizeof(T));
position_ = file_.tellg();
if (file_.eof()) LOG(ERROR) << name_ << ": reached end of stream";
if (file_.bad()) LOG(ERROR) << name_ << "ERROR: badbit is true";
if (file_.fail())
throw std::runtime_error(name_ + ": failed reading file.");
return tensor;
}
protected:
std::ifstream &file_;
size_t position_;
std::vector<int> shape_;
std::string name_;
size_t numel_;
};
void SetInput(std::vector<std::vector<paddle::PaddleTensor>> *inputs,
std::vector<paddle::PaddleTensor> *labels_gt,
bool with_accuracy_layer = FLAGS_with_accuracy_layer,
int32_t batch_size = FLAGS_batch_size) {
std::ifstream file(FLAGS_infer_data, std::ios::binary);
if (!file) {
throw std::runtime_error("Couldn't open file: " + FLAGS_infer_data);
}
int64_t total_images{0};
file.seekg(0, std::ios::beg);
file.read(reinterpret_cast<char *>(&total_images), sizeof(total_images));
LOG(INFO) << "Total images in file: " << total_images;
std::vector<int> image_batch_shape{batch_size, 3, 224, 224};
std::vector<int> label_batch_shape{batch_size, 1};
auto images_offset_in_file = static_cast<size_t>(file.tellg());
TensorReader<float> image_reader(
file, images_offset_in_file, image_batch_shape, "image");
auto iterations_max = total_images / batch_size;
auto iterations = iterations_max;
if (FLAGS_iterations > 0 && FLAGS_iterations < iterations_max) {
iterations = FLAGS_iterations;
}
auto labels_offset_in_file =
images_offset_in_file + sizeof(float) * total_images * 3 * 224 * 224;
TensorReader<int64_t> label_reader(
file, labels_offset_in_file, label_batch_shape, "label");
for (auto i = 0; i < iterations; i++) {
auto images = image_reader.NextBatch();
std::vector<paddle::PaddleTensor> tmp_vec;
tmp_vec.push_back(std::move(images));
auto labels = label_reader.NextBatch();
if (with_accuracy_layer) {
tmp_vec.push_back(std::move(labels));
} else {
labels_gt->push_back(std::move(labels));
}
inputs->push_back(std::move(tmp_vec));
}
}
static void PrintTime(int batch_size,
int num_threads,
double batch_latency,
int epoch = 1) {
double sample_latency = batch_latency / batch_size;
LOG(INFO) <<"Model: "<<FLAGS_infer_model;
LOG(INFO) << "====== num of threads: " << num_threads << " ======";
LOG(INFO) << "====== batch size: " << batch_size << ", iterations: " << epoch;
LOG(INFO) << "====== batch latency: " << batch_latency
<< "ms, number of samples: " << batch_size * epoch;
LOG(INFO) << ", sample latency: " << sample_latency
<< "ms, fps: " << 1000.f / sample_latency << " ======";
}
void PredictionRun(paddle::PaddlePredictor *predictor,
const std::vector<std::vector<paddle::PaddleTensor>> &inputs,
std::vector<std::vector<paddle::PaddleTensor>> *outputs,
int num_threads,
float *sample_latency = nullptr) {
int iterations = inputs.size(); // process the whole dataset ...
if (FLAGS_iterations > 0 &&
FLAGS_iterations < static_cast<int64_t>(inputs.size()))
iterations =
FLAGS_iterations; // ... unless the number of iterations is set
outputs->resize(iterations);
Timer run_timer;
double elapsed_time = 0;
#ifdef WITH_GPERFTOOLS
ResetProfiler();
ProfilerStart("paddle_inference.prof");
#endif
int predicted_num = 0;
for (int i = 0; i < iterations; i++) {
run_timer.tic();
predictor->Run(inputs[i], &(*outputs)[i], FLAGS_batch_size);
elapsed_time += run_timer.toc();
predicted_num += FLAGS_batch_size;
if (predicted_num % 100 == 0) {
LOG(INFO) << "Infer " << predicted_num << " samples";
}
}
#ifdef WITH_GPERFTOOLS
ProfilerStop();
#endif
auto batch_latency = elapsed_time / iterations;
PrintTime(FLAGS_batch_size, num_threads, batch_latency, iterations);
if (sample_latency != nullptr)
*sample_latency = batch_latency / FLAGS_batch_size;
}
std::pair<float, float> CalculateAccuracy(
const std::vector<std::vector<paddle::PaddleTensor>> &outputs,
const std::vector<paddle::PaddleTensor> &labels_gt,
bool with_accuracy = FLAGS_with_accuracy_layer) {
LOG_IF(ERROR, !with_accuracy && labels_gt.size() == 0)
<< "if with_accuracy set to false, labels_gt must be not empty";
std::vector<float> acc1_ss;
std::vector<float> acc5_ss;
if (!with_accuracy) { // model with_accuracy_layer = false
float *result_array; // for one batch 50*1000
int64_t *batch_labels; // 50*1
LOG_IF(ERROR, outputs.size() != labels_gt.size())
<< "outputs first dimension must be equal to labels_gt first dimension";
for (auto i = 0; i < outputs.size();
++i) { // same as labels first dimension
result_array = static_cast<float *>(outputs[i][0].data.data());
batch_labels = static_cast<int64_t *>(labels_gt[i].data.data());
int correct_1 = 0, correct_5 = 0, total = FLAGS_batch_size;
for (auto j = 0; j < FLAGS_batch_size; j++) { // batch_size
std::vector<float> v(result_array + j * 1000,
result_array + (j + 1) * 1000);
std::vector<std::pair<float, int>> vx;
for (int k = 0; k < 1000; k++) {
vx.push_back(std::make_pair(v[k], k));
}
std::partial_sort(vx.begin(),
vx.begin() + 5,
vx.end(),
[](std::pair<float, int> a, std::pair<float, int> b) {
return a.first > b.first;
});
if (static_cast<int>(batch_labels[j]) == vx[0].second) correct_1 += 1;
if (std::find_if(vx.begin(),
vx.begin() + 5,
[batch_labels, j](std::pair<float, int> a) {
return static_cast<int>(batch_labels[j]) == a.second;
}) != vx.begin() + 5)
correct_5 += 1;
}
acc1_ss.push_back(static_cast<float>(correct_1) /
static_cast<float>(total));
acc5_ss.push_back(static_cast<float>(correct_5) /
static_cast<float>(total));
}
} else { // model with_accuracy_layer = true
for (auto i = 0; i < outputs.size(); ++i) {
LOG_IF(ERROR, outputs[i].size() < 3UL) << "To get top1 and top5 "
"accuracy, output[i] size must "
"be bigger than or equal to 3";
acc1_ss.push_back(
*static_cast<float *>(outputs[i][1].data.data())); // 1 is top1 acc
acc5_ss.push_back(*static_cast<float *>(
outputs[i][2].data.data())); // 2 is top5 acc or mAP
}
}
auto acc1_ss_avg =
std::accumulate(acc1_ss.begin(), acc1_ss.end(), 0.0) / acc1_ss.size();
auto acc5_ss_avg =
std::accumulate(acc5_ss.begin(), acc5_ss.end(), 0.0) / acc5_ss.size();
return std::make_pair(acc1_ss_avg, acc5_ss_avg);
}
static void SetIrOptimConfig(paddle::AnalysisConfig *cfg) {
cfg->DisableGpu();
cfg->SwitchIrOptim();
cfg->EnableMKLDNN();
if(FLAGS_use_profile){
cfg->EnableProfile();
}
}
std::unique_ptr<paddle::PaddlePredictor> CreatePredictor(
const paddle::PaddlePredictor::Config *config, bool use_analysis = true) {
const auto *analysis_config =
reinterpret_cast<const paddle::AnalysisConfig *>(config);
if (use_analysis) {
return paddle::CreatePaddlePredictor<paddle::AnalysisConfig>(
*analysis_config);
}
auto native_config = analysis_config->ToNativeConfig();
return paddle::CreatePaddlePredictor<paddle::NativeConfig>(native_config);
}
int main(int argc, char *argv[]) {
// InitFLAGS(argc, argv);
google::InitGoogleLogging(*argv);
gflags::ParseCommandLineFlags(&argc, &argv, true);
paddle::AnalysisConfig cfg;
cfg.SetModel(FLAGS_infer_model);
cfg.SetCpuMathLibraryNumThreads(FLAGS_num_threads);
if (FLAGS_optimize_fp32_model){
SetIrOptimConfig(&cfg);
}
std::vector<std::vector<paddle::PaddleTensor>> input_slots_all;
std::vector<std::vector<paddle::PaddleTensor>> outputs;
std::vector<paddle::PaddleTensor> labels_gt; // optional
SetInput(&input_slots_all, &labels_gt); // iterations*batch_size
auto predictor = CreatePredictor(reinterpret_cast<paddle::PaddlePredictor::Config *>(&cfg), FLAGS_optimize_fp32_model);
PredictionRun(predictor.get(), input_slots_all, &outputs, FLAGS_num_threads);
auto acc_pair = CalculateAccuracy(outputs, labels_gt);
LOG(INFO) <<"Top1 accuracy: " << std::fixed << std::setw(6)
<<std::setprecision(4) << acc_pair.first;
LOG(INFO) <<"Top5 accuracy: " << std::fixed << std::setw(6)
<<std::setprecision(4) << acc_pair.second;
}
# copyright (c) 2020 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
import unittest
import os
import sys
import argparse
import logging
import struct
import six
import numpy as np
import time
import paddle
import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph
from paddle.fluid import core
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
'--batch_size', type=int, default=1, help='Batch size.')
parser.add_argument(
'--skip_batch_num',
type=int,
default=0,
help='Number of the first minibatches to skip in performance statistics.'
)
parser.add_argument(
'--infer_model',
type=str,
default='',
help='A path to an Inference model.')
parser.add_argument(
'--infer_data', type=str, default='', help='Data file.')
parser.add_argument(
'--batch_num',
type=int,
default=0,
help='Number of batches to process. 0 or less means whole dataset. Default: 0.'
)
parser.add_argument(
'--with_accuracy_layer',
type=bool,
default=False,
help='The model is with accuracy or without accuracy layer')
test_args, args = parser.parse_known_args(namespace=unittest)
return test_args, sys.argv[:1] + args
class SampleTester(unittest.TestCase):
def _reader_creator(self, data_file='data.bin'):
def reader():
with open(data_file, 'rb') as fp:
num = fp.read(8)
num = struct.unpack('q', num)[0]
imgs_offset = 8
img_ch = 3
img_w = 224
img_h = 224
img_pixel_size = 4
img_size = img_ch * img_h * img_w * img_pixel_size
label_size = 8
labels_offset = imgs_offset + num * img_size
step = 0
while step < num:
fp.seek(imgs_offset + img_size * step)
img = fp.read(img_size)
img = struct.unpack_from(
'{}f'.format(img_ch * img_w * img_h), img)
img = np.array(img)
img.shape = (img_ch, img_w, img_h)
fp.seek(labels_offset + label_size * step)
label = fp.read(label_size)
label = struct.unpack('q', label)[0]
yield img, int(label)
step += 1
return reader
def _get_batch_accuracy(self, batch_output=None, labels=None):
total = 0
correct = 0
correct_5 = 0
for n, result in enumerate(batch_output):
index = result.argsort()
top_1_index = index[-1]
top_5_index = index[-5:]
total += 1
if top_1_index == labels[n]:
correct += 1
if labels[n] in top_5_index:
correct_5 += 1
acc1 = float(correct) / float(total)
acc5 = float(correct_5) / float(total)
return acc1, acc5
def _prepare_for_fp32_mkldnn(self, graph):
ops = graph.all_op_nodes()
for op_node in ops:
name = op_node.name()
if name in ['depthwise_conv2d']:
input_var_node = graph._find_node_by_name(
op_node.inputs, op_node.input("Input")[0])
weight_var_node = graph._find_node_by_name(
op_node.inputs, op_node.input("Filter")[0])
output_var_node = graph._find_node_by_name(
graph.all_var_nodes(), op_node.output("Output")[0])
attrs = {
name: op_node.op().attr(name)
for name in op_node.op().attr_names()
}
conv_op_node = graph.create_op_node(
op_type='conv2d',
attrs=attrs,
inputs={
'Input': input_var_node,
'Filter': weight_var_node
},
outputs={'Output': output_var_node})
graph.link_to(input_var_node, conv_op_node)
graph.link_to(weight_var_node, conv_op_node)
graph.link_to(conv_op_node, output_var_node)
graph.safe_remove_nodes(op_node)
return graph
def _predict(self,
test_reader=None,
model_path=None,
with_accuracy_layer=False,
batch_size=1,
batch_num=1,
skip_batch_num=0):
place = fluid.CPUPlace()
exe = fluid.Executor(place)
inference_scope = fluid.executor.global_scope()
with fluid.scope_guard(inference_scope):
if os.path.exists(os.path.join(model_path, '__model__')):
[inference_program, feed_target_names, fetch_targets
] = fluid.io.load_inference_model(model_path, exe)
else:
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(
model_path, exe, 'model', 'params')
graph = IrGraph(core.Graph(inference_program.desc), for_test=True)
graph = self._prepare_for_fp32_mkldnn(graph)
inference_program = graph.to_program()
dshape = [3, 224, 224]
outputs = []
infer_accs1 = []
infer_accs5 = []
batch_acc1 = 0.0
batch_acc5 = 0.0
fpses = []
batch_times = []
batch_time = 0.0
total_samples = 0
iters = 0
infer_start_time = time.time()
for data in test_reader():
if batch_num > 0 and iters >= batch_num:
break
if iters == skip_batch_num:
total_samples = 0
infer_start_time = time.time()
if six.PY2:
images = map(lambda x: x[0].reshape(dshape), data)
if six.PY3:
images = list(map(lambda x: x[0].reshape(dshape), data))
images = np.array(images).astype('float32')
labels = np.array([x[1] for x in data]).astype('int64')
if (with_accuracy_layer == False):
# models that do not have accuracy measuring layers
start = time.time()
out = exe.run(inference_program,
feed={feed_target_names[0]: images},
fetch_list=fetch_targets)
batch_time = (time.time() - start) * 1000 # in miliseconds
outputs.append(out[0])
# Calculate accuracy result
batch_acc1, batch_acc5 = self._get_batch_accuracy(out[0],
labels)
else:
# models have accuracy measuring layers
labels = labels.reshape([-1, 1])
start = time.time()
out = exe.run(inference_program,
feed={
feed_target_names[0]: images,
feed_target_names[1]: labels
},
fetch_list=fetch_targets)
batch_time = (time.time() - start) * 1000 # in miliseconds
batch_acc1, batch_acc5 = out[1][0], out[2][0]
outputs.append(batch_acc1)
infer_accs1.append(batch_acc1)
infer_accs5.append(batch_acc5)
samples = len(data)
total_samples += samples
batch_times.append(batch_time)
fps = samples / batch_time * 1000
fpses.append(fps)
iters += 1
appx = ' (warm-up)' if iters <= skip_batch_num else ''
_logger.info('batch {0}{5}, acc1: {1:.4f}, acc5: {2:.4f}, '
'latency: {3:.4f} ms, fps: {4:.2f}'.format(
iters, batch_acc1, batch_acc5, batch_time /
batch_size, fps, appx))
# Postprocess benchmark data
batch_latencies = batch_times[skip_batch_num:]
batch_latency_avg = np.average(batch_latencies)
latency_avg = batch_latency_avg / batch_size
fpses = fpses[skip_batch_num:]
fps_avg = np.average(fpses)
infer_total_time = time.time() - infer_start_time
acc1_avg = np.mean(infer_accs1)
acc5_avg = np.mean(infer_accs5)
_logger.info('Total inference run time: {:.2f} s'.format(
infer_total_time))
return outputs, acc1_avg, acc5_avg, fps_avg, latency_avg
def test_graph_transformation(self):
if not fluid.core.is_compiled_with_mkldnn():
return
infer_model_path = test_case_args.infer_model
assert infer_model_path, 'The model path cannot be empty. Please, use the --infer_model option.'
data_path = test_case_args.infer_data
assert data_path, 'The dataset path cannot be empty. Please, use the --infer_data option.'
batch_size = test_case_args.batch_size
batch_num = test_case_args.batch_num
skip_batch_num = test_case_args.skip_batch_num
with_accuracy_layer = test_case_args.with_accuracy_layer
_logger.info('Inference model: {0}'.format(infer_model_path))
_logger.info('Dataset: {0}'.format(data_path))
_logger.info('Batch size: {0}'.format(batch_size))
_logger.info('Batch number: {0}'.format(batch_num))
_logger.info('--- Inference prediction start ---')
val_reader = paddle.batch(
self._reader_creator(data_path), batch_size=batch_size)
fp32_output, fp32_acc1, fp32_acc5, fp32_fps, fp32_lat = self._predict(
val_reader, infer_model_path, with_accuracy_layer, batch_size,
batch_num, skip_batch_num)
_logger.info(
'Inference: avg top1 accuracy: {0:.4f}, avg top5 accuracy: {1:.4f}'.
format(fp32_acc1, fp32_acc5))
_logger.info('Inference: avg fps: {0:.2f}, avg latency: {1:.4f} ms'.
format(fp32_fps, fp32_lat))
if __name__ == '__main__':
global test_case_args
test_case_args, remaining_args = parse_args()
unittest.main(argv=remaining_args)
# CPU部署预测INT8模型的精度和性能
在Intel(R) Xeon(R) Gold 6271机器上,经过量化和DNNL加速,INT8模型在单线程上性能为原FP32模型的3~4倍;在 Intel(R) Xeon(R) Gold 6148,单线程性能为原FP32模型的1.5倍,而精度仅有极小下降。图像分类量化的样例教程请参考[图像分类INT8模型在CPU优化部署和预测](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/mkldnn_quant/quant_aware/PaddleCV_mkldnn_quantaware_tutorial_cn.md)。自然语言处理模型的量化请参考[ERNIE INT8 模型精度与性能复现](https://github.com/PaddlePaddle/benchmark/tree/master/Inference/c%2B%2B/ernie/mkldnn)
## 图像分类INT8模型在 Xeon(R) 6271 上的精度和性能
>**图像分类INT8模型在 Intel(R) Xeon(R) Gold 6271 上精度**
| Model | FP32 Top1 Accuracy | INT8 Top1 Accuracy | Top1 Diff | FP32 Top5 Accuracy | INT8 Top5 Accuracy | Top5 Diff |
| :----------: | :----------------: | :--------------------: | :-------: | :----------------: | :--------------------: | :-------: |
| MobileNet-V1 | 70.78% | 70.71% | -0.07% | 89.69% | 89.41% | -0.28% |
| MobileNet-V2 | 71.90% | 72.11% | +0.21% | 90.56% | 90.62% | +0.06% |
| ResNet101 | 77.50% | 77.64% | +0.14% | 93.58% | 93.58% | 0.00% |
| ResNet50 | 76.63% | 76.47% | -0.16% | 93.10% | 92.98% | -0.12% |
| VGG16 | 72.08% | 71.73% | -0.35% | 90.63% | 89.71% | -0.92% |
| VGG19 | 72.57% | 72.12% | -0.45% | 90.84% | 90.15% | -0.69% |
>**图像分类INT8模型在 Intel(R) Xeon(R) Gold 6271 单核上性能**
| Model | FP32 (images/s) | INT8 (images/s) | Ratio (INT8/FP32) |
| :----------: | :-------------: | :-----------------: | :---------------: |
| MobileNet-V1 | 74.05 | 196.98 | 2.66 |
| MobileNet-V2 | 88.60 | 187.67 | 2.12 |
| ResNet101 | 7.20 | 26.43 | 3.67 |
| ResNet50 | 13.23 | 47.44 | 3.59 |
| VGG16 | 3.47 | 10.20 | 2.94 |
| VGG19 | 2.83 | 8.67 | 3.06 |
## 自然语言处理INT8模型在 Xeon(R) 6271 上的精度和性能
>**I. Ernie INT8 DNNL 在 Intel(R) Xeon(R) Gold 6271 的精度结果**
| Model | FP32 Accuracy | INT8 Accuracy | Accuracy Diff |
|:------------:|:----------------------:|:----------------------:|:---------:|
| Ernie | 80.20% | 79.44% | -0.76% |
>**II. Ernie INT8 DNNL 在 Intel(R) Xeon(R) Gold 6271 上单样本耗时**
| Threads | FP32 Latency (ms) | INT8 Latency (ms) | Ratio (FP32/INT8) |
|:------------:|:----------------------:|:-------------------:|:-----------------:|
| 1 thread | 237.21 | 79.26 | 2.99X |
| 20 threads | 22.08 | 12.57 | 1.76X |
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册