提交 26b1242b 编写于 作者: C chenzomi

mobilenetV2 change for gpu

上级 f80e5796
...@@ -960,7 +960,7 @@ class ActQuant(_QuantActivation): ...@@ -960,7 +960,7 @@ class ActQuant(_QuantActivation):
Tensor, with the same type and shape as the `x`. Tensor, with the same type and shape as the `x`.
Examples: Examples:
>>> act_quant = nn.ActQuant(nn.ReLU) >>> act_quant = nn.ActQuant(nn.ReLU())
>>> input_x = Tensor(np.array([[1, 2, -1], [-2, 0, -1]]), mindspore.float32) >>> input_x = Tensor(np.array([[1, 2, -1], [-2, 0, -1]]), mindspore.float32)
>>> result = act_quant(input_x) >>> result = act_quant(input_x)
""" """
...@@ -1009,7 +1009,7 @@ class LeakyReLUQuant(_QuantActivation): ...@@ -1009,7 +1009,7 @@ class LeakyReLUQuant(_QuantActivation):
quant_delay (int): Quantization delay parameters according by global step. Default: 0. quant_delay (int): Quantization delay parameters according by global step. Default: 0.
Inputs: Inputs:
- **x** (Tensor) - The input of HSwishQuant. - **x** (Tensor) - The input of LeakyReLUQuant.
Outputs: Outputs:
Tensor, with the same type and shape as the `x`. Tensor, with the same type and shape as the `x`.
......
...@@ -306,7 +306,7 @@ class ExportToQuantInferNetwork: ...@@ -306,7 +306,7 @@ class ExportToQuantInferNetwork:
std_dev (int, float): Input data variance. Default: 127.5. std_dev (int, float): Input data variance. Default: 127.5.
Returns: Returns:
Cell, GEIR backend Infer network. Cell, Infer network.
""" """
__quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"] __quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"]
......
...@@ -91,6 +91,6 @@ if [ $1 = "Ascend" ] ; then ...@@ -91,6 +91,6 @@ if [ $1 = "Ascend" ] ; then
elif [ $1 = "GPU" ] ; then elif [ $1 = "GPU" ] ; then
run_gpu "$@" run_gpu "$@"
else else
echo "not support platform" echo "Unsupported platform."
fi; fi;
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# """MobileNetV2 Quant model define"""
import numpy as np
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore import Tensor
__all__ = ['mobilenetV2']
def _make_divisible(v, divisor, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10 %.
if new_v < 0.9 * v:
new_v += divisor
return new_v
class GlobalAvgPooling(nn.Cell):
"""
Global avg pooling definition.
Args:
Returns:
Tensor, output tensor.
Examples:
>>> GlobalAvgPooling()
"""
def __init__(self):
super(GlobalAvgPooling, self).__init__()
self.mean = P.ReduceMean(keep_dims=False)
def construct(self, x):
x = self.mean(x, (2, 3))
return x
class ConvBNReLU(nn.Cell):
"""
Convolution/Depthwise fused with Batchnorm and ReLU block definition.
Args:
in_planes (int): Input channel.
out_planes (int): Output channel.
kernel_size (int): Input kernel size.
stride (int): Stride size for the first convolutional layer. Default: 1.
groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1.
Returns:
Tensor, output tensor.
Examples:
>>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1)
"""
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
super(ConvBNReLU, self).__init__()
padding = (kernel_size - 1) // 2
self.conv = nn.Conv2dBnAct(in_planes, out_planes, kernel_size,
stride=stride,
pad_mode='pad',
padding=padding,
group=groups,
has_bn=True,
activation='relu')
def construct(self, x):
x = self.conv(x)
return x
class InvertedResidual(nn.Cell):
"""
Mobilenetv2 residual block definition.
Args:
inp (int): Input channel.
oup (int): Output channel.
stride (int): Stride size for the first convolutional layer. Default: 1.
expand_ratio (int): expand ration of input channel
Returns:
Tensor, output tensor.
Examples:
>>> ResidualBlock(3, 256, 1, 1)
"""
def __init__(self, inp, oup, stride, expand_ratio):
super(InvertedResidual, self).__init__()
assert stride in [1, 2]
hidden_dim = int(round(inp * expand_ratio))
self.use_res_connect = stride == 1 and inp == oup
layers = []
if expand_ratio != 1:
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
layers.extend([
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
nn.Conv2dBnAct(hidden_dim, oup, kernel_size=1, stride=1, pad_mode='pad', padding=0, group=1, has_bn=True)
])
self.conv = nn.SequentialCell(layers)
self.add = P.TensorAdd()
def construct(self, x):
out = self.conv(x)
if self.use_res_connect:
out = self.add(out, x)
return out
class mobilenetV2(nn.Cell):
"""
mobilenetV2 fusion architecture.
Args:
class_num (Cell): number of classes.
width_mult (int): Channels multiplier for round to 8/16 and others. Default is 1.
has_dropout (bool): Is dropout used. Default is false
inverted_residual_setting (list): Inverted residual settings. Default is None
round_nearest (list): Channel round to . Default is 8
Returns:
Tensor, output tensor.
Examples:
>>> mobilenetV2(num_classes=1000)
"""
def __init__(self, num_classes=1000, width_mult=1.,
has_dropout=False, inverted_residual_setting=None, round_nearest=8):
super(mobilenetV2, self).__init__()
block = InvertedResidual
input_channel = 32
last_channel = 1280
# setting of inverted residual blocks
self.cfgs = inverted_residual_setting
if inverted_residual_setting is None:
self.cfgs = [
# t, c, n, s
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
[6, 160, 3, 2],
[6, 320, 1, 1],
]
# building first layer
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
self.out_channels = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
features = [ConvBNReLU(3, input_channel, stride=2)]
# building inverted residual blocks
for t, c, n, s in self.cfgs:
output_channel = _make_divisible(c * width_mult, round_nearest)
for i in range(n):
stride = s if i == 0 else 1
features.append(block(input_channel, output_channel, stride, expand_ratio=t))
input_channel = output_channel
# building last several layers
features.append(ConvBNReLU(input_channel, self.out_channels, kernel_size=1))
# make it nn.CellList
self.features = nn.SequentialCell(features)
# mobilenet head
head = ([GlobalAvgPooling(),
nn.DenseBnAct(self.out_channels, num_classes, has_bias=True, has_bn=False)
] if not has_dropout else
[GlobalAvgPooling(),
nn.Dropout(0.2),
nn.DenseBnAct(self.out_channels, num_classes, has_bias=True, has_bn=False)
])
self.head = nn.SequentialCell(head)
# init weights
self._initialize_weights()
def construct(self, x):
x = self.features(x)
x = self.head(x)
return x
def _initialize_weights(self):
"""
Initialize weights.
Args:
Returns:
None.
Examples:
>>> _initialize_weights()
"""
for _, m in self.cells_and_names():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
w = Tensor(np.random.normal(0, np.sqrt(2. / n), m.weight.data.shape).astype("float32"))
m.weight.set_parameter_data(w)
if m.bias is not None:
m.bias.set_parameter_data(Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
elif isinstance(m, nn.Conv2dBnAct):
n = m.conv.kernel_size[0] * m.conv.kernel_size[1] * m.conv.out_channels
w = Tensor(np.random.normal(0, np.sqrt(2. / n), m.conv.weight.data.shape).astype("float32"))
m.conv.weight.set_parameter_data(w)
if m.conv.bias is not None:
m.conv.bias.set_parameter_data(Tensor(np.zeros(m.conv.bias.data.shape, dtype="float32")))
elif isinstance(m, nn.BatchNorm2d):
m.gamma.set_parameter_data(Tensor(np.ones(m.gamma.data.shape, dtype="float32")))
m.beta.set_parameter_data(Tensor(np.zeros(m.beta.data.shape, dtype="float32")))
elif isinstance(m, nn.Dense):
m.weight.set_parameter_data(Tensor(np.random.normal(0, 0.01, m.weight.data.shape).astype("float32")))
if m.bias is not None:
m.bias.set_parameter_data(Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
elif isinstance(m, nn.DenseBnAct):
m.dense.weight.set_parameter_data(
Tensor(np.random.normal(0, 0.01, m.dense.weight.data.shape).astype("float32")))
if m.dense.bias is not None:
m.dense.bias.set_parameter_data(Tensor(np.zeros(m.dense.bias.data.shape, dtype="float32")))
...@@ -12,7 +12,8 @@ ...@@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""train_imagenet.""" """Train mobilenetV2 on ImageNet."""
import os import os
import time import time
import argparse import argparse
...@@ -165,15 +166,14 @@ if __name__ == '__main__': ...@@ -165,15 +166,14 @@ if __name__ == '__main__':
print("train args: ", args_opt) print("train args: ", args_opt)
print("cfg: ", config_gpu) print("cfg: ", config_gpu)
# define net # define network
net = mobilenet_v2(num_classes=config_gpu.num_classes, platform="GPU") net = mobilenet_v2(num_classes=config_gpu.num_classes, platform="GPU")
# define loss # define loss
if config_gpu.label_smooth > 0: if config_gpu.label_smooth > 0:
loss = CrossEntropyWithLabelSmooth( loss = CrossEntropyWithLabelSmooth(smooth_factor=config_gpu.label_smooth,
smooth_factor=config_gpu.label_smooth, num_classes=config_gpu.num_classes) num_classes=config_gpu.num_classes)
else: else:
loss = SoftmaxCrossEntropyWithLogits( loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')
is_grad=False, sparse=True, reduction='mean')
# define dataset # define dataset
epoch_size = config_gpu.epoch_size epoch_size = config_gpu.epoch_size
dataset = create_dataset(dataset_path=args_opt.dataset_path, dataset = create_dataset(dataset_path=args_opt.dataset_path,
...@@ -187,7 +187,8 @@ if __name__ == '__main__': ...@@ -187,7 +187,8 @@ if __name__ == '__main__':
if args_opt.pre_trained: if args_opt.pre_trained:
param_dict = load_checkpoint(args_opt.pre_trained) param_dict = load_checkpoint(args_opt.pre_trained)
load_param_into_net(net, param_dict) load_param_into_net(net, param_dict)
# define optimizer
# get learning rate
loss_scale = FixedLossScaleManager( loss_scale = FixedLossScaleManager(
config_gpu.loss_scale, drop_overflow_update=False) config_gpu.loss_scale, drop_overflow_update=False)
lr = Tensor(get_lr(global_step=0, lr = Tensor(get_lr(global_step=0,
...@@ -197,12 +198,14 @@ if __name__ == '__main__': ...@@ -197,12 +198,14 @@ if __name__ == '__main__':
warmup_epochs=config_gpu.warmup_epochs, warmup_epochs=config_gpu.warmup_epochs,
total_epochs=epoch_size, total_epochs=epoch_size,
steps_per_epoch=step_size)) steps_per_epoch=step_size))
# define optimization
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config_gpu.momentum, opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config_gpu.momentum,
config_gpu.weight_decay, config_gpu.loss_scale) config_gpu.weight_decay, config_gpu.loss_scale)
# define model # define model
model = Model(net, loss_fn=loss, optimizer=opt, model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale)
loss_scale_manager=loss_scale)
print("============== Starting Training ==============")
cb = [Monitor(lr_init=lr.asnumpy())] cb = [Monitor(lr_init=lr.asnumpy())]
ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"
if config_gpu.save_checkpoint: if config_gpu.save_checkpoint:
...@@ -212,6 +215,7 @@ if __name__ == '__main__': ...@@ -212,6 +215,7 @@ if __name__ == '__main__':
cb += [ckpt_cb] cb += [ckpt_cb]
# begin train # begin train
model.train(epoch_size, dataset, callbacks=cb) model.train(epoch_size, dataset, callbacks=cb)
print("============== End Training ==============")
elif args_opt.platform == "Ascend": elif args_opt.platform == "Ascend":
# train on ascend # train on ascend
print("train args: ", args_opt, "\ncfg: ", config_ascend, print("train args: ", args_opt, "\ncfg: ", config_ascend,
......
...@@ -64,12 +64,14 @@ Dataset use: ImageNet ...@@ -64,12 +64,14 @@ Dataset use: ImageNet
Train a MindSpore fusion MobileNetV2 model for ImageNet, like: Train a MindSpore fusion MobileNetV2 model for ImageNet, like:
- sh run_train.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH] - Ascend: sh run_train.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH]
- GPU: sh run_trian.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]
You can just run this command instead. You can just run this command instead.
``` bash ``` bash
>>> sh run_train.sh Ascend 4 192.168.0.1 0,1,2,3 ~/imagenet/train/ ~/mobilenet.ckpt >>> Ascend: sh run_train.sh Ascend 4 192.168.0.1 0,1,2,3 ~/imagenet/train/ ~/mobilenet.ckpt
>>> GPU: sh run_train.sh GPU 8 0,1,2,3,4,5,6,7 ~/imagenet/train/
``` ```
Training result will be stored in the example path. Checkpoints will be stored at `. /checkpoint` by default, and training log will be redirected to `./train/train.log` like followings. Training result will be stored in the example path. Checkpoints will be stored at `. /checkpoint` by default, and training log will be redirected to `./train/train.log` like followings.
......
...@@ -46,16 +46,50 @@ run_ascend() ...@@ -46,16 +46,50 @@ run_ascend()
--device_target=$1 &> train.log & # dataset train folder --device_target=$1 &> train.log & # dataset train folder
} }
run_gpu()
{
if [ $2 -lt 1 ] && [ $2 -gt 8 ]
then
echo "error: DEVICE_NUM=$2 is not in (1-8)"
exit 1
fi
if [ ! -d $4 ]
then
echo "error: DATASET_PATH=$4 is not a directory"
exit 1
fi
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
if [ -d "../train" ];
then
rm -rf ../train
fi
mkdir ../train
cd ../train || exit
export CUDA_VISIBLE_DEVICES="$3"
mpirun -n $2 --allow-run-as-root \
python ${BASEPATH}/../train.py \
--dataset_path=$4 \
--device_target=$1 \
&> ../train.log & # dataset train folder
}
if [ $# -gt 6 ] || [ $# -lt 4 ] if [ $# -gt 6 ] || [ $# -lt 4 ]
then then
echo "Usage:\n \ echo "Usage:\n \
Ascend: sh run_train.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH]\n \ Ascend: sh run_train.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH]\n \
GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]\n \
" "
exit 1 exit 1
fi fi
if [ $1 = "Ascend" ] ; then if [ $1 = "Ascend" ] ; then
run_ascend "$@" run_ascend "$@"
elif [ $1 = "GPU" ] ; then
run_gpu "$@"
else else
echo "Unsupported device target." echo "Unsupported device target."
fi; fi;
......
...@@ -47,16 +47,51 @@ run_ascend() ...@@ -47,16 +47,51 @@ run_ascend()
--device_target=$1 &> train.log & # dataset train folder --device_target=$1 &> train.log & # dataset train folder
} }
run_gpu()
{
if [ $2 -lt 1 ] && [ $2 -gt 8 ]
then
echo "error: DEVICE_NUM=$2 is not in (1-8)"
exit 1
fi
if [ ! -d $4 ]
then
echo "error: DATASET_PATH=$4 is not a directory"
exit 1
fi
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
if [ -d "../train" ];
then
rm -rf ../train
fi
mkdir ../train
cd ../train || exit
export CUDA_VISIBLE_DEVICES="$3"
mpirun -n $2 --allow-run-as-root \
python ${BASEPATH}/../train.py \
--dataset_path=$4 \
--device_target=$1 \
--quantization_aware=True \
&> ../train.log & # dataset train folder
}
if [ $# -gt 6 ] || [ $# -lt 4 ] if [ $# -gt 6 ] || [ $# -lt 4 ]
then then
echo "Usage:\n \ echo "Usage:\n \
Ascend: sh run_train.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH]\n \ Ascend: sh run_train.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH]\n \
GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]\n \
" "
exit 1 exit 1
fi fi
if [ $1 = "Ascend" ] ; then if [ $1 = "Ascend" ] ; then
run_ascend "$@" run_ascend "$@"
elif [ $1 = "GPU" ] ; then
run_gpu "$@"
else else
echo "Unsupported device target." echo "Unsupported device target."
fi; fi;
......
...@@ -33,7 +33,7 @@ config_ascend = ed({ ...@@ -33,7 +33,7 @@ config_ascend = ed({
"loss_scale": 1024, "loss_scale": 1024,
"save_checkpoint": True, "save_checkpoint": True,
"save_checkpoint_epochs": 1, "save_checkpoint_epochs": 1,
"keep_checkpoint_max": 200, "keep_checkpoint_max": 300,
"save_checkpoint_path": "./checkpoint", "save_checkpoint_path": "./checkpoint",
"quantization_aware": False, "quantization_aware": False,
}) })
...@@ -54,7 +54,45 @@ config_ascend_quant = ed({ ...@@ -54,7 +54,45 @@ config_ascend_quant = ed({
"loss_scale": 1024, "loss_scale": 1024,
"save_checkpoint": True, "save_checkpoint": True,
"save_checkpoint_epochs": 1, "save_checkpoint_epochs": 1,
"keep_checkpoint_max": 200, "keep_checkpoint_max": 300,
"save_checkpoint_path": "./checkpoint",
"quantization_aware": True,
})
config_gpu = ed({
"num_classes": 1000,
"image_height": 224,
"image_width": 224,
"batch_size": 150,
"epoch_size": 200,
"warmup_epochs": 4,
"lr": 0.8,
"momentum": 0.9,
"weight_decay": 4e-5,
"label_smooth": 0.1,
"loss_scale": 1024,
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 300,
"save_checkpoint_path": "./checkpoint",
})
config_gpu_quant = ed({
"num_classes": 1000,
"image_height": 224,
"image_width": 224,
"batch_size": 134,
"epoch_size": 60,
"start_epoch": 200,
"warmup_epochs": 1,
"lr": 0.3,
"momentum": 0.9,
"weight_decay": 4e-5,
"label_smooth": 0.1,
"loss_scale": 1024,
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 300,
"save_checkpoint_path": "./checkpoint", "save_checkpoint_path": "./checkpoint",
"quantization_aware": True, "quantization_aware": True,
}) })
...@@ -222,6 +222,12 @@ class mobilenetV2(nn.Cell): ...@@ -222,6 +222,12 @@ class mobilenetV2(nn.Cell):
m.weight.set_parameter_data(w) m.weight.set_parameter_data(w)
if m.bias is not None: if m.bias is not None:
m.bias.set_parameter_data(Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) m.bias.set_parameter_data(Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
elif isinstance(m, nn.Conv2dBnAct):
n = m.conv.kernel_size[0] * m.conv.kernel_size[1] * m.conv.out_channels
w = Tensor(np.random.normal(0, np.sqrt(2. / n), m.conv.weight.data.shape).astype("float32"))
m.conv.weight.set_parameter_data(w)
if m.conv.bias is not None:
m.conv.bias.set_parameter_data(Tensor(np.zeros(m.conv.bias.data.shape, dtype="float32")))
elif isinstance(m, nn.BatchNorm2d): elif isinstance(m, nn.BatchNorm2d):
m.gamma.set_parameter_data(Tensor(np.ones(m.gamma.data.shape, dtype="float32"))) m.gamma.set_parameter_data(Tensor(np.ones(m.gamma.data.shape, dtype="float32")))
m.beta.set_parameter_data(Tensor(np.zeros(m.beta.data.shape, dtype="float32"))) m.beta.set_parameter_data(Tensor(np.zeros(m.beta.data.shape, dtype="float32")))
...@@ -229,3 +235,8 @@ class mobilenetV2(nn.Cell): ...@@ -229,3 +235,8 @@ class mobilenetV2(nn.Cell):
m.weight.set_parameter_data(Tensor(np.random.normal(0, 0.01, m.weight.data.shape).astype("float32"))) m.weight.set_parameter_data(Tensor(np.random.normal(0, 0.01, m.weight.data.shape).astype("float32")))
if m.bias is not None: if m.bias is not None:
m.bias.set_parameter_data(Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) m.bias.set_parameter_data(Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
elif isinstance(m, nn.DenseBnAct):
m.dense.weight.set_parameter_data(
Tensor(np.random.normal(0, 0.01, m.dense.weight.data.shape).astype("float32")))
if m.dense.bias is not None:
m.dense.bias.set_parameter_data(Tensor(np.zeros(m.dense.bias.data.shape, dtype="float32")))
...@@ -23,16 +23,17 @@ from mindspore import context ...@@ -23,16 +23,17 @@ from mindspore import context
from mindspore import Tensor from mindspore import Tensor
from mindspore import nn from mindspore import nn
from mindspore.train.model import Model, ParallelMode from mindspore.train.model import Model, ParallelMode
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.communication.management import init from mindspore.communication.management import init, get_group_size, get_rank
from mindspore.train.quant import quant from mindspore.train.quant import quant
import mindspore.dataset.engine as de import mindspore.dataset.engine as de
from src.dataset import create_dataset from src.dataset import create_dataset
from src.lr_generator import get_lr from src.lr_generator import get_lr
from src.utils import Monitor, CrossEntropyWithLabelSmooth from src.utils import Monitor, CrossEntropyWithLabelSmooth
from src.config import config_ascend, config_ascend_quant from src.config import config_ascend_quant, config_ascend, config_gpu_quant, config_gpu
from src.mobilenetV2 import mobilenetV2 from src.mobilenetV2 import mobilenetV2
random.seed(1) random.seed(1)
...@@ -55,11 +56,19 @@ if args_opt.device_target == "Ascend": ...@@ -55,11 +56,19 @@ if args_opt.device_target == "Ascend":
context.set_context(mode=context.GRAPH_MODE, context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend", device_target="Ascend",
device_id=device_id, save_graphs=False) device_id=device_id, save_graphs=False)
elif args_opt.platform == "GPU":
init("nccl")
context.set_auto_parallel_context(device_num=get_group_size(),
parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
context.set_context(mode=context.GRAPH_MODE,
device_target="GPU",
save_graphs=False)
else: else:
raise ValueError("Unsupported device target.") raise ValueError("Unsupported device target.")
if __name__ == '__main__':
# train on ascend def train_on_ascend():
config = config_ascend_quant if args_opt.quantization_aware else config_ascend config = config_ascend_quant if args_opt.quantization_aware else config_ascend
print("training args: {}".format(args_opt)) print("training args: {}".format(args_opt))
print("training configure: {}".format(config)) print("training configure: {}".format(config))
...@@ -129,3 +138,72 @@ if __name__ == '__main__': ...@@ -129,3 +138,72 @@ if __name__ == '__main__':
callback += [ckpt_cb] callback += [ckpt_cb]
model.train(epoch_size, dataset, callbacks=callback) model.train(epoch_size, dataset, callbacks=callback)
print("============== End Training ==============") print("============== End Training ==============")
def train_on_gpu():
config = config_gpu_quant if args_opt.quantization_aware else config_gpu
print("training args: {}".format(args_opt))
print("training configure: {}".format(config))
# define network
network = mobilenetV2(num_classes=config.num_classes)
# define loss
if config.label_smooth > 0:
loss = CrossEntropyWithLabelSmooth(smooth_factor=config.label_smooth,
num_classes=config.num_classes)
else:
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')
# define dataset
epoch_size = config.epoch_size
dataset = create_dataset(dataset_path=args_opt.dataset_path,
do_train=True,
config=config,
device_target=args_opt.device_target,
repeat_num=1,
batch_size=config.batch_size)
step_size = dataset.get_dataset_size()
# resume
if args_opt.pre_trained:
param_dict = load_checkpoint(args_opt.pre_trained)
load_param_into_net(network, param_dict)
# convert fusion network to quantization aware network
if config.quantization_aware:
network = quant.convert_quant_network(network,
bn_fold=True,
per_channel=[True, False],
symmetric=[True, True])
# get learning rate
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
lr = Tensor(get_lr(global_step=config.start_epoch * step_size,
lr_init=0,
lr_end=0,
lr_max=config.lr,
warmup_epochs=config.warmup_epochs,
total_epochs=epoch_size + config.start_epoch,
steps_per_epoch=step_size))
# define optimization
opt = nn.Momentum(filter(lambda x: x.requires_grad, network.get_parameters()), lr, config.momentum,
config.weight_decay, config.loss_scale)
# define model
model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale)
print("============== Starting Training ==============")
callback = [Monitor(lr_init=lr.asnumpy())]
ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"
if config.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
keep_checkpoint_max=config.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix="mobilenetV2", directory=ckpt_save_dir, config=config_ck)
callback += [ckpt_cb]
model.train(epoch_size, dataset, callbacks=callback)
print("============== End Training ==============")
if __name__ == '__main__':
if args_opt.device_target == "Ascend":
train_on_ascend()
elif args_opt.platform == "GPU":
train_on_gpu()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册