提交 dc91636b 编写于 作者: B barriery

config op from dict

上级 1e19ccac
......@@ -19,11 +19,10 @@ try:
from paddle_serving_server_gpu import OpMaker, OpSeqMaker, Server
except ImportError:
from paddle_serving_server import OpMaker, OpSeqMaker, Server
from .util import AvailablePortGenerator, NameGenerator
from . import util
_LOGGER = logging.getLogger(__name__)
_workdir_name_gen = NameGenerator("workdir_")
_available_port_gen = AvailablePortGenerator()
_workdir_name_gen = util.NameGenerator("workdir_")
class LocalRpcServiceHandler(object):
......@@ -36,7 +35,7 @@ class LocalRpcServiceHandler(object):
ir_optim=False,
available_port_generator=None):
if available_port_generator is None:
available_port_generator = _available_port_gen
available_port_generator = util.GetAvailablePortGenerator()
self._model_config = model_config
self._port_list = []
......
......@@ -38,6 +38,7 @@ from .channel import (ThreadChannel, ProcessChannel, ChannelDataEcode,
ChannelTimeoutError)
from .util import NameGenerator
from .profiler import UnsafeTimeProfiler as TimeProfiler
from . import local_rpc_service_handler
_LOGGER = logging.getLogger(__name__)
_op_name_gen = NameGenerator("Op")
......@@ -47,13 +48,13 @@ class Op(object):
def __init__(self,
name=None,
input_ops=[],
server_endpoints=[],
fetch_list=[],
server_endpoints=None,
fetch_list=None,
client_config=None,
concurrency=1,
timeout=-1,
retry=1,
batch_size=1,
concurrency=None,
timeout=None,
retry=None,
batch_size=None,
auto_batching_timeout=None,
local_rpc_service_handler=None):
if name is None:
......@@ -62,49 +63,104 @@ class Op(object):
self.concurrency = concurrency # amount of concurrency
self.set_input_ops(input_ops)
if len(server_endpoints) != 0:
# remote service
self.with_serving = True
else:
if local_rpc_service_handler is not None:
# local rpc service
self.with_serving = True
local_rpc_service_handler.prepare_server() # get fetch_list
serivce_ports = local_rpc_service_handler.get_port_list()
server_endpoints = [
"127.0.0.1:{}".format(p) for p in serivce_ports
]
if client_config is None:
client_config = local_rpc_service_handler.get_client_config(
)
if len(fetch_list) == 0:
fetch_list = local_rpc_service_handler.get_fetch_list()
else:
self.with_serving = False
self._local_rpc_service_handler = local_rpc_service_handler
self._server_endpoints = server_endpoints
self._fetch_names = fetch_list
self._client_config = client_config
if timeout > 0:
self._timeout = timeout / 1000.0
else:
self._timeout = -1
self._timeout = timeout
self._retry = max(1, retry)
self._batch_size = batch_size
self._auto_batching_timeout = auto_batching_timeout
self._input = None
self._outputs = []
self._batch_size = batch_size
self._auto_batching_timeout = auto_batching_timeout
if self._auto_batching_timeout is not None:
if self._auto_batching_timeout <= 0 or self._batch_size == 1:
_LOGGER.warning(
self._log(
"Because auto_batching_timeout <= 0 or batch_size == 1,"
" set auto_batching_timeout to None."))
self._auto_batching_timeout = None
self._server_use_profile = False
self._tracer = None
# only for thread op
self._for_init_op_lock = threading.Lock()
self._for_close_op_lock = threading.Lock()
self._succ_init_op = False
self._succ_close_op = False
def configure_from_dict(self, conf):
if self.concurrency is None:
self.concurrency = conf["concurrency"]
if self._retry is None:
self._retry = conf["retry"]
if self._fetch_names is None:
self._fetch_names = conf.get("fetch_list")
if self._client_config is None:
self._client_config = conf.get("client_config")
if self._timeout is None:
self._timeout = conf["timeout"]
if self._timeout > 0:
self._timeout = self._timeout / 1000.0
else:
self._timeout = -1
if self._batch_size is None:
self._batch_size = conf["batch_size"]
if self._auto_batching_timeout is None:
self._auto_batching_timeout = conf["auto_batching_timeout"]
if self._auto_batching_timeout <= 0 or self._batch_size == 1:
_LOGGER.warning(
self._log(
"Because auto_batching_timeout <= 0 or batch_size == 1,"
" set auto_batching_timeout to None."))
self._auto_batching_timeout = None
else:
self._auto_batching_timeout = self._auto_batching_timeout / 1000.0
if self._server_endpoints is None:
server_endpoints = conf.get("server_endpoints", [])
if len(server_endpoints) != 0:
# remote service
self.with_serving = True
self._server_endpoints = server_endpoints
else:
self._auto_batching_timeout = self._auto_batching_timeout / 1000.0
if self._local_rpc_service_handler is None:
local_service_conf = conf.get("local_service_conf")
model_config = local_service_conf.get("model_config")
if model_config is None:
self.with_serving = False
else:
# local rpc service
self.with_serving = True
service_handler = local_rpc_service_handler.LocalRpcServiceHandler(
model_config=model_config,
workdir=local_service_conf["workdir"],
thread_num=local_service_conf["thread_num"],
devices=local_service_conf["devices"],
mem_optim=local_service_conf["mem_optim"],
ir_optim=local_service_conf["ir_optim"])
service_handler.prepare_server() # get fetch_list
serivce_ports = service_handler.get_port_list()
self._server_endpoints = [
"127.0.0.1:{}".format(p) for p in serivce_ports
]
if client_config is None:
client_config = service_handler.get_client_config()
if fetch_list is None:
fetch_list = service_handler.get_fetch_list()
self._local_rpc_service_handler = service_handler
else:
self._local_rpc_service_handler.prepare_server(
) # get fetch_list
serivce_ports = self._local_rpc_service_handler.get_port_list(
)
self._server_endpoints = [
"127.0.0.1:{}".format(p) for p in serivce_ports
]
if client_config is None:
client_config = self._local_rpc_service_handler.get_client_config(
)
if fetch_list is None:
fetch_list = self._local_rpc_service_handler.get_fetch_list(
)
if not isinstance(self, RequestOp) and not isinstance(self, ResponseOp):
_LOGGER.info(
self._log("\n\tinput_ops: {},"
......@@ -116,21 +172,12 @@ class Op(object):
"\n\tretry: {},"
"\n\tbatch_size: {},"
"\n\tauto_batching_timeout(s): {}".format(
", ".join([op.name for op in input_ops
", ".join([op.name for op in self._input_ops
]), self._server_endpoints,
self._fetch_names, self._client_config,
self.concurrency, self._timeout, self._retry,
self._batch_size, self._auto_batching_timeout)))
self._server_use_profile = False
self._tracer = None
# only for thread op
self._for_init_op_lock = threading.Lock()
self._for_close_op_lock = threading.Lock()
self._succ_init_op = False
self._succ_close_op = False
def launch_local_rpc_service(self):
if self._local_rpc_service_handler is None:
_LOGGER.warning(
......
......@@ -22,19 +22,19 @@ from contextlib import closing
import multiprocessing
import yaml
from .proto import pipeline_service_pb2_grpc
from .operator import ResponseOp, RequestOp
from .dag import DAGExecutor, DAG
from .util import AvailablePortGenerator
from . import proto
from . import operator
from . import dag
from . import util
_LOGGER = logging.getLogger(__name__)
class PipelineServicer(pipeline_service_pb2_grpc.PipelineServiceServicer):
class PipelineServicer(proto.pipeline_service_pb2_grpc.PipelineServiceServicer):
def __init__(self, response_op, dag_conf, worker_idx=-1):
super(PipelineServicer, self).__init__()
# init dag executor
self._dag_executor = DAGExecutor(response_op, dag_conf, worker_idx)
self._dag_executor = dag.DAGExecutor(response_op, dag_conf, worker_idx)
self._dag_executor.start()
_LOGGER.info("[PipelineServicer] succ init")
......@@ -59,7 +59,7 @@ def _reserve_port(port):
class PipelineServer(object):
def __init__(self):
self._port = None
self._rpc_port = None
self._worker_num = None
self._response_op = None
self._proxy_server = None
......@@ -77,7 +77,7 @@ class PipelineServer(object):
if http_port <= 0:
_LOGGER.info("Ignore grpc_gateway configuration.")
return
if not AvailablePortGenerator.port_is_available(http_port):
if not util.AvailablePortGenerator.port_is_available(http_port):
raise SystemExit("Failed to run grpc-gateway: prot {} "
"is already used".format(http_port))
if self._proxy_server is not None:
......@@ -90,25 +90,50 @@ class PipelineServer(object):
self._proxy_server.start()
def set_response_op(self, response_op):
if not isinstance(response_op, ResponseOp):
if not isinstance(response_op, operator.ResponseOp):
raise Exception("Failed to set response_op: response_op "
"must be ResponseOp type.")
if len(response_op.get_input_ops()) != 1:
raise Exception("Failed to set response_op: response_op "
"can only have one previous op.")
self._response_op = response_op
self._used_op, _ = dag.DAG.get_use_ops(self._response_op)
def prepare_server(self, yml_file=None, yml_dict=None):
conf = ServerYamlConfChecker.load_server_yaml_conf(
yml_file=yml_file, yml_dict=yml_dict)
self._port = conf["port"]
if not AvailablePortGenerator.port_is_available(self._port):
raise SystemExit("Failed to prepare_server: prot {} "
"is already used".format(self._port))
self._rpc_port = conf.get("rpc_port")
self._http_port = conf.get("http_port")
if self._rpc_port is None:
if self._http_port is None:
raise SystemExit("Failed to prepare_server: rpc_port or "
"http_port can not be None.")
else:
# http mode: generate rpc_port
if not util.AvailablePortGenerator.port_is_available(
self._http_port):
raise SystemExit("Failed to prepare_server: http_port({}) "
"is already used".format(self._http_port))
self._rpc_port = util.GetAvailablePortGenerator().next()
else:
if not util.AvailablePortGenerator.port_is_available(
self._rpc_port):
raise SystemExit("Failed to prepare_server: prot {} "
"is already used".format(self._rpc_port))
if self._http_port is None:
# rpc mode
pass
else:
# http mode
if not util.AvailablePortGenerator.port_is_available(
self._http_port):
raise SystemExit("Failed to prepare_server: http_port({}) "
"is already used".format(self._http_port))
self._worker_num = conf["worker_num"]
self._grpc_gateway_port = conf["grpc_gateway_port"]
self._build_dag_each_worker = conf["build_dag_each_worker"]
self._configure_ops(conf["op"])
_LOGGER.info("============= PIPELINE SERVER =============")
_LOGGER.info("\n{}".format(
......@@ -122,18 +147,37 @@ class PipelineServer(object):
self._conf = conf
def start_local_rpc_service(self):
def _configure_ops(self, op_conf):
default_conf = {
"concurrency": 1,
"timeout": -1,
"retry": 1,
"batch_size": 1,
"auto_batching_timeout": None,
"local_service_conf": {
"workdir": None,
"thread_num": 2,
"devices": "",
"mem_optim": True,
"ir_optim": False,
},
}
for op in self._used_op:
if not isinstance(op, operator.RequestOp):
conf = op_conf.get(op.name, default_conf)
op.configure_from_dict(conf)
def _start_local_rpc_service(self):
# only brpc now
if self._conf["dag"]["client_type"] != "brpc":
raise ValueError("Local service version must be brpc type now.")
used_op, _ = DAG.get_use_ops(self._response_op)
for op in used_op:
if not isinstance(op, RequestOp):
for op in self._used_op:
if not isinstance(op, operator.RequestOp):
op.launch_local_rpc_service()
def run_server(self):
if self._build_dag_each_worker:
with _reserve_port(self._port) as port:
with _reserve_port(self._rpc_port) as port:
bind_address = 'localhost:{}'.format(port)
workers = []
for i in range(self._worker_num):
......@@ -144,8 +188,8 @@ class PipelineServer(object):
worker.start()
workers.append(worker)
self._run_grpc_gateway(
grpc_port=self._port,
http_port=self._grpc_gateway_port) # start grpc_gateway
grpc_port=self._rpc_port,
http_port=self._http_port) # start grpc_gateway
for worker in workers:
worker.join()
else:
......@@ -154,13 +198,13 @@ class PipelineServer(object):
options=[('grpc.max_send_message_length', 256 * 1024 * 1024),
('grpc.max_receive_message_length', 256 * 1024 * 1024)
])
pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server(
proto.pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server(
PipelineServicer(self._response_op, self._conf), server)
server.add_insecure_port('[::]:{}'.format(self._port))
server.add_insecure_port('[::]:{}'.format(self._rpc_port))
server.start()
self._run_grpc_gateway(
grpc_port=self._port,
http_port=self._grpc_gateway_port) # start grpc_gateway
grpc_port=self._rpc_port,
http_port=self._http_port) # start grpc_gateway
server.wait_for_termination()
def _run_server_func(self, bind_address, response_op, dag_conf, worker_idx):
......@@ -170,7 +214,7 @@ class PipelineServer(object):
server = grpc.server(
futures.ThreadPoolExecutor(
max_workers=1, ), options=options)
pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server(
proto.pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server(
PipelineServicer(response_op, dag_conf, worker_idx), server)
server.add_insecure_port(bind_address)
server.start()
......@@ -197,6 +241,10 @@ class ServerYamlConfChecker(object):
ServerYamlConfChecker.check_server_conf(conf)
ServerYamlConfChecker.check_dag_conf(conf["dag"])
ServerYamlConfChecker.check_tracer_conf(conf["dag"]["tracer"])
for op_name in conf["op"]:
ServerYamlConfChecker.check_op_conf(conf["op"][op_name])
ServerYamlConfChecker.check_local_service_conf(conf["op"][op_name][
"local_service_conf"])
return conf
@staticmethod
......@@ -208,28 +256,81 @@ class ServerYamlConfChecker(object):
@staticmethod
def check_server_conf(conf):
default_conf = {
"port": 9292,
# "rpc_port": 9292,
"worker_num": 1,
"build_dag_each_worker": False,
"grpc_gateway_port": 0,
#"http_port": 0,
"dag": {},
"op": {},
}
conf_type = {
"port": int,
"rpc_port": int,
"http_port": int,
"worker_num": int,
"build_dag_each_worker": bool,
"grpc_gateway_port": int,
}
conf_qualification = {
"port": [(">=", 1024), ("<=", 65535)],
"rpc_port": [(">=", 1024), ("<=", 65535)],
"http_port": [(">=", 1024), ("<=", 65535)],
"worker_num": (">=", 1),
}
ServerYamlConfChecker.check_conf(conf, default_conf, conf_type,
conf_qualification)
@staticmethod
def check_local_service_conf(conf):
default_conf = {
"workdir": None,
"thread_num": 2,
"devices": "",
"mem_optim": True,
"ir_optim": False,
}
conf_type = {
"model_config": str,
"workdir": str,
"thread_num": int,
"devices": str,
"mem_optim": bool,
"ir_optim": bool,
}
conf_qualification = {"thread_num": (">=", 1), }
ServerYamlConfChecker.check_conf(conf, default_conf, conf_type,
conf_qualification)
@staticmethod
def check_op_conf(conf):
default_conf = {
"concurrency": 1,
"timeout": -1,
"retry": 1,
"batch_size": 1,
"auto_batching_timeout": None,
"local_service_conf": {},
}
conf_type = {
"server_endpoints": list,
"fetch_list": list,
"client_config": str,
"concurrency": int,
"timeout": int,
"retry": int,
"batch_size": int,
"auto_batching_timeout": int,
}
conf_qualification = {
"concurrency": (">=", 1),
"retry": (">=", 1),
"batch_size": (">=", 1),
}
for op_name in conf:
ServerYamlConfChecker.check_conf(op_conf[op_name], {}, conf_type,
conf_qualification)
@staticmethod
def check_tracer_conf(conf):
default_conf = {"interval_s": -1, }
......@@ -280,6 +381,8 @@ class ServerYamlConfChecker(object):
@staticmethod
def check_conf_type(conf, conf_type):
for key, val in conf_type.items():
if key not in conf:
continue
if not isinstance(conf[key], val):
raise SystemExit("[CONF] {} must be {} type, but get {}."
.format(key, val, type(conf[key])))
......
......@@ -29,6 +29,11 @@ else:
raise Exception("Error Python version")
_LOGGER = logging.getLogger(__name__)
_AvailablePortGenerator = AvailablePortGenerator()
def GetAvailablePortGenerator():
return _AvailablePortGenerator
class AvailablePortGenerator(object):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册