未验证 提交 ab1c367f 编写于 作者: G Guanghua Yu 提交者: GitHub

fix act some bug (#1421)

上级 b747596c
......@@ -98,7 +98,8 @@ def eval():
def main(args):
global global_config
global_config = load_slim_config(args.config_path)
all_config = load_slim_config(args.config_path)
global_config = all_config["Global"]
global data_dir
data_dir = global_config['data_dir']
......
......@@ -55,7 +55,7 @@ def eval():
score_threshold=0.001,
nms_threshold=0.65,
multi_label=True,
num_top_k=global_config.get('nms_num_top_k', 30000))
nms_top_k=global_config.get('nms_num_top_k', 30000))
bboxes_list, bbox_nums_list, image_id_list = [], [], []
with tqdm(
......
......@@ -80,7 +80,7 @@ class YOLOPostProcess(object):
multi_label(bool): Whether keep multi label in boxes.
keep_top_k(int): Number of total bboxes to be kept per image after NMS
step. -1 means keeping all bboxes after NMS step.
num_top_k(int): Maximum number of boxes put into torchvision.ops.nums()
nms_top_k(int): Maximum number of boxes put into nums.
"""
def __init__(self,
......@@ -88,12 +88,12 @@ class YOLOPostProcess(object):
nms_threshold=0.5,
multi_label=False,
keep_top_k=300,
num_top_k=30000):
nms_top_k=30000):
self.score_threshold = score_threshold
self.nms_threshold = nms_threshold
self.multi_label = multi_label
self.keep_top_k = keep_top_k
self.num_top_k = num_top_k
self.nms_top_k = nms_top_k
def _xywh2xyxy(self, x):
# Convert from [x, y, w, h] to [x1, y1, x2, y2]
......
......@@ -71,7 +71,10 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
feed={test_feed_names[0]: data_all['image']},
fetch_list=test_fetch_list,
return_numpy=False)
res = postprocess(np.array(outs[0]), data_all['scale_factor'])
res = postprocess(
np.array(outs[0]),
data_all['scale_factor'],
nms_top_k=global_config.get('nms_num_top_k', 30000))
bboxes_list.append(res['bbox'])
bbox_nums_list.append(res['bbox_num'])
image_id_list.append(np.array(data_all['im_id']))
......
......@@ -98,7 +98,8 @@ def eval():
def main(args):
global global_config
global_config = load_slim_config(args.config_path)
all_config = load_slim_config(args.config_path)
global_config = all_config["Global"]
global data_dir
data_dir = global_config['data_dir']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册