提交 dc91636b 编写于 作者: B barriery

config op from dict

上级 1e19ccac
...@@ -19,11 +19,10 @@ try: ...@@ -19,11 +19,10 @@ try:
from paddle_serving_server_gpu import OpMaker, OpSeqMaker, Server from paddle_serving_server_gpu import OpMaker, OpSeqMaker, Server
except ImportError: except ImportError:
from paddle_serving_server import OpMaker, OpSeqMaker, Server from paddle_serving_server import OpMaker, OpSeqMaker, Server
from .util import AvailablePortGenerator, NameGenerator from . import util
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
_workdir_name_gen = NameGenerator("workdir_") _workdir_name_gen = util.NameGenerator("workdir_")
_available_port_gen = AvailablePortGenerator()
class LocalRpcServiceHandler(object): class LocalRpcServiceHandler(object):
...@@ -36,7 +35,7 @@ class LocalRpcServiceHandler(object): ...@@ -36,7 +35,7 @@ class LocalRpcServiceHandler(object):
ir_optim=False, ir_optim=False,
available_port_generator=None): available_port_generator=None):
if available_port_generator is None: if available_port_generator is None:
available_port_generator = _available_port_gen available_port_generator = util.GetAvailablePortGenerator()
self._model_config = model_config self._model_config = model_config
self._port_list = [] self._port_list = []
......
...@@ -38,6 +38,7 @@ from .channel import (ThreadChannel, ProcessChannel, ChannelDataEcode, ...@@ -38,6 +38,7 @@ from .channel import (ThreadChannel, ProcessChannel, ChannelDataEcode,
ChannelTimeoutError) ChannelTimeoutError)
from .util import NameGenerator from .util import NameGenerator
from .profiler import UnsafeTimeProfiler as TimeProfiler from .profiler import UnsafeTimeProfiler as TimeProfiler
from . import local_rpc_service_handler
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
_op_name_gen = NameGenerator("Op") _op_name_gen = NameGenerator("Op")
...@@ -47,13 +48,13 @@ class Op(object): ...@@ -47,13 +48,13 @@ class Op(object):
def __init__(self, def __init__(self,
name=None, name=None,
input_ops=[], input_ops=[],
server_endpoints=[], server_endpoints=None,
fetch_list=[], fetch_list=None,
client_config=None, client_config=None,
concurrency=1, concurrency=None,
timeout=-1, timeout=None,
retry=1, retry=None,
batch_size=1, batch_size=None,
auto_batching_timeout=None, auto_batching_timeout=None,
local_rpc_service_handler=None): local_rpc_service_handler=None):
if name is None: if name is None:
...@@ -62,41 +63,48 @@ class Op(object): ...@@ -62,41 +63,48 @@ class Op(object):
self.concurrency = concurrency # amount of concurrency self.concurrency = concurrency # amount of concurrency
self.set_input_ops(input_ops) self.set_input_ops(input_ops)
if len(server_endpoints) != 0:
# remote service
self.with_serving = True
else:
if local_rpc_service_handler is not None:
# local rpc service
self.with_serving = True
local_rpc_service_handler.prepare_server() # get fetch_list
serivce_ports = local_rpc_service_handler.get_port_list()
server_endpoints = [
"127.0.0.1:{}".format(p) for p in serivce_ports
]
if client_config is None:
client_config = local_rpc_service_handler.get_client_config(
)
if len(fetch_list) == 0:
fetch_list = local_rpc_service_handler.get_fetch_list()
else:
self.with_serving = False
self._local_rpc_service_handler = local_rpc_service_handler self._local_rpc_service_handler = local_rpc_service_handler
self._server_endpoints = server_endpoints self._server_endpoints = server_endpoints
self._fetch_names = fetch_list self._fetch_names = fetch_list
self._client_config = client_config self._client_config = client_config
self._timeout = timeout
if timeout > 0:
self._timeout = timeout / 1000.0
else:
self._timeout = -1
self._retry = max(1, retry) self._retry = max(1, retry)
self._batch_size = batch_size
self._auto_batching_timeout = auto_batching_timeout
self._input = None self._input = None
self._outputs = [] self._outputs = []
self._batch_size = batch_size self._server_use_profile = False
self._auto_batching_timeout = auto_batching_timeout self._tracer = None
if self._auto_batching_timeout is not None:
# only for thread op
self._for_init_op_lock = threading.Lock()
self._for_close_op_lock = threading.Lock()
self._succ_init_op = False
self._succ_close_op = False
def configure_from_dict(self, conf):
if self.concurrency is None:
self.concurrency = conf["concurrency"]
if self._retry is None:
self._retry = conf["retry"]
if self._fetch_names is None:
self._fetch_names = conf.get("fetch_list")
if self._client_config is None:
self._client_config = conf.get("client_config")
if self._timeout is None:
self._timeout = conf["timeout"]
if self._timeout > 0:
self._timeout = self._timeout / 1000.0
else:
self._timeout = -1
if self._batch_size is None:
self._batch_size = conf["batch_size"]
if self._auto_batching_timeout is None:
self._auto_batching_timeout = conf["auto_batching_timeout"]
if self._auto_batching_timeout <= 0 or self._batch_size == 1: if self._auto_batching_timeout <= 0 or self._batch_size == 1:
_LOGGER.warning( _LOGGER.warning(
self._log( self._log(
...@@ -105,6 +113,54 @@ class Op(object): ...@@ -105,6 +113,54 @@ class Op(object):
self._auto_batching_timeout = None self._auto_batching_timeout = None
else: else:
self._auto_batching_timeout = self._auto_batching_timeout / 1000.0 self._auto_batching_timeout = self._auto_batching_timeout / 1000.0
if self._server_endpoints is None:
server_endpoints = conf.get("server_endpoints", [])
if len(server_endpoints) != 0:
# remote service
self.with_serving = True
self._server_endpoints = server_endpoints
else:
if self._local_rpc_service_handler is None:
local_service_conf = conf.get("local_service_conf")
model_config = local_service_conf.get("model_config")
if model_config is None:
self.with_serving = False
else:
# local rpc service
self.with_serving = True
service_handler = local_rpc_service_handler.LocalRpcServiceHandler(
model_config=model_config,
workdir=local_service_conf["workdir"],
thread_num=local_service_conf["thread_num"],
devices=local_service_conf["devices"],
mem_optim=local_service_conf["mem_optim"],
ir_optim=local_service_conf["ir_optim"])
service_handler.prepare_server() # get fetch_list
serivce_ports = service_handler.get_port_list()
self._server_endpoints = [
"127.0.0.1:{}".format(p) for p in serivce_ports
]
if client_config is None:
client_config = service_handler.get_client_config()
if fetch_list is None:
fetch_list = service_handler.get_fetch_list()
self._local_rpc_service_handler = service_handler
else:
self._local_rpc_service_handler.prepare_server(
) # get fetch_list
serivce_ports = self._local_rpc_service_handler.get_port_list(
)
self._server_endpoints = [
"127.0.0.1:{}".format(p) for p in serivce_ports
]
if client_config is None:
client_config = self._local_rpc_service_handler.get_client_config(
)
if fetch_list is None:
fetch_list = self._local_rpc_service_handler.get_fetch_list(
)
if not isinstance(self, RequestOp) and not isinstance(self, ResponseOp): if not isinstance(self, RequestOp) and not isinstance(self, ResponseOp):
_LOGGER.info( _LOGGER.info(
self._log("\n\tinput_ops: {}," self._log("\n\tinput_ops: {},"
...@@ -116,21 +172,12 @@ class Op(object): ...@@ -116,21 +172,12 @@ class Op(object):
"\n\tretry: {}," "\n\tretry: {},"
"\n\tbatch_size: {}," "\n\tbatch_size: {},"
"\n\tauto_batching_timeout(s): {}".format( "\n\tauto_batching_timeout(s): {}".format(
", ".join([op.name for op in input_ops ", ".join([op.name for op in self._input_ops
]), self._server_endpoints, ]), self._server_endpoints,
self._fetch_names, self._client_config, self._fetch_names, self._client_config,
self.concurrency, self._timeout, self._retry, self.concurrency, self._timeout, self._retry,
self._batch_size, self._auto_batching_timeout))) self._batch_size, self._auto_batching_timeout)))
self._server_use_profile = False
self._tracer = None
# only for thread op
self._for_init_op_lock = threading.Lock()
self._for_close_op_lock = threading.Lock()
self._succ_init_op = False
self._succ_close_op = False
def launch_local_rpc_service(self): def launch_local_rpc_service(self):
if self._local_rpc_service_handler is None: if self._local_rpc_service_handler is None:
_LOGGER.warning( _LOGGER.warning(
......
...@@ -22,19 +22,19 @@ from contextlib import closing ...@@ -22,19 +22,19 @@ from contextlib import closing
import multiprocessing import multiprocessing
import yaml import yaml
from .proto import pipeline_service_pb2_grpc from . import proto
from .operator import ResponseOp, RequestOp from . import operator
from .dag import DAGExecutor, DAG from . import dag
from .util import AvailablePortGenerator from . import util
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
class PipelineServicer(pipeline_service_pb2_grpc.PipelineServiceServicer): class PipelineServicer(proto.pipeline_service_pb2_grpc.PipelineServiceServicer):
def __init__(self, response_op, dag_conf, worker_idx=-1): def __init__(self, response_op, dag_conf, worker_idx=-1):
super(PipelineServicer, self).__init__() super(PipelineServicer, self).__init__()
# init dag executor # init dag executor
self._dag_executor = DAGExecutor(response_op, dag_conf, worker_idx) self._dag_executor = dag.DAGExecutor(response_op, dag_conf, worker_idx)
self._dag_executor.start() self._dag_executor.start()
_LOGGER.info("[PipelineServicer] succ init") _LOGGER.info("[PipelineServicer] succ init")
...@@ -59,7 +59,7 @@ def _reserve_port(port): ...@@ -59,7 +59,7 @@ def _reserve_port(port):
class PipelineServer(object): class PipelineServer(object):
def __init__(self): def __init__(self):
self._port = None self._rpc_port = None
self._worker_num = None self._worker_num = None
self._response_op = None self._response_op = None
self._proxy_server = None self._proxy_server = None
...@@ -77,7 +77,7 @@ class PipelineServer(object): ...@@ -77,7 +77,7 @@ class PipelineServer(object):
if http_port <= 0: if http_port <= 0:
_LOGGER.info("Ignore grpc_gateway configuration.") _LOGGER.info("Ignore grpc_gateway configuration.")
return return
if not AvailablePortGenerator.port_is_available(http_port): if not util.AvailablePortGenerator.port_is_available(http_port):
raise SystemExit("Failed to run grpc-gateway: prot {} " raise SystemExit("Failed to run grpc-gateway: prot {} "
"is already used".format(http_port)) "is already used".format(http_port))
if self._proxy_server is not None: if self._proxy_server is not None:
...@@ -90,25 +90,50 @@ class PipelineServer(object): ...@@ -90,25 +90,50 @@ class PipelineServer(object):
self._proxy_server.start() self._proxy_server.start()
def set_response_op(self, response_op): def set_response_op(self, response_op):
if not isinstance(response_op, ResponseOp): if not isinstance(response_op, operator.ResponseOp):
raise Exception("Failed to set response_op: response_op " raise Exception("Failed to set response_op: response_op "
"must be ResponseOp type.") "must be ResponseOp type.")
if len(response_op.get_input_ops()) != 1: if len(response_op.get_input_ops()) != 1:
raise Exception("Failed to set response_op: response_op " raise Exception("Failed to set response_op: response_op "
"can only have one previous op.") "can only have one previous op.")
self._response_op = response_op self._response_op = response_op
self._used_op, _ = dag.DAG.get_use_ops(self._response_op)
def prepare_server(self, yml_file=None, yml_dict=None): def prepare_server(self, yml_file=None, yml_dict=None):
conf = ServerYamlConfChecker.load_server_yaml_conf( conf = ServerYamlConfChecker.load_server_yaml_conf(
yml_file=yml_file, yml_dict=yml_dict) yml_file=yml_file, yml_dict=yml_dict)
self._port = conf["port"] self._rpc_port = conf.get("rpc_port")
if not AvailablePortGenerator.port_is_available(self._port): self._http_port = conf.get("http_port")
if self._rpc_port is None:
if self._http_port is None:
raise SystemExit("Failed to prepare_server: rpc_port or "
"http_port can not be None.")
else:
# http mode: generate rpc_port
if not util.AvailablePortGenerator.port_is_available(
self._http_port):
raise SystemExit("Failed to prepare_server: http_port({}) "
"is already used".format(self._http_port))
self._rpc_port = util.GetAvailablePortGenerator().next()
else:
if not util.AvailablePortGenerator.port_is_available(
self._rpc_port):
raise SystemExit("Failed to prepare_server: prot {} " raise SystemExit("Failed to prepare_server: prot {} "
"is already used".format(self._port)) "is already used".format(self._rpc_port))
if self._http_port is None:
# rpc mode
pass
else:
# http mode
if not util.AvailablePortGenerator.port_is_available(
self._http_port):
raise SystemExit("Failed to prepare_server: http_port({}) "
"is already used".format(self._http_port))
self._worker_num = conf["worker_num"] self._worker_num = conf["worker_num"]
self._grpc_gateway_port = conf["grpc_gateway_port"]
self._build_dag_each_worker = conf["build_dag_each_worker"] self._build_dag_each_worker = conf["build_dag_each_worker"]
self._configure_ops(conf["op"])
_LOGGER.info("============= PIPELINE SERVER =============") _LOGGER.info("============= PIPELINE SERVER =============")
_LOGGER.info("\n{}".format( _LOGGER.info("\n{}".format(
...@@ -122,18 +147,37 @@ class PipelineServer(object): ...@@ -122,18 +147,37 @@ class PipelineServer(object):
self._conf = conf self._conf = conf
def start_local_rpc_service(self): def _configure_ops(self, op_conf):
default_conf = {
"concurrency": 1,
"timeout": -1,
"retry": 1,
"batch_size": 1,
"auto_batching_timeout": None,
"local_service_conf": {
"workdir": None,
"thread_num": 2,
"devices": "",
"mem_optim": True,
"ir_optim": False,
},
}
for op in self._used_op:
if not isinstance(op, operator.RequestOp):
conf = op_conf.get(op.name, default_conf)
op.configure_from_dict(conf)
def _start_local_rpc_service(self):
# only brpc now # only brpc now
if self._conf["dag"]["client_type"] != "brpc": if self._conf["dag"]["client_type"] != "brpc":
raise ValueError("Local service version must be brpc type now.") raise ValueError("Local service version must be brpc type now.")
used_op, _ = DAG.get_use_ops(self._response_op) for op in self._used_op:
for op in used_op: if not isinstance(op, operator.RequestOp):
if not isinstance(op, RequestOp):
op.launch_local_rpc_service() op.launch_local_rpc_service()
def run_server(self): def run_server(self):
if self._build_dag_each_worker: if self._build_dag_each_worker:
with _reserve_port(self._port) as port: with _reserve_port(self._rpc_port) as port:
bind_address = 'localhost:{}'.format(port) bind_address = 'localhost:{}'.format(port)
workers = [] workers = []
for i in range(self._worker_num): for i in range(self._worker_num):
...@@ -144,8 +188,8 @@ class PipelineServer(object): ...@@ -144,8 +188,8 @@ class PipelineServer(object):
worker.start() worker.start()
workers.append(worker) workers.append(worker)
self._run_grpc_gateway( self._run_grpc_gateway(
grpc_port=self._port, grpc_port=self._rpc_port,
http_port=self._grpc_gateway_port) # start grpc_gateway http_port=self._http_port) # start grpc_gateway
for worker in workers: for worker in workers:
worker.join() worker.join()
else: else:
...@@ -154,13 +198,13 @@ class PipelineServer(object): ...@@ -154,13 +198,13 @@ class PipelineServer(object):
options=[('grpc.max_send_message_length', 256 * 1024 * 1024), options=[('grpc.max_send_message_length', 256 * 1024 * 1024),
('grpc.max_receive_message_length', 256 * 1024 * 1024) ('grpc.max_receive_message_length', 256 * 1024 * 1024)
]) ])
pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server( proto.pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server(
PipelineServicer(self._response_op, self._conf), server) PipelineServicer(self._response_op, self._conf), server)
server.add_insecure_port('[::]:{}'.format(self._port)) server.add_insecure_port('[::]:{}'.format(self._rpc_port))
server.start() server.start()
self._run_grpc_gateway( self._run_grpc_gateway(
grpc_port=self._port, grpc_port=self._rpc_port,
http_port=self._grpc_gateway_port) # start grpc_gateway http_port=self._http_port) # start grpc_gateway
server.wait_for_termination() server.wait_for_termination()
def _run_server_func(self, bind_address, response_op, dag_conf, worker_idx): def _run_server_func(self, bind_address, response_op, dag_conf, worker_idx):
...@@ -170,7 +214,7 @@ class PipelineServer(object): ...@@ -170,7 +214,7 @@ class PipelineServer(object):
server = grpc.server( server = grpc.server(
futures.ThreadPoolExecutor( futures.ThreadPoolExecutor(
max_workers=1, ), options=options) max_workers=1, ), options=options)
pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server( proto.pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server(
PipelineServicer(response_op, dag_conf, worker_idx), server) PipelineServicer(response_op, dag_conf, worker_idx), server)
server.add_insecure_port(bind_address) server.add_insecure_port(bind_address)
server.start() server.start()
...@@ -197,6 +241,10 @@ class ServerYamlConfChecker(object): ...@@ -197,6 +241,10 @@ class ServerYamlConfChecker(object):
ServerYamlConfChecker.check_server_conf(conf) ServerYamlConfChecker.check_server_conf(conf)
ServerYamlConfChecker.check_dag_conf(conf["dag"]) ServerYamlConfChecker.check_dag_conf(conf["dag"])
ServerYamlConfChecker.check_tracer_conf(conf["dag"]["tracer"]) ServerYamlConfChecker.check_tracer_conf(conf["dag"]["tracer"])
for op_name in conf["op"]:
ServerYamlConfChecker.check_op_conf(conf["op"][op_name])
ServerYamlConfChecker.check_local_service_conf(conf["op"][op_name][
"local_service_conf"])
return conf return conf
@staticmethod @staticmethod
...@@ -208,28 +256,81 @@ class ServerYamlConfChecker(object): ...@@ -208,28 +256,81 @@ class ServerYamlConfChecker(object):
@staticmethod @staticmethod
def check_server_conf(conf): def check_server_conf(conf):
default_conf = { default_conf = {
"port": 9292, # "rpc_port": 9292,
"worker_num": 1, "worker_num": 1,
"build_dag_each_worker": False, "build_dag_each_worker": False,
"grpc_gateway_port": 0, #"http_port": 0,
"dag": {}, "dag": {},
"op": {},
} }
conf_type = { conf_type = {
"port": int, "rpc_port": int,
"http_port": int,
"worker_num": int, "worker_num": int,
"build_dag_each_worker": bool, "build_dag_each_worker": bool,
"grpc_gateway_port": int, "grpc_gateway_port": int,
} }
conf_qualification = { conf_qualification = {
"port": [(">=", 1024), ("<=", 65535)], "rpc_port": [(">=", 1024), ("<=", 65535)],
"http_port": [(">=", 1024), ("<=", 65535)],
"worker_num": (">=", 1), "worker_num": (">=", 1),
} }
ServerYamlConfChecker.check_conf(conf, default_conf, conf_type, ServerYamlConfChecker.check_conf(conf, default_conf, conf_type,
conf_qualification) conf_qualification)
@staticmethod
def check_local_service_conf(conf):
default_conf = {
"workdir": None,
"thread_num": 2,
"devices": "",
"mem_optim": True,
"ir_optim": False,
}
conf_type = {
"model_config": str,
"workdir": str,
"thread_num": int,
"devices": str,
"mem_optim": bool,
"ir_optim": bool,
}
conf_qualification = {"thread_num": (">=", 1), }
ServerYamlConfChecker.check_conf(conf, default_conf, conf_type,
conf_qualification)
@staticmethod
def check_op_conf(conf):
default_conf = {
"concurrency": 1,
"timeout": -1,
"retry": 1,
"batch_size": 1,
"auto_batching_timeout": None,
"local_service_conf": {},
}
conf_type = {
"server_endpoints": list,
"fetch_list": list,
"client_config": str,
"concurrency": int,
"timeout": int,
"retry": int,
"batch_size": int,
"auto_batching_timeout": int,
}
conf_qualification = {
"concurrency": (">=", 1),
"retry": (">=", 1),
"batch_size": (">=", 1),
}
for op_name in conf:
ServerYamlConfChecker.check_conf(op_conf[op_name], {}, conf_type,
conf_qualification)
@staticmethod @staticmethod
def check_tracer_conf(conf): def check_tracer_conf(conf):
default_conf = {"interval_s": -1, } default_conf = {"interval_s": -1, }
...@@ -280,6 +381,8 @@ class ServerYamlConfChecker(object): ...@@ -280,6 +381,8 @@ class ServerYamlConfChecker(object):
@staticmethod @staticmethod
def check_conf_type(conf, conf_type): def check_conf_type(conf, conf_type):
for key, val in conf_type.items(): for key, val in conf_type.items():
if key not in conf:
continue
if not isinstance(conf[key], val): if not isinstance(conf[key], val):
raise SystemExit("[CONF] {} must be {} type, but get {}." raise SystemExit("[CONF] {} must be {} type, but get {}."
.format(key, val, type(conf[key]))) .format(key, val, type(conf[key])))
......
...@@ -29,6 +29,11 @@ else: ...@@ -29,6 +29,11 @@ else:
raise Exception("Error Python version") raise Exception("Error Python version")
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
_AvailablePortGenerator = AvailablePortGenerator()
def GetAvailablePortGenerator():
return _AvailablePortGenerator
class AvailablePortGenerator(object): class AvailablePortGenerator(object):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册