未验证 提交 ab1c367f 编写于 作者: G Guanghua Yu 提交者: GitHub

fix act some bug (#1421)

上级 b747596c
...@@ -98,7 +98,8 @@ def eval(): ...@@ -98,7 +98,8 @@ def eval():
def main(args): def main(args):
global global_config global global_config
global_config = load_slim_config(args.config_path) all_config = load_slim_config(args.config_path)
global_config = all_config["Global"]
global data_dir global data_dir
data_dir = global_config['data_dir'] data_dir = global_config['data_dir']
......
...@@ -55,7 +55,7 @@ def eval(): ...@@ -55,7 +55,7 @@ def eval():
score_threshold=0.001, score_threshold=0.001,
nms_threshold=0.65, nms_threshold=0.65,
multi_label=True, multi_label=True,
num_top_k=global_config.get('nms_num_top_k', 30000)) nms_top_k=global_config.get('nms_num_top_k', 30000))
bboxes_list, bbox_nums_list, image_id_list = [], [], [] bboxes_list, bbox_nums_list, image_id_list = [], [], []
with tqdm( with tqdm(
......
...@@ -80,7 +80,7 @@ class YOLOPostProcess(object): ...@@ -80,7 +80,7 @@ class YOLOPostProcess(object):
multi_label(bool): Whether keep multi label in boxes. multi_label(bool): Whether keep multi label in boxes.
keep_top_k(int): Number of total bboxes to be kept per image after NMS keep_top_k(int): Number of total bboxes to be kept per image after NMS
step. -1 means keeping all bboxes after NMS step. step. -1 means keeping all bboxes after NMS step.
num_top_k(int): Maximum number of boxes put into torchvision.ops.nums() nms_top_k(int): Maximum number of boxes put into nums.
""" """
def __init__(self, def __init__(self,
...@@ -88,12 +88,12 @@ class YOLOPostProcess(object): ...@@ -88,12 +88,12 @@ class YOLOPostProcess(object):
nms_threshold=0.5, nms_threshold=0.5,
multi_label=False, multi_label=False,
keep_top_k=300, keep_top_k=300,
num_top_k=30000): nms_top_k=30000):
self.score_threshold = score_threshold self.score_threshold = score_threshold
self.nms_threshold = nms_threshold self.nms_threshold = nms_threshold
self.multi_label = multi_label self.multi_label = multi_label
self.keep_top_k = keep_top_k self.keep_top_k = keep_top_k
self.num_top_k = num_top_k self.nms_top_k = nms_top_k
def _xywh2xyxy(self, x): def _xywh2xyxy(self, x):
# Convert from [x, y, w, h] to [x1, y1, x2, y2] # Convert from [x, y, w, h] to [x1, y1, x2, y2]
......
...@@ -71,7 +71,10 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list): ...@@ -71,7 +71,10 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
feed={test_feed_names[0]: data_all['image']}, feed={test_feed_names[0]: data_all['image']},
fetch_list=test_fetch_list, fetch_list=test_fetch_list,
return_numpy=False) return_numpy=False)
res = postprocess(np.array(outs[0]), data_all['scale_factor']) res = postprocess(
np.array(outs[0]),
data_all['scale_factor'],
nms_top_k=global_config.get('nms_num_top_k', 30000))
bboxes_list.append(res['bbox']) bboxes_list.append(res['bbox'])
bbox_nums_list.append(res['bbox_num']) bbox_nums_list.append(res['bbox_num'])
image_id_list.append(np.array(data_all['im_id'])) image_id_list.append(np.array(data_all['im_id']))
......
...@@ -98,7 +98,8 @@ def eval(): ...@@ -98,7 +98,8 @@ def eval():
def main(args): def main(args):
global global_config global global_config
global_config = load_slim_config(args.config_path) all_config = load_slim_config(args.config_path)
global_config = all_config["Global"]
global data_dir global data_dir
data_dir = global_config['data_dir'] data_dir = global_config['data_dir']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册