提交 75939257 编写于 作者: C chengxianbin

upload yolov3-darknet53 quant code

上级 0df5a561
......@@ -29,7 +29,7 @@ from mindspore._checkparam import Rel
import mindspore.context as context
from .normalization import BatchNorm2d, BatchNorm1d
from .activation import get_activation, ReLU
from .activation import get_activation, ReLU, LeakyReLU
from ..cell import Cell
from . import conv, basic
from ..._checkparam import ParamValidator as validator
......@@ -115,7 +115,11 @@ class Conv2dBnAct(Cell):
weight_init='normal',
bias_init='zeros',
has_bn=False,
activation=None):
momentum=0.9,
eps=1e-5,
activation=None,
alpha=0.2,
after_fake=True):
super(Conv2dBnAct, self).__init__()
if context.get_context('device_target') == "Ascend" and group > 1:
......@@ -145,9 +149,13 @@ class Conv2dBnAct(Cell):
self.has_bn = validator.check_bool("has_bn", has_bn)
self.has_act = activation is not None
self.after_fake = after_fake
if has_bn:
self.batchnorm = BatchNorm2d(out_channels)
self.activation = get_activation(activation)
self.batchnorm = BatchNorm2d(out_channels, eps, momentum)
if activation == "leakyrelu":
self.activation = LeakyReLU(alpha)
else:
self.activation = get_activation(activation)
def construct(self, x):
x = self.conv(x)
......
......@@ -244,7 +244,7 @@ class ConvertToQuantNetwork:
subcell.conv = conv_inner
if subcell.has_act and subcell.activation is not None:
subcell.activation = self._convert_activation(subcell.activation)
else:
elif subcell.after_fake:
subcell.has_act = True
subcell.activation = _AddFakeQuantAfterSubCell(F.identity,
num_bits=self.act_bits,
......@@ -274,7 +274,7 @@ class ConvertToQuantNetwork:
subcell.dense = dense_inner
if subcell.has_act and subcell.activation is not None:
subcell.activation = self._convert_activation(subcell.activation)
else:
elif subcell.after_fake:
subcell.has_act = True
subcell.activation = _AddFakeQuantAfterSubCell(F.identity,
num_bits=self.act_bits,
......
# YOLOV3-DarkNet53-Quant Example
## Description
This is an example of training YOLOV3-DarkNet53-Quant with COCO2014 dataset in MindSpore.
## Requirements
- Install [MindSpore](https://www.mindspore.cn/install/en).
- Download the dataset COCO2014.
> Unzip the COCO2014 dataset to any path you want, the folder should include train and eval dataset as follows:
```
.
└─dataset
├─train2014
├─val2014
└─annotations
```
## Structure
```shell
.
└─yolov3_darknet53_quant
├─README.md
├─scripts
├─run_standalone_train.sh # launch standalone training(1p)
├─run_distribute_train.sh # launch distributed training(8p)
└─run_eval.sh # launch evaluating
├─src
├─__init__.py # python init file
├─config.py # parameter configuration
├─darknet.py # backbone of network
├─distributed_sampler.py # iterator of dataset
├─initializer.py # initializer of parameters
├─logger.py # log function
├─loss.py # loss function
├─lr_scheduler.py # generate learning rate
├─transforms.py # Preprocess data
├─util.py # util function
├─yolo.py # yolov3 network
├─yolo_dataset.py # create dataset for YOLOV3
├─eval.py # eval net
└─train.py # train net
```
## Running the example
### Train
#### Usage
```
# distributed training
sh run_distribute_train.sh [DATASET_PATH] [RESUME_YOLOV3] [MINDSPORE_HCCL_CONFIG_PATH]
# standalone training
sh run_standalone_train.sh [DATASET_PATH] [RESUME_YOLOV3]
```
#### Launch
```bash
# distributed training example(8p)
sh run_distribute_train.sh dataset/coco2014 yolov3_darknet_noquant_ckpt/0-320_102400.ckpt rank_table_8p.json
# standalone training example(1p)
sh run_standalone_train.sh dataset/coco2014 yolov3_darknet_noquant_ckpt/0-320_102400.ckpt
```
> About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html).
#### Result
Training result will be stored in the scripts path, whose folder name begins with "train" or "train_parallel". You can find checkpoint file together with result like the followings in log.txt.
```
# distribute training result(8p)
epoch[0], iter[0], loss:483.341675, 0.31 imgs/sec, lr:0.0
epoch[0], iter[100], loss:55.690952, 3.46 imgs/sec, lr:0.0
epoch[0], iter[200], loss:54.045728, 126.54 imgs/sec, lr:0.0
epoch[0], iter[300], loss:48.771608, 133.04 imgs/sec, lr:0.0
epoch[0], iter[400], loss:48.486769, 139.69 imgs/sec, lr:0.0
epoch[0], iter[500], loss:48.649275, 143.29 imgs/sec, lr:0.0
epoch[0], iter[600], loss:44.731309, 144.03 imgs/sec, lr:0.0
epoch[1], iter[700], loss:43.037023, 136.08 imgs/sec, lr:0.0
epoch[1], iter[800], loss:41.514788, 132.94 imgs/sec, lr:0.0
epoch[133], iter[85700], loss:33.326716, 136.14 imgs/sec, lr:6.497331924038008e-06
epoch[133], iter[85800], loss:34.968744, 136.76 imgs/sec, lr:6.497331924038008e-06
epoch[134], iter[85900], loss:35.868543, 137.08 imgs/sec, lr:1.6245529650404933e-06
epoch[134], iter[86000], loss:35.740817, 139.49 imgs/sec, lr:1.6245529650404933e-06
epoch[134], iter[86100], loss:34.600463, 141.47 imgs/sec, lr:1.6245529650404933e-06
epoch[134], iter[86200], loss:36.641916, 137.91 imgs/sec, lr:1.6245529650404933e-06
epoch[134], iter[86300], loss:32.819769, 138.17 imgs/sec, lr:1.6245529650404933e-06
epoch[134], iter[86400], loss:35.603033, 142.23 imgs/sec, lr:1.6245529650404933e-06
epoch[134], iter[86500], loss:34.303755, 145.18 imgs/sec, lr:1.6245529650404933e-06
...
```
### Infer
#### Usage
```
# infer
sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]
```
#### Launch
```bash
# infer with checkpoint
sh run_eval.sh dataset/coco2014/ checkpoint/0-135.ckpt
```
> checkpoint can be produced in training process.
#### Result
Inference result will be stored in the scripts path, whose folder name is "eval". Under this, you can find result like the followings in log.txt.
```
=============coco eval reulst=========
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.310
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.531
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.322
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.130
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.326
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.425
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.260
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.402
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.429
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.232
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.450
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.558
```
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""YoloV3 eval."""
import os
import argparse
import datetime
import time
import sys
from collections import defaultdict
import numpy as np
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from mindspore import Tensor
from mindspore.train import ParallelMode
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
import mindspore as ms
from mindspore.train.quant import quant
from src.yolo import YOLOV3DarkNet53
from src.logger import get_logger
from src.yolo_dataset import create_yolo_dataset
from src.config import ConfigYOLOV3DarkNet53
devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, device_id=devid)
class Redirct:
def __init__(self):
self.content = ""
def write(self, content):
self.content += content
def flush(self):
self.content = ""
class DetectionEngine:
"""Detection engine."""
def __init__(self, args):
self.ignore_threshold = args.ignore_threshold
self.labels = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',
'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat',
'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack',
'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book',
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
self.num_classes = len(self.labels)
self.results = {}
self.file_path = ''
self.save_prefix = args.outputs_dir
self.annFile = args.annFile
self._coco = COCO(self.annFile)
self._img_ids = list(sorted(self._coco.imgs.keys()))
self.det_boxes = []
self.nms_thresh = args.nms_thresh
self.coco_catIds = self._coco.getCatIds()
def do_nms_for_results(self):
"""Get result boxes."""
for img_id in self.results:
for clsi in self.results[img_id]:
dets = self.results[img_id][clsi]
dets = np.array(dets)
keep_index = self._nms(dets, self.nms_thresh)
keep_box = [{'image_id': int(img_id),
'category_id': int(clsi),
'bbox': list(dets[i][:4].astype(float)),
'score': dets[i][4].astype(float)}
for i in keep_index]
self.det_boxes.extend(keep_box)
def _nms(self, dets, thresh):
"""Calculate NMS."""
# conver xywh -> xmin ymin xmax ymax
x1 = dets[:, 0]
y1 = dets[:, 1]
x2 = x1 + dets[:, 2]
y2 = y1 + dets[:, 3]
scores = dets[:, 4]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
inds = np.where(ovr <= thresh)[0]
order = order[inds + 1]
return keep
def write_result(self):
"""Save result to file."""
import json
t = datetime.datetime.now().strftime('_%Y_%m_%d_%H_%M_%S')
try:
self.file_path = self.save_prefix + '/predict' + t + '.json'
f = open(self.file_path, 'w')
json.dump(self.det_boxes, f)
except IOError as e:
raise RuntimeError("Unable to open json file to dump. What(): {}".format(str(e)))
else:
f.close()
return self.file_path
def get_eval_result(self):
"""Get eval result."""
cocoGt = COCO(self.annFile)
cocoDt = cocoGt.loadRes(self.file_path)
cocoEval = COCOeval(cocoGt, cocoDt, 'bbox')
cocoEval.evaluate()
cocoEval.accumulate()
rdct = Redirct()
stdout = sys.stdout
sys.stdout = rdct
cocoEval.summarize()
sys.stdout = stdout
return rdct.content
def detect(self, outputs, batch, image_shape, image_id):
"""Detect boxes."""
outputs_num = len(outputs)
# output [|32, 52, 52, 3, 85| ]
for batch_id in range(batch):
for out_id in range(outputs_num):
# 32, 52, 52, 3, 85
out_item = outputs[out_id]
# 52, 52, 3, 85
out_item_single = out_item[batch_id, :]
# get number of items in one head, [B, gx, gy, anchors, 5+80]
dimensions = out_item_single.shape[:-1]
out_num = 1
for d in dimensions:
out_num *= d
ori_w, ori_h = image_shape[batch_id]
img_id = int(image_id[batch_id])
x = out_item_single[..., 0] * ori_w
y = out_item_single[..., 1] * ori_h
w = out_item_single[..., 2] * ori_w
h = out_item_single[..., 3] * ori_h
conf = out_item_single[..., 4:5]
cls_emb = out_item_single[..., 5:]
cls_argmax = np.expand_dims(np.argmax(cls_emb, axis=-1), axis=-1)
x = x.reshape(-1)
y = y.reshape(-1)
w = w.reshape(-1)
h = h.reshape(-1)
cls_emb = cls_emb.reshape(-1, 80)
conf = conf.reshape(-1)
cls_argmax = cls_argmax.reshape(-1)
x_top_left = x - w / 2.
y_top_left = y - h / 2.
# creat all False
flag = np.random.random(cls_emb.shape) > sys.maxsize
for i in range(flag.shape[0]):
c = cls_argmax[i]
flag[i, c] = True
confidence = cls_emb[flag] * conf
for x_lefti, y_lefti, wi, hi, confi, clsi in zip(x_top_left, y_top_left, w, h, confidence, cls_argmax):
if confi < self.ignore_threshold:
continue
if img_id not in self.results:
self.results[img_id] = defaultdict(list)
x_lefti = max(0, x_lefti)
y_lefti = max(0, y_lefti)
wi = min(wi, ori_w)
hi = min(hi, ori_h)
# transform catId to match coco
coco_clsi = self.coco_catIds[clsi]
self.results[img_id][coco_clsi].append([x_lefti, y_lefti, wi, hi, confi])
def parse_args():
"""Parse arguments."""
parser = argparse.ArgumentParser('mindspore coco testing')
# dataset related
parser.add_argument('--data_dir', type=str, default='', help='train data dir')
parser.add_argument('--per_batch_size', default=1, type=int, help='batch size for per gpu')
# network related
parser.add_argument('--pretrained', default='', type=str, help='model_path, local pretrained model to load')
# logging related
parser.add_argument('--log_path', type=str, default='outputs/', help='checkpoint save location')
# detect_related
parser.add_argument('--nms_thresh', type=float, default=0.5, help='threshold for NMS')
parser.add_argument('--annFile', type=str, default='', help='path to annotation')
parser.add_argument('--testing_shape', type=str, default='', help='shape for test ')
parser.add_argument('--ignore_threshold', type=float, default=0.001, help='threshold to throw low quality boxes')
args, _ = parser.parse_known_args()
args.data_root = os.path.join(args.data_dir, 'val2014')
args.annFile = os.path.join(args.data_dir, 'annotations/instances_val2014.json')
return args
def conver_testing_shape(args):
"""Convert testing shape to list."""
testing_shape = [int(args.testing_shape), int(args.testing_shape)]
return testing_shape
def test():
"""The function of eval."""
start_time = time.time()
args = parse_args()
# logger
args.outputs_dir = os.path.join(args.log_path,
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
rank_id = int(os.environ.get('RANK_ID'))
args.logger = get_logger(args.outputs_dir, rank_id)
context.reset_auto_parallel_context()
parallel_mode = ParallelMode.STAND_ALONE
context.set_auto_parallel_context(parallel_mode=parallel_mode, mirror_mean=True, device_num=1)
args.logger.info('Creating Network....')
network = YOLOV3DarkNet53(is_training=False)
config = ConfigYOLOV3DarkNet53()
if args.testing_shape:
config.test_img_shape = conver_testing_shape(args)
# convert fusion network to quantization aware network
if config.quantization_aware:
network = quant.convert_quant_network(network,
bn_fold=True,
per_channel=[True, False],
symmetric=[True, False])
args.logger.info(args.pretrained)
if os.path.isfile(args.pretrained):
param_dict = load_checkpoint(args.pretrained)
param_dict_new = {}
for key, values in param_dict.items():
if key.startswith('moments.'):
continue
elif key.startswith('yolo_network.'):
param_dict_new[key[13:]] = values
else:
param_dict_new[key] = values
load_param_into_net(network, param_dict_new)
args.logger.info('load_model {} success'.format(args.pretrained))
else:
args.logger.info('{} not exists or not a pre-trained file'.format(args.pretrained))
assert FileNotFoundError('{} not exists or not a pre-trained file'.format(args.pretrained))
exit(1)
data_root = args.data_root
ann_file = args.annFile
ds, data_size = create_yolo_dataset(data_root, ann_file, is_training=False, batch_size=args.per_batch_size,
max_epoch=1, device_num=1, rank=rank_id, shuffle=False,
config=config)
args.logger.info('testing shape : {}'.format(config.test_img_shape))
args.logger.info('totol {} images to eval'.format(data_size))
network.set_train(False)
# init detection engine
detection = DetectionEngine(args)
input_shape = Tensor(tuple(config.test_img_shape), ms.float32)
args.logger.info('Start inference....')
for i, data in enumerate(ds.create_dict_iterator()):
image = Tensor(data["image"])
image_shape = Tensor(data["image_shape"])
image_id = Tensor(data["img_id"])
prediction = network(image, input_shape)
output_big, output_me, output_small = prediction
output_big = output_big.asnumpy()
output_me = output_me.asnumpy()
output_small = output_small.asnumpy()
image_id = image_id.asnumpy()
image_shape = image_shape.asnumpy()
detection.detect([output_small, output_me, output_big], args.per_batch_size, image_shape, image_id)
if i % 1000 == 0:
args.logger.info('Processing... {:.2f}% '.format(i * args.per_batch_size / data_size * 100))
args.logger.info('Calculating mAP...')
detection.do_nms_for_results()
result_file_path = detection.write_result()
args.logger.info('result file path: {}'.format(result_file_path))
eval_result = detection.get_eval_result()
cost_time = time.time() - start_time
args.logger.info('\n=============coco eval reulst=========\n' + eval_result)
args.logger.info('testing cost time {:.2f}h'.format(cost_time / 3600.))
if __name__ == "__main__":
test()
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
then
echo "Usage: sh run_distribute_train.sh [DATASET_PATH] [RESUME_YOLOV3] [MINDSPORE_HCCL_CONFIG_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
DATASET_PATH=$(get_real_path $1)
RESUME_YOLOV3=$(get_real_path $2)
MINDSPORE_HCCL_CONFIG_PATH=$(get_real_path $3)
echo $DATASET_PATH
echo $RESUME_YOLOV3
echo $MINDSPORE_HCCL_CONFIG_PATH
if [ ! -d $DATASET_PATH ]
then
echo "error: DATASET_PATH=$DATASET_PATH is not a directory"
exit 1
fi
if [ ! -f $RESUME_YOLOV3 ]
then
echo "error: PRETRAINED_PATH=$RESUME_YOLOV3 is not a file"
exit 1
fi
if [ ! -f $MINDSPORE_HCCL_CONFIG_PATH ]
then
echo "error: MINDSPORE_HCCL_CONFIG_PATH=$MINDSPORE_HCCL_CONFIG_PATH is not a file"
exit 1
fi
export DEVICE_NUM=8
export RANK_SIZE=8
export MINDSPORE_HCCL_CONFIG_PATH=$MINDSPORE_HCCL_CONFIG_PATH
for((i=0; i<${DEVICE_NUM}; i++))
do
export DEVICE_ID=$i
export RANK_ID=$i
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp ../*.py ./train_parallel$i
cp -r ../src ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env > env.log
python train.py \
--data_dir=$DATASET_PATH \
--resume_yolov3=$RESUME_YOLOV3 \
--is_distributed=1 \
--per_batch_size=16 \
--lr=0.012 \
--T_max=135 \
--max_epoch=135 \
--warmup_epochs=5 \
--lr_scheduler=cosine_annealing > log.txt 2>&1 &
cd ..
done
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
DATASET_PATH=$(get_real_path $1)
CHECKPOINT_PATH=$(get_real_path $2)
echo $DATASET_PATH
echo $CHECKPOINT_PATH
if [ ! -d $DATASET_PATH ]
then
echo "error: DATASET_PATH=$PATH1 is not a directory"
exit 1
fi
if [ ! -f $CHECKPOINT_PATH ]
then
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
exit 1
fi
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0
if [ -d "eval" ];
then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp -r ../src ./eval
cd ./eval || exit
env > env.log
echo "start infering for device $DEVICE_ID"
python eval.py \
--data_dir=$DATASET_PATH \
--pretrained=$CHECKPOINT_PATH \
--testing_shape=416 > log.txt 2>&1 &
cd ..
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "Usage: sh run_standalone_train.sh [DATASET_PATH] [RESUME_YOLOV3]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
DATASET_PATH=$(get_real_path $1)
echo $DATASET_PATH
RESUME_YOLOV3=$(get_real_path $2)
echo $RESUME_YOLOV3
if [ ! -d $DATASET_PATH ]
then
echo "error: DATASET_PATH=$DATASET_PATH is not a directory"
exit 1
fi
if [ ! -f $RESUME_YOLOV3 ]
then
echo "error: PRETRAINED_PATH=$RESUME_YOLOV3 is not a file"
exit 1
fi
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_ID=0
export RANK_SIZE=1
if [ -d "train" ];
then
rm -rf ./train
fi
mkdir ./train
cp ../*.py ./train
cp -r ../src ./train
cd ./train || exit
echo "start training for device $DEVICE_ID"
env > env.log
python train.py \
--data_dir=$DATASET_PATH \
--resume_yolov3=$RESUME_YOLOV3 \
--is_distributed=0 \
--per_batch_size=16 \
--lr=0.004 \
--T_max=135 \
--max_epoch=135 \
--warmup_epochs=5 \
--lr_scheduler=cosine_annealing > log.txt 2>&1 &
cd ..
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Config parameters for Darknet based yolov3_darknet53 models."""
class ConfigYOLOV3DarkNet53:
"""
Config parameters for the yolov3_darknet53.
Examples:
ConfigYOLOV3DarkNet53()
"""
# train_param
# data augmentation related
hue = 0.1
saturation = 1.5
value = 1.5
jitter = 0.3
resize_rate = 1
multi_scale = [[320, 320],
[352, 352],
[384, 384],
[416, 416],
[448, 448],
[480, 480],
[512, 512],
[544, 544],
[576, 576],
[608, 608]
]
num_classes = 80
max_box = 50
backbone_input_shape = [32, 64, 128, 256, 512]
backbone_shape = [64, 128, 256, 512, 1024]
backbone_layers = [1, 2, 8, 8, 4]
# confidence under ignore_threshold means no object when training
ignore_threshold = 0.7
# h->w
anchor_scales = [(10, 13),
(16, 30),
(33, 23),
(30, 61),
(62, 45),
(59, 119),
(116, 90),
(156, 198),
(373, 326)]
out_channel = 255
quantization_aware = True
# test_param
test_img_shape = [416, 416]
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""DarkNet model."""
import mindspore.nn as nn
from mindspore.ops import operations as P
def conv_block(in_channels,
out_channels,
kernel_size,
stride,
dilation=1):
"""Get a conv2d batchnorm and relu layer"""
pad_mode = 'same'
padding = 0
return nn.Conv2dBnAct(in_channels, out_channels, kernel_size,
stride=stride,
pad_mode=pad_mode,
padding=padding,
dilation=dilation,
has_bn=True,
momentum=0.1,
activation='relu')
class ResidualBlock(nn.Cell):
"""
DarkNet V1 residual block definition.
Args:
in_channels: Integer. Input channel.
out_channels: Integer. Output channel.
Returns:
Tensor, output tensor.
Examples:
ResidualBlock(3, 208)
"""
expansion = 4
def __init__(self,
in_channels,
out_channels):
super(ResidualBlock, self).__init__()
out_chls = out_channels//2
self.conv1 = conv_block(in_channels, out_chls, kernel_size=1, stride=1)
self.conv2 = conv_block(out_chls, out_channels, kernel_size=3, stride=1)
self.add = P.TensorAdd()
def construct(self, x):
identity = x
out = self.conv1(x)
out = self.conv2(out)
out = self.add(out, identity)
return out
class DarkNet(nn.Cell):
"""
DarkNet V1 network.
Args:
block: Cell. Block for network.
layer_nums: List. Numbers of different layers.
in_channels: Integer. Input channel.
out_channels: Integer. Output channel.
detect: Bool. Whether detect or not. Default:False.
Returns:
Tuple, tuple of output tensor,(f1,f2,f3,f4,f5).
Examples:
DarkNet(ResidualBlock,
[1, 2, 8, 8, 4],
[32, 64, 128, 256, 512],
[64, 128, 256, 512, 1024],
100)
"""
def __init__(self,
block,
layer_nums,
in_channels,
out_channels,
detect=False):
super(DarkNet, self).__init__()
self.outchannel = out_channels[-1]
self.detect = detect
if not len(layer_nums) == len(in_channels) == len(out_channels) == 5:
raise ValueError("the length of layer_num, inchannel, outchannel list must be 5!")
self.conv0 = conv_block(3,
in_channels[0],
kernel_size=3,
stride=1)
self.conv1 = conv_block(in_channels[0],
out_channels[0],
kernel_size=3,
stride=2)
self.conv2 = conv_block(in_channels[1],
out_channels[1],
kernel_size=3,
stride=2)
self.conv3 = conv_block(in_channels[2],
out_channels[2],
kernel_size=3,
stride=2)
self.conv4 = conv_block(in_channels[3],
out_channels[3],
kernel_size=3,
stride=2)
self.conv5 = conv_block(in_channels[4],
out_channels[4],
kernel_size=3,
stride=2)
self.layer1 = self._make_layer(block,
layer_nums[0],
in_channel=out_channels[0],
out_channel=out_channels[0])
self.layer2 = self._make_layer(block,
layer_nums[1],
in_channel=out_channels[1],
out_channel=out_channels[1])
self.layer3 = self._make_layer(block,
layer_nums[2],
in_channel=out_channels[2],
out_channel=out_channels[2])
self.layer4 = self._make_layer(block,
layer_nums[3],
in_channel=out_channels[3],
out_channel=out_channels[3])
self.layer5 = self._make_layer(block,
layer_nums[4],
in_channel=out_channels[4],
out_channel=out_channels[4])
def _make_layer(self, block, layer_num, in_channel, out_channel):
"""
Make Layer for DarkNet.
:param block: Cell. DarkNet block.
:param layer_num: Integer. Layer number.
:param in_channel: Integer. Input channel.
:param out_channel: Integer. Output channel.
Examples:
_make_layer(ConvBlock, 1, 128, 256)
"""
layers = []
darkblk = block(in_channel, out_channel)
layers.append(darkblk)
for _ in range(1, layer_num):
darkblk = block(out_channel, out_channel)
layers.append(darkblk)
return nn.SequentialCell(layers)
def construct(self, x):
c1 = self.conv0(x)
c2 = self.conv1(c1)
c3 = self.layer1(c2)
c4 = self.conv2(c3)
c5 = self.layer2(c4)
c6 = self.conv3(c5)
c7 = self.layer3(c6)
c8 = self.conv4(c7)
c9 = self.layer4(c8)
c10 = self.conv5(c9)
c11 = self.layer5(c10)
if self.detect:
return c7, c9, c11
return c11
def get_out_channels(self):
return self.outchannel
def darknet53():
"""
Get DarkNet53 neural network.
Returns:
Cell, cell instance of DarkNet53 neural network.
Examples:
darknet53()
"""
return DarkNet(ResidualBlock, [1, 2, 8, 8, 4],
[32, 64, 128, 256, 512],
[64, 128, 256, 512, 1024])
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Yolo dataset distributed sampler."""
from __future__ import division
import math
import numpy as np
class DistributedSampler:
"""Distributed sampler."""
def __init__(self, dataset_size, num_replicas=None, rank=None, shuffle=True):
if num_replicas is None:
print("***********Setting world_size to 1 since it is not passed in ******************")
num_replicas = 1
if rank is None:
print("***********Setting rank to 0 since it is not passed in ******************")
rank = 0
self.dataset_size = dataset_size
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = int(math.ceil(dataset_size * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
def __iter__(self):
# deterministically shuffle based on epoch
if self.shuffle:
indices = np.random.RandomState(seed=self.epoch).permutation(self.dataset_size)
# np.array type. number from 0 to len(dataset_size)-1, used as index of dataset
indices = indices.tolist()
self.epoch += 1
# change to list type
else:
indices = list(range(self.dataset_size))
# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parameter init."""
import math
import numpy as np
from mindspore.common import initializer as init
from mindspore.common.initializer import Initializer as MeInitializer
import mindspore.nn as nn
from mindspore import Tensor
np.random.seed(5)
def calculate_gain(nonlinearity, param=None):
r"""Return the recommended gain value for the given nonlinearity function.
The values are as follows:
================= ====================================================
nonlinearity gain
================= ====================================================
Linear / Identity :math:`1`
Conv{1,2,3}D :math:`1`
Sigmoid :math:`1`
Tanh :math:`\frac{5}{3}`
ReLU :math:`\sqrt{2}`
Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
================= ====================================================
Args:
nonlinearity: the non-linear function (`nn.functional` name)
param: optional parameter for the non-linear function
Examples:
>>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2
"""
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
return 1
if nonlinearity == 'tanh':
return 5.0 / 3
if nonlinearity == 'relu':
return math.sqrt(2.0)
if nonlinearity == 'leaky_relu':
if param is None:
negative_slope = 0.01
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
# True/False are instances of int, hence check above
negative_slope = param
else:
raise ValueError("negative_slope {} not a valid number".format(param))
return math.sqrt(2.0 / (1 + negative_slope ** 2))
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
def _assignment(arr, num):
"""Assign the value of 'num' and 'arr'."""
if arr.shape == ():
arr = arr.reshape((1))
arr[:] = num
arr = arr.reshape(())
else:
if isinstance(num, np.ndarray):
arr[:] = num[:]
else:
arr[:] = num
return arr
def _calculate_correct_fan(array, mode):
mode = mode.lower()
valid_modes = ['fan_in', 'fan_out']
if mode not in valid_modes:
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
fan_in, fan_out = _calculate_fan_in_and_fan_out(array)
return fan_in if mode == 'fan_in' else fan_out
def kaiming_uniform_(arr, a=0, mode='fan_in', nonlinearity='leaky_relu'):
r"""Fills the input `Tensor` with values according to the method
described in `Delving deep into rectifiers: Surpassing human-level
performance on ImageNet classification` - He, K. et al. (2015), using a
uniform distribution. The resulting tensor will have values sampled from
:math:`\mathcal{U}(-\text{bound}, \text{bound})` where
.. math::
\text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
Also known as He initialization.
Args:
tensor: an n-dimensional `Tensor`
a: the negative slope of the rectifier used after this layer (only
used with ``'leaky_relu'``)
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
preserves the magnitude of the variance of the weights in the
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
backwards pass.
nonlinearity: the non-linear function (`nn.functional` name),
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
Examples:
>>> w = np.empty(3, 5)
>>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
"""
fan = _calculate_correct_fan(arr, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
return np.random.uniform(-bound, bound, arr.shape)
def _calculate_fan_in_and_fan_out(arr):
"""Calculate fan in and fan out."""
dimensions = len(arr.shape)
if dimensions < 2:
raise ValueError("Fan in and fan out can not be computed for array with fewer than 2 dimensions")
num_input_fmaps = arr.shape[1]
num_output_fmaps = arr.shape[0]
receptive_field_size = 1
if dimensions > 2:
receptive_field_size = arr[0][0].size
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size
return fan_in, fan_out
class KaimingUniform(MeInitializer):
"""Kaiming uniform initializer."""
def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'):
super(KaimingUniform, self).__init__()
self.a = a
self.mode = mode
self.nonlinearity = nonlinearity
def _initialize(self, arr):
tmp = kaiming_uniform_(arr, self.a, self.mode, self.nonlinearity)
_assignment(arr, tmp)
def default_recurisive_init(custom_cell):
"""Initialize parameter."""
for _, cell in custom_cell.cells_and_names():
if isinstance(cell, nn.Conv2d):
cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)),
cell.weight.default_input.shape,
cell.weight.default_input.dtype).to_tensor()
if cell.bias is not None:
fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight.default_input.asnumpy())
bound = 1 / math.sqrt(fan_in)
cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape),
cell.bias.default_input.dtype)
elif isinstance(cell, nn.Dense):
cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)),
cell.weight.default_input.shape,
cell.weight.default_input.dtype).to_tensor()
if cell.bias is not None:
fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight.default_input.asnumpy())
bound = 1 / math.sqrt(fan_in)
cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape),
cell.bias.default_input.dtype)
elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
pass
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Custom Logger."""
import os
import sys
import logging
from datetime import datetime
class LOGGER(logging.Logger):
"""
Logger.
Args:
logger_name: String. Logger name.
rank: Integer. Rank id.
"""
def __init__(self, logger_name, rank=0):
super(LOGGER, self).__init__(logger_name)
self.rank = rank
if rank % 8 == 0:
console = logging.StreamHandler(sys.stdout)
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
console.setFormatter(formatter)
self.addHandler(console)
def setup_logging_file(self, log_dir, rank=0):
"""Setup logging file."""
self.rank = rank
if not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(rank)
self.log_fn = os.path.join(log_dir, log_name)
fh = logging.FileHandler(self.log_fn)
fh.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
fh.setFormatter(formatter)
self.addHandler(fh)
def info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO):
self._log(logging.INFO, msg, args, **kwargs)
def save_args(self, args):
self.info('Args:')
args_dict = vars(args)
for key in args_dict.keys():
self.info('--> %s: %s', key, args_dict[key])
self.info('')
def important_info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO) and self.rank == 0:
line_width = 2
important_msg = '\n'
important_msg += ('*'*70 + '\n')*line_width
important_msg += ('*'*line_width + '\n')*2
important_msg += '*'*line_width + ' '*8 + msg + '\n'
important_msg += ('*'*line_width + '\n')*2
important_msg += ('*'*70 + '\n')*line_width
self.info(important_msg, *args, **kwargs)
def get_logger(path, rank):
"""Get Logger."""
logger = LOGGER('yolov3_darknet53', rank)
logger.setup_logging_file(path, rank)
return logger
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""YOLOV3 loss."""
from mindspore.ops import operations as P
import mindspore.nn as nn
class XYLoss(nn.Cell):
"""Loss for x and y."""
def __init__(self):
super(XYLoss, self).__init__()
self.cross_entropy = P.SigmoidCrossEntropyWithLogits()
self.reduce_sum = P.ReduceSum()
def construct(self, object_mask, box_loss_scale, predict_xy, true_xy):
xy_loss = object_mask * box_loss_scale * self.cross_entropy(predict_xy, true_xy)
xy_loss = self.reduce_sum(xy_loss, ())
return xy_loss
class WHLoss(nn.Cell):
"""Loss for w and h."""
def __init__(self):
super(WHLoss, self).__init__()
self.square = P.Square()
self.reduce_sum = P.ReduceSum()
def construct(self, object_mask, box_loss_scale, predict_wh, true_wh):
wh_loss = object_mask * box_loss_scale * 0.5 * P.Square()(true_wh - predict_wh)
wh_loss = self.reduce_sum(wh_loss, ())
return wh_loss
class ConfidenceLoss(nn.Cell):
"""Loss for confidence."""
def __init__(self):
super(ConfidenceLoss, self).__init__()
self.cross_entropy = P.SigmoidCrossEntropyWithLogits()
self.reduce_sum = P.ReduceSum()
def construct(self, object_mask, predict_confidence, ignore_mask):
confidence_loss = self.cross_entropy(predict_confidence, object_mask)
confidence_loss = object_mask * confidence_loss + (1 - object_mask) * confidence_loss * ignore_mask
confidence_loss = self.reduce_sum(confidence_loss, ())
return confidence_loss
class ClassLoss(nn.Cell):
"""Loss for classification."""
def __init__(self):
super(ClassLoss, self).__init__()
self.cross_entropy = P.SigmoidCrossEntropyWithLogits()
self.reduce_sum = P.ReduceSum()
def construct(self, object_mask, predict_class, class_probs):
class_loss = object_mask * self.cross_entropy(predict_class, class_probs)
class_loss = self.reduce_sum(class_loss, ())
return class_loss
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Learning rate scheduler."""
import math
from collections import Counter
import numpy as np
def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):
"""Linear learning rate."""
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
lr = float(init_lr) + lr_inc * current_step
return lr
def warmup_step_lr(lr, lr_epochs, steps_per_epoch, warmup_epochs, max_epoch, gamma=0.1):
"""Warmup step learning rate."""
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
milestones = lr_epochs
milestones_steps = []
for milestone in milestones:
milestones_step = milestone * steps_per_epoch
milestones_steps.append(milestones_step)
lr_each_step = []
lr = base_lr
milestones_steps_counter = Counter(milestones_steps)
for i in range(total_steps):
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
lr = lr * gamma**milestones_steps_counter[i]
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)
def multi_step_lr(lr, milestones, steps_per_epoch, max_epoch, gamma=0.1):
return warmup_step_lr(lr, milestones, steps_per_epoch, 0, max_epoch, gamma=gamma)
def step_lr(lr, epoch_size, steps_per_epoch, max_epoch, gamma=0.1):
lr_epochs = []
for i in range(1, max_epoch):
if i % epoch_size == 0:
lr_epochs.append(i)
return multi_step_lr(lr, lr_epochs, steps_per_epoch, max_epoch, gamma=gamma)
def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0):
"""Cosine annealing learning rate."""
base_lr = lr
total_steps = int(max_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
lr_each_step = []
for i in range(total_steps):
last_epoch = i // steps_per_epoch
if i < warmup_steps:
lr = 0
else:
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / T_max)) / 2
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)
def warmup_cosine_annealing_lr_V2(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0):
"""Cosine annealing learning rate V2."""
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
last_lr = 0
last_epoch_V1 = 0
T_max_V2 = int(max_epoch*1/3)
lr_each_step = []
for i in range(total_steps):
last_epoch = i // steps_per_epoch
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
if i < total_steps*2/3:
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / T_max)) / 2
last_lr = lr
last_epoch_V1 = last_epoch
else:
base_lr = last_lr
last_epoch = last_epoch-last_epoch_V1
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi * last_epoch / T_max_V2)) / 2
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)
def warmup_cosine_annealing_lr_sample(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0):
"""Warmup cosine annealing learning rate."""
start_sample_epoch = 60
step_sample = 2
tobe_sampled_epoch = 60
end_sampled_epoch = start_sample_epoch + step_sample*tobe_sampled_epoch
max_sampled_epoch = max_epoch+tobe_sampled_epoch
T_max = max_sampled_epoch
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
total_sampled_steps = int(max_sampled_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
lr_each_step = []
for i in range(total_sampled_steps):
last_epoch = i // steps_per_epoch
if last_epoch in range(start_sample_epoch, end_sampled_epoch, step_sample):
continue
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / T_max)) / 2
lr_each_step.append(lr)
assert total_steps == len(lr_each_step)
return np.array(lr_each_step).astype(np.float32)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Util class or function."""
from mindspore.train.serialization import load_checkpoint
import mindspore.nn as nn
class AverageMeter:
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f', tb_writer=None):
self.name = name
self.fmt = fmt
self.reset()
self.tb_writer = tb_writer
self.cur_step = 1
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
if self.tb_writer is not None:
self.tb_writer.add_scalar(self.name, self.val, self.cur_step)
self.cur_step += 1
def __str__(self):
fmtstr = '{name}:{avg' + self.fmt + '}'
return fmtstr.format(**self.__dict__)
def load_backbone(net, ckpt_path, args):
"""Load darknet53 backbone checkpoint."""
param_dict = load_checkpoint(ckpt_path)
yolo_backbone_prefix = 'feature_map.backbone'
darknet_backbone_prefix = 'network.backbone'
find_param = []
not_found_param = []
for name, cell in net.cells_and_names():
if name.startswith(yolo_backbone_prefix):
name = name.replace(yolo_backbone_prefix, darknet_backbone_prefix)
if isinstance(cell, (nn.Conv2d, nn.Dense)):
darknet_weight = '{}.weight'.format(name)
darknet_bias = '{}.bias'.format(name)
if darknet_weight in param_dict:
cell.weight.default_input = param_dict[darknet_weight].data
find_param.append(darknet_weight)
else:
not_found_param.append(darknet_weight)
if darknet_bias in param_dict:
cell.bias.default_input = param_dict[darknet_bias].data
find_param.append(darknet_bias)
else:
not_found_param.append(darknet_bias)
elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
darknet_moving_mean = '{}.moving_mean'.format(name)
darknet_moving_variance = '{}.moving_variance'.format(name)
darknet_gamma = '{}.gamma'.format(name)
darknet_beta = '{}.beta'.format(name)
if darknet_moving_mean in param_dict:
cell.moving_mean.default_input = param_dict[darknet_moving_mean].data
find_param.append(darknet_moving_mean)
else:
not_found_param.append(darknet_moving_mean)
if darknet_moving_variance in param_dict:
cell.moving_variance.default_input = param_dict[darknet_moving_variance].data
find_param.append(darknet_moving_variance)
else:
not_found_param.append(darknet_moving_variance)
if darknet_gamma in param_dict:
cell.gamma.default_input = param_dict[darknet_gamma].data
find_param.append(darknet_gamma)
else:
not_found_param.append(darknet_gamma)
if darknet_beta in param_dict:
cell.beta.default_input = param_dict[darknet_beta].data
find_param.append(darknet_beta)
else:
not_found_param.append(darknet_beta)
args.logger.info('================found_param {}========='.format(len(find_param)))
args.logger.info(find_param)
args.logger.info('================not_found_param {}========='.format(len(not_found_param)))
args.logger.info(not_found_param)
args.logger.info('=====load {} successfully ====='.format(ckpt_path))
return net
def default_wd_filter(x):
"""default weight decay filter."""
parameter_name = x.name
if parameter_name.endswith('.bias'):
# all bias not using weight decay
return False
if parameter_name.endswith('.gamma'):
# bn weight bias not using weight decay, be carefully for now x not include BN
return False
if parameter_name.endswith('.beta'):
# bn weight bias not using weight decay, be carefully for now x not include BN
return False
return True
def get_param_groups(network):
"""Param groups for optimizer."""
decay_params = []
no_decay_params = []
for x in network.trainable_params():
parameter_name = x.name
if parameter_name.endswith('.bias'):
# all bias not using weight decay
no_decay_params.append(x)
elif parameter_name.endswith('.gamma'):
# bn weight bias not using weight decay, be carefully for now x not include BN
no_decay_params.append(x)
elif parameter_name.endswith('.beta'):
# bn weight bias not using weight decay, be carefully for now x not include BN
no_decay_params.append(x)
else:
decay_params.append(x)
return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]
class ShapeRecord:
"""Log image shape."""
def __init__(self):
self.shape_record = {
320: 0,
352: 0,
384: 0,
416: 0,
448: 0,
480: 0,
512: 0,
544: 0,
576: 0,
608: 0,
'total': 0
}
def set(self, shape):
if len(shape) > 1:
shape = shape[0]
shape = int(shape)
self.shape_record[shape] += 1
self.shape_record['total'] += 1
def show(self, logger):
for key in self.shape_record:
rate = self.shape_record[key] / float(self.shape_record['total'])
logger.info('shape {}: {:.2f}%'.format(key, rate*100))
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""YOLOv3 based on DarkNet."""
import mindspore as ms
import mindspore.nn as nn
from mindspore.common.tensor import Tensor
from mindspore import context
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.communication.management import get_group_size
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from src.darknet import DarkNet, ResidualBlock
from src.config import ConfigYOLOV3DarkNet53
from src.loss import XYLoss, WHLoss, ConfidenceLoss, ClassLoss
def _conv_bn_relu(in_channel,
out_channel,
ksize,
stride=1,
padding=0,
dilation=1,
alpha=0.1,
momentum=0.9,
eps=1e-5,
pad_mode="same"):
"""Get a conv2d batchnorm and relu layer"""
return nn.Conv2dBnAct(in_channel, out_channel, ksize,
stride=stride,
pad_mode=pad_mode,
padding=padding,
dilation=dilation,
has_bn=True,
momentum=momentum,
eps=eps,
activation='leakyrelu',
alpha=alpha)
class YoloBlock(nn.Cell):
"""
YoloBlock for YOLOv3.
Args:
in_channels: Integer. Input channel.
out_chls: Interger. Middle channel.
out_channels: Integer. Output channel.
Returns:
Tuple, tuple of output tensor,(f1,f2,f3).
Examples:
YoloBlock(1024, 512, 255)
"""
def __init__(self, in_channels, out_chls, out_channels):
super(YoloBlock, self).__init__()
out_chls_2 = out_chls*2
self.conv0 = _conv_bn_relu(in_channels, out_chls, ksize=1)
self.conv1 = _conv_bn_relu(out_chls, out_chls_2, ksize=3)
self.conv2 = _conv_bn_relu(out_chls_2, out_chls, ksize=1)
self.conv3 = _conv_bn_relu(out_chls, out_chls_2, ksize=3)
self.conv4 = _conv_bn_relu(out_chls_2, out_chls, ksize=1)
self.conv5 = _conv_bn_relu(out_chls, out_chls_2, ksize=3)
self.conv6 = nn.Conv2dBnAct(out_chls_2, out_channels, kernel_size=1, stride=1,
has_bias=True, has_bn=False, activation=None, after_fake=False)
def construct(self, x):
c1 = self.conv0(x)
c2 = self.conv1(c1)
c3 = self.conv2(c2)
c4 = self.conv3(c3)
c5 = self.conv4(c4)
c6 = self.conv5(c5)
out = self.conv6(c6)
return c5, out
class YOLOv3(nn.Cell):
"""
YOLOv3 Network.
Note:
backbone = darknet53
Args:
backbone_shape: List. Darknet output channels shape.
backbone: Cell. Backbone Network.
out_channel: Interger. Output channel.
Returns:
Tensor, output tensor.
Examples:
YOLOv3(backbone_shape=[64, 128, 256, 512, 1024]
backbone=darknet53(),
out_channel=255)
"""
def __init__(self, backbone_shape, backbone, out_channel):
super(YOLOv3, self).__init__()
self.out_channel = out_channel
self.backbone = backbone
self.backblock0 = YoloBlock(backbone_shape[-1], out_chls=backbone_shape[-2], out_channels=out_channel)
self.conv1 = _conv_bn_relu(in_channel=backbone_shape[-2], out_channel=backbone_shape[-2]//2, ksize=1)
self.backblock1 = YoloBlock(in_channels=backbone_shape[-2]+backbone_shape[-3],
out_chls=backbone_shape[-3],
out_channels=out_channel)
self.conv2 = _conv_bn_relu(in_channel=backbone_shape[-3], out_channel=backbone_shape[-3]//2, ksize=1)
self.backblock2 = YoloBlock(in_channels=backbone_shape[-3]+backbone_shape[-4],
out_chls=backbone_shape[-4],
out_channels=out_channel)
self.concat = P.Concat(axis=1)
def construct(self, x):
# input_shape of x is (batch_size, 3, h, w)
# feature_map1 is (batch_size, backbone_shape[2], h/8, w/8)
# feature_map2 is (batch_size, backbone_shape[3], h/16, w/16)
# feature_map3 is (batch_size, backbone_shape[4], h/32, w/32)
img_hight = P.Shape()(x)[2]
img_width = P.Shape()(x)[3]
feature_map1, feature_map2, feature_map3 = self.backbone(x)
con1, big_object_output = self.backblock0(feature_map3)
con1 = self.conv1(con1)
ups1 = P.ResizeNearestNeighbor((img_hight / 16, img_width / 16))(con1)
con1 = self.concat((ups1, feature_map2))
con2, medium_object_output = self.backblock1(con1)
con2 = self.conv2(con2)
ups2 = P.ResizeNearestNeighbor((img_hight / 8, img_width / 8))(con2)
con3 = self.concat((ups2, feature_map1))
_, small_object_output = self.backblock2(con3)
return big_object_output, medium_object_output, small_object_output
class DetectionBlock(nn.Cell):
"""
YOLOv3 detection Network. It will finally output the detection result.
Args:
scale: Character.
config: ConfigYOLOV3DarkNet53, Configuration instance.
is_training: Bool, Whether train or not, default True.
Returns:
Tuple, tuple of output tensor,(f1,f2,f3).
Examples:
DetectionBlock(scale='l',stride=32)
"""
def __init__(self, scale, config=ConfigYOLOV3DarkNet53(), is_training=True):
super(DetectionBlock, self).__init__()
self.config = config
if scale == 's':
idx = (0, 1, 2)
elif scale == 'm':
idx = (3, 4, 5)
elif scale == 'l':
idx = (6, 7, 8)
else:
raise KeyError("Invalid scale value for DetectionBlock")
self.anchors = Tensor([self.config.anchor_scales[i] for i in idx], ms.float32)
self.num_anchors_per_scale = 3
self.num_attrib = 4+1+self.config.num_classes
self.lambda_coord = 1
self.sigmoid = nn.Sigmoid()
self.reshape = P.Reshape()
self.tile = P.Tile()
self.concat = P.Concat(axis=-1)
self.conf_training = is_training
def construct(self, x, input_shape):
num_batch = P.Shape()(x)[0]
grid_size = P.Shape()(x)[2:4]
# Reshape and transpose the feature to [n, grid_size[0], grid_size[1], 3, num_attrib]
prediction = P.Reshape()(x, (num_batch,
self.num_anchors_per_scale,
self.num_attrib,
grid_size[0],
grid_size[1]))
prediction = P.Transpose()(prediction, (0, 3, 4, 1, 2))
range_x = range(grid_size[1])
range_y = range(grid_size[0])
grid_x = P.Cast()(F.tuple_to_array(range_x), ms.float32)
grid_y = P.Cast()(F.tuple_to_array(range_y), ms.float32)
# Tensor of shape [grid_size[0], grid_size[1], 1, 1] representing the coordinate of x/y axis for each grid
# [batch, gridx, gridy, 1, 1]
grid_x = self.tile(self.reshape(grid_x, (1, 1, -1, 1, 1)), (1, grid_size[0], 1, 1, 1))
grid_y = self.tile(self.reshape(grid_y, (1, -1, 1, 1, 1)), (1, 1, grid_size[1], 1, 1))
# Shape is [grid_size[0], grid_size[1], 1, 2]
grid = self.concat((grid_x, grid_y))
box_xy = prediction[:, :, :, :, :2]
box_wh = prediction[:, :, :, :, 2:4]
box_confidence = prediction[:, :, :, :, 4:5]
box_probs = prediction[:, :, :, :, 5:]
# gridsize1 is x
# gridsize0 is y
box_xy = (self.sigmoid(box_xy) + grid) / P.Cast()(F.tuple_to_array((grid_size[1], grid_size[0])), ms.float32)
# box_wh is w->h
box_wh = P.Exp()(box_wh) * self.anchors / input_shape
box_confidence = self.sigmoid(box_confidence)
box_probs = self.sigmoid(box_probs)
if self.conf_training:
return grid, prediction, box_xy, box_wh
return self.concat((box_xy, box_wh, box_confidence, box_probs))
class Iou(nn.Cell):
"""Calculate the iou of boxes"""
def __init__(self):
super(Iou, self).__init__()
self.min = P.Minimum()
self.max = P.Maximum()
def construct(self, box1, box2):
# box1: pred_box [batch, gx, gy, anchors, 1, 4] ->4: [x_center, y_center, w, h]
# box2: gt_box [batch, 1, 1, 1, maxbox, 4]
# convert to topLeft and rightDown
box1_xy = box1[:, :, :, :, :, :2]
box1_wh = box1[:, :, :, :, :, 2:4]
box1_mins = box1_xy - box1_wh / F.scalar_to_array(2.0) # topLeft
box1_maxs = box1_xy + box1_wh / F.scalar_to_array(2.0) # rightDown
box2_xy = box2[:, :, :, :, :, :2]
box2_wh = box2[:, :, :, :, :, 2:4]
box2_mins = box2_xy - box2_wh / F.scalar_to_array(2.0)
box2_maxs = box2_xy + box2_wh / F.scalar_to_array(2.0)
intersect_mins = self.max(box1_mins, box2_mins)
intersect_maxs = self.min(box1_maxs, box2_maxs)
intersect_wh = self.max(intersect_maxs - intersect_mins, F.scalar_to_array(0.0))
# P.squeeze: for effiecient slice
intersect_area = P.Squeeze(-1)(intersect_wh[:, :, :, :, :, 0:1]) * \
P.Squeeze(-1)(intersect_wh[:, :, :, :, :, 1:2])
box1_area = P.Squeeze(-1)(box1_wh[:, :, :, :, :, 0:1]) * P.Squeeze(-1)(box1_wh[:, :, :, :, :, 1:2])
box2_area = P.Squeeze(-1)(box2_wh[:, :, :, :, :, 0:1]) * P.Squeeze(-1)(box2_wh[:, :, :, :, :, 1:2])
iou = intersect_area / (box1_area + box2_area - intersect_area)
# iou : [batch, gx, gy, anchors, maxboxes]
return iou
class YoloLossBlock(nn.Cell):
"""
Loss block cell of YOLOV3 network.
"""
def __init__(self, scale, config=ConfigYOLOV3DarkNet53()):
super(YoloLossBlock, self).__init__()
self.config = config
if scale == 's':
# anchor mask
idx = (0, 1, 2)
elif scale == 'm':
idx = (3, 4, 5)
elif scale == 'l':
idx = (6, 7, 8)
else:
raise KeyError("Invalid scale value for DetectionBlock")
self.anchors = Tensor([self.config.anchor_scales[i] for i in idx], ms.float32)
self.ignore_threshold = Tensor(self.config.ignore_threshold, ms.float32)
self.concat = P.Concat(axis=-1)
self.iou = Iou()
self.reduce_max = P.ReduceMax(keep_dims=False)
self.xy_loss = XYLoss()
self.wh_loss = WHLoss()
self.confidenceLoss = ConfidenceLoss()
self.classLoss = ClassLoss()
def construct(self, grid, prediction, pred_xy, pred_wh, y_true, gt_box, input_shape):
# prediction : origin output from yolo
# pred_xy: (sigmoid(xy)+grid)/grid_size
# pred_wh: (exp(wh)*anchors)/input_shape
# y_true : after normalize
# gt_box: [batch, maxboxes, xyhw] after normalize
object_mask = y_true[:, :, :, :, 4:5]
class_probs = y_true[:, :, :, :, 5:]
grid_shape = P.Shape()(prediction)[1:3]
grid_shape = P.Cast()(F.tuple_to_array(grid_shape[::-1]), ms.float32)
pred_boxes = self.concat((pred_xy, pred_wh))
true_xy = y_true[:, :, :, :, :2] * grid_shape - grid
true_wh = y_true[:, :, :, :, 2:4]
true_wh = P.Select()(P.Equal()(true_wh, 0.0),
P.Fill()(P.DType()(true_wh),
P.Shape()(true_wh), 1.0),
true_wh)
true_wh = P.Log()(true_wh / self.anchors * input_shape)
# 2-w*h for large picture, use small scale, since small obj need more precise
box_loss_scale = 2 - y_true[:, :, :, :, 2:3] * y_true[:, :, :, :, 3:4]
gt_shape = P.Shape()(gt_box)
gt_box = P.Reshape()(gt_box, (gt_shape[0], 1, 1, 1, gt_shape[1], gt_shape[2]))
# add one more dimension for broadcast
iou = self.iou(P.ExpandDims()(pred_boxes, -2), gt_box)
# gt_box is x,y,h,w after normalize
# [batch, grid[0], grid[1], num_anchor, num_gt]
best_iou = self.reduce_max(iou, -1)
# [batch, grid[0], grid[1], num_anchor]
# ignore_mask IOU too small
ignore_mask = best_iou < self.ignore_threshold
ignore_mask = P.Cast()(ignore_mask, ms.float32)
ignore_mask = P.ExpandDims()(ignore_mask, -1)
# ignore_mask backpro will cause a lot maximunGrad and minimumGrad time consume.
# so we turn off its gradient
ignore_mask = F.stop_gradient(ignore_mask)
xy_loss = self.xy_loss(object_mask, box_loss_scale, prediction[:, :, :, :, :2], true_xy)
wh_loss = self.wh_loss(object_mask, box_loss_scale, prediction[:, :, :, :, 2:4], true_wh)
confidence_loss = self.confidenceLoss(object_mask, prediction[:, :, :, :, 4:5], ignore_mask)
class_loss = self.classLoss(object_mask, prediction[:, :, :, :, 5:], class_probs)
loss = xy_loss + wh_loss + confidence_loss + class_loss
batch_size = P.Shape()(prediction)[0]
return loss / batch_size
class YOLOV3DarkNet53(nn.Cell):
"""
Darknet based YOLOV3 network.
Args:
is_training: Bool. Whether train or not.
Returns:
Cell, cell instance of Darknet based YOLOV3 neural network.
Examples:
YOLOV3DarkNet53(True)
"""
def __init__(self, is_training):
super(YOLOV3DarkNet53, self).__init__()
self.config = ConfigYOLOV3DarkNet53()
# YOLOv3 network
self.feature_map = YOLOv3(backbone=DarkNet(ResidualBlock, self.config.backbone_layers,
self.config.backbone_input_shape,
self.config.backbone_shape,
detect=True),
backbone_shape=self.config.backbone_shape,
out_channel=self.config.out_channel)
# prediction on the default anchor boxes
self.detect_1 = DetectionBlock('l', is_training=is_training)
self.detect_2 = DetectionBlock('m', is_training=is_training)
self.detect_3 = DetectionBlock('s', is_training=is_training)
def construct(self, x, input_shape):
big_object_output, medium_object_output, small_object_output = self.feature_map(x)
output_big = self.detect_1(big_object_output, input_shape)
output_me = self.detect_2(medium_object_output, input_shape)
output_small = self.detect_3(small_object_output, input_shape)
# big is the final output which has smallest feature map
return output_big, output_me, output_small
class YoloWithLossCell(nn.Cell):
"""YOLOV3 loss."""
def __init__(self, network):
super(YoloWithLossCell, self).__init__()
self.yolo_network = network
self.config = ConfigYOLOV3DarkNet53()
self.loss_big = YoloLossBlock('l', self.config)
self.loss_me = YoloLossBlock('m', self.config)
self.loss_small = YoloLossBlock('s', self.config)
def construct(self, x, y_true_0, y_true_1, y_true_2, gt_0, gt_1, gt_2, input_shape):
yolo_out = self.yolo_network(x, input_shape)
loss_l = self.loss_big(*yolo_out[0], y_true_0, gt_0, input_shape)
loss_m = self.loss_me(*yolo_out[1], y_true_1, gt_1, input_shape)
loss_s = self.loss_small(*yolo_out[2], y_true_2, gt_2, input_shape)
return loss_l + loss_m + loss_s
class TrainingWrapper(nn.Cell):
"""Training wrapper."""
def __init__(self, network, optimizer, sens=1.0):
super(TrainingWrapper, self).__init__(auto_prefix=False)
self.network = network
self.weights = optimizer.parameters
self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
self.sens = sens
self.reducer_flag = False
self.grad_reducer = None
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ms.ParallelMode.DATA_PARALLEL, ms.ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean")
if auto_parallel_context().get_device_num_is_set():
degree = context.get_auto_parallel_context("device_num")
else:
degree = get_group_size()
self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
def construct(self, *args):
weights = self.weights
loss = self.network(*args)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(*args, sens)
if self.reducer_flag:
grads = self.grad_reducer(grads)
return F.depend(loss, self.optimizer(grads))
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""YOLOV3 dataset."""
import os
from PIL import Image
from pycocotools.coco import COCO
import mindspore.dataset as de
import mindspore.dataset.transforms.vision.c_transforms as CV
from src.distributed_sampler import DistributedSampler
from src.transforms import reshape_fn, MultiScaleTrans
min_keypoints_per_image = 10
def _has_only_empty_bbox(anno):
return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)
def _count_visible_keypoints(anno):
return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)
def has_valid_annotation(anno):
"""Check annotation file."""
# if it's empty, there is no annotation
if not anno:
return False
# if all boxes have close to zero area, there is no annotation
if _has_only_empty_bbox(anno):
return False
# keypoints task have a slight different critera for considering
# if an annotation is valid
if "keypoints" not in anno[0]:
return True
# for keypoint detection tasks, only consider valid images those
# containing at least min_keypoints_per_image
if _count_visible_keypoints(anno) >= min_keypoints_per_image:
return True
return False
class COCOYoloDataset:
"""YOLOV3 Dataset for COCO."""
def __init__(self, root, ann_file, remove_images_without_annotations=True,
filter_crowd_anno=True, is_training=True):
self.coco = COCO(ann_file)
self.root = root
self.img_ids = list(sorted(self.coco.imgs.keys()))
self.filter_crowd_anno = filter_crowd_anno
self.is_training = is_training
# filter images without any annotations
if remove_images_without_annotations:
img_ids = []
for img_id in self.img_ids:
ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None)
anno = self.coco.loadAnns(ann_ids)
if has_valid_annotation(anno):
img_ids.append(img_id)
self.img_ids = img_ids
self.categories = {cat["id"]: cat["name"] for cat in self.coco.cats.values()}
self.cat_ids_to_continuous_ids = {
v: i for i, v in enumerate(self.coco.getCatIds())
}
self.continuous_ids_cat_ids = {
v: k for k, v in self.cat_ids_to_continuous_ids.items()
}
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
(img, target) (tuple): target is a dictionary contains "bbox", "segmentation" or "keypoints",
generated by the image's annotation. img is a PIL image.
"""
coco = self.coco
img_id = self.img_ids[index]
img_path = coco.loadImgs(img_id)[0]["file_name"]
img = Image.open(os.path.join(self.root, img_path)).convert("RGB")
if not self.is_training:
return img, img_id
ann_ids = coco.getAnnIds(imgIds=img_id)
target = coco.loadAnns(ann_ids)
# filter crowd annotations
if self.filter_crowd_anno:
annos = [anno for anno in target if anno["iscrowd"] == 0]
else:
annos = [anno for anno in target]
target = {}
boxes = [anno["bbox"] for anno in annos]
target["bboxes"] = boxes
classes = [anno["category_id"] for anno in annos]
classes = [self.cat_ids_to_continuous_ids[cl] for cl in classes]
target["labels"] = classes
bboxes = target['bboxes']
labels = target['labels']
out_target = []
for bbox, label in zip(bboxes, labels):
tmp = []
# convert to [x_min y_min x_max y_max]
bbox = self._convetTopDown(bbox)
tmp.extend(bbox)
tmp.append(int(label))
# tmp [x_min y_min x_max y_max, label]
out_target.append(tmp)
return img, out_target
def __len__(self):
return len(self.img_ids)
def _convetTopDown(self, bbox):
x_min = bbox[0]
y_min = bbox[1]
w = bbox[2]
h = bbox[3]
return [x_min, y_min, x_min+w, y_min+h]
def create_yolo_dataset(image_dir, anno_path, batch_size, max_epoch, device_num, rank,
config=None, is_training=True, shuffle=True):
"""Create dataset for YOLOV3."""
if is_training:
filter_crowd = True
remove_empty_anno = True
else:
filter_crowd = False
remove_empty_anno = False
yolo_dataset = COCOYoloDataset(root=image_dir, ann_file=anno_path, filter_crowd_anno=filter_crowd,
remove_images_without_annotations=remove_empty_anno, is_training=is_training)
distributed_sampler = DistributedSampler(len(yolo_dataset), device_num, rank, shuffle=shuffle)
hwc_to_chw = CV.HWC2CHW()
config.dataset_size = len(yolo_dataset)
num_parallel_workers1 = int(64 / device_num)
num_parallel_workers2 = int(16 / device_num)
if is_training:
multi_scale_trans = MultiScaleTrans(config, device_num)
if device_num != 8:
ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "annotation"],
num_parallel_workers=num_parallel_workers1,
sampler=distributed_sampler)
ds = ds.batch(batch_size, per_batch_map=multi_scale_trans, input_columns=['image', 'annotation'],
num_parallel_workers=num_parallel_workers2, drop_remainder=True)
else:
ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "annotation"], sampler=distributed_sampler)
ds = ds.batch(batch_size, per_batch_map=multi_scale_trans, input_columns=['image', 'annotation'],
num_parallel_workers=8, drop_remainder=True)
else:
ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "img_id"],
sampler=distributed_sampler)
compose_map_func = (lambda image, img_id: reshape_fn(image, img_id, config))
ds = ds.map(input_columns=["image", "img_id"],
output_columns=["image", "image_shape", "img_id"],
columns_order=["image", "image_shape", "img_id"],
operations=compose_map_func, num_parallel_workers=8)
ds = ds.map(input_columns=["image"], operations=hwc_to_chw, num_parallel_workers=8)
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.repeat(max_epoch)
return ds, len(yolo_dataset)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""YoloV3 train."""
import os
import time
import argparse
import datetime
from mindspore import ParallelMode
from mindspore.nn.optim.momentum import Momentum
from mindspore import Tensor
from mindspore import context
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.train.callback import ModelCheckpoint, RunContext
from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig
import mindspore as ms
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.quant import quant
from src.yolo import YOLOV3DarkNet53, YoloWithLossCell, TrainingWrapper
from src.logger import get_logger
from src.util import AverageMeter, load_backbone, get_param_groups
from src.lr_scheduler import warmup_step_lr, warmup_cosine_annealing_lr, \
warmup_cosine_annealing_lr_V2, warmup_cosine_annealing_lr_sample
from src.yolo_dataset import create_yolo_dataset
from src.initializer import default_recurisive_init
from src.config import ConfigYOLOV3DarkNet53
from src.transforms import batch_preprocess_true_box, batch_preprocess_true_box_single
from src.util import ShapeRecord
devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
device_target="Ascend", save_graphs=True, device_id=devid)
def parse_args():
"""Parse train arguments."""
parser = argparse.ArgumentParser('mindspore coco training')
# dataset related
parser.add_argument('--data_dir', type=str, default='', help='train data dir')
parser.add_argument('--per_batch_size', default=32, type=int, help='batch size for per gpu')
# network related
parser.add_argument('--pretrained_backbone', default='', type=str, help='model_path, local pretrained backbone'
' model to load')
parser.add_argument('--resume_yolov3', default='', type=str, help='path of pretrained yolov3')
# optimizer and lr related
parser.add_argument('--lr_scheduler', default='exponential', type=str,
help='lr-scheduler, option type: exponential, cosine_annealing')
parser.add_argument('--lr', default=0.001, type=float, help='learning rate of the training')
parser.add_argument('--lr_epochs', type=str, default='220,250', help='epoch of lr changing')
parser.add_argument('--lr_gamma', type=float, default=0.1,
help='decrease lr by a factor of exponential lr_scheduler')
parser.add_argument('--eta_min', type=float, default=0., help='eta_min in cosine_annealing scheduler')
parser.add_argument('--T_max', type=int, default=320, help='T-max in cosine_annealing scheduler')
parser.add_argument('--max_epoch', type=int, default=320, help='max epoch num to train the model')
parser.add_argument('--warmup_epochs', default=0, type=float, help='warmup epoch')
parser.add_argument('--weight_decay', type=float, default=0.0005, help='weight decay')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
# loss related
parser.add_argument('--loss_scale', type=int, default=1024, help='static loss scale')
parser.add_argument('--label_smooth', type=int, default=0, help='whether to use label smooth in CE')
parser.add_argument('--label_smooth_factor', type=float, default=0.1, help='smooth strength of original one-hot')
# logging related
parser.add_argument('--log_interval', type=int, default=100, help='logging interval')
parser.add_argument('--ckpt_path', type=str, default='outputs/', help='checkpoint save location')
parser.add_argument('--ckpt_interval', type=int, default=None, help='ckpt_interval')
parser.add_argument('--is_save_on_master', type=int, default=1, help='save ckpt on master or all rank')
# distributed related
parser.add_argument('--is_distributed', type=int, default=1, help='if multi device')
parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
# roma obs
parser.add_argument('--train_url', type=str, default="", help='train url')
# profiler init
parser.add_argument('--need_profiler', type=int, default=0, help='whether use profiler')
# reset default config
parser.add_argument('--training_shape', type=str, default="", help='fix training shape')
parser.add_argument('--resize_rate', type=int, default=None, help='resize rate for multi-scale training')
args, _ = parser.parse_known_args()
if args.lr_scheduler == 'cosine_annealing' and args.max_epoch > args.T_max:
args.T_max = args.max_epoch
args.lr_epochs = list(map(int, args.lr_epochs.split(',')))
args.data_root = os.path.join(args.data_dir, 'train2014')
args.annFile = os.path.join(args.data_dir, 'annotations/instances_train2014.json')
return args
def conver_training_shape(args):
training_shape = [int(args.training_shape), int(args.training_shape)]
return training_shape
def train():
"""Train function."""
args = parse_args()
# init distributed
if args.is_distributed:
init()
args.rank = get_rank()
args.group_size = get_group_size()
# select for master rank save ckpt or all rank save, compatiable for model parallel
args.rank_save_ckpt_flag = 0
if args.is_save_on_master:
if args.rank == 0:
args.rank_save_ckpt_flag = 1
else:
args.rank_save_ckpt_flag = 1
# logger
args.outputs_dir = os.path.join(args.ckpt_path,
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
args.logger = get_logger(args.outputs_dir, args.rank)
args.logger.save_args(args)
if args.need_profiler:
from mindinsight.profiler.profiling import Profiler
profiler = Profiler(output_path=args.outputs_dir, is_detail=True, is_show_op_path=True)
loss_meter = AverageMeter('loss')
context.reset_auto_parallel_context()
if args.is_distributed:
parallel_mode = ParallelMode.DATA_PARALLEL
degree = get_group_size()
else:
parallel_mode = ParallelMode.STAND_ALONE
degree = 1
context.set_auto_parallel_context(parallel_mode=parallel_mode, mirror_mean=True, device_num=degree)
network = YOLOV3DarkNet53(is_training=True)
# default is kaiming-normal
default_recurisive_init(network)
if args.pretrained_backbone:
network = load_backbone(network, args.pretrained_backbone, args)
args.logger.info('load pre-trained backbone {} into network'.format(args.pretrained_backbone))
else:
args.logger.info('Not load pre-trained backbone, please be careful')
if args.resume_yolov3:
param_dict = load_checkpoint(args.resume_yolov3)
param_dict_new = {}
for key, values in param_dict.items():
args.logger.info('ckpt param name = {}'.format(key))
if key.startswith('moments.') or key.startswith('global_') or \
key.startswith('learning_rate') or key.startswith('momentum'):
continue
elif key.startswith('yolo_network.'):
key_new = key[13:]
if key_new.endswith('1.beta'):
key_new = key_new.replace('1.beta', 'batchnorm.beta')
if key_new.endswith('1.gamma'):
key_new = key_new.replace('1.gamma', 'batchnorm.gamma')
if key_new.endswith('1.moving_mean'):
key_new = key_new.replace('1.moving_mean', 'batchnorm.moving_mean')
if key_new.endswith('1.moving_variance'):
key_new = key_new.replace('1.moving_variance', 'batchnorm.moving_variance')
if key_new.endswith('.weight'):
if key_new.endswith('0.weight'):
key_new = key_new.replace('0.weight', 'conv.weight')
else:
key_new = key_new.replace('.weight', '.conv.weight')
if key_new.endswith('.bias'):
key_new = key_new.replace('.bias', '.conv.bias')
param_dict_new[key_new] = values
args.logger.info('in resume {}'.format(key_new))
else:
param_dict_new[key] = values
args.logger.info('in resume {}'.format(key))
args.logger.info('resume finished')
for _, param in network.parameters_and_names():
args.logger.info('network param name = {}'.format(param.name))
if param.name not in param_dict_new:
args.logger.info('not match param name = {}'.format(param.name))
load_param_into_net(network, param_dict_new)
args.logger.info('load_model {} success'.format(args.resume_yolov3))
config = ConfigYOLOV3DarkNet53()
# convert fusion network to quantization aware network
if config.quantization_aware:
network = quant.convert_quant_network(network,
bn_fold=True,
per_channel=[True, False],
symmetric=[True, False])
network = YoloWithLossCell(network)
args.logger.info('finish get network')
config.label_smooth = args.label_smooth
config.label_smooth_factor = args.label_smooth_factor
if args.training_shape:
config.multi_scale = [conver_training_shape(args)]
if args.resize_rate:
config.resize_rate = args.resize_rate
ds, data_size = create_yolo_dataset(image_dir=args.data_root, anno_path=args.annFile, is_training=True,
batch_size=args.per_batch_size, max_epoch=args.max_epoch,
device_num=args.group_size, rank=args.rank, config=config)
args.logger.info('Finish loading dataset')
args.steps_per_epoch = int(data_size / args.per_batch_size / args.group_size)
if not args.ckpt_interval:
args.ckpt_interval = args.steps_per_epoch
# lr scheduler
if args.lr_scheduler == 'exponential':
lr = warmup_step_lr(args.lr,
args.lr_epochs,
args.steps_per_epoch,
args.warmup_epochs,
args.max_epoch,
gamma=args.lr_gamma,
)
elif args.lr_scheduler == 'cosine_annealing':
lr = warmup_cosine_annealing_lr(args.lr,
args.steps_per_epoch,
args.warmup_epochs,
args.max_epoch,
args.T_max,
args.eta_min)
elif args.lr_scheduler == 'cosine_annealing_V2':
lr = warmup_cosine_annealing_lr_V2(args.lr,
args.steps_per_epoch,
args.warmup_epochs,
args.max_epoch,
args.T_max,
args.eta_min)
elif args.lr_scheduler == 'cosine_annealing_sample':
lr = warmup_cosine_annealing_lr_sample(args.lr,
args.steps_per_epoch,
args.warmup_epochs,
args.max_epoch,
args.T_max,
args.eta_min)
else:
raise NotImplementedError(args.lr_scheduler)
opt = Momentum(params=get_param_groups(network),
learning_rate=Tensor(lr),
momentum=args.momentum,
weight_decay=args.weight_decay,
loss_scale=args.loss_scale)
network = TrainingWrapper(network, opt)
network.set_train()
if args.rank_save_ckpt_flag:
# checkpoint save
ckpt_max_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval,
keep_checkpoint_max=ckpt_max_num)
ckpt_cb = ModelCheckpoint(config=ckpt_config,
directory=args.outputs_dir,
prefix='{}'.format(args.rank))
cb_params = _InternalCallbackParam()
cb_params.train_network = network
cb_params.epoch_num = ckpt_max_num
cb_params.cur_epoch_num = 1
run_context = RunContext(cb_params)
ckpt_cb.begin(run_context)
old_progress = -1
t_end = time.time()
data_loader = ds.create_dict_iterator()
shape_record = ShapeRecord()
for i, data in enumerate(data_loader):
images = data["image"]
input_shape = images.shape[2:4]
args.logger.info('iter[{}], shape{}'.format(i, input_shape[0]))
shape_record.set(input_shape)
images = Tensor(images)
annos = data["annotation"]
if args.group_size == 1:
batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, batch_gt_box2 = \
batch_preprocess_true_box(annos, config, input_shape)
else:
batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, batch_gt_box2 = \
batch_preprocess_true_box_single(annos, config, input_shape)
batch_y_true_0 = Tensor(batch_y_true_0)
batch_y_true_1 = Tensor(batch_y_true_1)
batch_y_true_2 = Tensor(batch_y_true_2)
batch_gt_box0 = Tensor(batch_gt_box0)
batch_gt_box1 = Tensor(batch_gt_box1)
batch_gt_box2 = Tensor(batch_gt_box2)
input_shape = Tensor(tuple(input_shape[::-1]), ms.float32)
loss = network(images, batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1,
batch_gt_box2, input_shape)
loss_meter.update(loss.asnumpy())
if args.rank_save_ckpt_flag:
# ckpt progress
cb_params.cur_step_num = i + 1 # current step number
cb_params.batch_num = i + 2
ckpt_cb.step_end(run_context)
if i % args.log_interval == 0:
time_used = time.time() - t_end
epoch = int(i / args.steps_per_epoch)
fps = args.per_batch_size * (i - old_progress) * args.group_size / time_used
if args.rank == 0:
args.logger.info(
'epoch[{}], iter[{}], {}, {:.2f} imgs/sec, lr:{}'.format(epoch, i, loss_meter, fps, lr[i]))
t_end = time.time()
loss_meter.reset()
old_progress = i
if (i + 1) % args.steps_per_epoch == 0 and args.rank_save_ckpt_flag:
cb_params.cur_epoch_num += 1
if args.need_profiler:
if i == 10:
profiler.analyse()
break
args.logger.info('==========end training===============')
if __name__ == "__main__":
train()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册