From c530e15e091c394cf1707c86b0d552a03f498224 Mon Sep 17 00:00:00 2001 From: chenzomi Date: Tue, 30 Jun 2020 14:04:15 +0800 Subject: [PATCH] add mobilenet v2 quant and resnet50 quant to model_zoo --- model_zoo/lenet_quant/README.md | 6 +- model_zoo/lenet_quant/eval.py | 2 +- model_zoo/lenet_quant/eval_quant.py | 4 +- model_zoo/lenet_quant/train.py | 5 +- model_zoo/lenet_quant/train_quant.py | 11 +- model_zoo/mobilenetv2/scripts/run_train.sh | 2 +- model_zoo/mobilenetv2_quant/Readme.md | 142 ++++++++++ model_zoo/mobilenetv2_quant/eval.py | 76 ++++++ .../mobilenetv2_quant/scripts/run_infer.sh | 53 ++++ .../scripts/run_infer_quant.sh | 54 ++++ .../mobilenetv2_quant/scripts/run_train.sh | 62 +++++ .../scripts/run_train_quant.sh | 63 +++++ model_zoo/mobilenetv2_quant/src/config.py | 60 +++++ model_zoo/mobilenetv2_quant/src/dataset.py | 156 +++++++++++ model_zoo/mobilenetv2_quant/src/launch.py | 166 ++++++++++++ .../mobilenetv2_quant/src/lr_generator.py | 54 ++++ .../mobilenetv2_quant/src/mobilenetV2.py | 231 ++++++++++++++++ model_zoo/mobilenetv2_quant/src/utils.py | 113 ++++++++ model_zoo/mobilenetv2_quant/train.py | 131 +++++++++ model_zoo/mobilenetv3/scripts/run_train.sh | 2 +- model_zoo/resnet50_quant/Readme.md | 122 +++++++++ model_zoo/resnet50_quant/eval.py | 78 ++++++ .../resnet50_quant/models/resnet_quant.py | 251 ++++++++++++++++++ model_zoo/resnet50_quant/scripts/run_infer.sh | 54 ++++ model_zoo/resnet50_quant/scripts/run_train.sh | 62 +++++ model_zoo/resnet50_quant/src/config.py | 68 +++++ model_zoo/resnet50_quant/src/crossentropy.py | 39 +++ model_zoo/resnet50_quant/src/dataset.py | 157 +++++++++++ model_zoo/resnet50_quant/src/launch.py | 165 ++++++++++++ model_zoo/resnet50_quant/src/lr_generator.py | 87 ++++++ model_zoo/resnet50_quant/src/utils.py | 46 ++++ model_zoo/resnet50_quant/train.py | 153 +++++++++++ 32 files changed, 2659 insertions(+), 16 deletions(-) create mode 100644 model_zoo/mobilenetv2_quant/Readme.md create mode 100644 model_zoo/mobilenetv2_quant/eval.py create mode 100644 model_zoo/mobilenetv2_quant/scripts/run_infer.sh create mode 100644 model_zoo/mobilenetv2_quant/scripts/run_infer_quant.sh create mode 100644 model_zoo/mobilenetv2_quant/scripts/run_train.sh create mode 100644 model_zoo/mobilenetv2_quant/scripts/run_train_quant.sh create mode 100644 model_zoo/mobilenetv2_quant/src/config.py create mode 100644 model_zoo/mobilenetv2_quant/src/dataset.py create mode 100644 model_zoo/mobilenetv2_quant/src/launch.py create mode 100644 model_zoo/mobilenetv2_quant/src/lr_generator.py create mode 100644 model_zoo/mobilenetv2_quant/src/mobilenetV2.py create mode 100644 model_zoo/mobilenetv2_quant/src/utils.py create mode 100644 model_zoo/mobilenetv2_quant/train.py create mode 100644 model_zoo/resnet50_quant/Readme.md create mode 100755 model_zoo/resnet50_quant/eval.py create mode 100755 model_zoo/resnet50_quant/models/resnet_quant.py create mode 100644 model_zoo/resnet50_quant/scripts/run_infer.sh create mode 100644 model_zoo/resnet50_quant/scripts/run_train.sh create mode 100755 model_zoo/resnet50_quant/src/config.py create mode 100644 model_zoo/resnet50_quant/src/crossentropy.py create mode 100755 model_zoo/resnet50_quant/src/dataset.py create mode 100644 model_zoo/resnet50_quant/src/launch.py create mode 100755 model_zoo/resnet50_quant/src/lr_generator.py create mode 100644 model_zoo/resnet50_quant/src/utils.py create mode 100755 model_zoo/resnet50_quant/train.py diff --git a/model_zoo/lenet_quant/README.md b/model_zoo/lenet_quant/README.md index 2f949f6d7..2fd3e129a 100644 --- a/model_zoo/lenet_quant/README.md +++ b/model_zoo/lenet_quant/README.md @@ -33,7 +33,7 @@ Then you will get the following display ```bash >>> Found existing installation: mindspore-ascend >>> Uninstalling mindspore-ascend: ->>> Successfully uninstalled mindspore-ascend. +>>> Successfully uninstalled mindspore-ascend. ``` ### Prepare Dataset @@ -186,7 +186,7 @@ model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) ### train quantization aware model -Also, you can just run this command instread. +Also, you can just run this command instead. ```python python train_quant.py --data_path MNIST_Data --device_target Ascend --ckpt_path checkpoint_lenet.ckpt @@ -235,7 +235,7 @@ The top1 accuracy would display on shell. Here are some optional parameters: ```bash ---device_target {Ascend,GPU,CPU} +--device_target {Ascend,GPU} device where the code will be implemented (default: Ascend) --data_path DATA_PATH path where the dataset is saved diff --git a/model_zoo/lenet_quant/eval.py b/model_zoo/lenet_quant/eval.py index d94e77279..c0293ae1f 100644 --- a/model_zoo/lenet_quant/eval.py +++ b/model_zoo/lenet_quant/eval.py @@ -31,7 +31,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion parser = argparse.ArgumentParser(description='MindSpore MNIST Example') parser.add_argument('--device_target', type=str, default="Ascend", - choices=['Ascend', 'GPU', 'CPU'], + choices=['Ascend', 'GPU'], help='device where the code will be implemented (default: Ascend)') parser.add_argument('--data_path', type=str, default="./MNIST_Data", help='path where the dataset is saved') diff --git a/model_zoo/lenet_quant/eval_quant.py b/model_zoo/lenet_quant/eval_quant.py index 2c2477123..bc9b62121 100644 --- a/model_zoo/lenet_quant/eval_quant.py +++ b/model_zoo/lenet_quant/eval_quant.py @@ -32,7 +32,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion parser = argparse.ArgumentParser(description='MindSpore MNIST Example') parser.add_argument('--device_target', type=str, default="Ascend", - choices=['Ascend', 'GPU', 'CPU'], + choices=['Ascend', 'GPU'], help='device where the code will be implemented (default: Ascend)') parser.add_argument('--data_path', type=str, default="./MNIST_Data", help='path where the dataset is saved') @@ -61,7 +61,7 @@ if __name__ == "__main__": model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) # load quantization aware network checkpoint - param_dict = load_checkpoint(args.ckpt_path, model_type="quant") + param_dict = load_checkpoint(args.ckpt_path) load_param_into_net(network, param_dict) print("============== Starting Testing ==============") diff --git a/model_zoo/lenet_quant/train.py b/model_zoo/lenet_quant/train.py index b6040776e..a34b6d5ed 100644 --- a/model_zoo/lenet_quant/train.py +++ b/model_zoo/lenet_quant/train.py @@ -31,7 +31,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion parser = argparse.ArgumentParser(description='MindSpore MNIST Example') parser.add_argument('--device_target', type=str, default="Ascend", - choices=['Ascend', 'GPU', 'CPU'], + choices=['Ascend', 'GPU'], help='device where the code will be implemented (default: Ascend)') parser.add_argument('--data_path', type=str, default="./MNIST_Data", help='path where the dataset is saved') @@ -56,8 +56,7 @@ if __name__ == "__main__": # call back and monitor time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, - keep_checkpoint_max=cfg.keep_checkpoint_max, - model_type=network.type) + keep_checkpoint_max=cfg.keep_checkpoint_max) ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt) # define model diff --git a/model_zoo/lenet_quant/train_quant.py b/model_zoo/lenet_quant/train_quant.py index eb1f783a7..ba54e63d8 100644 --- a/model_zoo/lenet_quant/train_quant.py +++ b/model_zoo/lenet_quant/train_quant.py @@ -33,7 +33,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion parser = argparse.ArgumentParser(description='MindSpore MNIST Example') parser.add_argument('--device_target', type=str, default="Ascend", - choices=['Ascend', 'GPU', 'CPU'], + choices=['Ascend', 'GPU'], help='device where the code will be implemented (default: Ascend)') parser.add_argument('--data_path', type=str, default="./MNIST_Data", help='path where the dataset is saved') @@ -50,11 +50,13 @@ if __name__ == "__main__": # define fusion network network = LeNet5Fusion(cfg.num_classes) + + # convert fusion network to quantization aware network + network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000) + # load quantization aware network checkpoint param_dict = load_checkpoint(args.ckpt_path, network.type) load_param_into_net(network, param_dict) - # convert fusion network to quantization aware network - network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000) # define network loss net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") @@ -64,8 +66,7 @@ if __name__ == "__main__": # call back and monitor time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, - keep_checkpoint_max=cfg.keep_checkpoint_max, - model_type="quant") + keep_checkpoint_max=cfg.keep_checkpoint_max) ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt) # define model diff --git a/model_zoo/mobilenetv2/scripts/run_train.sh b/model_zoo/mobilenetv2/scripts/run_train.sh index fc013d474..9b9c13a00 100644 --- a/model_zoo/mobilenetv2/scripts/run_train.sh +++ b/model_zoo/mobilenetv2/scripts/run_train.sh @@ -30,7 +30,7 @@ run_ascend() BASEPATH=$(cd "`dirname $0`" || exit; pwd) export PYTHONPATH=${BASEPATH}:$PYTHONPATH - if [ -d "train" ]; + if [ -d "../train" ]; then rm -rf ../train fi diff --git a/model_zoo/mobilenetv2_quant/Readme.md b/model_zoo/mobilenetv2_quant/Readme.md new file mode 100644 index 000000000..81be5d519 --- /dev/null +++ b/model_zoo/mobilenetv2_quant/Readme.md @@ -0,0 +1,142 @@ +# MobileNetV2 Quantization Aware Training + +MobileNetV2 is a significant improvement over MobileNetV1 and pushes the state of the art for mobile visual recognition including classification, object detection and semantic segmentation. + +MobileNetV2 builds upon the ideas from MobileNetV1, using depthwise separable convolution as efficient building blocks. However, V2 introduces two new features to the architecture: 1) linear bottlenecks between the layers, and 2) shortcut connections between the bottlenecks1. + +Training MobileNetV2 with ImageNet dataset in MindSpore with quantization aware training. + +This is the simple and basic tutorial for constructing a network in MindSpore with quantization aware. + +In this readme tutorial, you will: + +1. Train a MindSpore fusion MobileNetV2 model for ImageNet from scratch using `nn.Conv2dBnAct` and `nn.DenseBnAct`. +2. Fine tune the fusion model by applying the quantization aware training auto network converter API `convert_quant_network`, after the network convergence then export a quantization aware model checkpoint file. + +[Paper](https://arxiv.org/pdf/1801.04381) Sandler, Mark, et al. "Mobilenetv2: Inverted residuals and linear bottlenecks." Proceedings of the IEEE conference on computer vision and pattern recognition. 2018. + +# Dataset + +Dataset use: ImageNet + +- Dataset size: about 125G + - Train: 120G, 1281167 images: 1000 directories + - Test: 5G, 50000 images: images should be classified into 1000 directories firstly, just like train images +- Data format: RGB images. + - Note: Data will be processed in src/dataset.py + +# Environment Requirements + +- Hardware(Ascend) + - Prepare hardware environment with Ascend processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. +- Framework + - [MindSpore](http://10.90.67.50/mindspore/archive/20200506/OpenSource/me_vm_x86/) +- For more information, please check the resources below: + - [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html) + - [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html) + + +# Script description + +## Script and sample code + +```python +├── mobilenetv2_quant + ├── Readme.md + ├── scripts + │ ├──run_train.sh + │ ├──run_infer.sh + │ ├──run_train_quant.sh + │ ├──run_infer_quant.sh + ├── src + │ ├──config.py + │ ├──dataset.py + │ ├──luanch.py + │ ├──lr_generator.py + │ ├──mobilenetV2.py + ├── train.py + ├── eval.py +``` + +## Training process + +### Train MobileNetV2 model + +Train a MindSpore fusion MobileNetV2 model for ImageNet, like: + +- sh run_train.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH] + +You can just run this command instead. + +``` bash +>>> sh run_train.sh Ascend 4 192.168.0.1 0,1,2,3 ~/imagenet/train/ ~/mobilenet.ckpt +``` + +Training result will be stored in the example path. Checkpoints will be stored at `. /checkpoint` by default, and training log will be redirected to `./train/train.log` like followings. + +``` +>>> epoch: [ 0/200], step:[ 624/ 625], loss:[5.258/5.258], time:[140412.236], lr:[0.100] +>>> epoch time: 140522.500, per step time: 224.836, avg loss: 5.258 +>>> epoch: [ 1/200], step:[ 624/ 625], loss:[3.917/3.917], time:[138221.250], lr:[0.200] +>>> epoch time: 138331.250, per step time: 221.330, avg loss: 3.917 +``` + +### Evaluate MobileNetV2 model + +Evaluate a MindSpore fusion MobileNetV2 model for ImageNet, like: + +- sh run_infer.sh Ascend [DATASET_PATH] [CHECKPOINT_PATH] + +You can just run this command instead. + +``` bash +>>> sh run_infer.sh Ascend ~/imagenet/val/ ~/train/mobilenet-200_625.ckpt +``` + +Inference result will be stored in the example path, you can find result like the followings in `val.log`. + +``` +>>> result: {'acc': 0.71976314102564111} ckpt=/path/to/checkpoint/mobilenet-200_625.ckpt +``` + +### Fine-tune for quantization aware training + +Fine tune the fusion model by applying the quantization aware training auto network converter API `convert_quant_network`, after the network convergence then export a quantization aware model checkpoint file. + +- sh run_train_quant.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH] + +You can just run this command instead. + +``` bash +>>> sh run_train_quant.sh Ascend 4 192.168.0.1 0,1,2,3 ~/imagenet/train/ ~/mobilenet.ckpt +``` + +Training result will be stored in the example path. Checkpoints will be stored at `. /checkpoint` by default, and training log will be redirected to `./train/train.log` like followings. + +``` +>>> epoch: [ 0/60], step:[ 624/ 625], loss:[5.258/5.258], time:[140412.236], lr:[0.100] +>>> epoch time: 140522.500, per step time: 224.836, avg loss: 5.258 +>>> epoch: [ 1/60], step:[ 624/ 625], loss:[3.917/3.917], time:[138221.250], lr:[0.200] +>>> epoch time: 138331.250, per step time: 221.330, avg loss: 3.917 +``` + +### Evaluate quantization aware training model + +Evaluate a MindSpore fusion MobileNetV2 model for ImageNet by applying the quantization aware training, like: + +- sh run_infer_quant.sh Ascend [DATASET_PATH] [CHECKPOINT_PATH] + +You can just run this command instead. + +``` bash +>>> sh run_infer_quant.sh Ascend ~/imagenet/val/ ~/train/mobilenet-60_625.ckpt +``` + +Inference result will be stored in the example path, you can find result like the followings in `val.log`. + +``` +>>> result: {'acc': 0.71976314102564111} ckpt=/path/to/checkpoint/mobilenet-60_625.ckpt +``` + +# ModelZoo Homepage + [Link](https://gitee.com/mindspore/mindspore/tree/master/mindspore/model_zoo) diff --git a/model_zoo/mobilenetv2_quant/eval.py b/model_zoo/mobilenetv2_quant/eval.py new file mode 100644 index 000000000..0976abbe9 --- /dev/null +++ b/model_zoo/mobilenetv2_quant/eval.py @@ -0,0 +1,76 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Evaluate MobilenetV2 on ImageNet""" + +import os +import argparse + +from mindspore import context +from mindspore import nn +from mindspore.train.model import Model +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.train.quant import quant + +from src.mobilenetV2 import mobilenetV2 +from src.dataset import create_dataset +from src.config import config_ascend + +parser = argparse.ArgumentParser(description='Image classification') +parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') +parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') +parser.add_argument('--device_target', type=str, default=None, help='Run device target') +parser.add_argument('--quantization_aware', type=bool, default=False, help='Use quantization aware training') +args_opt = parser.parse_args() + +if __name__ == '__main__': + config_device_target = None + if args_opt.device_target == "Ascend": + config_device_target = config_ascend + device_id = int(os.getenv('DEVICE_ID')) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", + device_id=device_id, save_graphs=False) + + else: + raise ValueError("Unsupported device target: {}.".format(args_opt.device_target)) + + # define fusion network + network = mobilenetV2(num_classes=config_device_target.num_classes) + if args_opt.quantization_aware: + # convert fusion network to quantization aware network + network = quant.convert_quant_network(network, bn_fold=True, per_channel=[True, False], symmetric=[True, False]) + # define network loss + loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') + + # define dataset + dataset = create_dataset(dataset_path=args_opt.dataset_path, + do_train=False, + config=config_device_target, + device_target=args_opt.device_target, + batch_size=config_device_target.batch_size) + step_size = dataset.get_dataset_size() + + # load checkpoint + if args_opt.checkpoint_path: + param_dict = load_checkpoint(args_opt.checkpoint_path) + load_param_into_net(network, param_dict) + network.set_train(False) + + # define model + model = Model(network, loss_fn=loss, metrics={'acc'}) + + print("============== Starting Validation ==============") + res = model.eval(dataset) + print("result:", res, "ckpt=", args_opt.checkpoint_path) + print("============== End Validation ==============") diff --git a/model_zoo/mobilenetv2_quant/scripts/run_infer.sh b/model_zoo/mobilenetv2_quant/scripts/run_infer.sh new file mode 100644 index 000000000..308723af2 --- /dev/null +++ b/model_zoo/mobilenetv2_quant/scripts/run_infer.sh @@ -0,0 +1,53 @@ +#!/usr/bin/env bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +if [ $# != 3 ] +then + echo "Ascend: sh run_infer.sh [PLATFORM] [DATASET_PATH] [CHECKPOINT_PATH]" +exit 1 +fi + +# check dataset path +if [ ! -d $2 ] && [ ! -f $2 ] +then + echo "error: DATASET_PATH=$2 is not a directory or file" +exit 1 +fi + +# check checkpoint file +if [ ! -f $3 ] +then + echo "error: CHECKPOINT_PATH=$3 is not a file" +exit 1 +fi + +# set environment +BASEPATH=$(cd "`dirname $0`" || exit; pwd) +export DEVICE_ID=0 +export RANK_ID=0 +export RANK_SIZE=1 +if [ -d "../eval" ]; +then + rm -rf ../eval +fi +mkdir ../eval +cd ../eval || exit + +# launch +python ${BASEPATH}/../eval.py \ + --device_target=$1 \ + --dataset_path=$2 \ + --checkpoint_path=$3 \ + &> infer.log & # dataset val folder path diff --git a/model_zoo/mobilenetv2_quant/scripts/run_infer_quant.sh b/model_zoo/mobilenetv2_quant/scripts/run_infer_quant.sh new file mode 100644 index 000000000..f8f3c1061 --- /dev/null +++ b/model_zoo/mobilenetv2_quant/scripts/run_infer_quant.sh @@ -0,0 +1,54 @@ +#!/usr/bin/env bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +if [ $# != 3 ] +then + echo "Ascend: sh run_infer.sh [PLATFORM] [DATASET_PATH] [CHECKPOINT_PATH]" +exit 1 +fi + +# check dataset path +if [ ! -d $2 ] && [ ! -f $2 ] +then + echo "error: DATASET_PATH=$2 is not a directory or file" +exit 1 +fi + +# check checkpoint file +if [ ! -f $3 ] +then + echo "error: CHECKPOINT_PATH=$3 is not a file" +exit 1 +fi + +# set environment +BASEPATH=$(cd "`dirname $0`" || exit; pwd) +export DEVICE_ID=0 +export RANK_ID=0 +export RANK_SIZE=1 +if [ -d "../eval" ]; +then + rm -rf ../eval +fi +mkdir ../eval +cd ../eval || exit + +# launch +python ${BASEPATH}/../eval.py \ + --device_target=$1 \ + --dataset_path=$2 \ + --checkpoint_path=$3 \ + --quantization_aware=True \ + &> infer.log & # dataset val folder path diff --git a/model_zoo/mobilenetv2_quant/scripts/run_train.sh b/model_zoo/mobilenetv2_quant/scripts/run_train.sh new file mode 100644 index 000000000..59b105f92 --- /dev/null +++ b/model_zoo/mobilenetv2_quant/scripts/run_train.sh @@ -0,0 +1,62 @@ +#!/usr/bin/env bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +run_ascend() +{ + if [ $2 -lt 1 ] && [ $2 -gt 8 ] + then + echo "error: DEVICE_NUM=$2 is not in (1-9)" + exit 1 + fi + + if [ ! -d $5 ] && [ ! -f $5 ] + then + echo "error: DATASET_PATH=$5 is not a directory or file" + exit 1 + fi + + BASEPATH=$(cd "`dirname $0`" || exit; pwd) + export PYTHONPATH=${BASEPATH}:$PYTHONPATH + if [ -d "../train" ]; + then + rm -rf ../train + fi + mkdir ../train + cd ../train || exit + python ${BASEPATH}/../src/launch.py \ + --nproc_per_node=$2 \ + --visible_devices=$4 \ + --server_id=$3 \ + --training_script=${BASEPATH}/../train.py \ + --dataset_path=$5 \ + --pre_trained=$6 \ + --device_target=$1 &> train.log & # dataset train folder +} + +if [ $# -gt 6 ] || [ $# -lt 4 ] +then + echo "Usage:\n \ + Ascend: sh run_train.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH]\n \ + " +exit 1 +fi + +if [ $1 = "Ascend" ] ; then + run_ascend "$@" +else + echo "Unsupported device target." +fi; + diff --git a/model_zoo/mobilenetv2_quant/scripts/run_train_quant.sh b/model_zoo/mobilenetv2_quant/scripts/run_train_quant.sh new file mode 100644 index 000000000..c82d1b0da --- /dev/null +++ b/model_zoo/mobilenetv2_quant/scripts/run_train_quant.sh @@ -0,0 +1,63 @@ +#!/usr/bin/env bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +run_ascend() +{ + if [ $2 -lt 1 ] && [ $2 -gt 8 ] + then + echo "error: DEVICE_NUM=$2 is not in (1-9)" + exit 1 + fi + + if [ ! -d $5 ] && [ ! -f $5 ] + then + echo "error: DATASET_PATH=$5 is not a directory or file" + exit 1 + fi + + BASEPATH=$(cd "`dirname $0`" || exit; pwd) + export PYTHONPATH=${BASEPATH}:$PYTHONPATH + if [ -d "../train" ]; + then + rm -rf ../train + fi + mkdir ../train + cd ../train || exit + python ${BASEPATH}/../src/launch.py \ + --nproc_per_node=$2 \ + --visible_devices=$4 \ + --server_id=$3 \ + --training_script=${BASEPATH}/../train.py \ + --dataset_path=$5 \ + --pre_trained=$6 \ + --quantization_aware=True \ + --device_target=$1 &> train.log & # dataset train folder +} + +if [ $# -gt 6 ] || [ $# -lt 4 ] +then + echo "Usage:\n \ + Ascend: sh run_train.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH]\n \ + " +exit 1 +fi + +if [ $1 = "Ascend" ] ; then + run_ascend "$@" +else + echo "Unsupported device target." +fi; + diff --git a/model_zoo/mobilenetv2_quant/src/config.py b/model_zoo/mobilenetv2_quant/src/config.py new file mode 100644 index 000000000..97fbc52e1 --- /dev/null +++ b/model_zoo/mobilenetv2_quant/src/config.py @@ -0,0 +1,60 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +network config setting, will be used in train.py and eval.py +""" +from easydict import EasyDict as ed + +config_ascend = ed({ + "num_classes": 1000, + "image_height": 224, + "image_width": 224, + "batch_size": 256, + "data_load_mode": "mindrecord", + "epoch_size": 200, + "start_epoch": 0, + "warmup_epochs": 4, + "lr": 0.4, + "momentum": 0.9, + "weight_decay": 4e-5, + "label_smooth": 0.1, + "loss_scale": 1024, + "save_checkpoint": True, + "save_checkpoint_epochs": 1, + "keep_checkpoint_max": 200, + "save_checkpoint_path": "./checkpoint", + "quantization_aware": False, +}) + +config_ascend_quant = ed({ + "num_classes": 1000, + "image_height": 224, + "image_width": 224, + "batch_size": 192, + "data_load_mode": "mindrecord", + "epoch_size": 60, + "start_epoch": 200, + "warmup_epochs": 1, + "lr": 0.3, + "momentum": 0.9, + "weight_decay": 4e-5, + "label_smooth": 0.1, + "loss_scale": 1024, + "save_checkpoint": True, + "save_checkpoint_epochs": 1, + "keep_checkpoint_max": 200, + "save_checkpoint_path": "./checkpoint", + "quantization_aware": True, +}) diff --git a/model_zoo/mobilenetv2_quant/src/dataset.py b/model_zoo/mobilenetv2_quant/src/dataset.py new file mode 100644 index 000000000..105a5e139 --- /dev/null +++ b/model_zoo/mobilenetv2_quant/src/dataset.py @@ -0,0 +1,156 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +create train or eval dataset. +""" +import os +from functools import partial +import mindspore.common.dtype as mstype +import mindspore.dataset.engine as de +import mindspore.dataset.transforms.vision.c_transforms as C +import mindspore.dataset.transforms.c_transforms as C2 +import mindspore.dataset.transforms.vision.py_transforms as P +from src.config import config_ascend + + +def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1, batch_size=32): + """ + create a train or eval dataset + + Args: + dataset_path(string): the path of dataset. + do_train(bool): whether dataset is used for train or eval. + repeat_num(int): the repeat times of dataset. Default: 1. + batch_size(int): the batch size of dataset. Default: 32. + + Returns: + dataset + """ + if device_target == "Ascend": + rank_size = int(os.getenv("RANK_SIZE")) + rank_id = int(os.getenv("RANK_ID")) + columns_list = ['image', 'label'] + if config_ascend.data_load_mode == "mindrecord": + load_func = partial(de.MindDataset, dataset_path, columns_list) + else: + load_func = partial(de.ImageFolderDatasetV2, dataset_path) + if do_train: + if rank_size == 1: + ds = load_func(num_parallel_workers=8, shuffle=True) + else: + ds = load_func(num_parallel_workers=8, shuffle=True, + num_shards=rank_size, shard_id=rank_id) + else: + ds = load_func(num_parallel_workers=8, shuffle=False) + else: + raise ValueError("Unsupport device_target.") + + resize_height = config.image_height + + if do_train: + buffer_size = 20480 + # apply shuffle operations + ds = ds.shuffle(buffer_size=buffer_size) + + # define map operations + decode_op = C.Decode() + resize_crop_decode_op = C.RandomCropDecodeResize(resize_height, scale=(0.08, 1.0), ratio=(0.75, 1.333)) + horizontal_flip_op = C.RandomHorizontalFlip(prob=0.5) + + resize_op = C.Resize(256) + center_crop = C.CenterCrop(resize_height) + normalize_op = C.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], + std=[0.229 * 255, 0.224 * 255, 0.225 * 255]) + change_swap_op = C.HWC2CHW() + + if do_train: + trans = [resize_crop_decode_op, horizontal_flip_op, normalize_op, change_swap_op] + else: + trans = [decode_op, resize_op, center_crop, normalize_op, change_swap_op] + + type_cast_op = C2.TypeCast(mstype.int32) + + ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=16) + ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=8) + + # apply batch operations + ds = ds.batch(batch_size, drop_remainder=True) + + # apply dataset repeat operation + ds = ds.repeat(repeat_num) + + return ds + + +def create_dataset_py(dataset_path, do_train, config, device_target, repeat_num=1, batch_size=32): + """ + create a train or eval dataset + + Args: + dataset_path(string): the path of dataset. + do_train(bool): whether dataset is used for train or eval. + repeat_num(int): the repeat times of dataset. Default: 1. + batch_size(int): the batch size of dataset. Default: 32. + + Returns: + dataset + """ + if device_target == "Ascend": + rank_size = int(os.getenv("RANK_SIZE")) + rank_id = int(os.getenv("RANK_ID")) + if do_train: + if rank_size == 1: + ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True) + else: + ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, + num_shards=rank_size, shard_id=rank_id) + else: + ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=False) + else: + raise ValueError("Unsupported device target.") + + resize_height = config.image_height + + if do_train: + buffer_size = 20480 + # apply shuffle operations + ds = ds.shuffle(buffer_size=buffer_size) + + # define map operations + decode_op = P.Decode() + resize_crop_op = P.RandomResizedCrop(resize_height, scale=(0.08, 1.0), ratio=(0.75, 1.333)) + horizontal_flip_op = P.RandomHorizontalFlip(prob=0.5) + + resize_op = P.Resize(256) + center_crop = P.CenterCrop(resize_height) + to_tensor = P.ToTensor() + normalize_op = P.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + if do_train: + trans = [decode_op, resize_crop_op, horizontal_flip_op, to_tensor, normalize_op] + else: + trans = [decode_op, resize_op, center_crop, to_tensor, normalize_op] + + compose = P.ComposeOp(trans) + + ds = ds.map(input_columns="image", operations=compose(), num_parallel_workers=8, python_multiprocessing=True) + + # apply batch operations + ds = ds.batch(batch_size, drop_remainder=True) + + # apply dataset repeat operation + ds = ds.repeat(repeat_num) + + return ds diff --git a/model_zoo/mobilenetv2_quant/src/launch.py b/model_zoo/mobilenetv2_quant/src/launch.py new file mode 100644 index 000000000..08477a363 --- /dev/null +++ b/model_zoo/mobilenetv2_quant/src/launch.py @@ -0,0 +1,166 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""launch train script""" +import os +import sys +import json +import subprocess +import shutil +import platform +from argparse import ArgumentParser + + +def parse_args(): + """ + parse args . + + Args: + + Returns: + args. + + Examples: + >>> parse_args() + """ + parser = ArgumentParser(description="mindspore distributed training launch " + "helper utilty that will spawn up " + "multiple distributed processes") + parser.add_argument("--nproc_per_node", type=int, default=1, + help="The number of processes to launch on each node, " + "for D training, this is recommended to be set " + "to the number of D in your system so that " + "each process can be bound to a single D.") + parser.add_argument("--visible_devices", type=str, default="0,1,2,3,4,5,6,7", + help="will use the visible devices sequentially") + parser.add_argument("--server_id", type=str, default="", + help="server ip") + parser.add_argument("--training_script", type=str, + help="The full path to the single D training " + "program/script to be launched in parallel, " + "followed by all the arguments for the " + "training script") + # rest from the training program + args, unknown = parser.parse_known_args() + args.training_script_args = unknown + return args + + +def main(): + print("start", __file__) + args = parse_args() + print(args) + visible_devices = args.visible_devices.split(',') + assert os.path.isfile(args.training_script) + assert len(visible_devices) >= args.nproc_per_node + print('visible_devices:{}'.format(visible_devices)) + if not args.server_id: + print('pleaser input server ip!!!') + exit(0) + print('server_id:{}'.format(args.server_id)) + + # construct hccn_table + hccn_configs = open('/etc/hccn.conf', 'r').readlines() + device_ips = {} + for hccn_item in hccn_configs: + hccn_item = hccn_item.strip() + if hccn_item.startswith('address_'): + device_id, device_ip = hccn_item.split('=') + device_id = device_id.split('_')[1] + device_ips[device_id] = device_ip + print('device_id:{}, device_ip:{}'.format(device_id, device_ip)) + hccn_table = {} + arch = platform.processor() + hccn_table['board_id'] = {'aarch64': '0x002f', 'x86_64': '0x0000'}[arch] + hccn_table['chip_info'] = '910' + hccn_table['deploy_mode'] = 'lab' + hccn_table['group_count'] = '1' + hccn_table['group_list'] = [] + instance_list = [] + usable_dev = '' + for instance_id in range(args.nproc_per_node): + instance = {} + instance['devices'] = [] + device_id = visible_devices[instance_id] + device_ip = device_ips[device_id] + usable_dev += str(device_id) + instance['devices'].append({ + 'device_id': device_id, + 'device_ip': device_ip, + }) + instance['rank_id'] = str(instance_id) + instance['server_id'] = args.server_id + instance_list.append(instance) + hccn_table['group_list'].append({ + 'device_num': str(args.nproc_per_node), + 'server_num': '1', + 'group_name': '', + 'instance_count': str(args.nproc_per_node), + 'instance_list': instance_list, + }) + hccn_table['para_plane_nic_location'] = 'device' + hccn_table['para_plane_nic_name'] = [] + for instance_id in range(args.nproc_per_node): + eth_id = visible_devices[instance_id] + hccn_table['para_plane_nic_name'].append('eth{}'.format(eth_id)) + hccn_table['para_plane_nic_num'] = str(args.nproc_per_node) + hccn_table['status'] = 'completed' + + # save hccn_table to file + table_path = os.getcwd() + if not os.path.exists(table_path): + os.mkdir(table_path) + table_fn = os.path.join(table_path, + 'rank_table_{}p_{}_{}.json'.format(args.nproc_per_node, usable_dev, args.server_id)) + with open(table_fn, 'w') as table_fp: + json.dump(hccn_table, table_fp, indent=4) + sys.stdout.flush() + + # spawn the processes + processes = [] + cmds = [] + log_files = [] + env = os.environ.copy() + env['RANK_SIZE'] = str(args.nproc_per_node) + cur_path = os.getcwd() + for rank_id in range(0, args.nproc_per_node): + os.chdir(cur_path) + device_id = visible_devices[rank_id] + device_dir = os.path.join(cur_path, 'device{}'.format(rank_id)) + env['RANK_ID'] = str(rank_id) + env['DEVICE_ID'] = str(device_id) + if args.nproc_per_node > 1: + env['MINDSPORE_HCCL_CONFIG_PATH'] = table_fn + env['RANK_TABLE_FILE'] = table_fn + if os.path.exists(device_dir): + shutil.rmtree(device_dir) + os.mkdir(device_dir) + os.chdir(device_dir) + cmd = [sys.executable, '-u'] + cmd.append(args.training_script) + cmd.extend(args.training_script_args) + log_file = open('{dir}/log{id}.log'.format(dir=device_dir, id=rank_id), 'w') + process = subprocess.Popen(cmd, stdout=log_file, stderr=log_file, env=env) + processes.append(process) + cmds.append(cmd) + log_files.append(log_file) + for process, cmd, log_file in zip(processes, cmds, log_files): + process.wait() + if process.returncode != 0: + raise subprocess.CalledProcessError(returncode=process, cmd=cmd) + log_file.close() + + +if __name__ == "__main__": + main() diff --git a/model_zoo/mobilenetv2_quant/src/lr_generator.py b/model_zoo/mobilenetv2_quant/src/lr_generator.py new file mode 100644 index 000000000..68bbfe315 --- /dev/null +++ b/model_zoo/mobilenetv2_quant/src/lr_generator.py @@ -0,0 +1,54 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""learning rate generator""" +import math +import numpy as np + + +def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch): + """ + generate learning rate array + + Args: + global_step(int): total steps of the training + lr_init(float): init learning rate + lr_end(float): end learning rate + lr_max(float): max learning rate + warmup_epochs(int): number of warmup epochs + total_epochs(int): total epoch of training + steps_per_epoch(int): steps of one epoch + + Returns: + np.array, learning rate array + """ + lr_each_step = [] + total_steps = steps_per_epoch * total_epochs + warmup_steps = steps_per_epoch * warmup_epochs + for i in range(total_steps): + if i < warmup_steps: + lr = lr_init + (lr_max - lr_init) * i / warmup_steps + else: + lr = lr_end + \ + (lr_max - lr_end) * \ + (1. + math.cos(math.pi * (i - warmup_steps) / (total_steps - warmup_steps))) / 2. + if lr < 0.0: + lr = 0.0 + lr_each_step.append(lr) + + current_step = global_step + lr_each_step = np.array(lr_each_step).astype(np.float32) + learning_rate = lr_each_step[current_step:] + + return learning_rate diff --git a/model_zoo/mobilenetv2_quant/src/mobilenetV2.py b/model_zoo/mobilenetv2_quant/src/mobilenetV2.py new file mode 100644 index 000000000..25dccfed1 --- /dev/null +++ b/model_zoo/mobilenetv2_quant/src/mobilenetV2.py @@ -0,0 +1,231 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""MobileNetV2 Quant model define""" + +import numpy as np + +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore import Tensor + +__all__ = ['mobilenetV2'] + + +def _make_divisible(v, divisor, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class GlobalAvgPooling(nn.Cell): + """ + Global avg pooling definition. + + Args: + + Returns: + Tensor, output tensor. + + Examples: + >>> GlobalAvgPooling() + """ + + def __init__(self): + super(GlobalAvgPooling, self).__init__() + self.mean = P.ReduceMean(keep_dims=False) + + def construct(self, x): + x = self.mean(x, (2, 3)) + return x + + +class ConvBNReLU(nn.Cell): + """ + Convolution/Depthwise fused with Batchnorm and ReLU block definition. + + Args: + in_planes (int): Input channel. + out_planes (int): Output channel. + kernel_size (int): Input kernel size. + stride (int): Stride size for the first convolutional layer. Default: 1. + groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1. + + Returns: + Tensor, output tensor. + + Examples: + >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1) + """ + + def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): + super(ConvBNReLU, self).__init__() + padding = (kernel_size - 1) // 2 + self.conv = nn.Conv2dBnAct(in_planes, out_planes, kernel_size, + stride=stride, + pad_mode='pad', + padding=padding, + group=groups, + has_bn=True, + activation='relu') + + def construct(self, x): + x = self.conv(x) + return x + + +class InvertedResidual(nn.Cell): + """ + Mobilenetv2 residual block definition. + + Args: + inp (int): Input channel. + oup (int): Output channel. + stride (int): Stride size for the first convolutional layer. Default: 1. + expand_ratio (int): expand ration of input channel + + Returns: + Tensor, output tensor. + + Examples: + >>> ResidualBlock(3, 256, 1, 1) + """ + + def __init__(self, inp, oup, stride, expand_ratio): + super(InvertedResidual, self).__init__() + assert stride in [1, 2] + + hidden_dim = int(round(inp * expand_ratio)) + self.use_res_connect = stride == 1 and inp == oup + + layers = [] + if expand_ratio != 1: + layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) + layers.extend([ + # dw + ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), + # pw-linear + nn.Conv2dBnAct(hidden_dim, oup, kernel_size=1, stride=1, pad_mode='pad', padding=0, group=1, has_bn=True) + ]) + self.conv = nn.SequentialCell(layers) + self.add = P.TensorAdd() + + def construct(self, x): + out = self.conv(x) + if self.use_res_connect: + out = self.add(out, x) + return out + + +class mobilenetV2(nn.Cell): + """ + mobilenetV2 fusion architecture. + + Args: + class_num (Cell): number of classes. + width_mult (int): Channels multiplier for round to 8/16 and others. Default is 1. + has_dropout (bool): Is dropout used. Default is false + inverted_residual_setting (list): Inverted residual settings. Default is None + round_nearest (list): Channel round to . Default is 8 + Returns: + Tensor, output tensor. + + Examples: + >>> mobilenetV2(num_classes=1000) + """ + + def __init__(self, num_classes=1000, width_mult=1., + has_dropout=False, inverted_residual_setting=None, round_nearest=8): + super(mobilenetV2, self).__init__() + block = InvertedResidual + input_channel = 32 + last_channel = 1280 + # setting of inverted residual blocks + self.cfgs = inverted_residual_setting + if inverted_residual_setting is None: + self.cfgs = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + [6, 96, 3, 1], + [6, 160, 3, 2], + [6, 320, 1, 1], + ] + + # building first layer + input_channel = _make_divisible(input_channel * width_mult, round_nearest) + self.out_channels = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) + + features = [ConvBNReLU(3, input_channel, stride=2)] + # building inverted residual blocks + for t, c, n, s in self.cfgs: + output_channel = _make_divisible(c * width_mult, round_nearest) + for i in range(n): + stride = s if i == 0 else 1 + features.append(block(input_channel, output_channel, stride, expand_ratio=t)) + input_channel = output_channel + # building last several layers + features.append(ConvBNReLU(input_channel, self.out_channels, kernel_size=1)) + # make it nn.CellList + self.features = nn.SequentialCell(features) + # mobilenet head + head = ([GlobalAvgPooling(), + nn.DenseBnAct(self.out_channels, num_classes, has_bias=True, has_bn=False) + ] if not has_dropout else + [GlobalAvgPooling(), + nn.Dropout(0.2), + nn.DenseBnAct(self.out_channels, num_classes, has_bias=True, has_bn=False) + ]) + self.head = nn.SequentialCell(head) + + # init weights + self._initialize_weights() + + def construct(self, x): + x = self.features(x) + x = self.head(x) + return x + + def _initialize_weights(self): + """ + Initialize weights. + + Args: + + Returns: + None. + + Examples: + >>> _initialize_weights() + """ + for _, m in self.cells_and_names(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + w = Tensor(np.random.normal(0, np.sqrt(2. / n), m.weight.data.shape).astype("float32")) + m.weight.set_parameter_data(w) + if m.bias is not None: + m.bias.set_parameter_data(Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) + elif isinstance(m, nn.BatchNorm2d): + m.gamma.set_parameter_data(Tensor(np.ones(m.gamma.data.shape, dtype="float32"))) + m.beta.set_parameter_data(Tensor(np.zeros(m.beta.data.shape, dtype="float32"))) + elif isinstance(m, nn.Dense): + m.weight.set_parameter_data(Tensor(np.random.normal(0, 0.01, m.weight.data.shape).astype("float32"))) + if m.bias is not None: + m.bias.set_parameter_data(Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) diff --git a/model_zoo/mobilenetv2_quant/src/utils.py b/model_zoo/mobilenetv2_quant/src/utils.py new file mode 100644 index 000000000..8690d9c38 --- /dev/null +++ b/model_zoo/mobilenetv2_quant/src/utils.py @@ -0,0 +1,113 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""MobileNetV2 utils""" + +import time +import numpy as np + +from mindspore.train.callback import Callback +from mindspore import Tensor +from mindspore import nn +from mindspore.nn.loss.loss import _Loss +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.common import dtype as mstype + + +class Monitor(Callback): + """ + Monitor loss and time. + + Args: + lr_init (numpy array): train lr + + Returns: + None + + Examples: + >>> Monitor(100,lr_init=Tensor([0.05]*100).asnumpy()) + """ + + def __init__(self, lr_init=None): + super(Monitor, self).__init__() + self.lr_init = lr_init + self.lr_init_len = len(lr_init) + + def epoch_begin(self, run_context): + self.losses = [] + self.epoch_time = time.time() + + def epoch_end(self, run_context): + cb_params = run_context.original_args() + + epoch_mseconds = (time.time() - self.epoch_time) * 1000 + per_step_mseconds = epoch_mseconds / cb_params.batch_num + print("epoch time: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}".format(epoch_mseconds, + per_step_mseconds, + np.mean(self.losses))) + + def step_begin(self, run_context): + self.step_time = time.time() + + def step_end(self, run_context): + cb_params = run_context.original_args() + step_mseconds = (time.time() - self.step_time) * 1000 + step_loss = cb_params.net_outputs + + if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor): + step_loss = step_loss[0] + if isinstance(step_loss, Tensor): + step_loss = np.mean(step_loss.asnumpy()) + + self.losses.append(step_loss) + cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + + print("epoch: [{:3d}/{:3d}], step:[{:5d}/{:5d}], loss:[{:5.3f}/{:5.3f}], time:[{:5.3f}], lr:[{:5.5f}]".format( + cb_params.cur_epoch_num - + 1, cb_params.epoch_num, cur_step_in_epoch, cb_params.batch_num, step_loss, + np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1])) + + +class CrossEntropyWithLabelSmooth(_Loss): + """ + CrossEntropyWith LabelSmooth. + + Args: + smooth_factor (float): smooth factor, default=0. + num_classes (int): num classes + + Returns: + None. + + Examples: + >>> CrossEntropyWithLabelSmooth(smooth_factor=0., num_classes=1000) + """ + + def __init__(self, smooth_factor=0., num_classes=1000): + super(CrossEntropyWithLabelSmooth, self).__init__() + self.onehot = P.OneHot() + self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) + self.off_value = Tensor(1.0 * smooth_factor / + (num_classes - 1), mstype.float32) + self.ce = nn.SoftmaxCrossEntropyWithLogits() + self.mean = P.ReduceMean(False) + self.cast = P.Cast() + + def construct(self, logit, label): + one_hot_label = self.onehot(self.cast(label, mstype.int32), F.shape(logit)[1], + self.on_value, self.off_value) + out_loss = self.ce(logit, one_hot_label) + out_loss = self.mean(out_loss, 0) + return out_loss diff --git a/model_zoo/mobilenetv2_quant/train.py b/model_zoo/mobilenetv2_quant/train.py new file mode 100644 index 000000000..1302c3cf2 --- /dev/null +++ b/model_zoo/mobilenetv2_quant/train.py @@ -0,0 +1,131 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Train mobilenetV2 on ImageNet""" + +import os +import argparse +import random +import numpy as np + +from mindspore import context +from mindspore import Tensor +from mindspore import nn +from mindspore.train.model import Model, ParallelMode +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.communication.management import init +from mindspore.train.quant import quant +import mindspore.dataset.engine as de + +from src.dataset import create_dataset +from src.lr_generator import get_lr +from src.utils import Monitor, CrossEntropyWithLabelSmooth +from src.config import config_ascend, config_ascend_quant +from src.mobilenetV2 import mobilenetV2 + +random.seed(1) +np.random.seed(1) +de.config.set_seed(1) + +parser = argparse.ArgumentParser(description='Image classification') +parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') +parser.add_argument('--pre_trained', type=str, default=None, help='Pertained checkpoint path') +parser.add_argument('--device_target', type=str, default=None, help='Run device target') +parser.add_argument('--quantization_aware', type=bool, default=False, help='Use quantization aware training') +args_opt = parser.parse_args() + +if args_opt.device_target == "Ascend": + device_id = int(os.getenv('DEVICE_ID')) + rank_id = int(os.getenv('RANK_ID')) + rank_size = int(os.getenv('RANK_SIZE')) + run_distribute = rank_size > 1 + device_id = int(os.getenv('DEVICE_ID')) + context.set_context(mode=context.GRAPH_MODE, + device_target="Ascend", + device_id=device_id, save_graphs=False) +else: + raise ValueError("Unsupported device target.") + +if __name__ == '__main__': + # train on ascend + config = config_ascend_quant if args_opt.quantization_aware else config_ascend + print("training args: {}".format(args_opt)) + print("training configure: {}".format(config)) + print("parallel args: rank_id {}, device_id {}, rank_size {}".format(rank_id, device_id, rank_size)) + epoch_size = config.epoch_size + + # distribute init + if run_distribute: + context.set_auto_parallel_context(device_num=rank_size, + parallel_mode=ParallelMode.DATA_PARALLEL, + parameter_broadcast=True, + mirror_mean=True) + init() + + # define network + network = mobilenetV2(num_classes=config.num_classes) + # define loss + if config.label_smooth > 0: + loss = CrossEntropyWithLabelSmooth(smooth_factor=config.label_smooth, num_classes=config.num_classes) + else: + loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') + # define dataset + dataset = create_dataset(dataset_path=args_opt.dataset_path, + do_train=True, + config=config, + device_target=args_opt.device_target, + repeat_num=epoch_size, + batch_size=config.batch_size) + step_size = dataset.get_dataset_size() + # load pre trained ckpt + if args_opt.pre_trained: + param_dict = load_checkpoint(args_opt.pre_trained) + load_param_into_net(network, param_dict) + + # convert fusion network to quantization aware network + if config.quantization_aware: + network = quant.convert_quant_network(network, + bn_fold=True, + per_channel=[True, False], + symmetric=[True, False]) + + # get learning rate + lr = Tensor(get_lr(global_step=config.start_epoch * step_size, + lr_init=0, + lr_end=0, + lr_max=config.lr, + warmup_epochs=config.warmup_epochs, + total_epochs=epoch_size + config.start_epoch, + steps_per_epoch=step_size)) + + # define optimization + opt = nn.Momentum(filter(lambda x: x.requires_grad, network.get_parameters()), lr, config.momentum, + config.weight_decay) + # define model + model = Model(network, loss_fn=loss, optimizer=opt) + + print("============== Starting Training ==============") + callback = None + if rank_id == 0: + callback = [Monitor(lr_init=lr.asnumpy())] + if config.save_checkpoint: + config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size, + keep_checkpoint_max=config.keep_checkpoint_max) + ckpt_cb = ModelCheckpoint(prefix="mobilenetV2", + directory=config.save_checkpoint_path, + config=config_ck) + callback += [ckpt_cb] + model.train(epoch_size, dataset, callbacks=callback) + print("============== End Training ==============") diff --git a/model_zoo/mobilenetv3/scripts/run_train.sh b/model_zoo/mobilenetv3/scripts/run_train.sh index 78b79b235..4912c895d 100644 --- a/model_zoo/mobilenetv3/scripts/run_train.sh +++ b/model_zoo/mobilenetv3/scripts/run_train.sh @@ -29,7 +29,7 @@ run_ascend() BASEPATH=$(cd "`dirname $0`" || exit; pwd) export PYTHONPATH=${BASEPATH}:$PYTHONPATH - if [ -d "train" ]; + if [ -d "../train" ]; then rm -rf ../train fi diff --git a/model_zoo/resnet50_quant/Readme.md b/model_zoo/resnet50_quant/Readme.md new file mode 100644 index 000000000..9e843b222 --- /dev/null +++ b/model_zoo/resnet50_quant/Readme.md @@ -0,0 +1,122 @@ +# ResNet-50_quant Example + +## Description + +This is an example of training ResNet-50_quant with ImageNet2012 dataset in MindSpore. + +## Requirements + +- Install [MindSpore](https://www.mindspore.cn/install/en). + +- Download the dataset ImageNet2012 + +> Unzip the ImageNet2012 dataset to any path you want and the folder structure should include train and eval dataset as follows: +> ``` +> . +> ├── ilsvrc # train dataset +> └── ilsvrc_eval # infer dataset: images should be classified into 1000 directories firstly, just like train images +> ``` + + +## Example structure + +```shell +. +├── Resnet50_quant + ├── Readme.md + ├── scripts + │ ├──run_train.sh + │ ├──run_eval.sh + ├── src + │ ├──config.py + │ ├──crossentropy.py + │ ├──dataset.py + │ ├──luanch.py + │ ├──lr_generator.py + │ ├──utils.py + ├── models + │ ├──resnet_quant.py + ├── train.py + ├── eval.py +``` + + +## Parameter configuration + +Parameters for both training and inference can be set in config.py. + +``` +"class_num": 1001, # dataset class number +"batch_size": 32, # batch size of input tensor +"loss_scale": 1024, # loss scale +"momentum": 0.9, # momentum optimizer +"weight_decay": 1e-4, # weight decay +"epoch_size": 120, # only valid for taining, which is always 1 for inference +"pretrained_epoch_size": 90, # epoch size that model has been trained before load pretrained checkpoint +"buffer_size": 1000, # number of queue size in data preprocessing +"image_height": 224, # image height +"image_width": 224, # image width +"save_checkpoint": True, # whether save checkpoint or not +"save_checkpoint_epochs": 1, # the epoch interval between two checkpoints. By default, the last checkpoint will be saved after the last epoch +"keep_checkpoint_max": 50, # only keep the last keep_checkpoint_max checkpoint +"save_checkpoint_path": "./", # path to save checkpoint relative to the executed path +"warmup_epochs": 0, # number of warmup epoch +"lr_decay_mode": "cosine", # decay mode for generating learning rate +"label_smooth": True, # label smooth +"label_smooth_factor": 0.1, # label smooth factor +"lr_init": 0, # initial learning rate +"lr_max": 0.005, # maximum learning rate +``` + +## Running the example + +### Train + +### Usage + +- Ascend: sh run_train.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH] + + +### Launch + +``` +# training example + Ascend: sh run_train.sh Ascend 8 192.168.0.1 0,1,2,3,4,5,6,7 ~/imagenet/train/ +``` + +### Result + +Training result will be stored in the example path. Checkpoints will be stored at `. /checkpoint` by default, and training log will be redirected to `./train/train.log` like followings. + +``` +epoch: 1 step: 5004, loss is 4.8995576 +epoch: 2 step: 5004, loss is 3.9235563 +epoch: 3 step: 5004, loss is 3.833077 +epoch: 4 step: 5004, loss is 3.2795618 +epoch: 5 step: 5004, loss is 3.1978393 +``` + +## Eval process + +### Usage + +- Ascend: sh run_infer.sh Ascend [DATASET_PATH] [CHECKPOINT_PATH] + +### Launch + +``` +# infer example + Ascend: sh run_infer.sh Ascend ~/imagenet/val/ ~/checkpoint/resnet50-110_5004.ckpt +``` + + +> checkpoint can be produced in training process. + +#### Result + +Inference result will be stored in the example path, whose folder name is "infer". Under this, you can find result like the followings in log. + +``` +result: {'acc': 0.75.252054737516005} ckpt=train_parallel0/resnet-110_5004.ckpt +``` + diff --git a/model_zoo/resnet50_quant/eval.py b/model_zoo/resnet50_quant/eval.py new file mode 100755 index 000000000..481e4bb85 --- /dev/null +++ b/model_zoo/resnet50_quant/eval.py @@ -0,0 +1,78 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Evaluate Resnet50 on ImageNet""" + +import os +import argparse + +from src.config import quant_set, config_quant, config_noquant +from src.dataset import create_dataset +from src.crossentropy import CrossEntropy +from src.utils import _load_param_into_net +from models.resnet_quant import resnet50_quant + +from mindspore import context +from mindspore.train.model import Model +from mindspore.train.serialization import load_checkpoint +from mindspore.train.quant import quant + +parser = argparse.ArgumentParser(description='Image classification') +parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') +parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') +parser.add_argument('--device_target', type=str, default='Ascend', help='Device target') +args_opt = parser.parse_args() + +context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, save_graphs=False) +config = config_quant if quant_set.quantization_aware else config_noquant + +if args_opt.device_target == "Ascend": + device_id = int(os.getenv('DEVICE_ID')) + context.set_context(device_id=device_id) + +if __name__ == '__main__': + # define fusion network + net = resnet50_quant(class_num=config.class_num) + if quant_set.quantization_aware: + # convert fusion network to quantization aware network + net = quant.convert_quant_network(net, + bn_fold=True, + per_channel=[True, False], + symmetric=[True, False]) + # define network loss + if not config.use_label_smooth: + config.label_smooth_factor = 0.0 + loss = CrossEntropy(smooth_factor=config.label_smooth_factor, + num_classes=config.class_num) + + # define dataset + dataset = create_dataset(dataset_path=args_opt.dataset_path, + do_train=False, + batch_size=config.batch_size, + target=args_opt.device_target) + step_size = dataset.get_dataset_size() + + # load checkpoint + if args_opt.checkpoint_path: + param_dict = load_checkpoint(args_opt.checkpoint_path) + _load_param_into_net(net, param_dict) + net.set_train(False) + + # define model + model = Model(net, loss_fn=loss, metrics={'acc'}) + + print("============== Starting Validation ==============") + res = model.eval(dataset) + print("result:", res, "ckpt=", args_opt.checkpoint_path) + print("============== End Validation ==============") diff --git a/model_zoo/resnet50_quant/models/resnet_quant.py b/model_zoo/resnet50_quant/models/resnet_quant.py new file mode 100755 index 000000000..63fa32222 --- /dev/null +++ b/model_zoo/resnet50_quant/models/resnet_quant.py @@ -0,0 +1,251 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""ResNet.""" +import mindspore.nn as nn +from mindspore.ops import operations as P + + +class ConvBNReLU(nn.Cell): + """ + Convolution/Depthwise fused with Batchnorm and ReLU block definition. + + Args: + in_planes (int): Input channel. + out_planes (int): Output channel. + kernel_size (int): Input kernel size. + stride (int): Stride size for the first convolutional layer. Default: 1. + groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1. + + Returns: + Tensor, output tensor. + + Examples: + >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1) + """ + + def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): + super(ConvBNReLU, self).__init__() + padding = (kernel_size - 1) // 2 + conv = nn.Conv2dBnAct(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding, + group=groups, has_bn=True, activation='relu') + self.features = conv + + def construct(self, x): + output = self.features(x) + return output + + +class ResidualBlock(nn.Cell): + """ + ResNet V1 residual block definition. + + Args: + in_channel (int): Input channel. + out_channel (int): Output channel. + stride (int): Stride size for the first convolutional layer. Default: 1. + + Returns: + Tensor, output tensor. + + Examples: + >>> ResidualBlock(3, 256, stride=2) + """ + expansion = 4 + + def __init__(self, + in_channel, + out_channel, + stride=1): + super(ResidualBlock, self).__init__() + + channel = out_channel // self.expansion + self.conv1 = ConvBNReLU(in_channel, channel, kernel_size=1, stride=1) + self.conv2 = ConvBNReLU(channel, channel, kernel_size=3, stride=stride) + self.conv3 = nn.Conv2dBnAct(channel, out_channel, kernel_size=1, stride=1, pad_mode='same', padding=0, + has_bn=True, activation='relu') + + self.down_sample = False + if stride != 1 or in_channel != out_channel: + self.down_sample = True + self.down_sample_layer = None + + if self.down_sample: + self.down_sample_layer = nn.Conv2dBnAct(in_channel, out_channel, + kernel_size=1, stride=stride, + pad_mode='same', padding=0, has_bn=True, activation='relu') + self.add = P.TensorAdd() + self.relu = P.ReLU() + + def construct(self, x): + identity = x + out = self.conv1(x) + out = self.conv2(out) + out = self.conv3(out) + + if self.down_sample: + identity = self.down_sample_layer(identity) + + out = self.add(out, identity) + out = self.relu(out) + + return out + + +class ResNet(nn.Cell): + """ + ResNet architecture. + + Args: + block (Cell): Block for network. + layer_nums (list): Numbers of block in different layers. + in_channels (list): Input channel in each layer. + out_channels (list): Output channel in each layer. + strides (list): Stride size in each layer. + num_classes (int): The number of classes that the training images are belonging to. + Returns: + Tensor, output tensor. + + Examples: + >>> ResNet(ResidualBlock, + >>> [3, 4, 6, 3], + >>> [64, 256, 512, 1024], + >>> [256, 512, 1024, 2048], + >>> [1, 2, 2, 2], + >>> 10) + """ + + def __init__(self, + block, + layer_nums, + in_channels, + out_channels, + strides, + num_classes): + super(ResNet, self).__init__() + + if not len(layer_nums) == len(in_channels) == len(out_channels) == 4: + raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!") + + self.conv1 = ConvBNReLU(3, 64, kernel_size=7, stride=2) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") + + self.layer1 = self._make_layer(block, + layer_nums[0], + in_channel=in_channels[0], + out_channel=out_channels[0], + stride=strides[0]) + self.layer2 = self._make_layer(block, + layer_nums[1], + in_channel=in_channels[1], + out_channel=out_channels[1], + stride=strides[1]) + self.layer3 = self._make_layer(block, + layer_nums[2], + in_channel=in_channels[2], + out_channel=out_channels[2], + stride=strides[2]) + self.layer4 = self._make_layer(block, + layer_nums[3], + in_channel=in_channels[3], + out_channel=out_channels[3], + stride=strides[3]) + + self.mean = P.ReduceMean(keep_dims=True) + self.flatten = nn.Flatten() + self.end_point = nn.DenseBnAct(out_channels[3], num_classes, has_bias=True, has_bn=False) + + def _make_layer(self, block, layer_num, in_channel, out_channel, stride): + """ + Make stage network of ResNet. + + Args: + block (Cell): Resnet block. + layer_num (int): Layer number. + in_channel (int): Input channel. + out_channel (int): Output channel. + stride (int): Stride size for the first convolutional layer. + + Returns: + SequentialCell, the output layer. + + Examples: + >>> _make_layer(ResidualBlock, 3, 128, 256, 2) + """ + layers = [] + + resnet_block = block(in_channel, out_channel, stride=stride) + layers.append(resnet_block) + + for _ in range(1, layer_num): + resnet_block = block(out_channel, out_channel, stride=1) + layers.append(resnet_block) + + return nn.SequentialCell(layers) + + def construct(self, x): + x = self.conv1(x) + c1 = self.maxpool(x) + + c2 = self.layer1(c1) + c3 = self.layer2(c2) + c4 = self.layer3(c3) + c5 = self.layer4(c4) + + out = self.mean(c5, (2, 3)) + out = self.flatten(out) + out = self.end_point(out) + return out + + +def resnet50_quant(class_num=10001): + """ + Get ResNet50 neural network. + + Args: + class_num (int): Class number. + + Returns: + Cell, cell instance of ResNet50 neural network. + + Examples: + >>> net = resnet50_quant(10) + """ + return ResNet(ResidualBlock, + [3, 4, 6, 3], + [64, 256, 512, 1024], + [256, 512, 1024, 2048], + [1, 2, 2, 2], + class_num) + + +def resnet101_quant(class_num=1001): + """ + Get ResNet101 neural network. + + Args: + class_num (int): Class number. + + Returns: + Cell, cell instance of ResNet101 neural network. + + Examples: + >>> net = resnet101(1001) + """ + return ResNet(ResidualBlock, + [3, 4, 23, 3], + [64, 256, 512, 1024], + [256, 512, 1024, 2048], + [1, 2, 2, 2], + class_num) diff --git a/model_zoo/resnet50_quant/scripts/run_infer.sh b/model_zoo/resnet50_quant/scripts/run_infer.sh new file mode 100644 index 000000000..1d74f6373 --- /dev/null +++ b/model_zoo/resnet50_quant/scripts/run_infer.sh @@ -0,0 +1,54 @@ +#!/usr/bin/env bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +if [ $# != 3 ] +then + echo "Ascend: sh run_infer.sh [PLATFORM] [DATASET_PATH] [CHECKPOINT_PATH]" +exit 1 +fi + +# check dataset path +if [ ! -d $2 ] && [ ! -f $2 ] +then + echo "error: DATASET_PATH=$2 is not a directory or file" +exit 1 +fi + +# check checkpoint file +if [ ! -f $3 ] +then + echo "error: CHECKPOINT_PATH=$3 is not a file" +exit 1 +fi + +# set environment +BASEPATH=$(cd "`dirname $0`" || exit; pwd) +export PYTHONPATH=${BASEPATH}:$PYTHONPATH +export DEVICE_ID=0 +export RANK_ID=0 +export RANK_SIZE=1 +if [ -d "../eval" ]; +then + rm -rf ../eval +fi +mkdir ../eval +cd ../eval || exit + +# luanch +python ${BASEPATH}/../eval.py \ + --device_target=$1 \ + --dataset_path=$2 \ + --checkpoint_path=$3 \ + &> infer.log & # dataset val folder path diff --git a/model_zoo/resnet50_quant/scripts/run_train.sh b/model_zoo/resnet50_quant/scripts/run_train.sh new file mode 100644 index 000000000..a42720158 --- /dev/null +++ b/model_zoo/resnet50_quant/scripts/run_train.sh @@ -0,0 +1,62 @@ +#!/usr/bin/env bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +run_ascend() +{ + if [ $2 -lt 1 ] && [ $2 -gt 8 ] + then + echo "error: DEVICE_NUM=$2 is not in (1-8)" + exit 1 + fi + + if [ ! -d $5 ] && [ ! -f $5 ] + then + echo "error: DATASET_PATH=$5 is not a directory or file" + exit 1 + fi + + BASEPATH=$(cd "`dirname $0`" || exit; pwd) + export PYTHONPATH=${BASEPATH}:$PYTHONPATH + if [ -d "../train" ]; + then + rm -rf ../train + fi + mkdir ../train + cd ../train || exit + python ${BASEPATH}/../src/launch.py \ + --nproc_per_node=$2 \ + --visible_devices=$4 \ + --server_id=$3 \ + --training_script=${BASEPATH}/../train.py \ + --dataset_path=$5 \ + --pre_trained=$6 \ + --device_target=$1 &> train.log & # dataset train folder +} + +if [ $# -gt 6 ] || [ $# -lt 4 ] +then + echo "Usage:\n \ + Ascend: sh run_train.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH]\n \ + " +exit 1 +fi + +if [ $1 = "Ascend" ] ; then + run_ascend "$@" +else + echo "not support platform" +fi; + diff --git a/model_zoo/resnet50_quant/src/config.py b/model_zoo/resnet50_quant/src/config.py new file mode 100755 index 000000000..523702dc4 --- /dev/null +++ b/model_zoo/resnet50_quant/src/config.py @@ -0,0 +1,68 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +network config setting, will be used in train.py and eval.py +""" +from easydict import EasyDict as ed + +quant_set = ed({ + "quantization_aware": True, +}) +config_noquant = ed({ + "class_num": 1001, + "batch_size": 32, + "loss_scale": 1024, + "momentum": 0.9, + "weight_decay": 1e-4, + "epoch_size": 90, + "pretrained_epoch_size": 1, + "buffer_size": 1000, + "image_height": 224, + "image_width": 224, + "data_load_mode": "mindrecord", + "save_checkpoint": True, + "save_checkpoint_epochs": 1, + "keep_checkpoint_max": 50, + "save_checkpoint_path": "./", + "warmup_epochs": 0, + "lr_decay_mode": "cosine", + "use_label_smooth": True, + "label_smooth_factor": 0.1, + "lr_init": 0, + "lr_max": 0.1, +}) +config_quant = ed({ + "class_num": 1001, + "batch_size": 32, + "loss_scale": 1024, + "momentum": 0.9, + "weight_decay": 1e-4, + "epoch_size": 120, + "pretrained_epoch_size": 90, + "buffer_size": 1000, + "image_height": 224, + "image_width": 224, + "data_load_mode": "mindrecord", + "save_checkpoint": True, + "save_checkpoint_epochs": 1, + "keep_checkpoint_max": 50, + "save_checkpoint_path": "./", + "warmup_epochs": 0, + "lr_decay_mode": "cosine", + "use_label_smooth": True, + "label_smooth_factor": 0.1, + "lr_init": 0, + "lr_max": 0.005, +}) diff --git a/model_zoo/resnet50_quant/src/crossentropy.py b/model_zoo/resnet50_quant/src/crossentropy.py new file mode 100644 index 000000000..b078b29f6 --- /dev/null +++ b/model_zoo/resnet50_quant/src/crossentropy.py @@ -0,0 +1,39 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""define loss function for network""" +from mindspore.nn.loss.loss import _Loss +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore import Tensor +from mindspore.common import dtype as mstype +import mindspore.nn as nn + + +class CrossEntropy(_Loss): + """the redefined loss function with SoftmaxCrossEntropyWithLogits""" + + def __init__(self, smooth_factor=0, num_classes=1001): + super(CrossEntropy, self).__init__() + self.onehot = P.OneHot() + self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) + self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32) + self.ce = nn.SoftmaxCrossEntropyWithLogits() + self.mean = P.ReduceMean(False) + + def construct(self, logit, label): + one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value) + loss = self.ce(logit, one_hot_label) + loss = self.mean(loss, 0) + return loss diff --git a/model_zoo/resnet50_quant/src/dataset.py b/model_zoo/resnet50_quant/src/dataset.py new file mode 100755 index 000000000..73c078009 --- /dev/null +++ b/model_zoo/resnet50_quant/src/dataset.py @@ -0,0 +1,157 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +create train or eval dataset. +""" +import os +from functools import partial +import mindspore.common.dtype as mstype +import mindspore.dataset.engine as de +import mindspore.dataset.transforms.vision.c_transforms as C +import mindspore.dataset.transforms.c_transforms as C2 +import mindspore.dataset.transforms.vision.py_transforms as P +from mindspore.communication.management import init, get_rank, get_group_size +from src.config import quant_set, config_quant, config_noquant + +config = config_quant if quant_set.quantization_aware else config_noquant + + +def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"): + """ + create a train or eval dataset + + Args: + dataset_path(string): the path of dataset. + do_train(bool): whether dataset is used for train or eval. + repeat_num(int): the repeat times of dataset. Default: 1 + batch_size(int): the batch size of dataset. Default: 32 + target(str): the device target. Default: Ascend + + Returns: + dataset + """ + if target == "Ascend": + device_num = int(os.getenv("RANK_SIZE")) + rank_id = int(os.getenv("RANK_ID")) + else: + init("nccl") + rank_id = get_rank() + device_num = get_group_size() + + columns_list = ['image', 'label'] + if config.data_load_mode == "mindrecord": + load_func = partial(de.MindDataset, dataset_path, columns_list) + else: + load_func = partial(de.ImageFolderDatasetV2, dataset_path) + if device_num == 1: + ds = load_func(num_parallel_workers=8, shuffle=True) + else: + ds = load_func(num_parallel_workers=8, shuffle=True, + num_shards=device_num, shard_id=rank_id) + + image_size = config.image_height + mean = [0.485 * 255, 0.456 * 255, 0.406 * 255] + std = [0.229 * 255, 0.224 * 255, 0.225 * 255] + + # define map operations + if do_train: + trans = [ + C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)), + C.RandomHorizontalFlip(prob=0.5), + C.Normalize(mean=mean, std=std), + C.HWC2CHW() + ] + else: + trans = [ + C.Decode(), + C.Resize(256), + C.CenterCrop(image_size), + C.Normalize(mean=mean, std=std), + C.HWC2CHW() + ] + + type_cast_op = C2.TypeCast(mstype.int32) + + ds = ds.map(input_columns="image", num_parallel_workers=8, operations=trans) + ds = ds.map(input_columns="label", num_parallel_workers=8, operations=type_cast_op) + + # apply batch operations + ds = ds.batch(batch_size, drop_remainder=True) + + # apply dataset repeat operation + ds = ds.repeat(repeat_num) + + return ds + + +def create_dataset_py(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"): + """ + create a train or eval dataset + + Args: + dataset_path(string): the path of dataset. + do_train(bool): whether dataset is used for train or eval. + repeat_num(int): the repeat times of dataset. Default: 1 + batch_size(int): the batch size of dataset. Default: 32 + target(str): the device target. Default: Ascend + + Returns: + dataset + """ + if target == "Ascend": + device_num = int(os.getenv("RANK_SIZE")) + rank_id = int(os.getenv("RANK_ID")) + else: + init("nccl") + rank_id = get_rank() + device_num = get_group_size() + + if do_train: + if device_num == 1: + ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True) + else: + ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, + num_shards=device_num, shard_id=rank_id) + else: + ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=False) + + image_size = 224 + + # define map operations + decode_op = P.Decode() + resize_crop_op = P.RandomResizedCrop(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)) + horizontal_flip_op = P.RandomHorizontalFlip(prob=0.5) + + resize_op = P.Resize(256) + center_crop = P.CenterCrop(image_size) + to_tensor = P.ToTensor() + normalize_op = P.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + # define map operations + if do_train: + trans = [decode_op, resize_crop_op, horizontal_flip_op, to_tensor, normalize_op] + else: + trans = [decode_op, resize_op, center_crop, to_tensor, normalize_op] + + compose = P.ComposeOp(trans) + ds = ds.map(input_columns="image", operations=compose(), num_parallel_workers=8, python_multiprocessing=True) + + # apply batch operations + ds = ds.batch(batch_size, drop_remainder=True) + + # apply dataset repeat operation + ds = ds.repeat(repeat_num) + + return ds diff --git a/model_zoo/resnet50_quant/src/launch.py b/model_zoo/resnet50_quant/src/launch.py new file mode 100644 index 000000000..abba92a54 --- /dev/null +++ b/model_zoo/resnet50_quant/src/launch.py @@ -0,0 +1,165 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""launch train script""" +import os +import sys +import json +import subprocess +import shutil +import platform +from argparse import ArgumentParser + +def parse_args(): + """ + parse args . + + Args: + + Returns: + args. + + Examples: + >>> parse_args() + """ + parser = ArgumentParser(description="mindspore distributed training launch " + "helper utilty that will spawn up " + "multiple distributed processes") + parser.add_argument("--nproc_per_node", type=int, default=1, + help="The number of processes to launch on each node, " + "for D training, this is recommended to be set " + "to the number of D in your system so that " + "each process can be bound to a single D.") + parser.add_argument("--visible_devices", type=str, default="0,1,2,3,4,5,6,7", + help="will use the visible devices sequentially") + parser.add_argument("--server_id", type=str, default="", + help="server ip") + parser.add_argument("--training_script", type=str, + help="The full path to the single D training " + "program/script to be launched in parallel, " + "followed by all the arguments for the " + "training script") + # rest from the training program + args, unknown = parser.parse_known_args() + args.training_script_args = unknown + return args + + +def main(): + print("start", __file__) + args = parse_args() + print(args) + visible_devices = args.visible_devices.split(',') + assert os.path.isfile(args.training_script) + assert len(visible_devices) >= args.nproc_per_node + print('visible_devices:{}'.format(visible_devices)) + if not args.server_id: + print('pleaser input server ip!!!') + exit(0) + print('server_id:{}'.format(args.server_id)) + + # construct hccn_table + hccn_configs = open('/etc/hccn.conf', 'r').readlines() + device_ips = {} + for hccn_item in hccn_configs: + hccn_item = hccn_item.strip() + if hccn_item.startswith('address_'): + device_id, device_ip = hccn_item.split('=') + device_id = device_id.split('_')[1] + device_ips[device_id] = device_ip + print('device_id:{}, device_ip:{}'.format(device_id, device_ip)) + hccn_table = {} + arch = platform.processor() + hccn_table['board_id'] = {'aarch64': '0x002f', 'x86_64': '0x0000'}[arch] + hccn_table['chip_info'] = '910' + hccn_table['deploy_mode'] = 'lab' + hccn_table['group_count'] = '1' + hccn_table['group_list'] = [] + instance_list = [] + usable_dev = '' + for instance_id in range(args.nproc_per_node): + instance = {} + instance['devices'] = [] + device_id = visible_devices[instance_id] + device_ip = device_ips[device_id] + usable_dev += str(device_id) + instance['devices'].append({ + 'device_id': device_id, + 'device_ip': device_ip, + }) + instance['rank_id'] = str(instance_id) + instance['server_id'] = args.server_id + instance_list.append(instance) + hccn_table['group_list'].append({ + 'device_num': str(args.nproc_per_node), + 'server_num': '1', + 'group_name': '', + 'instance_count': str(args.nproc_per_node), + 'instance_list': instance_list, + }) + hccn_table['para_plane_nic_location'] = 'device' + hccn_table['para_plane_nic_name'] = [] + for instance_id in range(args.nproc_per_node): + eth_id = visible_devices[instance_id] + hccn_table['para_plane_nic_name'].append('eth{}'.format(eth_id)) + hccn_table['para_plane_nic_num'] = str(args.nproc_per_node) + hccn_table['status'] = 'completed' + + # save hccn_table to file + table_path = os.getcwd() + if not os.path.exists(table_path): + os.mkdir(table_path) + table_fn = os.path.join(table_path, + 'rank_table_{}p_{}_{}.json'.format(args.nproc_per_node, usable_dev, args.server_id)) + with open(table_fn, 'w') as table_fp: + json.dump(hccn_table, table_fp, indent=4) + sys.stdout.flush() + + # spawn the processes + processes = [] + cmds = [] + log_files = [] + env = os.environ.copy() + env['RANK_SIZE'] = str(args.nproc_per_node) + cur_path = os.getcwd() + for rank_id in range(0, args.nproc_per_node): + os.chdir(cur_path) + device_id = visible_devices[rank_id] + device_dir = os.path.join(cur_path, 'device{}'.format(rank_id)) + env['RANK_ID'] = str(rank_id) + env['DEVICE_ID'] = str(device_id) + if args.nproc_per_node > 1: + env['MINDSPORE_HCCL_CONFIG_PATH'] = table_fn + env['RANK_TABLE_FILE'] = table_fn + if os.path.exists(device_dir): + shutil.rmtree(device_dir) + os.mkdir(device_dir) + os.chdir(device_dir) + cmd = [sys.executable, '-u'] + cmd.append(args.training_script) + cmd.extend(args.training_script_args) + log_file = open('{dir}/log{id}.log'.format(dir=device_dir, id=rank_id), 'w') + process = subprocess.Popen(cmd, stdout=log_file, stderr=log_file, env=env) + processes.append(process) + cmds.append(cmd) + log_files.append(log_file) + for process, cmd, log_file in zip(processes, cmds, log_files): + process.wait() + if process.returncode != 0: + raise subprocess.CalledProcessError(returncode=process, cmd=cmd) + log_file.close() + + +if __name__ == "__main__": + main() diff --git a/model_zoo/resnet50_quant/src/lr_generator.py b/model_zoo/resnet50_quant/src/lr_generator.py new file mode 100755 index 000000000..4a57be2f0 --- /dev/null +++ b/model_zoo/resnet50_quant/src/lr_generator.py @@ -0,0 +1,87 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""learning rate generator""" +import math +import numpy as np + + +def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode): + """ + generate learning rate array + + Args: + lr_init(float): init learning rate + lr_end(float): end learning rate + lr_max(float): max learning rate + warmup_epochs(int): number of warmup epochs + total_epochs(int): total epoch of training + steps_per_epoch(int): steps of one epoch + lr_decay_mode(string): learning rate decay mode, including steps, poly, cosine or default + + Returns: + np.array, learning rate array + """ + lr_each_step = [] + total_steps = steps_per_epoch * total_epochs + warmup_steps = steps_per_epoch * warmup_epochs + if lr_decay_mode == 'steps': + decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps] + for i in range(total_steps): + if i < decay_epoch_index[0]: + lr = lr_max + elif i < decay_epoch_index[1]: + lr = lr_max * 0.1 + elif i < decay_epoch_index[2]: + lr = lr_max * 0.01 + else: + lr = lr_max * 0.001 + lr_each_step.append(lr) + elif lr_decay_mode == 'poly': + if warmup_steps != 0: + inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps) + else: + inc_each_step = 0 + for i in range(total_steps): + if i < warmup_steps: + lr = float(lr_init) + inc_each_step * float(i) + else: + base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps))) + lr = float(lr_max) * base * base + if lr < 0.0: + lr = 0.0 + lr_each_step.append(lr) + elif lr_decay_mode == 'cosine': + decay_steps = total_steps - warmup_steps + for i in range(total_steps): + if i < warmup_steps: + lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps) + lr = float(lr_init) + lr_inc * (i + 1) + else: + linear_decay = (total_steps - i) / decay_steps + cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps)) + decayed = linear_decay * cosine_decay + 0.00001 + lr = lr_max * decayed + lr_each_step.append(lr) + else: + for i in range(total_steps): + if i < warmup_steps: + lr = lr_init + (lr_max - lr_init) * i / warmup_steps + else: + lr = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps) + lr_each_step.append(lr) + + learning_rate = np.array(lr_each_step).astype(np.float32) + + return learning_rate diff --git a/model_zoo/resnet50_quant/src/utils.py b/model_zoo/resnet50_quant/src/utils.py new file mode 100644 index 000000000..846fd7b89 --- /dev/null +++ b/model_zoo/resnet50_quant/src/utils.py @@ -0,0 +1,46 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""utils script""" + +def _load_param_into_net(model, params_dict): + """ + load fp32 model parameters to quantization model. + + Args: + model: quantization model + params_dict: f32 param + + Returns: + None + """ + iterable_dict = { + 'weight': iter([item for item in params_dict.items() if item[0].endswith('weight')]), + 'bias': iter([item for item in params_dict.items() if item[0].endswith('bias')]), + 'gamma': iter([item for item in params_dict.items() if item[0].endswith('gamma')]), + 'beta': iter([item for item in params_dict.items() if item[0].endswith('beta')]), + 'moving_mean': iter([item for item in params_dict.items() if item[0].endswith('moving_mean')]), + 'moving_variance': iter( + [item for item in params_dict.items() if item[0].endswith('moving_variance')]), + 'minq': iter([item for item in params_dict.items() if item[0].endswith('minq')]), + 'maxq': iter([item for item in params_dict.items() if item[0].endswith('maxq')]) + } + for name, param in model.parameters_and_names(): + key_name = name.split(".")[-1] + if key_name not in iterable_dict.keys(): + continue + value_param = next(iterable_dict[key_name], None) + if value_param is not None: + param.set_parameter_data(value_param[1].data) + print(f'init model param {name} with checkpoint param {value_param[0]}') diff --git a/model_zoo/resnet50_quant/train.py b/model_zoo/resnet50_quant/train.py new file mode 100755 index 000000000..b026f9727 --- /dev/null +++ b/model_zoo/resnet50_quant/train.py @@ -0,0 +1,153 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Train Resnet50 on ImageNet""" + +import os +import argparse + +from mindspore import context +from mindspore import Tensor +from mindspore.parallel._auto_parallel_context import auto_parallel_context +from mindspore.nn.optim.momentum import Momentum +from mindspore.train.model import Model, ParallelMode +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor +from mindspore.train.loss_scale_manager import FixedLossScaleManager +from mindspore.train.serialization import load_checkpoint +from mindspore.train.quant import quant +from mindspore.communication.management import init +import mindspore.nn as nn +import mindspore.common.initializer as weight_init + +from models.resnet_quant import resnet50_quant +from src.dataset import create_dataset +from src.lr_generator import get_lr +from src.config import quant_set, config_quant, config_noquant +from src.crossentropy import CrossEntropy +from src.utils import _load_param_into_net + +parser = argparse.ArgumentParser(description='Image classification') +parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute') +parser.add_argument('--device_num', type=int, default=1, help='Device num.') +parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') +parser.add_argument('--device_target', type=str, default='Ascend', help='Device target') +parser.add_argument('--pre_trained', type=str, default=None, help='Pertained checkpoint path') +args_opt = parser.parse_args() +config = config_quant if quant_set.quantization_aware else config_noquant + +if args_opt.device_target == "Ascend": + device_id = int(os.getenv('DEVICE_ID')) + rank_id = int(os.getenv('RANK_ID')) + rank_size = int(os.getenv('RANK_SIZE')) + run_distribute = rank_size > 1 + context.set_context(mode=context.GRAPH_MODE, + device_target="Ascend", + save_graphs=False, + device_id=device_id, + enable_auto_mixed_precision=True) +else: + raise ValueError("Unsupported device target.") + +if __name__ == '__main__': + # train on ascend + print("training args: {}".format(args_opt)) + print("training configure: {}".format(config)) + print("parallel args: rank_id {}, device_id {}, rank_size {}".format(rank_id, device_id, rank_size)) + epoch_size = config.epoch_size + + # distribute init + if run_distribute: + context.set_auto_parallel_context(device_num=rank_size, + parallel_mode=ParallelMode.DATA_PARALLEL, + parameter_broadcast=True, + mirror_mean=True) + init() + context.set_auto_parallel_context(device_num=args_opt.device_num, + parallel_mode=ParallelMode.DATA_PARALLEL, + mirror_mean=True) + auto_parallel_context().set_all_reduce_fusion_split_indices([107, 160]) + + # define network + net = resnet50_quant(class_num=config.class_num) + net.set_train(True) + + # weight init and load checkpoint file + if args_opt.pre_trained: + param_dict = load_checkpoint(args_opt.pre_trained) + _load_param_into_net(net, param_dict) + epoch_size = config.epoch_size - config.pretrained_epoch_size + else: + for _, cell in net.cells_and_names(): + if isinstance(cell, nn.Conv2d): + cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(), + cell.weight.default_input.shape, + cell.weight.default_input.dtype).to_tensor() + if isinstance(cell, nn.Dense): + cell.weight.default_input = weight_init.initializer(weight_init.TruncatedNormal(), + cell.weight.default_input.shape, + cell.weight.default_input.dtype).to_tensor() + if not config.use_label_smooth: + config.label_smooth_factor = 0.0 + loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num) + loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) + + # define dataset + dataset = create_dataset(dataset_path=args_opt.dataset_path, + do_train=True, + repeat_num=epoch_size, + batch_size=config.batch_size, + target=args_opt.device_target) + step_size = dataset.get_dataset_size() + + if quant_set.quantization_aware: + # convert fusion network to quantization aware network + net = quant.convert_quant_network(net, bn_fold=True, per_channel=[True, False], symmetric=[True, False]) + + # get learning rate + lr = get_lr(lr_init=config.lr_init, + lr_end=0.0, + lr_max=config.lr_max, + warmup_epochs=config.warmup_epochs, + total_epochs=config.epoch_size, + steps_per_epoch=step_size, + lr_decay_mode='cosine') + if args_opt.pre_trained: + lr = lr[config.pretrained_epoch_size * step_size:] + lr = Tensor(lr) + + # define optimization + opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, + config.weight_decay, config.loss_scale) + + # define model + if quant_set.quantization_aware: + model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}) + else: + model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}, + amp_level="O2") + + print("============== Starting Training ==============") + time_callback = TimeMonitor(data_size=step_size) + loss_callback = LossMonitor() + callbacks = [time_callback, loss_callback] + if rank_id == 0: + if config.save_checkpoint: + config_ckpt = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size, + keep_checkpoint_max=config.keep_checkpoint_max) + ckpt_callback = ModelCheckpoint(prefix="ResNet50", + directory=config.save_checkpoint_path, + config=config_ckpt) + callbacks += [ckpt_callback] + model.train(epoch_size, dataset, callbacks=callbacks) + print("============== End Training ==============") -- GitLab