提交 9428785b 编写于 作者: B barrierye

update proto && PredictorRes in client

上级 99ed075f
......@@ -32,7 +32,7 @@ message Request {
repeated FeedInst insts = 1;
repeated string feed_var_names = 2;
repeated string fetch_var_names = 3;
required bool is_python = 4 [ default = true ];
required bool is_python = 4 [ default = false ];
};
message Response {
......
......@@ -21,7 +21,6 @@ import google.protobuf.text_format
import numpy as np
import time
import sys
from .serving_client import PredictorRes
import grpc
from .proto import multi_lang_general_model_service_pb2
......@@ -129,6 +128,8 @@ class Client(object):
self.all_numpy_input = True
self.has_numpy_input = False
self.rpc_timeout_ms = 20000
from .serving_client import PredictorRes
self.predictorres_constructor = PredictorRes
def load_client_config(self, path):
from .serving_client import PredictorClient
......@@ -308,7 +309,7 @@ class Client(object):
self.profile_.record('py_prepro_1')
self.profile_.record('py_client_infer_0')
result_batch_handle = PredictorRes()
result_batch_handle = self.predictorres_constructor()
if self.all_numpy_input:
res = self.client_handle_.numpy_predict(
float_slot_batch, float_feed_names, float_shape, int_slot_batch,
......@@ -495,9 +496,11 @@ class MultiLangClient(object):
raise Exception("error type.")
else:
if v_type == 0: # int64
result_map[name] = np.array(list(var.int64_data))
result_map[name] = np.array(
list(var.int64_data), dtype="int64")
elif v_type == 1: # float32
result_map[name] = np.array(list(var.float_data))
result_map[name] = np.array(
list(var.float_data), dtype="float32")
else:
raise Exception("error type.")
result_map[name].shape = list(var.shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册