提交 fab692cb 编写于 作者: W wangjiawei04

add bert

上级 756cf163
import sys
import os
import yaml
import requests
import time
import json
try:
from paddle_serving_server_gpu.pipeline import PipelineClient
except ImportError:
from paddle_serving_server.pipeline import PipelineClient
import numpy as np
from paddle_serving_client.utils import MultiThreadRunner
from paddle_serving_client.utils import benchmark_args, show_latency
def gen_yml(device):
fin = open("config.yml", "r")
config = yaml.load(fin)
fin.close()
config["dag"]["tracer"] = {"interval_s": 20}
if device == "gpu":
config["op"]["bert"]["local_service_conf"]["device_type"] = 1
config["op"]["bert"]["local_service_conf"]["devices"] = "2"
with open("config2.yml", "w") as fout:
yaml.dump(config, fout, default_flow_style=False)
def run_http(idx, batch_size):
print("start thread ({})".format(idx))
url = "http://127.0.0.1:18082/bert/prediction"
start = time.time()
with open("data-c.txt", 'r') as fin:
start = time.time()
lines = fin.readlines()
start_idx = 0
while start_idx < len(lines):
end_idx = min(len(lines), start_idx + batch_size)
feed = {}
for i in range(start_idx, end_idx):
feed[str(i - start_idx)] = lines[i]
keys = list(feed.keys())
values = [feed[x] for x in keys]
data = {"key": keys, "value": values}
r = requests.post(url=url, data=json.dumps(data))
start_idx += batch_size
end = time.time()
return [[end - start]]
def multithread_http(thread, batch_size):
multi_thread_runner = MultiThreadRunner()
result = multi_thread_runner.run(run_http , thread, batch_size)
def run_rpc(thread, batch_size):
client = PipelineClient()
client.connect(['127.0.0.1:9998'])
with open("data-c.txt", 'r') as fin:
start = time.time()
lines = fin.readlines()
start_idx = 0
while start_idx < len(lines):
end_idx = min(len(lines), start_idx + batch_size)
feed = {}
for i in range(start_idx, end_idx):
feed[str(i - start_idx)] = lines[i]
ret = client.predict(feed_dict=feed, fetch=["res"])
start_idx += batch_size
end = time.time()
return [[end - start]]
def multithread_rpc(thraed, batch_size):
multi_thread_runner = MultiThreadRunner()
result = multi_thread_runner.run(run_rpc , thread, batch_size)
if __name__ == "__main__":
if sys.argv[1] == "yaml":
mode = sys.argv[2] # brpc/ local predictor
thread = int(sys.argv[3])
device = sys.argv[4]
gen_yml(device)
elif sys.argv[1] == "run":
mode = sys.argv[2] # http/ rpc
thread = int(sys.argv[3])
batch_size = int(sys.argv[4])
if mode == "http":
multithread_http(thread, batch_size)
elif mode == "rpc":
multithread_rpc(thread, batch_size)
alias python3="python3.7"
# HTTP
ps -ef | grep web_service | awk '{print $2}' | xargs kill -9
sleep 3
python3 benchmark.py yaml local_predictor 1 gpu
for thread_num in 1
do
for batch_size in 1
do
rm -rf PipelineServingLogs
rm -rf cpu_utilization.py
python3 web_service.py >web.log 2>&1 &
sleep 3
nvidia-smi --id=2 --query-compute-apps=used_memory --format=csv -lms 100 > gpu_use.log 2>&1 &
nvidia-smi --id=2 --query-gpu=utilization.gpu --format=csv -lms 100 > gpu_utilization.log 2>&1 &
echo "import psutil\ncpu_utilization=psutil.cpu_percent(1,False)\nprint('CPU_UTILIZATION:', cpu_utilization)\n" > cpu_utilization.py
python3 benchmark.py run http $thread_num $batch_size
python3 cpu_utilization.py
echo "------------Fit a line pipeline benchmark (Thread: $thread_num) (BatchSize: $batch_size)"
tail -n 25 PipelineServingLogs/pipeline.tracer
awk 'BEGIN {max = 0} {if(NR>1){if ($1 > max) max=$1}} END {print "MAX_GPU_MEMORY:", max}' gpu_use.log >> profile_log_$1
awk 'BEGIN {max = 0} {if(NR>1){if ($1 > max) max=$1}} END {print "GPU_UTILIZATION:", max}' gpu_utilization.log >> profile_log_$1
#rm -rf gpu_use.log gpu_utilization.log
ps -ef | grep web_service | awk '{print $2}' | xargs kill -9
done
done
# RPC
ps -ef | grep web_service | awk '{print $2}' | xargs kill -9
sleep 3
python3 benchmark.py yaml local_predictor 1 gpu
for thread_num in 1
do
for batch_size in 1
do
rm -rf PipelineServingLogs
rm -rf cpu_utilization.py
python3 web_service.py >web.log 2>&1 &
sleep 3
nvidia-smi --id=2 --query-compute-apps=used_memory --format=csv -lms 100 > gpu_use.log 2>&1 &
nvidia-smi --id=2 --query-gpu=utilization.gpu --format=csv -lms 100 > gpu_utilization.log 2>&1 &
echo "import psutil\ncpu_utilization=psutil.cpu_percent(1,False)\nprint('CPU_UTILIZATION:', cpu_utilization)\n" > cpu_utilization.py
python3 benchmark.py run rpc $thread_num $batch_size
python3 cpu_utilization.py
echo "------------Fit a line pipeline benchmark (Thread: $thread_num) (BatchSize: $batch_size)"
tail -n 25 PipelineServingLogs/pipeline.tracer
awk 'BEGIN {max = 0} {if(NR>1){if ($1 > max) max=$1}} END {print "MAX_GPU_MEMORY:", max}' gpu_use.log >> profile_log_$1
awk 'BEGIN {max = 0} {if(NR>1){if ($1 > max) max=$1}} END {print "GPU_UTILIZATION:", max}' gpu_utilization.log >> profile_log_$1
#rm -rf gpu_use.log gpu_utilization.log
ps -ef | grep web_service | awk '{print $2}' | xargs kill -9
done
done
dag:
is_thread_op: false
tracer:
interval_s: 10
http_port: 18082
op:
bert:
local_service_conf:
client_type: local_predictor
concurrency: 2
device_type: 1
devices: '2'
fetch_list:
- pooled_output
model_config: bert_seq128_model/
rpc_port: 9998
worker_num: 1
wget https://paddle-serving.bj.bcebos.com/paddle_hub_models/text/SemanticModel/bert_chinese_L-12_H-768_A-12.tar.gz
tar -xzf bert_chinese_L-12_H-768_A-12.tar.gz
mv bert_chinese_L-12_H-768_A-12_model bert_seq128_model
mv bert_chinese_L-12_H-768_A-12_client bert_seq128_client
wget https://paddle-serving.bj.bcebos.com/bert_example/data-c.txt --no-check-certificate
wget https://paddle-serving.bj.bcebos.com/bert_example/vocab.txt --no-check-certificate
import sys
import os
import yaml
import requests
import time
import json
try:
from paddle_serving_server_gpu.pipeline import PipelineClient
except ImportError:
from paddle_serving_server.pipeline import PipelineClient
import numpy as np
client = PipelineClient()
client.connect(['127.0.0.1:9998'])
batch_size = 101
with open("data-c.txt", 'r') as fin:
lines = fin.readlines()
start_idx = 0
while start_idx < len(lines):
end_idx = min(len(lines), start_idx + batch_size)
feed = {}
for i in range(start_idx, end_idx):
feed[str(i - start_idx)] = lines[i]
ret = client.predict(feed_dict=feed, fetch=["res"])
print(ret)
start_idx += batch_size
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
try:
from paddle_serving_server_gpu.web_service import WebService, Op
except ImportError:
from paddle_serving_server.web_service import WebService, Op
import logging
import numpy as np
import sys
from paddle_serving_app.reader import ChineseBertReader
_LOGGER = logging.getLogger()
class BertOp(Op):
def init_op(self):
self.reader = ChineseBertReader({
"vocab_file": "vocab.txt",
"max_seq_len": 128
})
def preprocess(self, input_dicts, data_id, log_id):
(_, input_dict), = input_dicts.items()
print("input dict", input_dict)
batch_size = len(input_dict.keys())
feed_res = []
for i in range(batch_size):
feed_dict = self.reader.process(input_dict[str(i)].encode("utf-8"))
for key in feed_dict.keys():
feed_dict[key] = np.array(feed_dict[key]).reshape((1, len(feed_dict[key]), 1))
feed_res.append(feed_dict)
feed_dict = {}
for key in feed_res[0].keys():
feed_dict[key] = np.concatenate([x[key] for x in feed_res], axis=0)
print(key, feed_dict[key].shape)
return feed_dict, False, None, ""
def postprocess(self, input_dicts, fetch_dict, log_id):
fetch_dict["pooled_output"] = str(fetch_dict["pooled_output"])
return fetch_dict, None, ""
class BertService(WebService):
def get_pipeline_response(self, read_op):
bert_op = BertOp(name="bert", input_ops=[read_op])
return bert_op
bert_service = BertService(name="bert")
bert_service.prepare_pipeline_config("config2.yml")
bert_service.run_service()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册