提交 f3904a1d 编写于 作者: W wangjiawei04

fix faster rcnn

上级 7cb0028c
...@@ -16,12 +16,20 @@ from paddle_serving_client import Client ...@@ -16,12 +16,20 @@ from paddle_serving_client import Client
import sys import sys
import os import os
import time import time
from paddle_serving_app.pddet import infer, ArgParse from paddle_serving_app.pddet import preprocess, postprocess, ArgParse
import numpy as np import numpy as np
py_version = sys.version_info[0] py_version = sys.version_info[0]
feed_var_names = ['image', 'im_shape', 'im_info'] feed_var_names = ['image', 'im_shape', 'im_info']
fetch_var_names = ['multiclass_nms'] fetch_var_names = ['multiclass_nms']
ArgParse() FLAGS = ArgParse()
infer(['127.0.0.1:9494'], feed_var_names, fetch_var_names) feed_dict = preprocess(feed_var_names)
client = Client()
client.load_client_config(FLAGS.serving_client_conf)
client.connect(['127.0.0.1:9494'])
fetch_map = client.predict(feed=feed_dict, fetch=fetch_var_names)
print(fetch_map)
outs = fetch_map.values()
print (len(outs[0]), len(outs[0][0]))
postprocess(fetch_map, fetch_var_names)
...@@ -24,7 +24,7 @@ import argparse ...@@ -24,7 +24,7 @@ import argparse
import cv2 import cv2
import yaml import yaml
import copy import copy
import json
import logging import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s' FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
...@@ -471,11 +471,12 @@ def draw_mask(image, masks, threshold, color_list, alpha=0.7): ...@@ -471,11 +471,12 @@ def draw_mask(image, masks, threshold, color_list, alpha=0.7):
return Image.fromarray(img_array.astype('uint8')) return Image.fromarray(img_array.astype('uint8'))
def get_bbox_result(output, result, conf, clsid2catid): def get_bbox_result(fetch_map, fetch_name, result, conf, clsid2catid):
is_bbox_normalized = True if 'SSD' in conf['arch'] else False is_bbox_normalized = True if 'SSD' in conf['arch'] else False
lengths = offset_to_lengths(output.lod()) output = fetch_map[fetch_name]
np_data = np.array(output) if conf[ lod = [fetch_map[fetch_name + '.lod']]
'use_python_inference'] else output.copy_to_cpu() lengths = offset_to_lengths(lod)
np_data = np.array(output)
result['bbox'] = (np_data, lengths) result['bbox'] = (np_data, lengths)
result['im_id'] = np.array([[0]]) result['im_id'] = np.array([[0]])
...@@ -483,14 +484,13 @@ def get_bbox_result(output, result, conf, clsid2catid): ...@@ -483,14 +484,13 @@ def get_bbox_result(output, result, conf, clsid2catid):
return bbox_results return bbox_results
def get_mask_result(output, result, conf, clsid2catid): def get_mask_result(fetch_map, fetch_var_names, result, conf, clsid2catid):
resolution = conf['mask_resolution'] resolution = conf['mask_resolution']
bbox_out, mask_out = output bbox_out, mask_out = fetch_map[fetch_var_names]
print (bbox_out, mask_out)
lengths = offset_to_lengths(bbox_out.lod()) lengths = offset_to_lengths(bbox_out.lod())
bbox = np.array(bbox_out) if conf[ bbox = np.array(bbox_out)
'use_python_inference'] else bbox_out.copy_to_cpu() mask = np.array(mask_out)
mask = np.array(mask_out) if conf[
'use_python_inference'] else mask_out.copy_to_cpu()
result['bbox'] = (bbox, lengths) result['bbox'] = (bbox, lengths)
result['mask'] = (mask, lengths) result['mask'] = (mask, lengths)
mask_results = mask2out([result], clsid2catid, conf['mask_resolution']) mask_results = mask2out([result], clsid2catid, conf['mask_resolution'])
...@@ -511,7 +511,7 @@ def visualize(bbox_results, catid2name, num_classes, mask_results=None): ...@@ -511,7 +511,7 @@ def visualize(bbox_results, catid2name, num_classes, mask_results=None):
logger.info('Save visualize result to {}'.format(out_path)) logger.info('Save visualize result to {}'.format(out_path))
def infer(server_ip_list, feed_var_names, fetch_var_names): def preprocess(feed_var_names):
global FLAGS global FLAGS
config_path = FLAGS.config_path config_path = FLAGS.config_path
res = {} res = {}
...@@ -530,37 +530,41 @@ def infer(server_ip_list, feed_var_names, fetch_var_names): ...@@ -530,37 +530,41 @@ def infer(server_ip_list, feed_var_names, fetch_var_names):
) )
def processImg(v): def processImg(v):
np_data = np.array(v) np_data = np.array(v[0])
np_feed = np.reshape(np_data, (-1)) res = np_data
res = np_feed.tolist()
return res return res
feed_dict = {k: processImg(v) for k, v in zip(feed_var_names, img_data)} feed_dict = {k: processImg(v) for k, v in zip(feed_var_names, img_data)}
# Infer from Server return feed_dict
from paddle_serving_client import Client
client = Client()
client.load_client_config(FLAGS.serving_client_conf)
client.connect(server_ip_list) def postprocess(fetch_map, fetch_var_names):
fetch_map = client.predict(feed=feed_dict, fetch=fetch_var_names) config_path = FLAGS.config_path
print(fetch_map) res = {}
outs = fetch_map.values() with open(config_path) as f:
conf = yaml.safe_load(f)
# post process if 'SSD' in conf['arch']:
img_data, res['im_shape'] = img_data
img_data = [img_data]
clsid2catid, catid2name = get_category_info(conf['with_background'], clsid2catid, catid2name = get_category_info(conf['with_background'],
conf['label_list']) conf['label_list'])
bbox_result = get_bbox_result(outs[0], res, conf, clsid2catid) bbox_result = get_bbox_result(fetch_map, fetch_var_names[0], res, conf, clsid2catid)
mask_result = None mask_result = None
if 'mask_resolution' in conf: if 'mask_resolution' in conf:
res['im_shape'] = img_data[-1] res['im_shape'] = img_data[-1]
mask_result = get_mask_result(outs, res, conf, clsid2catid) mask_result = get_mask_result(fetch_map, fetch_var_names, res, conf, clsid2catid)
if FLAGS.visualize: if FLAGS.visualize:
if os.path.isdir(FLAGS.output_dir) is False:
os.mkdir(FLAGS.output_dir)
visualize(bbox_result, catid2name, len(conf['label_list']), mask_result) visualize(bbox_result, catid2name, len(conf['label_list']), mask_result)
if flags.dump_result: if FLAGS.dump_result:
bbox_file = os.path.join(flags.output_dir, 'bbox.json') if os.path.isdir(FLAGS.output_dir) is False:
os.mkdir(FLAGS.output_dir)
bbox_file = os.path.join(FLAGS.output_dir, 'bbox.json')
logger.info('dump bbox to {}'.format(bbox_file)) logger.info('dump bbox to {}'.format(bbox_file))
with open(bbox_file, 'w') as f: with open(bbox_file, 'w') as f:
json.dump(bbox_result, f, indent=4) json.dump(bbox_result, f, indent=4)
if mask_result is not none: if mask_result is not None:
mask_file = os.path.join(flags.output_dir, 'mask.json') mask_file = os.path.join(flags.output_dir, 'mask.json')
logger.info('dump mask to {}'.format(mask_file)) logger.info('dump mask to {}'.format(mask_file))
with open(mask_file, 'w') as f: with open(mask_file, 'w') as f:
...@@ -569,8 +573,6 @@ def infer(server_ip_list, feed_var_names, fetch_var_names): ...@@ -569,8 +573,6 @@ def infer(server_ip_list, feed_var_names, fetch_var_names):
def ArgParse(): def ArgParse():
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--model_path", type=str, default=None, help="model path.")
parser.add_argument( parser.add_argument(
"--config_path", type=str, default=None, help="preprocess config path.") "--config_path", type=str, default=None, help="preprocess config path.")
parser.add_argument( parser.add_argument(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册