提交 7962287f 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5097 Adding ShuffleNetV2 model to modelzoo

Merge pull request !5097 from alashkari/shufflenet
# Contents
- [ShuffleNetV2 Description](#shufflenetv2-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Training Process](#training-process)
- [Evaluation Process](#evaluation-process)
- [Evaluation](#evaluation)
- [Model Description](#model-description)
- [Performance](#performance)
- [Training Performance](#evaluation-performance)
- [Inference Performance](#evaluation-performance)
# [ShuffleNetV2 Description](#contents)
ShuffleNetV2 is a much faster and more accurate netowrk than the previous networks on different platforms such as Ascend or GPU.
[Paper](https://arxiv.org/pdf/1807.11164.pdf) Ma, N., Zhang, X., Zheng, H. T., & Sun, J. (2018). Shufflenet v2: Practical guidelines for efficient cnn architecture design. In Proceedings of the European conference on computer vision (ECCV) (pp. 116-131).
# [Model architecture](#contents)
The overall network architecture of ShuffleNetV2 is show below:
[Link](https://arxiv.org/pdf/1807.11164.pdf)
# [Dataset](#contents)
Dataset used: [imagenet](http://www.image-net.org/)
- Dataset size: ~125G, 1.2W colorful images in 1000 classes
- Train: 120G, 1.2W images
- Test: 5G, 50000 images
- Data format: RGB images.
- Note: Data will be processed in src/dataset.py
# [Environment Requirements](#contents)
- Hardware(GPU)
- Prepare hardware environment with GPU processor.
- Framework
- [MindSpore](http://10.90.67.50/mindspore/archive/20200506/OpenSource/me_vm_x86/)
- For more information, please check the resources below:
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html)
- [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html)
# [Script description](#contents)
## [Script and sample code](#contents)
```python
+-- ShuffleNetV2
+-- Readme.md # descriptions about ShuffleNetV2
+-- scripts
+--run_distribute_train_for_gpu.sh # shell script for distributed training
+--run_eval_for_multi_gpu.sh # shell script for evaluation
+--run_standalone_train_for_gpu.sh # shell script for standalone training
+-- src
+--config.py # parameter configuration
+--dataset.py # creating dataset
+--loss.py # loss function for network
+--lr_generator.py # learning rate config
+-- train.py # training script
+-- eval.py # evaluation script
+-- blocks.py # ShuffleNetV2 blocks
+-- network.py # ShuffleNetV2 model network
```
## [Training process](#contents)
### Usage
You can start training using python or shell scripts. The usage of shell scripts as follows:
- Ditributed training on GPU: sh run_distribute_train_for_gpu.sh [DATA_DIR]
- Standalone training on GPU: sh run_standalone_train_for_gpu.sh [DEVICE_ID] [DATA_DIR]
### Launch
```
# training example
python:
GPU: mpirun --allow-run-as-root -n 8 python train.py --is_distributed --platform 'GPU' --dataset_path '~/imagenet/train/' > train.log 2>&1 &
shell:
GPU: sh run_distribute_train_for_gpu.sh ~/imagenet/train/
```
### Result
Training result will be stored in the example path. Checkpoints will be stored at `. /checkpoint` by default, and training log will be redirected to `./train/train.log`.
## [Eval process](#contents)
### Usage
You can start evaluation using python or shell scripts. The usage of shell scripts as follows:
- GPU: sh run_eval_for_multi_gpu.sh [DEVICE_ID] [EPOCH]
### Launch
```
# infer example
python:
GPU: CUDA_VISIBLE_DEVICES=0 python eval.py --platform 'GPU' --dataset_path '~/imagenet/val/' --epoch 250 > eval.log 2>&1 &
shell:
GPU: sh run_eval_for_multi_gpu.sh 0 250
```
> checkpoint can be produced in training process.
### Result
Inference result will be stored in the example path, you can find result in `val.log`.
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore.nn as nn
import mindspore.ops.operations as P
class ShuffleV2Block(nn.Cell):
def __init__(self, inp, oup, mid_channels, *, ksize, stride):
super(ShuffleV2Block, self).__init__()
self.stride = stride
##assert stride in [1, 2]
self.mid_channels = mid_channels
self.ksize = ksize
pad = ksize // 2
self.pad = pad
self.inp = inp
outputs = oup - inp
branch_main = [
# pw
nn.Conv2d(in_channels=inp, out_channels=mid_channels, kernel_size=1, stride=1,
pad_mode='pad', padding=0, has_bias=False),
nn.BatchNorm2d(num_features=mid_channels, momentum=0.9),
nn.ReLU(),
# dw
nn.Conv2d(in_channels=mid_channels, out_channels=mid_channels, kernel_size=ksize, stride=stride,
pad_mode='pad', padding=pad, group=mid_channels, has_bias=False),
nn.BatchNorm2d(num_features=mid_channels, momentum=0.9),
# pw-linear
nn.Conv2d(in_channels=mid_channels, out_channels=outputs, kernel_size=1, stride=1,
pad_mode='pad', padding=0, has_bias=False),
nn.BatchNorm2d(num_features=outputs, momentum=0.9),
nn.ReLU(),
]
self.branch_main = nn.SequentialCell(branch_main)
if stride == 2:
branch_proj = [
# dw
nn.Conv2d(in_channels=inp, out_channels=inp, kernel_size=ksize, stride=stride,
pad_mode='pad', padding=pad, group=inp, has_bias=False),
nn.BatchNorm2d(num_features=inp, momentum=0.9),
# pw-linear
nn.Conv2d(in_channels=inp, out_channels=inp, kernel_size=1, stride=1,
pad_mode='pad', padding=0, has_bias=False),
nn.BatchNorm2d(num_features=inp, momentum=0.9),
nn.ReLU(),
]
self.branch_proj = nn.SequentialCell(branch_proj)
else:
self.branch_proj = None
def construct(self, old_x):
if self.stride == 1:
x_proj, x = self.channel_shuffle(old_x)
return P.Concat(1)((x_proj, self.branch_main(x)))
if self.stride == 2:
x_proj = old_x
x = old_x
return P.Concat(1)((self.branch_proj(x_proj), self.branch_main(x)))
return None
def channel_shuffle(self, x):
batchsize, num_channels, height, width = P.Shape()(x)
##assert (num_channels % 4 == 0)
x = P.Reshape()(x, (batchsize * num_channels // 2, 2, height * width,))
x = P.Transpose()(x, (1, 0, 2,))
x = P.Reshape()(x, (2, -1, num_channels // 2, height, width,))
return x[0], x[1]
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""evaluate_imagenet"""
import argparse
import os
import mindspore.nn as nn
from mindspore import context
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.config import config_gpu as cfg
from src.dataset import create_dataset
from network import ShuffleNetV2
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='image classification evaluation')
parser.add_argument('--checkpoint', type=str, default='', help='checkpoint of ShuffleNetV2 (Default: None)')
parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU'), help='run platform')
parser.add_argument('--epoch', type=str, default='')
args_opt = parser.parse_args()
if args_opt.platform == 'Ascend':
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id)
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, device_id=0)
net = ShuffleNetV2(n_class=cfg.num_classes)
ckpt = load_checkpoint(args_opt.checkpoint)
load_param_into_net(net, ckpt)
net.set_train(False)
dataset = create_dataset(args_opt.dataset_path, cfg, False)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False,
smooth_factor=0.1, num_classes=cfg.num_classes)
eval_metrics = {'Loss': nn.Loss(),
'Top1-Acc': nn.Top1CategoricalAccuracy(),
'Top5-Acc': nn.Top5CategoricalAccuracy()}
model = Model(net, loss, optimizer=None, metrics=eval_metrics)
metrics = model.eval(dataset)
print("metric: ", metrics)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
from blocks import ShuffleV2Block
from mindspore import Tensor
import mindspore.nn as nn
import mindspore.ops.operations as P
class ShuffleNetV2(nn.Cell):
def __init__(self, input_size=224, n_class=1000, model_size='1.0x'):
super(ShuffleNetV2, self).__init__()
print('model size is ', model_size)
self.stage_repeats = [4, 8, 4]
self.model_size = model_size
if model_size == '0.5x':
self.stage_out_channels = [-1, 24, 48, 96, 192, 1024]
elif model_size == '1.0x':
self.stage_out_channels = [-1, 24, 116, 232, 464, 1024]
elif model_size == '1.5x':
self.stage_out_channels = [-1, 24, 176, 352, 704, 1024]
elif model_size == '2.0x':
self.stage_out_channels = [-1, 24, 244, 488, 976, 2048]
else:
raise NotImplementedError
# building first layer
input_channel = self.stage_out_channels[1]
self.first_conv = nn.SequentialCell([
nn.Conv2d(in_channels=3, out_channels=input_channel, kernel_size=3, stride=2,
pad_mode='pad', padding=1, has_bias=False),
nn.BatchNorm2d(num_features=input_channel, momentum=0.9),
nn.ReLU(),
])
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
self.features = []
for idxstage in range(len(self.stage_repeats)):
numrepeat = self.stage_repeats[idxstage]
output_channel = self.stage_out_channels[idxstage+2]
for i in range(numrepeat):
if i == 0:
self.features.append(ShuffleV2Block(input_channel, output_channel,
mid_channels=output_channel // 2, ksize=3, stride=2))
else:
self.features.append(ShuffleV2Block(input_channel // 2, output_channel,
mid_channels=output_channel // 2, ksize=3, stride=1))
input_channel = output_channel
self.features = nn.SequentialCell([*self.features])
self.conv_last = nn.SequentialCell([
nn.Conv2d(in_channels=input_channel, out_channels=self.stage_out_channels[-1], kernel_size=1, stride=1,
pad_mode='pad', padding=0, has_bias=False),
nn.BatchNorm2d(num_features=self.stage_out_channels[-1], momentum=0.9),
nn.ReLU()
])
self.globalpool = nn.AvgPool2d(kernel_size=7, stride=7, pad_mode='valid')
if self.model_size == '2.0x':
self.dropout = nn.Dropout(keep_prob=0.8)
self.classifier = nn.SequentialCell([nn.Dense(in_channels=self.stage_out_channels[-1],
out_channels=n_class, has_bias=False)])
##TODO init weights
self._initialize_weights()
def construct(self, x):
x = self.first_conv(x)
x = self.maxpool(x)
x = self.features(x)
x = self.conv_last(x)
x = self.globalpool(x)
if self.model_size == '2.0x':
x = self.dropout(x)
x = P.Reshape()(x, (-1, self.stage_out_channels[-1],))
x = self.classifier(x)
return x
def _initialize_weights(self):
for name, m in self.cells_and_names():
if isinstance(m, nn.Conv2d):
if 'first' in name:
m.weight.set_parameter_data(Tensor(np.random.normal(0, 0.01,
m.weight.data.shape).astype("float32")))
else:
m.weight.set_parameter_data(Tensor(np.random.normal(0, 1.0/m.weight.data.shape[1],
m.weight.data.shape).astype("float32")))
if isinstance(m, nn.Dense):
m.weight.set_parameter_data(Tensor(np.random.normal(0, 0.01, m.weight.data.shape).astype("float32")))
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
DATA_DIR=$1
mpirun --allow-run-as-root -n 8 python ./train.py --is_distributed --platform 'GPU' --dataset_path $DATA_DIR > train.log 2>&1 &
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
DEVICE_ID=$1
EPOCH=$2
CUDA_VISIBLE_DEVICES=$DEVICE_ID python ./eval.py --platform 'GPU' --dataset_path '/home/data/ImageNet_Original/val/' --epoch $EPOCH > eval.log 2>&1 &
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
DEVICE_ID=$1
DATA_DIR=$2
CUDA_VISIBLE_DEVICES=$DEVICE_ID python ./train.py --platform 'GPU' --dataset_path $DATA_DIR > train.log 2>&1 &
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in main.py
"""
from easydict import EasyDict as edict
config_gpu = edict({
'random_seed': 1,
'rank': 0,
'group_size': 1,
'work_nums': 8,
'epoch_size': 250,
'keep_checkpoint_max': 100,
'ckpt_path': './checkpoint/',
'is_save_on_master': 0,
### Dataset Config
'batch_size': 128,
'num_classes': 1000,
### Loss Config
'label_smooth_factor': 0.1,
'aux_factor': 0.4,
### Learning Rate Config
'lr_init': 0.5,
### Optimization Config
'weight_decay': 0.00004,
'momentum': 0.9,
'opt_eps': 1.0,
'rmsprop_decay': 0.9,
"loss_scale": 1,
})
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Data operations, will be used in train.py and eval.py
"""
import numpy as np
from src.config import config_gpu as cfg
import mindspore.common.dtype as mstype
import mindspore.dataset.engine as de
import mindspore.dataset.transforms.c_transforms as C2
import mindspore.dataset.transforms.vision.c_transforms as C
class toBGR():
def __call__(self, img):
img = img[:, :, ::-1]
img = np.ascontiguousarray(img)
return img
def create_dataset(dataset_path, do_train, rank, group_size, repeat_num=1):
"""
create a train or eval dataset
Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
rank (int): The shard ID within num_shards (default=None).
group_size (int): Number of shards that the dataset should be divided into (default=None).
repeat_num(int): the repeat times of dataset. Default: 1.
Returns:
dataset
"""
if group_size == 1:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=cfg.work_nums, shuffle=True)
else:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=cfg.work_nums, shuffle=True,
num_shards=group_size, shard_id=rank)
# define map operations
if do_train:
trans = [
C.RandomCropDecodeResize(224),
C.RandomHorizontalFlip(prob=0.5),
C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4)
]
else:
trans = [
C.Decode(),
C.Resize(256),
C.CenterCrop(224)
]
trans += [
toBGR(),
C.Rescale(1.0 / 255.0, 0.0),
# C.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
C.HWC2CHW(),
C2.TypeCast(mstype.float32)
]
type_cast_op = C2.TypeCast(mstype.int32)
ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=cfg.work_nums)
ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=cfg.work_nums)
# apply batch operations
ds = ds.batch(cfg.batch_size, drop_remainder=True)
# apply dataset repeat operation
ds = ds.repeat(repeat_num)
return ds
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""define loss function for network."""
from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore import Tensor
import mindspore.nn as nn
class CrossEntropy(_Loss):
"""the redefined loss function with SoftmaxCrossEntropyWithLogits"""
def __init__(self, smooth_factor=0, num_classes=1000, factor=0.4):
super(CrossEntropy, self).__init__()
self.factor = factor
self.onehot = P.OneHot()
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits()
self.mean = P.ReduceMean(False)
def construct(self, logits, label):
logit, aux = logits
one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
loss_logit = self.ce(logit, one_hot_label)
loss_logit = self.mean(loss_logit, 0)
one_hot_label_aux = self.onehot(label, F.shape(aux)[1], self.on_value, self.off_value)
loss_aux = self.ce(aux, one_hot_label_aux)
loss_aux = self.mean(loss_aux, 0)
return loss_logit + self.factor*loss_aux
class CrossEntropy_Val(_Loss):
"""the redefined loss function with SoftmaxCrossEntropyWithLogits, will be used in inference process"""
def __init__(self, smooth_factor=0, num_classes=1000):
super(CrossEntropy_Val, self).__init__()
self.onehot = P.OneHot()
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits()
self.mean = P.ReduceMean(False)
def construct(self, logits, label):
one_hot_label = self.onehot(label, F.shape(logits)[1], self.on_value, self.off_value)
loss_logit = self.ce(logits, one_hot_label)
loss_logit = self.mean(loss_logit, 0)
return loss_logit
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""learning rate exponential decay generator"""
import math
import numpy as np
def get_lr(lr_init, lr_decay_rate, num_epoch_per_decay, total_epochs, steps_per_epoch, is_stair=False):
"""
generate learning rate array
Args:
lr_init(float): init learning rate
lr_decay_rate (float):
total_epochs(int): total epoch of training
steps_per_epoch(int): steps of one epoch
is_stair(bool): If `True` decay the learning rate at discrete intervals (default=False)
Returns:
learning_rate, learning rate numpy array
"""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
decay_steps = steps_per_epoch * num_epoch_per_decay
for i in range(total_steps):
p = i/decay_steps
if is_stair:
p = math.floor(p)
lr_each_step.append(lr_init * math.pow(lr_decay_rate, p))
learning_rate = np.array(lr_each_step).astype(np.float32)
return learning_rate
def get_lr_basic(lr_init, total_epochs, steps_per_epoch, is_stair=False):
"""
generate basic learning rate array
Args:
lr_init(float): init learning rate
total_epochs(int): total epochs of training
steps_per_epoch(int): steps of one epoch
is_stair(bool): If `True` decay the learning rate at discrete intervals (default=False)
Returns:
learning_rate, learning rate numpy array
"""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
for i in range(total_steps):
lr = lr_init - lr_init * (i) / (total_steps)
lr_each_step.append(lr)
learning_rate = np.array(lr_each_step).astype(np.float32)
return learning_rate
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train_imagenet."""
import argparse
import os
import random
import numpy as np
from network import ShuffleNetV2
import mindspore.nn as nn
from mindspore import context
from mindspore import dataset as de
from mindspore import ParallelMode
from mindspore import Tensor
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.config import config_gpu as cfg
from src.dataset import create_dataset
from src.lr_generator import get_lr_basic
random.seed(cfg.random_seed)
np.random.seed(cfg.random_seed)
de.config.set_seed(cfg.random_seed)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='image classification training')
parser.add_argument('--dataset_path', type=str, default='/home/data/imagenet_jpeg/train/', help='Dataset path')
parser.add_argument('--resume', type=str, default='', help='resume training with existed checkpoint')
parser.add_argument('--is_distributed', action='store_true', default=False,
help='distributed training')
parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU'), help='run platform')
parser.add_argument('--model_size', type=str, default='1.0x', help='ShuffleNetV2 model size parameter')
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False)
if os.getenv('DEVICE_ID', "not_set").isdigit():
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
# init distributed
if args_opt.is_distributed:
if args_opt.platform == "Ascend":
init()
else:
init("nccl")
cfg.rank = get_rank()
cfg.group_size = get_group_size()
parallel_mode = ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.group_size,
parameter_broadcast=True, mirror_mean=True)
else:
cfg.rank = 0
cfg.group_size = 1
# dataloader
dataset = create_dataset(args_opt.dataset_path, True, cfg.rank, cfg.group_size)
batches_per_epoch = dataset.get_dataset_size()
print("Batches Per Epoch: ", batches_per_epoch)
# network
net = ShuffleNetV2(n_class=cfg.num_classes, model_size=args_opt.model_size)
# loss
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean", is_grad=False,
smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes)
# learning rate schedule
lr = get_lr_basic(lr_init=cfg.lr_init, total_epochs=cfg.epoch_size,
steps_per_epoch=batches_per_epoch, is_stair=True)
lr = Tensor(lr)
# optimizer
decayed_params = []
no_decayed_params = []
for param in net.trainable_params():
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
decayed_params.append(param)
else:
no_decayed_params.append(param)
group_params = [{'params': decayed_params, 'weight_decay': cfg.weight_decay},
{'params': no_decayed_params},
{'order_params': net.trainable_params()}]
optimizer = Momentum(params=net.trainable_params(), learning_rate=Tensor(lr), momentum=cfg.momentum,
weight_decay=cfg.weight_decay)
eval_metrics = {'Loss': nn.Loss(),
'Top1-Acc': nn.Top1CategoricalAccuracy(),
'Top5-Acc': nn.Top5CategoricalAccuracy()}
if args_opt.resume:
ckpt = load_checkpoint(args_opt.resume)
load_param_into_net(net, ckpt)
model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={'acc'})
print("============== Starting Training ==============")
loss_cb = LossMonitor(per_print_times=batches_per_epoch)
time_cb = TimeMonitor(data_size=batches_per_epoch)
callbacks = [loss_cb, time_cb]
config_ck = CheckpointConfig(save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix=f"shufflenet-rank{cfg.rank}", directory=cfg.ckpt_path, config=config_ck)
if args_opt.is_distributed & cfg.is_save_on_master:
if cfg.rank == 0:
callbacks.append(ckpoint_cb)
model.train(cfg.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=True)
else:
callbacks.append(ckpoint_cb)
model.train(cfg.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=True)
print("train success")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册