提交 eea1cc66 编写于 作者: B barriery

add local_rpc_service_handler

上级 73e5be16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
from paddle_serving_server_gpu.pipeline import Op, RequestOp, ResponseOp
from paddle_serving_server_gpu.pipeline import PipelineServer
from paddle_serving_server_gpu.pipeline.proto import pipeline_service_pb2
from paddle_serving_server_gpu.pipeline.channel import ChannelDataEcode
from paddle_serving_server_gpu.pipeline import LocalRpcServiceHandler
import numpy as np
import cv2
import time
import base64
import json
from paddle_serving_app.reader import OCRReader
from paddle_serving_app.reader import Sequential, ResizeByFactor
from paddle_serving_app.reader import Div, Normalize, Transpose
from paddle_serving_app.reader import DBPostProcess, FilterBoxes, GetRotateCropImage, SortedBoxes
import time
import re
import base64
import logging
_LOGGER = logging.getLogger()
class DetOp(Op):
def init_op(self):
self.det_preprocess = Sequential([
ResizeByFactor(32, 960), Div(255),
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), Transpose(
(2, 0, 1))
])
self.filter_func = FilterBoxes(10, 10)
self.post_func = DBPostProcess({
"thresh": 0.3,
"box_thresh": 0.5,
"max_candidates": 1000,
"unclip_ratio": 1.5,
"min_size": 3
})
def preprocess(self, input_dicts):
(_, input_dict), = input_dicts.items()
data = base64.b64decode(input_dict["image"].encode('utf8'))
data = np.fromstring(data, np.uint8)
# Note: class variables(self.var) can only be used in process op mode
self.im = cv2.imdecode(data, cv2.IMREAD_COLOR)
self.ori_h, self.ori_w, _ = self.im.shape
det_img = self.det_preprocess(self.im)
_, self.new_h, self.new_w = det_img.shape
return {"image": det_img}
def postprocess(self, input_dicts, fetch_dict):
det_out = fetch_dict["concat_1.tmp_0"]
ratio_list = [
float(self.new_h) / self.ori_h, float(self.new_w) / self.ori_w
]
dt_boxes_list = self.post_func(det_out, [ratio_list])
dt_boxes = self.filter_func(dt_boxes_list[0], [self.ori_h, self.ori_w])
out_dict = {"dt_boxes": dt_boxes, "image": self.im}
return out_dict
class RecOp(Op):
def init_op(self):
self.ocr_reader = OCRReader()
self.get_rotate_crop_image = GetRotateCropImage()
self.sorted_boxes = SortedBoxes()
def preprocess(self, input_dicts):
(_, input_dict), = input_dicts.items()
im = input_dict["image"]
dt_boxes = input_dict["dt_boxes"]
dt_boxes = self.sorted_boxes(dt_boxes)
feed_list = []
img_list = []
max_wh_ratio = 0
for i, dtbox in enumerate(dt_boxes):
boximg = self.get_rotate_crop_image(im, dt_boxes[i])
img_list.append(boximg)
h, w = boximg.shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for img in img_list:
norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio)
feed = {"image": norm_img}
feed_list.append(feed)
return feed_list
def postprocess(self, input_dicts, fetch_dict):
rec_res = self.ocr_reader.postprocess(fetch_dict, with_score=True)
res_lst = []
for res in rec_res:
res_lst.append(res[0])
res = {"res": str(res_lst)}
return res
read_op = RequestOp()
det_op = DetOp(
name="det",
input_ops=[read_op],
local_rpc_service_handler=LocalRpcServiceHandler(
model_config="ocr_det_model",
workdir="det_workdir", # defalut: "workdir"
thread_num=2, # defalut: 2
devices="0", # gpu0. defalut: "" (cpu)
mem_optim=True, # defalut: True
ir_optim=False, # defalut: False
available_port_generator=None), # defalut: None
concurrency=1)
rec_op = RecOp(
name="rec",
input_ops=[det_op],
local_rpc_service_handler=LocalRpcServiceHandler(
model_config="ocr_rec_model"),
concurrency=1)
response_op = ResponseOp(input_ops=[rec_op])
server = PipelineServer()
server.set_response_op(response_op)
server.start_local_rpc_service() # add this line
server.prepare_server('config.yml')
server.run_server()
......@@ -15,4 +15,5 @@ import logger # this module must be the first to import
from operator import Op, RequestOp, ResponseOp
from pipeline_server import PipelineServer
from pipeline_client import PipelineClient
from local_rpc_service_handler import LocalRpcServiceHandler
from analyse import Analyst
......@@ -338,7 +338,7 @@ class DAG(object):
_LOGGER.info("[DAG] Succ init")
@staticmethod
def get_use_ops(self, response_op):
def get_use_ops(response_op):
unique_names = set()
used_ops = set()
succ_ops_of_use_op = {} # {op_name: succ_ops}
......@@ -431,7 +431,7 @@ class DAG(object):
if not self._build_dag_each_worker:
_LOGGER.info("================= USED OP =================")
for op in used_ops:
if op.name != self._request_name:
if not isinstance(op, RequestOp):
_LOGGER.info(op.name)
_LOGGER.info("-------------------------------------------")
if len(used_ops) <= 1:
......
......@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import logging
import multiprocessing
try:
from paddle_serving_server import OpMaker, OpSeqMaker, Server
except ImportError:
......@@ -24,9 +26,9 @@ _workdir_name_gen = NameGenerator("workdir_")
_available_port_gen = AvailablePortGenerator()
class DefaultRpcServerHandler(object):
class LocalRpcServiceHandler(object):
def __init__(self,
model_config=None,
model_config,
workdir=None,
thread_num=2,
devices="",
......@@ -36,22 +38,26 @@ class DefaultRpcServerHandler(object):
if available_port_generator is None:
available_port_generator = _available_port_gen
self._model_config = model_config
self._port_list = []
if devices == "":
# cpu
devices = [-1]
self._port_list.append(available_port_generator.next())
_LOGGER.info("Model({}) will be launch in cpu device. Port({})"
.format(model_config, self._port_list))
else:
# gpu
devices = [int(x) for x in devices.split(",")]
for _ in devices:
self._port_list.append(available_port_generator.next())
_LOGGER.info("Model({}) will be launch in gpu device: {}. Port({})"
.format(model_config, devices, self._port_list))
self._workdir = workdir
self._devices = devices
self._thread_num = thread_num
self._mem_optim = mem_optim
self._ir_optim = ir_optim
self._model_config = model_config
self._rpc_service_list = []
self._server_pros = []
......@@ -63,15 +69,15 @@ class DefaultRpcServerHandler(object):
def get_port_list(self):
return self._port_list
def set_model_config(self, model_config):
self._model_config = model_config
def get_client_config(self):
return os.path.join(self._model_config, "serving_server_conf.prototxt")
def _prepare_one_server(self, workdir, port, gpuid, thread_num, mem_optim,
ir_optim):
device = "gpu"
if gpuid == -1:
device = "cpu"
op_maker = serving.OpMaker()
op_maker = OpMaker()
read_op = op_maker.create('general_reader')
general_infer_op = op_maker.create('general_infer')
general_response_op = op_maker.create('general_response')
......@@ -115,7 +121,9 @@ class DefaultRpcServerHandler(object):
def start_server(self):
for i, service in enumerate(self._rpc_service_list):
p = Process(target=self._start_one_server, args=(i, ))
p = multiprocessing.Process(
target=self._start_one_server, args=(i, ))
p.daemon = True
self._server_pros.append(p)
for p in self._server_pros:
p.start()
......@@ -55,7 +55,7 @@ class Op(object):
retry=1,
batch_size=1,
auto_batching_timeout=None,
local_rpc_server_handler=None):
local_rpc_service_handler=None):
if name is None:
name = _op_name_gen.next()
self.name = name # to identify the type of OP, it must be globally unique
......@@ -65,23 +65,26 @@ class Op(object):
if len(server_endpoints) != 0:
# remote service
self.with_serving = True
self._server_endpoints = server_endpoints
self._client_config = client_config
else:
if local_rpc_server_handler is not None:
if local_rpc_service_handler is not None:
# local rpc service
self.with_serving = True
serivce_ports = local_rpc_server_handler.get_port_list()
self._server_endpoints = [
local_rpc_service_handler.prepare_server() # get fetch_list
serivce_ports = local_rpc_service_handler.get_port_list()
server_endpoints = [
"127.0.0.1:{}".format(p) for p in serivce_ports
]
local_rpc_server_handler.set_client_config(client_config)
self._client_config = client_config
if client_config is None:
client_config = local_rpc_service_handler.get_client_config(
)
if len(fetch_list) == 0:
fetch_list = local_rpc_service_handler.get_fetch_list()
else:
self.with_serving = False
self._local_rpc_server_handler = local_rpc_server_handler
self._local_rpc_service_handler = local_rpc_service_handler
self._server_endpoints = server_endpoints
self._fetch_names = fetch_list
self._client_config = client_config
if timeout > 0:
self._timeout = timeout / 1000.0
......@@ -129,13 +132,14 @@ class Op(object):
self._succ_close_op = False
def launch_local_rpc_service(self):
if self._local_rpc_server_handler is None:
raise ValueError("Failed to launch local rpc service: "
"local_rpc_server_handler is None.")
port = self._local_rpc_server_handler.get_port_list()
self._local_rpc_server_handler.prepare_server()
self._local_rpc_server_handler.start_server()
_LOGGER.info("Op({}) launch local rpc service at port: {}"
if self._local_rpc_service_handler is None:
_LOGGER.warning(
self._log("Failed to launch local rpc"
" service: local_rpc_service_handler is None."))
return
port = self._local_rpc_service_handler.get_port_list()
self._local_rpc_service_handler.start_server()
_LOGGER.info("Op({}) use local rpc service at port: {}"
.format(self.name, port))
def use_default_auto_batching_config(self):
......
......@@ -23,7 +23,7 @@ import multiprocessing
import yaml
from .proto import pipeline_service_pb2_grpc
from .operator import ResponseOp
from .operator import ResponseOp, RequestOp
from .dag import DAGExecutor, DAG
from .util import AvailablePortGenerator
......@@ -126,7 +126,8 @@ class PipelineServer(object):
# only brpc now
used_op, _ = DAG.get_use_ops(self._response_op)
for op in used_op:
op.launch_local_rpc_service()
if not isinstance(op, RequestOp):
op.launch_local_rpc_service()
def run_server(self):
if self._build_dag_each_worker:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册