提交 7036ce8a 编写于 作者: D Dong Daxiang 提交者: GitHub

Merge pull request #346 from MRXLT/0.2.0-fix

fix bert demo
...@@ -42,7 +42,7 @@ pip install paddle_serving_app ...@@ -42,7 +42,7 @@ pip install paddle_serving_app
``` ```
执行 执行
``` ```
cat data-c.txt | python bert_client.py head data-c.txt | python bert_client.py --model bert_seq20_client/serving_client_conf.prototxt
``` ```
启动client读取data-c.txt中的数据进行预测,预测结果为文本的向量表示(由于数据较多,脚本中没有将输出进行打印),server端的地址在脚本中修改。 启动client读取data-c.txt中的数据进行预测,预测结果为文本的向量表示(由于数据较多,脚本中没有将输出进行打印),server端的地址在脚本中修改。
......
...@@ -35,21 +35,29 @@ def single_func(idx, resource): ...@@ -35,21 +35,29 @@ def single_func(idx, resource):
dataset = [] dataset = []
for line in fin: for line in fin:
dataset.append(line.strip()) dataset.append(line.strip())
profile_flags = False
if os.environ["FLAGS_profile_client"]:
profile_flags = True
if args.request == "rpc": if args.request == "rpc":
reader = BertReader(vocab_file="vocab.txt", max_seq_len=20) reader = BertReader(vocab_file="vocab.txt", max_seq_len=20)
fetch = ["pooled_output"] fetch = ["pooled_output"]
client = Client() client = Client()
client.load_client_config(args.model) client.load_client_config(args.model)
client.connect([resource["endpoint"][idx % len(resource["endpoint"])]]) client.connect([resource["endpoint"][idx % len(resource["endpoint"])]])
feed_batch = []
for bi in range(args.batch_size):
feed_batch.append(reader.process(dataset[bi]))
start = time.time() start = time.time()
for i in range(1000): for i in range(1000):
if args.batch_size >= 1: if args.batch_size >= 1:
result = client.batch_predict( feed_batch = []
feed_batch=feed_batch, fetch=fetch) b_start = time.time()
for bi in range(args.batch_size):
feed_batch.append(reader.process(dataset[bi]))
b_end = time.time()
if profile_flags:
print("PROFILE\tpid:{}\tbert+pre_0:{} bert_pre_1:{}".format(
os.getpid(),
int(round(b_start * 1000000)),
int(round(b_end * 1000000))))
result = client.predict(feed_batch=feed_batch, fetch=fetch)
else: else:
print("unsupport batch size {}".format(args.batch_size)) print("unsupport batch size {}".format(args.batch_size))
...@@ -62,7 +70,7 @@ def single_func(idx, resource): ...@@ -62,7 +70,7 @@ def single_func(idx, resource):
if __name__ == '__main__': if __name__ == '__main__':
multi_thread_runner = MultiThreadRunner() multi_thread_runner = MultiThreadRunner()
endpoint_list = [ endpoint_list = [
"127.0.0.1:9295", "127.0.0.1:9296", "127.0.0.1:9297", "127.0.0.1:9298" "127.0.0.1:9292", "127.0.0.1:9293", "127.0.0.1:9294", "127.0.0.1:9295"
] ]
result = multi_thread_runner.run(single_func, args.thread, result = multi_thread_runner.run(single_func, args.thread,
{"endpoint": endpoint_list}) {"endpoint": endpoint_list})
......
...@@ -29,13 +29,13 @@ from paddle_serving_app import ChineseBertReader ...@@ -29,13 +29,13 @@ from paddle_serving_app import ChineseBertReader
args = benchmark_args() args = benchmark_args()
reader = ChineseBertReader(max_seq_len=20) reader = ChineseBertReader({"max_seq_len": 20})
fetch = ["pooled_output"] fetch = ["pooled_output"]
endpoint_list = ["127.0.0.1:9292"] endpoint_list = ["127.0.0.1:9292"]
client = Client() client = Client()
client.load_client_config(args.model) client.load_client_config(args.model)
client.connect(endpoint_list) client.connect(endpoint_list)
for line in fin: for line in sys.stdin:
feed_dict = reader.process(line) feed_dict = reader.process(line)
result = client.predict(feed=feed_dict, fetch=fetch) result = client.predict(feed=feed_dict, fetch=fetch)
...@@ -32,8 +32,7 @@ bert_service = BertService(name="bert") ...@@ -32,8 +32,7 @@ bert_service = BertService(name="bert")
bert_service.load() bert_service.load()
bert_service.load_model_config(sys.argv[1]) bert_service.load_model_config(sys.argv[1])
gpu_ids = os.environ["CUDA_VISIBLE_DEVICES"] gpu_ids = os.environ["CUDA_VISIBLE_DEVICES"]
gpus = [int(x) for x in gpu_ids.split(",")] bert_service.set_gpus(gpu_ids)
bert_service.set_gpus(gpus)
bert_service.prepare_server( bert_service.prepare_server(
workdir="workdir", port=int(sys.argv[2]), device="gpu") workdir="workdir", port=int(sys.argv[2]), device="gpu")
bert_service.run_server() bert_service.run_server()
...@@ -55,8 +55,7 @@ def single_func(idx, resource): ...@@ -55,8 +55,7 @@ def single_func(idx, resource):
for i in range(1, 27): for i in range(1, 27):
feed_dict["sparse_{}".format(i - 1)] = data[0][i] feed_dict["sparse_{}".format(i - 1)] = data[0][i]
feed_batch.append(feed_dict) feed_batch.append(feed_dict)
result = client.batch_predict( result = client.predict(feed_batch=feed_batch, fetch=fetch)
feed_batch=feed_batch, fetch=fetch)
else: else:
print("unsupport batch size {}".format(args.batch_size)) print("unsupport batch size {}".format(args.batch_size))
......
...@@ -50,8 +50,7 @@ def single_func(idx, resource): ...@@ -50,8 +50,7 @@ def single_func(idx, resource):
img = reader.process_image(img_list[i]) img = reader.process_image(img_list[i])
img = img.reshape(-1) img = img.reshape(-1)
feed_batch.append({"image": img}) feed_batch.append({"image": img})
result = client.batch_predict( result = client.predict(feed_batch=feed_batch, fetch=fetch)
feed_batch=feed_batch, fetch=fetch)
else: else:
print("unsupport batch size {}".format(args.batch_size)) print("unsupport batch size {}".format(args.batch_size))
......
...@@ -42,7 +42,7 @@ def single_func(idx, resource): ...@@ -42,7 +42,7 @@ def single_func(idx, resource):
for bi in range(args.batch_size): for bi in range(args.batch_size):
word_ids, label = imdb_dataset.get_words_and_label(line) word_ids, label = imdb_dataset.get_words_and_label(line)
feed_batch.append({"words": word_ids}) feed_batch.append({"words": word_ids})
result = client.batch_predict( result = client.predict(
feed_batch=feed_batch, fetch=["prediction"]) feed_batch=feed_batch, fetch=["prediction"])
else: else:
print("unsupport batch size {}".format(args.batch_size)) print("unsupport batch size {}".format(args.batch_size))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册