提交 a8635745 编写于 作者: B barrierye

add default preprocess in pyserving

上级 3b2fa3d5
...@@ -22,14 +22,6 @@ import python_service_channel_pb2 ...@@ -22,14 +22,6 @@ import python_service_channel_pb2
# channel data: {name(str): data(bytes)} # channel data: {name(str): data(bytes)}
class ImdbOp(Op): class ImdbOp(Op):
def preprocess(self, input_data):
data = input_data[0] # batchsize=1
feed = {}
for inst in data.insts:
feed[inst.name] = np.frombuffer(inst.data, dtype='int64')
# feed[inst.name] = np.frombuffer(inst.data)
return feed
def postprocess(self, output_data): def postprocess(self, output_data):
data = python_service_channel_pb2.ChannelData() data = python_service_channel_pb2.ChannelData()
inst = python_service_channel_pb2.Inst() inst = python_service_channel_pb2.Inst()
...@@ -58,14 +50,6 @@ class CombineOp(Op): ...@@ -58,14 +50,6 @@ class CombineOp(Op):
class UciOp(Op): class UciOp(Op):
def preprocess(self, input_data):
data = input_data[0] # batchsize=1
feed = {}
for inst in data.insts:
feed[inst.name] = np.frombuffer(inst.data, dtype='float')
# feed[inst.name] = np.frombuffer(inst.data)
return feed
def postprocess(self, output_data): def postprocess(self, output_data):
data = python_service_channel_pb2.ChannelData() data = python_service_channel_pb2.ChannelData()
inst = python_service_channel_pb2.Inst() inst = python_service_channel_pb2.Inst()
...@@ -83,7 +67,9 @@ bow_out_channel = Channel() ...@@ -83,7 +67,9 @@ bow_out_channel = Channel()
combine_out_channel = Channel() combine_out_channel = Channel()
cnn_op = UciOp( cnn_op = UciOp(
inputs=[read_channel], inputs=[read_channel],
in_dtype='float',
outputs=[cnn_out_channel], outputs=[cnn_out_channel],
out_dtype='float',
server_model="./uci_housing_model", server_model="./uci_housing_model",
server_port="9393", server_port="9393",
device="cpu", device="cpu",
...@@ -92,7 +78,9 @@ cnn_op = UciOp( ...@@ -92,7 +78,9 @@ cnn_op = UciOp(
fetch_names=["price"]) fetch_names=["price"])
bow_op = UciOp( bow_op = UciOp(
inputs=[read_channel], inputs=[read_channel],
in_dtype='float',
outputs=[bow_out_channel], outputs=[bow_out_channel],
out_dtype='float',
server_model="./uci_housing_model", server_model="./uci_housing_model",
server_port="9292", server_port="9292",
device="cpu", device="cpu",
...@@ -120,7 +108,10 @@ bow_op = ImdbOp( ...@@ -120,7 +108,10 @@ bow_op = ImdbOp(
fetch_names=["acc", "cost", "prediction"]) fetch_names=["acc", "cost", "prediction"])
''' '''
combine_op = CombineOp( combine_op = CombineOp(
inputs=[cnn_out_channel, bow_out_channel], outputs=[combine_out_channel]) inputs=[cnn_out_channel, bow_out_channel],
in_dtype='float',
outputs=[combine_out_channel],
out_dtype='float')
pyserver = PyServer() pyserver = PyServer()
pyserver.add_channel(read_channel) pyserver.add_channel(read_channel)
......
...@@ -19,7 +19,7 @@ import os ...@@ -19,7 +19,7 @@ import os
import paddle_serving_server import paddle_serving_server
from paddle_serving_client import Client from paddle_serving_client import Client
from concurrent import futures from concurrent import futures
import numpy import numpy as np
import grpc import grpc
import general_python_service_pb2 import general_python_service_pb2
import general_python_service_pb2_grpc import general_python_service_pb2_grpc
...@@ -63,7 +63,10 @@ class Channel(Queue.Queue): ...@@ -63,7 +63,10 @@ class Channel(Queue.Queue):
class Op(object): class Op(object):
def __init__(self, def __init__(self,
inputs, inputs,
in_dtype,
outputs, outputs,
out_dtype,
batchsize=1,
server_model=None, server_model=None,
server_port=None, server_port=None,
device=None, device=None,
...@@ -72,7 +75,10 @@ class Op(object): ...@@ -72,7 +75,10 @@ class Op(object):
fetch_names=None): fetch_names=None):
self._run = False self._run = False
self.set_inputs(inputs) self.set_inputs(inputs)
self._in_dtype = in_dtype
self.set_outputs(outputs) self.set_outputs(outputs)
self._out_dtype = out_dtype
self._batch_size = batchsize
self._client = None self._client = None
if client_config is not None and \ if client_config is not None and \
server_name is not None and \ server_name is not None and \
...@@ -108,7 +114,19 @@ class Op(object): ...@@ -108,7 +114,19 @@ class Op(object):
self._outputs = channels self._outputs = channels
def preprocess(self, input_data): def preprocess(self, input_data):
return input_data if len(input_data) != 1:
raise Exception(
'this Op has multiple previous channels. Please override this method'
)
feed_batch = []
for data in input_data:
if len(data.insts) != self._batch_size:
raise Exception('len(data_insts) != self._batch_size')
feed = {}
for inst in data.insts:
feed[inst.name] = np.frombuffer(inst.data, dtype=self._in_dtype)
feed_batch.append(feed)
return feed_batch
def midprocess(self, data): def midprocess(self, data):
# data = preprocess(input), which must be a dict # data = preprocess(input), which must be a dict
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册