提交 cc497424 编写于 作者: C chenzupeng

adapt for mobilenetV2 quantization awared train in r0.3

上级 fb65a1a9
...@@ -41,7 +41,7 @@ Dataset used: imagenet ...@@ -41,7 +41,7 @@ Dataset used: imagenet
## Script and sample code ## Script and sample code
```python ```python
├── MobileNetV2 ├── mobilenetv2_quant
├── Readme.md ├── Readme.md
├── scripts ├── scripts
├──run_train.sh ├──run_train.sh
...@@ -51,7 +51,7 @@ Dataset used: imagenet ...@@ -51,7 +51,7 @@ Dataset used: imagenet
├──dataset.py ├──dataset.py
├──luanch.py ├──luanch.py
├──lr_generator.py ├──lr_generator.py
├──mobilenetV2.py ├──mobilenetV2_quant.py
├── train.py ├── train.py
├── eval.py ├── eval.py
``` ```
......
...@@ -21,11 +21,9 @@ from mindspore import context ...@@ -21,11 +21,9 @@ from mindspore import context
from mindspore import nn from mindspore import nn
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import dtype as mstype from src.mobilenetV2_quant import mobilenet_v2_quant
from mindspore.model_zoo.mobilenetV2 import mobilenet_v2
from src.dataset import create_dataset from src.dataset import create_dataset
from src.config import config_ascend, config_gpu from src.config import config_ascend
parser = argparse.ArgumentParser(description='Image classification') parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
...@@ -33,7 +31,6 @@ parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path ...@@ -33,7 +31,6 @@ parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path
parser.add_argument('--platform', type=str, default=None, help='run platform') parser.add_argument('--platform', type=str, default=None, help='run platform')
args_opt = parser.parse_args() args_opt = parser.parse_args()
if __name__ == '__main__': if __name__ == '__main__':
config_platform = None config_platform = None
net = None net = None
...@@ -42,24 +39,13 @@ if __name__ == '__main__': ...@@ -42,24 +39,13 @@ if __name__ == '__main__':
device_id = int(os.getenv('DEVICE_ID')) device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", context.set_context(mode=context.GRAPH_MODE, device_target="Ascend",
device_id=device_id, save_graphs=False) device_id=device_id, save_graphs=False)
net = mobilenet_v2(num_classes=config_platform.num_classes, platform="Ascend") net = mobilenet_v2_quant(num_classes=config_platform.num_classes)
elif args_opt.platform == "GPU":
config_platform = config_gpu
context.set_context(mode=context.GRAPH_MODE,
device_target="GPU", save_graphs=False)
net = mobilenet_v2(num_classes=config_platform.num_classes, platform="GPU")
else: else:
raise ValueError("Unsupport platform.") raise ValueError("Unsupport platform.")
loss = nn.SoftmaxCrossEntropyWithLogits( loss = nn.SoftmaxCrossEntropyWithLogits(
is_grad=False, sparse=True, reduction='mean') is_grad=False, sparse=True, reduction='mean')
if args_opt.platform == "Ascend":
net.to_float(mstype.float16)
for _, cell in net.cells_and_names():
if isinstance(cell, nn.Dense):
cell.to_float(mstype.float32)
dataset = create_dataset(dataset_path=args_opt.dataset_path, dataset = create_dataset(dataset_path=args_opt.dataset_path,
do_train=False, do_train=False,
config=config_platform, config=config_platform,
......
...@@ -15,8 +15,7 @@ ...@@ -15,8 +15,7 @@
# ============================================================================ # ============================================================================
if [ $# != 3 ] if [ $# != 3 ]
then then
echo "Ascend: sh run_infer.sh [PLATFORM] [DATASET_PATH] [CHECKPOINT_PATH] \ echo "Ascend: sh run_infer.sh [PLATFORM] [DATASET_PATH] [CHECKPOINT_PATH]"
GPU: sh run_infer.sh [PLATFORM] [DATASET_PATH] [CHECKPOINT_PATH]"
exit 1 exit 1
fi fi
......
...@@ -82,15 +82,12 @@ if [ $# -gt 6 ] || [ $# -lt 4 ] ...@@ -82,15 +82,12 @@ if [ $# -gt 6 ] || [ $# -lt 4 ]
then then
echo "Usage:\n \ echo "Usage:\n \
Ascend: sh run_train.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH]\n \ Ascend: sh run_train.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH]\n \
GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH]\n \
" "
exit 1 exit 1
fi fi
if [ $1 = "Ascend" ] ; then if [ $1 = "Ascend" ] ; then
run_ascend "$@" run_ascend "$@"
elif [ $1 = "GPU" ] ; then
run_gpu "$@"
else else
echo "not support platform" echo "not support platform"
fi; fi;
......
...@@ -21,10 +21,11 @@ config_ascend = ed({ ...@@ -21,10 +21,11 @@ config_ascend = ed({
"num_classes": 1000, "num_classes": 1000,
"image_height": 224, "image_height": 224,
"image_width": 224, "image_width": 224,
"batch_size": 256, "batch_size": 192,
"epoch_size": 200, "epoch_size": 40,
"warmup_epochs": 4, "start_epoch": 200,
"lr": 0.4, "warmup_epochs": 1,
"lr": 0.15,
"momentum": 0.9, "momentum": 0.9,
"weight_decay": 4e-5, "weight_decay": 4e-5,
"label_smooth": 0.1, "label_smooth": 0.1,
......
...@@ -37,24 +37,31 @@ def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch ...@@ -37,24 +37,31 @@ def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch
if platform == "Ascend": if platform == "Ascend":
rank_size = int(os.getenv("RANK_SIZE")) rank_size = int(os.getenv("RANK_SIZE"))
rank_id = int(os.getenv("RANK_ID")) rank_id = int(os.getenv("RANK_ID"))
if rank_size == 1: if do_train:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True) if rank_size == 1:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True)
else:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=rank_size, shard_id=rank_id)
else: else:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=False)
num_shards=rank_size, shard_id=rank_id)
elif platform == "GPU": elif platform == "GPU":
if do_train: if do_train:
from mindspore.communication.management import get_rank, get_group_size from mindspore.communication.management import get_rank, get_group_size
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=get_group_size(), shard_id=get_rank()) num_shards=get_group_size(), shard_id=get_rank())
else: else:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True) ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=False)
else: else:
raise ValueError("Unsupport platform.") raise ValueError("Unsupport platform.")
resize_height = config.image_height resize_height = config.image_height
resize_width = config.image_width resize_width = config.image_width
buffer_size = 1000
if do_train:
buffer_size = 20480
# apply shuffle operations
ds = ds.shuffle(buffer_size=buffer_size)
# define map operations # define map operations
decode_op = C.Decode() decode_op = C.Decode()
...@@ -63,12 +70,15 @@ def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch ...@@ -63,12 +70,15 @@ def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch
resize_op = C.Resize((256, 256)) resize_op = C.Resize((256, 256))
center_crop = C.CenterCrop(resize_width) center_crop = C.CenterCrop(resize_width)
rescale_op = C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4) random_color_op = C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4)
normalize_op = C.Normalize(mean=[0.485*255, 0.456*255, 0.406*255], std=[0.229*255, 0.224*255, 0.225*255]) normalize_op = C.Normalize(mean=[0.485*255, 0.456*255, 0.406*255], std=[0.229*255, 0.224*255, 0.225*255])
change_swap_op = C.HWC2CHW() change_swap_op = C.HWC2CHW()
transform_uniform = [horizontal_flip_op, random_color_op]
uni_aug = C.UniformAugment(operations=transform_uniform, num_ops=2)
if do_train: if do_train:
trans = [resize_crop_op, horizontal_flip_op, rescale_op, normalize_op, change_swap_op] trans = [resize_crop_op, uni_aug, normalize_op, change_swap_op]
else: else:
trans = [decode_op, resize_op, center_crop, normalize_op, change_swap_op] trans = [decode_op, resize_op, center_crop, normalize_op, change_swap_op]
...@@ -77,9 +87,6 @@ def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch ...@@ -77,9 +87,6 @@ def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch
ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=8) ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=8)
ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=8) ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=8)
# apply shuffle operations
ds = ds.shuffle(buffer_size=buffer_size)
# apply batch operations # apply batch operations
ds = ds.batch(batch_size, drop_remainder=True) ds = ds.batch(batch_size, drop_remainder=True)
......
...@@ -20,6 +20,7 @@ import subprocess ...@@ -20,6 +20,7 @@ import subprocess
import shutil import shutil
from argparse import ArgumentParser from argparse import ArgumentParser
def parse_args(): def parse_args():
""" """
parse args . parse args .
...@@ -79,7 +80,7 @@ def main(): ...@@ -79,7 +80,7 @@ def main():
device_ips[device_id] = device_ip device_ips[device_id] = device_ip
print('device_id:{}, device_ip:{}'.format(device_id, device_ip)) print('device_id:{}, device_ip:{}'.format(device_id, device_ip))
hccn_table = {} hccn_table = {}
hccn_table['board_id'] = '0x0000' hccn_table['board_id'] = '0x0020'
hccn_table['chip_info'] = '910' hccn_table['chip_info'] = '910'
hccn_table['deploy_mode'] = 'lab' hccn_table['deploy_mode'] = 'lab'
hccn_table['group_count'] = '1' hccn_table['group_count'] = '1'
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MobileNetV2 model define"""
import numpy as np
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.ops.operations import TensorAdd
from mindspore import Parameter, Tensor
from mindspore.common.initializer import initializer
__all__ = ['mobilenet_v2']
def _make_divisible(v, divisor, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
class GlobalAvgPooling(nn.Cell):
"""
Global avg pooling definition.
Args:
Returns:
Tensor, output tensor.
Examples:
>>> GlobalAvgPooling()
"""
def __init__(self):
super(GlobalAvgPooling, self).__init__()
self.mean = P.ReduceMean(keep_dims=False)
def construct(self, x):
x = self.mean(x, (2, 3))
return x
class DepthwiseConv(nn.Cell):
"""
Depthwise Convolution warpper definition.
Args:
in_planes (int): Input channel.
kernel_size (int): Input kernel size.
stride (int): Stride size.
pad_mode (str): pad mode in (pad, same, valid)
channel_multiplier (int): Output channel multiplier
has_bias (bool): has bias or not
Returns:
Tensor, output tensor.
Examples:
>>> DepthwiseConv(16, 3, 1, 'pad', 1, channel_multiplier=1)
"""
def __init__(self, in_planes, kernel_size, stride, pad_mode, pad, channel_multiplier=1, has_bias=False):
super(DepthwiseConv, self).__init__()
self.has_bias = has_bias
self.in_channels = in_planes
self.channel_multiplier = channel_multiplier
self.out_channels = in_planes * channel_multiplier
self.kernel_size = (kernel_size, kernel_size)
self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier,
kernel_size=self.kernel_size,
stride=stride, pad_mode=pad_mode, pad=pad)
self.bias_add = P.BiasAdd()
weight_shape = [channel_multiplier, in_planes, *self.kernel_size]
self.weight = Parameter(initializer('ones', weight_shape), name='weight')
if has_bias:
bias_shape = [channel_multiplier * in_planes]
self.bias = Parameter(initializer('zeros', bias_shape), name='bias')
else:
self.bias = None
def construct(self, x):
output = self.depthwise_conv(x, self.weight)
if self.has_bias:
output = self.bias_add(output, self.bias)
return output
class ConvBNReLU(nn.Cell):
"""
Convolution/Depthwise fused with Batchnorm and ReLU block definition.
Args:
in_planes (int): Input channel.
out_planes (int): Output channel.
kernel_size (int): Input kernel size.
stride (int): Stride size for the first convolutional layer. Default: 1.
groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1.
Returns:
Tensor, output tensor.
Examples:
>>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1)
"""
def __init__(self, platform, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
super(ConvBNReLU, self).__init__()
padding = (kernel_size - 1) // 2
if groups == 1:
conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding)
else:
if platform == "Ascend":
conv = DepthwiseConv(in_planes, kernel_size, stride, pad_mode='pad', pad=padding)
elif platform == "GPU":
conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride,
group=in_planes, pad_mode='pad', padding=padding)
layers = [conv, nn.BatchNorm2d(out_planes), nn.ReLU6()]
self.features = nn.SequentialCell(layers)
def construct(self, x):
output = self.features(x)
return output
class InvertedResidual(nn.Cell):
"""
Mobilenetv2 residual block definition.
Args:
inp (int): Input channel.
oup (int): Output channel.
stride (int): Stride size for the first convolutional layer. Default: 1.
expand_ratio (int): expand ration of input channel
Returns:
Tensor, output tensor.
Examples:
>>> ResidualBlock(3, 256, 1, 1)
"""
def __init__(self, platform, inp, oup, stride, expand_ratio):
super(InvertedResidual, self).__init__()
assert stride in [1, 2]
hidden_dim = int(round(inp * expand_ratio))
self.use_res_connect = stride == 1 and inp == oup
layers = []
if expand_ratio != 1:
layers.append(ConvBNReLU(platform, inp, hidden_dim, kernel_size=1))
layers.extend([
# dw
ConvBNReLU(platform, hidden_dim, hidden_dim,
stride=stride, groups=hidden_dim),
# pw-linear
nn.Conv2d(hidden_dim, oup, kernel_size=1,
stride=1, has_bias=False),
nn.BatchNorm2d(oup),
])
self.conv = nn.SequentialCell(layers)
self.add = TensorAdd()
self.cast = P.Cast()
def construct(self, x):
identity = x
x = self.conv(x)
if self.use_res_connect:
return self.add(identity, x)
return x
class MobileNetV2(nn.Cell):
"""
MobileNetV2 architecture.
Args:
class_num (Cell): number of classes.
width_mult (int): Channels multiplier for round to 8/16 and others. Default is 1.
has_dropout (bool): Is dropout used. Default is false
inverted_residual_setting (list): Inverted residual settings. Default is None
round_nearest (list): Channel round to . Default is 8
Returns:
Tensor, output tensor.
Examples:
>>> MobileNetV2(num_classes=1000)
"""
def __init__(self, platform, num_classes=1000, width_mult=1.,
has_dropout=False, inverted_residual_setting=None, round_nearest=8):
super(MobileNetV2, self).__init__()
block = InvertedResidual
input_channel = 32
last_channel = 1280
# setting of inverted residual blocks
self.cfgs = inverted_residual_setting
if inverted_residual_setting is None:
self.cfgs = [
# t, c, n, s
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
[6, 160, 3, 2],
[6, 320, 1, 1],
]
# building first layer
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
self.out_channels = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
features = [ConvBNReLU(platform, 3, input_channel, stride=2)]
# building inverted residual blocks
for t, c, n, s in self.cfgs:
output_channel = _make_divisible(c * width_mult, round_nearest)
for i in range(n):
stride = s if i == 0 else 1
features.append(block(platform, input_channel, output_channel, stride, expand_ratio=t))
input_channel = output_channel
# building last several layers
features.append(ConvBNReLU(platform, input_channel, self.out_channels, kernel_size=1))
# make it nn.CellList
self.features = nn.SequentialCell(features)
# mobilenet head
head = ([GlobalAvgPooling(), nn.Dense(self.out_channels, num_classes, has_bias=True)] if not has_dropout else
[GlobalAvgPooling(), nn.Dropout(0.2), nn.Dense(self.out_channels, num_classes, has_bias=True)])
self.head = nn.SequentialCell(head)
self._initialize_weights()
def construct(self, x):
x = self.features(x)
x = self.head(x)
return x
def _initialize_weights(self):
"""
Initialize weights.
Args:
Returns:
None.
Examples:
>>> _initialize_weights()
"""
for _, m in self.cells_and_names():
if isinstance(m, (nn.Conv2d, DepthwiseConv)):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.set_parameter_data(Tensor(np.random.normal(0, np.sqrt(2. / n),
m.weight.data.shape()).astype("float32")))
if m.bias is not None:
m.bias.set_parameter_data(
Tensor(np.zeros(m.bias.data.shape(), dtype="float32")))
elif isinstance(m, nn.BatchNorm2d):
m.gamma.set_parameter_data(
Tensor(np.ones(m.gamma.data.shape(), dtype="float32")))
m.beta.set_parameter_data(
Tensor(np.zeros(m.beta.data.shape(), dtype="float32")))
elif isinstance(m, nn.Dense):
m.weight.set_parameter_data(Tensor(np.random.normal(
0, 0.01, m.weight.data.shape()).astype("float32")))
if m.bias is not None:
m.bias.set_parameter_data(
Tensor(np.zeros(m.bias.data.shape(), dtype="float32")))
def mobilenet_v2(**kwargs):
"""
Constructs a MobileNet V2 model
"""
return MobileNetV2(**kwargs)
...@@ -13,18 +13,16 @@ ...@@ -13,18 +13,16 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""MobileNetV2 Quant model define""" """MobileNetV2 Quant model define"""
import numpy as np
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops.operations import TensorAdd from mindspore.ops.operations import TensorAdd
from mindspore import Parameter, Tensor
from mindspore.common.initializer import initializer
__all__ = ['mobilenet_v2_quant'] __all__ = ['mobilenet_v2_quant']
_ema_decay = 0.999 _ema_decay = 0.999
_symmetric = False _symmetric = False
def _make_divisible(v, divisor, min_value=None): def _make_divisible(v, divisor, min_value=None):
if min_value is None: if min_value is None:
min_value = divisor min_value = divisor
...@@ -57,52 +55,6 @@ class GlobalAvgPooling(nn.Cell): ...@@ -57,52 +55,6 @@ class GlobalAvgPooling(nn.Cell):
return x return x
class DepthwiseConv(nn.Cell):
"""
Depthwise Convolution warpper definition.
Args:
in_planes (int): Input channel.
kernel_size (int): Input kernel size.
stride (int): Stride size.
pad_mode (str): pad mode in (pad, same, valid)
channel_multiplier (int): Output channel multiplier
has_bias (bool): has bias or not
Returns:
Tensor, output tensor.
Examples:
>>> DepthwiseConv(16, 3, 1, 'pad', 1, channel_multiplier=1)
"""
def __init__(self, in_planes, kernel_size, stride, pad_mode, pad, channel_multiplier=1, has_bias=False):
super(DepthwiseConv, self).__init__()
self.has_bias = has_bias
self.in_channels = in_planes
self.channel_multiplier = channel_multiplier
self.out_channels = in_planes * channel_multiplier
self.kernel_size = (kernel_size, kernel_size)
self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier,
kernel_size=self.kernel_size,
stride=stride, pad_mode=pad_mode, pad=pad)
self.bias_add = P.BiasAdd()
weight_shape = [channel_multiplier, in_planes, *self.kernel_size]
self.weight = Parameter(initializer('ones', weight_shape), name='weight')
if has_bias:
bias_shape = [channel_multiplier * in_planes]
self.bias = Parameter(initializer('zeros', bias_shape), name='bias')
else:
self.bias = None
def construct(self, x):
output = self.depthwise_conv(x, self.weight)
if self.has_bias:
output = self.bias_add(output, self.bias)
return output
class ConvBNReLU(nn.Cell): class ConvBNReLU(nn.Cell):
""" """
Convolution/Depthwise fused with Batchnorm and ReLU block definition. Convolution/Depthwise fused with Batchnorm and ReLU block definition.
...@@ -121,21 +73,14 @@ class ConvBNReLU(nn.Cell): ...@@ -121,21 +73,14 @@ class ConvBNReLU(nn.Cell):
>>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1) >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1)
""" """
def __init__(self, platform, in_planes, out_planes, kernel_size=3, stride=1, groups=1): def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
super(ConvBNReLU, self).__init__() super(ConvBNReLU, self).__init__()
padding = (kernel_size - 1) // 2 padding = (kernel_size - 1) // 2
if groups == 1: conv = nn.Conv2dBatchNormQuant(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding,
conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding) group=groups)
else: layers = [conv, nn.ReLU()]
if platform == "Ascend":
conv = DepthwiseConv(in_planes, kernel_size, stride, pad_mode='pad', pad=padding)
elif platform == "GPU":
conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride,
group=in_planes, pad_mode='pad', padding=padding)
layers = [conv, nn.BatchNorm2d(out_planes), nn.ReLU6()]
self.features = nn.SequentialCell(layers) self.features = nn.SequentialCell(layers)
self.fake = nn.FakeQuantWithMinMax(in_planes, ema=True, ema_decay=_ema_decay, symmetric=_symmetric) self.fake = nn.FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay, symmetric=_symmetric, min_init=0)
def construct(self, x): def construct(self, x):
output = self.features(x) output = self.features(x)
...@@ -160,7 +105,7 @@ class InvertedResidual(nn.Cell): ...@@ -160,7 +105,7 @@ class InvertedResidual(nn.Cell):
>>> ResidualBlock(3, 256, 1, 1) >>> ResidualBlock(3, 256, 1, 1)
""" """
def __init__(self, platform, inp, oup, stride, expand_ratio): def __init__(self, inp, oup, stride, expand_ratio):
super(InvertedResidual, self).__init__() super(InvertedResidual, self).__init__()
assert stride in [1, 2] assert stride in [1, 2]
...@@ -169,19 +114,17 @@ class InvertedResidual(nn.Cell): ...@@ -169,19 +114,17 @@ class InvertedResidual(nn.Cell):
layers = [] layers = []
if expand_ratio != 1: if expand_ratio != 1:
layers.append(ConvBNReLU(platform, inp, hidden_dim, kernel_size=1)) layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
layers.extend([ layers.extend([
# dw # dw
ConvBNReLU(platform, hidden_dim, hidden_dim, ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
stride=stride, groups=hidden_dim),
# pw-linear # pw-linear
nn.Conv2dBatchNormQuant(hidden_dim, oup, kernel_size=1, stride=1, pad_mode='pad', padding=0, group=1), nn.Conv2dBatchNormQuant(hidden_dim, oup, kernel_size=1, stride=1, pad_mode='pad', padding=0, group=1),
nn.FakeQuantWithMinMax(oup, ema=True, ema_decay=_ema_decay, symmetric=_symmetric) nn.FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay, symmetric=_symmetric)
]) ])
self.conv = nn.SequentialCell(layers) self.conv = nn.SequentialCell(layers)
self.add = TensorAdd() self.add = TensorAdd()
self.add_fake = nn.FakeQuantWithMinMax(oup, ema=True, ema_decay=_ema_decay, symmetric=_symmetric) self.add_fake = nn.FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay, symmetric=_symmetric)
self.cast = P.Cast()
def construct(self, x): def construct(self, x):
identity = x identity = x
...@@ -209,7 +152,7 @@ class MobileNetV2Quant(nn.Cell): ...@@ -209,7 +152,7 @@ class MobileNetV2Quant(nn.Cell):
>>> MobileNetV2Quant(num_classes=1000) >>> MobileNetV2Quant(num_classes=1000)
""" """
def __init__(self, platform, num_classes=1000, width_mult=1., def __init__(self, num_classes=1000, width_mult=1.,
has_dropout=False, inverted_residual_setting=None, round_nearest=8): has_dropout=False, inverted_residual_setting=None, round_nearest=8):
super(MobileNetV2Quant, self).__init__() super(MobileNetV2Quant, self).__init__()
block = InvertedResidual block = InvertedResidual
...@@ -232,16 +175,17 @@ class MobileNetV2Quant(nn.Cell): ...@@ -232,16 +175,17 @@ class MobileNetV2Quant(nn.Cell):
# building first layer # building first layer
input_channel = _make_divisible(input_channel * width_mult, round_nearest) input_channel = _make_divisible(input_channel * width_mult, round_nearest)
self.out_channels = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) self.out_channels = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
features = [ConvBNReLU(platform, 3, input_channel, stride=2)] self.input_fake = nn.FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay, symmetric=_symmetric)
features = [ConvBNReLU(3, input_channel, stride=2)]
# building inverted residual blocks # building inverted residual blocks
for t, c, n, s in self.cfgs: for t, c, n, s in self.cfgs:
output_channel = _make_divisible(c * width_mult, round_nearest) output_channel = _make_divisible(c * width_mult, round_nearest)
for i in range(n): for i in range(n):
stride = s if i == 0 else 1 stride = s if i == 0 else 1
features.append(block(platform, input_channel, output_channel, stride, expand_ratio=t)) features.append(block(input_channel, output_channel, stride, expand_ratio=t))
input_channel = output_channel input_channel = output_channel
# building last several layers # building last several layers
features.append(ConvBNReLU(platform, input_channel, self.out_channels, kernel_size=1)) features.append(ConvBNReLU(input_channel, self.out_channels, kernel_size=1))
# make it nn.CellList # make it nn.CellList
self.features = nn.SequentialCell(features) self.features = nn.SequentialCell(features)
# mobilenet head # mobilenet head
...@@ -249,45 +193,12 @@ class MobileNetV2Quant(nn.Cell): ...@@ -249,45 +193,12 @@ class MobileNetV2Quant(nn.Cell):
[GlobalAvgPooling(), nn.Dropout(0.2), nn.Dense(self.out_channels, num_classes, has_bias=True)]) [GlobalAvgPooling(), nn.Dropout(0.2), nn.Dense(self.out_channels, num_classes, has_bias=True)])
self.head = nn.SequentialCell(head) self.head = nn.SequentialCell(head)
self._initialize_weights()
def construct(self, x): def construct(self, x):
x = self.input_fake(x)
x = self.features(x) x = self.features(x)
x = self.head(x) x = self.head(x)
return x return x
def _initialize_weights(self):
"""
Initialize weights.
Args:
Returns:
None.
Examples:
>>> _initialize_weights()
"""
for _, m in self.cells_and_names():
if isinstance(m, (nn.Conv2d, DepthwiseConv)):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.set_parameter_data(Tensor(np.random.normal(0, np.sqrt(2. / n),
m.weight.data.shape()).astype("float32")))
if m.bias is not None:
m.bias.set_parameter_data(
Tensor(np.zeros(m.bias.data.shape(), dtype="float32")))
elif isinstance(m, nn.BatchNorm2d):
m.gamma.set_parameter_data(
Tensor(np.ones(m.gamma.data.shape(), dtype="float32")))
m.beta.set_parameter_data(
Tensor(np.zeros(m.beta.data.shape(), dtype="float32")))
elif isinstance(m, nn.Dense):
m.weight.set_parameter_data(Tensor(np.random.normal(
0, 0.01, m.weight.data.shape()).astype("float32")))
if m.bias is not None:
m.bias.set_parameter_data(
Tensor(np.zeros(m.bias.data.shape(), dtype="float32")))
def mobilenet_v2_quant(**kwargs): def mobilenet_v2_quant(**kwargs):
""" """
......
...@@ -30,14 +30,12 @@ from mindspore.ops import functional as F ...@@ -30,14 +30,12 @@ from mindspore.ops import functional as F
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.train.model import Model, ParallelMode from mindspore.train.model import Model, ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback
from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.serialization import load_checkpoint
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.communication.management import init
from mindspore.communication.management import init, get_group_size
import mindspore.dataset.engine as de import mindspore.dataset.engine as de
from src.dataset import create_dataset from src.dataset import create_dataset
from src.lr_generator import get_lr from src.lr_generator import get_lr
from src.config import config_gpu, config_ascend from src.config import config_ascend
from src.mobilenetV2 import mobilenet_v2
from src.mobilenetV2_quant import mobilenet_v2_quant from src.mobilenetV2_quant import mobilenet_v2_quant
random.seed(1) random.seed(1)
...@@ -153,122 +151,87 @@ class Monitor(Callback): ...@@ -153,122 +151,87 @@ class Monitor(Callback):
np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1])) np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1]))
if __name__ == '__main__': def _load_param_into_net(ori_model, ckpt_param_dict):
if args_opt.platform == "GPU": """
# train on gpu load fp32 model parameters to quantization model.
print("train args: ", args_opt, "\ncfg: ", config_gpu)
Args:
init('nccl') ori_model: quantization model
context.set_auto_parallel_context(parallel_mode="data_parallel", ckpt_param_dict: f32 param
mirror_mean=True,
device_num=get_group_size()) Returns:
None
# define net """
net = mobilenet_v2(num_classes=config_gpu.num_classes, platform="GPU") iterable_dict = {
# define loss 'weight': iter([item for item in ckpt_param_dict.items() if item[0].endswith('weight')]),
if config_gpu.label_smooth > 0: 'bias': iter([item for item in ckpt_param_dict.items() if item[0].endswith('bias')]),
loss = CrossEntropyWithLabelSmooth( 'gamma': iter([item for item in ckpt_param_dict.items() if item[0].endswith('gamma')]),
smooth_factor=config_gpu.label_smooth, num_classes=config_gpu.num_classes) 'beta': iter([item for item in ckpt_param_dict.items() if item[0].endswith('beta')]),
else: 'moving_mean': iter([item for item in ckpt_param_dict.items() if item[0].endswith('moving_mean')]),
loss = SoftmaxCrossEntropyWithLogits( 'moving_variance': iter(
is_grad=False, sparse=True, reduction='mean') [item for item in ckpt_param_dict.items() if item[0].endswith('moving_variance')]),
# define dataset 'minq': iter([item for item in ckpt_param_dict.items() if item[0].endswith('minq')]),
epoch_size = config_gpu.epoch_size 'maxq': iter([item for item in ckpt_param_dict.items() if item[0].endswith('maxq')])
dataset = create_dataset(dataset_path=args_opt.dataset_path, }
do_train=True, for name, param in ori_model.parameters_and_names():
config=config_gpu, key_name = name.split(".")[-1]
platform=args_opt.platform, if key_name not in iterable_dict.keys():
repeat_num=epoch_size, continue
batch_size=config_gpu.batch_size) value_param = next(iterable_dict[key_name], None)
step_size = dataset.get_dataset_size() if value_param is not None:
# resume param.set_parameter_data(value_param[1].data)
if args_opt.pre_trained: print(f'init model param {name} with checkpoint param {value_param[0]}')
param_dict = load_checkpoint(args_opt.pre_trained)
load_param_into_net(net, param_dict)
# define optimizer
loss_scale = FixedLossScaleManager(
config_gpu.loss_scale, drop_overflow_update=False)
lr = Tensor(get_lr(global_step=0,
lr_init=0,
lr_end=0,
lr_max=config_gpu.lr,
warmup_epochs=config_gpu.warmup_epochs,
total_epochs=epoch_size,
steps_per_epoch=step_size))
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config_gpu.momentum,
config_gpu.weight_decay, config_gpu.loss_scale)
# define model
model = Model(net, loss_fn=loss, optimizer=opt,
loss_scale_manager=loss_scale)
if __name__ == '__main__':
# train on ascend
print("train args: ", args_opt, "\ncfg: ", config_ascend,
"\nparallel args: rank_id {}, device_id {}, rank_size {}".format(rank_id, device_id, rank_size))
if run_distribute:
context.set_auto_parallel_context(device_num=rank_size, parallel_mode=ParallelMode.DATA_PARALLEL,
parameter_broadcast=True, mirror_mean=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([140])
init()
epoch_size = config_ascend.epoch_size
net = mobilenet_v2_quant(num_classes=config_ascend.num_classes)
if config_ascend.label_smooth > 0:
loss = CrossEntropyWithLabelSmooth(
smooth_factor=config_ascend.label_smooth, num_classes=config_ascend.num_classes)
else:
loss = SoftmaxCrossEntropyWithLogits(
is_grad=False, sparse=True, reduction='mean')
dataset = create_dataset(dataset_path=args_opt.dataset_path,
do_train=True,
config=config_ascend,
platform=args_opt.platform,
repeat_num=epoch_size,
batch_size=config_ascend.batch_size)
step_size = dataset.get_dataset_size()
if args_opt.pre_trained:
param_dict = load_checkpoint(args_opt.pre_trained)
_load_param_into_net(net, param_dict)
lr = Tensor(get_lr(global_step=config_ascend.start_epoch * step_size,
lr_init=0,
lr_end=0,
lr_max=config_ascend.lr,
warmup_epochs=config_ascend.warmup_epochs,
total_epochs=epoch_size + config_ascend.start_epoch,
steps_per_epoch=step_size))
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config_ascend.momentum,
config_ascend.weight_decay)
model = Model(net, loss_fn=loss, optimizer=opt)
cb = None
if rank_id == 0:
cb = [Monitor(lr_init=lr.asnumpy())] cb = [Monitor(lr_init=lr.asnumpy())]
if config_gpu.save_checkpoint: if config_ascend.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=config_gpu.save_checkpoint_epochs * step_size, config_ck = CheckpointConfig(save_checkpoint_steps=config_ascend.save_checkpoint_epochs * step_size,
keep_checkpoint_max=config_gpu.keep_checkpoint_max) keep_checkpoint_max=config_ascend.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint( ckpt_cb = ModelCheckpoint(
prefix="mobilenet", directory=config_gpu.save_checkpoint_path, config=config_ck) prefix="mobilenet", directory=config_ascend.save_checkpoint_path, config=config_ck)
cb += [ckpt_cb] cb += [ckpt_cb]
# begine train model.train(epoch_size, dataset, callbacks=cb)
model.train(epoch_size, dataset, callbacks=cb)
elif args_opt.platform == "Ascend":
# train on ascend
print("train args: ", args_opt, "\ncfg: ", config_ascend,
"\nparallel args: rank_id {}, device_id {}, rank_size {}".format(rank_id, device_id, rank_size))
if run_distribute:
context.set_auto_parallel_context(device_num=rank_size, parallel_mode=ParallelMode.DATA_PARALLEL,
parameter_broadcast=True, mirror_mean=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([140])
init()
epoch_size = config_ascend.epoch_size
net = mobilenet_v2(num_classes=config_ascend.num_classes, platform="Ascend")
net = mobilenet_v2_quant(num_classes=config_ascend.num_classes, platform="Ascend")
net.to_float(mstype.float16)
for _, cell in net.cells_and_names():
if isinstance(cell, nn.Dense):
cell.to_float(mstype.float32)
if config_ascend.label_smooth > 0:
loss = CrossEntropyWithLabelSmooth(
smooth_factor=config_ascend.label_smooth, num_classes=config_ascend.num_classes)
else:
loss = SoftmaxCrossEntropyWithLogits(
is_grad=False, sparse=True, reduction='mean')
dataset = create_dataset(dataset_path=args_opt.dataset_path,
do_train=True,
config=config_ascend,
platform=args_opt.platform,
repeat_num=epoch_size,
batch_size=config_ascend.batch_size)
step_size = dataset.get_dataset_size()
if args_opt.pre_trained:
param_dict = load_checkpoint(args_opt.pre_trained)
load_param_into_net(net, param_dict)
loss_scale = FixedLossScaleManager(
config_ascend.loss_scale, drop_overflow_update=False)
lr = Tensor(get_lr(global_step=0,
lr_init=0,
lr_end=0,
lr_max=config_ascend.lr,
warmup_epochs=config_ascend.warmup_epochs,
total_epochs=epoch_size,
steps_per_epoch=step_size))
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config_ascend.momentum,
config_ascend.weight_decay, config_ascend.loss_scale)
model = Model(net, loss_fn=loss, optimizer=opt,
loss_scale_manager=loss_scale)
cb = None
if rank_id == 0:
cb = [Monitor(lr_init=lr.asnumpy())]
if config_ascend.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=config_ascend.save_checkpoint_epochs * step_size,
keep_checkpoint_max=config_ascend.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(
prefix="mobilenet", directory=config_ascend.save_checkpoint_path, config=config_ck)
cb += [ckpt_cb]
model.train(epoch_size, dataset, callbacks=cb)
else:
raise ValueError("Unsupport platform.")
...@@ -69,11 +69,12 @@ class BatchNormFoldCell(Cell): ...@@ -69,11 +69,12 @@ class BatchNormFoldCell(Cell):
""" """
def __init__(self, momentum=0.9, epsilon=1e-5, freeze_bn=0): def __init__(self, momentum=0.9, epsilon=1e-5, freeze_bn=0, freeze_bn_ascend=True):
"""init batch norm fold layer""" """init batch norm fold layer"""
super(BatchNormFoldCell, self).__init__() super(BatchNormFoldCell, self).__init__()
self.epsilon = epsilon self.epsilon = epsilon
self.is_gpu = context.get_context('device_target') == "GPU" self.is_gpu = context.get_context('device_target') == "GPU"
self.freeze_bn_ascend = freeze_bn_ascend
if self.is_gpu: if self.is_gpu:
self.bn_train = P.BatchNormFold(momentum, epsilon, is_training=True, freeze_bn=freeze_bn) self.bn_train = P.BatchNormFold(momentum, epsilon, is_training=True, freeze_bn=freeze_bn)
self.bn_infer = P.BatchNormFold(momentum, epsilon, is_training=False, freeze_bn=freeze_bn) self.bn_infer = P.BatchNormFold(momentum, epsilon, is_training=False, freeze_bn=freeze_bn)
...@@ -88,7 +89,7 @@ class BatchNormFoldCell(Cell): ...@@ -88,7 +89,7 @@ class BatchNormFoldCell(Cell):
else: else:
batch_mean, batch_std, running_mean, running_std = self.bn_infer(x, mean, variance, global_step) batch_mean, batch_std, running_mean, running_std = self.bn_infer(x, mean, variance, global_step)
else: else:
if self.training: if self.training and not self.freeze_bn_ascend:
x_sum, x_square_sum = self.bn_reduce(x) x_sum, x_square_sum = self.bn_reduce(x)
_, batch_mean, batch_std, running_mean, running_std, mean_updated, variance_updated = \ _, batch_mean, batch_std, running_mean, running_std, mean_updated, variance_updated = \
self.bn_update(x, x_sum, x_square_sum, mean, variance) self.bn_update(x, x_sum, x_square_sum, mean, variance)
...@@ -279,7 +280,8 @@ class Conv2dBatchNormQuant(Cell): ...@@ -279,7 +280,8 @@ class Conv2dBatchNormQuant(Cell):
num_bits=8, num_bits=8,
per_channel=False, per_channel=False,
symmetric=False, symmetric=False,
narrow_range=False): narrow_range=False,
freeze_bn_ascend=True):
"""init Conv2dBatchNormQuant layer""" """init Conv2dBatchNormQuant layer"""
super(Conv2dBatchNormQuant, self).__init__() super(Conv2dBatchNormQuant, self).__init__()
self.in_channels = in_channels self.in_channels = in_channels
...@@ -300,6 +302,7 @@ class Conv2dBatchNormQuant(Cell): ...@@ -300,6 +302,7 @@ class Conv2dBatchNormQuant(Cell):
self.symmetric = symmetric self.symmetric = symmetric
self.narrow_range = narrow_range self.narrow_range = narrow_range
self.is_gpu = context.get_context('device_target') == "GPU" self.is_gpu = context.get_context('device_target') == "GPU"
self.freeze_bn_ascend = freeze_bn_ascend
# initialize convolution op and Parameter # initialize convolution op and Parameter
if context.get_context('device_target') == "Ascend" and group > 1: if context.get_context('device_target') == "Ascend" and group > 1:
...@@ -398,7 +401,7 @@ class Conv2dBatchNormQuant(Cell): ...@@ -398,7 +401,7 @@ class Conv2dBatchNormQuant(Cell):
out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, out = self.batchnorm_fold2_infer(out, self.beta, self.gamma,
batch_std, batch_mean, running_std, running_mean, self.step) batch_std, batch_mean, running_std, running_mean, self.step)
else: else:
if self.training: if self.training and not self.freeze_bn_ascend:
out = self.batchnorm_fold2_train(out, self.beta, self.gamma, batch_std, batch_mean, running_std) out = self.batchnorm_fold2_train(out, self.beta, self.gamma, batch_std, batch_mean, running_std)
F.control_depend(out, self.assignadd(self.step, self.one)) F.control_depend(out, self.assignadd(self.step, self.one))
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册