提交 a03210b8 编写于 作者: M MRXLT

add set_rpc_timeout_ms function

上级 d9453185
...@@ -61,13 +61,18 @@ class SDKConfig(object): ...@@ -61,13 +61,18 @@ class SDKConfig(object):
self.tag_list = [] self.tag_list = []
self.cluster_list = [] self.cluster_list = []
self.variant_weight_list = [] self.variant_weight_list = []
self.rpc_timeout_ms = 20000
self.load_balance_strategy = "la"
def add_server_variant(self, tag, cluster, variant_weight): def add_server_variant(self, tag, cluster, variant_weight):
self.tag_list.append(tag) self.tag_list.append(tag)
self.cluster_list.append(cluster) self.cluster_list.append(cluster)
self.variant_weight_list.append(variant_weight) self.variant_weight_list.append(variant_weight)
def gen_desc(self): def set_load_banlance_strategy(self, strategy):
self.load_balance_strategy = strategy
def gen_desc(self, rpc_timeout_ms):
predictor_desc = sdk.Predictor() predictor_desc = sdk.Predictor()
predictor_desc.name = "general_model" predictor_desc.name = "general_model"
predictor_desc.service_name = \ predictor_desc.service_name = \
...@@ -86,7 +91,7 @@ class SDKConfig(object): ...@@ -86,7 +91,7 @@ class SDKConfig(object):
self.sdk_desc.predictors.extend([predictor_desc]) self.sdk_desc.predictors.extend([predictor_desc])
self.sdk_desc.default_variant_conf.tag = "default" self.sdk_desc.default_variant_conf.tag = "default"
self.sdk_desc.default_variant_conf.connection_conf.connect_timeout_ms = 2000 self.sdk_desc.default_variant_conf.connection_conf.connect_timeout_ms = 2000
self.sdk_desc.default_variant_conf.connection_conf.rpc_timeout_ms = 20000 self.sdk_desc.default_variant_conf.connection_conf.rpc_timeout_ms = rpc_timeout_ms
self.sdk_desc.default_variant_conf.connection_conf.connect_retry_count = 2 self.sdk_desc.default_variant_conf.connection_conf.connect_retry_count = 2
self.sdk_desc.default_variant_conf.connection_conf.max_connection_per_host = 100 self.sdk_desc.default_variant_conf.connection_conf.max_connection_per_host = 100
self.sdk_desc.default_variant_conf.connection_conf.hedge_request_timeout_ms = -1 self.sdk_desc.default_variant_conf.connection_conf.hedge_request_timeout_ms = -1
...@@ -119,6 +124,7 @@ class Client(object): ...@@ -119,6 +124,7 @@ class Client(object):
self.profile_ = _Profiler() self.profile_ = _Profiler()
self.all_numpy_input = True self.all_numpy_input = True
self.has_numpy_input = False self.has_numpy_input = False
self.rpc_timeout_ms = 20000
def load_client_config(self, path): def load_client_config(self, path):
from .serving_client import PredictorClient from .serving_client import PredictorClient
...@@ -171,6 +177,12 @@ class Client(object): ...@@ -171,6 +177,12 @@ class Client(object):
self.predictor_sdk_.add_server_variant(tag, cluster, self.predictor_sdk_.add_server_variant(tag, cluster,
str(variant_weight)) str(variant_weight))
def set_rpc_timeout_ms(self, rpc_timeout):
if not isinstance(rpc_timeout, int):
raise ValueError("rpc_timeout must be int type.")
else:
self.rpc_timeout_ms = rpc_timeout
def connect(self, endpoints=None): def connect(self, endpoints=None):
# check whether current endpoint is available # check whether current endpoint is available
# init from client config # init from client config
...@@ -188,7 +200,7 @@ class Client(object): ...@@ -188,7 +200,7 @@ class Client(object):
print( print(
"parameter endpoints({}) will not take effect, because you use the add_variant function.". "parameter endpoints({}) will not take effect, because you use the add_variant function.".
format(endpoints)) format(endpoints))
sdk_desc = self.predictor_sdk_.gen_desc() sdk_desc = self.predictor_sdk_.gen_desc(self.rpc_timeout_ms)
self.client_handle_.create_predictor_by_desc(sdk_desc.SerializeToString( self.client_handle_.create_predictor_by_desc(sdk_desc.SerializeToString(
)) ))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册