未验证 提交 07ef9374 编写于 作者: J Jiawei Wang 提交者: GitHub

Merge pull request #1111 from TeslaZhao/develop

Update paddle predictor API verson to 2.0
...@@ -48,7 +48,7 @@ class DetOp(Op): ...@@ -48,7 +48,7 @@ class DetOp(Op):
imgs = [] imgs = []
for key in input_dict.keys(): for key in input_dict.keys():
data = base64.b64decode(input_dict[key].encode('utf8')) data = base64.b64decode(input_dict[key].encode('utf8'))
data = np.fromstring(data, np.uint8) data = np.frombuffer(data, np.uint8)
self.im = cv2.imdecode(data, cv2.IMREAD_COLOR) self.im = cv2.imdecode(data, cv2.IMREAD_COLOR)
self.ori_h, self.ori_w, _ = self.im.shape self.ori_h, self.ori_w, _ = self.im.shape
det_img = self.det_preprocess(self.im) det_img = self.det_preprocess(self.im)
...@@ -57,7 +57,7 @@ class DetOp(Op): ...@@ -57,7 +57,7 @@ class DetOp(Op):
return {"image": np.concatenate(imgs, axis=0)}, False, None, "" return {"image": np.concatenate(imgs, axis=0)}, False, None, ""
def postprocess(self, input_dicts, fetch_dict, log_id): def postprocess(self, input_dicts, fetch_dict, log_id):
# print(fetch_dict) # print(fetch_dict)
det_out = fetch_dict["concat_1.tmp_0"] det_out = fetch_dict["concat_1.tmp_0"]
ratio_list = [ ratio_list = [
float(self.new_h) / self.ori_h, float(self.new_w) / self.ori_w float(self.new_h) / self.ori_h, float(self.new_w) / self.ori_w
...@@ -114,5 +114,5 @@ class OcrService(WebService): ...@@ -114,5 +114,5 @@ class OcrService(WebService):
uci_service = OcrService(name="ocr") uci_service = OcrService(name="ocr")
uci_service.prepare_pipeline_config("config2.yml") uci_service.prepare_pipeline_config("config.yml")
uci_service.run_service() uci_service.run_service()
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
try: try:
from paddle_serving_server.web_service import WebService, Op from paddle_serving_server.web_service import WebService, Op
except ImportError: except ImportError:
from paddle_serving_server.web_service import WebService, Op from paddle_serving_server_gpu.web_service import WebService, Op
import logging import logging
import numpy as np import numpy as np
import sys import sys
...@@ -34,8 +34,11 @@ class UciOp(Op): ...@@ -34,8 +34,11 @@ class UciOp(Op):
x_value = input_dict["x"].split(self.batch_separator) x_value = input_dict["x"].split(self.batch_separator)
x_lst = [] x_lst = []
for x_val in x_value: for x_val in x_value:
x_lst.append(np.array([float(x.strip()) for x in x_val.split(self.separator)]).reshape(1, 13)) x_lst.append(
input_dict["x"] = np.concatenate(x_lst, axis=0) np.array([
float(x.strip()) for x in x_val.split(self.separator)
]).reshape(1, 13))
input_dict["x"] = np.concatenate(x_lst, axis=0)
proc_dict = {} proc_dict = {}
return input_dict, False, None, "" return input_dict, False, None, ""
...@@ -53,5 +56,5 @@ class UciService(WebService): ...@@ -53,5 +56,5 @@ class UciService(WebService):
uci_service = UciService(name="uci") uci_service = UciService(name="uci")
uci_service.prepare_pipeline_config("config2.yml") uci_service.prepare_pipeline_config("config.yml")
uci_service.run_service() uci_service.run_service()
...@@ -19,16 +19,12 @@ import os ...@@ -19,16 +19,12 @@ import os
import google.protobuf.text_format import google.protobuf.text_format
import numpy as np import numpy as np
import argparse import argparse
import paddle.fluid as fluid
import paddle.inference as inference
from .proto import general_model_config_pb2 as m_config from .proto import general_model_config_pb2 as m_config
from paddle.fluid.core import PaddleTensor import paddle.inference as paddle_infer
from paddle.fluid.core import AnalysisConfig
from paddle.fluid.core import create_paddle_predictor
import logging import logging
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s") logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger("fluid") logger = logging.getLogger("LocalPredictor")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
...@@ -62,7 +58,7 @@ class LocalPredictor(object): ...@@ -62,7 +58,7 @@ class LocalPredictor(object):
use_xpu=False, use_xpu=False,
use_feed_fetch_ops=False): use_feed_fetch_ops=False):
""" """
Load model config and set the engine config for the paddle predictor Load model configs and create the paddle predictor by Paddle Inference API.
Args: Args:
model_path: model config path. model_path: model config path.
...@@ -83,14 +79,18 @@ class LocalPredictor(object): ...@@ -83,14 +79,18 @@ class LocalPredictor(object):
model_conf = google.protobuf.text_format.Merge( model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf) str(f.read()), model_conf)
if os.path.exists(os.path.join(model_path, "__params__")): if os.path.exists(os.path.join(model_path, "__params__")):
config = AnalysisConfig(os.path.join(model_path, "__model__"), os.path.join(model_path, "__params__")) config = paddle_infer.Config(
os.path.join(model_path, "__model__"),
os.path.join(model_path, "__params__"))
else: else:
config = AnalysisConfig(model_path) config = paddle_infer.Config(model_path)
logger.info("load_model_config params: model_path:{}, use_gpu:{},\
logger.info(
"LocalPredictor load_model_config params: model_path:{}, use_gpu:{},\
gpu_id:{}, use_profile:{}, thread_num:{}, mem_optim:{}, ir_optim:{},\ gpu_id:{}, use_profile:{}, thread_num:{}, mem_optim:{}, ir_optim:{},\
use_trt:{}, use_lite:{}, use_xpu: {}, use_feed_fetch_ops:{}".format( use_trt:{}, use_lite:{}, use_xpu: {}, use_feed_fetch_ops:{}".format(
model_path, use_gpu, gpu_id, use_profile, thread_num, mem_optim, model_path, use_gpu, gpu_id, use_profile, thread_num, mem_optim,
ir_optim, use_trt, use_lite, use_xpu, use_feed_fetch_ops)) ir_optim, use_trt, use_lite, use_xpu, use_feed_fetch_ops))
self.feed_names_ = [var.alias_name for var in model_conf.feed_var] self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var] self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
...@@ -129,7 +129,7 @@ class LocalPredictor(object): ...@@ -129,7 +129,7 @@ class LocalPredictor(object):
if use_lite: if use_lite:
config.enable_lite_engine( config.enable_lite_engine(
precision_mode=inference.PrecisionType.Float32, precision_mode=paddle_infer.PrecisionType.Float32,
zero_copy=True, zero_copy=True,
passes_filter=[], passes_filter=[],
ops_filter=[]) ops_filter=[])
...@@ -138,11 +138,11 @@ class LocalPredictor(object): ...@@ -138,11 +138,11 @@ class LocalPredictor(object):
# 2MB l3 cache # 2MB l3 cache
config.enable_xpu(8 * 1024 * 1024) config.enable_xpu(8 * 1024 * 1024)
self.predictor = create_paddle_predictor(config) self.predictor = paddle_infer.create_predictor(config)
def predict(self, feed=None, fetch=None, batch=False, log_id=0): def predict(self, feed=None, fetch=None, batch=False, log_id=0):
""" """
Predict locally Run model inference by Paddle Inference API.
Args: Args:
feed: feed var feed: feed var
...@@ -155,14 +155,16 @@ class LocalPredictor(object): ...@@ -155,14 +155,16 @@ class LocalPredictor(object):
fetch_map: dict fetch_map: dict
""" """
if feed is None or fetch is None: if feed is None or fetch is None:
raise ValueError("You should specify feed and fetch for prediction") raise ValueError("You should specify feed and fetch for prediction.\
log_id:{}".format(log_id))
fetch_list = [] fetch_list = []
if isinstance(fetch, str): if isinstance(fetch, str):
fetch_list = [fetch] fetch_list = [fetch]
elif isinstance(fetch, list): elif isinstance(fetch, list):
fetch_list = fetch fetch_list = fetch
else: else:
raise ValueError("Fetch only accepts string and list of string") raise ValueError("Fetch only accepts string and list of string.\
log_id:{}".format(log_id))
feed_batch = [] feed_batch = []
if isinstance(feed, dict): if isinstance(feed, dict):
...@@ -170,27 +172,21 @@ class LocalPredictor(object): ...@@ -170,27 +172,21 @@ class LocalPredictor(object):
elif isinstance(feed, list): elif isinstance(feed, list):
feed_batch = feed feed_batch = feed
else: else:
raise ValueError("Feed only accepts dict and list of dict") raise ValueError("Feed only accepts dict and list of dict.\
log_id:{}".format(log_id))
int_slot_batch = []
float_slot_batch = []
int_feed_names = []
float_feed_names = []
int_shape = []
float_shape = []
fetch_names = []
counter = 0
batch_size = len(feed_batch)
fetch_names = []
# Filter invalid fetch names
for key in fetch_list: for key in fetch_list:
if key in self.fetch_names_: if key in self.fetch_names_:
fetch_names.append(key) fetch_names.append(key)
if len(fetch_names) == 0: if len(fetch_names) == 0:
raise ValueError( raise ValueError(
"Fetch names should not be empty or out of saved fetch list.") "Fetch names should not be empty or out of saved fetch list.\
return {} log_id:{}".format(log_id))
# Assemble the input data of paddle predictor
input_names = self.predictor.get_input_names() input_names = self.predictor.get_input_names()
for name in input_names: for name in input_names:
if isinstance(feed[name], list): if isinstance(feed[name], list):
...@@ -204,27 +200,31 @@ class LocalPredictor(object): ...@@ -204,27 +200,31 @@ class LocalPredictor(object):
feed[name] = feed[name].astype("int32") feed[name] = feed[name].astype("int32")
else: else:
raise ValueError("local predictor receives wrong data type") raise ValueError("local predictor receives wrong data type")
input_tensor = self.predictor.get_input_tensor(name) input_tensor_handle = self.predictor.get_input_handle(name)
if "{}.lod".format(name) in feed: if "{}.lod".format(name) in feed:
input_tensor.set_lod([feed["{}.lod".format(name)]]) input_tensor_handle.set_lod([feed["{}.lod".format(name)]])
if batch == False: if batch == False:
input_tensor.copy_from_cpu(feed[name][np.newaxis, :]) input_tensor_handle.copy_from_cpu(feed[name][np.newaxis, :])
else: else:
input_tensor.copy_from_cpu(feed[name]) input_tensor_handle.copy_from_cpu(feed[name])
output_tensors = [] output_tensor_handles = []
output_names = self.predictor.get_output_names() output_names = self.predictor.get_output_names()
for output_name in output_names: for output_name in output_names:
output_tensor = self.predictor.get_output_tensor(output_name) output_tensor_handle = self.predictor.get_output_handle(output_name)
output_tensors.append(output_tensor) output_tensor_handles.append(output_tensor_handle)
# Run inference
self.predictor.run()
# Assemble output data of predict results
outputs = [] outputs = []
self.predictor.zero_copy_run() for output_tensor_handle in output_tensor_handles:
for output_tensor in output_tensors: output = output_tensor_handle.copy_to_cpu()
output = output_tensor.copy_to_cpu()
outputs.append(output) outputs.append(output)
fetch_map = {} fetch_map = {}
for i, name in enumerate(fetch): for i, name in enumerate(fetch):
fetch_map[name] = outputs[i] fetch_map[name] = outputs[i]
if len(output_tensors[i].lod()) > 0: if len(output_tensor_handles[i].lod()) > 0:
fetch_map[name + ".lod"] = np.array(output_tensors[i].lod()[ fetch_map[name + ".lod"] = np.array(output_tensor_handles[i]
0]).astype('int32') .lod()[0]).astype('int32')
return fetch_map return fetch_map
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册