diff --git a/python/examples/faster_rcnn_model/test_client.py b/python/examples/faster_rcnn_model/test_client.py index cd78005926c4a10140afa0f22aef6aecb0c16464..ca54ea7117284001d05f8dd2403b76def656472d 100755 --- a/python/examples/faster_rcnn_model/test_client.py +++ b/python/examples/faster_rcnn_model/test_client.py @@ -16,12 +16,20 @@ from paddle_serving_client import Client import sys import os import time -from paddle_serving_app.pddet import infer, ArgParse +from paddle_serving_app.pddet import preprocess, postprocess, ArgParse import numpy as np py_version = sys.version_info[0] feed_var_names = ['image', 'im_shape', 'im_info'] fetch_var_names = ['multiclass_nms'] -ArgParse() -infer(['127.0.0.1:9494'], feed_var_names, fetch_var_names) +FLAGS = ArgParse() +feed_dict = preprocess(feed_var_names) +client = Client() +client.load_client_config(FLAGS.serving_client_conf) +client.connect(['127.0.0.1:9494']) +fetch_map = client.predict(feed=feed_dict, fetch=fetch_var_names) +print(fetch_map) +outs = fetch_map.values() +print (len(outs[0]), len(outs[0][0])) +postprocess(fetch_map, fetch_var_names) diff --git a/python/paddle_serving_app/pddet/__init__.py b/python/paddle_serving_app/pddet/__init__.py index ed5e97a62a3d800ef7b4787a38ea55855987e0e0..1bc009e25fd961b0f44d46a902a2f546b749bf53 100644 --- a/python/paddle_serving_app/pddet/__init__.py +++ b/python/paddle_serving_app/pddet/__init__.py @@ -24,7 +24,7 @@ import argparse import cv2 import yaml import copy - +import json import logging FORMAT = '%(asctime)s-%(levelname)s: %(message)s' @@ -471,11 +471,12 @@ def draw_mask(image, masks, threshold, color_list, alpha=0.7): return Image.fromarray(img_array.astype('uint8')) -def get_bbox_result(output, result, conf, clsid2catid): +def get_bbox_result(fetch_map, fetch_name, result, conf, clsid2catid): is_bbox_normalized = True if 'SSD' in conf['arch'] else False - lengths = offset_to_lengths(output.lod()) - np_data = np.array(output) if conf[ - 'use_python_inference'] else output.copy_to_cpu() + output = fetch_map[fetch_name] + lod = [fetch_map[fetch_name + '.lod']] + lengths = offset_to_lengths(lod) + np_data = np.array(output) result['bbox'] = (np_data, lengths) result['im_id'] = np.array([[0]]) @@ -483,14 +484,13 @@ def get_bbox_result(output, result, conf, clsid2catid): return bbox_results -def get_mask_result(output, result, conf, clsid2catid): +def get_mask_result(fetch_map, fetch_var_names, result, conf, clsid2catid): resolution = conf['mask_resolution'] - bbox_out, mask_out = output + bbox_out, mask_out = fetch_map[fetch_var_names] + print (bbox_out, mask_out) lengths = offset_to_lengths(bbox_out.lod()) - bbox = np.array(bbox_out) if conf[ - 'use_python_inference'] else bbox_out.copy_to_cpu() - mask = np.array(mask_out) if conf[ - 'use_python_inference'] else mask_out.copy_to_cpu() + bbox = np.array(bbox_out) + mask = np.array(mask_out) result['bbox'] = (bbox, lengths) result['mask'] = (mask, lengths) mask_results = mask2out([result], clsid2catid, conf['mask_resolution']) @@ -511,7 +511,7 @@ def visualize(bbox_results, catid2name, num_classes, mask_results=None): logger.info('Save visualize result to {}'.format(out_path)) -def infer(server_ip_list, feed_var_names, fetch_var_names): +def preprocess(feed_var_names): global FLAGS config_path = FLAGS.config_path res = {} @@ -530,37 +530,41 @@ def infer(server_ip_list, feed_var_names, fetch_var_names): ) def processImg(v): - np_data = np.array(v) - np_feed = np.reshape(np_data, (-1)) - res = np_feed.tolist() + np_data = np.array(v[0]) + res = np_data return res - feed_dict = {k: processImg(v) for k, v in zip(feed_var_names, img_data)} - # Infer from Server - from paddle_serving_client import Client - client = Client() - client.load_client_config(FLAGS.serving_client_conf) - client.connect(server_ip_list) - fetch_map = client.predict(feed=feed_dict, fetch=fetch_var_names) - print(fetch_map) - outs = fetch_map.values() - - # post process + return feed_dict + + + +def postprocess(fetch_map, fetch_var_names): + config_path = FLAGS.config_path + res = {} + with open(config_path) as f: + conf = yaml.safe_load(f) + if 'SSD' in conf['arch']: + img_data, res['im_shape'] = img_data + img_data = [img_data] clsid2catid, catid2name = get_category_info(conf['with_background'], conf['label_list']) - bbox_result = get_bbox_result(outs[0], res, conf, clsid2catid) + bbox_result = get_bbox_result(fetch_map, fetch_var_names[0], res, conf, clsid2catid) mask_result = None if 'mask_resolution' in conf: res['im_shape'] = img_data[-1] - mask_result = get_mask_result(outs, res, conf, clsid2catid) + mask_result = get_mask_result(fetch_map, fetch_var_names, res, conf, clsid2catid) if FLAGS.visualize: + if os.path.isdir(FLAGS.output_dir) is False: + os.mkdir(FLAGS.output_dir) visualize(bbox_result, catid2name, len(conf['label_list']), mask_result) - if flags.dump_result: - bbox_file = os.path.join(flags.output_dir, 'bbox.json') + if FLAGS.dump_result: + if os.path.isdir(FLAGS.output_dir) is False: + os.mkdir(FLAGS.output_dir) + bbox_file = os.path.join(FLAGS.output_dir, 'bbox.json') logger.info('dump bbox to {}'.format(bbox_file)) with open(bbox_file, 'w') as f: json.dump(bbox_result, f, indent=4) - if mask_result is not none: + if mask_result is not None: mask_file = os.path.join(flags.output_dir, 'mask.json') logger.info('dump mask to {}'.format(mask_file)) with open(mask_file, 'w') as f: @@ -569,8 +573,6 @@ def infer(server_ip_list, feed_var_names, fetch_var_names): def ArgParse(): parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument( - "--model_path", type=str, default=None, help="model path.") parser.add_argument( "--config_path", type=str, default=None, help="preprocess config path.") parser.add_argument(