提交 a5e0a6c1 编写于 作者: B barrierye

Support multiple different Op feed data or fetch data

上级 da0949c4
...@@ -43,8 +43,9 @@ data = np.ndarray.tobytes(x) ...@@ -43,8 +43,9 @@ data = np.ndarray.tobytes(x)
req.feed_var_names.append("x") req.feed_var_names.append("x")
req.feed_insts.append(data) req.feed_insts.append(data)
resp = stub.inference(req) for i in range(100):
for idx, name in enumerate(resp.fetch_var_names): resp = stub.inference(req)
print('{}: {}'.format( for idx, name in enumerate(resp.fetch_var_names):
name, np.frombuffer( print('{}: {}'.format(
resp.fetch_insts[idx], dtype='float'))) name, np.frombuffer(
resp.fetch_insts[idx], dtype='float')))
...@@ -29,83 +29,80 @@ logging.basicConfig( ...@@ -29,83 +29,80 @@ logging.basicConfig(
class CombineOp(Op): class CombineOp(Op):
#TODO: different id of data
def preprocess(self, input_data): def preprocess(self, input_data):
data_id = None
cnt = 0 cnt = 0
for input in input_data: for op_name, data in input_data.items():
data = input[0] # batchsize=1 logging.debug("CombineOp preprocess: {}".format(op_name))
cnt += np.frombuffer(data.insts[0].data, dtype='float') cnt += np.frombuffer(data.insts[0].data, dtype='float')
if data_id is None:
data_id = data.id
if data_id != data.id:
raise Exception("id not match: {} vs {}".format(data_id,
data.id))
data = python_service_channel_pb2.ChannelData() data = python_service_channel_pb2.ChannelData()
inst = python_service_channel_pb2.Inst() inst = python_service_channel_pb2.Inst()
inst.data = np.ndarray.tobytes(cnt) inst.data = np.ndarray.tobytes(cnt)
inst.name = "resp" inst.name = "resp"
data.insts.append(inst) data.insts.append(inst)
data.id = data_id
return data return data
def postprocess(self, output_data):
return output_data
class UciOp(Op): class UciOp(Op):
def postprocess(self, output_data): def postprocess(self, output_data):
data_ids = self.get_data_ids()
data = python_service_channel_pb2.ChannelData() data = python_service_channel_pb2.ChannelData()
inst = python_service_channel_pb2.Inst() inst = python_service_channel_pb2.Inst()
pred = np.array(output_data["price"][0][0], dtype='float') pred = np.array(output_data["price"][0][0], dtype='float')
inst.data = np.ndarray.tobytes(pred) inst.data = np.ndarray.tobytes(pred)
inst.name = "prediction" inst.name = "prediction"
data.insts.append(inst) data.insts.append(inst)
data.id = data_ids[0]
return data return data
read_channel = Channel() read_channel = Channel(name="read_channel")
cnn_out_channel = Channel() combine_channel = Channel(name="combine_channel")
bow_out_channel = Channel() out_channel = Channel(name="out_channel")
combine_out_channel = Channel()
cnn_op = UciOp( cnn_op = UciOp(
name="cnn_op", name="cnn_op",
inputs=[read_channel], input=read_channel,
in_dtype='float', in_dtype='float',
outputs=[cnn_out_channel], outputs=[combine_channel],
out_dtype='float', out_dtype='float',
server_model="./uci_housing_model", server_model="./uci_housing_model",
server_port="9393", server_port="9393",
device="cpu", device="cpu",
client_config="uci_housing_client/serving_client_conf.prototxt", client_config="uci_housing_client/serving_client_conf.prototxt",
server_name="127.0.0.1:9393", server_name="127.0.0.1:9393",
fetch_names=["price"]) fetch_names=["price"],
concurrency=2)
bow_op = UciOp( bow_op = UciOp(
name="bow_op", name="bow_op",
inputs=[read_channel], input=read_channel,
in_dtype='float', in_dtype='float',
outputs=[bow_out_channel], outputs=[combine_channel],
out_dtype='float', out_dtype='float',
server_model="./uci_housing_model", server_model="./uci_housing_model",
server_port="9292", server_port="9292",
device="cpu", device="cpu",
client_config="uci_housing_client/serving_client_conf.prototxt", client_config="uci_housing_client/serving_client_conf.prototxt",
server_name="127.0.0.1:9393", server_name="127.0.0.1:9393",
fetch_names=["price"]) fetch_names=["price"],
concurrency=2)
combine_op = CombineOp( combine_op = CombineOp(
name="combine_op", name="combine_op",
inputs=[cnn_out_channel, bow_out_channel], input=combine_channel,
in_dtype='float', in_dtype='float',
outputs=[combine_out_channel], outputs=[out_channel],
out_dtype='float') out_dtype='float',
concurrency=2)
logging.info(read_channel.debug())
logging.info(combine_channel.debug())
logging.info(out_channel.debug())
pyserver = PyServer() pyserver = PyServer()
pyserver.add_channel(read_channel) pyserver.add_channel(read_channel)
pyserver.add_channel(cnn_out_channel) pyserver.add_channel(combine_channel)
pyserver.add_channel(bow_out_channel) pyserver.add_channel(out_channel)
pyserver.add_channel(combine_out_channel)
pyserver.add_op(cnn_op) pyserver.add_op(cnn_op)
pyserver.add_op(bow_op) pyserver.add_op(bow_op)
pyserver.add_op(combine_op) pyserver.add_op(combine_op)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册