diff --git a/model_zoo/official/cv/deeplabv3/README.md b/model_zoo/official/cv/deeplabv3/README.md index 464c51e7a2d588971d4ad47d1a9f0835d857668a..1cf1889c2536acdccd57949f777fad27c28886ff 100644 --- a/model_zoo/official/cv/deeplabv3/README.md +++ b/model_zoo/official/cv/deeplabv3/README.md @@ -1,226 +1,284 @@ -# Contents - -- [DeeplabV3 Description](#DeeplabV3-description) -- [Model Architecture](#model-architecture) -- [Dataset](#dataset) -- [Features](#features) - - [Mixed Precision](#mixed-precision) -- [Environment Requirements](#environment-requirements) -- [Script Description](#script-description) - - [Script and Sample Code](#script-and-sample-code) - - [Training Process](#training-process) - - [Evaluation Process](#evaluation-process) - - [Evaluation](#evaluation) -- [Model Description](#model-description) - - [Performance](#performance) - - [Training Performance](#evaluation-performance) - - [Inference Performance](#evaluation-performance) -- [Description of Random Situation](#description-of-random-situation) -- [ModelZoo Homepage](#modelzoo-homepage) - -# [DeeplabV3 Description](#contents) - -DeepLabv3 is a semantic segmentation architecture that improves upon DeepLabv2 with several modifications.To handle the problem of segmenting objects at multiple scales, modules are designed which employ atrous convolution in cascade or in parallel to capture multi-scale context by adopting multiple atrous rates. - -[Paper](https://arxiv.org/pdf/1706.05587.pdf) Chen L C , Papandreou G , Schroff F , et al. Rethinking Atrous Convolution for Semantic Image Segmentation[J]. 2017. - -# [Model architecture](#contents) - -The overall network architecture of DeepLabv3 is show below: - -[Link](https://arxiv.org/pdf/1706.05587.pdf) - - -# [Dataset](#contents) - -Dataset used: [VOC2012](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html) - -20 classes. The train/val data has 11,530 images containing 27,450 ROI annotated objects and 6,929 segmentations. And we need to remove color map from annotation. - -# [Features](#contents) - -## [Mixed Precision(Ascend)](#contents) - -The [mixed precision](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware. - -For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching ‘reduce precision’. - -# [Environment Requirements](#contents) - -- Hardware(Ascend) - - Prepare hardware environment with Ascend. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. -- Framework - - [MindSpore](http://10.90.67.50/mindspore/archive/20200506/OpenSource/me_vm_x86/) -- For more information, please check the resources below: - - [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html) - - [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html) - -# [Script description](#contents) - -## [Script and sample code](#contents) - -```shell -. -└─deeplabv3 - ├──README.md - ├──eval.py - ├──train.py - ├──scripts - │ ├──run_distribute_train.sh # launch distributed training with ascend platform(8p) - │ ├──run_eval.sh # launch evaluating with ascend platform - │ ├──run_standalone_train.sh # launch standalone training with ascend platform(1p) - ├──src - │ ├──config.py # parameter configuration - │ ├──deeplabv3.py # network definition - │ ├──ei_dataset.py # data preprocessing for EI - │ ├──losses.py # customized loss function - │ ├──md_dataset.py # data preprocessing - │ ├──miou_precision.py # miou metrics - │ ├──__init__.py - │ │ - │ ├──backbone - │ │ ├──resnet_deeplab.py # backbone network definition - │ │ ├──__init__.py - │ │ - │ └──utils - │ ├──adapter.py # adapter of dataset - │ ├──custom_transforms.py # random process dataset - │ ├──file_io.py # file operation module - │ ├──__init__.py -``` - -## [Script Parameters](#contents) - -```python -Major parameters in train.py and config.py are: - learning_rate Learning rate, default is 0.0014. - weight_decay Weight decay, default is 5e-5. - momentum Momentum, default is 0.97. - crop_size Image crop size [height, width] during training, default is 513. - eval_scales The scales to resize images for evaluation, default is [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]. - output_stride The ratio of input to output spatial resolution, default is 16. - ignore_label Ignore label value, default is 255. - seg_num_classes Number of semantic classes, including the background class. - foreground classes + 1 background class in the PASCAL VOC 2012 dataset, default is 21. - fine_tune_batch_norm Fine tune the batch norm parameters or not, default is False. - atrous_rates Atrous rates for atrous spatial pyramid pooling, default is None. - decoder_output_stride The ratio of input to output spatial resolution when employing decoder - to refine segmentation results, default is None. - image_pyramid Input scales for multi-scale feature extraction, default is None. - epoch_size Epoch size, default is 6. - batch_size Batch size of input dataset: N, default is 2. - enable_save_ckpt Enable save checkpoint, default is true. - save_checkpoint_steps Save checkpoint steps, default is 1000. - save_checkpoint_num Save checkpoint numbers, default is 1. -``` - -## [Training process](#contents) - -### Usage - - -You can start training using python or shell scripts. The usage of shell scripts as follows: -``` - sh scripts/run_distribute_train.sh RANK_TABLE_FILE DATA_PATH (CKPT_PATH) -``` -> Notes: - RANK_TABLE_FILE can refer to [Link](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training_ascend.html) , and the device_ip can be got as https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools. - -### Launch - -``` -# training example - python: - python train.py --dataset_url DATA_PATH - - shell: - sh scripts/run_standalone_train.sh DEVICE_ID DATA_PATH (CKPT_PATH) -``` -> Notes: - If you are running a fine-tuning or evaluation task, prepare the corresponding checkpoint file. - -### Result - -Training result(8p) will be stored in the example path. Checkpoints will be stored at `. /train_parallel0/` by default, and training log will be redirected to `./train_parallel0/log.txt` like followings. - -``` -epoch: 1 step: 732, loss is 0.11594 -Epoch time: 78748.379, per step time: 107.378 -epoch: 2 step: 732, loss is 0.092868 -Epoch time: 160917.911, per step time: 36.631 -``` -## [Eval process](#contents) - -### Usage - -You can start training using python or shell scripts. The usage of shell scripts as follows: - -``` - sh scripts/run_eval.sh DEVICE_ID DATA_PATH PRETRAINED_CKPT_PATH -``` -### Launch - -``` -# eval example - python: - python eval.py --device_id DEVICE_ID --dataset_url DATA_DIR --checkpoint_url PATH_CHECKPOINT - - shell: - sh scripts/run_eval.sh DEVICE_ID DATA_PATH PRETRAINED_CKPT_PATH -``` - -> checkpoint can be produced in training process. - -### Result - -Evaluation result will be stored in the example path, you can find result like the followings in `eval.log`. - -``` -mIoU = 0.65049 -``` -# [Model description](#contents) - -## [Performance](#contents) - -### Training Performance - -| Parameters | DeeplabV3 | -| -------------------------- | ---------------------------------------------------------- | -| Model Version | V1 | -| Resource | Ascend 910, cpu:2.60GHz 56cores, memory:314G | -| Uploaded Date | 08/24/2020(month/day/year) | -| MindSpore Version | 0.6.0-beta | -| Dataset | voc2012/train | -| Batch_size | 2 | -| Optimizer | Momentum | -| Loss Function | SoftmaxCrossEntropy | -| Outputs | probability | -| Loss | 0.98 | -| Accuracy | mIoU:65% | -| Total time | 5mins | -| Params (M) | 94M | -| Checkpoint for Fine tuning | 100M | -| Scripts | [deeplabv3 script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/deeplabv3) | [deeplabv3 script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/deeplabv3) | - -#### Inference Performance - -| Parameters | DeeplabV3 | -| -------------------------- | ---------------------------------------------------------- | -| Model Version | V1 | -| Resource | Ascend 910, cpu:2.60GHz 56cores, memory:314G | -| Uploaded Date | 08/24/2020 (month/day/year) | -| MindSpore Version | 0.6.0-beta | -| Dataset | voc2012/val | -| Batch_size | 2 | -| Outputs | probability | -| Accuracy | mIoU:65% | -| Total time | 10mins | -| Model for inference | 97M (.GEIR file) | - -# [Description of Random Situation](#contents) - -We use random in custom_transforms.py for data preprocessing. - -# [ModelZoo Homepage](#contents) - -Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). \ No newline at end of file +# DeepLabV3 for MindSpore + +DeepLab is a series of image semantic segmentation models, DeepLabV3 improves significantly over previous versions. Two keypoints of DeepLabV3:Its multi-grid atrous convolution makes it better to deal with segmenting objects at multiple scales, and augmented ASPP makes image-level features available to capture long range information. +This repository provides a script and recipe to DeepLabV3 model and achieve state-of-the-art performance. + +## Table Of Contents + +* [Model overview](#model-overview) + * [Model Architecture](#model-architecture) + * [Default configuration](#default-configuration) +* [Setup](#setup) + * [Requirements](#requirements) +* [Quick start guide](#quick-start-guide) +* [Performance](#performance) + * [Results](#results) + * [Training accuracy](#training-accuracy) + * [Training performance](#training-performance) + * [One-hour performance](#one-hour-performance) + + +​ + +## Model overview + +Refer to [this paper][1] for network details. + +`Chen L C, Papandreou G, Schroff F, et al. Rethinking atrous convolution for semantic image segmentation[J]. arXiv preprint arXiv:1706.05587, 2017.` + +[1]: https://arxiv.org/abs/1706.05587 + +## Default Configuration + +- network structure + + Resnet101 as backbone, atrous convolution for dense feature extraction. + +- preprocessing on training data: + + crop size: 513 * 513 + + random scale: scale range 0.5 to 2.0 + + random flip + + mean subtraction: means are [103.53, 116.28, 123.675] + +- preprocessing on validation data: + + The image's long side is resized to 513, then the image is padded to 513 * 513 + +- training parameters: + + - Momentum: 0.9 + - LR scheduler: cosine + - Weight decay: 0.0001 + +## Setup + +The following section lists the requirements to start training the deeplabv3 model. + + +### Requirements + +Before running code of this project,please ensure you have the following environments: + - [MindSpore](https://www.mindspore.cn/) + - Hardware environment with the Ascend AI processor + + + For more information about how to get started with MindSpore, see the following sections: + - [MindSpore's Tutorial](https://www.mindspore.cn/tutorial/zh-CN/master/index.html) + - [MindSpore's Api](https://www.mindspore.cn/api/zh-CN/master/index.html) + + +## Quick Start Guide + +### 1. Clone the respository + +``` +git clone xxx +cd ModelZoo_DeepLabV3_MS_MTI/00-access +``` +### 2. Install python packages in requirements.txt + +### 3. Download and preprocess the dataset + + - Download segmentation dataset. + + - Prepare the training data list file. The list file saves the relative path to image and annotation pairs. Lines are like: + + ``` + JPEGImages/00001.jpg SegmentationClassGray/00001.png + JPEGImages/00002.jpg SegmentationClassGray/00002.png + JPEGImages/00003.jpg SegmentationClassGray/00003.png + JPEGImages/00004.jpg SegmentationClassGray/00004.png + ...... + ``` + + - Configure and run build_data.sh to convert dataset to mindrecords. Arguments in build_data.sh: + + ``` + --data_root root path of training data + --data_lst list of training data(prepared above) + --dst_path where mindrecords are saved + --num_shards number of shards of the mindrecords + --shuffle shuffle or not + ``` + +### 4. Generate config json file for 8-cards training + + ``` + # From the root of this projectcd tools + python get_multicards_json.py 10.111.*.* + # 10.111.*.* is the computer's ip address. + ``` + +### 5. Train + +Based on original DeeplabV3 paper, we reproduce two training experiments on vocaug (also as trainaug) dataset and evaluate on voc val dataset. + +For single device training, please config parameters, training script is as follows: +``` +# run_standalone_train.sh +python ${train_code_path}/train.py --data_file=/PATH/TO/MINDRECORD_NAME \ + --train_dir=${train_path}/ckpt \ + --train_epochs=200 \ + --batch_size=32 \ + --crop_size=513 \ + --base_lr=0.015 \ + --lr_type=cos \ + --min_scale=0.5 \ + --max_scale=2.0 \ + --ignore_label=255 \ + --num_classes=21 \ + --model=deeplab_v3_s16 \ + --ckpt_pre_trained=/PATH/TO/PRETRAIN_MODEL \ + --save_steps=1500 \ + --keep_checkpoint_max=200 >log 2>&1 & +``` +For 8 devices training, training steps are as follows: + +1. Train s16 with vocaug dataset, finetuning from resnet101 pretrained model, script is as follows: + +``` +# run_distribute_train_s16_r1.sh +for((i=0;i<=$RANK_SIZE-1;i++)); +do + export RANK_ID=$i + export DEVICE_ID=`expr $i + $RANK_START_ID` + echo 'start rank='$i', device id='$DEVICE_ID'...' + mkdir ${train_path}/device$DEVICE_ID + cd ${train_path}/device$DEVICE_ID + python ${train_code_path}/train.py --train_dir=${train_path}/ckpt \ + --data_file=/PATH/TO/MINDRECORD_NAME \ + --train_epochs=300 \ + --batch_size=32 \ + --crop_size=513 \ + --base_lr=0.08 \ + --lr_type=cos \ + --min_scale=0.5 \ + --max_scale=2.0 \ + --ignore_label=255 \ + --num_classes=21 \ + --model=deeplab_v3_s16 \ + --ckpt_pre_trained=/PATH/TO/PRETRAIN_MODEL \ + --is_distributed \ + --save_steps=410 \ + --keep_checkpoint_max=200 >log 2>&1 & +done +``` +2. Train s8 with vocaug dataset, finetuning from model in previous step, training script is as follows: +``` +# run_distribute_train_s8_r1.sh +for((i=0;i<=$RANK_SIZE-1;i++)); +do + export RANK_ID=$i + export DEVICE_ID=`expr $i + $RANK_START_ID` + echo 'start rank='$i', device id='$DEVICE_ID'...' + mkdir ${train_path}/device$DEVICE_ID + cd ${train_path}/device$DEVICE_ID + python ${train_code_path}/train.py --train_dir=${train_path}/ckpt \ + --data_file=/PATH/TO/MINDRECORD_NAME \ + --train_epochs=800 \ + --batch_size=16 \ + --crop_size=513 \ + --base_lr=0.02 \ + --lr_type=cos \ + --min_scale=0.5 \ + --max_scale=2.0 \ + --ignore_label=255 \ + --num_classes=21 \ + --model=deeplab_v3_s8 \ + --loss_scale=2048 \ + --ckpt_pre_trained=/PATH/TO/PRETRAIN_MODEL \ + --is_distributed \ + --save_steps=820 \ + --keep_checkpoint_max=200 >log 2>&1 & +done +``` +3. Train s8 with voctrain dataset, finetuning from model in pervious step, training script is as follows: +``` +# run_distribute_train_r2.sh +for((i=0;i<=$RANK_SIZE-1;i++)); +do + export RANK_ID=$i + export DEVICE_ID=`expr $i + $RANK_START_ID` + echo 'start rank='$i', device id='$DEVICE_ID'...' + mkdir ${train_path}/device$DEVICE_ID + cd ${train_path}/device$DEVICE_ID + python ${train_code_path}/train.py --train_dir=${train_path}/ckpt \ + --data_file=/PATH/TO/MINDRECORD_NAME \ + --train_epochs=300 \ + --batch_size=16 \ + --crop_size=513 \ + --base_lr=0.008 \ + --lr_type=cos \ + --min_scale=0.5 \ + --max_scale=2.0 \ + --ignore_label=255 \ + --num_classes=21 \ + --model=deeplab_v3_s8 \ + --loss_scale=2048 \ + --ckpt_pre_trained=/PATH/TO/PRETRAIN_MODEL \ + --is_distributed \ + --save_steps=110 \ + --keep_checkpoint_max=200 >log 2>&1 & +done +``` +### 6. Test + +Config checkpoint with --ckpt_path, run script, mIOU with print in eval_path/eval_log. +``` +./run_eval_s16.sh # test s16 +./run_eval_s8.sh # test s8 +./run_eval_s8_multiscale.sh # test s8 + multiscale +./run_eval_s8_multiscale_flip.sh # test s8 + multiscale + flip +``` +Example of test script is as follows: +``` +python ${train_code_path}/eval.py --data_root=/PATH/TO/DATA \ + --data_lst=/PATH/TO/DATA_lst.txt \ + --batch_size=16 \ + --crop_size=513 \ + --ignore_label=255 \ + --num_classes=21 \ + --model=deeplab_v3_s8 \ + --scales=0.5 \ + --scales=0.75 \ + --scales=1.0 \ + --scales=1.25 \ + --scales=1.75 \ + --flip \ + --freeze_bn \ + --ckpt_path=/PATH/TO/PRETRAIN_MODEL >${eval_path}/eval_log 2>&1 & +``` + +## Performance + +### Result + +Our result were obtained by running the applicable training script. To achieve the same results, follow the steps in the Quick Start Guide. + +#### Training accuracy + +| **Network** | OS=16 | OS=8 | MS | Flip | mIOU | mIOU in paper | +| :----------: | :-----: | :----: | :----: | :-----: | :-----: | :-------------: | +| deeplab_v3 | √ | | | | 77.37 | 77.21 | +| deeplab_v3 | | √ | | | 78.84 | 78.51 | +| deeplab_v3 | | √ | √ | | 79.70 |79.45 | +| deeplab_v3 | | √ | √ | √ | 79.89 | 79.77 | + +#### Training performance + +| **NPUs** | train performance | +| :------: | :---------------: | +| 1 | 26 img/s | +| 8 | 131 img/s | + + + + + + + + diff --git a/model_zoo/official/cv/deeplabv3/eval.py b/model_zoo/official/cv/deeplabv3/eval.py index 7e435719827e835a60dd8b2c636e2fbc6bf86281..e36cbce4a5c8bbb86f412eecbb7a8527cd40a5a0 100644 --- a/model_zoo/official/cv/deeplabv3/eval.py +++ b/model_zoo/official/cv/deeplabv3/eval.py @@ -1,51 +1,213 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""evaluation.""" -import argparse -from mindspore import context -from mindspore import Model -from mindspore.train.serialization import load_checkpoint, load_param_into_net -from src.md_dataset import create_dataset -from src.losses import OhemLoss -from src.miou_precision import MiouPrecision -from src.deeplabv3 import deeplabv3_resnet50 -from src.config import config - - -parser = argparse.ArgumentParser(description="Deeplabv3 evaluation") -parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") -parser.add_argument('--data_url', required=True, default=None, help='Evaluation data url') -parser.add_argument('--checkpoint_url', default=None, help='Checkpoint path') - -args_opt = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) -print(args_opt) - - -if __name__ == "__main__": - args_opt.crop_size = config.crop_size - args_opt.base_size = config.crop_size - eval_dataset = create_dataset(args_opt, args_opt.data_url, config.epoch_size, config.batch_size, usage="eval") - net = deeplabv3_resnet50(config.seg_num_classes, [config.batch_size, 3, args_opt.crop_size, args_opt.crop_size], - infer_scale_sizes=config.eval_scales, atrous_rates=config.atrous_rates, - decoder_output_stride=config.decoder_output_stride, output_stride=config.output_stride, - fine_tune_batch_norm=config.fine_tune_batch_norm, image_pyramid=config.image_pyramid) - param_dict = load_checkpoint(args_opt.checkpoint_url) - load_param_into_net(net, param_dict) - mIou = MiouPrecision(config.seg_num_classes) - metrics = {'mIou': mIou} - loss = OhemLoss(config.seg_num_classes, config.ignore_label) - model = Model(net, loss, metrics=metrics) - model.eval(eval_dataset) +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""eval deeplabv3.""" + +import os +import argparse +import numpy as np +import cv2 +from mindspore import Tensor +import mindspore.common.dtype as mstype +import mindspore.nn as nn +from mindspore import context +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from src.nets import net_factory +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, + device_id=int(os.getenv('DEVICE_ID'))) + + +def parse_args(): + parser = argparse.ArgumentParser('mindspore deeplabv3 eval') + + # val data + parser.add_argument('--data_root', type=str, default='', help='root path of val data') + parser.add_argument('--data_lst', type=str, default='', help='list of val data') + parser.add_argument('--batch_size', type=int, default=16, help='batch size') + parser.add_argument('--crop_size', type=int, default=513, help='crop size') + parser.add_argument('--image_mean', type=list, default=[103.53, 116.28, 123.675], help='image mean') + parser.add_argument('--image_std', type=list, default=[57.375, 57.120, 58.395], help='image std') + parser.add_argument('--scales', type=float, action='append', help='scales of evaluation') + parser.add_argument('--flip', action='store_true', help='perform left-right flip') + parser.add_argument('--ignore_label', type=int, default=255, help='ignore label') + parser.add_argument('--num_classes', type=int, default=21, help='number of classes') + + # model + parser.add_argument('--model', type=str, default='deeplab_v3_s16', help='select model') + parser.add_argument('--freeze_bn', action='store_true', default=False, help='freeze bn') + parser.add_argument('--ckpt_path', type=str, default='', help='model to evaluate') + + args, _ = parser.parse_known_args() + return args + + +def cal_hist(a, b, n): + k = (a >= 0) & (a < n) + return np.bincount(n * a[k].astype(np.int32) + b[k], minlength=n ** 2).reshape(n, n) + + +def resize_long(img, long_size=513): + h, w, _ = img.shape + if h > w: + new_h = long_size + new_w = int(1.0 * long_size * w / h) + else: + new_w = long_size + new_h = int(1.0 * long_size * h / w) + imo = cv2.resize(img, (new_w, new_h)) + return imo + + +class BuildEvalNetwork(nn.Cell): + def __init__(self, network): + super(BuildEvalNetwork, self).__init__() + self.network = network + self.softmax = nn.Softmax(axis=1) + + def construct(self, input_data): + output = self.network(input_data) + output = self.softmax(output) + return output + + +def pre_process(args, img_, crop_size=513): + # resize + img_ = resize_long(img_, crop_size) + resize_h, resize_w, _ = img_.shape + + # mean, std + image_mean = np.array(args.image_mean) + image_std = np.array(args.image_std) + img_ = (img_ - image_mean) / image_std + + # pad to crop_size + pad_h = crop_size - img_.shape[0] + pad_w = crop_size - img_.shape[1] + if pad_h > 0 or pad_w > 0: + img_ = cv2.copyMakeBorder(img_, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0) + + # hwc to chw + img_ = img_.transpose((2, 0, 1)) + return img_, resize_h, resize_w + + +def eval_batch(args, eval_net, img_lst, crop_size=513, flip=True): + result_lst = [] + batch_size = len(img_lst) + batch_img = np.zeros((args.batch_size, 3, crop_size, crop_size), dtype=np.float32) + resize_hw = [] + for l in range(batch_size): + img_ = img_lst[l] + img_, resize_h, resize_w = pre_process(args, img_, crop_size) + batch_img[l] = img_ + resize_hw.append([resize_h, resize_w]) + + batch_img = np.ascontiguousarray(batch_img) + net_out = eval_net(Tensor(batch_img, mstype.float32)) + net_out = net_out.asnumpy() + + if flip: + batch_img = batch_img[:, :, :, ::-1] + net_out_flip = eval_net(Tensor(batch_img, mstype.float32)) + net_out += net_out_flip.asnumpy()[:, :, :, ::-1] + + for bs in range(batch_size): + probs_ = net_out[bs][:, :resize_hw[bs][0], :resize_hw[bs][1]].transpose((1, 2, 0)) + ori_h, ori_w = img_lst[bs].shape[0], img_lst[bs].shape[1] + probs_ = cv2.resize(probs_, (ori_w, ori_h)) + result_lst.append(probs_) + + return result_lst + + +def eval_batch_scales(args, eval_net, img_lst, scales, + base_crop_size=513, flip=True): + sizes_ = [int((base_crop_size - 1) * sc) + 1 for sc in scales] + probs_lst = eval_batch(args, eval_net, img_lst, crop_size=sizes_[0], flip=flip) + print(sizes_) + for crop_size_ in sizes_[1:]: + probs_lst_tmp = eval_batch(args, eval_net, img_lst, crop_size=crop_size_, flip=flip) + for pl, _ in enumerate(probs_lst): + probs_lst[pl] += probs_lst_tmp[pl] + + result_msk = [] + for i in probs_lst: + result_msk.append(i.argmax(axis=2)) + return result_msk + + +def net_eval(): + args = parse_args() + + # data list + with open(args.data_lst) as f: + img_lst = f.readlines() + + # network + if args.model == 'deeplab_v3_s16': + network = net_factory.nets_map[args.model]('eval', args.num_classes, 16, args.freeze_bn) + elif args.model == 'deeplab_v3_s8': + network = net_factory.nets_map[args.model]('eval', args.num_classes, 8, args.freeze_bn) + else: + raise NotImplementedError('model [{:s}] not recognized'.format(args.model)) + + eval_net = BuildEvalNetwork(network) + + # load model + param_dict = load_checkpoint(args.ckpt_path) + load_param_into_net(eval_net, param_dict) + eval_net.set_train(False) + + # evaluate + hist = np.zeros((args.num_classes, args.num_classes)) + batch_img_lst = [] + batch_msk_lst = [] + bi = 0 + image_num = 0 + for i, line in enumerate(img_lst): + img_path, msk_path = line.strip().split(' ') + img_path = os.path.join(args.data_root, img_path) + msk_path = os.path.join(args.data_root, msk_path) + img_ = cv2.imread(img_path) + msk_ = cv2.imread(msk_path, cv2.IMREAD_GRAYSCALE) + batch_img_lst.append(img_) + batch_msk_lst.append(msk_) + bi += 1 + if bi == args.batch_size: + batch_res = eval_batch_scales(args, eval_net, batch_img_lst, scales=args.scales, + base_crop_size=args.crop_size, flip=args.flip) + for mi in range(args.batch_size): + hist += cal_hist(batch_msk_lst[mi].flatten(), batch_res[mi].flatten(), args.num_classes) + + bi = 0 + batch_img_lst = [] + batch_msk_lst = [] + print('processed {} images'.format(i+1)) + image_num = i + + if bi > 0: + batch_res = eval_batch_scales(args, eval_net, batch_img_lst, scales=args.scales, + base_crop_size=args.crop_size, flip=args.flip) + for mi in range(bi): + hist += cal_hist(batch_msk_lst[mi].flatten(), batch_res[mi].flatten(), args.num_classes) + print('processed {} images'.format(image_num + 1)) + + print(hist) + iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) + print('per-class IoU', iu) + print('mean IoU', np.nanmean(iu)) + + +if __name__ == '__main__': + net_eval() diff --git a/model_zoo/official/cv/deeplabv3/requirements.txt b/model_zoo/official/cv/deeplabv3/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..033d55ffced90deb5475269cd958ff316f78ccad --- /dev/null +++ b/model_zoo/official/cv/deeplabv3/requirements.txt @@ -0,0 +1,4 @@ +mindspore +numpy +Pillow +python-opencv diff --git a/model_zoo/official/cv/deeplabv3/src/utils/file_io.py b/model_zoo/official/cv/deeplabv3/scripts/build_data.sh similarity index 61% rename from model_zoo/official/cv/deeplabv3/src/utils/file_io.py rename to model_zoo/official/cv/deeplabv3/scripts/build_data.sh index 9d6db034f3c41369d87f30a0705cced7ffce8991..0d29e8ace7d5a1f5ec02f123b8cad8ef8861b73b 100644 --- a/model_zoo/official/cv/deeplabv3/src/utils/file_io.py +++ b/model_zoo/official/cv/deeplabv3/scripts/build_data.sh @@ -1,3 +1,4 @@ +#!/bin/bash # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,25 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""File operation module.""" -import os - -def _is_obs(url): - return url.startswith("obs://") or url.startswith("s3://") - - -def read(url, binary=False): - if _is_obs(url): - # TODO read cloud file. - return None - - with open(url, "rb" if binary else "r") as f: - return f.read() - - -def walk(url): - if _is_obs(url): - # TODO read cloud file. - return None - return os.walk(url) +export DEVICE_ID=7 +python /PATH/TO/MODEL_ZOO_CODE/data/build_seg_data.py --data_root=/PATH/TO/DATA_ROOT \ + --data_lst=/PATH/TO/DATA_lst.txt \ + --dst_path=/PATH/TO/MINDRECORED_NAME.mindrecord \ + --num_shards=8 \ + --shuffle=True \ No newline at end of file diff --git a/model_zoo/official/cv/deeplabv3/scripts/run_distribute_train.sh b/model_zoo/official/cv/deeplabv3/scripts/run_distribute_train.sh deleted file mode 100644 index 51fd741d78d31ca9611b993595a23e7b5fe95856..0000000000000000000000000000000000000000 --- a/model_zoo/official/cv/deeplabv3/scripts/run_distribute_train.sh +++ /dev/null @@ -1,68 +0,0 @@ -#!/bin/bash -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -echo "==============================================================================================================" -echo "Please run the scipt as: " -echo "bash run_distribute_train.sh RANK_TABLE_FILE DATA_PATH" -echo "for example: bash run_distribute_train.sh RANK_TABLE_FILE DATA_PATH [PRETRAINED_CKPT_PATH](option)" -echo "It is better to use absolute path." -echo "==============================================================================================================" - -DATA_DIR=$2 - -export RANK_TABLE_FILE=$1 -export RANK_SIZE=8 -export DEVICE_NUM=8 -PATH_CHECKPOINT="" -if [ $# == 3 ] -then - PATH_CHECKPOINT=$3 -fi -cores=`cat /proc/cpuinfo|grep "processor" |wc -l` -echo "the number of logical core" $cores -avg_core_per_rank=`expr $cores \/ $RANK_SIZE` -core_gap=`expr $avg_core_per_rank \- 1` -echo "avg_core_per_rank" $avg_core_per_rank -echo "core_gap" $core_gap -export SERVER_ID=0 -rank_start=$((DEVICE_NUM * SERVER_ID)) -for((i=0;i env.log - taskset -c $cmdopt python ../train.py \ - --distribute="true" \ - --device_id=$DEVICE_ID \ - --checkpoint_url=$PATH_CHECKPOINT \ - --data_url=$DATA_DIR > log.txt 2>&1 & - cd ../ -done \ No newline at end of file diff --git a/model_zoo/official/cv/deeplabv3/scripts/run_distribute_train_s16_r1.sh b/model_zoo/official/cv/deeplabv3/scripts/run_distribute_train_s16_r1.sh new file mode 100644 index 0000000000000000000000000000000000000000..5c490de58e1aa52ec64f1f1c7a09bcc6252c4a42 --- /dev/null +++ b/model_zoo/official/cv/deeplabv3/scripts/run_distribute_train_s16_r1.sh @@ -0,0 +1,54 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +ulimit -c unlimited +train_path=/PATH/TO/EXPERIMENTS_DIR +export SLOG_PRINT_TO_STDOUT=0 +train_code_path=/PATH/TO/MODEL_ZOO_CODE +export RANK_TABLE_FILE=${train_code_path}/src/tools/rank_table_8p.json +export RANK_SIZE=8 +export RANK_START_ID=0 + +if [ -d ${train_path} ]; then + rm -rf ${train_path} +fi +mkdir -p ${train_path} +mkdir ${train_path}/ckpt + +for((i=0;i<=$RANK_SIZE-1;i++)); +do + export RANK_ID=${i} + export DEVICE_ID=$((i + RANK_START_ID)) + echo 'start rank='${i}', device id='${DEVICE_ID}'...' + mkdir ${train_path}/device${DEVICE_ID} + cd ${train_path}/device${DEVICE_ID} || exit + python ${train_code_path}/train.py --train_dir=${train_path}/ckpt \ + --data_file=/PATH/TO/MINDRECORD_NAME \ + --train_epochs=300 \ + --batch_size=32 \ + --crop_size=513 \ + --base_lr=0.08 \ + --lr_type=cos \ + --min_scale=0.5 \ + --max_scale=2.0 \ + --ignore_label=255 \ + --num_classes=21 \ + --model=deeplab_v3_s16 \ + --ckpt_pre_trained=/PATH/TO/PRETRAIN_MODEL \ + --is_distributed \ + --save_steps=410 \ + --keep_checkpoint_max=200 >log 2>&1 & +done diff --git a/model_zoo/official/cv/deeplabv3/scripts/run_distribute_train_s8_r1.sh b/model_zoo/official/cv/deeplabv3/scripts/run_distribute_train_s8_r1.sh new file mode 100644 index 0000000000000000000000000000000000000000..2e50fba39a8d9b779873c3ad77b680de8c09cd9e --- /dev/null +++ b/model_zoo/official/cv/deeplabv3/scripts/run_distribute_train_s8_r1.sh @@ -0,0 +1,55 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +ulimit -c unlimited +train_path=/PATH/TO/EXPERIMENTS_DIR +export SLOG_PRINT_TO_STDOUT=0 +train_code_path=/PATH/TO/MODEL_ZOO_CODE +export RANK_TABLE_FILE=${train_code_path}/src/tools/rank_table_8p.json +export RANK_SIZE=8 +export RANK_START_ID=0 + +if [ -d ${train_path} ]; then + rm -rf ${train_path} +fi +mkdir -p ${train_path} +mkdir ${train_path}/ckpt + +for((i=0;i<=$RANK_SIZE-1;i++)); +do + export RANK_ID=${i} + export DEVICE_ID=$((i + RANK_START_ID)) + echo 'start rank='${i}', device id='${DEVICE_ID}'...' + mkdir ${train_path}/device${DEVICE_ID} + cd ${train_path}/device${DEVICE_ID} || exit + python ${train_code_path}/train.py --train_dir=${train_path}/ckpt \ + --data_file=/PATH/TO/MINDRECORD_NAME \ + --train_epochs=800 \ + --batch_size=16 \ + --crop_size=513 \ + --base_lr=0.02 \ + --lr_type=cos \ + --min_scale=0.5 \ + --max_scale=2.0 \ + --ignore_label=255 \ + --num_classes=21 \ + --model=deeplab_v3_s8 \ + --loss_scale=2048 \ + --ckpt_pre_trained=/PATH/TO/PRETRAIN_MODEL \ + --is_distributed \ + --save_steps=820 \ + --keep_checkpoint_max=200 >log 2>&1 & +done diff --git a/model_zoo/official/cv/deeplabv3/scripts/run_distribute_train_s8_r2.sh b/model_zoo/official/cv/deeplabv3/scripts/run_distribute_train_s8_r2.sh new file mode 100644 index 0000000000000000000000000000000000000000..0a34002ae9e346c6db939836538f21a5f6d612b1 --- /dev/null +++ b/model_zoo/official/cv/deeplabv3/scripts/run_distribute_train_s8_r2.sh @@ -0,0 +1,55 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +ulimit -c unlimited +train_path=/PATH/TO/EXPERIMENTS_DIR +export SLOG_PRINT_TO_STDOUT=0 +train_code_path=/PATH/TO/MODEL_ZOO_CODE +export RANK_TABLE_FILE=${train_code_path}/src/tools/rank_table_8p.json +export RANK_SIZE=8 +export RANK_START_ID=0 + +if [ -d ${train_path} ]; then + rm -rf ${train_path} +fi +mkdir -p ${train_path} +mkdir ${train_path}/ckpt + +for((i=0;i<=$RANK_SIZE-1;i++)); +do + export RANK_ID=${i} + export DEVICE_ID=$((i + RANK_START_ID)) + echo 'start rank='${i}', device id='${DEVICE_ID}'...' + mkdir ${train_path}/device${DEVICE_ID} + cd ${train_path}/device${DEVICE_ID} || exit + python ${train_code_path}/train.py --train_dir=${train_path}/ckpt \ + --data_file=/PATH/TO/MINDRECORD_NAME \ + --train_epochs=300 \ + --batch_size=16 \ + --crop_size=513 \ + --base_lr=0.008 \ + --lr_type=cos \ + --min_scale=0.5 \ + --max_scale=2.0 \ + --ignore_label=255 \ + --num_classes=21 \ + --model=deeplab_v3_s8 \ + --loss_scale=2048 \ + --ckpt_pre_trained=/PATH/TO/PRETRAIN_MODEL \ + --is_distributed \ + --save_steps=110 \ + --keep_checkpoint_max=200 >log 2>&1 & +done diff --git a/model_zoo/official/cv/deeplabv3/scripts/run_eval.sh b/model_zoo/official/cv/deeplabv3/scripts/run_eval.sh deleted file mode 100644 index 248b84597e5d23c13c2bc551784f31959f27d48e..0000000000000000000000000000000000000000 --- a/model_zoo/official/cv/deeplabv3/scripts/run_eval.sh +++ /dev/null @@ -1,34 +0,0 @@ -#!/bin/bash -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the License); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# httpwww.apache.orglicensesLICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an AS IS BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -echo "==============================================================================================================" -echo "Please run the scipt as: " -echo "bash run_eval.sh DEVICE_ID DATA_PATH PRETRAINED_CKPT_PATH" -echo "for example: bash run_eval.sh DEVICE_ID DATA_PATH PRETRAINED_CKPT_PATH" -echo "==============================================================================================================" - -DEVICE_ID=$1 -DATA_DIR=$2 -PATH_CHECKPOINT=$3 - - -mkdir -p ms_log -CUR_DIR=`pwd` -export GLOG_log_dir=${CUR_DIR}/ms_log -export GLOG_logtostderr=0 -python eval.py \ - --device_id=$DEVICE_ID \ - --checkpoint_url=$PATH_CHECKPOINT \ - --data_url=$DATA_DIR > eval.log 2>&1 & \ No newline at end of file diff --git a/model_zoo/official/cv/deeplabv3/scripts/run_eval_s16.sh b/model_zoo/official/cv/deeplabv3/scripts/run_eval_s16.sh new file mode 100644 index 0000000000000000000000000000000000000000..66305e860a1f6a00112fd0c4162455d3d6f62772 --- /dev/null +++ b/model_zoo/official/cv/deeplabv3/scripts/run_eval_s16.sh @@ -0,0 +1,37 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +export DEVICE_ID=3 +export SLOG_PRINT_TO_STDOUT=0 +train_code_path=/PATH/TO/MODEL_ZOO_CODE +eval_path=/PATH/TO/EVAL + +if [ -d ${eval_path} ]; then + rm -rf ${eval_path} +fi +mkdir -p ${eval_path} + +python ${train_code_path}/eval.py --data_root=/PATH/TO/DATA \ + --data_lst=/PATH/TO/DATA_lst.txt \ + --batch_size=32 \ + --crop_size=513 \ + --ignore_label=255 \ + --num_classes=21 \ + --model=deeplab_v3_s16 \ + --scales=1.0 \ + --freeze_bn \ + --ckpt_path=/PATH/TO/PRETRAIN_MODEL >${eval_path}/eval_log 2>&1 & + diff --git a/model_zoo/official/cv/deeplabv3/scripts/run_eval_s8.sh b/model_zoo/official/cv/deeplabv3/scripts/run_eval_s8.sh new file mode 100644 index 0000000000000000000000000000000000000000..a189089ceb0d5485f3e7a3b3b2e66399e9c0967c --- /dev/null +++ b/model_zoo/official/cv/deeplabv3/scripts/run_eval_s8.sh @@ -0,0 +1,37 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +export DEVICE_ID=3 +export SLOG_PRINT_TO_STDOUT=0 +train_code_path=/PATH/TO/MODEL_ZOO_CODE +eval_path=/PATH/TO/EVAL + +if [ -d ${eval_path} ]; then + rm -rf ${eval_path} +fi +mkdir -p ${eval_path} + +python ${train_code_path}/eval.py --data_root=/PATH/TO/DATA \ + --data_lst=/PATH/TO/DATA_lst.txt \ + --batch_size=16 \ + --crop_size=513 \ + --ignore_label=255 \ + --num_classes=21 \ + --model=deeplab_v3_s8 \ + --scales=1.0 \ + --freeze_bn \ + --ckpt_path=/PATH/TO/PRETRAIN_MODEL >${eval_path}/eval_log 2>&1 & + diff --git a/model_zoo/official/cv/deeplabv3/scripts/run_eval_s8_multiscale.sh b/model_zoo/official/cv/deeplabv3/scripts/run_eval_s8_multiscale.sh new file mode 100644 index 0000000000000000000000000000000000000000..824d539e3a1d9b52e4ce4f93a182dc5a90eaf4b0 --- /dev/null +++ b/model_zoo/official/cv/deeplabv3/scripts/run_eval_s8_multiscale.sh @@ -0,0 +1,41 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +export DEVICE_ID=3 +export SLOG_PRINT_TO_STDOUT=0 +train_code_path=/PATH/TO/MODEL_ZOO_CODE +eval_path=/PATH/TO/EVAL + +if [ -d ${eval_path} ]; then + rm -rf ${eval_path} +fi +mkdir -p ${eval_path} + +python ${train_code_path}/eval.py --data_root=/PATH/TO/DATA \ + --data_lst=/PATH/TO/DATA_lst.txt \ + --batch_size=16 \ + --crop_size=513 \ + --ignore_label=255 \ + --num_classes=21 \ + --model=deeplab_v3_s8 \ + --scales=0.5 \ + --scales=0.75 \ + --scales=1.0 \ + --scales=1.25 \ + --scales=1.75 \ + --freeze_bn \ + --ckpt_path=/PATH/TO/PRETRAIN_MODEL >${eval_path}/eval_log 2>&1 & + diff --git a/model_zoo/official/cv/deeplabv3/scripts/run_eval_s8_multiscale_flip.sh b/model_zoo/official/cv/deeplabv3/scripts/run_eval_s8_multiscale_flip.sh new file mode 100644 index 0000000000000000000000000000000000000000..88beb11d6fa785c69826a03a38f812e723bf283e --- /dev/null +++ b/model_zoo/official/cv/deeplabv3/scripts/run_eval_s8_multiscale_flip.sh @@ -0,0 +1,42 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +export DEVICE_ID=3 +export SLOG_PRINT_TO_STDOUT=0 +train_code_path=/PATH/TO/MODEL_ZOO_CODE +eval_path=/PATH/TO/EVAL + +if [ -d ${eval_path} ]; then + rm -rf ${eval_path} +fi +mkdir -p ${eval_path} + +python ${train_code_path}/eval.py --data_root=/PATH/TO/DATA \ + --data_lst=/PATH/TO/DATA_lst.txt \ + --batch_size=16 \ + --crop_size=513 \ + --ignore_label=255 \ + --num_classes=21 \ + --model=deeplab_v3_s8 \ + --scales=0.5 \ + --scales=0.75 \ + --scales=1.0 \ + --scales=1.25 \ + --scales=1.75 \ + --flip \ + --freeze_bn \ + --ckpt_path=/PATH/TO/PRETRAIN_MODEL >${eval_path}/eval_log 2>&1 & + diff --git a/model_zoo/official/cv/deeplabv3/scripts/run_standalone_train.sh b/model_zoo/official/cv/deeplabv3/scripts/run_standalone_train.sh index 6f5e8dbe52a7e66218cca82cef3c9c4a78247970..a9a741ff2ac861f5123aadc5a11de8b59406a02c 100644 --- a/model_zoo/official/cv/deeplabv3/scripts/run_standalone_train.sh +++ b/model_zoo/official/cv/deeplabv3/scripts/run_standalone_train.sh @@ -1,38 +1,44 @@ #!/bin/bash # Copyright 2020 Huawei Technologies Co., Ltd # -# Licensed under the Apache License, Version 2.0 (the License); +# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# httpwww.apache.orglicensesLICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an AS IS BASIS, +# distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -echo "==============================================================================================================" -echo "Please run the scipt as: " -echo "bash run_standalone_pretrain.sh DEVICE_ID DATA_PATH" -echo "for example: bash run_standalone_train.sh DEVICE_ID DATA_PATH [PRETRAINED_CKPT_PATH](option)" -echo "==============================================================================================================" - -DEVICE_ID=$1 -DATA_DIR=$2 -PATH_CHECKPOINT="" -if [ $# == 3 ] -then - PATH_CHECKPOINT=$3 + +export DEVICE_ID=5 +export SLOG_PRINT_TO_STDOUT=0 +train_path=/PATH/TO/EXPERIMENTS_DIR +train_code_path=/PATH/TO/MODEL_ZOO_CODE + +if [ -d ${train_path} ]; then + rm -rf ${train_path} fi - -mkdir -p ms_log -CUR_DIR=`pwd` -export GLOG_log_dir=${CUR_DIR}/ms_log -export GLOG_logtostderr=0 -python train.py \ - --distribute="false" \ - --device_id=$DEVICE_ID \ - --checkpoint_url=$PATH_CHECKPOINT \ - --data_url=$DATA_DIR > log.txt 2>&1 & \ No newline at end of file +mkdir -p ${train_path} +mkdir ${train_path}/device${DEVICE_ID} +mkdir ${train_path}/ckpt +cd ${train_path}/device${DEVICE_ID} || exit + +python ${train_code_path}/train.py --data_file=/PATH/TO/MINDRECORD_NAME \ + --train_dir=${train_path}/ckpt \ + --train_epochs=200 \ + --batch_size=32 \ + --crop_size=513 \ + --base_lr=0.015 \ + --lr_type=cos \ + --min_scale=0.5 \ + --max_scale=2.0 \ + --ignore_label=255 \ + --num_classes=21 \ + --model=deeplab_v3_s16 \ + --ckpt_pre_trained=/PATH/TO/PRETRAIN_MODEL \ + --save_steps=1500 \ + --keep_checkpoint_max=200 >log 2>&1 & \ No newline at end of file diff --git a/model_zoo/official/cv/deeplabv3/src/__init__.py b/model_zoo/official/cv/deeplabv3/src/__init__.py index 64d070799292b2b00357fbbd5c3a05bc929fa971..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644 --- a/model_zoo/official/cv/deeplabv3/src/__init__.py +++ b/model_zoo/official/cv/deeplabv3/src/__init__.py @@ -1,23 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the License); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# httpwww.apache.orglicensesLICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an AS IS BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Init DeepLabv3.""" -from .deeplabv3 import ASPP, DeepLabV3, deeplabv3_resnet50 -from .backbone import * - -__all__ = [ - "ASPP", "DeepLabV3", "deeplabv3_resnet50" -] - -__all__.extend(backbone.__all__) diff --git a/model_zoo/official/cv/deeplabv3/src/backbone/__init__.py b/model_zoo/official/cv/deeplabv3/src/backbone/__init__.py deleted file mode 100644 index 6f78084131755bded1f15694ff69685b8cae2611..0000000000000000000000000000000000000000 --- a/model_zoo/official/cv/deeplabv3/src/backbone/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the License); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# httpwww.apache.orglicensesLICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an AS IS BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Init backbone.""" -from .resnet_deeplab import Subsample, DepthwiseConv2dNative, SpaceToBatch, BatchToSpace, ResNetV1, \ - RootBlockBeta, resnet50_dl - -__all__ = [ - "Subsample", "DepthwiseConv2dNative", "SpaceToBatch", "BatchToSpace", "ResNetV1", "RootBlockBeta", "resnet50_dl" -] diff --git a/model_zoo/official/cv/deeplabv3/src/backbone/resnet_deeplab.py b/model_zoo/official/cv/deeplabv3/src/backbone/resnet_deeplab.py deleted file mode 100644 index 1dda6fe746d604df4d9ac61b3476f0603ab0d962..0000000000000000000000000000000000000000 --- a/model_zoo/official/cv/deeplabv3/src/backbone/resnet_deeplab.py +++ /dev/null @@ -1,577 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""ResNet based DeepLab.""" -import mindspore.nn as nn -from mindspore.ops import operations as P -from mindspore.common.initializer import initializer -from mindspore._checkparam import twice -from mindspore.common.parameter import Parameter - - -def _conv_bn_relu(in_channel, - out_channel, - ksize, - stride=1, - padding=0, - dilation=1, - pad_mode="pad", - use_batch_statistics=False): - """Get a conv2d -> batchnorm -> relu layer""" - return nn.SequentialCell( - [nn.Conv2d(in_channel, - out_channel, - kernel_size=ksize, - stride=stride, - padding=padding, - dilation=dilation, - pad_mode=pad_mode), - nn.BatchNorm2d(out_channel, use_batch_statistics=use_batch_statistics), - nn.ReLU()] - ) - - -def _deep_conv_bn_relu(in_channel, - channel_multiplier, - ksize, - stride=1, - padding=0, - dilation=1, - pad_mode="pad", - use_batch_statistics=False): - """Get a spacetobatch -> conv2d -> batchnorm -> relu -> batchtospace layer""" - return nn.SequentialCell( - [DepthwiseConv2dNative(in_channel, - channel_multiplier, - kernel_size=ksize, - stride=stride, - padding=padding, - dilation=dilation, - pad_mode=pad_mode), - nn.BatchNorm2d(channel_multiplier * in_channel, use_batch_statistics=use_batch_statistics), - nn.ReLU()] - ) - - -def _stob_deep_conv_btos_bn_relu(in_channel, - channel_multiplier, - ksize, - space_to_batch_block_shape, - batch_to_space_block_shape, - paddings, - crops, - stride=1, - padding=0, - dilation=1, - pad_mode="pad", - use_batch_statistics=False): - """Get a spacetobatch -> conv2d -> batchnorm -> relu -> batchtospace layer""" - return nn.SequentialCell( - [SpaceToBatch(space_to_batch_block_shape, paddings), - DepthwiseConv2dNative(in_channel, - channel_multiplier, - kernel_size=ksize, - stride=stride, - padding=padding, - dilation=dilation, - pad_mode=pad_mode), - BatchToSpace(batch_to_space_block_shape, crops), - nn.BatchNorm2d(channel_multiplier * in_channel, use_batch_statistics=use_batch_statistics), - nn.ReLU()] - ) - - -def _stob_conv_btos_bn_relu(in_channel, - out_channel, - ksize, - space_to_batch_block_shape, - batch_to_space_block_shape, - paddings, - crops, - stride=1, - padding=0, - dilation=1, - pad_mode="pad", - use_batch_statistics=False): - """Get a spacetobatch -> conv2d -> batchnorm -> relu -> batchtospace layer""" - return nn.SequentialCell([SpaceToBatch(space_to_batch_block_shape, paddings), - nn.Conv2d(in_channel, - out_channel, - kernel_size=ksize, - stride=stride, - padding=padding, - dilation=dilation, - pad_mode=pad_mode), - BatchToSpace(batch_to_space_block_shape, crops), - nn.BatchNorm2d(out_channel, use_batch_statistics=use_batch_statistics), - nn.ReLU()] - ) - - -def _make_layer(block, - in_channels, - out_channels, - num_blocks, - stride=1, - rate=1, - multi_grads=None, - output_stride=None, - g_current_stride=2, - g_rate=1): - """Make layer for DeepLab-ResNet network.""" - if multi_grads is None: - multi_grads = [1] * num_blocks - # (stride == 2, num_blocks == 4 --> strides == [1, 1, 1, 2]) - strides = [1] * (num_blocks - 1) + [stride] - blocks = [] - if output_stride is not None: - if output_stride % 4 != 0: - raise ValueError('The output_stride needs to be a multiple of 4.') - output_stride //= 4 - for i_stride, _ in enumerate(strides): - if output_stride is not None and g_current_stride > output_stride: - raise ValueError('The target output_stride cannot be reached.') - if output_stride is not None and g_current_stride == output_stride: - b_rate = g_rate - b_stride = 1 - g_rate *= strides[i_stride] - else: - b_rate = rate - b_stride = strides[i_stride] - g_current_stride *= strides[i_stride] - blocks.append(block(in_channels=in_channels, - out_channels=out_channels, - stride=b_stride, - rate=b_rate, - multi_grad=multi_grads[i_stride])) - in_channels = out_channels - layer = nn.SequentialCell(blocks) - return layer, g_current_stride, g_rate - - -class Subsample(nn.Cell): - """ - Subsample for DeepLab-ResNet. - Args: - factor (int): Sample factor. - Returns: - Tensor, the sub sampled tensor. - Examples: - >>> Subsample(2) - """ - def __init__(self, factor): - super(Subsample, self).__init__() - self.factor = factor - self.pool = nn.MaxPool2d(kernel_size=1, - stride=factor) - - def construct(self, x): - if self.factor == 1: - return x - return self.pool(x) - - -class SpaceToBatch(nn.Cell): - def __init__(self, block_shape, paddings): - super(SpaceToBatch, self).__init__() - self.space_to_batch = P.SpaceToBatch(block_shape, paddings) - self.bs = block_shape - self.pd = paddings - - def construct(self, x): - return self.space_to_batch(x) - - -class BatchToSpace(nn.Cell): - def __init__(self, block_shape, crops): - super(BatchToSpace, self).__init__() - self.batch_to_space = P.BatchToSpace(block_shape, crops) - self.bs = block_shape - self.cr = crops - - def construct(self, x): - return self.batch_to_space(x) - - -class _DepthwiseConv2dNative(nn.Cell): - """Depthwise Conv2D Cell.""" - def __init__(self, - in_channels, - channel_multiplier, - kernel_size, - stride, - pad_mode, - padding, - dilation, - group, - weight_init): - super(_DepthwiseConv2dNative, self).__init__() - self.in_channels = in_channels - self.channel_multiplier = channel_multiplier - self.kernel_size = kernel_size - self.stride = stride - self.pad_mode = pad_mode - self.padding = padding - self.dilation = dilation - self.group = group - if not (isinstance(in_channels, int) and in_channels > 0): - raise ValueError('Attr \'in_channels\' of \'DepthwiseConv2D\' Op passed ' - + str(in_channels) + ', should be a int and greater than 0.') - if (not isinstance(kernel_size, tuple)) or len(kernel_size) != 2 or \ - (not isinstance(kernel_size[0], int)) or (not isinstance(kernel_size[1], int)) or \ - kernel_size[0] < 1 or kernel_size[1] < 1: - raise ValueError('Attr \'kernel_size\' of \'DepthwiseConv2D\' Op passed ' - + str(self.kernel_size) + ', should be a int or tuple and equal to or greater than 1.') - self.weight = Parameter(initializer(weight_init, [1, in_channels // group, *kernel_size]), - name='weight') - - def construct(self, *inputs): - """Must be overridden by all subclasses.""" - raise NotImplementedError - - -class DepthwiseConv2dNative(_DepthwiseConv2dNative): - """Depthwise Conv2D Cell.""" - def __init__(self, - in_channels, - channel_multiplier, - kernel_size, - stride=1, - pad_mode='same', - padding=0, - dilation=1, - group=1, - weight_init='normal'): - kernel_size = twice(kernel_size) - super(DepthwiseConv2dNative, self).__init__( - in_channels, - channel_multiplier, - kernel_size, - stride, - pad_mode, - padding, - dilation, - group, - weight_init) - self.depthwise_conv2d_native = P.DepthwiseConv2dNative(channel_multiplier=self.channel_multiplier, - kernel_size=self.kernel_size, - mode=3, - pad_mode=self.pad_mode, - pad=self.padding, - stride=self.stride, - dilation=self.dilation, - group=self.group) - - def set_strategy(self, strategy): - self.depthwise_conv2d_native.set_strategy(strategy) - return self - - def construct(self, x): - return self.depthwise_conv2d_native(x, self.weight) - - -class BottleneckV1(nn.Cell): - """ - ResNet V1 BottleneckV1 block definition. - Args: - in_channels (int): Input channel. - out_channels (int): Output channel. - stride (int): Stride size for the initial convolutional layer. Default: 1. - rate (int): Rate for convolution. Default: 1. - multi_grad (int): Employ a rate within network. Default: 1. - Returns: - Tensor, the ResNet unit's output. - Examples: - >>> BottleneckV1(3,256,stride=2) - """ - def __init__(self, - in_channels, - out_channels, - stride=1, - use_batch_statistics=False, - use_batch_to_stob_and_btos=False): - super(BottleneckV1, self).__init__() - expansion = 4 - mid_channels = out_channels // expansion - self.conv_bn1 = _conv_bn_relu(in_channels, - mid_channels, - ksize=1, - stride=1, - use_batch_statistics=use_batch_statistics) - self.conv_bn2 = _conv_bn_relu(mid_channels, - mid_channels, - ksize=3, - stride=stride, - padding=1, - dilation=1, - use_batch_statistics=use_batch_statistics) - if use_batch_to_stob_and_btos: - self.conv_bn2 = _stob_conv_btos_bn_relu(mid_channels, - mid_channels, - ksize=3, - stride=stride, - padding=0, - dilation=1, - space_to_batch_block_shape=2, - batch_to_space_block_shape=2, - paddings=[[2, 3], [2, 3]], - crops=[[0, 1], [0, 1]], - pad_mode="valid", - use_batch_statistics=use_batch_statistics) - - self.conv3 = nn.Conv2d(mid_channels, - out_channels, - kernel_size=1, - stride=1) - self.bn3 = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics) - if in_channels != out_channels: - conv = nn.Conv2d(in_channels, - out_channels, - kernel_size=1, - stride=stride) - bn = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics) - self.downsample = nn.SequentialCell([conv, bn]) - else: - self.downsample = Subsample(stride) - self.add = P.TensorAdd() - self.relu = nn.ReLU() - self.Reshape = P.Reshape() - - def construct(self, x): - out = self.conv_bn1(x) - out = self.conv_bn2(out) - out = self.bn3(self.conv3(out)) - out = self.add(out, self.downsample(x)) - out = self.relu(out) - return out - - -class BottleneckV2(nn.Cell): - """ - ResNet V2 Bottleneck variance V2 block definition. - Args: - in_channels (int): Input channel. - out_channels (int): Output channel. - stride (int): Stride size for the initial convolutional layer. Default: 1. - Returns: - Tensor, the ResNet unit's output. - Examples: - >>> BottleneckV2(3,256,stride=2) - """ - def __init__(self, - in_channels, - out_channels, - stride=1, - use_batch_statistics=False, - use_batch_to_stob_and_btos=False, - dilation=1): - super(BottleneckV2, self).__init__() - expansion = 4 - mid_channels = out_channels // expansion - self.conv_bn1 = _conv_bn_relu(in_channels, - mid_channels, - ksize=1, - stride=1, - use_batch_statistics=use_batch_statistics) - self.conv_bn2 = _conv_bn_relu(mid_channels, - mid_channels, - ksize=3, - stride=stride, - padding=1, - dilation=dilation, - use_batch_statistics=use_batch_statistics) - if use_batch_to_stob_and_btos: - self.conv_bn2 = _stob_conv_btos_bn_relu(mid_channels, - mid_channels, - ksize=3, - stride=stride, - padding=0, - dilation=1, - space_to_batch_block_shape=2, - batch_to_space_block_shape=2, - paddings=[[2, 3], [2, 3]], - crops=[[0, 1], [0, 1]], - pad_mode="valid", - use_batch_statistics=use_batch_statistics) - self.conv3 = nn.Conv2d(mid_channels, - out_channels, - kernel_size=1, - stride=1) - self.bn3 = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics) - if in_channels != out_channels: - conv = nn.Conv2d(in_channels, - out_channels, - kernel_size=1, - stride=stride) - bn = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics) - self.downsample = nn.SequentialCell([conv, bn]) - else: - self.downsample = Subsample(stride) - self.add = P.TensorAdd() - self.relu = nn.ReLU() - - def construct(self, x): - out = self.conv_bn1(x) - out = self.conv_bn2(out) - out = self.bn3(self.conv3(out)) - out = self.add(out, x) - out = self.relu(out) - return out - - -class BottleneckV3(nn.Cell): - """ - ResNet V1 Bottleneck variance V1 block definition. - Args: - in_channels (int): Input channel. - out_channels (int): Output channel. - stride (int): Stride size for the initial convolutional layer. Default: 1. - Returns: - Tensor, the ResNet unit's output. - Examples: - >>> BottleneckV3(3,256,stride=2) - """ - def __init__(self, - in_channels, - out_channels, - stride=1, - use_batch_statistics=False): - super(BottleneckV3, self).__init__() - expansion = 4 - mid_channels = out_channels // expansion - self.conv_bn1 = _conv_bn_relu(in_channels, - mid_channels, - ksize=1, - stride=1, - use_batch_statistics=use_batch_statistics) - self.conv_bn2 = _conv_bn_relu(mid_channels, - mid_channels, - ksize=3, - stride=stride, - padding=1, - dilation=1, - use_batch_statistics=use_batch_statistics) - self.conv3 = nn.Conv2d(mid_channels, - out_channels, - kernel_size=1, - stride=1) - self.bn3 = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics) - - if in_channels != out_channels: - conv = nn.Conv2d(in_channels, - out_channels, - kernel_size=1, - stride=stride) - bn = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics) - self.downsample = nn.SequentialCell([conv, bn]) - else: - self.downsample = Subsample(stride) - self.downsample = Subsample(stride) - self.add = P.TensorAdd() - self.relu = nn.ReLU() - - def construct(self, x): - out = self.conv_bn1(x) - out = self.conv_bn2(out) - out = self.bn3(self.conv3(out)) - out = self.add(out, self.downsample(x)) - out = self.relu(out) - return out - - -class ResNetV1(nn.Cell): - """ - ResNet V1 for DeepLab. - Args: - Returns: - Tuple, output tensor tuple, (c2,c5). - Examples: - >>> ResNetV1(False) - """ - def __init__(self, fine_tune_batch_norm=False): - super(ResNetV1, self).__init__() - self.layer_root = nn.SequentialCell( - [RootBlockBeta(fine_tune_batch_norm), - nn.MaxPool2d(kernel_size=(3, 3), - stride=(2, 2), - pad_mode='same')]) - self.layer1_1 = BottleneckV1(128, 256, stride=1, use_batch_statistics=fine_tune_batch_norm) - self.layer1_2 = BottleneckV2(256, 256, stride=1, use_batch_statistics=fine_tune_batch_norm) - self.layer1_3 = BottleneckV3(256, 256, stride=2, use_batch_statistics=fine_tune_batch_norm) - self.layer2_1 = BottleneckV1(256, 512, stride=1, use_batch_statistics=fine_tune_batch_norm) - self.layer2_2 = BottleneckV2(512, 512, stride=1, use_batch_statistics=fine_tune_batch_norm) - self.layer2_3 = BottleneckV2(512, 512, stride=1, use_batch_statistics=fine_tune_batch_norm) - self.layer2_4 = BottleneckV3(512, 512, stride=2, use_batch_statistics=fine_tune_batch_norm) - self.layer3_1 = BottleneckV1(512, 1024, stride=1, use_batch_statistics=fine_tune_batch_norm) - self.layer3_2 = BottleneckV2(1024, 1024, stride=1, use_batch_statistics=fine_tune_batch_norm) - self.layer3_3 = BottleneckV2(1024, 1024, stride=1, use_batch_statistics=fine_tune_batch_norm) - self.layer3_4 = BottleneckV2(1024, 1024, stride=1, use_batch_statistics=fine_tune_batch_norm) - self.layer3_5 = BottleneckV2(1024, 1024, stride=1, use_batch_statistics=fine_tune_batch_norm) - self.layer3_6 = BottleneckV2(1024, 1024, stride=1, use_batch_statistics=fine_tune_batch_norm) - - self.layer4_1 = BottleneckV1(1024, 2048, stride=1, use_batch_to_stob_and_btos=True, - use_batch_statistics=fine_tune_batch_norm) - self.layer4_2 = BottleneckV2(2048, 2048, stride=1, use_batch_to_stob_and_btos=True, - use_batch_statistics=fine_tune_batch_norm) - self.layer4_3 = BottleneckV2(2048, 2048, stride=1, use_batch_to_stob_and_btos=True, - use_batch_statistics=fine_tune_batch_norm) - - def construct(self, x): - x = self.layer_root(x) - x = self.layer1_1(x) - c2 = self.layer1_2(x) - x = self.layer1_3(c2) - x = self.layer2_1(x) - x = self.layer2_2(x) - x = self.layer2_3(x) - x = self.layer2_4(x) - x = self.layer3_1(x) - x = self.layer3_2(x) - x = self.layer3_3(x) - x = self.layer3_4(x) - x = self.layer3_5(x) - x = self.layer3_6(x) - - x = self.layer4_1(x) - x = self.layer4_2(x) - c5 = self.layer4_3(x) - return c2, c5 - - -class RootBlockBeta(nn.Cell): - """ - ResNet V1 beta root block definition. - Returns: - Tensor, the block unit's output. - Examples: - >>> RootBlockBeta() - """ - def __init__(self, fine_tune_batch_norm=False): - super(RootBlockBeta, self).__init__() - self.conv1 = _conv_bn_relu(3, 64, ksize=3, stride=2, padding=0, pad_mode="valid", - use_batch_statistics=fine_tune_batch_norm) - self.conv2 = _conv_bn_relu(64, 64, ksize=3, stride=1, padding=0, pad_mode="same", - use_batch_statistics=fine_tune_batch_norm) - self.conv3 = _conv_bn_relu(64, 128, ksize=3, stride=1, padding=0, pad_mode="same", - use_batch_statistics=fine_tune_batch_norm) - - def construct(self, x): - x = self.conv1(x) - x = self.conv2(x) - x = self.conv3(x) - return x - - -def resnet50_dl(fine_tune_batch_norm=False): - return ResNetV1(fine_tune_batch_norm) diff --git a/model_zoo/official/cv/deeplabv3/src/data/__init__.py b/model_zoo/official/cv/deeplabv3/src/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/model_zoo/official/cv/deeplabv3/src/data/build_seg_data.py b/model_zoo/official/cv/deeplabv3/src/data/build_seg_data.py new file mode 100644 index 0000000000000000000000000000000000000000..0e7935b48a05a16f42a18a7a0c44d054f1e99943 --- /dev/null +++ b/model_zoo/official/cv/deeplabv3/src/data/build_seg_data.py @@ -0,0 +1,72 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import os +import argparse +import numpy as np +from mindspore.mindrecord import FileWriter + + +seg_schema = {"file_name": {"type": "string"}, "label": {"type": "bytes"}, "data": {"type": "bytes"}} + + +def parse_args(): + parser = argparse.ArgumentParser('mindrecord') + + parser.add_argument('--data_root', type=str, default='', help='root path of data') + parser.add_argument('--data_lst', type=str, default='', help='list of data') + parser.add_argument('--dst_path', type=str, default='', help='save path of mindrecords') + parser.add_argument('--num_shards', type=int, default=8, help='number of shards') + parser.add_argument('--shuffle', type=bool, default=True, help='shuffle or not') + + parser_args, _ = parser.parse_known_args() + return parser_args + + +if __name__ == '__main__': + args = parse_args() + + datas = [] + with open(args.data_lst) as f: + lines = f.readlines() + if args.shuffle: + np.random.shuffle(lines) + + dst_dir = '/'.join(args.dst_path.split('/')[:-1]) + if not os.path.exists(dst_dir): + os.makedirs(dst_dir) + + print('number of samples:', len(lines)) + writer = FileWriter(file_name=args.dst_path, shard_num=args.num_shards) + writer.add_schema(seg_schema, "seg_schema") + cnt = 0 + for l in lines: + img_path, label_path = l.strip().split(' ') + sample_ = {"file_name": img_path.split('/')[-1]} + with open(os.path.join(args.data_root, img_path), 'rb') as f: + sample_['data'] = f.read() + with open(os.path.join(args.data_root, label_path), 'rb') as f: + sample_['label'] = f.read() + datas.append(sample_) + cnt += 1 + if cnt % 1000 == 0: + writer.write_raw_data(datas) + print('number of samples written:', cnt) + datas = [] + + if datas: + writer.write_raw_data(datas) + writer.commit() + print('number of samples written:', cnt) diff --git a/model_zoo/official/cv/deeplabv3/src/data/data_generator.py b/model_zoo/official/cv/deeplabv3/src/data/data_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..c0ce3106d3b9735cb56a5bca16d3db5b5ca59cef --- /dev/null +++ b/model_zoo/official/cv/deeplabv3/src/data/data_generator.py @@ -0,0 +1,92 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import cv2 +import numpy as np +import mindspore.dataset as de + + +class SegDataset: + def __init__(self, + image_mean, + image_std, + data_file='', + batch_size=32, + crop_size=512, + max_scale=2.0, + min_scale=0.5, + ignore_label=255, + num_classes=21, + num_readers=2, + num_parallel_calls=4, + shard_id=None, + shard_num=None): + + self.data_file = data_file + self.batch_size = batch_size + self.crop_size = crop_size + self.image_mean = np.array(image_mean, dtype=np.float32) + self.image_std = np.array(image_std, dtype=np.float32) + self.max_scale = max_scale + self.min_scale = min_scale + self.ignore_label = ignore_label + self.num_classes = num_classes + self.num_readers = num_readers + self.num_parallel_calls = num_parallel_calls + self.shard_id = shard_id + self.shard_num = shard_num + assert max_scale > min_scale + + def preprocess_(self, image, label): + # bgr image + image_out = cv2.imdecode(np.frombuffer(image, dtype=np.uint8), cv2.IMREAD_COLOR) + label_out = cv2.imdecode(np.frombuffer(label, dtype=np.uint8), cv2.IMREAD_GRAYSCALE) + + sc = np.random.uniform(self.min_scale, self.max_scale) + new_h, new_w = int(sc * image_out.shape[0]), int(sc * image_out.shape[1]) + image_out = cv2.resize(image_out, (new_w, new_h), interpolation=cv2.INTER_CUBIC) + label_out = cv2.resize(label_out, (new_w, new_h), interpolation=cv2.INTER_NEAREST) + + image_out = (image_out - self.image_mean) / self.image_std + h_, w_ = max(new_h, self.crop_size), max(new_w, self.crop_size) + pad_h, pad_w = h_ - new_h, w_ - new_w + if pad_h > 0 or pad_w > 0: + image_out = cv2.copyMakeBorder(image_out, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0) + label_out = cv2.copyMakeBorder(label_out, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=self.ignore_label) + offset_h = np.random.randint(0, h_ - self.crop_size + 1) + offset_w = np.random.randint(0, w_ - self.crop_size + 1) + image_out = image_out[offset_h: offset_h + self.crop_size, offset_w: offset_w + self.crop_size, :] + label_out = label_out[offset_h: offset_h + self.crop_size, offset_w: offset_w+self.crop_size] + + if np.random.uniform(0.0, 1.0) > 0.5: + image_out = image_out[:, ::-1, :] + label_out = label_out[:, ::-1] + + image_out = image_out.transpose((2, 0, 1)) + image_out = image_out.copy() + label_out = label_out.copy() + return image_out, label_out + + def get_dataset(self, repeat=1): + data_set = de.MindDataset(dataset_file=self.data_file, columns_list=["data", "label"], + shuffle=True, num_parallel_workers=self.num_readers, + num_shards=self.shard_num, shard_id=self.shard_id) + transforms_list = self.preprocess_ + data_set = data_set.map(input_columns=["data", "label"], output_columns=["data", "label"], + operations=transforms_list, num_parallel_workers=self.num_parallel_calls) + data_set = data_set.shuffle(buffer_size=self.batch_size * 10) + data_set = data_set.batch(self.batch_size, drop_remainder=True) + data_set = data_set.repeat(repeat) + return data_set diff --git a/model_zoo/official/cv/deeplabv3/src/deeplabv3.py b/model_zoo/official/cv/deeplabv3/src/deeplabv3.py deleted file mode 100644 index 7b3c8eb53b900ca9e6f588bb09be0a99e8e1e4f0..0000000000000000000000000000000000000000 --- a/model_zoo/official/cv/deeplabv3/src/deeplabv3.py +++ /dev/null @@ -1,460 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the License); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# httpwww.apache.orglicensesLICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an AS IS BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""DeepLabv3.""" - -import numpy as np -import mindspore.nn as nn -from mindspore.ops import operations as P -from .backbone.resnet_deeplab import _conv_bn_relu, resnet50_dl, _deep_conv_bn_relu, \ - DepthwiseConv2dNative, SpaceToBatch, BatchToSpace - - -class ASPPSampleBlock(nn.Cell): - """ASPP sample block.""" - def __init__(self, feature_shape, scale_size, output_stride): - super(ASPPSampleBlock, self).__init__() - sample_h = (feature_shape[0] * scale_size + 1) / output_stride + 1 - sample_w = (feature_shape[1] * scale_size + 1) / output_stride + 1 - self.sample = P.ResizeBilinear((int(sample_h), int(sample_w)), align_corners=True) - - def construct(self, x): - return self.sample(x) - - -class ASPP(nn.Cell): - """ - ASPP model for DeepLabv3. - - Args: - channel (int): Input channel. - depth (int): Output channel. - feature_shape (list): The shape of feature,[h,w]. - scale_sizes (list): Input scales for multi-scale feature extraction. - atrous_rates (list): Atrous rates for atrous spatial pyramid pooling. - output_stride (int): 'The ratio of input to output spatial resolution.' - fine_tune_batch_norm (bool): 'Fine tune the batch norm parameters or not' - - Returns: - Tensor, output tensor. - - Examples: - >>> ASPP(channel=2048,256,[14,14],[1],[6],16) - """ - def __init__(self, channel, depth, feature_shape, scale_sizes, - atrous_rates, output_stride, fine_tune_batch_norm=False): - super(ASPP, self).__init__() - self.aspp0 = _conv_bn_relu(channel, - depth, - ksize=1, - stride=1, - use_batch_statistics=fine_tune_batch_norm) - self.atrous_rates = [] - if atrous_rates is not None: - self.atrous_rates = atrous_rates - self.aspp_pointwise = _conv_bn_relu(channel, - depth, - ksize=1, - stride=1, - use_batch_statistics=fine_tune_batch_norm) - self.aspp_depth_depthwiseconv = DepthwiseConv2dNative(channel, - channel_multiplier=1, - kernel_size=3, - stride=1, - dilation=1, - pad_mode="valid") - self.aspp_depth_bn = nn.BatchNorm2d(1 * channel, use_batch_statistics=fine_tune_batch_norm) - self.aspp_depth_relu = nn.ReLU() - self.aspp_depths = [] - self.aspp_depth_spacetobatchs = [] - self.aspp_depth_batchtospaces = [] - - for scale_size in scale_sizes: - aspp_scale_depth_size = np.ceil((feature_shape[0]*scale_size)/16) - if atrous_rates is None: - break - for rate in atrous_rates: - padding = 0 - for j in range(100): - padded_size = rate * j - if padded_size >= aspp_scale_depth_size + 2 * rate: - padding = padded_size - aspp_scale_depth_size - 2 * rate - break - paddings = [[rate, rate + int(padding)], - [rate, rate + int(padding)]] - self.aspp_depth_spacetobatch = SpaceToBatch(rate, paddings) - self.aspp_depth_spacetobatchs.append(self.aspp_depth_spacetobatch) - crops = [[0, int(padding)], [0, int(padding)]] - self.aspp_depth_batchtospace = BatchToSpace(rate, crops) - self.aspp_depth_batchtospaces.append(self.aspp_depth_batchtospace) - self.aspp_depths = nn.CellList(self.aspp_depths) - self.aspp_depth_spacetobatchs = nn.CellList(self.aspp_depth_spacetobatchs) - self.aspp_depth_batchtospaces = nn.CellList(self.aspp_depth_batchtospaces) - - self.global_pooling = nn.AvgPool2d(kernel_size=(int(feature_shape[0]), int(feature_shape[1]))) - self.global_poolings = [] - for scale_size in scale_sizes: - pooling_h = np.ceil((feature_shape[0]*scale_size)/output_stride) - pooling_w = np.ceil((feature_shape[0]*scale_size)/output_stride) - self.global_poolings.append(nn.AvgPool2d(kernel_size=(int(pooling_h), int(pooling_w)))) - self.global_poolings = nn.CellList(self.global_poolings) - self.conv_bn = _conv_bn_relu(channel, - depth, - ksize=1, - stride=1, - use_batch_statistics=fine_tune_batch_norm) - self.samples = [] - for scale_size in scale_sizes: - self.samples.append(ASPPSampleBlock(feature_shape, scale_size, output_stride)) - self.samples = nn.CellList(self.samples) - self.feature_shape = feature_shape - self.concat = P.Concat(axis=1) - - def construct(self, x, scale_index=0): - aspp0 = self.aspp0(x) - aspp1 = self.global_poolings[scale_index](x) - aspp1 = self.conv_bn(aspp1) - aspp1 = self.samples[scale_index](aspp1) - output = self.concat((aspp1, aspp0)) - - for i in range(len(self.atrous_rates)): - aspp_i = self.aspp_depth_spacetobatchs[i + scale_index * len(self.atrous_rates)](x) - aspp_i = self.aspp_depth_depthwiseconv(aspp_i) - aspp_i = self.aspp_depth_batchtospaces[i + scale_index * len(self.atrous_rates)](aspp_i) - aspp_i = self.aspp_depth_bn(aspp_i) - aspp_i = self.aspp_depth_relu(aspp_i) - aspp_i = self.aspp_pointwise(aspp_i) - output = self.concat((output, aspp_i)) - return output - - -class DecoderSampleBlock(nn.Cell): - """Decoder sample block.""" - def __init__(self, feature_shape, scale_size=1.0, decoder_output_stride=4): - super(DecoderSampleBlock, self).__init__() - sample_h = (feature_shape[0] * scale_size + 1) / decoder_output_stride + 1 - sample_w = (feature_shape[1] * scale_size + 1) / decoder_output_stride + 1 - self.sample = P.ResizeBilinear((int(sample_h), int(sample_w)), align_corners=True) - - def construct(self, x): - return self.sample(x) - - -class Decoder(nn.Cell): - """ - Decode module for DeepLabv3. - Args: - low_level_channel (int): Low level input channel - channel (int): Input channel. - depth (int): Output channel. - feature_shape (list): 'Input image shape, [N,C,H,W].' - scale_sizes (list): 'Input scales for multi-scale feature extraction.' - decoder_output_stride (int): 'The ratio of input to output spatial resolution' - fine_tune_batch_norm (bool): 'Fine tune the batch norm parameters or not' - Returns: - Tensor, output tensor. - Examples: - >>> Decoder(256, 100, [56,56]) - """ - def __init__(self, - low_level_channel, - channel, - depth, - feature_shape, - scale_sizes, - decoder_output_stride, - fine_tune_batch_norm): - super(Decoder, self).__init__() - self.feature_projection = _conv_bn_relu(low_level_channel, 48, ksize=1, stride=1, - pad_mode="same", use_batch_statistics=fine_tune_batch_norm) - self.decoder_depth0 = _deep_conv_bn_relu(channel + 48, - channel_multiplier=1, - ksize=3, - stride=1, - pad_mode="same", - dilation=1, - use_batch_statistics=fine_tune_batch_norm) - self.decoder_pointwise0 = _conv_bn_relu(channel + 48, - depth, - ksize=1, - stride=1, - use_batch_statistics=fine_tune_batch_norm) - self.decoder_depth1 = _deep_conv_bn_relu(depth, - channel_multiplier=1, - ksize=3, - stride=1, - pad_mode="same", - dilation=1, - use_batch_statistics=fine_tune_batch_norm) - self.decoder_pointwise1 = _conv_bn_relu(depth, - depth, - ksize=1, - stride=1, - use_batch_statistics=fine_tune_batch_norm) - self.depth = depth - self.concat = P.Concat(axis=1) - self.samples = [] - for scale_size in scale_sizes: - self.samples.append(DecoderSampleBlock(feature_shape, scale_size, decoder_output_stride)) - self.samples = nn.CellList(self.samples) - - def construct(self, x, low_level_feature, scale_index): - low_level_feature = self.feature_projection(low_level_feature) - low_level_feature = self.samples[scale_index](low_level_feature) - x = self.samples[scale_index](x) - output = self.concat((x, low_level_feature)) - output = self.decoder_depth0(output) - output = self.decoder_pointwise0(output) - output = self.decoder_depth1(output) - output = self.decoder_pointwise1(output) - return output - - -class SingleDeepLabV3(nn.Cell): - """ - DeepLabv3 Network. - Args: - num_classes (int): Class number. - feature_shape (list): Input image shape, [N,C,H,W]. - backbone (Cell): Backbone Network. - channel (int): Resnet output channel. - depth (int): ASPP block depth. - scale_sizes (list): Input scales for multi-scale feature extraction. - atrous_rates (list): Atrous rates for atrous spatial pyramid pooling. - decoder_output_stride (int): 'The ratio of input to output spatial resolution' - output_stride (int): 'The ratio of input to output spatial resolution.' - fine_tune_batch_norm (bool): 'Fine tune the batch norm parameters or not' - Returns: - Tensor, output tensor. - Examples: - >>> SingleDeepLabV3(num_classes=10, - >>> feature_shape=[1,3,224,224], - >>> backbone=resnet50_dl(), - >>> channel=2048, - >>> depth=256) - >>> scale_sizes=[1.0]) - >>> atrous_rates=[6]) - >>> decoder_output_stride=4) - >>> output_stride=16) - """ - - def __init__(self, - num_classes, - feature_shape, - backbone, - channel, - depth, - scale_sizes, - atrous_rates, - decoder_output_stride, - output_stride, - fine_tune_batch_norm=False): - super(SingleDeepLabV3, self).__init__() - self.num_classes = num_classes - self.channel = channel - self.depth = depth - self.scale_sizes = [] - for scale_size in np.sort(scale_sizes): - self.scale_sizes.append(scale_size) - self.net = backbone - self.aspp = ASPP(channel=self.channel, - depth=self.depth, - feature_shape=[feature_shape[2], - feature_shape[3]], - scale_sizes=self.scale_sizes, - atrous_rates=atrous_rates, - output_stride=output_stride, - fine_tune_batch_norm=fine_tune_batch_norm) - - atrous_rates_len = 0 - if atrous_rates is not None: - atrous_rates_len = len(atrous_rates) - self.fc1 = _conv_bn_relu(depth * (2 + atrous_rates_len), depth, - ksize=1, - stride=1, - use_batch_statistics=fine_tune_batch_norm) - self.fc2 = nn.Conv2d(depth, - num_classes, - kernel_size=1, - stride=1, - has_bias=True) - self.upsample = P.ResizeBilinear((int(feature_shape[2]), - int(feature_shape[3])), - align_corners=True) - self.samples = [] - for scale_size in self.scale_sizes: - self.samples.append(SampleBlock(feature_shape, scale_size)) - self.samples = nn.CellList(self.samples) - self.feature_shape = [float(feature_shape[0]), float(feature_shape[1]), float(feature_shape[2]), - float(feature_shape[3])] - - self.pad = P.Pad(((0, 0), (0, 0), (1, 1), (1, 1))) - self.dropout = nn.Dropout(keep_prob=0.9) - self.shape = P.Shape() - self.decoder_output_stride = decoder_output_stride - if decoder_output_stride is not None: - self.decoder = Decoder(low_level_channel=depth, - channel=depth, - depth=depth, - feature_shape=[feature_shape[2], - feature_shape[3]], - scale_sizes=self.scale_sizes, - decoder_output_stride=decoder_output_stride, - fine_tune_batch_norm=fine_tune_batch_norm) - - def construct(self, x, scale_index=0): - x = (2.0 / 255.0) * x - 1.0 - x = self.pad(x) - low_level_feature, feature_map = self.net(x) - for scale_size in self.scale_sizes: - if scale_size * self.feature_shape[2] + 1.0 >= self.shape(x)[2] - 2: - output = self.aspp(feature_map, scale_index) - output = self.fc1(output) - if self.decoder_output_stride is not None: - output = self.decoder(output, low_level_feature, scale_index) - output = self.fc2(output) - output = self.samples[scale_index](output) - return output - scale_index += 1 - return feature_map - - -class SampleBlock(nn.Cell): - """Sample block.""" - def __init__(self, - feature_shape, - scale_size=1.0): - super(SampleBlock, self).__init__() - sample_h = np.ceil(float(feature_shape[2]) * scale_size) - sample_w = np.ceil(float(feature_shape[3]) * scale_size) - self.sample = P.ResizeBilinear((int(sample_h), int(sample_w)), align_corners=True) - - def construct(self, x): - return self.sample(x) - - -class DeepLabV3(nn.Cell): - """DeepLabV3 model.""" - def __init__(self, num_classes, feature_shape, backbone, channel, depth, infer_scale_sizes, atrous_rates, - decoder_output_stride, output_stride, fine_tune_batch_norm, image_pyramid): - super(DeepLabV3, self).__init__() - self.infer_scale_sizes = [] - if infer_scale_sizes is not None: - self.infer_scale_sizes = infer_scale_sizes - - self.infer_scale_sizes = infer_scale_sizes - if image_pyramid is None: - image_pyramid = [1.0] - - self.image_pyramid = image_pyramid - scale_sizes = [] - for pyramid in image_pyramid: - scale_sizes.append(pyramid) - for scale in infer_scale_sizes: - scale_sizes.append(scale) - self.samples = [] - for scale_size in scale_sizes: - self.samples.append(SampleBlock(feature_shape, scale_size)) - self.samples = nn.CellList(self.samples) - self.deeplabv3 = SingleDeepLabV3(num_classes=num_classes, - feature_shape=feature_shape, - backbone=resnet50_dl(fine_tune_batch_norm), - channel=channel, - depth=depth, - scale_sizes=scale_sizes, - atrous_rates=atrous_rates, - decoder_output_stride=decoder_output_stride, - output_stride=output_stride, - fine_tune_batch_norm=fine_tune_batch_norm) - self.softmax = P.Softmax(axis=1) - self.concat = P.Concat(axis=2) - self.expand_dims = P.ExpandDims() - self.reduce_mean = P.ReduceMean() - self.argmax = P.Argmax(axis=1) - self.sample_common = P.ResizeBilinear((int(feature_shape[2]), - int(feature_shape[3])), - align_corners=True) - - def construct(self, x): - logits = () - if self.training: - if len(self.image_pyramid) >= 1: - if self.image_pyramid[0] == 1: - logits = self.deeplabv3(x) - else: - x1 = self.samples[0](x) - logits = self.deeplabv3(x1) - logits = self.sample_common(logits) - logits = self.expand_dims(logits, 2) - for i in range(len(self.image_pyramid) - 1): - x_i = self.samples[i + 1](x) - logits_i = self.deeplabv3(x_i) - logits_i = self.sample_common(logits_i) - logits_i = self.expand_dims(logits_i, 2) - logits = self.concat((logits, logits_i)) - logits = self.reduce_mean(logits, 2) - return logits - if len(self.infer_scale_sizes) >= 1: - infer_index = len(self.image_pyramid) - x1 = self.samples[infer_index](x) - logits = self.deeplabv3(x1) - logits = self.sample_common(logits) - logits = self.softmax(logits) - logits = self.expand_dims(logits, 2) - for i in range(len(self.infer_scale_sizes) - 1): - x_i = self.samples[i + 1 + infer_index](x) - logits_i = self.deeplabv3(x_i) - logits_i = self.sample_common(logits_i) - logits_i = self.softmax(logits_i) - logits_i = self.expand_dims(logits_i, 2) - logits = self.concat((logits, logits_i)) - logits = self.reduce_mean(logits, 2) - if not self.training: - logits = self.argmax(logits) - return logits - - -def deeplabv3_resnet50(num_classes, feature_shape, image_pyramid, - infer_scale_sizes, atrous_rates=None, decoder_output_stride=None, - output_stride=16, fine_tune_batch_norm=False): - """ - ResNet50 based DeepLabv3 network. - - Args: - num_classes (int): Class number. - feature_shape (list): Input image shape, [N,C,H,W]. - image_pyramid (list): Input scales for multi-scale feature extraction. - atrous_rates (list): Atrous rates for atrous spatial pyramid pooling. - infer_scale_sizes (list): 'The scales to resize images for inference. - decoder_output_stride (int): 'The ratio of input to output spatial resolution' - output_stride (int): 'The ratio of input to output spatial resolution.' - fine_tune_batch_norm (bool): 'Fine tune the batch norm parameters or not' - - Returns: - Cell, cell instance of ResNet50 based DeepLabv3 neural network. - - Examples: - >>> deeplabv3_resnet50(100, [1,3,224,224],[1.0],[1.0]) - """ - return DeepLabV3(num_classes=num_classes, - feature_shape=feature_shape, - backbone=resnet50_dl(fine_tune_batch_norm), - channel=2048, - depth=256, - infer_scale_sizes=infer_scale_sizes, - atrous_rates=atrous_rates, - decoder_output_stride=decoder_output_stride, - output_stride=output_stride, - fine_tune_batch_norm=fine_tune_batch_norm, - image_pyramid=image_pyramid) diff --git a/model_zoo/official/cv/deeplabv3/src/ei_dataset.py b/model_zoo/official/cv/deeplabv3/src/ei_dataset.py deleted file mode 100644 index 8b471065aef0dab5d21cd07b2546ba7006ad0535..0000000000000000000000000000000000000000 --- a/model_zoo/official/cv/deeplabv3/src/ei_dataset.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the License); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# httpwww.apache.orglicensesLICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an AS IS BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Process Dataset.""" -import abc -import os -import time - -from .utils.adapter import get_raw_samples, read_image - - -class BaseDataset: - """ - Create dataset. - - Args: - data_url (str): The path of data. - usage (str): Whether to use train or eval (default='train'). - - Returns: - Dataset. - """ - def __init__(self, data_url, usage): - self.data_url = data_url - self.usage = usage - self.cur_index = 0 - self.samples = [] - _s_time = time.time() - self._load_samples() - _e_time = time.time() - print(f"load samples success~, time cost = {_e_time - _s_time}") - - def __getitem__(self, item): - sample = self.samples[item] - return self._next_data(sample) - - def __len__(self): - return len(self.samples) - - @staticmethod - def _next_data(sample): - image_path = sample[0] - mask_image_path = sample[1] - - image = read_image(image_path) - mask_image = read_image(mask_image_path) - return [image, mask_image] - - @abc.abstractmethod - def _load_samples(self): - pass - - -class HwVocRawDataset(BaseDataset): - """ - Create dataset with raw data. - - Args: - data_url (str): The path of data. - usage (str): Whether to use train or eval (default='train'). - - Returns: - Dataset. - """ - def __init__(self, data_url, usage="train"): - super().__init__(data_url, usage) - - def _load_samples(self): - try: - self.samples = get_raw_samples(os.path.join(self.data_url, self.usage)) - except Exception as e: - print("load HwVocRawDataset failed!!!") - raise e diff --git a/model_zoo/official/cv/deeplabv3/src/loss/__init__.py b/model_zoo/official/cv/deeplabv3/src/loss/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/model_zoo/official/cv/deeplabv3/src/loss/loss.py b/model_zoo/official/cv/deeplabv3/src/loss/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..377f099bd7cbdeb337e86625a758e149d74a4e1a --- /dev/null +++ b/model_zoo/official/cv/deeplabv3/src/loss/loss.py @@ -0,0 +1,50 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from mindspore import Tensor +import mindspore.common.dtype as mstype +import mindspore.nn as nn +from mindspore.ops import operations as P + + +class SoftmaxCrossEntropyLoss(nn.Cell): + def __init__(self, num_cls=21, ignore_label=255): + super(SoftmaxCrossEntropyLoss, self).__init__() + self.one_hot = P.OneHot(axis=-1) + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0.0, mstype.float32) + self.cast = P.Cast() + self.ce = nn.SoftmaxCrossEntropyWithLogits() + self.not_equal = P.NotEqual() + self.num_cls = num_cls + self.ignore_label = ignore_label + self.mul = P.Mul() + self.sum = P.ReduceSum(False) + self.div = P.RealDiv() + self.transpose = P.Transpose() + self.reshape = P.Reshape() + + def construct(self, logits, labels): + labels_int = self.cast(labels, mstype.int32) + labels_int = self.reshape(labels_int, (-1,)) + logits_ = self.transpose(logits, (0, 2, 3, 1)) + logits_ = self.reshape(logits_, (-1, self.num_cls)) + weights = self.not_equal(labels_int, self.ignore_label) + weights = self.cast(weights, mstype.float32) + one_hot_labels = self.one_hot(labels_int, self.num_cls, self.on_value, self.off_value) + loss = self.ce(logits_, one_hot_labels) + loss = self.mul(weights, loss) + loss = self.div(self.sum(loss), self.sum(weights)) + return loss diff --git a/model_zoo/official/cv/deeplabv3/src/losses.py b/model_zoo/official/cv/deeplabv3/src/losses.py deleted file mode 100644 index db45cbb6b670734a33fc98157df65f965bde608d..0000000000000000000000000000000000000000 --- a/model_zoo/official/cv/deeplabv3/src/losses.py +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""OhemLoss.""" -import mindspore.nn as nn -import mindspore.common.dtype as mstype -from mindspore.ops import operations as P -from mindspore.ops import functional as F - - -class OhemLoss(nn.Cell): - """Ohem loss cell.""" - def __init__(self, num, ignore_label): - super(OhemLoss, self).__init__() - self.mul = P.Mul() - self.shape = P.Shape() - self.one_hot = nn.OneHot(-1, num, 1.0, 0.0) - self.squeeze = P.Squeeze() - self.num = num - self.cross_entropy = P.SoftmaxCrossEntropyWithLogits() - self.mean = P.ReduceMean() - self.select = P.Select() - self.reshape = P.Reshape() - self.cast = P.Cast() - self.not_equal = P.NotEqual() - self.equal = P.Equal() - self.reduce_sum = P.ReduceSum(keep_dims=False) - self.fill = P.Fill() - self.transpose = P.Transpose() - self.ignore_label = ignore_label - self.loss_weight = 1.0 - - def construct(self, logits, labels): - if not self.training: - return 0 - logits = self.transpose(logits, (0, 2, 3, 1)) - logits = self.reshape(logits, (-1, self.num)) - labels = F.cast(labels, mstype.int32) - labels = self.reshape(labels, (-1,)) - one_hot_labels = self.one_hot(labels) - losses = self.cross_entropy(logits, one_hot_labels)[0] - weights = self.cast(self.not_equal(labels, self.ignore_label), mstype.float32) * self.loss_weight - weighted_losses = self.mul(losses, weights) - loss = self.reduce_sum(weighted_losses, (0,)) - zeros = self.fill(mstype.float32, self.shape(weights), 0.0) - ones = self.fill(mstype.float32, self.shape(weights), 1.0) - present = self.select(self.equal(weights, zeros), zeros, ones) - present = self.reduce_sum(present, (0,)) - - zeros = self.fill(mstype.float32, self.shape(present), 0.0) - min_control = self.fill(mstype.float32, self.shape(present), 1.0) - present = self.select(self.equal(present, zeros), min_control, present) - loss = loss / present - return loss diff --git a/model_zoo/official/cv/deeplabv3/src/md_dataset.py b/model_zoo/official/cv/deeplabv3/src/md_dataset.py deleted file mode 100644 index 358c28ef2af8510bce75ba120b1185a4e1eb727c..0000000000000000000000000000000000000000 --- a/model_zoo/official/cv/deeplabv3/src/md_dataset.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the License); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# httpwww.apache.orglicensesLICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an AS IS BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Dataset module.""" -import numpy as np -from PIL import Image -import mindspore.dataset as de -import mindspore.dataset.transforms.vision.c_transforms as C - -from .ei_dataset import HwVocRawDataset -from .utils import custom_transforms as tr - - -class DataTransform: - """Transform dataset for DeepLabV3.""" - - def __init__(self, args, usage): - self.args = args - self.usage = usage - - def __call__(self, image, label): - if self.usage == "train": - return self._train(image, label) - if self.usage == "eval": - return self._eval(image, label) - return None - - def _train(self, image, label): - """ - Process training data. - - Args: - image (list): Image data. - label (list): Dataset label. - """ - image = Image.fromarray(image) - label = Image.fromarray(label) - - rsc_tr = tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size) - image, label = rsc_tr(image, label) - - rhf_tr = tr.RandomHorizontalFlip() - image, label = rhf_tr(image, label) - - image = np.array(image).astype(np.float32) - label = np.array(label).astype(np.float32) - - return image, label - - def _eval(self, image, label): - """ - Process eval data. - - Args: - image (list): Image data. - label (list): Dataset label. - """ - image = Image.fromarray(image) - label = Image.fromarray(label) - - fsc_tr = tr.FixScaleCrop(crop_size=self.args.crop_size) - image, label = fsc_tr(image, label) - - image = np.array(image).astype(np.float32) - label = np.array(label).astype(np.float32) - - return image, label - - -def create_dataset(args, data_url, epoch_num=1, batch_size=1, usage="train", shuffle=True): - """ - Create Dataset for DeepLabV3. - - Args: - args (dict): Train parameters. - data_url (str): Dataset path. - epoch_num (int): Epoch of dataset (default=1). - batch_size (int): Batch size of dataset (default=1). - usage (str): Whether is use to train or eval (default='train'). - - Returns: - Dataset. - """ - # create iter dataset - dataset = HwVocRawDataset(data_url, usage=usage) - dataset_len = len(dataset) - - # wrapped with GeneratorDataset - dataset = de.GeneratorDataset(dataset, ["image", "label"], sampler=None) - dataset.set_dataset_size(dataset_len) - dataset = dataset.map(input_columns=["image", "label"], operations=DataTransform(args, usage=usage)) - - channelswap_op = C.HWC2CHW() - dataset = dataset.map(input_columns="image", operations=channelswap_op) - - # 1464 samples / batch_size 8 = 183 batches - # epoch_num is num of steps - # 3658 steps / 183 = 20 epochs - if usage == "train" and shuffle: - dataset = dataset.shuffle(1464) - dataset = dataset.batch(batch_size, drop_remainder=(usage == "train")) - dataset = dataset.repeat(count=epoch_num) - dataset.map_model = 4 - - return dataset diff --git a/model_zoo/official/cv/deeplabv3/src/miou_precision.py b/model_zoo/official/cv/deeplabv3/src/miou_precision.py deleted file mode 100644 index 8b3e2f5d0840fff9a320b31729ab6850667c19e3..0000000000000000000000000000000000000000 --- a/model_zoo/official/cv/deeplabv3/src/miou_precision.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""mIou.""" -import numpy as np -from mindspore.nn.metrics.metric import Metric - - -def confuse_matrix(target, pred, n): - k = (target >= 0) & (target < n) - return np.bincount(n * target[k].astype(int) + pred[k], minlength=n ** 2).reshape(n, n) - - -def iou(hist): - denominator = hist.sum(1) + hist.sum(0) - np.diag(hist) - res = np.diag(hist) / np.where(denominator > 0, denominator, 1) - res = np.sum(res) / np.count_nonzero(denominator) - return res - - -class MiouPrecision(Metric): - """Calculate miou precision.""" - def __init__(self, num_class=21): - super(MiouPrecision, self).__init__() - if not isinstance(num_class, int): - raise TypeError('num_class should be integer type, but got {}'.format(type(num_class))) - if num_class < 1: - raise ValueError('num_class must be at least 1, but got {}'.format(num_class)) - self._num_class = num_class - self._mIoU = [] - self.clear() - - def clear(self): - self._hist = np.zeros((self._num_class, self._num_class)) - self._mIoU = [] - - def update(self, *inputs): - if len(inputs) != 2: - raise ValueError('Need 2 inputs (y_pred, y), but got {}'.format(len(inputs))) - predict_in = self._convert_data(inputs[0]) - label_in = self._convert_data(inputs[1]) - pred = predict_in - label = label_in - if len(label.flatten()) != len(pred.flatten()): - print('Skipping: len(gt) = {:d}, len(pred) = {:d}'.format(len(label.flatten()), len(pred.flatten()))) - raise ValueError('Class number not match, last input data contain {} classes, but current data contain {} ' - 'classes'.format(self._num_class, predict_in.shape[1])) - self._hist = confuse_matrix(label.flatten(), pred.flatten(), self._num_class) - mIoUs = iou(self._hist) - self._mIoU.append(mIoUs) - - def eval(self): - """ - Computes the mIoU categorical accuracy. - """ - mIoU = np.nanmean(self._mIoU) - print('mIoU = {}'.format(mIoU)) - return mIoU diff --git a/model_zoo/official/cv/deeplabv3/src/nets/__init__.py b/model_zoo/official/cv/deeplabv3/src/nets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/model_zoo/official/cv/deeplabv3/src/nets/deeplab_v3/__init__.py b/model_zoo/official/cv/deeplabv3/src/nets/deeplab_v3/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/model_zoo/official/cv/deeplabv3/src/nets/deeplab_v3/deeplab_v3.py b/model_zoo/official/cv/deeplabv3/src/nets/deeplab_v3/deeplab_v3.py new file mode 100644 index 0000000000000000000000000000000000000000..74d11cd52fceb9ffbaf5c0575bf1834d919ef9cb --- /dev/null +++ b/model_zoo/official/cv/deeplabv3/src/nets/deeplab_v3/deeplab_v3.py @@ -0,0 +1,219 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import mindspore.nn as nn +from mindspore.ops import operations as P + + +def conv1x1(in_planes, out_planes, stride=1): + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, weight_init='xavier_uniform') + + +def conv3x3(in_planes, out_planes, stride=1, dilation=1, padding=1): + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, pad_mode='pad', padding=padding, + dilation=dilation, weight_init='xavier_uniform') + + +class Resnet(nn.Cell): + def __init__(self, block, block_num, output_stride, use_batch_statistics=True): + super(Resnet, self).__init__() + self.inplanes = 64 + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, pad_mode='pad', padding=3, + weight_init='xavier_uniform') + self.bn1 = nn.BatchNorm2d(self.inplanes, use_batch_statistics=use_batch_statistics) + self.relu = nn.ReLU() + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same') + self.layer1 = self._make_layer(block, 64, block_num[0], use_batch_statistics=use_batch_statistics) + self.layer2 = self._make_layer(block, 128, block_num[1], stride=2, use_batch_statistics=use_batch_statistics) + + if output_stride == 16: + self.layer3 = self._make_layer(block, 256, block_num[2], stride=2, + use_batch_statistics=use_batch_statistics) + self.layer4 = self._make_layer(block, 512, block_num[3], stride=1, base_dilation=2, grids=[1, 2, 4], + use_batch_statistics=use_batch_statistics) + elif output_stride == 8: + self.layer3 = self._make_layer(block, 256, block_num[2], stride=1, base_dilation=2, + use_batch_statistics=use_batch_statistics) + self.layer4 = self._make_layer(block, 512, block_num[3], stride=1, base_dilation=4, grids=[1, 2, 4], + use_batch_statistics=use_batch_statistics) + + def _make_layer(self, block, planes, blocks, stride=1, base_dilation=1, grids=None, use_batch_statistics=True): + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.SequentialCell([ + conv1x1(self.inplanes, planes * block.expansion, stride), + nn.BatchNorm2d(planes * block.expansion, use_batch_statistics=use_batch_statistics) + ]) + + if grids is None: + grids = [1] * blocks + + layers = [ + block(self.inplanes, planes, stride, downsample, dilation=base_dilation * grids[0], + use_batch_statistics=use_batch_statistics) + ] + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append( + block(self.inplanes, planes, dilation=base_dilation * grids[i], + use_batch_statistics=use_batch_statistics)) + + return nn.SequentialCell(layers) + + def construct(self, x): + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.maxpool(out) + + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + return out + + +class Bottleneck(nn.Cell): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1, use_batch_statistics=True): + super(Bottleneck, self).__init__() + self.conv1 = conv1x1(inplanes, planes) + self.bn1 = nn.BatchNorm2d(planes, use_batch_statistics=use_batch_statistics) + + self.conv2 = conv3x3(planes, planes, stride, dilation, dilation) + self.bn2 = nn.BatchNorm2d(planes, use_batch_statistics=use_batch_statistics) + + self.conv3 = conv1x1(planes, planes * self.expansion) + self.bn3 = nn.BatchNorm2d(planes * self.expansion, use_batch_statistics=use_batch_statistics) + + self.relu = nn.ReLU() + self.downsample = downsample + + self.add = P.TensorAdd() + + def construct(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out = self.add(out, identity) + out = self.relu(out) + return out + + +class ASPP(nn.Cell): + def __init__(self, atrous_rates, phase='train', in_channels=2048, num_classes=21, + use_batch_statistics=True): + super(ASPP, self).__init__() + self.phase = phase + out_channels = 256 + self.aspp1 = ASPPConv(in_channels, out_channels, atrous_rates[0], use_batch_statistics=use_batch_statistics) + self.aspp2 = ASPPConv(in_channels, out_channels, atrous_rates[1], use_batch_statistics=use_batch_statistics) + self.aspp3 = ASPPConv(in_channels, out_channels, atrous_rates[2], use_batch_statistics=use_batch_statistics) + self.aspp4 = ASPPConv(in_channels, out_channels, atrous_rates[3], use_batch_statistics=use_batch_statistics) + self.aspp_pooling = ASPPPooling(in_channels, out_channels) + self.conv1 = nn.Conv2d(out_channels * (len(atrous_rates) + 1), out_channels, kernel_size=1, + weight_init='xavier_uniform') + self.bn1 = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics) + self.relu = nn.ReLU() + self.conv2 = nn.Conv2d(out_channels, num_classes, kernel_size=1, weight_init='xavier_uniform', has_bias=True) + self.concat = P.Concat(axis=1) + self.drop = nn.Dropout(0.3) + + def construct(self, x): + x1 = self.aspp1(x) + x2 = self.aspp2(x) + x3 = self.aspp3(x) + x4 = self.aspp4(x) + x5 = self.aspp_pooling(x) + + x = self.concat((x1, x2)) + x = self.concat((x, x3)) + x = self.concat((x, x4)) + x = self.concat((x, x5)) + + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + if self.phase == 'train': + x = self.drop(x) + x = self.conv2(x) + return x + + +class ASPPPooling(nn.Cell): + def __init__(self, in_channels, out_channels, use_batch_statistics=True): + super(ASPPPooling, self).__init__() + self.conv = nn.SequentialCell([ + nn.Conv2d(in_channels, out_channels, kernel_size=1, weight_init='xavier_uniform'), + nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics), + nn.ReLU() + ]) + self.shape = P.Shape() + + def construct(self, x): + size = self.shape(x) + out = nn.AvgPool2d(size[2])(x) + out = self.conv(out) + out = P.ResizeNearestNeighbor((size[2], size[3]), True)(out) + return out + + +class ASPPConv(nn.Cell): + def __init__(self, in_channels, out_channels, atrous_rate=1, use_batch_statistics=True): + super(ASPPConv, self).__init__() + if atrous_rate == 1: + conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, has_bias=False, weight_init='xavier_uniform') + else: + conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, pad_mode='pad', padding=atrous_rate, + dilation=atrous_rate, weight_init='xavier_uniform') + bn = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics) + relu = nn.ReLU() + self.aspp_conv = nn.SequentialCell([conv, bn, relu]) + + def construct(self, x): + out = self.aspp_conv(x) + return out + + +class DeepLabV3(nn.Cell): + def __init__(self, phase='train', num_classes=21, output_stride=16, freeze_bn=False): + super(DeepLabV3, self).__init__() + use_batch_statistics = not freeze_bn + self.resnet = Resnet(Bottleneck, [3, 4, 23, 3], output_stride=output_stride, + use_batch_statistics=use_batch_statistics) + self.aspp = ASPP([1, 6, 12, 18], phase, 2048, num_classes, + use_batch_statistics=use_batch_statistics) + self.shape = P.Shape() + + def construct(self, x): + size = self.shape(x) + out = self.resnet(x) + out = self.aspp(out) + out = P.ResizeBilinear((size[2], size[3]), True)(out) + return out diff --git a/model_zoo/official/cv/deeplabv3/src/config.py b/model_zoo/official/cv/deeplabv3/src/nets/net_factory.py similarity index 51% rename from model_zoo/official/cv/deeplabv3/src/config.py rename to model_zoo/official/cv/deeplabv3/src/nets/net_factory.py index 6b5519e46cc766fe822fe61d2d25952b11cd1738..8423844a3a46ec4ff733749741f341152a654c6f 100644 --- a/model_zoo/official/cv/deeplabv3/src/config.py +++ b/model_zoo/official/cv/deeplabv3/src/nets/net_factory.py @@ -1,38 +1,18 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -network config setting, will be used in train.py and evaluation.py -""" -from easydict import EasyDict as ed - -config = ed({ - "learning_rate": 0.0014, - "weight_decay": 0.00005, - "momentum": 0.97, - "crop_size": 513, - "eval_scales": [0.5, 0.75, 1.0, 1.25, 1.5, 1.75], - "atrous_rates": None, - "image_pyramid": None, - "output_stride": 16, - "fine_tune_batch_norm": False, - "ignore_label": 255, - "decoder_output_stride": None, - "seg_num_classes": 21, - "epoch_size": 6, - "batch_size": 2, - "enable_save_ckpt": True, - "save_checkpoint_steps": 10000, - "save_checkpoint_num": 1 -}) +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from src.nets.deeplab_v3 import deeplab_v3 +nets_map = {'deeplab_v3_s8': deeplab_v3.DeepLabV3, + 'deeplab_v3_s16': deeplab_v3.DeepLabV3} diff --git a/model_zoo/official/cv/deeplabv3/src/tools/__init__.py b/model_zoo/official/cv/deeplabv3/src/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/model_zoo/official/cv/deeplabv3/src/tools/get_multicards_json.py b/model_zoo/official/cv/deeplabv3/src/tools/get_multicards_json.py new file mode 100644 index 0000000000000000000000000000000000000000..d07f9b65cbbd04620a9f122c4f797e20b14cf2a7 --- /dev/null +++ b/model_zoo/official/cv/deeplabv3/src/tools/get_multicards_json.py @@ -0,0 +1,66 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import os +import sys + + +def get_multicards_json(server_id): + hccn_configs = open('/etc/hccn.conf', 'r').readlines() + device_ips = {} + for hccn_item in hccn_configs: + hccn_item = hccn_item.strip() + if hccn_item.startswith('address_'): + device_id, device_ip = hccn_item.split('=') + device_id = device_id.split('_')[1] + device_ips[device_id] = device_ip + print('device_id:{}, device_ip:{}'.format(device_id, device_ip)) + hccn_table = {'board_id': '0x0000', 'chip_info': '910', 'deploy_mode': 'lab', 'group_count': '1', 'group_list': []} + instance_list = [] + usable_dev = '' + for instance_id in range(8): + instance = {'devices': []} + device_id = str(instance_id) + device_ip = device_ips[device_id] + usable_dev += str(device_id) + instance['devices'].append({ + 'device_id': device_id, + 'device_ip': device_ip, + }) + instance['rank_id'] = str(instance_id) + instance['server_id'] = server_id + instance_list.append(instance) + hccn_table['group_list'].append({ + 'device_num': '8', + 'server_num': '1', + 'group_name': '', + 'instance_count': '8', + 'instance_list': instance_list, + }) + hccn_table['para_plane_nic_location'] = 'device' + hccn_table['para_plane_nic_name'] = [] + for instance_id in range(8): + hccn_table['para_plane_nic_name'].append('eth{}'.format(instance_id)) + hccn_table['para_plane_nic_num'] = '8' + hccn_table['status'] = 'completed' + import json + table_fn = os.path.join(os.getcwd(), 'rank_table_8p.json') + print(table_fn) + with open(table_fn, 'w') as table_fp: + json.dump(hccn_table, table_fp, indent=4) + + +host_server_id = sys.argv[1] +get_multicards_json(host_server_id) diff --git a/model_zoo/official/cv/deeplabv3/src/utils/__init__.py b/model_zoo/official/cv/deeplabv3/src/utils/__init__.py index e30774307ca2107b3a81c071ad33c042ef924790..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644 --- a/model_zoo/official/cv/deeplabv3/src/utils/__init__.py +++ b/model_zoo/official/cv/deeplabv3/src/utils/__init__.py @@ -1,14 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ diff --git a/model_zoo/official/cv/deeplabv3/src/utils/adapter.py b/model_zoo/official/cv/deeplabv3/src/utils/adapter.py deleted file mode 100644 index 37173ebf48741e3f6837483bafaf68146fd89fc0..0000000000000000000000000000000000000000 --- a/model_zoo/official/cv/deeplabv3/src/utils/adapter.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the License); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# httpwww.apache.orglicensesLICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an AS IS BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Adapter dataset.""" -import fnmatch -import io -import os - -import numpy as np -from PIL import Image - -from ..utils import file_io - - -def get_raw_samples(data_url): - """ - Get dataset from raw data. - - Args: - data_url (str): Dataset path. - - Returns: - list, a file list. - """ - def _list_files(dir_path, pattern): - full_files = [] - _, _, files = next(file_io.walk(dir_path)) - for f in files: - if fnmatch.fnmatch(f.lower(), pattern.lower()): - full_files.append(os.path.join(dir_path, f)) - return full_files - - img_files = _list_files(os.path.join(data_url, "Images"), "*.jpg") - seg_files = _list_files(os.path.join(data_url, "SegmentationClassRaw"), "*.png") - - files = [] - for img_file in img_files: - _, file_name = os.path.split(img_file) - name, _ = os.path.splitext(file_name) - seg_file = os.path.join(data_url, "SegmentationClassRaw", ".".join([name, "png"])) - if seg_file in seg_files: - files.append([img_file, seg_file]) - return files - - -def read_image(img_path): - """ - Read image from file. - - Args: - img_path (str): image path. - """ - img = file_io.read(img_path.strip(), binary=True) - data = io.BytesIO(img) - img = Image.open(data) - return np.array(img) diff --git a/model_zoo/official/cv/deeplabv3/src/utils/custom_transforms.py b/model_zoo/official/cv/deeplabv3/src/utils/custom_transforms.py deleted file mode 100644 index 75c78e12409d4cdfbf779abe5092d95c19988b68..0000000000000000000000000000000000000000 --- a/model_zoo/official/cv/deeplabv3/src/utils/custom_transforms.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the License); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# httpwww.apache.orglicensesLICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an AS IS BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Random process dataset.""" -import random - -import numpy as np -from PIL import Image, ImageOps, ImageFilter - - -class Normalize: - """Normalize a tensor image with mean and standard deviation. - Args: - mean (tuple): means for each channel. - std (tuple): standard deviations for each channel. - """ - - def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)): - self.mean = mean - self.std = std - - def __call__(self, img, mask): - img = np.array(img).astype(np.float32) - mask = np.array(mask).astype(np.float32) - img = ((img - self.mean) / self.std).astype(np.float32) - - return img, mask - - -class RandomHorizontalFlip: - """Randomly decide whether to horizontal flip.""" - def __call__(self, img, mask): - if random.random() < 0.5: - img = img.transpose(Image.FLIP_LEFT_RIGHT) - mask = mask.transpose(Image.FLIP_LEFT_RIGHT) - - return img, mask - - -class RandomRotate: - """ - Randomly decide whether to rotate. - - Args: - degree (float): The degree of rotate. - """ - def __init__(self, degree): - self.degree = degree - - def __call__(self, img, mask): - rotate_degree = random.uniform(-1 * self.degree, self.degree) - img = img.rotate(rotate_degree, Image.BILINEAR) - mask = mask.rotate(rotate_degree, Image.NEAREST) - - return img, mask - - -class RandomGaussianBlur: - """Randomly decide whether to filter image with gaussian blur.""" - def __call__(self, img, mask): - if random.random() < 0.5: - img = img.filter(ImageFilter.GaussianBlur( - radius=random.random())) - - return img, mask - - -class RandomScaleCrop: - """Randomly decide whether to scale and crop image.""" - def __init__(self, base_size, crop_size, fill=0): - self.base_size = base_size - self.crop_size = crop_size - self.fill = fill - - def __call__(self, img, mask): - # random scale (short edge) - short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0)) - w, h = img.size - if h > w: - ow = short_size - oh = int(1.0 * h * ow / w) - else: - oh = short_size - ow = int(1.0 * w * oh / h) - img = img.resize((ow, oh), Image.BILINEAR) - mask = mask.resize((ow, oh), Image.NEAREST) - # pad crop - if short_size < self.crop_size: - padh = self.crop_size - oh if oh < self.crop_size else 0 - padw = self.crop_size - ow if ow < self.crop_size else 0 - img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) - mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=self.fill) - # random crop crop_size - w, h = img.size - x1 = random.randint(0, w - self.crop_size) - y1 = random.randint(0, h - self.crop_size) - img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) - mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) - - return img, mask - - -class FixScaleCrop: - """Scale and crop image with fixing size.""" - def __init__(self, crop_size): - self.crop_size = crop_size - - def __call__(self, img, mask): - w, h = img.size - if w > h: - oh = self.crop_size - ow = int(1.0 * w * oh / h) - else: - ow = self.crop_size - oh = int(1.0 * h * ow / w) - img = img.resize((ow, oh), Image.BILINEAR) - mask = mask.resize((ow, oh), Image.NEAREST) - # center crop - w, h = img.size - x1 = int(round((w - self.crop_size) / 2.)) - y1 = int(round((h - self.crop_size) / 2.)) - img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) - mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) - - return img, mask - - -class FixedResize: - """Resize image with fixing size.""" - def __init__(self, size): - self.size = (size, size) - - def __call__(self, img, mask): - assert img.size == mask.size - - img = img.resize(self.size, Image.BILINEAR) - mask = mask.resize(self.size, Image.NEAREST) - return img, mask diff --git a/model_zoo/official/cv/deeplabv3/src/utils/learning_rates.py b/model_zoo/official/cv/deeplabv3/src/utils/learning_rates.py new file mode 100644 index 0000000000000000000000000000000000000000..c70e84f316cbc8b766437fa680f208be09f1c238 --- /dev/null +++ b/model_zoo/official/cv/deeplabv3/src/utils/learning_rates.py @@ -0,0 +1,37 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np + + +def cosine_lr(base_lr, decay_steps, total_steps): + for i in range(total_steps): + step_ = min(i, decay_steps) + yield base_lr * 0.5 * (1 + np.cos(np.pi * step_ / decay_steps)) + + +def poly_lr(base_lr, decay_steps, total_steps, end_lr=0.0001, power=0.9): + for i in range(total_steps): + step_ = min(i, decay_steps) + yield (base_lr - end_lr) * ((1.0 - step_ / decay_steps) ** power) + end_lr + + +def exponential_lr(base_lr, decay_steps, decay_rate, total_steps, staircase=False): + for i in range(total_steps): + if staircase: + power_ = i // decay_steps + else: + power_ = float(i) / decay_steps + yield base_lr * (decay_rate ** power_) diff --git a/model_zoo/official/cv/deeplabv3/train.py b/model_zoo/official/cv/deeplabv3/train.py index da84215fd968bd58480fdc5f92ff288128a34956..e28b065a788595d3d4c38cf3d19af5311a31cfa8 100644 --- a/model_zoo/official/cv/deeplabv3/train.py +++ b/model_zoo/official/cv/deeplabv3/train.py @@ -1,91 +1,166 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""train.""" -import argparse -from mindspore import context -from mindspore.communication.management import init -from mindspore.nn.optim.momentum import Momentum -from mindspore import Model -from mindspore.context import ParallelMode -from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindspore.train.callback import Callback, CheckpointConfig, ModelCheckpoint, TimeMonitor -from mindspore.common import set_seed -from src.md_dataset import create_dataset -from src.losses import OhemLoss -from src.deeplabv3 import deeplabv3_resnet50 -from src.config import config - -set_seed(1) - -parser = argparse.ArgumentParser(description="Deeplabv3 training") -parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.") -parser.add_argument('--data_url', required=True, default=None, help='Train data url') -parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") -parser.add_argument('--checkpoint_url', default=None, help='Checkpoint path') - -args_opt = parser.parse_args() -print(args_opt) -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) -class LossCallBack(Callback): - """ - Monitor the loss in training. - Note: - if per_print_times is 0 do not print loss. - Args: - per_print_times (int): Print loss every times. Default: 1. - """ - def __init__(self, per_print_times=1): - super(LossCallBack, self).__init__() - if not isinstance(per_print_times, int) or per_print_times < 0: - raise ValueError("print_step must be int and >= 0") - self._per_print_times = per_print_times - def step_end(self, run_context): - cb_params = run_context.original_args() - print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num, - str(cb_params.net_outputs))) -def model_fine_tune(flags, train_net, fix_weight_layer): - checkpoint_path = flags.checkpoint_url - if checkpoint_path is None: - return - param_dict = load_checkpoint(checkpoint_path) - load_param_into_net(train_net, param_dict) - for para in train_net.trainable_params(): - if fix_weight_layer in para.name: - para.requires_grad = False -if __name__ == "__main__": - if args_opt.distribute == "true": - context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True) - init() - args_opt.base_size = config.crop_size - args_opt.crop_size = config.crop_size - train_dataset = create_dataset(args_opt, args_opt.data_url, 1, config.batch_size, usage="train") - dataset_size = train_dataset.get_dataset_size() - time_cb = TimeMonitor(data_size=dataset_size) - callback = [time_cb, LossCallBack()] - if config.enable_save_ckpt: - config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps, - keep_checkpoint_max=config.save_checkpoint_num) - ckpoint_cb = ModelCheckpoint(prefix='checkpoint_deeplabv3', config=config_ck) - callback.append(ckpoint_cb) - net = deeplabv3_resnet50(config.seg_num_classes, [config.batch_size, 3, args_opt.crop_size, args_opt.crop_size], - infer_scale_sizes=config.eval_scales, atrous_rates=config.atrous_rates, - decoder_output_stride=config.decoder_output_stride, output_stride=config.output_stride, - fine_tune_batch_norm=config.fine_tune_batch_norm, image_pyramid=config.image_pyramid) - net.set_train() - model_fine_tune(args_opt, net, 'layer') - loss = OhemLoss(config.seg_num_classes, config.ignore_label) - opt = Momentum(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'depth' not in x.name and 'bias' not in x.name, net.trainable_params()), learning_rate=config.learning_rate, momentum=config.momentum, weight_decay=config.weight_decay) - model = Model(net, loss, opt) - model.train(config.epoch_size, train_dataset, callback) +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""train deeplabv3.""" + +import os +import argparse +from mindspore import context +from mindspore.train.model import ParallelMode, Model +import mindspore.nn as nn +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.communication.management import init, get_rank, get_group_size +from mindspore.train.callback import LossMonitor, TimeMonitor +from mindspore.train.loss_scale_manager import FixedLossScaleManager +from src.data import data_generator +from src.loss import loss +from src.nets import net_factory +from src.utils import learning_rates +context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, save_graphs=False, + device_target="Ascend", device_id=int(os.getenv('DEVICE_ID'))) + + +class BuildTrainNetwork(nn.Cell): + def __init__(self, network, criterion): + super(BuildTrainNetwork, self).__init__() + self.network = network + self.criterion = criterion + + def construct(self, input_data, label): + output = self.network(input_data) + net_loss = self.criterion(output, label) + return net_loss + + +def parse_args(): + parser = argparse.ArgumentParser('mindspore deeplabv3 training') + parser.add_argument('--train_dir', type=str, default='', help='where training log and ckpts saved') + + # dataset + parser.add_argument('--data_file', type=str, default='', help='path and name of one mindrecord file') + parser.add_argument('--batch_size', type=int, default=32, help='batch size') + parser.add_argument('--crop_size', type=int, default=513, help='crop size') + parser.add_argument('--image_mean', type=list, default=[103.53, 116.28, 123.675], help='image mean') + parser.add_argument('--image_std', type=list, default=[57.375, 57.120, 58.395], help='image std') + parser.add_argument('--min_scale', type=float, default=0.5, help='minimum scale of data argumentation') + parser.add_argument('--max_scale', type=float, default=2.0, help='maximum scale of data argumentation') + parser.add_argument('--ignore_label', type=int, default=255, help='ignore label') + parser.add_argument('--num_classes', type=int, default=21, help='number of classes') + + # optimizer + parser.add_argument('--train_epochs', type=int, default=300, help='epoch') + parser.add_argument('--lr_type', type=str, default='cos', help='type of learning rate') + parser.add_argument('--base_lr', type=float, default=0.015, help='base learning rate') + parser.add_argument('--lr_decay_step', type=int, default=40000, help='learning rate decay step') + parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='learning rate decay rate') + parser.add_argument('--loss_scale', type=float, default=3072.0, help='loss scale') + + # model + parser.add_argument('--model', type=str, default='deeplab_v3_s16', help='select model') + parser.add_argument('--freeze_bn', action='store_true', help='freeze bn') + parser.add_argument('--ckpt_pre_trained', type=str, default='', help='pretrained model') + + # train + parser.add_argument('--is_distributed', action='store_true', help='distributed training') + parser.add_argument('--rank', type=int, default=0, help='local rank of distributed') + parser.add_argument('--group_size', type=int, default=1, help='world size of distributed') + parser.add_argument('--save_steps', type=int, default=3000, help='steps interval for saving') + parser.add_argument('--keep_checkpoint_max', type=int, default=int, help='max checkpoint for saving') + + args, _ = parser.parse_known_args() + return args + + +def train(): + args = parse_args() + + # init multicards training + if args.is_distributed: + init() + args.rank = get_rank() + args.group_size = get_group_size() + + parallel_mode = ParallelMode.DATA_PARALLEL + context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=args.group_size) + + # dataset + dataset = data_generator.SegDataset(image_mean=args.image_mean, + image_std=args.image_std, + data_file=args.data_file, + batch_size=args.batch_size, + crop_size=args.crop_size, + max_scale=args.max_scale, + min_scale=args.min_scale, + ignore_label=args.ignore_label, + num_classes=args.num_classes, + num_readers=2, + num_parallel_calls=4, + shard_id=args.rank, + shard_num=args.group_size) + dataset = dataset.get_dataset(repeat=1) + + # network + if args.model == 'deeplab_v3_s16': + network = net_factory.nets_map[args.model]('train', args.num_classes, 16, args.freeze_bn) + elif args.model == 'deeplab_v3_s8': + network = net_factory.nets_map[args.model]('train', args.num_classes, 8, args.freeze_bn) + else: + raise NotImplementedError('model [{:s}] not recognized'.format(args.model)) + + # loss + loss_ = loss.SoftmaxCrossEntropyLoss(args.num_classes, args.ignore_label) + loss_.add_flags_recursive(fp32=True) + train_net = BuildTrainNetwork(network, loss_) + + # load pretrained model + if args.ckpt_pre_trained: + param_dict = load_checkpoint(args.ckpt_pre_trained) + load_param_into_net(train_net, param_dict) + + # optimizer + iters_per_epoch = dataset.get_dataset_size() + total_train_steps = iters_per_epoch * args.train_epochs + if args.lr_type == 'cos': + lr_iter = learning_rates.cosine_lr(args.base_lr, total_train_steps, total_train_steps) + elif args.lr_type == 'poly': + lr_iter = learning_rates.poly_lr(args.base_lr, total_train_steps, total_train_steps, end_lr=0.0, power=0.9) + elif args.lr_type == 'exp': + lr_iter = learning_rates.exponential_lr(args.base_lr, args.lr_decay_step, args.lr_decay_rate, + total_train_steps, staircase=True) + else: + raise ValueError('unknown learning rate type') + opt = nn.Momentum(params=train_net.trainable_params(), learning_rate=lr_iter, momentum=0.9, weight_decay=0.0001, + loss_scale=args.loss_scale) + + # loss scale + manager_loss_scale = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False) + model = Model(train_net, optimizer=opt, amp_level="O3", loss_scale_manager=manager_loss_scale) + + # callback for saving ckpts + time_cb = TimeMonitor(data_size=iters_per_epoch) + loss_cb = LossMonitor() + cbs = [time_cb, loss_cb] + + if args.rank == 0: + config_ck = CheckpointConfig(save_checkpoint_steps=args.save_steps, + keep_checkpoint_max=args.keep_checkpoint_max) + ckpoint_cb = ModelCheckpoint(prefix=args.model, directory=args.train_dir, config=config_ck) + cbs.append(ckpoint_cb) + + model.train(args.train_epochs, dataset, callbacks=cbs) + + +if __name__ == '__main__': + train() diff --git a/model_zoo/official/cv/resnet/scripts/run_distribute_train.sh b/model_zoo/official/cv/resnet/scripts/run_distribute_train.sh index 7117805f1a115a58e3408c86fb869796f25105ed..64142643ce53ec16a4a972fa35199339917ca749 100755 --- a/model_zoo/official/cv/resnet/scripts/run_distribute_train.sh +++ b/model_zoo/official/cv/resnet/scripts/run_distribute_train.sh @@ -88,7 +88,7 @@ rank_start=$((DEVICE_NUM * SERVER_ID)) for((i=0; i<${DEVICE_NUM}; i++)) do - export DEVICE_ID=$i + export DEVICE_ID=${i} export RANK_ID=$((rank_start + i)) rm -rf ./train_parallel$i mkdir ./train_parallel$i