diff --git a/model_zoo/official/cv/vgg16/eval.py b/model_zoo/official/cv/vgg16/eval.py index 8cdcc86031be48c2a74eec943d9985efb4c9d2c3..504a79207d659ebda7d3bb273ac9c62ff241b9c0 100644 --- a/model_zoo/official/cv/vgg16/eval.py +++ b/model_zoo/official/cv/vgg16/eval.py @@ -12,42 +12,204 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -""" -##############test vgg16 example on cifar10################# -python eval.py --data_path=$DATA_HOME --device_id=$DEVICE_ID -""" +"""Eval""" +import os +import time import argparse - +import datetime +import glob +import numpy as np import mindspore.nn as nn -from mindspore import context + +from mindspore import Tensor, context from mindspore.nn.optim.momentum import Momentum from mindspore.train.model import Model from mindspore.train.serialization import load_checkpoint, load_param_into_net -from src.config import cifar_cfg as cfg -from src.dataset import vgg_create_dataset +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.common import dtype as mstype + +from src.utils.logging import get_logger from src.vgg import vgg16 +from src.dataset import vgg_create_dataset +from src.dataset import classification_dataset + + +class ParameterReduce(nn.Cell): + """ParameterReduce""" + def __init__(self): + super(ParameterReduce, self).__init__() + self.cast = P.Cast() + self.reduce = P.AllReduce() + + def construct(self, x): + one = self.cast(F.scalar_to_array(1.0), mstype.float32) + out = x * one + ret = self.reduce(out) + return ret + -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Cifar10 classification') - parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'], +def parse_args(cloud_args=None): + """parse_args""" + parser = argparse.ArgumentParser('mindspore classification test') + parser.add_argument('--device_target', type=str, default='GPU', choices=['Ascend', 'GPU'], help='device where the code will be implemented. (Default: Ascend)') - parser.add_argument('--data_path', type=str, default='./cifar', help='path where the dataset is saved') - parser.add_argument('--checkpoint_path', type=str, default=None, help='checkpoint file path.') - parser.add_argument('--device_id', type=int, default=None, help='device id of GPU or Ascend. (Default: None)') + # dataset related + parser.add_argument('--dataset', type=str, choices=["cifar10", "imagenet2012"], default="imagenet2012") + parser.add_argument('--data_path', type=str, default='', help='eval data dir') + parser.add_argument('--per_batch_size', default=32, type=int, help='batch size for per npu') + # network related + parser.add_argument('--graph_ckpt', type=int, default=1, help='graph ckpt or feed ckpt') + parser.add_argument('--pretrained', default='', type=str, help='fully path of pretrained model to load. ' + 'If it is a direction, it will test all ckpt') + + # logging related + parser.add_argument('--log_path', type=str, default='outputs/', help='path to save log') + parser.add_argument('--rank', type=int, default=0, help='local rank of distributed') + parser.add_argument('--group_size', type=int, default=1, help='world size of distributed') + + # roma obs + parser.add_argument('--train_url', type=str, default="", help='train url') + args_opt = parser.parse_args() + args_opt = merge_args(args_opt, cloud_args) + + if args_opt.dataset == "cifar10": + from src.config import cifar_cfg as cfg + else: + from src.config import imagenet_cfg as cfg + + args_opt.image_size = cfg.image_size + args_opt.num_classes = cfg.num_classes + args_opt.per_batch_size = cfg.batch_size + args_opt.buffer_size = cfg.buffer_size + args_opt.pad_mode = cfg.pad_mode + args_opt.padding = cfg.padding + args_opt.has_bias = cfg.has_bias + args_opt.batch_norm = cfg.batch_norm + + args_opt.image_size = list(map(int, args_opt.image_size.split(','))) + + return args_opt + + +def get_top5_acc(top5_arg, gt_class): + sub_count = 0 + for top5, gt in zip(top5_arg, gt_class): + if gt in top5: + sub_count += 1 + return sub_count + + +def merge_args(args, cloud_args): + """merge_args""" + args_dict = vars(args) + if isinstance(cloud_args, dict): + for key in cloud_args.keys(): + val = cloud_args[key] + if key in args_dict and val: + arg_type = type(args_dict[key]) + if arg_type is not type(None): + val = arg_type(val) + args_dict[key] = val + return args + + +def test(cloud_args=None): + """test""" + args = parse_args(cloud_args) + context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, + device_target=args.device_target, save_graphs=False) + if os.getenv('DEVICE_ID', "not_set").isdigit(): + context.set_context(device_id=int(os.getenv('DEVICE_ID'))) + + args.outputs_dir = os.path.join(args.log_path, + datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) + + args.logger = get_logger(args.outputs_dir, args.rank) + args.logger.save_args(args) + + if args.dataset == "cifar10": + net = vgg16(num_classes=args.num_classes) + opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum, + weight_decay=args.weight_decay) + loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) + model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) + + param_dict = load_checkpoint(args.checkpoint_path) + load_param_into_net(net, param_dict) + net.set_train(False) + dataset = vgg_create_dataset(args.data_path, 1, False) + res = model.eval(dataset) + print("result: ", res) + else: + # network + args.logger.important_info('start create network') + if os.path.isdir(args.pretrained): + models = list(glob.glob(os.path.join(args.pretrained, '*.ckpt'))) + print(models) + if args.graph_ckpt: + f = lambda x: -1 * int(os.path.splitext(os.path.split(x)[-1])[0].split('-')[-1].split('_')[0]) + else: + f = lambda x: -1 * int(os.path.splitext(os.path.split(x)[-1])[0].split('_')[-1]) + args.models = sorted(models, key=f) + else: + args.models = [args.pretrained,] + + for model in args.models: + if args.dataset == "cifar10": + dataset = vgg_create_dataset(args.data_path, args.image_size, args.per_batch_size, training=False) + else: + dataset = classification_dataset(args.data_path, args.image_size, args.per_batch_size) + + eval_dataloader = dataset.create_tuple_iterator() + network = vgg16(args.num_classes, args, phase="test") + + # pre_trained + load_param_into_net(network, load_checkpoint(model)) + network.add_flags_recursive(fp16=True) + + img_tot = 0 + top1_correct = 0 + top5_correct = 0 + + network.set_train(False) + t_end = time.time() + it = 0 + for data, gt_classes in eval_dataloader: + output = network(Tensor(data, mstype.float32)) + output = output.asnumpy() + + top1_output = np.argmax(output, (-1)) + top5_output = np.argsort(output)[:, -5:] + + t1_correct = np.equal(top1_output, gt_classes).sum() + top1_correct += t1_correct + top5_correct += get_top5_acc(top5_output, gt_classes) + img_tot += args.per_batch_size + + if args.rank == 0 and it == 0: + t_end = time.time() + it = 1 + if args.rank == 0: + time_used = time.time() - t_end + fps = (img_tot - args.per_batch_size) * args.group_size / time_used + args.logger.info('Inference Performance: {:.2f} img/sec'.format(fps)) + results = [[top1_correct], [top5_correct], [img_tot]] + args.logger.info('before results={}'.format(results)) + results = np.array(results) + + args.logger.info('after results={}'.format(results)) + top1_correct = results[0, 0] + top5_correct = results[1, 0] + img_tot = results[2, 0] + acc1 = 100.0 * top1_correct / img_tot + acc5 = 100.0 * top5_correct / img_tot + args.logger.info('after allreduce eval: top1_correct={}, tot={},' + 'acc={:.2f}%(TOP1)'.format(top1_correct, img_tot, acc1)) + args.logger.info('after allreduce eval: top5_correct={}, tot={},' + 'acc={:.2f}%(TOP5)'.format(top5_correct, img_tot, acc5)) + - context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) - context.set_context(device_id=args_opt.device_id) - - net = vgg16(num_classes=cfg.num_classes) - opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum, - weight_decay=cfg.weight_decay) - loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) - model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) - - param_dict = load_checkpoint(args_opt.checkpoint_path) - load_param_into_net(net, param_dict) - net.set_train(False) - dataset = vgg_create_dataset(args_opt.data_path, 1, False) - res = model.eval(dataset) - print("result: ", res) +if __name__ == "__main__": + test() diff --git a/model_zoo/official/cv/vgg16/src/config.py b/model_zoo/official/cv/vgg16/src/config.py old mode 100644 new mode 100755 index a34cf7a1d3ee61a13dd71e4bed4d42d9982a04b9..d8bfab71e7ad86547eb023f7bc9151049dac5b35 --- a/model_zoo/official/cv/vgg16/src/config.py +++ b/model_zoo/official/cv/vgg16/src/config.py @@ -13,21 +13,56 @@ # limitations under the License. # ============================================================================ """ -network config setting, will be used in main.py +network config setting, will be used in train.py and eval.py """ from easydict import EasyDict as edict +# config for vgg16, cifar10 cifar_cfg = edict({ 'num_classes': 10, + "lr": 0.01, 'lr_init': 0.01, 'lr_max': 0.1, + "lr_epochs": '30,60,90,120', + "lr_scheduler": "step", 'warmup_epochs': 5, 'batch_size': 64, - 'epoch_size': 70, + 'max_epoch': 70, 'momentum': 0.9, 'weight_decay': 5e-4, + "loss_scale": 1.0, + "label_smooth": 0, + "label_smooth_factor": 0, 'buffer_size': 10, - 'image_height': 224, - 'image_width': 224, + "image_size": '224,224', + 'pad_mode': 'same', + 'padding': 0, + 'has_bias': False, + "batch_norm": True, + 'keep_checkpoint_max': 10 +}) + +# config for vgg16, imagenet2012 +imagenet_cfg = edict({ + 'num_classes': 1000, + "lr": 0.01, + 'lr_init': 0.01, + 'lr_max': 0.1, + "lr_epochs": '30,60,90,120', + "lr_scheduler": 'cosine_annealing', + 'warmup_epochs': 0, + 'batch_size': 32, + 'max_epoch': 150, + 'momentum': 0.9, + 'weight_decay': 1e-4, + "loss_scale": 1024, + "label_smooth": 1, + "label_smooth_factor": 0.1, + 'buffer_size': 10, + "image_size": '224,224', + 'pad_mode': 'pad', + 'padding': 1, + 'has_bias': True, + "batch_norm": False, 'keep_checkpoint_max': 10 }) diff --git a/model_zoo/official/cv/vgg16/src/crossentropy.py b/model_zoo/official/cv/vgg16/src/crossentropy.py new file mode 100755 index 0000000000000000000000000000000000000000..5118cb5161218035cc88ac18dfce6ea086322566 --- /dev/null +++ b/model_zoo/official/cv/vgg16/src/crossentropy.py @@ -0,0 +1,39 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""define loss function for network""" +from mindspore.nn.loss.loss import _Loss +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore import Tensor +from mindspore.common import dtype as mstype +import mindspore.nn as nn + + +class CrossEntropy(_Loss): + """the redefined loss function with SoftmaxCrossEntropyWithLogits""" + + def __init__(self, smooth_factor=0., num_classes=1001): + super(CrossEntropy, self).__init__() + self.onehot = P.OneHot() + self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) + self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32) + self.ce = nn.SoftmaxCrossEntropyWithLogits() + self.mean = P.ReduceMean(False) + + def construct(self, logit, label): + one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value) + loss = self.ce(logit, one_hot_label) + loss = self.mean(loss, 0) + return loss diff --git a/model_zoo/official/cv/vgg16/src/dataset.py b/model_zoo/official/cv/vgg16/src/dataset.py index b08659fb5ea972c8c640b6a65ba1a6ead431fd17..c510b494977be11728a23bb327b0a2d8b2e37d54 100644 --- a/model_zoo/official/cv/vgg16/src/dataset.py +++ b/model_zoo/official/cv/vgg16/src/dataset.py @@ -13,37 +13,35 @@ # limitations under the License. # ============================================================================ """ -Data operations, will be used in train.py and eval.py +dataset processing. """ import os - -import mindspore.common.dtype as mstype -import mindspore.dataset as ds +from mindspore.common import dtype as mstype +import mindspore.dataset as de import mindspore.dataset.transforms.c_transforms as C import mindspore.dataset.transforms.vision.c_transforms as vision -from .config import cifar_cfg as cfg +from PIL import Image, ImageFile +from src.utils.sampler import DistributedSampler + +ImageFile.LOAD_TRUNCATED_IMAGES = True -def vgg_create_dataset(data_home, repeat_num=1, training=True): +def vgg_create_dataset(data_home, image_size, batch_size, rank_id=0, rank_size=1, repeat_num=1, training=True): """Data operations.""" - ds.config.set_seed(1) + de.config.set_seed(1) data_dir = os.path.join(data_home, "cifar-10-batches-bin") if not training: data_dir = os.path.join(data_home, "cifar-10-verify-bin") - rank_size = int(os.environ.get("RANK_SIZE")) if os.environ.get("RANK_SIZE") else None - rank_id = int(os.environ.get("RANK_ID")) if os.environ.get("RANK_ID") else None - data_set = ds.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id) + data_set = de.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id) - resize_height = cfg.image_height - resize_width = cfg.image_width rescale = 1.0 / 255.0 shift = 0.0 # define map operations random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT random_horizontal_op = vision.RandomHorizontalFlip() - resize_op = vision.Resize((resize_height, resize_width)) # interpolation default BILINEAR + resize_op = vision.Resize(image_size) # interpolation default BILINEAR rescale_op = vision.Rescale(rescale, shift) normalize_op = vision.Normalize((0.4465, 0.4822, 0.4914), (0.2010, 0.1994, 0.2023)) changeswap_op = vision.HWC2CHW() @@ -66,6 +64,134 @@ def vgg_create_dataset(data_home, repeat_num=1, training=True): data_set = data_set.shuffle(buffer_size=10) # apply batch operations - data_set = data_set.batch(batch_size=cfg.batch_size, drop_remainder=True) + data_set = data_set.batch(batch_size=batch_size, drop_remainder=True) return data_set + + +def classification_dataset(data_dir, image_size, per_batch_size, rank=0, group_size=1, + mode='train', + input_mode='folder', + root='', + num_parallel_workers=None, + shuffle=None, + sampler=None, + repeat_num=1, + class_indexing=None, + drop_remainder=True, + transform=None, + target_transform=None): + """ + A function that returns a dataset for classification. The mode of input dataset could be "folder" or "txt". + If it is "folder", all images within one folder have the same label. If it is "txt", all paths of images + are written into a textfile. + + Args: + data_dir (str): Path to the root directory that contains the dataset for "input_mode="folder"". + Or path of the textfile that contains every image's path of the dataset. + image_size (str): Size of the input images. + per_batch_size (int): the batch size of evey step during training. + rank (int): The shard ID within num_shards (default=None). + group_size (int): Number of shards that the dataset should be divided + into (default=None). + mode (str): "train" or others. Default: " train". + input_mode (str): The form of the input dataset. "folder" or "txt". Default: "folder". + root (str): the images path for "input_mode="txt"". Default: " ". + num_parallel_workers (int): Number of workers to read the data. Default: None. + shuffle (bool): Whether or not to perform shuffle on the dataset + (default=None, performs shuffle). + sampler (Sampler): Object used to choose samples from the dataset. Default: None. + repeat_num (int): the num of repeat dataset. + class_indexing (dict): A str-to-int mapping from folder name to index + (default=None, the folder names will be sorted + alphabetically and each class will be given a + unique index starting from 0). + + Examples: + >>> from mindvision.common.datasets.classification import classification_dataset + >>> # path to imagefolder directory. This directory needs to contain sub-directories which contain the images + >>> dataset_dir = "/path/to/imagefolder_directory" + >>> de_dataset = classification_dataset(train_data_dir, image_size=[224, 244], + >>> per_batch_size=64, rank=0, group_size=4) + >>> # Path of the textfile that contains every image's path of the dataset. + >>> dataset_dir = "/path/to/dataset/images/train.txt" + >>> images_dir = "/path/to/dataset/images" + >>> de_dataset = classification_dataset(train_data_dir, image_size=[224, 244], + >>> per_batch_size=64, rank=0, group_size=4, + >>> input_mode="txt", root=images_dir) + """ + + mean = [0.485 * 255, 0.456 * 255, 0.406 * 255] + std = [0.229 * 255, 0.224 * 255, 0.225 * 255] + + if transform is None: + if mode == 'train': + transform_img = [ + vision.RandomCropDecodeResize(image_size, scale=(0.08, 1.0)), + vision.RandomHorizontalFlip(prob=0.5), + vision.Normalize(mean=mean, std=std), + vision.HWC2CHW() + ] + else: + transform_img = [ + vision.Decode(), + vision.Resize((256, 256)), + vision.CenterCrop(image_size), + vision.Normalize(mean=mean, std=std), + vision.HWC2CHW() + ] + else: + transform_img = transform + + if target_transform is None: + transform_label = [C.TypeCast(mstype.int32)] + else: + transform_label = target_transform + + if input_mode == 'folder': + de_dataset = de.ImageFolderDatasetV2(data_dir, num_parallel_workers=num_parallel_workers, + shuffle=shuffle, sampler=sampler, class_indexing=class_indexing, + num_shards=group_size, shard_id=rank) + else: + dataset = TxtDataset(root, data_dir) + sampler = DistributedSampler(dataset, rank, group_size, shuffle=shuffle) + de_dataset = de.GeneratorDataset(dataset, ["image", "label"], sampler=sampler) + de_dataset.set_dataset_size(len(sampler)) + + de_dataset = de_dataset.map(input_columns="image", num_parallel_workers=8, operations=transform_img) + de_dataset = de_dataset.map(input_columns="label", num_parallel_workers=8, operations=transform_label) + + columns_to_project = ["image", "label"] + de_dataset = de_dataset.project(columns=columns_to_project) + + de_dataset = de_dataset.batch(per_batch_size, drop_remainder=drop_remainder) + de_dataset = de_dataset.repeat(repeat_num) + + return de_dataset + + +class TxtDataset: + """ + create txt dataset. + + Args: + Returns: + de_dataset. + """ + def __init__(self, root, txt_name): + super(TxtDataset, self).__init__() + self.imgs = [] + self.labels = [] + fin = open(txt_name, "r") + for line in fin: + img_name, label = line.strip().split(' ') + self.imgs.append(os.path.join(root, img_name)) + self.labels.append(int(label)) + fin.close() + + def __getitem__(self, index): + img = Image.open(self.imgs[index]).convert('RGB') + return img, self.labels[index] + + def __len__(self): + return len(self.imgs) diff --git a/tests/ut/python/model/test_vgg.py b/model_zoo/official/cv/vgg16/src/linear_warmup.py similarity index 64% rename from tests/ut/python/model/test_vgg.py rename to model_zoo/official/cv/vgg16/src/linear_warmup.py index 16365d5ee0d8e835c8d9ec1f38ad8bbdc52e307b..dc926e5ce118a44472af0671a1e78a4f80beaac7 100644 --- a/tests/ut/python/model/test_vgg.py +++ b/model_zoo/official/cv/vgg16/src/linear_warmup.py @@ -12,18 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""test_vgg""" -import numpy as np -import pytest +""" +linear warm up learning rate. +""" -from mindspore import Tensor -from model_zoo.official.cv.vgg16.src.vgg import vgg16 -from ..ut_filter import non_graph_engine - -@non_graph_engine -def test_vgg16(): - inputs = Tensor(np.random.rand(1, 3, 112, 112).astype(np.float32)) - net = vgg16() - with pytest.raises(ValueError): - print(net.construct(inputs)) +def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr): + lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps) + lr = float(init_lr) + lr_inc * current_step + return lr diff --git a/model_zoo/official/cv/vgg16/src/utils/logging.py b/model_zoo/official/cv/vgg16/src/utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..ac37bec4ecc17cc3e0705fad09e5f578bacf8532 --- /dev/null +++ b/model_zoo/official/cv/vgg16/src/utils/logging.py @@ -0,0 +1,82 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +get logger. +""" +import logging +import os +import sys +from datetime import datetime + +class LOGGER(logging.Logger): + """ + set up logging file. + + Args: + logger_name (string): logger name. + log_dir (string): path of logger. + + Returns: + string, logger path + """ + def __init__(self, logger_name, rank=0): + super(LOGGER, self).__init__(logger_name) + if rank % 8 == 0: + console = logging.StreamHandler(sys.stdout) + console.setLevel(logging.INFO) + formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s') + console.setFormatter(formatter) + self.addHandler(console) + + def setup_logging_file(self, log_dir, rank=0): + """set up log file""" + self.rank = rank + if not os.path.exists(log_dir): + os.makedirs(log_dir, exist_ok=True) + log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(rank) + self.log_fn = os.path.join(log_dir, log_name) + fh = logging.FileHandler(self.log_fn) + fh.setLevel(logging.INFO) + formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s') + fh.setFormatter(formatter) + self.addHandler(fh) + + def info(self, msg, *args, **kwargs): + if self.isEnabledFor(logging.INFO): + self._log(logging.INFO, msg, args, **kwargs) + + def save_args(self, args): + self.info('Args:') + args_dict = vars(args) + for key in args_dict.keys(): + self.info('--> %s: %s', key, args_dict[key]) + self.info('') + + def important_info(self, msg, *args, **kwargs): + if self.isEnabledFor(logging.INFO) and self.rank == 0: + line_width = 2 + important_msg = '\n' + important_msg += ('*'*70 + '\n')*line_width + important_msg += ('*'*line_width + '\n')*2 + important_msg += '*'*line_width + ' '*8 + msg + '\n' + important_msg += ('*'*line_width + '\n')*2 + important_msg += ('*'*70 + '\n')*line_width + self.info(important_msg, *args, **kwargs) + + +def get_logger(path, rank): + logger = LOGGER("mindversion", rank) + logger.setup_logging_file(path, rank) + return logger diff --git a/model_zoo/official/cv/vgg16/src/utils/sampler.py b/model_zoo/official/cv/vgg16/src/utils/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..5b68f8325e496521321186fd3e801624873b3a16 --- /dev/null +++ b/model_zoo/official/cv/vgg16/src/utils/sampler.py @@ -0,0 +1,53 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +choose samples from the dataset +""" +import math +import numpy as np + +class DistributedSampler(): + """ + sampling the dataset. + + Args: + Returns: + num_samples, number of samples. + """ + def __init__(self, dataset, rank, group_size, shuffle=True, seed=0): + self.dataset = dataset + self.rank = rank + self.group_size = group_size + self.dataset_length = len(self.dataset) + self.num_samples = int(math.ceil(self.dataset_length * 1.0 / self.group_size)) + self.total_size = self.num_samples * self.group_size + self.shuffle = shuffle + self.seed = seed + + def __iter__(self): + if self.shuffle: + self.seed = (self.seed + 1) & 0xffffffff + np.random.seed(self.seed) + indices = np.random.permutation(self.dataset_length).tolist() + else: + indices = list(range(len(self.dataset_length))) + + indices += indices[:(self.total_size - len(indices))] + indices = indices[self.rank::self.group_size] + return iter(indices) + + def __len__(self): + return self.num_samples + \ No newline at end of file diff --git a/model_zoo/official/cv/vgg16/src/utils/util.py b/model_zoo/official/cv/vgg16/src/utils/util.py new file mode 100644 index 0000000000000000000000000000000000000000..6f84045a8981a293eceab9cb77c20c66a339cf59 --- /dev/null +++ b/model_zoo/official/cv/vgg16/src/utils/util.py @@ -0,0 +1,36 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Util class or function.""" + + +def get_param_groups(network): + """Param groups for optimizer.""" + decay_params = [] + no_decay_params = [] + for x in network.trainable_params(): + parameter_name = x.name + if parameter_name.endswith('.bias'): + # all bias not using weight decay + no_decay_params.append(x) + elif parameter_name.endswith('.gamma'): + # bn weight bias not using weight decay, be carefully for now x not include BN + no_decay_params.append(x) + elif parameter_name.endswith('.beta'): + # bn weight bias not using weight decay, be carefully for now x not include BN + no_decay_params.append(x) + else: + decay_params.append(x) + + return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}] diff --git a/model_zoo/official/cv/vgg16/src/utils/var_init.py b/model_zoo/official/cv/vgg16/src/utils/var_init.py new file mode 100644 index 0000000000000000000000000000000000000000..51fc109990be1e6ce30335c85f3209eb723c2e45 --- /dev/null +++ b/model_zoo/official/cv/vgg16/src/utils/var_init.py @@ -0,0 +1,213 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Initialize. +""" +import math +from functools import reduce +import numpy as np +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common import initializer as init + +def _calculate_gain(nonlinearity, param=None): + r""" + Return the recommended gain value for the given nonlinearity function. + + The values are as follows: + ================= ==================================================== + nonlinearity gain + ================= ==================================================== + Linear / Identity :math:`1` + Conv{1,2,3}D :math:`1` + Sigmoid :math:`1` + Tanh :math:`\frac{5}{3}` + ReLU :math:`\sqrt{2}` + Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}` + ================= ==================================================== + + Args: + nonlinearity: the non-linear function + param: optional parameter for the non-linear function + + Examples: + >>> gain = calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2 + """ + linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d'] + if nonlinearity in linear_fns or nonlinearity == 'sigmoid': + return 1 + if nonlinearity == 'tanh': + return 5.0 / 3 + if nonlinearity == 'relu': + return math.sqrt(2.0) + if nonlinearity == 'leaky_relu': + if param is None: + negative_slope = 0.01 + elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float): + negative_slope = param + else: + raise ValueError("negative_slope {} not a valid number".format(param)) + return math.sqrt(2.0 / (1 + negative_slope ** 2)) + + raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) + +def _assignment(arr, num): + """Assign the value of `num` to `arr`.""" + if arr.shape == (): + arr = arr.reshape((1)) + arr[:] = num + arr = arr.reshape(()) + else: + if isinstance(num, np.ndarray): + arr[:] = num[:] + else: + arr[:] = num + return arr + +def _calculate_in_and_out(arr): + """ + Calculate n_in and n_out. + + Args: + arr (Array): Input array. + + Returns: + Tuple, a tuple with two elements, the first element is `n_in` and the second element is `n_out`. + """ + dim = len(arr.shape) + if dim < 2: + raise ValueError("If initialize data with xavier uniform, the dimension of data must greater than 1.") + + n_in = arr.shape[1] + n_out = arr.shape[0] + + if dim > 2: + counter = reduce(lambda x, y: x * y, arr.shape[2:]) + n_in *= counter + n_out *= counter + return n_in, n_out + +def _select_fan(array, mode): + mode = mode.lower() + valid_modes = ['fan_in', 'fan_out'] + if mode not in valid_modes: + raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes)) + + fan_in, fan_out = _calculate_in_and_out(array) + return fan_in if mode == 'fan_in' else fan_out + +class KaimingInit(init.Initializer): + r""" + Base Class. Initialize the array with He kaiming algorithm. + + Args: + a: the negative slope of the rectifier used after this layer (only + used with ``'leaky_relu'``) + mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` + preserves the magnitude of the variance of the weights in the + forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the + backwards pass. + nonlinearity: the non-linear function, recommended to use only with + ``'relu'`` or ``'leaky_relu'`` (default). + """ + def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'): + super(KaimingInit, self).__init__() + self.mode = mode + self.gain = _calculate_gain(nonlinearity, a) + def _initialize(self, arr): + pass + + +class KaimingUniform(KaimingInit): + r""" + Initialize the array with He kaiming uniform algorithm. The resulting tensor will + have values sampled from :math:`\mathcal{U}(-\text{bound}, \text{bound})` where + + .. math:: + \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}} + + Input: + arr (Array): The array to be assigned. + + Returns: + Array, assigned array. + + Examples: + >>> w = np.empty(3, 5) + >>> KaimingUniform(w, mode='fan_in', nonlinearity='relu') + """ + + def _initialize(self, arr): + fan = _select_fan(arr, self.mode) + bound = math.sqrt(3.0) * self.gain / math.sqrt(fan) + np.random.seed(0) + data = np.random.uniform(-bound, bound, arr.shape) + + _assignment(arr, data) + + +class KaimingNormal(KaimingInit): + r""" + Initialize the array with He kaiming normal algorithm. The resulting tensor will + have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where + + .. math:: + \text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}} + + Input: + arr (Array): The array to be assigned. + + Returns: + Array, assigned array. + + Examples: + >>> w = np.empty(3, 5) + >>> KaimingNormal(w, mode='fan_out', nonlinearity='relu') + """ + + def _initialize(self, arr): + fan = _select_fan(arr, self.mode) + std = self.gain / math.sqrt(fan) + np.random.seed(0) + data = np.random.normal(0, std, arr.shape) + + _assignment(arr, data) + + +def default_recurisive_init(custom_cell): + """default_recurisive_init""" + for _, cell in custom_cell.cells_and_names(): + if isinstance(cell, nn.Conv2d): + cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), + cell.weight.default_input.shape, + cell.weight.default_input.dtype).to_tensor() + if cell.bias is not None: + fan_in, _ = _calculate_in_and_out(cell.weight.default_input.asnumpy()) + bound = 1 / math.sqrt(fan_in) + np.random.seed(0) + cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape), + cell.bias.default_input.dtype) + elif isinstance(cell, nn.Dense): + cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), + cell.weight.default_input.shape, + cell.weight.default_input.dtype).to_tensor() + if cell.bias is not None: + fan_in, _ = _calculate_in_and_out(cell.weight.default_input.asnumpy()) + bound = 1 / math.sqrt(fan_in) + np.random.seed(0) + cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape), + cell.bias.default_input.dtype) + elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)): + pass diff --git a/model_zoo/official/cv/vgg16/src/vgg.py b/model_zoo/official/cv/vgg16/src/vgg.py index 55130871cc97af0257b1549f277ebbe1c8eeb29c..835d2a0b5d61c4f0dca131e03f836781fe99ba63 100644 --- a/model_zoo/official/cv/vgg16/src/vgg.py +++ b/model_zoo/official/cv/vgg16/src/vgg.py @@ -12,12 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""VGG.""" +""" +Image classifiation. +""" +import math import mindspore.nn as nn -from mindspore.common.initializer import initializer import mindspore.common.dtype as mstype +from mindspore.common import initializer as init +from mindspore.common.initializer import initializer +from .utils.var_init import default_recurisive_init, KaimingNormal -def _make_layer(base, batch_norm): + +def _make_layer(base, args, batch_norm): """Make stage network of VGG.""" layers = [] in_channels = 3 @@ -27,11 +33,14 @@ def _make_layer(base, batch_norm): else: weight_shape = (v, in_channels, 3, 3) weight = initializer('XavierUniform', shape=weight_shape, dtype=mstype.float32).to_tensor() + if args.dataset == "imagenet2012": + weight = 'normal' conv2d = nn.Conv2d(in_channels=in_channels, out_channels=v, kernel_size=3, - padding=0, - pad_mode='same', + padding=args.padding, + pad_mode=args.pad_mode, + has_bias=args.has_bias, weight_init=weight) if batch_norm: layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU()] @@ -59,17 +68,25 @@ class Vgg(nn.Cell): >>> num_classes=1000, batch_norm=False, batch_size=1) """ - def __init__(self, base, num_classes=1000, batch_norm=False, batch_size=1): + def __init__(self, base, num_classes=1000, batch_norm=False, batch_size=1, args=None, phase="train"): super(Vgg, self).__init__() _ = batch_size - self.layers = _make_layer(base, batch_norm=batch_norm) + self.layers = _make_layer(base, args, batch_norm=batch_norm) self.flatten = nn.Flatten() + dropout_ratio = 0.5 + if args.dataset == "cifar10" or phase == "test": + dropout_ratio = 1.0 self.classifier = nn.SequentialCell([ nn.Dense(512 * 7 * 7, 4096), nn.ReLU(), + nn.Dropout(dropout_ratio), nn.Dense(4096, 4096), nn.ReLU(), + nn.Dropout(dropout_ratio), nn.Dense(4096, num_classes)]) + if args.dataset == "imagenet2012": + default_recurisive_init(self) + self.custom_init_weight() def construct(self, x): x = self.layers(x) @@ -77,6 +94,25 @@ class Vgg(nn.Cell): x = self.classifier(x) return x + def custom_init_weight(self): + """ + Init the weight of Conv2d and Dense in the net. + """ + for _, cell in self.cells_and_names(): + if isinstance(cell, nn.Conv2d): + cell.weight.default_input = init.initializer( + KaimingNormal(a=math.sqrt(5), mode='fan_out', nonlinearity='relu'), + cell.weight.default_input.shape, cell.weight.default_input.dtype).to_tensor() + if cell.bias is not None: + cell.bias.default_input = init.initializer( + 'zeros', cell.bias.default_input.shape, cell.bias.default_input.dtype).to_tensor() + elif isinstance(cell, nn.Dense): + cell.weight.default_input = init.initializer( + init.Normal(0.01), cell.weight.default_input.shape, cell.weight.default_input.dtype).to_tensor() + if cell.bias is not None: + cell.bias.default_input = init.initializer( + 'zeros', cell.bias.default_input.shape, cell.bias.default_input.dtype).to_tensor() + cfg = { '11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], @@ -86,12 +122,14 @@ cfg = { } -def vgg16(num_classes=1000): +def vgg16(num_classes=1000, args=None, phase="train"): """ Get Vgg16 neural network with batch normalization. Args: num_classes (int): Class numbers. Default: 1000. + args(dict): param for net init. + phase(str): train or test mode. Returns: Cell, cell instance of Vgg16 neural network with batch normalization. @@ -100,5 +138,5 @@ def vgg16(num_classes=1000): >>> vgg16(num_classes=1000) """ - net = Vgg(cfg['16'], num_classes=num_classes, batch_norm=True) + net = Vgg(cfg['16'], num_classes=num_classes, args=args, batch_norm=True, phase=phase) return net diff --git a/model_zoo/official/cv/vgg16/src/warmup_cosine_annealing_lr.py b/model_zoo/official/cv/vgg16/src/warmup_cosine_annealing_lr.py new file mode 100644 index 0000000000000000000000000000000000000000..5d9fce9af4c8b2236ff25faed1a019c540ef72f3 --- /dev/null +++ b/model_zoo/official/cv/vgg16/src/warmup_cosine_annealing_lr.py @@ -0,0 +1,40 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +warm up cosine annealing learning rate. +""" +import math +import numpy as np + +from .linear_warmup import linear_warmup_lr + + +def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0): + """warm up cosine annealing learning rate.""" + base_lr = lr + warmup_init_lr = 0 + total_steps = int(max_epoch * steps_per_epoch) + warmup_steps = int(warmup_epochs * steps_per_epoch) + + lr_each_step = [] + for i in range(total_steps): + last_epoch = i // steps_per_epoch + if i < warmup_steps: + lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) + else: + lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / T_max)) / 2 + lr_each_step.append(lr) + + return np.array(lr_each_step).astype(np.float32) diff --git a/model_zoo/official/cv/vgg16/src/warmup_step_lr.py b/model_zoo/official/cv/vgg16/src/warmup_step_lr.py new file mode 100644 index 0000000000000000000000000000000000000000..2ffaa493e9bef9307fa6c7d3e736ea49557352f7 --- /dev/null +++ b/model_zoo/official/cv/vgg16/src/warmup_step_lr.py @@ -0,0 +1,84 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +warm up step learning rate. +""" +from collections import Counter +import numpy as np + +from .linear_warmup import linear_warmup_lr + + +def lr_steps(global_step, lr_init, lr_max, warmup_epochs, total_epochs, steps_per_epoch): + """Set learning rate.""" + lr_each_step = [] + total_steps = steps_per_epoch * total_epochs + warmup_steps = steps_per_epoch * warmup_epochs + if warmup_steps != 0: + inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps) + else: + inc_each_step = 0 + for i in range(total_steps): + if i < warmup_steps: + lr_value = float(lr_init) + inc_each_step * float(i) + else: + base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps))) + lr_value = float(lr_max) * base * base + if lr_value < 0.0: + lr_value = 0.0 + lr_each_step.append(lr_value) + + current_step = global_step + lr_each_step = np.array(lr_each_step).astype(np.float32) + learning_rate = lr_each_step[current_step:] + + return learning_rate + + +def warmup_step_lr(lr, lr_epochs, steps_per_epoch, warmup_epochs, max_epoch, gamma=0.1): + """warmup_step_lr""" + base_lr = lr + warmup_init_lr = 0 + total_steps = int(max_epoch * steps_per_epoch) + warmup_steps = int(warmup_epochs * steps_per_epoch) + milestones = lr_epochs + milestones_steps = [] + for milestone in milestones: + milestones_step = milestone * steps_per_epoch + milestones_steps.append(milestones_step) + + lr_each_step = [] + lr = base_lr + milestones_steps_counter = Counter(milestones_steps) + for i in range(total_steps): + if i < warmup_steps: + lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) + else: + lr = lr * gamma**milestones_steps_counter[i] + lr_each_step.append(lr) + + return np.array(lr_each_step).astype(np.float32) + + +def multi_step_lr(lr, milestones, steps_per_epoch, max_epoch, gamma=0.1): + return warmup_step_lr(lr, milestones, steps_per_epoch, 0, max_epoch, gamma=gamma) + + +def step_lr(lr, epoch_size, steps_per_epoch, max_epoch, gamma=0.1): + lr_epochs = [] + for i in range(1, max_epoch): + if i % epoch_size == 0: + lr_epochs.append(i) + return multi_step_lr(lr, lr_epochs, steps_per_epoch, max_epoch, gamma=gamma) diff --git a/model_zoo/official/cv/vgg16/train.py b/model_zoo/official/cv/vgg16/train.py index 5195476b97eff1b1fe8a6a704e92841da6fd9a11..2bd78c4685eea4503c17da8cb52edb40690fff57 100644 --- a/model_zoo/official/cv/vgg16/train.py +++ b/model_zoo/official/cv/vgg16/train.py @@ -17,6 +17,8 @@ python train.py --data_path=$DATA_HOME --device_id=$DEVICE_ID """ import argparse +import datetime +import time import os import random @@ -25,83 +27,264 @@ import numpy as np import mindspore.nn as nn from mindspore import Tensor from mindspore import context -from mindspore.communication.management import init +from mindspore.communication.management import init, get_rank, get_group_size from mindspore.nn.optim.momentum import Momentum -from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor +from mindspore.train.callback import Callback, ModelCheckpoint, CheckpointConfig from mindspore.train.model import Model, ParallelMode from mindspore.train.serialization import load_param_into_net, load_checkpoint -from src.config import cifar_cfg as cfg +from mindspore.train.loss_scale_manager import FixedLossScaleManager from src.dataset import vgg_create_dataset +from src.dataset import classification_dataset + +from src.crossentropy import CrossEntropy +from src.warmup_step_lr import warmup_step_lr +from src.warmup_cosine_annealing_lr import warmup_cosine_annealing_lr +from src.warmup_step_lr import lr_steps +from src.utils.logging import get_logger +from src.utils.util import get_param_groups from src.vgg import vgg16 + random.seed(1) np.random.seed(1) -def lr_steps(global_step, lr_init, lr_max, warmup_epochs, total_epochs, steps_per_epoch): - """Set learning rate.""" - lr_each_step = [] - total_steps = steps_per_epoch * total_epochs - warmup_steps = steps_per_epoch * warmup_epochs - if warmup_steps != 0: - inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps) - else: - inc_each_step = 0 - for i in range(total_steps): - if i < warmup_steps: - lr_value = float(lr_init) + inc_each_step * float(i) - else: - base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps))) - lr_value = float(lr_max) * base * base - if lr_value < 0.0: - lr_value = 0.0 - lr_each_step.append(lr_value) +class ProgressMonitor(Callback): + """monitor loss and time""" + def __init__(self, args_param): + super(ProgressMonitor, self).__init__() + self.me_epoch_start_time = 0 + self.me_epoch_start_step_num = 0 + self.args = args_param + self.ckpt_history = [] - current_step = global_step - lr_each_step = np.array(lr_each_step).astype(np.float32) - learning_rate = lr_each_step[current_step:] + def begin(self, run_context): + self.args.logger.info('start network train...') - return learning_rate + def epoch_begin(self, run_context): + pass + def epoch_end(self, run_context): + """ + Called after each epoch finished. -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Cifar10 classification') - parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'], + Args: + run_context (RunContext): Include some information of the model. + """ + cb_params = run_context.original_args() + me_step = cb_params.cur_step_num - 1 + + real_epoch = me_step // self.args.steps_per_epoch + time_used = time.time() - self.me_epoch_start_time + fps_mean = self.args.per_batch_size * (me_step-self.me_epoch_start_step_num) * self.args.group_size / time_used + self.args.logger.info('epoch[{}], iter[{}], loss:{}, mean_fps:{:.2f}' + 'imgs/sec'.format(real_epoch, me_step, cb_params.net_outputs, fps_mean)) + + if self.args.rank_save_ckpt_flag: + import glob + ckpts = glob.glob(os.path.join(self.args.outputs_dir, '*.ckpt')) + for ckpt in ckpts: + ckpt_fn = os.path.basename(ckpt) + if not ckpt_fn.startswith('{}-'.format(self.args.rank)): + continue + if ckpt in self.ckpt_history: + continue + self.ckpt_history.append(ckpt) + self.args.logger.info('epoch[{}], iter[{}], loss:{}, ckpt:{},' + 'ckpt_fn:{}'.format(real_epoch, me_step, cb_params.net_outputs, ckpt, ckpt_fn)) + + self.me_epoch_start_step_num = me_step + self.me_epoch_start_time = time.time() + + def step_begin(self, run_context): + pass + + def step_end(self, run_context, *me_args): + pass + + def end(self, run_context): + self.args.logger.info('end network train...') + + +def parse_args(cloud_args=None): + """parameters""" + parser = argparse.ArgumentParser('mindspore classification training') + parser.add_argument('--device_target', type=str, default='GPU', choices=['Ascend', 'GPU'], help='device where the code will be implemented. (Default: Ascend)') - parser.add_argument('--data_path', type=str, default='./cifar', help='path where the dataset is saved') - parser.add_argument('--device_id', type=int, default=None, help='device id of GPU or Ascend. (Default: None)') - parser.add_argument('--pre_trained', type=str, default=None, help='the pretrained checkpoint file path.') + parser.add_argument('--device_id', type=int, default=1, help='device id of GPU or Ascend. (Default: None)') + + # dataset related + parser.add_argument('--dataset', type=str, choices=["cifar10", "imagenet2012"], default="cifar10") + parser.add_argument('--data_path', type=str, default='', help='train data dir') + + # network related + parser.add_argument('--pre_trained', default='', type=str, help='model_path, local pretrained model to load') + parser.add_argument('--lr_gamma', type=float, default=0.1, + help='decrease lr by a factor of exponential lr_scheduler') + parser.add_argument('--eta_min', type=float, default=0., help='eta_min in cosine_annealing scheduler') + parser.add_argument('--T_max', type=int, default=150, help='T-max in cosine_annealing scheduler') + + # logging and checkpoint related + parser.add_argument('--log_interval', type=int, default=100, help='logging interval') + parser.add_argument('--ckpt_path', type=str, default='outputs/', help='checkpoint save location') + parser.add_argument('--ckpt_interval', type=int, default=5000, help='ckpt_interval') + parser.add_argument('--is_save_on_master', type=int, default=1, help='save ckpt on master or all rank') + + # distributed related + parser.add_argument('--is_distributed', type=int, default=0, help='if multi device') + parser.add_argument('--rank', type=int, default=0, help='local rank of distributed') + parser.add_argument('--group_size', type=int, default=1, help='world size of distributed') args_opt = parser.parse_args() + args_opt = merge_args(args_opt, cloud_args) + + if args_opt.dataset == "cifar10": + from src.config import cifar_cfg as cfg + else: + from src.config import imagenet_cfg as cfg + + args_opt.label_smooth = cfg.label_smooth + args_opt.label_smooth_factor = cfg.label_smooth_factor + args_opt.lr_scheduler = cfg.lr_scheduler + args_opt.loss_scale = cfg.loss_scale + args_opt.max_epoch = cfg.max_epoch + args_opt.warmup_epochs = cfg.warmup_epochs + args_opt.lr = cfg.lr + args_opt.lr_init = cfg.lr_init + args_opt.lr_max = cfg.lr_max + args_opt.momentum = cfg.momentum + args_opt.weight_decay = cfg.weight_decay + args_opt.per_batch_size = cfg.batch_size + args_opt.num_classes = cfg.num_classes + args_opt.buffer_size = cfg.buffer_size + args_opt.ckpt_save_max = cfg.keep_checkpoint_max + args_opt.pad_mode = cfg.pad_mode + args_opt.padding = cfg.padding + args_opt.has_bias = cfg.has_bias + args_opt.batch_norm = cfg.batch_norm + + args_opt.lr_epochs = list(map(int, cfg.lr_epochs.split(','))) + args_opt.image_size = list(map(int, cfg.image_size.split(','))) + + return args_opt - context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) - context.set_context(device_id=args_opt.device_id) + +def merge_args(args_opt, cloud_args): + """dictionary""" + args_dict = vars(args_opt) + if isinstance(cloud_args, dict): + for key_arg in cloud_args.keys(): + val = cloud_args[key_arg] + if key_arg in args_dict and val: + arg_type = type(args_dict[key_arg]) + if arg_type is not None: + val = arg_type(val) + args_dict[key_arg] = val + return args_opt + + +if __name__ == '__main__': + args = parse_args() device_num = int(os.environ.get("DEVICE_NUM", 1)) - if device_num > 1: + if args.is_distributed: + if args.device_target == "Ascend": + init() + elif args.device_target == "GPU": + init("nccl") + args.rank = get_rank() + args.group_size = get_group_size() + device_num = args.group_size + context.reset_auto_parallel_context() context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True) - init() + else: + context.set_context(device_id=args.device_id) + context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) + + # select for master rank save ckpt or all rank save, compatiable for model parallel + args.rank_save_ckpt_flag = 0 + if args.is_save_on_master: + if args.rank == 0: + args.rank_save_ckpt_flag = 1 + else: + args.rank_save_ckpt_flag = 1 + + # logger + args.outputs_dir = os.path.join(args.ckpt_path, + datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) + args.logger = get_logger(args.outputs_dir, args.rank) + + if args.dataset == "cifar10": + dataset = vgg_create_dataset(args.data_path, args.image_size, args.per_batch_size, args.rank, args.group_size) + else: + dataset = classification_dataset(args.data_path, args.image_size, args.per_batch_size, + args.rank, args.group_size) - dataset = vgg_create_dataset(args_opt.data_path, 1) batch_num = dataset.get_dataset_size() + args.steps_per_epoch = dataset.get_dataset_size() + args.logger.save_args(args) + + # network + args.logger.important_info('start create network') + + # get network and init + network = vgg16(args.num_classes, args) - net = vgg16(num_classes=cfg.num_classes) # pre_trained - if args_opt.pre_trained: - load_param_into_net(net, load_checkpoint(args_opt.pre_trained)) - - lr = lr_steps(0, lr_init=cfg.lr_init, lr_max=cfg.lr_max, warmup_epochs=cfg.warmup_epochs, - total_epochs=cfg.epoch_size, steps_per_epoch=batch_num) - opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), cfg.momentum, - weight_decay=cfg.weight_decay) - loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) - model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, - amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None) - - config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5, keep_checkpoint_max=cfg.keep_checkpoint_max) - time_cb = TimeMonitor(data_size=batch_num) - ckpoint_cb = ModelCheckpoint(prefix="train_vgg_cifar10", directory="./", config=config_ck) - loss_cb = LossMonitor() - model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb]) - print("train success") + if args.pre_trained: + load_param_into_net(network, load_checkpoint(args.pre_trained)) + + # lr scheduler + if args.lr_scheduler == 'exponential': + lr = warmup_step_lr(args.lr, + args.lr_epochs, + args.steps_per_epoch, + args.warmup_epochs, + args.max_epoch, + gamma=args.lr_gamma, + ) + elif args.lr_scheduler == 'cosine_annealing': + lr = warmup_cosine_annealing_lr(args.lr, + args.steps_per_epoch, + args.warmup_epochs, + args.max_epoch, + args.T_max, + args.eta_min) + elif args.lr_scheduler == 'step': + lr = lr_steps(0, lr_init=args.lr_init, lr_max=args.lr_max, warmup_epochs=args.warmup_epochs, + total_epochs=args.max_epoch, steps_per_epoch=batch_num) + else: + raise NotImplementedError(args.lr_scheduler) + + # optimizer + opt = Momentum(params=get_param_groups(network), + learning_rate=Tensor(lr), + momentum=args.momentum, + weight_decay=args.weight_decay, + loss_scale=args.loss_scale) + + if args.dataset == "cifar10": + loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) + model = Model(network, loss_fn=loss, optimizer=opt, metrics={'acc'}, + amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None) + else: + if not args.label_smooth: + args.label_smooth_factor = 0.0 + loss = CrossEntropy(smooth_factor=args.label_smooth_factor, num_classes=args.num_classes) + + loss_scale_manager = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False) + model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager, amp_level="O2") + + # checkpoint save + progress_cb = ProgressMonitor(args) + callbacks = [progress_cb,] + if args.rank_save_ckpt_flag: + ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval * args.steps_per_epoch, + keep_checkpoint_max=args.ckpt_save_max) + ckpt_cb = ModelCheckpoint(config=ckpt_config, + directory=args.outputs_dir, + prefix='{}'.format(args.rank)) + callbacks.append(ckpt_cb) + + model.train(args.max_epoch, dataset, callbacks=callbacks)