diff --git a/example/mobilenetv2_quant/Readme.md b/example/mobilenetv2_quant/Readme.md index e8ea05ba618006b25b921367a499d3b56f670d3a..5d8f55c394e360ca47ff4cafde1a49d8c6cd41a4 100644 --- a/example/mobilenetv2_quant/Readme.md +++ b/example/mobilenetv2_quant/Readme.md @@ -41,7 +41,7 @@ Dataset used: imagenet ## Script and sample code ```python -├── MobileNetV2 +├── mobilenetv2_quant ├── Readme.md ├── scripts │ ├──run_train.sh @@ -51,7 +51,7 @@ Dataset used: imagenet │ ├──dataset.py │ ├──luanch.py │ ├──lr_generator.py - │ ├──mobilenetV2.py + │ ├──mobilenetV2_quant.py ├── train.py ├── eval.py ``` diff --git a/example/mobilenetv2_quant/eval.py b/example/mobilenetv2_quant/eval.py index d8e25ff93b96d356a7836364c7daeb4f034bb5b3..8513f15171e76613fe1432bf3619dc0dd893ba46 100644 --- a/example/mobilenetv2_quant/eval.py +++ b/example/mobilenetv2_quant/eval.py @@ -21,11 +21,9 @@ from mindspore import context from mindspore import nn from mindspore.train.model import Model from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindspore.common import dtype as mstype -from mindspore.model_zoo.mobilenetV2 import mobilenet_v2 +from src.mobilenetV2_quant import mobilenet_v2_quant from src.dataset import create_dataset -from src.config import config_ascend, config_gpu - +from src.config import config_ascend parser = argparse.ArgumentParser(description='Image classification') parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') @@ -33,7 +31,6 @@ parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path parser.add_argument('--platform', type=str, default=None, help='run platform') args_opt = parser.parse_args() - if __name__ == '__main__': config_platform = None net = None @@ -42,24 +39,13 @@ if __name__ == '__main__': device_id = int(os.getenv('DEVICE_ID')) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id, save_graphs=False) - net = mobilenet_v2(num_classes=config_platform.num_classes, platform="Ascend") - elif args_opt.platform == "GPU": - config_platform = config_gpu - context.set_context(mode=context.GRAPH_MODE, - device_target="GPU", save_graphs=False) - net = mobilenet_v2(num_classes=config_platform.num_classes, platform="GPU") + net = mobilenet_v2_quant(num_classes=config_platform.num_classes) else: raise ValueError("Unsupport platform.") loss = nn.SoftmaxCrossEntropyWithLogits( is_grad=False, sparse=True, reduction='mean') - if args_opt.platform == "Ascend": - net.to_float(mstype.float16) - for _, cell in net.cells_and_names(): - if isinstance(cell, nn.Dense): - cell.to_float(mstype.float32) - dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, config=config_platform, diff --git a/example/mobilenetv2_quant/scripts/run_infer.sh b/example/mobilenetv2_quant/scripts/run_infer.sh index e200e600bfec0f8e725cdeaad4a8de3f7fe95f76..907b823a4759e5c3ca36c66145ea68953d0c988f 100644 --- a/example/mobilenetv2_quant/scripts/run_infer.sh +++ b/example/mobilenetv2_quant/scripts/run_infer.sh @@ -15,8 +15,7 @@ # ============================================================================ if [ $# != 3 ] then - echo "Ascend: sh run_infer.sh [PLATFORM] [DATASET_PATH] [CHECKPOINT_PATH] \ - GPU: sh run_infer.sh [PLATFORM] [DATASET_PATH] [CHECKPOINT_PATH]" + echo "Ascend: sh run_infer.sh [PLATFORM] [DATASET_PATH] [CHECKPOINT_PATH]" exit 1 fi diff --git a/example/mobilenetv2_quant/scripts/run_train.sh b/example/mobilenetv2_quant/scripts/run_train.sh index aabe09cf34e80cbbba61de32d8ffda8335c0c7ab..c18d03e418653c89c2375acfd7425ccf02dc32a8 100644 --- a/example/mobilenetv2_quant/scripts/run_train.sh +++ b/example/mobilenetv2_quant/scripts/run_train.sh @@ -82,15 +82,12 @@ if [ $# -gt 6 ] || [ $# -lt 4 ] then echo "Usage:\n \ Ascend: sh run_train.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH]\n \ - GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH]\n \ " exit 1 fi if [ $1 = "Ascend" ] ; then run_ascend "$@" -elif [ $1 = "GPU" ] ; then - run_gpu "$@" else echo "not support platform" fi; diff --git a/example/mobilenetv2_quant/src/config.py b/example/mobilenetv2_quant/src/config.py index c8885336b2e7e85719e581c10544963f871ea275..8cbab844d4fe70a6d64b2a5e6834b0309b4bd4c1 100644 --- a/example/mobilenetv2_quant/src/config.py +++ b/example/mobilenetv2_quant/src/config.py @@ -21,10 +21,11 @@ config_ascend = ed({ "num_classes": 1000, "image_height": 224, "image_width": 224, - "batch_size": 256, - "epoch_size": 200, - "warmup_epochs": 4, - "lr": 0.4, + "batch_size": 192, + "epoch_size": 40, + "start_epoch": 200, + "warmup_epochs": 1, + "lr": 0.15, "momentum": 0.9, "weight_decay": 4e-5, "label_smooth": 0.1, diff --git a/example/mobilenetv2_quant/src/dataset.py b/example/mobilenetv2_quant/src/dataset.py index 1edfcabfac7eca72344d3b930fa8ca037a76d8ed..e4d757ec0a383433ad337569947dad6a20227080 100644 --- a/example/mobilenetv2_quant/src/dataset.py +++ b/example/mobilenetv2_quant/src/dataset.py @@ -37,24 +37,31 @@ def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch if platform == "Ascend": rank_size = int(os.getenv("RANK_SIZE")) rank_id = int(os.getenv("RANK_ID")) - if rank_size == 1: - ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True) + if do_train: + if rank_size == 1: + ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True) + else: + ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, + num_shards=rank_size, shard_id=rank_id) else: - ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, - num_shards=rank_size, shard_id=rank_id) + ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=False) elif platform == "GPU": if do_train: from mindspore.communication.management import get_rank, get_group_size ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, num_shards=get_group_size(), shard_id=get_rank()) else: - ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True) + ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=False) else: raise ValueError("Unsupport platform.") resize_height = config.image_height resize_width = config.image_width - buffer_size = 1000 + + if do_train: + buffer_size = 20480 + # apply shuffle operations + ds = ds.shuffle(buffer_size=buffer_size) # define map operations decode_op = C.Decode() @@ -63,12 +70,15 @@ def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch resize_op = C.Resize((256, 256)) center_crop = C.CenterCrop(resize_width) - rescale_op = C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4) + random_color_op = C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4) normalize_op = C.Normalize(mean=[0.485*255, 0.456*255, 0.406*255], std=[0.229*255, 0.224*255, 0.225*255]) change_swap_op = C.HWC2CHW() + transform_uniform = [horizontal_flip_op, random_color_op] + uni_aug = C.UniformAugment(operations=transform_uniform, num_ops=2) + if do_train: - trans = [resize_crop_op, horizontal_flip_op, rescale_op, normalize_op, change_swap_op] + trans = [resize_crop_op, uni_aug, normalize_op, change_swap_op] else: trans = [decode_op, resize_op, center_crop, normalize_op, change_swap_op] @@ -77,9 +87,6 @@ def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=8) ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=8) - # apply shuffle operations - ds = ds.shuffle(buffer_size=buffer_size) - # apply batch operations ds = ds.batch(batch_size, drop_remainder=True) diff --git a/example/mobilenetv2_quant/src/launch.py b/example/mobilenetv2_quant/src/launch.py index 48c81596645d75628cecb0eb02acd23dd736bbea..52b0b0b3a87a04b9493cec79a1f3205a54e277fb 100644 --- a/example/mobilenetv2_quant/src/launch.py +++ b/example/mobilenetv2_quant/src/launch.py @@ -20,6 +20,7 @@ import subprocess import shutil from argparse import ArgumentParser + def parse_args(): """ parse args . @@ -79,7 +80,7 @@ def main(): device_ips[device_id] = device_ip print('device_id:{}, device_ip:{}'.format(device_id, device_ip)) hccn_table = {} - hccn_table['board_id'] = '0x0000' + hccn_table['board_id'] = '0x0020' hccn_table['chip_info'] = '910' hccn_table['deploy_mode'] = 'lab' hccn_table['group_count'] = '1' diff --git a/example/mobilenetv2_quant/src/mobilenetV2.py b/example/mobilenetv2_quant/src/mobilenetV2.py deleted file mode 100644 index df35c5f3693164eba2638c301b0988ef504b2475..0000000000000000000000000000000000000000 --- a/example/mobilenetv2_quant/src/mobilenetV2.py +++ /dev/null @@ -1,291 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""MobileNetV2 model define""" -import numpy as np -import mindspore.nn as nn -from mindspore.ops import operations as P -from mindspore.ops.operations import TensorAdd -from mindspore import Parameter, Tensor -from mindspore.common.initializer import initializer - -__all__ = ['mobilenet_v2'] - - -def _make_divisible(v, divisor, min_value=None): - if min_value is None: - min_value = divisor - new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) - # Make sure that round down does not go down by more than 10%. - if new_v < 0.9 * v: - new_v += divisor - return new_v - - -class GlobalAvgPooling(nn.Cell): - """ - Global avg pooling definition. - - Args: - - Returns: - Tensor, output tensor. - - Examples: - >>> GlobalAvgPooling() - """ - - def __init__(self): - super(GlobalAvgPooling, self).__init__() - self.mean = P.ReduceMean(keep_dims=False) - - def construct(self, x): - x = self.mean(x, (2, 3)) - return x - - -class DepthwiseConv(nn.Cell): - """ - Depthwise Convolution warpper definition. - - Args: - in_planes (int): Input channel. - kernel_size (int): Input kernel size. - stride (int): Stride size. - pad_mode (str): pad mode in (pad, same, valid) - channel_multiplier (int): Output channel multiplier - has_bias (bool): has bias or not - - Returns: - Tensor, output tensor. - - Examples: - >>> DepthwiseConv(16, 3, 1, 'pad', 1, channel_multiplier=1) - """ - - def __init__(self, in_planes, kernel_size, stride, pad_mode, pad, channel_multiplier=1, has_bias=False): - super(DepthwiseConv, self).__init__() - self.has_bias = has_bias - self.in_channels = in_planes - self.channel_multiplier = channel_multiplier - self.out_channels = in_planes * channel_multiplier - self.kernel_size = (kernel_size, kernel_size) - self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier, - kernel_size=self.kernel_size, - stride=stride, pad_mode=pad_mode, pad=pad) - self.bias_add = P.BiasAdd() - weight_shape = [channel_multiplier, in_planes, *self.kernel_size] - self.weight = Parameter(initializer('ones', weight_shape), name='weight') - - if has_bias: - bias_shape = [channel_multiplier * in_planes] - self.bias = Parameter(initializer('zeros', bias_shape), name='bias') - else: - self.bias = None - - def construct(self, x): - output = self.depthwise_conv(x, self.weight) - if self.has_bias: - output = self.bias_add(output, self.bias) - return output - - -class ConvBNReLU(nn.Cell): - """ - Convolution/Depthwise fused with Batchnorm and ReLU block definition. - - Args: - in_planes (int): Input channel. - out_planes (int): Output channel. - kernel_size (int): Input kernel size. - stride (int): Stride size for the first convolutional layer. Default: 1. - groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1. - - Returns: - Tensor, output tensor. - - Examples: - >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1) - """ - - def __init__(self, platform, in_planes, out_planes, kernel_size=3, stride=1, groups=1): - super(ConvBNReLU, self).__init__() - padding = (kernel_size - 1) // 2 - if groups == 1: - conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding) - else: - if platform == "Ascend": - conv = DepthwiseConv(in_planes, kernel_size, stride, pad_mode='pad', pad=padding) - elif platform == "GPU": - conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, - group=in_planes, pad_mode='pad', padding=padding) - - layers = [conv, nn.BatchNorm2d(out_planes), nn.ReLU6()] - self.features = nn.SequentialCell(layers) - - def construct(self, x): - output = self.features(x) - return output - - -class InvertedResidual(nn.Cell): - """ - Mobilenetv2 residual block definition. - - Args: - inp (int): Input channel. - oup (int): Output channel. - stride (int): Stride size for the first convolutional layer. Default: 1. - expand_ratio (int): expand ration of input channel - - Returns: - Tensor, output tensor. - - Examples: - >>> ResidualBlock(3, 256, 1, 1) - """ - - def __init__(self, platform, inp, oup, stride, expand_ratio): - super(InvertedResidual, self).__init__() - assert stride in [1, 2] - - hidden_dim = int(round(inp * expand_ratio)) - self.use_res_connect = stride == 1 and inp == oup - - layers = [] - if expand_ratio != 1: - layers.append(ConvBNReLU(platform, inp, hidden_dim, kernel_size=1)) - layers.extend([ - # dw - ConvBNReLU(platform, hidden_dim, hidden_dim, - stride=stride, groups=hidden_dim), - # pw-linear - nn.Conv2d(hidden_dim, oup, kernel_size=1, - stride=1, has_bias=False), - nn.BatchNorm2d(oup), - ]) - self.conv = nn.SequentialCell(layers) - self.add = TensorAdd() - self.cast = P.Cast() - - def construct(self, x): - identity = x - x = self.conv(x) - if self.use_res_connect: - return self.add(identity, x) - return x - - -class MobileNetV2(nn.Cell): - """ - MobileNetV2 architecture. - - Args: - class_num (Cell): number of classes. - width_mult (int): Channels multiplier for round to 8/16 and others. Default is 1. - has_dropout (bool): Is dropout used. Default is false - inverted_residual_setting (list): Inverted residual settings. Default is None - round_nearest (list): Channel round to . Default is 8 - Returns: - Tensor, output tensor. - - Examples: - >>> MobileNetV2(num_classes=1000) - """ - - def __init__(self, platform, num_classes=1000, width_mult=1., - has_dropout=False, inverted_residual_setting=None, round_nearest=8): - super(MobileNetV2, self).__init__() - block = InvertedResidual - input_channel = 32 - last_channel = 1280 - # setting of inverted residual blocks - self.cfgs = inverted_residual_setting - if inverted_residual_setting is None: - self.cfgs = [ - # t, c, n, s - [1, 16, 1, 1], - [6, 24, 2, 2], - [6, 32, 3, 2], - [6, 64, 4, 2], - [6, 96, 3, 1], - [6, 160, 3, 2], - [6, 320, 1, 1], - ] - - # building first layer - input_channel = _make_divisible(input_channel * width_mult, round_nearest) - self.out_channels = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) - features = [ConvBNReLU(platform, 3, input_channel, stride=2)] - # building inverted residual blocks - for t, c, n, s in self.cfgs: - output_channel = _make_divisible(c * width_mult, round_nearest) - for i in range(n): - stride = s if i == 0 else 1 - features.append(block(platform, input_channel, output_channel, stride, expand_ratio=t)) - input_channel = output_channel - # building last several layers - features.append(ConvBNReLU(platform, input_channel, self.out_channels, kernel_size=1)) - # make it nn.CellList - self.features = nn.SequentialCell(features) - # mobilenet head - head = ([GlobalAvgPooling(), nn.Dense(self.out_channels, num_classes, has_bias=True)] if not has_dropout else - [GlobalAvgPooling(), nn.Dropout(0.2), nn.Dense(self.out_channels, num_classes, has_bias=True)]) - self.head = nn.SequentialCell(head) - - self._initialize_weights() - - def construct(self, x): - x = self.features(x) - x = self.head(x) - return x - - def _initialize_weights(self): - """ - Initialize weights. - - Args: - - Returns: - None. - - Examples: - >>> _initialize_weights() - """ - for _, m in self.cells_and_names(): - if isinstance(m, (nn.Conv2d, DepthwiseConv)): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.set_parameter_data(Tensor(np.random.normal(0, np.sqrt(2. / n), - m.weight.data.shape()).astype("float32"))) - if m.bias is not None: - m.bias.set_parameter_data( - Tensor(np.zeros(m.bias.data.shape(), dtype="float32"))) - elif isinstance(m, nn.BatchNorm2d): - m.gamma.set_parameter_data( - Tensor(np.ones(m.gamma.data.shape(), dtype="float32"))) - m.beta.set_parameter_data( - Tensor(np.zeros(m.beta.data.shape(), dtype="float32"))) - elif isinstance(m, nn.Dense): - m.weight.set_parameter_data(Tensor(np.random.normal( - 0, 0.01, m.weight.data.shape()).astype("float32"))) - if m.bias is not None: - m.bias.set_parameter_data( - Tensor(np.zeros(m.bias.data.shape(), dtype="float32"))) - - -def mobilenet_v2(**kwargs): - """ - Constructs a MobileNet V2 model - """ - return MobileNetV2(**kwargs) diff --git a/example/mobilenetv2_quant/src/mobilenetV2_quant.py b/example/mobilenetv2_quant/src/mobilenetV2_quant.py index 7e8d2cb231b7b5cafa0cefc40c6fa25df83860c2..84679c967601f54af8d414cf125bcd4ae2be8802 100644 --- a/example/mobilenetv2_quant/src/mobilenetV2_quant.py +++ b/example/mobilenetv2_quant/src/mobilenetV2_quant.py @@ -13,18 +13,16 @@ # limitations under the License. # ============================================================================ """MobileNetV2 Quant model define""" -import numpy as np import mindspore.nn as nn from mindspore.ops import operations as P from mindspore.ops.operations import TensorAdd -from mindspore import Parameter, Tensor -from mindspore.common.initializer import initializer __all__ = ['mobilenet_v2_quant'] _ema_decay = 0.999 _symmetric = False + def _make_divisible(v, divisor, min_value=None): if min_value is None: min_value = divisor @@ -57,52 +55,6 @@ class GlobalAvgPooling(nn.Cell): return x -class DepthwiseConv(nn.Cell): - """ - Depthwise Convolution warpper definition. - - Args: - in_planes (int): Input channel. - kernel_size (int): Input kernel size. - stride (int): Stride size. - pad_mode (str): pad mode in (pad, same, valid) - channel_multiplier (int): Output channel multiplier - has_bias (bool): has bias or not - - Returns: - Tensor, output tensor. - - Examples: - >>> DepthwiseConv(16, 3, 1, 'pad', 1, channel_multiplier=1) - """ - - def __init__(self, in_planes, kernel_size, stride, pad_mode, pad, channel_multiplier=1, has_bias=False): - super(DepthwiseConv, self).__init__() - self.has_bias = has_bias - self.in_channels = in_planes - self.channel_multiplier = channel_multiplier - self.out_channels = in_planes * channel_multiplier - self.kernel_size = (kernel_size, kernel_size) - self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier, - kernel_size=self.kernel_size, - stride=stride, pad_mode=pad_mode, pad=pad) - self.bias_add = P.BiasAdd() - weight_shape = [channel_multiplier, in_planes, *self.kernel_size] - self.weight = Parameter(initializer('ones', weight_shape), name='weight') - - if has_bias: - bias_shape = [channel_multiplier * in_planes] - self.bias = Parameter(initializer('zeros', bias_shape), name='bias') - else: - self.bias = None - - def construct(self, x): - output = self.depthwise_conv(x, self.weight) - if self.has_bias: - output = self.bias_add(output, self.bias) - return output - - class ConvBNReLU(nn.Cell): """ Convolution/Depthwise fused with Batchnorm and ReLU block definition. @@ -121,21 +73,14 @@ class ConvBNReLU(nn.Cell): >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1) """ - def __init__(self, platform, in_planes, out_planes, kernel_size=3, stride=1, groups=1): + def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): super(ConvBNReLU, self).__init__() padding = (kernel_size - 1) // 2 - if groups == 1: - conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding) - else: - if platform == "Ascend": - conv = DepthwiseConv(in_planes, kernel_size, stride, pad_mode='pad', pad=padding) - elif platform == "GPU": - conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, - group=in_planes, pad_mode='pad', padding=padding) - - layers = [conv, nn.BatchNorm2d(out_planes), nn.ReLU6()] + conv = nn.Conv2dBatchNormQuant(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding, + group=groups) + layers = [conv, nn.ReLU()] self.features = nn.SequentialCell(layers) - self.fake = nn.FakeQuantWithMinMax(in_planes, ema=True, ema_decay=_ema_decay, symmetric=_symmetric) + self.fake = nn.FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay, symmetric=_symmetric, min_init=0) def construct(self, x): output = self.features(x) @@ -160,7 +105,7 @@ class InvertedResidual(nn.Cell): >>> ResidualBlock(3, 256, 1, 1) """ - def __init__(self, platform, inp, oup, stride, expand_ratio): + def __init__(self, inp, oup, stride, expand_ratio): super(InvertedResidual, self).__init__() assert stride in [1, 2] @@ -169,19 +114,17 @@ class InvertedResidual(nn.Cell): layers = [] if expand_ratio != 1: - layers.append(ConvBNReLU(platform, inp, hidden_dim, kernel_size=1)) + layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) layers.extend([ # dw - ConvBNReLU(platform, hidden_dim, hidden_dim, - stride=stride, groups=hidden_dim), + ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), # pw-linear nn.Conv2dBatchNormQuant(hidden_dim, oup, kernel_size=1, stride=1, pad_mode='pad', padding=0, group=1), - nn.FakeQuantWithMinMax(oup, ema=True, ema_decay=_ema_decay, symmetric=_symmetric) + nn.FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay, symmetric=_symmetric) ]) self.conv = nn.SequentialCell(layers) self.add = TensorAdd() - self.add_fake = nn.FakeQuantWithMinMax(oup, ema=True, ema_decay=_ema_decay, symmetric=_symmetric) - self.cast = P.Cast() + self.add_fake = nn.FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay, symmetric=_symmetric) def construct(self, x): identity = x @@ -209,7 +152,7 @@ class MobileNetV2Quant(nn.Cell): >>> MobileNetV2Quant(num_classes=1000) """ - def __init__(self, platform, num_classes=1000, width_mult=1., + def __init__(self, num_classes=1000, width_mult=1., has_dropout=False, inverted_residual_setting=None, round_nearest=8): super(MobileNetV2Quant, self).__init__() block = InvertedResidual @@ -232,16 +175,17 @@ class MobileNetV2Quant(nn.Cell): # building first layer input_channel = _make_divisible(input_channel * width_mult, round_nearest) self.out_channels = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) - features = [ConvBNReLU(platform, 3, input_channel, stride=2)] + self.input_fake = nn.FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay, symmetric=_symmetric) + features = [ConvBNReLU(3, input_channel, stride=2)] # building inverted residual blocks for t, c, n, s in self.cfgs: output_channel = _make_divisible(c * width_mult, round_nearest) for i in range(n): stride = s if i == 0 else 1 - features.append(block(platform, input_channel, output_channel, stride, expand_ratio=t)) + features.append(block(input_channel, output_channel, stride, expand_ratio=t)) input_channel = output_channel # building last several layers - features.append(ConvBNReLU(platform, input_channel, self.out_channels, kernel_size=1)) + features.append(ConvBNReLU(input_channel, self.out_channels, kernel_size=1)) # make it nn.CellList self.features = nn.SequentialCell(features) # mobilenet head @@ -249,45 +193,12 @@ class MobileNetV2Quant(nn.Cell): [GlobalAvgPooling(), nn.Dropout(0.2), nn.Dense(self.out_channels, num_classes, has_bias=True)]) self.head = nn.SequentialCell(head) - self._initialize_weights() - def construct(self, x): + x = self.input_fake(x) x = self.features(x) x = self.head(x) return x - def _initialize_weights(self): - """ - Initialize weights. - - Args: - - Returns: - None. - - Examples: - >>> _initialize_weights() - """ - for _, m in self.cells_and_names(): - if isinstance(m, (nn.Conv2d, DepthwiseConv)): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.set_parameter_data(Tensor(np.random.normal(0, np.sqrt(2. / n), - m.weight.data.shape()).astype("float32"))) - if m.bias is not None: - m.bias.set_parameter_data( - Tensor(np.zeros(m.bias.data.shape(), dtype="float32"))) - elif isinstance(m, nn.BatchNorm2d): - m.gamma.set_parameter_data( - Tensor(np.ones(m.gamma.data.shape(), dtype="float32"))) - m.beta.set_parameter_data( - Tensor(np.zeros(m.beta.data.shape(), dtype="float32"))) - elif isinstance(m, nn.Dense): - m.weight.set_parameter_data(Tensor(np.random.normal( - 0, 0.01, m.weight.data.shape()).astype("float32"))) - if m.bias is not None: - m.bias.set_parameter_data( - Tensor(np.zeros(m.bias.data.shape(), dtype="float32"))) - def mobilenet_v2_quant(**kwargs): """ diff --git a/example/mobilenetv2_quant/train.py b/example/mobilenetv2_quant/train.py index d99514034030841e36b91639d3b67e15fd26188e..0491bdb25130b6f202a41e22c96c105e5b186866 100644 --- a/example/mobilenetv2_quant/train.py +++ b/example/mobilenetv2_quant/train.py @@ -30,14 +30,12 @@ from mindspore.ops import functional as F from mindspore.common import dtype as mstype from mindspore.train.model import Model, ParallelMode from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback -from mindspore.train.loss_scale_manager import FixedLossScaleManager -from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindspore.communication.management import init, get_group_size +from mindspore.train.serialization import load_checkpoint +from mindspore.communication.management import init import mindspore.dataset.engine as de from src.dataset import create_dataset from src.lr_generator import get_lr -from src.config import config_gpu, config_ascend -from src.mobilenetV2 import mobilenet_v2 +from src.config import config_ascend from src.mobilenetV2_quant import mobilenet_v2_quant random.seed(1) @@ -153,122 +151,87 @@ class Monitor(Callback): np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1])) -if __name__ == '__main__': - if args_opt.platform == "GPU": - # train on gpu - print("train args: ", args_opt, "\ncfg: ", config_gpu) - - init('nccl') - context.set_auto_parallel_context(parallel_mode="data_parallel", - mirror_mean=True, - device_num=get_group_size()) - - # define net - net = mobilenet_v2(num_classes=config_gpu.num_classes, platform="GPU") - # define loss - if config_gpu.label_smooth > 0: - loss = CrossEntropyWithLabelSmooth( - smooth_factor=config_gpu.label_smooth, num_classes=config_gpu.num_classes) - else: - loss = SoftmaxCrossEntropyWithLogits( - is_grad=False, sparse=True, reduction='mean') - # define dataset - epoch_size = config_gpu.epoch_size - dataset = create_dataset(dataset_path=args_opt.dataset_path, - do_train=True, - config=config_gpu, - platform=args_opt.platform, - repeat_num=epoch_size, - batch_size=config_gpu.batch_size) - step_size = dataset.get_dataset_size() - # resume - if args_opt.pre_trained: - param_dict = load_checkpoint(args_opt.pre_trained) - load_param_into_net(net, param_dict) - # define optimizer - loss_scale = FixedLossScaleManager( - config_gpu.loss_scale, drop_overflow_update=False) - lr = Tensor(get_lr(global_step=0, - lr_init=0, - lr_end=0, - lr_max=config_gpu.lr, - warmup_epochs=config_gpu.warmup_epochs, - total_epochs=epoch_size, - steps_per_epoch=step_size)) - opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config_gpu.momentum, - config_gpu.weight_decay, config_gpu.loss_scale) - # define model - model = Model(net, loss_fn=loss, optimizer=opt, - loss_scale_manager=loss_scale) +def _load_param_into_net(ori_model, ckpt_param_dict): + """ + load fp32 model parameters to quantization model. + + Args: + ori_model: quantization model + ckpt_param_dict: f32 param + + Returns: + None + """ + iterable_dict = { + 'weight': iter([item for item in ckpt_param_dict.items() if item[0].endswith('weight')]), + 'bias': iter([item for item in ckpt_param_dict.items() if item[0].endswith('bias')]), + 'gamma': iter([item for item in ckpt_param_dict.items() if item[0].endswith('gamma')]), + 'beta': iter([item for item in ckpt_param_dict.items() if item[0].endswith('beta')]), + 'moving_mean': iter([item for item in ckpt_param_dict.items() if item[0].endswith('moving_mean')]), + 'moving_variance': iter( + [item for item in ckpt_param_dict.items() if item[0].endswith('moving_variance')]), + 'minq': iter([item for item in ckpt_param_dict.items() if item[0].endswith('minq')]), + 'maxq': iter([item for item in ckpt_param_dict.items() if item[0].endswith('maxq')]) + } + for name, param in ori_model.parameters_and_names(): + key_name = name.split(".")[-1] + if key_name not in iterable_dict.keys(): + continue + value_param = next(iterable_dict[key_name], None) + if value_param is not None: + param.set_parameter_data(value_param[1].data) + print(f'init model param {name} with checkpoint param {value_param[0]}') + +if __name__ == '__main__': + # train on ascend + print("train args: ", args_opt, "\ncfg: ", config_ascend, + "\nparallel args: rank_id {}, device_id {}, rank_size {}".format(rank_id, device_id, rank_size)) + + if run_distribute: + context.set_auto_parallel_context(device_num=rank_size, parallel_mode=ParallelMode.DATA_PARALLEL, + parameter_broadcast=True, mirror_mean=True) + auto_parallel_context().set_all_reduce_fusion_split_indices([140]) + init() + + epoch_size = config_ascend.epoch_size + net = mobilenet_v2_quant(num_classes=config_ascend.num_classes) + if config_ascend.label_smooth > 0: + loss = CrossEntropyWithLabelSmooth( + smooth_factor=config_ascend.label_smooth, num_classes=config_ascend.num_classes) + else: + loss = SoftmaxCrossEntropyWithLogits( + is_grad=False, sparse=True, reduction='mean') + dataset = create_dataset(dataset_path=args_opt.dataset_path, + do_train=True, + config=config_ascend, + platform=args_opt.platform, + repeat_num=epoch_size, + batch_size=config_ascend.batch_size) + step_size = dataset.get_dataset_size() + if args_opt.pre_trained: + param_dict = load_checkpoint(args_opt.pre_trained) + _load_param_into_net(net, param_dict) + + lr = Tensor(get_lr(global_step=config_ascend.start_epoch * step_size, + lr_init=0, + lr_end=0, + lr_max=config_ascend.lr, + warmup_epochs=config_ascend.warmup_epochs, + total_epochs=epoch_size + config_ascend.start_epoch, + steps_per_epoch=step_size)) + opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config_ascend.momentum, + config_ascend.weight_decay) + + model = Model(net, loss_fn=loss, optimizer=opt) + + cb = None + if rank_id == 0: cb = [Monitor(lr_init=lr.asnumpy())] - if config_gpu.save_checkpoint: - config_ck = CheckpointConfig(save_checkpoint_steps=config_gpu.save_checkpoint_epochs * step_size, - keep_checkpoint_max=config_gpu.keep_checkpoint_max) + if config_ascend.save_checkpoint: + config_ck = CheckpointConfig(save_checkpoint_steps=config_ascend.save_checkpoint_epochs * step_size, + keep_checkpoint_max=config_ascend.keep_checkpoint_max) ckpt_cb = ModelCheckpoint( - prefix="mobilenet", directory=config_gpu.save_checkpoint_path, config=config_ck) + prefix="mobilenet", directory=config_ascend.save_checkpoint_path, config=config_ck) cb += [ckpt_cb] - # begine train - model.train(epoch_size, dataset, callbacks=cb) - elif args_opt.platform == "Ascend": - # train on ascend - print("train args: ", args_opt, "\ncfg: ", config_ascend, - "\nparallel args: rank_id {}, device_id {}, rank_size {}".format(rank_id, device_id, rank_size)) - - if run_distribute: - context.set_auto_parallel_context(device_num=rank_size, parallel_mode=ParallelMode.DATA_PARALLEL, - parameter_broadcast=True, mirror_mean=True) - auto_parallel_context().set_all_reduce_fusion_split_indices([140]) - init() - - epoch_size = config_ascend.epoch_size - net = mobilenet_v2(num_classes=config_ascend.num_classes, platform="Ascend") - net = mobilenet_v2_quant(num_classes=config_ascend.num_classes, platform="Ascend") - net.to_float(mstype.float16) - for _, cell in net.cells_and_names(): - if isinstance(cell, nn.Dense): - cell.to_float(mstype.float32) - if config_ascend.label_smooth > 0: - loss = CrossEntropyWithLabelSmooth( - smooth_factor=config_ascend.label_smooth, num_classes=config_ascend.num_classes) - else: - loss = SoftmaxCrossEntropyWithLogits( - is_grad=False, sparse=True, reduction='mean') - dataset = create_dataset(dataset_path=args_opt.dataset_path, - do_train=True, - config=config_ascend, - platform=args_opt.platform, - repeat_num=epoch_size, - batch_size=config_ascend.batch_size) - step_size = dataset.get_dataset_size() - if args_opt.pre_trained: - param_dict = load_checkpoint(args_opt.pre_trained) - load_param_into_net(net, param_dict) - - loss_scale = FixedLossScaleManager( - config_ascend.loss_scale, drop_overflow_update=False) - lr = Tensor(get_lr(global_step=0, - lr_init=0, - lr_end=0, - lr_max=config_ascend.lr, - warmup_epochs=config_ascend.warmup_epochs, - total_epochs=epoch_size, - steps_per_epoch=step_size)) - opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config_ascend.momentum, - config_ascend.weight_decay, config_ascend.loss_scale) - - model = Model(net, loss_fn=loss, optimizer=opt, - loss_scale_manager=loss_scale) - - cb = None - if rank_id == 0: - cb = [Monitor(lr_init=lr.asnumpy())] - if config_ascend.save_checkpoint: - config_ck = CheckpointConfig(save_checkpoint_steps=config_ascend.save_checkpoint_epochs * step_size, - keep_checkpoint_max=config_ascend.keep_checkpoint_max) - ckpt_cb = ModelCheckpoint( - prefix="mobilenet", directory=config_ascend.save_checkpoint_path, config=config_ck) - cb += [ckpt_cb] - model.train(epoch_size, dataset, callbacks=cb) - else: - raise ValueError("Unsupport platform.") + model.train(epoch_size, dataset, callbacks=cb) diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index fdbaa01ff6fe7478495d4e431f494d34e2071bc6..af30af215e9782039e3d02cb39c33447fd52f669 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -69,11 +69,12 @@ class BatchNormFoldCell(Cell): """ - def __init__(self, momentum=0.9, epsilon=1e-5, freeze_bn=0): + def __init__(self, momentum=0.9, epsilon=1e-5, freeze_bn=0, freeze_bn_ascend=True): """init batch norm fold layer""" super(BatchNormFoldCell, self).__init__() self.epsilon = epsilon self.is_gpu = context.get_context('device_target') == "GPU" + self.freeze_bn_ascend = freeze_bn_ascend if self.is_gpu: self.bn_train = P.BatchNormFold(momentum, epsilon, is_training=True, freeze_bn=freeze_bn) self.bn_infer = P.BatchNormFold(momentum, epsilon, is_training=False, freeze_bn=freeze_bn) @@ -88,7 +89,7 @@ class BatchNormFoldCell(Cell): else: batch_mean, batch_std, running_mean, running_std = self.bn_infer(x, mean, variance, global_step) else: - if self.training: + if self.training and not self.freeze_bn_ascend: x_sum, x_square_sum = self.bn_reduce(x) _, batch_mean, batch_std, running_mean, running_std, mean_updated, variance_updated = \ self.bn_update(x, x_sum, x_square_sum, mean, variance) @@ -279,7 +280,8 @@ class Conv2dBatchNormQuant(Cell): num_bits=8, per_channel=False, symmetric=False, - narrow_range=False): + narrow_range=False, + freeze_bn_ascend=True): """init Conv2dBatchNormQuant layer""" super(Conv2dBatchNormQuant, self).__init__() self.in_channels = in_channels @@ -300,6 +302,7 @@ class Conv2dBatchNormQuant(Cell): self.symmetric = symmetric self.narrow_range = narrow_range self.is_gpu = context.get_context('device_target') == "GPU" + self.freeze_bn_ascend = freeze_bn_ascend # initialize convolution op and Parameter if context.get_context('device_target') == "Ascend" and group > 1: @@ -398,7 +401,7 @@ class Conv2dBatchNormQuant(Cell): out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, batch_std, batch_mean, running_std, running_mean, self.step) else: - if self.training: + if self.training and not self.freeze_bn_ascend: out = self.batchnorm_fold2_train(out, self.beta, self.gamma, batch_std, batch_mean, running_std) F.control_depend(out, self.assignadd(self.step, self.one)) else: