提交 73e5be16 编写于 作者: B barriery

add local_rpc_server_handler into pipeline

上级 ec039781
......@@ -337,6 +337,7 @@ class DAG(object):
self._manager = PipelineProcSyncManager()
_LOGGER.info("[DAG] Succ init")
@staticmethod
def get_use_ops(self, response_op):
unique_names = set()
used_ops = set()
......@@ -426,7 +427,7 @@ class DAG(object):
_LOGGER.critical("Failed to build DAG: ResponseOp"
" has not been set.")
os._exit(-1)
used_ops, out_degree_ops = self.get_use_ops(response_op)
used_ops, out_degree_ops = DAG.get_use_ops(response_op)
if not self._build_dag_each_worker:
_LOGGER.info("================= USED OP =================")
for op in used_ops:
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
try:
from paddle_serving_server import OpMaker, OpSeqMaker, Server
except ImportError:
from paddle_serving_server_gpu import OpMaker, OpSeqMaker, Server
from .util import AvailablePortGenerator, NameGenerator
_LOGGER = logging.getLogger(__name__)
_workdir_name_gen = NameGenerator("workdir_")
_available_port_gen = AvailablePortGenerator()
class DefaultRpcServerHandler(object):
def __init__(self,
model_config=None,
workdir=None,
thread_num=2,
devices="",
mem_optim=True,
ir_optim=False,
available_port_generator=None):
if available_port_generator is None:
available_port_generator = _available_port_gen
self._port_list = []
if devices == "":
# cpu
devices = [-1]
self._port_list.append(available_port_generator.next())
else:
# gpu
devices = [int(x) for x in devices.split(",")]
for _ in devices:
self._port_list.append(available_port_generator.next())
self._workdir = workdir
self._devices = devices
self._thread_num = thread_num
self._mem_optim = mem_optim
self._ir_optim = ir_optim
self._model_config = model_config
self._rpc_service_list = []
self._server_pros = []
self._fetch_vars = None
def get_fetch_list(self):
return self._fetch_vars
def get_port_list(self):
return self._port_list
def set_model_config(self, model_config):
self._model_config = model_config
def _prepare_one_server(self, workdir, port, gpuid, thread_num, mem_optim,
ir_optim):
device = "gpu"
if gpuid == -1:
device = "cpu"
op_maker = serving.OpMaker()
read_op = op_maker.create('general_reader')
general_infer_op = op_maker.create('general_infer')
general_response_op = op_maker.create('general_response')
op_seq_maker = OpSeqMaker()
op_seq_maker.add_op(read_op)
op_seq_maker.add_op(general_infer_op)
op_seq_maker.add_op(general_response_op)
server = Server()
server.set_op_sequence(op_seq_maker.get_op_sequence())
server.set_num_threads(thread_num)
server.set_memory_optimize(mem_optim)
server.set_ir_optimize(ir_optim)
server.load_model_config(self._model_config)
if gpuid >= 0:
server.set_gpuid(gpuid)
server.prepare_server(workdir=workdir, port=port, device=device)
if self._fetch_vars is None:
self._fetch_vars = server.get_fetch_list()
return server
def _start_one_server(self, service_idx):
self._rpc_service_list[service_idx].run_server()
def prepare_server(self):
for i, device_id in enumerate(self._devices):
if self._workdir is not None:
workdir = "{}_{}".format(self._workdir, i)
else:
workdir = _workdir_name_gen.next()
self._rpc_service_list.append(
self._prepare_one_server(
workdir,
self._port_list[i],
device_id,
thread_num=self._thread_num,
mem_optim=self._mem_optim,
ir_optim=self._ir_optim))
def start_server(self):
for i, service in enumerate(self._rpc_service_list):
p = Process(target=self._start_one_server, args=(i, ))
self._server_pros.append(p)
for p in self._server_pros:
p.start()
......@@ -54,18 +54,33 @@ class Op(object):
timeout=-1,
retry=1,
batch_size=1,
auto_batching_timeout=None):
auto_batching_timeout=None,
local_rpc_server_handler=None):
if name is None:
name = _op_name_gen.next()
self.name = name # to identify the type of OP, it must be globally unique
self.concurrency = concurrency # amount of concurrency
self.set_input_ops(input_ops)
self._server_endpoints = server_endpoints
self.with_serving = False
if len(self._server_endpoints) != 0:
if len(server_endpoints) != 0:
# remote service
self.with_serving = True
self._client_config = client_config
self._server_endpoints = server_endpoints
self._client_config = client_config
else:
if local_rpc_server_handler is not None:
# local rpc service
self.with_serving = True
serivce_ports = local_rpc_server_handler.get_port_list()
self._server_endpoints = [
"127.0.0.1:{}".format(p) for p in serivce_ports
]
local_rpc_server_handler.set_client_config(client_config)
self._client_config = client_config
else:
self.with_serving = False
self._local_rpc_server_handler = local_rpc_server_handler
self._fetch_names = fetch_list
if timeout > 0:
......@@ -113,6 +128,16 @@ class Op(object):
self._succ_init_op = False
self._succ_close_op = False
def launch_local_rpc_service(self):
if self._local_rpc_server_handler is None:
raise ValueError("Failed to launch local rpc service: "
"local_rpc_server_handler is None.")
port = self._local_rpc_server_handler.get_port_list()
self._local_rpc_server_handler.prepare_server()
self._local_rpc_server_handler.start_server()
_LOGGER.info("Op({}) launch local rpc service at port: {}"
.format(self.name, port))
def use_default_auto_batching_config(self):
if self._batch_size != 1:
_LOGGER.warning("Op({}) reset batch_size=1 (original: {})"
......
......@@ -24,7 +24,7 @@ import yaml
from .proto import pipeline_service_pb2_grpc
from .operator import ResponseOp
from .dag import DAGExecutor
from .dag import DAGExecutor, DAG
from .util import AvailablePortGenerator
_LOGGER = logging.getLogger(__name__)
......@@ -122,6 +122,12 @@ class PipelineServer(object):
self._conf = conf
def start_local_rpc_service(self):
# only brpc now
used_op, _ = DAG.get_use_ops(self._response_op)
for op in used_op:
op.launch_local_rpc_service()
def run_server(self):
if self._build_dag_each_worker:
with _reserve_port(self._port) as port:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册