train.py 9.9 KB
Newer Older
W
wangjun260 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
#################train vgg16 example on cifar10########################
python train.py --data_path=$DATA_HOME --device_id=$DEVICE_ID
"""
import argparse
20
import datetime
C
caojian05 已提交
21
import os
W
wangjun260 已提交
22
import random
23

W
wangjun260 已提交
24
import numpy as np
25

W
wangjun260 已提交
26 27
import mindspore.nn as nn
from mindspore import Tensor
28
from mindspore import context
29 30
from mindspore import ParallelMode
from mindspore.communication.management import init, get_rank, get_group_size
W
wangjun260 已提交
31
from mindspore.nn.optim.momentum import Momentum
C
caojian05 已提交
32
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
33
from mindspore.train.model import Model
34
from mindspore.train.serialization import load_param_into_net, load_checkpoint
35
from mindspore.train.loss_scale_manager import FixedLossScaleManager
36
from src.dataset import vgg_create_dataset
37 38 39 40 41 42 43 44
from src.dataset import classification_dataset

from src.crossentropy import CrossEntropy
from src.warmup_step_lr import warmup_step_lr
from src.warmup_cosine_annealing_lr import warmup_cosine_annealing_lr
from src.warmup_step_lr import lr_steps
from src.utils.logging import get_logger
from src.utils.util import get_param_groups
45 46
from src.vgg import vgg16

47

W
wangjun260 已提交
48 49 50
random.seed(1)
np.random.seed(1)

51

52 53 54 55 56 57
def parse_args(cloud_args=None):
    """parameters"""
    parser = argparse.ArgumentParser('mindspore classification training')
    parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
                        help='device where the code will be implemented. (Default: Ascend)')
    parser.add_argument('--device_id', type=int, default=1, help='device id of GPU or Ascend. (Default: None)')
C
caojian05 已提交
58

59 60 61
    # dataset related
    parser.add_argument('--dataset', type=str, choices=["cifar10", "imagenet2012"], default="cifar10")
    parser.add_argument('--data_path', type=str, default='', help='train data dir')
W
wangjun260 已提交
62

63 64 65 66 67 68
    # network related
    parser.add_argument('--pre_trained', default='', type=str, help='model_path, local pretrained model to load')
    parser.add_argument('--lr_gamma', type=float, default=0.1,
                        help='decrease lr by a factor of exponential lr_scheduler')
    parser.add_argument('--eta_min', type=float, default=0., help='eta_min in cosine_annealing scheduler')
    parser.add_argument('--T_max', type=int, default=150, help='T-max in cosine_annealing scheduler')
W
wangjun260 已提交
69

70 71 72 73 74
    # logging and checkpoint related
    parser.add_argument('--log_interval', type=int, default=100, help='logging interval')
    parser.add_argument('--ckpt_path', type=str, default='outputs/', help='checkpoint save location')
    parser.add_argument('--ckpt_interval', type=int, default=5, help='ckpt_interval')
    parser.add_argument('--is_save_on_master', type=int, default=1, help='save ckpt on master or all rank')
W
wangjun260 已提交
75

76 77 78 79
    # distributed related
    parser.add_argument('--is_distributed', type=int, default=0, help='if multi device')
    parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
    parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
W
wangjun260 已提交
80
    args_opt = parser.parse_args()
81
    args_opt = merge_args(args_opt, cloud_args)
W
wangjun260 已提交
82

83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
    if args_opt.dataset == "cifar10":
        from src.config import cifar_cfg as cfg
    else:
        from src.config import imagenet_cfg as cfg

    args_opt.label_smooth = cfg.label_smooth
    args_opt.label_smooth_factor = cfg.label_smooth_factor
    args_opt.lr_scheduler = cfg.lr_scheduler
    args_opt.loss_scale = cfg.loss_scale
    args_opt.max_epoch = cfg.max_epoch
    args_opt.warmup_epochs = cfg.warmup_epochs
    args_opt.lr = cfg.lr
    args_opt.lr_init = cfg.lr_init
    args_opt.lr_max = cfg.lr_max
    args_opt.momentum = cfg.momentum
    args_opt.weight_decay = cfg.weight_decay
    args_opt.per_batch_size = cfg.batch_size
    args_opt.num_classes = cfg.num_classes
    args_opt.buffer_size = cfg.buffer_size
    args_opt.ckpt_save_max = cfg.keep_checkpoint_max
    args_opt.pad_mode = cfg.pad_mode
    args_opt.padding = cfg.padding
    args_opt.has_bias = cfg.has_bias
    args_opt.batch_norm = cfg.batch_norm
    args_opt.initialize_mode = cfg.initialize_mode
    args_opt.has_dropout = cfg.has_dropout

    args_opt.lr_epochs = list(map(int, cfg.lr_epochs.split(',')))
    args_opt.image_size = list(map(int, cfg.image_size.split(',')))

    return args_opt


def merge_args(args_opt, cloud_args):
    """dictionary"""
    args_dict = vars(args_opt)
    if isinstance(cloud_args, dict):
        for key_arg in cloud_args.keys():
            val = cloud_args[key_arg]
            if key_arg in args_dict and val:
                arg_type = type(args_dict[key_arg])
                if arg_type is not None:
                    val = arg_type(val)
                args_dict[key_arg] = val
    return args_opt


if __name__ == '__main__':
    args = parse_args()
W
wangjun260 已提交
132

C
caojian05 已提交
133
    device_num = int(os.environ.get("DEVICE_NUM", 1))
134 135 136 137 138 139 140 141 142 143
    if args.is_distributed:
        if args.device_target == "Ascend":
            init()
            context.set_context(device_id=args.device_id)
        elif args.device_target == "GPU":
            init("nccl")

        args.rank = get_rank()
        args.group_size = get_group_size()
        device_num = args.group_size
C
caojian05 已提交
144 145
        context.reset_auto_parallel_context()
        context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
                                          parameter_broadcast=True, mirror_mean=True)
    else:
        context.set_context(device_id=args.device_id)
    context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)

    # select for master rank save ckpt or all rank save, compatible for model parallel
    args.rank_save_ckpt_flag = 0
    if args.is_save_on_master:
        if args.rank == 0:
            args.rank_save_ckpt_flag = 1
    else:
        args.rank_save_ckpt_flag = 1

    # logger
    args.outputs_dir = os.path.join(args.ckpt_path,
                                    datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
    args.logger = get_logger(args.outputs_dir, args.rank)

    if args.dataset == "cifar10":
        dataset = vgg_create_dataset(args.data_path, args.image_size, args.per_batch_size, args.rank, args.group_size,
                                     repeat_num=args.max_epoch)
    else:
        dataset = classification_dataset(args.data_path, args.image_size, args.per_batch_size,
                                         args.rank, args.group_size, repeat_num=args.max_epoch)
C
caojian05 已提交
170 171

    batch_num = dataset.get_dataset_size()
172 173 174 175 176 177 178 179
    args.steps_per_epoch = dataset.get_dataset_size()
    args.logger.save_args(args)

    # network
    args.logger.important_info('start create network')

    # get network and init
    network = vgg16(args.num_classes, args)
C
caojian05 已提交
180

181
    # pre_trained
182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226
    if args.pre_trained:
        load_param_into_net(network, load_checkpoint(args.pre_trained))

    # lr scheduler
    if args.lr_scheduler == 'exponential':
        lr = warmup_step_lr(args.lr,
                            args.lr_epochs,
                            args.steps_per_epoch,
                            args.warmup_epochs,
                            args.max_epoch,
                            gamma=args.lr_gamma,
                            )
    elif args.lr_scheduler == 'cosine_annealing':
        lr = warmup_cosine_annealing_lr(args.lr,
                                        args.steps_per_epoch,
                                        args.warmup_epochs,
                                        args.max_epoch,
                                        args.T_max,
                                        args.eta_min)
    elif args.lr_scheduler == 'step':
        lr = lr_steps(0, lr_init=args.lr_init, lr_max=args.lr_max, warmup_epochs=args.warmup_epochs,
                      total_epochs=args.max_epoch, steps_per_epoch=batch_num)
    else:
        raise NotImplementedError(args.lr_scheduler)

    # optimizer
    opt = Momentum(params=get_param_groups(network),
                   learning_rate=Tensor(lr),
                   momentum=args.momentum,
                   weight_decay=args.weight_decay,
                   loss_scale=args.loss_scale)

    if args.dataset == "cifar10":
        loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False)
        model = Model(network, loss_fn=loss, optimizer=opt, metrics={'acc'},
                      amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None)
    else:
        if not args.label_smooth:
            args.label_smooth_factor = 0.0
        loss = CrossEntropy(smooth_factor=args.label_smooth_factor, num_classes=args.num_classes)

        loss_scale_manager = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False)
        model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager, amp_level="O2")

    # define callbacks
C
caojian05 已提交
227
    time_cb = TimeMonitor(data_size=batch_num)
228 229 230 231 232 233 234 235 236 237 238
    loss_cb = LossMonitor(per_print_times=batch_num)
    callbacks = [time_cb, loss_cb]
    if args.rank_save_ckpt_flag:
        ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval * args.steps_per_epoch,
                                       keep_checkpoint_max=args.ckpt_save_max)
        ckpt_cb = ModelCheckpoint(config=ckpt_config,
                                  directory=args.outputs_dir,
                                  prefix='{}'.format(args.rank))
        callbacks.append(ckpt_cb)

    model.train(args.max_epoch, dataset, callbacks=callbacks)