提交 b2bfbe9a 编写于 作者: B barriery

update code

上级 dc91636b
...@@ -28,7 +28,7 @@ _workdir_name_gen = util.NameGenerator("workdir_") ...@@ -28,7 +28,7 @@ _workdir_name_gen = util.NameGenerator("workdir_")
class LocalRpcServiceHandler(object): class LocalRpcServiceHandler(object):
def __init__(self, def __init__(self,
model_config, model_config,
workdir=None, workdir="",
thread_num=2, thread_num=2,
devices="", devices="",
mem_optim=True, mem_optim=True,
...@@ -105,7 +105,7 @@ class LocalRpcServiceHandler(object): ...@@ -105,7 +105,7 @@ class LocalRpcServiceHandler(object):
def prepare_server(self): def prepare_server(self):
for i, device_id in enumerate(self._devices): for i, device_id in enumerate(self._devices):
if self._workdir is not None: if self._workdir != "":
workdir = "{}_{}".format(self._workdir, i) workdir = "{}_{}".format(self._workdir, i)
else: else:
workdir = _workdir_name_gen.next() workdir = _workdir_name_gen.next()
......
...@@ -57,6 +57,7 @@ class Op(object): ...@@ -57,6 +57,7 @@ class Op(object):
batch_size=None, batch_size=None,
auto_batching_timeout=None, auto_batching_timeout=None,
local_rpc_service_handler=None): local_rpc_service_handler=None):
# In __init__, all the parameters are just saved and Op is not initialized
if name is None: if name is None:
name = _op_name_gen.next() name = _op_name_gen.next()
self.name = name # to identify the type of OP, it must be globally unique self.name = name # to identify the type of OP, it must be globally unique
...@@ -84,7 +85,8 @@ class Op(object): ...@@ -84,7 +85,8 @@ class Op(object):
self._succ_init_op = False self._succ_init_op = False
self._succ_close_op = False self._succ_close_op = False
def configure_from_dict(self, conf): def init_from_dict(self, conf):
# init op
if self.concurrency is None: if self.concurrency is None:
self.concurrency = conf["concurrency"] self.concurrency = conf["concurrency"]
if self._retry is None: if self._retry is None:
...@@ -123,7 +125,10 @@ class Op(object): ...@@ -123,7 +125,10 @@ class Op(object):
else: else:
if self._local_rpc_service_handler is None: if self._local_rpc_service_handler is None:
local_service_conf = conf.get("local_service_conf") local_service_conf = conf.get("local_service_conf")
_LOGGER.info("local_service_conf: {}".format(
local_service_conf))
model_config = local_service_conf.get("model_config") model_config = local_service_conf.get("model_config")
_LOGGER.info("model_config: {}".format(model_config))
if model_config is None: if model_config is None:
self.with_serving = False self.with_serving = False
else: else:
...@@ -141,12 +146,14 @@ class Op(object): ...@@ -141,12 +146,14 @@ class Op(object):
self._server_endpoints = [ self._server_endpoints = [
"127.0.0.1:{}".format(p) for p in serivce_ports "127.0.0.1:{}".format(p) for p in serivce_ports
] ]
if client_config is None: if self._client_config is None:
client_config = service_handler.get_client_config() self._client_config = service_handler.get_client_config(
if fetch_list is None: )
fetch_list = service_handler.get_fetch_list() if self._fetch_names is None:
self._fetch_names = service_handler.get_fetch_list()
self._local_rpc_service_handler = service_handler self._local_rpc_service_handler = service_handler
else: else:
self.with_serving = True
self._local_rpc_service_handler.prepare_server( self._local_rpc_service_handler.prepare_server(
) # get fetch_list ) # get fetch_list
serivce_ports = self._local_rpc_service_handler.get_port_list( serivce_ports = self._local_rpc_service_handler.get_port_list(
...@@ -154,12 +161,14 @@ class Op(object): ...@@ -154,12 +161,14 @@ class Op(object):
self._server_endpoints = [ self._server_endpoints = [
"127.0.0.1:{}".format(p) for p in serivce_ports "127.0.0.1:{}".format(p) for p in serivce_ports
] ]
if client_config is None: if self._client_config is None:
client_config = self._local_rpc_service_handler.get_client_config( self._client_config = self._local_rpc_service_handler.get_client_config(
) )
if fetch_list is None: if self._fetch_names is None:
fetch_list = self._local_rpc_service_handler.get_fetch_list( self._fetch_names = self._local_rpc_service_handler.get_fetch_list(
) )
else:
self.with_serving = True
if not isinstance(self, RequestOp) and not isinstance(self, ResponseOp): if not isinstance(self, RequestOp) and not isinstance(self, ResponseOp):
_LOGGER.info( _LOGGER.info(
......
...@@ -22,7 +22,7 @@ from contextlib import closing ...@@ -22,7 +22,7 @@ from contextlib import closing
import multiprocessing import multiprocessing
import yaml import yaml
from . import proto from .proto import pipeline_service_pb2_grpc
from . import operator from . import operator
from . import dag from . import dag
from . import util from . import util
...@@ -30,7 +30,7 @@ from . import util ...@@ -30,7 +30,7 @@ from . import util
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
class PipelineServicer(proto.pipeline_service_pb2_grpc.PipelineServiceServicer): class PipelineServicer(pipeline_service_pb2_grpc.PipelineServiceServicer):
def __init__(self, response_op, dag_conf, worker_idx=-1): def __init__(self, response_op, dag_conf, worker_idx=-1):
super(PipelineServicer, self).__init__() super(PipelineServicer, self).__init__()
# init dag executor # init dag executor
...@@ -133,7 +133,7 @@ class PipelineServer(object): ...@@ -133,7 +133,7 @@ class PipelineServer(object):
self._worker_num = conf["worker_num"] self._worker_num = conf["worker_num"]
self._build_dag_each_worker = conf["build_dag_each_worker"] self._build_dag_each_worker = conf["build_dag_each_worker"]
self._configure_ops(conf["op"]) self._init_ops(conf["op"])
_LOGGER.info("============= PIPELINE SERVER =============") _LOGGER.info("============= PIPELINE SERVER =============")
_LOGGER.info("\n{}".format( _LOGGER.info("\n{}".format(
...@@ -146,16 +146,17 @@ class PipelineServer(object): ...@@ -146,16 +146,17 @@ class PipelineServer(object):
_LOGGER.info("-------------------------------------------") _LOGGER.info("-------------------------------------------")
self._conf = conf self._conf = conf
self._start_local_rpc_service()
def _configure_ops(self, op_conf): def _init_ops(self, op_conf):
default_conf = { default_conf = {
"concurrency": 1, "concurrency": 1,
"timeout": -1, "timeout": -1,
"retry": 1, "retry": 1,
"batch_size": 1, "batch_size": 1,
"auto_batching_timeout": None, "auto_batching_timeout": -1,
"local_service_conf": { "local_service_conf": {
"workdir": None, "workdir": "",
"thread_num": 2, "thread_num": 2,
"devices": "", "devices": "",
"mem_optim": True, "mem_optim": True,
...@@ -163,9 +164,10 @@ class PipelineServer(object): ...@@ -163,9 +164,10 @@ class PipelineServer(object):
}, },
} }
for op in self._used_op: for op in self._used_op:
if not isinstance(op, operator.RequestOp): if not isinstance(op, operator.RequestOp) and not isinstance(
op, operator.ResponseOp):
conf = op_conf.get(op.name, default_conf) conf = op_conf.get(op.name, default_conf)
op.configure_from_dict(conf) op.init_from_dict(conf)
def _start_local_rpc_service(self): def _start_local_rpc_service(self):
# only brpc now # only brpc now
...@@ -198,7 +200,7 @@ class PipelineServer(object): ...@@ -198,7 +200,7 @@ class PipelineServer(object):
options=[('grpc.max_send_message_length', 256 * 1024 * 1024), options=[('grpc.max_send_message_length', 256 * 1024 * 1024),
('grpc.max_receive_message_length', 256 * 1024 * 1024) ('grpc.max_receive_message_length', 256 * 1024 * 1024)
]) ])
proto.pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server( pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server(
PipelineServicer(self._response_op, self._conf), server) PipelineServicer(self._response_op, self._conf), server)
server.add_insecure_port('[::]:{}'.format(self._rpc_port)) server.add_insecure_port('[::]:{}'.format(self._rpc_port))
server.start() server.start()
...@@ -214,7 +216,7 @@ class PipelineServer(object): ...@@ -214,7 +216,7 @@ class PipelineServer(object):
server = grpc.server( server = grpc.server(
futures.ThreadPoolExecutor( futures.ThreadPoolExecutor(
max_workers=1, ), options=options) max_workers=1, ), options=options)
proto.pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server( pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server(
PipelineServicer(response_op, dag_conf, worker_idx), server) PipelineServicer(response_op, dag_conf, worker_idx), server)
server.add_insecure_port(bind_address) server.add_insecure_port(bind_address)
server.start() server.start()
...@@ -284,7 +286,7 @@ class ServerYamlConfChecker(object): ...@@ -284,7 +286,7 @@ class ServerYamlConfChecker(object):
@staticmethod @staticmethod
def check_local_service_conf(conf): def check_local_service_conf(conf):
default_conf = { default_conf = {
"workdir": None, "workdir": "",
"thread_num": 2, "thread_num": 2,
"devices": "", "devices": "",
"mem_optim": True, "mem_optim": True,
...@@ -309,7 +311,7 @@ class ServerYamlConfChecker(object): ...@@ -309,7 +311,7 @@ class ServerYamlConfChecker(object):
"timeout": -1, "timeout": -1,
"retry": 1, "retry": 1,
"batch_size": 1, "batch_size": 1,
"auto_batching_timeout": None, "auto_batching_timeout": -1,
"local_service_conf": {}, "local_service_conf": {},
} }
conf_type = { conf_type = {
...@@ -327,9 +329,8 @@ class ServerYamlConfChecker(object): ...@@ -327,9 +329,8 @@ class ServerYamlConfChecker(object):
"retry": (">=", 1), "retry": (">=", 1),
"batch_size": (">=", 1), "batch_size": (">=", 1),
} }
for op_name in conf: ServerYamlConfChecker.check_conf(conf, default_conf, conf_type,
ServerYamlConfChecker.check_conf(op_conf[op_name], {}, conf_type, conf_qualification)
conf_qualification)
@staticmethod @staticmethod
def check_tracer_conf(conf): def check_tracer_conf(conf):
...@@ -390,6 +391,8 @@ class ServerYamlConfChecker(object): ...@@ -390,6 +391,8 @@ class ServerYamlConfChecker(object):
@staticmethod @staticmethod
def check_conf_qualification(conf, conf_qualification): def check_conf_qualification(conf, conf_qualification):
for key, qualification in conf_qualification.items(): for key, qualification in conf_qualification.items():
if key not in conf:
continue
if not isinstance(qualification, list): if not isinstance(qualification, list):
qualification = [qualification] qualification = [qualification]
if not ServerYamlConfChecker.qualification_check(conf[key], if not ServerYamlConfChecker.qualification_check(conf[key],
......
...@@ -29,11 +29,6 @@ else: ...@@ -29,11 +29,6 @@ else:
raise Exception("Error Python version") raise Exception("Error Python version")
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
_AvailablePortGenerator = AvailablePortGenerator()
def GetAvailablePortGenerator():
return _AvailablePortGenerator
class AvailablePortGenerator(object): class AvailablePortGenerator(object):
...@@ -57,6 +52,13 @@ class AvailablePortGenerator(object): ...@@ -57,6 +52,13 @@ class AvailablePortGenerator(object):
return self._curr_port - 1 return self._curr_port - 1
_AvailablePortGenerator = AvailablePortGenerator()
def GetAvailablePortGenerator():
return _AvailablePortGenerator
class NameGenerator(object): class NameGenerator(object):
# use unsafe-id-generator # use unsafe-id-generator
def __init__(self, prefix): def __init__(self, prefix):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册