diff --git a/pdseg/models/modeling/deeplab.py b/pdseg/models/modeling/deeplab.py index e7ed9604b2227bb498c2eb0b863804fbe0159333..186e2406d90d291de43133550875072d790a805f 100644 --- a/pdseg/models/modeling/deeplab.py +++ b/pdseg/models/modeling/deeplab.py @@ -27,6 +27,7 @@ from models.libs.model_libs import separate_conv from models.backbone.mobilenet_v2 import MobileNetV2 as mobilenet_backbone from models.backbone.xception import Xception as xception_backbone + def encoder(input): # 编码器配置,采用ASPP架构,pooling + 1x1_conv + 三个不同尺度的空洞卷积并行, concat后1x1conv # ASPP_WITH_SEP_CONV:默认为真,使用depthwise可分离卷积,否则使用普通卷积 @@ -47,8 +48,7 @@ def encoder(input): with scope('encoder'): channel = 256 with scope("image_pool"): - image_avg = fluid.layers.reduce_mean( - input, [2, 3], keep_dim=True) + image_avg = fluid.layers.reduce_mean(input, [2, 3], keep_dim=True) image_avg = bn_relu( conv( image_avg, @@ -250,14 +250,15 @@ def deeplabv3p(img, num_classes): regularization_coeff=0.0), initializer=fluid.initializer.TruncatedNormal(loc=0.0, scale=0.01)) with scope('logit'): - logit = conv( - data, - num_classes, - 1, - stride=1, - padding=0, - bias_attr=True, - param_attr=param_attr) + with fluid.name_scope('last_conv'): + logit = conv( + data, + num_classes, + 1, + stride=1, + padding=0, + bias_attr=True, + param_attr=param_attr) logit = fluid.layers.resize_bilinear(logit, img.shape[2:]) return logit diff --git a/slim/quantization/README.md b/slim/quantization/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9af04033b3a9af84d4b1fdf081f156be6f8dc0c2 --- /dev/null +++ b/slim/quantization/README.md @@ -0,0 +1,142 @@ +>运行该示例前请安装Paddle1.6或更高版本和PaddleSlim + +# 分割模型量化压缩示例 + +## 概述 + +该示例使用PaddleSlim提供的[量化压缩API](https://paddlepaddle.github.io/PaddleSlim/api/quantization_api/)对分割模型进行压缩。 +在阅读该示例前,建议您先了解以下内容: + +- [分割模型的常规训练方法](../../docs/usage.md) +- [PaddleSlim使用文档](https://paddlepaddle.github.io/PaddleSlim/) + + +## 安装PaddleSlim +可按照[PaddleSlim使用文档](https://paddlepaddle.github.io/PaddleSlim/)中的步骤安装PaddleSlim。 + + +## 训练 + + +### 数据集 +请按照分割库的教程下载数据集并放到对应位置。 + +### 下载训练好的分割模型 + +在分割库根目录下运行以下命令: +```bash +mkdir pretrain +cd pretrain +wget https://paddleseg.bj.bcebos.com/models/mobilenet_cityscapes.tgz +tar xf mobilenet_cityscapes.tgz +``` + +### 定义量化配置 +config = { + 'weight_quantize_type': 'channel_wise_abs_max', + 'activation_quantize_type': 'moving_average_abs_max', + 'quantize_op_types': ['depthwise_conv2d', 'mul', 'conv2d'], + 'not_quant_pattern': ['last_conv'] + } + +如何配置以及含义请参考[PaddleSlim 量化API](https://paddlepaddle.github.io/PaddleSlim/api/quantization_api/)。 + +### 插入量化反量化OP +使用[PaddleSlim quant_aware API](https://paddlepaddle.github.io/PaddleSlim/api/quantization_api/#quant_aware)在Program中插入量化和反量化OP。 +``` +compiled_train_prog = quant_aware(train_prog, place, config, for_test=False) +``` + +### 关闭一些训练策略 + +因为量化要对Program做修改,所以一些会修改Program的训练策略需要关闭。``sync_batch_norm`` 和量化多卡训练同时使用时会出错, 需要将其关闭。 +``` +build_strategy.fuse_all_reduce_ops = False +build_strategy.sync_batch_norm = False +``` + +### 开始训练 + + +step1: 设置gpu卡 +``` +export CUDA_VISIBLE_DEVICES=0 +``` +step2: 将``pdseg``文件夹加到系统路径 + +分割库根目录下运行以下命令 +``` +export PYTHONPATH=$PYTHONPATH:./pdseg +``` + +step2: 开始训练 + + +在分割库根目录下运行以下命令进行训练。 +``` +python -u ./slim/quantization/train_quant.py --log_steps 10 --not_quant_pattern last_conv --cfg configs/deeplabv3p_mobilenetv2_cityscapes.yaml --use_gpu --use_mpio --do_eval \ +TRAIN.PRETRAINED_MODEL_DIR "./pretrain/mobilenet_cityscapes/" \ +TRAIN.MODEL_SAVE_DIR "./snapshots/mobilenetv2_quant" \ +MODEL.DEEPLAB.ENCODER_WITH_ASPP False \ +MODEL.DEEPLAB.ENABLE_DECODER False \ +TRAIN.SYNC_BATCH_NORM False \ +SOLVER.LR 0.0001 \ +TRAIN.SNAPSHOT_EPOCH 1 \ +SOLVER.NUM_EPOCHS 30 \ +BATCH_SIZE 16 \ +``` + + +### 训练时的模型结构 +[PaddleSlim 量化API](https://paddlepaddle.github.io/PaddleSlim/api/quantization_api/)文档中介绍了``paddleslim.quant.quant_aware``和``paddleslim.quant.convert``两个接口。 +``paddleslim.quant.quant_aware`` 作用是在网络中的conv2d、depthwise_conv2d、mul等算子的各个输入前插入连续的量化op和反量化op,并改变相应反向算子的某些输入。示例图如下: + +

