From 36d1bc7b2924f26936a8f768fba97415899bc3b3 Mon Sep 17 00:00:00 2001 From: Liufang Sang Date: Tue, 14 Jan 2020 13:49:16 +0800 Subject: [PATCH] change to new quantization api in quantization demo (#127) * add slim quantization demo * add quantization demo * add skip quant * fix details in doc * fix details * fix details * fix details * fix details * fix details * fix details * add model zoo link in doc * remove result in doc --- ppdet/modeling/anchor_heads/yolo_head.py | 30 +- slim/quantization/README.md | 241 +++++---------- slim/quantization/compress.py | 270 ---------------- slim/quantization/eval.py | 177 +++++++++++ slim/quantization/export_model.py | 120 ++++++++ slim/quantization/infer.py | 201 ++++++++++++ slim/quantization/train.py | 291 ++++++++++++++++++ .../yolov3_mobilenet_v1_slim.yaml | 20 -- 8 files changed, 888 insertions(+), 462 deletions(-) delete mode 100644 slim/quantization/compress.py create mode 100644 slim/quantization/eval.py create mode 100644 slim/quantization/export_model.py create mode 100644 slim/quantization/infer.py create mode 100644 slim/quantization/train.py delete mode 100644 slim/quantization/yolov3_mobilenet_v1_slim.yaml diff --git a/ppdet/modeling/anchor_heads/yolo_head.py b/ppdet/modeling/anchor_heads/yolo_head.py index 647647596..b8142d66d 100644 --- a/ppdet/modeling/anchor_heads/yolo_head.py +++ b/ppdet/modeling/anchor_heads/yolo_head.py @@ -221,20 +221,22 @@ class YOLOv3Head(object): # out channel number = mask_num * (5 + class_num) num_filters = len(self.anchor_masks[i]) * (self.num_classes + 5) - block_out = fluid.layers.conv2d( - input=tip, - num_filters=num_filters, - filter_size=1, - stride=1, - padding=0, - act=None, - param_attr=ParamAttr(name=self.prefix_name + - "yolo_output.{}.conv.weights".format(i)), - bias_attr=ParamAttr( - regularizer=L2Decay(0.), - name=self.prefix_name + - "yolo_output.{}.conv.bias".format(i))) - outputs.append(block_out) + with fluid.name_scope('yolo_output'): + block_out = fluid.layers.conv2d( + input=tip, + num_filters=num_filters, + filter_size=1, + stride=1, + padding=0, + act=None, + param_attr=ParamAttr( + name=self.prefix_name + + "yolo_output.{}.conv.weights".format(i)), + bias_attr=ParamAttr( + regularizer=L2Decay(0.), + name=self.prefix_name + + "yolo_output.{}.conv.bias".format(i))) + outputs.append(block_out) if i < len(blocks) - 1: # do not perform upsample in the last detection_block diff --git a/slim/quantization/README.md b/slim/quantization/README.md index 159b7a7f8..fb002eb4b 100644 --- a/slim/quantization/README.md +++ b/slim/quantization/README.md @@ -1,241 +1,166 @@ ->运行该示例前请安装Paddle1.6或更高版本 +>运行该示例前请安装Paddle1.6或更高版本和PaddleSlim # 检测模型量化压缩示例 ## 概述 -该示例使用PaddleSlim提供的[量化压缩策略](https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/tutorial.md#1-quantization-aware-training%E9%87%8F%E5%8C%96%E4%BB%8B%E7%BB%8D)对分类模型进行压缩。 +该示例使用PaddleSlim提供的[量化压缩API](https://paddlepaddle.github.io/PaddleSlim/api/quantization_api/)对检测模型进行压缩。 在阅读该示例前,建议您先了解以下内容: - [检测模型的常规训练方法](https://github.com/PaddlePaddle/PaddleDetection) -- [PaddleSlim使用文档](https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/usage.md) +- [PaddleSlim使用文档](https://paddlepaddle.github.io/PaddleSlim/) -## 配置文件说明 +## 安装PaddleSlim +可按照[PaddleSlim使用文档](https://paddlepaddle.github.io/PaddleSlim/)中的步骤安装PaddleSlim。 -关于配置文件如何编写您可以参考: -- [PaddleSlim配置文件编写说明](https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/usage.md#122-%E9%85%8D%E7%BD%AE%E6%96%87%E4%BB%B6%E7%9A%84%E4%BD%BF%E7%94%A8) -- [量化策略配置文件编写说明](https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/usage.md#21-%E9%87%8F%E5%8C%96%E8%AE%AD%E7%BB%83) +## 训练 + +根据 [tools/train.py](../../tools/train.py) 编写压缩脚本train.py。脚本中量化的步骤如下。 + +### 定义量化配置 +config = { + 'weight_quantize_type': 'channel_wise_abs_max', + 'activation_quantize_type': 'moving_average_abs_max', + 'quantize_op_types': ['depthwise_conv2d', 'mul', 'conv2d'], + 'not_quant_pattern': ['yolo_output'] + } -其中save_out_nodes需要得到检测结果的Variable的名称,下面介绍如何确定save_out_nodes的参数 -以MobileNet V1为例,可在compress.py中构建好网络之后,直接打印Variable得到Variable的名称信息。 -代码示例: +如何配置以及含义请参考[PaddleSlim 量化API](https://paddlepaddle.github.io/PaddleSlim/api/quantization_api/)。 + +### 插入量化反量化OP +使用[PaddleSlim quant_aware API](https://paddlepaddle.github.io/PaddleSlim/api/quantization_api/#quant_aware)在Program中插入量化和反量化OP。 ``` - eval_keys, eval_values, eval_cls = parse_fetches(fetches, eval_prog, - extra_keys) - # print(eval_values) +train_prog = quant_aware(train_prog, place, config, for_test=False) ``` -根据运行结果可看到Variable的名字为:`multiclass_nms_0.tmp_0`。 -## 训练 -根据 [tools/train.py](https://github.com/PaddlePaddle/PaddleDetection/tree/master/tools/train.py) 编写压缩脚本compress.py。 -在该脚本中定义了Compressor对象,用于执行压缩任务。 +### 关闭一些训练策略 -通过`python compress.py --help`查看可配置参数,简述如下: +因为量化要对Program做修改,所以一些会修改Program的训练策略需要关闭。``sync_batch_norm`` 和量化多卡训练同时使用时会出错,原因暂不知,因此也需要将其关闭。 +``` +build_strategy.fuse_all_reduce_ops = False +build_strategy.sync_batch_norm = False +``` -- config: 检测库的配置,其中配置了训练超参数、数据集信息等。 -- slim_file: PaddleSlim的配置文件,参见[配置文件说明](#配置文件说明)。 +### 开始训练 -您可以通过运行以下命令运行该示例。 +您可以通过运行以下命令运行该示例。(该示例是在COCO数据集上训练yolov3-mobilenetv1, 替换模型和数据集的方法和检测库类似,直接替换相应的配置文件即可) step1: 设置gpu卡 ``` export CUDA_VISIBLE_DEVICES=0 ``` step2: 开始训练 -使用PaddleDetection提供的配置文件在用8卡进行训练: + +请在PaddleDetection根目录下运行。 ``` -python compress.py \ - -s yolov3_mobilenet_v1_slim.yaml \ - -c ../../configs/yolov3_mobilenet_v1_voc.yml \ - -d "../../dataset/voc" \ - -o max_iters=258 \ +python slim/quantization/train.py \ + --eval \ + -c ./configs/yolov3_mobilenet_v1.yml \ + -o max_iters=30000 \ + save_dir=./output/mobilenetv1 \ LearningRate.base_lr=0.0001 \ - LearningRate.schedulers="[!PiecewiseDecay {gamma: 0.1, milestones: [258, 516]}]" \ - pretrain_weights=https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1_voc.tar \ - YoloTrainFeed.batch_size=64 + LearningRate.schedulers='[!PiecewiseDecay {gamma: 0.1, milestones: [10000]}]' \ + pretrain_weights=https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1.tar ``` ->通过命令行覆盖设置max_iters选项,因为PaddleDetection中训练是以`batch`为单位迭代的,并没有涉及`epoch`的概念,但是PaddleSlim需要知道当前训练进行到第几个`epoch`, 所以需要将`max_iters`设置为一个`epoch`内的`batch`的数量。 - -如果要调整训练卡数,需要调整配置文件`yolov3_mobilenet_v1_voc.yml`中的以下参数: +>通过命令行覆设置max_iters选项,因为量化的训练轮次比正常训练小很多,所以需要修改此选项。 +如果要调整训练卡数,可根据需要调整配置文件`yolov3_mobilenet_v1_voc.yml`中的以下参数: -- **max_iters:** 一个`epoch`中batch的数量,需要设置为`total_num / batch_size`, 其中`total_num`为训练样本总数量,`batch_size`为多卡上总的batch size. -- **YoloTrainFeed.batch_size:** 当使用DataLoader时,表示单张卡上的batch size; 当使用普通reader时,则表示多卡上的总的batch_size。batch_size受限于显存大小。 +- **max_iters:** 训练的总轮次。 - **LeaningRate.base_lr:** 根据多卡的总`batch_size`调整`base_lr`,两者大小正相关,可以简单的按比例进行调整。 - **LearningRate.schedulers.PiecewiseDecay.milestones:** 请根据batch size的变化对其调整。 -- **LearningRate.schedulers.PiecewiseDecay.LinearWarmup.steps:** 请根据batch size的变化对其进行调整。 -以下为4卡训练示例,通过命令行覆盖`yolov3_mobilenet_v1_voc.yml`中的参数: - -``` -python compress.py \ - -s yolov3_mobilenet_v1_slim.yaml \ - -c ../../configs/yolov3_mobilenet_v1_voc.yml \ - -d "../../dataset/voc" \ - -o max_iters=258 \ - LearningRate.base_lr=0.0001 \ - LearningRate.schedulers="[!PiecewiseDecay {gamma: 0.1, milestones: [258, 516]}]" \ - pretrain_weights=https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1_voc.tar \ - YoloTrainFeed.batch_size=64 - -``` - -以下为2卡训练示例,受显存所制,单卡`batch_size`不变, 总`batch_size`减小,`base_lr`减小,一个epoch内batch数量增加,同时需要调整学习率相关参数,如下: - -``` -python compress.py \ - -s yolov3_mobilenet_v1_slim.yaml \ - -c ../../configs/yolov3_mobilenet_v1_voc.yml \ - -d "../../dataset/voc" \ - -o max_iters=516 \ - LearningRate.base_lr=0.00005 \ - LearningRate.schedulers="[!PiecewiseDecay {gamma: 0.1, milestones: [516, 1012]}]" \ - pretrain_weights=https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1_voc.tar \ - YoloTrainFeed.batch_size=32 -``` - -通过`python compress.py --help`查看可配置参数。 -通过`python ../../tools/configure.py ${option_name} help`查看如何通过命令行覆盖配置文件`yolov3_mobilenet_v1_voc.yml`中的参数。 +通过`python slim/quantization/train.py --help`查看可配置参数。 +通过`python .tools/configure.py ${option_name} help`查看如何通过命令行覆盖配置文件中的参数。 ### 训练时的模型结构 -这部分介绍来源于[量化low-level API介绍](https://github.com/PaddlePaddle/models/tree/develop/PaddleSlim/quant_low_level_api#1-%E9%87%8F%E5%8C%96%E8%AE%AD%E7%BB%83low-level-apis%E4%BB%8B%E7%BB%8D)。 - -PaddlePaddle框架中和量化相关的IrPass, 分别有QuantizationTransformPass、QuantizationFreezePass、ConvertToInt8Pass。在训练时,对网络应用了QuantizationTransformPass,作用是在网络中的conv2d、depthwise_conv2d、mul等算子的各个输入前插入连续的量化op和反量化op,并改变相应反向算子的某些输入。示例图如下: +[PaddleSlim 量化API](https://paddlepaddle.github.io/PaddleSlim/api/quantization_api/)文档中介绍了``paddleslim.quant.quant_aware``和``paddleslim.quant.convert``两个接口。 +``paddleslim.quant.quant_aware`` 作用是在网络中的conv2d、depthwise_conv2d、mul等算子的各个输入前插入连续的量化op和反量化op,并改变相应反向算子的某些输入。示例图如下:


-图1:应用QuantizationTransformPass后的结果 +图1:应用 paddleslim.quant.quant_aware 后的结果

### 保存断点(checkpoint) +在脚本中使用保存checkpoint的代码为: +``` +# insert quantize op in eval_prog +eval_prog = quant_aware(eval_prog, place, config, for_test=True) +checkpoint.save(exe, eval_prog, os.path.join(save_dir, save_name)) +``` -如果在配置文件中设置了`checkpoint_path`, 则在压缩任务执行过程中会自动保存断点,当任务异常中断时, -重启任务会自动从`checkpoint_path`路径下按数字顺序加载最新的checkpoint文件。如果不想让重启的任务从断点恢复, -需要修改配置文件中的`checkpoint_path`,或者将`checkpoint_path`路径下文件清空。 - ->注意:配置文件中的信息不会保存在断点中,重启前对配置文件的修改将会生效。 - +### 边训练边测试 -### 保存评估和预测模型 +在脚本中边训练边测试得到的测试精度是基于图1中的网络结构进行的。 -如果在配置文件的量化策略中设置了`float_model_save_path`, `int8_model_save_path` 在训练结束后,会保存模型量化压缩之后用于预测的模型。接下来介绍这2种预测模型的区别。 +## 评估 -#### FP32模型 -在介绍量化训练时的模型结构时介绍了PaddlePaddle框架中和量化相关的IrPass, 分别是QuantizationTransformPass、QuantizationFreezePass、ConvertToInt8Pass。FP32模型是在应用QuantizationFreezePass并删除eval_program中多余的operators之后,保存的模型。 +### 最终评估模型 -QuantizationFreezePass主要用于改变IrGraph中量化op和反量化op的顺序,即将类似图1中的量化op和反量化op顺序改变为图2中的布局。除此之外,QuantizationFreezePass还会将`conv2d`、`depthwise_conv2d`、`mul`等算子的权重离线量化为int8_t范围内的值(但数据类型仍为float32),以减少预测过程中对权重的量化操作,示例如图2: +``paddleslim.quant.convert`` 主要用于改变Program中量化op和反量化op的顺序,即将类似图1中的量化op和反量化op顺序改变为图2中的布局。除此之外,``paddleslim.quant.convert`` 还会将`conv2d`、`depthwise_conv2d`、`mul`等算子参数变为量化后的int8_t范围内的值(但数据类型仍为float32),示例如图2:


-图2:应用QuantizationFreezePass后的结果 +图2:paddleslim.quant.convert 后的结果

-#### 8-bit模型 -在对训练网络进行QuantizationFreezePass之后,执行ConvertToInt8Pass, -其主要目的是将执行完QuantizationFreezePass后输出的权重类型由`FP32`更改为`INT8`。换言之,用户可以选择将量化后的权重保存为float32类型(不执行ConvertToInt8Pass)或者int8_t类型(执行ConvertToInt8Pass),示例如图3: - -

-
-图3:应用ConvertToInt8Pass后的结果 -

- -> 综上,可得在量化过程中有以下几种模型结构: - -1. 原始模型 -2. 经QuantizationTransformPass之后得到的适用于训练的量化模型结构,在${checkpoint_path}下保存的`eval_model`是这种结构,在训练过程中每个epoch结束时也使用这个网络结构进行评估,虽然这个模型结构不是最终想要的模型结构,但是每个epoch的评估结果可用来挑选模型。 -3. 经QuantizationFreezePass之后得到的FP32模型结构,具体结构已在上面进行介绍。本文档中列出的数据集的评估结果是对FP32模型结构进行评估得到的结果。这种模型结构在训练过程中只会保存一次,也就是在量化配置文件中设置的`end_epoch`结束时进行保存,如果想将其他epoch的训练结果转化成FP32模型,可使用脚本 PaddleSlim/classification/quantization/freeze.py进行转化,具体使用方法在[评估](#评估)中介绍。 -4. 经ConvertToInt8Pass之后得到的8-bit模型结构,具体结构已在上面进行介绍。这种模型结构在训练过程中只会保存一次,也就是在量化配置文件中设置的`end_epoch`结束时进行保存,如果想将其他epoch的训练结果转化成8-bit模型,可使用脚本 slim/quantization/freeze.py进行转化,具体使用方法在[评估](#评估)中介绍。 +所以在调用 ``paddleslim.quant.convert`` 之后,才得到最终的量化模型。此模型可使用PaddleLite进行加载预测,可参见教程[Paddle-Lite如何加载运行量化模型](https://github.com/PaddlePaddle/Paddle-Lite/wiki/model_quantization)。 +### 评估脚本 +使用脚本[slim/quantization/eval.py](./eval.py)进行评估。 -## 评估 - -### 每个epoch保存的评估模型 -因为量化的最终模型只有在end_epoch时保存一次,不能保证保存的模型是最好的,因此 -如果在配置文件中设置了`checkpoint_path`,则每个epoch会保存一个量化后的用于评估的模型, -该模型会保存在`${checkpoint_path}/${epoch_id}/eval_model/`路径下,包含`__model__`和`__params__`两个文件。 -其中,`__model__`用于保存模型结构信息,`__params__`用于保存参数(parameters)信息。模型结构和训练时一样。 +- 定义配置。使用和训练脚本中一样的量化配置,以得到和量化训练时同样的模型。 +- 使用 ``paddleslim.quant.quant_aware`` 插入量化和反量化op。 +- 使用 ``paddleslim.quant.convert`` 改变op顺序,得到最终量化模型进行评估。 -如果不需要保存评估模型,可以在定义Compressor对象时,将`save_eval_model`选项设置为False(默认为True)。 +评估命令: -脚本slim/eval.py中为使用该模型在评估数据集上做评估的示例。 -运行命令为: ``` -python ../eval.py \ - --model_path ${checkpoint_path}/${epoch_id}/eval_model/ \ - --model_name __model__ \ - --params_name __params__ \ - -c ../../configs/yolov3_mobilenet_v1_voc.yml \ - -d "../../dataset/voc" +python slim/quantization/eval.py -c ./configs/yolov3_mobilenet_v1.yml \ +-o weights=./output/mobilenetv1/yolov3_mobilenet_v1/best_model ``` -在评估之后,选取效果最好的epoch的模型,可使用脚本 slim/quantization/freeze.py将该模型转化为以上介绍的2种模型:FP32模型,int8模型,需要配置的参数为: +## 导出模型 -- model_path, 加载的模型路径,`为${checkpoint_path}/${epoch_id}/eval_model/` -- weight_quant_type 模型参数的量化方式,和配置文件中的类型保持一致 -- save_path `FP32`, `8-bit` 模型的保存路径,分别为 `${save_path}/float/`, `${save_path}/int8/` +使用脚本[slim/quantization/export_model.py](./export_model.py)导出模型。 -运行命令示例: -``` -python freeze.py \ - --model_path ${checkpoint_path}/${epoch_id}/eval_model/ \ - --weight_quant_type ${weight_quant_type} \ - --save_path ${any path you want} \ - -c ../../configs/yolov3_mobilenet_v1_voc.yml \ - -d "../../dataset/voc" -``` +- 定义配置。使用和训练脚本中一样的量化配置,以得到和量化训练时同样的模型。 +- 使用 ``paddleslim.quant.quant_aware`` 插入量化和反量化op。 +- 使用 ``paddleslim.quant.convert`` 改变op顺序,得到最终量化模型进行评估。 + +导出模型命令: -### 最终评估模型 -最终使用的评估模型是FP32模型,使用脚本slim/eval.py中为使用该模型在评估数据集上做评估的示例。 -运行命令为: ``` -python ../eval.py \ - --model_path ${float_model_path} - --model_name model \ - --params_name weights \ - -c ../../configs/yolov3_mobilenet_v1_voc.yml \ - -d "../../dataset/voc" + python slim/quantization/export_model.py -c ./configs/yolov3_mobilenet_v1.yml --output_dir ${save path} \ +-o weights=./output/mobilenetv1/yolov3_mobilenet_v1/best_model ``` - ## 预测 ### python预测 -FP32模型可直接使用原生PaddlePaddle Fluid预测方法进行预测。 -在脚本slim/infer.py中展示了如何使用fluid python API加载使用预测模型进行预测。 +在脚本slim/quantization/infer.py中展示了如何使用fluid python API加载使用预测模型进行预测。 运行命令示例: ``` -python ../infer.py \ - --model_path ${save_path}/float \ - --model_name model \ - --params_name weights \ - -c ../../configs/yolov3_mobilenet_v1_voc.yml \ - --infer_dir ../../demo +python slim/quantization/infer.py \ +-c ./configs/yolov3_mobilenet_v1.yml \ +--infer_dir ./demo \ +-o weights=./output/mobilenetv1/yolov3_mobilenet_v1/best_model ``` ### PaddleLite预测 -FP32模型可使用PaddleLite进行加载预测,可参见教程[Paddle-Lite如何加载运行量化模型](https://github.com/PaddlePaddle/Paddle-Lite/wiki/model_quantization) - - -## 示例结果 - ->当前release的结果并非超参调优后的最好结果,仅做示例参考,后续我们会优化当前结果。 +导出模型步骤中导出的FP32模型可使用PaddleLite进行加载预测,可参见教程[Paddle-Lite如何加载运行量化模型](https://github.com/PaddlePaddle/Paddle-Lite/wiki/model_quantization) -### MobileNetV1-YOLO-V3 -| weight量化方式 | activation量化方式| Box ap |Paddle Fluid inference time(ms)| Paddle Lite inference time(ms)| -|---|---|---|---|---| -|baseline|- |76.2%|- |-| -|abs_max|abs_max|- |- |-| -|abs_max|moving_average_abs_max|- |- |-| -|channel_wise_abs_max|abs_max|- |- |-| +## 量化结果 ## FAQ diff --git a/slim/quantization/compress.py b/slim/quantization/compress.py deleted file mode 100644 index 6ec072b28..000000000 --- a/slim/quantization/compress.py +++ /dev/null @@ -1,270 +0,0 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import time -import multiprocessing -import numpy as np -import datetime -from collections import deque -import sys -sys.path.append("../../") -from paddle.fluid.contrib.slim import Compressor -from paddle.fluid.framework import IrGraph -from paddle.fluid import core - - -def set_paddle_flags(**kwargs): - for key, value in kwargs.items(): - if os.environ.get(key, None) is None: - os.environ[key] = str(value) - - -# NOTE(paddle-dev): All of these flags should be set before -# `import paddle`. Otherwise, it would not take any effect. -set_paddle_flags( - FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory -) - -from paddle import fluid - -from ppdet.core.workspace import load_config, merge_config, create -from ppdet.data.data_feed import create_reader - -from ppdet.utils.eval_utils import parse_fetches, eval_results -from ppdet.utils.stats import TrainingStats -from ppdet.utils.cli import ArgsParser, print_total_cfg -from ppdet.utils.check import check_gpu, check_version -import ppdet.utils.checkpoint as checkpoint -from ppdet.modeling.model_input import create_feed - -import logging -FORMAT = '%(asctime)s-%(levelname)s: %(message)s' -logging.basicConfig(level=logging.INFO, format=FORMAT) -logger = logging.getLogger(__name__) - - -def eval_run(exe, compile_program, reader, keys, values, cls, test_feed, cfg): - """ - Run evaluation program, return program outputs. - """ - iter_id = 0 - results = [] - if len(cls) != 0: - values = [] - for i in range(len(cls)): - _, accum_map = cls[i].get_map_var() - cls[i].reset(exe) - values.append(accum_map) - - images_num = 0 - start_time = time.time() - has_bbox = 'bbox' in keys - for data in reader(): - data = test_feed.feed(data) - feed_data = {'image': data['image'], 'im_size': data['im_size']} - outs = exe.run(compile_program, - feed=feed_data, - fetch_list=[values[0]], - return_numpy=False) - if cfg.metric == 'VOC': - outs.append(data['gt_box']) - outs.append(data['gt_label']) - outs.append(data['is_difficult']) - elif cfg.metric == 'COCO': - outs.append(data['im_id']) - res = { - k: (np.array(v), v.recursive_sequence_lengths()) - for k, v in zip(keys, outs) - } - results.append(res) - if iter_id % 100 == 0: - logger.info('Test iter {}'.format(iter_id)) - iter_id += 1 - images_num += len(res['bbox'][1][0]) if has_bbox else 1 - logger.info('Test finish iter {}'.format(iter_id)) - - end_time = time.time() - fps = images_num / (end_time - start_time) - if has_bbox: - logger.info('Total number of images: {}, inference time: {} fps.'. - format(images_num, fps)) - else: - logger.info('Total iteration: {}, inference time: {} batch/s.'.format( - images_num, fps)) - - return results - - -def main(): - cfg = load_config(FLAGS.config) - if 'architecture' in cfg: - main_arch = cfg.architecture - else: - raise ValueError("'architecture' not specified in config file.") - - merge_config(FLAGS.opt) - if 'log_iter' not in cfg: - cfg.log_iter = 20 - - # check if set use_gpu=True in paddlepaddle cpu version - check_gpu(cfg.use_gpu) - # print_total_cfg(cfg) - #check_version() - if cfg.use_gpu: - devices_num = fluid.core.get_cuda_device_count() - else: - devices_num = int( - os.environ.get('CPU_NUM', multiprocessing.cpu_count())) - - if 'train_feed' not in cfg: - train_feed = create(main_arch + 'TrainFeed') - else: - train_feed = create(cfg.train_feed) - - if 'eval_feed' not in cfg: - eval_feed = create(main_arch + 'EvalFeed') - else: - eval_feed = create(cfg.eval_feed) - - place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace() - exe = fluid.Executor(place) - - lr_builder = create('LearningRate') - optim_builder = create('OptimizerBuilder') - - # build program - startup_prog = fluid.Program() - train_prog = fluid.Program() - with fluid.program_guard(train_prog, startup_prog): - with fluid.unique_name.guard(): - model = create(main_arch) - _, feed_vars = create_feed(train_feed, True) - train_fetches = model.train(feed_vars) - loss = train_fetches['loss'] - lr = lr_builder() - optimizer = optim_builder(lr) - optimizer.minimize(loss) - - train_reader = create_reader(train_feed, cfg.max_iters, FLAGS.dataset_dir) - - # parse train fetches - train_keys, train_values, _ = parse_fetches(train_fetches) - train_values.append(lr) - - train_fetch_list = [] - for k, v in zip(train_keys, train_values): - train_fetch_list.append((k, v)) - print("train_fetch_list: {}".format(train_fetch_list)) - - eval_prog = fluid.Program() - with fluid.program_guard(eval_prog, startup_prog): - with fluid.unique_name.guard(): - model = create(main_arch) - _, test_feed_vars = create_feed(eval_feed, True) - fetches = model.eval(test_feed_vars) - eval_prog = eval_prog.clone(True) - - eval_reader = create_reader(eval_feed, args_path=FLAGS.dataset_dir) - #eval_pyreader.decorate_sample_list_generator(eval_reader, place) - test_data_feed = fluid.DataFeeder(test_feed_vars.values(), place) - - # parse eval fetches - extra_keys = [] - if cfg.metric == 'COCO': - extra_keys = ['im_info', 'im_id', 'im_shape'] - if cfg.metric == 'VOC': - extra_keys = ['gt_box', 'gt_label', 'is_difficult'] - eval_keys, eval_values, eval_cls = parse_fetches(fetches, eval_prog, - extra_keys) - # print(eval_values) - - eval_fetch_list = [] - for k, v in zip(eval_keys, eval_values): - eval_fetch_list.append((k, v)) - - exe.run(startup_prog) - - start_iter = 0 - - checkpoint.load_params(exe, train_prog, cfg.pretrain_weights) - - best_box_ap_list = [] - - def eval_func(program, scope): - - #place = fluid.CPUPlace() - #exe = fluid.Executor(place) - results = eval_run(exe, program, eval_reader, eval_keys, eval_values, - eval_cls, test_data_feed, cfg) - - resolution = None - if 'mask' in results[0]: - resolution = model.mask_head.resolution - box_ap_stats = eval_results(results, eval_feed, cfg.metric, - cfg.num_classes, resolution, False, - FLAGS.output_eval) - if len(best_box_ap_list) == 0: - best_box_ap_list.append(box_ap_stats[0]) - elif box_ap_stats[0] > best_box_ap_list[0]: - best_box_ap_list[0] = box_ap_stats[0] - logger.info("Best test box ap: {}".format(best_box_ap_list[0])) - return best_box_ap_list[0] - - test_feed = [('image', test_feed_vars['image'].name), - ('im_size', test_feed_vars['im_size'].name)] - - com = Compressor( - place, - fluid.global_scope(), - train_prog, - train_reader=train_reader, - train_feed_list=[(key, value.name) for key, value in feed_vars.items()], - train_fetch_list=train_fetch_list, - eval_program=eval_prog, - eval_reader=eval_reader, - eval_feed_list=test_feed, - eval_func={'map': eval_func}, - eval_fetch_list=[eval_fetch_list[0]], - prune_infer_model=[["image", "im_size"], ["multiclass_nms_0.tmp_0"]], - train_optimizer=None) - com.config(FLAGS.slim_file) - com.run() - - -if __name__ == '__main__': - parser = ArgsParser() - parser.add_argument( - "-s", - "--slim_file", - default=None, - type=str, - help="Config file of PaddleSlim.") - parser.add_argument( - "--output_eval", - default=None, - type=str, - help="Evaluation directory, default is current directory.") - parser.add_argument( - "-d", - "--dataset_dir", - default=None, - type=str, - help="Dataset path, same as DataFeed.dataset.dataset_dir") - FLAGS = parser.parse_args() - main() diff --git a/slim/quantization/eval.py b/slim/quantization/eval.py new file mode 100644 index 000000000..8812a8f4a --- /dev/null +++ b/slim/quantization/eval.py @@ -0,0 +1,177 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys + +import paddle.fluid as fluid + +from ppdet.utils.eval_utils import parse_fetches, eval_run, eval_results, json_eval_results +import ppdet.utils.checkpoint as checkpoint +from ppdet.utils.check import check_gpu, check_version + +from ppdet.data.reader import create_reader + +from ppdet.core.workspace import load_config, merge_config, create +from ppdet.utils.cli import ArgsParser + +import logging +FORMAT = '%(asctime)s-%(levelname)s: %(message)s' +logging.basicConfig(level=logging.INFO, format=FORMAT) +logger = logging.getLogger(__name__) + +# import paddleslim +from paddleslim.quant import quant_aware, convert + + +def main(): + """ + Main evaluate function + """ + cfg = load_config(FLAGS.config) + if 'architecture' in cfg: + main_arch = cfg.architecture + else: + raise ValueError("'architecture' not specified in config file.") + + merge_config(FLAGS.opt) + # check if set use_gpu=True in paddlepaddle cpu version + check_gpu(cfg.use_gpu) + # check if paddlepaddle version is satisfied + check_version() + + # define executor + place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + + # build program + model = create(main_arch) + startup_prog = fluid.Program() + eval_prog = fluid.Program() + with fluid.program_guard(eval_prog, startup_prog): + with fluid.unique_name.guard(): + inputs_def = cfg['EvalReader']['inputs_def'] + test_feed_vars, loader = model.build_inputs(**inputs_def) + test_fetches = model.eval(test_feed_vars) + eval_prog = eval_prog.clone(True) + + reader = create_reader(cfg.EvalReader) + loader.set_sample_list_generator(reader, place) + + # eval already exists json file + if FLAGS.json_eval: + logger.info( + "In json_eval mode, PaddleDetection will evaluate json files in " + "output_eval directly. And proposal.json, bbox.json and mask.json " + "will be detected by default.") + json_eval_results( + cfg.metric, json_directory=FLAGS.output_eval, dataset=dataset) + return + + assert cfg.metric != 'OID', "eval process of OID dataset \ + is not supported." + + if cfg.metric == "WIDERFACE": + raise ValueError("metric type {} does not support in tools/eval.py, " + "please use tools/face_eval.py".format(cfg.metric)) + assert cfg.metric in ['COCO', 'VOC'], \ + "unknown metric type {}".format(cfg.metric) + extra_keys = [] + + if cfg.metric == 'COCO': + extra_keys = ['im_info', 'im_id', 'im_shape'] + if cfg.metric == 'VOC': + extra_keys = ['gt_bbox', 'gt_class', 'is_difficult'] + + keys, values, cls = parse_fetches(test_fetches, eval_prog, extra_keys) + + # whether output bbox is normalized in model output layer + is_bbox_normalized = False + if hasattr(model, 'is_bbox_normalized') and \ + callable(model.is_bbox_normalized): + is_bbox_normalized = model.is_bbox_normalized() + + dataset = cfg['EvalReader']['dataset'] + + sub_eval_prog = None + sub_keys = None + sub_values = None + + not_quant_pattern = [] + if FLAGS.not_quant_pattern: + not_quant_pattern = FLAGS.not_quant_pattern + config = { + 'weight_quantize_type': 'channel_wise_abs_max', + 'activation_quantize_type': 'moving_average_abs_max', + 'quantize_op_types': ['depthwise_conv2d', 'mul', 'conv2d'], + 'not_quant_pattern': not_quant_pattern + } + + eval_prog = quant_aware(eval_prog, place, config, for_test=True) + + # load model + exe.run(startup_prog) + if 'weights' in cfg: + checkpoint.load_params(exe, eval_prog, cfg.weights) + eval_prog = convert(eval_prog, place, config, save_int8=False) + + compile_program = fluid.compiler.CompiledProgram( + eval_prog).with_data_parallel() + + results = eval_run(exe, compile_program, loader, keys, values, cls, cfg, + sub_eval_prog, sub_keys, sub_values) + + # evaluation + resolution = None + if 'mask' in results[0]: + resolution = model.mask_head.resolution + # if map_type not set, use default 11point, only use in VOC eval + map_type = cfg.map_type if 'map_type' in cfg else '11point' + eval_results( + results, + cfg.metric, + cfg.num_classes, + resolution, + is_bbox_normalized, + FLAGS.output_eval, + map_type, + dataset=dataset) + + +if __name__ == '__main__': + parser = ArgsParser() + parser.add_argument( + "--json_eval", + action='store_true', + default=False, + help="Whether to re eval with already exists bbox.json or mask.json") + parser.add_argument( + "-f", + "--output_eval", + default=None, + type=str, + help="Evaluation file directory, default is current directory.") + parser.add_argument( + "--not_quant_pattern", + nargs='+', + type=str, + help="Layers which name_scope contains string in not_quant_pattern will not be quantized" + ) + + FLAGS = parser.parse_args() + main() diff --git a/slim/quantization/export_model.py b/slim/quantization/export_model.py new file mode 100644 index 000000000..15a30407f --- /dev/null +++ b/slim/quantization/export_model.py @@ -0,0 +1,120 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys +from paddle import fluid + +from ppdet.core.workspace import load_config, merge_config, create +from ppdet.modeling.model_input import create_feed +from ppdet.utils.cli import ArgsParser +import ppdet.utils.checkpoint as checkpoint +from tools.export_model import prune_feed_vars + +import logging +FORMAT = '%(asctime)s-%(levelname)s: %(message)s' +logging.basicConfig(level=logging.INFO, format=FORMAT) +logger = logging.getLogger(__name__) +from paddleslim.quant import quant_aware, convert + + +def save_infer_model(save_dir, exe, feed_vars, test_fetches, infer_prog): + feed_var_names = [var.name for var in feed_vars.values()] + target_vars = list(test_fetches.values()) + feed_var_names = prune_feed_vars(feed_var_names, target_vars, infer_prog) + logger.info("Export inference model to {}, input: {}, output: " + "{}...".format(save_dir, feed_var_names, + [str(var.name) for var in target_vars])) + fluid.io.save_inference_model( + save_dir, + feeded_var_names=feed_var_names, + target_vars=target_vars, + executor=exe, + main_program=infer_prog, + params_filename="__params__") + + +def main(): + cfg = load_config(FLAGS.config) + + if 'architecture' in cfg: + main_arch = cfg.architecture + else: + raise ValueError("'architecture' not specified in config file.") + + merge_config(FLAGS.opt) + + # Use CPU for exporting inference model instead of GPU + place = fluid.CPUPlace() + exe = fluid.Executor(place) + + model = create(main_arch) + + startup_prog = fluid.Program() + infer_prog = fluid.Program() + with fluid.program_guard(infer_prog, startup_prog): + with fluid.unique_name.guard(): + inputs_def = cfg['TestReader']['inputs_def'] + inputs_def['use_dataloader'] = False + feed_vars, _ = model.build_inputs(**inputs_def) + test_fetches = model.test(feed_vars) + infer_prog = infer_prog.clone(True) + + not_quant_pattern = [] + if FLAGS.not_quant_pattern: + not_quant_pattern = FLAGS.not_quant_pattern + config = { + 'weight_quantize_type': 'channel_wise_abs_max', + 'activation_quantize_type': 'moving_average_abs_max', + 'quantize_op_types': ['depthwise_conv2d', 'mul', 'conv2d'], + 'not_quant_pattern': not_quant_pattern + } + + infer_prog = quant_aware(infer_prog, place, config, for_test=True) + + exe.run(startup_prog) + checkpoint.load_params(exe, infer_prog, cfg.weights) + + infer_prog, int8_program = convert( + infer_prog, place, config, save_int8=True) + + save_infer_model( + os.path.join(FLAGS.output_dir, 'float'), exe, feed_vars, test_fetches, + infer_prog) + + save_infer_model( + os.path.join(FLAGS.output_dir, 'int'), exe, feed_vars, test_fetches, + int8_program) + + +if __name__ == '__main__': + parser = ArgsParser() + parser.add_argument( + "--output_dir", + type=str, + default="output", + help="Directory for storing the output model files.") + parser.add_argument( + "--not_quant_pattern", + nargs='+', + type=str, + help="Layers which name_scope contains string in not_quant_pattern will not be quantized" + ) + + FLAGS = parser.parse_args() + main() diff --git a/slim/quantization/infer.py b/slim/quantization/infer.py new file mode 100644 index 000000000..ed675e2f1 --- /dev/null +++ b/slim/quantization/infer.py @@ -0,0 +1,201 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import glob +import sys + +import numpy as np +from PIL import Image + +from paddle import fluid + +from ppdet.core.workspace import load_config, merge_config, create + +from ppdet.utils.eval_utils import parse_fetches +from ppdet.utils.cli import ArgsParser +from ppdet.utils.check import check_gpu, check_version +from ppdet.utils.visualizer import visualize_results +import ppdet.utils.checkpoint as checkpoint + +from ppdet.data.reader import create_reader +from tools.infer import get_test_images, get_save_image_name +import logging +FORMAT = '%(asctime)s-%(levelname)s: %(message)s' +logging.basicConfig(level=logging.INFO, format=FORMAT) +logger = logging.getLogger(__name__) +from paddleslim.quant import quant_aware, convert + + +def main(): + cfg = load_config(FLAGS.config) + + if 'architecture' in cfg: + main_arch = cfg.architecture + else: + raise ValueError("'architecture' not specified in config file.") + + merge_config(FLAGS.opt) + + # check if set use_gpu=True in paddlepaddle cpu version + check_gpu(cfg.use_gpu) + # check if paddlepaddle version is satisfied + check_version() + + dataset = cfg.TestReader['dataset'] + + test_images = get_test_images(FLAGS.infer_dir, FLAGS.infer_img) + dataset.set_images(test_images) + + place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + + model = create(main_arch) + + startup_prog = fluid.Program() + infer_prog = fluid.Program() + with fluid.program_guard(infer_prog, startup_prog): + with fluid.unique_name.guard(): + inputs_def = cfg['TestReader']['inputs_def'] + feed_vars, loader = model.build_inputs(**inputs_def) + test_fetches = model.test(feed_vars) + infer_prog = infer_prog.clone(True) + + reader = create_reader(cfg.TestReader) + loader.set_sample_list_generator(reader, place) + not_quant_pattern = [] + if FLAGS.not_quant_pattern: + not_quant_pattern = FLAGS.not_quant_pattern + config = { + 'weight_quantize_type': 'channel_wise_abs_max', + 'activation_quantize_type': 'moving_average_abs_max', + 'quantize_op_types': ['depthwise_conv2d', 'mul', 'conv2d'], + 'not_quant_pattern': not_quant_pattern + } + + infer_prog = quant_aware(infer_prog, place, config, for_test=True) + + exe.run(startup_prog) + + if cfg.weights: + checkpoint.load_params(exe, infer_prog, cfg.weights) + infer_prog = convert(infer_prog, place, config, save_int8=False) + + # parse infer fetches + assert cfg.metric in ['COCO', 'VOC', 'OID', 'WIDERFACE'], \ + "unknown metric type {}".format(cfg.metric) + extra_keys = [] + if cfg['metric'] in ['COCO', 'OID']: + extra_keys = ['im_info', 'im_id', 'im_shape'] + if cfg['metric'] == 'VOC' or cfg['metric'] == 'WIDERFACE': + extra_keys = ['im_id', 'im_shape'] + keys, values, _ = parse_fetches(test_fetches, infer_prog, extra_keys) + + # parse dataset category + if cfg.metric == 'COCO': + from ppdet.utils.coco_eval import bbox2out, mask2out, get_category_info + if cfg.metric == 'OID': + from ppdet.utils.oid_eval import bbox2out, get_category_info + if cfg.metric == "VOC": + from ppdet.utils.voc_eval import bbox2out, get_category_info + if cfg.metric == "WIDERFACE": + from ppdet.utils.widerface_eval_utils import bbox2out, get_category_info + + anno_file = dataset.get_anno() + with_background = dataset.with_background + use_default_label = dataset.use_default_label + + clsid2catid, catid2name = get_category_info(anno_file, with_background, + use_default_label) + + # whether output bbox is normalized in model output layer + is_bbox_normalized = False + if hasattr(model, 'is_bbox_normalized') and \ + callable(model.is_bbox_normalized): + is_bbox_normalized = model.is_bbox_normalized() + + imid2path = dataset.get_imid2path() + iter_id = 0 + try: + loader.start() + while True: + outs = exe.run(infer_prog, fetch_list=values, return_numpy=False) + res = { + k: (np.array(v), v.recursive_sequence_lengths()) + for k, v in zip(keys, outs) + } + logger.info('Infer iter {}'.format(iter_id)) + iter_id += 1 + bbox_results = None + mask_results = None + if 'bbox' in res: + bbox_results = bbox2out([res], clsid2catid, is_bbox_normalized) + if 'mask' in res: + mask_results = mask2out([res], clsid2catid, + model.mask_head.resolution) + + # visualize result + im_ids = res['im_id'][0] + for im_id in im_ids: + image_path = imid2path[int(im_id)] + image = Image.open(image_path).convert('RGB') + + image = visualize_results(image, + int(im_id), catid2name, + FLAGS.draw_threshold, bbox_results, + mask_results) + + save_name = get_save_image_name(FLAGS.output_dir, image_path) + logger.info("Detection bbox results save in {}".format( + save_name)) + image.save(save_name, quality=95) + except (StopIteration, fluid.core.EOFException): + loader.reset() + + +if __name__ == '__main__': + parser = ArgsParser() + parser.add_argument( + "--infer_dir", + type=str, + default=None, + help="Directory for images to perform inference on.") + parser.add_argument( + "--infer_img", + type=str, + default=None, + help="Image path, has higher priority over --infer_dir") + parser.add_argument( + "--output_dir", + type=str, + default="output", + help="Directory for storing the output visualization files.") + parser.add_argument( + "--draw_threshold", + type=float, + default=0.5, + help="Threshold to reserve the result for visualization.") + parser.add_argument( + "--not_quant_pattern", + nargs='+', + type=str, + help="Layers which name_scope contains string in not_quant_pattern will not be quantized" + ) + + FLAGS = parser.parse_args() + main() diff --git a/slim/quantization/train.py b/slim/quantization/train.py new file mode 100644 index 000000000..caab040d4 --- /dev/null +++ b/slim/quantization/train.py @@ -0,0 +1,291 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys +import time +import numpy as np +import datetime +from collections import deque + +from paddle import fluid + +from ppdet.core.workspace import load_config, merge_config, create +from ppdet.data.reader import create_reader + +from ppdet.utils.cli import print_total_cfg +from ppdet.utils import dist_utils +from ppdet.utils.eval_utils import parse_fetches, eval_run, eval_results +from ppdet.utils.stats import TrainingStats +from ppdet.utils.cli import ArgsParser +from ppdet.utils.check import check_gpu, check_version +import ppdet.utils.checkpoint as checkpoint +from paddleslim.quant import quant_aware, convert +import logging +FORMAT = '%(asctime)s-%(levelname)s: %(message)s' +logging.basicConfig(level=logging.INFO, format=FORMAT) +logger = logging.getLogger(__name__) + + +def main(): + env = os.environ + FLAGS.dist = 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env + if FLAGS.dist: + trainer_id = int(env['PADDLE_TRAINER_ID']) + import random + local_seed = (99 + trainer_id) + random.seed(local_seed) + np.random.seed(local_seed) + + cfg = load_config(FLAGS.config) + if 'architecture' in cfg: + main_arch = cfg.architecture + else: + raise ValueError("'architecture' not specified in config file.") + + merge_config(FLAGS.opt) + + if 'log_iter' not in cfg: + cfg.log_iter = 20 + + # check if set use_gpu=True in paddlepaddle cpu version + check_gpu(cfg.use_gpu) + # check if paddlepaddle version is satisfied + check_version() + if not FLAGS.dist or trainer_id == 0: + print_total_cfg(cfg) + + if cfg.use_gpu: + devices_num = fluid.core.get_cuda_device_count() + else: + devices_num = int(os.environ.get('CPU_NUM', 1)) + + if 'FLAGS_selected_gpus' in env: + device_id = int(env['FLAGS_selected_gpus']) + else: + device_id = 0 + place = fluid.CUDAPlace(device_id) if cfg.use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + + lr_builder = create('LearningRate') + optim_builder = create('OptimizerBuilder') + + # build program + startup_prog = fluid.Program() + train_prog = fluid.Program() + with fluid.program_guard(train_prog, startup_prog): + with fluid.unique_name.guard(): + model = create(main_arch) + + inputs_def = cfg['TrainReader']['inputs_def'] + feed_vars, train_loader = model.build_inputs(**inputs_def) + train_fetches = model.train(feed_vars) + loss = train_fetches['loss'] + lr = lr_builder() + optimizer = optim_builder(lr) + optimizer.minimize(loss) + + # parse train fetches + train_keys, train_values, _ = parse_fetches(train_fetches) + train_values.append(lr) + + if FLAGS.eval: + eval_prog = fluid.Program() + with fluid.program_guard(eval_prog, startup_prog): + with fluid.unique_name.guard(): + model = create(main_arch) + inputs_def = cfg['EvalReader']['inputs_def'] + feed_vars, eval_loader = model.build_inputs(**inputs_def) + fetches = model.eval(feed_vars) + eval_prog = eval_prog.clone(True) + + eval_reader = create_reader(cfg.EvalReader) + eval_loader.set_sample_list_generator(eval_reader, place) + + # parse eval fetches + extra_keys = [] + if cfg.metric == 'COCO': + extra_keys = ['im_info', 'im_id', 'im_shape'] + if cfg.metric == 'VOC': + extra_keys = ['gt_bbox', 'gt_class', 'is_difficult'] + if cfg.metric == 'WIDERFACE': + extra_keys = ['im_id', 'im_shape', 'gt_bbox'] + eval_keys, eval_values, eval_cls = parse_fetches(fetches, eval_prog, + extra_keys) + + # compile program for multi-devices + build_strategy = fluid.BuildStrategy() + build_strategy.fuse_all_optimizer_ops = False + build_strategy.fuse_elewise_add_act_ops = True + build_strategy.fuse_all_reduce_ops = False + + # only enable sync_bn in multi GPU devices + sync_bn = getattr(model.backbone, 'norm_type', None) == 'sync_bn' + sync_bn = False + build_strategy.sync_batch_norm = sync_bn and devices_num > 1 \ + and cfg.use_gpu + + exec_strategy = fluid.ExecutionStrategy() + # iteration number when CompiledProgram tries to drop local execution scopes. + # Set it to be 1 to save memory usages, so that unused variables in + # local execution scopes can be deleted after each iteration. + exec_strategy.num_iteration_per_drop_scope = 1 + if FLAGS.dist: + dist_utils.prepare_for_multi_process(exe, build_strategy, startup_prog, + train_prog) + exec_strategy.num_threads = 1 + + exe.run(startup_prog) + not_quant_pattern = [] + if FLAGS.not_quant_pattern: + not_quant_pattern = FLAGS.not_quant_pattern + config = { + 'weight_quantize_type': 'channel_wise_abs_max', + 'activation_quantize_type': 'moving_average_abs_max', + 'quantize_op_types': ['depthwise_conv2d', 'mul', 'conv2d'], + 'not_quant_pattern': not_quant_pattern + } + + ignore_params = cfg.finetune_exclude_pretrained_params \ + if 'finetune_exclude_pretrained_params' in cfg else [] + + fuse_bn = getattr(model.backbone, 'norm_type', None) == 'affine_channel' + + if FLAGS.resume_checkpoint: + checkpoint.load_checkpoint(exe, train_prog, FLAGS.resume_checkpoint) + start_iter = checkpoint.global_step() + elif cfg.pretrain_weights and fuse_bn and not ignore_params: + checkpoint.load_and_fusebn(exe, train_prog, cfg.pretrain_weights) + elif cfg.pretrain_weights: + checkpoint.load_params( + exe, train_prog, cfg.pretrain_weights, ignore_params=ignore_params) + # insert quantize op in train_prog, return type is CompiledProgram + train_prog = quant_aware(train_prog, place, config, for_test=False) + + compiled_train_prog = train_prog.with_data_parallel( + loss_name=loss.name, + build_strategy=build_strategy, + exec_strategy=exec_strategy) + + if FLAGS.eval: + # insert quantize op in eval_prog + eval_prog = quant_aware(eval_prog, place, config, for_test=True) + + compiled_eval_prog = fluid.compiler.CompiledProgram(eval_prog) + + start_iter = 0 + + train_reader = create_reader(cfg.TrainReader, + (cfg.max_iters - start_iter) * devices_num) + train_loader.set_sample_list_generator(train_reader, place) + + # whether output bbox is normalized in model output layer + is_bbox_normalized = False + if hasattr(model, 'is_bbox_normalized') and \ + callable(model.is_bbox_normalized): + is_bbox_normalized = model.is_bbox_normalized() + + # if map_type not set, use default 11point, only use in VOC eval + map_type = cfg.map_type if 'map_type' in cfg else '11point' + + train_stats = TrainingStats(cfg.log_smooth_window, train_keys) + train_loader.start() + start_time = time.time() + end_time = time.time() + + cfg_name = os.path.basename(FLAGS.config).split('.')[0] + save_dir = os.path.join(cfg.save_dir, cfg_name) + time_stat = deque(maxlen=cfg.log_smooth_window) + best_box_ap_list = [0.0, 0] #[map, iter] + + for it in range(start_iter, cfg.max_iters): + start_time = end_time + end_time = time.time() + time_stat.append(end_time - start_time) + time_cost = np.mean(time_stat) + eta_sec = (cfg.max_iters - it) * time_cost + eta = str(datetime.timedelta(seconds=int(eta_sec))) + outs = exe.run(compiled_train_prog, fetch_list=train_values) + stats = {k: np.array(v).mean() for k, v in zip(train_keys, outs[:-1])} + + train_stats.update(stats) + logs = train_stats.log() + if it % cfg.log_iter == 0 and (not FLAGS.dist or trainer_id == 0): + strs = 'iter: {}, lr: {:.6f}, {}, time: {:.3f}, eta: {}'.format( + it, np.mean(outs[-1]), logs, time_cost, eta) + logger.info(strs) + + if (it > 0 and it % cfg.snapshot_iter == 0 or it == cfg.max_iters - 1) \ + and (not FLAGS.dist or trainer_id == 0): + save_name = str(it) if it != cfg.max_iters - 1 else "model_final" + checkpoint.save(exe, eval_prog, os.path.join(save_dir, save_name)) + + if FLAGS.eval: + # evaluation + results = eval_run(exe, compiled_eval_prog, eval_loader, + eval_keys, eval_values, eval_cls) + resolution = None + if 'mask' in results[0]: + resolution = model.mask_head.resolution + box_ap_stats = eval_results( + results, cfg.metric, cfg.num_classes, resolution, + is_bbox_normalized, FLAGS.output_eval, map_type, + cfg['EvalReader']['dataset']) + + if box_ap_stats[0] > best_box_ap_list[0]: + best_box_ap_list[0] = box_ap_stats[0] + best_box_ap_list[1] = it + checkpoint.save(exe, eval_prog, + os.path.join(save_dir, "best_model")) + logger.info("Best test box ap: {}, in iter: {}".format( + best_box_ap_list[0], best_box_ap_list[1])) + + train_loader.reset() + + +if __name__ == '__main__': + parser = ArgsParser() + parser.add_argument( + "-r", + "--resume_checkpoint", + default=None, + type=str, + help="Checkpoint path for resuming training.") + parser.add_argument( + "--loss_scale", + default=8., + type=float, + help="Mixed precision training loss scale.") + parser.add_argument( + "--eval", + action='store_true', + default=False, + help="Whether to perform evaluation in train") + parser.add_argument( + "--output_eval", + default=None, + type=str, + help="Evaluation directory, default is current directory.") + parser.add_argument( + "--not_quant_pattern", + nargs='+', + type=str, + help="Layers which name_scope contains string in not_quant_pattern will not be quantized" + ) + FLAGS = parser.parse_args() + main() diff --git a/slim/quantization/yolov3_mobilenet_v1_slim.yaml b/slim/quantization/yolov3_mobilenet_v1_slim.yaml deleted file mode 100644 index 60a66f656..000000000 --- a/slim/quantization/yolov3_mobilenet_v1_slim.yaml +++ /dev/null @@ -1,20 +0,0 @@ -version: 1.0 -strategies: - quantization_strategy: - class: 'QuantizationStrategy' - start_epoch: 0 - end_epoch: 4 - float_model_save_path: './output/yolov3/float' - mobile_model_save_path: './output/yolov3/mobile' - int8_model_save_path: './output/yolov3/int8' - weight_bits: 8 - activation_bits: 8 - weight_quantize_type: 'abs_max' - activation_quantize_type: 'moving_average_abs_max' - save_in_nodes: ['image', 'im_size'] - save_out_nodes: ['multiclass_nms_0.tmp_0'] -compressor: - epoch: 5 - checkpoint_path: './checkpoints/yolov3/' - strategies: - - quantization_strategy -- GitLab