eval.py 5.2 KB
Newer Older
D
Dang Qingqing 已提交
1 2 3 4 5
import os
import time
import numpy as np
import argparse
import functools
B
Bai Yifan 已提交
6
import math
D
Dang Qingqing 已提交
7 8 9 10 11 12 13 14 15 16

import paddle
import paddle.fluid as fluid
import reader
from mobilenet_ssd import mobile_net
from utility import add_arguments, print_arguments

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
17
add_arg('dataset',          str,   'pascalvoc',  "coco2014, coco2017, and pascalvoc.")
D
Dang Qingqing 已提交
18
add_arg('batch_size',       int,   32,        "Minibatch size.")
19
add_arg('use_gpu',          bool,  True,      "Whether use GPU.")
D
Dang Qingqing 已提交
20 21
add_arg('data_dir',         str,   '',        "The data root path.")
add_arg('test_list',        str,   '',        "The testing data lists.")
22 23 24 25 26 27 28 29
add_arg('model_dir',        str,   '',     "The model path.")
add_arg('nms_threshold',    float, 0.45,   "NMS threshold.")
add_arg('ap_version',       str,   '11point',   "integral, 11point.")
add_arg('resize_h',         int,   300,    "The resized image height.")
add_arg('resize_w',         int,   300,    "The resized image height.")
add_arg('mean_value_B',     float, 127.5,  "Mean value for B channel which will be subtracted.")  #123.68
add_arg('mean_value_G',     float, 127.5,  "Mean value for G channel which will be subtracted.")  #116.78
add_arg('mean_value_R',     float, 127.5,  "Mean value for R channel which will be subtracted.")  #103.94
D
Dang Qingqing 已提交
30 31
# yapf: enable

D
Dang Qingqing 已提交
32

B
Bai Yifan 已提交
33
def build_program(main_prog, startup_prog, args, data_args):
D
Dang Qingqing 已提交
34
    image_shape = [3, data_args.resize_h, data_args.resize_w]
35 36 37
    if 'coco' in data_args.dataset:
        num_classes = 91
    elif 'pascalvoc' in data_args.dataset:
D
Dang Qingqing 已提交
38 39
        num_classes = 21

B
Bai Yifan 已提交
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
    with fluid.program_guard(main_prog, startup_prog):
        py_reader = fluid.layers.py_reader(
            capacity=64,
            shapes=[[-1] + image_shape, [-1, 4], [-1, 1], [-1, 1]],
            lod_levels=[0, 1, 1, 1],
            dtypes=["float32", "float32", "int32", "int32"],
            use_double_buffer=True)
        with fluid.unique_name.guard():
            image, gt_box, gt_label, difficult = fluid.layers.read_file(
                py_reader)
            locs, confs, box, box_var = mobile_net(num_classes, image,
                                                   image_shape)
            nmsed_out = fluid.layers.detection_output(
                locs, confs, box, box_var, nms_threshold=args.nms_threshold)
            with fluid.program_guard(main_prog):
                map = fluid.evaluator.DetectionMAP(
                    nmsed_out,
                    gt_label,
                    gt_box,
                    difficult,
                    num_classes,
                    overlap_threshold=0.5,
                    evaluate_difficult=False,
                    ap_version=args.ap_version)
    return py_reader, map

D
Dang Qingqing 已提交
66

B
Bai Yifan 已提交
67 68 69
def eval(args, data_args, test_list, batch_size, model_dir=None):
    startup_prog = fluid.Program()
    test_prog = fluid.Program()
D
Dang Qingqing 已提交
70

B
Bai Yifan 已提交
71 72 73 74 75 76
    test_py_reader, map_eval = build_program(
        main_prog=test_prog,
        startup_prog=startup_prog,
        args=args,
        data_args=data_args)
    test_prog = test_prog.clone(for_test=True)
D
Dang Qingqing 已提交
77 78
    place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)
B
Bai Yifan 已提交
79
    exe.run(startup_prog)
80
    # yapf: disable
D
Dang Qingqing 已提交
81 82 83
    if model_dir:
        def if_exist(var):
            return os.path.exists(os.path.join(model_dir, var.name))
B
Bai Yifan 已提交
84
        fluid.io.load_vars(exe, model_dir, main_program=test_prog, predicate=if_exist)
85
    # yapf: enable
B
Bai Yifan 已提交
86 87
    test_reader = reader.test(data_args, test_list, batch_size=batch_size)
    test_py_reader.decorate_paddle_reader(test_reader)
D
Dang Qingqing 已提交
88

B
Bai Yifan 已提交
89 90 91 92 93 94 95 96
    _, accum_map = map_eval.get_map_var()
    map_eval.reset(exe)
    test_py_reader.start()
    try:
        batch_id = 0
        while True:
            test_map, = exe.run(test_prog, fetch_list=[accum_map])
            if batch_id % 10 == 0:
X
Xingyuan Bu 已提交
97
                print("Batch {0}, map {1}".format(batch_id, test_map))
B
Bai Yifan 已提交
98 99 100 101
            batch_id += 1
    except fluid.core.EOFException:
        test_py_reader.reset()
    print("Test model {0}, map {1}".format(model_dir, test_map))
D
Dang Qingqing 已提交
102

D
Dang Qingqing 已提交
103

D
Dang Qingqing 已提交
104 105 106
if __name__ == '__main__':
    args = parser.parse_args()
    print_arguments(args)
107 108 109 110

    data_dir = 'data/pascalvoc'
    test_list = 'test.txt'
    label_file = 'label_list'
111 112 113 114

    if not os.path.exists(args.model_dir):
        raise ValueError("The model path [%s] does not exist." %
                         (args.model_dir))
115
    if 'coco' in args.dataset:
X
Xingyuan Bu 已提交
116
        data_dir = 'data/coco'
117
        if '2014' in args.dataset:
X
Xingyuan Bu 已提交
118
            test_list = 'annotations/instances_val2014.json'
119 120 121
        elif '2017' in args.dataset:
            test_list = 'annotations/instances_val2017.json'

D
Dang Qingqing 已提交
122 123
    data_args = reader.Settings(
        dataset=args.dataset,
124 125
        data_dir=args.data_dir if len(args.data_dir) > 0 else data_dir,
        label_file=label_file,
D
Dang Qingqing 已提交
126 127
        resize_h=args.resize_h,
        resize_w=args.resize_w,
128 129 130
        mean_value=[args.mean_value_B, args.mean_value_G, args.mean_value_R],
        apply_distort=False,
        apply_expand=False,
B
Bai Yifan 已提交
131
        ap_version=args.ap_version)
D
Dang Qingqing 已提交
132 133 134
    eval(
        args,
        data_args=data_args,
135
        test_list=args.test_list if len(args.test_list) > 0 else test_list,
D
Dang Qingqing 已提交
136 137
        batch_size=args.batch_size,
        model_dir=args.model_dir)