提交 02f9966c 编写于 作者: B barrierye

Support multiple different Op feed data or fetch data

上级 20985a8c
...@@ -43,8 +43,9 @@ data = np.ndarray.tobytes(x) ...@@ -43,8 +43,9 @@ data = np.ndarray.tobytes(x)
req.feed_var_names.append("x") req.feed_var_names.append("x")
req.feed_insts.append(data) req.feed_insts.append(data)
resp = stub.inference(req) for i in range(100):
for idx, name in enumerate(resp.fetch_var_names): resp = stub.inference(req)
print('{}: {}'.format( for idx, name in enumerate(resp.fetch_var_names):
name, np.frombuffer( print('{}: {}'.format(
resp.fetch_insts[idx], dtype='float'))) name, np.frombuffer(
resp.fetch_insts[idx], dtype='float')))
...@@ -29,83 +29,80 @@ logging.basicConfig( ...@@ -29,83 +29,80 @@ logging.basicConfig(
class CombineOp(Op): class CombineOp(Op):
#TODO: different id of data
def preprocess(self, input_data): def preprocess(self, input_data):
data_id = None
cnt = 0 cnt = 0
for input in input_data: for op_name, data in input_data.items():
data = input[0] # batchsize=1 logging.debug("CombineOp preprocess: {}".format(op_name))
cnt += np.frombuffer(data.insts[0].data, dtype='float') cnt += np.frombuffer(data.insts[0].data, dtype='float')
if data_id is None:
data_id = data.id
if data_id != data.id:
raise Exception("id not match: {} vs {}".format(data_id,
data.id))
data = python_service_channel_pb2.ChannelData() data = python_service_channel_pb2.ChannelData()
inst = python_service_channel_pb2.Inst() inst = python_service_channel_pb2.Inst()
inst.data = np.ndarray.tobytes(cnt) inst.data = np.ndarray.tobytes(cnt)
inst.name = "resp" inst.name = "resp"
data.insts.append(inst) data.insts.append(inst)
data.id = data_id
return data return data
def postprocess(self, output_data):
return output_data
class UciOp(Op): class UciOp(Op):
def postprocess(self, output_data): def postprocess(self, output_data):
data_ids = self.get_data_ids()
data = python_service_channel_pb2.ChannelData() data = python_service_channel_pb2.ChannelData()
inst = python_service_channel_pb2.Inst() inst = python_service_channel_pb2.Inst()
pred = np.array(output_data["price"][0][0], dtype='float') pred = np.array(output_data["price"][0][0], dtype='float')
inst.data = np.ndarray.tobytes(pred) inst.data = np.ndarray.tobytes(pred)
inst.name = "prediction" inst.name = "prediction"
data.insts.append(inst) data.insts.append(inst)
data.id = data_ids[0]
return data return data
read_channel = Channel() read_channel = Channel(name="read_channel")
cnn_out_channel = Channel() combine_channel = Channel(name="combine_channel")
bow_out_channel = Channel() out_channel = Channel(name="out_channel")
combine_out_channel = Channel()
cnn_op = UciOp( cnn_op = UciOp(
name="cnn_op", name="cnn_op",
inputs=[read_channel], input=read_channel,
in_dtype='float', in_dtype='float',
outputs=[cnn_out_channel], outputs=[combine_channel],
out_dtype='float', out_dtype='float',
server_model="./uci_housing_model", server_model="./uci_housing_model",
server_port="9393", server_port="9393",
device="cpu", device="cpu",
client_config="uci_housing_client/serving_client_conf.prototxt", client_config="uci_housing_client/serving_client_conf.prototxt",
server_name="127.0.0.1:9393", server_name="127.0.0.1:9393",
fetch_names=["price"]) fetch_names=["price"],
concurrency=2)
bow_op = UciOp( bow_op = UciOp(
name="bow_op", name="bow_op",
inputs=[read_channel], input=read_channel,
in_dtype='float', in_dtype='float',
outputs=[bow_out_channel], outputs=[combine_channel],
out_dtype='float', out_dtype='float',
server_model="./uci_housing_model", server_model="./uci_housing_model",
server_port="9292", server_port="9292",
device="cpu", device="cpu",
client_config="uci_housing_client/serving_client_conf.prototxt", client_config="uci_housing_client/serving_client_conf.prototxt",
server_name="127.0.0.1:9393", server_name="127.0.0.1:9393",
fetch_names=["price"]) fetch_names=["price"],
concurrency=2)
combine_op = CombineOp( combine_op = CombineOp(
name="combine_op", name="combine_op",
inputs=[cnn_out_channel, bow_out_channel], input=combine_channel,
in_dtype='float', in_dtype='float',
outputs=[combine_out_channel], outputs=[out_channel],
out_dtype='float') out_dtype='float',
concurrency=2)
logging.info(read_channel.debug())
logging.info(combine_channel.debug())
logging.info(out_channel.debug())
pyserver = PyServer() pyserver = PyServer()
pyserver.add_channel(read_channel) pyserver.add_channel(read_channel)
pyserver.add_channel(cnn_out_channel) pyserver.add_channel(combine_channel)
pyserver.add_channel(bow_out_channel) pyserver.add_channel(out_channel)
pyserver.add_channel(combine_out_channel)
pyserver.add_op(cnn_op) pyserver.add_op(cnn_op)
pyserver.add_op(bow_op) pyserver.add_op(bow_op)
pyserver.add_op(combine_op) pyserver.add_op(combine_op)
......
...@@ -25,81 +25,202 @@ import general_python_service_pb2 ...@@ -25,81 +25,202 @@ import general_python_service_pb2
import general_python_service_pb2_grpc import general_python_service_pb2_grpc
import python_service_channel_pb2 import python_service_channel_pb2
import logging import logging
import random
import time import time
class Channel(Queue.Queue): class Channel(Queue.Queue):
def __init__(self, maxsize=-1, timeout=None, batchsize=1): """
The channel used for communication between Ops.
1. Support multiple different Op feed data (multiple producer)
Different types of data will be packaged through the data ID
2. Support multiple different Op fetch data (multiple consumer)
Only when all types of Ops get the data of the same ID,
the data will be poped; The Op of the same type will not
get the data of the same ID.
3. (TODO) Timeout and BatchSize are not fully supported.
Note:
1. The ID of the data in the channel must be different.
2. The function add_producer() and add_consumer() are not thread safe,
and can only be called during initialization.
"""
def __init__(self, name=None, maxsize=-1, timeout=None, batchsize=1):
Queue.Queue.__init__(self, maxsize=maxsize) Queue.Queue.__init__(self, maxsize=maxsize)
self._maxsize = maxsize self._maxsize = maxsize
self._timeout = timeout self._timeout = timeout
self._batchsize = batchsize self._name = name
self._pushlock = threading.Lock() #self._batchsize = batchsize
self._frontlock = threading.Lock() # self._pushbatch = []
self._pushbatch = []
self._consumer = {} # {op_name: idx} self._cv = threading.Condition()
self._producers = []
self._producer_res_count = {} # {data_id: count}
self._push_res = {} # {data_id: {op_name: data}}
self._front_wait_interval = 0.1 # second
self._consumers = {} # {op_name: idx}
self._idx_consumer_num = {} # {idx: num}
self._consumer_base_idx = 0 self._consumer_base_idx = 0
self._frontbatch = [] self._front_res = []
self._idx_consumer_num = {}
def get_producers(self):
return self._producers
def get_consumers(self):
return self._consumers.keys()
def _log(self, info_str):
return "[{}] {}".format(self._name, info_str)
def debug(self):
return self._log("p: {}, c: {}".format(self.get_producers(),
self.get_consumers()))
def add_producer(self, op_name):
""" not thread safe, and can only be called during initialization """
if op_name in self._producers:
raise ValueError(
self._log("producer({}) is already in channel".format(op_name)))
self._producers.append(op_name)
def add_consumer(self, op_name): def add_consumer(self, op_name):
""" not thread safe """ """ not thread safe, and can only be called during initialization """
if op_name in self._consumer: if op_name in self._consumers:
raise ValueError("op_name({}) is already in channel".format( raise ValueError(
op_name)) self._log("consumer({}) is already in channel".format(op_name)))
self._consumer_id[op_name] = 0 self._consumers[op_name] = 0
if self._idx_consumer_num.get(0) is None: if self._idx_consumer_num.get(0) is None:
self._idx_consumer_num[0] = 0 self._idx_consumer_num[0] = 0
self._idx_consumer_num[0] += 1 self._idx_consumer_num[0] += 1
def push(self, item): def push(self, data, op_name=None):
with self._pushlock: logging.debug(
self._pushbatch.append(item) self._log("{} try to push data: {}".format(op_name, data)))
if len(self._pushbatch) == self._batchsize: if len(self._producers) == 0:
self.put(self._pushbatch, timeout=self._timeout)
self._pushbatch = []
def front(self, op_name):
if len(self._consumer) == 0:
raise Exception( raise Exception(
"expected number of consumers to be greater than 0, but the it is 0." self._log(
) "expected number of producers to be greater than 0, but the it is 0."
elif len(self._consumer) == 1: ))
return self.get(timeout=self._timeout) elif len(self._producers) == 1:
self._cv.acquire()
while True:
try:
self.put(data, timeout=0)
break
except Queue.Empty:
self._cv.wait()
self._cv.notify_all()
self._cv.release()
logging.debug(self._log("{} push data succ!".format(op_name)))
return True
elif op_name is None:
raise Exception(
self._log(
"There are multiple producers, so op_name cannot be None."))
with self._frontlock: producer_num = len(self._producers)
consumer_idx = self._consumer[op_name] data_id = data.id
base_idx = self._consumer_base_idx put_data = None
data_idx = consumer_idx - base_idx self._cv.acquire()
logging.debug(self._log("{} get lock ~".format(op_name)))
if data_id not in self._push_res:
self._push_res[data_id] = {name: None for name in self._producers}
self._producer_res_count[data_id] = 0
self._push_res[data_id][op_name] = data
if self._producer_res_count[data_id] + 1 == producer_num:
put_data = self._push_res[data_id]
self._push_res.pop(data_id)
self._producer_res_count.pop(data_id)
else:
self._producer_res_count[data_id] += 1
if data_idx >= len(self._frontbatch): if put_data is None:
batch_data = self.get(timeout=self._timeout) logging.debug(
self._frontbatch.append(batch_data) self._log("{} push data succ, not not push to queue.".format(
op_name)))
else:
while True:
try:
self.put(put_data, timeout=0)
break
except Queue.Empty:
self._cv.wait()
logging.debug(
self._log("multi | {} push data succ!".format(op_name)))
self._cv.notify_all()
self._cv.release()
return True
resp = self._frontbatch[data_idx] def front(self, op_name=None):
logging.debug(self._log("{} try to get data".format(op_name)))
if len(self._consumers) == 0:
raise Exception(
self._log(
"expected number of consumers to be greater than 0, but the it is 0."
))
elif len(self._consumers) == 1:
self._cv.acquire()
resp = None
while resp is None:
try:
resp = self.get(timeout=0)
break
except Queue.Empty:
self._cv.wait()
logging.debug(self._log("{} get data succ!".format(op_name)))
return resp
elif op_name is None:
raise Exception(
self._log(
"There are multiple consumers, so op_name cannot be None."))
self._idx_consumer_num[consumer_idx] -= 1 self._cv.acquire()
if consumer_idx == base_idx and self._idx_consumer_num[ # data_idx = consumer_idx - base_idx
consumer_idx] == 0: while self._consumers[op_name] - self._consumer_base_idx >= len(
self._idx_consumer_num.pop(consumer_idx) self._front_res):
self._frontbatch.pop(0) try:
self._consumer_base_idx += 1 data = self.get(timeout=0)
self._front_res.append(data)
break
except Queue.Empty:
self._cv.wait()
consumer_idx = self._consumers[op_name]
base_idx = self._consumer_base_idx
data_idx = consumer_idx - base_idx
resp = self._front_res[data_idx]
logging.debug(self._log("{} get data: {}".format(op_name, resp)))
self._idx_consumer_num[consumer_idx] -= 1
if consumer_idx == base_idx and self._idx_consumer_num[
consumer_idx] == 0:
self._idx_consumer_num.pop(consumer_idx)
self._front_res.pop(0)
self._consumer_base_idx += 1
self._consumers[op_name] += 1
new_consumer_idx = self._consumers[op_name]
if self._idx_consumer_num.get(new_consumer_idx) is None:
self._idx_consumer_num[new_consumer_idx] = 0
self._idx_consumer_num[new_consumer_idx] += 1
self._consumer[op_name] += 1 self._cv.notify_all()
new_consumer_idx = self._consumer[op_name] self._cv.release()
if self._idx_consumer_num.get(new_consumer_idx) is None:
self._idx_consumer_num[new_consumer_idx] = 0
self._idx_consumer_num[new_consumer_idx] += 1
logging.debug(self._log("multi | {} get data succ!".format(op_name)))
return resp # reference, read only return resp # reference, read only
class Op(object): class Op(object):
def __init__(self, def __init__(self,
name, name,
inputs, input,
in_dtype, in_dtype,
outputs, outputs,
out_dtype, out_dtype,
...@@ -115,11 +236,11 @@ class Op(object): ...@@ -115,11 +236,11 @@ class Op(object):
# TODO: globally unique check # TODO: globally unique check
self._name = name # to identify the type of OP, it must be globally unique self._name = name # to identify the type of OP, it must be globally unique
self._concurrency = concurrency # amount of concurrency self._concurrency = concurrency # amount of concurrency
self.set_inputs(inputs) self.set_input(input)
self._in_dtype = in_dtype self._in_dtype = in_dtype
self.set_outputs(outputs) self.set_outputs(outputs)
self._out_dtype = out_dtype self._out_dtype = out_dtype
self._batch_size = batchsize # self._batch_size = batchsize
self._client = None self._client = None
if client_config is not None and \ if client_config is not None and \
server_name is not None and \ server_name is not None and \
...@@ -128,7 +249,6 @@ class Op(object): ...@@ -128,7 +249,6 @@ class Op(object):
self._server_model = server_model self._server_model = server_model
self._server_port = server_port self._server_port = server_port
self._device = device self._device = device
self._data_ids = []
def set_client(self, client_config, server_name, fetch_names): def set_client(self, client_config, server_name, fetch_names):
self._client = Client() self._client = Client()
...@@ -139,59 +259,57 @@ class Op(object): ...@@ -139,59 +259,57 @@ class Op(object):
def with_serving(self): def with_serving(self):
return self._client is not None return self._client is not None
def get_inputs(self): def get_input(self):
return self._inputs return self._input
def set_inputs(self, channels): def set_input(self, channel):
if not isinstance(channels, list): if not isinstance(channel, Channel):
raise TypeError('channels must be list type') raise TypeError(
for channel in channels: self._log('input channel must be Channel type, not {}'.format(
channel.add_consumer(self._name) type(channel))))
self._inputs = channels channel.add_consumer(self._name)
self._input = channel
def get_outputs(self): def get_outputs(self):
return self._outputs return self._outputs
def set_outputs(self, channels): def set_outputs(self, channels):
if not isinstance(channels, list): if not isinstance(channels, list):
raise TypeError('channels must be list type') raise TypeError(
self._log('output channels must be list type, not {}'.format(
type(channels))))
for channel in channels:
channel.add_producer(self._name)
self._outputs = channels self._outputs = channels
def get_data_ids(self): def preprocess(self, data):
return self._data_ids if isinstance(data, dict):
def clear_data_ids(self):
self._data_ids = []
def append_id_to_data_ids(self, data_id):
self._data_ids.append(data_id)
def preprocess(self, input_data):
if len(input_data) != 1:
raise Exception( raise Exception(
'this Op has multiple previous channels. Please override this method' self._log(
) 'this Op has multiple previous inputs. Please override this method'
feed_batch = [] ))
self.clear_data_ids() feed = {}
for data in input_data: for inst in data.insts:
if len(data.insts) != self._batch_size: feed[inst.name] = np.frombuffer(inst.data, dtype=self._in_dtype)
raise Exception('len(data_insts) != self._batch_size') return feed
feed = {}
for inst in data.insts:
feed[inst.name] = np.frombuffer(inst.data, dtype=self._in_dtype)
feed_batch.append(feed)
self.append_id_to_data_ids(data.id)
return feed_batch
def midprocess(self, data): def midprocess(self, data):
# data = preprocess(input), which must be a dict if not isinstance(data, dict):
logging.debug('data: {}'.format(data)) raise Exception(
logging.debug('fetch: {}'.format(self._fetch_names)) self._log(
'data must be dict type(the output of preprocess()), but get {}'.
format(type(data))))
logging.debug(self._log('data: {}'.format(data)))
logging.debug(self._log('fetch: {}'.format(self._fetch_names)))
fetch_map = self._client.predict(feed=data, fetch=self._fetch_names) fetch_map = self._client.predict(feed=data, fetch=self._fetch_names)
logging.debug(self._log("finish predict"))
return fetch_map return fetch_map
def postprocess(self, output_data): def postprocess(self, output_data):
return output_data raise Exception(
self._log(
'Please override this method to convert data to the format in channel.'
))
def stop(self): def stop(self):
self._run = False self._run = False
...@@ -199,22 +317,33 @@ class Op(object): ...@@ -199,22 +317,33 @@ class Op(object):
def start(self): def start(self):
self._run = True self._run = True
while self._run: while self._run:
input_data = [] input_data = self._input.front(self._name)
for channel in self._inputs: data_id = None
input_data.append(channel.front(self._name)) logging.debug(self._log("input_data: {}".format(input_data)))
if len(input_data) > 1: if isinstance(input_data, dict):
data = self.preprocess(input_data) key = input_data.keys()[0]
data_id = input_data[key].id
else: else:
data = self.preprocess(input_data[0]) data_id = input_data.id
data = self.preprocess(input_data)
if self.with_serving(): if self.with_serving():
fetch_map = self.midprocess(data) data = self.midprocess(data)
output_data = self.postprocess(fetch_map) output_data = self.postprocess(data)
else:
output_data = self.postprocess(data) if not isinstance(output_data,
python_service_channel_pb2.ChannelData):
raise TypeError(
self._log(
'output_data must be ChannelData type, but get {}'.
format(type(output_data))))
output_data.id = data_id
for channel in self._outputs: for channel in self._outputs:
channel.push(output_data) channel.push(output_data, self._name)
def _log(self, info_str):
return "[{}] {}".format(self._name, info_str)
def get_concurrency(self): def get_concurrency(self):
return self._concurrency return self._concurrency
...@@ -224,8 +353,11 @@ class GeneralPythonService( ...@@ -224,8 +353,11 @@ class GeneralPythonService(
general_python_service_pb2_grpc.GeneralPythonService): general_python_service_pb2_grpc.GeneralPythonService):
def __init__(self, in_channel, out_channel): def __init__(self, in_channel, out_channel):
super(GeneralPythonService, self).__init__() super(GeneralPythonService, self).__init__()
self._name = "__GeneralPythonService__"
self.set_in_channel(in_channel) self.set_in_channel(in_channel)
self.set_out_channel(out_channel) self.set_out_channel(out_channel)
logging.debug(self._log(in_channel.debug()))
logging.debug(self._log(out_channel.debug()))
#TODO: #TODO:
# multi-lock for different clients # multi-lock for different clients
# diffenert lock for server and client # diffenert lock for server and client
...@@ -236,29 +368,35 @@ class GeneralPythonService( ...@@ -236,29 +368,35 @@ class GeneralPythonService(
self._recive_func = threading.Thread( self._recive_func = threading.Thread(
target=GeneralPythonService._recive_out_channel_func, args=(self, )) target=GeneralPythonService._recive_out_channel_func, args=(self, ))
self._recive_func.start() self._recive_func.start()
logging.debug('succ init')
def _log(self, info_str):
return "[{}] {}".format(self._name, info_str)
def set_in_channel(self, in_channel): def set_in_channel(self, in_channel):
if not isinstance(in_channel, Channel):
raise TypeError(
self._log('in_channel must be Channel type, but get {}'.format(
type(in_channel))))
in_channel.add_producer(self._name)
self._in_channel = in_channel self._in_channel = in_channel
def set_out_channel(self, out_channel): def set_out_channel(self, out_channel):
if isinstance(out_channel, list): if not isinstance(out_channel, Channel):
raise TypeError('out_channel can not be list type') raise TypeError(
out_channel.add_consumer("__GeneralPythonService__") self._log('out_channel must be Channel type, but get {}'.format(
type(out_channel))))
out_channel.add_consumer(self._name)
self._out_channel = out_channel self._out_channel = out_channel
def _recive_out_channel_func(self): def _recive_out_channel_func(self):
while True: while True:
data = self._out_channel.front() data = self._out_channel.front(self._name)
data_id = None if not isinstance(data, python_service_channel_pb2.ChannelData):
for d in data: raise TypeError(
if data_id is None: self._log('data must be ChannelData type, but get {}'.
data_id = d.id format(type(data))))
if data_id != d.id:
raise Exception("id not match: {} vs {}".format(data_id,
d.id))
self._cv.acquire() self._cv.acquire()
self._globel_resp_dict[data_id] = data self._globel_resp_dict[data.id] = data
self._cv.notify_all() self._cv.notify_all()
self._cv.release() self._cv.release()
...@@ -277,13 +415,14 @@ class GeneralPythonService( ...@@ -277,13 +415,14 @@ class GeneralPythonService(
return resp return resp
def _pack_data_for_infer(self, request): def _pack_data_for_infer(self, request):
logging.debug('start inferce') logging.debug(self._log('start inferce'))
data = python_service_channel_pb2.ChannelData() data = python_service_channel_pb2.ChannelData()
data_id = self._get_next_id() data_id = self._get_next_id()
data.id = data_id data.id = data_id
for idx, name in enumerate(request.feed_var_names): for idx, name in enumerate(request.feed_var_names):
logging.debug('name: {}'.format(request.feed_var_names[idx])) logging.debug(
logging.debug('data: {}'.format(request.feed_insts[idx])) self._log('name: {}'.format(request.feed_var_names[idx])))
logging.debug(self._log('data: {}'.format(request.feed_insts[idx])))
inst = python_service_channel_pb2.Inst() inst = python_service_channel_pb2.Inst()
inst.data = request.feed_insts[idx] inst.data = request.feed_insts[idx]
inst.name = name inst.name = name
...@@ -291,23 +430,22 @@ class GeneralPythonService( ...@@ -291,23 +430,22 @@ class GeneralPythonService(
return data, data_id return data, data_id
def _pack_data_for_resp(self, data): def _pack_data_for_resp(self, data):
data = data[0] #TODO batchsize = 1 logging.debug(self._log('get data'))
logging.debug('get data')
resp = general_python_service_pb2.Response() resp = general_python_service_pb2.Response()
logging.debug('gen resp') logging.debug(self._log('gen resp'))
logging.debug(data) logging.debug(data)
for inst in data.insts: for inst in data.insts:
logging.debug('append data') logging.debug(self._log('append data'))
resp.fetch_insts.append(inst.data) resp.fetch_insts.append(inst.data)
logging.debug('append name') logging.debug(self._log('append name'))
resp.fetch_var_names.append(inst.name) resp.fetch_var_names.append(inst.name)
return resp return resp
def inference(self, request, context): def inference(self, request, context):
data, data_id = self._pack_data_for_infer(request) data, data_id = self._pack_data_for_infer(request)
logging.debug('push data') logging.debug(self._log('push data'))
self._in_channel.push(data) self._in_channel.push(data, self._name)
logging.debug('wait for infer') logging.debug(self._log('wait for infer'))
resp_data = None resp_data = None
resp_data = self._get_data_in_globel_resp_dict(data_id) resp_data = self._get_data_in_globel_resp_dict(data_id)
resp = self._pack_data_for_resp(resp_data) resp = self._pack_data_for_resp(resp_data)
...@@ -340,7 +478,7 @@ class PyServer(object): ...@@ -340,7 +478,7 @@ class PyServer(object):
inputs = set() inputs = set()
outputs = set() outputs = set()
for op in self._ops: for op in self._ops:
inputs |= set(op.get_inputs()) inputs |= set([op.get_input()])
outputs |= set(op.get_outputs()) outputs |= set(op.get_outputs())
if op.with_serving(): if op.with_serving():
self.prepare_serving(op) self.prepare_serving(op)
...@@ -360,6 +498,8 @@ class PyServer(object): ...@@ -360,6 +498,8 @@ class PyServer(object):
def _run_ops(self): def _run_ops(self):
for op in self._ops: for op in self._ops:
op_concurrency = op.get_concurrency() op_concurrency = op.get_concurrency()
logging.debug("run op: {}, op_concurrency: {}".format(
op._name, op_concurrency))
for c in range(op_concurrency): for c in range(op_concurrency):
# th = multiprocessing.Process(target=self._op_start_wrapper, args=(op, )) # th = multiprocessing.Process(target=self._op_start_wrapper, args=(op, ))
th = threading.Thread( th = threading.Thread(
...@@ -387,13 +527,13 @@ class PyServer(object): ...@@ -387,13 +527,13 @@ class PyServer(object):
port = op._server_port port = op._server_port
device = op._device device = op._device
# run a server (not in PyServing)
if device == "cpu": if device == "cpu":
cmd = "python -m paddle_serving_server.serve --model {} --thread 4 --port {} &>/dev/null &".format( cmd = "python -m paddle_serving_server.serve --model {} --thread 4 --port {} &>/dev/null &".format(
model_path, port) model_path, port)
else: else:
cmd = "python -m paddle_serving_server_gpu.serve --model {} --thread 4 --port {} &>/dev/null &".format( cmd = "python -m paddle_serving_server_gpu.serve --model {} --thread 4 --port {} &>/dev/null &".format(
model_path, port) model_path, port)
logging.info(cmd) # run a server (not in PyServing)
logging.info("run a server (not in PyServing): {}".format(cmd))
return return
# os.system(cmd) # os.system(cmd)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册