提交 f86a565b 编写于 作者: B barrierye

recover client.connect() function

上级 8636d0d4
......@@ -41,9 +41,7 @@ def single_func(idx, resource):
fetch = ["pooled_output"]
client = Client()
client.load_client_config(args.model)
client.add_variant(
"var1", [resource["endpoint"][idx % len(resource["endpoint"])]], 50)
client.connect()
client.connect([resource["endpoint"][idx % len(resource["endpoint"])]])
start = time.time()
for i in range(1000):
......
......@@ -40,9 +40,7 @@ def single_func(idx, resource):
fetch = ["pooled_output"]
client = Client()
client.load_client_config(args.model)
client.add_variant(
"var1", [resource["endpoint"][idx % len(resource["endpoint"])]], 50)
client.connect()
client.connect([resource["endpoint"][idx % len(resource["endpoint"])]])
feed_batch = []
for bi in range(args.batch_size):
feed_batch.append(reader.process(dataset[bi]))
......
......@@ -33,8 +33,7 @@ fetch = ["pooled_output"]
endpoint_list = ["127.0.0.1:9494"]
client = Client()
client.load_client_config(args.model)
client.add_variant("var1", endpoint_list, 50)
client.connect()
client.connect(endpoint_list)
for line in fin:
feed_dict = reader.process(line)
......
......@@ -43,9 +43,7 @@ def single_func(idx, resource):
fetch = ["prob"]
client = Client()
client.load_client_config(args.model)
client.add_variant(
"var1", [resource["endpoint"][idx % len(resource["endpoint"])]], 50)
client.connect()
client.connect([resource["endpoint"][idx % len(resource["endpoint"])]])
start = time.time()
for i in range(1000):
......
......@@ -43,9 +43,7 @@ def single_func(idx, resource):
fetch = ["prob"]
client = Client()
client.load_client_config(args.model)
client.add_variant(
"var1", [resource["endpoint"][idx % len(resource["endpoint"])]], 50)
client.connect()
client.connect([resource["endpoint"][idx % len(resource["endpoint"])]])
start = time.time()
for i in range(1000):
......
......@@ -23,8 +23,7 @@ from paddle_serving_client.metric import auc
client = Client()
client.load_client_config(sys.argv[1])
client.add_variant("var1", ["127.0.0.1:9292"], 50)
client.connect()
client.connect(["127.0.0.1:9292"])
batch = 1
buf_size = 100
......
......@@ -30,8 +30,7 @@ args = benchmark_args()
def single_func(idx, resource):
client = Client()
client.load_client_config('ctr_client_conf/serving_client_conf.prototxt')
client.add_variant("var1", ['127.0.0.1:9292'], 50)
client.connect()
client.connect(['127.0.0.1:9292'])
batch = 1
buf_size = 100
dataset = criteo.CriteoDataset()
......
......@@ -31,8 +31,7 @@ def single_func(idx, resource):
client = Client()
print([resource["endpoint"][idx % len(resource["endpoint"])]])
client.load_client_config('ctr_client_conf/serving_client_conf.prototxt')
client.add_variant("var1", ['127.0.0.1:9292'], 50)
client.connect()
client.connect(['127.0.0.1:9292'])
batch = 1
buf_size = 100
dataset = criteo.CriteoDataset()
......
......@@ -22,8 +22,7 @@ from paddle_serving_client.metric import auc
client = Client()
client.load_client_config(sys.argv[1])
client.add_variant("var1", ['127.0.0.1:9292'], 50)
client.connect()
client.connect(['127.0.0.1:9292'])
batch = 1
buf_size = 100
......
......@@ -28,8 +28,7 @@ def single_func(idx, resource):
if args.request == "rpc":
client = Client()
client.load_client_config(args.model)
client.add_variant("var1", [args.endpoint], 50)
client.connect()
client.connect([args.endpoint])
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.uci_housing.train(), buf_size=500),
......
......@@ -18,8 +18,7 @@ import sys
client = Client()
client.load_client_config(sys.argv[1])
client.add_variant("var1", ["127.0.0.1:9393"], 50)
client.connect()
client.connect(["127.0.0.1:9393"])
import paddle
test_reader = paddle.batch(
......
......@@ -36,9 +36,7 @@ def single_func(idx, resource):
fetch = ["score"]
client = Client()
client.load_client_config(args.model)
client.add_variant(
"var1", [resource["endpoint"][idx % len(resource["endpoint"])]], 50)
client.connect()
client.connect([resource["endpoint"][idx % len(resource["endpoint"])]])
start = time.time()
for i in range(1000):
......
......@@ -41,9 +41,7 @@ def single_func(idx, resource):
fetch = ["score"]
client = Client()
client.load_client_config(args.model)
client.add_variant(
"var1", [resource["endpoint"][idx % len(resource["endpoint"])]], 50)
client.connect()
client.connect([resource["endpoint"][idx % len(resource["endpoint"])]])
start = time.time()
for i in range(1000):
if args.batch_size >= 1:
......
......@@ -19,8 +19,7 @@ import time
client = Client()
client.load_client_config(sys.argv[1])
client.add_variant("var1", ["127.0.0.1:9295"], 50)
client.connect()
client.connect(["127.0.0.1:9295"])
reader = ImageReader()
start = time.time()
......
......@@ -35,8 +35,7 @@ def single_func(idx, resource):
if args.request == "rpc":
client = Client()
client.load_client_config(args.model)
client.add_variant("var1", [args.endpoint], 50)
client.connect()
client.connect([args.endpoint])
for i in range(1000):
if args.batch_size >= 1:
feed_batch = []
......
......@@ -18,8 +18,7 @@ import sys
client = Client()
client.load_client_config(sys.argv[1])
client.add_variant("var1", ["127.0.0.1:9292"], 50)
client.connect()
client.connect(["127.0.0.1:9292"])
# you can define any english sentence or dataset here
# This example reuses imdb reader in training, you
......
......@@ -23,8 +23,7 @@ import time
def batch_predict(batch_size=4):
client = Client()
client.load_client_config(conf_file)
client.add_variant("var1", ["127.0.0.1:9292"], 50)
client.connect()
client.connect(["127.0.0.1:9292"])
fetch = ["acc", "cost", "prediction"]
feed_batch = []
for line in sys.stdin:
......
......@@ -30,8 +30,7 @@ def single_func(idx, resource):
if args.request == "rpc":
client = Client()
client.load_client_config(args.model)
client.add_variant("var1", [args.endpoint], 50)
client.connect()
client.connect([args.endpoint])
fin = open("jieba_test.txt")
for line in fin:
feed_data = reader.process(line)
......
......@@ -22,8 +22,7 @@ import io
client = Client()
client.load_client_config(sys.argv[1])
client.add_variant("var1", ["127.0.0.1:9280"], 50)
client.connect()
client.connect(["127.0.0.1:9280"])
reader = LACReader(sys.argv[2])
for line in sys.stdin:
......
......@@ -85,7 +85,7 @@ class Client(object):
self.feed_names_to_idx_ = {}
self.rpath()
self.pid = os.getpid()
self.predictor_sdk_ = SDKConfig()
self.predictor_sdk_ = None
def rpath(self):
lib_path = os.path.dirname(paddle_serving_client.__file__)
......@@ -138,13 +138,27 @@ class Client(object):
return
def add_variant(self, tag, cluster, variant_weight):
if self.predictor_sdk_ is None:
self.predictor_sdk_ = SDKConfig()
self.predictor_sdk_.add_server_variant(tag, cluster,
str(variant_weight))
def connect(self):
def connect(self, endpoints=None):
# check whether current endpoint is available
# init from client config
# create predictor here
if endpoints is None:
if self.predictor_sdk_ is None:
raise SystemExit(
"You must set the endpoints parameter or use add_variant function to create a variant."
)
else:
if self.predictor_sdk_ is None:
self.add_variant('var1', endpoints, 100)
else:
print(
"endpoints({}) will not be enabled because you use the add_variant function.".
format(endpoints))
sdk_desc = self.predictor_sdk_.gen_desc()
print(sdk_desc)
self.client_handle_.create_predictor_by_desc(sdk_desc.SerializeToString(
......
......@@ -54,9 +54,7 @@ class WebService(object):
client_service = Client()
client_service.load_client_config(
"{}/serving_server_conf.prototxt".format(self.model_config))
client_service.add_variant("var1",
["0.0.0.0:{}".format(self.port + 1)], 100)
client_service.connect()
client_service.connect(["0.0.0.0:{}".format(self.port + 1)])
service_name = "/" + self.name + "/prediction"
@app_instance.route(service_name, methods=['POST'])
......
......@@ -91,8 +91,7 @@ class WebService(object):
client = Client()
client.load_client_config("{}/serving_server_conf.prototxt".format(
self.model_config))
client.add_variant("var1", [endpoint], 100)
client.connect()
client.connect([endpoint])
while True:
request_json = inputqueue.get()
feed, fetch = self.preprocess(request_json, request_json["fetch"])
......@@ -135,8 +134,7 @@ class WebService(object):
client = Client()
client.load_client_config("{}/serving_server_conf.prototxt".format(
self.model_config))
client.add_variant("var1", ["0.0.0.0:{}".format(self.port + 1)], 100)
client.connect()
client.connect(["0.0.0.0:{}".format(self.port + 1)])
self.idx = 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册