提交 51bd1035 编写于 作者: B barrierye

recover client.connect() function

上级 0afd7195
...@@ -41,9 +41,7 @@ def single_func(idx, resource): ...@@ -41,9 +41,7 @@ def single_func(idx, resource):
fetch = ["pooled_output"] fetch = ["pooled_output"]
client = Client() client = Client()
client.load_client_config(args.model) client.load_client_config(args.model)
client.add_variant( client.connect([resource["endpoint"][idx % len(resource["endpoint"])]])
"var1", [resource["endpoint"][idx % len(resource["endpoint"])]], 50)
client.connect()
start = time.time() start = time.time()
for i in range(1000): for i in range(1000):
......
...@@ -40,9 +40,7 @@ def single_func(idx, resource): ...@@ -40,9 +40,7 @@ def single_func(idx, resource):
fetch = ["pooled_output"] fetch = ["pooled_output"]
client = Client() client = Client()
client.load_client_config(args.model) client.load_client_config(args.model)
client.add_variant( client.connect([resource["endpoint"][idx % len(resource["endpoint"])]])
"var1", [resource["endpoint"][idx % len(resource["endpoint"])]], 50)
client.connect()
feed_batch = [] feed_batch = []
for bi in range(args.batch_size): for bi in range(args.batch_size):
feed_batch.append(reader.process(dataset[bi])) feed_batch.append(reader.process(dataset[bi]))
......
...@@ -33,8 +33,7 @@ fetch = ["pooled_output"] ...@@ -33,8 +33,7 @@ fetch = ["pooled_output"]
endpoint_list = ["127.0.0.1:9494"] endpoint_list = ["127.0.0.1:9494"]
client = Client() client = Client()
client.load_client_config(args.model) client.load_client_config(args.model)
client.add_variant("var1", endpoint_list, 50) client.connect(endpoint_list)
client.connect()
for line in fin: for line in fin:
feed_dict = reader.process(line) feed_dict = reader.process(line)
......
...@@ -43,9 +43,7 @@ def single_func(idx, resource): ...@@ -43,9 +43,7 @@ def single_func(idx, resource):
fetch = ["prob"] fetch = ["prob"]
client = Client() client = Client()
client.load_client_config(args.model) client.load_client_config(args.model)
client.add_variant( client.connect([resource["endpoint"][idx % len(resource["endpoint"])]])
"var1", [resource["endpoint"][idx % len(resource["endpoint"])]], 50)
client.connect()
start = time.time() start = time.time()
for i in range(1000): for i in range(1000):
......
...@@ -43,9 +43,7 @@ def single_func(idx, resource): ...@@ -43,9 +43,7 @@ def single_func(idx, resource):
fetch = ["prob"] fetch = ["prob"]
client = Client() client = Client()
client.load_client_config(args.model) client.load_client_config(args.model)
client.add_variant( client.connect([resource["endpoint"][idx % len(resource["endpoint"])]])
"var1", [resource["endpoint"][idx % len(resource["endpoint"])]], 50)
client.connect()
start = time.time() start = time.time()
for i in range(1000): for i in range(1000):
......
...@@ -23,8 +23,7 @@ from paddle_serving_client.metric import auc ...@@ -23,8 +23,7 @@ from paddle_serving_client.metric import auc
client = Client() client = Client()
client.load_client_config(sys.argv[1]) client.load_client_config(sys.argv[1])
client.add_variant("var1", ["127.0.0.1:9292"], 50) client.connect(["127.0.0.1:9292"])
client.connect()
batch = 1 batch = 1
buf_size = 100 buf_size = 100
......
...@@ -30,8 +30,7 @@ args = benchmark_args() ...@@ -30,8 +30,7 @@ args = benchmark_args()
def single_func(idx, resource): def single_func(idx, resource):
client = Client() client = Client()
client.load_client_config('ctr_client_conf/serving_client_conf.prototxt') client.load_client_config('ctr_client_conf/serving_client_conf.prototxt')
client.add_variant("var1", ['127.0.0.1:9292'], 50) client.connect(['127.0.0.1:9292'])
client.connect()
batch = 1 batch = 1
buf_size = 100 buf_size = 100
dataset = criteo.CriteoDataset() dataset = criteo.CriteoDataset()
......
...@@ -31,8 +31,7 @@ def single_func(idx, resource): ...@@ -31,8 +31,7 @@ def single_func(idx, resource):
client = Client() client = Client()
print([resource["endpoint"][idx % len(resource["endpoint"])]]) print([resource["endpoint"][idx % len(resource["endpoint"])]])
client.load_client_config('ctr_client_conf/serving_client_conf.prototxt') client.load_client_config('ctr_client_conf/serving_client_conf.prototxt')
client.add_variant("var1", ['127.0.0.1:9292'], 50) client.connect(['127.0.0.1:9292'])
client.connect()
batch = 1 batch = 1
buf_size = 100 buf_size = 100
dataset = criteo.CriteoDataset() dataset = criteo.CriteoDataset()
......
...@@ -22,8 +22,7 @@ from paddle_serving_client.metric import auc ...@@ -22,8 +22,7 @@ from paddle_serving_client.metric import auc
client = Client() client = Client()
client.load_client_config(sys.argv[1]) client.load_client_config(sys.argv[1])
client.add_variant("var1", ['127.0.0.1:9292'], 50) client.connect(['127.0.0.1:9292'])
client.connect()
batch = 1 batch = 1
buf_size = 100 buf_size = 100
......
...@@ -28,8 +28,7 @@ def single_func(idx, resource): ...@@ -28,8 +28,7 @@ def single_func(idx, resource):
if args.request == "rpc": if args.request == "rpc":
client = Client() client = Client()
client.load_client_config(args.model) client.load_client_config(args.model)
client.add_variant("var1", [args.endpoint], 50) client.connect([args.endpoint])
client.connect()
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
paddle.dataset.uci_housing.train(), buf_size=500), paddle.dataset.uci_housing.train(), buf_size=500),
......
...@@ -18,8 +18,7 @@ import sys ...@@ -18,8 +18,7 @@ import sys
client = Client() client = Client()
client.load_client_config(sys.argv[1]) client.load_client_config(sys.argv[1])
client.add_variant("var1", ["127.0.0.1:9393"], 50) client.connect(["127.0.0.1:9393"])
client.connect()
import paddle import paddle
test_reader = paddle.batch( test_reader = paddle.batch(
......
...@@ -36,9 +36,7 @@ def single_func(idx, resource): ...@@ -36,9 +36,7 @@ def single_func(idx, resource):
fetch = ["score"] fetch = ["score"]
client = Client() client = Client()
client.load_client_config(args.model) client.load_client_config(args.model)
client.add_variant( client.connect([resource["endpoint"][idx % len(resource["endpoint"])]])
"var1", [resource["endpoint"][idx % len(resource["endpoint"])]], 50)
client.connect()
start = time.time() start = time.time()
for i in range(1000): for i in range(1000):
......
...@@ -41,9 +41,7 @@ def single_func(idx, resource): ...@@ -41,9 +41,7 @@ def single_func(idx, resource):
fetch = ["score"] fetch = ["score"]
client = Client() client = Client()
client.load_client_config(args.model) client.load_client_config(args.model)
client.add_variant( client.connect([resource["endpoint"][idx % len(resource["endpoint"])]])
"var1", [resource["endpoint"][idx % len(resource["endpoint"])]], 50)
client.connect()
start = time.time() start = time.time()
for i in range(1000): for i in range(1000):
if args.batch_size >= 1: if args.batch_size >= 1:
......
...@@ -19,8 +19,7 @@ import time ...@@ -19,8 +19,7 @@ import time
client = Client() client = Client()
client.load_client_config(sys.argv[1]) client.load_client_config(sys.argv[1])
client.add_variant("var1", ["127.0.0.1:9295"], 50) client.connect(["127.0.0.1:9295"])
client.connect()
reader = ImageReader() reader = ImageReader()
start = time.time() start = time.time()
......
...@@ -35,8 +35,7 @@ def single_func(idx, resource): ...@@ -35,8 +35,7 @@ def single_func(idx, resource):
if args.request == "rpc": if args.request == "rpc":
client = Client() client = Client()
client.load_client_config(args.model) client.load_client_config(args.model)
client.add_variant("var1", [args.endpoint], 50) client.connect([args.endpoint])
client.connect()
for i in range(1000): for i in range(1000):
if args.batch_size >= 1: if args.batch_size >= 1:
feed_batch = [] feed_batch = []
......
...@@ -18,8 +18,7 @@ import sys ...@@ -18,8 +18,7 @@ import sys
client = Client() client = Client()
client.load_client_config(sys.argv[1]) client.load_client_config(sys.argv[1])
client.add_variant("var1", ["127.0.0.1:9292"], 50) client.connect(["127.0.0.1:9292"])
client.connect()
# you can define any english sentence or dataset here # you can define any english sentence or dataset here
# This example reuses imdb reader in training, you # This example reuses imdb reader in training, you
......
...@@ -23,8 +23,7 @@ import time ...@@ -23,8 +23,7 @@ import time
def batch_predict(batch_size=4): def batch_predict(batch_size=4):
client = Client() client = Client()
client.load_client_config(conf_file) client.load_client_config(conf_file)
client.add_variant("var1", ["127.0.0.1:9292"], 50) client.connect(["127.0.0.1:9292"])
client.connect()
fetch = ["acc", "cost", "prediction"] fetch = ["acc", "cost", "prediction"]
feed_batch = [] feed_batch = []
for line in sys.stdin: for line in sys.stdin:
......
...@@ -30,8 +30,7 @@ def single_func(idx, resource): ...@@ -30,8 +30,7 @@ def single_func(idx, resource):
if args.request == "rpc": if args.request == "rpc":
client = Client() client = Client()
client.load_client_config(args.model) client.load_client_config(args.model)
client.add_variant("var1", [args.endpoint], 50) client.connect([args.endpoint])
client.connect()
fin = open("jieba_test.txt") fin = open("jieba_test.txt")
for line in fin: for line in fin:
feed_data = reader.process(line) feed_data = reader.process(line)
......
...@@ -22,8 +22,7 @@ import io ...@@ -22,8 +22,7 @@ import io
client = Client() client = Client()
client.load_client_config(sys.argv[1]) client.load_client_config(sys.argv[1])
client.add_variant("var1", ["127.0.0.1:9280"], 50) client.connect(["127.0.0.1:9280"])
client.connect()
reader = LACReader(sys.argv[2]) reader = LACReader(sys.argv[2])
for line in sys.stdin: for line in sys.stdin:
......
...@@ -85,7 +85,7 @@ class Client(object): ...@@ -85,7 +85,7 @@ class Client(object):
self.feed_names_to_idx_ = {} self.feed_names_to_idx_ = {}
self.rpath() self.rpath()
self.pid = os.getpid() self.pid = os.getpid()
self.predictor_sdk_ = SDKConfig() self.predictor_sdk_ = None
def rpath(self): def rpath(self):
lib_path = os.path.dirname(paddle_serving_client.__file__) lib_path = os.path.dirname(paddle_serving_client.__file__)
...@@ -138,13 +138,27 @@ class Client(object): ...@@ -138,13 +138,27 @@ class Client(object):
return return
def add_variant(self, tag, cluster, variant_weight): def add_variant(self, tag, cluster, variant_weight):
if self.predictor_sdk_ is None:
self.predictor_sdk_ = SDKConfig()
self.predictor_sdk_.add_server_variant(tag, cluster, self.predictor_sdk_.add_server_variant(tag, cluster,
str(variant_weight)) str(variant_weight))
def connect(self): def connect(self, endpoints=None):
# check whether current endpoint is available # check whether current endpoint is available
# init from client config # init from client config
# create predictor here # create predictor here
if endpoints is None:
if self.predictor_sdk_ is None:
raise SystemExit(
"You must set the endpoints parameter or use add_variant function to create a variant."
)
else:
if self.predictor_sdk_ is None:
self.add_variant('var1', endpoints, 100)
else:
print(
"endpoints({}) will not be enabled because you use the add_variant function.".
format(endpoints))
sdk_desc = self.predictor_sdk_.gen_desc() sdk_desc = self.predictor_sdk_.gen_desc()
print(sdk_desc) print(sdk_desc)
self.client_handle_.create_predictor_by_desc(sdk_desc.SerializeToString( self.client_handle_.create_predictor_by_desc(sdk_desc.SerializeToString(
......
...@@ -54,9 +54,7 @@ class WebService(object): ...@@ -54,9 +54,7 @@ class WebService(object):
client_service = Client() client_service = Client()
client_service.load_client_config( client_service.load_client_config(
"{}/serving_server_conf.prototxt".format(self.model_config)) "{}/serving_server_conf.prototxt".format(self.model_config))
client_service.add_variant("var1", client_service.connect(["0.0.0.0:{}".format(self.port + 1)])
["0.0.0.0:{}".format(self.port + 1)], 100)
client_service.connect()
service_name = "/" + self.name + "/prediction" service_name = "/" + self.name + "/prediction"
@app_instance.route(service_name, methods=['POST']) @app_instance.route(service_name, methods=['POST'])
......
...@@ -91,8 +91,7 @@ class WebService(object): ...@@ -91,8 +91,7 @@ class WebService(object):
client = Client() client = Client()
client.load_client_config("{}/serving_server_conf.prototxt".format( client.load_client_config("{}/serving_server_conf.prototxt".format(
self.model_config)) self.model_config))
client.add_variant("var1", [endpoint], 100) client.connect([endpoint])
client.connect()
while True: while True:
request_json = inputqueue.get() request_json = inputqueue.get()
feed, fetch = self.preprocess(request_json, request_json["fetch"]) feed, fetch = self.preprocess(request_json, request_json["fetch"])
...@@ -135,8 +134,7 @@ class WebService(object): ...@@ -135,8 +134,7 @@ class WebService(object):
client = Client() client = Client()
client.load_client_config("{}/serving_server_conf.prototxt".format( client.load_client_config("{}/serving_server_conf.prototxt".format(
self.model_config)) self.model_config))
client.add_variant("var1", ["0.0.0.0:{}".format(self.port + 1)], 100) client.connect(["0.0.0.0:{}".format(self.port + 1)])
client.connect()
self.idx = 0 self.idx = 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册