eval.py 5.0 KB
Newer Older
L
liuluobin 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Eval"""
import os
import argparse
import datetime
import mindspore.nn as nn

from mindspore import context
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
from mindarmour.utils import LogUtil

from vgg.vgg import vgg16
from vgg.dataset import vgg_create_dataset100
from vgg.config import cifar_cfg as cfg


class ParameterReduce(nn.Cell):
    """ParameterReduce"""
    def __init__(self):
        super(ParameterReduce, self).__init__()
        self.cast = P.Cast()
        self.reduce = P.AllReduce()

    def construct(self, x):
        one = self.cast(F.scalar_to_array(1.0), mstype.float32)
        out = x*one
        ret = self.reduce(out)
        return ret


def parse_args(cloud_args=None):
    """parse_args"""
    parser = argparse.ArgumentParser('mindspore classification test')
    parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
                        help='device where the code will be implemented. (Default: Ascend)')
    # dataset related
    parser.add_argument('--data_path', type=str, default='', help='eval data dir')
    parser.add_argument('--per_batch_size', default=32, type=int, help='batch size for per npu')
    # network related
    parser.add_argument('--graph_ckpt', type=int, default=1, help='graph ckpt or feed ckpt')
    parser.add_argument('--pre_trained', default='', type=str, help='fully path of pretrained model to load. '
                        'If it is a direction, it will test all ckpt')

    # logging related
    parser.add_argument('--log_path', type=str, default='outputs/', help='path to save log')
    parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
    parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')

    args_opt = parser.parse_args()
    args_opt = merge_args(args_opt, cloud_args)

    args_opt.image_size = cfg.image_size
    args_opt.num_classes = cfg.num_classes
    args_opt.per_batch_size = cfg.batch_size
    args_opt.momentum = cfg.momentum
    args_opt.weight_decay = cfg.weight_decay
    args_opt.buffer_size = cfg.buffer_size
    args_opt.pad_mode = cfg.pad_mode
    args_opt.padding = cfg.padding
    args_opt.has_bias = cfg.has_bias
    args_opt.batch_norm = cfg.batch_norm
    args_opt.initialize_mode = cfg.initialize_mode
    args_opt.has_dropout = cfg.has_dropout

    args_opt.image_size = list(map(int, args_opt.image_size.split(',')))

    return args_opt


def merge_args(args, cloud_args):
    """merge_args"""
    args_dict = vars(args)
    if isinstance(cloud_args, dict):
        for key in cloud_args.keys():
            val = cloud_args[key]
            if key in args_dict and val:
                arg_type = type(args_dict[key])
                if arg_type is not type(None):
                    val = arg_type(val)
                args_dict[key] = val
    return args


def test(cloud_args=None):
    """test"""
    args = parse_args(cloud_args)
    context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
                        device_target=args.device_target, save_graphs=False)
    if os.getenv('DEVICE_ID', "not_set").isdigit():
        context.set_context(device_id=int(os.getenv('DEVICE_ID')))

    args.outputs_dir = os.path.join(args.log_path,
                                    datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))

    args.logger = LogUtil.get_instance()
    args.logger.set_level(20)

    net = vgg16(num_classes=args.num_classes, args=args)
    opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, args.momentum,
                   weight_decay=args.weight_decay)
    loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False)
    model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})

    param_dict = load_checkpoint(args.pre_trained)
    load_param_into_net(net, param_dict)
    net.set_train(False)

    dataset_test = vgg_create_dataset100(args.data_path, args.image_size, args.per_batch_size, training=False)
    res = model.eval(dataset_test)
    print("result: ", res)


if __name__ == "__main__":
    test()