提交 1e1168af 编写于 作者: W wangjiawei04

fix code style

上级 70b2ae05
...@@ -31,5 +31,5 @@ client.connect(['127.0.0.1:9494']) ...@@ -31,5 +31,5 @@ client.connect(['127.0.0.1:9494'])
fetch_map = client.predict(feed=feed_dict, fetch=fetch_var_names) fetch_map = client.predict(feed=feed_dict, fetch=fetch_var_names)
print(type(fetch_map['multiclass_nms'])) print(type(fetch_map['multiclass_nms']))
outs = fetch_map.values() outs = fetch_map.values()
print (len(outs[0]), len(outs[0][0])) print(len(outs[0]), len(outs[0][0]))
postprocess(fetch_map, fetch_var_names) postprocess(fetch_map, fetch_var_names)
...@@ -476,7 +476,7 @@ def get_bbox_result(fetch_map, fetch_name, result, conf, clsid2catid): ...@@ -476,7 +476,7 @@ def get_bbox_result(fetch_map, fetch_name, result, conf, clsid2catid):
output = fetch_map[fetch_name] output = fetch_map[fetch_name]
lod = [fetch_map[fetch_name + '.lod']] lod = [fetch_map[fetch_name + '.lod']]
lengths = offset_to_lengths(lod) lengths = offset_to_lengths(lod)
np_data = np.array(output) np_data = np.array(output)
result['bbox'] = (np_data, lengths) result['bbox'] = (np_data, lengths)
result['im_id'] = np.array([[0]]) result['im_id'] = np.array([[0]])
...@@ -487,10 +487,10 @@ def get_bbox_result(fetch_map, fetch_name, result, conf, clsid2catid): ...@@ -487,10 +487,10 @@ def get_bbox_result(fetch_map, fetch_name, result, conf, clsid2catid):
def get_mask_result(fetch_map, fetch_var_names, result, conf, clsid2catid): def get_mask_result(fetch_map, fetch_var_names, result, conf, clsid2catid):
resolution = conf['mask_resolution'] resolution = conf['mask_resolution']
bbox_out, mask_out = fetch_map[fetch_var_names] bbox_out, mask_out = fetch_map[fetch_var_names]
print (bbox_out, mask_out) print(bbox_out, mask_out)
lengths = offset_to_lengths(bbox_out.lod()) lengths = offset_to_lengths(bbox_out.lod())
bbox = np.array(bbox_out) bbox = np.array(bbox_out)
mask = np.array(mask_out) mask = np.array(mask_out)
result['bbox'] = (bbox, lengths) result['bbox'] = (bbox, lengths)
result['mask'] = (mask, lengths) result['mask'] = (mask, lengths)
mask_results = mask2out([result], clsid2catid, conf['mask_resolution']) mask_results = mask2out([result], clsid2catid, conf['mask_resolution'])
...@@ -533,11 +533,11 @@ def preprocess(feed_var_names): ...@@ -533,11 +533,11 @@ def preprocess(feed_var_names):
np_data = np.array(v[0]) np_data = np.array(v[0])
res = np_data res = np_data
return res return res
feed_dict = {k: processImg(v) for k, v in zip(feed_var_names, img_data)} feed_dict = {k: processImg(v) for k, v in zip(feed_var_names, img_data)}
return feed_dict return feed_dict
def postprocess(fetch_map, fetch_var_names): def postprocess(fetch_map, fetch_var_names):
config_path = FLAGS.config_path config_path = FLAGS.config_path
res = {} res = {}
...@@ -548,11 +548,13 @@ def postprocess(fetch_map, fetch_var_names): ...@@ -548,11 +548,13 @@ def postprocess(fetch_map, fetch_var_names):
img_data = [img_data] img_data = [img_data]
clsid2catid, catid2name = get_category_info(conf['with_background'], clsid2catid, catid2name = get_category_info(conf['with_background'],
conf['label_list']) conf['label_list'])
bbox_result = get_bbox_result(fetch_map, fetch_var_names[0], res, conf, clsid2catid) bbox_result = get_bbox_result(fetch_map, fetch_var_names[0], res, conf,
clsid2catid)
mask_result = None mask_result = None
if 'mask_resolution' in conf: if 'mask_resolution' in conf:
res['im_shape'] = img_data[-1] res['im_shape'] = img_data[-1]
mask_result = get_mask_result(fetch_map, fetch_var_names, res, conf, clsid2catid) mask_result = get_mask_result(fetch_map, fetch_var_names, res, conf,
clsid2catid)
if FLAGS.visualize: if FLAGS.visualize:
if os.path.isdir(FLAGS.output_dir) is False: if os.path.isdir(FLAGS.output_dir) is False:
os.mkdir(FLAGS.output_dir) os.mkdir(FLAGS.output_dir)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册