提交 9428785b 编写于 作者: B barrierye

update proto && PredictorRes in client

上级 99ed075f
...@@ -32,7 +32,7 @@ message Request { ...@@ -32,7 +32,7 @@ message Request {
repeated FeedInst insts = 1; repeated FeedInst insts = 1;
repeated string feed_var_names = 2; repeated string feed_var_names = 2;
repeated string fetch_var_names = 3; repeated string fetch_var_names = 3;
required bool is_python = 4 [ default = true ]; required bool is_python = 4 [ default = false ];
}; };
message Response { message Response {
......
...@@ -21,7 +21,6 @@ import google.protobuf.text_format ...@@ -21,7 +21,6 @@ import google.protobuf.text_format
import numpy as np import numpy as np
import time import time
import sys import sys
from .serving_client import PredictorRes
import grpc import grpc
from .proto import multi_lang_general_model_service_pb2 from .proto import multi_lang_general_model_service_pb2
...@@ -129,6 +128,8 @@ class Client(object): ...@@ -129,6 +128,8 @@ class Client(object):
self.all_numpy_input = True self.all_numpy_input = True
self.has_numpy_input = False self.has_numpy_input = False
self.rpc_timeout_ms = 20000 self.rpc_timeout_ms = 20000
from .serving_client import PredictorRes
self.predictorres_constructor = PredictorRes
def load_client_config(self, path): def load_client_config(self, path):
from .serving_client import PredictorClient from .serving_client import PredictorClient
...@@ -308,7 +309,7 @@ class Client(object): ...@@ -308,7 +309,7 @@ class Client(object):
self.profile_.record('py_prepro_1') self.profile_.record('py_prepro_1')
self.profile_.record('py_client_infer_0') self.profile_.record('py_client_infer_0')
result_batch_handle = PredictorRes() result_batch_handle = self.predictorres_constructor()
if self.all_numpy_input: if self.all_numpy_input:
res = self.client_handle_.numpy_predict( res = self.client_handle_.numpy_predict(
float_slot_batch, float_feed_names, float_shape, int_slot_batch, float_slot_batch, float_feed_names, float_shape, int_slot_batch,
...@@ -495,9 +496,11 @@ class MultiLangClient(object): ...@@ -495,9 +496,11 @@ class MultiLangClient(object):
raise Exception("error type.") raise Exception("error type.")
else: else:
if v_type == 0: # int64 if v_type == 0: # int64
result_map[name] = np.array(list(var.int64_data)) result_map[name] = np.array(
list(var.int64_data), dtype="int64")
elif v_type == 1: # float32 elif v_type == 1: # float32
result_map[name] = np.array(list(var.float_data)) result_map[name] = np.array(
list(var.float_data), dtype="float32")
else: else:
raise Exception("error type.") raise Exception("error type.")
result_map[name].shape = list(var.shape) result_map[name].shape = list(var.shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册