提交 fe5e781a 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5345 Add new model_zoo net densenet

Merge pull request !5345 from TuDouNi/r0.5
# Contents
- [DenseNet121 Description](#densenet121-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Features](#features)
- [Mixed Precision](#mixed-precision)
- [Environment Requirements](#environment-requirements)
- [Quick Start](#quick-start)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Training Process](#training-process)
- [Training](#training)
- [Distributed Training](#distributed-training)
- [Evaluation Process](#evaluation-process)
- [Evaluation](#evaluation)
- [Model Description](#model-description)
- [Performance](#performance)
- [Training accuracy results](#training-accuracy-results)
- [Training performance results](#yraining-performance-results)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
# [DenseNet121 Description](#contents)
DenseNet121 is a convolution based neural network for the task of image classification. The paper describing the model can be found [here](https://arxiv.org/abs/1608.06993). HuaWei’s DenseNet121 is a implementation on [MindSpore](https://www.mindspore.cn/).
The repository also contains scripts to launch training and inference routines.
# [Model Architecture](#contents)
DenseNet121 builds on 4 densely connected block. In every dense block, each layer obtains additional inputs from all preceding layers and passes on its own feature-maps to all subsequent layers. Concatenation is used. Each layer is receiving a “collective knowledge” from all preceding layers.
# [Dataset](#contents)
Dataset used: ImageNet
The default configuration of the Dataset are as follows:
- Training Dataset preprocess:
- Input size of images is 224\*224
- Range (min, max) of respective size of the original size to be cropped is (0.08, 1.0)
- Range (min, max) of aspect ratio to be cropped is (0.75, 1.333)
- Probability of the image being flipped set to 0.5
- Randomly adjust the brightness, contrast, saturation (0.4, 0.4, 0.4)
- Normalize the input image with respect to mean and standard deviation
- Test Dataset preprocess:
- Input size of images is 224\*224 (Resize to 256\*256 then crops images at the center)
- Normalize the input image with respect to mean and standard deviation
# [Features](#contents)
## Mixed Precision
The [mixed precision](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware.
For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching ‘reduce precision’.
# [Environment Requirements](#contents)
- Hardware(Ascend)
- Prepare hardware environment with Ascend AI processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below:
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html)
- [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html)
# [Quick Start](#contents)
After installing MindSpore via the official website, you can start training and evaluation as follows:
```python
# run training example
python train.py --data_dir /PATH/TO/DATASET --is_distributed 0> train.log 2>&1 &
# run distributed training example
sh scripts/run_distribute_train.sh 8 rank_table.json /PATH/TO/DATASET
# run evaluation example
python eval.py --data_dir /PATH/TO/DATASET --pretrained /PATH/TO/CHECKPOINT> eval.log 2>&1 &
OR
sh scripts/run_distribute_eval.sh 8 rank_table.json /PATH/TO/DATASET /PATH/TO/CHECKPOINT
```
For distributed training, a hccl configuration file with JSON format needs to be created in advance.
Please follow the instructions in the link below:
https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools.
# [Script Description](#contents)
## [Script and Sample Code](#contents)
```
├── model_zoo
├── README.md // descriptions about all the models
├── densenet121
├── README.md // descriptions about densenet121
├── scripts
│ ├── run_distribute_train.sh // shell script for distributed on Ascend
│ ├── run_distribute_eval.sh // shell script for evaluation on Ascend
├── src
│ ├── datasets // dataset processing function
│ ├── losses
│ ├──crossentropy.py // densenet loss function
│ ├── lr_scheduler
│ ├──lr_scheduler.py // densenet learning rate schedule function
│ ├── network
│ ├──densenet.py // densenet architecture
│ ├──optimizers // densenet optimize function
│ ├──utils
│ ├──logging.py // logging function
│ ├──var_init.py // densenet variable init function
│ ├── config.py // network config
├── train.py // training script
├── eval.py // evaluation script
```
## [Script Parameters](#contents)
You can modify the training behaviour through the various flags in the `train.py` script. Flags in the `train.py` script are as follows:
```
--data_dir train data dir
--num_classes num of classes in dataset(default:1000)
--image_size image size of the dataset
--per_batch_size mini-batch size (default: 256) per gpu
--pretrained path of pretrained model
--lr_scheduler type of LR schedule: exponential, cosine_annealing
--lr initial learning rate
--lr_epochs epoch milestone of lr changing
--lr_gamma decrease lr by a factor of exponential lr_scheduler
--eta_min eta_min in cosine_annealing scheduler
--T_max T_max in cosine_annealing scheduler
--max_epoch max epoch num to train the model
--warmup_epochs warmup epoch(when batchsize is large)
--weight_decay weight decay (default: 1e-4)
--momentum momentum(default: 0.9)
--label_smooth whether to use label smooth in CE
--label_smooth_factor smooth strength of original one-hot
--log_interval logging interval(dafault:100)
--ckpt_path path to save checkpoint
--ckpt_interval the interval to save checkpoint
--is_save_on_master save checkpoint on master or all rank
--is_distributed if multi device(default: 1)
--rank local rank of distributed(default: 0)
--group_size world size of distributed(default: 1)
```
## [Training Process](#contents)
### Training
- running on Ascend
```
python train.py --data_dir /PATH/TO/DATASET --is_distributed 0> train.log 2>&1 &
```
The python command above will run in the background, The log and model checkpoint will be generated in `output/202x-xx-xx_time_xx_xx_xx/`. The loss value will be achieved as follows:
```
2020-08-22 16:58:56,617:INFO:epoch[0], iter[5003], loss:4.367, mean_fps:0.00 imgs/sec
2020-08-22 16:58:56,619:INFO:local passed
2020-08-22 17:02:19,920:INFO:epoch[1], iter[10007], loss:3.193, mean_fps:6301.11 imgs/sec
2020-08-22 17:02:19,921:INFO:local passed
2020-08-22 17:05:43,112:INFO:epoch[2], iter[15011], loss:3.096, mean_fps:6304.53 imgs/sec
2020-08-22 17:05:43,113:INFO:local passed
...
```
### Distributed Training
- running on Ascend
```
sh scripts/run_distribute_train.sh 8 rank_table.json /PATH/TO/DATASET
```
The above shell script will run distribute training in the background. You can view the results log and model checkpoint through the file `LOG[X]/output/202x-xx-xx_time_xx_xx_xx/`. The loss value will be achieved as follows:
```
2020-08-22 16:58:54,556:INFO:epoch[0], iter[5003], loss:3.857, mean_fps:0.00 imgs/sec
2020-08-22 17:02:19,188:INFO:epoch[1], iter[10007], loss:3.18, mean_fps:6260.18 imgs/sec
2020-08-22 17:05:42,490:INFO:epoch[2], iter[15011], loss:2.621, mean_fps:6301.11 imgs/sec
2020-08-22 17:09:05,686:INFO:epoch[3], iter[20015], loss:3.113, mean_fps:6304.37 imgs/sec
2020-08-22 17:12:28,925:INFO:epoch[4], iter[25019], loss:3.29, mean_fps:6303.07 imgs/sec
2020-08-22 17:15:52,167:INFO:epoch[5], iter[30023], loss:2.865, mean_fps:6302.98 imgs/sec
...
...
```
## [Evaluation Process](#contents)
### Evaluation
- evaluation on Ascend
running the command below for evaluation.
```
python eval.py --data_dir /PATH/TO/DATASET --pretrained /PATH/TO/CHECKPOINT> eval.log 2>&1 &
OR
sh scripts/run_distribute_eval.sh 8 rank_table.json /PATH/TO/DATASET /PATH/TO/CHECKPOINT
```
The above python command will run in the background. You can view the results through the file "output/202x-xx-xx_time_xx_xx_xx/202x_xxxx.log". The accuracy of the test dataset will be as follows:
```
2020-08-24 09:21:50,551:INFO:after allreduce eval: top1_correct=37657, tot=49920, acc=75.43%
2020-08-24 09:21:50,551:INFO:after allreduce eval: top5_correct=46224, tot=49920, acc=92.60%
```
# [Model Description](#contents)
## [Performance](#contents)
### Training accuracy results
| Parameters | Densenet |
| ------------------- | --------------------------- |
| Model Version | Inception V1 |
| Resource | Ascend 910 |
| Uploaded Date | 08/28/2020 (month/day/year) |
| MindSpore Version | 0.5.0-alpha |
| Dataset | ImageNet |
| epochs | 120 |
| outputs | probability |
| train performance | Top1:75.13%; Top5:92.57% |
### Training performance results
| Parameters | Densenet |
| ------------------- | --------------------------- |
| Model Version | Inception V1 |
| Resource | Ascend 910 |
| Uploaded Date | 08/28/2020 (month/day/year) |
| MindSpore Version | 0.5.0-alpha |
| Dataset | ImageNet |
| batch_size | 32 |
| outputs | probability |
| speed | 1pc:760 img/s;8pc:6000 img/s|
# [Description of Random Situation](#contents)
In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py.
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
##############test densenet example#################
python eval.py --data_dir /PATH/TO/DATASET --pretrained /PATH/TO/CHECKPOINT
"""
import os
import argparse
import datetime
import glob
import numpy as np
from mindspore import context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.communication.management import init, get_rank, get_group_size, release
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
from src.utils.logging import get_logger
from src.datasets import classification_dataset
from src.network import DenseNet121
from src.config import config
devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Davinci",
save_graphs=True, device_id=devid)
class ParameterReduce(nn.Cell):
"""
reduce parameter
"""
def __init__(self):
super(ParameterReduce, self).__init__()
self.cast = P.Cast()
self.reduce = P.AllReduce()
def construct(self, x):
one = self.cast(F.scalar_to_array(1.0), mstype.float32)
out = x * one
ret = self.reduce(out)
return ret
def parse_args(cloud_args=None):
"""
parse args
"""
parser = argparse.ArgumentParser('mindspore classification test')
# dataset related
parser.add_argument('--data_dir', type=str, default='', help='eval data dir')
parser.add_argument('--num_classes', type=int, default=1000, help='num of classes in dataset')
parser.add_argument('--image_size', type=str, default='224,224', help='image size of the dataset')
# network related
parser.add_argument('--backbone', default='resnet50', help='backbone')
parser.add_argument('--pretrained', default='', type=str, help='fully path of pretrained model to load.'
'If it is a direction, it will test all ckpt')
# logging related
parser.add_argument('--log_path', type=str, default='outputs/', help='path to save log')
parser.add_argument('--is_distributed', type=int, default=1, help='if multi device')
parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
# roma obs
parser.add_argument('--train_url', type=str, default="", help='train url')
args, _ = parser.parse_known_args()
args = merge_args(args, cloud_args)
args.per_batch_size = config.per_batch_size
args.image_size = list(map(int, args.image_size.split(',')))
return args
def get_top5_acc(top5_arg, gt_class):
sub_count = 0
for top5, gt in zip(top5_arg, gt_class):
if gt in top5:
sub_count += 1
return sub_count
def merge_args(args, cloud_args):
"""
merge args and cloud_args
"""
args_dict = vars(args)
if isinstance(cloud_args, dict):
for key in cloud_args.keys():
val = cloud_args[key]
if key in args_dict and val:
arg_type = type(args_dict[key])
if arg_type is not type(None):
val = arg_type(val)
args_dict[key] = val
return args
def test(cloud_args=None):
"""
network eval function. Get top1 and top5 ACC from classification.
The result will be save at [./outputs] by default.
"""
args = parse_args(cloud_args)
# init distributed
if args.is_distributed:
init()
args.rank = get_rank()
args.group_size = get_group_size()
args.outputs_dir = os.path.join(args.log_path,
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
args.logger = get_logger(args.outputs_dir, args.rank)
args.logger.save_args(args)
# network
args.logger.important_info('start create network')
if os.path.isdir(args.pretrained):
models = list(glob.glob(os.path.join(args.pretrained, '*.ckpt')))
print(models)
f = lambda x: -1 * int(os.path.splitext(os.path.split(x)[-1])[0].split('-')[-1].split('_')[0])
args.models = sorted(models, key=f)
else:
args.models = [args.pretrained,]
for model in args.models:
de_dataset = classification_dataset(args.data_dir, image_size=args.image_size,
per_batch_size=args.per_batch_size,
max_epoch=1, rank=args.rank, group_size=args.group_size,
mode='eval')
eval_dataloader = de_dataset.create_tuple_iterator()
network = DenseNet121(args.num_classes)
param_dict = load_checkpoint(model)
param_dict_new = {}
for key, values in param_dict.items():
if key.startswith('moments.'):
continue
elif key.startswith('network.'):
param_dict_new[key[8:]] = values
else:
param_dict_new[key] = values
print("key:" + key)
print(values.data)
load_param_into_net(network, param_dict_new)
args.logger.info('load model {} success'.format(model))
# must add
network.add_flags_recursive(fp16=True)
img_tot = 0
top1_correct = 0
top5_correct = 0
network.set_train(False)
for data, gt_classes in eval_dataloader:
output = network(Tensor(data, mstype.float32))
output = output.asnumpy()
top1_output = np.argmax(output, (-1))
top5_output = np.argsort(output)[:, -5:]
t1_correct = np.equal(top1_output, gt_classes).sum()
top1_correct += t1_correct
top5_correct += get_top5_acc(top5_output, gt_classes)
img_tot += args.per_batch_size
results = [[top1_correct], [top5_correct], [img_tot]]
args.logger.info('before results={}'.format(results))
if args.is_distributed:
model_md5 = model.replace('/', '')
tmp_dir = '/cache'
if not os.path.exists(tmp_dir):
os.mkdir(tmp_dir)
top1_correct_npy = '/cache/top1_rank_{}_{}.npy'.format(args.rank, model_md5)
top5_correct_npy = '/cache/top5_rank_{}_{}.npy'.format(args.rank, model_md5)
img_tot_npy = '/cache/img_tot_rank_{}_{}.npy'.format(args.rank, model_md5)
np.save(top1_correct_npy, top1_correct)
np.save(top5_correct_npy, top5_correct)
np.save(img_tot_npy, img_tot)
while True:
rank_ok = True
for other_rank in range(args.group_size):
top1_correct_npy = '/cache/top1_rank_{}_{}.npy'.format(other_rank, model_md5)
top5_correct_npy = '/cache/top5_rank_{}_{}.npy'.format(other_rank, model_md5)
img_tot_npy = '/cache/img_tot_rank_{}_{}.npy'.format(other_rank, model_md5)
if not os.path.exists(top1_correct_npy) or not os.path.exists(top5_correct_npy) \
or not os.path.exists(img_tot_npy):
rank_ok = False
if rank_ok:
break
top1_correct_all = 0
top5_correct_all = 0
img_tot_all = 0
for other_rank in range(args.group_size):
top1_correct_npy = '/cache/top1_rank_{}_{}.npy'.format(other_rank, model_md5)
top5_correct_npy = '/cache/top5_rank_{}_{}.npy'.format(other_rank, model_md5)
img_tot_npy = '/cache/img_tot_rank_{}_{}.npy'.format(other_rank, model_md5)
top1_correct_all += np.load(top1_correct_npy)
top5_correct_all += np.load(top5_correct_npy)
img_tot_all += np.load(img_tot_npy)
results = [[top1_correct_all], [top5_correct_all], [img_tot_all]]
results = np.array(results)
else:
results = np.array(results)
args.logger.info('after results={}'.format(results))
top1_correct = results[0, 0]
top5_correct = results[1, 0]
img_tot = results[2, 0]
acc1 = 100.0 * top1_correct / img_tot
acc5 = 100.0 * top5_correct / img_tot
args.logger.info('after allreduce eval: top1_correct={}, tot={}, acc={:.2f}%'.format(top1_correct,
img_tot,
acc1))
args.logger.info('after allreduce eval: top5_correct={}, tot={}, acc={:.2f}%'.format(top5_correct,
img_tot,
acc5))
if args.is_distributed:
release()
if __name__ == "__main__":
test()
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "sh run_distribute_eval.sh DEVICE_NUM RANK_TABLE_FILE DATASET CKPT_PATH"
echo "for example: sh run_distribute_train.sh 8 /data/hccl.json /path/to/dataset /path/to/ckpt"
echo "It is better to use absolute path."
echo "================================================================================================================="
echo "After running the scipt, the network runs in the background. The log will be generated in LOGx/log.txt"
export RANK_SIZE=$1
export RANK_TABLE_FILE=$2
DATASET=$3
CKPT_PATH=$4
for((i=0;i<RANK_SIZE;i++))
do
export DEVICE_ID=$i
rm -rf LOG$i
mkdir ./LOG$i
cp ./*.py ./LOG$i
cp -r ./src ./LOG$i
cd ./LOG$i || exit
export RANK_ID=$i
echo "start training for rank $i, device $DEVICE_ID"
env > env.log
python eval.py \
--data_dir=$DATASET \
--pretrained=$CKPT_PATH > log.txt 2>&1 &
cd ../
done
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "sh run_distribute_train.sh DEVICE_NUM RANK_TABLE_FILE DATASET"
echo "for example: sh run_distribute_train.sh 8 /data/hccl.json /path/to/dataset"
echo "It is better to use absolute path."
echo "================================================================================================================="
echo "After running the scipt, the network runs in the background. The log will be generated in LOGx/log.txt"
export RANK_SIZE=$1
export RANK_TABLE_FILE=$2
DATASET=$3
for((i=0;i<RANK_SIZE;i++))
do
export DEVICE_ID=$i
rm -rf LOG$i
mkdir ./LOG$i
cp ./*.py ./LOG$i
cp -r ./src ./LOG$i
cd ./LOG$i || exit
export RANK_ID=$i
echo "start training for rank $i, device $DEVICE_ID"
env > env.log
python train.py \
--data_dir=$DATASET > log.txt 2>&1 &
cd ../
done
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""config"""
from easydict import EasyDict as ed
config = ed({
"image_size": '224,224',
"num_classes": 1000,
"lr": 0.1,
"lr_scheduler": 'cosine_annealing',
"lr_epochs": '30,60,90,120',
"lr_gamma": 0.1,
"eta_min": 0,
"T_max": 120,
"max_epoch": 120,
"per_batch_size": 32,
"warmup_epochs": 0,
"weight_decay": 0.0001,
"momentum": 0.9,
"is_dynamic_loss_scale": 0,
"loss_scale": 1024,
"label_smooth": 0,
"label_smooth_factor": 0.1,
"log_interval": 100,
"ckpt_interval": 2000,
"ckpt_path": 'outputs/',
"is_save_on_master": 1,
"rank": 0,
"group_size": 1
})
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
read dataset for classification
"""
from .classification import classification_dataset
__all__ = ["classification_dataset"]
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
A function that returns a dataset for classification.
"""
import os
from PIL import Image, ImageFile
from mindspore import dtype as mstype
import mindspore.dataset as de
import mindspore.dataset.transforms.vision.c_transforms as vision_C
import mindspore.dataset.transforms.c_transforms as normal_C
from src.datasets.sampler import DistributedSampler
ImageFile.LOAD_TRUNCATED_IMAGES = True
class TxtDataset():
"""
read dataset from txt
"""
def __init__(self, root, txt_name):
super(TxtDataset, self).__init__()
self.imgs = []
self.labels = []
fin = open(txt_name, "r")
for line in fin:
img_name, label = line.strip().split(' ')
self.imgs.append(os.path.join(root, img_name))
self.labels.append(int(label))
fin.close()
def __getitem__(self, index):
img = Image.open(self.imgs[index]).convert('RGB')
return img, self.labels[index]
def __len__(self):
return len(self.imgs)
def classification_dataset(data_dir, image_size, per_batch_size, max_epoch, rank, group_size,
mode='train',
input_mode='folder',
root='',
num_parallel_workers=None,
shuffle=None,
sampler=None,
class_indexing=None,
drop_remainder=True,
transform=None,
target_transform=None):
"""
A function that returns a dataset for classification. The mode of input dataset could be "folder" or "txt".
If it is "folder", all images within one folder have the same label. If it is "txt", all paths of images
are written into a textfile.
Args:
data_dir (str): Path to the root directory that contains the dataset for "input_mode="folder"".
Or path of the textfile that contains every image's path of the dataset.
image_size (str): Size of the input images.
per_batch_size (int): the batch size of evey step during training.
max_epoch (int): the number of epochs.
rank (int): The shard ID within num_shards (default=None).
group_size (int): Number of shards that the dataset should be divided
into (default=None).
mode (str): "train" or others. Default: " train".
input_mode (str): The form of the input dataset. "folder" or "txt". Default: "folder".
root (str): the images path for "input_mode="txt"". Default: " ".
num_parallel_workers (int): Number of workers to read the data. Default: None.
shuffle (bool): Whether or not to perform shuffle on the dataset
(default=None, performs shuffle).
sampler (Sampler): Object used to choose samples from the dataset. Default: None.
class_indexing (dict): A str-to-int mapping from folder name to index
(default=None, the folder names will be sorted
alphabetically and each class will be given a
unique index starting from 0).
Examples:
>>> from src.datasets.classification import classification_dataset
>>> # path to imagefolder directory. This directory needs to contain sub-directories which contain the images
>>> dataset_dir = "/path/to/imagefolder_directory"
>>> de_dataset = classification_dataset(train_data_dir, image_size=[224, 244],
>>> per_batch_size=64, max_epoch=100,
>>> rank=0, group_size=4)
>>> # Path of the textfile that contains every image's path of the dataset.
>>> dataset_dir = "/path/to/dataset/images/train.txt"
>>> images_dir = "/path/to/dataset/images"
>>> de_dataset = classification_dataset(train_data_dir, image_size=[224, 244],
>>> per_batch_size=64, max_epoch=100,
>>> rank=0, group_size=4,
>>> input_mode="txt", root=images_dir)
"""
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
if transform is None:
if mode == 'train':
transform_img = [
vision_C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
vision_C.RandomHorizontalFlip(prob=0.5),
vision_C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4),
vision_C.Normalize(mean=mean, std=std),
vision_C.HWC2CHW()
]
else:
transform_img = [
vision_C.Decode(),
vision_C.Resize((256, 256)),
vision_C.CenterCrop(image_size),
vision_C.Normalize(mean=mean, std=std),
vision_C.HWC2CHW()
]
else:
transform_img = transform
if target_transform is None:
transform_label = [
normal_C.TypeCast(mstype.int32)
]
else:
transform_label = target_transform
if input_mode == 'folder':
de_dataset = de.ImageFolderDatasetV2(data_dir, num_parallel_workers=num_parallel_workers,
shuffle=shuffle, sampler=sampler, class_indexing=class_indexing,
num_shards=group_size, shard_id=rank)
else:
dataset = TxtDataset(root, data_dir)
sampler = DistributedSampler(dataset, rank, group_size, shuffle=shuffle)
de_dataset = de.GeneratorDataset(dataset, ["image", "label"], sampler=sampler)
de_dataset.set_dataset_size(len(sampler))
de_dataset = de_dataset.map(input_columns="image", num_parallel_workers=8, operations=transform_img)
de_dataset = de_dataset.map(input_columns="label", num_parallel_workers=8, operations=transform_label)
columns_to_project = ["image", "label"]
de_dataset = de_dataset.project(columns=columns_to_project)
de_dataset = de_dataset.batch(per_batch_size, drop_remainder=drop_remainder)
de_dataset = de_dataset.repeat(max_epoch)
return de_dataset
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
shuffle and distribute sample
"""
import math
import numpy as np
class DistributedSampler():
"""
function to distribute and shuffle sample
"""
def __init__(self, dataset, rank, group_size, shuffle=True, seed=0):
self.dataset = dataset
self.rank = rank
self.group_size = group_size
self.dataset_length = len(self.dataset)
self.num_samples = int(math.ceil(self.dataset_length * 1.0 / self.group_size))
self.total_size = self.num_samples * self.group_size
self.shuffle = shuffle
self.seed = seed
def __iter__(self):
if self.shuffle:
self.seed = (self.seed + 1) & 0xffffffff
np.random.seed(self.seed)
indices = np.random.permutation(self.dataset_length).tolist()
else:
indices = list(range(len(self.dataset_length)))
indices += indices[:(self.total_size - len(indices))]
indices = indices[self.rank::self.group_size]
return iter(indices)
def __len__(self):
return self.num_samples
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
loss function
"""
from .crossentropy import *
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
loss function CrossEntropy
"""
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore import Tensor
from mindspore.common import dtype as mstype
import mindspore.nn as nn
class CrossEntropy(_Loss):
"""
loss function CrossEntropy
"""
def __init__(self, smooth_factor=0., num_classes=1000):
super(CrossEntropy, self).__init__()
self.onehot = P.OneHot()
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / (num_classes -1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits()
self.mean = P.ReduceMean(False)
def construct(self, logit, label):
one_hot_label = self.onehot(label,
F.shape(logit)[1], self.on_value, self.off_value)
loss = self.ce(logit, one_hot_label)
loss = self.mean(loss, 0)
return loss
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
learning rate scheduler
"""
from .lr_scheduler import *
此差异已折叠。
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
densenet network
"""
from .densenet import DenseNet121
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
model architecture of densenet
"""
import math
from collections import OrderedDict
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common import initializer as init
from src.utils.var_init import default_recurisive_init, KaimingNormal
__all__ = ["DenseNet121"]
class GlobalAvgPooling(nn.Cell):
"""
GlobalAvgPooling function.
"""
def __init__(self):
super(GlobalAvgPooling, self).__init__()
self.mean = P.ReduceMean(True)
self.shape = P.Shape()
self.reshape = P.Reshape()
def construct(self, x):
x = self.mean(x, (2, 3))
b, c, _, _ = self.shape(x)
x = self.reshape(x, (b, c))
return x
class CommonHead(nn.Cell):
def __init__(self, num_classes, out_channels):
super(CommonHead, self).__init__()
self.avgpool = GlobalAvgPooling()
self.fc = nn.Dense(out_channels, num_classes, has_bias=True)
def construct(self, x):
x = self.avgpool(x)
x = self.fc(x)
return x
def conv7x7(in_channels, out_channels, stride=1, padding=3, has_bias=False):
return nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=stride, has_bias=has_bias,
padding=padding, pad_mode="pad")
def conv3x3(in_channels, out_channels, stride=1, padding=1, has_bias=False):
return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, has_bias=has_bias,
padding=padding, pad_mode="pad")
def conv1x1(in_channels, out_channels, stride=1, padding=0, has_bias=False):
return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, has_bias=has_bias,
padding=padding, pad_mode="pad")
class _DenseLayer(nn.Cell):
"""
the dense layer, include 2 conv layer
"""
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
super(_DenseLayer, self).__init__()
self.norm1 = nn.BatchNorm2d(num_input_features)
self.relu1 = nn.ReLU()
self.conv1 = conv1x1(num_input_features, bn_size*growth_rate)
self.norm2 = nn.BatchNorm2d(bn_size*growth_rate)
self.relu2 = nn.ReLU()
self.conv2 = conv3x3(bn_size*growth_rate, growth_rate)
# nn.Dropout in MindSpore use keep_prob, diff from Pytorch
self.keep_prob = 1 - drop_rate
self.dropout = nn.Dropout(keep_prob=self.keep_prob)
def construct(self, features):
bottleneck = self.conv1(self.relu1(self.norm1(features)))
new_features = self.conv2(self.relu2(self.norm2(bottleneck)))
if self.keep_prob < 1:
new_features = self.dropout(new_features)
return new_features
class _DenseBlock(nn.Cell):
"""
the dense block
"""
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):
super(_DenseBlock, self).__init__()
self.cell_list = nn.CellList()
for i in range(num_layers):
layer = _DenseLayer(
num_input_features + i * growth_rate,
growth_rate=growth_rate,
bn_size=bn_size,
drop_rate=drop_rate
)
self.cell_list.append(layer)
self.concate = P.Concat(axis=1)
def construct(self, init_features):
features = init_features
for layer in self.cell_list:
new_features = layer(features)
features = self.concate((features, new_features))
return features
class _Transition(nn.Cell):
"""
the transiton layer
"""
def __init__(self, num_input_features, num_output_features):
super(_Transition, self).__init__()
self.features = nn.SequentialCell(OrderedDict([
('norm', nn.BatchNorm2d(num_input_features)),
('relu', nn.ReLU()),
('conv', conv1x1(num_input_features, num_output_features)),
('pool', nn.MaxPool2d(kernel_size=2, stride=2))
]))
def construct(self, x):
x = self.features(x)
return x
class Densenet(nn.Cell):
"""
the densenet architecture
"""
__constants__ = ['features']
def __init__(self, growth_rate, block_config, num_init_features, bn_size=4, drop_rate=0):
super(Densenet, self).__init__()
layers = OrderedDict()
layers['conv0'] = conv7x7(3, num_init_features, stride=2, padding=3)
layers['norm0'] = nn.BatchNorm2d(num_init_features)
layers['relu0'] = nn.ReLU()
layers['pool0'] = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
# Each denseblock
num_features = num_init_features
for i, num_layers in enumerate(block_config):
block = _DenseBlock(
num_layers=num_layers,
num_input_features=num_features,
bn_size=bn_size,
growth_rate=growth_rate,
drop_rate=drop_rate
)
layers['denseblock%d'%(i+1)] = block
num_features = num_features + num_layers*growth_rate
if i != len(block_config)-1:
trans = _Transition(num_input_features=num_features,
num_output_features=num_features // 2)
layers['transition%d'%(i+1)] = trans
num_features = num_features // 2
# Final batch norm
layers['norm5'] = nn.BatchNorm2d(num_features)
layers['relu5'] = nn.ReLU()
self.features = nn.SequentialCell(layers)
self.out_channels = num_features
def construct(self, x):
x = self.features(x)
return x
def get_out_channels(self):
return self.out_channels
def _densenet121(**kwargs):
return Densenet(growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64, **kwargs)
def _densenet161(**kwargs):
return Densenet(growth_rate=48, block_config=(6, 12, 36, 24), num_init_features=96, **kwargs)
def _densenet169(**kwargs):
return Densenet(growth_rate=32, block_config=(6, 12, 32, 32), num_init_features=64, **kwargs)
def _densenet201(**kwargs):
return Densenet(growth_rate=32, block_config=(6, 12, 48, 32), num_init_features=64, **kwargs)
class DenseNet121(nn.Cell):
"""
the densenet121 architectur
"""
def __init__(self, num_classes):
super(DenseNet121, self).__init__()
self.backbone = _densenet121()
out_channels = self.backbone.get_out_channels()
self.head = CommonHead(num_classes, out_channels)
default_recurisive_init(self)
for _, cell in self.cells_and_names():
if isinstance(cell, nn.Conv2d):
cell.weight.default_input = init.initializer(KaimingNormal(a=math.sqrt(5), mode='fan_out',
nonlinearity='relu'),
cell.weight.default_input.shape,
cell.weight.default_input.dtype).to_tensor()
elif isinstance(cell, nn.BatchNorm2d):
cell.gamma.default_input = init.initializer('ones', cell.gamma.default_input.shape).to_tensor()
cell.beta.default_input = init.initializer('zeros', cell.beta.default_input.shape).to_tensor()
elif isinstance(cell, nn.Dense):
cell.bias.default_input = init.initializer('zeros', cell.bias.default_input.shape).to_tensor()
def construct(self, x):
x = self.backbone(x)
x = self.head(x)
return x
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
get parameter function
"""
def get_param_groups(network):
"""
get parameter groups
"""
decay_params = []
no_decay_params = []
for x in network.trainable_params():
parameter_name = x.name
if parameter_name.endswith('.bias'):
# all bias not using weight decay
# print('no decay:{}'.format(parameter_name))
no_decay_params.append(x)
elif parameter_name.endswith('.gamma'):
# bn weight bias not using weight decay, be carefully for now x not include BN
# print('no decay:{}'.format(parameter_name))
no_decay_params.append(x)
elif parameter_name.endswith('.beta'):
# bn weight bias not using weight decay, be carefully for now x not include BN
# print('no decay:{}'.format(parameter_name))
no_decay_params.append(x)
else:
decay_params.append(x)
return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
get logger.
"""
import logging
import os
import sys
from datetime import datetime
class LOGGER(logging.Logger):
"""
set up logging file.
Args:
logger_name (string): logger name.
log_dir (string): path of logger.
Returns:
string, logger path
"""
def __init__(self, logger_name, rank=0):
super(LOGGER, self).__init__(logger_name)
if rank % 8 == 0:
console = logging.StreamHandler(sys.stdout)
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
console.setFormatter(formatter)
self.addHandler(console)
def setup_logging_file(self, log_dir, rank=0):
"""set up log file"""
self.rank = rank
if not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(rank)
self.log_fn = os.path.join(log_dir, log_name)
fh = logging.FileHandler(self.log_fn)
fh.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
fh.setFormatter(formatter)
self.addHandler(fh)
def info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO):
self._log(logging.INFO, msg, args, **kwargs)
def save_args(self, args):
self.info('Args:')
args_dict = vars(args)
for key in args_dict.keys():
self.info('--> %s: %s', key, args_dict[key])
self.info('')
def important_info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO) and self.rank == 0:
line_width = 2
important_msg = '\n'
important_msg += ('*'*70 + '\n')*line_width
important_msg += ('*'*line_width + '\n')*2
important_msg += '*'*line_width + ' '*8 + msg + '\n'
important_msg += ('*'*line_width + '\n')*2
important_msg += ('*'*70 + '\n')*line_width
self.info(important_msg, *args, **kwargs)
def get_logger(path, rank):
logger = LOGGER("mindversion", rank)
logger.setup_logging_file(path, rank)
return logger
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Initialize.
"""
import math
from functools import reduce
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import initializer as init
def _calculate_gain(nonlinearity, param=None):
r"""
Return the recommended gain value for the given nonlinearity function.
The values are as follows:
================= ====================================================
nonlinearity gain
================= ====================================================
Linear / Identity :math:`1`
Conv{1,2,3}D :math:`1`
Sigmoid :math:`1`
Tanh :math:`\frac{5}{3}`
ReLU :math:`\sqrt{2}`
Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
================= ====================================================
Args:
nonlinearity: the non-linear function
param: optional parameter for the non-linear function
Examples:
>>> gain = calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2
"""
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
return 1
if nonlinearity == 'tanh':
return 5.0 / 3
if nonlinearity == 'relu':
return math.sqrt(2.0)
if nonlinearity == 'leaky_relu':
if param is None:
negative_slope = 0.01
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
negative_slope = param
else:
raise ValueError("negative_slope {} not a valid number".format(param))
return math.sqrt(2.0 / (1 + negative_slope ** 2))
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
def _assignment(arr, num):
"""Assign the value of `num` to `arr`."""
if arr.shape == ():
arr = arr.reshape((1))
arr[:] = num
arr = arr.reshape(())
else:
if isinstance(num, np.ndarray):
arr[:] = num[:]
else:
arr[:] = num
return arr
def _calculate_in_and_out(arr):
"""
Calculate n_in and n_out.
Args:
arr (Array): Input array.
Returns:
Tuple, a tuple with two elements, the first element is `n_in` and the second element is `n_out`.
"""
dim = len(arr.shape)
if dim < 2:
raise ValueError("If initialize data with xavier uniform, the dimension of data must greater than 1.")
n_in = arr.shape[1]
n_out = arr.shape[0]
if dim > 2:
counter = reduce(lambda x, y: x * y, arr.shape[2:])
n_in *= counter
n_out *= counter
return n_in, n_out
def _select_fan(array, mode):
mode = mode.lower()
valid_modes = ['fan_in', 'fan_out']
if mode not in valid_modes:
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
fan_in, fan_out = _calculate_in_and_out(array)
return fan_in if mode == 'fan_in' else fan_out
class KaimingInit(init.Initializer):
r"""
Base Class. Initialize the array with He kaiming algorithm.
Args:
a: the negative slope of the rectifier used after this layer (only
used with ``'leaky_relu'``)
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
preserves the magnitude of the variance of the weights in the
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
backwards pass.
nonlinearity: the non-linear function, recommended to use only with
``'relu'`` or ``'leaky_relu'`` (default).
"""
def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'):
super(KaimingInit, self).__init__()
self.mode = mode
self.gain = _calculate_gain(nonlinearity, a)
class KaimingUniform(KaimingInit):
r"""
Initialize the array with He kaiming uniform algorithm. The resulting tensor will
have values sampled from :math:`\mathcal{U}(-\text{bound}, \text{bound})` where
.. math::
\text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
Input:
arr (Array): The array to be assigned.
Returns:
Array, assigned array.
Examples:
>>> w = np.empty(3, 5)
>>> KaimingUniform(w, mode='fan_in', nonlinearity='relu')
"""
def _initialize(self, arr):
fan = _select_fan(arr, self.mode)
bound = math.sqrt(3.0) * self.gain / math.sqrt(fan)
np.random.seed(1)
data = np.random.uniform(-bound, bound, arr.shape)
_assignment(arr, data)
class KaimingNormal(KaimingInit):
r"""
Initialize the array with He kaiming normal algorithm. The resulting tensor will
have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where
.. math::
\text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}}
Input:
arr (Array): The array to be assigned.
Returns:
Array, assigned array.
Examples:
>>> w = np.empty(3, 5)
>>> KaimingNormal(w, mode='fan_out', nonlinearity='relu')
"""
def _initialize(self, arr):
fan = _select_fan(arr, self.mode)
std = self.gain / math.sqrt(fan)
np.random.seed(1)
data = np.random.normal(0, std, arr.shape)
_assignment(arr, data)
def default_recurisive_init(custom_cell):
"""default_recurisive_init"""
for _, cell in custom_cell.cells_and_names():
if isinstance(cell, nn.Conv2d):
cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)),
cell.weight.default_input.shape,
cell.weight.default_input.dtype).to_tensor()
if cell.bias is not None:
fan_in, _ = _calculate_in_and_out(cell.weight.default_input.asnumpy())
bound = 1 / math.sqrt(fan_in)
np.random.seed(1)
cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape),
cell.bias.default_input.dtype)
elif isinstance(cell, nn.Dense):
cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)),
cell.weight.default_input.shape,
cell.weight.default_input.dtype).to_tensor()
if cell.bias is not None:
fan_in, _ = _calculate_in_and_out(cell.weight.default_input.asnumpy())
bound = 1 / math.sqrt(fan_in)
np.random.seed(1)
cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape),
cell.bias.default_input.dtype)
elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
pass
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train launch."""
import os
import time
import argparse
import datetime
import mindspore.nn as nn
from mindspore import Tensor, ParallelMode
from mindspore.nn.optim import Momentum
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.train.callback import ModelCheckpoint
from mindspore.train.callback import CheckpointConfig, Callback
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.model import Model
from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager
from mindspore import context
from src.optimizers import get_param_groups
from src.network import DenseNet121
from src.datasets import classification_dataset
from src.losses.crossentropy import CrossEntropy
from src.lr_scheduler import MultiStepLR, CosineAnnealingLR
from src.utils.logging import get_logger
from src.config import config
devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
device_target="Davinci", save_graphs=False, device_id=devid)
class BuildTrainNetwork(nn.Cell):
"""build training network"""
def __init__(self, network, criterion):
super(BuildTrainNetwork, self).__init__()
self.network = network
self.criterion = criterion
def construct(self, input_data, label):
output = self.network(input_data)
loss = self.criterion(output, label)
return loss
class ProgressMonitor(Callback):
"""monitor loss and time"""
def __init__(self, args):
super(ProgressMonitor, self).__init__()
self.me_epoch_start_time = 0
self.me_epoch_start_step_num = 0
self.args = args
self.ckpt_history = []
def begin(self, run_context):
self.args.logger.info('start network train...')
def epoch_begin(self, run_context):
pass
def epoch_end(self, run_context, *me_args):
"""process epoch end"""
cb_params = run_context.original_args()
me_step = cb_params.cur_step_num - 1
real_epoch = me_step // self.args.steps_per_epoch
time_used = time.time() - self.me_epoch_start_time
fps_mean = self.args.per_batch_size * (me_step-self.me_epoch_start_step_num) * self.args.group_size / time_used
self.args.logger.info('epoch[{}], iter[{}], loss:{},'
'mean_fps:{:.2f} imgs/sec'.format(real_epoch, me_step, cb_params.net_outputs, fps_mean))
if self.args.rank_save_ckpt_flag:
import glob
ckpts = glob.glob(os.path.join(self.args.outputs_dir, '*.ckpt'))
for ckpt in ckpts:
ckpt_fn = os.path.basename(ckpt)
if not ckpt_fn.startswith('{}-'.format(self.args.rank)):
continue
if ckpt in self.ckpt_history:
continue
self.ckpt_history.append(ckpt)
self.args.logger.info('epoch[{}], iter[{}], loss:{}, ckpt:{},'
'ckpt_fn:{}'.format(real_epoch, me_step, cb_params.net_outputs, ckpt, ckpt_fn))
self.me_epoch_start_step_num = me_step
self.me_epoch_start_time = time.time()
def step_begin(self, run_context):
pass
def step_end(self, run_context, *me_args):
pass
def end(self, run_context):
self.args.logger.info('end network train...')
def parse_args(cloud_args=None):
"""parameters"""
parser = argparse.ArgumentParser('mindspore classification training')
# dataset related
parser.add_argument('--data_dir', type=str, default='', help='train data dir')
# network related
parser.add_argument('--pretrained', default='', type=str, help='model_path, local pretrained model to load')
# distributed related
parser.add_argument('--is_distributed', type=int, default=1, help='if multi device')
# roma obs
parser.add_argument('--train_url', type=str, default="", help='train url')
args, _ = parser.parse_known_args()
args = merge_args(args, cloud_args)
args.image_size = config.image_size
args.num_classes = config.num_classes
args.lr = config.lr
args.lr_scheduler = config.lr_scheduler
args.lr_epochs = config.lr_epochs
args.lr_gamma = config.lr_gamma
args.eta_min = config.eta_min
args.T_max = config.T_max
args.max_epoch = config.max_epoch
args.warmup_epochs = config.warmup_epochs
args.weight_decay = config.weight_decay
args.momentum = config.momentum
args.is_dynamic_loss_scale = config.is_dynamic_loss_scale
args.loss_scale = config.loss_scale
args.label_smooth = config.label_smooth
args.label_smooth_factor = config.label_smooth_factor
args.ckpt_interval = config.ckpt_interval
args.ckpt_path = config.ckpt_path
args.is_save_on_master = config.is_save_on_master
args.rank = config.rank
args.group_size = config.group_size
args.log_interval = config.log_interval
args.per_batch_size = config.per_batch_size
args.lr_epochs = list(map(int, args.lr_epochs.split(',')))
args.image_size = list(map(int, args.image_size.split(',')))
return args
def merge_args(args, cloud_args):
"""dictionary"""
args_dict = vars(args)
if isinstance(cloud_args, dict):
for key in cloud_args.keys():
val = cloud_args[key]
if key in args_dict and val:
arg_type = type(args_dict[key])
if arg_type is not type(None):
val = arg_type(val)
args_dict[key] = val
return args
def train(cloud_args=None):
"""training process"""
args = parse_args(cloud_args)
# init distributed
if args.is_distributed:
init()
args.rank = get_rank()
args.group_size = get_group_size()
if args.is_dynamic_loss_scale == 1:
args.loss_scale = 1 # for dynamic loss scale can not set loss scale in momentum opt
# select for master rank save ckpt or all rank save, compatiable for model parallel
args.rank_save_ckpt_flag = 0
if args.is_save_on_master:
if args.rank == 0:
args.rank_save_ckpt_flag = 1
else:
args.rank_save_ckpt_flag = 1
# logger
args.outputs_dir = os.path.join(args.ckpt_path,
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
args.logger = get_logger(args.outputs_dir, args.rank)
# dataloader
de_dataset = classification_dataset(args.data_dir, args.image_size,
args.per_batch_size, args.max_epoch,
args.rank, args.group_size)
de_dataset.map_model = 4 # !!!important
args.steps_per_epoch = de_dataset.get_dataset_size()
args.logger.save_args(args)
# network
args.logger.important_info('start create network')
# get network and init
network = DenseNet121(args.num_classes)
# loss
if not args.label_smooth:
args.label_smooth_factor = 0.0
criterion = CrossEntropy(smooth_factor=args.label_smooth_factor,
num_classes=args.num_classes)
# load pretrain model
if os.path.isfile(args.pretrained):
param_dict = load_checkpoint(args.pretrained)
param_dict_new = {}
for key, values in param_dict.items():
if key.startswith('moments.'):
continue
elif key.startswith('network.'):
param_dict_new[key[8:]] = values
else:
param_dict_new[key] = values
load_param_into_net(network, param_dict_new)
args.logger.info('load model {} success'.format(args.pretrained))
# lr scheduler
if args.lr_scheduler == 'exponential':
lr_scheduler = MultiStepLR(args.lr,
args.lr_epochs,
args.lr_gamma,
args.steps_per_epoch,
args.max_epoch,
warmup_epochs=args.warmup_epochs)
elif args.lr_scheduler == 'cosine_annealing':
lr_scheduler = CosineAnnealingLR(args.lr,
args.T_max,
args.steps_per_epoch,
args.max_epoch,
warmup_epochs=args.warmup_epochs,
eta_min=args.eta_min)
else:
raise NotImplementedError(args.lr_scheduler)
lr_schedule = lr_scheduler.get_lr()
# optimizer
opt = Momentum(params=get_param_groups(network),
learning_rate=Tensor(lr_schedule),
momentum=args.momentum,
weight_decay=args.weight_decay,
loss_scale=args.loss_scale)
# mixed precision training
criterion.add_flags_recursive(fp32=True)
# package training process, adjust lr + forward + backward + optimizer
train_net = BuildTrainNetwork(network, criterion)
if args.is_distributed:
parallel_mode = ParallelMode.DATA_PARALLEL
else:
parallel_mode = ParallelMode.STAND_ALONE
if args.is_dynamic_loss_scale == 1:
loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536, scale_factor=2, scale_window=2000)
else:
loss_scale_manager = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False)
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size,
parameter_broadcast=True, mirror_mean=True)
model = Model(train_net, optimizer=opt, metrics=None, loss_scale_manager=loss_scale_manager, amp_level="O3")
# checkpoint save
progress_cb = ProgressMonitor(args)
callbacks = [progress_cb,]
if args.rank_save_ckpt_flag:
ckpt_max_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval,
keep_checkpoint_max=ckpt_max_num)
ckpt_cb = ModelCheckpoint(config=ckpt_config,
directory=args.outputs_dir,
prefix='{}'.format(args.rank))
callbacks.append(ckpt_cb)
model.train(args.max_epoch, de_dataset, callbacks=callbacks)
if __name__ == "__main__":
train()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册