utility.py 5.2 KB
Newer Older
D
dengkaipeng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
#  Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
"""
Contains common utility functions.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import distutils.util
import numpy as np
import six
from collections import deque
from paddle.fluid import core
import argparse
import functools
from config.config import *


def print_arguments(args):
    """Print argparse's arguments.

    Usage:

    .. code-block:: python

        parser = argparse.ArgumentParser()
        parser.add_argument("name", default="Jonh", type=str, help="User name.")
        args = parser.parse_args()
        print_arguments(args)

    :param args: Input argparse.Namespace for printing.
    :type args: argparse.Namespace
    """
    print("-----------  Configuration Arguments -----------")
    for arg, value in sorted(six.iteritems(vars(args))):
        print("%s: %s" % (arg, value))
    print("------------------------------------------------")


def add_arguments(argname, type, default, help, argparser, **kwargs):
    """Add argparse's argument.

    Usage:

    .. code-block:: python

        parser = argparse.ArgumentParser()
        add_argument("name", str, "Jonh", "User name.", parser)
        args = parser.parse_args()
    """
    type = distutils.util.strtobool if type == bool else type
    argparser.add_argument(
        "--" + argname,
        default=default,
        type=type,
        help=help + ' Default: %(default)s.',
        **kwargs)


class SmoothedValue(object):
    """Track a series of values and provide access to smoothed values over a
    window or the global series average.
    """

    def __init__(self, window_size):
        self.deque = deque(maxlen=window_size)

    def add_value(self, value):
        self.deque.append(value)

    def get_median_value(self):
        return np.median(self.deque)


def parse_args():
    """return all args
    """
    parser = argparse.ArgumentParser(description=__doc__)
    add_arg = functools.partial(add_arguments, argparser=parser)
    # yapf: disable
    # ENV
    add_arg('parallel',         bool,   True,       "Whether use parallel.")
    add_arg('use_gpu',          bool,   True,      "Whether use GPU.")
    add_arg('model_cfg_path',   str,    'config/yolov3.cfg', "YOLO model config file path.")
    add_arg('model_save_dir',   str,    'checkpoints',     "The path to save model.")
    add_arg('pretrain_base',    str,    'weights/darknet53', "The init model weights path.")
    add_arg('pretrained_model', str,    'weights/mxnet', "The pretrained model path.")
    add_arg('dataset',          str,    'coco2017',  "Dataset: coco2014, coco2017.")
    add_arg('class_num',        int,    80,          "Class number.")
    add_arg('data_dir',         str,    'dataset/coco',        "The data root path.")
    add_arg('use_pyreader',     bool,   True,           "Use pyreader.")
D
dengkaipeng 已提交
105
    add_arg('use_multiprocess', bool,   True,           "Use multiprocessing for train reader.")
D
dengkaipeng 已提交
106 107 108 109 110 111 112 113 114
    add_arg('use_profile',      bool,   False,       "Whether use profiler.")
    add_arg('start_iter',       int,    0,     "Start iteration.")
    #SOLVER
    add_arg('learning_rate',    float,  0.001,     "Learning rate.")
    add_arg('max_iter',         int,    500200,   "Iter number.")
    add_arg('snapshot_iter',    int,    2000,    "Save model every snapshot stride.")
    add_arg('log_window',       int,    20,        "Log smooth window, set 1 for debug, set 20 for train.")
    # TRAIN TEST INFER
    add_arg('input_size',       int,    608,    "Image input size of YOLOv3.")
D
dengkaipeng 已提交
115 116
    add_arg('random_shape',     bool,   True,     "Resize to random shape for train reader.")
    add_arg('label_smooth',     bool,   True,     "Use label smooth in class label.")
D
dengkaipeng 已提交
117
    add_arg('no_mixup_iter',    int,    40000,      "Disable mixup in last N iter.")
D
dengkaipeng 已提交
118 119 120 121 122 123 124 125 126 127 128 129 130 131
    add_arg('valid_thresh',     float,  0.01,    "Valid confidence score for NMS.")
    add_arg('nms_thresh',       float,  0.45,    "NMS threshold.")
    add_arg('nms_topk',         int,    400,    "The number of boxes to perform NMS.")
    add_arg('nms_posk',         int,    100,    "The number of boxes of NMS output.")
    add_arg('debug',            bool,   False,   "Debug mode")
    # SINGLE EVAL AND DRAW
    add_arg('image_path',       str,   'image',  "The image path used to inference and visualize.")
    add_arg('image_name',       str,    None,       "The single image used to inference and visualize. None to inference all images in image_path")
    add_arg('draw_thresh',      float,  0.5,    "Confidence score threshold to draw prediction box in image in debug mode")
    # yapf: enable
    args = parser.parse_args()
    file_name = sys.argv[0]
    merge_cfg_from_args(args)
    return args