提交 383cc4c9 编写于 作者: B bjjwwang

add feed dict error catch, and feed/fetch list in operator. TODO: fix the run_process

上级 8981c1bf
......@@ -179,4 +179,18 @@ class ParamVerify(object):
return False
@staticmethod
def check_feed_dict(feed_dict, feed_list):
if not isinstance(feed_dict, dict):
return False
# read model config, try catch and
feed_dict_key_size = len(feed_dict.keys())
if len(feed_dict.keys()) != len(feed_list):
return False
for key in feed_list:
if key in feed_dict.keys():
return False
return True
ErrorCatch = ErrorCatch()
......@@ -114,6 +114,19 @@ class Op(object):
self._succ_init_op = False
self._succ_close_op = False
# for feed/fetch dict cehck
@staticmethod
def get_feed_fetch_list(client):
from paddle_serving_app.local_predict import LocalPredictor
if isinstance(client, Client):
feed_names = client.get_feed_names()
fetch_names = client.get_fetch_names()
if isinstance(client, LocalPredictor):
feed_names = client.feed_names_
fetch_names = client.fetch_names_
return feed_names, fetch_names
def init_from_dict(self, conf):
"""
Initializing one Op from config.yaml. If server_endpoints exist,
......@@ -134,7 +147,6 @@ class Op(object):
self._fetch_names = conf.get("fetch_list")
if self._client_config is None:
self._client_config = conf.get("client_config")
if self._timeout is None:
self._timeout = conf["timeout"]
if self._timeout > 0:
......@@ -354,12 +366,14 @@ class Op(object):
if self.client_type == 'brpc':
client = Client()
client.load_client_config(client_config)
self.right_feed_names, self.right_fetch_names = self.get_feed_fetch_list(client)
elif self.client_type == 'pipeline_grpc':
client = PPClient()
elif self.client_type == 'local_predictor':
if self.local_predictor is None:
raise ValueError("local predictor not yet created")
client = self.local_predictor
self.right_feed_names, self.right_fetch_names = self.get_feed_fetch_list(client)
else:
raise ValueError("Failed to init client: unknow client "
"type {}".format(self.client_type))
......@@ -368,6 +382,7 @@ class Op(object):
_LOGGER.info("Op({}) has no fetch name set. So fetch all vars")
if self.client_type != "local_predictor":
client.connect(server_endpoints)
_LOGGER.info("init_client, feed_list:{}, fetch_list: {}".format(self.right_feed_names, self.right_fetch_names))
return client
def get_input_ops(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册