提交 31b165a5 编写于 作者: W wangjun260

add vgg scripts

上级 930a1fb0
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in main.py
"""
from easydict import EasyDict as edict
cifar_cfg = edict({
'num_classes': 10,
'lr_init': 0.05,
'batch_size': 64,
'epoch_size': 70,
'momentum': 0.9,
'weight_decay': 5e-4,
'buffer_size': 10,
'image_height': 224,
'image_width': 224,
'keep_checkpoint_max': 10
})
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Data operations, will be used in train.py and eval.py
"""
import os
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.transforms.vision.c_transforms as vision
import mindspore.common.dtype as mstype
from config import cifar_cfg as cfg
def create_dataset(data_home, repeat_num=1, training=True):
"""Data operations."""
ds.config.set_seed(1)
data_dir = os.path.join(data_home, "cifar-10-batches-bin")
if not training:
data_dir = os.path.join(data_home, "cifar-10-verify-bin")
data_set = ds.Cifar10Dataset(data_dir)
resize_height = cfg.image_height
resize_width = cfg.image_width
rescale = 1.0 / 255.0
shift = 0.0
# define map operations
random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT
random_horizontal_op = vision.RandomHorizontalFlip()
resize_op = vision.Resize((resize_height, resize_width)) # interpolation default BILINEAR
rescale_op = vision.Rescale(rescale, shift)
normalize_op = vision.Normalize((0.4465, 0.4822, 0.4914), (0.2010, 0.1994, 0.2023))
changeswap_op = vision.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32)
c_trans = []
if training:
c_trans = [random_crop_op, random_horizontal_op]
c_trans += [resize_op, rescale_op, normalize_op,
changeswap_op]
# apply map operations on images
data_set = data_set.map(input_columns="label", operations=type_cast_op)
data_set = data_set.map(input_columns="image", operations=c_trans)
# apply repeat operations
data_set = data_set.repeat(repeat_num)
# apply shuffle operations
data_set = data_set.shuffle(buffer_size=10)
# apply batch operations
data_set = data_set.batch(batch_size=cfg.batch_size, drop_remainder=True)
return data_set
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
##############test vgg16 example on cifar10#################
python eval.py --data_path=$DATA_HOME --device_id=$DEVICE_ID
"""
import argparse
import mindspore.nn as nn
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.model import Model
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.model_zoo.vgg import vgg16
from config import cifar_cfg as cfg
import dataset
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Cifar10 classification')
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
help='device where the code will be implemented. (Default: Ascend)')
parser.add_argument('--data_path', type=str, default='./cifar', help='path where the dataset is saved')
parser.add_argument('--checkpoint_path', type=str, default=None, help='checkpoint file path.')
parser.add_argument('--device_id', type=int, default=None, help='device id of GPU or Ascend. (Default: None)')
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
context.set_context(device_id=args_opt.device_id)
context.set_context(enable_mem_reuse=True, enable_hccl=False)
net = vgg16(batch_size=cfg.batch_size, num_classes=cfg.num_classes)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum,
weight_decay=cfg.weight_decay)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False)
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(net, param_dict)
net.set_train(False)
dataset = dataset.create_dataset(args_opt.data_path, 1, False)
res = model.eval(dataset)
print("result: ", res)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
#################train vgg16 example on cifar10########################
python train.py --data_path=$DATA_HOME --device_id=$DEVICE_ID
"""
import argparse
import random
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.model import Model
from mindspore import context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.model_zoo.vgg import vgg16
import dataset
from config import cifar_cfg as cfg
random.seed(1)
np.random.seed(1)
def lr_steps(global_step, lr_max=None, total_epochs=None, steps_per_epoch=None):
"""Set learning rate."""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps]
for i in range(total_steps):
if i < decay_epoch_index[0]:
lr_each_step.append(lr_max)
elif i < decay_epoch_index[1]:
lr_each_step.append(lr_max * 0.1)
elif i < decay_epoch_index[2]:
lr_each_step.append(lr_max * 0.01)
else:
lr_each_step.append(lr_max * 0.001)
current_step = global_step
lr_each_step = np.array(lr_each_step).astype(np.float32)
learning_rate = lr_each_step[current_step:]
return learning_rate
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Cifar10 classification')
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
help='device where the code will be implemented. (Default: Ascend)')
parser.add_argument('--data_path', type=str, default='./cifar', help='path where the dataset is saved')
parser.add_argument('--device_id', type=int, default=None, help='device id of GPU or Ascend. (Default: None)')
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
context.set_context(device_id=args_opt.device_id)
context.set_context(enable_mem_reuse=True, enable_hccl=False)
net = vgg16(batch_size=cfg.batch_size, num_classes=cfg.num_classes)
lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, steps_per_epoch=50000 // cfg.batch_size)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), cfg.momentum, weight_decay=cfg.weight_decay)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False)
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
dataset = dataset.create_dataset(args_opt.data_path, cfg.epoch_size)
batch_num = dataset.get_dataset_size()
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5, keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="train_vgg_cifar10", directory="./", config=config_ck)
loss_cb = LossMonitor()
model.train(cfg.epoch_size, dataset, callbacks=[ckpoint_cb, loss_cb])
......@@ -15,7 +15,8 @@
"""VGG."""
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.initializer import initializer
import mindspore.common.dtype as mstype
def _make_layer(base, batch_norm):
"""Make stage network of VGG."""
......@@ -25,11 +26,14 @@ def _make_layer(base, batch_norm):
if v == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
weight_shape = (v, in_channels, 3, 3)
weight = initializer('XavierUniform', shape=weight_shape, dtype=mstype.float32)
conv2d = nn.Conv2d(in_channels=in_channels,
out_channels=v,
kernel_size=3,
padding=1,
pad_mode='pad')
padding=0,
pad_mode='same',
weight_init=weight)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU()]
else:
......@@ -52,13 +56,13 @@ class Vgg(nn.Cell):
Tensor, infer output tensor.
Examples:
>>> VGG([64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
>>> Vgg([64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
>>> num_classes=1000, batch_norm=False, batch_size=1)
"""
def __init__(self, base, num_classes=1000, batch_norm=False, batch_size=1):
super(Vgg, self).__init__()
self.layers = _make_layer(base, batch_norm=batch_norm)
self.avgpool = nn.AvgPool2d(7)
self.reshape = P.Reshape()
self.shp = (batch_size, -1)
self.classifier = nn.SequentialCell([
......@@ -70,7 +74,6 @@ class Vgg(nn.Cell):
def construct(self, x):
x = self.layers(x)
x = self.avgpool(x)
x = self.reshape(x, self.shp)
x = self.classifier(x)
return x
......@@ -84,15 +87,20 @@ cfg = {
}
def vgg16():
def vgg16(batch_size=1, num_classes=1000):
"""
Get VGG16 neural network.
Get Vgg16 neural network with batch normalization.
Args:
batch_size (int): Batch size. Default: 1.
num_classes (int): Class numbers. Default: 1000.
Returns:
Cell, cell instance of VGG16 neural network.
Cell, cell instance of Vgg16 neural network with batch normalization.
Examples:
>>> vgg16()
>>> vgg16(batch_size=1, num_classes=1000)
"""
net = Vgg(cfg['16'], num_classes=1000, batch_norm=False, batch_size=1)
net = Vgg(cfg['16'], num_classes=num_classes, batch_norm=True, batch_size=batch_size)
return net
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册