提交 68c1b089 编写于 作者: D dongshuilong

update slim for new trainer

上级 5d1ab55a
......@@ -66,12 +66,11 @@ cd PaddleClas
以CPU为例,若使用GPU,则将命令中改成`cpu`改成`gpu`
```bash
python3.7 deploy/slim/slim.py -m train -c ppcls/configs/slim/ResNet50_vd_quantization.yaml -o Global.device=cpu
python3.7 tools/train.py -c ppcls/configs/slim/ResNet50_vd_quantization.yaml -o Global.device=cpu
```
其中`yaml`文件解析详见[参考文档](../../docs/zh_CN/tutorials/config_description.md)。为了保证精度,`yaml`文件中已经使用`pretrained model`.
`-m`:表示`slim.py`支持的模式,有`train、eval、infer、export`,4种模式,分别为:训练、测试、动态图预测、导出`inference model`
* 单机多卡/多机多卡启动
......@@ -79,8 +78,7 @@ python3.7 deploy/slim/slim.py -m train -c ppcls/configs/slim/ResNet50_vd_quantiz
export CUDA_VISIBLE_DEVICES=0,1,2,3
python3.7 -m paddle.distributed.launch \
--gpus="0,1,2,3" \
deploy/slim/slim.py \
-m train \
tools/train.py \
-c ppcls/configs/slim/ResNet50_vd_quantization.yaml
```
......@@ -109,7 +107,7 @@ python3.7 deploy/slim/quant_post_static.py -c ppcls/configs/ImageNet/ResNet/ResN
以CPU为例,若使用GPU,则将命令中改成`cpu`改成`gpu`
```bash
python3.7 deploy/slim/slim.py -m train -c ppcls/configs/slim/ResNet50_vd_prune.yaml -o Global.device=cpu
python3.7 tools/train.py -c ppcls/configs/slim/ResNet50_vd_prune.yaml -o Global.device=cpu
```
- 单机单卡/单机多卡/多机多卡启动
......@@ -118,8 +116,7 @@ python3.7 deploy/slim/slim.py -m train -c ppcls/configs/slim/ResNet50_vd_prune.y
export CUDA_VISIBLE_DEVICES=0,1,2,3
python3.7 -m paddle.distributed.launch \
--gpus="0,1,2,3" \
deploy/slim/slim.py \
-m train \
tools/train.py \
-c ppcls/configs/slim/ResNet50_vd_prune.yaml
```
......@@ -128,9 +125,9 @@ python3.7 -m paddle.distributed.launch \
在得到在线量化训练、模型剪枝保存的模型后,可以将其导出为inference model,用于预测部署,以模型剪枝为例:
```bash
python3.7 deploy/slim/slim.py \
-m export \
python3.7 tools/export.py \
-c ppcls/configs/slim/ResNet50_vd_prune.yaml \
-o Global.pretrained_model=./output/ResNet50_vd/best_model \
-o Global.save_inference_dir=./inference
```
......
......@@ -67,12 +67,11 @@ The training command is as follow:
If using GPU, change the `cpu` to `gpu` in the following command.
```bash
python3.7 deploy/slim/slim.py -m train -c ppcls/configs/slim/ResNet50_vd_quantization.yaml -o Global.device=cpu
python3.7 tools/train.py -c ppcls/configs/slim/ResNet50_vd_quantization.yaml -o Global.device=cpu
```
The description of `yaml` file can be found in this [doc](../../docs/en/tutorials/config_en.md). To get better accuracy, the `pretrained model`is used in `yaml`.
`-m`: the mode of `slim.py` supported, include ` train, eval, infer, export`, means training models, evaluating model, inferring images using dygraph model and exporting inference model for deploy respectively.
* Distributed training
......@@ -80,7 +79,7 @@ The description of `yaml` file can be found in this [doc](../../docs/en/tutoria
export CUDA_VISIBLE_DEVICES=0,1,2,3
python3.7 -m paddle.distributed.launch \
--gpus="0,1,2,3" \
deploy/slim/slim.py \
tools/train.py \
-m train \
-c ppcls/configs/slim/ResNet50_vd_quantization.yaml
```
......@@ -108,7 +107,7 @@ If run successfully, the directory `quant_post_static_model` is generated in `Gl
If using GPU, change the `cpu` to `gpu` in the following command.
```bash
python3.7 deploy/slim/slim.py -m train -c ppcls/configs/slim/ResNet50_vd_prune.yaml -o Global.device=cpu
python3.7 tools/train.py -c ppcls/configs/slim/ResNet50_vd_prune.yaml -o Global.device=cpu
```
- Distributed training
......@@ -117,8 +116,7 @@ python3.7 deploy/slim/slim.py -m train -c ppcls/configs/slim/ResNet50_vd_prune.y
export CUDA_VISIBLE_DEVICES=0,1,2,3
python3.7 -m paddle.distributed.launch \
--gpus="0,1,2,3" \
deploy/slim/slim.py \
-m train \
tools/train.py \
-c ppcls/configs/slim/ResNet50_vd_prune.yaml
```
......@@ -129,9 +127,9 @@ python3.7 -m paddle.distributed.launch \
After getting the compressed model, we can export it as inference model for predictive deployment. Using pruned model as example:
```bash
python3.7 deploy/slim/slim.py \
-m export \
python3.7 tools/export.py \
-c ppcls/configs/slim/ResNet50_vd_prune.yaml \
-o Global.pretrained_model=./output/ResNet50_vd/best_model
-o Global.save_inference_dir=./inference
```
......
......@@ -162,7 +162,7 @@ class MobileNetV3(TheseusLayer):
if_act=True,
act="hardswish")
self.blocks = nn.Sequential(*[
self.blocks = nn.Sequential(* [
ResidualUnit(
in_c=_make_divisible(self.inplanes * self.scale if i == 0 else
self.cfg[i - 1][2] * self.scale),
......@@ -333,6 +333,8 @@ class SEModule(TheseusLayer):
stride=1,
padding=0)
self.hardsigmoid = Hardsigmoid(slope=0.2, offset=0.5)
self.conv1.skip_quant = True
self.conv2.skip_quant = True
def forward(self, x):
identity = x
......
......@@ -42,6 +42,7 @@ from ppcls.data import create_operators
from ppcls.engine.train import train_epoch
from ppcls.engine import evaluation
from ppcls.arch.gears.identity_head import IdentityHead
from ppcls.engine.slim import get_pruner, get_quaner
class Engine(object):
......@@ -170,6 +171,8 @@ class Engine(object):
self.model, self.config["Global"]["pretrained_model"])
# for slim
self.pruner = get_pruner(self.config, self.model)
self.quanter = get_quaner(self.config, self.model)
# build optimizer
if self.mode == 'train':
......@@ -334,18 +337,26 @@ class Engine(object):
self.config["Global"]["pretrained_model"])
model.eval()
model = paddle.jit.to_static(
model,
input_spec=[
paddle.static.InputSpec(
shape=[None] + self.config["Global"]["image_shape"],
dtype='float32')
])
paddle.jit.save(
model,
os.path.join(self.config["Global"]["save_inference_dir"],
"inference"))
save_path = os.path.join(self.config["Global"]["save_inference_dir"],
"inference")
if self.quanter:
self.quanter.save_quantized_model(
model,
save_path,
input_spec=[
paddle.static.InputSpec(
shape=[None] + self.config["Global"]["image_shape"],
dtype='float32')
])
else:
model = paddle.jit.to_static(
model,
input_spec=[
paddle.static.InputSpec(
shape=[None] + self.config["Global"]["image_shape"],
dtype='float32')
])
paddle.jit.save(model, save_path)
class ExportModel(nn.Layer):
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ppcls.engine.slim.prune import get_pruner
from ppcls.engine.slim.quant import get_quaner
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function
import paddle
from ppcls.utils import logger
def get_pruner(config, model):
if config.get("Slim", False) and config["Slim"].get("prune", False):
import paddleslim
prune_method_name = config["Slim"]["prune"]["name"].lower()
assert prune_method_name in [
"fpgm", "l1_norm"
], "The prune methods only support 'fpgm' and 'l1_norm'"
if prune_method_name == "fpgm":
pruner = paddleslim.dygraph.FPGMFilterPruner(
model, [1] + config["Global"]["image_shape"])
else:
pruner = paddleslim.dygraph.L1NormFilterPruner(
model, [1] + config["Global"]["image_shape"])
# prune model
_prune_model(pruner, config, model)
else:
pruner = None
return pruner
def _prune_model(pruner, config, model):
from paddleslim.analysis import dygraph_flops as flops
logger.info("FLOPs before pruning: {}GFLOPs".format(
flops(model, [1] + config["Global"]["image_shape"]) / 1e9))
model.eval()
params = []
for sublayer in model.sublayers():
for param in sublayer.parameters(include_sublayers=False):
if isinstance(sublayer, paddle.nn.Conv2D):
params.append(param.name)
ratios = {}
for param in params:
ratios[param] = config["Slim"]["prune"]["pruned_ratio"]
plan = pruner.prune_vars(ratios, [0])
logger.info("FLOPs after pruning: {}GFLOPs; pruned ratio: {}".format(
flops(model, [1] + config["Global"]["image_shape"]) / 1e9,
plan.pruned_flops))
for param in model.parameters():
if "conv2d" in param.name:
logger.info("{}\t{}".format(param.name, param.shape))
model.train()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function
import paddle
from ppcls.utils import logger
QUANT_CONFIG = {
# weight preprocess type, default is None and no preprocessing is performed.
'weight_preprocess_type': None,
# activation preprocess type, default is None and no preprocessing is performed.
'activation_preprocess_type': None,
# weight quantize type, default is 'channel_wise_abs_max'
'weight_quantize_type': 'channel_wise_abs_max',
# activation quantize type, default is 'moving_average_abs_max'
'activation_quantize_type': 'moving_average_abs_max',
# weight quantize bit num, default is 8
'weight_bits': 8,
# activation quantize bit num, default is 8
'activation_bits': 8,
# data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
'dtype': 'int8',
# window size for 'range_abs_max' quantization. default is 10000
'window_size': 10000,
# The decay coefficient of moving average, default is 0.9
'moving_rate': 0.9,
# for dygraph quantization, layers of type in quantizable_layer_type will be quantized
'quantizable_layer_type': ['Conv2D', 'Linear'],
}
def get_quaner(config, model):
if config.get("Slim", False) and config["Slim"].get("quant", False):
from paddleslim.dygraph.quant import QAT
assert config["Slim"]["quant"]["name"].lower(
) == 'pact', 'Only PACT quantization method is supported now'
QUANT_CONFIG["activation_preprocess_type"] = "PACT"
quanter = QAT(config=QUANT_CONFIG)
quanter.quantize(model)
logger.info("QAT model summary:")
paddle.summary(model, (1, 3, 224, 224))
else:
quanter = None
return quanter
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册