diff --git a/python/pipeline/local_rpc_service_handler.py b/python/pipeline/local_rpc_service_handler.py index 080f91000cd5005f39dc215e1654f4ac3d25c4e6..4f2a817ba0007923df2cf4a1f18bd4dfe84108f6 100644 --- a/python/pipeline/local_rpc_service_handler.py +++ b/python/pipeline/local_rpc_service_handler.py @@ -19,11 +19,10 @@ try: from paddle_serving_server_gpu import OpMaker, OpSeqMaker, Server except ImportError: from paddle_serving_server import OpMaker, OpSeqMaker, Server -from .util import AvailablePortGenerator, NameGenerator +from . import util _LOGGER = logging.getLogger(__name__) -_workdir_name_gen = NameGenerator("workdir_") -_available_port_gen = AvailablePortGenerator() +_workdir_name_gen = util.NameGenerator("workdir_") class LocalRpcServiceHandler(object): @@ -36,7 +35,7 @@ class LocalRpcServiceHandler(object): ir_optim=False, available_port_generator=None): if available_port_generator is None: - available_port_generator = _available_port_gen + available_port_generator = util.GetAvailablePortGenerator() self._model_config = model_config self._port_list = [] diff --git a/python/pipeline/operator.py b/python/pipeline/operator.py index a16bdddafc9e67584dd35b7a1aee2440192933fd..946813e57d287f2e970a6b175f2f781f637b9591 100644 --- a/python/pipeline/operator.py +++ b/python/pipeline/operator.py @@ -38,6 +38,7 @@ from .channel import (ThreadChannel, ProcessChannel, ChannelDataEcode, ChannelTimeoutError) from .util import NameGenerator from .profiler import UnsafeTimeProfiler as TimeProfiler +from . import local_rpc_service_handler _LOGGER = logging.getLogger(__name__) _op_name_gen = NameGenerator("Op") @@ -47,13 +48,13 @@ class Op(object): def __init__(self, name=None, input_ops=[], - server_endpoints=[], - fetch_list=[], + server_endpoints=None, + fetch_list=None, client_config=None, - concurrency=1, - timeout=-1, - retry=1, - batch_size=1, + concurrency=None, + timeout=None, + retry=None, + batch_size=None, auto_batching_timeout=None, local_rpc_service_handler=None): if name is None: @@ -62,49 +63,104 @@ class Op(object): self.concurrency = concurrency # amount of concurrency self.set_input_ops(input_ops) - if len(server_endpoints) != 0: - # remote service - self.with_serving = True - else: - if local_rpc_service_handler is not None: - # local rpc service - self.with_serving = True - local_rpc_service_handler.prepare_server() # get fetch_list - serivce_ports = local_rpc_service_handler.get_port_list() - server_endpoints = [ - "127.0.0.1:{}".format(p) for p in serivce_ports - ] - if client_config is None: - client_config = local_rpc_service_handler.get_client_config( - ) - if len(fetch_list) == 0: - fetch_list = local_rpc_service_handler.get_fetch_list() - else: - self.with_serving = False self._local_rpc_service_handler = local_rpc_service_handler self._server_endpoints = server_endpoints self._fetch_names = fetch_list self._client_config = client_config - - if timeout > 0: - self._timeout = timeout / 1000.0 - else: - self._timeout = -1 + self._timeout = timeout self._retry = max(1, retry) + self._batch_size = batch_size + self._auto_batching_timeout = auto_batching_timeout + self._input = None self._outputs = [] - self._batch_size = batch_size - self._auto_batching_timeout = auto_batching_timeout - if self._auto_batching_timeout is not None: - if self._auto_batching_timeout <= 0 or self._batch_size == 1: - _LOGGER.warning( - self._log( - "Because auto_batching_timeout <= 0 or batch_size == 1," - " set auto_batching_timeout to None.")) - self._auto_batching_timeout = None + self._server_use_profile = False + self._tracer = None + + # only for thread op + self._for_init_op_lock = threading.Lock() + self._for_close_op_lock = threading.Lock() + self._succ_init_op = False + self._succ_close_op = False + + def configure_from_dict(self, conf): + if self.concurrency is None: + self.concurrency = conf["concurrency"] + if self._retry is None: + self._retry = conf["retry"] + if self._fetch_names is None: + self._fetch_names = conf.get("fetch_list") + if self._client_config is None: + self._client_config = conf.get("client_config") + + if self._timeout is None: + self._timeout = conf["timeout"] + if self._timeout > 0: + self._timeout = self._timeout / 1000.0 + else: + self._timeout = -1 + + if self._batch_size is None: + self._batch_size = conf["batch_size"] + if self._auto_batching_timeout is None: + self._auto_batching_timeout = conf["auto_batching_timeout"] + if self._auto_batching_timeout <= 0 or self._batch_size == 1: + _LOGGER.warning( + self._log( + "Because auto_batching_timeout <= 0 or batch_size == 1," + " set auto_batching_timeout to None.")) + self._auto_batching_timeout = None + else: + self._auto_batching_timeout = self._auto_batching_timeout / 1000.0 + + if self._server_endpoints is None: + server_endpoints = conf.get("server_endpoints", []) + if len(server_endpoints) != 0: + # remote service + self.with_serving = True + self._server_endpoints = server_endpoints else: - self._auto_batching_timeout = self._auto_batching_timeout / 1000.0 + if self._local_rpc_service_handler is None: + local_service_conf = conf.get("local_service_conf") + model_config = local_service_conf.get("model_config") + if model_config is None: + self.with_serving = False + else: + # local rpc service + self.with_serving = True + service_handler = local_rpc_service_handler.LocalRpcServiceHandler( + model_config=model_config, + workdir=local_service_conf["workdir"], + thread_num=local_service_conf["thread_num"], + devices=local_service_conf["devices"], + mem_optim=local_service_conf["mem_optim"], + ir_optim=local_service_conf["ir_optim"]) + service_handler.prepare_server() # get fetch_list + serivce_ports = service_handler.get_port_list() + self._server_endpoints = [ + "127.0.0.1:{}".format(p) for p in serivce_ports + ] + if client_config is None: + client_config = service_handler.get_client_config() + if fetch_list is None: + fetch_list = service_handler.get_fetch_list() + self._local_rpc_service_handler = service_handler + else: + self._local_rpc_service_handler.prepare_server( + ) # get fetch_list + serivce_ports = self._local_rpc_service_handler.get_port_list( + ) + self._server_endpoints = [ + "127.0.0.1:{}".format(p) for p in serivce_ports + ] + if client_config is None: + client_config = self._local_rpc_service_handler.get_client_config( + ) + if fetch_list is None: + fetch_list = self._local_rpc_service_handler.get_fetch_list( + ) + if not isinstance(self, RequestOp) and not isinstance(self, ResponseOp): _LOGGER.info( self._log("\n\tinput_ops: {}," @@ -116,21 +172,12 @@ class Op(object): "\n\tretry: {}," "\n\tbatch_size: {}," "\n\tauto_batching_timeout(s): {}".format( - ", ".join([op.name for op in input_ops + ", ".join([op.name for op in self._input_ops ]), self._server_endpoints, self._fetch_names, self._client_config, self.concurrency, self._timeout, self._retry, self._batch_size, self._auto_batching_timeout))) - self._server_use_profile = False - self._tracer = None - - # only for thread op - self._for_init_op_lock = threading.Lock() - self._for_close_op_lock = threading.Lock() - self._succ_init_op = False - self._succ_close_op = False - def launch_local_rpc_service(self): if self._local_rpc_service_handler is None: _LOGGER.warning( diff --git a/python/pipeline/pipeline_server.py b/python/pipeline/pipeline_server.py index ac92efd73ac6ab8b6f7be86e3067a9bad7dc5183..f6770025449ea890a44db6d4bc63ef0d88ffb4f7 100644 --- a/python/pipeline/pipeline_server.py +++ b/python/pipeline/pipeline_server.py @@ -22,19 +22,19 @@ from contextlib import closing import multiprocessing import yaml -from .proto import pipeline_service_pb2_grpc -from .operator import ResponseOp, RequestOp -from .dag import DAGExecutor, DAG -from .util import AvailablePortGenerator +from . import proto +from . import operator +from . import dag +from . import util _LOGGER = logging.getLogger(__name__) -class PipelineServicer(pipeline_service_pb2_grpc.PipelineServiceServicer): +class PipelineServicer(proto.pipeline_service_pb2_grpc.PipelineServiceServicer): def __init__(self, response_op, dag_conf, worker_idx=-1): super(PipelineServicer, self).__init__() # init dag executor - self._dag_executor = DAGExecutor(response_op, dag_conf, worker_idx) + self._dag_executor = dag.DAGExecutor(response_op, dag_conf, worker_idx) self._dag_executor.start() _LOGGER.info("[PipelineServicer] succ init") @@ -59,7 +59,7 @@ def _reserve_port(port): class PipelineServer(object): def __init__(self): - self._port = None + self._rpc_port = None self._worker_num = None self._response_op = None self._proxy_server = None @@ -77,7 +77,7 @@ class PipelineServer(object): if http_port <= 0: _LOGGER.info("Ignore grpc_gateway configuration.") return - if not AvailablePortGenerator.port_is_available(http_port): + if not util.AvailablePortGenerator.port_is_available(http_port): raise SystemExit("Failed to run grpc-gateway: prot {} " "is already used".format(http_port)) if self._proxy_server is not None: @@ -90,25 +90,50 @@ class PipelineServer(object): self._proxy_server.start() def set_response_op(self, response_op): - if not isinstance(response_op, ResponseOp): + if not isinstance(response_op, operator.ResponseOp): raise Exception("Failed to set response_op: response_op " "must be ResponseOp type.") if len(response_op.get_input_ops()) != 1: raise Exception("Failed to set response_op: response_op " "can only have one previous op.") self._response_op = response_op + self._used_op, _ = dag.DAG.get_use_ops(self._response_op) def prepare_server(self, yml_file=None, yml_dict=None): conf = ServerYamlConfChecker.load_server_yaml_conf( yml_file=yml_file, yml_dict=yml_dict) - self._port = conf["port"] - if not AvailablePortGenerator.port_is_available(self._port): - raise SystemExit("Failed to prepare_server: prot {} " - "is already used".format(self._port)) + self._rpc_port = conf.get("rpc_port") + self._http_port = conf.get("http_port") + if self._rpc_port is None: + if self._http_port is None: + raise SystemExit("Failed to prepare_server: rpc_port or " + "http_port can not be None.") + else: + # http mode: generate rpc_port + if not util.AvailablePortGenerator.port_is_available( + self._http_port): + raise SystemExit("Failed to prepare_server: http_port({}) " + "is already used".format(self._http_port)) + self._rpc_port = util.GetAvailablePortGenerator().next() + else: + if not util.AvailablePortGenerator.port_is_available( + self._rpc_port): + raise SystemExit("Failed to prepare_server: prot {} " + "is already used".format(self._rpc_port)) + if self._http_port is None: + # rpc mode + pass + else: + # http mode + if not util.AvailablePortGenerator.port_is_available( + self._http_port): + raise SystemExit("Failed to prepare_server: http_port({}) " + "is already used".format(self._http_port)) + self._worker_num = conf["worker_num"] - self._grpc_gateway_port = conf["grpc_gateway_port"] self._build_dag_each_worker = conf["build_dag_each_worker"] + self._configure_ops(conf["op"]) _LOGGER.info("============= PIPELINE SERVER =============") _LOGGER.info("\n{}".format( @@ -122,18 +147,37 @@ class PipelineServer(object): self._conf = conf - def start_local_rpc_service(self): + def _configure_ops(self, op_conf): + default_conf = { + "concurrency": 1, + "timeout": -1, + "retry": 1, + "batch_size": 1, + "auto_batching_timeout": None, + "local_service_conf": { + "workdir": None, + "thread_num": 2, + "devices": "", + "mem_optim": True, + "ir_optim": False, + }, + } + for op in self._used_op: + if not isinstance(op, operator.RequestOp): + conf = op_conf.get(op.name, default_conf) + op.configure_from_dict(conf) + + def _start_local_rpc_service(self): # only brpc now if self._conf["dag"]["client_type"] != "brpc": raise ValueError("Local service version must be brpc type now.") - used_op, _ = DAG.get_use_ops(self._response_op) - for op in used_op: - if not isinstance(op, RequestOp): + for op in self._used_op: + if not isinstance(op, operator.RequestOp): op.launch_local_rpc_service() def run_server(self): if self._build_dag_each_worker: - with _reserve_port(self._port) as port: + with _reserve_port(self._rpc_port) as port: bind_address = 'localhost:{}'.format(port) workers = [] for i in range(self._worker_num): @@ -144,8 +188,8 @@ class PipelineServer(object): worker.start() workers.append(worker) self._run_grpc_gateway( - grpc_port=self._port, - http_port=self._grpc_gateway_port) # start grpc_gateway + grpc_port=self._rpc_port, + http_port=self._http_port) # start grpc_gateway for worker in workers: worker.join() else: @@ -154,13 +198,13 @@ class PipelineServer(object): options=[('grpc.max_send_message_length', 256 * 1024 * 1024), ('grpc.max_receive_message_length', 256 * 1024 * 1024) ]) - pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server( + proto.pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server( PipelineServicer(self._response_op, self._conf), server) - server.add_insecure_port('[::]:{}'.format(self._port)) + server.add_insecure_port('[::]:{}'.format(self._rpc_port)) server.start() self._run_grpc_gateway( - grpc_port=self._port, - http_port=self._grpc_gateway_port) # start grpc_gateway + grpc_port=self._rpc_port, + http_port=self._http_port) # start grpc_gateway server.wait_for_termination() def _run_server_func(self, bind_address, response_op, dag_conf, worker_idx): @@ -170,7 +214,7 @@ class PipelineServer(object): server = grpc.server( futures.ThreadPoolExecutor( max_workers=1, ), options=options) - pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server( + proto.pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server( PipelineServicer(response_op, dag_conf, worker_idx), server) server.add_insecure_port(bind_address) server.start() @@ -197,6 +241,10 @@ class ServerYamlConfChecker(object): ServerYamlConfChecker.check_server_conf(conf) ServerYamlConfChecker.check_dag_conf(conf["dag"]) ServerYamlConfChecker.check_tracer_conf(conf["dag"]["tracer"]) + for op_name in conf["op"]: + ServerYamlConfChecker.check_op_conf(conf["op"][op_name]) + ServerYamlConfChecker.check_local_service_conf(conf["op"][op_name][ + "local_service_conf"]) return conf @staticmethod @@ -208,28 +256,81 @@ class ServerYamlConfChecker(object): @staticmethod def check_server_conf(conf): default_conf = { - "port": 9292, + # "rpc_port": 9292, "worker_num": 1, "build_dag_each_worker": False, - "grpc_gateway_port": 0, + #"http_port": 0, "dag": {}, + "op": {}, } conf_type = { - "port": int, + "rpc_port": int, + "http_port": int, "worker_num": int, "build_dag_each_worker": bool, "grpc_gateway_port": int, } conf_qualification = { - "port": [(">=", 1024), ("<=", 65535)], + "rpc_port": [(">=", 1024), ("<=", 65535)], + "http_port": [(">=", 1024), ("<=", 65535)], "worker_num": (">=", 1), } ServerYamlConfChecker.check_conf(conf, default_conf, conf_type, conf_qualification) + @staticmethod + def check_local_service_conf(conf): + default_conf = { + "workdir": None, + "thread_num": 2, + "devices": "", + "mem_optim": True, + "ir_optim": False, + } + conf_type = { + "model_config": str, + "workdir": str, + "thread_num": int, + "devices": str, + "mem_optim": bool, + "ir_optim": bool, + } + conf_qualification = {"thread_num": (">=", 1), } + ServerYamlConfChecker.check_conf(conf, default_conf, conf_type, + conf_qualification) + + @staticmethod + def check_op_conf(conf): + default_conf = { + "concurrency": 1, + "timeout": -1, + "retry": 1, + "batch_size": 1, + "auto_batching_timeout": None, + "local_service_conf": {}, + } + conf_type = { + "server_endpoints": list, + "fetch_list": list, + "client_config": str, + "concurrency": int, + "timeout": int, + "retry": int, + "batch_size": int, + "auto_batching_timeout": int, + } + conf_qualification = { + "concurrency": (">=", 1), + "retry": (">=", 1), + "batch_size": (">=", 1), + } + for op_name in conf: + ServerYamlConfChecker.check_conf(op_conf[op_name], {}, conf_type, + conf_qualification) + @staticmethod def check_tracer_conf(conf): default_conf = {"interval_s": -1, } @@ -280,6 +381,8 @@ class ServerYamlConfChecker(object): @staticmethod def check_conf_type(conf, conf_type): for key, val in conf_type.items(): + if key not in conf: + continue if not isinstance(conf[key], val): raise SystemExit("[CONF] {} must be {} type, but get {}." .format(key, val, type(conf[key]))) diff --git a/python/pipeline/util.py b/python/pipeline/util.py index 52b79c9fb59b0d89fb79b2f83bec171cdb21f388..d03598915f33bb85488c02f394c93839b085417b 100644 --- a/python/pipeline/util.py +++ b/python/pipeline/util.py @@ -29,6 +29,11 @@ else: raise Exception("Error Python version") _LOGGER = logging.getLogger(__name__) +_AvailablePortGenerator = AvailablePortGenerator() + + +def GetAvailablePortGenerator(): + return _AvailablePortGenerator class AvailablePortGenerator(object):