diff --git a/python/paddle_serving_client/__init__.py b/python/paddle_serving_client/__init__.py index e3302c14239c8bfc37a6bafb39b112cfed5230fd..f2922f577b21d8acc3f8ec629f2473b5339ee725 100644 --- a/python/paddle_serving_client/__init__.py +++ b/python/paddle_serving_client/__init__.py @@ -61,13 +61,18 @@ class SDKConfig(object): self.tag_list = [] self.cluster_list = [] self.variant_weight_list = [] + self.rpc_timeout_ms = 20000 + self.load_balance_strategy = "la" def add_server_variant(self, tag, cluster, variant_weight): self.tag_list.append(tag) self.cluster_list.append(cluster) self.variant_weight_list.append(variant_weight) - def gen_desc(self): + def set_load_banlance_strategy(self, strategy): + self.load_balance_strategy = strategy + + def gen_desc(self, rpc_timeout_ms): predictor_desc = sdk.Predictor() predictor_desc.name = "general_model" predictor_desc.service_name = \ @@ -86,7 +91,7 @@ class SDKConfig(object): self.sdk_desc.predictors.extend([predictor_desc]) self.sdk_desc.default_variant_conf.tag = "default" self.sdk_desc.default_variant_conf.connection_conf.connect_timeout_ms = 2000 - self.sdk_desc.default_variant_conf.connection_conf.rpc_timeout_ms = 20000 + self.sdk_desc.default_variant_conf.connection_conf.rpc_timeout_ms = rpc_timeout_ms self.sdk_desc.default_variant_conf.connection_conf.connect_retry_count = 2 self.sdk_desc.default_variant_conf.connection_conf.max_connection_per_host = 100 self.sdk_desc.default_variant_conf.connection_conf.hedge_request_timeout_ms = -1 @@ -119,6 +124,7 @@ class Client(object): self.profile_ = _Profiler() self.all_numpy_input = True self.has_numpy_input = False + self.rpc_timeout_ms = 20000 def load_client_config(self, path): from .serving_client import PredictorClient @@ -171,6 +177,12 @@ class Client(object): self.predictor_sdk_.add_server_variant(tag, cluster, str(variant_weight)) + def set_rpc_timeout_ms(self, rpc_timeout): + if not isinstance(rpc_timeout, int): + raise ValueError("rpc_timeout must be int type.") + else: + self.rpc_timeout_ms = rpc_timeout + def connect(self, endpoints=None): # check whether current endpoint is available # init from client config @@ -188,7 +200,7 @@ class Client(object): print( "parameter endpoints({}) will not take effect, because you use the add_variant function.". format(endpoints)) - sdk_desc = self.predictor_sdk_.gen_desc() + sdk_desc = self.predictor_sdk_.gen_desc(self.rpc_timeout_ms) self.client_handle_.create_predictor_by_desc(sdk_desc.SerializeToString( ))