提交 5936dfc6 编写于 作者: B barrierye

add ResposeOp and ResquestOp

上级 5ad0861c
use_multithread: true use_multithread: true
client_type: brpc client_type: brpc
retry: 2 retry: 1
profile: false profile: false
prot: 8080 prot: 8080
worker_num: 2 worker_num: 2
...@@ -23,7 +23,7 @@ lp_wrapper = lp(client.predict) ...@@ -23,7 +23,7 @@ lp_wrapper = lp(client.predict)
words = 'i am very sad | 0' words = 'i am very sad | 0'
for i in range(10): for i in range(1):
fetch_map = lp_wrapper(feed_dict={"words": words}, fetch=["prediction"]) fetch_map = lp_wrapper(feed_dict={"words": words}, fetch=["prediction"])
print(fetch_map) print(fetch_map)
......
...@@ -13,59 +13,81 @@ ...@@ -13,59 +13,81 @@
# limitations under the License. # limitations under the License.
# pylint: disable=doc-string-missing # pylint: disable=doc-string-missing
from paddle_serving_server.pipeline import Op, ReadOp from paddle_serving_server.pipeline import Op, RequestOp, ResponseOp
from paddle_serving_server.pipeline import PipelineServer from paddle_serving_server.pipeline import PipelineServer
from paddle_serving_server.pipeline.proto import pipeline_service_pb2
from paddle_serving_server.pipeline.channel import ChannelDataEcode
import numpy as np import numpy as np
import logging import logging
from paddle_serving_app.reader import IMDBDataset from paddle_serving_app.reader import IMDBDataset
_LOGGER = logging.getLogger(__name__)
logging.basicConfig( logging.basicConfig(
format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s', format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s',
datefmt='%Y-%m-%d %H:%M', datefmt='%Y-%m-%d %H:%M',
level=logging.INFO) level=logging.DEBUG)
class ImdbOp(Op): class ImdbRequestOp(RequestOp):
def load_user_resources(self): def load_user_resources(self):
self.imdb_dataset = IMDBDataset() self.imdb_dataset = IMDBDataset()
self.imdb_dataset.load_resource('imdb.vocab') self.imdb_dataset.load_resource('imdb.vocab')
def preprocess(self, input_data): def unpack_request_package(self, request):
data = input_data.parse() dictdata = {}
word_ids, _ = self.imdb_dataset.get_words_and_label(data['words']) for idx, key in enumerate(request.key):
return {"words": word_ids} if key != "words":
continue
words = request.value[idx]
word_ids, _ = self.imdb_dataset.get_words_and_label(words)
dictdata[key] = np.array(word_ids)
return dictdata
class CombineOp(Op): class CombineOp(Op):
def preprocess(self, input_data): def preprocess(self, input_data):
combined_prediction = 0 combined_prediction = 0
for op_name, channeldata in input_data.items(): for op_name, data in input_data.items():
data = channeldata.parse() _LOGGER.info("{}: {}".format(op_name, data["prediction"]))
logging.info("{}: {}".format(op_name, data["prediction"]))
combined_prediction += data["prediction"] combined_prediction += data["prediction"]
data = {"prediction": combined_prediction / 2} data = {"prediction": combined_prediction / 2}
return data return data
read_op = ReadOp() class ImdbResponseOp(ResponseOp):
bow_op = ImdbOp( # Here ImdbResponseOp is consistent with the default ResponseOp implementation
name="bow", def pack_response_package(self, channeldata):
input_ops=[read_op], resp = pipeline_service_pb2.Response()
server_endpoints=["127.0.0.1:9393"], resp.ecode = channeldata.ecode
fetch_list=["prediction"], if resp.ecode == ChannelDataEcode.OK.value:
client_config="imdb_bow_client_conf/serving_client_conf.prototxt", feed = channeldata.parse()
concurrency=1, # ndarray to string
timeout=-1, for name, var in feed.items():
retry=1) resp.value.append(var.__repr__())
cnn_op = ImdbOp( resp.key.append(name)
name="cnn", else:
input_ops=[read_op], resp.error_info = channeldata.error_info
server_endpoints=["127.0.0.1:9292"], return resp
fetch_list=["prediction"],
client_config="imdb_cnn_client_conf/serving_client_conf.prototxt",
concurrency=1, read_op = ImdbRequestOp()
timeout=-1, bow_op = Op(name="bow",
retry=1) input_ops=[read_op],
server_endpoints=["127.0.0.1:9393"],
fetch_list=["prediction"],
client_config="imdb_bow_client_conf/serving_client_conf.prototxt",
concurrency=1,
timeout=-1,
retry=1)
cnn_op = Op(name="cnn",
input_ops=[read_op],
server_endpoints=["127.0.0.1:9292"],
fetch_list=["prediction"],
client_config="imdb_cnn_client_conf/serving_client_conf.prototxt",
concurrency=1,
timeout=-1,
retry=1)
combine_op = CombineOp( combine_op = CombineOp(
name="combine", name="combine",
input_ops=[bow_op, cnn_op], input_ops=[bow_op, cnn_op],
...@@ -73,9 +95,16 @@ combine_op = CombineOp( ...@@ -73,9 +95,16 @@ combine_op = CombineOp(
timeout=-1, timeout=-1,
retry=1) retry=1)
# fetch output of bow_op
# response_op = ImdbResponseOp(input_ops=[bow_op])
# fetch output of combine_op
response_op = ImdbResponseOp(input_ops=[combine_op])
# use default ResponseOp implementation
# response_op = ResponseOp(input_ops=[combine_op])
server = PipelineServer() server = PipelineServer()
server.add_ops([read_op, bow_op, cnn_op, combine_op]) server.set_response_op(response_op)
#server.set_response_op(bow_op)
server.set_response_op(combine_op)
server.prepare_server('config.yml') server.prepare_server('config.yml')
server.run_server() server.run_server()
...@@ -12,6 +12,6 @@ ...@@ -12,6 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from operator import Op, ReadOp from operator import Op, RequestOp, ResponseOp
from pipeline_server import PipelineServer from pipeline_server import PipelineServer
from pipeline_client import PipelineClient from pipeline_client import PipelineClient
...@@ -27,6 +27,8 @@ import logging ...@@ -27,6 +27,8 @@ import logging
import enum import enum
import copy import copy
_LOGGER = logging.getLogger(__name__)
class ChannelDataEcode(enum.Enum): class ChannelDataEcode(enum.Enum):
OK = 0 OK = 0
...@@ -71,12 +73,12 @@ class ChannelData(object): ...@@ -71,12 +73,12 @@ class ChannelData(object):
ecode, error_info = ChannelData.check_npdata(npdata) ecode, error_info = ChannelData.check_npdata(npdata)
if ecode != ChannelDataEcode.OK.value: if ecode != ChannelDataEcode.OK.value:
datatype = ChannelDataType.ERROR.value datatype = ChannelDataType.ERROR.value
logging.error(error_info) _LOGGER.error(error_info)
elif datatype == ChannelDataType.DICT.value: elif datatype == ChannelDataType.DICT.value:
ecode, error_info = ChannelData.check_dictdata(dictdata) ecode, error_info = ChannelData.check_dictdata(dictdata)
if ecode != ChannelDataEcode.OK.value: if ecode != ChannelDataEcode.OK.value:
datatype = ChannelDataType.ERROR.value datatype = ChannelDataType.ERROR.value
logging.error(error_info) _LOGGER.error(error_info)
else: else:
raise ValueError("datatype not match") raise ValueError("datatype not match")
self.datatype = datatype self.datatype = datatype
...@@ -92,8 +94,8 @@ class ChannelData(object): ...@@ -92,8 +94,8 @@ class ChannelData(object):
error_info = None error_info = None
if not isinstance(dictdata, dict): if not isinstance(dictdata, dict):
ecode = ChannelDataEcode.TYPE_ERROR.value ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = "the value of postped_data must " \ error_info = "the value of data must " \
"be dict, but get {}".format(type(dictdata)) "be dict, but get {}.".format(type(dictdata))
return ecode, error_info return ecode, error_info
@staticmethod @staticmethod
...@@ -103,8 +105,8 @@ class ChannelData(object): ...@@ -103,8 +105,8 @@ class ChannelData(object):
for _, value in npdata.items(): for _, value in npdata.items():
if not isinstance(value, np.ndarray): if not isinstance(value, np.ndarray):
ecode = ChannelDataEcode.TYPE_ERROR.value ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = "the value of postped_data must " \ error_info = "the value of data must " \
"be np.ndarray, but get {}".format(type(value)) "be np.ndarray, but get {}.".format(type(value))
break break
return ecode, error_info return ecode, error_info
...@@ -200,7 +202,7 @@ class ProcessChannel(multiprocessing.queues.Queue): ...@@ -200,7 +202,7 @@ class ProcessChannel(multiprocessing.queues.Queue):
self._idx_consumer_num[0] += 1 self._idx_consumer_num[0] += 1
def push(self, channeldata, op_name=None): def push(self, channeldata, op_name=None):
logging.debug( _LOGGER.debug(
self._log("{} try to push data: {}".format(op_name, self._log("{} try to push data: {}".format(op_name,
channeldata.__str__()))) channeldata.__str__())))
if len(self._producers) == 0: if len(self._producers) == 0:
...@@ -212,16 +214,16 @@ class ProcessChannel(multiprocessing.queues.Queue): ...@@ -212,16 +214,16 @@ class ProcessChannel(multiprocessing.queues.Queue):
with self._cv: with self._cv:
while self._stop is False: while self._stop is False:
try: try:
self.put(channeldata, timeout=0) self.put({op_name: channeldata}, timeout=0)
break break
except Queue.Full: except Queue.Full:
self._cv.wait() self._cv.wait()
logging.debug( _LOGGER.debug(
self._log("{} channel size: {}".format(op_name, self._log("{} channel size: {}".format(op_name,
self.qsize()))) self.qsize())))
self._cv.notify_all() self._cv.notify_all()
logging.debug(self._log("{} notify all".format(op_name))) _LOGGER.debug(self._log("{} notify all".format(op_name)))
logging.debug(self._log("{} push data succ!".format(op_name))) _LOGGER.debug(self._log("{} push data succ!".format(op_name)))
return True return True
elif op_name is None: elif op_name is None:
raise Exception( raise Exception(
...@@ -232,7 +234,7 @@ class ProcessChannel(multiprocessing.queues.Queue): ...@@ -232,7 +234,7 @@ class ProcessChannel(multiprocessing.queues.Queue):
data_id = channeldata.id data_id = channeldata.id
put_data = None put_data = None
with self._cv: with self._cv:
logging.debug(self._log("{} get lock".format(op_name))) _LOGGER.debug(self._log("{} get lock".format(op_name)))
if data_id not in self._push_res: if data_id not in self._push_res:
self._push_res[data_id] = { self._push_res[data_id] = {
name: None name: None
...@@ -253,13 +255,13 @@ class ProcessChannel(multiprocessing.queues.Queue): ...@@ -253,13 +255,13 @@ class ProcessChannel(multiprocessing.queues.Queue):
self._producer_res_count[data_id] += 1 self._producer_res_count[data_id] += 1
if put_data is None: if put_data is None:
logging.debug( _LOGGER.debug(
self._log("{} push data succ, but not push to queue.". self._log("{} push data succ, but not push to queue.".
format(op_name))) format(op_name)))
else: else:
while self._stop is False: while self._stop is False:
try: try:
logging.debug( _LOGGER.debug(
self._log("{} push data succ: {}".format( self._log("{} push data succ: {}".format(
op_name, put_data.__str__()))) op_name, put_data.__str__())))
self.put(put_data, timeout=0) self.put(put_data, timeout=0)
...@@ -267,13 +269,13 @@ class ProcessChannel(multiprocessing.queues.Queue): ...@@ -267,13 +269,13 @@ class ProcessChannel(multiprocessing.queues.Queue):
except Queue.Empty: except Queue.Empty:
self._cv.wait() self._cv.wait()
logging.debug( _LOGGER.debug(
self._log("multi | {} push data succ!".format(op_name))) self._log("multi | {} push data succ!".format(op_name)))
self._cv.notify_all() self._cv.notify_all()
return True return True
def front(self, op_name=None): def front(self, op_name=None):
logging.debug(self._log("{} try to get data...".format(op_name))) _LOGGER.debug(self._log("{} try to get data...".format(op_name)))
if len(self._consumers) == 0: if len(self._consumers) == 0:
raise Exception( raise Exception(
self._log( self._log(
...@@ -284,7 +286,7 @@ class ProcessChannel(multiprocessing.queues.Queue): ...@@ -284,7 +286,7 @@ class ProcessChannel(multiprocessing.queues.Queue):
with self._cv: with self._cv:
while self._stop is False and resp is None: while self._stop is False and resp is None:
try: try:
logging.debug( _LOGGER.debug(
self._log("{} try to get(with channel empty: {})". self._log("{} try to get(with channel empty: {})".
format(op_name, self.empty()))) format(op_name, self.empty())))
# For queue multiprocess: after putting an object on # For queue multiprocess: after putting an object on
...@@ -296,12 +298,12 @@ class ProcessChannel(multiprocessing.queues.Queue): ...@@ -296,12 +298,12 @@ class ProcessChannel(multiprocessing.queues.Queue):
resp = self.get(timeout=1e-3) resp = self.get(timeout=1e-3)
break break
except Queue.Empty: except Queue.Empty:
logging.debug( _LOGGER.debug(
self._log( self._log(
"{} wait for empty queue(with channel empty: {})". "{} wait for empty queue(with channel empty: {})".
format(op_name, self.empty()))) format(op_name, self.empty())))
self._cv.wait() self._cv.wait()
logging.debug( _LOGGER.debug(
self._log("{} get data succ: {}".format(op_name, resp.__str__( self._log("{} get data succ: {}".format(op_name, resp.__str__(
)))) ))))
return resp return resp
...@@ -315,13 +317,13 @@ class ProcessChannel(multiprocessing.queues.Queue): ...@@ -315,13 +317,13 @@ class ProcessChannel(multiprocessing.queues.Queue):
while self._stop is False and self._consumers[ while self._stop is False and self._consumers[
op_name] - self._consumer_base_idx.value >= len( op_name] - self._consumer_base_idx.value >= len(
self._front_res): self._front_res):
logging.debug( _LOGGER.debug(
self._log( self._log(
"({}) B self._consumers: {}, self._consumer_base_idx: {}, len(self._front_res): {}". "({}) B self._consumers: {}, self._consumer_base_idx: {}, len(self._front_res): {}".
format(op_name, self._consumers, self. format(op_name, self._consumers, self.
_consumer_base_idx.value, len(self._front_res)))) _consumer_base_idx.value, len(self._front_res))))
try: try:
logging.debug( _LOGGER.debug(
self._log("{} try to get(with channel size: {})".format( self._log("{} try to get(with channel size: {})".format(
op_name, self.qsize()))) op_name, self.qsize())))
# For queue multiprocess: after putting an object on # For queue multiprocess: after putting an object on
...@@ -334,7 +336,7 @@ class ProcessChannel(multiprocessing.queues.Queue): ...@@ -334,7 +336,7 @@ class ProcessChannel(multiprocessing.queues.Queue):
self._front_res.append(channeldata) self._front_res.append(channeldata)
break break
except Queue.Empty: except Queue.Empty:
logging.debug( _LOGGER.debug(
self._log( self._log(
"{} wait for empty queue(with channel size: {})". "{} wait for empty queue(with channel size: {})".
format(op_name, self.qsize()))) format(op_name, self.qsize())))
...@@ -344,7 +346,7 @@ class ProcessChannel(multiprocessing.queues.Queue): ...@@ -344,7 +346,7 @@ class ProcessChannel(multiprocessing.queues.Queue):
base_idx = self._consumer_base_idx.value base_idx = self._consumer_base_idx.value
data_idx = consumer_idx - base_idx data_idx = consumer_idx - base_idx
resp = self._front_res[data_idx] resp = self._front_res[data_idx]
logging.debug(self._log("{} get data: {}".format(op_name, resp))) _LOGGER.debug(self._log("{} get data: {}".format(op_name, resp)))
self._idx_consumer_num[consumer_idx] -= 1 self._idx_consumer_num[consumer_idx] -= 1
if consumer_idx == base_idx and self._idx_consumer_num[ if consumer_idx == base_idx and self._idx_consumer_num[
...@@ -358,15 +360,15 @@ class ProcessChannel(multiprocessing.queues.Queue): ...@@ -358,15 +360,15 @@ class ProcessChannel(multiprocessing.queues.Queue):
if self._idx_consumer_num.get(new_consumer_idx) is None: if self._idx_consumer_num.get(new_consumer_idx) is None:
self._idx_consumer_num[new_consumer_idx] = 0 self._idx_consumer_num[new_consumer_idx] = 0
self._idx_consumer_num[new_consumer_idx] += 1 self._idx_consumer_num[new_consumer_idx] += 1
logging.debug( _LOGGER.debug(
self._log( self._log(
"({}) A self._consumers: {}, self._consumer_base_idx: {}, len(self._front_res): {}". "({}) A self._consumers: {}, self._consumer_base_idx: {}, len(self._front_res): {}".
format(op_name, self._consumers, self._consumer_base_idx. format(op_name, self._consumers, self._consumer_base_idx.
value, len(self._front_res)))) value, len(self._front_res))))
logging.debug(self._log("{} notify all".format(op_name))) _LOGGER.debug(self._log("{} notify all".format(op_name)))
self._cv.notify_all() self._cv.notify_all()
logging.debug(self._log("multi | {} get data succ!".format(op_name))) _LOGGER.debug(self._log("multi | {} get data succ!".format(op_name)))
return resp # reference, read only return resp # reference, read only
def stop(self): def stop(self):
...@@ -444,7 +446,7 @@ class ThreadChannel(Queue.Queue): ...@@ -444,7 +446,7 @@ class ThreadChannel(Queue.Queue):
self._idx_consumer_num[0] += 1 self._idx_consumer_num[0] += 1
def push(self, channeldata, op_name=None): def push(self, channeldata, op_name=None):
logging.debug( _LOGGER.debug(
self._log("{} try to push data: {}".format(op_name, self._log("{} try to push data: {}".format(op_name,
channeldata.__str__()))) channeldata.__str__())))
if len(self._producers) == 0: if len(self._producers) == 0:
...@@ -456,12 +458,12 @@ class ThreadChannel(Queue.Queue): ...@@ -456,12 +458,12 @@ class ThreadChannel(Queue.Queue):
with self._cv: with self._cv:
while self._stop is False: while self._stop is False:
try: try:
self.put(channeldata, timeout=0) self.put({op_name: channeldata}, timeout=0)
break break
except Queue.Full: except Queue.Full:
self._cv.wait() self._cv.wait()
self._cv.notify_all() self._cv.notify_all()
logging.debug(self._log("{} push data succ!".format(op_name))) _LOGGER.debug(self._log("{} push data succ!".format(op_name)))
return True return True
elif op_name is None: elif op_name is None:
raise Exception( raise Exception(
...@@ -472,7 +474,7 @@ class ThreadChannel(Queue.Queue): ...@@ -472,7 +474,7 @@ class ThreadChannel(Queue.Queue):
data_id = channeldata.id data_id = channeldata.id
put_data = None put_data = None
with self._cv: with self._cv:
logging.debug(self._log("{} get lock".format(op_name))) _LOGGER.debug(self._log("{} get lock".format(op_name)))
if data_id not in self._push_res: if data_id not in self._push_res:
self._push_res[data_id] = { self._push_res[data_id] = {
name: None name: None
...@@ -488,7 +490,7 @@ class ThreadChannel(Queue.Queue): ...@@ -488,7 +490,7 @@ class ThreadChannel(Queue.Queue):
self._producer_res_count[data_id] += 1 self._producer_res_count[data_id] += 1
if put_data is None: if put_data is None:
logging.debug( _LOGGER.debug(
self._log("{} push data succ, but not push to queue.". self._log("{} push data succ, but not push to queue.".
format(op_name))) format(op_name)))
else: else:
...@@ -499,13 +501,13 @@ class ThreadChannel(Queue.Queue): ...@@ -499,13 +501,13 @@ class ThreadChannel(Queue.Queue):
except Queue.Empty: except Queue.Empty:
self._cv.wait() self._cv.wait()
logging.debug( _LOGGER.debug(
self._log("multi | {} push data succ!".format(op_name))) self._log("multi | {} push data succ!".format(op_name)))
self._cv.notify_all() self._cv.notify_all()
return True return True
def front(self, op_name=None): def front(self, op_name=None):
logging.debug(self._log("{} try to get data".format(op_name))) _LOGGER.debug(self._log("{} try to get data".format(op_name)))
if len(self._consumers) == 0: if len(self._consumers) == 0:
raise Exception( raise Exception(
self._log( self._log(
...@@ -520,7 +522,7 @@ class ThreadChannel(Queue.Queue): ...@@ -520,7 +522,7 @@ class ThreadChannel(Queue.Queue):
break break
except Queue.Empty: except Queue.Empty:
self._cv.wait() self._cv.wait()
logging.debug( _LOGGER.debug(
self._log("{} get data succ: {}".format(op_name, resp.__str__( self._log("{} get data succ: {}".format(op_name, resp.__str__(
)))) ))))
return resp return resp
...@@ -544,7 +546,7 @@ class ThreadChannel(Queue.Queue): ...@@ -544,7 +546,7 @@ class ThreadChannel(Queue.Queue):
base_idx = self._consumer_base_idx base_idx = self._consumer_base_idx
data_idx = consumer_idx - base_idx data_idx = consumer_idx - base_idx
resp = self._front_res[data_idx] resp = self._front_res[data_idx]
logging.debug(self._log("{} get data: {}".format(op_name, resp))) _LOGGER.debug(self._log("{} get data: {}".format(op_name, resp)))
self._idx_consumer_num[consumer_idx] -= 1 self._idx_consumer_num[consumer_idx] -= 1
if consumer_idx == base_idx and self._idx_consumer_num[ if consumer_idx == base_idx and self._idx_consumer_num[
...@@ -561,7 +563,7 @@ class ThreadChannel(Queue.Queue): ...@@ -561,7 +563,7 @@ class ThreadChannel(Queue.Queue):
self._cv.notify_all() self._cv.notify_all()
logging.debug(self._log("multi | {} get data succ!".format(op_name))) _LOGGER.debug(self._log("multi | {} get data succ!".format(op_name)))
# return resp # reference, read only # return resp # reference, read only
return copy.deepcopy(resp) return copy.deepcopy(resp)
......
...@@ -20,9 +20,11 @@ from concurrent import futures ...@@ -20,9 +20,11 @@ from concurrent import futures
import logging import logging
import func_timeout import func_timeout
from .proto import pipeline_service_pb2
from .channel import ThreadChannel, ProcessChannel, ChannelDataEcode, ChannelData, ChannelDataType from .channel import ThreadChannel, ProcessChannel, ChannelDataEcode, ChannelData, ChannelDataType
from .util import NameGenerator from .util import NameGenerator
_LOGGER = logging.getLogger(__name__)
_op_name_gen = NameGenerator("Op") _op_name_gen = NameGenerator("Op")
...@@ -67,17 +69,17 @@ class Op(object): ...@@ -67,17 +69,17 @@ class Op(object):
def init_client(self, client_type, client_config, server_endpoints, def init_client(self, client_type, client_config, server_endpoints,
fetch_names): fetch_names):
if self.with_serving == False: if self.with_serving == False:
logging.debug("{} no client".format(self.name)) _LOGGER.debug("{} no client".format(self.name))
return return
logging.debug("{} client_config: {}".format(self.name, client_config)) _LOGGER.debug("{} client_config: {}".format(self.name, client_config))
logging.debug("{} fetch_names: {}".format(self.name, fetch_names)) _LOGGER.debug("{} fetch_names: {}".format(self.name, fetch_names))
if client_type == 'brpc': if client_type == 'brpc':
self._client = Client() self._client = Client()
self._client.load_client_config(client_config)
elif client_type == 'grpc': elif client_type == 'grpc':
self._client = MultiLangClient() self._client = MultiLangClient()
else: else:
raise ValueError("unknow client type: {}".format(client_type)) raise ValueError("unknow client type: {}".format(client_type))
self._client.load_client_config(client_config)
self._client.connect(server_endpoints) self._client.connect(server_endpoints)
self._fetch_names = fetch_names self._fetch_names = fetch_names
...@@ -117,38 +119,31 @@ class Op(object): ...@@ -117,38 +119,31 @@ class Op(object):
channel.add_producer(self.name) channel.add_producer(self.name)
self._outputs.append(channel) self._outputs.append(channel)
def preprocess(self, channeldata): def preprocess(self, input_dicts):
# multiple previous Op # multiple previous Op
if isinstance(channeldata, dict): if len(input_dicts) != 1:
raise NotImplementedError( raise NotImplementedError(
'this Op has multiple previous inputs. Please override this method' 'this Op has multiple previous inputs. Please override this func.'
) )
if channeldata.datatype is not ChannelDataType.CHANNEL_NPDATA.value: (_, input_dict), = input_dicts.items()
raise NotImplementedError( return input_dict
'datatype in channeldata is not CHANNEL_NPDATA({}). '
'Please override this method'.format(channeldata.datatype))
# get numpy dict
feed_data = channeldata.parse()
return feed_data
def process(self, feed_dict): def process(self, feed_dict):
if not isinstance(feed_dict, dict): err, err_info = ChannelData.check_npdata(feed_dict)
raise Exception( if err != 0:
self._log( raise NotImplementedError(
'feed_dict must be dict type(the output of preprocess()), but get {}'. "{} Please override preprocess func.".format(err_info))
format(type(feed_dict)))) _LOGGER.debug(self._log('feed_dict: {}'.format(feed_dict)))
logging.debug(self._log('feed_dict: {}'.format(feed_dict))) _LOGGER.debug(self._log('fetch: {}'.format(self._fetch_names)))
logging.debug(self._log('fetch: {}'.format(self._fetch_names)))
if isinstance(self._client, MultiLangClient): if isinstance(self._client, MultiLangClient):
call_result = self._client.predict( call_result = self._client.predict(
feed=feed_dict, fetch=self._fetch_names) feed=feed_dict, fetch=self._fetch_names)
logging.debug(self._log("get call_result")) _LOGGER.debug(self._log("get call_result"))
else: else:
call_result = self._client.predict( call_result = self._client.predict(
feed=feed_dict, fetch=self._fetch_names) feed=feed_dict, fetch=self._fetch_names)
logging.debug(self._log("get fetch_dict")) _LOGGER.debug(self._log("get fetch_dict"))
return call_result return call_result
def postprocess(self, fetch_dict): def postprocess(self, fetch_dict):
...@@ -160,21 +155,19 @@ class Op(object): ...@@ -160,21 +155,19 @@ class Op(object):
channel.stop() channel.stop()
self._is_run = False self._is_run = False
def _parse_channeldata(self, channeldata): def _parse_channeldata(self, channeldata_dict):
data_id, error_channeldata = None, None data_id, error_channeldata = None, None
if isinstance(channeldata, dict): parsed_data = {}
parsed_data = {}
key = list(channeldata.keys())[0] key = list(channeldata_dict.keys())[0]
data_id = channeldata[key].id data_id = channeldata_dict[key].id
for _, data in channeldata.items():
if data.ecode != ChannelDataEcode.OK.value: for name, data in channeldata_dict.items():
error_channeldata = data if data.ecode != ChannelDataEcode.OK.value:
break error_channeldata = data
else: break
data_id = channeldata.id parsed_data[name] = data.parse()
if channeldata.ecode != ChannelDataEcode.OK.value: return data_id, error_channeldata, parsed_data
error_channeldata = channeldata
return data_id, error_channeldata
def _push_to_output_channels(self, data, channels, name=None): def _push_to_output_channels(self, data, channels, name=None):
if name is None: if name is None:
...@@ -229,11 +222,12 @@ class Op(object): ...@@ -229,11 +222,12 @@ class Op(object):
self._is_run = True self._is_run = True
while self._is_run: while self._is_run:
self._profiler_record("{}-get#{}_0".format(op_info_prefix, tid)) self._profiler_record("{}-get#{}_0".format(op_info_prefix, tid))
channeldata = input_channel.front(self.name) channeldata_dict = input_channel.front(self.name)
self._profiler_record("{}-get#{}_1".format(op_info_prefix, tid)) self._profiler_record("{}-get#{}_1".format(op_info_prefix, tid))
logging.debug(log("input_data: {}".format(channeldata))) _LOGGER.debug(log("input_data: {}".format(channeldata_dict)))
data_id, error_channeldata = self._parse_channeldata(channeldata) data_id, error_channeldata, parsed_data = self._parse_channeldata(
channeldata_dict)
# error data in predecessor Op # error data in predecessor Op
if error_channeldata is not None: if error_channeldata is not None:
...@@ -245,13 +239,13 @@ class Op(object): ...@@ -245,13 +239,13 @@ class Op(object):
try: try:
self._profiler_record("{}-prep#{}_0".format(op_info_prefix, self._profiler_record("{}-prep#{}_0".format(op_info_prefix,
tid)) tid))
preped_data = self.preprocess(channeldata) preped_data = self.preprocess(parsed_data)
self._profiler_record("{}-prep#{}_1".format(op_info_prefix, self._profiler_record("{}-prep#{}_1".format(op_info_prefix,
tid)) tid))
except NotImplementedError as e: except NotImplementedError as e:
# preprocess function not implemented # preprocess function not implemented
error_info = log(e) error_info = log(e)
logging.error(error_info) _LOGGER.error(error_info)
self._push_to_output_channels( self._push_to_output_channels(
ChannelData( ChannelData(
ecode=ChannelDataEcode.NOT_IMPLEMENTED.value, ecode=ChannelDataEcode.NOT_IMPLEMENTED.value,
...@@ -262,7 +256,7 @@ class Op(object): ...@@ -262,7 +256,7 @@ class Op(object):
except TypeError as e: except TypeError as e:
# Error type in channeldata.datatype # Error type in channeldata.datatype
error_info = log(e) error_info = log(e)
logging.error(error_info) _LOGGER.error(error_info)
self._push_to_output_channels( self._push_to_output_channels(
ChannelData( ChannelData(
ecode=ChannelDataEcode.TYPE_ERROR.value, ecode=ChannelDataEcode.TYPE_ERROR.value,
...@@ -272,7 +266,7 @@ class Op(object): ...@@ -272,7 +266,7 @@ class Op(object):
continue continue
except Exception as e: except Exception as e:
error_info = log(e) error_info = log(e)
logging.error(error_info) _LOGGER.error(error_info)
self._push_to_output_channels( self._push_to_output_channels(
ChannelData( ChannelData(
ecode=ChannelDataEcode.UNKNOW.value, ecode=ChannelDataEcode.UNKNOW.value,
...@@ -293,7 +287,7 @@ class Op(object): ...@@ -293,7 +287,7 @@ class Op(object):
except Exception as e: except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value ecode = ChannelDataEcode.UNKNOW.value
error_info = log(e) error_info = log(e)
logging.error(error_info) _LOGGER.error(error_info)
else: else:
for i in range(self._retry): for i in range(self._retry):
try: try:
...@@ -305,14 +299,14 @@ class Op(object): ...@@ -305,14 +299,14 @@ class Op(object):
if i + 1 >= self._retry: if i + 1 >= self._retry:
ecode = ChannelDataEcode.TIMEOUT.value ecode = ChannelDataEcode.TIMEOUT.value
error_info = log(e) error_info = log(e)
logging.error(error_info) _LOGGER.error(error_info)
else: else:
logging.warn( _LOGGER.warn(
log("timeout, retry({})".format(i + 1))) log("timeout, retry({})".format(i + 1)))
except Exception as e: except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value ecode = ChannelDataEcode.UNKNOW.value
error_info = log(e) error_info = log(e)
logging.error(error_info) _LOGGER.error(error_info)
break break
else: else:
break break
...@@ -346,7 +340,7 @@ class Op(object): ...@@ -346,7 +340,7 @@ class Op(object):
except Exception as e: except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value ecode = ChannelDataEcode.UNKNOW.value
error_info = log(e) error_info = log(e)
logging.error(error_info) _LOGGER.error(error_info)
self._push_to_output_channels( self._push_to_output_channels(
ChannelData( ChannelData(
ecode=ecode, error_info=error_info, data_id=data_id), ecode=ecode, error_info=error_info, data_id=data_id),
...@@ -356,7 +350,7 @@ class Op(object): ...@@ -356,7 +350,7 @@ class Op(object):
ecode = ChannelDataEcode.TYPE_ERROR.value ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = log("output of postprocess funticon must be " \ error_info = log("output of postprocess funticon must be " \
"dict type, but get {}".format(type(postped_data))) "dict type, but get {}".format(type(postped_data)))
logging.error(error_info) _LOGGER.error(error_info)
self._push_to_output_channels( self._push_to_output_channels(
ChannelData( ChannelData(
ecode=ecode, error_info=error_info, data_id=data_id), ecode=ecode, error_info=error_info, data_id=data_id),
...@@ -385,11 +379,62 @@ class Op(object): ...@@ -385,11 +379,62 @@ class Op(object):
return "{} {}".format(self.name, info) return "{} {}".format(self.name, info)
class ReadOp(Op): class RequestOp(Op):
""" RequestOp do not run preprocess, process, postprocess. """
def __init__(self, concurrency=1): def __init__(self, concurrency=1):
# PipelineService.name = "#G" # PipelineService.name = "#G"
super(ReadOp, self).__init__( super(RequestOp, self).__init__(
name="#G", input_ops=[], concurrency=concurrency) name="#G", input_ops=[], concurrency=concurrency)
# load user resources
self.load_user_resources()
def unpack_request_package(self, request):
dictdata = {}
for idx, key in enumerate(request.key):
dictdata[key] = request.value[idx]
return dictdata
class ResponseOp(Op):
""" ResponseOp do not run preprocess, process, postprocess. """
def __init__(self, input_ops, concurrency=1):
super(ResponseOp, self).__init__(
name="#R", input_ops=input_ops, concurrency=concurrency)
# load user resources
self.load_user_resources()
def pack_response_package(self, channeldata):
resp = pipeline_service_pb2.Response()
resp.ecode = channeldata.ecode
if resp.ecode == ChannelDataEcode.OK.value:
if channeldata.datatype == ChannelDataType.CHANNEL_NPDATA.value:
feed = channeldata.parse()
# ndarray to string:
# https://stackoverflow.com/questions/30167538/convert-a-numpy-ndarray-to-stringor-bytes-and-convert-it-back-to-numpy-ndarray
for name, var in feed.items():
resp.value.append(var.__repr__())
resp.key.append(name)
elif channeldata.datatype == ChannelDataType.DICT.value:
feed = channeldata.parse()
for name, var in feed.items():
if not isinstance(var, str):
resp.ecode = ChannelDataEcode.TYPE_ERROR.value
resp.error_info = self._log(
"fetch var type must be str({}).".format(
type(var)))
break
resp.value.append(var)
resp.key.append(name)
else:
resp.ecode = ChannelDataEcode.TYPE_ERROR.value
resp.error_info = self._log(
"Error type({}) in datatype.".format(channeldata.datatype))
_LOGGER.error(resp.error_info)
else:
resp.error_info = channeldata.error_info
return resp
class VirtualOp(Op): class VirtualOp(Op):
...@@ -427,17 +472,11 @@ class VirtualOp(Op): ...@@ -427,17 +472,11 @@ class VirtualOp(Op):
self._is_run = True self._is_run = True
while self._is_run: while self._is_run:
self._profiler_record("{}-get#{}_0".format(op_info_prefix, tid)) self._profiler_record("{}-get#{}_0".format(op_info_prefix, tid))
channeldata = input_channel.front(self.name) channeldata_dict = input_channel.front(self.name)
self._profiler_record("{}-get#{}_1".format(op_info_prefix, tid)) self._profiler_record("{}-get#{}_1".format(op_info_prefix, tid))
self._profiler_record("{}-push#{}_0".format(op_info_prefix, tid)) self._profiler_record("{}-push#{}_0".format(op_info_prefix, tid))
if isinstance(channeldata, dict): for name, data in channeldata_dict.items():
for name, data in channeldata.items():
self._push_to_output_channels(
data, channels=output_channels, name=name)
else:
self._push_to_output_channels( self._push_to_output_channels(
channeldata, data, channels=output_channels, name=name)
channels=output_channels,
name=self._virtual_pred_ops[0].name)
self._profiler_record("{}-push#{}_1".format(op_info_prefix, tid)) self._profiler_record("{}-push#{}_1".format(op_info_prefix, tid))
...@@ -14,10 +14,13 @@ ...@@ -14,10 +14,13 @@
# pylint: disable=doc-string-missing # pylint: disable=doc-string-missing
import grpc import grpc
import numpy as np import numpy as np
from numpy import array from numpy import *
import logging
from .proto import pipeline_service_pb2 from .proto import pipeline_service_pb2
from .proto import pipeline_service_pb2_grpc from .proto import pipeline_service_pb2_grpc
_LOGGER = logging.getLogger(__name__)
class PipelineClient(object): class PipelineClient(object):
def __init__(self): def __init__(self):
......
...@@ -40,22 +40,24 @@ import yaml ...@@ -40,22 +40,24 @@ import yaml
from .proto import pipeline_service_pb2 from .proto import pipeline_service_pb2
from .proto import pipeline_service_pb2_grpc from .proto import pipeline_service_pb2_grpc
from .operator import Op, ReadOp, VirtualOp from .operator import Op, RequestOp, ResponseOp, VirtualOp
from .channel import ThreadChannel, ProcessChannel, ChannelData, ChannelDataEcode, ChannelDataType from .channel import ThreadChannel, ProcessChannel, ChannelData, ChannelDataEcode, ChannelDataType
from .profiler import TimeProfiler from .profiler import TimeProfiler
from .util import NameGenerator from .util import NameGenerator
_LOGGER = logging.getLogger(__name__)
_profiler = TimeProfiler() _profiler = TimeProfiler()
class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer): class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer):
def __init__(self, in_channel, out_channel, retry=2): def __init__(self, in_channel, out_channel, unpack_func, pack_func,
retry=2):
super(PipelineService, self).__init__() super(PipelineService, self).__init__()
self.name = "#G" self.name = "#G"
self.set_in_channel(in_channel) self.set_in_channel(in_channel)
self.set_out_channel(out_channel) self.set_out_channel(out_channel)
logging.debug(self._log(in_channel.debug())) _LOGGER.debug(self._log(in_channel.debug()))
logging.debug(self._log(out_channel.debug())) _LOGGER.debug(self._log(out_channel.debug()))
#TODO: #TODO:
# multi-lock for different clients # multi-lock for different clients
# diffenert lock for server and client # diffenert lock for server and client
...@@ -64,6 +66,8 @@ class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer): ...@@ -64,6 +66,8 @@ class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer):
self._globel_resp_dict = {} self._globel_resp_dict = {}
self._id_counter = 0 self._id_counter = 0
self._retry = retry self._retry = retry
self._pack_func = pack_func
self._unpack_func = unpack_func
self._recive_func = threading.Thread( self._recive_func = threading.Thread(
target=PipelineService._recive_out_channel_func, args=(self, )) target=PipelineService._recive_out_channel_func, args=(self, ))
self._recive_func.start() self._recive_func.start()
...@@ -89,7 +93,10 @@ class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer): ...@@ -89,7 +93,10 @@ class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer):
def _recive_out_channel_func(self): def _recive_out_channel_func(self):
while True: while True:
channeldata = self._out_channel.front(self.name) channeldata_dict = self._out_channel.front(self.name)
if len(channeldata_dict) != 1:
raise Exception("out_channel cannot have multiple input ops")
(_, channeldata), = channeldata_dict.items()
if not isinstance(channeldata, ChannelData): if not isinstance(channeldata, ChannelData):
raise TypeError( raise TypeError(
self._log('data must be ChannelData type, but get {}'. self._log('data must be ChannelData type, but get {}'.
...@@ -114,18 +121,15 @@ class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer): ...@@ -114,18 +121,15 @@ class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer):
return resp return resp
def _pack_data_for_infer(self, request): def _pack_data_for_infer(self, request):
logging.debug(self._log('start inferce')) _LOGGER.debug(self._log('start inferce'))
data_id = self._get_next_id() data_id = self._get_next_id()
dictdata = {} dictdata = None
try: try:
for idx, key in enumerate(request.key): dictdata = self._unpack_func(request)
logging.debug(self._log('key: {}'.format(key)))
logging.debug(self._log('value: {}'.format(request.value[idx])))
dictdata[key] = request.value[idx]
except Exception as e: except Exception as e:
return ChannelData( return ChannelData(
ecode=ChannelDataEcode.RPC_PACKAGE_ERROR.value, ecode=ChannelDataEcode.RPC_PACKAGE_ERROR.value,
error_info="rpc package error", error_info="rpc package error: {}".format(e),
data_id=data_id), data_id data_id=data_id), data_id
else: else:
return ChannelData( return ChannelData(
...@@ -134,35 +138,8 @@ class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer): ...@@ -134,35 +138,8 @@ class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer):
data_id=data_id), data_id data_id=data_id), data_id
def _pack_data_for_resp(self, channeldata): def _pack_data_for_resp(self, channeldata):
logging.debug(self._log('get channeldata')) _LOGGER.debug(self._log('get channeldata'))
resp = pipeline_service_pb2.Response() return self._pack_func(channeldata)
resp.ecode = channeldata.ecode
if resp.ecode == ChannelDataEcode.OK.value:
if channeldata.datatype == ChannelDataType.CHANNEL_NPDATA.value:
feed = channeldata.parse()
# ndarray to string
for name, var in feed.items():
resp.value.append(var.__repr__())
resp.key.append(name)
elif channeldata.datatype == ChannelDataType.DICT.value:
feed = channeldata.parse()
for name, var in feed.items():
if not isinstance(var, str):
resp.ecode = ChannelDataEcode.TYPE_ERROR.value
resp.error_info = self._log(
"fetch var type must be str({}).".format(
type(var)))
break
resp.value.append(var)
resp.key.append(name)
else:
resp.ecode = ChannelDataEcode.TYPE_ERROR.value
resp.error_info = self._log(
"Error type({}) in datatype.".format(channeldata.datatype))
logging.error(resp.error_info)
else:
resp.error_info = channeldata.error_info
return resp
def inference(self, request, context): def inference(self, request, context):
_profiler.record("{}-prepack_0".format(self.name)) _profiler.record("{}-prepack_0".format(self.name))
...@@ -171,12 +148,12 @@ class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer): ...@@ -171,12 +148,12 @@ class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer):
resp_channeldata = None resp_channeldata = None
for i in range(self._retry): for i in range(self._retry):
logging.debug(self._log('push data')) _LOGGER.debug(self._log('push data'))
_profiler.record("{}-push_0".format(self.name)) _profiler.record("{}-push_0".format(self.name))
self._in_channel.push(data, self.name) self._in_channel.push(data, self.name)
_profiler.record("{}-push_1".format(self.name)) _profiler.record("{}-push_1".format(self.name))
logging.debug(self._log('wait for infer')) _LOGGER.debug(self._log('wait for infer'))
_profiler.record("{}-fetch_0".format(self.name)) _profiler.record("{}-fetch_0".format(self.name))
resp_channeldata = self._get_data_in_globel_resp_dict(data_id) resp_channeldata = self._get_data_in_globel_resp_dict(data_id)
_profiler.record("{}-fetch_1".format(self.name)) _profiler.record("{}-fetch_1".format(self.name))
...@@ -184,7 +161,7 @@ class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer): ...@@ -184,7 +161,7 @@ class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer):
if resp_channeldata.ecode == ChannelDataEcode.OK.value: if resp_channeldata.ecode == ChannelDataEcode.OK.value:
break break
if i + 1 < self._retry: if i + 1 < self._retry:
logging.warn("retry({}): {}".format( _LOGGER.warn("retry({}): {}".format(
i + 1, resp_channeldata.error_info)) i + 1, resp_channeldata.error_info))
_profiler.record("{}-postpack_0".format(self.name)) _profiler.record("{}-postpack_0".format(self.name))
...@@ -197,32 +174,27 @@ class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer): ...@@ -197,32 +174,27 @@ class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer):
class PipelineServer(object): class PipelineServer(object):
def __init__(self): def __init__(self):
self._channels = [] self._channels = []
self._user_ops = []
self._actual_ops = [] self._actual_ops = []
self._port = None self._port = None
self._worker_num = None self._worker_num = None
self._in_channel = None self._in_channel = None
self._out_channel = None self._out_channel = None
self._response_op = None self._response_op = None
self._pack_func = None
self._unpack_func = None
def add_channel(self, channel): def add_channel(self, channel):
self._channels.append(channel) self._channels.append(channel)
def add_op(self, op):
self._user_ops.append(op)
def add_ops(self, ops):
self._user_ops.extend(ops)
def gen_desc(self): def gen_desc(self):
logging.info('here will generate desc for PAAS') _LOGGER.info('here will generate desc for PAAS')
pass pass
def set_response_op(self, response_op): def set_response_op(self, response_op):
if not isinstance(response_op, Op): if not isinstance(response_op, Op):
raise Exception("response_op must be Op type.") raise Exception("response_op must be Op type.")
if len(response_op.get_input_ops()) == 0: if len(response_op.get_input_ops()) != 1:
raise Exception("response_op cannot be ReadOp.") raise Exception("response_op can only have one previous op.")
self._response_op = response_op self._response_op = response_op
def _topo_sort(self, response_op): def _topo_sort(self, response_op):
...@@ -230,19 +202,21 @@ class PipelineServer(object): ...@@ -230,19 +202,21 @@ class PipelineServer(object):
raise Exception("response_op has not been set.") raise Exception("response_op has not been set.")
def get_use_ops(root): def get_use_ops(root):
# root: response_op
unique_names = set() unique_names = set()
use_ops = set() use_ops = set()
succ_ops_of_use_op = {} # {op_name: succ_ops} succ_ops_of_use_op = {} # {op_name: succ_ops}
que = Queue.Queue() que = Queue.Queue()
que.put(root) que.put(root)
use_ops.add(root) #use_ops.add(root)
unique_names.add(root.name) #unique_names.add(root.name)
while que.qsize() != 0: while que.qsize() != 0:
op = que.get() op = que.get()
for pred_op in op.get_input_ops(): for pred_op in op.get_input_ops():
if pred_op.name not in succ_ops_of_use_op: if pred_op.name not in succ_ops_of_use_op:
succ_ops_of_use_op[pred_op.name] = [] succ_ops_of_use_op[pred_op.name] = []
succ_ops_of_use_op[pred_op.name].append(op) if op != root:
succ_ops_of_use_op[pred_op.name].append(op)
if pred_op not in use_ops: if pred_op not in use_ops:
que.put(pred_op) que.put(pred_op)
use_ops.add(pred_op) use_ops.add(pred_op)
...@@ -268,7 +242,8 @@ class PipelineServer(object): ...@@ -268,7 +242,8 @@ class PipelineServer(object):
zero_indegree_num += 1 zero_indegree_num += 1
if zero_indegree_num != 1: if zero_indegree_num != 1:
raise Exception("DAG contains multiple input Ops") raise Exception("DAG contains multiple input Ops")
ques[que_idx].put(response_op) last_op = response_op.get_input_ops()[0]
ques[que_idx].put(last_op)
# topo sort to get dag_views # topo sort to get dag_views
dag_views = [] dag_views = []
...@@ -344,7 +319,7 @@ class PipelineServer(object): ...@@ -344,7 +319,7 @@ class PipelineServer(object):
continue continue
channel = gen_channel(channel_name_gen) channel = gen_channel(channel_name_gen)
channels.append(channel) channels.append(channel)
logging.debug("{} => {}".format(channel.name, op.name)) _LOGGER.debug("{} => {}".format(channel.name, op.name))
op.add_input_channel(channel) op.add_input_channel(channel)
pred_ops = pred_op_of_next_view_op[op.name] pred_ops = pred_op_of_next_view_op[op.name]
if v_idx == 0: if v_idx == 0:
...@@ -352,7 +327,7 @@ class PipelineServer(object): ...@@ -352,7 +327,7 @@ class PipelineServer(object):
else: else:
# if pred_op is virtual op, it will use ancestors as producers to channel # if pred_op is virtual op, it will use ancestors as producers to channel
for pred_op in pred_ops: for pred_op in pred_ops:
logging.debug("{} => {}".format(pred_op.name, _LOGGER.debug("{} => {}".format(pred_op.name,
channel.name)) channel.name))
pred_op.add_output_channel(channel) pred_op.add_output_channel(channel)
processed_op.add(op.name) processed_op.add(op.name)
...@@ -369,24 +344,26 @@ class PipelineServer(object): ...@@ -369,24 +344,26 @@ class PipelineServer(object):
same_flag = False same_flag = False
break break
if same_flag: if same_flag:
logging.debug("{} => {}".format(channel.name, _LOGGER.debug("{} => {}".format(channel.name,
other_op.name)) other_op.name))
other_op.add_input_channel(channel) other_op.add_input_channel(channel)
processed_op.add(other_op.name) processed_op.add(other_op.name)
output_channel = gen_channel(channel_name_gen) output_channel = gen_channel(channel_name_gen)
channels.append(output_channel) channels.append(output_channel)
response_op.add_output_channel(output_channel) last_op.add_output_channel(output_channel)
pack_func, unpack_func = None, None
pack_func = self._response_op.pack_response_package
self._actual_ops = virtual_ops self._actual_ops = virtual_ops
for op in use_ops: for op in use_ops:
if len(op.get_input_ops()) == 0: if len(op.get_input_ops()) == 0:
# pass read op unpack_func = op.unpack_request_package
continue continue
self._actual_ops.append(op) self._actual_ops.append(op)
self._channels = channels self._channels = channels
for c in channels: for c in channels:
logging.debug(c.debug()) _LOGGER.debug(c.debug())
return input_channel, output_channel return input_channel, output_channel, pack_func, unpack_func
def _port_is_available(self, port): def _port_is_available(self, port):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
...@@ -414,7 +391,8 @@ class PipelineServer(object): ...@@ -414,7 +391,8 @@ class PipelineServer(object):
"profile cannot be used in multiprocess version temporarily") "profile cannot be used in multiprocess version temporarily")
_profiler.enable(profile) _profiler.enable(profile)
input_channel, output_channel = self._topo_sort(self._response_op) input_channel, output_channel, self._pack_func, self._unpack_func = self._topo_sort(
self._response_op)
self._in_channel = input_channel self._in_channel = input_channel
self._out_channel = output_channel self._out_channel = output_channel
for op in self._actual_ops: for op in self._actual_ops:
...@@ -443,7 +421,8 @@ class PipelineServer(object): ...@@ -443,7 +421,8 @@ class PipelineServer(object):
server = grpc.server( server = grpc.server(
futures.ThreadPoolExecutor(max_workers=self._worker_num)) futures.ThreadPoolExecutor(max_workers=self._worker_num))
pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server( pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server(
PipelineService(self._in_channel, self._out_channel, self._retry), PipelineService(self._in_channel, self._out_channel,
self._unpack_func, self._pack_func, self._retry),
server) server)
server.add_insecure_port('[::]:{}'.format(self._port)) server.add_insecure_port('[::]:{}'.format(self._port))
server.start() server.start()
...@@ -454,4 +433,4 @@ class PipelineServer(object): ...@@ -454,4 +433,4 @@ class PipelineServer(object):
def prepare_serving(self, op): def prepare_serving(self, op):
# run a server (not in PyServing) # run a server (not in PyServing)
logging.info("run a server (not in PyServing)") _LOGGER.info("run a server (not in PyServing)")
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import os import os
import sys import sys
import logging
if sys.version_info.major == 2: if sys.version_info.major == 2:
import Queue import Queue
elif sys.version_info.major == 3: elif sys.version_info.major == 3:
...@@ -23,6 +24,8 @@ else: ...@@ -23,6 +24,8 @@ else:
raise Exception("Error Python version") raise Exception("Error Python version")
import time import time
_LOGGER = logging.getLogger(__name__)
class TimeProfiler(object): class TimeProfiler(object):
def __init__(self): def __init__(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册