提交 b2bfbe9a 编写于 作者: B barriery

update code

上级 dc91636b
......@@ -28,7 +28,7 @@ _workdir_name_gen = util.NameGenerator("workdir_")
class LocalRpcServiceHandler(object):
def __init__(self,
model_config,
workdir=None,
workdir="",
thread_num=2,
devices="",
mem_optim=True,
......@@ -105,7 +105,7 @@ class LocalRpcServiceHandler(object):
def prepare_server(self):
for i, device_id in enumerate(self._devices):
if self._workdir is not None:
if self._workdir != "":
workdir = "{}_{}".format(self._workdir, i)
else:
workdir = _workdir_name_gen.next()
......
......@@ -57,6 +57,7 @@ class Op(object):
batch_size=None,
auto_batching_timeout=None,
local_rpc_service_handler=None):
# In __init__, all the parameters are just saved and Op is not initialized
if name is None:
name = _op_name_gen.next()
self.name = name # to identify the type of OP, it must be globally unique
......@@ -84,7 +85,8 @@ class Op(object):
self._succ_init_op = False
self._succ_close_op = False
def configure_from_dict(self, conf):
def init_from_dict(self, conf):
# init op
if self.concurrency is None:
self.concurrency = conf["concurrency"]
if self._retry is None:
......@@ -123,7 +125,10 @@ class Op(object):
else:
if self._local_rpc_service_handler is None:
local_service_conf = conf.get("local_service_conf")
_LOGGER.info("local_service_conf: {}".format(
local_service_conf))
model_config = local_service_conf.get("model_config")
_LOGGER.info("model_config: {}".format(model_config))
if model_config is None:
self.with_serving = False
else:
......@@ -141,12 +146,14 @@ class Op(object):
self._server_endpoints = [
"127.0.0.1:{}".format(p) for p in serivce_ports
]
if client_config is None:
client_config = service_handler.get_client_config()
if fetch_list is None:
fetch_list = service_handler.get_fetch_list()
if self._client_config is None:
self._client_config = service_handler.get_client_config(
)
if self._fetch_names is None:
self._fetch_names = service_handler.get_fetch_list()
self._local_rpc_service_handler = service_handler
else:
self.with_serving = True
self._local_rpc_service_handler.prepare_server(
) # get fetch_list
serivce_ports = self._local_rpc_service_handler.get_port_list(
......@@ -154,12 +161,14 @@ class Op(object):
self._server_endpoints = [
"127.0.0.1:{}".format(p) for p in serivce_ports
]
if client_config is None:
client_config = self._local_rpc_service_handler.get_client_config(
if self._client_config is None:
self._client_config = self._local_rpc_service_handler.get_client_config(
)
if fetch_list is None:
fetch_list = self._local_rpc_service_handler.get_fetch_list(
if self._fetch_names is None:
self._fetch_names = self._local_rpc_service_handler.get_fetch_list(
)
else:
self.with_serving = True
if not isinstance(self, RequestOp) and not isinstance(self, ResponseOp):
_LOGGER.info(
......
......@@ -22,7 +22,7 @@ from contextlib import closing
import multiprocessing
import yaml
from . import proto
from .proto import pipeline_service_pb2_grpc
from . import operator
from . import dag
from . import util
......@@ -30,7 +30,7 @@ from . import util
_LOGGER = logging.getLogger(__name__)
class PipelineServicer(proto.pipeline_service_pb2_grpc.PipelineServiceServicer):
class PipelineServicer(pipeline_service_pb2_grpc.PipelineServiceServicer):
def __init__(self, response_op, dag_conf, worker_idx=-1):
super(PipelineServicer, self).__init__()
# init dag executor
......@@ -133,7 +133,7 @@ class PipelineServer(object):
self._worker_num = conf["worker_num"]
self._build_dag_each_worker = conf["build_dag_each_worker"]
self._configure_ops(conf["op"])
self._init_ops(conf["op"])
_LOGGER.info("============= PIPELINE SERVER =============")
_LOGGER.info("\n{}".format(
......@@ -146,16 +146,17 @@ class PipelineServer(object):
_LOGGER.info("-------------------------------------------")
self._conf = conf
self._start_local_rpc_service()
def _configure_ops(self, op_conf):
def _init_ops(self, op_conf):
default_conf = {
"concurrency": 1,
"timeout": -1,
"retry": 1,
"batch_size": 1,
"auto_batching_timeout": None,
"auto_batching_timeout": -1,
"local_service_conf": {
"workdir": None,
"workdir": "",
"thread_num": 2,
"devices": "",
"mem_optim": True,
......@@ -163,9 +164,10 @@ class PipelineServer(object):
},
}
for op in self._used_op:
if not isinstance(op, operator.RequestOp):
if not isinstance(op, operator.RequestOp) and not isinstance(
op, operator.ResponseOp):
conf = op_conf.get(op.name, default_conf)
op.configure_from_dict(conf)
op.init_from_dict(conf)
def _start_local_rpc_service(self):
# only brpc now
......@@ -198,7 +200,7 @@ class PipelineServer(object):
options=[('grpc.max_send_message_length', 256 * 1024 * 1024),
('grpc.max_receive_message_length', 256 * 1024 * 1024)
])
proto.pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server(
pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server(
PipelineServicer(self._response_op, self._conf), server)
server.add_insecure_port('[::]:{}'.format(self._rpc_port))
server.start()
......@@ -214,7 +216,7 @@ class PipelineServer(object):
server = grpc.server(
futures.ThreadPoolExecutor(
max_workers=1, ), options=options)
proto.pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server(
pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server(
PipelineServicer(response_op, dag_conf, worker_idx), server)
server.add_insecure_port(bind_address)
server.start()
......@@ -284,7 +286,7 @@ class ServerYamlConfChecker(object):
@staticmethod
def check_local_service_conf(conf):
default_conf = {
"workdir": None,
"workdir": "",
"thread_num": 2,
"devices": "",
"mem_optim": True,
......@@ -309,7 +311,7 @@ class ServerYamlConfChecker(object):
"timeout": -1,
"retry": 1,
"batch_size": 1,
"auto_batching_timeout": None,
"auto_batching_timeout": -1,
"local_service_conf": {},
}
conf_type = {
......@@ -327,9 +329,8 @@ class ServerYamlConfChecker(object):
"retry": (">=", 1),
"batch_size": (">=", 1),
}
for op_name in conf:
ServerYamlConfChecker.check_conf(op_conf[op_name], {}, conf_type,
conf_qualification)
ServerYamlConfChecker.check_conf(conf, default_conf, conf_type,
conf_qualification)
@staticmethod
def check_tracer_conf(conf):
......@@ -390,6 +391,8 @@ class ServerYamlConfChecker(object):
@staticmethod
def check_conf_qualification(conf, conf_qualification):
for key, qualification in conf_qualification.items():
if key not in conf:
continue
if not isinstance(qualification, list):
qualification = [qualification]
if not ServerYamlConfChecker.qualification_check(conf[key],
......
......@@ -29,11 +29,6 @@ else:
raise Exception("Error Python version")
_LOGGER = logging.getLogger(__name__)
_AvailablePortGenerator = AvailablePortGenerator()
def GetAvailablePortGenerator():
return _AvailablePortGenerator
class AvailablePortGenerator(object):
......@@ -57,6 +52,13 @@ class AvailablePortGenerator(object):
return self._curr_port - 1
_AvailablePortGenerator = AvailablePortGenerator()
def GetAvailablePortGenerator():
return _AvailablePortGenerator
class NameGenerator(object):
# use unsafe-id-generator
def __init__(self, prefix):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册