+
+图1:应用 paddleslim.quant.quant_aware 后的结果 +

+ + +### 边训练边测试 + +在脚本中边训练边测试得到的测试精度是基于图1中的网络结构进行的。 + +## 评估 + +### 最终评估模型 + +``paddleslim.quant.convert`` 主要用于改变Program中量化op和反量化op的顺序,即将类似图1中的量化op和反量化op顺序改变为图2中的布局。除此之外,``paddleslim.quant.convert`` 还会将`conv2d`、`depthwise_conv2d`、`mul`等算子参数变为量化后的int8_t范围内的值(但数据类型仍为float32),示例如图2: + +

+
+图2:paddleslim.quant.convert 后的结果 +

+ +所以在调用 ``paddleslim.quant.convert`` 之后,才得到最终的量化模型。此模型可使用PaddleLite进行加载预测,可参见教程[Paddle-Lite如何加载运行量化模型](https://github.com/PaddlePaddle/Paddle-Lite/wiki/model_quantization)。 + +### 评估脚本 +使用脚本[slim/quantization/eval_quant.py](./eval_quant.py)进行评估。 + +- 定义配置。使用和训练脚本中一样的量化配置,以得到和量化训练时同样的模型。 +- 使用 ``paddleslim.quant.quant_aware`` 插入量化和反量化op。 +- 使用 ``paddleslim.quant.convert`` 改变op顺序,得到最终量化模型进行评估。 + +评估命令: + +分割库根目录下运行 +``` +python -u ./slim/quantization/eval_quant.py --cfg configs/deeplabv3p_mobilenetv2_cityscapes.yaml --use_gpu --not_quant_pattern last_conv --use_mpio --convert \ +TEST.TEST_MODEL "./snapshots/mobilenetv2_quant/best_model" \ +MODEL.DEEPLAB.ENCODER_WITH_ASPP False \ +MODEL.DEEPLAB.ENABLE_DECODER False \ +TRAIN.SYNC_BATCH_NORM False \ +BATCH_SIZE 16 \ +``` + + + +## 量化结果 + + + +## FAQ diff --git a/slim/quantization/eval_quant.py b/slim/quantization/eval_quant.py new file mode 100644 index 0000000000000000000000000000000000000000..f40021df10ac5cabee789ca4de04b7489b37f182 --- /dev/null +++ b/slim/quantization/eval_quant.py @@ -0,0 +1,203 @@ +# coding: utf8 +# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import sys +import time +import argparse +import functools +import pprint +import cv2 +import numpy as np +import paddle +import paddle.fluid as fluid + +from utils.config import cfg +from utils.timer import Timer, calculate_eta +from models.model_builder import build_model +from models.model_builder import ModelPhase +from reader import SegDataset +from metrics import ConfusionMatrix +from paddleslim.quant import quant_aware, convert + + +def parse_args(): + parser = argparse.ArgumentParser(description='PaddleSeg model evalution') + parser.add_argument( + '--cfg', + dest='cfg_file', + help='Config file for training (and optionally testing)', + default=None, + type=str) + parser.add_argument( + '--use_gpu', + dest='use_gpu', + help='Use gpu or cpu', + action='store_true', + default=False) + parser.add_argument( + '--use_mpio', + dest='use_mpio', + help='Use multiprocess IO or not', + action='store_true', + default=False) + parser.add_argument( + 'opts', + help='See utils/config.py for all options', + default=None, + nargs=argparse.REMAINDER) + parser.add_argument( + '--convert', + dest='convert', + help='Convert or not', + action='store_true', + default=False) + parser.add_argument( + "--not_quant_pattern", + nargs='+', + type=str, + help= + "Layers which name_scope contains string in not_quant_pattern will not be quantized" + ) + + if len(sys.argv) == 1: + parser.print_help() + sys.exit(1) + return parser.parse_args() + + +def evaluate(cfg, ckpt_dir=None, use_gpu=False, use_mpio=False, **kwargs): + np.set_printoptions(precision=5, suppress=True) + + startup_prog = fluid.Program() + test_prog = fluid.Program() + dataset = SegDataset( + file_list=cfg.DATASET.VAL_FILE_LIST, + mode=ModelPhase.EVAL, + data_dir=cfg.DATASET.DATA_DIR) + + def data_generator(): + #TODO: check is batch reader compatitable with Windows + if use_mpio: + data_gen = dataset.multiprocess_generator( + num_processes=cfg.DATALOADER.NUM_WORKERS, + max_queue_size=cfg.DATALOADER.BUF_SIZE) + else: + data_gen = dataset.generator() + + for b in data_gen: + yield b[0], b[1], b[2] + + py_reader, avg_loss, pred, grts, masks = build_model( + test_prog, startup_prog, phase=ModelPhase.EVAL) + + py_reader.decorate_sample_generator( + data_generator, drop_last=False, batch_size=cfg.BATCH_SIZE) + + # Get device environment + places = fluid.cuda_places() if use_gpu else fluid.cpu_places() + place = places[0] + dev_count = len(places) + print("#Device count: {}".format(dev_count)) + + exe = fluid.Executor(place) + exe.run(startup_prog) + + test_prog = test_prog.clone(for_test=True) + not_quant_pattern_list = [] + if kwargs['not_quant_pattern'] is not None: + not_quant_pattern_list = kwargs['not_quant_pattern'] + config = { + 'weight_quantize_type': 'channel_wise_abs_max', + 'activation_quantize_type': 'moving_average_abs_max', + 'quantize_op_types': ['depthwise_conv2d', 'mul', 'conv2d'], + 'not_quant_pattern': not_quant_pattern_list + } + test_prog = quant_aware(test_prog, place, config, for_test=True) + + ckpt_dir = cfg.TEST.TEST_MODEL if not ckpt_dir else ckpt_dir + + if not os.path.exists(ckpt_dir): + raise ValueError('The TEST.TEST_MODEL {} is not found'.format(ckpt_dir)) + + if ckpt_dir is not None: + print('load test model:', ckpt_dir) + fluid.io.load_persistables(exe, ckpt_dir, main_program=test_prog) + if kwargs['convert']: + test_prog = convert(test_prog, place, config) + # Use streaming confusion matrix to calculate mean_iou + np.set_printoptions( + precision=4, suppress=True, linewidth=160, floatmode="fixed") + conf_mat = ConfusionMatrix(cfg.DATASET.NUM_CLASSES, streaming=True) + fetch_list = [avg_loss.name, pred.name, grts.name, masks.name] + num_images = 0 + step = 0 + all_step = cfg.DATASET.TEST_TOTAL_IMAGES // cfg.BATCH_SIZE + 1 + timer = Timer() + timer.start() + py_reader.start() + while True: + try: + step += 1 + loss, pred, grts, masks = exe.run( + test_prog, fetch_list=fetch_list, return_numpy=True) + + loss = np.mean(np.array(loss)) + + num_images += pred.shape[0] + conf_mat.calculate(pred, grts, masks) + _, iou = conf_mat.mean_iou() + _, acc = conf_mat.accuracy() + + speed = 1.0 / timer.elapsed_time() + + print( + "[EVAL]step={} loss={:.5f} acc={:.4f} IoU={:.4f} step/sec={:.2f} | ETA {}" + .format(step, loss, acc, iou, speed, + calculate_eta(all_step - step, speed))) + timer.restart() + sys.stdout.flush() + except fluid.core.EOFException: + break + + category_iou, avg_iou = conf_mat.mean_iou() + category_acc, avg_acc = conf_mat.accuracy() + print("[EVAL]#image={} acc={:.4f} IoU={:.4f}".format( + num_images, avg_acc, avg_iou)) + print("[EVAL]Category IoU:", category_iou) + print("[EVAL]Category Acc:", category_acc) + print("[EVAL]Kappa:{:.4f}".format(conf_mat.kappa())) + + return category_iou, avg_iou, category_acc, avg_acc + + +def main(): + args = parse_args() + if args.cfg_file is not None: + cfg.update_from_file(args.cfg_file) + if args.opts: + cfg.update_from_list(args.opts) + cfg.check_and_infer() + print(pprint.pformat(cfg)) + evaluate(cfg, **args.__dict__) + + +if __name__ == '__main__': + main() diff --git a/slim/quantization/images/ConvertToInt8Pass.png b/slim/quantization/images/ConvertToInt8Pass.png new file mode 100644 index 0000000000000000000000000000000000000000..8b5849819c0bc8e592dc8f864d8945330df85ab1 Binary files /dev/null and b/slim/quantization/images/ConvertToInt8Pass.png differ diff --git a/slim/quantization/images/FreezePass.png b/slim/quantization/images/FreezePass.png new file mode 100644 index 0000000000000000000000000000000000000000..acd2b0a890a8af85bec6eecdb22e47ad386a178c Binary files /dev/null and b/slim/quantization/images/FreezePass.png differ diff --git a/slim/quantization/images/TransformForMobilePass.png b/slim/quantization/images/TransformForMobilePass.png new file mode 100644 index 0000000000000000000000000000000000000000..4104cacc67af0be1c7bc152696e2ae544127aace Binary files /dev/null and b/slim/quantization/images/TransformForMobilePass.png differ diff --git a/slim/quantization/images/TransformPass.png b/slim/quantization/images/TransformPass.png new file mode 100644 index 0000000000000000000000000000000000000000..f29ab62753e0e6ddf28d0c1dda7139705fc24b18 Binary files /dev/null and b/slim/quantization/images/TransformPass.png differ diff --git a/slim/quantization/train_quant.py b/slim/quantization/train_quant.py new file mode 100644 index 0000000000000000000000000000000000000000..6a29dccdbaeda54b06c11299fb37e979cec6e401 --- /dev/null +++ b/slim/quantization/train_quant.py @@ -0,0 +1,388 @@ +# coding: utf8 +# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import sys +import argparse +import pprint +import random +import shutil +import functools + +import paddle +import numpy as np +import paddle.fluid as fluid + +from utils.config import cfg +from utils.timer import Timer, calculate_eta +from metrics import ConfusionMatrix +from reader import SegDataset +from models.model_builder import build_model +from models.model_builder import ModelPhase +from models.model_builder import parse_shape_from_file +from eval_quant import evaluate +from vis import visualize +from utils import dist_utils +from train import save_vars, save_checkpoint, load_checkpoint, update_best_model, print_info + +from paddleslim.quant import quant_aware + + +def parse_args(): + parser = argparse.ArgumentParser(description='PaddleSeg training') + parser.add_argument( + '--cfg', + dest='cfg_file', + help='Config file for training (and optionally testing)', + default=None, + type=str) + parser.add_argument( + '--use_gpu', + dest='use_gpu', + help='Use gpu or cpu', + action='store_true', + default=False) + parser.add_argument( + '--use_mpio', + dest='use_mpio', + help='Use multiprocess I/O or not', + action='store_true', + default=False) + parser.add_argument( + '--log_steps', + dest='log_steps', + help='Display logging information at every log_steps', + default=10, + type=int) + parser.add_argument( + '--debug', + dest='debug', + help='debug mode, display detail information of training', + action='store_true') + parser.add_argument( + '--do_eval', + dest='do_eval', + help='Evaluation models result on every new checkpoint', + action='store_true') + parser.add_argument( + 'opts', + help='See utils/config.py for all options', + default=None, + nargs=argparse.REMAINDER) + parser.add_argument( + '--enable_ce', + dest='enable_ce', + help='If set True, enable continuous evaluation job.' + 'This flag is only used for internal test.', + action='store_true') + parser.add_argument( + "--not_quant_pattern", + nargs='+', + type=str, + help= + "Layers which name_scope contains string in not_quant_pattern will not be quantized" + ) + + return parser.parse_args() + + +def train_quant(cfg): + startup_prog = fluid.Program() + train_prog = fluid.Program() + if args.enable_ce: + startup_prog.random_seed = 1000 + train_prog.random_seed = 1000 + drop_last = True + + dataset = SegDataset( + file_list=cfg.DATASET.TRAIN_FILE_LIST, + mode=ModelPhase.TRAIN, + shuffle=True, + data_dir=cfg.DATASET.DATA_DIR) + + def data_generator(): + if args.use_mpio: + data_gen = dataset.multiprocess_generator( + num_processes=cfg.DATALOADER.NUM_WORKERS, + max_queue_size=cfg.DATALOADER.BUF_SIZE) + else: + data_gen = dataset.generator() + + batch_data = [] + for b in data_gen: + batch_data.append(b) + if len(batch_data) == (cfg.BATCH_SIZE // cfg.NUM_TRAINERS): + for item in batch_data: + yield item[0], item[1], item[2] + batch_data = [] + # If use sync batch norm strategy, drop last batch if number of samples + # in batch_data is less then cfg.BATCH_SIZE to avoid NCCL hang issues + if not cfg.TRAIN.SYNC_BATCH_NORM: + for item in batch_data: + yield item[0], item[1], item[2] + + # Get device environment + # places = fluid.cuda_places() if args.use_gpu else fluid.cpu_places() + # place = places[0] + gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0)) + place = fluid.CUDAPlace(gpu_id) if args.use_gpu else fluid.CPUPlace() + places = fluid.cuda_places() if args.use_gpu else fluid.cpu_places() + + # Get number of GPU + dev_count = cfg.NUM_TRAINERS if cfg.NUM_TRAINERS > 1 else len(places) + print_info("#Device count: {}".format(dev_count)) + + # Make sure BATCH_SIZE can divided by GPU cards + assert cfg.BATCH_SIZE % dev_count == 0, ( + 'BATCH_SIZE:{} not divisble by number of GPUs:{}'.format( + cfg.BATCH_SIZE, dev_count)) + # If use multi-gpu training mode, batch data will allocated to each GPU evenly + batch_size_per_dev = cfg.BATCH_SIZE // dev_count + print_info("batch_size_per_dev: {}".format(batch_size_per_dev)) + + py_reader, avg_loss, lr, pred, grts, masks = build_model( + train_prog, startup_prog, phase=ModelPhase.TRAIN) + py_reader.decorate_sample_generator( + data_generator, batch_size=batch_size_per_dev, drop_last=drop_last) + + exe = fluid.Executor(place) + exe.run(startup_prog) + + exec_strategy = fluid.ExecutionStrategy() + # Clear temporary variables every 100 iteration + if args.use_gpu: + exec_strategy.num_threads = fluid.core.get_cuda_device_count() + exec_strategy.num_iteration_per_drop_scope = 100 + build_strategy = fluid.BuildStrategy() + + if cfg.NUM_TRAINERS > 1 and args.use_gpu: + dist_utils.prepare_for_multi_process(exe, build_strategy, train_prog) + exec_strategy.num_threads = 1 + + # Resume training + begin_epoch = cfg.SOLVER.BEGIN_EPOCH + if cfg.TRAIN.RESUME_MODEL_DIR: + begin_epoch = load_checkpoint(exe, train_prog) + # Load pretrained model + elif os.path.exists(cfg.TRAIN.PRETRAINED_MODEL_DIR): + print_info('Pretrained model dir: ', cfg.TRAIN.PRETRAINED_MODEL_DIR) + load_vars = [] + load_fail_vars = [] + + def var_shape_matched(var, shape): + """ + Check whehter persitable variable shape is match with current network + """ + var_exist = os.path.exists( + os.path.join(cfg.TRAIN.PRETRAINED_MODEL_DIR, var.name)) + if var_exist: + var_shape = parse_shape_from_file( + os.path.join(cfg.TRAIN.PRETRAINED_MODEL_DIR, var.name)) + return var_shape == shape + return False + + for x in train_prog.list_vars(): + if isinstance(x, fluid.framework.Parameter): + shape = tuple(fluid.global_scope().find_var( + x.name).get_tensor().shape()) + if var_shape_matched(x, shape): + load_vars.append(x) + else: + load_fail_vars.append(x) + + fluid.io.load_vars( + exe, dirname=cfg.TRAIN.PRETRAINED_MODEL_DIR, vars=load_vars) + for var in load_vars: + print_info("Parameter[{}] loaded sucessfully!".format(var.name)) + for var in load_fail_vars: + print_info( + "Parameter[{}] don't exist or shape does not match current network, skip" + " to load it.".format(var.name)) + print_info("{}/{} pretrained parameters loaded successfully!".format( + len(load_vars), + len(load_vars) + len(load_fail_vars))) + else: + print_info( + 'Pretrained model dir {} not exists, training from scratch...'. + format(cfg.TRAIN.PRETRAINED_MODEL_DIR)) + + fetch_list = [avg_loss.name, lr.name] + if args.debug: + # Fetch more variable info and use streaming confusion matrix to + # calculate IoU results if in debug mode + np.set_printoptions( + precision=4, suppress=True, linewidth=160, floatmode="fixed") + fetch_list.extend([pred.name, grts.name, masks.name]) + cm = ConfusionMatrix(cfg.DATASET.NUM_CLASSES, streaming=True) + + not_quant_pattern = [] + if args.not_quant_pattern: + not_quant_pattern = args.not_quant_pattern + config = { + 'weight_quantize_type': 'channel_wise_abs_max', + 'activation_quantize_type': 'moving_average_abs_max', + 'quantize_op_types': ['depthwise_conv2d', 'mul', 'conv2d'], + 'not_quant_pattern': not_quant_pattern + } + compiled_train_prog = quant_aware(train_prog, place, config, for_test=False) + eval_prog = quant_aware(train_prog, place, config, for_test=True) + build_strategy.fuse_all_reduce_ops = False + build_strategy.sync_batch_norm = False + compiled_train_prog = compiled_train_prog.with_data_parallel( + loss_name=avg_loss.name, + exec_strategy=exec_strategy, + build_strategy=build_strategy) + + # trainer_id = int(os.getenv("PADDLE_TRAINER_ID", 0)) + # num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1)) + global_step = 0 + all_step = cfg.DATASET.TRAIN_TOTAL_IMAGES // cfg.BATCH_SIZE + if cfg.DATASET.TRAIN_TOTAL_IMAGES % cfg.BATCH_SIZE and drop_last != True: + all_step += 1 + all_step *= (cfg.SOLVER.NUM_EPOCHS - begin_epoch + 1) + + avg_loss = 0.0 + best_mIoU = 0.0 + + timer = Timer() + timer.start() + if begin_epoch > cfg.SOLVER.NUM_EPOCHS: + raise ValueError( + ("begin epoch[{}] is larger than cfg.SOLVER.NUM_EPOCHS[{}]").format( + begin_epoch, cfg.SOLVER.NUM_EPOCHS)) + + if args.use_mpio: + print_info("Use multiprocess reader") + else: + print_info("Use multi-thread reader") + + for epoch in range(begin_epoch, cfg.SOLVER.NUM_EPOCHS + 1): + py_reader.start() + while True: + try: + if args.debug: + # Print category IoU and accuracy to check whether the + # traning process is corresponed to expectation + loss, lr, pred, grts, masks = exe.run( + program=compiled_train_prog, + fetch_list=fetch_list, + return_numpy=True) + cm.calculate(pred, grts, masks) + avg_loss += np.mean(np.array(loss)) + global_step += 1 + + if global_step % args.log_steps == 0: + speed = args.log_steps / timer.elapsed_time() + avg_loss /= args.log_steps + category_acc, mean_acc = cm.accuracy() + category_iou, mean_iou = cm.mean_iou() + + print_info(( + "epoch={} step={} lr={:.5f} loss={:.4f} acc={:.5f} mIoU={:.5f} step/sec={:.3f} | ETA {}" + ).format(epoch, global_step, lr[0], avg_loss, mean_acc, + mean_iou, speed, + calculate_eta(all_step - global_step, speed))) + print_info("Category IoU: ", category_iou) + print_info("Category Acc: ", category_acc) + sys.stdout.flush() + avg_loss = 0.0 + cm.zero_matrix() + timer.restart() + else: + # If not in debug mode, avoid unnessary log and calculate + loss, lr = exe.run( + program=compiled_train_prog, + fetch_list=fetch_list, + return_numpy=True) + avg_loss += np.mean(np.array(loss)) + global_step += 1 + + if global_step % args.log_steps == 0 and cfg.TRAINER_ID == 0: + avg_loss /= args.log_steps + speed = args.log_steps / timer.elapsed_time() + print(( + "epoch={} step={} lr={:.5f} loss={:.4f} step/sec={:.3f} | ETA {}" + ).format(epoch, global_step, lr[0], avg_loss, speed, + calculate_eta(all_step - global_step, speed))) + sys.stdout.flush() + avg_loss = 0.0 + timer.restart() + + except fluid.core.EOFException: + py_reader.reset() + break + except Exception as e: + print(e) + + if (epoch % cfg.TRAIN.SNAPSHOT_EPOCH == 0 + or epoch == cfg.SOLVER.NUM_EPOCHS) and cfg.TRAINER_ID == 0: + ckpt_dir = save_checkpoint(exe, eval_prog, epoch) + + if args.do_eval: + print("Evaluation start") + _, mean_iou, _, mean_acc = evaluate( + cfg=cfg, + ckpt_dir=ckpt_dir, + use_gpu=args.use_gpu, + use_mpio=args.use_mpio, + not_quant_pattern=args.not_quant_pattern, + convert=False) + + if mean_iou > best_mIoU: + best_mIoU = mean_iou + update_best_model(ckpt_dir) + print_info("Save best model {} to {}, mIoU = {:.4f}".format( + ckpt_dir, + os.path.join(cfg.TRAIN.MODEL_SAVE_DIR, 'best_model'), + mean_iou)) + + # save final model + if cfg.TRAINER_ID == 0: + save_checkpoint(exe, eval_prog, 'final') + + +def main(args): + if args.cfg_file is not None: + cfg.update_from_file(args.cfg_file) + if args.opts: + cfg.update_from_list(args.opts) + if args.enable_ce: + random.seed(0) + np.random.seed(0) + + cfg.TRAINER_ID = int(os.getenv("PADDLE_TRAINER_ID", 0)) + cfg.NUM_TRAINERS = int(os.environ.get('PADDLE_TRAINERS_NUM', 1)) + + cfg.check_and_infer() + print_info(pprint.pformat(cfg)) + train_quant(cfg) + + +if __name__ == '__main__': + args = parse_args() + if fluid.core.is_compiled_with_cuda() != True and args.use_gpu == True: + print( + "You can not set use_gpu = True in the model because you are using paddlepaddle-cpu." + ) + print( + "Please: 1. Install paddlepaddle-gpu to run your models on GPU or 2. Set use_gpu=False to run models on CPU." + ) + sys.exit(1) + main(args)