提交 bcf75d12 编写于 作者: W wangjiawei04

fix code style

上级 28032c43
......@@ -31,5 +31,5 @@ client.connect(['127.0.0.1:9494'])
fetch_map = client.predict(feed=feed_dict, fetch=fetch_var_names)
print(type(fetch_map['multiclass_nms']))
outs = fetch_map.values()
print (len(outs[0]), len(outs[0][0]))
print(len(outs[0]), len(outs[0][0]))
postprocess(fetch_map, fetch_var_names)
......@@ -476,7 +476,7 @@ def get_bbox_result(fetch_map, fetch_name, result, conf, clsid2catid):
output = fetch_map[fetch_name]
lod = [fetch_map[fetch_name + '.lod']]
lengths = offset_to_lengths(lod)
np_data = np.array(output)
np_data = np.array(output)
result['bbox'] = (np_data, lengths)
result['im_id'] = np.array([[0]])
......@@ -487,10 +487,10 @@ def get_bbox_result(fetch_map, fetch_name, result, conf, clsid2catid):
def get_mask_result(fetch_map, fetch_var_names, result, conf, clsid2catid):
resolution = conf['mask_resolution']
bbox_out, mask_out = fetch_map[fetch_var_names]
print (bbox_out, mask_out)
print(bbox_out, mask_out)
lengths = offset_to_lengths(bbox_out.lod())
bbox = np.array(bbox_out)
mask = np.array(mask_out)
bbox = np.array(bbox_out)
mask = np.array(mask_out)
result['bbox'] = (bbox, lengths)
result['mask'] = (mask, lengths)
mask_results = mask2out([result], clsid2catid, conf['mask_resolution'])
......@@ -533,11 +533,11 @@ def preprocess(feed_var_names):
np_data = np.array(v[0])
res = np_data
return res
feed_dict = {k: processImg(v) for k, v in zip(feed_var_names, img_data)}
return feed_dict
def postprocess(fetch_map, fetch_var_names):
config_path = FLAGS.config_path
res = {}
......@@ -548,11 +548,13 @@ def postprocess(fetch_map, fetch_var_names):
img_data = [img_data]
clsid2catid, catid2name = get_category_info(conf['with_background'],
conf['label_list'])
bbox_result = get_bbox_result(fetch_map, fetch_var_names[0], res, conf, clsid2catid)
bbox_result = get_bbox_result(fetch_map, fetch_var_names[0], res, conf,
clsid2catid)
mask_result = None
if 'mask_resolution' in conf:
res['im_shape'] = img_data[-1]
mask_result = get_mask_result(fetch_map, fetch_var_names, res, conf, clsid2catid)
mask_result = get_mask_result(fetch_map, fetch_var_names, res, conf,
clsid2catid)
if FLAGS.visualize:
if os.path.isdir(FLAGS.output_dir) is False:
os.mkdir(FLAGS.output_dir)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册