eval.py 5.1 KB
Newer Older
D
Dang Qingqing 已提交
1 2 3 4 5
import os
import time
import numpy as np
import argparse
import functools
B
Bai Yifan 已提交
6
import math
D
Dang Qingqing 已提交
7 8 9 10 11 12 13 14 15 16

import paddle
import paddle.fluid as fluid
import reader
from mobilenet_ssd import mobile_net
from utility import add_arguments, print_arguments

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
17
add_arg('dataset',          str,   'pascalvoc',  "coco2014, coco2017, and pascalvoc.")
D
Dang Qingqing 已提交
18
add_arg('batch_size',       int,   32,        "Minibatch size.")
19
add_arg('use_gpu',          bool,  True,      "Whether use GPU.")
D
Dang Qingqing 已提交
20 21
add_arg('data_dir',         str,   '',        "The data root path.")
add_arg('test_list',        str,   '',        "The testing data lists.")
22 23 24 25 26 27 28 29
add_arg('model_dir',        str,   '',     "The model path.")
add_arg('nms_threshold',    float, 0.45,   "NMS threshold.")
add_arg('ap_version',       str,   '11point',   "integral, 11point.")
add_arg('resize_h',         int,   300,    "The resized image height.")
add_arg('resize_w',         int,   300,    "The resized image height.")
add_arg('mean_value_B',     float, 127.5,  "Mean value for B channel which will be subtracted.")  #123.68
add_arg('mean_value_G',     float, 127.5,  "Mean value for G channel which will be subtracted.")  #116.78
add_arg('mean_value_R',     float, 127.5,  "Mean value for R channel which will be subtracted.")  #103.94
D
Dang Qingqing 已提交
30 31
# yapf: enable

D
Dang Qingqing 已提交
32

B
Bai Yifan 已提交
33
def build_program(main_prog, startup_prog, args, data_args):
D
Dang Qingqing 已提交
34
    image_shape = [3, data_args.resize_h, data_args.resize_w]
35 36 37
    if 'coco' in data_args.dataset:
        num_classes = 91
    elif 'pascalvoc' in data_args.dataset:
D
Dang Qingqing 已提交
38 39
        num_classes = 21

B
Bai Yifan 已提交
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
    with fluid.program_guard(main_prog, startup_prog):
        py_reader = fluid.layers.py_reader(
            capacity=64,
            shapes=[[-1] + image_shape, [-1, 4], [-1, 1], [-1, 1]],
            lod_levels=[0, 1, 1, 1],
            dtypes=["float32", "float32", "int32", "int32"],
            use_double_buffer=True)
        with fluid.unique_name.guard():
            image, gt_box, gt_label, difficult = fluid.layers.read_file(
                py_reader)
            locs, confs, box, box_var = mobile_net(num_classes, image,
                                                   image_shape)
            nmsed_out = fluid.layers.detection_output(
                locs, confs, box, box_var, nms_threshold=args.nms_threshold)
            with fluid.program_guard(main_prog):
55
                map = fluid.metrics.DetectionMAP(
B
Bai Yifan 已提交
56 57 58 59 60 61 62 63 64 65
                    nmsed_out,
                    gt_label,
                    gt_box,
                    difficult,
                    num_classes,
                    overlap_threshold=0.5,
                    evaluate_difficult=False,
                    ap_version=args.ap_version)
    return py_reader, map

D
Dang Qingqing 已提交
66

B
Bai Yifan 已提交
67 68 69
def eval(args, data_args, test_list, batch_size, model_dir=None):
    startup_prog = fluid.Program()
    test_prog = fluid.Program()
D
Dang Qingqing 已提交
70

B
Bai Yifan 已提交
71 72 73 74 75 76
    test_py_reader, map_eval = build_program(
        main_prog=test_prog,
        startup_prog=startup_prog,
        args=args,
        data_args=data_args)
    test_prog = test_prog.clone(for_test=True)
D
Dang Qingqing 已提交
77 78
    place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)
B
Bai Yifan 已提交
79
    exe.run(startup_prog)
80 81 82 83 84 85 86

    def if_exist(var):
        return os.path.exists(os.path.join(model_dir, var.name))

    fluid.io.load_vars(
        exe, model_dir, main_program=test_prog, predicate=if_exist)

B
Bai Yifan 已提交
87 88
    test_reader = reader.test(data_args, test_list, batch_size=batch_size)
    test_py_reader.decorate_paddle_reader(test_reader)
D
Dang Qingqing 已提交
89

B
Bai Yifan 已提交
90 91 92 93 94 95 96 97
    _, accum_map = map_eval.get_map_var()
    map_eval.reset(exe)
    test_py_reader.start()
    try:
        batch_id = 0
        while True:
            test_map, = exe.run(test_prog, fetch_list=[accum_map])
            if batch_id % 10 == 0:
X
Xingyuan Bu 已提交
98
                print("Batch {0}, map {1}".format(batch_id, test_map))
B
Bai Yifan 已提交
99
            batch_id += 1
100
    except (fluid.core.EOFException, StopIteration):
B
Bai Yifan 已提交
101 102
        test_py_reader.reset()
    print("Test model {0}, map {1}".format(model_dir, test_map))
D
Dang Qingqing 已提交
103

D
Dang Qingqing 已提交
104

D
Dang Qingqing 已提交
105 106 107
if __name__ == '__main__':
    args = parser.parse_args()
    print_arguments(args)
108 109 110 111

    data_dir = 'data/pascalvoc'
    test_list = 'test.txt'
    label_file = 'label_list'
112 113 114 115

    if not os.path.exists(args.model_dir):
        raise ValueError("The model path [%s] does not exist." %
                         (args.model_dir))
116
    if 'coco' in args.dataset:
X
Xingyuan Bu 已提交
117
        data_dir = 'data/coco'
118
        if '2014' in args.dataset:
X
Xingyuan Bu 已提交
119
            test_list = 'annotations/instances_val2014.json'
120 121 122
        elif '2017' in args.dataset:
            test_list = 'annotations/instances_val2017.json'

D
Dang Qingqing 已提交
123 124
    data_args = reader.Settings(
        dataset=args.dataset,
125 126
        data_dir=args.data_dir if len(args.data_dir) > 0 else data_dir,
        label_file=label_file,
D
Dang Qingqing 已提交
127 128
        resize_h=args.resize_h,
        resize_w=args.resize_w,
129 130 131
        mean_value=[args.mean_value_B, args.mean_value_G, args.mean_value_R],
        apply_distort=False,
        apply_expand=False,
B
Bai Yifan 已提交
132
        ap_version=args.ap_version)
D
Dang Qingqing 已提交
133 134 135
    eval(
        args,
        data_args=data_args,
136
        test_list=args.test_list if len(args.test_list) > 0 else test_list,
D
Dang Qingqing 已提交
137 138
        batch_size=args.batch_size,
        model_dir=args.model_dir)