未验证 提交 a2021b56 编写于 作者: C cc 提交者: GitHub

Update doc: add_operation, model_quantization and support_operation_list (#2840)

* update doc:add_operation, model_quantization and support_operation_list, test=document_fix
上级 b65cb2af
请参考[PaddleLite文档开发规范](http://agroup.baidu.com/paddle-infer/md/article/2561104)
# 新增OP的方法
以下以添加argmax为例,详细说明新增op的方法。
## 1. 添加OpParam 结构体以传导 Op 的输入和输出
- 这里命名为 `ArgmaxParam`
-`paddlelite/lite/operators/op_params.h` 中添加 `ArgmaxParam` 结构体,代码如下:
```c++
struct ArgmaxParam {
lite::Tensor* X{};
lite::Tensor* Out{};
int Axis{0};
};
```
## 2. 添加 Argmax Op 并注册
- 在paddlelite/lite/operators/目录下新建argmax_op.h文件,主要代码如下:
```c++
class ArgmaxOpLite : public OpLite {
public:
ArgmaxOpLite() {}
explicit ArgmaxOpLite(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "argmax"; }
private:
mutable ArgmaxParam param_;
};
```
`ArgmaxOpLite` 继承 `OpLite` ,成员变量包括 `ArgmaxParam` 结构体,需要实现的接口包括 `CheckShape()``InferShape()``AttachImp()``AttachKernel()``DebugString()` 函数。`AttachKernel()``DebugString() `函数较为简单,此处直接实现;
-`paddlelite/lite/operators/` 目录下新建argmax_op.cc文件,需要具体实现`CheckShape()``InferShape()``AttachImp()`函数。`CheckShape()`函数检查输入是否符合要求,`InferShape()`函数基于输入推断得到输出的维度,`AttachImp()`函数绑定Op的输入输出。然后在argmax_op.cc文件中注册argmax,核心代码如下:
```c++
bool ArgmaxOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.Out);
CHECK_OR_FALSE(param_.Axis < (param_.X)->dims().size());
return true;
}
bool ArgmaxOpLite::InferShape() const {
auto x_dims = param_.X->dims();
int x_rank = x_dims.size();
int axis = param_.Axis;
if (axis < 0) axis += x_rank;
std::vector<int64_t> out_dims;
for (int64_t i = 0; i < axis; i++) {
out_dims.push_back(x_dims[i]);
}
for (int64_t i = axis + 1; i < x_rank; i++) {
out_dims.push_back(x_dims[i]);
}
// Set output dims
param_.Out->Resize(lite::DDim(out_dims));
return true;
}
bool ArgmaxOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
auto x = op_desc.Input("X").front();
auto out = op_desc.Output("Out").front();
param_.X = scope->FindVar(x)->GetMutable<lite::Tensor>();
param_.Out = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.Axis = op_desc.GetAttr<int>("Axis");
return true;
}
REGISTER_LITE_OP(argmax, paddle::lite::operators::ArgmaxOpLite);
```
- 在paddlelite/lite/operators/CMakeLists.txt中添加```add_operator(argmax_op basic SRCS argmax_op.cc DEPS ${op_DEPS})```
## 3. 添加Argmax Kernel并绑定
以下以arm端argmax实现为例说明
- 在paddlelite/lite/kernels/arm/目录下新建argmax_compute.h文件,声明ArgmaxCompute类,并继承KernelLite,主要代码如下:
```c++
class ArgmaxCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::ArgmaxParam;
void Run() override;
virtual ~ArgmaxCompute() = default;
};
```
- 在paddlelite/lite/kernels/arm/目录下新建argmax_compute.cc文件,主要实现Run函数。`Run()`函数调用paddlelite/lite/bachends/arm/math/argmax.h中的`argmax_func()`函数,根据输入计算输出。最后在argmax_compute.cc文件中,我们绑定argmax的输入输出(为tensor的输入参数都需要绑定),代码如下:
```c++
void ArgmaxCompute::Run() {
auto& param = Param<operators::ArgmaxParam>();
lite::Tensor* input = param.X;
lite::Tensor* output = param.Out;
int axis = param.Axis;
lite::arm::math::argmax_func(input, axis, output);
return;
}
REGISTER_LITE_KERNEL(
argmax, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::ArgmaxCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
```
- 在paddlelite/lite/kernels/arm/CMakeLists.txt中添加
```cmake
add_kernel(argmax_compute_arm ARM basic SRCS argmax_compute.cc DEPS ${lite_kernel_deps} math_arm)
```
## 4. 添加Argmax实现
- 在paddlelite/lite/backends/arm/math/目录下新建argmax.h文件,声明`argmax_func()`函数,代码如下:
```c++
void argmax_func(const lite::Tensor* input, const int axis, lite::Tensor* output);
```
- 在paddlelite/lite/backends/arm/math/目录下新建argmax.cc文件,具体实现`argmax_func()`函数,代码如下:
```c++
void argmax_func(const lite::Tensor *input,
const int axis,
lite::Tensor *output) {
auto input_ddim = input->dims();
auto output_ddim = output->dims();
const int size = input_ddim[axis];
const int in_channel = input_ddim.count(axis, input_ddim.size());
const int out_channel = output_ddim.count(axis, output_ddim.size());
const int in_stride = input_ddim.count(axis + 1, input_ddim.size());
const int out_stride = input_ddim.count(0, axis);
for (int n = 0; n < out_stride; n++) {
for (int k = 0; k < in_stride; k++) {
const float *in_ptr = input->data<float>() + n * in_channel + k;
std::vector<std::pair<float, int>> vec;
vec.resize(size);
for (int i = 0; i < size; i++) {
vec[i] = std::make_pair(in_ptr[i * in_stride], i);
}
// sort
std::partial_sort(vec.begin(),
vec.begin() + 1,
vec.end(),
std::greater<std::pair<float, int>>());
// out
float *out_ptr = output->mutable_data<float>() + n * out_channel + k;
*out_ptr = vec[0].second;
}
}
}
```
- 在paddlelite/lite/backends/arm/math/CMakeFile.txt中的```math_arm library```中添加argmax.cc,在paddlelite/lite/backends/arm/math/funcs.h中添加```#include "lite/arm/math/argmax.h"```
## 5. 添加Argmax单测
- 在paddlelite/lite/tests/kernels目录下新建argmax_compute_test.cc文件,声明并实现ArgmaxComputeTester类;
- ArgmaxComputeTester类中主要包括PrepareOpDesc、PrepareData和RunBaseline函数。PrepareOpDesc函数设定单测op的类型和输入输出参数,PrepareData函数对输入tensor进行初始化,RunBaseline是基于输入计算得到输出,用于和框架计算的输出进行对比;
- 使用gtest添加单测,代码如下:
```c++
TEST(Argmax, precision) {
#ifdef LITE_WITH_ARM
LOG(INFO) << "test argmax arm";
Place place(TARGET(kARM));
for (int axis : {0, 1, 2, 3}) {
for (int n : {1, 3}) {
for (int c : {3, 6}) {
for (int h : {9, 18}) {
for (int w : {9, 18}) {
std::unique_ptr<arena::TestCase> tester(
new ArgmaxComputeTester(place, "def", axis, n, c, h, w));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
}
}
}
}
#endif
}
```
- 在paddlelite/lite/tests/kernels/CMakeLists.txt中添加
```cmake
lite_cc_test(test_kernel_argmax_compute SRCS argmax_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
```
## 6. 编译运行
- 在paddlelite目录中,执行```./lite/tools/ci_build.sh build_test_arm```,该脚本会创建手机模拟器,并编译运行所有单测(花费时间较久)。如果运行无误,则表明添加argmax成功。
# 模型量化
本文主要介绍使用Paddle-Lite加载PaddlePaddle产出的量化模型,并进行推理执行。我们以MobileNetV1模型为示例,首先介绍准备量化模型,然后介绍部署执行。
## 准备量化模型
PaddlePaddle使用量化训练和训练后量化两种方法将FP32模型量化成Int8模型,下面分别介绍两种方法如何产出量化模型。
### 量化训练
目前,PaddlePaddle框架的量化训练主要针对卷积层(包括二维卷积和Depthwise卷积)、和全连接层,对应算子是conv2d、depthwise_conv2d和mul,更多量化训练的原理请参考[文档](https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/tutorial.md#1-quantization-aware-training%E9%87%8F%E5%8C%96%E4%BB%8B%E7%BB%8D)。Paddle-Lite支持运行PaddlePaddle框架量化训练产出的模型,可以进一步加快模型在移动端的执行速度。
温馨提示:如果您是初次接触PaddlePaddle框架,建议首先学习[新人入门](https://www.paddlepaddle.org.cn/documentation/docs/zh/1.5/beginners_guide/index_cn.html)[使用指南](https://www.paddlepaddle.org.cn/documentation/docs/zh/1.5/user_guides/index_cn.html)
您可以选择下载训练好的量化模型,或者使用PaddleSlim模型压缩工具训练得到量化模型。
#### 下载量化模型
官方发布了[MobileNetV1量化模型](https://paddle-inference-dist.bj.bcebos.com/int8%2Fpretrain%2Fmobilenet_v1_quant%2Ffloat.zip),直接下载到本地。
```bash
wget https://paddle-inference-dist.bj.bcebos.com/int8%2Fpretrain%2Fmobilenet_v1_quant%2Ffloat.zip
```
#### 使用PaddleSlim模型压缩工具训练量化模型
##### 安装PaddlePaddle
根据操作系统、安装方式、Python版本和CUDA版本,按照[官方说明](https://paddlepaddle.org.cn/start)安装PaddlePaddle。例如:
Ubuntu 16.04.4 LTS操作系统,CUDA9,cuDNN7,GPU版本安装:
```bash
pip install paddlepaddle-gpu==1.6.0.post97 -i https://mirrors.aliyun.com/pypi/simple/
```
Ubuntu 16.04.4 LTS操作系统,CPU版本安装:
```bash
pip install paddlepaddle==1.6.0 -i https://mirrors.aliyun.com/pypi/simple/
```
##### 克隆量化训练所需的代码库
克隆[PaddlePaddle/models](https://github.com/PaddlePaddle/models)到本地,并进入models/PaddleSlim路径。
```bash
git clone https://github.com/PaddlePaddle/models.git
cd models/PaddleSlim
```
##### 数据准备
###### 训练数据准备
参考[models/PaddleCV/image_classification](https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/image_classification#data-preparation)中的数据准备教程,下载训练数据,并且保存到PaddleSlim/data路径下。
###### 预训练模型准备
参考/models/PaddleSlim/run.sh脚本, 从[models/PaddleCV/image_classification](https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/image_classification#supported-models-and-performances)下载MobileNetV1的预训练模型,并保存到PaddleSlim/pretrain路径下。
经过以上三步,PaddleSlim目录下的文件结构如下所示:
```bash
.
├── compress.py # 模型压缩任务主脚本,定义了压缩任务需要的模型相关信息
├── configs # 压缩任务的配置文件,包括:蒸馏、int8量化量化、filter剪切和组合策略的配置文件
├── data # 存放训练数据(需要用户自己创建)
│   └── ILSVRC2012
├── pretrain # 存放预训练模型参数,执行run.sh自动生成
│   ├── MobileNetV1_pretrained
│   ├── MobileNetV1_pretrained.tar
│   ├── ResNet50_pretrained
│   └── ResNet50_pretrained.tar
├── docs # 文档目录
├── light_nas
├── models # 模型网络结构的定义,如MobileNetV1
├── quant_low_level_api # 量化训练的底层API, 用于灵活定制量化训练的过程,适用于高阶用户
├── reader.py # 定义数据处理逻辑
├── README.md
├── run.sh # 模型压缩任务启动脚本
└── utility.py # 定义了常用的工具方法
```
##### 压缩脚本介绍
`compress.py`中定义了执行压缩任务需要的所有模型相关的信息,这里对几个关键的步骤进行简要介绍:
###### 目标网络的定义
compress.py的以下代码片段定义了train program, 这里train program只有前向计算操作。
```python
out = model.net(input=image, class_dim=args.class_dim)
cost = fluid.layers.cross_entropy(input=out, label=label)
avg_cost = fluid.layers.mean(x=cost)
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
```
然后,通过clone方法得到eval_program, 用来在压缩过程中评估模型精度,如下:
```python
val_program = fluid.default_main_program().clone()
```
定义完目标网络结构,需要对其初始化,并根据需要加载预训练模型。
###### 定义feed_list和fetch_list
对于train program, 定义train_feed_list用于指定从train data reader中取的数据feed给哪些variable。定义train_fetch_list用于指定在训练时,需要在log中展示的结果。如果需要在训练过程中在log中打印accuracy信心,则将('acc_top1', acc_top1.name)添加到train_fetch_list中即可。
```python
train_feed_list = [('image', image.name), ('label', label.name)]
train_fetch_list = [('loss', avg_cost.name)]
```
> 注意: 在train_fetch_list里必须有loss这一项。
对于eval program. 同上定义eval_feed_list和train_fetch_list:
```python
val_feed_list = [('image', image.name), ('label', label.name)]
val_fetch_list = [('acc_top1', acc_top1.name), ('acc_top5', acc_top5.name)]
```
###### Compressor和量化配置文件
`compress.py`主要使用Compressor和yaml文件完成对模型的量化训练工作。Compressor类的定义如下:
```python
class Compressor(object):
def __init__(self,
place,
scope,
train_program,
train_reader=None,
train_feed_list=None,
train_fetch_list=None,
eval_program=None,
eval_reader=None,
eval_feed_list=None,
eval_fetch_list=None,
teacher_programs=[],
checkpoint_path='./checkpoints',
train_optimizer=None,
distiller_optimizer=None):
```
在定义Compressor对象时,需要注意以下问题:
* train program如果带反向operators和优化更新相关的operators, 参数train_optimizer需要设置为None.
* eval_program中parameter的名称需要与train_program中的parameter的名称完全一致。
* 最终保存的量化模型是在eval_program网络基础上进行剪枝保存的。所以,如果用户希望最终保存的模型可以用于inference, 则eval program需要包含推理阶段需要的各种operators.
* checkpoint保存的是float数据类型的模型。
`configs/quantization.yaml`量化配置文件示例如下:
```python
version: 1.0
strategies:
quantization_strategy:
class: 'QuantizationStrategy'
start_epoch: 0
end_epoch: 9
float_model_save_path: './output/float'
mobile_model_save_path: './output/mobile'
int8_model_save_path: './output/int8'
weight_bits: 8
activation_bits: 8
weight_quantize_type: 'abs_max'
activation_quantize_type: 'moving_average_abs_max'
save_in_nodes: ['image']
save_out_nodes: ['fc_0.tmp_2']
compressor:
epoch: 10
checkpoint_path: './checkpoints_quan/'
strategies:
- quantization_strategy
```
其中,可配置参数包括:
- **class:** 量化策略的类名称,目前仅支持`QuantizationStrategy`
- **start_epoch:** 在start_epoch开始之前,量化训练策略会往train_program和eval_program插入量化operators和反量化operators。 从start_epoch开始,进入量化训练阶段。
- **end_epoch:** 在end_epoch结束之后,会保存用户指定格式的模型。注意:end_epoch之后并不会停止量化训练,而是继续训练直到epoch数等于compressor.epoch值为止。举例来说,当start_epoch=0,end_epoch=0,compressor.epoch=2时,量化训练开始于epoch0,结束于epoch1,但保存的模型是epoch0结束时的参数状态。
- **float_model_save_path:** 保存float数据格式的模型路径,即该路径下的模型参数范围为int8范围但参数数据类型为float32。如果设置为None, 则不存储float格式的模型,默认为None。**注意:Paddle-Lite即使用该目录下的模型进行量化模型推理优化,详见本文[使用Paddle-Lite运行量化模型推理](#二使用Paddle-Lite运行量化模型推理)部分。**
- **int8_model_save_path:** 保存int8数据格式的模型路径,即该路径下的模型参数范围为int8范围且参数数据类型为int8。如果设置为None, 则不存储int8格式的模型,默认为None.
- **mobile_model_save_path:** 保存兼容paddle-mobile框架的模型路径。如果设置为None, 则不存储paddle-mobile格式的模型,默认为None。目前paddle-mobile已升级为Paddle-Lite。
- **weight_bits:** 量化weight的bit数,注意偏置(bias)参数不会被量化。
- **activation_bits:** 量化activation的bit数。
- **weight_quantize_type:** weight量化方式,目前量化训练支持`abs_max``channel_wise_abs_max`
- **activation_quantize_type:** activation量化方式,目前量化训练支持`range_abs_max``moving_average_abs_max`。PaddlePaddle中还支持 `abs_max` 方法对激活进行量化,但是该方法动态计算输入的量化scale,这会增加计算量、减慢模型推理速度,所以lite不支持 `abs_max`激活量化方式。
- **save_in_nodes:** variable名称列表。在保存量化后模型的时候,需要根据save_in_nodes对eval programg 网络进行前向遍历剪枝。默认为eval_feed_list内指定的variable的名称列表。
- **save_out_nodes:** varibale名称列表。在保存量化后模型的时候,需要根据save_out_nodes对eval programg 网络进行回溯剪枝。默认为eval_fetch_list内指定的variable的名称列表。
> **备注:**
>
> 1)`abs_max`意为在训练的每个step及inference阶段均动态计算量化scale值。`channel_wise_abs_max`与`abs_max`类似,不同点在于它会对卷积权重进行分channel求取量化scale。换言之,`abs_max`属于tensor-wise量化,而`channel_wise_abs_max`属于channel-wise量化,详细说明请猛戳[此处](https://github.com/PaddlePaddle/FluidDoc/blob/develop/doc/fluid/design/quantization/training_quantization_model_format.md)。
>
> 2)`moving_average_abs_max`和`range_abs_max`意为在训练阶段计算出一个静态的量化scale值,并将其用于inference阶段。`moving_average_abs_max`使用窗口滑动平均的方法计算量化scale,而`range_abs_max`则使用窗口绝对值最大值的方式。
>
> 3)**目前,Paddle-Lite仅支持运行weight量化方式使用`abs_max`且activation量化方式使用`moving_average_abs_max`或`range_abs_max`产出的量化模型**。
##### 执行int8量化训练
修改run.sh,即注释掉`# enable GC strategy``# for sensitivity filter pruning`之间的内容并打开`#for quantization`相关的脚本命令(所需打开注释的命令如下所示)。
```bash
# for quantization
#---------------------------
export CUDA_VISIBLE_DEVICES=0
python compress.py \
--batch_size 64 \
--model "MobileNet" \
--pretrained_model ./pretrain/MobileNetV1_pretrained \
--compress_config ./configs/quantization.yaml \
--quant_only True
```
最后,运行`sh run.sh`命令开始int8量化训练。
上述量化训练过程完成后,若按照本文中所述`configs/quantization.yaml`文件内容配置的模型输出路径,则可在models/PaddleSlim/output目录下看到`float``int8``mobile`三个目录,其中:
* float目录: 参数范围为int8范围但参数数据类型为float32的量化模型。Paddle-Lite即使用该目录下的模型文件及参数进行量化模型的部署。
* int8目录: 参数范围为int8范围且参数数据类型为int8的量化模型。
* mobile目录:参数特点与int8目录相同且兼容paddle-mobile的量化模型(目前paddle-mobile已升级为Paddle-Lite)。
### 训练后量化
下面以MobileNetV1为例,介绍使用训练后量化方法产出量化模型。关于训练后量化的原理和详细使用方法,请参考[文档](https://github.com/PaddlePaddle/models/tree/develop/PaddleSlim/quant_low_level_api)
> 该示例的代码放在[models/PaddleSlim/quant_low_level_api/](https://github.com/PaddlePaddle/models/tree/develop/PaddleSlim/quant_low_level_api)目录下。如果需要执行该示例,首先clone下来[models](https://github.com/PaddlePaddle/models.git),安装具有训练后量化功能的PaddlePaddle。因为目前Lite支持支持对conv2d、depthwise_conv2d和mul量化,所以修改[run_post_training_quanzation.sh](https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/quant_low_level_api/run_post_training_quanzation.sh) 脚本,设置is_full_quantize=False,然后执行该脚本;执行结束后,量化模型保存在`mobilenetv1_int8_model`目录下。下面介绍详细步骤。
1)**准备模型和校准数据**
安装PaddlePaddle的develop分支编译的whl包,准备已经训练好的FP32预测模型。
准备校准数据,文件结构如下。val文件夹中有100张图片,val_list.txt文件中包含图片的label。
```bash
samples_100
└──val
└──val_list.txt
```
2)**配置校准数据生成器**
MobileNetV1的输入是图片和标签,所以配置读取校准数据的sample_generator,每次返回一张图片和一个标签。详细代码在[models/PaddleSlim/reader.py](https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/reader.py)
3)**调用训练后量化**
调用训练后量化的核心代码如下,详细代码在[post_training_quantization.py](https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/quant_low_level_api/post_training_quantization.py)
``` python
place = fluid.CUDAPlace(0) if args.use_gpu == "True" else fluid.CPUPlace()
exe = fluid.Executor(place)
sample_generator = reader.val(data_dir=args.data_path)
ptq = PostTrainingQuantization(
executor=exe,
sample_generator=sample_generator,
model_dir=args.model_dir,
model_filename=args.model_filename,
params_filename=args.params_filename,
batch_size=args.batch_size,
batch_nums=args.batch_nums,
algo=args.algo,
is_full_quantize=args.is_full_quantize == "True")
quantized_program = ptq.quantize()
ptq.save_quantized_model(args.save_model_path)
```
## 使用Paddle-Lite运行量化模型推理
#### 使用模型优化工具对量化模型进行优化
接下来,使用原始的量化模型生成适合在移动端直接部署的模型。
参考[源码编译](../source_compile)配置编译环境,确保可以编译成功。参考[模型转化方法](../model_optimize_tool),首先编译model_optimize_tool工具,然后执行下面命令对量化训练的模型进行优化(注意,需要自行修改model_file、param_file和optimize_out)。
```bash
./model_optimize_tool \
--model_file=mobilenet_v1_quant/float/model \
--param_file=mobilenet_v1_quant/float/weights \
--optimize_out_type=naive_buffer \
--optimize_out=mobilenet_v1_quant_opt \
--valid_targets=arm \
--prefer_int8_kernel=true
```
如前所述,量化训练后,float目录下的模型参数范围为int8,但参数数据类型仍为float32类型,这样确实没有起到模型参数压缩的效果。但是,经过model\_optimize\_tool工具优化后对应的量化参数均会以int8类型重新存储达到参数压缩的效果,且模型结构也被优化(如进行了各种operator fuse操作)。
#### 在手机端准备量化模型文件
使用如下命令将mobilenet_v1_quant_opt目录下的量化模型文件导入到手机端:
```bash
adb push mobilenet_v1_quant_opt /data/local/tmp
```
#### 使用mobilenetv1\_light\_api运行优化后的量化模型
参考[源码编译](../source_compile)配置编译环境后,在Paddle-Lite执行如下命令获取轻量级API的demo:
```bash
cd /Paddle-Lite/build.lite.android.armv8.gcc/inference_lite_lib.android.armv8/demo/cxx/mobile_light
make clean && make -j
```
执行完上述命令后,可在`Paddle-Lite/build.lite.android.armv8.gcc/inference_lite_lib.android.armv8/demo/cxx/mobile_light/`路径下看到`mobilenetv1_light_api`可执行文件。将`mobilenetv1_light_api`导入到手机端并运行量化模型推理。执行命令如下:
```bash
adb push Paddle-Lite/build.lite.android.armv8.gcc/inference_lite_lib.android.armv8/demo/cxx/mobile_light/mobilenetv1_light_api /data/local/tmp
adb shell chmod +x /data/local/tmp/mobilenetv1_light_api
adb shell /data/local/tmp/mobilenetv1_light_api \
--model_dir=/data/local/tmp/mobilenet_v1_quant_opt
```
**程序运行结果如下:**
```bash
Output dim: 1000
Output[0]: 0.000228
Output[100]: 0.000260
Output[200]: 0.000250
Output[300]: 0.000560
Output[400]: 0.000950
Output[500]: 0.000275
Output[600]: 0.005143
Output[700]: 0.002509
Output[800]: 0.000538
Output[900]: 0.000969
```
在C++中使用Paddle-Lite API的方法请猛戳[此处](../cpp_demo),用户也可参考[mobilenetv1_light_api.cc](https://github.com/PaddlePaddle/Paddle-Lite/blob/develop/lite/demo/cxx/mobile_light/mobilenetv1_light_api.cc)的代码示例。
### FAQ
**问题**:Compiled with WITH_GPU, but no GPU found in runtime
**解答**:检查本机是否支持GPU训练,如果不支持请使用CPU训练。如果在docker进行GPU训练,请使用nvidia_docker启动容器。
**问题**:Inufficient GPU memory to allocation. at [/paddle/paddle/fluid/platform/gpu_info.cc:262]
**解答**:正确设置run.sh脚本中`CUDA_VISIBLE_DEVICES`,确保显卡剩余内存大于需要内存。
# 支持OP列表
## Ops
- affine_channel
- anchor_generator
- arg_max
- assign
- assign_value
- attention_padding_mask
- axpy
- batch_norm
- beam_search
- beam_search_decode
- bilinear_interp
- box_clip
- box_coder
- calib
- calib_once
- cast
- collect_fpn_proposals
- concat
- conditional_block
- conv2d
- conv2d_transpose
- crop
- decode_bboxes
- density_prior_box
- depthwise_conv2d
- distribute_fpn_proposals
- dropout
- elementwise_add
- elementwise_div
- elementwise_max
- elementwise_mul
- elementwise_sub
- equal
- exp
- expand
- fake_channel_wise_dequantize_max_abs
- fake_dequantize_max_abs
- fake_quantize_dequantize_moving_average_abs_max
- fake_quantize_moving_average_abs_max
- fake_quantize_range_abs_max
- fc
- feed
- fetch
- fill_constant
- fill_constant_batch_size_like
- flatten
- flatten2
- floor
- fusion_elementwise_add_activation
- fusion_elementwise_div_activation
- fusion_elementwise_max_activation
- fusion_elementwise_mul_activation
- fusion_elementwise_sub_activation
- gather
- generate_proposals
- graph_op
- greater_equal
- greater_than
- gru
- gru_unit
- hard_sigmoid
- im2sequence
- increment
- instance_norm
- io_copy
- io_copy_once
- is_empty
- layer_norm
- layout
- layout_once
- leaky_relu
- less_equal
- less_than
- lod_reset
- log
- logical_and
- logical_not
- logical_or
- logical_xor
- lookup_table
- lookup_table_v2
- lrn
- match_matrix_tensor
- matmul
- mean
- merge_lod_tensor
- mul
- multiclass_nms
- nearest_interp
- negative
- norm
- notequal
- pad2d
- pool2d
- power
- prelu
- prior_box
- range
- read_from_array
- reduce_max
- reduce_mean
- reduce_prod
- reduce_sum
- relu
- relu6
- relu_clipped
- reshape
- reshape2
- roi_align
- rsqrt
- scale
- search_aligned_mat_mul
- search_attention_padding_mask
- search_fc
- search_grnn
- search_group_padding
- search_seq_arithmetic
- search_seq_depadding
- search_seq_fc
- search_seq_softmax
- sequence_arithmetic
- sequence_concat
- sequence_expand
- sequence_expand_as
- sequence_pool
- sequence_reshape
- sequence_reverse
- sequence_softmax
- sequence_topk_avg_pooling
- shape
- shuffle_channel
- sigmoid
- slice
- softmax
- softsign
- split
- split_lod_tensor
- sqrt
- square
- squeeze
- squeeze2
- stack
- swish
- tanh
- top_k
- transpose
- transpose2
- uniform_random
- unsqueeze
- unsqueeze2
- var_conv_2d
- while
- write_to_array
- yolo_box
## Kernels
### Host kernels
- feed
- fetch
- flatten
- flatten2
- multiclass_nms
- reshape
- reshape2
### ARM kernels
- affine_channel
- anchor_generator
- arg_max
- assign
- assign_value
- axpy
- batch_norm
- beam_search
- beam_search_decode
- bilinear_interp
- box_clip
- box_coder
- cast
- collect_fpn_proposals
- concat
- conditional_block
- conv2d
- conv2d_transpose
- crop
- decode_bboxes
- density_prior_box
- depthwise_conv2d
- distribute_fpn_proposals
- dropout
- elementwise_add
- elementwise_div
- elementwise_max
- elementwise_mul
- elementwise_sub
- equal
- exp
- expand
- fc
- fill_constant
- fill_constant_batch_size_like
- floor
- fusion_elementwise_add_activation
- fusion_elementwise_div_activation
- fusion_elementwise_max_activation
- fusion_elementwise_mul_activation
- fusion_elementwise_sub_activation
- gather
- generate_proposals
- greater_equal
- greater_than
- gru
- gru_unit
- hard_sigmoid
- im2sequence
- increment
- instance_norm
- is_empty
- layer_norm
- layout
- layout_once
- leaky_relu
- less_equal
- less_than
- lod_reset
- log
- logical_and
- logical_not
- logical_or
- logical_xor
- lookup_table
- lookup_table_v2
- lrn
- matmul
- merge_lod_tensor
- mul
- nearest_interp
- negative
- norm
- not_equal
- pad2d
- pool2d
- power
- prelu
- prior_box
- range
- read_from_array
- reduce_max
- reduce_mean
- reduce_prod
- relu
- relu6
- relu_clipped
- roi_align
- rsqrt
- scale
- sequence_expand
- sequence_pool
- sequence_softmax
- shape
- shuffle_channel
- sigmoid
- slice
- softmax
- split
- split_lod_tensor
- squeeze
- squeeze2
- stack
- swish
- tanh
- top_k
- transpose
- transpose2
- unsqueeze
- unsqueeze2
- while
- write_to_array
- yolo_box
### X86 kernels
- batch_norm
- cast
- concat
- conv2d
- depthwise_conv2d
- dropout
- elementwise_add
- elementwise_sub
- fc
- fill_constant_batch_size_like
- gather
- gelu
- gru
- layer_norm
- match_matrix_tensor
- matmul
- mul
- pool2d
- reduce_sum
- relu
- reshape
- reshape2
- scale
- search_aligned_mat_mul
- search_attention_padding_mask
- search_fc
- search_grnn
- search_group_padding
- search_seq_arithmetic
- search_seq_depadding
- search_seq_fc
- search_seq_softmax
- sequence_arithmetic
- sequence_concat
- sequence_expand_as
- sequence_pool
- sequence_reverse
- sequence_topk_avg_pooling
- shape
- slice
- softmax
- softsign
- square
- squeeze
- squeeze2
- stack
- tanh
- transpose
- transpose2
- var_conv_2d
### CUDA kernels
- attention_padding_mask
- bilinear_interp
- calib
- concat
- conv
- dropout
- elementwise_add
- fusion_elementwise_add_activation
- fusion_elementwise_mul_activation
- elementwise_mul
- feed
- io_copy
- layout
- layout_once
- leaky_relu
- lookup_table
- match_matrix_tensor
- mul
- nearest_interp
- pool2d
- relu
- scale
- search_aligned_mat_mul
- search_fc
- search_grnn
- search_group_padding
- search_seq_depadding
- search_seq_fc
- sequence_arithmetic
- sequence_concat
- sequence_pool
- sequence_reverse
- sequence_topk_avg_pooling
- softmax
- transpose
- var_conv_2d
- yolo_box
### OpenCL kernels
- conv2d
- depthwise_conv2d
- elementwise_add
- fc
- fusion_elementwise_add_activation
- layout
- layout_once
- io_copy
- io_copy_once
- mul
- pool2d
- relu
......@@ -55,7 +55,8 @@
#### paddlepaddle model
骁龙855|armv7 | | |armv8 | | |
骁龙855|armv7 | armv7 | armv7 |armv8 | armv8 |armv8
----| ---- | ---- | ---- | ---- |---- |----
threads num|1 |2 |4 |1 |2 |4
mobilenet_v1 |32.19 |18.81 |10.90 |30.92 |18.31 |10.15
......@@ -64,7 +65,9 @@ shufflenet_v2 |4.67 |3.37 |2.65 |4.43 |3.15 |2.66
squeezenet_v1.1 |25.10 |15.93 |9.68 |23.28 |14.61 |8.71
mnasnet |21.84 |13.14 |7.96 |19.61 |11.88 |7.55
骁龙835|armv7 | | |armv8 | | |
骁龙835|armv7 | armv7 | armv7 |armv8 | armv8 |armv8
----| ---- | ---- | ---- | ---- |---- |----
threads num|1 |2 |4 |1 |2 |4
mobilenet_v1 |94.13 |52.17 |30.68 |88.28 |47.58 |26.64
......@@ -74,7 +77,7 @@ squeezenet_v1.1 |73.61 |42.25 |24.44 |64.87 |38.43 |23.06
mnasnet |58.22 |33.43 |20.44 |53.43 |30.20 |18.09
麒麟980|armv7 | | |armv8 | | |
麒麟980|armv7 | armv7 | armv7 |armv8 | armv8 |armv8
----| ---- | ---- | ---- | ---- |---- |----
threads num|1 |2 |4 |1 |2 |4
mobilenet_v1 |55.11 |28.24 |13.27 |34.24 |17.74 |12.41
......@@ -83,7 +86,7 @@ shufflenet_v2 |7.26 |4.94 |15.06 |5.32 |3.33 |2.82
squeezenet_v1.1 |42.73 |23.66 |57.39 |26.03 |14.53 |13.66
mnasnet |36.87 |20.15 |46.04 |21.85 |12.06 |8.68
麒麟970|armv7 | | |armv8 | | |
麒麟970|armv7 | armv7 | armv7 |armv8 | armv8 |armv8
----| ---- | ---- | ---- | ---- |---- |----
threads num|1 |2 |4 |1 |2 |4
mobilenet_v1 |97.80 |52.64 |34.46 |94.51 |49.36 |28.43
......@@ -94,32 +97,32 @@ mnasnet |61.86 |34.62 |22.68 |59.61 |32.79 |19.56
#### caffe model
骁龙855|armv7 | | |armv8 | | |
----| ---- | ---- | ---- | ---- |---- |----|
骁龙855|armv7 | armv7 | armv7 |armv8 | armv8 |armv8
----| ---- | ---- | ---- | ---- |---- |----
threads num|1 |2 |4 |1 |2 |4 |
mobilenet_v1 |32.42 |18.68 |10.86 |30.92 |18.35 |10.07 |
mobilenet_v2 |29.53 |17.76 |10.89 |27.19 |16.53 |9.75 |
shufflenet_v2 |4.61 |3.29 |2.61 |4.36 |3.11 |2.51 |
骁龙835|armv7 | | |armv8 | | |
----| ---- | ---- | ---- | ---- |---- |----|
骁龙835|armv7 | armv7 | armv7 |armv8 | armv8 |armv8
----| ---- | ---- | ---- | ---- |---- |----
threads num|1 |2 |4 |1 |2 |4 |
mobilenet_v1 |92.52 |52.34 |30.37 |88.31 |49.75 |27.29 |
mobilenet_v2 |79.50 |45.67 |28.79 |76.13 |44.01 |26.13 |
shufflenet_v2 |10.94 |7.08 |5.16 |10.64 |6.83 |5.01 |
麒麟980|armv7 | | |armv8 | | |
----| ---- | ---- | ---- | ---- |---- |----|
麒麟980|armv7 | armv7 | armv7 |armv8 | armv8 |armv8
----| ---- | ---- | ---- | ---- |---- |----
threads num|1 |2 |4 |1 |2 |4 |
mobilenet_v1 |55.36 |28.18 |13.31 |34.42 |17.93 |12.52 |
mobilenet_v2 |49.17 |26.10 |65.49 |30.50 |16.66 |11.72 |
shufflenet_v2 |8.45 |5.00 |15.65 |4.58 |3.14 |2.83 |
麒麟970|armv7 | | |armv8 | | |
----| ---- | ---- | ---- | ---- |---- |----|
麒麟970|armv7 | armv7 | armv7 |armv8 | armv8 |armv8
----| ---- | ---- | ---- | ---- |---- |----
threads num|1 |2 |4 |1 |2 |4 |
mobilenet_v1 |97.85 |53.38 |33.85 |94.29 |49.42 |28.29 |
mobilenet_v2 |87.40 |50.25 |31.85 |85.55 |48.11 |28.24 |
......@@ -127,21 +130,21 @@ shufflenet_v2 |12.16 |8.39 |6.21 |12.21 |8.33 |6.32 |
#### int8量化模型测试数据
骁龙855|armv7 | | |armv8 | | |
----| ---- | ---- | ---- | ---- |---- |----|
骁龙855|armv7 | armv7 | armv7 |armv8 | armv8 |armv8
----| ---- | ---- | ---- | ---- |---- |----
threads num|1 |2 |4 |1 |2 |4 |
mobilenet_v1 |36.80 |21.58 |11.12 | 14.01 |8.13 |4.32 |
mobilenet_v2 |28.72 |19.08 |12.49 | 17.24 |11.55 |7.82 |
骁龙835|armv7 | | |armv8 | | |
----| ---- | ---- | ---- | ---- |---- |----|
骁龙835|armv7 | armv7 | armv7 |armv8 | armv8 |armv8
----| ---- | ---- | ---- | ---- |---- |----
threads num|1 |2 |4 |1 |2 |4 |
mobilenet_v1 |60.76 |32.25 |16.66 |56.57 |29.84 |15.24 |
mobilenet_v2 |49.38 |31.10 |22.07 |47.52 |28.18 |19.24 |
麒麟970|armv7 | | |armv8 | | |
----| ---- | ---- | ---- | ---- |---- |----|
麒麟970|armv7 | armv7 | armv7 |armv8 | armv8 |armv8
----| ---- | ---- | ---- | ---- |---- |----
threads num|1 |2 |4 |1 |2 |4 |
mobilenet_v1 |65.95 |34.39 |18.68 |60.86 |30.98 |16.31 |
mobilenet_v2 |68.87 |39.39 |24.43 |65.57 |37.31 |20.87 |
......@@ -41,7 +41,7 @@ release = u''
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = ['recommonmark']
extensions = ['recommonmark', 'sphinx_markdown_tables']
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
......
......@@ -38,6 +38,10 @@ Welcome to Paddle-Lite's documentation!
:maxdepth: 1
:caption: 进阶使用指南
advanced_user_guides/support_operation_list
advanced_user_guides/add_operation
advanced_user_guides/model_quantization
.. toctree::
:maxdepth: 1
:caption: 开发者文档
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册