提交 fdca4220 编写于 作者: W wangjiawei04

support local on pipelien

上级 e66b54cf
rpc_port: 18080
worker_num: 4
build_dag_each_worker: false
http_port: 9999
dag:
is_thread_op: false
retry: 1
use_profile: false
op:
det:
concurrency: 2
local_service_conf:
client_type: brpc
model_config: ocr_det_model
devices: ""
rec:
concurrency: 1
timeout: -1
retry: 1
local_service_conf:
client_type: brpc
model_config: ocr_rec_model
devices: ""
......@@ -4,19 +4,20 @@ build_dag_each_worker: false
http_port: 9999
dag:
is_thread_op: false
client_type: brpc
retry: 1
use_profile: false
op:
det:
concurrency: 2
local_service_conf:
client_type: local_predictor
model_config: ocr_det_model
devices: "0"
devices: ""
rec:
concurrency: 1
timeout: -1
retry: 1
local_service_conf:
client_type: local_predictor
model_config: ocr_rec_model
devices: "0"
devices: ""
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
try:
from paddle_serving_server.web_service import WebService, Op
except ImportError:
from paddle_serving_server.web_service import WebService, Op
import logging
import numpy as np
import cv2
import base64
from paddle_serving_app.reader import OCRReader
from paddle_serving_app.reader import Sequential, ResizeByFactor
from paddle_serving_app.reader import Div, Normalize, Transpose
from paddle_serving_app.reader import DBPostProcess, FilterBoxes, GetRotateCropImage, SortedBoxes
_LOGGER = logging.getLogger()
class DetOp(Op):
def init_op(self):
self.det_preprocess = Sequential([
ResizeByFactor(32, 960), Div(255),
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), Transpose(
(2, 0, 1))
])
self.filter_func = FilterBoxes(10, 10)
self.post_func = DBPostProcess({
"thresh": 0.3,
"box_thresh": 0.5,
"max_candidates": 1000,
"unclip_ratio": 1.5,
"min_size": 3
})
def preprocess(self, input_dicts):
(_, input_dict), = input_dicts.items()
data = base64.b64decode(input_dict["image"].encode('utf8'))
data = np.fromstring(data, np.uint8)
# Note: class variables(self.var) can only be used in process op mode
self.im = cv2.imdecode(data, cv2.IMREAD_COLOR)
self.ori_h, self.ori_w, _ = self.im.shape
det_img = self.det_preprocess(self.im)
_, self.new_h, self.new_w = det_img.shape
return {"image": det_img[np.newaxis,:].copy()}
def postprocess(self, input_dicts, fetch_dict):
det_out = fetch_dict["concat_1.tmp_0"]
ratio_list = [
float(self.new_h) / self.ori_h, float(self.new_w) / self.ori_w
]
dt_boxes_list = self.post_func(det_out, [ratio_list])
dt_boxes = self.filter_func(dt_boxes_list[0], [self.ori_h, self.ori_w])
out_dict = {"dt_boxes": dt_boxes, "image": self.im}
return out_dict
class RecOp(Op):
def init_op(self):
self.ocr_reader = OCRReader()
self.get_rotate_crop_image = GetRotateCropImage()
self.sorted_boxes = SortedBoxes()
def preprocess(self, input_dicts):
(_, input_dict), = input_dicts.items()
im = input_dict["image"]
dt_boxes = input_dict["dt_boxes"]
dt_boxes = self.sorted_boxes(dt_boxes)
feed_list = []
img_list = []
max_wh_ratio = 0
for i, dtbox in enumerate(dt_boxes):
boximg = self.get_rotate_crop_image(im, dt_boxes[i])
img_list.append(boximg)
h, w = boximg.shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
_, w, h = self.ocr_reader.resize_norm_img(img_list[0],
max_wh_ratio).shape
imgs = np.zeros((len(img_list), 3, w, h)).astype('float32')
for id, img in enumerate(img_list):
norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio)
imgs[id] = norm_img
feed = {"image": imgs.copy()}
return feed
def postprocess(self, input_dicts, fetch_dict):
rec_res = self.ocr_reader.postprocess(fetch_dict, with_score=True)
res_lst = []
for res in rec_res:
res_lst.append(res[0])
res = {"res": str(res_lst)}
return res
class OcrService(WebService):
def get_pipeline_response(self, read_op):
det_op = DetOp(name="det", input_ops=[read_op])
rec_op = RecOp(name="rec", input_ops=[det_op])
return rec_op
uci_service = OcrService(name="ocr")
uci_service.prepare_pipeline_config("config.yml")
uci_service.run_service()
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
try:
from paddle_serving_server_gpu.web_service import WebService, Op
from paddle_serving_server.web_service import WebService, Op
except ImportError:
from paddle_serving_server.web_service import WebService, Op
import logging
......@@ -108,5 +108,5 @@ class OcrService(WebService):
uci_service = OcrService(name="ocr")
uci_service.prepare_pipeline_config("config.yml")
uci_service.prepare_pipeline_config("brpc_config.yml")
uci_service.run_service()
......@@ -76,7 +76,9 @@ class Debugger(object):
config.switch_use_feed_fetch_ops(False)
self.predictor = create_paddle_predictor(config)
def predict(self, feed=None, fetch=None):
def predict(self, feed=None, fetch=None, log_id=0):
print("feed", feed)
print("fetch", fetch)
if feed is None or fetch is None:
raise ValueError("You should specify feed and fetch for prediction")
fetch_list = []
......@@ -139,5 +141,5 @@ class Debugger(object):
for i, name in enumerate(fetch):
fetch_map[name] = outputs[i]
if len(output_tensors[i].lod()) > 0:
fetch_map[name + ".lod"] = output_tensors[i].lod()[0]
fetch_map[name + ".lod"] = np.array(output_tensors[i].lod()[0]).astype('int32')
return fetch_map
......@@ -43,7 +43,6 @@ class DAGExecutor(object):
dag_conf = server_conf["dag"]
self._retry = dag_conf["retry"]
client_type = dag_conf["client_type"]
self._server_use_profile = dag_conf["use_profile"]
channel_size = dag_conf["channel_size"]
self._is_thread_op = dag_conf["is_thread_op"]
......@@ -61,7 +60,7 @@ class DAGExecutor(object):
self._is_thread_op, tracer_interval_s, server_worker_num)
self._dag = DAG(self.name, response_op, self._server_use_profile,
self._is_thread_op, client_type, channel_size,
self._is_thread_op, channel_size,
build_dag_each_worker, self._tracer)
(in_channel, out_channel, pack_rpc_func,
unpack_rpc_func) = self._dag.build()
......@@ -324,13 +323,12 @@ class DAGExecutor(object):
class DAG(object):
def __init__(self, request_name, response_op, use_profile, is_thread_op,
client_type, channel_size, build_dag_each_worker, tracer):
channel_size, build_dag_each_worker, tracer):
self._request_name = request_name
self._response_op = response_op
self._use_profile = use_profile
self._is_thread_op = is_thread_op
self._channel_size = channel_size
self._client_type = client_type
self._build_dag_each_worker = build_dag_each_worker
self._tracer = tracer
if not self._is_thread_op:
......@@ -571,10 +569,10 @@ class DAG(object):
op.set_tracer(self._tracer)
if self._is_thread_op:
self._threads_or_proces.extend(
op.start_with_thread(self._client_type))
op.start_with_thread())
else:
self._threads_or_proces.extend(
op.start_with_process(self._client_type))
op.start_with_process())
_LOGGER.info("[DAG] start")
# not join yet
......@@ -582,7 +580,8 @@ class DAG(object):
def join(self):
for x in self._threads_or_proces:
x.join()
if x is not None:
x.join()
def stop(self):
for chl in self._channels:
......
......@@ -21,7 +21,7 @@ try:
except ImportError:
from paddle_serving_server import OpMaker, OpSeqMaker, Server
PACKAGE_VERSION = "CPU"
from . import util
import util
_LOGGER = logging.getLogger(__name__)
_workdir_name_gen = util.NameGenerator("workdir_")
......@@ -132,3 +132,28 @@ class LocalRpcServiceHandler(object):
self._server_pros.append(p)
for p in self._server_pros:
p.start()
class LocalPredictorServiceHandler(LocalRpcServiceHandler):
def prepare_server(self):
from paddle_serving_app.local_predict import Debugger
gpuid = self._devices
if gpuid == -1:
gpu = False
else:
gpu = True
self.predictor = Debugger()
self.predictor.load_model_config(model_path=self._model_config, gpu=gpu, profile=False, cpu_num=1)
def get_client(self):
if self.predictor is None:
raise ValueError("local predictor not yet created.")
return self.predictor
def get_fetch_list(self):
if self.predictor is None:
raise ValueError("local predictor not yet created.")
return self.predictor.fetch_names_
def start_server(self):
pass
......@@ -128,29 +128,44 @@ class Op(object):
_LOGGER.info("local_service_conf: {}".format(
local_service_conf))
model_config = local_service_conf.get("model_config")
self.client_type = local_service_conf.get("client_type")
_LOGGER.info("model_config: {}".format(model_config))
if model_config is None:
self.with_serving = False
else:
# local rpc service
self.with_serving = True
service_handler = local_rpc_service_handler.LocalRpcServiceHandler(
model_config=model_config,
workdir=local_service_conf["workdir"],
thread_num=local_service_conf["thread_num"],
devices=local_service_conf["devices"],
mem_optim=local_service_conf["mem_optim"],
ir_optim=local_service_conf["ir_optim"])
service_handler.prepare_server() # get fetch_list
serivce_ports = service_handler.get_port_list()
self._server_endpoints = [
"127.0.0.1:{}".format(p) for p in serivce_ports
]
if self._client_config is None:
self._client_config = service_handler.get_client_config(
)
if self._fetch_names is None:
self._fetch_names = service_handler.get_fetch_list()
if self.client_type == "brpc" or self.client_type == "grpc":
service_handler = local_rpc_service_handler.LocalRpcServiceHandler(
model_config=model_config,
workdir=local_service_conf["workdir"],
thread_num=local_service_conf["thread_num"],
devices=local_service_conf["devices"],
mem_optim=local_service_conf["mem_optim"],
ir_optim=local_service_conf["ir_optim"])
service_handler.prepare_server() # get fetch_list
serivce_ports = service_handler.get_port_list()
self._server_endpoints = [
"127.0.0.1:{}".format(p) for p in serivce_ports
]
if self._client_config is None:
self._client_config = service_handler.get_client_config(
)
if self._fetch_names is None:
self._fetch_names = service_handler.get_fetch_list()
elif self.client_type == "local_predictor":
service_handler = local_rpc_service_handler.LocalPredictorServiceHandler(
model_config=model_config,
workdir=local_service_conf["workdir"],
thread_num=local_service_conf["thread_num"],
devices=local_service_conf["devices"])
service_handler.prepare_server() # get fetch_list
self.local_predictor = service_handler.get_client()
if self._client_config is None:
self._client_config = service_handler.get_client_config(
)
if self._fetch_names is None:
self._fetch_names = service_handler.get_fetch_list()
self._local_rpc_service_handler = service_handler
else:
self.with_serving = True
......@@ -215,21 +230,27 @@ class Op(object):
def set_tracer(self, tracer):
self._tracer = tracer
def init_client(self, client_type, client_config, server_endpoints,
def init_client(self, client_config, server_endpoints,
fetch_names):
print("init client", fetch_names)
if self.with_serving == False:
_LOGGER.info("Op({}) has no client (and it also do not "
"run the process function)".format(self.name))
return None
if client_type == 'brpc':
if self.client_type == 'brpc':
client = Client()
client.load_client_config(client_config)
elif client_type == 'grpc':
elif self.client_type == 'grpc':
client = MultiLangClient()
elif self.client_type == 'local_predictor':
if self.local_predictor is None:
raise ValueError("local predictor not yet created")
client = self.local_predictor
else:
raise ValueError("Failed to init client: unknow client "
"type {}".format(client_type))
client.connect(server_endpoints)
"type {}".format(self.client_type))
if self.client_type != "local_predictor":
client.connect(server_endpoints)
self._fetch_names = fetch_names
return client
......@@ -292,14 +313,19 @@ class Op(object):
return input_dict
def process(self, feed_batch, typical_logid):
print("now we start process")
err, err_info = ChannelData.check_batch_npdata(feed_batch)
if err != 0:
_LOGGER.critical(
self._log("Failed to run process: {}. Please override "
"preprocess func.".format(err_info)))
os._exit(-1)
call_result = self.client.predict(
feed=feed_batch, fetch=self._fetch_names, log_id=typical_logid)
if self.client_type == "local_predictor":
call_result = self.client.predict(feed=feed_batch[0], fetch=self._fetch_names, log_id=typical_logid)
else:
call_result = self.client.predict(
feed=feed_batch, fetch=self._fetch_names, log_id=typical_logid)
print("now we end predict")
if isinstance(self.client, MultiLangClient):
if call_result is None or call_result["serving_status_code"] != 0:
return None
......@@ -347,23 +373,23 @@ class Op(object):
for channel in channels:
channel.push(data, name)
def start_with_process(self, client_type):
def start_with_process(self):
trace_buffer = None
if self._tracer is not None:
trace_buffer = self._tracer.data_buffer()
proces = []
process= []
for concurrency_idx in range(self.concurrency):
p = multiprocessing.Process(
target=self._run,
args=(concurrency_idx, self._get_input_channel(),
self._get_output_channels(), client_type, False,
self._get_output_channels(), False,
trace_buffer))
p.daemon = True
p.start()
proces.append(p)
return proces
process.append(p)
return process
def start_with_thread(self, client_type):
def start_with_thread(self):
trace_buffer = None
if self._tracer is not None:
trace_buffer = self._tracer.data_buffer()
......@@ -372,7 +398,7 @@ class Op(object):
t = threading.Thread(
target=self._run,
args=(concurrency_idx, self._get_input_channel(),
self._get_output_channels(), client_type, True,
self._get_output_channels(), True,
trace_buffer))
# When a process exits, it attempts to terminate
# all of its daemonic child processes.
......@@ -537,6 +563,7 @@ class Op(object):
lod_offset_right = lod_offset[data_offset_right]
midped_data_dict[data_id][name] = value[
lod_offset_left:lod_offset_right]
print(lod_offset[data_offset_left:data_offset_right + 1], lod_offset[data_offset_left])
midped_data_dict[data_id][lod_offset_name] = \
lod_offset[data_offset_left:data_offset_right + 1] - lod_offset[data_offset_left]
else:
......@@ -652,7 +679,7 @@ class Op(object):
return parsed_data_dict, need_profile_dict, profile_dict
def _run(self, concurrency_idx, input_channel, output_channels, client_type,
def _run(self, concurrency_idx, input_channel, output_channels,
is_thread_op, trace_buffer):
op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
tid = threading.current_thread().ident
......@@ -660,7 +687,7 @@ class Op(object):
# init op
profiler = None
try:
profiler = self._initialize(is_thread_op, client_type,
profiler = self._initialize(is_thread_op,
concurrency_idx)
except Exception as e:
_LOGGER.critical(
......@@ -801,7 +828,7 @@ class Op(object):
except Queue.Full:
break
def _initialize(self, is_thread_op, client_type, concurrency_idx):
def _initialize(self, is_thread_op, concurrency_idx):
if is_thread_op:
with self._for_init_op_lock:
if not self._succ_init_op:
......@@ -809,7 +836,7 @@ class Op(object):
self.concurrency_idx = None
# init client
self.client = self.init_client(
client_type, self._client_config,
self._client_config,
self._server_endpoints, self._fetch_names)
# user defined
self.init_op()
......@@ -818,7 +845,7 @@ class Op(object):
else:
self.concurrency_idx = concurrency_idx
# init client
self.client = self.init_client(client_type, self._client_config,
self.client = self.init_client(self._client_config,
self._server_endpoints,
self._fetch_names)
# user defined
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册