提交 5acbb589 编写于 作者: B barrierye

add default preprocess in pyserving

上级 f21134b0
......@@ -22,14 +22,6 @@ import python_service_channel_pb2
# channel data: {name(str): data(bytes)}
class ImdbOp(Op):
def preprocess(self, input_data):
data = input_data[0] # batchsize=1
feed = {}
for inst in data.insts:
feed[inst.name] = np.frombuffer(inst.data, dtype='int64')
# feed[inst.name] = np.frombuffer(inst.data)
return feed
def postprocess(self, output_data):
data = python_service_channel_pb2.ChannelData()
inst = python_service_channel_pb2.Inst()
......@@ -58,14 +50,6 @@ class CombineOp(Op):
class UciOp(Op):
def preprocess(self, input_data):
data = input_data[0] # batchsize=1
feed = {}
for inst in data.insts:
feed[inst.name] = np.frombuffer(inst.data, dtype='float')
# feed[inst.name] = np.frombuffer(inst.data)
return feed
def postprocess(self, output_data):
data = python_service_channel_pb2.ChannelData()
inst = python_service_channel_pb2.Inst()
......@@ -83,7 +67,9 @@ bow_out_channel = Channel()
combine_out_channel = Channel()
cnn_op = UciOp(
inputs=[read_channel],
in_dtype='float',
outputs=[cnn_out_channel],
out_dtype='float',
server_model="./uci_housing_model",
server_port="9393",
device="cpu",
......@@ -92,7 +78,9 @@ cnn_op = UciOp(
fetch_names=["price"])
bow_op = UciOp(
inputs=[read_channel],
in_dtype='float',
outputs=[bow_out_channel],
out_dtype='float',
server_model="./uci_housing_model",
server_port="9292",
device="cpu",
......@@ -120,7 +108,10 @@ bow_op = ImdbOp(
fetch_names=["acc", "cost", "prediction"])
'''
combine_op = CombineOp(
inputs=[cnn_out_channel, bow_out_channel], outputs=[combine_out_channel])
inputs=[cnn_out_channel, bow_out_channel],
in_dtype='float',
outputs=[combine_out_channel],
out_dtype='float')
pyserver = PyServer()
pyserver.add_channel(read_channel)
......
......@@ -19,7 +19,7 @@ import os
import paddle_serving_server
from paddle_serving_client import Client
from concurrent import futures
import numpy
import numpy as np
import grpc
import general_python_service_pb2
import general_python_service_pb2_grpc
......@@ -63,7 +63,10 @@ class Channel(Queue.Queue):
class Op(object):
def __init__(self,
inputs,
in_dtype,
outputs,
out_dtype,
batchsize=1,
server_model=None,
server_port=None,
device=None,
......@@ -72,7 +75,10 @@ class Op(object):
fetch_names=None):
self._run = False
self.set_inputs(inputs)
self._in_dtype = in_dtype
self.set_outputs(outputs)
self._out_dtype = out_dtype
self._batch_size = batchsize
self._client = None
if client_config is not None and \
server_name is not None and \
......@@ -108,7 +114,19 @@ class Op(object):
self._outputs = channels
def preprocess(self, input_data):
return input_data
if len(input_data) != 1:
raise Exception(
'this Op has multiple previous channels. Please override this method'
)
feed_batch = []
for data in input_data:
if len(data.insts) != self._batch_size:
raise Exception('len(data_insts) != self._batch_size')
feed = {}
for inst in data.insts:
feed[inst.name] = np.frombuffer(inst.data, dtype=self._in_dtype)
feed_batch.append(feed)
return feed_batch
def midprocess(self, data):
# data = preprocess(input), which must be a dict
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册