提交 f3904a1d 编写于 作者: W wangjiawei04

fix faster rcnn

上级 7cb0028c
......@@ -16,12 +16,20 @@ from paddle_serving_client import Client
import sys
import os
import time
from paddle_serving_app.pddet import infer, ArgParse
from paddle_serving_app.pddet import preprocess, postprocess, ArgParse
import numpy as np
py_version = sys.version_info[0]
feed_var_names = ['image', 'im_shape', 'im_info']
fetch_var_names = ['multiclass_nms']
ArgParse()
infer(['127.0.0.1:9494'], feed_var_names, fetch_var_names)
FLAGS = ArgParse()
feed_dict = preprocess(feed_var_names)
client = Client()
client.load_client_config(FLAGS.serving_client_conf)
client.connect(['127.0.0.1:9494'])
fetch_map = client.predict(feed=feed_dict, fetch=fetch_var_names)
print(fetch_map)
outs = fetch_map.values()
print (len(outs[0]), len(outs[0][0]))
postprocess(fetch_map, fetch_var_names)
......@@ -24,7 +24,7 @@ import argparse
import cv2
import yaml
import copy
import json
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
......@@ -471,11 +471,12 @@ def draw_mask(image, masks, threshold, color_list, alpha=0.7):
return Image.fromarray(img_array.astype('uint8'))
def get_bbox_result(output, result, conf, clsid2catid):
def get_bbox_result(fetch_map, fetch_name, result, conf, clsid2catid):
is_bbox_normalized = True if 'SSD' in conf['arch'] else False
lengths = offset_to_lengths(output.lod())
np_data = np.array(output) if conf[
'use_python_inference'] else output.copy_to_cpu()
output = fetch_map[fetch_name]
lod = [fetch_map[fetch_name + '.lod']]
lengths = offset_to_lengths(lod)
np_data = np.array(output)
result['bbox'] = (np_data, lengths)
result['im_id'] = np.array([[0]])
......@@ -483,14 +484,13 @@ def get_bbox_result(output, result, conf, clsid2catid):
return bbox_results
def get_mask_result(output, result, conf, clsid2catid):
def get_mask_result(fetch_map, fetch_var_names, result, conf, clsid2catid):
resolution = conf['mask_resolution']
bbox_out, mask_out = output
bbox_out, mask_out = fetch_map[fetch_var_names]
print (bbox_out, mask_out)
lengths = offset_to_lengths(bbox_out.lod())
bbox = np.array(bbox_out) if conf[
'use_python_inference'] else bbox_out.copy_to_cpu()
mask = np.array(mask_out) if conf[
'use_python_inference'] else mask_out.copy_to_cpu()
bbox = np.array(bbox_out)
mask = np.array(mask_out)
result['bbox'] = (bbox, lengths)
result['mask'] = (mask, lengths)
mask_results = mask2out([result], clsid2catid, conf['mask_resolution'])
......@@ -511,7 +511,7 @@ def visualize(bbox_results, catid2name, num_classes, mask_results=None):
logger.info('Save visualize result to {}'.format(out_path))
def infer(server_ip_list, feed_var_names, fetch_var_names):
def preprocess(feed_var_names):
global FLAGS
config_path = FLAGS.config_path
res = {}
......@@ -530,37 +530,41 @@ def infer(server_ip_list, feed_var_names, fetch_var_names):
)
def processImg(v):
np_data = np.array(v)
np_feed = np.reshape(np_data, (-1))
res = np_feed.tolist()
np_data = np.array(v[0])
res = np_data
return res
feed_dict = {k: processImg(v) for k, v in zip(feed_var_names, img_data)}
# Infer from Server
from paddle_serving_client import Client
client = Client()
client.load_client_config(FLAGS.serving_client_conf)
client.connect(server_ip_list)
fetch_map = client.predict(feed=feed_dict, fetch=fetch_var_names)
print(fetch_map)
outs = fetch_map.values()
# post process
return feed_dict
def postprocess(fetch_map, fetch_var_names):
config_path = FLAGS.config_path
res = {}
with open(config_path) as f:
conf = yaml.safe_load(f)
if 'SSD' in conf['arch']:
img_data, res['im_shape'] = img_data
img_data = [img_data]
clsid2catid, catid2name = get_category_info(conf['with_background'],
conf['label_list'])
bbox_result = get_bbox_result(outs[0], res, conf, clsid2catid)
bbox_result = get_bbox_result(fetch_map, fetch_var_names[0], res, conf, clsid2catid)
mask_result = None
if 'mask_resolution' in conf:
res['im_shape'] = img_data[-1]
mask_result = get_mask_result(outs, res, conf, clsid2catid)
mask_result = get_mask_result(fetch_map, fetch_var_names, res, conf, clsid2catid)
if FLAGS.visualize:
if os.path.isdir(FLAGS.output_dir) is False:
os.mkdir(FLAGS.output_dir)
visualize(bbox_result, catid2name, len(conf['label_list']), mask_result)
if flags.dump_result:
bbox_file = os.path.join(flags.output_dir, 'bbox.json')
if FLAGS.dump_result:
if os.path.isdir(FLAGS.output_dir) is False:
os.mkdir(FLAGS.output_dir)
bbox_file = os.path.join(FLAGS.output_dir, 'bbox.json')
logger.info('dump bbox to {}'.format(bbox_file))
with open(bbox_file, 'w') as f:
json.dump(bbox_result, f, indent=4)
if mask_result is not none:
if mask_result is not None:
mask_file = os.path.join(flags.output_dir, 'mask.json')
logger.info('dump mask to {}'.format(mask_file))
with open(mask_file, 'w') as f:
......@@ -569,8 +573,6 @@ def infer(server_ip_list, feed_var_names, fetch_var_names):
def ArgParse():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--model_path", type=str, default=None, help="model path.")
parser.add_argument(
"--config_path", type=str, default=None, help="preprocess config path.")
parser.add_argument(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册