提交 d2f4fdff 编写于 作者: B barrierye

add default postprocess func

上级 937e7719
...@@ -34,52 +34,40 @@ class CombineOp(Op): ...@@ -34,52 +34,40 @@ class CombineOp(Op):
for op_name, channeldata in input_data.items(): for op_name, channeldata in input_data.items():
logging.debug("CombineOp preprocess: {}".format(op_name)) logging.debug("CombineOp preprocess: {}".format(op_name))
data = channeldata.parse() data = channeldata.parse()
cnt += data["prediction"] cnt += data["price"]
data = {"combine_op_output": cnt} data = {"combine_op_output": cnt}
return data return data
def postprocess(self, output_data):
return output_data
class UciOp(Op):
def postprocess(self, output_data):
pred = np.array(output_data["price"][0][0], dtype='float32')
data = {"prediction": pred}
return data
read_channel = Channel(name="read_channel") read_channel = Channel(name="read_channel")
combine_channel = Channel(name="combine_channel") combine_channel = Channel(name="combine_channel")
out_channel = Channel(name="out_channel") out_channel = Channel(name="out_channel")
uci1_op = UciOp( uci1_op = Op(name="uci1",
name="uci1", input=read_channel,
input=read_channel, outputs=[combine_channel],
outputs=[combine_channel], server_model="./uci_housing_model",
server_model="./uci_housing_model", server_port="9393",
server_port="9393", device="cpu",
device="cpu", client_config="uci_housing_client/serving_client_conf.prototxt",
client_config="uci_housing_client/serving_client_conf.prototxt", server_name="127.0.0.1:9393",
server_name="127.0.0.1:9393", fetch_names=["price"],
fetch_names=["price"], concurrency=1,
concurrency=1, timeout=0.1,
timeout=0.1, retry=2)
retry=2)
uci2_op = UciOp( uci2_op = Op(name="uci2",
name="uci2", input=read_channel,
input=read_channel, outputs=[combine_channel],
outputs=[combine_channel], server_model="./uci_housing_model",
server_model="./uci_housing_model", server_port="9292",
server_port="9292", device="cpu",
device="cpu", client_config="uci_housing_client/serving_client_conf.prototxt",
client_config="uci_housing_client/serving_client_conf.prototxt", server_name="127.0.0.1:9393",
server_name="127.0.0.1:9393", fetch_names=["price"],
fetch_names=["price"], concurrency=1,
concurrency=1, timeout=-1,
timeout=-1, retry=1)
retry=1)
combine_op = CombineOp( combine_op = CombineOp(
name="combine", name="combine",
......
...@@ -396,11 +396,7 @@ class Op(object): ...@@ -396,11 +396,7 @@ class Op(object):
return call_future return call_future
def postprocess(self, output_data): def postprocess(self, output_data):
raise Exception( return output_data
self._log(
'Please override this method to convert data to the format in channel.' \
' The return value format should be in {name(str): var(narray)}'
))
def errorprocess(self, error_info, data_id): def errorprocess(self, error_info, data_id):
data = channel_pb2.ChannelData() data = channel_pb2.ChannelData()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册