提交 eea1cc66 编写于 作者: B barriery

add local_rpc_service_handler

上级 73e5be16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
from paddle_serving_server_gpu.pipeline import Op, RequestOp, ResponseOp
from paddle_serving_server_gpu.pipeline import PipelineServer
from paddle_serving_server_gpu.pipeline.proto import pipeline_service_pb2
from paddle_serving_server_gpu.pipeline.channel import ChannelDataEcode
from paddle_serving_server_gpu.pipeline import LocalRpcServiceHandler
import numpy as np
import cv2
import time
import base64
import json
from paddle_serving_app.reader import OCRReader
from paddle_serving_app.reader import Sequential, ResizeByFactor
from paddle_serving_app.reader import Div, Normalize, Transpose
from paddle_serving_app.reader import DBPostProcess, FilterBoxes, GetRotateCropImage, SortedBoxes
import time
import re
import base64
import logging
_LOGGER = logging.getLogger()
class DetOp(Op):
def init_op(self):
self.det_preprocess = Sequential([
ResizeByFactor(32, 960), Div(255),
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), Transpose(
(2, 0, 1))
])
self.filter_func = FilterBoxes(10, 10)
self.post_func = DBPostProcess({
"thresh": 0.3,
"box_thresh": 0.5,
"max_candidates": 1000,
"unclip_ratio": 1.5,
"min_size": 3
})
def preprocess(self, input_dicts):
(_, input_dict), = input_dicts.items()
data = base64.b64decode(input_dict["image"].encode('utf8'))
data = np.fromstring(data, np.uint8)
# Note: class variables(self.var) can only be used in process op mode
self.im = cv2.imdecode(data, cv2.IMREAD_COLOR)
self.ori_h, self.ori_w, _ = self.im.shape
det_img = self.det_preprocess(self.im)
_, self.new_h, self.new_w = det_img.shape
return {"image": det_img}
def postprocess(self, input_dicts, fetch_dict):
det_out = fetch_dict["concat_1.tmp_0"]
ratio_list = [
float(self.new_h) / self.ori_h, float(self.new_w) / self.ori_w
]
dt_boxes_list = self.post_func(det_out, [ratio_list])
dt_boxes = self.filter_func(dt_boxes_list[0], [self.ori_h, self.ori_w])
out_dict = {"dt_boxes": dt_boxes, "image": self.im}
return out_dict
class RecOp(Op):
def init_op(self):
self.ocr_reader = OCRReader()
self.get_rotate_crop_image = GetRotateCropImage()
self.sorted_boxes = SortedBoxes()
def preprocess(self, input_dicts):
(_, input_dict), = input_dicts.items()
im = input_dict["image"]
dt_boxes = input_dict["dt_boxes"]
dt_boxes = self.sorted_boxes(dt_boxes)
feed_list = []
img_list = []
max_wh_ratio = 0
for i, dtbox in enumerate(dt_boxes):
boximg = self.get_rotate_crop_image(im, dt_boxes[i])
img_list.append(boximg)
h, w = boximg.shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for img in img_list:
norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio)
feed = {"image": norm_img}
feed_list.append(feed)
return feed_list
def postprocess(self, input_dicts, fetch_dict):
rec_res = self.ocr_reader.postprocess(fetch_dict, with_score=True)
res_lst = []
for res in rec_res:
res_lst.append(res[0])
res = {"res": str(res_lst)}
return res
read_op = RequestOp()
det_op = DetOp(
name="det",
input_ops=[read_op],
local_rpc_service_handler=LocalRpcServiceHandler(
model_config="ocr_det_model",
workdir="det_workdir", # defalut: "workdir"
thread_num=2, # defalut: 2
devices="0", # gpu0. defalut: "" (cpu)
mem_optim=True, # defalut: True
ir_optim=False, # defalut: False
available_port_generator=None), # defalut: None
concurrency=1)
rec_op = RecOp(
name="rec",
input_ops=[det_op],
local_rpc_service_handler=LocalRpcServiceHandler(
model_config="ocr_rec_model"),
concurrency=1)
response_op = ResponseOp(input_ops=[rec_op])
server = PipelineServer()
server.set_response_op(response_op)
server.start_local_rpc_service() # add this line
server.prepare_server('config.yml')
server.run_server()
...@@ -15,4 +15,5 @@ import logger # this module must be the first to import ...@@ -15,4 +15,5 @@ import logger # this module must be the first to import
from operator import Op, RequestOp, ResponseOp from operator import Op, RequestOp, ResponseOp
from pipeline_server import PipelineServer from pipeline_server import PipelineServer
from pipeline_client import PipelineClient from pipeline_client import PipelineClient
from local_rpc_service_handler import LocalRpcServiceHandler
from analyse import Analyst from analyse import Analyst
...@@ -338,7 +338,7 @@ class DAG(object): ...@@ -338,7 +338,7 @@ class DAG(object):
_LOGGER.info("[DAG] Succ init") _LOGGER.info("[DAG] Succ init")
@staticmethod @staticmethod
def get_use_ops(self, response_op): def get_use_ops(response_op):
unique_names = set() unique_names = set()
used_ops = set() used_ops = set()
succ_ops_of_use_op = {} # {op_name: succ_ops} succ_ops_of_use_op = {} # {op_name: succ_ops}
...@@ -431,7 +431,7 @@ class DAG(object): ...@@ -431,7 +431,7 @@ class DAG(object):
if not self._build_dag_each_worker: if not self._build_dag_each_worker:
_LOGGER.info("================= USED OP =================") _LOGGER.info("================= USED OP =================")
for op in used_ops: for op in used_ops:
if op.name != self._request_name: if not isinstance(op, RequestOp):
_LOGGER.info(op.name) _LOGGER.info(op.name)
_LOGGER.info("-------------------------------------------") _LOGGER.info("-------------------------------------------")
if len(used_ops) <= 1: if len(used_ops) <= 1:
......
...@@ -12,7 +12,9 @@ ...@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import logging import logging
import multiprocessing
try: try:
from paddle_serving_server import OpMaker, OpSeqMaker, Server from paddle_serving_server import OpMaker, OpSeqMaker, Server
except ImportError: except ImportError:
...@@ -24,9 +26,9 @@ _workdir_name_gen = NameGenerator("workdir_") ...@@ -24,9 +26,9 @@ _workdir_name_gen = NameGenerator("workdir_")
_available_port_gen = AvailablePortGenerator() _available_port_gen = AvailablePortGenerator()
class DefaultRpcServerHandler(object): class LocalRpcServiceHandler(object):
def __init__(self, def __init__(self,
model_config=None, model_config,
workdir=None, workdir=None,
thread_num=2, thread_num=2,
devices="", devices="",
...@@ -36,22 +38,26 @@ class DefaultRpcServerHandler(object): ...@@ -36,22 +38,26 @@ class DefaultRpcServerHandler(object):
if available_port_generator is None: if available_port_generator is None:
available_port_generator = _available_port_gen available_port_generator = _available_port_gen
self._model_config = model_config
self._port_list = [] self._port_list = []
if devices == "": if devices == "":
# cpu # cpu
devices = [-1] devices = [-1]
self._port_list.append(available_port_generator.next()) self._port_list.append(available_port_generator.next())
_LOGGER.info("Model({}) will be launch in cpu device. Port({})"
.format(model_config, self._port_list))
else: else:
# gpu # gpu
devices = [int(x) for x in devices.split(",")] devices = [int(x) for x in devices.split(",")]
for _ in devices: for _ in devices:
self._port_list.append(available_port_generator.next()) self._port_list.append(available_port_generator.next())
_LOGGER.info("Model({}) will be launch in gpu device: {}. Port({})"
.format(model_config, devices, self._port_list))
self._workdir = workdir self._workdir = workdir
self._devices = devices self._devices = devices
self._thread_num = thread_num self._thread_num = thread_num
self._mem_optim = mem_optim self._mem_optim = mem_optim
self._ir_optim = ir_optim self._ir_optim = ir_optim
self._model_config = model_config
self._rpc_service_list = [] self._rpc_service_list = []
self._server_pros = [] self._server_pros = []
...@@ -63,15 +69,15 @@ class DefaultRpcServerHandler(object): ...@@ -63,15 +69,15 @@ class DefaultRpcServerHandler(object):
def get_port_list(self): def get_port_list(self):
return self._port_list return self._port_list
def set_model_config(self, model_config): def get_client_config(self):
self._model_config = model_config return os.path.join(self._model_config, "serving_server_conf.prototxt")
def _prepare_one_server(self, workdir, port, gpuid, thread_num, mem_optim, def _prepare_one_server(self, workdir, port, gpuid, thread_num, mem_optim,
ir_optim): ir_optim):
device = "gpu" device = "gpu"
if gpuid == -1: if gpuid == -1:
device = "cpu" device = "cpu"
op_maker = serving.OpMaker() op_maker = OpMaker()
read_op = op_maker.create('general_reader') read_op = op_maker.create('general_reader')
general_infer_op = op_maker.create('general_infer') general_infer_op = op_maker.create('general_infer')
general_response_op = op_maker.create('general_response') general_response_op = op_maker.create('general_response')
...@@ -115,7 +121,9 @@ class DefaultRpcServerHandler(object): ...@@ -115,7 +121,9 @@ class DefaultRpcServerHandler(object):
def start_server(self): def start_server(self):
for i, service in enumerate(self._rpc_service_list): for i, service in enumerate(self._rpc_service_list):
p = Process(target=self._start_one_server, args=(i, )) p = multiprocessing.Process(
target=self._start_one_server, args=(i, ))
p.daemon = True
self._server_pros.append(p) self._server_pros.append(p)
for p in self._server_pros: for p in self._server_pros:
p.start() p.start()
...@@ -55,7 +55,7 @@ class Op(object): ...@@ -55,7 +55,7 @@ class Op(object):
retry=1, retry=1,
batch_size=1, batch_size=1,
auto_batching_timeout=None, auto_batching_timeout=None,
local_rpc_server_handler=None): local_rpc_service_handler=None):
if name is None: if name is None:
name = _op_name_gen.next() name = _op_name_gen.next()
self.name = name # to identify the type of OP, it must be globally unique self.name = name # to identify the type of OP, it must be globally unique
...@@ -65,23 +65,26 @@ class Op(object): ...@@ -65,23 +65,26 @@ class Op(object):
if len(server_endpoints) != 0: if len(server_endpoints) != 0:
# remote service # remote service
self.with_serving = True self.with_serving = True
self._server_endpoints = server_endpoints
self._client_config = client_config
else: else:
if local_rpc_server_handler is not None: if local_rpc_service_handler is not None:
# local rpc service # local rpc service
self.with_serving = True self.with_serving = True
serivce_ports = local_rpc_server_handler.get_port_list() local_rpc_service_handler.prepare_server() # get fetch_list
self._server_endpoints = [ serivce_ports = local_rpc_service_handler.get_port_list()
server_endpoints = [
"127.0.0.1:{}".format(p) for p in serivce_ports "127.0.0.1:{}".format(p) for p in serivce_ports
] ]
local_rpc_server_handler.set_client_config(client_config) if client_config is None:
self._client_config = client_config client_config = local_rpc_service_handler.get_client_config(
)
if len(fetch_list) == 0:
fetch_list = local_rpc_service_handler.get_fetch_list()
else: else:
self.with_serving = False self.with_serving = False
self._local_rpc_server_handler = local_rpc_server_handler self._local_rpc_service_handler = local_rpc_service_handler
self._server_endpoints = server_endpoints
self._fetch_names = fetch_list self._fetch_names = fetch_list
self._client_config = client_config
if timeout > 0: if timeout > 0:
self._timeout = timeout / 1000.0 self._timeout = timeout / 1000.0
...@@ -129,13 +132,14 @@ class Op(object): ...@@ -129,13 +132,14 @@ class Op(object):
self._succ_close_op = False self._succ_close_op = False
def launch_local_rpc_service(self): def launch_local_rpc_service(self):
if self._local_rpc_server_handler is None: if self._local_rpc_service_handler is None:
raise ValueError("Failed to launch local rpc service: " _LOGGER.warning(
"local_rpc_server_handler is None.") self._log("Failed to launch local rpc"
port = self._local_rpc_server_handler.get_port_list() " service: local_rpc_service_handler is None."))
self._local_rpc_server_handler.prepare_server() return
self._local_rpc_server_handler.start_server() port = self._local_rpc_service_handler.get_port_list()
_LOGGER.info("Op({}) launch local rpc service at port: {}" self._local_rpc_service_handler.start_server()
_LOGGER.info("Op({}) use local rpc service at port: {}"
.format(self.name, port)) .format(self.name, port))
def use_default_auto_batching_config(self): def use_default_auto_batching_config(self):
......
...@@ -23,7 +23,7 @@ import multiprocessing ...@@ -23,7 +23,7 @@ import multiprocessing
import yaml import yaml
from .proto import pipeline_service_pb2_grpc from .proto import pipeline_service_pb2_grpc
from .operator import ResponseOp from .operator import ResponseOp, RequestOp
from .dag import DAGExecutor, DAG from .dag import DAGExecutor, DAG
from .util import AvailablePortGenerator from .util import AvailablePortGenerator
...@@ -126,7 +126,8 @@ class PipelineServer(object): ...@@ -126,7 +126,8 @@ class PipelineServer(object):
# only brpc now # only brpc now
used_op, _ = DAG.get_use_ops(self._response_op) used_op, _ = DAG.get_use_ops(self._response_op)
for op in used_op: for op in used_op:
op.launch_local_rpc_service() if not isinstance(op, RequestOp):
op.launch_local_rpc_service()
def run_server(self): def run_server(self):
if self._build_dag_each_worker: if self._build_dag_each_worker:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册