提交 653b33b9 编写于 作者: G guru4elephant

update python client api

上级 b87806d8
......@@ -13,6 +13,51 @@
# limitations under the License.
from .serving_client import PredictorClient
from ..proto import sdk_configure_pb2 as sdk
import time
class SDKConfig(object):
def __init__(self):
self.sdk_desc = sdk.SDKConf()
self.endpoints = []
def set_server_endpoints(self, endpoints):
self.endpoints = endpoints
def gen_desc(self):
predictor_desc = sdk.Predictor()
predictor_desc.name = "general_model"
predictor_desc.service_name = \
"baidu.paddle_serving.predictor.general_model.GeneralModelService"
predictor_desc.endpoint_router = "WeightedRandomRender"
predictor_desc.weighted_random_render_conf.variant_weight_list = "30"
variant_desc = sdk.VariantConf()
variant_desc.tag = "var1"
variant_desc.naming_conf.cluster = "list://%s".format(":".join(self.endpoints))
predictor_desc.variants.extend([variant_desc])
self.sdk_desc.predictors.extend([predictor_desc])
self.sdk_desc.default_variant_conf.tag = "default"
self.sdk_desc.default_variant_conf.connection_conf.connect_timeout_ms = 2000
self.sdk_desc.default_variant_conf.connection_conf.rpc_timeout_ms = 20000
self.sdk_desc.default_variant_conf.connection_conf.connect_retry_count = 2
self.sdk_desc.default_variant_conf.connection_conf.max_connection_per_host = 100
self.sdk_desc.default_variant_conf.connection_conf.hedge_request_timeout_ms = -1
self.sdk_desc.default_variant_conf.connection_conf.hedge_fetch_retry_count = 2
self.sdk_desc.default_variant_conf.connection_conf.connection_type = "pooled"
self.sdk_desc.default_variant_conf.naming_conf.cluster_filter_strategy = "Default"
self.sdk_desc.default_variant_conf.naming_conf.load_balance_strategy = "la"
self.sdk_desc.default_variant_conf.rpc_parameter.compress_type = 0
self.sdk_desc.default_variant_conf.rpc_parameter.package_size = 20
self.sdk_desc.default_variant_conf.rpc_parameter.protocol = "baidu_std"
self.sdk_desc.default_variant_conf.rpc_parameter.max_channel_per_request = 3
return str(self.sdk_desc)
class Client(object):
def __init__(self):
......@@ -28,13 +73,26 @@ class Client(object):
# get feed vars, fetch vars
# get feed shapes, feed types
# map feed names to index
self.client_handle_ = PredictorClient()
self.client_handle_.init(path)
return
def connect(self, endpoint):
def connect(self, endpoints):
# check whether current endpoint is available
# init from client config
# create predictor here
return
predictor_sdk = SDKConfig()
predictor_sdk.set_server_endpoints(endpoints)
sdk_desc = predictor_sdk.gen_desc()
print(sdk_desc)
timestamp = time.asctime(time.localtime(time.time()))
predictor_path = "/tmp/"
predictor_file = "%s_predictor.conf" % timestamp
with open(predictor_path + predictor_file, "w") as fout:
fout.write(sdk_desc)
self.client_handle_.set_predictor_conf(
predictor_path, predictor_file)
self.client_handle_.create_predictor()
def get_feed_names(self):
return self.feed_names_
......
......@@ -33,10 +33,13 @@ REQUIRED_PACKAGES = [
]
packages=['paddle_serving',
'paddle_serving.serving_client']
'paddle_serving.serving_client',
'paddle_serving.proto']
package_data={'paddle_serving.serving_client': ['serving_client.so']}
package_dir={'paddle_serving.serving_client':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving/serving_client'}
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving/serving_client',
'paddle_serving.proto':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving/proto'}
setup(
name='paddle-serving-client',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册