未验证 提交 5471a176 编写于 作者: W wangguanzhong 提交者: GitHub

Update pphuman (#5393)

* refine attr vis & refine model_dir in config

* support model_dir in command line
上级 6e1fa92d
...@@ -11,6 +11,6 @@ ATTR: ...@@ -11,6 +11,6 @@ ATTR:
batch_size: 8 batch_size: 8
MOT: MOT:
model_dir: output_inference/pedestrian_yolov3_darknet/ model_dir: output_inference/mot_ppyolov3/
tracker_config: deploy/pphuman/tracker_config.yml tracker_config: deploy/pphuman/config/tracker_config.yml
batch_size: 1 batch_size: 1
...@@ -45,6 +45,8 @@ def argsparser(): ...@@ -45,6 +45,8 @@ def argsparser():
default=None, default=None,
help="Path of video file, `video_file` or `camera_id` has a highest priority." help="Path of video file, `video_file` or `camera_id` has a highest priority."
) )
parser.add_argument(
"--model_dir", nargs='*', help="set model dir in pipeline")
parser.add_argument( parser.add_argument(
"--camera_id", "--camera_id",
type=int, type=int,
...@@ -182,6 +184,21 @@ class PipeTimer(Times): ...@@ -182,6 +184,21 @@ class PipeTimer(Times):
return dic return dic
def merge_model_dir(args, model_dir):
# set --model_dir DET=ppyoloe/ to overwrite the model_dir in config file
task_set = ['DET', 'ATTR', 'MOT', 'KPT', 'ACTION']
if not model_dir:
return args
for md in model_dir:
md = md.strip()
k, v = md.split('=', 1)
k_upper = k.upper()
assert k_upper in task_set, 'Illegal type of task, expect task are: {}, but received {}'.format(
task_set, k)
args[k_upper].update({'model_dir': v})
return args
def merge_cfg(args): def merge_cfg(args):
with open(args.config) as f: with open(args.config) as f:
pred_config = yaml.safe_load(f) pred_config = yaml.safe_load(f)
...@@ -196,14 +213,17 @@ def merge_cfg(args): ...@@ -196,14 +213,17 @@ def merge_cfg(args):
merge_cfg[k] = merge(v, arg) merge_cfg[k] = merge(v, arg)
return merge_cfg return merge_cfg
pred_config = merge(pred_config, vars(args)) args_dict = vars(args)
model_dir = args_dict.pop('model_dir')
pred_config = merge_model_dir(pred_config, model_dir)
pred_config = merge(pred_config, args_dict)
return pred_config return pred_config
def print_arguments(cfg): def print_arguments(cfg):
print('----------- Running Arguments -----------') print('----------- Running Arguments -----------')
for arg, value in sorted(cfg.items()): buffer = yaml.dump(cfg)
print('%s: %s' % (arg, value)) print(buffer)
print('------------------------------------------') print('------------------------------------------')
......
...@@ -96,13 +96,13 @@ class AttrDetector(Detector): ...@@ -96,13 +96,13 @@ class AttrDetector(Detector):
age_list = ['AgeLess18', 'Age18-60', 'AgeOver60'] age_list = ['AgeLess18', 'Age18-60', 'AgeOver60']
direct_list = ['Front', 'Side', 'Back'] direct_list = ['Front', 'Side', 'Back']
bag_list = ['HandBag', 'ShoulderBag', 'Backpack'] bag_list = ['HandBag', 'ShoulderBag', 'Backpack']
upper_list = [ upper_list = ['UpperStride', 'UpperLogo', 'UpperPlaid', 'UpperSplice']
'UpperStride', 'UpperLogo', 'UpperPlaid', 'UpperSplice', 'LongCoat'
]
lower_list = [ lower_list = [
'LowerStripe', 'LowerPattern', 'Trousers', 'Shorts', 'Skirt&Dress' 'LowerStripe', 'LowerPattern', 'LongCoat', 'Trousers', 'Shorts',
'Skirt&Dress'
] ]
glasses_threshold = 0.3
hold_threshold = 0.6
batch_res = [] batch_res = []
for res in im_results: for res in im_results:
res = res.tolist() res = res.tolist()
...@@ -118,7 +118,7 @@ class AttrDetector(Detector): ...@@ -118,7 +118,7 @@ class AttrDetector(Detector):
label_res.append(direction) label_res.append(direction)
# glasses # glasses
glasses = 'Glasses: ' glasses = 'Glasses: '
if res[1] > self.threshold: if res[1] > glasses_threshold:
glasses += 'True' glasses += 'True'
else: else:
glasses += 'False' glasses += 'False'
...@@ -132,7 +132,7 @@ class AttrDetector(Detector): ...@@ -132,7 +132,7 @@ class AttrDetector(Detector):
label_res.append(hat) label_res.append(hat)
# hold obj # hold obj
hold_obj = 'HoldObjectsInFront: ' hold_obj = 'HoldObjectsInFront: '
if res[18] > self.threshold: if res[18] > hold_threshold:
hold_obj += 'True' hold_obj += 'True'
else: else:
hold_obj += 'False' hold_obj += 'False'
...@@ -143,7 +143,7 @@ class AttrDetector(Detector): ...@@ -143,7 +143,7 @@ class AttrDetector(Detector):
bag_label = bag if bag_score > self.threshold else 'No bag' bag_label = bag if bag_score > self.threshold else 'No bag'
label_res.append(bag_label) label_res.append(bag_label)
# upper # upper
upper_res = res[4:8] + res[10:11] upper_res = res[4:8]
upper_label = 'Upper:' upper_label = 'Upper:'
sleeve = 'LongSleeve' if res[3] > res[2] else 'ShortSleeve' sleeve = 'LongSleeve' if res[3] > res[2] else 'ShortSleeve'
upper_label += ' {}'.format(sleeve) upper_label += ' {}'.format(sleeve)
...@@ -152,7 +152,7 @@ class AttrDetector(Detector): ...@@ -152,7 +152,7 @@ class AttrDetector(Detector):
upper_label += ' {}'.format(upper_list[i]) upper_label += ' {}'.format(upper_list[i])
label_res.append(upper_label) label_res.append(upper_label)
# lower # lower
lower_res = res[8:10] + res[11:14] lower_res = res[8:14]
lower_label = 'Lower: ' lower_label = 'Lower: '
has_lower = False has_lower = False
for i, l in enumerate(lower_res): for i, l in enumerate(lower_res):
......
...@@ -338,8 +338,8 @@ def visualize_attr(im, results, boxes=None): ...@@ -338,8 +338,8 @@ def visualize_attr(im, results, boxes=None):
im = np.ascontiguousarray(np.copy(im)) im = np.ascontiguousarray(np.copy(im))
im_h, im_w = im.shape[:2] im_h, im_w = im.shape[:2]
text_scale = max(1, int(im.shape[0] / 1600.)) text_scale = max(1, int(im.shape[0] / 1200.))
text_thickness = 2 text_thickness = 3
line_inter = im.shape[0] / 50. line_inter = im.shape[0] / 50.
for i, res in enumerate(results): for i, res in enumerate(results):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册