diff --git a/mindspore/model_zoo/mobilenetv2/Readme.md b/example/mobilenetv2_imagenet/Readme.md similarity index 98% rename from mindspore/model_zoo/mobilenetv2/Readme.md rename to example/mobilenetv2_imagenet/Readme.md index 2ee9f0a6abdc29f3398c99a71fc40b22a95c8cce..5d18579d6bfdeba61736a51ab47567830cc6117e 100644 --- a/mindspore/model_zoo/mobilenetv2/Readme.md +++ b/example/mobilenetv2_imagenet/Readme.md @@ -13,7 +13,7 @@ The overall network architecture of MobileNetV2 is show below: # Dataset -Dataset used: [imagenet](http://www.image-net.org/) +Dataset used: imagenet - Dataset size: ~125G, 1.2W colorful images in 1000 classes - Train: 120G, 1.2W images @@ -60,8 +60,8 @@ Dataset used: [imagenet](http://www.image-net.org/) ### Usage -- Ascend: sh run_train.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] -- GPU: sh run_trian.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] +- Ascend: sh run_train.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH] +- GPU: sh run_trian.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH] ### Launch diff --git a/mindspore/model_zoo/mobilenetv2/eval.py b/example/mobilenetv2_imagenet/eval.py similarity index 100% rename from mindspore/model_zoo/mobilenetv2/eval.py rename to example/mobilenetv2_imagenet/eval.py diff --git a/mindspore/model_zoo/mobilenetv2/scripts/run_infer.sh b/example/mobilenetv2_imagenet/scripts/run_infer.sh similarity index 100% rename from mindspore/model_zoo/mobilenetv2/scripts/run_infer.sh rename to example/mobilenetv2_imagenet/scripts/run_infer.sh diff --git a/mindspore/model_zoo/mobilenetv2/scripts/run_train.sh b/example/mobilenetv2_imagenet/scripts/run_train.sh similarity index 91% rename from mindspore/model_zoo/mobilenetv2/scripts/run_train.sh rename to example/mobilenetv2_imagenet/scripts/run_train.sh index 95f9b39b93fbd3dc71839fd65df55f2f749e7500..aabe09cf34e80cbbba61de32d8ffda8335c0c7ab 100644 --- a/mindspore/model_zoo/mobilenetv2/scripts/run_train.sh +++ b/example/mobilenetv2_imagenet/scripts/run_train.sh @@ -42,6 +42,7 @@ run_ascend() --server_id=$3 \ --training_script=${BASEPATH}/../train.py \ --dataset_path=$5 \ + --pre_trained=$6 \ --platform=$1 &> ../train.log & # dataset train folder } @@ -73,14 +74,15 @@ run_gpu() python ${BASEPATH}/../train.py \ --dataset_path=$4 \ --platform=$1 \ + --pre_trained=$5 \ &> ../train.log & # dataset train folder } -if [ $# -gt 5 ] || [ $# -lt 4 ] +if [ $# -gt 6 ] || [ $# -lt 4 ] then echo "Usage:\n \ - Ascend: sh run_train.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]\n \ - GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]\n \ + Ascend: sh run_train.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH]\n \ + GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH]\n \ " exit 1 fi diff --git a/mindspore/model_zoo/mobilenetv2/src/config.py b/example/mobilenetv2_imagenet/src/config.py similarity index 100% rename from mindspore/model_zoo/mobilenetv2/src/config.py rename to example/mobilenetv2_imagenet/src/config.py diff --git a/mindspore/model_zoo/mobilenetv2/src/dataset.py b/example/mobilenetv2_imagenet/src/dataset.py similarity index 100% rename from mindspore/model_zoo/mobilenetv2/src/dataset.py rename to example/mobilenetv2_imagenet/src/dataset.py diff --git a/mindspore/model_zoo/mobilenetv2/src/launch.py b/example/mobilenetv2_imagenet/src/launch.py similarity index 100% rename from mindspore/model_zoo/mobilenetv2/src/launch.py rename to example/mobilenetv2_imagenet/src/launch.py diff --git a/mindspore/model_zoo/mobilenetv2/src/lr_generator.py b/example/mobilenetv2_imagenet/src/lr_generator.py similarity index 100% rename from mindspore/model_zoo/mobilenetv2/src/lr_generator.py rename to example/mobilenetv2_imagenet/src/lr_generator.py diff --git a/mindspore/model_zoo/mobilenetv2/train.py b/example/mobilenetv2_imagenet/train.py similarity index 99% rename from mindspore/model_zoo/mobilenetv2/train.py rename to example/mobilenetv2_imagenet/train.py index 775981f0304d0d3461b1a6c96d2be04dd2e9fdcf..9ba2d82966feae48748576d4e3cc0fc05c86ec88 100644 --- a/mindspore/model_zoo/mobilenetv2/train.py +++ b/example/mobilenetv2_imagenet/train.py @@ -33,11 +33,11 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.communication.management import init, get_group_size +from mindspore.model_zoo.mobilenetV2 import mobilenet_v2 import mindspore.dataset.engine as de from src.dataset import create_dataset from src.lr_generator import get_lr from src.config import config_gpu, config_ascend -from src.mobilenetV2 import mobilenet_v2 random.seed(1) np.random.seed(1) diff --git a/mindspore/model_zoo/mobilenetv3/Readme.md b/example/mobilenetv3_imagenet/Readme.md similarity index 97% rename from mindspore/model_zoo/mobilenetv3/Readme.md rename to example/mobilenetv3_imagenet/Readme.md index fa5ca1ae77659b32e858108c9b22313ea1d5b297..cbd3bbcc8adbe1487ef871f408096f4d52964427 100644 --- a/mindspore/model_zoo/mobilenetv3/Readme.md +++ b/example/mobilenetv3_imagenet/Readme.md @@ -13,7 +13,7 @@ The overall network architecture of MobileNetV3 is show below: # Dataset -Dataset used: [imagenet](http://www.image-net.org/) +Dataset used: imagenet - Dataset size: ~125G, 1.2W colorful images in 1000 classes - Train: 120G, 1.2W images @@ -67,8 +67,8 @@ Dataset used: [imagenet](http://www.image-net.org/) ``` # training example - Ascend: sh run_train.sh Ascend 8 192.168.0.1 0,1,2,3,4,5,6,7 ~/imagenet/train/ - GPU: sh run_train.sh GPU 8 0,1,2,3,4,5,6,7 ~/imagenet/train/ + Ascend: sh run_train.sh Ascend 8 192.168.0.1 0,1,2,3,4,5,6,7 ~/imagenet/train/ mobilenet_199.ckpt + GPU: sh run_train.sh GPU 8 0,1,2,3,4,5,6,7 ~/imagenet/train/ mobilenet_199.ckpt ``` ### Result diff --git a/mindspore/model_zoo/mobilenetv3/eval.py b/example/mobilenetv3_imagenet/eval.py similarity index 100% rename from mindspore/model_zoo/mobilenetv3/eval.py rename to example/mobilenetv3_imagenet/eval.py diff --git a/mindspore/model_zoo/mobilenetv3/scripts/run_infer.sh b/example/mobilenetv3_imagenet/scripts/run_infer.sh similarity index 100% rename from mindspore/model_zoo/mobilenetv3/scripts/run_infer.sh rename to example/mobilenetv3_imagenet/scripts/run_infer.sh diff --git a/mindspore/model_zoo/mobilenetv3/scripts/run_train.sh b/example/mobilenetv3_imagenet/scripts/run_train.sh similarity index 91% rename from mindspore/model_zoo/mobilenetv3/scripts/run_train.sh rename to example/mobilenetv3_imagenet/scripts/run_train.sh index 78b79b235fd459e9938bf57d9926c4c56b2a06d0..06e8b485335857a4e00a8a8edd7bc0e728d76fb0 100644 --- a/mindspore/model_zoo/mobilenetv3/scripts/run_train.sh +++ b/example/mobilenetv3_imagenet/scripts/run_train.sh @@ -41,6 +41,7 @@ run_ascend() --server_id=$3 \ --training_script=${BASEPATH}/../train.py \ --dataset_path=$5 \ + --pre_trained=$6 \ --platform=$1 &> ../train.log & # dataset train folder } @@ -72,14 +73,15 @@ run_gpu() python ${BASEPATH}/../train.py \ --dataset_path=$4 \ --platform=$1 \ + --pre_trained=$5 \ &> ../train.log & # dataset train folder } -if [ $# -gt 5 ] || [ $# -lt 4 ] +if [ $# -gt 6 ] || [ $# -lt 4 ] then echo "Usage:\n \ - Ascend: sh run_train.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]\n \ - GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]\n \ + Ascend: sh run_train.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH]\n \ + GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH]\n \ " exit 1 fi diff --git a/mindspore/model_zoo/mobilenetv3/src/config.py b/example/mobilenetv3_imagenet/src/config.py similarity index 100% rename from mindspore/model_zoo/mobilenetv3/src/config.py rename to example/mobilenetv3_imagenet/src/config.py diff --git a/mindspore/model_zoo/mobilenetv3/src/dataset.py b/example/mobilenetv3_imagenet/src/dataset.py similarity index 100% rename from mindspore/model_zoo/mobilenetv3/src/dataset.py rename to example/mobilenetv3_imagenet/src/dataset.py diff --git a/mindspore/model_zoo/mobilenetv3/src/launch.py b/example/mobilenetv3_imagenet/src/launch.py similarity index 100% rename from mindspore/model_zoo/mobilenetv3/src/launch.py rename to example/mobilenetv3_imagenet/src/launch.py diff --git a/mindspore/model_zoo/mobilenetv3/src/lr_generator.py b/example/mobilenetv3_imagenet/src/lr_generator.py similarity index 100% rename from mindspore/model_zoo/mobilenetv3/src/lr_generator.py rename to example/mobilenetv3_imagenet/src/lr_generator.py diff --git a/mindspore/model_zoo/mobilenetv3/train.py b/example/mobilenetv3_imagenet/train.py similarity index 99% rename from mindspore/model_zoo/mobilenetv3/train.py rename to example/mobilenetv3_imagenet/train.py index 724fed7cb84c1ec1be9c91d20f5e13eddbfecad7..478da10fc9c9ff8c9f173f15575cd5d396c6f53a 100644 --- a/mindspore/model_zoo/mobilenetv3/train.py +++ b/example/mobilenetv3_imagenet/train.py @@ -34,10 +34,10 @@ from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.serialization import load_checkpoint, load_param_into_net import mindspore.dataset.engine as de from mindspore.communication.management import init, get_group_size +from mindspore.model_zoo.mobilenetV3 import mobilenet_v3_large from src.dataset import create_dataset from src.lr_generator import get_lr from src.config import config_gpu, config_ascend -from src.mobilenetV3 import mobilenet_v3_large random.seed(1) np.random.seed(1) diff --git a/mindspore/model_zoo/mobilenet.py b/mindspore/model_zoo/mobilenet.py deleted file mode 100644 index 6539c3e2690073a70c9c6d16877d42b739ef7c56..0000000000000000000000000000000000000000 --- a/mindspore/model_zoo/mobilenet.py +++ /dev/null @@ -1,285 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""MobileNetV2 model define""" -import numpy as np -import mindspore.nn as nn -from mindspore.ops import operations as P -from mindspore.ops.operations import TensorAdd -from mindspore import Parameter, Tensor -from mindspore.common.initializer import initializer - -__all__ = ['MobileNetV2', 'mobilenet_v2'] - - -def _make_divisible(v, divisor, min_value=None): - """ - This function is taken from the original tf repo. - It ensures that all layers have a channel number that is divisible by 8 - It can be seen here: - https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py - :param v: - :param divisor: - :param min_value: - :return: - """ - if min_value is None: - min_value = divisor - new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) - # Make sure that round down does not go down by more than 10%. - if new_v < 0.9 * v: - new_v += divisor - return new_v - - -class GlobalAvgPooling(nn.Cell): - """ - Global avg pooling definition. - - Args: - - Returns: - Tensor, output tensor. - - Examples: - >>> GlobalAvgPooling() - """ - def __init__(self): - super(GlobalAvgPooling, self).__init__() - self.mean = P.ReduceMean(keep_dims=False) - - def construct(self, x): - x = self.mean(x, (2, 3)) - return x - - -class DepthwiseConv(nn.Cell): - """ - Depthwise Convolution warpper definition. - - Args: - in_planes (int): Input channel. - kernel_size (int): Input kernel size. - stride (int): Stride size. - pad_mode (str): pad mode in (pad, same, valid) - channel_multiplier (int): Output channel multiplier - has_bias (bool): has bias or not - - Returns: - Tensor, output tensor. - - Examples: - >>> DepthwiseConv(16, 3, 1, 'pad', 1, channel_multiplier=1) - """ - def __init__(self, in_planes, kernel_size, stride, pad_mode, pad, channel_multiplier=1, has_bias=False): - super(DepthwiseConv, self).__init__() - self.has_bias = has_bias - self.in_channels = in_planes - self.channel_multiplier = channel_multiplier - self.out_channels = in_planes * channel_multiplier - self.kernel_size = (kernel_size, kernel_size) - self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier, - kernel_size=self.kernel_size, - stride=stride, pad_mode=pad_mode, pad=pad) - self.bias_add = P.BiasAdd() - weight_shape = [channel_multiplier, in_planes, *self.kernel_size] - self.weight = Parameter(initializer('ones', weight_shape), name='weight') - - if has_bias: - bias_shape = [channel_multiplier * in_planes] - self.bias = Parameter(initializer('zeros', bias_shape), name='bias') - else: - self.bias = None - - def construct(self, x): - output = self.depthwise_conv(x, self.weight) - if self.has_bias: - output = self.bias_add(output, self.bias) - return output - - -class ConvBNReLU(nn.Cell): - """ - Convolution/Depthwise fused with Batchnorm and ReLU block definition. - - Args: - in_planes (int): Input channel. - out_planes (int): Output channel. - kernel_size (int): Input kernel size. - stride (int): Stride size for the first convolutional layer. Default: 1. - groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1. - - Returns: - Tensor, output tensor. - - Examples: - >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1) - """ - def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): - super(ConvBNReLU, self).__init__() - padding = (kernel_size - 1) // 2 - if groups == 1: - conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', - padding=padding) - else: - conv = DepthwiseConv(in_planes, kernel_size, stride, pad_mode='pad', pad=padding) - layers = [conv, nn.BatchNorm2d(out_planes), nn.ReLU6()] - self.features = nn.SequentialCell(layers) - - def construct(self, x): - output = self.features(x) - return output - - -class InvertedResidual(nn.Cell): - """ - Mobilenetv2 residual block definition. - - Args: - inp (int): Input channel. - oup (int): Output channel. - stride (int): Stride size for the first convolutional layer. Default: 1. - expand_ratio (int): expand ration of input channel - - Returns: - Tensor, output tensor. - - Examples: - >>> ResidualBlock(3, 256, 1, 1) - """ - def __init__(self, inp, oup, stride, expand_ratio): - super(InvertedResidual, self).__init__() - assert stride in [1, 2] - - hidden_dim = int(round(inp * expand_ratio)) - self.use_res_connect = stride == 1 and inp == oup - - layers = [] - if expand_ratio != 1: - layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) - layers.extend([ - # dw - ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), - # pw-linear - nn.Conv2d(hidden_dim, oup, kernel_size=1, stride=1, has_bias=False), - nn.BatchNorm2d(oup), - ]) - self.conv = nn.SequentialCell(layers) - self.add = TensorAdd() - self.cast = P.Cast() - - def construct(self, x): - identity = x - x = self.conv(x) - if self.use_res_connect: - return self.add(identity, x) - return x - - -class MobileNetV2(nn.Cell): - """ - MobileNetV2 architecture. - - Args: - class_num (Cell): number of classes. - width_mult (int): Channels multiplier for round to 8/16 and others. Default is 1. - has_dropout (bool): Is dropout used. Default is false - inverted_residual_setting (list): Inverted residual settings. Default is None - round_nearest (list): Channel round to . Default is 8 - Returns: - Tensor, output tensor. - - Examples: - >>> MobileNetV2(num_classes=1000) - """ - def __init__(self, num_classes=1000, width_mult=1., - has_dropout=False, inverted_residual_setting=None, round_nearest=8): - super(MobileNetV2, self).__init__() - block = InvertedResidual - input_channel = 32 - last_channel = 1280 - # setting of inverted residual blocks - self.cfgs = inverted_residual_setting - if inverted_residual_setting is None: - self.cfgs = [ - # t, c, n, s - [1, 16, 1, 1], - [6, 24, 2, 2], - [6, 32, 3, 2], - [6, 64, 4, 2], - [6, 96, 3, 1], - [6, 160, 3, 2], - [6, 320, 1, 1], - ] - - # building first layer - input_channel = _make_divisible(input_channel * width_mult, round_nearest) - self.out_channels = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) - features = [ConvBNReLU(3, input_channel, stride=2)] - # building inverted residual blocks - for t, c, n, s in self.cfgs: - output_channel = _make_divisible(c * width_mult, round_nearest) - for i in range(n): - stride = s if i == 0 else 1 - features.append(block(input_channel, output_channel, stride, expand_ratio=t)) - input_channel = output_channel - # building last several layers - features.append(ConvBNReLU(input_channel, self.out_channels, kernel_size=1)) - # make it nn.CellList - self.features = nn.SequentialCell(features) - # mobilenet head - head = ([GlobalAvgPooling(), nn.Dense(self.out_channels, num_classes, has_bias=True)] if not has_dropout else - [GlobalAvgPooling(), nn.Dropout(0.2), nn.Dense(self.out_channels, num_classes, has_bias=True)]) - self.head = nn.SequentialCell(head) - - self._initialize_weights() - - def construct(self, x): - x = self.features(x) - x = self.head(x) - return x - - def _initialize_weights(self): - """ - Initialize weights. - - Args: - - Returns: - None. - - Examples: - >>> _initialize_weights() - """ - for _, m in self.cells_and_names(): - if isinstance(m, (nn.Conv2d, DepthwiseConv)): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.set_parameter_data(Tensor(np.random.normal(0, np.sqrt(2. / n), - m.weight.data.shape()).astype("float32"))) - if m.bias is not None: - m.bias.set_parameter_data(Tensor(np.zeros(m.bias.data.shape(), dtype="float32"))) - elif isinstance(m, nn.BatchNorm2d): - m.gamma.set_parameter_data(Tensor(np.ones(m.gamma.data.shape(), dtype="float32"))) - m.beta.set_parameter_data(Tensor(np.zeros(m.beta.data.shape(), dtype="float32"))) - elif isinstance(m, nn.Dense): - m.weight.set_parameter_data(Tensor(np.random.normal(0, 0.01, m.weight.data.shape()).astype("float32"))) - if m.bias is not None: - m.bias.set_parameter_data(Tensor(np.zeros(m.bias.data.shape(), dtype="float32"))) - - -def mobilenet_v2(**kwargs): - """ - Constructs a MobileNet V2 model - """ - return MobileNetV2(**kwargs) diff --git a/mindspore/model_zoo/mobilenetv2/src/mobilenetV2.py b/mindspore/model_zoo/mobilenetV2.py similarity index 100% rename from mindspore/model_zoo/mobilenetv2/src/mobilenetV2.py rename to mindspore/model_zoo/mobilenetV2.py diff --git a/mindspore/model_zoo/mobilenetv3/src/mobilenetV3.py b/mindspore/model_zoo/mobilenetV3.py similarity index 100% rename from mindspore/model_zoo/mobilenetv3/src/mobilenetV3.py rename to mindspore/model_zoo/mobilenetV3.py