提交 a03210b8 编写于 作者: M MRXLT

add set_rpc_timeout_ms function

上级 d9453185
......@@ -61,13 +61,18 @@ class SDKConfig(object):
self.tag_list = []
self.cluster_list = []
self.variant_weight_list = []
self.rpc_timeout_ms = 20000
self.load_balance_strategy = "la"
def add_server_variant(self, tag, cluster, variant_weight):
self.tag_list.append(tag)
self.cluster_list.append(cluster)
self.variant_weight_list.append(variant_weight)
def gen_desc(self):
def set_load_banlance_strategy(self, strategy):
self.load_balance_strategy = strategy
def gen_desc(self, rpc_timeout_ms):
predictor_desc = sdk.Predictor()
predictor_desc.name = "general_model"
predictor_desc.service_name = \
......@@ -86,7 +91,7 @@ class SDKConfig(object):
self.sdk_desc.predictors.extend([predictor_desc])
self.sdk_desc.default_variant_conf.tag = "default"
self.sdk_desc.default_variant_conf.connection_conf.connect_timeout_ms = 2000
self.sdk_desc.default_variant_conf.connection_conf.rpc_timeout_ms = 20000
self.sdk_desc.default_variant_conf.connection_conf.rpc_timeout_ms = rpc_timeout_ms
self.sdk_desc.default_variant_conf.connection_conf.connect_retry_count = 2
self.sdk_desc.default_variant_conf.connection_conf.max_connection_per_host = 100
self.sdk_desc.default_variant_conf.connection_conf.hedge_request_timeout_ms = -1
......@@ -119,6 +124,7 @@ class Client(object):
self.profile_ = _Profiler()
self.all_numpy_input = True
self.has_numpy_input = False
self.rpc_timeout_ms = 20000
def load_client_config(self, path):
from .serving_client import PredictorClient
......@@ -171,6 +177,12 @@ class Client(object):
self.predictor_sdk_.add_server_variant(tag, cluster,
str(variant_weight))
def set_rpc_timeout_ms(self, rpc_timeout):
if not isinstance(rpc_timeout, int):
raise ValueError("rpc_timeout must be int type.")
else:
self.rpc_timeout_ms = rpc_timeout
def connect(self, endpoints=None):
# check whether current endpoint is available
# init from client config
......@@ -188,7 +200,7 @@ class Client(object):
print(
"parameter endpoints({}) will not take effect, because you use the add_variant function.".
format(endpoints))
sdk_desc = self.predictor_sdk_.gen_desc()
sdk_desc = self.predictor_sdk_.gen_desc(self.rpc_timeout_ms)
self.client_handle_.create_predictor_by_desc(sdk_desc.SerializeToString(
))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册