From 759392571f315682906ddb4c615ead34ec996a92 Mon Sep 17 00:00:00 2001 From: chengxianbin Date: Wed, 12 Aug 2020 09:37:05 +0800 Subject: [PATCH] upload yolov3-darknet53 quant code --- mindspore/nn/layer/quant.py | 16 +- mindspore/train/quant/quant.py | 4 +- .../cv/yolov3_darknet53_quant/README.md | 143 +++++ .../cv/yolov3_darknet53_quant/eval.py | 336 ++++++++++ .../scripts/run_distribute_train.sh | 83 +++ .../scripts/run_eval.sh | 67 ++ .../scripts/run_standalone_train.sh | 74 +++ .../cv/yolov3_darknet53_quant/src/__init__.py | 0 .../cv/yolov3_darknet53_quant/src/config.py | 69 +++ .../cv/yolov3_darknet53_quant/src/darknet.py | 208 +++++++ .../src/distributed_sampler.py | 60 ++ .../yolov3_darknet53_quant/src/initializer.py | 179 ++++++ .../cv/yolov3_darknet53_quant/src/logger.py | 80 +++ .../cv/yolov3_darknet53_quant/src/loss.py | 70 +++ .../src/lr_scheduler.py | 143 +++++ .../yolov3_darknet53_quant/src/transforms.py | 577 ++++++++++++++++++ .../cv/yolov3_darknet53_quant/src/util.py | 177 ++++++ .../cv/yolov3_darknet53_quant/src/yolo.py | 437 +++++++++++++ .../src/yolo_dataset.py | 184 ++++++ .../cv/yolov3_darknet53_quant/train.py | 362 +++++++++++ 20 files changed, 3263 insertions(+), 6 deletions(-) create mode 100644 model_zoo/official/cv/yolov3_darknet53_quant/README.md create mode 100644 model_zoo/official/cv/yolov3_darknet53_quant/eval.py create mode 100644 model_zoo/official/cv/yolov3_darknet53_quant/scripts/run_distribute_train.sh create mode 100644 model_zoo/official/cv/yolov3_darknet53_quant/scripts/run_eval.sh create mode 100644 model_zoo/official/cv/yolov3_darknet53_quant/scripts/run_standalone_train.sh create mode 100644 model_zoo/official/cv/yolov3_darknet53_quant/src/__init__.py create mode 100644 model_zoo/official/cv/yolov3_darknet53_quant/src/config.py create mode 100644 model_zoo/official/cv/yolov3_darknet53_quant/src/darknet.py create mode 100644 model_zoo/official/cv/yolov3_darknet53_quant/src/distributed_sampler.py create mode 100644 model_zoo/official/cv/yolov3_darknet53_quant/src/initializer.py create mode 100644 model_zoo/official/cv/yolov3_darknet53_quant/src/logger.py create mode 100644 model_zoo/official/cv/yolov3_darknet53_quant/src/loss.py create mode 100644 model_zoo/official/cv/yolov3_darknet53_quant/src/lr_scheduler.py create mode 100644 model_zoo/official/cv/yolov3_darknet53_quant/src/transforms.py create mode 100644 model_zoo/official/cv/yolov3_darknet53_quant/src/util.py create mode 100644 model_zoo/official/cv/yolov3_darknet53_quant/src/yolo.py create mode 100644 model_zoo/official/cv/yolov3_darknet53_quant/src/yolo_dataset.py create mode 100644 model_zoo/official/cv/yolov3_darknet53_quant/train.py diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index 0b94de155..760a43361 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -29,7 +29,7 @@ from mindspore._checkparam import Rel import mindspore.context as context from .normalization import BatchNorm2d, BatchNorm1d -from .activation import get_activation, ReLU +from .activation import get_activation, ReLU, LeakyReLU from ..cell import Cell from . import conv, basic from ..._checkparam import ParamValidator as validator @@ -115,7 +115,11 @@ class Conv2dBnAct(Cell): weight_init='normal', bias_init='zeros', has_bn=False, - activation=None): + momentum=0.9, + eps=1e-5, + activation=None, + alpha=0.2, + after_fake=True): super(Conv2dBnAct, self).__init__() if context.get_context('device_target') == "Ascend" and group > 1: @@ -145,9 +149,13 @@ class Conv2dBnAct(Cell): self.has_bn = validator.check_bool("has_bn", has_bn) self.has_act = activation is not None + self.after_fake = after_fake if has_bn: - self.batchnorm = BatchNorm2d(out_channels) - self.activation = get_activation(activation) + self.batchnorm = BatchNorm2d(out_channels, eps, momentum) + if activation == "leakyrelu": + self.activation = LeakyReLU(alpha) + else: + self.activation = get_activation(activation) def construct(self, x): x = self.conv(x) diff --git a/mindspore/train/quant/quant.py b/mindspore/train/quant/quant.py index b94781103..b67292f04 100644 --- a/mindspore/train/quant/quant.py +++ b/mindspore/train/quant/quant.py @@ -244,7 +244,7 @@ class ConvertToQuantNetwork: subcell.conv = conv_inner if subcell.has_act and subcell.activation is not None: subcell.activation = self._convert_activation(subcell.activation) - else: + elif subcell.after_fake: subcell.has_act = True subcell.activation = _AddFakeQuantAfterSubCell(F.identity, num_bits=self.act_bits, @@ -274,7 +274,7 @@ class ConvertToQuantNetwork: subcell.dense = dense_inner if subcell.has_act and subcell.activation is not None: subcell.activation = self._convert_activation(subcell.activation) - else: + elif subcell.after_fake: subcell.has_act = True subcell.activation = _AddFakeQuantAfterSubCell(F.identity, num_bits=self.act_bits, diff --git a/model_zoo/official/cv/yolov3_darknet53_quant/README.md b/model_zoo/official/cv/yolov3_darknet53_quant/README.md new file mode 100644 index 000000000..55942c49d --- /dev/null +++ b/model_zoo/official/cv/yolov3_darknet53_quant/README.md @@ -0,0 +1,143 @@ +# YOLOV3-DarkNet53-Quant Example + +## Description + +This is an example of training YOLOV3-DarkNet53-Quant with COCO2014 dataset in MindSpore. + +## Requirements + +- Install [MindSpore](https://www.mindspore.cn/install/en). + +- Download the dataset COCO2014. + +> Unzip the COCO2014 dataset to any path you want, the folder should include train and eval dataset as follows: + +``` +. +└─dataset + ├─train2014 + ├─val2014 + └─annotations +``` + +## Structure + +```shell +. +└─yolov3_darknet53_quant + ├─README.md + ├─scripts + ├─run_standalone_train.sh # launch standalone training(1p) + ├─run_distribute_train.sh # launch distributed training(8p) + └─run_eval.sh # launch evaluating + ├─src + ├─__init__.py # python init file + ├─config.py # parameter configuration + ├─darknet.py # backbone of network + ├─distributed_sampler.py # iterator of dataset + ├─initializer.py # initializer of parameters + ├─logger.py # log function + ├─loss.py # loss function + ├─lr_scheduler.py # generate learning rate + ├─transforms.py # Preprocess data + ├─util.py # util function + ├─yolo.py # yolov3 network + ├─yolo_dataset.py # create dataset for YOLOV3 + ├─eval.py # eval net + └─train.py # train net +``` + +## Running the example + +### Train + +#### Usage + +``` +# distributed training +sh run_distribute_train.sh [DATASET_PATH] [RESUME_YOLOV3] [MINDSPORE_HCCL_CONFIG_PATH] + +# standalone training +sh run_standalone_train.sh [DATASET_PATH] [RESUME_YOLOV3] +``` + +#### Launch + +```bash +# distributed training example(8p) +sh run_distribute_train.sh dataset/coco2014 yolov3_darknet_noquant_ckpt/0-320_102400.ckpt rank_table_8p.json + +# standalone training example(1p) +sh run_standalone_train.sh dataset/coco2014 yolov3_darknet_noquant_ckpt/0-320_102400.ckpt +``` + +> About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html). + +#### Result + +Training result will be stored in the scripts path, whose folder name begins with "train" or "train_parallel". You can find checkpoint file together with result like the followings in log.txt. + +``` +# distribute training result(8p) +epoch[0], iter[0], loss:483.341675, 0.31 imgs/sec, lr:0.0 +epoch[0], iter[100], loss:55.690952, 3.46 imgs/sec, lr:0.0 +epoch[0], iter[200], loss:54.045728, 126.54 imgs/sec, lr:0.0 +epoch[0], iter[300], loss:48.771608, 133.04 imgs/sec, lr:0.0 +epoch[0], iter[400], loss:48.486769, 139.69 imgs/sec, lr:0.0 +epoch[0], iter[500], loss:48.649275, 143.29 imgs/sec, lr:0.0 +epoch[0], iter[600], loss:44.731309, 144.03 imgs/sec, lr:0.0 +epoch[1], iter[700], loss:43.037023, 136.08 imgs/sec, lr:0.0 +epoch[1], iter[800], loss:41.514788, 132.94 imgs/sec, lr:0.0 + +… +epoch[133], iter[85700], loss:33.326716, 136.14 imgs/sec, lr:6.497331924038008e-06 +epoch[133], iter[85800], loss:34.968744, 136.76 imgs/sec, lr:6.497331924038008e-06 +epoch[134], iter[85900], loss:35.868543, 137.08 imgs/sec, lr:1.6245529650404933e-06 +epoch[134], iter[86000], loss:35.740817, 139.49 imgs/sec, lr:1.6245529650404933e-06 +epoch[134], iter[86100], loss:34.600463, 141.47 imgs/sec, lr:1.6245529650404933e-06 +epoch[134], iter[86200], loss:36.641916, 137.91 imgs/sec, lr:1.6245529650404933e-06 +epoch[134], iter[86300], loss:32.819769, 138.17 imgs/sec, lr:1.6245529650404933e-06 +epoch[134], iter[86400], loss:35.603033, 142.23 imgs/sec, lr:1.6245529650404933e-06 +epoch[134], iter[86500], loss:34.303755, 145.18 imgs/sec, lr:1.6245529650404933e-06 +... +``` + +### Infer + +#### Usage + +``` +# infer +sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] +``` + +#### Launch + +```bash +# infer with checkpoint +sh run_eval.sh dataset/coco2014/ checkpoint/0-135.ckpt + +``` + +> checkpoint can be produced in training process. + + +#### Result + +Inference result will be stored in the scripts path, whose folder name is "eval". Under this, you can find result like the followings in log.txt. + +``` +=============coco eval reulst========= +Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.310 +Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.531 +Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.322 +Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.130 +Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.326 +Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.425 +Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.260 +Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.402 +Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.429 +Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.232 +Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.450 +Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.558 +``` diff --git a/model_zoo/official/cv/yolov3_darknet53_quant/eval.py b/model_zoo/official/cv/yolov3_darknet53_quant/eval.py new file mode 100644 index 000000000..24260f6ee --- /dev/null +++ b/model_zoo/official/cv/yolov3_darknet53_quant/eval.py @@ -0,0 +1,336 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""YoloV3 eval.""" +import os +import argparse +import datetime +import time +import sys +from collections import defaultdict + +import numpy as np +from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval + +from mindspore import Tensor +from mindspore.train import ParallelMode +from mindspore import context +from mindspore.train.serialization import load_checkpoint, load_param_into_net +import mindspore as ms +from mindspore.train.quant import quant + +from src.yolo import YOLOV3DarkNet53 +from src.logger import get_logger +from src.yolo_dataset import create_yolo_dataset +from src.config import ConfigYOLOV3DarkNet53 + +devid = int(os.getenv('DEVICE_ID')) +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, device_id=devid) + + +class Redirct: + def __init__(self): + self.content = "" + + def write(self, content): + self.content += content + + def flush(self): + self.content = "" + + +class DetectionEngine: + """Detection engine.""" + def __init__(self, args): + self.ignore_threshold = args.ignore_threshold + self.labels = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', + 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', + 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', + 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', + 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', + 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', + 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', + 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', + 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', + 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'] + self.num_classes = len(self.labels) + self.results = {} + self.file_path = '' + self.save_prefix = args.outputs_dir + self.annFile = args.annFile + self._coco = COCO(self.annFile) + self._img_ids = list(sorted(self._coco.imgs.keys())) + self.det_boxes = [] + self.nms_thresh = args.nms_thresh + self.coco_catIds = self._coco.getCatIds() + + def do_nms_for_results(self): + """Get result boxes.""" + for img_id in self.results: + for clsi in self.results[img_id]: + dets = self.results[img_id][clsi] + dets = np.array(dets) + keep_index = self._nms(dets, self.nms_thresh) + + keep_box = [{'image_id': int(img_id), + 'category_id': int(clsi), + 'bbox': list(dets[i][:4].astype(float)), + 'score': dets[i][4].astype(float)} + for i in keep_index] + self.det_boxes.extend(keep_box) + + def _nms(self, dets, thresh): + """Calculate NMS.""" + # conver xywh -> xmin ymin xmax ymax + x1 = dets[:, 0] + y1 = dets[:, 1] + x2 = x1 + dets[:, 2] + y2 = y1 + dets[:, 3] + scores = dets[:, 4] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= thresh)[0] + order = order[inds + 1] + return keep + + def write_result(self): + """Save result to file.""" + import json + t = datetime.datetime.now().strftime('_%Y_%m_%d_%H_%M_%S') + try: + self.file_path = self.save_prefix + '/predict' + t + '.json' + f = open(self.file_path, 'w') + json.dump(self.det_boxes, f) + except IOError as e: + raise RuntimeError("Unable to open json file to dump. What(): {}".format(str(e))) + else: + f.close() + return self.file_path + + def get_eval_result(self): + """Get eval result.""" + cocoGt = COCO(self.annFile) + cocoDt = cocoGt.loadRes(self.file_path) + cocoEval = COCOeval(cocoGt, cocoDt, 'bbox') + cocoEval.evaluate() + cocoEval.accumulate() + rdct = Redirct() + stdout = sys.stdout + sys.stdout = rdct + cocoEval.summarize() + sys.stdout = stdout + return rdct.content + + def detect(self, outputs, batch, image_shape, image_id): + """Detect boxes.""" + outputs_num = len(outputs) + # output [|32, 52, 52, 3, 85| ] + for batch_id in range(batch): + for out_id in range(outputs_num): + # 32, 52, 52, 3, 85 + out_item = outputs[out_id] + # 52, 52, 3, 85 + out_item_single = out_item[batch_id, :] + # get number of items in one head, [B, gx, gy, anchors, 5+80] + dimensions = out_item_single.shape[:-1] + out_num = 1 + for d in dimensions: + out_num *= d + ori_w, ori_h = image_shape[batch_id] + img_id = int(image_id[batch_id]) + x = out_item_single[..., 0] * ori_w + y = out_item_single[..., 1] * ori_h + w = out_item_single[..., 2] * ori_w + h = out_item_single[..., 3] * ori_h + + conf = out_item_single[..., 4:5] + cls_emb = out_item_single[..., 5:] + + cls_argmax = np.expand_dims(np.argmax(cls_emb, axis=-1), axis=-1) + x = x.reshape(-1) + y = y.reshape(-1) + w = w.reshape(-1) + h = h.reshape(-1) + cls_emb = cls_emb.reshape(-1, 80) + conf = conf.reshape(-1) + cls_argmax = cls_argmax.reshape(-1) + + x_top_left = x - w / 2. + y_top_left = y - h / 2. + # creat all False + flag = np.random.random(cls_emb.shape) > sys.maxsize + for i in range(flag.shape[0]): + c = cls_argmax[i] + flag[i, c] = True + confidence = cls_emb[flag] * conf + for x_lefti, y_lefti, wi, hi, confi, clsi in zip(x_top_left, y_top_left, w, h, confidence, cls_argmax): + if confi < self.ignore_threshold: + continue + if img_id not in self.results: + self.results[img_id] = defaultdict(list) + x_lefti = max(0, x_lefti) + y_lefti = max(0, y_lefti) + wi = min(wi, ori_w) + hi = min(hi, ori_h) + # transform catId to match coco + coco_clsi = self.coco_catIds[clsi] + self.results[img_id][coco_clsi].append([x_lefti, y_lefti, wi, hi, confi]) + + +def parse_args(): + """Parse arguments.""" + parser = argparse.ArgumentParser('mindspore coco testing') + + # dataset related + parser.add_argument('--data_dir', type=str, default='', help='train data dir') + parser.add_argument('--per_batch_size', default=1, type=int, help='batch size for per gpu') + + # network related + parser.add_argument('--pretrained', default='', type=str, help='model_path, local pretrained model to load') + + # logging related + parser.add_argument('--log_path', type=str, default='outputs/', help='checkpoint save location') + + # detect_related + parser.add_argument('--nms_thresh', type=float, default=0.5, help='threshold for NMS') + parser.add_argument('--annFile', type=str, default='', help='path to annotation') + parser.add_argument('--testing_shape', type=str, default='', help='shape for test ') + parser.add_argument('--ignore_threshold', type=float, default=0.001, help='threshold to throw low quality boxes') + + args, _ = parser.parse_known_args() + + args.data_root = os.path.join(args.data_dir, 'val2014') + args.annFile = os.path.join(args.data_dir, 'annotations/instances_val2014.json') + + return args + + +def conver_testing_shape(args): + """Convert testing shape to list.""" + testing_shape = [int(args.testing_shape), int(args.testing_shape)] + return testing_shape + + +def test(): + """The function of eval.""" + start_time = time.time() + args = parse_args() + + # logger + args.outputs_dir = os.path.join(args.log_path, + datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) + rank_id = int(os.environ.get('RANK_ID')) + args.logger = get_logger(args.outputs_dir, rank_id) + + context.reset_auto_parallel_context() + parallel_mode = ParallelMode.STAND_ALONE + context.set_auto_parallel_context(parallel_mode=parallel_mode, mirror_mean=True, device_num=1) + + args.logger.info('Creating Network....') + network = YOLOV3DarkNet53(is_training=False) + + config = ConfigYOLOV3DarkNet53() + if args.testing_shape: + config.test_img_shape = conver_testing_shape(args) + + # convert fusion network to quantization aware network + if config.quantization_aware: + network = quant.convert_quant_network(network, + bn_fold=True, + per_channel=[True, False], + symmetric=[True, False]) + + args.logger.info(args.pretrained) + if os.path.isfile(args.pretrained): + param_dict = load_checkpoint(args.pretrained) + param_dict_new = {} + for key, values in param_dict.items(): + if key.startswith('moments.'): + continue + elif key.startswith('yolo_network.'): + param_dict_new[key[13:]] = values + else: + param_dict_new[key] = values + load_param_into_net(network, param_dict_new) + args.logger.info('load_model {} success'.format(args.pretrained)) + else: + args.logger.info('{} not exists or not a pre-trained file'.format(args.pretrained)) + assert FileNotFoundError('{} not exists or not a pre-trained file'.format(args.pretrained)) + exit(1) + + data_root = args.data_root + ann_file = args.annFile + + ds, data_size = create_yolo_dataset(data_root, ann_file, is_training=False, batch_size=args.per_batch_size, + max_epoch=1, device_num=1, rank=rank_id, shuffle=False, + config=config) + + args.logger.info('testing shape : {}'.format(config.test_img_shape)) + args.logger.info('totol {} images to eval'.format(data_size)) + + network.set_train(False) + + # init detection engine + detection = DetectionEngine(args) + + input_shape = Tensor(tuple(config.test_img_shape), ms.float32) + args.logger.info('Start inference....') + for i, data in enumerate(ds.create_dict_iterator()): + image = Tensor(data["image"]) + + image_shape = Tensor(data["image_shape"]) + image_id = Tensor(data["img_id"]) + + prediction = network(image, input_shape) + output_big, output_me, output_small = prediction + output_big = output_big.asnumpy() + output_me = output_me.asnumpy() + output_small = output_small.asnumpy() + image_id = image_id.asnumpy() + image_shape = image_shape.asnumpy() + + detection.detect([output_small, output_me, output_big], args.per_batch_size, image_shape, image_id) + if i % 1000 == 0: + args.logger.info('Processing... {:.2f}% '.format(i * args.per_batch_size / data_size * 100)) + + args.logger.info('Calculating mAP...') + detection.do_nms_for_results() + result_file_path = detection.write_result() + args.logger.info('result file path: {}'.format(result_file_path)) + eval_result = detection.get_eval_result() + + cost_time = time.time() - start_time + args.logger.info('\n=============coco eval reulst=========\n' + eval_result) + args.logger.info('testing cost time {:.2f}h'.format(cost_time / 3600.)) + + +if __name__ == "__main__": + test() diff --git a/model_zoo/official/cv/yolov3_darknet53_quant/scripts/run_distribute_train.sh b/model_zoo/official/cv/yolov3_darknet53_quant/scripts/run_distribute_train.sh new file mode 100644 index 000000000..ced864263 --- /dev/null +++ b/model_zoo/official/cv/yolov3_darknet53_quant/scripts/run_distribute_train.sh @@ -0,0 +1,83 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 3 ] +then + echo "Usage: sh run_distribute_train.sh [DATASET_PATH] [RESUME_YOLOV3] [MINDSPORE_HCCL_CONFIG_PATH]" +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +DATASET_PATH=$(get_real_path $1) +RESUME_YOLOV3=$(get_real_path $2) +MINDSPORE_HCCL_CONFIG_PATH=$(get_real_path $3) + +echo $DATASET_PATH +echo $RESUME_YOLOV3 +echo $MINDSPORE_HCCL_CONFIG_PATH + +if [ ! -d $DATASET_PATH ] +then + echo "error: DATASET_PATH=$DATASET_PATH is not a directory" +exit 1 +fi + +if [ ! -f $RESUME_YOLOV3 ] +then + echo "error: PRETRAINED_PATH=$RESUME_YOLOV3 is not a file" +exit 1 +fi + +if [ ! -f $MINDSPORE_HCCL_CONFIG_PATH ] +then + echo "error: MINDSPORE_HCCL_CONFIG_PATH=$MINDSPORE_HCCL_CONFIG_PATH is not a file" +exit 1 +fi + +export DEVICE_NUM=8 +export RANK_SIZE=8 +export MINDSPORE_HCCL_CONFIG_PATH=$MINDSPORE_HCCL_CONFIG_PATH + +for((i=0; i<${DEVICE_NUM}; i++)) +do + export DEVICE_ID=$i + export RANK_ID=$i + rm -rf ./train_parallel$i + mkdir ./train_parallel$i + cp ../*.py ./train_parallel$i + cp -r ../src ./train_parallel$i + cd ./train_parallel$i || exit + echo "start training for rank $RANK_ID, device $DEVICE_ID" + env > env.log + python train.py \ + --data_dir=$DATASET_PATH \ + --resume_yolov3=$RESUME_YOLOV3 \ + --is_distributed=1 \ + --per_batch_size=16 \ + --lr=0.012 \ + --T_max=135 \ + --max_epoch=135 \ + --warmup_epochs=5 \ + --lr_scheduler=cosine_annealing > log.txt 2>&1 & + cd .. +done diff --git a/model_zoo/official/cv/yolov3_darknet53_quant/scripts/run_eval.sh b/model_zoo/official/cv/yolov3_darknet53_quant/scripts/run_eval.sh new file mode 100644 index 000000000..ad15d6c6d --- /dev/null +++ b/model_zoo/official/cv/yolov3_darknet53_quant/scripts/run_eval.sh @@ -0,0 +1,67 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 2 ] +then + echo "Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]" +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} +DATASET_PATH=$(get_real_path $1) +CHECKPOINT_PATH=$(get_real_path $2) +echo $DATASET_PATH +echo $CHECKPOINT_PATH + +if [ ! -d $DATASET_PATH ] +then + echo "error: DATASET_PATH=$PATH1 is not a directory" +exit 1 +fi + +if [ ! -f $CHECKPOINT_PATH ] +then + echo "error: CHECKPOINT_PATH=$PATH2 is not a file" +exit 1 +fi + +export DEVICE_NUM=1 +export DEVICE_ID=0 +export RANK_SIZE=$DEVICE_NUM +export RANK_ID=0 + +if [ -d "eval" ]; +then + rm -rf ./eval +fi + +mkdir ./eval +cp ../*.py ./eval +cp -r ../src ./eval +cd ./eval || exit +env > env.log +echo "start infering for device $DEVICE_ID" +python eval.py \ + --data_dir=$DATASET_PATH \ + --pretrained=$CHECKPOINT_PATH \ + --testing_shape=416 > log.txt 2>&1 & +cd .. diff --git a/model_zoo/official/cv/yolov3_darknet53_quant/scripts/run_standalone_train.sh b/model_zoo/official/cv/yolov3_darknet53_quant/scripts/run_standalone_train.sh new file mode 100644 index 000000000..bf642f9e2 --- /dev/null +++ b/model_zoo/official/cv/yolov3_darknet53_quant/scripts/run_standalone_train.sh @@ -0,0 +1,74 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 2 ] +then + echo "Usage: sh run_standalone_train.sh [DATASET_PATH] [RESUME_YOLOV3]" +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +DATASET_PATH=$(get_real_path $1) +echo $DATASET_PATH +RESUME_YOLOV3=$(get_real_path $2) +echo $RESUME_YOLOV3 + +if [ ! -d $DATASET_PATH ] +then + echo "error: DATASET_PATH=$DATASET_PATH is not a directory" +exit 1 +fi + +if [ ! -f $RESUME_YOLOV3 ] +then + echo "error: PRETRAINED_PATH=$RESUME_YOLOV3 is not a file" +exit 1 +fi + +export DEVICE_NUM=1 +export DEVICE_ID=0 +export RANK_ID=0 +export RANK_SIZE=1 + +if [ -d "train" ]; +then + rm -rf ./train +fi +mkdir ./train +cp ../*.py ./train +cp -r ../src ./train +cd ./train || exit +echo "start training for device $DEVICE_ID" +env > env.log + +python train.py \ + --data_dir=$DATASET_PATH \ + --resume_yolov3=$RESUME_YOLOV3 \ + --is_distributed=0 \ + --per_batch_size=16 \ + --lr=0.004 \ + --T_max=135 \ + --max_epoch=135 \ + --warmup_epochs=5 \ + --lr_scheduler=cosine_annealing > log.txt 2>&1 & +cd .. diff --git a/model_zoo/official/cv/yolov3_darknet53_quant/src/__init__.py b/model_zoo/official/cv/yolov3_darknet53_quant/src/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/model_zoo/official/cv/yolov3_darknet53_quant/src/config.py b/model_zoo/official/cv/yolov3_darknet53_quant/src/config.py new file mode 100644 index 000000000..c10e0515b --- /dev/null +++ b/model_zoo/official/cv/yolov3_darknet53_quant/src/config.py @@ -0,0 +1,69 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Config parameters for Darknet based yolov3_darknet53 models.""" + + +class ConfigYOLOV3DarkNet53: + """ + Config parameters for the yolov3_darknet53. + + Examples: + ConfigYOLOV3DarkNet53() + """ + # train_param + # data augmentation related + hue = 0.1 + saturation = 1.5 + value = 1.5 + jitter = 0.3 + + resize_rate = 1 + multi_scale = [[320, 320], + [352, 352], + [384, 384], + [416, 416], + [448, 448], + [480, 480], + [512, 512], + [544, 544], + [576, 576], + [608, 608] + ] + + num_classes = 80 + max_box = 50 + + backbone_input_shape = [32, 64, 128, 256, 512] + backbone_shape = [64, 128, 256, 512, 1024] + backbone_layers = [1, 2, 8, 8, 4] + + # confidence under ignore_threshold means no object when training + ignore_threshold = 0.7 + + # h->w + anchor_scales = [(10, 13), + (16, 30), + (33, 23), + (30, 61), + (62, 45), + (59, 119), + (116, 90), + (156, 198), + (373, 326)] + out_channel = 255 + + quantization_aware = True + # test_param + test_img_shape = [416, 416] diff --git a/model_zoo/official/cv/yolov3_darknet53_quant/src/darknet.py b/model_zoo/official/cv/yolov3_darknet53_quant/src/darknet.py new file mode 100644 index 000000000..012386830 --- /dev/null +++ b/model_zoo/official/cv/yolov3_darknet53_quant/src/darknet.py @@ -0,0 +1,208 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""DarkNet model.""" +import mindspore.nn as nn +from mindspore.ops import operations as P + + +def conv_block(in_channels, + out_channels, + kernel_size, + stride, + dilation=1): + """Get a conv2d batchnorm and relu layer""" + pad_mode = 'same' + padding = 0 + + return nn.Conv2dBnAct(in_channels, out_channels, kernel_size, + stride=stride, + pad_mode=pad_mode, + padding=padding, + dilation=dilation, + has_bn=True, + momentum=0.1, + activation='relu') + + +class ResidualBlock(nn.Cell): + """ + DarkNet V1 residual block definition. + + Args: + in_channels: Integer. Input channel. + out_channels: Integer. Output channel. + + Returns: + Tensor, output tensor. + Examples: + ResidualBlock(3, 208) + """ + expansion = 4 + + def __init__(self, + in_channels, + out_channels): + + super(ResidualBlock, self).__init__() + out_chls = out_channels//2 + self.conv1 = conv_block(in_channels, out_chls, kernel_size=1, stride=1) + self.conv2 = conv_block(out_chls, out_channels, kernel_size=3, stride=1) + self.add = P.TensorAdd() + + def construct(self, x): + identity = x + out = self.conv1(x) + out = self.conv2(out) + out = self.add(out, identity) + + return out + + +class DarkNet(nn.Cell): + """ + DarkNet V1 network. + + Args: + block: Cell. Block for network. + layer_nums: List. Numbers of different layers. + in_channels: Integer. Input channel. + out_channels: Integer. Output channel. + detect: Bool. Whether detect or not. Default:False. + + Returns: + Tuple, tuple of output tensor,(f1,f2,f3,f4,f5). + + Examples: + DarkNet(ResidualBlock, + [1, 2, 8, 8, 4], + [32, 64, 128, 256, 512], + [64, 128, 256, 512, 1024], + 100) + """ + def __init__(self, + block, + layer_nums, + in_channels, + out_channels, + detect=False): + super(DarkNet, self).__init__() + + self.outchannel = out_channels[-1] + self.detect = detect + + if not len(layer_nums) == len(in_channels) == len(out_channels) == 5: + raise ValueError("the length of layer_num, inchannel, outchannel list must be 5!") + self.conv0 = conv_block(3, + in_channels[0], + kernel_size=3, + stride=1) + self.conv1 = conv_block(in_channels[0], + out_channels[0], + kernel_size=3, + stride=2) + self.conv2 = conv_block(in_channels[1], + out_channels[1], + kernel_size=3, + stride=2) + self.conv3 = conv_block(in_channels[2], + out_channels[2], + kernel_size=3, + stride=2) + self.conv4 = conv_block(in_channels[3], + out_channels[3], + kernel_size=3, + stride=2) + self.conv5 = conv_block(in_channels[4], + out_channels[4], + kernel_size=3, + stride=2) + + self.layer1 = self._make_layer(block, + layer_nums[0], + in_channel=out_channels[0], + out_channel=out_channels[0]) + self.layer2 = self._make_layer(block, + layer_nums[1], + in_channel=out_channels[1], + out_channel=out_channels[1]) + self.layer3 = self._make_layer(block, + layer_nums[2], + in_channel=out_channels[2], + out_channel=out_channels[2]) + self.layer4 = self._make_layer(block, + layer_nums[3], + in_channel=out_channels[3], + out_channel=out_channels[3]) + self.layer5 = self._make_layer(block, + layer_nums[4], + in_channel=out_channels[4], + out_channel=out_channels[4]) + + def _make_layer(self, block, layer_num, in_channel, out_channel): + """ + Make Layer for DarkNet. + + :param block: Cell. DarkNet block. + :param layer_num: Integer. Layer number. + :param in_channel: Integer. Input channel. + :param out_channel: Integer. Output channel. + + Examples: + _make_layer(ConvBlock, 1, 128, 256) + """ + layers = [] + darkblk = block(in_channel, out_channel) + layers.append(darkblk) + + for _ in range(1, layer_num): + darkblk = block(out_channel, out_channel) + layers.append(darkblk) + + return nn.SequentialCell(layers) + + def construct(self, x): + c1 = self.conv0(x) + c2 = self.conv1(c1) + c3 = self.layer1(c2) + c4 = self.conv2(c3) + c5 = self.layer2(c4) + c6 = self.conv3(c5) + c7 = self.layer3(c6) + c8 = self.conv4(c7) + c9 = self.layer4(c8) + c10 = self.conv5(c9) + c11 = self.layer5(c10) + if self.detect: + return c7, c9, c11 + + return c11 + + def get_out_channels(self): + return self.outchannel + + +def darknet53(): + """ + Get DarkNet53 neural network. + + Returns: + Cell, cell instance of DarkNet53 neural network. + + Examples: + darknet53() + """ + return DarkNet(ResidualBlock, [1, 2, 8, 8, 4], + [32, 64, 128, 256, 512], + [64, 128, 256, 512, 1024]) diff --git a/model_zoo/official/cv/yolov3_darknet53_quant/src/distributed_sampler.py b/model_zoo/official/cv/yolov3_darknet53_quant/src/distributed_sampler.py new file mode 100644 index 000000000..d31048ee9 --- /dev/null +++ b/model_zoo/official/cv/yolov3_darknet53_quant/src/distributed_sampler.py @@ -0,0 +1,60 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Yolo dataset distributed sampler.""" +from __future__ import division +import math +import numpy as np + + +class DistributedSampler: + """Distributed sampler.""" + def __init__(self, dataset_size, num_replicas=None, rank=None, shuffle=True): + if num_replicas is None: + print("***********Setting world_size to 1 since it is not passed in ******************") + num_replicas = 1 + if rank is None: + print("***********Setting rank to 0 since it is not passed in ******************") + rank = 0 + self.dataset_size = dataset_size + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = int(math.ceil(dataset_size * 1.0 / self.num_replicas)) + self.total_size = self.num_samples * self.num_replicas + self.shuffle = shuffle + + def __iter__(self): + # deterministically shuffle based on epoch + if self.shuffle: + indices = np.random.RandomState(seed=self.epoch).permutation(self.dataset_size) + # np.array type. number from 0 to len(dataset_size)-1, used as index of dataset + indices = indices.tolist() + self.epoch += 1 + # change to list type + else: + indices = list(range(self.dataset_size)) + + # add extra samples to make it evenly divisible + indices += indices[:(self.total_size - len(indices))] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples diff --git a/model_zoo/official/cv/yolov3_darknet53_quant/src/initializer.py b/model_zoo/official/cv/yolov3_darknet53_quant/src/initializer.py new file mode 100644 index 000000000..f3c03a8ad --- /dev/null +++ b/model_zoo/official/cv/yolov3_darknet53_quant/src/initializer.py @@ -0,0 +1,179 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Parameter init.""" +import math +import numpy as np +from mindspore.common import initializer as init +from mindspore.common.initializer import Initializer as MeInitializer +import mindspore.nn as nn +from mindspore import Tensor + + +np.random.seed(5) + + +def calculate_gain(nonlinearity, param=None): + r"""Return the recommended gain value for the given nonlinearity function. + The values are as follows: + + ================= ==================================================== + nonlinearity gain + ================= ==================================================== + Linear / Identity :math:`1` + Conv{1,2,3}D :math:`1` + Sigmoid :math:`1` + Tanh :math:`\frac{5}{3}` + ReLU :math:`\sqrt{2}` + Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}` + ================= ==================================================== + + Args: + nonlinearity: the non-linear function (`nn.functional` name) + param: optional parameter for the non-linear function + + Examples: + >>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2 + """ + linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d'] + if nonlinearity in linear_fns or nonlinearity == 'sigmoid': + return 1 + if nonlinearity == 'tanh': + return 5.0 / 3 + if nonlinearity == 'relu': + return math.sqrt(2.0) + if nonlinearity == 'leaky_relu': + if param is None: + negative_slope = 0.01 + elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float): + # True/False are instances of int, hence check above + negative_slope = param + else: + raise ValueError("negative_slope {} not a valid number".format(param)) + return math.sqrt(2.0 / (1 + negative_slope ** 2)) + + raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) + + +def _assignment(arr, num): + """Assign the value of 'num' and 'arr'.""" + if arr.shape == (): + arr = arr.reshape((1)) + arr[:] = num + arr = arr.reshape(()) + else: + if isinstance(num, np.ndarray): + arr[:] = num[:] + else: + arr[:] = num + return arr + + +def _calculate_correct_fan(array, mode): + mode = mode.lower() + valid_modes = ['fan_in', 'fan_out'] + if mode not in valid_modes: + raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes)) + + fan_in, fan_out = _calculate_fan_in_and_fan_out(array) + return fan_in if mode == 'fan_in' else fan_out + + +def kaiming_uniform_(arr, a=0, mode='fan_in', nonlinearity='leaky_relu'): + r"""Fills the input `Tensor` with values according to the method + described in `Delving deep into rectifiers: Surpassing human-level + performance on ImageNet classification` - He, K. et al. (2015), using a + uniform distribution. The resulting tensor will have values sampled from + :math:`\mathcal{U}(-\text{bound}, \text{bound})` where + + .. math:: + \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}} + + Also known as He initialization. + + Args: + tensor: an n-dimensional `Tensor` + a: the negative slope of the rectifier used after this layer (only + used with ``'leaky_relu'``) + mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` + preserves the magnitude of the variance of the weights in the + forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the + backwards pass. + nonlinearity: the non-linear function (`nn.functional` name), + recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). + + Examples: + >>> w = np.empty(3, 5) + >>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu') + """ + fan = _calculate_correct_fan(arr, mode) + gain = calculate_gain(nonlinearity, a) + std = gain / math.sqrt(fan) + bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation + return np.random.uniform(-bound, bound, arr.shape) + + +def _calculate_fan_in_and_fan_out(arr): + """Calculate fan in and fan out.""" + dimensions = len(arr.shape) + if dimensions < 2: + raise ValueError("Fan in and fan out can not be computed for array with fewer than 2 dimensions") + + num_input_fmaps = arr.shape[1] + num_output_fmaps = arr.shape[0] + receptive_field_size = 1 + if dimensions > 2: + receptive_field_size = arr[0][0].size + fan_in = num_input_fmaps * receptive_field_size + fan_out = num_output_fmaps * receptive_field_size + + return fan_in, fan_out + + +class KaimingUniform(MeInitializer): + """Kaiming uniform initializer.""" + def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'): + super(KaimingUniform, self).__init__() + self.a = a + self.mode = mode + self.nonlinearity = nonlinearity + + def _initialize(self, arr): + tmp = kaiming_uniform_(arr, self.a, self.mode, self.nonlinearity) + _assignment(arr, tmp) + + +def default_recurisive_init(custom_cell): + """Initialize parameter.""" + for _, cell in custom_cell.cells_and_names(): + if isinstance(cell, nn.Conv2d): + cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), + cell.weight.default_input.shape, + cell.weight.default_input.dtype).to_tensor() + if cell.bias is not None: + fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight.default_input.asnumpy()) + bound = 1 / math.sqrt(fan_in) + cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape), + cell.bias.default_input.dtype) + elif isinstance(cell, nn.Dense): + cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), + cell.weight.default_input.shape, + cell.weight.default_input.dtype).to_tensor() + if cell.bias is not None: + fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight.default_input.asnumpy()) + bound = 1 / math.sqrt(fan_in) + cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape), + cell.bias.default_input.dtype) + elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)): + pass diff --git a/model_zoo/official/cv/yolov3_darknet53_quant/src/logger.py b/model_zoo/official/cv/yolov3_darknet53_quant/src/logger.py new file mode 100644 index 000000000..b41ab405f --- /dev/null +++ b/model_zoo/official/cv/yolov3_darknet53_quant/src/logger.py @@ -0,0 +1,80 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Custom Logger.""" +import os +import sys +import logging +from datetime import datetime + + +class LOGGER(logging.Logger): + """ + Logger. + + Args: + logger_name: String. Logger name. + rank: Integer. Rank id. + """ + def __init__(self, logger_name, rank=0): + super(LOGGER, self).__init__(logger_name) + self.rank = rank + if rank % 8 == 0: + console = logging.StreamHandler(sys.stdout) + console.setLevel(logging.INFO) + formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s') + console.setFormatter(formatter) + self.addHandler(console) + + def setup_logging_file(self, log_dir, rank=0): + """Setup logging file.""" + self.rank = rank + if not os.path.exists(log_dir): + os.makedirs(log_dir, exist_ok=True) + log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(rank) + self.log_fn = os.path.join(log_dir, log_name) + fh = logging.FileHandler(self.log_fn) + fh.setLevel(logging.INFO) + formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s') + fh.setFormatter(formatter) + self.addHandler(fh) + + def info(self, msg, *args, **kwargs): + if self.isEnabledFor(logging.INFO): + self._log(logging.INFO, msg, args, **kwargs) + + def save_args(self, args): + self.info('Args:') + args_dict = vars(args) + for key in args_dict.keys(): + self.info('--> %s: %s', key, args_dict[key]) + self.info('') + + def important_info(self, msg, *args, **kwargs): + if self.isEnabledFor(logging.INFO) and self.rank == 0: + line_width = 2 + important_msg = '\n' + important_msg += ('*'*70 + '\n')*line_width + important_msg += ('*'*line_width + '\n')*2 + important_msg += '*'*line_width + ' '*8 + msg + '\n' + important_msg += ('*'*line_width + '\n')*2 + important_msg += ('*'*70 + '\n')*line_width + self.info(important_msg, *args, **kwargs) + + +def get_logger(path, rank): + """Get Logger.""" + logger = LOGGER('yolov3_darknet53', rank) + logger.setup_logging_file(path, rank) + return logger diff --git a/model_zoo/official/cv/yolov3_darknet53_quant/src/loss.py b/model_zoo/official/cv/yolov3_darknet53_quant/src/loss.py new file mode 100644 index 000000000..acdc6dba6 --- /dev/null +++ b/model_zoo/official/cv/yolov3_darknet53_quant/src/loss.py @@ -0,0 +1,70 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""YOLOV3 loss.""" +from mindspore.ops import operations as P +import mindspore.nn as nn + + +class XYLoss(nn.Cell): + """Loss for x and y.""" + def __init__(self): + super(XYLoss, self).__init__() + self.cross_entropy = P.SigmoidCrossEntropyWithLogits() + self.reduce_sum = P.ReduceSum() + + def construct(self, object_mask, box_loss_scale, predict_xy, true_xy): + xy_loss = object_mask * box_loss_scale * self.cross_entropy(predict_xy, true_xy) + xy_loss = self.reduce_sum(xy_loss, ()) + return xy_loss + + +class WHLoss(nn.Cell): + """Loss for w and h.""" + def __init__(self): + super(WHLoss, self).__init__() + self.square = P.Square() + self.reduce_sum = P.ReduceSum() + + def construct(self, object_mask, box_loss_scale, predict_wh, true_wh): + wh_loss = object_mask * box_loss_scale * 0.5 * P.Square()(true_wh - predict_wh) + wh_loss = self.reduce_sum(wh_loss, ()) + return wh_loss + + +class ConfidenceLoss(nn.Cell): + """Loss for confidence.""" + def __init__(self): + super(ConfidenceLoss, self).__init__() + self.cross_entropy = P.SigmoidCrossEntropyWithLogits() + self.reduce_sum = P.ReduceSum() + + def construct(self, object_mask, predict_confidence, ignore_mask): + confidence_loss = self.cross_entropy(predict_confidence, object_mask) + confidence_loss = object_mask * confidence_loss + (1 - object_mask) * confidence_loss * ignore_mask + confidence_loss = self.reduce_sum(confidence_loss, ()) + return confidence_loss + + +class ClassLoss(nn.Cell): + """Loss for classification.""" + def __init__(self): + super(ClassLoss, self).__init__() + self.cross_entropy = P.SigmoidCrossEntropyWithLogits() + self.reduce_sum = P.ReduceSum() + + def construct(self, object_mask, predict_class, class_probs): + class_loss = object_mask * self.cross_entropy(predict_class, class_probs) + class_loss = self.reduce_sum(class_loss, ()) + return class_loss diff --git a/model_zoo/official/cv/yolov3_darknet53_quant/src/lr_scheduler.py b/model_zoo/official/cv/yolov3_darknet53_quant/src/lr_scheduler.py new file mode 100644 index 000000000..c745edfd9 --- /dev/null +++ b/model_zoo/official/cv/yolov3_darknet53_quant/src/lr_scheduler.py @@ -0,0 +1,143 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Learning rate scheduler.""" +import math +from collections import Counter + +import numpy as np + + +def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr): + """Linear learning rate.""" + lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps) + lr = float(init_lr) + lr_inc * current_step + return lr + + +def warmup_step_lr(lr, lr_epochs, steps_per_epoch, warmup_epochs, max_epoch, gamma=0.1): + """Warmup step learning rate.""" + base_lr = lr + warmup_init_lr = 0 + total_steps = int(max_epoch * steps_per_epoch) + warmup_steps = int(warmup_epochs * steps_per_epoch) + milestones = lr_epochs + milestones_steps = [] + for milestone in milestones: + milestones_step = milestone * steps_per_epoch + milestones_steps.append(milestones_step) + + lr_each_step = [] + lr = base_lr + milestones_steps_counter = Counter(milestones_steps) + for i in range(total_steps): + if i < warmup_steps: + lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) + else: + lr = lr * gamma**milestones_steps_counter[i] + lr_each_step.append(lr) + + return np.array(lr_each_step).astype(np.float32) + + +def multi_step_lr(lr, milestones, steps_per_epoch, max_epoch, gamma=0.1): + return warmup_step_lr(lr, milestones, steps_per_epoch, 0, max_epoch, gamma=gamma) + + +def step_lr(lr, epoch_size, steps_per_epoch, max_epoch, gamma=0.1): + lr_epochs = [] + for i in range(1, max_epoch): + if i % epoch_size == 0: + lr_epochs.append(i) + return multi_step_lr(lr, lr_epochs, steps_per_epoch, max_epoch, gamma=gamma) + + +def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0): + """Cosine annealing learning rate.""" + base_lr = lr + total_steps = int(max_epoch * steps_per_epoch) + warmup_steps = int(warmup_epochs * steps_per_epoch) + + lr_each_step = [] + for i in range(total_steps): + last_epoch = i // steps_per_epoch + if i < warmup_steps: + lr = 0 + else: + lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / T_max)) / 2 + lr_each_step.append(lr) + + return np.array(lr_each_step).astype(np.float32) + + +def warmup_cosine_annealing_lr_V2(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0): + """Cosine annealing learning rate V2.""" + base_lr = lr + warmup_init_lr = 0 + total_steps = int(max_epoch * steps_per_epoch) + warmup_steps = int(warmup_epochs * steps_per_epoch) + + last_lr = 0 + last_epoch_V1 = 0 + + T_max_V2 = int(max_epoch*1/3) + + lr_each_step = [] + for i in range(total_steps): + last_epoch = i // steps_per_epoch + if i < warmup_steps: + lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) + else: + if i < total_steps*2/3: + lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / T_max)) / 2 + last_lr = lr + last_epoch_V1 = last_epoch + else: + base_lr = last_lr + last_epoch = last_epoch-last_epoch_V1 + lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi * last_epoch / T_max_V2)) / 2 + + lr_each_step.append(lr) + return np.array(lr_each_step).astype(np.float32) + + +def warmup_cosine_annealing_lr_sample(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0): + """Warmup cosine annealing learning rate.""" + start_sample_epoch = 60 + step_sample = 2 + tobe_sampled_epoch = 60 + end_sampled_epoch = start_sample_epoch + step_sample*tobe_sampled_epoch + max_sampled_epoch = max_epoch+tobe_sampled_epoch + T_max = max_sampled_epoch + + base_lr = lr + warmup_init_lr = 0 + total_steps = int(max_epoch * steps_per_epoch) + total_sampled_steps = int(max_sampled_epoch * steps_per_epoch) + warmup_steps = int(warmup_epochs * steps_per_epoch) + + lr_each_step = [] + + for i in range(total_sampled_steps): + last_epoch = i // steps_per_epoch + if last_epoch in range(start_sample_epoch, end_sampled_epoch, step_sample): + continue + if i < warmup_steps: + lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) + else: + lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / T_max)) / 2 + lr_each_step.append(lr) + + assert total_steps == len(lr_each_step) + return np.array(lr_each_step).astype(np.float32) diff --git a/model_zoo/official/cv/yolov3_darknet53_quant/src/transforms.py b/model_zoo/official/cv/yolov3_darknet53_quant/src/transforms.py new file mode 100644 index 000000000..837d1a25e --- /dev/null +++ b/model_zoo/official/cv/yolov3_darknet53_quant/src/transforms.py @@ -0,0 +1,577 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Preprocess dataset.""" +import random +import threading +import copy + +import numpy as np +from PIL import Image +import cv2 + + +def _rand(a=0., b=1.): + return np.random.rand() * (b - a) + a + + +def bbox_iou(bbox_a, bbox_b, offset=0): + """Calculate Intersection-Over-Union(IOU) of two bounding boxes. + + Parameters + ---------- + bbox_a : numpy.ndarray + An ndarray with shape :math:`(N, 4)`. + bbox_b : numpy.ndarray + An ndarray with shape :math:`(M, 4)`. + offset : float or int, default is 0 + The ``offset`` is used to control the whether the width(or height) is computed as + (right - left + ``offset``). + Note that the offset must be 0 for normalized bboxes, whose ranges are in ``[0, 1]``. + + Returns + ------- + numpy.ndarray + An ndarray with shape :math:`(N, M)` indicates IOU between each pairs of + bounding boxes in `bbox_a` and `bbox_b`. + + """ + if bbox_a.shape[1] < 4 or bbox_b.shape[1] < 4: + raise IndexError("Bounding boxes axis 1 must have at least length 4") + + tl = np.maximum(bbox_a[:, None, :2], bbox_b[:, :2]) + br = np.minimum(bbox_a[:, None, 2:4], bbox_b[:, 2:4]) + + area_i = np.prod(br - tl + offset, axis=2) * (tl < br).all(axis=2) + area_a = np.prod(bbox_a[:, 2:4] - bbox_a[:, :2] + offset, axis=1) + area_b = np.prod(bbox_b[:, 2:4] - bbox_b[:, :2] + offset, axis=1) + return area_i / (area_a[:, None] + area_b - area_i) + + +def statistic_normalize_img(img, statistic_norm): + """Statistic normalize images.""" + # img: RGB + if isinstance(img, Image.Image): + img = np.array(img) + img = img/255. + mean = np.array([0.485, 0.456, 0.406]) + std = np.array([0.229, 0.224, 0.225]) + if statistic_norm: + img = (img - mean) / std + return img + + +def get_interp_method(interp, sizes=()): + """Get the interpolation method for resize functions. + The major purpose of this function is to wrap a random interp method selection + and a auto-estimation method. + + Parameters + ---------- + interp : int + interpolation method for all resizing operations + + Possible values: + 0: Nearest Neighbors Interpolation. + 1: Bilinear interpolation. + 2: Bicubic interpolation over 4x4 pixel neighborhood. + 3: Nearest Neighbors. [Originally it should be Area-based, + as we cannot find Area-based, so we use NN instead. + Area-based (resampling using pixel area relation). It may be a + preferred method for image decimation, as it gives moire-free + results. But when the image is zoomed, it is similar to the Nearest + Neighbors method. (used by default). + 4: Lanczos interpolation over 8x8 pixel neighborhood. + 9: Cubic for enlarge, area for shrink, bilinear for others + 10: Random select from interpolation method metioned above. + Note: + When shrinking an image, it will generally look best with AREA-based + interpolation, whereas, when enlarging an image, it will generally look best + with Bicubic (slow) or Bilinear (faster but still looks OK). + More details can be found in the documentation of OpenCV, please refer to + http://docs.opencv.org/master/da/d54/group__imgproc__transform.html. + sizes : tuple of int + (old_height, old_width, new_height, new_width), if None provided, auto(9) + will return Area(2) anyway. + + Returns + ------- + int + interp method from 0 to 4 + """ + if interp == 9: + if sizes: + assert len(sizes) == 4 + oh, ow, nh, nw = sizes + if nh > oh and nw > ow: + return 2 + if nh < oh and nw < ow: + return 0 + return 1 + return 2 + if interp == 10: + return random.randint(0, 4) + if interp not in (0, 1, 2, 3, 4): + raise ValueError('Unknown interp method %d' % interp) + return interp + + +def pil_image_reshape(interp): + """Reshape pil image.""" + reshape_type = { + 0: Image.NEAREST, + 1: Image.BILINEAR, + 2: Image.BICUBIC, + 3: Image.NEAREST, + 4: Image.LANCZOS, + } + return reshape_type[interp] + + +def _preprocess_true_boxes(true_boxes, anchors, in_shape, num_classes, + max_boxes, label_smooth, label_smooth_factor=0.1): + """Preprocess annotation boxes.""" + anchors = np.array(anchors) + num_layers = anchors.shape[0] // 3 + anchor_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]] + true_boxes = np.array(true_boxes, dtype='float32') + input_shape = np.array(in_shape, dtype='int32') + boxes_xy = (true_boxes[..., 0:2] + true_boxes[..., 2:4]) // 2. + # trans to box center point + boxes_wh = true_boxes[..., 2:4] - true_boxes[..., 0:2] + # input_shape is [h, w] + true_boxes[..., 0:2] = boxes_xy / input_shape[::-1] + true_boxes[..., 2:4] = boxes_wh / input_shape[::-1] + # true_boxes = [xywh] + + grid_shapes = [input_shape // 32, input_shape // 16, input_shape // 8] + # grid_shape [h, w] + y_true = [np.zeros((grid_shapes[l][0], grid_shapes[l][1], len(anchor_mask[l]), + 5 + num_classes), dtype='float32') for l in range(num_layers)] + # y_true [gridy, gridx] + anchors = np.expand_dims(anchors, 0) + anchors_max = anchors / 2. + anchors_min = -anchors_max + valid_mask = boxes_wh[..., 0] > 0 + + wh = boxes_wh[valid_mask] + if wh.size > 0: + wh = np.expand_dims(wh, -2) + boxes_max = wh / 2. + boxes_min = -boxes_max + + intersect_min = np.maximum(boxes_min, anchors_min) + intersect_max = np.minimum(boxes_max, anchors_max) + intersect_wh = np.maximum(intersect_max - intersect_min, 0.) + intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1] + box_area = wh[..., 0] * wh[..., 1] + anchor_area = anchors[..., 0] * anchors[..., 1] + iou = intersect_area / (box_area + anchor_area - intersect_area) + + best_anchor = np.argmax(iou, axis=-1) + for t, n in enumerate(best_anchor): + for l in range(num_layers): + if n in anchor_mask[l]: + i = np.floor(true_boxes[t, 0] * grid_shapes[l][1]).astype('int32') # grid_y + j = np.floor(true_boxes[t, 1] * grid_shapes[l][0]).astype('int32') # grid_x + + k = anchor_mask[l].index(n) + c = true_boxes[t, 4].astype('int32') + y_true[l][j, i, k, 0:4] = true_boxes[t, 0:4] + y_true[l][j, i, k, 4] = 1. + + # lable-smooth + if label_smooth: + sigma = label_smooth_factor/(num_classes-1) + y_true[l][j, i, k, 5:] = sigma + y_true[l][j, i, k, 5+c] = 1-label_smooth_factor + else: + y_true[l][j, i, k, 5 + c] = 1. + + # pad_gt_boxes for avoiding dynamic shape + pad_gt_box0 = np.zeros(shape=[max_boxes, 4], dtype=np.float32) + pad_gt_box1 = np.zeros(shape=[max_boxes, 4], dtype=np.float32) + pad_gt_box2 = np.zeros(shape=[max_boxes, 4], dtype=np.float32) + + mask0 = np.reshape(y_true[0][..., 4:5], [-1]) + gt_box0 = np.reshape(y_true[0][..., 0:4], [-1, 4]) + # gt_box [boxes, [x,y,w,h]] + gt_box0 = gt_box0[mask0 == 1] + # gt_box0: get all boxes which have object + pad_gt_box0[:gt_box0.shape[0]] = gt_box0 + # gt_box0.shape[0]: total number of boxes in gt_box0 + # top N of pad_gt_box0 is real box, and after are pad by zero + + mask1 = np.reshape(y_true[1][..., 4:5], [-1]) + gt_box1 = np.reshape(y_true[1][..., 0:4], [-1, 4]) + gt_box1 = gt_box1[mask1 == 1] + pad_gt_box1[:gt_box1.shape[0]] = gt_box1 + + mask2 = np.reshape(y_true[2][..., 4:5], [-1]) + gt_box2 = np.reshape(y_true[2][..., 0:4], [-1, 4]) + + gt_box2 = gt_box2[mask2 == 1] + pad_gt_box2[:gt_box2.shape[0]] = gt_box2 + return y_true[0], y_true[1], y_true[2], pad_gt_box0, pad_gt_box1, pad_gt_box2 + + +def _reshape_data(image, image_size): + """Reshape image.""" + if not isinstance(image, Image.Image): + image = Image.fromarray(image) + ori_w, ori_h = image.size + ori_image_shape = np.array([ori_w, ori_h], np.int32) + # original image shape fir:H sec:W + h, w = image_size + interp = get_interp_method(interp=9, sizes=(ori_h, ori_w, h, w)) + image = image.resize((w, h), pil_image_reshape(interp)) + image_data = statistic_normalize_img(image, statistic_norm=True) + if len(image_data.shape) == 2: + image_data = np.expand_dims(image_data, axis=-1) + image_data = np.concatenate([image_data, image_data, image_data], axis=-1) + image_data = image_data.astype(np.float32) + return image_data, ori_image_shape + + +def color_distortion(img, hue, sat, val, device_num): + """Color distortion.""" + hue = _rand(-hue, hue) + sat = _rand(1, sat) if _rand() < .5 else 1 / _rand(1, sat) + val = _rand(1, val) if _rand() < .5 else 1 / _rand(1, val) + if device_num != 1: + cv2.setNumThreads(1) + x = cv2.cvtColor(img, cv2.COLOR_RGB2HSV_FULL) + x = x / 255. + x[..., 0] += hue + x[..., 0][x[..., 0] > 1] -= 1 + x[..., 0][x[..., 0] < 0] += 1 + x[..., 1] *= sat + x[..., 2] *= val + x[x > 1] = 1 + x[x < 0] = 0 + x = x * 255. + x = x.astype(np.uint8) + image_data = cv2.cvtColor(x, cv2.COLOR_HSV2RGB_FULL) + return image_data + + +def filp_pil_image(img): + return img.transpose(Image.FLIP_LEFT_RIGHT) + + +def convert_gray_to_color(img): + if len(img.shape) == 2: + img = np.expand_dims(img, axis=-1) + img = np.concatenate([img, img, img], axis=-1) + return img + + +def _is_iou_satisfied_constraint(min_iou, max_iou, box, crop_box): + iou = bbox_iou(box, crop_box) + return min_iou <= iou.min() and max_iou >= iou.max() + + +def _choose_candidate_by_constraints(max_trial, input_w, input_h, image_w, image_h, jitter, box, use_constraints): + """Choose candidate by constraints.""" + if use_constraints: + constraints = ( + (0.1, None), + (0.3, None), + (0.5, None), + (0.7, None), + (0.9, None), + (None, 1), + ) + else: + constraints = ( + (None, None), + ) + # add default candidate + candidates = [(0, 0, input_w, input_h)] + for constraint in constraints: + min_iou, max_iou = constraint + min_iou = -np.inf if min_iou is None else min_iou + max_iou = np.inf if max_iou is None else max_iou + + for _ in range(max_trial): + # box_data should have at least one box + new_ar = float(input_w) / float(input_h) * _rand(1 - jitter, 1 + jitter) / _rand(1 - jitter, 1 + jitter) + scale = _rand(0.25, 2) + + if new_ar < 1: + nh = int(scale * input_h) + nw = int(nh * new_ar) + else: + nw = int(scale * input_w) + nh = int(nw / new_ar) + + dx = int(_rand(0, input_w - nw)) + dy = int(_rand(0, input_h - nh)) + + if box.size > 0: + t_box = copy.deepcopy(box) + t_box[:, [0, 2]] = t_box[:, [0, 2]] * float(nw) / float(image_w) + dx + t_box[:, [1, 3]] = t_box[:, [1, 3]] * float(nh) / float(image_h) + dy + + crop_box = np.array((0, 0, input_w, input_h)) + if not _is_iou_satisfied_constraint(min_iou, max_iou, t_box, crop_box[np.newaxis]): + continue + else: + candidates.append((dx, dy, nw, nh)) + else: + raise Exception("!!! annotation box is less than 1") + return candidates + + +def _correct_bbox_by_candidates(candidates, input_w, input_h, image_w, + image_h, flip, box, box_data, allow_outside_center): + """Calculate correct boxes.""" + while candidates: + if len(candidates) > 1: + # ignore default candidate which do not crop + candidate = candidates.pop(np.random.randint(1, len(candidates))) + else: + candidate = candidates.pop(np.random.randint(0, len(candidates))) + dx, dy, nw, nh = candidate + t_box = copy.deepcopy(box) + t_box[:, [0, 2]] = t_box[:, [0, 2]] * float(nw) / float(image_w) + dx + t_box[:, [1, 3]] = t_box[:, [1, 3]] * float(nh) / float(image_h) + dy + if flip: + t_box[:, [0, 2]] = input_w - t_box[:, [2, 0]] + + if allow_outside_center: + pass + else: + t_box = t_box[np.logical_and((t_box[:, 0] + t_box[:, 2])/2. >= 0., (t_box[:, 1] + t_box[:, 3])/2. >= 0.)] + t_box = t_box[np.logical_and((t_box[:, 0] + t_box[:, 2]) / 2. <= input_w, + (t_box[:, 1] + t_box[:, 3]) / 2. <= input_h)] + + # recorrect x, y for case x,y < 0 reset to zero, after dx and dy, some box can smaller than zero + t_box[:, 0:2][t_box[:, 0:2] < 0] = 0 + # recorrect w,h not higher than input size + t_box[:, 2][t_box[:, 2] > input_w] = input_w + t_box[:, 3][t_box[:, 3] > input_h] = input_h + box_w = t_box[:, 2] - t_box[:, 0] + box_h = t_box[:, 3] - t_box[:, 1] + # discard invalid box: w or h smaller than 1 pixel + t_box = t_box[np.logical_and(box_w > 1, box_h > 1)] + + if t_box.shape[0] > 0: + # break if number of find t_box + box_data[: len(t_box)] = t_box + return box_data, candidate + raise Exception('all candidates can not satisfied re-correct bbox') + + +def _data_aug(image, box, jitter, hue, sat, val, image_input_size, max_boxes, + anchors, num_classes, max_trial=10, device_num=1): + """Crop an image randomly with bounding box constraints. + + This data augmentation is used in training of + Single Shot Multibox Detector [#]_. More details can be found in + data augmentation section of the original paper. + .. [#] Wei Liu, Dragomir Anguelov, Dumitru Erhan, Christian Szegedy, + Scott Reed, Cheng-Yang Fu, Alexander C. Berg. + SSD: Single Shot MultiBox Detector. ECCV 2016.""" + + if not isinstance(image, Image.Image): + image = Image.fromarray(image) + + image_w, image_h = image.size + input_h, input_w = image_input_size + + np.random.shuffle(box) + if len(box) > max_boxes: + box = box[:max_boxes] + flip = _rand() < .5 + box_data = np.zeros((max_boxes, 5)) + + candidates = _choose_candidate_by_constraints(use_constraints=False, + max_trial=max_trial, + input_w=input_w, + input_h=input_h, + image_w=image_w, + image_h=image_h, + jitter=jitter, + box=box) + box_data, candidate = _correct_bbox_by_candidates(candidates=candidates, + input_w=input_w, + input_h=input_h, + image_w=image_w, + image_h=image_h, + flip=flip, + box=box, + box_data=box_data, + allow_outside_center=True) + dx, dy, nw, nh = candidate + interp = get_interp_method(interp=10) + image = image.resize((nw, nh), pil_image_reshape(interp)) + # place image, gray color as back graoud + new_image = Image.new('RGB', (input_w, input_h), (128, 128, 128)) + new_image.paste(image, (dx, dy)) + image = new_image + + if flip: + image = filp_pil_image(image) + + image = np.array(image) + + image = convert_gray_to_color(image) + + image_data = color_distortion(image, hue, sat, val, device_num) + image_data = statistic_normalize_img(image_data, statistic_norm=True) + + image_data = image_data.astype(np.float32) + + return image_data, box_data + + +def preprocess_fn(image, box, config, input_size, device_num): + """Preprocess data function.""" + config_anchors = config.anchor_scales + anchors = np.array([list(x) for x in config_anchors]) + max_boxes = config.max_box + num_classes = config.num_classes + jitter = config.jitter + hue = config.hue + sat = config.saturation + val = config.value + image, anno = _data_aug(image, box, jitter=jitter, hue=hue, sat=sat, val=val, + image_input_size=input_size, max_boxes=max_boxes, + num_classes=num_classes, anchors=anchors, device_num=device_num) + return image, anno + + +def reshape_fn(image, img_id, config): + input_size = config.test_img_shape + image, ori_image_shape = _reshape_data(image, image_size=input_size) + return image, ori_image_shape, img_id + + +class MultiScaleTrans: + """Multi scale transform.""" + def __init__(self, config, device_num): + self.config = config + self.seed = 0 + self.size_list = [] + self.resize_rate = config.resize_rate + self.dataset_size = config.dataset_size + self.size_dict = {} + self.seed_num = int(1e6) + self.seed_list = self.generate_seed_list(seed_num=self.seed_num) + self.resize_count_num = int(np.ceil(self.dataset_size / self.resize_rate)) + self.device_num = device_num + + def generate_seed_list(self, init_seed=1234, seed_num=int(1e6), seed_range=(1, 1000)): + seed_list = [] + random.seed(init_seed) + for _ in range(seed_num): + seed = random.randint(seed_range[0], seed_range[1]) + seed_list.append(seed) + return seed_list + + def __call__(self, imgs, annos, batchInfo): + epoch_num = batchInfo.get_epoch_num() + size_idx = int(batchInfo.get_batch_num() / self.resize_rate) + seed_key = self.seed_list[(epoch_num * self.resize_count_num + size_idx) % self.seed_num] + ret_imgs = [] + ret_annos = [] + + if self.size_dict.get(seed_key, None) is None: + random.seed(seed_key) + new_size = random.choice(self.config.multi_scale) + self.size_dict[seed_key] = new_size + seed = seed_key + + input_size = self.size_dict[seed] + for img, anno in zip(imgs, annos): + img, anno = preprocess_fn(img, anno, self.config, input_size, self.device_num) + ret_imgs.append(img.transpose(2, 0, 1).copy()) + ret_annos.append(anno) + return np.array(ret_imgs), np.array(ret_annos) + + +def thread_batch_preprocess_true_box(annos, config, input_shape, result_index, batch_bbox_true_1, batch_bbox_true_2, + batch_bbox_true_3, batch_gt_box1, batch_gt_box2, batch_gt_box3): + """Preprocess true box for multi-thread.""" + i = 0 + for anno in annos: + bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \ + _preprocess_true_boxes(true_boxes=anno, anchors=config.anchor_scales, in_shape=input_shape, + num_classes=config.num_classes, max_boxes=config.max_box, + label_smooth=config.label_smooth, label_smooth_factor=config.label_smooth_factor) + batch_bbox_true_1[result_index + i] = bbox_true_1 + batch_bbox_true_2[result_index + i] = bbox_true_2 + batch_bbox_true_3[result_index + i] = bbox_true_3 + batch_gt_box1[result_index + i] = gt_box1 + batch_gt_box2[result_index + i] = gt_box2 + batch_gt_box3[result_index + i] = gt_box3 + i = i + 1 + + +def batch_preprocess_true_box(annos, config, input_shape): + """Preprocess true box with multi-thread.""" + batch_bbox_true_1 = [] + batch_bbox_true_2 = [] + batch_bbox_true_3 = [] + batch_gt_box1 = [] + batch_gt_box2 = [] + batch_gt_box3 = [] + threads = [] + + step = 4 + for index in range(0, len(annos), step): + for _ in range(step): + batch_bbox_true_1.append(None) + batch_bbox_true_2.append(None) + batch_bbox_true_3.append(None) + batch_gt_box1.append(None) + batch_gt_box2.append(None) + batch_gt_box3.append(None) + step_anno = annos[index: index + step] + t = threading.Thread(target=thread_batch_preprocess_true_box, + args=(step_anno, config, input_shape, index, batch_bbox_true_1, batch_bbox_true_2, + batch_bbox_true_3, batch_gt_box1, batch_gt_box2, batch_gt_box3)) + t.start() + threads.append(t) + + for t in threads: + t.join() + + return np.array(batch_bbox_true_1), np.array(batch_bbox_true_2), np.array(batch_bbox_true_3), \ + np.array(batch_gt_box1), np.array(batch_gt_box2), np.array(batch_gt_box3) + + +def batch_preprocess_true_box_single(annos, config, input_shape): + """Preprocess true boxes.""" + batch_bbox_true_1 = [] + batch_bbox_true_2 = [] + batch_bbox_true_3 = [] + batch_gt_box1 = [] + batch_gt_box2 = [] + batch_gt_box3 = [] + for anno in annos: + bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \ + _preprocess_true_boxes(true_boxes=anno, anchors=config.anchor_scales, in_shape=input_shape, + num_classes=config.num_classes, max_boxes=config.max_box, + label_smooth=config.label_smooth, label_smooth_factor=config.label_smooth_factor) + batch_bbox_true_1.append(bbox_true_1) + batch_bbox_true_2.append(bbox_true_2) + batch_bbox_true_3.append(bbox_true_3) + batch_gt_box1.append(gt_box1) + batch_gt_box2.append(gt_box2) + batch_gt_box3.append(gt_box3) + + return np.array(batch_bbox_true_1), np.array(batch_bbox_true_2), np.array(batch_bbox_true_3), \ + np.array(batch_gt_box1), np.array(batch_gt_box2), np.array(batch_gt_box3) diff --git a/model_zoo/official/cv/yolov3_darknet53_quant/src/util.py b/model_zoo/official/cv/yolov3_darknet53_quant/src/util.py new file mode 100644 index 000000000..1a3da9918 --- /dev/null +++ b/model_zoo/official/cv/yolov3_darknet53_quant/src/util.py @@ -0,0 +1,177 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Util class or function.""" +from mindspore.train.serialization import load_checkpoint +import mindspore.nn as nn + + +class AverageMeter: + """Computes and stores the average and current value""" + + def __init__(self, name, fmt=':f', tb_writer=None): + self.name = name + self.fmt = fmt + self.reset() + self.tb_writer = tb_writer + self.cur_step = 1 + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + if self.tb_writer is not None: + self.tb_writer.add_scalar(self.name, self.val, self.cur_step) + self.cur_step += 1 + + def __str__(self): + fmtstr = '{name}:{avg' + self.fmt + '}' + return fmtstr.format(**self.__dict__) + + +def load_backbone(net, ckpt_path, args): + """Load darknet53 backbone checkpoint.""" + param_dict = load_checkpoint(ckpt_path) + yolo_backbone_prefix = 'feature_map.backbone' + darknet_backbone_prefix = 'network.backbone' + find_param = [] + not_found_param = [] + + for name, cell in net.cells_and_names(): + if name.startswith(yolo_backbone_prefix): + name = name.replace(yolo_backbone_prefix, darknet_backbone_prefix) + if isinstance(cell, (nn.Conv2d, nn.Dense)): + darknet_weight = '{}.weight'.format(name) + darknet_bias = '{}.bias'.format(name) + if darknet_weight in param_dict: + cell.weight.default_input = param_dict[darknet_weight].data + find_param.append(darknet_weight) + else: + not_found_param.append(darknet_weight) + if darknet_bias in param_dict: + cell.bias.default_input = param_dict[darknet_bias].data + find_param.append(darknet_bias) + else: + not_found_param.append(darknet_bias) + elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)): + darknet_moving_mean = '{}.moving_mean'.format(name) + darknet_moving_variance = '{}.moving_variance'.format(name) + darknet_gamma = '{}.gamma'.format(name) + darknet_beta = '{}.beta'.format(name) + if darknet_moving_mean in param_dict: + cell.moving_mean.default_input = param_dict[darknet_moving_mean].data + find_param.append(darknet_moving_mean) + else: + not_found_param.append(darknet_moving_mean) + if darknet_moving_variance in param_dict: + cell.moving_variance.default_input = param_dict[darknet_moving_variance].data + find_param.append(darknet_moving_variance) + else: + not_found_param.append(darknet_moving_variance) + if darknet_gamma in param_dict: + cell.gamma.default_input = param_dict[darknet_gamma].data + find_param.append(darknet_gamma) + else: + not_found_param.append(darknet_gamma) + if darknet_beta in param_dict: + cell.beta.default_input = param_dict[darknet_beta].data + find_param.append(darknet_beta) + else: + not_found_param.append(darknet_beta) + + args.logger.info('================found_param {}========='.format(len(find_param))) + args.logger.info(find_param) + args.logger.info('================not_found_param {}========='.format(len(not_found_param))) + args.logger.info(not_found_param) + args.logger.info('=====load {} successfully ====='.format(ckpt_path)) + + return net + + +def default_wd_filter(x): + """default weight decay filter.""" + parameter_name = x.name + if parameter_name.endswith('.bias'): + # all bias not using weight decay + return False + if parameter_name.endswith('.gamma'): + # bn weight bias not using weight decay, be carefully for now x not include BN + return False + if parameter_name.endswith('.beta'): + # bn weight bias not using weight decay, be carefully for now x not include BN + return False + + return True + + +def get_param_groups(network): + """Param groups for optimizer.""" + decay_params = [] + no_decay_params = [] + for x in network.trainable_params(): + parameter_name = x.name + if parameter_name.endswith('.bias'): + # all bias not using weight decay + no_decay_params.append(x) + elif parameter_name.endswith('.gamma'): + # bn weight bias not using weight decay, be carefully for now x not include BN + no_decay_params.append(x) + elif parameter_name.endswith('.beta'): + # bn weight bias not using weight decay, be carefully for now x not include BN + no_decay_params.append(x) + else: + decay_params.append(x) + + return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}] + + +class ShapeRecord: + """Log image shape.""" + def __init__(self): + self.shape_record = { + 320: 0, + 352: 0, + 384: 0, + 416: 0, + 448: 0, + 480: 0, + 512: 0, + 544: 0, + 576: 0, + 608: 0, + 'total': 0 + } + + def set(self, shape): + if len(shape) > 1: + shape = shape[0] + shape = int(shape) + self.shape_record[shape] += 1 + self.shape_record['total'] += 1 + + def show(self, logger): + for key in self.shape_record: + rate = self.shape_record[key] / float(self.shape_record['total']) + logger.info('shape {}: {:.2f}%'.format(key, rate*100)) diff --git a/model_zoo/official/cv/yolov3_darknet53_quant/src/yolo.py b/model_zoo/official/cv/yolov3_darknet53_quant/src/yolo.py new file mode 100644 index 000000000..e010ddef2 --- /dev/null +++ b/model_zoo/official/cv/yolov3_darknet53_quant/src/yolo.py @@ -0,0 +1,437 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""YOLOv3 based on DarkNet.""" +import mindspore as ms +import mindspore.nn as nn +from mindspore.common.tensor import Tensor +from mindspore import context +from mindspore.parallel._auto_parallel_context import auto_parallel_context +from mindspore.communication.management import get_group_size +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.ops import composite as C + +from src.darknet import DarkNet, ResidualBlock +from src.config import ConfigYOLOV3DarkNet53 +from src.loss import XYLoss, WHLoss, ConfidenceLoss, ClassLoss + + +def _conv_bn_relu(in_channel, + out_channel, + ksize, + stride=1, + padding=0, + dilation=1, + alpha=0.1, + momentum=0.9, + eps=1e-5, + pad_mode="same"): + """Get a conv2d batchnorm and relu layer""" + return nn.Conv2dBnAct(in_channel, out_channel, ksize, + stride=stride, + pad_mode=pad_mode, + padding=padding, + dilation=dilation, + has_bn=True, + momentum=momentum, + eps=eps, + activation='leakyrelu', + alpha=alpha) + + +class YoloBlock(nn.Cell): + """ + YoloBlock for YOLOv3. + + Args: + in_channels: Integer. Input channel. + out_chls: Interger. Middle channel. + out_channels: Integer. Output channel. + + Returns: + Tuple, tuple of output tensor,(f1,f2,f3). + + Examples: + YoloBlock(1024, 512, 255) + + """ + def __init__(self, in_channels, out_chls, out_channels): + super(YoloBlock, self).__init__() + out_chls_2 = out_chls*2 + + self.conv0 = _conv_bn_relu(in_channels, out_chls, ksize=1) + self.conv1 = _conv_bn_relu(out_chls, out_chls_2, ksize=3) + + self.conv2 = _conv_bn_relu(out_chls_2, out_chls, ksize=1) + self.conv3 = _conv_bn_relu(out_chls, out_chls_2, ksize=3) + + self.conv4 = _conv_bn_relu(out_chls_2, out_chls, ksize=1) + self.conv5 = _conv_bn_relu(out_chls, out_chls_2, ksize=3) + + self.conv6 = nn.Conv2dBnAct(out_chls_2, out_channels, kernel_size=1, stride=1, + has_bias=True, has_bn=False, activation=None, after_fake=False) + + def construct(self, x): + c1 = self.conv0(x) + c2 = self.conv1(c1) + + c3 = self.conv2(c2) + c4 = self.conv3(c3) + + c5 = self.conv4(c4) + c6 = self.conv5(c5) + + out = self.conv6(c6) + return c5, out + + +class YOLOv3(nn.Cell): + """ + YOLOv3 Network. + + Note: + backbone = darknet53 + + Args: + backbone_shape: List. Darknet output channels shape. + backbone: Cell. Backbone Network. + out_channel: Interger. Output channel. + + Returns: + Tensor, output tensor. + + Examples: + YOLOv3(backbone_shape=[64, 128, 256, 512, 1024] + backbone=darknet53(), + out_channel=255) + """ + def __init__(self, backbone_shape, backbone, out_channel): + super(YOLOv3, self).__init__() + self.out_channel = out_channel + self.backbone = backbone + self.backblock0 = YoloBlock(backbone_shape[-1], out_chls=backbone_shape[-2], out_channels=out_channel) + + self.conv1 = _conv_bn_relu(in_channel=backbone_shape[-2], out_channel=backbone_shape[-2]//2, ksize=1) + self.backblock1 = YoloBlock(in_channels=backbone_shape[-2]+backbone_shape[-3], + out_chls=backbone_shape[-3], + out_channels=out_channel) + + self.conv2 = _conv_bn_relu(in_channel=backbone_shape[-3], out_channel=backbone_shape[-3]//2, ksize=1) + self.backblock2 = YoloBlock(in_channels=backbone_shape[-3]+backbone_shape[-4], + out_chls=backbone_shape[-4], + out_channels=out_channel) + self.concat = P.Concat(axis=1) + + def construct(self, x): + # input_shape of x is (batch_size, 3, h, w) + # feature_map1 is (batch_size, backbone_shape[2], h/8, w/8) + # feature_map2 is (batch_size, backbone_shape[3], h/16, w/16) + # feature_map3 is (batch_size, backbone_shape[4], h/32, w/32) + img_hight = P.Shape()(x)[2] + img_width = P.Shape()(x)[3] + feature_map1, feature_map2, feature_map3 = self.backbone(x) + con1, big_object_output = self.backblock0(feature_map3) + + con1 = self.conv1(con1) + ups1 = P.ResizeNearestNeighbor((img_hight / 16, img_width / 16))(con1) + con1 = self.concat((ups1, feature_map2)) + con2, medium_object_output = self.backblock1(con1) + + con2 = self.conv2(con2) + ups2 = P.ResizeNearestNeighbor((img_hight / 8, img_width / 8))(con2) + con3 = self.concat((ups2, feature_map1)) + _, small_object_output = self.backblock2(con3) + + return big_object_output, medium_object_output, small_object_output + + +class DetectionBlock(nn.Cell): + """ + YOLOv3 detection Network. It will finally output the detection result. + + Args: + scale: Character. + config: ConfigYOLOV3DarkNet53, Configuration instance. + is_training: Bool, Whether train or not, default True. + + Returns: + Tuple, tuple of output tensor,(f1,f2,f3). + + Examples: + DetectionBlock(scale='l',stride=32) + """ + + def __init__(self, scale, config=ConfigYOLOV3DarkNet53(), is_training=True): + super(DetectionBlock, self).__init__() + self.config = config + if scale == 's': + idx = (0, 1, 2) + elif scale == 'm': + idx = (3, 4, 5) + elif scale == 'l': + idx = (6, 7, 8) + else: + raise KeyError("Invalid scale value for DetectionBlock") + self.anchors = Tensor([self.config.anchor_scales[i] for i in idx], ms.float32) + self.num_anchors_per_scale = 3 + self.num_attrib = 4+1+self.config.num_classes + self.lambda_coord = 1 + + self.sigmoid = nn.Sigmoid() + self.reshape = P.Reshape() + self.tile = P.Tile() + self.concat = P.Concat(axis=-1) + self.conf_training = is_training + + def construct(self, x, input_shape): + num_batch = P.Shape()(x)[0] + grid_size = P.Shape()(x)[2:4] + + # Reshape and transpose the feature to [n, grid_size[0], grid_size[1], 3, num_attrib] + prediction = P.Reshape()(x, (num_batch, + self.num_anchors_per_scale, + self.num_attrib, + grid_size[0], + grid_size[1])) + prediction = P.Transpose()(prediction, (0, 3, 4, 1, 2)) + + range_x = range(grid_size[1]) + range_y = range(grid_size[0]) + grid_x = P.Cast()(F.tuple_to_array(range_x), ms.float32) + grid_y = P.Cast()(F.tuple_to_array(range_y), ms.float32) + # Tensor of shape [grid_size[0], grid_size[1], 1, 1] representing the coordinate of x/y axis for each grid + # [batch, gridx, gridy, 1, 1] + grid_x = self.tile(self.reshape(grid_x, (1, 1, -1, 1, 1)), (1, grid_size[0], 1, 1, 1)) + grid_y = self.tile(self.reshape(grid_y, (1, -1, 1, 1, 1)), (1, 1, grid_size[1], 1, 1)) + # Shape is [grid_size[0], grid_size[1], 1, 2] + grid = self.concat((grid_x, grid_y)) + + box_xy = prediction[:, :, :, :, :2] + box_wh = prediction[:, :, :, :, 2:4] + box_confidence = prediction[:, :, :, :, 4:5] + box_probs = prediction[:, :, :, :, 5:] + + # gridsize1 is x + # gridsize0 is y + box_xy = (self.sigmoid(box_xy) + grid) / P.Cast()(F.tuple_to_array((grid_size[1], grid_size[0])), ms.float32) + # box_wh is w->h + box_wh = P.Exp()(box_wh) * self.anchors / input_shape + box_confidence = self.sigmoid(box_confidence) + box_probs = self.sigmoid(box_probs) + + if self.conf_training: + return grid, prediction, box_xy, box_wh + return self.concat((box_xy, box_wh, box_confidence, box_probs)) + + +class Iou(nn.Cell): + """Calculate the iou of boxes""" + def __init__(self): + super(Iou, self).__init__() + self.min = P.Minimum() + self.max = P.Maximum() + + def construct(self, box1, box2): + # box1: pred_box [batch, gx, gy, anchors, 1, 4] ->4: [x_center, y_center, w, h] + # box2: gt_box [batch, 1, 1, 1, maxbox, 4] + # convert to topLeft and rightDown + box1_xy = box1[:, :, :, :, :, :2] + box1_wh = box1[:, :, :, :, :, 2:4] + box1_mins = box1_xy - box1_wh / F.scalar_to_array(2.0) # topLeft + box1_maxs = box1_xy + box1_wh / F.scalar_to_array(2.0) # rightDown + + box2_xy = box2[:, :, :, :, :, :2] + box2_wh = box2[:, :, :, :, :, 2:4] + box2_mins = box2_xy - box2_wh / F.scalar_to_array(2.0) + box2_maxs = box2_xy + box2_wh / F.scalar_to_array(2.0) + + intersect_mins = self.max(box1_mins, box2_mins) + intersect_maxs = self.min(box1_maxs, box2_maxs) + intersect_wh = self.max(intersect_maxs - intersect_mins, F.scalar_to_array(0.0)) + # P.squeeze: for effiecient slice + intersect_area = P.Squeeze(-1)(intersect_wh[:, :, :, :, :, 0:1]) * \ + P.Squeeze(-1)(intersect_wh[:, :, :, :, :, 1:2]) + box1_area = P.Squeeze(-1)(box1_wh[:, :, :, :, :, 0:1]) * P.Squeeze(-1)(box1_wh[:, :, :, :, :, 1:2]) + box2_area = P.Squeeze(-1)(box2_wh[:, :, :, :, :, 0:1]) * P.Squeeze(-1)(box2_wh[:, :, :, :, :, 1:2]) + iou = intersect_area / (box1_area + box2_area - intersect_area) + # iou : [batch, gx, gy, anchors, maxboxes] + return iou + + +class YoloLossBlock(nn.Cell): + """ + Loss block cell of YOLOV3 network. + """ + def __init__(self, scale, config=ConfigYOLOV3DarkNet53()): + super(YoloLossBlock, self).__init__() + self.config = config + if scale == 's': + # anchor mask + idx = (0, 1, 2) + elif scale == 'm': + idx = (3, 4, 5) + elif scale == 'l': + idx = (6, 7, 8) + else: + raise KeyError("Invalid scale value for DetectionBlock") + self.anchors = Tensor([self.config.anchor_scales[i] for i in idx], ms.float32) + self.ignore_threshold = Tensor(self.config.ignore_threshold, ms.float32) + self.concat = P.Concat(axis=-1) + self.iou = Iou() + self.reduce_max = P.ReduceMax(keep_dims=False) + self.xy_loss = XYLoss() + self.wh_loss = WHLoss() + self.confidenceLoss = ConfidenceLoss() + self.classLoss = ClassLoss() + + def construct(self, grid, prediction, pred_xy, pred_wh, y_true, gt_box, input_shape): + # prediction : origin output from yolo + # pred_xy: (sigmoid(xy)+grid)/grid_size + # pred_wh: (exp(wh)*anchors)/input_shape + # y_true : after normalize + # gt_box: [batch, maxboxes, xyhw] after normalize + + object_mask = y_true[:, :, :, :, 4:5] + class_probs = y_true[:, :, :, :, 5:] + + grid_shape = P.Shape()(prediction)[1:3] + grid_shape = P.Cast()(F.tuple_to_array(grid_shape[::-1]), ms.float32) + + pred_boxes = self.concat((pred_xy, pred_wh)) + true_xy = y_true[:, :, :, :, :2] * grid_shape - grid + true_wh = y_true[:, :, :, :, 2:4] + true_wh = P.Select()(P.Equal()(true_wh, 0.0), + P.Fill()(P.DType()(true_wh), + P.Shape()(true_wh), 1.0), + true_wh) + true_wh = P.Log()(true_wh / self.anchors * input_shape) + # 2-w*h for large picture, use small scale, since small obj need more precise + box_loss_scale = 2 - y_true[:, :, :, :, 2:3] * y_true[:, :, :, :, 3:4] + + gt_shape = P.Shape()(gt_box) + gt_box = P.Reshape()(gt_box, (gt_shape[0], 1, 1, 1, gt_shape[1], gt_shape[2])) + + # add one more dimension for broadcast + iou = self.iou(P.ExpandDims()(pred_boxes, -2), gt_box) + # gt_box is x,y,h,w after normalize + # [batch, grid[0], grid[1], num_anchor, num_gt] + best_iou = self.reduce_max(iou, -1) + # [batch, grid[0], grid[1], num_anchor] + + # ignore_mask IOU too small + ignore_mask = best_iou < self.ignore_threshold + ignore_mask = P.Cast()(ignore_mask, ms.float32) + ignore_mask = P.ExpandDims()(ignore_mask, -1) + # ignore_mask backpro will cause a lot maximunGrad and minimumGrad time consume. + # so we turn off its gradient + ignore_mask = F.stop_gradient(ignore_mask) + + xy_loss = self.xy_loss(object_mask, box_loss_scale, prediction[:, :, :, :, :2], true_xy) + wh_loss = self.wh_loss(object_mask, box_loss_scale, prediction[:, :, :, :, 2:4], true_wh) + confidence_loss = self.confidenceLoss(object_mask, prediction[:, :, :, :, 4:5], ignore_mask) + class_loss = self.classLoss(object_mask, prediction[:, :, :, :, 5:], class_probs) + loss = xy_loss + wh_loss + confidence_loss + class_loss + batch_size = P.Shape()(prediction)[0] + return loss / batch_size + + +class YOLOV3DarkNet53(nn.Cell): + """ + Darknet based YOLOV3 network. + + Args: + is_training: Bool. Whether train or not. + + Returns: + Cell, cell instance of Darknet based YOLOV3 neural network. + + Examples: + YOLOV3DarkNet53(True) + """ + + def __init__(self, is_training): + super(YOLOV3DarkNet53, self).__init__() + self.config = ConfigYOLOV3DarkNet53() + + # YOLOv3 network + self.feature_map = YOLOv3(backbone=DarkNet(ResidualBlock, self.config.backbone_layers, + self.config.backbone_input_shape, + self.config.backbone_shape, + detect=True), + backbone_shape=self.config.backbone_shape, + out_channel=self.config.out_channel) + + # prediction on the default anchor boxes + self.detect_1 = DetectionBlock('l', is_training=is_training) + self.detect_2 = DetectionBlock('m', is_training=is_training) + self.detect_3 = DetectionBlock('s', is_training=is_training) + + def construct(self, x, input_shape): + big_object_output, medium_object_output, small_object_output = self.feature_map(x) + output_big = self.detect_1(big_object_output, input_shape) + output_me = self.detect_2(medium_object_output, input_shape) + output_small = self.detect_3(small_object_output, input_shape) + # big is the final output which has smallest feature map + return output_big, output_me, output_small + + +class YoloWithLossCell(nn.Cell): + """YOLOV3 loss.""" + def __init__(self, network): + super(YoloWithLossCell, self).__init__() + self.yolo_network = network + self.config = ConfigYOLOV3DarkNet53() + self.loss_big = YoloLossBlock('l', self.config) + self.loss_me = YoloLossBlock('m', self.config) + self.loss_small = YoloLossBlock('s', self.config) + + def construct(self, x, y_true_0, y_true_1, y_true_2, gt_0, gt_1, gt_2, input_shape): + yolo_out = self.yolo_network(x, input_shape) + loss_l = self.loss_big(*yolo_out[0], y_true_0, gt_0, input_shape) + loss_m = self.loss_me(*yolo_out[1], y_true_1, gt_1, input_shape) + loss_s = self.loss_small(*yolo_out[2], y_true_2, gt_2, input_shape) + return loss_l + loss_m + loss_s + + +class TrainingWrapper(nn.Cell): + """Training wrapper.""" + def __init__(self, network, optimizer, sens=1.0): + super(TrainingWrapper, self).__init__(auto_prefix=False) + self.network = network + self.weights = optimizer.parameters + self.optimizer = optimizer + self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) + self.sens = sens + self.reducer_flag = False + self.grad_reducer = None + self.parallel_mode = context.get_auto_parallel_context("parallel_mode") + if self.parallel_mode in [ms.ParallelMode.DATA_PARALLEL, ms.ParallelMode.HYBRID_PARALLEL]: + self.reducer_flag = True + if self.reducer_flag: + mean = context.get_auto_parallel_context("mirror_mean") + if auto_parallel_context().get_device_num_is_set(): + degree = context.get_auto_parallel_context("device_num") + else: + degree = get_group_size() + self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree) + + def construct(self, *args): + weights = self.weights + loss = self.network(*args) + sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) + grads = self.grad(self.network, weights)(*args, sens) + if self.reducer_flag: + grads = self.grad_reducer(grads) + return F.depend(loss, self.optimizer(grads)) diff --git a/model_zoo/official/cv/yolov3_darknet53_quant/src/yolo_dataset.py b/model_zoo/official/cv/yolov3_darknet53_quant/src/yolo_dataset.py new file mode 100644 index 000000000..45657db82 --- /dev/null +++ b/model_zoo/official/cv/yolov3_darknet53_quant/src/yolo_dataset.py @@ -0,0 +1,184 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""YOLOV3 dataset.""" +import os + +from PIL import Image +from pycocotools.coco import COCO +import mindspore.dataset as de +import mindspore.dataset.transforms.vision.c_transforms as CV + +from src.distributed_sampler import DistributedSampler +from src.transforms import reshape_fn, MultiScaleTrans + + +min_keypoints_per_image = 10 + + +def _has_only_empty_bbox(anno): + return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno) + + +def _count_visible_keypoints(anno): + return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno) + + +def has_valid_annotation(anno): + """Check annotation file.""" + # if it's empty, there is no annotation + if not anno: + return False + # if all boxes have close to zero area, there is no annotation + if _has_only_empty_bbox(anno): + return False + # keypoints task have a slight different critera for considering + # if an annotation is valid + if "keypoints" not in anno[0]: + return True + # for keypoint detection tasks, only consider valid images those + # containing at least min_keypoints_per_image + if _count_visible_keypoints(anno) >= min_keypoints_per_image: + return True + return False + + +class COCOYoloDataset: + """YOLOV3 Dataset for COCO.""" + def __init__(self, root, ann_file, remove_images_without_annotations=True, + filter_crowd_anno=True, is_training=True): + self.coco = COCO(ann_file) + self.root = root + self.img_ids = list(sorted(self.coco.imgs.keys())) + self.filter_crowd_anno = filter_crowd_anno + self.is_training = is_training + + # filter images without any annotations + if remove_images_without_annotations: + img_ids = [] + for img_id in self.img_ids: + ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None) + anno = self.coco.loadAnns(ann_ids) + if has_valid_annotation(anno): + img_ids.append(img_id) + self.img_ids = img_ids + + self.categories = {cat["id"]: cat["name"] for cat in self.coco.cats.values()} + + self.cat_ids_to_continuous_ids = { + v: i for i, v in enumerate(self.coco.getCatIds()) + } + self.continuous_ids_cat_ids = { + v: k for k, v in self.cat_ids_to_continuous_ids.items() + } + + def __getitem__(self, index): + """ + Args: + index (int): Index + + Returns: + (img, target) (tuple): target is a dictionary contains "bbox", "segmentation" or "keypoints", + generated by the image's annotation. img is a PIL image. + """ + coco = self.coco + img_id = self.img_ids[index] + img_path = coco.loadImgs(img_id)[0]["file_name"] + img = Image.open(os.path.join(self.root, img_path)).convert("RGB") + if not self.is_training: + return img, img_id + + ann_ids = coco.getAnnIds(imgIds=img_id) + target = coco.loadAnns(ann_ids) + # filter crowd annotations + if self.filter_crowd_anno: + annos = [anno for anno in target if anno["iscrowd"] == 0] + else: + annos = [anno for anno in target] + + target = {} + boxes = [anno["bbox"] for anno in annos] + target["bboxes"] = boxes + + classes = [anno["category_id"] for anno in annos] + classes = [self.cat_ids_to_continuous_ids[cl] for cl in classes] + target["labels"] = classes + + bboxes = target['bboxes'] + labels = target['labels'] + out_target = [] + for bbox, label in zip(bboxes, labels): + tmp = [] + # convert to [x_min y_min x_max y_max] + bbox = self._convetTopDown(bbox) + tmp.extend(bbox) + tmp.append(int(label)) + # tmp [x_min y_min x_max y_max, label] + out_target.append(tmp) + return img, out_target + + def __len__(self): + return len(self.img_ids) + + def _convetTopDown(self, bbox): + x_min = bbox[0] + y_min = bbox[1] + w = bbox[2] + h = bbox[3] + return [x_min, y_min, x_min+w, y_min+h] + + +def create_yolo_dataset(image_dir, anno_path, batch_size, max_epoch, device_num, rank, + config=None, is_training=True, shuffle=True): + """Create dataset for YOLOV3.""" + if is_training: + filter_crowd = True + remove_empty_anno = True + else: + filter_crowd = False + remove_empty_anno = False + + yolo_dataset = COCOYoloDataset(root=image_dir, ann_file=anno_path, filter_crowd_anno=filter_crowd, + remove_images_without_annotations=remove_empty_anno, is_training=is_training) + distributed_sampler = DistributedSampler(len(yolo_dataset), device_num, rank, shuffle=shuffle) + hwc_to_chw = CV.HWC2CHW() + + config.dataset_size = len(yolo_dataset) + num_parallel_workers1 = int(64 / device_num) + num_parallel_workers2 = int(16 / device_num) + if is_training: + multi_scale_trans = MultiScaleTrans(config, device_num) + if device_num != 8: + ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "annotation"], + num_parallel_workers=num_parallel_workers1, + sampler=distributed_sampler) + ds = ds.batch(batch_size, per_batch_map=multi_scale_trans, input_columns=['image', 'annotation'], + num_parallel_workers=num_parallel_workers2, drop_remainder=True) + else: + ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "annotation"], sampler=distributed_sampler) + ds = ds.batch(batch_size, per_batch_map=multi_scale_trans, input_columns=['image', 'annotation'], + num_parallel_workers=8, drop_remainder=True) + else: + ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "img_id"], + sampler=distributed_sampler) + compose_map_func = (lambda image, img_id: reshape_fn(image, img_id, config)) + ds = ds.map(input_columns=["image", "img_id"], + output_columns=["image", "image_shape", "img_id"], + columns_order=["image", "image_shape", "img_id"], + operations=compose_map_func, num_parallel_workers=8) + ds = ds.map(input_columns=["image"], operations=hwc_to_chw, num_parallel_workers=8) + ds = ds.batch(batch_size, drop_remainder=True) + ds = ds.repeat(max_epoch) + + return ds, len(yolo_dataset) diff --git a/model_zoo/official/cv/yolov3_darknet53_quant/train.py b/model_zoo/official/cv/yolov3_darknet53_quant/train.py new file mode 100644 index 000000000..75d1eb090 --- /dev/null +++ b/model_zoo/official/cv/yolov3_darknet53_quant/train.py @@ -0,0 +1,362 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""YoloV3 train.""" + +import os +import time +import argparse +import datetime + +from mindspore import ParallelMode +from mindspore.nn.optim.momentum import Momentum +from mindspore import Tensor +from mindspore import context +from mindspore.communication.management import init, get_rank, get_group_size +from mindspore.train.callback import ModelCheckpoint, RunContext +from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig +import mindspore as ms +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.train.quant import quant + +from src.yolo import YOLOV3DarkNet53, YoloWithLossCell, TrainingWrapper +from src.logger import get_logger +from src.util import AverageMeter, load_backbone, get_param_groups +from src.lr_scheduler import warmup_step_lr, warmup_cosine_annealing_lr, \ + warmup_cosine_annealing_lr_V2, warmup_cosine_annealing_lr_sample +from src.yolo_dataset import create_yolo_dataset +from src.initializer import default_recurisive_init +from src.config import ConfigYOLOV3DarkNet53 +from src.transforms import batch_preprocess_true_box, batch_preprocess_true_box_single +from src.util import ShapeRecord + + +devid = int(os.getenv('DEVICE_ID')) +context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, + device_target="Ascend", save_graphs=True, device_id=devid) + + +def parse_args(): + """Parse train arguments.""" + parser = argparse.ArgumentParser('mindspore coco training') + + # dataset related + parser.add_argument('--data_dir', type=str, default='', help='train data dir') + parser.add_argument('--per_batch_size', default=32, type=int, help='batch size for per gpu') + + # network related + parser.add_argument('--pretrained_backbone', default='', type=str, help='model_path, local pretrained backbone' + ' model to load') + parser.add_argument('--resume_yolov3', default='', type=str, help='path of pretrained yolov3') + + # optimizer and lr related + parser.add_argument('--lr_scheduler', default='exponential', type=str, + help='lr-scheduler, option type: exponential, cosine_annealing') + parser.add_argument('--lr', default=0.001, type=float, help='learning rate of the training') + parser.add_argument('--lr_epochs', type=str, default='220,250', help='epoch of lr changing') + parser.add_argument('--lr_gamma', type=float, default=0.1, + help='decrease lr by a factor of exponential lr_scheduler') + parser.add_argument('--eta_min', type=float, default=0., help='eta_min in cosine_annealing scheduler') + parser.add_argument('--T_max', type=int, default=320, help='T-max in cosine_annealing scheduler') + parser.add_argument('--max_epoch', type=int, default=320, help='max epoch num to train the model') + parser.add_argument('--warmup_epochs', default=0, type=float, help='warmup epoch') + parser.add_argument('--weight_decay', type=float, default=0.0005, help='weight decay') + parser.add_argument('--momentum', type=float, default=0.9, help='momentum') + + # loss related + parser.add_argument('--loss_scale', type=int, default=1024, help='static loss scale') + parser.add_argument('--label_smooth', type=int, default=0, help='whether to use label smooth in CE') + parser.add_argument('--label_smooth_factor', type=float, default=0.1, help='smooth strength of original one-hot') + + # logging related + parser.add_argument('--log_interval', type=int, default=100, help='logging interval') + parser.add_argument('--ckpt_path', type=str, default='outputs/', help='checkpoint save location') + parser.add_argument('--ckpt_interval', type=int, default=None, help='ckpt_interval') + parser.add_argument('--is_save_on_master', type=int, default=1, help='save ckpt on master or all rank') + + # distributed related + parser.add_argument('--is_distributed', type=int, default=1, help='if multi device') + parser.add_argument('--rank', type=int, default=0, help='local rank of distributed') + parser.add_argument('--group_size', type=int, default=1, help='world size of distributed') + + # roma obs + parser.add_argument('--train_url', type=str, default="", help='train url') + + # profiler init + parser.add_argument('--need_profiler', type=int, default=0, help='whether use profiler') + + # reset default config + parser.add_argument('--training_shape', type=str, default="", help='fix training shape') + parser.add_argument('--resize_rate', type=int, default=None, help='resize rate for multi-scale training') + + args, _ = parser.parse_known_args() + if args.lr_scheduler == 'cosine_annealing' and args.max_epoch > args.T_max: + args.T_max = args.max_epoch + + args.lr_epochs = list(map(int, args.lr_epochs.split(','))) + args.data_root = os.path.join(args.data_dir, 'train2014') + args.annFile = os.path.join(args.data_dir, 'annotations/instances_train2014.json') + + return args + + +def conver_training_shape(args): + training_shape = [int(args.training_shape), int(args.training_shape)] + return training_shape + + +def train(): + """Train function.""" + args = parse_args() + + # init distributed + if args.is_distributed: + init() + args.rank = get_rank() + args.group_size = get_group_size() + + # select for master rank save ckpt or all rank save, compatiable for model parallel + args.rank_save_ckpt_flag = 0 + if args.is_save_on_master: + if args.rank == 0: + args.rank_save_ckpt_flag = 1 + else: + args.rank_save_ckpt_flag = 1 + + # logger + args.outputs_dir = os.path.join(args.ckpt_path, + datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) + args.logger = get_logger(args.outputs_dir, args.rank) + args.logger.save_args(args) + + if args.need_profiler: + from mindinsight.profiler.profiling import Profiler + profiler = Profiler(output_path=args.outputs_dir, is_detail=True, is_show_op_path=True) + + loss_meter = AverageMeter('loss') + + context.reset_auto_parallel_context() + if args.is_distributed: + parallel_mode = ParallelMode.DATA_PARALLEL + degree = get_group_size() + else: + parallel_mode = ParallelMode.STAND_ALONE + degree = 1 + context.set_auto_parallel_context(parallel_mode=parallel_mode, mirror_mean=True, device_num=degree) + + network = YOLOV3DarkNet53(is_training=True) + # default is kaiming-normal + default_recurisive_init(network) + + if args.pretrained_backbone: + network = load_backbone(network, args.pretrained_backbone, args) + args.logger.info('load pre-trained backbone {} into network'.format(args.pretrained_backbone)) + else: + args.logger.info('Not load pre-trained backbone, please be careful') + + if args.resume_yolov3: + param_dict = load_checkpoint(args.resume_yolov3) + param_dict_new = {} + for key, values in param_dict.items(): + args.logger.info('ckpt param name = {}'.format(key)) + if key.startswith('moments.') or key.startswith('global_') or \ + key.startswith('learning_rate') or key.startswith('momentum'): + continue + elif key.startswith('yolo_network.'): + key_new = key[13:] + + if key_new.endswith('1.beta'): + key_new = key_new.replace('1.beta', 'batchnorm.beta') + + if key_new.endswith('1.gamma'): + key_new = key_new.replace('1.gamma', 'batchnorm.gamma') + + if key_new.endswith('1.moving_mean'): + key_new = key_new.replace('1.moving_mean', 'batchnorm.moving_mean') + + if key_new.endswith('1.moving_variance'): + key_new = key_new.replace('1.moving_variance', 'batchnorm.moving_variance') + + if key_new.endswith('.weight'): + if key_new.endswith('0.weight'): + key_new = key_new.replace('0.weight', 'conv.weight') + else: + key_new = key_new.replace('.weight', '.conv.weight') + + if key_new.endswith('.bias'): + key_new = key_new.replace('.bias', '.conv.bias') + param_dict_new[key_new] = values + + args.logger.info('in resume {}'.format(key_new)) + else: + param_dict_new[key] = values + args.logger.info('in resume {}'.format(key)) + + args.logger.info('resume finished') + for _, param in network.parameters_and_names(): + args.logger.info('network param name = {}'.format(param.name)) + if param.name not in param_dict_new: + args.logger.info('not match param name = {}'.format(param.name)) + load_param_into_net(network, param_dict_new) + args.logger.info('load_model {} success'.format(args.resume_yolov3)) + + config = ConfigYOLOV3DarkNet53() + # convert fusion network to quantization aware network + if config.quantization_aware: + network = quant.convert_quant_network(network, + bn_fold=True, + per_channel=[True, False], + symmetric=[True, False]) + + network = YoloWithLossCell(network) + args.logger.info('finish get network') + + config.label_smooth = args.label_smooth + config.label_smooth_factor = args.label_smooth_factor + + if args.training_shape: + config.multi_scale = [conver_training_shape(args)] + + if args.resize_rate: + config.resize_rate = args.resize_rate + + ds, data_size = create_yolo_dataset(image_dir=args.data_root, anno_path=args.annFile, is_training=True, + batch_size=args.per_batch_size, max_epoch=args.max_epoch, + device_num=args.group_size, rank=args.rank, config=config) + args.logger.info('Finish loading dataset') + + args.steps_per_epoch = int(data_size / args.per_batch_size / args.group_size) + + if not args.ckpt_interval: + args.ckpt_interval = args.steps_per_epoch + + # lr scheduler + if args.lr_scheduler == 'exponential': + lr = warmup_step_lr(args.lr, + args.lr_epochs, + args.steps_per_epoch, + args.warmup_epochs, + args.max_epoch, + gamma=args.lr_gamma, + ) + elif args.lr_scheduler == 'cosine_annealing': + lr = warmup_cosine_annealing_lr(args.lr, + args.steps_per_epoch, + args.warmup_epochs, + args.max_epoch, + args.T_max, + args.eta_min) + elif args.lr_scheduler == 'cosine_annealing_V2': + lr = warmup_cosine_annealing_lr_V2(args.lr, + args.steps_per_epoch, + args.warmup_epochs, + args.max_epoch, + args.T_max, + args.eta_min) + elif args.lr_scheduler == 'cosine_annealing_sample': + lr = warmup_cosine_annealing_lr_sample(args.lr, + args.steps_per_epoch, + args.warmup_epochs, + args.max_epoch, + args.T_max, + args.eta_min) + else: + raise NotImplementedError(args.lr_scheduler) + + opt = Momentum(params=get_param_groups(network), + learning_rate=Tensor(lr), + momentum=args.momentum, + weight_decay=args.weight_decay, + loss_scale=args.loss_scale) + + network = TrainingWrapper(network, opt) + network.set_train() + + if args.rank_save_ckpt_flag: + # checkpoint save + ckpt_max_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval + ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval, + keep_checkpoint_max=ckpt_max_num) + ckpt_cb = ModelCheckpoint(config=ckpt_config, + directory=args.outputs_dir, + prefix='{}'.format(args.rank)) + cb_params = _InternalCallbackParam() + cb_params.train_network = network + cb_params.epoch_num = ckpt_max_num + cb_params.cur_epoch_num = 1 + run_context = RunContext(cb_params) + ckpt_cb.begin(run_context) + + old_progress = -1 + t_end = time.time() + data_loader = ds.create_dict_iterator() + + shape_record = ShapeRecord() + for i, data in enumerate(data_loader): + images = data["image"] + input_shape = images.shape[2:4] + args.logger.info('iter[{}], shape{}'.format(i, input_shape[0])) + shape_record.set(input_shape) + + images = Tensor(images) + annos = data["annotation"] + if args.group_size == 1: + batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, batch_gt_box2 = \ + batch_preprocess_true_box(annos, config, input_shape) + else: + batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, batch_gt_box2 = \ + batch_preprocess_true_box_single(annos, config, input_shape) + + batch_y_true_0 = Tensor(batch_y_true_0) + batch_y_true_1 = Tensor(batch_y_true_1) + batch_y_true_2 = Tensor(batch_y_true_2) + batch_gt_box0 = Tensor(batch_gt_box0) + batch_gt_box1 = Tensor(batch_gt_box1) + batch_gt_box2 = Tensor(batch_gt_box2) + + input_shape = Tensor(tuple(input_shape[::-1]), ms.float32) + loss = network(images, batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, + batch_gt_box2, input_shape) + loss_meter.update(loss.asnumpy()) + + if args.rank_save_ckpt_flag: + # ckpt progress + cb_params.cur_step_num = i + 1 # current step number + cb_params.batch_num = i + 2 + ckpt_cb.step_end(run_context) + + if i % args.log_interval == 0: + time_used = time.time() - t_end + epoch = int(i / args.steps_per_epoch) + fps = args.per_batch_size * (i - old_progress) * args.group_size / time_used + if args.rank == 0: + args.logger.info( + 'epoch[{}], iter[{}], {}, {:.2f} imgs/sec, lr:{}'.format(epoch, i, loss_meter, fps, lr[i])) + t_end = time.time() + loss_meter.reset() + old_progress = i + + if (i + 1) % args.steps_per_epoch == 0 and args.rank_save_ckpt_flag: + cb_params.cur_epoch_num += 1 + + if args.need_profiler: + if i == 10: + profiler.analyse() + break + + args.logger.info('==========end training===============') + + +if __name__ == "__main__": + train() -- GitLab