未验证 提交 fb516ca4 编写于 作者: G Guanghua Yu 提交者: GitHub

delete pact in eval,infer,export_model (#1584)

上级 625c9863
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
- [检测模型的常规训练方法](https://github.com/PaddlePaddle/PaddleDetection) - [检测模型的常规训练方法](https://github.com/PaddlePaddle/PaddleDetection)
- [PaddleSlim使用文档](https://paddlepaddle.github.io/PaddleSlim/) - [PaddleSlim使用文档](https://paddlepaddle.github.io/PaddleSlim/)
- [自定义量化PACT](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/quant/pact_quant_aware)
已发布量化模型见[压缩模型库](../README.md) 已发布量化模型见[压缩模型库](../README.md)
...@@ -76,11 +77,24 @@ python slim/quantization/train.py --not_quant_pattern yolo_output \ ...@@ -76,11 +77,24 @@ python slim/quantization/train.py --not_quant_pattern yolo_output \
- **LeaningRate.base_lr:** 根据多卡的总`batch_size`调整`base_lr`,两者大小正相关,可以简单的按比例进行调整。 - **LeaningRate.base_lr:** 根据多卡的总`batch_size`调整`base_lr`,两者大小正相关,可以简单的按比例进行调整。
- **LearningRate.schedulers.PiecewiseDecay.milestones:** 请根据batch size的变化对其调整。 - **LearningRate.schedulers.PiecewiseDecay.milestones:** 请根据batch size的变化对其调整。
通过`python slim/quantization/train.py --help`查看可配置参数。 通过`python slim/quantization/train.py --help`查看可配置参数。
通过`python ./tools/configure.py help ${option_name}`查看如何通过命令行覆盖配置文件中的参数。 通过`python ./tools/configure.py help ${option_name}`查看如何通过命令行覆盖配置文件中的参数。
### PACT自定义量化
```
python slim/quantization/train.py \
--eval \
-c ./configs/yolov3_mobilenet_v3.yml \
-o max_iters=30000 \
save_dir=./output/mobilenetv3 \
LearningRate.base_lr=0.0001 \
LearningRate.schedulers="[!PiecewiseDecay {gamma: 0.1, milestones: [10000]}]" \
pretrain_weights=https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v3.pdparams \
--use_pact=True
```
- 在量化训练时,将`--use_pact=True`,即可选择PACT自定义量化
### 训练时的模型结构 ### 训练时的模型结构
[PaddleSlim 量化API](https://paddlepaddle.github.io/PaddleSlim/api/quantization_api/)文档中介绍了``paddleslim.quant.quant_aware````paddleslim.quant.convert``两个接口。 [PaddleSlim 量化API](https://paddlepaddle.github.io/PaddleSlim/api/quantization_api/)文档中介绍了``paddleslim.quant.quant_aware````paddleslim.quant.convert``两个接口。
...@@ -144,6 +158,7 @@ python slim/quantization/eval.py --not_quant_pattern yolo_output -c ./configs/y ...@@ -144,6 +158,7 @@ python slim/quantization/eval.py --not_quant_pattern yolo_output -c ./configs/y
python slim/quantization/export_model.py --not_quant_pattern yolo_output -c ./configs/yolov3_mobilenet_v1.yml --output_dir ${save path} \ python slim/quantization/export_model.py --not_quant_pattern yolo_output -c ./configs/yolov3_mobilenet_v1.yml --output_dir ${save path} \
-o weights=./output/mobilenetv1/yolov3_mobilenet_v1/best_model -o weights=./output/mobilenetv1/yolov3_mobilenet_v1/best_model
``` ```
## 预测 ## 预测
### python预测 ### python预测
...@@ -158,7 +173,6 @@ python slim/quantization/infer.py --not_quant_pattern yolo_output \ ...@@ -158,7 +173,6 @@ python slim/quantization/infer.py --not_quant_pattern yolo_output \
-o weights=./output/mobilenetv1/yolov3_mobilenet_v1/best_model -o weights=./output/mobilenetv1/yolov3_mobilenet_v1/best_model
``` ```
### PaddleLite预测 ### PaddleLite预测
导出模型步骤中导出的FP32模型可使用PaddleLite进行加载预测,可参见教程[Paddle-Lite如何加载运行量化模型](https://github.com/PaddlePaddle/Paddle-Lite/wiki/model_quantization) 导出模型步骤中导出的FP32模型可使用PaddleLite进行加载预测,可参见教程[Paddle-Lite如何加载运行量化模型](https://github.com/PaddlePaddle/Paddle-Lite/wiki/model_quantization)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import paddle
import paddle.fluid as fluid
from paddleslim.quant import quant_aware, convert
import numpy as np
from paddle.fluid.layer_helper import LayerHelper
def pact(x, name=None):
helper = LayerHelper("pact", **locals())
dtype = 'float32'
init_thres = 20
u_param_attr = fluid.ParamAttr(
name=x.name + '_pact',
initializer=fluid.initializer.ConstantInitializer(value=init_thres),
regularizer=fluid.regularizer.L2Decay(0.0001),
learning_rate=1)
u_param = helper.create_parameter(attr=u_param_attr, shape=[1], dtype=dtype)
x = fluid.layers.elementwise_sub(
x, fluid.layers.relu(fluid.layers.elementwise_sub(x, u_param)))
x = fluid.layers.elementwise_add(
x, fluid.layers.relu(fluid.layers.elementwise_sub(-u_param, x)))
return x
def get_optimizer():
return fluid.optimizer.MomentumOptimizer(0.0001, 0.9)
...@@ -39,6 +39,7 @@ from ppdet.utils.cli import ArgsParser ...@@ -39,6 +39,7 @@ from ppdet.utils.cli import ArgsParser
from ppdet.utils.check import check_gpu, check_version, check_config from ppdet.utils.check import check_gpu, check_version, check_config
import ppdet.utils.checkpoint as checkpoint import ppdet.utils.checkpoint as checkpoint
from paddleslim.quant import quant_aware, convert from paddleslim.quant import quant_aware, convert
from pact import pact, get_optimizer
import logging import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s' FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT) logging.basicConfig(level=logging.INFO, format=FORMAT)
...@@ -51,14 +52,6 @@ def save_checkpoint(exe, prog, path, train_prog): ...@@ -51,14 +52,6 @@ def save_checkpoint(exe, prog, path, train_prog):
logger.info('Save model to {}.'.format(path)) logger.info('Save model to {}.'.format(path))
fluid.io.save_persistables(exe, path, main_program=prog) fluid.io.save_persistables(exe, path, main_program=prog)
v = train_prog.global_block().var('@LR_DECAY_COUNTER@')
fluid.io.save_vars(exe, dirname=path, vars=[v])
def load_global_step(exe, prog, path):
v = prog.global_block().var('@LR_DECAY_COUNTER@')
fluid.io.load_vars(exe, path, prog, [v])
def main(): def main():
if FLAGS.eval is False: if FLAGS.eval is False:
...@@ -105,9 +98,10 @@ def main(): ...@@ -105,9 +98,10 @@ def main():
with fluid.program_guard(train_prog, startup_prog): with fluid.program_guard(train_prog, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
model = create(main_arch) model = create(main_arch)
inputs_def = cfg['TrainReader']['inputs_def'] inputs_def = cfg['TrainReader']['inputs_def']
feed_vars, train_loader = model.build_inputs(**inputs_def) feed_vars, train_loader = model.build_inputs(**inputs_def)
if FLAGS.use_pact:
feed_vars['image'].stop_gradient = False
train_fetches = model.train(feed_vars) train_fetches = model.train(feed_vars)
loss = train_fetches['loss'] loss = train_fetches['loss']
lr = lr_builder() lr = lr_builder()
...@@ -181,17 +175,30 @@ def main(): ...@@ -181,17 +175,30 @@ def main():
fuse_bn = getattr(model.backbone, 'norm_type', None) == 'affine_channel' fuse_bn = getattr(model.backbone, 'norm_type', None) == 'affine_channel'
if not FLAGS.resume_checkpoint: if cfg.pretrain_weights and fuse_bn and not ignore_params:
if cfg.pretrain_weights and fuse_bn and not ignore_params: checkpoint.load_and_fusebn(exe, train_prog, cfg.pretrain_weights)
checkpoint.load_and_fusebn(exe, train_prog, cfg.pretrain_weights) elif cfg.pretrain_weights:
elif cfg.pretrain_weights: checkpoint.load_params(
checkpoint.load_params( exe, train_prog, cfg.pretrain_weights, ignore_params=ignore_params)
exe,
train_prog, if FLAGS.use_pact:
cfg.pretrain_weights, act_preprocess_func = pact
ignore_params=ignore_params) optimizer_func = get_optimizer
executor = exe
else:
act_preprocess_func = None
optimizer_func = None
executor = None
# insert quantize op in train_prog, return type is CompiledProgram # insert quantize op in train_prog, return type is CompiledProgram
train_prog_quant = quant_aware(train_prog, place, config, for_test=False) train_prog_quant = quant_aware(
train_prog,
place,
config,
scope=None,
act_preprocess_func=act_preprocess_func,
optimizer_func=optimizer_func,
executor=executor,
for_test=False)
compiled_train_prog = train_prog_quant.with_data_parallel( compiled_train_prog = train_prog_quant.with_data_parallel(
loss_name=loss.name, loss_name=loss.name,
...@@ -200,14 +207,18 @@ def main(): ...@@ -200,14 +207,18 @@ def main():
if FLAGS.eval: if FLAGS.eval:
# insert quantize op in eval_prog # insert quantize op in eval_prog
eval_prog = quant_aware(eval_prog, place, config, for_test=True) eval_prog = quant_aware(
eval_prog,
place,
config,
scope=None,
act_preprocess_func=act_preprocess_func,
optimizer_func=optimizer_func,
executor=executor,
for_test=True)
compiled_eval_prog = fluid.CompiledProgram(eval_prog) compiled_eval_prog = fluid.CompiledProgram(eval_prog)
start_iter = 0 start_iter = 0
if FLAGS.resume_checkpoint:
checkpoint.load_checkpoint(exe, eval_prog, FLAGS.resume_checkpoint)
load_global_step(exe, train_prog, FLAGS.resume_checkpoint)
start_iter = checkpoint.global_step()
train_reader = create_reader(cfg.TrainReader, train_reader = create_reader(cfg.TrainReader,
(cfg.max_iters - start_iter) * devices_num) (cfg.max_iters - start_iter) * devices_num)
...@@ -253,8 +264,6 @@ def main(): ...@@ -253,8 +264,6 @@ def main():
if (it > 0 and it % cfg.snapshot_iter == 0 or it == cfg.max_iters - 1) \ if (it > 0 and it % cfg.snapshot_iter == 0 or it == cfg.max_iters - 1) \
and (not FLAGS.dist or trainer_id == 0): and (not FLAGS.dist or trainer_id == 0):
save_name = str(it) if it != cfg.max_iters - 1 else "model_final" save_name = str(it) if it != cfg.max_iters - 1 else "model_final"
save_checkpoint(exe, eval_prog,
os.path.join(save_dir, save_name), train_prog)
if FLAGS.eval: if FLAGS.eval:
# evaluation # evaluation
...@@ -288,12 +297,6 @@ def main(): ...@@ -288,12 +297,6 @@ def main():
if __name__ == '__main__': if __name__ == '__main__':
parser = ArgsParser() parser = ArgsParser()
parser.add_argument(
"-r",
"--resume_checkpoint",
default=None,
type=str,
help="Checkpoint path for resuming training.")
parser.add_argument( parser.add_argument(
"--loss_scale", "--loss_scale",
default=8., default=8.,
...@@ -315,5 +318,7 @@ if __name__ == '__main__': ...@@ -315,5 +318,7 @@ if __name__ == '__main__':
type=str, type=str,
help="Layers which name_scope contains string in not_quant_pattern will not be quantized" help="Layers which name_scope contains string in not_quant_pattern will not be quantized"
) )
parser.add_argument(
"--use_pact", nargs='+', type=bool, help="Whether to use PACT or not.")
FLAGS = parser.parse_args() FLAGS = parser.parse_args()
main() main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册