eval.py 5.1 KB
Newer Older
D
Dang Qingqing 已提交
1 2 3 4 5
import os
import time
import numpy as np
import argparse
import functools
B
Bai Yifan 已提交
6
import math
D
Dang Qingqing 已提交
7 8 9 10 11 12 13 14 15 16

import paddle
import paddle.fluid as fluid
import reader
from mobilenet_ssd import mobile_net
from utility import add_arguments, print_arguments

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
17
add_arg('dataset',          str,   'pascalvoc',  "coco2014, coco2017, and pascalvoc.")
D
Dang Qingqing 已提交
18
add_arg('batch_size',       int,   32,        "Minibatch size.")
19
add_arg('use_gpu',          bool,  True,      "Whether use GPU.")
D
Dang Qingqing 已提交
20 21
add_arg('data_dir',         str,   '',        "The data root path.")
add_arg('test_list',        str,   '',        "The testing data lists.")
22 23 24 25 26 27 28 29
add_arg('model_dir',        str,   '',     "The model path.")
add_arg('nms_threshold',    float, 0.45,   "NMS threshold.")
add_arg('ap_version',       str,   '11point',   "integral, 11point.")
add_arg('resize_h',         int,   300,    "The resized image height.")
add_arg('resize_w',         int,   300,    "The resized image height.")
add_arg('mean_value_B',     float, 127.5,  "Mean value for B channel which will be subtracted.")  #123.68
add_arg('mean_value_G',     float, 127.5,  "Mean value for G channel which will be subtracted.")  #116.78
add_arg('mean_value_R',     float, 127.5,  "Mean value for R channel which will be subtracted.")  #103.94
D
Dang Qingqing 已提交
30 31
# yapf: enable

D
Dang Qingqing 已提交
32

B
Bai Yifan 已提交
33
def build_program(main_prog, startup_prog, args, data_args):
D
Dang Qingqing 已提交
34
    image_shape = [3, data_args.resize_h, data_args.resize_w]
35 36 37
    if 'coco' in data_args.dataset:
        num_classes = 91
    elif 'pascalvoc' in data_args.dataset:
D
Dang Qingqing 已提交
38 39
        num_classes = 21

B
Bai Yifan 已提交
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
    with fluid.program_guard(main_prog, startup_prog):
        py_reader = fluid.layers.py_reader(
            capacity=64,
            shapes=[[-1] + image_shape, [-1, 4], [-1, 1], [-1, 1]],
            lod_levels=[0, 1, 1, 1],
            dtypes=["float32", "float32", "int32", "int32"],
            use_double_buffer=True)
        with fluid.unique_name.guard():
            image, gt_box, gt_label, difficult = fluid.layers.read_file(
                py_reader)
            locs, confs, box, box_var = mobile_net(num_classes, image,
                                                   image_shape)
            nmsed_out = fluid.layers.detection_output(
                locs, confs, box, box_var, nms_threshold=args.nms_threshold)
            with fluid.program_guard(main_prog):
                map = fluid.evaluator.DetectionMAP(
                    nmsed_out,
                    gt_label,
                    gt_box,
                    difficult,
                    num_classes,
                    overlap_threshold=0.5,
                    evaluate_difficult=False,
                    ap_version=args.ap_version)
    return py_reader, map

D
Dang Qingqing 已提交
66

B
Bai Yifan 已提交
67 68 69
def eval(args, data_args, test_list, batch_size, model_dir=None):
    startup_prog = fluid.Program()
    test_prog = fluid.Program()
D
Dang Qingqing 已提交
70

B
Bai Yifan 已提交
71 72 73 74 75 76
    test_py_reader, map_eval = build_program(
        main_prog=test_prog,
        startup_prog=startup_prog,
        args=args,
        data_args=data_args)
    test_prog = test_prog.clone(for_test=True)
D
Dang Qingqing 已提交
77 78
    place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)
B
Bai Yifan 已提交
79
    exe.run(startup_prog)
80 81 82 83 84 85 86

    def if_exist(var):
        return os.path.exists(os.path.join(model_dir, var.name))

    fluid.io.load_vars(
        exe, model_dir, main_program=test_prog, predicate=if_exist)

B
Bai Yifan 已提交
87 88
    test_reader = reader.test(data_args, test_list, batch_size=batch_size)
    test_py_reader.decorate_paddle_reader(test_reader)
D
Dang Qingqing 已提交
89

B
Bai Yifan 已提交
90 91 92 93 94 95 96 97
    _, accum_map = map_eval.get_map_var()
    map_eval.reset(exe)
    test_py_reader.start()
    try:
        batch_id = 0
        while True:
            test_map, = exe.run(test_prog, fetch_list=[accum_map])
            if batch_id % 10 == 0:
X
Xingyuan Bu 已提交
98
                print("Batch {0}, map {1}".format(batch_id, test_map))
B
Bai Yifan 已提交
99
            batch_id += 1
100
    except (fluid.core.EOFException, StopIteration):
B
Bai Yifan 已提交
101 102
        test_py_reader.reset()
    print("Test model {0}, map {1}".format(model_dir, test_map))
D
Dang Qingqing 已提交
103

D
Dang Qingqing 已提交
104

D
Dang Qingqing 已提交
105 106 107
if __name__ == '__main__':
    args = parser.parse_args()
    print_arguments(args)
108 109 110 111

    data_dir = 'data/pascalvoc'
    test_list = 'test.txt'
    label_file = 'label_list'
112 113 114 115

    if not os.path.exists(args.model_dir):
        raise ValueError("The model path [%s] does not exist." %
                         (args.model_dir))
116
    if 'coco' in args.dataset:
X
Xingyuan Bu 已提交
117
        data_dir = 'data/coco'
118
        if '2014' in args.dataset:
X
Xingyuan Bu 已提交
119
            test_list = 'annotations/instances_val2014.json'
120 121 122
        elif '2017' in args.dataset:
            test_list = 'annotations/instances_val2017.json'

D
Dang Qingqing 已提交
123 124
    data_args = reader.Settings(
        dataset=args.dataset,
125 126
        data_dir=args.data_dir if len(args.data_dir) > 0 else data_dir,
        label_file=label_file,
D
Dang Qingqing 已提交
127 128
        resize_h=args.resize_h,
        resize_w=args.resize_w,
129 130 131
        mean_value=[args.mean_value_B, args.mean_value_G, args.mean_value_R],
        apply_distort=False,
        apply_expand=False,
B
Bai Yifan 已提交
132
        ap_version=args.ap_version)
D
Dang Qingqing 已提交
133 134 135
    eval(
        args,
        data_args=data_args,
136
        test_list=args.test_list if len(args.test_list) > 0 else test_list,
D
Dang Qingqing 已提交
137 138
        batch_size=args.batch_size,
        model_dir=args.model_dir)